From 533e99418c9ecb315663c273171f0fed5c1e576d Mon Sep 17 00:00:00 2001 From: MeetWq Date: Wed, 20 Apr 2022 14:43:29 +0800 Subject: [PATCH] =?UTF-8?q?Feat:=20=E6=B7=BB=E5=8A=A0=20`CommandStart`=20?= =?UTF-8?q?=E4=BE=9D=E8=B5=96=E6=B3=A8=E5=85=A5=E5=8F=82=E6=95=B0=20(#915)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: yanyongyu <42488585+yanyongyu@users.noreply.github.com> --- nonebot/consts.py | 2 ++ nonebot/params.py | 10 ++++++ nonebot/rule.py | 34 +++++++++++++------ tests/plugins/param/param_state.py | 5 +++ tests/test_param.py | 15 +++++++- .../docs/tutorial/plugin/create-handler.md | 19 +++++++++++ 6 files changed, 74 insertions(+), 11 deletions(-) diff --git a/nonebot/consts.py b/nonebot/consts.py index 8672bd48..3fe93486 100644 --- a/nonebot/consts.py +++ b/nonebot/consts.py @@ -28,6 +28,8 @@ RAW_CMD_KEY: Literal["raw_command"] = "raw_command" """命令文本存储 key""" CMD_ARG_KEY: Literal["command_arg"] = "command_arg" """命令参数存储 key""" +CMD_START_KEY: Literal["command_start"] = "command_start" +"""命令开头存储 key""" SHELL_ARGS: Literal["_args"] = "_args" """shell 命令 parse 后参数字典存储 key""" diff --git a/nonebot/params.py b/nonebot/params.py index 9dca0ab2..a2c8f265 100644 --- a/nonebot/params.py +++ b/nonebot/params.py @@ -32,6 +32,7 @@ from nonebot.consts import ( CMD_ARG_KEY, RAW_CMD_KEY, REGEX_GROUP, + CMD_START_KEY, REGEX_MATCHED, ) @@ -99,6 +100,15 @@ def CommandArg() -> Any: return Depends(_command_arg) +def _command_start(state: T_State) -> str: + return state[PREFIX_KEY][CMD_START_KEY] + + +def CommandStart() -> str: + """消息命令开头""" + return Depends(_command_start) + + def _shell_command_args(state: T_State) -> Any: return state[SHELL_ARGS] diff --git a/nonebot/rule.py b/nonebot/rule.py index 2e0b0c9d..f10fdff8 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -14,7 +14,7 @@ from itertools import product from argparse import Namespace from typing_extensions import TypedDict from argparse import ArgumentParser as ArgParser -from typing import Any, List, Tuple, Union, Optional, Sequence +from typing import Any, List, Tuple, Union, Optional, Sequence, NamedTuple from pygtrie import CharTrie @@ -41,6 +41,7 @@ from nonebot.consts import ( CMD_ARG_KEY, RAW_CMD_KEY, REGEX_GROUP, + CMD_START_KEY, REGEX_MATCHED, ) @@ -50,15 +51,20 @@ CMD_RESULT = TypedDict( "command": Optional[Tuple[str, ...]], "raw_command": Optional[str], "command_arg": Optional[Message[MessageSegment]], + "command_start": Optional[str], }, ) +TRIE_VALUE = NamedTuple( + "TRIE_VALUE", [("command_start", str), ("command", Tuple[str, ...])] +) + class TrieRule: prefix: CharTrie = CharTrie() @classmethod - def add_prefix(cls, prefix: str, value: Any): + def add_prefix(cls, prefix: str, value: TRIE_VALUE) -> None: if prefix in cls.prefix: logger.warning(f'Duplicated prefix rule "{prefix}"') return @@ -66,7 +72,9 @@ class TrieRule: @classmethod def get_value(cls, bot: Bot, event: Event, state: T_State) -> CMD_RESULT: - prefix = CMD_RESULT(command=None, raw_command=None, command_arg=None) + prefix = CMD_RESULT( + command=None, raw_command=None, command_arg=None, command_start=None + ) state[PREFIX_KEY] = prefix if event.get_type() != "message": return prefix @@ -76,9 +84,11 @@ class TrieRule: if message_seg.is_text(): segment_text = str(message_seg).lstrip() pf = cls.prefix.longest_prefix(segment_text) - prefix[RAW_CMD_KEY] = pf.key - prefix[CMD_KEY] = pf.value - if pf.key: + if pf: + value: TRIE_VALUE = pf.value + prefix[RAW_CMD_KEY] = pf.key + prefix[CMD_START_KEY] = value.command_start + prefix[CMD_KEY] = value.command msg = message.copy() msg.pop(0) new_message = msg.__class__(segment_text[len(pf.key) :].lstrip()) @@ -292,10 +302,12 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule: if len(command) == 1: for start in command_start: - TrieRule.add_prefix(f"{start}{command[0]}", command) + TrieRule.add_prefix(f"{start}{command[0]}", TRIE_VALUE(start, command)) else: for start, sep in product(command_start, command_sep): - TrieRule.add_prefix(f"{start}{sep.join(command)}", command) + TrieRule.add_prefix( + f"{start}{sep.join(command)}", TRIE_VALUE(start, command) + ) return Rule(CommandRule(commands)) @@ -416,10 +428,12 @@ def shell_command( if len(command) == 1: for start in command_start: - TrieRule.add_prefix(f"{start}{command[0]}", command) + TrieRule.add_prefix(f"{start}{command[0]}", TRIE_VALUE(start, command)) else: for start, sep in product(command_start, command_sep): - TrieRule.add_prefix(f"{start}{sep.join(command)}", command) + TrieRule.add_prefix( + f"{start}{sep.join(command)}", TRIE_VALUE(start, command) + ) return Rule(ShellCommandRule(commands, parser)) diff --git a/tests/plugins/param/param_state.py b/tests/plugins/param/param_state.py index 636015fb..9d800c93 100644 --- a/tests/plugins/param/param_state.py +++ b/tests/plugins/param/param_state.py @@ -8,6 +8,7 @@ from nonebot.params import ( CommandArg, RawCommand, RegexGroup, + CommandStart, RegexMatched, ShellCommandArgs, ShellCommandArgv, @@ -30,6 +31,10 @@ async def command_arg(cmd_arg: Message = CommandArg()) -> Message: return cmd_arg +async def command_start(start: str = CommandStart()) -> str: + return start + + async def shell_command_args( shell_command_args: dict = ShellCommandArgs(), ) -> dict: diff --git a/tests/test_param.py b/tests/test_param.py index 52521a41..eaf07f61 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -122,6 +122,7 @@ async def test_state(app: App, load_plugin): CMD_ARG_KEY, RAW_CMD_KEY, REGEX_GROUP, + CMD_START_KEY, REGEX_MATCHED, ) from plugins.param.param_state import ( @@ -131,6 +132,7 @@ async def test_state(app: App, load_plugin): command_arg, raw_command, regex_group, + command_start, regex_matched, shell_command_args, shell_command_argv, @@ -138,7 +140,12 @@ async def test_state(app: App, load_plugin): fake_message = make_fake_message()("text") fake_state = { - PREFIX_KEY: {CMD_KEY: ("cmd",), RAW_CMD_KEY: "/cmd", CMD_ARG_KEY: fake_message}, + PREFIX_KEY: { + CMD_KEY: ("cmd",), + RAW_CMD_KEY: "/cmd", + CMD_START_KEY: "/", + CMD_ARG_KEY: fake_message, + }, SHELL_ARGV: ["-h"], SHELL_ARGS: {"help": True}, REGEX_MATCHED: "[cq:test,arg=value]", @@ -168,6 +175,12 @@ async def test_state(app: App, load_plugin): ctx.pass_params(state=fake_state) ctx.should_return(fake_state[PREFIX_KEY][CMD_ARG_KEY]) + async with app.test_dependent( + command_start, allow_types=[StateParam, DependParam] + ) as ctx: + ctx.pass_params(state=fake_state) + ctx.should_return(fake_state[PREFIX_KEY][CMD_START_KEY]) + async with app.test_dependent( shell_command_argv, allow_types=[StateParam, DependParam] ) as ctx: diff --git a/website/docs/tutorial/plugin/create-handler.md b/website/docs/tutorial/plugin/create-handler.md index 097a42bb..30649ebf 100644 --- a/website/docs/tutorial/plugin/create-handler.md +++ b/website/docs/tutorial/plugin/create-handler.md @@ -256,6 +256,25 @@ async def _(foo: Message = CommandArg()): ... 命令详情只能在首次接收到命令型消息时获取,如果在事件处理后续流程中获取,则会获取到不同的值。 ::: +### CommandStart + +获取命令型消息命令前缀。 + +```python {8} +from nonebot import on_command +from nonebot.adapters import Message +from nonebot.params import CommandStart + +matcher = on_command("cmd") + +@matcher.handle() +async def _(foo: str = CommandStart()): ... +``` + +:::tip 提示 +命令详情只能在首次接收到命令型消息时获取,如果在事件处理后续流程中获取,则会获取到不同的值。 +::: + ### ShellCommandArgs 获取 shell 命令解析后的参数。