🐛 fix http2 for fastapi

This commit is contained in:
yanyongyu 2021-12-02 20:52:39 +08:00
parent 534119eaf0
commit 226fc0feb3

View File

@ -13,7 +13,7 @@ FastAPI 驱动适配
import asyncio
import logging
from dataclasses import dataclass
from typing import List, Union, Callable, Optional, Awaitable, cast
from typing import List, Union, TypeVar, Callable, Optional, Awaitable, cast
import httpx
import uvicorn
@ -40,6 +40,7 @@ from nonebot.drivers import (
HTTPPollingSetup,
)
S = TypeVar("S", bound=Union[HTTPPollingSetup, WebSocketSetup])
HTTPPOLLING_SETUP = Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]
WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
@ -408,8 +409,22 @@ class FullDriver(ForwardDriver, Driver):
if not task.done():
task.cancel()
async def _http_loop(self, setup: HTTPPOLLING_SETUP):
async def _build_request(setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
async def _prepare_setup(
self, setup: Union[S, Callable[[], Awaitable[S]]]
) -> Optional[S]:
try:
if callable(setup):
return await setup()
else:
return setup
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error while parsing setup "
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
)
return
def _build_http_request(self, setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
url = httpx.URL(setup.url)
if not url.netloc:
logger.opt(colors=True).error(
@ -421,65 +436,59 @@ class FullDriver(ForwardDriver, Driver):
url.scheme,
url.path,
url.query,
{**setup.headers, "host": url.netloc.decode("ascii")},
setup.headers,
setup.method,
setup.body,
)
async def _http_loop(self, _setup: HTTPPOLLING_SETUP):
http2: bool = False
bot: Optional[Bot] = None
request: Optional[HTTPRequest] = None
setup_: Optional[HTTPPollingSetup] = None
client: Optional[httpx.AsyncClient] = None
logger.opt(colors=True).info(
f"Start http polling for <y>{escape_tag(setup.adapter.upper())} "
f"Bot {escape_tag(setup.self_id)}</y>"
)
# FIXME: seperate const values from setup (self_id, adapter)
# logger.opt(colors=True).info(
# f"Start http polling for <y>{escape_tag(_setup.adapter.upper())} "
# f"Bot {escape_tag(_setup.self_id)}</y>"
# )
try:
async with httpx.AsyncClient(http2=True, follow_redirects=True) as session:
while not self.shutdown.is_set():
try:
if callable(setup):
setup_ = await setup()
else:
setup_ = setup
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error while parsing setup "
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
)
setup = await self._prepare_setup(_setup)
if not setup:
await asyncio.sleep(3)
continue
request = self._build_http_request(setup)
if not request:
await asyncio.sleep(setup.poll_interval)
continue
setup_ = cast(HTTPPollingSetup, setup_)
if not client:
client = httpx.AsyncClient(http2=http2, follow_redirects=True)
elif http2 != (setup.http_version == "2"):
await client.aclose()
client = httpx.AsyncClient(http2=http2, follow_redirects=True)
http2 = setup.http_version == "2"
if not bot:
request = await _build_request(setup_)
if not request:
return
BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request)
self._bot_connect(bot)
elif callable(setup):
request = await _build_request(setup_)
if not request:
await asyncio.sleep(setup_.poll_interval)
continue
else:
bot.request = request
request = cast(HTTPRequest, request)
headers = request.headers
logger.debug(
f"Bot {setup_.self_id} from adapter {setup_.adapter} request {setup_.url}"
f"Bot {setup.self_id} from adapter {setup.adapter} request {setup.url}"
)
try:
response = await session.request(
response = await client.request(
request.method,
setup_.url,
setup.url,
content=request.body,
headers=headers,
headers=request.headers,
timeout=30.0,
)
response.raise_for_status()
@ -487,11 +496,11 @@ class FullDriver(ForwardDriver, Driver):
asyncio.create_task(bot.handle_message(data))
except httpx.HTTPError as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error occurred while requesting {escape_tag(setup_.url)}. "
f"<r><bg #f8bbd0>Error occurred while requesting {escape_tag(setup.url)}. "
"Try to reconnect...</bg #f8bbd0></r>"
)
await asyncio.sleep(setup_.poll_interval)
await asyncio.sleep(setup.poll_interval)
except asyncio.CancelledError:
pass
@ -503,50 +512,43 @@ class FullDriver(ForwardDriver, Driver):
finally:
if bot:
self._bot_disconnect(bot)
if client:
await client.aclose()
async def _ws_loop(self, setup: WEBSOCKET_SETUP):
async def _ws_loop(self, _setup: WEBSOCKET_SETUP):
bot: Optional[Bot] = None
try:
while True:
try:
if callable(setup):
setup_ = await setup()
else:
setup_ = setup
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error while parsing setup "
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
)
setup = await self._prepare_setup(_setup)
if not setup:
await asyncio.sleep(3)
continue
url = httpx.URL(setup_.url)
url = httpx.URL(setup.url)
if not url.netloc:
logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
)
return
headers = setup_.headers.copy()
logger.debug(
f"Bot {setup_.self_id} from adapter {setup_.adapter} connecting to {url}"
f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}"
)
try:
connection = Connect(setup_.url, extra_headers=headers)
connection = Connect(setup.url, extra_headers=setup.headers)
async with connection as ws:
logger.opt(colors=True).info(
f"WebSocket Connection to <y>{escape_tag(setup_.adapter.upper())} "
f"Bot {escape_tag(setup_.self_id)}</y> succeeded!"
f"WebSocket Connection to <y>{escape_tag(setup.adapter.upper())} "
f"Bot {escape_tag(setup.self_id)}</y> succeeded!"
)
request = WebSocket(
"1.1", url.scheme, url.path, url.query, headers, ws
"1.1", url.scheme, url.path, url.query, setup.headers, ws
)
BotClass = self._adapters[setup_.adapter]
bot = BotClass(setup_.self_id, request)
BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request)
self._bot_connect(bot)
while not self.shutdown.is_set():
# use try except instead of "request.closed" because of queued message
@ -569,12 +571,10 @@ class FullDriver(ForwardDriver, Driver):
self._bot_disconnect(bot)
bot = None
if not setup_.reconnect:
logger.info(
f"WebSocket reconnect disabled for bot {setup_.self_id}"
)
if not setup.reconnect:
logger.info(f"WebSocket reconnect disabled for bot {setup.self_id}")
break
await asyncio.sleep(setup_.reconnect_interval)
await asyncio.sleep(setup.reconnect_interval)
except asyncio.CancelledError:
pass