Feat: 支持 WebSocket 连接同时获取 str 或 bytes (#962)

This commit is contained in:
Ju4tCode 2022-05-14 21:06:57 +08:00 committed by GitHub
parent 91c5056c97
commit 56f99b7f0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 98 additions and 25 deletions

View File

@ -132,20 +132,33 @@ class WebSocket(BaseWebSocket):
@overrides(BaseWebSocket)
async def receive(self) -> str:
msg = await self._receive()
if msg.type not in (aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY):
raise TypeError(
f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}"
)
return msg.data
@overrides(BaseWebSocket)
async def receive_text(self) -> str:
msg = await self._receive()
if msg.type != aiohttp.WSMsgType.TEXT:
raise TypeError(f"WebSocket received unexpected frame type: {msg.type}")
raise TypeError(
f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}"
)
return msg.data
@overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes:
msg = await self._receive()
if msg.type != aiohttp.WSMsgType.TEXT:
raise TypeError(f"WebSocket received unexpected frame type: {msg.type}")
if msg.type != aiohttp.WSMsgType.BINARY:
raise TypeError(
f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}"
)
return msg.data
@overrides(BaseWebSocket)
async def send(self, data: str) -> None:
async def send_text(self, data: str) -> None:
await self.websocket.send_str(data)
@overrides(BaseWebSocket)

View File

@ -11,7 +11,7 @@ FrontMatter:
import logging
from functools import wraps
from typing import Any, List, Tuple, Callable, Optional
from typing import Any, List, Tuple, Union, Callable, Optional
import uvicorn
from pydantic import BaseSettings
@ -36,6 +36,8 @@ def catch_closed(func):
return await func(*args, **kwargs)
except WebSocketDisconnect as e:
raise WebSocketClosed(e.code)
except KeyError:
raise TypeError("WebSocket received unexpected frame type")
return decorator
@ -261,9 +263,17 @@ class FastAPIWebSocket(BaseWebSocket):
) -> None:
await self.websocket.close(code)
@overrides(BaseWebSocket)
async def receive(self) -> Union[str, bytes]:
# assert self.websocket.application_state == WebSocketState.CONNECTED
msg = await self.websocket.receive()
if msg["type"] == "websocket.disconnect":
raise WebSocketClosed(msg["code"])
return msg["text"] if "text" in msg else msg["bytes"]
@overrides(BaseWebSocket)
@catch_closed
async def receive(self) -> str:
async def receive_text(self) -> str:
return await self.websocket.receive_text()
@overrides(BaseWebSocket)
@ -272,7 +282,7 @@ class FastAPIWebSocket(BaseWebSocket):
return await self.websocket.receive_bytes()
@overrides(BaseWebSocket)
async def send(self, data: str) -> None:
async def send_text(self, data: str) -> None:
await self.websocket.send({"type": "websocket.send", "text": data})
@overrides(BaseWebSocket)

View File

@ -49,22 +49,22 @@ class Mixin(ForwardMixin):
async def request(self, setup: Request) -> Response:
async with httpx.AsyncClient(
http2=setup.version == HTTPVersion.H2,
proxies=setup.proxy,
proxies=setup.proxy, # type: ignore
follow_redirects=True,
) as client:
response = await client.request(
setup.method,
str(setup.url),
content=setup.content,
data=setup.data,
content=setup.content, # type: ignore
data=setup.data, # type: ignore
json=setup.json,
files=setup.files,
files=setup.files, # type: ignore
headers=tuple(setup.headers.items()),
timeout=setup.timeout,
)
return Response(
response.status_code,
headers=response.headers,
headers=response.headers.multi_items(),
content=response.content,
request=setup,
)

View File

