251 lines
8.0 KiB
Python
Raw Normal View History

2021-07-07 22:36:08 +08:00
import itertools
2021-07-23 14:46:55 +08:00
import json
from dataclasses import dataclass
2021-07-23 14:46:55 +08:00
from typing import (Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type,
Union)
2021-07-03 13:56:47 +08:00
2021-07-23 14:46:55 +08:00
from nonebot.adapters import Message as BaseMessage
from nonebot.adapters import MessageSegment as BaseMessageSegment
2021-07-03 13:56:47 +08:00
from nonebot.typing import overrides
2021-07-01 07:59:50 +08:00
2021-07-04 14:19:36 +08:00
class MessageSegment(BaseMessageSegment["Message"]):
2021-07-03 13:56:47 +08:00
"""
飞书 协议 MessageSegment 适配具体方法参考协议消息段类型或源码
"""
@classmethod
@overrides(BaseMessageSegment)
def get_message_class(cls) -> Type["Message"]:
return Message
2021-07-01 07:59:50 +08:00
2021-07-08 14:52:34 +08:00
@property
def segment_text(self) -> dict:
return {
"image": "[图片]",
"file": "[文件]",
"audio": "[音频]",
"media": "[视频]",
"sticker": "[表情包]",
"interactive": "[消息卡片]",
"hongbao": "[红包]",
"share_calendar_event": "[日程卡片]",
"share_chat": "[群名片]",
"share_user": "[个人名片]",
"system": "[系统消息]",
"location": "[位置]",
"video_chat": "[视频通话]"
}
2021-07-01 07:59:50 +08:00
def __str__(self) -> str:
2021-07-08 14:52:34 +08:00
if self.type in ["text", "hongbao", "a"]:
2021-07-03 13:56:47 +08:00
return str(self.data["text"])
2021-07-08 14:52:34 +08:00
elif self.type == "at":
return str(f"@{self.data['user_name']}")
else:
return self.segment_text.get(self.type, "")
2021-07-01 07:59:50 +08:00
2021-07-04 14:19:36 +08:00
@overrides(BaseMessageSegment)
2021-07-01 07:59:50 +08:00
def __add__(self, other) -> "Message":
2021-07-04 14:19:36 +08:00
return Message(self) + (MessageSegment.text(other) if isinstance(
other, str) else other)
2021-07-01 07:59:50 +08:00
2021-07-04 14:19:36 +08:00
@overrides(BaseMessageSegment)
2021-07-01 07:59:50 +08:00
def __radd__(self, other) -> "Message":
2021-07-04 14:19:36 +08:00
return (MessageSegment.text(other)
if isinstance(other, str) else Message(other)) + self
2021-07-01 07:59:50 +08:00
2021-07-04 14:19:36 +08:00
@overrides(BaseMessageSegment)
2021-07-01 07:59:50 +08:00
def is_text(self) -> bool:
2021-07-03 13:56:47 +08:00
return self.type == "text"
2021-07-08 16:12:43 +08:00
#接收消息
@staticmethod
2021-07-08 23:05:24 +08:00
def at(user_id: str) -> "MessageSegment":
return MessageSegment("at", {"user_id": user_id})
2021-07-08 16:12:43 +08:00
#发送消息
2021-07-03 13:56:47 +08:00
@staticmethod
def text(text: str) -> "MessageSegment":
return MessageSegment("text", {"text": text})
@staticmethod
def post(title: str, content: list) -> "MessageSegment":
return MessageSegment("post", {"title": title, "content": content})
@staticmethod
def image(image_key: str) -> "MessageSegment":
return MessageSegment("image", {"image_key": image_key})
@staticmethod
2021-07-08 10:07:55 +08:00
def interactive(title: str, elements: list) -> "MessageSegment":
return MessageSegment("interactive", {
"title": title,
"elements": elements
2021-07-03 13:56:47 +08:00
})
2021-07-08 10:07:55 +08:00
@staticmethod
def share_chat(chat_id: str) -> "MessageSegment":
return MessageSegment("share_chat", {"chat_id": chat_id})
@staticmethod
def share_user(user_id: str) -> "MessageSegment":
return MessageSegment("share_user", {"user_id": user_id})
2021-07-03 13:56:47 +08:00
@staticmethod
def audio(file_key: str, duration: int) -> "MessageSegment":
return MessageSegment("audio", {
"file_key": file_key,
"duration": duration
})
@staticmethod
def media(file_key: str, image_key: str, file_name: str,
duration: int) -> "MessageSegment":
return MessageSegment(
"media", {
"file_key": file_key,
"image_key": image_key,
"file_name": file_name,
"duration": duration
})
@staticmethod
2021-07-08 10:07:55 +08:00
def file(file_key: str, file_name: str) -> "MessageSegment":
return MessageSegment("file", {
"file_key": file_key,
"file_name": file_name
2021-07-03 13:56:47 +08:00
})
@staticmethod
2021-07-08 10:07:55 +08:00
def sticker(file_key) -> "MessageSegment":
return MessageSegment("sticker", {"file_key": file_key})
2021-07-01 07:59:50 +08:00
2021-07-04 14:19:36 +08:00
class Message(BaseMessage[MessageSegment]):
2021-07-03 13:56:47 +08:00
"""
飞书 协议 Message 适配
"""
@classmethod
@overrides(BaseMessage)
def get_segment_class(cls) -> Type[MessageSegment]:
return MessageSegment
2021-07-01 07:59:50 +08:00
2021-07-04 14:19:36 +08:00
@overrides(BaseMessage)
def __add__(self, other: Union[str, Mapping,
Iterable[Mapping]]) -> "Message":
return super(Message, self).__add__(
MessageSegment.text(other) if isinstance(other, str) else other)
@overrides(BaseMessage)
def __radd__(self, other: Union[str, Mapping,
Iterable[Mapping]]) -> "Message":
return super(Message, self).__radd__(
MessageSegment.text(other) if isinstance(other, str) else other)
2021-07-01 07:59:50 +08:00
@staticmethod
2021-07-04 14:19:36 +08:00
@overrides(BaseMessage)
2021-07-01 07:59:50 +08:00
def _construct(
msg: Union[str, Mapping,
Iterable[Mapping]]) -> Iterable[MessageSegment]:
2021-07-03 13:56:47 +08:00
if isinstance(msg, Mapping):
yield MessageSegment(msg["type"], msg.get("data") or {})
return
2021-07-03 13:56:47 +08:00
elif isinstance(msg, str):
yield MessageSegment.text(msg)
elif isinstance(msg, Iterable):
for seg in msg:
if isinstance(seg, MessageSegment):
yield seg
else:
yield MessageSegment(seg["type"], seg.get("data") or {})
2021-07-03 13:56:47 +08:00
2021-07-08 10:49:02 +08:00
def _merge(self) -> "Message":
i: int
seg: MessageSegment
msg: List[MessageSegment] = []
for i, seg in enumerate(self):
if seg.type == "text" and i != 0 and msg[-1].type == "text":
msg[-1] = MessageSegment(
"text", {"text": msg[-1].data["text"] + seg.data["text"]})
else:
msg.append(seg)
return Message(msg)
2021-07-04 14:19:36 +08:00
@overrides(BaseMessage)
def extract_plain_text(self) -> str:
return "".join(seg.data["text"] for seg in self if seg.is_text())
@dataclass
class MessageSerializer:
"""
飞书 协议 Message 序列化器
"""
message: Message
2021-07-06 21:12:49 +08:00
def serialize(self) -> Tuple[str, str]:
2021-07-08 16:12:43 +08:00
segments = list(self.message)
last_segment_type: str = ""
if len(segments) > 1:
msg = {"title": "", "content": [[]]}
for segment in segments:
if segment == "image":
if last_segment_type != "image":
msg["content"].append([])
else:
if last_segment_type == "image":
msg["content"].append([])
2021-07-08 23:05:24 +08:00
msg["content"][-1].append({
"tag": segment.type if segment.type != "image" else "img",
**segment.data
})
2021-07-08 16:12:43 +08:00
last_segment_type = segment.type
2021-07-08 23:05:24 +08:00
return "post", json.dumps({"zh_cn": {**msg}})
2021-07-08 16:12:43 +08:00
else:
return self.message[0].type, json.dumps(self.message[0].data)
@dataclass
class MessageDeserializer:
"""
飞书 协议 Message 反序列化器
"""
type: str
2021-07-06 21:12:49 +08:00
data: Dict[str, Any]
2021-07-08 22:30:39 +08:00
mentions: Optional[List[dict]]
2021-07-06 21:12:49 +08:00
def deserialize(self) -> Message:
2021-07-08 22:30:39 +08:00
dict_mention = {}
2021-07-09 16:02:54 +08:00
if self.mentions:
for mention in self.mentions:
dict_mention[mention["key"]] = mention
2021-07-08 22:30:39 +08:00
2021-07-09 16:02:54 +08:00
if self.type == "post":
2021-07-08 10:07:55 +08:00
msg = Message()
if self.data["title"] != "":
msg += MessageSegment("text", {'text': self.data["title"]})
2021-07-08 22:30:39 +08:00
2021-07-08 10:07:55 +08:00
for seg in itertools.chain(*self.data["content"]):
tag = seg.pop("tag")
2021-07-08 22:30:39 +08:00
if tag == "at":
seg["user_name"] = dict_mention[seg["user_id"]]["name"]
seg["user_id"] = dict_mention[
seg["user_id"]]["id"]["open_id"]
2021-07-08 10:07:55 +08:00
msg += MessageSegment(tag if tag != "img" else "image", seg)
2021-07-08 22:30:39 +08:00
2021-07-08 10:49:02 +08:00
return msg._merge()
2021-07-09 16:02:54 +08:00
elif self.type == "text":
for key, mention in dict_mention.items():
2021-07-11 12:25:27 +08:00
self.data["text"] = self.data["text"].replace(
key, f"@{mention['name']}")
2021-07-09 16:02:54 +08:00
self.data["mentions"] = dict_mention
return Message(MessageSegment(self.type, self.data))
2021-07-08 10:07:55 +08:00
else:
return Message(MessageSegment(self.type, self.data))