Initial YakPanel commit

This commit is contained in:
Niranjan
2026-04-07 02:04:22 +05:30
commit 2826d3e7f3
5359 changed files with 1390724 additions and 0 deletions

View File

@@ -0,0 +1,13 @@
# coding: utf-8
from .config_manager import *
from .fields import *
from .manager import Q
from .model import aaModel
__version__ = "1.2.0"
__all__ = [
"__version__",
"aaModel",
"Q",
] + fields.__all__ + config_manager.__all__

View File

@@ -0,0 +1,387 @@
# coding: utf-8
# -------------------------------------------------------------------
# YakPanel
# -------------------------------------------------------------------
# Copyright (c) 2014-2099 YakPanel(www.yakpanel.com) All rights reserved.
# -------------------------------------------------------------------
# Author: yakpanel
# -------------------------------------------------------------------
# ------------------------------
# config app
# ------------------------------
import copy
import os
import threading
try:
import ujson as json
except ImportError:
import json
__all__ = [
"DictConfig",
"ListConfig",
]
class _Ctx:
"""轻量锁+加载"""
__slots__ = ("_mgr", "_save")
def __init__(self, mgr: "SimpleConfig", save: bool = False):
self._mgr = mgr
self._save = save
def __enter__(self):
mgr = self._mgr
mgr._lock.acquire()
try:
if not mgr._loaded:
mgr._do_load()
except:
mgr._lock.release()
raise
return mgr._cache
def __exit__(self, exc_type, exc_val, exc_tb):
try:
if self._save and exc_type is None:
self._mgr._save()
finally:
self._mgr._lock.release()
class SimpleConfig(object):
"""json"""
__slots__ = ("_path", "_tmp_path", "_lock", "_cache", "_loaded", "_default")
def __init__(self, path: str, default=None):
"""配置文件的绝对路径"""
self._path: str = path
self._default = default
self._tmp_path: str = path + ".tmp"
self._lock = threading.RLock()
self._cache = None
self._loaded: bool = False
if default is not None and not os.path.exists(path):
with self._lock:
self._cache = copy.deepcopy(default)
self._loaded = True
self._save()
def __bool__(self) -> bool:
"""判断当前数据是否非空"""
with self._ctx() as c:
return bool(c)
def _ctx(self, save: bool = False) -> _Ctx:
return _Ctx(self, save=save)
def _do_load(self):
"""需持锁"""
if os.path.exists(self._path):
try:
with open(self._path, "r", encoding="utf-8") as f:
self._cache = json.load(f)
self._loaded = True
return
except (ValueError, OSError):
pass
self._cache = copy.deepcopy(self._default) if self._default is not None else self._default_data()
self._loaded = True
def _save(self):
"""需持锁"""
dirname = os.path.dirname(self._path)
if dirname:
os.makedirs(dirname, exist_ok=True)
try:
with open(self._tmp_path, "w", encoding="utf-8") as f:
json.dump(self._cache, f, ensure_ascii=False)
f.flush()
os.replace(self._tmp_path, self._path)
except OSError:
self._loaded = False
try:
os.remove(self._tmp_path)
except FileNotFoundError:
pass
raise
def _default_data(self):
raise NotImplementedError
# ----------- public -----------
def reload(self):
"""强制重新加载"""
with self._lock:
self._loaded = False
self._do_load()
def clear(self):
"""清空并持久化"""
with self._ctx(save=True):
self._cache = copy.deepcopy(self._default) if self._default is not None else self._default_data()
def save(self):
"""手动持久化"""
with self._ctx():
if self._loaded:
self._save()
def atomic(self) -> _Ctx:
"""原子操作上下文, 事务"""
return self._ctx(save=True)
@property
def path(self) -> str:
return self._path
def exists(self) -> bool:
"""配置文件是否存在于磁盘"""
return os.path.exists(self._path)
def delete_config(self):
with self._lock:
for p in (self._path, self._tmp_path):
try:
os.remove(p)
except FileNotFoundError:
pass
self._cache = copy.deepcopy(self._default) if self._default is not None else self._default_data()
self._loaded = True
# ---------------------------------------------------------------------------
class DictConfig(SimpleConfig):
"""
cfg = DictConfig("/path/to/config.json", default={"a": 1})
cfg.set("key", "value")
cfg.get("key") # "value"
cfg.get("missing", 0) # 0
cfg["key"] = "new_value"
del cfg["key"]
cfg.update({"a": 1, "b": 2}) # 浅合并,整体覆盖同名 key
cfg.merge({"a": {"x": 1}}) # 深合并,更新嵌套中若干字段而非整体覆盖
cfg.keys() / cfg.values() / cfg.items()
"key" in cfg
len(cfg)
cfg.pop("key", None)
cfg.setdefault("key", default_val)
cfg.as_dict()
"""
def _default_data(self) -> dict:
return {}
def get(self, key: str, default=None):
with self._ctx() as c:
return c.get(key, default)
def __getitem__(self, key: str):
with self._ctx() as c:
return c[key]
def __contains__(self, key: str) -> bool:
with self._ctx() as c:
return key in c
def __len__(self) -> int:
with self._ctx() as c:
return len(c)
def __iter__(self):
with self._ctx() as c:
return iter(list(c.keys()))
def keys(self):
with self._ctx() as c:
return list(c.keys())
def values(self):
with self._ctx() as c:
return list(c.values())
def items(self):
with self._ctx() as c:
return list(c.items())
def as_dict(self) -> dict:
with self._ctx() as c:
return dict(c)
def __repr__(self) -> str:
with self._ctx() as c:
return f"DictConfig({self._path!r}, {c!r})"
def set(self, key: str, value) -> None:
with self._ctx(save=True) as c:
c[key] = value
def __setitem__(self, key: str, value) -> None:
with self._ctx(save=True) as c:
c[key] = value
def update(self, data: dict) -> None:
if not data:
return
with self._ctx(save=True) as c:
c.update(data)
def merge(self, data: dict) -> None:
"""合并data到配置, 对嵌套dict递归合并, 非update覆盖"""
if not data:
return
def _deep_merge(base: dict, patch: dict) -> None:
for k, v in patch.items():
if k in base and isinstance(base[k], dict) and isinstance(v, dict):
_deep_merge(base[k], v)
else:
base[k] = v
with self._ctx(save=True) as c:
_deep_merge(c, data)
def setdefault(self, key: str, default=None):
with self._ctx(save=True) as c:
if key not in c:
c[key] = default
return c[key]
def delete(self, key: str) -> None:
with self._ctx(save=True) as c:
if key in c:
del c[key]
def __delitem__(self, key: str) -> None:
with self._ctx(save=True) as c:
del c[key]
def pop(self, key: str, *args):
with self._ctx(save=True) as c:
return c.pop(key, *args) if args else c.pop(key)
# ---------------------------------------------------------------------------
class ListConfig(SimpleConfig):
"""
cfg = ListConfig("/path/to/list.json", default=[1,2,3])
cfg.append("item") / cfg.insert(0, "x") / cfg.extend([...])
cfg.get(0) / cfg[0] / cfg[0] = "v"
cfg.remove("x") / cfg.pop(0) / del cfg[0]
cfg.index("x") / cfg.count("x")
"x" in cfg / len(cfg) / iter(cfg)
cfg.sort() / cfg.reverse()
cfg.unique()
cfg.as_list()
"""
def _default_data(self) -> list:
return []
def get(self, index: int, default=None):
with self._ctx() as c:
try:
return c[index]
except IndexError:
return default
def __getitem__(self, index):
with self._ctx() as c:
return c[index]
def __contains__(self, item) -> bool:
with self._ctx() as c:
return item in c
def __len__(self) -> int:
with self._ctx() as c:
return len(c)
def __iter__(self):
with self._ctx() as c:
return iter(list(c))
def count(self, item) -> int:
with self._ctx() as c:
return c.count(item)
def index(self, item, *args) -> int:
with self._ctx() as c:
return c.index(item, *args)
def as_list(self) -> list:
with self._ctx() as c:
return list(c)
def __repr__(self) -> str:
with self._ctx() as c:
return f"ListConfig({self._path!r}, {c!r})"
def set(self, index: int, value) -> None:
with self._ctx(save=True) as c:
c[index] = value
def __setitem__(self, index, value) -> None:
with self._ctx(save=True) as c:
c[index] = value
def __delitem__(self, index) -> None:
with self._ctx(save=True) as c:
del c[index]
def append(self, item) -> None:
with self._ctx(save=True) as c:
c.append(item)
def insert(self, index: int, item) -> None:
with self._ctx(save=True) as c:
c.insert(index, item)
def extend(self, items) -> None:
items = list(items)
if not items:
return
with self._ctx(save=True) as c:
c.extend(items)
def remove(self, item) -> None:
with self._ctx(save=True) as c:
c.remove(item)
def pop(self, index: int = -1):
with self._ctx(save=True) as c:
return c.pop(index)
def sort(self, *, key=None, reverse: bool = False) -> None:
with self._ctx(save=True) as c:
c.sort(key=key, reverse=reverse)
def reverse(self) -> None:
with self._ctx(save=True) as c:
c.reverse()
def unique(self) -> None:
"""保序去重"""
with self._ctx(save=True) as c:
seen: set = set()
result = []
for item in c:
try:
key = item
hash(key)
except TypeError:
key = id(item)
if key not in seen:
seen.add(key)
result.append(item)
c[:] = result

View File

@@ -0,0 +1,588 @@
# coding: utf-8
import copy
import itertools
import time
try:
import ujson as json
except ImportError:
try:
os.system("btpip install ujson")
import ujson as json
except:
import json
from collections.abc import Callable
from dataclasses import dataclass, field as dataclass_field
from datetime import datetime
from typing import Any, TypeVar, List, Optional, Iterable, TYPE_CHECKING
from public.exceptions import HintException
__all__ = [
"StrField",
"IntField",
"FloatField",
"BlobField",
"ListField",
"DictField",
"DateTimeStrField",
]
if TYPE_CHECKING:
from .model import aaModel
M = TypeVar("M", bound="aaModel")
def json_func(v_type: type, value: Any, forward: bool = True):
try:
if forward is True:
if isinstance(value, v_type):
return json.dumps(value)
else:
if isinstance(value, str):
return json.loads(value)
return value
except TypeError as t:
print("type error %s" % t)
return value
except Exception as e:
print("error %s" % e)
raise e
def _wrap_value(value, on_change_callback, field_name):
"""wrap Tracked"""
if isinstance(value, list) and not isinstance(value, TrackedList):
return TrackedList(value, on_change=on_change_callback, field_name=field_name)
if isinstance(value, dict) and not isinstance(value, TrackedDict):
return TrackedDict(value, on_change=on_change_callback, field_name=field_name)
return value
class TrackedList(list):
"""override list, track fields dirty"""
__slots__ = ("_on_change", "_field_name")
def __init__(self, iterable: Iterable = (), *, on_change: Callable = None, field_name: str | None = None):
self._on_change = on_change
self._field_name = field_name
wrapped = [
_wrap_value(v, on_change, field_name) for v in (iterable or [])
]
super().__init__(wrapped)
def _notify_change(self):
"""call back"""
if self._on_change and self._field_name:
self._on_change(self._field_name)
def __deepcopy__(self, memo):
return list(copy.deepcopy(item, memo) for item in self)
def __setitem__(self, key, value):
if isinstance(key, slice):
value = [_wrap_value(v, self._on_change, self._field_name) for v in value]
else:
value = _wrap_value(value, self._on_change, self._field_name)
super().__setitem__(key, value)
self._notify_change()
def append(self, item):
super().append(_wrap_value(item, self._on_change, self._field_name))
self._notify_change()
def insert(self, index: int, item: Any):
super().insert(index, _wrap_value(item, self._on_change, self._field_name))
self._notify_change()
def remove(self, item):
super().remove(item)
self._notify_change()
def pop(self, *args, **kwargs):
result = super().pop(*args, **kwargs)
self._notify_change()
return result
def clear(self):
super().clear()
self._notify_change()
def extend(self, iterable):
if not iterable:
return
wrapped_iterable = [
_wrap_value(v, self._on_change, self._field_name) for v in iterable
]
super().extend(wrapped_iterable)
self._notify_change()
def sort(self, *args, **kwargs):
super().sort(*args, **kwargs)
self._notify_change()
def reverse(self):
super().reverse()
self._notify_change()
class TrackedDict(dict):
"""override dict, track fields dirty"""
__slots__ = ("_on_change", "_field_name")
def __init__(self, *args, on_change: Callable = None, field_name: str | None = None, **kwargs):
super().__init__(*args, **kwargs)
self._on_change = on_change
self._field_name = field_name
items_to_update = {}
for key, value in list(self.items()):
# list() to avoid "dictionary changed size during iteration"
items_to_update[key] = _wrap_value(value, self._on_change, self._field_name)
super().update(items_to_update)
def _notify_change(self):
"""call back"""
if self._on_change and self._field_name:
self._on_change(self._field_name)
def __deepcopy__(self, memo):
return {k: copy.deepcopy(v, memo) for k, v in self.items()}
def __setitem__(self, key, value):
wrapped_value = _wrap_value(value, self._on_change, self._field_name)
super().__setitem__(key, wrapped_value)
self._notify_change()
def __delitem__(self, key):
super().__delitem__(key)
self._notify_change()
def pop(self, *args, **kwargs):
result = super().pop(*args, **kwargs)
self._notify_change()
return result
def popitem(self):
result = super().popitem()
self._notify_change()
return result
def clear(self):
super().clear()
self._notify_change()
def update(self, *args, **kwargs):
other = dict(*args, **kwargs)
if not other:
return
wrapped_other = {}
for key, value in other.items():
wrapped_other[key] = _wrap_value(value, self._on_change, self._field_name)
super().update(wrapped_other)
self._notify_change()
def setdefault(self, key, default=None):
if key not in self:
wrapped_default = _wrap_value(default, self._on_change, self._field_name)
result = super().setdefault(key, wrapped_default)
self._notify_change()
else:
result = super().get(key)
return result
@dataclass
class aaField(object):
"""
字段基类
default 默认值
ps 字段说明
null 是否null
primary_key 是否主键
foreign_key 外键
field_name 字段key
field_type sql类型
py_type py类型
compare 比较
transform 转换工具
"""
default: Any = None
ps: str = None
null: bool = False
primary_key: bool = False
field_name: str = None
field_type: str = None
py_type: type = None
compare: tuple = None
serialized: Callable = None
def __set_name__(self, owner: M, name: str):
self.__model = owner
self.field_name = str(name)
def __get__(self, instance: object, owner):
if instance is None:
return self
return instance.__dict__.get(self.field_name, self.get_default_val())
def __set__(self, instance: M, value: Any):
# base type field, check new set value
current_value = instance.__dict__.get(self.field_name)
if current_value is value:
# base type field, not Tracker
return
if hasattr(instance, "_mark_dirty"):
instance._mark_dirty(self.field_name)
instance.__dict__[self.field_name] = value
def __delete__(self, instance):
try:
del instance.__dict__[self.field_name]
except KeyError:
raise AttributeError(f"{instance} dont have attr '{self.field_name}'")
def _raise_error(self, raise_exp: bool = True) -> bool:
if raise_exp is True:
err = f"'{self.field_name}' TypeError! It should be '{self.py_type.__name__}'"
raise HintException(err)
else:
return False
def _check_type(self, target: Any, raise_exp=True) -> Optional[bool]:
target = target if not isinstance(target, Callable) else target()
if any(isinstance(target, x) for x in self.py_types):
return True
return self._raise_error(raise_exp=raise_exp)
@property
def default_val_sql(self) -> str:
default_v = self.get_default_val(check=True)
default_v = default_v if self.serialized is None else self.serialized(default_v)
return f"DEFAULT '{default_v}'" if isinstance(default_v, str) else f"DEFAULT {default_v}"
@property
def py_types(self) -> List[type]:
if self.null is True:
original = [type(self.default), self.py_type, type(None)]
else:
original = [type(self.default), self.py_type]
return list(set(original))
def get_default_val(self, check: bool = False) -> Optional[Any]:
if check:
self.check_org_type(raise_exp=check)
if self.default is not None:
return self.default if not isinstance(self.default, Callable) else self.default()
else:
if self.null is True:
return None
else:
raise TypeError(
f"\n1: field '{self.field_name}' is not null, must have a default value"
f"\n2: you can add the '{self.field_name}' field's params null=True"
)
def model_check_type(self, target: Any, raise_exp=True) -> bool:
"""
检查模型当前类型结构
"""
return self._check_type(target, raise_exp=raise_exp)
def check_org_type(self, raise_exp: bool = True) -> bool:
"""
检查初始化的类型结构
"""
return self._check_type(self.default, raise_exp=raise_exp)
@dataclass()
class StrField(aaField):
"""
String field
field__like="a",
field__ne="a",
field__in=["a", "b", "c"]
field__not_in=["a", "b", "c"]
field__startswith="a"
field__endswith="a"
"""
default: str | None = ""
field_type: str = "TEXT"
py_type: type = str
max_length: int = 255 # not limit now
min_length: int = 0 # not limit now
compare: tuple[str] = (
"like",
"ne",
"in",
"not_in",
"startswith",
"endswith",
)
@dataclass
class IntField(aaField):
"""
Int field
field__gt=1,
field__gte=1,
field__lt=1,
field__lte=1,
field__ne=1,
field__in=[1, 2, 3]
field__not_in=[1, 2, 3]
"""
default: int | None | Any = 0
field_type: str = "INTEGER"
py_type: type = int
max: int = 0 # not limit now
min: int = 0 # not limit now
compare: tuple[str] = (
"gt",
"lt",
"gte",
"lte",
"ne",
"in",
"not_in",
)
@dataclass
class FloatField(aaField):
"""
Float field
"""
default: float | None = 0.0
field_type: str = "REAL"
py_type: type = float
max: float = 0.0 # not limit now
min: float = 0.0 # not limit now
compare: tuple[str] = (
"gt",
"lt",
"gte",
"lte",
"ne",
"in",
"not_in",
)
@dataclass
class BlobField(aaField):
"""
Blob field
"""
default: bytes | None = b''
field_type: str = "BLOB"
py_type: type = bytes
@dataclass
class ListField(aaField):
"""
List field
"""
# override __get__ to return tracker
def __get__(self, instance: M, owner):
if instance is None:
return self
value: Iterable[Any] = instance.__dict__.get(self.field_name)
if value is None:
value = self.get_default_val()
# init default val for the first time
instance.__dict__[self.field_name] = value
if not isinstance(value, TrackedList):
# generate call back
value = TrackedList(
value,
on_change=instance._mark_dirty,
field_name=self.field_name,
)
instance.__dict__[self.field_name] = value # update instance's attr
return value
def __set__(self, instance: M, value: Any):
"""override, other update handled by TrackedList"""
if not isinstance(value, list):
raise TypeError(f"Field '{self.field_name}' expects a list, but got {type(value).__name__}")
tracked_value = TrackedList(
value,
on_change=instance._mark_dirty,
field_name=self.field_name,
)
instance._mark_dirty(self.field_name)
super().__set__(instance, tracked_value)
@staticmethod
def _serialized(value: list | str, forward: bool = True) -> list | Any:
return json_func(list, value, forward)
default: list = dataclass_field(default_factory=list)
field_type: str = "TEXT"
serialized: Callable = _serialized
py_type: type = list
compare: tuple[str] = (
"like",
"contains",
"any_contains",
)
update: tuple[str] = (
"append",
)
@dataclass
class DictField(aaField):
"""
Dict field
"""
def __get__(self, instance: M, owner):
if instance is None:
return self
value: dict = instance.__dict__.get(self.field_name)
if value is None:
value = self.get_default_val()
instance.__dict__[self.field_name] = value
if not isinstance(value, TrackedDict):
# generate tracker call back
value = TrackedDict(
value,
on_change=instance._mark_dirty,
field_name=self.field_name,
)
instance.__dict__[self.field_name] = value # update instance's attr
return value
def __set__(self, instance: M, value: Any):
"""override, other update handled by TrackedDict"""
if not isinstance(value, dict):
raise TypeError(f"Field '{self.field_name}' expects a dict, but got {type(value).__name__}")
tracked_value = TrackedDict(
value,
on_change=instance._mark_dirty,
field_name=self.field_name,
)
instance._mark_dirty(self.field_name)
super().__set__(instance, tracked_value)
@staticmethod
def _serialized(value: dict | str, forward: bool = True) -> dict | Any:
return json_func(dict, value, forward)
default: dict = dataclass_field(default_factory=dict)
field_type: str = "TEXT"
serialized: Callable = _serialized
py_type: type = dict
compare: tuple[str] = (
# "has_key",
# "has_value",
# "has_key_value",
"lt",
"lte",
"gt",
"gte",
"ne",
"like",
"startswith",
"endswith",
)
update: tuple[str] = (
"update",
)
ACCURACY = 1000
@dataclass
class DateTimeStrField(aaField):
"""
时间戳
auto_now_add=True 创建时间自动添加
auto_now=True 更新时间自动更新
"""
@classmethod
def _current_timestamp(cls):
return time.strftime(cls.format, time.localtime())
@staticmethod
def _dynamic(obj, val):
if hasattr(obj, "auto_now_add") and obj.auto_now_add is True: # 创建时间
return val
elif hasattr(obj, "auto_now") and obj.auto_now is True: # 更新时间
return obj.get_default_val()
else:
return val
@staticmethod
def _serialized(value: int | str, forward: bool = True) -> str | int:
try:
if forward is True:
if isinstance(value, str): # save will be str
return int(
datetime.strptime(value, DateTimeStrField.format).timestamp() * DateTimeStrField.accuracy)
else:
if isinstance(value, int):
return datetime.fromtimestamp(value / DateTimeStrField.accuracy).strftime(DateTimeStrField.format)
return value
except Exception as e:
print("type error %s" % e)
return value
serialized: Callable = _serialized
default: str | Callable = ""
field_type: str = "INTEGER"
py_type: type = str
dynamic: bool = True
auto_now_add: bool = False
auto_now: bool = False
accuracy: int = ACCURACY
format: str = "%Y-%m-%d %H:%M:%S"
compare: tuple[str] = (
"gt",
"lt",
"gte",
"lte",
)
def __post_init__(self):
if self.auto_now_add is True and self.auto_now is True:
raise TypeError("auto_now_add and auto_now can not be used at the same time")
if self.auto_now is True:
self.default = self._current_timestamp
elif self.auto_now_add is True:
self.default = self._current_timestamp
COMPARE = tuple(set(
itertools.chain(
*[
StrField.compare,
IntField.compare,
FloatField.compare,
ListField.compare,
DictField.compare,
DateTimeStrField.compare,
]
)
))

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,343 @@
# coding: utf-8
import copy
from dataclasses import replace as replace_dataclass
from functools import lru_cache
from typing import Self, Generator, Optional, Dict, Any
from .fields import aaField, COMPARE
from .manager import aaManager
__all__ = ["aaModel"]
from public.exceptions import HintException
@lru_cache(maxsize=16)
def generate_table_name(class_name: str) -> str:
"""
驼峰名转表名
"""
return ''.join(['_' + c.lower() if c.isupper() else c for c in class_name]).lstrip('_')
class aaMetaClass(type):
__abstract__: bool
__db_name__: str
__table_name__: str
__fields__: dict
__primary_key__: str
__serializes__: dict
__index_keys__: list
def __new__(cls, name, bases, attrs):
if attrs.get("__abstract__") is True:
return super().__new__(cls, name, bases, attrs)
attrs.update({"__abstract__": False})
new_class = super().__new__(cls, name, bases, attrs)
cls.__fields_process(obj=new_class, name=name, attrs=attrs)
cls.__database_process(obj=new_class, name=name, attrs=attrs)
return new_class
def __setattr__(cls, key, value):
if key == '__abstract__':
raise AttributeError("can't set attribute '__abstract__'")
return super().__setattr__(key, value)
@classmethod
def __fields_process(cls, obj: "aaMetaClass", name: str, attrs: dict):
pk = ""
fields = {}
for k, v in attrs.items():
if isinstance(v, aaField):
if k in fields:
raise HintException(f"model {name} field '{k}' is already defined")
if k in COMPARE:
raise HintException(f"model {name} field '{k}' is compare field, please change the name")
if k.startswith("_"):
raise HintException(f"model {name} field '{k}' is not support start with '_'")
fields[k] = v
if v.primary_key:
if pk:
raise HintException(f"model {name} can only have one primary key")
else:
pk = k
if not pk:
raise HintException(f"sth wrong with {name}'s primary key, please check the model")
setattr(obj, "__primary_key__", pk)
setattr(obj, "__fields__", fields)
@classmethod
def __database_process(cls, obj: "aaMetaClass", name: str, attrs: dict):
db_name, tb_name, idx = "default", generate_table_name(name), []
meta = attrs.get("_Meta")
if meta:
if hasattr(meta, "db_name"):
db_name = meta.db_name
if hasattr(meta, "table_name"):
tb_name = meta.table_name
if hasattr(meta, "index"):
idx = meta.index
setattr(obj, "__db_name__", db_name)
setattr(obj, "__table_name__", tb_name)
setattr(obj, "__index_keys__", idx)
class aaCusModel(metaclass=aaMetaClass):
__abstract__ = True
objects = aaManager()
_dirty_fields: Optional[set] = None
def __init__(self, **kwargs):
if self.__abstract__:
raise RuntimeError(f'{self.__class__.__name__} class can not be init')
self._field_filter = kwargs.pop("_field_filter", None)
for f, v in self._generate_init(kwargs, all_flag=True):
setattr(self, f.field_name, v)
# after init set, init dirty fields set
self._dirty_fields = set()
def _mark_dirty(self, field_name: str):
if self._dirty_fields is None:
return
self._dirty_fields.add(field_name)
def _generate_init(self, val_data: dict, all_flag: bool = False) -> Generator:
fields_map = self._get_fields() if all_flag else {
k: v for k, v in self._get_fields().items()
if k in val_data or (hasattr(v, "dynamic") and v.auto_now is True)
}
for name, field in fields_map.items():
default_val = field.get_default_val()
val = val_data.get(name, default_val)
if field.primary_key is True and val == 0:
continue # skip default id val
if field.primary_key is True and val != 0:
try:
val = int(val)
except Exception:
pass
yield field, val
# if val_data: # other field
# raise AttributeError(f"model '{self.__class__.__name__}' has no field {val_data}")
# pass
# =========================================
@classmethod
@lru_cache(maxsize=32)
def _get_fields(cls):
return cls.__fields__
@classmethod
@lru_cache(maxsize=32)
def _get_serialized(cls) -> dict:
return {
k: replace_dataclass(v) for k, v in cls._get_fields().items() if v.serialized is not None
}
@classmethod
def _get_serialized_fields_fz(cls) -> frozenset:
return frozenset(cls._get_serialized().keys())
class aaModel(aaCusModel):
"""
基础模型
:example:
class MyTestModel(aaModel):
id = IntField(primary_key=True)
name = StrField(ps="名字")
status = BoolField(default=True, ps="状态")
float_number = FloatField(default=0.05, ps="浮点")
class _Meta:
db_name = "default" 默认为 default.db 文件
table_name = "my_table" 默认为类名驼峰转表名 my_test_model
index = ["status"] 索引
"""
__db_name__: str
__table_name__: str
__fields__: dict
__primary_key__: str
__serializes__: dict
__index_keys__: list
__abstract__: bool = True
__destroyed: bool = False
id: int = None
def __repr__(self):
return f"<{self.__class__.__name__}: {self.as_dict()}>"
@staticmethod
def _check_destroyed(func):
def wrapper(self, *args, **kwargs):
if getattr(self, "__destroyed", False):
raise RuntimeError(f"Cannot call {func.__name__}() on destroyed object")
return func(self, *args, **kwargs)
return wrapper
@classmethod
@_check_destroyed
def _output(cls, data: dict, _field_filter=None) -> dict:
serlz = cls._get_serialized()
serlz_fields = cls._get_serialized_fields_fz()
if _field_filter is not None:
data = {k: v for k, v in data.items() if k in _field_filter}
return {
k: serlz[k].serialized(v, False) if k in serlz_fields else v
for k, v in data.items()
}
@classmethod
@_check_destroyed
def _serialized_data(cls, data: Optional[dict | list], _field_filter=None) -> Optional[dict | list]:
if isinstance(data, list):
return [cls._output(d, _field_filter) for d in data]
elif isinstance(data, dict):
return cls._output(data, _field_filter)
else:
return data
@_check_destroyed
def _validate(self, target: dict = None, raise_exp: bool = True) -> Optional[Dict[str, Any]]:
"""
模型验证, 返回序列化后的结果
"""
body = {}
for f, cur_val in self._generate_init(target or copy.deepcopy(self.__dict__)):
try:
# 1, dynamic generated
if hasattr(f, "dynamic") and f.dynamic is True:
cur_val = f._dynamic(f, cur_val)
setattr(self, f.field_name, cur_val)
# 2, check type and return serialized
if f.model_check_type(target=cur_val, raise_exp=raise_exp) is True:
body[f.field_name] = f.serialized(cur_val, True) if f.serialized else cur_val
else:
# if not raise_exp and check is False, return {}
return None
except HintException as e1:
if raise_exp:
raise e1
return None
except Exception as e:
raise Exception(e)
return body
def _before_save(self):
# override
pass
def _after_save(self):
# override
pass
def _before_update(self):
# override
pass
def _after_update(self):
# override
pass
@_check_destroyed
def save(self, raise_exp: bool = True) -> Optional[Self]:
"""
模型数据, 不存在则 保存 , 存在则 更新, 仅更新变动字段
:raise_exp 抛异常
:return: model object 字段类型异常等问题返回 None
"""
if self.__class__.__abstract__:
raise RuntimeError(f'{self.__class__.__name__} class can not be save')
try:
cls = self.__class__
primary_key = cls.__primary_key__
pk = int(self.__dict__.get(primary_key, 0))
# not changed & not insert.
if not self._dirty_fields and pk != 0:
return self
dirtys = {
k: v for k, v in self.__dict__.items() if k in self._dirty_fields
}
if pk == 0:
# insert, all fields default
validate = self._validate(raise_exp=raise_exp)
else:
if "update_time" in cls._get_fields() and "update_time" not in dirtys:
dirtys["update_time"] = None
# for field_name, field_obj in cls._get_fields().items():
# if hasattr(field_obj, "auto_now") and field_obj.auto_now:
# dirtys[field_name] = None
# update, olnly validate dirty fields
validate = self._validate(target=dirtys, raise_exp=raise_exp)
if not validate:
if raise_exp:
raise HintException("validate error")
return None
if pk == 0: # insert
self._before_save()
new_id = cls.objects._insert(validate)
if not new_id:
if raise_exp:
raise HintException("insert failed")
return None
self.__dict__[primary_key] = new_id
self._after_save()
else: # update
self._before_update()
if cls.objects._update({primary_key: pk}, validate) == 1:
self._after_update()
else: # update failed
if raise_exp:
raise HintException("update failed")
return None
# reset in finally block
return self
except (TypeError, AttributeError) as t:
if raise_exp:
raise t
return None
except Exception as e:
import traceback
print(traceback.format_exc())
raise HintException(e)
finally:
if self._dirty_fields:
self._dirty_fields.clear()
@_check_destroyed
def delete(self) -> int:
try:
self.__class__.objects._query.where(
f"{self.__class__.__primary_key__}=?", (self.id,)
).delete()
setattr(self, "__destroyed", True)
except Exception as e:
print(e)
return 0
return 1
@_check_destroyed
def as_dict(self) -> dict:
"""
转字典
"""
result = {}
for k, v in self.__dict__.items():
if k.startswith("_"):
continue
if self._field_filter is not None and k not in self._field_filter:
continue
result[k] = v
return result