mirror of
https://github.com/LiteyukiStudio/nonebot-plugin-marshoai.git
synced 2025-03-04 01:23:39 +08:00
✨ 重构Caller类,移除泛型参数;添加函数签名复制装饰器
This commit is contained in:
parent
af9a5e3c96
commit
0379789bec
@ -23,6 +23,8 @@ from nonebot.permission import SUPERUSER
|
|||||||
from nonebot.rule import Rule, to_me
|
from nonebot.rule import Rule, to_me
|
||||||
from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
|
from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
|
||||||
|
|
||||||
|
from nonebot_plugin_marshoai.plugin.func_call.caller import get_function_calls
|
||||||
|
|
||||||
from .metadata import metadata
|
from .metadata import metadata
|
||||||
from .models import MarshoContext, MarshoTools
|
from .models import MarshoContext, MarshoTools
|
||||||
from .plugin import _plugins, load_plugins
|
from .plugin import _plugins, load_plugins
|
||||||
@ -103,10 +105,9 @@ async def _preload_tools():
|
|||||||
@driver.on_startup
|
@driver.on_startup
|
||||||
async def _preload_plugins():
|
async def _preload_plugins():
|
||||||
"""启动钩子加载插件"""
|
"""启动钩子加载插件"""
|
||||||
marshoai_plugin_dirs = config.marshoai_plugin_dirs
|
marshoai_plugin_dirs = config.marshoai_plugin_dirs # 外部插件目录列表
|
||||||
marshoai_plugin_dirs.insert(0, Path(__file__).parent / "plugins")
|
marshoai_plugin_dirs.insert(0, Path(__file__).parent / "plugins") # 预置插件目录
|
||||||
load_plugins(*marshoai_plugin_dirs)
|
load_plugins(*marshoai_plugin_dirs)
|
||||||
logger.opt(colors=True).info(f"已加载 <c>{len(_plugins)}</c> 个小棉插件")
|
|
||||||
|
|
||||||
|
|
||||||
@add_usermsg_cmd.handle()
|
@add_usermsg_cmd.handle()
|
||||||
@ -266,7 +267,10 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None)
|
|||||||
client=client,
|
client=client,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
msg=context_msg + [UserMessage(content=usermsg)], # type: ignore
|
msg=context_msg + [UserMessage(content=usermsg)], # type: ignore
|
||||||
tools=tools.get_tools_list(),
|
tools=tools.get_tools_list()
|
||||||
|
+ list(
|
||||||
|
map(lambda v: v.data(), get_function_calls().values())
|
||||||
|
), # TODO 临时追加函数,后期优化
|
||||||
)
|
)
|
||||||
# await UniMessage(str(response)).send()
|
# await UniMessage(str(response)).send()
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
@ -315,9 +319,23 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None)
|
|||||||
await UniMessage(
|
await UniMessage(
|
||||||
f"调用函数 {tool_call.function.name} ,参数为 {function_args}"
|
f"调用函数 {tool_call.function.name} ,参数为 {function_args}"
|
||||||
).send()
|
).send()
|
||||||
|
# TODO 临时追加插件函数,若工具中没有则调用插件函数
|
||||||
|
if tools.has_function(tool_call.function.name):
|
||||||
|
logger.debug(f"调用工具函数 {tool_call.function.name}")
|
||||||
func_return = await tools.call(
|
func_return = await tools.call(
|
||||||
tool_call.function.name, function_args
|
tool_call.function.name, function_args
|
||||||
) # 获取返回值
|
) # 获取返回值
|
||||||
|
else:
|
||||||
|
if caller := get_function_calls().get(
|
||||||
|
tool_call.function.name
|
||||||
|
):
|
||||||
|
logger.debug(f"调用插件函数 {tool_call.function.name}")
|
||||||
|
# 实现依赖注入,检查函数参数及参数注解类型,对Event类型的参数进行注入
|
||||||
|
caller.event = event
|
||||||
|
func_return = await caller.call(**function_args)
|
||||||
|
else:
|
||||||
|
logger.error(f"未找到函数 {tool_call.function.name}")
|
||||||
|
func_return = f"未找到函数 {tool_call.function.name}"
|
||||||
tool_msg.append(
|
tool_msg.append(
|
||||||
ToolMessage(tool_call_id=tool_call.id, content=func_return) # type: ignore
|
ToolMessage(tool_call_id=tool_call.id, content=func_return) # type: ignore
|
||||||
)
|
)
|
||||||
|
@ -90,6 +90,7 @@ class MarshoTools:
|
|||||||
with open(json_path, "r", encoding="utf-8") as json_file:
|
with open(json_path, "r", encoding="utf-8") as json_file:
|
||||||
data = json.load(json_file)
|
data = json.load(json_file)
|
||||||
for i in data:
|
for i in data:
|
||||||
|
|
||||||
self.tools_list.append(i)
|
self.tools_list.append(i)
|
||||||
|
|
||||||
spec = importlib.util.spec_from_file_location(
|
spec = importlib.util.spec_from_file_location(
|
||||||
@ -136,6 +137,21 @@ class MarshoTools:
|
|||||||
else:
|
else:
|
||||||
logger.error(f"工具包 '{package_name}' 未导入")
|
logger.error(f"工具包 '{package_name}' 未导入")
|
||||||
|
|
||||||
|
def has_function(self, full_function_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查是否存在指定的函数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
for t in self.tools_list:
|
||||||
|
if t["function"]["name"].replace(
|
||||||
|
"-", "_"
|
||||||
|
) == full_function_name.replace("-", "_"):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"检查函数 '{full_function_name}' 时发生错误:{e}")
|
||||||
|
return False
|
||||||
|
|
||||||
def get_tools_list(self):
|
def get_tools_list(self):
|
||||||
if not self.tools_list or not config.marshoai_enable_tools:
|
if not self.tools_list or not config.marshoai_enable_tools:
|
||||||
return None
|
return None
|
||||||
|
@ -1,36 +1,34 @@
|
|||||||
from typing import Generic, TypeVar
|
import inspect
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from nonebot import logger
|
from nonebot import logger
|
||||||
|
from nonebot.adapters import Event
|
||||||
|
|
||||||
from ..typing import FUNCTION_CALL_FUNC
|
from ..typing import ASYNC_FUNCTION_CALL_FUNC, F
|
||||||
from .params import P
|
from .utils import async_wrap, is_coroutine_callable
|
||||||
|
|
||||||
F = TypeVar("F", bound=FUNCTION_CALL_FUNC)
|
_caller_data: dict[str, "Caller"] = {}
|
||||||
|
|
||||||
|
|
||||||
class Caller(Generic[P]):
|
class Caller:
|
||||||
def __init__(self, name: str | None = None, description: str | None = None):
|
def __init__(self, name: str | None = None, description: str | None = None):
|
||||||
self._name = name
|
self._name = name
|
||||||
self._description = description
|
self._description = description
|
||||||
self._parameters: dict[str, P] = {}
|
self.func: ASYNC_FUNCTION_CALL_FUNC | None = None
|
||||||
self.func: FUNCTION_CALL_FUNC | None = None
|
self._parameters: dict[str, Any] = {}
|
||||||
|
"""依赖注入的参数"""
|
||||||
|
self.event: Event | None = None
|
||||||
|
|
||||||
def params(self, **kwargs: P) -> "Caller":
|
def params(self, **kwargs: Any) -> "Caller":
|
||||||
"""设置多个函数参数
|
|
||||||
Args:
|
|
||||||
**kwargs: 参数字典
|
|
||||||
Returns:
|
|
||||||
Caller: Caller对象
|
|
||||||
"""
|
|
||||||
self._parameters.update(kwargs)
|
self._parameters.update(kwargs)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def param(self, name: str, param: P) -> "Caller":
|
def param(self, name: str, param: Any) -> "Caller":
|
||||||
"""设置一个函数参数
|
"""设置一个函数参数
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name (str): 参数名
|
name (str): 参数名
|
||||||
param (P): 参数对象
|
param (Any): 参数对象
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Caller: Caller对象
|
Caller: Caller对象
|
||||||
@ -51,14 +49,6 @@ class Caller(Generic[P]):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def description(self, description: str) -> "Caller":
|
def description(self, description: str) -> "Caller":
|
||||||
"""设置函数描述
|
|
||||||
|
|
||||||
Args:
|
|
||||||
description (str): 函数描述
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Caller: Caller对象
|
|
||||||
"""
|
|
||||||
self._description = description
|
self._description = description
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -71,12 +61,78 @@ class Caller(Generic[P]):
|
|||||||
Returns:
|
Returns:
|
||||||
F: 函数对象
|
F: 函数对象
|
||||||
"""
|
"""
|
||||||
|
global _caller_data
|
||||||
|
if self._name is None:
|
||||||
|
if module := inspect.getmodule(func):
|
||||||
|
module_name = module.__name__.split(".")[-1]
|
||||||
|
else:
|
||||||
|
module_name = "global"
|
||||||
|
self._name = f"{module_name}-{func.__name__}"
|
||||||
|
_caller_data[self._name] = self
|
||||||
|
|
||||||
|
if is_coroutine_callable(func):
|
||||||
|
self.func = func # type: ignore
|
||||||
|
else:
|
||||||
|
self.func = async_wrap(func) # type: ignore
|
||||||
|
|
||||||
|
if module := inspect.getmodule(func):
|
||||||
|
module_name = module.__name__ + "."
|
||||||
|
else:
|
||||||
|
module_name = ""
|
||||||
logger.opt(colors=True).info(
|
logger.opt(colors=True).info(
|
||||||
f"<y>加载函数 {func.__name__} {self._description}</y>"
|
f"<y>加载函数 {module_name}{func.__name__}: {self._description}</y>"
|
||||||
)
|
)
|
||||||
self.func = func
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
def data(self) -> dict[str, Any]:
|
||||||
|
"""返回函数的json数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, Any]: 函数的json数据
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": self._name,
|
||||||
|
"description": self._description,
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
key: value.data() for key, value in self._parameters.items()
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
key
|
||||||
|
for key, value in self._parameters.items()
|
||||||
|
if value.default is None
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_event(self, event: Event):
|
||||||
|
self.event = event
|
||||||
|
|
||||||
|
async def call(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
"""调用函数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: 函数返回值
|
||||||
|
"""
|
||||||
|
if self.func is None:
|
||||||
|
raise ValueError("未注册函数对象")
|
||||||
|
sig = inspect.signature(self.func)
|
||||||
|
for name, param in sig.parameters.items():
|
||||||
|
if issubclass(param.annotation, Event) or isinstance(
|
||||||
|
param.annotation, Event
|
||||||
|
):
|
||||||
|
kwargs[name] = self.event
|
||||||
|
if issubclass(param.annotation, Caller) or isinstance(
|
||||||
|
param.annotation, Caller
|
||||||
|
):
|
||||||
|
kwargs[name] = self
|
||||||
|
return await self.func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def on_function_call(name: str | None = None, description: str | None = None) -> Caller:
|
def on_function_call(name: str | None = None, description: str | None = None) -> Caller:
|
||||||
"""返回一个Caller类,可用于装饰一个函数,使其注册为一个可被AI调用的function call函数
|
"""返回一个Caller类,可用于装饰一个函数,使其注册为一个可被AI调用的function call函数
|
||||||
@ -87,5 +143,14 @@ def on_function_call(name: str | None = None, description: str | None = None) ->
|
|||||||
Returns:
|
Returns:
|
||||||
Caller: Caller对象
|
Caller: Caller对象
|
||||||
"""
|
"""
|
||||||
|
caller = Caller(name=name, description=description)
|
||||||
|
return caller
|
||||||
|
|
||||||
return Caller(name=name, description=description)
|
|
||||||
|
def get_function_calls() -> dict[str, Caller]:
|
||||||
|
"""获取所有已注册的function call函数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, Caller]: 所有已注册的function call函数
|
||||||
|
"""
|
||||||
|
return _caller_data
|
||||||
|
52
nonebot_plugin_marshoai/plugin/func_call/utils.py
Normal file
52
nonebot_plugin_marshoai/plugin/func_call/utils.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import inspect
|
||||||
|
from functools import wraps
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable
|
||||||
|
|
||||||
|
from ..typing import F
|
||||||
|
|
||||||
|
|
||||||
|
def copy_signature(func: F) -> Callable[[Callable[..., Any]], F]:
|
||||||
|
"""复制函数签名和文档字符串的装饰器"""
|
||||||
|
|
||||||
|
def decorator(wrapper: Callable[..., Any]) -> F:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
return wrapper(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapped # type: ignore
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def async_wrap(func: F) -> F:
|
||||||
|
"""装饰器,将同步函数包装为异步函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func (F): 函数对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
F: 包装后的函数对象
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
|
||||||
|
"""
|
||||||
|
判断是否为async def 函数
|
||||||
|
请注意:是否为 async def 函数与该函数是否能被await调用是两个不同的概念,具体取决于函数返回值是否为awaitable对象
|
||||||
|
Args:
|
||||||
|
call: 可调用对象
|
||||||
|
Returns:
|
||||||
|
bool: 是否为async def函数
|
||||||
|
"""
|
||||||
|
if inspect.isroutine(call):
|
||||||
|
return inspect.iscoroutinefunction(call)
|
||||||
|
if inspect.isclass(call):
|
||||||
|
return False
|
||||||
|
func_ = getattr(call, "__call__", None)
|
||||||
|
return inspect.iscoroutinefunction(func_)
|
@ -1,5 +1,7 @@
|
|||||||
from typing import Any, Callable, Coroutine, TypeAlias
|
from typing import Any, Callable, Coroutine, TypeAlias, TypeVar
|
||||||
|
|
||||||
SYNC_FUNCTION_CALL_FUNC: TypeAlias = Callable[..., str]
|
SYNC_FUNCTION_CALL_FUNC: TypeAlias = Callable[..., str]
|
||||||
ASYNC_FUNCTION_CALL_FUNC: TypeAlias = Callable[..., Coroutine[str, Any, str]]
|
ASYNC_FUNCTION_CALL_FUNC: TypeAlias = Callable[..., Coroutine[str, Any, str]]
|
||||||
FUNCTION_CALL_FUNC: TypeAlias = SYNC_FUNCTION_CALL_FUNC | ASYNC_FUNCTION_CALL_FUNC
|
FUNCTION_CALL_FUNC: TypeAlias = SYNC_FUNCTION_CALL_FUNC | ASYNC_FUNCTION_CALL_FUNC
|
||||||
|
|
||||||
|
F = TypeVar("F", bound=FUNCTION_CALL_FUNC)
|
||||||
|
@ -18,21 +18,5 @@ def path_to_module_name(path: Path) -> str:
|
|||||||
return ".".join(rel_path.parts[:-1] + (rel_path.stem,))
|
return ".".join(rel_path.parts[:-1] + (rel_path.stem,))
|
||||||
|
|
||||||
|
|
||||||
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
|
|
||||||
"""
|
|
||||||
判断是否为async def 函数
|
|
||||||
Args:
|
|
||||||
call: 可调用对象
|
|
||||||
Returns:
|
|
||||||
bool: 是否为协程可调用对象
|
|
||||||
"""
|
|
||||||
if inspect.isroutine(call):
|
|
||||||
return inspect.iscoroutinefunction(call)
|
|
||||||
if inspect.isclass(call):
|
|
||||||
return False
|
|
||||||
func_ = getattr(call, "__call__", None)
|
|
||||||
return inspect.iscoroutinefunction(func_)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_function_docsring():
|
def parse_function_docsring():
|
||||||
pass
|
pass
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from nonebot.adapters.onebot.v11 import MessageEvent
|
||||||
|
|
||||||
from nonebot_plugin_marshoai.plugin import (
|
from nonebot_plugin_marshoai.plugin import (
|
||||||
Integer,
|
Integer,
|
||||||
Parameter,
|
Parameter,
|
||||||
@ -5,6 +7,7 @@ from nonebot_plugin_marshoai.plugin import (
|
|||||||
String,
|
String,
|
||||||
on_function_call,
|
on_function_call,
|
||||||
)
|
)
|
||||||
|
from nonebot_plugin_marshoai.plugin.func_call.caller import Caller
|
||||||
|
|
||||||
__marsho_meta__ = PluginMetadata(
|
__marsho_meta__ = PluginMetadata(
|
||||||
name="SnowyKami 测试插件",
|
name="SnowyKami 测试插件",
|
||||||
@ -19,16 +22,7 @@ __marsho_meta__ = PluginMetadata(
|
|||||||
gender=String(enum=["男", "女"], description="性别"),
|
gender=String(enum=["男", "女"], description="性别"),
|
||||||
)
|
)
|
||||||
async def fortune_telling(age: int, name: str, gender: str) -> str:
|
async def fortune_telling(age: int, name: str, gender: str) -> str:
|
||||||
"""使用姓名,年龄,性别进行算命
|
"""使用姓名,年龄,性别进行算命"""
|
||||||
|
|
||||||
Args:
|
|
||||||
age (int): _description_
|
|
||||||
name (str): _description_
|
|
||||||
gender (str): _description_
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: _description_
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 进行一系列算命操作...
|
# 进行一系列算命操作...
|
||||||
|
|
||||||
@ -41,17 +35,22 @@ async def fortune_telling(age: int, name: str, gender: str) -> str:
|
|||||||
unit=String(enum=["摄氏度", "华氏度"], description="温度单位"),
|
unit=String(enum=["摄氏度", "华氏度"], description="温度单位"),
|
||||||
)
|
)
|
||||||
async def get_weather(location: str, days: int, unit: str) -> str:
|
async def get_weather(location: str, days: int, unit: str) -> str:
|
||||||
"""获取一个地点未来一段时间的天气
|
"""获取一个地点未来一段时间的天气"""
|
||||||
|
|
||||||
Args:
|
|
||||||
location (str): 地点名称,可以是城市名、地区名等
|
|
||||||
days (int): 天数
|
|
||||||
unit (str): 温度单位
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 天气信息
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 进行一系列获取天气操作...
|
# 进行一系列获取天气操作...
|
||||||
|
|
||||||
return f"{location}未来{days}天的天气信息..."
|
return f"{location}未来{days}天的天气信息..."
|
||||||
|
|
||||||
|
|
||||||
|
@on_function_call(description="获取设备物理地理位置")
|
||||||
|
async def get_location() -> str:
|
||||||
|
"""获取设备物理地理位置"""
|
||||||
|
|
||||||
|
# 进行一系列获取地理位置操作...
|
||||||
|
|
||||||
|
return "日本 东京都 世田谷区"
|
||||||
|
|
||||||
|
|
||||||
|
@on_function_call(description="获取聊天者个人信息")
|
||||||
|
async def get_user_info(e: MessageEvent, c: Caller) -> str:
|
||||||
|
return f"用户信息:{e.user_id} {e.sender.nickname}, {c._parameters}"
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "marshoai-memory__write_memory",
|
"name": "marshoai_memory__write_memory",
|
||||||
"description": "当你想记住有关与你对话的人的一些信息的时候,调用此函数。",
|
"description": "当你想记住有关与你对话的人的一些信息的时候,调用此函数。",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user