diff --git a/nonebot/adapters/mirai/bot.py b/nonebot/adapters/mirai/bot.py index 6190bedb..70166eff 100644 --- a/nonebot/adapters/mirai/bot.py +++ b/nonebot/adapters/mirai/bot.py @@ -9,6 +9,7 @@ import websockets from nonebot.adapters import Bot as BaseBot from nonebot.adapters import Event as BaseEvent +from nonebot.config import Config from nonebot.drivers import Driver from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.exception import RequestDenied @@ -16,7 +17,7 @@ from nonebot.log import logger from nonebot.message import handle_event from nonebot.typing import overrides -from .config import Config +from .config import Config as MiraiConfig from .event import Event WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]] @@ -24,6 +25,32 @@ WebsocketHandler_T = TypeVar('WebsocketHandler_T', bound=WebsocketHandlerFunction) +async def _ws_authorization(client: httpx.AsyncClient, *, auth_key: str, + qq: int) -> str: + + async def request(method: str, *, path: str, **kwargs) -> Dict[str, Any]: + response = await client.request(method, path, **kwargs) + response.raise_for_status() + return response.json() + + about = await request('GET', path='/about') + logger.opt(colors=True).debug('Mirai API HTTP backend version: ' + f'{about["data"]["version"]}') + + status = await request('POST', path='/auth', json={'authKey': auth_key}) + assert status['code'] == 0 + session_key = status['session'] + + verify = await request('POST', + path='/verify', + json={ + 'sessionKey': session_key, + 'qq': qq + }) + assert verify['code'] == 0, verify['msg'] + return session_key + + class WebSocket(BaseWebSocket): @classmethod @@ -44,6 +71,11 @@ class WebSocket(BaseWebSocket): def websocket(self) -> websockets.WebSocketClientProtocol: return self._websocket + @property + @overrides(BaseWebSocket) + def closed(self) -> bool: + return self.websocket.closed + @overrides(BaseWebSocket) async def send(self, data: Dict[str, Any]): return await self.websocket.send(json.dumps(data)) @@ -54,23 +86,26 @@ class WebSocket(BaseWebSocket): return json.loads(received) async def _dispatcher(self): - while not self.websocket.closed: + while not self.closed: try: data = await self.receive() except websockets.ConnectionClosedOK: logger.debug(f'Websocket connection {self.websocket} closed') break - except Exception as e: + except websockets.ConnectionClosedError: + logger.exception(f'Websocket connection {self.websocket} ' + 'connection closed abnormally:') + break + except json.JSONDecodeError as e: logger.exception(f'Websocket client listened {self.websocket} ' - f'failed to receive data: {e}') + f'failed to decode data: {e}') continue - asyncio.ensure_future( - asyncio.gather(*map(lambda f: f(data), self.event_handlers), - return_exceptions=True)) + asyncio.gather(*map(lambda f: f(data), self.event_handlers), + return_exceptions=True) @overrides(BaseWebSocket) async def accept(self): - asyncio.ensure_future(self._dispatcher()) + asyncio.create_task(self._dispatcher()) @overrides(BaseWebSocket) async def close(self): @@ -92,6 +127,10 @@ class MiraiBot(BaseBot): def type(self) -> str: return "mirai" + @property + def alive(self) -> bool: + return not self.websocket.closed + @classmethod @overrides(BaseBot) async def check_permission(cls, driver: "Driver", connection_type: str, @@ -103,33 +142,26 @@ class MiraiBot(BaseBot): @classmethod @overrides(BaseBot) def register(cls, driver: "Driver", config: "Config", qq: int): - config = Config.parse_obj(config.dict()) - assert config.auth_key and config.host and config.port, f'Current config {config!r} is invalid' + cls.mirai_config = MiraiConfig(**config.dict()) + cls.active = True + assert cls.mirai_config.auth_key is not None + assert cls.mirai_config.host is not None + assert cls.mirai_config.port is not None + super().register(driver, config) - super().register(driver, config) # type: ignore - - @driver.on_startup - async def _startup(): + async def _bot_connection(): async with httpx.AsyncClient( - base_url=f'http://{config.host}:{config.port}') as client: - response = await client.get('/about') - info = response.json() - logger.debug(f'Mirai API returned info: {info}') - response = await client.post('/auth', - json={'authKey': config.auth_key}) - status = response.json() - assert status['code'] == 0 - session_key = status['session'] - response = await client.post('/verify', - json={ - 'sessionKey': session_key, - 'qq': qq - }) - assert response.json()['code'] == 0 + base_url= + f'http://{cls.mirai_config.host}:{cls.mirai_config.port}' + ) as client: + session_key = await _ws_authorization( + client, + auth_key=cls.mirai_config.auth_key, # type: ignore + qq=qq) # type: ignore websocket = await WebSocket.new( - host=config.host, # type: ignore - port=config.port, # type: ignore + host=cls.mirai_config.host, # type: ignore + port=cls.mirai_config.port, # type: ignore session_key=session_key) bot = cls(connection_type='forward_ws', self_id=str(qq), @@ -138,8 +170,32 @@ class MiraiBot(BaseBot): driver._clients[str(qq)] = bot await websocket.accept() + async def _connection_ensure(): + if str(qq) not in driver._clients: + await _bot_connection() + elif not driver._clients[str(qq)].alive: + driver._clients.pop(str(qq), None) + await _bot_connection() + + @driver.on_startup + async def _startup(): + + async def _checker(): + while cls.active: + try: + await _connection_ensure() + except Exception as e: + logger.opt(colors=True).warning( + 'Failed to create mirai connection to ' + f'{qq}, reason: {e}. ' + 'Will retry after 3 seconds') + await asyncio.sleep(3) + + asyncio.create_task(_checker()) + @driver.on_shutdown async def _shutdown(): + cls.active = False bot = driver._clients.pop(str(qq), None) if bot is None: return