From 7b204d72e61cd2bfa75b195df3e1837391e4bfa5 Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Sun, 26 Dec 2021 13:42:13 +0800 Subject: [PATCH] :wheelchair: change websocket client to context manager --- nonebot/drivers/__init__.py | 16 ++++++++++++++-- nonebot/drivers/aiohttp.py | 14 +++++++++----- nonebot/drivers/httpx.py | 9 +++++++-- nonebot/drivers/websockets.py | 10 +++++++--- 4 files changed, 37 insertions(+), 12 deletions(-) diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index 079534ce..7c7bf3a9 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -8,7 +8,17 @@ import abc import asyncio from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, Awaitable +from contextlib import asynccontextmanager +from typing import ( + TYPE_CHECKING, + Any, + Set, + Dict, + Type, + Callable, + Awaitable, + AsyncGenerator, +) from ._model import URL as URL from nonebot.log import logger @@ -215,8 +225,10 @@ class ForwardMixin(abc.ABC): raise NotImplementedError @abc.abstractmethod - async def websocket(self, setup: Request) -> WebSocket: + @asynccontextmanager + async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]: raise NotImplementedError + yield # used for static type checking's generator detection class ForwardDriver(Driver, ForwardMixin): diff --git a/nonebot/drivers/aiohttp.py b/nonebot/drivers/aiohttp.py index 264e78e8..080744cf 100644 --- a/nonebot/drivers/aiohttp.py +++ b/nonebot/drivers/aiohttp.py @@ -5,6 +5,9 @@ AIOHTTP 驱动适配 本驱动仅支持客户端连接 """ +from typing import AsyncGenerator +from contextlib import asynccontextmanager + from nonebot.typing import overrides from nonebot.drivers import Request, Response from nonebot.drivers._block_driver import BlockDriver @@ -59,7 +62,8 @@ class Mixin(ForwardMixin): return res @overrides(ForwardMixin) - async def websocket(self, setup: Request) -> "WebSocket": + @asynccontextmanager + async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]: if setup.version == HTTPVersion.H10: version = aiohttp.HttpVersion10 elif setup.version == HTTPVersion.H11: @@ -68,15 +72,15 @@ class Mixin(ForwardMixin): raise RuntimeError(f"Unsupported HTTP version: {setup.version}") session = aiohttp.ClientSession(version=version, trust_env=True) - ws = await session.ws_connect( + async with session.ws_connect( setup.url, method=setup.method, timeout=setup.timeout or 10, headers=setup.headers, proxy=setup.proxy, - ) - websocket = WebSocket(request=setup, session=session, websocket=ws) - return websocket + ) as ws: + websocket = WebSocket(request=setup, session=session, websocket=ws) + yield websocket class WebSocket(BaseWebSocket): diff --git a/nonebot/drivers/httpx.py b/nonebot/drivers/httpx.py index 742b6b6c..fe82e50d 100644 --- a/nonebot/drivers/httpx.py +++ b/nonebot/drivers/httpx.py @@ -1,3 +1,6 @@ +from typing import AsyncGenerator +from contextlib import asynccontextmanager + from nonebot.typing import overrides from nonebot.drivers._block_driver import BlockDriver from nonebot.drivers import ( @@ -48,8 +51,10 @@ class Mixin(ForwardMixin): ) @overrides(ForwardMixin) - async def websocket(self, setup: Request) -> WebSocket: - return await super(Mixin, self).websocket(setup) + @asynccontextmanager + async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]: + async with super(Mixin, self).websocket(setup) as ws: + yield ws Driver = combine_driver(BlockDriver, Mixin) diff --git a/nonebot/drivers/websockets.py b/nonebot/drivers/websockets.py index 81903698..0cc4827b 100644 --- a/nonebot/drivers/websockets.py +++ b/nonebot/drivers/websockets.py @@ -1,4 +1,6 @@ import logging +from typing import AsyncGenerator +from contextlib import asynccontextmanager from nonebot.typing import overrides from nonebot.log import LoguruHandler @@ -29,13 +31,15 @@ class Mixin(ForwardMixin): return await super(Mixin, self).request(setup) @overrides(ForwardMixin) - async def websocket(self, setup: Request) -> "WebSocket": - ws = await Connect( + @asynccontextmanager + async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]: + connection = Connect( str(setup.url), extra_headers=setup.headers.items(), open_timeout=setup.timeout, ) - return WebSocket(request=setup, websocket=ws) + async with connection as ws: + yield WebSocket(request=setup, websocket=ws) class WebSocket(BaseWebSocket):