重构Marsho插件,优化模块导入,钩子函数与类实例化,全局变量独立为模块

This commit is contained in:
Asankilp 2024-12-31 00:26:23 +08:00
parent 5f7d82ae29
commit aca5c2bd04
7 changed files with 91 additions and 76 deletions

View File

@ -6,11 +6,11 @@ require("nonebot_plugin_localstore")
import nonebot_plugin_localstore as store # type: ignore
from nonebot import get_driver, logger # type: ignore
from .azure import *
from .config import config
# from .hunyuan import *
from .dev import *
from .marsho import *
from .metadata import metadata
__author__ = "Asankilp"

View File

@ -8,8 +8,8 @@ from nonebot.typing import T_State
from nonebot_plugin_marshoai.plugin.load import reload_plugin
from .azure import context
from .config import config
from .marsho import context
from .plugin.func_call.models import SessionContext
require("nonebot_plugin_alconna")

View File

@ -0,0 +1,65 @@
# Marsho 的钩子函数
import os
from pathlib import Path
import nonebot_plugin_localstore as store
from nonebot import logger
from .config import config
from .instances import *
from .plugin import load_plugin, load_plugins
from .util import get_backup_context, save_context_to_json
@driver.on_startup
async def _preload_tools():
"""启动钩子加载工具"""
tools_dir = store.get_plugin_data_dir() / "tools"
os.makedirs(tools_dir, exist_ok=True)
if config.marshoai_enable_tools:
if config.marshoai_load_builtin_tools:
tools.load_tools(Path(__file__).parent / "tools")
tools.load_tools(store.get_plugin_data_dir() / "tools")
for tool_dir in config.marshoai_toolset_dir:
tools.load_tools(tool_dir)
logger.info(
"如果启用小棉工具后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_TOOLS 设为 false。"
)
logger.opt(colors=True).warning(
"<y>小棉工具已被弃用,可能会在未来版本中移除。</y>"
)
@driver.on_startup
async def _():
"""启动钩子加载插件"""
if config.marshoai_enable_plugins:
marshoai_plugin_dirs = config.marshoai_plugin_dirs # 外部插件目录列表
"""加载内置插件"""
for p in os.listdir(Path(__file__).parent / "plugins"):
load_plugin(f"{__package__}.plugins.{p}")
"""加载指定目录插件"""
load_plugins(*marshoai_plugin_dirs)
"""加载sys.path下的包, 包括从pip安装的包"""
for package_name in config.marshoai_plugins:
load_plugin(package_name)
logger.info(
"如果启用小棉插件后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_PLUGINS 设为 false。"
)
@driver.on_shutdown
async def auto_backup_context():
for target_info in target_list:
target_id, target_private = target_info
contexts_data = context.build(target_id, target_private)
if target_private:
target_uid = "private_" + target_id
else:
target_uid = "group_" + target_id
await save_context_to_json(
f"back_up_context_{target_uid}", contexts_data, "contexts/backup"
)
logger.info(f"已保存会话 {target_id} 的上下文备份,将在下次对话时恢复~")

View File

@ -0,0 +1,18 @@
# Marsho 的类实例以及全局变量
from azure.ai.inference.aio import ChatCompletionsClient
from azure.core.credentials import AzureKeyCredential
from nonebot import get_driver
from .config import config
from .models import MarshoContext, MarshoTools
driver = get_driver()
command_start = driver.config.command_start
model_name = config.marshoai_default_model
context = MarshoContext()
tools = MarshoTools()
token = config.marshoai_token
endpoint = config.marshoai_azure_endpoint
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token))
target_list: list[list] = [] # 记录需保存历史上下文的列表

View File

