mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 00:55:07 +08:00
✨ Feature: 支持 HTTP 客户端会话 (#2627)
This commit is contained in:
parent
53e2a86dd9
commit
485aa62755
@ -23,6 +23,7 @@ from nonebot.internal.driver import ReverseDriver as ReverseDriver
|
||||
from nonebot.internal.driver import combine_driver as combine_driver
|
||||
from nonebot.internal.driver import HTTPClientMixin as HTTPClientMixin
|
||||
from nonebot.internal.driver import HTTPServerSetup as HTTPServerSetup
|
||||
from nonebot.internal.driver import HTTPClientSession as HTTPClientSession
|
||||
from nonebot.internal.driver import WebSocketClientMixin as WebSocketClientMixin
|
||||
from nonebot.internal.driver import WebSocketServerSetup as WebSocketServerSetup
|
||||
|
||||
|
@ -17,15 +17,19 @@ FrontMatter:
|
||||
|
||||
from typing_extensions import override
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, AsyncGenerator
|
||||
from typing import TYPE_CHECKING, Union, Optional, AsyncGenerator
|
||||
|
||||
from multidict import CIMultiDict
|
||||
|
||||
from nonebot.drivers import Request, Response
|
||||
from nonebot.exception import WebSocketClosed
|
||||
from nonebot.drivers import URL, Request, Response
|
||||
from nonebot.drivers.none import Driver as NoneDriver
|
||||
from nonebot.drivers import WebSocket as BaseWebSocket
|
||||
from nonebot.internal.driver import Cookies, QueryTypes, CookieTypes, HeaderTypes
|
||||
from nonebot.drivers import (
|
||||
HTTPVersion,
|
||||
HTTPClientMixin,
|
||||
HTTPClientSession,
|
||||
WebSocketClientMixin,
|
||||
combine_driver,
|
||||
)
|
||||
@ -39,6 +43,105 @@ except ModuleNotFoundError as e: # pragma: no cover
|
||||
) from e
|
||||
|
||||
|
||||
class Session(HTTPClientSession):
|
||||
@override
|
||||
def __init__(
|
||||
self,
|
||||
params: QueryTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieTypes = None,
|
||||
version: Union[str, HTTPVersion] = HTTPVersion.H11,
|
||||
timeout: Optional[float] = None,
|
||||
proxy: Optional[str] = None,
|
||||
):
|
||||
self._client: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
self._params = URL.build(query=params).query if params is not None else None
|
||||
|
||||
self._headers = CIMultiDict(headers) if headers is not None else None
|
||||
self._cookies = tuple(
|
||||
(cookie.name, cookie.value)
|
||||
for cookie in Cookies(cookies)
|
||||
if cookie.value is not None
|
||||
)
|
||||
|
||||
version = HTTPVersion(version)
|
||||
if version == HTTPVersion.H10:
|
||||
self._version = aiohttp.HttpVersion10
|
||||
elif version == HTTPVersion.H11:
|
||||
self._version = aiohttp.HttpVersion11
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported HTTP version: {version}")
|
||||
|
||||
self._timeout = timeout
|
||||
self._proxy = proxy
|
||||
|
||||
@property
|
||||
def client(self) -> aiohttp.ClientSession:
|
||||
if self._client is None:
|
||||
raise RuntimeError("Session is not initialized")
|
||||
return self._client
|
||||
|
||||
@override
|
||||
async def request(self, setup: Request) -> Response:
|
||||
if self._params:
|
||||
params = self._params.copy()
|
||||
params.update(setup.url.query)
|
||||
url = setup.url.with_query(params)
|
||||
else:
|
||||
url = setup.url
|
||||
|
||||
data = setup.data
|
||||
if setup.files:
|
||||
data = aiohttp.FormData(data or {}, quote_fields=False)
|
||||
for name, file in setup.files:
|
||||
data.add_field(name, file[1], content_type=file[2], filename=file[0])
|
||||
|
||||
cookies = (
|
||||
(cookie.name, cookie.value)
|
||||
for cookie in setup.cookies
|
||||
if cookie.value is not None
|
||||
)
|
||||
|
||||
timeout = aiohttp.ClientTimeout(setup.timeout)
|
||||
|
||||
async with await self.client.request(
|
||||
setup.method,
|
||||
url,
|
||||
data=setup.content or data,
|
||||
json=setup.json,
|
||||
cookies=cookies,
|
||||
headers=setup.headers,
|
||||
proxy=setup.proxy or self._proxy,
|
||||
timeout=timeout,
|
||||
) as response:
|
||||
return Response(
|
||||
response.status,
|
||||
headers=response.headers.copy(),
|
||||
content=await response.read(),
|
||||
request=setup,
|
||||
)
|
||||
|
||||
@override
|
||||
async def setup(self) -> None:
|
||||
self._client = aiohttp.ClientSession(
|
||||
cookies=self._cookies,
|
||||
headers=self._headers,
|
||||
version=self._version,
|
||||
timeout=self._timeout,
|
||||
trust_env=True,
|
||||
)
|
||||
await self._client.__aenter__()
|
||||
|
||||
@override
|
||||
async def close(self) -> None:
|
||||
try:
|
||||
if self._client is not None:
|
||||
await self._client.close()
|
||||
finally:
|
||||
self._client = None
|
||||
|
||||
|
||||
class Mixin(HTTPClientMixin, WebSocketClientMixin):
|
||||
"""AIOHTTP Mixin"""
|
||||
|
||||
@ -49,42 +152,8 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin):
|
||||
|
||||
@override
|
||||
async def request(self, setup: Request) -> Response:
|
||||
if setup.version == HTTPVersion.H10:
|
||||
version = aiohttp.HttpVersion10
|
||||
elif setup.version == HTTPVersion.H11:
|
||||
version = aiohttp.HttpVersion11
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
|
||||
|
||||
timeout = aiohttp.ClientTimeout(setup.timeout)
|
||||
|
||||
data = setup.data
|
||||
if setup.files:
|
||||
data = aiohttp.FormData(data or {}, quote_fields=False)
|
||||
for name, file in setup.files:
|
||||
data.add_field(name, file[1], content_type=file[2], filename=file[0])
|
||||
|
||||
cookies = {
|
||||
cookie.name: cookie.value for cookie in setup.cookies if cookie.value
|
||||
}
|
||||
async with aiohttp.ClientSession(
|
||||
cookies=cookies, version=version, trust_env=True
|
||||
) as session:
|
||||
async with session.request(
|
||||
setup.method,
|
||||
setup.url,
|
||||
data=setup.content or data,
|
||||
json=setup.json,
|
||||
headers=setup.headers,
|
||||
timeout=timeout,
|
||||
proxy=setup.proxy,
|
||||
) as response:
|
||||
return Response(
|
||||
response.status,
|
||||
headers=response.headers.copy(),
|
||||
content=await response.read(),
|
||||
request=setup,
|
||||
)
|
||||
async with self.get_session() as session:
|
||||
return await session.request(setup)
|
||||
|
||||
@override
|
||||
@asynccontextmanager
|
||||
@ -106,6 +175,25 @@ class Mixin(HTTPClientMixin, WebSocketClientMixin):
|
||||
) as ws:
|
||||
yield WebSocket(request=setup, session=session, websocket=ws)
|
||||
|
||||
@override
|
||||
def get_session(
|
||||
self,
|
||||
params: QueryTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieTypes = None,
|
||||
version: Union[str, HTTPVersion] = HTTPVersion.H11,
|
||||
timeout: Optional[float] = None,
|
||||
proxy: Optional[str] = None,
|
||||
) -> Session:
|
||||
return Session(
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
version=version,
|
||||
timeout=timeout,
|
||||
proxy=proxy,
|
||||
)
|
||||
|
||||
|
||||
class WebSocket(BaseWebSocket):
|
||||
"""AIOHTTP Websocket Wrapper"""
|
||||
|
@ -15,15 +15,20 @@ FrontMatter:
|
||||
description: nonebot.drivers.httpx 模块
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing_extensions import override
|
||||
from typing import TYPE_CHECKING, Union, Optional
|
||||
|
||||
from multidict import CIMultiDict
|
||||
|
||||
from nonebot.drivers.none import Driver as NoneDriver
|
||||
from nonebot.internal.driver import Cookies, QueryTypes, CookieTypes, HeaderTypes
|
||||
from nonebot.drivers import (
|
||||
URL,
|
||||
Request,
|
||||
Response,
|
||||
HTTPVersion,
|
||||
HTTPClientMixin,
|
||||
HTTPClientSession,
|
||||
combine_driver,
|
||||
)
|
||||
|
||||
@ -36,6 +41,77 @@ except ModuleNotFoundError as e: # pragma: no cover
|
||||
) from e
|
||||
|
||||
|
||||
class Session(HTTPClientSession):
|
||||
@override
|
||||
def __init__(
|
||||
self,
|
||||
params: QueryTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieTypes = None,
|
||||
version: Union[str, HTTPVersion] = HTTPVersion.H11,
|
||||
timeout: Optional[float] = None,
|
||||
proxy: Optional[str] = None,
|
||||
):
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
self._params = (
|
||||
tuple(URL.build(query=params).query.items()) if params is not None else None
|
||||
)
|
||||
self._headers = (
|
||||
tuple(CIMultiDict(headers).items()) if headers is not None else None
|
||||
)
|
||||
self._cookies = Cookies(cookies)
|
||||
self._version = HTTPVersion(version)
|
||||
self._timeout = timeout
|
||||
self._proxy = proxy
|
||||
|
||||
@property
|
||||
def client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
raise RuntimeError("Session is not initialized")
|
||||
return self._client
|
||||
|
||||
@override
|
||||
async def request(self, setup: Request) -> Response:
|
||||
response = await self.client.request(
|
||||
setup.method,
|
||||
str(setup.url),
|
||||
content=setup.content,
|
||||
data=setup.data,
|
||||
files=setup.files,
|
||||
json=setup.json,
|
||||
headers=tuple(setup.headers.items()),
|
||||
cookies=setup.cookies.jar,
|
||||
timeout=setup.timeout,
|
||||
)
|
||||
return Response(
|
||||
response.status_code,
|
||||
headers=response.headers.multi_items(),
|
||||
content=response.content,
|
||||
request=setup,
|
||||
)
|
||||
|
||||
@override
|
||||
async def setup(self) -> None:
|
||||
self._client = httpx.AsyncClient(
|
||||
params=self._params,
|
||||
headers=self._headers,
|
||||
cookies=self._cookies.jar,
|
||||
http2=self._version == HTTPVersion.H2,
|
||||
proxies=self._proxy,
|
||||
follow_redirects=True,
|
||||
)
|
||||
await self._client.__aenter__()
|
||||
|
||||
@override
|
||||
async def close(self) -> None:
|
||||
try:
|
||||
if self._client is not None:
|
||||
await self._client.aclose()
|
||||
finally:
|
||||
self._client = None
|
||||
|
||||
|
||||
class Mixin(HTTPClientMixin):
|
||||
"""HTTPX Mixin"""
|
||||
|
||||
@ -46,28 +122,29 @@ class Mixin(HTTPClientMixin):
|
||||
|
||||
@override
|
||||
async def request(self, setup: Request) -> Response:
|
||||
async with httpx.AsyncClient(
|
||||
cookies=setup.cookies.jar,
|
||||
http2=setup.version == HTTPVersion.H2,
|
||||
proxies=setup.proxy,
|
||||
follow_redirects=True,
|
||||
) as client:
|
||||
response = await client.request(
|
||||
setup.method,
|
||||
str(setup.url),
|
||||
content=setup.content,
|
||||
data=setup.data,
|
||||
json=setup.json,
|
||||
files=setup.files,
|
||||
headers=tuple(setup.headers.items()),
|
||||
timeout=setup.timeout,
|
||||
)
|
||||
return Response(
|
||||
response.status_code,
|
||||
headers=response.headers.multi_items(),
|
||||
content=response.content,
|
||||
request=setup,
|
||||
)
|
||||
async with self.get_session(
|
||||
version=setup.version, proxy=setup.proxy
|
||||
) as session:
|
||||
return await session.request(setup)
|
||||
|
||||
@override
|
||||
def get_session(
|
||||
self,
|
||||
params: QueryTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieTypes = None,
|
||||
version: Union[str, HTTPVersion] = HTTPVersion.H11,
|
||||
timeout: Optional[float] = None,
|
||||
proxy: Optional[str] = None,
|
||||
) -> Session:
|
||||
return Session(
|
||||
params=params,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
version=version,
|
||||
timeout=timeout,
|
||||
proxy=proxy,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -26,5 +26,6 @@ from .abstract import ReverseDriver as ReverseDriver
|
||||
from .combine import combine_driver as combine_driver
|
||||
from .model import HTTPServerSetup as HTTPServerSetup
|
||||
from .abstract import HTTPClientMixin as HTTPClientMixin
|
||||
from .abstract import HTTPClientSession as HTTPClientSession
|
||||
from .model import WebSocketServerSetup as WebSocketServerSetup
|
||||
from .abstract import WebSocketClientMixin as WebSocketClientMixin
|
||||
|
@ -1,8 +1,19 @@
|
||||
import abc
|
||||
import asyncio
|
||||
from typing_extensions import TypeAlias
|
||||
from types import TracebackType
|
||||
from typing_extensions import Self, TypeAlias
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any, Set, Dict, Type, ClassVar, AsyncGenerator
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Set,
|
||||
Dict,
|
||||
Type,
|
||||
Union,
|
||||
ClassVar,
|
||||
Optional,
|
||||
AsyncGenerator,
|
||||
)
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.config import Env, Config
|
||||
@ -17,7 +28,17 @@ from nonebot.typing import (
|
||||
)
|
||||
|
||||
from ._lifespan import LIFESPAN_FUNC, Lifespan
|
||||
from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup
|
||||
from .model import (
|
||||
Request,
|
||||
Response,
|
||||
WebSocket,
|
||||
QueryTypes,
|
||||
CookieTypes,
|
||||
HeaderTypes,
|
||||
HTTPVersion,
|
||||
HTTPServerSetup,
|
||||
WebSocketServerSetup,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nonebot.internal.adapter import Bot, Adapter
|
||||
@ -222,6 +243,49 @@ class ReverseMixin(Mixin):
|
||||
"""服务端混入基类。"""
|
||||
|
||||
|
||||
class HTTPClientSession(abc.ABC):
|
||||
"""HTTP 客户端会话基类。"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
params: QueryTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieTypes = None,
|
||||
version: Union[str, HTTPVersion] = HTTPVersion.H11,
|
||||
timeout: Optional[float] = None,
|
||||
proxy: Optional[str] = None,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def request(self, setup: Request) -> Response:
|
||||
"""发送一个 HTTP 请求"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def setup(self) -> None:
|
||||
"""初始化会话"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""关闭会话"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
await self.setup()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc: Optional[BaseException],
|
||||
tb: Optional[TracebackType],
|
||||
) -> None:
|
||||
await self.close()
|
||||
|
||||
|
||||
class HTTPClientMixin(ForwardMixin):
|
||||
"""HTTP 客户端混入基类。"""
|
||||
|
||||
@ -230,6 +294,19 @@ class HTTPClientMixin(ForwardMixin):
|
||||
"""发送一个 HTTP 请求"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_session(
|
||||
self,
|
||||
params: QueryTypes = None,
|
||||
headers: HeaderTypes = None,
|
||||
cookies: CookieTypes = None,
|
||||
version: Union[str, HTTPVersion] = HTTPVersion.H11,
|
||||
timeout: Optional[float] = None,
|
||||
proxy: Optional[str] = None,
|
||||
) -> HTTPClientSession:
|
||||
"""获取一个 HTTP 会话"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class WebSocketClientMixin(ForwardMixin):
|
||||
"""WebSocket 客户端混入基类。"""
|
||||
|
@ -27,7 +27,7 @@ RawURL: TypeAlias = Tuple[bytes, bytes, Optional[int], bytes]
|
||||
SimpleQuery: TypeAlias = Union[str, int, float]
|
||||
QueryVariable: TypeAlias = Union[SimpleQuery, List[SimpleQuery]]
|
||||
QueryTypes: TypeAlias = Union[
|
||||
None, str, Mapping[str, QueryVariable], List[Tuple[str, QueryVariable]]
|
||||
None, str, Mapping[str, QueryVariable], List[Tuple[str, SimpleQuery]]
|
||||
]
|
||||
|
||||
HeaderTypes: TypeAlias = Union[
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import asyncio
|
||||
from http.cookies import SimpleCookie
|
||||
from typing import Any, Set, Optional
|
||||
|
||||
import pytest
|
||||
@ -306,6 +307,116 @@ async def test_http_client(driver: Driver, server_url: URL):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"driver",
|
||||
[
|
||||
pytest.param("nonebot.drivers.httpx:Driver", id="httpx"),
|
||||
pytest.param("nonebot.drivers.aiohttp:Driver", id="aiohttp"),
|
||||
],
|
||||
indirect=True,
|
||||
)
|
||||
async def test_http_client_session(driver: Driver, server_url: URL):
|
||||
assert isinstance(driver, HTTPClientMixin)
|
||||
|
||||
session = driver.get_session(
|
||||
params={"session": "test"},
|
||||
headers={"X-Session": "test"},
|
||||
cookies={"session": "test"},
|
||||
)
|
||||
request = Request("GET", server_url)
|
||||
with pytest.raises(RuntimeError):
|
||||
await session.request(request)
|
||||
|
||||
async with session as session:
|
||||
# simple post with query, headers, cookies and content
|
||||
request = Request(
|
||||
"POST",
|
||||
server_url,
|
||||
params={"param": "test"},
|
||||
headers={"X-Test": "test"},
|
||||
cookies={"cookie": "test"},
|
||||
content="test",
|
||||
)
|
||||
response = await session.request(request)
|
||||
assert response.status_code == 200
|
||||
assert response.content
|
||||
data = json.loads(response.content)
|
||||
assert data["method"] == "POST"
|
||||
assert data["args"] == {"session": "test", "param": "test"}
|
||||
assert data["headers"].get("X-Session") == "test"
|
||||
assert data["headers"].get("X-Test") == "test"
|
||||
assert {
|
||||
key: cookie.value
|
||||
for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items()
|
||||
} == {
|
||||
"session": "test",
|
||||
"cookie": "test",
|
||||
}
|
||||
assert data["data"] == "test"
|
||||
|
||||
# post with data body
|
||||
request = Request("POST", server_url, data={"form": "test"})
|
||||
response = await session.request(request)
|
||||
assert response.status_code == 200
|
||||
assert response.content
|
||||
data = json.loads(response.content)
|
||||
assert data["method"] == "POST"
|
||||
assert data["args"] == {"session": "test"}
|
||||
assert data["headers"].get("X-Session") == "test"
|
||||
assert {
|
||||
key: cookie.value
|
||||
for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items()
|
||||
} == {"session": "test"}
|
||||
assert data["form"] == {"form": "test"}
|
||||
|
||||
# post with json body
|
||||
request = Request("POST", server_url, json={"json": "test"})
|
||||
response = await session.request(request)
|
||||
assert response.status_code == 200
|
||||
assert response.content
|
||||
data = json.loads(response.content)
|
||||
assert data["method"] == "POST"
|
||||
assert data["args"] == {"session": "test"}
|
||||
assert data["headers"].get("X-Session") == "test"
|
||||
assert {
|
||||
key: cookie.value
|
||||
for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items()
|
||||
} == {"session": "test"}
|
||||
assert data["json"] == {"json": "test"}
|
||||
|
||||
# post with files and form data
|
||||
request = Request(
|
||||
"POST",
|
||||
server_url,
|
||||
data={"form": "test"},
|
||||
files=[
|
||||
("test1", b"test"),
|
||||
("test2", ("test.txt", b"test")),
|
||||
("test3", ("test.txt", b"test", "text/plain")),
|
||||
],
|
||||
)
|
||||
response = await session.request(request)
|
||||
assert response.status_code == 200
|
||||
assert response.content
|
||||
data = json.loads(response.content)
|
||||
assert data["method"] == "POST"
|
||||
assert data["args"] == {"session": "test"}
|
||||
assert data["headers"].get("X-Session") == "test"
|
||||
assert {
|
||||
key: cookie.value
|
||||
for key, cookie in SimpleCookie(data["headers"].get("Cookie")).items()
|
||||
} == {"session": "test"}
|
||||
assert data["form"] == {"form": "test"}
|
||||
assert data["files"] == {
|
||||
"test1": "test",
|
||||
"test2": "test",
|
||||
"test3": "test",
|
||||
}, "file parsing error"
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"driver",
|
||||
|
Loading…
Reference in New Issue
Block a user