Feature: 支持 HTTP 客户端会话 (#2627)

This commit is contained in:
Ju4tCode 2024-04-05 21:11:05 +08:00 committed by GitHub
parent 53e2a86dd9
commit 485aa62755
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 420 additions and 65 deletions

View File

@ -23,6 +23,7 @@ from nonebot.internal.driver import ReverseDriver as ReverseDriver
from nonebot.internal.driver import combine_driver as combine_driver from nonebot.internal.driver import combine_driver as combine_driver
from nonebot.internal.driver import HTTPClientMixin as HTTPClientMixin from nonebot.internal.driver import HTTPClientMixin as HTTPClientMixin
from nonebot.internal.driver import HTTPServerSetup as HTTPServerSetup from nonebot.internal.driver import HTTPServerSetup as HTTPServerSetup
from nonebot.internal.driver import HTTPClientSession as HTTPClientSession
from nonebot.internal.driver import WebSocketClientMixin as WebSocketClientMixin from nonebot.internal.driver import WebSocketClientMixin as WebSocketClientMixin
from nonebot.internal.driver import WebSocketServerSetup as WebSocketServerSetup from nonebot.internal.driver import WebSocketServerSetup as WebSocketServerSetup

View File

@ -17,15 +17,19 @@ FrontMatter:
from typing_extensions import override from typing_extensions import override
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, AsyncGenerator from typing import TYPE_CHECKING, Union, Optional, AsyncGenerator
from multidict import CIMultiDict
from nonebot.drivers import Request, Response
from nonebot.exception import WebSocketClosed from nonebot.exception import WebSocketClosed
from nonebot.drivers import URL, Request, Response
from nonebot.drivers.none import Driver as NoneDriver from nonebot.drivers.none import Driver as NoneDriver
from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.internal.driver import Cookies, QueryTypes, CookieTypes, HeaderTypes
from nonebot.drivers import ( from nonebot.drivers import (
HTTPVersion, HTTPVersion,
HTTPClientMixin, HTTPClientMixin,
HTTPClientSession,
WebSocketClientMixin, WebSocketClientMixin,
combine_driver, combine_driver,
) )
@ -39,6 +43,105 @@ except ModuleNotFoundError as e: # pragma: no cover
) from e ) from e
class Session(HTTPClientSession):
@override
def __init__(
self,
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
proxy: Optional[str] = None,
):
self._client: Optional[aiohttp.ClientSession] = None
self._params = URL.build(query=params).query if params is not None else None
self._headers = CIMultiDict(headers) if headers is not None else None
self._cookies = tuple(
(cookie.name, cookie.value)
for cookie in Cookies(cookies)
if cookie.value is not None
)
version = HTTPVersion(version)
if version == HTTPVersion.H10:
self._version = aiohttp.HttpVersion10
elif version == HTTPVersion.H11:
self._version = aiohttp.HttpVersion11
else:
raise RuntimeError(f"Unsupported HTTP version: {version}")
self._timeout = timeout
self._proxy = proxy
@property
def client(self) -> aiohttp.ClientSession:
if self._client is None:
raise RuntimeError("Session is not initialized")
return self._client
@override
async def request(self, setup: Request) -> Response:
if self._params:
params = self._params.copy()
params.update(setup.url.query)
url = setup.url.with_query(params)
else:
url = setup.url
data = setup.data
if setup.files:
data = aiohttp.FormData(data or {}, quote_fields=False)
for name, file in setup.files:
data.add_field(name, file[1], content_type=file[2], filename=file[0])
cookies = (
(cookie.name, cookie.value)
for cookie in setup.cookies
if cookie.value is not None
)
timeout = aiohttp.ClientTimeout(setup.timeout)
async with await self.client.request(
setup.method,
url,
data=setup.content or data,
json=setup.json,
cookies=cookies,
headers=setup.headers,
proxy=setup.proxy or self._proxy,
timeout=timeout,
) as response:
return Response(
response.status,
headers=response.headers.copy(),
content=await response.read(),
request=setup,
)
@override
async def setup(self) -> None:
self._client = aiohttp.ClientSession(
cookies=self._cookies,
headers=self._headers,
version=self._version,
timeout=self._timeout,
trust_env=True,
)
await self._client.__aenter__()
@override
async def close(self) -> None:
try:
if self._client is not None:
await self._client.close()
finally:
self._client = None
class Mixin(HTTPClientMixin, WebSocketClientMixin): class Mixin(HTTPClientMixin, WebSocketClientMixin):
"""AIOHTTP Mixin""" """AIOHTTP Mixin"""
@ -49,42 +152,8 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin):
@override @override
async def request(self, setup: Request) -> Response: async def request(self, setup: Request) -> Response:
if setup.version == HTTPVersion.H10: async with self.get_session() as session:
version = aiohttp.HttpVersion10 return await session.request(setup)
elif setup.version == HTTPVersion.H11:
version = aiohttp.HttpVersion11
else:
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
timeout = aiohttp.ClientTimeout(setup.timeout)
data = setup.data
if setup.files:
data = aiohttp.FormData(data or {}, quote_fields=False)
for name, file in setup.files:
data.add_field(name, file[1], content_type=file[2], filename=file[0])
cookies = {
cookie.name: cookie.value for cookie in setup.cookies if cookie.value
}
async with aiohttp.ClientSession(
cookies=cookies, version=version, trust_env=True
) as session:
async with session.request(
setup.method,
setup.url,
data=setup.content or data,
json=setup.json,
headers=setup.headers,
timeout=timeout,
proxy=setup.proxy,
) as response:
return Response(
response.status,
headers=response.headers.copy(),
content=await response.read(),
request=setup,
)
@override @override
@asynccontextmanager @asynccontextmanager
@ -106,6 +175,25 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin):
) as ws: ) as ws:
yield WebSocket(request=setup, session=session, websocket=ws) yield WebSocket(request=setup, session=session, websocket=ws)
@override
def get_session(
self,
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
proxy: Optional[str] = None,
) -> Session:
return Session(
params=params,
headers=headers,
cookies=cookies,
version=version,
timeout=timeout,
proxy=proxy,
)
class WebSocket(BaseWebSocket): class WebSocket(BaseWebSocket):
"""AIOHTTP Websocket Wrapper""" """AIOHTTP Websocket Wrapper"""

