From f9f1e332624c6b60c67b3e04592820b393a21974 Mon Sep 17 00:00:00 2001 From: StarHeartHunt Date: Sun, 4 Jul 2021 14:19:36 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20implement=20message=20constructor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../nonebot/adapters/feishu/message.py | 50 ++++++++++++++++--- 1 file changed, 44 insertions(+), 6 deletions(-) diff --git a/packages/nonebot-adapter-feishu/nonebot/adapters/feishu/message.py b/packages/nonebot-adapter-feishu/nonebot/adapters/feishu/message.py index a0722f30..8a40ad1e 100644 --- a/packages/nonebot-adapter-feishu/nonebot/adapters/feishu/message.py +++ b/packages/nonebot-adapter-feishu/nonebot/adapters/feishu/message.py @@ -1,10 +1,13 @@ -from typing import Type, Union, Mapping, Iterable +import itertools +import json + +from typing import Any, Tuple, Type, Union, Mapping, Iterable from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment from nonebot.typing import overrides -class MessageSegment(BaseMessageSegment): +class MessageSegment(BaseMessageSegment["Message"]): """ 飞书 协议 MessageSegment 适配。具体方法参考协议消息段类型或源码。 """ @@ -19,12 +22,17 @@ class MessageSegment(BaseMessageSegment): return str(self.data["text"]) return "" + @overrides(BaseMessageSegment) def __add__(self, other) -> "Message": - return Message(self) + other + return Message(self) + (MessageSegment.text(other) if isinstance( + other, str) else other) + @overrides(BaseMessageSegment) def __radd__(self, other) -> "Message": - return Message(other) + self + return (MessageSegment.text(other) + if isinstance(other, str) else Message(other)) + self + @overrides(BaseMessageSegment) def is_text(self) -> bool: return self.type == "text" @@ -123,7 +131,7 @@ class MessageSegment(BaseMessageSegment): }) -class Message(BaseMessage): +class Message(BaseMessage[MessageSegment]): """ 飞书 协议 Message 适配。 """ @@ -133,12 +141,38 @@ class Message(BaseMessage): def get_segment_class(cls) -> Type[MessageSegment]: return MessageSegment + @overrides(BaseMessage) + def __add__(self, other: Union[str, Mapping, + Iterable[Mapping]]) -> "Message": + return super(Message, self).__add__( + MessageSegment.text(other) if isinstance(other, str) else other) + + @overrides(BaseMessage) + def __radd__(self, other: Union[str, Mapping, + Iterable[Mapping]]) -> "Message": + return super(Message, self).__radd__( + MessageSegment.text(other) if isinstance(other, str) else other) + @staticmethod + @overrides(BaseMessage) def _construct( msg: Union[str, Mapping, Iterable[Mapping]]) -> Iterable[MessageSegment]: if isinstance(msg, Mapping): - yield MessageSegment(msg["type"], msg.get("data") or {}) + + def _iter_message(msg: Mapping) -> Iterable[Tuple[str, dict]]: + pure_text: str = msg.get("text", "") + content: dict = msg.get("content", {}) + if pure_text and not content: + yield "text", {"text": pure_text} + elif content and not pure_text: + for element in list(itertools.chain(*content)): + tag = element.pop("tag") + yield tag, element + + for type_, data in _iter_message(msg): + yield MessageSegment(type_, data) + elif isinstance(msg, str): yield MessageSegment.text(msg) elif isinstance(msg, Iterable): @@ -147,3 +181,7 @@ class Message(BaseMessage): def _produce(self) -> dict: raise NotImplementedError + + @overrides(BaseMessage) + def extract_plain_text(self) -> str: + return "".join(seg.data["text"] for seg in self if seg.is_text())