Featue: load_plugin 支持 pathlib.Path (#1194)

Co-authored-by: Ju4tCode <42488585+yanyongyu@users.noreply.github.com>
This commit is contained in:
Lan 2022-08-31 10:07:14 +08:00 committed by GitHub
parent 4e76518a58
commit 1cfdee2645
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 45 additions and 25 deletions

View File

@ -5,24 +5,30 @@ FrontMatter:
description: nonebot.plugin.load 模块
"""
import json
import warnings
from pathlib import Path
from types import ModuleType
from typing import Set, Iterable, Optional
from typing import Set, Union, Iterable, Optional
import tomlkit
from nonebot.utils import path_to_module_name
from .plugin import Plugin
from .manager import PluginManager
from . import _managers, get_plugin, _module_name_to_plugin_name
def load_plugin(module_path: str) -> Optional[Plugin]:
def load_plugin(module_path: Union[str, Path]) -> Optional[Plugin]:
"""加载单个插件,可以是本地插件或是通过 `pip` 安装的插件。
参数:
module_path: 插件名称 `path.to.your.plugin`
module_path: 插件名称 `path.to.your.plugin` 或插件路径 `pathlib.Path(path/to/your/plugin)`
"""
module_path = (
path_to_module_name(module_path)
if isinstance(module_path, Path)
else module_path
)
manager = PluginManager([module_path])
_managers.append(manager)
return manager.load_plugin(module_path)

View File

@ -17,7 +17,7 @@ from importlib.machinery import PathFinder, SourceFileLoader
from typing import Set, Dict, List, Union, Iterable, Optional, Sequence
from nonebot.log import logger
from nonebot.utils import escape_tag
from nonebot.utils import escape_tag, path_to_module_name
from .plugin import Plugin, PluginMetadata
from . import (
@ -66,13 +66,6 @@ class PluginManager:
"""返回当前插件管理器中可用的插件名称。"""
return self.third_party_plugins | self.searched_plugins
def _path_to_module_name(self, path: Path) -> str:
rel_path = path.resolve().relative_to(Path(".").resolve())
if rel_path.stem == "__init__":
return ".".join(rel_path.parts[:-1])
else:
return ".".join(rel_path.parts[:-1] + (rel_path.stem,))
def _previous_plugins(self) -> Set[str]:
_pre_managers: List[PluginManager]
if self in _managers:
@ -86,7 +79,6 @@ class PluginManager:
def prepare_plugins(self) -> Set[str]:
"""搜索插件并缓存插件名称。"""
# get all previous ready to load plugins
previous_plugins = self._previous_plugins()
searched_plugins: Dict[str, Path] = {}
@ -118,11 +110,13 @@ class PluginManager:
f"Plugin already exists: {module_info.name}! Check your plugin name"
)
module_spec = module_info.module_finder.find_spec(module_info.name, None)
if not module_spec:
if not (
module_spec := module_info.module_finder.find_spec(
module_info.name, None
)
):
continue
module_path = module_spec.origin
if not module_path:
if not (module_path := module_spec.origin):
continue
searched_plugins[module_info.name] = Path(module_path).resolve()
@ -146,7 +140,7 @@ class PluginManager:
module = importlib.import_module(self._third_party_plugin_names[name])
elif name in self._searched_plugin_names:
module = importlib.import_module(
self._path_to_module_name(self._searched_plugin_names[name])
path_to_module_name(self._searched_plugin_names[name])
)
else:
raise RuntimeError(f"Plugin not found: {name}! Check your plugin name")
@ -154,8 +148,7 @@ class PluginManager:
logger.opt(colors=True).success(
f'Succeeded to import "<y>{escape_tag(name)}</y>"'
)
plugin = getattr(module, "__plugin__", None)
if plugin is None:
if (plugin := getattr(module, "__plugin__", None)) is None:
raise RuntimeError(
f"Module {module.__name__} is not loaded as a plugin! "
"Make sure not to import it before loading."

View File

@ -10,6 +10,7 @@ import json
import asyncio
import inspect
import dataclasses
from pathlib import Path
from functools import wraps, partial
from contextlib import asynccontextmanager
from typing_extensions import ParamSpec, get_args, get_origin
@ -165,6 +166,14 @@ def get_name(obj: Any) -> str:
return obj.__class__.__name__
def path_to_module_name(path: Path) -> str:
rel_path = path.resolve().relative_to(Path(".").resolve())
if rel_path.stem == "__init__":
return ".".join(rel_path.parts[:-1])
else:
return ".".join(rel_path.parts[:-1] + (rel_path.stem,))
class DataclassEncoder(json.JSONEncoder):
"""在JSON序列化 {re}`nonebot.adapters._message.Message` (List[Dataclass]) 时使用的 `JSONEncoder`"""

View File

@ -1,4 +1,5 @@
import sys
from pathlib import Path
from dataclasses import asdict
from typing import TYPE_CHECKING, Set
@ -10,7 +11,21 @@ if TYPE_CHECKING:
@pytest.mark.asyncio
async def test_load_plugin(app: App, load_plugin: Set["Plugin"]):
async def test_load_plugin(app: App):
import nonebot
# check regular
assert nonebot.load_plugin("plugins.metadata")
# check path
assert nonebot.load_plugin(Path("plugins/export"))
# check not found
assert nonebot.load_plugin("some_plugin_not_exist") is None
@pytest.mark.asyncio
async def test_load_plugins(app: App, load_plugin: Set["Plugin"]):
import nonebot
from nonebot.plugin import PluginManager
@ -34,9 +49,6 @@ async def test_load_plugin(app: App, load_plugin: Set["Plugin"]):
with pytest.raises(RuntimeError):
PluginManager(search_path=["plugins"]).load_all_plugins()
# check not found
assert nonebot.load_plugin("some_plugin_not_exist") is None
@pytest.mark.asyncio
async def test_load_nested_plugin(app: App, load_plugin: Set["Plugin"]):