🐛 fix error matcher module when import

This commit is contained in:
yanyongyu 2021-03-22 01:15:15 +08:00
parent d738f8674d
commit 6371cd6bfe
2 changed files with 30 additions and 13 deletions

View File

@ -19,7 +19,7 @@ from nonebot.permission import Permission
from nonebot.typing import T_State, T_StateFactory, T_Handler, T_RuleChecker
from nonebot.rule import Rule, startswith, endswith, keyword, command, shell_command, ArgumentParser, regex
from .manager import PluginManager
from .manager import PluginManager, _current_plugin
if TYPE_CHECKING:
from nonebot.adapters import Bot, Event
@ -32,6 +32,7 @@ plugins: Dict[str, "Plugin"] = {}
PLUGIN_NAMESPACE = "nonebot.loaded_plugins"
_export: ContextVar["Export"] = ContextVar("_export")
# FIXME: tmp matchers context var will be removed
_tmp_matchers: ContextVar[Set[Type[Matcher]]] = ContextVar("_tmp_matchers")
@ -142,6 +143,7 @@ def on(type: str = "",
priority=priority,
block=block,
handlers=handlers,
module=_current_plugin.get(),
default_state=state,
default_state_factory=state_factory)
_tmp_matchers.get().add(matcher)
@ -183,6 +185,7 @@ def on_metaevent(
priority=priority,
block=block,
handlers=handlers,
module=_current_plugin.get(),
default_state=state,
default_state_factory=state_factory)
_tmp_matchers.get().add(matcher)
@ -225,6 +228,7 @@ def on_message(rule: Optional[Union[Rule, T_RuleChecker]] = None,
priority=priority,
block=block,
handlers=handlers,
module=_current_plugin.get(),
default_state=state,
default_state_factory=state_factory)
_tmp_matchers.get().add(matcher)
@ -265,6 +269,7 @@ def on_notice(rule: Optional[Union[Rule, T_RuleChecker]] = None,
priority=priority,
block=block,
handlers=handlers,
module=_current_plugin.get(),
default_state=state,
default_state_factory=state_factory)
_tmp_matchers.get().add(matcher)
@ -305,6 +310,7 @@ def on_request(rule: Optional[Union[Rule, T_RuleChecker]] = None,
priority=priority,
block=block,
handlers=handlers,
module=_current_plugin.get(),
default_state=state,
default_state_factory=state_factory)
_tmp_matchers.get().add(matcher)
@ -960,8 +966,7 @@ def _load_plugin(manager: PluginManager, plugin_name: str) -> Optional[Plugin]:
try:
module = manager.load_plugin(plugin_name)
# for m in _tmp_matchers.get():
# m.module = plugin_name
# FIXME: store matchers using new method
plugin = Plugin(plugin_name, module, _tmp_matchers.get(), _export.get())
plugins[plugin_name] = plugin
logger.opt(

View File

@ -5,10 +5,14 @@ import importlib
from hashlib import md5
from types import ModuleType
from collections import Counter
from contextvars import ContextVar
from importlib.abc import MetaPathFinder
from importlib.machinery import PathFinder, SourceFileLoader
from importlib.machinery import PathFinder, FrozenImporter, SourceFileLoader
from typing import Set, List, Optional, Iterable
_current_plugin: ContextVar[Optional[str]] = ContextVar("_current_plugin",
default=None)
_internal_space = ModuleType(__name__ + "._internal")
_internal_space.__path__ = [] # type: ignore
sys.modules[_internal_space.__name__] = _internal_space
@ -138,6 +142,7 @@ class PluginManager:
def load_plugin(self, name) -> ModuleType:
if name in self.plugins:
with self:
return importlib.import_module(name)
if "." in name:
@ -150,14 +155,15 @@ class PluginManager:
return [self.load_plugin(name) for name in self.list_plugins()]
def _rewrite_module_name(self, module_name) -> Optional[str]:
if module_name == self.namespace:
return self.internal_module.__name__
elif module_name.startswith(self.namespace + "."):
prefix = f"{self.internal_module.__name__}."
if module_name.startswith(self.namespace + "."):
path = module_name.split(".")
length = self.namespace.count(".") + 1
return f"{self.internal_module.__name__}.{'.'.join(path[length:])}"
return f"{prefix}{'.'.join(path[length:])}"
elif module_name in self.search_plugins():
return f"{self.internal_module.__name__}.{module_name}"
return f"{prefix}{module_name}"
elif module_name in self.plugins or module_name.startswith(prefix):
return module_name
return None
@ -170,9 +176,8 @@ class PluginFinder(MetaPathFinder):
manager = _manager_stack[index]
newname = manager._rewrite_module_name(fullname)
if newname:
spec = PathFinder.find_spec(newname,
list(manager.search_path),
target)
spec = PathFinder.find_spec(
newname, [*manager.search_path, *(path or [])], target)
if spec:
spec.loader = PluginLoader(manager, newname,
spec.origin)
@ -186,12 +191,17 @@ class PluginLoader(SourceFileLoader):
def __init__(self, manager: PluginManager, fullname: str, path) -> None:
self.manager = manager
self.loaded = False
self._context_token = None
super().__init__(fullname, path)
def create_module(self, spec) -> Optional[ModuleType]:
if self.name in sys.modules:
self.loaded = True
return sys.modules[self.name]
prefix = self.manager.internal_module.__name__
plugin_name = self.name[len(prefix):] if self.name.startswith(
prefix) else self.name
self._context_token = _current_plugin.set(plugin_name.lstrip("."))
# return None to use default module creation
return super().create_module(spec)
@ -200,6 +210,8 @@ class PluginLoader(SourceFileLoader):
return
setattr(module, "__manager__", self.manager)
super().exec_module(module)
if self._context_token:
_current_plugin.reset(self._context_token)
return