nonebot2/nonebot/plugin/manager.py
2024-10-22 10:33:48 +08:00

273 lines
8.9 KiB
Python

"""本模块实现插件加载流程。
参考: [import hooks](https://docs.python.org/3/reference/import.html#import-hooks), [PEP302](https://www.python.org/dev/peps/pep-0302/)
FrontMatter:
mdx:
format: md
sidebar_position: 5
description: nonebot.plugin.manager 模块
"""
import sys
import pkgutil
import importlib
from pathlib import Path
from itertools import chain
from typing import Optional
from types import ModuleType
from importlib.abc import MetaPathFinder
from collections.abc import Iterable, Sequence
from importlib.machinery import PathFinder, SourceFileLoader
from nonebot.log import logger
from nonebot.utils import escape_tag, path_to_module_name
from .model import Plugin, PluginMetadata
from . import (
_managers,
_new_plugin,
_revert_plugin,
_current_plugin,
_module_name_to_plugin_id,
)
class PluginManager:
"""插件管理器。
参数:
plugins: 独立插件模块名集合。
search_path: 插件搜索路径(文件夹),相对于当前工作目录。
"""
def __init__(
self,
plugins: Optional[Iterable[str]] = None,
search_path: Optional[Iterable[str]] = None,
):
# simple plugin not in search path
self.plugins: set[str] = set(plugins or [])
self.search_path: set[str] = set(search_path or [])
# cache plugins
self._third_party_plugin_ids: dict[str, str] = {}
self._searched_plugin_ids: dict[str, str] = {}
self._prepare_plugins()
def __repr__(self) -> str:
return f"PluginManager(available_plugins={self.controlled_modules})"
@property
def third_party_plugins(self) -> set[str]:
"""返回所有独立插件标识符。"""
return set(self._third_party_plugin_ids.keys())
@property
def searched_plugins(self) -> set[str]:
"""返回已搜索到的插件标识符。"""
return set(self._searched_plugin_ids.keys())
@property
def available_plugins(self) -> set[str]:
"""返回当前插件管理器中可用的插件标识符。"""
return self.third_party_plugins | self.searched_plugins
@property
def controlled_modules(self) -> dict[str, str]:
"""返回当前插件管理器中控制的插件标识符与模块路径映射字典。"""
return dict(
chain(
self._third_party_plugin_ids.items(), self._searched_plugin_ids.items()
)
)
def _previous_controlled_modules(self) -> dict[str, str]:
_pre_managers: list[PluginManager]
if self in _managers:
_pre_managers = _managers[: _managers.index(self)]
else:
_pre_managers = _managers[:]
return {
plugin_id: module_name
for manager in _pre_managers
for plugin_id, module_name in manager.controlled_modules.items()
}
def _prepare_plugins(self) -> set[str]:
"""搜索插件并缓存插件名称。"""
# get all previous ready to load plugins
previous_plugin_ids = self._previous_controlled_modules()
# if self not in global managers, merge self's controlled modules
def get_controlled_modules():
return (
previous_plugin_ids
if self in _managers
else {**previous_plugin_ids, **self.controlled_modules}
)
# check third party plugins
for plugin in self.plugins:
plugin_id = _module_name_to_plugin_id(plugin, get_controlled_modules())
if (
plugin_id in self._third_party_plugin_ids
or plugin_id in previous_plugin_ids
):
raise RuntimeError(
f"Plugin already exists: {plugin_id}! Check your plugin name"
)
self._third_party_plugin_ids[plugin_id] = plugin
# check plugins in search path
for module_info in pkgutil.iter_modules(self.search_path):
# ignore if startswith "_"
if module_info.name.startswith("_"):
continue
if not (
module_spec := module_info.module_finder.find_spec(
module_info.name, None
)
):
continue
if not module_spec.origin:
continue
# get module name from path, pkgutil does not return the actual module name
module_path = Path(module_spec.origin).resolve()
module_name = path_to_module_name(module_path)
plugin_id = _module_name_to_plugin_id(module_name, get_controlled_modules())
if (
plugin_id in previous_plugin_ids
or plugin_id in self._third_party_plugin_ids
or plugin_id in self._searched_plugin_ids
):
raise RuntimeError(
f"Plugin already exists: {plugin_id}! Check your plugin name"
)
self._searched_plugin_ids[plugin_id] = module_name
return self.available_plugins
def load_plugin(self, name: str) -> Optional[Plugin]:
"""加载指定插件。
可以使用完整插件模块名或者插件标识符加载。
参数:
name: 插件名称或插件标识符。
"""
try:
# load using plugin id
if name in self._third_party_plugin_ids:
module = importlib.import_module(self._third_party_plugin_ids[name])
elif name in self._searched_plugin_ids:
module = importlib.import_module(self._searched_plugin_ids[name])
# load using module name
elif (
name in self._third_party_plugin_ids.values()
or name in self._searched_plugin_ids.values()
):
module = importlib.import_module(name)
else:
raise RuntimeError(f"Plugin not found: {name}! Check your plugin name")
if (
plugin := getattr(module, "__plugin__", None)
) is None or not isinstance(plugin, Plugin):
raise RuntimeError(
f"Module {module.__name__} is not loaded as a plugin! "
f"Make sure not to import it before loading."
)
logger.opt(colors=True).success(
f'Succeeded to load plugin "<y>{escape_tag(plugin.id_)}</y>"'
+ (
f' from "<m>{escape_tag(plugin.module_name)}</m>"'
if plugin.module_name != plugin.id_
else ""
)
)
return plugin
except Exception as e:
logger.opt(colors=True, exception=e).error(
f'<r><bg #f8bbd0>Failed to import "{escape_tag(name)}"</bg #f8bbd0></r>'
)
def load_all_plugins(self) -> set[Plugin]:
"""加载所有可用插件。"""
return set(
filter(None, (self.load_plugin(name) for name in self.available_plugins))
)
class PluginFinder(MetaPathFinder):
def find_spec(
self,
fullname: str,
path: Optional[Sequence[str]],
target: Optional[ModuleType] = None,
):
if _managers:
module_spec = PathFinder.find_spec(fullname, path, target)
if not module_spec:
return
module_origin = module_spec.origin
if not module_origin:
return
for manager in reversed(_managers):
if fullname in manager.controlled_modules.values():
module_spec.loader = PluginLoader(manager, fullname, module_origin)
return module_spec
return
class PluginLoader(SourceFileLoader):
def __init__(self, manager: PluginManager, fullname: str, path: str) -> None:
self.manager = manager
self.loaded = False
super().__init__(fullname, path)
def create_module(self, spec) -> Optional[ModuleType]:
if self.name in sys.modules:
self.loaded = True
return sys.modules[self.name]
# return None to use default module creation
return super().create_module(spec)
def exec_module(self, module: ModuleType) -> None:
if self.loaded:
return
# create plugin before executing
plugin = _new_plugin(self.name, module, self.manager)
setattr(module, "__plugin__", plugin)
# enter plugin context
_plugin_token = _current_plugin.set(plugin)
try:
super().exec_module(module)
except Exception:
_revert_plugin(plugin)
raise
finally:
# leave plugin context
_current_plugin.reset(_plugin_token)
# get plugin metadata
metadata: Optional[PluginMetadata] = getattr(module, "__plugin_meta__", None)
plugin.metadata = metadata
return
sys.meta_path.insert(0, PluginFinder())