2022-01-22 15:23:07 +08:00
|
|
|
"""[websockets](https://websockets.readthedocs.io/) 驱动适配
|
|
|
|
|
|
|
|
```bash
|
|
|
|
nb driver install websockets
|
|
|
|
# 或者
|
|
|
|
pip install nonebot2[websockets]
|
|
|
|
```
|
|
|
|
|
|
|
|
:::tip 提示
|
|
|
|
本驱动仅支持客户端 WebSocket 连接
|
|
|
|
:::
|
|
|
|
|
|
|
|
FrontMatter:
|
2024-10-22 10:33:48 +08:00
|
|
|
mdx:
|
|
|
|
format: md
|
2022-01-22 15:23:07 +08:00
|
|
|
sidebar_position: 4
|
|
|
|
description: nonebot.drivers.websockets 模块
|
|
|
|
"""
|
2023-06-24 14:47:35 +08:00
|
|
|
|
2021-12-22 16:53:55 +08:00
|
|
|
import logging
|
2021-12-26 14:20:09 +08:00
|
|
|
from functools import wraps
|
2021-12-26 13:42:13 +08:00
|
|
|
from contextlib import asynccontextmanager
|
2023-07-17 15:56:27 +08:00
|
|
|
from typing_extensions import ParamSpec, override
|
2024-04-16 00:33:48 +08:00
|
|
|
from collections.abc import Coroutine, AsyncGenerator
|
|
|
|
from typing import TYPE_CHECKING, Any, Union, TypeVar, Callable
|
2021-12-22 16:53:55 +08:00
|
|
|
|
2023-08-26 11:03:24 +08:00
|
|
|
from nonebot.drivers import Request
|
2021-12-22 16:53:55 +08:00
|
|
|
from nonebot.log import LoguruHandler
|
2021-12-26 14:20:09 +08:00
|
|
|
from nonebot.exception import WebSocketClosed
|
2023-01-01 15:08:00 +08:00
|
|
|
from nonebot.drivers.none import Driver as NoneDriver
|
2021-12-22 16:53:55 +08:00
|
|
|
from nonebot.drivers import WebSocket as BaseWebSocket
|
2023-08-26 11:03:24 +08:00
|
|
|
from nonebot.drivers import WebSocketClientMixin, combine_driver
|
2021-12-22 16:53:55 +08:00
|
|
|
|
2021-12-23 17:20:26 +08:00
|
|
|
try:
|
2021-12-26 14:20:09 +08:00
|
|
|
from websockets.exceptions import ConnectionClosed
|
2021-12-23 17:20:26 +08:00
|
|
|
from websockets.legacy.client import Connect, WebSocketClientProtocol
|
2023-03-29 15:59:54 +08:00
|
|
|
except ModuleNotFoundError as e: # pragma: no cover
|
2021-12-23 17:20:26 +08:00
|
|
|
raise ImportError(
|
2023-06-24 14:47:35 +08:00
|
|
|
"Please install websockets first to use this driver. "
|
|
|
|
"Install with pip: `pip install nonebot2[websockets]`"
|
2023-02-09 10:24:27 +08:00
|
|
|
) from e
|
2021-12-23 17:20:26 +08:00
|
|
|
|
2023-06-24 14:47:35 +08:00
|
|
|
T = TypeVar("T")
|
|
|
|
P = ParamSpec("P")
|
|
|
|
|
2021-12-22 16:53:55 +08:00
|
|
|
logger = logging.Logger("websockets.client", "INFO")
|
|
|
|
logger.addHandler(LoguruHandler())
|
|
|
|
|
|
|
|
|
2024-04-16 00:33:48 +08:00
|
|
|
def catch_closed(
|
|
|
|
func: Callable[P, Coroutine[Any, Any, T]]
|
|
|
|
) -> Callable[P, Coroutine[Any, Any, T]]:
|
2021-12-26 14:20:09 +08:00
|
|
|
@wraps(func)
|
2023-06-24 14:47:35 +08:00
|
|
|
async def decorator(*args: P.args, **kwargs: P.kwargs) -> T:
|
2021-12-26 14:20:09 +08:00
|
|
|
try:
|
|
|
|
return await func(*args, **kwargs)
|
|
|
|
except ConnectionClosed as e:
|
2024-01-17 16:39:35 +08:00
|
|
|
raise WebSocketClosed(e.code, e.reason)
|
2021-12-26 14:20:09 +08:00
|
|
|
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
2023-08-26 11:03:24 +08:00
|
|
|
class Mixin(WebSocketClientMixin):
|
2022-01-22 15:23:07 +08:00
|
|
|
"""Websockets Mixin"""
|
|
|
|
|
2021-12-22 16:53:55 +08:00
|
|
|
@property
|
2023-07-17 15:56:27 +08:00
|
|
|
@override
|
2021-12-22 16:53:55 +08:00
|
|
|
def type(self) -> str:
|
|
|
|
return "websockets"
|
|
|
|
|
2023-07-17 15:56:27 +08:00
|
|
|
@override
|
2021-12-26 13:42:13 +08:00
|
|
|
@asynccontextmanager
|
|
|
|
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
|
2024-08-22 21:28:51 +08:00
|
|
|
if setup.proxy is not None:
|
|
|
|
logger.warning("proxy is not supported by websockets driver")
|
2021-12-26 13:42:13 +08:00
|
|
|
connection = Connect(
|
2021-12-22 16:53:55 +08:00
|
|
|
str(setup.url),
|
2022-12-20 18:13:45 +08:00
|
|
|
extra_headers={**setup.headers, **setup.cookies.as_header(setup)},
|
2021-12-22 16:53:55 +08:00
|
|
|
open_timeout=setup.timeout,
|
|
|
|
)
|
2021-12-26 13:42:13 +08:00
|
|
|
async with connection as ws:
|
|
|
|
yield WebSocket(request=setup, websocket=ws)
|
2021-12-22 16:53:55 +08:00
|
|
|
|
|
|
|
|
|
|
|
class WebSocket(BaseWebSocket):
|
2022-01-22 15:23:07 +08:00
|
|
|
"""Websockets WebSocket Wrapper"""
|
|
|
|
|
2023-07-17 15:56:27 +08:00
|
|
|
@override
|
2021-12-22 16:53:55 +08:00
|
|
|
def __init__(self, *, request: Request, websocket: WebSocketClientProtocol):
|
|
|
|
super().__init__(request=request)
|
|
|
|
self.websocket = websocket
|
|
|
|
|
|
|
|
@property
|
2023-07-17 15:56:27 +08:00
|
|
|
@override
|
2021-12-22 16:53:55 +08:00
|
|
|
def closed(self) -> bool:
|
|
|
|
return self.websocket.closed
|
|
|
|
|
2023-07-17 15:56:27 +08:00
|
|
|
@override
|
2021-12-22 16:53:55 +08:00
|
|
|
async def accept(self):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
2023-07-17 15:56:27 +08:00
|
|
|
@override
|
2021-12-22 16:53:55 +08:00
|
|
|
async def close(self, code: int = 1000, reason: str = ""):
|
|
|
|
await self.websocket.close(code, reason)
|
|
|
|
|
2023-07-17 15:56:27 +08:00
|
|
|
@override
|
2021-12-26 14:20:09 +08:00
|
|
|
@catch_closed
|
2022-05-14 21:06:57 +08:00
|
|
|
async def receive(self) -> Union[str, bytes]:
|
2022-12-20 18:13:45 +08:00
|
|
|
return await self.websocket.recv()
|
2022-05-14 21:06:57 +08:00
|
|
|
|
2023-07-17 15:56:27 +08:00
|
|
|
@override
|
2022-05-14 21:06:57 +08:00
|
|
|
@catch_closed
|
|
|
|
async def receive_text(self) -> str:
|
2021-12-22 16:53:55 +08:00
|
|
|
msg = await self.websocket.recv()
|
|
|
|
if isinstance(msg, bytes):
|
|
|
|
raise TypeError("WebSocket received unexpected frame type: bytes")
|
|
|
|
return msg
|
|
|
|
|
2023-07-17 15:56:27 +08:00
|
|
|
@override
|
2021-12-26 14:20:09 +08:00
|
|
|
@catch_closed
|
2021-12-22 16:53:55 +08:00
|
|
|
async def receive_bytes(self) -> bytes:
|
|
|
|
msg = await self.websocket.recv()
|
|
|
|
if isinstance(msg, str):
|
|
|
|
raise TypeError("WebSocket received unexpected frame type: str")
|
|
|
|
return msg
|
|
|
|
|
2023-07-17 15:56:27 +08:00
|
|
|
@override
|
2022-05-14 21:06:57 +08:00
|
|
|
async def send_text(self, data: str) -> None:
|
2021-12-22 16:53:55 +08:00
|
|
|
await self.websocket.send(data)
|
|
|
|
|
2023-07-17 15:56:27 +08:00
|
|
|
@override
|
2021-12-22 16:53:55 +08:00
|
|
|
async def send_bytes(self, data: bytes) -> None:
|
|
|
|
await self.websocket.send(data)
|
|
|
|
|
|
|
|
|
2023-08-26 11:03:24 +08:00
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
2024-02-06 12:48:23 +08:00
|
|
|
class Driver(Mixin, NoneDriver): ...
|
2023-08-26 11:03:24 +08:00
|
|
|
|
|
|
|
else:
|
|
|
|
Driver = combine_driver(NoneDriver, Mixin)
|
|
|
|
"""Websockets Driver"""
|