🐛 简化代码

This commit is contained in:
Asankilp 2024-12-30 00:01:57 +08:00
parent b417a5c8d0
commit 80ed7692a4
3 changed files with 44 additions and 55 deletions

View File

@ -28,40 +28,31 @@ class MarshoContext:
往上下文中添加消息 往上下文中添加消息
""" """
target_dict = self._get_target_dict(is_private) target_dict = self._get_target_dict(is_private)
if target_id not in target_dict: target_dict.setdefault(target_id, []).append(content)
target_dict[target_id] = []
target_dict[target_id].append(content)
def set_context(self, contexts, target_id: str, is_private: bool): def set_context(self, contexts, target_id: str, is_private: bool):
""" """
设置上下文 设置上下文
""" """
target_dict = self._get_target_dict(is_private) self._get_target_dict(is_private)[target_id] = contexts
target_dict[target_id] = contexts
def reset(self, target_id: str, is_private: bool): def reset(self, target_id: str, is_private: bool):
""" """
重置上下文 重置上下文
""" """
target_dict = self._get_target_dict(is_private) self._get_target_dict(is_private).pop(target_id, None)
if target_id in target_dict:
target_dict[target_id].clear()
def reset_all(self): def reset_all(self):
""" """
重置所有上下文 重置所有上下文
""" """
self.contents["private"].clear() self.contents = {"private": {}, "non-private": {}}
self.contents["non-private"].clear()
def build(self, target_id: str, is_private: bool) -> list: def build(self, target_id: str, is_private: bool) -> list:
""" """
构建返回的上下文不包括系统消息 构建返回的上下文不包括系统消息
""" """
target_dict = self._get_target_dict(is_private) return self._get_target_dict(is_private).setdefault(target_id, [])
if target_id not in target_dict:
target_dict[target_id] = []
return target_dict[target_id]
class MarshoTools: class MarshoTools:
@ -84,21 +75,24 @@ class MarshoTools:
for package_name in os.listdir(tools_dir): for package_name in os.listdir(tools_dir):
package_path = os.path.join(tools_dir, package_name) package_path = os.path.join(tools_dir, package_name)
# logger.info(f"尝试加载工具包 {package_name}")
if package_name in config.marshoai_disabled_toolkits: if package_name in config.marshoai_disabled_toolkits:
logger.info(f"工具包 {package_name} 已被禁用。") logger.info(f"工具包 {package_name} 已被禁用。")
continue continue
if os.path.isdir(package_path) and os.path.exists( if os.path.isdir(package_path) and os.path.exists(
os.path.join(package_path, "__init__.py") os.path.join(package_path, "__init__.py")
): ):
self._load_package(package_name, package_path)
else:
logger.warning(f"{package_path} 不是有效的工具包路径,跳过加载。")
def _load_package(self, package_name, package_path):
json_path = os.path.join(package_path, "tools.json") json_path = os.path.join(package_path, "tools.json")
if os.path.exists(json_path): if os.path.exists(json_path):
try: try:
with open(json_path, "r", encoding="utf-8") as json_file: with open(json_path, "r", encoding="utf-8") as json_file:
data = json.load(json_file) data = json.load(json_file)
for i in data: self.tools_list.extend(data)
self.tools_list.append(i)
spec = importlib.util.spec_from_file_location( spec = importlib.util.spec_from_file_location(
package_name, os.path.join(package_path, "__init__.py") package_name, os.path.join(package_path, "__init__.py")
@ -115,23 +109,18 @@ class MarshoTools:
logger.error(f"加载工具包时发生错误: {e}") logger.error(f"加载工具包时发生错误: {e}")
traceback.print_exc() traceback.print_exc()
else: else:
logger.warning( logger.warning(f"在工具包 {package_path} 下找不到tools.json跳过加载。")
f"在工具包 {package_path} 下找不到tools.json跳过加载。"
)
else:
logger.warning(f"{package_path} 不是有效的工具包路径,跳过加载。")
async def call(self, full_function_name: str, args: dict): async def call(self, full_function_name: str, args: dict):
""" """
调用指定的函数 调用指定的函数
""" """
# 分割包名和函数名
parts = full_function_name.split("__") parts = full_function_name.split("__")
if len(parts) == 2: if len(parts) != 2:
package_name = parts[0]
function_name = parts[1]
else:
logger.error("函数名无效") logger.error("函数名无效")
return
package_name, function_name = parts
if package_name in self.imported_packages: if package_name in self.imported_packages:
package = self.imported_packages[package_name] package = self.imported_packages[package_name]
try: try:
@ -149,12 +138,11 @@ class MarshoTools:
检查是否存在指定的函数 检查是否存在指定的函数
""" """
try: try:
for t in self.tools_list: return any(
if t["function"]["name"].replace( t["function"]["name"].replace("-", "_")
"-", "_" == full_function_name.replace("-", "_")
) == full_function_name.replace("-", "_"): for t in self.tools_list
return True )
return False
except Exception as e: except Exception as e:
logger.error(f"检查函数 '{full_function_name}' 时发生错误:{e}") logger.error(f"检查函数 '{full_function_name}' 时发生错误:{e}")
return False return False

View File

@ -1,7 +1,7 @@
import traceback import traceback
import httpx import httpx
from zhDateTime import DateTime from zhDateTime import DateTime # type: ignore
from nonebot_plugin_marshoai.plugin import PluginMetadata, on_function_call from nonebot_plugin_marshoai.plugin import PluginMetadata, on_function_call
from nonebot_plugin_marshoai.plugin.func_call.params import String from nonebot_plugin_marshoai.plugin.func_call.params import String

View File

@ -90,7 +90,8 @@ async def make_chat(
参数: 参数:
client: 用于与AI模型进行通信 client: 用于与AI模型进行通信
msg: 消息内容 msg: 消息内容
model_name: 指定AI模型名""" model_name: 指定AI模型名
tools: 工具列表"""
return await client.complete( return await client.complete(
messages=msg, messages=msg,
model=model_name, model=model_name,