mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-01-19 01:18:19 +08:00
🐛 fix parent detect error after require (#1121)
This commit is contained in:
parent
48ccef2f06
commit
2192e8cb6d
@ -36,12 +36,12 @@ FrontMatter:
|
||||
from itertools import chain
|
||||
from types import ModuleType
|
||||
from contextvars import ContextVar
|
||||
from typing import Set, Dict, List, Optional
|
||||
from typing import Set, Dict, List, Tuple, Optional
|
||||
|
||||
_plugins: Dict[str, "Plugin"] = {}
|
||||
_managers: List["PluginManager"] = []
|
||||
_current_plugin: ContextVar[Optional["Plugin"]] = ContextVar(
|
||||
"_current_plugin", default=None
|
||||
_current_plugin_chain: ContextVar[Tuple["Plugin", ...]] = ContextVar(
|
||||
"_current_plugin_chain", default=tuple()
|
||||
)
|
||||
|
||||
|
||||
|
@ -9,7 +9,7 @@ FrontMatter:
|
||||
|
||||
import warnings
|
||||
|
||||
from . import _current_plugin
|
||||
from . import _current_plugin_chain
|
||||
|
||||
|
||||
class Export(dict):
|
||||
@ -58,7 +58,7 @@ def export() -> Export:
|
||||
"See https://github.com/nonebot/nonebot2/issues/935.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
plugin = _current_plugin.get()
|
||||
if not plugin:
|
||||
plugins = _current_plugin_chain.get()
|
||||
if not plugins:
|
||||
raise RuntimeError("Export outside of the plugin!")
|
||||
return plugin.export
|
||||
return plugins[-1].export
|
||||
|
@ -24,7 +24,7 @@ from . import (
|
||||
_managers,
|
||||
_new_plugin,
|
||||
_revert_plugin,
|
||||
_current_plugin,
|
||||
_current_plugin_chain,
|
||||
_module_name_to_plugin_name,
|
||||
)
|
||||
|
||||
@ -223,15 +223,15 @@ class PluginLoader(SourceFileLoader):
|
||||
setattr(module, "__plugin__", plugin)
|
||||
|
||||
# detect parent plugin before entering current plugin context
|
||||
parent_plugin = _current_plugin.get()
|
||||
if parent_plugin and _managers.index(parent_plugin.manager) < _managers.index(
|
||||
self.manager
|
||||
):
|
||||
plugin.parent_plugin = parent_plugin
|
||||
parent_plugin.sub_plugins.add(plugin)
|
||||
parent_plugins = _current_plugin_chain.get()
|
||||
for pre_plugin in reversed(parent_plugins):
|
||||
if _managers.index(pre_plugin.manager) < _managers.index(self.manager):
|
||||
plugin.parent_plugin = pre_plugin
|
||||
pre_plugin.sub_plugins.add(plugin)
|
||||
break
|
||||
|
||||
# enter plugin context
|
||||
_plugin_token = _current_plugin.set(plugin)
|
||||
_plugin_token = _current_plugin_chain.set(parent_plugins + (plugin,))
|
||||
|
||||
try:
|
||||
super().exec_module(module)
|
||||
@ -240,7 +240,7 @@ class PluginLoader(SourceFileLoader):
|
||||
raise
|
||||
finally:
|
||||
# leave plugin context
|
||||
_current_plugin.reset(_plugin_token)
|
||||
_current_plugin_chain.reset(_plugin_token)
|
||||
|
||||
# get plugin metadata
|
||||
metadata: Optional[PluginMetadata] = getattr(module, "__plugin_meta__", None)
|
||||
|
@ -26,14 +26,14 @@ from nonebot.rule import (
|
||||
shell_command,
|
||||
)
|
||||
|
||||
from .manager import _current_plugin
|
||||
from .manager import _current_plugin_chain
|
||||
|
||||
|
||||
def _store_matcher(matcher: Type[Matcher]) -> None:
|
||||
plugin = _current_plugin.get()
|
||||
plugins = _current_plugin_chain.get()
|
||||
# only store the matcher defined in the plugin
|
||||
if plugin:
|
||||
plugin.matcher.add(matcher)
|
||||
if plugins:
|
||||
plugins[-1].matcher.add(matcher)
|
||||
|
||||
|
||||
def _get_matcher_module(depth: int = 1) -> Optional[ModuleType]:
|
||||
@ -70,6 +70,7 @@ def on(
|
||||
block: 是否阻止事件向更低优先级传递
|
||||
state: 默认 state
|
||||
"""
|
||||
plugin_chain = _current_plugin_chain.get()
|
||||
matcher = Matcher.new(
|
||||
type,
|
||||
Rule() & rule,
|
||||
@ -79,7 +80,7 @@ def on(
|
||||
priority=priority,
|
||||
block=block,
|
||||
handlers=handlers,
|
||||
plugin=_current_plugin.get(),
|
||||
plugin=plugin_chain[-1] if plugin_chain else None,
|
||||
module=_get_matcher_module(_depth + 1),
|
||||
default_state=state,
|
||||
)
|
||||
@ -109,6 +110,7 @@ def on_metaevent(
|
||||
block: 是否阻止事件向更低优先级传递
|
||||
state: 默认 state
|
||||
"""
|
||||
plugin_chain = _current_plugin_chain.get()
|
||||
matcher = Matcher.new(
|
||||
"meta_event",
|
||||
Rule() & rule,
|
||||
@ -118,7 +120,7 @@ def on_metaevent(
|
||||
priority=priority,
|
||||
block=block,
|
||||
handlers=handlers,
|
||||
plugin=_current_plugin.get(),
|
||||
plugin=plugin_chain[-1] if plugin_chain else None,
|
||||
module=_get_matcher_module(_depth + 1),
|
||||
default_state=state,
|
||||
)
|
||||
@ -150,6 +152,7 @@ def on_message(
|
||||
block: 是否阻止事件向更低优先级传递
|
||||
state: 默认 state
|
||||
"""
|
||||
plugin_chain = _current_plugin_chain.get()
|
||||
matcher = Matcher.new(
|
||||
"message",
|
||||
Rule() & rule,
|
||||
@ -159,7 +162,7 @@ def on_message(
|
||||
priority=priority,
|
||||
block=block,
|
||||
handlers=handlers,
|
||||
plugin=_current_plugin.get(),
|
||||
plugin=plugin_chain[-1] if plugin_chain else None,
|
||||
module=_get_matcher_module(_depth + 1),
|
||||
default_state=state,
|
||||
)
|
||||
@ -189,6 +192,7 @@ def on_notice(
|
||||
block: 是否阻止事件向更低优先级传递
|
||||
state: 默认 state
|
||||
"""
|
||||
plugin_chain = _current_plugin_chain.get()
|
||||
matcher = Matcher.new(
|
||||
"notice",
|
||||
Rule() & rule,
|
||||
@ -198,7 +202,7 @@ def on_notice(
|
||||
priority=priority,
|
||||
block=block,
|
||||
handlers=handlers,
|
||||
plugin=_current_plugin.get(),
|
||||
plugin=plugin_chain[-1] if plugin_chain else None,
|
||||
module=_get_matcher_module(_depth + 1),
|
||||
default_state=state,
|
||||
)
|
||||
@ -228,6 +232,7 @@ def on_request(
|
||||
block: 是否阻止事件向更低优先级传递
|
||||
state: 默认 state
|
||||
"""
|
||||
plugin_chain = _current_plugin_chain.get()
|
||||
matcher = Matcher.new(
|
||||
"request",
|
||||
Rule() & rule,
|
||||
@ -237,7 +242,7 @@ def on_request(
|
||||
priority=priority,
|
||||
block=block,
|
||||
handlers=handlers,
|
||||
plugin=_current_plugin.get(),
|
||||
plugin=plugin_chain[-1] if plugin_chain else None,
|
||||
module=_get_matcher_module(_depth + 1),
|
||||
default_state=state,
|
||||
)
|
||||
|
@ -1,6 +1,13 @@
|
||||
from pathlib import Path
|
||||
|
||||
import nonebot
|
||||
from nonebot.plugin import PluginManager, _managers
|
||||
|
||||
_sub_plugins = set()
|
||||
_sub_plugins |= nonebot.load_plugins(str((Path(__file__).parent / "plugins").resolve()))
|
||||
manager = PluginManager(
|
||||
search_path=[str((Path(__file__).parent / "plugins").resolve())]
|
||||
)
|
||||
_managers.append(manager)
|
||||
|
||||
# test load nested plugin with require
|
||||
manager.load_plugin("nested_subplugin")
|
||||
manager.load_plugin("nested_subplugin2")
|
||||
|
@ -0,0 +1 @@
|
||||
from .nested_subplugin2 import a
|
1
tests/plugins/nested/plugins/nested_subplugin2.py
Normal file
1
tests/plugins/nested/plugins/nested_subplugin2.py
Normal file
@ -0,0 +1 @@
|
||||
a = "required by another subplugin"
|
@ -38,6 +38,19 @@ async def test_load_plugin(app: App, load_plugin: Set["Plugin"]):
|
||||
assert nonebot.load_plugin("some_plugin_not_exist") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_nested_plugin(app: App, load_plugin: Set["Plugin"]):
|
||||
import nonebot
|
||||
|
||||
parent_plugin = nonebot.get_plugin("nested")
|
||||
sub_plugin = nonebot.get_plugin("nested_subplugin")
|
||||
sub_plugin2 = nonebot.get_plugin("nested_subplugin2")
|
||||
assert parent_plugin and sub_plugin and sub_plugin2
|
||||
assert sub_plugin.parent_plugin is parent_plugin
|
||||
assert sub_plugin2.parent_plugin is parent_plugin
|
||||
assert parent_plugin.sub_plugins == {sub_plugin, sub_plugin2}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bad_plugin(app: App):
|
||||
import nonebot
|
||||
|
Loading…
Reference in New Issue
Block a user