diff --git a/nonebot/params.py b/nonebot/params.py index 514290a1..c5b8c751 100644 --- a/nonebot/params.py +++ b/nonebot/params.py @@ -5,16 +5,16 @@ FrontMatter: description: nonebot.params 模块 """ -from typing import Any, Dict, List, Tuple, Optional +from typing import Any, Dict, List, Tuple, Union, Optional from nonebot.typing import T_State from nonebot.matcher import Matcher -from nonebot.adapters import Event, Message from nonebot.internal.params import Arg as Arg from nonebot.internal.params import ArgStr as ArgStr from nonebot.internal.params import Depends as Depends from nonebot.internal.params import ArgParam as ArgParam from nonebot.internal.params import BotParam as BotParam +from nonebot.adapters import Event, Message, MessageSegment from nonebot.internal.params import EventParam as EventParam from nonebot.internal.params import StateParam as StateParam from nonebot.internal.params import DependParam as DependParam @@ -109,15 +109,15 @@ def CommandStart() -> str: def _shell_command_args(state: T_State) -> Any: - return state[SHELL_ARGS] + return state[SHELL_ARGS] # Namespace or ParserExit -def ShellCommandArgs(): +def ShellCommandArgs() -> Any: """shell 命令解析后的参数字典""" return Depends(_shell_command_args, use_cache=False) -def _shell_command_argv(state: T_State) -> List[str]: +def _shell_command_argv(state: T_State) -> List[Union[str, MessageSegment]]: return state[SHELL_ARGV] diff --git a/nonebot/rule.py b/nonebot/rule.py index 2fedfa06..03b2df56 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -10,10 +10,26 @@ FrontMatter: import re import shlex -from itertools import product -from argparse import Namespace +from argparse import Action +from argparse import ArgumentError +from itertools import chain, product +from argparse import Namespace as Namespace from argparse import ArgumentParser as ArgParser -from typing import Any, List, Tuple, Union, Optional, Sequence, TypedDict, NamedTuple +from typing import ( + IO, + TYPE_CHECKING, + Any, + List, + Tuple, + Union, + TypeVar, + Optional, + Sequence, + TypedDict, + NamedTuple, + cast, + overload, +) from pygtrie import CharTrie @@ -44,6 +60,8 @@ from nonebot.consts import ( REGEX_MATCHED, ) +T = TypeVar("T") + CMD_RESULT = TypedDict( "CMD_RESULT", { @@ -318,25 +336,48 @@ class ArgumentParser(ArgParser): 参考文档: [argparse](https://docs.python.org/3/library/argparse.html) """ - def _print_message(self, message, file=None): - old_message: str = getattr(self, "message", "") - if old_message: - old_message += "\n" - old_message += message - setattr(self, "message", old_message) + if TYPE_CHECKING: - def exit(self, status: int = 0, message: Optional[str] = None): - raise ParserExit( - status=status, message=message or getattr(self, "message", None) + @overload + def parse_args( + self, args: Optional[Sequence[Union[str, MessageSegment]]] = ... + ) -> Namespace: + ... + + @overload + def parse_args( + self, args: Optional[Sequence[Union[str, MessageSegment]]], namespace: None + ) -> Namespace: + ... # type: ignore[misc] + + @overload + def parse_args( + self, args: Optional[Sequence[Union[str, MessageSegment]]], namespace: T + ) -> T: + ... + + def parse_args( + self, + args: Optional[Sequence[Union[str, MessageSegment]]] = None, + namespace: Optional[T] = None, + ) -> Union[Namespace, T]: + ... + + def _parse_optional( + self, arg_string: Union[str, MessageSegment] + ) -> Optional[Tuple[Optional[Action], str, Optional[str]]]: + return ( + super()._parse_optional(arg_string) if isinstance(arg_string, str) else None ) - def parse_args( - self, - args: Optional[Sequence[str]] = None, - namespace: Optional[Namespace] = None, - ) -> Namespace: - setattr(self, "message", "") - return super().parse_args(args=args, namespace=namespace) # type: ignore + def _print_message(self, message: str, file: Optional[IO[str]] = None): + if message: + setattr(self, "_message", getattr(self, "_message", "") + message) + + def exit(self, status: int = 0, message: Optional[str] = None): + if message: + self._print_message(message) + raise ParserExit(status=status, message=getattr(self, "_message", None)) class ShellCommandRule: @@ -359,19 +400,26 @@ class ShellCommandRule: cmd: Optional[Tuple[str, ...]] = Command(), msg: Optional[Message] = CommandArg(), ) -> bool: - if cmd in self.cmds and msg is not None: - message = str(msg) - state[SHELL_ARGV] = shlex.split(message) - if self.parser: - try: - args = self.parser.parse_args(state[SHELL_ARGV]) - state[SHELL_ARGS] = args - except ParserExit as e: - state[SHELL_ARGS] = e - return True - else: + if cmd not in self.cmds or msg is None: return False + state[SHELL_ARGV] = list( + chain.from_iterable( + shlex.split(str(seg)) if cast(MessageSegment, seg).is_text() else (seg,) + for seg in msg + ) + ) + + if self.parser: + try: + args = self.parser.parse_args(state[SHELL_ARGV]) + state[SHELL_ARGS] = args + except ArgumentError as e: + state[SHELL_ARGS] = ParserExit(status=2, message=str(e)) + except ParserExit as e: + state[SHELL_ARGS] = e + return True + def shell_command( *cmds: Union[str, Tuple[str, ...]], parser: Optional[ArgumentParser] = None diff --git a/tests/test_rule.py b/tests/test_rule.py index 683ef35c..4aac218f 100644 --- a/tests/test_rule.py +++ b/tests/test_rule.py @@ -1,3 +1,4 @@ +import sys from typing import Tuple, Union import pytest @@ -202,7 +203,104 @@ async def test_command(app: App, cmds: Tuple[Tuple[str, ...]]): assert await dependent(state=state) -# TODO: shell command +@pytest.mark.asyncio +async def test_shell_command(app: App): + from nonebot.typing import T_State + from nonebot.exception import ParserExit + from nonebot.consts import CMD_KEY, PREFIX_KEY, SHELL_ARGS, SHELL_ARGV, CMD_ARG_KEY + from nonebot.rule import Namespace, ArgumentParser, ShellCommandRule, shell_command + + state: T_State + CMD = ("test",) + Message = make_fake_message() + MessageSegment = Message.get_segment_class() + + test_not_cmd = shell_command(CMD) + dependent = list(test_not_cmd.checkers)[0] + checker = dependent.call + assert isinstance(checker, ShellCommandRule) + message = Message() + event = make_fake_event(_message=message)() + state = {PREFIX_KEY: {CMD_KEY: ("not",), CMD_ARG_KEY: message}} + assert not await dependent(event=event, state=state) + + test_no_parser = shell_command(CMD) + dependent = list(test_no_parser.checkers)[0] + checker = dependent.call + assert isinstance(checker, ShellCommandRule) + message = Message() + event = make_fake_event(_message=message)() + state = {PREFIX_KEY: {CMD_KEY: CMD, CMD_ARG_KEY: message}} + assert await dependent(event=event, state=state) + assert state[SHELL_ARGV] == [] + assert SHELL_ARGS not in state + + parser = ArgumentParser("test") + parser.add_argument("-a", required=True) + + test_simple_parser = shell_command(CMD, parser=parser) + dependent = list(test_simple_parser.checkers)[0] + checker = dependent.call + assert isinstance(checker, ShellCommandRule) + message = Message("-a 1") + event = make_fake_event(_message=message)() + state = {PREFIX_KEY: {CMD_KEY: CMD, CMD_ARG_KEY: message}} + assert await dependent(event=event, state=state) + assert state[SHELL_ARGV] == ["-a", "1"] + assert state[SHELL_ARGS] == Namespace(a="1") + + test_parser_help = shell_command(CMD, parser=parser) + dependent = list(test_parser_help.checkers)[0] + checker = dependent.call + assert isinstance(checker, ShellCommandRule) + message = Message("-h") + event = make_fake_event(_message=message)() + state = {PREFIX_KEY: {CMD_KEY: CMD, CMD_ARG_KEY: message}} + assert await dependent(event=event, state=state) + assert state[SHELL_ARGV] == ["-h"] + assert isinstance(state[SHELL_ARGS], ParserExit) + assert state[SHELL_ARGS].status == 0 + assert state[SHELL_ARGS].message == parser.format_help() + + test_parser_error = shell_command(CMD, parser=parser) + dependent = list(test_parser_error.checkers)[0] + checker = dependent.call + assert isinstance(checker, ShellCommandRule) + message = Message() + event = make_fake_event(_message=message)() + state = {PREFIX_KEY: {CMD_KEY: CMD, CMD_ARG_KEY: message}} + assert await dependent(event=event, state=state) + assert state[SHELL_ARGV] == [] + assert isinstance(state[SHELL_ARGS], ParserExit) + assert state[SHELL_ARGS].status != 0 + + test_message_parser = shell_command(CMD, parser=parser) + dependent = list(test_message_parser.checkers)[0] + checker = dependent.call + assert isinstance(checker, ShellCommandRule) + message = MessageSegment.text("-a") + MessageSegment.image("test") + event = make_fake_event(_message=message)() + state = {PREFIX_KEY: {CMD_KEY: CMD, CMD_ARG_KEY: message}} + assert await dependent(event=event, state=state) + assert state[SHELL_ARGV] == ["-a", MessageSegment.image("test")] + assert state[SHELL_ARGS] == Namespace(a=MessageSegment.image("test")) + + if sys.version_info >= (3, 9): + parser = ArgumentParser("test", exit_on_error=False) + parser.add_argument("-a", required=True) + + test_not_exit = shell_command(CMD, parser=parser) + dependent = list(test_not_exit.checkers)[0] + checker = dependent.call + assert isinstance(checker, ShellCommandRule) + message = Message() + event = make_fake_event(_message=message)() + state = {PREFIX_KEY: {CMD_KEY: CMD, CMD_ARG_KEY: message}} + assert await dependent(event=event, state=state) + assert state[SHELL_ARGV] == [] + assert isinstance(state[SHELL_ARGS], ParserExit) + assert state[SHELL_ARGS].status != 0 + # TODO: regex diff --git a/website/docs/tutorial/plugin/create-handler.md b/website/docs/tutorial/plugin/create-handler.md index 758bd8f3..c6aef475 100644 --- a/website/docs/tutorial/plugin/create-handler.md +++ b/website/docs/tutorial/plugin/create-handler.md @@ -277,7 +277,7 @@ async def _(foo: str = CommandStart()): ... ### ShellCommandArgs -获取 shell 命令解析后的参数。 +获取 shell 命令解析后的参数,支持 MessageSegment 富文本(如:图片)。 :::tip 提示 如果参数解析失败,则为 [`ParserExit`](../../api/exception.md#ParserExit) 异常,并携带错误码与错误信息。 @@ -288,21 +288,28 @@ async def _(foo: str = CommandStart()): ... ```python {8,12} from nonebot import on_shell_command from nonebot.params import ShellCommandArgs +from nonebot.rule import Namespace, ArgumentParser +parser = ArgumentParser("demo") +# parser.add_argument ... matcher = on_shell_command("cmd", parser) # 解析失败 @matcher.handle() -async def _(foo: ParserExit = ShellCommandArgs()): ... +async def _(foo: ParserExit = ShellCommandArgs()): + if foo.status == 0: + foo.message # help message + else: + foo.message # error message # 解析成功 @matcher.handle() -async def _(foo: Dict[str, Any] = ShellCommandArgs()): ... +async def _(foo: Namespace = ShellCommandArgs()): ... ``` ### ShellCommandArgv -获取 shell 命令解析前的参数列表。 +获取 shell 命令解析前的参数列表,支持 MessageSegment 富文本(如:图片)。 ```python {7} from nonebot import on_shell_command @@ -311,7 +318,7 @@ from nonebot.params import ShellCommandArgs matcher = on_shell_command("cmd") @matcher.handle() -async def _(foo: List[str] = ShellCommandArgv()): ... +async def _(foo: List[Union[str, MessageSegment]] = ShellCommandArgv()): ... ``` ### RegexMatched