nonebot2/nonebot/drivers/aiohttp.py

187 lines
5.4 KiB
Python
Raw Normal View History

2022-01-22 15:23:07 +08:00
"""[AIOHTTP](https://aiohttp.readthedocs.io/en/stable/) 驱动适配器。
```bash
nb driver install aiohttp
# 或者
pip install nonebot2[aiohttp]
```
2021-07-31 12:24:11 +08:00
2022-01-22 15:23:07 +08:00
:::tip 提示
2021-07-31 12:24:11 +08:00
本驱动仅支持客户端连接
2022-01-22 15:23:07 +08:00
:::
FrontMatter:
sidebar_position: 2
description: nonebot.drivers.aiohttp 模块
2021-06-21 01:22:33 +08:00
"""
from typing_extensions import override
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, AsyncGenerator
2021-12-22 16:53:55 +08:00
from nonebot.drivers import Request, Response
from nonebot.exception import WebSocketClosed
from nonebot.drivers.none import Driver as NoneDriver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import (
HTTPVersion,
HTTPClientMixin,
WebSocketClientMixin,
combine_driver,
)
2021-06-21 01:22:33 +08:00
try:
import aiohttp
except ModuleNotFoundError as e: # pragma: no cover
raise ImportError(
"Please install aiohttp first to use this driver. "
"Install with pip: `pip install nonebot2[aiohttp]`"
) from e
2021-06-21 01:22:33 +08:00
class Mixin(HTTPClientMixin, WebSocketClientMixin):
2022-01-22 15:23:07 +08:00
"""AIOHTTP Mixin"""
2021-06-21 01:22:33 +08:00
@property
@override
2021-06-21 01:22:33 +08:00
def type(self) -> str:
return "aiohttp"
@override
2021-12-22 16:53:55 +08:00
async def request(self, setup: Request) -> Response:
if setup.version == HTTPVersion.H10:
version = aiohttp.HttpVersion10
elif setup.version == HTTPVersion.H11:
version = aiohttp.HttpVersion11
2021-07-20 15:35:56 +08:00
else:
2021-12-22 16:53:55 +08:00
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
2021-07-20 15:35:56 +08:00
2021-12-22 16:53:55 +08:00
timeout = aiohttp.ClientTimeout(setup.timeout)
data = setup.data
if setup.files:
data = aiohttp.FormData(data or {})
for name, file in setup.files:
data.add_field(name, file[1], content_type=file[2], filename=file[0])
cookies = {
cookie.name: cookie.value for cookie in setup.cookies if cookie.value
}
async with aiohttp.ClientSession(
cookies=cookies, version=version, trust_env=True
) as session:
2021-12-22 16:53:55 +08:00
async with session.request(
setup.method,
2021-12-22 16:53:55 +08:00
setup.url,
data=setup.content or data,
json=setup.json,
2021-12-22 16:53:55 +08:00
headers=setup.headers,
timeout=timeout,
2021-12-25 14:04:53 +08:00
proxy=setup.proxy,
2021-12-22 16:53:55 +08:00
) as response:
return Response(
2021-12-22 16:53:55 +08:00
response.status,
headers=response.headers.copy(),
content=await response.read(),
request=setup,
)
@override
@asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
2021-12-22 16:53:55 +08:00
if setup.version == HTTPVersion.H10:
version = aiohttp.HttpVersion10
elif setup.version == HTTPVersion.H11:
version = aiohttp.HttpVersion11
else:
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
2021-12-27 02:26:02 +08:00
async with aiohttp.ClientSession(version=version, trust_env=True) as session:
async with session.ws_connect(
setup.url,
method=setup.method,
timeout=setup.timeout or 10,
headers=setup.headers,
proxy=setup.proxy,
) as ws:
yield WebSocket(request=setup, session=session, websocket=ws)
2021-07-19 01:20:17 +08:00
2021-07-31 12:24:11 +08:00
2021-07-20 15:35:56 +08:00
class WebSocket(BaseWebSocket):
2022-01-22 15:23:07 +08:00
"""AIOHTTP Websocket Wrapper"""
2021-12-22 16:53:55 +08:00
def __init__(
self,
*,
request: Request,
session: aiohttp.ClientSession,
websocket: aiohttp.ClientWebSocketResponse,
):
super().__init__(request=request)
self.session = session
self.websocket = websocket
2021-07-20 15:35:56 +08:00
@property
@override
2021-07-20 15:35:56 +08:00
def closed(self):
return self.websocket.closed
@override
2021-07-20 15:35:56 +08:00
async def accept(self):
raise NotImplementedError
@override
2021-07-20 15:35:56 +08:00
async def close(self, code: int = 1000):
await self.websocket.close(code=code)
2021-12-22 16:53:55 +08:00
await self.session.close()
2021-07-20 15:35:56 +08:00
async def _receive(self) -> aiohttp.WSMessage:
msg = await self.websocket.receive()
if msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
raise WebSocketClosed(self.websocket.close_code or 1006)
return msg
@override
2021-07-20 15:35:56 +08:00
async def receive(self) -> str:
msg = await self._receive()
if msg.type not in (aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY):
raise TypeError(
f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}"
)
return msg.data
@override
async def receive_text(self) -> str:
msg = await self._receive()
if msg.type != aiohttp.WSMsgType.TEXT:
raise TypeError(
f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}"
)
return msg.data
2021-07-20 15:35:56 +08:00
@override
2021-07-20 15:35:56 +08:00
async def receive_bytes(self) -> bytes:
msg = await self._receive()
if msg.type != aiohttp.WSMsgType.BINARY:
raise TypeError(
f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}"
)
return msg.data
2021-07-20 15:35:56 +08:00
@override
async def send_text(self, data: str) -> None:
2021-07-20 15:35:56 +08:00
await self.websocket.send_str(data)
@override
2021-07-20 15:35:56 +08:00
async def send_bytes(self, data: bytes) -> None:
await self.websocket.send_bytes(data)
2021-12-22 16:53:55 +08:00
if TYPE_CHECKING:
class Driver(Mixin, NoneDriver): ...
else:
Driver = combine_driver(NoneDriver, Mixin)
"""AIOHTTP Driver"""