Featue: load_plugin 支持 pathlib.Path (#1194)

Co-authored-by: Ju4tCode <42488585+yanyongyu@users.noreply.github.com>
This commit is contained in:
Lan 2022-08-31 10:07:14 +08:00 committed by GitHub
parent 4e76518a58
commit 1cfdee2645
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 45 additions and 25 deletions

View File

@ -5,24 +5,30 @@ FrontMatter:
description: nonebot.plugin.load 模块 description: nonebot.plugin.load 模块
""" """
import json import json
import warnings from pathlib import Path
from types import ModuleType from types import ModuleType
from typing import Set, Iterable, Optional from typing import Set, Union, Iterable, Optional
import tomlkit import tomlkit
from nonebot.utils import path_to_module_name
from .plugin import Plugin from .plugin import Plugin
from .manager import PluginManager from .manager import PluginManager
from . import _managers, get_plugin, _module_name_to_plugin_name from . import _managers, get_plugin, _module_name_to_plugin_name
def load_plugin(module_path: str) -> Optional[Plugin]: def load_plugin(module_path: Union[str, Path]) -> Optional[Plugin]:
"""加载单个插件,可以是本地插件或是通过 `pip` 安装的插件。 """加载单个插件,可以是本地插件或是通过 `pip` 安装的插件。
参数: 参数:
module_path: 插件名称 `path.to.your.plugin` module_path: 插件名称 `path.to.your.plugin` 或插件路径 `pathlib.Path(path/to/your/plugin)`
""" """
module_path = (
path_to_module_name(module_path)
if isinstance(module_path, Path)
else module_path
)
manager = PluginManager([module_path]) manager = PluginManager([module_path])
_managers.append(manager) _managers.append(manager)
return manager.load_plugin(module_path) return manager.load_plugin(module_path)

View File

@ -17,7 +17,7 @@ from importlib.machinery import PathFinder, SourceFileLoader
from typing import Set, Dict, List, Union, Iterable, Optional, Sequence from typing import Set, Dict, List, Union, Iterable, Optional, Sequence
from nonebot.log import logger from nonebot.log import logger
from nonebot.utils import escape_tag from nonebot.utils import escape_tag, path_to_module_name
from .plugin import Plugin, PluginMetadata from .plugin import Plugin, PluginMetadata
from . import ( from . import (
@ -66,13 +66,6 @@ class PluginManager:
"""返回当前插件管理器中可用的插件名称。""" """返回当前插件管理器中可用的插件名称。"""
return self.third_party_plugins | self.searched_plugins return self.third_party_plugins | self.searched_plugins
def _path_to_module_name(self, path: Path) -> str:
rel_path = path.resolve().relative_to(Path(".").resolve())
if rel_path.stem == "__init__":
return ".".join(rel_path.parts[:-1])
else:
return ".".join(rel_path.parts[:-1] + (rel_path.stem,))
def _previous_plugins(self) -> Set[str]: def _previous_plugins(self) -> Set[str]:
_pre_managers: List[PluginManager] _pre_managers: List[PluginManager]
if self in _managers: if self in _managers:
@ -86,7 +79,6 @@ class PluginManager:
def prepare_plugins(self) -> Set[str]: def prepare_plugins(self) -> Set[str]:
"""搜索插件并缓存插件名称。""" """搜索插件并缓存插件名称。"""
# get all previous ready to load plugins # get all previous ready to load plugins
previous_plugins = self._previous_plugins() previous_plugins = self._previous_plugins()
searched_plugins: Dict[str, Path] = {} searched_plugins: Dict[str, Path] = {}
@ -118,11 +110,13 @@ class PluginManager:
f"Plugin already exists: {module_info.name}! Check your plugin name" f"Plugin already exists: {module_info.name}! Check your plugin name"
) )
module_spec = module_info.module_finder.find_spec(module_info.name, None) if not (
if not module_spec: module_spec := module_info.module_finder.find_spec(
module_info.name, None
)
):
continue continue
module_path = module_spec.origin if not (module_path := module_spec.origin):
if not module_path:
continue continue
searched_plugins[module_info.name] = Path(module_path).resolve() searched_plugins[module_info.name] = Path(module_path).resolve()
@ -146,7 +140,7 @@ class PluginManager:
module = importlib.import_module(self._third_party_plugin_names[name]) module = importlib.import_module(self._third_party_plugin_names[name])
elif name in self._searched_plugin_names: elif name in self._searched_plugin_names:
module = importlib.import_module( module = importlib.import_module(
self._path_to_module_name(self._searched_plugin_names[name]) path_to_module_name(self._searched_plugin_names[name])
) )
else: else:
raise RuntimeError(f"Plugin not found: {name}! Check your plugin name") raise RuntimeError(f"Plugin not found: {name}! Check your plugin name")
@ -154,8 +148,7 @@ class PluginManager:
logger.opt(colors=True).success( logger.opt(colors=True).success(
f'Succeeded to import "<y>{escape_tag(name)}</y>"' f'Succeeded to import "<y>{escape_tag(name)}</y>"'
) )
plugin = getattr(module, "__plugin__", None) if (plugin := getattr(module, "__plugin__", None)) is None:
if plugin is None:
raise RuntimeError( raise RuntimeError(
f"Module {module.__name__} is not loaded as a plugin! " f"Module {module.__name__} is not loaded as a plugin! "
"Make sure not to import it before loading." "Make sure not to import it before loading."

View File

@ -10,6 +10,7 @@ import json
import asyncio import asyncio
import inspect import inspect
import dataclasses import dataclasses
from pathlib import Path
from functools import wraps, partial from functools import wraps, partial
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing_extensions import ParamSpec, get_args, get_origin from typing_extensions import ParamSpec, get_args, get_origin
@ -165,6 +166,14 @@ def get_name(obj: Any) -> str:
return obj.__class__.__name__ return obj.__class__.__name__
def path_to_module_name(path: Path) -> str:
rel_path = path.resolve().relative_to(Path(".").resolve())
if rel_path.stem == "__init__":
return ".".join(rel_path.parts[:-1])
else:
return ".".join(rel_path.parts[:-1] + (rel_path.stem,))
class DataclassEncoder(json.JSONEncoder): class DataclassEncoder(json.JSONEncoder):
"""在JSON序列化 {re}`nonebot.adapters._message.Message` (List[Dataclass]) 时使用的 `JSONEncoder`""" """在JSON序列化 {re}`nonebot.adapters._message.Message` (List[Dataclass]) 时使用的 `JSONEncoder`"""

View File

@ -1,4 +1,5 @@
import sys import sys
from pathlib import Path
from dataclasses import asdict from dataclasses import asdict
from typing import TYPE_CHECKING, Set from typing import TYPE_CHECKING, Set
@ -10,7 +11,21 @@ if TYPE_CHECKING:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load_plugin(app: App, load_plugin: Set["Plugin"]): async def test_load_plugin(app: App):
import nonebot
# check regular
assert nonebot.load_plugin("plugins.metadata")
# check path
assert nonebot.load_plugin(Path("plugins/export"))
# check not found
assert nonebot.load_plugin("some_plugin_not_exist") is None
@pytest.mark.asyncio
async def test_load_plugins(app: App, load_plugin: Set["Plugin"]):
import nonebot import nonebot
from nonebot.plugin import PluginManager from nonebot.plugin import PluginManager
@ -34,9 +49,6 @@ async def test_load_plugin(app: App, load_plugin: Set["Plugin"]):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
PluginManager(search_path=["plugins"]).load_all_plugins() PluginManager(search_path=["plugins"]).load_all_plugins()
# check not found
assert nonebot.load_plugin("some_plugin_not_exist") is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load_nested_plugin(app: App, load_plugin: Set["Plugin"]): async def test_load_nested_plugin(app: App, load_plugin: Set["Plugin"]):