2024-03-01 23:24:36 +08:00
|
|
|
|
import json
|
2024-03-01 00:07:49 +08:00
|
|
|
|
import sqlite3
|
2024-03-01 23:24:36 +08:00
|
|
|
|
from abc import ABC
|
|
|
|
|
from collections.abc import Iterable
|
2024-03-01 00:07:49 +08:00
|
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
2024-03-01 23:24:36 +08:00
|
|
|
|
BaseIterable = list | tuple | set | dict
|
|
|
|
|
|
2024-03-01 00:07:49 +08:00
|
|
|
|
|
|
|
|
|
class LiteModel(BaseModel):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
2024-03-01 23:24:36 +08:00
|
|
|
|
class BaseORMAdapter(ABC):
|
|
|
|
|
def __init__(self):
|
2024-03-01 00:07:49 +08:00
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def auto_migrate(self, *args, **kwargs):
|
2024-03-01 23:24:36 +08:00
|
|
|
|
"""自动迁移
|
2024-03-01 00:07:49 +08:00
|
|
|
|
|
|
|
|
|
Returns:
|
2024-03-01 23:24:36 +08:00
|
|
|
|
|
2024-03-01 00:07:49 +08:00
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
def save(self, *args, **kwargs):
|
2024-03-01 23:24:36 +08:00
|
|
|
|
"""存储数据
|
2024-03-01 00:07:49 +08:00
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
2024-03-01 23:24:36 +08:00
|
|
|
|
def first(self, *args, **kwargs):
|
|
|
|
|
"""查询第一条数据
|
2024-03-01 00:07:49 +08:00
|
|
|
|
|
|
|
|
|
Returns:
|
2024-03-01 23:24:36 +08:00
|
|
|
|
|
2024-03-01 00:07:49 +08:00
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
2024-03-01 23:24:36 +08:00
|
|
|
|
def all(self, *args, **kwargs):
|
|
|
|
|
"""查询所有数据
|
2024-03-01 00:07:49 +08:00
|
|
|
|
|
|
|
|
|
Returns:
|
2024-03-01 23:24:36 +08:00
|
|
|
|
|
2024-03-01 00:07:49 +08:00
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
2024-03-01 23:24:36 +08:00
|
|
|
|
def delete(self, *args, **kwargs):
|
|
|
|
|
"""删除数据
|
2024-03-01 00:07:49 +08:00
|
|
|
|
|
|
|
|
|
Returns:
|
2024-03-01 23:24:36 +08:00
|
|
|
|
|
2024-03-01 00:07:49 +08:00
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
2024-03-01 23:24:36 +08:00
|
|
|
|
def update(self, *args, **kwargs):
|
|
|
|
|
"""更新数据
|
2024-03-01 00:07:49 +08:00
|
|
|
|
|
|
|
|
|
Returns:
|
2024-03-01 23:24:36 +08:00
|
|
|
|
|
2024-03-01 00:07:49 +08:00
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
2024-03-01 23:24:36 +08:00
|
|
|
|
class SqliteORMAdapter(BaseORMAdapter):
|
2024-03-01 00:07:49 +08:00
|
|
|
|
type_map = {
|
2024-03-01 23:24:36 +08:00
|
|
|
|
# default: TEXT
|
|
|
|
|
str : 'TEXT',
|
|
|
|
|
int : 'INTEGER',
|
|
|
|
|
float: 'REAL',
|
|
|
|
|
bool : 'INTEGER',
|
|
|
|
|
list : 'TEXT'
|
2024-03-01 00:07:49 +08:00
|
|
|
|
}
|
|
|
|
|
|
2024-03-01 23:24:36 +08:00
|
|
|
|
def __init__(self, db_name: str):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.conn = sqlite3.connect(db_name)
|
|
|
|
|
self.cursor = self.conn.cursor()
|
|
|
|
|
|
|
|
|
|
def auto_migrate(self, *args: LiteModel):
|
|
|
|
|
"""自动迁移,检测新模型字段和原有表字段的差异,如有差异自动增删新字段
|
|
|
|
|
|
2024-03-01 00:07:49 +08:00
|
|
|
|
Args:
|
2024-03-01 23:24:36 +08:00
|
|
|
|
*args:
|
|
|
|
|
|
2024-03-01 00:07:49 +08:00
|
|
|
|
Returns:
|
2024-03-01 23:24:36 +08:00
|
|
|
|
|
2024-03-01 00:07:49 +08:00
|
|
|
|
"""
|
|
|
|
|
for model in args:
|
2024-03-01 23:24:36 +08:00
|
|
|
|
# 检测并创建表
|
|
|
|
|
self.cursor.execute(f'CREATE TABLE IF NOT EXISTS {model.__name__}(id INTEGER PRIMARY KEY AUTOINCREMENT)')
|
|
|
|
|
# 获取表字段
|
|
|
|
|
self.cursor.execute(f'PRAGMA table_info({model.__name__})')
|
|
|
|
|
table_fields = self.cursor.fetchall()
|
|
|
|
|
table_fields = [field[1] for field in table_fields]
|
2024-03-01 00:07:49 +08:00
|
|
|
|
|
2024-03-01 23:24:36 +08:00
|
|
|
|
# 获取模型字段
|
|
|
|
|
model_fields = model.__dict__.keys()
|
2024-03-01 00:07:49 +08:00
|
|
|
|
|
2024-03-01 23:24:36 +08:00
|
|
|
|
# 获取模型字段类型
|
|
|
|
|
model_types = [self.type_map.get(type(model.__dict__[field]), 'TEXT') for field in model_fields]
|
2024-03-01 00:07:49 +08:00
|
|
|
|
|
2024-03-01 23:24:36 +08:00
|
|
|
|
# 获取模型字段默认值
|
|
|
|
|
model_defaults = [model.__dict__[field] for field in model_fields]
|
2024-03-01 00:07:49 +08:00
|
|
|
|
|
2024-03-01 23:24:36 +08:00
|
|
|
|
# 检测新字段
|
|
|
|
|
for field, type_, default in zip(model_fields, model_types, model_defaults):
|
|
|
|
|
if field not in table_fields:
|
|
|
|
|
self.cursor.execute(f'ALTER TABLE {model.__name__} ADD COLUMN {field} {type_} DEFAULT {default}')
|
|
|
|
|
|
|
|
|
|
# 检测多余字段
|
|
|
|
|
for field in table_fields:
|
|
|
|
|
if field not in model_fields:
|
|
|
|
|
self.cursor.execute(f'ALTER TABLE {model.__name__} DROP COLUMN {field}')
|
|
|
|
|
|
|
|
|
|
self.conn.commit()
|
|
|
|
|
|
|
|
|
|
def save(self, *models: LiteModel) -> int:
|
|
|
|
|
"""存储数据
|
2024-03-01 00:07:49 +08:00
|
|
|
|
|
|
|
|
|
Args:
|
2024-03-01 23:24:36 +08:00
|
|
|
|
models: 数据
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
id: 数据id,多个数据返回最后一个数据id
|
2024-03-01 00:07:49 +08:00
|
|
|
|
"""
|
2024-03-01 23:24:36 +08:00
|
|
|
|
_id = 0
|
|
|
|
|
for model in models:
|
|
|
|
|
table_name = model.__class__.__name__
|
|
|
|
|
key_list = []
|
|
|
|
|
value_list = []
|
|
|
|
|
# 处理外键,添加前缀'$IDFieldName'
|
|
|
|
|
for field, value in model.__dict__.items():
|
|
|
|
|
if isinstance(value, LiteModel):
|
|
|
|
|
key_list.append(f'$id:{field}')
|
|
|
|
|
value_list.append(f'{value.__class__.__name__}:{self.save(value)}')
|
|
|
|
|
elif isinstance(value, list | tuple | set):
|
|
|
|
|
key_list.append(field)
|
|
|
|
|
value_list.append(json.dumps(value))
|
|
|
|
|
else:
|
|
|
|
|
key_list.append(field)
|
|
|
|
|
value_list.append(value)
|
|
|
|
|
|
|
|
|
|
def flat(self, data: Iterable) -> Any:
|
|
|
|
|
if isinstance(data, dict):
|
|
|
|
|
for k, v in data.items():
|
|
|
|
|
if isinstance(v, dict | list | tuple | set):
|
|
|
|
|
self.flat(v)
|
|
|
|
|
else:
|
|
|
|
|
print(k, v)
|
|
|
|
|
|
|
|
|
|
def first(self, model: type(LiteModel), *args, **kwargs):
|
|
|
|
|
self.cursor.execute(f'SELECT * FROM {model.__name__} WHERE {args[0]}', args[1])
|
|
|
|
|
return self.convert2dict(model, self.cursor.fetchone())
|
|
|
|
|
|
|
|
|
|
def convert2dict(self, data: dict) -> dict | list:
|
|
|
|
|
"""将模型转换为dict
|
|
|
|
|
|
2024-03-01 00:07:49 +08:00
|
|
|
|
Args:
|
2024-03-01 23:24:36 +08:00
|
|
|
|
data: 数据库查询结果
|
|
|
|
|
|
2024-03-01 00:07:49 +08:00
|
|
|
|
Returns:
|
2024-03-01 23:24:36 +08:00
|
|
|
|
模型对象
|
2024-03-01 00:07:49 +08:00
|
|
|
|
"""
|
2024-03-01 23:24:36 +08:00
|
|
|
|
for field, value in data.items():
|
|
|
|
|
if field.startswith('$id'):
|
|
|
|
|
# 外键处理
|
|
|
|
|
table_name = value[1:].split('_')[0]
|
|
|
|
|
id_ = value.split('_')[1]
|
|
|
|
|
key_tuple = self.cursor.execute(f'PRAGMA table_info({table_name})').fetchall()
|
|
|
|
|
value_tuple = self.cursor.execute(f'SELECT * FROM {table_name} WHERE id = ?', id_).fetchone()
|
|
|
|
|
field_dict = dict(zip([field[1] for field in key_tuple], value_tuple))
|