nonebot2/nonebot/drivers/fastapi.py

615 lines
20 KiB
Python
Raw Normal View History

2020-10-16 01:10:46 +08:00
"""
FastAPI 驱动适配
================
2021-07-31 12:24:11 +08:00
本驱动同时支持服务端以及客户端连接
2020-10-16 01:10:46 +08:00
后端使用方法请参考: `FastAPI 文档`_
.. _FastAPI 文档:
https://fastapi.tiangolo.com/
"""
2020-07-04 22:51:10 +08:00
import asyncio
2020-07-04 22:51:10 +08:00
import logging
2021-06-10 21:52:20 +08:00
from dataclasses import dataclass
from typing import List, Union, Callable, Optional, Awaitable, cast
2020-07-04 22:51:10 +08:00
import httpx
2020-07-04 22:51:10 +08:00
import uvicorn
2021-01-12 18:02:05 +08:00
from pydantic import BaseSettings
2020-08-25 18:02:18 +08:00
from fastapi.responses import Response
from websockets.exceptions import ConnectionClosed
from fastapi import FastAPI, Request, HTTPException, status
from starlette.websockets import WebSocket as FastAPIWebSocket
from starlette.websockets import WebSocketState, WebSocketDisconnect
from websockets.legacy.client import Connect, WebSocketClientProtocol
2020-07-04 22:51:10 +08:00
from nonebot.config import Env
2020-07-05 20:39:34 +08:00
from nonebot.log import logger
from nonebot.adapters import Bot
2020-12-06 02:30:19 +08:00
from nonebot.typing import overrides
from nonebot.utils import escape_tag
from nonebot.config import Config as NoneBotConfig
from nonebot.drivers import WebSocket as BaseWebSocket
from nonebot.drivers import (
HTTPRequest,
ForwardDriver,
ReverseDriver,
WebSocketSetup,
HTTPPollingSetup,
)
HTTPPOLLING_SETUP = Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]
2021-07-31 12:24:11 +08:00
WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
2020-08-20 16:34:07 +08:00
2021-01-12 18:02:05 +08:00
class Config(BaseSettings):
2021-02-05 13:31:33 +08:00
"""
FastAPI 驱动框架设置详情参考 FastAPI 文档
"""
2021-01-12 18:02:05 +08:00
fastapi_openapi_url: Optional[str] = None
2021-02-05 13:31:33 +08:00
"""
:类型:
``Optional[str]``
:说明:
``openapi.json`` 地址默认为 ``None`` 即关闭
2021-02-05 13:31:33 +08:00
"""
2021-01-12 18:02:05 +08:00
fastapi_docs_url: Optional[str] = None
2021-02-05 13:31:33 +08:00
"""
:类型:
``Optional[str]``
:说明:
``swagger`` 地址默认为 ``None`` 即关闭
2021-02-05 13:31:33 +08:00
"""
2021-01-12 18:02:05 +08:00
fastapi_redoc_url: Optional[str] = None
2021-02-05 13:31:33 +08:00
"""
:类型:
``Optional[str]``
:说明:
``redoc`` 地址默认为 ``None`` 即关闭
"""
fastapi_reload: Optional[bool] = None
2021-09-28 21:20:29 +08:00
"""
:类型:
``Optional[bool]``
2021-09-28 21:20:29 +08:00
:说明:
开启/关闭冷重载默认会在配置了 app debug 模式启用
2021-09-28 21:20:29 +08:00
"""
fastapi_reload_dirs: Optional[List[str]] = None
"""
:类型:
``Optional[List[str]]``
:说明:
2021-09-28 21:20:29 +08:00
重载监控文件夹列表默认为 uvicorn 默认值
"""
fastapi_reload_delay: Optional[float] = None
"""
:类型:
``Optional[float]``
:说明:
重载延迟默认为 uvicorn 默认值
2021-02-05 13:31:33 +08:00
"""
fastapi_reload_includes: Optional[List[str]] = None
"""
:类型:
``Optional[List[str]]``
:说明:
要监听的文件列表支持 glob pattern默认为 uvicorn 默认值
"""
fastapi_reload_excludes: Optional[List[str]] = None
"""
:类型:
``Optional[List[str]]``
:说明:
不要监听的文件列表支持 glob pattern默认为 uvicorn 默认值
"""
2021-01-12 18:02:05 +08:00
class Config:
extra = "ignore"
class Driver(ReverseDriver, ForwardDriver):
2020-12-02 15:14:24 +08:00
"""
FastAPI 驱动框架
:上报地址:
* ``/{adapter name}/``: HTTP POST 上报
* ``/{adapter name}/http/``: HTTP POST 上报
* ``/{adapter name}/ws``: WebSocket 上报
* ``/{adapter name}/ws/``: WebSocket 上报
"""
2020-07-04 22:51:10 +08:00
2021-01-12 18:02:05 +08:00
def __init__(self, env: Env, config: NoneBotConfig):
2020-08-10 13:06:02 +08:00
super().__init__(env, config)
2020-08-07 17:51:57 +08:00
self.fastapi_config: Config = Config(**config.dict())
2021-07-31 12:24:11 +08:00
self.http_pollings: List[HTTPPOLLING_SETUP] = []
self.websockets: List[WEBSOCKET_SETUP] = []
self.shutdown: asyncio.Event = asyncio.Event()
self.connections: List[asyncio.Task] = []
2021-01-12 18:02:05 +08:00
2020-07-04 22:51:10 +08:00
self._server_app = FastAPI(
debug=config.debug,
2021-01-12 18:02:05 +08:00
openapi_url=self.fastapi_config.fastapi_openapi_url,
docs_url=self.fastapi_config.fastapi_docs_url,
redoc_url=self.fastapi_config.fastapi_redoc_url,
2020-07-04 22:51:10 +08:00
)
2020-07-11 17:32:03 +08:00
self._server_app.post("/{adapter}/")(self._handle_http)
self._server_app.post("/{adapter}/http")(self._handle_http)
self._server_app.websocket("/{adapter}/ws")(self._handle_ws_reverse)
self._server_app.websocket("/{adapter}/ws/")(self._handle_ws_reverse)
2020-07-05 20:39:34 +08:00
self.on_startup(self._run_forward)
self.on_shutdown(self._shutdown_forward)
2020-08-13 15:56:09 +08:00
@property
2021-05-21 17:06:20 +08:00
@overrides(ReverseDriver)
2020-08-13 15:56:09 +08:00
def type(self) -> str:
2020-10-16 01:10:46 +08:00
"""驱动名称: ``fastapi``"""
2020-08-13 15:56:09 +08:00
return "fastapi"
2020-07-04 22:51:10 +08:00
@property
2021-05-21 17:06:20 +08:00
@overrides(ReverseDriver)
2020-08-11 10:44:05 +08:00
def server_app(self) -> FastAPI:
2020-10-16 01:10:46 +08:00
"""``FastAPI APP`` 对象"""
2020-07-04 22:51:10 +08:00
return self._server_app
@property
2021-05-21 17:06:20 +08:00
@overrides(ReverseDriver)
def asgi(self) -> FastAPI:
2020-10-16 01:10:46 +08:00
"""``FastAPI APP`` 对象"""
2020-07-04 22:51:10 +08:00
return self._server_app
@property
2021-05-21 17:06:20 +08:00
@overrides(ReverseDriver)
2020-08-01 22:03:40 +08:00
def logger(self) -> logging.Logger:
2020-10-16 01:10:46 +08:00
"""fastapi 使用的 logger"""
2020-07-04 22:51:10 +08:00
return logging.getLogger("fastapi")
2021-05-21 17:06:20 +08:00
@overrides(ReverseDriver)
2020-08-11 10:44:05 +08:00
def on_startup(self, func: Callable) -> Callable:
2020-10-16 01:10:46 +08:00
"""参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_"""
2020-08-11 10:44:05 +08:00
return self.server_app.on_event("startup")(func)
2021-05-21 17:06:20 +08:00
@overrides(ReverseDriver)
2020-08-11 10:44:05 +08:00
def on_shutdown(self, func: Callable) -> Callable:
2020-10-16 01:10:46 +08:00
"""参考文档: `Events <https://fastapi.tiangolo.com/advanced/events/#startup-event>`_"""
2020-08-11 10:44:05 +08:00
return self.server_app.on_event("shutdown")(func)
@overrides(ForwardDriver)
2021-07-31 12:24:11 +08:00
def setup_http_polling(self, setup: HTTPPOLLING_SETUP) -> None:
"""
:说明:
注册一个 HTTP 轮询连接如果传入一个函数则该函数会在每次连接时被调用
:参数:
* ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
"""
self.http_pollings.append(setup)
@overrides(ForwardDriver)
2021-07-31 12:24:11 +08:00
def setup_websocket(self, setup: WEBSOCKET_SETUP) -> None:
"""
:说明:
注册一个 WebSocket 连接如果传入一个函数则该函数会在每次重连时被调用
:参数:
* ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
"""
self.websockets.append(setup)
2021-05-21 17:06:20 +08:00
@overrides(ReverseDriver)
def run(
self,
host: Optional[str] = None,
port: Optional[int] = None,
*,
app: Optional[str] = None,
**kwargs,
):
2020-10-16 01:10:46 +08:00
"""使用 ``uvicorn`` 启动 FastAPI"""
super().run(host, port, app, **kwargs)
2020-07-04 22:51:10 +08:00
LOGGING_CONFIG = {
"version": 1,
"disable_existing_loggers": False,
"handlers": {
"default": {
2020-08-27 13:27:42 +08:00
"class": "nonebot.log.LoguruHandler",
2020-07-04 22:51:10 +08:00
},
},
"loggers": {
"uvicorn.error": {"handlers": ["default"], "level": "INFO"},
2020-07-04 22:51:10 +08:00
"uvicorn.access": {
"handlers": ["default"],
"level": "INFO",
},
},
}
uvicorn.run(
app or self.server_app, # type: ignore
host=host or str(self.config.host),
port=port or self.config.port,
reload=self.fastapi_config.fastapi_reload
if self.fastapi_config.fastapi_reload is not None
else (bool(app) and self.config.debug),
reload_dirs=self.fastapi_config.fastapi_reload_dirs,
2021-09-28 21:20:29 +08:00
reload_delay=self.fastapi_config.fastapi_reload_delay,
reload_includes=self.fastapi_config.fastapi_reload_includes,
reload_excludes=self.fastapi_config.fastapi_reload_excludes,
debug=self.config.debug,
log_config=LOGGING_CONFIG,
**kwargs,
)
def _run_forward(self):
for setup in self.http_pollings:
self.connections.append(asyncio.create_task(self._http_loop(setup)))
for setup in self.websockets:
self.connections.append(asyncio.create_task(self._ws_loop(setup)))
def _shutdown_forward(self):
self.shutdown.set()
for task in self.connections:
if not task.done():
task.cancel()
2020-07-05 20:39:34 +08:00
2021-03-20 14:49:58 +08:00
async def _handle_http(self, adapter: str, request: Request):
data = await request.body()
2020-08-10 13:06:02 +08:00
if adapter not in self._adapters:
logger.warning(
f"Unknown adapter {adapter}. Please register the adapter before use."
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="adapter not found"
)
# 创建 Bot 对象
BotClass = self._adapters[adapter]
http_request = HTTPRequest(
request.scope["http_version"],
request.url.scheme,
request.url.path,
request.scope["query_string"],
dict(request.headers),
request.method,
data,
)
x_self_id, response = await BotClass.check_permission(self, http_request)
2021-06-10 21:52:20 +08:00
if not x_self_id:
2021-07-03 11:50:56 +08:00
raise HTTPException(
response and response.status or 401,
response and response.body and response.body.decode("utf-8"),
)
2020-08-25 18:02:18 +08:00
if x_self_id in self._clients:
logger.warning(
"There's already a reverse websocket connection,"
"so the event may be handled twice."
)
2020-08-25 18:02:18 +08:00
2021-06-10 21:52:20 +08:00
bot = BotClass(x_self_id, http_request)
2020-08-10 13:06:02 +08:00
2021-06-10 21:52:20 +08:00
asyncio.create_task(bot.handle_message(data))
return Response(response and response.body, response and response.status or 200)
async def _handle_ws_reverse(self, adapter: str, websocket: FastAPIWebSocket):
ws = WebSocket(
websocket.scope.get("http_version", "1.1"),
websocket.url.scheme,
websocket.url.path,
websocket.scope["query_string"],
dict(websocket.headers),
websocket,
)
2020-08-25 18:02:18 +08:00
if adapter not in self._adapters:
logger.warning(
f"Unknown adapter {adapter}. Please register the adapter before use."
)
2020-08-25 18:02:18 +08:00
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
2020-09-15 14:48:15 +08:00
return
2020-08-21 16:59:41 +08:00
2020-08-01 22:03:40 +08:00
# Create Bot Object
BotClass = self._adapters[adapter]
self_id, _ = await BotClass.check_permission(self, ws)
2021-06-10 21:52:20 +08:00
if not self_id:
2020-09-27 12:37:15 +08:00
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
2020-08-01 22:03:40 +08:00
if self_id in self._clients:
2021-03-16 16:20:58 +08:00
logger.opt(colors=True).warning(
"There's already a websocket connection, "
f"<y>{escape_tag(adapter.upper())} Bot {escape_tag(self_id)}</y> ignored."
)
2020-11-13 01:46:26 +08:00
await ws.close(code=status.WS_1008_POLICY_VIOLATION)
return
2020-11-13 01:46:26 +08:00
bot = BotClass(self_id, ws)
2020-08-25 18:02:18 +08:00
await ws.accept()
2020-08-27 16:43:58 +08:00
logger.opt(colors=True).info(
f"WebSocket Connection from <y>{escape_tag(adapter.upper())} "
f"Bot {escape_tag(self_id)}</y> Accepted!"
)
2020-12-28 13:59:54 +08:00
self._bot_connect(bot)
2020-08-14 17:41:24 +08:00
try:
2020-08-25 18:02:18 +08:00
while not ws.closed:
2021-06-10 21:52:20 +08:00
try:
data = await ws.receive()
except WebSocketDisconnect:
logger.error("WebSocket disconnected by peer.")
break
except Exception as e:
logger.opt(exception=e).error(
"Error when receiving data from websocket."
)
2021-06-10 21:52:20 +08:00
break
asyncio.create_task(bot.handle_message(data.encode()))
2020-08-14 17:41:24 +08:00
finally:
2020-12-28 13:59:54 +08:00
self._bot_disconnect(bot)
2021-07-31 12:24:11 +08:00
async def _http_loop(self, setup: HTTPPOLLING_SETUP):
async def _build_request(setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
2021-07-31 12:24:11 +08:00
url = httpx.URL(setup.url)
if not url.netloc:
logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
)
2021-07-31 12:24:11 +08:00
return
return HTTPRequest(
setup.http_version,
url.scheme,
url.path,
url.query,
{**setup.headers, "host": url.netloc.decode("ascii")},
setup.method,
setup.body,
)
2021-07-31 12:24:11 +08:00
bot: Optional[Bot] = None
request: Optional[HTTPRequest] = None
setup_: Optional[HTTPPollingSetup] = None
logger.opt(colors=True).info(
f"Start http polling for <y>{escape_tag(setup.adapter.upper())} "
f"Bot {escape_tag(setup.self_id)}</y>"
)
try:
async with httpx.AsyncClient(http2=True, follow_redirects=True) as session:
while not self.shutdown.is_set():
try:
2021-07-31 12:24:11 +08:00
if callable(setup):
setup_ = await setup()
else:
setup_ = setup
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error while parsing setup "
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
)
await asyncio.sleep(3)
continue
2021-11-19 18:18:53 +08:00
setup_ = cast(HTTPPollingSetup, setup_)
if not bot:
2021-07-31 12:24:11 +08:00
request = await _build_request(setup_)
if not request:
return
BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request)
self._bot_connect(bot)
elif callable(setup):
request = await _build_request(setup_)
if not request:
await asyncio.sleep(setup_.poll_interval)
continue
bot.request = request
request = cast(HTTPRequest, request)
headers = request.headers
logger.debug(
2021-07-31 12:24:11 +08:00
f"Bot {setup_.self_id} from adapter {setup_.adapter} request {setup_.url}"
)
try:
response = await session.request(
request.method,
setup_.url,
content=request.body,
headers=headers,
timeout=30.0,
)
response.raise_for_status()
data = response.read()
asyncio.create_task(bot.handle_message(data))
except httpx.HTTPError as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error occurred while requesting {escape_tag(setup_.url)}. "
"Try to reconnect...</bg #f8bbd0></r>"
)
2021-07-31 12:24:11 +08:00
await asyncio.sleep(setup_.poll_interval)
except asyncio.CancelledError:
pass
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Unexpected exception occurred "
"while http polling</bg #f8bbd0></r>"
)
finally:
2021-07-31 12:24:11 +08:00
if bot:
self._bot_disconnect(bot)
2021-07-31 12:24:11 +08:00
async def _ws_loop(self, setup: WEBSOCKET_SETUP):
bot: Optional[Bot] = None
2021-07-31 12:24:11 +08:00
try:
while True:
try:
if callable(setup):
setup_ = await setup()
else:
setup_ = setup
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error while parsing setup "
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
)
await asyncio.sleep(3)
continue
2021-07-31 12:24:11 +08:00
url = httpx.URL(setup_.url)
if not url.netloc:
logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
2021-07-31 12:24:11 +08:00
)
return
headers = setup_.headers.copy()
logger.debug(
2021-07-31 12:24:11 +08:00
f"Bot {setup_.self_id} from adapter {setup_.adapter} connecting to {url}"
)
try:
connection = Connect(setup_.url, extra_headers=headers)
async with connection as ws:
logger.opt(colors=True).info(
f"WebSocket Connection to <y>{escape_tag(setup_.adapter.upper())} "
f"Bot {escape_tag(setup_.self_id)}</y> succeeded!"
)
request = WebSocket(
"1.1", url.scheme, url.path, url.query, headers, ws
)
2021-07-31 12:24:11 +08:00
BotClass = self._adapters[setup_.adapter]
bot = BotClass(setup_.self_id, request)
self._bot_connect(bot)
while not self.shutdown.is_set():
# use try except instead of "request.closed" because of queued message
try:
msg = await request.receive_bytes()
asyncio.create_task(bot.handle_message(msg))
except ConnectionClosed:
logger.opt(colors=True).error(
"<r><bg #f8bbd0>WebSocket connection closed by peer. "
"Try to reconnect...</bg #f8bbd0></r>"
)
break
except Exception as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error while connecting to {url}. "
"Try to reconnect...</bg #f8bbd0></r>"
)
finally:
if bot:
self._bot_disconnect(bot)
bot = None
2021-07-31 12:24:11 +08:00
await asyncio.sleep(setup_.reconnect_interval)
except asyncio.CancelledError:
pass
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Unexpected exception occurred "
"while websocket loop</bg #f8bbd0></r>"
)
2021-06-10 21:52:20 +08:00
@dataclass
class WebSocket(BaseWebSocket):
websocket: Union[FastAPIWebSocket, WebSocketClientProtocol] = None # type: ignore
@property
2020-08-07 17:51:57 +08:00
@overrides(BaseWebSocket)
def closed(self) -> bool:
if isinstance(self.websocket, FastAPIWebSocket):
return (
self.websocket.client_state == WebSocketState.DISCONNECTED
or self.websocket.application_state == WebSocketState.DISCONNECTED
)
else:
return self.websocket.closed
2020-08-07 17:51:57 +08:00
@overrides(BaseWebSocket)
async def accept(self):
if isinstance(self.websocket, FastAPIWebSocket):
await self.websocket.accept()
else:
raise NotImplementedError
2020-08-07 17:51:57 +08:00
@overrides(BaseWebSocket)
2020-08-01 22:03:40 +08:00
async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE):
await self.websocket.close(code)
2020-08-07 17:51:57 +08:00
@overrides(BaseWebSocket)
2021-06-10 21:52:20 +08:00
async def receive(self) -> str:
if isinstance(self.websocket, FastAPIWebSocket):
return await self.websocket.receive_text()
else:
msg = await self.websocket.recv()
return msg.decode("utf-8") if isinstance(msg, bytes) else msg
2021-06-10 21:52:20 +08:00
@overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes:
if isinstance(self.websocket, FastAPIWebSocket):
return await self.websocket.receive_bytes()
else:
msg = await self.websocket.recv()
return msg.encode("utf-8") if isinstance(msg, str) else msg
2021-06-10 21:52:20 +08:00
@overrides(BaseWebSocket)
async def send(self, data: str) -> None:
if isinstance(self.websocket, FastAPIWebSocket):
await self.websocket.send({"type": "websocket.send", "text": data})
else:
await self.websocket.send(data)
2020-08-07 17:51:57 +08:00
@overrides(BaseWebSocket)
2021-06-10 21:52:20 +08:00
async def send_bytes(self, data: bytes) -> None:
if isinstance(self.websocket, FastAPIWebSocket):
await self.websocket.send({"type": "websocket.send", "bytes": data})
else:
await self.websocket.send(data)