🐛 fix parent detect error after require (#1121)

This commit is contained in:
Ju4tCode 2022-08-04 13:39:20 +08:00 committed by GitHub
parent 48ccef2f06
commit 2192e8cb6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 54 additions and 27 deletions

View File

@ -36,12 +36,12 @@ FrontMatter:
from itertools import chain
from types import ModuleType
from contextvars import ContextVar
from typing import Set, Dict, List, Optional
from typing import Set, Dict, List, Tuple, Optional
_plugins: Dict[str, "Plugin"] = {}
_managers: List["PluginManager"] = []
_current_plugin: ContextVar[Optional["Plugin"]] = ContextVar(
"_current_plugin", default=None
_current_plugin_chain: ContextVar[Tuple["Plugin", ...]] = ContextVar(
"_current_plugin_chain", default=tuple()
)

View File

@ -9,7 +9,7 @@ FrontMatter:
import warnings
from . import _current_plugin
from . import _current_plugin_chain
class Export(dict):
@ -58,7 +58,7 @@ def export() -> Export:
"See https://github.com/nonebot/nonebot2/issues/935.",
DeprecationWarning,
)
plugin = _current_plugin.get()
if not plugin:
plugins = _current_plugin_chain.get()
if not plugins:
raise RuntimeError("Export outside of the plugin!")
return plugin.export
return plugins[-1].export

View File

@ -24,7 +24,7 @@ from . import (
_managers,
_new_plugin,
_revert_plugin,
_current_plugin,
_current_plugin_chain,
_module_name_to_plugin_name,
)
@ -223,15 +223,15 @@ class PluginLoader(SourceFileLoader):
setattr(module, "__plugin__", plugin)
# detect parent plugin before entering current plugin context
parent_plugin = _current_plugin.get()
if parent_plugin and _managers.index(parent_plugin.manager) < _managers.index(
self.manager
):
plugin.parent_plugin = parent_plugin
parent_plugin.sub_plugins.add(plugin)
parent_plugins = _current_plugin_chain.get()
for pre_plugin in reversed(parent_plugins):
if _managers.index(pre_plugin.manager) < _managers.index(self.manager):
plugin.parent_plugin = pre_plugin
pre_plugin.sub_plugins.add(plugin)
break
# enter plugin context
_plugin_token = _current_plugin.set(plugin)
_plugin_token = _current_plugin_chain.set(parent_plugins + (plugin,))
try:
super().exec_module(module)
@ -240,7 +240,7 @@ class PluginLoader(SourceFileLoader):
raise
finally:
# leave plugin context
_current_plugin.reset(_plugin_token)
_current_plugin_chain.reset(_plugin_token)
# get plugin metadata
metadata: Optional[PluginMetadata] = getattr(module, "__plugin_meta__", None)

View File

@ -26,14 +26,14 @@ from nonebot.rule import (
shell_command,
)
from .manager import _current_plugin
from .manager import _current_plugin_chain
def _store_matcher(matcher: Type[Matcher]) -> None:
plugin = _current_plugin.get()
plugins = _current_plugin_chain.get()
# only store the matcher defined in the plugin
if plugin:
plugin.matcher.add(matcher)
if plugins:
plugins[-1].matcher.add(matcher)
def _get_matcher_module(depth: int = 1) -> Optional[ModuleType]:
@ -70,6 +70,7 @@ def on(
block: 是否阻止事件向更低优先级传递
state: 默认 state
"""
plugin_chain = _current_plugin_chain.get()
matcher = Matcher.new(
type,
Rule() & rule,
@ -79,7 +80,7 @@ def on(
priority=priority,
block=block,
handlers=handlers,
plugin=_current_plugin.get(),
plugin=plugin_chain[-1] if plugin_chain else None,
module=_get_matcher_module(_depth + 1),
default_state=state,
)
@ -109,6 +110,7 @@ def on_metaevent(
block: 是否阻止事件向更低优先级传递
state: 默认 state
"""
plugin_chain = _current_plugin_chain.get()
matcher = Matcher.new(
"meta_event",
Rule() & rule,
@ -118,7 +120,7 @@ def on_metaevent(
priority=priority,
block=block,
handlers=handlers,
plugin=_current_plugin.get(),
plugin=plugin_chain[-1] if plugin_chain else None,
module=_get_matcher_module(_depth + 1),
default_state=state,
)
@ -150,6 +152,7 @@ def on_message(
block: 是否阻止事件向更低优先级传递
state: 默认 state
"""
plugin_chain = _current_plugin_chain.get()
matcher = Matcher.new(
"message",
Rule() & rule,
@ -159,7 +162,7 @@ def on_message(
priority=priority,
block=block,
handlers=handlers,
plugin=_current_plugin.get(),
plugin=plugin_chain[-1] if plugin_chain else None,
module=_get_matcher_module(_depth + 1),
default_state=state,
)
@ -189,6 +192,7 @@ def on_notice(
block: 是否阻止事件向更低优先级传递
state: 默认 state
"""
plugin_chain = _current_plugin_chain.get()
matcher = Matcher.new(
"notice",
Rule() & rule,
@ -198,7 +202,7 @@ def on_notice(
priority=priority,
block=block,
handlers=handlers,
plugin=_current_plugin.get(),
plugin=plugin_chain[-1] if plugin_chain else None,
module=_get_matcher_module(_depth + 1),
default_state=state,
)
@ -228,6 +232,7 @@ def on_request(
block: 是否阻止事件向更低优先级传递
state: 默认 state
"""
plugin_chain = _current_plugin_chain.get()
matcher = Matcher.new(
"request",
Rule() & rule,
@ -237,7 +242,7 @@ def on_request(
priority=priority,
block=block,
handlers=handlers,
plugin=_current_plugin.get(),
plugin=plugin_chain[-1] if plugin_chain else None,
module=_get_matcher_module(_depth + 1),
default_state=state,
)

View File

@ -1,6 +1,13 @@
from pathlib import Path
import nonebot
from nonebot.plugin import PluginManager, _managers
_sub_plugins = set()
_sub_plugins |= nonebot.load_plugins(str((Path(__file__).parent / "plugins").resolve()))
manager = PluginManager(
search_path=[str((Path(__file__).parent / "plugins").resolve())]
)
_managers.append(manager)
# test load nested plugin with require
manager.load_plugin("nested_subplugin")
manager.load_plugin("nested_subplugin2")

View File

@ -0,0 +1 @@
from .nested_subplugin2 import a

View File

@ -0,0 +1 @@
a = "required by another subplugin"

View File

@ -38,6 +38,19 @@ async def test_load_plugin(app: App, load_plugin: Set["Plugin"]):
assert nonebot.load_plugin("some_plugin_not_exist") is None
@pytest.mark.asyncio
async def test_load_nested_plugin(app: App, load_plugin: Set["Plugin"]):
import nonebot
parent_plugin = nonebot.get_plugin("nested")
sub_plugin = nonebot.get_plugin("nested_subplugin")
sub_plugin2 = nonebot.get_plugin("nested_subplugin2")
assert parent_plugin and sub_plugin and sub_plugin2
assert sub_plugin.parent_plugin is parent_plugin
assert sub_plugin2.parent_plugin is parent_plugin
assert parent_plugin.sub_plugins == {sub_plugin, sub_plugin2}
@pytest.mark.asyncio
async def test_bad_plugin(app: App):
import nonebot