🐛 fix http2 for fastapi

This commit is contained in:
yanyongyu 2021-12-02 20:52:39 +08:00
parent 534119eaf0
commit 226fc0feb3

View File

@ -13,7 +13,7 @@ FastAPI 驱动适配
import asyncio import asyncio
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Union, Callable, Optional, Awaitable, cast from typing import List, Union, TypeVar, Callable, Optional, Awaitable, cast
import httpx import httpx
import uvicorn import uvicorn
@ -40,6 +40,7 @@ from nonebot.drivers import (
HTTPPollingSetup, HTTPPollingSetup,
) )
S = TypeVar("S", bound=Union[HTTPPollingSetup, WebSocketSetup])
HTTPPOLLING_SETUP = Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]] HTTPPOLLING_SETUP = Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]
WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]] WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
@ -408,90 +409,98 @@ class FullDriver(ForwardDriver, Driver):
if not task.done(): if not task.done():
task.cancel() task.cancel()
async def _http_loop(self, setup: HTTPPOLLING_SETUP): async def _prepare_setup(
async def _build_request(setup: HTTPPollingSetup) -> Optional[HTTPRequest]: self, setup: Union[S, Callable[[], Awaitable[S]]]
url = httpx.URL(setup.url) ) -> Optional[S]:
if not url.netloc: try:
logger.opt(colors=True).error( if callable(setup):
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>" return await setup()
) else:
return return setup
return HTTPRequest( except Exception as e:
setup.http_version, logger.opt(colors=True, exception=e).error(
url.scheme, "<r><bg #f8bbd0>Error while parsing setup "
url.path, f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
url.query,
{**setup.headers, "host": url.netloc.decode("ascii")},
setup.method,
setup.body,
) )
return
bot: Optional[Bot] = None def _build_http_request(self, setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
request: Optional[HTTPRequest] = None url = httpx.URL(setup.url)
setup_: Optional[HTTPPollingSetup] = None if not url.netloc:
logger.opt(colors=True).error(
logger.opt(colors=True).info( f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
f"Start http polling for <y>{escape_tag(setup.adapter.upper())} " )
f"Bot {escape_tag(setup.self_id)}</y>" return
return HTTPRequest(
setup.http_version,
url.scheme,
url.path,
url.query,
setup.headers,
setup.method,
setup.body,
) )
async def _http_loop(self, _setup: HTTPPOLLING_SETUP):
http2: bool = False
bot: Optional[Bot] = None
request: Optional[HTTPRequest] = None
client: Optional[httpx.AsyncClient] = None
# FIXME: seperate const values from setup (self_id, adapter)
# logger.opt(colors=True).info(
# f"Start http polling for <y>{escape_tag(_setup.adapter.upper())} "
# f"Bot {escape_tag(_setup.self_id)}</y>"
# )
try: try:
async with httpx.AsyncClient(http2=True, follow_redirects=True) as session: while not self.shutdown.is_set():
while not self.shutdown.is_set():
try: setup = await self._prepare_setup(_setup)
if callable(setup): if not setup:
setup_ = await setup() await asyncio.sleep(3)
else: continue
setup_ = setup request = self._build_http_request(setup)
except Exception as e: if not request:
logger.opt(colors=True, exception=e).error( await asyncio.sleep(setup.poll_interval)
"<r><bg #f8bbd0>Error while parsing setup " continue
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
)
await asyncio.sleep(3)
continue
setup_ = cast(HTTPPollingSetup, setup_) if not client:
client = httpx.AsyncClient(http2=http2, follow_redirects=True)
elif http2 != (setup.http_version == "2"):
await client.aclose()
client = httpx.AsyncClient(http2=http2, follow_redirects=True)
http2 = setup.http_version == "2"
if not bot: if not bot:
request = await _build_request(setup_) BotClass = self._adapters[setup.adapter]
if not request: bot = BotClass(setup.self_id, request)
return self._bot_connect(bot)
BotClass = self._adapters[setup.adapter] else:
bot = BotClass(setup.self_id, request) bot.request = request
self._bot_connect(bot)
elif callable(setup):
request = await _build_request(setup_)
if not request:
await asyncio.sleep(setup_.poll_interval)
continue
bot.request = request
request = cast(HTTPRequest, request) logger.debug(
headers = request.headers f"Bot {setup.self_id} from adapter {setup.adapter} request {setup.url}"
)
logger.debug( try:
f"Bot {setup_.self_id} from adapter {setup_.adapter} request {setup_.url}" response = await client.request(
request.method,
setup.url,
content=request.body,
headers=request.headers,
timeout=30.0,
)
response.raise_for_status()
data = response.read()
asyncio.create_task(bot.handle_message(data))
except httpx.HTTPError as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error occurred while requesting {escape_tag(setup.url)}. "
"Try to reconnect...</bg #f8bbd0></r>"
) )
try:
response = await session.request(
request.method,
setup_.url,
content=request.body,
headers=headers,
timeout=30.0,
)
response.raise_for_status()
data = response.read()
asyncio.create_task(bot.handle_message(data))
except httpx.HTTPError as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error occurred while requesting {escape_tag(setup_.url)}. "
"Try to reconnect...</bg #f8bbd0></r>"
)
await asyncio.sleep(setup_.poll_interval) await asyncio.sleep(setup.poll_interval)
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
@ -503,50 +512,43 @@ class FullDriver(ForwardDriver, Driver):
finally: finally:
if bot: if bot:
self._bot_disconnect(bot) self._bot_disconnect(bot)
if client:
await client.aclose()
async def _ws_loop(self, setup: WEBSOCKET_SETUP): async def _ws_loop(self, _setup: WEBSOCKET_SETUP):
bot: Optional[Bot] = None bot: Optional[Bot] = None
try: try:
while True: while True:
try: setup = await self._prepare_setup(_setup)
if callable(setup): if not setup:
setup_ = await setup()
else:
setup_ = setup
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error while parsing setup "
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
)
await asyncio.sleep(3) await asyncio.sleep(3)
continue continue
url = httpx.URL(setup_.url) url = httpx.URL(setup.url)
if not url.netloc: if not url.netloc:
logger.opt(colors=True).error( logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>" f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
) )
return return
headers = setup_.headers.copy()
logger.debug( logger.debug(
f"Bot {setup_.self_id} from adapter {setup_.adapter} connecting to {url}" f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}"
) )
try: try:
connection = Connect(setup_.url, extra_headers=headers) connection = Connect(setup.url, extra_headers=setup.headers)
async with connection as ws: async with connection as ws:
logger.opt(colors=True).info( logger.opt(colors=True).info(
f"WebSocket Connection to <y>{escape_tag(setup_.adapter.upper())} " f"WebSocket Connection to <y>{escape_tag(setup.adapter.upper())} "
f"Bot {escape_tag(setup_.self_id)}</y> succeeded!" f"Bot {escape_tag(setup.self_id)}</y> succeeded!"
) )
request = WebSocket( request = WebSocket(
"1.1", url.scheme, url.path, url.query, headers, ws "1.1", url.scheme, url.path, url.query, setup.headers, ws
) )
BotClass = self._adapters[setup_.adapter] BotClass = self._adapters[setup.adapter]
bot = BotClass(setup_.self_id, request) bot = BotClass(setup.self_id, request)
self._bot_connect(bot) self._bot_connect(bot)
while not self.shutdown.is_set(): while not self.shutdown.is_set():
# use try except instead of "request.closed" because of queued message # use try except instead of "request.closed" because of queued message
@ -569,12 +571,10 @@ class FullDriver(ForwardDriver, Driver):
self._bot_disconnect(bot) self._bot_disconnect(bot)
bot = None bot = None
if not setup_.reconnect: if not setup.reconnect:
logger.info( logger.info(f"WebSocket reconnect disabled for bot {setup.self_id}")
f"WebSocket reconnect disabled for bot {setup_.self_id}"
)
break break
await asyncio.sleep(setup_.reconnect_interval) await asyncio.sleep(setup.reconnect_interval)
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass