implement message serializer & deserializer

This commit is contained in:
StarHeartHunt 2021-07-06 18:01:41 +08:00
parent 5635c83bfb
commit 603a63a629
3 changed files with 80 additions and 24 deletions

View File

@ -12,7 +12,7 @@ from nonebot.drivers import Driver, HTTPRequest, HTTPResponse
from .config import Config as FeishuConfig
from .event import Event, GroupMessageEvent, PrivateMessageEvent, get_event_model
from .exception import ActionFailed, ApiNotAvailable, NetworkError
from .message import Message, MessageSegment
from .message import Message, MessageSegment, MessageSerializer
from .utils import log, AESCipher
if TYPE_CHECKING:
@ -103,7 +103,6 @@ class Bot(BaseBot):
super().register(driver, config)
cls.feishu_config = FeishuConfig(**config.dict())
#TODO:校验schema 要求为2.0
@classmethod
@overrides(BaseBot)
async def check_permission(
@ -133,7 +132,10 @@ class Bot(BaseBot):
schema = data.get("schema")
if not schema:
return None, HTTPResponse(400, b"Missing `schema` in POST body, only accept event of version 2.0")
return None, HTTPResponse(
400,
b"Missing `schema` in POST body, only accept event of version 2.0"
)
headers = data.get("header")
if headers:
@ -162,6 +164,7 @@ class Bot(BaseBot):
处理事件并转换为 `Event <#class-event>`_
"""
data = json.loads(message)
print("handle_event start")
print(data)
if data.get("type") == "url_verification":
return
@ -230,6 +233,7 @@ class Bot(BaseBot):
"Authorization"] = "Bearer " + self.feishu_config.tenant_access_token
try:
print("call_api request start")
print(data)
async with httpx.AsyncClient(headers=headers) as client:
response = await client.post(
@ -237,7 +241,7 @@ class Bot(BaseBot):
json=data["body"],
params=data["query"],
timeout=self.config.api_timeout)
print("remote server returned.")
print(response.json())
if 200 <= response.status_code < 300:
result = response.json()
@ -285,14 +289,22 @@ class Bot(BaseBot):
else:
raise ValueError(
"Cannot guess `receive_id` and `receive_id_type` to reply!")
if isinstance(message, MessageSegment):
msg_type = message.type
elif isinstance(message, Message):
msg_type = message[0].type
else:
msg_type = "text"
params = {
"query": {
"receive_id_type": receive_id_type
},
"body": {
"receive_id": receive_id,
"content": str(message),
"msg_type": "text" if len(message) == 1 else "content"
"content": MessageSerializer(Message(message)).serialize(),
"msg_type": msg_type
}
}

View File

@ -8,7 +8,7 @@ from pydantic import BaseModel, root_validator, Field
from nonebot.adapters import Event as BaseEvent
from nonebot.typing import overrides
from .message import Message
from .message import Message, MessageDeserializer
class EventHeader(BaseModel):
@ -97,7 +97,9 @@ class EventMessage(BaseModel):
@root_validator(pre=True)
def parse_message(cls, values: dict):
values["content"] = json.loads(values["content"])
values["content"] = MessageDeserializer(
data=json.loads(values["content"]),
type=values["message_type"]).deserialize()
return values

View File

@ -1,6 +1,8 @@
import itertools
from typing import Tuple, Type, Union, Mapping, Iterable
from dataclasses import dataclass
import json
from typing import Any, Dict, List, Tuple, Type, Union, Mapping, Iterable
from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment
from nonebot.typing import overrides
@ -158,24 +160,15 @@ class Message(BaseMessage[MessageSegment]):
msg: Union[str, Mapping,
Iterable[Mapping]]) -> Iterable[MessageSegment]:
if isinstance(msg, Mapping):
def _iter_message(msg: Mapping) -> Iterable[Tuple[str, dict]]:
pure_text: str = msg.get("text", "")
content: dict = msg.get("content", {})
if pure_text and not content:
yield "text", {"text": pure_text}
elif content and not pure_text:
for element in list(itertools.chain(*content)):
tag = element.pop("tag")
yield tag, element
for type_, data in _iter_message(msg):
yield MessageSegment(type_, data)
yield MessageSegment(msg["type"], msg.get("data") or {})
return
elif isinstance(msg, str):
yield MessageSegment.text(msg)
elif isinstance(msg, Iterable):
for seg in msg:
if isinstance(seg, MessageSegment):
yield seg
else:
yield MessageSegment(seg["type"], seg.get("data") or {})
def _produce(self) -> dict:
@ -184,3 +177,52 @@ class Message(BaseMessage[MessageSegment]):
@overrides(BaseMessage)
def extract_plain_text(self) -> str:
return "".join(seg.data["text"] for seg in self if seg.is_text())
@dataclass
class MessageSerializer:
"""
飞书 协议 Message 序列化器
"""
message: Message
def serialize(self):
for segment in self.message:
if segment.type == "post":
raise NotImplementedError
else:
return json.dumps(segment.data)
@dataclass
class MessageDeserializer:
"""
飞书 协议 Message 反序列化器
"""
data: Dict[str, Any]
type: str
def deserialize(self):
print(self.type, self.data)
if self.type == "post":
return self._parse_rich_text(self.data)
else:
return Message(MessageSegment(self.type, self.data))
def _parse_rich_text(self, message_data: Dict[str,
Any]) -> List[MessageSegment]:
def _iter_message(
message_data: Dict[str,
Any]) -> Iterable[Tuple[str, Dict[str, Any]]]:
content: dict = message_data.get("content", {})
if content:
for element in list(itertools.chain(*content)):
tag = element.get("tag")
yield tag, element
temp = Message()
for type_, data in _iter_message(message_data):
temp += MessageSegment(type_, data)
return temp