add cookies support for forward driver

This commit is contained in:
Ju4tCode 2022-12-20 10:13:45 +00:00 committed by GitHub
parent 827d8fbc0e
commit 2d08465426
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 53 additions and 32 deletions

View File

@ -56,7 +56,13 @@ class Mixin(ForwardMixin):
files = aiohttp.FormData() files = aiohttp.FormData()
for name, file in setup.files: for name, file in setup.files:
files.add_field(name, file[1], content_type=file[2], filename=file[0]) files.add_field(name, file[1], content_type=file[2], filename=file[0])
async with aiohttp.ClientSession(version=version, trust_env=True) as session:
cookies = {
cookie.name: cookie.value for cookie in setup.cookies if cookie.value
}
async with aiohttp.ClientSession(
cookies=cookies, version=version, trust_env=True
) as session:
async with session.request( async with session.request(
setup.method, setup.method,
setup.url, setup.url,
@ -66,13 +72,12 @@ class Mixin(ForwardMixin):
timeout=timeout, timeout=timeout,
proxy=setup.proxy, proxy=setup.proxy,
) as response: ) as response:
res = Response( return Response(
response.status, response.status,
headers=response.headers.copy(), headers=response.headers.copy(),
content=await response.read(), content=await response.read(),
request=setup, request=setup,
) )
return res
@overrides(ForwardMixin) @overrides(ForwardMixin)
@asynccontextmanager @asynccontextmanager
@ -92,8 +97,7 @@ class Mixin(ForwardMixin):
headers=setup.headers, headers=setup.headers,
proxy=setup.proxy, proxy=setup.proxy,
) as ws: ) as ws:
websocket = WebSocket(request=setup, session=session, websocket=ws) yield WebSocket(request=setup, session=session, websocket=ws)
yield websocket
class WebSocket(BaseWebSocket): class WebSocket(BaseWebSocket):

View File

@ -48,17 +48,18 @@ class Mixin(ForwardMixin):
@overrides(ForwardMixin) @overrides(ForwardMixin)
async def request(self, setup: Request) -> Response: async def request(self, setup: Request) -> Response:
async with httpx.AsyncClient( async with httpx.AsyncClient(
cookies=setup.cookies.jar,
http2=setup.version == HTTPVersion.H2, http2=setup.version == HTTPVersion.H2,
proxies=setup.proxy, # type: ignore proxies=setup.proxy,
follow_redirects=True, follow_redirects=True,
) as client: ) as client:
response = await client.request( response = await client.request(
setup.method, setup.method,
str(setup.url), str(setup.url),
content=setup.content, # type: ignore content=setup.content,
data=setup.data, # type: ignore data=setup.data,
json=setup.json, json=setup.json,
files=setup.files, # type: ignore files=setup.files,
headers=tuple(setup.headers.items()), headers=tuple(setup.headers.items()),
timeout=setup.timeout, timeout=setup.timeout,
) )

View File

@ -70,7 +70,7 @@ class Mixin(ForwardMixin):
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]: async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
connection = Connect( connection = Connect(
str(setup.url), str(setup.url),
extra_headers=setup.headers.items(), extra_headers={**setup.headers, **setup.cookies.as_header(setup)},
open_timeout=setup.timeout, open_timeout=setup.timeout,
) )
async with connection as ws: async with connection as ws:
@ -101,8 +101,7 @@ class WebSocket(BaseWebSocket):
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
@catch_closed @catch_closed
async def receive(self) -> Union[str, bytes]: async def receive(self) -> Union[str, bytes]:
msg = await self.websocket.recv() return await self.websocket.recv()
return msg
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
@catch_closed @catch_closed

View File

