From 723fa4b3d8f4a4d1c590e800c66c932c8668943c Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Thu, 9 May 2024 15:08:49 +0800 Subject: [PATCH] =?UTF-8?q?:bug:=20Fix:=20State=20ForwardRef=20=E6=A3=80?= =?UTF-8?q?=E6=B5=8B=E9=94=99=E8=AF=AF=20(#2698)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot/internal/params.py | 12 ++++++++++-- nonebot/typing.py | 10 +++++++++- tests/plugins/param/param_bot.py | 4 ++++ tests/plugins/param/param_event.py | 4 ++++ tests/plugins/param/param_matcher.py | 6 +++++- tests/plugins/param/param_state.py | 4 ++++ tests/test_param.py | 21 +++++++++++++++++++++ 7 files changed, 57 insertions(+), 4 deletions(-) diff --git a/nonebot/internal/params.py b/nonebot/internal/params.py index d48a6852..0d5b8dd5 100644 --- a/nonebot/internal/params.py +++ b/nonebot/internal/params.py @@ -17,8 +17,14 @@ from pydantic.fields import FieldInfo as PydanticFieldInfo from nonebot.dependencies import Param, Dependent from nonebot.dependencies.utils import check_field_type -from nonebot.typing import T_State, T_Handler, T_DependencyCache from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info +from nonebot.typing import ( + _STATE_FLAG, + T_State, + T_Handler, + T_DependencyCache, + origin_is_annotated, +) from nonebot.utils import ( get_name, run_sync, @@ -349,7 +355,9 @@ class StateParam(Param): cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...] ) -> Optional[Self]: # param type is T_State - if param.annotation is T_State: + if origin_is_annotated( + get_origin(param.annotation) + ) and _STATE_FLAG in get_args(param.annotation): return cls() # legacy: param is named "state" and has no type annotation elif param.annotation == param.empty and param.name == "state": diff --git a/nonebot/typing.py b/nonebot/typing.py index fcfb3763..600def25 100644 --- a/nonebot/typing.py +++ b/nonebot/typing.py @@ -108,7 +108,15 @@ def evaluate_forwardref( # state -T_State: TypeAlias = dict[t.Any, t.Any] +# use annotated flag to avoid ForwardRef recreate generic type (py >= 3.11) +class StateFlag: + def __repr__(self) -> str: + return "StateFlag()" + + +_STATE_FLAG = StateFlag() + +T_State: TypeAlias = t.Annotated[dict[t.Any, t.Any], _STATE_FLAG] """事件处理状态 State 类型""" _DependentCallable: TypeAlias = t.Union[ diff --git a/tests/plugins/param/param_bot.py b/tests/plugins/param/param_bot.py index 893693f9..1a10f805 100644 --- a/tests/plugins/param/param_bot.py +++ b/tests/plugins/param/param_bot.py @@ -7,6 +7,10 @@ async def get_bot(b: Bot) -> Bot: return b +async def postpone_bot(b: "Bot") -> Bot: + return b + + async def legacy_bot(bot): return bot diff --git a/tests/plugins/param/param_event.py b/tests/plugins/param/param_event.py index 840a2012..549d0711 100644 --- a/tests/plugins/param/param_event.py +++ b/tests/plugins/param/param_event.py @@ -8,6 +8,10 @@ async def event(e: Event) -> Event: return e +async def postpone_event(e: "Event") -> Event: + return e + + async def legacy_event(event): return event diff --git a/tests/plugins/param/param_matcher.py b/tests/plugins/param/param_matcher.py index 6534c7fd..514c52db 100644 --- a/tests/plugins/param/param_matcher.py +++ b/tests/plugins/param/param_matcher.py @@ -9,6 +9,10 @@ async def matcher(m: Matcher) -> Matcher: return m +async def postpone_matcher(m: "Matcher") -> Matcher: + return m + + async def legacy_matcher(matcher): return matcher @@ -27,7 +31,7 @@ class BarMatcher(Matcher): ... async def union_matcher( - m: Union[FooMatcher, BarMatcher] + m: Union[FooMatcher, BarMatcher], ) -> Union[FooMatcher, BarMatcher]: return m diff --git a/tests/plugins/param/param_state.py b/tests/plugins/param/param_state.py index 05cc2dbc..8d118771 100644 --- a/tests/plugins/param/param_state.py +++ b/tests/plugins/param/param_state.py @@ -25,6 +25,10 @@ async def state(x: T_State) -> T_State: return x +async def postpone_state(x: "T_State") -> T_State: + return x + + async def legacy_state(state): return state diff --git a/tests/test_param.py b/tests/test_param.py index 8a8b9008..1be92b6c 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -129,6 +129,7 @@ async def test_bot(app: App): union_bot, legacy_bot, generic_bot, + postpone_bot, not_legacy_bot, generic_bot_none, ) @@ -138,6 +139,11 @@ async def test_bot(app: App): ctx.pass_params(bot=bot) ctx.should_return(bot) + async with app.test_dependent(postpone_bot, allow_types=[BotParam]) as ctx: + bot = ctx.create_bot() + ctx.pass_params(bot=bot) + ctx.should_return(bot) + async with app.test_dependent(legacy_bot, allow_types=[BotParam]) as ctx: bot = ctx.create_bot() ctx.pass_params(bot=bot) @@ -188,6 +194,7 @@ async def test_event(app: App): legacy_event, event_message, generic_event, + postpone_event, event_plain_text, not_legacy_event, generic_event_none, @@ -201,6 +208,10 @@ async def test_event(app: App): ctx.pass_params(event=fake_event) ctx.should_return(fake_event) + async with app.test_dependent(postpone_event, allow_types=[EventParam]) as ctx: + ctx.pass_params(event=fake_event) + ctx.should_return(fake_event) + async with app.test_dependent(legacy_event, allow_types=[EventParam]) as ctx: ctx.pass_params(event=fake_event) ctx.should_return(fake_event) @@ -273,6 +284,7 @@ async def test_state(app: App): legacy_state, command_start, regex_matched, + postpone_state, not_legacy_state, command_whitespace, shell_command_args, @@ -302,6 +314,10 @@ async def test_state(app: App): ctx.pass_params(state=fake_state) ctx.should_return(fake_state) + async with app.test_dependent(postpone_state, allow_types=[StateParam]) as ctx: + ctx.pass_params(state=fake_state) + ctx.should_return(fake_state) + async with app.test_dependent(legacy_state, allow_types=[StateParam]) as ctx: ctx.pass_params(state=fake_state) ctx.should_return(fake_state) @@ -414,6 +430,7 @@ async def test_matcher(app: App): union_matcher, legacy_matcher, generic_matcher, + postpone_matcher, not_legacy_matcher, generic_matcher_none, ) @@ -425,6 +442,10 @@ async def test_matcher(app: App): ctx.pass_params(matcher=fake_matcher) ctx.should_return(fake_matcher) + async with app.test_dependent(postpone_matcher, allow_types=[MatcherParam]) as ctx: + ctx.pass_params(matcher=fake_matcher) + ctx.should_return(fake_matcher) + async with app.test_dependent(legacy_matcher, allow_types=[MatcherParam]) as ctx: ctx.pass_params(matcher=fake_matcher) ctx.should_return(fake_matcher)