From 6feed0610b61b2b17b58ecd40ec34b3b2b7175d0 Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Sun, 22 May 2022 19:42:30 +0800 Subject: [PATCH] :bug: fix union validation error (#1001) --- nonebot/internal/adapter/event.py | 9 ++++ nonebot/internal/params.py | 17 +++---- nonebot/utils.py | 10 ++--- tests/plugins/param/param_bot.py | 26 ++++++++++- tests/plugins/param/param_event.py | 26 ++++++++++- tests/plugins/param/param_state.py | 8 ++++ tests/test_param.py | 72 +++++++++++++++++++++++++++--- 7 files changed, 144 insertions(+), 24 deletions(-) diff --git a/nonebot/internal/adapter/event.py b/nonebot/internal/adapter/event.py index 8ef3123b..f839e88a 100644 --- a/nonebot/internal/adapter/event.py +++ b/nonebot/internal/adapter/event.py @@ -1,4 +1,5 @@ import abc +from typing import Any, Type, TypeVar from pydantic import BaseModel @@ -6,6 +7,8 @@ from nonebot.utils import DataclassEncoder from .message import Message +E = TypeVar("E", bound="Event") + class Event(abc.ABC, BaseModel): """Event 基类。提供获取关键信息的方法,其余信息可直接获取。""" @@ -14,6 +17,12 @@ class Event(abc.ABC, BaseModel): extra = "allow" json_encoders = {Message: DataclassEncoder} + @classmethod + def validate(cls: Type["E"], value: Any) -> "E": + if isinstance(value, Event) and not isinstance(value, cls): + raise TypeError(f"{value} is incompatible with Event type {cls}") + return super().validate(value) + @abc.abstractmethod def get_type(self) -> str: """获取事件类型的方法,类型通常为 NoneBot 内置的四种类型。""" diff --git a/nonebot/internal/params.py b/nonebot/internal/params.py index 207a8e10..23fd4176 100644 --- a/nonebot/internal/params.py +++ b/nonebot/internal/params.py @@ -9,6 +9,7 @@ from pydantic.fields import Required, Undefined, ModelField from nonebot.log import logger from nonebot.exception import TypeMisMatch +from nonebot.dependencies.utils import check_field_type from nonebot.dependencies import Param, Dependent, CustomConfig from nonebot.typing import T_State, T_Handler, T_DependencyCache from nonebot.utils import ( @@ -159,14 +160,14 @@ class DependParam(Param): class _BotChecker(Param): async def _solve(self, bot: "Bot", **kwargs: Any) -> Any: field: ModelField = self.extra["field"] - if isinstance(bot, field.type_): - return bot - else: + try: + return check_field_type(field, bot) + except TypeMisMatch: logger.debug( f"Bot type {type(bot)} not match " f"annotation {field._type_display()}, ignored" ) - raise TypeMisMatch(field, bot) + raise class BotParam(Param): @@ -205,14 +206,14 @@ class BotParam(Param): class _EventChecker(Param): async def _solve(self, event: "Event", **kwargs: Any) -> Any: field: ModelField = self.extra["field"] - if isinstance(event, field.type_): - return event - else: + try: + return check_field_type(field, event) + except TypeMisMatch: logger.debug( f"Event type {type(event)} not match " f"annotation {field._type_display()}, ignored" ) - raise TypeMisMatch(field, event) + raise class EventParam(Param): diff --git a/nonebot/utils.py b/nonebot/utils.py index 33e8e0bc..652a9924 100644 --- a/nonebot/utils.py +++ b/nonebot/utils.py @@ -62,12 +62,10 @@ def generic_check_issubclass( except TypeError: origin = get_origin(cls) if is_union(origin): - for type_ in get_args(cls): - if not is_none_type(type_) and not generic_check_issubclass( - type_, class_or_tuple - ): - return False - return True + return all( + is_none_type(type_) or generic_check_issubclass(type_, class_or_tuple) + for type_ in get_args(cls) + ) elif origin: return issubclass(origin, class_or_tuple) return False diff --git a/tests/plugins/param/param_bot.py b/tests/plugins/param/param_bot.py index 08129673..40c343f5 100644 --- a/tests/plugins/param/param_bot.py +++ b/tests/plugins/param/param_bot.py @@ -1,3 +1,5 @@ +from typing import Union + from nonebot.adapters import Bot @@ -5,9 +7,29 @@ async def get_bot(b: Bot) -> Bot: return b -class SubBot(Bot): +async def legacy_bot(bot): + return bot + + +async def not_legacy_bot(bot: int): ... -async def sub_bot(b: SubBot) -> SubBot: +class FooBot(Bot): + ... + + +async def sub_bot(b: FooBot) -> FooBot: return b + + +class BarBot(Bot): + ... + + +async def union_bot(b: Union[FooBot, BarBot]) -> Union[FooBot, BarBot]: + return b + + +async def not_bot(b: Union[int, Bot]): + ... diff --git a/tests/plugins/param/param_event.py b/tests/plugins/param/param_event.py index a526cf79..3f4005fe 100644 --- a/tests/plugins/param/param_event.py +++ b/tests/plugins/param/param_event.py @@ -1,3 +1,5 @@ +from typing import Union + from nonebot.adapters import Event, Message from nonebot.params import EventToMe, EventType, EventMessage, EventPlainText @@ -6,14 +8,34 @@ async def event(e: Event) -> Event: return e -class SubEvent(Event): +async def legacy_event(event): + return event + + +async def not_legacy_event(event: int): ... -async def sub_event(e: SubEvent) -> SubEvent: +class FooEvent(Event): + ... + + +async def sub_event(e: FooEvent) -> FooEvent: return e +class BarEvent(Event): + ... + + +async def union_event(e: Union[FooEvent, BarEvent]) -> Union[FooEvent, BarEvent]: + return e + + +async def not_event(e: Union[int, Event]): + ... + + async def event_type(t: str = EventType()) -> str: return t diff --git a/tests/plugins/param/param_state.py b/tests/plugins/param/param_state.py index 9d800c93..d9a2e21d 100644 --- a/tests/plugins/param/param_state.py +++ b/tests/plugins/param/param_state.py @@ -19,6 +19,14 @@ async def state(x: T_State) -> T_State: return x +async def legacy_state(state): + return state + + +async def not_legacy_state(state: int): + ... + + async def command(cmd: Tuple[str, ...] = Command()) -> Tuple[str, ...]: return cmd diff --git a/tests/test_param.py b/tests/test_param.py index eaf07f61..e904ba14 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -37,15 +37,32 @@ async def test_depend(app: App, load_plugin): async def test_bot(app: App, load_plugin): from nonebot.params import BotParam from nonebot.exception import TypeMisMatch - from plugins.param.param_bot import SubBot, get_bot, sub_bot + from plugins.param.param_bot import ( + FooBot, + get_bot, + not_bot, + sub_bot, + union_bot, + legacy_bot, + not_legacy_bot, + ) async with app.test_dependent(get_bot, allow_types=[BotParam]) as ctx: bot = ctx.create_bot() ctx.pass_params(bot=bot) ctx.should_return(bot) + async with app.test_dependent(legacy_bot, allow_types=[BotParam]) as ctx: + bot = ctx.create_bot() + ctx.pass_params(bot=bot) + ctx.should_return(bot) + + with pytest.raises(ValueError): + async with app.test_dependent(not_legacy_bot, allow_types=[BotParam]) as ctx: + ... + async with app.test_dependent(sub_bot, allow_types=[BotParam]) as ctx: - bot = ctx.create_bot(base=SubBot) + bot = ctx.create_bot(base=FooBot) ctx.pass_params(bot=bot) ctx.should_return(bot) @@ -54,37 +71,68 @@ async def test_bot(app: App, load_plugin): bot = ctx.create_bot() ctx.pass_params(bot=bot) + async with app.test_dependent(union_bot, allow_types=[BotParam]) as ctx: + bot = ctx.create_bot(base=FooBot) + ctx.pass_params(bot=bot) + ctx.should_return(bot) + + with pytest.raises(ValueError): + async with app.test_dependent(not_bot, allow_types=[BotParam]) as ctx: + ... + @pytest.mark.asyncio async def test_event(app: App, load_plugin): from nonebot.exception import TypeMisMatch from nonebot.params import EventParam, DependParam from plugins.param.param_event import ( - SubEvent, + FooEvent, event, + not_event, sub_event, event_type, event_to_me, + union_event, + legacy_event, event_message, event_plain_text, + not_legacy_event, ) fake_message = make_fake_message()("text") fake_event = make_fake_event(_message=fake_message)() - fake_subevent = make_fake_event(_base=SubEvent)() + fake_fooevent = make_fake_event(_base=FooEvent)() async with app.test_dependent(event, allow_types=[EventParam]) as ctx: ctx.pass_params(event=fake_event) ctx.should_return(fake_event) + async with app.test_dependent(legacy_event, allow_types=[EventParam]) as ctx: + ctx.pass_params(event=fake_event) + ctx.should_return(fake_event) + + with pytest.raises(ValueError): + async with app.test_dependent( + not_legacy_event, allow_types=[EventParam] + ) as ctx: + ... + async with app.test_dependent(sub_event, allow_types=[EventParam]) as ctx: - ctx.pass_params(event=fake_subevent) - ctx.should_return(fake_subevent) + ctx.pass_params(event=fake_fooevent) + ctx.should_return(fake_fooevent) with pytest.raises(TypeMisMatch): async with app.test_dependent(sub_event, allow_types=[EventParam]) as ctx: ctx.pass_params(event=fake_event) + async with app.test_dependent(union_event, allow_types=[EventParam]) as ctx: + ctx.pass_params(event=fake_fooevent) + ctx.should_return(fake_event) + + with pytest.raises(ValueError): + async with app.test_dependent(not_event, allow_types=[EventParam]) as ctx: + ... + async with app.test_dependent( event_type, allow_types=[EventParam, DependParam] ) as ctx: @@ -132,8 +180,10 @@ async def test_state(app: App, load_plugin): command_arg, raw_command, regex_group, + legacy_state, command_start, regex_matched, + not_legacy_state, shell_command_args, shell_command_argv, ) @@ -157,6 +207,16 @@ async def test_state(app: App, load_plugin): ctx.pass_params(state=fake_state) ctx.should_return(fake_state) + async with app.test_dependent(legacy_state, allow_types=[StateParam]) as ctx: + ctx.pass_params(state=fake_state) + ctx.should_return(fake_state) + + with pytest.raises(ValueError): + async with app.test_dependent( + not_legacy_state, allow_types=[StateParam] + ) as ctx: + ... + async with app.test_dependent( command, allow_types=[StateParam, DependParam] ) as ctx: