🍻 tweak

This commit is contained in:
StarHeartHunt 2021-07-08 10:07:55 +08:00
parent fe43c8d69a
commit 79b6601d12
2 changed files with 35 additions and 71 deletions

View File

@ -98,8 +98,8 @@ class EventMessage(BaseModel):
@root_validator(pre=True) @root_validator(pre=True)
def parse_message(cls, values: dict): def parse_message(cls, values: dict):
values["content"] = MessageDeserializer( values["content"] = MessageDeserializer(
data=json.loads(values["content"]), values["message_type"],
type=values["message_type"]).deserialize() json.loads(values["content"])).deserialize()
return values return values
@ -141,7 +141,7 @@ class MessageEvent(Event):
return ( return (
f"{self.event.message.message_id} from {self.get_user_id()}" f"{self.event.message.message_id} from {self.get_user_id()}"
f"@[{self.event.message.chat_type}:{self.event.message.chat_id}]" f"@[{self.event.message.chat_type}:{self.event.message.chat_id}]"
f" {MessageSerializer(self.get_message()).serialize()[1]}") f" {self.get_message()}")
@overrides(Event) @overrides(Event)
def get_message(self) -> Message: def get_message(self) -> Message:

View File

@ -19,15 +19,10 @@ class MessageSegment(BaseMessageSegment["Message"]):
return Message return Message
def __str__(self) -> str: def __str__(self) -> str:
if self.type == "post": if self.type == "text" or self.type == "hongbao":
return "".join(
str(MessageSegment(seg["tag"], seg))
for seg in itertools.chain(*self.data["content"]))
elif self.type == "text" or self.type == "hongbao":
return str(self.data["text"]) return str(self.data["text"])
elif self.type == "img" or self.type == "image": elif self.type == "image":
return "[图片]" return "[图片]"
return "" return ""
@ -59,12 +54,20 @@ class MessageSegment(BaseMessageSegment["Message"]):
return MessageSegment("image", {"image_key": image_key}) return MessageSegment("image", {"image_key": image_key})
@staticmethod @staticmethod
def file(file_key: str, file_name: str) -> "MessageSegment": def interactive(title: str, elements: list) -> "MessageSegment":
return MessageSegment("file", { return MessageSegment("interactive", {
"file_key": file_key, "title": title,
"file_name": file_name "elements": elements
}) })
@staticmethod
def share_chat(chat_id: str) -> "MessageSegment":
return MessageSegment("share_chat", {"chat_id": chat_id})
@staticmethod
def share_user(user_id: str) -> "MessageSegment":
return MessageSegment("share_user", {"user_id": user_id})
@staticmethod @staticmethod
def audio(file_key: str, duration: int) -> "MessageSegment": def audio(file_key: str, duration: int) -> "MessageSegment":
return MessageSegment("audio", { return MessageSegment("audio", {
@ -83,63 +86,17 @@ class MessageSegment(BaseMessageSegment["Message"]):
"duration": duration "duration": duration
}) })
@staticmethod
def file(file_key: str, file_name: str) -> "MessageSegment":
return MessageSegment("file", {
"file_key": file_key,
"file_name": file_name
})
@staticmethod @staticmethod
def sticker(file_key) -> "MessageSegment": def sticker(file_key) -> "MessageSegment":
return MessageSegment("sticker", {"file_key": file_key}) return MessageSegment("sticker", {"file_key": file_key})
@staticmethod
def interactive(title: str, elements: list) -> "MessageSegment":
return MessageSegment("interactive", {
"title": title,
"elements": elements
})
@staticmethod
def hongbao(text: str) -> "MessageSegment":
return MessageSegment("hongbao", {"text": text})
@staticmethod
def share_calendar_event(summary: str, start_time: str,
end_time: str) -> "MessageSegment":
return MessageSegment("share_calendar_event", {
"summary": summary,
"start_time": start_time,
"end_time": end_time
})
@staticmethod
def share_chat(chat_id: str) -> "MessageSegment":
return MessageSegment("share_chat", {"chat_id": chat_id})
@staticmethod
def share_user(user_id: str) -> "MessageSegment":
return MessageSegment("share_user", {"user_id": user_id})
@staticmethod
def system(template: str, from_user: list,
to_chatters: list) -> "MessageSegment":
return MessageSegment(
"system", {
"template": template,
"from_user": from_user,
"to_chatters": to_chatters
})
@staticmethod
def location(name: str, longitude: str, latitude: str) -> "MessageSegment":
return MessageSegment("location", {
"name": name,
"longitude": longitude,
"latitude": latitude
})
@staticmethod
def video_chat(topic: str, start_time: str) -> "MessageSegment":
return MessageSegment("video_chat", {
"topic": topic,
"start_time": start_time,
})
class Message(BaseMessage[MessageSegment]): class Message(BaseMessage[MessageSegment]):
""" """
@ -180,9 +137,6 @@ class Message(BaseMessage[MessageSegment]):
else: else:
yield MessageSegment(seg["type"], seg.get("data") or {}) yield MessageSegment(seg["type"], seg.get("data") or {})
def _produce(self) -> dict:
raise NotImplementedError
@overrides(BaseMessage) @overrides(BaseMessage)
def extract_plain_text(self) -> str: def extract_plain_text(self) -> str:
return "".join(seg.data["text"] for seg in self if seg.is_text()) return "".join(seg.data["text"] for seg in self if seg.is_text())
@ -208,4 +162,14 @@ class MessageDeserializer:
data: Dict[str, Any] data: Dict[str, Any]
def deserialize(self) -> Message: def deserialize(self) -> Message:
return Message(MessageSegment(self.type, self.data)) if self.type == "post":
msg = Message()
if self.data["title"] != "":
msg += MessageSegment("text", {'text': self.data["title"]})
for seg in itertools.chain(*self.data["content"]):
tag = seg.pop("tag")
msg += MessageSegment(tag if tag != "img" else "image", seg)
return msg
else:
return Message(MessageSegment(self.type, self.data))