diff --git a/nonebot_plugin_marshoai/azure.py b/nonebot_plugin_marshoai/azure.py index 6f972352..65952e9b 100755 --- a/nonebot_plugin_marshoai/azure.py +++ b/nonebot_plugin_marshoai/azure.py @@ -23,6 +23,8 @@ from nonebot.permission import SUPERUSER from nonebot.rule import Rule, to_me from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna +from nonebot_plugin_marshoai.plugin.func_call.caller import get_function_calls + from .metadata import metadata from .models import MarshoContext, MarshoTools from .plugin import _plugins, load_plugins @@ -103,10 +105,9 @@ async def _preload_tools(): @driver.on_startup async def _preload_plugins(): """启动钩子加载插件""" - marshoai_plugin_dirs = config.marshoai_plugin_dirs - marshoai_plugin_dirs.insert(0, Path(__file__).parent / "plugins") + marshoai_plugin_dirs = config.marshoai_plugin_dirs # 外部插件目录列表 + marshoai_plugin_dirs.insert(0, Path(__file__).parent / "plugins") # 预置插件目录 load_plugins(*marshoai_plugin_dirs) - logger.opt(colors=True).info(f"已加载 {len(_plugins)} 个小棉插件") @add_usermsg_cmd.handle() @@ -266,7 +267,10 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None) client=client, model_name=model_name, msg=context_msg + [UserMessage(content=usermsg)], # type: ignore - tools=tools.get_tools_list(), + tools=tools.get_tools_list() + + list( + map(lambda v: v.data(), get_function_calls().values()) + ), # TODO 临时追加函数,后期优化 ) # await UniMessage(str(response)).send() choice = response.choices[0] @@ -315,9 +319,23 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None) await UniMessage( f"调用函数 {tool_call.function.name} ,参数为 {function_args}" ).send() - func_return = await tools.call( - tool_call.function.name, function_args - ) # 获取返回值 + # TODO 临时追加插件函数,若工具中没有则调用插件函数 + if tools.has_function(tool_call.function.name): + logger.debug(f"调用工具函数 {tool_call.function.name}") + func_return = await tools.call( + tool_call.function.name, function_args + ) # 获取返回值 + else: + if caller := get_function_calls().get( + tool_call.function.name + ): + logger.debug(f"调用插件函数 {tool_call.function.name}") + # 实现依赖注入,检查函数参数及参数注解类型,对Event类型的参数进行注入 + caller.event = event + func_return = await caller.call(**function_args) + else: + logger.error(f"未找到函数 {tool_call.function.name}") + func_return = f"未找到函数 {tool_call.function.name}" tool_msg.append( ToolMessage(tool_call_id=tool_call.id, content=func_return) # type: ignore ) diff --git a/nonebot_plugin_marshoai/models.py b/nonebot_plugin_marshoai/models.py index 160b743d..04ae473e 100755 --- a/nonebot_plugin_marshoai/models.py +++ b/nonebot_plugin_marshoai/models.py @@ -90,6 +90,7 @@ class MarshoTools: with open(json_path, "r", encoding="utf-8") as json_file: data = json.load(json_file) for i in data: + self.tools_list.append(i) spec = importlib.util.spec_from_file_location( @@ -136,6 +137,21 @@ class MarshoTools: else: logger.error(f"工具包 '{package_name}' 未导入") + def has_function(self, full_function_name: str) -> bool: + """ + 检查是否存在指定的函数 + """ + try: + for t in self.tools_list: + if t["function"]["name"].replace( + "-", "_" + ) == full_function_name.replace("-", "_"): + return True + return False + except Exception as e: + logger.error(f"检查函数 '{full_function_name}' 时发生错误:{e}") + return False + def get_tools_list(self): if not self.tools_list or not config.marshoai_enable_tools: return None diff --git a/nonebot_plugin_marshoai/plugin/func_call/caller.py b/nonebot_plugin_marshoai/plugin/func_call/caller.py index cc8fd767..19181662 100644 --- a/nonebot_plugin_marshoai/plugin/func_call/caller.py +++ b/nonebot_plugin_marshoai/plugin/func_call/caller.py @@ -1,36 +1,34 @@ -from typing import Generic, TypeVar +import inspect +from typing import Any from nonebot import logger +from nonebot.adapters import Event -from ..typing import FUNCTION_CALL_FUNC -from .params import P +from ..typing import ASYNC_FUNCTION_CALL_FUNC, F +from .utils import async_wrap, is_coroutine_callable -F = TypeVar("F", bound=FUNCTION_CALL_FUNC) +_caller_data: dict[str, "Caller"] = {} -class Caller(Generic[P]): +class Caller: def __init__(self, name: str | None = None, description: str | None = None): self._name = name self._description = description - self._parameters: dict[str, P] = {} - self.func: FUNCTION_CALL_FUNC | None = None + self.func: ASYNC_FUNCTION_CALL_FUNC | None = None + self._parameters: dict[str, Any] = {} + """依赖注入的参数""" + self.event: Event | None = None - def params(self, **kwargs: P) -> "Caller": - """设置多个函数参数 - Args: - **kwargs: 参数字典 - Returns: - Caller: Caller对象 - """ + def params(self, **kwargs: Any) -> "Caller": self._parameters.update(kwargs) return self - def param(self, name: str, param: P) -> "Caller": + def param(self, name: str, param: Any) -> "Caller": """设置一个函数参数 Args: name (str): 参数名 - param (P): 参数对象 + param (Any): 参数对象 Returns: Caller: Caller对象 @@ -51,14 +49,6 @@ class Caller(Generic[P]): return self def description(self, description: str) -> "Caller": - """设置函数描述 - - Args: - description (str): 函数描述 - - Returns: - Caller: Caller对象 - """ self._description = description return self @@ -71,12 +61,78 @@ class Caller(Generic[P]): Returns: F: 函数对象 """ + global _caller_data + if self._name is None: + if module := inspect.getmodule(func): + module_name = module.__name__.split(".")[-1] + else: + module_name = "global" + self._name = f"{module_name}-{func.__name__}" + _caller_data[self._name] = self + + if is_coroutine_callable(func): + self.func = func # type: ignore + else: + self.func = async_wrap(func) # type: ignore + + if module := inspect.getmodule(func): + module_name = module.__name__ + "." + else: + module_name = "" logger.opt(colors=True).info( - f"加载函数 {func.__name__} {self._description}" + f"加载函数 {module_name}{func.__name__}: {self._description}" ) - self.func = func + return func + def data(self) -> dict[str, Any]: + """返回函数的json数据 + + Returns: + dict[str, Any]: 函数的json数据 + """ + return { + "type": "function", + "function": { + "name": self._name, + "description": self._description, + "parameters": { + "type": "object", + "properties": { + key: value.data() for key, value in self._parameters.items() + }, + }, + "required": [ + key + for key, value in self._parameters.items() + if value.default is None + ], + }, + } + + def set_event(self, event: Event): + self.event = event + + async def call(self, *args: Any, **kwargs: Any) -> Any: + """调用函数 + + Returns: + Any: 函数返回值 + """ + if self.func is None: + raise ValueError("未注册函数对象") + sig = inspect.signature(self.func) + for name, param in sig.parameters.items(): + if issubclass(param.annotation, Event) or isinstance( + param.annotation, Event + ): + kwargs[name] = self.event + if issubclass(param.annotation, Caller) or isinstance( + param.annotation, Caller + ): + kwargs[name] = self + return await self.func(*args, **kwargs) + def on_function_call(name: str | None = None, description: str | None = None) -> Caller: """返回一个Caller类,可用于装饰一个函数,使其注册为一个可被AI调用的function call函数 @@ -87,5 +143,14 @@ def on_function_call(name: str | None = None, description: str | None = None) -> Returns: Caller: Caller对象 """ + caller = Caller(name=name, description=description) + return caller - return Caller(name=name, description=description) + +def get_function_calls() -> dict[str, Caller]: + """获取所有已注册的function call函数 + + Returns: + dict[str, Caller]: 所有已注册的function call函数 + """ + return _caller_data diff --git a/nonebot_plugin_marshoai/plugin/func_call/utils.py b/nonebot_plugin_marshoai/plugin/func_call/utils.py new file mode 100644 index 00000000..1ce908a8 --- /dev/null +++ b/nonebot_plugin_marshoai/plugin/func_call/utils.py @@ -0,0 +1,52 @@ +import inspect +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable + +from ..typing import F + + +def copy_signature(func: F) -> Callable[[Callable[..., Any]], F]: + """复制函数签名和文档字符串的装饰器""" + + def decorator(wrapper: Callable[..., Any]) -> F: + @wraps(func) + def wrapped(*args: Any, **kwargs: Any) -> Any: + return wrapper(*args, **kwargs) + + return wrapped # type: ignore + + return decorator + + +def async_wrap(func: F) -> F: + """装饰器,将同步函数包装为异步函数 + + Args: + func (F): 函数对象 + + Returns: + F: 包装后的函数对象 + """ + + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + return func(*args, **kwargs) + + return wrapper # type: ignore + + +def is_coroutine_callable(call: Callable[..., Any]) -> bool: + """ + 判断是否为async def 函数 + 请注意:是否为 async def 函数与该函数是否能被await调用是两个不同的概念,具体取决于函数返回值是否为awaitable对象 + Args: + call: 可调用对象 + Returns: + bool: 是否为async def函数 + """ + if inspect.isroutine(call): + return inspect.iscoroutinefunction(call) + if inspect.isclass(call): + return False + func_ = getattr(call, "__call__", None) + return inspect.iscoroutinefunction(func_) diff --git a/nonebot_plugin_marshoai/plugin/typing.py b/nonebot_plugin_marshoai/plugin/typing.py index 1618dc26..54554e2a 100755 --- a/nonebot_plugin_marshoai/plugin/typing.py +++ b/nonebot_plugin_marshoai/plugin/typing.py @@ -1,5 +1,7 @@ -from typing import Any, Callable, Coroutine, TypeAlias +from typing import Any, Callable, Coroutine, TypeAlias, TypeVar SYNC_FUNCTION_CALL_FUNC: TypeAlias = Callable[..., str] ASYNC_FUNCTION_CALL_FUNC: TypeAlias = Callable[..., Coroutine[str, Any, str]] FUNCTION_CALL_FUNC: TypeAlias = SYNC_FUNCTION_CALL_FUNC | ASYNC_FUNCTION_CALL_FUNC + +F = TypeVar("F", bound=FUNCTION_CALL_FUNC) diff --git a/nonebot_plugin_marshoai/plugin/utils.py b/nonebot_plugin_marshoai/plugin/utils.py index 55dd7e1c..dc02c670 100755 --- a/nonebot_plugin_marshoai/plugin/utils.py +++ b/nonebot_plugin_marshoai/plugin/utils.py @@ -18,21 +18,5 @@ def path_to_module_name(path: Path) -> str: return ".".join(rel_path.parts[:-1] + (rel_path.stem,)) -def is_coroutine_callable(call: Callable[..., Any]) -> bool: - """ - 判断是否为async def 函数 - Args: - call: 可调用对象 - Returns: - bool: 是否为协程可调用对象 - """ - if inspect.isroutine(call): - return inspect.iscoroutinefunction(call) - if inspect.isclass(call): - return False - func_ = getattr(call, "__call__", None) - return inspect.iscoroutinefunction(func_) - - def parse_function_docsring(): pass diff --git a/nonebot_plugin_marshoai/plugins/snowykami_testplugin/__init__.py b/nonebot_plugin_marshoai/plugins/snowykami_testplugin/__init__.py index 47837a0e..d851cb1d 100644 --- a/nonebot_plugin_marshoai/plugins/snowykami_testplugin/__init__.py +++ b/nonebot_plugin_marshoai/plugins/snowykami_testplugin/__init__.py @@ -1,3 +1,5 @@ +from nonebot.adapters.onebot.v11 import MessageEvent + from nonebot_plugin_marshoai.plugin import ( Integer, Parameter, @@ -5,6 +7,7 @@ from nonebot_plugin_marshoai.plugin import ( String, on_function_call, ) +from nonebot_plugin_marshoai.plugin.func_call.caller import Caller __marsho_meta__ = PluginMetadata( name="SnowyKami 测试插件", @@ -19,16 +22,7 @@ __marsho_meta__ = PluginMetadata( gender=String(enum=["男", "女"], description="性别"), ) async def fortune_telling(age: int, name: str, gender: str) -> str: - """使用姓名,年龄,性别进行算命 - - Args: - age (int): _description_ - name (str): _description_ - gender (str): _description_ - - Returns: - str: _description_ - """ + """使用姓名,年龄,性别进行算命""" # 进行一系列算命操作... @@ -41,17 +35,22 @@ async def fortune_telling(age: int, name: str, gender: str) -> str: unit=String(enum=["摄氏度", "华氏度"], description="温度单位"), ) async def get_weather(location: str, days: int, unit: str) -> str: - """获取一个地点未来一段时间的天气 - - Args: - location (str): 地点名称,可以是城市名、地区名等 - days (int): 天数 - unit (str): 温度单位 - - Returns: - str: 天气信息 - """ + """获取一个地点未来一段时间的天气""" # 进行一系列获取天气操作... return f"{location}未来{days}天的天气信息..." + + +@on_function_call(description="获取设备物理地理位置") +async def get_location() -> str: + """获取设备物理地理位置""" + + # 进行一系列获取地理位置操作... + + return "日本 东京都 世田谷区" + + +@on_function_call(description="获取聊天者个人信息") +async def get_user_info(e: MessageEvent, c: Caller) -> str: + return f"用户信息:{e.user_id} {e.sender.nickname}, {c._parameters}" diff --git a/nonebot_plugin_marshoai/tools_wip/marshoai_memory/tools.json b/nonebot_plugin_marshoai/tools_wip/marshoai_memory/tools.json index adab49a4..73e1c22b 100755 --- a/nonebot_plugin_marshoai/tools_wip/marshoai_memory/tools.json +++ b/nonebot_plugin_marshoai/tools_wip/marshoai_memory/tools.json @@ -2,7 +2,7 @@ { "type": "function", "function": { - "name": "marshoai-memory__write_memory", + "name": "marshoai_memory__write_memory", "description": "当你想记住有关与你对话的人的一些信息的时候,调用此函数。", "parameters": { "type": "object",