🐛 fix plugin matcher data

This commit is contained in:
yanyongyu 2021-03-22 16:41:29 +08:00
parent f0a8b47c06
commit 45199a247b
2 changed files with 22 additions and 45 deletions

View File

@ -5,8 +5,6 @@
该模块实现事件响应器的创建与运行并提供一些快捷方法来帮助用户更好的与机器人进行对话
"""
import sys
import inspect
from functools import wraps
from datetime import datetime
from contextvars import ContextVar

View File

@ -8,6 +8,7 @@ import re
import json
from types import ModuleType
from dataclasses import dataclass
from collections import defaultdict
from contextvars import Context, ContextVar, copy_context
from typing import Any, Set, List, Dict, Type, Tuple, Union, Optional, TYPE_CHECKING
@ -33,7 +34,7 @@ PLUGIN_NAMESPACE = "nonebot.loaded_plugins"
_export: ContextVar["Export"] = ContextVar("_export")
# FIXME: tmp matchers context var will be removed
_tmp_matchers: ContextVar[Set[Type[Matcher]]] = ContextVar("_tmp_matchers")
_plugin_matchers: Dict[str, Set[Type[Matcher]]] = defaultdict(set)
class Export(dict):
@ -93,17 +94,25 @@ class Plugin(object):
- **类型**: ``ModuleType``
- **说明**: 插件模块对象
"""
matcher: Set[Type[Matcher]]
"""
- **类型**: ``Set[Type[Matcher]]``
- **说明**: 插件内定义的 ``Matcher``
"""
export: Export
"""
- **类型**: ``Export``
- **说明**: 插件内定义的导出内容
"""
@property
def matcher(self) -> Set[Type[Matcher]]:
"""
- **类型**: ``Set[Type[Matcher]]``
- **说明**: 插件内定义的 ``Matcher``
"""
return _plugin_matchers[self.name]
def _store_matcher(matcher: Type[Matcher]):
plugin_name = matcher.module.split(".", maxsplit=1)[0]
_plugin_matchers[plugin_name].add(matcher)
def on(type: str = "",
rule: Optional[Union[Rule, T_RuleChecker]] = None,
@ -146,7 +155,7 @@ def on(type: str = "",
module=_current_plugin.get(),
default_state=state,
default_state_factory=state_factory)
_tmp_matchers.get().add(matcher)
_store_matcher(matcher)
return matcher
@ -188,7 +197,7 @@ def on_metaevent(
module=_current_plugin.get(),
default_state=state,
default_state_factory=state_factory)
_tmp_matchers.get().add(matcher)
_store_matcher(matcher)
return matcher
@ -231,7 +240,7 @@ def on_message(rule: Optional[Union[Rule, T_RuleChecker]] = None,
module=_current_plugin.get(),
default_state=state,
default_state_factory=state_factory)
_tmp_matchers.get().add(matcher)
_store_matcher(matcher)
return matcher
@ -272,7 +281,7 @@ def on_notice(rule: Optional[Union[Rule, T_RuleChecker]] = None,
module=_current_plugin.get(),
default_state=state,
default_state_factory=state_factory)
_tmp_matchers.get().add(matcher)
_store_matcher(matcher)
return matcher
@ -313,7 +322,7 @@ def on_request(rule: Optional[Union[Rule, T_RuleChecker]] = None,
module=_current_plugin.get(),
default_state=state,
default_state_factory=state_factory)
_tmp_matchers.get().add(matcher)
_store_matcher(matcher)
return matcher
@ -957,7 +966,6 @@ def _load_plugin(manager: PluginManager, plugin_name: str) -> Optional[Plugin]:
if plugin_name.startswith("_"):
return None
_tmp_matchers.set(set())
_export.set(Export())
if plugin_name in plugins:
@ -966,8 +974,7 @@ def _load_plugin(manager: PluginManager, plugin_name: str) -> Optional[Plugin]:
try:
module = manager.load_plugin(plugin_name)
# FIXME: store matchers using new method
plugin = Plugin(plugin_name, module, _tmp_matchers.get(), _export.get())
plugin = Plugin(plugin_name, module, _export.get())
plugins[plugin_name] = plugin
logger.opt(
colors=True).info(f'Succeeded to import "<y>{plugin_name}</y>"')
@ -1038,39 +1045,11 @@ def load_all_plugins(module_path: Set[str],
- ``Set[Plugin]``
"""
def _load_plugin(plugin_name: str) -> Optional[Plugin]:
if plugin_name.startswith("_"):
return None
_tmp_matchers.set(set())
_export.set(Export())
if plugin_name in plugins:
return None
try:
module = manager.load_plugin(plugin_name)
for m in _tmp_matchers.get():
m.module = plugin_name
plugin = Plugin(plugin_name, module, _tmp_matchers.get(),
_export.get())
plugins[plugin_name] = plugin
logger.opt(
colors=True).info(f'Succeeded to import "<y>{plugin_name}</y>"')
return plugin
except Exception as e:
logger.opt(colors=True, exception=e).error(
f'<r><bg #f8bbd0>Failed to import "{plugin_name}"</bg #f8bbd0></r>'
)
return None
loaded_plugins = set()
manager = PluginManager(PLUGIN_NAMESPACE, module_path, plugin_dir)
for plugin_name in manager.list_plugins():
context: Context = copy_context()
result = context.run(_load_plugin, plugin_name)
result = context.run(_load_plugin, manager, plugin_name)
if result:
loaded_plugins.add(result)
return loaded_plugins