mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-01-19 01:18:19 +08:00
💡 🚸 complete comments and optimize usage in mirai adapter
This commit is contained in:
parent
923cbd3b8c
commit
8fe562e864
@ -4,8 +4,18 @@ Mirai-API-HTTP 协议适配
|
||||
|
||||
协议详情请看: `mirai-api-http 文档`_
|
||||
|
||||
.. mirai-api-http 文档:
|
||||
\:\:\: tip
|
||||
该Adapter目前仍然处在早期实验性阶段, 并未经过充分测试
|
||||
|
||||
如果你在使用过程中遇到了任何问题, 请前往 `Issue页面`_ 为我们提供反馈
|
||||
\:\:\:
|
||||
|
||||
.. _mirai-api-http 文档:
|
||||
https://github.com/project-mirai/mirai-api-http/tree/master/docs
|
||||
|
||||
.. _Issue页面
|
||||
https://github.com/nonebot/nonebot2/issues
|
||||
|
||||
"""
|
||||
|
||||
from .bot import MiraiBot
|
||||
|
@ -1,34 +1,23 @@
|
||||
from datetime import datetime, timedelta
|
||||
from functools import wraps
|
||||
from io import BytesIO
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union
|
||||
from typing import (Any, Dict, List, NoReturn, Optional, Tuple, Union)
|
||||
|
||||
import httpx
|
||||
|
||||
from nonebot.adapters import Bot as BaseBot
|
||||
from nonebot.config import Config
|
||||
from nonebot.drivers import Driver, WebSocket
|
||||
from nonebot.exception import ActionFailed as BaseActionFailed
|
||||
from nonebot.exception import RequestDenied
|
||||
from nonebot.exception import ApiNotAvailable, RequestDenied
|
||||
from nonebot.log import logger
|
||||
from nonebot.message import handle_event
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.utils import escape_tag
|
||||
|
||||
from .config import Config as MiraiConfig
|
||||
from .event import Event, FriendMessage, GroupMessage, TempMessage
|
||||
from .message import MessageChain, MessageSegment
|
||||
|
||||
|
||||
class ActionFailed(BaseActionFailed):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__('mirai')
|
||||
self.data = kwargs.copy()
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(%s)' % ', '.join(
|
||||
map(lambda m: '%s=%r' % m, self.data.items()))
|
||||
from .utils import catch_network_error, argument_validation
|
||||
|
||||
|
||||
class SessionManager:
|
||||
@ -39,19 +28,11 @@ class SessionManager:
|
||||
def __init__(self, session_key: str, client: httpx.AsyncClient):
|
||||
self.session_key, self.client = session_key, client
|
||||
|
||||
@staticmethod
|
||||
def _raise_code(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
logger.opt(colors=True).debug(
|
||||
f'Mirai API returned data: <y>{escape_tag(str(data))}</y>')
|
||||
if isinstance(data, dict) and ('code' in data):
|
||||
if data['code'] != 0:
|
||||
raise ActionFailed(**data)
|
||||
return data
|
||||
|
||||
@catch_network_error
|
||||
async def post(self,
|
||||
path: str,
|
||||
*,
|
||||
params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
params: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -75,13 +56,13 @@ class SessionManager:
|
||||
timeout=3,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return self._raise_code(response.json())
|
||||
return response.json()
|
||||
|
||||
@catch_network_error
|
||||
async def request(self,
|
||||
path: str,
|
||||
*,
|
||||
params: Optional[Dict[str,
|
||||
Any]] = None) -> Dict[str, Any]:
|
||||
params: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
@ -91,10 +72,6 @@ class SessionManager:
|
||||
|
||||
* ``path: str``: 对应API路径
|
||||
* ``params: Optional[Dict[str, Any]]``: 请求参数 (无需sessionKey)
|
||||
|
||||
:返回:
|
||||
|
||||
- ``Dict[str, Any]``: API 返回值
|
||||
"""
|
||||
response = await self.client.get(
|
||||
path,
|
||||
@ -105,25 +82,34 @@ class SessionManager:
|
||||
timeout=3,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return self._raise_code(response.json())
|
||||
return response.json()
|
||||
|
||||
async def upload(self, path: str, *, type: str,
|
||||
file: Tuple[str, BytesIO]) -> Dict[str, Any]:
|
||||
@catch_network_error
|
||||
async def upload(self, path: str, *, params: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
file_type, file_io = file
|
||||
response = await self.client.post(path,
|
||||
data={
|
||||
'sessionKey': self.session_key,
|
||||
'type': type
|
||||
},
|
||||
files={file_type: file_io},
|
||||
timeout=6)
|
||||
以表单(``multipart/form-data``)形式主动提交API请求
|
||||
|
||||
:参数:
|
||||
|
||||
* ``path: str``: 对应API路径
|
||||
* ``params: Dict[str, Any]``: 请求参数 (无需sessionKey)
|
||||
"""
|
||||
files = {k: v for k, v in params.items() if isinstance(v, BytesIO)}
|
||||
form = {k: v for k, v in params.items() if k not in files}
|
||||
response = await self.client.post(
|
||||
path,
|
||||
data=form,
|
||||
files=files,
|
||||
timeout=6,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return self._raise_code(response.json())
|
||||
return response.json()
|
||||
|
||||
@classmethod
|
||||
async def new(cls, self_id: int, *, host: IPv4Address, port: int,
|
||||
auth_key: str):
|
||||
auth_key: str) -> "SessionManager":
|
||||
session = cls.get(self_id)
|
||||
if session is not None:
|
||||
return session
|
||||
@ -145,7 +131,9 @@ class SessionManager:
|
||||
return cls(session_key, client)
|
||||
|
||||
@classmethod
|
||||
def get(cls, self_id: int, check_expire: bool = True):
|
||||
def get(cls,
|
||||
self_id: int,
|
||||
check_expire: bool = True) -> Optional["SessionManager"]:
|
||||
if self_id not in cls.sessions:
|
||||
return None
|
||||
key, time, client = cls.sessions[self_id]
|
||||
@ -157,6 +145,13 @@ class SessionManager:
|
||||
class MiraiBot(BaseBot):
|
||||
"""
|
||||
mirai-api-http 协议 Bot 适配。
|
||||
|
||||
\:\:\: warning
|
||||
API中为了使代码更加整洁, 我们采用了与PEP8相符的命名规则取代Mirai原有的驼峰命名
|
||||
|
||||
部分字段可能与文档在符号上不一致
|
||||
\:\:\:
|
||||
|
||||
"""
|
||||
|
||||
@overrides(BaseBot)
|
||||
@ -207,9 +202,9 @@ class MiraiBot(BaseBot):
|
||||
@overrides(BaseBot)
|
||||
def register(cls, driver: "Driver", config: "Config"):
|
||||
cls.mirai_config = MiraiConfig(**config.dict())
|
||||
assert cls.mirai_config.auth_key is not None
|
||||
assert cls.mirai_config.host is not None
|
||||
assert cls.mirai_config.port is not None
|
||||
if (cls.mirai_config.auth_key and cls.mirai_config.host and
|
||||
cls.mirai_config.port) is None:
|
||||
raise ApiNotAvailable('mirai')
|
||||
super().register(driver, config)
|
||||
|
||||
@overrides(BaseBot)
|
||||
@ -222,7 +217,12 @@ class MiraiBot(BaseBot):
|
||||
|
||||
@overrides(BaseBot)
|
||||
async def call_api(self, api: str, **data) -> NoReturn:
|
||||
"""由于Mirai的HTTP API特殊性, 该API暂时无法实现"""
|
||||
"""
|
||||
由于Mirai的HTTP API特殊性, 该API暂时无法实现
|
||||
\:\:\: tip
|
||||
你可以使用 ``MiraiBot.api`` 中提供的调用方法来代替
|
||||
\:\:\:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@overrides(BaseBot)
|
||||
@ -231,6 +231,7 @@ class MiraiBot(BaseBot):
|
||||
raise NotImplementedError
|
||||
|
||||
@overrides(BaseBot)
|
||||
@argument_validation
|
||||
async def send(self,
|
||||
event: Event,
|
||||
message: Union[MessageChain, MessageSegment, str],
|
||||
@ -245,10 +246,6 @@ class MiraiBot(BaseBot):
|
||||
* ``event: Event``: Event对象
|
||||
* ``message: Union[MessageChain, MessageSegment, str]``: 要发送的消息
|
||||
* ``at_sender: bool``: 是否 @ 事件主题
|
||||
|
||||
:返回:
|
||||
|
||||
- ``Any``: API 调用返回数据
|
||||
"""
|
||||
if isinstance(message, MessageSegment):
|
||||
message = MessageChain(message)
|
||||
@ -269,6 +266,7 @@ class MiraiBot(BaseBot):
|
||||
else:
|
||||
raise ValueError(f'Unsupported event type {event!r}.')
|
||||
|
||||
@argument_validation
|
||||
async def send_friend_message(self, target: int,
|
||||
message_chain: MessageChain):
|
||||
"""
|
||||
@ -280,10 +278,6 @@ class MiraiBot(BaseBot):
|
||||
|
||||
* ``target: int``: 发送消息目标好友的 QQ 号
|
||||
* ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
|
||||
|
||||
:返回:
|
||||
|
||||
- ``Any``: API 调用返回数据
|
||||
"""
|
||||
return await self.api.post('sendFriendMessage',
|
||||
params={
|
||||
@ -291,6 +285,7 @@ class MiraiBot(BaseBot):
|
||||
'messageChain': message_chain.export()
|
||||
})
|
||||
|
||||
@argument_validation
|
||||
async def send_temp_message(self, qq: int, group: int,
|
||||
message_chain: MessageChain):
|
||||
"""
|
||||
@ -303,10 +298,6 @@ class MiraiBot(BaseBot):
|
||||
* ``qq: int``: 临时会话对象 QQ 号
|
||||
* ``group: int``: 临时会话群号
|
||||
* ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
|
||||
|
||||
:返回:
|
||||
|
||||
- ``Any``: API 调用返回数据
|
||||
"""
|
||||
return await self.api.post('sendTempMessage',
|
||||
params={
|
||||
@ -315,6 +306,7 @@ class MiraiBot(BaseBot):
|
||||
'messageChain': message_chain.export()
|
||||
})
|
||||
|
||||
@argument_validation
|
||||
async def send_group_message(self,
|
||||
group: int,
|
||||
message_chain: MessageChain,
|
||||
@ -329,10 +321,6 @@ class MiraiBot(BaseBot):
|
||||
* ``group: int``: 发送消息目标群的群号
|
||||
* ``message_chain: MessageChain``: 消息链,是一个消息对象构成的数组
|
||||
* ``quote: Optional[int]``: 引用一条消息的 message_id 进行回复
|
||||
|
||||
:返回:
|
||||
|
||||
- ``Any``: API 调用返回数据
|
||||
"""
|
||||
return await self.api.post('sendGroupMessage',
|
||||
params={
|
||||
@ -341,6 +329,7 @@ class MiraiBot(BaseBot):
|
||||
'quote': quote
|
||||
})
|
||||
|
||||
@argument_validation
|
||||
async def recall(self, target: int):
|
||||
"""
|
||||
:说明:
|
||||
@ -350,13 +339,10 @@ class MiraiBot(BaseBot):
|
||||
:参数:
|
||||
|
||||
* ``target: int``: 需要撤回的消息的message_id
|
||||
|
||||
:返回:
|
||||
|
||||
- ``Any``: API 调用返回数据
|
||||
"""
|
||||
return await self.api.post('recall', params={'target': target})
|
||||
|
||||
@argument_validation
|
||||
async def send_image_message(self, target: int, qq: int, group: int,
|
||||
urls: List[str]) -> List[str]:
|
||||
"""
|
||||
@ -384,8 +370,9 @@ class MiraiBot(BaseBot):
|
||||
'qq': qq,
|
||||
'group': group,
|
||||
'urls': urls
|
||||
}) # type: ignore
|
||||
})
|
||||
|
||||
@argument_validation
|
||||
async def upload_image(self, type: str, img: BytesIO):
|
||||
"""
|
||||
:说明:
|
||||
@ -396,15 +383,14 @@ class MiraiBot(BaseBot):
|
||||
|
||||
* ``type: str``: "friend" 或 "group" 或 "temp"
|
||||
* ``img: BytesIO``: 图片的BytesIO对象
|
||||
|
||||
:返回:
|
||||
|
||||
- ``Any``: API 调用返回数据
|
||||
"""
|
||||
return await self.api.upload('uploadImage',
|
||||
type=type,
|
||||
file=('img', img))
|
||||
params={
|
||||
'type': type,
|
||||
'img': img
|
||||
})
|
||||
|
||||
@argument_validation
|
||||
async def upload_voice(self, type: str, voice: BytesIO):
|
||||
"""
|
||||
:说明:
|
||||
@ -415,15 +401,14 @@ class MiraiBot(BaseBot):
|
||||
|
||||
* ``type: str``: 当前仅支持 "group"
|
||||
* ``voice: BytesIO``: 语音的BytesIO对象
|
||||
|
||||
:返回:
|
||||
|
||||
- ``Any``: API 调用返回数据
|
||||
"""
|
||||
return await self.api.upload('uploadVoice',
|
||||
type=type,
|
||||
file=('voice', voice))
|
||||
params={
|
||||
'type': type,
|
||||
'voice': voice
|
||||
})
|
||||
|
||||
@argument_validation
|
||||
async def fetch_message(self, count: int = 10):
|
||||
"""
|
||||
:说明:
|
||||
@ -437,6 +422,7 @@ class MiraiBot(BaseBot):
|
||||
"""
|
||||
return await self.api.request('fetchMessage', params={'count': count})
|
||||
|
||||
@argument_validation
|
||||
async def fetch_latest_message(self, count: int = 10):
|
||||
"""
|
||||
:说明:
|
||||
@ -451,6 +437,7 @@ class MiraiBot(BaseBot):
|
||||
return await self.api.request('fetchLatestMessage',
|
||||
params={'count': count})
|
||||
|
||||
@argument_validation
|
||||
async def peek_message(self, count: int = 10):
|
||||
"""
|
||||
:说明:
|
||||
@ -464,6 +451,7 @@ class MiraiBot(BaseBot):
|
||||
"""
|
||||
return await self.api.request('peekMessage', params={'count': count})
|
||||
|
||||
@argument_validation
|
||||
async def peek_latest_message(self, count: int = 10):
|
||||
"""
|
||||
:说明:
|
||||
@ -478,6 +466,7 @@ class MiraiBot(BaseBot):
|
||||
return await self.api.request('peekLatestMessage',
|
||||
params={'count': count})
|
||||
|
||||
@argument_validation
|
||||
async def messsage_from_id(self, id: int):
|
||||
"""
|
||||
:说明:
|
||||
@ -491,6 +480,7 @@ class MiraiBot(BaseBot):
|
||||
"""
|
||||
return await self.api.request('messageFromId', params={'id': id})
|
||||
|
||||
@argument_validation
|
||||
async def count_message(self):
|
||||
"""
|
||||
:说明:
|
||||
@ -499,6 +489,7 @@ class MiraiBot(BaseBot):
|
||||
"""
|
||||
return await self.api.request('countMessage')
|
||||
|
||||
@argument_validation
|
||||
async def friend_list(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
:说明:
|
||||
@ -509,8 +500,9 @@ class MiraiBot(BaseBot):
|
||||
|
||||
- ``List[Dict[str, Any]]``: 返回的好友列表数据
|
||||
"""
|
||||
return await self.api.request('friendList') # type: ignore
|
||||
return await self.api.request('friendList')
|
||||
|
||||
@argument_validation
|
||||
async def group_list(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
:说明:
|
||||
@ -521,8 +513,9 @@ class MiraiBot(BaseBot):
|
||||
|
||||
- ``List[Dict[str, Any]]``: 返回的群列表数据
|
||||
"""
|
||||
return await self.api.request('groupList') # type: ignore
|
||||
return await self.api.request('groupList')
|
||||
|
||||
@argument_validation
|
||||
async def member_list(self, target: int) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
:说明:
|
||||
@ -537,9 +530,9 @@ class MiraiBot(BaseBot):
|
||||
|
||||
- ``List[Dict[str, Any]]``: 返回的群成员列表数据
|
||||
"""
|
||||
return await self.api.request('memberList',
|
||||
params={'target': target}) # type: ignore
|
||||
return await self.api.request('memberList', params={'target': target})
|
||||
|
||||
@argument_validation
|
||||
async def mute(self, target: int, member_id: int, time: int):
|
||||
"""
|
||||
:说明:
|
||||
@ -559,6 +552,7 @@ class MiraiBot(BaseBot):
|
||||
'time': time
|
||||
})
|
||||
|
||||
@argument_validation
|
||||
async def unmute(self, target: int, member_id: int):
|
||||
"""
|
||||
:说明:
|
||||
@ -576,6 +570,7 @@ class MiraiBot(BaseBot):
|
||||
'memberId': member_id
|
||||
})
|
||||
|
||||
@argument_validation
|
||||
async def kick(self, target: int, member_id: int, msg: str):
|
||||
"""
|
||||
:说明:
|
||||
@ -595,6 +590,7 @@ class MiraiBot(BaseBot):
|
||||
'msg': msg
|
||||
})
|
||||
|
||||
@argument_validation
|
||||
async def quit(self, target: int):
|
||||
"""
|
||||
:说明:
|
||||
@ -607,6 +603,7 @@ class MiraiBot(BaseBot):
|
||||
"""
|
||||
return await self.api.post('quit', params={'target': target})
|
||||
|
||||
@argument_validation
|
||||
async def mute_all(self, target: int):
|
||||
"""
|
||||
:说明:
|
||||
@ -619,6 +616,7 @@ class MiraiBot(BaseBot):
|
||||
"""
|
||||
return await self.api.post('muteAll', params={'target': target})
|
||||
|
||||
@argument_validation
|
||||
async def unmute_all(self, target: int):
|
||||
"""
|
||||
:说明:
|
||||
@ -631,6 +629,7 @@ class MiraiBot(BaseBot):
|
||||
"""
|
||||
return await self.api.post('unmuteAll', params={'target': target})
|
||||
|
||||
@argument_validation
|
||||
async def group_config(self, target: int):
|
||||
"""
|
||||
:说明:
|
||||
@ -656,6 +655,7 @@ class MiraiBot(BaseBot):
|
||||
"""
|
||||
return await self.api.request('groupConfig', params={'target': target})
|
||||
|
||||
@argument_validation
|
||||
async def modify_group_config(self, target: int, config: Dict[str, Any]):
|
||||
"""
|
||||
:说明:
|
||||
@ -673,6 +673,7 @@ class MiraiBot(BaseBot):
|
||||
'config': config
|
||||
})
|
||||
|
||||
@argument_validation
|
||||
async def member_info(self, target: int, member_id: int):
|
||||
"""
|
||||
:说明:
|
||||
@ -699,6 +700,7 @@ class MiraiBot(BaseBot):
|
||||
'memberId': member_id
|
||||
})
|
||||
|
||||
@argument_validation
|
||||
async def modify_member_info(self, target: int, member_id: int,
|
||||
info: Dict[str, Any]):
|
||||
"""
|
||||
|
89
nonebot/adapters/mirai/utils.py
Normal file
89
nonebot/adapters/mirai/utils.py
Normal file
@ -0,0 +1,89 @@
|
||||
from functools import wraps
|
||||
from typing import Callable, Coroutine, TypeVar
|
||||
|
||||
import httpx
|
||||
from pydantic import ValidationError, validate_arguments, Extra
|
||||
|
||||
import nonebot.exception as exception
|
||||
from nonebot.log import logger
|
||||
from nonebot.utils import escape_tag
|
||||
|
||||
_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])
|
||||
_AnyCallable = TypeVar("_AnyCallable", bound=Callable)
|
||||
|
||||
|
||||
class ActionFailed(exception.ActionFailed):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
API 请求成功返回数据,但 API 操作失败。
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__('mirai')
|
||||
self.data = kwargs.copy()
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(%s)' % ', '.join(
|
||||
map(lambda m: '%s=%r' % m, self.data.items()))
|
||||
|
||||
|
||||
class InvalidArgument(exception.AdapterException):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
调用API的参数出错
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__('mirai')
|
||||
|
||||
|
||||
def catch_network_error(function: _AsyncCallable) -> _AsyncCallable:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
捕捉函数抛出的httpx网络异常并释放``NetworkError``异常
|
||||
处理返回数据, 在code不为0时释放``ActionFailed``异常
|
||||
|
||||
\:\:\: warning
|
||||
此装饰器只支持使用了httpx的异步函数
|
||||
\:\:\:
|
||||
"""
|
||||
|
||||
@wraps(function)
|
||||
async def wrapper(*args, **kwargs):
|
||||
try:
|
||||
data = await function(*args, **kwargs)
|
||||
except httpx.HTTPError:
|
||||
raise exception.NetworkError('mirai')
|
||||
logger.opt(colors=True).debug('<b>Mirai API returned data:</b> '
|
||||
f'<y>{escape_tag(str(data))}</y>')
|
||||
if isinstance(data, dict):
|
||||
if data.get('code', 0) != 0:
|
||||
raise ActionFailed(**data)
|
||||
return data
|
||||
|
||||
return wrapper # type: ignore
|
||||
|
||||
|
||||
def argument_validation(function: _AnyCallable) -> _AnyCallable:
|
||||
"""
|
||||
:说明:
|
||||
|
||||
通过函数签名中的类型注解来对传入参数进行运行时校验
|
||||
会在参数出错时释放``InvalidArgument``异常
|
||||
"""
|
||||
function = validate_arguments(config={
|
||||
'arbitrary_types_allowed': True,
|
||||
'extra': Extra.forbid
|
||||
})(function)
|
||||
|
||||
@wraps(function)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return function(*args, **kwargs)
|
||||
except ValidationError:
|
||||
raise InvalidArgument
|
||||
|
||||
return wrapper # type: ignore
|
Loading…
Reference in New Issue
Block a user