mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 00:55:07 +08:00
♿ improve pydantic validate for message
This commit is contained in:
parent
b43dfb983d
commit
5fa7806a2f
@ -20,6 +20,8 @@ from typing import (
|
||||
overload,
|
||||
)
|
||||
|
||||
from pydantic import parse_obj_as
|
||||
|
||||
from ._template import MessageTemplate
|
||||
|
||||
T = TypeVar("T")
|
||||
@ -33,7 +35,7 @@ class MessageSegment(abc.ABC, Generic[TM]):
|
||||
|
||||
type: str
|
||||
"""消息段类型"""
|
||||
data: Dict[str, Any] = field(default_factory=lambda: {})
|
||||
data: Dict[str, Any] = field(default_factory=dict)
|
||||
"""消息段数据"""
|
||||
|
||||
@classmethod
|
||||
@ -59,6 +61,18 @@ class MessageSegment(abc.ABC, Generic[TM]):
|
||||
def __radd__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM:
|
||||
return self.get_message_class()(other) + self
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls):
|
||||
yield cls._validate
|
||||
|
||||
@classmethod
|
||||
def _validate(cls, value):
|
||||
if isinstance(value, cls):
|
||||
return value
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(f"Expected dict for MessageSegment, got {type(value)}")
|
||||
return cls(**value)
|
||||
|
||||
def get(self, key: str, default: Any = None):
|
||||
return asdict(self).get(key, default)
|
||||
|
||||
@ -89,7 +103,7 @@ class Message(List[TMS], abc.ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: Union[str, None, Iterable[TMS], TMS, Any] = None,
|
||||
message: Union[str, None, Iterable[TMS], TMS] = None,
|
||||
):
|
||||
super().__init__()
|
||||
if message is None:
|
||||
@ -150,11 +164,21 @@ class Message(List[TMS], abc.ABC):
|
||||
|
||||
@classmethod
|
||||
def _validate(cls, value):
|
||||
if isinstance(value, str):
|
||||
pass
|
||||
elif isinstance(value, dict):
|
||||
value = parse_obj_as(cls.get_segment_class(), value)
|
||||
elif isinstance(value, Iterable):
|
||||
value = [parse_obj_as(cls.get_segment_class(), v) for v in value]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected str, dict or iterable for Message, got {type(value)}"
|
||||
)
|
||||
return cls(value)
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def _construct(msg: Union[str, Any]) -> Iterable[TMS]:
|
||||
def _construct(msg: str) -> Iterable[TMS]:
|
||||
"""构造消息数组"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user