mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 00:55:07 +08:00
✨ Feature: 支持 HTTP 客户端会话 (#2627)
This commit is contained in:
parent
53e2a86dd9
commit
485aa62755
@ -23,6 +23,7 @@ from nonebot.internal.driver import ReverseDriver as ReverseDriver
|
|||||||
from nonebot.internal.driver import combine_driver as combine_driver
|
from nonebot.internal.driver import combine_driver as combine_driver
|
||||||
from nonebot.internal.driver import HTTPClientMixin as HTTPClientMixin
|
from nonebot.internal.driver import HTTPClientMixin as HTTPClientMixin
|
||||||
from nonebot.internal.driver import HTTPServerSetup as HTTPServerSetup
|
from nonebot.internal.driver import HTTPServerSetup as HTTPServerSetup
|
||||||
|
from nonebot.internal.driver import HTTPClientSession as HTTPClientSession
|
||||||
from nonebot.internal.driver import WebSocketClientMixin as WebSocketClientMixin
|
from nonebot.internal.driver import WebSocketClientMixin as WebSocketClientMixin
|
||||||
from nonebot.internal.driver import WebSocketServerSetup as WebSocketServerSetup
|
from nonebot.internal.driver import WebSocketServerSetup as WebSocketServerSetup
|
||||||
|
|
||||||
|
@ -17,15 +17,19 @@ FrontMatter:
|
|||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import TYPE_CHECKING, AsyncGenerator
|
from typing import TYPE_CHECKING, Union, Optional, AsyncGenerator
|
||||||
|
|
||||||
|
from multidict import CIMultiDict
|
||||||
|
|
||||||
from nonebot.drivers import Request, Response
|
|
||||||
from nonebot.exception import WebSocketClosed
|
from nonebot.exception import WebSocketClosed
|
||||||
|
from nonebot.drivers import URL, Request, Response
|
||||||
from nonebot.drivers.none import Driver as NoneDriver
|
from nonebot.drivers.none import Driver as NoneDriver
|
||||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||||
|
from nonebot.internal.driver import Cookies, QueryTypes, CookieTypes, HeaderTypes
|
||||||
from nonebot.drivers import (
|
from nonebot.drivers import (
|
||||||
HTTPVersion,
|
HTTPVersion,
|
||||||
HTTPClientMixin,
|
HTTPClientMixin,
|
||||||
|
HTTPClientSession,
|
||||||
WebSocketClientMixin,
|
WebSocketClientMixin,
|
||||||
combine_driver,
|
combine_driver,
|
||||||
)
|
)
|
||||||
@ -39,6 +43,105 @@ except ModuleNotFoundError as e: # pragma: no cover
|
|||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
class Session(HTTPClientSession):
|
||||||
|
@override
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: QueryTypes = None,
|
||||||
|
headers: HeaderTypes = None,
|
||||||
|
cookies: CookieTypes = None,
|
||||||
|
version: Union[str, HTTPVersion] = HTTPVersion.H11,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
proxy: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self._client: Optional[aiohttp.ClientSession] = None
|
||||||
|
|
||||||
|
self._params = URL.build(query=params).query if params is not None else None
|
||||||
|
|
||||||
|
self._headers = CIMultiDict(headers) if headers is not None else None
|
||||||
|
self._cookies = tuple(
|
||||||
|
(cookie.name, cookie.value)
|
||||||
|
for cookie in Cookies(cookies)
|
||||||
|
if cookie.value is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
version = HTTPVersion(version)
|
||||||
|
if version == HTTPVersion.H10:
|
||||||
|
self._version = aiohttp.HttpVersion10
|
||||||
|
elif version == HTTPVersion.H11:
|
||||||
|
self._version = aiohttp.HttpVersion11
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported HTTP version: {version}")
|
||||||
|
|
||||||
|
self._timeout = timeout
|
||||||
|
self._proxy = proxy
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self) -> aiohttp.ClientSession:
|
||||||
|
if self._client is None:
|
||||||
|
raise RuntimeError("Session is not initialized")
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def request(self, setup: Request) -> Response:
|
||||||
|
if self._params:
|
||||||
|
params = self._params.copy()
|
||||||
|
params.update(setup.url.query)
|
||||||
|
url = setup.url.with_query(params)
|
||||||
|
else:
|
||||||
|
url = setup.url
|
||||||
|
|
||||||
|
data = setup.data
|
||||||
|
if setup.files:
|
||||||
|
data = aiohttp.FormData(data or {}, quote_fields=False)
|
||||||
|
for name, file in setup.files:
|
||||||
|
data.add_field(name, file[1], content_type=file[2], filename=file[0])
|
||||||
|
|
||||||
|
cookies = (
|
||||||
|
(cookie.name, cookie.value)
|
||||||
|
for cookie in setup.cookies
|
||||||
|
if cookie.value is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
timeout = aiohttp.ClientTimeout(setup.timeout)
|
||||||
|
|
||||||
|
async with await self.client.request(
|
||||||
|
setup.method,
|
||||||
|
url,
|
||||||
|
data=setup.content or data,
|
||||||
|
json=setup.json,
|
||||||
|
cookies=cookies,
|
||||||
|
headers=setup.headers,
|
||||||
|
proxy=setup.proxy or self._proxy,
|
||||||
|
timeout=timeout,
|
||||||
|
) as response:
|
||||||
|
return Response(
|
||||||
|
response.status,
|
||||||
|
headers=response.headers.copy(),
|
||||||
|
content=await response.read(),
|
||||||
|
request=setup,
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def setup(self) -> None:
|
||||||
|
self._client = aiohttp.ClientSession(
|
||||||
|
cookies=self._cookies,
|
||||||
|
headers=self._headers,
|
||||||
|
version=self._version,
|
||||||
|
timeout=self._timeout,
|
||||||
|
trust_env=True,
|
||||||
|
)
|
||||||
|
await self._client.__aenter__()
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def close(self) -> None:
|
||||||
|
try:
|
||||||
|
if self._client is not None:
|
||||||
|
await self._client.close()
|
||||||
|
finally:
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
|
||||||
class Mixin(HTTPClientMixin, WebSocketClientMixin):
|
class Mixin(HTTPClientMixin, WebSocketClientMixin):
|
||||||
"""AIOHTTP Mixin"""
|
"""AIOHTTP Mixin"""
|
||||||
|
|
||||||
@ -49,42 +152,8 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
async def request(self, setup: Request) -> Response:
|
async def request(self, setup: Request) -> Response:
|
||||||
if setup.version == HTTPVersion.H10:
|
async with self.get_session() as session:
|
||||||
version = aiohttp.HttpVersion10
|
return await session.request(setup)
|
||||||
elif setup.version == HTTPVersion.H11:
|
|
||||||
version = aiohttp.HttpVersion11
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
|
|
||||||
|
|
||||||
timeout = aiohttp.ClientTimeout(setup.timeout)
|
|
||||||
|
|
||||||
data = setup.data
|
|
||||||
if setup.files:
|
|
||||||
data = aiohttp.FormData(data or {}, quote_fields=False)
|
|
||||||
for name, file in setup.files:
|
|
||||||
data.add_field(name, file[1], content_type=file[2], filename=file[0])
|
|
||||||
|
|
||||||
cookies = {
|
|
||||||
cookie.name: cookie.value for cookie in setup.cookies if cookie.value
|
|
||||||
}
|
|
||||||
async with aiohttp.ClientSession(
|
|
||||||
cookies=cookies, version=version, trust_env=True
|
|
||||||
) as session:
|
|
||||||
async with session.request(
|
|
||||||
setup.method,
|
|
||||||
setup.url,
|
|
||||||
data=setup.content or data,
|
|
||||||
json=setup.json,
|
|
||||||
headers=setup.headers,
|
|
||||||
timeout=timeout,
|
|
||||||
proxy=setup.proxy,
|
|
||||||
) as response:
|
|
||||||
return Response(
|
|
||||||
response.status,
|
|
||||||
headers=response.headers.copy(),
|
|
||||||
content=await response.read(),
|
|
||||||
request=setup,
|
|
||||||
)
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@ -106,6 +175,25 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin):
|
|||||||
) as ws:
|
) as ws:
|
||||||
yield WebSocket(request=setup, session=session, websocket=ws)
|
yield WebSocket(request=setup, session=session, websocket=ws)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_session(
|
||||||
|
self,
|
||||||
|
params: QueryTypes = None,
|
||||||
|
headers: HeaderTypes = None,
|
||||||
|
cookies: CookieTypes = None,
|
||||||
|
version: Union[str, HTTPVersion] = HTTPVersion.H11,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
proxy: Optional[str] = None,
|
||||||
|
) -> Session:
|
||||||
|
return Session(
|
||||||
|
params=params,
|
||||||
|
headers=headers,
|
||||||
|
cookies=cookies,
|
||||||
|
version=version,
|
||||||
|
timeout=timeout,
|
||||||
|
proxy=proxy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WebSocket(BaseWebSocket):
|
class WebSocket(BaseWebSocket):
|
||||||
"""AIOHTTP Websocket Wrapper"""
|
"""AIOHTTP Websocket Wrapper"""
|
||||||
|
@ -15,15 +15,20 @@ FrontMatter:
|
|||||||
description: nonebot.drivers.httpx 模块
|
description: nonebot.drivers.httpx 模块
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
from typing import TYPE_CHECKING, Union, Optional
|
||||||
|
|
||||||
|
from multidict import CIMultiDict
|
||||||
|
|
||||||
from nonebot.drivers.none import Driver as NoneDriver
|
from nonebot.drivers.none import Driver as NoneDriver
|
||||||
|
from nonebot.internal.driver import Cookies, QueryTypes, CookieTypes, HeaderTypes
|
||||||
from nonebot.drivers import (
|
from nonebot.drivers import (
|
||||||
|
URL,
|
||||||
Request,
|
Request,
|
||||||
Response,
|
Response,
|
||||||
HTTPVersion,
|
HTTPVersion,
|
||||||
HTTPClientMixin,
|
HTTPClientMixin,
|
||||||
|
HTTPClientSession,
|
||||||
combine_driver,
|
combine_driver,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -36,6 +41,77 @@ except ModuleNotFoundError as e: # pragma: no cover
|
|||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
class Session(HTTPClientSession):
|
||||||
|
@override
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: QueryTypes = None,
|
||||||
|
headers: HeaderTypes = None,
|
||||||
|
cookies: CookieTypes = None,
|
||||||
|
version: Union[str, HTTPVersion] = HTTPVersion.H11,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
proxy: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self._client: Optional[httpx.AsyncClient] = None
|
||||||
|
|
||||||
|
self._params = (
|
||||||
|
tuple(URL.build(query=params).query.items()) if params is not None else None
|
||||||
|
)
|
||||||
|
self._headers = (
|
||||||
|
tuple(CIMultiDict(headers).items()) if headers is not None else None
|
||||||
|
)
|
||||||
|
self._cookies = Cookies(cookies)
|
||||||
|
self._version = HTTPVersion(version)
|
||||||
|
self._timeout = timeout
|
||||||
|
self._proxy = proxy
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self) -> httpx.AsyncClient:
|
||||||
|
if self._client is None:
|
||||||
|
raise RuntimeError("Session is not initialized")
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def request(self, setup: Request) -> Response:
|
||||||
|
response = await self.client.request(
|
||||||
|
setup.method,
|
||||||
|
str(setup.url),
|
||||||
|
content=setup.content,
|
||||||
|
data=setup.data,
|
||||||
|
files=setup.files,
|
||||||
|
json=setup.json,
|
||||||
|
headers=tuple(setup.headers.items()),
|
||||||
|
cookies=setup.cookies.jar,
|
||||||
|
timeout=setup.timeout,
|
||||||
|
)
|
||||||
|
return Response(
|
||||||
|
response.status_code,
|
||||||
|
headers=response.headers.multi_items(),
|
||||||
|
content=response.content,
|
||||||
|
request=setup,
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def setup(self) -> None:
|
||||||
|
self._client = httpx.AsyncClient(
|
||||||
|
params=self._params,
|
||||||
|
headers=self._headers,
|
||||||
|
cookies=self._cookies.jar,
|
||||||
|
http2=self._version == HTTPVersion.H2,
|
||||||
|
proxies=self._proxy,
|
||||||
|
follow_redirects=True,
|
||||||
|
)
|
||||||
|
await self._client.__aenter__()
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def close(self) -> None:
|
||||||
|
try:
|
||||||
|
if self._client is not None:
|
||||||
|
await self._client.aclose()
|
||||||
|
finally:
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
|
||||||
class Mixin(HTTPClientMixin):
|
class Mixin(HTTPClientMixin):
|
||||||
"""HTTPX Mixin"""
|
"""HTTPX Mixin"""
|
||||||
|
|
||||||
@ -46,28 +122,29 @@ class Mixin(HTTPClientMixin):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
async def request(self, setup: Request) -> Response:
|
async def request(self, setup: Request) -> Response:
|
||||||
async with httpx.AsyncClient(
|
async with self.get_session(
|
||||||
cookies=setup.cookies.jar,
|
version=setup.version, proxy=setup.proxy
|
||||||
http2=setup.version == HTTPVersion.H2,
|
) as session:
|
||||||
proxies=setup.proxy,
|
return await session.request(setup)
|
||||||
follow_redirects=True,
|
|
||||||
) as client:
|
@override
|
||||||
response = await client.request(
|
def get_session(
|
||||||
setup.method,
|
self,
|
||||||
str(setup.url),
|
params: QueryTypes = None,
|
||||||
content=setup.content,
|
headers: HeaderTypes = None,
|
||||||
data=setup.data,
|
cookies: CookieTypes = None,
|
||||||
json=setup.json,
|
version: Union[str, HTTPVersion] = HTTPVersion.H11,
|
||||||
files=setup.files,
|
timeout: Optional[float] = None,
|
||||||
headers=tuple(setup.headers.items()),
|
proxy: Optional[str] = None,
|
||||||
timeout=setup.timeout,
|
) -> Session:
|
||||||
)
|
return Session(
|
||||||
return Response(
|
params=params,
|
||||||
response.status_code,
|
headers=headers,
|
||||||
headers=response.headers.multi_items(),
|
cookies=cookies,
|
||||||
content=response.content,
|
version=version,
|
||||||
request=setup,
|
timeout=timeout,
|
||||||
)
|
proxy=proxy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -26,5 +26,6 @@ from .abstract import ReverseDriver as ReverseDriver
|
|||||||
from .combine import combine_driver as combine_driver
|
from .combine import combine_driver as combine_driver
|
||||||
from .model import HTTPServerSetup as HTTPServerSetup
|
from .model import HTTPServerSetup as HTTPServerSetup
|
||||||
from .abstract import HTTPClientMixin as HTTPClientMixin
|
from .abstract import HTTPClientMixin as HTTPClientMixin
|
||||||
|
from .abstract import HTTPClientSession as HTTPClientSession
|
||||||
from .model import WebSocketServerSetup as WebSocketServerSetup
|
from .model import WebSocketServerSetup as WebSocketServerSetup
|
||||||
from .abstract import WebSocketClientMixin as WebSocketClientMixin
|
from .abstract import WebSocketClientMixin as WebSocketClientMixin
|
||||||
|
@ -1,8 +1,19 @@
|
|||||||
import abc
|
import abc
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing_extensions import TypeAlias
|
from types import TracebackType
|
||||||
|
from typing_extensions import Self, TypeAlias
|
||||||
from contextlib import AsyncExitStack, asynccontextmanager
|
from contextlib import AsyncExitStack, asynccontextmanager
|
||||||
from typing import TYPE_CHECKING, Any, Set, Dict, Type, ClassVar, AsyncGenerator
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Set,
|
||||||
|
Dict,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
ClassVar,
|
||||||
|
Optional,
|
||||||
|
AsyncGenerator,
|
||||||
|
)
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.config import Env, Config
|
from nonebot.config import Env, Config
|
||||||
@ -17,7 +28,17 @@ from nonebot.typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from ._lifespan import LIFESPAN_FUNC, Lifespan
|
from ._lifespan import LIFESPAN_FUNC, Lifespan
|
||||||
from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup
|
from .model import (
|
||||||
|
Request,
|
||||||
|
Response,
|
||||||
|
WebSocket,
|
||||||
|
QueryTypes,
|
||||||
|
CookieTypes,
|
||||||
|
HeaderTypes,
|
||||||
|
HTTPVersion,
|
||||||
|
HTTPServerSetup,
|
||||||
|
WebSocketServerSetup,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from nonebot.internal.adapter import Bot, Adapter
|
from nonebot.internal.adapter import Bot, Adapter
|
||||||
@ -222,6 +243,49 @@ class ReverseMixin(Mixin):
|
|||||||
"""服务端混入基类。"""
|
"""服务端混入基类。"""
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPClientSession(abc.ABC):
|
||||||
|
"""HTTP 客户端会话基类。"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: QueryTypes = None,
|
||||||
|
headers: HeaderTypes = None,
|
||||||
|
cookies: CookieTypes = None,
|
||||||
|
version: Union[str, HTTPVersion] = HTTPVersion.H11,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
proxy: Optional[str] = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def request(self, setup: Request) -> Response:
|
||||||
|
"""发送一个 HTTP 请求"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def setup(self) -> None:
|
||||||
|
"""初始化会话"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""关闭会话"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def __aenter__(self) -> Self:
|
||||||
|
await self.setup()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[Type[BaseException]],
|
||||||
|
exc: Optional[BaseException],
|
||||||
|
tb: Optional[TracebackType],
|
||||||
|
) -> None:
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
|
||||||
class HTTPClientMixin(ForwardMixin):
|
class HTTPClientMixin(ForwardMixin):
|
||||||
"""HTTP 客户端混入基类。"""
|
"""HTTP 客户端混入基类。"""
|
||||||
|
|
||||||
@ -230,6 +294,19 @@ class HTTPClientMixin(ForwardMixin):
|
|||||||
"""发送一个 HTTP 请求"""
|
"""发送一个 HTTP 请求"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_session(
|
||||||
|
self,
|
||||||
|
params: QueryTypes = None,
|
||||||
|
headers: HeaderTypes = None,
|
||||||
|
cookies: CookieTypes = None,
|
||||||
|
version: Union[str, HTTPVersion] = HTTPVersion.H11,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
proxy: Optional[str] = None,
|
||||||
|
) -> HTTPClientSession:
|
||||||
|
"""获取一个 HTTP 会话"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class WebSocketClientMixin(ForwardMixin):
|
class WebSocketClientMixin(ForwardMixin):
|
||||||
"""WebSocket 客户端混入基类。"""
|
"""WebSocket 客户端混入基类。"""
|
||||||
|
@ -27,7 +27,7 @@ RawURL: TypeAlias = Tuple[bytes, bytes, Optional[int], bytes]
|
|||||||
SimpleQuery: TypeAlias = Union[str, int, float]
|
SimpleQuery: TypeAlias = Union[str, int, float]
|
||||||
QueryVariable: TypeAlias = Union[SimpleQuery, List[SimpleQuery]]
|
QueryVariable: TypeAlias = Union[SimpleQuery, List[SimpleQuery]]
|
||||||
QueryTypes: TypeAlias = Union[
|
QueryTypes: TypeAlias = Union[
|
||||||
None, str, Mapping[str, QueryVariable], List[Tuple[str, QueryVariable]]
|
None, str, Mapping[str, QueryVariable], List[Tuple[str, SimpleQuery]]
|
||||||
]
|
]
|
||||||
|
|
||||||
HeaderTypes: TypeAlias = Union[
|
HeaderTypes: TypeAlias = Union[
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from http.cookies import SimpleCookie
|
||||||
from typing import Any, Set, Optional
|
from typing import Any, Set, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -306,6 +307,116 @@ async def test_http_client(driver: Driver, server_url: URL):
|
|||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"driver",
|
||||||
|
[
|
||||||
|
pytest.param("nonebot.drivers.httpx:Driver", id="httpx"),
|
||||||
|
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
|
||||||
|
],
|
||||||
|
indirect=True,
|
||||||
|
)
|
||||||
|
async def test_http_client_session(driver: Driver, server_url: URL):
|
||||||
|
assert isinstance(driver, HTTPClientMixin)
|
||||||
|
|
||||||
|
session = driver.get_session(
|
||||||
|
params={"session": "test"},
|
||||||
|
headers={"X-Session": "test"},
|
||||||
|
cookies={"session": "test"},
|
||||||
|
)
|
||||||
|
request = Request("GET", server_url)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await session.request(request)
|
||||||
|
|
||||||
|
async with session as session:
|
||||||
|
# simple post with query, headers, cookies and content
|
||||||
|
request = Request(
|
||||||
|
"POST",
|
||||||
|
server_url,
|
||||||
|
params={"param": "test"},
|
||||||
|
headers={"X-Test": "test"},
|
||||||
|
cookies={"cookie": "test"},
|
||||||
|
content="test",
|
||||||
|
)
|
||||||
|
response = await session.request(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.content
|
||||||
|
data = json.loads(response.content)
|
||||||
|
assert data["method"] == "POST"
|
||||||
|
assert data["args"] == {"session": "test", "param": "test"}
|
||||||
|
assert data["headers"].get("X-Session") == "test"
|
||||||
|
assert data["headers"].get("X-Test") == "test"
|
||||||
|
assert {
|
||||||
|
key: cookie.value
|
||||||
|
for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items()
|
||||||
|
} == {
|
||||||
|
"session": "test",
|
||||||
|
"cookie": "test",
|
||||||
|
}
|
||||||
|
assert data["data"] == "test"
|
||||||
|
|
||||||
|
# post with data body
|
||||||
|
request = Request("POST", server_url, data={"form": "test"})
|
||||||
|
response = await session.request(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.content
|
||||||
|
data = json.loads(response.content)
|
||||||
|
assert data["method"] == "POST"
|
||||||
|
assert data["args"] == {"session": "test"}
|
||||||
|
assert data["headers"].get("X-Session") == "test"
|
||||||
|
assert {
|
||||||
|
key: cookie.value
|
||||||
|
for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items()
|
||||||
|
} == {"session": "test"}
|
||||||
|
assert data["form"] == {"form": "test"}
|
||||||
|
|
||||||
|
# post with json body
|
||||||
|
request = Request("POST", server_url, json={"json": "test"})
|
||||||
|
response = await session.request(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.content
|
||||||
|
data = json.loads(response.content)
|
||||||
|
assert data["method"] == "POST"
|
||||||
|
assert data["args"] == {"session": "test"}
|
||||||
|
assert data["headers"].get("X-Session") == "test"
|
||||||
|
assert {
|
||||||
|
key: cookie.value
|
||||||
|
for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items()
|
||||||
|
} == {"session": "test"}
|
||||||
|
assert data["json"] == {"json": "test"}
|
||||||
|
|
||||||
|
# post with files and form data
|
||||||
|
request = Request(
|
||||||
|
"POST",
|
||||||
|
server_url,
|
||||||
|
data={"form": "test"},
|
||||||
|
files=[
|
||||||
|
("test1", b"test"),
|
||||||
|
("test2", ("test.txt", b"test")),
|
||||||
|
("test3", ("test.txt", b"test", "text/plain")),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
response = await session.request(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.content
|
||||||
|
data = json.loads(response.content)
|
||||||
|
assert data["method"] == "POST"
|
||||||
|
assert data["args"] == {"session": "test"}
|
||||||
|
assert data["headers"].get("X-Session") == "test"
|
||||||
|
assert {
|
||||||
|
key: cookie.value
|
||||||
|
for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items()
|
||||||
|
} == {"session": "test"}
|
||||||
|
assert data["form"] == {"form": "test"}
|
||||||
|
assert data["files"] == {
|
||||||
|
"test1": "test",
|
||||||
|
"test2": "test",
|
||||||
|
"test3": "test",
|
||||||
|
}, "file parsing error"
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"driver",
|
"driver",
|
||||||
|
Loading…
Reference in New Issue
Block a user