diff --git a/db_test2.py b/db_test2.py new file mode 100644 index 0000000..84b4d4a --- /dev/null +++ b/db_test2.py @@ -0,0 +1,63 @@ +import json +import random + +from src.api.data import LiteModel, SqliteORMAdapter + + +class Score(LiteModel): + subject: str + score: int + + +class Student(LiteModel): + name: str + age: str + sex: str + scores: list[Score] = [] + + +class Teacher(LiteModel): + name: str + age: str + + +class Class(LiteModel): + name: str + students: list[Student] + test_json: dict[str, Student] = {'张三': Student} + + +class University(LiteModel): + name: str + address: str + rank: int = 0 + classes: list[Class] = [] + students: list[Student] = [] + + +cqupt = University(name='重庆邮电大学', address='重庆', rank=100) +pku = University(name='北京大学', address='北京', rank=1) +cqu = University(name='重庆大学', address='重庆', rank=10) + +student_name_list = ['张三', '李四', '王五', '赵六', '钱七', '孙八', '周九', '吴十'] +student_list = [Student(name=name, age=19, sex=random.choice(['男', '女'])) for name in student_name_list] +# 构建复杂模型进行数据库测试 +cqupt.students = student_list +cqupt.classes = [Class(name='软件工程', students=student_list, test_json={name: student for name, student in zip(student_name_list, student_list)})] +for student in student_list: + student.scores = [Score(subject='语文', score=random.randint(60, 100)), Score(subject='数学', score=random.randint(60, 100)), + Score(subject='英语', score=random.randint(60, 100))] + student.scores.append(Score(subject='物理', score=random.randint(60, 100))) + student.scores.append(Score(subject='化学', score=random.randint(60, 100))) + student.scores.append(Score(subject='生物', score=random.randint(60, 100))) + student.scores.append(Score(subject='历史', score=random.randint(60, 100))) + student.scores.append(Score(subject='地理', score=random.randint(60, 100))) + student.scores.append(Score(subject='政治', score=random.randint(60, 100)) + ) +print(json.dumps(cqupt.dict(), indent=4, ensure_ascii=False)) +db = SqliteORMAdapter('test2.db') +db.auto_migrate(University, Class, Student, Score) +db.save(cqupt) + +# 查询测试 +db.first(University, 'name = ?', '重庆邮电大学') \ No newline at end of file diff --git a/src/api/data.py b/src/api/data.py index 4f7889b..6dc4ff8 100644 --- a/src/api/data.py +++ b/src/api/data.py @@ -1,7 +1,10 @@ import json import sqlite3 +import types +import typing from abc import ABC from collections.abc import Iterable +from copy import deepcopy from typing import Any from pydantic import BaseModel @@ -10,7 +13,10 @@ BaseIterable = list | tuple | set | dict class LiteModel(BaseModel): - pass + """轻量级模型基类 + 类型注解统一使用Python3.9的PEP585标准,如需使用泛型请使用typing模块的泛型类型 + """ + id: Any = None class BaseORMAdapter(ABC): @@ -67,6 +73,11 @@ class BaseORMAdapter(ABC): class SqliteORMAdapter(BaseORMAdapter): + """SQLiteORM适配器,严禁使用FORIEGNID和JSON作为主键前缀,严禁使用$ID:作为字符串值前缀 + + Attributes: + + """ type_map = { # default: TEXT str : 'TEXT', @@ -75,60 +86,79 @@ class SqliteORMAdapter(BaseORMAdapter): bool : 'INTEGER', list : 'TEXT' } + FOREIGNID = 'FOREIGNID' + JSON = 'JSON' + ID = '$ID' def __init__(self, db_name: str): super().__init__() self.conn = sqlite3.connect(db_name) + self.conn.row_factory = sqlite3.Row self.cursor = self.conn.cursor() - def auto_migrate(self, *args: LiteModel): + def auto_migrate(self, *args: type(LiteModel)): """自动迁移,检测新模型字段和原有表字段的差异,如有差异自动增删新字段 Args: - *args: + *args: 模型类 Returns: """ for model in args: - # 检测并创建表 - self.cursor.execute(f'CREATE TABLE IF NOT EXISTS {model.__name__}(id INTEGER PRIMARY KEY AUTOINCREMENT)') + model: type(LiteModel) + # 检测并创建表,若模型未定义id字段则使用自增主键,有定义的话使用id字段,且id有可能为字符串 + table_name = model.__name__ + if 'id' in model.__annotations__ and model.__annotations__['id'] is not None: + # 如果模型定义了id字段,那么使用模型的id字段 + id_type = self.type_map.get(model.__annotations__['id'], 'TEXT') + self.cursor.execute(f'CREATE TABLE IF NOT EXISTS {table_name} (id {id_type} PRIMARY KEY)') + else: + # 如果模型未定义id字段,那么使用自增主键 + self.cursor.execute(f'CREATE TABLE IF NOT EXISTS {table_name} (id INTEGER PRIMARY KEY AUTOINCREMENT)') # 获取表字段 - self.cursor.execute(f'PRAGMA table_info({model.__name__})') + self.cursor.execute(f'PRAGMA table_info({table_name})') table_fields = self.cursor.fetchall() table_fields = [field[1] for field in table_fields] - # 获取模型字段 - model_fields = model.__dict__.keys() - - # 获取模型字段类型 - model_types = [self.type_map.get(type(model.__dict__[field]), 'TEXT') for field in model_fields] - - # 获取模型字段默认值 - model_defaults = [model.__dict__[field] for field in model_fields] + raw_fields = model.__annotations__.keys() + # 获取模型字段,若有模型则添加FOREIGNID前缀,若为BaseIterable则添加JSON前缀,用多行if判断 + model_fields = [] + model_types = [] + for field in raw_fields: + if isinstance(model.__annotations__[field], type(LiteModel)): + model_fields.append(f'{self.FOREIGNID}{field}') + model_types.append('TEXT') + elif isinstance(model.__annotations__[field], types.GenericAlias): + model_fields.append(f'{self.JSON}{field}') + model_types.append('TEXT') + else: + model_fields.append(field) + model_types.append(self.type_map.get(model.__annotations__[field], 'TEXT')) # 检测新字段 - for field, type_, default in zip(model_fields, model_types, model_defaults): + for field, type_ in zip(model_fields, model_types): if field not in table_fields: - self.cursor.execute(f'ALTER TABLE {model.__name__} ADD COLUMN {field} {type_} DEFAULT {default}') + print(f'ALTER TABLE {table_name} ADD COLUMN {field} {type_}') + self.cursor.execute(f'ALTER TABLE {table_name} ADD COLUMN {field} {type_}') - # 检测多余字段 + # 检测多余字段,除了id字段 for field in table_fields: - if field not in model_fields: - self.cursor.execute(f'ALTER TABLE {model.__name__} DROP COLUMN {field}') + if field not in model_fields and field != 'id': + self.cursor.execute(f'ALTER TABLE {table_name} DROP COLUMN {field}') self.conn.commit() - def save(self, *models: LiteModel) -> int: - """存储数据 + def save(self, *models: LiteModel) -> int | tuple: + """存储数据,检查id字段,如果有id字段则更新,没有则插入 Args: models: 数据 Returns: - id: 数据id,多个数据返回最后一个数据id + id: 数据id,如果有多个数据则返回id元组 """ - _id = 0 + ids = [] for model in models: table_name = model.__class__.__name__ key_list = [] @@ -136,41 +166,145 @@ class SqliteORMAdapter(BaseORMAdapter): # 处理外键,添加前缀'$IDFieldName' for field, value in model.__dict__.items(): if isinstance(value, LiteModel): - key_list.append(f'$id:{field}') - value_list.append(f'{value.__class__.__name__}:{self.save(value)}') - elif isinstance(value, list | tuple | set): - key_list.append(field) - value_list.append(json.dumps(value)) + key_list.append(f'{self.FOREIGNID}{field}') + value_list.append(f'{self.ID}:{value.__class__.__name__}:{self.save(value)}') + elif isinstance(value, BaseIterable): + key_list.append(f'{self.JSON}{field}') + value_list.append(self.flat(value)) else: key_list.append(field) value_list.append(value) + # 更新或插入数据,用?占位 + self.cursor.execute(f'INSERT OR REPLACE INTO {table_name} ({",".join(key_list)}) VALUES ({",".join(["?" for _ in key_list])})', value_list) - def flat(self, data: Iterable) -> Any: - if isinstance(data, dict): - for k, v in data.items(): - if isinstance(v, dict | list | tuple | set): - self.flat(v) - else: - print(k, v) + ids.append(self.cursor.lastrowid) + self.conn.commit() + return ids[0] if len(ids) == 1 else tuple(ids) - def first(self, model: type(LiteModel), *args, **kwargs): - self.cursor.execute(f'SELECT * FROM {model.__name__} WHERE {args[0]}', args[1]) - return self.convert2dict(model, self.cursor.fetchone()) - - def convert2dict(self, data: dict) -> dict | list: - """将模型转换为dict + def flat(self, data: Iterable) -> str: + """扁平化数据,返回扁平化对象 Args: - data: 数据库查询结果 + data: 数据,可迭代对象 + + Returns: json字符串 + """ + if isinstance(data, dict): + return_data = {} + for k, v in data.items(): + if isinstance(v, LiteModel): + return_data[f'{self.FOREIGNID}{k}'] = f'{self.ID}:{v.__class__.__name__}:{self.save(v)}' + elif isinstance(v, BaseIterable): + return_data[f'{self.JSON}{k}'] = self.flat(v) + else: + return_data[k] = v + + elif isinstance(data, list | tuple | set): + return_data = [] + for v in data: + if isinstance(v, LiteModel): + return_data.append(f'{self.ID}:{v.__class__.__name__}:{self.save(v)}') + elif isinstance(v, BaseIterable): + return_data.append(self.flat(v)) + else: + return_data.append(v) + else: + raise ValueError('数据类型错误') + + return json.dumps(return_data) + + def first(self, model: type(LiteModel), conditions, *args, default: Any = None) -> LiteModel | None: + """查询第一条数据 + + Args: + model: 模型 + conditions: 查询条件 + *args: 参数化查询条件参数 + default: 未查询到结果默认返回值 + + Returns: 数据 + """ + table_name = model.__name__ + self.cursor.execute(f'SELECT * FROM {table_name} WHERE {conditions}', args) + data = dict(self.cursor.fetchone()) + return model(**self.convert_to_dict(data)) if data else default + + def all(self, model: type(LiteModel), conditions, *args, default: Any = None) -> list[LiteModel] | None: + """查询所有数据 + + Args: + model: 模型 + conditions: 查询条件 + *args: 参数化查询条件参数 + default: 未查询到结果默认返回值 + + Returns: 数据 + """ + table_name = model.__name__ + self.cursor.execute(f'SELECT * FROM {table_name} WHERE {conditions}', args) + data = self.cursor.fetchall() + return [model(**self.convert_to_dict(d)) for d in data] if data else default + + def delete(self, model: type(LiteModel), conditions, *args): + """删除数据 + + Args: + model: 模型 + conditions: 查询条件 + *args: 参数化查询条件参数 Returns: - 模型对象 + """ - for field, value in data.items(): - if field.startswith('$id'): - # 外键处理 - table_name = value[1:].split('_')[0] - id_ = value.split('_')[1] - key_tuple = self.cursor.execute(f'PRAGMA table_info({table_name})').fetchall() - value_tuple = self.cursor.execute(f'SELECT * FROM {table_name} WHERE id = ?', id_).fetchone() - field_dict = dict(zip([field[1] for field in key_tuple], value_tuple)) + table_name = model.__name__ + self.cursor.execute(f'DELETE FROM {table_name} WHERE {conditions}', args) + self.conn.commit() + + def update(self, model: type(LiteModel), conditions: str, *args, operation: str): + """更新数据 + + Args: + model: 模型 + conditions: 查询条件 + *args: 参数化查询条件参数 + operation: 更新操作 + + Returns: + + """ + table_name = model.__name__ + self.cursor.execute(f'UPDATE {table_name} SET {operation} WHERE {conditions}', args) + self.conn.commit() + + def convert_to_dict(self, data: dict) -> dict: + """将json字符串转换为字典 + + Args: + data: json字符串 + + Returns: 字典 + """ + + def load(d: BaseIterable) -> BaseIterable: + """递归加载数据,去除前缀""" + if isinstance(d, dict): + new_d = {} + for k, v in d.items(): + if k.startswith(self.FOREIGNID): + new_d[k.replace(self.FOREIGNID, '')] = load(dict(self.cursor.execute(f'SELECT * FROM {v.split(":")[1]} WHERE id = ?', (v.split(":")[2],)).fetchone())) + elif k.startswith(self.JSON): + new_d[k.replace(self.JSON, '')] = load(json.loads(v)) + else: + new_d[k] = v + elif isinstance(d, list | tuple | set): + new_d = [] + for i, v in enumerate(d): + if isinstance(v, str) and v.startswith(self.ID): + new_d.append(load(dict(self.cursor.execute(f'SELECT * FROM {v.split(":")[1]} WHERE id = ?', (v.split(":")[2],)).fetchone()))) + elif isinstance(v, BaseIterable): + new_d.append(load(v)) + else: + new_d = d + return new_d + + return load(data)