From 56f99b7f0be8cbb5e6b3d6a8ed385b0e7112253f Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Sat, 14 May 2022 21:06:57 +0800 Subject: [PATCH] =?UTF-8?q?:sparkles:=20Feat:=20=E6=94=AF=E6=8C=81=20WebSo?= =?UTF-8?q?cket=20=E8=BF=9E=E6=8E=A5=E5=90=8C=E6=97=B6=E8=8E=B7=E5=8F=96?= =?UTF-8?q?=20str=20=E6=88=96=20bytes=20(#962)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot/drivers/aiohttp.py | 21 +++++++++++++++++---- nonebot/drivers/fastapi.py | 16 +++++++++++++--- nonebot/drivers/httpx.py | 10 +++++----- nonebot/drivers/quart.py | 15 ++++++++++----- nonebot/drivers/websockets.py | 16 +++++++++++----- nonebot/internal/driver/model.py | 18 ++++++++++++++++-- tests/test_driver.py | 27 ++++++++++++++++++++++++++- 7 files changed, 98 insertions(+), 25 deletions(-) diff --git a/nonebot/drivers/aiohttp.py b/nonebot/drivers/aiohttp.py index 971ce4ac..6d686e05 100644 --- a/nonebot/drivers/aiohttp.py +++ b/nonebot/drivers/aiohttp.py @@ -132,20 +132,33 @@ class WebSocket(BaseWebSocket): @overrides(BaseWebSocket) async def receive(self) -> str: + msg = await self._receive() + if msg.type not in (aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY): + raise TypeError( + f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}" + ) + return msg.data + + @overrides(BaseWebSocket) + async def receive_text(self) -> str: msg = await self._receive() if msg.type != aiohttp.WSMsgType.TEXT: - raise TypeError(f"WebSocket received unexpected frame type: {msg.type}") + raise TypeError( + f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}" + ) return msg.data @overrides(BaseWebSocket) async def receive_bytes(self) -> bytes: msg = await self._receive() - if msg.type != aiohttp.WSMsgType.TEXT: - raise TypeError(f"WebSocket received unexpected frame type: {msg.type}") + if msg.type != aiohttp.WSMsgType.BINARY: + raise TypeError( + f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}" + ) return msg.data @overrides(BaseWebSocket) - async def send(self, data: str) -> None: + async def send_text(self, data: str) -> None: await self.websocket.send_str(data) @overrides(BaseWebSocket) diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index e714300a..e0494ff6 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -11,7 +11,7 @@ FrontMatter: import logging from functools import wraps -from typing import Any, List, Tuple, Callable, Optional +from typing import Any, List, Tuple, Union, Callable, Optional import uvicorn from pydantic import BaseSettings @@ -36,6 +36,8 @@ def catch_closed(func): return await func(*args, **kwargs) except WebSocketDisconnect as e: raise WebSocketClosed(e.code) + except KeyError: + raise TypeError("WebSocket received unexpected frame type") return decorator @@ -261,9 +263,17 @@ class FastAPIWebSocket(BaseWebSocket): ) -> None: await self.websocket.close(code) + @overrides(BaseWebSocket) + async def receive(self) -> Union[str, bytes]: + # assert self.websocket.application_state == WebSocketState.CONNECTED + msg = await self.websocket.receive() + if msg["type"] == "websocket.disconnect": + raise WebSocketClosed(msg["code"]) + return msg["text"] if "text" in msg else msg["bytes"] + @overrides(BaseWebSocket) @catch_closed - async def receive(self) -> str: + async def receive_text(self) -> str: return await self.websocket.receive_text() @overrides(BaseWebSocket) @@ -272,7 +282,7 @@ class FastAPIWebSocket(BaseWebSocket): return await self.websocket.receive_bytes() @overrides(BaseWebSocket) - async def send(self, data: str) -> None: + async def send_text(self, data: str) -> None: await self.websocket.send({"type": "websocket.send", "text": data}) @overrides(BaseWebSocket) diff --git a/nonebot/drivers/httpx.py b/nonebot/drivers/httpx.py index 79100dcf..fa167b6b 100644 --- a/nonebot/drivers/httpx.py +++ b/nonebot/drivers/httpx.py @@ -49,22 +49,22 @@ class Mixin(ForwardMixin): async def request(self, setup: Request) -> Response: async with httpx.AsyncClient( http2=setup.version == HTTPVersion.H2, - proxies=setup.proxy, + proxies=setup.proxy, # type: ignore follow_redirects=True, ) as client: response = await client.request( setup.method, str(setup.url), - content=setup.content, - data=setup.data, + content=setup.content, # type: ignore + data=setup.data, # type: ignore json=setup.json, - files=setup.files, + files=setup.files, # type: ignore headers=tuple(setup.headers.items()), timeout=setup.timeout, ) return Response( response.status_code, - headers=response.headers, + headers=response.headers.multi_items(), content=response.content, request=setup, ) diff --git a/nonebot/drivers/quart.py b/nonebot/drivers/quart.py index 882072ae..61df7c16 100644 --- a/nonebot/drivers/quart.py +++ b/nonebot/drivers/quart.py @@ -17,7 +17,7 @@ FrontMatter: import asyncio from functools import wraps -from typing import List, Tuple, TypeVar, Callable, Optional, Coroutine +from typing import List, Tuple, Union, TypeVar, Callable, Optional, Coroutine import uvicorn from pydantic import BaseSettings @@ -199,7 +199,7 @@ class Driver(ReverseDriver): http_request = BaseRequest( request.method, request.url, - headers=request.headers.items(), + headers=list(request.headers.items()), cookies=list(request.cookies.items()), content=await request.get_data( cache=False, as_text=False, parse_form_data=False @@ -224,7 +224,7 @@ class Driver(ReverseDriver): http_request = BaseRequest( websocket.method, websocket.url, - headers=websocket.headers.items(), + headers=list(websocket.headers.items()), cookies=list(websocket.cookies.items()), version=websocket.http_version, ) @@ -257,7 +257,12 @@ class WebSocket(BaseWebSocket): @overrides(BaseWebSocket) @catch_closed - async def receive(self) -> str: + async def receive(self) -> Union[str, bytes]: + return await self.websocket.receive() + + @overrides(BaseWebSocket) + @catch_closed + async def receive_text(self) -> str: msg = await self.websocket.receive() if isinstance(msg, bytes): raise TypeError("WebSocket received unexpected frame type: bytes") @@ -272,7 +277,7 @@ class WebSocket(BaseWebSocket): return msg @overrides(BaseWebSocket) - async def send(self, data: str): + async def send_text(self, data: str): await self.websocket.send(data) @overrides(BaseWebSocket) diff --git a/nonebot/drivers/websockets.py b/nonebot/drivers/websockets.py index 57b87169..9b78c829 100644 --- a/nonebot/drivers/websockets.py +++ b/nonebot/drivers/websockets.py @@ -16,8 +16,8 @@ FrontMatter: """ import logging from functools import wraps -from typing import Type, AsyncGenerator from contextlib import asynccontextmanager +from typing import Type, Union, AsyncGenerator from nonebot.typing import overrides from nonebot.log import LoguruHandler @@ -46,9 +46,9 @@ def catch_closed(func): return await func(*args, **kwargs) except ConnectionClosed as e: if e.rcvd_then_sent: - raise WebSocketClosed(e.rcvd.code, e.rcvd.reason) + raise WebSocketClosed(e.rcvd.code, e.rcvd.reason) # type: ignore else: - raise WebSocketClosed(e.sent.code, e.sent.reason) + raise WebSocketClosed(e.sent.code, e.sent.reason) # type: ignore return decorator @@ -100,7 +100,13 @@ class WebSocket(BaseWebSocket): @overrides(BaseWebSocket) @catch_closed - async def receive(self) -> str: + async def receive(self) -> Union[str, bytes]: + msg = await self.websocket.recv() + return msg + + @overrides(BaseWebSocket) + @catch_closed + async def receive_text(self) -> str: msg = await self.websocket.recv() if isinstance(msg, bytes): raise TypeError("WebSocket received unexpected frame type: bytes") @@ -115,7 +121,7 @@ class WebSocket(BaseWebSocket): return msg @overrides(BaseWebSocket) - async def send(self, data: str) -> None: + async def send_text(self, data: str) -> None: await self.websocket.send(data) @overrides(BaseWebSocket) diff --git a/nonebot/internal/driver/model.py b/nonebot/internal/driver/model.py index bd601083..de13b81a 100644 --- a/nonebot/internal/driver/model.py +++ b/nonebot/internal/driver/model.py @@ -186,7 +186,12 @@ class WebSocket(abc.ABC): raise NotImplementedError @abc.abstractmethod - async def receive(self) -> str: + async def receive(self) -> Union[str, bytes]: + """接收一条 WebSocket text/bytes 信息""" + raise NotImplementedError + + @abc.abstractmethod + async def receive_text(self) -> str: """接收一条 WebSocket text 信息""" raise NotImplementedError @@ -195,8 +200,17 @@ class WebSocket(abc.ABC): """接收一条 WebSocket binary 信息""" raise NotImplementedError + async def send(self, data: Union[str, bytes]) -> None: + """发送一条 WebSocket text/bytes 信息""" + if isinstance(data, str): + await self.send_text(data) + elif isinstance(data, bytes): + await self.send_bytes(data) + else: + raise TypeError("WebSocker send method expects str or bytes!") + @abc.abstractmethod - async def send(self, data: str) -> None: + async def send_text(self, data: str) -> None: """发送一条 WebSocket text 信息""" raise NotImplementedError diff --git a/tests/test_driver.py b/tests/test_driver.py index 21949b2b..c74a448a 100644 --- a/tests/test_driver.py +++ b/tests/test_driver.py @@ -15,6 +15,7 @@ from nonebug import App ) async def test_reverse_driver(app: App): import nonebot + from nonebot.exception import WebSocketClosed from nonebot.drivers import ( URL, Request, @@ -36,7 +37,21 @@ async def test_reverse_driver(app: App): data = await ws.receive() assert data == "ping" await ws.send("pong") - await ws.close() + + data = await ws.receive() + assert data == b"ping" + await ws.send(b"pong") + + data = await ws.receive_text() + assert data == "ping" + await ws.send("pong") + + data = await ws.receive_bytes() + assert data == b"ping" + await ws.send(b"pong") + + with pytest.raises(WebSocketClosed): + await ws.receive() http_setup = HTTPServerSetup(URL("/http_test"), "POST", "http_test", _handle_http) driver.setup_http_server(http_setup) @@ -53,3 +68,13 @@ async def test_reverse_driver(app: App): async with client.websocket_connect("/ws_test") as ws: await ws.send_text("ping") assert await ws.receive_text() == "pong" + await ws.send_bytes(b"ping") + assert await ws.receive_bytes() == b"pong" + + await ws.send_text("ping") + assert await ws.receive_text() == "pong" + + await ws.send_bytes(b"ping") + assert await ws.receive_bytes() == b"pong" + + await ws.close()