Feature: 命令匹配支持强制指定空白符 (#1748)

This commit is contained in:
Ju4tCode 2023-02-27 00:11:24 +08:00 committed by GitHub
parent f8c67ebdf6
commit 433c672130
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 154 additions and 19 deletions

View File

@ -30,6 +30,8 @@ CMD_ARG_KEY: Literal["command_arg"] = "command_arg"
"""命令参数存储 key""" """命令参数存储 key"""
CMD_START_KEY: Literal["command_start"] = "command_start" CMD_START_KEY: Literal["command_start"] = "command_start"
"""命令开头存储 key""" """命令开头存储 key"""
CMD_WHITESPACE_KEY: Literal["command_whitespace"] = "command_whitespace"
"""命令与参数间空白符存储 key"""
SHELL_ARGS: Literal["_args"] = "_args" SHELL_ARGS: Literal["_args"] = "_args"
"""shell 命令 parse 后参数字典存储 key""" """shell 命令 parse 后参数字典存储 key"""

View File

@ -39,6 +39,7 @@ from nonebot.consts import (
FULLMATCH_KEY, FULLMATCH_KEY,
REGEX_MATCHED, REGEX_MATCHED,
STARTSWITH_KEY, STARTSWITH_KEY,
CMD_WHITESPACE_KEY,
) )
@ -114,6 +115,15 @@ def CommandStart() -> str:
return Depends(_command_start) return Depends(_command_start)
def _command_whitespace(state: T_State) -> str:
return state[PREFIX_KEY][CMD_WHITESPACE_KEY]
def CommandWhitespace() -> str:
"""消息命令与参数之间的空白"""
return Depends(_command_whitespace)
def _shell_command_args(state: T_State) -> Any: def _shell_command_args(state: T_State) -> Any:
return state[SHELL_ARGS] # Namespace or ParserExit return state[SHELL_ARGS] # Namespace or ParserExit

View File

@ -349,6 +349,7 @@ def on_command(
cmd: Union[str, Tuple[str, ...]], cmd: Union[str, Tuple[str, ...]],
rule: Optional[Union[Rule, T_RuleChecker]] = None, rule: Optional[Union[Rule, T_RuleChecker]] = None,
aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = None, aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = None,
force_whitespace: Optional[Union[str, bool]] = None,
_depth: int = 0, _depth: int = 0,
**kwargs, **kwargs,
) -> Type[Matcher]: ) -> Type[Matcher]:
@ -360,6 +361,7 @@ def on_command(
cmd: 指定命令内容 cmd: 指定命令内容
rule: 事件响应规则 rule: 事件响应规则
aliases: 命令别名 aliases: 命令别名
force_whitespace: 是否强制命令后必须有指定空白符
permission: 事件响应权限 permission: 事件响应权限
handlers: 事件处理函数列表 handlers: 事件处理函数列表
temp: 是否为临时事件响应器仅执行一次 temp: 是否为临时事件响应器仅执行一次
@ -372,7 +374,10 @@ def on_command(
commands = {cmd} | (aliases or set()) commands = {cmd} | (aliases or set())
block = kwargs.pop("block", False) block = kwargs.pop("block", False)
return on_message( return on_message(
command(*commands) & rule, block=block, **kwargs, _depth=_depth + 1 command(*commands, force_whitespace=force_whitespace) & rule,
block=block,
**kwargs,
_depth=_depth + 1,
) )
@ -518,6 +523,7 @@ class CommandGroup(_Group):
参数: 参数:
cmd: 指定命令内容 cmd: 指定命令内容
aliases: 命令别名 aliases: 命令别名
force_whitespace: 是否强制命令后必须有指定空白符
rule: 事件响应规则 rule: 事件响应规则
permission: 事件响应权限 permission: 事件响应权限
handlers: 事件处理函数列表 handlers: 事件处理函数列表
@ -736,6 +742,7 @@ class MatcherGroup(_Group):
self, self,
cmd: Union[str, Tuple[str, ...]], cmd: Union[str, Tuple[str, ...]],
aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = None, aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = None,
force_whitespace: Optional[Union[str, bool]] = None,
**kwargs, **kwargs,
) -> Type[Matcher]: ) -> Type[Matcher]:
"""注册一个消息事件响应器,并且当消息以指定命令开头时响应。 """注册一个消息事件响应器,并且当消息以指定命令开头时响应。
@ -745,6 +752,7 @@ class MatcherGroup(_Group):
参数: 参数:
cmd: 指定命令内容 cmd: 指定命令内容
aliases: 命令别名 aliases: 命令别名
force_whitespace: 是否强制命令后必须有指定空白符
rule: 事件响应规则 rule: 事件响应规则
permission: 事件响应权限 permission: 事件响应权限
handlers: 事件处理函数列表 handlers: 事件处理函数列表
@ -755,7 +763,9 @@ class MatcherGroup(_Group):
state: 默认 state state: 默认 state
""" """
final_kwargs = self._get_final_kwargs(kwargs, exclude={"type"}) final_kwargs = self._get_final_kwargs(kwargs, exclude={"type"})
matcher = on_command(cmd, aliases=aliases, **final_kwargs) matcher = on_command(
cmd, aliases=aliases, force_whitespace=force_whitespace, **final_kwargs
)
self.matchers.append(matcher) self.matchers.append(matcher)
return matcher return matcher

