From 8fe562e864df0ff3012785c99dbb69eb53bdc550 Mon Sep 17 00:00:00 2001 From: Mix Date: Mon, 1 Feb 2021 13:19:37 +0800 Subject: [PATCH] :bulb: :children_crossing: complete comments and optimize usage in mirai adapter --- nonebot/adapters/mirai/__init__.py | 12 +- nonebot/adapters/mirai/bot.py | 176 +++++++++++++++-------------- nonebot/adapters/mirai/utils.py | 89 +++++++++++++++ 3 files changed, 189 insertions(+), 88 deletions(-) create mode 100644 nonebot/adapters/mirai/utils.py diff --git a/nonebot/adapters/mirai/__init__.py b/nonebot/adapters/mirai/__init__.py index b209657e..75afaff8 100644 --- a/nonebot/adapters/mirai/__init__.py +++ b/nonebot/adapters/mirai/__init__.py @@ -4,8 +4,18 @@ Mirai-API-HTTP 协议适配 协议详情请看: `mirai-api-http 文档`_ -.. mirai-api-http 文档: +\:\:\: tip +该Adapter目前仍然处在早期实验性阶段, 并未经过充分测试 + +如果你在使用过程中遇到了任何问题, 请前往 `Issue页面`_ 为我们提供反馈 +\:\:\: + +.. _mirai-api-http 文档: https://github.com/project-mirai/mirai-api-http/tree/master/docs + +.. _Issue页面 + https://github.com/nonebot/nonebot2/issues + """ from .bot import MiraiBot diff --git a/nonebot/adapters/mirai/bot.py b/nonebot/adapters/mirai/bot.py index 74c4f602..ed0b9ae1 100644 --- a/nonebot/adapters/mirai/bot.py +++ b/nonebot/adapters/mirai/bot.py @@ -1,34 +1,23 @@ from datetime import datetime, timedelta +from functools import wraps from io import BytesIO from ipaddress import IPv4Address -from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union +from typing import (Any, Dict, List, NoReturn, Optional, Tuple, Union) import httpx from nonebot.adapters import Bot as BaseBot from nonebot.config import Config from nonebot.drivers import Driver, WebSocket -from nonebot.exception import ActionFailed as BaseActionFailed -from nonebot.exception import RequestDenied +from nonebot.exception import ApiNotAvailable, RequestDenied from nonebot.log import logger from nonebot.message import handle_event from nonebot.typing import overrides -from nonebot.utils import escape_tag from .config import Config as MiraiConfig from .event import Event, FriendMessage, GroupMessage, TempMessage from .message import MessageChain, MessageSegment - - -class ActionFailed(BaseActionFailed): - - def __init__(self, **kwargs): - super().__init__('mirai') - self.data = kwargs.copy() - - def __repr__(self): - return self.__class__.__name__ + '(%s)' % ', '.join( - map(lambda m: '%s=%r' % m, self.data.items())) +from .utils import catch_network_error, argument_validation class SessionManager: @@ -39,19 +28,11 @@ class SessionManager: def __init__(self, session_key: str, client: httpx.AsyncClient): self.session_key, self.client = session_key, client - @staticmethod - def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]: - logger.opt(colors=True).debug( - f'Mirai API returned data: {escape_tag(str(data))}') - if isinstance(data, dict) and ('code' in data): - if data['code'] != 0: - raise ActionFailed(**data) - return data - + @catch_network_error async def post(self, path: str, *, - params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + params: Optional[Dict[str, Any]] = None) -> Any: """ :说明: @@ -75,13 +56,13 @@ class SessionManager: timeout=3, ) response.raise_for_status() - return self._raise_code(response.json()) + return response.json() + @catch_network_error async def request(self, path: str, *, - params: Optional[Dict[str, - Any]] = None) -> Dict[str, Any]: + params: Optional[Dict[str, Any]] = None) -> Any: """ :说明: @@ -91,10 +72,6 @@ class SessionManager: * ``path: str``: 对应API路径 * ``params: Optional[Dict[str, Any]]``: 请求参数 (无需sessionKey) - - :返回: - - - ``Dict[str, Any]``: API 返回值 """ response = await self.client.get( path, @@ -105,25 +82,34 @@ class SessionManager: timeout=3, ) response.raise_for_status() - return self._raise_code(response.json()) + return response.json() - async def upload(self, path: str, *, type: str, - file: Tuple[str, BytesIO]) -> Dict[str, Any]: + @catch_network_error + async def upload(self, path: str, *, params: Dict[str, Any]) -> Any: + """ + :说明: - file_type, file_io = file - response = await self.client.post(path, - data={ - 'sessionKey': self.session_key, - 'type': type - }, - files={file_type: file_io}, - timeout=6) + 以表单(``multipart/form-data``)形式主动提交API请求 + + :参数: + + * ``path: str``: 对应API路径 + * ``params: Dict[str, Any]``: 请求参数 (无需sessionKey) + """ + files = {k: v for k, v in params.items() if isinstance(v, BytesIO)} + form = {k: v for k, v in params.items() if k not in files} + response = await self.client.post( + path, + data=form, + files=files, + timeout=6, + ) response.raise_for_status() - return self._raise_code(response.json()) + return response.json() @classmethod async def new(cls, self_id: int, *, host: IPv4Address, port: int, - auth_key: str): + auth_key: str) -> "SessionManager": session = cls.get(self_id) if session is not None: return session @@ -145,7 +131,9 @@ class SessionManager: return cls(session_key, client) @classmethod - def get(cls, self_id: int, check_expire: bool = True): + def get(cls, + self_id: int, + check_expire: bool = True) -> Optional["SessionManager"]: if self_id not in cls.sessions: return None key, time, client = cls.sessions[self_id] @@ -157,6 +145,13 @@ class SessionManager: class MiraiBot(BaseBot): """ mirai-api-http 协议 Bot 适配。 + + \:\:\: warning + API中为了使代码更加整洁, 我们采用了与PEP8相符的命名规则取代Mirai原有的驼峰命名 + + 部分字段可能与文档在符号上不一致 + \:\:\: + """ @overrides(BaseBot) @@ -207,9 +202,9 @@ class MiraiBot(BaseBot): @overrides(BaseBot) def register(cls, driver: "Driver", config: "Config"): cls.mirai_config = MiraiConfig(**config.dict()) - assert cls.mirai_config.auth_key is not None - assert cls.mirai_config.host is not None - assert cls.mirai_config.port is not None + if (cls.mirai_config.auth_key and cls.mirai_config.host and + cls.mirai_config.port) is None: + raise ApiNotAvailable('mirai') super().register(driver, config) @overrides(BaseBot) @@ -222,7 +217,12 @@ class MiraiBot(BaseBot): @overrides(BaseBot) async def call_api(self, api: str, **data) -> NoReturn: - """由于Mirai的HTTP API特殊性, 该API暂时无法实现""" + """ + 由于Mirai的HTTP API特殊性, 该API暂时无法实现 + \:\:\: tip + 你可以使用 ``MiraiBot.api`` 中提供的调用方法来代替 + \:\:\: + """ raise NotImplementedError @overrides(BaseBot) @@ -231,6 +231,7 @@ class MiraiBot(BaseBot): raise NotImplementedError @overrides(BaseBot) + @argument_validation async def send(self, event: Event, message: Union[MessageChain, MessageSegment, str], @@ -245,10 +246,6 @@ class MiraiBot(BaseBot): * ``event: Event``: Event对象 * ``message: Union[MessageChain, MessageSegment, str]``: 要发送的消息 * ``at_sender: bool``: 是否 @ 事件主题 - - :返回: - - - ``Any``: API 调用返回数据 """ if isinstance(message, MessageSegment): message = MessageChain(message) @@ -269,6 +266,7 @@ class MiraiBot(BaseBot): else: raise ValueError(f'Unsupported event type {event!r}.') + @argument_validation async def send_friend_message(self, target: int, message_chain: MessageChain): """ @@ -280,10 +278,6 @@ class MiraiBot(BaseBot): * ``target: int``: 发送消息目标好友的 QQ 号 * ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组 - - :返回: - - - ``Any``: API 调用返回数据 """ return await self.api.post('sendFriendMessage', params={ @@ -291,6 +285,7 @@ class MiraiBot(BaseBot): 'messageChain': message_chain.export() }) + @argument_validation async def send_temp_message(self, qq: int, group: int, message_chain: MessageChain): """ @@ -303,10 +298,6 @@ class MiraiBot(BaseBot): * ``qq: int``: 临时会话对象 QQ 号 * ``group: int``: 临时会话群号 * ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组 - - :返回: - - - ``Any``: API 调用返回数据 """ return await self.api.post('sendTempMessage', params={ @@ -315,6 +306,7 @@ class MiraiBot(BaseBot): 'messageChain': message_chain.export() }) + @argument_validation async def send_group_message(self, group: int, message_chain: MessageChain, @@ -329,10 +321,6 @@ class MiraiBot(BaseBot): * ``group: int``: 发送消息目标群的群号 * ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组 * ``quote: Optional[int]``: 引用一条消息的 message_id 进行回复 - - :返回: - - - ``Any``: API 调用返回数据 """ return await self.api.post('sendGroupMessage', params={ @@ -341,6 +329,7 @@ class MiraiBot(BaseBot): 'quote': quote }) + @argument_validation async def recall(self, target: int): """ :说明: @@ -350,13 +339,10 @@ class MiraiBot(BaseBot): :参数: * ``target: int``: 需要撤回的消息的message_id - - :返回: - - - ``Any``: API 调用返回数据 """ return await self.api.post('recall', params={'target': target}) + @argument_validation async def send_image_message(self, target: int, qq: int, group: int, urls: List[str]) -> List[str]: """ @@ -384,8 +370,9 @@ class MiraiBot(BaseBot): 'qq': qq, 'group': group, 'urls': urls - }) # type: ignore + }) + @argument_validation async def upload_image(self, type: str, img: BytesIO): """ :说明: @@ -396,15 +383,14 @@ class MiraiBot(BaseBot): * ``type: str``: "friend" 或 "group" 或 "temp" * ``img: BytesIO``: 图片的BytesIO对象 - - :返回: - - - ``Any``: API 调用返回数据 """ return await self.api.upload('uploadImage', - type=type, - file=('img', img)) + params={ + 'type': type, + 'img': img + }) + @argument_validation async def upload_voice(self, type: str, voice: BytesIO): """ :说明: @@ -415,15 +401,14 @@ class MiraiBot(BaseBot): * ``type: str``: 当前仅支持 "group" * ``voice: BytesIO``: 语音的BytesIO对象 - - :返回: - - - ``Any``: API 调用返回数据 """ return await self.api.upload('uploadVoice', - type=type, - file=('voice', voice)) + params={ + 'type': type, + 'voice': voice + }) + @argument_validation async def fetch_message(self, count: int = 10): """ :说明: @@ -437,6 +422,7 @@ class MiraiBot(BaseBot): """ return await self.api.request('fetchMessage', params={'count': count}) + @argument_validation async def fetch_latest_message(self, count: int = 10): """ :说明: @@ -451,6 +437,7 @@ class MiraiBot(BaseBot): return await self.api.request('fetchLatestMessage', params={'count': count}) + @argument_validation async def peek_message(self, count: int = 10): """ :说明: @@ -464,6 +451,7 @@ class MiraiBot(BaseBot): """ return await self.api.request('peekMessage', params={'count': count}) + @argument_validation async def peek_latest_message(self, count: int = 10): """ :说明: @@ -478,6 +466,7 @@ class MiraiBot(BaseBot): return await self.api.request('peekLatestMessage', params={'count': count}) + @argument_validation async def messsage_from_id(self, id: int): """ :说明: @@ -491,6 +480,7 @@ class MiraiBot(BaseBot): """ return await self.api.request('messageFromId', params={'id': id}) + @argument_validation async def count_message(self): """ :说明: @@ -499,6 +489,7 @@ class MiraiBot(BaseBot): """ return await self.api.request('countMessage') + @argument_validation async def friend_list(self) -> List[Dict[str, Any]]: """ :说明: @@ -509,8 +500,9 @@ class MiraiBot(BaseBot): - ``List[Dict[str, Any]]``: 返回的好友列表数据 """ - return await self.api.request('friendList') # type: ignore + return await self.api.request('friendList') + @argument_validation async def group_list(self) -> List[Dict[str, Any]]: """ :说明: @@ -521,8 +513,9 @@ class MiraiBot(BaseBot): - ``List[Dict[str, Any]]``: 返回的群列表数据 """ - return await self.api.request('groupList') # type: ignore + return await self.api.request('groupList') + @argument_validation async def member_list(self, target: int) -> List[Dict[str, Any]]: """ :说明: @@ -537,9 +530,9 @@ class MiraiBot(BaseBot): - ``List[Dict[str, Any]]``: 返回的群成员列表数据 """ - return await self.api.request('memberList', - params={'target': target}) # type: ignore + return await self.api.request('memberList', params={'target': target}) + @argument_validation async def mute(self, target: int, member_id: int, time: int): """ :说明: @@ -559,6 +552,7 @@ class MiraiBot(BaseBot): 'time': time }) + @argument_validation async def unmute(self, target: int, member_id: int): """ :说明: @@ -576,6 +570,7 @@ class MiraiBot(BaseBot): 'memberId': member_id }) + @argument_validation async def kick(self, target: int, member_id: int, msg: str): """ :说明: @@ -595,6 +590,7 @@ class MiraiBot(BaseBot): 'msg': msg }) + @argument_validation async def quit(self, target: int): """ :说明: @@ -607,6 +603,7 @@ class MiraiBot(BaseBot): """ return await self.api.post('quit', params={'target': target}) + @argument_validation async def mute_all(self, target: int): """ :说明: @@ -619,6 +616,7 @@ class MiraiBot(BaseBot): """ return await self.api.post('muteAll', params={'target': target}) + @argument_validation async def unmute_all(self, target: int): """ :说明: @@ -631,6 +629,7 @@ class MiraiBot(BaseBot): """ return await self.api.post('unmuteAll', params={'target': target}) + @argument_validation async def group_config(self, target: int): """ :说明: @@ -656,6 +655,7 @@ class MiraiBot(BaseBot): """ return await self.api.request('groupConfig', params={'target': target}) + @argument_validation async def modify_group_config(self, target: int, config: Dict[str, Any]): """ :说明: @@ -673,6 +673,7 @@ class MiraiBot(BaseBot): 'config': config }) + @argument_validation async def member_info(self, target: int, member_id: int): """ :说明: @@ -699,6 +700,7 @@ class MiraiBot(BaseBot): 'memberId': member_id }) + @argument_validation async def modify_member_info(self, target: int, member_id: int, info: Dict[str, Any]): """ diff --git a/nonebot/adapters/mirai/utils.py b/nonebot/adapters/mirai/utils.py new file mode 100644 index 00000000..0a4b4a1b --- /dev/null +++ b/nonebot/adapters/mirai/utils.py @@ -0,0 +1,89 @@ +from functools import wraps +from typing import Callable, Coroutine, TypeVar + +import httpx +from pydantic import ValidationError, validate_arguments, Extra + +import nonebot.exception as exception +from nonebot.log import logger +from nonebot.utils import escape_tag + +_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine]) +_AnyCallable = TypeVar("_AnyCallable", bound=Callable) + + +class ActionFailed(exception.ActionFailed): + """ + :说明: + + API 请求成功返回数据,但 API 操作失败。 + """ + + def __init__(self, **kwargs): + super().__init__('mirai') + self.data = kwargs.copy() + + def __repr__(self): + return self.__class__.__name__ + '(%s)' % ', '.join( + map(lambda m: '%s=%r' % m, self.data.items())) + + +class InvalidArgument(exception.AdapterException): + """ + :说明: + + 调用API的参数出错 + """ + + def __init__(self, **kwargs): + super().__init__('mirai') + + +def catch_network_error(function: _AsyncCallable) -> _AsyncCallable: + """ + :说明: + + 捕捉函数抛出的httpx网络异常并释放``NetworkError``异常 + 处理返回数据, 在code不为0时释放``ActionFailed``异常 + + \:\:\: warning + 此装饰器只支持使用了httpx的异步函数 + \:\:\: + """ + + @wraps(function) + async def wrapper(*args, **kwargs): + try: + data = await function(*args, **kwargs) + except httpx.HTTPError: + raise exception.NetworkError('mirai') + logger.opt(colors=True).debug('Mirai API returned data: ' + f'{escape_tag(str(data))}') + if isinstance(data, dict): + if data.get('code', 0) != 0: + raise ActionFailed(**data) + return data + + return wrapper # type: ignore + + +def argument_validation(function: _AnyCallable) -> _AnyCallable: + """ + :说明: + + 通过函数签名中的类型注解来对传入参数进行运行时校验 + 会在参数出错时释放``InvalidArgument``异常 + """ + function = validate_arguments(config={ + 'arbitrary_types_allowed': True, + 'extra': Extra.forbid + })(function) + + @wraps(function) + def wrapper(*args, **kwargs): + try: + return function(*args, **kwargs) + except ValidationError: + raise InvalidArgument + + return wrapper # type: ignore