mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 00:55:07 +08:00
🐛 Fix: 修复 ArgParam
不支持 Annotated
(#2124)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
3cdbf35dc6
commit
8de25447b3
@ -389,6 +389,10 @@ class ArgParam(Param):
|
|||||||
return cls(
|
return cls(
|
||||||
Required, key=param.default.key or param.name, type=param.default.type
|
Required, key=param.default.key or param.name, type=param.default.type
|
||||||
)
|
)
|
||||||
|
elif get_origin(param.annotation) is Annotated:
|
||||||
|
for arg in get_args(param.annotation):
|
||||||
|
if isinstance(arg, ArgInner):
|
||||||
|
return cls(Required, key=arg.key or param.name, type=arg.type)
|
||||||
|
|
||||||
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
|
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
|
||||||
key: str = self.extra["key"]
|
key: str = self.extra["key"]
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from nonebot.adapters import Message
|
from nonebot.adapters import Message
|
||||||
from nonebot.params import Arg, ArgStr, ArgPlainText
|
from nonebot.params import Arg, ArgStr, ArgPlainText
|
||||||
|
|
||||||
@ -12,3 +14,15 @@ async def arg_str(key: str = ArgStr()) -> str:
|
|||||||
|
|
||||||
async def arg_plain_text(key: str = ArgPlainText()) -> str:
|
async def arg_plain_text(key: str = ArgPlainText()) -> str:
|
||||||
return key
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
async def annotated_arg(key: Annotated[Message, Arg()]) -> Message:
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
async def annotated_arg_str(key: Annotated[str, ArgStr()]) -> str:
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
async def annotated_arg_plain_text(key: Annotated[str, ArgPlainText()]) -> str:
|
||||||
|
return key
|
||||||
|
@ -441,7 +441,14 @@ async def test_matcher(app: App):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_arg(app: App):
|
async def test_arg(app: App):
|
||||||
from plugins.param.param_arg import arg, arg_str, arg_plain_text
|
from plugins.param.param_arg import (
|
||||||
|
arg,
|
||||||
|
arg_str,
|
||||||
|
annotated_arg,
|
||||||
|
arg_plain_text,
|
||||||
|
annotated_arg_str,
|
||||||
|
annotated_arg_plain_text,
|
||||||
|
)
|
||||||
|
|
||||||
matcher = Matcher()
|
matcher = Matcher()
|
||||||
message = FakeMessage("text")
|
message = FakeMessage("text")
|
||||||
@ -459,6 +466,20 @@ async def test_arg(app: App):
|
|||||||
ctx.pass_params(matcher=matcher)
|
ctx.pass_params(matcher=matcher)
|
||||||
ctx.should_return(message.extract_plain_text())
|
ctx.should_return(message.extract_plain_text())
|
||||||
|
|
||||||
|
async with app.test_dependent(annotated_arg, allow_types=[ArgParam]) as ctx:
|
||||||
|
ctx.pass_params(matcher=matcher)
|
||||||
|
ctx.should_return(message)
|
||||||
|
|
||||||
|
async with app.test_dependent(annotated_arg_str, allow_types=[ArgParam]) as ctx:
|
||||||
|
ctx.pass_params(matcher=matcher)
|
||||||
|
ctx.should_return(str(message))
|
||||||
|
|
||||||
|
async with app.test_dependent(
|
||||||
|
annotated_arg_plain_text, allow_types=[ArgParam]
|
||||||
|
) as ctx:
|
||||||
|
ctx.pass_params(matcher=matcher)
|
||||||
|
ctx.should_return(message.extract_plain_text())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_exception(app: App):
|
async def test_exception(app: App):
|
||||||
|
Loading…
Reference in New Issue
Block a user