From 485aa6275506189b1ba5641c39e68d85a61613d2 Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Fri, 5 Apr 2024 21:11:05 +0800 Subject: [PATCH] =?UTF-8?q?:sparkles:=20Feature:=20=E6=94=AF=E6=8C=81=20HT?= =?UTF-8?q?TP=20=E5=AE=A2=E6=88=B7=E7=AB=AF=E4=BC=9A=E8=AF=9D=20(#2627)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot/drivers/__init__.py | 1 + nonebot/drivers/aiohttp.py | 164 +++++++++++++++++++++------- nonebot/drivers/httpx.py | 123 +++++++++++++++++---- nonebot/internal/driver/__init__.py | 1 + nonebot/internal/driver/abstract.py | 83 +++++++++++++- nonebot/internal/driver/model.py | 2 +- tests/test_driver.py | 111 +++++++++++++++++++ 7 files changed, 420 insertions(+), 65 deletions(-) diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index de0f9897..f318fc84 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -23,6 +23,7 @@ from nonebot.internal.driver import ReverseDriver as ReverseDriver from nonebot.internal.driver import combine_driver as combine_driver from nonebot.internal.driver import HTTPClientMixin as HTTPClientMixin from nonebot.internal.driver import HTTPServerSetup as HTTPServerSetup +from nonebot.internal.driver import HTTPClientSession as HTTPClientSession from nonebot.internal.driver import WebSocketClientMixin as WebSocketClientMixin from nonebot.internal.driver import WebSocketServerSetup as WebSocketServerSetup diff --git a/nonebot/drivers/aiohttp.py b/nonebot/drivers/aiohttp.py index 860b1ec2..b12d204b 100644 --- a/nonebot/drivers/aiohttp.py +++ b/nonebot/drivers/aiohttp.py @@ -17,15 +17,19 @@ FrontMatter: from typing_extensions import override from contextlib import asynccontextmanager -from typing import TYPE_CHECKING, AsyncGenerator +from typing import TYPE_CHECKING, Union, Optional, AsyncGenerator + +from multidict import CIMultiDict -from nonebot.drivers import Request, Response from nonebot.exception import WebSocketClosed +from nonebot.drivers import URL, Request, Response from nonebot.drivers.none import Driver as NoneDriver from nonebot.drivers import WebSocket as BaseWebSocket +from nonebot.internal.driver import Cookies, QueryTypes, CookieTypes, HeaderTypes from nonebot.drivers import ( HTTPVersion, HTTPClientMixin, + HTTPClientSession, WebSocketClientMixin, combine_driver, ) @@ -39,6 +43,105 @@ except ModuleNotFoundError as e: # pragma: no cover ) from e +class Session(HTTPClientSession): + @override + def __init__( + self, + params: QueryTypes = None, + headers: HeaderTypes = None, + cookies: CookieTypes = None, + version: Union[str, HTTPVersion] = HTTPVersion.H11, + timeout: Optional[float] = None, + proxy: Optional[str] = None, + ): + self._client: Optional[aiohttp.ClientSession] = None + + self._params = URL.build(query=params).query if params is not None else None + + self._headers = CIMultiDict(headers) if headers is not None else None + self._cookies = tuple( + (cookie.name, cookie.value) + for cookie in Cookies(cookies) + if cookie.value is not None + ) + + version = HTTPVersion(version) + if version == HTTPVersion.H10: + self._version = aiohttp.HttpVersion10 + elif version == HTTPVersion.H11: + self._version = aiohttp.HttpVersion11 + else: + raise RuntimeError(f"Unsupported HTTP version: {version}") + + self._timeout = timeout + self._proxy = proxy + + @property + def client(self) -> aiohttp.ClientSession: + if self._client is None: + raise RuntimeError("Session is not initialized") + return self._client + + @override + async def request(self, setup: Request) -> Response: + if self._params: + params = self._params.copy() + params.update(setup.url.query) + url = setup.url.with_query(params) + else: + url = setup.url + + data = setup.data + if setup.files: + data = aiohttp.FormData(data or {}, quote_fields=False) + for name, file in setup.files: + data.add_field(name, file[1], content_type=file[2], filename=file[0]) + + cookies = ( + (cookie.name, cookie.value) + for cookie in setup.cookies + if cookie.value is not None + ) + + timeout = aiohttp.ClientTimeout(setup.timeout) + + async with await self.client.request( + setup.method, + url, + data=setup.content or data, + json=setup.json, + cookies=cookies, + headers=setup.headers, + proxy=setup.proxy or self._proxy, + timeout=timeout, + ) as response: + return Response( + response.status, + headers=response.headers.copy(), + content=await response.read(), + request=setup, + ) + + @override + async def setup(self) -> None: + self._client = aiohttp.ClientSession( + cookies=self._cookies, + headers=self._headers, + version=self._version, + timeout=self._timeout, + trust_env=True, + ) + await self._client.__aenter__() + + @override + async def close(self) -> None: + try: + if self._client is not None: + await self._client.close() + finally: + self._client = None + + class Mixin(HTTPClientMixin, WebSocketClientMixin): """AIOHTTP Mixin""" @@ -49,42 +152,8 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin): @override async def request(self, setup: Request) -> Response: - if setup.version == HTTPVersion.H10: - version = aiohttp.HttpVersion10 - elif setup.version == HTTPVersion.H11: - version = aiohttp.HttpVersion11 - else: - raise RuntimeError(f"Unsupported HTTP version: {setup.version}") - - timeout = aiohttp.ClientTimeout(setup.timeout) - - data = setup.data - if setup.files: - data = aiohttp.FormData(data or {}, quote_fields=False) - for name, file in setup.files: - data.add_field(name, file[1], content_type=file[2], filename=file[0]) - - cookies = { - cookie.name: cookie.value for cookie in setup.cookies if cookie.value - } - async with aiohttp.ClientSession( - cookies=cookies, version=version, trust_env=True - ) as session: - async with session.request( - setup.method, - setup.url, - data=setup.content or data, - json=setup.json, - headers=setup.headers, - timeout=timeout, - proxy=setup.proxy, - ) as response: - return Response( - response.status, - headers=response.headers.copy(), - content=await response.read(), - request=setup, - ) + async with self.get_session() as session: + return await session.request(setup) @override @asynccontextmanager @@ -106,6 +175,25 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin): ) as ws: yield WebSocket(request=setup, session=session, websocket=ws) + @override + def get_session( + self, + params: QueryTypes = None, + headers: HeaderTypes = None, + cookies: CookieTypes = None, + version: Union[str, HTTPVersion] = HTTPVersion.H11, + timeout: Optional[float] = None, + proxy: Optional[str] = None, + ) -> Session: + return Session( + params=params, + headers=headers, + cookies=cookies, + version=version, + timeout=timeout, + proxy=proxy, + ) + class WebSocket(BaseWebSocket): """AIOHTTP Websocket Wrapper""" diff --git a/nonebot/drivers/httpx.py b/nonebot/drivers/httpx.py index 8c70aada..8300323b 100644 --- a/nonebot/drivers/httpx.py +++ b/nonebot/drivers/httpx.py @@ -15,15 +15,20 @@ FrontMatter: description: nonebot.drivers.httpx 模块 """ -from typing import TYPE_CHECKING from typing_extensions import override +from typing import TYPE_CHECKING, Union, Optional + +from multidict import CIMultiDict from nonebot.drivers.none import Driver as NoneDriver +from nonebot.internal.driver import Cookies, QueryTypes, CookieTypes, HeaderTypes from nonebot.drivers import ( + URL, Request, Response, HTTPVersion, HTTPClientMixin, + HTTPClientSession, combine_driver, ) @@ -36,6 +41,77 @@ except ModuleNotFoundError as e: # pragma: no cover ) from e +class Session(HTTPClientSession): + @override + def __init__( + self, + params: QueryTypes = None, + headers: HeaderTypes = None, + cookies: CookieTypes = None, + version: Union[str, HTTPVersion] = HTTPVersion.H11, + timeout: Optional[float] = None, + proxy: Optional[str] = None, + ): + self._client: Optional[httpx.AsyncClient] = None + + self._params = ( + tuple(URL.build(query=params).query.items()) if params is not None else None + ) + self._headers = ( + tuple(CIMultiDict(headers).items()) if headers is not None else None + ) + self._cookies = Cookies(cookies) + self._version = HTTPVersion(version) + self._timeout = timeout + self._proxy = proxy + + @property + def client(self) -> httpx.AsyncClient: + if self._client is None: + raise RuntimeError("Session is not initialized") + return self._client + + @override + async def request(self, setup: Request) -> Response: + response = await self.client.request( + setup.method, + str(setup.url), + content=setup.content, + data=setup.data, + files=setup.files, + json=setup.json, + headers=tuple(setup.headers.items()), + cookies=setup.cookies.jar, + timeout=setup.timeout, + ) + return Response( + response.status_code, + headers=response.headers.multi_items(), + content=response.content, + request=setup, + ) + + @override + async def setup(self) -> None: + self._client = httpx.AsyncClient( + params=self._params, + headers=self._headers, + cookies=self._cookies.jar, + http2=self._version == HTTPVersion.H2, + proxies=self._proxy, + follow_redirects=True, + ) + await self._client.__aenter__() + + @override + async def close(self) -> None: + try: + if self._client is not None: + await self._client.aclose() + finally: + self._client = None + + class Mixin(HTTPClientMixin): """HTTPX Mixin""" @@ -46,28 +122,29 @@ class Mixin(HTTPClientMixin): @override async def request(self, setup: Request) -> Response: - async with httpx.AsyncClient( - cookies=setup.cookies.jar, - http2=setup.version == HTTPVersion.H2, - proxies=setup.proxy, - follow_redirects=True, - ) as client: - response = await client.request( - setup.method, - str(setup.url), - content=setup.content, - data=setup.data, - json=setup.json, - files=setup.files, - headers=tuple(setup.headers.items()), - timeout=setup.timeout, - ) - return Response( - response.status_code, - headers=response.headers.multi_items(), - content=response.content, - request=setup, - ) + async with self.get_session( + version=setup.version, proxy=setup.proxy + ) as session: + return await session.request(setup) + + @override + def get_session( + self, + params: QueryTypes = None, + headers: HeaderTypes = None, + cookies: CookieTypes = None, + version: Union[str, HTTPVersion] = HTTPVersion.H11, + timeout: Optional[float] = None, + proxy: Optional[str] = None, + ) -> Session: + return Session( + params=params, + headers=headers, + cookies=cookies, + version=version, + timeout=timeout, + proxy=proxy, + ) if TYPE_CHECKING: diff --git a/nonebot/internal/driver/__init__.py b/nonebot/internal/driver/__init__.py index bf6db531..3fc7a514 100644 --- a/nonebot/internal/driver/__init__.py +++ b/nonebot/internal/driver/__init__.py @@ -26,5 +26,6 @@ from .abstract import ReverseDriver as ReverseDriver from .combine import combine_driver as combine_driver from .model import HTTPServerSetup as HTTPServerSetup from .abstract import HTTPClientMixin as HTTPClientMixin +from .abstract import HTTPClientSession as HTTPClientSession from .model import WebSocketServerSetup as WebSocketServerSetup from .abstract import WebSocketClientMixin as WebSocketClientMixin diff --git a/nonebot/internal/driver/abstract.py b/nonebot/internal/driver/abstract.py index 59aac425..e18bc378 100644 --- a/nonebot/internal/driver/abstract.py +++ b/nonebot/internal/driver/abstract.py @@ -1,8 +1,19 @@ import abc import asyncio -from typing_extensions import TypeAlias +from types import TracebackType +from typing_extensions import Self, TypeAlias from contextlib import AsyncExitStack, asynccontextmanager -from typing import TYPE_CHECKING, Any, Set, Dict, Type, ClassVar, AsyncGenerator +from typing import ( + TYPE_CHECKING, + Any, + Set, + Dict, + Type, + Union, + ClassVar, + Optional, + AsyncGenerator, +) from nonebot.log import logger from nonebot.config import Env, Config @@ -17,7 +28,17 @@ from nonebot.typing import ( ) from ._lifespan import LIFESPAN_FUNC, Lifespan -from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup +from .model import ( + Request, + Response, + WebSocket, + QueryTypes, + CookieTypes, + HeaderTypes, + HTTPVersion, + HTTPServerSetup, + WebSocketServerSetup, +) if TYPE_CHECKING: from nonebot.internal.adapter import Bot, Adapter @@ -222,6 +243,49 @@ class ReverseMixin(Mixin): """服务端混入基类。""" +class HTTPClientSession(abc.ABC): + """HTTP 客户端会话基类。""" + + @abc.abstractmethod + def __init__( + self, + params: QueryTypes = None, + headers: HeaderTypes = None, + cookies: CookieTypes = None, + version: Union[str, HTTPVersion] = HTTPVersion.H11, + timeout: Optional[float] = None, + proxy: Optional[str] = None, + ): + raise NotImplementedError + + @abc.abstractmethod + async def request(self, setup: Request) -> Response: + """发送一个 HTTP 请求""" + raise NotImplementedError + + @abc.abstractmethod + async def setup(self) -> None: + """初始化会话""" + raise NotImplementedError + + @abc.abstractmethod + async def close(self) -> None: + """关闭会话""" + raise NotImplementedError + + async def __aenter__(self) -> Self: + await self.setup() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + await self.close() + + class HTTPClientMixin(ForwardMixin): """HTTP 客户端混入基类。""" @@ -230,6 +294,19 @@ class HTTPClientMixin(ForwardMixin): """发送一个 HTTP 请求""" raise NotImplementedError + @abc.abstractmethod + def get_session( + self, + params: QueryTypes = None, + headers: HeaderTypes = None, + cookies: CookieTypes = None, + version: Union[str, HTTPVersion] = HTTPVersion.H11, + timeout: Optional[float] = None, + proxy: Optional[str] = None, + ) -> HTTPClientSession: + """获取一个 HTTP 会话""" + raise NotImplementedError + class WebSocketClientMixin(ForwardMixin): """WebSocket 客户端混入基类。""" diff --git a/nonebot/internal/driver/model.py b/nonebot/internal/driver/model.py index 927c9a89..34776e58 100644 --- a/nonebot/internal/driver/model.py +++ b/nonebot/internal/driver/model.py @@ -27,7 +27,7 @@ RawURL: TypeAlias = Tuple[bytes, bytes, Optional[int], bytes] SimpleQuery: TypeAlias = Union[str, int, float] QueryVariable: TypeAlias = Union[SimpleQuery, List[SimpleQuery]] QueryTypes: TypeAlias = Union[ - None, str, Mapping[str, QueryVariable], List[Tuple[str, QueryVariable]] + None, str, Mapping[str, QueryVariable], List[Tuple[str, SimpleQuery]] ] HeaderTypes: TypeAlias = Union[ diff --git a/tests/test_driver.py b/tests/test_driver.py index 8b30da18..546cfeec 100644 --- a/tests/test_driver.py +++ b/tests/test_driver.py @@ -1,5 +1,6 @@ import json import asyncio +from http.cookies import SimpleCookie from typing import Any, Set, Optional import pytest @@ -306,6 +307,116 @@ async def test_http_client(driver: Driver, server_url: URL): await asyncio.sleep(1) +@pytest.mark.asyncio +@pytest.mark.parametrize( + "driver", + [ + pytest.param("nonebot.drivers.httpx:Driver", id="httpx"), + pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"), + ], + indirect=True, +) +async def test_http_client_session(driver: Driver, server_url: URL): + assert isinstance(driver, HTTPClientMixin) + + session = driver.get_session( + params={"session": "test"}, + headers={"X-Session": "test"}, + cookies={"session": "test"}, + ) + request = Request("GET", server_url) + with pytest.raises(RuntimeError): + await session.request(request) + + async with session as session: + # simple post with query, headers, cookies and content + request = Request( + "POST", + server_url, + params={"param": "test"}, + headers={"X-Test": "test"}, + cookies={"cookie": "test"}, + content="test", + ) + response = await session.request(request) + assert response.status_code == 200 + assert response.content + data = json.loads(response.content) + assert data["method"] == "POST" + assert data["args"] == {"session": "test", "param": "test"} + assert data["headers"].get("X-Session") == "test" + assert data["headers"].get("X-Test") == "test" + assert { + key: cookie.value + for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items() + } == { + "session": "test", + "cookie": "test", + } + assert data["data"] == "test" + + # post with data body + request = Request("POST", server_url, data={"form": "test"}) + response = await session.request(request) + assert response.status_code == 200 + assert response.content + data = json.loads(response.content) + assert data["method"] == "POST" + assert data["args"] == {"session": "test"} + assert data["headers"].get("X-Session") == "test" + assert { + key: cookie.value + for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items() + } == {"session": "test"} + assert data["form"] == {"form": "test"} + + # post with json body + request = Request("POST", server_url, json={"json": "test"}) + response = await session.request(request) + assert response.status_code == 200 + assert response.content + data = json.loads(response.content) + assert data["method"] == "POST" + assert data["args"] == {"session": "test"} + assert data["headers"].get("X-Session") == "test" + assert { + key: cookie.value + for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items() + } == {"session": "test"} + assert data["json"] == {"json": "test"} + + # post with files and form data + request = Request( + "POST", + server_url, + data={"form": "test"}, + files=[ + ("test1", b"test"), + ("test2", ("test.txt", b"test")), + ("test3", ("test.txt", b"test", "text/plain")), + ], + ) + response = await session.request(request) + assert response.status_code == 200 + assert response.content + data = json.loads(response.content) + assert data["method"] == "POST" + assert data["args"] == {"session": "test"} + assert data["headers"].get("X-Session") == "test" + assert { + key: cookie.value + for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items() + } == {"session": "test"} + assert data["form"] == {"form": "test"} + assert data["files"] == { + "test1": "test", + "test2": "test", + "test3": "test", + }, "file parsing error" + + await asyncio.sleep(1) + + @pytest.mark.asyncio @pytest.mark.parametrize( "driver",