@ -17,7 +17,7 @@ FrontMatter:
import asyncio
from functools import wraps
from typing import List, Tuple, TypeVar, Callable, Optional, Coroutine
from typing import List, Tuple, Union, TypeVar, Callable, Optional, Coroutine
import uvicorn
from pydantic import BaseSettings
@ -199,7 +199,7 @@ class Driver(ReverseDriver):
http_request = BaseRequest(
request.method,
request.url,
headers=request.headers.items(),
headers=list(request.headers.items()),
cookies=list(request.cookies.items()),
content=await request.get_data(
cache=False, as_text=False, parse_form_data=False
@ -224,7 +224,7 @@ class Driver(ReverseDriver):
http_request = BaseRequest(
websocket.method,
websocket.url,
headers=websocket.headers.items(),
headers=list(websocket.headers.items()),
cookies=list(websocket.cookies.items()),
version=websocket.http_version,
)
@ -257,7 +257,12 @@ class WebSocket(BaseWebSocket):
@overrides(BaseWebSocket)
@catch_closed
async def receive(self) -> str:
async def receive(self) -> Union[str, bytes]:
return await self.websocket.receive()
@overrides(BaseWebSocket)
@catch_closed
async def receive_text(self) -> str:
msg = await self.websocket.receive()
if isinstance(msg, bytes):
raise TypeError("WebSocket received unexpected frame type: bytes")
@ -272,7 +277,7 @@ class WebSocket(BaseWebSocket):
return msg
@overrides(BaseWebSocket)
async def send(self, data: str):
async def send_text(self, data: str):
await self.websocket.send(data)
@overrides(BaseWebSocket)

View File

@ -16,8 +16,8 @@ FrontMatter:
"""
import logging
from functools import wraps
from typing import Type, AsyncGenerator
from contextlib import asynccontextmanager
from typing import Type, Union, AsyncGenerator
from nonebot.typing import overrides
from nonebot.log import LoguruHandler
@ -46,9 +46,9 @@ def catch_closed(func):
return await func(*args, **kwargs)
except ConnectionClosed as e:
if e.rcvd_then_sent:
raise WebSocketClosed(e.rcvd.code, e.rcvd.reason)
raise WebSocketClosed(e.rcvd.code, e.rcvd.reason) # type: ignore
else:
raise WebSocketClosed(e.sent.code, e.sent.reason)
raise WebSocketClosed(e.sent.code, e.sent.reason) # type: ignore
return decorator
@ -100,7 +100,13 @@ class WebSocket(BaseWebSocket):
@overrides(BaseWebSocket)
@catch_closed
async def receive(self) -> str:
async def receive(self) -> Union[str, bytes]:
msg = await self.websocket.recv()
return msg
@overrides(BaseWebSocket)
@catch_closed
async def receive_text(self) -> str:
msg = await self.websocket.recv()
if isinstance(msg, bytes):
raise TypeError("WebSocket received unexpected frame type: bytes")
@ -115,7 +121,7 @@ class WebSocket(BaseWebSocket):
return msg
@overrides(BaseWebSocket)
async def send(self, data: str) -> None:
async def send_text(self, data: str) -> None:
await self.websocket.send(data)
@overrides(BaseWebSocket)

View File

@ -186,7 +186,12 @@ class WebSocket(abc.ABC):
raise NotImplementedError
@abc.abstractmethod
async def receive(self) -> str:
async def receive(self) -> Union[str, bytes]:
"""接收一条 WebSocket text/bytes 信息"""
raise NotImplementedError
@abc.abstractmethod
async def receive_text(self) -> str:
"""接收一条 WebSocket text 信息"""
raise NotImplementedError
@ -195,8 +200,17 @@ class WebSocket(abc.ABC):
"""接收一条 WebSocket binary 信息"""
raise NotImplementedError
async def send(self, data: Union[str, bytes]) -> None:
"""发送一条 WebSocket text/bytes 信息"""
if isinstance(data, str):
await self.send_text(data)
elif isinstance(data, bytes):
await self.send_bytes(data)
else:
raise TypeError("WebSocker send method expects str or bytes!")
@abc.abstractmethod
async def send(self, data: str) -> None:
async def send_text(self, data: str) -> None:
"""发送一条 WebSocket text 信息"""
raise NotImplementedError

View File

@ -15,6 +15,7 @@ from nonebug import App
)
async def test_reverse_driver(app: App):
import nonebot
from nonebot.exception import WebSocketClosed
from nonebot.drivers import (
URL,
Request,
@ -36,7 +37,21 @@ async def test_reverse_driver(app: App):
data = await ws.receive()
assert data == "ping"
await ws.send("pong")
await ws.close()
data = await ws.receive()
assert data == b"ping"
await ws.send(b"pong")
data = await ws.receive_text()
assert data == "ping"
await ws.send("pong")
data = await ws.receive_bytes()
assert data == b"ping"
await ws.send(b"pong")
with pytest.raises(WebSocketClosed):
await ws.receive()
http_setup = HTTPServerSetup(URL("/http_test"), "POST", "http_test", _handle_http)
driver.setup_http_server(http_setup)
@ -53,3 +68,13 @@ async def test_reverse_driver(app: App):
async with client.websocket_connect("/ws_test") as ws:
await ws.send_text("ping")
assert await ws.receive_text() == "pong"
await ws.send_bytes(b"ping")
assert await ws.receive_bytes() == b"pong"
await ws.send_text("ping")
assert await ws.receive_text() == "pong"
await ws.send_bytes(b"ping")
assert await ws.receive_bytes() == b"pong"
await ws.close()