mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 00:55:07 +08:00
🐛 fix union validation error (#1001)
This commit is contained in:
parent
fe43cc92a5
commit
6feed0610b
@ -1,4 +1,5 @@
|
||||
import abc
|
||||
from typing import Any, Type, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -6,6 +7,8 @@ from nonebot.utils import DataclassEncoder
|
||||
|
||||
from .message import Message
|
||||
|
||||
E = TypeVar("E", bound="Event")
|
||||
|
||||
|
||||
class Event(abc.ABC, BaseModel):
|
||||
"""Event 基类。提供获取关键信息的方法,其余信息可直接获取。"""
|
||||
@ -14,6 +17,12 @@ class Event(abc.ABC, BaseModel):
|
||||
extra = "allow"
|
||||
json_encoders = {Message: DataclassEncoder}
|
||||
|
||||
@classmethod
|
||||
def validate(cls: Type["E"], value: Any) -> "E":
|
||||
if isinstance(value, Event) and not isinstance(value, cls):
|
||||
raise TypeError(f"{value} is incompatible with Event type {cls}")
|
||||
return super().validate(value)
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_type(self) -> str:
|
||||
"""获取事件类型的方法,类型通常为 NoneBot 内置的四种类型。"""
|
||||
|
@ -9,6 +9,7 @@ from pydantic.fields import Required, Undefined, ModelField
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.exception import TypeMisMatch
|
||||
from nonebot.dependencies.utils import check_field_type
|
||||
from nonebot.dependencies import Param, Dependent, CustomConfig
|
||||
from nonebot.typing import T_State, T_Handler, T_DependencyCache
|
||||
from nonebot.utils import (
|
||||
@ -159,14 +160,14 @@ class DependParam(Param):
|
||||
class _BotChecker(Param):
|
||||
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
|
||||
field: ModelField = self.extra["field"]
|
||||
if isinstance(bot, field.type_):
|
||||
return bot
|
||||
else:
|
||||
try:
|
||||
return check_field_type(field, bot)
|
||||
except TypeMisMatch:
|
||||
logger.debug(
|
||||
f"Bot type {type(bot)} not match "
|
||||
f"annotation {field._type_display()}, ignored"
|
||||
)
|
||||
raise TypeMisMatch(field, bot)
|
||||
raise
|
||||
|
||||
|
||||
class BotParam(Param):
|
||||
@ -205,14 +206,14 @@ class BotParam(Param):
|
||||
class _EventChecker(Param):
|
||||
async def _solve(self, event: "Event", **kwargs: Any) -> Any:
|
||||
field: ModelField = self.extra["field"]
|
||||
if isinstance(event, field.type_):
|
||||
return event
|
||||
else:
|
||||
try:
|
||||
return check_field_type(field, event)
|
||||
except TypeMisMatch:
|
||||
logger.debug(
|
||||
f"Event type {type(event)} not match "
|
||||
f"annotation {field._type_display()}, ignored"
|
||||
)
|
||||
raise TypeMisMatch(field, event)
|
||||
raise
|
||||
|
||||
|
||||
class EventParam(Param):
|
||||
|
@ -62,12 +62,10 @@ def generic_check_issubclass(
|
||||
except TypeError:
|
||||
origin = get_origin(cls)
|
||||
if is_union(origin):
|
||||
for type_ in get_args(cls):
|
||||
if not is_none_type(type_) and not generic_check_issubclass(
|
||||
type_, class_or_tuple
|
||||
):
|
||||
return False
|
||||
return True
|
||||
return all(
|
||||
is_none_type(type_) or generic_check_issubclass(type_, class_or_tuple)
|
||||
for type_ in get_args(cls)
|
||||
)
|
||||
elif origin:
|
||||
return issubclass(origin, class_or_tuple)
|
||||
return False
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Union
|
||||
|
||||
from nonebot.adapters import Bot
|
||||
|
||||
|
||||
@ -5,9 +7,29 @@ async def get_bot(b: Bot) -> Bot:
|
||||
return b
|
||||
|
||||
|
||||
class SubBot(Bot):
|
||||
async def legacy_bot(bot):
|
||||
return bot
|
||||
|
||||
|
||||
async def not_legacy_bot(bot: int):
|
||||
...
|
||||
|
||||
|
||||
async def sub_bot(b: SubBot) -> SubBot:
|
||||
class FooBot(Bot):
|
||||
...
|
||||
|
||||
|
||||
async def sub_bot(b: FooBot) -> FooBot:
|
||||
return b
|
||||
|
||||
|
||||
class BarBot(Bot):
|
||||
...
|
||||
|
||||
|
||||
async def union_bot(b: Union[FooBot, BarBot]) -> Union[FooBot, BarBot]:
|
||||
return b
|
||||
|
||||
|
||||
async def not_bot(b: Union[int, Bot]):
|
||||
...
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Union
|
||||
|
||||
from nonebot.adapters import Event, Message
|
||||
from nonebot.params import EventToMe, EventType, EventMessage, EventPlainText
|
||||
|
||||
@ -6,14 +8,34 @@ async def event(e: Event) -> Event:
|
||||
return e
|
||||
|
||||
|
||||
class SubEvent(Event):
|
||||
async def legacy_event(event):
|
||||
return event
|
||||
|
||||
|
||||
async def not_legacy_event(event: int):
|
||||
...
|
||||
|
||||
|
||||
async def sub_event(e: SubEvent) -> SubEvent:
|
||||
class FooEvent(Event):
|
||||
...
|
||||
|
||||
|
||||
async def sub_event(e: FooEvent) -> FooEvent:
|
||||
return e
|
||||
|
||||
|
||||
class BarEvent(Event):
|
||||
...
|
||||
|
||||
|
||||
async def union_event(e: Union[FooEvent, BarEvent]) -> Union[FooEvent, BarEvent]:
|
||||
return e
|
||||
|
||||
|
||||
async def not_event(e: Union[int, Event]):
|
||||
...
|
||||
|
||||
|
||||
async def event_type(t: str = EventType()) -> str:
|
||||
return t
|
||||
|
||||
|
@ -19,6 +19,14 @@ async def state(x: T_State) -> T_State:
|
||||
return x
|
||||
|
||||
|
||||
async def legacy_state(state):
|
||||
return state
|
||||
|
||||
|
||||
async def not_legacy_state(state: int):
|
||||
...
|
||||
|
||||
|
||||
async def command(cmd: Tuple[str, ...] = Command()) -> Tuple[str, ...]:
|
||||
return cmd
|
||||
|
||||
|
@ -37,15 +37,32 @@ async def test_depend(app: App, load_plugin):
|
||||
async def test_bot(app: App, load_plugin):
|
||||
from nonebot.params import BotParam
|
||||
from nonebot.exception import TypeMisMatch
|
||||
from plugins.param.param_bot import SubBot, get_bot, sub_bot
|
||||
from plugins.param.param_bot import (
|
||||
FooBot,
|
||||
get_bot,
|
||||
not_bot,
|
||||
sub_bot,
|
||||
union_bot,
|
||||
legacy_bot,
|
||||
not_legacy_bot,
|
||||
)
|
||||
|
||||
async with app.test_dependent(get_bot, allow_types=[BotParam]) as ctx:
|
||||
bot = ctx.create_bot()
|
||||
ctx.pass_params(bot=bot)
|
||||
ctx.should_return(bot)
|
||||
|
||||
async with app.test_dependent(legacy_bot, allow_types=[BotParam]) as ctx:
|
||||
bot = ctx.create_bot()
|
||||
ctx.pass_params(bot=bot)
|
||||
ctx.should_return(bot)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async with app.test_dependent(not_legacy_bot, allow_types=[BotParam]) as ctx:
|
||||
...
|
||||
|
||||
async with app.test_dependent(sub_bot, allow_types=[BotParam]) as ctx:
|
||||
bot = ctx.create_bot(base=SubBot)
|
||||
bot = ctx.create_bot(base=FooBot)
|
||||
ctx.pass_params(bot=bot)
|
||||
ctx.should_return(bot)
|
||||
|
||||
@ -54,37 +71,68 @@ async def test_bot(app: App, load_plugin):
|
||||
bot = ctx.create_bot()
|
||||
ctx.pass_params(bot=bot)
|
||||
|
||||
async with app.test_dependent(union_bot, allow_types=[BotParam]) as ctx:
|
||||
bot = ctx.create_bot(base=FooBot)
|
||||
ctx.pass_params(bot=bot)
|
||||
ctx.should_return(bot)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async with app.test_dependent(not_bot, allow_types=[BotParam]) as ctx:
|
||||
...
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event(app: App, load_plugin):
|
||||
from nonebot.exception import TypeMisMatch
|
||||
from nonebot.params import EventParam, DependParam
|
||||
from plugins.param.param_event import (
|
||||
SubEvent,
|
||||
FooEvent,
|
||||
event,
|
||||
not_event,
|
||||
sub_event,
|
||||
event_type,
|
||||
event_to_me,
|
||||
union_event,
|
||||
legacy_event,
|
||||
event_message,
|
||||
event_plain_text,
|
||||
not_legacy_event,
|
||||
)
|
||||
|
||||
fake_message = make_fake_message()("text")
|
||||
fake_event = make_fake_event(_message=fake_message)()
|
||||
fake_subevent = make_fake_event(_base=SubEvent)()
|
||||
fake_fooevent = make_fake_event(_base=FooEvent)()
|
||||
|
||||
async with app.test_dependent(event, allow_types=[EventParam]) as ctx:
|
||||
ctx.pass_params(event=fake_event)
|
||||
ctx.should_return(fake_event)
|
||||
|
||||
async with app.test_dependent(legacy_event, allow_types=[EventParam]) as ctx:
|
||||
ctx.pass_params(event=fake_event)
|
||||
ctx.should_return(fake_event)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async with app.test_dependent(
|
||||
not_legacy_event, allow_types=[EventParam]
|
||||
) as ctx:
|
||||
...
|
||||
|
||||
async with app.test_dependent(sub_event, allow_types=[EventParam]) as ctx:
|
||||
ctx.pass_params(event=fake_subevent)
|
||||
ctx.should_return(fake_subevent)
|
||||
ctx.pass_params(event=fake_fooevent)
|
||||
ctx.should_return(fake_fooevent)
|
||||
|
||||
with pytest.raises(TypeMisMatch):
|
||||
async with app.test_dependent(sub_event, allow_types=[EventParam]) as ctx:
|
||||
ctx.pass_params(event=fake_event)
|
||||
|
||||
async with app.test_dependent(union_event, allow_types=[EventParam]) as ctx:
|
||||
ctx.pass_params(event=fake_fooevent)
|
||||
ctx.should_return(fake_event)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async with app.test_dependent(not_event, allow_types=[EventParam]) as ctx:
|
||||
...
|
||||
|
||||
async with app.test_dependent(
|
||||
event_type, allow_types=[EventParam, DependParam]
|
||||
) as ctx:
|
||||
@ -132,8 +180,10 @@ async def test_state(app: App, load_plugin):
|
||||
command_arg,
|
||||
raw_command,
|
||||
regex_group,
|
||||
legacy_state,
|
||||
command_start,
|
||||
regex_matched,
|
||||
not_legacy_state,
|
||||
shell_command_args,
|
||||
shell_command_argv,
|
||||
)
|
||||
@ -157,6 +207,16 @@ async def test_state(app: App, load_plugin):
|
||||
ctx.pass_params(state=fake_state)
|
||||
ctx.should_return(fake_state)
|
||||
|
||||
async with app.test_dependent(legacy_state, allow_types=[StateParam]) as ctx:
|
||||
ctx.pass_params(state=fake_state)
|
||||
ctx.should_return(fake_state)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async with app.test_dependent(
|
||||
not_legacy_state, allow_types=[StateParam]
|
||||
) as ctx:
|
||||
...
|
||||
|
||||
async with app.test_dependent(
|
||||
command, allow_types=[StateParam, DependParam]
|
||||
) as ctx:
|
||||
|
Loading…
Reference in New Issue
Block a user