Feature: 优先使用 Annotated 的最后一个子依赖 (#2360)

Co-authored-by: Ju4tCode <42488585+yanyongyu@users.noreply.github.com>
This commit is contained in:
Bryan不可思议 2023-09-14 00:14:45 +08:00 committed by GitHub
parent 0030bf725e
commit 7aaa66c8ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 37 additions and 3 deletions

View File

@ -142,7 +142,7 @@ class DependParam(Param):
if get_origin(param.annotation) is Annotated: if get_origin(param.annotation) is Annotated:
type_annotation, *extra_args = get_args(param.annotation) type_annotation, *extra_args = get_args(param.annotation)
depends_inner = next( depends_inner = next(
(x for x in extra_args if isinstance(x, DependsInner)), None (x for x in reversed(extra_args) if isinstance(x, DependsInner)), None
) )
# param default value takes higher priority # param default value takes higher priority
@ -462,7 +462,7 @@ class ArgParam(Param):
Required, key=param.default.key or param.name, type=param.default.type Required, key=param.default.key or param.name, type=param.default.type
) )
elif get_origin(param.annotation) is Annotated: elif get_origin(param.annotation) is Annotated:
for arg in get_args(param.annotation): for arg in get_args(param.annotation)[:0:-1]:
if isinstance(arg, ArgInner): if isinstance(arg, ArgInner):
return cls(Required, key=arg.key or param.name, type=arg.type) return cls(Required, key=arg.key or param.name, type=arg.type)

View File

@ -26,3 +26,14 @@ async def annotated_arg_str(key: Annotated[str, ArgStr()]) -> str:
async def annotated_arg_plain_text(key: Annotated[str, ArgPlainText()]) -> str: async def annotated_arg_plain_text(key: Annotated[str, ArgPlainText()]) -> str:
return key return key
# test dependency priority
async def annotated_prior_arg(key: Annotated[str, ArgStr("foo")] = ArgPlainText()):
return key
async def annotated_multi_arg(
key: Annotated[Annotated[str, ArgStr("foo")], ArgPlainText()]
):
return key

View File

@ -79,6 +79,12 @@ async def annotated_prior_depend(
return x return x
async def annotated_multi_depend(
x: Annotated[Annotated[int, Depends(lambda: 2)], Depends(dependency)]
):
return x
# test sub dependency type mismatch # test sub dependency type mismatch
async def sub_type_mismatch(b: FooBot = Depends(sub_bot)): async def sub_type_mismatch(b: FooBot = Depends(sub_bot)):
return b return b

View File

@ -51,6 +51,7 @@ async def test_depend(app: App):
sub_type_mismatch, sub_type_mismatch,
validate_field_fail, validate_field_fail,
annotated_class_depend, annotated_class_depend,
annotated_multi_depend,
annotated_prior_depend, annotated_prior_depend,
) )
@ -81,7 +82,13 @@ async def test_depend(app: App):
annotated_prior_depend, allow_types=[DependParam] annotated_prior_depend, allow_types=[DependParam]
) as ctx: ) as ctx:
ctx.should_return(1) ctx.should_return(1)
assert runned == [1, 1]
async with app.test_dependent(
annotated_multi_depend, allow_types=[DependParam]
) as ctx:
ctx.should_return(1)
assert runned == [1, 1, 1]
async with app.test_dependent( async with app.test_dependent(
annotated_class_depend, allow_types=[DependParam] annotated_class_depend, allow_types=[DependParam]
@ -474,6 +481,8 @@ async def test_arg(app: App):
annotated_arg, annotated_arg,
arg_plain_text, arg_plain_text,
annotated_arg_str, annotated_arg_str,
annotated_multi_arg,
annotated_prior_arg,
annotated_arg_plain_text, annotated_arg_plain_text,
) )
@ -507,6 +516,14 @@ async def test_arg(app: App):
ctx.pass_params(matcher=matcher) ctx.pass_params(matcher=matcher)
ctx.should_return(message.extract_plain_text()) ctx.should_return(message.extract_plain_text())
async with app.test_dependent(annotated_multi_arg, allow_types=[ArgParam]) as ctx:
ctx.pass_params(matcher=matcher)
ctx.should_return(message.extract_plain_text())
async with app.test_dependent(annotated_prior_arg, allow_types=[ArgParam]) as ctx:
ctx.pass_params(matcher=matcher)
ctx.should_return(message.extract_plain_text())
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_exception(app: App): async def test_exception(app: App):