🎨 Apply black formatting

This commit is contained in:
远野千束(神羽) 2024-12-14 04:43:03 +08:00
parent 59e0871840
commit a9938d30ed
16 changed files with 361 additions and 82 deletions

View File

@ -25,6 +25,7 @@ from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
from .metadata import metadata from .metadata import metadata
from .models import MarshoContext, MarshoTools from .models import MarshoContext, MarshoTools
from .plugin import _plugins, load_plugins
from .util import * from .util import *
@ -85,6 +86,7 @@ target_list = [] # 记录需保存历史上下文的列表
@driver.on_startup @driver.on_startup
async def _preload_tools(): async def _preload_tools():
"""启动钩子加载工具"""
tools_dir = store.get_plugin_data_dir() / "tools" tools_dir = store.get_plugin_data_dir() / "tools"
os.makedirs(tools_dir, exist_ok=True) os.makedirs(tools_dir, exist_ok=True)
if config.marshoai_enable_tools: if config.marshoai_enable_tools:
@ -98,6 +100,15 @@ async def _preload_tools():
) )
@driver.on_startup
async def _preload_plugins():
"""启动钩子加载插件"""
marshoai_plugin_dirs = config.marshoai_plugin_dirs
marshoai_plugin_dirs.insert(0, Path(__file__).parent / "plugins")
load_plugins(*marshoai_plugin_dirs)
logger.opt(colors=True).info(f"已加载 <c>{len(_plugins)}</c> 个小棉插件")
@add_usermsg_cmd.handle() @add_usermsg_cmd.handle()
async def add_usermsg(target: MsgTarget, arg: Message = CommandArg()): async def add_usermsg(target: MsgTarget, arg: Message = CommandArg()):
if msg := arg.extract_plain_text(): if msg := arg.extract_plain_text():

View File

@ -48,6 +48,8 @@ class ConfigModel(BaseModel):
marshoai_tencent_secretid: str | None = None marshoai_tencent_secretid: str | None = None
marshoai_tencent_secretkey: str | None = None marshoai_tencent_secretkey: str | None = None
marshoai_plugin_dirs: list[str] = []
yaml = YAML() yaml = YAML()

View File

