163 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import inspect
from typing import Any
from nonebot import logger
from nonebot.adapters import Event
from ..typing import ASYNC_FUNCTION_CALL_FUNC, F
from .utils import async_wrap, is_coroutine_callable
_caller_data: dict[str, "Caller"] = {}
class Caller:
def __init__(self, name: str | None = None, description: str | None = None):
self._name = name
self._description = description
self.func: ASYNC_FUNCTION_CALL_FUNC | None = None
self._parameters: dict[str, Any] = {}
"""依赖注入的参数"""
self.event: Event | None = None
def params(self, **kwargs: Any) -> "Caller":
self._parameters.update(kwargs)
return self
def param(self, name: str, param: Any) -> "Caller":
"""设置一个函数参数
Args:
name (str): 参数名
param (Any): 参数对象
Returns:
Caller: Caller对象
"""
self._parameters[name] = param
return self
def name(self, name: str) -> "Caller":
"""设置函数名称
Args:
name (str): 函数名称
Returns:
Caller: Caller对象
"""
self._name = name
return self
def description(self, description: str) -> "Caller":
self._description = description
return self
def __call__(self, func: F) -> F:
"""装饰函数注册为一个可被AI调用的function call函数
Args:
func (F): 函数对象
Returns:
F: 函数对象
"""
global _caller_data
if self._name is None:
if module := inspect.getmodule(func):
module_name = module.__name__.split(".")[-1]
else:
module_name = "global"
self._name = f"{module_name}-{func.__name__}"
_caller_data[self._name] = self
if is_coroutine_callable(func):
self.func = func # type: ignore
else:
self.func = async_wrap(func) # type: ignore
if module := inspect.getmodule(func):
module_name = module.__name__ + "."
else:
module_name = ""
logger.opt(colors=True).info(
f"<y>加载函数 {module_name}{func.__name__}: {self._description}</y>"
)
return func
def data(self) -> dict[str, Any]:
"""返回函数的json数据
Returns:
dict[str, Any]: 函数的json数据
"""
return {
"type": "function",
"function": {
"name": self._name,
"description": self._description,
"parameters": {
"type": "object",
"properties": {
key: value.data() for key, value in self._parameters.items()
},
},
"required": [
key
for key, value in self._parameters.items()
if value.default is None
],
},
}
def set_event(self, event: Event):
self.event = event
async def call(self, *args: Any, **kwargs: Any) -> Any:
"""调用函数
Returns:
Any: 函数返回值
"""
if self.func is None:
raise ValueError("未注册函数对象")
sig = inspect.signature(self.func)
for name, param in sig.parameters.items():
if issubclass(param.annotation, Event) or isinstance(
param.annotation, Event
):
kwargs[name] = self.event
if issubclass(param.annotation, Caller) or isinstance(
param.annotation, Caller
):
kwargs[name] = self
# 检查形参是否有默认值或传入若没有则用parameters中的默认值填充
for name, param in sig.parameters.items():
if name not in kwargs:
kwargs[name] = self._parameters.get(name, param.default)
return await self.func(*args, **kwargs)
def on_function_call(name: str | None = None, description: str | None = None) -> Caller:
"""返回一个Caller类可用于装饰一个函数使其注册为一个可被AI调用的function call函数
Args:
description: 函数描述若为None则从函数的docstring中获取
Returns:
Caller: Caller对象
"""
caller = Caller(name=name, description=description)
return caller
def get_function_calls() -> dict[str, Caller]:
"""获取所有已注册的function call函数
Returns:
dict[str, Caller]: 所有已注册的function call函数
"""
return _caller_data