⚗️ experimenting aiohttp driver

This commit is contained in:
yanyongyu 2021-07-19 14:51:28 +08:00
parent 637c48aea7
commit 32787fdc1e

View File

@ -3,9 +3,11 @@
import signal import signal
import asyncio import asyncio
from typing import Set, Union, Callable, Awaitable, DefaultDict from dataclasses import dataclass
from typing import Set, List, Union, Callable, Awaitable
import aiohttp import aiohttp
from yarl import URL
from nonebot.log import logger from nonebot.log import logger
from nonebot.adapters import Bot from nonebot.adapters import Bot
@ -19,14 +21,21 @@ SHUTDOWN_FUNC = Callable[[], Awaitable[None]]
AVAILABLE_REQUEST = Union[HTTPRequest, WebSocket] AVAILABLE_REQUEST = Union[HTTPRequest, WebSocket]
@dataclass
class RequestSetup:
adapter: str
request: AVAILABLE_REQUEST
poll_interval: float
reconnect_interval: float
class Driver(ForwardDriver): class Driver(ForwardDriver):
def __init__(self, env: Env, config: Config): def __init__(self, env: Env, config: Config):
super().__init__(env, config) super().__init__(env, config)
self.startup_funcs: Set[STARTUP_FUNC] = set() self.startup_funcs: Set[STARTUP_FUNC] = set()
self.shutdown_funcs: Set[SHUTDOWN_FUNC] = set() self.shutdown_funcs: Set[SHUTDOWN_FUNC] = set()
self.requests: DefaultDict[str, self.requests: List[RequestSetup] = []
Set[AVAILABLE_REQUEST]] = DefaultDict(set)
@property @property
@overrides(ForwardDriver) @overrides(ForwardDriver)
@ -50,10 +59,15 @@ class Driver(ForwardDriver):
return func return func
@overrides(ForwardDriver) @overrides(ForwardDriver)
def setup(self, adapter: str, request: HTTPConnection) -> None: def setup(self,
adapter: str,
request: HTTPConnection,
poll_interval: float = 3.,
reconnect_interval: float = 3.) -> None:
if not isinstance(request, (HTTPRequest, WebSocket)): if not isinstance(request, (HTTPRequest, WebSocket)):
raise TypeError(f"Request Type {type(request)!r} is not supported!") raise TypeError(f"Request Type {type(request)!r} is not supported!")
self.requests[adapter].add(request) self.requests.append(
RequestSetup(adapter, request, poll_interval, reconnect_interval))
@overrides(ForwardDriver) @overrides(ForwardDriver)
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
@ -73,12 +87,15 @@ class Driver(ForwardDriver):
async def startup(self): async def startup(self):
setups = [] setups = []
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
for adapter, requests in self.requests.items(): for setup in self.requests:
for request in requests: if isinstance(setup.request, HTTPRequest):
if isinstance(request, HTTPRequest): setups.append(
setups.append(self._http_setup(adapter, request)) self._http_setup(setup.adapter, setup.request,
else: setup.poll_interval))
setups.append(self._ws_setup(adapter, request)) else:
setups.append(
self._ws_setup(setup.adapter, setup.request,
setup.reconnect_interval))
try: try:
await asyncio.gather(*setups) await asyncio.gather(*setups)
@ -125,7 +142,8 @@ class Driver(ForwardDriver):
loop.stop() loop.stop()
async def _http_setup(self, adapter: str, request: HTTPRequest): async def _http_setup(self, adapter: str, request: HTTPRequest,
poll_interval: float):
BotClass = self._adapters[adapter] BotClass = self._adapters[adapter]
self_id, _ = await BotClass.check_permission(self, request) self_id, _ = await BotClass.check_permission(self, request)
@ -134,16 +152,84 @@ class Driver(ForwardDriver):
bot = BotClass(self_id, request) bot = BotClass(self_id, request)
self._bot_connect(bot) self._bot_connect(bot)
asyncio.create_task(self._http_loop(bot, request)) asyncio.create_task(self._http_loop(bot, request, poll_interval))
async def _ws_setup(self, adapter: str, request: WebSocket): async def _ws_setup(self, adapter: str, request: WebSocket,
... reconnect_interval: float):
BotClass = self._adapters[adapter]
self_id, _ = await BotClass.check_permission(self, request)
async def _http_loop(self, bot: Bot, request: HTTPRequest): if not self_id:
# TODO: main loop for HTTP long polling raise SetupFailed("Bot self_id get failed")
bot = BotClass(self_id, request)
self._bot_connect(bot)
asyncio.create_task(self._ws_loop(bot, request, reconnect_interval))
async def _http_loop(self, bot: Bot, request: HTTPRequest,
poll_interval: float):
try: try:
while True: headers = request.headers
... url = URL.build(scheme=request.scheme,
# include asyncio.CancelledError host=request.headers["host"],
path=request.path,
query_string=request.query_string.decode("latin-1"))
timeout = aiohttp.ClientTimeout(30)
async with aiohttp.ClientSession(headers=headers,
timeout=timeout) as session:
while True:
try:
async with session.request(
request.method, url,
data=request.body) as response:
response.raise_for_status()
data = await response.read()
asyncio.create_task(bot.handle_message(data))
except aiohttp.ClientResponseError as e:
logger.opt(colors=True, exception=e).error(
f"Error occurred while requesting {url}")
await asyncio.sleep(poll_interval)
except asyncio.CancelledError:
pass
except Exception as e:
logger.opt(colors=True, exception=e).error(
"Unexpected exception occurred while http polling")
finally:
self._bot_disconnect(bot)
async def _ws_loop(self, bot: Bot, request: WebSocket,
reconnect_interval: float):
try:
headers = request.headers
url = URL.build(scheme=request.scheme,
host=request.headers["host"],
path=request.path,
query_string=request.query_string.decode("latin-1"))
timeout = aiohttp.ClientTimeout(30)
async with aiohttp.ClientSession(headers=headers,
timeout=timeout) as session:
while True:
async with session.ws_connect(url) as ws:
async for msg in ws:
if msg.type == aiohttp.WSMsgType.text:
asyncio.create_task(
bot.handle_message(msg.data.encode()))
elif msg.type == aiohttp.WSMsgType.binary:
asyncio.create_task(bot.handle_message(
msg.data))
elif msg.type == aiohttp.WSMsgType.error:
logger.opt(colors=True).error(
"<r><bg #f8bbd0>Error while handling websocket frame. "
"Try to reconnect...</bg></r>")
break
asyncio.sleep(reconnect_interval)
except asyncio.CancelledError:
pass
except Exception as e:
logger.opt(colors=True, exception=e).error(
"Unexpected exception occurred while websocket loop")
finally: finally:
self._bot_disconnect(bot) self._bot_disconnect(bot)