From 5b315c46b1813737dedfb90467737ec2624807dc Mon Sep 17 00:00:00 2001 From: Asankilp Date: Sun, 23 Feb 2025 14:50:35 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20=E9=87=8D=E5=86=99=E5=9F=BA?= =?UTF-8?q?=E6=9C=AC=E5=AE=8C=E6=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot_plugin_marshoai/handler.py | 185 +++++++----------- nonebot_plugin_marshoai/instances.py | 1 - nonebot_plugin_marshoai/marsho.py | 11 +- .../plugin/func_call/caller.py | 15 +- .../plugin/func_call/models.py | 4 +- 5 files changed, 89 insertions(+), 127 deletions(-) diff --git a/nonebot_plugin_marshoai/handler.py b/nonebot_plugin_marshoai/handler.py index 65243f46..3dd3402d 100644 --- a/nonebot_plugin_marshoai/handler.py +++ b/nonebot_plugin_marshoai/handler.py @@ -1,5 +1,5 @@ import json -from typing import Optional, Union +from typing import Optional, Tuple, Union from azure.ai.inference.models import ( CompletionsFinishReason, @@ -21,11 +21,11 @@ from nonebot.matcher import ( from nonebot.typing import T_State from nonebot_plugin_alconna.uniseg import UniMessage, UniMsg from openai import AsyncOpenAI -from openai.types.chat import ChatCompletion +from openai.types.chat import ChatCompletion, ChatCompletionMessage from .config import config from .constants import SUPPORT_IMAGE_MODELS -from .instances import target_list, tools +from .instances import target_list from .models import MarshoContext from .plugin.func_call.caller import get_function_calls from .plugin.func_call.models import SessionContext @@ -50,7 +50,7 @@ class MarshoHandler: self.context = context self.bot: Bot = current_bot.get() self.event: Event = current_event.get() - self.state: T_State = current_handler.get().state + # self.state: T_State = current_handler.get().state self.matcher: Matcher = current_matcher.get() self.message_id: str = UniMessage.get_message_id(self.event) self.target = UniMessage.get_target(self.event) @@ -97,138 +97,97 @@ class MarshoHandler: self, user_message: Union[str, list], model_name: str, - tools: list, - with_context: bool = True, + tools_list: list, + tool_message: Optional[list] = None, ) -> ChatCompletion: """ 处理单条聊天 """ - context_msg = ( - get_prompt(model_name) - + (self.context.build(self.target.id, self.target.private)) - if with_context - else "" + context_msg = get_prompt(model_name) + ( + self.context.build(self.target.id, self.target.private) ) response = await make_chat_openai( client=self.client, - msg=context_msg + [UserMessage(content=user_message).as_dict()], # type: ignore + msg=context_msg + [UserMessage(content=user_message).as_dict()] + (tool_message if tool_message else []), # type: ignore model_name=model_name, - tools=tools, + tools=tools_list if tools_list else None, ) return response async def handle_function_call( self, completion: ChatCompletion, + user_message: Union[str, list], + model_name: str, + tools_list: list, ): # function call # 需要获取额外信息,调用函数工具 tool_msg = [] choice = completion.choices[0] - while choice.message.tool_calls is not None: - # await UniMessage(str(response)).send() - tool_calls = choice.message.tool_calls - # try: - # if tool_calls[0]["function"]["name"].startswith("$"): - # choice.message.tool_calls[0][ - # "type" - # ] = "builtin_function" # 兼容 moonshot AI 内置函数的临时方案 - # except: - # pass - tool_msg.append(choice.message) - for tool_call in tool_calls: - try: - function_args = json.loads(tool_call.function.arguments) - except json.JSONDecodeError: - function_args = json.loads( - tool_call.function.arguments.replace("'", '"') + # await UniMessage(str(response)).send() + tool_calls = choice.message.tool_calls + # try: + # if tool_calls[0]["function"]["name"].startswith("$"): + # choice.message.tool_calls[0][ + # "type" + # ] = "builtin_function" # 兼容 moonshot AI 内置函数的临时方案 + # except: + # pass + tool_msg.append(choice.message) + for tool_call in tool_calls: + try: + function_args = json.loads(tool_call.function.arguments) + except json.JSONDecodeError: + function_args = json.loads( + tool_call.function.arguments.replace("'", '"') + ) + # 删除args的placeholder参数 + if "placeholder" in function_args: + del function_args["placeholder"] + logger.info( + f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:" + + "\n".join([f"{k}={v}" for k, v in function_args.items()]) + ) + await UniMessage( + f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:" + + "\n".join([f"{k}={v}" for k, v in function_args.items()]) + ).send() + if caller := get_function_calls().get(tool_call.function.name): + logger.debug(f"调用插件函数 {caller.full_name}") + # 权限检查,规则检查 TODO + # 实现依赖注入,检查函数参数及参数注解类型,对Event类型的参数进行注入 + func_return = await caller.with_ctx( + SessionContext( + bot=self.bot, + event=self.event, + matcher=self.matcher, ) - # 删除args的placeholder参数 - if "placeholder" in function_args: - del function_args["placeholder"] - logger.info( - f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:" - + "\n".join([f"{k}={v}" for k, v in function_args.items()]) - ) - await UniMessage( - f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:" - + "\n".join([f"{k}={v}" for k, v in function_args.items()]) - ).send() - # TODO 临时追加插件函数,若工具中没有则调用插件函数 - if tools.has_function(tool_call.function.name): - logger.debug(f"调用工具函数 {tool_call.function.name}") - func_return = await tools.call( - tool_call.function.name, function_args - ) # 获取返回值 - else: - if caller := get_function_calls().get(tool_call.function.name): - logger.debug(f"调用插件函数 {caller.full_name}") - # 权限检查,规则检查 TODO - # 实现依赖注入,检查函数参数及参数注解类型,对Event类型的参数进行注入 - func_return = await caller.with_ctx( - SessionContext( - bot=self.bot, - event=self.event, - state=self.state, - matcher=self.matcher, - ) - ).call(**function_args) - else: - logger.error( - f"未找到函数 {tool_call.function.name.replace('-', '.')}" - ) - func_return = ( - f"未找到函数 {tool_call.function.name.replace('-', '.')}" - ) - tool_msg.append( - ToolMessage(tool_call_id=tool_call.id, content=func_return).as_dict() # type: ignore - ) - # tool_msg[0]["tool_calls"][0]["type"] = "builtin_function" - # await UniMessage(str(tool_msg)).send() - request_msg = context_msg + [UserMessage(content=usermsg).as_dict()] + tool_msg # type: ignore - response = await make_chat_openai( - client=client, - model_name=model_name, - msg=request_msg, # type: ignore - tools=( - tools_lists if tools_lists else None - ), # TODO 临时追加函数,后期优化 - ) - choice = response.choices[0] - # 当tool_calls非空时,将finish_reason设置为TOOL_CALLS - if choice.message.tool_calls is not None: - choice.finish_reason = CompletionsFinishReason.TOOL_CALLS - if choice.finish_reason == CompletionsFinishReason.STOPPED: - - # 对话成功 添加上下文 - context.append( - UserMessage(content=usermsg).as_dict(), self.target.id, self.target.private # type: ignore - ) - # context.append(tool_msg, self.target.id, self.target.private) - choice_msg_dict = choice.message.to_dict() - if "reasoning_content" in choice_msg_dict: - del choice_msg_dict["reasoning_content"] - context.append(choice_msg_dict, self.target.id, self.target.private) - - # 发送消息 - if config.marshoai_enable_richtext_parse: - await (await parse_richtext(str(choice.message.content))).send( - reply_to=True - ) + ).call(**function_args) else: - await UniMessage(str(choice.message.content)).send(reply_to=True) - else: - await marsho_cmd.finish(f"意外的完成原因:{choice.finish_reason}") + logger.error(f"未找到函数 {tool_call.function.name.replace('-', '.')}") + func_return = f"未找到函数 {tool_call.function.name.replace('-', '.')}" + tool_msg.append( + ToolMessage(tool_call_id=tool_call.id, content=func_return).as_dict() # type: ignore + ) + # tool_msg[0]["tool_calls"][0]["type"] = "builtin_function" + # await UniMessage(str(tool_msg)).send() + return await self.handle_common_chat( + user_message=user_message, + model_name=model_name, + tools_list=tools_list, + tool_message=tool_msg, + ) async def handle_common_chat( self, user_message: Union[str, list], model_name: str, - tools: list, - with_context: bool = True, + tools_list: list, stream: bool = False, - ) -> ChatCompletion: + tool_message: Optional[list] = None, + ) -> Union[Tuple[UserMessage, ChatCompletionMessage], None]: """ 处理一般聊天 """ @@ -238,8 +197,8 @@ class MarshoHandler: response = await self.handle_single_chat( user_message=user_message, model_name=model_name, - tools=tools, - with_context=with_context, + tools_list=tools_list, + tool_message=tool_message, ) choice = response.choices[0] # Sprint(choice) @@ -267,7 +226,7 @@ class MarshoHandler: ) else: await UniMessage(str(choice_msg_content)).send(reply_to=True) - return (UserMessage(context=user_message), choice_msg_after) + return UserMessage(content=user_message), choice_msg_after elif choice.finish_reason == CompletionsFinishReason.CONTENT_FILTERED: # 对话失败,消息过滤 @@ -277,7 +236,9 @@ class MarshoHandler: ) return None elif choice.finish_reason == CompletionsFinishReason.TOOL_CALLS: - pass + return await self.handle_function_call( + response, user_message, model_name, tools_list + ) else: await UniMessage(f"意外的完成原因:{choice.finish_reason}").send() return None diff --git a/nonebot_plugin_marshoai/instances.py b/nonebot_plugin_marshoai/instances.py index b4cba256..d5aaf709 100644 --- a/nonebot_plugin_marshoai/instances.py +++ b/nonebot_plugin_marshoai/instances.py @@ -3,7 +3,6 @@ from nonebot import get_driver from openai import AsyncOpenAI from .config import config -from .handler import MarshoHandler from .models import MarshoContext, MarshoTools driver = get_driver() diff --git a/nonebot_plugin_marshoai/marsho.py b/nonebot_plugin_marshoai/marsho.py index b7eb4d63..e24bff4b 100644 --- a/nonebot_plugin_marshoai/marsho.py +++ b/nonebot_plugin_marshoai/marsho.py @@ -263,11 +263,14 @@ async def marsho( ) logger.info(f"正在获取回答,模型:{model_name}") # logger.info(f"上下文:{context_msg}") - response = await handler.handle_single_chat( - usermsg, model_name, tools_lists, with_context=True - ) + response = await handler.handle_common_chat(usermsg, model_name, tools_lists) # await UniMessage(str(response)).send() - + if response is not None: + context_user, context_assistant = response + context.append(context_user.as_dict(), target.id, target.private) + context.append(context_assistant.to_dict(), target.id, target.private) + else: + await UniMessage("没有回答").send() except Exception as e: await UniMessage(str(e) + suggest_solution(str(e))).send() traceback.print_exc() diff --git a/nonebot_plugin_marshoai/plugin/func_call/caller.py b/nonebot_plugin_marshoai/plugin/func_call/caller.py index cf634e3a..04ab747c 100644 --- a/nonebot_plugin_marshoai/plugin/func_call/caller.py +++ b/nonebot_plugin_marshoai/plugin/func_call/caller.py @@ -70,11 +70,9 @@ class Caller: ): return False, "告诉用户 Permission Denied 权限不足" - if self.ctx.state is None: - return False, "State is None" - if self._rule and not await self._rule( - self.ctx.bot, self.ctx.event, self.ctx.state - ): + # if self.ctx.state is None: + # return False, "State is None" + if self._rule and not await self._rule(self.ctx.bot, self.ctx.event): return False, "告诉用户 Rule Denied 规则不匹配" return True, "" @@ -115,6 +113,10 @@ class Caller: # 检查函数签名,确定依赖注入参数 sig = inspect.signature(func) for name, param in sig.parameters.items(): + if param.annotation == T_State: + self.di.state = name + continue # 防止后续判断T_State子类时报错 + if issubclass(param.annotation, Event) or isinstance( param.annotation, Event ): @@ -133,9 +135,6 @@ class Caller: ): self.di.matcher = name - if param.annotation == T_State: - self.di.state = name - # 检查默认值情况 for name, param in sig.parameters.items(): if param.default is not inspect.Parameter.empty: diff --git a/nonebot_plugin_marshoai/plugin/func_call/models.py b/nonebot_plugin_marshoai/plugin/func_call/models.py index 379388e7..3eeefe05 100644 --- a/nonebot_plugin_marshoai/plugin/func_call/models.py +++ b/nonebot_plugin_marshoai/plugin/func_call/models.py @@ -19,7 +19,7 @@ class SessionContext(BaseModel): bot: Bot event: Event matcher: Matcher - state: T_State + # state: T_State caller: Any = None class Config: @@ -30,5 +30,5 @@ class SessionContextDepends(BaseModel): bot: str | None = None event: str | None = None matcher: str | None = None - state: str | None = None + # state: str | None = None caller: str | None = None