2024-12-15 17:08:02 +08:00
|
|
|
|
import inspect
|
|
|
|
|
from typing import Any
|
2024-12-15 02:59:06 +08:00
|
|
|
|
|
|
|
|
|
from nonebot import logger
|
2024-12-15 18:27:30 +08:00
|
|
|
|
from nonebot.adapters import Bot, Event
|
2024-12-17 13:25:30 +08:00
|
|
|
|
from nonebot.matcher import Matcher
|
2024-12-15 18:27:30 +08:00
|
|
|
|
from nonebot.permission import Permission
|
|
|
|
|
from nonebot.rule import Rule
|
|
|
|
|
from nonebot.typing import T_State
|
2024-12-15 02:51:37 +08:00
|
|
|
|
|
2024-12-17 19:32:51 +08:00
|
|
|
|
from ..models import Plugin
|
2024-12-15 17:08:02 +08:00
|
|
|
|
from ..typing import ASYNC_FUNCTION_CALL_FUNC, F
|
2024-12-17 13:25:30 +08:00
|
|
|
|
from .models import SessionContext, SessionContextDepends
|
2024-12-15 17:08:02 +08:00
|
|
|
|
from .utils import async_wrap, is_coroutine_callable
|
2024-12-15 02:59:06 +08:00
|
|
|
|
|
2024-12-15 17:08:02 +08:00
|
|
|
|
_caller_data: dict[str, "Caller"] = {}
|
2024-12-15 02:51:37 +08:00
|
|
|
|
|
|
|
|
|
|
2024-12-15 17:08:02 +08:00
|
|
|
|
class Caller:
|
2024-12-17 19:32:51 +08:00
|
|
|
|
def __init__(self, name: str = "", description: str | None = None):
|
|
|
|
|
self._name: str = name
|
|
|
|
|
"""函数名称"""
|
2024-12-15 02:51:37 +08:00
|
|
|
|
self._description = description
|
2024-12-17 19:32:51 +08:00
|
|
|
|
"""函数描述"""
|
|
|
|
|
self._plugin: Plugin | None = None
|
|
|
|
|
"""所属插件对象,装饰时声明"""
|
2024-12-15 17:08:02 +08:00
|
|
|
|
self.func: ASYNC_FUNCTION_CALL_FUNC | None = None
|
2024-12-17 19:32:51 +08:00
|
|
|
|
"""函数对象"""
|
|
|
|
|
self.module_name: str = ""
|
2024-12-17 20:51:42 +08:00
|
|
|
|
"""模块名,仅为父级模块名,不一定是插件顶级模块名"""
|
2024-12-15 17:08:02 +08:00
|
|
|
|
self._parameters: dict[str, Any] = {}
|
2024-12-17 13:25:30 +08:00
|
|
|
|
"""声明参数"""
|
|
|
|
|
|
|
|
|
|
self.di: SessionContextDepends = SessionContextDepends()
|
|
|
|
|
"""依赖注入的参数信息"""
|
|
|
|
|
|
|
|
|
|
self.default: dict[str, Any] = {}
|
|
|
|
|
"""默认值"""
|
|
|
|
|
|
|
|
|
|
self.ctx: SessionContext | None = None
|
2024-12-15 18:27:30 +08:00
|
|
|
|
|
|
|
|
|
self._permission: Permission | None = None
|
|
|
|
|
self._rule: Rule | None = None
|
2024-12-15 02:51:37 +08:00
|
|
|
|
|
2024-12-15 17:08:02 +08:00
|
|
|
|
def params(self, **kwargs: Any) -> "Caller":
|
2024-12-15 02:51:37 +08:00
|
|
|
|
self._parameters.update(kwargs)
|
|
|
|
|
return self
|
|
|
|
|
|
2024-12-15 18:27:30 +08:00
|
|
|
|
def permission(self, permission: Permission) -> "Caller":
|
|
|
|
|
self._permission = self._permission or permission
|
|
|
|
|
return self
|
2024-12-15 02:51:37 +08:00
|
|
|
|
|
2024-12-15 18:27:30 +08:00
|
|
|
|
async def pre_check(self) -> tuple[bool, str]:
|
2024-12-17 13:25:30 +08:00
|
|
|
|
if self.ctx is None:
|
|
|
|
|
return False, "上下文为空"
|
|
|
|
|
if self.ctx.bot is None or self.ctx.event is None:
|
2024-12-15 18:27:30 +08:00
|
|
|
|
return False, "Context is None"
|
2024-12-17 13:25:30 +08:00
|
|
|
|
if self._permission and not await self._permission(
|
|
|
|
|
self.ctx.bot, self.ctx.event
|
|
|
|
|
):
|
2024-12-17 02:34:59 +08:00
|
|
|
|
return False, "告诉用户 Permission Denied 权限不足"
|
2024-12-15 02:51:37 +08:00
|
|
|
|
|
2024-12-17 13:25:30 +08:00
|
|
|
|
if self.ctx.state is None:
|
2024-12-15 18:27:30 +08:00
|
|
|
|
return False, "State is None"
|
2024-12-17 13:25:30 +08:00
|
|
|
|
if self._rule and not await self._rule(
|
|
|
|
|
self.ctx.bot, self.ctx.event, self.ctx.state
|
|
|
|
|
):
|
2024-12-17 02:34:59 +08:00
|
|
|
|
return False, "告诉用户 Rule Denied 规则不匹配"
|
2024-12-15 18:27:30 +08:00
|
|
|
|
|
|
|
|
|
return True, ""
|
|
|
|
|
|
|
|
|
|
def rule(self, rule: Rule) -> "Caller":
|
|
|
|
|
self._rule = self._rule and rule
|
2024-12-15 02:51:37 +08:00
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def name(self, name: str) -> "Caller":
|
|
|
|
|
"""设置函数名称
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name (str): 函数名称
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Caller: Caller对象
|
|
|
|
|
"""
|
|
|
|
|
self._name = name
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def description(self, description: str) -> "Caller":
|
|
|
|
|
self._description = description
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def __call__(self, func: F) -> F:
|
2024-12-15 02:59:06 +08:00
|
|
|
|
"""装饰函数,注册为一个可被AI调用的function call函数
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
func (F): 函数对象
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
F: 函数对象
|
|
|
|
|
"""
|
2024-12-15 17:08:02 +08:00
|
|
|
|
global _caller_data
|
2024-12-17 19:32:51 +08:00
|
|
|
|
if not self._name:
|
|
|
|
|
self._name = func.__name__
|
2024-12-15 17:08:02 +08:00
|
|
|
|
|
2024-12-17 13:25:30 +08:00
|
|
|
|
# 检查函数签名,确定依赖注入参数
|
|
|
|
|
sig = inspect.signature(func)
|
|
|
|
|
for name, param in sig.parameters.items():
|
|
|
|
|
if issubclass(param.annotation, Event) or isinstance(
|
|
|
|
|
param.annotation, Event
|
|
|
|
|
):
|
|
|
|
|
self.di.event = name
|
|
|
|
|
|
|
|
|
|
if issubclass(param.annotation, Caller) or isinstance(
|
|
|
|
|
param.annotation, Caller
|
|
|
|
|
):
|
|
|
|
|
self.di.caller = name
|
|
|
|
|
|
|
|
|
|
if issubclass(param.annotation, Bot) or isinstance(param.annotation, Bot):
|
|
|
|
|
self.di.bot = name
|
|
|
|
|
|
|
|
|
|
if issubclass(param.annotation, Matcher) or isinstance(
|
|
|
|
|
param.annotation, Matcher
|
|
|
|
|
):
|
|
|
|
|
self.di.matcher = name
|
|
|
|
|
|
|
|
|
|
if param.annotation == T_State:
|
|
|
|
|
self.di.state = name
|
|
|
|
|
|
|
|
|
|
# 检查默认值情况
|
|
|
|
|
for name, param in sig.parameters.items():
|
|
|
|
|
if param.default is not inspect.Parameter.empty:
|
|
|
|
|
self.default[name] = param.default
|
|
|
|
|
|
2024-12-15 17:08:02 +08:00
|
|
|
|
if is_coroutine_callable(func):
|
|
|
|
|
self.func = func # type: ignore
|
|
|
|
|
else:
|
|
|
|
|
self.func = async_wrap(func) # type: ignore
|
|
|
|
|
|
|
|
|
|
if module := inspect.getmodule(func):
|
2024-12-17 20:51:42 +08:00
|
|
|
|
module_name = module.__name__.split(".")[-1]
|
2024-12-15 17:08:02 +08:00
|
|
|
|
else:
|
|
|
|
|
module_name = ""
|
2024-12-17 19:32:51 +08:00
|
|
|
|
|
|
|
|
|
self.module_name = module_name
|
2024-12-30 23:14:49 +08:00
|
|
|
|
_caller_data[self.aifc_name] = self
|
2024-12-17 02:34:59 +08:00
|
|
|
|
logger.opt(colors=True).debug(
|
2024-12-17 20:51:42 +08:00
|
|
|
|
f"<y>加载函数 {self.full_name}: {self._description}</y>"
|
2024-12-15 02:59:06 +08:00
|
|
|
|
)
|
2024-12-15 17:08:02 +08:00
|
|
|
|
|
2024-12-15 02:51:37 +08:00
|
|
|
|
return func
|
|
|
|
|
|
2024-12-15 17:08:02 +08:00
|
|
|
|
def data(self) -> dict[str, Any]:
|
|
|
|
|
"""返回函数的json数据
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
dict[str, Any]: 函数的json数据
|
|
|
|
|
"""
|
2024-12-30 13:16:09 +08:00
|
|
|
|
properties = {key: value.data() for key, value in self._parameters.items()}
|
|
|
|
|
if not properties:
|
|
|
|
|
properties["placeholder"] = {
|
|
|
|
|
"type": "string",
|
|
|
|
|
"description": "占位符,用于显示在对话框中", # 为保证兼容性而设置的无用参数
|
|
|
|
|
}
|
2024-12-15 17:08:02 +08:00
|
|
|
|
return {
|
|
|
|
|
"type": "function",
|
|
|
|
|
"function": {
|
2024-12-17 19:32:51 +08:00
|
|
|
|
"name": self.aifc_name,
|
2024-12-15 17:08:02 +08:00
|
|
|
|
"description": self._description,
|
|
|
|
|
"parameters": {
|
|
|
|
|
"type": "object",
|
2024-12-30 13:16:09 +08:00
|
|
|
|
"properties": properties,
|
2024-12-15 17:08:02 +08:00
|
|
|
|
},
|
|
|
|
|
"required": [
|
|
|
|
|
key
|
|
|
|
|
for key, value in self._parameters.items()
|
|
|
|
|
if value.default is None
|
|
|
|
|
],
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
2024-12-17 13:25:30 +08:00
|
|
|
|
def set_ctx(self, ctx: SessionContext) -> None:
|
|
|
|
|
"""设置依赖注入上下文
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
ctx (SessionContext): 依赖注入上下文
|
|
|
|
|
"""
|
|
|
|
|
ctx.caller = self
|
|
|
|
|
self.ctx = ctx
|
|
|
|
|
for type_name, arg_name in self.di.model_dump().items():
|
|
|
|
|
|
|
|
|
|
if arg_name:
|
|
|
|
|
self.default[arg_name] = ctx.__getattribute__(type_name)
|
|
|
|
|
|
|
|
|
|
def with_ctx(self, ctx: SessionContext) -> "Caller":
|
|
|
|
|
"""设置依赖注入上下文
|
2024-12-15 17:08:02 +08:00
|
|
|
|
|
2024-12-17 13:25:30 +08:00
|
|
|
|
Args:
|
|
|
|
|
ctx (SessionContext): 依赖注入上下文
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Caller: Caller对象
|
|
|
|
|
"""
|
|
|
|
|
self.set_ctx(ctx)
|
|
|
|
|
return self
|
2024-12-15 18:27:30 +08:00
|
|
|
|
|
2024-12-17 19:32:51 +08:00
|
|
|
|
def __str__(self) -> str:
|
|
|
|
|
return f"{self._name}({self._description})\n" + "\n".join(
|
|
|
|
|
f" - {key}: {value}" for key, value in self._parameters.items()
|
|
|
|
|
)
|
|
|
|
|
|
2024-12-15 17:08:02 +08:00
|
|
|
|
async def call(self, *args: Any, **kwargs: Any) -> Any:
|
|
|
|
|
"""调用函数
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Any: 函数返回值
|
|
|
|
|
"""
|
2024-12-15 18:27:30 +08:00
|
|
|
|
y, r = await self.pre_check()
|
|
|
|
|
if not y:
|
2024-12-16 02:15:55 +08:00
|
|
|
|
logger.debug(f"Function {self._name} pre_check failed: {r}")
|
2024-12-15 18:27:30 +08:00
|
|
|
|
return r
|
|
|
|
|
|
2024-12-15 17:08:02 +08:00
|
|
|
|
if self.func is None:
|
|
|
|
|
raise ValueError("未注册函数对象")
|
2024-12-15 17:43:03 +08:00
|
|
|
|
|
2024-12-17 13:25:30 +08:00
|
|
|
|
# 检查形参是否有默认值或传入,若没有则用default中的默认值填充
|
|
|
|
|
for name, value in self.default.items():
|
2024-12-15 17:43:03 +08:00
|
|
|
|
if name not in kwargs:
|
2024-12-17 13:25:30 +08:00
|
|
|
|
kwargs[name] = value
|
2024-12-15 17:43:03 +08:00
|
|
|
|
|
2024-12-15 17:08:02 +08:00
|
|
|
|
return await self.func(*args, **kwargs)
|
|
|
|
|
|
2024-12-17 19:32:51 +08:00
|
|
|
|
@property
|
|
|
|
|
def short_name(self) -> str:
|
|
|
|
|
"""函数本名"""
|
|
|
|
|
return self._name.split(".")[-1]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def aifc_name(self) -> str:
|
|
|
|
|
"""AI调用名,没有点"""
|
2024-12-17 23:33:53 +08:00
|
|
|
|
return self.full_name.replace(".", "-")
|
2024-12-17 19:32:51 +08:00
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def full_name(self) -> str:
|
|
|
|
|
"""完整名"""
|
2024-12-17 20:51:42 +08:00
|
|
|
|
return self.module_name + "." + self._name
|
2024-12-17 19:32:51 +08:00
|
|
|
|
|
2024-12-17 19:38:46 +08:00
|
|
|
|
@property
|
|
|
|
|
def short_info(self) -> str:
|
|
|
|
|
return f"{self.full_name}({self._description})"
|
|
|
|
|
|
2024-12-15 02:51:37 +08:00
|
|
|
|
|
2024-12-17 19:32:51 +08:00
|
|
|
|
def on_function_call(name: str = "", description: str | None = None) -> Caller:
|
2024-12-15 02:51:37 +08:00
|
|
|
|
"""返回一个Caller类,可用于装饰一个函数,使其注册为一个可被AI调用的function call函数
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
description: 函数描述,若为None则从函数的docstring中获取
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Caller: Caller对象
|
|
|
|
|
"""
|
2024-12-15 17:08:02 +08:00
|
|
|
|
caller = Caller(name=name, description=description)
|
|
|
|
|
return caller
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_function_calls() -> dict[str, Caller]:
|
|
|
|
|
"""获取所有已注册的function call函数
|
2024-12-15 02:59:06 +08:00
|
|
|
|
|
2024-12-15 17:08:02 +08:00
|
|
|
|
Returns:
|
|
|
|
|
dict[str, Caller]: 所有已注册的function call函数
|
|
|
|
|
"""
|
|
|
|
|
return _caller_data
|