Add ding adapter

This commit is contained in:
Artin 2020-12-03 00:59:32 +08:00
parent f332199baa
commit edb4458031
11 changed files with 695 additions and 11 deletions

View File

@ -9,9 +9,11 @@ import abc
from functools import reduce, partial
from dataclasses import dataclass, field
from pydantic import BaseModel
from nonebot.config import Config
from nonebot.typing import Driver, Message, WebSocket
from nonebot.typing import Any, Dict, Union, Optional, NoReturn, Callable, Iterable, Awaitable
from nonebot.typing import Any, Dict, Union, Optional, NoReturn, Callable, Iterable, Awaitable, TypeVar, Generic
class BaseBot(abc.ABC):
@ -135,24 +137,27 @@ class BaseBot(abc.ABC):
raise NotImplementedError
class BaseEvent(abc.ABC):
T = TypeVar("T", dict, BaseModel)
class BaseEvent(abc.ABC, Generic[T]):
"""
Event 基类提供上报信息的关键信息其余信息可从原始上报消息获取
"""
def __init__(self, raw_event: dict):
def __init__(self, raw_event: T):
"""
:参数:
* ``raw_event: dict``: 原始上报消息
* ``raw_event: T``: 原始上报消息
"""
self._raw_event = raw_event
self._raw_event: T = raw_event
def __repr__(self) -> str:
return f"<Event {self.self_id}: {self.name} {self.time}>"
@property
def raw_event(self) -> dict:
def raw_event(self) -> T:
"""原始上报消息"""
return self._raw_event
@ -347,17 +352,17 @@ class BaseMessage(list, abc.ABC):
"""消息数组"""
def __init__(self,
message: Union[str, dict, list, BaseMessageSegment,
message: Union[str, dict, list, BaseModel, BaseMessageSegment,
"BaseMessage"] = None,
*args,
**kwargs):
"""
:参数:
* ``message: Union[str, dict, list, MessageSegment, Message]``: 消息内容
* ``message: Union[str, dict, list, BaseModel, MessageSegment, Message]``: 消息内容
"""
super().__init__(*args, **kwargs)
if isinstance(message, (str, dict, list)):
if isinstance(message, (str, dict, list, BaseModel)):
self.extend(self._construct(message))
elif isinstance(message, BaseMessage):
self.extend(message)
@ -448,4 +453,4 @@ class BaseMessage(list, abc.ABC):
return f"{x} {y}" if y.type == "text" else x
plain_text = reduce(_concat, self, "")
return plain_text[1:] if plain_text else plain_text
return plain_text.strip()

View File

@ -0,0 +1,15 @@
"""
钉钉群机器人 协议适配
============================
协议详情请看: `钉钉文档`_
.. _钉钉文档:
https://ding-doc.dingtalk.com/doc#/serverapi2/krgddi
"""
from .bot import Bot
from .event import Event
from .message import Message, MessageSegment
from .exception import ApiError, SessionExpired, AdapterException

View File

@ -0,0 +1,205 @@
from datetime import datetime
import httpx
from nonebot.log import logger
from nonebot.config import Config
from nonebot.message import handle_event
from nonebot.typing import Driver, WebSocket, NoReturn
from nonebot.typing import Any, Union, Optional
from nonebot.adapters import BaseBot
from nonebot.exception import NetworkError, RequestDenied, ApiNotAvailable
from .exception import ApiError, SessionExpired
from .utils import check_legal, log
from .event import Event
from .message import Message, MessageSegment
from .model import MessageModel
class Bot(BaseBot):
"""
钉钉 协议 Bot 适配继承属性参考 `BaseBot <./#class-basebot>`_ 。
"""
def __init__(self,
driver: Driver,
connection_type: str,
config: Config,
self_id: str,
*,
websocket: Optional[WebSocket] = None):
super().__init__(driver,
connection_type,
config,
self_id,
websocket=websocket)
@property
def type(self) -> str:
"""
- 返回: ``"ding"``
"""
return "ding"
@classmethod
async def check_permission(cls, driver: Driver, connection_type: str,
headers: dict,
body: Optional[dict]) -> Union[str, NoReturn]:
"""
:说明:
钉钉协议鉴权参考 `鉴权 <https://ding-doc.dingtalk.com/doc#/serverapi2/elzz1p>`_
"""
timestamp = headers.get("timestamp")
sign = headers.get("sign")
log("DEBUG", "headers: {}".format(headers))
log("DEBUG", "body: {}".format(body))
# 检查 timestamp
if not timestamp:
log("WARNING", "Missing `timestamp` Header")
raise RequestDenied(400, "Missing `timestamp` Header")
# 检查 sign
if not sign:
log("WARNING", "Missing `sign` Header")
raise RequestDenied(400, "Missing `sign` Header")
# 校验 sign 和 timestamp判断是否是来自钉钉的合法请求
if not check_legal(timestamp, sign, driver):
log("WARNING", "Signature Header is invalid")
raise RequestDenied(403, "Signature is invalid")
# 检查连接方式
if connection_type not in ["http"]:
log("WARNING", "Unsupported connection type")
raise RequestDenied(405, "Unsupported connection type")
access_token = driver.config.access_token
if access_token and access_token != access_token:
log(
"WARNING", "Authorization Header is invalid"
if access_token else "Missing Authorization Header")
raise RequestDenied(
403, "Authorization Header is invalid"
if access_token else "Missing Authorization Header")
return body.get("chatbotUserId")
async def handle_message(self, body: dict):
message = MessageModel.parse_obj(body)
if not message:
return
log("DEBUG", "message: {}".format(message))
try:
event = Event(message)
await handle_event(self, event)
except Exception as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Failed to handle event. Raw: {message}</bg #f8bbd0></r>"
)
return
async def call_api(self, api: str, **data) -> Union[Any, NoReturn]:
"""
:说明:
调用 钉钉 协议 API
:参数:
* ``api: str``: API 名称
* ``**data: Any``: API 参数
:返回:
- ``Any``: API 调用返回数据
:异常:
- ``NetworkError``: 网络错误
- ``ActionFailed``: API 调用失败
"""
if "self_id" in data:
self_id = data.pop("self_id")
if self_id:
bot = self.driver.bots[str(self_id)]
return await bot.call_api(api, **data)
log("DEBUG", f"Calling API <y>{api}</y>")
log("DEBUG", f"Calling data <y>{data}</y>")
if self.connection_type == "http" and api == "post_webhook":
raw_event: MessageModel = data["raw_event"]
if int(datetime.now().timestamp()) > int(
raw_event.sessionWebhookExpiredTime / 1000):
raise SessionExpired
target = raw_event.sessionWebhook
if not target:
raise ApiNotAvailable
headers = {}
segment: MessageSegment = data["message"][0]
try:
async with httpx.AsyncClient(headers=headers) as client:
response = await client.post(
target,
params={"access_token": self.config.access_token},
json=segment.data,
timeout=self.config.api_timeout)
if 200 <= response.status_code < 300:
result = response.json()
if isinstance(result, dict):
if result.get("errcode") != 0:
raise ApiError(errcode=result.get("errcode"),
errmsg=result.get("errmsg"))
return result
raise NetworkError(f"HTTP request received unexpected "
f"status code: {response.status_code}")
except httpx.InvalidURL:
raise NetworkError("API root url invalid")
except httpx.HTTPError:
raise NetworkError("HTTP request failed")
async def send(self,
event: "Event",
message: Union[str, "Message", "MessageSegment"],
at_sender: bool = False,
**kwargs) -> Union[Any, NoReturn]:
"""
:说明:
根据 ``event`` 向触发事件的主体发送消息
:参数:
* ``event: Event``: Event 对象
* ``message: Union[str, Message, MessageSegment]``: 要发送的消息
* ``at_sender: bool``: 是否 @ 事件主体
* ``**kwargs``: 覆盖默认参数
:返回:
- ``Any``: API 调用返回数据
:异常:
- ``ValueError``: 缺少 ``user_id``, ``group_id``
- ``NetworkError``: 网络错误
- ``ActionFailed``: API 调用失败
"""
msg = message if isinstance(message, Message) else Message(message)
log("DEBUG", f"send -> msg: {msg}")
at_sender = at_sender and bool(event.user_id)
log("DEBUG", f"send -> at_sender: {at_sender}")
params = {"raw_event": event.raw_event}
params.update(kwargs)
if at_sender and event.detail_type != "private":
params["message"] = f"@{event.user_id} " + msg
else:
params["message"] = msg
log("DEBUG", f"send -> params: {params}")
return await self.call_api("post_webhook", **params)

