💡 🚸 complete comments and optimize usage in mirai adapter

This commit is contained in:
Mix 2021-02-01 13:19:37 +08:00
parent 923cbd3b8c
commit 8fe562e864
3 changed files with 189 additions and 88 deletions

View File

@ -4,8 +4,18 @@ Mirai-API-HTTP 协议适配
协议详情请看: `mirai-api-http 文档`_
.. mirai-api-http 文档:
\:\:\: tip
该Adapter目前仍然处在早期实验性阶段, 并未经过充分测试
如果你在使用过程中遇到了任何问题, 请前往 `Issue页面`_ 为我们提供反馈
\:\:\:
.. _mirai-api-http 文档:
https://github.com/project-mirai/mirai-api-http/tree/master/docs
.. _Issue页面
https://github.com/nonebot/nonebot2/issues
"""
from .bot import MiraiBot

View File

@ -1,34 +1,23 @@
from datetime import datetime, timedelta
from functools import wraps
from io import BytesIO
from ipaddress import IPv4Address
from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
from typing import (Any, Dict, List, NoReturn, Optional, Tuple, Union)
import httpx
from nonebot.adapters import Bot as BaseBot
from nonebot.config import Config
from nonebot.drivers import Driver, WebSocket
from nonebot.exception import ActionFailed as BaseActionFailed
from nonebot.exception import RequestDenied
from nonebot.exception import ApiNotAvailable, RequestDenied
from nonebot.log import logger
from nonebot.message import handle_event
from nonebot.typing import overrides
from nonebot.utils import escape_tag
from .config import Config as MiraiConfig
from .event import Event, FriendMessage, GroupMessage, TempMessage
from .message import MessageChain, MessageSegment
class ActionFailed(BaseActionFailed):
def __init__(self, **kwargs):
super().__init__('mirai')
self.data = kwargs.copy()
def __repr__(self):
return self.__class__.__name__ + '(%s)' % ', '.join(
map(lambda m: '%s=%r' % m, self.data.items()))
from .utils import catch_network_error, argument_validation
class SessionManager:
@ -39,19 +28,11 @@ class SessionManager:
def __init__(self, session_key: str, client: httpx.AsyncClient):
self.session_key, self.client = session_key, client
@staticmethod
def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]:
logger.opt(colors=True).debug(
f'Mirai API returned data: <y>{escape_tag(str(data))}</y>')
if isinstance(data, dict) and ('code' in data):
if data['code'] != 0:
raise ActionFailed(**data)
return data
@catch_network_error
async def post(self,
path: str,
*,
params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
params: Optional[Dict[str, Any]] = None) -> Any:
"""
:说明:
@ -75,13 +56,13 @@ class SessionManager:
timeout=3,
)
response.raise_for_status()
return self._raise_code(response.json())
return response.json()
@catch_network_error
async def request(self,
path: str,
*,
params: Optional[Dict[str,
Any]] = None) -> Dict[str, Any]:
params: Optional[Dict[str, Any]] = None) -> Any:
"""
:说明:
@ -91,10 +72,6 @@ class SessionManager:
* ``path: str``: 对应API路径
* ``params: Optional[Dict[str, Any]]``: 请求参数 (无需sessionKey)
:返回:
- ``Dict[str, Any]``: API 返回值
"""
response = await self.client.get(
path,
@ -105,25 +82,34 @@ class SessionManager:
timeout=3,
)
response.raise_for_status()
return self._raise_code(response.json())
return response.json()
async def upload(self, path: str, *, type: str,
file: Tuple[str, BytesIO]) -> Dict[str, Any]:
@catch_network_error
async def upload(self, path: str, *, params: Dict[str, Any]) -> Any:
"""
:说明:
file_type, file_io = file
response = await self.client.post(path,
data={
'sessionKey': self.session_key,
'type': type
},
files={file_type: file_io},
timeout=6)
以表单(``multipart/form-data``)形式主动提交API请求
:参数:
* ``path: str``: 对应API路径
* ``params: Dict[str, Any]``: 请求参数 (无需sessionKey)
"""
files = {k: v for k, v in params.items() if isinstance(v, BytesIO)}
form = {k: v for k, v in params.items() if k not in files}
response = await self.client.post(
path,
data=form,
files=files,
timeout=6,
)
response.raise_for_status()
return self._raise_code(response.json())
return response.json()
@classmethod
async def new(cls, self_id: int, *, host: IPv4Address, port: int,
auth_key: str):
auth_key: str) -> "SessionManager":
session = cls.get(self_id)
if session is not None:
return session
@ -145,7 +131,9 @@ class SessionManager:
return cls(session_key, client)
@classmethod
def get(cls, self_id: int, check_expire: bool = True):
def get(cls,
self_id: int,
check_expire: bool = True) -> Optional["SessionManager"]:
if self_id not in cls.sessions:
return None
key, time, client = cls.sessions[self_id]
@ -157,6 +145,13 @@ class SessionManager:
class MiraiBot(BaseBot):
"""
mirai-api-http 协议 Bot 适配
\:\:\: warning
API中为了使代码更加整洁, 我们采用了与PEP8相符的命名规则取代Mirai原有的驼峰命名
部分字段可能与文档在符号上不一致
\:\:\:
"""
@overrides(BaseBot)
@ -207,9 +202,9 @@ class MiraiBot(BaseBot):
@overrides(BaseBot)
def register(cls, driver: "Driver", config: "Config"):
cls.mirai_config = MiraiConfig(**config.dict())
assert cls.mirai_config.auth_key is not None
assert cls.mirai_config.host is not None
assert cls.mirai_config.port is not None
if (cls.mirai_config.auth_key and cls.mirai_config.host and
cls.mirai_config.port) is None:
raise ApiNotAvailable('mirai')
super().register(driver, config)
@overrides(BaseBot)
@ -222,7 +217,12 @@ class MiraiBot(BaseBot):
@overrides(BaseBot)
async def call_api(self, api: str, **data) -> NoReturn:
"""由于Mirai的HTTP API特殊性, 该API暂时无法实现"""
"""
由于Mirai的HTTP API特殊性, 该API暂时无法实现
\:\:\: tip
你可以使用 ``MiraiBot.api`` 中提供的调用方法来代替
\:\:\:
"""
raise NotImplementedError
@overrides(BaseBot)
@ -231,6 +231,7 @@ class MiraiBot(BaseBot):
raise NotImplementedError
@overrides(BaseBot)
@argument_validation
async def send(self,
event: Event,
message: Union[MessageChain, MessageSegment, str],
@ -245,10 +246,6 @@ class MiraiBot(BaseBot):
* ``event: Event``: Event对象
* ``message: Union[MessageChain, MessageSegment, str]``: 要发送的消息
* ``at_sender: bool``: 是否 @ 事件主题
:返回:
- ``Any``: API 调用返回数据
"""
if isinstance(message, MessageSegment):
message = MessageChain(message)
@ -269,6 +266,7 @@ class MiraiBot(BaseBot):
else:
raise ValueError(f'Unsupported event type {event!r}.')
@argument_validation
async def send_friend_message(self, target: int,
message_chain: MessageChain):
"""
@ -280,10 +278,6 @@ class MiraiBot(BaseBot):
* ``target: int``: 发送消息目标好友的 QQ
* ``message_chain: MessageChain``: 消息链是一个消息对象构成的数组
:返回:
- ``Any``: API 调用返回数据
"""
return await self.api.post('sendFriendMessage',
params={
@ -291,6 +285,7 @@ class MiraiBot(BaseBot):
'messageChain': message_chain.export()
})
@argument_validation
async def send_temp_message(self, qq: int, group: int,
message_chain: MessageChain):
"""
@ -303,10 +298,6 @@ class MiraiBot(BaseBot):
* ``qq: int``: 临时会话对象 QQ
* ``group: int``: 临时会话群号
* ``message_chain: MessageChain``: 消息链是一个消息对象构成的数组
:返回:
- ``Any``: API 调用返回数据
"""
return await self.api.post('sendTempMessage',
params={
@ -315,6 +306,7 @@ class MiraiBot(BaseBot):
'messageChain': message_chain.export()
})
@argument_validation
async def send_group_message(self,
group: int,
message_chain: MessageChain,
@ -329,10 +321,6 @@ class MiraiBot(BaseBot):
* ``group: int``: 发送消息目标群的群号
* ``message_chain: MessageChain``: 消息链是一个消息对象构成的数组
* ``quote: Optional[int]``: 引用一条消息的 message_id 进行回复
:返回:
- ``Any``: API 调用返回数据
"""
return await self.api.post('sendGroupMessage',
params={
@ -341,6 +329,7 @@ class MiraiBot(BaseBot):
'quote': quote
})
@argument_validation
async def recall(self, target: int):
"""
:说明:
@ -350,13 +339,10 @@ class MiraiBot(BaseBot):
:参数:
* ``target: int``: 需要撤回的消息的message_id
:返回:
- ``Any``: API 调用返回数据
"""
return await self.api.post('recall', params={'target': target})
@argument_validation
async def send_image_message(self, target: int, qq: int, group: int,
urls: List[str]) -> List[str]:
"""
@ -384,8 +370,9 @@ class MiraiBot(BaseBot):
'qq': qq,
'group': group,
'urls': urls
}) # type: ignore
})
@argument_validation
async def upload_image(self, type: str, img: BytesIO):
"""
:说明:
@ -396,15 +383,14 @@ class MiraiBot(BaseBot):
* ``type: str``: "friend" "group" "temp"
* ``img: BytesIO``: 图片的BytesIO对象
:返回:
- ``Any``: API 调用返回数据
"""
return await self.api.upload('uploadImage',
type=type,
file=('img', img))
params={
'type': type,
'img': img
})
@argument_validation
async def upload_voice(self, type: str, voice: BytesIO):
"""
:说明:
@ -415,15 +401,14 @@ class MiraiBot(BaseBot):
* ``type: str``: 当前仅支持 "group"
* ``voice: BytesIO``: 语音的BytesIO对象
:返回:
- ``Any``: API 调用返回数据
"""
return await self.api.upload('uploadVoice',
type=type,
file=('voice', voice))
params={
'type': type,
'voice': voice
})
@argument_validation
async def fetch_message(self, count: int = 10):
"""
:说明:
@ -437,6 +422,7 @@ class MiraiBot(BaseBot):
"""
return await self.api.request('fetchMessage', params={'count': count})
@argument_validation
async def fetch_latest_message(self, count: int = 10):
"""
:说明:
@ -451,6 +437,7 @@ class MiraiBot(BaseBot):
return await self.api.request('fetchLatestMessage',
params={'count': count})
@argument_validation
async def peek_message(self, count: int = 10):
"""
:说明:
@ -464,6 +451,7 @@ class MiraiBot(BaseBot):
"""
return await self.api.request('peekMessage', params={'count': count})
@argument_validation
async def peek_latest_message(self, count: int = 10):
"""
:说明:
@ -478,6 +466,7 @@ class MiraiBot(BaseBot):
return await self.api.request('peekLatestMessage',
params={'count': count})
@argument_validation
async def messsage_from_id(self, id: int):
"""
:说明:
@ -491,6 +480,7 @@ class MiraiBot(BaseBot):
"""
return await self.api.request('messageFromId', params={'id': id})
@argument_validation
async def count_message(self):
"""
:说明:
@ -499,6 +489,7 @@ class MiraiBot(BaseBot):
"""
return await self.api.request('countMessage')
@argument_validation
async def friend_list(self) -> List[Dict[str, Any]]:
"""
:说明:
@ -509,8 +500,9 @@ class MiraiBot(BaseBot):
- ``List[Dict[str, Any]]``: 返回的好友列表数据
"""
return await self.api.request('friendList') # type: ignore
return await self.api.request('friendList')
@argument_validation
async def group_list(self) -> List[Dict[str, Any]]:
"""
:说明:
@ -521,8 +513,9 @@ class MiraiBot(BaseBot):
- ``List[Dict[str, Any]]``: 返回的群列表数据
"""
return await self.api.request('groupList') # type: ignore
return await self.api.request('groupList')
@argument_validation
async def member_list(self, target: int) -> List[Dict[str, Any]]:
"""
:说明:
@ -537,9 +530,9 @@ class MiraiBot(BaseBot):
- ``List[Dict[str, Any]]``: 返回的群成员列表数据
"""
return await self.api.request('memberList',
params={'target': target}) # type: ignore
return await self.api.request('memberList', params={'target': target})
@argument_validation
async def mute(self, target: int, member_id: int, time: int):
"""
:说明:
@ -559,6 +552,7 @@ class MiraiBot(BaseBot):
'time': time
})
@argument_validation
async def unmute(self, target: int, member_id: int):
"""
:说明:
@ -576,6 +570,7 @@ class MiraiBot(BaseBot):
'memberId': member_id
})
@argument_validation
async def kick(self, target: int, member_id: int, msg: str):
"""
:说明:
@ -595,6 +590,7 @@ class MiraiBot(BaseBot):
'msg': msg
})
@argument_validation
async def quit(self, target: int):
"""
:说明:
@ -607,6 +603,7 @@ class MiraiBot(BaseBot):
"""
return await self.api.post('quit', params={'target': target})
@argument_validation
async def mute_all(self, target: int):
"""
:说明:
@ -619,6 +616,7 @@ class MiraiBot(BaseBot):
"""
return await self.api.post('muteAll', params={'target': target})
@argument_validation
async def unmute_all(self, target: int):
"""
:说明:
@ -631,6 +629,7 @@ class MiraiBot(BaseBot):
"""
return await self.api.post('unmuteAll', params={'target': target})
@argument_validation
async def group_config(self, target: int):
"""
:说明:
@ -656,6 +655,7 @@ class MiraiBot(BaseBot):
"""
return await self.api.request('groupConfig', params={'target': target})
@argument_validation
async def modify_group_config(self, target: int, config: Dict[str, Any]):
"""
:说明:
@ -673,6 +673,7 @@ class MiraiBot(BaseBot):
'config': config
})
@argument_validation
async def member_info(self, target: int, member_id: int):
"""
:说明:
@ -699,6 +700,7 @@ class MiraiBot(BaseBot):
'memberId': member_id
})
@argument_validation
async def modify_member_info(self, target: int, member_id: int,
info: Dict[str, Any]):
"""

View File

@ -0,0 +1,89 @@
from functools import wraps
from typing import Callable, Coroutine, TypeVar
import httpx
from pydantic import ValidationError, validate_arguments, Extra
import nonebot.exception as exception
from nonebot.log import logger
from nonebot.utils import escape_tag
_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])
_AnyCallable = TypeVar("_AnyCallable", bound=Callable)
class ActionFailed(exception.ActionFailed):
"""
:说明:
API 请求成功返回数据 API 操作失败
"""
def __init__(self, **kwargs):
super().__init__('mirai')
self.data = kwargs.copy()
def __repr__(self):
return self.__class__.__name__ + '(%s)' % ', '.join(
map(lambda m: '%s=%r' % m, self.data.items()))
class InvalidArgument(exception.AdapterException):
"""
:说明:
调用API的参数出错
"""
def __init__(self, **kwargs):
super().__init__('mirai')
def catch_network_error(function: _AsyncCallable) -> _AsyncCallable:
"""
:说明:
捕捉函数抛出的httpx网络异常并释放``NetworkError``异常
处理返回数据, 在code不为0时释放``ActionFailed``异常
\:\:\: warning
此装饰器只支持使用了httpx的异步函数
\:\:\:
"""
@wraps(function)
async def wrapper(*args, **kwargs):
try:
data = await function(*args, **kwargs)
except httpx.HTTPError:
raise exception.NetworkError('mirai')
logger.opt(colors=True).debug('<b>Mirai API returned data:</b> '
f'<y>{escape_tag(str(data))}</y>')
if isinstance(data, dict):
if data.get('code', 0) != 0:
raise ActionFailed(**data)
return data
return wrapper # type: ignore
def argument_validation(function: _AnyCallable) -> _AnyCallable:
"""
:说明:
通过函数签名中的类型注解来对传入参数进行运行时校验
会在参数出错时释放``InvalidArgument``异常
"""
function = validate_arguments(config={
'arbitrary_types_allowed': True,
'extra': Extra.forbid
})(function)
@wraps(function)
def wrapper(*args, **kwargs):
try:
return function(*args, **kwargs)
except ValidationError:
raise InvalidArgument
return wrapper # type: ignore