mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 00:55:07 +08:00
♻️ rewrite driver request and response class
This commit is contained in:
parent
c0f321116a
commit
ec9e159ef6
@ -7,22 +7,13 @@
|
|||||||
|
|
||||||
import abc
|
import abc
|
||||||
import asyncio
|
import asyncio
|
||||||
from dataclasses import field, dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, Awaitable
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
Set,
|
|
||||||
Dict,
|
|
||||||
Type,
|
|
||||||
Union,
|
|
||||||
Callable,
|
|
||||||
Optional,
|
|
||||||
Awaitable,
|
|
||||||
)
|
|
||||||
|
|
||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.utils import escape_tag
|
from nonebot.utils import escape_tag
|
||||||
from nonebot.config import Env, Config
|
from nonebot.config import Env, Config
|
||||||
|
from ._model import URL, Request, Response, WebSocket
|
||||||
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
|
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -213,11 +204,11 @@ class ForwardDriver(Driver):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def request(self, setup: "HTTPRequest") -> Any:
|
async def request(self, setup: "Request") -> Any:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def websocket(self, setup: "HTTPConnection") -> Any:
|
async def websocket(self, setup: "Request") -> Any:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@ -247,153 +238,16 @@ class ReverseDriver(Driver):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
# TODO: repack dataclass
|
|
||||||
@dataclass
|
|
||||||
class HTTPConnection(abc.ABC):
|
|
||||||
http_version: str
|
|
||||||
"""One of ``"1.0"``, ``"1.1"`` or ``"2"``."""
|
|
||||||
scheme: str
|
|
||||||
"""URL scheme portion (likely ``"http"`` or ``"https"``)."""
|
|
||||||
path: str
|
|
||||||
"""
|
|
||||||
HTTP request target excluding any query string,
|
|
||||||
with percent-encoded sequences and UTF-8 byte sequences
|
|
||||||
decoded into characters.
|
|
||||||
"""
|
|
||||||
query_string: bytes = b""
|
|
||||||
""" URL portion after the ``?``, percent-encoded."""
|
|
||||||
headers: Dict[str, str] = field(default_factory=dict)
|
|
||||||
"""A dict of name-value pairs,
|
|
||||||
where name is the header name, and value is the header value.
|
|
||||||
|
|
||||||
Order of header values must be preserved from the original HTTP request;
|
|
||||||
order of header names is not important.
|
|
||||||
|
|
||||||
Header names must be lowercased.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abc.abstractmethod
|
|
||||||
def type(self) -> str:
|
|
||||||
"""Connection type."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class HTTPRequest(HTTPConnection):
|
|
||||||
"""HTTP 请求封装。参考 `asgi http scope`_。
|
|
||||||
|
|
||||||
.. _asgi http scope:
|
|
||||||
https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
|
|
||||||
"""
|
|
||||||
|
|
||||||
method: str = "GET"
|
|
||||||
"""The HTTP method name, uppercased."""
|
|
||||||
body: bytes = b""
|
|
||||||
"""Body of the request.
|
|
||||||
|
|
||||||
Optional; if missing defaults to ``b""``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def type(self) -> str:
|
|
||||||
"""Always ``http``"""
|
|
||||||
return "http"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class HTTPResponse:
|
|
||||||
"""HTTP 响应封装。参考 `asgi http scope`_。
|
|
||||||
|
|
||||||
.. _asgi http scope:
|
|
||||||
https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
|
|
||||||
"""
|
|
||||||
|
|
||||||
status: int
|
|
||||||
"""HTTP status code."""
|
|
||||||
body: Optional[bytes] = None
|
|
||||||
"""HTTP body content.
|
|
||||||
|
|
||||||
Optional; if missing defaults to ``None``.
|
|
||||||
"""
|
|
||||||
headers: Dict[str, str] = field(default_factory=dict)
|
|
||||||
"""A dict of name-value pairs,
|
|
||||||
where name is the header name, and value is the header value.
|
|
||||||
|
|
||||||
Order must be preserved in the HTTP response.
|
|
||||||
|
|
||||||
Header names must be lowercased.
|
|
||||||
|
|
||||||
Optional; if missing defaults to an empty dict.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def type(self) -> str:
|
|
||||||
"""Always ``http``"""
|
|
||||||
return "http"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class WebSocket(HTTPConnection, abc.ABC):
|
|
||||||
"""WebSocket 连接封装。参考 `asgi websocket scope`_。
|
|
||||||
|
|
||||||
.. _asgi websocket scope:
|
|
||||||
https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def type(self) -> str:
|
|
||||||
"""Always ``websocket``"""
|
|
||||||
return "websocket"
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abc.abstractmethod
|
|
||||||
def closed(self) -> bool:
|
|
||||||
"""
|
|
||||||
:类型: ``bool``
|
|
||||||
:说明: 连接是否已经关闭
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def accept(self):
|
|
||||||
"""接受 WebSocket 连接请求"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def close(self, code: int):
|
|
||||||
"""关闭 WebSocket 连接请求"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def receive(self) -> str:
|
|
||||||
"""接收一条 WebSocket text 信息"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def receive_bytes(self) -> bytes:
|
|
||||||
"""接收一条 WebSocket binary 信息"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def send(self, data: str):
|
|
||||||
"""发送一条 WebSocket text 信息"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def send_bytes(self, data: bytes):
|
|
||||||
"""发送一条 WebSocket binary 信息"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HTTPServerSetup:
|
class HTTPServerSetup:
|
||||||
path: str
|
path: URL # path should not be absolute, check it by URL.is_absolute() == False
|
||||||
method: str
|
method: str
|
||||||
handle_func: Callable[[HTTPRequest], Awaitable[HTTPResponse]]
|
name: str
|
||||||
|
handle_func: Callable[[Request], Awaitable[Response]]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WebSocketServerSetup:
|
class WebSocketServerSetup:
|
||||||
path: str
|
path: URL # path should not be absolute, check it by URL.is_absolute() == False
|
||||||
|
name: str
|
||||||
handle_func: Callable[[WebSocket], Awaitable[Any]]
|
handle_func: Callable[[WebSocket], Awaitable[Any]]
|
||||||
|
282
nonebot/drivers/_model.py
Normal file
282
nonebot/drivers/_model.py
Normal file
@ -0,0 +1,282 @@
|
|||||||
|
import abc
|
||||||
|
from enum import Enum
|
||||||
|
from http.cookiejar import Cookie, CookieJar
|
||||||
|
from typing import (
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
Mapping,
|
||||||
|
Iterator,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
MutableMapping,
|
||||||
|
)
|
||||||
|
|
||||||
|
from yarl import URL as URL
|
||||||
|
from multidict import CIMultiDict
|
||||||
|
|
||||||
|
RawURL = Tuple[bytes, bytes, Optional[int], bytes]
|
||||||
|
|
||||||
|
SimpleQuery = Union[str, int, float]
|
||||||
|
QueryVariable = Union[SimpleQuery, Sequence[SimpleQuery]]
|
||||||
|
QueryTypes = Union[
|
||||||
|
None, str, Mapping[str, QueryVariable], Sequence[Tuple[str, QueryVariable]]
|
||||||
|
]
|
||||||
|
|
||||||
|
HeaderTypes = Union[
|
||||||
|
None,
|
||||||
|
CIMultiDict[str],
|
||||||
|
Dict[str, str],
|
||||||
|
Sequence[Tuple[str, str]],
|
||||||
|
]
|
||||||
|
|
||||||
|
ContentTypes = Union[str, bytes]
|
||||||
|
CookieTypes = Union[None, "Cookies", CookieJar, Dict[str, str], List[Tuple[str, str]]]
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPVersion(Enum):
|
||||||
|
H10 = "1.0"
|
||||||
|
H11 = "1.1"
|
||||||
|
H2 = "2"
|
||||||
|
|
||||||
|
|
||||||
|
class Request:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
method: Union[str, bytes],
|
||||||
|
url: Union["URL", str, RawURL],
|
||||||
|
*,
|
||||||
|
params: QueryTypes = None,
|
||||||
|
headers: HeaderTypes = None,
|
||||||
|
cookies: CookieTypes = None,
|
||||||
|
content: ContentTypes = None,
|
||||||
|
version: Union[str, HTTPVersion] = HTTPVersion.H11,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
):
|
||||||
|
# method
|
||||||
|
self.method = (
|
||||||
|
method.decode("ascii").upper()
|
||||||
|
if isinstance(method, bytes)
|
||||||
|
else method.upper()
|
||||||
|
)
|
||||||
|
# http version
|
||||||
|
self.version = HTTPVersion(version)
|
||||||
|
# timeout
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
|
# url
|
||||||
|
if isinstance(url, tuple):
|
||||||
|
scheme, host, port, path = url
|
||||||
|
url = URL.build(
|
||||||
|
scheme=scheme.decode("ascii"),
|
||||||
|
host=host.decode("ascii"),
|
||||||
|
port=port,
|
||||||
|
path=path.decode("ascii"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
url = URL(url)
|
||||||
|
|
||||||
|
if params is not None:
|
||||||
|
url = url.update_query(params)
|
||||||
|
self.url = url
|
||||||
|
|
||||||
|
# headers
|
||||||
|
if headers is not None:
|
||||||
|
self.headers = CIMultiDict(headers)
|
||||||
|
else:
|
||||||
|
self.headers = CIMultiDict()
|
||||||
|
|
||||||
|
# cookies
|
||||||
|
self.cookies = Cookies(cookies)
|
||||||
|
|
||||||
|
# body
|
||||||
|
self.content = content
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
class_name = self.__class__.__name__
|
||||||
|
url = str(self.url)
|
||||||
|
return f"<{class_name}({self.method!r}, {url!r})>"
|
||||||
|
|
||||||
|
|
||||||
|
class Response:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
status_code: int,
|
||||||
|
*,
|
||||||
|
headers: HeaderTypes = None,
|
||||||
|
content: ContentTypes = None,
|
||||||
|
request: Optional[Request] = None,
|
||||||
|
):
|
||||||
|
# status code
|
||||||
|
self.status_code = status_code
|
||||||
|
|
||||||
|
# headers
|
||||||
|
if headers is not None:
|
||||||
|
self.headers = CIMultiDict(headers)
|
||||||
|
else:
|
||||||
|
self.headers = CIMultiDict()
|
||||||
|
|
||||||
|
# body
|
||||||
|
self.content = content
|
||||||
|
|
||||||
|
# request
|
||||||
|
self.request = request
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocket(abc.ABC):
|
||||||
|
def __init__(self, *, request: Request):
|
||||||
|
# request
|
||||||
|
self.request = request
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
|
def closed(self) -> bool:
|
||||||
|
"""
|
||||||
|
:类型: ``bool``
|
||||||
|
:说明: 连接是否已经关闭
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def accept(self):
|
||||||
|
"""接受 WebSocket 连接请求"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def close(self, code: int):
|
||||||
|
"""关闭 WebSocket 连接请求"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def receive(self) -> str:
|
||||||
|
"""接收一条 WebSocket text 信息"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def receive_bytes(self) -> bytes:
|
||||||
|
"""接收一条 WebSocket binary 信息"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def send(self, data: str):
|
||||||
|
"""发送一条 WebSocket text 信息"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def send_bytes(self, data: bytes):
|
||||||
|
"""发送一条 WebSocket binary 信息"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class Cookies(MutableMapping):
|
||||||
|
def __init__(self, cookies: CookieTypes = None) -> None:
|
||||||
|
self.jar = cookies if isinstance(cookies, CookieJar) else CookieJar()
|
||||||
|
if cookies is not None and not isinstance(cookies, CookieJar):
|
||||||
|
if isinstance(cookies, dict):
|
||||||
|
for key, value in cookies.items():
|
||||||
|
self.set(key, value)
|
||||||
|
elif isinstance(cookies, list):
|
||||||
|
for key, value in cookies:
|
||||||
|
self.set(key, value)
|
||||||
|
elif isinstance(cookies, Cookies):
|
||||||
|
for cookie in cookies.jar:
|
||||||
|
self.jar.set_cookie(cookie)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Cookies must be dict or list, not {type(cookies)}")
|
||||||
|
|
||||||
|
def set(self, name: str, value: str, domain: str = "", path: str = "/") -> None:
|
||||||
|
cookie = Cookie(
|
||||||
|
version=0,
|
||||||
|
name=name,
|
||||||
|
value=value,
|
||||||
|
port=None,
|
||||||
|
port_specified=False,
|
||||||
|
domain=domain,
|
||||||
|
domain_specified=bool(domain),
|
||||||
|
domain_initial_dot=domain.startswith("."),
|
||||||
|
path=path,
|
||||||
|
path_specified=bool(path),
|
||||||
|
secure=False,
|
||||||
|
expires=None,
|
||||||
|
discard=True,
|
||||||
|
comment=None,
|
||||||
|
comment_url=None,
|
||||||
|
rest={},
|
||||||
|
rfc2109=False,
|
||||||
|
)
|
||||||
|
self.jar.set_cookie(cookie)
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
default: Optional[str] = None,
|
||||||
|
domain: str = None,
|
||||||
|
path: str = None,
|
||||||
|
) -> Optional[str]:
|
||||||
|
value: Optional[str] = None
|
||||||
|
for cookie in self.jar:
|
||||||
|
if (
|
||||||
|
cookie.name == name
|
||||||
|
and (domain is None or cookie.domain == domain)
|
||||||
|
and (path is None or cookie.path == path)
|
||||||
|
):
|
||||||
|
if value is not None:
|
||||||
|
message = f"Multiple cookies exist with name={name}"
|
||||||
|
raise ValueError(message)
|
||||||
|
value = cookie.value
|
||||||
|
|
||||||
|
return default if value is None else value
|
||||||
|
|
||||||
|
def delete(
|
||||||
|
self, name: str, domain: Optional[str] = None, path: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
if domain is not None and path is not None:
|
||||||
|
return self.jar.clear(domain, path, name)
|
||||||
|
|
||||||
|
remove = [
|
||||||
|
cookie
|
||||||
|
for cookie in self.jar
|
||||||
|
if cookie.name == name
|
||||||
|
and (domain is None or cookie.domain == domain)
|
||||||
|
and (path is None or cookie.path == path)
|
||||||
|
]
|
||||||
|
|
||||||
|
for cookie in remove:
|
||||||
|
self.jar.clear(cookie.domain, cookie.path, cookie.name)
|
||||||
|
|
||||||
|
def clear(self, domain: Optional[str] = None, path: Optional[str] = None) -> None:
|
||||||
|
self.jar.clear(domain, path)
|
||||||
|
|
||||||
|
def update(self, cookies: CookieTypes = None) -> None:
|
||||||
|
cookies = Cookies(cookies)
|
||||||
|
for cookie in cookies.jar:
|
||||||
|
self.jar.set_cookie(cookie)
|
||||||
|
|
||||||
|
def __setitem__(self, name: str, value: str) -> None:
|
||||||
|
return self.set(name, value)
|
||||||
|
|
||||||
|
def __getitem__(self, name: str) -> str:
|
||||||
|
value = self.get(name)
|
||||||
|
if value is None:
|
||||||
|
raise KeyError(name)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def __delitem__(self, name: str) -> None:
|
||||||
|
return self.delete(name)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.jar)
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[Cookie]:
|
||||||
|
return (cookie for cookie in self.jar)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
cookies_repr = ", ".join(
|
||||||
|
[
|
||||||
|
f"<Cookie {cookie.name}={cookie.value} for {cookie.domain} />"
|
||||||
|
for cookie in self.jar
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"<Cookies [{cookies_repr}]>"
|
6
poetry.lock
generated
6
poetry.lock
generated
@ -513,7 +513,7 @@ name = "multidict"
|
|||||||
version = "5.2.0"
|
version = "5.2.0"
|
||||||
description = "multidict implementation"
|
description = "multidict implementation"
|
||||||
category = "main"
|
category = "main"
|
||||||
optional = true
|
optional = false
|
||||||
python-versions = ">=3.6"
|
python-versions = ">=3.6"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -1152,7 +1152,7 @@ name = "yarl"
|
|||||||
version = "1.7.2"
|
version = "1.7.2"
|
||||||
description = "Yet another URL library"
|
description = "Yet another URL library"
|
||||||
category = "main"
|
category = "main"
|
||||||
optional = true
|
optional = false
|
||||||
python-versions = ">=3.6"
|
python-versions = ">=3.6"
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@ -1180,7 +1180,7 @@ quart = ["Quart"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "1.1"
|
lock-version = "1.1"
|
||||||
python-versions = "^3.7.3"
|
python-versions = "^3.7.3"
|
||||||
content-hash = "32ef230c4c02d1eb68754ad0bb938d1d960d7c6362ea2250cc23bbed910fbe34"
|
content-hash = "3cef667f39a760749d4c1d35883be4cec9c224e8079ee28752f6aa7f58688442"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
aiodns = [
|
aiodns = [
|
||||||
|
@ -34,6 +34,7 @@ httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"] }
|
|||||||
pydantic = { version = "~1.8.0", extras = ["dotenv"] }
|
pydantic = { version = "~1.8.0", extras = ["dotenv"] }
|
||||||
uvicorn = { version = "^0.15.0", extras = ["standard"] }
|
uvicorn = { version = "^0.15.0", extras = ["standard"] }
|
||||||
aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true }
|
aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true }
|
||||||
|
yarl = "^1.7.2"
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
sphinx = "^4.1.1"
|
sphinx = "^4.1.1"
|
||||||
|
Loading…
Reference in New Issue
Block a user