nonebot-plugin-marshoai/nonebot_plugin_marshoai/models.py

134 lines
4.6 KiB
Python
Raw Normal View History

from .util import *
from .config import config
import os
import re
import json
import importlib
# import importlib.util
import traceback
from nonebot import logger
class MarshoContext:
"""
Marsho 的上下文类
"""
def __init__(self):
self.contents = {"private": {}, "non-private": {}}
def _get_target_dict(self, is_private):
return self.contents["private"] if is_private else self.contents["non-private"]
2024-10-03 15:16:32 +08:00
def append(self, content, target_id: str, is_private: bool):
"""
往上下文中添加消息
"""
target_dict = self._get_target_dict(is_private)
if target_id not in target_dict:
target_dict[target_id] = []
target_dict[target_id].append(content)
2024-10-03 15:16:32 +08:00
def set_context(self, contexts, target_id: str, is_private: bool):
"""
设置上下文
"""
target_dict = self._get_target_dict(is_private)
target_dict[target_id] = contexts
def reset(self, target_id: str, is_private: bool):
"""
重置上下文
"""
target_dict = self._get_target_dict(is_private)
2024-11-17 02:48:17 +08:00
if target_id in target_dict:
target_dict[target_id].clear()
2024-10-03 15:16:32 +08:00
def build(self, target_id: str, is_private: bool) -> list:
"""
构建返回的上下文不包括系统消息
"""
target_dict = self._get_target_dict(is_private)
if target_id not in target_dict:
target_dict[target_id] = []
return target_dict[target_id]
class MarshoTools:
"""
Marsho 的工具类
"""
def __init__(self):
self.tools_list = []
self.imported_packages = {}
def load_tools(self, tools_dir):
"""
从指定路径加载工具包
"""
if not os.path.exists(tools_dir):
logger.error(f"工具集目录 {tools_dir} 不存在。")
return
for package_name in os.listdir(tools_dir):
package_path = os.path.join(tools_dir, package_name)
if os.path.isdir(package_path) and os.path.exists(
os.path.join(package_path, "__init__.py")
):
json_path = os.path.join(package_path, "tools.json")
if os.path.exists(json_path):
try:
with open(json_path, "r", encoding="utf-8") as json_file:
2024-11-24 01:44:42 +08:00
data = json.load(json_file)
for i in data:
self.tools_list.append(i)
# 导入包
spec = importlib.util.spec_from_file_location(
package_name, os.path.join(package_path, "__init__.py")
)
package = importlib.util.module_from_spec(spec)
spec.loader.exec_module(package)
self.imported_packages[package_name] = package
logger.success(f"成功加载工具包 {package_name}")
except json.JSONDecodeError as e:
logger.error(f"解码 JSON {json_path} 时发生错误: {e}")
except Exception as e:
logger.error(f"加载工具包时发生错误: {e}")
traceback.print_exc()
else:
logger.warning(
f"在工具包 {package_path} 下找不到tools.json跳过加载。"
)
else:
logger.warning(f"{package_path} 不是有效的工具包路径,跳过加载。")
async def call(self, full_function_name: str, args: dict):
"""
调用指定的函数
"""
# 分割包名和函数名
parts = full_function_name.split("__")
if len(parts) == 2:
package_name = parts[0]
function_name = parts[1]
else:
logger.error("函数名无效")
if package_name in self.imported_packages:
package = self.imported_packages[package_name]
try:
function = getattr(package, function_name)
return await function(**args)
except Exception as e:
errinfo = f"调用函数 '{function_name}'时发生错误:{e}"
logger.error(errinfo)
return errinfo
else:
logger.error(f"工具包 '{package_name}' 未导入")
def get_tools_list(self):
if not self.tools_list or not config.marshoai_enable_tools:
return None
return self.tools_list