🐛 Fix: 修复 ArgParam 不支持 Annotated (#2124)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
eya46 2023-06-24 19:18:24 +08:00 committed by GitHub
parent 3cdbf35dc6
commit 8de25447b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 1 deletions

View File

@ -389,6 +389,10 @@ class ArgParam(Param):
return cls( return cls(
Required, key=param.default.key or param.name, type=param.default.type Required, key=param.default.key or param.name, type=param.default.type
) )
elif get_origin(param.annotation) is Annotated:
for arg in get_args(param.annotation):
if isinstance(arg, ArgInner):
return cls(Required, key=arg.key or param.name, type=arg.type)
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any: async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
key: str = self.extra["key"] key: str = self.extra["key"]

View File

@ -1,3 +1,5 @@
from typing_extensions import Annotated
from nonebot.adapters import Message from nonebot.adapters import Message
from nonebot.params import Arg, ArgStr, ArgPlainText from nonebot.params import Arg, ArgStr, ArgPlainText
@ -12,3 +14,15 @@ async def arg_str(key: str = ArgStr()) -> str:
async def arg_plain_text(key: str = ArgPlainText()) -> str: async def arg_plain_text(key: str = ArgPlainText()) -> str:
return key return key
async def annotated_arg(key: Annotated[Message, Arg()]) -> Message:
return key
async def annotated_arg_str(key: Annotated[str, ArgStr()]) -> str:
return key
async def annotated_arg_plain_text(key: Annotated[str, ArgPlainText()]) -> str:
return key

View File

@ -441,7 +441,14 @@ async def test_matcher(app: App):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_arg(app: App): async def test_arg(app: App):
from plugins.param.param_arg import arg, arg_str, arg_plain_text from plugins.param.param_arg import (
arg,
arg_str,
annotated_arg,
arg_plain_text,
annotated_arg_str,
annotated_arg_plain_text,
)
matcher = Matcher() matcher = Matcher()
message = FakeMessage("text") message = FakeMessage("text")
@ -459,6 +466,20 @@ async def test_arg(app: App):
ctx.pass_params(matcher=matcher) ctx.pass_params(matcher=matcher)
ctx.should_return(message.extract_plain_text()) ctx.should_return(message.extract_plain_text())
async with app.test_dependent(annotated_arg, allow_types=[ArgParam]) as ctx:
ctx.pass_params(matcher=matcher)
ctx.should_return(message)
async with app.test_dependent(annotated_arg_str, allow_types=[ArgParam]) as ctx:
ctx.pass_params(matcher=matcher)
ctx.should_return(str(message))
async with app.test_dependent(
annotated_arg_plain_text, allow_types=[ArgParam]
) as ctx:
ctx.pass_params(matcher=matcher)
ctx.should_return(message.extract_plain_text())
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_exception(app: App): async def test_exception(app: App):