mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-12-12 22:45:47 +08:00
143 lines
4.6 KiB
Python
143 lines
4.6 KiB
Python
"""
|
|
AIOHTTP 驱动适配
|
|
================
|
|
|
|
本驱动仅支持客户端连接
|
|
"""
|
|
|
|
from typing import AsyncGenerator
|
|
from contextlib import asynccontextmanager
|
|
|
|
from nonebot.typing import overrides
|
|
from nonebot.drivers import Request, Response
|
|
from nonebot.exception import WebSocketClosed
|
|
from nonebot.drivers._block_driver import BlockDriver
|
|
from nonebot.drivers import WebSocket as BaseWebSocket
|
|
from nonebot.drivers import HTTPVersion, ForwardMixin, combine_driver
|
|
|
|
try:
|
|
import aiohttp
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Please install aiohttp first to use this driver. `pip install nonebot2[aiohttp]`"
|
|
) from None
|
|
|
|
|
|
class Mixin(ForwardMixin):
|
|
@property
|
|
@overrides(ForwardMixin)
|
|
def type(self) -> str:
|
|
return "aiohttp"
|
|
|
|
@overrides(ForwardMixin)
|
|
async def request(self, setup: Request) -> Response:
|
|
if setup.version == HTTPVersion.H10:
|
|
version = aiohttp.HttpVersion10
|
|
elif setup.version == HTTPVersion.H11:
|
|
version = aiohttp.HttpVersion11
|
|
else:
|
|
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
|
|
|
|
timeout = aiohttp.ClientTimeout(setup.timeout)
|
|
files = None
|
|
if setup.files:
|
|
files = aiohttp.FormData()
|
|
for name, file in setup.files:
|
|
files.add_field(name, file[1], content_type=file[2], filename=file[0])
|
|
async with aiohttp.ClientSession(version=version, trust_env=True) as session:
|
|
async with session.request(
|
|
setup.method,
|
|
setup.url,
|
|
data=setup.content or setup.data or files,
|
|
json=setup.json,
|
|
headers=setup.headers,
|
|
timeout=timeout,
|
|
proxy=setup.proxy,
|
|
) as response:
|
|
res = Response(
|
|
response.status,
|
|
headers=response.headers.copy(),
|
|
content=await response.read(),
|
|
request=setup,
|
|
)
|
|
return res
|
|
|
|
@overrides(ForwardMixin)
|
|
@asynccontextmanager
|
|
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
|
|
if setup.version == HTTPVersion.H10:
|
|
version = aiohttp.HttpVersion10
|
|
elif setup.version == HTTPVersion.H11:
|
|
version = aiohttp.HttpVersion11
|
|
else:
|
|
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
|
|
|
|
async with aiohttp.ClientSession(version=version, trust_env=True) as session:
|
|
async with session.ws_connect(
|
|
setup.url,
|
|
method=setup.method,
|
|
timeout=setup.timeout or 10,
|
|
headers=setup.headers,
|
|
proxy=setup.proxy,
|
|
) as ws:
|
|
websocket = WebSocket(request=setup, session=session, websocket=ws)
|
|
yield websocket
|
|
|
|
|
|
class WebSocket(BaseWebSocket):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
request: Request,
|
|
session: aiohttp.ClientSession,
|
|
websocket: aiohttp.ClientWebSocketResponse,
|
|
):
|
|
super().__init__(request=request)
|
|
self.session = session
|
|
self.websocket = websocket
|
|
|
|
@property
|
|
@overrides(BaseWebSocket)
|
|
def closed(self):
|
|
return self.websocket.closed
|
|
|
|
@overrides(BaseWebSocket)
|
|
async def accept(self):
|
|
raise NotImplementedError
|
|
|
|
@overrides(BaseWebSocket)
|
|
async def close(self, code: int = 1000):
|
|
await self.websocket.close(code=code)
|
|
await self.session.close()
|
|
|
|
async def _receive(self) -> aiohttp.WSMessage:
|
|
msg = await self.websocket.receive()
|
|
if msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
|
|
raise WebSocketClosed(self.websocket.close_code or 1006)
|
|
return msg
|
|
|
|
@overrides(BaseWebSocket)
|
|
async def receive(self) -> str:
|
|
msg = await self._receive()
|
|
if msg.type != aiohttp.WSMsgType.TEXT:
|
|
raise TypeError(f"WebSocket received unexpected frame type: {msg.type}")
|
|
return msg.data
|
|
|
|
@overrides(BaseWebSocket)
|
|
async def receive_bytes(self) -> bytes:
|
|
msg = await self._receive()
|
|
if msg.type != aiohttp.WSMsgType.TEXT:
|
|
raise TypeError(f"WebSocket received unexpected frame type: {msg.type}")
|
|
return msg.data
|
|
|
|
@overrides(BaseWebSocket)
|
|
async def send(self, data: str) -> None:
|
|
await self.websocket.send_str(data)
|
|
|
|
@overrides(BaseWebSocket)
|
|
async def send_bytes(self, data: bytes) -> None:
|
|
await self.websocket.send_bytes(data)
|
|
|
|
|
|
Driver = combine_driver(BlockDriver, Mixin)
|