orm框架新增@db.on_save回调函数,用于检测数据库更新时的变动

This commit is contained in:
snowy 2024-06-22 14:17:14 +08:00
parent 1787ef4db7
commit 4162ea33ff
3 changed files with 51 additions and 8 deletions

View File

@ -257,7 +257,6 @@ async def _(result: Arparma, bot: T_Bot, event: T_MessageEvent, matcher: Matcher
""" """
api_name = result.main_args.get("api") api_name = result.main_args.get("api")
args: tuple[str] = result.main_args.get("args", ()) # 类似于url参数但每个参数间用空格分隔空格是%20 args: tuple[str] = result.main_args.get("args", ()) # 类似于url参数但每个参数间用空格分隔空格是%20
print(args)
args_dict = {} args_dict = {}
for arg in args: for arg in args:
@ -392,6 +391,7 @@ need_group_id = (
"send_msg", "send_msg",
"set_group_card", "set_group_card",
"set_group_name", "set_group_name",
"set_group_special_title", "set_group_special_title",
"get_group_member_info", "get_group_member_info",
"get_group_member_list", "get_group_member_list",

View File

@ -1,6 +1,7 @@
import asyncio import asyncio
import random import random
import nonebot
from nonebot import Bot, on_message, get_driver, require from nonebot import Bot, on_message, get_driver, require
from nonebot.internal.matcher import Matcher from nonebot.internal.matcher import Matcher
from nonebot.permission import SUPERUSER from nonebot.permission import SUPERUSER
@ -38,8 +39,6 @@ async def _(result: Arparma, event: T_MessageEvent, matcher: Matcher):
if get_message_type(event) == "group": if get_message_type(event) == "group":
group_id = event.group_id group_id = event.group_id
probability = result.main_args.get("probability") probability = result.main_args.get("probability")
# 保存到内存
group_reply_probability[group_id] = probability
# 保存到数据库 # 保存到数据库
group: Group = group_db.where_one(Group(), "group_id = ?", group_id, default=Group(group_id=str(group_id))) group: Group = group_db.where_one(Group(), "group_id = ?", group_id, default=Group(group_id=str(group_id)))
group.config["reply_probability"] = probability group.config["reply_probability"] = probability
@ -49,6 +48,19 @@ async def _(result: Arparma, event: T_MessageEvent, matcher: Matcher):
return return
@group_db.on_save
def _(model: Group):
"""
在数据库更新时更新内存中的回复概率
Args:
model:
Returns:
"""
group_reply_probability[model.group_id] = model.config.get("reply_probability", default_reply_probability)
@driver.on_bot_connect @driver.on_bot_connect
async def _(bot: Bot): async def _(bot: Bot):
global nicknames global nicknames
@ -85,7 +97,7 @@ async def _(event: T_MessageEvent, bot: Bot, state: T_State, matcher: Matcher):
reply = reply.replace("", "||").replace("", "||").replace("", "||").replace("", "||") reply = reply.replace("", "||").replace("", "||").replace("", "||").replace("", "||")
replies = reply.split("||") replies = reply.split("||")
for r in replies: for r in replies:
if r: # 防止空字符串 if r: # 防止空字符串
await asyncio.sleep(random.random() * 2) await asyncio.sleep(random.random() * 2)
await matcher.send(r) await matcher.send(r)
else: else:

View File

@ -2,9 +2,9 @@ import os
import pickle import pickle
import sqlite3 import sqlite3
from types import NoneType from types import NoneType
from typing import Any from typing import Any, Callable
from packaging.version import parse from packaging.version import parse
import inspect
import nonebot import nonebot
import pydantic import pydantic
from pydantic import BaseModel from pydantic import BaseModel
@ -31,6 +31,8 @@ class Database:
self.conn = sqlite3.connect(db_name) self.conn = sqlite3.connect(db_name)
self.cursor = self.conn.cursor() self.cursor = self.conn.cursor()
self._on_save_callbacks = []
def where_one(self, model: LiteModel, condition: str = "", *args: Any, default: Any = None) -> LiteModel | Any | None: def where_one(self, model: LiteModel, condition: str = "", *args: Any, default: Any = None) -> LiteModel | Any | None:
"""查询第一个 """查询第一个
Args: Args:
@ -94,6 +96,9 @@ class Database:
else: else:
self._save(model.dump(by_alias=True)) self._save(model.dump(by_alias=True))
for callback in self._on_save_callbacks:
callback(model)
def _save(self, obj: Any) -> Any: def _save(self, obj: Any) -> Any:
# obj = copy.deepcopy(obj) # obj = copy.deepcopy(obj)
if isinstance(obj, dict): if isinstance(obj, dict):
@ -156,7 +161,7 @@ class Database:
if field.startswith(self.BYTES_PREFIX): if field.startswith(self.BYTES_PREFIX):
if isinstance(value, bytes): if isinstance(value, bytes):
new_obj[field.replace(self.BYTES_PREFIX, "")] = self._load(pickle.loads(value)) new_obj[field.replace(self.BYTES_PREFIX, "")] = self._load(pickle.loads(value))
else: # 从value字段可能为Nonefix at 2024/6/13 else: # 从value字段可能为Nonefix at 2024/6/13
pass pass
# 暂时不作处理,后面再修 # 暂时不作处理,后面再修
@ -301,6 +306,32 @@ class Database:
result = self.cursor.execute(f"SELECT * FROM {table_name} WHERE id = ?", (foreign_id,)).fetchone() result = self.cursor.execute(f"SELECT * FROM {table_name} WHERE id = ?", (foreign_id,)).fetchone()
return dict(zip(fields, result)) return dict(zip(fields, result))
def on_save(self, func: Callable[[LiteModel | Any], None]):
"""
装饰一个可调用对象使其在储存数据模型时被调用
Args:
func:
Returns:
"""
def wrapper(model):
# 检查被装饰函数声明的model类型和传入的model类型是否一致
sign = inspect.signature(func)
if param := sign.parameters.get("model"):
if isinstance(model, param.annotation):
pass
else:
return
else:
return
result = func(model)
for callback in self._on_save_callbacks:
callback(result)
return result
self._on_save_callbacks.append(wrapper)
return wrapper
TYPE_MAPPING = { TYPE_MAPPING = {
int : "INTEGER", int : "INTEGER",
float : "REAL", float : "REAL",