add websocket close exception

This commit is contained in:
yanyongyu 2021-12-26 14:20:09 +08:00
parent e64f399370
commit 8093c5d154
3 changed files with 58 additions and 2 deletions

View File

@ -10,6 +10,7 @@ from contextlib import asynccontextmanager
from nonebot.typing import overrides
from nonebot.drivers import Request, Response
from nonebot.exception import WebSocketClosed
from nonebot.drivers._block_driver import BlockDriver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import HTTPVersion, ForwardMixin, combine_driver
@ -109,13 +110,25 @@ class WebSocket(BaseWebSocket):
await self.websocket.close(code=code)
await self.session.close()
async def _receive(self) -> aiohttp.WSMessage:
msg = await self.websocket.receive()
if msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
raise WebSocketClosed(self.websocket.close_code or 1006)
return msg
@overrides(BaseWebSocket)
async def receive(self) -> str:
return await self.websocket.receive_str()
msg = await self._receive()
if msg.type != aiohttp.WSMsgType.TEXT:
raise TypeError(f"WebSocket received unexpected frame type: {msg.type}")
return msg.data
@overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes:
return await self.websocket.receive_bytes()
msg = await self._receive()
if msg.type != aiohttp.WSMsgType.TEXT:
raise TypeError(f"WebSocket received unexpected frame type: {msg.type}")
return msg.data
@overrides(BaseWebSocket)
async def send(self, data: str) -> None:

View File

@ -1,15 +1,18 @@
import logging
from functools import wraps
from typing import AsyncGenerator
from contextlib import asynccontextmanager
from nonebot.typing import overrides
from nonebot.log import LoguruHandler
from nonebot.drivers import Request, Response
from nonebot.exception import WebSocketClosed
from nonebot.drivers._block_driver import BlockDriver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import ForwardMixin, combine_driver
try:
from websockets.exceptions import ConnectionClosed
from websockets.legacy.client import Connect, WebSocketClientProtocol
except ImportError:
raise ImportError(
@ -20,6 +23,20 @@ logger = logging.Logger("websockets.client", "INFO")
logger.addHandler(LoguruHandler())
def catch_closed(func):
@wraps(func)
async def decorator(*args, **kwargs):
try:
return await func(*args, **kwargs)
except ConnectionClosed as e:
if e.rcvd_then_sent:
raise WebSocketClosed(e.rcvd.code, e.rcvd.reason)
else:
raise WebSocketClosed(e.sent.code, e.sent.reason)
return decorator
class Mixin(ForwardMixin):
@property
@overrides(ForwardMixin)
@ -62,6 +79,7 @@ class WebSocket(BaseWebSocket):
await self.websocket.close(code, reason)
@overrides(BaseWebSocket)
@catch_closed
async def receive(self) -> str:
msg = await self.websocket.recv()
if isinstance(msg, bytes):
@ -69,6 +87,7 @@ class WebSocket(BaseWebSocket):
return msg
@overrides(BaseWebSocket)
@catch_closed
async def receive_bytes(self) -> bytes:
msg = await self.websocket.recv()
if isinstance(msg, str):

View File

@ -230,3 +230,27 @@ class ActionFailed(AdapterException):
"""
pass
# Driver Exceptions
class DriverException(NoneBotException):
"""
:说明:
``Driver`` 抛出的异常基类
"""
class WebSocketClosed(DriverException):
"""
:说明:
WebSocket 连接已关闭
"""
def __init__(self, code: int, reason: Optional[str] = None):
self.code = code
self.reason = reason
def __repr__(self) -> str:
return f"<WebSocketClosed code={self.code} reason={self.reason}>"