diff --git a/nonebot/adapters/_message.py b/nonebot/adapters/_message.py index 3ff6e9b9..fc8d8c3f 100644 --- a/nonebot/adapters/_message.py +++ b/nonebot/adapters/_message.py @@ -20,6 +20,8 @@ from typing import ( overload, ) +from pydantic import parse_obj_as + from ._template import MessageTemplate T = TypeVar("T") @@ -33,7 +35,7 @@ class MessageSegment(abc.ABC, Generic[TM]): type: str """消息段类型""" - data: Dict[str, Any] = field(default_factory=lambda: {}) + data: Dict[str, Any] = field(default_factory=dict) """消息段数据""" @classmethod @@ -59,6 +61,18 @@ class MessageSegment(abc.ABC, Generic[TM]): def __radd__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM: return self.get_message_class()(other) + self + @classmethod + def __get_validators__(cls): + yield cls._validate + + @classmethod + def _validate(cls, value): + if isinstance(value, cls): + return value + if not isinstance(value, dict): + raise ValueError(f"Expected dict for MessageSegment, got {type(value)}") + return cls(**value) + def get(self, key: str, default: Any = None): return asdict(self).get(key, default) @@ -89,7 +103,7 @@ class Message(List[TMS], abc.ABC): def __init__( self, - message: Union[str, None, Iterable[TMS], TMS, Any] = None, + message: Union[str, None, Iterable[TMS], TMS] = None, ): super().__init__() if message is None: @@ -150,11 +164,21 @@ class Message(List[TMS], abc.ABC): @classmethod def _validate(cls, value): + if isinstance(value, str): + pass + elif isinstance(value, dict): + value = parse_obj_as(cls.get_segment_class(), value) + elif isinstance(value, Iterable): + value = [parse_obj_as(cls.get_segment_class(), v) for v in value] + else: + raise ValueError( + f"Expected str, dict or iterable for Message, got {type(value)}" + ) return cls(value) @staticmethod @abc.abstractmethod - def _construct(msg: Union[str, Any]) -> Iterable[TMS]: + def _construct(msg: str) -> Iterable[TMS]: """构造消息数组""" raise NotImplementedError