From 433c672130fae2c298f11e7a0d1dc34f27f2201e Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Mon, 27 Feb 2023 00:11:24 +0800 Subject: [PATCH] =?UTF-8?q?:sparkles:=20Feature:=20=E5=91=BD=E4=BB=A4?= =?UTF-8?q?=E5=8C=B9=E9=85=8D=E6=94=AF=E6=8C=81=E5=BC=BA=E5=88=B6=E6=8C=87?= =?UTF-8?q?=E5=AE=9A=E7=A9=BA=E7=99=BD=E7=AC=A6=20(#1748)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot/consts.py | 2 + nonebot/params.py | 10 +++++ nonebot/plugin/on.py | 14 +++++- nonebot/plugin/on.pyi | 3 ++ nonebot/rule.py | 62 ++++++++++++++++++++++----- tests/plugins/param/param_state.py | 5 +++ tests/test_param.py | 9 ++++ tests/test_rule.py | 68 +++++++++++++++++++++++++++--- 8 files changed, 154 insertions(+), 19 deletions(-) diff --git a/nonebot/consts.py b/nonebot/consts.py index e0e4bb6b..c19c5feb 100644 --- a/nonebot/consts.py +++ b/nonebot/consts.py @@ -30,6 +30,8 @@ CMD_ARG_KEY: Literal["command_arg"] = "command_arg" """命令参数存储 key""" CMD_START_KEY: Literal["command_start"] = "command_start" """命令开头存储 key""" +CMD_WHITESPACE_KEY: Literal["command_whitespace"] = "command_whitespace" +"""命令与参数间空白符存储 key""" SHELL_ARGS: Literal["_args"] = "_args" """shell 命令 parse 后参数字典存储 key""" diff --git a/nonebot/params.py b/nonebot/params.py index 47e068ae..1a03f82e 100644 --- a/nonebot/params.py +++ b/nonebot/params.py @@ -39,6 +39,7 @@ from nonebot.consts import ( FULLMATCH_KEY, REGEX_MATCHED, STARTSWITH_KEY, + CMD_WHITESPACE_KEY, ) @@ -114,6 +115,15 @@ def CommandStart() -> str: return Depends(_command_start) +def _command_whitespace(state: T_State) -> str: + return state[PREFIX_KEY][CMD_WHITESPACE_KEY] + + +def CommandWhitespace() -> str: + """消息命令与参数之间的空白""" + return Depends(_command_whitespace) + + def _shell_command_args(state: T_State) -> Any: return state[SHELL_ARGS] # Namespace or ParserExit diff --git a/nonebot/plugin/on.py b/nonebot/plugin/on.py index d81dadea..5a7920d6 100644 --- a/nonebot/plugin/on.py +++ b/nonebot/plugin/on.py @@ -349,6 +349,7 @@ def on_command( cmd: Union[str, Tuple[str, ...]], rule: Optional[Union[Rule, T_RuleChecker]] = None, aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = None, + force_whitespace: Optional[Union[str, bool]] = None, _depth: int = 0, **kwargs, ) -> Type[Matcher]: @@ -360,6 +361,7 @@ def on_command( cmd: 指定命令内容 rule: 事件响应规则 aliases: 命令别名 + force_whitespace: 是否强制命令后必须有指定空白符 permission: 事件响应权限 handlers: 事件处理函数列表 temp: 是否为临时事件响应器(仅执行一次) @@ -372,7 +374,10 @@ def on_command( commands = {cmd} | (aliases or set()) block = kwargs.pop("block", False) return on_message( - command(*commands) & rule, block=block, **kwargs, _depth=_depth + 1 + command(*commands, force_whitespace=force_whitespace) & rule, + block=block, + **kwargs, + _depth=_depth + 1, ) @@ -518,6 +523,7 @@ class CommandGroup(_Group): 参数: cmd: 指定命令内容 aliases: 命令别名 + force_whitespace: 是否强制命令后必须有指定空白符 rule: 事件响应规则 permission: 事件响应权限 handlers: 事件处理函数列表 @@ -736,6 +742,7 @@ class MatcherGroup(_Group): self, cmd: Union[str, Tuple[str, ...]], aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = None, + force_whitespace: Optional[Union[str, bool]] = None, **kwargs, ) -> Type[Matcher]: """注册一个消息事件响应器,并且当消息以指定命令开头时响应。 @@ -745,6 +752,7 @@ class MatcherGroup(_Group): 参数: cmd: 指定命令内容 aliases: 命令别名 + force_whitespace: 是否强制命令后必须有指定空白符 rule: 事件响应规则 permission: 事件响应权限 handlers: 事件处理函数列表 @@ -755,7 +763,9 @@ class MatcherGroup(_Group): state: 默认 state """ final_kwargs = self._get_final_kwargs(kwargs, exclude={"type"}) - matcher = on_command(cmd, aliases=aliases, **final_kwargs) + matcher = on_command( + cmd, aliases=aliases, force_whitespace=force_whitespace, **final_kwargs + ) self.matchers.append(matcher) return matcher diff --git a/nonebot/plugin/on.pyi b/nonebot/plugin/on.pyi index ae399bad..fed781fe 100644 --- a/nonebot/plugin/on.pyi +++ b/nonebot/plugin/on.pyi @@ -117,6 +117,7 @@ def on_command( cmd: Union[str, Tuple[str, ...]], rule: Optional[Union[Rule, T_RuleChecker]] = ..., aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = ..., + force_whitespace: Optional[Union[str, bool]] = ..., *, permission: Optional[Union[Permission, T_PermissionChecker]] = ..., handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., @@ -186,6 +187,7 @@ class CommandGroup: *, rule: Optional[Union[Rule, T_RuleChecker]] = ..., aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = ..., + force_whitespace: Optional[Union[str, bool]] = ..., permission: Optional[Union[Permission, T_PermissionChecker]] = ..., handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., @@ -341,6 +343,7 @@ class MatcherGroup: self, cmd: Union[str, Tuple[str, ...]], aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = ..., + force_whitespace: Optional[Union[str, bool]] = ..., *, rule: Optional[Union[Rule, T_RuleChecker]] = ..., permission: Optional[Union[Permission, T_PermissionChecker]] = ..., diff --git a/nonebot/rule.py b/nonebot/rule.py index f49bb3c8..98ea0418 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -39,8 +39,8 @@ from nonebot.log import logger from nonebot.typing import T_State from nonebot.exception import ParserExit from nonebot.internal.rule import Rule as Rule -from nonebot.params import Command, EventToMe, CommandArg from nonebot.adapters import Bot, Event, Message, MessageSegment +from nonebot.params import Command, EventToMe, CommandArg, CommandWhitespace from nonebot.consts import ( CMD_KEY, REGEX_STR, @@ -57,6 +57,7 @@ from nonebot.consts import ( FULLMATCH_KEY, REGEX_MATCHED, STARTSWITH_KEY, + CMD_WHITESPACE_KEY, ) T = TypeVar("T") @@ -68,6 +69,7 @@ CMD_RESULT = TypedDict( "raw_command": Optional[str], "command_arg": Optional[Message[MessageSegment]], "command_start": Optional[str], + "command_whitespace": Optional[str], }, ) @@ -91,7 +93,11 @@ class TrieRule: @classmethod def get_value(cls, bot: Bot, event: Event, state: T_State) -> CMD_RESULT: prefix = CMD_RESULT( - command=None, raw_command=None, command_arg=None, command_start=None + command=None, + raw_command=None, + command_arg=None, + command_start=None, + command_whitespace=None, ) state[PREFIX_KEY] = prefix if event.get_type() != "message": @@ -106,11 +112,25 @@ class TrieRule: prefix[RAW_CMD_KEY] = pf.key prefix[CMD_START_KEY] = value.command_start prefix[CMD_KEY] = value.command + msg = message.copy() msg.pop(0) - new_message = msg.__class__(segment_text[len(pf.key) :].lstrip()) - for new_segment in reversed(new_message): - msg.insert(0, new_segment) + + # check whitespace + arg_str = segment_text[len(pf.key) :] + arg_str_stripped = arg_str.lstrip() + has_arg = arg_str_stripped or msg + if ( + has_arg + and (stripped_len := len(arg_str) - len(arg_str_stripped)) > 0 + ): + prefix[CMD_WHITESPACE_KEY] = arg_str[:stripped_len] + + # construct command arg + if arg_str_stripped: + new_message = msg.__class__(arg_str_stripped) + for new_segment in reversed(new_message): + msg.insert(0, new_segment) prefix[CMD_ARG_KEY] = msg return prefix @@ -339,12 +359,18 @@ class CommandRule: 参数: cmds: 指定命令元组列表 + force_whitespace: 是否强制命令后必须有指定空白符 """ - __slots__ = ("cmds",) + __slots__ = ("cmds", "force_whitespace") - def __init__(self, cmds: List[Tuple[str, ...]]): + def __init__( + self, + cmds: List[Tuple[str, ...]], + force_whitespace: Optional[Union[str, bool]] = None, + ): self.cmds = tuple(cmds) + self.force_whitespace = force_whitespace def __repr__(self) -> str: return f"Command(cmds={self.cmds})" @@ -357,11 +383,24 @@ class CommandRule: def __hash__(self) -> int: return hash((frozenset(self.cmds),)) - async def __call__(self, cmd: Optional[Tuple[str, ...]] = Command()) -> bool: - return cmd in self.cmds + async def __call__( + self, + cmd: Optional[Tuple[str, ...]] = Command(), + cmd_whitespace: Optional[str] = CommandWhitespace(), + ) -> bool: + if cmd not in self.cmds: + return False + if self.force_whitespace is None: + return True + if isinstance(self.force_whitespace, str): + return self.force_whitespace == cmd_whitespace + return self.force_whitespace == (cmd_whitespace is not None) -def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule: +def command( + *cmds: Union[str, Tuple[str, ...]], + force_whitespace: Optional[Union[str, bool]] = None, +) -> Rule: """匹配消息命令。 根据配置里提供的 {ref}``command_start` `, @@ -373,6 +412,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule: 参数: cmds: 命令文本或命令元组 + force_whitespace: 是否强制命令后必须有指定空白符 用法: 使用默认 `command_start`, `command_sep` 配置 @@ -404,7 +444,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule: f"{start}{sep.join(command)}", TRIE_VALUE(start, command) ) - return Rule(CommandRule(commands)) + return Rule(CommandRule(commands, force_whitespace)) class ArgumentParser(ArgParser): diff --git a/tests/plugins/param/param_state.py b/tests/plugins/param/param_state.py index 1baed2b3..77375064 100644 --- a/tests/plugins/param/param_state.py +++ b/tests/plugins/param/param_state.py @@ -17,6 +17,7 @@ from nonebot.params import ( RegexMatched, ShellCommandArgs, ShellCommandArgv, + CommandWhitespace, ) @@ -48,6 +49,10 @@ async def command_start(start: str = CommandStart()) -> str: return start +async def command_whitespace(whitespace: str = CommandWhitespace()) -> str: + return whitespace + + async def shell_command_args( shell_command_args: dict = ShellCommandArgs(), ) -> dict: diff --git a/tests/test_param.py b/tests/test_param.py index a8ba0560..b4171312 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -30,6 +30,7 @@ from nonebot.consts import ( FULLMATCH_KEY, REGEX_MATCHED, STARTSWITH_KEY, + CMD_WHITESPACE_KEY, ) @@ -202,6 +203,7 @@ async def test_state(app: App): command_start, regex_matched, not_legacy_state, + command_whitespace, shell_command_args, shell_command_argv, ) @@ -213,6 +215,7 @@ async def test_state(app: App): RAW_CMD_KEY: "/cmd", CMD_START_KEY: "/", CMD_ARG_KEY: fake_message, + CMD_WHITESPACE_KEY: " ", }, SHELL_ARGV: ["-h"], SHELL_ARGS: {"help": True}, @@ -264,6 +267,12 @@ async def test_state(app: App): ctx.pass_params(state=fake_state) ctx.should_return(fake_state[PREFIX_KEY][CMD_START_KEY]) + async with app.test_dependent( + command_whitespace, allow_types=[StateParam, DependParam] + ) as ctx: + ctx.pass_params(state=fake_state) + ctx.should_return(fake_state[PREFIX_KEY][CMD_WHITESPACE_KEY]) + async with app.test_dependent( shell_command_argv, allow_types=[StateParam, DependParam] ) as ctx: diff --git a/tests/test_rule.py b/tests/test_rule.py index fe11fd7d..727e6ef6 100644 --- a/tests/test_rule.py +++ b/tests/test_rule.py @@ -21,10 +21,14 @@ from nonebot.consts import ( FULLMATCH_KEY, REGEX_MATCHED, STARTSWITH_KEY, + CMD_WHITESPACE_KEY, ) from nonebot.rule import ( + CMD_RESULT, + TRIE_VALUE, Rule, ToMeRule, + TrieRule, Namespace, RegexRule, IsTypeRule, @@ -79,6 +83,44 @@ async def test_rule(app: App): assert await Rule(truthy, skipped)(bot, event, {}) == False +@pytest.mark.asyncio +async def test_trie(app: App): + TrieRule.add_prefix("/fake-prefix", TRIE_VALUE("/", ("fake-prefix",))) + + Message = make_fake_message() + MessageSegment = Message.get_segment_class() + + async with app.test_api() as ctx: + bot = ctx.create_bot() + message = Message("/fake-prefix some args") + event = make_fake_event(_message=message)() + state = {} + TrieRule.get_value(bot, event, state) + assert state[PREFIX_KEY] == CMD_RESULT( + command=("fake-prefix",), + raw_command="/fake-prefix", + command_arg=Message("some args"), + command_start="/", + command_whitespace=" ", + ) + + message = MessageSegment.text("/fake-prefix ") + MessageSegment.image( + "fake url" + ) + event = make_fake_event(_message=message)() + state = {} + TrieRule.get_value(bot, event, state) + assert state[PREFIX_KEY] == CMD_RESULT( + command=("fake-prefix",), + raw_command="/fake-prefix", + command_arg=Message(MessageSegment.image("fake url")), + command_start="/", + command_whitespace=" ", + ) + + del TrieRule.prefix["/fake-prefix"] + + @pytest.mark.asyncio @pytest.mark.parametrize( "msg, ignorecase, type, text, expected", @@ -229,19 +271,33 @@ async def test_keyword( @pytest.mark.asyncio @pytest.mark.parametrize( - "cmds", [(("help",),), (("help", "foo"),), (("help",), ("foo",))] + "cmds, cmd, force_whitespace, whitespace, expected", + [ + [(("help",),), ("help",), None, None, True], + [(("help",),), ("foo",), None, None, False], + [(("help", "foo"),), ("help", "foo"), True, " ", True], + [(("help",), ("foo",)), ("help",), " ", " ", True], + [(("help",),), ("help",), False, " ", False], + [(("help",),), ("help",), True, None, False], + [(("help",),), ("help",), "\n", " ", False], + ], ) -async def test_command(cmds: Tuple[Tuple[str, ...]]): - test_command = command(*cmds) +async def test_command( + cmds: Tuple[Tuple[str, ...]], + cmd: Tuple[str, ...], + force_whitespace: Optional[Union[str, bool]], + whitespace: Optional[str], + expected: bool, +): + test_command = command(*cmds, force_whitespace=force_whitespace) dependent = list(test_command.checkers)[0] checker = dependent.call assert isinstance(checker, CommandRule) assert checker.cmds == cmds - for cmd in cmds: - state = {PREFIX_KEY: {CMD_KEY: cmd}} - assert await dependent(state=state) + state = {PREFIX_KEY: {CMD_KEY: cmd, CMD_WHITESPACE_KEY: whitespace}} + assert await dependent(state=state) == expected @pytest.mark.asyncio