mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-27 18:45:05 +08:00
♿ add websocket close exception
This commit is contained in:
parent
e64f399370
commit
8093c5d154
@ -10,6 +10,7 @@ from contextlib import asynccontextmanager
|
||||
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.drivers import Request, Response
|
||||
from nonebot.exception import WebSocketClosed
|
||||
from nonebot.drivers._block_driver import BlockDriver
|
||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||
from nonebot.drivers import HTTPVersion, ForwardMixin, combine_driver
|
||||
@ -109,13 +110,25 @@ class WebSocket(BaseWebSocket):
|
||||
await self.websocket.close(code=code)
|
||||
await self.session.close()
|
||||
|
||||
async def _receive(self) -> aiohttp.WSMessage:
|
||||
msg = await self.websocket.receive()
|
||||
if msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
|
||||
raise WebSocketClosed(self.websocket.close_code or 1006)
|
||||
return msg
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def receive(self) -> str:
|
||||
return await self.websocket.receive_str()
|
||||
msg = await self._receive()
|
||||
if msg.type != aiohttp.WSMsgType.TEXT:
|
||||
raise TypeError(f"WebSocket received unexpected frame type: {msg.type}")
|
||||
return msg.data
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def receive_bytes(self) -> bytes:
|
||||
return await self.websocket.receive_bytes()
|
||||
msg = await self._receive()
|
||||
if msg.type != aiohttp.WSMsgType.TEXT:
|
||||
raise TypeError(f"WebSocket received unexpected frame type: {msg.type}")
|
||||
return msg.data
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def send(self, data: str) -> None:
|
||||
|
@ -1,15 +1,18 @@
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.log import LoguruHandler
|
||||
from nonebot.drivers import Request, Response
|
||||
from nonebot.exception import WebSocketClosed
|
||||
from nonebot.drivers._block_driver import BlockDriver
|
||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||
from nonebot.drivers import ForwardMixin, combine_driver
|
||||
|
||||
try:
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
from websockets.legacy.client import Connect, WebSocketClientProtocol
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
@ -20,6 +23,20 @@ logger = logging.Logger("websockets.client", "INFO")
|
||||
logger.addHandler(LoguruHandler())
|
||||
|
||||
|
||||
def catch_closed(func):
|
||||
@wraps(func)
|
||||
async def decorator(*args, **kwargs):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except ConnectionClosed as e:
|
||||
if e.rcvd_then_sent:
|
||||
raise WebSocketClosed(e.rcvd.code, e.rcvd.reason)
|
||||
else:
|
||||
raise WebSocketClosed(e.sent.code, e.sent.reason)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class Mixin(ForwardMixin):
|
||||
@property
|
||||
@overrides(ForwardMixin)
|
||||
@ -62,6 +79,7 @@ class WebSocket(BaseWebSocket):
|
||||
await self.websocket.close(code, reason)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@catch_closed
|
||||
async def receive(self) -> str:
|
||||
msg = await self.websocket.recv()
|
||||
if isinstance(msg, bytes):
|
||||
@ -69,6 +87,7 @@ class WebSocket(BaseWebSocket):
|
||||
return msg
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@catch_closed
|
||||
async def receive_bytes(self) -> bytes:
|
||||
msg = await self.websocket.recv()
|
||||
if isinstance(msg, str):
|
||||
|
@ -230,3 +230,27 @@ class ActionFailed(AdapterException):
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# Driver Exceptions
|
||||
class DriverException(NoneBotException):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
``Driver`` 抛出的异常基类
|
||||
"""
|
||||
|
||||
|
||||
class WebSocketClosed(DriverException):
|
||||
"""
|
||||
:说明:
|
||||
|
||||
WebSocket 连接已关闭
|
||||
"""
|
||||
|
||||
def __init__(self, code: int, reason: Optional[str] = None):
|
||||
self.code = code
|
||||
self.reason = reason
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<WebSocketClosed code={self.code} reason={self.reason}>"
|
||||
|
Loading…
Reference in New Issue
Block a user