From 7aaa66c8baf3fd32011ae8203c6ffb5ea5cbbc6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bryan=E4=B8=8D=E5=8F=AF=E6=80=9D=E8=AE=AE?= Date: Thu, 14 Sep 2023 00:14:45 +0800 Subject: [PATCH] =?UTF-8?q?:sparkles:=20Feature:=20=E4=BC=98=E5=85=88?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=20`Annotated`=20=E7=9A=84=E6=9C=80=E5=90=8E?= =?UTF-8?q?=E4=B8=80=E4=B8=AA=E5=AD=90=E4=BE=9D=E8=B5=96=20(#2360)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> --- nonebot/internal/params.py | 4 ++-- tests/plugins/param/param_arg.py | 11 +++++++++++ tests/plugins/param/param_depend.py | 6 ++++++ tests/test_param.py | 19 ++++++++++++++++++- 4 files changed, 37 insertions(+), 3 deletions(-) diff --git a/nonebot/internal/params.py b/nonebot/internal/params.py index ca2dc8ad..c1820499 100644 --- a/nonebot/internal/params.py +++ b/nonebot/internal/params.py @@ -142,7 +142,7 @@ class DependParam(Param): if get_origin(param.annotation) is Annotated: type_annotation, *extra_args = get_args(param.annotation) depends_inner = next( - (x for x in extra_args if isinstance(x, DependsInner)), None + (x for x in reversed(extra_args) if isinstance(x, DependsInner)), None ) # param default value takes higher priority @@ -462,7 +462,7 @@ class ArgParam(Param): Required, key=param.default.key or param.name, type=param.default.type ) elif get_origin(param.annotation) is Annotated: - for arg in get_args(param.annotation): + for arg in get_args(param.annotation)[:0:-1]: if isinstance(arg, ArgInner): return cls(Required, key=arg.key or param.name, type=arg.type) diff --git a/tests/plugins/param/param_arg.py b/tests/plugins/param/param_arg.py index 4c81cacd..b38541d2 100644 --- a/tests/plugins/param/param_arg.py +++ b/tests/plugins/param/param_arg.py @@ -26,3 +26,14 @@ async def annotated_arg_str(key: Annotated[str, ArgStr()]) -> str: async def annotated_arg_plain_text(key: Annotated[str, ArgPlainText()]) -> str: return key + + +# test dependency priority +async def annotated_prior_arg(key: Annotated[str, ArgStr("foo")] = ArgPlainText()): + return key + + +async def annotated_multi_arg( + key: Annotated[Annotated[str, ArgStr("foo")], ArgPlainText()] +): + return key diff --git a/tests/plugins/param/param_depend.py b/tests/plugins/param/param_depend.py index 9a7f3fc2..63870b62 100644 --- a/tests/plugins/param/param_depend.py +++ b/tests/plugins/param/param_depend.py @@ -79,6 +79,12 @@ async def annotated_prior_depend( return x +async def annotated_multi_depend( + x: Annotated[Annotated[int, Depends(lambda: 2)], Depends(dependency)] +): + return x + + # test sub dependency type mismatch async def sub_type_mismatch(b: FooBot = Depends(sub_bot)): return b diff --git a/tests/test_param.py b/tests/test_param.py index 4795e2a6..3bbf70ae 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -51,6 +51,7 @@ async def test_depend(app: App): sub_type_mismatch, validate_field_fail, annotated_class_depend, + annotated_multi_depend, annotated_prior_depend, ) @@ -81,7 +82,13 @@ async def test_depend(app: App): annotated_prior_depend, allow_types=[DependParam] ) as ctx: ctx.should_return(1) - assert runned == [1, 1] + + async with app.test_dependent( + annotated_multi_depend, allow_types=[DependParam] + ) as ctx: + ctx.should_return(1) + + assert runned == [1, 1, 1] async with app.test_dependent( annotated_class_depend, allow_types=[DependParam] @@ -474,6 +481,8 @@ async def test_arg(app: App): annotated_arg, arg_plain_text, annotated_arg_str, + annotated_multi_arg, + annotated_prior_arg, annotated_arg_plain_text, ) @@ -507,6 +516,14 @@ async def test_arg(app: App): ctx.pass_params(matcher=matcher) ctx.should_return(message.extract_plain_text()) + async with app.test_dependent(annotated_multi_arg, allow_types=[ArgParam]) as ctx: + ctx.pass_params(matcher=matcher) + ctx.should_return(message.extract_plain_text()) + + async with app.test_dependent(annotated_prior_arg, allow_types=[ArgParam]) as ctx: + ctx.pass_params(matcher=matcher) + ctx.should_return(message.extract_plain_text()) + @pytest.mark.asyncio async def test_exception(app: App):