diff --git a/nonebot/plugin/__init__.py b/nonebot/plugin/__init__.py index ded978b4..b17a8190 100644 --- a/nonebot/plugin/__init__.py +++ b/nonebot/plugin/__init__.py @@ -19,7 +19,7 @@ from nonebot.permission import Permission from nonebot.typing import T_State, T_StateFactory, T_Handler, T_RuleChecker from nonebot.rule import Rule, startswith, endswith, keyword, command, shell_command, ArgumentParser, regex -from .manager import PluginManager +from .manager import PluginManager, _current_plugin if TYPE_CHECKING: from nonebot.adapters import Bot, Event @@ -32,6 +32,7 @@ plugins: Dict[str, "Plugin"] = {} PLUGIN_NAMESPACE = "nonebot.loaded_plugins" _export: ContextVar["Export"] = ContextVar("_export") +# FIXME: tmp matchers context var will be removed _tmp_matchers: ContextVar[Set[Type[Matcher]]] = ContextVar("_tmp_matchers") @@ -142,6 +143,7 @@ def on(type: str = "", priority=priority, block=block, handlers=handlers, + module=_current_plugin.get(), default_state=state, default_state_factory=state_factory) _tmp_matchers.get().add(matcher) @@ -183,6 +185,7 @@ def on_metaevent( priority=priority, block=block, handlers=handlers, + module=_current_plugin.get(), default_state=state, default_state_factory=state_factory) _tmp_matchers.get().add(matcher) @@ -225,6 +228,7 @@ def on_message(rule: Optional[Union[Rule, T_RuleChecker]] = None, priority=priority, block=block, handlers=handlers, + module=_current_plugin.get(), default_state=state, default_state_factory=state_factory) _tmp_matchers.get().add(matcher) @@ -265,6 +269,7 @@ def on_notice(rule: Optional[Union[Rule, T_RuleChecker]] = None, priority=priority, block=block, handlers=handlers, + module=_current_plugin.get(), default_state=state, default_state_factory=state_factory) _tmp_matchers.get().add(matcher) @@ -305,6 +310,7 @@ def on_request(rule: Optional[Union[Rule, T_RuleChecker]] = None, priority=priority, block=block, handlers=handlers, + module=_current_plugin.get(), default_state=state, default_state_factory=state_factory) _tmp_matchers.get().add(matcher) @@ -960,8 +966,7 @@ def _load_plugin(manager: PluginManager, plugin_name: str) -> Optional[Plugin]: try: module = manager.load_plugin(plugin_name) - # for m in _tmp_matchers.get(): - # m.module = plugin_name + # FIXME: store matchers using new method plugin = Plugin(plugin_name, module, _tmp_matchers.get(), _export.get()) plugins[plugin_name] = plugin logger.opt( diff --git a/nonebot/plugin/manager.py b/nonebot/plugin/manager.py index bac90246..07c47115 100644 --- a/nonebot/plugin/manager.py +++ b/nonebot/plugin/manager.py @@ -5,10 +5,14 @@ import importlib from hashlib import md5 from types import ModuleType from collections import Counter +from contextvars import ContextVar from importlib.abc import MetaPathFinder -from importlib.machinery import PathFinder, SourceFileLoader +from importlib.machinery import PathFinder, FrozenImporter, SourceFileLoader from typing import Set, List, Optional, Iterable +_current_plugin: ContextVar[Optional[str]] = ContextVar("_current_plugin", + default=None) + _internal_space = ModuleType(__name__ + "._internal") _internal_space.__path__ = [] # type: ignore sys.modules[_internal_space.__name__] = _internal_space @@ -138,7 +142,8 @@ class PluginManager: def load_plugin(self, name) -> ModuleType: if name in self.plugins: - return importlib.import_module(name) + with self: + return importlib.import_module(name) if "." in name: raise ValueError("Plugin name cannot contain '.'") @@ -150,14 +155,15 @@ class PluginManager: return [self.load_plugin(name) for name in self.list_plugins()] def _rewrite_module_name(self, module_name) -> Optional[str]: - if module_name == self.namespace: - return self.internal_module.__name__ - elif module_name.startswith(self.namespace + "."): + prefix = f"{self.internal_module.__name__}." + if module_name.startswith(self.namespace + "."): path = module_name.split(".") length = self.namespace.count(".") + 1 - return f"{self.internal_module.__name__}.{'.'.join(path[length:])}" + return f"{prefix}{'.'.join(path[length:])}" elif module_name in self.search_plugins(): - return f"{self.internal_module.__name__}.{module_name}" + return f"{prefix}{module_name}" + elif module_name in self.plugins or module_name.startswith(prefix): + return module_name return None @@ -170,9 +176,8 @@ class PluginFinder(MetaPathFinder): manager = _manager_stack[index] newname = manager._rewrite_module_name(fullname) if newname: - spec = PathFinder.find_spec(newname, - list(manager.search_path), - target) + spec = PathFinder.find_spec( + newname, [*manager.search_path, *(path or [])], target) if spec: spec.loader = PluginLoader(manager, newname, spec.origin) @@ -186,12 +191,17 @@ class PluginLoader(SourceFileLoader): def __init__(self, manager: PluginManager, fullname: str, path) -> None: self.manager = manager self.loaded = False + self._context_token = None super().__init__(fullname, path) def create_module(self, spec) -> Optional[ModuleType]: if self.name in sys.modules: self.loaded = True return sys.modules[self.name] + prefix = self.manager.internal_module.__name__ + plugin_name = self.name[len(prefix):] if self.name.startswith( + prefix) else self.name + self._context_token = _current_plugin.set(plugin_name.lstrip(".")) # return None to use default module creation return super().create_module(spec) @@ -200,6 +210,8 @@ class PluginLoader(SourceFileLoader): return setattr(module, "__manager__", self.manager) super().exec_module(module) + if self._context_token: + _current_plugin.reset(self._context_token) return