🐛 Fix: websockets 驱动器连接关闭 code 获取错误 (#2537)

This commit is contained in:
Ju4tCode 2024-01-17 16:39:35 +08:00 committed by GitHub
parent c2d2169a9f
commit 2c6affecea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 97 additions and 12 deletions

View File

@ -50,10 +50,7 @@ def catch_closed(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
try: try:
return await func(*args, **kwargs) return await func(*args, **kwargs)
except ConnectionClosed as e: except ConnectionClosed as e:
if e.rcvd_then_sent: raise WebSocketClosed(e.code, e.reason)
raise WebSocketClosed(e.rcvd.code, e.rcvd.reason) # type: ignore
else:
raise WebSocketClosed(e.sent.code, e.sent.reason) # type: ignore
return decorator return decorator

6
poetry.lock generated
View File

@ -889,7 +889,7 @@ files = [
name = "h11" name = "h11"
version = "0.14.0" version = "0.14.0"
description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1"
optional = true optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"},
@ -2267,7 +2267,7 @@ dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"]
name = "wsproto" name = "wsproto"
version = "1.2.0" version = "1.2.0"
description = "WebSockets state-machine based protocol implementation" description = "WebSockets state-machine based protocol implementation"
optional = true optional = false
python-versions = ">=3.7.0" python-versions = ">=3.7.0"
files = [ files = [
{file = "wsproto-1.2.0-py3-none-any.whl", hash = "sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736"}, {file = "wsproto-1.2.0-py3-none-any.whl", hash = "sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736"},
@ -2406,4 +2406,4 @@ websockets = ["websockets"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.8" python-versions = "^3.8"
content-hash = "e7bd1c1b070f1a46d94022047f2b76dbf90751f49086a099139f2ade4ad07a65" content-hash = "ec064b0d1c22da40c55132f706fbf3802b8a5f8dcf647c2302ee0a2d248e3340"

View File

@ -52,6 +52,7 @@ ruff = ">=0.0.272,<1.0.0"
[tool.poetry.group.test.dependencies] [tool.poetry.group.test.dependencies]
nonebug = "^0.3.0" nonebug = "^0.3.0"
wsproto = "^1.2.0"
pytest-cov = "^4.0.0" pytest-cov = "^4.0.0"
pytest-xdist = "^3.0.2" pytest-xdist = "^3.0.2"
pytest-asyncio = "^0.23.2" pytest-asyncio = "^0.23.2"

View File

@ -1,9 +1,15 @@
import json import json
import base64 import base64
import socket
from typing import Dict, List, Union, TypeVar from typing import Dict, List, Union, TypeVar
from wsproto.events import Ping
from werkzeug import Request, Response from werkzeug import Request, Response
from werkzeug.datastructures import MultiDict from werkzeug.datastructures import MultiDict
from wsproto.frame_protocol import CloseReason
from wsproto.events import Request as WSRequest
from wsproto import WSConnection, ConnectionType
from wsproto.events import TextMessage, BytesMessage, CloseConnection, AcceptConnection
K = TypeVar("K") K = TypeVar("K")
V = TypeVar("V") V = TypeVar("V")
@ -29,8 +35,7 @@ def flattern(d: "MultiDict[K, V]") -> Dict[K, Union[V, List[V]]]:
return {k: v[0] if len(v) == 1 else v for k, v in d.to_dict(flat=False).items()} return {k: v[0] if len(v) == 1 else v for k, v in d.to_dict(flat=False).items()}
@Request.application def http_echo(request: Request) -> Response:
def request_handler(request: Request) -> Response:
try: try:
_json = json.loads(request.data.decode("utf-8")) _json = json.loads(request.data.decode("utf-8"))
except (ValueError, TypeError): except (ValueError, TypeError):
@ -67,3 +72,65 @@ def request_handler(request: Request) -> Response:
status=200, status=200,
content_type="application/json", content_type="application/json",
) )
def websocket_echo(request: Request) -> Response:
stream = request.environ["werkzeug.socket"]
ws = WSConnection(ConnectionType.SERVER)
in_data = b"GET %s HTTP/1.1\r\n" % request.path.encode("utf-8")
for header, value in request.headers.items():
in_data += f"{header}: {value}\r\n".encode()
in_data += b"\r\n"
ws.receive_data(in_data)
running: bool = True
while True:
out_data = b""
for event in ws.events():
if isinstance(event, WSRequest):
out_data += ws.send(AcceptConnection())
elif isinstance(event, CloseConnection):
out_data += ws.send(event.response())
running = False
elif isinstance(event, Ping):
out_data += ws.send(event.response())
elif isinstance(event, TextMessage):
if event.data == "quit":
out_data += ws.send(
CloseConnection(CloseReason.NORMAL_CLOSURE, "bye")
)
running = False
else:
out_data += ws.send(TextMessage(data=event.data))
elif isinstance(event, BytesMessage):
if event.data == b"quit":
out_data += ws.send(
CloseConnection(CloseReason.NORMAL_CLOSURE, "bye")
)
running = False
else:
out_data += ws.send(BytesMessage(data=event.data))
if out_data:
stream.send(out_data)
if not running:
break
in_data = stream.recv(4096)
ws.receive_data(in_data)
stream.shutdown(socket.SHUT_RDWR)
return Response("", status=204)
@Request.application
def request_handler(request: Request) -> Response:
if request.headers.get("Connection") == "Upgrade":
return websocket_echo(request)
else:
return http_echo(request)

View File

@ -131,7 +131,7 @@ async def test_websocket_server(app: App, driver: Driver):
assert data == b"ping" assert data == b"ping"
await ws.send(b"pong") await ws.send(b"pong")
with pytest.raises(WebSocketClosed): with pytest.raises(WebSocketClosed, match=r"code=1000"):
await ws.receive() await ws.receive()
ws_setup = WebSocketServerSetup(URL("/ws_test"), "ws_test", _handle_ws) ws_setup = WebSocketServerSetup(URL("/ws_test"), "ws_test", _handle_ws)
@ -152,7 +152,7 @@ async def test_websocket_server(app: App, driver: Driver):
await ws.send_bytes(b"ping") await ws.send_bytes(b"ping")
assert await ws.receive_bytes() == b"pong" assert await ws.receive_bytes() == b"pong"
await ws.close() await ws.close(code=1000)
await asyncio.sleep(1) await asyncio.sleep(1)
@ -315,9 +315,29 @@ async def test_http_client(driver: Driver, server_url: URL):
], ],
indirect=True, indirect=True,
) )
async def test_websocket_client(driver: Driver): async def test_websocket_client(driver: Driver, server_url: URL):
assert isinstance(driver, WebSocketClientMixin) assert isinstance(driver, WebSocketClientMixin)
request = Request("GET", server_url.with_scheme("ws"))
async with driver.websocket(request) as ws:
await ws.send("test")
assert await ws.receive() == "test"
await ws.send(b"test")
assert await ws.receive() == b"test"
await ws.send_text("test")
assert await ws.receive_text() == "test"
await ws.send_bytes(b"test")
assert await ws.receive_bytes() == b"test"
await ws.send("quit")
with pytest.raises(WebSocketClosed, match=r"code=1000"):
await ws.receive()
await asyncio.sleep(1)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(