diff --git a/nonebot_plugin_marshoai/azure.py b/nonebot_plugin_marshoai/azure.py
index 957b698a..6f972352 100755
--- a/nonebot_plugin_marshoai/azure.py
+++ b/nonebot_plugin_marshoai/azure.py
@@ -25,6 +25,7 @@ from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
from .metadata import metadata
from .models import MarshoContext, MarshoTools
+from .plugin import _plugins, load_plugins
from .util import *
@@ -85,6 +86,7 @@ target_list = [] # 记录需保存历史上下文的列表
@driver.on_startup
async def _preload_tools():
+ """启动钩子加载工具"""
tools_dir = store.get_plugin_data_dir() / "tools"
os.makedirs(tools_dir, exist_ok=True)
if config.marshoai_enable_tools:
@@ -98,6 +100,15 @@ async def _preload_tools():
)
+@driver.on_startup
+async def _preload_plugins():
+ """启动钩子加载插件"""
+ marshoai_plugin_dirs = config.marshoai_plugin_dirs
+ marshoai_plugin_dirs.insert(0, Path(__file__).parent / "plugins")
+ load_plugins(*marshoai_plugin_dirs)
+ logger.opt(colors=True).info(f"已加载 {len(_plugins)} 个小棉插件")
+
+
@add_usermsg_cmd.handle()
async def add_usermsg(target: MsgTarget, arg: Message = CommandArg()):
if msg := arg.extract_plain_text():
diff --git a/nonebot_plugin_marshoai/config.py b/nonebot_plugin_marshoai/config.py
index db4743c7..5516640f 100755
--- a/nonebot_plugin_marshoai/config.py
+++ b/nonebot_plugin_marshoai/config.py
@@ -48,6 +48,8 @@ class ConfigModel(BaseModel):
marshoai_tencent_secretid: str | None = None
marshoai_tencent_secretkey: str | None = None
+ marshoai_plugin_dirs: list[str] = []
+
yaml = YAML()
diff --git a/nonebot_plugin_marshoai/deal_latex.py b/nonebot_plugin_marshoai/deal_latex.py
index 5eeec9e8..2b8ba91a 100755
--- a/nonebot_plugin_marshoai/deal_latex.py
+++ b/nonebot_plugin_marshoai/deal_latex.py
@@ -14,6 +14,7 @@ MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
"""
+import asyncio
import time
from typing import Literal, Optional, Tuple
@@ -35,7 +36,7 @@ class ConvertChannel:
return False, "请勿直接调用母类"
@staticmethod
- def channel_test() -> int:
+ async def channel_test() -> int:
return -1
@@ -90,21 +91,23 @@ class L2PChannel(ConvertChannel):
return False, "未知错误"
@staticmethod
- def channel_test() -> int:
- with httpx.Client(timeout=5, verify=False) as client:
+ async def channel_test() -> int:
+ async with httpx.AsyncClient(timeout=5, verify=False) as client:
try:
start_time = time.time_ns()
latex2png = (
- client.get(
+ await client.get(
"http://www.latex2png.com{}"
- + client.post(
- "http://www.latex2png.com/api/convert",
- json={
- "auth": {"user": "guest", "password": "guest"},
- "latex": "\\\\int_{a}^{b} x^2 \\\\, dx = \\\\frac{b^3}{3} - \\\\frac{a^3}{5}\n",
- "resolution": 600,
- "color": "000000",
- },
+ + (
+ await client.post(
+ "http://www.latex2png.com/api/convert",
+ json={
+ "auth": {"user": "guest", "password": "guest"},
+ "latex": "\\\\int_{a}^{b} x^2 \\\\, dx = \\\\frac{b^3}{3} - \\\\frac{a^3}{5}\n",
+ "resolution": 600,
+ "color": "000000",
+ },
+ )
).json()["url"]
),
time.time_ns() - start_time,
@@ -156,12 +159,12 @@ class CDCChannel(ConvertChannel):
return False, "未知错误"
@staticmethod
- def channel_test() -> int:
- with httpx.Client(timeout=5, verify=False) as client:
+ async def channel_test() -> int:
+ async with httpx.AsyncClient(timeout=5, verify=False) as client:
try:
start_time = time.time_ns()
codecogs = (
- client.get(
+ await client.get(
r"https://latex.codecogs.com/png.image?\huge%20\dpi{600}\\int_{a}^{b}x^2\\,dx=\\frac{b^3}{3}-\\frac{a^3}{5}"
),
time.time_ns() - start_time,
@@ -223,19 +226,21 @@ class JRTChannel(ConvertChannel):
return False, "未知错误"
@staticmethod
- def channel_test() -> int:
- with httpx.Client(timeout=5, verify=False) as client:
+ async def channel_test() -> int:
+ async with httpx.AsyncClient(timeout=5, verify=False) as client:
try:
start_time = time.time_ns()
joeraut = (
- client.get(
- client.post(
- "http://www.latex2png.com/api/convert",
- json={
- "latexInput": "\\\\int_{a}^{b} x^2 \\\\, dx = \\\\frac{b^3}{3} - \\\\frac{a^3}{5}",
- "outputFormat": "PNG",
- "outputScale": "1000%",
- },
+ await client.get(
+ (
+ await client.post(
+ "http://www.latex2png.com/api/convert",
+ json={
+ "latexInput": "\\\\int_{a}^{b} x^2 \\\\, dx = \\\\frac{b^3}{3} - \\\\frac{a^3}{5}",
+ "outputFormat": "PNG",
+ "outputScale": "1000%",
+ },
+ )
).json()["imageUrl"]
),
time.time_ns() - start_time,
@@ -255,11 +260,14 @@ class ConvertLatex:
channel: ConvertChannel
- def __init__(self, channel: Optional[ConvertChannel] = None) -> None:
+ def __init__(self, channel: Optional[ConvertChannel] = None):
+ logger.info("LaTeX 转换服务将在 Bot 连接时异步加载")
+ async def load_channel(self, channel: ConvertChannel | None = None) -> None:
if channel is None:
logger.info("正在选择 LaTeX 转换服务频道,请稍等...")
- self.channel = self.auto_choose_channel()
+ self.channel = await self.auto_choose_channel()
+ logger.info(f"已选择 {self.channel.__class__.__name__} 服务频道")
else:
self.channel = channel
@@ -297,9 +305,15 @@ class ConvertLatex:
)
@staticmethod
- def auto_choose_channel() -> ConvertChannel:
+ async def auto_choose_channel() -> ConvertChannel:
+ async def channel_test_wrapper(
+ channel: type[ConvertChannel],
+ ) -> Tuple[int, type[ConvertChannel]]:
+ score = await channel.channel_test()
+ return score, channel
- return min(
- channel_list,
- key=lambda channel: channel.channel_test(),
- )()
+ results = await asyncio.gather(
+ *(channel_test_wrapper(channel) for channel in channel_list)
+ )
+ best_channel = min(results, key=lambda x: x[0])[1]
+ return best_channel()
diff --git a/nonebot_plugin_marshoai/plugin/__init__.py b/nonebot_plugin_marshoai/plugin/__init__.py
new file mode 100755
index 00000000..315134e9
--- /dev/null
+++ b/nonebot_plugin_marshoai/plugin/__init__.py
@@ -0,0 +1,7 @@
+"""该功能目前正在开发中,暂时不可用,受影响的文件夹 `plugin`, `plugins`
+"""
+
+from .load import *
+from .models import *
+from .register import *
+from .utils import *
diff --git a/nonebot_plugin_marshoai/tool/load.py b/nonebot_plugin_marshoai/plugin/load.py
similarity index 74%
rename from nonebot_plugin_marshoai/tool/load.py
rename to nonebot_plugin_marshoai/plugin/load.py
index 9dc8e502..7f2a5d5f 100755
--- a/nonebot_plugin_marshoai/tool/load.py
+++ b/nonebot_plugin_marshoai/plugin/load.py
@@ -23,6 +23,26 @@ __all__ = [
]
+def get_plugin(name: str) -> Plugin | None:
+ """获取插件对象
+
+ Args:
+ name: 插件名称
+ Returns:
+ Optional[Plugin]: 插件对象
+ """
+ return _plugins.get(name)
+
+
+def get_plugins() -> dict[str, Plugin]:
+ """获取所有插件
+
+ Returns:
+ dict[str, Plugin]: 插件集合
+ """
+ return _plugins
+
+
def load_plugin(module_path: str | Path) -> Optional[Plugin]:
"""加载单个插件,可以是本地插件或是通过 `pip` 安装的插件。
该函数产生的副作用在于将插件加载到 `_plugins` 中。
@@ -45,20 +65,23 @@ def load_plugin(module_path: str | Path) -> Optional[Plugin]:
module=module,
module_name=module_path,
)
+ _plugins[plugin.name] = plugin
plugin.metadata = getattr(module, "__marsho_meta__", None)
- _plugins[plugin.name] = plugin
+ if plugin.metadata is None:
+ logger.opt(colors=True).warning(
+ f"成功加载小棉插件 {plugin.name}, 但是没有定义元数据"
+ )
+ else:
+ logger.opt(colors=True).success(
+ f'成功加载小棉插件 "{plugin.metadata.name}"'
+ )
- logger.opt(colors=True).success(
- f'Succeeded to load liteyuki plugin "{plugin.name}"'
- )
- return _plugins[module.__name__]
+ return plugin
except Exception as e:
- logger.opt(colors=True).success(
- f'Failed to load liteyuki plugin "{module_path}"'
- )
+ logger.opt(colors=True).success(f'加载小棉插件失败 "{module_path}"')
traceback.print_exc()
return None
diff --git a/nonebot_plugin_marshoai/tool/models.py b/nonebot_plugin_marshoai/plugin/models.py
similarity index 82%
rename from nonebot_plugin_marshoai/tool/models.py
rename to nonebot_plugin_marshoai/plugin/models.py
index c553b648..c0e0ca81 100644
--- a/nonebot_plugin_marshoai/tool/models.py
+++ b/nonebot_plugin_marshoai/plugin/models.py
@@ -4,32 +4,6 @@ from typing import Any
from pydantic import BaseModel
-class Plugin(BaseModel):
- """
- 存储插件信息
-
- Attributes:
- ----------
- name: str
- 包名称 例如marsho_test
- module: ModuleType
- 插件模块对象
- module_name: str
- 点分割模块路径 例如a.b.c
- metadata: "PluginMeta" | None
- 元
- """
-
- name: str
- """包名称 例如marsho_test"""
- module: ModuleType
- """插件模块对象"""
- module_name: str
- """点分割模块路径 例如a.b.c"""
- metadata: "PluginMetadata" | None = None
- """元"""
-
-
class PluginMetadata(BaseModel):
"""
Marsho 插件 对象元数据
@@ -58,3 +32,38 @@ class PluginMetadata(BaseModel):
author: str = ""
homepage: str = ""
extra: dict[str, Any] = {}
+
+
+class Plugin(BaseModel):
+ """
+ 存储插件信息
+
+ Attributes:
+ ----------
+ name: str
+ 包名称 例如marsho_test
+ module: ModuleType
+ 插件模块对象
+ module_name: str
+ 点分割模块路径 例如a.b.c
+ metadata: "PluginMeta" | None
+ 元
+ """
+
+ name: str
+ """包名称 例如marsho_test"""
+ module: ModuleType
+ """插件模块对象"""
+ module_name: str
+ """点分割模块路径 例如a.b.c"""
+ metadata: PluginMetadata | None = None
+ """元"""
+
+ class Config:
+ arbitrary_types_allowed = True
+
+ def __hash__(self) -> int:
+ return hash(self.name)
+
+ def __eq__(self, other: Any) -> bool:
+ return self.name == other.name
diff --git a/nonebot_plugin_marshoai/plugin/register.py b/nonebot_plugin_marshoai/plugin/register.py
new file mode 100644
index 00000000..609bf3ce
--- /dev/null
+++ b/nonebot_plugin_marshoai/plugin/register.py
@@ -0,0 +1,55 @@
+"""此模块用于获取function call中函数定义信息以及注册函数
+"""
+
+import inspect
+from typing import Any, Callable, Coroutine, TypeAlias
+
+import nonebot
+
+from .utils import is_coroutine_callable
+
+SYNC_FUNCTION_CALL: TypeAlias = Callable[..., str]
+ASYNC_FUNCTION_CALL: TypeAlias = Callable[..., Coroutine[str, Any, str]]
+FUNCTION_CALL: TypeAlias = SYNC_FUNCTION_CALL | ASYNC_FUNCTION_CALL
+
+_loaded_functions: dict[str, FUNCTION_CALL] = {}
+
+
+def async_wrapper(func: SYNC_FUNCTION_CALL) -> ASYNC_FUNCTION_CALL:
+ """将同步函数包装为异步函数,但是不会真正异步执行,仅用于统一调用及函数签名
+
+ Args:
+ func: 同步函数
+
+ Returns:
+ ASYNC_FUNCTION_CALL: 异步函数
+ """
+
+ async def wrapper(*args, **kwargs) -> str:
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+def function_call(*funcs: FUNCTION_CALL):
+ """返回一个装饰器,装饰一个函数, 使其注册为一个可被AI调用的function call函数
+
+ Args:
+ func: 函数对象,要有完整的 Google Style Docstring
+
+ Returns:
+ str: 函数定义信息
+ """
+ for func in funcs:
+ if module := inspect.getmodule(func):
+ module_name = module.__name__ + "."
+ else:
+ module_name = ""
+ name = func.__name__
+ if not is_coroutine_callable(func):
+ func = async_wrapper(func) # type: ignore
+
+ _loaded_functions[name] = func
+ nonebot.logger.opt(colors=True).info(
+ f"加载 function call: {module_name}{name}"
+ )
diff --git a/nonebot_plugin_marshoai/plugin/utils.py b/nonebot_plugin_marshoai/plugin/utils.py
new file mode 100644
index 00000000..030c4fea
--- /dev/null
+++ b/nonebot_plugin_marshoai/plugin/utils.py
@@ -0,0 +1,34 @@
+import inspect
+from pathlib import Path
+from typing import Any, Callable
+
+
+def path_to_module_name(path: Path) -> str:
+ """
+ 转换路径为模块名
+ Args:
+ path: 路径a/b/c/d -> a.b.c.d
+ Returns:
+ str: 模块名
+ """
+ rel_path = path.resolve().relative_to(Path.cwd().resolve())
+ if rel_path.stem == "__init__":
+ return ".".join(rel_path.parts[:-1])
+ else:
+ return ".".join(rel_path.parts[:-1] + (rel_path.stem,))
+
+
+def is_coroutine_callable(call: Callable[..., Any]) -> bool:
+ """
+ 判断是否为async def 函数
+ Args:
+ call: 可调用对象
+ Returns:
+ bool: 是否为协程可调用对象
+ """
+ if inspect.isroutine(call):
+ return inspect.iscoroutinefunction(call)
+ if inspect.isclass(call):
+ return False
+ func_ = getattr(call, "__call__", None)
+ return inspect.iscoroutinefunction(func_)
diff --git a/nonebot_plugin_marshoai/plugins/marshoai_bangumi/__init__.py b/nonebot_plugin_marshoai/plugins/marshoai_bangumi/__init__.py
new file mode 100755
index 00000000..96d3194a
--- /dev/null
+++ b/nonebot_plugin_marshoai/plugins/marshoai_bangumi/__init__.py
@@ -0,0 +1,54 @@
+import traceback
+
+import httpx
+
+from nonebot_plugin_marshoai.plugin import PluginMetadata, function_call
+
+__marsho_meta__ = PluginMetadata(
+ name="Bangumi 番剧信息",
+ description="Bangumi 番剧信息",
+ usage="Bangumi 番剧信息",
+ author="Liteyuki",
+ homepage="",
+)
+
+
+async def fetch_calendar():
+ url = "https://api.bgm.tv/calendar"
+ headers = {
+ "User-Agent": "LiteyukiStudio/nonebot-plugin-marshoai (https://github.com/LiteyukiStudio/nonebot-plugin-marshoai)"
+ }
+ async with httpx.AsyncClient() as client:
+ response = await client.get(url, headers=headers)
+ # print(response.text)
+ return response.json()
+
+
+@function_call
+async def get_bangumi_news() -> str:
+ """获取今天的新番(动漫)列表,在调用之前,你需要知道今天星期几。
+
+ Returns:
+ _type_: _description_
+ """
+ result = await fetch_calendar()
+ info = ""
+ try:
+ for i in result:
+ weekday = i["weekday"]["cn"]
+ # print(weekday)
+ info += f"{weekday}:"
+ items = i["items"]
+ for item in items:
+ name = item["name_cn"]
+ info += f"《{name}》"
+ info += "\n"
+ return info
+ except Exception as e:
+ traceback.print_exc()
+ return ""
+
+
+@function_call
+def test_sync() -> str:
+ return "sync"
diff --git a/nonebot_plugin_marshoai/plugins/marshoai_bangumi/tools.json b/nonebot_plugin_marshoai/plugins/marshoai_bangumi/tools.json
new file mode 100755
index 00000000..a814f53a
--- /dev/null
+++ b/nonebot_plugin_marshoai/plugins/marshoai_bangumi/tools.json
@@ -0,0 +1,9 @@
+[
+ {
+ "type": "function",
+ "function": {
+ "name": "marshoai-bangumi__get_bangumi_news",
+ "description": "获取今天的新番(动漫)列表,在调用之前,你需要知道今天星期几。"
+ }
+ }
+]
diff --git a/nonebot_plugin_marshoai/plugins/marshoai_basic/__init__.py b/nonebot_plugin_marshoai/plugins/marshoai_basic/__init__.py
new file mode 100755
index 00000000..a76a3333
--- /dev/null
+++ b/nonebot_plugin_marshoai/plugins/marshoai_basic/__init__.py
@@ -0,0 +1,24 @@
+import os
+
+from zhDateTime import DateTime
+
+
+async def get_weather(location: str):
+ return f"{location}的温度是114514℃。"
+
+
+async def get_current_env():
+ ver = os.popen("uname -a").read()
+ return str(ver)
+
+
+async def get_current_time():
+ current_time = DateTime.now().strftime("%Y.%m.%d %H:%M:%S")
+ current_weekday = DateTime.now().weekday()
+
+ weekdays = ["星期一", "星期二", "星期三", "星期四", "星期五", "星期六", "星期日"]
+ current_weekday_name = weekdays[current_weekday]
+
+ current_lunar_date = DateTime.now().to_lunar().date_hanzify()[5:]
+ time_prompt = f"现在的时间是{current_time},{current_weekday_name},农历{current_lunar_date}。"
+ return time_prompt
diff --git a/nonebot_plugin_marshoai/plugins/marshoai_basic/tools.json b/nonebot_plugin_marshoai/plugins/marshoai_basic/tools.json
new file mode 100755
index 00000000..47b477ed
--- /dev/null
+++ b/nonebot_plugin_marshoai/plugins/marshoai_basic/tools.json
@@ -0,0 +1,9 @@
+[
+ {
+ "type": "function",
+ "function": {
+ "name": "marshoai-basic__get_current_time",
+ "description": "获取现在的日期,时间和星期。"
+ }
+ }
+]
diff --git a/nonebot_plugin_marshoai/plugins/marshoai_basic/tools_test.json b/nonebot_plugin_marshoai/plugins/marshoai_basic/tools_test.json
new file mode 100755
index 00000000..833ef7cc
--- /dev/null
+++ b/nonebot_plugin_marshoai/plugins/marshoai_basic/tools_test.json
@@ -0,0 +1,39 @@
+[
+ {
+ "type": "function",
+ "function": {
+ "name": "marshoai-basic__get_weather",
+ "description": "当你想查询指定城市的天气时非常有用。",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "城市或县区,比如北京市、杭州市、余杭区等。"
+ }
+ }
+ },
+ "required": [
+ "location"
+ ]
+ }
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "marshoai-basic__get_current_env",
+ "description": "获取当前的运行环境。",
+ "parameters": {
+ }
+ }
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "marshoai-basic__get_current_time",
+ "description": "获取现在的时间。",
+ "parameters": {
+ }
+ }
+ }
+]
diff --git a/nonebot_plugin_marshoai/tool/__init__.py b/nonebot_plugin_marshoai/tool/__init__.py
deleted file mode 100755
index e69de29b..00000000
diff --git a/nonebot_plugin_marshoai/tool/utils.py b/nonebot_plugin_marshoai/tool/utils.py
deleted file mode 100644
index ca63a25c..00000000
--- a/nonebot_plugin_marshoai/tool/utils.py
+++ /dev/null
@@ -1,16 +0,0 @@
-from pathlib import Path
-
-
-def path_to_module_name(path: Path) -> str:
- """
- 转换路径为模块名
- Args:
- path: 路径a/b/c/d -> a.b.c.d
- Returns:
- str: 模块名
- """
- rel_path = path.resolve().relative_to(Path.cwd().resolve())
- if rel_path.stem == "__init__":
- return ".".join(rel_path.parts[:-1])
- else:
- return ".".join(rel_path.parts[:-1] + (rel_path.stem,))
diff --git a/nonebot_plugin_marshoai/util.py b/nonebot_plugin_marshoai/util.py
index c1b89778..9cda5e3e 100755
--- a/nonebot_plugin_marshoai/util.py
+++ b/nonebot_plugin_marshoai/util.py
@@ -11,6 +11,7 @@ import nonebot_plugin_localstore as store
# from zhDateTime import DateTime
from azure.ai.inference.aio import ChatCompletionsClient
from azure.ai.inference.models import SystemMessage
+from nonebot import get_driver
from nonebot.log import logger
from nonebot_plugin_alconna import Image as ImageMsg
from nonebot_plugin_alconna import Text as TextMsg
@@ -280,6 +281,10 @@ if config.marshoai_enable_richtext_parse:
latex_convert = ConvertLatex() # 开启一个转换实例
+ @get_driver().on_bot_connect
+ async def load_latex_convert():
+ await latex_convert.load_channel(None)
+
async def get_uuid_back2codeblock(
msg: str, code_blank_uuid_map: list[tuple[str, str]]
):