From cda1ad093f40832c9e835ea18a63b226163b9fd7 Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Sat, 31 Jul 2021 12:24:11 +0800 Subject: [PATCH] :boom: change forward setup api --- docs/.vuepress/config.js | 4 + docs/api/README.md | 3 + docs/api/drivers/README.md | 114 ++++++++ docs/api/drivers/aiohttp.md | 101 +++++++ docs/api/drivers/fastapi.md | 46 ++- docs_build/README.rst | 1 + docs_build/drivers/aiohttp.rst | 12 + nonebot/drivers/__init__.py | 81 ++++-- nonebot/drivers/aiohttp.py | 274 ++++++++++-------- nonebot/drivers/fastapi.py | 200 +++++++------ .../nonebot/adapters/cqhttp/bot.py | 8 +- 11 files changed, 599 insertions(+), 245 deletions(-) create mode 100644 docs/api/drivers/aiohttp.md create mode 100644 docs_build/drivers/aiohttp.rst diff --git a/docs/.vuepress/config.js b/docs/.vuepress/config.js index 69efc7ba..4aef42b7 100644 --- a/docs/.vuepress/config.js +++ b/docs/.vuepress/config.js @@ -219,6 +219,10 @@ module.exports = (context) => ({ title: "nonebot.drivers.quart 模块", path: "drivers/quart", }, + { + title: "nonebot.drivers.aiohttp 模块", + path: "drivers/aiohttp", + }, { title: "nonebot.adapters 模块", path: "adapters/", diff --git a/docs/api/README.md b/docs/api/README.md index 26d4a0f5..38fac915 100644 --- a/docs/api/README.md +++ b/docs/api/README.md @@ -49,6 +49,9 @@ * [nonebot.drivers.quart](drivers/quart.html) + * [nonebot.drivers.aiohttp](drivers/aiohttp.html) + + * [nonebot.adapters](adapters/) diff --git a/docs/api/drivers/README.md b/docs/api/drivers/README.md index 57c4d135..9f8ee3ee 100644 --- a/docs/api/drivers/README.md +++ b/docs/api/drivers/README.md @@ -238,6 +238,45 @@ Driver 基类。 在 WebSocket 连接断开后,调用该函数来注销 bot 对象 +## _class_ `ForwardDriver` + +基类:`nonebot.drivers.Driver` + +Forward Driver 基类。将客户端框架封装,以满足适配器使用。 + + +### _abstract_ `setup_http_polling(setup)` + + +* **说明** + + 注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用 + + + +* **参数** + + + * `setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]` + + + +### _abstract_ `setup_websocket(setup)` + + +* **说明** + + 注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用 + + + +* **参数** + + + * `setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]` + + + ## _class_ `ReverseDriver` 基类:`nonebot.drivers.Driver` @@ -413,3 +452,78 @@ Always `websocket` ### _abstract async_ `send_bytes(data)` 发送一条 WebSocket binary 信息 + + +## _class_ `HTTPPollingSetup` + +基类:`object` + + +### `adapter` + +协议适配器名称 + + +### `self_id` + +机器人 ID + + +### `url` + +URL + + +### `method` + +HTTP method + + +### `body` + +HTTP body + + +### `headers` + +HTTP headers + + +### `http_version` + +HTTP version + + +### `poll_interval` + +HTTP 轮询间隔 + + +## _class_ `WebSocketSetup` + +基类:`object` + + +### `adapter` + +协议适配器名称 + + +### `self_id` + +机器人 ID + + +### `url` + +URL + + +### `headers` + +HTTP headers + + +### `reconnect_interval` + +WebSocket 重连间隔 diff --git a/docs/api/drivers/aiohttp.md b/docs/api/drivers/aiohttp.md new file mode 100644 index 00000000..4159d44e --- /dev/null +++ b/docs/api/drivers/aiohttp.md @@ -0,0 +1,101 @@ +--- +contentSidebar: true +sidebarDepth: 0 +--- + +# NoneBot.drivers.aiohttp 模块 + +## AIOHTTP 驱动适配 + +本驱动仅支持客户端连接 + + +## _class_ `Driver` + +基类:[`nonebot.drivers.ForwardDriver`](README.md#nonebot.drivers.ForwardDriver) + +AIOHTTP 驱动框架 + + +### _property_ `type` + +驱动名称: `aiohttp` + + +### _property_ `logger` + +aiohttp driver 使用的 logger + + +### `on_startup(func)` + + +* **说明** + + 注册一个启动时执行的函数 + + + +* **参数** + + + * `func: Callable[[], Awaitable[None]]` + + + +### `on_shutdown(func)` + + +* **说明** + + 注册一个停止时执行的函数 + + + +* **参数** + + + * `func: Callable[[], Awaitable[None]]` + + + +### `setup_http_polling(setup)` + + +* **说明** + + 注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用 + + + +* **参数** + + + * `setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]` + + + +### `setup_websocket(setup)` + + +* **说明** + + 注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用 + + + +* **参数** + + + * `setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]` + + + +### `run(*args, **kwargs)` + +启动 aiohttp driver + + +## _class_ `WebSocket` + +基类:[`nonebot.drivers.WebSocket`](README.md#nonebot.drivers.WebSocket) diff --git a/docs/api/drivers/fastapi.md b/docs/api/drivers/fastapi.md index 3a79f42c..2c02c8d6 100644 --- a/docs/api/drivers/fastapi.md +++ b/docs/api/drivers/fastapi.md @@ -7,19 +7,11 @@ sidebarDepth: 0 ## FastAPI 驱动适配 +本驱动同时支持服务端以及客户端连接 + 后端使用方法请参考: [FastAPI 文档](https://fastapi.tiangolo.com/) -## _class_ `HTTPPollingSetup` - -基类:`object` - - -## _class_ `WebSocketSetup` - -基类:`object` - - ## _class_ `Config` 基类:`pydantic.env_settings.BaseSettings` @@ -89,7 +81,7 @@ FastAPI 驱动框架设置,详情参考 FastAPI 文档 ## _class_ `Driver` -基类:[`nonebot.drivers.ReverseDriver`](README.md#nonebot.drivers.ReverseDriver), `nonebot.drivers.ForwardDriver` +基类:[`nonebot.drivers.ReverseDriver`](README.md#nonebot.drivers.ReverseDriver), [`nonebot.drivers.ForwardDriver`](README.md#nonebot.drivers.ForwardDriver) FastAPI 驱动框架 @@ -140,6 +132,38 @@ fastapi 使用的 logger 参考文档: [Events](https://fastapi.tiangolo.com/advanced/events/#startup-event) +### `setup_http_polling(setup)` + + +* **说明** + + 注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用 + + + +* **参数** + + + * `setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]` + + + +### `setup_websocket(setup)` + + +* **说明** + + 注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用 + + + +* **参数** + + + * `setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]` + + + ### `run(host=None, port=None, *, app=None, **kwargs)` 使用 `uvicorn` 启动 FastAPI diff --git a/docs_build/README.rst b/docs_build/README.rst index 35d9368c..5d5fc86d 100644 --- a/docs_build/README.rst +++ b/docs_build/README.rst @@ -17,6 +17,7 @@ NoneBot Api Reference - `nonebot.drivers `_ - `nonebot.drivers.fastapi `_ - `nonebot.drivers.quart `_ + - `nonebot.drivers.aiohttp `_ - `nonebot.adapters `_ - `nonebot.adapters.cqhttp `_ - `nonebot.adapters.ding `_ diff --git a/docs_build/drivers/aiohttp.rst b/docs_build/drivers/aiohttp.rst new file mode 100644 index 00000000..077da6c5 --- /dev/null +++ b/docs_build/drivers/aiohttp.rst @@ -0,0 +1,12 @@ +\-\-\- +contentSidebar: true +sidebarDepth: 0 +\-\-\- + +NoneBot.drivers.aiohttp 模块 +============================= + +.. automodule:: nonebot.drivers.aiohttp + :members: + :private-members: + :show-inheritance: diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index d7a7a8df..7607e4b3 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -8,7 +8,7 @@ import abc import asyncio from dataclasses import dataclass, field -from typing import Any, Set, Dict, Type, Optional, Callable, TYPE_CHECKING +from typing import Any, Set, Dict, Type, Union, Optional, Callable, Awaitable, TYPE_CHECKING from nonebot.log import logger from nonebot.config import Env, Config @@ -193,27 +193,40 @@ class Driver(abc.ABC): class ForwardDriver(Driver): + """ + Forward Driver 基类。将客户端框架封装,以满足适配器使用。 + """ @abc.abstractmethod - def setup_http_polling(self, - adapter: str, - self_id: str, - url: str, - polling_interval: float = 3., - method: str = "GET", - body: bytes = b"", - headers: Dict[str, str] = {}, - http_version: str = "1.1") -> None: + def setup_http_polling( + self, setup: Union["HTTPPollingSetup", + Callable[[], Awaitable["HTTPPollingSetup"]]] + ) -> None: + """ + :说明: + + 注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用 + + :参数: + + * ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]`` + """ raise NotImplementedError @abc.abstractmethod - def setup_websocket(self, - adapter: str, - self_id: str, - url: str, - reconnect_interval: float = 3., - headers: Dict[str, str] = {}, - http_version: str = "1.1") -> None: + def setup_websocket( + self, setup: Union["WebSocketSetup", + Callable[[], Awaitable["WebSocketSetup"]]] + ) -> None: + """ + :说明: + + 注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用 + + :参数: + + * ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]`` + """ raise NotImplementedError @@ -369,3 +382,37 @@ class WebSocket(HTTPConnection, abc.ABC): async def send_bytes(self, data: bytes): """发送一条 WebSocket binary 信息""" raise NotImplementedError + + +@dataclass +class HTTPPollingSetup: + adapter: str + """协议适配器名称""" + self_id: str + """机器人 ID""" + url: str + """URL""" + method: str + """HTTP method""" + body: bytes + """HTTP body""" + headers: Dict[str, str] + """HTTP headers""" + http_version: str + """HTTP version""" + poll_interval: float + """HTTP 轮询间隔""" + + +@dataclass +class WebSocketSetup: + adapter: str + """协议适配器名称""" + self_id: str + """机器人 ID""" + url: str + """URL""" + headers: Dict[str, str] = field(default_factory=dict) + """HTTP headers""" + reconnect_interval: float = 3. + """WebSocket 重连间隔""" diff --git a/nonebot/drivers/aiohttp.py b/nonebot/drivers/aiohttp.py index 7977acfb..10d464c0 100644 --- a/nonebot/drivers/aiohttp.py +++ b/nonebot/drivers/aiohttp.py @@ -1,11 +1,15 @@ """ +AIOHTTP 驱动适配 +================ + +本驱动仅支持客户端连接 """ import signal import asyncio import threading from dataclasses import dataclass -from typing import Set, List, Dict, Optional, Callable, Awaitable +from typing import Set, List, cast, Union, Optional, Callable, Awaitable import aiohttp from yarl import URL @@ -14,46 +18,31 @@ from nonebot.log import logger from nonebot.adapters import Bot from nonebot.typing import overrides from nonebot.config import Env, Config -from nonebot.drivers import ForwardDriver, HTTPRequest, WebSocket as BaseWebSocket +from nonebot.drivers import (ForwardDriver, HTTPPollingSetup, WebSocketSetup, + HTTPRequest, WebSocket as BaseWebSocket) STARTUP_FUNC = Callable[[], Awaitable[None]] SHUTDOWN_FUNC = Callable[[], Awaitable[None]] +HTTPPOLLING_SETUP = Union[HTTPPollingSetup, + Callable[[], Awaitable[HTTPPollingSetup]]] +WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]] HANDLED_SIGNALS = ( signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. signal.SIGTERM, # Unix signal 15. Sent by `kill `. ) -@dataclass -class HTTPPollingSetup: - adapter: str - self_id: str - url: str - method: str - body: bytes - headers: Dict[str, str] - http_version: str - poll_interval: float - - -@dataclass -class WebSocketSetup: - adapter: str - self_id: str - url: str - headers: Dict[str, str] - http_version: str - reconnect_interval: float - - class Driver(ForwardDriver): + """ + AIOHTTP 驱动框架 + """ def __init__(self, env: Env, config: Config): super().__init__(env, config) self.startup_funcs: Set[STARTUP_FUNC] = set() self.shutdown_funcs: Set[SHUTDOWN_FUNC] = set() - self.http_pollings: List[HTTPPollingSetup] = [] - self.websockets: List[WebSocketSetup] = [] + self.http_pollings: List[HTTPPOLLING_SETUP] = [] + self.websockets: List[WEBSOCKET_SETUP] = [] self.connections: List[asyncio.Task] = [] self.should_exit: asyncio.Event = asyncio.Event() self.force_exit: bool = False @@ -67,46 +56,66 @@ class Driver(ForwardDriver): @property @overrides(ForwardDriver) def logger(self): + """aiohttp driver 使用的 logger""" return logger @overrides(ForwardDriver) - def on_startup(self, func: Callable) -> Callable: + def on_startup(self, func: STARTUP_FUNC) -> STARTUP_FUNC: + """ + :说明: + + 注册一个启动时执行的函数 + + :参数: + + * ``func: Callable[[], Awaitable[None]]`` + """ self.startup_funcs.add(func) return func @overrides(ForwardDriver) - def on_shutdown(self, func: Callable) -> Callable: + def on_shutdown(self, func: SHUTDOWN_FUNC) -> SHUTDOWN_FUNC: + """ + :说明: + + 注册一个停止时执行的函数 + + :参数: + + * ``func: Callable[[], Awaitable[None]]`` + """ self.shutdown_funcs.add(func) return func @overrides(ForwardDriver) - def setup_http_polling(self, - adapter: str, - self_id: str, - url: str, - polling_interval: float = 3., - method: str = "GET", - body: bytes = b"", - headers: Dict[str, str] = {}, - http_version: str = "1.1") -> None: - self.http_pollings.append( - HTTPPollingSetup(adapter, self_id, url, method, body, headers, - http_version, polling_interval)) + def setup_http_polling(self, setup: HTTPPOLLING_SETUP) -> None: + """ + :说明: + + 注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用 + + :参数: + + * ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]`` + """ + self.http_pollings.append(setup) @overrides(ForwardDriver) - def setup_websocket(self, - adapter: str, - self_id: str, - url: str, - reconnect_interval: float = 3., - headers: Dict[str, str] = {}, - http_version: str = "1.1") -> None: - self.websockets.append( - WebSocketSetup(adapter, self_id, url, headers, http_version, - reconnect_interval)) + def setup_websocket(self, setup: WEBSOCKET_SETUP) -> None: + """ + :说明: + + 注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用 + + :参数: + + * ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]`` + """ + self.websockets.append(setup) @overrides(ForwardDriver) def run(self, *args, **kwargs): + """启动 aiohttp driver""" super().run(*args, **kwargs) loop = asyncio.get_event_loop() loop.run_until_complete(self.serve()) @@ -197,59 +206,88 @@ class Driver(ForwardDriver): else: self.should_exit.set() - async def _http_loop(self, setup: HTTPPollingSetup): - url = URL(setup.url) - if not url.is_absolute() or not url.host: - logger.opt(colors=True).error( - f"Error parsing url {url}") - return - host = f"{url.host}:{url.port}" if url.port else url.host - request = HTTPRequest(setup.http_version, url.scheme, url.path, - url.raw_query_string.encode("latin-1"), { - **setup.headers, "host": host - }, setup.method, setup.body) + async def _http_loop(self, setup: HTTPPOLLING_SETUP): + + async def _build_request( + setup: HTTPPollingSetup) -> Optional[HTTPRequest]: + url = URL(setup.url) + if not url.is_absolute() or not url.host: + logger.opt(colors=True).error( + f"Error parsing url {url}") + return + host = f"{url.host}:{url.port}" if url.port else url.host + return HTTPRequest(setup.http_version, url.scheme, url.path, + url.raw_query_string.encode("latin-1"), { + **setup.headers, "host": host + }, setup.method, setup.body) + + bot: Optional[Bot] = None + request: Optional[HTTPRequest] = None + setup_: Optional[HTTPPollingSetup] = None - BotClass = self._adapters[setup.adapter] - bot = BotClass(setup.self_id, request) - self._bot_connect(bot) logger.opt(colors=True).info( f"Start http polling for {setup.adapter.upper()} " f"Bot {setup.self_id}") - headers = request.headers - timeout = aiohttp.ClientTimeout(30) - version: aiohttp.HttpVersion - if request.http_version == "1.0": - version = aiohttp.HttpVersion10 - elif request.http_version == "1.1": - version = aiohttp.HttpVersion11 - else: - logger.opt(colors=True).error( - "Unsupported HTTP Version " - f"{request.http_version}") - return - try: - async with aiohttp.ClientSession(headers=headers, - timeout=timeout, - version=version) as session: + async with aiohttp.ClientSession() as session: while not self.should_exit.is_set(): + if not bot: + if callable(setup): + setup_ = await setup() + else: + setup_ = setup + request = await _build_request(setup_) + if not request: + return + + BotClass = self._adapters[setup.adapter] + bot = BotClass(setup.self_id, request) + self._bot_connect(bot) + elif callable(setup): + setup_ = await setup() + request = await _build_request(setup_) + if not request: + await asyncio.sleep(setup_.poll_interval) + continue + bot.request = request + + request = cast(HTTPRequest, request) + setup_ = cast(HTTPPollingSetup, setup_) + + headers = request.headers + timeout = aiohttp.ClientTimeout(30) + version: aiohttp.HttpVersion + if request.http_version == "1.0": + version = aiohttp.HttpVersion10 + elif request.http_version == "1.1": + version = aiohttp.HttpVersion11 + else: + logger.opt(colors=True).error( + "Unsupported HTTP Version " + f"{request.http_version}") + return + logger.debug( - f"Bot {setup.self_id} from adapter {setup.adapter} request {url}" + f"Bot {setup_.self_id} from adapter {setup_.adapter} request {setup_.url}" ) + try: - async with session.request( - request.method, url, - data=request.body) as response: + async with session.request(request.method, + setup_.url, + data=request.body, + headers=headers, + timeout=timeout, + version=version) as response: response.raise_for_status() data = await response.read() asyncio.create_task(bot.handle_message(data)) except aiohttp.ClientResponseError as e: logger.opt(colors=True, exception=e).error( - f"Error occurred while requesting {url}. " + f"Error occurred while requesting {setup_.url}. " "Try to reconnect...") - await asyncio.sleep(setup.poll_interval) + await asyncio.sleep(setup_.poll_interval) except asyncio.CancelledError: pass @@ -258,50 +296,48 @@ class Driver(ForwardDriver): "Unexpected exception occurred " "while http polling") finally: - self._bot_disconnect(bot) - - async def _ws_loop(self, setup: WebSocketSetup): - url = URL(setup.url) - if not url.is_absolute() or not url.host: - logger.opt(colors=True).error( - f"Error parsing url {url}") - return - host = f"{url.host}:{url.port}" if url.port else url.host - - headers = {**setup.headers, "host": host} - timeout = aiohttp.ClientTimeout(30) - version: aiohttp.HttpVersion - if setup.http_version == "1.0": - version = aiohttp.HttpVersion10 - elif setup.http_version == "1.1": - version = aiohttp.HttpVersion11 - else: - logger.opt(colors=True).error( - "Unsupported HTTP Version " - f"{setup.http_version}") - return + if bot: + self._bot_disconnect(bot) + async def _ws_loop(self, setup: WEBSOCKET_SETUP): bot: Optional[Bot] = None + try: - async with aiohttp.ClientSession(headers=headers, - timeout=timeout, - version=version) as session: + async with aiohttp.ClientSession() as session: while True: + if callable(setup): + setup_ = await setup() + else: + setup_ = setup + + url = URL(setup_.url) + if not url.is_absolute() or not url.host: + logger.opt(colors=True).error( + f"Error parsing url {url}" + ) + await asyncio.sleep(setup_.reconnect_interval) + continue + + host = f"{url.host}:{url.port}" if url.port else url.host + headers = {**setup_.headers, "host": host} + logger.debug( - f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}" + f"Bot {setup_.self_id} from adapter {setup_.adapter} connecting to {url}" ) try: - async with session.ws_connect(url) as ws: + async with session.ws_connect(url, + headers=headers, + timeout=30.) as ws: logger.opt(colors=True).info( - f"WebSocket Connection to {setup.adapter.upper()} " - f"Bot {setup.self_id} succeeded!") + f"WebSocket Connection to {setup_.adapter.upper()} " + f"Bot {setup_.self_id} succeeded!") request = WebSocket( - setup.http_version, url.scheme, url.path, + "1.1", url.scheme, url.path, url.raw_query_string.encode("latin-1"), headers, ws) - BotClass = self._adapters[setup.adapter] - bot = BotClass(setup.self_id, request) + BotClass = self._adapters[setup_.adapter] + bot = BotClass(setup_.self_id, request) self._bot_connect(bot) while not self.should_exit.is_set(): msg = await ws.receive() @@ -330,7 +366,7 @@ class Driver(ForwardDriver): if bot: self._bot_disconnect(bot) bot = None - await asyncio.sleep(setup.reconnect_interval) + await asyncio.sleep(setup_.reconnect_interval) except asyncio.CancelledError: pass diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index caf9e2e9..b263ba7a 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -2,6 +2,8 @@ FastAPI 驱动适配 ================ +本驱动同时支持服务端以及客户端连接 + 后端使用方法请参考: `FastAPI 文档`_ .. _FastAPI 文档: @@ -11,7 +13,7 @@ FastAPI 驱动适配 import asyncio import logging from dataclasses import dataclass -from typing import List, Dict, Union, Optional, Callable +from typing import List, cast, Union, Optional, Callable, Awaitable import httpx import uvicorn @@ -27,30 +29,13 @@ from nonebot.log import logger from nonebot.adapters import Bot from nonebot.typing import overrides from nonebot.config import Env, Config as NoneBotConfig -from nonebot.drivers import ReverseDriver, ForwardDriver -from nonebot.drivers import HTTPRequest, WebSocket as BaseWebSocket +from nonebot.drivers import (ReverseDriver, ForwardDriver, HTTPPollingSetup, + WebSocketSetup, HTTPRequest, WebSocket as + BaseWebSocket) - -@dataclass -class HTTPPollingSetup: - adapter: str - self_id: str - url: str - method: str - body: bytes - headers: Dict[str, str] - http_version: str - poll_interval: float - - -@dataclass -class WebSocketSetup: - adapter: str - self_id: str - url: str - headers: Dict[str, str] - http_version: str - reconnect_interval: float +HTTPPOLLING_SETUP = Union[HTTPPollingSetup, + Callable[[], Awaitable[HTTPPollingSetup]]] +WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]] class Config(BaseSettings): @@ -118,8 +103,8 @@ class Driver(ReverseDriver, ForwardDriver): super().__init__(env, config) self.fastapi_config: Config = Config(**config.dict()) - self.http_pollings: List[HTTPPollingSetup] = [] - self.websockets: List[WebSocketSetup] = [] + self.http_pollings: List[HTTPPOLLING_SETUP] = [] + self.websockets: List[WEBSOCKET_SETUP] = [] self.shutdown: asyncio.Event = asyncio.Event() self.connections: List[asyncio.Task] = [] @@ -173,30 +158,30 @@ class Driver(ReverseDriver, ForwardDriver): return self.server_app.on_event("shutdown")(func) @overrides(ForwardDriver) - def setup_http_polling(self, - adapter: str, - self_id: str, - url: str, - polling_interval: float = 3., - method: str = "GET", - body: bytes = b"", - headers: Dict[str, str] = {}, - http_version: str = "1.1") -> None: - self.http_pollings.append( - HTTPPollingSetup(adapter, self_id, url, method, body, headers, - http_version, polling_interval)) + def setup_http_polling(self, setup: HTTPPOLLING_SETUP) -> None: + """ + :说明: + + 注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用 + + :参数: + + * ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]`` + """ + self.http_pollings.append(setup) @overrides(ForwardDriver) - def setup_websocket(self, - adapter: str, - self_id: str, - url: str, - reconnect_interval: float = 3., - headers: Dict[str, str] = {}, - http_version: str = "1.1") -> None: - self.websockets.append( - WebSocketSetup(adapter, self_id, url, headers, http_version, - reconnect_interval)) + def setup_websocket(self, setup: WEBSOCKET_SETUP) -> None: + """ + :说明: + + 注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用 + + :参数: + + * ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]`` + """ + self.websockets.append(setup) @overrides(ReverseDriver) def run(self, @@ -336,50 +321,72 @@ class Driver(ReverseDriver, ForwardDriver): finally: self._bot_disconnect(bot) - async def _http_loop(self, setup: HTTPPollingSetup): - url = httpx.URL(setup.url) - if not url.netloc: - logger.opt(colors=True).error( - f"Error parsing url {url}") - return - request = HTTPRequest( - setup.http_version, url.scheme, url.path, url.query, { - **setup.headers, "host": url.netloc.decode("ascii") - }, setup.method, setup.body) + async def _http_loop(self, setup: HTTPPOLLING_SETUP): + + async def _build_request( + setup: HTTPPollingSetup) -> Optional[HTTPRequest]: + url = httpx.URL(setup.url) + if not url.netloc: + logger.opt(colors=True).error( + f"Error parsing url {url}") + return + return HTTPRequest( + setup.http_version, url.scheme, url.path, url.query, { + **setup.headers, "host": url.netloc.decode("ascii") + }, setup.method, setup.body) + + bot: Optional[Bot] = None + request: Optional[HTTPRequest] = None + setup_: Optional[HTTPPollingSetup] = None - BotClass = self._adapters[setup.adapter] - bot = BotClass(setup.self_id, request) - self._bot_connect(bot) logger.opt(colors=True).info( f"Start http polling for {setup.adapter.upper()} " f"Bot {setup.self_id}") - headers = request.headers - http2: bool = False - if request.http_version == "2": - http2 = True - try: - async with httpx.AsyncClient(headers=headers, - timeout=30., - http2=http2) as session: + async with httpx.AsyncClient(http2=True) as session: while not self.shutdown.is_set(): + if not bot: + if callable(setup): + setup_ = await setup() + else: + setup_ = setup + request = await _build_request(setup_) + if not request: + return + BotClass = self._adapters[setup.adapter] + bot = BotClass(setup.self_id, request) + self._bot_connect(bot) + elif callable(setup): + setup_ = await setup() + request = await _build_request(setup_) + if not request: + await asyncio.sleep(setup_.poll_interval) + continue + bot.request = request + + setup_ = cast(HTTPPollingSetup, setup_) + request = cast(HTTPRequest, request) + headers = request.headers + logger.debug( - f"Bot {setup.self_id} from adapter {setup.adapter} request {url}" + f"Bot {setup_.self_id} from adapter {setup_.adapter} request {setup_.url}" ) try: response = await session.request(request.method, - url, - content=request.body) + setup_.url, + content=request.body, + headers=headers, + timeout=30.) response.raise_for_status() data = response.read() asyncio.create_task(bot.handle_message(data)) except httpx.HTTPError as e: logger.opt(colors=True, exception=e).error( - f"Error occurred while requesting {url}. " + f"Error occurred while requesting {setup_.url}. " "Try to reconnect...") - await asyncio.sleep(setup.poll_interval) + await asyncio.sleep(setup_.poll_interval) except asyncio.CancelledError: pass @@ -388,34 +395,41 @@ class Driver(ReverseDriver, ForwardDriver): "Unexpected exception occurred " "while http polling") finally: - self._bot_disconnect(bot) - - async def _ws_loop(self, setup: WebSocketSetup): - url = httpx.URL(setup.url) - if not url.netloc: - logger.opt(colors=True).error( - f"Error parsing url {url}") - return - - headers = {**setup.headers, "host": url.netloc.decode("ascii")} + if bot: + self._bot_disconnect(bot) + async def _ws_loop(self, setup: WEBSOCKET_SETUP): bot: Optional[Bot] = None + try: while True: + if callable(setup): + setup_ = await setup() + else: + setup_ = setup + + url = httpx.URL(setup_.url) + if not url.netloc: + logger.opt(colors=True).error( + f"Error parsing url {url}" + ) + return + + headers = {**setup_.headers, "host": url.netloc.decode("ascii")} logger.debug( - f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}" + f"Bot {setup_.self_id} from adapter {setup_.adapter} connecting to {url}" ) try: - connection = Connect(setup.url) + connection = Connect(setup_.url) async with connection as ws: logger.opt(colors=True).info( - f"WebSocket Connection to {setup.adapter.upper()} " - f"Bot {setup.self_id} succeeded!") - request = WebSocket(setup.http_version, url.scheme, - url.path, url.query, headers, ws) + f"WebSocket Connection to {setup_.adapter.upper()} " + f"Bot {setup_.self_id} succeeded!") + request = WebSocket("1.1", url.scheme, url.path, + url.query, headers, ws) - BotClass = self._adapters[setup.adapter] - bot = BotClass(setup.self_id, request) + BotClass = self._adapters[setup_.adapter] + bot = BotClass(setup_.self_id, request) self._bot_connect(bot) while not self.shutdown.is_set(): # use try except instead of "request.closed" because of queued message @@ -434,7 +448,7 @@ class Driver(ReverseDriver, ForwardDriver): if bot: self._bot_disconnect(bot) bot = None - await asyncio.sleep(setup.reconnect_interval) + await asyncio.sleep(setup_.reconnect_interval) except asyncio.CancelledError: pass diff --git a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py index 6f6233bc..a6a922df 100644 --- a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py +++ b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py @@ -11,7 +11,7 @@ from nonebot.typing import overrides from nonebot.message import handle_event from nonebot.adapters import Bot as BaseBot from nonebot.utils import escape_tag, DataclassEncoder -from nonebot.drivers import Driver, ForwardDriver, ReverseDriver +from nonebot.drivers import Driver, ForwardDriver, WebSocketSetup from nonebot.drivers import HTTPConnection, HTTPRequest, HTTPResponse, WebSocket from .utils import log, escape @@ -249,10 +249,8 @@ class Bot(BaseBot): "authorization": f"Bearer {cls.cqhttp_config.access_token}" } if cls.cqhttp_config.access_token else {} - driver.setup_websocket("cqhttp", - self_id, - url, - headers=headers) + driver.setup_websocket( + WebSocketSetup("cqhttp", self_id, url, headers=headers)) except Exception as e: logger.opt(colors=True, exception=e).error( f"Bad url {url} for bot {self_id} "