mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 00:55:07 +08:00
♻️ rewrite builtin rule and permission
This commit is contained in:
parent
8fb394e4c3
commit
e3aba26080
@ -11,23 +11,19 @@ r"""
|
||||
|
||||
import asyncio
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import (
|
||||
Any,
|
||||
Set,
|
||||
Dict,
|
||||
Tuple,
|
||||
Union,
|
||||
Callable,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Coroutine,
|
||||
)
|
||||
from typing import Any, Set, Tuple, Union, NoReturn, Optional, Coroutine
|
||||
|
||||
from nonebot import params
|
||||
from nonebot.adapters import Bot, Event
|
||||
from nonebot.dependencies import Dependent
|
||||
from nonebot.exception import SkippedException
|
||||
from nonebot.typing import T_Handler, T_DependencyCache, T_PermissionChecker
|
||||
from nonebot.params import (
|
||||
BotParam,
|
||||
EventType,
|
||||
EventParam,
|
||||
DependParam,
|
||||
DefaultParam,
|
||||
)
|
||||
|
||||
|
||||
async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]):
|
||||
@ -56,10 +52,10 @@ class Permission:
|
||||
__slots__ = ("checkers",)
|
||||
|
||||
HANDLER_PARAM_TYPES = [
|
||||
params.DependParam,
|
||||
params.BotParam,
|
||||
params.EventParam,
|
||||
params.DefaultParam,
|
||||
DependParam,
|
||||
BotParam,
|
||||
EventParam,
|
||||
DefaultParam,
|
||||
]
|
||||
|
||||
def __init__(self, *checkers: Union[T_PermissionChecker, Dependent[bool]]) -> None:
|
||||
@ -142,23 +138,23 @@ class Permission:
|
||||
|
||||
|
||||
class Message:
|
||||
async def __call__(self, event: Event) -> bool:
|
||||
return event.get_type() == "message"
|
||||
async def __call__(self, type: str = EventType()) -> bool:
|
||||
return type == "message"
|
||||
|
||||
|
||||
class Notice:
|
||||
async def __call__(self, event: Event) -> bool:
|
||||
return event.get_type() == "notice"
|
||||
async def __call__(self, type: str = EventType()) -> bool:
|
||||
return type == "notice"
|
||||
|
||||
|
||||
class Request:
|
||||
async def __call__(self, event: Event) -> bool:
|
||||
return event.get_type() == "request"
|
||||
async def __call__(self, type: str = EventType()) -> bool:
|
||||
return type == "request"
|
||||
|
||||
|
||||
class MetaEvent:
|
||||
async def __call__(self, event: Event) -> bool:
|
||||
return event.get_type() == "meta_event"
|
||||
async def __call__(self, type: str = EventType()) -> bool:
|
||||
return type == "meta_event"
|
||||
|
||||
|
||||
MESSAGE = Permission(Message())
|
||||
|
106
nonebot/rule.py
106
nonebot/rule.py
@ -21,12 +21,12 @@ from typing import Any, Set, List, Tuple, Union, NoReturn, Optional, Sequence
|
||||
|
||||
from pygtrie import CharTrie
|
||||
|
||||
from nonebot import get_driver
|
||||
from nonebot.log import logger
|
||||
from nonebot import params, get_driver
|
||||
from nonebot.dependencies import Dependent
|
||||
from nonebot.exception import ParserExit, SkippedException
|
||||
from nonebot.adapters import Bot, Event, Message, MessageSegment
|
||||
from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_DependencyCache
|
||||
from nonebot.typing import T_State, T_RuleChecker, T_DependencyCache
|
||||
from nonebot.consts import (
|
||||
CMD_KEY,
|
||||
PREFIX_KEY,
|
||||
@ -38,6 +38,19 @@ from nonebot.consts import (
|
||||
REGEX_GROUP,
|
||||
REGEX_MATCHED,
|
||||
)
|
||||
from nonebot.params import (
|
||||
State,
|
||||
Command,
|
||||
BotParam,
|
||||
EventToMe,
|
||||
EventType,
|
||||
EventParam,
|
||||
StateParam,
|
||||
DependParam,
|
||||
DefaultParam,
|
||||
EventMessage,
|
||||
EventPlainText,
|
||||
)
|
||||
|
||||
CMD_RESULT = TypedDict(
|
||||
"CMD_RESULT",
|
||||
@ -68,11 +81,11 @@ class Rule:
|
||||
__slots__ = ("checkers",)
|
||||
|
||||
HANDLER_PARAM_TYPES = [
|
||||
params.DependParam,
|
||||
params.BotParam,
|
||||
params.EventParam,
|
||||
params.StateParam,
|
||||
params.DefaultParam,
|
||||
DependParam,
|
||||
BotParam,
|
||||
EventParam,
|
||||
StateParam,
|
||||
DefaultParam,
|
||||
]
|
||||
|
||||
def __init__(self, *checkers: Union[T_RuleChecker, Dependent[bool]]) -> None:
|
||||
@ -189,15 +202,16 @@ class TrieRule:
|
||||
return prefix
|
||||
|
||||
|
||||
class Startswith:
|
||||
class StartswithRule:
|
||||
def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False):
|
||||
self.msg = msg
|
||||
self.ignorecase = ignorecase
|
||||
|
||||
async def __call__(self, event: Event) -> Any:
|
||||
if event.get_type() != "message":
|
||||
async def __call__(
|
||||
self, type: str = EventType(), text: str = EventPlainText()
|
||||
) -> Any:
|
||||
if type != "message":
|
||||
return False
|
||||
text = event.get_plaintext()
|
||||
return bool(
|
||||
re.match(
|
||||
f"^(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})",
|
||||
@ -220,18 +234,19 @@ def startswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Ru
|
||||
if isinstance(msg, str):
|
||||
msg = (msg,)
|
||||
|
||||
return Rule(Startswith(msg, ignorecase))
|
||||
return Rule(StartswithRule(msg, ignorecase))
|
||||
|
||||
|
||||
class Endswith:
|
||||
class EndswithRule:
|
||||
def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False):
|
||||
self.msg = msg
|
||||
self.ignorecase = ignorecase
|
||||
|
||||
async def __call__(self, event: Event) -> Any:
|
||||
if event.get_type() != "message":
|
||||
async def __call__(
|
||||
self, type: str = EventType(), text: str = EventPlainText()
|
||||
) -> Any:
|
||||
if type != "message":
|
||||
return False
|
||||
text = event.get_plaintext()
|
||||
return bool(
|
||||
re.search(
|
||||
f"(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})$",
|
||||
@ -254,17 +269,18 @@ def endswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule
|
||||
if isinstance(msg, str):
|
||||
msg = (msg,)
|
||||
|
||||
return Rule(Endswith(msg, ignorecase))
|
||||
return Rule(EndswithRule(msg, ignorecase))
|
||||
|
||||
|
||||
class Keywords:
|
||||
class KeywordsRule:
|
||||
def __init__(self, *keywords: str):
|
||||
self.keywords = keywords
|
||||
|
||||
async def __call__(self, bot: Bot, event: Event, state: T_State) -> bool:
|
||||
if event.get_type() != "message":
|
||||
async def __call__(
|
||||
self, type: str = EventType(), text: str = EventPlainText()
|
||||
) -> bool:
|
||||
if type != "message":
|
||||
return False
|
||||
text = event.get_plaintext()
|
||||
return bool(text and any(keyword in text for keyword in self.keywords))
|
||||
|
||||
|
||||
@ -279,15 +295,15 @@ def keyword(*keywords: str) -> Rule:
|
||||
* ``*keywords: str``: 关键词
|
||||
"""
|
||||
|
||||
return Rule(Keywords(*keywords))
|
||||
return Rule(KeywordsRule(*keywords))
|
||||
|
||||
|
||||
class Command:
|
||||
class CommandRule:
|
||||
def __init__(self, cmds: List[Tuple[str, ...]]):
|
||||
self.cmds = cmds
|
||||
|
||||
async def __call__(self, state: T_State) -> bool:
|
||||
return state[PREFIX_KEY][CMD_KEY] in self.cmds
|
||||
async def __call__(self, cmd: Tuple[str, ...] = Command()) -> bool:
|
||||
return cmd in self.cmds
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Command {self.cmds}>"
|
||||
@ -334,7 +350,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
|
||||
for start, sep in product(command_start, command_sep):
|
||||
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
|
||||
|
||||
return Rule(Command(commands))
|
||||
return Rule(CommandRule(commands))
|
||||
|
||||
|
||||
class ArgumentParser(ArgParser):
|
||||
@ -365,14 +381,19 @@ class ArgumentParser(ArgParser):
|
||||
return super().parse_args(args=args, namespace=namespace) # type: ignore
|
||||
|
||||
|
||||
class ShellCommand:
|
||||
class ShellCommandRule:
|
||||
def __init__(self, cmds: List[Tuple[str, ...]], parser: Optional[ArgumentParser]):
|
||||
self.cmds = cmds
|
||||
self.parser = parser
|
||||
|
||||
async def __call__(self, event: Event, state: T_State) -> bool:
|
||||
if state[PREFIX_KEY][CMD_KEY] in self.cmds:
|
||||
message = str(event.get_message())
|
||||
async def __call__(
|
||||
self,
|
||||
cmd: Tuple[str, ...] = Command(),
|
||||
msg: Message = EventMessage(),
|
||||
state: T_State = State(),
|
||||
) -> bool:
|
||||
if cmd in self.cmds:
|
||||
message = str(msg)
|
||||
strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip()
|
||||
state[SHELL_ARGV] = shlex.split(strip_message)
|
||||
if self.parser:
|
||||
@ -442,18 +463,23 @@ def shell_command(
|
||||
for start, sep in product(command_start, command_sep):
|
||||
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
|
||||
|
||||
return Rule(ShellCommand(commands, parser))
|
||||
return Rule(ShellCommandRule(commands, parser))
|
||||
|
||||
|
||||
class Regex:
|
||||
class RegexRule:
|
||||
def __init__(self, regex: str, flags: int = 0):
|
||||
self.regex = regex
|
||||
self.flags = flags
|
||||
|
||||
async def __call__(self, event: Event, state: T_State) -> bool:
|
||||
if event.get_type() != "message":
|
||||
async def __call__(
|
||||
self,
|
||||
type: str = EventType(),
|
||||
msg: Message = EventMessage(),
|
||||
state: T_State = State(),
|
||||
) -> bool:
|
||||
if type != "message":
|
||||
return False
|
||||
matched = re.search(self.regex, str(event.get_message()), self.flags)
|
||||
matched = re.search(self.regex, str(msg), self.flags)
|
||||
if matched:
|
||||
state[REGEX_MATCHED] = matched.group()
|
||||
state[REGEX_GROUP] = matched.groups()
|
||||
@ -482,12 +508,12 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
|
||||
\:\:\:
|
||||
"""
|
||||
|
||||
return Rule(Regex(regex, flags))
|
||||
return Rule(RegexRule(regex, flags))
|
||||
|
||||
|
||||
class ToMe:
|
||||
async def __call__(self, event: Event) -> bool:
|
||||
return event.is_tome()
|
||||
class ToMeRule:
|
||||
async def __call__(self, to_me: bool = EventToMe()) -> bool:
|
||||
return to_me
|
||||
|
||||
|
||||
def to_me() -> Rule:
|
||||
@ -501,4 +527,4 @@ def to_me() -> Rule:
|
||||
* 无
|
||||
"""
|
||||
|
||||
return Rule(ToMe())
|
||||
return Rule(ToMeRule())
|
||||
|
Loading…
Reference in New Issue
Block a user