View File

@ -117,6 +117,7 @@ def on_command(
cmd: Union[str, Tuple[str, ...]], cmd: Union[str, Tuple[str, ...]],
rule: Optional[Union[Rule, T_RuleChecker]] = ..., rule: Optional[Union[Rule, T_RuleChecker]] = ...,
aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = ..., aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = ...,
force_whitespace: Optional[Union[str, bool]] = ...,
*, *,
permission: Optional[Union[Permission, T_PermissionChecker]] = ..., permission: Optional[Union[Permission, T_PermissionChecker]] = ...,
handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., handlers: Optional[List[Union[T_Handler, Dependent]]] = ...,
@ -186,6 +187,7 @@ class CommandGroup:
*, *,
rule: Optional[Union[Rule, T_RuleChecker]] = ..., rule: Optional[Union[Rule, T_RuleChecker]] = ...,
aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = ..., aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = ...,
force_whitespace: Optional[Union[str, bool]] = ...,
permission: Optional[Union[Permission, T_PermissionChecker]] = ..., permission: Optional[Union[Permission, T_PermissionChecker]] = ...,
handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., handlers: Optional[List[Union[T_Handler, Dependent]]] = ...,
temp: bool = ..., temp: bool = ...,
@ -341,6 +343,7 @@ class MatcherGroup:
self, self,
cmd: Union[str, Tuple[str, ...]], cmd: Union[str, Tuple[str, ...]],
aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = ..., aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = ...,
force_whitespace: Optional[Union[str, bool]] = ...,
*, *,
rule: Optional[Union[Rule, T_RuleChecker]] = ..., rule: Optional[Union[Rule, T_RuleChecker]] = ...,
permission: Optional[Union[Permission, T_PermissionChecker]] = ..., permission: Optional[Union[Permission, T_PermissionChecker]] = ...,

View File

@ -39,8 +39,8 @@ from nonebot.log import logger
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot.exception import ParserExit from nonebot.exception import ParserExit
from nonebot.internal.rule import Rule as Rule from nonebot.internal.rule import Rule as Rule
from nonebot.params import Command, EventToMe, CommandArg
from nonebot.adapters import Bot, Event, Message, MessageSegment from nonebot.adapters import Bot, Event, Message, MessageSegment
from nonebot.params import Command, EventToMe, CommandArg, CommandWhitespace
from nonebot.consts import ( from nonebot.consts import (
CMD_KEY, CMD_KEY,
REGEX_STR, REGEX_STR,
@ -57,6 +57,7 @@ from nonebot.consts import (
FULLMATCH_KEY, FULLMATCH_KEY,
REGEX_MATCHED, REGEX_MATCHED,
STARTSWITH_KEY, STARTSWITH_KEY,
CMD_WHITESPACE_KEY,
) )
T = TypeVar("T") T = TypeVar("T")
@ -68,6 +69,7 @@ CMD_RESULT = TypedDict(
"raw_command": Optional[str], "raw_command": Optional[str],
"command_arg": Optional[Message[MessageSegment]], "command_arg": Optional[Message[MessageSegment]],
"command_start": Optional[str], "command_start": Optional[str],
"command_whitespace": Optional[str],
}, },
) )
@ -91,7 +93,11 @@ class TrieRule:
@classmethod @classmethod
def get_value(cls, bot: Bot, event: Event, state: T_State) -> CMD_RESULT: def get_value(cls, bot: Bot, event: Event, state: T_State) -> CMD_RESULT:
prefix = CMD_RESULT( prefix = CMD_RESULT(
command=None, raw_command=None, command_arg=None, command_start=None command=None,
raw_command=None,
command_arg=None,
command_start=None,
command_whitespace=None,
) )
state[PREFIX_KEY] = prefix state[PREFIX_KEY] = prefix
if event.get_type() != "message": if event.get_type() != "message":
@ -106,11 +112,25 @@ class TrieRule:
prefix[RAW_CMD_KEY] = pf.key prefix[RAW_CMD_KEY] = pf.key
prefix[CMD_START_KEY] = value.command_start prefix[CMD_START_KEY] = value.command_start
prefix[CMD_KEY] = value.command prefix[CMD_KEY] = value.command
msg = message.copy() msg = message.copy()
msg.pop(0) msg.pop(0)
new_message = msg.__class__(segment_text[len(pf.key) :].lstrip())
for new_segment in reversed(new_message): # check whitespace
msg.insert(0, new_segment) arg_str = segment_text[len(pf.key) :]
arg_str_stripped = arg_str.lstrip()
has_arg = arg_str_stripped or msg
if (
has_arg
and (stripped_len := len(arg_str) - len(arg_str_stripped)) > 0
):
prefix[CMD_WHITESPACE_KEY] = arg_str[:stripped_len]
# construct command arg
if arg_str_stripped:
new_message = msg.__class__(arg_str_stripped)
for new_segment in reversed(new_message):
msg.insert(0, new_segment)
prefix[CMD_ARG_KEY] = msg prefix[CMD_ARG_KEY] = msg
return prefix return prefix
@ -339,12 +359,18 @@ class CommandRule:
参数: 参数:
cmds: 指定命令元组列表 cmds: 指定命令元组列表
force_whitespace: 是否强制命令后必须有指定空白符
""" """
__slots__ = ("cmds",) __slots__ = ("cmds", "force_whitespace")
def __init__(self, cmds: List[Tuple[str, ...]]): def __init__(
self,
cmds: List[Tuple[str, ...]],
force_whitespace: Optional[Union[str, bool]] = None,
):
self.cmds = tuple(cmds) self.cmds = tuple(cmds)
self.force_whitespace = force_whitespace
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Command(cmds={self.cmds})" return f"Command(cmds={self.cmds})"
@ -357,11 +383,24 @@ class CommandRule:
def __hash__(self) -> int: def __hash__(self) -> int:
return hash((frozenset(self.cmds),)) return hash((frozenset(self.cmds),))
async def __call__(self, cmd: Optional[Tuple[str, ...]] = Command()) -> bool: async def __call__(
return cmd in self.cmds self,
cmd: Optional[Tuple[str, ...]] = Command(),
cmd_whitespace: Optional[str] = CommandWhitespace(),
) -> bool:
if cmd not in self.cmds:
return False
if self.force_whitespace is None:
return True
if isinstance(self.force_whitespace, str):
return self.force_whitespace == cmd_whitespace
return self.force_whitespace == (cmd_whitespace is not None)
def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule: def command(
*cmds: Union[str, Tuple[str, ...]],
force_whitespace: Optional[Union[str, bool]] = None,
) -> Rule:
"""匹配消息命令。 """匹配消息命令。
根据配置里提供的 {ref}``command_start` <nonebot.config.Config.command_start>`, 根据配置里提供的 {ref}``command_start` <nonebot.config.Config.command_start>`,
@ -373,6 +412,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
参数: 参数:
cmds: 命令文本或命令元组 cmds: 命令文本或命令元组
force_whitespace: 是否强制命令后必须有指定空白符
用法: 用法:
使用默认 `command_start`, `command_sep` 配置 使用默认 `command_start`, `command_sep` 配置
@ -404,7 +444,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
f"{start}{sep.join(command)}", TRIE_VALUE(start, command) f"{start}{sep.join(command)}", TRIE_VALUE(start, command)
) )
return Rule(CommandRule(commands)) return Rule(CommandRule(commands, force_whitespace))
class ArgumentParser(ArgParser): class ArgumentParser(ArgParser):