@ -1,9 +1,7 @@
import contextlib
import traceback
from pathlib import Path
from typing import Optional
import nonebot_plugin_localstore as store
from arclet.alconna import Alconna, AllParam, Args
from azure.ai.inference.models import (
AssistantMessage,
@ -15,8 +13,7 @@ from azure.ai.inference.models import (
ToolMessage,
UserMessage,
)
from azure.core.credentials import AzureKeyCredential
from nonebot import get_driver, logger, on_command, on_message
from nonebot import logger, on_command, on_message
from nonebot.adapters import Bot, Event, Message
from nonebot.matcher import Matcher
from nonebot.params import CommandArg
@ -25,9 +22,9 @@ from nonebot.rule import Rule, to_me
from nonebot.typing import T_State
from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
from .hooks import *
from .instances import *
from .metadata import metadata
from .models import MarshoContext, MarshoTools
from .plugin import _plugins, load_plugin, load_plugins
from .plugin.func_call.caller import get_function_calls
from .plugin.func_call.models import SessionContext
from .util import *
@ -37,8 +34,6 @@ async def at_enable():
return config.marshoai_at
driver = get_driver()
changemodel_cmd = on_command(
"changemodel", permission=SUPERUSER, priority=10, block=True
)
@ -93,54 +88,6 @@ refresh_data_cmd = on_command(
"refresh_data", permission=SUPERUSER, priority=10, block=True
)
command_start = driver.config.command_start
model_name = config.marshoai_default_model
context = MarshoContext()
tools = MarshoTools()
token = config.marshoai_token
endpoint = config.marshoai_azure_endpoint
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token))
target_list = [] # 记录需保存历史上下文的列表
@driver.on_startup
async def _preload_tools():
"""启动钩子加载工具"""
tools_dir = store.get_plugin_data_dir() / "tools"
os.makedirs(tools_dir, exist_ok=True)
if config.marshoai_enable_tools:
if config.marshoai_load_builtin_tools:
tools.load_tools(Path(__file__).parent / "tools")
tools.load_tools(store.get_plugin_data_dir() / "tools")
for tool_dir in config.marshoai_toolset_dir:
tools.load_tools(tool_dir)
logger.info(
"如果启用小棉工具后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_TOOLS 设为 false。"
)
logger.opt(colors=True).warning(
"<y>小棉工具已被弃用,可能会在未来版本中移除。</y>"
)
@driver.on_startup
async def _():
"""启动钩子加载插件"""
if config.marshoai_enable_plugins:
marshoai_plugin_dirs = config.marshoai_plugin_dirs # 外部插件目录列表
"""加载内置插件"""
for p in os.listdir(Path(__file__).parent / "plugins"):
load_plugin(f"{__package__}.plugins.{p}")
"""加载指定目录插件"""
load_plugins(*marshoai_plugin_dirs)
"""加载sys.path下的包, 包括从pip安装的包"""
for package_name in config.marshoai_plugins:
load_plugin(package_name)
logger.info(
"如果启用小棉插件后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_PLUGINS 设为 false。"
)
@add_usermsg_cmd.handle()
async def add_usermsg(target: MsgTarget, arg: Message = CommandArg()):
@ -470,7 +417,7 @@ async def marsho(
with contextlib.suppress(ImportError): # 优化先不做()
import nonebot.adapters.onebot.v11 # type: ignore
from .azure_onebot import poke_notify
from .marsho_onebot import poke_notify
@poke_notify.handle()
async def poke(event: Event):
@ -499,18 +446,3 @@ with contextlib.suppress(ImportError): # 优化先不做()
await UniMessage(str(e) + suggest_solution(str(e))).send()
traceback.print_exc()
return
@driver.on_shutdown
async def auto_backup_context():
for target_info in target_list:
target_id, target_private = target_info
contexts_data = context.build(target_id, target_private)
if target_private:
target_uid = "private_" + target_id
else:
target_uid = "group_" + target_id
await save_context_to_json(
f"back_up_context_{target_uid}", contexts_data, "contexts/backup"
)
logger.info(f"已保存会话 {target_id} 的上下文备份,将在下次对话时恢复~")

View File

@ -9,7 +9,7 @@ require("nonebot_plugin_apscheduler")
require("nonebot_plugin_marshoai")
from nonebot_plugin_apscheduler import scheduler
from nonebot_plugin_marshoai.azure import client
from nonebot_plugin_marshoai.instances import client
from nonebot_plugin_marshoai.plugin import PluginMetadata, on_function_call
from nonebot_plugin_marshoai.plugin.func_call.params import String
@ -84,7 +84,7 @@ if plugin_config.marshoai_plugin_memory_scheduler:
@driver.on_startup
async def _():
logger.info("定时记忆整理已启动!")
logger.info("定时记忆整理已启动!")
scheduler.add_job(
organize_memories,
"cron",