From f03a41d38a09de5e262b4b11a70a65b4097f5e5d Mon Sep 17 00:00:00 2001 From: MoeSnowyFox Date: Sun, 17 Nov 2024 16:55:41 +0800 Subject: [PATCH] =?UTF-8?q?:bug:=20=E4=BD=BF=E7=94=A8set=5Fcontext?= =?UTF-8?q?=E9=87=8D=E5=86=99=E5=8E=86=E5=8F=B2=E8=AE=B0=E5=BD=95=E5=8A=A0?= =?UTF-8?q?=E8=BD=BD,=20=E4=BF=AE=E5=A4=8D=E4=BD=BF=E7=94=A8contexts?= =?UTF-8?q?=E6=8C=87=E4=BB=A4=E4=BD=86=E6=98=AF=E6=9C=AA=E8=BF=9B=E8=A1=8C?= =?UTF-8?q?=E5=AF=B9=E8=AF=9D=E6=97=B6=E6=9C=AA=E5=8A=A0=E8=BD=BD=E5=8E=86?= =?UTF-8?q?=E5=8F=B2=E8=AE=B0=E5=BD=95=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot_plugin_marshoai/azure.py | 28 ++++++++++++++-------------- nonebot_plugin_marshoai/util.py | 22 +++++++++++++++++++--- 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/nonebot_plugin_marshoai/azure.py b/nonebot_plugin_marshoai/azure.py index 6f4aef3..c6a628e 100644 --- a/nonebot_plugin_marshoai/azure.py +++ b/nonebot_plugin_marshoai/azure.py @@ -56,8 +56,7 @@ context = MarshoContext() token = config.marshoai_token endpoint = config.marshoai_azure_endpoint client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token)) -target_list = [] -loaded_target_list = [] +target_list = [] # 记录需保存历史记录的列表 @add_usermsg_cmd.handle() @@ -83,12 +82,17 @@ async def praises(): @contexts_cmd.handle() async def contexts(target: MsgTarget): + context.set_context( + await get_backup_context(target.id, target.private), target.id, target.private + ) # 加载历史记录 await contexts_cmd.finish(str(context.build(target.id, target.private))) @save_context_cmd.handle() async def save_context(target: MsgTarget, arg: Message = CommandArg()): contexts_data = context.build(target.id, target.private) + if not context: + await save_context_cmd.finish("暂无上下文可以保存") if msg := arg.extract_plain_text(): await save_context_to_json(msg, contexts_data, "contexts") await save_context_cmd.finish("已保存上下文") @@ -142,7 +146,7 @@ async def refresh_data(): @marsho_cmd.handle() async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None): - global target_list, loaded_target_list + global target_list if not text: # 发送说明 await UniMessage(metadata.usage + "\n当前使用的模型:" + model_name).send() @@ -179,18 +183,14 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None) ) elif config.marshoai_enable_support_image_tip: await UniMessage("*此模型不支持图片处理。").send() + context.set_context( + await get_backup_context(target.id, target.private), target.id, target.private + ) # 加载历史记录 context_msg = context.build(target.id, target.private) - if not context_msg and target.id not in loaded_target_list: - if target.private: - channel_id = "private_" + target.id - else: - channel_id = "group_" + target.id - context_msg = list(await load_context_from_json(f"back_up_context_{channel_id}", "contexts/backup")) - loaded_target_list.append(target.id) target_list.append([target.id, target.private]) if not is_reasoning_model: context_msg = [get_prompt()] + context_msg - # o1等推理模型不支持系统提示词 + # o1等推理模型不支持系统提示词, 故不添加 response = await make_chat( client=client, model_name=model_name, @@ -257,7 +257,7 @@ async def save_context(): target_id, target_private = target_info contexts_data = context.build(target_id, target_private) if target_private: - channel_id = "private_" + target_id + target_uid = "private_" + target_id else: - channel_id = "group_" + target_id - await save_context_to_json(f"back_up_context_{channel_id}", contexts_data, "contexts/backup") + target_uid = "group_" + target_id + await save_context_to_json(f"back_up_context_{target_uid}", contexts_data, "contexts/backup") diff --git a/nonebot_plugin_marshoai/util.py b/nonebot_plugin_marshoai/util.py index 8e2db39..64d4447 100644 --- a/nonebot_plugin_marshoai/util.py +++ b/nonebot_plugin_marshoai/util.py @@ -13,8 +13,9 @@ from azure.ai.inference.aio import ChatCompletionsClient from azure.ai.inference.models import SystemMessage from .config import config -nickname_json = None -praises_json = None +nickname_json = None # 记录昵称 +praises_json = None # 记录赞扬名单 +loaded_target_list = [] # 记录已恢复历史记录的列表 async def get_image_b64(url): @@ -111,7 +112,8 @@ async def save_context_to_json(name: str, context: Any, path: str): json.dump(context, json_file, ensure_ascii=False, indent=4) -async def load_context_from_json(name: str, path:str): +async def load_context_from_json(name: str, path: str) -> list: + """从指定路径加载历史记录""" context_dir = store.get_plugin_data_dir() / path os.makedirs(context_dir, exist_ok=True) file_path = os.path.join(context_dir, f"{name}.json") @@ -165,6 +167,7 @@ async def refresh_nickname_json(): def get_prompt(): + """获取系统提示词""" prompts = "" prompts += config.marshoai_additional_prompt if config.marshoai_enable_praises: @@ -199,3 +202,16 @@ def suggest_solution(errinfo: str) -> str: return f"\n{suggestion}" return "" + + +async def get_backup_context(target_id: str, target_private: bool) -> list: + """获取历史记录""" + global loaded_target_list + if target_private: + target_uid = f"private_{target_id}" + else: + target_uid = f"group_{target_id}" + if target_uid not in loaded_target_list: + loaded_target_list.append(target_uid) + return await load_context_from_json(f"back_up_context_{target_uid}", "contexts/backup") + return []