diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index 1785e1b6..75398eb9 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -7,22 +7,13 @@ import abc import asyncio -from dataclasses import field, dataclass -from typing import ( - TYPE_CHECKING, - Any, - Set, - Dict, - Type, - Union, - Callable, - Optional, - Awaitable, -) +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, Awaitable from nonebot.log import logger from nonebot.utils import escape_tag from nonebot.config import Env, Config +from ._model import URL, Request, Response, WebSocket from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook if TYPE_CHECKING: @@ -213,11 +204,11 @@ class ForwardDriver(Driver): """ @abc.abstractmethod - async def request(self, setup: "HTTPRequest") -> Any: + async def request(self, setup: "Request") -> Any: raise NotImplementedError @abc.abstractmethod - async def websocket(self, setup: "HTTPConnection") -> Any: + async def websocket(self, setup: "Request") -> Any: raise NotImplementedError @@ -247,153 +238,16 @@ class ReverseDriver(Driver): raise NotImplementedError -# TODO: repack dataclass -@dataclass -class HTTPConnection(abc.ABC): - http_version: str - """One of ``"1.0"``, ``"1.1"`` or ``"2"``.""" - scheme: str - """URL scheme portion (likely ``"http"`` or ``"https"``).""" - path: str - """ - HTTP request target excluding any query string, - with percent-encoded sequences and UTF-8 byte sequences - decoded into characters. - """ - query_string: bytes = b"" - """ URL portion after the ``?``, percent-encoded.""" - headers: Dict[str, str] = field(default_factory=dict) - """A dict of name-value pairs, - where name is the header name, and value is the header value. - - Order of header values must be preserved from the original HTTP request; - order of header names is not important. - - Header names must be lowercased. - """ - - @property - @abc.abstractmethod - def type(self) -> str: - """Connection type.""" - raise NotImplementedError - - -@dataclass -class HTTPRequest(HTTPConnection): - """HTTP 请求封装。参考 `asgi http scope`_。 - - .. _asgi http scope: - https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope - """ - - method: str = "GET" - """The HTTP method name, uppercased.""" - body: bytes = b"" - """Body of the request. - - Optional; if missing defaults to ``b""``. - """ - - @property - def type(self) -> str: - """Always ``http``""" - return "http" - - -@dataclass -class HTTPResponse: - """HTTP 响应封装。参考 `asgi http scope`_。 - - .. _asgi http scope: - https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope - """ - - status: int - """HTTP status code.""" - body: Optional[bytes] = None - """HTTP body content. - - Optional; if missing defaults to ``None``. - """ - headers: Dict[str, str] = field(default_factory=dict) - """A dict of name-value pairs, - where name is the header name, and value is the header value. - - Order must be preserved in the HTTP response. - - Header names must be lowercased. - - Optional; if missing defaults to an empty dict. - """ - - @property - def type(self) -> str: - """Always ``http``""" - return "http" - - -@dataclass -class WebSocket(HTTPConnection, abc.ABC): - """WebSocket 连接封装。参考 `asgi websocket scope`_。 - - .. _asgi websocket scope: - https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope - """ - - @property - def type(self) -> str: - """Always ``websocket``""" - return "websocket" - - @property - @abc.abstractmethod - def closed(self) -> bool: - """ - :类型: ``bool`` - :说明: 连接是否已经关闭 - """ - raise NotImplementedError - - @abc.abstractmethod - async def accept(self): - """接受 WebSocket 连接请求""" - raise NotImplementedError - - @abc.abstractmethod - async def close(self, code: int): - """关闭 WebSocket 连接请求""" - raise NotImplementedError - - @abc.abstractmethod - async def receive(self) -> str: - """接收一条 WebSocket text 信息""" - raise NotImplementedError - - @abc.abstractmethod - async def receive_bytes(self) -> bytes: - """接收一条 WebSocket binary 信息""" - raise NotImplementedError - - @abc.abstractmethod - async def send(self, data: str): - """发送一条 WebSocket text 信息""" - raise NotImplementedError - - @abc.abstractmethod - async def send_bytes(self, data: bytes): - """发送一条 WebSocket binary 信息""" - raise NotImplementedError - - @dataclass class HTTPServerSetup: - path: str + path: URL # path should not be absolute, check it by URL.is_absolute() == False method: str - handle_func: Callable[[HTTPRequest], Awaitable[HTTPResponse]] + name: str + handle_func: Callable[[Request], Awaitable[Response]] @dataclass class WebSocketServerSetup: - path: str + path: URL # path should not be absolute, check it by URL.is_absolute() == False + name: str handle_func: Callable[[WebSocket], Awaitable[Any]] diff --git a/nonebot/drivers/_model.py b/nonebot/drivers/_model.py new file mode 100644 index 00000000..ede3276a --- /dev/null +++ b/nonebot/drivers/_model.py @@ -0,0 +1,282 @@ +import abc +from enum import Enum +from http.cookiejar import Cookie, CookieJar +from typing import ( + Dict, + List, + Tuple, + Union, + Mapping, + Iterator, + Optional, + Sequence, + MutableMapping, +) + +from yarl import URL as URL +from multidict import CIMultiDict + +RawURL = Tuple[bytes, bytes, Optional[int], bytes] + +SimpleQuery = Union[str, int, float] +QueryVariable = Union[SimpleQuery, Sequence[SimpleQuery]] +QueryTypes = Union[ + None, str, Mapping[str, QueryVariable], Sequence[Tuple[str, QueryVariable]] +] + +HeaderTypes = Union[ + None, + CIMultiDict[str], + Dict[str, str], + Sequence[Tuple[str, str]], +] + +ContentTypes = Union[str, bytes] +CookieTypes = Union[None, "Cookies", CookieJar, Dict[str, str], List[Tuple[str, str]]] + + +class HTTPVersion(Enum): + H10 = "1.0" + H11 = "1.1" + H2 = "2" + + +class Request: + def __init__( + self, + method: Union[str, bytes], + url: Union["URL", str, RawURL], + *, + params: QueryTypes = None, + headers: HeaderTypes = None, + cookies: CookieTypes = None, + content: ContentTypes = None, + version: Union[str, HTTPVersion] = HTTPVersion.H11, + timeout: Optional[float] = None, + ): + # method + self.method = ( + method.decode("ascii").upper() + if isinstance(method, bytes) + else method.upper() + ) + # http version + self.version = HTTPVersion(version) + # timeout + self.timeout = timeout + + # url + if isinstance(url, tuple): + scheme, host, port, path = url + url = URL.build( + scheme=scheme.decode("ascii"), + host=host.decode("ascii"), + port=port, + path=path.decode("ascii"), + ) + else: + url = URL(url) + + if params is not None: + url = url.update_query(params) + self.url = url + + # headers + if headers is not None: + self.headers = CIMultiDict(headers) + else: + self.headers = CIMultiDict() + + # cookies + self.cookies = Cookies(cookies) + + # body + self.content = content + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + url = str(self.url) + return f"<{class_name}({self.method!r}, {url!r})>" + + +class Response: + def __init__( + self, + status_code: int, + *, + headers: HeaderTypes = None, + content: ContentTypes = None, + request: Optional[Request] = None, + ): + # status code + self.status_code = status_code + + # headers + if headers is not None: + self.headers = CIMultiDict(headers) + else: + self.headers = CIMultiDict() + + # body + self.content = content + + # request + self.request = request + + +class WebSocket(abc.ABC): + def __init__(self, *, request: Request): + # request + self.request = request + + @property + @abc.abstractmethod + def closed(self) -> bool: + """ + :类型: ``bool`` + :说明: 连接是否已经关闭 + """ + raise NotImplementedError + + @abc.abstractmethod + async def accept(self): + """接受 WebSocket 连接请求""" + raise NotImplementedError + + @abc.abstractmethod + async def close(self, code: int): + """关闭 WebSocket 连接请求""" + raise NotImplementedError + + @abc.abstractmethod + async def receive(self) -> str: + """接收一条 WebSocket text 信息""" + raise NotImplementedError + + @abc.abstractmethod + async def receive_bytes(self) -> bytes: + """接收一条 WebSocket binary 信息""" + raise NotImplementedError + + @abc.abstractmethod + async def send(self, data: str): + """发送一条 WebSocket text 信息""" + raise NotImplementedError + + @abc.abstractmethod + async def send_bytes(self, data: bytes): + """发送一条 WebSocket binary 信息""" + raise NotImplementedError + + +class Cookies(MutableMapping): + def __init__(self, cookies: CookieTypes = None) -> None: + self.jar = cookies if isinstance(cookies, CookieJar) else CookieJar() + if cookies is not None and not isinstance(cookies, CookieJar): + if isinstance(cookies, dict): + for key, value in cookies.items(): + self.set(key, value) + elif isinstance(cookies, list): + for key, value in cookies: + self.set(key, value) + elif isinstance(cookies, Cookies): + for cookie in cookies.jar: + self.jar.set_cookie(cookie) + else: + raise TypeError(f"Cookies must be dict or list, not {type(cookies)}") + + def set(self, name: str, value: str, domain: str = "", path: str = "/") -> None: + cookie = Cookie( + version=0, + name=name, + value=value, + port=None, + port_specified=False, + domain=domain, + domain_specified=bool(domain), + domain_initial_dot=domain.startswith("."), + path=path, + path_specified=bool(path), + secure=False, + expires=None, + discard=True, + comment=None, + comment_url=None, + rest={}, + rfc2109=False, + ) + self.jar.set_cookie(cookie) + + def get( + self, + name: str, + default: Optional[str] = None, + domain: str = None, + path: str = None, + ) -> Optional[str]: + value: Optional[str] = None + for cookie in self.jar: + if ( + cookie.name == name + and (domain is None or cookie.domain == domain) + and (path is None or cookie.path == path) + ): + if value is not None: + message = f"Multiple cookies exist with name={name}" + raise ValueError(message) + value = cookie.value + + return default if value is None else value + + def delete( + self, name: str, domain: Optional[str] = None, path: Optional[str] = None + ) -> None: + if domain is not None and path is not None: + return self.jar.clear(domain, path, name) + + remove = [ + cookie + for cookie in self.jar + if cookie.name == name + and (domain is None or cookie.domain == domain) + and (path is None or cookie.path == path) + ] + + for cookie in remove: + self.jar.clear(cookie.domain, cookie.path, cookie.name) + + def clear(self, domain: Optional[str] = None, path: Optional[str] = None) -> None: + self.jar.clear(domain, path) + + def update(self, cookies: CookieTypes = None) -> None: + cookies = Cookies(cookies) + for cookie in cookies.jar: + self.jar.set_cookie(cookie) + + def __setitem__(self, name: str, value: str) -> None: + return self.set(name, value) + + def __getitem__(self, name: str) -> str: + value = self.get(name) + if value is None: + raise KeyError(name) + return value + + def __delitem__(self, name: str) -> None: + return self.delete(name) + + def __len__(self) -> int: + return len(self.jar) + + def __iter__(self) -> Iterator[Cookie]: + return (cookie for cookie in self.jar) + + def __repr__(self) -> str: + cookies_repr = ", ".join( + [ + f"" + for cookie in self.jar + ] + ) + + return f"" diff --git a/poetry.lock b/poetry.lock index 12ca0bc1..fedf1714 100644 --- a/poetry.lock +++ b/poetry.lock @@ -513,7 +513,7 @@ name = "multidict" version = "5.2.0" description = "multidict implementation" category = "main" -optional = true +optional = false python-versions = ">=3.6" [[package]] @@ -1152,7 +1152,7 @@ name = "yarl" version = "1.7.2" description = "Yet another URL library" category = "main" -optional = true +optional = false python-versions = ">=3.6" [package.dependencies] @@ -1180,7 +1180,7 @@ quart = ["Quart"] [metadata] lock-version = "1.1" python-versions = "^3.7.3" -content-hash = "32ef230c4c02d1eb68754ad0bb938d1d960d7c6362ea2250cc23bbed910fbe34" +content-hash = "3cef667f39a760749d4c1d35883be4cec9c224e8079ee28752f6aa7f58688442" [metadata.files] aiodns = [ diff --git a/pyproject.toml b/pyproject.toml index 4312f2a8..5dfd8cbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"] } pydantic = { version = "~1.8.0", extras = ["dotenv"] } uvicorn = { version = "^0.15.0", extras = ["standard"] } aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true } +yarl = "^1.7.2" [tool.poetry.dev-dependencies] sphinx = "^4.1.1"