diff --git a/nonebot/internal/params.py b/nonebot/internal/params.py index 2d3885f5..9320742c 100644 --- a/nonebot/internal/params.py +++ b/nonebot/internal/params.py @@ -389,6 +389,10 @@ class ArgParam(Param): return cls( Required, key=param.default.key or param.name, type=param.default.type ) + elif get_origin(param.annotation) is Annotated: + for arg in get_args(param.annotation): + if isinstance(arg, ArgInner): + return cls(Required, key=arg.key or param.name, type=arg.type) async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any: key: str = self.extra["key"] diff --git a/tests/plugins/param/param_arg.py b/tests/plugins/param/param_arg.py index b137ce91..4c81cacd 100644 --- a/tests/plugins/param/param_arg.py +++ b/tests/plugins/param/param_arg.py @@ -1,3 +1,5 @@ +from typing_extensions import Annotated + from nonebot.adapters import Message from nonebot.params import Arg, ArgStr, ArgPlainText @@ -12,3 +14,15 @@ async def arg_str(key: str = ArgStr()) -> str: async def arg_plain_text(key: str = ArgPlainText()) -> str: return key + + +async def annotated_arg(key: Annotated[Message, Arg()]) -> Message: + return key + + +async def annotated_arg_str(key: Annotated[str, ArgStr()]) -> str: + return key + + +async def annotated_arg_plain_text(key: Annotated[str, ArgPlainText()]) -> str: + return key diff --git a/tests/test_param.py b/tests/test_param.py index 1ee0997b..9ea7f4c3 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -441,7 +441,14 @@ async def test_matcher(app: App): @pytest.mark.asyncio async def test_arg(app: App): - from plugins.param.param_arg import arg, arg_str, arg_plain_text + from plugins.param.param_arg import ( + arg, + arg_str, + annotated_arg, + arg_plain_text, + annotated_arg_str, + annotated_arg_plain_text, + ) matcher = Matcher() message = FakeMessage("text") @@ -459,6 +466,20 @@ async def test_arg(app: App): ctx.pass_params(matcher=matcher) ctx.should_return(message.extract_plain_text()) + async with app.test_dependent(annotated_arg, allow_types=[ArgParam]) as ctx: + ctx.pass_params(matcher=matcher) + ctx.should_return(message) + + async with app.test_dependent(annotated_arg_str, allow_types=[ArgParam]) as ctx: + ctx.pass_params(matcher=matcher) + ctx.should_return(str(message)) + + async with app.test_dependent( + annotated_arg_plain_text, allow_types=[ArgParam] + ) as ctx: + ctx.pass_params(matcher=matcher) + ctx.should_return(message.extract_plain_text()) + @pytest.mark.asyncio async def test_exception(app: App):