improve ding adapter

This commit is contained in:
yanyongyu 2020-12-30 00:36:29 +08:00
parent 0221d02ca7
commit c8cd6de2f2
8 changed files with 194 additions and 193 deletions

View File

@ -6,6 +6,7 @@
"""
import abc
from copy import copy
from typing_extensions import Literal
from functools import reduce, partial
from dataclasses import dataclass, field
@ -292,7 +293,7 @@ class MessageSegment(abc.ABC):
@abc.abstractmethod
def __add__(self: T_MessageSegment, other: Union[str, T_MessageSegment,
T_Message]) -> "T_Message":
T_Message]) -> T_Message:
"""你需要在这里实现不同消息段的合并:
比如
if isinstance(other, str):
@ -326,6 +327,9 @@ class MessageSegment(abc.ABC):
def get(self, key, default=None):
return getattr(self, key, default)
def copy(self: T_MessageSegment) -> T_MessageSegment:
return copy(self)
@abc.abstractmethod
def is_text(self) -> bool:
raise NotImplementedError
@ -335,7 +339,8 @@ class Message(list, abc.ABC):
"""消息数组"""
def __init__(self,
message: Union[str, list, dict, T_MessageSegment, T_Message, Any] = None,
message: Union[str, list, dict, T_MessageSegment, T_Message,
Any] = None,
*args,
**kwargs):
"""
@ -364,7 +369,8 @@ class Message(list, abc.ABC):
@staticmethod
@abc.abstractmethod
def _construct(msg: Union[str, list, dict, Any]) -> Iterable[T_MessageSegment]:
def _construct(
msg: Union[str, list, dict, Any]) -> Iterable[T_MessageSegment]:
raise NotImplementedError
def __add__(self: T_Message, other: Union[str, T_MessageSegment,

View File

@ -6,7 +6,6 @@ import asyncio
from typing import Any, Dict, Union, Optional, TYPE_CHECKING
import httpx
from nonebot.log import logger
from nonebot.config import Config
from nonebot.typing import overrides

View File

@ -11,7 +11,7 @@
from .utils import log
from .bot import Bot
from .event import Event
from .message import Message, MessageSegment
from .event import Event, MessageEvent, PrivateMessageEvent, GroupMessageEvent
from .exception import (DingAdapterException, ApiNotAvailable, NetworkError,
ActionFailed, SessionExpired)

View File

@ -6,18 +6,18 @@ from typing import Any, Union, Optional, TYPE_CHECKING
import httpx
from nonebot.log import logger
from nonebot.config import Config
from nonebot.typing import overrides
from nonebot.message import handle_event
from nonebot.adapters import Bot as BaseBot
from nonebot.exception import RequestDenied
from .utils import log
from .event import Event, MessageEvent, PrivateMessageEvent, GroupMessageEvent
from .model import ConversationType
from .message import Message, MessageSegment
from .exception import NetworkError, ApiNotAvailable, ActionFailed, SessionExpired
from .event import Event, MessageEvent, PrivateMessageEvent, GroupMessageEvent, ConversationType
if TYPE_CHECKING:
from nonebot.drivers import BaseDriver as Driver
from nonebot.drivers import Driver
class Bot(BaseBot):
@ -38,6 +38,7 @@ class Bot(BaseBot):
return "ding"
@classmethod
@overrides(BaseBot)
async def check_permission(cls, driver: "Driver", connection_type: str,
headers: dict, body: Optional[dict]) -> str:
"""
@ -73,18 +74,22 @@ class Bot(BaseBot):
log("WARNING", "Ding signature check ignored!")
return body["chatbotUserId"]
async def handle_message(self, body: dict):
if not body:
@overrides(BaseBot)
async def handle_message(self, message: dict):
if not message:
return
# 判断消息类型,生成不同的 Event
conversation_type = body["conversationType"]
if conversation_type == ConversationType.private:
event = PrivateMessageEvent.parse_obj(body)
else:
event = GroupMessageEvent.parse_obj(body)
if not event:
try:
conversation_type = message["conversationType"]
if conversation_type == ConversationType.private:
event = PrivateMessageEvent.parse_obj(message)
elif conversation_type == ConversationType.group:
event = GroupMessageEvent.parse_obj(message)
else:
raise ValueError("Unsupported conversation type")
except Exception as e:
log("Error", "Event Parser Error", e)
return
try:
@ -95,6 +100,7 @@ class Bot(BaseBot):
)
return
@overrides(BaseBot)
async def call_api(self,
api: str,
event: Optional[MessageEvent] = None,
@ -138,19 +144,18 @@ class Bot(BaseBot):
target = event.sessionWebhook
else:
target = None
if not target:
raise ApiNotAvailable
headers = {}
segment: MessageSegment = data["message"][0]
message: Message = data.get("message", None)
if not message:
raise ValueError("Message not found")
try:
async with httpx.AsyncClient(headers=headers) as client:
response = await client.post(
target,
params={"access_token": self.config.access_token},
json=segment.data,
json=message._produce(),
timeout=self.config.api_timeout)
if 200 <= response.status_code < 300:
@ -167,8 +172,9 @@ class Bot(BaseBot):
except httpx.HTTPError:
raise NetworkError("HTTP request failed")
@overrides(BaseBot)
async def send(self,
event: Event,
event: MessageEvent,
message: Union[str, "Message", "MessageSegment"],
at_sender: bool = False,
**kwargs) -> Any:
@ -196,13 +202,15 @@ class Bot(BaseBot):
"""
msg = message if isinstance(message, Message) else Message(message)
at_sender = at_sender and bool(event.user_id)
at_sender = at_sender and bool(event.senderId)
params = {}
params["event"] = event
params.update(kwargs)
if at_sender and event.detail_type != "private":
params["message"] = f"@{event.user_id} " + msg
if at_sender and event.conversationType != ConversationType.private:
params[
"message"] = f"@{event.senderId} " + msg + MessageSegment.atMobiles(
event.senderId)
else:
params["message"] = msg

View File

@ -1,84 +1,124 @@
from typing import Union, Optional
from enum import Enum
from typing import List, Optional
from typing_extensions import Literal
from pydantic import BaseModel, validator, parse_obj_as
from pydantic.fields import ModelField
from pydantic import BaseModel
from nonebot.adapters import Event as BaseEvent
from nonebot.utils import escape_tag
from nonebot.typing import overrides
from nonebot.adapters import Event as BaseEvent
from .message import Message
from .model import MessageModel, PrivateMessageModel, GroupMessageModel, ConversationType, TextMessage
class Event(BaseEvent):
"""
钉钉 协议 Event 适配继承属性参考 `BaseEvent <./#class-baseevent>`_ 。
钉钉 协议 Event 适配各事件字段参考 `钉钉文档`_
.. _钉钉文档:
https://ding-doc.dingtalk.com/document#/org-dev-guide/elzz1p
"""
message: Message = None
def __init__(self, **data):
super().__init__(**data)
# 其实目前钉钉机器人只能接收到 text 类型的消息
message: Union[TextMessage] = getattr(self, self.msgtype, None)
self.message = parse_obj_as(Message, message)
chatbotUserId: str
def get_type(self) -> Literal["message"]:
"""
- 类型: ``str``
- 说明: 事件类型
"""
return "message"
@overrides(BaseEvent)
def get_type(self) -> Literal["message", "notice", "request", "meta_event"]:
raise ValueError("Event has no type!")
@overrides(BaseEvent)
def get_event_name(self) -> str:
detail_type = self.conversationType.name
return self.get_type() + "." + detail_type
raise ValueError("Event has no type!")
@overrides(BaseEvent)
def get_event_description(self) -> str:
return (f'Message[{self.msgtype}] {self.msgId} from {self.senderId} "' +
"".join(
map(
lambda x: escape_tag(str(x))
if x.is_text() else f"<le>{escape_tag(str(x))}</le>",
self.message,
)) + '"')
def get_user_id(self) -> str:
return self.senderId
def get_session_id(self) -> str:
"""
- 类型: ``str``
- 说明: 消息 ID
"""
return self.msgId
raise ValueError("Event has no type!")
@overrides(BaseEvent)
def get_message(self) -> "Message":
"""
- 类型: ``Message``
- 说明: 消息内容
"""
return self.message
raise ValueError("Event has no type!")
@overrides(BaseEvent)
def get_plaintext(self) -> str:
"""
- 类型: ``str``
- 说明: 纯文本消息内容
"""
return self.message.extract_plain_text().strip() if self.message else ""
raise ValueError("Event has no type!")
@overrides(BaseEvent)
def get_user_id(self) -> str:
raise ValueError("Event has no type!")
class MessageEvent(MessageModel, Event):
pass
class PrivateMessageEvent(PrivateMessageModel, Event):
@overrides(BaseEvent)
def get_session_id(self) -> str:
raise ValueError("Event has no type!")
@overrides(BaseEvent)
def is_tome(self) -> bool:
return True
class GroupMessageEvent(GroupMessageModel, Event):
class TextMessage(BaseModel):
content: str
class AtUsersItem(BaseModel):
dingtalkId: str
staffId: Optional[str]
class ConversationType(str, Enum):
private = "1"
group = "2"
class MessageEvent(Event):
msgtype: str
text: TextMessage
msgId: str
createAt: int # ms
conversationType: ConversationType
conversationId: str
senderId: str
senderNick: str
senderCorpId: str
sessionWebhook: str
sessionWebhookExpiredTime: int
isAdmin: bool
@overrides(Event)
def get_type(self) -> Literal["message", "notice", "request", "meta_event"]:
return "message"
@overrides(BaseEvent)
def get_event_name(self) -> str:
return f"{self.get_type()}.{self.conversationType.name}"
@overrides(BaseEvent)
def get_event_description(self) -> str:
return f'Message[{self.msgtype}] {self.msgId} from {self.senderId} "{self.text.content}"'
@overrides(BaseEvent)
def get_plaintext(self) -> str:
return self.text.content
@overrides(BaseEvent)
def get_user_id(self) -> str:
return self.senderId
@overrides(BaseEvent)
def get_session_id(self) -> str:
return self.senderId
class PrivateMessageEvent(MessageEvent):
chatbotCorpId: str
senderStaffId: Optional[str]
conversationType: ConversationType = ConversationType.private
class GroupMessageEvent(MessageEvent):
atUsers: List[AtUsersItem]
conversationType: ConversationType = ConversationType.group
conversationTitle: str
isInAtList: bool
@overrides(MessageEvent)
def is_tome(self) -> bool:
return self.isInAtList

View File

@ -39,6 +39,9 @@ class ActionFailed(BaseActionFailed, DingAdapterException):
def __repr__(self):
return f"<ApiError errcode={self.errcode} errmsg={self.errmsg}>"
def __str__(self):
return self.__repr__()
class ApiNotAvailable(BaseApiNotAvailable, DingAdapterException):
pass
@ -66,7 +69,7 @@ class NetworkError(BaseNetworkError, DingAdapterException):
return self.__repr__()
class SessionExpired(BaseApiNotAvailable, DingAdapterException):
class SessionExpired(ApiNotAvailable, DingAdapterException):
"""
:说明:
@ -75,3 +78,6 @@ class SessionExpired(BaseApiNotAvailable, DingAdapterException):
def __repr__(self) -> str:
return f"<Session Webhook is Expired>"
def __str__(self):
return self.__repr__()

View File

@ -2,39 +2,23 @@ from typing import Any, Dict, Union, Iterable
from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment
from .utils import log
from .model import TextMessage
class MessageSegment(BaseMessageSegment):
"""
钉钉 协议 MessageSegment 适配具体方法参考协议消息段类型或源码
"""
def __init__(self, type_: str, msg: Dict[str, Any]) -> None:
data = {
"msgtype": type_,
}
if msg:
data.update(msg)
log("DEBUG", f"data {data}")
def __init__(self, type_: str, data: Dict[str, Any]) -> None:
super().__init__(type=type_, data=data)
@classmethod
def from_segment(cls, segment: "MessageSegment"):
return MessageSegment(segment.type, segment.data)
def __str__(self):
log("DEBUG", f"__str__: self.type {self.type} data {self.data}")
if self.type == "text":
return str(self.data["text"]["content"].strip())
return str(self.data["content"])
elif self.type == "markdown":
return str(self.data["text"])
return ""
def __add__(self, other) -> "Message":
if isinstance(other, str):
if self.type == 'text':
self.data['text']['content'] += other
return MessageSegment.from_segment(self)
return Message(self) + other
def __radd__(self, other) -> "Message":
@ -43,43 +27,41 @@ class MessageSegment(BaseMessageSegment):
def is_text(self) -> bool:
return self.type == "text"
def atMobile(self, mobileNumber):
self.data.setdefault("at", {})
self.data["at"].setdefault("atMobiles", [])
self.data["at"]["atMobiles"].append(mobileNumber)
def atAll(self, value):
self.data.setdefault("at", {})
self.data["at"]["isAtAll"] = value
@staticmethod
def atAll() -> "MessageSegment":
return MessageSegment("at", {"isAtAll": True})
@staticmethod
def text(text_: str) -> "MessageSegment":
return MessageSegment("text", {"text": {"content": text_.strip()}})
def atMobiles(*mobileNumber: str) -> "MessageSegment":
return MessageSegment("at", {"atMobiles": list(mobileNumber)})
@staticmethod
def text(text: str) -> "MessageSegment":
return MessageSegment("text", {"content": text})
@staticmethod
def markdown(title: str, text: str) -> "MessageSegment":
return MessageSegment("markdown", {
"markdown": {
return MessageSegment(
"markdown",
{
"title": title,
"text": text,
},
})
)
@staticmethod
def actionCardSingleBtn(title: str, text: str, btnTitle: str,
btnUrl) -> "MessageSegment":
return MessageSegment(
"actionCard", {
"actionCard": {
"title": title,
"text": text,
"singleTitle": btnTitle,
"singleURL": btnUrl
}
"title": title,
"text": text,
"singleTitle": btnTitle,
"singleURL": btnUrl
})
@staticmethod
def actionCardSingleMultiBtns(
def actionCardMultiBtns(
title: str,
text: str,
btns: list = [],
@ -95,28 +77,26 @@ class MessageSegment(BaseMessageSegment):
"""
return MessageSegment(
"actionCard", {
"actionCard": {
"title": title,
"text": text,
"hideAvatar": "1" if hideAvatar else "0",
"btnOrientation": btnOrientation,
"btns": btns
}
"title": title,
"text": text,
"hideAvatar": "1" if hideAvatar else "0",
"btnOrientation": btnOrientation,
"btns": btns
})
@staticmethod
def feedCard(links: list = [],) -> "MessageSegment":
def feedCard(links: list = []) -> "MessageSegment":
"""
:参数:
* ``links``: [{ "title": xxx, "messageURL": xxx, "picURL": xxx }, ...]
"""
return MessageSegment("feedCard", {"feedCard": {"links": links}})
return MessageSegment("feedCard", {"links": links})
@staticmethod
def empty() -> "MessageSegment":
"""不想回复消息到群里"""
return MessageSegment("empty")
return MessageSegment("empty", {})
class Message(BaseMessage):
@ -129,17 +109,35 @@ class Message(BaseMessage):
return cls(value)
@staticmethod
def _construct(
msg: Union[str, dict, list,
TextMessage]) -> Iterable[MessageSegment]:
def _construct(msg: Union[str, dict, list]) -> Iterable[MessageSegment]:
if isinstance(msg, dict):
yield MessageSegment(msg["type"], msg.get("data") or {})
return
elif isinstance(msg, list):
for seg in msg:
yield MessageSegment(seg["type"], seg.get("data") or {})
return
elif isinstance(msg, TextMessage):
yield MessageSegment("text", {"text": msg.dict()})
elif isinstance(msg, str):
yield MessageSegment.text(msg)
def _produce(self) -> dict:
data = {}
for segment in self:
if segment.type == "text":
data["msgtype"] = "text"
data.setdefault("text", {})
data["text"]["content"] = data["text"].setdefault(
"content", "") + segment.data["content"]
elif segment.type == "markdown":
data["msgtype"] = "markdown"
data.setdefault("markdown", {})
data["markdown"]["text"] = data["markdown"].setdefault(
"content", "") + segment.data["content"]
elif segment.type == "empty":
data["msgtype"] = "empty"
elif segment.type == "at" and "atMobiles" in segment.data:
data.setdefault("at", {})
data["at"]["atMobiles"] = data["at"].setdefault(
"atMobiles", []) + segment.data["atMobiles"]
elif segment.data:
data.setdefault(segment.type, {})
data[segment.type].update(segment.data)
return data

View File

@ -1,56 +0,0 @@
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel
class Headers(BaseModel):
sign: str
token: str
# ms
timestamp: int
class TextMessage(BaseModel):
content: str
class AtUsersItem(BaseModel):
dingtalkId: str
staffId: Optional[str]
class ConversationType(str, Enum):
private = '1'
group = '2'
class MessageModel(BaseModel):
chatbotUserId: str = None
conversationId: str = None
conversationType: ConversationType = None
# ms
createAt: int = None
isAdmin: bool = None
msgId: str = None
msgtype: str = None
senderCorpId: str = None
senderId: str = None
senderNick: str = None
sessionWebhook: str = None
# ms
sessionWebhookExpiredTime: int = None
text: Optional[TextMessage] = None
class PrivateMessageModel(MessageModel):
chatbotCorpId: str = None
conversationType: ConversationType = ConversationType.private
senderStaffId: str = None
class GroupMessageModel(MessageModel):
atUsers: List[AtUsersItem] = None
conversationType: ConversationType = ConversationType.group
conversationTitle: str = None
isInAtList: bool = None