nonebot2/nonebot/drivers/websockets.py
2023-07-17 15:56:27 +08:00

138 lines
3.9 KiB
Python

"""[websockets](https://websockets.readthedocs.io/) 驱动适配
```bash
nb driver install websockets
# 或者
pip install nonebot2[websockets]
```
:::tip 提示
本驱动仅支持客户端 WebSocket 连接
:::
FrontMatter:
sidebar_position: 4
description: nonebot.drivers.websockets 模块
"""
import logging
from functools import wraps
from contextlib import asynccontextmanager
from typing_extensions import ParamSpec, override
from typing import Type, Union, TypeVar, Callable, Awaitable, AsyncGenerator
from nonebot.log import LoguruHandler
from nonebot.drivers import Request, Response
from nonebot.exception import WebSocketClosed
from nonebot.drivers.none import Driver as NoneDriver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import ForwardMixin, ForwardDriver, combine_driver
try:
from websockets.exceptions import ConnectionClosed
from websockets.legacy.client import Connect, WebSocketClientProtocol
except ModuleNotFoundError as e: # pragma: no cover
raise ImportError(
"Please install websockets first to use this driver. "
"Install with pip: `pip install nonebot2[websockets]`"
) from e
T = TypeVar("T")
P = ParamSpec("P")
logger = logging.Logger("websockets.client", "INFO")
logger.addHandler(LoguruHandler())
def catch_closed(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
@wraps(func)
async def decorator(*args: P.args, **kwargs: P.kwargs) -> T:
try:
return await func(*args, **kwargs)
except ConnectionClosed as e:
if e.rcvd_then_sent:
raise WebSocketClosed(e.rcvd.code, e.rcvd.reason) # type: ignore
else:
raise WebSocketClosed(e.sent.code, e.sent.reason) # type: ignore
return decorator
class Mixin(ForwardMixin):
"""Websockets Mixin"""
@property
@override
def type(self) -> str:
return "websockets"
@override
async def request(self, setup: Request) -> Response:
return await super(Mixin, self).request(setup)
@override
@asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
connection = Connect(
str(setup.url),
extra_headers={**setup.headers, **setup.cookies.as_header(setup)},
open_timeout=setup.timeout,
)
async with connection as ws:
yield WebSocket(request=setup, websocket=ws)
class WebSocket(BaseWebSocket):
"""Websockets WebSocket Wrapper"""
@override
def __init__(self, *, request: Request, websocket: WebSocketClientProtocol):
super().__init__(request=request)
self.websocket = websocket
@property
@override
def closed(self) -> bool:
return self.websocket.closed
@override
async def accept(self):
raise NotImplementedError
@override
async def close(self, code: int = 1000, reason: str = ""):
await self.websocket.close(code, reason)
@override
@catch_closed
async def receive(self) -> Union[str, bytes]:
return await self.websocket.recv()
@override
@catch_closed
async def receive_text(self) -> str:
msg = await self.websocket.recv()
if isinstance(msg, bytes):
raise TypeError("WebSocket received unexpected frame type: bytes")
return msg
@override
@catch_closed
async def receive_bytes(self) -> bytes:
msg = await self.websocket.recv()
if isinstance(msg, str):
raise TypeError("WebSocket received unexpected frame type: str")
return msg
@override
async def send_text(self, data: str) -> None:
await self.websocket.send(data)
@override
async def send_bytes(self, data: bytes) -> None:
await self.websocket.send(data)
Driver: Type[ForwardDriver] = combine_driver(NoneDriver, Mixin) # type: ignore
"""Websockets Driver"""