diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index 75398eb9..7511c745 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, Awaitable from nonebot.log import logger from nonebot.utils import escape_tag from nonebot.config import Env, Config -from ._model import URL, Request, Response, WebSocket +from ._model import URL, Request, Response, WebSocket, HTTPVersion from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook if TYPE_CHECKING: diff --git a/nonebot/drivers/_model.py b/nonebot/drivers/_model.py index ede3276a..05d59c77 100644 --- a/nonebot/drivers/_model.py +++ b/nonebot/drivers/_model.py @@ -82,6 +82,7 @@ class Request: self.url = url # headers + self.headers: CIMultiDict[str] if headers is not None: self.headers = CIMultiDict(headers) else: @@ -112,6 +113,7 @@ class Response: self.status_code = status_code # headers + self.headers: CIMultiDict[str] if headers is not None: self.headers = CIMultiDict(headers) else: @@ -144,7 +146,7 @@ class WebSocket(abc.ABC): raise NotImplementedError @abc.abstractmethod - async def close(self, code: int): + async def close(self, code: int = 1000): """关闭 WebSocket 连接请求""" raise NotImplementedError diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index 9fea8d95..8eea6ea2 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -10,7 +10,6 @@ FastAPI 驱动适配 https://fastapi.tiangolo.com/ """ -import asyncio import logging from functools import partial from dataclasses import dataclass @@ -20,23 +19,21 @@ import httpx import uvicorn from pydantic import BaseSettings from fastapi.responses import Response -from starlette.websockets import WebSocketState -from fastapi import Depends, FastAPI, Request, status -from starlette.websockets import WebSocket as FastAPIWebSocket +from fastapi import FastAPI, Request, status +from starlette.websockets import WebSocket, WebSocketState from websockets.legacy.client import Connect, WebSocketClientProtocol from nonebot.config import Env from nonebot.typing import overrides from nonebot.utils import escape_tag -from nonebot.drivers import WebSocket from nonebot.config import Config as NoneBotConfig +from nonebot.drivers import Request as GenericRequest from nonebot.drivers import WebSocket as BaseWebSocket +from nonebot.drivers import Response as GenericResponse from nonebot.drivers import ( - HTTPRequest, - HTTPResponse, + HTTPVersion, ForwardDriver, ReverseDriver, - HTTPConnection, HTTPServerSetup, WebSocketServerSetup, ) @@ -173,26 +170,23 @@ class Driver(ReverseDriver): @overrides(ReverseDriver) def setup_http_server(self, setup: HTTPServerSetup): - def _get_handle_func(): - return setup.handle_func + async def _handle(request: Request) -> Response: + return await self._handle_http(request, setup) self._server_app.add_api_route( - setup.path, - partial(self._handle_http, handle_func=Depends(_get_handle_func)), + setup.path.path, + _handle, methods=[setup.method], ) @overrides(ReverseDriver) def setup_websocket_server(self, setup: WebSocketServerSetup) -> None: - def _get_handle_func(): - return setup.handle_func + async def _handle(websocket: WebSocket): + await self._handle_ws(websocket, setup) self._server_app.add_api_websocket_route( - setup.path, - partial( - self._handle_ws, - handle_func=Depends(_get_handle_func), - ), + setup.path.path, + _handle, ) @overrides(ReverseDriver) @@ -251,36 +245,34 @@ class Driver(ReverseDriver): async def _handle_http( self, request: Request, - handle_func: Callable[[HTTPRequest], Awaitable[HTTPResponse]], + setup: HTTPServerSetup, ): - http_request = HTTPRequest( - request.scope["http_version"], - request.url.scheme, - request.url.path, - request.scope["query_string"], - dict(request.headers), + http_request = GenericRequest( request.method, - await request.body(), + str(request.url), + headers=request.headers.items(), + cookies=request.cookies, + content=await request.body(), + version=request.scope["http_version"], ) - response = await handle_func(http_request) - return Response(response.body, response.status, response.headers) + response = await setup.handle_func(http_request) + return Response(response.content, response.status_code, dict(response.headers)) - async def _handle_ws( - self, - websocket: FastAPIWebSocket, - handle_func: Callable[[WebSocket], Awaitable[Any]], - ): - ws = WebSocket( - websocket.scope.get("http_version", "1.1"), - websocket.url.scheme, - websocket.url.path, - websocket.scope["query_string"], - dict(websocket.headers), - websocket, + async def _handle_ws(self, websocket: WebSocket, setup: WebSocketServerSetup): + request = GenericRequest( + "GET", + str(websocket.url), + headers=websocket.headers.items(), + cookies=websocket.cookies, + version=websocket.scope["http_version"], + ) + ws = FastAPIWebSocket( + request=request, + websocket=websocket, ) - await handle_func(ws) + await setup.handle_func(ws) class FullDriver(ForwardDriver, Driver): @@ -295,54 +287,110 @@ class FullDriver(ForwardDriver, Driver): """ @property - @overrides(ForwardDriver) + @overrides(Driver) def type(self) -> str: """驱动名称: ``fastapi_full``""" return "fastapi_full" @overrides(ForwardDriver) - async def request(self, setup: "HTTPRequest") -> Any: + async def request(self, setup: "GenericRequest") -> Any: async with httpx.AsyncClient( - http2=setup.http_version == "2", follow_redirects=True + http2=setup.version == HTTPVersion.H2, follow_redirects=True ) as client: response = await client.request( setup.method, - setup.url, - content=setup.body, - headers=setup.headers, + str(setup.url), + content=setup.content, + headers=tuple(setup.headers.items()), timeout=30.0, ) - return HTTPResponse( - response.status_code, response.content, response.headers + return GenericResponse( + response.status_code, + headers=response.headers, + content=response.content, + request=setup, ) @overrides(ForwardDriver) - async def websocket(self, setup: "HTTPConnection") -> Any: - ws = await Connect(setup.url, extra_headers=setup.headers) - return WebSocket("1.1", url.scheme, url.path, url.query, setup.headers, ws) + async def websocket(self, setup: "GenericRequest") -> Any: + ws = await Connect(str(setup.url), extra_headers=setup.headers.items()) + return WebSocketsWS(request=setup, websocket=ws) -@dataclass -class WebSocket(BaseWebSocket): - websocket: Union[FastAPIWebSocket, WebSocketClientProtocol] = None # type: ignore +class WebSocketsWS(BaseWebSocket): + @overrides(BaseWebSocket) + def __init__(self, *, request: GenericRequest, websocket: WebSocketClientProtocol): + super().__init__(request=request) + self.websocket = websocket @property @overrides(BaseWebSocket) def closed(self) -> bool: - if isinstance(self.websocket, FastAPIWebSocket): - return ( - self.websocket.client_state == WebSocketState.DISCONNECTED - or self.websocket.application_state == WebSocketState.DISCONNECTED - ) - else: - return self.websocket.closed + # if isinstance(self.websocket, WebSocket): + # return ( + # self.websocket.client_state == WebSocketState.DISCONNECTED + # or self.websocket.application_state == WebSocketState.DISCONNECTED + # ) + return self.websocket.closed @overrides(BaseWebSocket) async def accept(self): - if isinstance(self.websocket, FastAPIWebSocket): - await self.websocket.accept() - else: - raise NotImplementedError + # if isinstance(self.websocket, WebSocket): + # await self.websocket.accept() + raise NotImplementedError + + @overrides(BaseWebSocket) + async def close(self, code: int = 1000): + await self.websocket.close(code) + + @overrides(BaseWebSocket) + async def receive(self) -> str: + # if isinstance(self.websocket, WebSocket): + # return await self.websocket.receive_text() + msg = await self.websocket.recv() + if isinstance(msg, bytes): + raise TypeError("WebSocket received unexpected frame type: bytes") + return msg + + @overrides(BaseWebSocket) + async def receive_bytes(self) -> bytes: + # if isinstance(self.websocket, WebSocket): + # return await self.websocket.receive_bytes() + msg = await self.websocket.recv() + if isinstance(msg, str): + raise TypeError("WebSocket received unexpected frame type: str") + return msg + + @overrides(BaseWebSocket) + async def send(self, data: str) -> None: + # if isinstance(self.websocket, WebSocket): + # await self.websocket.send({"type": "websocket.send", "text": data}) + await self.websocket.send(data) + + @overrides(BaseWebSocket) + async def send_bytes(self, data: bytes) -> None: + # if isinstance(self.websocket, WebSocket): + # await self.websocket.send({"type": "websocket.send", "bytes": data}) + await self.websocket.send(data) + + +class FastAPIWebSocket(BaseWebSocket): + @overrides(BaseWebSocket) + def __init__(self, *, request: GenericRequest, websocket: WebSocket): + super().__init__(request=request) + self.websocket = websocket + + @property + @overrides(BaseWebSocket) + def closed(self) -> bool: + return ( + self.websocket.client_state == WebSocketState.DISCONNECTED + or self.websocket.application_state == WebSocketState.DISCONNECTED + ) + + @overrides(BaseWebSocket) + async def accept(self): + await self.websocket.accept() @overrides(BaseWebSocket) async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE): @@ -350,30 +398,16 @@ class WebSocket(BaseWebSocket): @overrides(BaseWebSocket) async def receive(self) -> str: - if isinstance(self.websocket, FastAPIWebSocket): - return await self.websocket.receive_text() - else: - msg = await self.websocket.recv() - return msg.decode("utf-8") if isinstance(msg, bytes) else msg + return await self.websocket.receive_text() @overrides(BaseWebSocket) async def receive_bytes(self) -> bytes: - if isinstance(self.websocket, FastAPIWebSocket): - return await self.websocket.receive_bytes() - else: - msg = await self.websocket.recv() - return msg.encode("utf-8") if isinstance(msg, str) else msg + return await self.websocket.receive_bytes() @overrides(BaseWebSocket) async def send(self, data: str) -> None: - if isinstance(self.websocket, FastAPIWebSocket): - await self.websocket.send({"type": "websocket.send", "text": data}) - else: - await self.websocket.send(data) + await self.websocket.send({"type": "websocket.send", "text": data}) @overrides(BaseWebSocket) async def send_bytes(self, data: bytes) -> None: - if isinstance(self.websocket, FastAPIWebSocket): - await self.websocket.send({"type": "websocket.send", "bytes": data}) - else: - await self.websocket.send(data) + await self.websocket.send({"type": "websocket.send", "bytes": data})