forked from bot/app
✨ orm框架新增@db.on_save回调函数,用于检测数据库更新时的变动
This commit is contained in:
parent
1787ef4db7
commit
4162ea33ff
@ -257,7 +257,6 @@ async def _(result: Arparma, bot: T_Bot, event: T_MessageEvent, matcher: Matcher
|
|||||||
"""
|
"""
|
||||||
api_name = result.main_args.get("api")
|
api_name = result.main_args.get("api")
|
||||||
args: tuple[str] = result.main_args.get("args", ()) # 类似于url参数,但每个参数间用空格分隔,空格是%20
|
args: tuple[str] = result.main_args.get("args", ()) # 类似于url参数,但每个参数间用空格分隔,空格是%20
|
||||||
print(args)
|
|
||||||
args_dict = {}
|
args_dict = {}
|
||||||
|
|
||||||
for arg in args:
|
for arg in args:
|
||||||
@ -392,6 +391,7 @@ need_group_id = (
|
|||||||
"send_msg",
|
"send_msg",
|
||||||
"set_group_card",
|
"set_group_card",
|
||||||
"set_group_name",
|
"set_group_name",
|
||||||
|
|
||||||
"set_group_special_title",
|
"set_group_special_title",
|
||||||
"get_group_member_info",
|
"get_group_member_info",
|
||||||
"get_group_member_list",
|
"get_group_member_list",
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
import nonebot
|
||||||
from nonebot import Bot, on_message, get_driver, require
|
from nonebot import Bot, on_message, get_driver, require
|
||||||
from nonebot.internal.matcher import Matcher
|
from nonebot.internal.matcher import Matcher
|
||||||
from nonebot.permission import SUPERUSER
|
from nonebot.permission import SUPERUSER
|
||||||
@ -38,8 +39,6 @@ async def _(result: Arparma, event: T_MessageEvent, matcher: Matcher):
|
|||||||
if get_message_type(event) == "group":
|
if get_message_type(event) == "group":
|
||||||
group_id = event.group_id
|
group_id = event.group_id
|
||||||
probability = result.main_args.get("probability")
|
probability = result.main_args.get("probability")
|
||||||
# 保存到内存
|
|
||||||
group_reply_probability[group_id] = probability
|
|
||||||
# 保存到数据库
|
# 保存到数据库
|
||||||
group: Group = group_db.where_one(Group(), "group_id = ?", group_id, default=Group(group_id=str(group_id)))
|
group: Group = group_db.where_one(Group(), "group_id = ?", group_id, default=Group(group_id=str(group_id)))
|
||||||
group.config["reply_probability"] = probability
|
group.config["reply_probability"] = probability
|
||||||
@ -49,6 +48,19 @@ async def _(result: Arparma, event: T_MessageEvent, matcher: Matcher):
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@group_db.on_save
|
||||||
|
def _(model: Group):
|
||||||
|
"""
|
||||||
|
在数据库更新时,更新内存中的回复概率
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
group_reply_probability[model.group_id] = model.config.get("reply_probability", default_reply_probability)
|
||||||
|
|
||||||
|
|
||||||
@driver.on_bot_connect
|
@driver.on_bot_connect
|
||||||
async def _(bot: Bot):
|
async def _(bot: Bot):
|
||||||
global nicknames
|
global nicknames
|
||||||
@ -85,7 +97,7 @@ async def _(event: T_MessageEvent, bot: Bot, state: T_State, matcher: Matcher):
|
|||||||
reply = reply.replace("。", "||").replace(",", "||").replace("!", "||").replace("?", "||")
|
reply = reply.replace("。", "||").replace(",", "||").replace("!", "||").replace("?", "||")
|
||||||
replies = reply.split("||")
|
replies = reply.split("||")
|
||||||
for r in replies:
|
for r in replies:
|
||||||
if r: # 防止空字符串
|
if r: # 防止空字符串
|
||||||
await asyncio.sleep(random.random() * 2)
|
await asyncio.sleep(random.random() * 2)
|
||||||
await matcher.send(r)
|
await matcher.send(r)
|
||||||
else:
|
else:
|
||||||
|
@ -2,9 +2,9 @@ import os
|
|||||||
import pickle
|
import pickle
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from types import NoneType
|
from types import NoneType
|
||||||
from typing import Any
|
from typing import Any, Callable
|
||||||
from packaging.version import parse
|
from packaging.version import parse
|
||||||
|
import inspect
|
||||||
import nonebot
|
import nonebot
|
||||||
import pydantic
|
import pydantic
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@ -31,6 +31,8 @@ class Database:
|
|||||||
self.conn = sqlite3.connect(db_name)
|
self.conn = sqlite3.connect(db_name)
|
||||||
self.cursor = self.conn.cursor()
|
self.cursor = self.conn.cursor()
|
||||||
|
|
||||||
|
self._on_save_callbacks = []
|
||||||
|
|
||||||
def where_one(self, model: LiteModel, condition: str = "", *args: Any, default: Any = None) -> LiteModel | Any | None:
|
def where_one(self, model: LiteModel, condition: str = "", *args: Any, default: Any = None) -> LiteModel | Any | None:
|
||||||
"""查询第一个
|
"""查询第一个
|
||||||
Args:
|
Args:
|
||||||
@ -94,6 +96,9 @@ class Database:
|
|||||||
else:
|
else:
|
||||||
self._save(model.dump(by_alias=True))
|
self._save(model.dump(by_alias=True))
|
||||||
|
|
||||||
|
for callback in self._on_save_callbacks:
|
||||||
|
callback(model)
|
||||||
|
|
||||||
def _save(self, obj: Any) -> Any:
|
def _save(self, obj: Any) -> Any:
|
||||||
# obj = copy.deepcopy(obj)
|
# obj = copy.deepcopy(obj)
|
||||||
if isinstance(obj, dict):
|
if isinstance(obj, dict):
|
||||||
@ -156,7 +161,7 @@ class Database:
|
|||||||
if field.startswith(self.BYTES_PREFIX):
|
if field.startswith(self.BYTES_PREFIX):
|
||||||
if isinstance(value, bytes):
|
if isinstance(value, bytes):
|
||||||
new_obj[field.replace(self.BYTES_PREFIX, "")] = self._load(pickle.loads(value))
|
new_obj[field.replace(self.BYTES_PREFIX, "")] = self._load(pickle.loads(value))
|
||||||
else: # 从value字段可能为None,fix at 2024/6/13
|
else: # 从value字段可能为None,fix at 2024/6/13
|
||||||
pass
|
pass
|
||||||
# 暂时不作处理,后面再修
|
# 暂时不作处理,后面再修
|
||||||
|
|
||||||
@ -301,6 +306,32 @@ class Database:
|
|||||||
result = self.cursor.execute(f"SELECT * FROM {table_name} WHERE id = ?", (foreign_id,)).fetchone()
|
result = self.cursor.execute(f"SELECT * FROM {table_name} WHERE id = ?", (foreign_id,)).fetchone()
|
||||||
return dict(zip(fields, result))
|
return dict(zip(fields, result))
|
||||||
|
|
||||||
|
def on_save(self, func: Callable[[LiteModel | Any], None]):
|
||||||
|
"""
|
||||||
|
装饰一个可调用对象使其在储存数据模型时被调用
|
||||||
|
Args:
|
||||||
|
func:
|
||||||
|
Returns:
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrapper(model):
|
||||||
|
# 检查被装饰函数声明的model类型和传入的model类型是否一致
|
||||||
|
sign = inspect.signature(func)
|
||||||
|
if param := sign.parameters.get("model"):
|
||||||
|
if isinstance(model, param.annotation):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
result = func(model)
|
||||||
|
for callback in self._on_save_callbacks:
|
||||||
|
callback(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
self._on_save_callbacks.append(wrapper)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
TYPE_MAPPING = {
|
TYPE_MAPPING = {
|
||||||
int : "INTEGER",
|
int : "INTEGER",
|
||||||
float : "REAL",
|
float : "REAL",
|
||||||
|
Loading…
Reference in New Issue
Block a user