diff --git a/nonebot_plugin_marshoai/handler.py b/nonebot_plugin_marshoai/handler.py index 654f7cad..65243f46 100644 --- a/nonebot_plugin_marshoai/handler.py +++ b/nonebot_plugin_marshoai/handler.py @@ -1,29 +1,42 @@ +import json from typing import Optional, Union from azure.ai.inference.models import ( - AssistantMessage, + CompletionsFinishReason, ImageContentItem, ImageUrl, TextContentItem, ToolMessage, UserMessage, ) -from nonebot.adapters import Event +from nonebot.adapters import Bot, Event from nonebot.log import logger -from nonebot.matcher import Matcher, current_event, current_matcher +from nonebot.matcher import ( + Matcher, + current_bot, + current_event, + current_handler, + current_matcher, +) +from nonebot.typing import T_State from nonebot_plugin_alconna.uniseg import UniMessage, UniMsg from openai import AsyncOpenAI from openai.types.chat import ChatCompletion from .config import config from .constants import SUPPORT_IMAGE_MODELS +from .instances import target_list, tools from .models import MarshoContext +from .plugin.func_call.caller import get_function_calls +from .plugin.func_call.models import SessionContext from .util import ( + extract_content_and_think, get_backup_context, get_image_b64, get_nickname_by_user_id, get_prompt, make_chat_openai, + parse_richtext, ) @@ -35,7 +48,9 @@ class MarshoHandler: ): self.client = client self.context = context + self.bot: Bot = current_bot.get() self.event: Event = current_event.get() + self.state: T_State = current_handler.get().state self.matcher: Matcher = current_matcher.get() self.message_id: str = UniMessage.get_message_id(self.event) self.target = UniMessage.get_target(self.event) @@ -44,7 +59,7 @@ class MarshoHandler: self, user_input: UniMsg, model_name: str ) -> Union[str, list]: """ - 处理用户输入 + 处理用户输入为可输入 API 的格式,并添加昵称提示 """ is_support_image_model = ( model_name.lower() @@ -88,12 +103,7 @@ class MarshoHandler: """ 处理单条聊天 """ - backup_context = await get_backup_context(self.target.id, self.target.private) - if backup_context: - self.context.set_context( - backup_context, self.target.id, self.target.private - ) # 加载历史记录 - logger.info(f"已恢复会话 {self.target.id} 的上下文备份~") + context_msg = ( get_prompt(model_name) + (self.context.build(self.target.id, self.target.private)) @@ -108,6 +118,109 @@ class MarshoHandler: ) return response + async def handle_function_call( + self, + completion: ChatCompletion, + ): + # function call + # 需要获取额外信息,调用函数工具 + tool_msg = [] + choice = completion.choices[0] + while choice.message.tool_calls is not None: + # await UniMessage(str(response)).send() + tool_calls = choice.message.tool_calls + # try: + # if tool_calls[0]["function"]["name"].startswith("$"): + # choice.message.tool_calls[0][ + # "type" + # ] = "builtin_function" # 兼容 moonshot AI 内置函数的临时方案 + # except: + # pass + tool_msg.append(choice.message) + for tool_call in tool_calls: + try: + function_args = json.loads(tool_call.function.arguments) + except json.JSONDecodeError: + function_args = json.loads( + tool_call.function.arguments.replace("'", '"') + ) + # 删除args的placeholder参数 + if "placeholder" in function_args: + del function_args["placeholder"] + logger.info( + f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:" + + "\n".join([f"{k}={v}" for k, v in function_args.items()]) + ) + await UniMessage( + f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:" + + "\n".join([f"{k}={v}" for k, v in function_args.items()]) + ).send() + # TODO 临时追加插件函数,若工具中没有则调用插件函数 + if tools.has_function(tool_call.function.name): + logger.debug(f"调用工具函数 {tool_call.function.name}") + func_return = await tools.call( + tool_call.function.name, function_args + ) # 获取返回值 + else: + if caller := get_function_calls().get(tool_call.function.name): + logger.debug(f"调用插件函数 {caller.full_name}") + # 权限检查,规则检查 TODO + # 实现依赖注入,检查函数参数及参数注解类型,对Event类型的参数进行注入 + func_return = await caller.with_ctx( + SessionContext( + bot=self.bot, + event=self.event, + state=self.state, + matcher=self.matcher, + ) + ).call(**function_args) + else: + logger.error( + f"未找到函数 {tool_call.function.name.replace('-', '.')}" + ) + func_return = ( + f"未找到函数 {tool_call.function.name.replace('-', '.')}" + ) + tool_msg.append( + ToolMessage(tool_call_id=tool_call.id, content=func_return).as_dict() # type: ignore + ) + # tool_msg[0]["tool_calls"][0]["type"] = "builtin_function" + # await UniMessage(str(tool_msg)).send() + request_msg = context_msg + [UserMessage(content=usermsg).as_dict()] + tool_msg # type: ignore + response = await make_chat_openai( + client=client, + model_name=model_name, + msg=request_msg, # type: ignore + tools=( + tools_lists if tools_lists else None + ), # TODO 临时追加函数,后期优化 + ) + choice = response.choices[0] + # 当tool_calls非空时,将finish_reason设置为TOOL_CALLS + if choice.message.tool_calls is not None: + choice.finish_reason = CompletionsFinishReason.TOOL_CALLS + if choice.finish_reason == CompletionsFinishReason.STOPPED: + + # 对话成功 添加上下文 + context.append( + UserMessage(content=usermsg).as_dict(), self.target.id, self.target.private # type: ignore + ) + # context.append(tool_msg, self.target.id, self.target.private) + choice_msg_dict = choice.message.to_dict() + if "reasoning_content" in choice_msg_dict: + del choice_msg_dict["reasoning_content"] + context.append(choice_msg_dict, self.target.id, self.target.private) + + # 发送消息 + if config.marshoai_enable_richtext_parse: + await (await parse_richtext(str(choice.message.content))).send( + reply_to=True + ) + else: + await UniMessage(str(choice.message.content)).send(reply_to=True) + else: + await marsho_cmd.finish(f"意外的完成原因:{choice.finish_reason}") + async def handle_common_chat( self, user_message: Union[str, list], @@ -119,6 +232,7 @@ class MarshoHandler: """ 处理一般聊天 """ + global target_list if stream: raise NotImplementedError response = await self.handle_single_chat( @@ -127,4 +241,43 @@ class MarshoHandler: tools=tools, with_context=with_context, ) - return response + choice = response.choices[0] + # Sprint(choice) + # 当tool_calls非空时,将finish_reason设置为TOOL_CALLS + if choice.message.tool_calls is not None and config.marshoai_fix_toolcalls: + choice.finish_reason = "tool_calls" + logger.info(f"完成原因:{choice.finish_reason}") + if choice.finish_reason == CompletionsFinishReason.STOPPED: + + ##### DeepSeek-R1 兼容部分 ##### + choice_msg_content, choice_msg_thinking, choice_msg_after = ( + extract_content_and_think(choice.message) + ) + if choice_msg_thinking and config.marshoai_send_thinking: + await UniMessage("思维链:\n" + choice_msg_thinking).send() + ##### 兼容部分结束 ##### + + if [self.target.id, self.target.private] not in target_list: + target_list.append([self.target.id, self.target.private]) + + # 对话成功发送消息 + if config.marshoai_enable_richtext_parse: + await (await parse_richtext(str(choice_msg_content))).send( + reply_to=True + ) + else: + await UniMessage(str(choice_msg_content)).send(reply_to=True) + return (UserMessage(context=user_message), choice_msg_after) + elif choice.finish_reason == CompletionsFinishReason.CONTENT_FILTERED: + + # 对话失败,消息过滤 + + await UniMessage("*已被内容过滤器过滤。请调整聊天内容后重试。").send( + reply_to=True + ) + return None + elif choice.finish_reason == CompletionsFinishReason.TOOL_CALLS: + pass + else: + await UniMessage(f"意外的完成原因:{choice.finish_reason}").send() + return None diff --git a/nonebot_plugin_marshoai/marsho.py b/nonebot_plugin_marshoai/marsho.py index bf1b0fe1..b7eb4d63 100644 --- a/nonebot_plugin_marshoai/marsho.py +++ b/nonebot_plugin_marshoai/marsho.py @@ -233,6 +233,12 @@ async def marsho( # 发送说明 # await UniMessage(metadata.usage + "\n当前使用的模型:" + model_name).send() await marsho_cmd.finish(INTRODUCTION) + backup_context = await get_backup_context(target.id, target.private) + if backup_context: + context.set_context( + backup_context, target.id, target.private + ) # 加载历史记录 + logger.info(f"已恢复会话 {target.id} 的上下文备份~") handler = MarshoHandler(client, context) try: user_nickname = await get_nickname_by_user_id(event.get_user_id()) @@ -261,143 +267,7 @@ async def marsho( usermsg, model_name, tools_lists, with_context=True ) # await UniMessage(str(response)).send() - choice = response.choices[0] - # Sprint(choice) - # 当tool_calls非空时,将finish_reason设置为TOOL_CALLS - if choice.message.tool_calls is not None and config.marshoai_fix_toolcalls: - choice.finish_reason = "tool_calls" - logger.info(f"完成原因:{choice.finish_reason}") - if choice.finish_reason == CompletionsFinishReason.STOPPED: - # 当对话成功时,将dict的上下文添加到上下文类中 - context.append( - UserMessage(content=usermsg).as_dict(), target.id, target.private # type: ignore - ) - ##### DeepSeek-R1 兼容部分 ##### - choice_msg_content, choice_msg_thinking, choice_msg_after = ( - extract_content_and_think(choice.message) - ) - if choice_msg_thinking and config.marshoai_send_thinking: - await UniMessage("思维链:\n" + choice_msg_thinking).send() - ##### 兼容部分结束 ##### - - context.append(choice_msg_after.to_dict(), target.id, target.private) - if [target.id, target.private] not in target_list: - target_list.append([target.id, target.private]) - - # 对话成功发送消息 - if config.marshoai_enable_richtext_parse: - await (await parse_richtext(str(choice_msg_content))).send( - reply_to=True - ) - else: - await UniMessage(str(choice_msg_content)).send(reply_to=True) - elif choice.finish_reason == CompletionsFinishReason.CONTENT_FILTERED: - - # 对话失败,消息过滤 - - await UniMessage("*已被内容过滤器过滤。请调整聊天内容后重试。").send( - reply_to=True - ) - return - elif choice.finish_reason == CompletionsFinishReason.TOOL_CALLS: - # function call - # 需要获取额外信息,调用函数工具 - tool_msg = [] - while choice.message.tool_calls is not None: - # await UniMessage(str(response)).send() - tool_calls = choice.message.tool_calls - # try: - # if tool_calls[0]["function"]["name"].startswith("$"): - # choice.message.tool_calls[0][ - # "type" - # ] = "builtin_function" # 兼容 moonshot AI 内置函数的临时方案 - # except: - # pass - tool_msg.append(choice.message) - for tool_call in tool_calls: - try: - function_args = json.loads(tool_call.function.arguments) - except json.JSONDecodeError: - function_args = json.loads( - tool_call.function.arguments.replace("'", '"') - ) - # 删除args的placeholder参数 - if "placeholder" in function_args: - del function_args["placeholder"] - logger.info( - f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:" - + "\n".join([f"{k}={v}" for k, v in function_args.items()]) - ) - await UniMessage( - f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:" - + "\n".join([f"{k}={v}" for k, v in function_args.items()]) - ).send() - # TODO 临时追加插件函数,若工具中没有则调用插件函数 - if tools.has_function(tool_call.function.name): - logger.debug(f"调用工具函数 {tool_call.function.name}") - func_return = await tools.call( - tool_call.function.name, function_args - ) # 获取返回值 - else: - if caller := get_function_calls().get(tool_call.function.name): - logger.debug(f"调用插件函数 {caller.full_name}") - # 权限检查,规则检查 TODO - # 实现依赖注入,检查函数参数及参数注解类型,对Event类型的参数进行注入 - func_return = await caller.with_ctx( - SessionContext( - bot=bot, - event=event, - state=state, - matcher=matcher, - ) - ).call(**function_args) - else: - logger.error( - f"未找到函数 {tool_call.function.name.replace('-', '.')}" - ) - func_return = f"未找到函数 {tool_call.function.name.replace('-', '.')}" - tool_msg.append( - ToolMessage(tool_call_id=tool_call.id, content=func_return).as_dict() # type: ignore - ) - # tool_msg[0]["tool_calls"][0]["type"] = "builtin_function" - # await UniMessage(str(tool_msg)).send() - request_msg = context_msg + [UserMessage(content=usermsg).as_dict()] + tool_msg # type: ignore - response = await make_chat_openai( - client=client, - model_name=model_name, - msg=request_msg, # type: ignore - tools=( - tools_lists if tools_lists else None - ), # TODO 临时追加函数,后期优化 - ) - choice = response.choices[0] - # 当tool_calls非空时,将finish_reason设置为TOOL_CALLS - if choice.message.tool_calls is not None: - choice.finish_reason = CompletionsFinishReason.TOOL_CALLS - if choice.finish_reason == CompletionsFinishReason.STOPPED: - - # 对话成功 添加上下文 - context.append( - UserMessage(content=usermsg).as_dict(), target.id, target.private # type: ignore - ) - # context.append(tool_msg, target.id, target.private) - choice_msg_dict = choice.message.to_dict() - if "reasoning_content" in choice_msg_dict: - del choice_msg_dict["reasoning_content"] - context.append(choice_msg_dict, target.id, target.private) - - # 发送消息 - if config.marshoai_enable_richtext_parse: - await (await parse_richtext(str(choice.message.content))).send( - reply_to=True - ) - else: - await UniMessage(str(choice.message.content)).send(reply_to=True) - else: - await marsho_cmd.finish(f"意外的完成原因:{choice.finish_reason}") - else: - await marsho_cmd.finish(f"意外的完成原因:{choice.finish_reason}") except Exception as e: await UniMessage(str(e) + suggest_solution(str(e))).send() traceback.print_exc()