⚗️ add export require option

This commit is contained in:
yanyongyu 2020-11-21 20:40:09 +08:00 committed by pull[bot]
parent 689180ebe8
commit 2b10f81326
6 changed files with 145 additions and 37 deletions

View File

@ -240,4 +240,4 @@ async def _start_scheduler():
from nonebot.plugin import on_message, on_notice, on_request, on_metaevent, CommandGroup from nonebot.plugin import on_message, on_notice, on_request, on_metaevent, CommandGroup
from nonebot.plugin import on_startswith, on_endswith, on_keyword, on_command, on_regex from nonebot.plugin import on_startswith, on_endswith, on_keyword, on_command, on_regex
from nonebot.plugin import load_plugin, load_plugins, load_builtin_plugins from nonebot.plugin import load_plugin, load_plugins, load_builtin_plugins
from nonebot.plugin import get_plugin, get_loaded_plugins from nonebot.plugin import export, require, get_plugin, get_loaded_plugins

View File

@ -11,6 +11,7 @@ import pkgutil
import importlib import importlib
from dataclasses import dataclass from dataclasses import dataclass
from importlib._bootstrap import _load from importlib._bootstrap import _load
from contextvars import Context, ContextVar, copy_context
from nonebot.log import logger from nonebot.log import logger
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
@ -25,7 +26,45 @@ plugins: Dict[str, "Plugin"] = {}
:说明: 已加载的插件 :说明: 已加载的插件
""" """
_tmp_matchers: Set[Type[Matcher]] = set() _tmp_matchers: ContextVar[Set[Type[Matcher]]] = ContextVar("_tmp_matchers")
_export: ContextVar["Export"] = ContextVar("_export")
class Export(dict):
"""
:说明:
插件导出内容以使得其他插件可以获得
:示例:
.. code-block:: python
nonebot.export().default = "bar"
@nonebot.export()
def some_function():
pass
@nonebot.export().sub
def something_else():
pass
"""
def __call__(self, func, **kwargs):
self[func.__name__] = func
self.update(kwargs)
return func
def __setitem__(self, key, value):
super().__setitem__(key,
Export(value) if isinstance(value, dict) else value)
def __setattr__(self, name, value):
self[name] = Export(value) if isinstance(value, dict) else value
def __getattr__(self, name):
if name not in self:
self[name] = Export()
return self[name]
@dataclass(eq=False) @dataclass(eq=False)
@ -46,6 +85,7 @@ class Plugin(object):
- **类型**: ``Set[Type[Matcher]]`` - **类型**: ``Set[Type[Matcher]]``
- **说明**: 插件内定义的 ``Matcher`` - **说明**: 插件内定义的 ``Matcher``
""" """
export: Export
def on(type: str = "", def on(type: str = "",
@ -80,7 +120,7 @@ def on(type: str = "",
block=block, block=block,
handlers=handlers, handlers=handlers,
default_state=state) default_state=state)
_tmp_matchers.add(matcher) _tmp_matchers.get().add(matcher)
return matcher return matcher
@ -112,7 +152,7 @@ def on_metaevent(rule: Optional[Union[Rule, RuleChecker]] = None,
block=block, block=block,
handlers=handlers, handlers=handlers,
default_state=state) default_state=state)
_tmp_matchers.add(matcher) _tmp_matchers.get().add(matcher)
return matcher return matcher
@ -146,7 +186,7 @@ def on_message(rule: Optional[Union[Rule, RuleChecker]] = None,
block=block, block=block,
handlers=handlers, handlers=handlers,
default_state=state) default_state=state)
_tmp_matchers.add(matcher) _tmp_matchers.get().add(matcher)
return matcher return matcher
@ -178,7 +218,7 @@ def on_notice(rule: Optional[Union[Rule, RuleChecker]] = None,
block=block, block=block,
handlers=handlers, handlers=handlers,
default_state=state) default_state=state)
_tmp_matchers.add(matcher) _tmp_matchers.get().add(matcher)
return matcher return matcher
@ -210,7 +250,7 @@ def on_request(rule: Optional[Union[Rule, RuleChecker]] = None,
block=block, block=block,
handlers=handlers, handlers=handlers,
default_state=state) default_state=state)
_tmp_matchers.add(matcher) _tmp_matchers.get().add(matcher)
return matcher return matcher
@ -387,27 +427,35 @@ def load_plugin(module_path: str) -> Optional[Plugin]:
:返回: :返回:
- ``Optional[Plugin]`` - ``Optional[Plugin]``
""" """
try:
_tmp_matchers.clear() def _load_plugin(module_path: str) -> Optional[Plugin]:
if module_path in plugins: try:
return plugins[module_path] _tmp_matchers.set(set())
elif module_path in sys.modules: _export.set(Export())
logger.warning( if module_path in plugins:
f"Module {module_path} has been loaded by other plugins! Ignored" return plugins[module_path]
elif module_path in sys.modules:
logger.warning(
f"Module {module_path} has been loaded by other plugins! Ignored"
)
return
module = importlib.import_module(module_path)
for m in _tmp_matchers.get():
m.module = module_path
plugin = Plugin(module_path, module, _tmp_matchers.get(),
_export.get())
plugins[module_path] = plugin
logger.opt(
colors=True).info(f'Succeeded to import "<y>{module_path}</y>"')
return plugin
except Exception as e:
logger.opt(colors=True, exception=e).error(
f'<r><bg #f8bbd0>Failed to import "{module_path}"</bg #f8bbd0></r>'
) )
return return None
module = importlib.import_module(module_path)
for m in _tmp_matchers: context: Context = copy_context()
m.module = module_path return context.run(_load_plugin, module_path)
plugin = Plugin(module_path, module, _tmp_matchers.copy())
plugins[module_path] = plugin
logger.opt(
colors=True).info(f'Succeeded to import "<y>{module_path}</y>"')
return plugin
except Exception as e:
logger.opt(colors=True, exception=e).error(
f'<r><bg #f8bbd0>Failed to import "{module_path}"</bg #f8bbd0></r>')
return None
def load_plugins(*plugin_dir: str) -> Set[Plugin]: def load_plugins(*plugin_dir: str) -> Set[Plugin]:
@ -419,33 +467,42 @@ def load_plugins(*plugin_dir: str) -> Set[Plugin]:
:返回: :返回:
- ``Set[Plugin]`` - ``Set[Plugin]``
""" """
loaded_plugins = set()
for module_info in pkgutil.iter_modules(plugin_dir): def _load_plugin(module_info) -> Optional[Plugin]:
_tmp_matchers.clear() _tmp_matchers.set(set())
_export.set(Export())
name = module_info.name name = module_info.name
if name.startswith("_"): if name.startswith("_"):
continue return
spec = module_info.module_finder.find_spec(name, None) spec = module_info.module_finder.find_spec(name, None)
if spec.name in plugins: if spec.name in plugins:
continue return
elif spec.name in sys.modules: elif spec.name in sys.modules:
logger.warning( logger.warning(
f"Module {spec.name} has been loaded by other plugin! Ignored") f"Module {spec.name} has been loaded by other plugin! Ignored")
continue return
try: try:
module = _load(spec) module = _load(spec)
for m in _tmp_matchers: for m in _tmp_matchers.get():
m.module = name m.module = name
plugin = Plugin(name, module, _tmp_matchers.copy()) plugin = Plugin(name, module, _tmp_matchers.get(), _export.get())
plugins[name] = plugin plugins[name] = plugin
loaded_plugins.add(plugin)
logger.opt(colors=True).info(f'Succeeded to import "<y>{name}</y>"') logger.opt(colors=True).info(f'Succeeded to import "<y>{name}</y>"')
return plugin
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
f'<r><bg #f8bbd0>Failed to import "{name}"</bg #f8bbd0></r>') f'<r><bg #f8bbd0>Failed to import "{name}"</bg #f8bbd0></r>')
return None
loaded_plugins = set()
for module_info in pkgutil.iter_modules(plugin_dir):
context: Context = copy_context()
result = context.run(_load_plugin, module_info)
if result:
loaded_plugins.add(result)
return loaded_plugins return loaded_plugins
@ -479,3 +536,12 @@ def get_loaded_plugins() -> Set[Plugin]:
- ``Set[Plugin]`` - ``Set[Plugin]``
""" """
return set(plugins.values()) return set(plugins.values())
def export() -> Export:
return _export.get()
def require(name: str) -> Optional[Export]:
plugin = get_plugin(name)
return plugin.export if plugin else None

