diff --git a/nonebot/drivers/aiohttp.py b/nonebot/drivers/aiohttp.py index b17901fa..c362174d 100644 --- a/nonebot/drivers/aiohttp.py +++ b/nonebot/drivers/aiohttp.py @@ -56,7 +56,13 @@ class Mixin(ForwardMixin): files = aiohttp.FormData() for name, file in setup.files: files.add_field(name, file[1], content_type=file[2], filename=file[0]) - async with aiohttp.ClientSession(version=version, trust_env=True) as session: + + cookies = { + cookie.name: cookie.value for cookie in setup.cookies if cookie.value + } + async with aiohttp.ClientSession( + cookies=cookies, version=version, trust_env=True + ) as session: async with session.request( setup.method, setup.url, @@ -66,13 +72,12 @@ class Mixin(ForwardMixin): timeout=timeout, proxy=setup.proxy, ) as response: - res = Response( + return Response( response.status, headers=response.headers.copy(), content=await response.read(), request=setup, ) - return res @overrides(ForwardMixin) @asynccontextmanager @@ -92,8 +97,7 @@ class Mixin(ForwardMixin): headers=setup.headers, proxy=setup.proxy, ) as ws: - websocket = WebSocket(request=setup, session=session, websocket=ws) - yield websocket + yield WebSocket(request=setup, session=session, websocket=ws) class WebSocket(BaseWebSocket): diff --git a/nonebot/drivers/httpx.py b/nonebot/drivers/httpx.py index 755ea0ec..de5617d2 100644 --- a/nonebot/drivers/httpx.py +++ b/nonebot/drivers/httpx.py @@ -48,17 +48,18 @@ class Mixin(ForwardMixin): @overrides(ForwardMixin) async def request(self, setup: Request) -> Response: async with httpx.AsyncClient( + cookies=setup.cookies.jar, http2=setup.version == HTTPVersion.H2, - proxies=setup.proxy, # type: ignore + proxies=setup.proxy, follow_redirects=True, ) as client: response = await client.request( setup.method, str(setup.url), - content=setup.content, # type: ignore - data=setup.data, # type: ignore + content=setup.content, + data=setup.data, json=setup.json, - files=setup.files, # type: ignore + files=setup.files, headers=tuple(setup.headers.items()), timeout=setup.timeout, ) diff --git a/nonebot/drivers/websockets.py b/nonebot/drivers/websockets.py index 1b4302f0..0701dfec 100644 --- a/nonebot/drivers/websockets.py +++ b/nonebot/drivers/websockets.py @@ -70,7 +70,7 @@ class Mixin(ForwardMixin): async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]: connection = Connect( str(setup.url), - extra_headers=setup.headers.items(), + extra_headers={**setup.headers, **setup.cookies.as_header(setup)}, open_timeout=setup.timeout, ) async with connection as ws: @@ -101,8 +101,7 @@ class WebSocket(BaseWebSocket): @overrides(BaseWebSocket) @catch_closed async def receive(self) -> Union[str, bytes]: - msg = await self.websocket.recv() - return msg + return await self.websocket.recv() @overrides(BaseWebSocket) @catch_closed diff --git a/nonebot/internal/adapter/template.py b/nonebot/internal/adapter/template.py index b67c71b8..45c58f20 100644 --- a/nonebot/internal/adapter/template.py +++ b/nonebot/internal/adapter/template.py @@ -49,11 +49,11 @@ class MessageTemplate(Formatter, Generic[TF]): ) -> None: ... - def __init__( # type:ignore - self, template, factory=str - ) -> None: # TODO: fix type hint here - self.template: TF = template - self.factory: Type[TF] = factory + def __init__( + self, template: Union[str, TM], factory: Union[Type[str], Type[TM]] = str + ) -> None: + self.template: TF = template # type: ignore + self.factory: Type[TF] = factory # type: ignore self.format_specs: Dict[str, FormatSpecFunc] = {} def __repr__(self) -> str: @@ -98,7 +98,7 @@ class MessageTemplate(Formatter, Generic[TF]): else: raise TypeError("template must be a string or instance of Message!") - self.check_unused_args(list(used_args), args, kwargs) + self.check_unused_args(used_args, args, kwargs) return cast(TF, full_message) def vformat( diff --git a/nonebot/internal/driver/model.py b/nonebot/internal/driver/model.py index f4bb9ae1..104b5c6d 100644 --- a/nonebot/internal/driver/model.py +++ b/nonebot/internal/driver/model.py @@ -1,4 +1,5 @@ import abc +import urllib.request from enum import Enum from dataclasses import dataclass from http.cookiejar import Cookie, CookieJar @@ -105,12 +106,9 @@ class Request: self.url: URL = url # headers - self.headers: CIMultiDict[str] - if headers is not None: - self.headers = CIMultiDict(headers) - else: - self.headers = CIMultiDict() - + self.headers: CIMultiDict[str] = ( + CIMultiDict(headers) if headers is not None else CIMultiDict() + ) # cookies self.cookies = Cookies(cookies) @@ -147,12 +145,9 @@ class Response: self.status_code: int = status_code # headers - self.headers: CIMultiDict[str] - if headers is not None: - self.headers = CIMultiDict(headers) - else: - self.headers = CIMultiDict() - + self.headers: CIMultiDict[str] = ( + CIMultiDict(headers) if headers is not None else CIMultiDict() + ) # body self.content: ContentTypes = content @@ -308,6 +303,11 @@ class Cookies(MutableMapping): for cookie in cookies.jar: self.jar.set_cookie(cookie) + def as_header(self, request: Request) -> Dict[str, str]: + urllib_request = self._CookieCompatRequest(request) + self.jar.add_cookie_header(urllib_request) + return urllib_request.added_headers + def __setitem__(self, name: str, value: str) -> None: return self.set(name, value) @@ -333,6 +333,20 @@ class Cookies(MutableMapping): ) return f"{self.__class__.__name__}({cookies_repr})" + class _CookieCompatRequest(urllib.request.Request): + def __init__(self, request: Request) -> None: + super().__init__( + url=str(request.url), + headers=dict(request.headers), + method=request.method, + ) + self.request = request + self.added_headers: Dict[str, str] = {} + + def add_unredirected_header(self, key: str, value: str) -> None: + super().add_unredirected_header(key, value) + self.added_headers[key] = value + @dataclass class HTTPServerSetup: diff --git a/nonebot/plugin/manager.py b/nonebot/plugin/manager.py index 5611b954..572879cf 100644 --- a/nonebot/plugin/manager.py +++ b/nonebot/plugin/manager.py @@ -14,7 +14,7 @@ from itertools import chain from types import ModuleType from importlib.abc import MetaPathFinder from importlib.machinery import PathFinder, SourceFileLoader -from typing import Set, Dict, List, Union, Iterable, Optional, Sequence +from typing import Set, Dict, List, Iterable, Optional, Sequence from nonebot.log import logger from nonebot.utils import escape_tag, path_to_module_name @@ -174,7 +174,7 @@ class PluginFinder(MetaPathFinder): def find_spec( self, fullname: str, - path: Optional[Sequence[Union[bytes, str]]], + path: Optional[Sequence[str]], target: Optional[ModuleType] = None, ): if _managers: diff --git a/pyproject.toml b/pyproject.toml index bbf7a605..1683fddd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,9 @@ extra_standard_library = ["typing_extensions"] path = "." all = false +[tool.pyright] +reportShadowedImports = false + [build-system] requires = ["poetry_core>=1.0.0"] build-backend = "poetry.core.masonry.api"