mirror of
https://github.com/LiteyukiStudio/nonebot-plugin-marshoai.git
synced 2025-02-07 05:56:11 +08:00
297 lines
8.8 KiB
Python
297 lines
8.8 KiB
Python
import inspect
|
||
from typing import Any
|
||
|
||
from nonebot import logger
|
||
from nonebot.adapters import Bot, Event
|
||
from nonebot.matcher import Matcher
|
||
from nonebot.permission import Permission
|
||
from nonebot.rule import Rule
|
||
from nonebot.typing import T_State
|
||
|
||
from ..models import Plugin
|
||
from ..typing import ASYNC_FUNCTION_CALL_FUNC, F
|
||
from .models import SessionContext, SessionContextDepends
|
||
from .utils import async_wrap, is_coroutine_callable
|
||
|
||
_caller_data: dict[str, "Caller"] = {}
|
||
|
||
|
||
class Caller:
|
||
def __init__(
|
||
self,
|
||
name: str = "",
|
||
description: str | None = None,
|
||
func_type: str = "function",
|
||
no_module_name: bool = False,
|
||
):
|
||
self._name: str = name
|
||
"""函数名称"""
|
||
self._description = description
|
||
"""函数描述"""
|
||
self._func_type = func_type
|
||
"""函数类型"""
|
||
self.no_module_name = no_module_name
|
||
"""是否不包含模块名"""
|
||
self._plugin: Plugin | None = None
|
||
"""所属插件对象,装饰时声明"""
|
||
self.func: ASYNC_FUNCTION_CALL_FUNC | None = None
|
||
"""函数对象"""
|
||
self.module_name: str = ""
|
||
"""模块名,仅为父级模块名,不一定是插件顶级模块名"""
|
||
self._parameters: dict[str, Any] = {}
|
||
"""声明参数"""
|
||
|
||
self.di: SessionContextDepends = SessionContextDepends()
|
||
"""依赖注入的参数信息"""
|
||
|
||
self.default: dict[str, Any] = {}
|
||
"""默认值"""
|
||
|
||
self.ctx: SessionContext | None = None
|
||
|
||
self._permission: Permission | None = None
|
||
self._rule: Rule | None = None
|
||
|
||
def params(self, **kwargs: Any) -> "Caller":
|
||
self._parameters.update(kwargs)
|
||
return self
|
||
|
||
def permission(self, permission: Permission) -> "Caller":
|
||
self._permission = self._permission or permission
|
||
return self
|
||
|
||
async def pre_check(self) -> tuple[bool, str]:
|
||
if self.ctx is None:
|
||
return False, "上下文为空"
|
||
if self.ctx.bot is None or self.ctx.event is None:
|
||
return False, "Context is None"
|
||
if self._permission and not await self._permission(
|
||
self.ctx.bot, self.ctx.event
|
||
):
|
||
return False, "告诉用户 Permission Denied 权限不足"
|
||
|
||
if self.ctx.state is None:
|
||
return False, "State is None"
|
||
if self._rule and not await self._rule(
|
||
self.ctx.bot, self.ctx.event, self.ctx.state
|
||
):
|
||
return False, "告诉用户 Rule Denied 规则不匹配"
|
||
|
||
return True, ""
|
||
|
||
def rule(self, rule: Rule) -> "Caller":
|
||
self._rule = self._rule and rule
|
||
return self
|
||
|
||
def name(self, name: str) -> "Caller":
|
||
"""设置函数名称
|
||
|
||
Args:
|
||
name (str): 函数名称
|
||
|
||
Returns:
|
||
Caller: Caller对象
|
||
"""
|
||
self._name = name
|
||
return self
|
||
|
||
def description(self, description: str) -> "Caller":
|
||
self._description = description
|
||
return self
|
||
|
||
def __call__(self, func: F) -> F:
|
||
"""装饰函数,注册为一个可被AI调用的function call函数
|
||
|
||
Args:
|
||
func (F): 函数对象
|
||
|
||
Returns:
|
||
F: 函数对象
|
||
"""
|
||
global _caller_data
|
||
if not self._name:
|
||
self._name = func.__name__
|
||
|
||
# 检查函数签名,确定依赖注入参数
|
||
sig = inspect.signature(func)
|
||
for name, param in sig.parameters.items():
|
||
if issubclass(param.annotation, Event) or isinstance(
|
||
param.annotation, Event
|
||
):
|
||
self.di.event = name
|
||
|
||
if issubclass(param.annotation, Caller) or isinstance(
|
||
param.annotation, Caller
|
||
):
|
||
self.di.caller = name
|
||
|
||
if issubclass(param.annotation, Bot) or isinstance(param.annotation, Bot):
|
||
self.di.bot = name
|
||
|
||
if issubclass(param.annotation, Matcher) or isinstance(
|
||
param.annotation, Matcher
|
||
):
|
||
self.di.matcher = name
|
||
|
||
if param.annotation == T_State:
|
||
self.di.state = name
|
||
|
||
# 检查默认值情况
|
||
for name, param in sig.parameters.items():
|
||
if param.default is not inspect.Parameter.empty:
|
||
self.default[name] = param.default
|
||
|
||
if is_coroutine_callable(func):
|
||
self.func = func # type: ignore
|
||
else:
|
||
self.func = async_wrap(func) # type: ignore
|
||
|
||
if module := inspect.getmodule(func):
|
||
module_name = module.__name__.split(".")[-1]
|
||
else:
|
||
module_name = ""
|
||
|
||
self.module_name = module_name
|
||
_caller_data[self.aifc_name] = self
|
||
logger.opt(colors=True).debug(
|
||
f"<y>加载函数 {self.full_name}: {self._description}</y>"
|
||
)
|
||
|
||
return func
|
||
|
||
def data(self) -> dict[str, Any]:
|
||
"""返回函数的json数据
|
||
|
||
Returns:
|
||
dict[str, Any]: 函数的json数据
|
||
"""
|
||
properties = {key: value.data() for key, value in self._parameters.items()}
|
||
if not properties:
|
||
properties["placeholder"] = {
|
||
"type": "string",
|
||
"description": "占位符,用于显示在对话框中", # 为保证兼容性而设置的无用参数
|
||
}
|
||
return {
|
||
"type": self._func_type,
|
||
"function": {
|
||
"name": self.aifc_name,
|
||
"description": self._description,
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": properties,
|
||
},
|
||
"required": [
|
||
key
|
||
for key, value in self._parameters.items()
|
||
if value.default is None
|
||
],
|
||
},
|
||
}
|
||
|
||
def set_ctx(self, ctx: SessionContext) -> None:
|
||
"""设置依赖注入上下文
|
||
|
||
Args:
|
||
ctx (SessionContext): 依赖注入上下文
|
||
"""
|
||
ctx.caller = self
|
||
self.ctx = ctx
|
||
for type_name, arg_name in self.di.model_dump().items():
|
||
|
||
if arg_name:
|
||
self.default[arg_name] = ctx.__getattribute__(type_name)
|
||
|
||
def with_ctx(self, ctx: SessionContext) -> "Caller":
|
||
"""设置依赖注入上下文
|
||
|
||
Args:
|
||
ctx (SessionContext): 依赖注入上下文
|
||
|
||
Returns:
|
||
Caller: Caller对象
|
||
"""
|
||
self.set_ctx(ctx)
|
||
return self
|
||
|
||
def __str__(self) -> str:
|
||
return f"{self._name}({self._description})\n" + "\n".join(
|
||
f" - {key}: {value}" for key, value in self._parameters.items()
|
||
)
|
||
|
||
async def call(self, *args: Any, **kwargs: Any) -> Any:
|
||
"""调用函数
|
||
|
||
Returns:
|
||
Any: 函数返回值
|
||
"""
|
||
y, r = await self.pre_check()
|
||
if not y:
|
||
logger.debug(f"Function {self._name} pre_check failed: {r}")
|
||
return r
|
||
|
||
if self.func is None:
|
||
raise ValueError("未注册函数对象")
|
||
|
||
# 检查形参是否有默认值或传入,若没有则用default中的默认值填充
|
||
for name, value in self.default.items():
|
||
if name not in kwargs:
|
||
kwargs[name] = value
|
||
|
||
return await self.func(*args, **kwargs)
|
||
|
||
@property
|
||
def short_name(self) -> str:
|
||
"""函数本名"""
|
||
return self._name.split(".")[-1]
|
||
|
||
@property
|
||
def aifc_name(self) -> str:
|
||
"""AI调用名,没有点"""
|
||
if self.no_module_name:
|
||
return self._name
|
||
return self.full_name.replace(".", "-")
|
||
|
||
@property
|
||
def full_name(self) -> str:
|
||
"""完整名"""
|
||
return self.module_name + "." + self._name
|
||
|
||
@property
|
||
def short_info(self) -> str:
|
||
return f"{self.full_name}({self._description})"
|
||
|
||
|
||
def on_function_call(
|
||
name: str = "",
|
||
description: str | None = None,
|
||
func_type: str = "function",
|
||
no_module_name: bool = False,
|
||
) -> Caller:
|
||
"""返回一个Caller类,可用于装饰一个函数,使其注册为一个可被AI调用的function call函数
|
||
|
||
Args:
|
||
name: 函数名称,若为空则从函数的__name__属性获取
|
||
description: 函数描述,若为None则从函数的docstring中获取
|
||
func_type: 函数类型,默认为function,若要注册为 Moonshot AI 的内置函数则为builtin_function
|
||
no_module_name: 是否不包含模块名,当注册为 Moonshot AI 的内置函数时为True
|
||
|
||
Returns:
|
||
Caller: Caller对象
|
||
"""
|
||
caller = Caller(
|
||
name=name,
|
||
description=description,
|
||
func_type=func_type,
|
||
no_module_name=no_module_name,
|
||
)
|
||
return caller
|
||
|
||
|
||
def get_function_calls() -> dict[str, Caller]:
|
||
"""获取所有已注册的function call函数
|
||
|
||
Returns:
|
||
dict[str, Caller]: 所有已注册的function call函数
|
||
"""
|
||
return _caller_data
|