♻️ rewrite builtin rule and permission

This commit is contained in:
yanyongyu 2021-12-23 17:50:59 +08:00
parent 8fb394e4c3
commit e3aba26080
2 changed files with 86 additions and 64 deletions

View File

@ -11,23 +11,19 @@ r"""
import asyncio
from contextlib import AsyncExitStack
from typing import (
Any,
Set,
Dict,
Tuple,
Union,
Callable,
NoReturn,
Optional,
Coroutine,
)
from typing import Any, Set, Tuple, Union, NoReturn, Optional, Coroutine
from nonebot import params
from nonebot.adapters import Bot, Event
from nonebot.dependencies import Dependent
from nonebot.exception import SkippedException
from nonebot.typing import T_Handler, T_DependencyCache, T_PermissionChecker
from nonebot.params import (
BotParam,
EventType,
EventParam,
DependParam,
DefaultParam,
)
async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]):
@ -56,10 +52,10 @@ class Permission:
__slots__ = ("checkers",)
HANDLER_PARAM_TYPES = [
params.DependParam,
params.BotParam,
params.EventParam,
params.DefaultParam,
DependParam,
BotParam,
EventParam,
DefaultParam,
]
def __init__(self, *checkers: Union[T_PermissionChecker, Dependent[bool]]) -> None:
@ -142,23 +138,23 @@ class Permission:
class Message:
async def __call__(self, event: Event) -> bool:
return event.get_type() == "message"
async def __call__(self, type: str = EventType()) -> bool:
return type == "message"
class Notice:
async def __call__(self, event: Event) -> bool:
return event.get_type() == "notice"
async def __call__(self, type: str = EventType()) -> bool:
return type == "notice"
class Request:
async def __call__(self, event: Event) -> bool:
return event.get_type() == "request"
async def __call__(self, type: str = EventType()) -> bool:
return type == "request"
class MetaEvent:
async def __call__(self, event: Event) -> bool:
return event.get_type() == "meta_event"
async def __call__(self, type: str = EventType()) -> bool:
return type == "meta_event"
MESSAGE = Permission(Message())

View File

@ -21,12 +21,12 @@ from typing import Any, Set, List, Tuple, Union, NoReturn, Optional, Sequence
from pygtrie import CharTrie
from nonebot import get_driver
from nonebot.log import logger
from nonebot import params, get_driver
from nonebot.dependencies import Dependent
from nonebot.exception import ParserExit, SkippedException
from nonebot.adapters import Bot, Event, Message, MessageSegment
from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_DependencyCache
from nonebot.typing import T_State, T_RuleChecker, T_DependencyCache
from nonebot.consts import (
CMD_KEY,
PREFIX_KEY,
@ -38,6 +38,19 @@ from nonebot.consts import (
REGEX_GROUP,
REGEX_MATCHED,
)
from nonebot.params import (
State,
Command,
BotParam,
EventToMe,
EventType,
EventParam,
StateParam,
DependParam,
DefaultParam,
EventMessage,
EventPlainText,
)
CMD_RESULT = TypedDict(
"CMD_RESULT",
@ -68,11 +81,11 @@ class Rule:
__slots__ = ("checkers",)
HANDLER_PARAM_TYPES = [
params.DependParam,
params.BotParam,
params.EventParam,
params.StateParam,
params.DefaultParam,
DependParam,
BotParam,
EventParam,
StateParam,
DefaultParam,
]
def __init__(self, *checkers: Union[T_RuleChecker, Dependent[bool]]) -> None:
@ -189,15 +202,16 @@ class TrieRule:
return prefix
class Startswith:
class StartswithRule:
def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False):
self.msg = msg
self.ignorecase = ignorecase
async def __call__(self, event: Event) -> Any:
if event.get_type() != "message":
async def __call__(
self, type: str = EventType(), text: str = EventPlainText()
) -> Any:
if type != "message":
return False
text = event.get_plaintext()
return bool(
re.match(
f"^(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})",
@ -220,18 +234,19 @@ def startswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Ru
if isinstance(msg, str):
msg = (msg,)
return Rule(Startswith(msg, ignorecase))
return Rule(StartswithRule(msg, ignorecase))
class Endswith:
class EndswithRule:
def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False):
self.msg = msg
self.ignorecase = ignorecase
async def __call__(self, event: Event) -> Any:
if event.get_type() != "message":
async def __call__(
self, type: str = EventType(), text: str = EventPlainText()
) -> Any:
if type != "message":
return False
text = event.get_plaintext()
return bool(
re.search(
f"(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})$",
@ -254,17 +269,18 @@ def endswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule
if isinstance(msg, str):
msg = (msg,)
return Rule(Endswith(msg, ignorecase))
return Rule(EndswithRule(msg, ignorecase))
class Keywords:
class KeywordsRule:
def __init__(self, *keywords: str):
self.keywords = keywords
async def __call__(self, bot: Bot, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
async def __call__(
self, type: str = EventType(), text: str = EventPlainText()
) -> bool:
if type != "message":
return False
text = event.get_plaintext()
return bool(text and any(keyword in text for keyword in self.keywords))
@ -279,15 +295,15 @@ def keyword(*keywords: str) -> Rule:
* ``*keywords: str``: 关键词
"""
return Rule(Keywords(*keywords))
return Rule(KeywordsRule(*keywords))
class Command:
class CommandRule:
def __init__(self, cmds: List[Tuple[str, ...]]):
self.cmds = cmds
async def __call__(self, state: T_State) -> bool:
return state[PREFIX_KEY][CMD_KEY] in self.cmds
async def __call__(self, cmd: Tuple[str, ...] = Command()) -> bool:
return cmd in self.cmds
def __repr__(self):
return f"<Command {self.cmds}>"
@ -334,7 +350,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
for start, sep in product(command_start, command_sep):
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
return Rule(Command(commands))
return Rule(CommandRule(commands))
class ArgumentParser(ArgParser):
@ -365,14 +381,19 @@ class ArgumentParser(ArgParser):
return super().parse_args(args=args, namespace=namespace) # type: ignore
class ShellCommand:
class ShellCommandRule:
def __init__(self, cmds: List[Tuple[str, ...]], parser: Optional[ArgumentParser]):
self.cmds = cmds
self.parser = parser
async def __call__(self, event: Event, state: T_State) -> bool:
if state[PREFIX_KEY][CMD_KEY] in self.cmds:
message = str(event.get_message())
async def __call__(
self,
cmd: Tuple[str, ...] = Command(),
msg: Message = EventMessage(),
state: T_State = State(),
) -> bool:
if cmd in self.cmds:
message = str(msg)
strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip()
state[SHELL_ARGV] = shlex.split(strip_message)
if self.parser:
@ -442,18 +463,23 @@ def shell_command(
for start, sep in product(command_start, command_sep):
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
return Rule(ShellCommand(commands, parser))
return Rule(ShellCommandRule(commands, parser))
class Regex:
class RegexRule:
def __init__(self, regex: str, flags: int = 0):
self.regex = regex
self.flags = flags
async def __call__(self, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
async def __call__(
self,
type: str = EventType(),
msg: Message = EventMessage(),
state: T_State = State(),
) -> bool:
if type != "message":
return False
matched = re.search(self.regex, str(event.get_message()), self.flags)
matched = re.search(self.regex, str(msg), self.flags)
if matched:
state[REGEX_MATCHED] = matched.group()
state[REGEX_GROUP] = matched.groups()
@ -482,12 +508,12 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
\:\:\:
"""
return Rule(Regex(regex, flags))
return Rule(RegexRule(regex, flags))
class ToMe:
async def __call__(self, event: Event) -> bool:
return event.is_tome()
class ToMeRule:
async def __call__(self, to_me: bool = EventToMe()) -> bool:
return to_me
def to_me() -> Rule:
@ -501,4 +527,4 @@ def to_me() -> Rule:
*
"""
return Rule(ToMe())
return Rule(ToMeRule())