mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-27 18:45:05 +08:00
🐛 fix ws close exception not catch in server
This commit is contained in:
parent
80c0ac5456
commit
23d0b2509e
@ -11,23 +11,36 @@ FastAPI 驱动适配
|
||||
"""
|
||||
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Any, List, Tuple, Callable, Optional
|
||||
|
||||
import uvicorn
|
||||
from pydantic import BaseSettings
|
||||
from fastapi.responses import Response
|
||||
from fastapi import FastAPI, Request, UploadFile, status
|
||||
from starlette.websockets import WebSocket, WebSocketState
|
||||
from starlette.websockets import WebSocket, WebSocketState, WebSocketDisconnect
|
||||
|
||||
from ._model import FileTypes
|
||||
from nonebot.config import Env
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.exception import WebSocketClosed
|
||||
from nonebot.config import Config as NoneBotConfig
|
||||
from nonebot.drivers import Request as BaseRequest
|
||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
|
||||
|
||||
|
||||
def catch_closed(func):
|
||||
@wraps(func)
|
||||
async def decorator(*args, **kwargs):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except WebSocketDisconnect as e:
|
||||
raise WebSocketClosed(e.code)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class Config(BaseSettings):
|
||||
"""
|
||||
FastAPI 驱动框架设置,详情参考 FastAPI 文档
|
||||
@ -311,10 +324,12 @@ class FastAPIWebSocket(BaseWebSocket):
|
||||
await self.websocket.close(code)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@catch_closed
|
||||
async def receive(self) -> str:
|
||||
return await self.websocket.receive_text()
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@catch_closed
|
||||
async def receive_bytes(self) -> bytes:
|
||||
return await self.websocket.receive_bytes()
|
||||
|
||||
|
@ -8,6 +8,8 @@ Quart 驱动适配
|
||||
https://pgjones.gitlab.io/quart/index.html
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
from typing import List, Tuple, TypeVar, Callable, Optional, Coroutine
|
||||
|
||||
import uvicorn
|
||||
@ -16,6 +18,7 @@ from pydantic import BaseSettings
|
||||
from ._model import FileTypes
|
||||
from nonebot.config import Env
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.exception import WebSocketClosed
|
||||
from nonebot.config import Config as NoneBotConfig
|
||||
from nonebot.drivers import Request as BaseRequest
|
||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||
@ -35,6 +38,17 @@ except ImportError:
|
||||
_AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine])
|
||||
|
||||
|
||||
def catch_closed(func):
|
||||
@wraps(func)
|
||||
async def decorator(*args, **kwargs):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except asyncio.CancelledError as e:
|
||||
raise WebSocketClosed(1000)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class Config(BaseSettings):
|
||||
"""
|
||||
Quart 驱动框架设置
|
||||
@ -281,6 +295,7 @@ class WebSocket(BaseWebSocket):
|
||||
await self.websocket.close(code, reason)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@catch_closed
|
||||
async def receive(self) -> str:
|
||||
msg = await self.websocket.receive()
|
||||
if isinstance(msg, bytes):
|
||||
@ -288,6 +303,7 @@ class WebSocket(BaseWebSocket):
|
||||
return msg
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@catch_closed
|
||||
async def receive_bytes(self) -> bytes:
|
||||
msg = await self.websocket.receive()
|
||||
if isinstance(msg, str):
|
||||
|
Loading…
Reference in New Issue
Block a user