fix: 数据库支持

This commit is contained in:
snowy 2024-03-26 21:33:40 +08:00
parent 90c9ef31a1
commit 58e603e1ad
14 changed files with 528 additions and 728 deletions

89
docs/README.md Normal file
View File

@ -0,0 +1,89 @@
<div align=center><h2><font color="#d0e9ff">轻雪</font><font color="#a2d8f4">6.2</font></h2></div>
<div align=center><h4>轻量,高效,易于扩展</h4></div>
- 基于[Nonebot2](https://github.com/nonebot/nonebot2),有良好的生态支持
- 开箱即用,无需复杂配置
- 新的点击交互模式,拒绝手打指令
- 全新可视化`npm`包管理,支持一键安装插件
- 支持一切Onebot标准通信
## 1.安装和部署
1. 安装`Git`和`Python3.10+`
2. 克隆项目`git clone https://github.com/snowykami/LiteyukiBot`
3. 切换目录`cd LiteyukiBot`
4. 安装依赖`pip install -r requirements.txt`(如果多个Python环境请指定后安装`pythonx -m pip install -r requirements.txt`)
5. 启动`python main.py`
## 2. 配置
### 轻雪配置项(Nonebot插件配置项也可以写在此与dotenv格式不同应为小写)
如果不确定字段的含义,请不要修改(部分在自动生成配置文件中未列出,需手动添加)
```yaml
# 生成文件的配置项
command_start: [ "/", " " ] # 指令前缀
host: 127.0.0.1 # 监听地址
port: 20216 # 绑定端口
nickname: [ "liteyuki" ] # 机器人昵称
superusers: [ "1919810" ] # 超级用户
# 未列出的配置项(如要自定义请手动修改)
onebot_access_token: "" # Onebot访问令牌[具体请看](https://onebot.adapters.nonebot.dev/docs/guide/configuration)
default_language: "zh-CN" # 默认语言
log_level: "INFO" # 日志等级
log_icon: true # 是否显示日志等级图标(某些控制台不可用)
auto_report: true # 是否自动上报问题给轻雪服务器,仅包含硬件信息和运行软件版本
# 其他Nonebot插件的配置项
custom_config_1: "custom_value1"
...
```
### Onebot实现端配置
不同的实现端给出的字段可能不同,但是基本上都是一样的,这里给出一个参考值
| 字段 | 参考值 | 说明 |
|-------------|--------------------------|----------------------------------|
| 协议 | 反向WebSocket | 推荐使用反向ws协议进行通信即轻雪作为服务端 |
| 地址 | ws://`addrss`/onebot/v11 | 地址取决于配置文件,本机默认为`127.0.0.1:20216` |
| AccessToken | `""` | 如果你给轻雪配置了`AccessToken`,请在此填写相同的值 |
## 3.其他
### 常见问题
- 设备上Python环境太乱了pip和python不对应怎么办
- 请使用`/path/to/python -m pip install -r requirements.txt`来安装依赖,
然后用`/path/to/python main.py`来启动Bot
其中`/path/to/python`是你要用来运行Bot可执行文件
- 为什么我启动后机器人没有反应?
- 请检查配置文件的`command_start`或`superusers`,确认你有权限使用命令并按照正确的命令发送
- 怎么登录QQ等聊天平台
- 你有这个问题说明你不是很了解这个项目,本项目不负责实现登录功能,只负责处理消息
你需要使用Onebot标准的实现端来连接到轻雪并将消息上报给轻雪下面已经列出一些推荐的实现端
#### 推荐方案(QQ)
1. [Lagrange.OneBot](https://github.com/KonataDev/Lagrange.Core)目前点按交互目前仅支持Lagrange
2. [LiteLoaderQQNT OneBot](https://github.com/LLOneBot/LLOneBot)基于NTQQ的Onebot实现
3. 云崽的`icqq-plugin`和`ws-plugin`进行通信
4. `Go-cqhttp`(目前已经半死不活了)
5. 人工实现的`Onebot`协议自己整一个WebSocket客户端看着QQ的消息然后给轻雪传输数据
#### 推荐方案(Minecraft)
1. 我们有专门为Minecraft开发的服务器Bot支持OnebotV11/12标准详细请看[MinecraftOneBot](https://github.com/snowykami/MinecraftOnebot)
使用其他项目连接请先自行查阅文档若有困难请联系对应开发者而不是Liteyuki的开发者
## 4.用户协议
1. 本项目遵循`MIT`协议,你可以自由使用,修改,分发,但是请保留原作者信息
2. 你可以选择开启`auto_report`(默认开启)轻雪会收集运行环境的设备信息通过安全的方式传输到轻雪服务器用于统计运行时的设备信息帮助我们改进轻雪收集的数据包括但不限于CPU内存插件信息异常信息会话负载(不含隐私部分)
3. 本项目不会收集用户的任何隐私信息,但请注意甄别第三方插件的安全性
## 5.鸣谢

View File

@ -1,12 +1,10 @@
import nonebot
from nonebot.plugin import PluginMetadata
from liteyuki.utils.language import get_default_lang
from liteyuki.utils.data_manager import *
from liteyuki.utils.language import get_default_lang
from .core import *
from .loader import *
from .webdash import *
from .core import *
from liteyuki.utils.config import config
from liteyuki.utils.liteyuki_api import liteyuki_api
__author__ = "snowykami"
__plugin_meta__ = PluginMetadata(
@ -20,8 +18,6 @@ __plugin_meta__ = PluginMetadata(
}
)
auto_migrate() # 自动迁移数据库
sys_lang = get_default_lang()
nonebot.logger.info(sys_lang.get("main.current_language", LANG=sys_lang.get("language.name")))
nonebot.logger.info(sys_lang.get("main.enable_webdash", URL=f"http://127.0.0.1:{config.get('port', 20216)}"))

View File

@ -16,11 +16,10 @@ nonebot.plugin.load_plugins("plugins")
init_log()
installed_plugins = plugin_db.all(InstalledPlugin)
installed_plugins: list[InstalledPlugin] = plugin_db.all(InstalledPlugin())
if installed_plugins:
for installed_plugin in plugin_db.all(InstalledPlugin):
if not check_for_package(installed_plugin.module_name):
for installed_plugin in installed_plugins:
if not installed_plugin.liteyuki and not check_for_package(installed_plugin.module_name):
nonebot.logger.error(f"{installed_plugin.module_name} not installed, but in loading database. please run `npm fixup` in chat to reinstall it.")
else:
print(installed_plugin.module_name)
nonebot.load_plugin(installed_plugin.module_name)

View File

@ -1,38 +1,35 @@
import sys
from typing import Optional
import nonebot
from nonebot import on_message, require
from nonebot.plugin import PluginMetadata
from liteyuki.utils.data import LiteModel
from liteyuki.utils.message import send_markdown
from liteyuki.utils.data import Database, LiteModel
from liteyuki.utils.ly_typing import T_Bot, T_MessageEvent
from liteyuki.utils.data import Database
from liteyuki.utils.message import send_markdown
require("nonebot_plugin_alconna")
from nonebot_plugin_alconna import on_alconna
from arclet.alconna import Arparma, Alconna, Args, Option, Subcommand, Arg
from arclet.alconna import Arparma, Alconna, Args, Option, Subcommand
class Node(LiteModel):
bot_id: str
session_type: str
session_id: str
TABLE_NAME = "node"
bot_id: str = ""
session_type: str = ""
session_id: str = ""
def __str__(self):
return f"{self.bot_id}.{self.session_type}.{self.session_id}"
class Push(LiteModel):
source: Node
target: Node
inde: int
TABLE_NAME = "push"
source: Node = Node()
target: Node = Node()
inde: int = 0
pushes_db = Database("data/pushes.ldb")
pushes_db.auto_migrate(Push, Node)
pushes_db.auto_migrate(Push(), Node())
alc = Alconna(
"lep",
@ -67,7 +64,7 @@ async def _(result: Arparma):
push1 = Push(
source=Node(bot_id=source[0], session_type=source[1], session_id=source[2]),
target=Node(bot_id=target[0], session_type=target[1], session_id=target[2]),
inde=len(pushes_db.all(Push, default=[]))
inde=len(pushes_db.all(Push(), default=[]))
)
pushes_db.upsert(push1)
@ -75,7 +72,7 @@ async def _(result: Arparma):
push2 = Push(
source=Node(bot_id=target[0], session_type=target[1], session_id=target[2]),
target=Node(bot_id=source[0], session_type=source[1], session_id=source[2]),
inde=len(pushes_db.all(Push, default=[]))
inde=len(pushes_db.all(Push(), default=[]))
)
pushes_db.upsert(push2)
await add_push.finish("添加成功")
@ -85,7 +82,7 @@ async def _(result: Arparma):
index = result.subcommands["rm"].args.get("index")
if index is not None:
try:
pushes_db.delete(Push, "inde = ?", index)
pushes_db.delete(Push(), "inde = ?", index)
await add_push.finish("删除成功")
except IndexError:
await add_push.finish("索引错误")
@ -95,19 +92,19 @@ async def _(result: Arparma):
await add_push.finish(
"\n".join([f"{push.inde} {push.source.bot_id}.{push.source.session_type}.{push.source.session_id} -> "
f"{push.target.bot_id}.{push.target.session_type}.{push.target.session_id}" for i, push in
enumerate(pushes_db.all(Push, default=[]))]))
enumerate(pushes_db.all(Push(), default=[]))]))
else:
await add_push.finish("参数错误")
@on_message(block=False).handle()
async def _(event: T_MessageEvent, bot: T_Bot):
for push in pushes_db.all(Push, default=[]):
for push in pushes_db.all(Push(), default=[]):
if str(push.source) == f"{bot.self_id}.{event.message_type}.{event.user_id if event.message_type == 'private' else event.group_id}":
bot2 = nonebot.get_bot(push.target.bot_id)
msg_formatted = ""
for l in str(event.message).split("\n"):
msg_formatted += f"**{l.strip()}**\n"
for line in str(event.message).split("\n"):
msg_formatted += f"**{line.strip()}**\n"
push_message = (
f"> From {event.sender.nickname}@{push.source.session_type}.{push.source.session_id}\n> Bot {bot.self_id}\n\n"
f"{msg_formatted}")

View File

@ -5,7 +5,7 @@ import aiofiles
import nonebot.plugin
from liteyuki.utils.data import Database, LiteModel
from liteyuki.utils.data_manager import GroupChat, InstalledPlugin, User, group_db, plugin_db, user_db
from liteyuki.utils.data_manager import Group, InstalledPlugin, User, group_db, plugin_db, user_db
from liteyuki.utils.ly_typing import T_MessageEvent
LNPM_COMMAND_START = "lnpm"
@ -75,9 +75,9 @@ def get_plugin_session_enable(event: T_MessageEvent, plugin_module_name: str) ->
bool: 插件当前状态
"""
if event.message_type == "group":
session: GroupChat = group_db.first(GroupChat, "group_id = ?", event.group_id, default=GroupChat(group_id=str(event.group_id)))
session: Group = group_db.first(Group(), "group_id = ?", event.group_id, default=Group(group_id=str(event.group_id)))
else:
session: User = user_db.first(User, "user_id = ?", event.user_id, default=User(user_id=str(event.user_id)))
session: User = user_db.first(User(), "user_id = ?", event.user_id, default=User(user_id=str(event.user_id)))
# 默认停用插件在启用列表内表示启用
# 默认停用插件不在启用列表内表示停用
# 默认启用插件在停用列表内表示停用
@ -90,7 +90,11 @@ def get_plugin_session_enable(event: T_MessageEvent, plugin_module_name: str) ->
def get_plugin_global_enable(plugin_module_name: str) -> bool:
return True
return plugin_db.first(
InstalledPlugin(),
"module_name = ?",
plugin_module_name,
default=InstalledPlugin(module_name=plugin_module_name, enabled=True)).enabled
def get_plugin_can_be_toggle(plugin_module_name: str) -> bool:

View File

@ -98,7 +98,7 @@ async def _(result: Arparma, event: T_MessageEvent, bot: T_Bot):
r_load = nonebot.load_plugin(plugin_module_name) # 加载插件
installed_plugin = InstalledPlugin(module_name=plugin_module_name) # 构造插件信息模型
found_in_db_plugin = plugin_db.first(InstalledPlugin, "module_name = ?", plugin_module_name) # 查询数据库中是否已经安装
found_in_db_plugin = plugin_db.first(InstalledPlugin(), "module_name = ?", plugin_module_name) # 查询数据库中是否已经安装
if r_load:
if found_in_db_plugin is None:
@ -131,7 +131,7 @@ async def _(result: Arparma, event: T_MessageEvent, bot: T_Bot):
elif result.subcommands.get("uninstall"):
plugin_module_name: str = result.subcommands["uninstall"].args.get("plugin_name")
found_installed_plugin: InstalledPlugin = plugin_db.first(InstalledPlugin, "module_name = ?", plugin_module_name)
found_installed_plugin: InstalledPlugin = plugin_db.first(InstalledPlugin(), "module_name = ?", plugin_module_name)
if found_installed_plugin:
plugin_db.delete(InstalledPlugin, "module_name = ?", plugin_module_name)
reply = f"{ulang.get('npm.uninstall_success', NAME=found_installed_plugin.module_name)}"

View File

@ -7,7 +7,7 @@ from nonebot.internal.matcher import Matcher
from nonebot.message import run_preprocessor
from nonebot.permission import SUPERUSER
from liteyuki.utils.data_manager import GroupChat, InstalledPlugin, User, group_db, plugin_db, user_db
from liteyuki.utils.data_manager import Group, InstalledPlugin, User, group_db, plugin_db, user_db
from liteyuki.utils.message import Markdown as md, send_markdown
from liteyuki.utils.permission import GROUP_ADMIN, GROUP_OWNER
from liteyuki.utils.ly_typing import T_Bot, T_MessageEvent
@ -26,11 +26,19 @@ list_plugins = on_alconna(
toggle_plugin = on_alconna(
Alconna(
["enable-plugin", "disable-plugin"],
["enable", "disable"],
Args["plugin_name", str],
)
)
toggle_plugin_global = on_alconna(
Alconna(
["enable-global", "disable-global"],
Args["plugin_name", str],
),
permission=SUPERUSER
)
global_toggle = on_alconna(
Alconna(
["toggle-global"],
@ -82,7 +90,7 @@ async def _(event: T_MessageEvent, bot: T_Bot):
if await GROUP_ADMIN(bot, event) or await GROUP_OWNER(bot, event) or await SUPERUSER(bot, event):
# 添加启用/停用插件按钮
cmd_toggle = f"{'disable' if session_enable else 'enable'}-plugin {plugin.module_name}"
cmd_toggle = f"{'disable' if session_enable else 'enable'} {plugin.module_name}"
text_toggle = lang.get("npm.disable" if session_enable else "npm.enable")
can_be_toggle = get_plugin_can_be_toggle(plugin.module_name)
btn_toggle = text_toggle if not can_be_toggle else md.button(text_toggle, cmd_toggle)
@ -90,7 +98,7 @@ async def _(event: T_MessageEvent, bot: T_Bot):
reply += f" {btn_toggle}"
if await SUPERUSER(bot, event):
plugin_in_database = plugin_db.first(InstalledPlugin, "module_name = ?", plugin.module_name)
plugin_in_database = plugin_db.first(InstalledPlugin(), "module_name = ?", plugin.module_name)
# 添加移除插件和全局切换按钮
global_enable = get_plugin_global_enable(plugin.module_name)
btn_uninstall = (
@ -98,7 +106,7 @@ async def _(event: T_MessageEvent, bot: T_Bot):
'npm.uninstall')
btn_toggle_global_text = lang.get("npm.disable_global" if global_enable else "npm.enable_global")
cmd_toggle_global = f"npm toggle-global {plugin.module_name}"
cmd_toggle_global = f"{'disable-global' if global_enable else 'enable-global'} {plugin.module_name}"
btn_toggle_global = btn_toggle_global_text if not can_be_toggle else md.button(btn_toggle_global_text, cmd_toggle_global)
reply += f" {btn_uninstall} {btn_toggle_global}"
@ -131,10 +139,10 @@ async def _(result: Arparma, event: T_MessageEvent, bot: T_Bot):
ulang.get("npm.plugin_already", NAME=plugin_module_name, STATUS=ulang.get("npm.enable") if toggle else ulang.get("npm.disable")))
if event.message_type == "private":
session = user_db.first(User, "user_id = ?", event.user_id, default=User(user_id=event.user_id))
session = user_db.first(User(), "user_id = ?", event.user_id, default=User(user_id=event.user_id))
else:
if await GROUP_ADMIN(bot, event) or await GROUP_OWNER(bot, event) or await SUPERUSER(bot, event):
session = group_db.first(GroupChat, "group_id = ?", event.group_id, default=GroupChat(group_id=str(event.group_id)))
session = group_db.first(Group(), "group_id = ?", event.group_id, default=Group(group_id=str(event.group_id)))
else:
raise FinishedException(ulang.get("Permission Denied"))
try:
@ -170,6 +178,48 @@ async def _(result: Arparma, event: T_MessageEvent, bot: T_Bot):
)
@toggle_plugin_global.handle()
async def _(result: Arparma, event: T_MessageEvent, bot: T_Bot):
if not os.path.exists("data/liteyuki/plugins.json"):
await npm_update()
# 判断会话类型
ulang = get_user_lang(str(event.user_id))
plugin_module_name = result.args.get("plugin_name")
toggle = result.header_result == "enable-global"
can_be_toggled = get_plugin_can_be_toggle(plugin_module_name)
if not can_be_toggled:
await toggle_plugin_global.finish(ulang.get("npm.plugin_cannot_be_toggled", NAME=plugin_module_name))
global_enable = get_plugin_global_enable(plugin_module_name)
if global_enable == toggle:
await toggle_plugin_global.finish(
ulang.get("npm.plugin_already", NAME=plugin_module_name, STATUS=ulang.get("npm.enable") if toggle else ulang.get("npm.disable")))
try:
plugin = plugin_db.first(InstalledPlugin(), "module_name = ?", plugin_module_name, default=InstalledPlugin(module_name=plugin_module_name))
if toggle:
plugin.enabled = True
else:
plugin.enabled = False
plugin_db.upsert(plugin)
except Exception as e:
print(e)
await toggle_plugin_global.finish(
ulang.get(
"npm.toggle_failed",
NAME=plugin_module_name,
STATUS=ulang.get("npm.enable") if toggle else ulang.get("npm.disable"),
ERROR=str(e))
)
await toggle_plugin_global.finish(
ulang.get(
"npm.toggle_success",
NAME=plugin_module_name,
STATUS=ulang.get("npm.enable") if toggle else ulang.get("npm.disable"))
)
@run_preprocessor
async def _(event: T_MessageEvent, matcher: Matcher):
plugin = matcher.plugin

View File

@ -40,7 +40,7 @@ class Profile(LiteModel):
@profile_alc.handle()
async def _(result: Arparma, event: T_MessageEvent, bot: T_Bot):
user: User = user_db.first(User, "user_id = ?", event.user_id, default=User(user_id=str(event.user_id)))
user: User = user_db.first(User(), "user_id = ?", event.user_id, default=User(user_id=str(event.user_id)))
ulang = get_user_lang(str(event.user_id))
if result.subcommands.get("set"):
if result.subcommands["set"].args.get("value"):

View File

@ -12,6 +12,7 @@ import requests
from liteyuki.utils.config import load_from_yaml, config
from .log import init_log
from .data_manager import auto_migrate
major, minor, patch = map(int, __VERSION__.split("."))
__VERSION_I__ = major * 10000 + minor * 100 + patch
@ -52,6 +53,7 @@ def init():
if sys.version_info < (3, 10):
nonebot.logger.error("This project requires Python3.10+ to run, please upgrade your Python Environment.")
exit(1)
auto_migrate()
# 在加载完成语言后再初始化日志
init_log()
nonebot.logger.info("Liteyuki is initializing...")

View File

@ -1,374 +1,358 @@
import json
import os
import pickle
import sqlite3
import types
from abc import ABC
from collections.abc import Iterable
import nonebot
from pydantic import BaseModel
from types import NoneType
from typing import Any
BaseIterable = list | tuple | set | dict
import nonebot
import pydantic
from pydantic import BaseModel
class LiteModel(BaseModel):
"""轻量级模型基类
类型注解统一使用Python3.9的PEP585标准如需使用泛型请使用typing模块的泛型类型
"""
TABLE_NAME: str = None
id: int = None
class BaseORMAdapter(ABC):
def __init__(self):
pass
def auto_migrate(self, *args, **kwargs):
"""自动迁移
Returns:
"""
raise NotImplementedError
def upsert(self, *args, **kwargs):
"""存储数据
Returns:
"""
raise NotImplementedError
def first(self, *args, **kwargs):
"""查询第一条数据
Returns:
"""
raise NotImplementedError
def all(self, *args, **kwargs):
"""查询所有数据
Returns:
"""
raise NotImplementedError
def delete(self, *args, **kwargs):
"""删除数据
Returns:
"""
raise NotImplementedError
def update(self, *args, **kwargs):
"""更新数据
Returns:
"""
raise NotImplementedError
def dump(self, *args, **kwargs):
if pydantic.__version__ < "1.8.2":
return self.dict(by_alias=True)
else:
return self.model_dump(by_alias=True)
class Database(BaseORMAdapter):
"""SQLiteORM适配器严禁使用`FORIEGNID`和`JSON`作为主键前缀,严禁使用`$ID:`作为字符串值前缀
Attributes:
"""
type_map = {
# default: TEXT
str : 'TEXT',
int : 'INTEGER',
float: 'REAL',
bool : 'INTEGER',
list : 'TEXT'
}
DEFAULT_VALUE = {
'TEXT' : '',
'INTEGER': 0,
'REAL' : 0.0
}
FOREIGNID = 'FOREIGNID'
JSON = 'JSON'
LIST = 'LIST'
DICT = 'DICT'
ID = '$ID'
class Database:
def __init__(self, db_name: str):
super().__init__()
if not os.path.exists(os.path.dirname(db_name)):
if os.path.dirname(db_name) != "" and not os.path.exists(os.path.dirname(db_name)):
os.makedirs(os.path.dirname(db_name))
self.db_name = db_name
self.conn = sqlite3.connect(db_name)
self.conn.row_factory = sqlite3.Row
self.cursor = self.conn.cursor()
def auto_migrate(self, *args: type(LiteModel)):
"""自动迁移,检测新模型字段和原有表字段的差异,如有差异自动增删新字段
def first(self, model: LiteModel, condition: str, *args: Any, default: Any = None) -> LiteModel | Any | None:
"""查询第一个
Args:
*args: 模型类
model: 数据模型实例
condition: 查询条件不给定则查询所有
*args: 参数化查询参数
default: 默认值
Returns:
"""
table_name = ''
all_results = self.all(model, condition, *args)
return all_results[0] if all_results else default
def all(self, model: LiteModel, condition: str = "", *args: Any, default: Any = None) -> list[LiteModel | Any] | None:
"""查询所有
Args:
model: 数据模型实例
condition: 查询条件不给定则查询所有
*args: 参数化查询参数
default: 默认值
Returns:
"""
table_name = model.TABLE_NAME
model_type = type(model)
if not table_name:
raise ValueError(f"数据模型{model_type.__name__}未提供表名")
# condition = f"WHERE {condition}"
# print(f"SELECT * FROM {table_name} {condition}", args)
# if len(args) == 0:
# results = self.cursor.execute(f"SELECT * FROM {table_name} {condition}").fetchall()
# else:
# results = self.cursor.execute(f"SELECT * FROM {table_name} {condition}", args).fetchall()
if condition:
results = self.cursor.execute(f"SELECT * FROM {table_name} WHERE {condition}", args).fetchall()
else:
results = self.cursor.execute(f"SELECT * FROM {table_name}").fetchall()
fields = [description[0] for description in self.cursor.description]
if not results:
return default
else:
return [model_type(**self._load(dict(zip(fields, result)))) for result in results]
def upsert(self, *args: LiteModel):
"""增/改操作
Args:
*args:
Returns:
"""
table_list = [item[0] for item in self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()]
for model in args:
model: type(LiteModel)
# 检测并创建表若模型未定义id字段则使用自增主键有定义的话使用id字段且id有可能为字符串
table_name = model.__name__
if 'id' in model.__annotations__ and model.__annotations__['id'] is not None:
# 如果模型定义了id字段那么使用模型的id字段
id_type = self.type_map.get(model.__annotations__['id'], 'TEXT')
self.cursor.execute(f'CREATE TABLE IF NOT EXISTS {table_name} (id {id_type} PRIMARY KEY)')
if not model.TABLE_NAME:
raise ValueError(f"数据模型 {model.__class__.__name__} 未提供表名")
elif model.TABLE_NAME not in table_list:
raise ValueError(f"数据模型 {model.__class__.__name__} 的表 {model.TABLE_NAME} 不存在,请先迁移")
else:
# 如果模型未定义id字段那么使用自增主键
self.cursor.execute(f'CREATE TABLE IF NOT EXISTS {table_name} (id INTEGER PRIMARY KEY AUTOINCREMENT)')
# 获取表字段
self.cursor.execute(f'PRAGMA table_info({table_name})')
table_fields = self.cursor.fetchall()
table_fields = [field[1] for field in table_fields]
self._save(model.dump(by_alias=True))
raw_fields, raw_types = zip(*model.__annotations__.items())
# 获取模型字段若有模型则添加FOREIGNID前缀若为BaseIterable则添加JSON前缀用多行if判断
model_fields = []
model_types = []
for field, r_type in zip(raw_fields, raw_types):
if isinstance(r_type, type(LiteModel)):
model_fields.append(f'{self.FOREIGNID}{field}')
model_types.append('TEXT')
elif r_type in [list[str], list[int], list[float], list[bool], list]:
model_fields.append(f'{self.LIST}{field}')
model_types.append('TEXT')
elif r_type in [dict[str, str], dict[str, int], dict[str, float], dict[str, bool], dict]:
model_fields.append(f'{self.DICT}{field}')
model_types.append('TEXT')
elif isinstance(r_type, types.GenericAlias):
model_fields.append(f'{self.JSON}{field}')
model_types.append('TEXT')
def _save(self, obj: Any) -> Any:
# obj = copy.deepcopy(obj)
if isinstance(obj, dict):
table_name = obj.get("TABLE_NAME")
row_id = obj.get("id")
new_obj = {}
for field, value in obj.items():
if isinstance(value, self.ITERABLE_TYPE):
new_obj[self._get_stored_field_prefix(value) + field] = self._save(value) # self._save(value) # -> bytes
elif isinstance(value, self.BASIC_TYPE):
new_obj[field] = value
else:
model_fields.append(field)
model_types.append(self.type_map.get(r_type, 'TEXT'))
# 检测新字段或字段类型是否有变化,有则增删字段,已经加了前缀类型
for field_changed, type_, r_type in zip(model_fields, model_types, raw_types):
if field_changed not in table_fields:
nonebot.logger.debug(f'ALTER TABLE {table_name} ADD COLUMN {field_changed} {type_}')
self.cursor.execute(f'ALTER TABLE {table_name} ADD COLUMN {field_changed} {type_}')
# 在原有的行中添加新字段对应类型的默认值从DEFAULT_TYPE中获取
self.cursor.execute(f'UPDATE {table_name} SET {field_changed} = ? WHERE {field_changed} IS NULL', (self.DEFAULT_VALUE.get(type_, ""),))
# 检测多余字段除了id字段
for field in table_fields:
if field not in model_fields and field != 'id':
nonebot.logger.debug(f'ALTER TABLE {table_name} DROP COLUMN {field}')
self.cursor.execute(f'ALTER TABLE {table_name} DROP COLUMN {field}')
self.conn.commit()
nonebot.logger.debug(f'Table {table_name} migrated successfully')
def upsert(self, *models: LiteModel) -> int | tuple:
"""存储数据检查id字段如果有id字段则更新没有则插入
Args:
models: 数据
Returns:
id: 数据id如果有多个数据则返回id元组
"""
ids = []
for model in models:
table_name = model.__class__.__name__
if not self._detect_for_table(table_name):
raise ValueError(f'{table_name}不存在,请先迁移')
key_list = []
value_list = []
# 处理外键,添加前缀'$IDFieldName'
for field, value in model.__dict__.items():
if isinstance(value, LiteModel):
key_list.append(f'{self.FOREIGNID}{field}')
value_list.append(f'{self.ID}:{value.__class__.__name__}:{self.upsert(value)}')
elif isinstance(value, list):
key_list.append(f'{self.LIST}{field}')
value_list.append(self._flat(value))
elif isinstance(value, dict):
key_list.append(f'{self.DICT}{field}')
value_list.append(self._flat(value))
elif isinstance(value, BaseIterable):
key_list.append(f'{self.JSON}{field}')
value_list.append(self._flat(value))
else:
key_list.append(field)
value_list.append(value)
# 更新或插入数据,用?占位
nonebot.logger.debug(f'INSERT OR REPLACE INTO {table_name} ({",".join(key_list)}) VALUES ({",".join(["?" for _ in key_list])})')
self.cursor.execute(f'INSERT OR REPLACE INTO {table_name} ({",".join(key_list)}) VALUES ({",".join(["?" for _ in key_list])})', value_list)
ids.append(self.cursor.lastrowid)
self.conn.commit()
return ids[0] if len(ids) == 1 else tuple(ids)
def _flat(self, data: Iterable) -> str:
"""扁平化数据,返回扁平化对象
Args:
data: 数据可迭代对象
Returns: json字符串
"""
if isinstance(data, dict):
return_data = {}
for k, v in data.items():
if isinstance(v, LiteModel):
return_data[f"{self.FOREIGNID}{k}"] = f"{self.ID}:{v.__class__.__name__}:{self.upsert(v)}"
elif isinstance(v, list):
return_data[f"{self.LIST}{k}"] = self._flat(v)
elif isinstance(v, dict):
return_data[f"{self.DICT}{k}"] = self._flat(v)
elif isinstance(v, BaseIterable):
return_data[f"{self.JSON}{k}"] = self._flat(v)
else:
return_data[k] = v
elif isinstance(data, list | tuple | set):
return_data = []
for v in data:
if isinstance(v, LiteModel):
return_data.append(f"{self.ID}:{v.__class__.__name__}:{self.upsert(v)}")
elif isinstance(v, list):
return_data.append(self._flat(v))
elif isinstance(v, dict):
return_data.append(self._flat(v))
elif isinstance(v, BaseIterable):
return_data.append(self._flat(v))
else:
return_data.append(v)
else:
raise ValueError("数据类型错误")
return json.dumps(return_data)
def _detect_for_table(self, table_name: str) -> bool:
"""在进行增删查改前检测表是否存在
Args:
table_name: 表名
Returns:
"""
return self.cursor.execute(f"SELECT * FROM sqlite_master WHERE type = 'table' AND name = ?", (table_name,)).fetchone()
def first(self, model: type(LiteModel), conditions, *args, default: Any = None) -> LiteModel | None:
"""查询第一条数据
Args:
model: 模型
conditions: 查询条件
*args: 参数化查询条件参数
default: 未查询到结果默认返回值
Returns: 数据
"""
table_name = model.__name__
if not self._detect_for_table(table_name):
return default
self.cursor.execute(f"SELECT * FROM {table_name} WHERE {conditions}", args)
if row_data := self.cursor.fetchone():
data = dict(row_data)
return model(**self.convert_to_dict(data))
return default
def all(self, model: type(LiteModel), conditions=None, *args, default: Any = None) -> list[LiteModel] | None:
"""查询所有数据
Args:
model: 模型
conditions: 查询条件
*args: 参数化查询条件参数
default: 未查询到结果默认返回值
Returns: 数据
"""
table_name = model.__name__
if not self._detect_for_table(table_name):
return default
if conditions:
self.cursor.execute(f"SELECT * FROM {table_name} WHERE {conditions}", args)
else:
self.cursor.execute(f"SELECT * FROM {table_name}")
if row_datas := self.cursor.fetchall():
datas = [dict(row_data) for row_data in row_datas]
return [model(**self.convert_to_dict(d)) for d in datas] if datas else default
return default
def delete(self, model: type(LiteModel), conditions, *args):
"""删除数据
Args:
model: 模型
conditions: 查询条件
*args: 参数化查询条件参数
Returns:
"""
table_name = model.__name__
if not self._detect_for_table(table_name):
return
nonebot.logger.debug(f"DELETE FROM {table_name} WHERE {conditions}")
self.cursor.execute(f"DELETE FROM {table_name} WHERE {conditions}", args)
self.conn.commit()
def convert_to_dict(self, data: dict) -> dict:
"""将json字符串转换为字典
Args:
data: json字符串
Returns: 字典
"""
def load(d: BaseIterable) -> BaseIterable:
"""递归加载数据,去除前缀"""
if isinstance(d, dict):
new_d = {}
for k, v in d.items():
if k.startswith(self.FOREIGNID):
new_d[k.replace(self.FOREIGNID, "")] = load(
dict(self.cursor.execute(f"SELECT * FROM {v.split(':', 2)[1]} WHERE id = ?", (v.split(":", 2)[2],)).fetchone()))
elif k.startswith(self.LIST):
if v == '': v = '[]'
new_d[k.replace(self.LIST, '')] = load(json.loads(v))
elif k.startswith(self.DICT):
if v == '': v = '{}'
new_d[k.replace(self.DICT, '')] = load(json.loads(v))
elif k.startswith(self.JSON):
if v == '': v = '[]'
new_d[k.replace(self.JSON, '')] = load(json.loads(v))
else:
new_d[k] = v
elif isinstance(d, list | tuple | set):
new_d = []
for i, v in enumerate(d):
if isinstance(v, str) and v.startswith(self.ID):
new_d.append(load(dict(self.cursor.execute(f'SELECT * FROM {v.split(":", 2)[1]} WHERE id = ?', (v.split(":", 2)[2],)).fetchone())))
elif isinstance(v, BaseIterable):
new_d.append(load(v))
raise ValueError(f"数据模型{table_name}包含不支持的数据类型,字段:{field} 值:{value} 值类型:{type(value)}")
if table_name:
fields, values = [], []
for n_field, n_value in new_obj.items():
if n_field not in ["TABLE_NAME", "id"]:
fields.append(n_field)
values.append(n_value)
# 移除TABLE_NAME和id
fields = list(fields)
values = list(values)
if row_id is not None:
# 如果 _id 不为空,将 'id' 插入到字段列表的开始
fields.insert(0, 'id')
# 将 _id 插入到值列表的开始
values.insert(0, row_id)
fields = ', '.join([f'"{field}"' for field in fields])
placeholders = ', '.join('?' for _ in values)
self.cursor.execute(f"INSERT OR REPLACE INTO {table_name}({fields}) VALUES ({placeholders})", tuple(values))
self.conn.commit()
foreign_id = self.cursor.execute("SELECT last_insert_rowid()").fetchone()[0]
return f"{self.FOREIGN_KEY_PREFIX}{foreign_id}@{table_name}" # -> FOREIGN_KEY_123456@{table_name} id@{table_name}
else:
new_d = d
return new_d
return pickle.dumps(new_obj) # -> bytes
elif isinstance(obj, (list, set, tuple)):
obj_type = type(obj) # 到时候转回去
new_obj = []
for item in obj:
if isinstance(item, self.ITERABLE_TYPE):
new_obj.append(self._save(item))
elif isinstance(item, self.BASIC_TYPE):
new_obj.append(item)
else:
raise ValueError(f"数据模型包含不支持的数据类型,值:{item} 值类型:{type(item)}")
return pickle.dumps(obj_type(new_obj)) # -> bytes
else:
raise ValueError(f"数据模型包含不支持的数据类型,值:{obj} 值类型:{type(obj)}")
return load(data)
def _load(self, obj: Any) -> Any:
if isinstance(obj, dict):
new_obj = {}
for field, value in obj.items():
field: str
if field.startswith(self.BYTES_PREFIX):
new_obj[field.replace(self.BYTES_PREFIX, "")] = self._load(pickle.loads(value))
elif field.startswith(self.FOREIGN_KEY_PREFIX):
new_obj[field.replace(self.FOREIGN_KEY_PREFIX, "")] = self._load(self._get_foreign_data(value))
else:
new_obj[field] = value
return new_obj
elif isinstance(obj, (list, set, tuple)):
print(" - Load as List")
new_obj = []
for item in obj:
print(" - Loading Item", item)
if isinstance(item, bytes):
# 对bytes进行尝试解析解析失败则返回原始bytes
try:
new_obj.append(self._load(pickle.loads(item)))
except Exception as e:
new_obj.append(self._load(item))
print(" - Load as Bytes | Result:", new_obj[-1])
elif isinstance(item, str) and item.startswith(self.FOREIGN_KEY_PREFIX):
new_obj.append(self._load(self._get_foreign_data(item)))
else:
new_obj.append(self._load(item))
return new_obj
else:
return obj
def delete(self, model: LiteModel, condition: str, *args: Any, allow_empty: bool = False):
"""
删除满足条件的数据
Args:
allow_empty: 允许空条件删除整个表
model:
condition:
*args:
Returns:
"""
table_name = model.TABLE_NAME
if not table_name:
raise ValueError(f"数据模型{model.__class__.__name__}未提供表名")
if not condition and not allow_empty:
raise ValueError("删除操作必须提供条件")
self.cursor.execute(f"DELETE FROM {table_name} WHERE {condition}", args)
def auto_migrate(self, *args: LiteModel):
"""
自动迁移模型
Args:
*args: 模型类实例化对象支持空默认值不支持嵌套迁移
Returns:
"""
for model in args:
if not model.TABLE_NAME:
raise ValueError(f"数据模型{type(model).__name__}未提供表名")
# 若无则创建表
self.cursor.execute(
f'CREATE TABLE IF NOT EXISTS "{model.TABLE_NAME}" (id INTEGER PRIMARY KEY AUTOINCREMENT)'
)
# 获取表结构,field -> SqliteType
new_structure = {}
for n_field, n_value in model.dump(by_alias=True).items():
if n_field not in ["TABLE_NAME", "id"]:
new_structure[self._get_stored_field_prefix(n_value) + n_field] = self._get_stored_type(n_value)
# 原有的字段列表
existing_structure = dict([(column[1], column[2]) for column in self.cursor.execute(f'PRAGMA table_info({model.TABLE_NAME})').fetchall()])
# 检测缺失字段由于SQLite是动态类型所以不需要检测类型
for n_field, n_type in new_structure.items():
if n_field not in existing_structure.keys() and n_field.lower() not in ["id", "table_name"]:
print(n_type, self.DEFAULT_MAPPING.get(n_type, ''))
self.cursor.execute(
f"ALTER TABLE '{model.TABLE_NAME}' ADD COLUMN {n_field} {n_type} DEFAULT {self.DEFAULT_MAPPING.get(n_type, '')}"
)
# 检测多余字段进行删除
for e_field in existing_structure.keys():
if e_field not in new_structure.keys() and e_field.lower() not in ['id']:
self.cursor.execute(
f'ALTER TABLE "{model.TABLE_NAME}" DROP COLUMN "{e_field}"'
)
self.conn.commit()
# 已完成
def _get_stored_field_prefix(self, value) -> str:
"""根据类型获取存储字段前缀,一定在后加上字段名
* -> ""
Args:
value: 储存的值
Returns:
Sqlite3存储字段
"""
if isinstance(value, LiteModel) or isinstance(value, dict) and "TABLE_NAME" in value:
return self.FOREIGN_KEY_PREFIX
elif type(value) in self.ITERABLE_TYPE:
return self.BYTES_PREFIX
return ""
def _get_stored_type(self, value) -> str:
"""获取存储类型
Args:
value: 储存的值
Returns:
Sqlite3存储类型
"""
if isinstance(value, dict) and "TABLE_NAME" in value:
# 是一个模型字典,储存外键
return "INTEGER"
return self.TYPE_MAPPING.get(type(value), "TEXT")
def _get_foreign_data(self, foreign_value: str) -> dict:
"""
获取外键数据
Args:
foreign_value:
Returns:
"""
foreign_value = foreign_value.replace(self.FOREIGN_KEY_PREFIX, "")
table_name = foreign_value.split("@")[-1]
foreign_id = foreign_value.split("@")[0]
fields = [description[1] for description in self.cursor.execute(f"PRAGMA table_info({table_name})").fetchall()]
result = self.cursor.execute(f"SELECT * FROM {table_name} WHERE id = ?", (foreign_id,)).fetchone()
return dict(zip(fields, result))
TYPE_MAPPING = {
int : "INTEGER",
float : "REAL",
str : "TEXT",
bool : "INTEGER",
bytes : "BLOB",
NoneType : "NULL",
# dict : "TEXT",
# list : "TEXT",
# tuple : "TEXT",
# set : "TEXT",
dict : "BLOB", # LITEYUKIDICT{key_name}
list : "BLOB", # LITEYUKILIST{key_name}
tuple : "BLOB", # LITEYUKITUPLE{key_name}
set : "BLOB", # LITEYUKISET{key_name}
LiteModel: "TEXT" # FOREIGN_KEY_{table_name}
}
DEFAULT_MAPPING = {
"TEXT" : "''",
"INTEGER": 0,
"REAL" : 0.0,
"BLOB" : b"",
"NULL" : None
}
# 基础类型
BASIC_TYPE = (int, float, str, bool, bytes, NoneType)
# 可序列化类型
ITERABLE_TYPE = (dict, list, tuple, set, LiteModel)
# 外键前缀
FOREIGN_KEY_PREFIX = "FOREIGN_KEY_"
# 转换为的字节前缀
BYTES_PREFIX = "PICKLE_BYTES_"
def check_sqlite_keyword(name):
sqlite_keywords = [
"ABORT", "ACTION", "ADD", "AFTER", "ALL", "ALTER", "ANALYZE", "AND", "AS", "ASC",
"ATTACH", "AUTOINCREMENT", "BEFORE", "BEGIN", "BETWEEN", "BY", "CASCADE", "CASE",
"CAST", "CHECK", "COLLATE", "COLUMN", "COMMIT", "CONFLICT", "CONSTRAINT", "CREATE",
"CROSS", "CURRENT_DATE", "CURRENT_TIME", "CURRENT_TIMESTAMP", "DATABASE", "DEFAULT",
"DEFERRABLE", "DEFERRED", "DELETE", "DESC", "DETACH", "DISTINCT", "DROP", "EACH",
"ELSE", "END", "ESCAPE", "EXCEPT", "EXCLUSIVE", "EXISTS", "EXPLAIN", "FAIL", "FOR",
"FOREIGN", "FROM", "FULL", "GLOB", "GROUP", "HAVING", "IF", "IGNORE", "IMMEDIATE",
"IN", "INDEX", "INDEXED", "INITIALLY", "INNER", "INSERT", "INSTEAD", "INTERSECT",
"INTO", "IS", "ISNULL", "JOIN", "KEY", "LEFT", "LIKE", "LIMIT", "MATCH", "NATURAL",
"NO", "NOT", "NOTNULL", "NULL", "OF", "OFFSET", "ON", "OR", "ORDER", "OUTER", "PLAN",
"PRAGMA", "PRIMARY", "QUERY", "RAISE", "RECURSIVE", "REFERENCES", "REGEXP", "REINDEX",
"RELEASE", "RENAME", "REPLACE", "RESTRICT", "RIGHT", "ROLLBACK", "ROW", "SAVEPOINT",
"SELECT", "SET", "TABLE", "TEMP", "TEMPORARY", "THEN", "TO", "TRANSACTION", "TRIGGER",
"UNION", "UNIQUE", "UPDATE", "USING", "VACUUM", "VALUES", "VIEW", "VIRTUAL", "WHEN",
"WHERE", "WITH", "WITHOUT"
]
return True
# if name.upper() in sqlite_keywords:
# raise ValueError(f"'{name}' 是SQLite保留字不建议使用请更换名称")

View File

@ -13,6 +13,7 @@ common_db = DB(os.path.join(DATA_PATH, "common.ldb"))
class User(LiteModel):
TABLE_NAME = "user"
user_id: str = Field(str(), alias="user_id")
username: str = Field(str(), alias="username")
profile: dict[str, str] = Field(dict(), alias="profile")
@ -20,7 +21,8 @@ class User(LiteModel):
disabled_plugins: list[str] = Field(list(), alias="disabled_plugins")
class GroupChat(LiteModel):
class Group(LiteModel):
TABLE_NAME = "group_chat"
# Group是一个关键字所以这里用GroupChat
group_id: str = Field(str(), alias="group_id")
group_name: str = Field(str(), alias="group_name")
@ -29,17 +31,22 @@ class GroupChat(LiteModel):
class InstalledPlugin(LiteModel):
liteyuki: bool = Field(True, alias="liteyuki") # 是否为LiteYuki插件
enabled: bool = Field(True, alias="enabled") # 全局启用
TABLE_NAME = "installed_plugin"
module_name: str = Field(str(), alias="module_name")
version: str = Field(str(), alias="version")
class GlobalPlugin(LiteModel):
TABLE_NAME = "global_plugin"
module_name: str = Field(str(), alias="module_name")
enabled: bool = Field(True, alias="enabled")
def auto_migrate():
user_db.auto_migrate(User)
group_db.auto_migrate(GroupChat)
plugin_db.auto_migrate(InstalledPlugin)
common_db.auto_migrate(GlobalPlugin)
print("Migrating databases...")
user_db.auto_migrate(User())
group_db.auto_migrate(Group())
plugin_db.auto_migrate(InstalledPlugin())
common_db.auto_migrate(GlobalPlugin())

View File

@ -1,326 +0,0 @@
import os
import pickle
import sqlite3
from types import NoneType
from typing import Any
import pydantic
from pydantic import BaseModel
class LiteModel(BaseModel):
TABLE_NAME: str = None
id: int = None
class Database:
def __init__(self, db_name: str):
if os.path.dirname(db_name) != "" and not os.path.exists(os.path.dirname(db_name)):
os.makedirs(os.path.dirname(db_name))
self.db_name = db_name
self.conn = sqlite3.connect(db_name)
self.cursor = self.conn.cursor()
def first(self, model: LiteModel, condition: str, *args: Any, default: Any = None) -> LiteModel | Any | None:
"""查询第一个
Args:
model: 数据模型实例
condition: 查询条件不给定则查询所有
*args: 参数化查询参数
default: 默认值
Returns:
"""
all_results = self.all(model, condition, *args, default=default)
return all_results[0] if all_results else default
def all(self, model: LiteModel, condition: str = "", *args: Any, default: Any = None) -> list[LiteModel] | list[Any] | None:
"""查询所有
Args:
model: 数据模型实例
condition: 查询条件不给定则查询所有
*args: 参数化查询参数
default: 默认值
Returns:
"""
table_name = model.TABLE_NAME
model_type = type(model)
if not table_name:
raise ValueError(f"数据模型{model_type.__name__}未提供表名")
condition = f"WHERE {condition}"
results = self.cursor.execute(f"SELECT * FROM {table_name} {condition}", args).fetchall()
fields = [description[0] for description in self.cursor.description]
if not results:
return default
else:
return [model_type(**self._load(dict(zip(fields, result)))) for result in results]
def upsert(self, *args: LiteModel):
"""增/改操作
Args:
*args:
Returns:
"""
table_list = [item[0] for item in self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall()]
for model in args:
if not model.TABLE_NAME:
raise ValueError(f"数据模型 {model.__class__.__name__} 未提供表名")
elif model.TABLE_NAME not in table_list:
raise ValueError(f"数据模型 {model.__class__.__name__} 的表 {model.TABLE_NAME} 不存在,请先迁移")
else:
if pydantic.__version__ < "1.8.2":
# 兼容pydantic 1.8.2以下版本
model_dict = model.dict(by_alias=True)
else:
model_dict = model.model_dump(by_alias=True)
self._save(model_dict)
def _save(self, obj: Any) -> Any:
# obj = copy.deepcopy(obj)
if isinstance(obj, dict):
table_name = obj.get("TABLE_NAME")
row_id = obj.get("id")
new_obj = {}
for field, value in obj.items():
if isinstance(value, self.ITERABLE_TYPE):
new_obj[self._get_stored_field_prefix(value) + field] = self._save(value) # self._save(value) # -> bytes
elif isinstance(value, self.BASIC_TYPE):
new_obj[field] = value
else:
raise ValueError(f"数据模型{table_name}包含不支持的数据类型,字段:{field} 值:{value} 值类型:{type(value)}")
if table_name:
fields, values = [], []
for n_field, n_value in new_obj.items():
if n_field not in ["TABLE_NAME", "id"]:
fields.append(n_field)
values.append(n_value)
# 移除TABLE_NAME和id
fields = list(fields)
values = list(values)
if row_id is not None:
# 如果 _id 不为空,将 'id' 插入到字段列表的开始
fields.insert(0, 'id')
# 将 _id 插入到值列表的开始
values.insert(0, row_id)
fields = ', '.join([f'"{field}"' for field in fields])
placeholders = ', '.join('?' for _ in values)
self.cursor.execute(f"INSERT OR REPLACE INTO {table_name}({fields}) VALUES ({placeholders})", tuple(values))
self.conn.commit()
foreign_id = self.cursor.execute("SELECT last_insert_rowid()").fetchone()[0]
return f"{self.FOREIGN_KEY_PREFIX}{foreign_id}@{table_name}" # -> FOREIGN_KEY_123456@{table_name} id@{table_name}
else:
return pickle.dumps(new_obj) # -> bytes
elif isinstance(obj, (list, set, tuple)):
obj_type = type(obj) # 到时候转回去
new_obj = []
for item in obj:
if isinstance(item, self.ITERABLE_TYPE):
new_obj.append(self._save(item))
elif isinstance(item, self.BASIC_TYPE):
new_obj.append(item)
else:
raise ValueError(f"数据模型包含不支持的数据类型,值:{item} 值类型:{type(item)}")
return pickle.dumps(obj_type(new_obj)) # -> bytes
else:
raise ValueError(f"数据模型包含不支持的数据类型,值:{obj} 值类型:{type(obj)}")
def _load(self, obj: Any) -> Any:
if isinstance(obj, dict):
new_obj = {}
for field, value in obj.items():
field: str
if field.startswith(self.BYTES_PREFIX):
new_obj[field.replace(self.BYTES_PREFIX, "")] = self._load(pickle.loads(value))
elif field.startswith(self.FOREIGN_KEY_PREFIX):
new_obj[field.replace(self.FOREIGN_KEY_PREFIX, "")] = self._load(self._get_foreign_data(value))
else:
new_obj[field] = value
return new_obj
elif isinstance(obj, (list, set, tuple)):
print(" - Load as List")
new_obj = []
for item in obj:
print(" - Loading Item", item)
if isinstance(item, bytes):
# 对bytes进行尝试解析解析失败则返回原始bytes
try:
new_obj.append(self._load(pickle.loads(item)))
except Exception as e:
new_obj.append(self._load(item))
print(" - Load as Bytes | Result:", new_obj[-1])
elif isinstance(item, str) and item.startswith(self.FOREIGN_KEY_PREFIX):
new_obj.append(self._load(self._get_foreign_data(item)))
else:
new_obj.append(self._load(item))
return new_obj
else:
return obj
def delete(self, model: LiteModel, condition: str, *args: Any):
pass
def auto_migrate(self, *args: LiteModel):
"""
自动迁移模型
Args:
*args: 模型类实例化对象支持空默认值不支持嵌套迁移
Returns:
"""
for model in args:
if not model.TABLE_NAME:
raise ValueError(f"数据模型{type(model).__name__}未提供表名")
# 若无则创建表
self.cursor.execute(
f'CREATE TABLE IF NOT EXISTS "{model.TABLE_NAME}" (id INTEGER PRIMARY KEY AUTOINCREMENT)'
)
# 获取表结构,field -> SqliteType
new_structure = {}
for n_field, n_value in model.model_dump(by_alias=True).items():
if n_field not in ["TABLE_NAME", "id"]:
new_structure[self._get_stored_field_prefix(n_value) + n_field] = self._get_stored_type(n_value)
# 原有的字段列表
existing_structure = dict([(column[1], column[2]) for column in self.cursor.execute(f'PRAGMA table_info({model.TABLE_NAME})').fetchall()])
# 检测缺失字段由于SQLite是动态类型所以不需要检测类型
for n_field, n_type in new_structure.items():
if n_field not in existing_structure.keys() and n_field.lower() not in ["id", "table_name"]:
self.cursor.execute(
f'ALTER TABLE "{model.TABLE_NAME}" ADD COLUMN "{n_field}" {n_type}'
)
# 检测多余字段进行删除
for e_field in existing_structure.keys():
if e_field not in new_structure.keys() and e_field.lower() not in ['id']:
self.cursor.execute(
f'ALTER TABLE "{model.TABLE_NAME}" DROP COLUMN "{e_field}"'
)
self.conn.commit()
# 已完成
def _get_stored_field_prefix(self, value) -> str:
"""根据类型获取存储字段前缀,一定在后加上字段名
* -> ""
Args:
value: 储存的值
Returns:
Sqlite3存储字段
"""
if isinstance(value, LiteModel) or isinstance(value, dict) and "TABLE_NAME" in value:
return self.FOREIGN_KEY_PREFIX
elif type(value) in self.ITERABLE_TYPE:
return self.BYTES_PREFIX
return ""
def _get_stored_type(self, value) -> str:
"""获取存储类型
Args:
value: 储存的值
Returns:
Sqlite3存储类型
"""
if isinstance(value, dict) and "TABLE_NAME" in value:
# 是一个模型字典,储存外键
return "INTEGER"
return self.TYPE_MAPPING.get(type(value), "TEXT")
def _get_foreign_data(self, foreign_value: str) -> dict:
"""
获取外键数据
Args:
foreign_value:
Returns:
"""
foreign_value = foreign_value.replace(self.FOREIGN_KEY_PREFIX, "")
table_name = foreign_value.split("@")[-1]
foreign_id = foreign_value.split("@")[0]
fields = [description[1] for description in self.cursor.execute(f"PRAGMA table_info({table_name})").fetchall()]
result = self.cursor.execute(f"SELECT * FROM {table_name} WHERE id = ?", (foreign_id,)).fetchone()
return dict(zip(fields, result))
TYPE_MAPPING = {
int : "INTEGER",
float : "REAL",
str : "TEXT",
bool : "INTEGER",
bytes : "BLOB",
NoneType : "NULL",
# dict : "TEXT",
# list : "TEXT",
# tuple : "TEXT",
# set : "TEXT",
dict : "BLOB", # LITEYUKIDICT{key_name}
list : "BLOB", # LITEYUKILIST{key_name}
tuple : "BLOB", # LITEYUKITUPLE{key_name}
set : "BLOB", # LITEYUKISET{key_name}
LiteModel: "INTEGER" # FOREIGN_KEY_{table_name}
}
# 基础类型
BASIC_TYPE = (int, float, str, bool, bytes, NoneType)
# 可序列化类型
ITERABLE_TYPE = (dict, list, tuple, set, LiteModel)
# 外键前缀
FOREIGN_KEY_PREFIX = "FOREIGN_KEY_"
# 转换为的字节前缀
BYTES_PREFIX = "PICKLE_BYTES_"
def check_sqlite_keyword(name):
sqlite_keywords = [
"ABORT", "ACTION", "ADD", "AFTER", "ALL", "ALTER", "ANALYZE", "AND", "AS", "ASC",
"ATTACH", "AUTOINCREMENT", "BEFORE", "BEGIN", "BETWEEN", "BY", "CASCADE", "CASE",
"CAST", "CHECK", "COLLATE", "COLUMN", "COMMIT", "CONFLICT", "CONSTRAINT", "CREATE",
"CROSS", "CURRENT_DATE", "CURRENT_TIME", "CURRENT_TIMESTAMP", "DATABASE", "DEFAULT",
"DEFERRABLE", "DEFERRED", "DELETE", "DESC", "DETACH", "DISTINCT", "DROP", "EACH",
"ELSE", "END", "ESCAPE", "EXCEPT", "EXCLUSIVE", "EXISTS", "EXPLAIN", "FAIL", "FOR",
"FOREIGN", "FROM", "FULL", "GLOB", "GROUP", "HAVING", "IF", "IGNORE", "IMMEDIATE",
"IN", "INDEX", "INDEXED", "INITIALLY", "INNER", "INSERT", "INSTEAD", "INTERSECT",
"INTO", "IS", "ISNULL", "JOIN", "KEY", "LEFT", "LIKE", "LIMIT", "MATCH", "NATURAL",
"NO", "NOT", "NOTNULL", "NULL", "OF", "OFFSET", "ON", "OR", "ORDER", "OUTER", "PLAN",
"PRAGMA", "PRIMARY", "QUERY", "RAISE", "RECURSIVE", "REFERENCES", "REGEXP", "REINDEX",
"RELEASE", "RENAME", "REPLACE", "RESTRICT", "RIGHT", "ROLLBACK", "ROW", "SAVEPOINT",
"SELECT", "SET", "TABLE", "TEMP", "TEMPORARY", "THEN", "TO", "TRANSACTION", "TRIGGER",
"UNION", "UNIQUE", "UPDATE", "USING", "VACUUM", "VALUES", "VIEW", "VIRTUAL", "WHEN",
"WHERE", "WITH", "WITHOUT"
]
return True
# if name.upper() in sqlite_keywords:
# raise ValueError(f"'{name}' 是SQLite保留字不建议使用请更换名称")

View File

@ -135,7 +135,7 @@ def get_user_lang(user_id: str) -> Language:
"""
获取用户的语言代码
"""
user = user_db.first(User, "user_id = ?", user_id, default=User(
user = user_db.first(User(), "user_id = ?", user_id, default=User(
user_id=user_id,
username="Unknown"
))

View File

@ -3,8 +3,6 @@ from urllib.parse import quote
import nonebot
from nonebot.adapters.onebot import v11, v12
from typing import Any
from .tools import encode_url
from .ly_typing import T_Bot, T_MessageEvent