添加依赖注入支持,重构函数调用上下文,优化插件加载机制

This commit is contained in:
远野千束(神羽) 2024-12-17 13:25:30 +08:00
parent a0f657b239
commit a2c4fb220e
4 changed files with 132 additions and 42 deletions

View File

@ -18,17 +18,18 @@ from azure.ai.inference.models import (
from azure.core.credentials import AzureKeyCredential from azure.core.credentials import AzureKeyCredential
from nonebot import get_driver, logger, on_command, on_message from nonebot import get_driver, logger, on_command, on_message
from nonebot.adapters import Bot, Event, Message from nonebot.adapters import Bot, Event, Message
from nonebot.matcher import Matcher
from nonebot.params import CommandArg from nonebot.params import CommandArg
from nonebot.permission import SUPERUSER from nonebot.permission import SUPERUSER
from nonebot.rule import Rule, to_me from nonebot.rule import Rule, to_me
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
from nonebot_plugin_marshoai.plugin.func_call.caller import get_function_calls
from .metadata import metadata from .metadata import metadata
from .models import MarshoContext, MarshoTools from .models import MarshoContext, MarshoTools
from .plugin import _plugins, load_plugins from .plugin import _plugins, load_plugin, load_plugins
from .plugin.func_call.caller import get_function_calls
from .plugin.func_call.models import SessionContext
from .util import * from .util import *
@ -115,10 +116,15 @@ async def _preload_plugins():
"""启动钩子加载插件""" """启动钩子加载插件"""
if config.marshoai_enable_plugins: if config.marshoai_enable_plugins:
marshoai_plugin_dirs = config.marshoai_plugin_dirs # 外部插件目录列表 marshoai_plugin_dirs = config.marshoai_plugin_dirs # 外部插件目录列表
"""加载内置插件"""
marshoai_plugin_dirs.insert( marshoai_plugin_dirs.insert(
0, Path(__file__).parent / "plugins" 0, Path(__file__).parent / "plugins"
) # 预置插件目录 ) # 预置插件目录
"""加载指定目录插件"""
load_plugins(*marshoai_plugin_dirs) load_plugins(*marshoai_plugin_dirs)
"""加载sys.path下的包"""
for package_name in config.marshoai_plugins:
load_plugin(package_name)
logger.info( logger.info(
"如果启用小棉插件后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_PLUGINS 设为 false。" "如果启用小棉插件后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_PLUGINS 设为 false。"
) )
@ -227,8 +233,10 @@ async def marsho(
event: Event, event: Event,
bot: Bot, bot: Bot,
state: T_State, state: T_State,
matcher: Matcher,
text: Optional[UniMsg] = None, text: Optional[UniMsg] = None,
): ):
global target_list global target_list
if event.get_message().extract_plain_text() and ( if event.get_message().extract_plain_text() and (
not text not text
@ -324,7 +332,7 @@ async def marsho(
) )
return return
elif choice["finish_reason"] == CompletionsFinishReason.TOOL_CALLS: elif choice["finish_reason"] == CompletionsFinishReason.TOOL_CALLS:
# function call
# 需要获取额外信息,调用函数工具 # 需要获取额外信息,调用函数工具
tool_msg = [] tool_msg = []
while choice.message.tool_calls != None: while choice.message.tool_calls != None:
@ -360,12 +368,14 @@ async def marsho(
logger.debug(f"调用插件函数 {tool_call.function.name}") logger.debug(f"调用插件函数 {tool_call.function.name}")
# 权限检查,规则检查 TODO # 权限检查,规则检查 TODO
# 实现依赖注入检查函数参数及参数注解类型对Event类型的参数进行注入 # 实现依赖注入检查函数参数及参数注解类型对Event类型的参数进行注入
caller.event, caller.bot, caller.state = ( func_return = await caller.with_ctx(
event, SessionContext(
bot, bot=bot,
state, event=event,
) state=state,
func_return = await caller.call(**function_args) matcher=matcher,
)
).call(**function_args)
else: else:
logger.error(f"未找到函数 {tool_call.function.name}") logger.error(f"未找到函数 {tool_call.function.name}")
func_return = f"未找到函数 {tool_call.function.name}" func_return = f"未找到函数 {tool_call.function.name}"

View File

@ -53,6 +53,8 @@ class ConfigModel(BaseModel):
"""插件目录(不是工具)""" """插件目录(不是工具)"""
marshoai_devmode: bool = False marshoai_devmode: bool = False
"""开发者模式""" """开发者模式"""
marshoai_plugins: list[str] = []
"""marsho插件的名称列表从pip安装的使用包名从本地导入的使用路径"""
yaml = YAML() yaml = YAML()

View File

@ -3,11 +3,13 @@ from typing import Any
from nonebot import logger from nonebot import logger
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
from nonebot.matcher import Matcher
from nonebot.permission import Permission from nonebot.permission import Permission
from nonebot.rule import Rule from nonebot.rule import Rule
from nonebot.typing import T_State from nonebot.typing import T_State
from ..typing import ASYNC_FUNCTION_CALL_FUNC, F from ..typing import ASYNC_FUNCTION_CALL_FUNC, F
from .models import SessionContext, SessionContextDepends
from .utils import async_wrap, is_coroutine_callable from .utils import async_wrap, is_coroutine_callable
_caller_data: dict[str, "Caller"] = {} _caller_data: dict[str, "Caller"] = {}
@ -19,10 +21,15 @@ class Caller:
self._description = description self._description = description
self.func: ASYNC_FUNCTION_CALL_FUNC | None = None self.func: ASYNC_FUNCTION_CALL_FUNC | None = None
self._parameters: dict[str, Any] = {} self._parameters: dict[str, Any] = {}
"""依赖注入的参数""" """声明参数"""
self.bot: Bot | None = None
self.event: Event | None = None self.di: SessionContextDepends = SessionContextDepends()
self.state: T_State | None = None """依赖注入的参数信息"""
self.default: dict[str, Any] = {}
"""默认值"""
self.ctx: SessionContext | None = None
self._permission: Permission | None = None self._permission: Permission | None = None
self._rule: Rule | None = None self._rule: Rule | None = None
@ -36,14 +43,20 @@ class Caller:
return self return self
async def pre_check(self) -> tuple[bool, str]: async def pre_check(self) -> tuple[bool, str]:
if self.bot is None or self.event is None: if self.ctx is None:
return False, "上下文为空"
if self.ctx.bot is None or self.ctx.event is None:
return False, "Context is None" return False, "Context is None"
if self._permission and not await self._permission(self.bot, self.event): if self._permission and not await self._permission(
self.ctx.bot, self.ctx.event
):
return False, "告诉用户 Permission Denied 权限不足" return False, "告诉用户 Permission Denied 权限不足"
if self.state is None: if self.ctx.state is None:
return False, "State is None" return False, "State is None"
if self._rule and not await self._rule(self.bot, self.event, self.state): if self._rule and not await self._rule(
self.ctx.bot, self.ctx.event, self.ctx.state
):
return False, "告诉用户 Rule Denied 规则不匹配" return False, "告诉用户 Rule Denied 规则不匹配"
return True, "" return True, ""
@ -86,6 +99,35 @@ class Caller:
self._name = f"{module_name}-{func.__name__}" self._name = f"{module_name}-{func.__name__}"
_caller_data[self._name] = self _caller_data[self._name] = self
# 检查函数签名,确定依赖注入参数
sig = inspect.signature(func)
for name, param in sig.parameters.items():
if issubclass(param.annotation, Event) or isinstance(
param.annotation, Event
):
self.di.event = name
if issubclass(param.annotation, Caller) or isinstance(
param.annotation, Caller
):
self.di.caller = name
if issubclass(param.annotation, Bot) or isinstance(param.annotation, Bot):
self.di.bot = name
if issubclass(param.annotation, Matcher) or isinstance(
param.annotation, Matcher
):
self.di.matcher = name
if param.annotation == T_State:
self.di.state = name
# 检查默认值情况
for name, param in sig.parameters.items():
if param.default is not inspect.Parameter.empty:
self.default[name] = param.default
if is_coroutine_callable(func): if is_coroutine_callable(func):
self.func = func # type: ignore self.func = func # type: ignore
else: else:
@ -126,11 +168,30 @@ class Caller:
}, },
} }
def set_event(self, event: Event): def set_ctx(self, ctx: SessionContext) -> None:
self.event = event """设置依赖注入上下文
def set_bot(self, bot: Bot): Args:
self.bot = bot ctx (SessionContext): 依赖注入上下文
"""
ctx.caller = self
self.ctx = ctx
for type_name, arg_name in self.di.model_dump().items():
if arg_name:
self.default[arg_name] = ctx.__getattribute__(type_name)
def with_ctx(self, ctx: SessionContext) -> "Caller":
"""设置依赖注入上下文
Args:
ctx (SessionContext): 依赖注入上下文
Returns:
Caller: Caller对象
"""
self.set_ctx(ctx)
return self
async def call(self, *args: Any, **kwargs: Any) -> Any: async def call(self, *args: Any, **kwargs: Any) -> Any:
"""调用函数 """调用函数
@ -145,28 +206,11 @@ class Caller:
if self.func is None: if self.func is None:
raise ValueError("未注册函数对象") raise ValueError("未注册函数对象")
sig = inspect.signature(self.func)
for name, param in sig.parameters.items():
if issubclass(param.annotation, Event) or isinstance(
param.annotation, Event
):
kwargs[name] = self.event
if issubclass(param.annotation, Caller) or isinstance( # 检查形参是否有默认值或传入若没有则用default中的默认值填充
param.annotation, Caller for name, value in self.default.items():
):
kwargs[name] = self
if issubclass(param.annotation, Bot) or isinstance(param.annotation, Bot):
kwargs[name] = self.bot
if param.annotation == T_State:
kwargs[name] = self.state
# 检查形参是否有默认值或传入若没有则用parameters中的默认值填充
for name, param in sig.parameters.items():
if name not in kwargs: if name not in kwargs:
kwargs[name] = self._parameters.get(name, param.default) kwargs[name] = value
return await self.func(*args, **kwargs) return await self.func(*args, **kwargs)

View File

@ -0,0 +1,34 @@
from typing import TYPE_CHECKING, Any
from nonebot.adapters import Bot, Event
from nonebot.matcher import Matcher
from nonebot.typing import T_State
from pydantic import BaseModel
if TYPE_CHECKING:
from .caller import Caller
class SessionContext(BaseModel):
"""依赖注入会话上下文
Args:
BaseModel (_type_): _description_
"""
bot: Bot
event: Event
matcher: Matcher
state: T_State
caller: Any = None
class Config:
arbitrary_types_allowed = True
class SessionContextDepends(BaseModel):
bot: str | None = None
event: str | None = None
matcher: str | None = None
state: str | None = None
caller: str | None = None