From e3aba260804e89be9c014de323267a64fbb5851a Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Thu, 23 Dec 2021 17:50:59 +0800 Subject: [PATCH] :recycle: rewrite builtin rule and permission --- nonebot/permission.py | 44 ++++++++---------- nonebot/rule.py | 106 ++++++++++++++++++++++++++---------------- 2 files changed, 86 insertions(+), 64 deletions(-) diff --git a/nonebot/permission.py b/nonebot/permission.py index a2ce1778..79aa87aa 100644 --- a/nonebot/permission.py +++ b/nonebot/permission.py @@ -11,23 +11,19 @@ r""" import asyncio from contextlib import AsyncExitStack -from typing import ( - Any, - Set, - Dict, - Tuple, - Union, - Callable, - NoReturn, - Optional, - Coroutine, -) +from typing import Any, Set, Tuple, Union, NoReturn, Optional, Coroutine -from nonebot import params from nonebot.adapters import Bot, Event from nonebot.dependencies import Dependent from nonebot.exception import SkippedException from nonebot.typing import T_Handler, T_DependencyCache, T_PermissionChecker +from nonebot.params import ( + BotParam, + EventType, + EventParam, + DependParam, + DefaultParam, +) async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]): @@ -56,10 +52,10 @@ class Permission: __slots__ = ("checkers",) HANDLER_PARAM_TYPES = [ - params.DependParam, - params.BotParam, - params.EventParam, - params.DefaultParam, + DependParam, + BotParam, + EventParam, + DefaultParam, ] def __init__(self, *checkers: Union[T_PermissionChecker, Dependent[bool]]) -> None: @@ -142,23 +138,23 @@ class Permission: class Message: - async def __call__(self, event: Event) -> bool: - return event.get_type() == "message" + async def __call__(self, type: str = EventType()) -> bool: + return type == "message" class Notice: - async def __call__(self, event: Event) -> bool: - return event.get_type() == "notice" + async def __call__(self, type: str = EventType()) -> bool: + return type == "notice" class Request: - async def __call__(self, event: Event) -> bool: - return event.get_type() == "request" + async def __call__(self, type: str = EventType()) -> bool: + return type == "request" class MetaEvent: - async def __call__(self, event: Event) -> bool: - return event.get_type() == "meta_event" + async def __call__(self, type: str = EventType()) -> bool: + return type == "meta_event" MESSAGE = Permission(Message()) diff --git a/nonebot/rule.py b/nonebot/rule.py index 8224e31f..33ab18d5 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -21,12 +21,12 @@ from typing import Any, Set, List, Tuple, Union, NoReturn, Optional, Sequence from pygtrie import CharTrie +from nonebot import get_driver from nonebot.log import logger -from nonebot import params, get_driver from nonebot.dependencies import Dependent from nonebot.exception import ParserExit, SkippedException from nonebot.adapters import Bot, Event, Message, MessageSegment -from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_DependencyCache +from nonebot.typing import T_State, T_RuleChecker, T_DependencyCache from nonebot.consts import ( CMD_KEY, PREFIX_KEY, @@ -38,6 +38,19 @@ from nonebot.consts import ( REGEX_GROUP, REGEX_MATCHED, ) +from nonebot.params import ( + State, + Command, + BotParam, + EventToMe, + EventType, + EventParam, + StateParam, + DependParam, + DefaultParam, + EventMessage, + EventPlainText, +) CMD_RESULT = TypedDict( "CMD_RESULT", @@ -68,11 +81,11 @@ class Rule: __slots__ = ("checkers",) HANDLER_PARAM_TYPES = [ - params.DependParam, - params.BotParam, - params.EventParam, - params.StateParam, - params.DefaultParam, + DependParam, + BotParam, + EventParam, + StateParam, + DefaultParam, ] def __init__(self, *checkers: Union[T_RuleChecker, Dependent[bool]]) -> None: @@ -189,15 +202,16 @@ class TrieRule: return prefix -class Startswith: +class StartswithRule: def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False): self.msg = msg self.ignorecase = ignorecase - async def __call__(self, event: Event) -> Any: - if event.get_type() != "message": + async def __call__( + self, type: str = EventType(), text: str = EventPlainText() + ) -> Any: + if type != "message": return False - text = event.get_plaintext() return bool( re.match( f"^(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})", @@ -220,18 +234,19 @@ def startswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Ru if isinstance(msg, str): msg = (msg,) - return Rule(Startswith(msg, ignorecase)) + return Rule(StartswithRule(msg, ignorecase)) -class Endswith: +class EndswithRule: def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False): self.msg = msg self.ignorecase = ignorecase - async def __call__(self, event: Event) -> Any: - if event.get_type() != "message": + async def __call__( + self, type: str = EventType(), text: str = EventPlainText() + ) -> Any: + if type != "message": return False - text = event.get_plaintext() return bool( re.search( f"(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})$", @@ -254,17 +269,18 @@ def endswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule if isinstance(msg, str): msg = (msg,) - return Rule(Endswith(msg, ignorecase)) + return Rule(EndswithRule(msg, ignorecase)) -class Keywords: +class KeywordsRule: def __init__(self, *keywords: str): self.keywords = keywords - async def __call__(self, bot: Bot, event: Event, state: T_State) -> bool: - if event.get_type() != "message": + async def __call__( + self, type: str = EventType(), text: str = EventPlainText() + ) -> bool: + if type != "message": return False - text = event.get_plaintext() return bool(text and any(keyword in text for keyword in self.keywords)) @@ -279,15 +295,15 @@ def keyword(*keywords: str) -> Rule: * ``*keywords: str``: 关键词 """ - return Rule(Keywords(*keywords)) + return Rule(KeywordsRule(*keywords)) -class Command: +class CommandRule: def __init__(self, cmds: List[Tuple[str, ...]]): self.cmds = cmds - async def __call__(self, state: T_State) -> bool: - return state[PREFIX_KEY][CMD_KEY] in self.cmds + async def __call__(self, cmd: Tuple[str, ...] = Command()) -> bool: + return cmd in self.cmds def __repr__(self): return f"" @@ -334,7 +350,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule: for start, sep in product(command_start, command_sep): TrieRule.add_prefix(f"{start}{sep.join(command)}", command) - return Rule(Command(commands)) + return Rule(CommandRule(commands)) class ArgumentParser(ArgParser): @@ -365,14 +381,19 @@ class ArgumentParser(ArgParser): return super().parse_args(args=args, namespace=namespace) # type: ignore -class ShellCommand: +class ShellCommandRule: def __init__(self, cmds: List[Tuple[str, ...]], parser: Optional[ArgumentParser]): self.cmds = cmds self.parser = parser - async def __call__(self, event: Event, state: T_State) -> bool: - if state[PREFIX_KEY][CMD_KEY] in self.cmds: - message = str(event.get_message()) + async def __call__( + self, + cmd: Tuple[str, ...] = Command(), + msg: Message = EventMessage(), + state: T_State = State(), + ) -> bool: + if cmd in self.cmds: + message = str(msg) strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip() state[SHELL_ARGV] = shlex.split(strip_message) if self.parser: @@ -442,18 +463,23 @@ def shell_command( for start, sep in product(command_start, command_sep): TrieRule.add_prefix(f"{start}{sep.join(command)}", command) - return Rule(ShellCommand(commands, parser)) + return Rule(ShellCommandRule(commands, parser)) -class Regex: +class RegexRule: def __init__(self, regex: str, flags: int = 0): self.regex = regex self.flags = flags - async def __call__(self, event: Event, state: T_State) -> bool: - if event.get_type() != "message": + async def __call__( + self, + type: str = EventType(), + msg: Message = EventMessage(), + state: T_State = State(), + ) -> bool: + if type != "message": return False - matched = re.search(self.regex, str(event.get_message()), self.flags) + matched = re.search(self.regex, str(msg), self.flags) if matched: state[REGEX_MATCHED] = matched.group() state[REGEX_GROUP] = matched.groups() @@ -482,12 +508,12 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule: \:\:\: """ - return Rule(Regex(regex, flags)) + return Rule(RegexRule(regex, flags)) -class ToMe: - async def __call__(self, event: Event) -> bool: - return event.is_tome() +class ToMeRule: + async def __call__(self, to_me: bool = EventToMe()) -> bool: + return to_me def to_me() -> Rule: @@ -501,4 +527,4 @@ def to_me() -> Rule: * 无 """ - return Rule(ToMe()) + return Rule(ToMeRule())