implement message serializer & deserializer

This commit is contained in:
StarHeartHunt 2021-07-06 18:01:41 +08:00
parent 5635c83bfb
commit 603a63a629
3 changed files with 80 additions and 24 deletions

View File

@ -12,7 +12,7 @@ from nonebot.drivers import Driver, HTTPRequest, HTTPResponse
from .config import Config as FeishuConfig from .config import Config as FeishuConfig
from .event import Event, GroupMessageEvent, PrivateMessageEvent, get_event_model from .event import Event, GroupMessageEvent, PrivateMessageEvent, get_event_model
from .exception import ActionFailed, ApiNotAvailable, NetworkError from .exception import ActionFailed, ApiNotAvailable, NetworkError
from .message import Message, MessageSegment from .message import Message, MessageSegment, MessageSerializer
from .utils import log, AESCipher from .utils import log, AESCipher
if TYPE_CHECKING: if TYPE_CHECKING:
@ -103,7 +103,6 @@ class Bot(BaseBot):
super().register(driver, config) super().register(driver, config)
cls.feishu_config = FeishuConfig(**config.dict()) cls.feishu_config = FeishuConfig(**config.dict())
#TODO:校验schema 要求为2.0
@classmethod @classmethod
@overrides(BaseBot) @overrides(BaseBot)
async def check_permission( async def check_permission(
@ -133,7 +132,10 @@ class Bot(BaseBot):
schema = data.get("schema") schema = data.get("schema")
if not schema: if not schema:
return None, HTTPResponse(400, b"Missing `schema` in POST body, only accept event of version 2.0") return None, HTTPResponse(
400,
b"Missing `schema` in POST body, only accept event of version 2.0"
)
headers = data.get("header") headers = data.get("header")
if headers: if headers:
@ -162,6 +164,7 @@ class Bot(BaseBot):
处理事件并转换为 `Event <#class-event>`_ 处理事件并转换为 `Event <#class-event>`_
""" """
data = json.loads(message) data = json.loads(message)
print("handle_event start")
print(data) print(data)
if data.get("type") == "url_verification": if data.get("type") == "url_verification":
return return
@ -230,6 +233,7 @@ class Bot(BaseBot):
"Authorization"] = "Bearer " + self.feishu_config.tenant_access_token "Authorization"] = "Bearer " + self.feishu_config.tenant_access_token
try: try:
print("call_api request start")
print(data) print(data)
async with httpx.AsyncClient(headers=headers) as client: async with httpx.AsyncClient(headers=headers) as client:
response = await client.post( response = await client.post(
@ -237,7 +241,7 @@ class Bot(BaseBot):
json=data["body"], json=data["body"],
params=data["query"], params=data["query"],
timeout=self.config.api_timeout) timeout=self.config.api_timeout)
print("remote server returned.")
print(response.json()) print(response.json())
if 200 <= response.status_code < 300: if 200 <= response.status_code < 300:
result = response.json() result = response.json()
@ -285,14 +289,22 @@ class Bot(BaseBot):
else: else:
raise ValueError( raise ValueError(
"Cannot guess `receive_id` and `receive_id_type` to reply!") "Cannot guess `receive_id` and `receive_id_type` to reply!")
if isinstance(message, MessageSegment):
msg_type = message.type
elif isinstance(message, Message):
msg_type = message[0].type
else:
msg_type = "text"
params = { params = {
"query": { "query": {
"receive_id_type": receive_id_type "receive_id_type": receive_id_type
}, },
"body": { "body": {
"receive_id": receive_id, "receive_id": receive_id,
"content": str(message), "content": MessageSerializer(Message(message)).serialize(),
"msg_type": "text" if len(message) == 1 else "content" "msg_type": msg_type
} }
} }

View File

@ -8,7 +8,7 @@ from pydantic import BaseModel, root_validator, Field
from nonebot.adapters import Event as BaseEvent from nonebot.adapters import Event as BaseEvent
from nonebot.typing import overrides from nonebot.typing import overrides
from .message import Message from .message import Message, MessageDeserializer
class EventHeader(BaseModel): class EventHeader(BaseModel):
@ -97,7 +97,9 @@ class EventMessage(BaseModel):
@root_validator(pre=True) @root_validator(pre=True)
def parse_message(cls, values: dict): def parse_message(cls, values: dict):
values["content"] = json.loads(values["content"]) values["content"] = MessageDeserializer(
data=json.loads(values["content"]),
type=values["message_type"]).deserialize()
return values return values

View File

@ -1,6 +1,8 @@
import itertools import itertools
from typing import Tuple, Type, Union, Mapping, Iterable from dataclasses import dataclass
import json
from typing import Any, Dict, List, Tuple, Type, Union, Mapping, Iterable
from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment
from nonebot.typing import overrides from nonebot.typing import overrides
@ -158,24 +160,15 @@ class Message(BaseMessage[MessageSegment]):
msg: Union[str, Mapping, msg: Union[str, Mapping,
Iterable[Mapping]]) -> Iterable[MessageSegment]: Iterable[Mapping]]) -> Iterable[MessageSegment]:
if isinstance(msg, Mapping): if isinstance(msg, Mapping):
yield MessageSegment(msg["type"], msg.get("data") or {})
def _iter_message(msg: Mapping) -> Iterable[Tuple[str, dict]]: return
pure_text: str = msg.get("text", "")
content: dict = msg.get("content", {})
if pure_text and not content:
yield "text", {"text": pure_text}
elif content and not pure_text:
for element in list(itertools.chain(*content)):
tag = element.pop("tag")
yield tag, element
for type_, data in _iter_message(msg):
yield MessageSegment(type_, data)
elif isinstance(msg, str): elif isinstance(msg, str):
yield MessageSegment.text(msg) yield MessageSegment.text(msg)
elif isinstance(msg, Iterable): elif isinstance(msg, Iterable):
for seg in msg: for seg in msg:
if isinstance(seg, MessageSegment):
yield seg
else:
yield MessageSegment(seg["type"], seg.get("data") or {}) yield MessageSegment(seg["type"], seg.get("data") or {})
def _produce(self) -> dict: def _produce(self) -> dict:
@ -184,3 +177,52 @@ class Message(BaseMessage[MessageSegment]):
@overrides(BaseMessage) @overrides(BaseMessage)
def extract_plain_text(self) -> str: def extract_plain_text(self) -> str:
return "".join(seg.data["text"] for seg in self if seg.is_text()) return "".join(seg.data["text"] for seg in self if seg.is_text())
@dataclass
class MessageSerializer:
"""
飞书 协议 Message 序列化器
"""
message: Message
def serialize(self):
for segment in self.message:
if segment.type == "post":
raise NotImplementedError
else:
return json.dumps(segment.data)
@dataclass
class MessageDeserializer:
"""
飞书 协议 Message 反序列化器
"""
data: Dict[str, Any]
type: str
def deserialize(self):
print(self.type, self.data)
if self.type == "post":
return self._parse_rich_text(self.data)
else:
return Message(MessageSegment(self.type, self.data))
def _parse_rich_text(self, message_data: Dict[str,
Any]) -> List[MessageSegment]:
def _iter_message(
message_data: Dict[str,
Any]) -> Iterable[Tuple[str, Dict[str, Any]]]:
content: dict = message_data.get("content", {})
if content:
for element in list(itertools.chain(*content)):
tag = element.get("tag")
yield tag, element
temp = Message()
for type_, data in _iter_message(message_data):
temp += MessageSegment(type_, data)
return temp