import importlib import json import os import sys # import importlib.util import traceback from nonebot import logger from .config import config from .util import * class MarshoContext: """ Marsho 的上下文类 """ def __init__(self): self.contents = {"private": {}, "non-private": {}} def _get_target_dict(self, is_private): return self.contents["private"] if is_private else self.contents["non-private"] def append(self, content, target_id: str, is_private: bool): """ 往上下文中添加消息 """ target_dict = self._get_target_dict(is_private) target_dict.setdefault(target_id, []).append(content) def set_context(self, contexts, target_id: str, is_private: bool): """ 设置上下文 """ self._get_target_dict(is_private)[target_id] = contexts def reset(self, target_id: str, is_private: bool): """ 重置上下文 """ self._get_target_dict(is_private).pop(target_id, None) def reset_all(self): """ 重置所有上下文 """ self.contents = {"private": {}, "non-private": {}} def build(self, target_id: str, is_private: bool) -> list: """ 构建返回的上下文,不包括系统消息 """ return self._get_target_dict(is_private).setdefault(target_id, []) class MarshoTools: """ Marsho 的工具类 """ def __init__(self): self.tools_list = [] self.imported_packages = {} def load_tools(self, tools_dir): """ 从指定路径加载工具包 """ if not os.path.exists(tools_dir): logger.error(f"工具集目录 {tools_dir} 不存在。") return for package_name in os.listdir(tools_dir): package_path = os.path.join(tools_dir, package_name) if package_name in config.marshoai_disabled_toolkits: logger.info(f"工具包 {package_name} 已被禁用。") continue if os.path.isdir(package_path) and os.path.exists( os.path.join(package_path, "__init__.py") ): self._load_package(package_name, package_path) else: logger.warning(f"{package_path} 不是有效的工具包路径,跳过加载。") def _load_package(self, package_name, package_path): json_path = os.path.join(package_path, "tools.json") if os.path.exists(json_path): try: with open(json_path, "r", encoding="utf-8") as json_file: data = json.load(json_file) self.tools_list.extend(data) spec = importlib.util.spec_from_file_location( package_name, os.path.join(package_path, "__init__.py") ) package = importlib.util.module_from_spec(spec) self.imported_packages[package_name] = package sys.modules[package_name] = package spec.loader.exec_module(package) logger.success(f"成功加载工具包 {package_name}") except json.JSONDecodeError as e: logger.error(f"解码 JSON {json_path} 时发生错误: {e}") except Exception as e: logger.error(f"加载工具包时发生错误: {e}") traceback.print_exc() else: logger.warning(f"在工具包 {package_path} 下找不到tools.json,跳过加载。") async def call(self, full_function_name: str, args: dict): """ 调用指定的函数 """ parts = full_function_name.split("__") if len(parts) != 2: logger.error("函数名无效") return package_name, function_name = parts if package_name in self.imported_packages: package = self.imported_packages[package_name] try: function = getattr(package, function_name) return await function(**args) except Exception as e: errinfo = f"调用函数 '{function_name}'时发生错误:{e}" logger.error(errinfo) return errinfo else: logger.error(f"工具包 '{package_name}' 未导入") def has_function(self, full_function_name: str) -> bool: """ 检查是否存在指定的函数 """ try: return any( t["function"]["name"].replace("-", "_") == full_function_name.replace("-", "_") for t in self.tools_list ) except Exception as e: logger.error(f"检查函数 '{full_function_name}' 时发生错误:{e}") return False def get_tools_list(self): if not self.tools_list or not config.marshoai_enable_tools: return None return self.tools_list