"""[AIOHTTP](https://aiohttp.readthedocs.io/en/stable/) 驱动适配器。 ```bash nb driver install aiohttp # 或者 pip install nonebot2[aiohttp] ``` :::tip 提示 本驱动仅支持客户端连接 ::: FrontMatter: sidebar_position: 2 description: nonebot.drivers.aiohttp 模块 """ from typing import Type, AsyncGenerator from contextlib import asynccontextmanager from nonebot.typing import overrides from nonebot.drivers import Request, Response from nonebot.exception import WebSocketClosed from nonebot.drivers.none import Driver as NoneDriver from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import HTTPVersion, ForwardMixin, ForwardDriver, combine_driver try: import aiohttp except ImportError as e: # pragma: no cover raise ImportError( "Please install aiohttp first to use this driver. `pip install nonebot2[aiohttp]`" ) from e class Mixin(ForwardMixin): """AIOHTTP Mixin""" @property @overrides(ForwardMixin) def type(self) -> str: return "aiohttp" @overrides(ForwardMixin) async def request(self, setup: Request) -> Response: if setup.version == HTTPVersion.H10: version = aiohttp.HttpVersion10 elif setup.version == HTTPVersion.H11: version = aiohttp.HttpVersion11 else: raise RuntimeError(f"Unsupported HTTP version: {setup.version}") timeout = aiohttp.ClientTimeout(setup.timeout) files = None if setup.files: files = aiohttp.FormData() for name, file in setup.files: files.add_field(name, file[1], content_type=file[2], filename=file[0]) cookies = { cookie.name: cookie.value for cookie in setup.cookies if cookie.value } async with aiohttp.ClientSession( cookies=cookies, version=version, trust_env=True ) as session: async with session.request( setup.method, setup.url, data=setup.content or setup.data or files, json=setup.json, headers=setup.headers, timeout=timeout, proxy=setup.proxy, ) as response: return Response( response.status, headers=response.headers.copy(), content=await response.read(), request=setup, ) @overrides(ForwardMixin) @asynccontextmanager async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]: if setup.version == HTTPVersion.H10: version = aiohttp.HttpVersion10 elif setup.version == HTTPVersion.H11: version = aiohttp.HttpVersion11 else: raise RuntimeError(f"Unsupported HTTP version: {setup.version}") async with aiohttp.ClientSession(version=version, trust_env=True) as session: async with session.ws_connect( setup.url, method=setup.method, timeout=setup.timeout or 10, headers=setup.headers, proxy=setup.proxy, ) as ws: yield WebSocket(request=setup, session=session, websocket=ws) class WebSocket(BaseWebSocket): """AIOHTTP Websocket Wrapper""" def __init__( self, *, request: Request, session: aiohttp.ClientSession, websocket: aiohttp.ClientWebSocketResponse, ): super().__init__(request=request) self.session = session self.websocket = websocket @property @overrides(BaseWebSocket) def closed(self): return self.websocket.closed @overrides(BaseWebSocket) async def accept(self): raise NotImplementedError @overrides(BaseWebSocket) async def close(self, code: int = 1000): await self.websocket.close(code=code) await self.session.close() async def _receive(self) -> aiohttp.WSMessage: msg = await self.websocket.receive() if msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING): raise WebSocketClosed(self.websocket.close_code or 1006) return msg @overrides(BaseWebSocket) async def receive(self) -> str: msg = await self._receive() if msg.type not in (aiohttp.WSMsgType.TEXT, aiohttp.WSMsgType.BINARY): raise TypeError( f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}" ) return msg.data @overrides(BaseWebSocket) async def receive_text(self) -> str: msg = await self._receive() if msg.type != aiohttp.WSMsgType.TEXT: raise TypeError( f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}" ) return msg.data @overrides(BaseWebSocket) async def receive_bytes(self) -> bytes: msg = await self._receive() if msg.type != aiohttp.WSMsgType.BINARY: raise TypeError( f"WebSocket received unexpected frame type: {msg.type}, {msg.data!r}" ) return msg.data @overrides(BaseWebSocket) async def send_text(self, data: str) -> None: await self.websocket.send_str(data) @overrides(BaseWebSocket) async def send_bytes(self, data: bytes) -> None: await self.websocket.send_bytes(data) Driver: Type[ForwardDriver] = combine_driver(NoneDriver, Mixin) # type: ignore """AIOHTTP Driver"""