nonebot2/nonebot/internal/adapter/message.py

443 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import abc
from collections.abc import Iterable
from copy import deepcopy
from dataclasses import asdict, dataclass, field
from typing import ( # noqa: UP035
Any,
Generic,
Optional,
SupportsIndex,
Type,
TypeVar,
Union,
overload,
)
from typing_extensions import Self
from nonebot.compat import custom_validation, type_validate_python
from .template import MessageTemplate
TMS = TypeVar("TMS", bound="MessageSegment")
TM = TypeVar("TM", bound="Message")
@custom_validation
@dataclass
class MessageSegment(abc.ABC, Generic[TM]):
"""消息段基类"""
type: str
"""消息段类型"""
data: dict[str, Any] = field(default_factory=dict)
"""消息段数据"""
@classmethod
@abc.abstractmethod
def get_message_class(cls) -> Type[TM]: # noqa: UP006
"""获取消息数组类型"""
raise NotImplementedError
@abc.abstractmethod
def __str__(self) -> str:
"""该消息段所代表的 str在命令匹配部分使用"""
raise NotImplementedError
def __len__(self) -> int:
return len(str(self))
def __ne__( # pyright: ignore[reportIncompatibleMethodOverride]
self, other: Self
) -> bool:
return not self == other
def __add__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM:
return self.get_message_class()(self) + other
def __radd__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM:
return self.get_message_class()(other) + self
@classmethod
def __get_validators__(cls):
yield cls._validate
@classmethod
def _validate(cls, value) -> Self:
if isinstance(value, cls):
return value
if isinstance(value, MessageSegment):
raise ValueError(f"Type {type(value)} can not be converted to {cls}")
if not isinstance(value, dict):
raise ValueError(f"Expected dict for MessageSegment, got {type(value)}")
if "type" not in value:
raise ValueError(
f"Expected dict with 'type' for MessageSegment, got {value}"
)
return cls(type=value["type"], data=value.get("data", {}))
def get(self, key: str, default: Any = None):
return asdict(self).get(key, default)
def keys(self):
return asdict(self).keys()
def values(self):
return asdict(self).values()
def items(self):
return asdict(self).items()
def join(self: TMS, iterable: Iterable[Union[TMS, TM]]) -> TM:
return self.get_message_class()(self).join(iterable)
def copy(self) -> Self:
return deepcopy(self)
@abc.abstractmethod
def is_text(self) -> bool:
"""当前消息段是否为纯文本"""
raise NotImplementedError
@custom_validation
class Message(list[TMS], abc.ABC):
"""消息序列
参数:
message: 消息内容
"""
def __init__(
self,
message: Union[str, None, Iterable[TMS], TMS] = None,
):
super().__init__()
if message is None:
return
elif isinstance(message, str):
self.extend(self._construct(message))
elif isinstance(message, MessageSegment):
self.append(message)
elif isinstance(message, Iterable):
self.extend(message)
else:
self.extend(self._construct(message)) # pragma: no cover
@classmethod
def template(cls, format_string: Union[str, TM]) -> MessageTemplate[Self]:
"""创建消息模板。
用法和 `str.format` 大致相同,支持以 `Message` 对象作为消息模板并输出消息对象。
并且提供了拓展的格式化控制符,
可以通过该消息类型的 `MessageSegment` 工厂方法创建消息。
参数:
format_string: 格式化模板
返回:
消息格式化器
"""
return MessageTemplate(format_string, cls)
@classmethod
@abc.abstractmethod
def get_segment_class(cls) -> type[TMS]:
"""获取消息段类型"""
raise NotImplementedError
def __str__(self) -> str:
return "".join(str(seg) for seg in self)
@classmethod
def __get_validators__(cls):
yield cls._validate
@classmethod
def _validate(cls, value) -> Self:
if isinstance(value, cls):
return value
elif isinstance(value, Message):
raise ValueError(f"Type {type(value)} can not be converted to {cls}")
elif isinstance(value, str):
pass
elif isinstance(value, dict):
value = type_validate_python(cls.get_segment_class(), value)
elif isinstance(value, Iterable):
value = [type_validate_python(cls.get_segment_class(), v) for v in value]
else:
raise ValueError(
f"Expected str, dict or iterable for Message, got {type(value)}"
)
return cls(value)
@staticmethod
@abc.abstractmethod
def _construct(msg: str) -> Iterable[TMS]:
"""构造消息数组"""
raise NotImplementedError
def __add__( # pyright: ignore[reportIncompatibleMethodOverride]
self, other: Union[str, TMS, Iterable[TMS]]
) -> Self:
result = self.copy()
result += other
return result
def __radd__(self, other: Union[str, TMS, Iterable[TMS]]) -> Self:
result = self.__class__(other)
return result + self
def __iadd__(self, other: Union[str, TMS, Iterable[TMS]]) -> Self:
if isinstance(other, str):
self.extend(self._construct(other))
elif isinstance(other, MessageSegment):
self.append(other)
elif isinstance(other, Iterable):
self.extend(other)
else:
raise TypeError(f"Unsupported type {type(other)!r}")
return self
@overload
def __getitem__(self, args: str) -> Self:
"""获取仅包含指定消息段类型的消息
参数:
args: 消息段类型
返回:
所有类型为 `args` 的消息段
"""
@overload
def __getitem__(self, args: tuple[str, int]) -> TMS:
"""索引指定类型的消息段
参数:
args: 消息段类型和索引
返回:
类型为 `args[0]` 的消息段第 `args[1]` 个
"""
@overload
def __getitem__(self, args: tuple[str, slice]) -> Self:
"""切片指定类型的消息段
参数:
args: 消息段类型和切片
返回:
类型为 `args[0]` 的消息段切片 `args[1]`
"""
@overload
def __getitem__(self, args: int) -> TMS:
"""索引消息段
参数:
args: 索引
返回:
第 `args` 个消息段
"""
@overload
def __getitem__(self, args: slice) -> Self:
"""切片消息段
参数:
args: 切片
返回:
消息切片 `args`
"""
def __getitem__( # pyright: ignore[reportIncompatibleMethodOverride]
self,
args: Union[
str,
tuple[str, int],
tuple[str, slice],
int,
slice,
],
) -> Union[TMS, Self]:
arg1, arg2 = args if isinstance(args, tuple) else (args, None)
if isinstance(arg1, int) and arg2 is None:
return super().__getitem__(arg1)
elif isinstance(arg1, slice) and arg2 is None:
return self.__class__(super().__getitem__(arg1))
elif isinstance(arg1, str) and arg2 is None:
return self.__class__(seg for seg in self if seg.type == arg1)
elif isinstance(arg1, str) and isinstance(arg2, int):
return [seg for seg in self if seg.type == arg1][arg2]
elif isinstance(arg1, str) and isinstance(arg2, slice):
return self.__class__([seg for seg in self if seg.type == arg1][arg2])
else:
raise ValueError("Incorrect arguments to slice") # pragma: no cover
def __contains__( # pyright: ignore[reportIncompatibleMethodOverride]
self, value: Union[TMS, str]
) -> bool:
"""检查消息段是否存在
参数:
value: 消息段或消息段类型
返回:
消息内是否存在给定消息段或给定类型的消息段
"""
if isinstance(value, str):
return next((seg for seg in self if seg.type == value), None) is not None
return super().__contains__(value)
def has(self, value: Union[TMS, str]) -> bool:
"""{ref}``__contains__` <nonebot.adapters.Message.__contains__>` 相同"""
return value in self
def index(self, value: Union[TMS, str], *args: SupportsIndex) -> int:
"""索引消息段
参数:
value: 消息段或者消息段类型
arg: start 与 end
返回:
索引 index
异常:
ValueError: 消息段不存在
"""
if isinstance(value, str):
first_segment = next((seg for seg in self if seg.type == value), None)
if first_segment is None:
raise ValueError(f"Segment with type {value!r} is not in message")
return super().index(first_segment, *args)
return super().index(value, *args)
def get(self, type_: str, count: Optional[int] = None) -> Self:
"""获取指定类型的消息段
参数:
type_: 消息段类型
count: 获取个数
返回:
构建的新消息
"""
if count is None:
return self[type_]
iterator, filtered = (
(seg for seg in self if seg.type == type_),
self.__class__(),
)
for _ in range(count):
seg = next(iterator, None)
if seg is None:
break
filtered.append(seg)
return filtered
def count(self, value: Union[TMS, str]) -> int:
"""计算指定消息段的个数
参数:
value: 消息段或消息段类型
返回:
个数
"""
return len(self[value]) if isinstance(value, str) else super().count(value)
def only(self, value: Union[TMS, str]) -> bool:
"""检查消息中是否仅包含指定消息段
参数:
value: 指定消息段或消息段类型
返回:
是否仅包含指定消息段
"""
if isinstance(value, str):
return all(seg.type == value for seg in self)
return all(seg == value for seg in self)
def append( # pyright: ignore[reportIncompatibleMethodOverride]
self, obj: Union[str, TMS]
) -> Self:
"""添加一个消息段到消息数组末尾。
参数:
obj: 要添加的消息段
"""
if isinstance(obj, MessageSegment):
super().append(obj)
elif isinstance(obj, str):
self.extend(self._construct(obj))
else:
raise ValueError(f"Unexpected type: {type(obj)} {obj}") # pragma: no cover
return self
def extend( # pyright: ignore[reportIncompatibleMethodOverride]
self, obj: Union[Self, Iterable[TMS]]
) -> Self:
"""拼接一个消息数组或多个消息段到消息数组末尾。
参数:
obj: 要添加的消息数组
"""
for segment in obj:
self.append(segment)
return self
def join(self, iterable: Iterable[Union[TMS, Self]]) -> Self:
"""将多个消息连接并将自身作为分割
参数:
iterable: 要连接的消息
返回:
连接后的消息
"""
ret = self.__class__()
for index, msg in enumerate(iterable):
if index != 0:
ret.extend(self)
if isinstance(msg, MessageSegment):
ret.append(msg.copy())
else:
ret.extend(msg.copy())
return ret
def copy(self) -> Self:
"""深拷贝消息"""
return deepcopy(self)
def include(self, *types: str) -> Self:
"""过滤消息
参数:
types: 包含的消息段类型
返回:
新构造的消息
"""
return self.__class__(seg for seg in self if seg.type in types)
def exclude(self, *types: str) -> Self:
"""过滤消息
参数:
types: 不包含的消息段类型
返回:
新构造的消息
"""
return self.__class__(seg for seg in self if seg.type not in types)
def extract_plain_text(self) -> str:
"""提取消息内纯文本消息"""
return "".join(str(seg) for seg in self if seg.is_text())