重构Caller类,移除泛型参数;添加函数签名复制装饰器

This commit is contained in:
远野千束 2024-12-15 17:08:02 +08:00
parent af9a5e3c96
commit 0379789bec
8 changed files with 208 additions and 72 deletions

View File

@ -23,6 +23,8 @@ from nonebot.permission import SUPERUSER
from nonebot.rule import Rule, to_me
from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
from nonebot_plugin_marshoai.plugin.func_call.caller import get_function_calls
from .metadata import metadata
from .models import MarshoContext, MarshoTools
from .plugin import _plugins, load_plugins
@ -103,10 +105,9 @@ async def _preload_tools():
@driver.on_startup
async def _preload_plugins():
"""启动钩子加载插件"""
marshoai_plugin_dirs = config.marshoai_plugin_dirs
marshoai_plugin_dirs.insert(0, Path(__file__).parent / "plugins")
marshoai_plugin_dirs = config.marshoai_plugin_dirs # 外部插件目录列表
marshoai_plugin_dirs.insert(0, Path(__file__).parent / "plugins") # 预置插件目录
load_plugins(*marshoai_plugin_dirs)
logger.opt(colors=True).info(f"已加载 <c>{len(_plugins)}</c> 个小棉插件")
@add_usermsg_cmd.handle()
@ -266,7 +267,10 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None)
client=client,
model_name=model_name,
msg=context_msg + [UserMessage(content=usermsg)], # type: ignore
tools=tools.get_tools_list(),
tools=tools.get_tools_list()
+ list(
map(lambda v: v.data(), get_function_calls().values())
), # TODO 临时追加函数,后期优化
)
# await UniMessage(str(response)).send()
choice = response.choices[0]
@ -315,9 +319,23 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None)
await UniMessage(
f"调用函数 {tool_call.function.name} ,参数为 {function_args}"
).send()
# TODO 临时追加插件函数,若工具中没有则调用插件函数
if tools.has_function(tool_call.function.name):
logger.debug(f"调用工具函数 {tool_call.function.name}")
func_return = await tools.call(
tool_call.function.name, function_args
) # 获取返回值
else:
if caller := get_function_calls().get(
tool_call.function.name
):
logger.debug(f"调用插件函数 {tool_call.function.name}")
# 实现依赖注入检查函数参数及参数注解类型对Event类型的参数进行注入
caller.event = event
func_return = await caller.call(**function_args)
else:
logger.error(f"未找到函数 {tool_call.function.name}")
func_return = f"未找到函数 {tool_call.function.name}"
tool_msg.append(
ToolMessage(tool_call_id=tool_call.id, content=func_return) # type: ignore
)

View File

@ -90,6 +90,7 @@ class MarshoTools:
with open(json_path, "r", encoding="utf-8") as json_file:
data = json.load(json_file)
for i in data:
self.tools_list.append(i)
spec = importlib.util.spec_from_file_location(
@ -136,6 +137,21 @@ class MarshoTools:
else:
logger.error(f"工具包 '{package_name}' 未导入")
def has_function(self, full_function_name: str) -> bool:
"""
检查是否存在指定的函数
"""
try:
for t in self.tools_list:
if t["function"]["name"].replace(
"-", "_"
) == full_function_name.replace("-", "_"):
return True
return False
except Exception as e:
logger.error(f"检查函数 '{full_function_name}' 时发生错误:{e}")
return False
def get_tools_list(self):
if not self.tools_list or not config.marshoai_enable_tools:
return None

View File

@ -1,36 +1,34 @@
from typing import Generic, TypeVar
import inspect
from typing import Any
from nonebot import logger
from nonebot.adapters import Event
from ..typing import FUNCTION_CALL_FUNC
from .params import P
from ..typing import ASYNC_FUNCTION_CALL_FUNC, F
from .utils import async_wrap, is_coroutine_callable
F = TypeVar("F", bound=FUNCTION_CALL_FUNC)
_caller_data: dict[str, "Caller"] = {}
class Caller(Generic[P]):
class Caller:
def __init__(self, name: str | None = None, description: str | None = None):
self._name = name
self._description = description
self._parameters: dict[str, P] = {}
self.func: FUNCTION_CALL_FUNC | None = None
self.func: ASYNC_FUNCTION_CALL_FUNC | None = None
self._parameters: dict[str, Any] = {}
"""依赖注入的参数"""
self.event: Event | None = None
def params(self, **kwargs: P) -> "Caller":
"""设置多个函数参数
Args:
**kwargs: 参数字典
Returns:
Caller: Caller对象
"""
def params(self, **kwargs: Any) -> "Caller":
self._parameters.update(kwargs)
return self
def param(self, name: str, param: P) -> "Caller":
def param(self, name: str, param: Any) -> "Caller":
"""设置一个函数参数
Args:
name (str): 参数名
param (P): 参数对象
param (Any): 参数对象
Returns:
Caller: Caller对象
@ -51,14 +49,6 @@ class Caller(Generic[P]):
return self
def description(self, description: str) -> "Caller":
"""设置函数描述
Args:
description (str): 函数描述
Returns:
Caller: Caller对象
"""
self._description = description
return self
@ -71,12 +61,78 @@ class Caller(Generic[P]):
Returns:
F: 函数对象
"""
global _caller_data
if self._name is None:
if module := inspect.getmodule(func):
module_name = module.__name__.split(".")[-1]
else:
module_name = "global"
self._name = f"{module_name}-{func.__name__}"
_caller_data[self._name] = self
if is_coroutine_callable(func):
self.func = func # type: ignore
else:
self.func = async_wrap(func) # type: ignore
if module := inspect.getmodule(func):
module_name = module.__name__ + "."
else:
module_name = ""
logger.opt(colors=True).info(
f"<y>加载函数 {func.__name__} {self._description}</y>"
f"<y>加载函数 {module_name}{func.__name__}: {self._description}</y>"
)
self.func = func
return func
def data(self) -> dict[str, Any]:
"""返回函数的json数据
Returns:
dict[str, Any]: 函数的json数据
"""
return {
"type": "function",
"function": {
"name": self._name,
"description": self._description,
"parameters": {
"type": "object",
"properties": {
key: value.data() for key, value in self._parameters.items()
},
},
"required": [
key
for key, value in self._parameters.items()
if value.default is None
],
},
}
def set_event(self, event: Event):
self.event = event
async def call(self, *args: Any, **kwargs: Any) -> Any:
"""调用函数
Returns:
Any: 函数返回值
"""
if self.func is None:
raise ValueError("未注册函数对象")
sig = inspect.signature(self.func)
for name, param in sig.parameters.items():
if issubclass(param.annotation, Event) or isinstance(
param.annotation, Event
):
kwargs[name] = self.event
if issubclass(param.annotation, Caller) or isinstance(
param.annotation, Caller
):
kwargs[name] = self
return await self.func(*args, **kwargs)
def on_function_call(name: str | None = None, description: str | None = None) -> Caller:
"""返回一个Caller类可用于装饰一个函数使其注册为一个可被AI调用的function call函数
@ -87,5 +143,14 @@ def on_function_call(name: str | None = None, description: str | None = None) ->
Returns:
Caller: Caller对象
"""
caller = Caller(name=name, description=description)
return caller
return Caller(name=name, description=description)
def get_function_calls() -> dict[str, Caller]:
"""获取所有已注册的function call函数
Returns:
dict[str, Caller]: 所有已注册的function call函数
"""
return _caller_data

View File

@ -0,0 +1,52 @@
import inspect
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable
from ..typing import F
def copy_signature(func: F) -> Callable[[Callable[..., Any]], F]:
"""复制函数签名和文档字符串的装饰器"""
def decorator(wrapper: Callable[..., Any]) -> F:
@wraps(func)
def wrapped(*args: Any, **kwargs: Any) -> Any:
return wrapper(*args, **kwargs)
return wrapped # type: ignore
return decorator
def async_wrap(func: F) -> F:
"""装饰器,将同步函数包装为异步函数
Args:
func (F): 函数对象
Returns:
F: 包装后的函数对象
"""
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
return func(*args, **kwargs)
return wrapper # type: ignore
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
"""
判断是否为async def 函数
请注意是否为 async def 函数与该函数是否能被await调用是两个不同的概念具体取决于函数返回值是否为awaitable对象
Args:
call: 可调用对象
Returns:
bool: 是否为async def函数
"""
if inspect.isroutine(call):
return inspect.iscoroutinefunction(call)
if inspect.isclass(call):
return False
func_ = getattr(call, "__call__", None)
return inspect.iscoroutinefunction(func_)

View File

@ -1,5 +1,7 @@
from typing import Any, Callable, Coroutine, TypeAlias
from typing import Any, Callable, Coroutine, TypeAlias, TypeVar
SYNC_FUNCTION_CALL_FUNC: TypeAlias = Callable[..., str]
ASYNC_FUNCTION_CALL_FUNC: TypeAlias = Callable[..., Coroutine[str, Any, str]]
FUNCTION_CALL_FUNC: TypeAlias = SYNC_FUNCTION_CALL_FUNC | ASYNC_FUNCTION_CALL_FUNC
F = TypeVar("F", bound=FUNCTION_CALL_FUNC)

View File

@ -18,21 +18,5 @@ def path_to_module_name(path: Path) -> str:
return ".".join(rel_path.parts[:-1] + (rel_path.stem,))
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
"""
判断是否为async def 函数
Args:
call: 可调用对象
Returns:
bool: 是否为协程可调用对象
"""
if inspect.isroutine(call):
return inspect.iscoroutinefunction(call)
if inspect.isclass(call):
return False
func_ = getattr(call, "__call__", None)
return inspect.iscoroutinefunction(func_)
def parse_function_docsring():
pass

View File

@ -1,3 +1,5 @@
from nonebot.adapters.onebot.v11 import MessageEvent
from nonebot_plugin_marshoai.plugin import (
Integer,
Parameter,
@ -5,6 +7,7 @@ from nonebot_plugin_marshoai.plugin import (
String,
on_function_call,
)
from nonebot_plugin_marshoai.plugin.func_call.caller import Caller
__marsho_meta__ = PluginMetadata(
name="SnowyKami 测试插件",
@ -19,16 +22,7 @@ __marsho_meta__ = PluginMetadata(
gender=String(enum=["", ""], description="性别"),
)
async def fortune_telling(age: int, name: str, gender: str) -> str:
"""使用姓名,年龄,性别进行算命
Args:
age (int): _description_
name (str): _description_
gender (str): _description_
Returns:
str: _description_
"""
"""使用姓名,年龄,性别进行算命"""
# 进行一系列算命操作...
@ -41,17 +35,22 @@ async def fortune_telling(age: int, name: str, gender: str) -> str:
unit=String(enum=["摄氏度", "华氏度"], description="温度单位"),
)
async def get_weather(location: str, days: int, unit: str) -> str:
"""获取一个地点未来一段时间的天气
Args:
location (str): 地点名称可以是城市名地区名等
days (int): 天数
unit (str): 温度单位
Returns:
str: 天气信息
"""
"""获取一个地点未来一段时间的天气"""
# 进行一系列获取天气操作...
return f"{location}未来{days}天的天气信息..."
@on_function_call(description="获取设备物理地理位置")
async def get_location() -> str:
"""获取设备物理地理位置"""
# 进行一系列获取地理位置操作...
return "日本 东京都 世田谷区"
@on_function_call(description="获取聊天者个人信息")
async def get_user_info(e: MessageEvent, c: Caller) -> str:
return f"用户信息:{e.user_id} {e.sender.nickname}, {c._parameters}"

View File

@ -2,7 +2,7 @@
{
"type": "function",
"function": {
"name": "marshoai-memory__write_memory",
"name": "marshoai_memory__write_memory",
"description": "当你想记住有关与你对话的人的一些信息的时候,调用此函数。",
"parameters": {
"type": "object",