View File

@ -15,15 +15,20 @@ FrontMatter:
description: nonebot.drivers.httpx 模块 description: nonebot.drivers.httpx 模块
""" """
from typing import TYPE_CHECKING
from typing_extensions import override from typing_extensions import override
from typing import TYPE_CHECKING, Union, Optional
from multidict import CIMultiDict
from nonebot.drivers.none import Driver as NoneDriver from nonebot.drivers.none import Driver as NoneDriver
from nonebot.internal.driver import Cookies, QueryTypes, CookieTypes, HeaderTypes
from nonebot.drivers import ( from nonebot.drivers import (
URL,
Request, Request,
Response, Response,
HTTPVersion, HTTPVersion,
HTTPClientMixin, HTTPClientMixin,
HTTPClientSession,
combine_driver, combine_driver,
) )
@ -36,6 +41,77 @@ except ModuleNotFoundError as e: # pragma: no cover
) from e ) from e
class Session(HTTPClientSession):
@override
def __init__(
self,
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
proxy: Optional[str] = None,
):
self._client: Optional[httpx.AsyncClient] = None
self._params = (
tuple(URL.build(query=params).query.items()) if params is not None else None
)
self._headers = (
tuple(CIMultiDict(headers).items()) if headers is not None else None
)
self._cookies = Cookies(cookies)
self._version = HTTPVersion(version)
self._timeout = timeout
self._proxy = proxy
@property
def client(self) -> httpx.AsyncClient:
if self._client is None:
raise RuntimeError("Session is not initialized")
return self._client
@override
async def request(self, setup: Request) -> Response:
response = await self.client.request(
setup.method,
str(setup.url),
content=setup.content,
data=setup.data,
files=setup.files,
json=setup.json,
headers=tuple(setup.headers.items()),
cookies=setup.cookies.jar,
timeout=setup.timeout,
)
return Response(
response.status_code,
headers=response.headers.multi_items(),
content=response.content,
request=setup,
)
@override
async def setup(self) -> None:
self._client = httpx.AsyncClient(
params=self._params,
headers=self._headers,
cookies=self._cookies.jar,
http2=self._version == HTTPVersion.H2,
proxies=self._proxy,
follow_redirects=True,
)
await self._client.__aenter__()
@override
async def close(self) -> None:
try:
if self._client is not None:
await self._client.aclose()
finally:
self._client = None
class Mixin(HTTPClientMixin): class Mixin(HTTPClientMixin):
"""HTTPX Mixin""" """HTTPX Mixin"""
@ -46,28 +122,29 @@ class Mixin(HTTPClientMixin):
@override @override
async def request(self, setup: Request) -> Response: async def request(self, setup: Request) -> Response:
async with httpx.AsyncClient( async with self.get_session(
cookies=setup.cookies.jar, version=setup.version, proxy=setup.proxy
http2=setup.version == HTTPVersion.H2, ) as session:
proxies=setup.proxy, return await session.request(setup)
follow_redirects=True,
) as client: @override
response = await client.request( def get_session(
setup.method, self,
str(setup.url), params: QueryTypes = None,
content=setup.content, headers: HeaderTypes = None,
data=setup.data, cookies: CookieTypes = None,
json=setup.json, version: Union[str, HTTPVersion] = HTTPVersion.H11,
files=setup.files, timeout: Optional[float] = None,
headers=tuple(setup.headers.items()), proxy: Optional[str] = None,
timeout=setup.timeout, ) -> Session:
) return Session(
return Response( params=params,
response.status_code, headers=headers,
headers=response.headers.multi_items(), cookies=cookies,
content=response.content, version=version,
request=setup, timeout=timeout,
) proxy=proxy,
)
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -26,5 +26,6 @@ from .abstract import ReverseDriver as ReverseDriver
from .combine import combine_driver as combine_driver from .combine import combine_driver as combine_driver
from .model import HTTPServerSetup as HTTPServerSetup from .model import HTTPServerSetup as HTTPServerSetup
from .abstract import HTTPClientMixin as HTTPClientMixin from .abstract import HTTPClientMixin as HTTPClientMixin
from .abstract import HTTPClientSession as HTTPClientSession
from .model import WebSocketServerSetup as WebSocketServerSetup from .model import WebSocketServerSetup as WebSocketServerSetup
from .abstract import WebSocketClientMixin as WebSocketClientMixin from .abstract import WebSocketClientMixin as WebSocketClientMixin

