mirror of
https://github.com/LiteyukiStudio/nonebot-plugin-marshoai.git
synced 2025-01-26 18:12:47 +08:00
✨ 添加依赖注入支持,重构函数调用上下文,优化插件加载机制
This commit is contained in:
parent
a0f657b239
commit
a2c4fb220e
@ -18,17 +18,18 @@ from azure.ai.inference.models import (
|
|||||||
from azure.core.credentials import AzureKeyCredential
|
from azure.core.credentials import AzureKeyCredential
|
||||||
from nonebot import get_driver, logger, on_command, on_message
|
from nonebot import get_driver, logger, on_command, on_message
|
||||||
from nonebot.adapters import Bot, Event, Message
|
from nonebot.adapters import Bot, Event, Message
|
||||||
|
from nonebot.matcher import Matcher
|
||||||
from nonebot.params import CommandArg
|
from nonebot.params import CommandArg
|
||||||
from nonebot.permission import SUPERUSER
|
from nonebot.permission import SUPERUSER
|
||||||
from nonebot.rule import Rule, to_me
|
from nonebot.rule import Rule, to_me
|
||||||
from nonebot.typing import T_State
|
from nonebot.typing import T_State
|
||||||
from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
|
from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
|
||||||
|
|
||||||
from nonebot_plugin_marshoai.plugin.func_call.caller import get_function_calls
|
|
||||||
|
|
||||||
from .metadata import metadata
|
from .metadata import metadata
|
||||||
from .models import MarshoContext, MarshoTools
|
from .models import MarshoContext, MarshoTools
|
||||||
from .plugin import _plugins, load_plugins
|
from .plugin import _plugins, load_plugin, load_plugins
|
||||||
|
from .plugin.func_call.caller import get_function_calls
|
||||||
|
from .plugin.func_call.models import SessionContext
|
||||||
from .util import *
|
from .util import *
|
||||||
|
|
||||||
|
|
||||||
@ -115,10 +116,15 @@ async def _preload_plugins():
|
|||||||
"""启动钩子加载插件"""
|
"""启动钩子加载插件"""
|
||||||
if config.marshoai_enable_plugins:
|
if config.marshoai_enable_plugins:
|
||||||
marshoai_plugin_dirs = config.marshoai_plugin_dirs # 外部插件目录列表
|
marshoai_plugin_dirs = config.marshoai_plugin_dirs # 外部插件目录列表
|
||||||
|
"""加载内置插件"""
|
||||||
marshoai_plugin_dirs.insert(
|
marshoai_plugin_dirs.insert(
|
||||||
0, Path(__file__).parent / "plugins"
|
0, Path(__file__).parent / "plugins"
|
||||||
) # 预置插件目录
|
) # 预置插件目录
|
||||||
|
"""加载指定目录插件"""
|
||||||
load_plugins(*marshoai_plugin_dirs)
|
load_plugins(*marshoai_plugin_dirs)
|
||||||
|
"""加载sys.path下的包"""
|
||||||
|
for package_name in config.marshoai_plugins:
|
||||||
|
load_plugin(package_name)
|
||||||
logger.info(
|
logger.info(
|
||||||
"如果启用小棉插件后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_PLUGINS 设为 false。"
|
"如果启用小棉插件后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_PLUGINS 设为 false。"
|
||||||
)
|
)
|
||||||
@ -227,8 +233,10 @@ async def marsho(
|
|||||||
event: Event,
|
event: Event,
|
||||||
bot: Bot,
|
bot: Bot,
|
||||||
state: T_State,
|
state: T_State,
|
||||||
|
matcher: Matcher,
|
||||||
text: Optional[UniMsg] = None,
|
text: Optional[UniMsg] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
global target_list
|
global target_list
|
||||||
if event.get_message().extract_plain_text() and (
|
if event.get_message().extract_plain_text() and (
|
||||||
not text
|
not text
|
||||||
@ -324,7 +332,7 @@ async def marsho(
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
elif choice["finish_reason"] == CompletionsFinishReason.TOOL_CALLS:
|
elif choice["finish_reason"] == CompletionsFinishReason.TOOL_CALLS:
|
||||||
|
# function call
|
||||||
# 需要获取额外信息,调用函数工具
|
# 需要获取额外信息,调用函数工具
|
||||||
tool_msg = []
|
tool_msg = []
|
||||||
while choice.message.tool_calls != None:
|
while choice.message.tool_calls != None:
|
||||||
@ -360,12 +368,14 @@ async def marsho(
|
|||||||
logger.debug(f"调用插件函数 {tool_call.function.name}")
|
logger.debug(f"调用插件函数 {tool_call.function.name}")
|
||||||
# 权限检查,规则检查 TODO
|
# 权限检查,规则检查 TODO
|
||||||
# 实现依赖注入,检查函数参数及参数注解类型,对Event类型的参数进行注入
|
# 实现依赖注入,检查函数参数及参数注解类型,对Event类型的参数进行注入
|
||||||
caller.event, caller.bot, caller.state = (
|
func_return = await caller.with_ctx(
|
||||||
event,
|
SessionContext(
|
||||||
bot,
|
bot=bot,
|
||||||
state,
|
event=event,
|
||||||
)
|
state=state,
|
||||||
func_return = await caller.call(**function_args)
|
matcher=matcher,
|
||||||
|
)
|
||||||
|
).call(**function_args)
|
||||||
else:
|
else:
|
||||||
logger.error(f"未找到函数 {tool_call.function.name}")
|
logger.error(f"未找到函数 {tool_call.function.name}")
|
||||||
func_return = f"未找到函数 {tool_call.function.name}"
|
func_return = f"未找到函数 {tool_call.function.name}"
|
||||||
|
@ -53,6 +53,8 @@ class ConfigModel(BaseModel):
|
|||||||
"""插件目录(不是工具)"""
|
"""插件目录(不是工具)"""
|
||||||
marshoai_devmode: bool = False
|
marshoai_devmode: bool = False
|
||||||
"""开发者模式"""
|
"""开发者模式"""
|
||||||
|
marshoai_plugins: list[str] = []
|
||||||
|
"""marsho插件的名称列表,从pip安装的使用包名,从本地导入的使用路径"""
|
||||||
|
|
||||||
|
|
||||||
yaml = YAML()
|
yaml = YAML()
|
||||||
|
@ -3,11 +3,13 @@ from typing import Any
|
|||||||
|
|
||||||
from nonebot import logger
|
from nonebot import logger
|
||||||
from nonebot.adapters import Bot, Event
|
from nonebot.adapters import Bot, Event
|
||||||
|
from nonebot.matcher import Matcher
|
||||||
from nonebot.permission import Permission
|
from nonebot.permission import Permission
|
||||||
from nonebot.rule import Rule
|
from nonebot.rule import Rule
|
||||||
from nonebot.typing import T_State
|
from nonebot.typing import T_State
|
||||||
|
|
||||||
from ..typing import ASYNC_FUNCTION_CALL_FUNC, F
|
from ..typing import ASYNC_FUNCTION_CALL_FUNC, F
|
||||||
|
from .models import SessionContext, SessionContextDepends
|
||||||
from .utils import async_wrap, is_coroutine_callable
|
from .utils import async_wrap, is_coroutine_callable
|
||||||
|
|
||||||
_caller_data: dict[str, "Caller"] = {}
|
_caller_data: dict[str, "Caller"] = {}
|
||||||
@ -19,10 +21,15 @@ class Caller:
|
|||||||
self._description = description
|
self._description = description
|
||||||
self.func: ASYNC_FUNCTION_CALL_FUNC | None = None
|
self.func: ASYNC_FUNCTION_CALL_FUNC | None = None
|
||||||
self._parameters: dict[str, Any] = {}
|
self._parameters: dict[str, Any] = {}
|
||||||
"""依赖注入的参数"""
|
"""声明参数"""
|
||||||
self.bot: Bot | None = None
|
|
||||||
self.event: Event | None = None
|
self.di: SessionContextDepends = SessionContextDepends()
|
||||||
self.state: T_State | None = None
|
"""依赖注入的参数信息"""
|
||||||
|
|
||||||
|
self.default: dict[str, Any] = {}
|
||||||
|
"""默认值"""
|
||||||
|
|
||||||
|
self.ctx: SessionContext | None = None
|
||||||
|
|
||||||
self._permission: Permission | None = None
|
self._permission: Permission | None = None
|
||||||
self._rule: Rule | None = None
|
self._rule: Rule | None = None
|
||||||
@ -36,14 +43,20 @@ class Caller:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
async def pre_check(self) -> tuple[bool, str]:
|
async def pre_check(self) -> tuple[bool, str]:
|
||||||
if self.bot is None or self.event is None:
|
if self.ctx is None:
|
||||||
|
return False, "上下文为空"
|
||||||
|
if self.ctx.bot is None or self.ctx.event is None:
|
||||||
return False, "Context is None"
|
return False, "Context is None"
|
||||||
if self._permission and not await self._permission(self.bot, self.event):
|
if self._permission and not await self._permission(
|
||||||
|
self.ctx.bot, self.ctx.event
|
||||||
|
):
|
||||||
return False, "告诉用户 Permission Denied 权限不足"
|
return False, "告诉用户 Permission Denied 权限不足"
|
||||||
|
|
||||||
if self.state is None:
|
if self.ctx.state is None:
|
||||||
return False, "State is None"
|
return False, "State is None"
|
||||||
if self._rule and not await self._rule(self.bot, self.event, self.state):
|
if self._rule and not await self._rule(
|
||||||
|
self.ctx.bot, self.ctx.event, self.ctx.state
|
||||||
|
):
|
||||||
return False, "告诉用户 Rule Denied 规则不匹配"
|
return False, "告诉用户 Rule Denied 规则不匹配"
|
||||||
|
|
||||||
return True, ""
|
return True, ""
|
||||||
@ -86,6 +99,35 @@ class Caller:
|
|||||||
self._name = f"{module_name}-{func.__name__}"
|
self._name = f"{module_name}-{func.__name__}"
|
||||||
_caller_data[self._name] = self
|
_caller_data[self._name] = self
|
||||||
|
|
||||||
|
# 检查函数签名,确定依赖注入参数
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
for name, param in sig.parameters.items():
|
||||||
|
if issubclass(param.annotation, Event) or isinstance(
|
||||||
|
param.annotation, Event
|
||||||
|
):
|
||||||
|
self.di.event = name
|
||||||
|
|
||||||
|
if issubclass(param.annotation, Caller) or isinstance(
|
||||||
|
param.annotation, Caller
|
||||||
|
):
|
||||||
|
self.di.caller = name
|
||||||
|
|
||||||
|
if issubclass(param.annotation, Bot) or isinstance(param.annotation, Bot):
|
||||||
|
self.di.bot = name
|
||||||
|
|
||||||
|
if issubclass(param.annotation, Matcher) or isinstance(
|
||||||
|
param.annotation, Matcher
|
||||||
|
):
|
||||||
|
self.di.matcher = name
|
||||||
|
|
||||||
|
if param.annotation == T_State:
|
||||||
|
self.di.state = name
|
||||||
|
|
||||||
|
# 检查默认值情况
|
||||||
|
for name, param in sig.parameters.items():
|
||||||
|
if param.default is not inspect.Parameter.empty:
|
||||||
|
self.default[name] = param.default
|
||||||
|
|
||||||
if is_coroutine_callable(func):
|
if is_coroutine_callable(func):
|
||||||
self.func = func # type: ignore
|
self.func = func # type: ignore
|
||||||
else:
|
else:
|
||||||
@ -126,11 +168,30 @@ class Caller:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def set_event(self, event: Event):
|
def set_ctx(self, ctx: SessionContext) -> None:
|
||||||
self.event = event
|
"""设置依赖注入上下文
|
||||||
|
|
||||||
def set_bot(self, bot: Bot):
|
Args:
|
||||||
self.bot = bot
|
ctx (SessionContext): 依赖注入上下文
|
||||||
|
"""
|
||||||
|
ctx.caller = self
|
||||||
|
self.ctx = ctx
|
||||||
|
for type_name, arg_name in self.di.model_dump().items():
|
||||||
|
|
||||||
|
if arg_name:
|
||||||
|
self.default[arg_name] = ctx.__getattribute__(type_name)
|
||||||
|
|
||||||
|
def with_ctx(self, ctx: SessionContext) -> "Caller":
|
||||||
|
"""设置依赖注入上下文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx (SessionContext): 依赖注入上下文
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Caller: Caller对象
|
||||||
|
"""
|
||||||
|
self.set_ctx(ctx)
|
||||||
|
return self
|
||||||
|
|
||||||
async def call(self, *args: Any, **kwargs: Any) -> Any:
|
async def call(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
"""调用函数
|
"""调用函数
|
||||||
@ -145,28 +206,11 @@ class Caller:
|
|||||||
|
|
||||||
if self.func is None:
|
if self.func is None:
|
||||||
raise ValueError("未注册函数对象")
|
raise ValueError("未注册函数对象")
|
||||||
sig = inspect.signature(self.func)
|
|
||||||
for name, param in sig.parameters.items():
|
|
||||||
if issubclass(param.annotation, Event) or isinstance(
|
|
||||||
param.annotation, Event
|
|
||||||
):
|
|
||||||
kwargs[name] = self.event
|
|
||||||
|
|
||||||
if issubclass(param.annotation, Caller) or isinstance(
|
# 检查形参是否有默认值或传入,若没有则用default中的默认值填充
|
||||||
param.annotation, Caller
|
for name, value in self.default.items():
|
||||||
):
|
|
||||||
kwargs[name] = self
|
|
||||||
|
|
||||||
if issubclass(param.annotation, Bot) or isinstance(param.annotation, Bot):
|
|
||||||
kwargs[name] = self.bot
|
|
||||||
|
|
||||||
if param.annotation == T_State:
|
|
||||||
kwargs[name] = self.state
|
|
||||||
|
|
||||||
# 检查形参是否有默认值或传入,若没有则用parameters中的默认值填充
|
|
||||||
for name, param in sig.parameters.items():
|
|
||||||
if name not in kwargs:
|
if name not in kwargs:
|
||||||
kwargs[name] = self._parameters.get(name, param.default)
|
kwargs[name] = value
|
||||||
|
|
||||||
return await self.func(*args, **kwargs)
|
return await self.func(*args, **kwargs)
|
||||||
|
|
||||||
|
34
nonebot_plugin_marshoai/plugin/func_call/models.py
Normal file
34
nonebot_plugin_marshoai/plugin/func_call/models.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from nonebot.adapters import Bot, Event
|
||||||
|
from nonebot.matcher import Matcher
|
||||||
|
from nonebot.typing import T_State
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .caller import Caller
|
||||||
|
|
||||||
|
|
||||||
|
class SessionContext(BaseModel):
|
||||||
|
"""依赖注入会话上下文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
BaseModel (_type_): _description_
|
||||||
|
"""
|
||||||
|
|
||||||
|
bot: Bot
|
||||||
|
event: Event
|
||||||
|
matcher: Matcher
|
||||||
|
state: T_State
|
||||||
|
caller: Any = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
class SessionContextDepends(BaseModel):
|
||||||
|
bot: str | None = None
|
||||||
|
event: str | None = None
|
||||||
|
matcher: str | None = None
|
||||||
|
state: str | None = None
|
||||||
|
caller: str | None = None
|
Loading…
x
Reference in New Issue
Block a user