mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-01-20 02:08:20 +08:00
✨ implement message constructor
This commit is contained in:
parent
7a8f881b04
commit
f9f1e33262
@ -1,10 +1,13 @@
|
|||||||
from typing import Type, Union, Mapping, Iterable
|
import itertools
|
||||||
|
import json
|
||||||
|
|
||||||
|
from typing import Any, Tuple, Type, Union, Mapping, Iterable
|
||||||
|
|
||||||
from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment
|
from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment
|
||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
|
|
||||||
|
|
||||||
class MessageSegment(BaseMessageSegment):
|
class MessageSegment(BaseMessageSegment["Message"]):
|
||||||
"""
|
"""
|
||||||
飞书 协议 MessageSegment 适配。具体方法参考协议消息段类型或源码。
|
飞书 协议 MessageSegment 适配。具体方法参考协议消息段类型或源码。
|
||||||
"""
|
"""
|
||||||
@ -19,12 +22,17 @@ class MessageSegment(BaseMessageSegment):
|
|||||||
return str(self.data["text"])
|
return str(self.data["text"])
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
@overrides(BaseMessageSegment)
|
||||||
def __add__(self, other) -> "Message":
|
def __add__(self, other) -> "Message":
|
||||||
return Message(self) + other
|
return Message(self) + (MessageSegment.text(other) if isinstance(
|
||||||
|
other, str) else other)
|
||||||
|
|
||||||
|
@overrides(BaseMessageSegment)
|
||||||
def __radd__(self, other) -> "Message":
|
def __radd__(self, other) -> "Message":
|
||||||
return Message(other) + self
|
return (MessageSegment.text(other)
|
||||||
|
if isinstance(other, str) else Message(other)) + self
|
||||||
|
|
||||||
|
@overrides(BaseMessageSegment)
|
||||||
def is_text(self) -> bool:
|
def is_text(self) -> bool:
|
||||||
return self.type == "text"
|
return self.type == "text"
|
||||||
|
|
||||||
@ -123,7 +131,7 @@ class MessageSegment(BaseMessageSegment):
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseMessage):
|
class Message(BaseMessage[MessageSegment]):
|
||||||
"""
|
"""
|
||||||
飞书 协议 Message 适配。
|
飞书 协议 Message 适配。
|
||||||
"""
|
"""
|
||||||
@ -133,12 +141,38 @@ class Message(BaseMessage):
|
|||||||
def get_segment_class(cls) -> Type[MessageSegment]:
|
def get_segment_class(cls) -> Type[MessageSegment]:
|
||||||
return MessageSegment
|
return MessageSegment
|
||||||
|
|
||||||
|
@overrides(BaseMessage)
|
||||||
|
def __add__(self, other: Union[str, Mapping,
|
||||||
|
Iterable[Mapping]]) -> "Message":
|
||||||
|
return super(Message, self).__add__(
|
||||||
|
MessageSegment.text(other) if isinstance(other, str) else other)
|
||||||
|
|
||||||
|
@overrides(BaseMessage)
|
||||||
|
def __radd__(self, other: Union[str, Mapping,
|
||||||
|
Iterable[Mapping]]) -> "Message":
|
||||||
|
return super(Message, self).__radd__(
|
||||||
|
MessageSegment.text(other) if isinstance(other, str) else other)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@overrides(BaseMessage)
|
||||||
def _construct(
|
def _construct(
|
||||||
msg: Union[str, Mapping,
|
msg: Union[str, Mapping,
|
||||||
Iterable[Mapping]]) -> Iterable[MessageSegment]:
|
Iterable[Mapping]]) -> Iterable[MessageSegment]:
|
||||||
if isinstance(msg, Mapping):
|
if isinstance(msg, Mapping):
|
||||||
yield MessageSegment(msg["type"], msg.get("data") or {})
|
|
||||||
|
def _iter_message(msg: Mapping) -> Iterable[Tuple[str, dict]]:
|
||||||
|
pure_text: str = msg.get("text", "")
|
||||||
|
content: dict = msg.get("content", {})
|
||||||
|
if pure_text and not content:
|
||||||
|
yield "text", {"text": pure_text}
|
||||||
|
elif content and not pure_text:
|
||||||
|
for element in list(itertools.chain(*content)):
|
||||||
|
tag = element.pop("tag")
|
||||||
|
yield tag, element
|
||||||
|
|
||||||
|
for type_, data in _iter_message(msg):
|
||||||
|
yield MessageSegment(type_, data)
|
||||||
|
|
||||||
elif isinstance(msg, str):
|
elif isinstance(msg, str):
|
||||||
yield MessageSegment.text(msg)
|
yield MessageSegment.text(msg)
|
||||||
elif isinstance(msg, Iterable):
|
elif isinstance(msg, Iterable):
|
||||||
@ -147,3 +181,7 @@ class Message(BaseMessage):
|
|||||||
|
|
||||||
def _produce(self) -> dict:
|
def _produce(self) -> dict:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@overrides(BaseMessage)
|
||||||
|
def extract_plain_text(self) -> str:
|
||||||
|
return "".join(seg.data["text"] for seg in self if seg.is_text())
|
||||||
|
Loading…
Reference in New Issue
Block a user