mirror of
https://github.com/LiteyukiStudio/nonebot-plugin-marshoai.git
synced 2025-02-07 18:36:09 +08:00
✨ 添加函数调用支持,重构插件加载机制,优化函数描述和模块路径管理
This commit is contained in:
parent
339d0e05bf
commit
7893f28259
@ -6,9 +6,11 @@ require("nonebot_plugin_localstore")
|
|||||||
import nonebot_plugin_localstore as store # type: ignore
|
import nonebot_plugin_localstore as store # type: ignore
|
||||||
from nonebot import get_driver, logger # type: ignore
|
from nonebot import get_driver, logger # type: ignore
|
||||||
|
|
||||||
# from .hunyuan import *
|
|
||||||
from .azure import *
|
from .azure import *
|
||||||
from .config import config
|
from .config import config
|
||||||
|
|
||||||
|
# from .hunyuan import *
|
||||||
|
from .dev import *
|
||||||
from .metadata import metadata
|
from .metadata import metadata
|
||||||
|
|
||||||
__author__ = "Asankilp"
|
__author__ = "Asankilp"
|
||||||
|
@ -350,11 +350,11 @@ async def marsho(
|
|||||||
tool_call.function.arguments.replace("'", '"')
|
tool_call.function.arguments.replace("'", '"')
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"调用函数 {tool_call.function.name.replace("-", ".")}\n参数:"
|
f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:"
|
||||||
+ "\n".join([f"{k}={v}" for k, v in function_args.items()])
|
+ "\n".join([f"{k}={v}" for k, v in function_args.items()])
|
||||||
)
|
)
|
||||||
await UniMessage(
|
await UniMessage(
|
||||||
f"调用函数 {tool_call.function.name.replace("-", ".")}\n参数:"
|
f"调用函数 {tool_call.function.name.replace('-', '.')}\n参数:"
|
||||||
+ "\n".join([f"{k}={v}" for k, v in function_args.items()])
|
+ "\n".join([f"{k}={v}" for k, v in function_args.items()])
|
||||||
).send()
|
).send()
|
||||||
# TODO 临时追加插件函数,若工具中没有则调用插件函数
|
# TODO 临时追加插件函数,若工具中没有则调用插件函数
|
||||||
@ -365,9 +365,9 @@ async def marsho(
|
|||||||
) # 获取返回值
|
) # 获取返回值
|
||||||
else:
|
else:
|
||||||
if caller := get_function_calls().get(
|
if caller := get_function_calls().get(
|
||||||
tool_call.function.name
|
tool_call.function.name.replace("-", ".")
|
||||||
):
|
):
|
||||||
logger.debug(f"调用插件函数 {tool_call.function.name}")
|
logger.debug(f"调用插件函数 {caller.full_name}")
|
||||||
# 权限检查,规则检查 TODO
|
# 权限检查,规则检查 TODO
|
||||||
# 实现依赖注入,检查函数参数及参数注解类型,对Event类型的参数进行注入
|
# 实现依赖注入,检查函数参数及参数注解类型,对Event类型的参数进行注入
|
||||||
func_return = await caller.with_ctx(
|
func_return = await caller.with_ctx(
|
||||||
@ -379,8 +379,10 @@ async def marsho(
|
|||||||
)
|
)
|
||||||
).call(**function_args)
|
).call(**function_args)
|
||||||
else:
|
else:
|
||||||
logger.error(f"未找到函数 {tool_call.function.name}")
|
logger.error(
|
||||||
func_return = f"未找到函数 {tool_call.function.name}"
|
f"未找到函数 {tool_call.function.name.replace('-', '.')}"
|
||||||
|
)
|
||||||
|
func_return = f"未找到函数 {tool_call.function.name.replace('-', '.')}"
|
||||||
tool_msg.append(
|
tool_msg.append(
|
||||||
ToolMessage(tool_call_id=tool_call.id, content=func_return) # type: ignore
|
ToolMessage(tool_call_id=tool_call.id, content=func_return) # type: ignore
|
||||||
)
|
)
|
||||||
|
@ -52,7 +52,7 @@ class ConfigModel(BaseModel):
|
|||||||
marshoai_plugin_dirs: list[str] = []
|
marshoai_plugin_dirs: list[str] = []
|
||||||
"""插件目录(不是工具)"""
|
"""插件目录(不是工具)"""
|
||||||
marshoai_devmode: bool = False
|
marshoai_devmode: bool = False
|
||||||
"""开发者模式"""
|
"""开发者模式,启用本地插件插件重载"""
|
||||||
marshoai_plugins: list[str] = []
|
marshoai_plugins: list[str] = []
|
||||||
"""marsho插件的名称列表,从pip安装的使用包名,从本地导入的使用路径"""
|
"""marsho插件的名称列表,从pip安装的使用包名,从本地导入的使用路径"""
|
||||||
|
|
||||||
|
@ -1,7 +1,77 @@
|
|||||||
from nonebot import require
|
from nonebot import require
|
||||||
|
from nonebot.adapters import Bot, Event
|
||||||
|
from nonebot.matcher import Matcher
|
||||||
|
from nonebot.typing import T_State
|
||||||
|
|
||||||
|
from nonebot_plugin_marshoai.plugin.func_call.models import SessionContext
|
||||||
|
|
||||||
require("nonebot_plugin_alconna")
|
require("nonebot_plugin_alconna")
|
||||||
|
|
||||||
from nonebot_plugin_alconna import Alconna, on_alconna
|
from nonebot.permission import SUPERUSER
|
||||||
|
from nonebot_plugin_alconna import (
|
||||||
|
Alconna,
|
||||||
|
Args,
|
||||||
|
MultiVar,
|
||||||
|
Subcommand,
|
||||||
|
UniMessage,
|
||||||
|
on_alconna,
|
||||||
|
)
|
||||||
|
|
||||||
function_call = on_alconna("marshocall")
|
from .plugin.func_call.caller import get_function_calls
|
||||||
|
|
||||||
|
function_call = on_alconna(
|
||||||
|
command=Alconna(
|
||||||
|
"marsho-function-call",
|
||||||
|
Subcommand(
|
||||||
|
"call",
|
||||||
|
Args["function_name", str]["kwargs", MultiVar(str), []],
|
||||||
|
alias={"c"},
|
||||||
|
),
|
||||||
|
Subcommand(
|
||||||
|
"list",
|
||||||
|
alias={"l"},
|
||||||
|
),
|
||||||
|
Subcommand("info", Args["function_name", str], alias={"i"}),
|
||||||
|
),
|
||||||
|
aliases={"mfc"},
|
||||||
|
permission=SUPERUSER,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@function_call.assign("list")
|
||||||
|
async def list_functions():
|
||||||
|
reply = "共有如下可调用函数:\n"
|
||||||
|
for function in get_function_calls().values():
|
||||||
|
reply += f"- {function.name}({function.description}))\n"
|
||||||
|
await UniMessage(reply).send()
|
||||||
|
|
||||||
|
|
||||||
|
@function_call.assign("info")
|
||||||
|
async def function_info(function_name: str):
|
||||||
|
function = get_function_calls().get(function_name)
|
||||||
|
if function is None:
|
||||||
|
await UniMessage(f"未找到函数 {function_name}").send()
|
||||||
|
return
|
||||||
|
await UniMessage(str(function)).send()
|
||||||
|
|
||||||
|
|
||||||
|
@function_call.assign("call")
|
||||||
|
async def call_function(
|
||||||
|
function_name: str,
|
||||||
|
kwargs: list[str],
|
||||||
|
event: Event,
|
||||||
|
bot: Bot,
|
||||||
|
matcher: Matcher,
|
||||||
|
state: T_State,
|
||||||
|
):
|
||||||
|
function = get_function_calls().get(function_name)
|
||||||
|
if function is None:
|
||||||
|
await UniMessage(f"未找到函数 {function_name}").send()
|
||||||
|
return
|
||||||
|
await UniMessage(
|
||||||
|
str(
|
||||||
|
await function.with_ctx(
|
||||||
|
SessionContext(event=event, bot=bot, matcher=matcher, state=state)
|
||||||
|
).call(**{i.split("=", 1)[0]: i.split("=", 1)[1] for i in kwargs})
|
||||||
|
)
|
||||||
|
).send()
|
||||||
|
@ -8,6 +8,7 @@ from nonebot.permission import Permission
|
|||||||
from nonebot.rule import Rule
|
from nonebot.rule import Rule
|
||||||
from nonebot.typing import T_State
|
from nonebot.typing import T_State
|
||||||
|
|
||||||
|
from ..models import Plugin
|
||||||
from ..typing import ASYNC_FUNCTION_CALL_FUNC, F
|
from ..typing import ASYNC_FUNCTION_CALL_FUNC, F
|
||||||
from .models import SessionContext, SessionContextDepends
|
from .models import SessionContext, SessionContextDepends
|
||||||
from .utils import async_wrap, is_coroutine_callable
|
from .utils import async_wrap, is_coroutine_callable
|
||||||
@ -16,10 +17,17 @@ _caller_data: dict[str, "Caller"] = {}
|
|||||||
|
|
||||||
|
|
||||||
class Caller:
|
class Caller:
|
||||||
def __init__(self, name: str | None = None, description: str | None = None):
|
def __init__(self, name: str = "", description: str | None = None):
|
||||||
self._name = name
|
self._name: str = name
|
||||||
|
"""函数名称"""
|
||||||
self._description = description
|
self._description = description
|
||||||
|
"""函数描述"""
|
||||||
|
self._plugin: Plugin | None = None
|
||||||
|
"""所属插件对象,装饰时声明"""
|
||||||
self.func: ASYNC_FUNCTION_CALL_FUNC | None = None
|
self.func: ASYNC_FUNCTION_CALL_FUNC | None = None
|
||||||
|
"""函数对象"""
|
||||||
|
self.module_name: str = ""
|
||||||
|
"""模块名"""
|
||||||
self._parameters: dict[str, Any] = {}
|
self._parameters: dict[str, Any] = {}
|
||||||
"""声明参数"""
|
"""声明参数"""
|
||||||
|
|
||||||
@ -91,13 +99,8 @@ class Caller:
|
|||||||
F: 函数对象
|
F: 函数对象
|
||||||
"""
|
"""
|
||||||
global _caller_data
|
global _caller_data
|
||||||
if self._name is None:
|
if not self._name:
|
||||||
if module := inspect.getmodule(func):
|
self._name = func.__name__
|
||||||
module_name = module.__name__.split(".")[-1]
|
|
||||||
else:
|
|
||||||
module_name = ""
|
|
||||||
self._name = f"{module_name}-{func.__name__}"
|
|
||||||
_caller_data[self._name] = self
|
|
||||||
|
|
||||||
# 检查函数签名,确定依赖注入参数
|
# 检查函数签名,确定依赖注入参数
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
@ -137,6 +140,9 @@ class Caller:
|
|||||||
module_name = module.__name__.split(".")[-1] + "."
|
module_name = module.__name__.split(".")[-1] + "."
|
||||||
else:
|
else:
|
||||||
module_name = ""
|
module_name = ""
|
||||||
|
|
||||||
|
self.module_name = module_name
|
||||||
|
_caller_data[self.full_name] = self
|
||||||
logger.opt(colors=True).debug(
|
logger.opt(colors=True).debug(
|
||||||
f"<y>加载函数 {module_name}{func.__name__}: {self._description}</y>"
|
f"<y>加载函数 {module_name}{func.__name__}: {self._description}</y>"
|
||||||
)
|
)
|
||||||
@ -152,7 +158,7 @@ class Caller:
|
|||||||
return {
|
return {
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": self._name,
|
"name": self.aifc_name,
|
||||||
"description": self._description,
|
"description": self._description,
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@ -193,6 +199,11 @@ class Caller:
|
|||||||
self.set_ctx(ctx)
|
self.set_ctx(ctx)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self._name}({self._description})\n" + "\n".join(
|
||||||
|
f" - {key}: {value}" for key, value in self._parameters.items()
|
||||||
|
)
|
||||||
|
|
||||||
async def call(self, *args: Any, **kwargs: Any) -> Any:
|
async def call(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
"""调用函数
|
"""调用函数
|
||||||
|
|
||||||
@ -214,8 +225,23 @@ class Caller:
|
|||||||
|
|
||||||
return await self.func(*args, **kwargs)
|
return await self.func(*args, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def short_name(self) -> str:
|
||||||
|
"""函数本名"""
|
||||||
|
return self._name.split(".")[-1]
|
||||||
|
|
||||||
def on_function_call(name: str | None = None, description: str | None = None) -> Caller:
|
@property
|
||||||
|
def aifc_name(self) -> str:
|
||||||
|
"""AI调用名,没有点"""
|
||||||
|
return self._name.replace(".", "-")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def full_name(self) -> str:
|
||||||
|
"""完整名"""
|
||||||
|
return self.module_name + self._name
|
||||||
|
|
||||||
|
|
||||||
|
def on_function_call(name: str = "", description: str | None = None) -> Caller:
|
||||||
"""返回一个Caller类,可用于装饰一个函数,使其注册为一个可被AI调用的function call函数
|
"""返回一个Caller类,可用于装饰一个函数,使其注册为一个可被AI调用的function call函数
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -66,6 +66,7 @@ def load_plugin(module_path: str | Path) -> Optional[Plugin]:
|
|||||||
name=module.__name__,
|
name=module.__name__,
|
||||||
module=module,
|
module=module,
|
||||||
module_name=module_path,
|
module_name=module_path,
|
||||||
|
module_path=module.__file__,
|
||||||
)
|
)
|
||||||
_plugins[plugin.name] = plugin
|
_plugins[plugin.name] = plugin
|
||||||
|
|
||||||
|
@ -58,6 +58,8 @@ class Plugin(BaseModel):
|
|||||||
"""插件模块对象"""
|
"""插件模块对象"""
|
||||||
module_name: str
|
module_name: str
|
||||||
"""点分割模块路径 例如a.b.c"""
|
"""点分割模块路径 例如a.b.c"""
|
||||||
|
module_path: str | None
|
||||||
|
"""实际路径,单文件为.py的路径,包为__init__.py路径"""
|
||||||
metadata: PluginMetadata | None = None
|
metadata: PluginMetadata | None = None
|
||||||
"""元"""
|
"""元"""
|
||||||
|
|
||||||
@ -69,3 +71,6 @@ class Plugin(BaseModel):
|
|||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
return self.name == other.name
|
return self.name == other.name
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"Plugin({self.name}({self.module_path}))"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user