diff --git a/nonebot_plugin_marshoai/cache/decos.py b/nonebot_plugin_marshoai/cache/decos.py new file mode 100644 index 00000000..dd15e79a --- /dev/null +++ b/nonebot_plugin_marshoai/cache/decos.py @@ -0,0 +1,39 @@ +from ..models import Cache + +cache = Cache() + + +def from_cache(key): + """ + 当缓存中有数据时,直接返回缓存中的数据,否则执行函数并将结果存入缓存 + """ + + def decorator(func): + async def wrapper(*args, **kwargs): + cached = cache.get(key) + if cached: + return cached + else: + result = await func(*args, **kwargs) + cache.set(key, result) + return result + + return wrapper + + return decorator + + +def update_to_cache(key): + """ + 执行函数并将结果存入缓存 + """ + + def decorator(func): + async def wrapper(*args, **kwargs): + result = await func(*args, **kwargs) + cache.set(key, result) + return result + + return wrapper + + return decorator diff --git a/nonebot_plugin_marshoai/decos.py b/nonebot_plugin_marshoai/decos.py deleted file mode 100644 index 66a12e30..00000000 --- a/nonebot_plugin_marshoai/decos.py +++ /dev/null @@ -1,15 +0,0 @@ -from .instances import cache - - -def from_cache(key): - def decorator(func): - def wrapper(*args, **kwargs): - cached = cache.get(key) - if cached: - return cached - else: - result = func(*args, **kwargs) - cache.set(key, result) - return result - - return wrapper diff --git a/nonebot_plugin_marshoai/handler.py b/nonebot_plugin_marshoai/handler.py new file mode 100644 index 00000000..e67a96c5 --- /dev/null +++ b/nonebot_plugin_marshoai/handler.py @@ -0,0 +1,241 @@ +import json +from typing import Optional, Tuple, Union + +from azure.ai.inference.models import ( + CompletionsFinishReason, + ImageContentItem, + ImageUrl, + TextContentItem, + ToolMessage, + UserMessage, +) +from nonebot.adapters import Bot, Event +from nonebot.log import logger +from nonebot.matcher import ( + Matcher, + current_bot, + current_event, + current_matcher, +) +from nonebot_plugin_alconna.uniseg import UniMessage, UniMsg +from openai import AsyncOpenAI +from openai.types.chat import ChatCompletion, ChatCompletionMessage + +from .config import config +from .constants import SUPPORT_IMAGE_MODELS +from .instances import target_list +from .models import MarshoContext +from .plugin.func_call.caller import get_function_calls +from .plugin.func_call.models import SessionContext +from .util import ( + extract_content_and_think, + get_image_b64, + get_nickname_by_user_id, + get_prompt, + make_chat_openai, + parse_richtext, +) + + +class MarshoHandler: + def __init__( + self, + client: AsyncOpenAI, + context: MarshoContext, + ): + self.client = client + self.context = context + self.bot: Bot = current_bot.get() + self.event: Event = current_event.get() + # self.state: T_State = current_handler.get().state + self.matcher: Matcher = current_matcher.get() + self.message_id: str = UniMessage.get_message_id(self.event) + self.target = UniMessage.get_target(self.event) + + async def process_user_input( + self, user_input: UniMsg, model_name: str + ) -> Union[str, list]: + """ + 处理用户输入为可输入 API 的格式,并添加昵称提示 + """ + is_support_image_model = ( + model_name.lower() + in SUPPORT_IMAGE_MODELS + config.marshoai_additional_image_models + ) + usermsg = [] if is_support_image_model else "" + user_nickname = await get_nickname_by_user_id(self.event.get_user_id()) + if user_nickname: + nickname_prompt = f"\n此消息的说话者为: {user_nickname}" + else: + nickname_prompt = "" + for i in user_input: # type: ignore + if i.type == "text": + if is_support_image_model: + usermsg += [TextContentItem(text=i.data["text"] + nickname_prompt).as_dict()] # type: ignore + else: + usermsg += str(i.data["text"] + nickname_prompt) # type: ignore + elif i.type == "image": + if is_support_image_model: + usermsg.append( # type: ignore + ImageContentItem( + image_url=ImageUrl( # type: ignore + url=str(await get_image_b64(i.data["url"])) # type: ignore + ) # type: ignore + ).as_dict() # type: ignore + ) # type: ignore + logger.info(f"输入图片 {i.data['url']}") + elif config.marshoai_enable_support_image_tip: + await UniMessage( + "*此模型不支持图片处理或管理员未启用此模型的图片支持。图片将被忽略。" + ).send() + return usermsg # type: ignore + + async def handle_single_chat( + self, + user_message: Union[str, list], + model_name: str, + tools_list: list, + tool_message: Optional[list] = None, + ) -> ChatCompletion: + """ + 处理单条聊天 + """ + + context_msg = get_prompt(model_name) + ( + self.context.build(self.target.id, self.target.private) + ) + response = await make_chat_openai( + client=self.client, + msg=context_msg + [UserMessage(content=user_message).as_dict()] + (tool_message if tool_message else []), # type: ignore + model_name=model_name, + tools=tools_list if tools_list else None, + ) + return response + + async def handle_function_call( + self, + completion: ChatCompletion, + user_message: Union[str, list], + model_name: str, + tools_list: list, + ): + # function call + # 需要获取额外信息,调用函数工具 + tool_msg = [] + choice = completion.choices[0] + # await UniMessage(str(response)).send() + tool_calls = choice.message.tool_calls + # try: + # if tool_calls[0]["function"]["name"].startswith("$"): + # choice.message.tool_calls[0][ + # "type" + # ] = "builtin_function" # 兼容 moonshot AI 内置函数的临时方案 + # except: + # pass + tool_msg.append(choice.message) + for tool_call in tool_calls: + try: + function_args = json.loads(tool_call.function.arguments) + except json.JSONDecodeError: + function_args = json.loads( + tool_call.function.arguments.replace("'", '"') + ) + # 删除args的placeholder参数 + if "placeholder" in function_args: + del function_args["placeholder"] + logger.info( + f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:" + + "\n".join([f"{k}={v}" for k, v in function_args.items()]) + ) + await UniMessage( + f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:" + + "\n".join([f"{k}={v}" for k, v in function_args.items()]) + ).send() + if caller := get_function_calls().get(tool_call.function.name): + logger.debug(f"调用插件函数 {caller.full_name}") + # 权限检查,规则检查 TODO + # 实现依赖注入,检查函数参数及参数注解类型,对Event类型的参数进行注入 + func_return = await caller.with_ctx( + SessionContext( + bot=self.bot, + event=self.event, + matcher=self.matcher, + ) + ).call(**function_args) + else: + logger.error(f"未找到函数 {tool_call.function.name.replace('-', '.')}") + func_return = f"未找到函数 {tool_call.function.name.replace('-', '.')}" + tool_msg.append( + ToolMessage(tool_call_id=tool_call.id, content=func_return).as_dict() # type: ignore + ) + # tool_msg[0]["tool_calls"][0]["type"] = "builtin_function" + # await UniMessage(str(tool_msg)).send() + return await self.handle_common_chat( + user_message=user_message, + model_name=model_name, + tools_list=tools_list, + tool_message=tool_msg, + ) + + async def handle_common_chat( + self, + user_message: Union[str, list], + model_name: str, + tools_list: list, + stream: bool = False, + tool_message: Optional[list] = None, + ) -> Union[Tuple[UserMessage, ChatCompletionMessage], None]: + """ + 处理一般聊天 + """ + global target_list + if stream: + raise NotImplementedError + response = await self.handle_single_chat( + user_message=user_message, + model_name=model_name, + tools_list=tools_list, + tool_message=tool_message, + ) + choice = response.choices[0] + # Sprint(choice) + # 当tool_calls非空时,将finish_reason设置为TOOL_CALLS + if choice.message.tool_calls is not None and config.marshoai_fix_toolcalls: + choice.finish_reason = "tool_calls" + logger.info(f"完成原因:{choice.finish_reason}") + if choice.finish_reason == CompletionsFinishReason.STOPPED: + + ##### DeepSeek-R1 兼容部分 ##### + choice_msg_content, choice_msg_thinking, choice_msg_after = ( + extract_content_and_think(choice.message) + ) + if choice_msg_thinking and config.marshoai_send_thinking: + await UniMessage("思维链:\n" + choice_msg_thinking).send() + ##### 兼容部分结束 ##### + + if [self.target.id, self.target.private] not in target_list: + target_list.append([self.target.id, self.target.private]) + + # 对话成功发送消息 + if config.marshoai_enable_richtext_parse: + await (await parse_richtext(str(choice_msg_content))).send( + reply_to=True + ) + else: + await UniMessage(str(choice_msg_content)).send(reply_to=True) + return UserMessage(content=user_message), choice_msg_after + elif choice.finish_reason == CompletionsFinishReason.CONTENT_FILTERED: + + # 对话失败,消息过滤 + + await UniMessage("*已被内容过滤器过滤。请调整聊天内容后重试。").send( + reply_to=True + ) + return None + elif choice.finish_reason == CompletionsFinishReason.TOOL_CALLS: + return await self.handle_function_call( + response, user_message, model_name, tools_list + ) + else: + await UniMessage(f"意外的完成原因:{choice.finish_reason}").send() + return None diff --git a/nonebot_plugin_marshoai/hooks.py b/nonebot_plugin_marshoai/hooks.py index 8b6ddd0e..6e6f439e 100644 --- a/nonebot_plugin_marshoai/hooks.py +++ b/nonebot_plugin_marshoai/hooks.py @@ -6,7 +6,7 @@ import nonebot_plugin_localstore as store from nonebot import logger from .config import config -from .instances import * +from .instances import context, driver, target_list, tools from .plugin import load_plugin, load_plugins from .util import get_backup_context, save_context_to_json diff --git a/nonebot_plugin_marshoai/instances.py b/nonebot_plugin_marshoai/instances.py index d450d2d2..d5aaf709 100644 --- a/nonebot_plugin_marshoai/instances.py +++ b/nonebot_plugin_marshoai/instances.py @@ -3,7 +3,7 @@ from nonebot import get_driver from openai import AsyncOpenAI from .config import config -from .models import Cache, MarshoContext, MarshoTools +from .models import MarshoContext, MarshoTools driver = get_driver() @@ -11,7 +11,6 @@ command_start = driver.config.command_start model_name = config.marshoai_default_model context = MarshoContext() tools = MarshoTools() -cache = Cache() token = config.marshoai_token endpoint = config.marshoai_azure_endpoint # client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token)) diff --git a/nonebot_plugin_marshoai/marsho.py b/nonebot_plugin_marshoai/marsho.py index ce20fba6..a922cc80 100644 --- a/nonebot_plugin_marshoai/marsho.py +++ b/nonebot_plugin_marshoai/marsho.py @@ -1,5 +1,4 @@ import contextlib -import json import traceback from typing import Optional @@ -7,10 +6,6 @@ from arclet.alconna import Alconna, AllParam, Args from azure.ai.inference.models import ( AssistantMessage, CompletionsFinishReason, - ImageContentItem, - ImageUrl, - TextContentItem, - ToolMessage, UserMessage, ) from nonebot import logger, on_command, on_message @@ -24,11 +19,11 @@ from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna from .config import config from .constants import INTRODUCTION, SUPPORT_IMAGE_MODELS +from .handler import MarshoHandler from .hooks import * from .instances import client, context, model_name, target_list, tools from .metadata import metadata from .plugin.func_call.caller import get_function_calls -from .plugin.func_call.models import SessionContext from .util import * @@ -232,16 +227,16 @@ async def marsho( # 发送说明 # await UniMessage(metadata.usage + "\n当前使用的模型:" + model_name).send() await marsho_cmd.finish(INTRODUCTION) + backup_context = await get_backup_context(target.id, target.private) + if backup_context: + context.set_context( + backup_context, target.id, target.private + ) # 加载历史记录 + logger.info(f"已恢复会话 {target.id} 的上下文备份~") + handler = MarshoHandler(client, context) try: - user_id = event.get_user_id() - nicknames = await get_nicknames() - user_nickname = nicknames.get(user_id, "") - if user_nickname != "": - nickname_prompt = ( - f"\n*此消息的说话者id为:{user_id},名字为:{user_nickname}*" - ) - else: - nickname_prompt = "" + user_nickname = await get_nickname_by_user_id(event.get_user_id()) + if not user_nickname: # 用户名无法获取,暂时注释 # user_nickname = event.sender.nickname # 未设置昵称时获取用户名 # nickname_prompt = f"\n*此消息的说话者:{user_nickname}" @@ -255,188 +250,21 @@ async def marsho( "※你未设置自己的昵称。推荐使用「nickname [昵称]」命令设置昵称来获得个性化(可能)回答。" ).send() - is_support_image_model = ( - model_name.lower() - in SUPPORT_IMAGE_MODELS + config.marshoai_additional_image_models - ) - usermsg = [] if is_support_image_model else "" - for i in text: # type: ignore - if i.type == "text": - if is_support_image_model: - usermsg += [TextContentItem(text=i.data["text"] + nickname_prompt).as_dict()] # type: ignore - else: - usermsg += str(i.data["text"] + nickname_prompt) # type: ignore - elif i.type == "image": - if is_support_image_model: - usermsg.append( # type: ignore - ImageContentItem( - image_url=ImageUrl( # type: ignore - url=str(await get_image_b64(i.data["url"])) # type: ignore - ) # type: ignore - ).as_dict() # type: ignore - ) # type: ignore - logger.info(f"输入图片 {i.data['url']}") - elif config.marshoai_enable_support_image_tip: - await UniMessage( - "*此模型不支持图片处理或管理员未启用此模型的图片支持。图片将被忽略。" - ).send() - backup_context = await get_backup_context(target.id, target.private) - if backup_context: - context.set_context( - backup_context, target.id, target.private - ) # 加载历史记录 - logger.info(f"已恢复会话 {target.id} 的上下文备份~") - context_msg = get_prompt(model_name) + context.build(target.id, target.private) + usermsg = await handler.process_user_input(text, model_name) tools_lists = tools.tools_list + list( map(lambda v: v.data(), get_function_calls().values()) ) logger.info(f"正在获取回答,模型:{model_name}") # logger.info(f"上下文:{context_msg}") - response = await make_chat_openai( - client=client, - model_name=model_name, - msg=context_msg + [UserMessage(content=usermsg).as_dict()], # type: ignore - tools=tools_lists if tools_lists else None, # TODO 临时追加函数,后期优化 - ) + response = await handler.handle_common_chat(usermsg, model_name, tools_lists) # await UniMessage(str(response)).send() - choice = response.choices[0] - # Sprint(choice) - # 当tool_calls非空时,将finish_reason设置为TOOL_CALLS - if choice.message.tool_calls is not None and config.marshoai_fix_toolcalls: - choice.finish_reason = "tool_calls" - logger.info(f"完成原因:{choice.finish_reason}") - if choice.finish_reason == CompletionsFinishReason.STOPPED: - # 当对话成功时,将dict的上下文添加到上下文类中 - context.append( - UserMessage(content=usermsg).as_dict(), target.id, target.private # type: ignore - ) - - ##### DeepSeek-R1 兼容部分 ##### - choice_msg_content, choice_msg_thinking, choice_msg_after = ( - extract_content_and_think(choice.message) - ) - if choice_msg_thinking and config.marshoai_send_thinking: - await UniMessage("思维链:\n" + choice_msg_thinking).send() - ##### 兼容部分结束 ##### - - context.append(choice_msg_after.to_dict(), target.id, target.private) - if [target.id, target.private] not in target_list: - target_list.append([target.id, target.private]) - - # 对话成功发送消息 - if config.marshoai_enable_richtext_parse: - await (await parse_richtext(str(choice_msg_content))).send( - reply_to=True - ) - else: - await UniMessage(str(choice_msg_content)).send(reply_to=True) - elif choice.finish_reason == CompletionsFinishReason.CONTENT_FILTERED: - - # 对话失败,消息过滤 - - await UniMessage("*已被内容过滤器过滤。请调整聊天内容后重试。").send( - reply_to=True - ) - return - elif choice.finish_reason == CompletionsFinishReason.TOOL_CALLS: - # function call - # 需要获取额外信息,调用函数工具 - tool_msg = [] - while choice.message.tool_calls is not None: - # await UniMessage(str(response)).send() - tool_calls = choice.message.tool_calls - # try: - # if tool_calls[0]["function"]["name"].startswith("$"): - # choice.message.tool_calls[0][ - # "type" - # ] = "builtin_function" # 兼容 moonshot AI 内置函数的临时方案 - # except: - # pass - tool_msg.append(choice.message) - for tool_call in tool_calls: - try: - function_args = json.loads(tool_call.function.arguments) - except json.JSONDecodeError: - function_args = json.loads( - tool_call.function.arguments.replace("'", '"') - ) - # 删除args的placeholder参数 - if "placeholder" in function_args: - del function_args["placeholder"] - logger.info( - f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:" - + "\n".join([f"{k}={v}" for k, v in function_args.items()]) - ) - await UniMessage( - f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:" - + "\n".join([f"{k}={v}" for k, v in function_args.items()]) - ).send() - # TODO 临时追加插件函数,若工具中没有则调用插件函数 - if tools.has_function(tool_call.function.name): - logger.debug(f"调用工具函数 {tool_call.function.name}") - func_return = await tools.call( - tool_call.function.name, function_args - ) # 获取返回值 - else: - if caller := get_function_calls().get(tool_call.function.name): - logger.debug(f"调用插件函数 {caller.full_name}") - # 权限检查,规则检查 TODO - # 实现依赖注入,检查函数参数及参数注解类型,对Event类型的参数进行注入 - func_return = await caller.with_ctx( - SessionContext( - bot=bot, - event=event, - state=state, - matcher=matcher, - ) - ).call(**function_args) - else: - logger.error( - f"未找到函数 {tool_call.function.name.replace('-', '.')}" - ) - func_return = f"未找到函数 {tool_call.function.name.replace('-', '.')}" - tool_msg.append( - ToolMessage(tool_call_id=tool_call.id, content=func_return).as_dict() # type: ignore - ) - # tool_msg[0]["tool_calls"][0]["type"] = "builtin_function" - # await UniMessage(str(tool_msg)).send() - request_msg = context_msg + [UserMessage(content=usermsg).as_dict()] + tool_msg # type: ignore - response = await make_chat_openai( - client=client, - model_name=model_name, - msg=request_msg, # type: ignore - tools=( - tools_lists if tools_lists else None - ), # TODO 临时追加函数,后期优化 - ) - choice = response.choices[0] - # 当tool_calls非空时,将finish_reason设置为TOOL_CALLS - if choice.message.tool_calls is not None: - choice.finish_reason = CompletionsFinishReason.TOOL_CALLS - if choice.finish_reason == CompletionsFinishReason.STOPPED: - - # 对话成功 添加上下文 - context.append( - UserMessage(content=usermsg).as_dict(), target.id, target.private # type: ignore - ) - # context.append(tool_msg, target.id, target.private) - choice_msg_dict = choice.message.to_dict() - if "reasoning_content" in choice_msg_dict: - del choice_msg_dict["reasoning_content"] - context.append(choice_msg_dict, target.id, target.private) - - # 发送消息 - if config.marshoai_enable_richtext_parse: - await (await parse_richtext(str(choice.message.content))).send( - reply_to=True - ) - else: - await UniMessage(str(choice.message.content)).send(reply_to=True) - else: - await marsho_cmd.finish(f"意外的完成原因:{choice.finish_reason}") + if response is not None: + context_user, context_assistant = response + context.append(context_user.as_dict(), target.id, target.private) + context.append(context_assistant.to_dict(), target.id, target.private) else: - await marsho_cmd.finish(f"意外的完成原因:{choice.finish_reason}") + await UniMessage("没有回答").send() except Exception as e: await UniMessage(str(e) + suggest_solution(str(e))).send() traceback.print_exc() @@ -451,12 +279,10 @@ with contextlib.suppress(ImportError): # 优化先不做() @poke_notify.handle() async def poke(event: Event): - user_id = event.get_user_id() - nicknames = await get_nicknames() - user_nickname = nicknames.get(user_id, "") + user_nickname = await get_nickname_by_user_id(event.get_user_id()) try: if config.marshoai_poke_suffix != "": - logger.info(f"收到戳一戳,用户昵称:{user_nickname},用户ID:{user_id}") + logger.info(f"收到戳一戳,用户昵称:{user_nickname}") response = await make_chat_openai( client=client, model_name=model_name, diff --git a/nonebot_plugin_marshoai/plugin/func_call/caller.py b/nonebot_plugin_marshoai/plugin/func_call/caller.py index cf634e3a..04ab747c 100644 --- a/nonebot_plugin_marshoai/plugin/func_call/caller.py +++ b/nonebot_plugin_marshoai/plugin/func_call/caller.py @@ -70,11 +70,9 @@ class Caller: ): return False, "告诉用户 Permission Denied 权限不足" - if self.ctx.state is None: - return False, "State is None" - if self._rule and not await self._rule( - self.ctx.bot, self.ctx.event, self.ctx.state - ): + # if self.ctx.state is None: + # return False, "State is None" + if self._rule and not await self._rule(self.ctx.bot, self.ctx.event): return False, "告诉用户 Rule Denied 规则不匹配" return True, "" @@ -115,6 +113,10 @@ class Caller: # 检查函数签名,确定依赖注入参数 sig = inspect.signature(func) for name, param in sig.parameters.items(): + if param.annotation == T_State: + self.di.state = name + continue # 防止后续判断T_State子类时报错 + if issubclass(param.annotation, Event) or isinstance( param.annotation, Event ): @@ -133,9 +135,6 @@ class Caller: ): self.di.matcher = name - if param.annotation == T_State: - self.di.state = name - # 检查默认值情况 for name, param in sig.parameters.items(): if param.default is not inspect.Parameter.empty: diff --git a/nonebot_plugin_marshoai/plugin/func_call/models.py b/nonebot_plugin_marshoai/plugin/func_call/models.py index 379388e7..3eeefe05 100644 --- a/nonebot_plugin_marshoai/plugin/func_call/models.py +++ b/nonebot_plugin_marshoai/plugin/func_call/models.py @@ -19,7 +19,7 @@ class SessionContext(BaseModel): bot: Bot event: Event matcher: Matcher - state: T_State + # state: T_State caller: Any = None class Config: @@ -30,5 +30,5 @@ class SessionContextDepends(BaseModel): bot: str | None = None event: str | None = None matcher: str | None = None - state: str | None = None + # state: str | None = None caller: str | None = None diff --git a/nonebot_plugin_marshoai/util.py b/nonebot_plugin_marshoai/util.py index 0fbd6ec0..1a35320f 100755 --- a/nonebot_plugin_marshoai/util.py +++ b/nonebot_plugin_marshoai/util.py @@ -20,13 +20,13 @@ from openai.types.chat import ChatCompletion, ChatCompletionMessage from zhDateTime import DateTime from ._types import DeveloperMessage +from .cache.decos import * from .config import config from .constants import CODE_BLOCK_PATTERN, IMG_LATEX_PATTERN, OPENAI_NEW_MODELS from .deal_latex import ConvertLatex -from .instances import cache -nickname_json = None # 记录昵称 -praises_json = None # 记录夸赞名单 +# nickname_json = None # 记录昵称 +# praises_json = None # 记录夸赞名单 loaded_target_list: List[str] = [] # 记录已恢复备份的上下文的列表 NOT_GIVEN = NotGiven() @@ -156,30 +156,29 @@ async def make_chat_openai( ) +@from_cache("praises") def get_praises(): - global praises_json - if praises_json is None: - praises_file = store.get_plugin_data_file( - "praises.json" - ) # 夸赞名单文件使用localstore存储 - if not praises_file.exists(): - with open(praises_file, "w", encoding="utf-8") as f: - json.dump(_praises_init_data, f, ensure_ascii=False, indent=4) - with open(praises_file, "r", encoding="utf-8") as f: - data = json.load(f) - praises_json = data + praises_file = store.get_plugin_data_file( + "praises.json" + ) # 夸赞名单文件使用localstore存储 + if not praises_file.exists(): + with open(praises_file, "w", encoding="utf-8") as f: + json.dump(_praises_init_data, f, ensure_ascii=False, indent=4) + with open(praises_file, "r", encoding="utf-8") as f: + data = json.load(f) + praises_json = data return praises_json +@update_to_cache("praises") async def refresh_praises_json(): - global praises_json praises_file = store.get_plugin_data_file("praises.json") if not praises_file.exists(): with open(praises_file, "w", encoding="utf-8") as f: json.dump(_praises_init_data, f, ensure_ascii=False, indent=4) # 异步? async with aiofiles.open(praises_file, "r", encoding="utf-8") as f: data = json.loads(await f.read()) - praises_json = data + return data def build_praises() -> str: @@ -211,22 +210,21 @@ async def load_context_from_json(name: str, path: str) -> list: return [] +@from_cache("nickname") async def get_nicknames(): - """获取nickname_json, 优先来源于全局变量""" - global nickname_json - if nickname_json is None: - filename = store.get_plugin_data_file("nickname.json") - # noinspection PyBroadException - try: - async with aiofiles.open(filename, "r", encoding="utf-8") as f: - nickname_json = json.loads(await f.read()) - except (json.JSONDecodeError, FileNotFoundError): - nickname_json = {} + """获取nickname_json, 优先来源于缓存""" + filename = store.get_plugin_data_file("nickname.json") + # noinspection PyBroadException + try: + async with aiofiles.open(filename, "r", encoding="utf-8") as f: + nickname_json = json.loads(await f.read()) + except (json.JSONDecodeError, FileNotFoundError): + nickname_json = {} return nickname_json +@update_to_cache("nickname") async def set_nickname(user_id: str, name: str): - global nickname_json filename = store.get_plugin_data_file("nickname.json") if not filename.exists(): data = {} @@ -238,18 +236,24 @@ async def set_nickname(user_id: str, name: str): del data[user_id] with open(filename, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=4) - nickname_json = data + return data +async def get_nickname_by_user_id(user_id: str): + nickname_json = await get_nicknames() + return nickname_json.get(user_id, "") + + +@update_to_cache("nickname") async def refresh_nickname_json(): - """强制刷新nickname_json, 刷新全局变量""" - global nickname_json + """强制刷新nickname_json""" # noinspection PyBroadException try: async with aiofiles.open( store.get_plugin_data_file("nickname.json"), "r", encoding="utf-8" ) as f: nickname_json = json.loads(await f.read()) + return nickname_json except (json.JSONDecodeError, FileNotFoundError): logger.error("刷新 nickname_json 表错误:无法载入 nickname.json 文件")