From 8de25447b30d4499a383106a9085cb2b674988ae Mon Sep 17 00:00:00 2001 From: eya46 <61458340+eya46@users.noreply.github.com> Date: Sat, 24 Jun 2023 19:18:24 +0800 Subject: [PATCH] =?UTF-8?q?:bug:=20Fix:=20=E4=BF=AE=E5=A4=8D=20`ArgParam`?= =?UTF-8?q?=20=E4=B8=8D=E6=94=AF=E6=8C=81=20`Annotated`=20(#2124)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- nonebot/internal/params.py | 4 ++++ tests/plugins/param/param_arg.py | 14 ++++++++++++++ tests/test_param.py | 23 ++++++++++++++++++++++- 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/nonebot/internal/params.py b/nonebot/internal/params.py index 2d3885f5..9320742c 100644 --- a/nonebot/internal/params.py +++ b/nonebot/internal/params.py @@ -389,6 +389,10 @@ class ArgParam(Param): return cls( Required, key=param.default.key or param.name, type=param.default.type ) + elif get_origin(param.annotation) is Annotated: + for arg in get_args(param.annotation): + if isinstance(arg, ArgInner): + return cls(Required, key=arg.key or param.name, type=arg.type) async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any: key: str = self.extra["key"] diff --git a/tests/plugins/param/param_arg.py b/tests/plugins/param/param_arg.py index b137ce91..4c81cacd 100644 --- a/tests/plugins/param/param_arg.py +++ b/tests/plugins/param/param_arg.py @@ -1,3 +1,5 @@ +from typing_extensions import Annotated + from nonebot.adapters import Message from nonebot.params import Arg, ArgStr, ArgPlainText @@ -12,3 +14,15 @@ async def arg_str(key: str = ArgStr()) -> str: async def arg_plain_text(key: str = ArgPlainText()) -> str: return key + + +async def annotated_arg(key: Annotated[Message, Arg()]) -> Message: + return key + + +async def annotated_arg_str(key: Annotated[str, ArgStr()]) -> str: + return key + + +async def annotated_arg_plain_text(key: Annotated[str, ArgPlainText()]) -> str: + return key diff --git a/tests/test_param.py b/tests/test_param.py index 1ee0997b..9ea7f4c3 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -441,7 +441,14 @@ async def test_matcher(app: App): @pytest.mark.asyncio async def test_arg(app: App): - from plugins.param.param_arg import arg, arg_str, arg_plain_text + from plugins.param.param_arg import ( + arg, + arg_str, + annotated_arg, + arg_plain_text, + annotated_arg_str, + annotated_arg_plain_text, + ) matcher = Matcher() message = FakeMessage("text") @@ -459,6 +466,20 @@ async def test_arg(app: App): ctx.pass_params(matcher=matcher) ctx.should_return(message.extract_plain_text()) + async with app.test_dependent(annotated_arg, allow_types=[ArgParam]) as ctx: + ctx.pass_params(matcher=matcher) + ctx.should_return(message) + + async with app.test_dependent(annotated_arg_str, allow_types=[ArgParam]) as ctx: + ctx.pass_params(matcher=matcher) + ctx.should_return(str(message)) + + async with app.test_dependent( + annotated_arg_plain_text, allow_types=[ArgParam] + ) as ctx: + ctx.pass_params(matcher=matcher) + ctx.should_return(message.extract_plain_text()) + @pytest.mark.asyncio async def test_exception(app: App):