from .util import * from .config import config import os import re import json import importlib #import importlib.util import traceback from nonebot import logger class MarshoContext: """ Marsho 的上下文类 """ def __init__(self): self.contents = { "private": {}, "non-private": {} } def _get_target_dict(self, is_private): return self.contents["private"] if is_private else self.contents["non-private"] def append(self, content, target_id: str, is_private: bool): """ 往上下文中添加消息 """ target_dict = self._get_target_dict(is_private) if target_id not in target_dict: target_dict[target_id] = [] target_dict[target_id].append(content) def set_context(self, contexts, target_id: str, is_private: bool): """ 设置上下文 """ target_dict = self._get_target_dict(is_private) target_dict[target_id] = contexts def reset(self, target_id: str, is_private: bool): """ 重置上下文 """ target_dict = self._get_target_dict(is_private) if target_id in target_dict: target_dict[target_id].clear() def build(self, target_id: str, is_private: bool) -> list: """ 构建返回的上下文,不包括系统消息 """ target_dict = self._get_target_dict(is_private) if target_id not in target_dict: target_dict[target_id] = [] return target_dict[target_id] class MarshoTools: """ Marsho 的工具类 """ def __init__(self): self.tools_list = [] self.imported_packages = {} def load_tools(self, tools_dir): """ 从指定路径加载工具包 """ if not os.path.exists(tools_dir): logger.error(f"工具集目录 {tools_dir} 不存在。") return for package_name in os.listdir(tools_dir): package_path = os.path.join(tools_dir, package_name) if os.path.isdir(package_path) and os.path.exists(os.path.join(package_path, '__init__.py')): json_path = os.path.join(package_path, 'tools.json') if os.path.exists(json_path): try: with open(json_path, 'r', encoding="utf-8") as json_file: data = json.load(json_file) for i in data: self.tools_list.append(i) # 导入包 spec = importlib.util.spec_from_file_location(package_name, os.path.join(package_path, "__init__.py")) package = importlib.util.module_from_spec(spec) spec.loader.exec_module(package) self.imported_packages[package_name] = package logger.info(f"成功加载工具包 {package_name}") except json.JSONDecodeError as e: logger.error(f"解码 JSON {json_path} 时发生错误: {e}") except Exception as e: logger.error(f"加载工具包时发生错误: {e}") traceback.print_exc() else: logger.warning(f"在工具包 {package_path} 下找不到tools.json,跳过加载。") else: logger.warning(f"{package_path} 不是有效的工具包路径,跳过加载。") async def call(self, full_function_name: str, args: dict): """ 调用指定的函数 """ # 分割包名和函数名 parts = full_function_name.split("__") if len(parts) == 2: package_name = parts[0] function_name = parts[1] else: logger.error("函数名无效") if package_name in self.imported_packages: package = self.imported_packages[package_name] try: function = getattr(package, function_name) return await function(**args) except AttributeError: logger.error(f"函数 '{function_name}' 在 '{package_name}' 中找不到。") except TypeError as e: logger.error(f"调用函数 '{function_name}' 时发生错误: {e}") else: logger.error(f"工具包 '{package_name}' 未导入") def get_tools_list(self): if not self.tools_list or not config.marshoai_enable_tools: return None return self.tools_list