⚗️ experimenting aiohttp driver

This commit is contained in:
yanyongyu 2021-07-19 14:51:28 +08:00
parent 637c48aea7
commit 32787fdc1e

View File

@ -3,9 +3,11 @@
import signal
import asyncio
from typing import Set, Union, Callable, Awaitable, DefaultDict
from dataclasses import dataclass
from typing import Set, List, Union, Callable, Awaitable
import aiohttp
from yarl import URL
from nonebot.log import logger
from nonebot.adapters import Bot
@ -19,14 +21,21 @@ SHUTDOWN_FUNC = Callable[[], Awaitable[None]]
AVAILABLE_REQUEST = Union[HTTPRequest, WebSocket]
@dataclass
class RequestSetup:
adapter: str
request: AVAILABLE_REQUEST
poll_interval: float
reconnect_interval: float
class Driver(ForwardDriver):
def __init__(self, env: Env, config: Config):
super().__init__(env, config)
self.startup_funcs: Set[STARTUP_FUNC] = set()
self.shutdown_funcs: Set[SHUTDOWN_FUNC] = set()
self.requests: DefaultDict[str,
Set[AVAILABLE_REQUEST]] = DefaultDict(set)
self.requests: List[RequestSetup] = []
@property
@overrides(ForwardDriver)
@ -50,10 +59,15 @@ class Driver(ForwardDriver):
return func
@overrides(ForwardDriver)
def setup(self, adapter: str, request: HTTPConnection) -> None:
def setup(self,
adapter: str,
request: HTTPConnection,
poll_interval: float = 3.,
reconnect_interval: float = 3.) -> None:
if not isinstance(request, (HTTPRequest, WebSocket)):
raise TypeError(f"Request Type {type(request)!r} is not supported!")
self.requests[adapter].add(request)
self.requests.append(
RequestSetup(adapter, request, poll_interval, reconnect_interval))
@overrides(ForwardDriver)
def run(self, *args, **kwargs):
@ -73,12 +87,15 @@ class Driver(ForwardDriver):
async def startup(self):
setups = []
loop = asyncio.get_event_loop()
for adapter, requests in self.requests.items():
for request in requests:
if isinstance(request, HTTPRequest):
setups.append(self._http_setup(adapter, request))
else:
setups.append(self._ws_setup(adapter, request))
for setup in self.requests:
if isinstance(setup.request, HTTPRequest):
setups.append(
self._http_setup(setup.adapter, setup.request,
setup.poll_interval))
else:
setups.append(
self._ws_setup(setup.adapter, setup.request,
setup.reconnect_interval))
try:
await asyncio.gather(*setups)
@ -125,7 +142,8 @@ class Driver(ForwardDriver):
loop.stop()
async def _http_setup(self, adapter: str, request: HTTPRequest):
async def _http_setup(self, adapter: str, request: HTTPRequest,
poll_interval: float):
BotClass = self._adapters[adapter]
self_id, _ = await BotClass.check_permission(self, request)
@ -134,16 +152,84 @@ class Driver(ForwardDriver):
bot = BotClass(self_id, request)
self._bot_connect(bot)
asyncio.create_task(self._http_loop(bot, request))
asyncio.create_task(self._http_loop(bot, request, poll_interval))
async def _ws_setup(self, adapter: str, request: WebSocket):
...
async def _ws_setup(self, adapter: str, request: WebSocket,
reconnect_interval: float):
BotClass = self._adapters[adapter]
self_id, _ = await BotClass.check_permission(self, request)
async def _http_loop(self, bot: Bot, request: HTTPRequest):
# TODO: main loop for HTTP long polling
if not self_id:
raise SetupFailed("Bot self_id get failed")
bot = BotClass(self_id, request)
self._bot_connect(bot)
asyncio.create_task(self._ws_loop(bot, request, reconnect_interval))
async def _http_loop(self, bot: Bot, request: HTTPRequest,
poll_interval: float):
try:
while True:
...
# include asyncio.CancelledError
headers = request.headers
url = URL.build(scheme=request.scheme,
host=request.headers["host"],
path=request.path,
query_string=request.query_string.decode("latin-1"))
timeout = aiohttp.ClientTimeout(30)
async with aiohttp.ClientSession(headers=headers,
timeout=timeout) as session:
while True:
try:
async with session.request(
request.method, url,
data=request.body) as response:
response.raise_for_status()
data = await response.read()
asyncio.create_task(bot.handle_message(data))
except aiohttp.ClientResponseError as e:
logger.opt(colors=True, exception=e).error(
f"Error occurred while requesting {url}")
await asyncio.sleep(poll_interval)
except asyncio.CancelledError:
pass
except Exception as e:
logger.opt(colors=True, exception=e).error(
"Unexpected exception occurred while http polling")
finally:
self._bot_disconnect(bot)
async def _ws_loop(self, bot: Bot, request: WebSocket,
reconnect_interval: float):
try:
headers = request.headers
url = URL.build(scheme=request.scheme,
host=request.headers["host"],
path=request.path,
query_string=request.query_string.decode("latin-1"))
timeout = aiohttp.ClientTimeout(30)
async with aiohttp.ClientSession(headers=headers,
timeout=timeout) as session:
while True:
async with session.ws_connect(url) as ws:
async for msg in ws:
if msg.type == aiohttp.WSMsgType.text:
asyncio.create_task(
bot.handle_message(msg.data.encode()))
elif msg.type == aiohttp.WSMsgType.binary:
asyncio.create_task(bot.handle_message(
msg.data))
elif msg.type == aiohttp.WSMsgType.error:
logger.opt(colors=True).error(
"<r><bg #f8bbd0>Error while handling websocket frame. "
"Try to reconnect...</bg></r>")
break
asyncio.sleep(reconnect_interval)
except asyncio.CancelledError:
pass
except Exception as e:
logger.opt(colors=True, exception=e).error(
"Unexpected exception occurred while websocket loop")
finally:
self._bot_disconnect(bot)