From 41c5ac0ac7baf824975a1fb56f5d5c51f38abc4f Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Mon, 10 May 2021 18:39:59 +0800 Subject: [PATCH] :wheelchair: improve plugin matcher system --- nonebot/matcher.py | 31 +++++++++++---- nonebot/plugin/__init__.py | 13 ++++--- nonebot/plugin/manager.py | 79 +++++++++++++++++++------------------- 3 files changed, 70 insertions(+), 53 deletions(-) diff --git a/nonebot/matcher.py b/nonebot/matcher.py index 399ae8a4..ba6ca7b8 100644 --- a/nonebot/matcher.py +++ b/nonebot/matcher.py @@ -6,6 +6,7 @@ """ from functools import wraps +from types import ModuleType from datetime import datetime from contextvars import ContextVar from collections import defaultdict @@ -36,6 +37,9 @@ current_event: ContextVar = ContextVar("current_event") class MatcherMeta(type): if TYPE_CHECKING: module: Optional[str] + plugin_name: Optional[str] + module_name: Optional[str] + module_prefix: Optional[str] type: str rule: Rule permission: Permission @@ -46,7 +50,7 @@ class MatcherMeta(type): expire_time: Optional[datetime] def __repr__(self) -> str: - return (f"") @@ -56,11 +60,17 @@ class MatcherMeta(type): class Matcher(metaclass=MatcherMeta): """事件响应器类""" - module: Optional[str] = None + module: Optional[ModuleType] = None """ - :类型: ``Optional[str]`` - :说明: 事件响应器所在模块名称 + :类型: ``Optional[ModuleType]`` + :说明: 事件响应器所在模块 """ + plugin_name: Optional[str] = module and getattr(module, "__plugin_name__", + None) + module_name: Optional[str] = module and getattr(module, "__module_name__", + None) + module_prefix: Optional[str] = module and getattr(module, + "__module_prefix__", None) type: str = "" """ @@ -136,8 +146,9 @@ class Matcher(metaclass=MatcherMeta): self.state = self._default_state.copy() def __repr__(self) -> str: - return (f"") + return ( + f"") def __str__(self) -> str: return repr(self) @@ -153,7 +164,7 @@ class Matcher(metaclass=MatcherMeta): priority: int = 1, block: bool = False, *, - module: Optional[str] = None, + module: Optional[ModuleType] = None, default_state: Optional[T_State] = None, default_state_factory: Optional[T_StateFactory] = None, expire_time: Optional[datetime] = None) -> Type["Matcher"]: @@ -185,6 +196,12 @@ class Matcher(metaclass=MatcherMeta): "Matcher", (Matcher,), { "module": module, + "plugin_name": + module and getattr(module, "__plugin_name__", None), + "module_name": + module and getattr(module, "__module_name__", None), + "module_prefix": + module and getattr(module, "__module_prefix__", None), "type": type_, "rule": diff --git a/nonebot/plugin/__init__.py b/nonebot/plugin/__init__.py index bf1d2827..425590d4 100644 --- a/nonebot/plugin/__init__.py +++ b/nonebot/plugin/__init__.py @@ -65,15 +65,16 @@ class Plugin(object): - **类型**: ``Set[Type[Matcher]]`` - **说明**: 插件内定义的 ``Matcher`` """ - return reduce( - lambda x, y: x | _plugin_matchers[y], - filter(lambda x: x.startswith(self.name), _plugin_matchers.keys()), - set()) + # return reduce( + # lambda x, y: x | _plugin_matchers[y], + # filter(lambda x: x.startswith(self.name), _plugin_matchers.keys()), + # set()) + return _plugin_matchers.get(self.name, set()) def _store_matcher(matcher: Type[Matcher]): - if matcher.module: - _plugin_matchers[matcher.module].add(matcher) + if matcher.plugin_name: + _plugin_matchers[matcher.plugin_name].add(matcher) def on(type: str = "", diff --git a/nonebot/plugin/manager.py b/nonebot/plugin/manager.py index f4121900..51091274 100644 --- a/nonebot/plugin/manager.py +++ b/nonebot/plugin/manager.py @@ -12,8 +12,8 @@ from importlib.machinery import PathFinder, SourceFileLoader from .export import _export, Export -_current_plugin: ContextVar[Optional[str]] = ContextVar("_current_plugin", - default=None) +_current_plugin: ContextVar[Optional[ModuleType]] = ContextVar( + "_current_plugin", default=None) _internal_space = ModuleType(__name__ + "._internal") _internal_space.__path__ = [] # type: ignore @@ -53,14 +53,13 @@ class _InternalModule(ModuleType): class PluginManager: def __init__(self, - namespace: Optional[str] = None, + namespace: str, plugins: Optional[Iterable[str]] = None, search_path: Optional[Iterable[str]] = None, *, id: Optional[str] = None): - self.namespace: Optional[str] = namespace - self.namespace_module: Optional[ModuleType] = self._setup_namespace( - namespace) + self.namespace: str = namespace + self.namespace_module: ModuleType = self._setup_namespace(namespace) self.id: str = id or str(uuid.uuid4()) self.internal_id: str = md5( @@ -73,12 +72,7 @@ class PluginManager: # ensure can be loaded self.list_plugins() - def _setup_namespace(self, - namespace: Optional[str] = None - ) -> Optional[ModuleType]: - if not namespace: - return None - + def _setup_namespace(self, namespace: str) -> ModuleType: try: module = importlib.import_module(namespace) except ImportError: @@ -156,14 +150,18 @@ class PluginManager: def load_all_plugins(self) -> List[ModuleType]: return [self.load_plugin(name) for name in self.list_plugins()] - def _rewrite_module_name(self, module_name) -> Optional[str]: + def _rewrite_module_name(self, module_name: str) -> Optional[str]: prefix = f"{self.internal_module.__name__}." - if module_name.startswith(self.namespace + "."): - path = module_name.split(".") - length = self.namespace.count(".") + 1 - return f"{prefix}{'.'.join(path[length:])}" + raw_name = module_name[len(self.namespace) + + 1:] if module_name.startswith(self.namespace + + ".") else None + # dir plugins + if raw_name and raw_name.split(".")[0] in self.search_plugins(): + return f"{prefix}{raw_name}" + # third party plugin or renamed dir plugins elif module_name in self.plugins or module_name.startswith(prefix): return module_name + # dir plugins elif module_name in self.search_plugins(): return f"{prefix}{module_name}" return None @@ -194,43 +192,44 @@ class PluginLoader(SourceFileLoader): def __init__(self, manager: PluginManager, fullname: str, path) -> None: self.manager = manager self.loaded = False - self._plugin_token = None - self._export_token = None super().__init__(fullname, path) def create_module(self, spec) -> Optional[ModuleType]: if self.name in sys.modules: self.loaded = True return sys.modules[self.name] - prefix = self.manager.internal_module.__name__ - plugin_name = self.name[len(prefix):] if self.name.startswith( - prefix) else self.name - self._plugin_token = _current_plugin.set(plugin_name.lstrip(".")) - self._export_token = _export.set(Export()) # return None to use default module creation return super().create_module(spec) def exec_module(self, module: ModuleType) -> None: if self.loaded: return - # really need? - # setattr(module, "__manager__", self.manager) - if self._plugin_token: - setattr(module, "__plugin_name__", - _current_plugin.get(self._plugin_token)) - if self._export_token: - setattr(module, "__export__", _export.get()) - try: - super().exec_module(module) - except Exception as e: - raise ImportError( - f"Error when executing module {self.name}.") from e + export = Export() + _export_token = _export.set(export) - if self._plugin_token: - _current_plugin.reset(self._plugin_token) - if self._export_token: - _export.reset(self._export_token) + prefix = self.manager.internal_module.__name__ + is_dir_plugin = self.name.startswith(prefix + ".") + module_name = self.name[len(prefix) + + 1:] if is_dir_plugin else self.name + _plugin_token = _current_plugin.set(module) + + setattr(module, "__export__", export) + setattr(module, "__plugin_name__", + module_name.split(".")[0] if is_dir_plugin else module_name) + setattr(module, "__module_name__", module_name) + setattr(module, "__module_prefix__", prefix if is_dir_plugin else "") + + # try: + # super().exec_module(module) + # except Exception as e: + # raise ImportError( + # f"Error when executing module {module_name} from {module.__file__}." + # ) from e + super().exec_module(module) + + _current_plugin.reset(_plugin_token) + _export.reset(_export_token) return