View File

@ -17,6 +17,7 @@ from nonebot.params import (
RegexMatched, RegexMatched,
ShellCommandArgs, ShellCommandArgs,
ShellCommandArgv, ShellCommandArgv,
CommandWhitespace,
) )
@ -48,6 +49,10 @@ async def command_start(start: str = CommandStart()) -> str:
return start return start
async def command_whitespace(whitespace: str = CommandWhitespace()) -> str:
return whitespace
async def shell_command_args( async def shell_command_args(
shell_command_args: dict = ShellCommandArgs(), shell_command_args: dict = ShellCommandArgs(),
) -> dict: ) -> dict:

View File

@ -30,6 +30,7 @@ from nonebot.consts import (
FULLMATCH_KEY, FULLMATCH_KEY,
REGEX_MATCHED, REGEX_MATCHED,
STARTSWITH_KEY, STARTSWITH_KEY,
CMD_WHITESPACE_KEY,
) )
@ -202,6 +203,7 @@ async def test_state(app: App):
command_start, command_start,
regex_matched, regex_matched,
not_legacy_state, not_legacy_state,
command_whitespace,
shell_command_args, shell_command_args,
shell_command_argv, shell_command_argv,
) )
@ -213,6 +215,7 @@ async def test_state(app: App):
RAW_CMD_KEY: "/cmd", RAW_CMD_KEY: "/cmd",
CMD_START_KEY: "/", CMD_START_KEY: "/",
CMD_ARG_KEY: fake_message, CMD_ARG_KEY: fake_message,
CMD_WHITESPACE_KEY: " ",
}, },
SHELL_ARGV: ["-h"], SHELL_ARGV: ["-h"],
SHELL_ARGS: {"help": True}, SHELL_ARGS: {"help": True},
@ -264,6 +267,12 @@ async def test_state(app: App):
ctx.pass_params(state=fake_state) ctx.pass_params(state=fake_state)
ctx.should_return(fake_state[PREFIX_KEY][CMD_START_KEY]) ctx.should_return(fake_state[PREFIX_KEY][CMD_START_KEY])
async with app.test_dependent(
command_whitespace, allow_types=[StateParam, DependParam]
) as ctx:
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state[PREFIX_KEY][CMD_WHITESPACE_KEY])
async with app.test_dependent( async with app.test_dependent(
shell_command_argv, allow_types=[StateParam, DependParam] shell_command_argv, allow_types=[StateParam, DependParam]
) as ctx: ) as ctx:

