diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index ec579f51..3b355e3f 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -13,7 +13,7 @@ FastAPI 驱动适配 import asyncio import logging from dataclasses import dataclass -from typing import List, Union, Callable, Optional, Awaitable, cast +from typing import List, Union, TypeVar, Callable, Optional, Awaitable, cast import httpx import uvicorn @@ -40,6 +40,7 @@ from nonebot.drivers import ( HTTPPollingSetup, ) +S = TypeVar("S", bound=Union[HTTPPollingSetup, WebSocketSetup]) HTTPPOLLING_SETUP = Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]] WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]] @@ -408,90 +409,98 @@ class FullDriver(ForwardDriver, Driver): if not task.done(): task.cancel() - async def _http_loop(self, setup: HTTPPOLLING_SETUP): - async def _build_request(setup: HTTPPollingSetup) -> Optional[HTTPRequest]: - url = httpx.URL(setup.url) - if not url.netloc: - logger.opt(colors=True).error( - f"Error parsing url {escape_tag(str(url))}" - ) - return - return HTTPRequest( - setup.http_version, - url.scheme, - url.path, - url.query, - {**setup.headers, "host": url.netloc.decode("ascii")}, - setup.method, - setup.body, + async def _prepare_setup( + self, setup: Union[S, Callable[[], Awaitable[S]]] + ) -> Optional[S]: + try: + if callable(setup): + return await setup() + else: + return setup + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Error while parsing setup " + f"{escape_tag(repr(setup))}." ) + return - bot: Optional[Bot] = None - request: Optional[HTTPRequest] = None - setup_: Optional[HTTPPollingSetup] = None - - logger.opt(colors=True).info( - f"Start http polling for {escape_tag(setup.adapter.upper())} " - f"Bot {escape_tag(setup.self_id)}" + def _build_http_request(self, setup: HTTPPollingSetup) -> Optional[HTTPRequest]: + url = httpx.URL(setup.url) + if not url.netloc: + logger.opt(colors=True).error( + f"Error parsing url {escape_tag(str(url))}" + ) + return + return HTTPRequest( + setup.http_version, + url.scheme, + url.path, + url.query, + setup.headers, + setup.method, + setup.body, ) + async def _http_loop(self, _setup: HTTPPOLLING_SETUP): + + http2: bool = False + bot: Optional[Bot] = None + request: Optional[HTTPRequest] = None + client: Optional[httpx.AsyncClient] = None + + # FIXME: seperate const values from setup (self_id, adapter) + # logger.opt(colors=True).info( + # f"Start http polling for {escape_tag(_setup.adapter.upper())} " + # f"Bot {escape_tag(_setup.self_id)}" + # ) + try: - async with httpx.AsyncClient(http2=True, follow_redirects=True) as session: - while not self.shutdown.is_set(): + while not self.shutdown.is_set(): - try: - if callable(setup): - setup_ = await setup() - else: - setup_ = setup - except Exception as e: - logger.opt(colors=True, exception=e).error( - "Error while parsing setup " - f"{escape_tag(repr(setup))}." - ) - await asyncio.sleep(3) - continue + setup = await self._prepare_setup(_setup) + if not setup: + await asyncio.sleep(3) + continue + request = self._build_http_request(setup) + if not request: + await asyncio.sleep(setup.poll_interval) + continue - setup_ = cast(HTTPPollingSetup, setup_) + if not client: + client = httpx.AsyncClient(http2=http2, follow_redirects=True) + elif http2 != (setup.http_version == "2"): + await client.aclose() + client = httpx.AsyncClient(http2=http2, follow_redirects=True) + http2 = setup.http_version == "2" - if not bot: - request = await _build_request(setup_) - if not request: - return - BotClass = self._adapters[setup.adapter] - bot = BotClass(setup.self_id, request) - self._bot_connect(bot) - elif callable(setup): - request = await _build_request(setup_) - if not request: - await asyncio.sleep(setup_.poll_interval) - continue - bot.request = request + if not bot: + BotClass = self._adapters[setup.adapter] + bot = BotClass(setup.self_id, request) + self._bot_connect(bot) + else: + bot.request = request - request = cast(HTTPRequest, request) - headers = request.headers - - logger.debug( - f"Bot {setup_.self_id} from adapter {setup_.adapter} request {setup_.url}" + logger.debug( + f"Bot {setup.self_id} from adapter {setup.adapter} request {setup.url}" + ) + try: + response = await client.request( + request.method, + setup.url, + content=request.body, + headers=request.headers, + timeout=30.0, + ) + response.raise_for_status() + data = response.read() + asyncio.create_task(bot.handle_message(data)) + except httpx.HTTPError as e: + logger.opt(colors=True, exception=e).error( + f"Error occurred while requesting {escape_tag(setup.url)}. " + "Try to reconnect..." ) - try: - response = await session.request( - request.method, - setup_.url, - content=request.body, - headers=headers, - timeout=30.0, - ) - response.raise_for_status() - data = response.read() - asyncio.create_task(bot.handle_message(data)) - except httpx.HTTPError as e: - logger.opt(colors=True, exception=e).error( - f"Error occurred while requesting {escape_tag(setup_.url)}. " - "Try to reconnect..." - ) - await asyncio.sleep(setup_.poll_interval) + await asyncio.sleep(setup.poll_interval) except asyncio.CancelledError: pass @@ -503,50 +512,43 @@ class FullDriver(ForwardDriver, Driver): finally: if bot: self._bot_disconnect(bot) + if client: + await client.aclose() - async def _ws_loop(self, setup: WEBSOCKET_SETUP): + async def _ws_loop(self, _setup: WEBSOCKET_SETUP): bot: Optional[Bot] = None try: while True: - try: - if callable(setup): - setup_ = await setup() - else: - setup_ = setup - except Exception as e: - logger.opt(colors=True, exception=e).error( - "Error while parsing setup " - f"{escape_tag(repr(setup))}." - ) + setup = await self._prepare_setup(_setup) + if not setup: await asyncio.sleep(3) continue - url = httpx.URL(setup_.url) + url = httpx.URL(setup.url) if not url.netloc: logger.opt(colors=True).error( f"Error parsing url {escape_tag(str(url))}" ) return - headers = setup_.headers.copy() logger.debug( - f"Bot {setup_.self_id} from adapter {setup_.adapter} connecting to {url}" + f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}" ) try: - connection = Connect(setup_.url, extra_headers=headers) + connection = Connect(setup.url, extra_headers=setup.headers) async with connection as ws: logger.opt(colors=True).info( - f"WebSocket Connection to {escape_tag(setup_.adapter.upper())} " - f"Bot {escape_tag(setup_.self_id)} succeeded!" + f"WebSocket Connection to {escape_tag(setup.adapter.upper())} " + f"Bot {escape_tag(setup.self_id)} succeeded!" ) request = WebSocket( - "1.1", url.scheme, url.path, url.query, headers, ws + "1.1", url.scheme, url.path, url.query, setup.headers, ws ) - BotClass = self._adapters[setup_.adapter] - bot = BotClass(setup_.self_id, request) + BotClass = self._adapters[setup.adapter] + bot = BotClass(setup.self_id, request) self._bot_connect(bot) while not self.shutdown.is_set(): # use try except instead of "request.closed" because of queued message @@ -569,12 +571,10 @@ class FullDriver(ForwardDriver, Driver): self._bot_disconnect(bot) bot = None - if not setup_.reconnect: - logger.info( - f"WebSocket reconnect disabled for bot {setup_.self_id}" - ) + if not setup.reconnect: + logger.info(f"WebSocket reconnect disabled for bot {setup.self_id}") break - await asyncio.sleep(setup_.reconnect_interval) + await asyncio.sleep(setup.reconnect_interval) except asyncio.CancelledError: pass