mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-28 03:36:52 +08:00
✨ add cookies support for forward driver
This commit is contained in:
parent
827d8fbc0e
commit
2d08465426
@ -56,7 +56,13 @@ class Mixin(ForwardMixin):
|
|||||||
files = aiohttp.FormData()
|
files = aiohttp.FormData()
|
||||||
for name, file in setup.files:
|
for name, file in setup.files:
|
||||||
files.add_field(name, file[1], content_type=file[2], filename=file[0])
|
files.add_field(name, file[1], content_type=file[2], filename=file[0])
|
||||||
async with aiohttp.ClientSession(version=version, trust_env=True) as session:
|
|
||||||
|
cookies = {
|
||||||
|
cookie.name: cookie.value for cookie in setup.cookies if cookie.value
|
||||||
|
}
|
||||||
|
async with aiohttp.ClientSession(
|
||||||
|
cookies=cookies, version=version, trust_env=True
|
||||||
|
) as session:
|
||||||
async with session.request(
|
async with session.request(
|
||||||
setup.method,
|
setup.method,
|
||||||
setup.url,
|
setup.url,
|
||||||
@ -66,13 +72,12 @@ class Mixin(ForwardMixin):
|
|||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
proxy=setup.proxy,
|
proxy=setup.proxy,
|
||||||
) as response:
|
) as response:
|
||||||
res = Response(
|
return Response(
|
||||||
response.status,
|
response.status,
|
||||||
headers=response.headers.copy(),
|
headers=response.headers.copy(),
|
||||||
content=await response.read(),
|
content=await response.read(),
|
||||||
request=setup,
|
request=setup,
|
||||||
)
|
)
|
||||||
return res
|
|
||||||
|
|
||||||
@overrides(ForwardMixin)
|
@overrides(ForwardMixin)
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@ -92,8 +97,7 @@ class Mixin(ForwardMixin):
|
|||||||
headers=setup.headers,
|
headers=setup.headers,
|
||||||
proxy=setup.proxy,
|
proxy=setup.proxy,
|
||||||
) as ws:
|
) as ws:
|
||||||
websocket = WebSocket(request=setup, session=session, websocket=ws)
|
yield WebSocket(request=setup, session=session, websocket=ws)
|
||||||
yield websocket
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocket(BaseWebSocket):
|
class WebSocket(BaseWebSocket):
|
||||||
|
@ -48,17 +48,18 @@ class Mixin(ForwardMixin):
|
|||||||
@overrides(ForwardMixin)
|
@overrides(ForwardMixin)
|
||||||
async def request(self, setup: Request) -> Response:
|
async def request(self, setup: Request) -> Response:
|
||||||
async with httpx.AsyncClient(
|
async with httpx.AsyncClient(
|
||||||
|
cookies=setup.cookies.jar,
|
||||||
http2=setup.version == HTTPVersion.H2,
|
http2=setup.version == HTTPVersion.H2,
|
||||||
proxies=setup.proxy, # type: ignore
|
proxies=setup.proxy,
|
||||||
follow_redirects=True,
|
follow_redirects=True,
|
||||||
) as client:
|
) as client:
|
||||||
response = await client.request(
|
response = await client.request(
|
||||||
setup.method,
|
setup.method,
|
||||||
str(setup.url),
|
str(setup.url),
|
||||||
content=setup.content, # type: ignore
|
content=setup.content,
|
||||||
data=setup.data, # type: ignore
|
data=setup.data,
|
||||||
json=setup.json,
|
json=setup.json,
|
||||||
files=setup.files, # type: ignore
|
files=setup.files,
|
||||||
headers=tuple(setup.headers.items()),
|
headers=tuple(setup.headers.items()),
|
||||||
timeout=setup.timeout,
|
timeout=setup.timeout,
|
||||||
)
|
)
|
||||||
|
@ -70,7 +70,7 @@ class Mixin(ForwardMixin):
|
|||||||
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
|
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
|
||||||
connection = Connect(
|
connection = Connect(
|
||||||
str(setup.url),
|
str(setup.url),
|
||||||
extra_headers=setup.headers.items(),
|
extra_headers={**setup.headers, **setup.cookies.as_header(setup)},
|
||||||
open_timeout=setup.timeout,
|
open_timeout=setup.timeout,
|
||||||
)
|
)
|
||||||
async with connection as ws:
|
async with connection as ws:
|
||||||
@ -101,8 +101,7 @@ class WebSocket(BaseWebSocket):
|
|||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
@catch_closed
|
@catch_closed
|
||||||
async def receive(self) -> Union[str, bytes]:
|
async def receive(self) -> Union[str, bytes]:
|
||||||
msg = await self.websocket.recv()
|
return await self.websocket.recv()
|
||||||
return msg
|
|
||||||
|
|
||||||
@overrides(BaseWebSocket)
|
@overrides(BaseWebSocket)
|
||||||
@catch_closed
|
@catch_closed
|
||||||
|
@ -49,11 +49,11 @@ class MessageTemplate(Formatter, Generic[TF]):
|
|||||||
) -> None:
|
) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
def __init__( # type:ignore
|
def __init__(
|
||||||
self, template, factory=str
|
self, template: Union[str, TM], factory: Union[Type[str], Type[TM]] = str
|
||||||
) -> None: # TODO: fix type hint here
|
) -> None:
|
||||||
self.template: TF = template
|
self.template: TF = template # type: ignore
|
||||||
self.factory: Type[TF] = factory
|
self.factory: Type[TF] = factory # type: ignore
|
||||||
self.format_specs: Dict[str, FormatSpecFunc] = {}
|
self.format_specs: Dict[str, FormatSpecFunc] = {}
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
@ -98,7 +98,7 @@ class MessageTemplate(Formatter, Generic[TF]):
|
|||||||
else:
|
else:
|
||||||
raise TypeError("template must be a string or instance of Message!")
|
raise TypeError("template must be a string or instance of Message!")
|
||||||
|
|
||||||
self.check_unused_args(list(used_args), args, kwargs)
|
self.check_unused_args(used_args, args, kwargs)
|
||||||
return cast(TF, full_message)
|
return cast(TF, full_message)
|
||||||
|
|
||||||
def vformat(
|
def vformat(
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import abc
|
import abc
|
||||||
|
import urllib.request
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from http.cookiejar import Cookie, CookieJar
|
from http.cookiejar import Cookie, CookieJar
|
||||||
@ -105,12 +106,9 @@ class Request:
|
|||||||
self.url: URL = url
|
self.url: URL = url
|
||||||
|
|
||||||
# headers
|
# headers
|
||||||
self.headers: CIMultiDict[str]
|
self.headers: CIMultiDict[str] = (
|
||||||
if headers is not None:
|
CIMultiDict(headers) if headers is not None else CIMultiDict()
|
||||||
self.headers = CIMultiDict(headers)
|
)
|
||||||
else:
|
|
||||||
self.headers = CIMultiDict()
|
|
||||||
|
|
||||||
# cookies
|
# cookies
|
||||||
self.cookies = Cookies(cookies)
|
self.cookies = Cookies(cookies)
|
||||||
|
|
||||||
@ -147,12 +145,9 @@ class Response:
|
|||||||
self.status_code: int = status_code
|
self.status_code: int = status_code
|
||||||
|
|
||||||
# headers
|
# headers
|
||||||
self.headers: CIMultiDict[str]
|
self.headers: CIMultiDict[str] = (
|
||||||
if headers is not None:
|
CIMultiDict(headers) if headers is not None else CIMultiDict()
|
||||||
self.headers = CIMultiDict(headers)
|
)
|
||||||
else:
|
|
||||||
self.headers = CIMultiDict()
|
|
||||||
|
|
||||||
# body
|
# body
|
||||||
self.content: ContentTypes = content
|
self.content: ContentTypes = content
|
||||||
|
|
||||||
@ -308,6 +303,11 @@ class Cookies(MutableMapping):
|
|||||||
for cookie in cookies.jar:
|
for cookie in cookies.jar:
|
||||||
self.jar.set_cookie(cookie)
|
self.jar.set_cookie(cookie)
|
||||||
|
|
||||||
|
def as_header(self, request: Request) -> Dict[str, str]:
|
||||||
|
urllib_request = self._CookieCompatRequest(request)
|
||||||
|
self.jar.add_cookie_header(urllib_request)
|
||||||
|
return urllib_request.added_headers
|
||||||
|
|
||||||
def __setitem__(self, name: str, value: str) -> None:
|
def __setitem__(self, name: str, value: str) -> None:
|
||||||
return self.set(name, value)
|
return self.set(name, value)
|
||||||
|
|
||||||
@ -333,6 +333,20 @@ class Cookies(MutableMapping):
|
|||||||
)
|
)
|
||||||
return f"{self.__class__.__name__}({cookies_repr})"
|
return f"{self.__class__.__name__}({cookies_repr})"
|
||||||
|
|
||||||
|
class _CookieCompatRequest(urllib.request.Request):
|
||||||
|
def __init__(self, request: Request) -> None:
|
||||||
|
super().__init__(
|
||||||
|
url=str(request.url),
|
||||||
|
headers=dict(request.headers),
|
||||||
|
method=request.method,
|
||||||
|
)
|
||||||
|
self.request = request
|
||||||
|
self.added_headers: Dict[str, str] = {}
|
||||||
|
|
||||||
|
def add_unredirected_header(self, key: str, value: str) -> None:
|
||||||
|
super().add_unredirected_header(key, value)
|
||||||
|
self.added_headers[key] = value
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HTTPServerSetup:
|
class HTTPServerSetup:
|
||||||
|
@ -14,7 +14,7 @@ from itertools import chain
|
|||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from importlib.abc import MetaPathFinder
|
from importlib.abc import MetaPathFinder
|
||||||
from importlib.machinery import PathFinder, SourceFileLoader
|
from importlib.machinery import PathFinder, SourceFileLoader
|
||||||
from typing import Set, Dict, List, Union, Iterable, Optional, Sequence
|
from typing import Set, Dict, List, Iterable, Optional, Sequence
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.utils import escape_tag, path_to_module_name
|
from nonebot.utils import escape_tag, path_to_module_name
|
||||||
@ -174,7 +174,7 @@ class PluginFinder(MetaPathFinder):
|
|||||||
def find_spec(
|
def find_spec(
|
||||||
self,
|
self,
|
||||||
fullname: str,
|
fullname: str,
|
||||||
path: Optional[Sequence[Union[bytes, str]]],
|
path: Optional[Sequence[str]],
|
||||||
target: Optional[ModuleType] = None,
|
target: Optional[ModuleType] = None,
|
||||||
):
|
):
|
||||||
if _managers:
|
if _managers:
|
||||||
|
@ -88,6 +88,9 @@ extra_standard_library = ["typing_extensions"]
|
|||||||
path = "."
|
path = "."
|
||||||
all = false
|
all = false
|
||||||
|
|
||||||
|
[tool.pyright]
|
||||||
|
reportShadowedImports = false
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry_core>=1.0.0"]
|
requires = ["poetry_core>=1.0.0"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
Loading…
Reference in New Issue
Block a user