@ -14,6 +14,7 @@ MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details. See the Mulan PSL v2 for more details.
""" """
import asyncio
import time import time
from typing import Literal, Optional, Tuple from typing import Literal, Optional, Tuple
@ -35,7 +36,7 @@ class ConvertChannel:
return False, "请勿直接调用母类" return False, "请勿直接调用母类"
@staticmethod @staticmethod
def channel_test() -> int: async def channel_test() -> int:
return -1 return -1
@ -90,14 +91,15 @@ class L2PChannel(ConvertChannel):
return False, "未知错误" return False, "未知错误"
@staticmethod @staticmethod
def channel_test() -> int: async def channel_test() -> int:
with httpx.Client(timeout=5, verify=False) as client: async with httpx.AsyncClient(timeout=5, verify=False) as client:
try: try:
start_time = time.time_ns() start_time = time.time_ns()
latex2png = ( latex2png = (
client.get( await client.get(
"http://www.latex2png.com{}" "http://www.latex2png.com{}"
+ client.post( + (
await client.post(
"http://www.latex2png.com/api/convert", "http://www.latex2png.com/api/convert",
json={ json={
"auth": {"user": "guest", "password": "guest"}, "auth": {"user": "guest", "password": "guest"},
@ -105,6 +107,7 @@ class L2PChannel(ConvertChannel):
"resolution": 600, "resolution": 600,
"color": "000000", "color": "000000",
}, },
)
).json()["url"] ).json()["url"]
), ),
time.time_ns() - start_time, time.time_ns() - start_time,
@ -156,12 +159,12 @@ class CDCChannel(ConvertChannel):
return False, "未知错误" return False, "未知错误"
@staticmethod @staticmethod
def channel_test() -> int: async def channel_test() -> int:
with httpx.Client(timeout=5, verify=False) as client: async with httpx.AsyncClient(timeout=5, verify=False) as client:
try: try:
start_time = time.time_ns() start_time = time.time_ns()
codecogs = ( codecogs = (
client.get( await client.get(
r"https://latex.codecogs.com/png.image?\huge%20\dpi{600}\\int_{a}^{b}x^2\\,dx=\\frac{b^3}{3}-\\frac{a^3}{5}" r"https://latex.codecogs.com/png.image?\huge%20\dpi{600}\\int_{a}^{b}x^2\\,dx=\\frac{b^3}{3}-\\frac{a^3}{5}"
), ),
time.time_ns() - start_time, time.time_ns() - start_time,
@ -223,19 +226,21 @@ class JRTChannel(ConvertChannel):
return False, "未知错误" return False, "未知错误"
@staticmethod @staticmethod
def channel_test() -> int: async def channel_test() -> int:
with httpx.Client(timeout=5, verify=False) as client: async with httpx.AsyncClient(timeout=5, verify=False) as client:
try: try:
start_time = time.time_ns() start_time = time.time_ns()
joeraut = ( joeraut = (
client.get( await client.get(
client.post( (
await client.post(
"http://www.latex2png.com/api/convert", "http://www.latex2png.com/api/convert",
json={ json={
"latexInput": "\\\\int_{a}^{b} x^2 \\\\, dx = \\\\frac{b^3}{3} - \\\\frac{a^3}{5}", "latexInput": "\\\\int_{a}^{b} x^2 \\\\, dx = \\\\frac{b^3}{3} - \\\\frac{a^3}{5}",
"outputFormat": "PNG", "outputFormat": "PNG",
"outputScale": "1000%", "outputScale": "1000%",
}, },
)
).json()["imageUrl"] ).json()["imageUrl"]
), ),
time.time_ns() - start_time, time.time_ns() - start_time,
@ -255,11 +260,14 @@ class ConvertLatex:
channel: ConvertChannel channel: ConvertChannel
def __init__(self, channel: Optional[ConvertChannel] = None) -> None: def __init__(self, channel: Optional[ConvertChannel] = None):
logger.info("LaTeX 转换服务将在 Bot 连接时异步加载")
async def load_channel(self, channel: ConvertChannel | None = None) -> None:
if channel is None: if channel is None:
logger.info("正在选择 LaTeX 转换服务频道,请稍等...") logger.info("正在选择 LaTeX 转换服务频道,请稍等...")
self.channel = self.auto_choose_channel() self.channel = await self.auto_choose_channel()
logger.info(f"已选择 {self.channel.__class__.__name__} 服务频道")
else: else:
self.channel = channel self.channel = channel
@ -297,9 +305,15 @@ class ConvertLatex:
) )
@staticmethod @staticmethod
def auto_choose_channel() -> ConvertChannel: async def auto_choose_channel() -> ConvertChannel:
async def channel_test_wrapper(
channel: type[ConvertChannel],
) -> Tuple[int, type[ConvertChannel]]:
score = await channel.channel_test()
return score, channel
return min( results = await asyncio.gather(
channel_list, *(channel_test_wrapper(channel) for channel in channel_list)
key=lambda channel: channel.channel_test(), )
)() best_channel = min(results, key=lambda x: x[0])[1]
return best_channel()

View File

@ -0,0 +1,7 @@
"""该功能目前正在开发中,暂时不可用,受影响的文件夹 `plugin`, `plugins`
"""
from .load import *
from .models import *
from .register import *
from .utils import *

View File

@ -23,6 +23,26 @@ __all__ = [
] ]
def get_plugin(name: str) -> Plugin | None:
"""获取插件对象
Args:
name: 插件名称
Returns:
Optional[Plugin]: 插件对象
"""
return _plugins.get(name)
def get_plugins() -> dict[str, Plugin]:
"""获取所有插件
Returns:
dict[str, Plugin]: 插件集合
"""
return _plugins
def load_plugin(module_path: str | Path) -> Optional[Plugin]: def load_plugin(module_path: str | Path) -> Optional[Plugin]:
"""加载单个插件,可以是本地插件或是通过 `pip` 安装的插件。 """加载单个插件,可以是本地插件或是通过 `pip` 安装的插件。
该函数产生的副作用在于将插件加载到 `_plugins` 该函数产生的副作用在于将插件加载到 `_plugins`
@ -45,20 +65,23 @@ def load_plugin(module_path: str | Path) -> Optional[Plugin]:
module=module, module=module,
module_name=module_path, module_name=module_path,
) )
_plugins[plugin.name] = plugin
plugin.metadata = getattr(module, "__marsho_meta__", None) plugin.metadata = getattr(module, "__marsho_meta__", None)
_plugins[plugin.name] = plugin if plugin.metadata is None:
logger.opt(colors=True).warning(
logger.opt(colors=True).success( f"成功加载小棉插件 <y>{plugin.name}</y>, 但是没有定义元数据"
f'Succeeded to load liteyuki plugin "{plugin.name}"'
) )
return _plugins[module.__name__] else:
logger.opt(colors=True).success(
f'成功加载小棉插件 <c>"{plugin.metadata.name}"</c>'
)
return plugin
except Exception as e: except Exception as e:
logger.opt(colors=True).success( logger.opt(colors=True).success(f'加载小棉插件失败 "<r>{module_path}</r>"')
f'Failed to load liteyuki plugin "<r>{module_path}</r>"'
)
traceback.print_exc() traceback.print_exc()
return None return None

View File

@ -4,32 +4,6 @@ from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
class Plugin(BaseModel):
"""
存储插件信息
Attributes:
----------
name: str
包名称 例如marsho_test
module: ModuleType
插件模块对象
module_name: str
点分割模块路径 例如a.b.c
metadata: "PluginMeta" | None
"""
name: str
"""包名称 例如marsho_test"""
module: ModuleType
"""插件模块对象"""
module_name: str
"""点分割模块路径 例如a.b.c"""
metadata: "PluginMetadata" | None = None
""""""
class PluginMetadata(BaseModel): class PluginMetadata(BaseModel):
""" """
Marsho 插件 对象元数据 Marsho 插件 对象元数据
@ -58,3 +32,38 @@ class PluginMetadata(BaseModel):
author: str = "" author: str = ""
homepage: str = "" homepage: str = ""
extra: dict[str, Any] = {} extra: dict[str, Any] = {}
class Plugin(BaseModel):
"""
存储插件信息
Attributes:
----------
name: str
包名称 例如marsho_test
module: ModuleType
插件模块对象
module_name: str
点分割模块路径 例如a.b.c
metadata: "PluginMeta" | None
"""
name: str
"""包名称 例如marsho_test"""
module: ModuleType
"""插件模块对象"""
module_name: str
"""点分割模块路径 例如a.b.c"""
metadata: PluginMetadata | None = None
""""""
class Config:
arbitrary_types_allowed = True
def __hash__(self) -> int:
return hash(self.name)
def __eq__(self, other: Any) -> bool:
return self.name == other.name

View File

@ -0,0 +1,55 @@
"""此模块用于获取function call中函数定义信息以及注册函数
"""
import inspect
from typing import Any, Callable, Coroutine, TypeAlias
import nonebot
from .utils import is_coroutine_callable
SYNC_FUNCTION_CALL: TypeAlias = Callable[..., str]
ASYNC_FUNCTION_CALL: TypeAlias = Callable[..., Coroutine[str, Any, str]]
FUNCTION_CALL: TypeAlias = SYNC_FUNCTION_CALL | ASYNC_FUNCTION_CALL
_loaded_functions: dict[str, FUNCTION_CALL] = {}
def async_wrapper(func: SYNC_FUNCTION_CALL) -> ASYNC_FUNCTION_CALL:
"""将同步函数包装为异步函数,但是不会真正异步执行,仅用于统一调用及函数签名
Args:
func: 同步函数
Returns:
ASYNC_FUNCTION_CALL: 异步函数
"""
async def wrapper(*args, **kwargs) -> str:
return func(*args, **kwargs)
return wrapper
def function_call(*funcs: FUNCTION_CALL):
"""返回一个装饰器,装饰一个函数, 使其注册为一个可被AI调用的function call函数
Args:
func: 函数对象要有完整的 Google Style Docstring
Returns:
str: 函数定义信息
"""
for func in funcs:
if module := inspect.getmodule(func):
module_name = module.__name__ + "."
else:
module_name = ""
name = func.__name__
if not is_coroutine_callable(func):
func = async_wrapper(func) # type: ignore
_loaded_functions[name] = func
nonebot.logger.opt(colors=True).info(
f"加载 function call: <c>{module_name}{name}</c>"
)

View File

@ -0,0 +1,34 @@
import inspect
from pathlib import Path
from typing import Any, Callable
def path_to_module_name(path: Path) -> str:
"""
转换路径为模块名
Args:
path: 路径a/b/c/d -> a.b.c.d
Returns:
str: 模块名
"""
rel_path = path.resolve().relative_to(Path.cwd().resolve())
if rel_path.stem == "__init__":
return ".".join(rel_path.parts[:-1])
else:
return ".".join(rel_path.parts[:-1] + (rel_path.stem,))
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
"""
判断是否为async def 函数
Args:
call: 可调用对象
Returns:
bool: 是否为协程可调用对象
"""
if inspect.isroutine(call):
return inspect.iscoroutinefunction(call)
if inspect.isclass(call):
return False
func_ = getattr(call, "__call__", None)
return inspect.iscoroutinefunction(func_)

View File

@ -0,0 +1,54 @@
import traceback
import httpx
from nonebot_plugin_marshoai.plugin import PluginMetadata, function_call
__marsho_meta__ = PluginMetadata(
name="Bangumi 番剧信息",
description="Bangumi 番剧信息",
usage="Bangumi 番剧信息",
author="Liteyuki",
homepage="",
)
async def fetch_calendar():
url = "https://api.bgm.tv/calendar"
headers = {
"User-Agent": "LiteyukiStudio/nonebot-plugin-marshoai (https://github.com/LiteyukiStudio/nonebot-plugin-marshoai)"
}
async with httpx.AsyncClient() as client:
response = await client.get(url, headers=headers)
# print(response.text)
return response.json()
@function_call
async def get_bangumi_news() -> str:
"""获取今天的新番(动漫)列表,在调用之前,你需要知道今天星期几。
Returns:
_type_: _description_
"""
result = await fetch_calendar()
info = ""
try:
for i in result:
weekday = i["weekday"]["cn"]
# print(weekday)
info += f"{weekday}:"
items = i["items"]
for item in items:
name = item["name_cn"]
info += f"{name}"
info += "\n"
return info
except Exception as e:
traceback.print_exc()
return ""
@function_call
def test_sync() -> str:
return "sync"

View File

@ -0,0 +1,9 @@
[
{
"type": "function",
"function": {
"name": "marshoai-bangumi__get_bangumi_news",
"description": "获取今天的新番(动漫)列表,在调用之前,你需要知道今天星期几。"
}
}
]

View File

@ -0,0 +1,24 @@
import os
from zhDateTime import DateTime
async def get_weather(location: str):
return f"{location}的温度是114514℃。"
async def get_current_env():
ver = os.popen("uname -a").read()
return str(ver)
async def get_current_time():
current_time = DateTime.now().strftime("%Y.%m.%d %H:%M:%S")
current_weekday = DateTime.now().weekday()
weekdays = ["星期一", "星期二", "星期三", "星期四", "星期五", "星期六", "星期日"]
current_weekday_name = weekdays[current_weekday]
current_lunar_date = DateTime.now().to_lunar().date_hanzify()[5:]
time_prompt = f"现在的时间是{current_time}{current_weekday_name},农历{current_lunar_date}"
return time_prompt

View File

@ -0,0 +1,9 @@
[
{
"type": "function",
"function": {
"name": "marshoai-basic__get_current_time",
"description": "获取现在的日期,时间和星期。"
}
}
]

View File

@ -0,0 +1,39 @@
[
{
"type": "function",
"function": {
"name": "marshoai-basic__get_weather",
"description": "当你想查询指定城市的天气时非常有用。",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "城市或县区,比如北京市、杭州市、余杭区等。"
}
}
},
"required": [
"location"
]
}
},
{
"type": "function",
"function": {
"name": "marshoai-basic__get_current_env",
"description": "获取当前的运行环境。",
"parameters": {
}
}
},
{
"type": "function",
"function": {
"name": "marshoai-basic__get_current_time",
"description": "获取现在的时间。",
"parameters": {
}
}
}
]

