🚧 rewrite fastapi driver implementation

This commit is contained in:
yanyongyu 2021-12-18 23:19:37 +08:00
parent ec9e159ef6
commit ca045b2f73
3 changed files with 123 additions and 87 deletions

View File

@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, Awaitable
from nonebot.log import logger
from nonebot.utils import escape_tag
from nonebot.config import Env, Config
from ._model import URL, Request, Response, WebSocket
from ._model import URL, Request, Response, WebSocket, HTTPVersion
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
if TYPE_CHECKING:

View File

@ -82,6 +82,7 @@ class Request:
self.url = url
# headers
self.headers: CIMultiDict[str]
if headers is not None:
self.headers = CIMultiDict(headers)
else:
@ -112,6 +113,7 @@ class Response:
self.status_code = status_code
# headers
self.headers: CIMultiDict[str]
if headers is not None:
self.headers = CIMultiDict(headers)
else:
@ -144,7 +146,7 @@ class WebSocket(abc.ABC):
raise NotImplementedError
@abc.abstractmethod
async def close(self, code: int):
async def close(self, code: int = 1000):
"""关闭 WebSocket 连接请求"""
raise NotImplementedError

View File

@ -10,7 +10,6 @@ FastAPI 驱动适配
https://fastapi.tiangolo.com/
"""
import asyncio
import logging
from functools import partial
from dataclasses import dataclass
@ -20,23 +19,21 @@ import httpx
import uvicorn
from pydantic import BaseSettings
from fastapi.responses import Response
from starlette.websockets import WebSocketState
from fastapi import Depends, FastAPI, Request, status
from starlette.websockets import WebSocket as FastAPIWebSocket
from fastapi import FastAPI, Request, status
from starlette.websockets import WebSocket, WebSocketState
from websockets.legacy.client import Connect, WebSocketClientProtocol
from nonebot.config import Env
from nonebot.typing import overrides
from nonebot.utils import escape_tag
from nonebot.drivers import WebSocket
from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import Request as GenericRequest
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import Response as GenericResponse
from nonebot.drivers import (
HTTPRequest,
HTTPResponse,
HTTPVersion,
ForwardDriver,
ReverseDriver,
HTTPConnection,
HTTPServerSetup,
WebSocketServerSetup,
)
@ -173,26 +170,23 @@ class Driver(ReverseDriver):
@overrides(ReverseDriver)
def setup_http_server(self, setup: HTTPServerSetup):
def _get_handle_func():
return setup.handle_func
async def _handle(request: Request) -> Response:
return await self._handle_http(request, setup)
self._server_app.add_api_route(
setup.path,
partial(self._handle_http, handle_func=Depends(_get_handle_func)),
setup.path.path,
_handle,
methods=[setup.method],
)
@overrides(ReverseDriver)
def setup_websocket_server(self, setup: WebSocketServerSetup) -> None:
def _get_handle_func():
return setup.handle_func
async def _handle(websocket: WebSocket):
await self._handle_ws(websocket, setup)
self._server_app.add_api_websocket_route(
setup.path,
partial(
self._handle_ws,
handle_func=Depends(_get_handle_func),
),
setup.path.path,
_handle,
)
@overrides(ReverseDriver)
@ -251,36 +245,34 @@ class Driver(ReverseDriver):
async def _handle_http(
self,
request: Request,
handle_func: Callable[[HTTPRequest], Awaitable[HTTPResponse]],
setup: HTTPServerSetup,
):
http_request = HTTPRequest(
request.scope["http_version"],
request.url.scheme,
request.url.path,
request.scope["query_string"],
dict(request.headers),
http_request = GenericRequest(
request.method,
await request.body(),
str(request.url),
headers=request.headers.items(),
cookies=request.cookies,
content=await request.body(),
version=request.scope["http_version"],
)
response = await handle_func(http_request)
return Response(response.body, response.status, response.headers)
response = await setup.handle_func(http_request)
return Response(response.content, response.status_code, dict(response.headers))
async def _handle_ws(
self,
websocket: FastAPIWebSocket,
handle_func: Callable[[WebSocket], Awaitable[Any]],
):
ws = WebSocket(
websocket.scope.get("http_version", "1.1"),
websocket.url.scheme,
websocket.url.path,
websocket.scope["query_string"],
dict(websocket.headers),
websocket,
async def _handle_ws(self, websocket: WebSocket, setup: WebSocketServerSetup):
request = GenericRequest(
"GET",
str(websocket.url),
headers=websocket.headers.items(),
cookies=websocket.cookies,
version=websocket.scope["http_version"],
)
ws = FastAPIWebSocket(
request=request,
websocket=websocket,
)
await handle_func(ws)
await setup.handle_func(ws)
class FullDriver(ForwardDriver, Driver):
@ -295,85 +287,127 @@ class FullDriver(ForwardDriver, Driver):
"""
@property
@overrides(ForwardDriver)
@overrides(Driver)
def type(self) -> str:
"""驱动名称: ``fastapi_full``"""
return "fastapi_full"
@overrides(ForwardDriver)
async def request(self, setup: "HTTPRequest") -> Any:
async def request(self, setup: "GenericRequest") -> Any:
async with httpx.AsyncClient(
http2=setup.http_version == "2", follow_redirects=True
http2=setup.version == HTTPVersion.H2, follow_redirects=True
) as client:
response = await client.request(
setup.method,
setup.url,
content=setup.body,
headers=setup.headers,
str(setup.url),
content=setup.content,
headers=tuple(setup.headers.items()),
timeout=30.0,
)
return HTTPResponse(
response.status_code, response.content, response.headers
return GenericResponse(
response.status_code,
headers=response.headers,
content=response.content,
request=setup,
)
@overrides(ForwardDriver)
async def websocket(self, setup: "HTTPConnection") -> Any:
ws = await Connect(setup.url, extra_headers=setup.headers)
return WebSocket("1.1", url.scheme, url.path, url.query, setup.headers, ws)
async def websocket(self, setup: "GenericRequest") -> Any:
ws = await Connect(str(setup.url), extra_headers=setup.headers.items())
return WebSocketsWS(request=setup, websocket=ws)
@dataclass
class WebSocket(BaseWebSocket):
websocket: Union[FastAPIWebSocket, WebSocketClientProtocol] = None # type: ignore
class WebSocketsWS(BaseWebSocket):
@overrides(BaseWebSocket)
def __init__(self, *, request: GenericRequest, websocket: WebSocketClientProtocol):
super().__init__(request=request)
self.websocket = websocket
@property
@overrides(BaseWebSocket)
def closed(self) -> bool:
if isinstance(self.websocket, FastAPIWebSocket):
return (
self.websocket.client_state == WebSocketState.DISCONNECTED
or self.websocket.application_state == WebSocketState.DISCONNECTED
)
else:
# if isinstance(self.websocket, WebSocket):
# return (
# self.websocket.client_state == WebSocketState.DISCONNECTED
# or self.websocket.application_state == WebSocketState.DISCONNECTED
# )
return self.websocket.closed
@overrides(BaseWebSocket)
async def accept(self):
if isinstance(self.websocket, FastAPIWebSocket):
await self.websocket.accept()
else:
# if isinstance(self.websocket, WebSocket):
# await self.websocket.accept()
raise NotImplementedError
@overrides(BaseWebSocket)
async def close(self, code: int = 1000):
await self.websocket.close(code)
@overrides(BaseWebSocket)
async def receive(self) -> str:
# if isinstance(self.websocket, WebSocket):
# return await self.websocket.receive_text()
msg = await self.websocket.recv()
if isinstance(msg, bytes):
raise TypeError("WebSocket received unexpected frame type: bytes")
return msg
@overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes:
# if isinstance(self.websocket, WebSocket):
# return await self.websocket.receive_bytes()
msg = await self.websocket.recv()
if isinstance(msg, str):
raise TypeError("WebSocket received unexpected frame type: str")
return msg
@overrides(BaseWebSocket)
async def send(self, data: str) -> None:
# if isinstance(self.websocket, WebSocket):
# await self.websocket.send({"type": "websocket.send", "text": data})
await self.websocket.send(data)
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None:
# if isinstance(self.websocket, WebSocket):
# await self.websocket.send({"type": "websocket.send", "bytes": data})
await self.websocket.send(data)
class FastAPIWebSocket(BaseWebSocket):
@overrides(BaseWebSocket)
def __init__(self, *, request: GenericRequest, websocket: WebSocket):
super().__init__(request=request)
self.websocket = websocket
@property
@overrides(BaseWebSocket)
def closed(self) -> bool:
return (
self.websocket.client_state == WebSocketState.DISCONNECTED
or self.websocket.application_state == WebSocketState.DISCONNECTED
)
@overrides(BaseWebSocket)
async def accept(self):
await self.websocket.accept()
@overrides(BaseWebSocket)
async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE):
await self.websocket.close(code)
@overrides(BaseWebSocket)
async def receive(self) -> str:
if isinstance(self.websocket, FastAPIWebSocket):
return await self.websocket.receive_text()
else:
msg = await self.websocket.recv()
return msg.decode("utf-8") if isinstance(msg, bytes) else msg
@overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes:
if isinstance(self.websocket, FastAPIWebSocket):
return await self.websocket.receive_bytes()
else:
msg = await self.websocket.recv()
return msg.encode("utf-8") if isinstance(msg, str) else msg
@overrides(BaseWebSocket)
async def send(self, data: str) -> None:
if isinstance(self.websocket, FastAPIWebSocket):
await self.websocket.send({"type": "websocket.send", "text": data})
else:
await self.websocket.send(data)
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None:
if isinstance(self.websocket, FastAPIWebSocket):
await self.websocket.send({"type": "websocket.send", "bytes": data})
else:
await self.websocket.send(data)