From 1cfdee26459f58a2598b44f2cf2c6d622f80bfe5 Mon Sep 17 00:00:00 2001 From: Lan <59906398+Lancercmd@users.noreply.github.com> Date: Wed, 31 Aug 2022 10:07:14 +0800 Subject: [PATCH] =?UTF-8?q?:sparkles:=20Featue:=20`load=5Fplugin`=20?= =?UTF-8?q?=E6=94=AF=E6=8C=81=20`pathlib.Path`=20(#1194)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> --- nonebot/plugin/load.py | 16 +++++++++++----- nonebot/plugin/manager.py | 25 +++++++++---------------- nonebot/utils.py | 9 +++++++++ tests/test_plugin/test_load.py | 20 ++++++++++++++++---- 4 files changed, 45 insertions(+), 25 deletions(-) diff --git a/nonebot/plugin/load.py b/nonebot/plugin/load.py index 1a6075b3..c51ba1e0 100644 --- a/nonebot/plugin/load.py +++ b/nonebot/plugin/load.py @@ -5,24 +5,30 @@ FrontMatter: description: nonebot.plugin.load 模块 """ import json -import warnings +from pathlib import Path from types import ModuleType -from typing import Set, Iterable, Optional +from typing import Set, Union, Iterable, Optional import tomlkit +from nonebot.utils import path_to_module_name + from .plugin import Plugin from .manager import PluginManager from . import _managers, get_plugin, _module_name_to_plugin_name -def load_plugin(module_path: str) -> Optional[Plugin]: +def load_plugin(module_path: Union[str, Path]) -> Optional[Plugin]: """加载单个插件,可以是本地插件或是通过 `pip` 安装的插件。 参数: - module_path: 插件名称 `path.to.your.plugin` + module_path: 插件名称 `path.to.your.plugin` 或插件路径 `pathlib.Path(path/to/your/plugin)` """ - + module_path = ( + path_to_module_name(module_path) + if isinstance(module_path, Path) + else module_path + ) manager = PluginManager([module_path]) _managers.append(manager) return manager.load_plugin(module_path) diff --git a/nonebot/plugin/manager.py b/nonebot/plugin/manager.py index ece76620..43f76845 100644 --- a/nonebot/plugin/manager.py +++ b/nonebot/plugin/manager.py @@ -17,7 +17,7 @@ from importlib.machinery import PathFinder, SourceFileLoader from typing import Set, Dict, List, Union, Iterable, Optional, Sequence from nonebot.log import logger -from nonebot.utils import escape_tag +from nonebot.utils import escape_tag, path_to_module_name from .plugin import Plugin, PluginMetadata from . import ( @@ -66,13 +66,6 @@ class PluginManager: """返回当前插件管理器中可用的插件名称。""" return self.third_party_plugins | self.searched_plugins - def _path_to_module_name(self, path: Path) -> str: - rel_path = path.resolve().relative_to(Path(".").resolve()) - if rel_path.stem == "__init__": - return ".".join(rel_path.parts[:-1]) - else: - return ".".join(rel_path.parts[:-1] + (rel_path.stem,)) - def _previous_plugins(self) -> Set[str]: _pre_managers: List[PluginManager] if self in _managers: @@ -86,7 +79,6 @@ class PluginManager: def prepare_plugins(self) -> Set[str]: """搜索插件并缓存插件名称。""" - # get all previous ready to load plugins previous_plugins = self._previous_plugins() searched_plugins: Dict[str, Path] = {} @@ -118,11 +110,13 @@ class PluginManager: f"Plugin already exists: {module_info.name}! Check your plugin name" ) - module_spec = module_info.module_finder.find_spec(module_info.name, None) - if not module_spec: + if not ( + module_spec := module_info.module_finder.find_spec( + module_info.name, None + ) + ): continue - module_path = module_spec.origin - if not module_path: + if not (module_path := module_spec.origin): continue searched_plugins[module_info.name] = Path(module_path).resolve() @@ -146,7 +140,7 @@ class PluginManager: module = importlib.import_module(self._third_party_plugin_names[name]) elif name in self._searched_plugin_names: module = importlib.import_module( - self._path_to_module_name(self._searched_plugin_names[name]) + path_to_module_name(self._searched_plugin_names[name]) ) else: raise RuntimeError(f"Plugin not found: {name}! Check your plugin name") @@ -154,8 +148,7 @@ class PluginManager: logger.opt(colors=True).success( f'Succeeded to import "{escape_tag(name)}"' ) - plugin = getattr(module, "__plugin__", None) - if plugin is None: + if (plugin := getattr(module, "__plugin__", None)) is None: raise RuntimeError( f"Module {module.__name__} is not loaded as a plugin! " "Make sure not to import it before loading." diff --git a/nonebot/utils.py b/nonebot/utils.py index b2aac548..8957aebc 100644 --- a/nonebot/utils.py +++ b/nonebot/utils.py @@ -10,6 +10,7 @@ import json import asyncio import inspect import dataclasses +from pathlib import Path from functools import wraps, partial from contextlib import asynccontextmanager from typing_extensions import ParamSpec, get_args, get_origin @@ -165,6 +166,14 @@ def get_name(obj: Any) -> str: return obj.__class__.__name__ +def path_to_module_name(path: Path) -> str: + rel_path = path.resolve().relative_to(Path(".").resolve()) + if rel_path.stem == "__init__": + return ".".join(rel_path.parts[:-1]) + else: + return ".".join(rel_path.parts[:-1] + (rel_path.stem,)) + + class DataclassEncoder(json.JSONEncoder): """在JSON序列化 {re}`nonebot.adapters._message.Message` (List[Dataclass]) 时使用的 `JSONEncoder`""" diff --git a/tests/test_plugin/test_load.py b/tests/test_plugin/test_load.py index ee4ce66b..a5e56566 100644 --- a/tests/test_plugin/test_load.py +++ b/tests/test_plugin/test_load.py @@ -1,4 +1,5 @@ import sys +from pathlib import Path from dataclasses import asdict from typing import TYPE_CHECKING, Set @@ -10,7 +11,21 @@ if TYPE_CHECKING: @pytest.mark.asyncio -async def test_load_plugin(app: App, load_plugin: Set["Plugin"]): +async def test_load_plugin(app: App): + import nonebot + + # check regular + assert nonebot.load_plugin("plugins.metadata") + + # check path + assert nonebot.load_plugin(Path("plugins/export")) + + # check not found + assert nonebot.load_plugin("some_plugin_not_exist") is None + + +@pytest.mark.asyncio +async def test_load_plugins(app: App, load_plugin: Set["Plugin"]): import nonebot from nonebot.plugin import PluginManager @@ -34,9 +49,6 @@ async def test_load_plugin(app: App, load_plugin: Set["Plugin"]): with pytest.raises(RuntimeError): PluginManager(search_path=["plugins"]).load_all_plugins() - # check not found - assert nonebot.load_plugin("some_plugin_not_exist") is None - @pytest.mark.asyncio async def test_load_nested_plugin(app: App, load_plugin: Set["Plugin"]):