🐛 fix plugin matcher data

This commit is contained in:
yanyongyu 2021-03-22 16:41:29 +08:00
parent f0a8b47c06
commit 45199a247b
2 changed files with 22 additions and 45 deletions

View File

@ -5,8 +5,6 @@
该模块实现事件响应器的创建与运行并提供一些快捷方法来帮助用户更好的与机器人进行对话 该模块实现事件响应器的创建与运行并提供一些快捷方法来帮助用户更好的与机器人进行对话
""" """
import sys
import inspect
from functools import wraps from functools import wraps
from datetime import datetime from datetime import datetime
from contextvars import ContextVar from contextvars import ContextVar

View File

@ -8,6 +8,7 @@ import re
import json import json
from types import ModuleType from types import ModuleType
from dataclasses import dataclass from dataclasses import dataclass
from collections import defaultdict
from contextvars import Context, ContextVar, copy_context from contextvars import Context, ContextVar, copy_context
from typing import Any, Set, List, Dict, Type, Tuple, Union, Optional, TYPE_CHECKING from typing import Any, Set, List, Dict, Type, Tuple, Union, Optional, TYPE_CHECKING
@ -33,7 +34,7 @@ PLUGIN_NAMESPACE = "nonebot.loaded_plugins"
_export: ContextVar["Export"] = ContextVar("_export") _export: ContextVar["Export"] = ContextVar("_export")
# FIXME: tmp matchers context var will be removed # FIXME: tmp matchers context var will be removed
_tmp_matchers: ContextVar[Set[Type[Matcher]]] = ContextVar("_tmp_matchers") _plugin_matchers: Dict[str, Set[Type[Matcher]]] = defaultdict(set)
class Export(dict): class Export(dict):
@ -93,17 +94,25 @@ class Plugin(object):
- **类型**: ``ModuleType`` - **类型**: ``ModuleType``
- **说明**: 插件模块对象 - **说明**: 插件模块对象
""" """
matcher: Set[Type[Matcher]]
"""
- **类型**: ``Set[Type[Matcher]]``
- **说明**: 插件内定义的 ``Matcher``
"""
export: Export export: Export
""" """
- **类型**: ``Export`` - **类型**: ``Export``
- **说明**: 插件内定义的导出内容 - **说明**: 插件内定义的导出内容
""" """
@property
def matcher(self) -> Set[Type[Matcher]]:
"""
- **类型**: ``Set[Type[Matcher]]``
- **说明**: 插件内定义的 ``Matcher``
"""
return _plugin_matchers[self.name]
def _store_matcher(matcher: Type[Matcher]):
plugin_name = matcher.module.split(".", maxsplit=1)[0]
_plugin_matchers[plugin_name].add(matcher)
def on(type: str = "", def on(type: str = "",
rule: Optional[Union[Rule, T_RuleChecker]] = None, rule: Optional[Union[Rule, T_RuleChecker]] = None,
@ -146,7 +155,7 @@ def on(type: str = "",
module=_current_plugin.get(), module=_current_plugin.get(),
default_state=state, default_state=state,
default_state_factory=state_factory) default_state_factory=state_factory)
_tmp_matchers.get().add(matcher) _store_matcher(matcher)
return matcher return matcher
@ -188,7 +197,7 @@ def on_metaevent(
module=_current_plugin.get(), module=_current_plugin.get(),
default_state=state, default_state=state,
default_state_factory=state_factory) default_state_factory=state_factory)
_tmp_matchers.get().add(matcher) _store_matcher(matcher)
return matcher return matcher
@ -231,7 +240,7 @@ def on_message(rule: Optional[Union[Rule, T_RuleChecker]] = None,
module=_current_plugin.get(), module=_current_plugin.get(),
default_state=state, default_state=state,
default_state_factory=state_factory) default_state_factory=state_factory)
_tmp_matchers.get().add(matcher) _store_matcher(matcher)
return matcher return matcher
@ -272,7 +281,7 @@ def on_notice(rule: Optional[Union[Rule, T_RuleChecker]] = None,
module=_current_plugin.get(), module=_current_plugin.get(),
default_state=state, default_state=state,
default_state_factory=state_factory) default_state_factory=state_factory)
_tmp_matchers.get().add(matcher) _store_matcher(matcher)
return matcher return matcher
@ -313,7 +322,7 @@ def on_request(rule: Optional[Union[Rule, T_RuleChecker]] = None,
module=_current_plugin.get(), module=_current_plugin.get(),
default_state=state, default_state=state,
default_state_factory=state_factory) default_state_factory=state_factory)
_tmp_matchers.get().add(matcher) _store_matcher(matcher)
return matcher return matcher
@ -957,7 +966,6 @@ def _load_plugin(manager: PluginManager, plugin_name: str) -> Optional[Plugin]:
if plugin_name.startswith("_"): if plugin_name.startswith("_"):
return None return None
_tmp_matchers.set(set())
_export.set(Export()) _export.set(Export())
if plugin_name in plugins: if plugin_name in plugins:
@ -966,8 +974,7 @@ def _load_plugin(manager: PluginManager, plugin_name: str) -> Optional[Plugin]:
try: try:
module = manager.load_plugin(plugin_name) module = manager.load_plugin(plugin_name)
# FIXME: store matchers using new method plugin = Plugin(plugin_name, module, _export.get())
plugin = Plugin(plugin_name, module, _tmp_matchers.get(), _export.get())
plugins[plugin_name] = plugin plugins[plugin_name] = plugin
logger.opt( logger.opt(
colors=True).info(f'Succeeded to import "<y>{plugin_name}</y>"') colors=True).info(f'Succeeded to import "<y>{plugin_name}</y>"')
@ -1038,39 +1045,11 @@ def load_all_plugins(module_path: Set[str],
- ``Set[Plugin]`` - ``Set[Plugin]``
""" """
def _load_plugin(plugin_name: str) -> Optional[Plugin]:
if plugin_name.startswith("_"):
return None
_tmp_matchers.set(set())
_export.set(Export())
if plugin_name in plugins:
return None
try:
module = manager.load_plugin(plugin_name)
for m in _tmp_matchers.get():
m.module = plugin_name
plugin = Plugin(plugin_name, module, _tmp_matchers.get(),
_export.get())
plugins[plugin_name] = plugin
logger.opt(
colors=True).info(f'Succeeded to import "<y>{plugin_name}</y>"')
return plugin
except Exception as e:
logger.opt(colors=True, exception=e).error(
f'<r><bg #f8bbd0>Failed to import "{plugin_name}"</bg #f8bbd0></r>'
)
return None
loaded_plugins = set() loaded_plugins = set()
manager = PluginManager(PLUGIN_NAMESPACE, module_path, plugin_dir) manager = PluginManager(PLUGIN_NAMESPACE, module_path, plugin_dir)
for plugin_name in manager.list_plugins(): for plugin_name in manager.list_plugins():
context: Context = copy_context() context: Context = copy_context()
result = context.run(_load_plugin, plugin_name) result = context.run(_load_plugin, manager, plugin_name)
if result: if result:
loaded_plugins.add(result) loaded_plugins.add(result)
return loaded_plugins return loaded_plugins