mirror of
https://github.com/LiteyukiStudio/nonebot-plugin-marshoai.git
synced 2025-03-03 12:43:40 +08:00
✨ 重构Caller类,移除泛型参数;添加函数签名复制装饰器
This commit is contained in:
parent
af9a5e3c96
commit
0379789bec
@ -23,6 +23,8 @@ from nonebot.permission import SUPERUSER
|
||||
from nonebot.rule import Rule, to_me
|
||||
from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
|
||||
|
||||
from nonebot_plugin_marshoai.plugin.func_call.caller import get_function_calls
|
||||
|
||||
from .metadata import metadata
|
||||
from .models import MarshoContext, MarshoTools
|
||||
from .plugin import _plugins, load_plugins
|
||||
@ -103,10 +105,9 @@ async def _preload_tools():
|
||||
@driver.on_startup
|
||||
async def _preload_plugins():
|
||||
"""启动钩子加载插件"""
|
||||
marshoai_plugin_dirs = config.marshoai_plugin_dirs
|
||||
marshoai_plugin_dirs.insert(0, Path(__file__).parent / "plugins")
|
||||
marshoai_plugin_dirs = config.marshoai_plugin_dirs # 外部插件目录列表
|
||||
marshoai_plugin_dirs.insert(0, Path(__file__).parent / "plugins") # 预置插件目录
|
||||
load_plugins(*marshoai_plugin_dirs)
|
||||
logger.opt(colors=True).info(f"已加载 <c>{len(_plugins)}</c> 个小棉插件")
|
||||
|
||||
|
||||
@add_usermsg_cmd.handle()
|
||||
@ -266,7 +267,10 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None)
|
||||
client=client,
|
||||
model_name=model_name,
|
||||
msg=context_msg + [UserMessage(content=usermsg)], # type: ignore
|
||||
tools=tools.get_tools_list(),
|
||||
tools=tools.get_tools_list()
|
||||
+ list(
|
||||
map(lambda v: v.data(), get_function_calls().values())
|
||||
), # TODO 临时追加函数,后期优化
|
||||
)
|
||||
# await UniMessage(str(response)).send()
|
||||
choice = response.choices[0]
|
||||
@ -315,9 +319,23 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None)
|
||||
await UniMessage(
|
||||
f"调用函数 {tool_call.function.name} ,参数为 {function_args}"
|
||||
).send()
|
||||
func_return = await tools.call(
|
||||
tool_call.function.name, function_args
|
||||
) # 获取返回值
|
||||
# TODO 临时追加插件函数,若工具中没有则调用插件函数
|
||||
if tools.has_function(tool_call.function.name):
|
||||
logger.debug(f"调用工具函数 {tool_call.function.name}")
|
||||
func_return = await tools.call(
|
||||
tool_call.function.name, function_args
|
||||
) # 获取返回值
|
||||
else:
|
||||
if caller := get_function_calls().get(
|
||||
tool_call.function.name
|
||||
):
|
||||
logger.debug(f"调用插件函数 {tool_call.function.name}")
|
||||
# 实现依赖注入,检查函数参数及参数注解类型,对Event类型的参数进行注入
|
||||
caller.event = event
|
||||
func_return = await caller.call(**function_args)
|
||||
else:
|
||||
logger.error(f"未找到函数 {tool_call.function.name}")
|
||||
func_return = f"未找到函数 {tool_call.function.name}"
|
||||
tool_msg.append(
|
||||
ToolMessage(tool_call_id=tool_call.id, content=func_return) # type: ignore
|
||||
)
|
||||
|
@ -90,6 +90,7 @@ class MarshoTools:
|
||||
with open(json_path, "r", encoding="utf-8") as json_file:
|
||||
data = json.load(json_file)
|
||||
for i in data:
|
||||
|
||||
self.tools_list.append(i)
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
@ -136,6 +137,21 @@ class MarshoTools:
|
||||
else:
|
||||
logger.error(f"工具包 '{package_name}' 未导入")
|
||||
|
||||
def has_function(self, full_function_name: str) -> bool:
|
||||
"""
|
||||
检查是否存在指定的函数
|
||||
"""
|
||||
try:
|
||||
for t in self.tools_list:
|
||||
if t["function"]["name"].replace(
|
||||
"-", "_"
|
||||
) == full_function_name.replace("-", "_"):
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"检查函数 '{full_function_name}' 时发生错误:{e}")
|
||||
return False
|
||||
|
||||
def get_tools_list(self):
|
||||
if not self.tools_list or not config.marshoai_enable_tools:
|
||||
return None
|
||||
|
@ -1,36 +1,34 @@
|
||||
from typing import Generic, TypeVar
|
||||
import inspect
|
||||
from typing import Any
|
||||
|
||||
from nonebot import logger
|
||||
from nonebot.adapters import Event
|
||||
|
||||
from ..typing import FUNCTION_CALL_FUNC
|
||||
from .params import P
|
||||
from ..typing import ASYNC_FUNCTION_CALL_FUNC, F
|
||||
from .utils import async_wrap, is_coroutine_callable
|
||||
|
||||
F = TypeVar("F", bound=FUNCTION_CALL_FUNC)
|
||||
_caller_data: dict[str, "Caller"] = {}
|
||||
|
||||
|
||||
class Caller(Generic[P]):
|
||||
class Caller:
|
||||
def __init__(self, name: str | None = None, description: str | None = None):
|
||||
self._name = name
|
||||
self._description = description
|
||||
self._parameters: dict[str, P] = {}
|
||||
self.func: FUNCTION_CALL_FUNC | None = None
|
||||
self.func: ASYNC_FUNCTION_CALL_FUNC | None = None
|
||||
self._parameters: dict[str, Any] = {}
|
||||
"""依赖注入的参数"""
|
||||
self.event: Event | None = None
|
||||
|
||||
def params(self, **kwargs: P) -> "Caller":
|
||||
"""设置多个函数参数
|
||||
Args:
|
||||
**kwargs: 参数字典
|
||||
Returns:
|
||||
Caller: Caller对象
|
||||
"""
|
||||
def params(self, **kwargs: Any) -> "Caller":
|
||||
self._parameters.update(kwargs)
|
||||
return self
|
||||
|
||||
def param(self, name: str, param: P) -> "Caller":
|
||||
def param(self, name: str, param: Any) -> "Caller":
|
||||
"""设置一个函数参数
|
||||
|
||||
Args:
|
||||
name (str): 参数名
|
||||
param (P): 参数对象
|
||||
param (Any): 参数对象
|
||||
|
||||
Returns:
|
||||
Caller: Caller对象
|
||||
@ -51,14 +49,6 @@ class Caller(Generic[P]):
|
||||
return self
|
||||
|
||||
def description(self, description: str) -> "Caller":
|
||||
"""设置函数描述
|
||||
|
||||
Args:
|
||||
description (str): 函数描述
|
||||
|
||||
Returns:
|
||||
Caller: Caller对象
|
||||
"""
|
||||
self._description = description
|
||||
return self
|
||||
|
||||
@ -71,12 +61,78 @@ class Caller(Generic[P]):
|
||||
Returns:
|
||||
F: 函数对象
|
||||
"""
|
||||
global _caller_data
|
||||
if self._name is None:
|
||||
if module := inspect.getmodule(func):
|
||||
module_name = module.__name__.split(".")[-1]
|
||||
else:
|
||||
module_name = "global"
|
||||
self._name = f"{module_name}-{func.__name__}"
|
||||
_caller_data[self._name] = self
|
||||
|
||||
if is_coroutine_callable(func):
|
||||
self.func = func # type: ignore
|
||||
else:
|
||||
self.func = async_wrap(func) # type: ignore
|
||||
|
||||
if module := inspect.getmodule(func):
|
||||
module_name = module.__name__ + "."
|
||||
else:
|
||||
module_name = ""
|
||||
logger.opt(colors=True).info(
|
||||
f"<y>加载函数 {func.__name__} {self._description}</y>"
|
||||
f"<y>加载函数 {module_name}{func.__name__}: {self._description}</y>"
|
||||
)
|
||||
self.func = func
|
||||
|
||||
return func
|
||||
|
||||
def data(self) -> dict[str, Any]:
|
||||
"""返回函数的json数据
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: 函数的json数据
|
||||
"""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self._name,
|
||||
"description": self._description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
key: value.data() for key, value in self._parameters.items()
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
key
|
||||
for key, value in self._parameters.items()
|
||||
if value.default is None
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
def set_event(self, event: Event):
|
||||
self.event = event
|
||||
|
||||
async def call(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""调用函数
|
||||
|
||||
Returns:
|
||||
Any: 函数返回值
|
||||
"""
|
||||
if self.func is None:
|
||||
raise ValueError("未注册函数对象")
|
||||
sig = inspect.signature(self.func)
|
||||
for name, param in sig.parameters.items():
|
||||
if issubclass(param.annotation, Event) or isinstance(
|
||||
param.annotation, Event
|
||||
):
|
||||
kwargs[name] = self.event
|
||||
if issubclass(param.annotation, Caller) or isinstance(
|
||||
param.annotation, Caller
|
||||
):
|
||||
kwargs[name] = self
|
||||
return await self.func(*args, **kwargs)
|
||||
|
||||
|
||||
def on_function_call(name: str | None = None, description: str | None = None) -> Caller:
|
||||
"""返回一个Caller类,可用于装饰一个函数,使其注册为一个可被AI调用的function call函数
|
||||
@ -87,5 +143,14 @@ def on_function_call(name: str | None = None, description: str | None = None) ->
|
||||
Returns:
|
||||
Caller: Caller对象
|
||||
"""
|
||||
caller = Caller(name=name, description=description)
|
||||
return caller
|
||||
|
||||
return Caller(name=name, description=description)
|
||||
|
||||
def get_function_calls() -> dict[str, Caller]:
|
||||
"""获取所有已注册的function call函数
|
||||
|
||||
Returns:
|
||||
dict[str, Caller]: 所有已注册的function call函数
|
||||
"""
|
||||
return _caller_data
|
||||
|
52
nonebot_plugin_marshoai/plugin/func_call/utils.py
Normal file
52
nonebot_plugin_marshoai/plugin/func_call/utils.py
Normal file
@ -0,0 +1,52 @@
|
||||
import inspect
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from ..typing import F
|
||||
|
||||
|
||||
def copy_signature(func: F) -> Callable[[Callable[..., Any]], F]:
|
||||
"""复制函数签名和文档字符串的装饰器"""
|
||||
|
||||
def decorator(wrapper: Callable[..., Any]) -> F:
|
||||
@wraps(func)
|
||||
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||
return wrapper(*args, **kwargs)
|
||||
|
||||
return wrapped # type: ignore
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def async_wrap(func: F) -> F:
|
||||
"""装饰器,将同步函数包装为异步函数
|
||||
|
||||
Args:
|
||||
func (F): 函数对象
|
||||
|
||||
Returns:
|
||||
F: 包装后的函数对象
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore
|
||||
|
||||
|
||||
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
|
||||
"""
|
||||
判断是否为async def 函数
|
||||
请注意:是否为 async def 函数与该函数是否能被await调用是两个不同的概念,具体取决于函数返回值是否为awaitable对象
|
||||
Args:
|
||||
call: 可调用对象
|
||||
Returns:
|
||||
bool: 是否为async def函数
|
||||
"""
|
||||
if inspect.isroutine(call):
|
||||
return inspect.iscoroutinefunction(call)
|
||||
if inspect.isclass(call):
|
||||
return False
|
||||
func_ = getattr(call, "__call__", None)
|
||||
return inspect.iscoroutinefunction(func_)
|
@ -1,5 +1,7 @@
|
||||
from typing import Any, Callable, Coroutine, TypeAlias
|
||||
from typing import Any, Callable, Coroutine, TypeAlias, TypeVar
|
||||
|
||||
SYNC_FUNCTION_CALL_FUNC: TypeAlias = Callable[..., str]
|
||||
ASYNC_FUNCTION_CALL_FUNC: TypeAlias = Callable[..., Coroutine[str, Any, str]]
|
||||
FUNCTION_CALL_FUNC: TypeAlias = SYNC_FUNCTION_CALL_FUNC | ASYNC_FUNCTION_CALL_FUNC
|
||||
|
||||
F = TypeVar("F", bound=FUNCTION_CALL_FUNC)
|
||||
|
@ -18,21 +18,5 @@ def path_to_module_name(path: Path) -> str:
|
||||
return ".".join(rel_path.parts[:-1] + (rel_path.stem,))
|
||||
|
||||
|
||||
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
|
||||
"""
|
||||
判断是否为async def 函数
|
||||
Args:
|
||||
call: 可调用对象
|
||||
Returns:
|
||||
bool: 是否为协程可调用对象
|
||||
"""
|
||||
if inspect.isroutine(call):
|
||||
return inspect.iscoroutinefunction(call)
|
||||
if inspect.isclass(call):
|
||||
return False
|
||||
func_ = getattr(call, "__call__", None)
|
||||
return inspect.iscoroutinefunction(func_)
|
||||
|
||||
|
||||
def parse_function_docsring():
|
||||
pass
|
||||
|
@ -1,3 +1,5 @@
|
||||
from nonebot.adapters.onebot.v11 import MessageEvent
|
||||
|
||||
from nonebot_plugin_marshoai.plugin import (
|
||||
Integer,
|
||||
Parameter,
|
||||
@ -5,6 +7,7 @@ from nonebot_plugin_marshoai.plugin import (
|
||||
String,
|
||||
on_function_call,
|
||||
)
|
||||
from nonebot_plugin_marshoai.plugin.func_call.caller import Caller
|
||||
|
||||
__marsho_meta__ = PluginMetadata(
|
||||
name="SnowyKami 测试插件",
|
||||
@ -19,16 +22,7 @@ __marsho_meta__ = PluginMetadata(
|
||||
gender=String(enum=["男", "女"], description="性别"),
|
||||
)
|
||||
async def fortune_telling(age: int, name: str, gender: str) -> str:
|
||||
"""使用姓名,年龄,性别进行算命
|
||||
|
||||
Args:
|
||||
age (int): _description_
|
||||
name (str): _description_
|
||||
gender (str): _description_
|
||||
|
||||
Returns:
|
||||
str: _description_
|
||||
"""
|
||||
"""使用姓名,年龄,性别进行算命"""
|
||||
|
||||
# 进行一系列算命操作...
|
||||
|
||||
@ -41,17 +35,22 @@ async def fortune_telling(age: int, name: str, gender: str) -> str:
|
||||
unit=String(enum=["摄氏度", "华氏度"], description="温度单位"),
|
||||
)
|
||||
async def get_weather(location: str, days: int, unit: str) -> str:
|
||||
"""获取一个地点未来一段时间的天气
|
||||
|
||||
Args:
|
||||
location (str): 地点名称,可以是城市名、地区名等
|
||||
days (int): 天数
|
||||
unit (str): 温度单位
|
||||
|
||||
Returns:
|
||||
str: 天气信息
|
||||
"""
|
||||
"""获取一个地点未来一段时间的天气"""
|
||||
|
||||
# 进行一系列获取天气操作...
|
||||
|
||||
return f"{location}未来{days}天的天气信息..."
|
||||
|
||||
|
||||
@on_function_call(description="获取设备物理地理位置")
|
||||
async def get_location() -> str:
|
||||
"""获取设备物理地理位置"""
|
||||
|
||||
# 进行一系列获取地理位置操作...
|
||||
|
||||
return "日本 东京都 世田谷区"
|
||||
|
||||
|
||||
@on_function_call(description="获取聊天者个人信息")
|
||||
async def get_user_info(e: MessageEvent, c: Caller) -> str:
|
||||
return f"用户信息:{e.user_id} {e.sender.nickname}, {c._parameters}"
|
||||
|
@ -2,7 +2,7 @@
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "marshoai-memory__write_memory",
|
||||
"name": "marshoai_memory__write_memory",
|
||||
"description": "当你想记住有关与你对话的人的一些信息的时候,调用此函数。",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
|
Loading…
x
Reference in New Issue
Block a user