diff --git a/nonebot_plugin_marshoai/azure.py b/nonebot_plugin_marshoai/azure.py index e628ebd3..95cf4635 100644 --- a/nonebot_plugin_marshoai/azure.py +++ b/nonebot_plugin_marshoai/azure.py @@ -18,17 +18,18 @@ from azure.ai.inference.models import ( from azure.core.credentials import AzureKeyCredential from nonebot import get_driver, logger, on_command, on_message from nonebot.adapters import Bot, Event, Message +from nonebot.matcher import Matcher from nonebot.params import CommandArg from nonebot.permission import SUPERUSER from nonebot.rule import Rule, to_me from nonebot.typing import T_State from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna -from nonebot_plugin_marshoai.plugin.func_call.caller import get_function_calls - from .metadata import metadata from .models import MarshoContext, MarshoTools -from .plugin import _plugins, load_plugins +from .plugin import _plugins, load_plugin, load_plugins +from .plugin.func_call.caller import get_function_calls +from .plugin.func_call.models import SessionContext from .util import * @@ -115,10 +116,15 @@ async def _preload_plugins(): """启动钩子加载插件""" if config.marshoai_enable_plugins: marshoai_plugin_dirs = config.marshoai_plugin_dirs # 外部插件目录列表 + """加载内置插件""" marshoai_plugin_dirs.insert( 0, Path(__file__).parent / "plugins" ) # 预置插件目录 + """加载指定目录插件""" load_plugins(*marshoai_plugin_dirs) + """加载sys.path下的包""" + for package_name in config.marshoai_plugins: + load_plugin(package_name) logger.info( "如果启用小棉插件后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_PLUGINS 设为 false。" ) @@ -227,8 +233,10 @@ async def marsho( event: Event, bot: Bot, state: T_State, + matcher: Matcher, text: Optional[UniMsg] = None, ): + global target_list if event.get_message().extract_plain_text() and ( not text @@ -324,7 +332,7 @@ async def marsho( ) return elif choice["finish_reason"] == CompletionsFinishReason.TOOL_CALLS: - + # function call # 需要获取额外信息,调用函数工具 tool_msg = [] while choice.message.tool_calls != None: @@ -360,12 +368,14 @@ async def marsho( logger.debug(f"调用插件函数 {tool_call.function.name}") # 权限检查,规则检查 TODO # 实现依赖注入,检查函数参数及参数注解类型,对Event类型的参数进行注入 - caller.event, caller.bot, caller.state = ( - event, - bot, - state, - ) - func_return = await caller.call(**function_args) + func_return = await caller.with_ctx( + SessionContext( + bot=bot, + event=event, + state=state, + matcher=matcher, + ) + ).call(**function_args) else: logger.error(f"未找到函数 {tool_call.function.name}") func_return = f"未找到函数 {tool_call.function.name}" diff --git a/nonebot_plugin_marshoai/config.py b/nonebot_plugin_marshoai/config.py index 941569c5..86d1b71e 100644 --- a/nonebot_plugin_marshoai/config.py +++ b/nonebot_plugin_marshoai/config.py @@ -53,6 +53,8 @@ class ConfigModel(BaseModel): """插件目录(不是工具)""" marshoai_devmode: bool = False """开发者模式""" + marshoai_plugins: list[str] = [] + """marsho插件的名称列表,从pip安装的使用包名,从本地导入的使用路径""" yaml = YAML() diff --git a/nonebot_plugin_marshoai/plugin/func_call/caller.py b/nonebot_plugin_marshoai/plugin/func_call/caller.py index c8c090b7..727793f0 100644 --- a/nonebot_plugin_marshoai/plugin/func_call/caller.py +++ b/nonebot_plugin_marshoai/plugin/func_call/caller.py @@ -3,11 +3,13 @@ from typing import Any from nonebot import logger from nonebot.adapters import Bot, Event +from nonebot.matcher import Matcher from nonebot.permission import Permission from nonebot.rule import Rule from nonebot.typing import T_State from ..typing import ASYNC_FUNCTION_CALL_FUNC, F +from .models import SessionContext, SessionContextDepends from .utils import async_wrap, is_coroutine_callable _caller_data: dict[str, "Caller"] = {} @@ -19,10 +21,15 @@ class Caller: self._description = description self.func: ASYNC_FUNCTION_CALL_FUNC | None = None self._parameters: dict[str, Any] = {} - """依赖注入的参数""" - self.bot: Bot | None = None - self.event: Event | None = None - self.state: T_State | None = None + """声明参数""" + + self.di: SessionContextDepends = SessionContextDepends() + """依赖注入的参数信息""" + + self.default: dict[str, Any] = {} + """默认值""" + + self.ctx: SessionContext | None = None self._permission: Permission | None = None self._rule: Rule | None = None @@ -36,14 +43,20 @@ class Caller: return self async def pre_check(self) -> tuple[bool, str]: - if self.bot is None or self.event is None: + if self.ctx is None: + return False, "上下文为空" + if self.ctx.bot is None or self.ctx.event is None: return False, "Context is None" - if self._permission and not await self._permission(self.bot, self.event): + if self._permission and not await self._permission( + self.ctx.bot, self.ctx.event + ): return False, "告诉用户 Permission Denied 权限不足" - if self.state is None: + if self.ctx.state is None: return False, "State is None" - if self._rule and not await self._rule(self.bot, self.event, self.state): + if self._rule and not await self._rule( + self.ctx.bot, self.ctx.event, self.ctx.state + ): return False, "告诉用户 Rule Denied 规则不匹配" return True, "" @@ -86,6 +99,35 @@ class Caller: self._name = f"{module_name}-{func.__name__}" _caller_data[self._name] = self + # 检查函数签名,确定依赖注入参数 + sig = inspect.signature(func) + for name, param in sig.parameters.items(): + if issubclass(param.annotation, Event) or isinstance( + param.annotation, Event + ): + self.di.event = name + + if issubclass(param.annotation, Caller) or isinstance( + param.annotation, Caller + ): + self.di.caller = name + + if issubclass(param.annotation, Bot) or isinstance(param.annotation, Bot): + self.di.bot = name + + if issubclass(param.annotation, Matcher) or isinstance( + param.annotation, Matcher + ): + self.di.matcher = name + + if param.annotation == T_State: + self.di.state = name + + # 检查默认值情况 + for name, param in sig.parameters.items(): + if param.default is not inspect.Parameter.empty: + self.default[name] = param.default + if is_coroutine_callable(func): self.func = func # type: ignore else: @@ -126,11 +168,30 @@ class Caller: }, } - def set_event(self, event: Event): - self.event = event + def set_ctx(self, ctx: SessionContext) -> None: + """设置依赖注入上下文 - def set_bot(self, bot: Bot): - self.bot = bot + Args: + ctx (SessionContext): 依赖注入上下文 + """ + ctx.caller = self + self.ctx = ctx + for type_name, arg_name in self.di.model_dump().items(): + + if arg_name: + self.default[arg_name] = ctx.__getattribute__(type_name) + + def with_ctx(self, ctx: SessionContext) -> "Caller": + """设置依赖注入上下文 + + Args: + ctx (SessionContext): 依赖注入上下文 + + Returns: + Caller: Caller对象 + """ + self.set_ctx(ctx) + return self async def call(self, *args: Any, **kwargs: Any) -> Any: """调用函数 @@ -145,28 +206,11 @@ class Caller: if self.func is None: raise ValueError("未注册函数对象") - sig = inspect.signature(self.func) - for name, param in sig.parameters.items(): - if issubclass(param.annotation, Event) or isinstance( - param.annotation, Event - ): - kwargs[name] = self.event - if issubclass(param.annotation, Caller) or isinstance( - param.annotation, Caller - ): - kwargs[name] = self - - if issubclass(param.annotation, Bot) or isinstance(param.annotation, Bot): - kwargs[name] = self.bot - - if param.annotation == T_State: - kwargs[name] = self.state - - # 检查形参是否有默认值或传入,若没有则用parameters中的默认值填充 - for name, param in sig.parameters.items(): + # 检查形参是否有默认值或传入,若没有则用default中的默认值填充 + for name, value in self.default.items(): if name not in kwargs: - kwargs[name] = self._parameters.get(name, param.default) + kwargs[name] = value return await self.func(*args, **kwargs) diff --git a/nonebot_plugin_marshoai/plugin/func_call/models.py b/nonebot_plugin_marshoai/plugin/func_call/models.py new file mode 100644 index 00000000..379388e7 --- /dev/null +++ b/nonebot_plugin_marshoai/plugin/func_call/models.py @@ -0,0 +1,34 @@ +from typing import TYPE_CHECKING, Any + +from nonebot.adapters import Bot, Event +from nonebot.matcher import Matcher +from nonebot.typing import T_State +from pydantic import BaseModel + +if TYPE_CHECKING: + from .caller import Caller + + +class SessionContext(BaseModel): + """依赖注入会话上下文 + + Args: + BaseModel (_type_): _description_ + """ + + bot: Bot + event: Event + matcher: Matcher + state: T_State + caller: Any = None + + class Config: + arbitrary_types_allowed = True + + +class SessionContextDepends(BaseModel): + bot: str | None = None + event: str | None = None + matcher: str | None = None + state: str | None = None + caller: str | None = None