2024-12-14 19:16:59 +08:00

6.5 KiB
Executable File
Raw Permalink Blame History

title
models

模块 nonebot_plugin_marshoai.models

class MarshoContext


func __init__(self)

源代码在GitHub上查看
def __init__(self):
    self.contents = {'private': {}, 'non-private': {}}

func append(self, content, target_id: str, is_private: bool)

说明: 往上下文中添加消息

源代码在GitHub上查看
def append(self, content, target_id: str, is_private: bool):
    target_dict = self._get_target_dict(is_private)
    if target_id not in target_dict:
        target_dict[target_id] = []
    target_dict[target_id].append(content)

func set_context(self, contexts, target_id: str, is_private: bool)

说明: 设置上下文

源代码在GitHub上查看
def set_context(self, contexts, target_id: str, is_private: bool):
    target_dict = self._get_target_dict(is_private)
    target_dict[target_id] = contexts

func reset(self, target_id: str, is_private: bool)

说明: 重置上下文

源代码在GitHub上查看
def reset(self, target_id: str, is_private: bool):
    target_dict = self._get_target_dict(is_private)
    if target_id in target_dict:
        target_dict[target_id].clear()

func build(self, target_id: str, is_private: bool) -> list

说明: 构建返回的上下文,不包括系统消息

源代码在GitHub上查看
def build(self, target_id: str, is_private: bool) -> list:
    target_dict = self._get_target_dict(is_private)
    if target_id not in target_dict:
        target_dict[target_id] = []
    return target_dict[target_id]

class MarshoTools


func __init__(self)

源代码在GitHub上查看
def __init__(self):
    self.tools_list = []
    self.imported_packages = {}

func load_tools(self, tools_dir)

说明: 从指定路径加载工具包

源代码在GitHub上查看
def load_tools(self, tools_dir):
    if not os.path.exists(tools_dir):
        logger.error(f'工具集目录 {tools_dir} 不存在。')
        return
    for package_name in os.listdir(tools_dir):
        package_path = os.path.join(tools_dir, package_name)
        if package_name in config.marshoai_disabled_toolkits:
            logger.info(f'工具包 {package_name} 已被禁用。')
            continue
        if os.path.isdir(package_path) and os.path.exists(os.path.join(package_path, '__init__.py')):
            json_path = os.path.join(package_path, 'tools.json')
            if os.path.exists(json_path):
                try:
                    with open(json_path, 'r', encoding='utf-8') as json_file:
                        data = json.load(json_file)
                        for i in data:
                            self.tools_list.append(i)
                        spec = importlib.util.spec_from_file_location(package_name, os.path.join(package_path, '__init__.py'))
                        package = importlib.util.module_from_spec(spec)
                        self.imported_packages[package_name] = package
                        sys.modules[package_name] = package
                        spec.loader.exec_module(package)
                        logger.success(f'成功加载工具包 {package_name}')
                except json.JSONDecodeError as e:
                    logger.error(f'解码 JSON {json_path} 时发生错误: {e}')
                except Exception as e:
                    logger.error(f'加载工具包时发生错误: {e}')
                    traceback.print_exc()
            else:
                logger.warning(f'在工具包 {package_path} 下找不到tools.json跳过加载。')
        else:
            logger.warning(f'{package_path} 不是有效的工具包路径,跳过加载。')

async func call(self, full_function_name: str, args: dict)

说明: 调用指定的函数

源代码在GitHub上查看
async def call(self, full_function_name: str, args: dict):
    parts = full_function_name.split('__')
    if len(parts) == 2:
        package_name = parts[0]
        function_name = parts[1]
    else:
        logger.error('函数名无效')
    if package_name in self.imported_packages:
        package = self.imported_packages[package_name]
        try:
            function = getattr(package, function_name)
            return await function(**args)
        except Exception as e:
            errinfo = f"调用函数 '{function_name}'时发生错误:{e}"
            logger.error(errinfo)
            return errinfo
    else:
        logger.error(f"工具包 '{package_name}' 未导入")

func get_tools_list(self)

源代码在GitHub上查看
def get_tools_list(self):
    if not self.tools_list or not config.marshoai_enable_tools:
        return None
    return self.tools_list