mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-01-19 01:18:19 +08:00
✨ Feat: 支持 WebSocket 连接同时获取 str 或 bytes (#962)
This commit is contained in:
parent
91c5056c97
commit
56f99b7f0b
@ -132,20 +132,33 @@ class WebSocket(BaseWebSocket):
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def receive(self) -> str:
|
||||
msg = await self._receive()
|
||||
if msg.type not in (aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY):
|
||||
raise TypeError(
|
||||
f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}"
|
||||
)
|
||||
return msg.data
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def receive_text(self) -> str:
|
||||
msg = await self._receive()
|
||||
if msg.type != aiohttp.WSMsgType.TEXT:
|
||||
raise TypeError(f"WebSocket received unexpected frame type: {msg.type}")
|
||||
raise TypeError(
|
||||
f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}"
|
||||
)
|
||||
return msg.data
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def receive_bytes(self) -> bytes:
|
||||
msg = await self._receive()
|
||||
if msg.type != aiohttp.WSMsgType.TEXT:
|
||||
raise TypeError(f"WebSocket received unexpected frame type: {msg.type}")
|
||||
if msg.type != aiohttp.WSMsgType.BINARY:
|
||||
raise TypeError(
|
||||
f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}"
|
||||
)
|
||||
return msg.data
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def send(self, data: str) -> None:
|
||||
async def send_text(self, data: str) -> None:
|
||||
await self.websocket.send_str(data)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
|
@ -11,7 +11,7 @@ FrontMatter:
|
||||
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Any, List, Tuple, Callable, Optional
|
||||
from typing import Any, List, Tuple, Union, Callable, Optional
|
||||
|
||||
import uvicorn
|
||||
from pydantic import BaseSettings
|
||||
@ -36,6 +36,8 @@ def catch_closed(func):
|
||||
return await func(*args, **kwargs)
|
||||
except WebSocketDisconnect as e:
|
||||
raise WebSocketClosed(e.code)
|
||||
except KeyError:
|
||||
raise TypeError("WebSocket received unexpected frame type")
|
||||
|
||||
return decorator
|
||||
|
||||
@ -261,9 +263,17 @@ class FastAPIWebSocket(BaseWebSocket):
|
||||
) -> None:
|
||||
await self.websocket.close(code)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def receive(self) -> Union[str, bytes]:
|
||||
# assert self.websocket.application_state == WebSocketState.CONNECTED
|
||||
msg = await self.websocket.receive()
|
||||
if msg["type"] == "websocket.disconnect":
|
||||
raise WebSocketClosed(msg["code"])
|
||||
return msg["text"] if "text" in msg else msg["bytes"]
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@catch_closed
|
||||
async def receive(self) -> str:
|
||||
async def receive_text(self) -> str:
|
||||
return await self.websocket.receive_text()
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@ -272,7 +282,7 @@ class FastAPIWebSocket(BaseWebSocket):
|
||||
return await self.websocket.receive_bytes()
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def send(self, data: str) -> None:
|
||||
async def send_text(self, data: str) -> None:
|
||||
await self.websocket.send({"type": "websocket.send", "text": data})
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
|
@ -49,22 +49,22 @@ class Mixin(ForwardMixin):
|
||||
async def request(self, setup: Request) -> Response:
|
||||
async with httpx.AsyncClient(
|
||||
http2=setup.version == HTTPVersion.H2,
|
||||
proxies=setup.proxy,
|
||||
proxies=setup.proxy, # type: ignore
|
||||
follow_redirects=True,
|
||||
) as client:
|
||||
response = await client.request(
|
||||
setup.method,
|
||||
str(setup.url),
|
||||
content=setup.content,
|
||||
data=setup.data,
|
||||
content=setup.content, # type: ignore
|
||||
data=setup.data, # type: ignore
|
||||
json=setup.json,
|
||||
files=setup.files,
|
||||
files=setup.files, # type: ignore
|
||||
headers=tuple(setup.headers.items()),
|
||||
timeout=setup.timeout,
|
||||
)
|
||||
return Response(
|
||||
response.status_code,
|
||||
headers=response.headers,
|
||||
headers=response.headers.multi_items(),
|
||||
content=response.content,
|
||||
request=setup,
|
||||
)
|
||||
|
@ -17,7 +17,7 @@ FrontMatter:
|
||||
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
from typing import List, Tuple, TypeVar, Callable, Optional, Coroutine
|
||||
from typing import List, Tuple, Union, TypeVar, Callable, Optional, Coroutine
|
||||
|
||||
import uvicorn
|
||||
from pydantic import BaseSettings
|
||||
@ -199,7 +199,7 @@ class Driver(ReverseDriver):
|
||||
http_request = BaseRequest(
|
||||
request.method,
|
||||
request.url,
|
||||
headers=request.headers.items(),
|
||||
headers=list(request.headers.items()),
|
||||
cookies=list(request.cookies.items()),
|
||||
content=await request.get_data(
|
||||
cache=False, as_text=False, parse_form_data=False
|
||||
@ -224,7 +224,7 @@ class Driver(ReverseDriver):
|
||||
http_request = BaseRequest(
|
||||
websocket.method,
|
||||
websocket.url,
|
||||
headers=websocket.headers.items(),
|
||||
headers=list(websocket.headers.items()),
|
||||
cookies=list(websocket.cookies.items()),
|
||||
version=websocket.http_version,
|
||||
)
|
||||
@ -257,7 +257,12 @@ class WebSocket(BaseWebSocket):
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@catch_closed
|
||||
async def receive(self) -> str:
|
||||
async def receive(self) -> Union[str, bytes]:
|
||||
return await self.websocket.receive()
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@catch_closed
|
||||
async def receive_text(self) -> str:
|
||||
msg = await self.websocket.receive()
|
||||
if isinstance(msg, bytes):
|
||||
raise TypeError("WebSocket received unexpected frame type: bytes")
|
||||
@ -272,7 +277,7 @@ class WebSocket(BaseWebSocket):
|
||||
return msg
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def send(self, data: str):
|
||||
async def send_text(self, data: str):
|
||||
await self.websocket.send(data)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
|
@ -16,8 +16,8 @@ FrontMatter:
|
||||
"""
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Type, AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Type, Union, AsyncGenerator
|
||||
|
||||
from nonebot.typing import overrides
|
||||
from nonebot.log import LoguruHandler
|
||||
@ -46,9 +46,9 @@ def catch_closed(func):
|
||||
return await func(*args, **kwargs)
|
||||
except ConnectionClosed as e:
|
||||
if e.rcvd_then_sent:
|
||||
raise WebSocketClosed(e.rcvd.code, e.rcvd.reason)
|
||||
raise WebSocketClosed(e.rcvd.code, e.rcvd.reason) # type: ignore
|
||||
else:
|
||||
raise WebSocketClosed(e.sent.code, e.sent.reason)
|
||||
raise WebSocketClosed(e.sent.code, e.sent.reason) # type: ignore
|
||||
|
||||
return decorator
|
||||
|
||||
@ -100,7 +100,13 @@ class WebSocket(BaseWebSocket):
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@catch_closed
|
||||
async def receive(self) -> str:
|
||||
async def receive(self) -> Union[str, bytes]:
|
||||
msg = await self.websocket.recv()
|
||||
return msg
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@catch_closed
|
||||
async def receive_text(self) -> str:
|
||||
msg = await self.websocket.recv()
|
||||
if isinstance(msg, bytes):
|
||||
raise TypeError("WebSocket received unexpected frame type: bytes")
|
||||
@ -115,7 +121,7 @@ class WebSocket(BaseWebSocket):
|
||||
return msg
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
async def send(self, data: str) -> None:
|
||||
async def send_text(self, data: str) -> None:
|
||||
await self.websocket.send(data)
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
|
@ -186,7 +186,12 @@ class WebSocket(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def receive(self) -> str:
|
||||
async def receive(self) -> Union[str, bytes]:
|
||||
"""接收一条 WebSocket text/bytes 信息"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def receive_text(self) -> str:
|
||||
"""接收一条 WebSocket text 信息"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -195,8 +200,17 @@ class WebSocket(abc.ABC):
|
||||
"""接收一条 WebSocket binary 信息"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def send(self, data: Union[str, bytes]) -> None:
|
||||
"""发送一条 WebSocket text/bytes 信息"""
|
||||
if isinstance(data, str):
|
||||
await self.send_text(data)
|
||||
elif isinstance(data, bytes):
|
||||
await self.send_bytes(data)
|
||||
else:
|
||||
raise TypeError("WebSocker send method expects str or bytes!")
|
||||
|
||||
@abc.abstractmethod
|
||||
async def send(self, data: str) -> None:
|
||||
async def send_text(self, data: str) -> None:
|
||||
"""发送一条 WebSocket text 信息"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -15,6 +15,7 @@ from nonebug import App
|
||||
)
|
||||
async def test_reverse_driver(app: App):
|
||||
import nonebot
|
||||
from nonebot.exception import WebSocketClosed
|
||||
from nonebot.drivers import (
|
||||
URL,
|
||||
Request,
|
||||
@ -36,7 +37,21 @@ async def test_reverse_driver(app: App):
|
||||
data = await ws.receive()
|
||||
assert data == "ping"
|
||||
await ws.send("pong")
|
||||
await ws.close()
|
||||
|
||||
data = await ws.receive()
|
||||
assert data == b"ping"
|
||||
await ws.send(b"pong")
|
||||
|
||||
data = await ws.receive_text()
|
||||
assert data == "ping"
|
||||
await ws.send("pong")
|
||||
|
||||
data = await ws.receive_bytes()
|
||||
assert data == b"ping"
|
||||
await ws.send(b"pong")
|
||||
|
||||
with pytest.raises(WebSocketClosed):
|
||||
await ws.receive()
|
||||
|
||||
http_setup = HTTPServerSetup(URL("/http_test"), "POST", "http_test", _handle_http)
|
||||
driver.setup_http_server(http_setup)
|
||||
@ -53,3 +68,13 @@ async def test_reverse_driver(app: App):
|
||||
async with client.websocket_connect("/ws_test") as ws:
|
||||
await ws.send_text("ping")
|
||||
assert await ws.receive_text() == "pong"
|
||||
await ws.send_bytes(b"ping")
|
||||
assert await ws.receive_bytes() == b"pong"
|
||||
|
||||
await ws.send_text("ping")
|
||||
assert await ws.receive_text() == "pong"
|
||||
|
||||
await ws.send_bytes(b"ping")
|
||||
assert await ws.receive_bytes() == b"pong"
|
||||
|
||||
await ws.close()
|
||||
|
Loading…
Reference in New Issue
Block a user