🎨 Apply black formatting

This commit is contained in:
远野千束(神羽) 2024-12-14 04:43:03 +08:00
parent 59e0871840
commit a9938d30ed
16 changed files with 361 additions and 82 deletions

View File

@ -25,6 +25,7 @@ from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
from .metadata import metadata
from .models import MarshoContext, MarshoTools
from .plugin import _plugins, load_plugins
from .util import *
@ -85,6 +86,7 @@ target_list = [] # 记录需保存历史上下文的列表
@driver.on_startup
async def _preload_tools():
"""启动钩子加载工具"""
tools_dir = store.get_plugin_data_dir() / "tools"
os.makedirs(tools_dir, exist_ok=True)
if config.marshoai_enable_tools:
@ -98,6 +100,15 @@ async def _preload_tools():
)
@driver.on_startup
async def _preload_plugins():
"""启动钩子加载插件"""
marshoai_plugin_dirs = config.marshoai_plugin_dirs
marshoai_plugin_dirs.insert(0, Path(__file__).parent / "plugins")
load_plugins(*marshoai_plugin_dirs)
logger.opt(colors=True).info(f"已加载 <c>{len(_plugins)}</c> 个小棉插件")
@add_usermsg_cmd.handle()
async def add_usermsg(target: MsgTarget, arg: Message = CommandArg()):
if msg := arg.extract_plain_text():

View File

@ -48,6 +48,8 @@ class ConfigModel(BaseModel):
marshoai_tencent_secretid: str | None = None
marshoai_tencent_secretkey: str | None = None
marshoai_plugin_dirs: list[str] = []
yaml = YAML()

View File

@ -14,6 +14,7 @@ MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
"""
import asyncio
import time
from typing import Literal, Optional, Tuple
@ -35,7 +36,7 @@ class ConvertChannel:
return False, "请勿直接调用母类"
@staticmethod
def channel_test() -> int:
async def channel_test() -> int:
return -1
@ -90,21 +91,23 @@ class L2PChannel(ConvertChannel):
return False, "未知错误"
@staticmethod
def channel_test() -> int:
with httpx.Client(timeout=5, verify=False) as client:
async def channel_test() -> int:
async with httpx.AsyncClient(timeout=5, verify=False) as client:
try:
start_time = time.time_ns()
latex2png = (
client.get(
await client.get(
"http://www.latex2png.com{}"
+ client.post(
"http://www.latex2png.com/api/convert",
json={
"auth": {"user": "guest", "password": "guest"},
"latex": "\\\\int_{a}^{b} x^2 \\\\, dx = \\\\frac{b^3}{3} - \\\\frac{a^3}{5}\n",
"resolution": 600,
"color": "000000",
},
+ (
await client.post(
"http://www.latex2png.com/api/convert",
json={
"auth": {"user": "guest", "password": "guest"},
"latex": "\\\\int_{a}^{b} x^2 \\\\, dx = \\\\frac{b^3}{3} - \\\\frac{a^3}{5}\n",
"resolution": 600,
"color": "000000",
},
)
).json()["url"]
),
time.time_ns() - start_time,
@ -156,12 +159,12 @@ class CDCChannel(ConvertChannel):
return False, "未知错误"
@staticmethod
def channel_test() -> int:
with httpx.Client(timeout=5, verify=False) as client:
async def channel_test() -> int:
async with httpx.AsyncClient(timeout=5, verify=False) as client:
try:
start_time = time.time_ns()
codecogs = (
client.get(
await client.get(
r"https://latex.codecogs.com/png.image?\huge%20\dpi{600}\\int_{a}^{b}x^2\\,dx=\\frac{b^3}{3}-\\frac{a^3}{5}"
),
time.time_ns() - start_time,
@ -223,19 +226,21 @@ class JRTChannel(ConvertChannel):
return False, "未知错误"
@staticmethod
def channel_test() -> int:
with httpx.Client(timeout=5, verify=False) as client:
async def channel_test() -> int:
async with httpx.AsyncClient(timeout=5, verify=False) as client:
try:
start_time = time.time_ns()
joeraut = (
client.get(
client.post(
"http://www.latex2png.com/api/convert",
json={
"latexInput": "\\\\int_{a}^{b} x^2 \\\\, dx = \\\\frac{b^3}{3} - \\\\frac{a^3}{5}",
"outputFormat": "PNG",
"outputScale": "1000%",
},
await client.get(
(
await client.post(
"http://www.latex2png.com/api/convert",
json={
"latexInput": "\\\\int_{a}^{b} x^2 \\\\, dx = \\\\frac{b^3}{3} - \\\\frac{a^3}{5}",
"outputFormat": "PNG",
"outputScale": "1000%",
},
)
).json()["imageUrl"]
),
time.time_ns() - start_time,
@ -255,11 +260,14 @@ class ConvertLatex:
channel: ConvertChannel
def __init__(self, channel: Optional[ConvertChannel] = None) -> None:
def __init__(self, channel: Optional[ConvertChannel] = None):
logger.info("LaTeX 转换服务将在 Bot 连接时异步加载")
async def load_channel(self, channel: ConvertChannel | None = None) -> None:
if channel is None:
logger.info("正在选择 LaTeX 转换服务频道,请稍等...")
self.channel = self.auto_choose_channel()
self.channel = await self.auto_choose_channel()
logger.info(f"已选择 {self.channel.__class__.__name__} 服务频道")
else:
self.channel = channel
@ -297,9 +305,15 @@ class ConvertLatex:
)
@staticmethod
def auto_choose_channel() -> ConvertChannel:
async def auto_choose_channel() -> ConvertChannel:
async def channel_test_wrapper(
channel: type[ConvertChannel],
) -> Tuple[int, type[ConvertChannel]]:
score = await channel.channel_test()
return score, channel
return min(
channel_list,
key=lambda channel: channel.channel_test(),
)()
results = await asyncio.gather(
*(channel_test_wrapper(channel) for channel in channel_list)
)
best_channel = min(results, key=lambda x: x[0])[1]
return best_channel()

View File

@ -0,0 +1,7 @@
"""该功能目前正在开发中,暂时不可用,受影响的文件夹 `plugin`, `plugins`
"""
from .load import *
from .models import *
from .register import *
from .utils import *

View File

@ -23,6 +23,26 @@ __all__ = [
]
def get_plugin(name: str) -> Plugin | None:
"""获取插件对象
Args:
name: 插件名称
Returns:
Optional[Plugin]: 插件对象
"""
return _plugins.get(name)
def get_plugins() -> dict[str, Plugin]:
"""获取所有插件
Returns:
dict[str, Plugin]: 插件集合
"""
return _plugins
def load_plugin(module_path: str | Path) -> Optional[Plugin]:
"""加载单个插件,可以是本地插件或是通过 `pip` 安装的插件。
该函数产生的副作用在于将插件加载到 `_plugins`
@ -45,20 +65,23 @@ def load_plugin(module_path: str | Path) -> Optional[Plugin]:
module=module,
module_name=module_path,
)
_plugins[plugin.name] = plugin
plugin.metadata = getattr(module, "__marsho_meta__", None)
_plugins[plugin.name] = plugin
if plugin.metadata is None:
logger.opt(colors=True).warning(
f"成功加载小棉插件 <y>{plugin.name}</y>, 但是没有定义元数据"
)
else:
logger.opt(colors=True).success(
f'成功加载小棉插件 <c>"{plugin.metadata.name}"</c>'
)
logger.opt(colors=True).success(
f'Succeeded to load liteyuki plugin "{plugin.name}"'
)
return _plugins[module.__name__]
return plugin
except Exception as e:
logger.opt(colors=True).success(
f'Failed to load liteyuki plugin "<r>{module_path}</r>"'
)
logger.opt(colors=True).success(f'加载小棉插件失败 "<r>{module_path}</r>"')
traceback.print_exc()
return None

View File