View File

@ -1,8 +1,19 @@
import abc import abc
import asyncio import asyncio
from typing_extensions import TypeAlias from types import TracebackType
from typing_extensions import Self, TypeAlias
from contextlib import AsyncExitStack, asynccontextmanager from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, Any, Set, Dict, Type, ClassVar, AsyncGenerator from typing import (
TYPE_CHECKING,
Any,
Set,
Dict,
Type,
Union,
ClassVar,
Optional,
AsyncGenerator,
)
from nonebot.log import logger from nonebot.log import logger
from nonebot.config import Env, Config from nonebot.config import Env, Config
@ -17,7 +28,17 @@ from nonebot.typing import (
) )
from ._lifespan import LIFESPAN_FUNC, Lifespan from ._lifespan import LIFESPAN_FUNC, Lifespan
from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup from .model import (
Request,
Response,
WebSocket,
QueryTypes,
CookieTypes,
HeaderTypes,
HTTPVersion,
HTTPServerSetup,
WebSocketServerSetup,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.internal.adapter import Bot, Adapter from nonebot.internal.adapter import Bot, Adapter
@ -222,6 +243,49 @@ class ReverseMixin(Mixin):
"""服务端混入基类。""" """服务端混入基类。"""
class HTTPClientSession(abc.ABC):
"""HTTP 客户端会话基类。"""
@abc.abstractmethod
def __init__(
self,
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
proxy: Optional[str] = None,
):
raise NotImplementedError
@abc.abstractmethod
async def request(self, setup: Request) -> Response:
"""发送一个 HTTP 请求"""
raise NotImplementedError
@abc.abstractmethod
async def setup(self) -> None:
"""初始化会话"""
raise NotImplementedError
@abc.abstractmethod
async def close(self) -> None:
"""关闭会话"""
raise NotImplementedError
async def __aenter__(self) -> Self:
await self.setup()
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
await self.close()
class HTTPClientMixin(ForwardMixin): class HTTPClientMixin(ForwardMixin):
"""HTTP 客户端混入基类。""" """HTTP 客户端混入基类。"""
@ -230,6 +294,19 @@ class HTTPClientMixin(ForwardMixin):
"""发送一个 HTTP 请求""" """发送一个 HTTP 请求"""
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
def get_session(
self,
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
proxy: Optional[str] = None,
) -> HTTPClientSession:
"""获取一个 HTTP 会话"""
raise NotImplementedError
class WebSocketClientMixin(ForwardMixin): class WebSocketClientMixin(ForwardMixin):
"""WebSocket 客户端混入基类。""" """WebSocket 客户端混入基类。"""

View File

@ -27,7 +27,7 @@ RawURL: TypeAlias = Tuple[bytes, bytes, Optional[int], bytes]
SimpleQuery: TypeAlias = Union[str, int, float] SimpleQuery: TypeAlias = Union[str, int, float]
QueryVariable: TypeAlias = Union[SimpleQuery, List[SimpleQuery]] QueryVariable: TypeAlias = Union[SimpleQuery, List[SimpleQuery]]
QueryTypes: TypeAlias = Union[ QueryTypes: TypeAlias = Union[
None, str, Mapping[str, QueryVariable], List[Tuple[str, QueryVariable]] None, str, Mapping[str, QueryVariable], List[Tuple[str, SimpleQuery]]
] ]
HeaderTypes: TypeAlias = Union[ HeaderTypes: TypeAlias = Union[

View File

@ -1,5 +1,6 @@
import json import json
import asyncio import asyncio
from http.cookies import SimpleCookie
from typing import Any, Set, Optional from typing import Any, Set, Optional
import pytest import pytest
@ -306,6 +307,116 @@ async def test_http_client(driver: Driver, server_url: URL):
await asyncio.sleep(1) await asyncio.sleep(1)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"driver",
[
pytest.param("nonebot.drivers.httpx:Driver", id="httpx"),
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
],
indirect=True,
)
async def test_http_client_session(driver: Driver, server_url: URL):
assert isinstance(driver, HTTPClientMixin)
session = driver.get_session(
params={"session": "test"},
headers={"X-Session": "test"},
cookies={"session": "test"},
)
request = Request("GET", server_url)
with pytest.raises(RuntimeError):
await session.request(request)
async with session as session:
# simple post with query, headers, cookies and content
request = Request(
"POST",
server_url,
params={"param": "test"},
headers={"X-Test": "test"},
cookies={"cookie": "test"},
content="test",
)
response = await session.request(request)
assert response.status_code == 200
assert response.content
data = json.loads(response.content)
assert data["method"] == "POST"
assert data["args"] == {"session": "test", "param": "test"}
assert data["headers"].get("X-Session") == "test"
assert data["headers"].get("X-Test") == "test"
assert {
key: cookie.value
for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items()
} == {
"session": "test",
"cookie": "test",
}
assert data["data"] == "test"
# post with data body
request = Request("POST", server_url, data={"form": "test"})
response = await session.request(request)
assert response.status_code == 200
assert response.content
data = json.loads(response.content)
assert data["method"] == "POST"
assert data["args"] == {"session": "test"}
assert data["headers"].get("X-Session") == "test"
assert {
key: cookie.value
for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items()
} == {"session": "test"}
assert data["form"] == {"form": "test"}
# post with json body
request = Request("POST", server_url, json={"json": "test"})
response = await session.request(request)
assert response.status_code == 200
assert response.content
data = json.loads(response.content)
assert data["method"] == "POST"
assert data["args"] == {"session": "test"}
assert data["headers"].get("X-Session") == "test"
assert {
key: cookie.value
for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items()
} == {"session": "test"}
assert data["json"] == {"json": "test"}
# post with files and form data
request = Request(
"POST",
server_url,
data={"form": "test"},
files=[
("test1", b"test"),
("test2", ("test.txt", b"test")),
("test3", ("test.txt", b"test", "text/plain")),
],
)
response = await session.request(request)
assert response.status_code == 200
assert response.content
data = json.loads(response.content)
assert data["method"] == "POST"
assert data["args"] == {"session": "test"}
assert data["headers"].get("X-Session") == "test"
assert {
key: cookie.value
for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items()
} == {"session": "test"}
assert data["form"] == {"form": "test"}
assert data["files"] == {
"test1": "test",
"test2": "test",
"test3": "test",
}, "file parsing error"
await asyncio.sleep(1)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"driver", "driver",