update: Liteyuki ORM

This commit is contained in:
远野千束 2024-03-01 23:24:36 +08:00
parent a3f63e383d
commit 8303514fa0

View File

@ -1,241 +1,176 @@
import json
import sqlite3 import sqlite3
import uuid from abc import ABC
from collections.abc import Iterable
from typing import Any from typing import Any
from pymongo import MongoClient
from pydantic import BaseModel from pydantic import BaseModel
BaseIterable = list | tuple | set | dict
class LiteModel(BaseModel): class LiteModel(BaseModel):
pass pass
class BaseORM: class BaseORMAdapter(ABC):
def __init__(self):
def __init__(self, *args, **kwargs):
pass pass
def auto_migrate(self, *args, **kwargs): def auto_migrate(self, *args, **kwargs):
"""自动迁移数据库 """自动迁移
Args:
*args:
**kwargs:
Returns: Returns:
""" """
raise NotImplementedError raise NotImplementedError
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
"""创建数据 """存储数据
Args:
*args:
**kwargs:
Returns: Returns:
"""
raise NotImplementedError
def update(self, *args, **kwargs):
"""更新数据
Args:
*args:
**kwargs:
Returns:
"""
raise NotImplementedError
def delete(self, *args, **kwargs):
"""删除数据
Args:
*args:
**kwargs:
Returns:
""" """
raise NotImplementedError raise NotImplementedError
def first(self, *args, **kwargs): def first(self, *args, **kwargs):
"""查询第一条数据 """查询第一条数据
Args:
*args:
**kwargs:
Returns: Returns:
"""
raise NotImplementedError
def where(self, *args, **kwargs):
"""查询数据
Args:
*args:
**kwargs:
Returns:
""" """
raise NotImplementedError raise NotImplementedError
def all(self, *args, **kwargs): def all(self, *args, **kwargs):
"""查询所有数据 """查询所有数据
Args:
*args:
**kwargs:
Returns: Returns:
"""
raise NotImplementedError
def delete(self, *args, **kwargs):
"""删除数据
Returns:
"""
raise NotImplementedError
def update(self, *args, **kwargs):
"""更新数据
Returns:
""" """
raise NotImplementedError raise NotImplementedError
class SqliteORM(BaseORM): class SqliteORMAdapter(BaseORMAdapter):
"""同步sqlite数据库操作"""
type_map = { type_map = {
int: 'INTEGER', # default: TEXT
str : 'TEXT',
int : 'INTEGER',
float: 'REAL', float: 'REAL',
str: 'TEXT', bool : 'INTEGER',
bool: 'INTEGER', list : 'TEXT'
} }
def __init__(self, db, *args, **kwargs): def __init__(self, db_name: str):
super().__init__(*args, **kwargs) super().__init__()
self.db = sqlite3.connect(db) self.conn = sqlite3.connect(db_name)
self.cursor = self.conn.cursor()
@staticmethod def auto_migrate(self, *args: LiteModel):
def get_model_table_name(model: type(LiteModel) | LiteModel | str) -> str: """自动迁移,检测新模型字段和原有表字段的差异,如有差异自动增删新字段
"""获取模型对应的表名"""
if isinstance(model, str):
return model
elif isinstance(model, LiteModel):
return model.__class__.__name__
elif isinstance(model, type(LiteModel)):
return model.__name__
def auto_migrate(self, *args: type(LiteModel) | LiteModel | str, **kwargs):
"""自动迁移数据库
Args: Args:
*args: BaseModel *args:
**kwargs:
delete_old_columns: bool = False # 是否删除旧字段
add_new_columns: bool = True # 添加新字段
Returns: Returns:
""" """
for model in args: for model in args:
# 获取模型对应的表名 # 检测并创建表
table_name = self.get_model_table_name(model) self.cursor.execute(f'CREATE TABLE IF NOT EXISTS {model.__name__}(id INTEGER PRIMARY KEY AUTOINCREMENT)')
# 获取表字段
self.cursor.execute(f'PRAGMA table_info({model.__name__})')
table_fields = self.cursor.fetchall()
table_fields = [field[1] for field in table_fields]
# 获取表中已有的字段 # 获取模型字段
existing_columns = set() model_fields = model.__dict__.keys()
cursor = self.db.execute(f"PRAGMA table_info({table_name})")
for column_info in cursor.fetchall():
existing_columns.add(column_info[1])
# 创建表,如果不存在的话 # 获取模型字段类型
self.db.execute(f'CREATE TABLE IF NOT EXISTS {table_name} (id INTEGER PRIMARY KEY AUTOINCREMENT)') model_types = [self.type_map.get(type(model.__dict__[field]), 'TEXT') for field in model_fields]
# 检测模型中的字段并添加新字段,按照类型添加 # 获取模型字段默认值
for field_name, field_type in model.__annotations__.items(): model_defaults = [model.__dict__[field] for field in model_fields]
if field_name not in existing_columns:
self.db.execute(f'ALTER TABLE {table_name} ADD COLUMN {field_name} {self.type_map.get(field_type, "TEXT")}')
# 提交事务 # 检测新字段
self.db.commit() for field, type_, default in zip(model_fields, model_types, model_defaults):
if field not in table_fields:
self.cursor.execute(f'ALTER TABLE {model.__name__} ADD COLUMN {field} {type_} DEFAULT {default}')
# 检测多余字段
for field in table_fields:
if field not in model_fields:
self.cursor.execute(f'ALTER TABLE {model.__name__} DROP COLUMN {field}')
self.conn.commit()
def save(self, *models: LiteModel) -> int:
"""存储数据
def save(self, model: LiteModel) -> int:
"""保存或创建数据对嵌套模型扁平化处理加特殊前缀表示为模型实际储存模型id嵌套对象单独储存
Args: Args:
model: BaseModel models: 数据
Returns: id主键
"""
# 先检测表是否存在,不存在则创建
table_name = self.get_model_table_name(model)
self.db.execute(f'CREATE TABLE IF NOT EXISTS {table_name} (id INTEGER PRIMARY KEY AUTOINCREMENT)')
# 构建插入语句 Returns:
column_list = [] id: 数据id多个数据返回最后一个数据id
"""
_id = 0
for model in models:
table_name = model.__class__.__name__
key_list = []
value_list = [] value_list = []
for key, value in model.dict().items(): # 处理外键,添加前缀'$IDFieldName'
for field, value in model.__dict__.items():
if isinstance(value, LiteModel): if isinstance(value, LiteModel):
# 如果是嵌套模型,先保存嵌套模型 key_list.append(f'$id:{field}')
nested_model_id = self.save(value) value_list.append(f'{value.__class__.__name__}:{self.save(value)}')
# 保存嵌套模型的id, 以特殊前缀表示为模型 elif isinstance(value, list | tuple | set):
column_list.append(f'$id_{key}') key_list.append(field)
value_list.append(f'{value.__class__.__name__}_{nested_model_id}') value_list.append(json.dumps(value))
elif isinstance(value, list):
# 如果是列表,先保存列表中的所有嵌套模型,有可能有多种类型的嵌套模型
# 列表内存'ModelType_ModelId',以特殊前缀表示为模型类型
nested_model_ids = []
for nested_model in value:
nested_model_id = self.save(nested_model)
nested_model_ids.append(f'{nested_model.__class__.__name__}_{nested_model_id}')
column_list.append(f'$ids_{key}')
value_list.append(nested_model_ids)
columns = ', '.join(column_list)
placeholders = ', '.join(['?' for _ in value_list])
values = tuple(value_list)
print(model.dict())
print(table_name, columns, placeholders, values)
# 插入数据
self.db.execute(f'INSERT INTO {table_name} ({columns}) VALUES ({placeholders})', values)
self.db.commit()
return self.db.execute(f'SELECT last_insert_rowid()').fetchone()[0]
def where(self, model_type: type(LiteModel) | str, conditions: str, *args, objectify: bool = True) -> list[LiteModel]:
"""查询数据
Args:
objectify: bool: 是否将查询结果转换为模型
model_type: BaseModel
conditions: str
*args:
Returns:
"""
table_name = self.get_model_table_name(model_type)
self.db.execute(f'CREATE TABLE IF NOT EXISTS {table_name} (id INTEGER PRIMARY KEY AUTOINCREMENT)')
return [self._convert_to_model(model_type, item) for item in self.db.execute(f'SELECT * FROM {table_name} WHERE {conditions}', args).fetchall()]
def first(self, model_type: type(LiteModel) | str, conditions: str, *args, objectify: bool = True):
"""查询第一条数据"""
table_name = self.get_model_table_name(model_type)
self.db.execute(f'CREATE TABLE IF NOT EXISTS {table_name} (id INTEGER PRIMARY KEY AUTOINCREMENT)')
return self._convert_to_model(model_type, self.db.execute(f'SELECT * FROM {table_name} WHERE {conditions}', args).fetchone())
def all(self, model_type: type(LiteModel) | str, objectify: bool = True):
"""查询所有数据"""
table_name = self.get_model_table_name(model_type)
self.db.execute(f'CREATE TABLE IF NOT EXISTS {table_name} (id INTEGER PRIMARY KEY AUTOINCREMENT)')
return [self._convert_to_model(model_type, item) for item in self.db.execute(f'SELECT * FROM {table_name}').fetchall()]
def update(self, model_type: type(LiteModel) | str, operation: str, conditions: str, *args):
"""更新数据
Args:
model_type: BaseModel
operation: str: 更新操作
conditions: str: 查询条件
*args:
Returns:
"""
table_name = self.get_model_table_name(model_type)
self.db.execute(f'CREATE TABLE IF NOT EXISTS {table_name} (id INTEGER PRIMARY KEY AUTOINCREMENT)')
self.db.execute(f'UPDATE {table_name} SET {operation} WHERE {conditions}', args)
self.db.commit()
def _convert_to_model(self, model_type: type(LiteModel), item: tuple) -> LiteModel:
"""将查询结果转换为模型,处理嵌套模型"""
# 获取表中已有的字段,再用字段值构建字典
table_name = self.get_model_table_name(model_type)
cursor = self.db.execute(f"PRAGMA table_info({table_name})")
columns = [column_info[1] for column_info in cursor.fetchall()]
item_dict = dict(zip(columns, item))
# 遍历字典,处理嵌套模型
new_item_dict = {}
for key, value in item_dict.items():
if key.startswith('$id_'):
# 处理单个嵌套模型类型时从model_type中获取键
new_item_dict[key.replace('$id_', '')] = self.first(model_type.__annotations__[key.replace('$id_', '')], 'id = ?', value.split('_')[-1])
elif key.startswith('$ids_'):
# 处理多个嵌套模型类型使用eval获取数据库对应索引的键
new_item_dict[key.replace('$ids_', '')] = [self.first(eval(type_id.split('_')[0]), 'id = ?', type_id.split('_')[-1]) for type_id in value]
else: else:
new_item_dict[key] = value key_list.append(field)
value_list.append(value)
return model_type(**new_item_dict) def flat(self, data: Iterable) -> Any:
if isinstance(data, dict):
for k, v in data.items():
if isinstance(v, dict | list | tuple | set):
self.flat(v)
else:
print(k, v)
def first(self, model: type(LiteModel), *args, **kwargs):
self.cursor.execute(f'SELECT * FROM {model.__name__} WHERE {args[0]}', args[1])
return self.convert2dict(model, self.cursor.fetchone())
def convert2dict(self, data: dict) -> dict | list:
"""将模型转换为dict
Args:
data: 数据库查询结果
Returns:
模型对象
"""
for field, value in data.items():
if field.startswith('$id'):
# 外键处理
table_name = value[1:].split('_')[0]
id_ = value.split('_')[1]
key_tuple = self.cursor.execute(f'PRAGMA table_info({table_name})').fetchall()
value_tuple = self.cursor.execute(f'SELECT * FROM {table_name} WHERE id = ?', id_).fetchone()
field_dict = dict(zip([field[1] for field in key_tuple], value_tuple))