Merge pull request #615 from nonebot/fix/httpx-http2

Fix: http2 for fastapi driver
This commit is contained in:
Ju4tCode 2021-12-04 15:03:17 +08:00 committed by GitHub
commit b52954a240
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -13,7 +13,7 @@ FastAPI 驱动适配
import asyncio
import logging
from dataclasses import dataclass
from typing import List, Union, Callable, Optional, Awaitable, cast
from typing import List, Union, TypeVar, Callable, Optional, Awaitable, cast
import httpx
import uvicorn
@ -40,6 +40,7 @@ from nonebot.drivers import (
HTTPPollingSetup,
)
S = TypeVar("S", bound=Union[HTTPPollingSetup, WebSocketSetup])
HTTPPOLLING_SETUP = Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]
WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
@ -408,90 +409,98 @@ class FullDriver(ForwardDriver, Driver):
if not task.done():
task.cancel()
async def _http_loop(self, setup: HTTPPOLLING_SETUP):
async def _build_request(setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
url = httpx.URL(setup.url)
if not url.netloc:
logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
)
return
return HTTPRequest(
setup.http_version,
url.scheme,
url.path,
url.query,
{**setup.headers, "host": url.netloc.decode("ascii")},
setup.method,
setup.body,
async def _prepare_setup(
self, setup: Union[S, Callable[[], Awaitable[S]]]
) -> Optional[S]:
try:
if callable(setup):
return await setup()
else:
return setup
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error while parsing setup "
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
)
return
bot: Optional[Bot] = None
request: Optional[HTTPRequest] = None
setup_: Optional[HTTPPollingSetup] = None
logger.opt(colors=True).info(
f"Start http polling for <y>{escape_tag(setup.adapter.upper())} "
f"Bot {escape_tag(setup.self_id)}</y>"
def _build_http_request(self, setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
url = httpx.URL(setup.url)
if not url.netloc:
logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
)
return
return HTTPRequest(
setup.http_version,
url.scheme,
url.path,
url.query,
setup.headers,
setup.method,
setup.body,
)
async def _http_loop(self, _setup: HTTPPOLLING_SETUP):
http2: bool = False
bot: Optional[Bot] = None
request: Optional[HTTPRequest] = None
client: Optional[httpx.AsyncClient] = None
# FIXME: seperate const values from setup (self_id, adapter)
# logger.opt(colors=True).info(
# f"Start http polling for <y>{escape_tag(_setup.adapter.upper())} "
# f"Bot {escape_tag(_setup.self_id)}</y>"
# )
try:
async with httpx.AsyncClient(http2=True, follow_redirects=True) as session:
while not self.shutdown.is_set():
while not self.shutdown.is_set():
try:
if callable(setup):
setup_ = await setup()
else:
setup_ = setup
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error while parsing setup "
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
)
await asyncio.sleep(3)
continue
setup = await self._prepare_setup(_setup)
if not setup:
await asyncio.sleep(3)
continue
request = self._build_http_request(setup)
if not request:
await asyncio.sleep(setup.poll_interval)
continue
setup_ = cast(HTTPPollingSetup, setup_)
if not client:
client = httpx.AsyncClient(http2=setup.http_version == "2", follow_redirects=True)
elif http2 != (setup.http_version == "2"):
await client.aclose()
client = httpx.AsyncClient(http2=setup.http_version == "2", follow_redirects=True)
http2 = setup.http_version == "2"
if not bot:
request = await _build_request(setup_)
if not request:
return
BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request)
self._bot_connect(bot)
elif callable(setup):
request = await _build_request(setup_)
if not request:
await asyncio.sleep(setup_.poll_interval)
continue
bot.request = request
if not bot:
BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request)
self._bot_connect(bot)
else:
bot.request = request
request = cast(HTTPRequest, request)
headers = request.headers
logger.debug(
f"Bot {setup_.self_id} from adapter {setup_.adapter} request {setup_.url}"
logger.debug(
f"Bot {setup.self_id} from adapter {setup.adapter} request {setup.url}"
)
try:
response = await client.request(
request.method,
setup.url,
content=request.body,
headers=request.headers,
timeout=30.0,
)
response.raise_for_status()
data = response.read()
asyncio.create_task(bot.handle_message(data))
except httpx.HTTPError as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error occurred while requesting {escape_tag(setup.url)}. "
"Try to reconnect...</bg #f8bbd0></r>"
)
try:
response = await session.request(
request.method,
setup_.url,
content=request.body,
headers=headers,
timeout=30.0,
)
response.raise_for_status()
data = response.read()
asyncio.create_task(bot.handle_message(data))
except httpx.HTTPError as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error occurred while requesting {escape_tag(setup_.url)}. "
"Try to reconnect...</bg #f8bbd0></r>"
)
await asyncio.sleep(setup_.poll_interval)
await asyncio.sleep(setup.poll_interval)
except asyncio.CancelledError:
pass
@ -503,50 +512,43 @@ class FullDriver(ForwardDriver, Driver):
finally:
if bot:
self._bot_disconnect(bot)
if client:
await client.aclose()
async def _ws_loop(self, setup: WEBSOCKET_SETUP):
async def _ws_loop(self, _setup: WEBSOCKET_SETUP):
bot: Optional[Bot] = None
try:
while True:
try:
if callable(setup):
setup_ = await setup()
else:
setup_ = setup
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error while parsing setup "
f"{escape_tag(repr(setup))}.</bg #f8bbd0></r>"
)
setup = await self._prepare_setup(_setup)
if not setup:
await asyncio.sleep(3)
continue
url = httpx.URL(setup_.url)
url = httpx.URL(setup.url)
if not url.netloc:
logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {escape_tag(str(url))}</bg #f8bbd0></r>"
)
return
headers = setup_.headers.copy()
logger.debug(
f"Bot {setup_.self_id} from adapter {setup_.adapter} connecting to {url}"
f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}"
)
try:
connection = Connect(setup_.url, extra_headers=headers)
connection = Connect(setup.url, extra_headers=setup.headers)
async with connection as ws:
logger.opt(colors=True).info(
f"WebSocket Connection to <y>{escape_tag(setup_.adapter.upper())} "
f"Bot {escape_tag(setup_.self_id)}</y> succeeded!"
f"WebSocket Connection to <y>{escape_tag(setup.adapter.upper())} "
f"Bot {escape_tag(setup.self_id)}</y> succeeded!"
)
request = WebSocket(
"1.1", url.scheme, url.path, url.query, headers, ws
"1.1", url.scheme, url.path, url.query, setup.headers, ws
)
BotClass = self._adapters[setup_.adapter]
bot = BotClass(setup_.self_id, request)
BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request)
self._bot_connect(bot)
while not self.shutdown.is_set():
# use try except instead of "request.closed" because of queued message
@ -569,12 +571,10 @@ class FullDriver(ForwardDriver, Driver):
self._bot_disconnect(bot)
bot = None
if not setup_.reconnect:
logger.info(
f"WebSocket reconnect disabled for bot {setup_.self_id}"
)
if not setup.reconnect:
logger.info(f"WebSocket reconnect disabled for bot {setup.self_id}")
break
await asyncio.sleep(setup_.reconnect_interval)
await asyncio.sleep(setup.reconnect_interval)
except asyncio.CancelledError:
pass