diff --git a/nonebot/matcher.py b/nonebot/matcher.py index 7f996323..8a5f60d6 100644 --- a/nonebot/matcher.py +++ b/nonebot/matcher.py @@ -178,6 +178,7 @@ class Matcher(metaclass=MatcherMeta): block: bool = False, *, plugin: Optional["Plugin"] = None, + module: Optional[ModuleType] = None, expire_time: Optional[datetime] = None, default_state: Optional[T_State] = None, default_state_factory: Optional[T_StateFactory] = None, @@ -200,6 +201,7 @@ class Matcher(metaclass=MatcherMeta): * ``priority: int``: 响应优先级 * ``block: bool``: 是否阻止事件向更低优先级的响应器传播 * ``plugin: Optional[Plugin]``: 事件响应器所在插件 + * ``module: Optional[ModuleType]``: 事件响应器所在模块 * ``default_state: Optional[T_State]``: 默认状态 ``state`` * ``default_state_factory: Optional[T_StateFactory]``: 默认状态 ``state`` 的工厂函数 * ``expire_time: Optional[datetime]``: 事件响应器最终有效时间点,过时即被删除 @@ -210,18 +212,15 @@ class Matcher(metaclass=MatcherMeta): """ NewMatcher = type( - "Matcher", - (Matcher,), - { + "Matcher", (Matcher,), { "plugin": plugin, "module": - plugin and plugin. - module, # FIXME: matcher module may different from plugin module + module, "plugin_name": plugin and plugin.name, "module_name": - plugin and plugin.module_name, + module and module.__name__, "type": type_, "rule": diff --git a/nonebot/plugin/manager.py b/nonebot/plugin/manager.py index 825e3844..d0243c91 100644 --- a/nonebot/plugin/manager.py +++ b/nonebot/plugin/manager.py @@ -1,8 +1,6 @@ import sys -import uuid import pkgutil import importlib -from hashlib import md5 from pathlib import Path from types import ModuleType from collections import Counter @@ -14,23 +12,17 @@ from .export import Export from . import _current_plugin from .plugin import Plugin, _new_plugin +_manager_stack: List["PluginManager"] = [] + # TODO class PluginManager: - def __init__(self, - namespace: str, - plugins: Optional[Iterable[str]] = None, - search_path: Optional[Iterable[str]] = None, - *, - id: Optional[str] = None): - self.namespace: str = namespace - self.namespace_module: ModuleType = self._setup_namespace(namespace) - - self.id: str = id or str(uuid.uuid4()) - self.internal_id: str = md5( - ((self.namespace or "") + self.id).encode()).hexdigest() - self.internal_module = self._setup_internal_module(self.internal_id) + def __init__( + self, + plugins: Optional[Iterable[str]] = None, + search_path: Optional[Iterable[str]] = None, + ): # simple plugin not in search path self.plugins: Set[str] = set(plugins or []) @@ -38,55 +30,6 @@ class PluginManager: # ensure can be loaded self.list_plugins() - def _setup_namespace(self, namespace: str) -> ModuleType: - try: - module = importlib.import_module(namespace) - except ImportError: - module = _NamespaceModule(namespace) - if "." in namespace: - parent = importlib.import_module(namespace.rsplit(".", 1)[0]) - setattr(parent, namespace.rsplit(".", 1)[1], module) - - sys.modules[namespace] = module - return module - - def _setup_internal_module(self, internal_id: str) -> ModuleType: - if hasattr(_internal_space, internal_id): - raise RuntimeError("Plugin manager already exists!") - - index = 2 - prefix: str = _internal_space.__name__ - while True: - try: - frame = sys._getframe(index) - except ValueError: - break - # check if is called in plugin - if "__plugin_name__" not in frame.f_globals: - index += 1 - continue - prefix = frame.f_globals.get("__name__", _internal_space.__name__) - break - - if not prefix.startswith(_internal_space.__name__): - prefix = _internal_space.__name__ - module = _InternalModule(prefix, self) - sys.modules[module.__name__] = module # type: ignore - setattr(_internal_space, internal_id, module) - return module - - def __enter__(self): - if self in _manager_stack: - raise RuntimeError("Plugin manager already activated!") - _manager_stack.append(self) - return self - - def __exit__(self, exc_type, exc_value, traceback): - try: - _manager_stack.pop() - except IndexError: - pass - def search_plugins(self) -> List[str]: return [ module_info.name @@ -116,14 +59,12 @@ class PluginManager: def load_plugin(self, name) -> ModuleType: if name in self.plugins: - with self: - return importlib.import_module(name) + return importlib.import_module(name) if "." in name: raise ValueError("Plugin name cannot contain '.'") - with self: - return importlib.import_module(f"{self.namespace}.{name}") + return importlib.import_module(f"{self.namespace}.{name}") def load_all_plugins(self) -> List[ModuleType]: return [self.load_plugin(name) for name in self.list_plugins()] diff --git a/nonebot/plugin/on.py b/nonebot/plugin/on.py index aa648b7a..f8d22167 100644 --- a/nonebot/plugin/on.py +++ b/nonebot/plugin/on.py @@ -1,4 +1,7 @@ import re +import sys +import inspect +from types import ModuleType from typing import (TYPE_CHECKING, Any, Set, Dict, List, Type, Tuple, Union, Optional) @@ -22,6 +25,15 @@ def _store_matcher(matcher: Type[Matcher]) -> None: plugin.matcher.add(matcher) +def _get_matcher_module(depth: int = 1) -> Optional[ModuleType]: + current_frame = inspect.currentframe() + if current_frame is None: + return None + frame = inspect.getouterframes(current_frame)[depth + 1].frame + module_name = frame.f_globals["__name__"] + return sys.modules.get(module_name) + + def on(type: str = "", rule: Optional[Union[Rule, T_RuleChecker]] = None, permission: Optional[Permission] = None, @@ -61,6 +73,7 @@ def on(type: str = "", block=block, handlers=handlers, plugin=_current_plugin.get(), + module=_get_matcher_module(), default_state=state, default_state_factory=state_factory) _store_matcher(matcher) @@ -103,6 +116,7 @@ def on_metaevent( block=block, handlers=handlers, plugin=_current_plugin.get(), + module=_get_matcher_module(), default_state=state, default_state_factory=state_factory) _store_matcher(matcher) @@ -146,6 +160,7 @@ def on_message(rule: Optional[Union[Rule, T_RuleChecker]] = None, block=block, handlers=handlers, plugin=_current_plugin.get(), + module=_get_matcher_module(), default_state=state, default_state_factory=state_factory) _store_matcher(matcher) @@ -187,6 +202,7 @@ def on_notice(rule: Optional[Union[Rule, T_RuleChecker]] = None, block=block, handlers=handlers, plugin=_current_plugin.get(), + module=_get_matcher_module(), default_state=state, default_state_factory=state_factory) _store_matcher(matcher) @@ -228,6 +244,7 @@ def on_request(rule: Optional[Union[Rule, T_RuleChecker]] = None, block=block, handlers=handlers, plugin=_current_plugin.get(), + module=_get_matcher_module(), default_state=state, default_state_factory=state_factory) _store_matcher(matcher)