🚧 💡 add comments for message etc. in mirai adapter

This commit is contained in:
Mix 2021-02-01 01:04:30 +08:00
parent 7fdfd89525
commit 56592fc413
8 changed files with 250 additions and 28 deletions

View File

@ -1,3 +1,13 @@
"""
Mirai-API-HTTP 协议适配
============================
协议详情请看: `mirai-api-http 文档`_
.. mirai-api-http 文档:
https://github.com/project-mirai/mirai-api-http/tree/master/docs
"""
from .bot import MiraiBot from .bot import MiraiBot
from .bot_ws import MiraiWebsocketBot from .bot_ws import MiraiWebsocketBot
from .event import * from .event import *

View File

@ -32,6 +32,7 @@ class ActionFailed(BaseActionFailed):
class SessionManager: class SessionManager:
"""Bot会话管理器, 提供API主动调用接口"""
sessions: Dict[int, Tuple[str, datetime, httpx.AsyncClient]] = {} sessions: Dict[int, Tuple[str, datetime, httpx.AsyncClient]] = {}
session_expiry: timedelta = timedelta(minutes=15) session_expiry: timedelta = timedelta(minutes=15)
@ -40,10 +41,10 @@ class SessionManager:
@staticmethod @staticmethod
def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]: def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]:
code = data.get('code', 0) logger.opt(colors=True).debug(
logger.opt(colors=True).debug('Mirai API returned data: ' f'Mirai API returned data: <y>{escape_tag(str(data))}</y>')
f'<y>{escape_tag(str(data))}</y>') if isinstance(data, dict) and ('code' in data):
if code != 0: if data['code'] != 0:
raise ActionFailed(**data) raise ActionFailed(**data)
return data return data
@ -51,8 +52,28 @@ class SessionManager:
path: str, path: str,
*, *,
params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
params = {**(params or {}), 'sessionKey': self.session_key} """
response = await self.client.post(path, json=params, timeout=3) :说明:
以POST方式主动提交API请求
:参数:
* ``path: str``: 对应API路径
* ``params: Optional[Dict[str, Any]]``: 请求参数 (无需sessionKey)
:返回:
- ``Dict[str, Any]``: API 返回值
"""
response = await self.client.post(
path,
json={
**(params or {}),
'sessionKey': self.session_key,
},
timeout=3,
)
response.raise_for_status() response.raise_for_status()
return self._raise_code(response.json()) return self._raise_code(response.json())
@ -61,12 +82,28 @@ class SessionManager:
*, *,
params: Optional[Dict[str, params: Optional[Dict[str,
Any]] = None) -> Dict[str, Any]: Any]] = None) -> Dict[str, Any]:
response = await self.client.get(path, """
:说明:
以GET方式主动提交API请求
:参数:
* ``path: str``: 对应API路径
* ``params: Optional[Dict[str, Any]]``: 请求参数 (无需sessionKey)
:返回:
- ``Dict[str, Any]``: API 返回值
"""
response = await self.client.get(
path,
params={ params={
**(params or {}), 'sessionKey': **(params or {}),
self.session_key 'sessionKey': self.session_key,
}, },
timeout=3) timeout=3,
)
response.raise_for_status() response.raise_for_status()
return self._raise_code(response.json()) return self._raise_code(response.json())
@ -108,11 +145,11 @@ class SessionManager:
return cls(session_key, client) return cls(session_key, client)
@classmethod @classmethod
def get(cls, self_id: int): def get(cls, self_id: int, check_expire: bool = True):
if self_id not in cls.sessions: if self_id not in cls.sessions:
return None return None
key, time, client = cls.sessions[self_id] key, time, client = cls.sessions[self_id]
if datetime.now() - time > cls.session_expiry: if check_expire and (datetime.now() - time > cls.session_expiry):
return None return None
return cls(key, client) return cls(key, client)
@ -129,7 +166,6 @@ class MiraiBot(BaseBot):
*, *,
websocket: Optional[WebSocket] = None): websocket: Optional[WebSocket] = None):
super().__init__(connection_type, self_id, websocket=websocket) super().__init__(connection_type, self_id, websocket=websocket)
self.api = SessionManager.get(int(self_id))
@property @property
@overrides(BaseBot) @overrides(BaseBot)
@ -140,6 +176,13 @@ class MiraiBot(BaseBot):
def alive(self) -> bool: def alive(self) -> bool:
return not self.websocket.closed return not self.websocket.closed
@property
def api(self) -> SessionManager:
"""返回该Bot对象的会话管理实例以提供API主动调用"""
api = SessionManager.get(self_id=int(self.self_id))
assert api is not None, 'SessionManager has not been initialized'
return api
@classmethod @classmethod
@overrides(BaseBot) @overrides(BaseBot)
async def check_permission(cls, driver: "Driver", connection_type: str, async def check_permission(cls, driver: "Driver", connection_type: str,
@ -179,10 +222,12 @@ class MiraiBot(BaseBot):
@overrides(BaseBot) @overrides(BaseBot)
async def call_api(self, api: str, **data) -> NoReturn: async def call_api(self, api: str, **data) -> NoReturn:
"""由于Mirai的HTTP API特殊性, 该API暂时无法实现"""
raise NotImplementedError raise NotImplementedError
@overrides(BaseBot) @overrides(BaseBot)
def __getattr__(self, key: str) -> NoReturn: def __getattr__(self, key: str) -> NoReturn:
"""由于Mirai的HTTP API特殊性, 该API暂时无法实现"""
raise NotImplementedError raise NotImplementedError
@overrides(BaseBot) @overrides(BaseBot)

View File

@ -107,6 +107,12 @@ class MiraiWebsocketBot(MiraiBot):
def alive(self) -> bool: def alive(self) -> bool:
return not self.websocket.closed return not self.websocket.closed
@property
def api(self) -> SessionManager:
api = SessionManager.get(self_id=int(self.self_id), check_expire=False)
assert api is not None, 'SessionManager has not been initialized'
return api
@classmethod @classmethod
@overrides(MiraiBot) @overrides(MiraiBot)
async def check_permission(cls, driver: "Driver", connection_type: str, async def check_permission(cls, driver: "Driver", connection_type: str,
@ -118,12 +124,23 @@ class MiraiWebsocketBot(MiraiBot):
@classmethod @classmethod
@overrides(MiraiBot) @overrides(MiraiBot)
def register(cls, driver: "Driver", config: "Config", qq: int): def register(cls, driver: "Driver", config: "Config", qq: int):
cls.mirai_config = MiraiConfig(**config.dict()) """
cls.active = True :说明:
assert cls.mirai_config.auth_key is not None
assert cls.mirai_config.host is not None 注册该Adapter
assert cls.mirai_config.port is not None
:参数:
* ``driver: Driver``: 程序所使用的``Driver``
* ``config: Config``: 程序配置对象
* ``qq: int``: 要使用的Bot的QQ号 **注意: 在使用正向Websocket时必须指定该值!**
:返回:
- ``[type]``: [description]
"""
super().register(driver, config) super().register(driver, config)
cls.active = True
async def _bot_connection(): async def _bot_connection():
session: SessionManager = await SessionManager.new( session: SessionManager = await SessionManager.new(

View File

@ -5,6 +5,15 @@ from pydantic import BaseModel, Extra, Field
class Config(BaseModel): class Config(BaseModel):
"""
Mirai 配置类
:必填:
- ``mirai_auth_key``: mirai-api-http的auth_key
- ``mirai_host``: mirai-api-http的地址
- ``mirai_port``: mirai-api-http的端口
"""
auth_key: Optional[str] = Field(None, alias='mirai_auth_key') auth_key: Optional[str] = Field(None, alias='mirai_auth_key')
host: Optional[IPv4Address] = Field(None, alias='mirai_host') host: Optional[IPv4Address] = Field(None, alias='mirai_host')
port: Optional[int] = Field(None, alias='mirai_port') port: Optional[int] = Field(None, alias='mirai_port')

View File

@ -1,3 +1,10 @@
"""
\:\:\:warning 警告
事件中为了使代码更加整洁, 我们采用了与PEP8相符的命名规则取代Mirai原有的驼峰命名
部分字段可能与文档在符号上不一致
\:\:\:
"""
from .base import Event, GroupChatInfo, GroupInfo, UserPermission, PrivateChatInfo from .base import Event, GroupChatInfo, GroupInfo, UserPermission, PrivateChatInfo
from .message import * from .message import *
from .notice import * from .notice import *

View File

@ -45,9 +45,9 @@ class PrivateChatInfo(BaseModel):
class Event(BaseEvent): class Event(BaseEvent):
""" """
mirai-api-http 协议事件字段与 mirai-api-http 一致各事件字段参考 `mirai-api-http 文档`_ mirai-api-http 协议事件字段与 mirai-api-http 一致各事件字段参考 `mirai-api-http 事件类型`_
.. _mirai-api-http 文档: .. _mirai-api-http 事件类型:
https://github.com/project-mirai/mirai-api-http/blob/master/docs/EventType.md https://github.com/project-mirai/mirai-api-http/blob/master/docs/EventType.md
""" """
self_id: int self_id: int

View File

@ -55,10 +55,6 @@ class NewFriendRequestEvent(RequestEvent):
- ``1``: 拒绝添加好友 - ``1``: 拒绝添加好友
- ``2``: 拒绝添加好友并添加黑名单不再接收该用户的好友申请 - ``2``: 拒绝添加好友并添加黑名单不再接收该用户的好友申请
* ``message: str``: 回复的信息 * ``message: str``: 回复的信息
:返回:
- ``[type]``: [description]
""" """
assert operate > 0 assert operate > 0
return await bot.api.post('/resp/newFriendRequestEvent', return await bot.api.post('/resp/newFriendRequestEvent',

View File

@ -9,6 +9,7 @@ from nonebot.typing import overrides
class MessageType(str, Enum): class MessageType(str, Enum):
"""消息类型枚举类"""
SOURCE = 'Source' SOURCE = 'Source'
QUOTE = 'Quote' QUOTE = 'Quote'
AT = 'At' AT = 'At'
@ -25,6 +26,13 @@ class MessageType(str, Enum):
class MessageSegment(BaseMessageSegment): class MessageSegment(BaseMessageSegment):
"""
CQHTTP 协议 MessageSegment 适配具体方法参考 `mirai-api-http 消息类型`_
.. _mirai-api-http 消息类型:
https://github.com/project-mirai/mirai-api-http/blob/master/docs/MessageType.md
"""
type: MessageType type: MessageType
data: Dict[str, Any] data: Dict[str, Any]
@ -59,6 +67,7 @@ class MessageSegment(BaseMessageSegment):
return self.type == MessageType.PLAIN return self.type == MessageType.PLAIN
def as_dict(self) -> Dict[str, Any]: def as_dict(self) -> Dict[str, Any]:
"""导出可以被正常json序列化的结构体"""
return {'type': self.type.value, **self.data} return {'type': self.type.value, **self.data}
@classmethod @classmethod
@ -68,6 +77,19 @@ class MessageSegment(BaseMessageSegment):
@classmethod @classmethod
def quote(cls, id: int, group_id: int, sender_id: int, target_id: int, def quote(cls, id: int, group_id: int, sender_id: int, target_id: int,
origin: "MessageChain"): origin: "MessageChain"):
"""
:说明:
生成回复引用消息段
:参数:
* ``id: int``: 被引用回复的原消息的message_id
* ``group_id: int``: 被引用回复的原消息所接收的群号当为好友消息时为0
* ``sender_id: int``: 被引用回复的原消息的发送者的QQ号
* ``target_id: int``: 被引用回复的原消息的接收者者的QQ号或群号
* ``origin: MessageChain``: 被引用回复的原消息的消息链对象
"""
return cls(type=MessageType.QUOTE, return cls(type=MessageType.QUOTE,
id=id, id=id,
groupId=group_id, groupId=group_id,
@ -77,18 +99,51 @@ class MessageSegment(BaseMessageSegment):
@classmethod @classmethod
def at(cls, target: int): def at(cls, target: int):
"""
:说明:
@某个人
:参数:
* ``target: int``: 群员QQ号
"""
return cls(type=MessageType.AT, target=target) return cls(type=MessageType.AT, target=target)
@classmethod @classmethod
def at_all(cls): def at_all(cls):
"""
:说明:
@全体成员
"""
return cls(type=MessageType.AT_ALL) return cls(type=MessageType.AT_ALL)
@classmethod @classmethod
def face(cls, face_id: Optional[int] = None, name: Optional[str] = None): def face(cls, face_id: Optional[int] = None, name: Optional[str] = None):
"""
:说明:
发送QQ表情
:参数:
* ``face_id: Optional[int]``: QQ表情编号可选优先高于name
* ``name: Optional[str]``: QQ表情拼音可选
"""
return cls(type=MessageType.FACE, faceId=face_id, name=name) return cls(type=MessageType.FACE, faceId=face_id, name=name)
@classmethod @classmethod
def plain(cls, text: str): def plain(cls, text: str):
"""
:说明:
纯文本消息
:参数:
* ``text: str``: 文字消息
"""
return cls(type=MessageType.PLAIN, text=text) return cls(type=MessageType.PLAIN, text=text)
@classmethod @classmethod
@ -96,6 +151,21 @@ class MessageSegment(BaseMessageSegment):
image_id: Optional[str] = None, image_id: Optional[str] = None,
url: Optional[str] = None, url: Optional[str] = None,
path: Optional[str] = None): path: Optional[str] = None):
"""
:说明:
图片消息
:参数:
* ``image_id: Optional[str]``: 图片的image_id群图片与好友图片格式不同不为空时将忽略url属性
* ``url: Optional[str]``: 图片的URL发送时可作网络图片的链接
* ``path: Optional[str]``: 图片的路径发送本地图片
:返回:
- ``[type]``: [description]
"""
return cls(type=MessageType.IMAGE, imageId=image_id, url=url, path=path) return cls(type=MessageType.IMAGE, imageId=image_id, url=url, path=path)
@classmethod @classmethod
@ -103,6 +173,15 @@ class MessageSegment(BaseMessageSegment):
image_id: Optional[str] = None, image_id: Optional[str] = None,
url: Optional[str] = None, url: Optional[str] = None,
path: Optional[str] = None): path: Optional[str] = None):
"""
:说明:
闪照消息
:参数:
``image``
"""
return cls(type=MessageType.FLASH_IMAGE, return cls(type=MessageType.FLASH_IMAGE,
imageId=image_id, imageId=image_id,
url=url, url=url,
@ -113,6 +192,17 @@ class MessageSegment(BaseMessageSegment):
voice_id: Optional[str] = None, voice_id: Optional[str] = None,
url: Optional[str] = None, url: Optional[str] = None,
path: Optional[str] = None): path: Optional[str] = None):
"""
:说明:
语音消息
:参数:
* ``voice_id: Optional[str]``: 语音的voice_id不为空时将忽略url属性
* ``url: Optional[str]``: 语音的URL发送时可作网络语音的链接
* ``path: Optional[str]``: 语音的路径发送本地语音
"""
return cls(type=MessageType.FLASH_IMAGE, return cls(type=MessageType.FLASH_IMAGE,
imageId=voice_id, imageId=voice_id,
url=url, url=url,
@ -120,22 +210,69 @@ class MessageSegment(BaseMessageSegment):
@classmethod @classmethod
def xml(cls, xml: str): def xml(cls, xml: str):
"""
:说明:
XML消息
:参数:
* ``xml: str``: XML文本
"""
return cls(type=MessageType.XML, xml=xml) return cls(type=MessageType.XML, xml=xml)
@classmethod @classmethod
def json(cls, json: str): def json(cls, json: str):
"""
:说明:
Json消息
:参数:
* ``json: str``: Json文本
"""
return cls(type=MessageType.JSON, json=json) return cls(type=MessageType.JSON, json=json)
@classmethod @classmethod
def app(cls, content: str): def app(cls, content: str):
"""
:说明:
应用程序消息
:参数:
* ``content: str``: 内容
"""
return cls(type=MessageType.APP, content=content) return cls(type=MessageType.APP, content=content)
@classmethod @classmethod
def poke(cls, name: str): def poke(cls, name: str):
"""
:说明:
戳一戳消息
:参数:
* ``name: str``: 戳一戳的类型
- "Poke": 戳一戳
- "ShowLove": 比心
- "Like": 点赞
- "Heartbroken": 心碎
- "SixSixSix": 666
- "FangDaZhao": 放大招
"""
return cls(type=MessageType.POKE, name=name) return cls(type=MessageType.POKE, name=name)
class MessageChain(BaseMessage): #type:List[MessageSegment] class MessageChain(BaseMessage):
"""
Mirai 协议 Messaqge 适配
由于Mirai协议的Message实现较为特殊, 故使用MessageChain命名
"""
@overrides(BaseMessage) @overrides(BaseMessage)
def __init__(self, message: Union[List[Dict[str, Any]], def __init__(self, message: Union[List[Dict[str, Any]],
@ -166,6 +303,7 @@ class MessageChain(BaseMessage): #type:List[MessageSegment]
] ]
def export(self) -> List[Dict[str, Any]]: def export(self) -> List[Dict[str, Any]]:
"""导出为可以被正常json序列化的数组"""
return [ return [
*map(lambda segment: segment.as_dict(), self.copy()) # type: ignore *map(lambda segment: segment.as_dict(), self.copy()) # type: ignore
] ]