diff --git a/nonebot_plugin_marshoai/models.py b/nonebot_plugin_marshoai/models.py index 2d29ad59..85b1c08c 100755 --- a/nonebot_plugin_marshoai/models.py +++ b/nonebot_plugin_marshoai/models.py @@ -28,40 +28,31 @@ class MarshoContext: 往上下文中添加消息 """ target_dict = self._get_target_dict(is_private) - if target_id not in target_dict: - target_dict[target_id] = [] - target_dict[target_id].append(content) + target_dict.setdefault(target_id, []).append(content) def set_context(self, contexts, target_id: str, is_private: bool): """ 设置上下文 """ - target_dict = self._get_target_dict(is_private) - target_dict[target_id] = contexts + self._get_target_dict(is_private)[target_id] = contexts def reset(self, target_id: str, is_private: bool): """ 重置上下文 """ - target_dict = self._get_target_dict(is_private) - if target_id in target_dict: - target_dict[target_id].clear() + self._get_target_dict(is_private).pop(target_id, None) def reset_all(self): """ 重置所有上下文 """ - self.contents["private"].clear() - self.contents["non-private"].clear() + self.contents = {"private": {}, "non-private": {}} def build(self, target_id: str, is_private: bool) -> list: """ 构建返回的上下文,不包括系统消息 """ - target_dict = self._get_target_dict(is_private) - if target_id not in target_dict: - target_dict[target_id] = [] - return target_dict[target_id] + return self._get_target_dict(is_private).setdefault(target_id, []) class MarshoTools: @@ -84,54 +75,52 @@ class MarshoTools: for package_name in os.listdir(tools_dir): package_path = os.path.join(tools_dir, package_name) - # logger.info(f"尝试加载工具包 {package_name}") if package_name in config.marshoai_disabled_toolkits: logger.info(f"工具包 {package_name} 已被禁用。") continue + if os.path.isdir(package_path) and os.path.exists( os.path.join(package_path, "__init__.py") ): - json_path = os.path.join(package_path, "tools.json") - if os.path.exists(json_path): - try: - with open(json_path, "r", encoding="utf-8") as json_file: - data = json.load(json_file) - for i in data: - - self.tools_list.append(i) - - spec = importlib.util.spec_from_file_location( - package_name, os.path.join(package_path, "__init__.py") - ) - package = importlib.util.module_from_spec(spec) - self.imported_packages[package_name] = package - sys.modules[package_name] = package - spec.loader.exec_module(package) - - logger.success(f"成功加载工具包 {package_name}") - except json.JSONDecodeError as e: - logger.error(f"解码 JSON {json_path} 时发生错误: {e}") - except Exception as e: - logger.error(f"加载工具包时发生错误: {e}") - traceback.print_exc() - else: - logger.warning( - f"在工具包 {package_path} 下找不到tools.json,跳过加载。" - ) + self._load_package(package_name, package_path) else: logger.warning(f"{package_path} 不是有效的工具包路径,跳过加载。") + def _load_package(self, package_name, package_path): + json_path = os.path.join(package_path, "tools.json") + if os.path.exists(json_path): + try: + with open(json_path, "r", encoding="utf-8") as json_file: + data = json.load(json_file) + self.tools_list.extend(data) + + spec = importlib.util.spec_from_file_location( + package_name, os.path.join(package_path, "__init__.py") + ) + package = importlib.util.module_from_spec(spec) + self.imported_packages[package_name] = package + sys.modules[package_name] = package + spec.loader.exec_module(package) + + logger.success(f"成功加载工具包 {package_name}") + except json.JSONDecodeError as e: + logger.error(f"解码 JSON {json_path} 时发生错误: {e}") + except Exception as e: + logger.error(f"加载工具包时发生错误: {e}") + traceback.print_exc() + else: + logger.warning(f"在工具包 {package_path} 下找不到tools.json,跳过加载。") + async def call(self, full_function_name: str, args: dict): """ 调用指定的函数 """ - # 分割包名和函数名 parts = full_function_name.split("__") - if len(parts) == 2: - package_name = parts[0] - function_name = parts[1] - else: + if len(parts) != 2: logger.error("函数名无效") + return + + package_name, function_name = parts if package_name in self.imported_packages: package = self.imported_packages[package_name] try: @@ -149,12 +138,11 @@ class MarshoTools: 检查是否存在指定的函数 """ try: - for t in self.tools_list: - if t["function"]["name"].replace( - "-", "_" - ) == full_function_name.replace("-", "_"): - return True - return False + return any( + t["function"]["name"].replace("-", "_") + == full_function_name.replace("-", "_") + for t in self.tools_list + ) except Exception as e: logger.error(f"检查函数 '{full_function_name}' 时发生错误:{e}") return False diff --git a/nonebot_plugin_marshoai/plugins/marshoai_bangumi/__init__.py b/nonebot_plugin_marshoai/plugins/marshoai_bangumi/__init__.py index 19891571..27194005 100644 --- a/nonebot_plugin_marshoai/plugins/marshoai_bangumi/__init__.py +++ b/nonebot_plugin_marshoai/plugins/marshoai_bangumi/__init__.py @@ -1,7 +1,7 @@ import traceback import httpx -from zhDateTime import DateTime +from zhDateTime import DateTime # type: ignore from nonebot_plugin_marshoai.plugin import PluginMetadata, on_function_call from nonebot_plugin_marshoai.plugin.func_call.params import String diff --git a/nonebot_plugin_marshoai/util.py b/nonebot_plugin_marshoai/util.py index 9cda5e3e..828cf101 100755 --- a/nonebot_plugin_marshoai/util.py +++ b/nonebot_plugin_marshoai/util.py @@ -90,7 +90,8 @@ async def make_chat( 参数: client: 用于与AI模型进行通信 msg: 消息内容 - model_name: 指定AI模型名""" + model_name: 指定AI模型名 + tools: 工具列表""" return await client.complete( messages=msg, model=model_name,