🏷️ add segment schema

This commit is contained in:
StarHeartHunt 2021-07-03 13:56:47 +08:00
parent 48816308f8
commit 4afd681ca9

View File

@ -1,11 +1,23 @@
from typing import Union, Mapping, Iterable from typing import Type, Union, Mapping, Iterable
from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment
from nonebot.typing import overrides
class MessageSegment(BaseMessageSegment): class MessageSegment(BaseMessageSegment):
"""
飞书 协议 MessageSegment 适配具体方法参考协议消息段类型或源码
"""
@classmethod
@overrides(BaseMessageSegment)
def get_message_class(cls) -> Type["Message"]:
return Message
def __str__(self) -> str: def __str__(self) -> str:
raise NotImplementedError if self.type == "text" or self.type == "hongbao":
return str(self.data["text"])
return ""
def __add__(self, other) -> "Message": def __add__(self, other) -> "Message":
return Message(self) + other return Message(self) + other
@ -14,13 +26,124 @@ class MessageSegment(BaseMessageSegment):
return Message(other) + self return Message(other) + self
def is_text(self) -> bool: def is_text(self) -> bool:
raise NotImplementedError return self.type == "text"
@staticmethod
def text(text: str) -> "MessageSegment":
return MessageSegment("text", {"text": text})
@staticmethod
def post(title: str, content: list) -> "MessageSegment":
return MessageSegment("post", {"title": title, "content": content})
@staticmethod
def image(image_key: str) -> "MessageSegment":
return MessageSegment("image", {"image_key": image_key})
@staticmethod
def file(file_key: str, file_name: str) -> "MessageSegment":
return MessageSegment("file", {
"file_key": file_key,
"file_name": file_name
})
@staticmethod
def audio(file_key: str, duration: int) -> "MessageSegment":
return MessageSegment("audio", {
"file_key": file_key,
"duration": duration
})
@staticmethod
def media(file_key: str, image_key: str, file_name: str,
duration: int) -> "MessageSegment":
return MessageSegment(
"media", {
"file_key": file_key,
"image_key": image_key,
"file_name": file_name,
"duration": duration
})
@staticmethod
def sticker(file_key) -> "MessageSegment":
return MessageSegment("sticker", {"file_key": file_key})
@staticmethod
def interactive(title: str, elements: list) -> "MessageSegment":
return MessageSegment("interactive", {
"title": title,
"elements": elements
})
@staticmethod
def hongbao(text: str) -> "MessageSegment":
return MessageSegment("hongbao", {"text": text})
@staticmethod
def share_calendar_event(summary: str, start_time: str,
end_time: str) -> "MessageSegment":
return MessageSegment("share_calendar_event", {
"summary": summary,
"start_time": start_time,
"end_time": end_time
})
@staticmethod
def share_chat(chat_id: str) -> "MessageSegment":
return MessageSegment("share_chat", {"chat_id": chat_id})
@staticmethod
def share_user(user_id: str) -> "MessageSegment":
return MessageSegment("share_user", {"user_id": user_id})
@staticmethod
def system(template: str, from_user: list,
to_chatters: list) -> "MessageSegment":
return MessageSegment(
"system", {
"template": template,
"from_user": from_user,
"to_chatters": to_chatters
})
@staticmethod
def location(name: str, longitude: str, latitude: str) -> "MessageSegment":
return MessageSegment("location", {
"name": name,
"longitude": longitude,
"latitude": latitude
})
@staticmethod
def video_chat(topic: str, start_time: str) -> "MessageSegment":
return MessageSegment("video_chat", {
"topic": topic,
"start_time": start_time,
})
class Message(BaseMessage): class Message(BaseMessage):
"""
飞书 协议 Message 适配
"""
@classmethod
@overrides(BaseMessage)
def get_segment_class(cls) -> Type[MessageSegment]:
return MessageSegment
@staticmethod @staticmethod
def _construct( def _construct(
msg: Union[str, Mapping, msg: Union[str, Mapping,
Iterable[Mapping]]) -> Iterable[MessageSegment]: Iterable[Mapping]]) -> Iterable[MessageSegment]:
if isinstance(msg, Mapping):
yield MessageSegment(msg["type"], msg.get("data") or {})
elif isinstance(msg, str):
yield MessageSegment.text(msg)
elif isinstance(msg, Iterable):
for seg in msg:
yield MessageSegment(seg["type"], seg.get("data") or {})
def _produce(self) -> dict:
raise NotImplementedError raise NotImplementedError