mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-30 17:15:08 +08:00
🐛 Fix: State ForwardRef 检测错误 (#2698)
This commit is contained in:
parent
41b59cff06
commit
723fa4b3d8
@ -17,8 +17,14 @@ from pydantic.fields import FieldInfo as PydanticFieldInfo
|
|||||||
|
|
||||||
from nonebot.dependencies import Param, Dependent
|
from nonebot.dependencies import Param, Dependent
|
||||||
from nonebot.dependencies.utils import check_field_type
|
from nonebot.dependencies.utils import check_field_type
|
||||||
from nonebot.typing import T_State, T_Handler, T_DependencyCache
|
|
||||||
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info
|
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined, extract_field_info
|
||||||
|
from nonebot.typing import (
|
||||||
|
_STATE_FLAG,
|
||||||
|
T_State,
|
||||||
|
T_Handler,
|
||||||
|
T_DependencyCache,
|
||||||
|
origin_is_annotated,
|
||||||
|
)
|
||||||
from nonebot.utils import (
|
from nonebot.utils import (
|
||||||
get_name,
|
get_name,
|
||||||
run_sync,
|
run_sync,
|
||||||
@ -349,7 +355,9 @@ class StateParam(Param):
|
|||||||
cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...]
|
cls, param: inspect.Parameter, allow_types: tuple[type[Param], ...]
|
||||||
) -> Optional[Self]:
|
) -> Optional[Self]:
|
||||||
# param type is T_State
|
# param type is T_State
|
||||||
if param.annotation is T_State:
|
if origin_is_annotated(
|
||||||
|
get_origin(param.annotation)
|
||||||
|
) and _STATE_FLAG in get_args(param.annotation):
|
||||||
return cls()
|
return cls()
|
||||||
# legacy: param is named "state" and has no type annotation
|
# legacy: param is named "state" and has no type annotation
|
||||||
elif param.annotation == param.empty and param.name == "state":
|
elif param.annotation == param.empty and param.name == "state":
|
||||||
|
@ -108,7 +108,15 @@ def evaluate_forwardref(
|
|||||||
|
|
||||||
|
|
||||||
# state
|
# state
|
||||||
T_State: TypeAlias = dict[t.Any, t.Any]
|
# use annotated flag to avoid ForwardRef recreate generic type (py >= 3.11)
|
||||||
|
class StateFlag:
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return "StateFlag()"
|
||||||
|
|
||||||
|
|
||||||
|
_STATE_FLAG = StateFlag()
|
||||||
|
|
||||||
|
T_State: TypeAlias = t.Annotated[dict[t.Any, t.Any], _STATE_FLAG]
|
||||||
"""事件处理状态 State 类型"""
|
"""事件处理状态 State 类型"""
|
||||||
|
|
||||||
_DependentCallable: TypeAlias = t.Union[
|
_DependentCallable: TypeAlias = t.Union[
|
||||||
|
@ -7,6 +7,10 @@ async def get_bot(b: Bot) -> Bot:
|
|||||||
return b
|
return b
|
||||||
|
|
||||||
|
|
||||||
|
async def postpone_bot(b: "Bot") -> Bot:
|
||||||
|
return b
|
||||||
|
|
||||||
|
|
||||||
async def legacy_bot(bot):
|
async def legacy_bot(bot):
|
||||||
return bot
|
return bot
|
||||||
|
|
||||||
|
@ -8,6 +8,10 @@ async def event(e: Event) -> Event:
|
|||||||
return e
|
return e
|
||||||
|
|
||||||
|
|
||||||
|
async def postpone_event(e: "Event") -> Event:
|
||||||
|
return e
|
||||||
|
|
||||||
|
|
||||||
async def legacy_event(event):
|
async def legacy_event(event):
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
@ -9,6 +9,10 @@ async def matcher(m: Matcher) -> Matcher:
|
|||||||
return m
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
async def postpone_matcher(m: "Matcher") -> Matcher:
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
async def legacy_matcher(matcher):
|
async def legacy_matcher(matcher):
|
||||||
return matcher
|
return matcher
|
||||||
|
|
||||||
@ -27,7 +31,7 @@ class BarMatcher(Matcher): ...
|
|||||||
|
|
||||||
|
|
||||||
async def union_matcher(
|
async def union_matcher(
|
||||||
m: Union[FooMatcher, BarMatcher]
|
m: Union[FooMatcher, BarMatcher],
|
||||||
) -> Union[FooMatcher, BarMatcher]:
|
) -> Union[FooMatcher, BarMatcher]:
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
@ -25,6 +25,10 @@ async def state(x: T_State) -> T_State:
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
async def postpone_state(x: "T_State") -> T_State:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
async def legacy_state(state):
|
async def legacy_state(state):
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
@ -129,6 +129,7 @@ async def test_bot(app: App):
|
|||||||
union_bot,
|
union_bot,
|
||||||
legacy_bot,
|
legacy_bot,
|
||||||
generic_bot,
|
generic_bot,
|
||||||
|
postpone_bot,
|
||||||
not_legacy_bot,
|
not_legacy_bot,
|
||||||
generic_bot_none,
|
generic_bot_none,
|
||||||
)
|
)
|
||||||
@ -138,6 +139,11 @@ async def test_bot(app: App):
|
|||||||
ctx.pass_params(bot=bot)
|
ctx.pass_params(bot=bot)
|
||||||
ctx.should_return(bot)
|
ctx.should_return(bot)
|
||||||
|
|
||||||
|
async with app.test_dependent(postpone_bot, allow_types=[BotParam]) as ctx:
|
||||||
|
bot = ctx.create_bot()
|
||||||
|
ctx.pass_params(bot=bot)
|
||||||
|
ctx.should_return(bot)
|
||||||
|
|
||||||
async with app.test_dependent(legacy_bot, allow_types=[BotParam]) as ctx:
|
async with app.test_dependent(legacy_bot, allow_types=[BotParam]) as ctx:
|
||||||
bot = ctx.create_bot()
|
bot = ctx.create_bot()
|
||||||
ctx.pass_params(bot=bot)
|
ctx.pass_params(bot=bot)
|
||||||
@ -188,6 +194,7 @@ async def test_event(app: App):
|
|||||||
legacy_event,
|
legacy_event,
|
||||||
event_message,
|
event_message,
|
||||||
generic_event,
|
generic_event,
|
||||||
|
postpone_event,
|
||||||
event_plain_text,
|
event_plain_text,
|
||||||
not_legacy_event,
|
not_legacy_event,
|
||||||
generic_event_none,
|
generic_event_none,
|
||||||
@ -201,6 +208,10 @@ async def test_event(app: App):
|
|||||||
ctx.pass_params(event=fake_event)
|
ctx.pass_params(event=fake_event)
|
||||||
ctx.should_return(fake_event)
|
ctx.should_return(fake_event)
|
||||||
|
|
||||||
|
async with app.test_dependent(postpone_event, allow_types=[EventParam]) as ctx:
|
||||||
|
ctx.pass_params(event=fake_event)
|
||||||
|
ctx.should_return(fake_event)
|
||||||
|
|
||||||
async with app.test_dependent(legacy_event, allow_types=[EventParam]) as ctx:
|
async with app.test_dependent(legacy_event, allow_types=[EventParam]) as ctx:
|
||||||
ctx.pass_params(event=fake_event)
|
ctx.pass_params(event=fake_event)
|
||||||
ctx.should_return(fake_event)
|
ctx.should_return(fake_event)
|
||||||
@ -273,6 +284,7 @@ async def test_state(app: App):
|
|||||||
legacy_state,
|
legacy_state,
|
||||||
command_start,
|
command_start,
|
||||||
regex_matched,
|
regex_matched,
|
||||||
|
postpone_state,
|
||||||
not_legacy_state,
|
not_legacy_state,
|
||||||
command_whitespace,
|
command_whitespace,
|
||||||
shell_command_args,
|
shell_command_args,
|
||||||
@ -302,6 +314,10 @@ async def test_state(app: App):
|
|||||||
ctx.pass_params(state=fake_state)
|
ctx.pass_params(state=fake_state)
|
||||||
ctx.should_return(fake_state)
|
ctx.should_return(fake_state)
|
||||||
|
|
||||||
|
async with app.test_dependent(postpone_state, allow_types=[StateParam]) as ctx:
|
||||||
|
ctx.pass_params(state=fake_state)
|
||||||
|
ctx.should_return(fake_state)
|
||||||
|
|
||||||
async with app.test_dependent(legacy_state, allow_types=[StateParam]) as ctx:
|
async with app.test_dependent(legacy_state, allow_types=[StateParam]) as ctx:
|
||||||
ctx.pass_params(state=fake_state)
|
ctx.pass_params(state=fake_state)
|
||||||
ctx.should_return(fake_state)
|
ctx.should_return(fake_state)
|
||||||
@ -414,6 +430,7 @@ async def test_matcher(app: App):
|
|||||||
union_matcher,
|
union_matcher,
|
||||||
legacy_matcher,
|
legacy_matcher,
|
||||||
generic_matcher,
|
generic_matcher,
|
||||||
|
postpone_matcher,
|
||||||
not_legacy_matcher,
|
not_legacy_matcher,
|
||||||
generic_matcher_none,
|
generic_matcher_none,
|
||||||
)
|
)
|
||||||
@ -425,6 +442,10 @@ async def test_matcher(app: App):
|
|||||||
ctx.pass_params(matcher=fake_matcher)
|
ctx.pass_params(matcher=fake_matcher)
|
||||||
ctx.should_return(fake_matcher)
|
ctx.should_return(fake_matcher)
|
||||||
|
|
||||||
|
async with app.test_dependent(postpone_matcher, allow_types=[MatcherParam]) as ctx:
|
||||||
|
ctx.pass_params(matcher=fake_matcher)
|
||||||
|
ctx.should_return(fake_matcher)
|
||||||
|
|
||||||
async with app.test_dependent(legacy_matcher, allow_types=[MatcherParam]) as ctx:
|
async with app.test_dependent(legacy_matcher, allow_types=[MatcherParam]) as ctx:
|
||||||
ctx.pass_params(matcher=fake_matcher)
|
ctx.pass_params(matcher=fake_matcher)
|
||||||
ctx.should_return(fake_matcher)
|
ctx.should_return(fake_matcher)
|
||||||
|
Loading…
Reference in New Issue
Block a user