From a9938d30edd95b87a201f8bde07dba7e921fba3c Mon Sep 17 00:00:00 2001 From: Snowykami Date: Sat, 14 Dec 2024 04:43:03 +0800 Subject: [PATCH] :art: Apply black formatting --- nonebot_plugin_marshoai/azure.py | 11 +++ nonebot_plugin_marshoai/config.py | 2 + nonebot_plugin_marshoai/deal_latex.py | 78 +++++++++++-------- nonebot_plugin_marshoai/plugin/__init__.py | 7 ++ .../{tool => plugin}/load.py | 39 ++++++++-- .../{tool => plugin}/models.py | 61 ++++++++------- nonebot_plugin_marshoai/plugin/register.py | 55 +++++++++++++ nonebot_plugin_marshoai/plugin/utils.py | 34 ++++++++ .../plugins/marshoai_bangumi/__init__.py | 54 +++++++++++++ .../plugins/marshoai_bangumi/tools.json | 9 +++ .../plugins/marshoai_basic/__init__.py | 24 ++++++ .../plugins/marshoai_basic/tools.json | 9 +++ .../plugins/marshoai_basic/tools_test.json | 39 ++++++++++ nonebot_plugin_marshoai/tool/__init__.py | 0 nonebot_plugin_marshoai/tool/utils.py | 16 ---- nonebot_plugin_marshoai/util.py | 5 ++ 16 files changed, 361 insertions(+), 82 deletions(-) create mode 100755 nonebot_plugin_marshoai/plugin/__init__.py rename nonebot_plugin_marshoai/{tool => plugin}/load.py (74%) rename nonebot_plugin_marshoai/{tool => plugin}/models.py (82%) create mode 100644 nonebot_plugin_marshoai/plugin/register.py create mode 100644 nonebot_plugin_marshoai/plugin/utils.py create mode 100755 nonebot_plugin_marshoai/plugins/marshoai_bangumi/__init__.py create mode 100755 nonebot_plugin_marshoai/plugins/marshoai_bangumi/tools.json create mode 100755 nonebot_plugin_marshoai/plugins/marshoai_basic/__init__.py create mode 100755 nonebot_plugin_marshoai/plugins/marshoai_basic/tools.json create mode 100755 nonebot_plugin_marshoai/plugins/marshoai_basic/tools_test.json delete mode 100755 nonebot_plugin_marshoai/tool/__init__.py delete mode 100644 nonebot_plugin_marshoai/tool/utils.py diff --git a/nonebot_plugin_marshoai/azure.py b/nonebot_plugin_marshoai/azure.py index 957b698a..6f972352 100755 --- a/nonebot_plugin_marshoai/azure.py +++ b/nonebot_plugin_marshoai/azure.py @@ -25,6 +25,7 @@ from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna from .metadata import metadata from .models import MarshoContext, MarshoTools +from .plugin import _plugins, load_plugins from .util import * @@ -85,6 +86,7 @@ target_list = [] # 记录需保存历史上下文的列表 @driver.on_startup async def _preload_tools(): + """启动钩子加载工具""" tools_dir = store.get_plugin_data_dir() / "tools" os.makedirs(tools_dir, exist_ok=True) if config.marshoai_enable_tools: @@ -98,6 +100,15 @@ async def _preload_tools(): ) +@driver.on_startup +async def _preload_plugins(): + """启动钩子加载插件""" + marshoai_plugin_dirs = config.marshoai_plugin_dirs + marshoai_plugin_dirs.insert(0, Path(__file__).parent / "plugins") + load_plugins(*marshoai_plugin_dirs) + logger.opt(colors=True).info(f"已加载 {len(_plugins)} 个小棉插件") + + @add_usermsg_cmd.handle() async def add_usermsg(target: MsgTarget, arg: Message = CommandArg()): if msg := arg.extract_plain_text(): diff --git a/nonebot_plugin_marshoai/config.py b/nonebot_plugin_marshoai/config.py index db4743c7..5516640f 100755 --- a/nonebot_plugin_marshoai/config.py +++ b/nonebot_plugin_marshoai/config.py @@ -48,6 +48,8 @@ class ConfigModel(BaseModel): marshoai_tencent_secretid: str | None = None marshoai_tencent_secretkey: str | None = None + marshoai_plugin_dirs: list[str] = [] + yaml = YAML() diff --git a/nonebot_plugin_marshoai/deal_latex.py b/nonebot_plugin_marshoai/deal_latex.py index 5eeec9e8..2b8ba91a 100755 --- a/nonebot_plugin_marshoai/deal_latex.py +++ b/nonebot_plugin_marshoai/deal_latex.py @@ -14,6 +14,7 @@ MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. See the Mulan PSL v2 for more details. """ +import asyncio import time from typing import Literal, Optional, Tuple @@ -35,7 +36,7 @@ class ConvertChannel: return False, "请勿直接调用母类" @staticmethod - def channel_test() -> int: + async def channel_test() -> int: return -1 @@ -90,21 +91,23 @@ class L2PChannel(ConvertChannel): return False, "未知错误" @staticmethod - def channel_test() -> int: - with httpx.Client(timeout=5, verify=False) as client: + async def channel_test() -> int: + async with httpx.AsyncClient(timeout=5, verify=False) as client: try: start_time = time.time_ns() latex2png = ( - client.get( + await client.get( "http://www.latex2png.com{}" - + client.post( - "http://www.latex2png.com/api/convert", - json={ - "auth": {"user": "guest", "password": "guest"}, - "latex": "\\\\int_{a}^{b} x^2 \\\\, dx = \\\\frac{b^3}{3} - \\\\frac{a^3}{5}\n", - "resolution": 600, - "color": "000000", - }, + + ( + await client.post( + "http://www.latex2png.com/api/convert", + json={ + "auth": {"user": "guest", "password": "guest"}, + "latex": "\\\\int_{a}^{b} x^2 \\\\, dx = \\\\frac{b^3}{3} - \\\\frac{a^3}{5}\n", + "resolution": 600, + "color": "000000", + }, + ) ).json()["url"] ), time.time_ns() - start_time, @@ -156,12 +159,12 @@ class CDCChannel(ConvertChannel): return False, "未知错误" @staticmethod - def channel_test() -> int: - with httpx.Client(timeout=5, verify=False) as client: + async def channel_test() -> int: + async with httpx.AsyncClient(timeout=5, verify=False) as client: try: start_time = time.time_ns() codecogs = ( - client.get( + await client.get( r"https://latex.codecogs.com/png.image?\huge%20\dpi{600}\\int_{a}^{b}x^2\\,dx=\\frac{b^3}{3}-\\frac{a^3}{5}" ), time.time_ns() - start_time, @@ -223,19 +226,21 @@ class JRTChannel(ConvertChannel): return False, "未知错误" @staticmethod - def channel_test() -> int: - with httpx.Client(timeout=5, verify=False) as client: + async def channel_test() -> int: + async with httpx.AsyncClient(timeout=5, verify=False) as client: try: start_time = time.time_ns() joeraut = ( - client.get( - client.post( - "http://www.latex2png.com/api/convert", - json={ - "latexInput": "\\\\int_{a}^{b} x^2 \\\\, dx = \\\\frac{b^3}{3} - \\\\frac{a^3}{5}", - "outputFormat": "PNG", - "outputScale": "1000%", - }, + await client.get( + ( + await client.post( + "http://www.latex2png.com/api/convert", + json={ + "latexInput": "\\\\int_{a}^{b} x^2 \\\\, dx = \\\\frac{b^3}{3} - \\\\frac{a^3}{5}", + "outputFormat": "PNG", + "outputScale": "1000%", + }, + ) ).json()["imageUrl"] ), time.time_ns() - start_time, @@ -255,11 +260,14 @@ class ConvertLatex: channel: ConvertChannel - def __init__(self, channel: Optional[ConvertChannel] = None) -> None: + def __init__(self, channel: Optional[ConvertChannel] = None): + logger.info("LaTeX 转换服务将在 Bot 连接时异步加载") + async def load_channel(self, channel: ConvertChannel | None = None) -> None: if channel is None: logger.info("正在选择 LaTeX 转换服务频道,请稍等...") - self.channel = self.auto_choose_channel() + self.channel = await self.auto_choose_channel() + logger.info(f"已选择 {self.channel.__class__.__name__} 服务频道") else: self.channel = channel @@ -297,9 +305,15 @@ class ConvertLatex: ) @staticmethod - def auto_choose_channel() -> ConvertChannel: + async def auto_choose_channel() -> ConvertChannel: + async def channel_test_wrapper( + channel: type[ConvertChannel], + ) -> Tuple[int, type[ConvertChannel]]: + score = await channel.channel_test() + return score, channel - return min( - channel_list, - key=lambda channel: channel.channel_test(), - )() + results = await asyncio.gather( + *(channel_test_wrapper(channel) for channel in channel_list) + ) + best_channel = min(results, key=lambda x: x[0])[1] + return best_channel() diff --git a/nonebot_plugin_marshoai/plugin/__init__.py b/nonebot_plugin_marshoai/plugin/__init__.py new file mode 100755 index 00000000..315134e9 --- /dev/null +++ b/nonebot_plugin_marshoai/plugin/__init__.py @@ -0,0 +1,7 @@ +"""该功能目前正在开发中,暂时不可用,受影响的文件夹 `plugin`, `plugins` +""" + +from .load import * +from .models import * +from .register import * +from .utils import * diff --git a/nonebot_plugin_marshoai/tool/load.py b/nonebot_plugin_marshoai/plugin/load.py similarity index 74% rename from nonebot_plugin_marshoai/tool/load.py rename to nonebot_plugin_marshoai/plugin/load.py index 9dc8e502..7f2a5d5f 100755 --- a/nonebot_plugin_marshoai/tool/load.py +++ b/nonebot_plugin_marshoai/plugin/load.py @@ -23,6 +23,26 @@ __all__ = [ ] +def get_plugin(name: str) -> Plugin | None: + """获取插件对象 + + Args: + name: 插件名称 + Returns: + Optional[Plugin]: 插件对象 + """ + return _plugins.get(name) + + +def get_plugins() -> dict[str, Plugin]: + """获取所有插件 + + Returns: + dict[str, Plugin]: 插件集合 + """ + return _plugins + + def load_plugin(module_path: str | Path) -> Optional[Plugin]: """加载单个插件,可以是本地插件或是通过 `pip` 安装的插件。 该函数产生的副作用在于将插件加载到 `_plugins` 中。 @@ -45,20 +65,23 @@ def load_plugin(module_path: str | Path) -> Optional[Plugin]: module=module, module_name=module_path, ) + _plugins[plugin.name] = plugin plugin.metadata = getattr(module, "__marsho_meta__", None) - _plugins[plugin.name] = plugin + if plugin.metadata is None: + logger.opt(colors=True).warning( + f"成功加载小棉插件 {plugin.name}, 但是没有定义元数据" + ) + else: + logger.opt(colors=True).success( + f'成功加载小棉插件 "{plugin.metadata.name}"' + ) - logger.opt(colors=True).success( - f'Succeeded to load liteyuki plugin "{plugin.name}"' - ) - return _plugins[module.__name__] + return plugin except Exception as e: - logger.opt(colors=True).success( - f'Failed to load liteyuki plugin "{module_path}"' - ) + logger.opt(colors=True).success(f'加载小棉插件失败 "{module_path}"') traceback.print_exc() return None diff --git a/nonebot_plugin_marshoai/tool/models.py b/nonebot_plugin_marshoai/plugin/models.py similarity index 82% rename from nonebot_plugin_marshoai/tool/models.py rename to nonebot_plugin_marshoai/plugin/models.py index c553b648..c0e0ca81 100644 --- a/nonebot_plugin_marshoai/tool/models.py +++ b/nonebot_plugin_marshoai/plugin/models.py @@ -4,32 +4,6 @@ from typing import Any from pydantic import BaseModel -class Plugin(BaseModel): - """ - 存储插件信息 - - Attributes: - ---------- - name: str - 包名称 例如marsho_test - module: ModuleType - 插件模块对象 - module_name: str - 点分割模块路径 例如a.b.c - metadata: "PluginMeta" | None - 元 - """ - - name: str - """包名称 例如marsho_test""" - module: ModuleType - """插件模块对象""" - module_name: str - """点分割模块路径 例如a.b.c""" - metadata: "PluginMetadata" | None = None - """元""" - - class PluginMetadata(BaseModel): """ Marsho 插件 对象元数据 @@ -58,3 +32,38 @@ class PluginMetadata(BaseModel): author: str = "" homepage: str = "" extra: dict[str, Any] = {} + + +class Plugin(BaseModel): + """ + 存储插件信息 + + Attributes: + ---------- + name: str + 包名称 例如marsho_test + module: ModuleType + 插件模块对象 + module_name: str + 点分割模块路径 例如a.b.c + metadata: "PluginMeta" | None + 元 + """ + + name: str + """包名称 例如marsho_test""" + module: ModuleType + """插件模块对象""" + module_name: str + """点分割模块路径 例如a.b.c""" + metadata: PluginMetadata | None = None + """元""" + + class Config: + arbitrary_types_allowed = True + + def __hash__(self) -> int: + return hash(self.name) + + def __eq__(self, other: Any) -> bool: + return self.name == other.name diff --git a/nonebot_plugin_marshoai/plugin/register.py b/nonebot_plugin_marshoai/plugin/register.py new file mode 100644 index 00000000..609bf3ce --- /dev/null +++ b/nonebot_plugin_marshoai/plugin/register.py @@ -0,0 +1,55 @@ +"""此模块用于获取function call中函数定义信息以及注册函数 +""" + +import inspect +from typing import Any, Callable, Coroutine, TypeAlias + +import nonebot + +from .utils import is_coroutine_callable + +SYNC_FUNCTION_CALL: TypeAlias = Callable[..., str] +ASYNC_FUNCTION_CALL: TypeAlias = Callable[..., Coroutine[str, Any, str]] +FUNCTION_CALL: TypeAlias = SYNC_FUNCTION_CALL | ASYNC_FUNCTION_CALL + +_loaded_functions: dict[str, FUNCTION_CALL] = {} + + +def async_wrapper(func: SYNC_FUNCTION_CALL) -> ASYNC_FUNCTION_CALL: + """将同步函数包装为异步函数,但是不会真正异步执行,仅用于统一调用及函数签名 + + Args: + func: 同步函数 + + Returns: + ASYNC_FUNCTION_CALL: 异步函数 + """ + + async def wrapper(*args, **kwargs) -> str: + return func(*args, **kwargs) + + return wrapper + + +def function_call(*funcs: FUNCTION_CALL): + """返回一个装饰器,装饰一个函数, 使其注册为一个可被AI调用的function call函数 + + Args: + func: 函数对象,要有完整的 Google Style Docstring + + Returns: + str: 函数定义信息 + """ + for func in funcs: + if module := inspect.getmodule(func): + module_name = module.__name__ + "." + else: + module_name = "" + name = func.__name__ + if not is_coroutine_callable(func): + func = async_wrapper(func) # type: ignore + + _loaded_functions[name] = func + nonebot.logger.opt(colors=True).info( + f"加载 function call: {module_name}{name}" + ) diff --git a/nonebot_plugin_marshoai/plugin/utils.py b/nonebot_plugin_marshoai/plugin/utils.py new file mode 100644 index 00000000..030c4fea --- /dev/null +++ b/nonebot_plugin_marshoai/plugin/utils.py @@ -0,0 +1,34 @@ +import inspect +from pathlib import Path +from typing import Any, Callable + + +def path_to_module_name(path: Path) -> str: + """ + 转换路径为模块名 + Args: + path: 路径a/b/c/d -> a.b.c.d + Returns: + str: 模块名 + """ + rel_path = path.resolve().relative_to(Path.cwd().resolve()) + if rel_path.stem == "__init__": + return ".".join(rel_path.parts[:-1]) + else: + return ".".join(rel_path.parts[:-1] + (rel_path.stem,)) + + +def is_coroutine_callable(call: Callable[..., Any]) -> bool: + """ + 判断是否为async def 函数 + Args: + call: 可调用对象 + Returns: + bool: 是否为协程可调用对象 + """ + if inspect.isroutine(call): + return inspect.iscoroutinefunction(call) + if inspect.isclass(call): + return False + func_ = getattr(call, "__call__", None) + return inspect.iscoroutinefunction(func_) diff --git a/nonebot_plugin_marshoai/plugins/marshoai_bangumi/__init__.py b/nonebot_plugin_marshoai/plugins/marshoai_bangumi/__init__.py new file mode 100755 index 00000000..96d3194a --- /dev/null +++ b/nonebot_plugin_marshoai/plugins/marshoai_bangumi/__init__.py @@ -0,0 +1,54 @@ +import traceback + +import httpx + +from nonebot_plugin_marshoai.plugin import PluginMetadata, function_call + +__marsho_meta__ = PluginMetadata( + name="Bangumi 番剧信息", + description="Bangumi 番剧信息", + usage="Bangumi 番剧信息", + author="Liteyuki", + homepage="", +) + + +async def fetch_calendar(): + url = "https://api.bgm.tv/calendar" + headers = { + "User-Agent": "LiteyukiStudio/nonebot-plugin-marshoai (https://github.com/LiteyukiStudio/nonebot-plugin-marshoai)" + } + async with httpx.AsyncClient() as client: + response = await client.get(url, headers=headers) + # print(response.text) + return response.json() + + +@function_call +async def get_bangumi_news() -> str: + """获取今天的新番(动漫)列表,在调用之前,你需要知道今天星期几。 + + Returns: + _type_: _description_ + """ + result = await fetch_calendar() + info = "" + try: + for i in result: + weekday = i["weekday"]["cn"] + # print(weekday) + info += f"{weekday}:" + items = i["items"] + for item in items: + name = item["name_cn"] + info += f"《{name}》" + info += "\n" + return info + except Exception as e: + traceback.print_exc() + return "" + + +@function_call +def test_sync() -> str: + return "sync" diff --git a/nonebot_plugin_marshoai/plugins/marshoai_bangumi/tools.json b/nonebot_plugin_marshoai/plugins/marshoai_bangumi/tools.json new file mode 100755 index 00000000..a814f53a --- /dev/null +++ b/nonebot_plugin_marshoai/plugins/marshoai_bangumi/tools.json @@ -0,0 +1,9 @@ +[ + { + "type": "function", + "function": { + "name": "marshoai-bangumi__get_bangumi_news", + "description": "获取今天的新番(动漫)列表,在调用之前,你需要知道今天星期几。" + } + } +] diff --git a/nonebot_plugin_marshoai/plugins/marshoai_basic/__init__.py b/nonebot_plugin_marshoai/plugins/marshoai_basic/__init__.py new file mode 100755 index 00000000..a76a3333 --- /dev/null +++ b/nonebot_plugin_marshoai/plugins/marshoai_basic/__init__.py @@ -0,0 +1,24 @@ +import os + +from zhDateTime import DateTime + + +async def get_weather(location: str): + return f"{location}的温度是114514℃。" + + +async def get_current_env(): + ver = os.popen("uname -a").read() + return str(ver) + + +async def get_current_time(): + current_time = DateTime.now().strftime("%Y.%m.%d %H:%M:%S") + current_weekday = DateTime.now().weekday() + + weekdays = ["星期一", "星期二", "星期三", "星期四", "星期五", "星期六", "星期日"] + current_weekday_name = weekdays[current_weekday] + + current_lunar_date = DateTime.now().to_lunar().date_hanzify()[5:] + time_prompt = f"现在的时间是{current_time},{current_weekday_name},农历{current_lunar_date}。" + return time_prompt diff --git a/nonebot_plugin_marshoai/plugins/marshoai_basic/tools.json b/nonebot_plugin_marshoai/plugins/marshoai_basic/tools.json new file mode 100755 index 00000000..47b477ed --- /dev/null +++ b/nonebot_plugin_marshoai/plugins/marshoai_basic/tools.json @@ -0,0 +1,9 @@ +[ + { + "type": "function", + "function": { + "name": "marshoai-basic__get_current_time", + "description": "获取现在的日期,时间和星期。" + } + } +] diff --git a/nonebot_plugin_marshoai/plugins/marshoai_basic/tools_test.json b/nonebot_plugin_marshoai/plugins/marshoai_basic/tools_test.json new file mode 100755 index 00000000..833ef7cc --- /dev/null +++ b/nonebot_plugin_marshoai/plugins/marshoai_basic/tools_test.json @@ -0,0 +1,39 @@ +[ + { + "type": "function", + "function": { + "name": "marshoai-basic__get_weather", + "description": "当你想查询指定城市的天气时非常有用。", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "城市或县区,比如北京市、杭州市、余杭区等。" + } + } + }, + "required": [ + "location" + ] + } + }, + { + "type": "function", + "function": { + "name": "marshoai-basic__get_current_env", + "description": "获取当前的运行环境。", + "parameters": { + } + } + }, + { + "type": "function", + "function": { + "name": "marshoai-basic__get_current_time", + "description": "获取现在的时间。", + "parameters": { + } + } + } +] diff --git a/nonebot_plugin_marshoai/tool/__init__.py b/nonebot_plugin_marshoai/tool/__init__.py deleted file mode 100755 index e69de29b..00000000 diff --git a/nonebot_plugin_marshoai/tool/utils.py b/nonebot_plugin_marshoai/tool/utils.py deleted file mode 100644 index ca63a25c..00000000 --- a/nonebot_plugin_marshoai/tool/utils.py +++ /dev/null @@ -1,16 +0,0 @@ -from pathlib import Path - - -def path_to_module_name(path: Path) -> str: - """ - 转换路径为模块名 - Args: - path: 路径a/b/c/d -> a.b.c.d - Returns: - str: 模块名 - """ - rel_path = path.resolve().relative_to(Path.cwd().resolve()) - if rel_path.stem == "__init__": - return ".".join(rel_path.parts[:-1]) - else: - return ".".join(rel_path.parts[:-1] + (rel_path.stem,)) diff --git a/nonebot_plugin_marshoai/util.py b/nonebot_plugin_marshoai/util.py index c1b89778..9cda5e3e 100755 --- a/nonebot_plugin_marshoai/util.py +++ b/nonebot_plugin_marshoai/util.py @@ -11,6 +11,7 @@ import nonebot_plugin_localstore as store # from zhDateTime import DateTime from azure.ai.inference.aio import ChatCompletionsClient from azure.ai.inference.models import SystemMessage +from nonebot import get_driver from nonebot.log import logger from nonebot_plugin_alconna import Image as ImageMsg from nonebot_plugin_alconna import Text as TextMsg @@ -280,6 +281,10 @@ if config.marshoai_enable_richtext_parse: latex_convert = ConvertLatex() # 开启一个转换实例 + @get_driver().on_bot_connect + async def load_latex_convert(): + await latex_convert.load_channel(None) + async def get_uuid_back2codeblock( msg: str, code_blank_uuid_map: list[tuple[str, str]] ):