nonebot2/nonebot/drivers/aiohttp.py

124 lines
3.7 KiB
Python
Raw Normal View History

2021-06-21 01:22:33 +08:00
"""
2021-07-31 12:24:11 +08:00
AIOHTTP 驱动适配
================
本驱动仅支持客户端连接
2021-06-21 01:22:33 +08:00
"""
from nonebot.typing import overrides
2021-12-22 16:53:55 +08:00
from nonebot.drivers import Request, Response
from nonebot.drivers._block_driver import BlockDriver
from nonebot.drivers import WebSocket as BaseWebSocket
2021-12-22 16:53:55 +08:00
from nonebot.drivers import HTTPVersion, ForwardMixin, combine_driver
2021-06-21 01:22:33 +08:00
try:
import aiohttp
except ImportError:
raise ImportError(
"Please install aiohttp first to use this driver. `pip install nonebot2[aiohttp]`"
) from None
2021-06-21 01:22:33 +08:00
class Mixin(ForwardMixin):
2021-06-21 01:22:33 +08:00
@property
2021-12-22 16:53:55 +08:00
@overrides(ForwardMixin)
2021-06-21 01:22:33 +08:00
def type(self) -> str:
return "aiohttp"
2021-12-22 16:53:55 +08:00
@overrides(ForwardMixin)
async def request(self, setup: Request) -> Response:
if setup.version == HTTPVersion.H10:
version = aiohttp.HttpVersion10
elif setup.version == HTTPVersion.H11:
version = aiohttp.HttpVersion11
2021-07-20 15:35:56 +08:00
else:
2021-12-22 16:53:55 +08:00
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
2021-07-20 15:35:56 +08:00
2021-12-22 16:53:55 +08:00
timeout = aiohttp.ClientTimeout(setup.timeout)
files = None
if setup.files:
files = aiohttp.FormData()
for name, file in setup.files:
files.add_field(name, file[1], content_type=file[2], filename=file[0])
2021-12-22 16:53:55 +08:00
async with aiohttp.ClientSession(version=version) as session:
async with session.request(
setup.method,
2021-12-22 16:53:55 +08:00
setup.url,
data=setup.content or setup.data or files,
json=setup.json,
2021-12-22 16:53:55 +08:00
headers=setup.headers,
timeout=timeout,
) as response:
res = Response(
response.status,
headers=response.headers.copy(),
content=await response.read(),
request=setup,
)
return res
@overrides(ForwardMixin)
async def websocket(self, setup: Request) -> "WebSocket":
if setup.version == HTTPVersion.H10:
version = aiohttp.HttpVersion10
elif setup.version == HTTPVersion.H11:
version = aiohttp.HttpVersion11
else:
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
session = aiohttp.ClientSession(version=version)
ws = await session.ws_connect(
setup.url,
method=setup.method,
timeout=setup.timeout or 10,
headers=setup.headers,
)
2021-12-22 16:53:55 +08:00
websocket = WebSocket(request=setup, session=session, websocket=ws)
return websocket
2021-07-19 01:20:17 +08:00
2021-07-31 12:24:11 +08:00
2021-07-20 15:35:56 +08:00
class WebSocket(BaseWebSocket):
2021-12-22 16:53:55 +08:00
def __init__(
self,
*,
request: Request,
session: aiohttp.ClientSession,
websocket: aiohttp.ClientWebSocketResponse,
):
super().__init__(request=request)
self.session = session
self.websocket = websocket
2021-07-20 15:35:56 +08:00
@property
@overrides(BaseWebSocket)
def closed(self):
return self.websocket.closed
@overrides(BaseWebSocket)
async def accept(self):
raise NotImplementedError
@overrides(BaseWebSocket)
async def close(self, code: int = 1000):
await self.websocket.close(code=code)
2021-12-22 16:53:55 +08:00
await self.session.close()
2021-07-20 15:35:56 +08:00
@overrides(BaseWebSocket)
async def receive(self) -> str:
return await self.websocket.receive_str()
@overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes:
return await self.websocket.receive_bytes()
@overrides(BaseWebSocket)
async def send(self, data: str) -> None:
await self.websocket.send_str(data)
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None:
await self.websocket.send_bytes(data)
2021-12-22 16:53:55 +08:00
Driver = combine_driver(BlockDriver, Mixin)