🐛 fix require error

This commit is contained in:
yanyongyu 2022-01-26 15:06:53 +08:00
parent 956ee7e321
commit 5c73c80c65
3 changed files with 30 additions and 13 deletions

View File

@ -137,6 +137,12 @@ def load_builtin_plugins(*plugins) -> Set[Plugin]:
return load_all_plugins([f"nonebot.plugins.{p}" for p in plugins], []) return load_all_plugins([f"nonebot.plugins.{p}" for p in plugins], [])
def _find_manager_by_name(name: str) -> Optional[PluginManager]:
for manager in reversed(_managers):
if name in manager.plugins or name in manager.searched_plugins:
return manager
def require(name: str) -> Export: def require(name: str) -> Export:
"""获取一个插件的导出内容。 """获取一个插件的导出内容。
@ -148,7 +154,13 @@ def require(name: str) -> Export:
异常: 异常:
RuntimeError: 插件无法加载 RuntimeError: 插件无法加载
""" """
plugin = get_plugin(name) or load_plugin(name) plugin = get_plugin(name.rsplit(".", 1)[-1])
if not plugin: if not plugin:
raise RuntimeError(f'Cannot load plugin "{name}"!') manager = _find_manager_by_name(name)
if manager:
plugin = manager.load_plugin(name)
else:
plugin = load_plugin(name)
if not plugin:
raise RuntimeError(f'Cannot load plugin "{name}"!')
return plugin.export return plugin.export

View File

@ -53,7 +53,10 @@ class PluginManager:
return [ return [
*chain.from_iterable( *chain.from_iterable(
[*manager.plugins, *manager.searched_plugins.keys()] [
*map(lambda x: x.rsplit(".", 1)[-1], manager.plugins),
*manager.searched_plugins.keys(),
]
for manager in _pre_managers for manager in _pre_managers
) )
] ]
@ -65,7 +68,7 @@ class PluginManager:
third_party_plugins: Set[str] = set() third_party_plugins: Set[str] = set()
for plugin in self.plugins: for plugin in self.plugins:
name = plugin.rsplit(".", 1)[-1] if "." in plugin else plugin name = plugin.rsplit(".", 1)[-1]
if name in third_party_plugins or name in previous_plugins: if name in third_party_plugins or name in previous_plugins:
raise RuntimeError( raise RuntimeError(
f"Plugin already exists: {name}! Check your plugin name" f"Plugin already exists: {name}! Check your plugin name"
@ -99,17 +102,23 @@ class PluginManager:
try: try:
if name in self.plugins: if name in self.plugins:
module = importlib.import_module(name) module = importlib.import_module(name)
elif name not in self.searched_plugins: elif name in self.searched_plugins:
raise RuntimeError(f"Plugin not found: {name}! Check your plugin name")
else:
module = importlib.import_module( module = importlib.import_module(
self._path_to_module_name(self.searched_plugins[name]) self._path_to_module_name(self.searched_plugins[name])
) )
else:
raise RuntimeError(f"Plugin not found: {name}! Check your plugin name")
logger.opt(colors=True).success( logger.opt(colors=True).success(
f'Succeeded to import "<y>{escape_tag(name)}</y>"' f'Succeeded to import "<y>{escape_tag(name)}</y>"'
) )
return getattr(module, "__plugin__", None) plugin = getattr(module, "__plugin__", None)
if plugin is None:
raise RuntimeError(
f"Module {module.__name__} is not loaded as a plugin! "
"Make sure not to import it before loading."
)
return plugin
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
f'<r><bg #f8bbd0>Failed to import "{escape_tag(name)}"</bg #f8bbd0></r>' f'<r><bg #f8bbd0>Failed to import "{escape_tag(name)}"</bg #f8bbd0></r>'

View File

@ -95,8 +95,4 @@ async def test_load_plugin(load_plugin: Set["Plugin"]):
except RuntimeError: except RuntimeError:
assert True assert True
try: assert nonebot.load_plugin("some_plugin_no_exist") is None
nonebot.load_plugin("some_plugin_no_exist")
assert False
except Exception:
assert nonebot.get_plugin("some_plugin_no_exist") is None