♻️ rewrite quart driver

This commit is contained in:
yanyongyu 2021-12-20 15:46:23 +08:00
parent c49059f9d3
commit ea8f7717b9
3 changed files with 64 additions and 129 deletions

View File

@ -27,9 +27,9 @@ from nonebot.config import Env
from nonebot.typing import overrides
from nonebot.utils import escape_tag
from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import Request as GenericRequest
from nonebot.drivers import Request as BaseRequest
from nonebot.drivers import Response as BaseResponse
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import Response as GenericResponse
from nonebot.drivers import (
HTTPVersion,
ForwardDriver,
@ -247,7 +247,7 @@ class Driver(ReverseDriver):
request: Request,
setup: HTTPServerSetup,
):
http_request = GenericRequest(
http_request = BaseRequest(
request.method,
str(request.url),
headers=request.headers.items(),
@ -260,7 +260,7 @@ class Driver(ReverseDriver):
return Response(response.content, response.status_code, dict(response.headers))
async def _handle_ws(self, websocket: WebSocket, setup: WebSocketServerSetup):
request = GenericRequest(
request = BaseRequest(
"GET",
str(websocket.url),
headers=websocket.headers.items(),
@ -293,7 +293,7 @@ class FullDriver(ForwardDriver, Driver):
return "fastapi_full"
@overrides(ForwardDriver)
async def request(self, setup: "GenericRequest") -> Any:
async def request(self, setup: "BaseRequest") -> Any:
async with httpx.AsyncClient(
http2=setup.version == HTTPVersion.H2, follow_redirects=True
) as client:
@ -304,7 +304,7 @@ class FullDriver(ForwardDriver, Driver):
headers=tuple(setup.headers.items()),
timeout=30.0,
)
return GenericResponse(
return BaseResponse(
response.status_code,
headers=response.headers,
content=response.content,
@ -312,31 +312,24 @@ class FullDriver(ForwardDriver, Driver):
)
@overrides(ForwardDriver)
async def websocket(self, setup: "GenericRequest") -> Any:
async def websocket(self, setup: "BaseRequest") -> Any:
ws = await Connect(str(setup.url), extra_headers=setup.headers.items())
return WebSocketsWS(request=setup, websocket=ws)
class WebSocketsWS(BaseWebSocket):
@overrides(BaseWebSocket)
def __init__(self, *, request: GenericRequest, websocket: WebSocketClientProtocol):
def __init__(self, *, request: BaseRequest, websocket: WebSocketClientProtocol):
super().__init__(request=request)
self.websocket = websocket
@property
@overrides(BaseWebSocket)
def closed(self) -> bool:
# if isinstance(self.websocket, WebSocket):
# return (
# self.websocket.client_state == WebSocketState.DISCONNECTED
# or self.websocket.application_state == WebSocketState.DISCONNECTED
# )
return self.websocket.closed
@overrides(BaseWebSocket)
async def accept(self):
# if isinstance(self.websocket, WebSocket):
# await self.websocket.accept()
raise NotImplementedError
@overrides(BaseWebSocket)
@ -345,8 +338,6 @@ class WebSocketsWS(BaseWebSocket):
@overrides(BaseWebSocket)
async def receive(self) -> str:
# if isinstance(self.websocket, WebSocket):
# return await self.websocket.receive_text()
msg = await self.websocket.recv()
if isinstance(msg, bytes):
raise TypeError("WebSocket received unexpected frame type: bytes")
@ -354,8 +345,6 @@ class WebSocketsWS(BaseWebSocket):
@overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes:
# if isinstance(self.websocket, WebSocket):
# return await self.websocket.receive_bytes()
msg = await self.websocket.recv()
if isinstance(msg, str):
raise TypeError("WebSocket received unexpected frame type: str")
@ -363,20 +352,16 @@ class WebSocketsWS(BaseWebSocket):
@overrides(BaseWebSocket)
async def send(self, data: str) -> None:
# if isinstance(self.websocket, WebSocket):
# await self.websocket.send({"type": "websocket.send", "text": data})
await self.websocket.send(data)
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None:
# if isinstance(self.websocket, WebSocket):
# await self.websocket.send({"type": "websocket.send", "bytes": data})
await self.websocket.send(data)
class FastAPIWebSocket(BaseWebSocket):
@overrides(BaseWebSocket)
def __init__(self, *, request: GenericRequest, websocket: WebSocket):
def __init__(self, *, request: BaseRequest, websocket: WebSocket):
super().__init__(request=request)
self.websocket = websocket

View File

@ -8,8 +8,7 @@ Quart 驱动适配
https://pgjones.gitlab.io/quart/index.html
"""
import asyncio
from dataclasses import dataclass
from functools import partial
from typing import List, TypeVar, Callable, Optional, Coroutine
import uvicorn
@ -20,8 +19,9 @@ from nonebot.log import logger
from nonebot.typing import overrides
from nonebot.utils import escape_tag
from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import HTTPRequest, ReverseDriver
from nonebot.drivers import Request as BaseRequest
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup
try:
from quart import request as _request
@ -98,11 +98,6 @@ class Config(BaseSettings):
class Driver(ReverseDriver):
"""
Quart 驱动框架
:上报地址:
* ``/{adapter name}/http``: HTTP POST 上报
* ``/{adapter name}/ws``: WebSocket 上报
"""
def __init__(self, env: Env, config: NoneBotConfig):
@ -111,12 +106,6 @@ class Driver(ReverseDriver):
self.quart_config = Config(**config.dict())
self._server_app = Quart(self.__class__.__qualname__)
self._server_app.add_url_rule(
"/<adapter>/http", methods=["POST"], view_func=self._handle_http
)
self._server_app.add_websocket(
"/<adapter>/ws", view_func=self._handle_ws_reverse
)
@property
@overrides(ReverseDriver)
@ -142,6 +131,21 @@ class Driver(ReverseDriver):
"""Quart 使用的 logger"""
return self._server_app.logger
@overrides(ReverseDriver)
def setup_http_server(self, setup: HTTPServerSetup):
self._server_app.add_url_rule(
setup.path.path,
methods=[setup.method],
view_func=partial(self._handle_http, setup=setup),
)
@overrides(ReverseDriver)
def setup_websocket_server(self, setup: WebSocketServerSetup) -> None:
self._server_app.add_websocket(
setup.path.path,
view_func=partial(self._handle_ws, setup=setup),
)
@overrides(ReverseDriver)
def on_startup(self, func: _AsyncCallable) -> _AsyncCallable:
"""参考文档: `Startup and Shutdown`_
@ -199,128 +203,75 @@ class Driver(ReverseDriver):
**kwargs,
)
async def _handle_http(self, adapter: str):
async def _handle_http(self, setup: HTTPServerSetup) -> Response:
request: Request = _request
data: bytes = await request.get_data() # type: ignore
if adapter not in self._adapters:
logger.warning(
f"Unknown adapter {adapter}. " "Please register the adapter before use."
)
raise exceptions.NotFound()
BotClass = self._adapters[adapter]
http_request = HTTPRequest(
request.http_version,
request.scheme,
request.path,
request.query_string,
dict(request.headers),
http_request = BaseRequest(
request.method,
data,
request.url,
headers=request.headers.items(),
cookies=list(request.cookies.items()),
content=await request.get_data(
cache=False, as_text=False, parse_form_data=False
),
version=request.http_version,
)
self_id, response = await BotClass.check_permission(self, http_request)
response = await setup.handle_func(http_request)
if not self_id:
raise exceptions.Unauthorized(
description=(response and response.body or b"").decode()
)
if self_id in self._clients:
logger.warning(
"There's already a reverse websocket connection,"
"so the event may be handled twice."
)
bot = BotClass(self_id, http_request)
asyncio.create_task(bot.handle_message(data))
return Response(
response and response.body or "", response and response.status or 200
response.content or "",
response.status_code or 200,
headers=dict(response.headers),
)
async def _handle_ws_reverse(self, adapter: str):
async def _handle_ws(self, setup: WebSocketServerSetup) -> None:
websocket: QuartWebSocket = _websocket
ws = WebSocket(
websocket.http_version,
websocket.scheme,
websocket.path,
websocket.query_string,
dict(websocket.headers),
websocket,
http_request = BaseRequest(
websocket.method,
websocket.url,
headers=websocket.headers.items(),
cookies=list(websocket.cookies.items()),
version=websocket.http_version,
)
if adapter not in self._adapters:
logger.warning(
f"Unknown adapter {adapter}. Please register the adapter before use."
)
raise exceptions.NotFound()
ws = WebSocket(request=http_request, websocket=websocket)
BotClass = self._adapters[adapter]
self_id, response = await BotClass.check_permission(self, ws)
if not self_id:
raise exceptions.Unauthorized(
description=(response and response.body or b"").decode()
)
if self_id in self._clients:
logger.opt(colors=True).warning(
"There's already a websocket connection, "
f"<y>{escape_tag(adapter.upper())} Bot {escape_tag(self_id)}</y> ignored."
)
raise exceptions.Forbidden(description="Client already exists.")
bot = BotClass(self_id, ws)
await ws.accept()
logger.opt(colors=True).info(
f"WebSocket Connection from <y>{escape_tag(adapter.upper())} "
f"Bot {escape_tag(self_id)}</y> Accepted!"
)
self._bot_connect(bot)
try:
while not ws.closed:
try:
data = await ws.receive()
except asyncio.CancelledError:
logger.warning("WebSocket disconnected by peer.")
break
except Exception as e:
logger.opt(exception=e).error(
"Error when receiving data from websocket."
)
break
asyncio.create_task(bot.handle_message(data.encode()))
finally:
self._bot_disconnect(bot)
await setup.handle_func(ws)
@dataclass
class WebSocket(BaseWebSocket):
websocket: QuartWebSocket = None # type: ignore
def __init__(self, *, request: BaseRequest, websocket: QuartWebSocket):
super().__init__(request=request)
self.websocket = websocket
@property
@overrides(BaseWebSocket)
def closed(self):
# FIXME
return False
raise NotImplementedError
@overrides(BaseWebSocket)
async def accept(self):
await self.websocket.accept()
@overrides(BaseWebSocket)
async def close(self):
# FIXME
pass
async def close(self, code: int = 1000):
await self.websocket.close(code)
@overrides(BaseWebSocket)
async def receive(self) -> str:
return await self.websocket.receive() # type: ignore
msg = await self.websocket.receive()
if isinstance(msg, bytes):
raise TypeError("WebSocket received unexpected frame type: bytes")
return msg
@overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes:
return await self.websocket.receive() # type: ignore
msg = await self.websocket.receive()
if isinstance(msg, str):
raise TypeError("WebSocket received unexpected frame type: str")
return msg
@overrides(BaseWebSocket)
async def send(self, data: str):

View File

@ -61,7 +61,6 @@ def Depends(
* ``dependency: Optional[Callable[..., Any]] = None``: 依赖函数默认为参数的类型注释
* ``use_cache: bool = True``: 是否使用缓存默认为 ``True``
* ``allow_types: Optional[List[Type[Param]]] = None``: 允许的参数类型默认为 ``None``
.. code-block:: python