diff --git a/nonebot/__init__.py b/nonebot/__init__.py index 61f7b9a9..312a8c67 100644 --- a/nonebot/__init__.py +++ b/nonebot/__init__.py @@ -240,4 +240,4 @@ async def _start_scheduler(): from nonebot.plugin import on_message, on_notice, on_request, on_metaevent, CommandGroup from nonebot.plugin import on_startswith, on_endswith, on_keyword, on_command, on_regex from nonebot.plugin import load_plugin, load_plugins, load_builtin_plugins -from nonebot.plugin import get_plugin, get_loaded_plugins +from nonebot.plugin import export, require, get_plugin, get_loaded_plugins diff --git a/nonebot/plugin.py b/nonebot/plugin.py index d9aba299..fd9195b7 100644 --- a/nonebot/plugin.py +++ b/nonebot/plugin.py @@ -11,6 +11,7 @@ import pkgutil import importlib from dataclasses import dataclass from importlib._bootstrap import _load +from contextvars import Context, ContextVar, copy_context from nonebot.log import logger from nonebot.matcher import Matcher @@ -25,7 +26,45 @@ plugins: Dict[str, "Plugin"] = {} :说明: 已加载的插件 """ -_tmp_matchers: Set[Type[Matcher]] = set() +_tmp_matchers: ContextVar[Set[Type[Matcher]]] = ContextVar("_tmp_matchers") +_export: ContextVar["Export"] = ContextVar("_export") + + +class Export(dict): + """ + :说明: + 插件导出内容以使得其他插件可以获得。 + :示例: + + .. code-block:: python + + nonebot.export().default = "bar" + + @nonebot.export() + def some_function(): + pass + + @nonebot.export().sub + def something_else(): + pass + """ + + def __call__(self, func, **kwargs): + self[func.__name__] = func + self.update(kwargs) + return func + + def __setitem__(self, key, value): + super().__setitem__(key, + Export(value) if isinstance(value, dict) else value) + + def __setattr__(self, name, value): + self[name] = Export(value) if isinstance(value, dict) else value + + def __getattr__(self, name): + if name not in self: + self[name] = Export() + return self[name] @dataclass(eq=False) @@ -46,6 +85,7 @@ class Plugin(object): - **类型**: ``Set[Type[Matcher]]`` - **说明**: 插件内定义的 ``Matcher`` """ + export: Export def on(type: str = "", @@ -80,7 +120,7 @@ def on(type: str = "", block=block, handlers=handlers, default_state=state) - _tmp_matchers.add(matcher) + _tmp_matchers.get().add(matcher) return matcher @@ -112,7 +152,7 @@ def on_metaevent(rule: Optional[Union[Rule, RuleChecker]] = None, block=block, handlers=handlers, default_state=state) - _tmp_matchers.add(matcher) + _tmp_matchers.get().add(matcher) return matcher @@ -146,7 +186,7 @@ def on_message(rule: Optional[Union[Rule, RuleChecker]] = None, block=block, handlers=handlers, default_state=state) - _tmp_matchers.add(matcher) + _tmp_matchers.get().add(matcher) return matcher @@ -178,7 +218,7 @@ def on_notice(rule: Optional[Union[Rule, RuleChecker]] = None, block=block, handlers=handlers, default_state=state) - _tmp_matchers.add(matcher) + _tmp_matchers.get().add(matcher) return matcher @@ -210,7 +250,7 @@ def on_request(rule: Optional[Union[Rule, RuleChecker]] = None, block=block, handlers=handlers, default_state=state) - _tmp_matchers.add(matcher) + _tmp_matchers.get().add(matcher) return matcher @@ -387,27 +427,35 @@ def load_plugin(module_path: str) -> Optional[Plugin]: :返回: - ``Optional[Plugin]`` """ - try: - _tmp_matchers.clear() - if module_path in plugins: - return plugins[module_path] - elif module_path in sys.modules: - logger.warning( - f"Module {module_path} has been loaded by other plugins! Ignored" + + def _load_plugin(module_path: str) -> Optional[Plugin]: + try: + _tmp_matchers.set(set()) + _export.set(Export()) + if module_path in plugins: + return plugins[module_path] + elif module_path in sys.modules: + logger.warning( + f"Module {module_path} has been loaded by other plugins! Ignored" + ) + return + module = importlib.import_module(module_path) + for m in _tmp_matchers.get(): + m.module = module_path + plugin = Plugin(module_path, module, _tmp_matchers.get(), + _export.get()) + plugins[module_path] = plugin + logger.opt( + colors=True).info(f'Succeeded to import "{module_path}"') + return plugin + except Exception as e: + logger.opt(colors=True, exception=e).error( + f'Failed to import "{module_path}"' ) - return - module = importlib.import_module(module_path) - for m in _tmp_matchers: - m.module = module_path - plugin = Plugin(module_path, module, _tmp_matchers.copy()) - plugins[module_path] = plugin - logger.opt( - colors=True).info(f'Succeeded to import "{module_path}"') - return plugin - except Exception as e: - logger.opt(colors=True, exception=e).error( - f'Failed to import "{module_path}"') - return None + return None + + context: Context = copy_context() + return context.run(_load_plugin, module_path) def load_plugins(*plugin_dir: str) -> Set[Plugin]: @@ -419,33 +467,42 @@ def load_plugins(*plugin_dir: str) -> Set[Plugin]: :返回: - ``Set[Plugin]`` """ - loaded_plugins = set() - for module_info in pkgutil.iter_modules(plugin_dir): - _tmp_matchers.clear() + + def _load_plugin(module_info) -> Optional[Plugin]: + _tmp_matchers.set(set()) + _export.set(Export()) name = module_info.name if name.startswith("_"): - continue + return spec = module_info.module_finder.find_spec(name, None) if spec.name in plugins: - continue + return elif spec.name in sys.modules: logger.warning( f"Module {spec.name} has been loaded by other plugin! Ignored") - continue + return try: module = _load(spec) - for m in _tmp_matchers: + for m in _tmp_matchers.get(): m.module = name - plugin = Plugin(name, module, _tmp_matchers.copy()) + plugin = Plugin(name, module, _tmp_matchers.get(), _export.get()) plugins[name] = plugin - loaded_plugins.add(plugin) logger.opt(colors=True).info(f'Succeeded to import "{name}"') + return plugin except Exception as e: logger.opt(colors=True, exception=e).error( f'Failed to import "{name}"') + return None + + loaded_plugins = set() + for module_info in pkgutil.iter_modules(plugin_dir): + context: Context = copy_context() + result = context.run(_load_plugin, module_info) + if result: + loaded_plugins.add(result) return loaded_plugins @@ -479,3 +536,12 @@ def get_loaded_plugins() -> Set[Plugin]: - ``Set[Plugin]`` """ return set(plugins.values()) + + +def export() -> Export: + return _export.get() + + +def require(name: str) -> Optional[Export]: + plugin = get_plugin(name) + return plugin.export if plugin else None diff --git a/nonebot/plugin.pyi b/nonebot/plugin.pyi index 68bb41fd..37d775d6 100644 --- a/nonebot/plugin.pyi +++ b/nonebot/plugin.pyi @@ -1,17 +1,32 @@ import re +from contextvars import ContextVar from nonebot.typing import Rule, Matcher, Handler, Permission, RuleChecker from nonebot.typing import Set, List, Dict, Type, Tuple, Union, Optional, ModuleType plugins: Dict[str, "Plugin"] = ... -_tmp_matchers: Set[Type[Matcher]] = ... +_tmp_matchers: ContextVar[Set[Type[Matcher]]] = ... +_export: ContextVar["Export"] = ... + + +class Export(dict): + + def __call__(self, func, **kwargs): + ... + + def __setattr__(self, name, value): + ... + + def __getattr__(self, name): + ... class Plugin(object): name: str module: ModuleType matcher: Set[Type[Matcher]] + export: Export def on(type: str = ..., @@ -149,6 +164,14 @@ def get_loaded_plugins() -> Set[Plugin]: ... +def export() -> Export: + ... + + +def require(name: str) -> Export: + ... + + class CommandGroup: def __init__(self, diff --git a/pages/changelog.md b/pages/changelog.md index c64953d6..859be509 100644 --- a/pages/changelog.md +++ b/pages/changelog.md @@ -9,6 +9,8 @@ sidebar: auto - 修复 cqhttp 检查 to me 时出现 IndexError - 修复已失效的事件响应器仍会运行一次的 bug - 修改 cqhttp 检查 reply 时未去除后续 at 以及空格 +- 添加 get_plugin 获取插件函数 +- 添加插件 export, require 方法 ## v2.0.0a6 diff --git a/tests/bot.py b/tests/bot.py index 7a294564..68a4e399 100644 --- a/tests/bot.py +++ b/tests/bot.py @@ -22,6 +22,8 @@ nonebot.load_builtin_plugins() # load local plugins nonebot.load_plugins("test_plugins") +print(nonebot.require("test_export")) + # modify some config / config depends on loaded configs config = nonebot.get_driver().config config.custom_config3 = config.custom_config1 diff --git a/tests/test_plugins/test_export.py b/tests/test_plugins/test_export.py new file mode 100644 index 00000000..ec549571 --- /dev/null +++ b/tests/test_plugins/test_export.py @@ -0,0 +1,15 @@ +import nonebot + +export = nonebot.export() +export.foo = "bar" +export["bar"] = "foo" + + +@export +def a(): + pass + + +@export.sub +def b(): + pass