mirror of
https://github.com/LiteyukiStudio/nonebot-plugin-marshoai.git
synced 2025-02-07 21:46:10 +08:00
✨ 重构Marsho插件,优化模块导入,钩子函数与类实例化,全局变量独立为模块
This commit is contained in:
parent
5f7d82ae29
commit
aca5c2bd04
@ -6,11 +6,11 @@ require("nonebot_plugin_localstore")
|
|||||||
import nonebot_plugin_localstore as store # type: ignore
|
import nonebot_plugin_localstore as store # type: ignore
|
||||||
from nonebot import get_driver, logger # type: ignore
|
from nonebot import get_driver, logger # type: ignore
|
||||||
|
|
||||||
from .azure import *
|
|
||||||
from .config import config
|
from .config import config
|
||||||
|
|
||||||
# from .hunyuan import *
|
# from .hunyuan import *
|
||||||
from .dev import *
|
from .dev import *
|
||||||
|
from .marsho import *
|
||||||
from .metadata import metadata
|
from .metadata import metadata
|
||||||
|
|
||||||
__author__ = "Asankilp"
|
__author__ = "Asankilp"
|
||||||
|
@ -8,8 +8,8 @@ from nonebot.typing import T_State
|
|||||||
|
|
||||||
from nonebot_plugin_marshoai.plugin.load import reload_plugin
|
from nonebot_plugin_marshoai.plugin.load import reload_plugin
|
||||||
|
|
||||||
from .azure import context
|
|
||||||
from .config import config
|
from .config import config
|
||||||
|
from .marsho import context
|
||||||
from .plugin.func_call.models import SessionContext
|
from .plugin.func_call.models import SessionContext
|
||||||
|
|
||||||
require("nonebot_plugin_alconna")
|
require("nonebot_plugin_alconna")
|
||||||
|
65
nonebot_plugin_marshoai/hooks.py
Normal file
65
nonebot_plugin_marshoai/hooks.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
# Marsho 的钩子函数
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import nonebot_plugin_localstore as store
|
||||||
|
from nonebot import logger
|
||||||
|
|
||||||
|
from .config import config
|
||||||
|
from .instances import *
|
||||||
|
from .plugin import load_plugin, load_plugins
|
||||||
|
from .util import get_backup_context, save_context_to_json
|
||||||
|
|
||||||
|
|
||||||
|
@driver.on_startup
|
||||||
|
async def _preload_tools():
|
||||||
|
"""启动钩子加载工具"""
|
||||||
|
tools_dir = store.get_plugin_data_dir() / "tools"
|
||||||
|
os.makedirs(tools_dir, exist_ok=True)
|
||||||
|
if config.marshoai_enable_tools:
|
||||||
|
if config.marshoai_load_builtin_tools:
|
||||||
|
tools.load_tools(Path(__file__).parent / "tools")
|
||||||
|
tools.load_tools(store.get_plugin_data_dir() / "tools")
|
||||||
|
for tool_dir in config.marshoai_toolset_dir:
|
||||||
|
tools.load_tools(tool_dir)
|
||||||
|
logger.info(
|
||||||
|
"如果启用小棉工具后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_TOOLS 设为 false。"
|
||||||
|
)
|
||||||
|
logger.opt(colors=True).warning(
|
||||||
|
"<y>小棉工具已被弃用,可能会在未来版本中移除。</y>"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@driver.on_startup
|
||||||
|
async def _():
|
||||||
|
"""启动钩子加载插件"""
|
||||||
|
if config.marshoai_enable_plugins:
|
||||||
|
marshoai_plugin_dirs = config.marshoai_plugin_dirs # 外部插件目录列表
|
||||||
|
"""加载内置插件"""
|
||||||
|
for p in os.listdir(Path(__file__).parent / "plugins"):
|
||||||
|
load_plugin(f"{__package__}.plugins.{p}")
|
||||||
|
|
||||||
|
"""加载指定目录插件"""
|
||||||
|
load_plugins(*marshoai_plugin_dirs)
|
||||||
|
|
||||||
|
"""加载sys.path下的包, 包括从pip安装的包"""
|
||||||
|
for package_name in config.marshoai_plugins:
|
||||||
|
load_plugin(package_name)
|
||||||
|
logger.info(
|
||||||
|
"如果启用小棉插件后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_PLUGINS 设为 false。"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@driver.on_shutdown
|
||||||
|
async def auto_backup_context():
|
||||||
|
for target_info in target_list:
|
||||||
|
target_id, target_private = target_info
|
||||||
|
contexts_data = context.build(target_id, target_private)
|
||||||
|
if target_private:
|
||||||
|
target_uid = "private_" + target_id
|
||||||
|
else:
|
||||||
|
target_uid = "group_" + target_id
|
||||||
|
await save_context_to_json(
|
||||||
|
f"back_up_context_{target_uid}", contexts_data, "contexts/backup"
|
||||||
|
)
|
||||||
|
logger.info(f"已保存会话 {target_id} 的上下文备份,将在下次对话时恢复~")
|
18
nonebot_plugin_marshoai/instances.py
Normal file
18
nonebot_plugin_marshoai/instances.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# Marsho 的类实例以及全局变量
|
||||||
|
from azure.ai.inference.aio import ChatCompletionsClient
|
||||||
|
from azure.core.credentials import AzureKeyCredential
|
||||||
|
from nonebot import get_driver
|
||||||
|
|
||||||
|
from .config import config
|
||||||
|
from .models import MarshoContext, MarshoTools
|
||||||
|
|
||||||
|
driver = get_driver()
|
||||||
|
|
||||||
|
command_start = driver.config.command_start
|
||||||
|
model_name = config.marshoai_default_model
|
||||||
|
context = MarshoContext()
|
||||||
|
tools = MarshoTools()
|
||||||
|
token = config.marshoai_token
|
||||||
|
endpoint = config.marshoai_azure_endpoint
|
||||||
|
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token))
|
||||||
|
target_list: list[list] = [] # 记录需保存历史上下文的列表
|
@ -1,9 +1,7 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import traceback
|
import traceback
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import nonebot_plugin_localstore as store
|
|
||||||
from arclet.alconna import Alconna, AllParam, Args
|
from arclet.alconna import Alconna, AllParam, Args
|
||||||
from azure.ai.inference.models import (
|
from azure.ai.inference.models import (
|
||||||
AssistantMessage,
|
AssistantMessage,
|
||||||
@ -15,8 +13,7 @@ from azure.ai.inference.models import (
|
|||||||
ToolMessage,
|
ToolMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from azure.core.credentials import AzureKeyCredential
|
from nonebot import logger, on_command, on_message
|
||||||
from nonebot import get_driver, logger, on_command, on_message
|
|
||||||
from nonebot.adapters import Bot, Event, Message
|
from nonebot.adapters import Bot, Event, Message
|
||||||
from nonebot.matcher import Matcher
|
from nonebot.matcher import Matcher
|
||||||
from nonebot.params import CommandArg
|
from nonebot.params import CommandArg
|
||||||
@ -25,9 +22,9 @@ from nonebot.rule import Rule, to_me
|
|||||||
from nonebot.typing import T_State
|
from nonebot.typing import T_State
|
||||||
from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
|
from nonebot_plugin_alconna import MsgTarget, UniMessage, UniMsg, on_alconna
|
||||||
|
|
||||||
|
from .hooks import *
|
||||||
|
from .instances import *
|
||||||
from .metadata import metadata
|
from .metadata import metadata
|
||||||
from .models import MarshoContext, MarshoTools
|
|
||||||
from .plugin import _plugins, load_plugin, load_plugins
|
|
||||||
from .plugin.func_call.caller import get_function_calls
|
from .plugin.func_call.caller import get_function_calls
|
||||||
from .plugin.func_call.models import SessionContext
|
from .plugin.func_call.models import SessionContext
|
||||||
from .util import *
|
from .util import *
|
||||||
@ -37,8 +34,6 @@ async def at_enable():
|
|||||||
return config.marshoai_at
|
return config.marshoai_at
|
||||||
|
|
||||||
|
|
||||||
driver = get_driver()
|
|
||||||
|
|
||||||
changemodel_cmd = on_command(
|
changemodel_cmd = on_command(
|
||||||
"changemodel", permission=SUPERUSER, priority=10, block=True
|
"changemodel", permission=SUPERUSER, priority=10, block=True
|
||||||
)
|
)
|
||||||
@ -93,54 +88,6 @@ refresh_data_cmd = on_command(
|
|||||||
"refresh_data", permission=SUPERUSER, priority=10, block=True
|
"refresh_data", permission=SUPERUSER, priority=10, block=True
|
||||||
)
|
)
|
||||||
|
|
||||||
command_start = driver.config.command_start
|
|
||||||
model_name = config.marshoai_default_model
|
|
||||||
context = MarshoContext()
|
|
||||||
tools = MarshoTools()
|
|
||||||
token = config.marshoai_token
|
|
||||||
endpoint = config.marshoai_azure_endpoint
|
|
||||||
client = ChatCompletionsClient(endpoint=endpoint, credential=AzureKeyCredential(token))
|
|
||||||
target_list = [] # 记录需保存历史上下文的列表
|
|
||||||
|
|
||||||
|
|
||||||
@driver.on_startup
|
|
||||||
async def _preload_tools():
|
|
||||||
"""启动钩子加载工具"""
|
|
||||||
tools_dir = store.get_plugin_data_dir() / "tools"
|
|
||||||
os.makedirs(tools_dir, exist_ok=True)
|
|
||||||
if config.marshoai_enable_tools:
|
|
||||||
if config.marshoai_load_builtin_tools:
|
|
||||||
tools.load_tools(Path(__file__).parent / "tools")
|
|
||||||
tools.load_tools(store.get_plugin_data_dir() / "tools")
|
|
||||||
for tool_dir in config.marshoai_toolset_dir:
|
|
||||||
tools.load_tools(tool_dir)
|
|
||||||
logger.info(
|
|
||||||
"如果启用小棉工具后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_TOOLS 设为 false。"
|
|
||||||
)
|
|
||||||
logger.opt(colors=True).warning(
|
|
||||||
"<y>小棉工具已被弃用,可能会在未来版本中移除。</y>"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@driver.on_startup
|
|
||||||
async def _():
|
|
||||||
"""启动钩子加载插件"""
|
|
||||||
if config.marshoai_enable_plugins:
|
|
||||||
marshoai_plugin_dirs = config.marshoai_plugin_dirs # 外部插件目录列表
|
|
||||||
"""加载内置插件"""
|
|
||||||
for p in os.listdir(Path(__file__).parent / "plugins"):
|
|
||||||
load_plugin(f"{__package__}.plugins.{p}")
|
|
||||||
|
|
||||||
"""加载指定目录插件"""
|
|
||||||
load_plugins(*marshoai_plugin_dirs)
|
|
||||||
|
|
||||||
"""加载sys.path下的包, 包括从pip安装的包"""
|
|
||||||
for package_name in config.marshoai_plugins:
|
|
||||||
load_plugin(package_name)
|
|
||||||
logger.info(
|
|
||||||
"如果启用小棉插件后使用的模型出现报错,请尝试将 MARSHOAI_ENABLE_PLUGINS 设为 false。"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@add_usermsg_cmd.handle()
|
@add_usermsg_cmd.handle()
|
||||||
async def add_usermsg(target: MsgTarget, arg: Message = CommandArg()):
|
async def add_usermsg(target: MsgTarget, arg: Message = CommandArg()):
|
||||||
@ -470,7 +417,7 @@ async def marsho(
|
|||||||
with contextlib.suppress(ImportError): # 优化先不做()
|
with contextlib.suppress(ImportError): # 优化先不做()
|
||||||
import nonebot.adapters.onebot.v11 # type: ignore
|
import nonebot.adapters.onebot.v11 # type: ignore
|
||||||
|
|
||||||
from .azure_onebot import poke_notify
|
from .marsho_onebot import poke_notify
|
||||||
|
|
||||||
@poke_notify.handle()
|
@poke_notify.handle()
|
||||||
async def poke(event: Event):
|
async def poke(event: Event):
|
||||||
@ -499,18 +446,3 @@ with contextlib.suppress(ImportError): # 优化先不做()
|
|||||||
await UniMessage(str(e) + suggest_solution(str(e))).send()
|
await UniMessage(str(e) + suggest_solution(str(e))).send()
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
@driver.on_shutdown
|
|
||||||
async def auto_backup_context():
|
|
||||||
for target_info in target_list:
|
|
||||||
target_id, target_private = target_info
|
|
||||||
contexts_data = context.build(target_id, target_private)
|
|
||||||
if target_private:
|
|
||||||
target_uid = "private_" + target_id
|
|
||||||
else:
|
|
||||||
target_uid = "group_" + target_id
|
|
||||||
await save_context_to_json(
|
|
||||||
f"back_up_context_{target_uid}", contexts_data, "contexts/backup"
|
|
||||||
)
|
|
||||||
logger.info(f"已保存会话 {target_id} 的上下文备份,将在下次对话时恢复~")
|
|
0
nonebot_plugin_marshoai/azure_onebot.py → nonebot_plugin_marshoai/marsho_onebot.py
Executable file → Normal file
0
nonebot_plugin_marshoai/azure_onebot.py → nonebot_plugin_marshoai/marsho_onebot.py
Executable file → Normal file
@ -9,7 +9,7 @@ require("nonebot_plugin_apscheduler")
|
|||||||
require("nonebot_plugin_marshoai")
|
require("nonebot_plugin_marshoai")
|
||||||
from nonebot_plugin_apscheduler import scheduler
|
from nonebot_plugin_apscheduler import scheduler
|
||||||
|
|
||||||
from nonebot_plugin_marshoai.azure import client
|
from nonebot_plugin_marshoai.instances import client
|
||||||
from nonebot_plugin_marshoai.plugin import PluginMetadata, on_function_call
|
from nonebot_plugin_marshoai.plugin import PluginMetadata, on_function_call
|
||||||
from nonebot_plugin_marshoai.plugin.func_call.params import String
|
from nonebot_plugin_marshoai.plugin.func_call.params import String
|
||||||
|
|
||||||
@ -84,7 +84,7 @@ if plugin_config.marshoai_plugin_memory_scheduler:
|
|||||||
|
|
||||||
@driver.on_startup
|
@driver.on_startup
|
||||||
async def _():
|
async def _():
|
||||||
logger.info("小绵定时记忆整理已启动!")
|
logger.info("小棉定时记忆整理已启动!")
|
||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
organize_memories,
|
organize_memories,
|
||||||
"cron",
|
"cron",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user