♻️ rewrite driver request and response class

This commit is contained in:
yanyongyu 2021-12-17 23:20:19 +08:00
parent c0f321116a
commit ec9e159ef6
4 changed files with 296 additions and 159 deletions

View File

@ -7,22 +7,13 @@
import abc import abc
import asyncio import asyncio
from dataclasses import field, dataclass from dataclasses import dataclass
from typing import ( from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, Awaitable
TYPE_CHECKING,
Any,
Set,
Dict,
Type,
Union,
Callable,
Optional,
Awaitable,
)
from nonebot.log import logger from nonebot.log import logger
from nonebot.utils import escape_tag from nonebot.utils import escape_tag
from nonebot.config import Env, Config from nonebot.config import Env, Config
from ._model import URL, Request, Response, WebSocket
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
if TYPE_CHECKING: if TYPE_CHECKING:
@ -213,11 +204,11 @@ class ForwardDriver(Driver):
""" """
@abc.abstractmethod @abc.abstractmethod
async def request(self, setup: "HTTPRequest") -> Any: async def request(self, setup: "Request") -> Any:
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def websocket(self, setup: "HTTPConnection") -> Any: async def websocket(self, setup: "Request") -> Any:
raise NotImplementedError raise NotImplementedError
@ -247,153 +238,16 @@ class ReverseDriver(Driver):
raise NotImplementedError raise NotImplementedError
# TODO: repack dataclass
@dataclass
class HTTPConnection(abc.ABC):
http_version: str
"""One of ``"1.0"``, ``"1.1"`` or ``"2"``."""
scheme: str
"""URL scheme portion (likely ``"http"`` or ``"https"``)."""
path: str
"""
HTTP request target excluding any query string,
with percent-encoded sequences and UTF-8 byte sequences
decoded into characters.
"""
query_string: bytes = b""
""" URL portion after the ``?``, percent-encoded."""
headers: Dict[str, str] = field(default_factory=dict)
"""A dict of name-value pairs,
where name is the header name, and value is the header value.
Order of header values must be preserved from the original HTTP request;
order of header names is not important.
Header names must be lowercased.
"""
@property
@abc.abstractmethod
def type(self) -> str:
"""Connection type."""
raise NotImplementedError
@dataclass
class HTTPRequest(HTTPConnection):
"""HTTP 请求封装。参考 `asgi http scope`_。
.. _asgi http scope:
https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
"""
method: str = "GET"
"""The HTTP method name, uppercased."""
body: bytes = b""
"""Body of the request.
Optional; if missing defaults to ``b""``.
"""
@property
def type(self) -> str:
"""Always ``http``"""
return "http"
@dataclass
class HTTPResponse:
"""HTTP 响应封装。参考 `asgi http scope`_。
.. _asgi http scope:
https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
"""
status: int
"""HTTP status code."""
body: Optional[bytes] = None
"""HTTP body content.
Optional; if missing defaults to ``None``.
"""
headers: Dict[str, str] = field(default_factory=dict)
"""A dict of name-value pairs,
where name is the header name, and value is the header value.
Order must be preserved in the HTTP response.
Header names must be lowercased.
Optional; if missing defaults to an empty dict.
"""
@property
def type(self) -> str:
"""Always ``http``"""
return "http"
@dataclass
class WebSocket(HTTPConnection, abc.ABC):
"""WebSocket 连接封装。参考 `asgi websocket scope`_。
.. _asgi websocket scope:
https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope
"""
@property
def type(self) -> str:
"""Always ``websocket``"""
return "websocket"
@property
@abc.abstractmethod
def closed(self) -> bool:
"""
:类型: ``bool``
:说明: 连接是否已经关闭
"""
raise NotImplementedError
@abc.abstractmethod
async def accept(self):
"""接受 WebSocket 连接请求"""
raise NotImplementedError
@abc.abstractmethod
async def close(self, code: int):
"""关闭 WebSocket 连接请求"""
raise NotImplementedError
@abc.abstractmethod
async def receive(self) -> str:
"""接收一条 WebSocket text 信息"""
raise NotImplementedError
@abc.abstractmethod
async def receive_bytes(self) -> bytes:
"""接收一条 WebSocket binary 信息"""
raise NotImplementedError
@abc.abstractmethod
async def send(self, data: str):
"""发送一条 WebSocket text 信息"""
raise NotImplementedError
@abc.abstractmethod
async def send_bytes(self, data: bytes):
"""发送一条 WebSocket binary 信息"""
raise NotImplementedError
@dataclass @dataclass
class HTTPServerSetup: class HTTPServerSetup:
path: str path: URL # path should not be absolute, check it by URL.is_absolute() == False
method: str method: str
handle_func: Callable[[HTTPRequest], Awaitable[HTTPResponse]] name: str
handle_func: Callable[[Request], Awaitable[Response]]
@dataclass @dataclass
class WebSocketServerSetup: class WebSocketServerSetup:
path: str path: URL # path should not be absolute, check it by URL.is_absolute() == False
name: str
handle_func: Callable[[WebSocket], Awaitable[Any]] handle_func: Callable[[WebSocket], Awaitable[Any]]

282
nonebot/drivers/_model.py Normal file
View File

@ -0,0 +1,282 @@
import abc
from enum import Enum
from http.cookiejar import Cookie, CookieJar
from typing import (
Dict,
List,
Tuple,
Union,
Mapping,
Iterator,
Optional,
Sequence,
MutableMapping,
)
from yarl import URL as URL
from multidict import CIMultiDict
RawURL = Tuple[bytes, bytes, Optional[int], bytes]
SimpleQuery = Union[str, int, float]
QueryVariable = Union[SimpleQuery, Sequence[SimpleQuery]]
QueryTypes = Union[
None, str, Mapping[str, QueryVariable], Sequence[Tuple[str, QueryVariable]]
]
HeaderTypes = Union[
None,
CIMultiDict[str],
Dict[str, str],
Sequence[Tuple[str, str]],
]
ContentTypes = Union[str, bytes]
CookieTypes = Union[None, "Cookies", CookieJar, Dict[str, str], List[Tuple[str, str]]]
class HTTPVersion(Enum):
H10 = "1.0"
H11 = "1.1"
H2 = "2"
class Request:
def __init__(
self,
method: Union[str, bytes],
url: Union["URL", str, RawURL],
*,
params: QueryTypes = None,
headers: HeaderTypes = None,
cookies: CookieTypes = None,
content: ContentTypes = None,
version: Union[str, HTTPVersion] = HTTPVersion.H11,
timeout: Optional[float] = None,
):
# method
self.method = (
method.decode("ascii").upper()
if isinstance(method, bytes)
else method.upper()
)
# http version
self.version = HTTPVersion(version)
# timeout
self.timeout = timeout
# url
if isinstance(url, tuple):
scheme, host, port, path = url
url = URL.build(
scheme=scheme.decode("ascii"),
host=host.decode("ascii"),
port=port,
path=path.decode("ascii"),
)
else:
url = URL(url)
if params is not None:
url = url.update_query(params)
self.url = url
# headers
if headers is not None:
self.headers = CIMultiDict(headers)
else:
self.headers = CIMultiDict()
# cookies
self.cookies = Cookies(cookies)
# body
self.content = content
def __repr__(self) -> str:
class_name = self.__class__.__name__
url = str(self.url)
return f"<{class_name}({self.method!r}, {url!r})>"
class Response:
def __init__(
self,
status_code: int,
*,
headers: HeaderTypes = None,
content: ContentTypes = None,
request: Optional[Request] = None,
):
# status code
self.status_code = status_code
# headers
if headers is not None:
self.headers = CIMultiDict(headers)
else:
self.headers = CIMultiDict()
# body
self.content = content
# request
self.request = request
class WebSocket(abc.ABC):
def __init__(self, *, request: Request):
# request
self.request = request
@property
@abc.abstractmethod
def closed(self) -> bool:
"""
:类型: ``bool``
:说明: 连接是否已经关闭
"""
raise NotImplementedError
@abc.abstractmethod
async def accept(self):
"""接受 WebSocket 连接请求"""
raise NotImplementedError
@abc.abstractmethod
async def close(self, code: int):
"""关闭 WebSocket 连接请求"""
raise NotImplementedError
@abc.abstractmethod
async def receive(self) -> str:
"""接收一条 WebSocket text 信息"""
raise NotImplementedError
@abc.abstractmethod
async def receive_bytes(self) -> bytes:
"""接收一条 WebSocket binary 信息"""
raise NotImplementedError
@abc.abstractmethod
async def send(self, data: str):
"""发送一条 WebSocket text 信息"""
raise NotImplementedError
@abc.abstractmethod
async def send_bytes(self, data: bytes):
"""发送一条 WebSocket binary 信息"""
raise NotImplementedError
class Cookies(MutableMapping):
def __init__(self, cookies: CookieTypes = None) -> None:
self.jar = cookies if isinstance(cookies, CookieJar) else CookieJar()
if cookies is not None and not isinstance(cookies, CookieJar):
if isinstance(cookies, dict):
for key, value in cookies.items():
self.set(key, value)
elif isinstance(cookies, list):
for key, value in cookies:
self.set(key, value)
elif isinstance(cookies, Cookies):
for cookie in cookies.jar:
self.jar.set_cookie(cookie)
else:
raise TypeError(f"Cookies must be dict or list, not {type(cookies)}")
def set(self, name: str, value: str, domain: str = "", path: str = "/") -> None:
cookie = Cookie(
version=0,
name=name,
value=value,
port=None,
port_specified=False,
domain=domain,
domain_specified=bool(domain),
domain_initial_dot=domain.startswith("."),
path=path,
path_specified=bool(path),
secure=False,
expires=None,
discard=True,
comment=None,
comment_url=None,
rest={},
rfc2109=False,
)
self.jar.set_cookie(cookie)
def get(
self,
name: str,
default: Optional[str] = None,
domain: str = None,
path: str = None,
) -> Optional[str]:
value: Optional[str] = None
for cookie in self.jar:
if (
cookie.name == name
and (domain is None or cookie.domain == domain)
and (path is None or cookie.path == path)
):
if value is not None:
message = f"Multiple cookies exist with name={name}"
raise ValueError(message)
value = cookie.value
return default if value is None else value
def delete(
self, name: str, domain: Optional[str] = None, path: Optional[str] = None
) -> None:
if domain is not None and path is not None:
return self.jar.clear(domain, path, name)
remove = [
cookie
for cookie in self.jar
if cookie.name == name
and (domain is None or cookie.domain == domain)
and (path is None or cookie.path == path)
]
for cookie in remove:
self.jar.clear(cookie.domain, cookie.path, cookie.name)
def clear(self, domain: Optional[str] = None, path: Optional[str] = None) -> None:
self.jar.clear(domain, path)
def update(self, cookies: CookieTypes = None) -> None:
cookies = Cookies(cookies)
for cookie in cookies.jar:
self.jar.set_cookie(cookie)
def __setitem__(self, name: str, value: str) -> None:
return self.set(name, value)
def __getitem__(self, name: str) -> str:
value = self.get(name)
if value is None:
raise KeyError(name)
return value
def __delitem__(self, name: str) -> None:
return self.delete(name)
def __len__(self) -> int:
return len(self.jar)
def __iter__(self) -> Iterator[Cookie]:
return (cookie for cookie in self.jar)
def __repr__(self) -> str:
cookies_repr = ", ".join(
[
f"<Cookie {cookie.name}={cookie.value} for {cookie.domain} />"
for cookie in self.jar
]
)
return f"<Cookies [{cookies_repr}]>"

6
poetry.lock generated
View File

@ -513,7 +513,7 @@ name = "multidict"
version = "5.2.0" version = "5.2.0"
description = "multidict implementation" description = "multidict implementation"
category = "main" category = "main"
optional = true optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
[[package]] [[package]]
@ -1152,7 +1152,7 @@ name = "yarl"
version = "1.7.2" version = "1.7.2"
description = "Yet another URL library" description = "Yet another URL library"
category = "main" category = "main"
optional = true optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
[package.dependencies] [package.dependencies]
@ -1180,7 +1180,7 @@ quart = ["Quart"]
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.7.3" python-versions = "^3.7.3"
content-hash = "32ef230c4c02d1eb68754ad0bb938d1d960d7c6362ea2250cc23bbed910fbe34" content-hash = "3cef667f39a760749d4c1d35883be4cec9c224e8079ee28752f6aa7f58688442"
[metadata.files] [metadata.files]
aiodns = [ aiodns = [

View File

@ -34,6 +34,7 @@ httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"] }
pydantic = { version = "~1.8.0", extras = ["dotenv"] } pydantic = { version = "~1.8.0", extras = ["dotenv"] }
uvicorn = { version = "^0.15.0", extras = ["standard"] } uvicorn = { version = "^0.15.0", extras = ["standard"] }
aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true } aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true }
yarl = "^1.7.2"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
sphinx = "^4.1.1" sphinx = "^4.1.1"