2021-06-21 01:22:33 +08:00
|
|
|
"""
|
2021-07-31 12:24:11 +08:00
|
|
|
AIOHTTP 驱动适配
|
|
|
|
================
|
|
|
|
|
|
|
|
本驱动仅支持客户端连接
|
2021-06-21 01:22:33 +08:00
|
|
|
"""
|
|
|
|
|
|
|
|
from nonebot.typing import overrides
|
2021-12-22 16:53:55 +08:00
|
|
|
from nonebot.drivers import Request, Response
|
|
|
|
from nonebot.drivers._block_driver import BlockDriver
|
2021-09-25 19:14:16 +08:00
|
|
|
from nonebot.drivers import WebSocket as BaseWebSocket
|
2021-12-22 16:53:55 +08:00
|
|
|
from nonebot.drivers import HTTPVersion, ForwardMixin, combine_driver
|
2021-06-21 01:22:33 +08:00
|
|
|
|
2021-12-23 14:29:21 +08:00
|
|
|
try:
|
|
|
|
import aiohttp
|
|
|
|
except ImportError:
|
|
|
|
raise ImportError(
|
|
|
|
"Please install aiohttp first to use this driver. `pip install nonebot2[aiohttp]`"
|
|
|
|
) from None
|
|
|
|
|
2021-06-21 01:22:33 +08:00
|
|
|
|
2021-12-23 17:20:26 +08:00
|
|
|
class Mixin(ForwardMixin):
|
2021-06-21 01:22:33 +08:00
|
|
|
@property
|
2021-12-22 16:53:55 +08:00
|
|
|
@overrides(ForwardMixin)
|
2021-06-21 01:22:33 +08:00
|
|
|
def type(self) -> str:
|
|
|
|
return "aiohttp"
|
|
|
|
|
2021-12-22 16:53:55 +08:00
|
|
|
@overrides(ForwardMixin)
|
|
|
|
async def request(self, setup: Request) -> Response:
|
|
|
|
if setup.version == HTTPVersion.H10:
|
|
|
|
version = aiohttp.HttpVersion10
|
|
|
|
elif setup.version == HTTPVersion.H11:
|
|
|
|
version = aiohttp.HttpVersion11
|
2021-07-20 15:35:56 +08:00
|
|
|
else:
|
2021-12-22 16:53:55 +08:00
|
|
|
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
|
2021-07-20 15:35:56 +08:00
|
|
|
|
2021-12-22 16:53:55 +08:00
|
|
|
timeout = aiohttp.ClientTimeout(setup.timeout)
|
|
|
|
async with aiohttp.ClientSession(version=version) as session:
|
|
|
|
async with session.request(
|
2021-11-22 23:21:26 +08:00
|
|
|
setup.method,
|
2021-12-22 16:53:55 +08:00
|
|
|
setup.url,
|
|
|
|
data=setup.content,
|
|
|
|
headers=setup.headers,
|
|
|
|
timeout=timeout,
|
|
|
|
) as response:
|
|
|
|
res = Response(
|
|
|
|
response.status,
|
|
|
|
headers=response.headers.copy(),
|
|
|
|
content=await response.read(),
|
|
|
|
request=setup,
|
|
|
|
)
|
|
|
|
return res
|
|
|
|
|
|
|
|
@overrides(ForwardMixin)
|
|
|
|
async def websocket(self, setup: Request) -> "WebSocket":
|
|
|
|
if setup.version == HTTPVersion.H10:
|
|
|
|
version = aiohttp.HttpVersion10
|
|
|
|
elif setup.version == HTTPVersion.H11:
|
|
|
|
version = aiohttp.HttpVersion11
|
|
|
|
else:
|
|
|
|
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
|
|
|
|
|
|
|
|
session = aiohttp.ClientSession(version=version)
|
|
|
|
ws = await session.ws_connect(
|
|
|
|
setup.url,
|
|
|
|
method=setup.method,
|
|
|
|
timeout=setup.timeout or 10,
|
|
|
|
headers=setup.headers,
|
2021-11-22 23:21:26 +08:00
|
|
|
)
|
2021-12-22 16:53:55 +08:00
|
|
|
websocket = WebSocket(request=setup, session=session, websocket=ws)
|
|
|
|
return websocket
|
2021-07-19 01:20:17 +08:00
|
|
|
|
2021-07-31 12:24:11 +08:00
|
|
|
|
2021-07-20 15:35:56 +08:00
|
|
|
class WebSocket(BaseWebSocket):
|
2021-12-22 16:53:55 +08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
*,
|
|
|
|
request: Request,
|
|
|
|
session: aiohttp.ClientSession,
|
|
|
|
websocket: aiohttp.ClientWebSocketResponse,
|
|
|
|
):
|
|
|
|
super().__init__(request=request)
|
|
|
|
self.session = session
|
|
|
|
self.websocket = websocket
|
2021-07-20 15:35:56 +08:00
|
|
|
|
|
|
|
@property
|
|
|
|
@overrides(BaseWebSocket)
|
|
|
|
def closed(self):
|
|
|
|
return self.websocket.closed
|
|
|
|
|
|
|
|
@overrides(BaseWebSocket)
|
|
|
|
async def accept(self):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
@overrides(BaseWebSocket)
|
|
|
|
async def close(self, code: int = 1000):
|
|
|
|
await self.websocket.close(code=code)
|
2021-12-22 16:53:55 +08:00
|
|
|
await self.session.close()
|
2021-07-20 15:35:56 +08:00
|
|
|
|
|
|
|
@overrides(BaseWebSocket)
|
|
|
|
async def receive(self) -> str:
|
|
|
|
return await self.websocket.receive_str()
|
|
|
|
|
|
|
|
@overrides(BaseWebSocket)
|
|
|
|
async def receive_bytes(self) -> bytes:
|
|
|
|
return await self.websocket.receive_bytes()
|
|
|
|
|
|
|
|
@overrides(BaseWebSocket)
|
|
|
|
async def send(self, data: str) -> None:
|
|
|
|
await self.websocket.send_str(data)
|
|
|
|
|
|
|
|
@overrides(BaseWebSocket)
|
|
|
|
async def send_bytes(self, data: bytes) -> None:
|
|
|
|
await self.websocket.send_bytes(data)
|
2021-12-22 16:53:55 +08:00
|
|
|
|
|
|
|
|
2021-12-23 17:20:26 +08:00
|
|
|
Driver = combine_driver(BlockDriver, Mixin)
|