diff --git a/nonebot/drivers/websockets.py b/nonebot/drivers/websockets.py index 5c4e167e..e618af1b 100644 --- a/nonebot/drivers/websockets.py +++ b/nonebot/drivers/websockets.py @@ -50,10 +50,7 @@ def catch_closed(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: try: return await func(*args, **kwargs) except ConnectionClosed as e: - if e.rcvd_then_sent: - raise WebSocketClosed(e.rcvd.code, e.rcvd.reason) # type: ignore - else: - raise WebSocketClosed(e.sent.code, e.sent.reason) # type: ignore + raise WebSocketClosed(e.code, e.reason) return decorator diff --git a/poetry.lock b/poetry.lock index 1ca194d9..ddeacaaf 100644 --- a/poetry.lock +++ b/poetry.lock @@ -889,7 +889,7 @@ files = [ name = "h11" version = "0.14.0" description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, @@ -2267,7 +2267,7 @@ dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] name = "wsproto" version = "1.2.0" description = "WebSockets state-machine based protocol implementation" -optional = true +optional = false python-versions = ">=3.7.0" files = [ {file = "wsproto-1.2.0-py3-none-any.whl", hash = "sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736"}, @@ -2406,4 +2406,4 @@ websockets = ["websockets"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "e7bd1c1b070f1a46d94022047f2b76dbf90751f49086a099139f2ade4ad07a65" +content-hash = "ec064b0d1c22da40c55132f706fbf3802b8a5f8dcf647c2302ee0a2d248e3340" diff --git a/pyproject.toml b/pyproject.toml index 3cbde147..68ab0cba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ ruff = ">=0.0.272,<1.0.0" [tool.poetry.group.test.dependencies] nonebug = "^0.3.0" +wsproto = "^1.2.0" pytest-cov = "^4.0.0" pytest-xdist = "^3.0.2" pytest-asyncio = "^0.23.2" diff --git a/tests/fake_server.py b/tests/fake_server.py index 3f19cf2c..353b7512 100644 --- a/tests/fake_server.py +++ b/tests/fake_server.py @@ -1,9 +1,15 @@ import json import base64 +import socket from typing import Dict, List, Union, TypeVar +from wsproto.events import Ping from werkzeug import Request, Response from werkzeug.datastructures import MultiDict +from wsproto.frame_protocol import CloseReason +from wsproto.events import Request as WSRequest +from wsproto import WSConnection, ConnectionType +from wsproto.events import TextMessage, BytesMessage, CloseConnection, AcceptConnection K = TypeVar("K") V = TypeVar("V") @@ -29,8 +35,7 @@ def flattern(d: "MultiDict[K, V]") -> Dict[K, Union[V, List[V]]]: return {k: v[0] if len(v) == 1 else v for k, v in d.to_dict(flat=False).items()} -@Request.application -def request_handler(request: Request) -> Response: +def http_echo(request: Request) -> Response: try: _json = json.loads(request.data.decode("utf-8")) except (ValueError, TypeError): @@ -67,3 +72,65 @@ def request_handler(request: Request) -> Response: status=200, content_type="application/json", ) + + +def websocket_echo(request: Request) -> Response: + stream = request.environ["werkzeug.socket"] + + ws = WSConnection(ConnectionType.SERVER) + + in_data = b"GET %s HTTP/1.1\r\n" % request.path.encode("utf-8") + for header, value in request.headers.items(): + in_data += f"{header}: {value}\r\n".encode() + in_data += b"\r\n" + + ws.receive_data(in_data) + + running: bool = True + while True: + out_data = b"" + + for event in ws.events(): + if isinstance(event, WSRequest): + out_data += ws.send(AcceptConnection()) + elif isinstance(event, CloseConnection): + out_data += ws.send(event.response()) + running = False + elif isinstance(event, Ping): + out_data += ws.send(event.response()) + elif isinstance(event, TextMessage): + if event.data == "quit": + out_data += ws.send( + CloseConnection(CloseReason.NORMAL_CLOSURE, "bye") + ) + running = False + else: + out_data += ws.send(TextMessage(data=event.data)) + elif isinstance(event, BytesMessage): + if event.data == b"quit": + out_data += ws.send( + CloseConnection(CloseReason.NORMAL_CLOSURE, "bye") + ) + running = False + else: + out_data += ws.send(BytesMessage(data=event.data)) + + if out_data: + stream.send(out_data) + + if not running: + break + + in_data = stream.recv(4096) + ws.receive_data(in_data) + + stream.shutdown(socket.SHUT_RDWR) + return Response("", status=204) + + +@Request.application +def request_handler(request: Request) -> Response: + if request.headers.get("Connection") == "Upgrade": + return websocket_echo(request) + else: + return http_echo(request) diff --git a/tests/test_driver.py b/tests/test_driver.py index cd9bc3a8..958536b7 100644 --- a/tests/test_driver.py +++ b/tests/test_driver.py @@ -131,7 +131,7 @@ async def test_websocket_server(app: App, driver: Driver): assert data == b"ping" await ws.send(b"pong") - with pytest.raises(WebSocketClosed): + with pytest.raises(WebSocketClosed, match=r"code=1000"): await ws.receive() ws_setup = WebSocketServerSetup(URL("/ws_test"), "ws_test", _handle_ws) @@ -152,7 +152,7 @@ async def test_websocket_server(app: App, driver: Driver): await ws.send_bytes(b"ping") assert await ws.receive_bytes() == b"pong" - await ws.close() + await ws.close(code=1000) await asyncio.sleep(1) @@ -315,9 +315,29 @@ async def test_http_client(driver: Driver, server_url: URL): ], indirect=True, ) -async def test_websocket_client(driver: Driver): +async def test_websocket_client(driver: Driver, server_url: URL): assert isinstance(driver, WebSocketClientMixin) + request = Request("GET", server_url.with_scheme("ws")) + async with driver.websocket(request) as ws: + await ws.send("test") + assert await ws.receive() == "test" + + await ws.send(b"test") + assert await ws.receive() == b"test" + + await ws.send_text("test") + assert await ws.receive_text() == "test" + + await ws.send_bytes(b"test") + assert await ws.receive_bytes() == b"test" + + await ws.send("quit") + with pytest.raises(WebSocketClosed, match=r"code=1000"): + await ws.receive() + + await asyncio.sleep(1) + @pytest.mark.asyncio @pytest.mark.parametrize(