🐛 fix import hook export

This commit is contained in:
yanyongyu 2021-03-31 20:38:00 +08:00
parent d1e8925fe0
commit ca08c56df7
3 changed files with 80 additions and 69 deletions

View File

@ -20,6 +20,7 @@ from nonebot.permission import Permission
from nonebot.typing import T_State, T_StateFactory, T_Handler, T_RuleChecker
from nonebot.rule import Rule, startswith, endswith, keyword, command, shell_command, ArgumentParser, regex
from .export import Export, export, _export
from .manager import PluginManager, _current_plugin
if TYPE_CHECKING:
@ -32,55 +33,9 @@ plugins: Dict[str, "Plugin"] = {}
"""
PLUGIN_NAMESPACE = "nonebot.loaded_plugins"
_export: ContextVar["Export"] = ContextVar("_export")
# FIXME: tmp matchers context var will be removed
_plugin_matchers: Dict[str, Set[Type[Matcher]]] = defaultdict(set)
class Export(dict):
"""
:说明:
插件导出内容以使得其他插件可以获得
:示例:
.. code-block:: python
nonebot.export().default = "bar"
@nonebot.export()
def some_function():
pass
# this doesn't work before python 3.9
# use
# export = nonebot.export(); @export.sub
# instead
# See also PEP-614: https://www.python.org/dev/peps/pep-0614/
@nonebot.export().sub
def something_else():
pass
"""
def __call__(self, func, **kwargs):
self[func.__name__] = func
self.update(kwargs)
return func
def __setitem__(self, key, value):
super().__setitem__(key,
Export(value) if isinstance(value, dict) else value)
def __setattr__(self, name, value):
self[name] = Export(value) if isinstance(value, dict) else value
def __getattr__(self, name):
if name not in self:
self[name] = Export()
return self[name]
@dataclass(eq=False)
class Plugin(object):
"""存储插件信息"""
@ -966,15 +921,14 @@ def _load_plugin(manager: PluginManager, plugin_name: str) -> Optional[Plugin]:
if plugin_name.startswith("_"):
return None
_export.set(Export())
if plugin_name in plugins:
return None
try:
module = manager.load_plugin(plugin_name)
plugin = Plugin(plugin_name, module, _export.get())
plugin = Plugin(plugin_name, module,
getattr(module, "__export__", Export()))
plugins[plugin_name] = plugin
logger.opt(
colors=True).info(f'Succeeded to import "<y>{plugin_name}</y>"')
@ -1153,19 +1107,6 @@ def get_loaded_plugins() -> Set[Plugin]:
return set(plugins.values())
def export() -> Export:
"""
:说明:
获取插件的导出内容对象
:返回:
- ``Export``
"""
return _export.get()
def require(name: str) -> Optional[Export]:
"""
:说明:

60
nonebot/plugin/export.py Normal file
View File

@ -0,0 +1,60 @@
from contextvars import ContextVar
_export: ContextVar["Export"] = ContextVar("_export")
class Export(dict):
"""
:说明:
插件导出内容以使得其他插件可以获得
:示例:
.. code-block:: python
nonebot.export().default = "bar"
@nonebot.export()
def some_function():
pass
# this doesn't work before python 3.9
# use
# export = nonebot.export(); @export.sub
# instead
# See also PEP-614: https://www.python.org/dev/peps/pep-0614/
@nonebot.export().sub
def something_else():
pass
"""
def __call__(self, func, **kwargs):
self[func.__name__] = func
self.update(kwargs)
return func
def __setitem__(self, key, value):
super().__setitem__(key,
Export(value) if isinstance(value, dict) else value)
def __setattr__(self, name, value):
self[name] = Export(value) if isinstance(value, dict) else value
def __getattr__(self, name):
if name not in self:
self[name] = Export()
return self[name]
def export() -> Export:
"""
:说明:
获取插件的导出内容对象
:返回:
- ``Export``
"""
return _export.get()

View File

@ -7,8 +7,10 @@ from types import ModuleType
from collections import Counter
from contextvars import ContextVar
from importlib.abc import MetaPathFinder
from importlib.machinery import PathFinder, FrozenImporter, SourceFileLoader
from typing import Set, List, Optional, Iterable
from importlib.machinery import PathFinder, SourceFileLoader
from .export import _export, Export
_current_plugin: ContextVar[Optional[str]] = ContextVar("_current_plugin",
default=None)
@ -160,10 +162,10 @@ class PluginManager:
path = module_name.split(".")
length = self.namespace.count(".") + 1
return f"{prefix}{'.'.join(path[length:])}"
elif module_name in self.search_plugins():
return f"{prefix}{module_name}"
elif module_name in self.plugins or module_name.startswith(prefix):
return module_name
elif module_name in self.search_plugins():
return f"{prefix}{module_name}"
return None
@ -191,7 +193,8 @@ class PluginLoader(SourceFileLoader):
def __init__(self, manager: PluginManager, fullname: str, path) -> None:
self.manager = manager
self.loaded = False
self._context_token = None
self._plugin_token = None
self._export_token = None
super().__init__(fullname, path)
def create_module(self, spec) -> Optional[ModuleType]:
@ -201,7 +204,8 @@ class PluginLoader(SourceFileLoader):
prefix = self.manager.internal_module.__name__
plugin_name = self.name[len(prefix):] if self.name.startswith(
prefix) else self.name
self._context_token = _current_plugin.set(plugin_name.lstrip("."))
self._plugin_token = _current_plugin.set(plugin_name.lstrip("."))
self._export_token = _export.set(Export())
# return None to use default module creation
return super().create_module(spec)
@ -210,9 +214,15 @@ class PluginLoader(SourceFileLoader):
return
# really need?
# setattr(module, "__manager__", self.manager)
if self._export_token:
setattr(module, "__export__", _export.get())
super().exec_module(module)
if self._context_token:
_current_plugin.reset(self._context_token)
if self._plugin_token:
_current_plugin.reset(self._plugin_token)
if self._export_token:
_export.reset(self._export_token)
return