diff --git a/nonebot/adapters/__init__.py b/nonebot/adapters/__init__.py index 81d2a1fb..5f630bdf 100644 --- a/nonebot/adapters/__init__.py +++ b/nonebot/adapters/__init__.py @@ -6,6 +6,7 @@ """ import abc +from copy import copy from typing_extensions import Literal from functools import reduce, partial from dataclasses import dataclass, field @@ -292,7 +293,7 @@ class MessageSegment(abc.ABC): @abc.abstractmethod def __add__(self: T_MessageSegment, other: Union[str, T_MessageSegment, - T_Message]) -> "T_Message": + T_Message]) -> T_Message: """你需要在这里实现不同消息段的合并: 比如: if isinstance(other, str): @@ -326,6 +327,9 @@ class MessageSegment(abc.ABC): def get(self, key, default=None): return getattr(self, key, default) + def copy(self: T_MessageSegment) -> T_MessageSegment: + return copy(self) + @abc.abstractmethod def is_text(self) -> bool: raise NotImplementedError @@ -335,7 +339,8 @@ class Message(list, abc.ABC): """消息数组""" def __init__(self, - message: Union[str, list, dict, T_MessageSegment, T_Message, Any] = None, + message: Union[str, list, dict, T_MessageSegment, T_Message, + Any] = None, *args, **kwargs): """ @@ -364,7 +369,8 @@ class Message(list, abc.ABC): @staticmethod @abc.abstractmethod - def _construct(msg: Union[str, list, dict, Any]) -> Iterable[T_MessageSegment]: + def _construct( + msg: Union[str, list, dict, Any]) -> Iterable[T_MessageSegment]: raise NotImplementedError def __add__(self: T_Message, other: Union[str, T_MessageSegment, diff --git a/nonebot/adapters/cqhttp/bot.py b/nonebot/adapters/cqhttp/bot.py index ee29e903..aa2783f3 100644 --- a/nonebot/adapters/cqhttp/bot.py +++ b/nonebot/adapters/cqhttp/bot.py @@ -6,7 +6,6 @@ import asyncio from typing import Any, Dict, Union, Optional, TYPE_CHECKING import httpx - from nonebot.log import logger from nonebot.config import Config from nonebot.typing import overrides diff --git a/nonebot/adapters/ding/__init__.py b/nonebot/adapters/ding/__init__.py index ea076c6f..b27d9c3e 100644 --- a/nonebot/adapters/ding/__init__.py +++ b/nonebot/adapters/ding/__init__.py @@ -11,7 +11,7 @@ from .utils import log from .bot import Bot -from .event import Event from .message import Message, MessageSegment +from .event import Event, MessageEvent, PrivateMessageEvent, GroupMessageEvent from .exception import (DingAdapterException, ApiNotAvailable, NetworkError, ActionFailed, SessionExpired) diff --git a/nonebot/adapters/ding/bot.py b/nonebot/adapters/ding/bot.py index 18dc5b69..eddccd33 100644 --- a/nonebot/adapters/ding/bot.py +++ b/nonebot/adapters/ding/bot.py @@ -6,18 +6,18 @@ from typing import Any, Union, Optional, TYPE_CHECKING import httpx from nonebot.log import logger from nonebot.config import Config +from nonebot.typing import overrides from nonebot.message import handle_event from nonebot.adapters import Bot as BaseBot from nonebot.exception import RequestDenied from .utils import log -from .event import Event, MessageEvent, PrivateMessageEvent, GroupMessageEvent -from .model import ConversationType from .message import Message, MessageSegment from .exception import NetworkError, ApiNotAvailable, ActionFailed, SessionExpired +from .event import Event, MessageEvent, PrivateMessageEvent, GroupMessageEvent, ConversationType if TYPE_CHECKING: - from nonebot.drivers import BaseDriver as Driver + from nonebot.drivers import Driver class Bot(BaseBot): @@ -38,6 +38,7 @@ class Bot(BaseBot): return "ding" @classmethod + @overrides(BaseBot) async def check_permission(cls, driver: "Driver", connection_type: str, headers: dict, body: Optional[dict]) -> str: """ @@ -73,18 +74,22 @@ class Bot(BaseBot): log("WARNING", "Ding signature check ignored!") return body["chatbotUserId"] - async def handle_message(self, body: dict): - if not body: + @overrides(BaseBot) + async def handle_message(self, message: dict): + if not message: return # 判断消息类型,生成不同的 Event - conversation_type = body["conversationType"] - if conversation_type == ConversationType.private: - event = PrivateMessageEvent.parse_obj(body) - else: - event = GroupMessageEvent.parse_obj(body) - - if not event: + try: + conversation_type = message["conversationType"] + if conversation_type == ConversationType.private: + event = PrivateMessageEvent.parse_obj(message) + elif conversation_type == ConversationType.group: + event = GroupMessageEvent.parse_obj(message) + else: + raise ValueError("Unsupported conversation type") + except Exception as e: + log("Error", "Event Parser Error", e) return try: @@ -95,6 +100,7 @@ class Bot(BaseBot): ) return + @overrides(BaseBot) async def call_api(self, api: str, event: Optional[MessageEvent] = None, @@ -138,19 +144,18 @@ class Bot(BaseBot): target = event.sessionWebhook else: - target = None - - if not target: raise ApiNotAvailable headers = {} - segment: MessageSegment = data["message"][0] + message: Message = data.get("message", None) + if not message: + raise ValueError("Message not found") try: async with httpx.AsyncClient(headers=headers) as client: response = await client.post( target, params={"access_token": self.config.access_token}, - json=segment.data, + json=message._produce(), timeout=self.config.api_timeout) if 200 <= response.status_code < 300: @@ -167,8 +172,9 @@ class Bot(BaseBot): except httpx.HTTPError: raise NetworkError("HTTP request failed") + @overrides(BaseBot) async def send(self, - event: Event, + event: MessageEvent, message: Union[str, "Message", "MessageSegment"], at_sender: bool = False, **kwargs) -> Any: @@ -196,13 +202,15 @@ class Bot(BaseBot): """ msg = message if isinstance(message, Message) else Message(message) - at_sender = at_sender and bool(event.user_id) + at_sender = at_sender and bool(event.senderId) params = {} params["event"] = event params.update(kwargs) - if at_sender and event.detail_type != "private": - params["message"] = f"@{event.user_id} " + msg + if at_sender and event.conversationType != ConversationType.private: + params[ + "message"] = f"@{event.senderId} " + msg + MessageSegment.atMobiles( + event.senderId) else: params["message"] = msg diff --git a/nonebot/adapters/ding/event.py b/nonebot/adapters/ding/event.py index 507e9ccc..d5c670e5 100644 --- a/nonebot/adapters/ding/event.py +++ b/nonebot/adapters/ding/event.py @@ -1,84 +1,124 @@ -from typing import Union, Optional +from enum import Enum +from typing import List, Optional from typing_extensions import Literal -from pydantic import BaseModel, validator, parse_obj_as -from pydantic.fields import ModelField +from pydantic import BaseModel -from nonebot.adapters import Event as BaseEvent from nonebot.utils import escape_tag +from nonebot.typing import overrides +from nonebot.adapters import Event as BaseEvent from .message import Message -from .model import MessageModel, PrivateMessageModel, GroupMessageModel, ConversationType, TextMessage class Event(BaseEvent): """ - 钉钉 协议 Event 适配。继承属性参考 `BaseEvent <./#class-baseevent>`_ 。 + 钉钉 协议 Event 适配。各事件字段参考 `钉钉文档`_ + + .. _钉钉文档: + https://ding-doc.dingtalk.com/document#/org-dev-guide/elzz1p """ - message: Message = None - def __init__(self, **data): - super().__init__(**data) - # 其实目前钉钉机器人只能接收到 text 类型的消息 - message: Union[TextMessage] = getattr(self, self.msgtype, None) - self.message = parse_obj_as(Message, message) + chatbotUserId: str - def get_type(self) -> Literal["message"]: - """ - - 类型: ``str`` - - 说明: 事件类型 - """ - return "message" + @overrides(BaseEvent) + def get_type(self) -> Literal["message", "notice", "request", "meta_event"]: + raise ValueError("Event has no type!") + @overrides(BaseEvent) def get_event_name(self) -> str: - detail_type = self.conversationType.name - return self.get_type() + "." + detail_type + raise ValueError("Event has no type!") + @overrides(BaseEvent) def get_event_description(self) -> str: - return (f'Message[{self.msgtype}] {self.msgId} from {self.senderId} "' + - "".join( - map( - lambda x: escape_tag(str(x)) - if x.is_text() else f"{escape_tag(str(x))}", - self.message, - )) + '"') - - def get_user_id(self) -> str: - return self.senderId - - def get_session_id(self) -> str: - """ - - 类型: ``str`` - - 说明: 消息 ID - """ - return self.msgId + raise ValueError("Event has no type!") + @overrides(BaseEvent) def get_message(self) -> "Message": - """ - - 类型: ``Message`` - - 说明: 消息内容 - """ - return self.message + raise ValueError("Event has no type!") + @overrides(BaseEvent) def get_plaintext(self) -> str: - """ - - 类型: ``str`` - - 说明: 纯文本消息内容 - """ - return self.message.extract_plain_text().strip() if self.message else "" + raise ValueError("Event has no type!") + @overrides(BaseEvent) + def get_user_id(self) -> str: + raise ValueError("Event has no type!") -class MessageEvent(MessageModel, Event): - pass - - -class PrivateMessageEvent(PrivateMessageModel, Event): + @overrides(BaseEvent) + def get_session_id(self) -> str: + raise ValueError("Event has no type!") + @overrides(BaseEvent) def is_tome(self) -> bool: return True -class GroupMessageEvent(GroupMessageModel, Event): +class TextMessage(BaseModel): + content: str + +class AtUsersItem(BaseModel): + dingtalkId: str + staffId: Optional[str] + + +class ConversationType(str, Enum): + private = "1" + group = "2" + + +class MessageEvent(Event): + msgtype: str + text: TextMessage + msgId: str + createAt: int # ms + conversationType: ConversationType + conversationId: str + senderId: str + senderNick: str + senderCorpId: str + sessionWebhook: str + sessionWebhookExpiredTime: int + isAdmin: bool + + @overrides(Event) + def get_type(self) -> Literal["message", "notice", "request", "meta_event"]: + return "message" + + @overrides(BaseEvent) + def get_event_name(self) -> str: + return f"{self.get_type()}.{self.conversationType.name}" + + @overrides(BaseEvent) + def get_event_description(self) -> str: + return f'Message[{self.msgtype}] {self.msgId} from {self.senderId} "{self.text.content}"' + + @overrides(BaseEvent) + def get_plaintext(self) -> str: + return self.text.content + + @overrides(BaseEvent) + def get_user_id(self) -> str: + return self.senderId + + @overrides(BaseEvent) + def get_session_id(self) -> str: + return self.senderId + + +class PrivateMessageEvent(MessageEvent): + chatbotCorpId: str + senderStaffId: Optional[str] + conversationType: ConversationType = ConversationType.private + + +class GroupMessageEvent(MessageEvent): + atUsers: List[AtUsersItem] + conversationType: ConversationType = ConversationType.group + conversationTitle: str + isInAtList: bool + + @overrides(MessageEvent) def is_tome(self) -> bool: return self.isInAtList diff --git a/nonebot/adapters/ding/exception.py b/nonebot/adapters/ding/exception.py index 37276eaa..63721efc 100644 --- a/nonebot/adapters/ding/exception.py +++ b/nonebot/adapters/ding/exception.py @@ -39,6 +39,9 @@ class ActionFailed(BaseActionFailed, DingAdapterException): def __repr__(self): return f"" + def __str__(self): + return self.__repr__() + class ApiNotAvailable(BaseApiNotAvailable, DingAdapterException): pass @@ -66,7 +69,7 @@ class NetworkError(BaseNetworkError, DingAdapterException): return self.__repr__() -class SessionExpired(BaseApiNotAvailable, DingAdapterException): +class SessionExpired(ApiNotAvailable, DingAdapterException): """ :说明: @@ -75,3 +78,6 @@ class SessionExpired(BaseApiNotAvailable, DingAdapterException): def __repr__(self) -> str: return f"" + + def __str__(self): + return self.__repr__() diff --git a/nonebot/adapters/ding/message.py b/nonebot/adapters/ding/message.py index cf6f56c0..db3a0083 100644 --- a/nonebot/adapters/ding/message.py +++ b/nonebot/adapters/ding/message.py @@ -2,39 +2,23 @@ from typing import Any, Dict, Union, Iterable from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment -from .utils import log -from .model import TextMessage - class MessageSegment(BaseMessageSegment): """ 钉钉 协议 MessageSegment 适配。具体方法参考协议消息段类型或源码。 """ - def __init__(self, type_: str, msg: Dict[str, Any]) -> None: - data = { - "msgtype": type_, - } - if msg: - data.update(msg) - log("DEBUG", f"data {data}") + def __init__(self, type_: str, data: Dict[str, Any]) -> None: super().__init__(type=type_, data=data) - @classmethod - def from_segment(cls, segment: "MessageSegment"): - return MessageSegment(segment.type, segment.data) - def __str__(self): - log("DEBUG", f"__str__: self.type {self.type} data {self.data}") if self.type == "text": - return str(self.data["text"]["content"].strip()) + return str(self.data["content"]) + elif self.type == "markdown": + return str(self.data["text"]) return "" def __add__(self, other) -> "Message": - if isinstance(other, str): - if self.type == 'text': - self.data['text']['content'] += other - return MessageSegment.from_segment(self) return Message(self) + other def __radd__(self, other) -> "Message": @@ -43,43 +27,41 @@ class MessageSegment(BaseMessageSegment): def is_text(self) -> bool: return self.type == "text" - def atMobile(self, mobileNumber): - self.data.setdefault("at", {}) - self.data["at"].setdefault("atMobiles", []) - self.data["at"]["atMobiles"].append(mobileNumber) - - def atAll(self, value): - self.data.setdefault("at", {}) - self.data["at"]["isAtAll"] = value + @staticmethod + def atAll() -> "MessageSegment": + return MessageSegment("at", {"isAtAll": True}) @staticmethod - def text(text_: str) -> "MessageSegment": - return MessageSegment("text", {"text": {"content": text_.strip()}}) + def atMobiles(*mobileNumber: str) -> "MessageSegment": + return MessageSegment("at", {"atMobiles": list(mobileNumber)}) + + @staticmethod + def text(text: str) -> "MessageSegment": + return MessageSegment("text", {"content": text}) @staticmethod def markdown(title: str, text: str) -> "MessageSegment": - return MessageSegment("markdown", { - "markdown": { + return MessageSegment( + "markdown", + { "title": title, "text": text, }, - }) + ) @staticmethod def actionCardSingleBtn(title: str, text: str, btnTitle: str, btnUrl) -> "MessageSegment": return MessageSegment( "actionCard", { - "actionCard": { - "title": title, - "text": text, - "singleTitle": btnTitle, - "singleURL": btnUrl - } + "title": title, + "text": text, + "singleTitle": btnTitle, + "singleURL": btnUrl }) @staticmethod - def actionCardSingleMultiBtns( + def actionCardMultiBtns( title: str, text: str, btns: list = [], @@ -95,28 +77,26 @@ class MessageSegment(BaseMessageSegment): """ return MessageSegment( "actionCard", { - "actionCard": { - "title": title, - "text": text, - "hideAvatar": "1" if hideAvatar else "0", - "btnOrientation": btnOrientation, - "btns": btns - } + "title": title, + "text": text, + "hideAvatar": "1" if hideAvatar else "0", + "btnOrientation": btnOrientation, + "btns": btns }) @staticmethod - def feedCard(links: list = [],) -> "MessageSegment": + def feedCard(links: list = []) -> "MessageSegment": """ :参数: * ``links``: [{ "title": xxx, "messageURL": xxx, "picURL": xxx }, ...] """ - return MessageSegment("feedCard", {"feedCard": {"links": links}}) + return MessageSegment("feedCard", {"links": links}) @staticmethod def empty() -> "MessageSegment": """不想回复消息到群里""" - return MessageSegment("empty") + return MessageSegment("empty", {}) class Message(BaseMessage): @@ -129,17 +109,35 @@ class Message(BaseMessage): return cls(value) @staticmethod - def _construct( - msg: Union[str, dict, list, - TextMessage]) -> Iterable[MessageSegment]: + def _construct(msg: Union[str, dict, list]) -> Iterable[MessageSegment]: if isinstance(msg, dict): yield MessageSegment(msg["type"], msg.get("data") or {}) - return elif isinstance(msg, list): for seg in msg: yield MessageSegment(seg["type"], seg.get("data") or {}) - return - elif isinstance(msg, TextMessage): - yield MessageSegment("text", {"text": msg.dict()}) elif isinstance(msg, str): yield MessageSegment.text(msg) + + def _produce(self) -> dict: + data = {} + for segment in self: + if segment.type == "text": + data["msgtype"] = "text" + data.setdefault("text", {}) + data["text"]["content"] = data["text"].setdefault( + "content", "") + segment.data["content"] + elif segment.type == "markdown": + data["msgtype"] = "markdown" + data.setdefault("markdown", {}) + data["markdown"]["text"] = data["markdown"].setdefault( + "content", "") + segment.data["content"] + elif segment.type == "empty": + data["msgtype"] = "empty" + elif segment.type == "at" and "atMobiles" in segment.data: + data.setdefault("at", {}) + data["at"]["atMobiles"] = data["at"].setdefault( + "atMobiles", []) + segment.data["atMobiles"] + elif segment.data: + data.setdefault(segment.type, {}) + data[segment.type].update(segment.data) + return data diff --git a/nonebot/adapters/ding/model.py b/nonebot/adapters/ding/model.py deleted file mode 100644 index 49e4b0f5..00000000 --- a/nonebot/adapters/ding/model.py +++ /dev/null @@ -1,56 +0,0 @@ -from enum import Enum -from typing import List, Optional - -from pydantic import BaseModel - - -class Headers(BaseModel): - sign: str - token: str - # ms - timestamp: int - - -class TextMessage(BaseModel): - content: str - - -class AtUsersItem(BaseModel): - dingtalkId: str - staffId: Optional[str] - - -class ConversationType(str, Enum): - private = '1' - group = '2' - - -class MessageModel(BaseModel): - chatbotUserId: str = None - conversationId: str = None - conversationType: ConversationType = None - # ms - createAt: int = None - isAdmin: bool = None - msgId: str = None - msgtype: str = None - senderCorpId: str = None - senderId: str = None - senderNick: str = None - sessionWebhook: str = None - # ms - sessionWebhookExpiredTime: int = None - text: Optional[TextMessage] = None - - -class PrivateMessageModel(MessageModel): - chatbotCorpId: str = None - conversationType: ConversationType = ConversationType.private - senderStaffId: str = None - - -class GroupMessageModel(MessageModel): - atUsers: List[AtUsersItem] = None - conversationType: ConversationType = ConversationType.group - conversationTitle: str = None - isInAtList: bool = None