diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index 8bb3951a..eeb44a09 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -80,7 +80,7 @@ class Driver(abc.ABC): """ return self._clients - def register_adapter(self, adapter: Type["Adapter"], **kwargs): + def register_adapter(self, adapter: Type["Adapter"], **kwargs) -> None: """ :说明: @@ -105,7 +105,7 @@ class Driver(abc.ABC): @property @abc.abstractmethod - def type(self): + def type(self) -> str: """驱动类型名称""" raise NotImplementedError @@ -204,10 +204,11 @@ class Driver(abc.ABC): asyncio.create_task(_run_hook(bot)) -class ForwardDriver(Driver): - """ - Forward Driver 基类。将客户端框架封装,以满足适配器使用。 - """ +class ForwardMixin(abc.ABC): + @property + @abc.abstractmethod + def type(self) -> str: + raise NotImplementedError @abc.abstractmethod async def request(self, setup: Request) -> Response: @@ -218,6 +219,12 @@ class ForwardDriver(Driver): raise NotImplementedError +class ForwardDriver(Driver, ForwardMixin): + """ + Forward Driver 基类。将客户端框架封装,以满足适配器使用。 + """ + + class ReverseDriver(Driver): """ Reverse Driver 基类。将后端框架封装,以满足适配器使用。 @@ -244,6 +251,19 @@ class ReverseDriver(Driver): raise NotImplementedError +def combine_driver(driver: Type[Driver], *mixins: Type[ForwardMixin]) -> Type[Driver]: + class CombinedDriver(driver, *mixins): # type: ignore + @property + def type(self) -> str: + return ( + driver.type.__get__(self) + + "+" + + "+".join(map(lambda x: x.type.__get__(self), mixins)) + ) + + return CombinedDriver + + @dataclass class HTTPServerSetup: path: URL # path should not be absolute, check it by URL.is_absolute() == False diff --git a/nonebot/drivers/_block_driver.py b/nonebot/drivers/_block_driver.py new file mode 100644 index 00000000..afd7f8b3 --- /dev/null +++ b/nonebot/drivers/_block_driver.py @@ -0,0 +1,158 @@ +import signal +import asyncio +import threading +from typing import Set, Callable, Awaitable + +from nonebot.log import logger +from nonebot.typing import overrides +from nonebot.config import Env, Config +from nonebot.drivers import ForwardDriver + +STARTUP_FUNC = Callable[[], Awaitable[None]] +SHUTDOWN_FUNC = Callable[[], Awaitable[None]] +HANDLED_SIGNALS = ( + signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. + signal.SIGTERM, # Unix signal 15. Sent by `kill `. +) + + +class BlockDriver(ForwardDriver): + """ + AIOHTTP 驱动框架 + """ + + def __init__(self, env: Env, config: Config): + super().__init__(env, config) + self.startup_funcs: Set[STARTUP_FUNC] = set() + self.shutdown_funcs: Set[SHUTDOWN_FUNC] = set() + self.should_exit: asyncio.Event = asyncio.Event() + self.force_exit: bool = False + + @property + @overrides(ForwardDriver) + def type(self) -> str: + """驱动名称: ``block_driver``""" + return "block_driver" + + @property + @overrides(ForwardDriver) + def logger(self): + """block driver 使用的 logger""" + return logger + + @overrides(ForwardDriver) + def on_startup(self, func: STARTUP_FUNC) -> STARTUP_FUNC: + """ + :说明: + + 注册一个启动时执行的函数 + + :参数: + + * ``func: Callable[[], Awaitable[None]]`` + """ + self.startup_funcs.add(func) + return func + + @overrides(ForwardDriver) + def on_shutdown(self, func: SHUTDOWN_FUNC) -> SHUTDOWN_FUNC: + """ + :说明: + + 注册一个停止时执行的函数 + + :参数: + + * ``func: Callable[[], Awaitable[None]]`` + """ + self.shutdown_funcs.add(func) + return func + + @overrides(ForwardDriver) + def run(self, *args, **kwargs): + """启动 block driver""" + super().run(*args, **kwargs) + loop = asyncio.get_event_loop() + loop.run_until_complete(self.serve()) + + async def serve(self): + self.install_signal_handlers() + await self.startup() + if self.should_exit.is_set(): + return + await self.main_loop() + await self.shutdown() + + async def startup(self): + # run startup + cors = [startup() for startup in self.startup_funcs] + if cors: + try: + await asyncio.gather(*cors) + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Error when running startup function. " + "Ignored!" + ) + + logger.info("Application startup completed.") + + async def main_loop(self): + await self.should_exit.wait() + + async def shutdown(self): + logger.info("Shutting down") + + logger.info("Waiting for application shutdown.") + # run shutdown + cors = [shutdown() for shutdown in self.shutdown_funcs] + if cors: + try: + await asyncio.gather(*cors) + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Error when running shutdown function. " + "Ignored!" + ) + + for task in asyncio.all_tasks(): + if task is not asyncio.current_task() and not task.done(): + task.cancel() + await asyncio.sleep(0.1) + + tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + if tasks and not self.force_exit: + logger.info("Waiting for tasks to finish. (CTRL+C to force quit)") + while tasks and not self.force_exit: + await asyncio.sleep(0.1) + tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + + for task in tasks: + task.cancel() + + await asyncio.gather(*tasks, return_exceptions=True) + + logger.info("Application shutdown complete.") + loop = asyncio.get_event_loop() + loop.stop() + + def install_signal_handlers(self) -> None: + if threading.current_thread() is not threading.main_thread(): + # Signals can only be listened to from the main thread. + return + + loop = asyncio.get_event_loop() + + try: + for sig in HANDLED_SIGNALS: + loop.add_signal_handler(sig, self.handle_exit, sig, None) + except NotImplementedError: + # Windows + for sig in HANDLED_SIGNALS: + signal.signal(sig, self.handle_exit) + + def handle_exit(self, sig, frame): + if self.should_exit.is_set(): + self.force_exit = True + else: + self.should_exit.set() diff --git a/nonebot/drivers/aiohttp.py b/nonebot/drivers/aiohttp.py index fbbf673d..cb7f7bbe 100644 --- a/nonebot/drivers/aiohttp.py +++ b/nonebot/drivers/aiohttp.py @@ -5,428 +5,79 @@ AIOHTTP 驱动适配 本驱动仅支持客户端连接 """ -import signal -import asyncio -import threading -from dataclasses import dataclass -from typing import Set, List, Union, Callable, Optional, Awaitable, cast import aiohttp -from yarl import URL -from nonebot.log import logger -from nonebot.adapters import Bot from nonebot.typing import overrides -from nonebot.utils import escape_tag -from nonebot.config import Env, Config +from nonebot.drivers import Request, Response +from nonebot.drivers._block_driver import BlockDriver from nonebot.drivers import WebSocket as BaseWebSocket -from nonebot.drivers import ( - HTTPRequest, - ForwardDriver, - WebSocketSetup, - HTTPPollingSetup, -) - -STARTUP_FUNC = Callable[[], Awaitable[None]] -SHUTDOWN_FUNC = Callable[[], Awaitable[None]] -HTTPPOLLING_SETUP = Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]] -WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]] -HANDLED_SIGNALS = ( - signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. - signal.SIGTERM, # Unix signal 15. Sent by `kill `. -) +from nonebot.drivers import HTTPVersion, ForwardMixin, combine_driver -class Driver(ForwardDriver): - """ - AIOHTTP 驱动框架 - """ - - def __init__(self, env: Env, config: Config): - super().__init__(env, config) - self.startup_funcs: Set[STARTUP_FUNC] = set() - self.shutdown_funcs: Set[SHUTDOWN_FUNC] = set() - self.http_pollings: List[HTTPPOLLING_SETUP] = [] - self.websockets: List[WEBSOCKET_SETUP] = [] - self.connections: List[asyncio.Task] = [] - self.should_exit: asyncio.Event = asyncio.Event() - self.force_exit: bool = False - +class AiohttpMixin(ForwardMixin): @property - @overrides(ForwardDriver) + @overrides(ForwardMixin) def type(self) -> str: - """驱动名称: ``aiohttp``""" return "aiohttp" - @property - @overrides(ForwardDriver) - def logger(self): - """aiohttp driver 使用的 logger""" - return logger - - @overrides(ForwardDriver) - def on_startup(self, func: STARTUP_FUNC) -> STARTUP_FUNC: - """ - :说明: - - 注册一个启动时执行的函数 - - :参数: - - * ``func: Callable[[], Awaitable[None]]`` - """ - self.startup_funcs.add(func) - return func - - @overrides(ForwardDriver) - def on_shutdown(self, func: SHUTDOWN_FUNC) -> SHUTDOWN_FUNC: - """ - :说明: - - 注册一个停止时执行的函数 - - :参数: - - * ``func: Callable[[], Awaitable[None]]`` - """ - self.shutdown_funcs.add(func) - return func - - @overrides(ForwardDriver) - def setup_http_polling(self, setup: HTTPPOLLING_SETUP) -> None: - """ - :说明: - - 注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用 - - :参数: - - * ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]`` - """ - self.http_pollings.append(setup) - - @overrides(ForwardDriver) - def setup_websocket(self, setup: WEBSOCKET_SETUP) -> None: - """ - :说明: - - 注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用 - - :参数: - - * ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]`` - """ - self.websockets.append(setup) - - @overrides(ForwardDriver) - def run(self, *args, **kwargs): - """启动 aiohttp driver""" - super().run(*args, **kwargs) - loop = asyncio.get_event_loop() - loop.run_until_complete(self.serve()) - - async def serve(self): - self.install_signal_handlers() - await self.startup() - if self.should_exit.is_set(): - return - await self.main_loop() - await self.shutdown() - - async def startup(self): - for setup in self.http_pollings: - self.connections.append(asyncio.create_task(self._http_loop(setup))) - for setup in self.websockets: - self.connections.append(asyncio.create_task(self._ws_loop(setup))) - - logger.info("Application startup completed.") - - # run startup - cors = [startup() for startup in self.startup_funcs] - if cors: - try: - await asyncio.gather(*cors) - except Exception as e: - logger.opt(colors=True, exception=e).error( - "Error when running startup function. " - "Ignored!" - ) - - async def main_loop(self): - await self.should_exit.wait() - - async def shutdown(self): - # run shutdown - cors = [shutdown() for shutdown in self.shutdown_funcs] - if cors: - try: - await asyncio.gather(*cors) - except Exception as e: - logger.opt(colors=True, exception=e).error( - "Error when running shutdown function. " - "Ignored!" - ) - - for task in self.connections: - if not task.done(): - task.cancel() - await asyncio.sleep(0.1) - - tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] - if tasks and not self.force_exit: - logger.info("Waiting for tasks to finish. (CTRL+C to force quit)") - while tasks and not self.force_exit: - await asyncio.sleep(0.1) - tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] - - for task in tasks: - task.cancel() - - await asyncio.gather(*tasks, return_exceptions=True) - - loop = asyncio.get_event_loop() - loop.stop() - - def install_signal_handlers(self) -> None: - if threading.current_thread() is not threading.main_thread(): - # Signals can only be listened to from the main thread. - return - - loop = asyncio.get_event_loop() - - try: - for sig in HANDLED_SIGNALS: - loop.add_signal_handler(sig, self.handle_exit, sig, None) - except NotImplementedError: - # Windows - for sig in HANDLED_SIGNALS: - signal.signal(sig, self.handle_exit) - - def handle_exit(self, sig, frame): - if self.should_exit.is_set(): - self.force_exit = True + @overrides(ForwardMixin) + async def request(self, setup: Request) -> Response: + if setup.version == HTTPVersion.H10: + version = aiohttp.HttpVersion10 + elif setup.version == HTTPVersion.H11: + version = aiohttp.HttpVersion11 else: - self.should_exit.set() + raise RuntimeError(f"Unsupported HTTP version: {setup.version}") - async def _http_loop(self, setup: HTTPPOLLING_SETUP): - async def _build_request(setup: HTTPPollingSetup) -> Optional[HTTPRequest]: - url = URL(setup.url) - if not url.is_absolute() or not url.host: - logger.opt(colors=True).error( - f"Error parsing url {escape_tag(str(url))}" - ) - return - host = f"{url.host}:{url.port}" if url.port else url.host - return HTTPRequest( - setup.http_version, - url.scheme, - url.path, - url.raw_query_string.encode("latin-1"), - {**setup.headers, "host": host}, + timeout = aiohttp.ClientTimeout(setup.timeout) + async with aiohttp.ClientSession(version=version) as session: + async with session.request( setup.method, - setup.body, - ) + setup.url, + data=setup.content, + headers=setup.headers, + timeout=timeout, + ) as response: + res = Response( + response.status, + headers=response.headers.copy(), + content=await response.read(), + request=setup, + ) + return res - bot: Optional[Bot] = None - request: Optional[HTTPRequest] = None - setup_: Optional[HTTPPollingSetup] = None + @overrides(ForwardMixin) + async def websocket(self, setup: Request) -> "WebSocket": + if setup.version == HTTPVersion.H10: + version = aiohttp.HttpVersion10 + elif setup.version == HTTPVersion.H11: + version = aiohttp.HttpVersion11 + else: + raise RuntimeError(f"Unsupported HTTP version: {setup.version}") - logger.opt(colors=True).info( - f"Start http polling for {escape_tag(setup.adapter.upper())} " - f"Bot {escape_tag(setup.self_id)}" + session = aiohttp.ClientSession(version=version) + ws = await session.ws_connect( + setup.url, + method=setup.method, + timeout=setup.timeout or 10, + headers=setup.headers, ) - - try: - async with aiohttp.ClientSession() as session: - while not self.should_exit.is_set(): - - try: - if callable(setup): - setup_ = await setup() - else: - setup_ = setup - except Exception as e: - logger.opt(colors=True, exception=e).error( - "Error while parsing setup " - f"{escape_tag(repr(setup))}." - ) - await asyncio.sleep(3) - continue - - setup_ = cast(HTTPPollingSetup, setup_) - - if not bot: - request = await _build_request(setup_) - if not request: - return - - BotClass = self._adapters[setup.adapter] - bot = BotClass(setup.self_id, request) - self._bot_connect(bot) - elif callable(setup): - request = await _build_request(setup_) - if not request: - await asyncio.sleep(setup_.poll_interval) - continue - bot.request = request - - request = cast(HTTPRequest, request) - - headers = request.headers - timeout = aiohttp.ClientTimeout(30) - version: aiohttp.HttpVersion - if request.http_version == "1.0": - version = aiohttp.HttpVersion10 - elif request.http_version == "1.1": - version = aiohttp.HttpVersion11 - else: - logger.opt(colors=True).error( - "Unsupported HTTP Version " - f"{escape_tag(request.http_version)}" - ) - return - - logger.debug( - f"Bot {setup_.self_id} from adapter {setup_.adapter} request {setup_.url}" - ) - - try: - async with session.request( - request.method, - setup_.url, - data=request.body, - headers=headers, - timeout=timeout, - version=version, - ) as response: - response.raise_for_status() - data = await response.read() - asyncio.create_task(bot.handle_message(data)) - except aiohttp.ClientResponseError as e: - logger.opt(colors=True, exception=e).error( - f"Error occurred while requesting {escape_tag(setup_.url)}. " - "Try to reconnect..." - ) - - await asyncio.sleep(setup_.poll_interval) - - except asyncio.CancelledError: - pass - except Exception as e: - logger.opt(colors=True, exception=e).error( - "Unexpected exception occurred " - "while http polling" - ) - finally: - if bot: - self._bot_disconnect(bot) - - async def _ws_loop(self, setup: WEBSOCKET_SETUP): - bot: Optional[Bot] = None - - try: - async with aiohttp.ClientSession() as session: - while True: - - try: - if callable(setup): - setup_ = await setup() - else: - setup_ = setup - except Exception as e: - logger.opt(colors=True, exception=e).error( - "Error while parsing setup " - f"{escape_tag(repr(setup))}." - ) - await asyncio.sleep(3) - continue - - url = URL(setup_.url) - if not url.is_absolute() or not url.host: - logger.opt(colors=True).error( - f"Error parsing url {escape_tag(str(url))}" - ) - await asyncio.sleep(setup_.reconnect_interval) - continue - - host = f"{url.host}:{url.port}" if url.port else url.host - headers = {**setup_.headers, "host": host} - - logger.debug( - f"Bot {setup_.self_id} from adapter {setup_.adapter} connecting to {url}" - ) - try: - async with session.ws_connect( - url, headers=headers, timeout=30.0 - ) as ws: - logger.opt(colors=True).info( - f"WebSocket Connection to {escape_tag(setup_.adapter.upper())} " - f"Bot {escape_tag(setup_.self_id)} succeeded!" - ) - request = WebSocket( - "1.1", - url.scheme, - url.path, - url.raw_query_string.encode("latin-1"), - headers, - ws, - ) - - BotClass = self._adapters[setup_.adapter] - bot = BotClass(setup_.self_id, request) - self._bot_connect(bot) - while not self.should_exit.is_set(): - msg = await ws.receive() - if msg.type == aiohttp.WSMsgType.text: - asyncio.create_task( - bot.handle_message(msg.data.encode()) - ) - elif msg.type == aiohttp.WSMsgType.binary: - asyncio.create_task(bot.handle_message(msg.data)) - elif msg.type == aiohttp.WSMsgType.error: - logger.opt(colors=True).error( - "Error while handling websocket frame. " - "Try to reconnect..." - ) - break - else: - logger.opt(colors=True).error( - "WebSocket connection closed by peer. " - "Try to reconnect..." - ) - break - except ( - aiohttp.ClientResponseError, - aiohttp.ClientConnectionError, - ) as e: - logger.opt(colors=True, exception=e).error( - f"Error while connecting to {escape_tag(str(url))}. " - "Try to reconnect..." - ) - finally: - if bot: - self._bot_disconnect(bot) - bot = None - - if not setup_.reconnect: - logger.info( - f"WebSocket reconnect disabled for bot {setup_.self_id}" - ) - break - await asyncio.sleep(setup_.reconnect_interval) - - except asyncio.CancelledError: - pass - except Exception as e: - logger.opt(colors=True, exception=e).error( - "Unexpected exception occurred " - "while websocket loop" - ) + websocket = WebSocket(request=setup, session=session, websocket=ws) + return websocket -@dataclass class WebSocket(BaseWebSocket): - websocket: aiohttp.ClientWebSocketResponse = None # type: ignore + def __init__( + self, + *, + request: Request, + session: aiohttp.ClientSession, + websocket: aiohttp.ClientWebSocketResponse, + ): + super().__init__(request=request) + self.session = session + self.websocket = websocket @property @overrides(BaseWebSocket) @@ -440,6 +91,7 @@ class WebSocket(BaseWebSocket): @overrides(BaseWebSocket) async def close(self, code: int = 1000): await self.websocket.close(code=code) + await self.session.close() @overrides(BaseWebSocket) async def receive(self) -> str: @@ -456,3 +108,6 @@ class WebSocket(BaseWebSocket): @overrides(BaseWebSocket) async def send_bytes(self, data: bytes) -> None: await self.websocket.send_bytes(data) + + +Driver = combine_driver(BlockDriver, AiohttpMixin) diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index 9752eb26..66dd0d0a 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -11,31 +11,28 @@ FastAPI 驱动适配 """ import logging -from functools import partial -from dataclasses import dataclass -from typing import Any, List, Union, Callable, Optional, Awaitable +from typing import List, Callable, Optional -import httpx import uvicorn from pydantic import BaseSettings from fastapi.responses import Response from fastapi import FastAPI, Request, status from starlette.websockets import WebSocket, WebSocketState -from websockets.legacy.client import Connect, WebSocketClientProtocol from nonebot.config import Env from nonebot.typing import overrides from nonebot.utils import escape_tag +from nonebot.drivers.httpx import HttpxMixin +from nonebot.drivers.aiohttp import AiohttpMixin from nonebot.config import Config as NoneBotConfig from nonebot.drivers import Request as BaseRequest -from nonebot.drivers import Response as BaseResponse from nonebot.drivers import WebSocket as BaseWebSocket +from nonebot.drivers.websockets import WebSocketsMixin from nonebot.drivers import ( - HTTPVersion, - ForwardDriver, ReverseDriver, HTTPServerSetup, WebSocketServerSetup, + combine_driver, ) @@ -246,7 +243,7 @@ class Driver(ReverseDriver): self, request: Request, setup: HTTPServerSetup, - ): + ) -> Response: http_request = BaseRequest( request.method, str(request.url), @@ -265,7 +262,7 @@ class Driver(ReverseDriver): str(websocket.url), headers=websocket.headers.items(), cookies=websocket.cookies, - version=websocket.scope["http_version"], + version=websocket.scope.get("http_version", "1.1"), ) ws = FastAPIWebSocket( request=request, @@ -275,90 +272,6 @@ class Driver(ReverseDriver): await setup.handle_func(ws) -class FullDriver(ForwardDriver, Driver): - """ - 完整的 FastAPI 驱动框架,包含正向 Client 支持和反向 Server 支持。 - - :使用方法: - - .. code-block:: dotenv - - DRIVER=nonebot.drivers.fastapi:FullDriver - """ - - @property - @overrides(Driver) - def type(self) -> str: - """驱动名称: ``fastapi_full``""" - return "fastapi_full" - - @overrides(ForwardDriver) - async def request(self, setup: "BaseRequest") -> Any: - async with httpx.AsyncClient( - http2=setup.version == HTTPVersion.H2, follow_redirects=True - ) as client: - response = await client.request( - setup.method, - str(setup.url), - content=setup.content, - headers=tuple(setup.headers.items()), - timeout=30.0, - ) - return BaseResponse( - response.status_code, - headers=response.headers, - content=response.content, - request=setup, - ) - - @overrides(ForwardDriver) - async def websocket(self, setup: "BaseRequest") -> Any: - ws = await Connect(str(setup.url), extra_headers=setup.headers.items()) - return WebSocketsWS(request=setup, websocket=ws) - - -class WebSocketsWS(BaseWebSocket): - @overrides(BaseWebSocket) - def __init__(self, *, request: BaseRequest, websocket: WebSocketClientProtocol): - super().__init__(request=request) - self.websocket = websocket - - @property - @overrides(BaseWebSocket) - def closed(self) -> bool: - return self.websocket.closed - - @overrides(BaseWebSocket) - async def accept(self): - raise NotImplementedError - - @overrides(BaseWebSocket) - async def close(self, code: int = 1000, reason: str = ""): - await self.websocket.close(code, reason) - - @overrides(BaseWebSocket) - async def receive(self) -> str: - msg = await self.websocket.recv() - if isinstance(msg, bytes): - raise TypeError("WebSocket received unexpected frame type: bytes") - return msg - - @overrides(BaseWebSocket) - async def receive_bytes(self) -> bytes: - msg = await self.websocket.recv() - if isinstance(msg, str): - raise TypeError("WebSocket received unexpected frame type: str") - return msg - - @overrides(BaseWebSocket) - async def send(self, data: str) -> None: - await self.websocket.send(data) - - @overrides(BaseWebSocket) - async def send_bytes(self, data: bytes) -> None: - await self.websocket.send(data) - - class FastAPIWebSocket(BaseWebSocket): @overrides(BaseWebSocket) def __init__(self, *, request: BaseRequest, websocket: WebSocket): @@ -398,3 +311,7 @@ class FastAPIWebSocket(BaseWebSocket): @overrides(BaseWebSocket) async def send_bytes(self, data: bytes) -> None: await self.websocket.send({"type": "websocket.send", "bytes": data}) + + +FullDriver = combine_driver(Driver, HttpxMixin, WebSocketsMixin) +AiohttpDriver = combine_driver(Driver, AiohttpMixin) diff --git a/nonebot/drivers/httpx.py b/nonebot/drivers/httpx.py new file mode 100644 index 00000000..1965c439 --- /dev/null +++ b/nonebot/drivers/httpx.py @@ -0,0 +1,45 @@ +import httpx + +from nonebot.typing import overrides +from nonebot.drivers._block_driver import BlockDriver +from nonebot.drivers import ( + Request, + Response, + WebSocket, + HTTPVersion, + ForwardMixin, + combine_driver, +) + + +class HttpxMixin(ForwardMixin): + @property + @overrides(ForwardMixin) + def type(self) -> str: + return "httpx" + + @overrides(ForwardMixin) + async def request(self, setup: Request) -> Response: + async with httpx.AsyncClient( + http2=setup.version == HTTPVersion.H2, follow_redirects=True + ) as client: + response = await client.request( + setup.method, + str(setup.url), + content=setup.content, + headers=tuple(setup.headers.items()), + timeout=setup.timeout, + ) + return Response( + response.status_code, + headers=response.headers, + content=response.content, + request=setup, + ) + + @overrides(ForwardMixin) + async def websocket(self, setup: Request) -> WebSocket: + return await super(HttpxMixin, self).websocket(setup) + + +Driver = combine_driver(BlockDriver, HttpxMixin) diff --git a/nonebot/drivers/quart.py b/nonebot/drivers/quart.py index 9d32463f..bec26c9f 100644 --- a/nonebot/drivers/quart.py +++ b/nonebot/drivers/quart.py @@ -18,10 +18,17 @@ from nonebot.config import Env from nonebot.log import logger from nonebot.typing import overrides from nonebot.utils import escape_tag +from nonebot.drivers.httpx import HttpxMixin from nonebot.config import Config as NoneBotConfig from nonebot.drivers import Request as BaseRequest from nonebot.drivers import WebSocket as BaseWebSocket -from nonebot.drivers import ReverseDriver, HTTPServerSetup, WebSocketServerSetup +from nonebot.drivers.websockets import WebSocketsMixin +from nonebot.drivers import ( + ReverseDriver, + HTTPServerSetup, + WebSocketServerSetup, + combine_driver, +) try: from quart import request as _request @@ -281,3 +288,6 @@ class WebSocket(BaseWebSocket): @overrides(BaseWebSocket) async def send_bytes(self, data: bytes): await self.websocket.send(data) + + +FullDriver = combine_driver(Driver, HttpxMixin, WebSocketsMixin) diff --git a/nonebot/drivers/websockets.py b/nonebot/drivers/websockets.py new file mode 100644 index 00000000..7dbd6399 --- /dev/null +++ b/nonebot/drivers/websockets.py @@ -0,0 +1,78 @@ +import logging + +from websockets.legacy.client import Connect, WebSocketClientProtocol + +from nonebot.typing import overrides +from nonebot.log import LoguruHandler +from nonebot.drivers import Request, Response +from nonebot.drivers._block_driver import BlockDriver +from nonebot.drivers import WebSocket as BaseWebSocket +from nonebot.drivers import ForwardMixin, combine_driver + +logger = logging.Logger("websockets.client", "INFO") +logger.addHandler(LoguruHandler()) + + +class WebSocketsMixin(ForwardMixin): + @property + @overrides(ForwardMixin) + def type(self) -> str: + return "websockets" + + @overrides(ForwardMixin) + async def request(self, setup: Request) -> Response: + return await super(WebSocketsMixin, self).request(setup) + + @overrides(ForwardMixin) + async def websocket(self, setup: Request) -> "WebSocket": + ws = await Connect( + str(setup.url), + extra_headers=setup.headers.items(), + open_timeout=setup.timeout, + ) + return WebSocket(request=setup, websocket=ws) + + +class WebSocket(BaseWebSocket): + @overrides(BaseWebSocket) + def __init__(self, *, request: Request, websocket: WebSocketClientProtocol): + super().__init__(request=request) + self.websocket = websocket + + @property + @overrides(BaseWebSocket) + def closed(self) -> bool: + return self.websocket.closed + + @overrides(BaseWebSocket) + async def accept(self): + raise NotImplementedError + + @overrides(BaseWebSocket) + async def close(self, code: int = 1000, reason: str = ""): + await self.websocket.close(code, reason) + + @overrides(BaseWebSocket) + async def receive(self) -> str: + msg = await self.websocket.recv() + if isinstance(msg, bytes): + raise TypeError("WebSocket received unexpected frame type: bytes") + return msg + + @overrides(BaseWebSocket) + async def receive_bytes(self) -> bytes: + msg = await self.websocket.recv() + if isinstance(msg, str): + raise TypeError("WebSocket received unexpected frame type: str") + return msg + + @overrides(BaseWebSocket) + async def send(self, data: str) -> None: + await self.websocket.send(data) + + @overrides(BaseWebSocket) + async def send_bytes(self, data: bytes) -> None: + await self.websocket.send(data) + + +Driver = combine_driver(BlockDriver, WebSocketsMixin) diff --git a/packages/nonebot-plugin-docs/pyproject.toml b/packages/nonebot-plugin-docs/pyproject.toml index 8ee81dc6..48349e9e 100644 --- a/packages/nonebot-plugin-docs/pyproject.toml +++ b/packages/nonebot-plugin-docs/pyproject.toml @@ -12,9 +12,8 @@ include = ["nonebot_plugin_docs/dist/**/*"] [tool.poetry.dependencies] -python = "^3.7" -aiofiles = "^0.7.0" -nonebot2 = "^2.0.0-alpha.1" +python = "^3.7.3" +nonebot2 = "^2.0.0-beta.1" [tool.poetry.dev-dependencies] diff --git a/poetry.lock b/poetry.lock index 32507fb9..588b3567 100644 --- a/poetry.lock +++ b/poetry.lock @@ -93,6 +93,18 @@ typing-extensions = {version = "*", markers = "python_version < \"3.8\""} [package.extras] tests = ["pytest", "pytest-asyncio", "mypy (>=0.800)"] +[[package]] +name = "async-asgi-testclient" +version = "1.4.8" +description = "Async client for testing ASGI web applications" +category = "dev" +optional = false +python-versions = "*" + +[package.dependencies] +multidict = ">=4.0,<6.0" +requests = ">=2.21,<3.0" + [[package]] name = "async-timeout" version = "4.0.2" @@ -534,6 +546,7 @@ python-versions = "^3.7.3" develop = false [package.dependencies] +async-asgi-testclient = "^1.4.8" nonebot2 = "^2.0.0-beta.1" pytest = "^6.2.5" pytest-asyncio = "^0.16.0" @@ -543,7 +556,7 @@ pytest-order = "^1.0.0" type = "git" url = "https://github.com/nonebot/nonebug.git" reference = "master" -resolved_reference = "0a1132e9dc1803517ded0d485bfbe8c47a1d8585" +resolved_reference = "4af5bd99c3eb0f63f4619620461b16de6c96b227" [[package]] name = "packaging" @@ -1281,6 +1294,9 @@ asgiref = [ {file = "asgiref-3.4.1-py3-none-any.whl", hash = "sha256:ffc141aa908e6f175673e7b1b3b7af4fdb0ecb738fc5c8b88f69f055c2415214"}, {file = "asgiref-3.4.1.tar.gz", hash = "sha256:4ef1ab46b484e3c706329cedeff284a5d40824200638503f5768edb6de7d58e9"}, ] +async-asgi-testclient = [ + {file = "async-asgi-testclient-1.4.8.tar.gz", hash = "sha256:52d666ea75971c8a825befd34a5684414578f3c5bfa5a90e7eb7de924f447aae"}, +] async-timeout = [ {file = "async-timeout-4.0.2.tar.gz", hash = "sha256:2163e1640ddb52b7a8c80d0a67a08587e5d245cc9c553a74a847056bc2976b15"}, {file = "async_timeout-4.0.2-py3-none-any.whl", hash = "sha256:8ca1e4fcf50d07413d66d1a5e416e42cfdf5851c981d679a09851a6853383b3c"}, diff --git a/pyproject.toml b/pyproject.toml index 5dfd8cbe..db906c0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ include = ["nonebot/py.typed"] [tool.poetry.dependencies] python = "^3.7.3" +yarl = "^1.7.2" loguru = "^0.5.1" pygtrie = "^2.4.1" tomlkit = "^0.7.0" @@ -34,7 +35,6 @@ httpx = { version = ">=0.20.0, <1.0.0", extras = ["http2"] } pydantic = { version = "~1.8.0", extras = ["dotenv"] } uvicorn = { version = "^0.15.0", extras = ["standard"] } aiohttp = { version = "^3.7.4", extras = ["speedups"], optional = true } -yarl = "^1.7.2" [tool.poetry.dev-dependencies] sphinx = "^4.1.1" diff --git a/tests/test_driver.py b/tests/test_driver.py new file mode 100644 index 00000000..5e5d8a28 --- /dev/null +++ b/tests/test_driver.py @@ -0,0 +1,11 @@ +import pytest + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "nonebug_init", + [{"driver": "nonebot.drivers.fastapi"}], + indirect=True, +) +async def test_driver(nonebug_init): + ...