mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-27 18:45:05 +08:00
✨ add cookies support for forward driver
This commit is contained in:
parent
827d8fbc0e
commit
2d08465426
@ -56,7 +56,13 @@ class Mixin(ForwardMixin):
|
||||
files = aiohttp.FormData()
|
||||
for name, file in setup.files:
|
||||
files.add_field(name, file[1], content_type=file[2], filename=file[0])
|
||||
async with aiohttp.ClientSession(version=version, trust_env=True) as session:
|
||||
|
||||
cookies = {
|
||||
cookie.name: cookie.value for cookie in setup.cookies if cookie.value
|
||||
}
|
||||
async with aiohttp.ClientSession(
|
||||
cookies=cookies, version=version, trust_env=True
|
||||
) as session:
|
||||
async with session.request(
|
||||
setup.method,
|
||||
setup.url,
|
||||
@ -66,13 +72,12 @@ class Mixin(ForwardMixin):
|
||||
timeout=timeout,
|
||||
proxy=setup.proxy,
|
||||
) as response:
|
||||
res = Response(
|
||||
return Response(
|
||||
response.status,
|
||||
headers=response.headers.copy(),
|
||||
content=await response.read(),
|
||||
request=setup,
|
||||
)
|
||||
return res
|
||||
|
||||
@overrides(ForwardMixin)
|
||||
@asynccontextmanager
|
||||
@ -92,8 +97,7 @@ class Mixin(ForwardMixin):
|
||||
headers=setup.headers,
|
||||
proxy=setup.proxy,
|
||||
) as ws:
|
||||
websocket = WebSocket(request=setup, session=session, websocket=ws)
|
||||
yield websocket
|
||||
yield WebSocket(request=setup, session=session, websocket=ws)
|
||||
|
||||
|
||||
class WebSocket(BaseWebSocket):
|
||||
|
@ -48,17 +48,18 @@ class Mixin(ForwardMixin):
|
||||
@overrides(ForwardMixin)
|
||||
async def request(self, setup: Request) -> Response:
|
||||
async with httpx.AsyncClient(
|
||||
cookies=setup.cookies.jar,
|
||||
http2=setup.version == HTTPVersion.H2,
|
||||
proxies=setup.proxy, # type: ignore
|
||||
proxies=setup.proxy,
|
||||
follow_redirects=True,
|
||||
) as client:
|
||||
response = await client.request(
|
||||
setup.method,
|
||||
str(setup.url),
|
||||
content=setup.content, # type: ignore
|
||||
data=setup.data, # type: ignore
|
||||
content=setup.content,
|
||||
data=setup.data,
|
||||
json=setup.json,
|
||||
files=setup.files, # type: ignore
|
||||
files=setup.files,
|
||||
headers=tuple(setup.headers.items()),
|
||||
timeout=setup.timeout,
|
||||
)
|
||||
|
@ -70,7 +70,7 @@ class Mixin(ForwardMixin):
|
||||
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
|
||||
connection = Connect(
|
||||
str(setup.url),
|
||||
extra_headers=setup.headers.items(),
|
||||
extra_headers={**setup.headers, **setup.cookies.as_header(setup)},
|
||||
open_timeout=setup.timeout,
|
||||
)
|
||||
async with connection as ws:
|
||||
@ -101,8 +101,7 @@ class WebSocket(BaseWebSocket):
|
||||
@overrides(BaseWebSocket)
|
||||
@catch_closed
|
||||
async def receive(self) -> Union[str, bytes]:
|
||||
msg = await self.websocket.recv()
|
||||
return msg
|
||||
return await self.websocket.recv()
|
||||
|
||||
@overrides(BaseWebSocket)
|
||||
@catch_closed
|
||||
|
@ -49,11 +49,11 @@ class MessageTemplate(Formatter, Generic[TF]):
|
||||
) -> None:
|
||||
...
|
||||
|
||||
def __init__( # type:ignore
|
||||
self, template, factory=str
|
||||
) -> None: # TODO: fix type hint here
|
||||
self.template: TF = template
|
||||
self.factory: Type[TF] = factory
|
||||
def __init__(
|
||||
self, template: Union[str, TM], factory: Union[Type[str], Type[TM]] = str
|
||||
) -> None:
|
||||
self.template: TF = template # type: ignore
|
||||
self.factory: Type[TF] = factory # type: ignore
|
||||
self.format_specs: Dict[str, FormatSpecFunc] = {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@ -98,7 +98,7 @@ class MessageTemplate(Formatter, Generic[TF]):
|
||||
else:
|
||||
raise TypeError("template must be a string or instance of Message!")
|
||||
|
||||
self.check_unused_args(list(used_args), args, kwargs)
|
||||
self.check_unused_args(used_args, args, kwargs)
|
||||
return cast(TF, full_message)
|
||||
|
||||
def vformat(
|
||||
|
@ -1,4 +1,5 @@
|
||||
import abc
|
||||
import urllib.request
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from http.cookiejar import Cookie, CookieJar
|
||||
@ -105,12 +106,9 @@ class Request:
|
||||
self.url: URL = url
|
||||
|
||||
# headers
|
||||
self.headers: CIMultiDict[str]
|
||||
if headers is not None:
|
||||
self.headers = CIMultiDict(headers)
|
||||
else:
|
||||
self.headers = CIMultiDict()
|
||||
|
||||
self.headers: CIMultiDict[str] = (
|
||||
CIMultiDict(headers) if headers is not None else CIMultiDict()
|
||||
)
|
||||
# cookies
|
||||
self.cookies = Cookies(cookies)
|
||||
|
||||
@ -147,12 +145,9 @@ class Response:
|
||||
self.status_code: int = status_code
|
||||
|
||||
# headers
|
||||
self.headers: CIMultiDict[str]
|
||||
if headers is not None:
|
||||
self.headers = CIMultiDict(headers)
|
||||
else:
|
||||
self.headers = CIMultiDict()
|
||||
|
||||
self.headers: CIMultiDict[str] = (
|
||||
CIMultiDict(headers) if headers is not None else CIMultiDict()
|
||||
)
|
||||
# body
|
||||
self.content: ContentTypes = content
|
||||
|
||||
@ -308,6 +303,11 @@ class Cookies(MutableMapping):
|
||||
for cookie in cookies.jar:
|
||||
self.jar.set_cookie(cookie)
|
||||
|
||||
def as_header(self, request: Request) -> Dict[str, str]:
|
||||
urllib_request = self._CookieCompatRequest(request)
|
||||
self.jar.add_cookie_header(urllib_request)
|
||||
return urllib_request.added_headers
|
||||
|
||||
def __setitem__(self, name: str, value: str) -> None:
|
||||
return self.set(name, value)
|
||||
|
||||
@ -333,6 +333,20 @@ class Cookies(MutableMapping):
|
||||
)
|
||||
return f"{self.__class__.__name__}({cookies_repr})"
|
||||
|
||||
class _CookieCompatRequest(urllib.request.Request):
|
||||
def __init__(self, request: Request) -> None:
|
||||
super().__init__(
|
||||
url=str(request.url),
|
||||
headers=dict(request.headers),
|
||||
method=request.method,
|
||||
)
|
||||
self.request = request
|
||||
self.added_headers: Dict[str, str] = {}
|
||||
|
||||
def add_unredirected_header(self, key: str, value: str) -> None:
|
||||
super().add_unredirected_header(key, value)
|
||||
self.added_headers[key] = value
|
||||
|
||||
|
||||
@dataclass
|
||||
class HTTPServerSetup:
|
||||
|
@ -14,7 +14,7 @@ from itertools import chain
|
||||
from types import ModuleType
|
||||
from importlib.abc import MetaPathFinder
|
||||
from importlib.machinery import PathFinder, SourceFileLoader
|
||||
from typing import Set, Dict, List, Union, Iterable, Optional, Sequence
|
||||
from typing import Set, Dict, List, Iterable, Optional, Sequence
|
||||
|
||||
from nonebot.log import logger
|
||||
from nonebot.utils import escape_tag, path_to_module_name
|
||||
@ -174,7 +174,7 @@ class PluginFinder(MetaPathFinder):
|
||||
def find_spec(
|
||||
self,
|
||||
fullname: str,
|
||||
path: Optional[Sequence[Union[bytes, str]]],
|
||||
path: Optional[Sequence[str]],
|
||||
target: Optional[ModuleType] = None,
|
||||
):
|
||||
if _managers:
|
||||
|
@ -88,6 +88,9 @@ extra_standard_library = ["typing_extensions"]
|
||||
path = "."
|
||||
all = false
|
||||
|
||||
[tool.pyright]
|
||||
reportShadowedImports = false
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry_core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
Loading…
Reference in New Issue
Block a user