🐛 update mirai adapter

This commit is contained in:
yanyongyu 2021-06-18 01:23:13 +08:00
parent a4c6d834ff
commit cd12718dcb
9 changed files with 56 additions and 44 deletions

View File

@ -234,9 +234,8 @@ class Bot(abc.ABC):
T = TypeVar("T")
TMS = TypeVar("TMS")
TMS = TypeVar("TMS", covariant=True)
TM = TypeVar("TM", bound="Message")
# TM = TypeVar("TM_co", bound="Message")
@dataclass

View File

@ -16,6 +16,7 @@ class MessageSegment(BaseMessageSegment["Message"]):
"""
@classmethod
@overrides(BaseMessageSegment)
def get_message_class(cls) -> Type["Message"]:
return Message
@ -235,6 +236,7 @@ class Message(BaseMessage[MessageSegment]):
"""
@classmethod
@overrides(BaseMessage)
def get_segment_class(cls) -> Type[MessageSegment]:
return MessageSegment

View File

@ -64,7 +64,7 @@ GROUP = Permission(_group)
- **说明**: 匹配任意群聊消息类型事件
"""
GROUP_MEMBER = Permission(_group_member)
"""
r"""
- **说明**: 匹配任意群员群聊消息类型事件
\:\:\:warning 警告

View File

@ -1,35 +1,28 @@
from copy import copy
from typing import Any, Dict, Union, Mapping, Iterable
from typing import Any, Dict, Type, Union, Mapping, Iterable
from nonebot.typing import overrides
from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment
class MessageSegment(BaseMessageSegment):
class MessageSegment(BaseMessageSegment["Message"]):
"""
钉钉 协议 MessageSegment 适配具体方法参考协议消息段类型或源码
"""
@classmethod
@overrides(BaseMessageSegment)
def __init__(self, type_: str, data: Dict[str, Any]) -> None:
super().__init__(type=type_, data=data)
def get_message_class(cls) -> Type["Message"]:
return Message
@overrides(BaseMessageSegment)
def __str__(self):
def __str__(self) -> str:
if self.type == "text":
return str(self.data["content"])
elif self.type == "markdown":
return str(self.data["text"])
return ""
@overrides(BaseMessageSegment)
def __add__(self, other) -> "Message":
return Message(self) + other
@overrides(BaseMessageSegment)
def __radd__(self, other) -> "Message":
return Message(other) + self
@overrides(BaseMessageSegment)
def is_text(self) -> bool:
return self.type == "text"
@ -143,7 +136,7 @@ class MessageSegment(BaseMessageSegment):
def raw(data) -> "MessageSegment":
return MessageSegment('raw', data)
def to_dict(self) -> dict:
def to_dict(self) -> Dict[str, Any]:
# 让用户可以直接发送原始的消息格式
if self.type == "raw":
return copy(self.data)
@ -160,6 +153,11 @@ class Message(BaseMessage[MessageSegment]):
钉钉 协议 Message 适配
"""
@classmethod
@overrides(BaseMessage)
def get_segment_class(cls) -> Type[MessageSegment]:
return MessageSegment
@staticmethod
@overrides(BaseMessage)
def _construct(

View File

@ -30,12 +30,27 @@ class WebSocket(BaseWebSocket):
params={'sessionKey': session_key})
websocket = await websockets.connect(uri=str(listen_address))
await (await websocket.ping())
return cls(websocket)
return cls("1.1",
listen_address.scheme,
listen_address.path,
listen_address.query,
websocket=websocket)
@overrides(BaseWebSocket)
def __init__(self, websocket: websockets.WebSocketClientProtocol):
def __init__(self,
http_version: str,
scheme: str,
path: str,
query_string: bytes = b"",
headers: Dict[str, str] = None,
websocket: websockets.WebSocketClientProtocol = None):
self.event_handlers: Set[WebsocketHandlerFunction] = set()
super().__init__(websocket)
self.websocket: websockets.WebSocketClientProtocol = websocket # type: ignore
super(WebSocket, self).__init__(http_version=http_version,
scheme=scheme,
path=path,
query_string=query_string,
headers=headers or {})
@property
@overrides(BaseWebSocket)
@ -146,9 +161,7 @@ class WebsocketBot(Bot):
host=cls.mirai_config.host, # type: ignore
port=cls.mirai_config.port, # type: ignore
session_key=session.session_key)
bot = cls(connection_type='forward_ws',
self_id=str(qq),
websocket=websocket)
bot = cls(self_id=str(qq), request=websocket)
websocket.handle(bot.handle_message)
await websocket.accept()
return bot

