mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-01-19 01:18:19 +08:00
⚗️ add export require option
This commit is contained in:
parent
689180ebe8
commit
2b10f81326
@ -240,4 +240,4 @@ async def _start_scheduler():
|
||||
from nonebot.plugin import on_message, on_notice, on_request, on_metaevent, CommandGroup
|
||||
from nonebot.plugin import on_startswith, on_endswith, on_keyword, on_command, on_regex
|
||||
from nonebot.plugin import load_plugin, load_plugins, load_builtin_plugins
|
||||
from nonebot.plugin import get_plugin, get_loaded_plugins
|
||||
from nonebot.plugin import export, require, get_plugin, get_loaded_plugins
|
||||
|
@ -11,6 +11,7 @@ import pkgutil
|
||||
import importlib
|
||||
from dataclasses import dataclass
|
||||
from importlib._bootstrap import _load
|
||||
from contextvars import Context, ContextVar, copy_context
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.matcher import Matcher
|
||||
@ -25,7 +26,45 @@ plugins: Dict[str, "Plugin"] = {}
|
||||
:说明: 已加载的插件
|
||||
"""
|
||||
|
||||
_tmp_matchers: Set[Type[Matcher]] = set()
|
||||
_tmp_matchers: ContextVar[Set[Type[Matcher]]] = ContextVar("_tmp_matchers")
|
||||
_export: ContextVar["Export"] = ContextVar("_export")
|
||||
|
||||
|
||||
class Export(dict):
|
||||
"""
|
||||
:说明:
|
||||
插件导出内容以使得其他插件可以获得。
|
||||
:示例:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
nonebot.export().default = "bar"
|
||||
|
||||
@nonebot.export()
|
||||
def some_function():
|
||||
pass
|
||||
|
||||
@nonebot.export().sub
|
||||
def something_else():
|
||||
pass
|
||||
"""
|
||||
|
||||
def __call__(self, func, **kwargs):
|
||||
self[func.__name__] = func
|
||||
self.update(kwargs)
|
||||
return func
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
super().__setitem__(key,
|
||||
Export(value) if isinstance(value, dict) else value)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
self[name] = Export(value) if isinstance(value, dict) else value
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name not in self:
|
||||
self[name] = Export()
|
||||
return self[name]
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
@ -46,6 +85,7 @@ class Plugin(object):
|
||||
- **类型**: ``Set[Type[Matcher]]``
|
||||
- **说明**: 插件内定义的 ``Matcher``
|
||||
"""
|
||||
export: Export
|
||||
|
||||
|
||||
def on(type: str = "",
|
||||
@ -80,7 +120,7 @@ def on(type: str = "",
|
||||
block=block,
|
||||
handlers=handlers,
|
||||
default_state=state)
|
||||
_tmp_matchers.add(matcher)
|
||||
_tmp_matchers.get().add(matcher)
|
||||
return matcher
|
||||
|
||||
|
||||
@ -112,7 +152,7 @@ def on_metaevent(rule: Optional[Union[Rule, RuleChecker]] = None,
|
||||
block=block,
|
||||
handlers=handlers,
|
||||
default_state=state)
|
||||
_tmp_matchers.add(matcher)
|
||||
_tmp_matchers.get().add(matcher)
|
||||
return matcher
|
||||
|
||||
|
||||
@ -146,7 +186,7 @@ def on_message(rule: Optional[Union[Rule, RuleChecker]] = None,
|
||||
block=block,
|
||||
handlers=handlers,
|
||||
default_state=state)
|
||||
_tmp_matchers.add(matcher)
|
||||
_tmp_matchers.get().add(matcher)
|
||||
return matcher
|
||||
|
||||
|
||||
@ -178,7 +218,7 @@ def on_notice(rule: Optional[Union[Rule, RuleChecker]] = None,
|
||||
block=block,
|
||||
handlers=handlers,
|
||||
default_state=state)
|
||||
_tmp_matchers.add(matcher)
|
||||
_tmp_matchers.get().add(matcher)
|
||||
return matcher
|
||||
|
||||
|
||||
@ -210,7 +250,7 @@ def on_request(rule: Optional[Union[Rule, RuleChecker]] = None,
|
||||
block=block,
|
||||
handlers=handlers,
|
||||
default_state=state)
|
||||
_tmp_matchers.add(matcher)
|
||||
_tmp_matchers.get().add(matcher)
|
||||
return matcher
|
||||
|
||||
|
||||
@ -387,27 +427,35 @@ def load_plugin(module_path: str) -> Optional[Plugin]:
|
||||
:返回:
|
||||
- ``Optional[Plugin]``
|
||||
"""
|
||||
try:
|
||||
_tmp_matchers.clear()
|
||||
if module_path in plugins:
|
||||
return plugins[module_path]
|
||||
elif module_path in sys.modules:
|
||||
logger.warning(
|
||||
f"Module {module_path} has been loaded by other plugins! Ignored"
|
||||
|
||||
def _load_plugin(module_path: str) -> Optional[Plugin]:
|
||||
try:
|
||||
_tmp_matchers.set(set())
|
||||
_export.set(Export())
|
||||
if module_path in plugins:
|
||||
return plugins[module_path]
|
||||
elif module_path in sys.modules:
|
||||
logger.warning(
|
||||
f"Module {module_path} has been loaded by other plugins! Ignored"
|
||||
)
|
||||
return
|
||||
module = importlib.import_module(module_path)
|
||||
for m in _tmp_matchers.get():
|
||||
m.module = module_path
|
||||
plugin = Plugin(module_path, module, _tmp_matchers.get(),
|
||||
_export.get())
|
||||
plugins[module_path] = plugin
|
||||
logger.opt(
|
||||
colors=True).info(f'Succeeded to import "<y>{module_path}</y>"')
|
||||
return plugin
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
f'<r><bg #f8bbd0>Failed to import "{module_path}"</bg #f8bbd0></r>'
|
||||
)
|
||||
return
|
||||
module = importlib.import_module(module_path)
|
||||
for m in _tmp_matchers:
|
||||
m.module = module_path
|
||||
plugin = Plugin(module_path, module, _tmp_matchers.copy())
|
||||
plugins[module_path] = plugin
|
||||
logger.opt(
|
||||
colors=True).info(f'Succeeded to import "<y>{module_path}</y>"')
|
||||
return plugin
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
f'<r><bg #f8bbd0>Failed to import "{module_path}"</bg #f8bbd0></r>')
|
||||
return None
|
||||
return None
|
||||
|
||||
context: Context = copy_context()
|
||||
return context.run(_load_plugin, module_path)
|
||||
|
||||
|
||||
def load_plugins(*plugin_dir: str) -> Set[Plugin]:
|
||||
@ -419,33 +467,42 @@ def load_plugins(*plugin_dir: str) -> Set[Plugin]:
|
||||
:返回:
|
||||
- ``Set[Plugin]``
|
||||
"""
|
||||
loaded_plugins = set()
|
||||
for module_info in pkgutil.iter_modules(plugin_dir):
|
||||
_tmp_matchers.clear()
|
||||
|
||||
def _load_plugin(module_info) -> Optional[Plugin]:
|
||||
_tmp_matchers.set(set())
|
||||
_export.set(Export())
|
||||
name = module_info.name
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
return
|
||||
|
||||
spec = module_info.module_finder.find_spec(name, None)
|
||||
if spec.name in plugins:
|
||||
continue
|
||||
return
|
||||
elif spec.name in sys.modules:
|
||||
logger.warning(
|
||||
f"Module {spec.name} has been loaded by other plugin! Ignored")
|
||||
continue
|
||||
return
|
||||
|
||||
try:
|
||||
module = _load(spec)
|
||||
|
||||
for m in _tmp_matchers:
|
||||
for m in _tmp_matchers.get():
|
||||
m.module = name
|
||||
plugin = Plugin(name, module, _tmp_matchers.copy())
|
||||
plugin = Plugin(name, module, _tmp_matchers.get(), _export.get())
|
||||
plugins[name] = plugin
|
||||
loaded_plugins.add(plugin)
|
||||
logger.opt(colors=True).info(f'Succeeded to import "<y>{name}</y>"')
|
||||
return plugin
|
||||
except Exception as e:
|
||||
logger.opt(colors=True, exception=e).error(
|
||||
f'<r><bg #f8bbd0>Failed to import "{name}"</bg #f8bbd0></r>')
|
||||
return None
|
||||
|
||||
loaded_plugins = set()
|
||||
for module_info in pkgutil.iter_modules(plugin_dir):
|
||||
context: Context = copy_context()
|
||||
result = context.run(_load_plugin, module_info)
|
||||
if result:
|
||||
loaded_plugins.add(result)
|
||||
return loaded_plugins
|
||||
|
||||
|
||||
@ -479,3 +536,12 @@ def get_loaded_plugins() -> Set[Plugin]:
|
||||
- ``Set[Plugin]``
|
||||
"""
|
||||
return set(plugins.values())
|
||||
|
||||
|
||||
def export() -> Export:
|
||||
return _export.get()
|
||||
|
||||
|
||||
def require(name: str) -> Optional[Export]:
|
||||
plugin = get_plugin(name)
|
||||
return plugin.export if plugin else None
|
||||
|
@ -1,17 +1,32 @@
|
||||
import re
|
||||
from contextvars import ContextVar
|
||||
|
||||
from nonebot.typing import Rule, Matcher, Handler, Permission, RuleChecker
|
||||
from nonebot.typing import Set, List, Dict, Type, Tuple, Union, Optional, ModuleType
|
||||
|
||||
plugins: Dict[str, "Plugin"] = ...
|
||||
|
||||
_tmp_matchers: Set[Type[Matcher]] = ...
|
||||
_tmp_matchers: ContextVar[Set[Type[Matcher]]] = ...
|
||||
_export: ContextVar["Export"] = ...
|
||||
|
||||
|
||||
class Export(dict):
|
||||
|
||||
def __call__(self, func, **kwargs):
|
||||
...
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
...
|
||||
|
||||
def __getattr__(self, name):
|
||||
...
|
||||
|
||||
|
||||
class Plugin(object):
|
||||
name: str
|
||||
module: ModuleType
|
||||
matcher: Set[Type[Matcher]]
|
||||
export: Export
|
||||
|
||||
|
||||
def on(type: str = ...,
|
||||
@ -149,6 +164,14 @@ def get_loaded_plugins() -> Set[Plugin]:
|
||||
...
|
||||
|
||||
|
||||
def export() -> Export:
|
||||
...
|
||||
|
||||
|
||||
def require(name: str) -> Export:
|
||||
...
|
||||
|
||||
|
||||
class CommandGroup:
|
||||
|
||||
def __init__(self,
|
||||
|
@ -9,6 +9,8 @@ sidebar: auto
|
||||
- 修复 cqhttp 检查 to me 时出现 IndexError
|
||||
- 修复已失效的事件响应器仍会运行一次的 bug
|
||||
- 修改 cqhttp 检查 reply 时未去除后续 at 以及空格
|
||||
- 添加 get_plugin 获取插件函数
|
||||
- 添加插件 export, require 方法
|
||||
|
||||
## v2.0.0a6
|
||||
|
||||
|
@ -22,6 +22,8 @@ nonebot.load_builtin_plugins()
|
||||
# load local plugins
|
||||
nonebot.load_plugins("test_plugins")
|
||||
|
||||
print(nonebot.require("test_export"))
|
||||
|
||||
# modify some config / config depends on loaded configs
|
||||
config = nonebot.get_driver().config
|
||||
config.custom_config3 = config.custom_config1
|
||||
|
15
tests/test_plugins/test_export.py
Normal file
15
tests/test_plugins/test_export.py
Normal file
@ -0,0 +1,15 @@
|
||||
import nonebot
|
||||
|
||||
export = nonebot.export()
|
||||
export.foo = "bar"
|
||||
export["bar"] = "foo"
|
||||
|
||||
|
||||
@export
|
||||
def a():
|
||||
pass
|
||||
|
||||
|
||||
@export.sub
|
||||
def b():
|
||||
pass
|
Loading…
Reference in New Issue
Block a user