mirror of
https://github.com/nonebot/nonebot2.git
synced 2025-02-07 19:36:56 +08:00
✨ Feat: 支持 WebSocket 连接同时获取 str 或 bytes (#962)
This commit is contained in:
parent
91c5056c97
commit
56f99b7f0b
@ -132,20 +132,33 @@ class WebSocket(BaseWebSocket):
|
|||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def receive(self) -> str:
|
async def receive(self) -> str:
|
||||||
|
msg = await self._receive()
|
||||||
|
if msg.type not in (aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY):
|
||||||
|
raise TypeError(
|
||||||
|
f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}"
|
||||||
|
)
|
||||||
|
return msg.data
|
||||||
|
|
||||||
|
@overrides(BaseWebSocket)
|
||||||
|
async def receive_text(self) -> str:
|
||||||
msg = await self._receive()
|
msg = await self._receive()
|
||||||
if msg.type != aiohttp.WSMsgType.TEXT:
|
if msg.type != aiohttp.WSMsgType.TEXT:
|
||||||
raise TypeError(f"WebSocket received unexpected frame type: {msg.type}")
|
raise TypeError(
|
||||||
|
f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}"
|
||||||
|
)
|
||||||
return msg.data
|
return msg.data
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def receive_bytes(self) -> bytes:
|
async def receive_bytes(self) -> bytes:
|
||||||
msg = await self._receive()
|
msg = await self._receive()
|
||||||
if msg.type != aiohttp.WSMsgType.TEXT:
|
if msg.type != aiohttp.WSMsgType.BINARY:
|
||||||
raise TypeError(f"WebSocket received unexpected frame type: {msg.type}")
|
raise TypeError(
|
||||||
|
f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}"
|
||||||
|
)
|
||||||
return msg.data
|
return msg.data
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def send(self, data: str) -> None:
|
async def send_text(self, data: str) -> None:
|
||||||
await self.websocket.send_str(data)
|
await self.websocket.send_str(data)
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
|
@ -11,7 +11,7 @@ FrontMatter:
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, List, Tuple, Callable, Optional
|
from typing import Any, List, Tuple, Union, Callable, Optional
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from pydantic import BaseSettings
|
from pydantic import BaseSettings
|
||||||
@ -36,6 +36,8 @@ def catch_closed(func):
|
|||||||
return await func(*args, **kwargs)
|
return await func(*args, **kwargs)
|
||||||
except WebSocketDisconnect as e:
|
except WebSocketDisconnect as e:
|
||||||
raise WebSocketClosed(e.code)
|
raise WebSocketClosed(e.code)
|
||||||
|
except KeyError:
|
||||||
|
raise TypeError("WebSocket received unexpected frame type")
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
@ -261,9 +263,17 @@ class FastAPIWebSocket(BaseWebSocket):
|
|||||||
) -> None:
|
) -> None:
|
||||||
await self.websocket.close(code)
|
await self.websocket.close(code)
|
||||||
|
|
||||||
|
@overrides(BaseWebSocket)
|
||||||
|
async def receive(self) -> Union[str, bytes]:
|
||||||
|
# assert self.websocket.application_state == WebSocketState.CONNECTED
|
||||||
|
msg = await self.websocket.receive()
|
||||||
|
if msg["type"] == "websocket.disconnect":
|
||||||
|
raise WebSocketClosed(msg["code"])
|
||||||
|
return msg["text"] if "text" in msg else msg["bytes"]
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
@catch_closed
|
@catch_closed
|
||||||
async def receive(self) -> str:
|
async def receive_text(self) -> str:
|
||||||
return await self.websocket.receive_text()
|
return await self.websocket.receive_text()
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
@ -272,7 +282,7 @@ class FastAPIWebSocket(BaseWebSocket):
|
|||||||
return await self.websocket.receive_bytes()
|
return await self.websocket.receive_bytes()
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def send(self, data: str) -> None:
|
async def send_text(self, data: str) -> None:
|
||||||
await self.websocket.send({"type": "websocket.send", "text": data})
|
await self.websocket.send({"type": "websocket.send", "text": data})
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
|
@ -49,22 +49,22 @@ class Mixin(ForwardMixin):
|
|||||||
async def request(self, setup: Request) -> Response:
|
async def request(self, setup: Request) -> Response:
|
||||||
async with httpx.AsyncClient(
|
async with httpx.AsyncClient(
|
||||||
http2=setup.version == HTTPVersion.H2,
|
http2=setup.version == HTTPVersion.H2,
|
||||||
proxies=setup.proxy,
|
proxies=setup.proxy, # type: ignore
|
||||||
follow_redirects=True,
|
follow_redirects=True,
|
||||||
) as client:
|
) as client:
|
||||||
response = await client.request(
|
response = await client.request(
|
||||||
setup.method,
|
setup.method,
|
||||||
str(setup.url),
|
str(setup.url),
|
||||||
content=setup.content,
|
content=setup.content, # type: ignore
|
||||||
data=setup.data,
|
data=setup.data, # type: ignore
|
||||||
json=setup.json,
|
json=setup.json,
|
||||||
files=setup.files,
|
files=setup.files, # type: ignore
|
||||||
headers=tuple(setup.headers.items()),
|
headers=tuple(setup.headers.items()),
|
||||||
timeout=setup.timeout,
|
timeout=setup.timeout,
|
||||||
)
|
)
|
||||||
return Response(
|
return Response(
|
||||||
response.status_code,
|
response.status_code,
|
||||||
headers=response.headers,
|
headers=response.headers.multi_items(),
|
||||||
content=response.content,
|
content=response.content,
|
||||||
request=setup,
|
request=setup,
|
||||||
)
|
)
|
||||||
|
@ -17,7 +17,7 @@ FrontMatter:
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import List, Tuple, TypeVar, Callable, Optional, Coroutine
|
from typing import List, Tuple, Union, TypeVar, Callable, Optional, Coroutine
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from pydantic import BaseSettings
|
from pydantic import BaseSettings
|
||||||
@ -199,7 +199,7 @@ class Driver(ReverseDriver):
|
|||||||
http_request = BaseRequest(
|
http_request = BaseRequest(
|
||||||
request.method,
|
request.method,
|
||||||
request.url,
|
request.url,
|
||||||
headers=request.headers.items(),
|
headers=list(request.headers.items()),
|
||||||
cookies=list(request.cookies.items()),
|
cookies=list(request.cookies.items()),
|
||||||
content=await request.get_data(
|
content=await request.get_data(
|
||||||
cache=False, as_text=False, parse_form_data=False
|
cache=False, as_text=False, parse_form_data=False
|
||||||
@ -224,7 +224,7 @@ class Driver(ReverseDriver):
|
|||||||
http_request = BaseRequest(
|
http_request = BaseRequest(
|
||||||
websocket.method,
|
websocket.method,
|
||||||
websocket.url,
|
websocket.url,
|
||||||
headers=websocket.headers.items(),
|
headers=list(websocket.headers.items()),
|
||||||
cookies=list(websocket.cookies.items()),
|
cookies=list(websocket.cookies.items()),
|
||||||
version=websocket.http_version,
|
version=websocket.http_version,
|
||||||
)
|
)
|
||||||
@ -257,7 +257,12 @@ class WebSocket(BaseWebSocket):
|
|||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
@catch_closed
|
@catch_closed
|
||||||
async def receive(self) -> str:
|
async def receive(self) -> Union[str, bytes]:
|
||||||
|
return await self.websocket.receive()
|
||||||
|
|
||||||
|
@overrides(BaseWebSocket)
|
||||||
|
@catch_closed
|
||||||
|
async def receive_text(self) -> str:
|
||||||
msg = await self.websocket.receive()
|
msg = await self.websocket.receive()
|
||||||
if isinstance(msg, bytes):
|
if isinstance(msg, bytes):
|
||||||
raise TypeError("WebSocket received unexpected frame type: bytes")
|
raise TypeError("WebSocket received unexpected frame type: bytes")
|
||||||
@ -272,7 +277,7 @@ class WebSocket(BaseWebSocket):
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def send(self, data: str):
|
async def send_text(self, data: str):
|
||||||
await self.websocket.send(data)
|
await self.websocket.send(data)
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
|
@ -16,8 +16,8 @@ FrontMatter:
|
|||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Type, AsyncGenerator
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Type, Union, AsyncGenerator
|
||||||
|
|
||||||
from nonebot.typing import overrides
|
from nonebot.typing import overrides
|
||||||
from nonebot.log import LoguruHandler
|
from nonebot.log import LoguruHandler
|
||||||
@ -46,9 +46,9 @@ def catch_closed(func):
|
|||||||
return await func(*args, **kwargs)
|
return await func(*args, **kwargs)
|
||||||
except ConnectionClosed as e:
|
except ConnectionClosed as e:
|
||||||
if e.rcvd_then_sent:
|
if e.rcvd_then_sent:
|
||||||
raise WebSocketClosed(e.rcvd.code, e.rcvd.reason)
|
raise WebSocketClosed(e.rcvd.code, e.rcvd.reason) # type: ignore
|
||||||
else:
|
else:
|
||||||
raise WebSocketClosed(e.sent.code, e.sent.reason)
|
raise WebSocketClosed(e.sent.code, e.sent.reason) # type: ignore
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
@ -100,7 +100,13 @@ class WebSocket(BaseWebSocket):
|
|||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
@catch_closed
|
@catch_closed
|
||||||
async def receive(self) -> str:
|
async def receive(self) -> Union[str, bytes]:
|
||||||
|
msg = await self.websocket.recv()
|
||||||
|
return msg
|
||||||
|
|
||||||
|
@overrides(BaseWebSocket)
|
||||||
|
@catch_closed
|
||||||
|
async def receive_text(self) -> str:
|
||||||
msg = await self.websocket.recv()
|
msg = await self.websocket.recv()
|
||||||
if isinstance(msg, bytes):
|
if isinstance(msg, bytes):
|
||||||
raise TypeError("WebSocket received unexpected frame type: bytes")
|
raise TypeError("WebSocket received unexpected frame type: bytes")
|
||||||
@ -115,7 +121,7 @@ class WebSocket(BaseWebSocket):
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
async def send(self, data: str) -> None:
|
async def send_text(self, data: str) -> None:
|
||||||
await self.websocket.send(data)
|
await self.websocket.send(data)
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
|
@ -186,7 +186,12 @@ class WebSocket(abc.ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def receive(self) -> str:
|
async def receive(self) -> Union[str, bytes]:
|
||||||
|
"""接收一条 WebSocket text/bytes 信息"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def receive_text(self) -> str:
|
||||||
"""接收一条 WebSocket text 信息"""
|
"""接收一条 WebSocket text 信息"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -195,8 +200,17 @@ class WebSocket(abc.ABC):
|
|||||||
"""接收一条 WebSocket binary 信息"""
|
"""接收一条 WebSocket binary 信息"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def send(self, data: Union[str, bytes]) -> None:
|
||||||
|
"""发送一条 WebSocket text/bytes 信息"""
|
||||||
|
if isinstance(data, str):
|
||||||
|
await self.send_text(data)
|
||||||
|
elif isinstance(data, bytes):
|
||||||
|
await self.send_bytes(data)
|
||||||
|
else:
|
||||||
|
raise TypeError("WebSocker send method expects str or bytes!")
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def send(self, data: str) -> None:
|
async def send_text(self, data: str) -> None:
|
||||||
"""发送一条 WebSocket text 信息"""
|
"""发送一条 WebSocket text 信息"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ from nonebug import App
|
|||||||
)
|
)
|
||||||
async def test_reverse_driver(app: App):
|
async def test_reverse_driver(app: App):
|
||||||
import nonebot
|
import nonebot
|
||||||
|
from nonebot.exception import WebSocketClosed
|
||||||
from nonebot.drivers import (
|
from nonebot.drivers import (
|
||||||
URL,
|
URL,
|
||||||
Request,
|
Request,
|
||||||
@ -36,7 +37,21 @@ async def test_reverse_driver(app: App):
|
|||||||
data = await ws.receive()
|
data = await ws.receive()
|
||||||
assert data == "ping"
|
assert data == "ping"
|
||||||
await ws.send("pong")
|
await ws.send("pong")
|
||||||
await ws.close()
|
|
||||||
|
data = await ws.receive()
|
||||||
|
assert data == b"ping"
|
||||||
|
await ws.send(b"pong")
|
||||||
|
|
||||||
|
data = await ws.receive_text()
|
||||||
|
assert data == "ping"
|
||||||
|
await ws.send("pong")
|
||||||
|
|
||||||
|
data = await ws.receive_bytes()
|
||||||
|
assert data == b"ping"
|
||||||
|
await ws.send(b"pong")
|
||||||
|
|
||||||
|
with pytest.raises(WebSocketClosed):
|
||||||
|
await ws.receive()
|
||||||
|
|
||||||
http_setup = HTTPServerSetup(URL("/http_test"), "POST", "http_test", _handle_http)
|
http_setup = HTTPServerSetup(URL("/http_test"), "POST", "http_test", _handle_http)
|
||||||
driver.setup_http_server(http_setup)
|
driver.setup_http_server(http_setup)
|
||||||
@ -53,3 +68,13 @@ async def test_reverse_driver(app: App):
|
|||||||
async with client.websocket_connect("/ws_test") as ws:
|
async with client.websocket_connect("/ws_test") as ws:
|
||||||
await ws.send_text("ping")
|
await ws.send_text("ping")
|
||||||
assert await ws.receive_text() == "pong"
|
assert await ws.receive_text() == "pong"
|
||||||
|
await ws.send_bytes(b"ping")
|
||||||
|
assert await ws.receive_bytes() == b"pong"
|
||||||
|
|
||||||
|
await ws.send_text("ping")
|
||||||
|
assert await ws.receive_text() == "pong"
|
||||||
|
|
||||||
|
await ws.send_bytes(b"ping")
|
||||||
|
assert await ws.receive_bytes() == b"pong"
|
||||||
|
|
||||||
|
await ws.close()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user