mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-01-31 15:41:34 +08:00
⚡ add retry for mirai adapter when websocket connection down
This commit is contained in:
parent
e2f837055e
commit
8b3eb4e076
@ -9,6 +9,7 @@ import websockets
|
|||||||
|
|
||||||
from nonebot.adapters import Bot as BaseBot
|
from nonebot.adapters import Bot as BaseBot
|
||||||
from nonebot.adapters import Event as BaseEvent
|
from nonebot.adapters import Event as BaseEvent
|
||||||
|
from nonebot.config import Config
|
||||||
from nonebot.drivers import Driver
|
from nonebot.drivers import Driver
|
||||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||||
from nonebot.exception import RequestDenied
|
from nonebot.exception import RequestDenied
|
||||||
@ -16,7 +17,7 @@ from nonebot.log import logger
|
|||||||
from nonebot.message import handle_event
|
from nonebot.message import handle_event
|
||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
|
|
||||||
from .config import Config
|
from .config import Config as MiraiConfig
|
||||||
from .event import Event
|
from .event import Event
|
||||||
|
|
||||||
WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]]
|
WebsocketHandlerFunction = Callable[[Dict[str, Any]], Coroutine[Any, Any, None]]
|
||||||
@ -24,6 +25,32 @@ WebsocketHandler_T = TypeVar('WebsocketHandler_T',
|
|||||||
bound=WebsocketHandlerFunction)
|
bound=WebsocketHandlerFunction)
|
||||||
|
|
||||||
|
|
||||||
|
async def _ws_authorization(client: httpx.AsyncClient, *, auth_key: str,
|
||||||
|
qq: int) -> str:
|
||||||
|
|
||||||
|
async def request(method: str, *, path: str, **kwargs) -> Dict[str, Any]:
|
||||||
|
response = await client.request(method, path, **kwargs)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
about = await request('GET', path='/about')
|
||||||
|
logger.opt(colors=True).debug('Mirai API HTTP backend version: '
|
||||||
|
f'<g><b>{about["data"]["version"]}</b></g>')
|
||||||
|
|
||||||
|
status = await request('POST', path='/auth', json={'authKey': auth_key})
|
||||||
|
assert status['code'] == 0
|
||||||
|
session_key = status['session']
|
||||||
|
|
||||||
|
verify = await request('POST',
|
||||||
|
path='/verify',
|
||||||
|
json={
|
||||||
|
'sessionKey': session_key,
|
||||||
|
'qq': qq
|
||||||
|
})
|
||||||
|
assert verify['code'] == 0, verify['msg']
|
||||||
|
return session_key
|
||||||
|
|
||||||
|
|
||||||
class WebSocket(BaseWebSocket):
|
class WebSocket(BaseWebSocket):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -44,6 +71,11 @@ class WebSocket(BaseWebSocket):
|
|||||||
def websocket(self) -> websockets.WebSocketClientProtocol:
|
def websocket(self) -> websockets.WebSocketClientProtocol:
|
||||||
return self._websocket
|
return self._websocket
|
||||||
|
|
||||||
|
@property
|
||||||
|
@overrides(BaseWebSocket)
|
||||||
|
def closed(self) -> bool:
|
||||||
|
return self.websocket.closed
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def send(self, data: Dict[str, Any]):
|
async def send(self, data: Dict[str, Any]):
|
||||||
return await self.websocket.send(json.dumps(data))
|
return await self.websocket.send(json.dumps(data))
|
||||||
@ -54,23 +86,26 @@ class WebSocket(BaseWebSocket):
|
|||||||
return json.loads(received)
|
return json.loads(received)
|
||||||
|
|
||||||
async def _dispatcher(self):
|
async def _dispatcher(self):
|
||||||
while not self.websocket.closed:
|
while not self.closed:
|
||||||
try:
|
try:
|
||||||
data = await self.receive()
|
data = await self.receive()
|
||||||
except websockets.ConnectionClosedOK:
|
except websockets.ConnectionClosedOK:
|
||||||
logger.debug(f'Websocket connection {self.websocket} closed')
|
logger.debug(f'Websocket connection {self.websocket} closed')
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except websockets.ConnectionClosedError:
|
||||||
|
logger.exception(f'Websocket connection {self.websocket} '
|
||||||
|
'connection closed abnormally:')
|
||||||
|
break
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
logger.exception(f'Websocket client listened {self.websocket} '
|
logger.exception(f'Websocket client listened {self.websocket} '
|
||||||
f'failed to receive data: {e}')
|
f'failed to decode data: {e}')
|
||||||
continue
|
continue
|
||||||
asyncio.ensure_future(
|
asyncio.gather(*map(lambda f: f(data), self.event_handlers),
|
||||||
asyncio.gather(*map(lambda f: f(data), self.event_handlers),
|
return_exceptions=True)
|
||||||
return_exceptions=True))
|
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def accept(self):
|
async def accept(self):
|
||||||
asyncio.ensure_future(self._dispatcher())
|
asyncio.create_task(self._dispatcher())
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def close(self):
|
async def close(self):
|
||||||
@ -92,6 +127,10 @@ class MiraiBot(BaseBot):
|
|||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
return "mirai"
|
return "mirai"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def alive(self) -> bool:
|
||||||
|
return not self.websocket.closed
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@overrides(BaseBot)
|
@overrides(BaseBot)
|
||||||
async def check_permission(cls, driver: "Driver", connection_type: str,
|
async def check_permission(cls, driver: "Driver", connection_type: str,
|
||||||
@ -103,33 +142,26 @@ class MiraiBot(BaseBot):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@overrides(BaseBot)
|
@overrides(BaseBot)
|
||||||
def register(cls, driver: "Driver", config: "Config", qq: int):
|
def register(cls, driver: "Driver", config: "Config", qq: int):
|
||||||
config = Config.parse_obj(config.dict())
|
cls.mirai_config = MiraiConfig(**config.dict())
|
||||||
assert config.auth_key and config.host and config.port, f'Current config {config!r} is invalid'
|
cls.active = True
|
||||||
|
assert cls.mirai_config.auth_key is not None
|
||||||
|
assert cls.mirai_config.host is not None
|
||||||
|
assert cls.mirai_config.port is not None
|
||||||
|
super().register(driver, config)
|
||||||
|
|
||||||
super().register(driver, config) # type: ignore
|
async def _bot_connection():
|
||||||
|
|
||||||
@driver.on_startup
|
|
||||||
async def _startup():
|
|
||||||
async with httpx.AsyncClient(
|
async with httpx.AsyncClient(
|
||||||
base_url=f'http://{config.host}:{config.port}') as client:
|
base_url=
|
||||||
response = await client.get('/about')
|
f'http://{cls.mirai_config.host}:{cls.mirai_config.port}'
|
||||||
info = response.json()
|
) as client:
|
||||||
logger.debug(f'Mirai API returned info: {info}')
|
session_key = await _ws_authorization(
|
||||||
response = await client.post('/auth',
|
client,
|
||||||
json={'authKey': config.auth_key})
|
auth_key=cls.mirai_config.auth_key, # type: ignore
|
||||||
status = response.json()
|
qq=qq) # type: ignore
|
||||||
assert status['code'] == 0
|
|
||||||
session_key = status['session']
|
|
||||||
response = await client.post('/verify',
|
|
||||||
json={
|
|
||||||
'sessionKey': session_key,
|
|
||||||
'qq': qq
|
|
||||||
})
|
|
||||||
assert response.json()['code'] == 0
|
|
||||||
|
|
||||||
websocket = await WebSocket.new(
|
websocket = await WebSocket.new(
|
||||||
host=config.host, # type: ignore
|
host=cls.mirai_config.host, # type: ignore
|
||||||
port=config.port, # type: ignore
|
port=cls.mirai_config.port, # type: ignore
|
||||||
session_key=session_key)
|
session_key=session_key)
|
||||||
bot = cls(connection_type='forward_ws',
|
bot = cls(connection_type='forward_ws',
|
||||||
self_id=str(qq),
|
self_id=str(qq),
|
||||||
@ -138,8 +170,32 @@ class MiraiBot(BaseBot):
|
|||||||
driver._clients[str(qq)] = bot
|
driver._clients[str(qq)] = bot
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
|
|
||||||
|
async def _connection_ensure():
|
||||||
|
if str(qq) not in driver._clients:
|
||||||
|
await _bot_connection()
|
||||||
|
elif not driver._clients[str(qq)].alive:
|
||||||
|
driver._clients.pop(str(qq), None)
|
||||||
|
await _bot_connection()
|
||||||
|
|
||||||
|
@driver.on_startup
|
||||||
|
async def _startup():
|
||||||
|
|
||||||
|
async def _checker():
|
||||||
|
while cls.active:
|
||||||
|
try:
|
||||||
|
await _connection_ensure()
|
||||||
|
except Exception as e:
|
||||||
|
logger.opt(colors=True).warning(
|
||||||
|
'Failed to create mirai connection to '
|
||||||
|
f'<y>{qq}</y>, reason: <r>{e}</r>. '
|
||||||
|
'Will retry after 3 seconds')
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
|
||||||
|
asyncio.create_task(_checker())
|
||||||
|
|
||||||
@driver.on_shutdown
|
@driver.on_shutdown
|
||||||
async def _shutdown():
|
async def _shutdown():
|
||||||
|
cls.active = False
|
||||||
bot = driver._clients.pop(str(qq), None)
|
bot = driver._clients.pop(str(qq), None)
|
||||||
if bot is None:
|
if bot is None:
|
||||||
return
|
return
|
||||||
|
Loading…
x
Reference in New Issue
Block a user