diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index a9e93aca..59f82563 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -11,23 +11,36 @@ FastAPI 驱动适配 """ import logging +from functools import wraps from typing import Any, List, Tuple, Callable, Optional import uvicorn from pydantic import BaseSettings from fastapi.responses import Response from fastapi import FastAPI, Request, UploadFile, status -from starlette.websockets import WebSocket, WebSocketState +from starlette.websockets import WebSocket, WebSocketState, WebSocketDisconnect from ._model import FileTypes from nonebot.config import Env from nonebot.typing import overrides +from nonebot.exception import WebSocketClosed from nonebot.config import Config as NoneBotConfig from nonebot.drivers import Request as BaseRequest from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup +def catch_closed(func): + @wraps(func) + async def decorator(*args, **kwargs): + try: + return await func(*args, **kwargs) + except WebSocketDisconnect as e: + raise WebSocketClosed(e.code) + + return decorator + + class Config(BaseSettings): """ FastAPI 驱动框架设置,详情参考 FastAPI 文档 @@ -311,10 +324,12 @@ class FastAPIWebSocket(BaseWebSocket): await self.websocket.close(code) @overrides(BaseWebSocket) + @catch_closed async def receive(self) -> str: return await self.websocket.receive_text() @overrides(BaseWebSocket) + @catch_closed async def receive_bytes(self) -> bytes: return await self.websocket.receive_bytes() diff --git a/nonebot/drivers/quart.py b/nonebot/drivers/quart.py index 1f5df4c6..c9fccf26 100644 --- a/nonebot/drivers/quart.py +++ b/nonebot/drivers/quart.py @@ -8,6 +8,8 @@ Quart 驱动适配 https://pgjones.gitlab.io/quart/index.html """ +import asyncio +from functools import wraps from typing import List, Tuple, TypeVar, Callable, Optional, Coroutine import uvicorn @@ -16,6 +18,7 @@ from pydantic import BaseSettings from ._model import FileTypes from nonebot.config import Env from nonebot.typing import overrides +from nonebot.exception import WebSocketClosed from nonebot.config import Config as NoneBotConfig from nonebot.drivers import Request as BaseRequest from nonebot.drivers import WebSocket as BaseWebSocket @@ -35,6 +38,17 @@ except ImportError: _AsyncCallable = TypeVar("_AsyncCallable", bound=Callable[..., Coroutine]) +def catch_closed(func): + @wraps(func) + async def decorator(*args, **kwargs): + try: + return await func(*args, **kwargs) + except asyncio.CancelledError as e: + raise WebSocketClosed(1000) + + return decorator + + class Config(BaseSettings): """ Quart 驱动框架设置 @@ -281,6 +295,7 @@ class WebSocket(BaseWebSocket): await self.websocket.close(code, reason) @overrides(BaseWebSocket) + @catch_closed async def receive(self) -> str: msg = await self.websocket.receive() if isinstance(msg, bytes): @@ -288,6 +303,7 @@ class WebSocket(BaseWebSocket): return msg @overrides(BaseWebSocket) + @catch_closed async def receive_bytes(self) -> bytes: msg = await self.websocket.receive() if isinstance(msg, str):