🐛 fix ws close exception not catch in server

This commit is contained in:
yanyongyu 2021-12-30 12:11:31 +08:00
parent 80c0ac5456
commit 23d0b2509e
2 changed files with 32 additions and 1 deletions

View File

@ -11,23 +11,36 @@ FastAPI 驱动适配
""" """
import logging import logging
from functools import wraps
from typing import Any, List, Tuple, Callable, Optional from typing import Any, List, Tuple, Callable, Optional
import uvicorn import uvicorn
from pydantic import BaseSettings from pydantic import BaseSettings
from fastapi.responses import Response from fastapi.responses import Response
from fastapi import FastAPI, Request, UploadFile, status from fastapi import FastAPI, Request, UploadFile, status
from starlette.websockets import WebSocket, WebSocketState from starlette.websockets import WebSocket, WebSocketState, WebSocketDisconnect
from ._model import FileTypes from ._model import FileTypes
from nonebot.config import Env from nonebot.config import Env
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.exception import WebSocketClosed
from nonebot.config import Config as NoneBotConfig from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import Request as BaseRequest from nonebot.drivers import Request as BaseRequest
from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
def catch_closed(func):
@wraps(func)
async def decorator(*args, **kwargs):
try:
return await func(*args, **kwargs)
except WebSocketDisconnect as e:
raise WebSocketClosed(e.code)
return decorator
class Config(BaseSettings): class Config(BaseSettings):
""" """
FastAPI 驱动框架设置详情参考 FastAPI 文档 FastAPI 驱动框架设置详情参考 FastAPI 文档
@ -311,10 +324,12 @@ class FastAPIWebSocket(BaseWebSocket):
await self.websocket.close(code) await self.websocket.close(code)
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
@catch_closed
async def receive(self) -> str: async def receive(self) -> str:
return await self.websocket.receive_text() return await self.websocket.receive_text()
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
@catch_closed
async def receive_bytes(self) -> bytes: async def receive_bytes(self) -> bytes:
return await self.websocket.receive_bytes() return await self.websocket.receive_bytes()

View File

@ -8,6 +8,8 @@ Quart 驱动适配
https://pgjones.gitlab.io/quart/index.html https://pgjones.gitlab.io/quart/index.html
""" """
import asyncio
from functools import wraps
from typing import List, Tuple, TypeVar, Callable, Optional, Coroutine from typing import List, Tuple, TypeVar, Callable, Optional, Coroutine
import uvicorn import uvicorn
@ -16,6 +18,7 @@ from pydantic import BaseSettings
from ._model import FileTypes from ._model import FileTypes
from nonebot.config import Env from nonebot.config import Env
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.exception import WebSocketClosed
from nonebot.config import Config as NoneBotConfig from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import Request as BaseRequest from nonebot.drivers import Request as BaseRequest
from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import WebSocket as BaseWebSocket
@ -35,6 +38,17 @@ except ImportError:
_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine]) _AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])
def catch_closed(func):
@wraps(func)
async def decorator(*args, **kwargs):
try:
return await func(*args, **kwargs)
except asyncio.CancelledError as e:
raise WebSocketClosed(1000)
return decorator
class Config(BaseSettings): class Config(BaseSettings):
""" """
Quart 驱动框架设置 Quart 驱动框架设置
@ -281,6 +295,7 @@ class WebSocket(BaseWebSocket):
await self.websocket.close(code, reason) await self.websocket.close(code, reason)
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
@catch_closed
async def receive(self) -> str: async def receive(self) -> str:
msg = await self.websocket.receive() msg = await self.websocket.receive()
if isinstance(msg, bytes): if isinstance(msg, bytes):
@ -288,6 +303,7 @@ class WebSocket(BaseWebSocket):
return msg return msg
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
@catch_closed
async def receive_bytes(self) -> bytes: async def receive_bytes(self) -> bytes:
msg = await self.websocket.receive() msg = await self.websocket.receive()
if isinstance(msg, str): if isinstance(msg, str):