mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 09:05:04 +08:00
♿ improve pydantic validate for message
This commit is contained in:
parent
b43dfb983d
commit
5fa7806a2f
@ -20,6 +20,8 @@ from typing import (
|
|||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from pydantic import parse_obj_as
|
||||||
|
|
||||||
from ._template import MessageTemplate
|
from ._template import MessageTemplate
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
@ -33,7 +35,7 @@ class MessageSegment(abc.ABC, Generic[TM]):
|
|||||||
|
|
||||||
type: str
|
type: str
|
||||||
"""消息段类型"""
|
"""消息段类型"""
|
||||||
data: Dict[str, Any] = field(default_factory=lambda: {})
|
data: Dict[str, Any] = field(default_factory=dict)
|
||||||
"""消息段数据"""
|
"""消息段数据"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -59,6 +61,18 @@ class MessageSegment(abc.ABC, Generic[TM]):
|
|||||||
def __radd__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM:
|
def __radd__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM:
|
||||||
return self.get_message_class()(other) + self
|
return self.get_message_class()(other) + self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_validators__(cls):
|
||||||
|
yield cls._validate
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate(cls, value):
|
||||||
|
if isinstance(value, cls):
|
||||||
|
return value
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
raise ValueError(f"Expected dict for MessageSegment, got {type(value)}")
|
||||||
|
return cls(**value)
|
||||||
|
|
||||||
def get(self, key: str, default: Any = None):
|
def get(self, key: str, default: Any = None):
|
||||||
return asdict(self).get(key, default)
|
return asdict(self).get(key, default)
|
||||||
|
|
||||||
@ -89,7 +103,7 @@ class Message(List[TMS], abc.ABC):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message: Union[str, None, Iterable[TMS], TMS, Any] = None,
|
message: Union[str, None, Iterable[TMS], TMS] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if message is None:
|
if message is None:
|
||||||
@ -150,11 +164,21 @@ class Message(List[TMS], abc.ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate(cls, value):
|
def _validate(cls, value):
|
||||||
|
if isinstance(value, str):
|
||||||
|
pass
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
value = parse_obj_as(cls.get_segment_class(), value)
|
||||||
|
elif isinstance(value, Iterable):
|
||||||
|
value = [parse_obj_as(cls.get_segment_class(), v) for v in value]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected str, dict or iterable for Message, got {type(value)}"
|
||||||
|
)
|
||||||
return cls(value)
|
return cls(value)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _construct(msg: Union[str, Any]) -> Iterable[TMS]:
|
def _construct(msg: str) -> Iterable[TMS]:
|
||||||
"""构造消息数组"""
|
"""构造消息数组"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user