diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py
index 1785e1b6..75398eb9 100644
--- a/nonebot/drivers/__init__.py
+++ b/nonebot/drivers/__init__.py
@@ -7,22 +7,13 @@
import abc
import asyncio
-from dataclasses import field, dataclass
-from typing import (
- TYPE_CHECKING,
- Any,
- Set,
- Dict,
- Type,
- Union,
- Callable,
- Optional,
- Awaitable,
-)
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, Awaitable
from nonebot.log import logger
from nonebot.utils import escape_tag
from nonebot.config import Env, Config
+from ._model import URL, Request, Response, WebSocket
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
if TYPE_CHECKING:
@@ -213,11 +204,11 @@ class ForwardDriver(Driver):
"""
@abc.abstractmethod
- async def request(self, setup: "HTTPRequest") -> Any:
+ async def request(self, setup: "Request") -> Any:
raise NotImplementedError
@abc.abstractmethod
- async def websocket(self, setup: "HTTPConnection") -> Any:
+ async def websocket(self, setup: "Request") -> Any:
raise NotImplementedError
@@ -247,153 +238,16 @@ class ReverseDriver(Driver):
raise NotImplementedError
-# TODO: repack dataclass
-@dataclass
-class HTTPConnection(abc.ABC):
- http_version: str
- """One of ``"1.0"``, ``"1.1"`` or ``"2"``."""
- scheme: str
- """URL scheme portion (likely ``"http"`` or ``"https"``)."""
- path: str
- """
- HTTP request target excluding any query string,
- with percent-encoded sequences and UTF-8 byte sequences
- decoded into characters.
- """
- query_string: bytes = b""
- """ URL portion after the ``?``, percent-encoded."""
- headers: Dict[str, str] = field(default_factory=dict)
- """A dict of name-value pairs,
- where name is the header name, and value is the header value.
-
- Order of header values must be preserved from the original HTTP request;
- order of header names is not important.
-
- Header names must be lowercased.
- """
-
- @property
- @abc.abstractmethod
- def type(self) -> str:
- """Connection type."""
- raise NotImplementedError
-
-
-@dataclass
-class HTTPRequest(HTTPConnection):
- """HTTP 请求封装。参考 `asgi http scope`_。
-
- .. _asgi http scope:
- https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
- """
-
- method: str = "GET"
- """The HTTP method name, uppercased."""
- body: bytes = b""
- """Body of the request.
-
- Optional; if missing defaults to ``b""``.
- """
-
- @property
- def type(self) -> str:
- """Always ``http``"""
- return "http"
-
-
-@dataclass
-class HTTPResponse:
- """HTTP 响应封装。参考 `asgi http scope`_。
-
- .. _asgi http scope:
- https://asgi.readthedocs.io/en/latest/specs/www.html#http-connection-scope
- """
-
- status: int
- """HTTP status code."""
- body: Optional[bytes] = None
- """HTTP body content.
-
- Optional; if missing defaults to ``None``.
- """
- headers: Dict[str, str] = field(default_factory=dict)
- """A dict of name-value pairs,
- where name is the header name, and value is the header value.
-
- Order must be preserved in the HTTP response.
-
- Header names must be lowercased.
-
- Optional; if missing defaults to an empty dict.
- """
-
- @property
- def type(self) -> str:
- """Always ``http``"""
- return "http"
-
-
-@dataclass
-class WebSocket(HTTPConnection, abc.ABC):
- """WebSocket 连接封装。参考 `asgi websocket scope`_。
-
- .. _asgi websocket scope:
- https://asgi.readthedocs.io/en/latest/specs/www.html#websocket-connection-scope
- """
-
- @property
- def type(self) -> str:
- """Always ``websocket``"""
- return "websocket"
-
- @property
- @abc.abstractmethod
- def closed(self) -> bool:
- """
- :类型: ``bool``
- :说明: 连接是否已经关闭
- """
- raise NotImplementedError
-
- @abc.abstractmethod
- async def accept(self):
- """接受 WebSocket 连接请求"""
- raise NotImplementedError
-
- @abc.abstractmethod
- async def close(self, code: int):
- """关闭 WebSocket 连接请求"""
- raise NotImplementedError
-
- @abc.abstractmethod
- async def receive(self) -> str:
- """接收一条 WebSocket text 信息"""
- raise NotImplementedError
-
- @abc.abstractmethod
- async def receive_bytes(self) -> bytes:
- """接收一条 WebSocket binary 信息"""
- raise NotImplementedError
-
- @abc.abstractmethod
- async def send(self, data: str):
- """发送一条 WebSocket text 信息"""
- raise NotImplementedError
-
- @abc.abstractmethod
- async def send_bytes(self, data: bytes):
- """发送一条 WebSocket binary 信息"""
- raise NotImplementedError
-
-
@dataclass
class HTTPServerSetup:
- path: str
+ path: URL # path should not be absolute, check it by URL.is_absolute() == False
method: str
- handle_func: Callable[[HTTPRequest], Awaitable[HTTPResponse]]
+ name: str
+ handle_func: Callable[[Request], Awaitable[Response]]
@dataclass
class WebSocketServerSetup:
- path: str
+ path: URL # path should not be absolute, check it by URL.is_absolute() == False
+ name: str
handle_func: Callable[[WebSocket], Awaitable[Any]]
diff --git a/nonebot/drivers/_model.py b/nonebot/drivers/_model.py
new file mode 100644
index 00000000..ede3276a
--- /dev/null
+++ b/nonebot/drivers/_model.py
@@ -0,0 +1,282 @@
+import abc
+from enum import Enum
+from http.cookiejar import Cookie, CookieJar
+from typing import (
+ Dict,
+ List,
+ Tuple,
+ Union,
+ Mapping,
+ Iterator,
+ Optional,
+ Sequence,
+ MutableMapping,
+)
+
+from yarl import URL as URL
+from multidict import CIMultiDict
+
+RawURL = Tuple[bytes, bytes, Optional[int], bytes]
+
+SimpleQuery = Union[str, int, float]
+QueryVariable = Union[SimpleQuery, Sequence[SimpleQuery]]
+QueryTypes = Union[
+ None, str, Mapping[str, QueryVariable], Sequence[Tuple[str, QueryVariable]]
+]
+
+HeaderTypes = Union[
+ None,
+ CIMultiDict[str],
+ Dict[str, str],
+ Sequence[Tuple[str, str]],
+]
+
+ContentTypes = Union[str, bytes]
+CookieTypes = Union[None, "Cookies", CookieJar, Dict[str, str], List[Tuple[str, str]]]
+
+
+class HTTPVersion(Enum):
+ H10 = "1.0"
+ H11 = "1.1"
+ H2 = "2"
+
+
+class Request:
+ def __init__(
+ self,
+ method: Union[str, bytes],
+ url: Union["URL", str, RawURL],
+ *,
+ params: QueryTypes = None,
+ headers: HeaderTypes = None,
+ cookies: CookieTypes = None,
+ content: ContentTypes = None,
+ version: Union[str, HTTPVersion] = HTTPVersion.H11,
+ timeout: Optional[float] = None,
+ ):
+ # method
+ self.method = (
+ method.decode("ascii").upper()
+ if isinstance(method, bytes)
+ else method.upper()
+ )
+ # http version
+ self.version = HTTPVersion(version)
+ # timeout
+ self.timeout = timeout
+
+ # url
+ if isinstance(url, tuple):
+ scheme, host, port, path = url
+ url = URL.build(
+ scheme=scheme.decode("ascii"),
+ host=host.decode("ascii"),
+ port=port,
+ path=path.decode("ascii"),
+ )
+ else:
+ url = URL(url)
+
+ if params is not None:
+ url = url.update_query(params)
+ self.url = url
+
+ # headers
+ if headers is not None:
+ self.headers = CIMultiDict(headers)
+ else:
+ self.headers = CIMultiDict()
+
+ # cookies
+ self.cookies = Cookies(cookies)
+
+ # body
+ self.content = content
+
+ def __repr__(self) -> str:
+ class_name = self.__class__.__name__
+ url = str(self.url)
+ return f"<{class_name}({self.method!r}, {url!r})>"
+
+
+class Response:
+ def __init__(
+ self,
+ status_code: int,
+ *,
+ headers: HeaderTypes = None,
+ content: ContentTypes = None,
+ request: Optional[Request] = None,
+ ):
+ # status code
+ self.status_code = status_code
+
+ # headers
+ if headers is not None:
+ self.headers = CIMultiDict(headers)
+ else:
+ self.headers = CIMultiDict()
+
+ # body
+ self.content = content
+
+ # request
+ self.request = request
+
+
+class WebSocket(abc.ABC):
+ def __init__(self, *, request: Request):
+ # request
+ self.request = request
+
+ @property
+ @abc.abstractmethod
+ def closed(self) -> bool:
+ """
+ :类型: ``bool``
+ :说明: 连接是否已经关闭
+ """
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ async def accept(self):
+ """接受 WebSocket 连接请求"""
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ async def close(self, code: int):
+ """关闭 WebSocket 连接请求"""
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ async def receive(self) -> str:
+ """接收一条 WebSocket text 信息"""
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ async def receive_bytes(self) -> bytes:
+ """接收一条 WebSocket binary 信息"""
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ async def send(self, data: str):
+ """发送一条 WebSocket text 信息"""
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ async def send_bytes(self, data: bytes):
+ """发送一条 WebSocket binary 信息"""
+ raise NotImplementedError
+
+
+class Cookies(MutableMapping):
+ def __init__(self, cookies: CookieTypes = None) -> None:
+ self.jar = cookies if isinstance(cookies, CookieJar) else CookieJar()
+ if cookies is not None and not isinstance(cookies, CookieJar):
+ if isinstance(cookies, dict):
+ for key, value in cookies.items():
+ self.set(key, value)
+ elif isinstance(cookies, list):
+ for key, value in cookies:
+ self.set(key, value)
+ elif isinstance(cookies, Cookies):
+ for cookie in cookies.jar:
+ self.jar.set_cookie(cookie)
+ else:
+ raise TypeError(f"Cookies must be dict or list, not {type(cookies)}")
+
+ def set(self, name: str, value: str, domain: str = "", path: str = "/") -> None:
+ cookie = Cookie(
+ version=0,
+ name=name,
+ value=value,
+ port=None,
+ port_specified=False,
+ domain=domain,
+ domain_specified=bool(domain),
+ domain_initial_dot=domain.startswith("."),
+ path=path,
+ path_specified=bool(path),
+ secure=False,
+ expires=None,
+ discard=True,
+ comment=None,
+ comment_url=None,
+ rest={},
+ rfc2109=False,
+ )
+ self.jar.set_cookie(cookie)
+
+ def get(
+ self,
+ name: str,
+ default: Optional[str] = None,
+ domain: str = None,
+ path: str = None,
+ ) -> Optional[str]:
+ value: Optional[str] = None
+ for cookie in self.jar:
+ if (
+ cookie.name == name
+ and (domain is None or cookie.domain == domain)
+ and (path is None or cookie.path == path)
+ ):
+ if value is not None:
+ message = f"Multiple cookies exist with name={name}"
+ raise ValueError(message)
+ value = cookie.value
+
+ return default if value is None else value
+
+ def delete(
+ self, name: str, domain: Optional[str] = None, path: Optional[str] = None
+ ) -> None:
+ if domain is not None and path is not None:
+ return self.jar.clear(domain, path, name)
+
+ remove = [
+ cookie
+ for cookie in self.jar
+ if cookie.name == name
+ and (domain is None or cookie.domain == domain)
+ and (path is None or cookie.path == path)
+ ]
+
+ for cookie in remove:
+ self.jar.clear(cookie.domain, cookie.path, cookie.name)
+
+ def clear(self, domain: Optional[str] = None, path: Optional[str] = None) -> None:
+ self.jar.clear(domain, path)
+
+ def update(self, cookies: CookieTypes = None) -> None:
+ cookies = Cookies(cookies)
+ for cookie in cookies.jar:
+ self.jar.set_cookie(cookie)
+
+ def __setitem__(self, name: str, value: str) -> None:
+ return self.set(name, value)
+
+ def __getitem__(self, name: str) -> str:
+ value = self.get(name)
+ if value is None:
+ raise KeyError(name)
+ return value
+
+ def __delitem__(self, name: str) -> None:
+ return self.delete(name)
+
+ def __len__(self) -> int:
+ return len(self.jar)
+
+ def __iter__(self) -> Iterator[Cookie]:
+ return (cookie for cookie in self.jar)
+
+ def __repr__(self) -> str:
+ cookies_repr = ", ".join(
+ [
+ f""
+ for cookie in self.jar
+ ]
+ )
+
+ return f""
diff --git a/poetry.lock b/poetry.lock
index 12ca0bc1..fedf1714 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -513,7 +513,7 @@ name = "multidict"
version = "5.2.0"
description = "multidict implementation"
category = "main"
-optional = true
+optional = false
python-versions = ">=3.6"
[[package]]
@@ -1152,7 +1152,7 @@ name = "yarl"
version = "1.7.2"
description = "Yet another URL library"
category = "main"
-optional = true
+optional = false
python-versions = ">=3.6"
[package.dependencies]
@@ -1180,7 +1180,7 @@ quart = ["Quart"]
[metadata]
lock-version = "1.1"
python-versions = "^3.7.3"
-content-hash = "32ef230c4c02d1eb68754ad0bb938d1d960d7c6362ea2250cc23bbed910fbe34"
+content-hash = "3cef667f39a760749d4c1d35883be4cec9c224e8079ee28752f6aa7f58688442"
[metadata.files]
aiodns = [
diff --git a/pyproject.toml b/pyproject.toml
index 4312f2a8..5dfd8cbe 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -34,6 +34,7 @@ httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"] }
pydantic = { version = "~1.8.0", extras = ["dotenv"] }
uvicorn = { version = "^0.15.0", extras = ["standard"] }
aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true }
+yarl = "^1.7.2"
[tool.poetry.dev-dependencies]
sphinx = "^4.1.1"