nonebot2/nonebot/drivers/aiohttp.py

277 lines
7.9 KiB
Python
Raw Normal View History

2022-01-22 15:23:07 +08:00
"""[AIOHTTP](https://aiohttp.readthedocs.io/en/stable/) 驱动适配器。
```bash
nb driver install aiohttp
# 或者
pip install nonebot2[aiohttp]
```
2021-07-31 12:24:11 +08:00
2022-01-22 15:23:07 +08:00
:::tip 提示
2021-07-31 12:24:11 +08:00
本驱动仅支持客户端连接
2022-01-22 15:23:07 +08:00
:::
FrontMatter:
sidebar_position: 2
description: nonebot.drivers.aiohttp 模块
2021-06-21 01:22:33 +08:00
"""
from typing_extensions import override
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Union, Optional, AsyncGenerator
from multidict import CIMultiDict
from nonebot.exception import WebSocketClosed
from nonebot.drivers import URL, Request, Response
from nonebot.drivers.none import Driver as NoneDriver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.internal.driver import Cookies, QueryTypes, CookieTypes, HeaderTypes
from nonebot.drivers import (
HTTPVersion,
HTTPClientMixin,
HTTPClientSession,
WebSocketClientMixin,
combine_driver,
)
2021-06-21 01:22:33 +08:00
try:
import aiohttp
except ModuleNotFoundError as e: # pragma: no cover
raise ImportError(
"Please install aiohttp first to use this driver. "
"Install with pip: `pip install nonebot2[aiohttp]`"
) from e
2021-06-21 01:22:33 +08:00
class Session(HTTPClientSession):
@override
def __init__(
self,
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
proxy: Optional[str] = None,
):
self._client: Optional[aiohttp.ClientSession] = None
self._params = URL.build(query=params).query if params is not None else None
self._headers = CIMultiDict(headers) if headers is not None else None
self._cookies = tuple(
(cookie.name, cookie.value)
for cookie in Cookies(cookies)
if cookie.value is not None
)
version = HTTPVersion(version)
if version == HTTPVersion.H10:
self._version = aiohttp.HttpVersion10
elif version == HTTPVersion.H11:
self._version = aiohttp.HttpVersion11
else:
raise RuntimeError(f"Unsupported HTTP version: {version}")
self._timeout = timeout
self._proxy = proxy
2022-01-22 15:23:07 +08:00
2021-06-21 01:22:33 +08:00
@property
def client(self) -> aiohttp.ClientSession:
if self._client is None:
raise RuntimeError("Session is not initialized")
return self._client
2021-06-21 01:22:33 +08:00
@override
2021-12-22 16:53:55 +08:00
async def request(self, setup: Request) -> Response:
if self._params:
params = self._params.copy()
params.update(setup.url.query)
url = setup.url.with_query(params)
2021-07-20 15:35:56 +08:00
else:
url = setup.url
data = setup.data
if setup.files:
data = aiohttp.FormData(data or {}, quote_fields=False)
for name, file in setup.files:
data.add_field(name, file[1], content_type=file[2], filename=file[0])
cookies = (
(cookie.name, cookie.value)
for cookie in setup.cookies
if cookie.value is not None
)
timeout = aiohttp.ClientTimeout(setup.timeout)
async with await self.client.request(
setup.method,
url,
data=setup.content or data,
json=setup.json,
cookies=cookies,
headers=setup.headers,
proxy=setup.proxy or self._proxy,
timeout=timeout,
) as response:
return Response(
response.status,
headers=response.headers.copy(),
content=await response.read(),
request=setup,
)
@override
async def setup(self) -> None:
if self._client is not None:
raise RuntimeError("Session has already been initialized")
self._client = aiohttp.ClientSession(
cookies=self._cookies,
headers=self._headers,
version=self._version,
timeout=self._timeout,
trust_env=True,
)
await self._client.__aenter__()
@override
async def close(self) -> None:
try:
if self._client is not None:
await self._client.close()
finally:
self._client = None
class Mixin(HTTPClientMixin, WebSocketClientMixin):
"""AIOHTTP Mixin"""
@property
@override
def type(self) -> str:
return "aiohttp"
@override
async def request(self, setup: Request) -> Response:
async with self.get_session() as session:
return await session.request(setup)
2021-12-22 16:53:55 +08:00
@override
@asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
2021-12-22 16:53:55 +08:00
if setup.version == HTTPVersion.H10:
version = aiohttp.HttpVersion10
elif setup.version == HTTPVersion.H11:
version = aiohttp.HttpVersion11
else:
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
2021-12-27 02:26:02 +08:00
async with aiohttp.ClientSession(version=version, trust_env=True) as session:
async with session.ws_connect(
setup.url,
method=setup.method,
timeout=setup.timeout or 10,
headers=setup.headers,
proxy=setup.proxy,
) as ws:
yield WebSocket(request=setup, session=session, websocket=ws)
2021-07-19 01:20:17 +08:00
@override
def get_session(
self,
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
proxy: Optional[str] = None,
) -> Session:
return Session(
params=params,
headers=headers,
cookies=cookies,
version=version,
timeout=timeout,
proxy=proxy,
)
2021-07-31 12:24:11 +08:00
2021-07-20 15:35:56 +08:00
class WebSocket(BaseWebSocket):
2022-01-22 15:23:07 +08:00
"""AIOHTTP Websocket Wrapper"""
2021-12-22 16:53:55 +08:00
def __init__(
self,
*,
request: Request,
session: aiohttp.ClientSession,
websocket: aiohttp.ClientWebSocketResponse,
):
super().__init__(request=request)
self.session = session
self.websocket = websocket
2021-07-20 15:35:56 +08:00
@property
@override
2021-07-20 15:35:56 +08:00
def closed(self):
return self.websocket.closed
@override
2021-07-20 15:35:56 +08:00
async def accept(self):
raise NotImplementedError
@override
2021-07-20 15:35:56 +08:00
async def close(self, code: int = 1000):
await self.websocket.close(code=code)
2021-12-22 16:53:55 +08:00
await self.session.close()
2021-07-20 15:35:56 +08:00
async def _receive(self) -> aiohttp.WSMessage:
msg = await self.websocket.receive()
if msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
raise WebSocketClosed(self.websocket.close_code or 1006)
return msg
@override
2021-07-20 15:35:56 +08:00
async def receive(self) -> str:
msg = await self._receive()
if msg.type not in (aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY):
raise TypeError(
f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}"
)
return msg.data
@override
async def receive_text(self) -> str:
msg = await self._receive()
if msg.type != aiohttp.WSMsgType.TEXT:
raise TypeError(
f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}"
)
return msg.data
2021-07-20 15:35:56 +08:00
@override
2021-07-20 15:35:56 +08:00
async def receive_bytes(self) -> bytes:
msg = await self._receive()
if msg.type != aiohttp.WSMsgType.BINARY:
raise TypeError(
f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}"
)
return msg.data
2021-07-20 15:35:56 +08:00
@override
async def send_text(self, data: str) -> None:
2021-07-20 15:35:56 +08:00
await self.websocket.send_str(data)
@override
2021-07-20 15:35:56 +08:00
async def send_bytes(self, data: bytes) -> None:
await self.websocket.send_bytes(data)
2021-12-22 16:53:55 +08:00
if TYPE_CHECKING:
class Driver(Mixin, NoneDriver): ...
else:
Driver = combine_driver(NoneDriver, Mixin)
"""AIOHTTP Driver"""