From cd12718dcb1b2053fc5275643f94b98c04870740 Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Fri, 18 Jun 2021 01:23:13 +0800 Subject: [PATCH] :bug: update mirai adapter --- nonebot/adapters/_base.py | 3 +-- .../nonebot/adapters/cqhttp/message.py | 2 ++ .../nonebot/adapters/cqhttp/permission.py | 2 +- .../nonebot/adapters/ding/message.py | 26 +++++++++---------- .../nonebot/adapters/mirai/bot_ws.py | 25 +++++++++++++----- .../nonebot/adapters/mirai/event/__init__.py | 4 +-- .../nonebot/adapters/mirai/event/base.py | 10 +++---- .../nonebot/adapters/mirai/message.py | 25 +++++++++--------- .../nonebot/adapters/mirai/utils.py | 3 +-- 9 files changed, 56 insertions(+), 44 deletions(-) diff --git a/nonebot/adapters/_base.py b/nonebot/adapters/_base.py index e1986bbf..24916b90 100644 --- a/nonebot/adapters/_base.py +++ b/nonebot/adapters/_base.py @@ -234,9 +234,8 @@ class Bot(abc.ABC): T = TypeVar("T") -TMS = TypeVar("TMS") +TMS = TypeVar("TMS", covariant=True) TM = TypeVar("TM", bound="Message") -# TM = TypeVar("TM_co", bound="Message") @dataclass diff --git a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/message.py b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/message.py index 1b0184b6..79b016bd 100644 --- a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/message.py +++ b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/message.py @@ -16,6 +16,7 @@ class MessageSegment(BaseMessageSegment["Message"]): """ @classmethod + @overrides(BaseMessageSegment) def get_message_class(cls) -> Type["Message"]: return Message @@ -235,6 +236,7 @@ class Message(BaseMessage[MessageSegment]): """ @classmethod + @overrides(BaseMessage) def get_segment_class(cls) -> Type[MessageSegment]: return MessageSegment diff --git a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/permission.py b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/permission.py index 5a9cfea8..1d3b3f36 100644 --- a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/permission.py +++ b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/permission.py @@ -64,7 +64,7 @@ GROUP = Permission(_group) - **说明**: 匹配任意群聊消息类型事件 """ GROUP_MEMBER = Permission(_group_member) -""" +r""" - **说明**: 匹配任意群员群聊消息类型事件 \:\:\:warning 警告 diff --git a/packages/nonebot-adapter-ding/nonebot/adapters/ding/message.py b/packages/nonebot-adapter-ding/nonebot/adapters/ding/message.py index 5da41b65..a9559435 100644 --- a/packages/nonebot-adapter-ding/nonebot/adapters/ding/message.py +++ b/packages/nonebot-adapter-ding/nonebot/adapters/ding/message.py @@ -1,35 +1,28 @@ from copy import copy -from typing import Any, Dict, Union, Mapping, Iterable +from typing import Any, Dict, Type, Union, Mapping, Iterable from nonebot.typing import overrides from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment -class MessageSegment(BaseMessageSegment): +class MessageSegment(BaseMessageSegment["Message"]): """ 钉钉 协议 MessageSegment 适配。具体方法参考协议消息段类型或源码。 """ + @classmethod @overrides(BaseMessageSegment) - def __init__(self, type_: str, data: Dict[str, Any]) -> None: - super().__init__(type=type_, data=data) + def get_message_class(cls) -> Type["Message"]: + return Message @overrides(BaseMessageSegment) - def __str__(self): + def __str__(self) -> str: if self.type == "text": return str(self.data["content"]) elif self.type == "markdown": return str(self.data["text"]) return "" - @overrides(BaseMessageSegment) - def __add__(self, other) -> "Message": - return Message(self) + other - - @overrides(BaseMessageSegment) - def __radd__(self, other) -> "Message": - return Message(other) + self - @overrides(BaseMessageSegment) def is_text(self) -> bool: return self.type == "text" @@ -143,7 +136,7 @@ class MessageSegment(BaseMessageSegment): def raw(data) -> "MessageSegment": return MessageSegment('raw', data) - def to_dict(self) -> dict: + def to_dict(self) -> Dict[str, Any]: # 让用户可以直接发送原始的消息格式 if self.type == "raw": return copy(self.data) @@ -160,6 +153,11 @@ class Message(BaseMessage[MessageSegment]): 钉钉 协议 Message 适配。 """ + @classmethod + @overrides(BaseMessage) + def get_segment_class(cls) -> Type[MessageSegment]: + return MessageSegment + @staticmethod @overrides(BaseMessage) def _construct( diff --git a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot_ws.py b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot_ws.py index c7139772..29fc12bf 100644 --- a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot_ws.py +++ b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/bot_ws.py @@ -30,12 +30,27 @@ class WebSocket(BaseWebSocket): params={'sessionKey': session_key}) websocket = await websockets.connect(uri=str(listen_address)) await (await websocket.ping()) - return cls(websocket) + return cls("1.1", + listen_address.scheme, + listen_address.path, + listen_address.query, + websocket=websocket) @overrides(BaseWebSocket) - def __init__(self, websocket: websockets.WebSocketClientProtocol): + def __init__(self, + http_version: str, + scheme: str, + path: str, + query_string: bytes = b"", + headers: Dict[str, str] = None, + websocket: websockets.WebSocketClientProtocol = None): self.event_handlers: Set[WebsocketHandlerFunction] = set() - super().__init__(websocket) + self.websocket: websockets.WebSocketClientProtocol = websocket # type: ignore + super(WebSocket, self).__init__(http_version=http_version, + scheme=scheme, + path=path, + query_string=query_string, + headers=headers or {}) @property @overrides(BaseWebSocket) @@ -146,9 +161,7 @@ class WebsocketBot(Bot): host=cls.mirai_config.host, # type: ignore port=cls.mirai_config.port, # type: ignore session_key=session.session_key) - bot = cls(connection_type='forward_ws', - self_id=str(qq), - websocket=websocket) + bot = cls(self_id=str(qq), request=websocket) websocket.handle(bot.handle_message) await websocket.accept() return bot diff --git a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/event/__init__.py b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/event/__init__.py index 91f4b127..78e5cba4 100644 --- a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/event/__init__.py +++ b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/event/__init__.py @@ -1,5 +1,5 @@ -""" -\:\:\: warning +r""" +\:\:\: warning 事件中为了使代码更加整洁, 我们采用了与PEP8相符的命名规则取代Mirai原有的驼峰命名 部分字段可能与文档在符号上不一致 diff --git a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/event/base.py b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/event/base.py index 4a7b3809..e0b976bc 100644 --- a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/event/base.py +++ b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/event/base.py @@ -14,12 +14,12 @@ from nonebot.typing import overrides class UserPermission(str, Enum): """ :说明: - - 用户权限枚举类 - * ``OWNER``: 群主 - * ``ADMINISTRATOR``: 群管理 - * ``MEMBER``: 普通群成员 + 用户权限枚举类 + + * ``OWNER``: 群主 + * ``ADMINISTRATOR``: 群管理 + * ``MEMBER``: 普通群成员 """ OWNER = 'OWNER' ADMINISTRATOR = 'ADMINISTRATOR' diff --git a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/message.py b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/message.py index 6d061ccb..c7ad6841 100644 --- a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/message.py +++ b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/message.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, List, Dict, Type, Iterable, Optional, Union from pydantic import validate_arguments @@ -25,7 +25,7 @@ class MessageType(str, Enum): POKE = 'Poke' -class MessageSegment(BaseMessageSegment): +class MessageSegment(BaseMessageSegment["MessageChain"]): """ Mirai-API-HTTP 协议 MessageSegment 适配。具体方法参考 `mirai-api-http 消息类型`_ @@ -36,9 +36,13 @@ class MessageSegment(BaseMessageSegment): type: MessageType data: Dict[str, Any] - @overrides(BaseMessageSegment) + @classmethod + def get_message_class(cls) -> Type["MessageChain"]: + return MessageChain + @validate_arguments - def __init__(self, type: MessageType, **data): + @overrides(BaseMessageSegment) + def __init__(self, type: MessageType, **data: Any): super().__init__(type=type, data={k: v for k, v in data.items() if v is not None}) @@ -55,14 +59,6 @@ class MessageSegment(BaseMessageSegment): ), ]) - @overrides(BaseMessageSegment) - def __add__(self, other) -> "MessageChain": - return MessageChain(self) + other - - @overrides(BaseMessageSegment) - def __radd__(self, other) -> "MessageChain": - return MessageChain(other) + self - @overrides(BaseMessageSegment) def is_text(self) -> bool: return self.type == MessageType.PLAIN @@ -273,6 +269,11 @@ class MessageChain(BaseMessage[MessageSegment]): 由于Mirai协议的Message实现较为特殊, 故使用MessageChain命名 """ + @classmethod + @overrides(BaseMessage) + def get_segment_class(cls) -> Type[MessageSegment]: + return MessageSegment + @overrides(BaseMessage) def __init__(self, message: Union[List[Dict[str, Any]], Iterable[MessageSegment], diff --git a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/utils.py b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/utils.py index 14879170..200f0197 100644 --- a/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/utils.py +++ b/packages/nonebot-adapter-mirai/nonebot/adapters/mirai/utils.py @@ -73,7 +73,7 @@ class InvalidArgument(exception.AdapterException): def catch_network_error(function: _AsyncCallable) -> _AsyncCallable: - """ + r""" :说明: 捕捉函数抛出的httpx网络异常并释放 ``NetworkError`` 异常 @@ -170,7 +170,6 @@ def process_reply(bot: "Bot", event: GroupMessage) -> GroupMessage: async def process_event(bot: "Bot", event: Event) -> None: if isinstance(event, MessageEvent): - event.message_chain.reduce() Log.debug(event.message_chain) event = process_source(bot, event) if isinstance(event, GroupMessage):