🐛 fix quart context error (#2192)

This commit is contained in:
Ju4tCode 2023-07-17 15:01:21 +08:00 committed by GitHub
parent 29364679c4
commit e167865686
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 83 additions and 9 deletions

View File

@ -17,7 +17,18 @@ FrontMatter:
import asyncio
from functools import wraps
from typing import Any, Dict, List, Tuple, Union, TypeVar, Callable, Optional, Coroutine
from typing import (
Any,
Dict,
List,
Tuple,
Union,
TypeVar,
Callable,
Optional,
Coroutine,
cast,
)
from pydantic import BaseSettings
@ -33,7 +44,8 @@ from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
try:
import uvicorn
from quart import request as _request
from quart import websocket as _websocket
from quart.ctx import WebsocketContext
from quart.globals import websocket_ctx
from quart import Quart, Request, Response
from quart.datastructures import FileStorage
from quart import Websocket as QuartWebSocket
@ -222,7 +234,8 @@ class Driver(ReverseDriver):
)
async def _handle_ws(self, setup: WebSocketServerSetup) -> None:
websocket: QuartWebSocket = _websocket
ctx = cast(WebsocketContext, websocket_ctx.copy())
websocket = websocket_ctx.websocket
http_request = BaseRequest(
websocket.method,
@ -232,7 +245,7 @@ class Driver(ReverseDriver):
version=websocket.http_version,
)
ws = WebSocket(request=http_request, websocket=websocket)
ws = WebSocket(request=http_request, websocket_ctx=ctx)
await setup.handle_func(ws)
@ -240,9 +253,13 @@ class Driver(ReverseDriver):
class WebSocket(BaseWebSocket):
"""Quart WebSocket Wrapper"""
def __init__(self, *, request: BaseRequest, websocket: QuartWebSocket):
def __init__(self, *, request: BaseRequest, websocket_ctx: WebsocketContext):
super().__init__(request=request)
self.websocket = websocket
self.websocket_ctx = websocket_ctx
@property
def websocket(self) -> QuartWebSocket:
return self.websocket_ctx.websocket
@property
@overrides(BaseWebSocket)

View File

@ -234,7 +234,7 @@ async def test_run_preprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
bot = ctx.create_bot()
event = make_fake_event()()
ctx.receive_event(bot, event)
ctx.should_call_send(event, "test", True, bot)
ctx.should_call_send(event, "test", True, bot=bot)
assert runned, "run_preprocessor should runned"
@ -346,7 +346,7 @@ async def test_run_postprocessor(app: App, monkeypatch: pytest.MonkeyPatch):
bot = ctx.create_bot()
event = make_fake_event()()
ctx.receive_event(bot, event)
ctx.should_call_send(event, "test", True, bot)
ctx.should_call_send(event, "test", True, bot=bot)
assert runned, "run_postprocessor should runned"

View File

@ -1,6 +1,6 @@
import json
import asyncio
from typing import Any, Set, cast
from typing import Any, Set, Optional, cast
import pytest
from nonebug import App
@ -154,6 +154,63 @@ async def test_websocket_server(app: App, driver: Driver):
await asyncio.sleep(1)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"driver",
[
pytest.param("nonebot.drivers.fastapi:Driver", id="fastapi"),
pytest.param("nonebot.drivers.quart:Driver", id="quart"),
],
indirect=True,
)
async def test_cross_context(app: App, driver: Driver):
driver = cast(ReverseDriver, driver)
ws: Optional[WebSocket] = None
ws_ready = asyncio.Event()
ws_should_close = asyncio.Event()
async def background_task():
try:
await ws_ready.wait()
assert ws is not None
await ws.send("ping")
data = await ws.receive()
assert data == "pong"
finally:
ws_should_close.set()
task = asyncio.create_task(background_task())
async def _handle_ws(websocket: WebSocket) -> None:
nonlocal ws
await websocket.accept()
ws = websocket
ws_ready.set()
await ws_should_close.wait()
await websocket.close()
ws_setup = WebSocketServerSetup(URL("/ws_test"), "ws_test", _handle_ws)
driver.setup_websocket_server(ws_setup)
async with app.test_server(driver.asgi) as ctx:
client = ctx.get_client()
async with client.websocket_connect("/ws_test") as websocket:
try:
data = await websocket.receive_text()
assert data == "ping"
await websocket.send_text("pong")
except Exception as e:
if not e.args or "websocket.close" not in str(e.args[0]):
raise
await task
await asyncio.sleep(1)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"driver",