🐛 fix parent detect error after require (#1121)

This commit is contained in:
Ju4tCode 2022-08-04 13:39:20 +08:00 committed by GitHub
parent 48ccef2f06
commit 2192e8cb6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 54 additions and 27 deletions

View File

@ -36,12 +36,12 @@ FrontMatter:
from itertools import chain from itertools import chain
from types import ModuleType from types import ModuleType
from contextvars import ContextVar from contextvars import ContextVar
from typing import Set, Dict, List, Optional from typing import Set, Dict, List, Tuple, Optional
_plugins: Dict[str, "Plugin"] = {} _plugins: Dict[str, "Plugin"] = {}
_managers: List["PluginManager"] = [] _managers: List["PluginManager"] = []
_current_plugin: ContextVar[Optional["Plugin"]] = ContextVar( _current_plugin_chain: ContextVar[Tuple["Plugin", ...]] = ContextVar(
"_current_plugin", default=None "_current_plugin_chain", default=tuple()
) )

View File

@ -9,7 +9,7 @@ FrontMatter:
import warnings import warnings
from . import _current_plugin from . import _current_plugin_chain
class Export(dict): class Export(dict):
@ -58,7 +58,7 @@ def export() -> Export:
"See https://github.com/nonebot/nonebot2/issues/935.", "See https://github.com/nonebot/nonebot2/issues/935.",
DeprecationWarning, DeprecationWarning,
) )
plugin = _current_plugin.get() plugins = _current_plugin_chain.get()
if not plugin: if not plugins:
raise RuntimeError("Export outside of the plugin!") raise RuntimeError("Export outside of the plugin!")
return plugin.export return plugins[-1].export

View File

@ -24,7 +24,7 @@ from . import (
_managers, _managers,
_new_plugin, _new_plugin,
_revert_plugin, _revert_plugin,
_current_plugin, _current_plugin_chain,
_module_name_to_plugin_name, _module_name_to_plugin_name,
) )
@ -223,15 +223,15 @@ class PluginLoader(SourceFileLoader):
setattr(module, "__plugin__", plugin) setattr(module, "__plugin__", plugin)
# detect parent plugin before entering current plugin context # detect parent plugin before entering current plugin context
parent_plugin = _current_plugin.get() parent_plugins = _current_plugin_chain.get()
if parent_plugin and _managers.index(parent_plugin.manager) < _managers.index( for pre_plugin in reversed(parent_plugins):
self.manager if _managers.index(pre_plugin.manager) < _managers.index(self.manager):
): plugin.parent_plugin = pre_plugin
plugin.parent_plugin = parent_plugin pre_plugin.sub_plugins.add(plugin)
parent_plugin.sub_plugins.add(plugin) break
# enter plugin context # enter plugin context
_plugin_token = _current_plugin.set(plugin) _plugin_token = _current_plugin_chain.set(parent_plugins + (plugin,))
try: try:
super().exec_module(module) super().exec_module(module)
@ -240,7 +240,7 @@ class PluginLoader(SourceFileLoader):
raise raise
finally: finally:
# leave plugin context # leave plugin context
_current_plugin.reset(_plugin_token) _current_plugin_chain.reset(_plugin_token)
# get plugin metadata # get plugin metadata
metadata: Optional[PluginMetadata] = getattr(module, "__plugin_meta__", None) metadata: Optional[PluginMetadata] = getattr(module, "__plugin_meta__", None)

View File