View File

@ -1,5 +1,5 @@
"""
\:\:\: warning
r"""
\:\:\: warning
事件中为了使代码更加整洁, 我们采用了与PEP8相符的命名规则取代Mirai原有的驼峰命名
部分字段可能与文档在符号上不一致

View File

@ -14,12 +14,12 @@ from nonebot.typing import overrides
class UserPermission(str, Enum):
"""
:说明:
用户权限枚举类
* ``OWNER``: 群主
* ``ADMINISTRATOR``: 群管理
* ``MEMBER``: 普通群成员
用户权限枚举类
* ``OWNER``: 群主
* ``ADMINISTRATOR``: 群管理
* ``MEMBER``: 普通群成员
"""
OWNER = 'OWNER'
ADMINISTRATOR = 'ADMINISTRATOR'

View File

@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, List, Dict, Type, Iterable, Optional, Union
from pydantic import validate_arguments
@ -25,7 +25,7 @@ class MessageType(str, Enum):
POKE = 'Poke'
class MessageSegment(BaseMessageSegment):
class MessageSegment(BaseMessageSegment["MessageChain"]):
"""
Mirai-API-HTTP 协议 MessageSegment 适配具体方法参考 `mirai-api-http 消息类型`_
@ -36,9 +36,13 @@ class MessageSegment(BaseMessageSegment):
type: MessageType
data: Dict[str, Any]
@overrides(BaseMessageSegment)
@classmethod
def get_message_class(cls) -> Type["MessageChain"]:
return MessageChain
@validate_arguments
def __init__(self, type: MessageType, **data):
@overrides(BaseMessageSegment)
def __init__(self, type: MessageType, **data: Any):
super().__init__(type=type,
data={k: v for k, v in data.items() if v is not None})
@ -55,14 +59,6 @@ class MessageSegment(BaseMessageSegment):
),
])
@overrides(BaseMessageSegment)
def __add__(self, other) -> "MessageChain":
return MessageChain(self) + other
@overrides(BaseMessageSegment)
def __radd__(self, other) -> "MessageChain":
return MessageChain(other) + self
@overrides(BaseMessageSegment)
def is_text(self) -> bool:
return self.type == MessageType.PLAIN
@ -273,6 +269,11 @@ class MessageChain(BaseMessage[MessageSegment]):
由于Mirai协议的Message实现较为特殊, 故使用MessageChain命名
"""
@classmethod
@overrides(BaseMessage)
def get_segment_class(cls) -> Type[MessageSegment]:
return MessageSegment
@overrides(BaseMessage)
def __init__(self, message: Union[List[Dict[str,
Any]], Iterable[MessageSegment],

View File

@ -73,7 +73,7 @@ class InvalidArgument(exception.AdapterException):
def catch_network_error(function: _AsyncCallable) -> _AsyncCallable:
"""
r"""
:说明:
捕捉函数抛出的httpx网络异常并释放 ``NetworkError`` 异常
@ -170,7 +170,6 @@ def process_reply(bot: "Bot", event: GroupMessage) -> GroupMessage:
async def process_event(bot: "Bot", event: Event) -> None:
if isinstance(event, MessageEvent):
event.message_chain.reduce()
Log.debug(event.message_chain)
event = process_source(bot, event)
if isinstance(event, GroupMessage):