mirror of
https://github.com/LiteyukiStudio/nonebot-plugin-marshoai.git
synced 2025-01-26 18:12:47 +08:00
🐛 简化代码
This commit is contained in:
parent
b417a5c8d0
commit
80ed7692a4
@ -28,40 +28,31 @@ class MarshoContext:
|
|||||||
往上下文中添加消息
|
往上下文中添加消息
|
||||||
"""
|
"""
|
||||||
target_dict = self._get_target_dict(is_private)
|
target_dict = self._get_target_dict(is_private)
|
||||||
if target_id not in target_dict:
|
target_dict.setdefault(target_id, []).append(content)
|
||||||
target_dict[target_id] = []
|
|
||||||
target_dict[target_id].append(content)
|
|
||||||
|
|
||||||
def set_context(self, contexts, target_id: str, is_private: bool):
|
def set_context(self, contexts, target_id: str, is_private: bool):
|
||||||
"""
|
"""
|
||||||
设置上下文
|
设置上下文
|
||||||
"""
|
"""
|
||||||
target_dict = self._get_target_dict(is_private)
|
self._get_target_dict(is_private)[target_id] = contexts
|
||||||
target_dict[target_id] = contexts
|
|
||||||
|
|
||||||
def reset(self, target_id: str, is_private: bool):
|
def reset(self, target_id: str, is_private: bool):
|
||||||
"""
|
"""
|
||||||
重置上下文
|
重置上下文
|
||||||
"""
|
"""
|
||||||
target_dict = self._get_target_dict(is_private)
|
self._get_target_dict(is_private).pop(target_id, None)
|
||||||
if target_id in target_dict:
|
|
||||||
target_dict[target_id].clear()
|
|
||||||
|
|
||||||
def reset_all(self):
|
def reset_all(self):
|
||||||
"""
|
"""
|
||||||
重置所有上下文
|
重置所有上下文
|
||||||
"""
|
"""
|
||||||
self.contents["private"].clear()
|
self.contents = {"private": {}, "non-private": {}}
|
||||||
self.contents["non-private"].clear()
|
|
||||||
|
|
||||||
def build(self, target_id: str, is_private: bool) -> list:
|
def build(self, target_id: str, is_private: bool) -> list:
|
||||||
"""
|
"""
|
||||||
构建返回的上下文,不包括系统消息
|
构建返回的上下文,不包括系统消息
|
||||||
"""
|
"""
|
||||||
target_dict = self._get_target_dict(is_private)
|
return self._get_target_dict(is_private).setdefault(target_id, [])
|
||||||
if target_id not in target_dict:
|
|
||||||
target_dict[target_id] = []
|
|
||||||
return target_dict[target_id]
|
|
||||||
|
|
||||||
|
|
||||||
class MarshoTools:
|
class MarshoTools:
|
||||||
@ -84,21 +75,24 @@ class MarshoTools:
|
|||||||
for package_name in os.listdir(tools_dir):
|
for package_name in os.listdir(tools_dir):
|
||||||
package_path = os.path.join(tools_dir, package_name)
|
package_path = os.path.join(tools_dir, package_name)
|
||||||
|
|
||||||
# logger.info(f"尝试加载工具包 {package_name}")
|
|
||||||
if package_name in config.marshoai_disabled_toolkits:
|
if package_name in config.marshoai_disabled_toolkits:
|
||||||
logger.info(f"工具包 {package_name} 已被禁用。")
|
logger.info(f"工具包 {package_name} 已被禁用。")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if os.path.isdir(package_path) and os.path.exists(
|
if os.path.isdir(package_path) and os.path.exists(
|
||||||
os.path.join(package_path, "__init__.py")
|
os.path.join(package_path, "__init__.py")
|
||||||
):
|
):
|
||||||
|
self._load_package(package_name, package_path)
|
||||||
|
else:
|
||||||
|
logger.warning(f"{package_path} 不是有效的工具包路径,跳过加载。")
|
||||||
|
|
||||||
|
def _load_package(self, package_name, package_path):
|
||||||
json_path = os.path.join(package_path, "tools.json")
|
json_path = os.path.join(package_path, "tools.json")
|
||||||
if os.path.exists(json_path):
|
if os.path.exists(json_path):
|
||||||
try:
|
try:
|
||||||
with open(json_path, "r", encoding="utf-8") as json_file:
|
with open(json_path, "r", encoding="utf-8") as json_file:
|
||||||
data = json.load(json_file)
|
data = json.load(json_file)
|
||||||
for i in data:
|
self.tools_list.extend(data)
|
||||||
|
|
||||||
self.tools_list.append(i)
|
|
||||||
|
|
||||||
spec = importlib.util.spec_from_file_location(
|
spec = importlib.util.spec_from_file_location(
|
||||||
package_name, os.path.join(package_path, "__init__.py")
|
package_name, os.path.join(package_path, "__init__.py")
|
||||||
@ -115,23 +109,18 @@ class MarshoTools:
|
|||||||
logger.error(f"加载工具包时发生错误: {e}")
|
logger.error(f"加载工具包时发生错误: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(f"在工具包 {package_path} 下找不到tools.json,跳过加载。")
|
||||||
f"在工具包 {package_path} 下找不到tools.json,跳过加载。"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(f"{package_path} 不是有效的工具包路径,跳过加载。")
|
|
||||||
|
|
||||||
async def call(self, full_function_name: str, args: dict):
|
async def call(self, full_function_name: str, args: dict):
|
||||||
"""
|
"""
|
||||||
调用指定的函数
|
调用指定的函数
|
||||||
"""
|
"""
|
||||||
# 分割包名和函数名
|
|
||||||
parts = full_function_name.split("__")
|
parts = full_function_name.split("__")
|
||||||
if len(parts) == 2:
|
if len(parts) != 2:
|
||||||
package_name = parts[0]
|
|
||||||
function_name = parts[1]
|
|
||||||
else:
|
|
||||||
logger.error("函数名无效")
|
logger.error("函数名无效")
|
||||||
|
return
|
||||||
|
|
||||||
|
package_name, function_name = parts
|
||||||
if package_name in self.imported_packages:
|
if package_name in self.imported_packages:
|
||||||
package = self.imported_packages[package_name]
|
package = self.imported_packages[package_name]
|
||||||
try:
|
try:
|
||||||
@ -149,12 +138,11 @@ class MarshoTools:
|
|||||||
检查是否存在指定的函数
|
检查是否存在指定的函数
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
for t in self.tools_list:
|
return any(
|
||||||
if t["function"]["name"].replace(
|
t["function"]["name"].replace("-", "_")
|
||||||
"-", "_"
|
== full_function_name.replace("-", "_")
|
||||||
) == full_function_name.replace("-", "_"):
|
for t in self.tools_list
|
||||||
return True
|
)
|
||||||
return False
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"检查函数 '{full_function_name}' 时发生错误:{e}")
|
logger.error(f"检查函数 '{full_function_name}' 时发生错误:{e}")
|
||||||
return False
|
return False
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from zhDateTime import DateTime
|
from zhDateTime import DateTime # type: ignore
|
||||||
|
|
||||||
from nonebot_plugin_marshoai.plugin import PluginMetadata, on_function_call
|
from nonebot_plugin_marshoai.plugin import PluginMetadata, on_function_call
|
||||||
from nonebot_plugin_marshoai.plugin.func_call.params import String
|
from nonebot_plugin_marshoai.plugin.func_call.params import String
|
||||||
|
@ -90,7 +90,8 @@ async def make_chat(
|
|||||||
参数:
|
参数:
|
||||||
client: 用于与AI模型进行通信
|
client: 用于与AI模型进行通信
|
||||||
msg: 消息内容
|
msg: 消息内容
|
||||||
model_name: 指定AI模型名"""
|
model_name: 指定AI模型名
|
||||||
|
tools: 工具列表"""
|
||||||
return await client.complete(
|
return await client.complete(
|
||||||
messages=msg,
|
messages=msg,
|
||||||
model=model_name,
|
model=model_name,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user