💡 🚸 complete comments and optimize usage in mirai adapter

This commit is contained in:
Mix 2021-02-01 13:19:37 +08:00
parent 923cbd3b8c
commit 8fe562e864
3 changed files with 189 additions and 88 deletions

View File

@ -4,8 +4,18 @@ Mirai-API-HTTP 协议适配
协议详情请看: `mirai-api-http 文档`_ 协议详情请看: `mirai-api-http 文档`_
.. mirai-api-http 文档: \:\:\: tip
该Adapter目前仍然处在早期实验性阶段, 并未经过充分测试
如果你在使用过程中遇到了任何问题, 请前往 `Issue页面`_ 为我们提供反馈
\:\:\:
.. _mirai-api-http 文档:
https://github.com/project-mirai/mirai-api-http/tree/master/docs https://github.com/project-mirai/mirai-api-http/tree/master/docs
.. _Issue页面
https://github.com/nonebot/nonebot2/issues
""" """
from .bot import MiraiBot from .bot import MiraiBot

View File

@ -1,34 +1,23 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import wraps
from io import BytesIO from io import BytesIO
from ipaddress import IPv4Address from ipaddress import IPv4Address
from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union from typing import (Any, Dict, List, NoReturn, Optional, Tuple, Union)
import httpx import httpx
from nonebot.adapters import Bot as BaseBot from nonebot.adapters import Bot as BaseBot
from nonebot.config import Config from nonebot.config import Config
from nonebot.drivers import Driver, WebSocket from nonebot.drivers import Driver, WebSocket
from nonebot.exception import ActionFailed as BaseActionFailed from nonebot.exception import ApiNotAvailable, RequestDenied
from nonebot.exception import RequestDenied
from nonebot.log import logger from nonebot.log import logger
from nonebot.message import handle_event from nonebot.message import handle_event
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.utils import escape_tag
from .config import Config as MiraiConfig from .config import Config as MiraiConfig
from .event import Event, FriendMessage, GroupMessage, TempMessage from .event import Event, FriendMessage, GroupMessage, TempMessage
from .message import MessageChain, MessageSegment from .message import MessageChain, MessageSegment
from .utils import catch_network_error, argument_validation
class ActionFailed(BaseActionFailed):
def __init__(self, **kwargs):
super().__init__('mirai')
self.data = kwargs.copy()
def __repr__(self):
return self.__class__.__name__ + '(%s)' % ', '.join(
map(lambda m: '%s=%r' % m, self.data.items()))
class SessionManager: class SessionManager:
@ -39,19 +28,11 @@ class SessionManager:
def __init__(self, session_key: str, client: httpx.AsyncClient): def __init__(self, session_key: str, client: httpx.AsyncClient):
self.session_key, self.client = session_key, client self.session_key, self.client = session_key, client
@staticmethod @catch_network_error
def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]:
logger.opt(colors=True).debug(
f'Mirai API returned data: <y>{escape_tag(str(data))}</y>')
if isinstance(data, dict) and ('code' in data):
if data['code'] != 0:
raise ActionFailed(**data)
return data
async def post(self, async def post(self,
path: str, path: str,
*, *,
params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: params: Optional[Dict[str, Any]] = None) -> Any:
""" """
:说明: :说明:
@ -75,13 +56,13 @@ class SessionManager:
timeout=3, timeout=3,
) )
response.raise_for_status() response.raise_for_status()
return self._raise_code(response.json()) return response.json()
@catch_network_error
async def request(self, async def request(self,
path: str, path: str,
*, *,
params: Optional[Dict[str, params: Optional[Dict[str, Any]] = None) -> Any:
Any]] = None) -> Dict[str, Any]:
""" """
:说明: :说明:
@ -91,10 +72,6 @@ class SessionManager:
* ``path: str``: 对应API路径 * ``path: str``: 对应API路径
* ``params: Optional[Dict[str, Any]]``: 请求参数 (无需sessionKey) * ``params: Optional[Dict[str, Any]]``: 请求参数 (无需sessionKey)
:返回:
- ``Dict[str, Any]``: API 返回值
""" """
response = await self.client.get( response = await self.client.get(
path, path,
@ -105,25 +82,34 @@ class SessionManager:
timeout=3, timeout=3,
) )
response.raise_for_status() response.raise_for_status()
return self._raise_code(response.json()) return response.json()
async def upload(self, path: str, *, type: str, @catch_network_error
file: Tuple[str, BytesIO]) -> Dict[str, Any]: async def upload(self, path: str, *, params: Dict[str, Any]) -> Any:
"""
:说明:
file_type, file_io = file 以表单(``multipart/form-data``)形式主动提交API请求
response = await self.client.post(path,
data={ :参数:
'sessionKey': self.session_key,
'type': type * ``path: str``: 对应API路径
}, * ``params: Dict[str, Any]``: 请求参数 (无需sessionKey)
files={file_type: file_io}, """
timeout=6) files = {k: v for k, v in params.items() if isinstance(v, BytesIO)}
form = {k: v for k, v in params.items() if k not in files}
response = await self.client.post(
path,
data=form,
files=files,
timeout=6,
)
response.raise_for_status() response.raise_for_status()
return self._raise_code(response.json()) return response.json()
@classmethod @classmethod
async def new(cls, self_id: int, *, host: IPv4Address, port: int, async def new(cls, self_id: int, *, host: IPv4Address, port: int,
auth_key: str): auth_key: str) -> "SessionManager":
session = cls.get(self_id) session = cls.get(self_id)
if session is not None: if session is not None:
return session return session
@ -145,7 +131,9 @@ class SessionManager:
return cls(session_key, client) return cls(session_key, client)
@classmethod @classmethod
def get(cls, self_id: int, check_expire: bool = True): def get(cls,
self_id: int,
check_expire: bool = True) -> Optional["SessionManager"]:
if self_id not in cls.sessions: if self_id not in cls.sessions:
return None return None
key, time, client = cls.sessions[self_id] key, time, client = cls.sessions[self_id]
@ -157,6 +145,13 @@ class SessionManager:
class MiraiBot(BaseBot): class MiraiBot(BaseBot):
""" """
mirai-api-http 协议 Bot 适配 mirai-api-http 协议 Bot 适配
\:\:\: warning
API中为了使代码更加整洁, 我们采用了与PEP8相符的命名规则取代Mirai原有的驼峰命名
部分字段可能与文档在符号上不一致
\:\:\:
""" """
@overrides(BaseBot) @overrides(BaseBot)
@ -207,9 +202,9 @@ class MiraiBot(BaseBot):
@overrides(BaseBot) @overrides(BaseBot)
def register(cls, driver: "Driver", config: "Config"): def register(cls, driver: "Driver", config: "Config"):
cls.mirai_config = MiraiConfig(**config.dict()) cls.mirai_config = MiraiConfig(**config.dict())
assert cls.mirai_config.auth_key is not None if (cls.mirai_config.auth_key and cls.mirai_config.host and
assert cls.mirai_config.host is not None cls.mirai_config.port) is None:
assert cls.mirai_config.port is not None raise ApiNotAvailable('mirai')
super().register(driver, config) super().register(driver, config)
@overrides(BaseBot) @overrides(BaseBot)
@ -222,7 +217,12 @@ class MiraiBot(BaseBot):
@overrides(BaseBot) @overrides(BaseBot)
async def call_api(self, api: str, **data) -> NoReturn: async def call_api(self, api: str, **data) -> NoReturn:
"""由于Mirai的HTTP API特殊性, 该API暂时无法实现""" """
由于Mirai的HTTP API特殊性, 该API暂时无法实现
\:\:\: tip
你可以使用 ``MiraiBot.api`` 中提供的调用方法来代替
\:\:\:
"""
raise NotImplementedError raise NotImplementedError
@overrides(BaseBot) @overrides(BaseBot)
@ -231,6 +231,7 @@ class MiraiBot(BaseBot):
raise NotImplementedError raise NotImplementedError
@overrides(BaseBot) @overrides(BaseBot)
@argument_validation
async def send(self, async def send(self,
event: Event, event: Event,
message: Union[MessageChain, MessageSegment, str], message: Union[MessageChain, MessageSegment, str],
@ -245,10 +246,6 @@ class MiraiBot(BaseBot):
* ``event: Event``: Event对象 * ``event: Event``: Event对象
* ``message: Union[MessageChain, MessageSegment, str]``: 要发送的消息 * ``message: Union[MessageChain, MessageSegment, str]``: 要发送的消息
* ``at_sender: bool``: 是否 @ 事件主题 * ``at_sender: bool``: 是否 @ 事件主题
:返回:
- ``Any``: API 调用返回数据
""" """
if isinstance(message, MessageSegment): if isinstance(message, MessageSegment):
message = MessageChain(message) message = MessageChain(message)
@ -269,6 +266,7 @@ class MiraiBot(BaseBot):
else: else:
raise ValueError(f'Unsupported event type {event!r}.') raise ValueError(f'Unsupported event type {event!r}.')
@argument_validation
async def send_friend_message(self, target: int, async def send_friend_message(self, target: int,
message_chain: MessageChain): message_chain: MessageChain):
""" """
@ -280,10 +278,6 @@ class MiraiBot(BaseBot):
* ``target: int``: 发送消息目标好友的 QQ * ``target: int``: 发送消息目标好友的 QQ
* ``message_chain: MessageChain``: 消息链是一个消息对象构成的数组 * ``message_chain: MessageChain``: 消息链是一个消息对象构成的数组
:返回:
- ``Any``: API 调用返回数据
""" """
return await self.api.post('sendFriendMessage', return await self.api.post('sendFriendMessage',
params={ params={
@ -291,6 +285,7 @@ class MiraiBot(BaseBot):
'messageChain': message_chain.export() 'messageChain': message_chain.export()
}) })
@argument_validation
async def send_temp_message(self, qq: int, group: int, async def send_temp_message(self, qq: int, group: int,
message_chain: MessageChain): message_chain: MessageChain):
""" """
@ -303,10 +298,6 @@ class MiraiBot(BaseBot):
* ``qq: int``: 临时会话对象 QQ * ``qq: int``: 临时会话对象 QQ
* ``group: int``: 临时会话群号 * ``group: int``: 临时会话群号
* ``message_chain: MessageChain``: 消息链是一个消息对象构成的数组 * ``message_chain: MessageChain``: 消息链是一个消息对象构成的数组
:返回:
- ``Any``: API 调用返回数据
""" """
return await self.api.post('sendTempMessage', return await self.api.post('sendTempMessage',
params={ params={
@ -315,6 +306,7 @@ class MiraiBot(BaseBot):
'messageChain': message_chain.export() 'messageChain': message_chain.export()
}) })
@argument_validation
async def send_group_message(self, async def send_group_message(self,
group: int, group: int,
message_chain: MessageChain, message_chain: MessageChain,
@ -329,10 +321,6 @@ class MiraiBot(BaseBot):
* ``group: int``: 发送消息目标群的群号 * ``group: int``: 发送消息目标群的群号
* ``message_chain: MessageChain``: 消息链是一个消息对象构成的数组 * ``message_chain: MessageChain``: 消息链是一个消息对象构成的数组
* ``quote: Optional[int]``: 引用一条消息的 message_id 进行回复 * ``quote: Optional[int]``: 引用一条消息的 message_id 进行回复
:返回:
- ``Any``: API 调用返回数据
""" """
return await self.api.post('sendGroupMessage', return await self.api.post('sendGroupMessage',
params={ params={
@ -341,6 +329,7 @@ class MiraiBot(BaseBot):
'quote': quote 'quote': quote
}) })
@argument_validation
async def recall(self, target: int): async def recall(self, target: int):
""" """
:说明: :说明:
@ -350,13 +339,10 @@ class MiraiBot(BaseBot):
:参数: :参数:
* ``target: int``: 需要撤回的消息的message_id * ``target: int``: 需要撤回的消息的message_id
:返回:
- ``Any``: API 调用返回数据
""" """
return await self.api.post('recall', params={'target': target}) return await self.api.post('recall', params={'target': target})
@argument_validation
async def send_image_message(self, target: int, qq: int, group: int, async def send_image_message(self, target: int, qq: int, group: int,
urls: List[str]) -> List[str]: urls: List[str]) -> List[str]:
""" """
@ -384,8 +370,9 @@ class MiraiBot(BaseBot):
'qq': qq, 'qq': qq,
'group': group, 'group': group,
'urls': urls 'urls': urls
}) # type: ignore })
@argument_validation
async def upload_image(self, type: str, img: BytesIO): async def upload_image(self, type: str, img: BytesIO):
""" """
:说明: :说明:
@ -396,15 +383,14 @@ class MiraiBot(BaseBot):
* ``type: str``: "friend" "group" "temp" * ``type: str``: "friend" "group" "temp"
* ``img: BytesIO``: 图片的BytesIO对象 * ``img: BytesIO``: 图片的BytesIO对象
:返回:
- ``Any``: API 调用返回数据
""" """
return await self.api.upload('uploadImage', return await self.api.upload('uploadImage',
type=type, params={
file=('img', img)) 'type': type,
'img': img
})
@argument_validation
async def upload_voice(self, type: str, voice: BytesIO): async def upload_voice(self, type: str, voice: BytesIO):
""" """
:说明: :说明:
@ -415,15 +401,14 @@ class MiraiBot(BaseBot):
* ``type: str``: 当前仅支持 "group" * ``type: str``: 当前仅支持 "group"
* ``voice: BytesIO``: 语音的BytesIO对象 * ``voice: BytesIO``: 语音的BytesIO对象
:返回:
- ``Any``: API 调用返回数据
""" """
return await self.api.upload('uploadVoice', return await self.api.upload('uploadVoice',
type=type, params={
file=('voice', voice)) 'type': type,
'voice': voice
})
@argument_validation
async def fetch_message(self, count: int = 10): async def fetch_message(self, count: int = 10):
""" """
:说明: :说明:
@ -437,6 +422,7 @@ class MiraiBot(BaseBot):
""" """
return await self.api.request('fetchMessage', params={'count': count}) return await self.api.request('fetchMessage', params={'count': count})
@argument_validation
async def fetch_latest_message(self, count: int = 10): async def fetch_latest_message(self, count: int = 10):
""" """
:说明: :说明:
@ -451,6 +437,7 @@ class MiraiBot(BaseBot):
return await self.api.request('fetchLatestMessage', return await self.api.request('fetchLatestMessage',
params={'count': count}) params={'count': count})
@argument_validation
async def peek_message(self, count: int = 10): async def peek_message(self, count: int = 10):
""" """
:说明: :说明:
@ -464,6 +451,7 @@ class MiraiBot(BaseBot):
""" """
return await self.api.request('peekMessage', params={'count': count}) return await self.api.request('peekMessage', params={'count': count})
@argument_validation
async def peek_latest_message(self, count: int = 10): async def peek_latest_message(self, count: int = 10):
""" """
:说明: :说明:
@ -478,6 +466,7 @@ class MiraiBot(BaseBot):
return await self.api.request('peekLatestMessage', return await self.api.request('peekLatestMessage',
params={'count': count}) params={'count': count})
@argument_validation
async def messsage_from_id(self, id: int): async def messsage_from_id(self, id: int):
""" """
:说明: :说明:
@ -491,6 +480,7 @@ class MiraiBot(BaseBot):
""" """
return await self.api.request('messageFromId', params={'id': id}) return await self.api.request('messageFromId', params={'id': id})
@argument_validation
async def count_message(self): async def count_message(self):
""" """
:说明: :说明:
@ -499,6 +489,7 @@ class MiraiBot(BaseBot):
""" """
return await self.api.request('countMessage') return await self.api.request('countMessage')
@argument_validation
async def friend_list(self) -> List[Dict[str, Any]]: async def friend_list(self) -> List[Dict[str, Any]]:
""" """
:说明: :说明:
@ -509,8 +500,9 @@ class MiraiBot(BaseBot):
- ``List[Dict[str, Any]]``: 返回的好友列表数据 - ``List[Dict[str, Any]]``: 返回的好友列表数据
""" """
return await self.api.request('friendList') # type: ignore return await self.api.request('friendList')
@argument_validation
async def group_list(self) -> List[Dict[str, Any]]: async def group_list(self) -> List[Dict[str, Any]]:
""" """
:说明: :说明:
@ -521,8 +513,9 @@ class MiraiBot(BaseBot):
- ``List[Dict[str, Any]]``: 返回的群列表数据 - ``List[Dict[str, Any]]``: 返回的群列表数据
""" """
return await self.api.request('groupList') # type: ignore return await self.api.request('groupList')
@argument_validation
async def member_list(self, target: int) -> List[Dict[str, Any]]: async def member_list(self, target: int) -> List[Dict[str, Any]]:
""" """
:说明: :说明:
@ -537,9 +530,9 @@ class MiraiBot(BaseBot):
- ``List[Dict[str, Any]]``: 返回的群成员列表数据 - ``List[Dict[str, Any]]``: 返回的群成员列表数据
""" """
return await self.api.request('memberList', return await self.api.request('memberList', params={'target': target})
params={'target': target}) # type: ignore
@argument_validation
async def mute(self, target: int, member_id: int, time: int): async def mute(self, target: int, member_id: int, time: int):
""" """
:说明: :说明:
@ -559,6 +552,7 @@ class MiraiBot(BaseBot):
'time': time 'time': time
}) })
@argument_validation
async def unmute(self, target: int, member_id: int): async def unmute(self, target: int, member_id: int):
""" """
:说明: :说明:
@ -576,6 +570,7 @@ class MiraiBot(BaseBot):
'memberId': member_id 'memberId': member_id
}) })
@argument_validation
async def kick(self, target: int, member_id: int, msg: str): async def kick(self, target: int, member_id: int, msg: str):
""" """
:说明: :说明:
@ -595,6 +590,7 @@ class MiraiBot(BaseBot):
'msg': msg 'msg': msg
}) })
@argument_validation
async def quit(self, target: int): async def quit(self, target: int):
""" """
:说明: :说明:
@ -607,6 +603,7 @@ class MiraiBot(BaseBot):
""" """
return await self.api.post('quit', params={'target': target}) return await self.api.post('quit', params={'target': target})
@argument_validation
async def mute_all(self, target: int): async def mute_all(self, target: int):
""" """
:说明: :说明:
@ -619,6 +616,7 @@ class MiraiBot(BaseBot):
""" """
return await self.api.post('muteAll', params={'target': target}) return await self.api.post('muteAll', params={'target': target})
@argument_validation
async def unmute_all(self, target: int): async def unmute_all(self, target: int):
""" """
:说明: :说明:
@ -631,6 +629,7 @@ class MiraiBot(BaseBot):
""" """
return await self.api.post('unmuteAll', params={'target': target}) return await self.api.post('unmuteAll', params={'target': target})
@argument_validation
async def group_config(self, target: int): async def group_config(self, target: int):
""" """
:说明: :说明:
@ -656,6 +655,7 @@ class MiraiBot(BaseBot):
""" """
return await self.api.request('groupConfig', params={'target': target}) return await self.api.request('groupConfig', params={'target': target})
@argument_validation
async def modify_group_config(self, target: int, config: Dict[str, Any]): async def modify_group_config(self, target: int, config: Dict[str, Any]):
""" """
:说明: :说明:
@ -673,6 +673,7 @@ class MiraiBot(BaseBot):
'config': config 'config': config
}) })
@argument_validation
async def member_info(self, target: int, member_id: int): async def member_info(self, target: int, member_id: int):
""" """
:说明: :说明:
@ -699,6 +700,7 @@ class MiraiBot(BaseBot):
'memberId': member_id 'memberId': member_id
}) })
@argument_validation
async def modify_member_info(self, target: int, member_id: int, async def modify_member_info(self, target: int, member_id: int,
info: Dict[str, Any]): info: Dict[str, Any]):
""" """

