mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-25 17:45:05 +08:00
251 lines
8.0 KiB
Python
251 lines
8.0 KiB
Python
import itertools
|
|
import json
|
|
from dataclasses import dataclass
|
|
from typing import (Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type,
|
|
Union)
|
|
|
|
from nonebot.adapters import Message as BaseMessage
|
|
from nonebot.adapters import MessageSegment as BaseMessageSegment
|
|
from nonebot.typing import overrides
|
|
|
|
|
|
class MessageSegment(BaseMessageSegment["Message"]):
|
|
"""
|
|
飞书 协议 MessageSegment 适配。具体方法参考协议消息段类型或源码。
|
|
"""
|
|
|
|
@classmethod
|
|
@overrides(BaseMessageSegment)
|
|
def get_message_class(cls) -> Type["Message"]:
|
|
return Message
|
|
|
|
@property
|
|
def segment_text(self) -> dict:
|
|
return {
|
|
"image": "[图片]",
|
|
"file": "[文件]",
|
|
"audio": "[音频]",
|
|
"media": "[视频]",
|
|
"sticker": "[表情包]",
|
|
"interactive": "[消息卡片]",
|
|
"hongbao": "[红包]",
|
|
"share_calendar_event": "[日程卡片]",
|
|
"share_chat": "[群名片]",
|
|
"share_user": "[个人名片]",
|
|
"system": "[系统消息]",
|
|
"location": "[位置]",
|
|
"video_chat": "[视频通话]"
|
|
}
|
|
|
|
def __str__(self) -> str:
|
|
if self.type in ["text", "hongbao", "a"]:
|
|
return str(self.data["text"])
|
|
elif self.type == "at":
|
|
return str(f"@{self.data['user_name']}")
|
|
else:
|
|
return self.segment_text.get(self.type, "")
|
|
|
|
@overrides(BaseMessageSegment)
|
|
def __add__(self, other) -> "Message":
|
|
return Message(self) + (MessageSegment.text(other) if isinstance(
|
|
other, str) else other)
|
|
|
|
@overrides(BaseMessageSegment)
|
|
def __radd__(self, other) -> "Message":
|
|
return (MessageSegment.text(other)
|
|
if isinstance(other, str) else Message(other)) + self
|
|
|
|
@overrides(BaseMessageSegment)
|
|
def is_text(self) -> bool:
|
|
return self.type == "text"
|
|
|
|
#接收消息
|
|
@staticmethod
|
|
def at(user_id: str) -> "MessageSegment":
|
|
return MessageSegment("at", {"user_id": user_id})
|
|
|
|
#发送消息
|
|
@staticmethod
|
|
def text(text: str) -> "MessageSegment":
|
|
return MessageSegment("text", {"text": text})
|
|
|
|
@staticmethod
|
|
def post(title: str, content: list) -> "MessageSegment":
|
|
return MessageSegment("post", {"title": title, "content": content})
|
|
|
|
@staticmethod
|
|
def image(image_key: str) -> "MessageSegment":
|
|
return MessageSegment("image", {"image_key": image_key})
|
|
|
|
@staticmethod
|
|
def interactive(title: str, elements: list) -> "MessageSegment":
|
|
return MessageSegment("interactive", {
|
|
"title": title,
|
|
"elements": elements
|
|
})
|
|
|
|
@staticmethod
|
|
def share_chat(chat_id: str) -> "MessageSegment":
|
|
return MessageSegment("share_chat", {"chat_id": chat_id})
|
|
|
|
@staticmethod
|
|
def share_user(user_id: str) -> "MessageSegment":
|
|
return MessageSegment("share_user", {"user_id": user_id})
|
|
|
|
@staticmethod
|
|
def audio(file_key: str, duration: int) -> "MessageSegment":
|
|
return MessageSegment("audio", {
|
|
"file_key": file_key,
|
|
"duration": duration
|
|
})
|
|
|
|
@staticmethod
|
|
def media(file_key: str, image_key: str, file_name: str,
|
|
duration: int) -> "MessageSegment":
|
|
return MessageSegment(
|
|
"media", {
|
|
"file_key": file_key,
|
|
"image_key": image_key,
|
|
"file_name": file_name,
|
|
"duration": duration
|
|
})
|
|
|
|
@staticmethod
|
|
def file(file_key: str, file_name: str) -> "MessageSegment":
|
|
return MessageSegment("file", {
|
|
"file_key": file_key,
|
|
"file_name": file_name
|
|
})
|
|
|
|
@staticmethod
|
|
def sticker(file_key) -> "MessageSegment":
|
|
return MessageSegment("sticker", {"file_key": file_key})
|
|
|
|
|
|
class Message(BaseMessage[MessageSegment]):
|
|
"""
|
|
飞书 协议 Message 适配。
|
|
"""
|
|
|
|
@classmethod
|
|
@overrides(BaseMessage)
|
|
def get_segment_class(cls) -> Type[MessageSegment]:
|
|
return MessageSegment
|
|
|
|
@overrides(BaseMessage)
|
|
def __add__(self, other: Union[str, Mapping,
|
|
Iterable[Mapping]]) -> "Message":
|
|
return super(Message, self).__add__(
|
|
MessageSegment.text(other) if isinstance(other, str) else other)
|
|
|
|
@overrides(BaseMessage)
|
|
def __radd__(self, other: Union[str, Mapping,
|
|
Iterable[Mapping]]) -> "Message":
|
|
return super(Message, self).__radd__(
|
|
MessageSegment.text(other) if isinstance(other, str) else other)
|
|
|
|
@staticmethod
|
|
@overrides(BaseMessage)
|
|
def _construct(
|
|
msg: Union[str, Mapping,
|
|
Iterable[Mapping]]) -> Iterable[MessageSegment]:
|
|
if isinstance(msg, Mapping):
|
|
yield MessageSegment(msg["type"], msg.get("data") or {})
|
|
return
|
|
elif isinstance(msg, str):
|
|
yield MessageSegment.text(msg)
|
|
elif isinstance(msg, Iterable):
|
|
for seg in msg:
|
|
if isinstance(seg, MessageSegment):
|
|
yield seg
|
|
else:
|
|
yield MessageSegment(seg["type"], seg.get("data") or {})
|
|
|
|
def _merge(self) -> "Message":
|
|
i: int
|
|
seg: MessageSegment
|
|
msg: List[MessageSegment] = []
|
|
for i, seg in enumerate(self):
|
|
if seg.type == "text" and i != 0 and msg[-1].type == "text":
|
|
msg[-1] = MessageSegment(
|
|
"text", {"text": msg[-1].data["text"] + seg.data["text"]})
|
|
else:
|
|
msg.append(seg)
|
|
return Message(msg)
|
|
|
|
@overrides(BaseMessage)
|
|
def extract_plain_text(self) -> str:
|
|
return "".join(seg.data["text"] for seg in self if seg.is_text())
|
|
|
|
|
|
@dataclass
|
|
class MessageSerializer:
|
|
"""
|
|
飞书 协议 Message 序列化器。
|
|
"""
|
|
message: Message
|
|
|
|
def serialize(self) -> Tuple[str, str]:
|
|
segments = list(self.message)
|
|
last_segment_type: str = ""
|
|
if len(segments) > 1:
|
|
msg = {"title": "", "content": [[]]}
|
|
for segment in segments:
|
|
if segment == "image":
|
|
if last_segment_type != "image":
|
|
msg["content"].append([])
|
|
else:
|
|
if last_segment_type == "image":
|
|
msg["content"].append([])
|
|
msg["content"][-1].append({
|
|
"tag": segment.type if segment.type != "image" else "img",
|
|
**segment.data
|
|
})
|
|
last_segment_type = segment.type
|
|
return "post", json.dumps({"zh_cn": {**msg}})
|
|
|
|
else:
|
|
return self.message[0].type, json.dumps(self.message[0].data)
|
|
|
|
|
|
@dataclass
|
|
class MessageDeserializer:
|
|
"""
|
|
飞书 协议 Message 反序列化器。
|
|
"""
|
|
type: str
|
|
data: Dict[str, Any]
|
|
mentions: Optional[List[dict]]
|
|
|
|
def deserialize(self) -> Message:
|
|
dict_mention = {}
|
|
if self.mentions:
|
|
for mention in self.mentions:
|
|
dict_mention[mention["key"]] = mention
|
|
|
|
if self.type == "post":
|
|
msg = Message()
|
|
if self.data["title"] != "":
|
|
msg += MessageSegment("text", {'text': self.data["title"]})
|
|
|
|
for seg in itertools.chain(*self.data["content"]):
|
|
tag = seg.pop("tag")
|
|
if tag == "at":
|
|
seg["user_name"] = dict_mention[seg["user_id"]]["name"]
|
|
seg["user_id"] = dict_mention[
|
|
seg["user_id"]]["id"]["open_id"]
|
|
|
|
msg += MessageSegment(tag if tag != "img" else "image", seg)
|
|
|
|
return msg._merge()
|
|
elif self.type == "text":
|
|
for key, mention in dict_mention.items():
|
|
self.data["text"] = self.data["text"].replace(
|
|
key, f"@{mention['name']}")
|
|
self.data["mentions"] = dict_mention
|
|
|
|
return Message(MessageSegment(self.type, self.data))
|
|
|
|
else:
|
|
return Message(MessageSegment(self.type, self.data))
|