From 32787fdc1ec2f9e5347544338b53f2164fc58d36 Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Mon, 19 Jul 2021 14:51:28 +0800 Subject: [PATCH] :alembic: experimenting aiohttp driver --- nonebot/drivers/aiohttp.py | 126 +++++++++++++++++++++++++++++++------ 1 file changed, 106 insertions(+), 20 deletions(-) diff --git a/nonebot/drivers/aiohttp.py b/nonebot/drivers/aiohttp.py index a5bd71e5..035f4c84 100644 --- a/nonebot/drivers/aiohttp.py +++ b/nonebot/drivers/aiohttp.py @@ -3,9 +3,11 @@ import signal import asyncio -from typing import Set, Union, Callable, Awaitable, DefaultDict +from dataclasses import dataclass +from typing import Set, List, Union, Callable, Awaitable import aiohttp +from yarl import URL from nonebot.log import logger from nonebot.adapters import Bot @@ -19,14 +21,21 @@ SHUTDOWN_FUNC = Callable[[], Awaitable[None]] AVAILABLE_REQUEST = Union[HTTPRequest, WebSocket] +@dataclass +class RequestSetup: + adapter: str + request: AVAILABLE_REQUEST + poll_interval: float + reconnect_interval: float + + class Driver(ForwardDriver): def __init__(self, env: Env, config: Config): super().__init__(env, config) self.startup_funcs: Set[STARTUP_FUNC] = set() self.shutdown_funcs: Set[SHUTDOWN_FUNC] = set() - self.requests: DefaultDict[str, - Set[AVAILABLE_REQUEST]] = DefaultDict(set) + self.requests: List[RequestSetup] = [] @property @overrides(ForwardDriver) @@ -50,10 +59,15 @@ class Driver(ForwardDriver): return func @overrides(ForwardDriver) - def setup(self, adapter: str, request: HTTPConnection) -> None: + def setup(self, + adapter: str, + request: HTTPConnection, + poll_interval: float = 3., + reconnect_interval: float = 3.) -> None: if not isinstance(request, (HTTPRequest, WebSocket)): raise TypeError(f"Request Type {type(request)!r} is not supported!") - self.requests[adapter].add(request) + self.requests.append( + RequestSetup(adapter, request, poll_interval, reconnect_interval)) @overrides(ForwardDriver) def run(self, *args, **kwargs): @@ -73,12 +87,15 @@ class Driver(ForwardDriver): async def startup(self): setups = [] loop = asyncio.get_event_loop() - for adapter, requests in self.requests.items(): - for request in requests: - if isinstance(request, HTTPRequest): - setups.append(self._http_setup(adapter, request)) - else: - setups.append(self._ws_setup(adapter, request)) + for setup in self.requests: + if isinstance(setup.request, HTTPRequest): + setups.append( + self._http_setup(setup.adapter, setup.request, + setup.poll_interval)) + else: + setups.append( + self._ws_setup(setup.adapter, setup.request, + setup.reconnect_interval)) try: await asyncio.gather(*setups) @@ -125,7 +142,8 @@ class Driver(ForwardDriver): loop.stop() - async def _http_setup(self, adapter: str, request: HTTPRequest): + async def _http_setup(self, adapter: str, request: HTTPRequest, + poll_interval: float): BotClass = self._adapters[adapter] self_id, _ = await BotClass.check_permission(self, request) @@ -134,16 +152,84 @@ class Driver(ForwardDriver): bot = BotClass(self_id, request) self._bot_connect(bot) - asyncio.create_task(self._http_loop(bot, request)) + asyncio.create_task(self._http_loop(bot, request, poll_interval)) - async def _ws_setup(self, adapter: str, request: WebSocket): - ... + async def _ws_setup(self, adapter: str, request: WebSocket, + reconnect_interval: float): + BotClass = self._adapters[adapter] + self_id, _ = await BotClass.check_permission(self, request) - async def _http_loop(self, bot: Bot, request: HTTPRequest): - # TODO: main loop for HTTP long polling + if not self_id: + raise SetupFailed("Bot self_id get failed") + + bot = BotClass(self_id, request) + self._bot_connect(bot) + asyncio.create_task(self._ws_loop(bot, request, reconnect_interval)) + + async def _http_loop(self, bot: Bot, request: HTTPRequest, + poll_interval: float): try: - while True: - ... - # include asyncio.CancelledError + headers = request.headers + url = URL.build(scheme=request.scheme, + host=request.headers["host"], + path=request.path, + query_string=request.query_string.decode("latin-1")) + timeout = aiohttp.ClientTimeout(30) + async with aiohttp.ClientSession(headers=headers, + timeout=timeout) as session: + while True: + try: + async with session.request( + request.method, url, + data=request.body) as response: + response.raise_for_status() + data = await response.read() + asyncio.create_task(bot.handle_message(data)) + except aiohttp.ClientResponseError as e: + logger.opt(colors=True, exception=e).error( + f"Error occurred while requesting {url}") + + await asyncio.sleep(poll_interval) + + except asyncio.CancelledError: + pass + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Unexpected exception occurred while http polling") + finally: + self._bot_disconnect(bot) + + async def _ws_loop(self, bot: Bot, request: WebSocket, + reconnect_interval: float): + try: + headers = request.headers + url = URL.build(scheme=request.scheme, + host=request.headers["host"], + path=request.path, + query_string=request.query_string.decode("latin-1")) + timeout = aiohttp.ClientTimeout(30) + async with aiohttp.ClientSession(headers=headers, + timeout=timeout) as session: + while True: + async with session.ws_connect(url) as ws: + async for msg in ws: + if msg.type == aiohttp.WSMsgType.text: + asyncio.create_task( + bot.handle_message(msg.data.encode())) + elif msg.type == aiohttp.WSMsgType.binary: + asyncio.create_task(bot.handle_message( + msg.data)) + elif msg.type == aiohttp.WSMsgType.error: + logger.opt(colors=True).error( + "Error while handling websocket frame. " + "Try to reconnect...") + break + asyncio.sleep(reconnect_interval) + + except asyncio.CancelledError: + pass + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Unexpected exception occurred while websocket loop") finally: self._bot_disconnect(bot)