mirror of
https://github.com/TriM-Organization/LiteyukiBot-TriM.git
synced 2024-11-25 08:35:03 +08:00
208 lines
7.5 KiB
Python
208 lines
7.5 KiB
Python
|
import os
|
|||
|
import pickle
|
|||
|
import sqlite3
|
|||
|
from types import NoneType
|
|||
|
from typing import Any
|
|||
|
|
|||
|
import nonebot
|
|||
|
from pydantic import BaseModel, Field
|
|||
|
|
|||
|
|
|||
|
class LiteModel(BaseModel):
|
|||
|
"""轻量级模型基类
|
|||
|
类型注解统一使用Python3.9的PEP585标准,如需使用泛型请使用typing模块的泛型类型
|
|||
|
不允许使用id, table_name以及其他SQLite关键字作为字段名,不允许使用JSON和ID,必须指定默认值,且默认值类型必须与字段类型一致
|
|||
|
"""
|
|||
|
__ID__: int = Field(None, alias='id')
|
|||
|
__TABLE_NAME__: str = Field(None, alias='table_name')
|
|||
|
|
|||
|
|
|||
|
class Database:
|
|||
|
TYPE_MAPPING = {
|
|||
|
int : "INTEGER",
|
|||
|
float : "REAL",
|
|||
|
str : "TEXT",
|
|||
|
bool : "INTEGER",
|
|||
|
bytes : "BLOB",
|
|||
|
NoneType: "NULL",
|
|||
|
|
|||
|
dict : "BLOB", # LITEYUKIDICT{key_name}
|
|||
|
list : "BLOB", # LITEYUKILIST{key_name}
|
|||
|
tuple : "BLOB", # LITEYUKITUPLE{key_name}
|
|||
|
set : "BLOB", # LITEYUKISET{key_name}
|
|||
|
}
|
|||
|
|
|||
|
# 基础类型
|
|||
|
BASIC_TYPE = [int, float, str, bool, bytes, NoneType]
|
|||
|
# 可序列化类型
|
|||
|
ITERABLE_TYPE = [dict, list, tuple, set]
|
|||
|
|
|||
|
LITEYUKI = "LITEYUKI"
|
|||
|
|
|||
|
# 字段前缀映射,默认基础类型为""
|
|||
|
FIELD_PREFIX_MAPPING = {
|
|||
|
dict : f"{LITEYUKI}DICT",
|
|||
|
list : f"{LITEYUKI}LIST",
|
|||
|
tuple : f"{LITEYUKI}TUPLE",
|
|||
|
set : f"{LITEYUKI}SET",
|
|||
|
type(LiteModel): f"{LITEYUKI}MODEL"
|
|||
|
}
|
|||
|
|
|||
|
def __init__(self, db_name: str):
|
|||
|
if not os.path.exists(os.path.dirname(db_name)):
|
|||
|
os.makedirs(os.path.dirname(db_name))
|
|||
|
self.conn = sqlite3.connect(db_name) # 连接对象
|
|||
|
self.conn.row_factory = sqlite3.Row # 以字典形式返回查询结果
|
|||
|
self.cursor = self.conn.cursor() # 游标对象
|
|||
|
|
|||
|
def auto_migrate(self, *args: LiteModel):
|
|||
|
"""
|
|||
|
自动迁移模型
|
|||
|
Args:
|
|||
|
*args: 模型类实例化对象,支持空默认值,不支持嵌套迁移
|
|||
|
|
|||
|
Returns:
|
|||
|
|
|||
|
"""
|
|||
|
for model in args:
|
|||
|
if not model.__TABLE_NAME__:
|
|||
|
raise ValueError(f"数据模型{model.__class__.__name__}未提供表名")
|
|||
|
|
|||
|
# 若无则创建表
|
|||
|
self.cursor.execute(
|
|||
|
f'CREATE TABLE IF NOT EXISTS {model.__TABLE_NAME__} (id INTEGER PRIMARY KEY AUTOINCREMENT)'
|
|||
|
)
|
|||
|
|
|||
|
# 获取表结构
|
|||
|
new_fields, new_stored_types = (
|
|||
|
zip(
|
|||
|
*[(self._get_stored_field_prefix(model.__getattribute__(field)) + field, self._get_stored_type(model.__getattribute__(field)))
|
|||
|
for field in model.__annotations__]
|
|||
|
)
|
|||
|
)
|
|||
|
|
|||
|
# 原有的字段列表
|
|||
|
existing_fields = self.cursor.execute(f'PRAGMA table_info({model.__TABLE_NAME__})').fetchall()
|
|||
|
existing_types = [field['name'] for field in existing_fields]
|
|||
|
|
|||
|
# 检测缺失字段,由于SQLite是动态类型,所以不需要检测类型
|
|||
|
for n_field, n_type in zip(new_fields, new_stored_types):
|
|||
|
if n_field not in existing_types:
|
|||
|
nonebot.logger.debug(f'ALTER TABLE {model.__TABLE_NAME__} ADD COLUMN {n_field} {n_type}')
|
|||
|
self.cursor.execute(
|
|||
|
f'ALTER TABLE {model.__TABLE_NAME__} ADD COLUMN {n_field} {n_type}'
|
|||
|
)
|
|||
|
|
|||
|
# 检测多余字段进行删除
|
|||
|
for e_field in existing_types:
|
|||
|
if e_field not in new_fields and e_field not in ['id']:
|
|||
|
nonebot.logger.debug(f'ALTER TABLE {model.__TABLE_NAME__} DROP COLUMN {e_field}')
|
|||
|
self.cursor.execute(
|
|||
|
f'ALTER TABLE {model.__TABLE_NAME__} DROP COLUMN {e_field}'
|
|||
|
)
|
|||
|
|
|||
|
self.conn.commit()
|
|||
|
|
|||
|
def save(self, *args: LiteModel) -> [int | tuple[int, ...]]:
|
|||
|
"""
|
|||
|
保存或更新模型
|
|||
|
Args:
|
|||
|
*args: 模型类实例化对象,支持空默认值,不支持嵌套迁移
|
|||
|
Returns:
|
|||
|
|
|||
|
"""
|
|||
|
ids = []
|
|||
|
for model in args:
|
|||
|
if not model.__TABLE_NAME__:
|
|||
|
raise ValueError(f"数据模型{model.__class__.__name__}未提供表名")
|
|||
|
if not self.cursor.execute(f'PRAGMA table_info({model.__TABLE_NAME__})').fetchall():
|
|||
|
raise ValueError(f"数据表{model.__TABLE_NAME__}不存在,请先迁移{model.__class__.__name__}模型")
|
|||
|
|
|||
|
stored_fields, stored_values = [], []
|
|||
|
for r_field in model.__annotations__:
|
|||
|
r_value = model.__getattribute__(r_field)
|
|||
|
stored_fields.append(self._get_stored_field_prefix(r_value) + r_field)
|
|||
|
|
|||
|
if type(r_value) in Database.BASIC_TYPE:
|
|||
|
# int str float bool bytes NoneType
|
|||
|
stored_values.append(r_value)
|
|||
|
|
|||
|
elif type(r_value) in Database.ITERABLE_TYPE:
|
|||
|
# dict list tuple set
|
|||
|
stored_values.append(pickle.dumps(self._flat_save(r_value)))
|
|||
|
|
|||
|
elif isinstance(r_value, LiteModel):
|
|||
|
# LiteModel TABLE_NAME:ID
|
|||
|
stored_values.append(f"{r_value.__TABLE_NAME__}:{self.save(r_value)}")
|
|||
|
|
|||
|
else:
|
|||
|
raise ValueError(f"不支持的数据类型{type(r_value)}")
|
|||
|
nonebot.logger.debug(f"INSERT OR REPLACE INTO {model.__TABLE_NAME__} ({','.join(stored_fields)}) VALUES ({','.join([_ for _ in stored_values])})")
|
|||
|
self.cursor.execute(
|
|||
|
f"INSERT OR REPLACE INTO {model.__TABLE_NAME__} ({','.join(stored_fields)}) VALUES ({','.join(['?' for _ in stored_values])})",
|
|||
|
stored_values
|
|||
|
)
|
|||
|
ids.append(self.cursor.lastrowid)
|
|||
|
self.conn.commit()
|
|||
|
return tuple(ids) if len(ids) > 1 else ids[0]
|
|||
|
|
|||
|
# 检测id字段是否有1,有则更新,无则插入
|
|||
|
|
|||
|
def _flat_save(self, obj) -> Any:
|
|||
|
"""扁平化存储
|
|||
|
|
|||
|
Args:
|
|||
|
obj: 需要存储的对象
|
|||
|
|
|||
|
Returns:
|
|||
|
存储的字节流
|
|||
|
"""
|
|||
|
# TODO 递归扁平化存储
|
|||
|
if type(obj) in Database.ITERABLE_TYPE:
|
|||
|
for i, item in enumerate(obj) if type(obj) in [list, tuple, set] else obj.items():
|
|||
|
if type(item) in Database.BASIC_TYPE:
|
|||
|
continue
|
|||
|
elif type(item) in Database.ITERABLE_TYPE:
|
|||
|
obj[i] = pickle.dumps(self._flat_save(item))
|
|||
|
elif isinstance(item, LiteModel):
|
|||
|
obj[i] = f"{item.__TABLE_NAME__}:{self.save(item)}"
|
|||
|
else:
|
|||
|
raise ValueError(f"不支持的数据类型{type(item)}")
|
|||
|
else:
|
|||
|
raise ValueError(f"不支持的数据类型{type(obj)}")
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def _get_stored_field_prefix(value) -> str:
|
|||
|
"""获取存储字段前缀,一定在后加上字段名
|
|||
|
|
|||
|
LiteModel -> LITEYUKIID
|
|||
|
|
|||
|
dict -> LITEYUKIDICT
|
|||
|
|
|||
|
list -> LITEYUKILIST
|
|||
|
|
|||
|
tuple -> LITEYUKITUPLE
|
|||
|
|
|||
|
set -> LITEYUKISET
|
|||
|
|
|||
|
* -> ""
|
|||
|
Args:
|
|||
|
value: 储存的值
|
|||
|
|
|||
|
Returns:
|
|||
|
Sqlite3存储字段
|
|||
|
"""
|
|||
|
return Database.FIELD_PREFIX_MAPPING.get(type(value), "")
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def _get_stored_type(value) -> str:
|
|||
|
"""获取存储类型
|
|||
|
|
|||
|
Args:
|
|||
|
value: 储存的值
|
|||
|
|
|||
|
Returns:
|
|||
|
Sqlite3存储类型
|
|||
|
"""
|
|||
|
return Database.TYPE_MAPPING.get(type(value), "TEXT")
|