🐛 fix ws close exception not catch in server

This commit is contained in:
yanyongyu 2021-12-30 12:11:31 +08:00
parent 80c0ac5456
commit 23d0b2509e
2 changed files with 32 additions and 1 deletions

View File

@ -11,23 +11,36 @@ FastAPI 驱动适配
"""
import logging
from functools import wraps
from typing import Any, List, Tuple, Callable, Optional
import uvicorn
from pydantic import BaseSettings
from fastapi.responses import Response
from fastapi import FastAPI, Request, UploadFile, status
from starlette.websockets import WebSocket, WebSocketState
from starlette.websockets import WebSocket, WebSocketState, WebSocketDisconnect
from ._model import FileTypes
from nonebot.config import Env
from nonebot.typing import overrides
from nonebot.exception import WebSocketClosed
from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import Request as BaseRequest
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
def catch_closed(func):
@wraps(func)
async def decorator(*args, **kwargs):
try:
return await func(*args, **kwargs)
except WebSocketDisconnect as e:
raise WebSocketClosed(e.code)
return decorator
class Config(BaseSettings):
"""
FastAPI 驱动框架设置详情参考 FastAPI 文档
@ -311,10 +324,12 @@ class FastAPIWebSocket(BaseWebSocket):
await self.websocket.close(code)
@overrides(BaseWebSocket)
@catch_closed
async def receive(self) -> str:
return await self.websocket.receive_text()
@overrides(BaseWebSocket)
@catch_closed
async def receive_bytes(self) -> bytes:
return await self.websocket.receive_bytes()

View File

@ -8,6 +8,8 @@ Quart 驱动适配
https://pgjones.gitlab.io/quart/index.html
"""
import asyncio
from functools import wraps
from typing import List, Tuple, TypeVar, Callable, Optional, Coroutine
import uvicorn
@ -16,6 +18,7 @@ from pydantic import BaseSettings
from ._model import FileTypes
from nonebot.config import Env
from nonebot.typing import overrides
from nonebot.exception import WebSocketClosed
from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import Request as BaseRequest
from nonebot.drivers import WebSocket as BaseWebSocket
@ -35,6 +38,17 @@ except ImportError:
_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])
def catch_closed(func):
@wraps(func)
async def decorator(*args, **kwargs):
try:
return await func(*args, **kwargs)
except asyncio.CancelledError as e:
raise WebSocketClosed(1000)
return decorator
class Config(BaseSettings):
"""
Quart 驱动框架设置
@ -281,6 +295,7 @@ class WebSocket(BaseWebSocket):
await self.websocket.close(code, reason)
@overrides(BaseWebSocket)
@catch_closed
async def receive(self) -> str:
msg = await self.websocket.receive()
if isinstance(msg, bytes):
@ -288,6 +303,7 @@ class WebSocket(BaseWebSocket):
return msg
@overrides(BaseWebSocket)
@catch_closed
async def receive_bytes(self) -> bytes:
msg = await self.websocket.receive()
if isinstance(msg, str):