View File

@ -21,10 +21,14 @@ from nonebot.consts import (
FULLMATCH_KEY, FULLMATCH_KEY,
REGEX_MATCHED, REGEX_MATCHED,
STARTSWITH_KEY, STARTSWITH_KEY,
CMD_WHITESPACE_KEY,
) )
from nonebot.rule import ( from nonebot.rule import (
CMD_RESULT,
TRIE_VALUE,
Rule, Rule,
ToMeRule, ToMeRule,
TrieRule,
Namespace, Namespace,
RegexRule, RegexRule,
IsTypeRule, IsTypeRule,
@ -79,6 +83,44 @@ async def test_rule(app: App):
assert await Rule(truthy, skipped)(bot, event, {}) == False assert await Rule(truthy, skipped)(bot, event, {}) == False
@pytest.mark.asyncio
async def test_trie(app: App):
TrieRule.add_prefix("/fake-prefix", TRIE_VALUE("/", ("fake-prefix",)))
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
async with app.test_api() as ctx:
bot = ctx.create_bot()
message = Message("/fake-prefix some args")
event = make_fake_event(_message=message)()
state = {}
TrieRule.get_value(bot, event, state)
assert state[PREFIX_KEY] == CMD_RESULT(
command=("fake-prefix",),
raw_command="/fake-prefix",
command_arg=Message("some args"),
command_start="/",
command_whitespace=" ",
)
message = MessageSegment.text("/fake-prefix ") + MessageSegment.image(
"fake url"
)
event = make_fake_event(_message=message)()
state = {}
TrieRule.get_value(bot, event, state)
assert state[PREFIX_KEY] == CMD_RESULT(
command=("fake-prefix",),
raw_command="/fake-prefix",
command_arg=Message(MessageSegment.image("fake url")),
command_start="/",
command_whitespace=" ",
)
del TrieRule.prefix["/fake-prefix"]
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"msg, ignorecase, type, text, expected", "msg, ignorecase, type, text, expected",
@ -229,19 +271,33 @@ async def test_keyword(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"cmds", [(("help",),), (("help", "foo"),), (("help",), ("foo",))] "cmds, cmd, force_whitespace, whitespace, expected",
[
[(("help",),), ("help",), None, None, True],
[(("help",),), ("foo",), None, None, False],
[(("help", "foo"),), ("help", "foo"), True, " ", True],
[(("help",), ("foo",)), ("help",), " ", " ", True],
[(("help",),), ("help",), False, " ", False],
[(("help",),), ("help",), True, None, False],
[(("help",),), ("help",), "\n", " ", False],
],
) )
async def test_command(cmds: Tuple[Tuple[str, ...]]): async def test_command(
test_command = command(*cmds) cmds: Tuple[Tuple[str, ...]],
cmd: Tuple[str, ...],
force_whitespace: Optional[Union[str, bool]],
whitespace: Optional[str],
expected: bool,
):
test_command = command(*cmds, force_whitespace=force_whitespace)
dependent = list(test_command.checkers)[0] dependent = list(test_command.checkers)[0]
checker = dependent.call checker = dependent.call
assert isinstance(checker, CommandRule) assert isinstance(checker, CommandRule)
assert checker.cmds == cmds assert checker.cmds == cmds
for cmd in cmds: state = {PREFIX_KEY: {CMD_KEY: cmd, CMD_WHITESPACE_KEY: whitespace}}
state = {PREFIX_KEY: {CMD_KEY: cmd}} assert await dependent(state=state) == expected
assert await dependent(state=state)
@pytest.mark.asyncio @pytest.mark.asyncio