🚧 rewrite fastapi driver implementation

This commit is contained in:
yanyongyu 2021-12-18 23:19:37 +08:00
parent ec9e159ef6
commit ca045b2f73
3 changed files with 123 additions and 87 deletions

View File

@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, Awaitable
from nonebot.log import logger from nonebot.log import logger
from nonebot.utils import escape_tag from nonebot.utils import escape_tag
from nonebot.config import Env, Config from nonebot.config import Env, Config
from ._model import URL, Request, Response, WebSocket from ._model import URL, Request, Response, WebSocket, HTTPVersion
from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -82,6 +82,7 @@ class Request:
self.url = url self.url = url
# headers # headers
self.headers: CIMultiDict[str]
if headers is not None: if headers is not None:
self.headers = CIMultiDict(headers) self.headers = CIMultiDict(headers)
else: else:
@ -112,6 +113,7 @@ class Response:
self.status_code = status_code self.status_code = status_code
# headers # headers
self.headers: CIMultiDict[str]
if headers is not None: if headers is not None:
self.headers = CIMultiDict(headers) self.headers = CIMultiDict(headers)
else: else:
@ -144,7 +146,7 @@ class WebSocket(abc.ABC):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def close(self, code: int): async def close(self, code: int = 1000):
"""关闭 WebSocket 连接请求""" """关闭 WebSocket 连接请求"""
raise NotImplementedError raise NotImplementedError

View File

