improve ding adapter

This commit is contained in:
yanyongyu 2020-12-30 00:36:29 +08:00
parent 0221d02ca7
commit c8cd6de2f2
8 changed files with 194 additions and 193 deletions

View File

@ -6,6 +6,7 @@
""" """
import abc import abc
from copy import copy
from typing_extensions import Literal from typing_extensions import Literal
from functools import reduce, partial from functools import reduce, partial
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -292,7 +293,7 @@ class MessageSegment(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def __add__(self: T_MessageSegment, other: Union[str, T_MessageSegment, def __add__(self: T_MessageSegment, other: Union[str, T_MessageSegment,
T_Message]) -> "T_Message": T_Message]) -> T_Message:
"""你需要在这里实现不同消息段的合并: """你需要在这里实现不同消息段的合并:
比如 比如
if isinstance(other, str): if isinstance(other, str):
@ -326,6 +327,9 @@ class MessageSegment(abc.ABC):
def get(self, key, default=None): def get(self, key, default=None):
return getattr(self, key, default) return getattr(self, key, default)
def copy(self: T_MessageSegment) -> T_MessageSegment:
return copy(self)
@abc.abstractmethod @abc.abstractmethod
def is_text(self) -> bool: def is_text(self) -> bool:
raise NotImplementedError raise NotImplementedError
@ -335,7 +339,8 @@ class Message(list, abc.ABC):
"""消息数组""" """消息数组"""
def __init__(self, def __init__(self,
message: Union[str, list, dict, T_MessageSegment, T_Message, Any] = None, message: Union[str, list, dict, T_MessageSegment, T_Message,
Any] = None,
*args, *args,
**kwargs): **kwargs):
""" """
@ -364,7 +369,8 @@ class Message(list, abc.ABC):
@staticmethod @staticmethod
@abc.abstractmethod @abc.abstractmethod
def _construct(msg: Union[str, list, dict, Any]) -> Iterable[T_MessageSegment]: def _construct(
msg: Union[str, list, dict, Any]) -> Iterable[T_MessageSegment]:
raise NotImplementedError raise NotImplementedError
def __add__(self: T_Message, other: Union[str, T_MessageSegment, def __add__(self: T_Message, other: Union[str, T_MessageSegment,

View File

@ -6,7 +6,6 @@ import asyncio
from typing import Any, Dict, Union, Optional, TYPE_CHECKING from typing import Any, Dict, Union, Optional, TYPE_CHECKING
import httpx import httpx
from nonebot.log import logger from nonebot.log import logger
from nonebot.config import Config from nonebot.config import Config
from nonebot.typing import overrides from nonebot.typing import overrides

View File

@ -11,7 +11,7 @@
from .utils import log from .utils import log
from .bot import Bot from .bot import Bot
from .event import Event
from .message import Message, MessageSegment from .message import Message, MessageSegment
from .event import Event, MessageEvent, PrivateMessageEvent, GroupMessageEvent
from .exception import (DingAdapterException, ApiNotAvailable, NetworkError, from .exception import (DingAdapterException, ApiNotAvailable, NetworkError,
ActionFailed, SessionExpired) ActionFailed, SessionExpired)

View File

@ -6,18 +6,18 @@ from typing import Any, Union, Optional, TYPE_CHECKING
import httpx import httpx
from nonebot.log import logger from nonebot.log import logger
from nonebot.config import Config from nonebot.config import Config
from nonebot.typing import overrides
from nonebot.message import handle_event from nonebot.message import handle_event
from nonebot.adapters import Bot as BaseBot from nonebot.adapters import Bot as BaseBot
from nonebot.exception import RequestDenied from nonebot.exception import RequestDenied
from .utils import log from .utils import log
from .event import Event, MessageEvent, PrivateMessageEvent, GroupMessageEvent
from .model import ConversationType
from .message import Message, MessageSegment from .message import Message, MessageSegment
from .exception import NetworkError, ApiNotAvailable, ActionFailed, SessionExpired from .exception import NetworkError, ApiNotAvailable, ActionFailed, SessionExpired
from .event import Event, MessageEvent, PrivateMessageEvent, GroupMessageEvent, ConversationType
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.drivers import BaseDriver as Driver from nonebot.drivers import Driver
class Bot(BaseBot): class Bot(BaseBot):
@ -38,6 +38,7 @@ class Bot(BaseBot):
return "ding" return "ding"
@classmethod @classmethod
@overrides(BaseBot)
async def check_permission(cls, driver: "Driver", connection_type: str, async def check_permission(cls, driver: "Driver", connection_type: str,
headers: dict, body: Optional[dict]) -> str: headers: dict, body: Optional[dict]) -> str:
""" """
@ -73,18 +74,22 @@ class Bot(BaseBot):
log("WARNING", "Ding signature check ignored!") log("WARNING", "Ding signature check ignored!")
return body["chatbotUserId"] return body["chatbotUserId"]
async def handle_message(self, body: dict): @overrides(BaseBot)
if not body: async def handle_message(self, message: dict):
if not message:
return return
# 判断消息类型,生成不同的 Event # 判断消息类型,生成不同的 Event
conversation_type = body["conversationType"] try:
if conversation_type == ConversationType.private: conversation_type = message["conversationType"]
event = PrivateMessageEvent.parse_obj(body) if conversation_type == ConversationType.private:
else: event = PrivateMessageEvent.parse_obj(message)
event = GroupMessageEvent.parse_obj(body) elif conversation_type == ConversationType.group:
event = GroupMessageEvent.parse_obj(message)
if not event: else:
raise ValueError("Unsupported conversation type")
except Exception as e:
log("Error", "Event Parser Error", e)
return return
try: try:
@ -95,6 +100,7 @@ class Bot(BaseBot):
) )
return return
@overrides(BaseBot)
async def call_api(self, async def call_api(self,
api: str, api: str,
event: Optional[MessageEvent] = None, event: Optional[MessageEvent] = None,
@ -138,19 +144,18 @@ class Bot(BaseBot):
target = event.sessionWebhook target = event.sessionWebhook
else: else:
target = None
if not target:
raise ApiNotAvailable raise ApiNotAvailable
headers = {} headers = {}
segment: MessageSegment = data["message"][0] message: Message = data.get("message", None)
if not message:
raise ValueError("Message not found")
try: try:
async with httpx.AsyncClient(headers=headers) as client: async with httpx.AsyncClient(headers=headers) as client:
response = await client.post( response = await client.post(
target, target,
params={"access_token": self.config.access_token}, params={"access_token": self.config.access_token},
json=segment.data, json=message._produce(),
timeout=self.config.api_timeout) timeout=self.config.api_timeout)
if 200 <= response.status_code < 300: if 200 <= response.status_code < 300:
@ -167,8 +172,9 @@ class Bot(BaseBot):
except httpx.HTTPError: except httpx.HTTPError:
raise NetworkError("HTTP request failed") raise NetworkError("HTTP request failed")
@overrides(BaseBot)
async def send(self, async def send(self,
event: Event, event: MessageEvent,
message: Union[str, "Message", "MessageSegment"], message: Union[str, "Message", "MessageSegment"],
at_sender: bool = False, at_sender: bool = False,
**kwargs) -> Any: **kwargs) -> Any:
@ -196,13 +202,15 @@ class Bot(BaseBot):
""" """
msg = message if isinstance(message, Message) else Message(message) msg = message if isinstance(message, Message) else Message(message)
at_sender = at_sender and bool(event.user_id) at_sender = at_sender and bool(event.senderId)
params = {} params = {}
params["event"] = event params["event"] = event
params.update(kwargs) params.update(kwargs)
if at_sender and event.detail_type != "private": if at_sender and event.conversationType != ConversationType.private:
params["message"] = f"@{event.user_id} " + msg params[
"message"] = f"@{event.senderId} " + msg + MessageSegment.atMobiles(
event.senderId)
else: else:
params["message"] = msg params["message"] = msg

View File

@ -1,84 +1,124 @@
from typing import Union, Optional from enum import Enum
from typing import List, Optional
from typing_extensions import Literal from typing_extensions import Literal
from pydantic import BaseModel, validator, parse_obj_as from pydantic import BaseModel
from pydantic.fields import ModelField
from nonebot.adapters import Event as BaseEvent
from nonebot.utils import escape_tag from nonebot.utils import escape_tag
from nonebot.typing import overrides
from nonebot.adapters import Event as BaseEvent
from .message import Message from .message import Message
from .model import MessageModel, PrivateMessageModel, GroupMessageModel, ConversationType, TextMessage
class Event(BaseEvent): class Event(BaseEvent):
""" """
钉钉 协议 Event 适配继承属性参考 `BaseEvent <./#class-baseevent>`_ 。 钉钉 协议 Event 适配各事件字段参考 `钉钉文档`_
.. _钉钉文档:
https://ding-doc.dingtalk.com/document#/org-dev-guide/elzz1p
""" """
message: Message = None
def __init__(self, **data): chatbotUserId: str
super().__init__(**data)
# 其实目前钉钉机器人只能接收到 text 类型的消息
message: Union[TextMessage] = getattr(self, self.msgtype, None)
self.message = parse_obj_as(Message, message)
def get_type(self) -> Literal["message"]: @overrides(BaseEvent)
""" def get_type(self) -> Literal["message", "notice", "request", "meta_event"]:
- 类型: ``str`` raise ValueError("Event has no type!")
- 说明: 事件类型
"""
return "message"
@overrides(BaseEvent)
def get_event_name(self) -> str: def get_event_name(self) -> str:
detail_type = self.conversationType.name raise ValueError("Event has no type!")
return self.get_type() + "." + detail_type
@overrides(BaseEvent)
def get_event_description(self) -> str: def get_event_description(self) -> str:
return (f'Message[{self.msgtype}] {self.msgId} from {self.senderId} "' + raise ValueError("Event has no type!")
"".join(
map(
lambda x: escape_tag(str(x))
if x.is_text() else f"<le>{escape_tag(str(x))}</le>",
self.message,
)) + '"')
def get_user_id(self) -> str:
return self.senderId
def get_session_id(self) -> str:
"""
- 类型: ``str``
- 说明: 消息 ID
"""
return self.msgId
@overrides(BaseEvent)
def get_message(self) -> "Message": def get_message(self) -> "Message":
""" raise ValueError("Event has no type!")
- 类型: ``Message``
- 说明: 消息内容
"""
return self.message
@overrides(BaseEvent)
def get_plaintext(self) -> str: def get_plaintext(self) -> str:
""" raise ValueError("Event has no type!")
- 类型: ``str``
- 说明: 纯文本消息内容
"""
return self.message.extract_plain_text().strip() if self.message else ""
@overrides(BaseEvent)
def get_user_id(self) -> str:
raise ValueError("Event has no type!")
class MessageEvent(MessageModel, Event): @overrides(BaseEvent)
pass def get_session_id(self) -> str:
raise ValueError("Event has no type!")
class PrivateMessageEvent(PrivateMessageModel, Event):
@overrides(BaseEvent)
def is_tome(self) -> bool: def is_tome(self) -> bool:
return True return True
class GroupMessageEvent(GroupMessageModel, Event): class TextMessage(BaseModel):
content: str
class AtUsersItem(BaseModel):
dingtalkId: str
staffId: Optional[str]
class ConversationType(str, Enum):
private = "1"
group = "2"
class MessageEvent(Event):
msgtype: str
text: TextMessage
msgId: str
createAt: int # ms
conversationType: ConversationType
conversationId: str
senderId: str
senderNick: str
senderCorpId: str
sessionWebhook: str
sessionWebhookExpiredTime: int
isAdmin: bool
@overrides(Event)
def get_type(self) -> Literal["message", "notice", "request", "meta_event"]:
return "message"
@overrides(BaseEvent)
def get_event_name(self) -> str:
return f"{self.get_type()}.{self.conversationType.name}"
@overrides(BaseEvent)
def get_event_description(self) -> str:
return f'Message[{self.msgtype}] {self.msgId} from {self.senderId} "{self.text.content}"'
@overrides(BaseEvent)
def get_plaintext(self) -> str:
return self.text.content
@overrides(BaseEvent)
def get_user_id(self) -> str:
return self.senderId
@overrides(BaseEvent)
def get_session_id(self) -> str:
return self.senderId
class PrivateMessageEvent(MessageEvent):
chatbotCorpId: str
senderStaffId: Optional[str]
conversationType: ConversationType = ConversationType.private
class GroupMessageEvent(MessageEvent):
atUsers: List[AtUsersItem]
conversationType: ConversationType = ConversationType.group
conversationTitle: str
isInAtList: bool
@overrides(MessageEvent)
def is_tome(self) -> bool: def is_tome(self) -> bool:
return self.isInAtList return self.isInAtList

View File

@ -39,6 +39,9 @@ class ActionFailed(BaseActionFailed, DingAdapterException):
def __repr__(self): def __repr__(self):
return f"<ApiError errcode={self.errcode} errmsg={self.errmsg}>" return f"<ApiError errcode={self.errcode} errmsg={self.errmsg}>"
def __str__(self):
return self.__repr__()
class ApiNotAvailable(BaseApiNotAvailable, DingAdapterException): class ApiNotAvailable(BaseApiNotAvailable, DingAdapterException):
pass pass
@ -66,7 +69,7 @@ class NetworkError(BaseNetworkError, DingAdapterException):
return self.__repr__() return self.__repr__()
class SessionExpired(BaseApiNotAvailable, DingAdapterException): class SessionExpired(ApiNotAvailable, DingAdapterException):
""" """
:说明: :说明:
@ -75,3 +78,6 @@ class SessionExpired(BaseApiNotAvailable, DingAdapterException):
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Session Webhook is Expired>" return f"<Session Webhook is Expired>"
def __str__(self):
return self.__repr__()

View File

@ -2,39 +2,23 @@ from typing import Any, Dict, Union, Iterable
from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment
from .utils import log
from .model import TextMessage
class MessageSegment(BaseMessageSegment): class MessageSegment(BaseMessageSegment):
""" """
钉钉 协议 MessageSegment 适配具体方法参考协议消息段类型或源码 钉钉 协议 MessageSegment 适配具体方法参考协议消息段类型或源码
""" """
def __init__(self, type_: str, msg: Dict[str, Any]) -> None: def __init__(self, type_: str, data: Dict[str, Any]) -> None:
data = {
"msgtype": type_,
}
if msg:
data.update(msg)
log("DEBUG", f"data {data}")
super().__init__(type=type_, data=data) super().__init__(type=type_, data=data)
@classmethod
def from_segment(cls, segment: "MessageSegment"):
return MessageSegment(segment.type, segment.data)
def __str__(self): def __str__(self):
log("DEBUG", f"__str__: self.type {self.type} data {self.data}")
if self.type == "text": if self.type == "text":
return str(self.data["text"]["content"].strip()) return str(self.data["content"])
elif self.type == "markdown":
return str(self.data["text"])
return "" return ""
def __add__(self, other) -> "Message": def __add__(self, other) -> "Message":
if isinstance(other, str):
if self.type == 'text':
self.data['text']['content'] += other
return MessageSegment.from_segment(self)
return Message(self) + other return Message(self) + other
def __radd__(self, other) -> "Message": def __radd__(self, other) -> "Message":
@ -43,43 +27,41 @@ class MessageSegment(BaseMessageSegment):
def is_text(self) -> bool: def is_text(self) -> bool:
return self.type == "text" return self.type == "text"
def atMobile(self, mobileNumber): @staticmethod
self.data.setdefault("at", {}) def atAll() -> "MessageSegment":
self.data["at"].setdefault("atMobiles", []) return MessageSegment("at", {"isAtAll": True})
self.data["at"]["atMobiles"].append(mobileNumber)
def atAll(self, value):
self.data.setdefault("at", {})
self.data["at"]["isAtAll"] = value
@staticmethod @staticmethod
def text(text_: str) -> "MessageSegment": def atMobiles(*mobileNumber: str) -> "MessageSegment":
return MessageSegment("text", {"text": {"content": text_.strip()}}) return MessageSegment("at", {"atMobiles": list(mobileNumber)})
@staticmethod
def text(text: str) -> "MessageSegment":
return MessageSegment("text", {"content": text})
@staticmethod @staticmethod
def markdown(title: str, text: str) -> "MessageSegment": def markdown(title: str, text: str) -> "MessageSegment":
return MessageSegment("markdown", { return MessageSegment(
"markdown": { "markdown",
{
"title": title, "title": title,
"text": text, "text": text,
}, },
}) )
@staticmethod @staticmethod
def actionCardSingleBtn(title: str, text: str, btnTitle: str, def actionCardSingleBtn(title: str, text: str, btnTitle: str,
btnUrl) -> "MessageSegment": btnUrl) -> "MessageSegment":
return MessageSegment( return MessageSegment(
"actionCard", { "actionCard", {
"actionCard": { "title": title,
"title": title, "text": text,
"text": text, "singleTitle": btnTitle,
"singleTitle": btnTitle, "singleURL": btnUrl
"singleURL": btnUrl
}
}) })
@staticmethod @staticmethod
def actionCardSingleMultiBtns( def actionCardMultiBtns(
title: str, title: str,
text: str, text: str,
btns: list = [], btns: list = [],
@ -95,28 +77,26 @@ class MessageSegment(BaseMessageSegment):
""" """
return MessageSegment( return MessageSegment(
"actionCard", { "actionCard", {
"actionCard": { "title": title,
"title": title, "text": text,
"text": text, "hideAvatar": "1" if hideAvatar else "0",
"hideAvatar": "1" if hideAvatar else "0", "btnOrientation": btnOrientation,
"btnOrientation": btnOrientation, "btns": btns
"btns": btns
}
}) })
@staticmethod @staticmethod
def feedCard(links: list = [],) -> "MessageSegment": def feedCard(links: list = []) -> "MessageSegment":
""" """
:参数: :参数:
* ``links``: [{ "title": xxx, "messageURL": xxx, "picURL": xxx }, ...] * ``links``: [{ "title": xxx, "messageURL": xxx, "picURL": xxx }, ...]
""" """
return MessageSegment("feedCard", {"feedCard": {"links": links}}) return MessageSegment("feedCard", {"links": links})
@staticmethod @staticmethod
def empty() -> "MessageSegment": def empty() -> "MessageSegment":
"""不想回复消息到群里""" """不想回复消息到群里"""
return MessageSegment("empty") return MessageSegment("empty", {})
class Message(BaseMessage): class Message(BaseMessage):
@ -129,17 +109,35 @@ class Message(BaseMessage):
return cls(value) return cls(value)
@staticmethod @staticmethod
def _construct( def _construct(msg: Union[str, dict, list]) -> Iterable[MessageSegment]:
msg: Union[str, dict, list,
TextMessage]) -> Iterable[MessageSegment]:
if isinstance(msg, dict): if isinstance(msg, dict):
yield MessageSegment(msg["type"], msg.get("data") or {}) yield MessageSegment(msg["type"], msg.get("data") or {})
return
elif isinstance(msg, list): elif isinstance(msg, list):
for seg in msg: for seg in msg:
yield MessageSegment(seg["type"], seg.get("data") or {}) yield MessageSegment(seg["type"], seg.get("data") or {})
return
elif isinstance(msg, TextMessage):
yield MessageSegment("text", {"text": msg.dict()})
elif isinstance(msg, str): elif isinstance(msg, str):
yield MessageSegment.text(msg) yield MessageSegment.text(msg)
def _produce(self) -> dict:
data = {}
for segment in self:
if segment.type == "text":
data["msgtype"] = "text"
data.setdefault("text", {})
data["text"]["content"] = data["text"].setdefault(
"content", "") + segment.data["content"]
elif segment.type == "markdown":
data["msgtype"] = "markdown"
data.setdefault("markdown", {})
data["markdown"]["text"] = data["markdown"].setdefault(
"content", "") + segment.data["content"]
elif segment.type == "empty":
data["msgtype"] = "empty"
elif segment.type == "at" and "atMobiles" in segment.data:
data.setdefault("at", {})
data["at"]["atMobiles"] = data["at"].setdefault(
"atMobiles", []) + segment.data["atMobiles"]
elif segment.data:
data.setdefault(segment.type, {})
data[segment.type].update(segment.data)
return data

View File

@ -1,56 +0,0 @@
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel
class Headers(BaseModel):
sign: str
token: str
# ms
timestamp: int
class TextMessage(BaseModel):
content: str
class AtUsersItem(BaseModel):
dingtalkId: str
staffId: Optional[str]
class ConversationType(str, Enum):
private = '1'
group = '2'
class MessageModel(BaseModel):
chatbotUserId: str = None
conversationId: str = None
conversationType: ConversationType = None
# ms
createAt: int = None
isAdmin: bool = None
msgId: str = None
msgtype: str = None
senderCorpId: str = None
senderId: str = None
senderNick: str = None
sessionWebhook: str = None
# ms
sessionWebhookExpiredTime: int = None
text: Optional[TextMessage] = None
class PrivateMessageModel(MessageModel):
chatbotCorpId: str = None
conversationType: ConversationType = ConversationType.private
senderStaffId: str = None
class GroupMessageModel(MessageModel):
atUsers: List[AtUsersItem] = None
conversationType: ConversationType = ConversationType.group
conversationTitle: str = None
isInAtList: bool = None