import json import itertools from dataclasses import dataclass from typing import (Any, Dict, List, Type, Tuple, Union, Mapping, Iterable, Optional, cast) from nonebot.typing import overrides from nonebot.adapters import Message as BaseMessage from nonebot.adapters import MessageSegment as BaseMessageSegment class MessageSegment(BaseMessageSegment["Message"]): """ 飞书 协议 MessageSegment 适配。具体方法参考协议消息段类型或源码。 """ @classmethod @overrides(BaseMessageSegment) def get_message_class(cls) -> Type["Message"]: return Message @property def segment_text(self) -> dict: return { "image": "[图片]", "file": "[文件]", "audio": "[音频]", "media": "[视频]", "sticker": "[表情包]", "interactive": "[消息卡片]", "hongbao": "[红包]", "share_calendar_event": "[日程卡片]", "share_chat": "[群名片]", "share_user": "[个人名片]", "system": "[系统消息]", "location": "[位置]", "video_chat": "[视频通话]" } def __str__(self) -> str: if self.type in ["text", "hongbao", "a"]: return str(self.data["text"]) elif self.type == "at": return str(f"@{self.data['user_name']}") else: return self.segment_text.get(self.type, "") @overrides(BaseMessageSegment) def __add__(self, other) -> "Message": return Message(self) + (MessageSegment.text(other) if isinstance( other, str) else other) @overrides(BaseMessageSegment) def __radd__(self, other) -> "Message": return (MessageSegment.text(other) if isinstance(other, str) else Message(other)) + self @overrides(BaseMessageSegment) def is_text(self) -> bool: return self.type == "text" #接收消息 @staticmethod def at(user_id: str) -> "MessageSegment": return MessageSegment("at", {"user_id": user_id}) #发送消息 @staticmethod def text(text: str) -> "MessageSegment": return MessageSegment("text", {"text": text}) @staticmethod def post(title: str, content: list) -> "MessageSegment": return MessageSegment("post", {"title": title, "content": content}) @staticmethod def image(image_key: str) -> "MessageSegment": return MessageSegment("image", {"image_key": image_key}) @staticmethod def interactive(title: str, elements: list) -> "MessageSegment": return MessageSegment("interactive", { "title": title, "elements": elements }) @staticmethod def share_chat(chat_id: str) -> "MessageSegment": return MessageSegment("share_chat", {"chat_id": chat_id}) @staticmethod def share_user(user_id: str) -> "MessageSegment": return MessageSegment("share_user", {"user_id": user_id}) @staticmethod def audio(file_key: str, duration: int) -> "MessageSegment": return MessageSegment("audio", { "file_key": file_key, "duration": duration }) @staticmethod def media(file_key: str, image_key: str, file_name: str, duration: int) -> "MessageSegment": return MessageSegment( "media", { "file_key": file_key, "image_key": image_key, "file_name": file_name, "duration": duration }) @staticmethod def file(file_key: str, file_name: str) -> "MessageSegment": return MessageSegment("file", { "file_key": file_key, "file_name": file_name }) @staticmethod def sticker(file_key) -> "MessageSegment": return MessageSegment("sticker", {"file_key": file_key}) class Message(BaseMessage[MessageSegment]): """ 飞书 协议 Message 适配。 """ @classmethod @overrides(BaseMessage) def get_segment_class(cls) -> Type[MessageSegment]: return MessageSegment @overrides(BaseMessage) def __add__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> "Message": return super(Message, self).__add__( MessageSegment.text(other) if isinstance(other, str) else other) @overrides(BaseMessage) def __radd__(self, other: Union[str, Mapping, Iterable[Mapping]]) -> "Message": return super(Message, self).__radd__( MessageSegment.text(other) if isinstance(other, str) else other) @staticmethod @overrides(BaseMessage) def _construct( msg: Union[str, Mapping, Iterable[Mapping]]) -> Iterable[MessageSegment]: if isinstance(msg, Mapping): msg = cast(Mapping[str, Any], msg) yield MessageSegment(msg["type"], msg.get("data") or {}) return elif isinstance(msg, str): yield MessageSegment.text(msg) elif isinstance(msg, Iterable): for seg in msg: if isinstance(seg, MessageSegment): yield seg else: yield MessageSegment(seg["type"], seg.get("data") or {}) def _merge(self) -> "Message": i: int seg: MessageSegment msg: List[MessageSegment] = [] for i, seg in enumerate(self): if seg.type == "text" and i != 0 and msg[-1].type == "text": msg[-1] = MessageSegment( "text", {"text": msg[-1].data["text"] + seg.data["text"]}) else: msg.append(seg) return Message(msg) @overrides(BaseMessage) def extract_plain_text(self) -> str: return "".join(seg.data["text"] for seg in self if seg.is_text()) @dataclass class MessageSerializer: """ 飞书 协议 Message 序列化器。 """ message: Message def serialize(self) -> Tuple[str, str]: segments = list(self.message) last_segment_type: str = "" if len(segments) > 1: msg = {"title": "", "content": [[]]} for segment in segments: if segment == "image": if last_segment_type != "image": msg["content"].append([]) else: if last_segment_type == "image": msg["content"].append([]) msg["content"][-1].append({ "tag": segment.type if segment.type != "image" else "img", **segment.data }) last_segment_type = segment.type return "post", json.dumps({"zh_cn": {**msg}}) else: return self.message[0].type, json.dumps(self.message[0].data) @dataclass class MessageDeserializer: """ 飞书 协议 Message 反序列化器。 """ type: str data: Dict[str, Any] mentions: Optional[List[dict]] def deserialize(self) -> Message: dict_mention = {} if self.mentions: for mention in self.mentions: dict_mention[mention["key"]] = mention if self.type == "post": msg = Message() if self.data["title"] != "": msg += MessageSegment("text", {'text': self.data["title"]}) for seg in itertools.chain(*self.data["content"]): tag = seg.pop("tag") if tag == "at": seg["user_name"] = dict_mention[seg["user_id"]]["name"] seg["user_id"] = dict_mention[ seg["user_id"]]["id"]["open_id"] msg += MessageSegment(tag if tag != "img" else "image", seg) return msg._merge() elif self.type == "text": for key, mention in dict_mention.items(): self.data["text"] = self.data["text"].replace( key, f"@{mention['name']}") self.data["mentions"] = dict_mention return Message(MessageSegment(self.type, self.data)) else: return Message(MessageSegment(self.type, self.data))