mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-28 04:26:24 +08:00
360 lines
11 KiB
Python
360 lines
11 KiB
Python
import abc
|
|
import urllib.request
|
|
from enum import Enum
|
|
from dataclasses import dataclass
|
|
from typing_extensions import TypeAlias
|
|
from http.cookiejar import Cookie, CookieJar
|
|
from typing import IO, Any, Union, Callable, Optional
|
|
from collections.abc import Mapping, Iterator, Awaitable, MutableMapping
|
|
|
|
from yarl import URL as URL
|
|
from multidict import CIMultiDict
|
|
|
|
RawURL: TypeAlias = tuple[bytes, bytes, Optional[int], bytes]
|
|
|
|
SimpleQuery: TypeAlias = Union[str, int, float]
|
|
QueryVariable: TypeAlias = Union[SimpleQuery, list[SimpleQuery]]
|
|
QueryTypes: TypeAlias = Union[
|
|
None, str, Mapping[str, QueryVariable], list[tuple[str, SimpleQuery]]
|
|
]
|
|
|
|
HeaderTypes: TypeAlias = Union[
|
|
None,
|
|
CIMultiDict[str],
|
|
dict[str, str],
|
|
list[tuple[str, str]],
|
|
]
|
|
|
|
CookieTypes: TypeAlias = Union[
|
|
None, "Cookies", CookieJar, dict[str, str], list[tuple[str, str]]
|
|
]
|
|
|
|
ContentTypes: TypeAlias = Union[str, bytes, None]
|
|
DataTypes: TypeAlias = Union[dict, None]
|
|
FileContent: TypeAlias = Union[IO[bytes], bytes]
|
|
FileType: TypeAlias = tuple[Optional[str], FileContent, Optional[str]]
|
|
FileTypes: TypeAlias = Union[
|
|
# file (or bytes)
|
|
FileContent,
|
|
# (filename, file (or bytes))
|
|
tuple[Optional[str], FileContent],
|
|
# (filename, file (or bytes), content_type)
|
|
FileType,
|
|
]
|
|
FilesTypes: TypeAlias = Union[dict[str, FileTypes], list[tuple[str, FileTypes]], None]
|
|
|
|
|
|
class HTTPVersion(Enum):
|
|
H10 = "1.0"
|
|
H11 = "1.1"
|
|
H2 = "2"
|
|
|
|
|
|
class Request:
|
|
def __init__(
|
|
self,
|
|
method: Union[str, bytes],
|
|
url: Union["URL", str, RawURL],
|
|
*,
|
|
params: QueryTypes = None,
|
|
headers: HeaderTypes = None,
|
|
cookies: CookieTypes = None,
|
|
content: ContentTypes = None,
|
|
data: DataTypes = None,
|
|
json: Any = None,
|
|
files: FilesTypes = None,
|
|
version: Union[str, HTTPVersion] = HTTPVersion.H11,
|
|
timeout: Optional[float] = None,
|
|
proxy: Optional[str] = None,
|
|
):
|
|
# method
|
|
self.method: str = (
|
|
method.decode("ascii").upper()
|
|
if isinstance(method, bytes)
|
|
else method.upper()
|
|
)
|
|
# http version
|
|
self.version: HTTPVersion = HTTPVersion(version)
|
|
# timeout
|
|
self.timeout: Optional[float] = timeout
|
|
# proxy
|
|
self.proxy: Optional[str] = proxy
|
|
|
|
# url
|
|
if isinstance(url, tuple):
|
|
scheme, host, port, path = url
|
|
url = URL.build(
|
|
scheme=scheme.decode("ascii"),
|
|
host=host.decode("ascii"),
|
|
port=port,
|
|
path=path.decode("ascii"),
|
|
)
|
|
else:
|
|
url = URL(url)
|
|
|
|
if params is not None:
|
|
url = url.update_query(params)
|
|
self.url: URL = url
|
|
|
|
# headers
|
|
self.headers: CIMultiDict[str] = (
|
|
CIMultiDict(headers) if headers is not None else CIMultiDict()
|
|
)
|
|
# cookies
|
|
self.cookies = Cookies(cookies)
|
|
|
|
# body
|
|
self.content: ContentTypes = content
|
|
self.data: DataTypes = data
|
|
self.json: Any = json
|
|
self.files: Optional[list[tuple[str, FileType]]] = None
|
|
if files:
|
|
self.files = []
|
|
files_ = files.items() if isinstance(files, dict) else files
|
|
for name, file_info in files_:
|
|
if not isinstance(file_info, tuple):
|
|
self.files.append((name, (name, file_info, None)))
|
|
elif len(file_info) == 2:
|
|
self.files.append((name, (file_info[0], file_info[1], None)))
|
|
else:
|
|
self.files.append((name, file_info)) # type: ignore
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{self.__class__.__name__}(method={self.method!r}, url='{self.url!s}')"
|
|
|
|
|
|
class Response:
|
|
def __init__(
|
|
self,
|
|
status_code: int,
|
|
*,
|
|
headers: HeaderTypes = None,
|
|
content: ContentTypes = None,
|
|
request: Optional[Request] = None,
|
|
):
|
|
# status code
|
|
self.status_code: int = status_code
|
|
|
|
# headers
|
|
self.headers: CIMultiDict[str] = (
|
|
CIMultiDict(headers) if headers is not None else CIMultiDict()
|
|
)
|
|
# body
|
|
self.content: ContentTypes = content
|
|
|
|
# request
|
|
self.request: Optional[Request] = request
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{self.__class__.__name__}(status_code={self.status_code!r})"
|
|
|
|
|
|
class WebSocket(abc.ABC):
|
|
def __init__(self, *, request: Request):
|
|
self.request: Request = request
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{self.__class__.__name__}('{self.request.url!s}')"
|
|
|
|
@property
|
|
@abc.abstractmethod
|
|
def closed(self) -> bool:
|
|
"""连接是否已经关闭"""
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
async def accept(self) -> None:
|
|
"""接受 WebSocket 连接请求"""
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
async def close(self, code: int = 1000, reason: str = "") -> None:
|
|
"""关闭 WebSocket 连接请求"""
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
async def receive(self) -> Union[str, bytes]:
|
|
"""接收一条 WebSocket text/bytes 信息"""
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
async def receive_text(self) -> str:
|
|
"""接收一条 WebSocket text 信息"""
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
async def receive_bytes(self) -> bytes:
|
|
"""接收一条 WebSocket binary 信息"""
|
|
raise NotImplementedError
|
|
|
|
async def send(self, data: Union[str, bytes]) -> None:
|
|
"""发送一条 WebSocket text/bytes 信息"""
|
|
if isinstance(data, str):
|
|
await self.send_text(data)
|
|
elif isinstance(data, bytes):
|
|
await self.send_bytes(data)
|
|
else:
|
|
raise TypeError("WebSocker send method expects str or bytes!")
|
|
|
|
@abc.abstractmethod
|
|
async def send_text(self, data: str) -> None:
|
|
"""发送一条 WebSocket text 信息"""
|
|
raise NotImplementedError
|
|
|
|
@abc.abstractmethod
|
|
async def send_bytes(self, data: bytes) -> None:
|
|
"""发送一条 WebSocket binary 信息"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class Cookies(MutableMapping):
|
|
def __init__(self, cookies: CookieTypes = None) -> None:
|
|
self.jar: CookieJar = cookies if isinstance(cookies, CookieJar) else CookieJar()
|
|
if cookies is not None and not isinstance(cookies, CookieJar):
|
|
if isinstance(cookies, dict):
|
|
for key, value in cookies.items():
|
|
self.set(key, value)
|
|
elif isinstance(cookies, list):
|
|
for key, value in cookies:
|
|
self.set(key, value)
|
|
elif isinstance(cookies, Cookies):
|
|
for cookie in cookies.jar:
|
|
self.jar.set_cookie(cookie)
|
|
else:
|
|
raise TypeError(f"Cookies must be dict or list, not {type(cookies)}")
|
|
|
|
def set(self, name: str, value: str, domain: str = "", path: str = "/") -> None:
|
|
cookie = Cookie(
|
|
version=0,
|
|
name=name,
|
|
value=value,
|
|
port=None,
|
|
port_specified=False,
|
|
domain=domain,
|
|
domain_specified=bool(domain),
|
|
domain_initial_dot=domain.startswith("."),
|
|
path=path,
|
|
path_specified=bool(path),
|
|
secure=False,
|
|
expires=None,
|
|
discard=True,
|
|
comment=None,
|
|
comment_url=None,
|
|
rest={},
|
|
rfc2109=False,
|
|
)
|
|
self.jar.set_cookie(cookie)
|
|
|
|
def get( # pyright: ignore[reportIncompatibleMethodOverride]
|
|
self,
|
|
name: str,
|
|
default: Optional[str] = None,
|
|
domain: Optional[str] = None,
|
|
path: Optional[str] = None,
|
|
) -> Optional[str]:
|
|
value: Optional[str] = None
|
|
for cookie in self.jar:
|
|
if (
|
|
cookie.name == name
|
|
and (domain is None or cookie.domain == domain)
|
|
and (path is None or cookie.path == path)
|
|
):
|
|
if value is not None:
|
|
message = f"Multiple cookies exist with name={name}"
|
|
raise ValueError(message)
|
|
value = cookie.value
|
|
|
|
return default if value is None else value
|
|
|
|
def delete(
|
|
self, name: str, domain: Optional[str] = None, path: Optional[str] = None
|
|
) -> None:
|
|
if domain is not None and path is not None:
|
|
return self.jar.clear(domain, path, name)
|
|
|
|
remove = [
|
|
cookie
|
|
for cookie in self.jar
|
|
if cookie.name == name
|
|
and (domain is None or cookie.domain == domain)
|
|
and (path is None or cookie.path == path)
|
|
]
|
|
|
|
for cookie in remove:
|
|
self.jar.clear(cookie.domain, cookie.path, cookie.name)
|
|
|
|
def clear(self, domain: Optional[str] = None, path: Optional[str] = None) -> None:
|
|
self.jar.clear(domain, path)
|
|
|
|
def update( # pyright: ignore[reportIncompatibleMethodOverride]
|
|
self, cookies: CookieTypes = None
|
|
) -> None:
|
|
cookies = Cookies(cookies)
|
|
for cookie in cookies.jar:
|
|
self.jar.set_cookie(cookie)
|
|
|
|
def as_header(self, request: Request) -> dict[str, str]:
|
|
urllib_request = self._CookieCompatRequest(request)
|
|
self.jar.add_cookie_header(urllib_request)
|
|
return urllib_request.added_headers
|
|
|
|
def __setitem__(self, name: str, value: str) -> None:
|
|
return self.set(name, value)
|
|
|
|
def __getitem__(self, name: str) -> str:
|
|
value = self.get(name)
|
|
if value is None:
|
|
raise KeyError(name)
|
|
return value
|
|
|
|
def __delitem__(self, name: str) -> None:
|
|
return self.delete(name)
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.jar)
|
|
|
|
def __iter__(self) -> Iterator[Cookie]:
|
|
return iter(self.jar)
|
|
|
|
def __repr__(self) -> str:
|
|
cookies_repr = ", ".join(
|
|
f"Cookie({cookie.name}={cookie.value} for {cookie.domain})"
|
|
for cookie in self.jar
|
|
)
|
|
return f"{self.__class__.__name__}({cookies_repr})"
|
|
|
|
class _CookieCompatRequest(urllib.request.Request):
|
|
def __init__(self, request: Request) -> None:
|
|
super().__init__(
|
|
url=str(request.url),
|
|
headers=dict(request.headers),
|
|
method=request.method,
|
|
)
|
|
self.request = request
|
|
self.added_headers: dict[str, str] = {}
|
|
|
|
def add_unredirected_header( # pyright: ignore[reportIncompatibleMethodOverride]
|
|
self, key: str, value: str
|
|
) -> None:
|
|
super().add_unredirected_header(key, value)
|
|
self.added_headers[key] = value
|
|
|
|
|
|
@dataclass
|
|
class HTTPServerSetup:
|
|
"""HTTP 服务器路由配置。"""
|
|
|
|
path: URL # path should not be absolute, check it by URL.is_absolute() == False
|
|
method: str
|
|
name: str
|
|
handle_func: Callable[[Request], Awaitable[Response]]
|
|
|
|
|
|
@dataclass
|
|
class WebSocketServerSetup:
|
|
"""WebSocket 服务器路由配置。"""
|
|
|
|
path: URL # path should not be absolute, check it by URL.is_absolute() == False
|
|
name: str
|
|
handle_func: Callable[[WebSocket], Awaitable[Any]]
|