View File

@ -0,0 +1,207 @@
from typing import Literal, Union
from nonebot.adapters import BaseEvent
from nonebot.typing import Optional
from .utils import log
from .message import Message
from .model import MessageModel, ConversationType, TextMessage
class Event(BaseEvent):
"""
钉钉 协议 Event 适配继承属性参考 `BaseEvent <./#class-baseevent>`_ 。
"""
def __init__(self, message: MessageModel):
super().__init__(message)
if not message.msgtype:
log("ERROR", "message has no msgtype")
# 目前钉钉机器人只能接收到 text 类型的消息
self._message = Message(getattr(message, message.msgtype or "text"))
@property
def raw_event(self) -> MessageModel:
"""原始上报消息"""
return self._raw_event
@property
def id(self) -> Optional[str]:
"""
- 类型: ``Optional[str]``
- 说明: 消息 ID
"""
return self.raw_event.msgId
@property
def name(self) -> str:
"""
- 类型: ``str``
- 说明: 事件名称由类型与 ``.`` 组合而成
"""
n = self.type + "." + self.detail_type
if self.sub_type:
n += "." + self.sub_type
return n
@property
def self_id(self) -> str:
"""
- 类型: ``str``
- 说明: 机器人自身 ID
"""
return str(self.raw_event.chatbotUserId)
@property
def time(self) -> int:
"""
- 类型: ``int``
- 说明: 消息的时间戳单位 s
"""
# 单位 ms -> s
return int(self.raw_event.createAt / 1000)
@property
def type(self) -> str:
"""
- 类型: ``str``
- 说明: 事件类型
"""
return "message"
@type.setter
def type(self, value) -> None:
pass
@property
def detail_type(self) -> Literal["private", "group"]:
"""
- 类型: ``str``
- 说明: 事件详细类型
"""
return self.raw_event.conversationType.name
@detail_type.setter
def detail_type(self, value) -> None:
if value == "private":
self.raw_event.conversationType = ConversationType.private
if value == "group":
self.raw_event.conversationType = ConversationType.group
@property
def sub_type(self) -> Optional[str]:
"""
- 类型: ``Optional[str]``
- 说明: 事件子类型
"""
return ""
@sub_type.setter
def sub_type(self, value) -> None:
pass
@property
def user_id(self) -> Optional[str]:
"""
- 类型: ``Optional[str]``
- 说明: 发送者 ID
"""
return self.raw_event.senderId
@user_id.setter
def user_id(self, value) -> None:
self.raw_event.senderId = value
@property
def group_id(self) -> Optional[str]:
"""
- 类型: ``Optional[str]``
- 说明: 事件主体群 ID
"""
return self.raw_event.conversationId
@group_id.setter
def group_id(self, value) -> None:
self.raw_event.conversationId = value
@property
def to_me(self) -> Optional[bool]:
"""
- 类型: ``Optional[bool]``
- 说明: 消息是否与机器人相关
"""
return self.detail_type == "private" or self.raw_event.isInAtList
@to_me.setter
def to_me(self, value) -> None:
self.raw_event.isInAtList = value
@property
def message(self) -> Optional["Message"]:
"""
- 类型: ``Optional[Message]``
- 说明: 消息内容
"""
return self._message
@message.setter
def message(self, value) -> None:
self._message = value
@property
def reply(self) -> None:
"""
- 类型: ``None``
- 说明: 回复消息详情
"""
raise ValueError("暂不支持 reply")
@property
def raw_message(self) -> Optional[TextMessage]:
"""
- 类型: ``Optional[str]``
- 说明: 原始消息
"""
return getattr(self.raw_event, self.raw_event.msgtype)
@raw_message.setter
def raw_message(self, value) -> None:
setattr(self.raw_event, self.raw_event.msgtype, value)
@property
def plain_text(self) -> Optional[str]:
"""
- 类型: ``Optional[str]``
- 说明: 纯文本消息内容
"""
return self.message and self.message.extract_plain_text().strip()
@property
def sender(self) -> Optional[dict]:
"""
- 类型: ``Optional[dict]``
- 说明: 消息发送者信息
"""
result = {
# 加密的发送者ID。
"senderId": self.raw_event.senderId,
# 发送者昵称。
"senderNick": self.raw_event.senderNick,
# 企业内部群有的发送者当前群的企业 corpId。
"senderCorpId": self.raw_event.senderCorpId,
# 企业内部群有的发送者在企业内的 userId。
"senderStaffId": self.raw_event.senderStaffId,
"role": "admin" if self.raw_event.isAdmin else "member"
}
return result
@sender.setter
def sender(self, value) -> None:
def set_wrapper(name):
if value.get(name):
setattr(self.raw_event, name, value.get(name))
set_wrapper("senderId")
set_wrapper("senderNick")
set_wrapper("senderCorpId")
set_wrapper("senderStaffId")

