mirror of
https://github.com/LiteyukiStudio/nonebot-plugin-marshoai.git
synced 2025-02-07 18:36:09 +08:00
✨ 重构Marsho插件,优化模块导入,钩子函数与类实例化,全局变量独立为模块
This commit is contained in:
parent
5f7d82ae29
commit
aca5c2bd04
@ -6,11 +6,11 @@ require("nonebot_plugin_localstore")
|
||||
import nonebot_plugin_localstore as store # type: ignore
|
||||
from nonebot import get_driver, logger # type: ignore
|
||||
|
||||
from .azure import *
|
||||
from .config import config
|
||||
|
||||
# from .hunyuan import *
|
||||
from .dev import *
|
||||
from .marsho import *
|
||||
from .metadata import metadata
|
||||
|
||||
__author__ = "Asankilp"
|
||||
|
@ -8,8 +8,8 @@ from nonebot.typing import T_State
|
||||
|
||||
from nonebot_plugin_marshoai.plugin.load import reload_plugin
|
||||
|
||||
from .azure import context
|
||||
from .config import config
|
||||
from .marsho import context
|
||||
from .plugin.func_call.models import SessionContext
|
||||
|
||||
require("nonebot_plugin_alconna")
|
||||
|
65
nonebot_plugin_marshoai/hooks.py
Normal file
65
nonebot_plugin_marshoai/hooks.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Marsho 的钩子函数
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import nonebot_plugin_localstore as store
|
||||
from nonebot import logger
|
||||
|
||||
from .config import config
|
||||
from .instances import *
|
||||
from .plugin import load_plugin, load_plugins
|
||||
from .util import get_backup_context, save_context_to_json
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
async def _preload_tools():
|
||||
"""启动钩子加载工具"""
|
||||
tools_dir = store.get_plugin_data_dir() / "tools"
|
||||
os.makedirs(tools_dir, exist_ok=True)
|
||||
if config.marshoai_enable_tools:
|
||||
if config.marshoai_load_builtin_tools:
|
||||
tools.load_tools(Path(__file__).parent / "tools")
|
||||
tools.load_tools(store.get_plugin_data_dir() / "tools")
|
||||
for tool_dir in config.marshoai_toolset_dir:
|
||||
tools.load_tools(tool_dir)
|
||||
logger.info(
|
||||
"如果启用小棉工具后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_TOOLS 设为 false。"
|
||||
)
|
||||
logger.opt(colors=True).warning(
|
||||
"<y>小棉工具已被弃用,可能会在未来版本中移除。</y>"
|
||||
)
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
async def _():
|
||||
"""启动钩子加载插件"""
|
||||
if config.marshoai_enable_plugins:
|
||||
marshoai_plugin_dirs = config.marshoai_plugin_dirs # 外部插件目录列表
|
||||
"""加载内置插件"""
|
||||
for p in os.listdir(Path(__file__).parent / "plugins"):
|
||||
load_plugin(f"{__package__}.plugins.{p}")
|
||||
|
||||
"""加载指定目录插件"""
|
||||
load_plugins(*marshoai_plugin_dirs)
|
||||
|
||||
"""加载sys.path下的包, 包括从pip安装的包"""
|
||||
for package_name in config.marshoai_plugins:
|
||||
load_plugin(package_name)
|
||||
logger.info(
|
||||
"如果启用小棉插件后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_PLUGINS 设为 false。"
|
||||
)
|
||||
|
||||
|
||||
@driver.on_shutdown
|
||||
async def auto_backup_context():
|
||||
for target_info in target_list:
|
||||
target_id, target_private = target_info
|
||||
contexts_data = context.build(target_id, target_private)
|
||||
if target_private:
|
||||
target_uid = "private_" + target_id
|
||||
else:
|
||||
target_uid = "group_" + target_id
|
||||
await save_context_to_json(
|
||||
f"back_up_context_{target_uid}", contexts_data, "contexts/backup"
|
||||
)
|
||||
logger.info(f"已保存会话 {target_id} 的上下文备份,将在下次对话时恢复~")
|
18
nonebot_plugin_marshoai/instances.py
Normal file
18
nonebot_plugin_marshoai/instances.py
Normal file
@ -0,0 +1,18 @@
|
||||
# Marsho 的类实例以及全局变量
|
||||
from azure.ai.inference.aio import ChatCompletionsClient
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from nonebot import get_driver
|
||||
|
||||
from .config import config
|
||||
from .models import MarshoContext, MarshoTools
|
||||
|
||||
driver = get_driver()
|
||||
|
||||
command_start = driver.config.command_start
|
||||
model_name = config.marshoai_default_model
|
||||
context = MarshoContext()
|
||||
tools = MarshoTools()
|
||||
token = config.marshoai_token
|
||||
endpoint = config.marshoai_azure_endpoint
|
||||
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token))
|
||||
target_list: list[list] = [] # 记录需保存历史上下文的列表
|
@ -1,9 +1,7 @@
|
||||
import contextlib
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import nonebot_plugin_localstore as store
|
||||
from arclet.alconna import Alconna, AllParam, Args
|
||||
from azure.ai.inference.models import (
|
||||
AssistantMessage,
|
||||
@ -15,8 +13,7 @@ from azure.ai.inference.models import (
|
||||
ToolMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from nonebot import get_driver, logger, on_command, on_message
|
||||
from nonebot import logger, on_command, on_message
|
||||
from nonebot.adapters import Bot, Event, Message
|
||||
from nonebot.matcher import Matcher
|
||||
from nonebot.params import CommandArg
|
||||
@ -25,9 +22,9 @@ from nonebot.rule import Rule, to_me
|
||||
from nonebot.typing import T_State
|
||||
from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
|
||||
|
||||
from .hooks import *
|
||||
from .instances import *
|
||||
from .metadata import metadata
|
||||
from .models import MarshoContext, MarshoTools
|
||||
from .plugin import _plugins, load_plugin, load_plugins
|
||||
from .plugin.func_call.caller import get_function_calls
|
||||
from .plugin.func_call.models import SessionContext
|
||||
from .util import *
|
||||
@ -37,8 +34,6 @@ async def at_enable():
|
||||
return config.marshoai_at
|
||||
|
||||
|
||||
driver = get_driver()
|
||||
|
||||
changemodel_cmd = on_command(
|
||||
"changemodel", permission=SUPERUSER, priority=10, block=True
|
||||
)
|
||||
@ -93,54 +88,6 @@ refresh_data_cmd = on_command(
|
||||
"refresh_data", permission=SUPERUSER, priority=10, block=True
|
||||
)
|
||||
|
||||
command_start = driver.config.command_start
|
||||
model_name = config.marshoai_default_model
|
||||
context = MarshoContext()
|
||||
tools = MarshoTools()
|
||||
token = config.marshoai_token
|
||||
endpoint = config.marshoai_azure_endpoint
|
||||
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token))
|
||||
target_list = [] # 记录需保存历史上下文的列表
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
async def _preload_tools():
|
||||
"""启动钩子加载工具"""
|
||||
tools_dir = store.get_plugin_data_dir() / "tools"
|
||||
os.makedirs(tools_dir, exist_ok=True)
|
||||
if config.marshoai_enable_tools:
|
||||
if config.marshoai_load_builtin_tools:
|
||||
tools.load_tools(Path(__file__).parent / "tools")
|
||||
tools.load_tools(store.get_plugin_data_dir() / "tools")
|
||||
for tool_dir in config.marshoai_toolset_dir:
|
||||
tools.load_tools(tool_dir)
|
||||
logger.info(
|
||||
"如果启用小棉工具后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_TOOLS 设为 false。"
|
||||
)
|
||||
logger.opt(colors=True).warning(
|
||||
"<y>小棉工具已被弃用,可能会在未来版本中移除。</y>"
|
||||
)
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
async def _():
|
||||
"""启动钩子加载插件"""
|
||||
if config.marshoai_enable_plugins:
|
||||
marshoai_plugin_dirs = config.marshoai_plugin_dirs # 外部插件目录列表
|
||||
"""加载内置插件"""
|
||||
for p in os.listdir(Path(__file__).parent / "plugins"):
|
||||
load_plugin(f"{__package__}.plugins.{p}")
|
||||
|
||||
"""加载指定目录插件"""
|
||||
load_plugins(*marshoai_plugin_dirs)
|
||||
|
||||
"""加载sys.path下的包, 包括从pip安装的包"""
|
||||
for package_name in config.marshoai_plugins:
|
||||
load_plugin(package_name)
|
||||
logger.info(
|
||||
"如果启用小棉插件后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_PLUGINS 设为 false。"
|
||||
)
|
||||
|
||||
|
||||
@add_usermsg_cmd.handle()
|
||||
async def add_usermsg(target: MsgTarget, arg: Message = CommandArg()):
|
||||
@ -470,7 +417,7 @@ async def marsho(
|
||||
with contextlib.suppress(ImportError): # 优化先不做()
|
||||
import nonebot.adapters.onebot.v11 # type: ignore
|
||||
|
||||
from .azure_onebot import poke_notify
|
||||
from .marsho_onebot import poke_notify
|
||||
|
||||
@poke_notify.handle()
|
||||
async def poke(event: Event):
|
||||
@ -499,18 +446,3 @@ with contextlib.suppress(ImportError): # 优化先不做()
|
||||
await UniMessage(str(e) + suggest_solution(str(e))).send()
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
|
||||
@driver.on_shutdown
|
||||
async def auto_backup_context():
|
||||
for target_info in target_list:
|
||||
target_id, target_private = target_info
|
||||
contexts_data = context.build(target_id, target_private)
|
||||
if target_private:
|
||||
target_uid = "private_" + target_id
|
||||
else:
|
||||
target_uid = "group_" + target_id
|
||||
await save_context_to_json(
|
||||
f"back_up_context_{target_uid}", contexts_data, "contexts/backup"
|
||||
)
|
||||
logger.info(f"已保存会话 {target_id} 的上下文备份,将在下次对话时恢复~")
|
0
nonebot_plugin_marshoai/azure_onebot.py → nonebot_plugin_marshoai/marsho_onebot.py
Executable file → Normal file
0
nonebot_plugin_marshoai/azure_onebot.py → nonebot_plugin_marshoai/marsho_onebot.py
Executable file → Normal file
@ -9,7 +9,7 @@ require("nonebot_plugin_apscheduler")
|
||||
require("nonebot_plugin_marshoai")
|
||||
from nonebot_plugin_apscheduler import scheduler
|
||||
|
||||
from nonebot_plugin_marshoai.azure import client
|
||||
from nonebot_plugin_marshoai.instances import client
|
||||
from nonebot_plugin_marshoai.plugin import PluginMetadata, on_function_call
|
||||
from nonebot_plugin_marshoai.plugin.func_call.params import String
|
||||
|
||||
@ -84,7 +84,7 @@ if plugin_config.marshoai_plugin_memory_scheduler:
|
||||
|
||||
@driver.on_startup
|
||||
async def _():
|
||||
logger.info("小绵定时记忆整理已启动!")
|
||||
logger.info("小棉定时记忆整理已启动!")
|
||||
scheduler.add_job(
|
||||
organize_memories,
|
||||
"cron",
|
||||
|
Loading…
x
Reference in New Issue
Block a user