diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py index fa109d7a..f61cbfc3 100644 --- a/nonebot/adapters/__init__.py +++ b/nonebot/adapters/__init__.py @@ -19,6 +19,11 @@ __autodoc__ = { "Event": True, "Adapter": True, "Message": True, + "Message.__getitem__": True, + "Message.__contains__": True, + "Message._construct": True, "MessageSegment": True, + "MessageSegment.__str__": True, + "MessageSegment.__add__": True, "MessageTemplate": True, } diff --git a/nonebot/internal/adapter/message.py b/nonebot/internal/adapter/message.py index 6abb1d48..82e7fdde 100644 --- a/nonebot/internal/adapter/message.py +++ b/nonebot/internal/adapter/message.py @@ -1,5 +1,6 @@ import abc from copy import deepcopy +from typing_extensions import Self from dataclasses import field, asdict, dataclass from typing import ( Any, @@ -12,6 +13,7 @@ from typing import ( TypeVar, Iterable, Optional, + SupportsIndex, overload, ) @@ -19,7 +21,6 @@ from pydantic import parse_obj_as from .template import MessageTemplate -T = TypeVar("T") TMS = TypeVar("TMS", bound="MessageSegment") TM = TypeVar("TM", bound="Message") @@ -47,7 +48,7 @@ class MessageSegment(abc.ABC, Generic[TM]): def __len__(self) -> int: return len(str(self)) - def __ne__(self: T, other: T) -> bool: + def __ne__(self, other: Self) -> bool: return not self == other def __add__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM: @@ -61,7 +62,7 @@ class MessageSegment(abc.ABC, Generic[TM]): yield cls._validate @classmethod - def _validate(cls, value): + def _validate(cls, value) -> Self: if isinstance(value, cls): return value if not isinstance(value, dict): @@ -84,7 +85,10 @@ class MessageSegment(abc.ABC, Generic[TM]): def items(self): return asdict(self).items() - def copy(self: T) -> T: + def join(self: TMS, iterable: Iterable[Union[TMS, TM]]) -> TM: + return self.get_message_class()(self).join(iterable) + + def copy(self) -> Self: return deepcopy(self) @abc.abstractmethod @@ -117,7 +121,7 @@ class Message(List[TMS], abc.ABC): self.extend(self._construct(message)) # pragma: no cover @classmethod - def template(cls: Type[TM], format_string: Union[str, TM]) -> MessageTemplate[TM]: + def template(cls, format_string: Union[str, TM]) -> MessageTemplate[Self]: """创建消息模板。 用法和 `str.format` 大致相同, 但是可以输出消息对象, 并且支持以 `Message` 对象作为消息模板 @@ -146,7 +150,7 @@ class Message(List[TMS], abc.ABC): yield cls._validate @classmethod - def _validate(cls, value): + def _validate(cls, value) -> Self: if isinstance(value, cls): return value elif isinstance(value, Message): @@ -169,16 +173,16 @@ class Message(List[TMS], abc.ABC): """构造消息数组""" raise NotImplementedError - def __add__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM: + def __add__(self, other: Union[str, TMS, Iterable[TMS]]) -> Self: result = self.copy() result += other return result - def __radd__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM: + def __radd__(self, other: Union[str, TMS, Iterable[TMS]]) -> Self: result = self.__class__(other) return result + self - def __iadd__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM: + def __iadd__(self, other: Union[str, TMS, Iterable[TMS]]) -> Self: if isinstance(other, str): self.extend(self._construct(other)) elif isinstance(other, MessageSegment): @@ -190,57 +194,62 @@ class Message(List[TMS], abc.ABC): return self @overload - def __getitem__(self: TM, __args: str) -> TM: - """ + def __getitem__(self, args: str) -> Self: + """获取仅包含指定消息段类型的消息 + 参数: - __args: 消息段类型 + args: 消息段类型 返回: - 所有类型为 `__args` 的消息段 + 所有类型为 `args` 的消息段 """ @overload - def __getitem__(self, __args: Tuple[str, int]) -> TMS: - """ + def __getitem__(self, args: Tuple[str, int]) -> TMS: + """索引指定类型的消息段 + 参数: - __args: 消息段类型和索引 + args: 消息段类型和索引 返回: - 类型为 `__args[0]` 的消息段第 `__args[1]` 个 + 类型为 `args[0]` 的消息段第 `args[1]` 个 """ @overload - def __getitem__(self: TM, __args: Tuple[str, slice]) -> TM: - """ + def __getitem__(self, args: Tuple[str, slice]) -> Self: + """切片指定类型的消息段 + 参数: - __args: 消息段类型和切片 + args: 消息段类型和切片 返回: - 类型为 `__args[0]` 的消息段切片 `__args[1]` + 类型为 `args[0]` 的消息段切片 `args[1]` """ @overload - def __getitem__(self, __args: int) -> TMS: - """ + def __getitem__(self, args: int) -> TMS: + """索引消息段 + 参数: - __args: 索引 + args: 索引 返回: - 第 `__args` 个消息段 + 第 `args` 个消息段 """ @overload - def __getitem__(self: TM, __args: slice) -> TM: - """ + def __getitem__(self, args: slice) -> Self: + """切片消息段 + 参数: - __args: 切片 + args: 切片 返回: - 消息切片 `__args` + 消息切片 `args` """ def __getitem__( - self: TM, + self, args: Union[ str, Tuple[str, int], @@ -248,7 +257,7 @@ class Message(List[TMS], abc.ABC): int, slice, ], - ) -> Union[TMS, TM]: + ) -> Union[TMS, Self]: arg1, arg2 = args if isinstance(args, tuple) else (args, None) if isinstance(arg1, int) and arg2 is None: return super().__getitem__(arg1) @@ -263,15 +272,52 @@ class Message(List[TMS], abc.ABC): else: raise ValueError("Incorrect arguments to slice") # pragma: no cover - def index(self, value: Union[TMS, str], *args) -> int: + def __contains__(self, value: Union[TMS, str]) -> bool: + """检查消息段是否存在 + + 参数: + value: 消息段或消息段类型 + 返回: + 消息内是否存在给定消息段或给定类型的消息段 + """ + if isinstance(value, str): + return bool(next((seg for seg in self if seg.type == value), None)) + return super().__contains__(value) + + def has(self, value: Union[TMS, str]) -> bool: + """与 {ref}``__contains__` ` 相同""" + return value in self + + def index(self, value: Union[TMS, str], *args: SupportsIndex) -> int: + """索引消息段 + + 参数: + value: 消息段或者消息段类型 + arg: start 与 end + + 返回: + 索引 index + + 异常: + ValueError: 消息段不存在 + """ if isinstance(value, str): first_segment = next((seg for seg in self if seg.type == value), None) if first_segment is None: - raise ValueError(f"Segment with type {value} is not in message") + raise ValueError(f"Segment with type {value!r} is not in message") return super().index(first_segment, *args) return super().index(value, *args) - def get(self: TM, type_: str, count: Optional[int] = None) -> TM: + def get(self, type_: str, count: Optional[int] = None) -> Self: + """获取指定类型的消息段 + + 参数: + type_: 消息段类型 + count: 获取个数 + + 返回: + 构建的新消息 + """ if count is None: return self[type_] @@ -286,9 +332,30 @@ class Message(List[TMS], abc.ABC): return filtered def count(self, value: Union[TMS, str]) -> int: + """计算指定消息段的个数 + + 参数: + value: 消息段或消息段类型 + + 返回: + 个数 + """ return len(self[value]) if isinstance(value, str) else super().count(value) - def append(self: TM, obj: Union[str, TMS]) -> TM: + def only(self, value: Union[TMS, str]) -> bool: + """检查消息中是否仅包含指定消息段 + + 参数: + value: 指定消息段或消息段类型 + + 返回: + 是否仅包含指定消息段 + """ + if isinstance(value, str): + return all(seg.type == value for seg in self) + return all(seg == value for seg in self) + + def append(self, obj: Union[str, TMS]) -> Self: """添加一个消息段到消息数组末尾。 参数: @@ -302,7 +369,7 @@ class Message(List[TMS], abc.ABC): raise ValueError(f"Unexpected type: {type(obj)} {obj}") # pragma: no cover return self - def extend(self: TM, obj: Union[TM, Iterable[TMS]]) -> TM: + def extend(self, obj: Union[Self, Iterable[TMS]]) -> Self: """拼接一个消息数组或多个消息段到消息数组末尾。 参数: @@ -312,18 +379,52 @@ class Message(List[TMS], abc.ABC): self.append(segment) return self - def copy(self: TM) -> TM: + def join(self, iterable: Iterable[Union[TMS, Self]]) -> Self: + """将多个消息连接并将自身作为分割 + + 参数: + iterable: 要连接的消息 + + 返回: + 连接后的消息 + """ + ret = self.__class__() + for index, msg in enumerate(iterable): + if index != 0: + ret.extend(self) + if isinstance(msg, MessageSegment): + ret.append(msg.copy()) + else: + ret.extend(msg.copy()) + return ret + + def copy(self) -> Self: + """深拷贝消息""" return deepcopy(self) + def include(self, *types: str) -> Self: + """过滤消息 + + 参数: + types: 包含的消息段类型 + + 返回: + 新构造的消息 + """ + return self.__class__(seg for seg in self if seg.type in types) + + def exclude(self, *types: str) -> Self: + """过滤消息 + + 参数: + types: 不包含的消息段类型 + + 返回: + 新构造的消息 + """ + return self.__class__(seg for seg in self if seg.type not in types) + def extract_plain_text(self) -> str: """提取消息内纯文本消息""" return "".join(str(seg) for seg in self if seg.is_text()) - - -__autodoc__ = { - "MessageSegment.__str__": True, - "MessageSegment.__add__": True, - "Message.__getitem__": True, - "Message._construct": True, -} diff --git a/nonebot/rule.py b/nonebot/rule.py index d241f605..7c9cebde 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -67,7 +67,7 @@ CMD_RESULT = TypedDict( { "command": Optional[Tuple[str, ...]], "raw_command": Optional[str], - "command_arg": Optional[Message[MessageSegment]], + "command_arg": Optional[Message], "command_start": Optional[str], "command_whitespace": Optional[str], }, diff --git a/tests/test_adapters/test_message.py b/tests/test_adapters/test_message.py index 1827d465..6e6bc377 100644 --- a/tests/test_adapters/test_message.py +++ b/tests/test_adapters/test_message.py @@ -41,6 +41,26 @@ def test_segment_validate(): parse_obj_as(MessageSegment, {"data": {}}) +def test_segment_join(): + Message = make_fake_message() + MessageSegment = Message.get_segment_class() + + seg = MessageSegment.text("test") + iterable = [ + MessageSegment.text("first"), + Message([MessageSegment.text("second"), MessageSegment.text("third")]), + ] + + assert seg.join(iterable) == Message( + [ + MessageSegment.text("first"), + MessageSegment.text("test"), + MessageSegment.text("second"), + MessageSegment.text("third"), + ] + ) + + def test_segment(): Message = make_fake_message() MessageSegment = Message.get_segment_class() @@ -146,3 +166,124 @@ def test_message_validate(): with pytest.raises(ValidationError): parse_obj_as(Message, object()) + + +def test_message_contains(): + Message = make_fake_message() + MessageSegment = Message.get_segment_class() + + message = Message( + [ + MessageSegment.text("test"), + MessageSegment.image("test2"), + MessageSegment.image("test3"), + MessageSegment.text("test4"), + ] + ) + + assert message.has(MessageSegment.text("test")) is True + assert MessageSegment.text("test") in message + assert message.has("image") is True + assert "image" in message + + assert message.has(MessageSegment.text("foo")) is False + assert MessageSegment.text("foo") not in message + assert message.has("foo") is False + assert "foo" not in message + + +def test_message_only(): + Message = make_fake_message() + MessageSegment = Message.get_segment_class() + + message = Message( + [ + MessageSegment.text("test"), + MessageSegment.text("test2"), + ] + ) + + assert message.only("text") is True + assert message.only(MessageSegment.text("test")) is False + + message = Message( + [ + MessageSegment.text("test"), + MessageSegment.image("test2"), + MessageSegment.image("test3"), + MessageSegment.text("test4"), + ] + ) + + assert message.only("text") is False + + message = Message( + [ + MessageSegment.text("test"), + MessageSegment.text("test"), + ] + ) + + assert message.only(MessageSegment.text("test")) is True + + +def test_message_join(): + Message = make_fake_message() + MessageSegment = Message.get_segment_class() + + msg = Message([MessageSegment.text("test")]) + iterable = [ + MessageSegment.text("first"), + Message([MessageSegment.text("second"), MessageSegment.text("third")]), + ] + + assert msg.join(iterable) == Message( + [ + MessageSegment.text("first"), + MessageSegment.text("test"), + MessageSegment.text("second"), + MessageSegment.text("third"), + ] + ) + + +def test_message_include(): + Message = make_fake_message() + MessageSegment = Message.get_segment_class() + + message = Message( + [ + MessageSegment.text("test"), + MessageSegment.image("test2"), + MessageSegment.image("test3"), + MessageSegment.text("test4"), + ] + ) + + assert message.include("text") == Message( + [ + MessageSegment.text("test"), + MessageSegment.text("test4"), + ] + ) + + +def test_message_exclude(): + Message = make_fake_message() + MessageSegment = Message.get_segment_class() + + message = Message( + [ + MessageSegment.text("test"), + MessageSegment.image("test2"), + MessageSegment.image("test3"), + MessageSegment.text("test4"), + ] + ) + + assert message.exclude("image") == Message( + [ + MessageSegment.text("test"), + MessageSegment.text("test4"), + ] + ) diff --git a/tests/utils.py b/tests/utils.py index 9b15557f..5367014a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,7 +13,7 @@ def escape_text(s: str, *, escape_comma: bool = True) -> str: def make_fake_message(): - class FakeMessageSegment(MessageSegment): + class FakeMessageSegment(MessageSegment["FakeMessage"]): @classmethod def get_message_class(cls): return FakeMessage @@ -36,7 +36,7 @@ def make_fake_message(): def is_text(self) -> bool: return self.type == "text" - class FakeMessage(Message): + class FakeMessage(Message[FakeMessageSegment]): @classmethod def get_segment_class(cls): return FakeMessageSegment @@ -50,7 +50,9 @@ def make_fake_message(): yield FakeMessageSegment(**seg) return - def __add__(self, other): + def __add__( + self, other: Union[str, FakeMessageSegment, Iterable[FakeMessageSegment]] + ): other = escape_text(other) if isinstance(other, str) else other return super().__add__(other) diff --git a/website/docs/tutorial/message.md b/website/docs/tutorial/message.md index 935c1a4f..33af7a80 100644 --- a/website/docs/tutorial/message.md +++ b/website/docs/tutorial/message.md @@ -120,16 +120,37 @@ Message( ### 遍历 -`Message` 继承自 `List[MessageSegment]` ,因此可以使用 `for` 循环遍历消息段。 +消息序列继承自 `List[MessageSegment]` ,因此可以使用 `for` 循环遍历消息段。 ```python for segment in message: ... ``` -### 索引与切片 +### 比较 -`Message` 对列表的索引与切片进行了增强,在原有列表 int 索引与切片的基础上,支持 `type` 过滤索引与切片。 +消息和消息段都可以使用 `==` 或 `!=` 运算符比较是否相同。 + +```python +MessageSegment.text("text") != MessageSegment.text("foo") + +some_message == Message([MessageSegment.text("text")]) +``` + +### 检查消息段 + +我们可以通过 `in` 运算符或消息序列的 `has` 方法来: + +```python +# 是否存在消息段 +MessageSegment.text("text") in message +# 是否存在指定类型的消息段 +"text" in message +``` + +### 过滤、索引与切片 + +消息序列对列表的索引与切片进行了增强,在原有列表 `int` 索引与 `slice` 切片的基础上,支持 `type` 过滤索引与切片。 ```python from nonebot.adapters.console import Message, MessageSegment @@ -160,7 +181,14 @@ message["markdown", 0:2] == Message( ) ``` -同样的,`Message` 对列表的 `index`、`count` 方法也进行了增强,可以用于索引指定类型的消息段。 +我们也可以通过消息序列的 `include`、`exclude` 方法进行类型过滤。 + +```python +message.include("text", "markdown") +message.exclude("text") +``` + +同样的,消息序列对列表的 `index`、`count` 方法也进行了增强,可以用于索引指定类型的消息段。 ```python # 指定类型首个消息段索引 @@ -169,7 +197,7 @@ message.index("markdown") == 1 message.count("markdown") == 2 ``` -此外,`Message` 添加了一个 `get` 方法,可以用于获取指定类型指定个数的消息段。 +此外,消息序列添加了一个 `get` 方法,可以用于获取指定类型指定个数的消息段。 ```python # 获取指定类型指定个数的消息段 @@ -214,6 +242,31 @@ msg.append(MessageSegment.text("text")) msg.extend([MessageSegment.text("text")]) ``` +我们也可以通过消息段或消息序列的 `join` 方法来拼接一串消息: + +```python +seg = MessageSegment.text("text") +msg = seg.join( + [ + MessageSegment.text("first"), + Message( + [ + MessageSegment.text("second"), + MessageSegment.text("third"), + ] + ) + ] +) +msg == Message( + [ + MessageSegment.text("first"), + MessageSegment.text("text"), + MessageSegment.text("second"), + MessageSegment.text("third"), + ] +) +``` + ### 使用消息模板 为了提供安全可靠的跨平台模板字符, 我们提供了一个消息模板功能来构建消息序列