🐛 update mirai adapter

This commit is contained in:
yanyongyu 2021-06-18 01:23:13 +08:00
parent a4c6d834ff
commit cd12718dcb
9 changed files with 56 additions and 44 deletions

View File

@ -234,9 +234,8 @@ class Bot(abc.ABC):
T = TypeVar("T") T = TypeVar("T")
TMS = TypeVar("TMS") TMS = TypeVar("TMS", covariant=True)
TM = TypeVar("TM", bound="Message") TM = TypeVar("TM", bound="Message")
# TM = TypeVar("TM_co", bound="Message")
@dataclass @dataclass

View File

@ -16,6 +16,7 @@ class MessageSegment(BaseMessageSegment["Message"]):
""" """
@classmethod @classmethod
@overrides(BaseMessageSegment)
def get_message_class(cls) -> Type["Message"]: def get_message_class(cls) -> Type["Message"]:
return Message return Message
@ -235,6 +236,7 @@ class Message(BaseMessage[MessageSegment]):
""" """
@classmethod @classmethod
@overrides(BaseMessage)
def get_segment_class(cls) -> Type[MessageSegment]: def get_segment_class(cls) -> Type[MessageSegment]:
return MessageSegment return MessageSegment

View File

@ -64,7 +64,7 @@ GROUP = Permission(_group)
- **说明**: 匹配任意群聊消息类型事件 - **说明**: 匹配任意群聊消息类型事件
""" """
GROUP_MEMBER = Permission(_group_member) GROUP_MEMBER = Permission(_group_member)
""" r"""
- **说明**: 匹配任意群员群聊消息类型事件 - **说明**: 匹配任意群员群聊消息类型事件
\:\:\:warning 警告 \:\:\:warning 警告

View File

@ -1,35 +1,28 @@
from copy import copy from copy import copy
from typing import Any, Dict, Union, Mapping, Iterable from typing import Any, Dict, Type, Union, Mapping, Iterable
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment
class MessageSegment(BaseMessageSegment): class MessageSegment(BaseMessageSegment["Message"]):
""" """
钉钉 协议 MessageSegment 适配具体方法参考协议消息段类型或源码 钉钉 协议 MessageSegment 适配具体方法参考协议消息段类型或源码
""" """
@classmethod
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
def __init__(self, type_: str, data: Dict[str, Any]) -> None: def get_message_class(cls) -> Type["Message"]:
super().__init__(type=type_, data=data) return Message
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
def __str__(self): def __str__(self) -> str:
if self.type == "text": if self.type == "text":
return str(self.data["content"]) return str(self.data["content"])
elif self.type == "markdown": elif self.type == "markdown":
return str(self.data["text"]) return str(self.data["text"])
return "" return ""
@overrides(BaseMessageSegment)
def __add__(self, other) -> "Message":
return Message(self) + other
@overrides(BaseMessageSegment)
def __radd__(self, other) -> "Message":
return Message(other) + self
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
def is_text(self) -> bool: def is_text(self) -> bool:
return self.type == "text" return self.type == "text"
@ -143,7 +136,7 @@ class MessageSegment(BaseMessageSegment):
def raw(data) -> "MessageSegment": def raw(data) -> "MessageSegment":
return MessageSegment('raw', data) return MessageSegment('raw', data)
def to_dict(self) -> dict: def to_dict(self) -> Dict[str, Any]:
# 让用户可以直接发送原始的消息格式 # 让用户可以直接发送原始的消息格式
if self.type == "raw": if self.type == "raw":
return copy(self.data) return copy(self.data)
@ -160,6 +153,11 @@ class Message(BaseMessage[MessageSegment]):
钉钉 协议 Message 适配 钉钉 协议 Message 适配
""" """
@classmethod
@overrides(BaseMessage)
def get_segment_class(cls) -> Type[MessageSegment]:
return MessageSegment
@staticmethod @staticmethod
@overrides(BaseMessage) @overrides(BaseMessage)
def _construct( def _construct(

View File

@ -30,12 +30,27 @@ class WebSocket(BaseWebSocket):
params={'sessionKey': session_key}) params={'sessionKey': session_key})
websocket = await websockets.connect(uri=str(listen_address)) websocket = await websockets.connect(uri=str(listen_address))
await (await websocket.ping()) await (await websocket.ping())
return cls(websocket) return cls("1.1",
listen_address.scheme,
listen_address.path,
listen_address.query,
websocket=websocket)
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
def __init__(self, websocket: websockets.WebSocketClientProtocol): def __init__(self,
http_version: str,
scheme: str,
path: str,
query_string: bytes = b"",
headers: Dict[str, str] = None,
websocket: websockets.WebSocketClientProtocol = None):
self.event_handlers: Set[WebsocketHandlerFunction] = set() self.event_handlers: Set[WebsocketHandlerFunction] = set()
super().__init__(websocket) self.websocket: websockets.WebSocketClientProtocol = websocket # type: ignore
super(WebSocket, self).__init__(http_version=http_version,
scheme=scheme,
path=path,
query_string=query_string,
headers=headers or {})
@property @property
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
@ -146,9 +161,7 @@ class WebsocketBot(Bot):
host=cls.mirai_config.host, # type: ignore host=cls.mirai_config.host, # type: ignore
port=cls.mirai_config.port, # type: ignore port=cls.mirai_config.port, # type: ignore
session_key=session.session_key) session_key=session.session_key)
bot = cls(connection_type='forward_ws', bot = cls(self_id=str(qq), request=websocket)
self_id=str(qq),
websocket=websocket)
websocket.handle(bot.handle_message) websocket.handle(bot.handle_message)
await websocket.accept() await websocket.accept()
return bot return bot

