nonebot2/nonebot/drivers/websockets.py

138 lines
3.9 KiB
Python
Raw Normal View History

2022-01-22 15:23:07 +08:00
"""[websockets](https://websockets.readthedocs.io/) 驱动适配
```bash
nb driver install websockets
# 或者
pip install nonebot2[websockets]
```
:::tip 提示
本驱动仅支持客户端 WebSocket 连接
:::
FrontMatter:
sidebar_position: 4
description: nonebot.drivers.websockets 模块
"""
2021-12-22 16:53:55 +08:00
import logging
from functools import wraps
from contextlib import asynccontextmanager
from typing_extensions import ParamSpec, override
from typing import Type, Union, TypeVar, Callable, Awaitable, AsyncGenerator
2021-12-22 16:53:55 +08:00
from nonebot.log import LoguruHandler
from nonebot.drivers import Request, Response
from nonebot.exception import WebSocketClosed
from nonebot.drivers.none import Driver as NoneDriver
2021-12-22 16:53:55 +08:00
from nonebot.drivers import WebSocket as BaseWebSocket
2022-01-22 15:23:07 +08:00
from nonebot.drivers import ForwardMixin, ForwardDriver, combine_driver
2021-12-22 16:53:55 +08:00
try:
from websockets.exceptions import ConnectionClosed
from websockets.legacy.client import Connect, WebSocketClientProtocol
except ModuleNotFoundError as e: # pragma: no cover
raise ImportError(
"Please install websockets first to use this driver. "
"Install with pip: `pip install nonebot2[websockets]`"
) from e
T = TypeVar("T")
P = ParamSpec("P")
2021-12-22 16:53:55 +08:00
logger = logging.Logger("websockets.client", "INFO")
logger.addHandler(LoguruHandler())
def catch_closed(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
@wraps(func)
async def decorator(*args: P.args, **kwargs: P.kwargs) -> T:
try:
return await func(*args, **kwargs)
except ConnectionClosed as e:
if e.rcvd_then_sent:
raise WebSocketClosed(e.rcvd.code, e.rcvd.reason) # type: ignore
else:
raise WebSocketClosed(e.sent.code, e.sent.reason) # type: ignore
return decorator
class Mixin(ForwardMixin):
2022-01-22 15:23:07 +08:00
"""Websockets Mixin"""
2021-12-22 16:53:55 +08:00
@property
@override
2021-12-22 16:53:55 +08:00
def type(self) -> str:
return "websockets"
@override
2021-12-22 16:53:55 +08:00
async def request(self, setup: Request) -> Response:
return await super(Mixin, self).request(setup)
2021-12-22 16:53:55 +08:00
@override
@asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
connection = Connect(
2021-12-22 16:53:55 +08:00
str(setup.url),
extra_headers={**setup.headers, **setup.cookies.as_header(setup)},
2021-12-22 16:53:55 +08:00
open_timeout=setup.timeout,
)
async with connection as ws:
yield WebSocket(request=setup, websocket=ws)
2021-12-22 16:53:55 +08:00
class WebSocket(BaseWebSocket):
2022-01-22 15:23:07 +08:00
"""Websockets WebSocket Wrapper"""
@override
2021-12-22 16:53:55 +08:00
def __init__(self, *, request: Request, websocket: WebSocketClientProtocol):
super().__init__(request=request)
self.websocket = websocket
@property
@override
2021-12-22 16:53:55 +08:00
def closed(self) -> bool:
return self.websocket.closed
@override
2021-12-22 16:53:55 +08:00
async def accept(self):
raise NotImplementedError
@override
2021-12-22 16:53:55 +08:00
async def close(self, code: int = 1000, reason: str = ""):
await self.websocket.close(code, reason)
@override
@catch_closed
async def receive(self) -> Union[str, bytes]:
return await self.websocket.recv()
@override
@catch_closed
async def receive_text(self) -> str:
2021-12-22 16:53:55 +08:00
msg = await self.websocket.recv()
if isinstance(msg, bytes):
raise TypeError("WebSocket received unexpected frame type: bytes")
return msg
@override
@catch_closed
2021-12-22 16:53:55 +08:00
async def receive_bytes(self) -> bytes:
msg = await self.websocket.recv()
if isinstance(msg, str):
raise TypeError("WebSocket received unexpected frame type: str")
return msg
@override
async def send_text(self, data: str) -> None:
2021-12-22 16:53:55 +08:00
await self.websocket.send(data)
@override
2021-12-22 16:53:55 +08:00
async def send_bytes(self, data: bytes) -> None:
await self.websocket.send(data)
Driver: Type[ForwardDriver] = combine_driver(NoneDriver, Mixin) # type: ignore
2022-01-22 15:23:07 +08:00
"""Websockets Driver"""