From 075a529aa132fbf512230dd5d9bbd0615277f666 Mon Sep 17 00:00:00 2001 From: Asankilp Date: Wed, 27 Nov 2024 13:38:11 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=E6=96=B0=E5=A2=9E=E5=8A=A0=E8=BD=BD?= =?UTF-8?q?=E5=A4=96=E9=83=A8=E5=B7=A5=E5=85=B7=E9=9B=86=E7=9A=84=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E9=A1=B9=EF=BC=8C=E4=BF=AE=E5=A4=8D=E4=BE=9D=E8=B5=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 +- nonebot_plugin_marshoai/azure.py | 29 ++++++++++++++------- nonebot_plugin_marshoai/config.py | 3 ++- nonebot_plugin_marshoai/config_example.yaml | 2 ++ nonebot_plugin_marshoai/models.py | 10 +++---- pyproject.toml | 3 ++- 6 files changed, 31 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 4975fa9..75baa8b 100644 --- a/README.md +++ b/README.md @@ -168,7 +168,7 @@ _✨ 使用 OpenAI 标准格式 API 的聊天机器人插件 ✨_ | MARSHOAI_ENABLE_PRAISES | `bool` | `true` | 是否启用夸赞名单功能 | | MARSHOAI_ENABLE_TOOLS | `bool` | `true` | 是否启用小棉工具 | | MARSHOAI_LOAD_BUILTIN_TOOLS | `bool` | `true` | 是否加载内置工具包 | - +| MARSHOAI_TOOLSET_DIR | `list` | `[]` | 外部工具集路径列表 | ## ❤ 鸣谢&版权说明 diff --git a/nonebot_plugin_marshoai/azure.py b/nonebot_plugin_marshoai/azure.py index ed5251b..458b860 100644 --- a/nonebot_plugin_marshoai/azure.py +++ b/nonebot_plugin_marshoai/azure.py @@ -80,9 +80,13 @@ target_list = [] # 记录需保存历史上下文的列表 async def _preload_tools(): tools_dir = store.get_plugin_data_dir() / "tools" os.makedirs(tools_dir, exist_ok=True) - if config.marshoai_load_builtin_tools: - tools.load_tools(Path(__file__).parent / "tools") - tools.load_tools(store.get_plugin_data_dir() / "tools") + if config.marshoai_enable_tools: + if config.marshoai_load_builtin_tools: + tools.load_tools(Path(__file__).parent / "tools") + tools.load_tools(store.get_plugin_data_dir() / "tools") + for tool_dir in config.marshoai_toolset_dir: + tools.load_tools(tool_dir) + logger.info("如果启用小棉工具后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_TOOLS 设为 false。") @add_usermsg_cmd.handle() @@ -250,11 +254,11 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None) while choice.message.tool_calls != None: tool_msg.append(AssistantMessage(tool_calls=response.choices[0].message.tool_calls)) for tool_call in choice.message.tool_calls: - if isinstance(tool_call, ChatCompletionsToolCall): + if isinstance(tool_call, ChatCompletionsToolCall): # 循环调用工具直到不需要调用 function_args = json.loads(tool_call.function.arguments.replace("'", '"')) logger.info(f"调用函数 {tool_call.function.name} ,参数为 {function_args}") await UniMessage(f"调用函数 {tool_call.function.name} ,参数为 {function_args}").send() - func_return = await tools.call(tool_call.function.name, function_args) + func_return = await tools.call(tool_call.function.name, function_args) # 获取返回值 tool_msg.append(ToolMessage(tool_call_id=tool_call.id, content=func_return)) response = await make_chat( client=client, @@ -263,12 +267,17 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None) tools=tools.get_tools_list() ) choice = response.choices[0] - context.append( - UserMessage(content=usermsg).as_dict(), target.id, target.private - ) + if choice["finish_reason"] == CompletionsFinishReason.STOPPED: + context.append( + UserMessage(content=usermsg).as_dict(), target.id, target.private + ) # context.append(tool_msg, target.id, target.private) - context.append(choice.message.as_dict(), target.id, target.private) - await UniMessage(str(choice.message.content)).send(reply_to=True) + context.append(choice.message.as_dict(), target.id, target.private) + await UniMessage(str(choice.message.content)).send(reply_to=True) + else: + await marsho_cmd.finish(f"意外的完成原因:{choice['finish_reason']}") + else: + await marsho_cmd.finish(f"意外的完成原因:{choice['finish_reason']}") except Exception as e: await UniMessage(str(e) + suggest_solution(str(e))).send() traceback.print_exc() diff --git a/nonebot_plugin_marshoai/config.py b/nonebot_plugin_marshoai/config.py index 65f3289..e1d0490 100644 --- a/nonebot_plugin_marshoai/config.py +++ b/nonebot_plugin_marshoai/config.py @@ -33,6 +33,7 @@ class ConfigModel(BaseModel): marshoai_enable_time_prompt: bool = True marshoai_enable_tools: bool = True marshoai_load_builtin_tools: bool = True + marshoai_toolset_dir: list = [] marshoai_azure_endpoint: str = "https://models.inference.ai.azure.com" marshoai_temperature: float | None = None marshoai_max_tokens: int | None = None @@ -117,4 +118,4 @@ if config.marshoai_use_yaml_config: config = ConfigModel(**yaml_config) else: - logger.info("MarshoAI 支持新的 YAML 配置系统,若要使用,请将 MARSHOAI_USE_YAML_CONFIG 配置项设置为 true。") \ No newline at end of file + logger.info("MarshoAI 支持新的 YAML 配置系统,若要使用,请将 MARSHOAI_USE_YAML_CONFIG 配置项设置为 true。") diff --git a/nonebot_plugin_marshoai/config_example.yaml b/nonebot_plugin_marshoai/config_example.yaml index 27b23e1..d3350c7 100644 --- a/nonebot_plugin_marshoai/config_example.yaml +++ b/nonebot_plugin_marshoai/config_example.yaml @@ -34,6 +34,8 @@ marshoai_enable_tools: true # 是否启用工具支持。 marshoai_load_builtin_tools: true # 是否加载内置工具。 +marshoai_toolset_dir: [] # 工具集路径。 + marshoai_azure_endpoint: "https://models.inference.ai.azure.com" # OpenAI 标准格式 API 的端点。 # 模型参数配置 diff --git a/nonebot_plugin_marshoai/models.py b/nonebot_plugin_marshoai/models.py index 61d6ecb..eb38e32 100644 --- a/nonebot_plugin_marshoai/models.py +++ b/nonebot_plugin_marshoai/models.py @@ -87,7 +87,7 @@ class MarshoTools: package = importlib.util.module_from_spec(spec) spec.loader.exec_module(package) self.imported_packages[package_name] = package - logger.info(f"成功加载工具包 {package_name}") + logger.success(f"成功加载工具包 {package_name}") except json.JSONDecodeError as e: logger.error(f"解码 JSON {json_path} 时发生错误: {e}") except Exception as e: @@ -114,10 +114,10 @@ class MarshoTools: try: function = getattr(package, function_name) return await function(**args) - except AttributeError: - logger.error(f"函数 '{function_name}' 在 '{package_name}' 中找不到。") - except TypeError as e: - logger.error(f"调用函数 '{function_name}' 时发生错误: {e}") + except Exception as e: + errinfo = f"调用函数 '{function_name}'时发生错误:{e}" + logger.error(errinfo) + return errinfo else: logger.error(f"工具包 '{package_name}' 未导入") diff --git a/pyproject.toml b/pyproject.toml index 9da2407..21cdc91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,8 @@ dependencies = [ "zhDatetime>=1.1.1", "aiohttp>=3.9", "httpx>=0.27.0", - "ruamel.yaml>=0.18.6" + "ruamel.yaml>=0.18.6", + "pyyaml>=6.0.2" ] license = { text = "MIT" }