Feature: 为消息类添加 has join include exclude 方法 (#1895)

This commit is contained in:
Ju4tCode 2023-04-04 21:42:01 +08:00 committed by GitHub
parent 20820e72ad
commit 1817102a7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 356 additions and 54 deletions

View File

@ -19,6 +19,11 @@ __autodoc__ = {
"Event": True, "Event": True,
"Adapter": True, "Adapter": True,
"Message": True, "Message": True,
"Message.__getitem__": True,
"Message.__contains__": True,
"Message._construct": True,
"MessageSegment": True, "MessageSegment": True,
"MessageSegment.__str__": True,
"MessageSegment.__add__": True,
"MessageTemplate": True, "MessageTemplate": True,
} }

View File

@ -1,5 +1,6 @@
import abc import abc
from copy import deepcopy from copy import deepcopy
from typing_extensions import Self
from dataclasses import field, asdict, dataclass from dataclasses import field, asdict, dataclass
from typing import ( from typing import (
Any, Any,
@ -12,6 +13,7 @@ from typing import (
TypeVar, TypeVar,
Iterable, Iterable,
Optional, Optional,
SupportsIndex,
overload, overload,
) )
@ -19,7 +21,6 @@ from pydantic import parse_obj_as
from .template import MessageTemplate from .template import MessageTemplate
T = TypeVar("T")
TMS = TypeVar("TMS", bound="MessageSegment") TMS = TypeVar("TMS", bound="MessageSegment")
TM = TypeVar("TM", bound="Message") TM = TypeVar("TM", bound="Message")
@ -47,7 +48,7 @@ class MessageSegment(abc.ABC, Generic[TM]):
def __len__(self) -> int: def __len__(self) -> int:
return len(str(self)) return len(str(self))
def __ne__(self: T, other: T) -> bool: def __ne__(self, other: Self) -> bool:
return not self == other return not self == other
def __add__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM: def __add__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM:
@ -61,7 +62,7 @@ class MessageSegment(abc.ABC, Generic[TM]):
yield cls._validate yield cls._validate
@classmethod @classmethod
def _validate(cls, value): def _validate(cls, value) -> Self:
if isinstance(value, cls): if isinstance(value, cls):
return value return value
if not isinstance(value, dict): if not isinstance(value, dict):
@ -84,7 +85,10 @@ class MessageSegment(abc.ABC, Generic[TM]):
def items(self): def items(self):
return asdict(self).items() return asdict(self).items()
def copy(self: T) -> T: def join(self: TMS, iterable: Iterable[Union[TMS, TM]]) -> TM:
return self.get_message_class()(self).join(iterable)
def copy(self) -> Self:
return deepcopy(self) return deepcopy(self)
@abc.abstractmethod @abc.abstractmethod
@ -117,7 +121,7 @@ class Message(List[TMS], abc.ABC):
self.extend(self._construct(message)) # pragma: no cover self.extend(self._construct(message)) # pragma: no cover
@classmethod @classmethod
def template(cls: Type[TM], format_string: Union[str, TM]) -> MessageTemplate[TM]: def template(cls, format_string: Union[str, TM]) -> MessageTemplate[Self]:
"""创建消息模板。 """创建消息模板。
用法和 `str.format` 大致相同, 但是可以输出消息对象, 并且支持以 `Message` 对象作为消息模板 用法和 `str.format` 大致相同, 但是可以输出消息对象, 并且支持以 `Message` 对象作为消息模板
@ -146,7 +150,7 @@ class Message(List[TMS], abc.ABC):
yield cls._validate yield cls._validate
@classmethod @classmethod
def _validate(cls, value): def _validate(cls, value) -> Self:
if isinstance(value, cls): if isinstance(value, cls):
return value return value
elif isinstance(value, Message): elif isinstance(value, Message):
@ -169,16 +173,16 @@ class Message(List[TMS], abc.ABC):
"""构造消息数组""" """构造消息数组"""
raise NotImplementedError raise NotImplementedError
def __add__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM: def __add__(self, other: Union[str, TMS, Iterable[TMS]]) -> Self:
result = self.copy() result = self.copy()
result += other result += other
return result return result
def __radd__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM: def __radd__(self, other: Union[str, TMS, Iterable[TMS]]) -> Self:
result = self.__class__(other) result = self.__class__(other)
return result + self return result + self
def __iadd__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM: def __iadd__(self, other: Union[str, TMS, Iterable[TMS]]) -> Self:
if isinstance(other, str): if isinstance(other, str):
self.extend(self._construct(other)) self.extend(self._construct(other))
elif isinstance(other, MessageSegment): elif isinstance(other, MessageSegment):
@ -190,57 +194,62 @@ class Message(List[TMS], abc.ABC):
return self return self
@overload @overload
def __getitem__(self: TM, __args: str) -> TM: def __getitem__(self, args: str) -> Self:
""" """获取仅包含指定消息段类型的消息
参数: 参数:
__args: 消息段类型 args: 消息段类型
返回: 返回:
所有类型为 `__args` 的消息段 所有类型为 `args` 的消息段
""" """
@overload @overload
def __getitem__(self, __args: Tuple[str, int]) -> TMS: def __getitem__(self, args: Tuple[str, int]) -> TMS:
""" """索引指定类型的消息段
参数: 参数:
__args: 消息段类型和索引 args: 消息段类型和索引
返回: 返回:
类型为 `__args[0]` 的消息段第 `__args[1]` 类型为 `args[0]` 的消息段第 `args[1]`
""" """
@overload @overload
def __getitem__(self: TM, __args: Tuple[str, slice]) -> TM: def __getitem__(self, args: Tuple[str, slice]) -> Self:
""" """切片指定类型的消息段
参数: 参数:
__args: 消息段类型和切片 args: 消息段类型和切片
返回: 返回:
类型为 `__args[0]` 的消息段切片 `__args[1]` 类型为 `args[0]` 的消息段切片 `args[1]`
""" """
@overload @overload
def __getitem__(self, __args: int) -> TMS: def __getitem__(self, args: int) -> TMS:
""" """索引消息段
参数: 参数:
__args: 索引 args: 索引
返回: 返回:
`__args` 个消息段 `args` 个消息段
""" """
@overload @overload
def __getitem__(self: TM, __args: slice) -> TM: def __getitem__(self, args: slice) -> Self:
""" """切片消息段
参数: 参数:
__args: 切片 args: 切片
返回: 返回:
消息切片 `__args` 消息切片 `args`
""" """
def __getitem__( def __getitem__(
self: TM, self,
args: Union[ args: Union[
str, str,
Tuple[str, int], Tuple[str, int],
@ -248,7 +257,7 @@ class Message(List[TMS], abc.ABC):
int, int,
slice, slice,
], ],
) -> Union[TMS, TM]: ) -> Union[TMS, Self]:
arg1, arg2 = args if isinstance(args, tuple) else (args, None) arg1, arg2 = args if isinstance(args, tuple) else (args, None)
if isinstance(arg1, int) and arg2 is None: if isinstance(arg1, int) and arg2 is None:
return super().__getitem__(arg1) return super().__getitem__(arg1)
@ -263,15 +272,52 @@ class Message(List[TMS], abc.ABC):
else: else:
raise ValueError("Incorrect arguments to slice") # pragma: no cover raise ValueError("Incorrect arguments to slice") # pragma: no cover
def index(self, value: Union[TMS, str], *args) -> int: def __contains__(self, value: Union[TMS, str]) -> bool:
"""检查消息段是否存在
参数:
value: 消息段或消息段类型
返回:
消息内是否存在给定消息段或给定类型的消息段
"""
if isinstance(value, str):
return bool(next((seg for seg in self if seg.type == value), None))
return super().__contains__(value)
def has(self, value: Union[TMS, str]) -> bool:
"""{ref}``__contains__` <nonebot.adapters.Message.__contains__>` 相同"""
return value in self
def index(self, value: Union[TMS, str], *args: SupportsIndex) -> int:
"""索引消息段
参数:
value: 消息段或者消息段类型
arg: start end
返回:
索引 index
异常:
ValueError: 消息段不存在
"""
if isinstance(value, str): if isinstance(value, str):
first_segment = next((seg for seg in self if seg.type == value), None) first_segment = next((seg for seg in self if seg.type == value), None)
if first_segment is None: if first_segment is None:
raise ValueError(f"Segment with type {value} is not in message") raise ValueError(f"Segment with type {value!r} is not in message")
return super().index(first_segment, *args) return super().index(first_segment, *args)
return super().index(value, *args) return super().index(value, *args)
def get(self: TM, type_: str, count: Optional[int] = None) -> TM: def get(self, type_: str, count: Optional[int] = None) -> Self:
"""获取指定类型的消息段
参数:
type_: 消息段类型
count: 获取个数
返回:
构建的新消息
"""
if count is None: if count is None:
return self[type_] return self[type_]
@ -286,9 +332,30 @@ class Message(List[TMS], abc.ABC):
return filtered return filtered
def count(self, value: Union[TMS, str]) -> int: def count(self, value: Union[TMS, str]) -> int:
"""计算指定消息段的个数
参数:
value: 消息段或消息段类型
返回:
个数
"""
return len(self[value]) if isinstance(value, str) else super().count(value) return len(self[value]) if isinstance(value, str) else super().count(value)
def append(self: TM, obj: Union[str, TMS]) -> TM: def only(self, value: Union[TMS, str]) -> bool:
"""检查消息中是否仅包含指定消息段
参数:
value: 指定消息段或消息段类型
返回:
是否仅包含指定消息段
"""
if isinstance(value, str):
return all(seg.type == value for seg in self)
return all(seg == value for seg in self)
def append(self, obj: Union[str, TMS]) -> Self:
"""添加一个消息段到消息数组末尾。 """添加一个消息段到消息数组末尾。
参数: 参数:
@ -302,7 +369,7 @@ class Message(List[TMS], abc.ABC):
raise ValueError(f"Unexpected type: {type(obj)} {obj}") # pragma: no cover raise ValueError(f"Unexpected type: {type(obj)} {obj}") # pragma: no cover
return self return self
def extend(self: TM, obj: Union[TM, Iterable[TMS]]) -> TM: def extend(self, obj: Union[Self, Iterable[TMS]]) -> Self:
"""拼接一个消息数组或多个消息段到消息数组末尾。 """拼接一个消息数组或多个消息段到消息数组末尾。
参数: 参数:
@ -312,18 +379,52 @@ class Message(List[TMS], abc.ABC):
self.append(segment) self.append(segment)
return self return self
def copy(self: TM) -> TM: def join(self, iterable: Iterable[Union[TMS, Self]]) -> Self:
"""将多个消息连接并将自身作为分割
参数:
iterable: 要连接的消息
返回:
连接后的消息
"""
ret = self.__class__()
for index, msg in enumerate(iterable):
if index != 0:
ret.extend(self)
if isinstance(msg, MessageSegment):
ret.append(msg.copy())
else:
ret.extend(msg.copy())
return ret
def copy(self) -> Self:
"""深拷贝消息"""
return deepcopy(self) return deepcopy(self)
def include(self, *types: str) -> Self:
"""过滤消息
参数:
types: 包含的消息段类型
返回:
新构造的消息
"""
return self.__class__(seg for seg in self if seg.type in types)
def exclude(self, *types: str) -> Self:
"""过滤消息
参数:
types: 不包含的消息段类型
返回:
新构造的消息
"""
return self.__class__(seg for seg in self if seg.type not in types)
def extract_plain_text(self) -> str: def extract_plain_text(self) -> str:
"""提取消息内纯文本消息""" """提取消息内纯文本消息"""
return "".join(str(seg) for seg in self if seg.is_text()) return "".join(str(seg) for seg in self if seg.is_text())
__autodoc__ = {
"MessageSegment.__str__": True,
"MessageSegment.__add__": True,
"Message.__getitem__": True,
"Message._construct": True,
}

View File

@ -67,7 +67,7 @@ CMD_RESULT = TypedDict(
{ {
"command": Optional[Tuple[str, ...]], "command": Optional[Tuple[str, ...]],
"raw_command": Optional[str], "raw_command": Optional[str],
"command_arg": Optional[Message[MessageSegment]], "command_arg": Optional[Message],
"command_start": Optional[str], "command_start": Optional[str],
"command_whitespace": Optional[str], "command_whitespace": Optional[str],
}, },

View File

@ -41,6 +41,26 @@ def test_segment_validate():
parse_obj_as(MessageSegment, {"data": {}}) parse_obj_as(MessageSegment, {"data": {}})
def test_segment_join():
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
seg = MessageSegment.text("test")
iterable = [
MessageSegment.text("first"),
Message([MessageSegment.text("second"), MessageSegment.text("third")]),
]
assert seg.join(iterable) == Message(
[
MessageSegment.text("first"),
MessageSegment.text("test"),
MessageSegment.text("second"),
MessageSegment.text("third"),
]
)
def test_segment(): def test_segment():
Message = make_fake_message() Message = make_fake_message()
MessageSegment = Message.get_segment_class() MessageSegment = Message.get_segment_class()
@ -146,3 +166,124 @@ def test_message_validate():
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
parse_obj_as(Message, object()) parse_obj_as(Message, object())
def test_message_contains():
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
message = Message(
[
MessageSegment.text("test"),
MessageSegment.image("test2"),
MessageSegment.image("test3"),
MessageSegment.text("test4"),
]
)
assert message.has(MessageSegment.text("test")) is True
assert MessageSegment.text("test") in message
assert message.has("image") is True
assert "image" in message
assert message.has(MessageSegment.text("foo")) is False
assert MessageSegment.text("foo") not in message
assert message.has("foo") is False
assert "foo" not in message
def test_message_only():
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
message = Message(
[
MessageSegment.text("test"),
MessageSegment.text("test2"),
]
)
assert message.only("text") is True
assert message.only(MessageSegment.text("test")) is False
message = Message(
[
MessageSegment.text("test"),
MessageSegment.image("test2"),
MessageSegment.image("test3"),
MessageSegment.text("test4"),
]
)
assert message.only("text") is False
message = Message(
[
MessageSegment.text("test"),
MessageSegment.text("test"),
]
)
assert message.only(MessageSegment.text("test")) is True
def test_message_join():
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
msg = Message([MessageSegment.text("test")])
iterable = [
MessageSegment.text("first"),
Message([MessageSegment.text("second"), MessageSegment.text("third")]),
]
assert msg.join(iterable) == Message(
[
MessageSegment.text("first"),
MessageSegment.text("test"),
MessageSegment.text("second"),
MessageSegment.text("third"),
]
)
def test_message_include():
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
message = Message(
[
MessageSegment.text("test"),
MessageSegment.image("test2"),
MessageSegment.image("test3"),
MessageSegment.text("test4"),
]
)
assert message.include("text") == Message(
[
MessageSegment.text("test"),
MessageSegment.text("test4"),
]
)
def test_message_exclude():
Message = make_fake_message()
MessageSegment = Message.get_segment_class()
message = Message(
[
MessageSegment.text("test"),
MessageSegment.image("test2"),
MessageSegment.image("test3"),
MessageSegment.text("test4"),
]
)
assert message.exclude("image") == Message(
[
MessageSegment.text("test"),
MessageSegment.text("test4"),
]
)

View File

@ -13,7 +13,7 @@ def escape_text(s: str, *, escape_comma: bool = True) -> str:
def make_fake_message(): def make_fake_message():
class FakeMessageSegment(MessageSegment): class FakeMessageSegment(MessageSegment["FakeMessage"]):
@classmethod @classmethod
def get_message_class(cls): def get_message_class(cls):
return FakeMessage return FakeMessage
@ -36,7 +36,7 @@ def make_fake_message():
def is_text(self) -> bool: def is_text(self) -> bool:
return self.type == "text" return self.type == "text"
class FakeMessage(Message): class FakeMessage(Message[FakeMessageSegment]):
@classmethod @classmethod
def get_segment_class(cls): def get_segment_class(cls):
return FakeMessageSegment return FakeMessageSegment
@ -50,7 +50,9 @@ def make_fake_message():
yield FakeMessageSegment(**seg) yield FakeMessageSegment(**seg)
return return
def __add__(self, other): def __add__(
self, other: Union[str, FakeMessageSegment, Iterable[FakeMessageSegment]]
):
other = escape_text(other) if isinstance(other, str) else other other = escape_text(other) if isinstance(other, str) else other
return super().__add__(other) return super().__add__(other)

View File

@ -120,16 +120,37 @@ Message(
### 遍历 ### 遍历
`Message` 继承自 `List[MessageSegment]` ,因此可以使用 `for` 循环遍历消息段。 消息序列继承自 `List[MessageSegment]` ,因此可以使用 `for` 循环遍历消息段。
```python ```python
for segment in message: for segment in message:
... ...
``` ```
### 索引与切片 ### 比较
`Message` 对列表的索引与切片进行了增强,在原有列表 int 索引与切片的基础上,支持 `type` 过滤索引与切片。 消息和消息段都可以使用 `==``!=` 运算符比较是否相同。
```python
MessageSegment.text("text") != MessageSegment.text("foo")
some_message == Message([MessageSegment.text("text")])
```
### 检查消息段
我们可以通过 `in` 运算符或消息序列的 `has` 方法来:
```python
# 是否存在消息段
MessageSegment.text("text") in message
# 是否存在指定类型的消息段
"text" in message
```
### 过滤、索引与切片
消息序列对列表的索引与切片进行了增强,在原有列表 `int` 索引与 `slice` 切片的基础上,支持 `type` 过滤索引与切片。
```python ```python
from nonebot.adapters.console import Message, MessageSegment from nonebot.adapters.console import Message, MessageSegment
@ -160,7 +181,14 @@ message["markdown", 0:2] == Message(
) )
``` ```
同样的,`Message` 对列表的 `index`、`count` 方法也进行了增强,可以用于索引指定类型的消息段。 我们也可以通过消息序列的 `include`、`exclude` 方法进行类型过滤。
```python
message.include("text", "markdown")
message.exclude("text")
```
同样的,消息序列对列表的 `index`、`count` 方法也进行了增强,可以用于索引指定类型的消息段。
```python ```python
# 指定类型首个消息段索引 # 指定类型首个消息段索引
@ -169,7 +197,7 @@ message.index("markdown") == 1
message.count("markdown") == 2 message.count("markdown") == 2
``` ```
此外,`Message` 添加了一个 `get` 方法,可以用于获取指定类型指定个数的消息段。 此外,消息序列添加了一个 `get` 方法,可以用于获取指定类型指定个数的消息段。
```python ```python
# 获取指定类型指定个数的消息段 # 获取指定类型指定个数的消息段
@ -214,6 +242,31 @@ msg.append(MessageSegment.text("text"))
msg.extend([MessageSegment.text("text")]) msg.extend([MessageSegment.text("text")])
``` ```
我们也可以通过消息段或消息序列的 `join` 方法来拼接一串消息:
```python
seg = MessageSegment.text("text")
msg = seg.join(
[
MessageSegment.text("first"),
Message(
[
MessageSegment.text("second"),
MessageSegment.text("third"),
]
)
]
)
msg == Message(
[
MessageSegment.text("first"),
MessageSegment.text("text"),
MessageSegment.text("second"),
MessageSegment.text("third"),
]
)
```
### 使用消息模板 ### 使用消息模板
为了提供安全可靠的跨平台模板字符, 我们提供了一个消息模板功能来构建消息序列 为了提供安全可靠的跨平台模板字符, 我们提供了一个消息模板功能来构建消息序列