diff --git a/nonebot/plugin/__init__.py b/nonebot/plugin/__init__.py index 77d3118d..349961d1 100644 --- a/nonebot/plugin/__init__.py +++ b/nonebot/plugin/__init__.py @@ -36,12 +36,12 @@ FrontMatter: from itertools import chain from types import ModuleType from contextvars import ContextVar -from typing import Set, Dict, List, Optional +from typing import Set, Dict, List, Tuple, Optional _plugins: Dict[str, "Plugin"] = {} _managers: List["PluginManager"] = [] -_current_plugin: ContextVar[Optional["Plugin"]] = ContextVar( - "_current_plugin", default=None +_current_plugin_chain: ContextVar[Tuple["Plugin", ...]] = ContextVar( + "_current_plugin_chain", default=tuple() ) diff --git a/nonebot/plugin/export.py b/nonebot/plugin/export.py index f27740ed..a64339f7 100644 --- a/nonebot/plugin/export.py +++ b/nonebot/plugin/export.py @@ -9,7 +9,7 @@ FrontMatter: import warnings -from . import _current_plugin +from . import _current_plugin_chain class Export(dict): @@ -58,7 +58,7 @@ def export() -> Export: "See https://github.com/nonebot/nonebot2/issues/935.", DeprecationWarning, ) - plugin = _current_plugin.get() - if not plugin: + plugins = _current_plugin_chain.get() + if not plugins: raise RuntimeError("Export outside of the plugin!") - return plugin.export + return plugins[-1].export diff --git a/nonebot/plugin/manager.py b/nonebot/plugin/manager.py index b136f27d..ece76620 100644 --- a/nonebot/plugin/manager.py +++ b/nonebot/plugin/manager.py @@ -24,7 +24,7 @@ from . import ( _managers, _new_plugin, _revert_plugin, - _current_plugin, + _current_plugin_chain, _module_name_to_plugin_name, ) @@ -223,15 +223,15 @@ class PluginLoader(SourceFileLoader): setattr(module, "__plugin__", plugin) # detect parent plugin before entering current plugin context - parent_plugin = _current_plugin.get() - if parent_plugin and _managers.index(parent_plugin.manager) < _managers.index( - self.manager - ): - plugin.parent_plugin = parent_plugin - parent_plugin.sub_plugins.add(plugin) + parent_plugins = _current_plugin_chain.get() + for pre_plugin in reversed(parent_plugins): + if _managers.index(pre_plugin.manager) < _managers.index(self.manager): + plugin.parent_plugin = pre_plugin + pre_plugin.sub_plugins.add(plugin) + break # enter plugin context - _plugin_token = _current_plugin.set(plugin) + _plugin_token = _current_plugin_chain.set(parent_plugins + (plugin,)) try: super().exec_module(module) @@ -240,7 +240,7 @@ class PluginLoader(SourceFileLoader): raise finally: # leave plugin context - _current_plugin.reset(_plugin_token) + _current_plugin_chain.reset(_plugin_token) # get plugin metadata metadata: Optional[PluginMetadata] = getattr(module, "__plugin_meta__", None) diff --git a/nonebot/plugin/on.py b/nonebot/plugin/on.py index 54a15078..390d7a27 100644 --- a/nonebot/plugin/on.py +++ b/nonebot/plugin/on.py @@ -26,14 +26,14 @@ from nonebot.rule import ( shell_command, ) -from .manager import _current_plugin +from .manager import _current_plugin_chain def _store_matcher(matcher: Type[Matcher]) -> None: - plugin = _current_plugin.get() + plugins = _current_plugin_chain.get() # only store the matcher defined in the plugin - if plugin: - plugin.matcher.add(matcher) + if plugins: + plugins[-1].matcher.add(matcher) def _get_matcher_module(depth: int = 1) -> Optional[ModuleType]: @@ -70,6 +70,7 @@ def on( block: 是否阻止事件向更低优先级传递 state: 默认 state """ + plugin_chain = _current_plugin_chain.get() matcher = Matcher.new( type, Rule() & rule, @@ -79,7 +80,7 @@ def on( priority=priority, block=block, handlers=handlers, - plugin=_current_plugin.get(), + plugin=plugin_chain[-1] if plugin_chain else None, module=_get_matcher_module(_depth + 1), default_state=state, ) @@ -109,6 +110,7 @@ def on_metaevent( block: 是否阻止事件向更低优先级传递 state: 默认 state """ + plugin_chain = _current_plugin_chain.get() matcher = Matcher.new( "meta_event", Rule() & rule, @@ -118,7 +120,7 @@ def on_metaevent( priority=priority, block=block, handlers=handlers, - plugin=_current_plugin.get(), + plugin=plugin_chain[-1] if plugin_chain else None, module=_get_matcher_module(_depth + 1), default_state=state, ) @@ -150,6 +152,7 @@ def on_message( block: 是否阻止事件向更低优先级传递 state: 默认 state """ + plugin_chain = _current_plugin_chain.get() matcher = Matcher.new( "message", Rule() & rule, @@ -159,7 +162,7 @@ def on_message( priority=priority, block=block, handlers=handlers, - plugin=_current_plugin.get(), + plugin=plugin_chain[-1] if plugin_chain else None, module=_get_matcher_module(_depth + 1), default_state=state, ) @@ -189,6 +192,7 @@ def on_notice( block: 是否阻止事件向更低优先级传递 state: 默认 state """ + plugin_chain = _current_plugin_chain.get() matcher = Matcher.new( "notice", Rule() & rule, @@ -198,7 +202,7 @@ def on_notice( priority=priority, block=block, handlers=handlers, - plugin=_current_plugin.get(), + plugin=plugin_chain[-1] if plugin_chain else None, module=_get_matcher_module(_depth + 1), default_state=state, ) @@ -228,6 +232,7 @@ def on_request( block: 是否阻止事件向更低优先级传递 state: 默认 state """ + plugin_chain = _current_plugin_chain.get() matcher = Matcher.new( "request", Rule() & rule, @@ -237,7 +242,7 @@ def on_request( priority=priority, block=block, handlers=handlers, - plugin=_current_plugin.get(), + plugin=plugin_chain[-1] if plugin_chain else None, module=_get_matcher_module(_depth + 1), default_state=state, ) diff --git a/tests/plugins/nested/__init__.py b/tests/plugins/nested/__init__.py index eb4cf3e4..6f514513 100644 --- a/tests/plugins/nested/__init__.py +++ b/tests/plugins/nested/__init__.py @@ -1,6 +1,13 @@ from pathlib import Path import nonebot +from nonebot.plugin import PluginManager, _managers -_sub_plugins = set() -_sub_plugins |= nonebot.load_plugins(str((Path(__file__).parent / "plugins").resolve())) +manager = PluginManager( + search_path=[str((Path(__file__).parent / "plugins").resolve())] +) +_managers.append(manager) + +# test load nested plugin with require +manager.load_plugin("nested_subplugin") +manager.load_plugin("nested_subplugin2") diff --git a/tests/plugins/nested/plugins/nested_subplugin.py b/tests/plugins/nested/plugins/nested_subplugin.py index e69de29b..8cae942f 100644 --- a/tests/plugins/nested/plugins/nested_subplugin.py +++ b/tests/plugins/nested/plugins/nested_subplugin.py @@ -0,0 +1 @@ +from .nested_subplugin2 import a diff --git a/tests/plugins/nested/plugins/nested_subplugin2.py b/tests/plugins/nested/plugins/nested_subplugin2.py new file mode 100644 index 00000000..cbf42e7a --- /dev/null +++ b/tests/plugins/nested/plugins/nested_subplugin2.py @@ -0,0 +1 @@ +a = "required by another subplugin" diff --git a/tests/test_plugin/test_load.py b/tests/test_plugin/test_load.py index a99852ec..53b1d1df 100644 --- a/tests/test_plugin/test_load.py +++ b/tests/test_plugin/test_load.py @@ -38,6 +38,19 @@ async def test_load_plugin(app: App, load_plugin: Set["Plugin"]): assert nonebot.load_plugin("some_plugin_not_exist") is None +@pytest.mark.asyncio +async def test_load_nested_plugin(app: App, load_plugin: Set["Plugin"]): + import nonebot + + parent_plugin = nonebot.get_plugin("nested") + sub_plugin = nonebot.get_plugin("nested_subplugin") + sub_plugin2 = nonebot.get_plugin("nested_subplugin2") + assert parent_plugin and sub_plugin and sub_plugin2 + assert sub_plugin.parent_plugin is parent_plugin + assert sub_plugin2.parent_plugin is parent_plugin + assert parent_plugin.sub_plugins == {sub_plugin, sub_plugin2} + + @pytest.mark.asyncio async def test_bad_plugin(app: App): import nonebot