diff --git a/nonebot/rule.py b/nonebot/rule.py index 7c9cebde..db292a45 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -11,6 +11,7 @@ FrontMatter: import re import shlex from argparse import Action +from gettext import gettext from argparse import ArgumentError from contextvars import ContextVar from itertools import chain, product @@ -450,30 +451,61 @@ class ArgumentParser(ArgParser): if TYPE_CHECKING: @overload - def parse_args( - self, args: Optional[Sequence[Union[str, MessageSegment]]] = ... - ) -> Namespace: + def parse_known_args( + self, + args: Optional[Sequence[Union[str, MessageSegment]]] = None, + namespace: None = None, + ) -> Tuple[Namespace, List[Union[str, MessageSegment]]]: ... @overload - def parse_args( - self, args: Optional[Sequence[Union[str, MessageSegment]]], namespace: None - ) -> Namespace: - ... # type: ignore[misc] - - @overload - def parse_args( + def parse_known_args( self, args: Optional[Sequence[Union[str, MessageSegment]]], namespace: T - ) -> T: + ) -> Tuple[T, List[Union[str, MessageSegment]]]: ... - def parse_args( + @overload + def parse_known_args( + self, *, namespace: T + ) -> Tuple[T, List[Union[str, MessageSegment]]]: + ... + + def parse_known_args( self, args: Optional[Sequence[Union[str, MessageSegment]]] = None, namespace: Optional[T] = None, - ) -> Union[Namespace, T]: + ) -> Tuple[Union[Namespace, T], List[Union[str, MessageSegment]]]: ... + @overload + def parse_args( + self, + args: Optional[Sequence[Union[str, MessageSegment]]] = None, + namespace: None = None, + ) -> Namespace: + ... + + @overload + def parse_args( + self, args: Optional[Sequence[Union[str, MessageSegment]]], namespace: T + ) -> T: + ... + + @overload + def parse_args(self, *, namespace: T) -> T: + ... + + def parse_args( + self, + args: Optional[Sequence[Union[str, MessageSegment]]] = None, + namespace: Optional[T] = None, + ) -> Union[Namespace, T]: + result, argv = self.parse_known_args(args, namespace) + if argv: + msg = gettext("unrecognized arguments: %s") + self.error(msg % " ".join(map(str, argv))) + return cast(Union[Namespace, T], result) + def _parse_optional( self, arg_string: Union[str, MessageSegment] ) -> Optional[Tuple[Optional[Action], str, Optional[str]]]: diff --git a/tests/test_rule.py b/tests/test_rule.py index f6a17d79..09bd2ada 100644 --- a/tests/test_rule.py +++ b/tests/test_rule.py @@ -371,6 +371,19 @@ async def test_shell_command(): assert state[SHELL_ARGS].status != 0 assert state[SHELL_ARGS].message.startswith(parser.format_usage() + "test: error:") + test_parser_remain_args = shell_command(CMD, parser=parser) + dependent = list(test_parser_remain_args.checkers)[0] + checker = dependent.call + assert isinstance(checker, ShellCommandRule) + message = MessageSegment.text("-a 1 2") + MessageSegment.image("test") + event = make_fake_event(_message=message)() + state = {PREFIX_KEY: {CMD_KEY: CMD, CMD_ARG_KEY: message}} + assert await dependent(event=event, state=state) + assert state[SHELL_ARGV] == ["-a", "1", "2", MessageSegment.image("test")] + assert isinstance(state[SHELL_ARGS], ParserExit) + assert state[SHELL_ARGS].status != 0 + assert state[SHELL_ARGS].message.startswith(parser.format_usage() + "test: error:") + test_message_parser = shell_command(CMD, parser=parser) dependent = list(test_message_parser.checkers)[0] checker = dependent.call