@ -4,32 +4,6 @@ from typing import Any
from pydantic import BaseModel
class Plugin(BaseModel):
"""
存储插件信息
Attributes:
----------
name: str
包名称 例如marsho_test
module: ModuleType
插件模块对象
module_name: str
点分割模块路径 例如a.b.c
metadata: "PluginMeta" | None
"""
name: str
"""包名称 例如marsho_test"""
module: ModuleType
"""插件模块对象"""
module_name: str
"""点分割模块路径 例如a.b.c"""
metadata: "PluginMetadata" | None = None
""""""
class PluginMetadata(BaseModel):
"""
Marsho 插件 对象元数据
@ -58,3 +32,38 @@ class PluginMetadata(BaseModel):
author: str = ""
homepage: str = ""
extra: dict[str, Any] = {}
class Plugin(BaseModel):
"""
存储插件信息
Attributes:
----------
name: str
包名称 例如marsho_test
module: ModuleType
插件模块对象
module_name: str
点分割模块路径 例如a.b.c
metadata: "PluginMeta" | None
"""
name: str
"""包名称 例如marsho_test"""
module: ModuleType
"""插件模块对象"""
module_name: str
"""点分割模块路径 例如a.b.c"""
metadata: PluginMetadata | None = None
""""""
class Config:
arbitrary_types_allowed = True
def __hash__(self) -> int:
return hash(self.name)
def __eq__(self, other: Any) -> bool:
return self.name == other.name

View File

@ -0,0 +1,55 @@
"""此模块用于获取function call中函数定义信息以及注册函数
"""
import inspect
from typing import Any, Callable, Coroutine, TypeAlias
import nonebot
from .utils import is_coroutine_callable
SYNC_FUNCTION_CALL: TypeAlias = Callable[..., str]
ASYNC_FUNCTION_CALL: TypeAlias = Callable[..., Coroutine[str, Any, str]]
FUNCTION_CALL: TypeAlias = SYNC_FUNCTION_CALL | ASYNC_FUNCTION_CALL
_loaded_functions: dict[str, FUNCTION_CALL] = {}
def async_wrapper(func: SYNC_FUNCTION_CALL) -> ASYNC_FUNCTION_CALL:
"""将同步函数包装为异步函数,但是不会真正异步执行,仅用于统一调用及函数签名
Args:
func: 同步函数
Returns:
ASYNC_FUNCTION_CALL: 异步函数
"""
async def wrapper(*args, **kwargs) -> str:
return func(*args, **kwargs)
return wrapper
def function_call(*funcs: FUNCTION_CALL):
"""返回一个装饰器,装饰一个函数, 使其注册为一个可被AI调用的function call函数
Args:
func: 函数对象要有完整的 Google Style Docstring
Returns:
str: 函数定义信息
"""
for func in funcs:
if module := inspect.getmodule(func):
module_name = module.__name__ + "."
else:
module_name = ""
name = func.__name__
if not is_coroutine_callable(func):
func = async_wrapper(func) # type: ignore
_loaded_functions[name] = func
nonebot.logger.opt(colors=True).info(
f"加载 function call: <c>{module_name}{name}</c>"
)

View File

@ -0,0 +1,34 @@
import inspect
from pathlib import Path
from typing import Any, Callable
def path_to_module_name(path: Path) -> str:
"""
转换路径为模块名
Args:
path: 路径a/b/c/d -> a.b.c.d
Returns:
str: 模块名
"""
rel_path = path.resolve().relative_to(Path.cwd().resolve())
if rel_path.stem == "__init__":
return ".".join(rel_path.parts[:-1])
else:
return ".".join(rel_path.parts[:-1] + (rel_path.stem,))
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
"""
判断是否为async def 函数
Args:
call: 可调用对象
Returns:
bool: 是否为协程可调用对象
"""
if inspect.isroutine(call):
return inspect.iscoroutinefunction(call)
if inspect.isclass(call):
return False
func_ = getattr(call, "__call__", None)
return inspect.iscoroutinefunction(func_)

View File

@ -0,0 +1,54 @@
import traceback
import httpx
from nonebot_plugin_marshoai.plugin import PluginMetadata, function_call
__marsho_meta__ = PluginMetadata(
name="Bangumi 番剧信息",
description="Bangumi 番剧信息",
usage="Bangumi 番剧信息",
author="Liteyuki",
homepage="",
)
async def fetch_calendar():
url = "https://api.bgm.tv/calendar"
headers = {
"User-Agent": "LiteyukiStudio/nonebot-plugin-marshoai (https://github.com/LiteyukiStudio/nonebot-plugin-marshoai)"
}
async with httpx.AsyncClient() as client:
response = await client.get(url, headers=headers)
# print(response.text)
return response.json()
@function_call
async def get_bangumi_news() -> str:
"""获取今天的新番(动漫)列表,在调用之前,你需要知道今天星期几。
Returns:
_type_: _description_
"""
result = await fetch_calendar()
info = ""
try:
for i in result:
weekday = i["weekday"]["cn"]
# print(weekday)
info += f"{weekday}:"
items = i["items"]
for item in items:
name = item["name_cn"]
info += f"{name}"
info += "\n"
return info
except Exception as e:
traceback.print_exc()
return ""
@function_call
def test_sync() -> str:
return "sync"

View File

@ -0,0 +1,9 @@
[
{
"type": "function",
"function": {
"name": "marshoai-bangumi__get_bangumi_news",
"description": "获取今天的新番(动漫)列表,在调用之前,你需要知道今天星期几。"
}
}
]

View File

@ -0,0 +1,24 @@
import os
from zhDateTime import DateTime
async def get_weather(location: str):
return f"{location}的温度是114514℃。"
async def get_current_env():
ver = os.popen("uname -a").read()
return str(ver)
async def get_current_time():
current_time = DateTime.now().strftime("%Y.%m.%d %H:%M:%S")
current_weekday = DateTime.now().weekday()
weekdays = ["星期一", "星期二", "星期三", "星期四", "星期五", "星期六", "星期日"]
current_weekday_name = weekdays[current_weekday]
current_lunar_date = DateTime.now().to_lunar().date_hanzify()[5:]
time_prompt = f"现在的时间是{current_time}{current_weekday_name},农历{current_lunar_date}"
return time_prompt

View File

@ -0,0 +1,9 @@
[
{
"type": "function",
"function": {
"name": "marshoai-basic__get_current_time",
"description": "获取现在的日期,时间和星期。"
}
}
]

View File

@ -0,0 +1,39 @@
[
{
"type": "function",
"function": {
"name": "marshoai-basic__get_weather",
"description": "当你想查询指定城市的天气时非常有用。",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "城市或县区,比如北京市、杭州市、余杭区等。"
}
}
},
"required": [
"location"
]
}
},
{
"type": "function",
"function": {
"name": "marshoai-basic__get_current_env",
"description": "获取当前的运行环境。",
"parameters": {
}
}
},
{
"type": "function",
"function": {
"name": "marshoai-basic__get_current_time",
"description": "获取现在的时间。",
"parameters": {
}
}
}
]

View File

@ -1,16 +0,0 @@
from pathlib import Path
def path_to_module_name(path: Path) -> str:
"""
转换路径为模块名
Args:
path: 路径a/b/c/d -> a.b.c.d
Returns:
str: 模块名
"""
rel_path = path.resolve().relative_to(Path.cwd().resolve())
if rel_path.stem == "__init__":
return ".".join(rel_path.parts[:-1])
else:
return ".".join(rel_path.parts[:-1] + (rel_path.stem,))

View File

@ -11,6 +11,7 @@ import nonebot_plugin_localstore as store
# from zhDateTime import DateTime
from azure.ai.inference.aio import ChatCompletionsClient
from azure.ai.inference.models import SystemMessage
from nonebot import get_driver
from nonebot.log import logger
from nonebot_plugin_alconna import Image as ImageMsg
from nonebot_plugin_alconna import Text as TextMsg
@ -280,6 +281,10 @@ if config.marshoai_enable_richtext_parse:
latex_convert = ConvertLatex() # 开启一个转换实例
@get_driver().on_bot_connect
async def load_latex_convert():
await latex_convert.load_channel(None)
async def get_uuid_back2codeblock(
msg: str, code_blank_uuid_map: list[tuple[str, str]]
):