From 0a1ae75b702d53834a256e45314f565923936fac Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Sun, 14 Nov 2021 18:51:23 +0800 Subject: [PATCH] :sparkles: finish matcher process --- nonebot/permission.py | 25 ++++---- nonebot/plugin/on.py | 11 ++-- nonebot/processor/__init__.py | 9 ++- nonebot/processor/handler.py | 19 +++++- nonebot/processor/matcher.py | 98 ++++++++++++++++-------------- nonebot/processor/utils.py | 18 ++++-- nonebot/rule.py | 32 +++++----- tests/test_plugins/test_depends.py | 17 ++++++ 8 files changed, 139 insertions(+), 90 deletions(-) create mode 100644 tests/test_plugins/test_depends.py diff --git a/nonebot/permission.py b/nonebot/permission.py index 675f546d..7564bb13 100644 --- a/nonebot/permission.py +++ b/nonebot/permission.py @@ -10,14 +10,12 @@ r""" """ import asyncio -from typing import TYPE_CHECKING, Union, Callable, NoReturn, Optional, Awaitable +from typing import Union, Callable, NoReturn, Optional, Awaitable from nonebot.utils import run_sync +from nonebot.adapters import Bot, Event from nonebot.typing import T_PermissionChecker -if TYPE_CHECKING: - from nonebot.adapters import Bot, Event - class Permission: """ @@ -36,9 +34,8 @@ class Permission: """ __slots__ = ("checkers",) - def __init__( - self, *checkers: Callable[["Bot", "Event"], - Awaitable[bool]]) -> None: + def __init__(self, *checkers: Callable[[Bot, Event], + Awaitable[bool]]) -> None: """ :参数: @@ -55,7 +52,7 @@ class Permission: * ``Set[Callable[[Bot, Event], Awaitable[bool]]]`` """ - async def __call__(self, bot: "Bot", event: "Event") -> bool: + async def __call__(self, bot: Bot, event: Event) -> bool: """ :说明: @@ -94,19 +91,19 @@ class Permission: return Permission(*checkers) -async def _message(bot: "Bot", event: "Event") -> bool: +async def _message(bot: Bot, event: Event) -> bool: return event.get_type() == "message" -async def _notice(bot: "Bot", event: "Event") -> bool: +async def _notice(bot: Bot, event: Event) -> bool: return event.get_type() == "notice" -async def _request(bot: "Bot", event: "Event") -> bool: +async def _request(bot: Bot, event: Event) -> bool: return event.get_type() == "request" -async def _metaevent(bot: "Bot", event: "Event") -> bool: +async def _metaevent(bot: Bot, event: Event) -> bool: return event.get_type() == "meta_event" @@ -140,14 +137,14 @@ def USER(*user: str, perm: Optional[Permission] = None): * ``perm: Optional[Permission]``: 需要同时满足的权限 """ - async def _user(bot: "Bot", event: "Event") -> bool: + async def _user(bot: Bot, event: Event) -> bool: return bool(event.get_session_id() in user and (perm is None or await perm(bot, event))) return Permission(_user) -async def _superuser(bot: "Bot", event: "Event") -> bool: +async def _superuser(bot: Bot, event: Event) -> bool: return (event.get_type() == "message" and event.get_user_id() in bot.config.superusers) diff --git a/nonebot/plugin/on.py b/nonebot/plugin/on.py index 46a5ecfe..40038984 100644 --- a/nonebot/plugin/on.py +++ b/nonebot/plugin/on.py @@ -2,19 +2,16 @@ import re import sys import inspect from types import ModuleType -from typing import (TYPE_CHECKING, Any, Set, Dict, List, Type, Tuple, Union, - Optional) +from typing import Any, Set, Dict, List, Type, Tuple, Union, Optional from .manager import _current_plugin +from nonebot.adapters import Bot, Event from nonebot.permission import Permission from nonebot.processor import Handler, Matcher from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory from nonebot.rule import (Rule, ArgumentParser, regex, command, keyword, endswith, startswith, shell_command) -if TYPE_CHECKING: - from nonebot.adapters import Bot, Event - def _store_matcher(matcher: Type[Matcher]) -> None: plugin = _current_plugin.get() @@ -375,7 +372,7 @@ def on_command(cmd: Union[str, Tuple[str, ...]], - ``Type[Matcher]`` """ - async def _strip_cmd(bot: "Bot", event: "Event", state: T_State): + async def _strip_cmd(bot: Bot, event: Event, state: T_State): message = event.get_message() if len(message) < 1: return @@ -432,7 +429,7 @@ def on_shell_command(cmd: Union[str, Tuple[str, ...]], - ``Type[Matcher]`` """ - async def _strip_cmd(bot: "Bot", event: "Event", state: T_State): + async def _strip_cmd(bot: Bot, event: Event, state: T_State): message = event.get_message() segment = message.pop(0) new_message = message.__class__( diff --git a/nonebot/processor/__init__.py b/nonebot/processor/__init__.py index 1ef0def8..70689631 100644 --- a/nonebot/processor/__init__.py +++ b/nonebot/processor/__init__.py @@ -3,6 +3,7 @@ from itertools import chain from typing import Any, Dict, List, Tuple, Callable, Optional, cast from .models import Dependent +from nonebot.log import logger from nonebot.typing import T_State from nonebot.adapters import Bot, Event from .models import Depends as DependsClass @@ -70,7 +71,7 @@ def get_dependent(*, f"{dependent.event_param_name} / {param_name}") dependent.event_param_name = param_name dependent.event_param_type = generic_get_types(param.annotation) - elif generic_check_issubclass(param.annotation, dict): + elif generic_check_issubclass(param.annotation, Dict): if dependent.state_param_name is not None: raise ValueError(f"{func} has more than one State parameter: " f"{dependent.state_param_name} / {param_name}") @@ -114,9 +115,15 @@ async def solve_dependencies( # check bot and event type if sub_dependent.bot_param_type and not isinstance( bot, sub_dependent.bot_param_type): + logger.debug( + f"Matcher {matcher} bot type {type(bot)} not match depends {func} " + f"annotation {sub_dependent.bot_param_type}, ignored") return values, dependency_cache, True elif sub_dependent.event_param_type and not isinstance( event, sub_dependent.event_param_type): + logger.debug( + f"Matcher {matcher} event type {type(event)} not match depends {func} " + f"annotation {sub_dependent.event_param_type}, ignored") return values, dependency_cache, True # dependency overrides diff --git a/nonebot/processor/handler.py b/nonebot/processor/handler.py index f21a9b37..0aafc8c9 100644 --- a/nonebot/processor/handler.py +++ b/nonebot/processor/handler.py @@ -8,6 +8,7 @@ import asyncio from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Callable, Optional +from nonebot.log import logger from .models import Depends, Dependent from nonebot.utils import get_name, run_sync from nonebot.typing import T_State, T_Handler @@ -48,7 +49,18 @@ class Handler: self.dependency_overrides_provider = dependency_overrides_provider self.dependent = get_dependent(func=func) - async def __call__(self, matcher: "Matcher", bot: Bot, event: Event, + def __repr__(self) -> str: + return ( + f"") + + def __str__(self) -> str: + return repr(self) + + async def __call__(self, matcher: "Matcher", bot: "Bot", event: "Event", state: T_State): values, _, ignored = await solve_dependencies( dependent=self.dependent, @@ -68,9 +80,14 @@ class Handler: # check bot and event type if self.dependent.bot_param_type and not isinstance( bot, self.dependent.bot_param_type): + logger.debug(f"Matcher {matcher} bot type {type(bot)} not match " + f"annotation {self.dependent.bot_param_type}, ignored") return elif self.dependent.event_param_type and not isinstance( event, self.dependent.event_param_type): + logger.debug( + f"Matcher {matcher} event type {type(event)} not match " + f"annotation {self.dependent.event_param_type}, ignored") return if asyncio.iscoroutinefunction(self.func): diff --git a/nonebot/processor/matcher.py b/nonebot/processor/matcher.py index 612b6d36..9fbceaa9 100644 --- a/nonebot/processor/matcher.py +++ b/nonebot/processor/matcher.py @@ -17,8 +17,9 @@ from .handler import Handler from nonebot.rule import Rule from nonebot import get_driver from nonebot.log import logger -from nonebot.adapters import MessageTemplate from nonebot.permission import USER, Permission +from nonebot.adapters import (Bot, Event, Message, MessageSegment, + MessageTemplate) from nonebot.exception import (PausedException, StopPropagation, FinishedException, RejectedException) from nonebot.typing import (T_State, T_Handler, T_ArgsParser, T_TypeUpdater, @@ -26,15 +27,14 @@ from nonebot.typing import (T_State, T_Handler, T_ArgsParser, T_TypeUpdater, if TYPE_CHECKING: from nonebot.plugin import Plugin - from nonebot.adapters import Bot, Event, Message, MessageSegment matchers: Dict[int, List[Type["Matcher"]]] = defaultdict(list) """ :类型: ``Dict[int, List[Type[Matcher]]]`` :说明: 用于存储当前所有的事件响应器 """ -current_bot: ContextVar["Bot"] = ContextVar("current_bot") -current_event: ContextVar["Event"] = ContextVar("current_event") +current_bot: ContextVar[Bot] = ContextVar("current_bot") +current_event: ContextVar[Event] = ContextVar("current_event") current_state: ContextVar[T_State] = ContextVar("current_state") @@ -259,7 +259,7 @@ class Matcher(metaclass=MatcherMeta): return NewMatcher @classmethod - async def check_perm(cls, bot: "Bot", event: "Event") -> bool: + async def check_perm(cls, bot: Bot, event: Event) -> bool: """ :说明: @@ -279,8 +279,7 @@ class Matcher(metaclass=MatcherMeta): await cls.permission(bot, event)) @classmethod - async def check_rule(cls, bot: "Bot", event: "Event", - state: T_State) -> bool: + async def check_rule(cls, bot: Bot, event: Event, state: T_State) -> bool: """ :说明: @@ -383,18 +382,21 @@ class Matcher(metaclass=MatcherMeta): * 无 """ + async def _receive(state: T_State) -> Union[None, NoReturn]: + if state.get(_receive): + return + state[_receive] = True + raise RejectedException + def _decorator(func: T_Handler) -> T_Handler: - async def _receive() -> NoReturn: - func_handler.remove_dependency(depend) - raise PausedException - depend = Depends(_receive) + if cls.handlers and cls.handlers[-1].func is func: func_handler = cls.handlers[-1] func_handler.prepend_dependency(depend) else: - func_handler = cls.append_handler( + cls.append_handler( func, dependencies=[depend] if cls.handlers else []) return func @@ -405,7 +407,7 @@ class Matcher(metaclass=MatcherMeta): def got( cls, key: str, - prompt: Optional[Union[str, "Message", "MessageSegment", + prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None, args_parser: Optional[T_ArgsParser] = None ) -> Callable[[T_Handler], T_Handler]: @@ -421,32 +423,36 @@ class Matcher(metaclass=MatcherMeta): * ``args_parser: Optional[T_ArgsParser]``: 可选参数解析函数,空则使用默认解析函数 """ + async def _key_getter(bot: Bot, event: Event, state: T_State): + if state.get(f"_{key}_prompted"): + return + + state["_current_key"] = key + state[f"_{key}_prompted"] = True + if key not in state: + if prompt is not None: + if isinstance(prompt, MessageTemplate): + _prompt = prompt.format(**state) + else: + _prompt = prompt + await bot.send(event=event, message=_prompt) + raise RejectedException + else: + state[f"_{key}_parsed"] = True + + async def _key_parser(bot: Bot, event: Event, state: T_State): + if key in state and state.get(f"_{key}_parsed"): + return + + parser = args_parser or cls._default_parser + if parser: + await parser(bot, event, state) + else: + state[key] = str(event.get_message()) + state[f"_{key}_parsed"] = True + def _decorator(func: T_Handler) -> T_Handler: - async def _key_getter(bot: "Bot", event: "Event", state: T_State): - func_handler.remove_dependency(get_depend) - state["_current_key"] = key - if key not in state: - if prompt is not None: - if isinstance(prompt, MessageTemplate): - _prompt = prompt.format(**state) - else: - _prompt = prompt - await bot.send(event=event, message=_prompt) - raise PausedException - else: - state["_skip_key"] = True - - async def _key_parser(bot: "Bot", event: "Event", state: T_State): - if key in state and state.get("_skip_key"): - del state["_skip_key"] - return - parser = args_parser or cls._default_parser - if parser: - await parser(bot, event, state) - else: - state[state["_current_key"]] = str(event.get_message()) - get_depend = Depends(_key_getter) parser_depend = Depends(_key_parser) @@ -455,15 +461,15 @@ class Matcher(metaclass=MatcherMeta): func_handler.prepend_dependency(parser_depend) func_handler.prepend_dependency(get_depend) else: - func_handler = cls.append_handler( - func, dependencies=[get_depend, parser_depend]) + cls.append_handler(func, + dependencies=[get_depend, parser_depend]) return func return _decorator @classmethod - async def send(cls, message: Union[str, "Message", "MessageSegment", + async def send(cls, message: Union[str, Message, MessageSegment, MessageTemplate], **kwargs) -> Any: """ :说明: @@ -486,7 +492,7 @@ class Matcher(metaclass=MatcherMeta): @classmethod async def finish(cls, - message: Optional[Union[str, "Message", "MessageSegment", + message: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None, **kwargs) -> NoReturn: """ @@ -512,7 +518,7 @@ class Matcher(metaclass=MatcherMeta): @classmethod async def pause(cls, - prompt: Optional[Union[str, "Message", "MessageSegment", + prompt: Optional[Union[str, Message, MessageSegment, MessageTemplate]] = None, **kwargs) -> NoReturn: """ @@ -538,8 +544,8 @@ class Matcher(metaclass=MatcherMeta): @classmethod async def reject(cls, - prompt: Optional[Union[str, "Message", - "MessageSegment"]] = None, + prompt: Optional[Union[str, Message, + MessageSegment]] = None, **kwargs) -> NoReturn: """ :说明: @@ -554,6 +560,8 @@ class Matcher(metaclass=MatcherMeta): bot = current_bot.get() event = current_event.get() state = current_state.get() + if "_current_key" in state and f"_{state['_current_key']}_parsed" in state: + del state[f"_{state['_current_key']}_parsed"] if isinstance(prompt, MessageTemplate): _prompt = prompt.format(**state) else: @@ -571,7 +579,7 @@ class Matcher(metaclass=MatcherMeta): self.block = True # 运行handlers - async def run(self, bot: "Bot", event: "Event", state: T_State): + async def run(self, bot: Bot, event: Event, state: T_State): b_t = current_bot.set(bot) e_t = current_event.set(event) s_t = current_state.set(self.state) diff --git a/nonebot/processor/utils.py b/nonebot/processor/utils.py index bdae3aea..6f13e96c 100644 --- a/nonebot/processor/utils.py +++ b/nonebot/processor/utils.py @@ -1,8 +1,9 @@ import inspect from typing import Any, Dict, Type, Tuple, Union, Callable +from typing_extensions import GenericAlias, get_args, get_origin # type: ignore -from pydantic.typing import (ForwardRef, get_args, get_origin, - evaluate_forwardref) +from loguru import logger +from pydantic.typing import ForwardRef, evaluate_forwardref def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: @@ -25,7 +26,13 @@ def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, annotation = param.annotation if isinstance(annotation, str): annotation = ForwardRef(annotation) - annotation = evaluate_forwardref(annotation, globalns, globalns) + try: + annotation = evaluate_forwardref(annotation, globalns, globalns) + except Exception as e: + logger.opt(colors=True, exception=e).warning( + f"Unknown ForwardRef[\"{param.annotation}\"] for parameter {param.name}" + ) + return inspect.Parameter.empty return annotation @@ -33,13 +40,16 @@ def generic_check_issubclass( cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...]]) -> bool: try: - return isinstance(cls, type) and issubclass(cls, class_or_tuple) + return issubclass(cls, class_or_tuple) except TypeError: if get_origin(cls) is Union: for type_ in get_args(cls): if not generic_check_issubclass(type_, class_or_tuple): return False return True + elif isinstance(cls, GenericAlias): + origin = get_origin(cls) + return bool(origin and issubclass(origin, class_or_tuple)) raise diff --git a/nonebot/rule.py b/nonebot/rule.py index 863339df..40fc1b43 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -15,20 +15,18 @@ import asyncio from itertools import product from argparse import Namespace from argparse import ArgumentParser as ArgParser -from typing import (TYPE_CHECKING, Any, Dict, Tuple, Union, Callable, NoReturn, - Optional, Sequence, Awaitable) +from typing import (Any, Dict, Tuple, Union, Callable, NoReturn, Optional, + Sequence, Awaitable) from pygtrie import CharTrie from nonebot import get_driver from nonebot.log import logger from nonebot.utils import run_sync +from nonebot.adapters import Bot, Event from nonebot.exception import ParserExit from nonebot.typing import T_State, T_RuleChecker -if TYPE_CHECKING: - from nonebot.adapters import Bot, Event - class Rule: """ @@ -48,8 +46,8 @@ class Rule: __slots__ = ("checkers",) def __init__( - self, *checkers: Callable[["Bot", "Event", T_State], - Awaitable[bool]]) -> None: + self, *checkers: Callable[[Bot, Event, T_State], + Awaitable[bool]]) -> None: """ :参数: @@ -67,8 +65,7 @@ class Rule: * ``Set[Callable[[Bot, Event, T_State], Awaitable[bool]]]`` """ - async def __call__(self, bot: "Bot", event: "Event", - state: T_State) -> bool: + async def __call__(self, bot: Bot, event: Event, state: T_State) -> bool: """ :说明: @@ -123,7 +120,7 @@ class TrieRule: cls.suffix[suffix[::-1]] = value @classmethod - def get_value(cls, bot: "Bot", event: "Event", + def get_value(cls, bot: Bot, event: Event, state: T_State) -> Tuple[Dict[str, Any], Dict[str, Any]]: if event.get_type() != "message": state["_prefix"] = {"raw_command": None, "command": None} @@ -195,7 +192,7 @@ def startswith(msg: Union[str, Tuple[str, ...]], f"^(?:{'|'.join(re.escape(prefix) for prefix in msg)})", re.IGNORECASE if ignorecase else 0) - async def _startswith(bot: "Bot", event: "Event", state: T_State) -> bool: + async def _startswith(bot: Bot, event: Event, state: T_State) -> bool: if event.get_type() != "message": return False text = event.get_plaintext() @@ -222,7 +219,7 @@ def endswith(msg: Union[str, Tuple[str, ...]], f"(?:{'|'.join(re.escape(prefix) for prefix in msg)})$", re.IGNORECASE if ignorecase else 0) - async def _endswith(bot: "Bot", event: "Event", state: T_State) -> bool: + async def _endswith(bot: Bot, event: Event, state: T_State) -> bool: if event.get_type() != "message": return False text = event.get_plaintext() @@ -242,7 +239,7 @@ def keyword(*keywords: str) -> Rule: * ``*keywords: str``: 关键词 """ - async def _keyword(bot: "Bot", event: "Event", state: T_State) -> bool: + async def _keyword(bot: Bot, event: Event, state: T_State) -> bool: if event.get_type() != "message": return False text = event.get_plaintext() @@ -290,7 +287,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule: for start, sep in product(command_start, command_sep): TrieRule.add_prefix(f"{start}{sep.join(command)}", command) - async def _command(bot: "Bot", event: "Event", state: T_State) -> bool: + async def _command(bot: Bot, event: Event, state: T_State) -> bool: return state["_prefix"]["command"] in commands return Rule(_command) @@ -376,8 +373,7 @@ def shell_command(*cmds: Union[str, Tuple[str, ...]], for start, sep in product(command_start, command_sep): TrieRule.add_prefix(f"{start}{sep.join(command)}", command) - async def _shell_command(bot: "Bot", event: "Event", - state: T_State) -> bool: + async def _shell_command(bot: Bot, event: Event, state: T_State) -> bool: if state["_prefix"]["command"] in commands: message = str(event.get_message()) strip_message = message[len(state["_prefix"]["raw_command"] @@ -417,7 +413,7 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule: pattern = re.compile(regex, flags) - async def _regex(bot: "Bot", event: "Event", state: T_State) -> bool: + async def _regex(bot: Bot, event: Event, state: T_State) -> bool: if event.get_type() != "message": return False matched = pattern.search(str(event.get_message())) @@ -443,7 +439,7 @@ def to_me() -> Rule: * 无 """ - async def _to_me(bot: "Bot", event: "Event", state: T_State) -> bool: + async def _to_me(bot: Bot, event: Event, state: T_State) -> bool: return event.is_tome() return Rule(_to_me) diff --git a/tests/test_plugins/test_depends.py b/tests/test_plugins/test_depends.py new file mode 100644 index 00000000..77580374 --- /dev/null +++ b/tests/test_plugins/test_depends.py @@ -0,0 +1,17 @@ +from nonebot import on_command +from nonebot.log import logger +from nonebot.processor import Depends + +test = on_command("123") + + +def depend(state: dict): + return state + + +@test.got("a", prompt="a") +@test.got("b", prompt="b") +@test.receive() +@test.got("c", prompt="c") +async def _(state: dict = Depends(depend)): + logger.info(f"=======, {state}")