implement message constructor

This commit is contained in:
StarHeartHunt 2021-07-04 14:19:36 +08:00
parent 7a8f881b04
commit f9f1e33262

View File

@ -1,10 +1,13 @@
from typing import Type, Union, Mapping, Iterable import itertools
import json
from typing import Any, Tuple, Type, Union, Mapping, Iterable
from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment
from nonebot.typing import overrides from nonebot.typing import overrides
class MessageSegment(BaseMessageSegment): class MessageSegment(BaseMessageSegment["Message"]):
""" """
飞书 协议 MessageSegment 适配具体方法参考协议消息段类型或源码 飞书 协议 MessageSegment 适配具体方法参考协议消息段类型或源码
""" """
@ -19,12 +22,17 @@ class MessageSegment(BaseMessageSegment):
return str(self.data["text"]) return str(self.data["text"])
return "" return ""
@overrides(BaseMessageSegment)
def __add__(self, other) -> "Message": def __add__(self, other) -> "Message":
return Message(self) + other return Message(self) + (MessageSegment.text(other) if isinstance(
other, str) else other)
@overrides(BaseMessageSegment)
def __radd__(self, other) -> "Message": def __radd__(self, other) -> "Message":
return Message(other) + self return (MessageSegment.text(other)
if isinstance(other, str) else Message(other)) + self
@overrides(BaseMessageSegment)
def is_text(self) -> bool: def is_text(self) -> bool:
return self.type == "text" return self.type == "text"
@ -123,7 +131,7 @@ class MessageSegment(BaseMessageSegment):
}) })
class Message(BaseMessage): class Message(BaseMessage[MessageSegment]):
""" """
飞书 协议 Message 适配 飞书 协议 Message 适配
""" """
@ -133,12 +141,38 @@ class Message(BaseMessage):
def get_segment_class(cls) -> Type[MessageSegment]: def get_segment_class(cls) -> Type[MessageSegment]:
return MessageSegment return MessageSegment
@overrides(BaseMessage)
def __add__(self, other: Union[str, Mapping,
Iterable[Mapping]]) -> "Message":
return super(Message, self).__add__(
MessageSegment.text(other) if isinstance(other, str) else other)
@overrides(BaseMessage)
def __radd__(self, other: Union[str, Mapping,
Iterable[Mapping]]) -> "Message":
return super(Message, self).__radd__(
MessageSegment.text(other) if isinstance(other, str) else other)
@staticmethod @staticmethod
@overrides(BaseMessage)
def _construct( def _construct(
msg: Union[str, Mapping, msg: Union[str, Mapping,
Iterable[Mapping]]) -> Iterable[MessageSegment]: Iterable[Mapping]]) -> Iterable[MessageSegment]:
if isinstance(msg, Mapping): if isinstance(msg, Mapping):
yield MessageSegment(msg["type"], msg.get("data") or {})
def _iter_message(msg: Mapping) -> Iterable[Tuple[str, dict]]:
pure_text: str = msg.get("text", "")
content: dict = msg.get("content", {})
if pure_text and not content:
yield "text", {"text": pure_text}
elif content and not pure_text:
for element in list(itertools.chain(*content)):
tag = element.pop("tag")
yield tag, element
for type_, data in _iter_message(msg):
yield MessageSegment(type_, data)
elif isinstance(msg, str): elif isinstance(msg, str):
yield MessageSegment.text(msg) yield MessageSegment.text(msg)
elif isinstance(msg, Iterable): elif isinstance(msg, Iterable):
@ -147,3 +181,7 @@ class Message(BaseMessage):
def _produce(self) -> dict: def _produce(self) -> dict:
raise NotImplementedError raise NotImplementedError
@overrides(BaseMessage)
def extract_plain_text(self) -> str:
return "".join(seg.data["text"] for seg in self if seg.is_text())