🐛 fix error matcher module when import

This commit is contained in:
yanyongyu 2021-03-22 01:15:15 +08:00
parent d738f8674d
commit 6371cd6bfe
2 changed files with 30 additions and 13 deletions

View File

@ -19,7 +19,7 @@ from nonebot.permission import Permission
from nonebot.typing import T_State, T_StateFactory, T_Handler, T_RuleChecker from nonebot.typing import T_State, T_StateFactory, T_Handler, T_RuleChecker
from nonebot.rule import Rule, startswith, endswith, keyword, command, shell_command, ArgumentParser, regex from nonebot.rule import Rule, startswith, endswith, keyword, command, shell_command, ArgumentParser, regex
from .manager import PluginManager from .manager import PluginManager, _current_plugin
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
@ -32,6 +32,7 @@ plugins: Dict[str, "Plugin"] = {}
PLUGIN_NAMESPACE = "nonebot.loaded_plugins" PLUGIN_NAMESPACE = "nonebot.loaded_plugins"
_export: ContextVar["Export"] = ContextVar("_export") _export: ContextVar["Export"] = ContextVar("_export")
# FIXME: tmp matchers context var will be removed
_tmp_matchers: ContextVar[Set[Type[Matcher]]] = ContextVar("_tmp_matchers") _tmp_matchers: ContextVar[Set[Type[Matcher]]] = ContextVar("_tmp_matchers")
@ -142,6 +143,7 @@ def on(type: str = "",
priority=priority, priority=priority,
block=block, block=block,
handlers=handlers, handlers=handlers,
module=_current_plugin.get(),
default_state=state, default_state=state,
default_state_factory=state_factory) default_state_factory=state_factory)
_tmp_matchers.get().add(matcher) _tmp_matchers.get().add(matcher)
@ -183,6 +185,7 @@ def on_metaevent(
priority=priority, priority=priority,
block=block, block=block,
handlers=handlers, handlers=handlers,
module=_current_plugin.get(),
default_state=state, default_state=state,
default_state_factory=state_factory) default_state_factory=state_factory)
_tmp_matchers.get().add(matcher) _tmp_matchers.get().add(matcher)
@ -225,6 +228,7 @@ def on_message(rule: Optional[Union[Rule, T_RuleChecker]] = None,
priority=priority, priority=priority,
block=block, block=block,
handlers=handlers, handlers=handlers,
module=_current_plugin.get(),
default_state=state, default_state=state,
default_state_factory=state_factory) default_state_factory=state_factory)
_tmp_matchers.get().add(matcher) _tmp_matchers.get().add(matcher)
@ -265,6 +269,7 @@ def on_notice(rule: Optional[Union[Rule, T_RuleChecker]] = None,
priority=priority, priority=priority,
block=block, block=block,
handlers=handlers, handlers=handlers,
module=_current_plugin.get(),
default_state=state, default_state=state,
default_state_factory=state_factory) default_state_factory=state_factory)
_tmp_matchers.get().add(matcher) _tmp_matchers.get().add(matcher)
@ -305,6 +310,7 @@ def on_request(rule: Optional[Union[Rule, T_RuleChecker]] = None,
priority=priority, priority=priority,
block=block, block=block,
handlers=handlers, handlers=handlers,
module=_current_plugin.get(),
default_state=state, default_state=state,
default_state_factory=state_factory) default_state_factory=state_factory)
_tmp_matchers.get().add(matcher) _tmp_matchers.get().add(matcher)
@ -960,8 +966,7 @@ def _load_plugin(manager: PluginManager, plugin_name: str) -> Optional[Plugin]:
try: try:
module = manager.load_plugin(plugin_name) module = manager.load_plugin(plugin_name)
# for m in _tmp_matchers.get(): # FIXME: store matchers using new method
# m.module = plugin_name
plugin = Plugin(plugin_name, module, _tmp_matchers.get(), _export.get()) plugin = Plugin(plugin_name, module, _tmp_matchers.get(), _export.get())
plugins[plugin_name] = plugin plugins[plugin_name] = plugin
logger.opt( logger.opt(

View File

@ -5,10 +5,14 @@ import importlib
from hashlib import md5 from hashlib import md5
from types import ModuleType from types import ModuleType
from collections import Counter from collections import Counter
from contextvars import ContextVar
from importlib.abc import MetaPathFinder from importlib.abc import MetaPathFinder
from importlib.machinery import PathFinder, SourceFileLoader from importlib.machinery import PathFinder, FrozenImporter, SourceFileLoader
from typing import Set, List, Optional, Iterable from typing import Set, List, Optional, Iterable
_current_plugin: ContextVar[Optional[str]] = ContextVar("_current_plugin",
default=None)
_internal_space = ModuleType(__name__ + "._internal") _internal_space = ModuleType(__name__ + "._internal")
_internal_space.__path__ = [] # type: ignore _internal_space.__path__ = [] # type: ignore
sys.modules[_internal_space.__name__] = _internal_space sys.modules[_internal_space.__name__] = _internal_space
@ -138,6 +142,7 @@ class PluginManager:
def load_plugin(self, name) -> ModuleType: def load_plugin(self, name) -> ModuleType:
if name in self.plugins: if name in self.plugins:
with self:
return importlib.import_module(name) return importlib.import_module(name)
if "." in name: if "." in name:
@ -150,14 +155,15 @@ class PluginManager:
return [self.load_plugin(name) for name in self.list_plugins()] return [self.load_plugin(name) for name in self.list_plugins()]
def _rewrite_module_name(self, module_name) -> Optional[str]: def _rewrite_module_name(self, module_name) -> Optional[str]:
if module_name == self.namespace: prefix = f"{self.internal_module.__name__}."
return self.internal_module.__name__ if module_name.startswith(self.namespace + "."):
elif module_name.startswith(self.namespace + "."):
path = module_name.split(".") path = module_name.split(".")
length = self.namespace.count(".") + 1 length = self.namespace.count(".") + 1
return f"{self.internal_module.__name__}.{'.'.join(path[length:])}" return f"{prefix}{'.'.join(path[length:])}"
elif module_name in self.search_plugins(): elif module_name in self.search_plugins():
return f"{self.internal_module.__name__}.{module_name}" return f"{prefix}{module_name}"
elif module_name in self.plugins or module_name.startswith(prefix):
return module_name
return None return None
@ -170,9 +176,8 @@ class PluginFinder(MetaPathFinder):
manager = _manager_stack[index] manager = _manager_stack[index]
newname = manager._rewrite_module_name(fullname) newname = manager._rewrite_module_name(fullname)
if newname: if newname:
spec = PathFinder.find_spec(newname, spec = PathFinder.find_spec(
list(manager.search_path), newname, [*manager.search_path, *(path or [])], target)
target)
if spec: if spec:
spec.loader = PluginLoader(manager, newname, spec.loader = PluginLoader(manager, newname,
spec.origin) spec.origin)
@ -186,12 +191,17 @@ class PluginLoader(SourceFileLoader):
def __init__(self, manager: PluginManager, fullname: str, path) -> None: def __init__(self, manager: PluginManager, fullname: str, path) -> None:
self.manager = manager self.manager = manager
self.loaded = False self.loaded = False
self._context_token = None
super().__init__(fullname, path) super().__init__(fullname, path)
def create_module(self, spec) -> Optional[ModuleType]: def create_module(self, spec) -> Optional[ModuleType]:
if self.name in sys.modules: if self.name in sys.modules:
self.loaded = True self.loaded = True
return sys.modules[self.name] return sys.modules[self.name]
prefix = self.manager.internal_module.__name__
plugin_name = self.name[len(prefix):] if self.name.startswith(
prefix) else self.name
self._context_token = _current_plugin.set(plugin_name.lstrip("."))
# return None to use default module creation # return None to use default module creation
return super().create_module(spec) return super().create_module(spec)
@ -200,6 +210,8 @@ class PluginLoader(SourceFileLoader):
return return
setattr(module, "__manager__", self.manager) setattr(module, "__manager__", self.manager)
super().exec_module(module) super().exec_module(module)
if self._context_token:
_current_plugin.reset(self._context_token)
return return