From 4162ea33ff0c2c0f9a808edbdc2d3da9f7e27f58 Mon Sep 17 00:00:00 2001 From: snowy Date: Sat, 22 Jun 2024 14:17:14 +0800 Subject: [PATCH] =?UTF-8?q?:sparkles:=20orm=E6=A1=86=E6=9E=B6=E6=96=B0?= =?UTF-8?q?=E5=A2=9E@db.on=5Fsave=E5=9B=9E=E8=B0=83=E5=87=BD=E6=95=B0?= =?UTF-8?q?=EF=BC=8C=E7=94=A8=E4=BA=8E=E6=A3=80=E6=B5=8B=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=BA=93=E6=9B=B4=E6=96=B0=E6=97=B6=E7=9A=84=E5=8F=98=E5=8A=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- liteyuki/liteyuki_main/core.py | 4 +- .../plugins/liteyuki_smart_reply/matchers.py | 18 +++++++-- liteyuki/utils/base/data.py | 37 +++++++++++++++++-- 3 files changed, 51 insertions(+), 8 deletions(-) diff --git a/liteyuki/liteyuki_main/core.py b/liteyuki/liteyuki_main/core.py index 0b33f578..058684c7 100644 --- a/liteyuki/liteyuki_main/core.py +++ b/liteyuki/liteyuki_main/core.py @@ -257,7 +257,6 @@ async def _(result: Arparma, bot: T_Bot, event: T_MessageEvent, matcher: Matcher """ api_name = result.main_args.get("api") args: tuple[str] = result.main_args.get("args", ()) # 类似于url参数,但每个参数间用空格分隔,空格是%20 - print(args) args_dict = {} for arg in args: @@ -392,8 +391,9 @@ need_group_id = ( "send_msg", "set_group_card", "set_group_name", + "set_group_special_title", "get_group_member_info", "get_group_member_list", "get_group_honor_info" -) +) \ No newline at end of file diff --git a/liteyuki/plugins/liteyuki_smart_reply/matchers.py b/liteyuki/plugins/liteyuki_smart_reply/matchers.py index 8e18ca89..a503917c 100644 --- a/liteyuki/plugins/liteyuki_smart_reply/matchers.py +++ b/liteyuki/plugins/liteyuki_smart_reply/matchers.py @@ -1,6 +1,7 @@ import asyncio import random +import nonebot from nonebot import Bot, on_message, get_driver, require from nonebot.internal.matcher import Matcher from nonebot.permission import SUPERUSER @@ -38,8 +39,6 @@ async def _(result: Arparma, event: T_MessageEvent, matcher: Matcher): if get_message_type(event) == "group": group_id = event.group_id probability = result.main_args.get("probability") - # 保存到内存 - group_reply_probability[group_id] = probability # 保存到数据库 group: Group = group_db.where_one(Group(), "group_id = ?", group_id, default=Group(group_id=str(group_id))) group.config["reply_probability"] = probability @@ -49,6 +48,19 @@ async def _(result: Arparma, event: T_MessageEvent, matcher: Matcher): return +@group_db.on_save +def _(model: Group): + """ + 在数据库更新时,更新内存中的回复概率 + Args: + model: + + Returns: + + """ + group_reply_probability[model.group_id] = model.config.get("reply_probability", default_reply_probability) + + @driver.on_bot_connect async def _(bot: Bot): global nicknames @@ -85,7 +97,7 @@ async def _(event: T_MessageEvent, bot: Bot, state: T_State, matcher: Matcher): reply = reply.replace("。", "||").replace(",", "||").replace("!", "||").replace("?", "||") replies = reply.split("||") for r in replies: - if r: # 防止空字符串 + if r: # 防止空字符串 await asyncio.sleep(random.random() * 2) await matcher.send(r) else: diff --git a/liteyuki/utils/base/data.py b/liteyuki/utils/base/data.py index e5e33098..453690ea 100644 --- a/liteyuki/utils/base/data.py +++ b/liteyuki/utils/base/data.py @@ -2,9 +2,9 @@ import os import pickle import sqlite3 from types import NoneType -from typing import Any +from typing import Any, Callable from packaging.version import parse - +import inspect import nonebot import pydantic from pydantic import BaseModel @@ -31,6 +31,8 @@ class Database: self.conn = sqlite3.connect(db_name) self.cursor = self.conn.cursor() + self._on_save_callbacks = [] + def where_one(self, model: LiteModel, condition: str = "", *args: Any, default: Any = None) -> LiteModel | Any | None: """查询第一个 Args: @@ -94,6 +96,9 @@ class Database: else: self._save(model.dump(by_alias=True)) + for callback in self._on_save_callbacks: + callback(model) + def _save(self, obj: Any) -> Any: # obj = copy.deepcopy(obj) if isinstance(obj, dict): @@ -156,7 +161,7 @@ class Database: if field.startswith(self.BYTES_PREFIX): if isinstance(value, bytes): new_obj[field.replace(self.BYTES_PREFIX, "")] = self._load(pickle.loads(value)) - else: # 从value字段可能为None,fix at 2024/6/13 + else: # 从value字段可能为None,fix at 2024/6/13 pass # 暂时不作处理,后面再修 @@ -301,6 +306,32 @@ class Database: result = self.cursor.execute(f"SELECT * FROM {table_name} WHERE id = ?", (foreign_id,)).fetchone() return dict(zip(fields, result)) + def on_save(self, func: Callable[[LiteModel | Any], None]): + """ + 装饰一个可调用对象使其在储存数据模型时被调用 + Args: + func: + Returns: + """ + + def wrapper(model): + # 检查被装饰函数声明的model类型和传入的model类型是否一致 + sign = inspect.signature(func) + if param := sign.parameters.get("model"): + if isinstance(model, param.annotation): + pass + else: + return + else: + return + result = func(model) + for callback in self._on_save_callbacks: + callback(result) + return result + + self._on_save_callbacks.append(wrapper) + return wrapper + TYPE_MAPPING = { int : "INTEGER", float : "REAL",