From edb4458031468d7821b83f0a83760d677261892a Mon Sep 17 00:00:00 2001 From: Artin Date: Thu, 3 Dec 2020 00:59:32 +0800 Subject: [PATCH] :sparkles: Add ding adapter --- nonebot/adapters/__init__.py | 25 ++-- nonebot/adapters/ding/__init__.py | 15 +++ nonebot/adapters/ding/bot.py | 205 ++++++++++++++++++++++++++++ nonebot/adapters/ding/event.py | 207 +++++++++++++++++++++++++++++ nonebot/adapters/ding/exception.py | 29 ++++ nonebot/adapters/ding/message.py | 133 ++++++++++++++++++ nonebot/adapters/ding/model.py | 47 +++++++ nonebot/adapters/ding/utils.py | 35 +++++ nonebot/exception.py | 6 + nonebot/typing.py | 2 +- tests/bot.py | 2 + 11 files changed, 695 insertions(+), 11 deletions(-) create mode 100644 nonebot/adapters/ding/__init__.py create mode 100644 nonebot/adapters/ding/bot.py create mode 100644 nonebot/adapters/ding/event.py create mode 100644 nonebot/adapters/ding/exception.py create mode 100644 nonebot/adapters/ding/message.py create mode 100644 nonebot/adapters/ding/model.py create mode 100644 nonebot/adapters/ding/utils.py diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py index 9895b88c..a7dd7b21 100644 --- a/nonebot/adapters/__init__.py +++ b/nonebot/adapters/__init__.py @@ -9,9 +9,11 @@ import abc from functools import reduce, partial from dataclasses import dataclass, field +from pydantic import BaseModel + from nonebot.config import Config from nonebot.typing import Driver, Message, WebSocket -from nonebot.typing import Any, Dict, Union, Optional, NoReturn, Callable, Iterable, Awaitable +from nonebot.typing import Any, Dict, Union, Optional, NoReturn, Callable, Iterable, Awaitable, TypeVar, Generic class BaseBot(abc.ABC): @@ -135,24 +137,27 @@ class BaseBot(abc.ABC): raise NotImplementedError -class BaseEvent(abc.ABC): +T = TypeVar("T", dict, BaseModel) + + +class BaseEvent(abc.ABC, Generic[T]): """ Event 基类。提供上报信息的关键信息,其余信息可从原始上报消息获取。 """ - def __init__(self, raw_event: dict): + def __init__(self, raw_event: T): """ :参数: - * ``raw_event: dict``: 原始上报消息 + * ``raw_event: T``: 原始上报消息 """ - self._raw_event = raw_event + self._raw_event: T = raw_event def __repr__(self) -> str: return f"" @property - def raw_event(self) -> dict: + def raw_event(self) -> T: """原始上报消息""" return self._raw_event @@ -347,17 +352,17 @@ class BaseMessage(list, abc.ABC): """消息数组""" def __init__(self, - message: Union[str, dict, list, BaseMessageSegment, + message: Union[str, dict, list, BaseModel, BaseMessageSegment, "BaseMessage"] = None, *args, **kwargs): """ :参数: - * ``message: Union[str, dict, list, MessageSegment, Message]``: 消息内容 + * ``message: Union[str, dict, list, BaseModel, MessageSegment, Message]``: 消息内容 """ super().__init__(*args, **kwargs) - if isinstance(message, (str, dict, list)): + if isinstance(message, (str, dict, list, BaseModel)): self.extend(self._construct(message)) elif isinstance(message, BaseMessage): self.extend(message) @@ -448,4 +453,4 @@ class BaseMessage(list, abc.ABC): return f"{x} {y}" if y.type == "text" else x plain_text = reduce(_concat, self, "") - return plain_text[1:] if plain_text else plain_text + return plain_text.strip() diff --git a/nonebot/adapters/ding/__init__.py b/nonebot/adapters/ding/__init__.py new file mode 100644 index 00000000..8b5f101d --- /dev/null +++ b/nonebot/adapters/ding/__init__.py @@ -0,0 +1,15 @@ +""" +钉钉群机器人 协议适配 +============================ + +协议详情请看: `钉钉文档`_ + +.. _钉钉文档: + https://ding-doc.dingtalk.com/doc#/serverapi2/krgddi + +""" + +from .bot import Bot +from .event import Event +from .message import Message, MessageSegment +from .exception import ApiError, SessionExpired, AdapterException diff --git a/nonebot/adapters/ding/bot.py b/nonebot/adapters/ding/bot.py new file mode 100644 index 00000000..4acfc2fc --- /dev/null +++ b/nonebot/adapters/ding/bot.py @@ -0,0 +1,205 @@ +from datetime import datetime +import httpx + +from nonebot.log import logger +from nonebot.config import Config +from nonebot.message import handle_event +from nonebot.typing import Driver, WebSocket, NoReturn +from nonebot.typing import Any, Union, Optional +from nonebot.adapters import BaseBot +from nonebot.exception import NetworkError, RequestDenied, ApiNotAvailable +from .exception import ApiError, SessionExpired +from .utils import check_legal, log +from .event import Event +from .message import Message, MessageSegment +from .model import MessageModel + + +class Bot(BaseBot): + """ + 钉钉 协议 Bot 适配。继承属性参考 `BaseBot <./#class-basebot>`_ 。 + """ + + def __init__(self, + driver: Driver, + connection_type: str, + config: Config, + self_id: str, + *, + websocket: Optional[WebSocket] = None): + + super().__init__(driver, + connection_type, + config, + self_id, + websocket=websocket) + + @property + def type(self) -> str: + """ + - 返回: ``"ding"`` + """ + return "ding" + + @classmethod + async def check_permission(cls, driver: Driver, connection_type: str, + headers: dict, + body: Optional[dict]) -> Union[str, NoReturn]: + """ + :说明: + 钉钉协议鉴权。参考 `鉴权 `_ + """ + timestamp = headers.get("timestamp") + sign = headers.get("sign") + log("DEBUG", "headers: {}".format(headers)) + log("DEBUG", "body: {}".format(body)) + + # 检查 timestamp + if not timestamp: + log("WARNING", "Missing `timestamp` Header") + raise RequestDenied(400, "Missing `timestamp` Header") + # 检查 sign + if not sign: + log("WARNING", "Missing `sign` Header") + raise RequestDenied(400, "Missing `sign` Header") + # 校验 sign 和 timestamp,判断是否是来自钉钉的合法请求 + if not check_legal(timestamp, sign, driver): + log("WARNING", "Signature Header is invalid") + raise RequestDenied(403, "Signature is invalid") + # 检查连接方式 + if connection_type not in ["http"]: + log("WARNING", "Unsupported connection type") + raise RequestDenied(405, "Unsupported connection type") + + access_token = driver.config.access_token + if access_token and access_token != access_token: + log( + "WARNING", "Authorization Header is invalid" + if access_token else "Missing Authorization Header") + raise RequestDenied( + 403, "Authorization Header is invalid" + if access_token else "Missing Authorization Header") + return body.get("chatbotUserId") + + async def handle_message(self, body: dict): + message = MessageModel.parse_obj(body) + if not message: + return + log("DEBUG", "message: {}".format(message)) + + try: + event = Event(message) + await handle_event(self, event) + except Exception as e: + logger.opt(colors=True, exception=e).error( + f"Failed to handle event. Raw: {message}" + ) + return + + async def call_api(self, api: str, **data) -> Union[Any, NoReturn]: + """ + :说明: + + 调用 钉钉 协议 API + + :参数: + + * ``api: str``: API 名称 + * ``**data: Any``: API 参数 + + :返回: + + - ``Any``: API 调用返回数据 + + :异常: + + - ``NetworkError``: 网络错误 + - ``ActionFailed``: API 调用失败 + """ + if "self_id" in data: + self_id = data.pop("self_id") + if self_id: + bot = self.driver.bots[str(self_id)] + return await bot.call_api(api, **data) + + log("DEBUG", f"Calling API {api}") + log("DEBUG", f"Calling data {data}") + + if self.connection_type == "http" and api == "post_webhook": + raw_event: MessageModel = data["raw_event"] + + if int(datetime.now().timestamp()) > int( + raw_event.sessionWebhookExpiredTime / 1000): + raise SessionExpired + + target = raw_event.sessionWebhook + + if not target: + raise ApiNotAvailable + + headers = {} + segment: MessageSegment = data["message"][0] + try: + async with httpx.AsyncClient(headers=headers) as client: + response = await client.post( + target, + params={"access_token": self.config.access_token}, + json=segment.data, + timeout=self.config.api_timeout) + + if 200 <= response.status_code < 300: + result = response.json() + if isinstance(result, dict): + if result.get("errcode") != 0: + raise ApiError(errcode=result.get("errcode"), + errmsg=result.get("errmsg")) + return result + raise NetworkError(f"HTTP request received unexpected " + f"status code: {response.status_code}") + except httpx.InvalidURL: + raise NetworkError("API root url invalid") + except httpx.HTTPError: + raise NetworkError("HTTP request failed") + + async def send(self, + event: "Event", + message: Union[str, "Message", "MessageSegment"], + at_sender: bool = False, + **kwargs) -> Union[Any, NoReturn]: + """ + :说明: + + 根据 ``event`` 向触发事件的主体发送消息。 + + :参数: + + * ``event: Event``: Event 对象 + * ``message: Union[str, Message, MessageSegment]``: 要发送的消息 + * ``at_sender: bool``: 是否 @ 事件主体 + * ``**kwargs``: 覆盖默认参数 + + :返回: + + - ``Any``: API 调用返回数据 + + :异常: + + - ``ValueError``: 缺少 ``user_id``, ``group_id`` + - ``NetworkError``: 网络错误 + - ``ActionFailed``: API 调用失败 + """ + msg = message if isinstance(message, Message) else Message(message) + log("DEBUG", f"send -> msg: {msg}") + + at_sender = at_sender and bool(event.user_id) + log("DEBUG", f"send -> at_sender: {at_sender}") + params = {"raw_event": event.raw_event} + params.update(kwargs) + + if at_sender and event.detail_type != "private": + params["message"] = f"@{event.user_id} " + msg + else: + params["message"] = msg + log("DEBUG", f"send -> params: {params}") + + return await self.call_api("post_webhook", **params) diff --git a/nonebot/adapters/ding/event.py b/nonebot/adapters/ding/event.py new file mode 100644 index 00000000..a4c50e9d --- /dev/null +++ b/nonebot/adapters/ding/event.py @@ -0,0 +1,207 @@ +from typing import Literal, Union +from nonebot.adapters import BaseEvent +from nonebot.typing import Optional + +from .utils import log +from .message import Message +from .model import MessageModel, ConversationType, TextMessage + + +class Event(BaseEvent): + """ + 钉钉 协议 Event 适配。继承属性参考 `BaseEvent <./#class-baseevent>`_ 。 + """ + + def __init__(self, message: MessageModel): + super().__init__(message) + if not message.msgtype: + log("ERROR", "message has no msgtype") + # 目前钉钉机器人只能接收到 text 类型的消息 + self._message = Message(getattr(message, message.msgtype or "text")) + + @property + def raw_event(self) -> MessageModel: + """原始上报消息""" + return self._raw_event + + @property + def id(self) -> Optional[str]: + """ + - 类型: ``Optional[str]`` + - 说明: 消息 ID + """ + return self.raw_event.msgId + + @property + def name(self) -> str: + """ + - 类型: ``str`` + - 说明: 事件名称,由类型与 ``.`` 组合而成 + """ + n = self.type + "." + self.detail_type + if self.sub_type: + n += "." + self.sub_type + return n + + @property + def self_id(self) -> str: + """ + - 类型: ``str`` + - 说明: 机器人自身 ID + """ + return str(self.raw_event.chatbotUserId) + + @property + def time(self) -> int: + """ + - 类型: ``int`` + - 说明: 消息的时间戳,单位 s + """ + # 单位 ms -> s + return int(self.raw_event.createAt / 1000) + + @property + def type(self) -> str: + """ + - 类型: ``str`` + - 说明: 事件类型 + """ + return "message" + + @type.setter + def type(self, value) -> None: + pass + + @property + def detail_type(self) -> Literal["private", "group"]: + """ + - 类型: ``str`` + - 说明: 事件详细类型 + """ + return self.raw_event.conversationType.name + + @detail_type.setter + def detail_type(self, value) -> None: + if value == "private": + self.raw_event.conversationType = ConversationType.private + if value == "group": + self.raw_event.conversationType = ConversationType.group + + @property + def sub_type(self) -> Optional[str]: + """ + - 类型: ``Optional[str]`` + - 说明: 事件子类型 + """ + return "" + + @sub_type.setter + def sub_type(self, value) -> None: + pass + + @property + def user_id(self) -> Optional[str]: + """ + - 类型: ``Optional[str]`` + - 说明: 发送者 ID + """ + return self.raw_event.senderId + + @user_id.setter + def user_id(self, value) -> None: + self.raw_event.senderId = value + + @property + def group_id(self) -> Optional[str]: + """ + - 类型: ``Optional[str]`` + - 说明: 事件主体群 ID + """ + return self.raw_event.conversationId + + @group_id.setter + def group_id(self, value) -> None: + self.raw_event.conversationId = value + + @property + def to_me(self) -> Optional[bool]: + """ + - 类型: ``Optional[bool]`` + - 说明: 消息是否与机器人相关 + """ + return self.detail_type == "private" or self.raw_event.isInAtList + + @to_me.setter + def to_me(self, value) -> None: + self.raw_event.isInAtList = value + + @property + def message(self) -> Optional["Message"]: + """ + - 类型: ``Optional[Message]`` + - 说明: 消息内容 + """ + return self._message + + @message.setter + def message(self, value) -> None: + self._message = value + + @property + def reply(self) -> None: + """ + - 类型: ``None`` + - 说明: 回复消息详情 + """ + raise ValueError("暂不支持 reply") + + @property + def raw_message(self) -> Optional[TextMessage]: + """ + - 类型: ``Optional[str]`` + - 说明: 原始消息 + """ + return getattr(self.raw_event, self.raw_event.msgtype) + + @raw_message.setter + def raw_message(self, value) -> None: + setattr(self.raw_event, self.raw_event.msgtype, value) + + @property + def plain_text(self) -> Optional[str]: + """ + - 类型: ``Optional[str]`` + - 说明: 纯文本消息内容 + """ + return self.message and self.message.extract_plain_text().strip() + + @property + def sender(self) -> Optional[dict]: + """ + - 类型: ``Optional[dict]`` + - 说明: 消息发送者信息 + """ + result = { + # 加密的发送者ID。 + "senderId": self.raw_event.senderId, + # 发送者昵称。 + "senderNick": self.raw_event.senderNick, + # 企业内部群有的发送者当前群的企业 corpId。 + "senderCorpId": self.raw_event.senderCorpId, + # 企业内部群有的发送者在企业内的 userId。 + "senderStaffId": self.raw_event.senderStaffId, + "role": "admin" if self.raw_event.isAdmin else "member" + } + return result + + @sender.setter + def sender(self, value) -> None: + + def set_wrapper(name): + if value.get(name): + setattr(self.raw_event, name, value.get(name)) + + set_wrapper("senderId") + set_wrapper("senderNick") + set_wrapper("senderCorpId") + set_wrapper("senderStaffId") diff --git a/nonebot/adapters/ding/exception.py b/nonebot/adapters/ding/exception.py new file mode 100644 index 00000000..bfb318c5 --- /dev/null +++ b/nonebot/adapters/ding/exception.py @@ -0,0 +1,29 @@ +from nonebot.exception import AdapterException + + +class DingAdapterException(AdapterException): + + def __init__(self) -> None: + super.__init__("DING") + + +class ApiError(DingAdapterException): + """ + :说明: + + API 请求成功返回数据,但 API 操作失败。 + + """ + + def __init__(self, errcode: int, errmsg: str): + self.errcode = errcode + self.errmsg = errmsg + + def __repr__(self): + return f"" + + +class SessionExpired(DingAdapterException): + + def __repr__(self) -> str: + return f"" diff --git a/nonebot/adapters/ding/message.py b/nonebot/adapters/ding/message.py new file mode 100644 index 00000000..53b83f6e --- /dev/null +++ b/nonebot/adapters/ding/message.py @@ -0,0 +1,133 @@ +from nonebot.typing import Any, Dict, Union, Iterable +from nonebot.adapters import BaseMessage, BaseMessageSegment +from .utils import log +from .model import TextMessage + + +class MessageSegment(BaseMessageSegment): + """ + 钉钉 协议 MessageSegment 适配。具体方法参考协议消息段类型或源码。 + """ + + def __init__(self, type_: str, msg: Dict[str, Any]) -> None: + data = { + "msgtype": type_, + } + if msg: + data.update(msg) + log("DEBUG", f"data {data}") + super().__init__(type=type_, data=data) + + @classmethod + def from_segment(cls, segment: "MessageSegment"): + return MessageSegment(segment.type, segment.data) + + def __str__(self): + log("DEBUG", f"__str__: self.type {self.type} data {self.data}") + if self.type == "text": + return str(self.data["text"]["content"].strip()) + return "" + + def __add__(self, other) -> "Message": + if isinstance(other, str): + if self.type == 'text': + self.data['text']['content'] += other + return MessageSegment.from_segment(self) + return Message(self) + other + + def atMobile(self, mobileNumber): + self.data.setdefault("at", {}) + self.data["at"].setdefault("atMobiles", []) + self.data["at"]["atMobiles"].append(mobileNumber) + + def atAll(self, value): + self.data.setdefault("at", {}) + self.data["at"]["isAtAll"] = value + + @staticmethod + def text(text: str) -> "MessageSegment": + return MessageSegment("text", {"text": {"content": text.strip()}}) + + @staticmethod + def markdown(title: str, text: str) -> "MessageSegment": + return MessageSegment("markdown", { + "markdown": { + "title": title, + "text": text, + }, + }) + + @staticmethod + def actionCardSingleBtn(title: str, text: str, btnTitle: str, + btnUrl) -> "MessageSegment": + return MessageSegment( + "actionCard", { + "actionCard": { + "title": title, + "text": text, + "singleTitle": btnTitle, + "singleURL": btnUrl + } + }) + + @staticmethod + def actionCardSingleMultiBtns( + title: str, + text: str, + btns: list = [], + hideAvatar: bool = False, + btnOrientation: str = '1', + ) -> "MessageSegment": + """ + :参数: + + * ``btnOrientation``: 0:按钮竖直排列 1:按钮横向排列 + + * ``btns``: [{ "title": title, "actionURL": actionURL }, ...] + """ + return MessageSegment( + "actionCard", { + "actionCard": { + "title": title, + "text": text, + "hideAvatar": "1" if hideAvatar else "0", + "btnOrientation": btnOrientation, + "btns": btns + } + }) + + @staticmethod + def feedCard(links: list = [],) -> "MessageSegment": + """ + :参数: + + * ``links``: [{ "title": xxx, "messageURL": xxx, "picURL": xxx }, ...] + """ + return MessageSegment("feedCard", {"feedCard": {"links": links}}) + + @staticmethod + def empty() -> "MessageSegment": + """不想回复消息到群里""" + return MessageSegment("empty") + + +class Message(BaseMessage): + """ + 钉钉 协议 Message 适配。 + """ + + @staticmethod + def _construct( + msg: Union[str, dict, list, + TextMessage]) -> Iterable[MessageSegment]: + if isinstance(msg, dict): + yield MessageSegment(msg["type"], msg.get("data") or {}) + return + elif isinstance(msg, list): + for seg in msg: + yield MessageSegment(seg["type"], seg.get("data") or {}) + return + elif isinstance(msg, TextMessage): + yield MessageSegment("text", {"text": msg.dict()}) + elif isinstance(msg, str): + yield MessageSegment.text(str) diff --git a/nonebot/adapters/ding/model.py b/nonebot/adapters/ding/model.py new file mode 100644 index 00000000..d317ea5b --- /dev/null +++ b/nonebot/adapters/ding/model.py @@ -0,0 +1,47 @@ +from typing import List, Optional +from enum import Enum +from pydantic import BaseModel + + +class Headers(BaseModel): + sign: str + token: str + # ms + timestamp: int + + +class TextMessage(BaseModel): + content: str + + +class AtUsersItem(BaseModel): + dingtalkId: str + staffId: Optional[str] + + +class ConversationType(str, Enum): + private = '1' + group = '2' + + +class MessageModel(BaseModel): + msgtype: str = None + text: Optional[TextMessage] = None + msgId: str + # ms + createAt: int = None + conversationType: ConversationType = None + conversationId: str = None + conversationTitle: str = None + senderId: str = None + senderNick: str = None + senderCorpId: str = None + senderStaffId: str = None + chatbotUserId: str = None + chatbotCorpId: str = None + atUsers: List[AtUsersItem] = None + sessionWebhook: str = None + # ms + sessionWebhookExpiredTime: int = None + isAdmin: bool = None + isInAtList: bool = None diff --git a/nonebot/adapters/ding/utils.py b/nonebot/adapters/ding/utils.py new file mode 100644 index 00000000..8c644683 --- /dev/null +++ b/nonebot/adapters/ding/utils.py @@ -0,0 +1,35 @@ +import base64 +import hashlib +import hmac +from typing import TYPE_CHECKING + +from nonebot.utils import logger_wrapper + +if TYPE_CHECKING: + from nonebot.drivers import BaseDriver +log = logger_wrapper("DING") + + +def check_legal(timestamp, remote_sign, driver: "BaseDriver"): + """ + 1. timestamp 与系统当前时间戳如果相差1小时以上,则认为是非法的请求。 + + 2. sign 与开发者自己计算的结果不一致,则认为是非法的请求。 + + 必须当timestamp和sign同时验证通过,才能认为是来自钉钉的合法请求。 + """ + # 目前先设置成 secret + # TODO 后面可能可以从 secret[adapter_name] 获取 + app_secret = driver.config.secret # 机器人的 appSecret + if not app_secret: + # TODO warning + log("WARNING", "No ding secrets set, won't check sign") + return True + app_secret_enc = app_secret.encode('utf-8') + string_to_sign = '{}\n{}'.format(timestamp, app_secret) + string_to_sign_enc = string_to_sign.encode('utf-8') + hmac_code = hmac.new(app_secret_enc, + string_to_sign_enc, + digestmod=hashlib.sha256).digest() + sign = base64.b64encode(hmac_code).decode('utf-8') + return remote_sign == sign diff --git a/nonebot/exception.py b/nonebot/exception.py index cc65e6da..1f61f5ed 100644 --- a/nonebot/exception.py +++ b/nonebot/exception.py @@ -145,3 +145,9 @@ class ActionFailed(Exception): def __str__(self): return self.__repr__() + + +class AdapterException(Exception): + + def __init__(self, adapter_name) -> None: + self.adapter_name = adapter_name diff --git a/nonebot/typing.py b/nonebot/typing.py index 09109b37..21a8b0ee 100644 --- a/nonebot/typing.py +++ b/nonebot/typing.py @@ -21,7 +21,7 @@ from types import ModuleType from typing import NoReturn, TYPE_CHECKING from typing import Any, Set, List, Dict, Type, Tuple, Mapping -from typing import Union, TypeVar, Optional, Iterable, Callable, Awaitable +from typing import Union, TypeVar, Optional, Iterable, Callable, Awaitable, Generic # import some modules needed when checking types if TYPE_CHECKING: diff --git a/tests/bot.py b/tests/bot.py index 45f99b95..16d3c5b0 100644 --- a/tests/bot.py +++ b/tests/bot.py @@ -5,6 +5,7 @@ sys.path.insert(0, os.path.abspath("..")) import nonebot from nonebot.adapters.cqhttp import Bot +from nonebot.adapters.ding import Bot as DingBot from nonebot.log import logger, default_format # test custom log @@ -18,6 +19,7 @@ nonebot.init(custom_config2="config on init") app = nonebot.get_asgi() driver = nonebot.get_driver() driver.register_adapter("cqhttp", Bot) +driver.register_adapter("ding", DingBot) # load builtin plugin nonebot.load_builtin_plugins()