nonebot2/nonebot/internal/driver/model.py
2024-04-16 00:33:48 +08:00

360 lines
11 KiB
Python

import abc
import urllib.request
from enum import Enum
from dataclasses import dataclass
from typing_extensions import TypeAlias
from http.cookiejar import Cookie, CookieJar
from typing import IO, Any, Union, Callable, Optional
from collections.abc import Mapping, Iterator, Awaitable, MutableMapping
from yarl import URL as URL
from multidict import CIMultiDict
RawURL: TypeAlias = tuple[bytes, bytes, Optional[int], bytes]
SimpleQuery: TypeAlias = Union[str, int, float]
QueryVariable: TypeAlias = Union[SimpleQuery, list[SimpleQuery]]
QueryTypes: TypeAlias = Union[
None, str, Mapping[str, QueryVariable], list[tuple[str, SimpleQuery]]
]
HeaderTypes: TypeAlias = Union[
None,
CIMultiDict[str],
dict[str, str],
list[tuple[str, str]],
]
CookieTypes: TypeAlias = Union[
None, "Cookies", CookieJar, dict[str, str], list[tuple[str, str]]
]
ContentTypes: TypeAlias = Union[str, bytes, None]
DataTypes: TypeAlias = Union[dict, None]
FileContent: TypeAlias = Union[IO[bytes], bytes]
FileType: TypeAlias = tuple[Optional[str], FileContent, Optional[str]]
FileTypes: TypeAlias = Union[
# file (or bytes)
FileContent,
# (filename, file (or bytes))
tuple[Optional[str], FileContent],
# (filename, file (or bytes), content_type)
FileType,
]
FilesTypes: TypeAlias = Union[dict[str, FileTypes], list[tuple[str, FileTypes]], None]
class HTTPVersion(Enum):
H10 = "1.0"
H11 = "1.1"
H2 = "2"
class Request:
def __init__(
self,
method: Union[str, bytes],
url: Union["URL", str, RawURL],
*,
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
content: ContentTypes = None,
data: DataTypes = None,
json: Any = None,
files: FilesTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
proxy: Optional[str] = None,
):
# method
self.method: str = (
method.decode("ascii").upper()
if isinstance(method, bytes)
else method.upper()
)
# http version
self.version: HTTPVersion = HTTPVersion(version)
# timeout
self.timeout: Optional[float] = timeout
# proxy
self.proxy: Optional[str] = proxy
# url
if isinstance(url, tuple):
scheme, host, port, path = url
url = URL.build(
scheme=scheme.decode("ascii"),
host=host.decode("ascii"),
port=port,
path=path.decode("ascii"),
)
else:
url = URL(url)
if params is not None:
url = url.update_query(params)
self.url: URL = url
# headers
self.headers: CIMultiDict[str] = (
CIMultiDict(headers) if headers is not None else CIMultiDict()
)
# cookies
self.cookies = Cookies(cookies)
# body
self.content: ContentTypes = content
self.data: DataTypes = data
self.json: Any = json
self.files: Optional[list[tuple[str, FileType]]] = None
if files:
self.files = []
files_ = files.items() if isinstance(files, dict) else files
for name, file_info in files_:
if not isinstance(file_info, tuple):
self.files.append((name, (name, file_info, None)))
elif len(file_info) == 2:
self.files.append((name, (file_info[0], file_info[1], None)))
else:
self.files.append((name, file_info)) # type: ignore
def __repr__(self) -> str:
return f"{self.__class__.__name__}(method={self.method!r}, url='{self.url!s}')"
class Response:
def __init__(
self,
status_code: int,
*,
headers: HeaderTypes = None,
content: ContentTypes = None,
request: Optional[Request] = None,
):
# status code
self.status_code: int = status_code
# headers
self.headers: CIMultiDict[str] = (
CIMultiDict(headers) if headers is not None else CIMultiDict()
)
# body
self.content: ContentTypes = content
# request
self.request: Optional[Request] = request
def __repr__(self) -> str:
return f"{self.__class__.__name__}(status_code={self.status_code!r})"
class WebSocket(abc.ABC):
def __init__(self, *, request: Request):
self.request: Request = request
def __repr__(self) -> str:
return f"{self.__class__.__name__}('{self.request.url!s}')"
@property
@abc.abstractmethod
def closed(self) -> bool:
"""连接是否已经关闭"""
raise NotImplementedError
@abc.abstractmethod
async def accept(self) -> None:
"""接受 WebSocket 连接请求"""
raise NotImplementedError
@abc.abstractmethod
async def close(self, code: int = 1000, reason: str = "") -> None:
"""关闭 WebSocket 连接请求"""
raise NotImplementedError
@abc.abstractmethod
async def receive(self) -> Union[str, bytes]:
"""接收一条 WebSocket text/bytes 信息"""
raise NotImplementedError
@abc.abstractmethod
async def receive_text(self) -> str:
"""接收一条 WebSocket text 信息"""
raise NotImplementedError
@abc.abstractmethod
async def receive_bytes(self) -> bytes:
"""接收一条 WebSocket binary 信息"""
raise NotImplementedError
async def send(self, data: Union[str, bytes]) -> None:
"""发送一条 WebSocket text/bytes 信息"""
if isinstance(data, str):
await self.send_text(data)
elif isinstance(data, bytes):
await self.send_bytes(data)
else:
raise TypeError("WebSocker send method expects str or bytes!")
@abc.abstractmethod
async def send_text(self, data: str) -> None:
"""发送一条 WebSocket text 信息"""
raise NotImplementedError
@abc.abstractmethod
async def send_bytes(self, data: bytes) -> None:
"""发送一条 WebSocket binary 信息"""
raise NotImplementedError
class Cookies(MutableMapping):
def __init__(self, cookies: CookieTypes = None) -> None:
self.jar: CookieJar = cookies if isinstance(cookies, CookieJar) else CookieJar()
if cookies is not None and not isinstance(cookies, CookieJar):
if isinstance(cookies, dict):
for key, value in cookies.items():
self.set(key, value)
elif isinstance(cookies, list):
for key, value in cookies:
self.set(key, value)
elif isinstance(cookies, Cookies):
for cookie in cookies.jar:
self.jar.set_cookie(cookie)
else:
raise TypeError(f"Cookies must be dict or list, not {type(cookies)}")
def set(self, name: str, value: str, domain: str = "", path: str = "/") -> None:
cookie = Cookie(
version=0,
name=name,
value=value,
port=None,
port_specified=False,
domain=domain,
domain_specified=bool(domain),
domain_initial_dot=domain.startswith("."),
path=path,
path_specified=bool(path),
secure=False,
expires=None,
discard=True,
comment=None,
comment_url=None,
rest={},
rfc2109=False,
)
self.jar.set_cookie(cookie)
def get( # pyright: ignore[reportIncompatibleMethodOverride]
self,
name: str,
default: Optional[str] = None,
domain: Optional[str] = None,
path: Optional[str] = None,
) -> Optional[str]:
value: Optional[str] = None
for cookie in self.jar:
if (
cookie.name == name
and (domain is None or cookie.domain == domain)
and (path is None or cookie.path == path)
):
if value is not None:
message = f"Multiple cookies exist with name={name}"
raise ValueError(message)
value = cookie.value
return default if value is None else value
def delete(
self, name: str, domain: Optional[str] = None, path: Optional[str] = None
) -> None:
if domain is not None and path is not None:
return self.jar.clear(domain, path, name)
remove = [
cookie
for cookie in self.jar
if cookie.name == name
and (domain is None or cookie.domain == domain)
and (path is None or cookie.path == path)
]
for cookie in remove:
self.jar.clear(cookie.domain, cookie.path, cookie.name)
def clear(self, domain: Optional[str] = None, path: Optional[str] = None) -> None:
self.jar.clear(domain, path)
def update( # pyright: ignore[reportIncompatibleMethodOverride]
self, cookies: CookieTypes = None
) -> None:
cookies = Cookies(cookies)
for cookie in cookies.jar:
self.jar.set_cookie(cookie)
def as_header(self, request: Request) -> dict[str, str]:
urllib_request = self._CookieCompatRequest(request)
self.jar.add_cookie_header(urllib_request)
return urllib_request.added_headers
def __setitem__(self, name: str, value: str) -> None:
return self.set(name, value)
def __getitem__(self, name: str) -> str:
value = self.get(name)
if value is None:
raise KeyError(name)
return value
def __delitem__(self, name: str) -> None:
return self.delete(name)
def __len__(self) -> int:
return len(self.jar)
def __iter__(self) -> Iterator[Cookie]:
return iter(self.jar)
def __repr__(self) -> str:
cookies_repr = ", ".join(
f"Cookie({cookie.name}={cookie.value} for {cookie.domain})"
for cookie in self.jar
)
return f"{self.__class__.__name__}({cookies_repr})"
class _CookieCompatRequest(urllib.request.Request):
def __init__(self, request: Request) -> None:
super().__init__(
url=str(request.url),
headers=dict(request.headers),
method=request.method,
)
self.request = request
self.added_headers: dict[str, str] = {}
def add_unredirected_header( # pyright: ignore[reportIncompatibleMethodOverride]
self, key: str, value: str
) -> None:
super().add_unredirected_header(key, value)
self.added_headers[key] = value
@dataclass
class HTTPServerSetup:
"""HTTP 服务器路由配置。"""
path: URL # path should not be absolute, check it by URL.is_absolute() == False
method: str
name: str
handle_func: Callable[[Request], Awaitable[Response]]
@dataclass
class WebSocketServerSetup:
"""WebSocket 服务器路由配置。"""
path: URL # path should not be absolute, check it by URL.is_absolute() == False
name: str
handle_func: Callable[[WebSocket], Awaitable[Any]]