diff --git a/nonebot/adapters/_message.py b/nonebot/adapters/_message.py index 57eaebaf..ffeb2bf5 100644 --- a/nonebot/adapters/_message.py +++ b/nonebot/adapters/_message.py @@ -14,27 +14,28 @@ from typing import ( Tuple, Union, Generic, - Mapping, TypeVar, Iterable, Optional, overload, ) +from pydantic import parse_obj_as + from ._template import MessageTemplate T = TypeVar("T") -TMS = TypeVar("TMS", covariant=True) +TMS = TypeVar("TMS", bound="MessageSegment") TM = TypeVar("TM", bound="Message") @dataclass -class MessageSegment(Mapping, abc.ABC, Generic[TM]): +class MessageSegment(abc.ABC, Generic[TM]): """消息段基类""" type: str """消息段类型""" - data: Dict[str, Any] = field(default_factory=lambda: {}) + data: Dict[str, Any] = field(default_factory=dict) """消息段数据""" @classmethod @@ -54,26 +55,26 @@ class MessageSegment(Mapping, abc.ABC, Generic[TM]): def __ne__(self: T, other: T) -> bool: return not self == other - def __add__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> TM: - return self.get_message_class()(self) + other # type: ignore + def __add__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM: + return self.get_message_class()(self) + other - def __radd__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> TM: - return self.get_message_class()(other) + self # type: ignore + def __radd__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM: + return self.get_message_class()(other) + self - def __getitem__(self, key: str): - return getattr(self, key) + @classmethod + def __get_validators__(cls): + yield cls._validate - def __setitem__(self, key: str, value: Any): - return setattr(self, key, value) - - def __iter__(self): - yield from asdict(self).keys() - - def __contains__(self, key: Any) -> bool: - return key in asdict(self).keys() + @classmethod + def _validate(cls, value): + if isinstance(value, cls): + return value + if not isinstance(value, dict): + raise ValueError(f"Expected dict for MessageSegment, got {type(value)}") + return cls(**value) def get(self, key: str, default: Any = None): - return getattr(self, key, default) + return asdict(self).get(key, default) def keys(self): return asdict(self).keys() @@ -101,20 +102,20 @@ class Message(List[TMS], abc.ABC): """ def __init__( - self: TM, - message: Union[str, None, Mapping, Iterable[Mapping], TMS, TM, Any] = None, - *args, - **kwargs, + self, + message: Union[str, None, Iterable[TMS], TMS] = None, ): - super().__init__(*args, **kwargs) + super().__init__() if message is None: return - elif isinstance(message, Message): - self.extend(message) + elif isinstance(message, str): + self.extend(self._construct(message)) elif isinstance(message, MessageSegment): self.append(message) + elif isinstance(message, Iterable): + self.extend(message) else: - self.extend(self._construct(message)) + self.extend(self._construct(message)) # pragma: no cover @classmethod def template(cls: Type[TM], format_string: Union[str, TM]) -> MessageTemplate[TM]: @@ -154,7 +155,7 @@ class Message(List[TMS], abc.ABC): """获取消息段类型""" raise NotImplementedError - def __str__(self): + def __str__(self) -> str: return "".join(str(seg) for seg in self) @classmethod @@ -163,51 +164,97 @@ class Message(List[TMS], abc.ABC): @classmethod def _validate(cls, value): + if isinstance(value, cls): + return value + elif isinstance(value, Message): + raise ValueError(f"Type {type(value)} can not be converted to {cls}") + elif isinstance(value, str): + pass + elif isinstance(value, dict): + value = parse_obj_as(cls.get_segment_class(), value) + elif isinstance(value, Iterable): + value = [parse_obj_as(cls.get_segment_class(), v) for v in value] + else: + raise ValueError( + f"Expected str, dict or iterable for Message, got {type(value)}" + ) return cls(value) @staticmethod @abc.abstractmethod - def _construct(msg: Union[str, Mapping, Iterable[Mapping], Any]) -> Iterable[TMS]: + def _construct(msg: str) -> Iterable[TMS]: """构造消息数组""" raise NotImplementedError - def __add__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM: + def __add__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM: result = self.copy() result += other return result - def __radd__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM: - result = self.__class__(other) # type: ignore + def __radd__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM: + result = self.__class__(other) return result + self - def __iadd__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM: - if isinstance(other, MessageSegment): + def __iadd__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM: + if isinstance(other, str): + self.extend(self._construct(other)) + elif isinstance(other, MessageSegment): self.append(other) - elif isinstance(other, Message): + elif isinstance(other, Iterable): self.extend(other) else: - self.extend(self._construct(other)) + raise ValueError(f"Unsupported type: {type(other)}") # pragma: no cover return self @overload def __getitem__(self: TM, __args: str) -> TM: - ... + """ + 参数: + __args: 消息段类型 + + 返回: + 所有类型为 `__args` 的消息段 + """ @overload def __getitem__(self, __args: Tuple[str, int]) -> TMS: - ... + """ + 参数: + __args: 消息段类型和索引 + + 返回: + 类型为 `__args[0]` 的消息段第 `__args[1]` 个 + """ @overload def __getitem__(self: TM, __args: Tuple[str, slice]) -> TM: - ... + """ + 参数: + __args: 消息段类型和切片 + + 返回: + 类型为 `__args[0]` 的消息段切片 `__args[1]` + """ @overload def __getitem__(self, __args: int) -> TMS: - ... + """ + 参数: + __args: 索引 + + 返回: + 第 `__args` 个消息段 + """ @overload def __getitem__(self: TM, __args: slice) -> TM: - ... + """ + 参数: + __args: 切片 + + 返回: + 消息切片 `__args` + """ def __getitem__( self: TM, @@ -231,25 +278,29 @@ class Message(List[TMS], abc.ABC): elif isinstance(arg1, str) and isinstance(arg2, slice): return self.__class__([seg for seg in self if seg.type == arg1][arg2]) else: - raise ValueError("Incorrect arguments to slice") + raise ValueError("Incorrect arguments to slice") # pragma: no cover def index(self, value: Union[TMS, str], *args) -> int: if isinstance(value, str): - first_segment = next((seg for seg in self if seg.type == value), None) # type: ignore - return super().index(first_segment, *args) # type: ignore + first_segment = next((seg for seg in self if seg.type == value), None) + if first_segment is None: + raise ValueError(f"Segment with type {value} is not in message") + return super().index(first_segment, *args) return super().index(value, *args) def get(self: TM, type_: str, count: Optional[int] = None) -> TM: if count is None: return self[type_] - iterator, filtered = (seg for seg in self if seg.type == type_), [] + iterator, filtered = ( + seg for seg in self if seg.type == type_ + ), self.__class__() for _ in range(count): seg = next(iterator, None) if seg is None: break filtered.append(seg) - return self.__class__(filtered) + return filtered def count(self, value: Union[TMS, str]) -> int: return len(self[value]) if isinstance(value, str) else super().count(value) @@ -261,11 +312,11 @@ class Message(List[TMS], abc.ABC): obj: 要添加的消息段 """ if isinstance(obj, MessageSegment): - super(Message, self).append(obj) + super().append(obj) elif isinstance(obj, str): self.extend(self._construct(obj)) else: - raise ValueError(f"Unexpected type: {type(obj)} {obj}") + raise ValueError(f"Unexpected type: {type(obj)} {obj}") # pragma: no cover return self def extend(self: TM, obj: Union[TM, Iterable[TMS]]) -> TM: @@ -281,10 +332,15 @@ class Message(List[TMS], abc.ABC): def copy(self: TM) -> TM: return deepcopy(self) - def extract_plain_text(self: "Message[MessageSegment]") -> str: + def extract_plain_text(self) -> str: """提取消息内纯文本消息""" return "".join(str(seg) for seg in self if seg.is_text()) -__autodoc__ = {"MessageSegment.__str__": True} +__autodoc__ = { + "MessageSegment.__str__": True, + "MessageSegment.__add__": True, + "Message.__getitem__": True, + "Message._construct": True, +} diff --git a/tests/test_adapters/test_message.py b/tests/test_adapters/test_message.py index 88d9e563..fd6bd5a4 100644 --- a/tests/test_adapters/test_message.py +++ b/tests/test_adapters/test_message.py @@ -1,23 +1,90 @@ +from pydantic import ValidationError, parse_obj_as + from utils import make_fake_message -def test_message_template(): - from nonebot.adapters import MessageTemplate - +def test_segment_add(): Message = make_fake_message() + MessageSegment = Message.get_segment_class() - template = MessageTemplate("{a:custom}{b:text}{c:image}", Message) + assert MessageSegment.text("text") + MessageSegment.text("text") == Message( + [MessageSegment.text("text"), MessageSegment.text("text")] + ) - @template.add_format_spec - def custom(input: str) -> str: - return input + "-custom!" + assert MessageSegment.text("text") + "text" == Message( + [MessageSegment.text("text"), MessageSegment.text("text")] + ) - formatted = template.format(a="test", b="test", c="https://example.com/test") - assert formatted.extract_plain_text() == "test-custom!test" - assert str(formatted) == "test-custom!test[fake:image]" + assert MessageSegment.text("text") + Message( + [MessageSegment.text("text")] + ) == Message([MessageSegment.text("text"), MessageSegment.text("text")]) + + assert "text" + MessageSegment.text("text") == Message( + [MessageSegment.text("text"), MessageSegment.text("text")] + ) -def test_message_slice(): +def test_segment_validate(): + Message = make_fake_message() + MessageSegment = Message.get_segment_class() + + assert parse_obj_as( + MessageSegment, {"type": "text", "data": {"text": "text"}} + ) == MessageSegment.text("text") + + try: + parse_obj_as(MessageSegment, "some str") + assert False + except ValidationError: + assert True + + +def test_segment(): + Message = make_fake_message() + MessageSegment = Message.get_segment_class() + + assert len(MessageSegment.text("text")) == 4 + assert MessageSegment.text("text") != MessageSegment.text("other") + assert MessageSegment.text("text").get("data") == {"text": "text"} + assert list(MessageSegment.text("text").keys()) == ["type", "data"] + assert list(MessageSegment.text("text").values()) == ["text", {"text": "text"}] + assert list(MessageSegment.text("text").items()) == [ + ("type", "text"), + ("data", {"text": "text"}), + ] + + origin = MessageSegment.text("text") + copy = origin.copy() + assert origin is not copy + assert origin == copy + + +def test_message_add(): + Message = make_fake_message() + MessageSegment = Message.get_segment_class() + + assert Message([MessageSegment.text("text")]) + MessageSegment.text( + "text" + ) == Message([MessageSegment.text("text"), MessageSegment.text("text")]) + + assert Message([MessageSegment.text("text")]) + "text" == Message( + [MessageSegment.text("text"), MessageSegment.text("text")] + ) + + assert Message([MessageSegment.text("text")]) + Message( + [MessageSegment.text("text")] + ) == Message([MessageSegment.text("text"), MessageSegment.text("text")]) + + assert "text" + Message([MessageSegment.text("text")]) == Message( + [MessageSegment.text("text"), MessageSegment.text("text")] + ) + + msg = Message([MessageSegment.text("text")]) + msg += MessageSegment.text("text") + assert msg == Message([MessageSegment.text("text"), MessageSegment.text("text")]) + + +def test_message_getitem(): Message = make_fake_message() MessageSegment = Message.get_segment_class() @@ -52,3 +119,38 @@ def test_message_slice(): assert message.get("image", 1) == Message([message["image", 0]]) assert message.count("image") == 2 + + +def test_message_validate(): + Message = make_fake_message() + MessageSegment = Message.get_segment_class() + + Message_ = make_fake_message() + + assert parse_obj_as(Message, Message([])) == Message([]) + + try: + parse_obj_as(Message, Message_([])) + assert False + except ValidationError: + assert True + + assert parse_obj_as(Message, "text") == Message([MessageSegment.text("text")]) + + assert parse_obj_as(Message, {"type": "text", "data": {"text": "text"}}) == Message( + [MessageSegment.text("text")] + ) + + assert ( + parse_obj_as( + Message, + [MessageSegment.text("text"), {"type": "text", "data": {"text": "text"}}], + ) + == Message([MessageSegment.text("text"), MessageSegment.text("text")]) + ) + + try: + parse_obj_as(Message, object()) + assert False + except ValidationError: + assert True diff --git a/tests/test_adapters/test_template.py b/tests/test_adapters/test_template.py new file mode 100644 index 00000000..3dbef541 --- /dev/null +++ b/tests/test_adapters/test_template.py @@ -0,0 +1,17 @@ +from utils import make_fake_message + + +def test_message_template(): + from nonebot.adapters import MessageTemplate + + Message = make_fake_message() + + template = MessageTemplate("{a:custom}{b:text}{c:image}", Message) + + @template.add_format_spec + def custom(input: str) -> str: + return input + "-custom!" + + formatted = template.format(a="test", b="test", c="https://example.com/test") + assert formatted.extract_plain_text() == "test-custom!test" + assert str(formatted) == "test-custom!test[fake:image]"