View File

@ -1,16 +0,0 @@
from pathlib import Path
def path_to_module_name(path: Path) -> str:
"""
转换路径为模块名
Args:
path: 路径a/b/c/d -> a.b.c.d
Returns:
str: 模块名
"""
rel_path = path.resolve().relative_to(Path.cwd().resolve())
if rel_path.stem == "__init__":
return ".".join(rel_path.parts[:-1])
else:
return ".".join(rel_path.parts[:-1] + (rel_path.stem,))

View File

@ -11,6 +11,7 @@ import nonebot_plugin_localstore as store
# from zhDateTime import DateTime # from zhDateTime import DateTime
from azure.ai.inference.aio import ChatCompletionsClient from azure.ai.inference.aio import ChatCompletionsClient
from azure.ai.inference.models import SystemMessage from azure.ai.inference.models import SystemMessage
from nonebot import get_driver
from nonebot.log import logger from nonebot.log import logger
from nonebot_plugin_alconna import Image as ImageMsg from nonebot_plugin_alconna import Image as ImageMsg
from nonebot_plugin_alconna import Text as TextMsg from nonebot_plugin_alconna import Text as TextMsg
@ -280,6 +281,10 @@ if config.marshoai_enable_richtext_parse:
latex_convert = ConvertLatex() # 开启一个转换实例 latex_convert = ConvertLatex() # 开启一个转换实例
@get_driver().on_bot_connect
async def load_latex_convert():
await latex_convert.load_channel(None)
async def get_uuid_back2codeblock( async def get_uuid_back2codeblock(
msg: str, code_blank_uuid_map: list[tuple[str, str]] msg: str, code_blank_uuid_map: list[tuple[str, str]]
): ):