View File

@ -0,0 +1,89 @@
from functools import wraps
from typing import Callable, Coroutine, TypeVar
import httpx
from pydantic import ValidationError, validate_arguments, Extra
import nonebot.exception as exception
from nonebot.log import logger
from nonebot.utils import escape_tag
_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])
_AnyCallable = TypeVar("_AnyCallable", bound=Callable)
class ActionFailed(exception.ActionFailed):
"""
:说明:
API 请求成功返回数据 API 操作失败
"""
def __init__(self, **kwargs):
super().__init__('mirai')
self.data = kwargs.copy()
def __repr__(self):
return self.__class__.__name__ + '(%s)' % ', '.join(
map(lambda m: '%s=%r' % m, self.data.items()))
class InvalidArgument(exception.AdapterException):
"""
:说明:
调用API的参数出错
"""
def __init__(self, **kwargs):
super().__init__('mirai')
def catch_network_error(function: _AsyncCallable) -> _AsyncCallable:
"""
:说明:
捕捉函数抛出的httpx网络异常并释放``NetworkError``异常
处理返回数据, 在code不为0时释放``ActionFailed``异常
\:\:\: warning
此装饰器只支持使用了httpx的异步函数
\:\:\:
"""
@wraps(function)
async def wrapper(*args, **kwargs):
try:
data = await function(*args, **kwargs)
except httpx.HTTPError:
raise exception.NetworkError('mirai')
logger.opt(colors=True).debug('<b>Mirai API returned data:</b> '
f'<y>{escape_tag(str(data))}</y>')
if isinstance(data, dict):
if data.get('code', 0) != 0:
raise ActionFailed(**data)
return data
return wrapper # type: ignore
def argument_validation(function: _AnyCallable) -> _AnyCallable:
"""
:说明:
通过函数签名中的类型注解来对传入参数进行运行时校验
会在参数出错时释放``InvalidArgument``异常
"""
function = validate_arguments(config={
'arbitrary_types_allowed': True,
'extra': Extra.forbid
})(function)
@wraps(function)
def wrapper(*args, **kwargs):
try:
return function(*args, **kwargs)
except ValidationError:
raise InvalidArgument
return wrapper # type: ignore