♻️ rewrite builtin rule and permission

This commit is contained in:
yanyongyu 2021-12-23 17:50:59 +08:00
parent 8fb394e4c3
commit e3aba26080
2 changed files with 86 additions and 64 deletions

View File

@ -11,23 +11,19 @@ r"""
import asyncio import asyncio
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from typing import ( from typing import Any, Set, Tuple, Union, NoReturn, Optional, Coroutine
Any,
Set,
Dict,
Tuple,
Union,
Callable,
NoReturn,
Optional,
Coroutine,
)
from nonebot import params
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
from nonebot.dependencies import Dependent from nonebot.dependencies import Dependent
from nonebot.exception import SkippedException from nonebot.exception import SkippedException
from nonebot.typing import T_Handler, T_DependencyCache, T_PermissionChecker from nonebot.typing import T_Handler, T_DependencyCache, T_PermissionChecker
from nonebot.params import (
BotParam,
EventType,
EventParam,
DependParam,
DefaultParam,
)
async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]): async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]):
@ -56,10 +52,10 @@ class Permission:
__slots__ = ("checkers",) __slots__ = ("checkers",)
HANDLER_PARAM_TYPES = [ HANDLER_PARAM_TYPES = [
params.DependParam, DependParam,
params.BotParam, BotParam,
params.EventParam, EventParam,
params.DefaultParam, DefaultParam,
] ]
def __init__(self, *checkers: Union[T_PermissionChecker, Dependent[bool]]) -> None: def __init__(self, *checkers: Union[T_PermissionChecker, Dependent[bool]]) -> None:
@ -142,23 +138,23 @@ class Permission:
class Message: class Message:
async def __call__(self, event: Event) -> bool: async def __call__(self, type: str = EventType()) -> bool:
return event.get_type() == "message" return type == "message"
class Notice: class Notice:
async def __call__(self, event: Event) -> bool: async def __call__(self, type: str = EventType()) -> bool:
return event.get_type() == "notice" return type == "notice"
class Request: class Request:
async def __call__(self, event: Event) -> bool: async def __call__(self, type: str = EventType()) -> bool:
return event.get_type() == "request" return type == "request"
class MetaEvent: class MetaEvent:
async def __call__(self, event: Event) -> bool: async def __call__(self, type: str = EventType()) -> bool:
return event.get_type() == "meta_event" return type == "meta_event"
MESSAGE = Permission(Message()) MESSAGE = Permission(Message())

View File

@ -21,12 +21,12 @@ from typing import Any, Set, List, Tuple, Union, NoReturn, Optional, Sequence
from pygtrie import CharTrie from pygtrie import CharTrie
from nonebot import get_driver
from nonebot.log import logger from nonebot.log import logger
from nonebot import params, get_driver
from nonebot.dependencies import Dependent from nonebot.dependencies import Dependent
from nonebot.exception import ParserExit, SkippedException from nonebot.exception import ParserExit, SkippedException
from nonebot.adapters import Bot, Event, Message, MessageSegment from nonebot.adapters import Bot, Event, Message, MessageSegment
from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_DependencyCache from nonebot.typing import T_State, T_RuleChecker, T_DependencyCache
from nonebot.consts import ( from nonebot.consts import (
CMD_KEY, CMD_KEY,
PREFIX_KEY, PREFIX_KEY,
@ -38,6 +38,19 @@ from nonebot.consts import (
REGEX_GROUP, REGEX_GROUP,
REGEX_MATCHED, REGEX_MATCHED,
) )
from nonebot.params import (
State,
Command,
BotParam,
EventToMe,
EventType,
EventParam,
StateParam,
DependParam,
DefaultParam,
EventMessage,
EventPlainText,
)
CMD_RESULT = TypedDict( CMD_RESULT = TypedDict(
"CMD_RESULT", "CMD_RESULT",
@ -68,11 +81,11 @@ class Rule:
__slots__ = ("checkers",) __slots__ = ("checkers",)
HANDLER_PARAM_TYPES = [ HANDLER_PARAM_TYPES = [
params.DependParam, DependParam,
params.BotParam, BotParam,
params.EventParam, EventParam,
params.StateParam, StateParam,
params.DefaultParam, DefaultParam,
] ]
def __init__(self, *checkers: Union[T_RuleChecker, Dependent[bool]]) -> None: def __init__(self, *checkers: Union[T_RuleChecker, Dependent[bool]]) -> None:
@ -189,15 +202,16 @@ class TrieRule:
return prefix return prefix
class Startswith: class StartswithRule:
def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False): def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False):
self.msg = msg self.msg = msg
self.ignorecase = ignorecase self.ignorecase = ignorecase
async def __call__(self, event: Event) -> Any: async def __call__(
if event.get_type() != "message": self, type: str = EventType(), text: str = EventPlainText()
) -> Any:
if type != "message":
return False return False
text = event.get_plaintext()
return bool( return bool(
re.match( re.match(
f"^(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})", f"^(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})",
@ -220,18 +234,19 @@ def startswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Ru
if isinstance(msg, str): if isinstance(msg, str):
msg = (msg,) msg = (msg,)
return Rule(Startswith(msg, ignorecase)) return Rule(StartswithRule(msg, ignorecase))
class Endswith: class EndswithRule:
def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False): def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False):
self.msg = msg self.msg = msg
self.ignorecase = ignorecase self.ignorecase = ignorecase
async def __call__(self, event: Event) -> Any: async def __call__(
if event.get_type() != "message": self, type: str = EventType(), text: str = EventPlainText()
) -> Any:
if type != "message":
return False return False
text = event.get_plaintext()
return bool( return bool(
re.search( re.search(
f"(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})$", f"(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})$",
@ -254,17 +269,18 @@ def endswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule
if isinstance(msg, str): if isinstance(msg, str):
msg = (msg,) msg = (msg,)
return Rule(Endswith(msg, ignorecase)) return Rule(EndswithRule(msg, ignorecase))
class Keywords: class KeywordsRule:
def __init__(self, *keywords: str): def __init__(self, *keywords: str):
self.keywords = keywords self.keywords = keywords
async def __call__(self, bot: Bot, event: Event, state: T_State) -> bool: async def __call__(
if event.get_type() != "message": self, type: str = EventType(), text: str = EventPlainText()
) -> bool:
if type != "message":
return False return False
text = event.get_plaintext()
return bool(text and any(keyword in text for keyword in self.keywords)) return bool(text and any(keyword in text for keyword in self.keywords))
@ -279,15 +295,15 @@ def keyword(*keywords: str) -> Rule:
* ``*keywords: str``: 关键词 * ``*keywords: str``: 关键词
""" """
return Rule(Keywords(*keywords)) return Rule(KeywordsRule(*keywords))
class Command: class CommandRule:
def __init__(self, cmds: List[Tuple[str, ...]]): def __init__(self, cmds: List[Tuple[str, ...]]):
self.cmds = cmds self.cmds = cmds
async def __call__(self, state: T_State) -> bool: async def __call__(self, cmd: Tuple[str, ...] = Command()) -> bool:
return state[PREFIX_KEY][CMD_KEY] in self.cmds return cmd in self.cmds
def __repr__(self): def __repr__(self):
return f"<Command {self.cmds}>" return f"<Command {self.cmds}>"
@ -334,7 +350,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
for start, sep in product(command_start, command_sep): for start, sep in product(command_start, command_sep):
TrieRule.add_prefix(f"{start}{sep.join(command)}", command) TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
return Rule(Command(commands)) return Rule(CommandRule(commands))
class ArgumentParser(ArgParser): class ArgumentParser(ArgParser):
@ -365,14 +381,19 @@ class ArgumentParser(ArgParser):
return super().parse_args(args=args, namespace=namespace) # type: ignore return super().parse_args(args=args, namespace=namespace) # type: ignore
class ShellCommand: class ShellCommandRule:
def __init__(self, cmds: List[Tuple[str, ...]], parser: Optional[ArgumentParser]): def __init__(self, cmds: List[Tuple[str, ...]], parser: Optional[ArgumentParser]):
self.cmds = cmds self.cmds = cmds
self.parser = parser self.parser = parser
async def __call__(self, event: Event, state: T_State) -> bool: async def __call__(
if state[PREFIX_KEY][CMD_KEY] in self.cmds: self,
message = str(event.get_message()) cmd: Tuple[str, ...] = Command(),
msg: Message = EventMessage(),
state: T_State = State(),
) -> bool:
if cmd in self.cmds:
message = str(msg)
strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip() strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip()
state[SHELL_ARGV] = shlex.split(strip_message) state[SHELL_ARGV] = shlex.split(strip_message)
if self.parser: if self.parser:
@ -442,18 +463,23 @@ def shell_command(
for start, sep in product(command_start, command_sep): for start, sep in product(command_start, command_sep):
TrieRule.add_prefix(f"{start}{sep.join(command)}", command) TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
return Rule(ShellCommand(commands, parser)) return Rule(ShellCommandRule(commands, parser))
class Regex: class RegexRule:
def __init__(self, regex: str, flags: int = 0): def __init__(self, regex: str, flags: int = 0):
self.regex = regex self.regex = regex
self.flags = flags self.flags = flags
async def __call__(self, event: Event, state: T_State) -> bool: async def __call__(
if event.get_type() != "message": self,
type: str = EventType(),
msg: Message = EventMessage(),
state: T_State = State(),
) -> bool:
if type != "message":
return False return False
matched = re.search(self.regex, str(event.get_message()), self.flags) matched = re.search(self.regex, str(msg), self.flags)
if matched: if matched:
state[REGEX_MATCHED] = matched.group() state[REGEX_MATCHED] = matched.group()
state[REGEX_GROUP] = matched.groups() state[REGEX_GROUP] = matched.groups()
@ -482,12 +508,12 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
\:\:\: \:\:\:
""" """
return Rule(Regex(regex, flags)) return Rule(RegexRule(regex, flags))
class ToMe: class ToMeRule:
async def __call__(self, event: Event) -> bool: async def __call__(self, to_me: bool = EventToMe()) -> bool:
return event.is_tome() return to_me
def to_me() -> Rule: def to_me() -> Rule:
@ -501,4 +527,4 @@ def to_me() -> Rule:
* *
""" """
return Rule(ToMe()) return Rule(ToMeRule())