diff --git a/nonebot/internal/params.py b/nonebot/internal/params.py index 341333d1..0fd5bbc3 100644 --- a/nonebot/internal/params.py +++ b/nonebot/internal/params.py @@ -1,8 +1,10 @@ import asyncio import inspect +from typing_extensions import Annotated from contextlib import AsyncExitStack, contextmanager, asynccontextmanager from typing import TYPE_CHECKING, Any, Type, Tuple, Literal, Callable, Optional, cast +from pydantic.typing import get_args, get_origin from pydantic.fields import Required, Undefined, ModelField from nonebot.dependencies.utils import check_field_type @@ -78,21 +80,33 @@ class DependParam(Param): def _check_param( cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] ) -> Optional["DependParam"]: - if isinstance(param.default, DependsInner): - dependency: T_Handler - if param.default.dependency is None: - assert param.annotation is not param.empty, "Dependency cannot be empty" - dependency = param.annotation - else: - dependency = param.default.dependency - sub_dependent = Dependent[Any].parse( - call=dependency, - allow_types=allow_types, - ) - return cls( - Required, use_cache=param.default.use_cache, dependent=sub_dependent + type_annotation, depends_inner = param.annotation, None + if get_origin(param.annotation) is Annotated: + type_annotation, *extra_args = get_args(param.annotation) + depends_inner = next( + (x for x in extra_args if isinstance(x, DependsInner)), None ) + depends_inner = ( + param.default if isinstance(param.default, DependsInner) else depends_inner + ) + if depends_inner is None: + return + + dependency: T_Handler + if depends_inner.dependency is None: + assert ( + type_annotation is not inspect.Signature.empty + ), "Dependency cannot be empty" + dependency = type_annotation + else: + dependency = depends_inner.dependency + sub_dependent = Dependent[Any].parse( + call=dependency, + allow_types=allow_types, + ) + return cls(Required, use_cache=depends_inner.use_cache, dependent=sub_dependent) + @classmethod def _check_parameterless( cls, value: Any, allow_types: Tuple[Type[Param], ...] diff --git a/tests/plugins/param/param_depend.py b/tests/plugins/param/param_depend.py index 5e4c438b..dd87199d 100644 --- a/tests/plugins/param/param_depend.py +++ b/tests/plugins/param/param_depend.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing_extensions import Annotated from nonebot import on_message from nonebot.params import Depends @@ -47,3 +48,17 @@ async def depends_cache(y: int = Depends(dependency, use_cache=True)): async def class_depend(c: ClassDependency = Depends()): return c + + +async def annotated_depend(x: Annotated[int, Depends(dependency)]): + return x + + +async def annotated_class_depend(c: Annotated[ClassDependency, Depends()]): + return c + + +async def annotated_prior_depend( + x: Annotated[int, Depends(lambda: 2)] = Depends(dependency) +): + return x diff --git a/tests/test_param.py b/tests/test_param.py index b4171312..c2d81567 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -42,6 +42,9 @@ async def test_depend(app: App): depends, class_depend, test_depends, + annotated_depend, + annotated_class_depend, + annotated_prior_depend, ) async with app.test_dependent(depends, allow_types=[DependParam]) as ctx: @@ -63,6 +66,20 @@ async def test_depend(app: App): async with app.test_dependent(class_depend, allow_types=[DependParam]) as ctx: ctx.should_return(ClassDependency(x=1, y=2)) + async with app.test_dependent(annotated_depend, allow_types=[DependParam]) as ctx: + ctx.should_return(1) + + async with app.test_dependent( + annotated_prior_depend, allow_types=[DependParam] + ) as ctx: + ctx.should_return(1) + assert runned == [1, 1] + + async with app.test_dependent( + annotated_class_depend, allow_types=[DependParam] + ) as ctx: + ctx.should_return(ClassDependency(x=1, y=2)) + @pytest.mark.asyncio async def test_bot(app: App):