diff --git a/nonebot_plugin_marshoai/__init__.py b/nonebot_plugin_marshoai/__init__.py index 7b1cc1a3..a4d082a8 100755 --- a/nonebot_plugin_marshoai/__init__.py +++ b/nonebot_plugin_marshoai/__init__.py @@ -6,9 +6,11 @@ require("nonebot_plugin_localstore") import nonebot_plugin_localstore as store # type: ignore from nonebot import get_driver, logger # type: ignore -# from .hunyuan import * from .azure import * from .config import config + +# from .hunyuan import * +from .dev import * from .metadata import metadata __author__ = "Asankilp" diff --git a/nonebot_plugin_marshoai/azure.py b/nonebot_plugin_marshoai/azure.py index 76420444..c1257fc4 100644 --- a/nonebot_plugin_marshoai/azure.py +++ b/nonebot_plugin_marshoai/azure.py @@ -350,11 +350,11 @@ async def marsho( tool_call.function.arguments.replace("'", '"') ) logger.info( - f"调用函数 {tool_call.function.name.replace("-", ".")}\n参数:" + f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:" + "\n".join([f"{k}={v}" for k, v in function_args.items()]) ) await UniMessage( - f"调用函数 {tool_call.function.name.replace("-", ".")}\n参数:" + f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:" + "\n".join([f"{k}={v}" for k, v in function_args.items()]) ).send() # TODO 临时追加插件函数,若工具中没有则调用插件函数 @@ -365,9 +365,9 @@ async def marsho( ) # 获取返回值 else: if caller := get_function_calls().get( - tool_call.function.name + tool_call.function.name.replace("-", ".") ): - logger.debug(f"调用插件函数 {tool_call.function.name}") + logger.debug(f"调用插件函数 {caller.full_name}") # 权限检查,规则检查 TODO # 实现依赖注入,检查函数参数及参数注解类型,对Event类型的参数进行注入 func_return = await caller.with_ctx( @@ -379,8 +379,10 @@ async def marsho( ) ).call(**function_args) else: - logger.error(f"未找到函数 {tool_call.function.name}") - func_return = f"未找到函数 {tool_call.function.name}" + logger.error( + f"未找到函数 {tool_call.function.name.replace('-', '.')}" + ) + func_return = f"未找到函数 {tool_call.function.name.replace('-', '.')}" tool_msg.append( ToolMessage(tool_call_id=tool_call.id, content=func_return) # type: ignore ) diff --git a/nonebot_plugin_marshoai/config.py b/nonebot_plugin_marshoai/config.py index 86d1b71e..361245fd 100644 --- a/nonebot_plugin_marshoai/config.py +++ b/nonebot_plugin_marshoai/config.py @@ -52,7 +52,7 @@ class ConfigModel(BaseModel): marshoai_plugin_dirs: list[str] = [] """插件目录(不是工具)""" marshoai_devmode: bool = False - """开发者模式""" + """开发者模式,启用本地插件插件重载""" marshoai_plugins: list[str] = [] """marsho插件的名称列表,从pip安装的使用包名,从本地导入的使用路径""" diff --git a/nonebot_plugin_marshoai/dev.py b/nonebot_plugin_marshoai/dev.py index 7af58bba..8fc5468c 100644 --- a/nonebot_plugin_marshoai/dev.py +++ b/nonebot_plugin_marshoai/dev.py @@ -1,7 +1,77 @@ from nonebot import require +from nonebot.adapters import Bot, Event +from nonebot.matcher import Matcher +from nonebot.typing import T_State + +from nonebot_plugin_marshoai.plugin.func_call.models import SessionContext require("nonebot_plugin_alconna") -from nonebot_plugin_alconna import Alconna, on_alconna +from nonebot.permission import SUPERUSER +from nonebot_plugin_alconna import ( + Alconna, + Args, + MultiVar, + Subcommand, + UniMessage, + on_alconna, +) -function_call = on_alconna("marshocall") +from .plugin.func_call.caller import get_function_calls + +function_call = on_alconna( + command=Alconna( + "marsho-function-call", + Subcommand( + "call", + Args["function_name", str]["kwargs", MultiVar(str), []], + alias={"c"}, + ), + Subcommand( + "list", + alias={"l"}, + ), + Subcommand("info", Args["function_name", str], alias={"i"}), + ), + aliases={"mfc"}, + permission=SUPERUSER, +) + + +@function_call.assign("list") +async def list_functions(): + reply = "共有如下可调用函数:\n" + for function in get_function_calls().values(): + reply += f"- {function.name}({function.description}))\n" + await UniMessage(reply).send() + + +@function_call.assign("info") +async def function_info(function_name: str): + function = get_function_calls().get(function_name) + if function is None: + await UniMessage(f"未找到函数 {function_name}").send() + return + await UniMessage(str(function)).send() + + +@function_call.assign("call") +async def call_function( + function_name: str, + kwargs: list[str], + event: Event, + bot: Bot, + matcher: Matcher, + state: T_State, +): + function = get_function_calls().get(function_name) + if function is None: + await UniMessage(f"未找到函数 {function_name}").send() + return + await UniMessage( + str( + await function.with_ctx( + SessionContext(event=event, bot=bot, matcher=matcher, state=state) + ).call(**{i.split("=", 1)[0]: i.split("=", 1)[1] for i in kwargs}) + ) + ).send() diff --git a/nonebot_plugin_marshoai/plugin/func_call/caller.py b/nonebot_plugin_marshoai/plugin/func_call/caller.py index 727793f0..dde80b4c 100644 --- a/nonebot_plugin_marshoai/plugin/func_call/caller.py +++ b/nonebot_plugin_marshoai/plugin/func_call/caller.py @@ -8,6 +8,7 @@ from nonebot.permission import Permission from nonebot.rule import Rule from nonebot.typing import T_State +from ..models import Plugin from ..typing import ASYNC_FUNCTION_CALL_FUNC, F from .models import SessionContext, SessionContextDepends from .utils import async_wrap, is_coroutine_callable @@ -16,10 +17,17 @@ _caller_data: dict[str, "Caller"] = {} class Caller: - def __init__(self, name: str | None = None, description: str | None = None): - self._name = name + def __init__(self, name: str = "", description: str | None = None): + self._name: str = name + """函数名称""" self._description = description + """函数描述""" + self._plugin: Plugin | None = None + """所属插件对象,装饰时声明""" self.func: ASYNC_FUNCTION_CALL_FUNC | None = None + """函数对象""" + self.module_name: str = "" + """模块名""" self._parameters: dict[str, Any] = {} """声明参数""" @@ -91,13 +99,8 @@ class Caller: F: 函数对象 """ global _caller_data - if self._name is None: - if module := inspect.getmodule(func): - module_name = module.__name__.split(".")[-1] - else: - module_name = "" - self._name = f"{module_name}-{func.__name__}" - _caller_data[self._name] = self + if not self._name: + self._name = func.__name__ # 检查函数签名,确定依赖注入参数 sig = inspect.signature(func) @@ -137,6 +140,9 @@ class Caller: module_name = module.__name__.split(".")[-1] + "." else: module_name = "" + + self.module_name = module_name + _caller_data[self.full_name] = self logger.opt(colors=True).debug( f"加载函数 {module_name}{func.__name__}: {self._description}" ) @@ -152,7 +158,7 @@ class Caller: return { "type": "function", "function": { - "name": self._name, + "name": self.aifc_name, "description": self._description, "parameters": { "type": "object", @@ -193,6 +199,11 @@ class Caller: self.set_ctx(ctx) return self + def __str__(self) -> str: + return f"{self._name}({self._description})\n" + "\n".join( + f" - {key}: {value}" for key, value in self._parameters.items() + ) + async def call(self, *args: Any, **kwargs: Any) -> Any: """调用函数 @@ -214,8 +225,23 @@ class Caller: return await self.func(*args, **kwargs) + @property + def short_name(self) -> str: + """函数本名""" + return self._name.split(".")[-1] -def on_function_call(name: str | None = None, description: str | None = None) -> Caller: + @property + def aifc_name(self) -> str: + """AI调用名,没有点""" + return self._name.replace(".", "-") + + @property + def full_name(self) -> str: + """完整名""" + return self.module_name + self._name + + +def on_function_call(name: str = "", description: str | None = None) -> Caller: """返回一个Caller类,可用于装饰一个函数,使其注册为一个可被AI调用的function call函数 Args: diff --git a/nonebot_plugin_marshoai/plugin/load.py b/nonebot_plugin_marshoai/plugin/load.py index 895c94e3..7bc10c15 100755 --- a/nonebot_plugin_marshoai/plugin/load.py +++ b/nonebot_plugin_marshoai/plugin/load.py @@ -66,6 +66,7 @@ def load_plugin(module_path: str | Path) -> Optional[Plugin]: name=module.__name__, module=module, module_name=module_path, + module_path=module.__file__, ) _plugins[plugin.name] = plugin diff --git a/nonebot_plugin_marshoai/plugin/models.py b/nonebot_plugin_marshoai/plugin/models.py index 8a2541f9..bc042386 100755 --- a/nonebot_plugin_marshoai/plugin/models.py +++ b/nonebot_plugin_marshoai/plugin/models.py @@ -58,6 +58,8 @@ class Plugin(BaseModel): """插件模块对象""" module_name: str """点分割模块路径 例如a.b.c""" + module_path: str | None + """实际路径,单文件为.py的路径,包为__init__.py路径""" metadata: PluginMetadata | None = None """元""" @@ -69,3 +71,6 @@ class Plugin(BaseModel): def __eq__(self, other: Any) -> bool: return self.name == other.name + + def __str__(self) -> str: + return f"Plugin({self.name}({self.module_path}))"