@ -49,11 +49,11 @@ class MessageTemplate(Formatter, Generic[TF]):
) -> None: ) -> None:
... ...
def __init__( # type:ignore def __init__(
self, template, factory=str self, template: Union[str, TM], factory: Union[Type[str], Type[TM]] = str
) -> None: # TODO: fix type hint here ) -> None:
self.template: TF = template self.template: TF = template # type: ignore
self.factory: Type[TF] = factory self.factory: Type[TF] = factory # type: ignore
self.format_specs: Dict[str, FormatSpecFunc] = {} self.format_specs: Dict[str, FormatSpecFunc] = {}
def __repr__(self) -> str: def __repr__(self) -> str:
@ -98,7 +98,7 @@ class MessageTemplate(Formatter, Generic[TF]):
else: else:
raise TypeError("template must be a string or instance of Message!") raise TypeError("template must be a string or instance of Message!")
self.check_unused_args(list(used_args), args, kwargs) self.check_unused_args(used_args, args, kwargs)
return cast(TF, full_message) return cast(TF, full_message)
def vformat( def vformat(

View File

@ -1,4 +1,5 @@
import abc import abc
import urllib.request
from enum import Enum from enum import Enum
from dataclasses import dataclass from dataclasses import dataclass
from http.cookiejar import Cookie, CookieJar from http.cookiejar import Cookie, CookieJar
@ -105,12 +106,9 @@ class Request:
self.url: URL = url self.url: URL = url
# headers # headers
self.headers: CIMultiDict[str] self.headers: CIMultiDict[str] = (
if headers is not None: CIMultiDict(headers) if headers is not None else CIMultiDict()
self.headers = CIMultiDict(headers) )
else:
self.headers = CIMultiDict()
# cookies # cookies
self.cookies = Cookies(cookies) self.cookies = Cookies(cookies)
@ -147,12 +145,9 @@ class Response:
self.status_code: int = status_code self.status_code: int = status_code
# headers # headers
self.headers: CIMultiDict[str] self.headers: CIMultiDict[str] = (
if headers is not None: CIMultiDict(headers) if headers is not None else CIMultiDict()
self.headers = CIMultiDict(headers) )
else:
self.headers = CIMultiDict()
# body # body
self.content: ContentTypes = content self.content: ContentTypes = content
@ -308,6 +303,11 @@ class Cookies(MutableMapping):
for cookie in cookies.jar: for cookie in cookies.jar:
self.jar.set_cookie(cookie) self.jar.set_cookie(cookie)
def as_header(self, request: Request) -> Dict[str, str]:
urllib_request = self._CookieCompatRequest(request)
self.jar.add_cookie_header(urllib_request)
return urllib_request.added_headers
def __setitem__(self, name: str, value: str) -> None: def __setitem__(self, name: str, value: str) -> None:
return self.set(name, value) return self.set(name, value)
@ -333,6 +333,20 @@ class Cookies(MutableMapping):
) )
return f"{self.__class__.__name__}({cookies_repr})" return f"{self.__class__.__name__}({cookies_repr})"
class _CookieCompatRequest(urllib.request.Request):
def __init__(self, request: Request) -> None:
super().__init__(
url=str(request.url),
headers=dict(request.headers),
method=request.method,
)
self.request = request
self.added_headers: Dict[str, str] = {}
def add_unredirected_header(self, key: str, value: str) -> None:
super().add_unredirected_header(key, value)
self.added_headers[key] = value
@dataclass @dataclass
class HTTPServerSetup: class HTTPServerSetup:

View File

@ -14,7 +14,7 @@ from itertools import chain
from types import ModuleType from types import ModuleType
from importlib.abc import MetaPathFinder from importlib.abc import MetaPathFinder
from importlib.machinery import PathFinder, SourceFileLoader from importlib.machinery import PathFinder, SourceFileLoader
from typing import Set, Dict, List, Union, Iterable, Optional, Sequence from typing import Set, Dict, List, Iterable, Optional, Sequence
from nonebot.log import logger from nonebot.log import logger
from nonebot.utils import escape_tag, path_to_module_name from nonebot.utils import escape_tag, path_to_module_name
@ -174,7 +174,7 @@ class PluginFinder(MetaPathFinder):
def find_spec( def find_spec(
self, self,
fullname: str, fullname: str,
path: Optional[Sequence[Union[bytes, str]]], path: Optional[Sequence[str]],
target: Optional[ModuleType] = None, target: Optional[ModuleType] = None,
): ):
if _managers: if _managers:

View File

@ -88,6 +88,9 @@ extra_standard_library = ["typing_extensions"]
path = "." path = "."
all = false all = false
[tool.pyright]
reportShadowedImports = false
[build-system] [build-system]
requires = ["poetry_core>=1.0.0"] requires = ["poetry_core>=1.0.0"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"