orm框架新增@db.on_save回调函数,用于检测数据库更新时的变动

This commit is contained in:
远野千束(神羽) 2024-06-22 14:17:14 +08:00
parent 1787ef4db7
commit 4162ea33ff
3 changed files with 51 additions and 8 deletions

View File

@ -257,7 +257,6 @@ async def _(result: Arparma, bot: T_Bot, event: T_MessageEvent, matcher: Matcher
"""
api_name = result.main_args.get("api")
args: tuple[str] = result.main_args.get("args", ()) # 类似于url参数但每个参数间用空格分隔空格是%20
print(args)
args_dict = {}
for arg in args:
@ -392,8 +391,9 @@ need_group_id = (
"send_msg",
"set_group_card",
"set_group_name",
"set_group_special_title",
"get_group_member_info",
"get_group_member_list",
"get_group_honor_info"
)
)

View File

@ -1,6 +1,7 @@
import asyncio
import random
import nonebot
from nonebot import Bot, on_message, get_driver, require
from nonebot.internal.matcher import Matcher
from nonebot.permission import SUPERUSER
@ -38,8 +39,6 @@ async def _(result: Arparma, event: T_MessageEvent, matcher: Matcher):
if get_message_type(event) == "group":
group_id = event.group_id
probability = result.main_args.get("probability")
# 保存到内存
group_reply_probability[group_id] = probability
# 保存到数据库
group: Group = group_db.where_one(Group(), "group_id = ?", group_id, default=Group(group_id=str(group_id)))
group.config["reply_probability"] = probability
@ -49,6 +48,19 @@ async def _(result: Arparma, event: T_MessageEvent, matcher: Matcher):
return
@group_db.on_save
def _(model: Group):
"""
在数据库更新时更新内存中的回复概率
Args:
model:
Returns:
"""
group_reply_probability[model.group_id] = model.config.get("reply_probability", default_reply_probability)
@driver.on_bot_connect
async def _(bot: Bot):
global nicknames
@ -85,7 +97,7 @@ async def _(event: T_MessageEvent, bot: Bot, state: T_State, matcher: Matcher):
reply = reply.replace("", "||").replace("", "||").replace("", "||").replace("", "||")
replies = reply.split("||")
for r in replies:
if r: # 防止空字符串
if r: # 防止空字符串
await asyncio.sleep(random.random() * 2)
await matcher.send(r)
else:

View File

@ -2,9 +2,9 @@ import os
import pickle
import sqlite3
from types import NoneType
from typing import Any
from typing import Any, Callable
from packaging.version import parse
import inspect
import nonebot
import pydantic
from pydantic import BaseModel
@ -31,6 +31,8 @@ class Database:
self.conn = sqlite3.connect(db_name)
self.cursor = self.conn.cursor()
self._on_save_callbacks = []
def where_one(self, model: LiteModel, condition: str = "", *args: Any, default: Any = None) -> LiteModel | Any | None:
"""查询第一个
Args:
@ -94,6 +96,9 @@ class Database:
else:
self._save(model.dump(by_alias=True))
for callback in self._on_save_callbacks:
callback(model)
def _save(self, obj: Any) -> Any:
# obj = copy.deepcopy(obj)
if isinstance(obj, dict):
@ -156,7 +161,7 @@ class Database:
if field.startswith(self.BYTES_PREFIX):
if isinstance(value, bytes):
new_obj[field.replace(self.BYTES_PREFIX, "")] = self._load(pickle.loads(value))
else: # 从value字段可能为Nonefix at 2024/6/13
else: # 从value字段可能为Nonefix at 2024/6/13
pass
# 暂时不作处理,后面再修
@ -301,6 +306,32 @@ class Database:
result = self.cursor.execute(f"SELECT * FROM {table_name} WHERE id = ?", (foreign_id,)).fetchone()
return dict(zip(fields, result))
def on_save(self, func: Callable[[LiteModel | Any], None]):
"""
装饰一个可调用对象使其在储存数据模型时被调用
Args:
func:
Returns:
"""
def wrapper(model):
# 检查被装饰函数声明的model类型和传入的model类型是否一致
sign = inspect.signature(func)
if param := sign.parameters.get("model"):
if isinstance(model, param.annotation):
pass
else:
return
else:
return
result = func(model)
for callback in self._on_save_callbacks:
callback(result)
return result
self._on_save_callbacks.append(wrapper)
return wrapper
TYPE_MAPPING = {
int : "INTEGER",
float : "REAL",