nonebot2/nonebot/plugin/manager.py

176 lines
5.7 KiB
Python
Raw Normal View History

2021-02-19 14:58:26 +08:00
import sys
import pkgutil
import importlib
from pathlib import Path
2021-11-11 17:33:30 +08:00
from itertools import chain
2021-02-19 14:58:26 +08:00
from types import ModuleType
from importlib.abc import MetaPathFinder
2021-03-31 20:38:00 +08:00
from importlib.machinery import PathFinder, SourceFileLoader
2021-11-11 17:33:30 +08:00
from typing import Set, Dict, List, Union, Iterable, Optional, Sequence
2021-03-31 20:38:00 +08:00
2021-11-11 17:33:30 +08:00
from nonebot.log import logger
from nonebot.utils import escape_tag
2021-11-08 01:02:35 +08:00
from .plugin import Plugin, _new_plugin
2021-11-11 17:33:30 +08:00
from . import _managers, _current_plugin
2021-02-19 14:58:26 +08:00
2021-11-09 00:57:59 +08:00
2021-02-19 14:58:26 +08:00
class PluginManager:
2021-11-09 00:57:59 +08:00
def __init__(
self,
plugins: Optional[Iterable[str]] = None,
search_path: Optional[Iterable[str]] = None,
):
2021-02-19 14:58:26 +08:00
# simple plugin not in search path
self.plugins: Set[str] = set(plugins or [])
self.search_path: Set[str] = set(search_path or [])
2021-11-11 17:33:30 +08:00
# cache plugins
self.searched_plugins: Dict[str, Path] = {}
2021-02-19 14:58:26 +08:00
self.list_plugins()
2021-11-11 17:33:30 +08:00
def _path_to_module_name(self, path: Path) -> str:
rel_path = path.resolve().relative_to(Path(".").resolve())
if rel_path.stem == "__init__":
return ".".join(rel_path.parts[:-1])
else:
return ".".join(rel_path.parts[:-1] + (rel_path.stem,))
def _previous_plugins(self) -> List[str]:
_pre_managers: List[PluginManager]
if self in _managers:
_pre_managers = _managers[: _managers.index(self)]
2021-11-11 17:33:30 +08:00
else:
_pre_managers = _managers[:]
2021-02-19 14:58:26 +08:00
return [
2021-11-11 17:33:30 +08:00
*chain.from_iterable(
[*manager.plugins, *manager.searched_plugins.keys()]
for manager in _pre_managers
)
2021-02-19 14:58:26 +08:00
]
def list_plugins(self) -> Set[str]:
2021-11-11 17:33:30 +08:00
# get all previous ready to load plugins
previous_plugins = self._previous_plugins()
searched_plugins: Dict[str, Path] = {}
for module_info in pkgutil.iter_modules(self.search_path):
if module_info.name.startswith("_"):
continue
if (
module_info.name in searched_plugins.keys()
or module_info.name in previous_plugins
):
2021-11-11 17:33:30 +08:00
raise RuntimeError(
f"Plugin already exists: {module_info.name}! Check your plugin name"
)
module_spec = module_info.module_finder.find_spec(module_info.name, None)
2021-11-11 17:33:30 +08:00
if not module_spec:
continue
module_path = module_spec.origin
if not module_path:
continue
searched_plugins[module_info.name] = Path(module_path).resolve()
self.searched_plugins = searched_plugins
return self.plugins | set(self.searched_plugins.keys())
def load_plugin(self, name) -> Optional[Plugin]:
try:
if name in self.plugins:
module = importlib.import_module(name)
elif name not in self.searched_plugins:
raise RuntimeError(f"Plugin not found: {name}! Check your plugin name")
2021-11-11 17:33:30 +08:00
else:
module = importlib.import_module(
self._path_to_module_name(self.searched_plugins[name])
)
2021-11-11 17:33:30 +08:00
logger.opt(colors=True).success(
f'Succeeded to import "<y>{escape_tag(name)}</y>"'
)
2021-11-11 17:33:30 +08:00
return getattr(module, "__plugin__", None)
except Exception as e:
logger.opt(colors=True, exception=e).error(
f'<r><bg #f8bbd0>Failed to import "{escape_tag(name)}"</bg #f8bbd0></r>'
)
def load_all_plugins(self) -> Set[Plugin]:
return set(
filter(None, (self.load_plugin(name) for name in self.list_plugins()))
)
2021-02-19 14:58:26 +08:00
class PluginFinder(MetaPathFinder):
def find_spec(
self,
fullname: str,
path: Optional[Sequence[Union[bytes, str]]],
target: Optional[ModuleType] = None,
):
2021-11-11 17:33:30 +08:00
if _managers:
2021-02-19 14:58:26 +08:00
index = -1
2021-11-11 17:33:30 +08:00
module_spec = PathFinder.find_spec(fullname, path, target)
if not module_spec:
return
module_origin = module_spec.origin
if not module_origin:
return
module_path = Path(module_origin).resolve()
while -index <= len(_managers):
manager = _managers[index]
if (
fullname in manager.plugins
or module_path in manager.searched_plugins.values()
2021-11-11 17:33:30 +08:00
):
module_spec.loader = PluginLoader(manager, fullname, module_origin)
2021-11-11 17:33:30 +08:00
return module_spec
2021-02-19 14:58:26 +08:00
index -= 1
2021-11-11 17:33:30 +08:00
return
2021-02-19 14:58:26 +08:00
2021-03-13 18:21:56 +08:00
class PluginLoader(SourceFileLoader):
2021-03-19 14:59:59 +08:00
def __init__(self, manager: PluginManager, fullname: str, path) -> None:
self.manager = manager
2021-03-13 18:21:56 +08:00
self.loaded = False
super().__init__(fullname, path)
def create_module(self, spec) -> Optional[ModuleType]:
if self.name in sys.modules:
self.loaded = True
return sys.modules[self.name]
2021-03-19 14:59:59 +08:00
# return None to use default module creation
2021-03-13 18:21:56 +08:00
return super().create_module(spec)
def exec_module(self, module: ModuleType) -> None:
if self.loaded:
return
2021-03-31 20:38:00 +08:00
2021-11-11 17:33:30 +08:00
plugin = _new_plugin(self.name, module)
parent_plugin = _current_plugin.get()
if parent_plugin:
plugin.parent_plugin = parent_plugin
parent_plugin.sub_plugins.add(plugin)
2021-11-11 17:33:30 +08:00
_plugin_token = _current_plugin.set(plugin)
2021-11-11 17:33:30 +08:00
setattr(module, "__plugin__", plugin)
# try:
# super().exec_module(module)
# except Exception as e:
# raise ImportError(
# f"Error when executing module {module_name} from {module.__file__}."
# ) from e
super().exec_module(module)
_current_plugin.reset(_plugin_token)
2021-03-19 14:59:59 +08:00
return
2021-03-13 18:21:56 +08:00
2021-02-19 14:58:26 +08:00
sys.meta_path.insert(0, PluginFinder())