mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-27 18:45:05 +08:00
🐛 update mirai adapter
This commit is contained in:
parent
a4c6d834ff
commit
cd12718dcb
@ -234,9 +234,8 @@ class Bot(abc.ABC):
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
TMS = TypeVar("TMS")
|
||||
TMS = TypeVar("TMS", covariant=True)
|
||||
TM = TypeVar("TM", bound="Message")
|
||||
# TM = TypeVar("TM_co", bound="Message")
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -16,6 +16,7 @@ class MessageSegment(BaseMessageSegment["Message"]):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@overrides(BaseMessageSegment)
|
||||
def get_message_class(cls) -> Type["Message"]:
|
||||
return Message
|
||||
|
||||
@ -235,6 +236,7 @@ class Message(BaseMessage[MessageSegment]):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@overrides(BaseMessage)
|
||||
def get_segment_class(cls) -> Type[MessageSegment]:
|
||||
return MessageSegment
|
||||
|
||||
|
@ -64,7 +64,7 @@ GROUP = Permission(_group)
|
||||
- **说明**: 匹配任意群聊消息类型事件
|
||||
"""
|
||||
GROUP_MEMBER = Permission(_group_member)
|
||||
"""
|
||||
r"""
|
||||
- **说明**: 匹配任意群员群聊消息类型事件
|
||||
|
||||
\:\:\:warning 警告
|
||||
|
@ -1,35 +1,28 @@
|
||||
from copy import copy
|
||||
from typing import Any, Dict, Union, Mapping, Iterable
|
||||
from typing import Any, Dict, Type, Union, Mapping, Iterable
|
||||
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.adapters import Message as BaseMessage, MessageSegment as BaseMessageSegment
|
||||
|
||||
|
||||
class MessageSegment(BaseMessageSegment):
|
||||
class MessageSegment(BaseMessageSegment["Message"]):
|
||||
"""
|
||||
钉钉 协议 MessageSegment 适配。具体方法参考协议消息段类型或源码。
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@overrides(BaseMessageSegment)
|
||||
def __init__(self, type_: str, data: Dict[str, Any]) -> None:
|
||||
super().__init__(type=type_, data=data)
|
||||
def get_message_class(cls) -> Type["Message"]:
|
||||
return Message
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
if self.type == "text":
|
||||
return str(self.data["content"])
|
||||
elif self.type == "markdown":
|
||||
return str(self.data["text"])
|
||||
return ""
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
def __add__(self, other) -> "Message":
|
||||
return Message(self) + other
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
def __radd__(self, other) -> "Message":
|
||||
return Message(other) + self
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
def is_text(self) -> bool:
|
||||
return self.type == "text"
|
||||
@ -143,7 +136,7 @@ class MessageSegment(BaseMessageSegment):
|
||||
def raw(data) -> "MessageSegment":
|
||||
return MessageSegment('raw', data)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
# 让用户可以直接发送原始的消息格式
|
||||
if self.type == "raw":
|
||||
return copy(self.data)
|
||||
@ -160,6 +153,11 @@ class Message(BaseMessage[MessageSegment]):
|
||||
钉钉 协议 Message 适配。
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@overrides(BaseMessage)
|
||||
def get_segment_class(cls) -> Type[MessageSegment]:
|
||||
return MessageSegment
|
||||
|
||||
@staticmethod
|
||||
@overrides(BaseMessage)
|
||||
def _construct(
|
||||
|
@ -30,12 +30,27 @@ class WebSocket(BaseWebSocket):
|
||||
params={'sessionKey': session_key})
|
||||
websocket = await websockets.connect(uri=str(listen_address))
|
||||
await (await websocket.ping())
|
||||
return cls(websocket)
|
||||
return cls("1.1",
|
||||
listen_address.scheme,
|
||||
listen_address.path,
|
||||
listen_address.query,
|
||||
websocket=websocket)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
def __init__(self, websocket: websockets.WebSocketClientProtocol):
|
||||
def __init__(self,
|
||||
http_version: str,
|
||||
scheme: str,
|
||||
path: str,
|
||||
query_string: bytes = b"",
|
||||
headers: Dict[str, str] = None,
|
||||
websocket: websockets.WebSocketClientProtocol = None):
|
||||
self.event_handlers: Set[WebsocketHandlerFunction] = set()
|
||||
super().__init__(websocket)
|
||||
self.websocket: websockets.WebSocketClientProtocol = websocket # type: ignore
|
||||
super(WebSocket, self).__init__(http_version=http_version,
|
||||
scheme=scheme,
|
||||
path=path,
|
||||
query_string=query_string,
|
||||
headers=headers or {})
|
||||
|
||||
@property
|
||||
@overrides(BaseWebSocket)
|
||||
@ -146,9 +161,7 @@ class WebsocketBot(Bot):
|
||||
host=cls.mirai_config.host, # type: ignore
|
||||
port=cls.mirai_config.port, # type: ignore
|
||||
session_key=session.session_key)
|
||||
bot = cls(connection_type='forward_ws',
|
||||
self_id=str(qq),
|
||||
websocket=websocket)
|
||||
bot = cls(self_id=str(qq), request=websocket)
|
||||
websocket.handle(bot.handle_message)
|
||||
await websocket.accept()
|
||||
return bot
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""
|
||||
\:\:\: warning
|
||||
r"""
|
||||
\:\:\: warning
|
||||
事件中为了使代码更加整洁, 我们采用了与PEP8相符的命名规则取代Mirai原有的驼峰命名
|
||||
|
||||
部分字段可能与文档在符号上不一致
|
||||
|
@ -14,12 +14,12 @@ from nonebot.typing import overrides
|
||||
class UserPermission(str, Enum):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
用户权限枚举类
|
||||
|
||||
* ``OWNER``: 群主
|
||||
* ``ADMINISTRATOR``: 群管理
|
||||
* ``MEMBER``: 普通群成员
|
||||
用户权限枚举类
|
||||
|
||||
* ``OWNER``: 群主
|
||||
* ``ADMINISTRATOR``: 群管理
|
||||
* ``MEMBER``: 普通群成员
|
||||
"""
|
||||
OWNER = 'OWNER'
|
||||
ADMINISTRATOR = 'ADMINISTRATOR'
|
||||
|
@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
from typing import Any, List, Dict, Type, Iterable, Optional, Union
|
||||
|
||||
from pydantic import validate_arguments
|
||||
|
||||
@ -25,7 +25,7 @@ class MessageType(str, Enum):
|
||||
POKE = 'Poke'
|
||||
|
||||
|
||||
class MessageSegment(BaseMessageSegment):
|
||||
class MessageSegment(BaseMessageSegment["MessageChain"]):
|
||||
"""
|
||||
Mirai-API-HTTP 协议 MessageSegment 适配。具体方法参考 `mirai-api-http 消息类型`_
|
||||
|
||||
@ -36,9 +36,13 @@ class MessageSegment(BaseMessageSegment):
|
||||
type: MessageType
|
||||
data: Dict[str, Any]
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
@classmethod
|
||||
def get_message_class(cls) -> Type["MessageChain"]:
|
||||
return MessageChain
|
||||
|
||||
@validate_arguments
|
||||
def __init__(self, type: MessageType, **data):
|
||||
@overrides(BaseMessageSegment)
|
||||
def __init__(self, type: MessageType, **data: Any):
|
||||
super().__init__(type=type,
|
||||
data={k: v for k, v in data.items() if v is not None})
|
||||
|
||||
@ -55,14 +59,6 @@ class MessageSegment(BaseMessageSegment):
|
||||
),
|
||||
])
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
def __add__(self, other) -> "MessageChain":
|
||||
return MessageChain(self) + other
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
def __radd__(self, other) -> "MessageChain":
|
||||
return MessageChain(other) + self
|
||||
|
||||
@overrides(BaseMessageSegment)
|
||||
def is_text(self) -> bool:
|
||||
return self.type == MessageType.PLAIN
|
||||
@ -273,6 +269,11 @@ class MessageChain(BaseMessage[MessageSegment]):
|
||||
由于Mirai协议的Message实现较为特殊, 故使用MessageChain命名
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@overrides(BaseMessage)
|
||||
def get_segment_class(cls) -> Type[MessageSegment]:
|
||||
return MessageSegment
|
||||
|
||||
@overrides(BaseMessage)
|
||||
def __init__(self, message: Union[List[Dict[str,
|
||||
Any]], Iterable[MessageSegment],
|
||||
|
@ -73,7 +73,7 @@ class InvalidArgument(exception.AdapterException):
|
||||
|
||||
|
||||
def catch_network_error(function: _AsyncCallable) -> _AsyncCallable:
|
||||
"""
|
||||
r"""
|
||||
:说明:
|
||||
|
||||
捕捉函数抛出的httpx网络异常并释放 ``NetworkError`` 异常
|
||||
@ -170,7 +170,6 @@ def process_reply(bot: "Bot", event: GroupMessage) -> GroupMessage:
|
||||
|
||||
async def process_event(bot: "Bot", event: Event) -> None:
|
||||
if isinstance(event, MessageEvent):
|
||||
event.message_chain.reduce()
|
||||
Log.debug(event.message_chain)
|
||||
event = process_source(bot, event)
|
||||
if isinstance(event, GroupMessage):
|
||||
|
Loading…
Reference in New Issue
Block a user