mirror of
https://github.com/LiteyukiStudio/nonebot-plugin-marshoai.git
synced 2024-11-30 09:14:52 +08:00
✨新增加载外部工具集的配置项,修复依赖
This commit is contained in:
parent
c7e55cc803
commit
075a529aa1
@ -168,7 +168,7 @@ _✨ 使用 OpenAI 标准格式 API 的聊天机器人插件 ✨_
|
||||
| MARSHOAI_ENABLE_PRAISES | `bool` | `true` | 是否启用夸赞名单功能 |
|
||||
| MARSHOAI_ENABLE_TOOLS | `bool` | `true` | 是否启用小棉工具 |
|
||||
| MARSHOAI_LOAD_BUILTIN_TOOLS | `bool` | `true` | 是否加载内置工具包 |
|
||||
|
||||
| MARSHOAI_TOOLSET_DIR | `list` | `[]` | 外部工具集路径列表 |
|
||||
|
||||
## ❤ 鸣谢&版权说明
|
||||
|
||||
|
@ -80,9 +80,13 @@ target_list = [] # 记录需保存历史上下文的列表
|
||||
async def _preload_tools():
|
||||
tools_dir = store.get_plugin_data_dir() / "tools"
|
||||
os.makedirs(tools_dir, exist_ok=True)
|
||||
if config.marshoai_load_builtin_tools:
|
||||
tools.load_tools(Path(__file__).parent / "tools")
|
||||
tools.load_tools(store.get_plugin_data_dir() / "tools")
|
||||
if config.marshoai_enable_tools:
|
||||
if config.marshoai_load_builtin_tools:
|
||||
tools.load_tools(Path(__file__).parent / "tools")
|
||||
tools.load_tools(store.get_plugin_data_dir() / "tools")
|
||||
for tool_dir in config.marshoai_toolset_dir:
|
||||
tools.load_tools(tool_dir)
|
||||
logger.info("如果启用小棉工具后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_TOOLS 设为 false。")
|
||||
|
||||
|
||||
@add_usermsg_cmd.handle()
|
||||
@ -250,11 +254,11 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None)
|
||||
while choice.message.tool_calls != None:
|
||||
tool_msg.append(AssistantMessage(tool_calls=response.choices[0].message.tool_calls))
|
||||
for tool_call in choice.message.tool_calls:
|
||||
if isinstance(tool_call, ChatCompletionsToolCall):
|
||||
if isinstance(tool_call, ChatCompletionsToolCall): # 循环调用工具直到不需要调用
|
||||
function_args = json.loads(tool_call.function.arguments.replace("'", '"'))
|
||||
logger.info(f"调用函数 {tool_call.function.name} ,参数为 {function_args}")
|
||||
await UniMessage(f"调用函数 {tool_call.function.name} ,参数为 {function_args}").send()
|
||||
func_return = await tools.call(tool_call.function.name, function_args)
|
||||
func_return = await tools.call(tool_call.function.name, function_args) # 获取返回值
|
||||
tool_msg.append(ToolMessage(tool_call_id=tool_call.id, content=func_return))
|
||||
response = await make_chat(
|
||||
client=client,
|
||||
@ -263,12 +267,17 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None)
|
||||
tools=tools.get_tools_list()
|
||||
)
|
||||
choice = response.choices[0]
|
||||
context.append(
|
||||
UserMessage(content=usermsg).as_dict(), target.id, target.private
|
||||
)
|
||||
if choice["finish_reason"] == CompletionsFinishReason.STOPPED:
|
||||
context.append(
|
||||
UserMessage(content=usermsg).as_dict(), target.id, target.private
|
||||
)
|
||||
# context.append(tool_msg, target.id, target.private)
|
||||
context.append(choice.message.as_dict(), target.id, target.private)
|
||||
await UniMessage(str(choice.message.content)).send(reply_to=True)
|
||||
context.append(choice.message.as_dict(), target.id, target.private)
|
||||
await UniMessage(str(choice.message.content)).send(reply_to=True)
|
||||
else:
|
||||
await marsho_cmd.finish(f"意外的完成原因:{choice['finish_reason']}")
|
||||
else:
|
||||
await marsho_cmd.finish(f"意外的完成原因:{choice['finish_reason']}")
|
||||
except Exception as e:
|
||||
await UniMessage(str(e) + suggest_solution(str(e))).send()
|
||||
traceback.print_exc()
|
||||
|
@ -33,6 +33,7 @@ class ConfigModel(BaseModel):
|
||||
marshoai_enable_time_prompt: bool = True
|
||||
marshoai_enable_tools: bool = True
|
||||
marshoai_load_builtin_tools: bool = True
|
||||
marshoai_toolset_dir: list = []
|
||||
marshoai_azure_endpoint: str = "https://models.inference.ai.azure.com"
|
||||
marshoai_temperature: float | None = None
|
||||
marshoai_max_tokens: int | None = None
|
||||
|
@ -34,6 +34,8 @@ marshoai_enable_tools: true # 是否启用工具支持。
|
||||
|
||||
marshoai_load_builtin_tools: true # 是否加载内置工具。
|
||||
|
||||
marshoai_toolset_dir: [] # 工具集路径。
|
||||
|
||||
marshoai_azure_endpoint: "https://models.inference.ai.azure.com" # OpenAI 标准格式 API 的端点。
|
||||
|
||||
# 模型参数配置
|
||||
|
@ -87,7 +87,7 @@ class MarshoTools:
|
||||
package = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(package)
|
||||
self.imported_packages[package_name] = package
|
||||
logger.info(f"成功加载工具包 {package_name}")
|
||||
logger.success(f"成功加载工具包 {package_name}")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"解码 JSON {json_path} 时发生错误: {e}")
|
||||
except Exception as e:
|
||||
@ -114,10 +114,10 @@ class MarshoTools:
|
||||
try:
|
||||
function = getattr(package, function_name)
|
||||
return await function(**args)
|
||||
except AttributeError:
|
||||
logger.error(f"函数 '{function_name}' 在 '{package_name}' 中找不到。")
|
||||
except TypeError as e:
|
||||
logger.error(f"调用函数 '{function_name}' 时发生错误: {e}")
|
||||
except Exception as e:
|
||||
errinfo = f"调用函数 '{function_name}'时发生错误:{e}"
|
||||
logger.error(errinfo)
|
||||
return errinfo
|
||||
else:
|
||||
logger.error(f"工具包 '{package_name}' 未导入")
|
||||
|
||||
|
@ -13,7 +13,8 @@ dependencies = [
|
||||
"zhDatetime>=1.1.1",
|
||||
"aiohttp>=3.9",
|
||||
"httpx>=0.27.0",
|
||||
"ruamel.yaml>=0.18.6"
|
||||
"ruamel.yaml>=0.18.6",
|
||||
"pyyaml>=6.0.2"
|
||||
|
||||
]
|
||||
license = { text = "MIT" }
|
||||
|
Loading…
Reference in New Issue
Block a user