diff --git a/nonebot/__init__.py b/nonebot/__init__.py index 6e71bfa9..9c8d62ac 100644 --- a/nonebot/__init__.py +++ b/nonebot/__init__.py @@ -26,7 +26,9 @@ - `load_builtin_plugin` => {ref}``load_builtin_plugin` ` - `load_builtin_plugins` => {ref}``load_builtin_plugins` ` - `get_plugin` => {ref}``get_plugin` ` +- `get_plugin_by_module_name` => {ref}``get_plugin_by_module_name` ` - `get_loaded_plugins` => {ref}``get_loaded_plugins` ` +- `get_available_plugin_names` => {ref}``get_available_plugin_names` ` - `export` => {ref}``export` ` - `require` => {ref}``require` ` @@ -284,5 +286,7 @@ from nonebot.plugin import on_shell_command as on_shell_command from nonebot.plugin import get_loaded_plugins as get_loaded_plugins from nonebot.plugin import load_builtin_plugin as load_builtin_plugin from nonebot.plugin import load_builtin_plugins as load_builtin_plugins +from nonebot.plugin import get_plugin_by_module_name as get_plugin_by_module_name +from nonebot.plugin import get_available_plugin_names as get_available_plugin_names __autodoc__ = {"internal": False} diff --git a/nonebot/plugin/__init__.py b/nonebot/plugin/__init__.py index c7b48c3b..6358a884 100644 --- a/nonebot/plugin/__init__.py +++ b/nonebot/plugin/__init__.py @@ -35,14 +35,77 @@ FrontMatter: description: nonebot.plugin 模块 """ -from typing import List, Optional +from itertools import chain +from types import ModuleType from contextvars import ContextVar +from typing import Set, Dict, List, Optional +_plugins: Dict[str, "Plugin"] = {} _managers: List["PluginManager"] = [] _current_plugin: ContextVar[Optional["Plugin"]] = ContextVar( "_current_plugin", default=None ) + +def _module_name_to_plugin_name(module_name: str) -> str: + return module_name.rsplit(".", 1)[-1] + + +def _new_plugin( + module_name: str, module: ModuleType, manager: "PluginManager" +) -> "Plugin": + plugin_name = _module_name_to_plugin_name(module_name) + if plugin_name in _plugins: + raise RuntimeError("Plugin already exists! Check your plugin name.") + plugin = Plugin(plugin_name, module, module_name, manager) + _plugins[plugin_name] = plugin + return plugin + + +def _revert_plugin(plugin: "Plugin") -> None: + if plugin.name not in _plugins: + raise RuntimeError("Plugin not found!") + del _plugins[plugin.name] + + +def get_plugin(name: str) -> Optional["Plugin"]: + """获取已经导入的某个插件。 + + 如果为 `load_plugins` 文件夹导入的插件,则为文件(夹)名。 + + 参数: + name: 插件名,即 {ref}`nonebot.plugin.plugin.Plugin.name`。 + """ + return _plugins.get(name) + + +def get_plugin_by_module_name(module_name: str) -> Optional["Plugin"]: + """通过模块名获取已经导入的某个插件。 + + 如果提供的模块名为某个插件的子模块,同样会返回该插件。 + + 参数: + module_name: 模块名,即 {ref}`nonebot.plugin.plugin.Plugin.module_name`。 + """ + splits = module_name.split(".") + loaded = {plugin.module_name: plugin for plugin in _plugins.values()} + while splits: + name = ".".join(splits) + if name in loaded: + return loaded[name] + splits.pop() + + +def get_loaded_plugins() -> Set["Plugin"]: + """获取当前已导入的所有插件。""" + return set(_plugins.values()) + + +def get_available_plugin_names() -> Set[str]: + """获取当前所有可用的插件名(包含尚未加载的插件)。""" + return {*chain.from_iterable(manager.available_plugins for manager in _managers)} + + from .on import on as on from .manager import PluginManager from .export import Export as Export @@ -61,7 +124,6 @@ from .on import CommandGroup as CommandGroup from .on import MatcherGroup as MatcherGroup from .on import on_fullmatch as on_fullmatch from .on import on_metaevent as on_metaevent -from .plugin import get_plugin as get_plugin from .load import load_plugins as load_plugins from .on import on_startswith as on_startswith from .load import load_from_json as load_from_json @@ -69,5 +131,4 @@ from .load import load_from_toml as load_from_toml from .on import on_shell_command as on_shell_command from .load import load_all_plugins as load_all_plugins from .load import load_builtin_plugin as load_builtin_plugin -from .plugin import get_loaded_plugins as get_loaded_plugins from .load import load_builtin_plugins as load_builtin_plugins diff --git a/nonebot/plugin/load.py b/nonebot/plugin/load.py index 7e762b69..a241804f 100644 --- a/nonebot/plugin/load.py +++ b/nonebot/plugin/load.py @@ -10,10 +10,10 @@ from typing import Set, Iterable, Optional import tomlkit -from . import _managers from .export import Export +from .plugin import Plugin from .manager import PluginManager -from .plugin import Plugin, get_plugin +from . import _managers, get_plugin, _module_name_to_plugin_name def load_plugin(module_path: str) -> Optional[Plugin]: @@ -128,7 +128,7 @@ def load_builtin_plugin(name: str) -> Optional[Plugin]: return load_plugin(f"nonebot.plugins.{name}") -def load_builtin_plugins(*plugins) -> Set[Plugin]: +def load_builtin_plugins(*plugins: str) -> Set[Plugin]: """导入多个 NoneBot 内置插件。 参数: @@ -154,7 +154,7 @@ def require(name: str) -> Export: 异常: RuntimeError: 插件无法加载 """ - plugin = get_plugin(name.rsplit(".", 1)[-1]) + plugin = get_plugin(_module_name_to_plugin_name(name)) if not plugin: manager = _find_manager_by_name(name) if manager: diff --git a/nonebot/plugin/manager.py b/nonebot/plugin/manager.py index 7e8b2826..4b85111a 100644 --- a/nonebot/plugin/manager.py +++ b/nonebot/plugin/manager.py @@ -19,23 +19,52 @@ from typing import Set, Dict, List, Union, Iterable, Optional, Sequence from nonebot.log import logger from nonebot.utils import escape_tag -from . import _managers, _current_plugin -from .plugin import Plugin, _new_plugin, _confirm_plugin +from .plugin import Plugin +from . import ( + _managers, + _new_plugin, + _revert_plugin, + _current_plugin, + _module_name_to_plugin_name, +) class PluginManager: + """插件管理器。 + + 参数: + plugins: 独立插件模块名集合。 + search_path: 插件搜索路径(文件夹)。 + """ + def __init__( self, plugins: Optional[Iterable[str]] = None, search_path: Optional[Iterable[str]] = None, ): - # simple plugin not in search path self.plugins: Set[str] = set(plugins or []) self.search_path: Set[str] = set(search_path or []) + # cache plugins - self.searched_plugins: Dict[str, Path] = {} - self.list_plugins() + self._third_party_plugin_names: Dict[str, str] = {} + self._searched_plugin_names: Dict[str, Path] = {} + self.prepare_plugins() + + @property + def third_party_plugins(self) -> Set[str]: + """返回所有独立插件名称。""" + return set(self._third_party_plugin_names.keys()) + + @property + def searched_plugins(self) -> Set[str]: + """返回已搜索到的插件名称。""" + return set(self._searched_plugin_names.keys()) + + @property + def available_plugins(self) -> Set[str]: + """返回当前插件管理器中可用的插件名称。""" + return self.third_party_plugins | self.searched_plugins def _path_to_module_name(self, path: Path) -> str: rel_path = path.resolve().relative_to(Path(".").resolve()) @@ -44,48 +73,51 @@ class PluginManager: else: return ".".join(rel_path.parts[:-1] + (rel_path.stem,)) - def _previous_plugins(self) -> List[str]: + def _previous_plugins(self) -> Set[str]: _pre_managers: List[PluginManager] if self in _managers: _pre_managers = _managers[: _managers.index(self)] else: _pre_managers = _managers[:] - return [ - *chain.from_iterable( - [ - *map(lambda x: x.rsplit(".", 1)[-1], manager.plugins), - *manager.searched_plugins.keys(), - ] - for manager in _pre_managers - ) - ] + return { + *chain.from_iterable(manager.available_plugins for manager in _pre_managers) + } + + def prepare_plugins(self) -> Set[str]: + """搜索插件并缓存插件名称。""" - def list_plugins(self) -> Set[str]: # get all previous ready to load plugins previous_plugins = self._previous_plugins() searched_plugins: Dict[str, Path] = {} - third_party_plugins: Set[str] = set() + third_party_plugins: Dict[str, str] = {} + # check third party plugins for plugin in self.plugins: - name = plugin.rsplit(".", 1)[-1] + name = _module_name_to_plugin_name(plugin) if name in third_party_plugins or name in previous_plugins: raise RuntimeError( f"Plugin already exists: {name}! Check your plugin name" ) - third_party_plugins.add(plugin) + third_party_plugins[name] = plugin + self._third_party_plugin_names = third_party_plugins + + # check plugins in search path for module_info in pkgutil.iter_modules(self.search_path): + # ignore if startswith "_" if module_info.name.startswith("_"): continue + if ( - module_info.name in searched_plugins.keys() + module_info.name in searched_plugins or module_info.name in previous_plugins or module_info.name in third_party_plugins ): raise RuntimeError( f"Plugin already exists: {module_info.name}! Check your plugin name" ) + module_spec = module_info.module_finder.find_spec(module_info.name, None) if not module_spec: continue @@ -94,17 +126,27 @@ class PluginManager: continue searched_plugins[module_info.name] = Path(module_path).resolve() - self.searched_plugins = searched_plugins + self._searched_plugin_names = searched_plugins - return third_party_plugins | set(self.searched_plugins.keys()) + return self.available_plugins def load_plugin(self, name: str) -> Optional[Plugin]: + """加载指定插件。 + + 对于独立插件,可以使用完整插件模块名或者插件名称。 + + 参数: + name: 插件名称。 + """ + try: if name in self.plugins: module = importlib.import_module(name) - elif name in self.searched_plugins: + elif name in self._third_party_plugin_names: + module = importlib.import_module(self._third_party_plugin_names[name]) + elif name in self._searched_plugin_names: module = importlib.import_module( - self._path_to_module_name(self.searched_plugins[name]) + self._path_to_module_name(self._searched_plugin_names[name]) ) else: raise RuntimeError(f"Plugin not found: {name}! Check your plugin name") @@ -125,8 +167,10 @@ class PluginManager: ) def load_all_plugins(self) -> Set[Plugin]: + """加载所有可用插件。""" + return set( - filter(None, (self.load_plugin(name) for name in self.list_plugins())) + filter(None, (self.load_plugin(name) for name in self.available_plugins)) ) @@ -147,9 +191,10 @@ class PluginFinder(MetaPathFinder): module_path = Path(module_origin).resolve() for manager in reversed(_managers): + # use path instead of name in case of submodule name conflict if ( fullname in manager.plugins - or module_path in manager.searched_plugins.values() + or module_path in manager._searched_plugin_names.values() ): module_spec.loader = PluginLoader(manager, fullname, module_origin) return module_spec @@ -173,7 +218,11 @@ class PluginLoader(SourceFileLoader): if self.loaded: return + # create plugin before executing plugin = _new_plugin(self.name, module, self.manager) + setattr(module, "__plugin__", plugin) + + # detect parent plugin before entering current plugin context parent_plugin = _current_plugin.get() if parent_plugin and _managers.index(parent_plugin.manager) < _managers.index( self.manager @@ -181,21 +230,18 @@ class PluginLoader(SourceFileLoader): plugin.parent_plugin = parent_plugin parent_plugin.sub_plugins.add(plugin) + # enter plugin context _plugin_token = _current_plugin.set(plugin) - setattr(module, "__plugin__", plugin) + try: + super().exec_module(module) + except Exception: + _revert_plugin(plugin) + raise + finally: + # leave plugin context + _current_plugin.reset(_plugin_token) - # try: - # super().exec_module(module) - # except Exception as e: - # raise ImportError( - # f"Error when executing module {module_name} from {module.__file__}." - # ) from e - super().exec_module(module) - - _confirm_plugin(plugin) - - _current_plugin.reset(_plugin_token) return diff --git a/nonebot/plugin/on.py b/nonebot/plugin/on.py index 2340255d..40afd83c 100644 --- a/nonebot/plugin/on.py +++ b/nonebot/plugin/on.py @@ -5,7 +5,6 @@ FrontMatter: description: nonebot.plugin.on 模块 """ import re -import sys import inspect from types import ModuleType from typing import Any, Set, Dict, List, Type, Tuple, Union, Optional @@ -41,8 +40,7 @@ def _get_matcher_module(depth: int = 1) -> Optional[ModuleType]: if current_frame is None: return None frame = inspect.getouterframes(current_frame)[depth + 1].frame - module_name = frame.f_globals["__name__"] - return sys.modules.get(module_name) + return inspect.getmodule(frame) def on( diff --git a/nonebot/plugin/plugin.py b/nonebot/plugin/plugin.py index 10e43723..658c073b 100644 --- a/nonebot/plugin/plugin.py +++ b/nonebot/plugin/plugin.py @@ -6,18 +6,16 @@ FrontMatter: """ from types import ModuleType from dataclasses import field, dataclass -from typing import TYPE_CHECKING, Set, Dict, Type, Optional +from typing import TYPE_CHECKING, Set, Type, Optional from nonebot.matcher import Matcher from .export import Export +from . import _plugins as plugins # FIXME: backport for nonebug if TYPE_CHECKING: from .manager import PluginManager -plugins: Dict[str, "Plugin"] = {} -"""已加载的插件""" - @dataclass(eq=False) class Plugin(object): @@ -32,40 +30,10 @@ class Plugin(object): manager: "PluginManager" """导入该插件的插件管理器""" export: Export = field(default_factory=Export) - """插件内定义的导出内容""" + """**Deprecated:** 插件内定义的导出内容""" matcher: Set[Type[Matcher]] = field(default_factory=set) """插件内定义的 `Matcher`""" parent_plugin: Optional["Plugin"] = None """父插件""" sub_plugins: Set["Plugin"] = field(default_factory=set) """子插件集合""" - - -def get_plugin(name: str) -> Optional[Plugin]: - """获取已经导入的某个插件。 - - 如果为 `load_plugins` 文件夹导入的插件,则为文件(夹)名。 - - 参数: - name: 插件名,即 {ref}`nonebot.plugin.plugin.Plugin.name`。 - """ - return plugins.get(name) - - -def get_loaded_plugins() -> Set[Plugin]: - """获取当前已导入的所有插件。""" - return set(plugins.values()) - - -def _new_plugin(fullname: str, module: ModuleType, manager: "PluginManager") -> Plugin: - name = fullname.rsplit(".", 1)[-1] if "." in fullname else fullname - if name in plugins: - raise RuntimeError("Plugin already exists! Check your plugin name.") - plugin = Plugin(name, module, fullname, manager) - return plugin - - -def _confirm_plugin(plugin: Plugin) -> None: - if plugin.name in plugins: - raise RuntimeError("Plugin already exists! Check your plugin name.") - plugins[plugin.name] = plugin diff --git a/tests/bad_plugins/bad_plugin.py b/tests/bad_plugins/bad_plugin.py new file mode 100644 index 00000000..1c84f256 --- /dev/null +++ b/tests/bad_plugins/bad_plugin.py @@ -0,0 +1,6 @@ +import nonebot + +plugin = nonebot.get_plugin("bad_plugin") +assert plugin + +x = 1 / 0 diff --git a/tests/plugins/_hidden.py b/tests/plugins/_hidden.py new file mode 100644 index 00000000..2cfffa40 --- /dev/null +++ b/tests/plugins/_hidden.py @@ -0,0 +1 @@ +assert False diff --git a/tests/plugins/export.py b/tests/plugins/export.py index f7570e92..149b8360 100644 --- a/tests/plugins/export.py +++ b/tests/plugins/export.py @@ -3,4 +3,4 @@ from nonebot import export @export() def test(): - ... + return "export" diff --git a/tests/plugins/nested/__init__.py b/tests/plugins/nested/__init__.py new file mode 100644 index 00000000..eb4cf3e4 --- /dev/null +++ b/tests/plugins/nested/__init__.py @@ -0,0 +1,6 @@ +from pathlib import Path + +import nonebot + +_sub_plugins = set() +_sub_plugins |= nonebot.load_plugins(str((Path(__file__).parent / "plugins").resolve())) diff --git a/tests/plugins/nested/plugins/nested_subplugin.py b/tests/plugins/nested/plugins/nested_subplugin.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/plugins/require.py b/tests/plugins/require.py index 5764431c..8a9084d3 100644 --- a/tests/plugins/require.py +++ b/tests/plugins/require.py @@ -1,8 +1,7 @@ from nonebot import require -from plugins.export import test - -from .export import test as test_related test_require = require("export").test -assert test is test_related and test is test_require, "Export Require Error" +from plugins.export import test + +assert test is test_require and test() == "export", "Export Require Error" diff --git a/tests/test_plugin/test_get.py b/tests/test_plugin/test_get.py new file mode 100644 index 00000000..0e068643 --- /dev/null +++ b/tests/test_plugin/test_get.py @@ -0,0 +1,39 @@ +from typing import TYPE_CHECKING, Set + +import pytest +from nonebug import App + +if TYPE_CHECKING: + from nonebot.plugin import Plugin + + +@pytest.mark.asyncio +async def test_get_plugin(app: App, load_plugin: Set["Plugin"]): + import nonebot + + # check simple plugin + plugin = nonebot.get_plugin("export") + assert plugin + assert plugin.module_name == "plugins.export" + + # check sub plugin + plugin = nonebot.get_plugin("nested_subplugin") + assert plugin + assert plugin.module_name == "plugins.nested.plugins.nested_subplugin" + + # check get plugin by module name + plugin = nonebot.get_plugin_by_module_name("plugins.nested.utils") + assert plugin + assert plugin.module_name == "plugins.nested" + + +@pytest.mark.asyncio +async def test_get_available_plugin(app: App): + import nonebot + from nonebot.plugin import PluginManager, _managers + + _managers.append(PluginManager(["plugins.export", "plugin.require"])) + + # check get available plugins + plugin_names = nonebot.get_available_plugin_names() + assert plugin_names == {"export", "require"} diff --git a/tests/test_plugin/test_load.py b/tests/test_plugin/test_load.py index c0c0c8f6..209ee8e0 100644 --- a/tests/test_plugin/test_load.py +++ b/tests/test_plugin/test_load.py @@ -9,27 +9,40 @@ if TYPE_CHECKING: @pytest.mark.asyncio -async def test_load_plugin(load_plugin: Set["Plugin"]): +async def test_load_plugin(app: App, load_plugin: Set["Plugin"]): import nonebot + from nonebot.plugin import PluginManager - loaded_plugins = set( + loaded_plugins = { plugin for plugin in nonebot.get_loaded_plugins() if not plugin.parent_plugin - ) + } assert loaded_plugins == load_plugin - plugin = nonebot.get_plugin("export") - assert plugin - assert plugin.module_name == "plugins.export" + + # check simple plugin assert "plugins.export" in sys.modules - try: - nonebot.load_plugin("plugins.export") - assert False - except RuntimeError: - assert True + # check sub plugin + assert "plugins.nested.plugins.nested_subplugin" in sys.modules + # check load again + with pytest.raises(RuntimeError): + PluginManager(plugins=["plugins.export"]).load_all_plugins() + with pytest.raises(RuntimeError): + PluginManager(search_path=["plugins"]).load_all_plugins() + + # check not found assert nonebot.load_plugin("some_plugin_not_exist") is None +@pytest.mark.asyncio +async def test_bad_plugin(app: App): + import nonebot + + nonebot.load_plugins("bad_plugins") + + assert nonebot.get_plugin("bad_plugins") is None + + @pytest.mark.asyncio async def test_require_loaded(app: App, monkeypatch: pytest.MonkeyPatch): import nonebot @@ -47,8 +60,7 @@ async def test_require_loaded(app: App, monkeypatch: pytest.MonkeyPatch): @pytest.mark.asyncio async def test_require_not_loaded(app: App, monkeypatch: pytest.MonkeyPatch): import nonebot - from nonebot.plugin import _managers - from nonebot.plugin.manager import PluginManager + from nonebot.plugin import PluginManager, _managers m = PluginManager(["plugins.export"]) _managers.append(m) @@ -80,10 +92,6 @@ async def test_require_not_declared(app: App): @pytest.mark.asyncio async def test_require_not_found(app: App): import nonebot - from nonebot.plugin import _managers - try: + with pytest.raises(RuntimeError): nonebot.require("some_plugin_not_exist") - assert False - except RuntimeError: - assert True diff --git a/tests/test_plugin/test_manager.py b/tests/test_plugin/test_manager.py new file mode 100644 index 00000000..e5758f78 --- /dev/null +++ b/tests/test_plugin/test_manager.py @@ -0,0 +1,12 @@ +import pytest +from nonebug import App + + +@pytest.mark.asyncio +async def test_load_plugin_name(app: App): + from nonebot.plugin import PluginManager + + m = PluginManager(plugins=["plugins.export"]) + module1 = m.load_plugin("export") + module2 = m.load_plugin("plugins.export") + assert module1 is module2