diff --git a/nonebot_plugin_marshoai/plugin/models.py b/nonebot_plugin_marshoai/plugin/models.py index c0e0ca81..ab96e426 100644 --- a/nonebot_plugin_marshoai/plugin/models.py +++ b/nonebot_plugin_marshoai/plugin/models.py @@ -3,6 +3,8 @@ from typing import Any from pydantic import BaseModel +from .typing import ASYNC_FUNCTION_CALL_FUNC + class PluginMetadata(BaseModel): """ @@ -67,3 +69,75 @@ class Plugin(BaseModel): def __eq__(self, other: Any) -> bool: return self.name == other.name + + +class FunctionCallArgument(BaseModel): + """ + 插件函数参数对象 + + Attributes: + ---------- + name: str + 参数名称 + type: str + 参数类型 string integer等 + description: str + 参数描述 + """ + + type_: str + """参数类型描述 string integer等""" + description: str + """参数描述""" + default: Any = None + """默认值""" + + def data(self) -> dict[str, Any]: + return {"type": self.type_, "description": self.description} + + +class FunctionCall(BaseModel): + """ + 插件函数对象 + + Attributes: + ---------- + name: str + 函数名称 + func: "FUNCTION_CALL" + 函数对象 + """ + + name: str + """函数名称 module.func""" + description: str + """函数描述 这个函数用于获取天气信息""" + arguments: dict[str, FunctionCallArgument] + """函数参数信息""" + function: ASYNC_FUNCTION_CALL_FUNC + """函数对象""" + + class Config: + arbitrary_types_allowed = True + + def __hash__(self) -> int: + return hash(self.name) + + def data(self) -> dict[str, Any]: + """生成函数描述信息 + + Returns: + dict[str, Any]: 函数描述信息 字典 + """ + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": { + "type": "object", + "properties": {k: v.data() for k, v in self.arguments.items()}, + }, + "required": [k for k, v in self.arguments.items() if v.default is None], + }, + } diff --git a/nonebot_plugin_marshoai/plugin/register.py b/nonebot_plugin_marshoai/plugin/register.py index 609bf3ce..1f281332 100644 --- a/nonebot_plugin_marshoai/plugin/register.py +++ b/nonebot_plugin_marshoai/plugin/register.py @@ -4,18 +4,20 @@ import inspect from typing import Any, Callable, Coroutine, TypeAlias -import nonebot +from nonebot import logger +from .models import FunctionCall, FunctionCallArgument +from .typing import ( + ASYNC_FUNCTION_CALL_FUNC, + FUNCTION_CALL_FUNC, + SYNC_FUNCTION_CALL_FUNC, +) from .utils import is_coroutine_callable -SYNC_FUNCTION_CALL: TypeAlias = Callable[..., str] -ASYNC_FUNCTION_CALL: TypeAlias = Callable[..., Coroutine[str, Any, str]] -FUNCTION_CALL: TypeAlias = SYNC_FUNCTION_CALL | ASYNC_FUNCTION_CALL - -_loaded_functions: dict[str, FUNCTION_CALL] = {} +_loaded_functions: dict[str, FUNCTION_CALL_FUNC] = {} -def async_wrapper(func: SYNC_FUNCTION_CALL) -> ASYNC_FUNCTION_CALL: +def async_wrapper(func: SYNC_FUNCTION_CALL_FUNC) -> ASYNC_FUNCTION_CALL_FUNC: """将同步函数包装为异步函数,但是不会真正异步执行,仅用于统一调用及函数签名 Args: @@ -31,7 +33,7 @@ def async_wrapper(func: SYNC_FUNCTION_CALL) -> ASYNC_FUNCTION_CALL: return wrapper -def function_call(*funcs: FUNCTION_CALL): +def function_call(*funcs: FUNCTION_CALL_FUNC) -> None: """返回一个装饰器,装饰一个函数, 使其注册为一个可被AI调用的function call函数 Args: @@ -41,15 +43,20 @@ def function_call(*funcs: FUNCTION_CALL): str: 函数定义信息 """ for func in funcs: - if module := inspect.getmodule(func): - module_name = module.__name__ + "." - else: - module_name = "" - name = func.__name__ - if not is_coroutine_callable(func): - func = async_wrapper(func) # type: ignore + function_call = get_function_info(func) + # TODO: 注册函数 - _loaded_functions[name] = func - nonebot.logger.opt(colors=True).info( - f"加载 function call: {module_name}{name}" - ) + +def get_function_info(func: FUNCTION_CALL_FUNC): + """获取函数信息 + + Args: + func: 函数对象 + + Returns: + FunctionCall: 函数信息对象模型 + """ + name = func.__name__ + description = func.__doc__ + logger.info(f"注册函数: {name} {description}") + # TODO: 获取函数参数信息 diff --git a/nonebot_plugin_marshoai/plugin/typing.py b/nonebot_plugin_marshoai/plugin/typing.py new file mode 100644 index 00000000..1618dc26 --- /dev/null +++ b/nonebot_plugin_marshoai/plugin/typing.py @@ -0,0 +1,5 @@ +from typing import Any, Callable, Coroutine, TypeAlias + +SYNC_FUNCTION_CALL_FUNC: TypeAlias = Callable[..., str] +ASYNC_FUNCTION_CALL_FUNC: TypeAlias = Callable[..., Coroutine[str, Any, str]] +FUNCTION_CALL_FUNC: TypeAlias = SYNC_FUNCTION_CALL_FUNC | ASYNC_FUNCTION_CALL_FUNC