From 99c113833eb150dd2528947b247de546bb8d1bd4 Mon Sep 17 00:00:00 2001 From: Snowykami Date: Tue, 17 Dec 2024 20:51:42 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20=E6=94=AF=E6=8C=81=E5=BC=80?= =?UTF-8?q?=E5=8F=91=E7=83=AD=E9=87=8D=E8=BD=BD=E6=8F=92=E4=BB=B6=EF=BC=8C?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E7=8B=AC=E7=AB=8B=E6=B5=8B=E8=AF=95=E5=87=BD?= =?UTF-8?q?=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot_plugin_marshoai/azure.py | 2 +- nonebot_plugin_marshoai/dev.py | 58 ++++++++++- nonebot_plugin_marshoai/models.py | 7 ++ nonebot_plugin_marshoai/observer.py | 99 +++++++++++++++++++ .../plugin/func_call/caller.py | 8 +- nonebot_plugin_marshoai/plugin/load.py | 41 +++++++- nonebot_plugin_marshoai/plugin/models.py | 2 +- .../plugins/random_number_generator.py | 5 + pyproject.toml | 3 +- 9 files changed, 211 insertions(+), 14 deletions(-) create mode 100755 nonebot_plugin_marshoai/observer.py diff --git a/nonebot_plugin_marshoai/azure.py b/nonebot_plugin_marshoai/azure.py index c1257fc4..51ade472 100644 --- a/nonebot_plugin_marshoai/azure.py +++ b/nonebot_plugin_marshoai/azure.py @@ -112,7 +112,7 @@ async def _preload_tools(): @driver.on_startup -async def _preload_plugins(): +async def _(): """启动钩子加载插件""" if config.marshoai_enable_plugins: marshoai_plugin_dirs = config.marshoai_plugin_dirs # 外部插件目录列表 diff --git a/nonebot_plugin_marshoai/dev.py b/nonebot_plugin_marshoai/dev.py index 5109c668..0f114348 100644 --- a/nonebot_plugin_marshoai/dev.py +++ b/nonebot_plugin_marshoai/dev.py @@ -1,9 +1,16 @@ -from nonebot import require +import os +from pathlib import Path + +from nonebot import get_driver, logger, require from nonebot.adapters import Bot, Event from nonebot.matcher import Matcher from nonebot.typing import T_State -from nonebot_plugin_marshoai.plugin.func_call.models import SessionContext +from nonebot_plugin_marshoai.plugin.load import reload_plugin + +from .azure import context +from .config import config +from .plugin.func_call.models import SessionContext require("nonebot_plugin_alconna") @@ -17,8 +24,12 @@ from nonebot_plugin_alconna import ( on_alconna, ) +from .observer import * +from .plugin import get_plugin, get_plugins from .plugin.func_call.caller import get_function_calls +driver = get_driver() + function_call = on_alconna( command=Alconna( "marsho-function-call", @@ -66,8 +77,13 @@ async def call_function( ): function = get_function_calls().get(function_name) if function is None: - await UniMessage(f"未找到函数 {function_name}").send() - return + for f in get_function_calls().values(): + if f.short_name == function_name: + function = f + break + else: + await UniMessage(f"未找到函数 {function_name}").send() + return await UniMessage( str( await function.with_ctx( @@ -75,3 +91,37 @@ async def call_function( ).call(**{i.split("=", 1)[0]: i.split("=", 1)[1] for i in kwargs}) ) ).send() + + +@on_file_system_event( + (str(Path(__file__).parent / "plugins"), *config.marshoai_plugin_dirs), + recursive=True, +) +def on_plugin_file_change(event): + if event.src_path.endswith(".py"): + logger.info(f"文件变动: {event.src_path}") + # 层层向上查找到插件目录 + dir_list: list[str] = event.src_path.split("/") # type: ignore + dir_list[-1] = dir_list[-1].split(".", 1)[0] + dir_list.reverse() + for plugin_name in dir_list: + if plugin := get_plugin(plugin_name): + if plugin.module_path.endswith("__init__.py"): + # 包插件 + if os.path.dirname(plugin.module_path).replace( + "\\", "/" + ) in event.src_path.replace("\\", "/"): + logger.debug(f"找到变动插件: {plugin.name},正在重新加载") + reload_plugin(plugin) + context.reset_all() + break + else: + # 单文件插件 + if plugin.module_path == event.src_path: + logger.debug(f"找到变动插件: {plugin.name},正在重新加载") + reload_plugin(plugin) + context.reset_all() + break + else: + logger.debug("未找到变动插件") + return diff --git a/nonebot_plugin_marshoai/models.py b/nonebot_plugin_marshoai/models.py index 04ae473e..2d29ad59 100755 --- a/nonebot_plugin_marshoai/models.py +++ b/nonebot_plugin_marshoai/models.py @@ -47,6 +47,13 @@ class MarshoContext: if target_id in target_dict: target_dict[target_id].clear() + def reset_all(self): + """ + 重置所有上下文 + """ + self.contents["private"].clear() + self.contents["non-private"].clear() + def build(self, target_id: str, is_private: bool) -> list: """ 构建返回的上下文,不包括系统消息 diff --git a/nonebot_plugin_marshoai/observer.py b/nonebot_plugin_marshoai/observer.py new file mode 100755 index 00000000..8d7573a4 --- /dev/null +++ b/nonebot_plugin_marshoai/observer.py @@ -0,0 +1,99 @@ +""" +此模块用于注册观察者函数,使用watchdog监控文件变化并重启bot +启用该模块需要在配置文件中设置`dev_mode`为True +""" + +import time +from typing import Callable, TypeAlias + +from nonebot import get_driver, logger +from watchdog.events import FileSystemEvent, FileSystemEventHandler +from watchdog.observers import Observer + +from .config import config + +CALLBACK_FUNC: TypeAlias = Callable[[FileSystemEvent], None] # 位置1为FileSystemEvent +FILTER_FUNC: TypeAlias = Callable[[FileSystemEvent], bool] # 位置1为FileSystemEvent + +observer = Observer() + +driver = get_driver() + + +def debounce(wait): + """ + 防抖函数 + """ + + def decorator(func): + def wrapper(*args, **kwargs): + nonlocal last_call_time + current_time = time.time() + if (current_time - last_call_time) > wait: + last_call_time = current_time + return func(*args, **kwargs) + + last_call_time = None + return wrapper + + return decorator + + +@driver.on_startup +async def check_for_reloader(): + if config.marshoai_devmode: + logger.debug("Marsho Reload enabled, watching for file changes...") + observer.start() + + +class CodeModifiedHandler(FileSystemEventHandler): + """ + Handler for code file changes + """ + + @debounce(1) + def on_modified(self, event): + raise NotImplementedError("on_modified must be implemented") + + def on_created(self, event): + self.on_modified(event) + + def on_deleted(self, event): + self.on_modified(event) + + def on_moved(self, event): + self.on_modified(event) + + def on_any_event(self, event): + self.on_modified(event) + + +def on_file_system_event( + directories: tuple[str, ...], + recursive: bool = True, + event_filter: FILTER_FUNC | None = None, +) -> Callable[[CALLBACK_FUNC], CALLBACK_FUNC]: + """ + 注册文件系统变化监听器 + Args: + directories: 监听目录们 + recursive: 是否递归监听子目录 + event_filter: 事件过滤器, 返回True则执行回调函数 + Returns: + 装饰器,装饰一个函数在接收到数据后执行 + """ + + def decorator(func: CALLBACK_FUNC) -> CALLBACK_FUNC: + def wrapper(event: FileSystemEvent): + + if event_filter is not None and not event_filter(event): + return + func(event) + + code_modified_handler = CodeModifiedHandler() + code_modified_handler.on_modified = wrapper + for directory in directories: + observer.schedule(code_modified_handler, directory, recursive=recursive) + return func + + return decorator diff --git a/nonebot_plugin_marshoai/plugin/func_call/caller.py b/nonebot_plugin_marshoai/plugin/func_call/caller.py index 1c4c2cef..ad4c77a3 100644 --- a/nonebot_plugin_marshoai/plugin/func_call/caller.py +++ b/nonebot_plugin_marshoai/plugin/func_call/caller.py @@ -27,7 +27,7 @@ class Caller: self.func: ASYNC_FUNCTION_CALL_FUNC | None = None """函数对象""" self.module_name: str = "" - """模块名""" + """模块名,仅为父级模块名,不一定是插件顶级模块名""" self._parameters: dict[str, Any] = {} """声明参数""" @@ -137,14 +137,14 @@ class Caller: self.func = async_wrap(func) # type: ignore if module := inspect.getmodule(func): - module_name = module.__name__.split(".")[-1] + "." + module_name = module.__name__.split(".")[-1] else: module_name = "" self.module_name = module_name _caller_data[self.full_name] = self logger.opt(colors=True).debug( - f"加载函数 {module_name}{func.__name__}: {self._description}" + f"加载函数 {self.full_name}: {self._description}" ) return func @@ -238,7 +238,7 @@ class Caller: @property def full_name(self) -> str: """完整名""" - return self.module_name + self._name + return self.module_name + "." + self._name @property def short_info(self) -> str: diff --git a/nonebot_plugin_marshoai/plugin/load.py b/nonebot_plugin_marshoai/plugin/load.py index 7bc10c15..f94569db 100755 --- a/nonebot_plugin_marshoai/plugin/load.py +++ b/nonebot_plugin_marshoai/plugin/load.py @@ -45,7 +45,9 @@ def get_plugins() -> dict[str, Plugin]: return _plugins -def load_plugin(module_path: str | Path) -> Optional[Plugin]: +def load_plugin( + module_path: str | Path, allow_reload: bool = False +) -> Optional[Plugin]: """加载单个插件,可以是本地插件或是通过 `pip` 安装的插件。 该函数产生的副作用在于将插件加载到 `_plugins` 中。 @@ -63,12 +65,15 @@ def load_plugin(module_path: str | Path) -> Optional[Plugin]: try: module = import_module(module_path) # 导入模块对象 plugin = Plugin( - name=module.__name__, + name=module.__name__.split(".")[-1], module=module, module_name=module_path, module_path=module.__file__, ) - _plugins[plugin.name] = plugin + if plugin.name in _plugins and not allow_reload: + raise ValueError(f"插件名称重复: {plugin.name}") + else: + _plugins[plugin.name] = plugin plugin.metadata = getattr(module, "__marsho_meta__", None) @@ -118,3 +123,33 @@ def load_plugins(*plugin_dirs: str) -> set[Plugin]: if module_name and (plugin := load_plugin(module_name)): plugins.add(plugin) return plugins + + +def reload_plugin(plugin: Plugin) -> Optional[Plugin]: + """开发模式下的重新加载插件 + 该方法无法保证没有副作用,因为插件可能会有自己的初始化方法 + 如果出现异常请重启即可 + Args: + plugin: 插件对象 + Returns: + Optional[Plugin]: 插件对象 + """ + try: + if plugin.module_path: + if new_plugin := load_plugin(plugin.module_name, True): + logger.opt(colors=True).debug( + f'重新加载插件 "{new_plugin.name}" 成功, 若出现异常或副作用请重启' + ) + return new_plugin + else: + logger.opt(colors=True).error( + f'重新加载插件失败 "{plugin.name}"' + ) + return None + else: + logger.opt(colors=True).error(f'插件不支持重载 "{plugin.name}"') + return None + except Exception as e: + logger.opt(colors=True).error(f'重新加载插件失败 "{plugin.name}"') + traceback.print_exc() + return None diff --git a/nonebot_plugin_marshoai/plugin/models.py b/nonebot_plugin_marshoai/plugin/models.py index bc042386..ce7a4f6a 100755 --- a/nonebot_plugin_marshoai/plugin/models.py +++ b/nonebot_plugin_marshoai/plugin/models.py @@ -57,7 +57,7 @@ class Plugin(BaseModel): module: ModuleType """插件模块对象""" module_name: str - """点分割模块路径 例如a.b.c""" + """点分或/割模块路径 例如a.b.c""" module_path: str | None """实际路径,单文件为.py的路径,包为__init__.py路径""" metadata: PluginMetadata | None = None diff --git a/nonebot_plugin_marshoai/plugins/random_number_generator.py b/nonebot_plugin_marshoai/plugins/random_number_generator.py index 500ef04b..019d0101 100644 --- a/nonebot_plugin_marshoai/plugins/random_number_generator.py +++ b/nonebot_plugin_marshoai/plugins/random_number_generator.py @@ -16,3 +16,8 @@ async def generate_random_numbers(count: int) -> str: # 该插件由MarshoAI自举编写 + + +@on_function_call(description="重载测试") +def test_reload(): + return 1 diff --git a/pyproject.toml b/pyproject.toml index d78cf9d6..bbc08e5f 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,8 @@ dependencies = [ "lxml[html_clean]>=5.3.0", "aiofiles>=24.1.0", "sumy>=0.11.0", - "azure-ai-inference>=1.0.0b6" + "azure-ai-inference>=1.0.0b6", + "watchdog>=6.0.0" ] license = { text = "MIT, Mulan PSL v2" }