mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-28 08:12:14 +08:00
✨ implement message serializer & deserializer
This commit is contained in:
parent
5635c83bfb
commit
603a63a629
@ -12,7 +12,7 @@ from nonebot.drivers import Driver, HTTPRequest, HTTPResponse
|
|||||||
from .config import Config as FeishuConfig
|
from .config import Config as FeishuConfig
|
||||||
from .event import Event, GroupMessageEvent, PrivateMessageEvent, get_event_model
|
from .event import Event, GroupMessageEvent, PrivateMessageEvent, get_event_model
|
||||||
from .exception import ActionFailed, ApiNotAvailable, NetworkError
|
from .exception import ActionFailed, ApiNotAvailable, NetworkError
|
||||||
from .message import Message, MessageSegment
|
from .message import Message, MessageSegment, MessageSerializer
|
||||||
from .utils import log, AESCipher
|
from .utils import log, AESCipher
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -103,7 +103,6 @@ class Bot(BaseBot):
|
|||||||
super().register(driver, config)
|
super().register(driver, config)
|
||||||
cls.feishu_config = FeishuConfig(**config.dict())
|
cls.feishu_config = FeishuConfig(**config.dict())
|
||||||
|
|
||||||
#TODO:校验schema 要求为2.0
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@overrides(BaseBot)
|
@overrides(BaseBot)
|
||||||
async def check_permission(
|
async def check_permission(
|
||||||
@ -133,7 +132,10 @@ class Bot(BaseBot):
|
|||||||
|
|
||||||
schema = data.get("schema")
|
schema = data.get("schema")
|
||||||
if not schema:
|
if not schema:
|
||||||
return None, HTTPResponse(400, b"Missing `schema` in POST body, only accept event of version 2.0")
|
return None, HTTPResponse(
|
||||||
|
400,
|
||||||
|
b"Missing `schema` in POST body, only accept event of version 2.0"
|
||||||
|
)
|
||||||
|
|
||||||
headers = data.get("header")
|
headers = data.get("header")
|
||||||
if headers:
|
if headers:
|
||||||
@ -162,6 +164,7 @@ class Bot(BaseBot):
|
|||||||
处理事件并转换为 `Event <#class-event>`_
|
处理事件并转换为 `Event <#class-event>`_
|
||||||
"""
|
"""
|
||||||
data = json.loads(message)
|
data = json.loads(message)
|
||||||
|
print("handle_event start")
|
||||||
print(data)
|
print(data)
|
||||||
if data.get("type") == "url_verification":
|
if data.get("type") == "url_verification":
|
||||||
return
|
return
|
||||||
@ -230,6 +233,7 @@ class Bot(BaseBot):
|
|||||||
"Authorization"] = "Bearer " + self.feishu_config.tenant_access_token
|
"Authorization"] = "Bearer " + self.feishu_config.tenant_access_token
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
print("call_api request start")
|
||||||
print(data)
|
print(data)
|
||||||
async with httpx.AsyncClient(headers=headers) as client:
|
async with httpx.AsyncClient(headers=headers) as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
@ -237,7 +241,7 @@ class Bot(BaseBot):
|
|||||||
json=data["body"],
|
json=data["body"],
|
||||||
params=data["query"],
|
params=data["query"],
|
||||||
timeout=self.config.api_timeout)
|
timeout=self.config.api_timeout)
|
||||||
|
print("remote server returned.")
|
||||||
print(response.json())
|
print(response.json())
|
||||||
if 200 <= response.status_code < 300:
|
if 200 <= response.status_code < 300:
|
||||||
result = response.json()
|
result = response.json()
|
||||||
@ -285,14 +289,22 @@ class Bot(BaseBot):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot guess `receive_id` and `receive_id_type` to reply!")
|
"Cannot guess `receive_id` and `receive_id_type` to reply!")
|
||||||
|
|
||||||
|
if isinstance(message, MessageSegment):
|
||||||
|
msg_type = message.type
|
||||||
|
elif isinstance(message, Message):
|
||||||
|
msg_type = message[0].type
|
||||||
|
else:
|
||||||
|
msg_type = "text"
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"query": {
|
"query": {
|
||||||
"receive_id_type": receive_id_type
|
"receive_id_type": receive_id_type
|
||||||
},
|
},
|
||||||
"body": {
|
"body": {
|
||||||
"receive_id": receive_id,
|
"receive_id": receive_id,
|
||||||
"content": str(message),
|
"content": MessageSerializer(Message(message)).serialize(),
|
||||||
"msg_type": "text" if len(message) == 1 else "content"
|
"msg_type": msg_type
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ from pydantic import BaseModel, root_validator, Field
|
|||||||
from nonebot.adapters import Event as BaseEvent
|
from nonebot.adapters import Event as BaseEvent
|
||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
|
|
||||||
from .message import Message
|
from .message import Message, MessageDeserializer
|
||||||
|
|
||||||
|
|
||||||
class EventHeader(BaseModel):
|
class EventHeader(BaseModel):
|
||||||
@ -97,7 +97,9 @@ class EventMessage(BaseModel):
|
|||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def parse_message(cls, values: dict):
|
def parse_message(cls, values: dict):
|
||||||
values["content"] = json.loads(values["content"])
|
values["content"] = MessageDeserializer(
|
||||||
|
data=json.loads(values["content"]),
|
||||||
|
type=values["message_type"]).deserialize()
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
from typing import Tuple, Type, Union, Mapping, Iterable
|
from dataclasses import dataclass
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict, List, Tuple, Type, Union, Mapping, Iterable
|
||||||
|
|
||||||
from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment
|
from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment
|
||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
@ -158,25 +160,16 @@ class Message(BaseMessage[MessageSegment]):
|
|||||||
msg: Union[str, Mapping,
|
msg: Union[str, Mapping,
|
||||||
Iterable[Mapping]]) -> Iterable[MessageSegment]:
|
Iterable[Mapping]]) -> Iterable[MessageSegment]:
|
||||||
if isinstance(msg, Mapping):
|
if isinstance(msg, Mapping):
|
||||||
|
yield MessageSegment(msg["type"], msg.get("data") or {})
|
||||||
def _iter_message(msg: Mapping) -> Iterable[Tuple[str, dict]]:
|
return
|
||||||
pure_text: str = msg.get("text", "")
|
|
||||||
content: dict = msg.get("content", {})
|
|
||||||
if pure_text and not content:
|
|
||||||
yield "text", {"text": pure_text}
|
|
||||||
elif content and not pure_text:
|
|
||||||
for element in list(itertools.chain(*content)):
|
|
||||||
tag = element.pop("tag")
|
|
||||||
yield tag, element
|
|
||||||
|
|
||||||
for type_, data in _iter_message(msg):
|
|
||||||
yield MessageSegment(type_, data)
|
|
||||||
|
|
||||||
elif isinstance(msg, str):
|
elif isinstance(msg, str):
|
||||||
yield MessageSegment.text(msg)
|
yield MessageSegment.text(msg)
|
||||||
elif isinstance(msg, Iterable):
|
elif isinstance(msg, Iterable):
|
||||||
for seg in msg:
|
for seg in msg:
|
||||||
yield MessageSegment(seg["type"], seg.get("data") or {})
|
if isinstance(seg, MessageSegment):
|
||||||
|
yield seg
|
||||||
|
else:
|
||||||
|
yield MessageSegment(seg["type"], seg.get("data") or {})
|
||||||
|
|
||||||
def _produce(self) -> dict:
|
def _produce(self) -> dict:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -184,3 +177,52 @@ class Message(BaseMessage[MessageSegment]):
|
|||||||
@overrides(BaseMessage)
|
@overrides(BaseMessage)
|
||||||
def extract_plain_text(self) -> str:
|
def extract_plain_text(self) -> str:
|
||||||
return "".join(seg.data["text"] for seg in self if seg.is_text())
|
return "".join(seg.data["text"] for seg in self if seg.is_text())
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MessageSerializer:
|
||||||
|
"""
|
||||||
|
飞书 协议 Message 序列化器。
|
||||||
|
"""
|
||||||
|
message: Message
|
||||||
|
|
||||||
|
def serialize(self):
|
||||||
|
for segment in self.message:
|
||||||
|
if segment.type == "post":
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
return json.dumps(segment.data)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MessageDeserializer:
|
||||||
|
"""
|
||||||
|
飞书 协议 Message 反序列化器。
|
||||||
|
"""
|
||||||
|
data: Dict[str, Any]
|
||||||
|
type: str
|
||||||
|
|
||||||
|
def deserialize(self):
|
||||||
|
print(self.type, self.data)
|
||||||
|
if self.type == "post":
|
||||||
|
return self._parse_rich_text(self.data)
|
||||||
|
else:
|
||||||
|
return Message(MessageSegment(self.type, self.data))
|
||||||
|
|
||||||
|
def _parse_rich_text(self, message_data: Dict[str,
|
||||||
|
Any]) -> List[MessageSegment]:
|
||||||
|
|
||||||
|
def _iter_message(
|
||||||
|
message_data: Dict[str,
|
||||||
|
Any]) -> Iterable[Tuple[str, Dict[str, Any]]]:
|
||||||
|
content: dict = message_data.get("content", {})
|
||||||
|
if content:
|
||||||
|
for element in list(itertools.chain(*content)):
|
||||||
|
tag = element.get("tag")
|
||||||
|
yield tag, element
|
||||||
|
|
||||||
|
temp = Message()
|
||||||
|
for type_, data in _iter_message(message_data):
|
||||||
|
temp += MessageSegment(type_, data)
|
||||||
|
|
||||||
|
return temp
|
||||||
|
Loading…
Reference in New Issue
Block a user