View File

@ -0,0 +1,29 @@
from nonebot.exception import AdapterException
class DingAdapterException(AdapterException):
def __init__(self) -> None:
super.__init__("DING")
class ApiError(DingAdapterException):
"""
:说明:
API 请求成功返回数据 API 操作失败
"""
def __init__(self, errcode: int, errmsg: str):
self.errcode = errcode
self.errmsg = errmsg
def __repr__(self):
return f"<ApiError errcode={self.errcode} errmsg={self.errmsg}>"
class SessionExpired(DingAdapterException):
def __repr__(self) -> str:
return f"<sessionWebhook is Expired>"

View File

@ -0,0 +1,133 @@
from nonebot.typing import Any, Dict, Union, Iterable
from nonebot.adapters import BaseMessage, BaseMessageSegment
from .utils import log
from .model import TextMessage
class MessageSegment(BaseMessageSegment):
"""
钉钉 协议 MessageSegment 适配具体方法参考协议消息段类型或源码
"""
def __init__(self, type_: str, msg: Dict[str, Any]) -> None:
data = {
"msgtype": type_,
}
if msg:
data.update(msg)
log("DEBUG", f"data {data}")
super().__init__(type=type_, data=data)
@classmethod
def from_segment(cls, segment: "MessageSegment"):
return MessageSegment(segment.type, segment.data)
def __str__(self):
log("DEBUG", f"__str__: self.type {self.type} data {self.data}")
if self.type == "text":
return str(self.data["text"]["content"].strip())
return ""
def __add__(self, other) -> "Message":
if isinstance(other, str):
if self.type == 'text':
self.data['text']['content'] += other
return MessageSegment.from_segment(self)
return Message(self) + other
def atMobile(self, mobileNumber):
self.data.setdefault("at", {})
self.data["at"].setdefault("atMobiles", [])
self.data["at"]["atMobiles"].append(mobileNumber)
def atAll(self, value):
self.data.setdefault("at", {})
self.data["at"]["isAtAll"] = value
@staticmethod
def text(text: str) -> "MessageSegment":
return MessageSegment("text", {"text": {"content": text.strip()}})
@staticmethod
def markdown(title: str, text: str) -> "MessageSegment":
return MessageSegment("markdown", {
"markdown": {
"title": title,
"text": text,
},
})
@staticmethod
def actionCardSingleBtn(title: str, text: str, btnTitle: str,
btnUrl) -> "MessageSegment":
return MessageSegment(
"actionCard", {
"actionCard": {
"title": title,
"text": text,
"singleTitle": btnTitle,
"singleURL": btnUrl
}
})
@staticmethod
def actionCardSingleMultiBtns(
title: str,
text: str,
btns: list = [],
hideAvatar: bool = False,
btnOrientation: str = '1',
) -> "MessageSegment":
"""
:参数:
* ``btnOrientation``: 0按钮竖直排列 1按钮横向排列
* ``btns``: [{ "title": title, "actionURL": actionURL }, ...]
"""
return MessageSegment(
"actionCard", {
"actionCard": {
"title": title,
"text": text,
"hideAvatar": "1" if hideAvatar else "0",
"btnOrientation": btnOrientation,
"btns": btns
}
})
@staticmethod
def feedCard(links: list = [],) -> "MessageSegment":
"""
:参数:
* ``links``: [{ "title": xxx, "messageURL": xxx, "picURL": xxx }, ...]
"""
return MessageSegment("feedCard", {"feedCard": {"links": links}})
@staticmethod
def empty() -> "MessageSegment":
"""不想回复消息到群里"""
return MessageSegment("empty")
class Message(BaseMessage):
"""
钉钉 协议 Message 适配
"""
@staticmethod
def _construct(
msg: Union[str, dict, list,
TextMessage]) -> Iterable[MessageSegment]:
if isinstance(msg, dict):
yield MessageSegment(msg["type"], msg.get("data") or {})
return
elif isinstance(msg, list):
for seg in msg:
yield MessageSegment(seg["type"], seg.get("data") or {})
return
elif isinstance(msg, TextMessage):
yield MessageSegment("text", {"text": msg.dict()})
elif isinstance(msg, str):
yield MessageSegment.text(str)

