新增加载外部工具集的配置项,修复依赖

This commit is contained in:
Asankilp 2024-11-27 13:38:11 +08:00
parent c7e55cc803
commit 075a529aa1
6 changed files with 31 additions and 18 deletions

View File

@ -168,7 +168,7 @@ _✨ 使用 OpenAI 标准格式 API 的聊天机器人插件 ✨_
| MARSHOAI_ENABLE_PRAISES | `bool` | `true` | 是否启用夸赞名单功能 |
| MARSHOAI_ENABLE_TOOLS | `bool` | `true` | 是否启用小棉工具 |
| MARSHOAI_LOAD_BUILTIN_TOOLS | `bool` | `true` | 是否加载内置工具包 |
| MARSHOAI_TOOLSET_DIR | `list` | `[]` | 外部工具集路径列表 |
## ❤ 鸣谢&版权说明

View File

@ -80,9 +80,13 @@ target_list = [] # 记录需保存历史上下文的列表
async def _preload_tools():
tools_dir = store.get_plugin_data_dir() / "tools"
os.makedirs(tools_dir, exist_ok=True)
if config.marshoai_enable_tools:
if config.marshoai_load_builtin_tools:
tools.load_tools(Path(__file__).parent / "tools")
tools.load_tools(store.get_plugin_data_dir() / "tools")
for tool_dir in config.marshoai_toolset_dir:
tools.load_tools(tool_dir)
logger.info("如果启用小棉工具后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_TOOLS 设为 false。")
@add_usermsg_cmd.handle()
@ -250,11 +254,11 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None)
while choice.message.tool_calls != None:
tool_msg.append(AssistantMessage(tool_calls=response.choices[0].message.tool_calls))
for tool_call in choice.message.tool_calls:
if isinstance(tool_call, ChatCompletionsToolCall):
if isinstance(tool_call, ChatCompletionsToolCall): # 循环调用工具直到不需要调用
function_args = json.loads(tool_call.function.arguments.replace("'", '"'))
logger.info(f"调用函数 {tool_call.function.name} ,参数为 {function_args}")
await UniMessage(f"调用函数 {tool_call.function.name} ,参数为 {function_args}").send()
func_return = await tools.call(tool_call.function.name, function_args)
func_return = await tools.call(tool_call.function.name, function_args) # 获取返回值
tool_msg.append(ToolMessage(tool_call_id=tool_call.id, content=func_return))
response = await make_chat(
client=client,
@ -263,12 +267,17 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None)
tools=tools.get_tools_list()
)
choice = response.choices[0]
if choice["finish_reason"] == CompletionsFinishReason.STOPPED:
context.append(
UserMessage(content=usermsg).as_dict(), target.id, target.private
)
# context.append(tool_msg, target.id, target.private)
context.append(choice.message.as_dict(), target.id, target.private)
await UniMessage(str(choice.message.content)).send(reply_to=True)
else:
await marsho_cmd.finish(f"意外的完成原因:{choice['finish_reason']}")
else:
await marsho_cmd.finish(f"意外的完成原因:{choice['finish_reason']}")
except Exception as e:
await UniMessage(str(e) + suggest_solution(str(e))).send()
traceback.print_exc()

View File

@ -33,6 +33,7 @@ class ConfigModel(BaseModel):
marshoai_enable_time_prompt: bool = True
marshoai_enable_tools: bool = True
marshoai_load_builtin_tools: bool = True
marshoai_toolset_dir: list = []
marshoai_azure_endpoint: str = "https://models.inference.ai.azure.com"
marshoai_temperature: float | None = None
marshoai_max_tokens: int | None = None

View File

@ -34,6 +34,8 @@ marshoai_enable_tools: true # 是否启用工具支持。
marshoai_load_builtin_tools: true # 是否加载内置工具。
marshoai_toolset_dir: [] # 工具集路径。
marshoai_azure_endpoint: "https://models.inference.ai.azure.com" # OpenAI 标准格式 API 的端点。
# 模型参数配置

View File

@ -87,7 +87,7 @@ class MarshoTools:
package = importlib.util.module_from_spec(spec)
spec.loader.exec_module(package)
self.imported_packages[package_name] = package
logger.info(f"成功加载工具包 {package_name}")
logger.success(f"成功加载工具包 {package_name}")
except json.JSONDecodeError as e:
logger.error(f"解码 JSON {json_path} 时发生错误: {e}")
except Exception as e:
@ -114,10 +114,10 @@ class MarshoTools:
try:
function = getattr(package, function_name)
return await function(**args)
except AttributeError:
logger.error(f"函数 '{function_name}''{package_name}' 中找不到。")
except TypeError as e:
logger.error(f"调用函数 '{function_name}' 时发生错误: {e}")
except Exception as e:
errinfo = f"调用函数 '{function_name}'时发生错误:{e}"
logger.error(errinfo)
return errinfo
else:
logger.error(f"工具包 '{package_name}' 未导入")

View File

@ -13,7 +13,8 @@ dependencies = [
"zhDatetime>=1.1.1",
"aiohttp>=3.9",
"httpx>=0.27.0",
"ruamel.yaml>=0.18.6"
"ruamel.yaml>=0.18.6",
"pyyaml>=6.0.2"
]
license = { text = "MIT" }