From 45e2e6c280bfd7529d753df5dfc877ac46105de4 Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Sun, 20 Mar 2022 19:40:43 +0800 Subject: [PATCH] :bug: fix event maybe converted when checking type (#876) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix: 修复 event 类型检查会对类型进行自动转换 --- nonebot/internal/params.py | 17 ++++++++--------- tests/plugins/param/param_bot.py | 10 +++++++++- tests/plugins/param/param_event.py | 8 ++++++++ tests/test_init.py | 10 ++-------- tests/test_param.py | 25 ++++++++++++++++++++++++- tests/utils.py | 3 ++- 6 files changed, 53 insertions(+), 20 deletions(-) diff --git a/nonebot/internal/params.py b/nonebot/internal/params.py index 23fd4176..207a8e10 100644 --- a/nonebot/internal/params.py +++ b/nonebot/internal/params.py @@ -9,7 +9,6 @@ from pydantic.fields import Required, Undefined, ModelField from nonebot.log import logger from nonebot.exception import TypeMisMatch -from nonebot.dependencies.utils import check_field_type from nonebot.dependencies import Param, Dependent, CustomConfig from nonebot.typing import T_State, T_Handler, T_DependencyCache from nonebot.utils import ( @@ -160,14 +159,14 @@ class DependParam(Param): class _BotChecker(Param): async def _solve(self, bot: "Bot", **kwargs: Any) -> Any: field: ModelField = self.extra["field"] - try: - return check_field_type(field, bot) - except TypeMisMatch: + if isinstance(bot, field.type_): + return bot + else: logger.debug( f"Bot type {type(bot)} not match " f"annotation {field._type_display()}, ignored" ) - raise + raise TypeMisMatch(field, bot) class BotParam(Param): @@ -206,14 +205,14 @@ class BotParam(Param): class _EventChecker(Param): async def _solve(self, event: "Event", **kwargs: Any) -> Any: field: ModelField = self.extra["field"] - try: - return check_field_type(field, event) - except TypeMisMatch: + if isinstance(event, field.type_): + return event + else: logger.debug( f"Event type {type(event)} not match " f"annotation {field._type_display()}, ignored" ) - raise + raise TypeMisMatch(field, event) class EventParam(Param): diff --git a/tests/plugins/param/param_bot.py b/tests/plugins/param/param_bot.py index a6befdda..08129673 100644 --- a/tests/plugins/param/param_bot.py +++ b/tests/plugins/param/param_bot.py @@ -1,5 +1,13 @@ from nonebot.adapters import Bot -async def get_bot(b: Bot): +async def get_bot(b: Bot) -> Bot: + return b + + +class SubBot(Bot): + ... + + +async def sub_bot(b: SubBot) -> SubBot: return b diff --git a/tests/plugins/param/param_event.py b/tests/plugins/param/param_event.py index 3cc04570..a526cf79 100644 --- a/tests/plugins/param/param_event.py +++ b/tests/plugins/param/param_event.py @@ -6,6 +6,14 @@ async def event(e: Event) -> Event: return e +class SubEvent(Event): + ... + + +async def sub_event(e: SubEvent) -> SubEvent: + return e + + async def event_type(t: str = EventType()) -> str: return t diff --git a/tests/test_init.py b/tests/test_init.py index 6d36dccd..a62520d0 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -35,11 +35,8 @@ async def test_get(monkeypatch: pytest.MonkeyPatch, nonebug_clear): from nonebot.drivers import ForwardDriver, ReverseDriver from nonebot import get_app, get_bot, get_asgi, get_bots, get_driver - try: + with pytest.raises(ValueError): get_driver() - assert False, "Driver can only be got after initialization" - except ValueError: - assert True nonebot.init(driver="nonebot.drivers.fastapi") @@ -59,11 +56,8 @@ async def test_get(monkeypatch: pytest.MonkeyPatch, nonebug_clear): nonebot.run("arg", kwarg="kwarg") assert runned - try: + with pytest.raises(ValueError): get_bot() - assert False - except ValueError: - assert True monkeypatch.setattr(driver, "_clients", {"test": "test"}) assert get_bot() == "test" diff --git a/tests/test_param.py b/tests/test_param.py index e5157635..52521a41 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -36,19 +36,33 @@ async def test_depend(app: App, load_plugin): @pytest.mark.asyncio async def test_bot(app: App, load_plugin): from nonebot.params import BotParam - from plugins.param.param_bot import get_bot + from nonebot.exception import TypeMisMatch + from plugins.param.param_bot import SubBot, get_bot, sub_bot async with app.test_dependent(get_bot, allow_types=[BotParam]) as ctx: bot = ctx.create_bot() ctx.pass_params(bot=bot) ctx.should_return(bot) + async with app.test_dependent(sub_bot, allow_types=[BotParam]) as ctx: + bot = ctx.create_bot(base=SubBot) + ctx.pass_params(bot=bot) + ctx.should_return(bot) + + with pytest.raises(TypeMisMatch): + async with app.test_dependent(sub_bot, allow_types=[BotParam]) as ctx: + bot = ctx.create_bot() + ctx.pass_params(bot=bot) + @pytest.mark.asyncio async def test_event(app: App, load_plugin): + from nonebot.exception import TypeMisMatch from nonebot.params import EventParam, DependParam from plugins.param.param_event import ( + SubEvent, event, + sub_event, event_type, event_to_me, event_message, @@ -57,11 +71,20 @@ async def test_event(app: App, load_plugin): fake_message = make_fake_message()("text") fake_event = make_fake_event(_message=fake_message)() + fake_subevent = make_fake_event(_base=SubEvent)() async with app.test_dependent(event, allow_types=[EventParam]) as ctx: ctx.pass_params(event=fake_event) ctx.should_return(fake_event) + async with app.test_dependent(sub_event, allow_types=[EventParam]) as ctx: + ctx.pass_params(event=fake_subevent) + ctx.should_return(fake_subevent) + + with pytest.raises(TypeMisMatch): + async with app.test_dependent(sub_event, allow_types=[EventParam]) as ctx: + ctx.pass_params(event=fake_event) + async with app.test_dependent( event_type, allow_types=[EventParam, DependParam] ) as ctx: diff --git a/tests/utils.py b/tests/utils.py index 31939be7..d0715e0e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -61,6 +61,7 @@ def make_fake_message(): def make_fake_event( + _base: Optional[Type["Event"]] = None, _type: str = "message", _name: str = "test", _description: str = "test", @@ -72,7 +73,7 @@ def make_fake_event( ) -> Type["Event"]: from nonebot.adapters import Event - _Fake = create_model("_Fake", __base__=Event, **fields) + _Fake = create_model("_Fake", __base__=_base or Event, **fields) class FakeEvent(_Fake): def get_type(self) -> str: