diff --git a/nonebot_plugin_marshoai/decos.py b/nonebot_plugin_marshoai/decos.py index 66a12e30..ac59531e 100644 --- a/nonebot_plugin_marshoai/decos.py +++ b/nonebot_plugin_marshoai/decos.py @@ -2,14 +2,36 @@ from .instances import cache def from_cache(key): + """ + 当缓存中有数据时,直接返回缓存中的数据,否则执行函数并将结果存入缓存 + """ + def decorator(func): - def wrapper(*args, **kwargs): + async def wrapper(*args, **kwargs): cached = cache.get(key) if cached: return cached else: - result = func(*args, **kwargs) + result = await func(*args, **kwargs) cache.set(key, result) return result return wrapper + + return decorator + + +def update_to_cache(key): + """ + 执行函数并将结果存入缓存 + """ + + def decorator(func): + async def wrapper(*args, **kwargs): + result = await func(*args, **kwargs) + cache.set(key, result) + return result + + return wrapper + + return decorator diff --git a/nonebot_plugin_marshoai/util.py b/nonebot_plugin_marshoai/util.py index 0fbd6ec0..33f8e4fe 100755 --- a/nonebot_plugin_marshoai/util.py +++ b/nonebot_plugin_marshoai/util.py @@ -23,10 +23,11 @@ from ._types import DeveloperMessage from .config import config from .constants import CODE_BLOCK_PATTERN, IMG_LATEX_PATTERN, OPENAI_NEW_MODELS from .deal_latex import ConvertLatex +from .decos import from_cache, update_to_cache from .instances import cache -nickname_json = None # 记录昵称 -praises_json = None # 记录夸赞名单 +# nickname_json = None # 记录昵称 +# praises_json = None # 记录夸赞名单 loaded_target_list: List[str] = [] # 记录已恢复备份的上下文的列表 NOT_GIVEN = NotGiven() @@ -156,30 +157,29 @@ async def make_chat_openai( ) +@from_cache("praises") def get_praises(): - global praises_json - if praises_json is None: - praises_file = store.get_plugin_data_file( - "praises.json" - ) # 夸赞名单文件使用localstore存储 - if not praises_file.exists(): - with open(praises_file, "w", encoding="utf-8") as f: - json.dump(_praises_init_data, f, ensure_ascii=False, indent=4) - with open(praises_file, "r", encoding="utf-8") as f: - data = json.load(f) - praises_json = data + praises_file = store.get_plugin_data_file( + "praises.json" + ) # 夸赞名单文件使用localstore存储 + if not praises_file.exists(): + with open(praises_file, "w", encoding="utf-8") as f: + json.dump(_praises_init_data, f, ensure_ascii=False, indent=4) + with open(praises_file, "r", encoding="utf-8") as f: + data = json.load(f) + praises_json = data return praises_json +@update_to_cache("praises") async def refresh_praises_json(): - global praises_json praises_file = store.get_plugin_data_file("praises.json") if not praises_file.exists(): with open(praises_file, "w", encoding="utf-8") as f: json.dump(_praises_init_data, f, ensure_ascii=False, indent=4) # 异步? async with aiofiles.open(praises_file, "r", encoding="utf-8") as f: data = json.loads(await f.read()) - praises_json = data + return data def build_praises() -> str: @@ -211,22 +211,21 @@ async def load_context_from_json(name: str, path: str) -> list: return [] +@from_cache("nickname") async def get_nicknames(): - """获取nickname_json, 优先来源于全局变量""" - global nickname_json - if nickname_json is None: - filename = store.get_plugin_data_file("nickname.json") - # noinspection PyBroadException - try: - async with aiofiles.open(filename, "r", encoding="utf-8") as f: - nickname_json = json.loads(await f.read()) - except (json.JSONDecodeError, FileNotFoundError): - nickname_json = {} + """获取nickname_json, 优先来源于缓存""" + filename = store.get_plugin_data_file("nickname.json") + # noinspection PyBroadException + try: + async with aiofiles.open(filename, "r", encoding="utf-8") as f: + nickname_json = json.loads(await f.read()) + except (json.JSONDecodeError, FileNotFoundError): + nickname_json = {} return nickname_json +@update_to_cache("nickname") async def set_nickname(user_id: str, name: str): - global nickname_json filename = store.get_plugin_data_file("nickname.json") if not filename.exists(): data = {} @@ -238,18 +237,19 @@ async def set_nickname(user_id: str, name: str): del data[user_id] with open(filename, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=4) - nickname_json = data + return data +@update_to_cache("nickname") async def refresh_nickname_json(): - """强制刷新nickname_json, 刷新全局变量""" - global nickname_json + """强制刷新nickname_json""" # noinspection PyBroadException try: async with aiofiles.open( store.get_plugin_data_file("nickname.json"), "r", encoding="utf-8" ) as f: nickname_json = json.loads(await f.read()) + return nickname_json except (json.JSONDecodeError, FileNotFoundError): logger.error("刷新 nickname_json 表错误:无法载入 nickname.json 文件")