improve plugin system (#1011)

This commit is contained in:
Ju4tCode 2022-05-26 16:35:47 +08:00 committed by GitHub
parent 579839f2a4
commit fa3ed2b58c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 254 additions and 106 deletions

View File

@ -26,7 +26,9 @@
- `load_builtin_plugin` => {ref}``load_builtin_plugin` <nonebot.plugin.load.load_builtin_plugin>`
- `load_builtin_plugins` => {ref}``load_builtin_plugins` <nonebot.plugin.load.load_builtin_plugins>`
- `get_plugin` => {ref}``get_plugin` <nonebot.plugin.plugin.get_plugin>`
- `get_plugin_by_module_name` => {ref}``get_plugin_by_module_name` <nonebot.plugin.plugin.get_plugin_by_module_name>`
- `get_loaded_plugins` => {ref}``get_loaded_plugins` <nonebot.plugin.plugin.get_loaded_plugins>`
- `get_available_plugin_names` => {ref}``get_available_plugin_names` <nonebot.plugin.plugin.get_available_plugin_names>`
- `export` => {ref}``export` <nonebot.plugin.export.export>`
- `require` => {ref}``require` <nonebot.plugin.load.require>`
@ -284,5 +286,7 @@ from nonebot.plugin import on_shell_command as on_shell_command
from nonebot.plugin import get_loaded_plugins as get_loaded_plugins
from nonebot.plugin import load_builtin_plugin as load_builtin_plugin
from nonebot.plugin import load_builtin_plugins as load_builtin_plugins
from nonebot.plugin import get_plugin_by_module_name as get_plugin_by_module_name
from nonebot.plugin import get_available_plugin_names as get_available_plugin_names
__autodoc__ = {"internal": False}

View File

@ -35,14 +35,77 @@ FrontMatter:
description: nonebot.plugin 模块
"""
from typing import List, Optional
from itertools import chain
from types import ModuleType
from contextvars import ContextVar
from typing import Set, Dict, List, Optional
_plugins: Dict[str, "Plugin"] = {}
_managers: List["PluginManager"] = []
_current_plugin: ContextVar[Optional["Plugin"]] = ContextVar(
"_current_plugin", default=None
)
def _module_name_to_plugin_name(module_name: str) -> str:
return module_name.rsplit(".", 1)[-1]
def _new_plugin(
module_name: str, module: ModuleType, manager: "PluginManager"
) -> "Plugin":
plugin_name = _module_name_to_plugin_name(module_name)
if plugin_name in _plugins:
raise RuntimeError("Plugin already exists! Check your plugin name.")
plugin = Plugin(plugin_name, module, module_name, manager)
_plugins[plugin_name] = plugin
return plugin
def _revert_plugin(plugin: "Plugin") -> None:
if plugin.name not in _plugins:
raise RuntimeError("Plugin not found!")
del _plugins[plugin.name]
def get_plugin(name: str) -> Optional["Plugin"]:
"""获取已经导入的某个插件。
如果为 `load_plugins` 文件夹导入的插件则为文件()
参数:
name: 插件名 {ref}`nonebot.plugin.plugin.Plugin.name`
"""
return _plugins.get(name)
def get_plugin_by_module_name(module_name: str) -> Optional["Plugin"]:
"""通过模块名获取已经导入的某个插件。
如果提供的模块名为某个插件的子模块同样会返回该插件
参数:
module_name: 模块名 {ref}`nonebot.plugin.plugin.Plugin.module_name`
"""
splits = module_name.split(".")
loaded = {plugin.module_name: plugin for plugin in _plugins.values()}
while splits:
name = ".".join(splits)
if name in loaded:
return loaded[name]
splits.pop()
def get_loaded_plugins() -> Set["Plugin"]:
"""获取当前已导入的所有插件。"""
return set(_plugins.values())
def get_available_plugin_names() -> Set[str]:
"""获取当前所有可用的插件名(包含尚未加载的插件)。"""
return {*chain.from_iterable(manager.available_plugins for manager in _managers)}
from .on import on as on
from .manager import PluginManager
from .export import Export as Export
@ -61,7 +124,6 @@ from .on import CommandGroup as CommandGroup
from .on import MatcherGroup as MatcherGroup
from .on import on_fullmatch as on_fullmatch
from .on import on_metaevent as on_metaevent
from .plugin import get_plugin as get_plugin
from .load import load_plugins as load_plugins
from .on import on_startswith as on_startswith
from .load import load_from_json as load_from_json
@ -69,5 +131,4 @@ from .load import load_from_toml as load_from_toml
from .on import on_shell_command as on_shell_command
from .load import load_all_plugins as load_all_plugins
from .load import load_builtin_plugin as load_builtin_plugin
from .plugin import get_loaded_plugins as get_loaded_plugins
from .load import load_builtin_plugins as load_builtin_plugins

View File

@ -10,10 +10,10 @@ from typing import Set, Iterable, Optional
import tomlkit
from . import _managers
from .export import Export
from .plugin import Plugin
from .manager import PluginManager
from .plugin import Plugin, get_plugin
from . import _managers, get_plugin, _module_name_to_plugin_name
def load_plugin(module_path: str) -> Optional[Plugin]:
@ -128,7 +128,7 @@ def load_builtin_plugin(name: str) -> Optional[Plugin]:
return load_plugin(f"nonebot.plugins.{name}")
def load_builtin_plugins(*plugins) -> Set[Plugin]:
def load_builtin_plugins(*plugins: str) -> Set[Plugin]:
"""导入多个 NoneBot 内置插件。
参数:
@ -154,7 +154,7 @@ def require(name: str) -> Export:
异常:
RuntimeError: 插件无法加载
"""
plugin = get_plugin(name.rsplit(".", 1)[-1])
plugin = get_plugin(_module_name_to_plugin_name(name))
if not plugin:
manager = _find_manager_by_name(name)
if manager:

View File

@ -19,23 +19,52 @@ from typing import Set, Dict, List, Union, Iterable, Optional, Sequence
from nonebot.log import logger
from nonebot.utils import escape_tag
from . import _managers, _current_plugin
from .plugin import Plugin, _new_plugin, _confirm_plugin
from .plugin import Plugin
from . import (
_managers,
_new_plugin,
_revert_plugin,
_current_plugin,
_module_name_to_plugin_name,
)
class PluginManager:
"""插件管理器。
参数:
plugins: 独立插件模块名集合
search_path: 插件搜索路径文件夹
"""
def __init__(
self,
plugins: Optional[Iterable[str]] = None,
search_path: Optional[Iterable[str]] = None,
):
# simple plugin not in search path
self.plugins: Set[str] = set(plugins or [])
self.search_path: Set[str] = set(search_path or [])
# cache plugins
self.searched_plugins: Dict[str, Path] = {}
self.list_plugins()
self._third_party_plugin_names: Dict[str, str] = {}
self._searched_plugin_names: Dict[str, Path] = {}
self.prepare_plugins()
@property
def third_party_plugins(self) -> Set[str]:
"""返回所有独立插件名称。"""
return set(self._third_party_plugin_names.keys())
@property
def searched_plugins(self) -> Set[str]:
"""返回已搜索到的插件名称。"""
return set(self._searched_plugin_names.keys())
@property
def available_plugins(self) -> Set[str]:
"""返回当前插件管理器中可用的插件名称。"""
return self.third_party_plugins | self.searched_plugins
def _path_to_module_name(self, path: Path) -> str:
rel_path = path.resolve().relative_to(Path(".").resolve())
@ -44,48 +73,51 @@ class PluginManager:
else:
return ".".join(rel_path.parts[:-1] + (rel_path.stem,))
def _previous_plugins(self) -> List[str]:
def _previous_plugins(self) -> Set[str]:
_pre_managers: List[PluginManager]
if self in _managers:
_pre_managers = _managers[: _managers.index(self)]
else:
_pre_managers = _managers[:]
return [
*chain.from_iterable(
[
*map(lambda x: x.rsplit(".", 1)[-1], manager.plugins),
*manager.searched_plugins.keys(),
]
for manager in _pre_managers
)
]
return {
*chain.from_iterable(manager.available_plugins for manager in _pre_managers)
}
def prepare_plugins(self) -> Set[str]:
"""搜索插件并缓存插件名称。"""
def list_plugins(self) -> Set[str]:
# get all previous ready to load plugins
previous_plugins = self._previous_plugins()
searched_plugins: Dict[str, Path] = {}
third_party_plugins: Set[str] = set()
third_party_plugins: Dict[str, str] = {}
# check third party plugins
for plugin in self.plugins:
name = plugin.rsplit(".", 1)[-1]
name = _module_name_to_plugin_name(plugin)
if name in third_party_plugins or name in previous_plugins:
raise RuntimeError(
f"Plugin already exists: {name}! Check your plugin name"
)
third_party_plugins.add(plugin)
third_party_plugins[name] = plugin
self._third_party_plugin_names = third_party_plugins
# check plugins in search path
for module_info in pkgutil.iter_modules(self.search_path):
# ignore if startswith "_"
if module_info.name.startswith("_"):
continue
if (
module_info.name in searched_plugins.keys()
module_info.name in searched_plugins
or module_info.name in previous_plugins
or module_info.name in third_party_plugins
):
raise RuntimeError(
f"Plugin already exists: {module_info.name}! Check your plugin name"
)
module_spec = module_info.module_finder.find_spec(module_info.name, None)
if not module_spec:
continue
@ -94,17 +126,27 @@ class PluginManager:
continue
searched_plugins[module_info.name] = Path(module_path).resolve()
self.searched_plugins = searched_plugins
self._searched_plugin_names = searched_plugins
return third_party_plugins | set(self.searched_plugins.keys())
return self.available_plugins
def load_plugin(self, name: str) -> Optional[Plugin]:
"""加载指定插件。
对于独立插件可以使用完整插件模块名或者插件名称
参数:
name: 插件名称
"""
try:
if name in self.plugins:
module = importlib.import_module(name)
elif name in self.searched_plugins:
elif name in self._third_party_plugin_names:
module = importlib.import_module(self._third_party_plugin_names[name])
elif name in self._searched_plugin_names:
module = importlib.import_module(
self._path_to_module_name(self.searched_plugins[name])
self._path_to_module_name(self._searched_plugin_names[name])
)
else:
raise RuntimeError(f"Plugin not found: {name}! Check your plugin name")
@ -125,8 +167,10 @@ class PluginManager:
)
def load_all_plugins(self) -> Set[Plugin]:
"""加载所有可用插件。"""
return set(
filter(None, (self.load_plugin(name) for name in self.list_plugins()))
filter(None, (self.load_plugin(name) for name in self.available_plugins))
)
@ -147,9 +191,10 @@ class PluginFinder(MetaPathFinder):
module_path = Path(module_origin).resolve()
for manager in reversed(_managers):
# use path instead of name in case of submodule name conflict
if (
fullname in manager.plugins
or module_path in manager.searched_plugins.values()
or module_path in manager._searched_plugin_names.values()
):
module_spec.loader = PluginLoader(manager, fullname, module_origin)
return module_spec
@ -173,7 +218,11 @@ class PluginLoader(SourceFileLoader):
if self.loaded:
return
# create plugin before executing
plugin = _new_plugin(self.name, module, self.manager)
setattr(module, "__plugin__", plugin)
# detect parent plugin before entering current plugin context
parent_plugin = _current_plugin.get()
if parent_plugin and _managers.index(parent_plugin.manager) < _managers.index(
self.manager
@ -181,21 +230,18 @@ class PluginLoader(SourceFileLoader):
plugin.parent_plugin = parent_plugin
parent_plugin.sub_plugins.add(plugin)
# enter plugin context
_plugin_token = _current_plugin.set(plugin)
setattr(module, "__plugin__", plugin)
try:
super().exec_module(module)
except Exception:
_revert_plugin(plugin)
raise
finally:
# leave plugin context
_current_plugin.reset(_plugin_token)
# try:
# super().exec_module(module)
# except Exception as e:
# raise ImportError(
# f"Error when executing module {module_name} from {module.__file__}."
# ) from e
super().exec_module(module)
_confirm_plugin(plugin)
_current_plugin.reset(_plugin_token)
return

View File

@ -5,7 +5,6 @@ FrontMatter:
description: nonebot.plugin.on 模块
"""
import re
import sys
import inspect
from types import ModuleType
from typing import Any, Set, Dict, List, Type, Tuple, Union, Optional
@ -41,8 +40,7 @@ def _get_matcher_module(depth: int = 1) -> Optional[ModuleType]:
if current_frame is None:
return None
frame = inspect.getouterframes(current_frame)[depth + 1].frame
module_name = frame.f_globals["__name__"]
return sys.modules.get(module_name)
return inspect.getmodule(frame)
def on(

View File

@ -6,18 +6,16 @@ FrontMatter:
"""
from types import ModuleType
from dataclasses import field, dataclass
from typing import TYPE_CHECKING, Set, Dict, Type, Optional
from typing import TYPE_CHECKING, Set, Type, Optional
from nonebot.matcher import Matcher
from .export import Export
from . import _plugins as plugins # FIXME: backport for nonebug
if TYPE_CHECKING:
from .manager import PluginManager
plugins: Dict[str, "Plugin"] = {}
"""已加载的插件"""
@dataclass(eq=False)
class Plugin(object):
@ -32,40 +30,10 @@ class Plugin(object):
manager: "PluginManager"
"""导入该插件的插件管理器"""
export: Export = field(default_factory=Export)
"""插件内定义的导出内容"""
"""**Deprecated:** 插件内定义的导出内容"""
matcher: Set[Type[Matcher]] = field(default_factory=set)
"""插件内定义的 `Matcher`"""
parent_plugin: Optional["Plugin"] = None
"""父插件"""
sub_plugins: Set["Plugin"] = field(default_factory=set)
"""子插件集合"""
def get_plugin(name: str) -> Optional[Plugin]:
"""获取已经导入的某个插件。
如果为 `load_plugins` 文件夹导入的插件则为文件()
参数:
name: 插件名 {ref}`nonebot.plugin.plugin.Plugin.name`
"""
return plugins.get(name)
def get_loaded_plugins() -> Set[Plugin]:
"""获取当前已导入的所有插件。"""
return set(plugins.values())
def _new_plugin(fullname: str, module: ModuleType, manager: "PluginManager") -> Plugin:
name = fullname.rsplit(".", 1)[-1] if "." in fullname else fullname
if name in plugins:
raise RuntimeError("Plugin already exists! Check your plugin name.")
plugin = Plugin(name, module, fullname, manager)
return plugin
def _confirm_plugin(plugin: Plugin) -> None:
if plugin.name in plugins:
raise RuntimeError("Plugin already exists! Check your plugin name.")
plugins[plugin.name] = plugin

View File

@ -0,0 +1,6 @@
import nonebot
plugin = nonebot.get_plugin("bad_plugin")
assert plugin
x = 1 / 0

1
tests/plugins/_hidden.py Normal file
View File

@ -0,0 +1 @@
assert False

View File

@ -3,4 +3,4 @@ from nonebot import export
@export()
def test():
...
return "export"

View File

@ -0,0 +1,6 @@
from pathlib import Path
import nonebot
_sub_plugins = set()
_sub_plugins |= nonebot.load_plugins(str((Path(__file__).parent / "plugins").resolve()))

View File

@ -1,8 +1,7 @@
from nonebot import require
from plugins.export import test
from .export import test as test_related
test_require = require("export").test
assert test is test_related and test is test_require, "Export Require Error"
from plugins.export import test
assert test is test_require and test() == "export", "Export Require Error"

View File

@ -0,0 +1,39 @@
from typing import TYPE_CHECKING, Set
import pytest
from nonebug import App
if TYPE_CHECKING:
from nonebot.plugin import Plugin
@pytest.mark.asyncio
async def test_get_plugin(app: App, load_plugin: Set["Plugin"]):
import nonebot
# check simple plugin
plugin = nonebot.get_plugin("export")
assert plugin
assert plugin.module_name == "plugins.export"
# check sub plugin
plugin = nonebot.get_plugin("nested_subplugin")
assert plugin
assert plugin.module_name == "plugins.nested.plugins.nested_subplugin"
# check get plugin by module name
plugin = nonebot.get_plugin_by_module_name("plugins.nested.utils")
assert plugin
assert plugin.module_name == "plugins.nested"
@pytest.mark.asyncio
async def test_get_available_plugin(app: App):
import nonebot
from nonebot.plugin import PluginManager, _managers
_managers.append(PluginManager(["plugins.export", "plugin.require"]))
# check get available plugins
plugin_names = nonebot.get_available_plugin_names()
assert plugin_names == {"export", "require"}

View File

@ -9,27 +9,40 @@ if TYPE_CHECKING:
@pytest.mark.asyncio
async def test_load_plugin(load_plugin: Set["Plugin"]):
async def test_load_plugin(app: App, load_plugin: Set["Plugin"]):
import nonebot
from nonebot.plugin import PluginManager
loaded_plugins = set(
loaded_plugins = {
plugin for plugin in nonebot.get_loaded_plugins() if not plugin.parent_plugin
)
}
assert loaded_plugins == load_plugin
plugin = nonebot.get_plugin("export")
assert plugin
assert plugin.module_name == "plugins.export"
# check simple plugin
assert "plugins.export" in sys.modules
try:
nonebot.load_plugin("plugins.export")
assert False
except RuntimeError:
assert True
# check sub plugin
assert "plugins.nested.plugins.nested_subplugin" in sys.modules
# check load again
with pytest.raises(RuntimeError):
PluginManager(plugins=["plugins.export"]).load_all_plugins()
with pytest.raises(RuntimeError):
PluginManager(search_path=["plugins"]).load_all_plugins()
# check not found
assert nonebot.load_plugin("some_plugin_not_exist") is None
@pytest.mark.asyncio
async def test_bad_plugin(app: App):
import nonebot
nonebot.load_plugins("bad_plugins")
assert nonebot.get_plugin("bad_plugins") is None
@pytest.mark.asyncio
async def test_require_loaded(app: App, monkeypatch: pytest.MonkeyPatch):
import nonebot
@ -47,8 +60,7 @@ async def test_require_loaded(app: App, monkeypatch: pytest.MonkeyPatch):
@pytest.mark.asyncio
async def test_require_not_loaded(app: App, monkeypatch: pytest.MonkeyPatch):
import nonebot
from nonebot.plugin import _managers
from nonebot.plugin.manager import PluginManager
from nonebot.plugin import PluginManager, _managers
m = PluginManager(["plugins.export"])
_managers.append(m)
@ -80,10 +92,6 @@ async def test_require_not_declared(app: App):
@pytest.mark.asyncio
async def test_require_not_found(app: App):
import nonebot
from nonebot.plugin import _managers
try:
with pytest.raises(RuntimeError):
nonebot.require("some_plugin_not_exist")
assert False
except RuntimeError:
assert True

View File

@ -0,0 +1,12 @@
import pytest
from nonebug import App
@pytest.mark.asyncio
async def test_load_plugin_name(app: App):
from nonebot.plugin import PluginManager
m = PluginManager(plugins=["plugins.export"])
module1 = m.load_plugin("export")
module2 = m.load_plugin("plugins.export")
assert module1 is module2