mirror of
https://github.com/LiteyukiStudio/nonebot-plugin-marshoai.git
synced 2024-11-26 23:05:04 +08:00
🐛 使用set_context重写历史记录加载, 修复使用contexts指令但是未进行对话时未加载历史记录的bug
This commit is contained in:
parent
8df614163a
commit
f03a41d38a
@ -56,8 +56,7 @@ context = MarshoContext()
|
|||||||
token = config.marshoai_token
|
token = config.marshoai_token
|
||||||
endpoint = config.marshoai_azure_endpoint
|
endpoint = config.marshoai_azure_endpoint
|
||||||
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token))
|
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token))
|
||||||
target_list = []
|
target_list = [] # 记录需保存历史记录的列表
|
||||||
loaded_target_list = []
|
|
||||||
|
|
||||||
|
|
||||||
@add_usermsg_cmd.handle()
|
@add_usermsg_cmd.handle()
|
||||||
@ -83,12 +82,17 @@ async def praises():
|
|||||||
|
|
||||||
@contexts_cmd.handle()
|
@contexts_cmd.handle()
|
||||||
async def contexts(target: MsgTarget):
|
async def contexts(target: MsgTarget):
|
||||||
|
context.set_context(
|
||||||
|
await get_backup_context(target.id, target.private), target.id, target.private
|
||||||
|
) # 加载历史记录
|
||||||
await contexts_cmd.finish(str(context.build(target.id, target.private)))
|
await contexts_cmd.finish(str(context.build(target.id, target.private)))
|
||||||
|
|
||||||
|
|
||||||
@save_context_cmd.handle()
|
@save_context_cmd.handle()
|
||||||
async def save_context(target: MsgTarget, arg: Message = CommandArg()):
|
async def save_context(target: MsgTarget, arg: Message = CommandArg()):
|
||||||
contexts_data = context.build(target.id, target.private)
|
contexts_data = context.build(target.id, target.private)
|
||||||
|
if not context:
|
||||||
|
await save_context_cmd.finish("暂无上下文可以保存")
|
||||||
if msg := arg.extract_plain_text():
|
if msg := arg.extract_plain_text():
|
||||||
await save_context_to_json(msg, contexts_data, "contexts")
|
await save_context_to_json(msg, contexts_data, "contexts")
|
||||||
await save_context_cmd.finish("已保存上下文")
|
await save_context_cmd.finish("已保存上下文")
|
||||||
@ -142,7 +146,7 @@ async def refresh_data():
|
|||||||
|
|
||||||
@marsho_cmd.handle()
|
@marsho_cmd.handle()
|
||||||
async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None):
|
async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None):
|
||||||
global target_list, loaded_target_list
|
global target_list
|
||||||
if not text:
|
if not text:
|
||||||
# 发送说明
|
# 发送说明
|
||||||
await UniMessage(metadata.usage + "\n当前使用的模型:" + model_name).send()
|
await UniMessage(metadata.usage + "\n当前使用的模型:" + model_name).send()
|
||||||
@ -179,18 +183,14 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None)
|
|||||||
)
|
)
|
||||||
elif config.marshoai_enable_support_image_tip:
|
elif config.marshoai_enable_support_image_tip:
|
||||||
await UniMessage("*此模型不支持图片处理。").send()
|
await UniMessage("*此模型不支持图片处理。").send()
|
||||||
|
context.set_context(
|
||||||
|
await get_backup_context(target.id, target.private), target.id, target.private
|
||||||
|
) # 加载历史记录
|
||||||
context_msg = context.build(target.id, target.private)
|
context_msg = context.build(target.id, target.private)
|
||||||
if not context_msg and target.id not in loaded_target_list:
|
|
||||||
if target.private:
|
|
||||||
channel_id = "private_" + target.id
|
|
||||||
else:
|
|
||||||
channel_id = "group_" + target.id
|
|
||||||
context_msg = list(await load_context_from_json(f"back_up_context_{channel_id}", "contexts/backup"))
|
|
||||||
loaded_target_list.append(target.id)
|
|
||||||
target_list.append([target.id, target.private])
|
target_list.append([target.id, target.private])
|
||||||
if not is_reasoning_model:
|
if not is_reasoning_model:
|
||||||
context_msg = [get_prompt()] + context_msg
|
context_msg = [get_prompt()] + context_msg
|
||||||
# o1等推理模型不支持系统提示词
|
# o1等推理模型不支持系统提示词, 故不添加
|
||||||
response = await make_chat(
|
response = await make_chat(
|
||||||
client=client,
|
client=client,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@ -257,7 +257,7 @@ async def save_context():
|
|||||||
target_id, target_private = target_info
|
target_id, target_private = target_info
|
||||||
contexts_data = context.build(target_id, target_private)
|
contexts_data = context.build(target_id, target_private)
|
||||||
if target_private:
|
if target_private:
|
||||||
channel_id = "private_" + target_id
|
target_uid = "private_" + target_id
|
||||||
else:
|
else:
|
||||||
channel_id = "group_" + target_id
|
target_uid = "group_" + target_id
|
||||||
await save_context_to_json(f"back_up_context_{channel_id}", contexts_data, "contexts/backup")
|
await save_context_to_json(f"back_up_context_{target_uid}", contexts_data, "contexts/backup")
|
||||||
|
@ -13,8 +13,9 @@ from azure.ai.inference.aio import ChatCompletionsClient
|
|||||||
from azure.ai.inference.models import SystemMessage
|
from azure.ai.inference.models import SystemMessage
|
||||||
from .config import config
|
from .config import config
|
||||||
|
|
||||||
nickname_json = None
|
nickname_json = None # 记录昵称
|
||||||
praises_json = None
|
praises_json = None # 记录赞扬名单
|
||||||
|
loaded_target_list = [] # 记录已恢复历史记录的列表
|
||||||
|
|
||||||
|
|
||||||
async def get_image_b64(url):
|
async def get_image_b64(url):
|
||||||
@ -111,7 +112,8 @@ async def save_context_to_json(name: str, context: Any, path: str):
|
|||||||
json.dump(context, json_file, ensure_ascii=False, indent=4)
|
json.dump(context, json_file, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
|
||||||
async def load_context_from_json(name: str, path:str):
|
async def load_context_from_json(name: str, path: str) -> list:
|
||||||
|
"""从指定路径加载历史记录"""
|
||||||
context_dir = store.get_plugin_data_dir() / path
|
context_dir = store.get_plugin_data_dir() / path
|
||||||
os.makedirs(context_dir, exist_ok=True)
|
os.makedirs(context_dir, exist_ok=True)
|
||||||
file_path = os.path.join(context_dir, f"{name}.json")
|
file_path = os.path.join(context_dir, f"{name}.json")
|
||||||
@ -165,6 +167,7 @@ async def refresh_nickname_json():
|
|||||||
|
|
||||||
|
|
||||||
def get_prompt():
|
def get_prompt():
|
||||||
|
"""获取系统提示词"""
|
||||||
prompts = ""
|
prompts = ""
|
||||||
prompts += config.marshoai_additional_prompt
|
prompts += config.marshoai_additional_prompt
|
||||||
if config.marshoai_enable_praises:
|
if config.marshoai_enable_praises:
|
||||||
@ -199,3 +202,16 @@ def suggest_solution(errinfo: str) -> str:
|
|||||||
return f"\n{suggestion}"
|
return f"\n{suggestion}"
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
async def get_backup_context(target_id: str, target_private: bool) -> list:
|
||||||
|
"""获取历史记录"""
|
||||||
|
global loaded_target_list
|
||||||
|
if target_private:
|
||||||
|
target_uid = f"private_{target_id}"
|
||||||
|
else:
|
||||||
|
target_uid = f"group_{target_id}"
|
||||||
|
if target_uid not in loaded_target_list:
|
||||||
|
loaded_target_list.append(target_uid)
|
||||||
|
return await load_context_from_json(f"back_up_context_{target_uid}", "contexts/backup")
|
||||||
|
return []
|
||||||
|
Loading…
Reference in New Issue
Block a user