add retry for mirai adapter when websocket connection down

This commit is contained in:
Mix 2021-01-30 13:36:31 +08:00
parent e2f837055e
commit 8b3eb4e076

View File

@ -9,6 +9,7 @@ import websockets
from nonebot.adapters import Bot as BaseBot from nonebot.adapters import Bot as BaseBot
from nonebot.adapters import Event as BaseEvent from nonebot.adapters import Event as BaseEvent
from nonebot.config import Config
from nonebot.drivers import Driver from nonebot.drivers import Driver
from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.exception import RequestDenied from nonebot.exception import RequestDenied
@ -16,7 +17,7 @@ from nonebot.log import logger
from nonebot.message import handle_event from nonebot.message import handle_event
from nonebot.typing import overrides from nonebot.typing import overrides
from .config import Config from .config import Config as MiraiConfig
from .event import Event from .event import Event
WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]] WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]]
@ -24,6 +25,32 @@ WebsocketHandler_T = TypeVar('WebsocketHandler_T',
bound=WebsocketHandlerFunction) bound=WebsocketHandlerFunction)
async def _ws_authorization(client: httpx.AsyncClient, *, auth_key: str,
qq: int) -> str:
async def request(method: str, *, path: str, **kwargs) -> Dict[str, Any]:
response = await client.request(method, path, **kwargs)
response.raise_for_status()
return response.json()
about = await request('GET', path='/about')
logger.opt(colors=True).debug('Mirai API HTTP backend version: '
f'<g><b>{about["data"]["version"]}</b></g>')
status = await request('POST', path='/auth', json={'authKey': auth_key})
assert status['code'] == 0
session_key = status['session']
verify = await request('POST',
path='/verify',
json={
'sessionKey': session_key,
'qq': qq
})
assert verify['code'] == 0, verify['msg']
return session_key
class WebSocket(BaseWebSocket): class WebSocket(BaseWebSocket):
@classmethod @classmethod
@ -44,6 +71,11 @@ class WebSocket(BaseWebSocket):
def websocket(self) -> websockets.WebSocketClientProtocol: def websocket(self) -> websockets.WebSocketClientProtocol:
return self._websocket return self._websocket
@property
@overrides(BaseWebSocket)
def closed(self) -> bool:
return self.websocket.closed
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def send(self, data: Dict[str, Any]): async def send(self, data: Dict[str, Any]):
return await self.websocket.send(json.dumps(data)) return await self.websocket.send(json.dumps(data))
@ -54,23 +86,26 @@ class WebSocket(BaseWebSocket):
return json.loads(received) return json.loads(received)
async def _dispatcher(self): async def _dispatcher(self):
while not self.websocket.closed: while not self.closed:
try: try:
data = await self.receive() data = await self.receive()
except websockets.ConnectionClosedOK: except websockets.ConnectionClosedOK:
logger.debug(f'Websocket connection {self.websocket} closed') logger.debug(f'Websocket connection {self.websocket} closed')
break break
except Exception as e: except websockets.ConnectionClosedError:
logger.exception(f'Websocket connection {self.websocket} '
'connection closed abnormally:')
break
except json.JSONDecodeError as e:
logger.exception(f'Websocket client listened {self.websocket} ' logger.exception(f'Websocket client listened {self.websocket} '
f'failed to receive data: {e}') f'failed to decode data: {e}')
continue continue
asyncio.ensure_future( asyncio.gather(*map(lambda f: f(data), self.event_handlers),
asyncio.gather(*map(lambda f: f(data), self.event_handlers), return_exceptions=True)
return_exceptions=True))
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def accept(self): async def accept(self):
asyncio.ensure_future(self._dispatcher()) asyncio.create_task(self._dispatcher())
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def close(self): async def close(self):
@ -92,6 +127,10 @@ class MiraiBot(BaseBot):
def type(self) -> str: def type(self) -> str:
return "mirai" return "mirai"
@property
def alive(self) -> bool:
return not self.websocket.closed
@classmethod @classmethod
@overrides(BaseBot) @overrides(BaseBot)
async def check_permission(cls, driver: "Driver", connection_type: str, async def check_permission(cls, driver: "Driver", connection_type: str,
@ -103,33 +142,26 @@ class MiraiBot(BaseBot):
@classmethod @classmethod
@overrides(BaseBot) @overrides(BaseBot)
def register(cls, driver: "Driver", config: "Config", qq: int): def register(cls, driver: "Driver", config: "Config", qq: int):
config = Config.parse_obj(config.dict()) cls.mirai_config = MiraiConfig(**config.dict())
assert config.auth_key and config.host and config.port, f'Current config {config!r} is invalid' cls.active = True
assert cls.mirai_config.auth_key is not None
assert cls.mirai_config.host is not None
assert cls.mirai_config.port is not None
super().register(driver, config)
super().register(driver, config) # type: ignore async def _bot_connection():
@driver.on_startup
async def _startup():
async with httpx.AsyncClient( async with httpx.AsyncClient(
base_url=f'http://{config.host}:{config.port}') as client: base_url=
response = await client.get('/about') f'http://{cls.mirai_config.host}:{cls.mirai_config.port}'
info = response.json() ) as client:
logger.debug(f'Mirai API returned info: {info}') session_key = await _ws_authorization(
response = await client.post('/auth', client,
json={'authKey': config.auth_key}) auth_key=cls.mirai_config.auth_key, # type: ignore
status = response.json() qq=qq) # type: ignore
assert status['code'] == 0
session_key = status['session']
response = await client.post('/verify',
json={
'sessionKey': session_key,
'qq': qq
})
assert response.json()['code'] == 0
websocket = await WebSocket.new( websocket = await WebSocket.new(
host=config.host, # type: ignore host=cls.mirai_config.host, # type: ignore
port=config.port, # type: ignore port=cls.mirai_config.port, # type: ignore
session_key=session_key) session_key=session_key)
bot = cls(connection_type='forward_ws', bot = cls(connection_type='forward_ws',
self_id=str(qq), self_id=str(qq),
@ -138,8 +170,32 @@ class MiraiBot(BaseBot):
driver._clients[str(qq)] = bot driver._clients[str(qq)] = bot
await websocket.accept() await websocket.accept()
async def _connection_ensure():
if str(qq) not in driver._clients:
await _bot_connection()
elif not driver._clients[str(qq)].alive:
driver._clients.pop(str(qq), None)
await _bot_connection()
@driver.on_startup
async def _startup():
async def _checker():
while cls.active:
try:
await _connection_ensure()
except Exception as e:
logger.opt(colors=True).warning(
'Failed to create mirai connection to '
f'<y>{qq}</y>, reason: <r>{e}</r>. '
'Will retry after 3 seconds')
await asyncio.sleep(3)
asyncio.create_task(_checker())
@driver.on_shutdown @driver.on_shutdown
async def _shutdown(): async def _shutdown():
cls.active = False
bot = driver._clients.pop(str(qq), None) bot = driver._clients.pop(str(qq), None)
if bot is None: if bot is None:
return return