View File

@ -0,0 +1,47 @@
from typing import List, Optional
from enum import Enum
from pydantic import BaseModel
class Headers(BaseModel):
sign: str
token: str
# ms
timestamp: int
class TextMessage(BaseModel):
content: str
class AtUsersItem(BaseModel):
dingtalkId: str
staffId: Optional[str]
class ConversationType(str, Enum):
private = '1'
group = '2'
class MessageModel(BaseModel):
msgtype: str = None
text: Optional[TextMessage] = None
msgId: str
# ms
createAt: int = None
conversationType: ConversationType = None
conversationId: str = None
conversationTitle: str = None
senderId: str = None
senderNick: str = None
senderCorpId: str = None
senderStaffId: str = None
chatbotUserId: str = None
chatbotCorpId: str = None
atUsers: List[AtUsersItem] = None
sessionWebhook: str = None
# ms
sessionWebhookExpiredTime: int = None
isAdmin: bool = None
isInAtList: bool = None

View File

@ -0,0 +1,35 @@
import base64
import hashlib
import hmac
from typing import TYPE_CHECKING
from nonebot.utils import logger_wrapper
if TYPE_CHECKING:
from nonebot.drivers import BaseDriver
log = logger_wrapper("DING")
def check_legal(timestamp, remote_sign, driver: "BaseDriver"):
"""
1. timestamp 与系统当前时间戳如果相差1小时以上则认为是非法的请求
2. sign 与开发者自己计算的结果不一致则认为是非法的请求
必须当timestamp和sign同时验证通过才能认为是来自钉钉的合法请求
"""
# 目前先设置成 secret
# TODO 后面可能可以从 secret[adapter_name] 获取
app_secret = driver.config.secret # 机器人的 appSecret
if not app_secret:
# TODO warning
log("WARNING", "No ding secrets set, won't check sign")
return True
app_secret_enc = app_secret.encode('utf-8')
string_to_sign = '{}\n{}'.format(timestamp, app_secret)
string_to_sign_enc = string_to_sign.encode('utf-8')
hmac_code = hmac.new(app_secret_enc,
string_to_sign_enc,
digestmod=hashlib.sha256).digest()
sign = base64.b64encode(hmac_code).decode('utf-8')
return remote_sign == sign

View File

@ -145,3 +145,9 @@ class ActionFailed(Exception):
def __str__(self):
return self.__repr__()
class AdapterException(Exception):
def __init__(self, adapter_name) -> None:
self.adapter_name = adapter_name

View File

@ -21,7 +21,7 @@
from types import ModuleType
from typing import NoReturn, TYPE_CHECKING
from typing import Any, Set, List, Dict, Type, Tuple, Mapping
from typing import Union, TypeVar, Optional, Iterable, Callable, Awaitable
from typing import Union, TypeVar, Optional, Iterable, Callable, Awaitable, Generic
# import some modules needed when checking types
if TYPE_CHECKING:

View File

@ -5,6 +5,7 @@ sys.path.insert(0, os.path.abspath(".."))
import nonebot
from nonebot.adapters.cqhttp import Bot
from nonebot.adapters.ding import Bot as DingBot
from nonebot.log import logger, default_format
# test custom log
@ -18,6 +19,7 @@ nonebot.init(custom_config2="config on init")
app = nonebot.get_asgi()
driver = nonebot.get_driver()
driver.register_adapter("cqhttp", Bot)
driver.register_adapter("ding", DingBot)
# load builtin plugin
nonebot.load_builtin_plugins()