View File

@ -1,17 +1,32 @@
import re import re
from contextvars import ContextVar
from nonebot.typing import Rule, Matcher, Handler, Permission, RuleChecker from nonebot.typing import Rule, Matcher, Handler, Permission, RuleChecker
from nonebot.typing import Set, List, Dict, Type, Tuple, Union, Optional, ModuleType from nonebot.typing import Set, List, Dict, Type, Tuple, Union, Optional, ModuleType
plugins: Dict[str, "Plugin"] = ... plugins: Dict[str, "Plugin"] = ...
_tmp_matchers: Set[Type[Matcher]] = ... _tmp_matchers: ContextVar[Set[Type[Matcher]]] = ...
_export: ContextVar["Export"] = ...
class Export(dict):
def __call__(self, func, **kwargs):
...
def __setattr__(self, name, value):
...
def __getattr__(self, name):
...
class Plugin(object): class Plugin(object):
name: str name: str
module: ModuleType module: ModuleType
matcher: Set[Type[Matcher]] matcher: Set[Type[Matcher]]
export: Export
def on(type: str = ..., def on(type: str = ...,
@ -149,6 +164,14 @@ def get_loaded_plugins() -> Set[Plugin]:
... ...
def export() -> Export:
...
def require(name: str) -> Export:
...
class CommandGroup: class CommandGroup:
def __init__(self, def __init__(self,

View File

@ -9,6 +9,8 @@ sidebar: auto
- 修复 cqhttp 检查 to me 时出现 IndexError - 修复 cqhttp 检查 to me 时出现 IndexError
- 修复已失效的事件响应器仍会运行一次的 bug - 修复已失效的事件响应器仍会运行一次的 bug
- 修改 cqhttp 检查 reply 时未去除后续 at 以及空格 - 修改 cqhttp 检查 reply 时未去除后续 at 以及空格
- 添加 get_plugin 获取插件函数
- 添加插件 export, require 方法
## v2.0.0a6 ## v2.0.0a6

View File

@ -22,6 +22,8 @@ nonebot.load_builtin_plugins()
# load local plugins # load local plugins
nonebot.load_plugins("test_plugins") nonebot.load_plugins("test_plugins")
print(nonebot.require("test_export"))
# modify some config / config depends on loaded configs # modify some config / config depends on loaded configs
config = nonebot.get_driver().config config = nonebot.get_driver().config
config.custom_config3 = config.custom_config1 config.custom_config3 = config.custom_config1

View File

@ -0,0 +1,15 @@
import nonebot
export = nonebot.export()
export.foo = "bar"
export["bar"] = "foo"
@export
def a():
pass
@export.sub
def b():
pass