diff --git a/nonebot/drivers/quart.py b/nonebot/drivers/quart.py index d7df01c0..4e5aac0d 100644 --- a/nonebot/drivers/quart.py +++ b/nonebot/drivers/quart.py @@ -17,7 +17,18 @@ FrontMatter: import asyncio from functools import wraps -from typing import Any, Dict, List, Tuple, Union, TypeVar, Callable, Optional, Coroutine +from typing import ( + Any, + Dict, + List, + Tuple, + Union, + TypeVar, + Callable, + Optional, + Coroutine, + cast, +) from pydantic import BaseSettings @@ -33,7 +44,8 @@ from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup try: import uvicorn from quart import request as _request - from quart import websocket as _websocket + from quart.ctx import WebsocketContext + from quart.globals import websocket_ctx from quart import Quart, Request, Response from quart.datastructures import FileStorage from quart import Websocket as QuartWebSocket @@ -222,7 +234,8 @@ class Driver(ReverseDriver): ) async def _handle_ws(self, setup: WebSocketServerSetup) -> None: - websocket: QuartWebSocket = _websocket + ctx = cast(WebsocketContext, websocket_ctx.copy()) + websocket = websocket_ctx.websocket http_request = BaseRequest( websocket.method, @@ -232,7 +245,7 @@ class Driver(ReverseDriver): version=websocket.http_version, ) - ws = WebSocket(request=http_request, websocket=websocket) + ws = WebSocket(request=http_request, websocket_ctx=ctx) await setup.handle_func(ws) @@ -240,9 +253,13 @@ class Driver(ReverseDriver): class WebSocket(BaseWebSocket): """Quart WebSocket Wrapper""" - def __init__(self, *, request: BaseRequest, websocket: QuartWebSocket): + def __init__(self, *, request: BaseRequest, websocket_ctx: WebsocketContext): super().__init__(request=request) - self.websocket = websocket + self.websocket_ctx = websocket_ctx + + @property + def websocket(self) -> QuartWebSocket: + return self.websocket_ctx.websocket @property @overrides(BaseWebSocket) diff --git a/tests/test_broadcast.py b/tests/test_broadcast.py index 4814761e..28aacad5 100644 --- a/tests/test_broadcast.py +++ b/tests/test_broadcast.py @@ -234,7 +234,7 @@ async def test_run_preprocessor(app: App, monkeypatch: pytest.MonkeyPatch): bot = ctx.create_bot() event = make_fake_event()() ctx.receive_event(bot, event) - ctx.should_call_send(event, "test", True, bot) + ctx.should_call_send(event, "test", True, bot=bot) assert runned, "run_preprocessor should runned" @@ -346,7 +346,7 @@ async def test_run_postprocessor(app: App, monkeypatch: pytest.MonkeyPatch): bot = ctx.create_bot() event = make_fake_event()() ctx.receive_event(bot, event) - ctx.should_call_send(event, "test", True, bot) + ctx.should_call_send(event, "test", True, bot=bot) assert runned, "run_postprocessor should runned" diff --git a/tests/test_driver.py b/tests/test_driver.py index 66394710..bd56deed 100644 --- a/tests/test_driver.py +++ b/tests/test_driver.py @@ -1,6 +1,6 @@ import json import asyncio -from typing import Any, Set, cast +from typing import Any, Set, Optional, cast import pytest from nonebug import App @@ -154,6 +154,63 @@ async def test_websocket_server(app: App, driver: Driver): await asyncio.sleep(1) +@pytest.mark.asyncio +@pytest.mark.parametrize( + "driver", + [ + pytest.param("nonebot.drivers.fastapi:Driver", id="fastapi"), + pytest.param("nonebot.drivers.quart:Driver", id="quart"), + ], + indirect=True, +) +async def test_cross_context(app: App, driver: Driver): + driver = cast(ReverseDriver, driver) + + ws: Optional[WebSocket] = None + ws_ready = asyncio.Event() + ws_should_close = asyncio.Event() + + async def background_task(): + try: + await ws_ready.wait() + assert ws is not None + + await ws.send("ping") + data = await ws.receive() + assert data == "pong" + finally: + ws_should_close.set() + + task = asyncio.create_task(background_task()) + + async def _handle_ws(websocket: WebSocket) -> None: + nonlocal ws + await websocket.accept() + ws = websocket + ws_ready.set() + + await ws_should_close.wait() + await websocket.close() + + ws_setup = WebSocketServerSetup(URL("/ws_test"), "ws_test", _handle_ws) + driver.setup_websocket_server(ws_setup) + + async with app.test_server(driver.asgi) as ctx: + client = ctx.get_client() + + async with client.websocket_connect("/ws_test") as websocket: + try: + data = await websocket.receive_text() + assert data == "ping" + await websocket.send_text("pong") + except Exception as e: + if not e.args or "websocket.close" not in str(e.args[0]): + raise + + await task + await asyncio.sleep(1) + + @pytest.mark.asyncio @pytest.mark.parametrize( "driver",