diff --git a/liteyuki/liteyuki_main/core.py b/liteyuki/liteyuki_main/core.py index 0b33f57..058684c 100644 --- a/liteyuki/liteyuki_main/core.py +++ b/liteyuki/liteyuki_main/core.py @@ -257,7 +257,6 @@ async def _(result: Arparma, bot: T_Bot, event: T_MessageEvent, matcher: Matcher """ api_name = result.main_args.get("api") args: tuple[str] = result.main_args.get("args", ()) # 类似于url参数,但每个参数间用空格分隔,空格是%20 - print(args) args_dict = {} for arg in args: @@ -392,8 +391,9 @@ need_group_id = ( "send_msg", "set_group_card", "set_group_name", + "set_group_special_title", "get_group_member_info", "get_group_member_list", "get_group_honor_info" -) +) \ No newline at end of file diff --git a/liteyuki/plugins/liteyuki_smart_reply/matchers.py b/liteyuki/plugins/liteyuki_smart_reply/matchers.py index 8e18ca8..a503917 100644 --- a/liteyuki/plugins/liteyuki_smart_reply/matchers.py +++ b/liteyuki/plugins/liteyuki_smart_reply/matchers.py @@ -1,6 +1,7 @@ import asyncio import random +import nonebot from nonebot import Bot, on_message, get_driver, require from nonebot.internal.matcher import Matcher from nonebot.permission import SUPERUSER @@ -38,8 +39,6 @@ async def _(result: Arparma, event: T_MessageEvent, matcher: Matcher): if get_message_type(event) == "group": group_id = event.group_id probability = result.main_args.get("probability") - # 保存到内存 - group_reply_probability[group_id] = probability # 保存到数据库 group: Group = group_db.where_one(Group(), "group_id = ?", group_id, default=Group(group_id=str(group_id))) group.config["reply_probability"] = probability @@ -49,6 +48,19 @@ async def _(result: Arparma, event: T_MessageEvent, matcher: Matcher): return +@group_db.on_save +def _(model: Group): + """ + 在数据库更新时,更新内存中的回复概率 + Args: + model: + + Returns: + + """ + group_reply_probability[model.group_id] = model.config.get("reply_probability", default_reply_probability) + + @driver.on_bot_connect async def _(bot: Bot): global nicknames @@ -85,7 +97,7 @@ async def _(event: T_MessageEvent, bot: Bot, state: T_State, matcher: Matcher): reply = reply.replace("。", "||").replace(",", "||").replace("!", "||").replace("?", "||") replies = reply.split("||") for r in replies: - if r: # 防止空字符串 + if r: # 防止空字符串 await asyncio.sleep(random.random() * 2) await matcher.send(r) else: diff --git a/liteyuki/utils/base/data.py b/liteyuki/utils/base/data.py index e5e3309..453690e 100644 --- a/liteyuki/utils/base/data.py +++ b/liteyuki/utils/base/data.py @@ -2,9 +2,9 @@ import os import pickle import sqlite3 from types import NoneType -from typing import Any +from typing import Any, Callable from packaging.version import parse - +import inspect import nonebot import pydantic from pydantic import BaseModel @@ -31,6 +31,8 @@ class Database: self.conn = sqlite3.connect(db_name) self.cursor = self.conn.cursor() + self._on_save_callbacks = [] + def where_one(self, model: LiteModel, condition: str = "", *args: Any, default: Any = None) -> LiteModel | Any | None: """查询第一个 Args: @@ -94,6 +96,9 @@ class Database: else: self._save(model.dump(by_alias=True)) + for callback in self._on_save_callbacks: + callback(model) + def _save(self, obj: Any) -> Any: # obj = copy.deepcopy(obj) if isinstance(obj, dict): @@ -156,7 +161,7 @@ class Database: if field.startswith(self.BYTES_PREFIX): if isinstance(value, bytes): new_obj[field.replace(self.BYTES_PREFIX, "")] = self._load(pickle.loads(value)) - else: # 从value字段可能为None,fix at 2024/6/13 + else: # 从value字段可能为None,fix at 2024/6/13 pass # 暂时不作处理,后面再修 @@ -301,6 +306,32 @@ class Database: result = self.cursor.execute(f"SELECT * FROM {table_name} WHERE id = ?", (foreign_id,)).fetchone() return dict(zip(fields, result)) + def on_save(self, func: Callable[[LiteModel | Any], None]): + """ + 装饰一个可调用对象使其在储存数据模型时被调用 + Args: + func: + Returns: + """ + + def wrapper(model): + # 检查被装饰函数声明的model类型和传入的model类型是否一致 + sign = inspect.signature(func) + if param := sign.parameters.get("model"): + if isinstance(model, param.annotation): + pass + else: + return + else: + return + result = func(model) + for callback in self._on_save_callbacks: + callback(result) + return result + + self._on_save_callbacks.append(wrapper) + return wrapper + TYPE_MAPPING = { int : "INTEGER", float : "REAL",