🐛 fix event maybe converted when checking type (#876)

Fix: 修复 event 类型检查会对类型进行自动转换
This commit is contained in:
Ju4tCode 2022-03-20 19:40:43 +08:00 committed by GitHub
parent fcdb05a7e2
commit 45e2e6c280
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 53 additions and 20 deletions

View File

@ -9,7 +9,6 @@ from pydantic.fields import Required, Undefined, ModelField
from nonebot.log import logger from nonebot.log import logger
from nonebot.exception import TypeMisMatch from nonebot.exception import TypeMisMatch
from nonebot.dependencies.utils import check_field_type
from nonebot.dependencies import Param, Dependent, CustomConfig from nonebot.dependencies import Param, Dependent, CustomConfig
from nonebot.typing import T_State, T_Handler, T_DependencyCache from nonebot.typing import T_State, T_Handler, T_DependencyCache
from nonebot.utils import ( from nonebot.utils import (
@ -160,14 +159,14 @@ class DependParam(Param):
class _BotChecker(Param): class _BotChecker(Param):
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any: async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
field: ModelField = self.extra["field"] field: ModelField = self.extra["field"]
try: if isinstance(bot, field.type_):
return check_field_type(field, bot) return bot
except TypeMisMatch: else:
logger.debug( logger.debug(
f"Bot type {type(bot)} not match " f"Bot type {type(bot)} not match "
f"annotation {field._type_display()}, ignored" f"annotation {field._type_display()}, ignored"
) )
raise raise TypeMisMatch(field, bot)
class BotParam(Param): class BotParam(Param):
@ -206,14 +205,14 @@ class BotParam(Param):
class _EventChecker(Param): class _EventChecker(Param):
async def _solve(self, event: "Event", **kwargs: Any) -> Any: async def _solve(self, event: "Event", **kwargs: Any) -> Any:
field: ModelField = self.extra["field"] field: ModelField = self.extra["field"]
try: if isinstance(event, field.type_):
return check_field_type(field, event) return event
except TypeMisMatch: else:
logger.debug( logger.debug(
f"Event type {type(event)} not match " f"Event type {type(event)} not match "
f"annotation {field._type_display()}, ignored" f"annotation {field._type_display()}, ignored"
) )
raise raise TypeMisMatch(field, event)
class EventParam(Param): class EventParam(Param):

View File

@ -1,5 +1,13 @@
from nonebot.adapters import Bot from nonebot.adapters import Bot
async def get_bot(b: Bot): async def get_bot(b: Bot) -> Bot:
return b
class SubBot(Bot):
...
async def sub_bot(b: SubBot) -> SubBot:
return b return b

View File

@ -6,6 +6,14 @@ async def event(e: Event) -> Event:
return e return e
class SubEvent(Event):
...
async def sub_event(e: SubEvent) -> SubEvent:
return e
async def event_type(t: str = EventType()) -> str: async def event_type(t: str = EventType()) -> str:
return t return t

View File

@ -35,11 +35,8 @@ async def test_get(monkeypatch: pytest.MonkeyPatch, nonebug_clear):
from nonebot.drivers import ForwardDriver, ReverseDriver from nonebot.drivers import ForwardDriver, ReverseDriver
from nonebot import get_app, get_bot, get_asgi, get_bots, get_driver from nonebot import get_app, get_bot, get_asgi, get_bots, get_driver
try: with pytest.raises(ValueError):
get_driver() get_driver()
assert False, "Driver can only be got after initialization"
except ValueError:
assert True
nonebot.init(driver="nonebot.drivers.fastapi") nonebot.init(driver="nonebot.drivers.fastapi")
@ -59,11 +56,8 @@ async def test_get(monkeypatch: pytest.MonkeyPatch, nonebug_clear):
nonebot.run("arg", kwarg="kwarg") nonebot.run("arg", kwarg="kwarg")
assert runned assert runned
try: with pytest.raises(ValueError):
get_bot() get_bot()
assert False
except ValueError:
assert True
monkeypatch.setattr(driver, "_clients", {"test": "test"}) monkeypatch.setattr(driver, "_clients", {"test": "test"})
assert get_bot() == "test" assert get_bot() == "test"

View File

@ -36,19 +36,33 @@ async def test_depend(app: App, load_plugin):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bot(app: App, load_plugin): async def test_bot(app: App, load_plugin):
from nonebot.params import BotParam from nonebot.params import BotParam
from plugins.param.param_bot import get_bot from nonebot.exception import TypeMisMatch
from plugins.param.param_bot import SubBot, get_bot, sub_bot
async with app.test_dependent(get_bot, allow_types=[BotParam]) as ctx: async with app.test_dependent(get_bot, allow_types=[BotParam]) as ctx:
bot = ctx.create_bot() bot = ctx.create_bot()
ctx.pass_params(bot=bot) ctx.pass_params(bot=bot)
ctx.should_return(bot) ctx.should_return(bot)
async with app.test_dependent(sub_bot, allow_types=[BotParam]) as ctx:
bot = ctx.create_bot(base=SubBot)
ctx.pass_params(bot=bot)
ctx.should_return(bot)
with pytest.raises(TypeMisMatch):
async with app.test_dependent(sub_bot, allow_types=[BotParam]) as ctx:
bot = ctx.create_bot()
ctx.pass_params(bot=bot)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_event(app: App, load_plugin): async def test_event(app: App, load_plugin):
from nonebot.exception import TypeMisMatch
from nonebot.params import EventParam, DependParam from nonebot.params import EventParam, DependParam
from plugins.param.param_event import ( from plugins.param.param_event import (
SubEvent,
event, event,
sub_event,
event_type, event_type,
event_to_me, event_to_me,
event_message, event_message,
@ -57,11 +71,20 @@ async def test_event(app: App, load_plugin):
fake_message = make_fake_message()("text") fake_message = make_fake_message()("text")
fake_event = make_fake_event(_message=fake_message)() fake_event = make_fake_event(_message=fake_message)()
fake_subevent = make_fake_event(_base=SubEvent)()
async with app.test_dependent(event, allow_types=[EventParam]) as ctx: async with app.test_dependent(event, allow_types=[EventParam]) as ctx:
ctx.pass_params(event=fake_event) ctx.pass_params(event=fake_event)
ctx.should_return(fake_event) ctx.should_return(fake_event)
async with app.test_dependent(sub_event, allow_types=[EventParam]) as ctx:
ctx.pass_params(event=fake_subevent)
ctx.should_return(fake_subevent)
with pytest.raises(TypeMisMatch):
async with app.test_dependent(sub_event, allow_types=[EventParam]) as ctx:
ctx.pass_params(event=fake_event)
async with app.test_dependent( async with app.test_dependent(
event_type, allow_types=[EventParam, DependParam] event_type, allow_types=[EventParam, DependParam]
) as ctx: ) as ctx:

View File

@ -61,6 +61,7 @@ def make_fake_message():
def make_fake_event( def make_fake_event(
_base: Optional[Type["Event"]] = None,
_type: str = "message", _type: str = "message",
_name: str = "test", _name: str = "test",
_description: str = "test", _description: str = "test",
@ -72,7 +73,7 @@ def make_fake_event(
) -> Type["Event"]: ) -> Type["Event"]:
from nonebot.adapters import Event from nonebot.adapters import Event
_Fake = create_model("_Fake", __base__=Event, **fields) _Fake = create_model("_Fake", __base__=_base or Event, **fields)
class FakeEvent(_Fake): class FakeEvent(_Fake):
def get_type(self) -> str: def get_type(self) -> str: