Files

1087 lines
38 KiB
Python
Raw Permalink Normal View History

2026-04-07 02:04:22 +05:30
# coding: utf-8
import os
import sqlite3 as Engine
import uuid
from functools import reduce
from itertools import chain
from typing import Optional, TypeVar, Generic, Any, List, Dict, Generator, Iterable, Type
try:
import ujson as json
except ImportError:
try:
os.system("btpip install ujson")
import ujson as json
except:
import json
from public.aaModel.fields import COMPARE
from public.exceptions import HintException, PanelError
from public.sqlite_easy import Db
__all__ = ["aaManager", "Q"]
# noinspection PyUnresolvedReferences
M = TypeVar("M", bound="aaModel")
# ==================== Patch ==================
def _builtin(check_engine: Any = None) -> bool:
if not check_engine:
check_engine = Engine
try:
conn = check_engine.connect(":memory:")
cursor = conn.cursor()
cursor.execute("SELECT json_extract('{\"a\": 1}', '$.a')")
cursor.execute("SELECT COUNT(*) FROM json_each('{\"a\":1, \"b\":2}')")
conn.close()
return True
except:
return False
def _get_engine() -> tuple[bool, Any]:
try:
import pysqlite3 as engine
flag = True
except:
try:
os.system("btpip install pysqlite3-binary")
import pysqlite3 as engine
flag = True
except:
engine = Engine
flag = False
return flag, engine
_ENGINE = None
_INSTEAD = False
_ORG = _builtin()
if not _ORG:
_INSTEAD, _ENGINE = _get_engine()
else:
_ENGINE = Engine
# ==================== Patch End ==================
def db_file_connect(db_name: str) -> Db:
if not db_name:
raise PanelError("db_file_connect_generator error, db_name is empty")
if db_name == "default":
return Db(db_name=db_name, engine=_ENGINE)
else:
db_root_dir = "/www/server/panel/data/db"
os.makedirs(db_root_dir, exist_ok=True)
db_path = os.path.join(db_root_dir, f"{db_name}")
new = Db(db_name="", engine=_ENGINE, auto_connect=False)
setattr(new, "_Db__DB_NAME", db_path)
new._connect()
return new
class Operator:
SIMPLE_OP = {
"gt": ">", "lt": "<", "gte": ">=", "lte": "<=", "ne": "!=",
}
LIKE_OP = {
"like": "%{}%", "startswith": "{}%", "endswith": "%{}",
}
def __init__(self, model_class: M, query: "Db.query"):
self._model_class: M = model_class
self._query = query
self._tb = self._model_class.__table_name__
self._fields = self._model_class._get_fields()
self._serializes = self._model_class._get_serialized()
self._flag = _ORG or _INSTEAD # fk flag
def _q_error(self, key: str, act: str, val: Any, sp_act: tuple):
raise HintException(
"field: '%s' is not support '%s', you can use: %s" % (key, act, sp_act)
)
def _deep_equal(self, obj1: Any, obj2: Any) -> bool:
if type(obj1) != type(obj2):
return False
if isinstance(obj1, dict):
if set(obj1.keys()) != set(obj2.keys()):
return False
return all(self._deep_equal(obj1[k], obj2[k]) for k in obj1)
if isinstance(obj1, list):
if len(obj1) != len(obj2):
return False
return all(self._deep_equal(i1, i2) for i1, i2 in zip(obj1, obj2))
return obj1 == obj2
def _deep_equal_in(self, item: Any, container_list: List[Any]) -> bool:
for elem in container_list:
if self._deep_equal(item, elem):
return True
return False
def _python_compare(self, v: Any, op_str: str, q_v: Any) -> bool:
try:
if op_str == "gt":
return v > q_v
if op_str == "lt":
return v < q_v
if op_str == "gte":
return v >= q_v
if op_str == "lte":
return v <= q_v
if op_str == "ne":
return not self._deep_equal(v, q_v)
if op_str == "like":
return isinstance(v, str) and isinstance(q_v, str) and q_v in v
if op_str == "startswith":
return isinstance(v, str) and isinstance(q_v, str) and v.startswith(q_v)
if op_str == "endswith":
return isinstance(v, str) and isinstance(q_v, str) and v.endswith(q_v)
if op_str == "in":
return isinstance(q_v, list) and self._deep_equal_in(v, q_v)
if op_str == "not_in":
return isinstance(q_v, list) and not self._deep_equal_in(v, q_v)
if op_str == "contains" or op_str == "any_contains":
items_to_search = q_v
if not isinstance(items_to_search, list):
items_to_search = [items_to_search]
if not items_to_search: # 如果搜索列表为空
return True if op_str == "contains" else False # AND(empty)=True, OR(empty)=False
match_results = []
for item in items_to_search:
found = False
if isinstance(v, list):
found = self._deep_equal_in(item, v)
elif isinstance(v, str) and isinstance(item, str):
found = item in v
elif isinstance(v, dict): # 检查item是否为dict中的值
found = self._deep_equal_in(item, list(v.values()))
match_results.append(found)
return all(match_results) if op_str == "contains" else any(match_results)
except TypeError: # 比较不兼容的类型
return False
except Exception:
return False
return False
def __generate_road(self, road):
path = "$"
for r in road:
if r.isdigit():
path += f"[{r}]"
else:
path += f".{r}"
return path
def __is_json(self, val: Any):
if isinstance(val, str):
val_str = val.strip()
if val_str.startswith(("{", "[")) and val_str.endswith(("}", "]")):
try:
json.loads(val_str)
return True, val
except json.JSONDecodeError:
pass
if isinstance(val, (dict, list)):
return True, json.dumps(val)
return False, val
def __navigate_path(self, json_data: Any, road: List[str]) -> tuple[Any, bool]:
current_val = json_data
for r in road:
if isinstance(current_val, dict):
if r in current_val:
current_val = current_val[r]
else:
return None, False
elif isinstance(current_val, list) and r.isdigit():
idx = int(r)
if 0 <= idx < len(current_val):
current_val = current_val[idx]
else:
return None, False
else:
return None, False
return current_val, True
def __compare_operator(self, key: str, compare: str, val: Any, is_json: bool, sp_compare: tuple):
def __contains_and_or(v: list, connector: str):
if not v:
return ("1=1" if connector == "AND" else "1=0"), []
conditions = []
params = []
for item in v:
if self._flag is True and not isinstance(item, (dict, list)):
# simple val
conditions.append(f"EXISTS(SELECT 1 FROM json_each({key}) WHERE value = ?)")
params.append(item)
else:
# other | complicated val
conditions.append(f"instr({key}, ?) > 0")
params.append(json.dumps(item))
return f" {connector} ".join(conditions), params
if compare in self.SIMPLE_OP:
if compare == "ne" and self._flag and is_json:
return f"NOT (json({key}) = json(?))", [val]
return f"{key} {self.SIMPLE_OP[compare]} ?", [val]
elif compare in self.LIKE_OP:
return f"{key} LIKE ?", [self.LIKE_OP[compare].format(val)]
elif compare in ("contains", "any_contains"):
if is_json:
try:
val = json.loads(val)
except (json.JSONDecodeError, TypeError):
pass
if not isinstance(val, list):
val = [val]
if compare == "contains":
return __contains_and_or(val, "AND")
else:
return __contains_and_or(val, "OR")
elif compare in ("in", "not_in"):
if is_json:
try:
val = json.loads(val)
except (json.JSONDecodeError, TypeError):
pass
if not isinstance(val, (list, tuple, set)):
raise HintException(f"{key}__{compare} expects a list/tuple/set")
val = list(val)
if len(val) == 0:
return ("1=0" if compare == "in" else "1=1"), []
placeholders = ", ".join(["?"] * len(val))
op = "IN" if compare == "in" else "NOT IN"
return f"{key} {op} ({placeholders})", val
self._q_error(key, compare, val, sp_compare)
return None, None
def __compare_reducer(self, key: str, compare: str, road: list, val: Any):
sp_compare = getattr(self._fields.get(key), "compare")
compare = compare.lower()
# 没路径, 正常拦截op
# 有路径, 最终字段类型是不确定的, 查询结果不可控
if not sp_compare or compare not in [c.lower() for c in sp_compare]:
if road:
self._q_error(f"{key}__{road}", compare, val, sp_compare)
else:
self._q_error(key, compare, val, sp_compare)
is_json, val = self.__is_json(val)
if not road:
sql, params = self.__compare_operator(
key=f"{self._tb}.{key}",
compare=compare,
val=val,
is_json=is_json,
sp_compare=sp_compare
)
return sql, params
else:
if self._flag:
sql, params = self.__compare_operator(
key=f"json_extract({self._tb}.{key}, '{self.__generate_road(road)}')",
compare=compare,
val=val,
is_json=is_json,
sp_compare=sp_compare
)
return sql, params
# 退化分支
if is_json:
try:
val = json.loads(val)
except:
pass
pk_name = self._model_class.__primary_key__
q_fork = self._query.fork()
q_fork._SqliteEasy__OPT_LIMIT.clear()
q_fork._SqliteEasy__OPT_ORDER.clear()
q_fork._SqliteEasy__OPT_GROUP.clear()
q_fork._SqliteEasy__OPT_HAVING.clear()
q_fork._SqliteEasy__OPT_FIELD.clear()
q_fork.field(f"{self._tb}.{pk_name}", f"{self._tb}.{key}")
db_rows = q_fork.select()
matching_pks = []
for row_dict in db_rows:
field_json_str = row_dict.get(key)
if field_json_str is None:
continue
if isinstance(field_json_str, str):
try:
current_data = json.loads(field_json_str)
except (json.JSONDecodeError, TypeError):
continue
elif isinstance(field_json_str, (dict, list)):
current_data = field_json_str
else:
continue
target_val, path_found = self.__navigate_path(current_data, road)
if path_found:
if self._python_compare(target_val, compare, val):
pk_value = row_dict.get(pk_name)
if pk_value is not None:
matching_pks.append(pk_value)
if matching_pks:
placeholders = ",".join(["?"] * len(matching_pks))
return f"{self._tb}.{pk_name} IN ({placeholders})", matching_pks
else:
return "1=0", []
def __equal_reducer(self, key: str, road: list, val: Any):
is_json, val = self.__is_json(val)
if not road: # normal field
if val is not None:
if self._flag is True and is_json:
return f"json({self._tb}.{key}) = json(?)", [val]
return f"{self._tb}.{key} = ?", [val]
return f"{self._tb}.{key} IS NULL", []
if self._flag:
path = self.__generate_road(road)
if val is not None:
if is_json:
return f"json_extract({self._tb}.{key}, ?) = json(?)", [path, val]
return f"json_extract({self._tb}.{key}, ?) = ?", [path, val]
return f"json_extract({self._tb}.{key}, ?) IS NULL", [path]
# 退化分支
pk_name = self._model_class.__primary_key__
q_fork = self._query.fork()
q_fork._SqliteEasy__OPT_LIMIT.clear()
q_fork._SqliteEasy__OPT_ORDER.clear()
q_fork._SqliteEasy__OPT_GROUP.clear()
q_fork._SqliteEasy__OPT_HAVING.clear()
q_fork._SqliteEasy__OPT_FIELD.clear()
q_fork.field(f"{self._tb}.{pk_name}", f"{self._tb}.{key}")
db_rows = q_fork.select()
matching_pks = []
new_val = json.loads(val) if is_json else val
for row in db_rows:
field_value = row.get(key)
if field_value is None:
continue
try:
if isinstance(field_value, str):
try:
field_value = json.loads(field_value)
except (json.JSONDecodeError, TypeError):
continue
elif isinstance(field_value, (dict, list)):
field_value = field_value
else:
continue
target, path_found = self.__navigate_path(field_value, road)
if path_found and self._deep_equal(target, new_val):
pk_value = row.get(pk_name)
if pk_value is not None:
matching_pks.append(pk_value)
except:
continue
if matching_pks:
placeholders = ",".join(["?"] * len(matching_pks))
return f"{self._tb}.{pk_name} IN ({placeholders})", matching_pks
return "1=0", []
def __parse_condition(self, condition: Dict[str, Any]):
"""
解析 key, compare, road, val
field 字段
compare 运算符, None为=
road 路径
val
"""
for k, v in condition.items():
parts = k.split("__")
field = parts[0]
if not field or not self._fields.get(field):
raise HintException("%s's fields is not found: '%s'" % (self._model_class.__name__, k))
compare = None
roads = []
for part in parts[1:]:
if part in COMPARE:
compare = part
break
roads.append(part)
yield field, compare, roads, v
def reducer_process(self, condition: Dict[str, Any]) -> Generator[tuple[str, list[Any] | Any], Any, None]:
for key, compare, road, val in self.__parse_condition(condition):
if self._serializes and key in self._serializes:
val = self._serializes[key].serialized(value=val, forward=True)
if not compare:
sql, params = self.__equal_reducer(key=key, road=road, val=val)
else:
if val is None:
raise HintException("do not try to use 'None' value to compare.")
sql, params = self.__compare_reducer(key=key, compare=compare, road=road, val=val)
if sql:
yield sql, params
class Q:
"""
嵌套查询
AND优先级大于OR, 括号改变优先级
example: model.object.filter( Q(a=1) & (Q(b=2) | Q(c=3)) )
"""
AND = "AND"
OR = "OR"
def __init__(self, *args, _connector=None, **kwargs):
self.children: list = []
self._connector = _connector or self.AND
for arg in args:
if isinstance(arg, Q) and arg._connector == self._connector:
self.children.extend(arg.children)
elif isinstance(arg, (Q, dict)):
self.children.append(arg)
else:
raise HintException(f"unsupported operand type(s) for Q: '{type(arg)}'")
if kwargs:
self.children.append(kwargs)
def __and__(self, other):
if not isinstance(other, Q):
raise HintException(f"unsupported operand type(s) for &: 'Q' and '{type(other)}'")
return Q(self, other, _connector=Q.AND)
def __or__(self, other):
if not isinstance(other, Q):
raise HintException(f"unsupported operand type(s) for |: 'Q' and '{type(other)}'")
return Q(self, other, _connector=Q.OR)
def resolve(self, operator, query):
for child in self.children:
if isinstance(child, dict):
for s, p in operator.reducer_process(child):
if s:
query.where(s, p)
elif isinstance(child, Q):
if child._connector == self._connector:
child.resolve(operator, query)
else:
with query.where_nest(logic=self._connector) as n:
child.resolve(operator, n)
else:
raise HintException(f"Invalid child type: {type(child)}")
class QuerySet(Generic[M]):
"""
查询集
"""
def __init__(self, model_class: Type[M], query: "Db.query"):
self._model_class: Type[M] = model_class
self._tb = self._model_class.__table_name__
self._query = query
self._cache = None
self._field_filter = None
def __len__(self):
if self._cache:
return len(self._cache)
raise RuntimeError("QuerySet is not executed, use count() instead")
def __bool__(self):
if self._cache is not None:
return bool(self._cache)
return self.exists()
def __iter__(self) -> Generator[M, None, None]:
yield from self.__execute()
def __getitem__(self, index: Optional[int | slice]) -> Optional[M | List[M]]:
"""
查询结果切片
"""
if isinstance(index, int):
if index < 0:
raise HintException("index is not supported")
if self._cache is not None:
try:
return self._cache[index]
except IndexError:
raise HintException("list index out of range")
else:
new_q = self._clone_q.limit(1).skip(index)
temp = new_q.find()
return self._gen_M(temp) if temp else None
elif isinstance(index, slice):
start = index.start or 0
if start < 0:
raise HintException("start index is not supported")
if index.stop is None:
# not stop, get all
self.__execute()
return self._cache[start: index.stop]
limit = max(0, index.stop - start)
q = self._clone_q.skip(start).limit(limit)
return [
self._gen_M(r) for r in q.select() or []
]
return None
def __add__(self, other: "QuerySet") -> Iterable:
"""
合并两个querset
:return: 生成器
"""
if not isinstance(other, QuerySet):
raise HintException(f"nou support: 'QuerySet' and '{type(other)}'")
if self._model_class != other._model_class:
raise HintException("not the same model class cant be merged")
return chain(self.__execute() or [], other.__execute() or [])
@property
def _clone_q(self) -> "Db.query":
return self._query.fork()
def _gen_M(self, data) -> M:
return self._model_class(
_field_filter=self._field_filter,
**self._model_class._serialized_data(data, self._field_filter)
)
def __execute(self) -> Optional[List[M]]:
if self._cache is None:
try:
if len(self._query._SqliteEasy__OPT_FIELD._Field__FIELDS) == 0:
self._query.field(f"`{self._tb}`.*")
self._cache = [
self._gen_M(i) for i in self._query.select() or []
]
except Exception as e:
print("db query error => %s" % str(e))
raise HintException(e)
return self._cache
def filter(self, *args, **kwargs) -> "QuerySet[M]":
"""
过滤
:return: QuerySet
"""
operator = Operator(model_class=self._model_class, query=self._query)
# args
for i in args:
if isinstance(i, Q):
i.resolve(operator, self._query)
elif isinstance(i, dict):
for s, p in operator.reducer_process(i):
if s:
self._query.where(s, p)
else:
raise HintException(f"Invalid filter argument: {type(i)}")
# kwargs
for s, p in operator.reducer_process(kwargs):
if s:
self._query.where(s, p)
return self
def limit(self, num: int) -> "QuerySet[M]":
"""
限制
:return: QuerySet
"""
self._query.limit(num)
return self
def offset(self, num: int) -> "QuerySet[M]":
"""
偏移量
:return: QuerySet
"""
self._query.skip(num)
return self
def distinct(self) -> "QuerySet[M]":
"""
以指定字段去重
"""
# todo
return self
def order_by(self, *args) -> "QuerySet[M]":
"""
排序
:param args: "filed" ASC "-filed" DESC
:return: QuerySet
"""
reduce(
lambda q, c: q.order(f"{self._tb}.{c[1:]}", "DESC") if c[:1] == "-"
else q.order(f"{self._tb}.{c}"), args, self._query
)
return self
def values(self, *args) -> "QuerySet[M]":
# todo
raise NotImplementedError("values")
def fields(self, *args) -> "QuerySet[M]":
if not args:
return self
field_set = set(args)
# make suer pk
pk = self._model_class.__primary_key__
if pk not in field_set:
field_set.add(pk)
field_set = [f for f in field_set if f in self._model_class._get_fields()]
self._field_filter = field_set # model level
self._query.field(*(f"{self._tb}.{f}" for f in field_set)) # db level
return self
def first(self) -> "Optional[M]":
"""
获取第一条数据
:return: QuerySet
"""
if self._cache is None:
if len(self._query._SqliteEasy__OPT_FIELD._Field__FIELDS) == 0:
self._query.field(f"`{self._tb}`.*")
data = self._query.find()
if not data:
return None
return self._gen_M(data)
else:
return self._cache[0] if len(self._cache) != 0 else None
def get_field(self, key_name: str) -> Optional[Any]:
"""
获取第一条数据的指定字段的值
:param key_name: 字段名
:return: Any
"""
f = self.first()
return f.as_dict().get(key_name) if f else None
def update(self, *args, **kwargs) -> int:
"""
更新数据
:return: int
"""
self._cache = None
if args and kwargs:
raise HintException("args and kwargs can not be used at the same time")
if args:
if len(args) != 1:
raise HintException("%s too many args" % (args,))
elif not isinstance(args[0], dict):
raise HintException("%s must be a dict" % (args[0],))
target = args[0]
elif kwargs:
target = kwargs
else:
target = None
if not target:
return 0
try:
if "update_time" not in target:
fields = getattr(self._model_class, "__fields__", {}) or {}
auto_now = None
for name, f in fields.items():
if name == "update_time" and getattr(f, "auto_now", False) is True:
auto_now = name
break
if auto_now:
import time
from public.aaModel.fields import ACCURACY
target[auto_now] = round(time.time() * ACCURACY)
except:
pass
serlz = self._model_class._get_serialized()
body = {
k: (serlz[k].serialized(v, True) if k in serlz else v) for k, v in target.items()
}
return self._query.update(body)
def delete(self) -> int:
"""
删除数据
:return: int
"""
self._cache = None
count = self._query.delete()
return count
def exists(self) -> bool:
"""
存在数据
:return: bool
"""
q_fk = self._clone_q
q_fk.field(self._model_class.__primary_key__)
q_fk.limit(1)
return bool(q_fk.find())
def count(self) -> int:
"""
获取数量
:return: int
"""
if self._cache is not None:
return len(self._cache)
q = self._clone_q
q._SqliteEasy__OPT_LIMIT.clear()
q._SqliteEasy__OPT_ORDER.clear()
q._SqliteEasy__OPT_GROUP.clear()
q._SqliteEasy__OPT_HAVING.clear()
return q.count()
def as_list(self) -> list:
"""
转列表
:return: list
"""
if self._cache is None:
self.__execute()
return [x.as_dict() for x in self._cache]
class aaObjects(Generic[M]):
"""
管理器
"""
_queryset_class = QuerySet
__m_map__ = {
"default": set(), # 默认使用default.db, default = 表集合
"db": dict() # 指定了db_file的只放在db路径下, db = {db_file: 表集合}
}
def __new__(cls, args):
if hasattr(args, "__table_name__"):
if args.__db_name__ == "default":
if args.__table_name__ not in cls.__m_map__["default"]:
cls.__m_map__["default"].add(aaMigrate(args).run_migrate())
else:
if args.__db_name__ not in cls.__m_map__["db"]:
cls.__m_map__["db"][args.__db_name__] = set()
cls.__m_map__["db"][args.__db_name__].add(aaMigrate(args).run_migrate())
return super(aaObjects, cls).__new__(cls)
def __init__(self, model: Type[M]):
self._model = model
self.__q = None
# @classmethod
# def _as_manager(cls):
# """自定义管理器"""
# return aaManager(obj_cls=cls)
def _get_queryset(self) -> "QuerySet[M]":
"""获取管理器关联的QuerySet"""
return self._queryset_class(self._model, self._query.fork())
@property
def _query(self) -> "Db.query":
if not self.__q:
q = db_file_connect(self._model.__db_name__).query()
self.__q = q.table(self._model.__table_name__)
return self.__q
def _insert(self, val_data) -> int:
return self._query.insert(val_data)
def _update(self, cdt: dict, val_data: dict) -> int:
if not cdt or not val_data:
return 0
q = self._query.fork()
conditions = []
params = []
for k, v in cdt.items():
conditions.append(f"`{k}` = ?")
params.append(v)
q.where(" AND ".join(conditions), params)
return q.update(val_data)
def insert(self, data: Dict[str, Any], raise_exp: bool = True) -> dict:
"""
插入单条数据
:data dict
:raise_exp bool 抛字段类型检查异常
:return 插入的数据
"""
model_obj = self._model(**data)
insert_res = self._insert(
model_obj._validate(raise_exp=raise_exp)
)
if insert_res:
return {
self._model.__primary_key__: insert_res, **model_obj.as_dict()
}
else:
if raise_exp:
raise HintException(insert_res)
else:
return {}
def insert_many(self, data: List[Dict[str, Any]], raise_exp: bool = True) -> int:
"""
批量插入数据
:data list
:raise_exp bool 不抛异常则跳过异常继续插入
:return: int 影响行数
"""
valid_list = []
for i in data:
if i and isinstance(i, dict):
temp = self._model(**i)._validate(raise_exp=raise_exp)
if temp:
valid_list.append(temp)
if not valid_list:
return 0
return self._query.insert_all(valid_list)
def find_one(self, **kwargs) -> Optional[M]:
"""
过滤查询一行数据
:kwargs dict
:return: QuerySet | None
"""
return self._get_queryset().filter(**kwargs).first()
def filter(self, *args, **kwargs) -> "QuerySet[M]":
"""
过滤
:kwargs dict
:return: QuerySet
"""
return self._get_queryset().filter(*args, **kwargs)
def all(self) -> "QuerySet[M]":
"""
所有数据
return: QuerySet
"""
return self._get_queryset()
class aaMigrate:
"""
同步表字段
"""
NULL_MAP = {False: "NOT NULL", True: "NULL"}
def __init__(self, model: M):
self.__model = model
self.__table = self.__model.__table_name__
self.__fields = self.__model._get_fields()
self.__client = None
self.__query = None
def run_migrate(self) -> str | None:
"""
迁移
"""
if not self.__model:
raise PanelError("Model is None")
if not hasattr(self.__model, '__db_name__'):
raise PanelError(f"{self.__model.__class__.__name__} need 'db_name'")
if not hasattr(self.__model, '__table_name__'):
raise PanelError(f"{self.__model.__class__.__name__} need 'table_name'")
if not hasattr(self.__model, '__fields__'):
raise PanelError(f"{self.__model.__class__.__name__} need 'fields'")
try:
self.__client: Db = db_file_connect(self.__model.__db_name__)
self.__table_exists()
self.__index_exists()
except Exception as e:
raise PanelError(e)
finally:
if self.__query:
self.__query.close()
if self.__client:
self.__client.close()
return self.__model.__table_name__
def __new_tb_transform_sql(self, tb_name: str) -> str:
"""
转sql
"""
field_sql = ""
pk_flag = 0
for key, val in self.__fields.items():
if key == "index":
raise PanelError("'%s' is a reserved word in SQL. do not use it" % key)
if val.primary_key is False:
field_sql += f"`{key}` {val.field_type} {self.NULL_MAP.get(val.null)} {val.default_val_sql}, "
else: # is primary_key
pk_flag += 1
if val.field_type != "INTEGER":
raise PanelError("'primary_key' only support IntegerField now")
field_sql += f"`{key}` {val.field_type} PRIMARY KEY AUTOINCREMENT, "
if not field_sql:
return ""
if pk_flag != 1:
raise PanelError("primary_key not found, and must be only one")
field_sql = field_sql.rstrip(", ")
sql = f"""CREATE TABLE IF NOT EXISTS `{tb_name}` ({field_sql});"""
return sql
def __fields_exist(self, add_fields_map: dict = None, del_fields: set = None, set_db: set = None) -> None:
"""
字段处理
"""
if not del_fields:
for k, v in add_fields_map.items():
add_sql = (f"ALTER TABLE `{self.__table}` "
f"ADD COLUMN `{k}` {v.field_type} {v.default_val_sql} {self.NULL_MAP.get(v.null)};")
self.__query.execute(add_sql)
else:
if set_db:
temp_tb = f"table_{uuid.uuid4().hex}"
new = self.__new_tb_transform_sql(temp_tb)
if new:
try:
self.__query.autocommit(autocommit=False)
self.__query.execute("BEGIN;")
self.__query.execute(new)
# rename fields will be loss old data now
format_keys = ", ".join(
[f"`{k}`" for k in set_db if k not in del_fields]
)
copy_sql = (f"INSERT INTO `{temp_tb}` ({format_keys}) "
f"SELECT {format_keys} FROM `{self.__table}`;")
self.__query.execute(copy_sql)
self.__query.execute(f"DROP TABLE IF EXISTS `{self.__table}`;")
self.__query.execute(f"ALTER TABLE `{temp_tb}` RENAME TO `{self.__table}`;")
self.__query.commit()
except Exception as e:
import traceback
print(traceback.format_exc())
self.__query.rollback()
raise e
def __table_exists(self) -> None:
"""
表迁移
"""
self.__query = self.__client.query().table("sqlite_master")
if self.__query.where("type=? AND name=?", ("table", self.__table)).count() != 1:
sql = self.__new_tb_transform_sql(self.__table)
if sql:
self.__query.execute(sql)
else: # has table
self.__query.table(self.__table)
set_cur = set(self.__fields.keys())
set_db = set(self.__query.get_columns())
add_fields = set_cur - set_db
del_fields = set_db - set_cur
add_fields_map = {k: v for k, v in self.__fields.items() if k in add_fields}
self.__fields_exist(add_fields_map, del_fields, set_db)
def __trans_index_key(self, index_info: tuple | str) -> str:
def __if_raise_error(item: str):
if not self.__model._get_fields().get(item):
raise PanelError(f"create index error, '{item}' is not in model's fields")
col_sql = ""
if isinstance(index_info, tuple):
for item in index_info:
__if_raise_error(item)
col_sql += f"`{item}`,"
elif isinstance(index_info, str):
__if_raise_error(index_info)
col_sql = f"`{index_info}`"
else:
raise PanelError("model's index error, should be like ['key1', ('key2', 'key3')]")
return col_sql.rstrip(",")
def __index_exists(self) -> bool | None:
"""
索引
"""
try:
if not hasattr(self.__model, "__index_keys__"):
return True
self.__query.table(self.__table)
cur = self.__query.query(f"PRAGMA index_list(`{self.__table}`);") or []
current_index = [
x.get("name") for x in cur if str(x.get("origin", "c")).lower() == "c"
] if cur else []
sql_statements = []
wanted = set()
for index_info in self.__model.__index_keys__:
"""
todo
index_info 对应字段
字段类型不为json: 普通索引
如果为list: 分表, 索引, CURD触发, 存在性同步.
如果为dict: 路径虚拟列索引
复合索引 index_info tuple
如果移除索引, 检查上述步骤
"""
col_sql = self.__trans_index_key(index_info)
cols = [col.strip("` ") for col in col_sql.split(",")]
index_name = f"idx_{self.__table}_{'_'.join(cols)}"
wanted.add(index_name)
if index_name not in current_index:
sql_statements.append(
f"CREATE INDEX IF NOT EXISTS `{index_name}` ON `{self.__table}` ({col_sql}); "
)
for index in current_index:
if index not in wanted:
sql_statements.append(f"DROP INDEX IF EXISTS `{index}`;")
if sql_statements:
self.__query.execute_script(
" ".join(sql_statements)
)
return True
except:
pass
class aaManager:
def __init__(self, obj_cls=aaObjects, qs_cls=QuerySet):
self._objects_class = obj_cls
self._queryset_class = qs_cls
self._cache = {}
def __get__(self, instance, cls: Type[M]):
if instance is not None:
raise AttributeError(
f"object manager can't accessible from '{cls.__name__}' instances"
)
try:
manager = self._cache.get(cls)
if manager is None:
manager = self._objects_class(cls)
setattr(manager, "_queryset_class", self._queryset_class)
self._cache[cls] = manager
return manager
except Exception:
import traceback
raise PanelError(traceback.format_exc())