🐛 fix quart context error (#2192)

This commit is contained in:
Ju4tCode 2023-07-17 15:01:21 +08:00 committed by GitHub
parent 29364679c4
commit e167865686
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 83 additions and 9 deletions

View File

@ -17,7 +17,18 @@ FrontMatter:
import asyncio import asyncio
from functools import wraps from functools import wraps
from typing import Any, Dict, List, Tuple, Union, TypeVar, Callable, Optional, Coroutine from typing import (
Any,
Dict,
List,
Tuple,
Union,
TypeVar,
Callable,
Optional,
Coroutine,
cast,
)
from pydantic import BaseSettings from pydantic import BaseSettings
@ -33,7 +44,8 @@ from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
try: try:
import uvicorn import uvicorn
from quart import request as _request from quart import request as _request
from quart import websocket as _websocket from quart.ctx import WebsocketContext
from quart.globals import websocket_ctx
from quart import Quart, Request, Response from quart import Quart, Request, Response
from quart.datastructures import FileStorage from quart.datastructures import FileStorage
from quart import Websocket as QuartWebSocket from quart import Websocket as QuartWebSocket
@ -222,7 +234,8 @@ class Driver(ReverseDriver):
) )
async def _handle_ws(self, setup: WebSocketServerSetup) -> None: async def _handle_ws(self, setup: WebSocketServerSetup) -> None:
websocket: QuartWebSocket = _websocket ctx = cast(WebsocketContext, websocket_ctx.copy())
websocket = websocket_ctx.websocket
http_request = BaseRequest( http_request = BaseRequest(
websocket.method, websocket.method,
@ -232,7 +245,7 @@ class Driver(ReverseDriver):
version=websocket.http_version, version=websocket.http_version,
) )
ws = WebSocket(request=http_request, websocket=websocket) ws = WebSocket(request=http_request, websocket_ctx=ctx)
await setup.handle_func(ws) await setup.handle_func(ws)
@ -240,9 +253,13 @@ class Driver(ReverseDriver):
class WebSocket(BaseWebSocket): class WebSocket(BaseWebSocket):
"""Quart WebSocket Wrapper""" """Quart WebSocket Wrapper"""
def __init__(self, *, request: BaseRequest, websocket: QuartWebSocket): def __init__(self, *, request: BaseRequest, websocket_ctx: WebsocketContext):
super().__init__(request=request) super().__init__(request=request)
self.websocket = websocket self.websocket_ctx = websocket_ctx
@property
def websocket(self) -> QuartWebSocket:
return self.websocket_ctx.websocket
@property @property
@overrides(BaseWebSocket) @overrides(BaseWebSocket)

View File

@ -234,7 +234,7 @@ async def test_run_preprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
bot = ctx.create_bot() bot = ctx.create_bot()
event = make_fake_event()() event = make_fake_event()()
ctx.receive_event(bot, event) ctx.receive_event(bot, event)
ctx.should_call_send(event, "test", True, bot) ctx.should_call_send(event, "test", True, bot=bot)
assert runned, "run_preprocessor should runned" assert runned, "run_preprocessor should runned"
@ -346,7 +346,7 @@ async def test_run_postprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
bot = ctx.create_bot() bot = ctx.create_bot()
event = make_fake_event()() event = make_fake_event()()
ctx.receive_event(bot, event) ctx.receive_event(bot, event)
ctx.should_call_send(event, "test", True, bot) ctx.should_call_send(event, "test", True, bot=bot)
assert runned, "run_postprocessor should runned" assert runned, "run_postprocessor should runned"

View File

@ -1,6 +1,6 @@
import json import json
import asyncio import asyncio
from typing import Any, Set, cast from typing import Any, Set, Optional, cast
import pytest import pytest
from nonebug import App from nonebug import App
@ -154,6 +154,63 @@ async def test_websocket_server(app: App, driver: Driver):
await asyncio.sleep(1) await asyncio.sleep(1)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"driver",
[
pytest.param("nonebot.drivers.fastapi:Driver", id="fastapi"),
pytest.param("nonebot.drivers.quart:Driver", id="quart"),
],
indirect=True,
)
async def test_cross_context(app: App, driver: Driver):
driver = cast(ReverseDriver, driver)
ws: Optional[WebSocket] = None
ws_ready = asyncio.Event()
ws_should_close = asyncio.Event()
async def background_task():
try:
await ws_ready.wait()
assert ws is not None
await ws.send("ping")
data = await ws.receive()
assert data == "pong"
finally:
ws_should_close.set()
task = asyncio.create_task(background_task())
async def _handle_ws(websocket: WebSocket) -> None:
nonlocal ws
await websocket.accept()
ws = websocket
ws_ready.set()
await ws_should_close.wait()
await websocket.close()
ws_setup = WebSocketServerSetup(URL("/ws_test"), "ws_test", _handle_ws)
driver.setup_websocket_server(ws_setup)
async with app.test_server(driver.asgi) as ctx:
client = ctx.get_client()
async with client.websocket_connect("/ws_test") as websocket:
try:
data = await websocket.receive_text()
assert data == "ping"
await websocket.send_text("pong")
except Exception as e:
if not e.args or "websocket.close" not in str(e.args[0]):
raise
await task
await asyncio.sleep(1)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"driver", "driver",