@ -26,14 +26,14 @@ from nonebot.rule import (
shell_command, shell_command,
) )
from .manager import _current_plugin from .manager import _current_plugin_chain
def _store_matcher(matcher: Type[Matcher]) -> None: def _store_matcher(matcher: Type[Matcher]) -> None:
plugin = _current_plugin.get() plugins = _current_plugin_chain.get()
# only store the matcher defined in the plugin # only store the matcher defined in the plugin
if plugin: if plugins:
plugin.matcher.add(matcher) plugins[-1].matcher.add(matcher)
def _get_matcher_module(depth: int = 1) -> Optional[ModuleType]: def _get_matcher_module(depth: int = 1) -> Optional[ModuleType]:
@ -70,6 +70,7 @@ def on(
block: 是否阻止事件向更低优先级传递 block: 是否阻止事件向更低优先级传递
state: 默认 state state: 默认 state
""" """
plugin_chain = _current_plugin_chain.get()
matcher = Matcher.new( matcher = Matcher.new(
type, type,
Rule() & rule, Rule() & rule,
@ -79,7 +80,7 @@ def on(
priority=priority, priority=priority,
block=block, block=block,
handlers=handlers, handlers=handlers,
plugin=_current_plugin.get(), plugin=plugin_chain[-1] if plugin_chain else None,
module=_get_matcher_module(_depth + 1), module=_get_matcher_module(_depth + 1),
default_state=state, default_state=state,
) )
@ -109,6 +110,7 @@ def on_metaevent(
block: 是否阻止事件向更低优先级传递 block: 是否阻止事件向更低优先级传递
state: 默认 state state: 默认 state
""" """
plugin_chain = _current_plugin_chain.get()
matcher = Matcher.new( matcher = Matcher.new(
"meta_event", "meta_event",
Rule() & rule, Rule() & rule,
@ -118,7 +120,7 @@ def on_metaevent(
priority=priority, priority=priority,
block=block, block=block,
handlers=handlers, handlers=handlers,
plugin=_current_plugin.get(), plugin=plugin_chain[-1] if plugin_chain else None,
module=_get_matcher_module(_depth + 1), module=_get_matcher_module(_depth + 1),
default_state=state, default_state=state,
) )
@ -150,6 +152,7 @@ def on_message(
block: 是否阻止事件向更低优先级传递 block: 是否阻止事件向更低优先级传递
state: 默认 state state: 默认 state
""" """
plugin_chain = _current_plugin_chain.get()
matcher = Matcher.new( matcher = Matcher.new(
"message", "message",
Rule() & rule, Rule() & rule,
@ -159,7 +162,7 @@ def on_message(
priority=priority, priority=priority,
block=block, block=block,
handlers=handlers, handlers=handlers,
plugin=_current_plugin.get(), plugin=plugin_chain[-1] if plugin_chain else None,
module=_get_matcher_module(_depth + 1), module=_get_matcher_module(_depth + 1),
default_state=state, default_state=state,
) )
@ -189,6 +192,7 @@ def on_notice(
block: 是否阻止事件向更低优先级传递 block: 是否阻止事件向更低优先级传递
state: 默认 state state: 默认 state
""" """
plugin_chain = _current_plugin_chain.get()
matcher = Matcher.new( matcher = Matcher.new(
"notice", "notice",
Rule() & rule, Rule() & rule,
@ -198,7 +202,7 @@ def on_notice(
priority=priority, priority=priority,
block=block, block=block,
handlers=handlers, handlers=handlers,
plugin=_current_plugin.get(), plugin=plugin_chain[-1] if plugin_chain else None,
module=_get_matcher_module(_depth + 1), module=_get_matcher_module(_depth + 1),
default_state=state, default_state=state,
) )
@ -228,6 +232,7 @@ def on_request(
block: 是否阻止事件向更低优先级传递 block: 是否阻止事件向更低优先级传递
state: 默认 state state: 默认 state
""" """
plugin_chain = _current_plugin_chain.get()
matcher = Matcher.new( matcher = Matcher.new(
"request", "request",
Rule() & rule, Rule() & rule,
@ -237,7 +242,7 @@ def on_request(
priority=priority, priority=priority,
block=block, block=block,
handlers=handlers, handlers=handlers,
plugin=_current_plugin.get(), plugin=plugin_chain[-1] if plugin_chain else None,
module=_get_matcher_module(_depth + 1), module=_get_matcher_module(_depth + 1),
default_state=state, default_state=state,
) )

View File

@ -1,6 +1,13 @@
from pathlib import Path from pathlib import Path
import nonebot import nonebot
from nonebot.plugin import PluginManager, _managers
_sub_plugins = set() manager = PluginManager(
_sub_plugins |= nonebot.load_plugins(str((Path(__file__).parent / "plugins").resolve())) search_path=[str((Path(__file__).parent / "plugins").resolve())]
)
_managers.append(manager)
# test load nested plugin with require
manager.load_plugin("nested_subplugin")
manager.load_plugin("nested_subplugin2")

View File

@ -0,0 +1 @@
from .nested_subplugin2 import a

View File

@ -0,0 +1 @@
a = "required by another subplugin"

View File

@ -38,6 +38,19 @@ async def test_load_plugin(app: App, load_plugin: Set["Plugin"]):
assert nonebot.load_plugin("some_plugin_not_exist") is None assert nonebot.load_plugin("some_plugin_not_exist") is None
@pytest.mark.asyncio
async def test_load_nested_plugin(app: App, load_plugin: Set["Plugin"]):
import nonebot
parent_plugin = nonebot.get_plugin("nested")
sub_plugin = nonebot.get_plugin("nested_subplugin")
sub_plugin2 = nonebot.get_plugin("nested_subplugin2")
assert parent_plugin and sub_plugin and sub_plugin2
assert sub_plugin.parent_plugin is parent_plugin
assert sub_plugin2.parent_plugin is parent_plugin
assert parent_plugin.sub_plugins == {sub_plugin, sub_plugin2}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bad_plugin(app: App): async def test_bad_plugin(app: App):
import nonebot import nonebot