mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-27 18:45:05 +08:00
🐛 Fix: websockets 驱动器连接关闭 code 获取错误 (#2537)
This commit is contained in:
parent
c2d2169a9f
commit
2c6affecea
@ -50,10 +50,7 @@ def catch_closed(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
|
|||||||
try:
|
try:
|
||||||
return await func(*args, **kwargs)
|
return await func(*args, **kwargs)
|
||||||
except ConnectionClosed as e:
|
except ConnectionClosed as e:
|
||||||
if e.rcvd_then_sent:
|
raise WebSocketClosed(e.code, e.reason)
|
||||||
raise WebSocketClosed(e.rcvd.code, e.rcvd.reason) # type: ignore
|
|
||||||
else:
|
|
||||||
raise WebSocketClosed(e.sent.code, e.sent.reason) # type: ignore
|
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
6
poetry.lock
generated
6
poetry.lock
generated
@ -889,7 +889,7 @@ files = [
|
|||||||
name = "h11"
|
name = "h11"
|
||||||
version = "0.14.0"
|
version = "0.14.0"
|
||||||
description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1"
|
description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1"
|
||||||
optional = true
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"},
|
{file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"},
|
||||||
@ -2267,7 +2267,7 @@ dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"]
|
|||||||
name = "wsproto"
|
name = "wsproto"
|
||||||
version = "1.2.0"
|
version = "1.2.0"
|
||||||
description = "WebSockets state-machine based protocol implementation"
|
description = "WebSockets state-machine based protocol implementation"
|
||||||
optional = true
|
optional = false
|
||||||
python-versions = ">=3.7.0"
|
python-versions = ">=3.7.0"
|
||||||
files = [
|
files = [
|
||||||
{file = "wsproto-1.2.0-py3-none-any.whl", hash = "sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736"},
|
{file = "wsproto-1.2.0-py3-none-any.whl", hash = "sha256:b9acddd652b585d75b20477888c56642fdade28bdfd3579aa24a4d2c037dd736"},
|
||||||
@ -2406,4 +2406,4 @@ websockets = ["websockets"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.8"
|
python-versions = "^3.8"
|
||||||
content-hash = "e7bd1c1b070f1a46d94022047f2b76dbf90751f49086a099139f2ade4ad07a65"
|
content-hash = "ec064b0d1c22da40c55132f706fbf3802b8a5f8dcf647c2302ee0a2d248e3340"
|
||||||
|
@ -52,6 +52,7 @@ ruff = ">=0.0.272,<1.0.0"
|
|||||||
|
|
||||||
[tool.poetry.group.test.dependencies]
|
[tool.poetry.group.test.dependencies]
|
||||||
nonebug = "^0.3.0"
|
nonebug = "^0.3.0"
|
||||||
|
wsproto = "^1.2.0"
|
||||||
pytest-cov = "^4.0.0"
|
pytest-cov = "^4.0.0"
|
||||||
pytest-xdist = "^3.0.2"
|
pytest-xdist = "^3.0.2"
|
||||||
pytest-asyncio = "^0.23.2"
|
pytest-asyncio = "^0.23.2"
|
||||||
|
@ -1,9 +1,15 @@
|
|||||||
import json
|
import json
|
||||||
import base64
|
import base64
|
||||||
|
import socket
|
||||||
from typing import Dict, List, Union, TypeVar
|
from typing import Dict, List, Union, TypeVar
|
||||||
|
|
||||||
|
from wsproto.events import Ping
|
||||||
from werkzeug import Request, Response
|
from werkzeug import Request, Response
|
||||||
from werkzeug.datastructures import MultiDict
|
from werkzeug.datastructures import MultiDict
|
||||||
|
from wsproto.frame_protocol import CloseReason
|
||||||
|
from wsproto.events import Request as WSRequest
|
||||||
|
from wsproto import WSConnection, ConnectionType
|
||||||
|
from wsproto.events import TextMessage, BytesMessage, CloseConnection, AcceptConnection
|
||||||
|
|
||||||
K = TypeVar("K")
|
K = TypeVar("K")
|
||||||
V = TypeVar("V")
|
V = TypeVar("V")
|
||||||
@ -29,8 +35,7 @@ def flattern(d: "MultiDict[K, V]") -> Dict[K, Union[V, List[V]]]:
|
|||||||
return {k: v[0] if len(v) == 1 else v for k, v in d.to_dict(flat=False).items()}
|
return {k: v[0] if len(v) == 1 else v for k, v in d.to_dict(flat=False).items()}
|
||||||
|
|
||||||
|
|
||||||
@Request.application
|
def http_echo(request: Request) -> Response:
|
||||||
def request_handler(request: Request) -> Response:
|
|
||||||
try:
|
try:
|
||||||
_json = json.loads(request.data.decode("utf-8"))
|
_json = json.loads(request.data.decode("utf-8"))
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
@ -67,3 +72,65 @@ def request_handler(request: Request) -> Response:
|
|||||||
status=200,
|
status=200,
|
||||||
content_type="application/json",
|
content_type="application/json",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def websocket_echo(request: Request) -> Response:
|
||||||
|
stream = request.environ["werkzeug.socket"]
|
||||||
|
|
||||||
|
ws = WSConnection(ConnectionType.SERVER)
|
||||||
|
|
||||||
|
in_data = b"GET %s HTTP/1.1\r\n" % request.path.encode("utf-8")
|
||||||
|
for header, value in request.headers.items():
|
||||||
|
in_data += f"{header}: {value}\r\n".encode()
|
||||||
|
in_data += b"\r\n"
|
||||||
|
|
||||||
|
ws.receive_data(in_data)
|
||||||
|
|
||||||
|
running: bool = True
|
||||||
|
while True:
|
||||||
|
out_data = b""
|
||||||
|
|
||||||
|
for event in ws.events():
|
||||||
|
if isinstance(event, WSRequest):
|
||||||
|
out_data += ws.send(AcceptConnection())
|
||||||
|
elif isinstance(event, CloseConnection):
|
||||||
|
out_data += ws.send(event.response())
|
||||||
|
running = False
|
||||||
|
elif isinstance(event, Ping):
|
||||||
|
out_data += ws.send(event.response())
|
||||||
|
elif isinstance(event, TextMessage):
|
||||||
|
if event.data == "quit":
|
||||||
|
out_data += ws.send(
|
||||||
|
CloseConnection(CloseReason.NORMAL_CLOSURE, "bye")
|
||||||
|
)
|
||||||
|
running = False
|
||||||
|
else:
|
||||||
|
out_data += ws.send(TextMessage(data=event.data))
|
||||||
|
elif isinstance(event, BytesMessage):
|
||||||
|
if event.data == b"quit":
|
||||||
|
out_data += ws.send(
|
||||||
|
CloseConnection(CloseReason.NORMAL_CLOSURE, "bye")
|
||||||
|
)
|
||||||
|
running = False
|
||||||
|
else:
|
||||||
|
out_data += ws.send(BytesMessage(data=event.data))
|
||||||
|
|
||||||
|
if out_data:
|
||||||
|
stream.send(out_data)
|
||||||
|
|
||||||
|
if not running:
|
||||||
|
break
|
||||||
|
|
||||||
|
in_data = stream.recv(4096)
|
||||||
|
ws.receive_data(in_data)
|
||||||
|
|
||||||
|
stream.shutdown(socket.SHUT_RDWR)
|
||||||
|
return Response("", status=204)
|
||||||
|
|
||||||
|
|
||||||
|
@Request.application
|
||||||
|
def request_handler(request: Request) -> Response:
|
||||||
|
if request.headers.get("Connection") == "Upgrade":
|
||||||
|
return websocket_echo(request)
|
||||||
|
else:
|
||||||
|
return http_echo(request)
|
||||||
|
@ -131,7 +131,7 @@ async def test_websocket_server(app: App, driver: Driver):
|
|||||||
assert data == b"ping"
|
assert data == b"ping"
|
||||||
await ws.send(b"pong")
|
await ws.send(b"pong")
|
||||||
|
|
||||||
with pytest.raises(WebSocketClosed):
|
with pytest.raises(WebSocketClosed, match=r"code=1000"):
|
||||||
await ws.receive()
|
await ws.receive()
|
||||||
|
|
||||||
ws_setup = WebSocketServerSetup(URL("/ws_test"), "ws_test", _handle_ws)
|
ws_setup = WebSocketServerSetup(URL("/ws_test"), "ws_test", _handle_ws)
|
||||||
@ -152,7 +152,7 @@ async def test_websocket_server(app: App, driver: Driver):
|
|||||||
await ws.send_bytes(b"ping")
|
await ws.send_bytes(b"ping")
|
||||||
assert await ws.receive_bytes() == b"pong"
|
assert await ws.receive_bytes() == b"pong"
|
||||||
|
|
||||||
await ws.close()
|
await ws.close(code=1000)
|
||||||
|
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
@ -315,9 +315,29 @@ async def test_http_client(driver: Driver, server_url: URL):
|
|||||||
],
|
],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
async def test_websocket_client(driver: Driver):
|
async def test_websocket_client(driver: Driver, server_url: URL):
|
||||||
assert isinstance(driver, WebSocketClientMixin)
|
assert isinstance(driver, WebSocketClientMixin)
|
||||||
|
|
||||||
|
request = Request("GET", server_url.with_scheme("ws"))
|
||||||
|
async with driver.websocket(request) as ws:
|
||||||
|
await ws.send("test")
|
||||||
|
assert await ws.receive() == "test"
|
||||||
|
|
||||||
|
await ws.send(b"test")
|
||||||
|
assert await ws.receive() == b"test"
|
||||||
|
|
||||||
|
await ws.send_text("test")
|
||||||
|
assert await ws.receive_text() == "test"
|
||||||
|
|
||||||
|
await ws.send_bytes(b"test")
|
||||||
|
assert await ws.receive_bytes() == b"test"
|
||||||
|
|
||||||
|
await ws.send("quit")
|
||||||
|
with pytest.raises(WebSocketClosed, match=r"code=1000"):
|
||||||
|
await ws.receive()
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
Loading…
Reference in New Issue
Block a user