From 8093c5d15489ea2d74681c4a75eabe48ef3dd3b7 Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Sun, 26 Dec 2021 14:20:09 +0800 Subject: [PATCH] :wheelchair: add websocket close exception --- nonebot/drivers/aiohttp.py | 17 +++++++++++++++-- nonebot/drivers/websockets.py | 19 +++++++++++++++++++ nonebot/exception.py | 24 ++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 2 deletions(-) diff --git a/nonebot/drivers/aiohttp.py b/nonebot/drivers/aiohttp.py index 080744cf..c2e50481 100644 --- a/nonebot/drivers/aiohttp.py +++ b/nonebot/drivers/aiohttp.py @@ -10,6 +10,7 @@ from contextlib import asynccontextmanager from nonebot.typing import overrides from nonebot.drivers import Request, Response +from nonebot.exception import WebSocketClosed from nonebot.drivers._block_driver import BlockDriver from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import HTTPVersion, ForwardMixin, combine_driver @@ -109,13 +110,25 @@ class WebSocket(BaseWebSocket): await self.websocket.close(code=code) await self.session.close() + async def _receive(self) -> aiohttp.WSMessage: + msg = await self.websocket.receive() + if msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING): + raise WebSocketClosed(self.websocket.close_code or 1006) + return msg + @overrides(BaseWebSocket) async def receive(self) -> str: - return await self.websocket.receive_str() + msg = await self._receive() + if msg.type != aiohttp.WSMsgType.TEXT: + raise TypeError(f"WebSocket received unexpected frame type: {msg.type}") + return msg.data @overrides(BaseWebSocket) async def receive_bytes(self) -> bytes: - return await self.websocket.receive_bytes() + msg = await self._receive() + if msg.type != aiohttp.WSMsgType.TEXT: + raise TypeError(f"WebSocket received unexpected frame type: {msg.type}") + return msg.data @overrides(BaseWebSocket) async def send(self, data: str) -> None: diff --git a/nonebot/drivers/websockets.py b/nonebot/drivers/websockets.py index 0cc4827b..b3f2d99c 100644 --- a/nonebot/drivers/websockets.py +++ b/nonebot/drivers/websockets.py @@ -1,15 +1,18 @@ import logging +from functools import wraps from typing import AsyncGenerator from contextlib import asynccontextmanager from nonebot.typing import overrides from nonebot.log import LoguruHandler from nonebot.drivers import Request, Response +from nonebot.exception import WebSocketClosed from nonebot.drivers._block_driver import BlockDriver from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import ForwardMixin, combine_driver try: + from websockets.exceptions import ConnectionClosed from websockets.legacy.client import Connect, WebSocketClientProtocol except ImportError: raise ImportError( @@ -20,6 +23,20 @@ logger = logging.Logger("websockets.client", "INFO") logger.addHandler(LoguruHandler()) +def catch_closed(func): + @wraps(func) + async def decorator(*args, **kwargs): + try: + return await func(*args, **kwargs) + except ConnectionClosed as e: + if e.rcvd_then_sent: + raise WebSocketClosed(e.rcvd.code, e.rcvd.reason) + else: + raise WebSocketClosed(e.sent.code, e.sent.reason) + + return decorator + + class Mixin(ForwardMixin): @property @overrides(ForwardMixin) @@ -62,6 +79,7 @@ class WebSocket(BaseWebSocket): await self.websocket.close(code, reason) @overrides(BaseWebSocket) + @catch_closed async def receive(self) -> str: msg = await self.websocket.recv() if isinstance(msg, bytes): @@ -69,6 +87,7 @@ class WebSocket(BaseWebSocket): return msg @overrides(BaseWebSocket) + @catch_closed async def receive_bytes(self) -> bytes: msg = await self.websocket.recv() if isinstance(msg, str): diff --git a/nonebot/exception.py b/nonebot/exception.py index 44ac6a7e..ce2fb9e6 100644 --- a/nonebot/exception.py +++ b/nonebot/exception.py @@ -230,3 +230,27 @@ class ActionFailed(AdapterException): """ pass + + +# Driver Exceptions +class DriverException(NoneBotException): + """ + :说明: + + ``Driver`` 抛出的异常基类 + """ + + +class WebSocketClosed(DriverException): + """ + :说明: + + WebSocket 连接已关闭 + """ + + def __init__(self, code: int, reason: Optional[str] = None): + self.code = code + self.reason = reason + + def __repr__(self) -> str: + return f""