diff --git a/nonebot_plugin_marshoai/azure.py b/nonebot_plugin_marshoai/azure.py
index 6f972352..65952e9b 100755
--- a/nonebot_plugin_marshoai/azure.py
+++ b/nonebot_plugin_marshoai/azure.py
@@ -23,6 +23,8 @@ from nonebot.permission import SUPERUSER
from nonebot.rule import Rule, to_me
from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
+from nonebot_plugin_marshoai.plugin.func_call.caller import get_function_calls
+
from .metadata import metadata
from .models import MarshoContext, MarshoTools
from .plugin import _plugins, load_plugins
@@ -103,10 +105,9 @@ async def _preload_tools():
@driver.on_startup
async def _preload_plugins():
"""启动钩子加载插件"""
- marshoai_plugin_dirs = config.marshoai_plugin_dirs
- marshoai_plugin_dirs.insert(0, Path(__file__).parent / "plugins")
+ marshoai_plugin_dirs = config.marshoai_plugin_dirs # 外部插件目录列表
+ marshoai_plugin_dirs.insert(0, Path(__file__).parent / "plugins") # 预置插件目录
load_plugins(*marshoai_plugin_dirs)
- logger.opt(colors=True).info(f"已加载 {len(_plugins)} 个小棉插件")
@add_usermsg_cmd.handle()
@@ -266,7 +267,10 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None)
client=client,
model_name=model_name,
msg=context_msg + [UserMessage(content=usermsg)], # type: ignore
- tools=tools.get_tools_list(),
+ tools=tools.get_tools_list()
+ + list(
+ map(lambda v: v.data(), get_function_calls().values())
+ ), # TODO 临时追加函数,后期优化
)
# await UniMessage(str(response)).send()
choice = response.choices[0]
@@ -315,9 +319,23 @@ async def marsho(target: MsgTarget, event: Event, text: Optional[UniMsg] = None)
await UniMessage(
f"调用函数 {tool_call.function.name} ,参数为 {function_args}"
).send()
- func_return = await tools.call(
- tool_call.function.name, function_args
- ) # 获取返回值
+ # TODO 临时追加插件函数,若工具中没有则调用插件函数
+ if tools.has_function(tool_call.function.name):
+ logger.debug(f"调用工具函数 {tool_call.function.name}")
+ func_return = await tools.call(
+ tool_call.function.name, function_args
+ ) # 获取返回值
+ else:
+ if caller := get_function_calls().get(
+ tool_call.function.name
+ ):
+ logger.debug(f"调用插件函数 {tool_call.function.name}")
+ # 实现依赖注入,检查函数参数及参数注解类型,对Event类型的参数进行注入
+ caller.event = event
+ func_return = await caller.call(**function_args)
+ else:
+ logger.error(f"未找到函数 {tool_call.function.name}")
+ func_return = f"未找到函数 {tool_call.function.name}"
tool_msg.append(
ToolMessage(tool_call_id=tool_call.id, content=func_return) # type: ignore
)
diff --git a/nonebot_plugin_marshoai/models.py b/nonebot_plugin_marshoai/models.py
index 160b743d..04ae473e 100755
--- a/nonebot_plugin_marshoai/models.py
+++ b/nonebot_plugin_marshoai/models.py
@@ -90,6 +90,7 @@ class MarshoTools:
with open(json_path, "r", encoding="utf-8") as json_file:
data = json.load(json_file)
for i in data:
+
self.tools_list.append(i)
spec = importlib.util.spec_from_file_location(
@@ -136,6 +137,21 @@ class MarshoTools:
else:
logger.error(f"工具包 '{package_name}' 未导入")
+ def has_function(self, full_function_name: str) -> bool:
+ """
+ 检查是否存在指定的函数
+ """
+ try:
+ for t in self.tools_list:
+ if t["function"]["name"].replace(
+ "-", "_"
+ ) == full_function_name.replace("-", "_"):
+ return True
+ return False
+ except Exception as e:
+ logger.error(f"检查函数 '{full_function_name}' 时发生错误:{e}")
+ return False
+
def get_tools_list(self):
if not self.tools_list or not config.marshoai_enable_tools:
return None
diff --git a/nonebot_plugin_marshoai/plugin/func_call/caller.py b/nonebot_plugin_marshoai/plugin/func_call/caller.py
index cc8fd767..19181662 100644
--- a/nonebot_plugin_marshoai/plugin/func_call/caller.py
+++ b/nonebot_plugin_marshoai/plugin/func_call/caller.py
@@ -1,36 +1,34 @@
-from typing import Generic, TypeVar
+import inspect
+from typing import Any
from nonebot import logger
+from nonebot.adapters import Event
-from ..typing import FUNCTION_CALL_FUNC
-from .params import P
+from ..typing import ASYNC_FUNCTION_CALL_FUNC, F
+from .utils import async_wrap, is_coroutine_callable
-F = TypeVar("F", bound=FUNCTION_CALL_FUNC)
+_caller_data: dict[str, "Caller"] = {}
-class Caller(Generic[P]):
+class Caller:
def __init__(self, name: str | None = None, description: str | None = None):
self._name = name
self._description = description
- self._parameters: dict[str, P] = {}
- self.func: FUNCTION_CALL_FUNC | None = None
+ self.func: ASYNC_FUNCTION_CALL_FUNC | None = None
+ self._parameters: dict[str, Any] = {}
+ """依赖注入的参数"""
+ self.event: Event | None = None
- def params(self, **kwargs: P) -> "Caller":
- """设置多个函数参数
- Args:
- **kwargs: 参数字典
- Returns:
- Caller: Caller对象
- """
+ def params(self, **kwargs: Any) -> "Caller":
self._parameters.update(kwargs)
return self
- def param(self, name: str, param: P) -> "Caller":
+ def param(self, name: str, param: Any) -> "Caller":
"""设置一个函数参数
Args:
name (str): 参数名
- param (P): 参数对象
+ param (Any): 参数对象
Returns:
Caller: Caller对象
@@ -51,14 +49,6 @@ class Caller(Generic[P]):
return self
def description(self, description: str) -> "Caller":
- """设置函数描述
-
- Args:
- description (str): 函数描述
-
- Returns:
- Caller: Caller对象
- """
self._description = description
return self
@@ -71,12 +61,78 @@ class Caller(Generic[P]):
Returns:
F: 函数对象
"""
+ global _caller_data
+ if self._name is None:
+ if module := inspect.getmodule(func):
+ module_name = module.__name__.split(".")[-1]
+ else:
+ module_name = "global"
+ self._name = f"{module_name}-{func.__name__}"
+ _caller_data[self._name] = self
+
+ if is_coroutine_callable(func):
+ self.func = func # type: ignore
+ else:
+ self.func = async_wrap(func) # type: ignore
+
+ if module := inspect.getmodule(func):
+ module_name = module.__name__ + "."
+ else:
+ module_name = ""
logger.opt(colors=True).info(
- f"加载函数 {func.__name__} {self._description}"
+ f"加载函数 {module_name}{func.__name__}: {self._description}"
)
- self.func = func
+
return func
+ def data(self) -> dict[str, Any]:
+ """返回函数的json数据
+
+ Returns:
+ dict[str, Any]: 函数的json数据
+ """
+ return {
+ "type": "function",
+ "function": {
+ "name": self._name,
+ "description": self._description,
+ "parameters": {
+ "type": "object",
+ "properties": {
+ key: value.data() for key, value in self._parameters.items()
+ },
+ },
+ "required": [
+ key
+ for key, value in self._parameters.items()
+ if value.default is None
+ ],
+ },
+ }
+
+ def set_event(self, event: Event):
+ self.event = event
+
+ async def call(self, *args: Any, **kwargs: Any) -> Any:
+ """调用函数
+
+ Returns:
+ Any: 函数返回值
+ """
+ if self.func is None:
+ raise ValueError("未注册函数对象")
+ sig = inspect.signature(self.func)
+ for name, param in sig.parameters.items():
+ if issubclass(param.annotation, Event) or isinstance(
+ param.annotation, Event
+ ):
+ kwargs[name] = self.event
+ if issubclass(param.annotation, Caller) or isinstance(
+ param.annotation, Caller
+ ):
+ kwargs[name] = self
+ return await self.func(*args, **kwargs)
+
def on_function_call(name: str | None = None, description: str | None = None) -> Caller:
"""返回一个Caller类,可用于装饰一个函数,使其注册为一个可被AI调用的function call函数
@@ -87,5 +143,14 @@ def on_function_call(name: str | None = None, description: str | None = None) ->
Returns:
Caller: Caller对象
"""
+ caller = Caller(name=name, description=description)
+ return caller
- return Caller(name=name, description=description)
+
+def get_function_calls() -> dict[str, Caller]:
+ """获取所有已注册的function call函数
+
+ Returns:
+ dict[str, Caller]: 所有已注册的function call函数
+ """
+ return _caller_data
diff --git a/nonebot_plugin_marshoai/plugin/func_call/utils.py b/nonebot_plugin_marshoai/plugin/func_call/utils.py
new file mode 100644
index 00000000..1ce908a8
--- /dev/null
+++ b/nonebot_plugin_marshoai/plugin/func_call/utils.py
@@ -0,0 +1,52 @@
+import inspect
+from functools import wraps
+from typing import TYPE_CHECKING, Any, Callable
+
+from ..typing import F
+
+
+def copy_signature(func: F) -> Callable[[Callable[..., Any]], F]:
+ """复制函数签名和文档字符串的装饰器"""
+
+ def decorator(wrapper: Callable[..., Any]) -> F:
+ @wraps(func)
+ def wrapped(*args: Any, **kwargs: Any) -> Any:
+ return wrapper(*args, **kwargs)
+
+ return wrapped # type: ignore
+
+ return decorator
+
+
+def async_wrap(func: F) -> F:
+ """装饰器,将同步函数包装为异步函数
+
+ Args:
+ func (F): 函数对象
+
+ Returns:
+ F: 包装后的函数对象
+ """
+
+ @wraps(func)
+ async def wrapper(*args: Any, **kwargs: Any) -> Any:
+ return func(*args, **kwargs)
+
+ return wrapper # type: ignore
+
+
+def is_coroutine_callable(call: Callable[..., Any]) -> bool:
+ """
+ 判断是否为async def 函数
+ 请注意:是否为 async def 函数与该函数是否能被await调用是两个不同的概念,具体取决于函数返回值是否为awaitable对象
+ Args:
+ call: 可调用对象
+ Returns:
+ bool: 是否为async def函数
+ """
+ if inspect.isroutine(call):
+ return inspect.iscoroutinefunction(call)
+ if inspect.isclass(call):
+ return False
+ func_ = getattr(call, "__call__", None)
+ return inspect.iscoroutinefunction(func_)
diff --git a/nonebot_plugin_marshoai/plugin/typing.py b/nonebot_plugin_marshoai/plugin/typing.py
index 1618dc26..54554e2a 100755
--- a/nonebot_plugin_marshoai/plugin/typing.py
+++ b/nonebot_plugin_marshoai/plugin/typing.py
@@ -1,5 +1,7 @@
-from typing import Any, Callable, Coroutine, TypeAlias
+from typing import Any, Callable, Coroutine, TypeAlias, TypeVar
SYNC_FUNCTION_CALL_FUNC: TypeAlias = Callable[..., str]
ASYNC_FUNCTION_CALL_FUNC: TypeAlias = Callable[..., Coroutine[str, Any, str]]
FUNCTION_CALL_FUNC: TypeAlias = SYNC_FUNCTION_CALL_FUNC | ASYNC_FUNCTION_CALL_FUNC
+
+F = TypeVar("F", bound=FUNCTION_CALL_FUNC)
diff --git a/nonebot_plugin_marshoai/plugin/utils.py b/nonebot_plugin_marshoai/plugin/utils.py
index 55dd7e1c..dc02c670 100755
--- a/nonebot_plugin_marshoai/plugin/utils.py
+++ b/nonebot_plugin_marshoai/plugin/utils.py
@@ -18,21 +18,5 @@ def path_to_module_name(path: Path) -> str:
return ".".join(rel_path.parts[:-1] + (rel_path.stem,))
-def is_coroutine_callable(call: Callable[..., Any]) -> bool:
- """
- 判断是否为async def 函数
- Args:
- call: 可调用对象
- Returns:
- bool: 是否为协程可调用对象
- """
- if inspect.isroutine(call):
- return inspect.iscoroutinefunction(call)
- if inspect.isclass(call):
- return False
- func_ = getattr(call, "__call__", None)
- return inspect.iscoroutinefunction(func_)
-
-
def parse_function_docsring():
pass
diff --git a/nonebot_plugin_marshoai/plugins/snowykami_testplugin/__init__.py b/nonebot_plugin_marshoai/plugins/snowykami_testplugin/__init__.py
index 47837a0e..d851cb1d 100644
--- a/nonebot_plugin_marshoai/plugins/snowykami_testplugin/__init__.py
+++ b/nonebot_plugin_marshoai/plugins/snowykami_testplugin/__init__.py
@@ -1,3 +1,5 @@
+from nonebot.adapters.onebot.v11 import MessageEvent
+
from nonebot_plugin_marshoai.plugin import (
Integer,
Parameter,
@@ -5,6 +7,7 @@ from nonebot_plugin_marshoai.plugin import (
String,
on_function_call,
)
+from nonebot_plugin_marshoai.plugin.func_call.caller import Caller
__marsho_meta__ = PluginMetadata(
name="SnowyKami 测试插件",
@@ -19,16 +22,7 @@ __marsho_meta__ = PluginMetadata(
gender=String(enum=["男", "女"], description="性别"),
)
async def fortune_telling(age: int, name: str, gender: str) -> str:
- """使用姓名,年龄,性别进行算命
-
- Args:
- age (int): _description_
- name (str): _description_
- gender (str): _description_
-
- Returns:
- str: _description_
- """
+ """使用姓名,年龄,性别进行算命"""
# 进行一系列算命操作...
@@ -41,17 +35,22 @@ async def fortune_telling(age: int, name: str, gender: str) -> str:
unit=String(enum=["摄氏度", "华氏度"], description="温度单位"),
)
async def get_weather(location: str, days: int, unit: str) -> str:
- """获取一个地点未来一段时间的天气
-
- Args:
- location (str): 地点名称,可以是城市名、地区名等
- days (int): 天数
- unit (str): 温度单位
-
- Returns:
- str: 天气信息
- """
+ """获取一个地点未来一段时间的天气"""
# 进行一系列获取天气操作...
return f"{location}未来{days}天的天气信息..."
+
+
+@on_function_call(description="获取设备物理地理位置")
+async def get_location() -> str:
+ """获取设备物理地理位置"""
+
+ # 进行一系列获取地理位置操作...
+
+ return "日本 东京都 世田谷区"
+
+
+@on_function_call(description="获取聊天者个人信息")
+async def get_user_info(e: MessageEvent, c: Caller) -> str:
+ return f"用户信息:{e.user_id} {e.sender.nickname}, {c._parameters}"
diff --git a/nonebot_plugin_marshoai/tools_wip/marshoai_memory/tools.json b/nonebot_plugin_marshoai/tools_wip/marshoai_memory/tools.json
index adab49a4..73e1c22b 100755
--- a/nonebot_plugin_marshoai/tools_wip/marshoai_memory/tools.json
+++ b/nonebot_plugin_marshoai/tools_wip/marshoai_memory/tools.json
@@ -2,7 +2,7 @@
{
"type": "function",
"function": {
- "name": "marshoai-memory__write_memory",
+ "name": "marshoai_memory__write_memory",
"description": "当你想记住有关与你对话的人的一些信息的时候,调用此函数。",
"parameters": {
"type": "object",