improve pydantic validate for message

This commit is contained in:
yanyongyu 2022-01-29 18:20:30 +08:00
parent b43dfb983d
commit 5fa7806a2f
No known key found for this signature in database
GPG Key ID: 796D8A7FB73396EB

View File

@ -20,6 +20,8 @@ from typing import (
overload, overload,
) )
from pydantic import parse_obj_as
from ._template import MessageTemplate from ._template import MessageTemplate
T = TypeVar("T") T = TypeVar("T")
@ -33,7 +35,7 @@ class MessageSegment(abc.ABC, Generic[TM]):
type: str type: str
"""消息段类型""" """消息段类型"""
data: Dict[str, Any] = field(default_factory=lambda: {}) data: Dict[str, Any] = field(default_factory=dict)
"""消息段数据""" """消息段数据"""
@classmethod @classmethod
@ -59,6 +61,18 @@ class MessageSegment(abc.ABC, Generic[TM]):
def __radd__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM: def __radd__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM:
return self.get_message_class()(other) + self return self.get_message_class()(other) + self
@classmethod
def __get_validators__(cls):
yield cls._validate
@classmethod
def _validate(cls, value):
if isinstance(value, cls):
return value
if not isinstance(value, dict):
raise ValueError(f"Expected dict for MessageSegment, got {type(value)}")
return cls(**value)
def get(self, key: str, default: Any = None): def get(self, key: str, default: Any = None):
return asdict(self).get(key, default) return asdict(self).get(key, default)
@ -89,7 +103,7 @@ class Message(List[TMS], abc.ABC):
def __init__( def __init__(
self, self,
message: Union[str, None, Iterable[TMS], TMS, Any] = None, message: Union[str, None, Iterable[TMS], TMS] = None,
): ):
super().__init__() super().__init__()
if message is None: if message is None:
@ -150,11 +164,21 @@ class Message(List[TMS], abc.ABC):
@classmethod @classmethod
def _validate(cls, value): def _validate(cls, value):
if isinstance(value, str):
pass
elif isinstance(value, dict):
value = parse_obj_as(cls.get_segment_class(), value)
elif isinstance(value, Iterable):
value = [parse_obj_as(cls.get_segment_class(), v) for v in value]
else:
raise ValueError(
f"Expected str, dict or iterable for Message, got {type(value)}"
)
return cls(value) return cls(value)
@staticmethod @staticmethod
@abc.abstractmethod @abc.abstractmethod
def _construct(msg: Union[str, Any]) -> Iterable[TMS]: def _construct(msg: str) -> Iterable[TMS]:
"""构造消息数组""" """构造消息数组"""
raise NotImplementedError raise NotImplementedError