# coding: utf-8 import os import sqlite3 as Engine import uuid from functools import reduce from itertools import chain from typing import Optional, TypeVar, Generic, Any, List, Dict, Generator, Iterable, Type try: import ujson as json except ImportError: try: os.system("btpip install ujson") import ujson as json except: import json from public.aaModel.fields import COMPARE from public.exceptions import HintException, PanelError from public.sqlite_easy import Db __all__ = ["aaManager", "Q"] # noinspection PyUnresolvedReferences M = TypeVar("M", bound="aaModel") # ==================== Patch ================== def _builtin(check_engine: Any = None) -> bool: if not check_engine: check_engine = Engine try: conn = check_engine.connect(":memory:") cursor = conn.cursor() cursor.execute("SELECT json_extract('{\"a\": 1}', '$.a')") cursor.execute("SELECT COUNT(*) FROM json_each('{\"a\":1, \"b\":2}')") conn.close() return True except: return False def _get_engine() -> tuple[bool, Any]: try: import pysqlite3 as engine flag = True except: try: os.system("btpip install pysqlite3-binary") import pysqlite3 as engine flag = True except: engine = Engine flag = False return flag, engine _ENGINE = None _INSTEAD = False _ORG = _builtin() if not _ORG: _INSTEAD, _ENGINE = _get_engine() else: _ENGINE = Engine # ==================== Patch End ================== def db_file_connect(db_name: str) -> Db: if not db_name: raise PanelError("db_file_connect_generator error, db_name is empty") if db_name == "default": return Db(db_name=db_name, engine=_ENGINE) else: db_root_dir = "/www/server/panel/data/db" os.makedirs(db_root_dir, exist_ok=True) db_path = os.path.join(db_root_dir, f"{db_name}") new = Db(db_name="", engine=_ENGINE, auto_connect=False) setattr(new, "_Db__DB_NAME", db_path) new._connect() return new class Operator: SIMPLE_OP = { "gt": ">", "lt": "<", "gte": ">=", "lte": "<=", "ne": "!=", } LIKE_OP = { "like": "%{}%", "startswith": "{}%", "endswith": "%{}", } def __init__(self, model_class: M, query: "Db.query"): self._model_class: M = model_class self._query = query self._tb = self._model_class.__table_name__ self._fields = self._model_class._get_fields() self._serializes = self._model_class._get_serialized() self._flag = _ORG or _INSTEAD # fk flag def _q_error(self, key: str, act: str, val: Any, sp_act: tuple): raise HintException( "field: '%s' is not support '%s', you can use: %s" % (key, act, sp_act) ) def _deep_equal(self, obj1: Any, obj2: Any) -> bool: if type(obj1) != type(obj2): return False if isinstance(obj1, dict): if set(obj1.keys()) != set(obj2.keys()): return False return all(self._deep_equal(obj1[k], obj2[k]) for k in obj1) if isinstance(obj1, list): if len(obj1) != len(obj2): return False return all(self._deep_equal(i1, i2) for i1, i2 in zip(obj1, obj2)) return obj1 == obj2 def _deep_equal_in(self, item: Any, container_list: List[Any]) -> bool: for elem in container_list: if self._deep_equal(item, elem): return True return False def _python_compare(self, v: Any, op_str: str, q_v: Any) -> bool: try: if op_str == "gt": return v > q_v if op_str == "lt": return v < q_v if op_str == "gte": return v >= q_v if op_str == "lte": return v <= q_v if op_str == "ne": return not self._deep_equal(v, q_v) if op_str == "like": return isinstance(v, str) and isinstance(q_v, str) and q_v in v if op_str == "startswith": return isinstance(v, str) and isinstance(q_v, str) and v.startswith(q_v) if op_str == "endswith": return isinstance(v, str) and isinstance(q_v, str) and v.endswith(q_v) if op_str == "in": return isinstance(q_v, list) and self._deep_equal_in(v, q_v) if op_str == "not_in": return isinstance(q_v, list) and not self._deep_equal_in(v, q_v) if op_str == "contains" or op_str == "any_contains": items_to_search = q_v if not isinstance(items_to_search, list): items_to_search = [items_to_search] if not items_to_search: # 如果搜索列表为空 return True if op_str == "contains" else False # AND(empty)=True, OR(empty)=False match_results = [] for item in items_to_search: found = False if isinstance(v, list): found = self._deep_equal_in(item, v) elif isinstance(v, str) and isinstance(item, str): found = item in v elif isinstance(v, dict): # 检查item是否为dict中的值 found = self._deep_equal_in(item, list(v.values())) match_results.append(found) return all(match_results) if op_str == "contains" else any(match_results) except TypeError: # 比较不兼容的类型 return False except Exception: return False return False def __generate_road(self, road): path = "$" for r in road: if r.isdigit(): path += f"[{r}]" else: path += f".{r}" return path def __is_json(self, val: Any): if isinstance(val, str): val_str = val.strip() if val_str.startswith(("{", "[")) and val_str.endswith(("}", "]")): try: json.loads(val_str) return True, val except json.JSONDecodeError: pass if isinstance(val, (dict, list)): return True, json.dumps(val) return False, val def __navigate_path(self, json_data: Any, road: List[str]) -> tuple[Any, bool]: current_val = json_data for r in road: if isinstance(current_val, dict): if r in current_val: current_val = current_val[r] else: return None, False elif isinstance(current_val, list) and r.isdigit(): idx = int(r) if 0 <= idx < len(current_val): current_val = current_val[idx] else: return None, False else: return None, False return current_val, True def __compare_operator(self, key: str, compare: str, val: Any, is_json: bool, sp_compare: tuple): def __contains_and_or(v: list, connector: str): if not v: return ("1=1" if connector == "AND" else "1=0"), [] conditions = [] params = [] for item in v: if self._flag is True and not isinstance(item, (dict, list)): # simple val conditions.append(f"EXISTS(SELECT 1 FROM json_each({key}) WHERE value = ?)") params.append(item) else: # other | complicated val conditions.append(f"instr({key}, ?) > 0") params.append(json.dumps(item)) return f" {connector} ".join(conditions), params if compare in self.SIMPLE_OP: if compare == "ne" and self._flag and is_json: return f"NOT (json({key}) = json(?))", [val] return f"{key} {self.SIMPLE_OP[compare]} ?", [val] elif compare in self.LIKE_OP: return f"{key} LIKE ?", [self.LIKE_OP[compare].format(val)] elif compare in ("contains", "any_contains"): if is_json: try: val = json.loads(val) except (json.JSONDecodeError, TypeError): pass if not isinstance(val, list): val = [val] if compare == "contains": return __contains_and_or(val, "AND") else: return __contains_and_or(val, "OR") elif compare in ("in", "not_in"): if is_json: try: val = json.loads(val) except (json.JSONDecodeError, TypeError): pass if not isinstance(val, (list, tuple, set)): raise HintException(f"{key}__{compare} expects a list/tuple/set") val = list(val) if len(val) == 0: return ("1=0" if compare == "in" else "1=1"), [] placeholders = ", ".join(["?"] * len(val)) op = "IN" if compare == "in" else "NOT IN" return f"{key} {op} ({placeholders})", val self._q_error(key, compare, val, sp_compare) return None, None def __compare_reducer(self, key: str, compare: str, road: list, val: Any): sp_compare = getattr(self._fields.get(key), "compare") compare = compare.lower() # 没路径, 正常拦截op # 有路径, 最终字段类型是不确定的, 查询结果不可控 if not sp_compare or compare not in [c.lower() for c in sp_compare]: if road: self._q_error(f"{key}__{road}", compare, val, sp_compare) else: self._q_error(key, compare, val, sp_compare) is_json, val = self.__is_json(val) if not road: sql, params = self.__compare_operator( key=f"{self._tb}.{key}", compare=compare, val=val, is_json=is_json, sp_compare=sp_compare ) return sql, params else: if self._flag: sql, params = self.__compare_operator( key=f"json_extract({self._tb}.{key}, '{self.__generate_road(road)}')", compare=compare, val=val, is_json=is_json, sp_compare=sp_compare ) return sql, params # 退化分支 if is_json: try: val = json.loads(val) except: pass pk_name = self._model_class.__primary_key__ q_fork = self._query.fork() q_fork._SqliteEasy__OPT_LIMIT.clear() q_fork._SqliteEasy__OPT_ORDER.clear() q_fork._SqliteEasy__OPT_GROUP.clear() q_fork._SqliteEasy__OPT_HAVING.clear() q_fork._SqliteEasy__OPT_FIELD.clear() q_fork.field(f"{self._tb}.{pk_name}", f"{self._tb}.{key}") db_rows = q_fork.select() matching_pks = [] for row_dict in db_rows: field_json_str = row_dict.get(key) if field_json_str is None: continue if isinstance(field_json_str, str): try: current_data = json.loads(field_json_str) except (json.JSONDecodeError, TypeError): continue elif isinstance(field_json_str, (dict, list)): current_data = field_json_str else: continue target_val, path_found = self.__navigate_path(current_data, road) if path_found: if self._python_compare(target_val, compare, val): pk_value = row_dict.get(pk_name) if pk_value is not None: matching_pks.append(pk_value) if matching_pks: placeholders = ",".join(["?"] * len(matching_pks)) return f"{self._tb}.{pk_name} IN ({placeholders})", matching_pks else: return "1=0", [] def __equal_reducer(self, key: str, road: list, val: Any): is_json, val = self.__is_json(val) if not road: # normal field if val is not None: if self._flag is True and is_json: return f"json({self._tb}.{key}) = json(?)", [val] return f"{self._tb}.{key} = ?", [val] return f"{self._tb}.{key} IS NULL", [] if self._flag: path = self.__generate_road(road) if val is not None: if is_json: return f"json_extract({self._tb}.{key}, ?) = json(?)", [path, val] return f"json_extract({self._tb}.{key}, ?) = ?", [path, val] return f"json_extract({self._tb}.{key}, ?) IS NULL", [path] # 退化分支 pk_name = self._model_class.__primary_key__ q_fork = self._query.fork() q_fork._SqliteEasy__OPT_LIMIT.clear() q_fork._SqliteEasy__OPT_ORDER.clear() q_fork._SqliteEasy__OPT_GROUP.clear() q_fork._SqliteEasy__OPT_HAVING.clear() q_fork._SqliteEasy__OPT_FIELD.clear() q_fork.field(f"{self._tb}.{pk_name}", f"{self._tb}.{key}") db_rows = q_fork.select() matching_pks = [] new_val = json.loads(val) if is_json else val for row in db_rows: field_value = row.get(key) if field_value is None: continue try: if isinstance(field_value, str): try: field_value = json.loads(field_value) except (json.JSONDecodeError, TypeError): continue elif isinstance(field_value, (dict, list)): field_value = field_value else: continue target, path_found = self.__navigate_path(field_value, road) if path_found and self._deep_equal(target, new_val): pk_value = row.get(pk_name) if pk_value is not None: matching_pks.append(pk_value) except: continue if matching_pks: placeholders = ",".join(["?"] * len(matching_pks)) return f"{self._tb}.{pk_name} IN ({placeholders})", matching_pks return "1=0", [] def __parse_condition(self, condition: Dict[str, Any]): """ 解析 key, compare, road, val field 字段 compare 运算符, None为= road 路径 val 值 """ for k, v in condition.items(): parts = k.split("__") field = parts[0] if not field or not self._fields.get(field): raise HintException("%s's fields is not found: '%s'" % (self._model_class.__name__, k)) compare = None roads = [] for part in parts[1:]: if part in COMPARE: compare = part break roads.append(part) yield field, compare, roads, v def reducer_process(self, condition: Dict[str, Any]) -> Generator[tuple[str, list[Any] | Any], Any, None]: for key, compare, road, val in self.__parse_condition(condition): if self._serializes and key in self._serializes: val = self._serializes[key].serialized(value=val, forward=True) if not compare: sql, params = self.__equal_reducer(key=key, road=road, val=val) else: if val is None: raise HintException("do not try to use 'None' value to compare.") sql, params = self.__compare_reducer(key=key, compare=compare, road=road, val=val) if sql: yield sql, params class Q: """ 嵌套查询 AND优先级大于OR, 括号改变优先级 example: model.object.filter( Q(a=1) & (Q(b=2) | Q(c=3)) ) """ AND = "AND" OR = "OR" def __init__(self, *args, _connector=None, **kwargs): self.children: list = [] self._connector = _connector or self.AND for arg in args: if isinstance(arg, Q) and arg._connector == self._connector: self.children.extend(arg.children) elif isinstance(arg, (Q, dict)): self.children.append(arg) else: raise HintException(f"unsupported operand type(s) for Q: '{type(arg)}'") if kwargs: self.children.append(kwargs) def __and__(self, other): if not isinstance(other, Q): raise HintException(f"unsupported operand type(s) for &: 'Q' and '{type(other)}'") return Q(self, other, _connector=Q.AND) def __or__(self, other): if not isinstance(other, Q): raise HintException(f"unsupported operand type(s) for |: 'Q' and '{type(other)}'") return Q(self, other, _connector=Q.OR) def resolve(self, operator, query): for child in self.children: if isinstance(child, dict): for s, p in operator.reducer_process(child): if s: query.where(s, p) elif isinstance(child, Q): if child._connector == self._connector: child.resolve(operator, query) else: with query.where_nest(logic=self._connector) as n: child.resolve(operator, n) else: raise HintException(f"Invalid child type: {type(child)}") class QuerySet(Generic[M]): """ 查询集 """ def __init__(self, model_class: Type[M], query: "Db.query"): self._model_class: Type[M] = model_class self._tb = self._model_class.__table_name__ self._query = query self._cache = None self._field_filter = None def __len__(self): if self._cache: return len(self._cache) raise RuntimeError("QuerySet is not executed, use count() instead") def __bool__(self): if self._cache is not None: return bool(self._cache) return self.exists() def __iter__(self) -> Generator[M, None, None]: yield from self.__execute() def __getitem__(self, index: Optional[int | slice]) -> Optional[M | List[M]]: """ 查询结果切片 """ if isinstance(index, int): if index < 0: raise HintException("index is not supported") if self._cache is not None: try: return self._cache[index] except IndexError: raise HintException("list index out of range") else: new_q = self._clone_q.limit(1).skip(index) temp = new_q.find() return self._gen_M(temp) if temp else None elif isinstance(index, slice): start = index.start or 0 if start < 0: raise HintException("start index is not supported") if index.stop is None: # not stop, get all self.__execute() return self._cache[start: index.stop] limit = max(0, index.stop - start) q = self._clone_q.skip(start).limit(limit) return [ self._gen_M(r) for r in q.select() or [] ] return None def __add__(self, other: "QuerySet") -> Iterable: """ 合并两个querset :return: 生成器 """ if not isinstance(other, QuerySet): raise HintException(f"nou support: 'QuerySet' and '{type(other)}'") if self._model_class != other._model_class: raise HintException("not the same model class cant be merged") return chain(self.__execute() or [], other.__execute() or []) @property def _clone_q(self) -> "Db.query": return self._query.fork() def _gen_M(self, data) -> M: return self._model_class( _field_filter=self._field_filter, **self._model_class._serialized_data(data, self._field_filter) ) def __execute(self) -> Optional[List[M]]: if self._cache is None: try: if len(self._query._SqliteEasy__OPT_FIELD._Field__FIELDS) == 0: self._query.field(f"`{self._tb}`.*") self._cache = [ self._gen_M(i) for i in self._query.select() or [] ] except Exception as e: print("db query error => %s" % str(e)) raise HintException(e) return self._cache def filter(self, *args, **kwargs) -> "QuerySet[M]": """ 过滤 :return: QuerySet """ operator = Operator(model_class=self._model_class, query=self._query) # args for i in args: if isinstance(i, Q): i.resolve(operator, self._query) elif isinstance(i, dict): for s, p in operator.reducer_process(i): if s: self._query.where(s, p) else: raise HintException(f"Invalid filter argument: {type(i)}") # kwargs for s, p in operator.reducer_process(kwargs): if s: self._query.where(s, p) return self def limit(self, num: int) -> "QuerySet[M]": """ 限制 :return: QuerySet """ self._query.limit(num) return self def offset(self, num: int) -> "QuerySet[M]": """ 偏移量 :return: QuerySet """ self._query.skip(num) return self def distinct(self) -> "QuerySet[M]": """ 以指定字段去重 """ # todo return self def order_by(self, *args) -> "QuerySet[M]": """ 排序 :param args: "filed" ASC "-filed" DESC :return: QuerySet """ reduce( lambda q, c: q.order(f"{self._tb}.{c[1:]}", "DESC") if c[:1] == "-" else q.order(f"{self._tb}.{c}"), args, self._query ) return self def values(self, *args) -> "QuerySet[M]": # todo raise NotImplementedError("values") def fields(self, *args) -> "QuerySet[M]": if not args: return self field_set = set(args) # make suer pk pk = self._model_class.__primary_key__ if pk not in field_set: field_set.add(pk) field_set = [f for f in field_set if f in self._model_class._get_fields()] self._field_filter = field_set # model level self._query.field(*(f"{self._tb}.{f}" for f in field_set)) # db level return self def first(self) -> "Optional[M]": """ 获取第一条数据 :return: QuerySet """ if self._cache is None: if len(self._query._SqliteEasy__OPT_FIELD._Field__FIELDS) == 0: self._query.field(f"`{self._tb}`.*") data = self._query.find() if not data: return None return self._gen_M(data) else: return self._cache[0] if len(self._cache) != 0 else None def get_field(self, key_name: str) -> Optional[Any]: """ 获取第一条数据的指定字段的值 :param key_name: 字段名 :return: Any """ f = self.first() return f.as_dict().get(key_name) if f else None def update(self, *args, **kwargs) -> int: """ 更新数据 :return: int """ self._cache = None if args and kwargs: raise HintException("args and kwargs can not be used at the same time") if args: if len(args) != 1: raise HintException("%s too many args" % (args,)) elif not isinstance(args[0], dict): raise HintException("%s must be a dict" % (args[0],)) target = args[0] elif kwargs: target = kwargs else: target = None if not target: return 0 try: if "update_time" not in target: fields = getattr(self._model_class, "__fields__", {}) or {} auto_now = None for name, f in fields.items(): if name == "update_time" and getattr(f, "auto_now", False) is True: auto_now = name break if auto_now: import time from public.aaModel.fields import ACCURACY target[auto_now] = round(time.time() * ACCURACY) except: pass serlz = self._model_class._get_serialized() body = { k: (serlz[k].serialized(v, True) if k in serlz else v) for k, v in target.items() } return self._query.update(body) def delete(self) -> int: """ 删除数据 :return: int """ self._cache = None count = self._query.delete() return count def exists(self) -> bool: """ 存在数据 :return: bool """ q_fk = self._clone_q q_fk.field(self._model_class.__primary_key__) q_fk.limit(1) return bool(q_fk.find()) def count(self) -> int: """ 获取数量 :return: int """ if self._cache is not None: return len(self._cache) q = self._clone_q q._SqliteEasy__OPT_LIMIT.clear() q._SqliteEasy__OPT_ORDER.clear() q._SqliteEasy__OPT_GROUP.clear() q._SqliteEasy__OPT_HAVING.clear() return q.count() def as_list(self) -> list: """ 转列表 :return: list """ if self._cache is None: self.__execute() return [x.as_dict() for x in self._cache] class aaObjects(Generic[M]): """ 管理器 """ _queryset_class = QuerySet __m_map__ = { "default": set(), # 默认使用default.db, default = 表集合 "db": dict() # 指定了db_file的只放在db路径下, db = {db_file: 表集合} } def __new__(cls, args): if hasattr(args, "__table_name__"): if args.__db_name__ == "default": if args.__table_name__ not in cls.__m_map__["default"]: cls.__m_map__["default"].add(aaMigrate(args).run_migrate()) else: if args.__db_name__ not in cls.__m_map__["db"]: cls.__m_map__["db"][args.__db_name__] = set() cls.__m_map__["db"][args.__db_name__].add(aaMigrate(args).run_migrate()) return super(aaObjects, cls).__new__(cls) def __init__(self, model: Type[M]): self._model = model self.__q = None # @classmethod # def _as_manager(cls): # """自定义管理器""" # return aaManager(obj_cls=cls) def _get_queryset(self) -> "QuerySet[M]": """获取管理器关联的QuerySet""" return self._queryset_class(self._model, self._query.fork()) @property def _query(self) -> "Db.query": if not self.__q: q = db_file_connect(self._model.__db_name__).query() self.__q = q.table(self._model.__table_name__) return self.__q def _insert(self, val_data) -> int: return self._query.insert(val_data) def _update(self, cdt: dict, val_data: dict) -> int: if not cdt or not val_data: return 0 q = self._query.fork() conditions = [] params = [] for k, v in cdt.items(): conditions.append(f"`{k}` = ?") params.append(v) q.where(" AND ".join(conditions), params) return q.update(val_data) def insert(self, data: Dict[str, Any], raise_exp: bool = True) -> dict: """ 插入单条数据 :data dict :raise_exp bool 抛字段类型检查异常 :return 插入的数据 """ model_obj = self._model(**data) insert_res = self._insert( model_obj._validate(raise_exp=raise_exp) ) if insert_res: return { self._model.__primary_key__: insert_res, **model_obj.as_dict() } else: if raise_exp: raise HintException(insert_res) else: return {} def insert_many(self, data: List[Dict[str, Any]], raise_exp: bool = True) -> int: """ 批量插入数据 :data list :raise_exp bool 不抛异常则跳过异常继续插入 :return: int 影响行数 """ valid_list = [] for i in data: if i and isinstance(i, dict): temp = self._model(**i)._validate(raise_exp=raise_exp) if temp: valid_list.append(temp) if not valid_list: return 0 return self._query.insert_all(valid_list) def find_one(self, **kwargs) -> Optional[M]: """ 过滤查询一行数据 :kwargs dict :return: QuerySet | None """ return self._get_queryset().filter(**kwargs).first() def filter(self, *args, **kwargs) -> "QuerySet[M]": """ 过滤 :kwargs dict :return: QuerySet """ return self._get_queryset().filter(*args, **kwargs) def all(self) -> "QuerySet[M]": """ 所有数据 return: QuerySet """ return self._get_queryset() class aaMigrate: """ 同步表字段 """ NULL_MAP = {False: "NOT NULL", True: "NULL"} def __init__(self, model: M): self.__model = model self.__table = self.__model.__table_name__ self.__fields = self.__model._get_fields() self.__client = None self.__query = None def run_migrate(self) -> str | None: """ 迁移 """ if not self.__model: raise PanelError("Model is None") if not hasattr(self.__model, '__db_name__'): raise PanelError(f"{self.__model.__class__.__name__} need 'db_name'") if not hasattr(self.__model, '__table_name__'): raise PanelError(f"{self.__model.__class__.__name__} need 'table_name'") if not hasattr(self.__model, '__fields__'): raise PanelError(f"{self.__model.__class__.__name__} need 'fields'") try: self.__client: Db = db_file_connect(self.__model.__db_name__) self.__table_exists() self.__index_exists() except Exception as e: raise PanelError(e) finally: if self.__query: self.__query.close() if self.__client: self.__client.close() return self.__model.__table_name__ def __new_tb_transform_sql(self, tb_name: str) -> str: """ 转sql """ field_sql = "" pk_flag = 0 for key, val in self.__fields.items(): if key == "index": raise PanelError("'%s' is a reserved word in SQL. do not use it" % key) if val.primary_key is False: field_sql += f"`{key}` {val.field_type} {self.NULL_MAP.get(val.null)} {val.default_val_sql}, " else: # is primary_key pk_flag += 1 if val.field_type != "INTEGER": raise PanelError("'primary_key' only support IntegerField now") field_sql += f"`{key}` {val.field_type} PRIMARY KEY AUTOINCREMENT, " if not field_sql: return "" if pk_flag != 1: raise PanelError("primary_key not found, and must be only one") field_sql = field_sql.rstrip(", ") sql = f"""CREATE TABLE IF NOT EXISTS `{tb_name}` ({field_sql});""" return sql def __fields_exist(self, add_fields_map: dict = None, del_fields: set = None, set_db: set = None) -> None: """ 字段处理 """ if not del_fields: for k, v in add_fields_map.items(): add_sql = (f"ALTER TABLE `{self.__table}` " f"ADD COLUMN `{k}` {v.field_type} {v.default_val_sql} {self.NULL_MAP.get(v.null)};") self.__query.execute(add_sql) else: if set_db: temp_tb = f"table_{uuid.uuid4().hex}" new = self.__new_tb_transform_sql(temp_tb) if new: try: self.__query.autocommit(autocommit=False) self.__query.execute("BEGIN;") self.__query.execute(new) # rename fields will be loss old data now format_keys = ", ".join( [f"`{k}`" for k in set_db if k not in del_fields] ) copy_sql = (f"INSERT INTO `{temp_tb}` ({format_keys}) " f"SELECT {format_keys} FROM `{self.__table}`;") self.__query.execute(copy_sql) self.__query.execute(f"DROP TABLE IF EXISTS `{self.__table}`;") self.__query.execute(f"ALTER TABLE `{temp_tb}` RENAME TO `{self.__table}`;") self.__query.commit() except Exception as e: import traceback print(traceback.format_exc()) self.__query.rollback() raise e def __table_exists(self) -> None: """ 表迁移 """ self.__query = self.__client.query().table("sqlite_master") if self.__query.where("type=? AND name=?", ("table", self.__table)).count() != 1: sql = self.__new_tb_transform_sql(self.__table) if sql: self.__query.execute(sql) else: # has table self.__query.table(self.__table) set_cur = set(self.__fields.keys()) set_db = set(self.__query.get_columns()) add_fields = set_cur - set_db del_fields = set_db - set_cur add_fields_map = {k: v for k, v in self.__fields.items() if k in add_fields} self.__fields_exist(add_fields_map, del_fields, set_db) def __trans_index_key(self, index_info: tuple | str) -> str: def __if_raise_error(item: str): if not self.__model._get_fields().get(item): raise PanelError(f"create index error, '{item}' is not in model's fields") col_sql = "" if isinstance(index_info, tuple): for item in index_info: __if_raise_error(item) col_sql += f"`{item}`," elif isinstance(index_info, str): __if_raise_error(index_info) col_sql = f"`{index_info}`" else: raise PanelError("model's index error, should be like ['key1', ('key2', 'key3')]") return col_sql.rstrip(",") def __index_exists(self) -> bool | None: """ 索引 """ try: if not hasattr(self.__model, "__index_keys__"): return True self.__query.table(self.__table) cur = self.__query.query(f"PRAGMA index_list(`{self.__table}`);") or [] current_index = [ x.get("name") for x in cur if str(x.get("origin", "c")).lower() == "c" ] if cur else [] sql_statements = [] wanted = set() for index_info in self.__model.__index_keys__: """ todo index_info 对应字段 字段类型不为json: 普通索引 如果为list: 分表, 索引, CURD触发, 存在性同步. 如果为dict: 路径虚拟列索引 复合索引 index_info tuple 如果移除索引, 检查上述步骤 """ col_sql = self.__trans_index_key(index_info) cols = [col.strip("` ") for col in col_sql.split(",")] index_name = f"idx_{self.__table}_{'_'.join(cols)}" wanted.add(index_name) if index_name not in current_index: sql_statements.append( f"CREATE INDEX IF NOT EXISTS `{index_name}` ON `{self.__table}` ({col_sql}); " ) for index in current_index: if index not in wanted: sql_statements.append(f"DROP INDEX IF EXISTS `{index}`;") if sql_statements: self.__query.execute_script( " ".join(sql_statements) ) return True except: pass class aaManager: def __init__(self, obj_cls=aaObjects, qs_cls=QuerySet): self._objects_class = obj_cls self._queryset_class = qs_cls self._cache = {} def __get__(self, instance, cls: Type[M]): if instance is not None: raise AttributeError( f"object manager can't accessible from '{cls.__name__}' instances" ) try: manager = self._cache.get(cls) if manager is None: manager = self._objects_class(cls) setattr(manager, "_queryset_class", self._queryset_class) self._cache[cls] = manager return manager except Exception: import traceback raise PanelError(traceback.format_exc())