diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py
index ec579f51..3b355e3f 100644
--- a/nonebot/drivers/fastapi.py
+++ b/nonebot/drivers/fastapi.py
@@ -13,7 +13,7 @@ FastAPI 驱动适配
import asyncio
import logging
from dataclasses import dataclass
-from typing import List, Union, Callable, Optional, Awaitable, cast
+from typing import List, Union, TypeVar, Callable, Optional, Awaitable, cast
import httpx
import uvicorn
@@ -40,6 +40,7 @@ from nonebot.drivers import (
HTTPPollingSetup,
)
+S = TypeVar("S", bound=Union[HTTPPollingSetup, WebSocketSetup])
HTTPPOLLING_SETUP = Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]
WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
@@ -408,90 +409,98 @@ class FullDriver(ForwardDriver, Driver):
if not task.done():
task.cancel()
- async def _http_loop(self, setup: HTTPPOLLING_SETUP):
- async def _build_request(setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
- url = httpx.URL(setup.url)
- if not url.netloc:
- logger.opt(colors=True).error(
- f"Error parsing url {escape_tag(str(url))}"
- )
- return
- return HTTPRequest(
- setup.http_version,
- url.scheme,
- url.path,
- url.query,
- {**setup.headers, "host": url.netloc.decode("ascii")},
- setup.method,
- setup.body,
+ async def _prepare_setup(
+ self, setup: Union[S, Callable[[], Awaitable[S]]]
+ ) -> Optional[S]:
+ try:
+ if callable(setup):
+ return await setup()
+ else:
+ return setup
+ except Exception as e:
+ logger.opt(colors=True, exception=e).error(
+ "Error while parsing setup "
+ f"{escape_tag(repr(setup))}."
)
+ return
- bot: Optional[Bot] = None
- request: Optional[HTTPRequest] = None
- setup_: Optional[HTTPPollingSetup] = None
-
- logger.opt(colors=True).info(
- f"Start http polling for {escape_tag(setup.adapter.upper())} "
- f"Bot {escape_tag(setup.self_id)}"
+ def _build_http_request(self, setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
+ url = httpx.URL(setup.url)
+ if not url.netloc:
+ logger.opt(colors=True).error(
+ f"Error parsing url {escape_tag(str(url))}"
+ )
+ return
+ return HTTPRequest(
+ setup.http_version,
+ url.scheme,
+ url.path,
+ url.query,
+ setup.headers,
+ setup.method,
+ setup.body,
)
+ async def _http_loop(self, _setup: HTTPPOLLING_SETUP):
+
+ http2: bool = False
+ bot: Optional[Bot] = None
+ request: Optional[HTTPRequest] = None
+ client: Optional[httpx.AsyncClient] = None
+
+ # FIXME: seperate const values from setup (self_id, adapter)
+ # logger.opt(colors=True).info(
+ # f"Start http polling for {escape_tag(_setup.adapter.upper())} "
+ # f"Bot {escape_tag(_setup.self_id)}"
+ # )
+
try:
- async with httpx.AsyncClient(http2=True, follow_redirects=True) as session:
- while not self.shutdown.is_set():
+ while not self.shutdown.is_set():
- try:
- if callable(setup):
- setup_ = await setup()
- else:
- setup_ = setup
- except Exception as e:
- logger.opt(colors=True, exception=e).error(
- "Error while parsing setup "
- f"{escape_tag(repr(setup))}."
- )
- await asyncio.sleep(3)
- continue
+ setup = await self._prepare_setup(_setup)
+ if not setup:
+ await asyncio.sleep(3)
+ continue
+ request = self._build_http_request(setup)
+ if not request:
+ await asyncio.sleep(setup.poll_interval)
+ continue
- setup_ = cast(HTTPPollingSetup, setup_)
+ if not client:
+ client = httpx.AsyncClient(http2=http2, follow_redirects=True)
+ elif http2 != (setup.http_version == "2"):
+ await client.aclose()
+ client = httpx.AsyncClient(http2=http2, follow_redirects=True)
+ http2 = setup.http_version == "2"
- if not bot:
- request = await _build_request(setup_)
- if not request:
- return
- BotClass = self._adapters[setup.adapter]
- bot = BotClass(setup.self_id, request)
- self._bot_connect(bot)
- elif callable(setup):
- request = await _build_request(setup_)
- if not request:
- await asyncio.sleep(setup_.poll_interval)
- continue
- bot.request = request
+ if not bot:
+ BotClass = self._adapters[setup.adapter]
+ bot = BotClass(setup.self_id, request)
+ self._bot_connect(bot)
+ else:
+ bot.request = request
- request = cast(HTTPRequest, request)
- headers = request.headers
-
- logger.debug(
- f"Bot {setup_.self_id} from adapter {setup_.adapter} request {setup_.url}"
+ logger.debug(
+ f"Bot {setup.self_id} from adapter {setup.adapter} request {setup.url}"
+ )
+ try:
+ response = await client.request(
+ request.method,
+ setup.url,
+ content=request.body,
+ headers=request.headers,
+ timeout=30.0,
+ )
+ response.raise_for_status()
+ data = response.read()
+ asyncio.create_task(bot.handle_message(data))
+ except httpx.HTTPError as e:
+ logger.opt(colors=True, exception=e).error(
+ f"Error occurred while requesting {escape_tag(setup.url)}. "
+ "Try to reconnect..."
)
- try:
- response = await session.request(
- request.method,
- setup_.url,
- content=request.body,
- headers=headers,
- timeout=30.0,
- )
- response.raise_for_status()
- data = response.read()
- asyncio.create_task(bot.handle_message(data))
- except httpx.HTTPError as e:
- logger.opt(colors=True, exception=e).error(
- f"Error occurred while requesting {escape_tag(setup_.url)}. "
- "Try to reconnect..."
- )
- await asyncio.sleep(setup_.poll_interval)
+ await asyncio.sleep(setup.poll_interval)
except asyncio.CancelledError:
pass
@@ -503,50 +512,43 @@ class FullDriver(ForwardDriver, Driver):
finally:
if bot:
self._bot_disconnect(bot)
+ if client:
+ await client.aclose()
- async def _ws_loop(self, setup: WEBSOCKET_SETUP):
+ async def _ws_loop(self, _setup: WEBSOCKET_SETUP):
bot: Optional[Bot] = None
try:
while True:
- try:
- if callable(setup):
- setup_ = await setup()
- else:
- setup_ = setup
- except Exception as e:
- logger.opt(colors=True, exception=e).error(
- "Error while parsing setup "
- f"{escape_tag(repr(setup))}."
- )
+ setup = await self._prepare_setup(_setup)
+ if not setup:
await asyncio.sleep(3)
continue
- url = httpx.URL(setup_.url)
+ url = httpx.URL(setup.url)
if not url.netloc:
logger.opt(colors=True).error(
f"Error parsing url {escape_tag(str(url))}"
)
return
- headers = setup_.headers.copy()
logger.debug(
- f"Bot {setup_.self_id} from adapter {setup_.adapter} connecting to {url}"
+ f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}"
)
try:
- connection = Connect(setup_.url, extra_headers=headers)
+ connection = Connect(setup.url, extra_headers=setup.headers)
async with connection as ws:
logger.opt(colors=True).info(
- f"WebSocket Connection to {escape_tag(setup_.adapter.upper())} "
- f"Bot {escape_tag(setup_.self_id)} succeeded!"
+ f"WebSocket Connection to {escape_tag(setup.adapter.upper())} "
+ f"Bot {escape_tag(setup.self_id)} succeeded!"
)
request = WebSocket(
- "1.1", url.scheme, url.path, url.query, headers, ws
+ "1.1", url.scheme, url.path, url.query, setup.headers, ws
)
- BotClass = self._adapters[setup_.adapter]
- bot = BotClass(setup_.self_id, request)
+ BotClass = self._adapters[setup.adapter]
+ bot = BotClass(setup.self_id, request)
self._bot_connect(bot)
while not self.shutdown.is_set():
# use try except instead of "request.closed" because of queued message
@@ -569,12 +571,10 @@ class FullDriver(ForwardDriver, Driver):
self._bot_disconnect(bot)
bot = None
- if not setup_.reconnect:
- logger.info(
- f"WebSocket reconnect disabled for bot {setup_.self_id}"
- )
+ if not setup.reconnect:
+ logger.info(f"WebSocket reconnect disabled for bot {setup.self_id}")
break
- await asyncio.sleep(setup_.reconnect_interval)
+ await asyncio.sleep(setup.reconnect_interval)
except asyncio.CancelledError:
pass