""" FastAPI 驱动适配 ================ 后端使用方法请参考: `FastAPI 文档`_ .. _FastAPI 文档: https://fastapi.tiangolo.com/ """ import asyncio import logging from dataclasses import dataclass from typing import List, Optional, Callable import uvicorn from pydantic import BaseSettings from fastapi.responses import Response from fastapi import status, Request, FastAPI, HTTPException from starlette.websockets import (WebSocketState, WebSocketDisconnect, WebSocket as FastAPIWebSocket) from nonebot.log import logger from nonebot.typing import overrides from nonebot.config import Env, Config as NoneBotConfig from nonebot.drivers import ReverseDriver, HTTPRequest, WebSocket as BaseWebSocket class Config(BaseSettings): """ FastAPI 驱动框架设置,详情参考 FastAPI 文档 """ fastapi_openapi_url: Optional[str] = None """ :类型: ``Optional[str]`` :说明: ``openapi.json`` 地址,默认为 ``None`` 即关闭 """ fastapi_docs_url: Optional[str] = None """ :类型: ``Optional[str]`` :说明: ``swagger`` 地址,默认为 ``None`` 即关闭 """ fastapi_redoc_url: Optional[str] = None """ :类型: ``Optional[str]`` :说明: ``redoc`` 地址,默认为 ``None`` 即关闭 """ fastapi_reload_dirs: List[str] = [] """ :类型: ``List[str]`` :说明: ``debug`` 模式下重载监控文件夹列表,默认为 uvicorn 默认值 """ class Config: extra = "ignore" class Driver(ReverseDriver): """ FastAPI 驱动框架 :上报地址: * ``/{adapter name}/``: HTTP POST 上报 * ``/{adapter name}/http/``: HTTP POST 上报 * ``/{adapter name}/ws``: WebSocket 上报 * ``/{adapter name}/ws/``: WebSocket 上报 """ def __init__(self, env: Env, config: NoneBotConfig): super().__init__(env, config) self.fastapi_config = Config(**config.dict()) self._server_app = FastAPI( debug=config.debug, openapi_url=self.fastapi_config.fastapi_openapi_url, docs_url=self.fastapi_config.fastapi_docs_url, redoc_url=self.fastapi_config.fastapi_redoc_url, ) self._server_app.post("/{adapter}/")(self._handle_http) self._server_app.post("/{adapter}/http")(self._handle_http) self._server_app.websocket("/{adapter}/ws")(self._handle_ws_reverse) self._server_app.websocket("/{adapter}/ws/")(self._handle_ws_reverse) @property @overrides(ReverseDriver) def type(self) -> str: """驱动名称: ``fastapi``""" return "fastapi" @property @overrides(ReverseDriver) def server_app(self) -> FastAPI: """``FastAPI APP`` 对象""" return self._server_app @property @overrides(ReverseDriver) def asgi(self) -> FastAPI: """``FastAPI APP`` 对象""" return self._server_app @property @overrides(ReverseDriver) def logger(self) -> logging.Logger: """fastapi 使用的 logger""" return logging.getLogger("fastapi") @overrides(ReverseDriver) def on_startup(self, func: Callable) -> Callable: """参考文档: `Events `_""" return self.server_app.on_event("startup")(func) @overrides(ReverseDriver) def on_shutdown(self, func: Callable) -> Callable: """参考文档: `Events `_""" return self.server_app.on_event("shutdown")(func) @overrides(ReverseDriver) def run(self, host: Optional[str] = None, port: Optional[int] = None, *, app: Optional[str] = None, **kwargs): """使用 ``uvicorn`` 启动 FastAPI""" super().run(host, port, app, **kwargs) LOGGING_CONFIG = { "version": 1, "disable_existing_loggers": False, "handlers": { "default": { "class": "nonebot.log.LoguruHandler", }, }, "loggers": { "uvicorn.error": { "handlers": ["default"], "level": "INFO" }, "uvicorn.access": { "handlers": ["default"], "level": "INFO", }, }, } uvicorn.run(app or self.server_app, host=host or str(self.config.host), port=port or self.config.port, reload=bool(app) and self.config.debug, reload_dirs=self.fastapi_config.fastapi_reload_dirs or None, debug=self.config.debug, log_config=LOGGING_CONFIG, **kwargs) async def _handle_http(self, adapter: str, request: Request): data = await request.body() if adapter not in self._adapters: logger.warning( f"Unknown adapter {adapter}. Please register the adapter before use." ) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="adapter not found") # 创建 Bot 对象 BotClass = self._adapters[adapter] http_request = HTTPRequest(request.scope["http_version"], request.url.scheme, request.url.path, request.scope["query_string"], dict(request.headers), request.method, data) x_self_id, response = await BotClass.check_permission( self, http_request) if not x_self_id: raise HTTPException(response and response.status or 401, response.body) if x_self_id in self._clients: logger.warning("There's already a reverse websocket connection," "so the event may be handled twice.") bot = BotClass(x_self_id, http_request) asyncio.create_task(bot.handle_message(data)) return Response(response and response.body, response and response.status or 200) async def _handle_ws_reverse(self, adapter: str, websocket: FastAPIWebSocket): ws = WebSocket(websocket.scope.get("http_version", "1.1"), websocket.url.scheme, websocket.url.path, websocket.scope["query_string"], dict(websocket.headers), websocket) if adapter not in self._adapters: logger.warning( f"Unknown adapter {adapter}. Please register the adapter before use." ) await ws.close(code=status.WS_1008_POLICY_VIOLATION) return # Create Bot Object BotClass = self._adapters[adapter] x_self_id, _ = await BotClass.check_permission(self, ws) if not x_self_id: await ws.close(code=status.WS_1008_POLICY_VIOLATION) return if x_self_id in self._clients: logger.opt(colors=True).warning( "There's already a reverse websocket connection, " f"{adapter.upper()} Bot {x_self_id} ignored.") await ws.close(code=status.WS_1008_POLICY_VIOLATION) return bot = BotClass(x_self_id, ws) await ws.accept() logger.opt(colors=True).info( f"WebSocket Connection from {adapter.upper()} " f"Bot {x_self_id} Accepted!") self._bot_connect(bot) try: while not ws.closed: try: data = await ws.receive() except WebSocketDisconnect: logger.error("WebSocket disconnected by peer.") break except Exception as e: logger.opt(exception=e).error( "Error when receiving data from websocket.") break asyncio.create_task(bot.handle_message(data.encode())) finally: self._bot_disconnect(bot) @dataclass class WebSocket(BaseWebSocket): websocket: FastAPIWebSocket = None # type: ignore @property @overrides(BaseWebSocket) def closed(self): return (self.websocket.client_state == WebSocketState.DISCONNECTED or self.websocket.application_state == WebSocketState.DISCONNECTED) @overrides(BaseWebSocket) async def accept(self): await self.websocket.accept() @overrides(BaseWebSocket) async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE): await self.websocket.close(code=code) @overrides(BaseWebSocket) async def receive(self) -> str: return await self.websocket.receive_text() @overrides(BaseWebSocket) async def receive_bytes(self) -> bytes: return await self.websocket.receive_bytes() @overrides(BaseWebSocket) async def send(self, data: str) -> None: await self.websocket.send({"type": "websocket.send", "text": data}) @overrides(BaseWebSocket) async def send_bytes(self, data: bytes) -> None: await self.websocket.send({"type": "websocket.send", "bytes": data})