add cookies support for forward driver

This commit is contained in:
Ju4tCode 2022-12-20 10:13:45 +00:00 committed by GitHub
parent 827d8fbc0e
commit 2d08465426
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 53 additions and 32 deletions

View File

@ -56,7 +56,13 @@ class Mixin(ForwardMixin):
files = aiohttp.FormData()
for name, file in setup.files:
files.add_field(name, file[1], content_type=file[2], filename=file[0])
async with aiohttp.ClientSession(version=version, trust_env=True) as session:
cookies = {
cookie.name: cookie.value for cookie in setup.cookies if cookie.value
}
async with aiohttp.ClientSession(
cookies=cookies, version=version, trust_env=True
) as session:
async with session.request(
setup.method,
setup.url,
@ -66,13 +72,12 @@ class Mixin(ForwardMixin):
timeout=timeout,
proxy=setup.proxy,
) as response:
res = Response(
return Response(
response.status,
headers=response.headers.copy(),
content=await response.read(),
request=setup,
)
return res
@overrides(ForwardMixin)
@asynccontextmanager
@ -92,8 +97,7 @@ class Mixin(ForwardMixin):
headers=setup.headers,
proxy=setup.proxy,
) as ws:
websocket = WebSocket(request=setup, session=session, websocket=ws)
yield websocket
yield WebSocket(request=setup, session=session, websocket=ws)
class WebSocket(BaseWebSocket):

View File

@ -48,17 +48,18 @@ class Mixin(ForwardMixin):
@overrides(ForwardMixin)
async def request(self, setup: Request) -> Response:
async with httpx.AsyncClient(
cookies=setup.cookies.jar,
http2=setup.version == HTTPVersion.H2,
proxies=setup.proxy, # type: ignore
proxies=setup.proxy,
follow_redirects=True,
) as client:
response = await client.request(
setup.method,
str(setup.url),
content=setup.content, # type: ignore
data=setup.data, # type: ignore
content=setup.content,
data=setup.data,
json=setup.json,
files=setup.files, # type: ignore
files=setup.files,
headers=tuple(setup.headers.items()),
timeout=setup.timeout,
)

View File

@ -70,7 +70,7 @@ class Mixin(ForwardMixin):
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
connection = Connect(
str(setup.url),
extra_headers=setup.headers.items(),
extra_headers={**setup.headers, **setup.cookies.as_header(setup)},
open_timeout=setup.timeout,
)
async with connection as ws:
@ -101,8 +101,7 @@ class WebSocket(BaseWebSocket):
@overrides(BaseWebSocket)
@catch_closed
async def receive(self) -> Union[str, bytes]:
msg = await self.websocket.recv()
return msg
return await self.websocket.recv()
@overrides(BaseWebSocket)
@catch_closed

View File

@ -49,11 +49,11 @@ class MessageTemplate(Formatter, Generic[TF]):
) -> None:
...
def __init__( # type:ignore
self, template, factory=str
) -> None: # TODO: fix type hint here
self.template: TF = template
self.factory: Type[TF] = factory
def __init__(
self, template: Union[str, TM], factory: Union[Type[str], Type[TM]] = str
) -> None:
self.template: TF = template # type: ignore
self.factory: Type[TF] = factory # type: ignore
self.format_specs: Dict[str, FormatSpecFunc] = {}
def __repr__(self) -> str:
@ -98,7 +98,7 @@ class MessageTemplate(Formatter, Generic[TF]):
else:
raise TypeError("template must be a string or instance of Message!")
self.check_unused_args(list(used_args), args, kwargs)
self.check_unused_args(used_args, args, kwargs)
return cast(TF, full_message)
def vformat(

View File

@ -1,4 +1,5 @@
import abc
import urllib.request
from enum import Enum
from dataclasses import dataclass
from http.cookiejar import Cookie, CookieJar
@ -105,12 +106,9 @@ class Request:
self.url: URL = url
# headers
self.headers: CIMultiDict[str]
if headers is not None:
self.headers = CIMultiDict(headers)
else:
self.headers = CIMultiDict()
self.headers: CIMultiDict[str] = (
CIMultiDict(headers) if headers is not None else CIMultiDict()
)
# cookies
self.cookies = Cookies(cookies)
@ -147,12 +145,9 @@ class Response:
self.status_code: int = status_code
# headers
self.headers: CIMultiDict[str]
if headers is not None:
self.headers = CIMultiDict(headers)
else:
self.headers = CIMultiDict()
self.headers: CIMultiDict[str] = (
CIMultiDict(headers) if headers is not None else CIMultiDict()
)
# body
self.content: ContentTypes = content
@ -308,6 +303,11 @@ class Cookies(MutableMapping):
for cookie in cookies.jar:
self.jar.set_cookie(cookie)
def as_header(self, request: Request) -> Dict[str, str]:
urllib_request = self._CookieCompatRequest(request)
self.jar.add_cookie_header(urllib_request)
return urllib_request.added_headers
def __setitem__(self, name: str, value: str) -> None:
return self.set(name, value)
@ -333,6 +333,20 @@ class Cookies(MutableMapping):
)
return f"{self.__class__.__name__}({cookies_repr})"
class _CookieCompatRequest(urllib.request.Request):
def __init__(self, request: Request) -> None:
super().__init__(
url=str(request.url),
headers=dict(request.headers),
method=request.method,
)
self.request = request
self.added_headers: Dict[str, str] = {}
def add_unredirected_header(self, key: str, value: str) -> None:
super().add_unredirected_header(key, value)
self.added_headers[key] = value
@dataclass
class HTTPServerSetup:

View File

@ -14,7 +14,7 @@ from itertools import chain
from types import ModuleType
from importlib.abc import MetaPathFinder
from importlib.machinery import PathFinder, SourceFileLoader
from typing import Set, Dict, List, Union, Iterable, Optional, Sequence
from typing import Set, Dict, List, Iterable, Optional, Sequence
from nonebot.log import logger
from nonebot.utils import escape_tag, path_to_module_name
@ -174,7 +174,7 @@ class PluginFinder(MetaPathFinder):
def find_spec(
self,
fullname: str,
path: Optional[Sequence[Union[bytes, str]]],
path: Optional[Sequence[str]],
target: Optional[ModuleType] = None,
):
if _managers:

View File

@ -88,6 +88,9 @@ extra_standard_library = ["typing_extensions"]
path = "."
all = false
[tool.pyright]
reportShadowedImports = false
[build-system]
requires = ["poetry_core>=1.0.0"]
build-backend = "poetry.core.masonry.api"