app/src/api/data.py

177 lines
4.9 KiB
Python
Raw Normal View History

2024-03-01 23:24:36 +08:00
import json
2024-03-01 00:07:49 +08:00
import sqlite3
2024-03-01 23:24:36 +08:00
from abc import ABC
from collections.abc import Iterable
2024-03-01 00:07:49 +08:00
from typing import Any
from pydantic import BaseModel
2024-03-01 23:24:36 +08:00
BaseIterable = list | tuple | set | dict
2024-03-01 00:07:49 +08:00
class LiteModel(BaseModel):
pass
2024-03-01 23:24:36 +08:00
class BaseORMAdapter(ABC):
def __init__(self):
2024-03-01 00:07:49 +08:00
pass
def auto_migrate(self, *args, **kwargs):
2024-03-01 23:24:36 +08:00
"""自动迁移
2024-03-01 00:07:49 +08:00
Returns:
2024-03-01 23:24:36 +08:00
2024-03-01 00:07:49 +08:00
"""
raise NotImplementedError
def save(self, *args, **kwargs):
2024-03-01 23:24:36 +08:00
"""存储数据
2024-03-01 00:07:49 +08:00
Returns:
"""
raise NotImplementedError
2024-03-01 23:24:36 +08:00
def first(self, *args, **kwargs):
"""查询第一条数据
2024-03-01 00:07:49 +08:00
Returns:
2024-03-01 23:24:36 +08:00
2024-03-01 00:07:49 +08:00
"""
raise NotImplementedError
2024-03-01 23:24:36 +08:00
def all(self, *args, **kwargs):
"""查询所有数据
2024-03-01 00:07:49 +08:00
Returns:
2024-03-01 23:24:36 +08:00
2024-03-01 00:07:49 +08:00
"""
raise NotImplementedError
2024-03-01 23:24:36 +08:00
def delete(self, *args, **kwargs):
"""删除数据
2024-03-01 00:07:49 +08:00
Returns:
2024-03-01 23:24:36 +08:00
2024-03-01 00:07:49 +08:00
"""
raise NotImplementedError
2024-03-01 23:24:36 +08:00
def update(self, *args, **kwargs):
"""更新数据
2024-03-01 00:07:49 +08:00
Returns:
2024-03-01 23:24:36 +08:00
2024-03-01 00:07:49 +08:00
"""
raise NotImplementedError
2024-03-01 23:24:36 +08:00
class SqliteORMAdapter(BaseORMAdapter):
2024-03-01 00:07:49 +08:00
type_map = {
2024-03-01 23:24:36 +08:00
# default: TEXT
str : 'TEXT',
int : 'INTEGER',
float: 'REAL',
bool : 'INTEGER',
list : 'TEXT'
2024-03-01 00:07:49 +08:00
}
2024-03-01 23:24:36 +08:00
def __init__(self, db_name: str):
super().__init__()
self.conn = sqlite3.connect(db_name)
self.cursor = self.conn.cursor()
def auto_migrate(self, *args: LiteModel):
"""自动迁移,检测新模型字段和原有表字段的差异,如有差异自动增删新字段
2024-03-01 00:07:49 +08:00
Args:
2024-03-01 23:24:36 +08:00
*args:
2024-03-01 00:07:49 +08:00
Returns:
2024-03-01 23:24:36 +08:00
2024-03-01 00:07:49 +08:00
"""
for model in args:
2024-03-01 23:24:36 +08:00
# 检测并创建表
self.cursor.execute(f'CREATE TABLE IF NOT EXISTS {model.__name__}(id INTEGER PRIMARY KEY AUTOINCREMENT)')
# 获取表字段
self.cursor.execute(f'PRAGMA table_info({model.__name__})')
table_fields = self.cursor.fetchall()
table_fields = [field[1] for field in table_fields]
2024-03-01 00:07:49 +08:00
2024-03-01 23:24:36 +08:00
# 获取模型字段
model_fields = model.__dict__.keys()
2024-03-01 00:07:49 +08:00
2024-03-01 23:24:36 +08:00
# 获取模型字段类型
model_types = [self.type_map.get(type(model.__dict__[field]), 'TEXT') for field in model_fields]
2024-03-01 00:07:49 +08:00
2024-03-01 23:24:36 +08:00
# 获取模型字段默认值
model_defaults = [model.__dict__[field] for field in model_fields]
2024-03-01 00:07:49 +08:00
2024-03-01 23:24:36 +08:00
# 检测新字段
for field, type_, default in zip(model_fields, model_types, model_defaults):
if field not in table_fields:
self.cursor.execute(f'ALTER TABLE {model.__name__} ADD COLUMN {field} {type_} DEFAULT {default}')
# 检测多余字段
for field in table_fields:
if field not in model_fields:
self.cursor.execute(f'ALTER TABLE {model.__name__} DROP COLUMN {field}')
self.conn.commit()
def save(self, *models: LiteModel) -> int:
"""存储数据
2024-03-01 00:07:49 +08:00
Args:
2024-03-01 23:24:36 +08:00
models: 数据
Returns:
id: 数据id多个数据返回最后一个数据id
2024-03-01 00:07:49 +08:00
"""
2024-03-01 23:24:36 +08:00
_id = 0
for model in models:
table_name = model.__class__.__name__
key_list = []
value_list = []
# 处理外键,添加前缀'$IDFieldName'
for field, value in model.__dict__.items():
if isinstance(value, LiteModel):
key_list.append(f'$id:{field}')
value_list.append(f'{value.__class__.__name__}:{self.save(value)}')
elif isinstance(value, list | tuple | set):
key_list.append(field)
value_list.append(json.dumps(value))
else:
key_list.append(field)
value_list.append(value)
def flat(self, data: Iterable) -> Any:
if isinstance(data, dict):
for k, v in data.items():
if isinstance(v, dict | list | tuple | set):
self.flat(v)
else:
print(k, v)
def first(self, model: type(LiteModel), *args, **kwargs):
self.cursor.execute(f'SELECT * FROM {model.__name__} WHERE {args[0]}', args[1])
return self.convert2dict(model, self.cursor.fetchone())
def convert2dict(self, data: dict) -> dict | list:
"""将模型转换为dict
2024-03-01 00:07:49 +08:00
Args:
2024-03-01 23:24:36 +08:00
data: 数据库查询结果
2024-03-01 00:07:49 +08:00
Returns:
2024-03-01 23:24:36 +08:00
模型对象
2024-03-01 00:07:49 +08:00
"""
2024-03-01 23:24:36 +08:00
for field, value in data.items():
if field.startswith('$id'):
# 外键处理
table_name = value[1:].split('_')[0]
id_ = value.split('_')[1]
key_tuple = self.cursor.execute(f'PRAGMA table_info({table_name})').fetchall()
value_tuple = self.cursor.execute(f'SELECT * FROM {table_name} WHERE id = ?', id_).fetchone()
field_dict = dict(zip([field[1] for field in key_tuple], value_tuple))