🐛 fix union validation error (#1001)

This commit is contained in:
Ju4tCode 2022-05-22 19:42:30 +08:00 committed by GitHub
parent fe43cc92a5
commit 6feed0610b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 144 additions and 24 deletions

View File

@ -1,4 +1,5 @@
import abc import abc
from typing import Any, Type, TypeVar
from pydantic import BaseModel from pydantic import BaseModel
@ -6,6 +7,8 @@ from nonebot.utils import DataclassEncoder
from .message import Message from .message import Message
E = TypeVar("E", bound="Event")
class Event(abc.ABC, BaseModel): class Event(abc.ABC, BaseModel):
"""Event 基类。提供获取关键信息的方法,其余信息可直接获取。""" """Event 基类。提供获取关键信息的方法,其余信息可直接获取。"""
@ -14,6 +17,12 @@ class Event(abc.ABC, BaseModel):
extra = "allow" extra = "allow"
json_encoders = {Message: DataclassEncoder} json_encoders = {Message: DataclassEncoder}
@classmethod
def validate(cls: Type["E"], value: Any) -> "E":
if isinstance(value, Event) and not isinstance(value, cls):
raise TypeError(f"{value} is incompatible with Event type {cls}")
return super().validate(value)
@abc.abstractmethod @abc.abstractmethod
def get_type(self) -> str: def get_type(self) -> str:
"""获取事件类型的方法,类型通常为 NoneBot 内置的四种类型。""" """获取事件类型的方法,类型通常为 NoneBot 内置的四种类型。"""

View File

@ -9,6 +9,7 @@ from pydantic.fields import Required, Undefined, ModelField
from nonebot.log import logger from nonebot.log import logger
from nonebot.exception import TypeMisMatch from nonebot.exception import TypeMisMatch
from nonebot.dependencies.utils import check_field_type
from nonebot.dependencies import Param, Dependent, CustomConfig from nonebot.dependencies import Param, Dependent, CustomConfig
from nonebot.typing import T_State, T_Handler, T_DependencyCache from nonebot.typing import T_State, T_Handler, T_DependencyCache
from nonebot.utils import ( from nonebot.utils import (
@ -159,14 +160,14 @@ class DependParam(Param):
class _BotChecker(Param): class _BotChecker(Param):
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any: async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
field: ModelField = self.extra["field"] field: ModelField = self.extra["field"]
if isinstance(bot, field.type_): try:
return bot return check_field_type(field, bot)
else: except TypeMisMatch:
logger.debug( logger.debug(
f"Bot type {type(bot)} not match " f"Bot type {type(bot)} not match "
f"annotation {field._type_display()}, ignored" f"annotation {field._type_display()}, ignored"
) )
raise TypeMisMatch(field, bot) raise
class BotParam(Param): class BotParam(Param):
@ -205,14 +206,14 @@ class BotParam(Param):
class _EventChecker(Param): class _EventChecker(Param):
async def _solve(self, event: "Event", **kwargs: Any) -> Any: async def _solve(self, event: "Event", **kwargs: Any) -> Any:
field: ModelField = self.extra["field"] field: ModelField = self.extra["field"]
if isinstance(event, field.type_): try:
return event return check_field_type(field, event)
else: except TypeMisMatch:
logger.debug( logger.debug(
f"Event type {type(event)} not match " f"Event type {type(event)} not match "
f"annotation {field._type_display()}, ignored" f"annotation {field._type_display()}, ignored"
) )
raise TypeMisMatch(field, event) raise
class EventParam(Param): class EventParam(Param):

View File

@ -62,12 +62,10 @@ def generic_check_issubclass(
except TypeError: except TypeError:
origin = get_origin(cls) origin = get_origin(cls)
if is_union(origin): if is_union(origin):
for type_ in get_args(cls): return all(
if not is_none_type(type_) and not generic_check_issubclass( is_none_type(type_) or generic_check_issubclass(type_, class_or_tuple)
type_, class_or_tuple for type_ in get_args(cls)
): )
return False
return True
elif origin: elif origin:
return issubclass(origin, class_or_tuple) return issubclass(origin, class_or_tuple)
return False return False

View File

@ -1,3 +1,5 @@
from typing import Union
from nonebot.adapters import Bot from nonebot.adapters import Bot
@ -5,9 +7,29 @@ async def get_bot(b: Bot) -> Bot:
return b return b
class SubBot(Bot): async def legacy_bot(bot):
return bot
async def not_legacy_bot(bot: int):
... ...
async def sub_bot(b: SubBot) -> SubBot: class FooBot(Bot):
...
async def sub_bot(b: FooBot) -> FooBot:
return b return b
class BarBot(Bot):
...
async def union_bot(b: Union[FooBot, BarBot]) -> Union[FooBot, BarBot]:
return b
async def not_bot(b: Union[int, Bot]):
...

View File

@ -1,3 +1,5 @@
from typing import Union
from nonebot.adapters import Event, Message from nonebot.adapters import Event, Message
from nonebot.params import EventToMe, EventType, EventMessage, EventPlainText from nonebot.params import EventToMe, EventType, EventMessage, EventPlainText
@ -6,14 +8,34 @@ async def event(e: Event) -> Event:
return e return e
class SubEvent(Event): async def legacy_event(event):
return event
async def not_legacy_event(event: int):
... ...
async def sub_event(e: SubEvent) -> SubEvent: class FooEvent(Event):
...
async def sub_event(e: FooEvent) -> FooEvent:
return e return e
class BarEvent(Event):
...
async def union_event(e: Union[FooEvent, BarEvent]) -> Union[FooEvent, BarEvent]:
return e
async def not_event(e: Union[int, Event]):
...
async def event_type(t: str = EventType()) -> str: async def event_type(t: str = EventType()) -> str:
return t return t

View File

@ -19,6 +19,14 @@ async def state(x: T_State) -> T_State:
return x return x
async def legacy_state(state):
return state
async def not_legacy_state(state: int):
...
async def command(cmd: Tuple[str, ...] = Command()) -> Tuple[str, ...]: async def command(cmd: Tuple[str, ...] = Command()) -> Tuple[str, ...]:
return cmd return cmd

View File

@ -37,15 +37,32 @@ async def test_depend(app: App, load_plugin):
async def test_bot(app: App, load_plugin): async def test_bot(app: App, load_plugin):
from nonebot.params import BotParam from nonebot.params import BotParam
from nonebot.exception import TypeMisMatch from nonebot.exception import TypeMisMatch
from plugins.param.param_bot import SubBot, get_bot, sub_bot from plugins.param.param_bot import (
FooBot,
get_bot,
not_bot,
sub_bot,
union_bot,
legacy_bot,
not_legacy_bot,
)
async with app.test_dependent(get_bot, allow_types=[BotParam]) as ctx: async with app.test_dependent(get_bot, allow_types=[BotParam]) as ctx:
bot = ctx.create_bot() bot = ctx.create_bot()
ctx.pass_params(bot=bot) ctx.pass_params(bot=bot)
ctx.should_return(bot) ctx.should_return(bot)
async with app.test_dependent(legacy_bot, allow_types=[BotParam]) as ctx:
bot = ctx.create_bot()
ctx.pass_params(bot=bot)
ctx.should_return(bot)
with pytest.raises(ValueError):
async with app.test_dependent(not_legacy_bot, allow_types=[BotParam]) as ctx:
...
async with app.test_dependent(sub_bot, allow_types=[BotParam]) as ctx: async with app.test_dependent(sub_bot, allow_types=[BotParam]) as ctx:
bot = ctx.create_bot(base=SubBot) bot = ctx.create_bot(base=FooBot)
ctx.pass_params(bot=bot) ctx.pass_params(bot=bot)
ctx.should_return(bot) ctx.should_return(bot)
@ -54,37 +71,68 @@ async def test_bot(app: App, load_plugin):
bot = ctx.create_bot() bot = ctx.create_bot()
ctx.pass_params(bot=bot) ctx.pass_params(bot=bot)
async with app.test_dependent(union_bot, allow_types=[BotParam]) as ctx:
bot = ctx.create_bot(base=FooBot)
ctx.pass_params(bot=bot)
ctx.should_return(bot)
with pytest.raises(ValueError):
async with app.test_dependent(not_bot, allow_types=[BotParam]) as ctx:
...
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_event(app: App, load_plugin): async def test_event(app: App, load_plugin):
from nonebot.exception import TypeMisMatch from nonebot.exception import TypeMisMatch
from nonebot.params import EventParam, DependParam from nonebot.params import EventParam, DependParam
from plugins.param.param_event import ( from plugins.param.param_event import (
SubEvent, FooEvent,
event, event,
not_event,
sub_event, sub_event,
event_type, event_type,
event_to_me, event_to_me,
union_event,
legacy_event,
event_message, event_message,
event_plain_text, event_plain_text,
not_legacy_event,
) )
fake_message = make_fake_message()("text") fake_message = make_fake_message()("text")
fake_event = make_fake_event(_message=fake_message)() fake_event = make_fake_event(_message=fake_message)()
fake_subevent = make_fake_event(_base=SubEvent)() fake_fooevent = make_fake_event(_base=FooEvent)()
async with app.test_dependent(event, allow_types=[EventParam]) as ctx: async with app.test_dependent(event, allow_types=[EventParam]) as ctx:
ctx.pass_params(event=fake_event) ctx.pass_params(event=fake_event)
ctx.should_return(fake_event) ctx.should_return(fake_event)
async with app.test_dependent(legacy_event, allow_types=[EventParam]) as ctx:
ctx.pass_params(event=fake_event)
ctx.should_return(fake_event)
with pytest.raises(ValueError):
async with app.test_dependent(
not_legacy_event, allow_types=[EventParam]
) as ctx:
...
async with app.test_dependent(sub_event, allow_types=[EventParam]) as ctx: async with app.test_dependent(sub_event, allow_types=[EventParam]) as ctx:
ctx.pass_params(event=fake_subevent) ctx.pass_params(event=fake_fooevent)
ctx.should_return(fake_subevent) ctx.should_return(fake_fooevent)
with pytest.raises(TypeMisMatch): with pytest.raises(TypeMisMatch):
async with app.test_dependent(sub_event, allow_types=[EventParam]) as ctx: async with app.test_dependent(sub_event, allow_types=[EventParam]) as ctx:
ctx.pass_params(event=fake_event) ctx.pass_params(event=fake_event)
async with app.test_dependent(union_event, allow_types=[EventParam]) as ctx:
ctx.pass_params(event=fake_fooevent)
ctx.should_return(fake_event)
with pytest.raises(ValueError):
async with app.test_dependent(not_event, allow_types=[EventParam]) as ctx:
...
async with app.test_dependent( async with app.test_dependent(
event_type, allow_types=[EventParam, DependParam] event_type, allow_types=[EventParam, DependParam]
) as ctx: ) as ctx:
@ -132,8 +180,10 @@ async def test_state(app: App, load_plugin):
command_arg, command_arg,
raw_command, raw_command,
regex_group, regex_group,
legacy_state,
command_start, command_start,
regex_matched, regex_matched,
not_legacy_state,
shell_command_args, shell_command_args,
shell_command_argv, shell_command_argv,
) )
@ -157,6 +207,16 @@ async def test_state(app: App, load_plugin):
ctx.pass_params(state=fake_state) ctx.pass_params(state=fake_state)
ctx.should_return(fake_state) ctx.should_return(fake_state)
async with app.test_dependent(legacy_state, allow_types=[StateParam]) as ctx:
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state)
with pytest.raises(ValueError):
async with app.test_dependent(
not_legacy_state, allow_types=[StateParam]
) as ctx:
...
async with app.test_dependent( async with app.test_dependent(
command, allow_types=[StateParam, DependParam] command, allow_types=[StateParam, DependParam]
) as ctx: ) as ctx: