194 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import inspect
from typing import Any
from nonebot import logger
from nonebot.adapters import Bot, Event
from nonebot.permission import Permission
from nonebot.rule import Rule
from nonebot.typing import T_State
from ..typing import ASYNC_FUNCTION_CALL_FUNC, F
from .utils import async_wrap, is_coroutine_callable
_caller_data: dict[str, "Caller"] = {}
class Caller:
def __init__(self, name: str | None = None, description: str | None = None):
self._name = name
self._description = description
self.func: ASYNC_FUNCTION_CALL_FUNC | None = None
self._parameters: dict[str, Any] = {}
"""依赖注入的参数"""
self.bot: Bot | None = None
self.event: Event | None = None
self.state: T_State | None = None
self._permission: Permission | None = None
self._rule: Rule | None = None
def params(self, **kwargs: Any) -> "Caller":
self._parameters.update(kwargs)
return self
def permission(self, permission: Permission) -> "Caller":
self._permission = self._permission or permission
return self
async def pre_check(self) -> tuple[bool, str]:
if self.bot is None or self.event is None:
return False, "Context is None"
if self._permission and not await self._permission(self.bot, self.event):
return False, "告诉用户 Permission Denied 权限不足"
if self.state is None:
return False, "State is None"
if self._rule and not await self._rule(self.bot, self.event, self.state):
return False, "告诉用户 Rule Denied 规则不匹配"
return True, ""
def rule(self, rule: Rule) -> "Caller":
self._rule = self._rule and rule
return self
def name(self, name: str) -> "Caller":
"""设置函数名称
Args:
name (str): 函数名称
Returns:
Caller: Caller对象
"""
self._name = name
return self
def description(self, description: str) -> "Caller":
self._description = description
return self
def __call__(self, func: F) -> F:
"""装饰函数注册为一个可被AI调用的function call函数
Args:
func (F): 函数对象
Returns:
F: 函数对象
"""
global _caller_data
if self._name is None:
if module := inspect.getmodule(func):
module_name = module.__name__.split(".")[-1]
else:
module_name = ""
self._name = f"{module_name}-{func.__name__}"
_caller_data[self._name] = self
if is_coroutine_callable(func):
self.func = func # type: ignore
else:
self.func = async_wrap(func) # type: ignore
if module := inspect.getmodule(func):
module_name = module.__name__.split(".")[-1] + "."
else:
module_name = ""
logger.opt(colors=True).debug(
f"<y>加载函数 {module_name}{func.__name__}: {self._description}</y>"
)
return func
def data(self) -> dict[str, Any]:
"""返回函数的json数据
Returns:
dict[str, Any]: 函数的json数据
"""
return {
"type": "function",
"function": {
"name": self._name,
"description": self._description,
"parameters": {
"type": "object",
"properties": {
key: value.data() for key, value in self._parameters.items()
},
},
"required": [
key
for key, value in self._parameters.items()
if value.default is None
],
},
}
def set_event(self, event: Event):
self.event = event
def set_bot(self, bot: Bot):
self.bot = bot
async def call(self, *args: Any, **kwargs: Any) -> Any:
"""调用函数
Returns:
Any: 函数返回值
"""
y, r = await self.pre_check()
if not y:
logger.debug(f"Function {self._name} pre_check failed: {r}")
return r
if self.func is None:
raise ValueError("未注册函数对象")
sig = inspect.signature(self.func)
for name, param in sig.parameters.items():
if issubclass(param.annotation, Event) or isinstance(
param.annotation, Event
):
kwargs[name] = self.event
if issubclass(param.annotation, Caller) or isinstance(
param.annotation, Caller
):
kwargs[name] = self
if issubclass(param.annotation, Bot) or isinstance(param.annotation, Bot):
kwargs[name] = self.bot
if param.annotation == T_State:
kwargs[name] = self.state
# 检查形参是否有默认值或传入若没有则用parameters中的默认值填充
for name, param in sig.parameters.items():
if name not in kwargs:
kwargs[name] = self._parameters.get(name, param.default)
return await self.func(*args, **kwargs)
def on_function_call(name: str | None = None, description: str | None = None) -> Caller:
"""返回一个Caller类可用于装饰一个函数使其注册为一个可被AI调用的function call函数
Args:
description: 函数描述若为None则从函数的docstring中获取
Returns:
Caller: Caller对象
"""
caller = Caller(name=name, description=description)
return caller
def get_function_calls() -> dict[str, Caller]:
"""获取所有已注册的function call函数
Returns:
dict[str, Caller]: 所有已注册的function call函数
"""
return _caller_data