"""[AIOHTTP](https://aiohttp.readthedocs.io/en/stable/) 驱动适配器。

```bash
nb driver install aiohttp
# 或者
pip install nonebot2[aiohttp]
```

:::tip 提示
本驱动仅支持客户端连接
:::

FrontMatter:
    sidebar_position: 2
    description: nonebot.drivers.aiohttp 模块
"""

from typing import Type, AsyncGenerator
from contextlib import asynccontextmanager

from nonebot.typing import overrides
from nonebot.drivers import Request, Response
from nonebot.exception import WebSocketClosed
from nonebot.drivers._block_driver import BlockDriver
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import HTTPVersion, ForwardMixin, ForwardDriver, combine_driver

try:
    import aiohttp
except ImportError:
    raise ImportError(
        "Please install aiohttp first to use this driver. `pip install nonebot2[aiohttp]`"
    ) from None


class Mixin(ForwardMixin):
    """AIOHTTP Mixin"""

    @property
    @overrides(ForwardMixin)
    def type(self) -> str:
        return "aiohttp"

    @overrides(ForwardMixin)
    async def request(self, setup: Request) -> Response:
        if setup.version == HTTPVersion.H10:
            version = aiohttp.HttpVersion10
        elif setup.version == HTTPVersion.H11:
            version = aiohttp.HttpVersion11
        else:
            raise RuntimeError(f"Unsupported HTTP version: {setup.version}")

        timeout = aiohttp.ClientTimeout(setup.timeout)
        files = None
        if setup.files:
            files = aiohttp.FormData()
            for name, file in setup.files:
                files.add_field(name, file[1], content_type=file[2], filename=file[0])
        async with aiohttp.ClientSession(version=version, trust_env=True) as session:
            async with session.request(
                setup.method,
                setup.url,
                data=setup.content or setup.data or files,
                json=setup.json,
                headers=setup.headers,
                timeout=timeout,
                proxy=setup.proxy,
            ) as response:
                res = Response(
                    response.status,
                    headers=response.headers.copy(),
                    content=await response.read(),
                    request=setup,
                )
                return res

    @overrides(ForwardMixin)
    @asynccontextmanager
    async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
        if setup.version == HTTPVersion.H10:
            version = aiohttp.HttpVersion10
        elif setup.version == HTTPVersion.H11:
            version = aiohttp.HttpVersion11
        else:
            raise RuntimeError(f"Unsupported HTTP version: {setup.version}")

        async with aiohttp.ClientSession(version=version, trust_env=True) as session:
            async with session.ws_connect(
                setup.url,
                method=setup.method,
                timeout=setup.timeout or 10,
                headers=setup.headers,
                proxy=setup.proxy,
            ) as ws:
                websocket = WebSocket(request=setup, session=session, websocket=ws)
                yield websocket


class WebSocket(BaseWebSocket):
    """AIOHTTP Websocket Wrapper"""

    def __init__(
        self,
        *,
        request: Request,
        session: aiohttp.ClientSession,
        websocket: aiohttp.ClientWebSocketResponse,
    ):
        super().__init__(request=request)
        self.session = session
        self.websocket = websocket

    @property
    @overrides(BaseWebSocket)
    def closed(self):
        return self.websocket.closed

    @overrides(BaseWebSocket)
    async def accept(self):
        raise NotImplementedError

    @overrides(BaseWebSocket)
    async def close(self, code: int = 1000):
        await self.websocket.close(code=code)
        await self.session.close()

    async def _receive(self) -> aiohttp.WSMessage:
        msg = await self.websocket.receive()
        if msg.type in (aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
            raise WebSocketClosed(self.websocket.close_code or 1006)
        return msg

    @overrides(BaseWebSocket)
    async def receive(self) -> str:
        msg = await self._receive()
        if msg.type != aiohttp.WSMsgType.TEXT:
            raise TypeError(f"WebSocket received unexpected frame type: {msg.type}")
        return msg.data

    @overrides(BaseWebSocket)
    async def receive_bytes(self) -> bytes:
        msg = await self._receive()
        if msg.type != aiohttp.WSMsgType.TEXT:
            raise TypeError(f"WebSocket received unexpected frame type: {msg.type}")
        return msg.data

    @overrides(BaseWebSocket)
    async def send(self, data: str) -> None:
        await self.websocket.send_str(data)

    @overrides(BaseWebSocket)
    async def send_bytes(self, data: bytes) -> None:
        await self.websocket.send_bytes(data)


Driver: Type[ForwardDriver] = combine_driver(BlockDriver, Mixin)  # type: ignore
"""AIOHTTP Driver"""