mirror of
https://github.com/LiteyukiStudio/nonebot-plugin-marshoai.git
synced 2024-11-30 09:14:52 +08:00
✨新增加载外部工具集的配置项,修复依赖
This commit is contained in:
parent
c7e55cc803
commit
075a529aa1
@ -168,7 +168,7 @@ _✨ 使用 OpenAI 标准格式 API 的聊天机器人插件 ✨_
|
|||||||
| MARSHOAI_ENABLE_PRAISES | `bool` | `true` | 是否启用夸赞名单功能 |
|
| MARSHOAI_ENABLE_PRAISES | `bool` | `true` | 是否启用夸赞名单功能 |
|
||||||
| MARSHOAI_ENABLE_TOOLS | `bool` | `true` | 是否启用小棉工具 |
|
| MARSHOAI_ENABLE_TOOLS | `bool` | `true` | 是否启用小棉工具 |
|
||||||
| MARSHOAI_LOAD_BUILTIN_TOOLS | `bool` | `true` | 是否加载内置工具包 |
|
| MARSHOAI_LOAD_BUILTIN_TOOLS | `bool` | `true` | 是否加载内置工具包 |
|
||||||
|
| MARSHOAI_TOOLSET_DIR | `list` | `[]` | 外部工具集路径列表 |
|
||||||
|
|
||||||
## ❤ 鸣谢&版权说明
|
## ❤ 鸣谢&版权说明
|
||||||
|
|
||||||
|
@ -80,9 +80,13 @@ target_list = [] # 记录需保存历史上下文的列表
|
|||||||
async def _preload_tools():
|
async def _preload_tools():
|
||||||
tools_dir = store.get_plugin_data_dir() / "tools"
|
tools_dir = store.get_plugin_data_dir() / "tools"
|
||||||
os.makedirs(tools_dir, exist_ok=True)
|
os.makedirs(tools_dir, exist_ok=True)
|
||||||
|
if config.marshoai_enable_tools:
|
||||||
if config.marshoai_load_builtin_tools:
|
if config.marshoai_load_builtin_tools:
|
||||||
tools.load_tools(Path(__file__).parent / "tools")
|
tools.load_tools(Path(__file__).parent / "tools")
|
||||||
tools.load_tools(store.get_plugin_data_dir() / "tools")
|
tools.load_tools(store.get_plugin_data_dir() / "tools")
|
||||||
|
for tool_dir in config.marshoai_toolset_dir:
|
||||||
|
tools.load_tools(tool_dir)
|
||||||
|
logger.info("如果启用小棉工具后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_TOOLS 设为 false。")
|
||||||
|
|
||||||
|
|
||||||
@add_usermsg_cmd.handle()
|
@add_usermsg_cmd.handle()
|
||||||
@ -250,11 +254,11 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None)
|
|||||||
while choice.message.tool_calls != None:
|
while choice.message.tool_calls != None:
|
||||||
tool_msg.append(AssistantMessage(tool_calls=response.choices[0].message.tool_calls))
|
tool_msg.append(AssistantMessage(tool_calls=response.choices[0].message.tool_calls))
|
||||||
for tool_call in choice.message.tool_calls:
|
for tool_call in choice.message.tool_calls:
|
||||||
if isinstance(tool_call, ChatCompletionsToolCall):
|
if isinstance(tool_call, ChatCompletionsToolCall): # 循环调用工具直到不需要调用
|
||||||
function_args = json.loads(tool_call.function.arguments.replace("'", '"'))
|
function_args = json.loads(tool_call.function.arguments.replace("'", '"'))
|
||||||
logger.info(f"调用函数 {tool_call.function.name} ,参数为 {function_args}")
|
logger.info(f"调用函数 {tool_call.function.name} ,参数为 {function_args}")
|
||||||
await UniMessage(f"调用函数 {tool_call.function.name} ,参数为 {function_args}").send()
|
await UniMessage(f"调用函数 {tool_call.function.name} ,参数为 {function_args}").send()
|
||||||
func_return = await tools.call(tool_call.function.name, function_args)
|
func_return = await tools.call(tool_call.function.name, function_args) # 获取返回值
|
||||||
tool_msg.append(ToolMessage(tool_call_id=tool_call.id, content=func_return))
|
tool_msg.append(ToolMessage(tool_call_id=tool_call.id, content=func_return))
|
||||||
response = await make_chat(
|
response = await make_chat(
|
||||||
client=client,
|
client=client,
|
||||||
@ -263,12 +267,17 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None)
|
|||||||
tools=tools.get_tools_list()
|
tools=tools.get_tools_list()
|
||||||
)
|
)
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
|
if choice["finish_reason"] == CompletionsFinishReason.STOPPED:
|
||||||
context.append(
|
context.append(
|
||||||
UserMessage(content=usermsg).as_dict(), target.id, target.private
|
UserMessage(content=usermsg).as_dict(), target.id, target.private
|
||||||
)
|
)
|
||||||
# context.append(tool_msg, target.id, target.private)
|
# context.append(tool_msg, target.id, target.private)
|
||||||
context.append(choice.message.as_dict(), target.id, target.private)
|
context.append(choice.message.as_dict(), target.id, target.private)
|
||||||
await UniMessage(str(choice.message.content)).send(reply_to=True)
|
await UniMessage(str(choice.message.content)).send(reply_to=True)
|
||||||
|
else:
|
||||||
|
await marsho_cmd.finish(f"意外的完成原因:{choice['finish_reason']}")
|
||||||
|
else:
|
||||||
|
await marsho_cmd.finish(f"意外的完成原因:{choice['finish_reason']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await UniMessage(str(e) + suggest_solution(str(e))).send()
|
await UniMessage(str(e) + suggest_solution(str(e))).send()
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
@ -33,6 +33,7 @@ class ConfigModel(BaseModel):
|
|||||||
marshoai_enable_time_prompt: bool = True
|
marshoai_enable_time_prompt: bool = True
|
||||||
marshoai_enable_tools: bool = True
|
marshoai_enable_tools: bool = True
|
||||||
marshoai_load_builtin_tools: bool = True
|
marshoai_load_builtin_tools: bool = True
|
||||||
|
marshoai_toolset_dir: list = []
|
||||||
marshoai_azure_endpoint: str = "https://models.inference.ai.azure.com"
|
marshoai_azure_endpoint: str = "https://models.inference.ai.azure.com"
|
||||||
marshoai_temperature: float | None = None
|
marshoai_temperature: float | None = None
|
||||||
marshoai_max_tokens: int | None = None
|
marshoai_max_tokens: int | None = None
|
||||||
|
@ -34,6 +34,8 @@ marshoai_enable_tools: true # 是否启用工具支持。
|
|||||||
|
|
||||||
marshoai_load_builtin_tools: true # 是否加载内置工具。
|
marshoai_load_builtin_tools: true # 是否加载内置工具。
|
||||||
|
|
||||||
|
marshoai_toolset_dir: [] # 工具集路径。
|
||||||
|
|
||||||
marshoai_azure_endpoint: "https://models.inference.ai.azure.com" # OpenAI 标准格式 API 的端点。
|
marshoai_azure_endpoint: "https://models.inference.ai.azure.com" # OpenAI 标准格式 API 的端点。
|
||||||
|
|
||||||
# 模型参数配置
|
# 模型参数配置
|
||||||
|
@ -87,7 +87,7 @@ class MarshoTools:
|
|||||||
package = importlib.util.module_from_spec(spec)
|
package = importlib.util.module_from_spec(spec)
|
||||||
spec.loader.exec_module(package)
|
spec.loader.exec_module(package)
|
||||||
self.imported_packages[package_name] = package
|
self.imported_packages[package_name] = package
|
||||||
logger.info(f"成功加载工具包 {package_name}")
|
logger.success(f"成功加载工具包 {package_name}")
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.error(f"解码 JSON {json_path} 时发生错误: {e}")
|
logger.error(f"解码 JSON {json_path} 时发生错误: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -114,10 +114,10 @@ class MarshoTools:
|
|||||||
try:
|
try:
|
||||||
function = getattr(package, function_name)
|
function = getattr(package, function_name)
|
||||||
return await function(**args)
|
return await function(**args)
|
||||||
except AttributeError:
|
except Exception as e:
|
||||||
logger.error(f"函数 '{function_name}' 在 '{package_name}' 中找不到。")
|
errinfo = f"调用函数 '{function_name}'时发生错误:{e}"
|
||||||
except TypeError as e:
|
logger.error(errinfo)
|
||||||
logger.error(f"调用函数 '{function_name}' 时发生错误: {e}")
|
return errinfo
|
||||||
else:
|
else:
|
||||||
logger.error(f"工具包 '{package_name}' 未导入")
|
logger.error(f"工具包 '{package_name}' 未导入")
|
||||||
|
|
||||||
|
@ -13,7 +13,8 @@ dependencies = [
|
|||||||
"zhDatetime>=1.1.1",
|
"zhDatetime>=1.1.1",
|
||||||
"aiohttp>=3.9",
|
"aiohttp>=3.9",
|
||||||
"httpx>=0.27.0",
|
"httpx>=0.27.0",
|
||||||
"ruamel.yaml>=0.18.6"
|
"ruamel.yaml>=0.18.6",
|
||||||
|
"pyyaml>=6.0.2"
|
||||||
|
|
||||||
]
|
]
|
||||||
license = { text = "MIT" }
|
license = { text = "MIT" }
|
||||||
|
Loading…
Reference in New Issue
Block a user