Feat: 添加 CommandStart 依赖注入参数 (#915)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: yanyongyu <42488585+yanyongyu@users.noreply.github.com>
This commit is contained in:
MeetWq 2022-04-20 14:43:29 +08:00 committed by GitHub
parent f989710cd6
commit 533e99418c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 74 additions and 11 deletions

View File

@ -28,6 +28,8 @@ RAW_CMD_KEY: Literal["raw_command"] = "raw_command"
"""命令文本存储 key"""
CMD_ARG_KEY: Literal["command_arg"] = "command_arg"
"""命令参数存储 key"""
CMD_START_KEY: Literal["command_start"] = "command_start"
"""命令开头存储 key"""
SHELL_ARGS: Literal["_args"] = "_args"
"""shell 命令 parse 后参数字典存储 key"""

View File

@ -32,6 +32,7 @@ from nonebot.consts import (
CMD_ARG_KEY,
RAW_CMD_KEY,
REGEX_GROUP,
CMD_START_KEY,
REGEX_MATCHED,
)
@ -99,6 +100,15 @@ def CommandArg() -> Any:
return Depends(_command_arg)
def _command_start(state: T_State) -> str:
return state[PREFIX_KEY][CMD_START_KEY]
def CommandStart() -> str:
"""消息命令开头"""
return Depends(_command_start)
def _shell_command_args(state: T_State) -> Any:
return state[SHELL_ARGS]

View File

@ -14,7 +14,7 @@ from itertools import product
from argparse import Namespace
from typing_extensions import TypedDict
from argparse import ArgumentParser as ArgParser
from typing import Any, List, Tuple, Union, Optional, Sequence
from typing import Any, List, Tuple, Union, Optional, Sequence, NamedTuple
from pygtrie import CharTrie
@ -41,6 +41,7 @@ from nonebot.consts import (
CMD_ARG_KEY,
RAW_CMD_KEY,
REGEX_GROUP,
CMD_START_KEY,
REGEX_MATCHED,
)
@ -50,15 +51,20 @@ CMD_RESULT = TypedDict(
"command": Optional[Tuple[str, ...]],
"raw_command": Optional[str],
"command_arg": Optional[Message[MessageSegment]],
"command_start": Optional[str],
},
)
TRIE_VALUE = NamedTuple(
"TRIE_VALUE", [("command_start", str), ("command", Tuple[str, ...])]
)
class TrieRule:
prefix: CharTrie = CharTrie()
@classmethod
def add_prefix(cls, prefix: str, value: Any):
def add_prefix(cls, prefix: str, value: TRIE_VALUE) -> None:
if prefix in cls.prefix:
logger.warning(f'Duplicated prefix rule "{prefix}"')
return
@ -66,7 +72,9 @@ class TrieRule:
@classmethod
def get_value(cls, bot: Bot, event: Event, state: T_State) -> CMD_RESULT:
prefix = CMD_RESULT(command=None, raw_command=None, command_arg=None)
prefix = CMD_RESULT(
command=None, raw_command=None, command_arg=None, command_start=None
)
state[PREFIX_KEY] = prefix
if event.get_type() != "message":
return prefix
@ -76,9 +84,11 @@ class TrieRule:
if message_seg.is_text():
segment_text = str(message_seg).lstrip()
pf = cls.prefix.longest_prefix(segment_text)
if pf:
value: TRIE_VALUE = pf.value
prefix[RAW_CMD_KEY] = pf.key
prefix[CMD_KEY] = pf.value
if pf.key:
prefix[CMD_START_KEY] = value.command_start
prefix[CMD_KEY] = value.command
msg = message.copy()
msg.pop(0)
new_message = msg.__class__(segment_text[len(pf.key) :].lstrip())
@ -292,10 +302,12 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
if len(command) == 1:
for start in command_start:
TrieRule.add_prefix(f"{start}{command[0]}", command)
TrieRule.add_prefix(f"{start}{command[0]}", TRIE_VALUE(start, command))
else:
for start, sep in product(command_start, command_sep):
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
TrieRule.add_prefix(
f"{start}{sep.join(command)}", TRIE_VALUE(start, command)
)
return Rule(CommandRule(commands))
@ -416,10 +428,12 @@ def shell_command(
if len(command) == 1:
for start in command_start:
TrieRule.add_prefix(f"{start}{command[0]}", command)
TrieRule.add_prefix(f"{start}{command[0]}", TRIE_VALUE(start, command))
else:
for start, sep in product(command_start, command_sep):
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
TrieRule.add_prefix(
f"{start}{sep.join(command)}", TRIE_VALUE(start, command)
)
return Rule(ShellCommandRule(commands, parser))

View File

@ -8,6 +8,7 @@ from nonebot.params import (
CommandArg,
RawCommand,
RegexGroup,
CommandStart,
RegexMatched,
ShellCommandArgs,
ShellCommandArgv,
@ -30,6 +31,10 @@ async def command_arg(cmd_arg: Message = CommandArg()) -> Message:
return cmd_arg
async def command_start(start: str = CommandStart()) -> str:
return start
async def shell_command_args(
shell_command_args: dict = ShellCommandArgs(),
) -> dict:

View File

@ -122,6 +122,7 @@ async def test_state(app: App, load_plugin):
CMD_ARG_KEY,
RAW_CMD_KEY,
REGEX_GROUP,
CMD_START_KEY,
REGEX_MATCHED,
)
from plugins.param.param_state import (
@ -131,6 +132,7 @@ async def test_state(app: App, load_plugin):
command_arg,
raw_command,
regex_group,
command_start,
regex_matched,
shell_command_args,
shell_command_argv,
@ -138,7 +140,12 @@ async def test_state(app: App, load_plugin):
fake_message = make_fake_message()("text")
fake_state = {
PREFIX_KEY: {CMD_KEY: ("cmd",), RAW_CMD_KEY: "/cmd", CMD_ARG_KEY: fake_message},
PREFIX_KEY: {
CMD_KEY: ("cmd",),
RAW_CMD_KEY: "/cmd",
CMD_START_KEY: "/",
CMD_ARG_KEY: fake_message,
},
SHELL_ARGV: ["-h"],
SHELL_ARGS: {"help": True},
REGEX_MATCHED: "[cq:test,arg=value]",
@ -168,6 +175,12 @@ async def test_state(app: App, load_plugin):
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state[PREFIX_KEY][CMD_ARG_KEY])
async with app.test_dependent(
command_start, allow_types=[StateParam, DependParam]
) as ctx:
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state[PREFIX_KEY][CMD_START_KEY])
async with app.test_dependent(
shell_command_argv, allow_types=[StateParam, DependParam]
) as ctx:

View File

@ -256,6 +256,25 @@ async def _(foo: Message = CommandArg()): ...
命令详情只能在首次接收到命令型消息时获取,如果在事件处理后续流程中获取,则会获取到不同的值。
:::
### CommandStart
获取命令型消息命令前缀。
```python {8}
from nonebot import on_command
from nonebot.adapters import Message
from nonebot.params import CommandStart
matcher = on_command("cmd")
@matcher.handle()
async def _(foo: str = CommandStart()): ...
```
:::tip 提示
命令详情只能在首次接收到命令型消息时获取,如果在事件处理后续流程中获取,则会获取到不同的值。
:::
### ShellCommandArgs
获取 shell 命令解析后的参数。