diff --git a/nonebot_plugin_marshoai/azure.py b/nonebot_plugin_marshoai/azure.py index 5e5a841..022d89f 100644 --- a/nonebot_plugin_marshoai/azure.py +++ b/nonebot_plugin_marshoai/azure.py @@ -58,6 +58,7 @@ token = config.marshoai_token endpoint = config.marshoai_azure_endpoint client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token)) target_list = [] +loaded_target_list = [] @add_usermsg_cmd.handle() @@ -90,7 +91,7 @@ async def contexts(target: MsgTarget): async def save_context(target: MsgTarget, arg: Message = CommandArg()): contexts_data = context.build(target.id, target.private)[1:] if msg := arg.extract_plain_text(): - await save_context_to_json(msg, contexts_data, "context") + await save_context_to_json(msg, contexts_data, "contexts") await save_context_cmd.finish("已保存上下文") @@ -98,7 +99,7 @@ async def save_context(target: MsgTarget, arg: Message = CommandArg()): async def load_context(target: MsgTarget, arg: Message = CommandArg()): if msg := arg.extract_plain_text(): context.set_context( - await load_context_from_json(msg, "context"), target.id, target.private + await load_context_from_json(msg, "contexts"), target.id, target.private ) await load_context_cmd.finish("已加载并覆盖上下文") @@ -142,7 +143,7 @@ async def refresh_data(): @marsho_cmd.handle() async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None): - global target_list + global target_list, loaded_target_list if not text: # 发送说明 await UniMessage(metadata.usage + "\n当前使用的模型:" + model_name).send() @@ -180,9 +181,9 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None) elif config.marshoai_enable_support_image_tip: await UniMessage("*此模型不支持图片处理。").send() context_msg = context.build(target.id, target.private) - if not context_msg: - context_msg = list(await load_context_from_json(f"back_up_context_{target.id}", "context/backup")) - await save_context_to_json(f"back_up_context_{target.id}", [], "context/backup") + if not context_msg and target.id not in loaded_target_list: + context_msg = list(await load_context_from_json(f"back_up_context_{target.id}", "contexts/backup")) + loaded_target_list.append(target.id) msg_prompt = get_prompt() context_msg = [msg_prompt] + context_msg print(str(context_msg)) @@ -255,4 +256,4 @@ async def save_context(): for target_info in target_list: target_id, target_private = target_info contexts_data = context.build(target_id, target_private)[1:] - await save_context_to_json(f"back_up_context_{target_id}", contexts_data, "context/backup") + await save_context_to_json(f"back_up_context_{target_id}", contexts_data, "contexts/backup")