mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 00:55:07 +08:00
♻️ rewrite builtin rule and permission
This commit is contained in:
parent
8fb394e4c3
commit
e3aba26080
@ -11,23 +11,19 @@ r"""
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from typing import (
|
from typing import Any, Set, Tuple, Union, NoReturn, Optional, Coroutine
|
||||||
Any,
|
|
||||||
Set,
|
|
||||||
Dict,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
Callable,
|
|
||||||
NoReturn,
|
|
||||||
Optional,
|
|
||||||
Coroutine,
|
|
||||||
)
|
|
||||||
|
|
||||||
from nonebot import params
|
|
||||||
from nonebot.adapters import Bot, Event
|
from nonebot.adapters import Bot, Event
|
||||||
from nonebot.dependencies import Dependent
|
from nonebot.dependencies import Dependent
|
||||||
from nonebot.exception import SkippedException
|
from nonebot.exception import SkippedException
|
||||||
from nonebot.typing import T_Handler, T_DependencyCache, T_PermissionChecker
|
from nonebot.typing import T_Handler, T_DependencyCache, T_PermissionChecker
|
||||||
|
from nonebot.params import (
|
||||||
|
BotParam,
|
||||||
|
EventType,
|
||||||
|
EventParam,
|
||||||
|
DependParam,
|
||||||
|
DefaultParam,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]):
|
async def _run_coro_with_catch(coro: Coroutine[Any, Any, Any]):
|
||||||
@ -56,10 +52,10 @@ class Permission:
|
|||||||
__slots__ = ("checkers",)
|
__slots__ = ("checkers",)
|
||||||
|
|
||||||
HANDLER_PARAM_TYPES = [
|
HANDLER_PARAM_TYPES = [
|
||||||
params.DependParam,
|
DependParam,
|
||||||
params.BotParam,
|
BotParam,
|
||||||
params.EventParam,
|
EventParam,
|
||||||
params.DefaultParam,
|
DefaultParam,
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, *checkers: Union[T_PermissionChecker, Dependent[bool]]) -> None:
|
def __init__(self, *checkers: Union[T_PermissionChecker, Dependent[bool]]) -> None:
|
||||||
@ -142,23 +138,23 @@ class Permission:
|
|||||||
|
|
||||||
|
|
||||||
class Message:
|
class Message:
|
||||||
async def __call__(self, event: Event) -> bool:
|
async def __call__(self, type: str = EventType()) -> bool:
|
||||||
return event.get_type() == "message"
|
return type == "message"
|
||||||
|
|
||||||
|
|
||||||
class Notice:
|
class Notice:
|
||||||
async def __call__(self, event: Event) -> bool:
|
async def __call__(self, type: str = EventType()) -> bool:
|
||||||
return event.get_type() == "notice"
|
return type == "notice"
|
||||||
|
|
||||||
|
|
||||||
class Request:
|
class Request:
|
||||||
async def __call__(self, event: Event) -> bool:
|
async def __call__(self, type: str = EventType()) -> bool:
|
||||||
return event.get_type() == "request"
|
return type == "request"
|
||||||
|
|
||||||
|
|
||||||
class MetaEvent:
|
class MetaEvent:
|
||||||
async def __call__(self, event: Event) -> bool:
|
async def __call__(self, type: str = EventType()) -> bool:
|
||||||
return event.get_type() == "meta_event"
|
return type == "meta_event"
|
||||||
|
|
||||||
|
|
||||||
MESSAGE = Permission(Message())
|
MESSAGE = Permission(Message())
|
||||||
|
106
nonebot/rule.py
106
nonebot/rule.py
@ -21,12 +21,12 @@ from typing import Any, Set, List, Tuple, Union, NoReturn, Optional, Sequence
|
|||||||
|
|
||||||
from pygtrie import CharTrie
|
from pygtrie import CharTrie
|
||||||
|
|
||||||
|
from nonebot import get_driver
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot import params, get_driver
|
|
||||||
from nonebot.dependencies import Dependent
|
from nonebot.dependencies import Dependent
|
||||||
from nonebot.exception import ParserExit, SkippedException
|
from nonebot.exception import ParserExit, SkippedException
|
||||||
from nonebot.adapters import Bot, Event, Message, MessageSegment
|
from nonebot.adapters import Bot, Event, Message, MessageSegment
|
||||||
from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_DependencyCache
|
from nonebot.typing import T_State, T_RuleChecker, T_DependencyCache
|
||||||
from nonebot.consts import (
|
from nonebot.consts import (
|
||||||
CMD_KEY,
|
CMD_KEY,
|
||||||
PREFIX_KEY,
|
PREFIX_KEY,
|
||||||
@ -38,6 +38,19 @@ from nonebot.consts import (
|
|||||||
REGEX_GROUP,
|
REGEX_GROUP,
|
||||||
REGEX_MATCHED,
|
REGEX_MATCHED,
|
||||||
)
|
)
|
||||||
|
from nonebot.params import (
|
||||||
|
State,
|
||||||
|
Command,
|
||||||
|
BotParam,
|
||||||
|
EventToMe,
|
||||||
|
EventType,
|
||||||
|
EventParam,
|
||||||
|
StateParam,
|
||||||
|
DependParam,
|
||||||
|
DefaultParam,
|
||||||
|
EventMessage,
|
||||||
|
EventPlainText,
|
||||||
|
)
|
||||||
|
|
||||||
CMD_RESULT = TypedDict(
|
CMD_RESULT = TypedDict(
|
||||||
"CMD_RESULT",
|
"CMD_RESULT",
|
||||||
@ -68,11 +81,11 @@ class Rule:
|
|||||||
__slots__ = ("checkers",)
|
__slots__ = ("checkers",)
|
||||||
|
|
||||||
HANDLER_PARAM_TYPES = [
|
HANDLER_PARAM_TYPES = [
|
||||||
params.DependParam,
|
DependParam,
|
||||||
params.BotParam,
|
BotParam,
|
||||||
params.EventParam,
|
EventParam,
|
||||||
params.StateParam,
|
StateParam,
|
||||||
params.DefaultParam,
|
DefaultParam,
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, *checkers: Union[T_RuleChecker, Dependent[bool]]) -> None:
|
def __init__(self, *checkers: Union[T_RuleChecker, Dependent[bool]]) -> None:
|
||||||
@ -189,15 +202,16 @@ class TrieRule:
|
|||||||
return prefix
|
return prefix
|
||||||
|
|
||||||
|
|
||||||
class Startswith:
|
class StartswithRule:
|
||||||
def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False):
|
def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False):
|
||||||
self.msg = msg
|
self.msg = msg
|
||||||
self.ignorecase = ignorecase
|
self.ignorecase = ignorecase
|
||||||
|
|
||||||
async def __call__(self, event: Event) -> Any:
|
async def __call__(
|
||||||
if event.get_type() != "message":
|
self, type: str = EventType(), text: str = EventPlainText()
|
||||||
|
) -> Any:
|
||||||
|
if type != "message":
|
||||||
return False
|
return False
|
||||||
text = event.get_plaintext()
|
|
||||||
return bool(
|
return bool(
|
||||||
re.match(
|
re.match(
|
||||||
f"^(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})",
|
f"^(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})",
|
||||||
@ -220,18 +234,19 @@ def startswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Ru
|
|||||||
if isinstance(msg, str):
|
if isinstance(msg, str):
|
||||||
msg = (msg,)
|
msg = (msg,)
|
||||||
|
|
||||||
return Rule(Startswith(msg, ignorecase))
|
return Rule(StartswithRule(msg, ignorecase))
|
||||||
|
|
||||||
|
|
||||||
class Endswith:
|
class EndswithRule:
|
||||||
def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False):
|
def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False):
|
||||||
self.msg = msg
|
self.msg = msg
|
||||||
self.ignorecase = ignorecase
|
self.ignorecase = ignorecase
|
||||||
|
|
||||||
async def __call__(self, event: Event) -> Any:
|
async def __call__(
|
||||||
if event.get_type() != "message":
|
self, type: str = EventType(), text: str = EventPlainText()
|
||||||
|
) -> Any:
|
||||||
|
if type != "message":
|
||||||
return False
|
return False
|
||||||
text = event.get_plaintext()
|
|
||||||
return bool(
|
return bool(
|
||||||
re.search(
|
re.search(
|
||||||
f"(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})$",
|
f"(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})$",
|
||||||
@ -254,17 +269,18 @@ def endswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule
|
|||||||
if isinstance(msg, str):
|
if isinstance(msg, str):
|
||||||
msg = (msg,)
|
msg = (msg,)
|
||||||
|
|
||||||
return Rule(Endswith(msg, ignorecase))
|
return Rule(EndswithRule(msg, ignorecase))
|
||||||
|
|
||||||
|
|
||||||
class Keywords:
|
class KeywordsRule:
|
||||||
def __init__(self, *keywords: str):
|
def __init__(self, *keywords: str):
|
||||||
self.keywords = keywords
|
self.keywords = keywords
|
||||||
|
|
||||||
async def __call__(self, bot: Bot, event: Event, state: T_State) -> bool:
|
async def __call__(
|
||||||
if event.get_type() != "message":
|
self, type: str = EventType(), text: str = EventPlainText()
|
||||||
|
) -> bool:
|
||||||
|
if type != "message":
|
||||||
return False
|
return False
|
||||||
text = event.get_plaintext()
|
|
||||||
return bool(text and any(keyword in text for keyword in self.keywords))
|
return bool(text and any(keyword in text for keyword in self.keywords))
|
||||||
|
|
||||||
|
|
||||||
@ -279,15 +295,15 @@ def keyword(*keywords: str) -> Rule:
|
|||||||
* ``*keywords: str``: 关键词
|
* ``*keywords: str``: 关键词
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return Rule(Keywords(*keywords))
|
return Rule(KeywordsRule(*keywords))
|
||||||
|
|
||||||
|
|
||||||
class Command:
|
class CommandRule:
|
||||||
def __init__(self, cmds: List[Tuple[str, ...]]):
|
def __init__(self, cmds: List[Tuple[str, ...]]):
|
||||||
self.cmds = cmds
|
self.cmds = cmds
|
||||||
|
|
||||||
async def __call__(self, state: T_State) -> bool:
|
async def __call__(self, cmd: Tuple[str, ...] = Command()) -> bool:
|
||||||
return state[PREFIX_KEY][CMD_KEY] in self.cmds
|
return cmd in self.cmds
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<Command {self.cmds}>"
|
return f"<Command {self.cmds}>"
|
||||||
@ -334,7 +350,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
|
|||||||
for start, sep in product(command_start, command_sep):
|
for start, sep in product(command_start, command_sep):
|
||||||
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
|
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
|
||||||
|
|
||||||
return Rule(Command(commands))
|
return Rule(CommandRule(commands))
|
||||||
|
|
||||||
|
|
||||||
class ArgumentParser(ArgParser):
|
class ArgumentParser(ArgParser):
|
||||||
@ -365,14 +381,19 @@ class ArgumentParser(ArgParser):
|
|||||||
return super().parse_args(args=args, namespace=namespace) # type: ignore
|
return super().parse_args(args=args, namespace=namespace) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class ShellCommand:
|
class ShellCommandRule:
|
||||||
def __init__(self, cmds: List[Tuple[str, ...]], parser: Optional[ArgumentParser]):
|
def __init__(self, cmds: List[Tuple[str, ...]], parser: Optional[ArgumentParser]):
|
||||||
self.cmds = cmds
|
self.cmds = cmds
|
||||||
self.parser = parser
|
self.parser = parser
|
||||||
|
|
||||||
async def __call__(self, event: Event, state: T_State) -> bool:
|
async def __call__(
|
||||||
if state[PREFIX_KEY][CMD_KEY] in self.cmds:
|
self,
|
||||||
message = str(event.get_message())
|
cmd: Tuple[str, ...] = Command(),
|
||||||
|
msg: Message = EventMessage(),
|
||||||
|
state: T_State = State(),
|
||||||
|
) -> bool:
|
||||||
|
if cmd in self.cmds:
|
||||||
|
message = str(msg)
|
||||||
strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip()
|
strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]) :].lstrip()
|
||||||
state[SHELL_ARGV] = shlex.split(strip_message)
|
state[SHELL_ARGV] = shlex.split(strip_message)
|
||||||
if self.parser:
|
if self.parser:
|
||||||
@ -442,18 +463,23 @@ def shell_command(
|
|||||||
for start, sep in product(command_start, command_sep):
|
for start, sep in product(command_start, command_sep):
|
||||||
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
|
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
|
||||||
|
|
||||||
return Rule(ShellCommand(commands, parser))
|
return Rule(ShellCommandRule(commands, parser))
|
||||||
|
|
||||||
|
|
||||||
class Regex:
|
class RegexRule:
|
||||||
def __init__(self, regex: str, flags: int = 0):
|
def __init__(self, regex: str, flags: int = 0):
|
||||||
self.regex = regex
|
self.regex = regex
|
||||||
self.flags = flags
|
self.flags = flags
|
||||||
|
|
||||||
async def __call__(self, event: Event, state: T_State) -> bool:
|
async def __call__(
|
||||||
if event.get_type() != "message":
|
self,
|
||||||
|
type: str = EventType(),
|
||||||
|
msg: Message = EventMessage(),
|
||||||
|
state: T_State = State(),
|
||||||
|
) -> bool:
|
||||||
|
if type != "message":
|
||||||
return False
|
return False
|
||||||
matched = re.search(self.regex, str(event.get_message()), self.flags)
|
matched = re.search(self.regex, str(msg), self.flags)
|
||||||
if matched:
|
if matched:
|
||||||
state[REGEX_MATCHED] = matched.group()
|
state[REGEX_MATCHED] = matched.group()
|
||||||
state[REGEX_GROUP] = matched.groups()
|
state[REGEX_GROUP] = matched.groups()
|
||||||
@ -482,12 +508,12 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
|
|||||||
\:\:\:
|
\:\:\:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return Rule(Regex(regex, flags))
|
return Rule(RegexRule(regex, flags))
|
||||||
|
|
||||||
|
|
||||||
class ToMe:
|
class ToMeRule:
|
||||||
async def __call__(self, event: Event) -> bool:
|
async def __call__(self, to_me: bool = EventToMe()) -> bool:
|
||||||
return event.is_tome()
|
return to_me
|
||||||
|
|
||||||
|
|
||||||
def to_me() -> Rule:
|
def to_me() -> Rule:
|
||||||
@ -501,4 +527,4 @@ def to_me() -> Rule:
|
|||||||
* 无
|
* 无
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return Rule(ToMe())
|
return Rule(ToMeRule())
|
||||||
|
Loading…
Reference in New Issue
Block a user