forked from bot/app
✨ orm框架新增@db.on_save回调函数,用于检测数据库更新时的变动
This commit is contained in:
parent
1787ef4db7
commit
4162ea33ff
@ -257,7 +257,6 @@ async def _(result: Arparma, bot: T_Bot, event: T_MessageEvent, matcher: Matcher
|
||||
"""
|
||||
api_name = result.main_args.get("api")
|
||||
args: tuple[str] = result.main_args.get("args", ()) # 类似于url参数,但每个参数间用空格分隔,空格是%20
|
||||
print(args)
|
||||
args_dict = {}
|
||||
|
||||
for arg in args:
|
||||
@ -392,6 +391,7 @@ need_group_id = (
|
||||
"send_msg",
|
||||
"set_group_card",
|
||||
"set_group_name",
|
||||
|
||||
"set_group_special_title",
|
||||
"get_group_member_info",
|
||||
"get_group_member_list",
|
||||
|
@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import random
|
||||
|
||||
import nonebot
|
||||
from nonebot import Bot, on_message, get_driver, require
|
||||
from nonebot.internal.matcher import Matcher
|
||||
from nonebot.permission import SUPERUSER
|
||||
@ -38,8 +39,6 @@ async def _(result: Arparma, event: T_MessageEvent, matcher: Matcher):
|
||||
if get_message_type(event) == "group":
|
||||
group_id = event.group_id
|
||||
probability = result.main_args.get("probability")
|
||||
# 保存到内存
|
||||
group_reply_probability[group_id] = probability
|
||||
# 保存到数据库
|
||||
group: Group = group_db.where_one(Group(), "group_id = ?", group_id, default=Group(group_id=str(group_id)))
|
||||
group.config["reply_probability"] = probability
|
||||
@ -49,6 +48,19 @@ async def _(result: Arparma, event: T_MessageEvent, matcher: Matcher):
|
||||
return
|
||||
|
||||
|
||||
@group_db.on_save
|
||||
def _(model: Group):
|
||||
"""
|
||||
在数据库更新时,更新内存中的回复概率
|
||||
Args:
|
||||
model:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
group_reply_probability[model.group_id] = model.config.get("reply_probability", default_reply_probability)
|
||||
|
||||
|
||||
@driver.on_bot_connect
|
||||
async def _(bot: Bot):
|
||||
global nicknames
|
||||
|
@ -2,9 +2,9 @@ import os
|
||||
import pickle
|
||||
import sqlite3
|
||||
from types import NoneType
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
from packaging.version import parse
|
||||
|
||||
import inspect
|
||||
import nonebot
|
||||
import pydantic
|
||||
from pydantic import BaseModel
|
||||
@ -31,6 +31,8 @@ class Database:
|
||||
self.conn = sqlite3.connect(db_name)
|
||||
self.cursor = self.conn.cursor()
|
||||
|
||||
self._on_save_callbacks = []
|
||||
|
||||
def where_one(self, model: LiteModel, condition: str = "", *args: Any, default: Any = None) -> LiteModel | Any | None:
|
||||
"""查询第一个
|
||||
Args:
|
||||
@ -94,6 +96,9 @@ class Database:
|
||||
else:
|
||||
self._save(model.dump(by_alias=True))
|
||||
|
||||
for callback in self._on_save_callbacks:
|
||||
callback(model)
|
||||
|
||||
def _save(self, obj: Any) -> Any:
|
||||
# obj = copy.deepcopy(obj)
|
||||
if isinstance(obj, dict):
|
||||
@ -301,6 +306,32 @@ class Database:
|
||||
result = self.cursor.execute(f"SELECT * FROM {table_name} WHERE id = ?", (foreign_id,)).fetchone()
|
||||
return dict(zip(fields, result))
|
||||
|
||||
def on_save(self, func: Callable[[LiteModel | Any], None]):
|
||||
"""
|
||||
装饰一个可调用对象使其在储存数据模型时被调用
|
||||
Args:
|
||||
func:
|
||||
Returns:
|
||||
"""
|
||||
|
||||
def wrapper(model):
|
||||
# 检查被装饰函数声明的model类型和传入的model类型是否一致
|
||||
sign = inspect.signature(func)
|
||||
if param := sign.parameters.get("model"):
|
||||
if isinstance(model, param.annotation):
|
||||
pass
|
||||
else:
|
||||
return
|
||||
else:
|
||||
return
|
||||
result = func(model)
|
||||
for callback in self._on_save_callbacks:
|
||||
callback(result)
|
||||
return result
|
||||
|
||||
self._on_save_callbacks.append(wrapper)
|
||||
return wrapper
|
||||
|
||||
TYPE_MAPPING = {
|
||||
int : "INTEGER",
|
||||
float : "REAL",
|
||||
|
Loading…
Reference in New Issue
Block a user