diff --git a/nonebot/internal/params.py b/nonebot/internal/params.py index ca2dc8ad..c1820499 100644 --- a/nonebot/internal/params.py +++ b/nonebot/internal/params.py @@ -142,7 +142,7 @@ class DependParam(Param): if get_origin(param.annotation) is Annotated: type_annotation, *extra_args = get_args(param.annotation) depends_inner = next( - (x for x in extra_args if isinstance(x, DependsInner)), None + (x for x in reversed(extra_args) if isinstance(x, DependsInner)), None ) # param default value takes higher priority @@ -462,7 +462,7 @@ class ArgParam(Param): Required, key=param.default.key or param.name, type=param.default.type ) elif get_origin(param.annotation) is Annotated: - for arg in get_args(param.annotation): + for arg in get_args(param.annotation)[:0:-1]: if isinstance(arg, ArgInner): return cls(Required, key=arg.key or param.name, type=arg.type) diff --git a/tests/plugins/param/param_arg.py b/tests/plugins/param/param_arg.py index 4c81cacd..b38541d2 100644 --- a/tests/plugins/param/param_arg.py +++ b/tests/plugins/param/param_arg.py @@ -26,3 +26,14 @@ async def annotated_arg_str(key: Annotated[str, ArgStr()]) -> str: async def annotated_arg_plain_text(key: Annotated[str, ArgPlainText()]) -> str: return key + + +# test dependency priority +async def annotated_prior_arg(key: Annotated[str, ArgStr("foo")] = ArgPlainText()): + return key + + +async def annotated_multi_arg( + key: Annotated[Annotated[str, ArgStr("foo")], ArgPlainText()] +): + return key diff --git a/tests/plugins/param/param_depend.py b/tests/plugins/param/param_depend.py index 9a7f3fc2..63870b62 100644 --- a/tests/plugins/param/param_depend.py +++ b/tests/plugins/param/param_depend.py @@ -79,6 +79,12 @@ async def annotated_prior_depend( return x +async def annotated_multi_depend( + x: Annotated[Annotated[int, Depends(lambda: 2)], Depends(dependency)] +): + return x + + # test sub dependency type mismatch async def sub_type_mismatch(b: FooBot = Depends(sub_bot)): return b diff --git a/tests/test_param.py b/tests/test_param.py index 4795e2a6..3bbf70ae 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -51,6 +51,7 @@ async def test_depend(app: App): sub_type_mismatch, validate_field_fail, annotated_class_depend, + annotated_multi_depend, annotated_prior_depend, ) @@ -81,7 +82,13 @@ async def test_depend(app: App): annotated_prior_depend, allow_types=[DependParam] ) as ctx: ctx.should_return(1) - assert runned == [1, 1] + + async with app.test_dependent( + annotated_multi_depend, allow_types=[DependParam] + ) as ctx: + ctx.should_return(1) + + assert runned == [1, 1, 1] async with app.test_dependent( annotated_class_depend, allow_types=[DependParam] @@ -474,6 +481,8 @@ async def test_arg(app: App): annotated_arg, arg_plain_text, annotated_arg_str, + annotated_multi_arg, + annotated_prior_arg, annotated_arg_plain_text, ) @@ -507,6 +516,14 @@ async def test_arg(app: App): ctx.pass_params(matcher=matcher) ctx.should_return(message.extract_plain_text()) + async with app.test_dependent(annotated_multi_arg, allow_types=[ArgParam]) as ctx: + ctx.pass_params(matcher=matcher) + ctx.should_return(message.extract_plain_text()) + + async with app.test_dependent(annotated_prior_arg, allow_types=[ArgParam]) as ctx: + ctx.pass_params(matcher=matcher) + ctx.should_return(message.extract_plain_text()) + @pytest.mark.asyncio async def test_exception(app: App):