improve plugin system (#1011)

This commit is contained in:
Ju4tCode 2022-05-26 16:35:47 +08:00 committed by GitHub
parent 579839f2a4
commit fa3ed2b58c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 254 additions and 106 deletions

View File

@ -26,7 +26,9 @@
- `load_builtin_plugin` => {ref}``load_builtin_plugin` <nonebot.plugin.load.load_builtin_plugin>` - `load_builtin_plugin` => {ref}``load_builtin_plugin` <nonebot.plugin.load.load_builtin_plugin>`
- `load_builtin_plugins` => {ref}``load_builtin_plugins` <nonebot.plugin.load.load_builtin_plugins>` - `load_builtin_plugins` => {ref}``load_builtin_plugins` <nonebot.plugin.load.load_builtin_plugins>`
- `get_plugin` => {ref}``get_plugin` <nonebot.plugin.plugin.get_plugin>` - `get_plugin` => {ref}``get_plugin` <nonebot.plugin.plugin.get_plugin>`
- `get_plugin_by_module_name` => {ref}``get_plugin_by_module_name` <nonebot.plugin.plugin.get_plugin_by_module_name>`
- `get_loaded_plugins` => {ref}``get_loaded_plugins` <nonebot.plugin.plugin.get_loaded_plugins>` - `get_loaded_plugins` => {ref}``get_loaded_plugins` <nonebot.plugin.plugin.get_loaded_plugins>`
- `get_available_plugin_names` => {ref}``get_available_plugin_names` <nonebot.plugin.plugin.get_available_plugin_names>`
- `export` => {ref}``export` <nonebot.plugin.export.export>` - `export` => {ref}``export` <nonebot.plugin.export.export>`
- `require` => {ref}``require` <nonebot.plugin.load.require>` - `require` => {ref}``require` <nonebot.plugin.load.require>`
@ -284,5 +286,7 @@ from nonebot.plugin import on_shell_command as on_shell_command
from nonebot.plugin import get_loaded_plugins as get_loaded_plugins from nonebot.plugin import get_loaded_plugins as get_loaded_plugins
from nonebot.plugin import load_builtin_plugin as load_builtin_plugin from nonebot.plugin import load_builtin_plugin as load_builtin_plugin
from nonebot.plugin import load_builtin_plugins as load_builtin_plugins from nonebot.plugin import load_builtin_plugins as load_builtin_plugins
from nonebot.plugin import get_plugin_by_module_name as get_plugin_by_module_name
from nonebot.plugin import get_available_plugin_names as get_available_plugin_names
__autodoc__ = {"internal": False} __autodoc__ = {"internal": False}

View File

@ -35,14 +35,77 @@ FrontMatter:
description: nonebot.plugin 模块 description: nonebot.plugin 模块
""" """
from typing import List, Optional from itertools import chain
from types import ModuleType
from contextvars import ContextVar from contextvars import ContextVar
from typing import Set, Dict, List, Optional
_plugins: Dict[str, "Plugin"] = {}
_managers: List["PluginManager"] = [] _managers: List["PluginManager"] = []
_current_plugin: ContextVar[Optional["Plugin"]] = ContextVar( _current_plugin: ContextVar[Optional["Plugin"]] = ContextVar(
"_current_plugin", default=None "_current_plugin", default=None
) )
def _module_name_to_plugin_name(module_name: str) -> str:
return module_name.rsplit(".", 1)[-1]
def _new_plugin(
module_name: str, module: ModuleType, manager: "PluginManager"
) -> "Plugin":
plugin_name = _module_name_to_plugin_name(module_name)
if plugin_name in _plugins:
raise RuntimeError("Plugin already exists! Check your plugin name.")
plugin = Plugin(plugin_name, module, module_name, manager)
_plugins[plugin_name] = plugin
return plugin
def _revert_plugin(plugin: "Plugin") -> None:
if plugin.name not in _plugins:
raise RuntimeError("Plugin not found!")
del _plugins[plugin.name]
def get_plugin(name: str) -> Optional["Plugin"]:
"""获取已经导入的某个插件。
如果为 `load_plugins` 文件夹导入的插件则为文件()
参数:
name: 插件名 {ref}`nonebot.plugin.plugin.Plugin.name`
"""
return _plugins.get(name)
def get_plugin_by_module_name(module_name: str) -> Optional["Plugin"]:
"""通过模块名获取已经导入的某个插件。
如果提供的模块名为某个插件的子模块同样会返回该插件
参数:
module_name: 模块名 {ref}`nonebot.plugin.plugin.Plugin.module_name`
"""
splits = module_name.split(".")
loaded = {plugin.module_name: plugin for plugin in _plugins.values()}
while splits:
name = ".".join(splits)
if name in loaded:
return loaded[name]
splits.pop()
def get_loaded_plugins() -> Set["Plugin"]:
"""获取当前已导入的所有插件。"""
return set(_plugins.values())
def get_available_plugin_names() -> Set[str]:
"""获取当前所有可用的插件名(包含尚未加载的插件)。"""
return {*chain.from_iterable(manager.available_plugins for manager in _managers)}
from .on import on as on from .on import on as on
from .manager import PluginManager from .manager import PluginManager
from .export import Export as Export from .export import Export as Export
@ -61,7 +124,6 @@ from .on import CommandGroup as CommandGroup
from .on import MatcherGroup as MatcherGroup from .on import MatcherGroup as MatcherGroup
from .on import on_fullmatch as on_fullmatch from .on import on_fullmatch as on_fullmatch
from .on import on_metaevent as on_metaevent from .on import on_metaevent as on_metaevent
from .plugin import get_plugin as get_plugin
from .load import load_plugins as load_plugins from .load import load_plugins as load_plugins
from .on import on_startswith as on_startswith from .on import on_startswith as on_startswith
from .load import load_from_json as load_from_json from .load import load_from_json as load_from_json
@ -69,5 +131,4 @@ from .load import load_from_toml as load_from_toml
from .on import on_shell_command as on_shell_command from .on import on_shell_command as on_shell_command
from .load import load_all_plugins as load_all_plugins from .load import load_all_plugins as load_all_plugins
from .load import load_builtin_plugin as load_builtin_plugin from .load import load_builtin_plugin as load_builtin_plugin
from .plugin import get_loaded_plugins as get_loaded_plugins
from .load import load_builtin_plugins as load_builtin_plugins from .load import load_builtin_plugins as load_builtin_plugins

View File

@ -10,10 +10,10 @@ from typing import Set, Iterable, Optional
import tomlkit import tomlkit
from . import _managers
from .export import Export from .export import Export
from .plugin import Plugin
from .manager import PluginManager from .manager import PluginManager
from .plugin import Plugin, get_plugin from . import _managers, get_plugin, _module_name_to_plugin_name
def load_plugin(module_path: str) -> Optional[Plugin]: def load_plugin(module_path: str) -> Optional[Plugin]:
@ -128,7 +128,7 @@ def load_builtin_plugin(name: str) -> Optional[Plugin]:
return load_plugin(f"nonebot.plugins.{name}") return load_plugin(f"nonebot.plugins.{name}")
def load_builtin_plugins(*plugins) -> Set[Plugin]: def load_builtin_plugins(*plugins: str) -> Set[Plugin]:
"""导入多个 NoneBot 内置插件。 """导入多个 NoneBot 内置插件。
参数: 参数:
@ -154,7 +154,7 @@ def require(name: str) -> Export:
异常: 异常:
RuntimeError: 插件无法加载 RuntimeError: 插件无法加载
""" """
plugin = get_plugin(name.rsplit(".", 1)[-1]) plugin = get_plugin(_module_name_to_plugin_name(name))
if not plugin: if not plugin:
manager = _find_manager_by_name(name) manager = _find_manager_by_name(name)
if manager: if manager:

View File

@ -19,23 +19,52 @@ from typing import Set, Dict, List, Union, Iterable, Optional, Sequence
from nonebot.log import logger from nonebot.log import logger
from nonebot.utils import escape_tag from nonebot.utils import escape_tag
from . import _managers, _current_plugin from .plugin import Plugin
from .plugin import Plugin, _new_plugin, _confirm_plugin from . import (
_managers,
_new_plugin,
_revert_plugin,
_current_plugin,
_module_name_to_plugin_name,
)
class PluginManager: class PluginManager:
"""插件管理器。
参数:
plugins: 独立插件模块名集合
search_path: 插件搜索路径文件夹
"""
def __init__( def __init__(
self, self,
plugins: Optional[Iterable[str]] = None, plugins: Optional[Iterable[str]] = None,
search_path: Optional[Iterable[str]] = None, search_path: Optional[Iterable[str]] = None,
): ):
# simple plugin not in search path # simple plugin not in search path
self.plugins: Set[str] = set(plugins or []) self.plugins: Set[str] = set(plugins or [])
self.search_path: Set[str] = set(search_path or []) self.search_path: Set[str] = set(search_path or [])
# cache plugins # cache plugins
self.searched_plugins: Dict[str, Path] = {} self._third_party_plugin_names: Dict[str, str] = {}
self.list_plugins() self._searched_plugin_names: Dict[str, Path] = {}
self.prepare_plugins()
@property
def third_party_plugins(self) -> Set[str]:
"""返回所有独立插件名称。"""
return set(self._third_party_plugin_names.keys())
@property
def searched_plugins(self) -> Set[str]:
"""返回已搜索到的插件名称。"""
return set(self._searched_plugin_names.keys())
@property
def available_plugins(self) -> Set[str]:
"""返回当前插件管理器中可用的插件名称。"""
return self.third_party_plugins | self.searched_plugins
def _path_to_module_name(self, path: Path) -> str: def _path_to_module_name(self, path: Path) -> str:
rel_path = path.resolve().relative_to(Path(".").resolve()) rel_path = path.resolve().relative_to(Path(".").resolve())
@ -44,48 +73,51 @@ class PluginManager:
else: else:
return ".".join(rel_path.parts[:-1] + (rel_path.stem,)) return ".".join(rel_path.parts[:-1] + (rel_path.stem,))
def _previous_plugins(self) -> List[str]: def _previous_plugins(self) -> Set[str]:
_pre_managers: List[PluginManager] _pre_managers: List[PluginManager]
if self in _managers: if self in _managers:
_pre_managers = _managers[: _managers.index(self)] _pre_managers = _managers[: _managers.index(self)]
else: else:
_pre_managers = _managers[:] _pre_managers = _managers[:]
return [ return {
*chain.from_iterable( *chain.from_iterable(manager.available_plugins for manager in _pre_managers)
[ }
*map(lambda x: x.rsplit(".", 1)[-1], manager.plugins),
*manager.searched_plugins.keys(), def prepare_plugins(self) -> Set[str]:
] """搜索插件并缓存插件名称。"""
for manager in _pre_managers
)
]
def list_plugins(self) -> Set[str]:
# get all previous ready to load plugins # get all previous ready to load plugins
previous_plugins = self._previous_plugins() previous_plugins = self._previous_plugins()
searched_plugins: Dict[str, Path] = {} searched_plugins: Dict[str, Path] = {}
third_party_plugins: Set[str] = set() third_party_plugins: Dict[str, str] = {}
# check third party plugins
for plugin in self.plugins: for plugin in self.plugins:
name = plugin.rsplit(".", 1)[-1] name = _module_name_to_plugin_name(plugin)
if name in third_party_plugins or name in previous_plugins: if name in third_party_plugins or name in previous_plugins:
raise RuntimeError( raise RuntimeError(
f"Plugin already exists: {name}! Check your plugin name" f"Plugin already exists: {name}! Check your plugin name"
) )
third_party_plugins.add(plugin) third_party_plugins[name] = plugin
self._third_party_plugin_names = third_party_plugins
# check plugins in search path
for module_info in pkgutil.iter_modules(self.search_path): for module_info in pkgutil.iter_modules(self.search_path):
# ignore if startswith "_"
if module_info.name.startswith("_"): if module_info.name.startswith("_"):
continue continue
if ( if (
module_info.name in searched_plugins.keys() module_info.name in searched_plugins
or module_info.name in previous_plugins or module_info.name in previous_plugins
or module_info.name in third_party_plugins or module_info.name in third_party_plugins
): ):
raise RuntimeError( raise RuntimeError(
f"Plugin already exists: {module_info.name}! Check your plugin name" f"Plugin already exists: {module_info.name}! Check your plugin name"
) )
module_spec = module_info.module_finder.find_spec(module_info.name, None) module_spec = module_info.module_finder.find_spec(module_info.name, None)
if not module_spec: if not module_spec:
continue continue
@ -94,17 +126,27 @@ class PluginManager:
continue continue
searched_plugins[module_info.name] = Path(module_path).resolve() searched_plugins[module_info.name] = Path(module_path).resolve()
self.searched_plugins = searched_plugins self._searched_plugin_names = searched_plugins
return third_party_plugins | set(self.searched_plugins.keys()) return self.available_plugins
def load_plugin(self, name: str) -> Optional[Plugin]: def load_plugin(self, name: str) -> Optional[Plugin]:
"""加载指定插件。
对于独立插件可以使用完整插件模块名或者插件名称
参数:
name: 插件名称
"""
try: try:
if name in self.plugins: if name in self.plugins:
module = importlib.import_module(name) module = importlib.import_module(name)
elif name in self.searched_plugins: elif name in self._third_party_plugin_names:
module = importlib.import_module(self._third_party_plugin_names[name])
elif name in self._searched_plugin_names:
module = importlib.import_module( module = importlib.import_module(
self._path_to_module_name(self.searched_plugins[name]) self._path_to_module_name(self._searched_plugin_names[name])
) )
else: else:
raise RuntimeError(f"Plugin not found: {name}! Check your plugin name") raise RuntimeError(f"Plugin not found: {name}! Check your plugin name")
@ -125,8 +167,10 @@ class PluginManager:
) )
def load_all_plugins(self) -> Set[Plugin]: def load_all_plugins(self) -> Set[Plugin]:
"""加载所有可用插件。"""
return set( return set(
filter(None, (self.load_plugin(name) for name in self.list_plugins())) filter(None, (self.load_plugin(name) for name in self.available_plugins))
) )
@ -147,9 +191,10 @@ class PluginFinder(MetaPathFinder):
module_path = Path(module_origin).resolve() module_path = Path(module_origin).resolve()
for manager in reversed(_managers): for manager in reversed(_managers):
# use path instead of name in case of submodule name conflict
if ( if (
fullname in manager.plugins fullname in manager.plugins
or module_path in manager.searched_plugins.values() or module_path in manager._searched_plugin_names.values()
): ):
module_spec.loader = PluginLoader(manager, fullname, module_origin) module_spec.loader = PluginLoader(manager, fullname, module_origin)
return module_spec return module_spec
@ -173,7 +218,11 @@ class PluginLoader(SourceFileLoader):
if self.loaded: if self.loaded:
return return
# create plugin before executing
plugin = _new_plugin(self.name, module, self.manager) plugin = _new_plugin(self.name, module, self.manager)
setattr(module, "__plugin__", plugin)
# detect parent plugin before entering current plugin context
parent_plugin = _current_plugin.get() parent_plugin = _current_plugin.get()
if parent_plugin and _managers.index(parent_plugin.manager) < _managers.index( if parent_plugin and _managers.index(parent_plugin.manager) < _managers.index(
self.manager self.manager
@ -181,21 +230,18 @@ class PluginLoader(SourceFileLoader):
plugin.parent_plugin = parent_plugin plugin.parent_plugin = parent_plugin
parent_plugin.sub_plugins.add(plugin) parent_plugin.sub_plugins.add(plugin)
# enter plugin context
_plugin_token = _current_plugin.set(plugin) _plugin_token = _current_plugin.set(plugin)
setattr(module, "__plugin__", plugin) try:
super().exec_module(module)
except Exception:
_revert_plugin(plugin)
raise
finally:
# leave plugin context
_current_plugin.reset(_plugin_token)
# try:
# super().exec_module(module)
# except Exception as e:
# raise ImportError(
# f"Error when executing module {module_name} from {module.__file__}."
# ) from e
super().exec_module(module)
_confirm_plugin(plugin)
_current_plugin.reset(_plugin_token)
return return

View File

@ -5,7 +5,6 @@ FrontMatter:
description: nonebot.plugin.on 模块 description: nonebot.plugin.on 模块
""" """
import re import re
import sys
import inspect import inspect
from types import ModuleType from types import ModuleType
from typing import Any, Set, Dict, List, Type, Tuple, Union, Optional from typing import Any, Set, Dict, List, Type, Tuple, Union, Optional
@ -41,8 +40,7 @@ def _get_matcher_module(depth: int = 1) -> Optional[ModuleType]:
if current_frame is None: if current_frame is None:
return None return None
frame = inspect.getouterframes(current_frame)[depth + 1].frame frame = inspect.getouterframes(current_frame)[depth + 1].frame
module_name = frame.f_globals["__name__"] return inspect.getmodule(frame)
return sys.modules.get(module_name)
def on( def on(

View File

@ -6,18 +6,16 @@ FrontMatter:
""" """
from types import ModuleType from types import ModuleType
from dataclasses import field, dataclass from dataclasses import field, dataclass
from typing import TYPE_CHECKING, Set, Dict, Type, Optional from typing import TYPE_CHECKING, Set, Type, Optional
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
from .export import Export from .export import Export
from . import _plugins as plugins # FIXME: backport for nonebug
if TYPE_CHECKING: if TYPE_CHECKING:
from .manager import PluginManager from .manager import PluginManager
plugins: Dict[str, "Plugin"] = {}
"""已加载的插件"""
@dataclass(eq=False) @dataclass(eq=False)
class Plugin(object): class Plugin(object):
@ -32,40 +30,10 @@ class Plugin(object):
manager: "PluginManager" manager: "PluginManager"
"""导入该插件的插件管理器""" """导入该插件的插件管理器"""
export: Export = field(default_factory=Export) export: Export = field(default_factory=Export)
"""插件内定义的导出内容""" """**Deprecated:** 插件内定义的导出内容"""
matcher: Set[Type[Matcher]] = field(default_factory=set) matcher: Set[Type[Matcher]] = field(default_factory=set)
"""插件内定义的 `Matcher`""" """插件内定义的 `Matcher`"""
parent_plugin: Optional["Plugin"] = None parent_plugin: Optional["Plugin"] = None
"""父插件""" """父插件"""
sub_plugins: Set["Plugin"] = field(default_factory=set) sub_plugins: Set["Plugin"] = field(default_factory=set)
"""子插件集合""" """子插件集合"""
def get_plugin(name: str) -> Optional[Plugin]:
"""获取已经导入的某个插件。
如果为 `load_plugins` 文件夹导入的插件则为文件()
参数:
name: 插件名 {ref}`nonebot.plugin.plugin.Plugin.name`
"""
return plugins.get(name)
def get_loaded_plugins() -> Set[Plugin]:
"""获取当前已导入的所有插件。"""
return set(plugins.values())
def _new_plugin(fullname: str, module: ModuleType, manager: "PluginManager") -> Plugin:
name = fullname.rsplit(".", 1)[-1] if "." in fullname else fullname
if name in plugins:
raise RuntimeError("Plugin already exists! Check your plugin name.")
plugin = Plugin(name, module, fullname, manager)
return plugin
def _confirm_plugin(plugin: Plugin) -> None:
if plugin.name in plugins:
raise RuntimeError("Plugin already exists! Check your plugin name.")
plugins[plugin.name] = plugin

View File

@ -0,0 +1,6 @@
import nonebot
plugin = nonebot.get_plugin("bad_plugin")
assert plugin
x = 1 / 0

1
tests/plugins/_hidden.py Normal file
View File

@ -0,0 +1 @@
assert False

View File

@ -3,4 +3,4 @@ from nonebot import export
@export() @export()
def test(): def test():
... return "export"

View File

@ -0,0 +1,6 @@
from pathlib import Path
import nonebot
_sub_plugins = set()
_sub_plugins |= nonebot.load_plugins(str((Path(__file__).parent / "plugins").resolve()))

View File

@ -1,8 +1,7 @@
from nonebot import require from nonebot import require
from plugins.export import test
from .export import test as test_related
test_require = require("export").test test_require = require("export").test
assert test is test_related and test is test_require, "Export Require Error" from plugins.export import test
assert test is test_require and test() == "export", "Export Require Error"

View File

@ -0,0 +1,39 @@
from typing import TYPE_CHECKING, Set
import pytest
from nonebug import App
if TYPE_CHECKING:
from nonebot.plugin import Plugin
@pytest.mark.asyncio
async def test_get_plugin(app: App, load_plugin: Set["Plugin"]):
import nonebot
# check simple plugin
plugin = nonebot.get_plugin("export")
assert plugin
assert plugin.module_name == "plugins.export"
# check sub plugin
plugin = nonebot.get_plugin("nested_subplugin")
assert plugin
assert plugin.module_name == "plugins.nested.plugins.nested_subplugin"
# check get plugin by module name
plugin = nonebot.get_plugin_by_module_name("plugins.nested.utils")
assert plugin
assert plugin.module_name == "plugins.nested"
@pytest.mark.asyncio
async def test_get_available_plugin(app: App):
import nonebot
from nonebot.plugin import PluginManager, _managers
_managers.append(PluginManager(["plugins.export", "plugin.require"]))
# check get available plugins
plugin_names = nonebot.get_available_plugin_names()
assert plugin_names == {"export", "require"}

View File

@ -9,27 +9,40 @@ if TYPE_CHECKING:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load_plugin(load_plugin: Set["Plugin"]): async def test_load_plugin(app: App, load_plugin: Set["Plugin"]):
import nonebot import nonebot
from nonebot.plugin import PluginManager
loaded_plugins = set( loaded_plugins = {
plugin for plugin in nonebot.get_loaded_plugins() if not plugin.parent_plugin plugin for plugin in nonebot.get_loaded_plugins() if not plugin.parent_plugin
) }
assert loaded_plugins == load_plugin assert loaded_plugins == load_plugin
plugin = nonebot.get_plugin("export")
assert plugin # check simple plugin
assert plugin.module_name == "plugins.export"
assert "plugins.export" in sys.modules assert "plugins.export" in sys.modules
try: # check sub plugin
nonebot.load_plugin("plugins.export") assert "plugins.nested.plugins.nested_subplugin" in sys.modules
assert False
except RuntimeError:
assert True
# check load again
with pytest.raises(RuntimeError):
PluginManager(plugins=["plugins.export"]).load_all_plugins()
with pytest.raises(RuntimeError):
PluginManager(search_path=["plugins"]).load_all_plugins()
# check not found
assert nonebot.load_plugin("some_plugin_not_exist") is None assert nonebot.load_plugin("some_plugin_not_exist") is None
@pytest.mark.asyncio
async def test_bad_plugin(app: App):
import nonebot
nonebot.load_plugins("bad_plugins")
assert nonebot.get_plugin("bad_plugins") is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_require_loaded(app: App, monkeypatch: pytest.MonkeyPatch): async def test_require_loaded(app: App, monkeypatch: pytest.MonkeyPatch):
import nonebot import nonebot
@ -47,8 +60,7 @@ async def test_require_loaded(app: App, monkeypatch: pytest.MonkeyPatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_require_not_loaded(app: App, monkeypatch: pytest.MonkeyPatch): async def test_require_not_loaded(app: App, monkeypatch: pytest.MonkeyPatch):
import nonebot import nonebot
from nonebot.plugin import _managers from nonebot.plugin import PluginManager, _managers
from nonebot.plugin.manager import PluginManager
m = PluginManager(["plugins.export"]) m = PluginManager(["plugins.export"])
_managers.append(m) _managers.append(m)
@ -80,10 +92,6 @@ async def test_require_not_declared(app: App):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_require_not_found(app: App): async def test_require_not_found(app: App):
import nonebot import nonebot
from nonebot.plugin import _managers
try: with pytest.raises(RuntimeError):
nonebot.require("some_plugin_not_exist") nonebot.require("some_plugin_not_exist")
assert False
except RuntimeError:
assert True

View File

@ -0,0 +1,12 @@
import pytest
from nonebug import App
@pytest.mark.asyncio
async def test_load_plugin_name(app: App):
from nonebot.plugin import PluginManager
m = PluginManager(plugins=["plugins.export"])
module1 = m.load_plugin("export")
module2 = m.load_plugin("plugins.export")
assert module1 is module2