nonebot2/nonebot/adapters/_message.py
2022-01-29 18:20:30 +08:00

343 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
FrontMatter:
sidebar_position: 4
description: nonebot.adapters._message 模块
"""
import abc
from copy import deepcopy
from dataclasses import field, asdict, dataclass
from typing import (
Any,
Dict,
List,
Type,
Tuple,
Union,
Generic,
TypeVar,
Iterable,
Optional,
overload,
)
from pydantic import parse_obj_as
from ._template import MessageTemplate
T = TypeVar("T")
TMS = TypeVar("TMS", bound="MessageSegment")
TM = TypeVar("TM", bound="Message")
@dataclass
class MessageSegment(abc.ABC, Generic[TM]):
"""消息段基类"""
type: str
"""消息段类型"""
data: Dict[str, Any] = field(default_factory=dict)
"""消息段数据"""
@classmethod
@abc.abstractmethod
def get_message_class(cls) -> Type[TM]:
"""获取消息数组类型"""
raise NotImplementedError
@abc.abstractmethod
def __str__(self) -> str:
"""该消息段所代表的 str在命令匹配部分使用"""
raise NotImplementedError
def __len__(self) -> int:
return len(str(self))
def __ne__(self: T, other: T) -> bool:
return not self == other
def __add__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM:
return self.get_message_class()(self) + other
def __radd__(self: TMS, other: Union[str, TMS, Iterable[TMS]]) -> TM:
return self.get_message_class()(other) + self
@classmethod
def __get_validators__(cls):
yield cls._validate
@classmethod
def _validate(cls, value):
if isinstance(value, cls):
return value
if not isinstance(value, dict):
raise ValueError(f"Expected dict for MessageSegment, got {type(value)}")
return cls(**value)
def get(self, key: str, default: Any = None):
return asdict(self).get(key, default)
def keys(self):
return asdict(self).keys()
def values(self):
return asdict(self).values()
def items(self):
return asdict(self).items()
def copy(self: T) -> T:
return deepcopy(self)
@abc.abstractmethod
def is_text(self) -> bool:
"""当前消息段是否为纯文本"""
raise NotImplementedError
class Message(List[TMS], abc.ABC):
"""消息数组
参数:
message: 消息内容
"""
def __init__(
self,
message: Union[str, None, Iterable[TMS], TMS] = None,
):
super().__init__()
if message is None:
return
elif isinstance(message, str):
self.extend(self._construct(message))
elif isinstance(message, MessageSegment):
self.append(message)
elif isinstance(message, Iterable):
self.extend(message)
else:
self.extend(self._construct(message))
@classmethod
def template(cls: Type[TM], format_string: Union[str, TM]) -> MessageTemplate[TM]:
"""创建消息模板。
用法和 `str.format` 大致相同, 但是可以输出消息对象, 并且支持以 `Message` 对象作为消息模板
并且提供了拓展的格式化控制符, 可以用适用于该消息类型的 `MessageSegment` 的工厂方法创建消息
用法:
```python
>>> Message.template("{} {}").format("hello", "world") # 基础演示
Message(MessageSegment(type='text', data={'text': 'hello world'}))
>>> Message.template("{} {}").format(MessageSegment.image("file///..."), "world") # 支持消息段等对象
Message(MessageSegment(type='image', data={'file': 'file///...'}), MessageSegment(type='text', data={'text': 'world'}))
>>> Message.template( # 支持以Message对象作为消息模板
... MessageSegment.text('test {event.user_id}') + MessageSegment.face(233) +
... MessageSegment.text('test {event.message}')).format(event={'user_id':123456, 'message':'hello world'})
Message(MessageSegment(type='text', data={'text': 'test 123456'}),
MessageSegment(type='face', data={'face': 233}),
MessageSegment(type='text', data={'text': 'test hello world'}))
>>> Message.template("{link:image}").format(link='https://...') # 支持拓展格式化控制符
Message(MessageSegment(type='image', data={'file': 'https://...'}))
```
参数:
format_string: 格式化字符串
返回:
消息格式化器
"""
return MessageTemplate(format_string, cls)
@classmethod
@abc.abstractmethod
def get_segment_class(cls) -> Type[TMS]:
"""获取消息段类型"""
raise NotImplementedError
def __str__(self) -> str:
return "".join(str(seg) for seg in self)
@classmethod
def __get_validators__(cls):
yield cls._validate
@classmethod
def _validate(cls, value):
if isinstance(value, str):
pass
elif isinstance(value, dict):
value = parse_obj_as(cls.get_segment_class(), value)
elif isinstance(value, Iterable):
value = [parse_obj_as(cls.get_segment_class(), v) for v in value]
else:
raise ValueError(
f"Expected str, dict or iterable for Message, got {type(value)}"
)
return cls(value)
@staticmethod
@abc.abstractmethod
def _construct(msg: str) -> Iterable[TMS]:
"""构造消息数组"""
raise NotImplementedError
def __add__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM:
result = self.copy()
result += other
return result
def __radd__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM:
result = self.__class__(other)
return result + self
def __iadd__(self: TM, other: Union[str, TMS, Iterable[TMS]]) -> TM:
if isinstance(other, str):
self.extend(self._construct(other))
elif isinstance(other, MessageSegment):
self.append(other)
elif isinstance(other, Iterable):
self.extend(other)
else:
raise ValueError(f"Unsupported type: {type(other)}")
return self
@overload
def __getitem__(self: TM, __args: str) -> TM:
"""
参数:
__args: 消息段类型
返回:
所有类型为 `__args` 的消息段
"""
@overload
def __getitem__(self, __args: Tuple[str, int]) -> TMS:
"""
参数:
__args: 消息段类型和索引
返回:
类型为 `__args[0]` 的消息段第 `__args[1]` 个
"""
@overload
def __getitem__(self: TM, __args: Tuple[str, slice]) -> TM:
"""
参数:
__args: 消息段类型和切片
返回:
类型为 `__args[0]` 的消息段切片 `__args[1]`
"""
@overload
def __getitem__(self, __args: int) -> TMS:
"""
参数:
__args: 索引
返回:
第 `__args` 个消息段
"""
@overload
def __getitem__(self: TM, __args: slice) -> TM:
"""
参数:
__args: 切片
返回:
消息切片 `__args`
"""
def __getitem__(
self: TM,
args: Union[
str,
Tuple[str, int],
Tuple[str, slice],
int,
slice,
],
) -> Union[TMS, TM]:
arg1, arg2 = args if isinstance(args, tuple) else (args, None)
if isinstance(arg1, int) and arg2 is None:
return super().__getitem__(arg1)
elif isinstance(arg1, slice) and arg2 is None:
return self.__class__(super().__getitem__(arg1))
elif isinstance(arg1, str) and arg2 is None:
return self.__class__(seg for seg in self if seg.type == arg1)
elif isinstance(arg1, str) and isinstance(arg2, int):
return [seg for seg in self if seg.type == arg1][arg2]
elif isinstance(arg1, str) and isinstance(arg2, slice):
return self.__class__([seg for seg in self if seg.type == arg1][arg2])
else:
raise ValueError("Incorrect arguments to slice")
def index(self, value: Union[TMS, str], *args) -> int:
if isinstance(value, str):
first_segment = next((seg for seg in self if seg.type == value), None)
if first_segment is None:
raise ValueError(f"Segment with type {value} is not in message")
return super().index(first_segment, *args)
return super().index(value, *args)
def get(self: TM, type_: str, count: Optional[int] = None) -> TM:
if count is None:
return self[type_]
iterator, filtered = (
seg for seg in self if seg.type == type_
), self.__class__()
for _ in range(count):
seg = next(iterator, None)
if seg is None:
break
filtered.append(seg)
return filtered
def count(self, value: Union[TMS, str]) -> int:
return len(self[value]) if isinstance(value, str) else super().count(value)
def append(self: TM, obj: Union[str, TMS]) -> TM:
"""添加一个消息段到消息数组末尾。
参数:
obj: 要添加的消息段
"""
if isinstance(obj, MessageSegment):
super().append(obj)
elif isinstance(obj, str):
self.extend(self._construct(obj))
else:
raise ValueError(f"Unexpected type: {type(obj)} {obj}")
return self
def extend(self: TM, obj: Union[TM, Iterable[TMS]]) -> TM:
"""拼接一个消息数组或多个消息段到消息数组末尾。
参数:
obj: 要添加的消息数组
"""
for segment in obj:
self.append(segment)
return self
def copy(self: TM) -> TM:
return deepcopy(self)
def extract_plain_text(self) -> str:
"""提取消息内纯文本消息"""
return "".join(str(seg) for seg in self if seg.is_text())
__autodoc__ = {
"MessageSegment.__str__": True,
"MessageSegment.__add__": True,
"Message.__getitem__": True,
"Message._construct": True,
}