From e887c3999866cf84694653b8369c9383c1e884e2 Mon Sep 17 00:00:00 2001 From: yanyongyu <42488585+yanyongyu@users.noreply.github.com> Date: Sat, 29 Jan 2022 13:56:54 +0800 Subject: [PATCH 1/7] :label: update message typing --- nonebot/adapters/_message.py | 123 ++++++++++++++++++++++------------- 1 file changed, 76 insertions(+), 47 deletions(-) diff --git a/nonebot/adapters/_message.py b/nonebot/adapters/_message.py index 57eaebaf..8232fac4 100644 --- a/nonebot/adapters/_message.py +++ b/nonebot/adapters/_message.py @@ -24,12 +24,12 @@ from typing import ( from ._template import MessageTemplate T = TypeVar("T") -TMS = TypeVar("TMS", covariant=True) +TMS = TypeVar("TMS", bound="MessageSegment") TM = TypeVar("TM", bound="Message") @dataclass -class MessageSegment(Mapping, abc.ABC, Generic[TM]): +class MessageSegment(abc.ABC, Generic[TM]): """消息段基类""" type: str @@ -54,26 +54,14 @@ class MessageSegment(Mapping, abc.ABC, Generic[TM]): def __ne__(self: T, other: T) -> bool: return not self == other - def __add__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> TM: - return self.get_message_class()(self) + other # type: ignore + def __add__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM: + return self.get_message_class()(self) + other - def __radd__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> TM: - return self.get_message_class()(other) + self # type: ignore - - def __getitem__(self, key: str): - return getattr(self, key) - - def __setitem__(self, key: str, value: Any): - return setattr(self, key, value) - - def __iter__(self): - yield from asdict(self).keys() - - def __contains__(self, key: Any) -> bool: - return key in asdict(self).keys() + def __radd__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM: + return self.get_message_class()(other) + self def get(self, key: str, default: Any = None): - return getattr(self, key, default) + return asdict(self).get(key, default) def keys(self): return asdict(self).keys() @@ -101,18 +89,18 @@ class Message(List[TMS], abc.ABC): """ def __init__( - self: TM, - message: Union[str, None, Mapping, Iterable[Mapping], TMS, TM, Any] = None, - *args, - **kwargs, + self, + message: Union[str, None, Iterable[TMS], TMS, Any] = None, ): - super().__init__(*args, **kwargs) + super().__init__() if message is None: return - elif isinstance(message, Message): - self.extend(message) + elif isinstance(message, str): + self.extend(self._construct(message)) elif isinstance(message, MessageSegment): self.append(message) + elif isinstance(message, Iterable): + self.extend(message) else: self.extend(self._construct(message)) @@ -154,7 +142,7 @@ class Message(List[TMS], abc.ABC): """获取消息段类型""" raise NotImplementedError - def __str__(self): + def __str__(self) -> str: return "".join(str(seg) for seg in self) @classmethod @@ -167,47 +155,79 @@ class Message(List[TMS], abc.ABC): @staticmethod @abc.abstractmethod - def _construct(msg: Union[str, Mapping, Iterable[Mapping], Any]) -> Iterable[TMS]: + def _construct(msg: Union[str, Any]) -> Iterable[TMS]: """构造消息数组""" raise NotImplementedError - def __add__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM: + def __add__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM: result = self.copy() result += other return result - def __radd__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM: - result = self.__class__(other) # type: ignore + def __radd__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM: + result = self.__class__(other) return result + self - def __iadd__(self: TM, other: Union[str, Mapping, Iterable[Mapping]]) -> TM: - if isinstance(other, MessageSegment): + def __iadd__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM: + if isinstance(other, str): + self.extend(self._construct(other)) + elif isinstance(other, MessageSegment): self.append(other) - elif isinstance(other, Message): + elif isinstance(other, Iterable): self.extend(other) else: - self.extend(self._construct(other)) + raise ValueError(f"Unsupported type: {type(other)}") return self @overload def __getitem__(self: TM, __args: str) -> TM: - ... + """ + 参数: + __args: 消息段类型 + + 返回: + 所有类型为 `__args` 的消息段 + """ @overload def __getitem__(self, __args: Tuple[str, int]) -> TMS: - ... + """ + 参数: + __args: 消息段类型和索引 + + 返回: + 类型为 `__args[0]` 的消息段第 `__args[1]` 个 + """ @overload def __getitem__(self: TM, __args: Tuple[str, slice]) -> TM: - ... + """ + 参数: + __args: 消息段类型和切片 + + 返回: + 类型为 `__args[0]` 的消息段切片 `__args[1]` + """ @overload def __getitem__(self, __args: int) -> TMS: - ... + """ + 参数: + __args: 索引 + + 返回: + 第 `__args` 个消息段 + """ @overload def __getitem__(self: TM, __args: slice) -> TM: - ... + """ + 参数: + __args: 切片 + + 返回: + 消息切片 `__args` + """ def __getitem__( self: TM, @@ -235,21 +255,25 @@ class Message(List[TMS], abc.ABC): def index(self, value: Union[TMS, str], *args) -> int: if isinstance(value, str): - first_segment = next((seg for seg in self if seg.type == value), None) # type: ignore - return super().index(first_segment, *args) # type: ignore + first_segment = next((seg for seg in self if seg.type == value), None) + if first_segment is None: + raise ValueError(f"Segment with type {value} is not in message") + return super().index(first_segment, *args) return super().index(value, *args) def get(self: TM, type_: str, count: Optional[int] = None) -> TM: if count is None: return self[type_] - iterator, filtered = (seg for seg in self if seg.type == type_), [] + iterator, filtered = ( + seg for seg in self if seg.type == type_ + ), self.__class__() for _ in range(count): seg = next(iterator, None) if seg is None: break filtered.append(seg) - return self.__class__(filtered) + return filtered def count(self, value: Union[TMS, str]) -> int: return len(self[value]) if isinstance(value, str) else super().count(value) @@ -261,7 +285,7 @@ class Message(List[TMS], abc.ABC): obj: 要添加的消息段 """ if isinstance(obj, MessageSegment): - super(Message, self).append(obj) + super().append(obj) elif isinstance(obj, str): self.extend(self._construct(obj)) else: @@ -281,10 +305,15 @@ class Message(List[TMS], abc.ABC): def copy(self: TM) -> TM: return deepcopy(self) - def extract_plain_text(self: "Message[MessageSegment]") -> str: + def extract_plain_text(self) -> str: """提取消息内纯文本消息""" return "".join(str(seg) for seg in self if seg.is_text()) -__autodoc__ = {"MessageSegment.__str__": True} +__autodoc__ = { + "MessageSegment.__str__": True, + "MessageSegment.__add__": True, + "Message.__getitem__": True, + "Message._construct": True, +} From b43dfb983d8aa1e90ea1c28255c71b3731f076d8 Mon Sep 17 00:00:00 2001 From: yanyongyu <42488585+yanyongyu@users.noreply.github.com> Date: Sat, 29 Jan 2022 15:36:25 +0800 Subject: [PATCH 2/7] :coffin: remove unused import --- nonebot/adapters/_message.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nonebot/adapters/_message.py b/nonebot/adapters/_message.py index 8232fac4..3ff6e9b9 100644 --- a/nonebot/adapters/_message.py +++ b/nonebot/adapters/_message.py @@ -14,7 +14,6 @@ from typing import ( Tuple, Union, Generic, - Mapping, TypeVar, Iterable, Optional, From 5fa7806a2f349248bb1d7d6e76a1de792f264680 Mon Sep 17 00:00:00 2001 From: yanyongyu <42488585+yanyongyu@users.noreply.github.com> Date: Sat, 29 Jan 2022 18:20:30 +0800 Subject: [PATCH 3/7] :wheelchair: improve pydantic validate for message --- nonebot/adapters/_message.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/nonebot/adapters/_message.py b/nonebot/adapters/_message.py index 3ff6e9b9..fc8d8c3f 100644 --- a/nonebot/adapters/_message.py +++ b/nonebot/adapters/_message.py @@ -20,6 +20,8 @@ from typing import ( overload, ) +from pydantic import parse_obj_as + from ._template import MessageTemplate T = TypeVar("T") @@ -33,7 +35,7 @@ class MessageSegment(abc.ABC, Generic[TM]): type: str """消息段类型""" - data: Dict[str, Any] = field(default_factory=lambda: {}) + data: Dict[str, Any] = field(default_factory=dict) """消息段数据""" @classmethod @@ -59,6 +61,18 @@ class MessageSegment(abc.ABC, Generic[TM]): def __radd__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM: return self.get_message_class()(other) + self + @classmethod + def __get_validators__(cls): + yield cls._validate + + @classmethod + def _validate(cls, value): + if isinstance(value, cls): + return value + if not isinstance(value, dict): + raise ValueError(f"Expected dict for MessageSegment, got {type(value)}") + return cls(**value) + def get(self, key: str, default: Any = None): return asdict(self).get(key, default) @@ -89,7 +103,7 @@ class Message(List[TMS], abc.ABC): def __init__( self, - message: Union[str, None, Iterable[TMS], TMS, Any] = None, + message: Union[str, None, Iterable[TMS], TMS] = None, ): super().__init__() if message is None: @@ -150,11 +164,21 @@ class Message(List[TMS], abc.ABC): @classmethod def _validate(cls, value): + if isinstance(value, str): + pass + elif isinstance(value, dict): + value = parse_obj_as(cls.get_segment_class(), value) + elif isinstance(value, Iterable): + value = [parse_obj_as(cls.get_segment_class(), v) for v in value] + else: + raise ValueError( + f"Expected str, dict or iterable for Message, got {type(value)}" + ) return cls(value) @staticmethod @abc.abstractmethod - def _construct(msg: Union[str, Any]) -> Iterable[TMS]: + def _construct(msg: str) -> Iterable[TMS]: """构造消息数组""" raise NotImplementedError From 5abf55d0956f26c650ee9d7c8447154820138133 Mon Sep 17 00:00:00 2001 From: yanyongyu <42488585+yanyongyu@users.noreply.github.com> Date: Sat, 29 Jan 2022 23:39:13 +0800 Subject: [PATCH 4/7] :white_check_mark: add message tests --- tests/test_adapters/test_message.py | 48 +++++++++++++++++++++------- tests/test_adapters/test_template.py | 17 ++++++++++ 2 files changed, 54 insertions(+), 11 deletions(-) create mode 100644 tests/test_adapters/test_template.py diff --git a/tests/test_adapters/test_message.py b/tests/test_adapters/test_message.py index 88d9e563..56fbab7c 100644 --- a/tests/test_adapters/test_message.py +++ b/tests/test_adapters/test_message.py @@ -1,23 +1,49 @@ from utils import make_fake_message -def test_message_template(): - from nonebot.adapters import MessageTemplate - +def test_segment_add(): Message = make_fake_message() + MessageSegment = Message.get_segment_class() - template = MessageTemplate("{a:custom}{b:text}{c:image}", Message) + assert MessageSegment.text("text") + MessageSegment.text("text") == Message( + [MessageSegment.text("text"), MessageSegment.text("text")] + ) - @template.add_format_spec - def custom(input: str) -> str: - return input + "-custom!" + assert MessageSegment.text("text") + "text" == Message( + [MessageSegment.text("text"), MessageSegment.text("text")] + ) - formatted = template.format(a="test", b="test", c="https://example.com/test") - assert formatted.extract_plain_text() == "test-custom!test" - assert str(formatted) == "test-custom!test[fake:image]" + assert MessageSegment.text("text") + Message( + [MessageSegment.text("text")] + ) == Message([MessageSegment.text("text"), MessageSegment.text("text")]) + + assert "text" + MessageSegment.text("text") == Message( + [MessageSegment.text("text"), MessageSegment.text("text")] + ) -def test_message_slice(): +def test_message_add(): + Message = make_fake_message() + MessageSegment = Message.get_segment_class() + + assert Message([MessageSegment.text("text")]) + MessageSegment.text( + "text" + ) == Message([MessageSegment.text("text"), MessageSegment.text("text")]) + + assert Message([MessageSegment.text("text")]) + "text" == Message( + [MessageSegment.text("text"), MessageSegment.text("text")] + ) + + assert Message([MessageSegment.text("text")]) + Message( + [MessageSegment.text("text")] + ) == Message([MessageSegment.text("text"), MessageSegment.text("text")]) + + msg = Message([MessageSegment.text("text")]) + msg += MessageSegment.text("text") + assert msg == Message([MessageSegment.text("text"), MessageSegment.text("text")]) + + +def test_message_getitem(): Message = make_fake_message() MessageSegment = Message.get_segment_class() diff --git a/tests/test_adapters/test_template.py b/tests/test_adapters/test_template.py new file mode 100644 index 00000000..3dbef541 --- /dev/null +++ b/tests/test_adapters/test_template.py @@ -0,0 +1,17 @@ +from utils import make_fake_message + + +def test_message_template(): + from nonebot.adapters import MessageTemplate + + Message = make_fake_message() + + template = MessageTemplate("{a:custom}{b:text}{c:image}", Message) + + @template.add_format_spec + def custom(input: str) -> str: + return input + "-custom!" + + formatted = template.format(a="test", b="test", c="https://example.com/test") + assert formatted.extract_plain_text() == "test-custom!test" + assert str(formatted) == "test-custom!test[fake:image]" From 2ec5917709947f2420a9028cb38566bd22c71177 Mon Sep 17 00:00:00 2001 From: yanyongyu <42488585+yanyongyu@users.noreply.github.com> Date: Sat, 29 Jan 2022 23:55:14 +0800 Subject: [PATCH 5/7] :bug: fix missing self instance validate --- nonebot/adapters/_message.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nonebot/adapters/_message.py b/nonebot/adapters/_message.py index fc8d8c3f..21c4a46b 100644 --- a/nonebot/adapters/_message.py +++ b/nonebot/adapters/_message.py @@ -164,7 +164,9 @@ class Message(List[TMS], abc.ABC): @classmethod def _validate(cls, value): - if isinstance(value, str): + if isinstance(value, cls): + return value + elif isinstance(value, str): pass elif isinstance(value, dict): value = parse_obj_as(cls.get_segment_class(), value) From f3cc93c699f62eeee5a79cf489525b109998043c Mon Sep 17 00:00:00 2001 From: yanyongyu <42488585+yanyongyu@users.noreply.github.com> Date: Sun, 30 Jan 2022 00:05:01 +0800 Subject: [PATCH 6/7] :white_check_mark: add more tests --- tests/test_adapters/test_message.py | 38 +++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/test_adapters/test_message.py b/tests/test_adapters/test_message.py index 56fbab7c..a5d88df0 100644 --- a/tests/test_adapters/test_message.py +++ b/tests/test_adapters/test_message.py @@ -1,3 +1,5 @@ +from pydantic import ValidationError, parse_obj_as + from utils import make_fake_message @@ -22,6 +24,21 @@ def test_segment_add(): ) +def test_segment_validate(): + Message = make_fake_message() + MessageSegment = Message.get_segment_class() + + assert parse_obj_as( + MessageSegment, {"type": "text", "data": {"text": "text"}} + ) == MessageSegment.text("text") + + try: + parse_obj_as(MessageSegment, "some str") + assert False + except ValidationError: + assert True + + def test_message_add(): Message = make_fake_message() MessageSegment = Message.get_segment_class() @@ -78,3 +95,24 @@ def test_message_getitem(): assert message.get("image", 1) == Message([message["image", 0]]) assert message.count("image") == 2 + + +def test_message_validate(): + Message = make_fake_message() + MessageSegment = Message.get_segment_class() + + assert parse_obj_as(Message, "text") == Message([MessageSegment.text("text")]) + + assert parse_obj_as(Message, {"type": "text", "data": {"text": "text"}}) == Message( + [MessageSegment.text("text")] + ) + + assert parse_obj_as( + Message, [{"type": "text", "data": {"text": "text"}}] + ) == Message([MessageSegment.text("text")]) + + try: + parse_obj_as(Message, object()) + assert False + except ValidationError: + assert True From 2cd6867bd1a6f392f485e0a8f0741c38904f8109 Mon Sep 17 00:00:00 2001 From: yanyongyu <42488585+yanyongyu@users.noreply.github.com> Date: Sun, 30 Jan 2022 11:04:02 +0800 Subject: [PATCH 7/7] :white_check_mark: add more tests --- nonebot/adapters/_message.py | 10 ++++--- tests/test_adapters/test_message.py | 44 +++++++++++++++++++++++++++-- 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/nonebot/adapters/_message.py b/nonebot/adapters/_message.py index 21c4a46b..ffeb2bf5 100644 --- a/nonebot/adapters/_message.py +++ b/nonebot/adapters/_message.py @@ -115,7 +115,7 @@ class Message(List[TMS], abc.ABC): elif isinstance(message, Iterable): self.extend(message) else: - self.extend(self._construct(message)) + self.extend(self._construct(message)) # pragma: no cover @classmethod def template(cls: Type[TM], format_string: Union[str, TM]) -> MessageTemplate[TM]: @@ -166,6 +166,8 @@ class Message(List[TMS], abc.ABC): def _validate(cls, value): if isinstance(value, cls): return value + elif isinstance(value, Message): + raise ValueError(f"Type {type(value)} can not be converted to {cls}") elif isinstance(value, str): pass elif isinstance(value, dict): @@ -201,7 +203,7 @@ class Message(List[TMS], abc.ABC): elif isinstance(other, Iterable): self.extend(other) else: - raise ValueError(f"Unsupported type: {type(other)}") + raise ValueError(f"Unsupported type: {type(other)}") # pragma: no cover return self @overload @@ -276,7 +278,7 @@ class Message(List[TMS], abc.ABC): elif isinstance(arg1, str) and isinstance(arg2, slice): return self.__class__([seg for seg in self if seg.type == arg1][arg2]) else: - raise ValueError("Incorrect arguments to slice") + raise ValueError("Incorrect arguments to slice") # pragma: no cover def index(self, value: Union[TMS, str], *args) -> int: if isinstance(value, str): @@ -314,7 +316,7 @@ class Message(List[TMS], abc.ABC): elif isinstance(obj, str): self.extend(self._construct(obj)) else: - raise ValueError(f"Unexpected type: {type(obj)} {obj}") + raise ValueError(f"Unexpected type: {type(obj)} {obj}") # pragma: no cover return self def extend(self: TM, obj: Union[TM, Iterable[TMS]]) -> TM: diff --git a/tests/test_adapters/test_message.py b/tests/test_adapters/test_message.py index a5d88df0..fd6bd5a4 100644 --- a/tests/test_adapters/test_message.py +++ b/tests/test_adapters/test_message.py @@ -39,6 +39,26 @@ def test_segment_validate(): assert True +def test_segment(): + Message = make_fake_message() + MessageSegment = Message.get_segment_class() + + assert len(MessageSegment.text("text")) == 4 + assert MessageSegment.text("text") != MessageSegment.text("other") + assert MessageSegment.text("text").get("data") == {"text": "text"} + assert list(MessageSegment.text("text").keys()) == ["type", "data"] + assert list(MessageSegment.text("text").values()) == ["text", {"text": "text"}] + assert list(MessageSegment.text("text").items()) == [ + ("type", "text"), + ("data", {"text": "text"}), + ] + + origin = MessageSegment.text("text") + copy = origin.copy() + assert origin is not copy + assert origin == copy + + def test_message_add(): Message = make_fake_message() MessageSegment = Message.get_segment_class() @@ -55,6 +75,10 @@ def test_message_add(): [MessageSegment.text("text")] ) == Message([MessageSegment.text("text"), MessageSegment.text("text")]) + assert "text" + Message([MessageSegment.text("text")]) == Message( + [MessageSegment.text("text"), MessageSegment.text("text")] + ) + msg = Message([MessageSegment.text("text")]) msg += MessageSegment.text("text") assert msg == Message([MessageSegment.text("text"), MessageSegment.text("text")]) @@ -101,15 +125,29 @@ def test_message_validate(): Message = make_fake_message() MessageSegment = Message.get_segment_class() + Message_ = make_fake_message() + + assert parse_obj_as(Message, Message([])) == Message([]) + + try: + parse_obj_as(Message, Message_([])) + assert False + except ValidationError: + assert True + assert parse_obj_as(Message, "text") == Message([MessageSegment.text("text")]) assert parse_obj_as(Message, {"type": "text", "data": {"text": "text"}}) == Message( [MessageSegment.text("text")] ) - assert parse_obj_as( - Message, [{"type": "text", "data": {"text": "text"}}] - ) == Message([MessageSegment.text("text")]) + assert ( + parse_obj_as( + Message, + [MessageSegment.text("text"), {"type": "text", "data": {"text": "text"}}], + ) + == Message([MessageSegment.text("text"), MessageSegment.text("text")]) + ) try: parse_obj_as(Message, object())