diff --git a/nonebot/adapters/mirai/bot.py b/nonebot/adapters/mirai/bot.py
index 6190bedb..70166eff 100644
--- a/nonebot/adapters/mirai/bot.py
+++ b/nonebot/adapters/mirai/bot.py
@@ -9,6 +9,7 @@ import websockets
from nonebot.adapters import Bot as BaseBot
from nonebot.adapters import Event as BaseEvent
+from nonebot.config import Config
from nonebot.drivers import Driver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.exception import RequestDenied
@@ -16,7 +17,7 @@ from nonebot.log import logger
from nonebot.message import handle_event
from nonebot.typing import overrides
-from .config import Config
+from .config import Config as MiraiConfig
from .event import Event
WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]]
@@ -24,6 +25,32 @@ WebsocketHandler_T = TypeVar('WebsocketHandler_T',
bound=WebsocketHandlerFunction)
+async def _ws_authorization(client: httpx.AsyncClient, *, auth_key: str,
+ qq: int) -> str:
+
+ async def request(method: str, *, path: str, **kwargs) -> Dict[str, Any]:
+ response = await client.request(method, path, **kwargs)
+ response.raise_for_status()
+ return response.json()
+
+ about = await request('GET', path='/about')
+ logger.opt(colors=True).debug('Mirai API HTTP backend version: '
+ f'{about["data"]["version"]}')
+
+ status = await request('POST', path='/auth', json={'authKey': auth_key})
+ assert status['code'] == 0
+ session_key = status['session']
+
+ verify = await request('POST',
+ path='/verify',
+ json={
+ 'sessionKey': session_key,
+ 'qq': qq
+ })
+ assert verify['code'] == 0, verify['msg']
+ return session_key
+
+
class WebSocket(BaseWebSocket):
@classmethod
@@ -44,6 +71,11 @@ class WebSocket(BaseWebSocket):
def websocket(self) -> websockets.WebSocketClientProtocol:
return self._websocket
+ @property
+ @overrides(BaseWebSocket)
+ def closed(self) -> bool:
+ return self.websocket.closed
+
@overrides(BaseWebSocket)
async def send(self, data: Dict[str, Any]):
return await self.websocket.send(json.dumps(data))
@@ -54,23 +86,26 @@ class WebSocket(BaseWebSocket):
return json.loads(received)
async def _dispatcher(self):
- while not self.websocket.closed:
+ while not self.closed:
try:
data = await self.receive()
except websockets.ConnectionClosedOK:
logger.debug(f'Websocket connection {self.websocket} closed')
break
- except Exception as e:
+ except websockets.ConnectionClosedError:
+ logger.exception(f'Websocket connection {self.websocket} '
+ 'connection closed abnormally:')
+ break
+ except json.JSONDecodeError as e:
logger.exception(f'Websocket client listened {self.websocket} '
- f'failed to receive data: {e}')
+ f'failed to decode data: {e}')
continue
- asyncio.ensure_future(
- asyncio.gather(*map(lambda f: f(data), self.event_handlers),
- return_exceptions=True))
+ asyncio.gather(*map(lambda f: f(data), self.event_handlers),
+ return_exceptions=True)
@overrides(BaseWebSocket)
async def accept(self):
- asyncio.ensure_future(self._dispatcher())
+ asyncio.create_task(self._dispatcher())
@overrides(BaseWebSocket)
async def close(self):
@@ -92,6 +127,10 @@ class MiraiBot(BaseBot):
def type(self) -> str:
return "mirai"
+ @property
+ def alive(self) -> bool:
+ return not self.websocket.closed
+
@classmethod
@overrides(BaseBot)
async def check_permission(cls, driver: "Driver", connection_type: str,
@@ -103,33 +142,26 @@ class MiraiBot(BaseBot):
@classmethod
@overrides(BaseBot)
def register(cls, driver: "Driver", config: "Config", qq: int):
- config = Config.parse_obj(config.dict())
- assert config.auth_key and config.host and config.port, f'Current config {config!r} is invalid'
+ cls.mirai_config = MiraiConfig(**config.dict())
+ cls.active = True
+ assert cls.mirai_config.auth_key is not None
+ assert cls.mirai_config.host is not None
+ assert cls.mirai_config.port is not None
+ super().register(driver, config)
- super().register(driver, config) # type: ignore
-
- @driver.on_startup
- async def _startup():
+ async def _bot_connection():
async with httpx.AsyncClient(
- base_url=f'http://{config.host}:{config.port}') as client:
- response = await client.get('/about')
- info = response.json()
- logger.debug(f'Mirai API returned info: {info}')
- response = await client.post('/auth',
- json={'authKey': config.auth_key})
- status = response.json()
- assert status['code'] == 0
- session_key = status['session']
- response = await client.post('/verify',
- json={
- 'sessionKey': session_key,
- 'qq': qq
- })
- assert response.json()['code'] == 0
+ base_url=
+ f'http://{cls.mirai_config.host}:{cls.mirai_config.port}'
+ ) as client:
+ session_key = await _ws_authorization(
+ client,
+ auth_key=cls.mirai_config.auth_key, # type: ignore
+ qq=qq) # type: ignore
websocket = await WebSocket.new(
- host=config.host, # type: ignore
- port=config.port, # type: ignore
+ host=cls.mirai_config.host, # type: ignore
+ port=cls.mirai_config.port, # type: ignore
session_key=session_key)
bot = cls(connection_type='forward_ws',
self_id=str(qq),
@@ -138,8 +170,32 @@ class MiraiBot(BaseBot):
driver._clients[str(qq)] = bot
await websocket.accept()
+ async def _connection_ensure():
+ if str(qq) not in driver._clients:
+ await _bot_connection()
+ elif not driver._clients[str(qq)].alive:
+ driver._clients.pop(str(qq), None)
+ await _bot_connection()
+
+ @driver.on_startup
+ async def _startup():
+
+ async def _checker():
+ while cls.active:
+ try:
+ await _connection_ensure()
+ except Exception as e:
+ logger.opt(colors=True).warning(
+ 'Failed to create mirai connection to '
+ f'{qq}, reason: {e}. '
+ 'Will retry after 3 seconds')
+ await asyncio.sleep(3)
+
+ asyncio.create_task(_checker())
+
@driver.on_shutdown
async def _shutdown():
+ cls.active = False
bot = driver._clients.pop(str(qq), None)
if bot is None:
return