nonebot2/nonebot/drivers/aiohttp.py

143 lines
4.5 KiB
Python
Raw Normal View History

2021-06-21 01:22:33 +08:00
"""
2021-07-31 12:24:11 +08:00
AIOHTTP 驱动适配
================
本驱动仅支持客户端连接
2021-06-21 01:22:33 +08:00
"""
from typing import AsyncGenerator
from contextlib import asynccontextmanager
2021-06-21 01:22:33 +08:00
from nonebot.typing import overrides
2021-12-22 16:53:55 +08:00
from nonebot.drivers import Request, Response
from nonebot.exception import WebSocketClosed
2021-12-22 16:53:55 +08:00
from nonebot.drivers._block_driver import BlockDriver
from nonebot.drivers import WebSocket as BaseWebSocket
2021-12-22 16:53:55 +08:00
from nonebot.drivers import HTTPVersion, ForwardMixin, combine_driver
2021-06-21 01:22:33 +08:00
try:
import aiohttp
except ImportError:
raise ImportError(
"Please install aiohttp first to use this driver. `pip install nonebot2[aiohttp]`"
) from None
2021-06-21 01:22:33 +08:00
class Mixin(ForwardMixin):
2021-06-21 01:22:33 +08:00
@property
2021-12-22 16:53:55 +08:00
@overrides(ForwardMixin)
2021-06-21 01:22:33 +08:00
def type(self) -> str:
return "aiohttp"
2021-12-22 16:53:55 +08:00
@overrides(ForwardMixin)
async def request(self, setup: Request) -> Response:
if setup.version == HTTPVersion.H10:
version = aiohttp.HttpVersion10
elif setup.version == HTTPVersion.H11:
version = aiohttp.HttpVersion11
2021-07-20 15:35:56 +08:00
else:
2021-12-22 16:53:55 +08:00
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
2021-07-20 15:35:56 +08:00
2021-12-22 16:53:55 +08:00
timeout = aiohttp.ClientTimeout(setup.timeout)
files = None
if setup.files:
files = aiohttp.FormData()
for name, file in setup.files:
files.add_field(name, file[1], content_type=file[2], filename=file[0])
2021-12-25 14:04:53 +08:00
async with aiohttp.ClientSession(version=version, trust_env=True) as session:
2021-12-22 16:53:55 +08:00
async with session.request(
setup.method,
2021-12-22 16:53:55 +08:00
setup.url,
data=setup.content or setup.data or files,
json=setup.json,
2021-12-22 16:53:55 +08:00
headers=setup.headers,
timeout=timeout,
2021-12-25 14:04:53 +08:00
proxy=setup.proxy,
2021-12-22 16:53:55 +08:00
) as response:
res = Response(
response.status,
headers=response.headers.copy(),
content=await response.read(),
request=setup,
)
return res
@overrides(ForwardMixin)
@asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
2021-12-22 16:53:55 +08:00
if setup.version == HTTPVersion.H10:
version = aiohttp.HttpVersion10
elif setup.version == HTTPVersion.H11:
version = aiohttp.HttpVersion11
else:
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
2021-12-25 14:04:53 +08:00
session = aiohttp.ClientSession(version=version, trust_env=True)
async with session.ws_connect(
2021-12-22 16:53:55 +08:00
setup.url,
method=setup.method,
timeout=setup.timeout or 10,
headers=setup.headers,
2021-12-25 14:04:53 +08:00
proxy=setup.proxy,
) as ws:
websocket = WebSocket(request=setup, session=session, websocket=ws)
yield websocket
2021-07-19 01:20:17 +08:00
2021-07-31 12:24:11 +08:00
2021-07-20 15:35:56 +08:00
class WebSocket(BaseWebSocket):
2021-12-22 16:53:55 +08:00
def __init__(
self,
*,
request: Request,
session: aiohttp.ClientSession,
websocket: aiohttp.ClientWebSocketResponse,
):
super().__init__(request=request)
self.session = session
self.websocket = websocket
2021-07-20 15:35:56 +08:00
@property
@overrides(BaseWebSocket)
def closed(self):
return self.websocket.closed
@overrides(BaseWebSocket)
async def accept(self):
raise NotImplementedError
@overrides(BaseWebSocket)
async def close(self, code: int = 1000):
await self.websocket.close(code=code)
2021-12-22 16:53:55 +08:00
await self.session.close()
2021-07-20 15:35:56 +08:00
async def _receive(self) -> aiohttp.WSMessage:
msg = await self.websocket.receive()
if msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
raise WebSocketClosed(self.websocket.close_code or 1006)
return msg
2021-07-20 15:35:56 +08:00
@overrides(BaseWebSocket)
async def receive(self) -> str:
msg = await self._receive()
if msg.type != aiohttp.WSMsgType.TEXT:
raise TypeError(f"WebSocket received unexpected frame type: {msg.type}")
return msg.data
2021-07-20 15:35:56 +08:00
@overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes:
msg = await self._receive()
if msg.type != aiohttp.WSMsgType.TEXT:
raise TypeError(f"WebSocket received unexpected frame type: {msg.type}")
return msg.data
2021-07-20 15:35:56 +08:00
@overrides(BaseWebSocket)
async def send(self, data: str) -> None:
await self.websocket.send_str(data)
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None:
await self.websocket.send_bytes(data)
2021-12-22 16:53:55 +08:00
Driver = combine_driver(BlockDriver, Mixin)