@ -10,7 +10,6 @@ FastAPI 驱动适配
https://fastapi.tiangolo.com/ https://fastapi.tiangolo.com/
""" """
import asyncio
import logging import logging
from functools import partial from functools import partial
from dataclasses import dataclass from dataclasses import dataclass
@ -20,23 +19,21 @@ import httpx
import uvicorn import uvicorn
from pydantic import BaseSettings from pydantic import BaseSettings
from fastapi.responses import Response from fastapi.responses import Response
from starlette.websockets import WebSocketState from fastapi import FastAPI, Request, status
from fastapi import Depends, FastAPI, Request, status from starlette.websockets import WebSocket, WebSocketState
from starlette.websockets import WebSocket as FastAPIWebSocket
from websockets.legacy.client import Connect, WebSocketClientProtocol from websockets.legacy.client import Connect, WebSocketClientProtocol
from nonebot.config import Env from nonebot.config import Env
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.utils import escape_tag from nonebot.utils import escape_tag
from nonebot.drivers import WebSocket
from nonebot.config import Config as NoneBotConfig from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import Request as GenericRequest
from nonebot.drivers import WebSocket as BaseWebSocket from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import Response as GenericResponse
from nonebot.drivers import ( from nonebot.drivers import (
HTTPRequest, HTTPVersion,
HTTPResponse,
ForwardDriver, ForwardDriver,
ReverseDriver, ReverseDriver,
HTTPConnection,
HTTPServerSetup, HTTPServerSetup,
WebSocketServerSetup, WebSocketServerSetup,
) )
@ -173,26 +170,23 @@ class Driver(ReverseDriver):
@overrides(ReverseDriver) @overrides(ReverseDriver)
def setup_http_server(self, setup: HTTPServerSetup): def setup_http_server(self, setup: HTTPServerSetup):
def _get_handle_func(): async def _handle(request: Request) -> Response:
return setup.handle_func return await self._handle_http(request, setup)
self._server_app.add_api_route( self._server_app.add_api_route(
setup.path, setup.path.path,
partial(self._handle_http, handle_func=Depends(_get_handle_func)), _handle,
methods=[setup.method], methods=[setup.method],
) )
@overrides(ReverseDriver) @overrides(ReverseDriver)
def setup_websocket_server(self, setup: WebSocketServerSetup) -> None: def setup_websocket_server(self, setup: WebSocketServerSetup) -> None:
def _get_handle_func(): async def _handle(websocket: WebSocket):
return setup.handle_func await self._handle_ws(websocket, setup)
self._server_app.add_api_websocket_route( self._server_app.add_api_websocket_route(
setup.path, setup.path.path,
partial( _handle,
self._handle_ws,
handle_func=Depends(_get_handle_func),
),
) )
@overrides(ReverseDriver) @overrides(ReverseDriver)
@ -251,36 +245,34 @@ class Driver(ReverseDriver):
async def _handle_http( async def _handle_http(
self, self,
request: Request, request: Request,
handle_func: Callable[[HTTPRequest], Awaitable[HTTPResponse]], setup: HTTPServerSetup,
): ):
http_request = HTTPRequest( http_request = GenericRequest(
request.scope["http_version"],
request.url.scheme,
request.url.path,
request.scope["query_string"],
dict(request.headers),
request.method, request.method,
await request.body(), str(request.url),
headers=request.headers.items(),
cookies=request.cookies,
content=await request.body(),
version=request.scope["http_version"],
) )
response = await handle_func(http_request) response = await setup.handle_func(http_request)
return Response(response.body, response.status, response.headers) return Response(response.content, response.status_code, dict(response.headers))
async def _handle_ws( async def _handle_ws(self, websocket: WebSocket, setup: WebSocketServerSetup):
self, request = GenericRequest(
websocket: FastAPIWebSocket, "GET",
handle_func: Callable[[WebSocket], Awaitable[Any]], str(websocket.url),
): headers=websocket.headers.items(),
ws = WebSocket( cookies=websocket.cookies,
websocket.scope.get("http_version", "1.1"), version=websocket.scope["http_version"],
websocket.url.scheme, )
websocket.url.path, ws = FastAPIWebSocket(
websocket.scope["query_string"], request=request,
dict(websocket.headers), websocket=websocket,
websocket,
) )
await handle_func(ws) await setup.handle_func(ws)
class FullDriver(ForwardDriver, Driver): class FullDriver(ForwardDriver, Driver):
@ -295,54 +287,110 @@ class FullDriver(ForwardDriver, Driver):
""" """
@property @property
@overrides(ForwardDriver) @overrides(Driver)
def type(self) -> str: def type(self) -> str:
"""驱动名称: ``fastapi_full``""" """驱动名称: ``fastapi_full``"""
return "fastapi_full" return "fastapi_full"
@overrides(ForwardDriver) @overrides(ForwardDriver)
async def request(self, setup: "HTTPRequest") -> Any: async def request(self, setup: "GenericRequest") -> Any:
async with httpx.AsyncClient( async with httpx.AsyncClient(
http2=setup.http_version == "2", follow_redirects=True http2=setup.version == HTTPVersion.H2, follow_redirects=True
) as client: ) as client:
response = await client.request( response = await client.request(
setup.method, setup.method,
setup.url, str(setup.url),
content=setup.body, content=setup.content,
headers=setup.headers, headers=tuple(setup.headers.items()),
timeout=30.0, timeout=30.0,
) )
return HTTPResponse( return GenericResponse(
response.status_code, response.content, response.headers response.status_code,
headers=response.headers,
content=response.content,
request=setup,
) )
@overrides(ForwardDriver) @overrides(ForwardDriver)
async def websocket(self, setup: "HTTPConnection") -> Any: async def websocket(self, setup: "GenericRequest") -> Any:
ws = await Connect(setup.url, extra_headers=setup.headers) ws = await Connect(str(setup.url), extra_headers=setup.headers.items())
return WebSocket("1.1", url.scheme, url.path, url.query, setup.headers, ws) return WebSocketsWS(request=setup, websocket=ws)
@dataclass class WebSocketsWS(BaseWebSocket):
class WebSocket(BaseWebSocket): @overrides(BaseWebSocket)
websocket: Union[FastAPIWebSocket, WebSocketClientProtocol] = None # type: ignore def __init__(self, *, request: GenericRequest, websocket: WebSocketClientProtocol):
super().__init__(request=request)
self.websocket = websocket
@property @property
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
def closed(self) -> bool: def closed(self) -> bool:
if isinstance(self.websocket, FastAPIWebSocket): # if isinstance(self.websocket, WebSocket):
return ( # return (
self.websocket.client_state == WebSocketState.DISCONNECTED # self.websocket.client_state == WebSocketState.DISCONNECTED
or self.websocket.application_state == WebSocketState.DISCONNECTED # or self.websocket.application_state == WebSocketState.DISCONNECTED
) # )
else: return self.websocket.closed
return self.websocket.closed
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def accept(self): async def accept(self):
if isinstance(self.websocket, FastAPIWebSocket): # if isinstance(self.websocket, WebSocket):
await self.websocket.accept() # await self.websocket.accept()
else: raise NotImplementedError
raise NotImplementedError
@overrides(BaseWebSocket)
async def close(self, code: int = 1000):
await self.websocket.close(code)
@overrides(BaseWebSocket)
async def receive(self) -> str:
# if isinstance(self.websocket, WebSocket):
# return await self.websocket.receive_text()
msg = await self.websocket.recv()
if isinstance(msg, bytes):
raise TypeError("WebSocket received unexpected frame type: bytes")
return msg
@overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes:
# if isinstance(self.websocket, WebSocket):
# return await self.websocket.receive_bytes()
msg = await self.websocket.recv()
if isinstance(msg, str):
raise TypeError("WebSocket received unexpected frame type: str")
return msg
@overrides(BaseWebSocket)
async def send(self, data: str) -> None:
# if isinstance(self.websocket, WebSocket):
# await self.websocket.send({"type": "websocket.send", "text": data})
await self.websocket.send(data)
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None:
# if isinstance(self.websocket, WebSocket):
# await self.websocket.send({"type": "websocket.send", "bytes": data})
await self.websocket.send(data)
class FastAPIWebSocket(BaseWebSocket):
@overrides(BaseWebSocket)
def __init__(self, *, request: GenericRequest, websocket: WebSocket):
super().__init__(request=request)
self.websocket = websocket
@property
@overrides(BaseWebSocket)
def closed(self) -> bool:
return (
self.websocket.client_state == WebSocketState.DISCONNECTED
or self.websocket.application_state == WebSocketState.DISCONNECTED
)
@overrides(BaseWebSocket)
async def accept(self):
await self.websocket.accept()
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE): async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE):
@ -350,30 +398,16 @@ class WebSocket(BaseWebSocket):
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def receive(self) -> str: async def receive(self) -> str:
if isinstance(self.websocket, FastAPIWebSocket): return await self.websocket.receive_text()
return await self.websocket.receive_text()
else:
msg = await self.websocket.recv()
return msg.decode("utf-8") if isinstance(msg, bytes) else msg
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes: async def receive_bytes(self) -> bytes:
if isinstance(self.websocket, FastAPIWebSocket): return await self.websocket.receive_bytes()
return await self.websocket.receive_bytes()
else:
msg = await self.websocket.recv()
return msg.encode("utf-8") if isinstance(msg, str) else msg
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def send(self, data: str) -> None: async def send(self, data: str) -> None:
if isinstance(self.websocket, FastAPIWebSocket): await self.websocket.send({"type": "websocket.send", "text": data})
await self.websocket.send({"type": "websocket.send", "text": data})
else:
await self.websocket.send(data)
@overrides(BaseWebSocket) @overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None: async def send_bytes(self, data: bytes) -> None:
if isinstance(self.websocket, FastAPIWebSocket): await self.websocket.send({"type": "websocket.send", "bytes": data})
await self.websocket.send({"type": "websocket.send", "bytes": data})
else:
await self.websocket.send(data)