change websocket client to context manager

This commit is contained in:
yanyongyu 2021-12-26 13:42:13 +08:00
parent 00c2ee8490
commit 7b204d72e6
4 changed files with 37 additions and 12 deletions

View File

@ -8,7 +8,17 @@
import abc
import asyncio
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, Awaitable
from contextlib import asynccontextmanager
from typing import (
TYPE_CHECKING,
Any,
Set,
Dict,
Type,
Callable,
Awaitable,
AsyncGenerator,
)
from ._model import URL as URL
from nonebot.log import logger
@ -215,8 +225,10 @@ class ForwardMixin(abc.ABC):
raise NotImplementedError
@abc.abstractmethod
async def websocket(self, setup: Request) -> WebSocket:
@asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]:
raise NotImplementedError
yield # used for static type checking's generator detection
class ForwardDriver(Driver, ForwardMixin):

View File

@ -5,6 +5,9 @@ AIOHTTP 驱动适配
本驱动仅支持客户端连接
"""
from typing import AsyncGenerator
from contextlib import asynccontextmanager
from nonebot.typing import overrides
from nonebot.drivers import Request, Response
from nonebot.drivers._block_driver import BlockDriver
@ -59,7 +62,8 @@ class Mixin(ForwardMixin):
return res
@overrides(ForwardMixin)
async def websocket(self, setup: Request) -> "WebSocket":
@asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
if setup.version == HTTPVersion.H10:
version = aiohttp.HttpVersion10
elif setup.version == HTTPVersion.H11:
@ -68,15 +72,15 @@ class Mixin(ForwardMixin):
raise RuntimeError(f"Unsupported HTTP version: {setup.version}")
session = aiohttp.ClientSession(version=version, trust_env=True)
ws = await session.ws_connect(
async with session.ws_connect(
setup.url,
method=setup.method,
timeout=setup.timeout or 10,
headers=setup.headers,
proxy=setup.proxy,
)
websocket = WebSocket(request=setup, session=session, websocket=ws)
return websocket
) as ws:
websocket = WebSocket(request=setup, session=session, websocket=ws)
yield websocket
class WebSocket(BaseWebSocket):

View File

@ -1,3 +1,6 @@
from typing import AsyncGenerator
from contextlib import asynccontextmanager
from nonebot.typing import overrides
from nonebot.drivers._block_driver import BlockDriver
from nonebot.drivers import (
@ -48,8 +51,10 @@ class Mixin(ForwardMixin):
)
@overrides(ForwardMixin)
async def websocket(self, setup: Request) -> WebSocket:
return await super(Mixin, self).websocket(setup)
@asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator[WebSocket, None]:
async with super(Mixin, self).websocket(setup) as ws:
yield ws
Driver = combine_driver(BlockDriver, Mixin)

View File

@ -1,4 +1,6 @@
import logging
from typing import AsyncGenerator
from contextlib import asynccontextmanager
from nonebot.typing import overrides
from nonebot.log import LoguruHandler
@ -29,13 +31,15 @@ class Mixin(ForwardMixin):
return await super(Mixin, self).request(setup)
@overrides(ForwardMixin)
async def websocket(self, setup: Request) -> "WebSocket":
ws = await Connect(
@asynccontextmanager
async def websocket(self, setup: Request) -> AsyncGenerator["WebSocket", None]:
connection = Connect(
str(setup.url),
extra_headers=setup.headers.items(),
open_timeout=setup.timeout,
)
return WebSocket(request=setup, websocket=ws)
async with connection as ws:
yield WebSocket(request=setup, websocket=ws)
class WebSocket(BaseWebSocket):