View File

@ -1,4 +1,4 @@
""" r"""
\:\:\: warning \:\:\: warning
事件中为了使代码更加整洁, 我们采用了与PEP8相符的命名规则取代Mirai原有的驼峰命名 事件中为了使代码更加整洁, 我们采用了与PEP8相符的命名规则取代Mirai原有的驼峰命名

View File

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Union from typing import Any, List, Dict, Type, Iterable, Optional, Union
from pydantic import validate_arguments from pydantic import validate_arguments
@ -25,7 +25,7 @@ class MessageType(str, Enum):
POKE = 'Poke' POKE = 'Poke'
class MessageSegment(BaseMessageSegment): class MessageSegment(BaseMessageSegment["MessageChain"]):
""" """
Mirai-API-HTTP 协议 MessageSegment 适配具体方法参考 `mirai-api-http 消息类型`_ Mirai-API-HTTP 协议 MessageSegment 适配具体方法参考 `mirai-api-http 消息类型`_
@ -36,9 +36,13 @@ class MessageSegment(BaseMessageSegment):
type: MessageType type: MessageType
data: Dict[str, Any] data: Dict[str, Any]
@overrides(BaseMessageSegment) @classmethod
def get_message_class(cls) -> Type["MessageChain"]:
return MessageChain
@validate_arguments @validate_arguments
def __init__(self, type: MessageType, **data): @overrides(BaseMessageSegment)
def __init__(self, type: MessageType, **data: Any):
super().__init__(type=type, super().__init__(type=type,
data={k: v for k, v in data.items() if v is not None}) data={k: v for k, v in data.items() if v is not None})
@ -55,14 +59,6 @@ class MessageSegment(BaseMessageSegment):
), ),
]) ])
@overrides(BaseMessageSegment)
def __add__(self, other) -> "MessageChain":
return MessageChain(self) + other
@overrides(BaseMessageSegment)
def __radd__(self, other) -> "MessageChain":
return MessageChain(other) + self
@overrides(BaseMessageSegment) @overrides(BaseMessageSegment)
def is_text(self) -> bool: def is_text(self) -> bool:
return self.type == MessageType.PLAIN return self.type == MessageType.PLAIN
@ -273,6 +269,11 @@ class MessageChain(BaseMessage[MessageSegment]):
由于Mirai协议的Message实现较为特殊, 故使用MessageChain命名 由于Mirai协议的Message实现较为特殊, 故使用MessageChain命名
""" """
@classmethod
@overrides(BaseMessage)
def get_segment_class(cls) -> Type[MessageSegment]:
return MessageSegment
@overrides(BaseMessage) @overrides(BaseMessage)
def __init__(self, message: Union[List[Dict[str, def __init__(self, message: Union[List[Dict[str,
Any]], Iterable[MessageSegment], Any]], Iterable[MessageSegment],

View File

@ -73,7 +73,7 @@ class InvalidArgument(exception.AdapterException):
def catch_network_error(function: _AsyncCallable) -> _AsyncCallable: def catch_network_error(function: _AsyncCallable) -> _AsyncCallable:
""" r"""
:说明: :说明:
捕捉函数抛出的httpx网络异常并释放 ``NetworkError`` 异常 捕捉函数抛出的httpx网络异常并释放 ``NetworkError`` 异常
@ -170,7 +170,6 @@ def process_reply(bot: "Bot", event: GroupMessage) -> GroupMessage:
async def process_event(bot: "Bot", event: Event) -> None: async def process_event(bot: "Bot", event: Event) -> None:
if isinstance(event, MessageEvent): if isinstance(event, MessageEvent):
event.message_chain.reduce()
Log.debug(event.message_chain) Log.debug(event.message_chain)
event = process_source(bot, event) event = process_source(bot, event)
if isinstance(event, GroupMessage): if isinstance(event, GroupMessage):