From ecc613f6c500e20f9e932741eb1f71317f90f802 Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Tue, 20 Jul 2021 15:35:56 +0800 Subject: [PATCH] :sparkles: add cqhttp forward support --- docs/api/adapters/README.md | 4 +- docs/api/adapters/cqhttp.md | 3 + docs/api/drivers/README.md | 3 + docs/api/exception.md | 22 -- nonebot/drivers/__init__.py | 21 +- nonebot/drivers/aiohttp.py | 340 ++++++++++++------ nonebot/exception.py | 18 - .../nonebot/adapters/cqhttp/bot.py | 17 +- tests/.env.dev | 4 +- tests/bot.py | 5 +- 10 files changed, 278 insertions(+), 159 deletions(-) diff --git a/docs/api/adapters/README.md b/docs/api/adapters/README.md index e1682ce4..d1431b1f 100644 --- a/docs/api/adapters/README.md +++ b/docs/api/adapters/README.md @@ -85,12 +85,12 @@ Config 配置对象 Adapter 类型 -### _classmethod_ `register(driver, config)` +### _classmethod_ `register(driver, config, **kwargs)` * **说明** - register 方法会在 driver.register_adapter 时被调用,用于初始化相关配置 + `register` 方法会在 `driver.register_adapter` 时被调用,用于初始化相关配置 diff --git a/docs/api/adapters/cqhttp.md b/docs/api/adapters/cqhttp.md index 613cfe6c..1147a2cd 100644 --- a/docs/api/adapters/cqhttp.md +++ b/docs/api/adapters/cqhttp.md @@ -26,6 +26,9 @@ CQHTTP 配置类 * `secret` / `cqhttp_secret`: CQHTTP HTTP 上报数据签名口令 + * `ws_urls` / `cqhttp_ws_urls`: CQHTTP 正向 Websocket 连接 Bot ID、目标 URL 字典 + + # NoneBot.adapters.cqhttp.utils 模块 diff --git a/docs/api/drivers/README.md b/docs/api/drivers/README.md index 057a1903..1bfba3fc 100644 --- a/docs/api/drivers/README.md +++ b/docs/api/drivers/README.md @@ -153,6 +153,9 @@ Driver 基类。 * `adapter: Type[Bot]`: 适配器 Class + * `**kwargs`: 其他传递给适配器的参数 + + ### _abstract property_ `type` diff --git a/docs/api/exception.md b/docs/api/exception.md index 8f71b87a..f48a493b 100644 --- a/docs/api/exception.md +++ b/docs/api/exception.md @@ -61,28 +61,6 @@ sidebarDepth: 0 -## _exception_ `DriverException` - -基类:`nonebot.exception.NoneBotException` - - -* **说明** - - 代表 `Driver` 抛出的异常 - - - -## _exception_ `SetupFailed` - -基类:`nonebot.exception.DriverException` - - -* **说明** - - `ForwardDriver` 建立连接失败 - - - ## _exception_ `PausedException` 基类:`nonebot.exception.NoneBotException` diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index 54b6b3e4..abb289fd 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -196,8 +196,25 @@ class Driver(abc.ABC): class ForwardDriver(Driver): @abc.abstractmethod - def setup(self, adapter: str, self_id: str, - request: "HTTPConnection") -> None: + def setup_http_polling(self, + adapter: str, + self_id: str, + url: str, + polling_interval: float = 3., + method: str = "GET", + body: bytes = b"", + headers: Dict[str, str] = {}, + http_version: str = "1.1") -> None: + raise NotImplementedError + + @abc.abstractmethod + def setup_websocket(self, + adapter: str, + self_id: str, + url: str, + reconnect_interval: float = 3., + headers: Dict[str, str] = {}, + http_version: str = "1.1") -> None: raise NotImplementedError diff --git a/nonebot/drivers/aiohttp.py b/nonebot/drivers/aiohttp.py index ba9d9285..7bff8cc1 100644 --- a/nonebot/drivers/aiohttp.py +++ b/nonebot/drivers/aiohttp.py @@ -3,8 +3,9 @@ import signal import asyncio +import threading from dataclasses import dataclass -from typing import Set, List, Union, Callable, Awaitable +from typing import Set, List, Dict, Optional, Callable, Awaitable import aiohttp from yarl import URL @@ -13,20 +14,35 @@ from nonebot.log import logger from nonebot.adapters import Bot from nonebot.typing import overrides from nonebot.config import Env, Config -from nonebot.exception import SetupFailed -from nonebot.drivers import ForwardDriver, HTTPConnection, HTTPRequest, WebSocket +from nonebot.drivers import ForwardDriver, HTTPRequest, WebSocket as BaseWebSocket STARTUP_FUNC = Callable[[], Awaitable[None]] SHUTDOWN_FUNC = Callable[[], Awaitable[None]] -AVAILABLE_REQUEST = Union[HTTPRequest, WebSocket] +HANDLED_SIGNALS = ( + signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. + signal.SIGTERM, # Unix signal 15. Sent by `kill `. +) @dataclass -class RequestSetup: +class HTTPPollingSetup: adapter: str self_id: str - request: AVAILABLE_REQUEST + url: str + method: str + body: bytes + headers: Dict[str, str] + http_version: str poll_interval: float + + +@dataclass +class WebSocketSetup: + adapter: str + self_id: str + url: str + headers: Dict[str, str] + http_version: str reconnect_interval: float @@ -36,7 +52,11 @@ class Driver(ForwardDriver): super().__init__(env, config) self.startup_funcs: Set[STARTUP_FUNC] = set() self.shutdown_funcs: Set[SHUTDOWN_FUNC] = set() - self.requests: List[RequestSetup] = [] + self.http_pollings: List[HTTPPollingSetup] = [] + self.websockets: List[WebSocketSetup] = [] + self.connections: List[asyncio.Task] = [] + self.should_exit: bool = False + self.force_exit: bool = False @property @overrides(ForwardDriver) @@ -60,54 +80,52 @@ class Driver(ForwardDriver): return func @overrides(ForwardDriver) - def setup(self, - adapter: str, - self_id: str, - request: HTTPConnection, - poll_interval: float = 3., - reconnect_interval: float = 3.) -> None: - if not isinstance(request, (HTTPRequest, WebSocket)): - raise TypeError(f"Request Type {type(request)!r} is not supported!") - self.requests.append( - RequestSetup(adapter, self_id, request, poll_interval, - reconnect_interval)) + def setup_http_polling(self, + adapter: str, + self_id: str, + url: str, + polling_interval: float = 3., + method: str = "GET", + body: bytes = b"", + headers: Dict[str, str] = {}, + http_version: str = "1.1") -> None: + self.http_pollings.append( + HTTPPollingSetup(adapter, self_id, url, method, body, headers, + http_version, polling_interval)) + + @overrides(ForwardDriver) + def setup_websocket(self, + adapter: str, + self_id: str, + url: str, + reconnect_interval: float = 3., + headers: Dict[str, str] = {}, + http_version: str = "1.1") -> None: + self.websockets.append( + WebSocketSetup(adapter, self_id, url, headers, http_version, + reconnect_interval)) @overrides(ForwardDriver) def run(self, *args, **kwargs): + super().run(*args, **kwargs) loop = asyncio.get_event_loop() - signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) - for s in signals: - loop.add_signal_handler( - s, - lambda s=s: asyncio.create_task(self.shutdown(loop, signal=s))) + loop.run_until_complete(self.serve()) - try: - asyncio.create_task(self.startup()) - loop.run_forever() - finally: - loop.close() + async def serve(self): + self.install_signal_handlers() + await self.startup() + if self.should_exit: + return + await self.main_loop() + await self.shutdown() async def startup(self): - setups = [] - loop = asyncio.get_event_loop() - for setup in self.requests: - if isinstance(setup.request, HTTPRequest): - setups.append( - self._http_setup(setup.adapter, setup.self_id, - setup.request, setup.poll_interval)) - else: - setups.append( - self._ws_setup(setup.adapter, setup.self_id, setup.request, - setup.reconnect_interval)) + for setup in self.http_pollings: + self.connections.append(asyncio.create_task(self._http_loop(setup))) + for setup in self.websockets: + self.connections.append(asyncio.create_task(self._ws_loop(setup))) - try: - await asyncio.gather(*setups) - except Exception as e: - logger.opt( - colors=True, - exception=e).error("Application startup failed. Exiting.") - asyncio.create_task(self.shutdown(loop)) - return + logger.info("Application startup completed.") # run startup cors = [startup() for startup in self.startup_funcs] @@ -119,11 +137,11 @@ class Driver(ForwardDriver): "Error when running startup function. " "Ignored!") - async def shutdown(self, - loop: asyncio.AbstractEventLoop, - signal: signal.Signals = None): - # TODO: shutdown + async def main_loop(self): + while not self.should_exit: + await asyncio.sleep(0.1) + async def shutdown(self): # run shutdown cors = [shutdown() for shutdown in self.shutdown_funcs] if cors: @@ -134,44 +152,89 @@ class Driver(ForwardDriver): "Error when running shutdown function. " "Ignored!") + for task in self.connections: + if not task.done(): + task.cancel() + await asyncio.sleep(0.1) + tasks = [ t for t in asyncio.all_tasks() if t is not asyncio.current_task() ] + if tasks and not self.force_exit: + logger.info("Waiting for tasks to finish. (CTRL+C to force quit)") + while tasks and not self.force_exit: + await asyncio.sleep(0.1) + tasks = [ + t for t in asyncio.all_tasks() + if t is not asyncio.current_task() + ] for task in tasks: task.cancel() await asyncio.gather(*tasks, return_exceptions=True) + loop = asyncio.get_event_loop() loop.stop() - async def _http_setup(self, adapter: str, self_id: str, - request: HTTPRequest, poll_interval: float): - BotClass = self._adapters[adapter] + def install_signal_handlers(self) -> None: + if threading.current_thread() is not threading.main_thread(): + # Signals can only be listened to from the main thread. + return - bot = BotClass(self_id, request) - self._bot_connect(bot) - asyncio.create_task(self._http_loop(bot, request, poll_interval)) + loop = asyncio.get_event_loop() + + try: + for sig in HANDLED_SIGNALS: + loop.add_signal_handler(sig, self.handle_exit, sig, None) + except NotImplementedError: + # Windows + for sig in HANDLED_SIGNALS: + signal.signal(sig, self.handle_exit) + + def handle_exit(self, sig, frame): + if self.should_exit: + self.force_exit = True + else: + self.should_exit = True + + async def _http_loop(self, setup: HTTPPollingSetup): + url = URL(setup.url) + if not url.is_absolute() or not url.host: + logger.opt(colors=True).error( + f"Error parsing url {url}") + return + host = f"{url.host}:{url.port}" if url.port else url.host + request = HTTPRequest(setup.http_version, url.scheme, url.path, + url.raw_query_string.encode("latin-1"), { + **setup.headers, "host": host + }, setup.method, setup.body) + + BotClass = self._adapters[setup.adapter] + bot = BotClass(setup.self_id, request) + self._bot_connect(bot) + + headers = request.headers + timeout = aiohttp.ClientTimeout(30) + version: aiohttp.HttpVersion + if request.http_version == "1.0": + version = aiohttp.HttpVersion10 + elif request.http_version == "1.1": + version = aiohttp.HttpVersion11 + else: + logger.opt(colors=True).error( + "Unsupported HTTP Version " + f"{request.http_version}") + return - async def _ws_setup(self, adapter: str, self_id: str, request: WebSocket, - reconnect_interval: float): - BotClass = self._adapters[adapter] - bot = BotClass(self_id, request) - self._bot_connect(bot) - asyncio.create_task(self._ws_loop(bot, request, reconnect_interval)) - - async def _http_loop(self, bot: Bot, request: HTTPRequest, - poll_interval: float): try: - headers = request.headers - url = URL.build(scheme=request.scheme, - host=request.headers["host"], - path=request.path, - query_string=request.query_string.decode("latin-1")) - timeout = aiohttp.ClientTimeout(30) async with aiohttp.ClientSession(headers=headers, - timeout=timeout) as session: - while True: + timeout=timeout, + version=version) as session: + while not self.should_exit: + logger.debug( + f"Bot {setup.self_id} from adapter {setup.adapter} request {url}" + ) try: async with session.request( request.method, url, @@ -181,9 +244,10 @@ class Driver(ForwardDriver): asyncio.create_task(bot.handle_message(data)) except aiohttp.ClientResponseError as e: logger.opt(colors=True, exception=e).error( - f"Error occurred while requesting {url}") + f"Error occurred while requesting {url}. " + "Try to reconnect...") - await asyncio.sleep(poll_interval) + await asyncio.sleep(setup.poll_interval) except asyncio.CancelledError: pass @@ -193,37 +257,111 @@ class Driver(ForwardDriver): finally: self._bot_disconnect(bot) - async def _ws_loop(self, bot: Bot, request: WebSocket, - reconnect_interval: float): + async def _ws_loop(self, setup: WebSocketSetup): + url = URL(setup.url) + if not url.is_absolute() or not url.host: + logger.opt(colors=True).error( + f"Error parsing url {url}") + return + host = f"{url.host}:{url.port}" if url.port else url.host + + headers = {**setup.headers, "host": host} + timeout = aiohttp.ClientTimeout(30) + version: aiohttp.HttpVersion + if setup.http_version == "1.0": + version = aiohttp.HttpVersion10 + elif setup.http_version == "1.1": + version = aiohttp.HttpVersion11 + else: + logger.opt(colors=True).error( + "Unsupported HTTP Version " + f"{setup.http_version}") + return + + bot: Optional[Bot] = None try: - headers = request.headers - url = URL.build(scheme=request.scheme, - host=request.headers["host"], - path=request.path, - query_string=request.query_string.decode("latin-1")) - timeout = aiohttp.ClientTimeout(30) async with aiohttp.ClientSession(headers=headers, - timeout=timeout) as session: + timeout=timeout, + version=version) as session: while True: - async with session.ws_connect(url) as ws: - async for msg in ws: - if msg.type == aiohttp.WSMsgType.text: - asyncio.create_task( - bot.handle_message(msg.data.encode())) - elif msg.type == aiohttp.WSMsgType.binary: - asyncio.create_task(bot.handle_message( - msg.data)) - elif msg.type == aiohttp.WSMsgType.error: - logger.opt(colors=True).error( - "Error while handling websocket frame. " - "Try to reconnect...") - break - asyncio.sleep(reconnect_interval) + logger.debug( + f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}" + ) + try: + async with session.ws_connect(url) as ws: + request = WebSocket( + setup.http_version, url.scheme, url.path, + url.raw_query_string.encode("latin-1"), { + **setup.headers, "host": host + }, ws) + + BotClass = self._adapters[setup.adapter] + bot = BotClass(setup.self_id, request) + self._bot_connect(bot) + while not self.should_exit: + msg = await ws.receive() + if msg.type == aiohttp.WSMsgType.text: + asyncio.create_task( + bot.handle_message(msg.data.encode())) + elif msg.type == aiohttp.WSMsgType.binary: + asyncio.create_task( + bot.handle_message(msg.data)) + elif msg.type == aiohttp.WSMsgType.error: + logger.opt(colors=True).error( + "Error while handling websocket frame. " + "Try to reconnect...") + break + else: + logger.opt(colors=True).error( + "WebSocket connection closed by peer. " + "Try to reconnect...") + break + except aiohttp.WSServerHandshakeError as e: + logger.opt(colors=True, exception=e).error( + f"Error while connecting to {url}" + "Try to reconnect...") + finally: + if bot: + self._bot_disconnect(bot) + bot = None + await asyncio.sleep(setup.reconnect_interval) except asyncio.CancelledError: pass except Exception as e: logger.opt(colors=True, exception=e).error( "Unexpected exception occurred while websocket loop") - finally: - self._bot_disconnect(bot) + + +@dataclass +class WebSocket(BaseWebSocket): + websocket: aiohttp.ClientWebSocketResponse = None # type: ignore + + @property + @overrides(BaseWebSocket) + def closed(self): + return self.websocket.closed + + @overrides(BaseWebSocket) + async def accept(self): + raise NotImplementedError + + @overrides(BaseWebSocket) + async def close(self, code: int = 1000): + await self.websocket.close(code=code) + + @overrides(BaseWebSocket) + async def receive(self) -> str: + return await self.websocket.receive_str() + + @overrides(BaseWebSocket) + async def receive_bytes(self) -> bytes: + return await self.websocket.receive_bytes() + + @overrides(BaseWebSocket) + async def send(self, data: str) -> None: + await self.websocket.send_str(data) + + @overrides(BaseWebSocket) + async def send_bytes(self, data: bytes) -> None: + await self.websocket.send_bytes(data) diff --git a/nonebot/exception.py b/nonebot/exception.py index c218f1a4..3cad317a 100644 --- a/nonebot/exception.py +++ b/nonebot/exception.py @@ -60,24 +60,6 @@ class ParserExit(NoneBotException): return self.__repr__() -class DriverException(NoneBotException): - """ - :说明: - - 代表 ``Driver`` 抛出的异常 - """ - pass - - -class SetupFailed(DriverException): - """ - :说明: - - ``ForwardDriver`` 建立连接失败 - """ - pass - - class PausedException(NoneBotException): """ :说明: diff --git a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py index 842ef0ea..00947e31 100644 --- a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py +++ b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py @@ -3,7 +3,6 @@ import sys import hmac import json import asyncio -from urllib.parse import urlsplit from typing import Any, Dict, Tuple, Union, Optional, TYPE_CHECKING import httpx @@ -246,22 +245,18 @@ class Bot(BaseBot): elif isinstance(driver, ForwardDriver) and cls.cqhttp_config.ws_urls: for self_id, url in cls.cqhttp_config.ws_urls.items(): try: - url_info = urlsplit(url) headers = { "authorization": - f"Bearer {cls.cqhttp_config.access_token}", - "host": - url_info.netloc if not url_info.port else - f"{url_info.netloc}:{url_info.port}", + f"Bearer {cls.cqhttp_config.access_token}" } - driver.setup( - "cqhttp", self_id, - WebSocket("1.1", url_info.scheme, url_info.path, - url_info.query.encode("latin-1"), headers)) + driver.setup_websocket("cqhttp", + self_id, + url, + headers=headers) except Exception as e: logger.opt(colors=True, exception=e).error( f"Bad url {url} for bot {self_id} " - "in cqhttp forward websocket") + "in cqhttp forward websocket") @classmethod @overrides(BaseBot) diff --git a/tests/.env.dev b/tests/.env.dev index 8e167d1d..7056d21a 100644 --- a/tests/.env.dev +++ b/tests/.env.dev @@ -1,4 +1,4 @@ -DRIVER=nonebot.drivers.fastapi +DRIVER=nonebot.drivers.aiohttp:Driver HOST=0.0.0.0 PORT=2333 DEBUG=true @@ -13,6 +13,8 @@ COMMAND_SEP=["/", "."] CUSTOM_CONFIG1=config in env CUSTOM_CONFIG3= +CQHTTP_WS_URLS={"123123123": "ws://127.0.0.1:6700/"} + MIRAI_AUTH_KEY=12345678 MIRAI_HOST=127.0.0.1 MIRAI_PORT=8080 diff --git a/tests/bot.py b/tests/bot.py index ccccd009..9542ebe2 100644 --- a/tests/bot.py +++ b/tests/bot.py @@ -18,7 +18,7 @@ logger.add("error.log", format=default_format) nonebot.init(custom_config2="config on init") -app = nonebot.get_asgi() +# app = nonebot.get_asgi() driver = nonebot.get_driver() driver.register_adapter("cqhttp", Bot) driver.register_adapter("ding", DingBot) @@ -37,4 +37,5 @@ config.custom_config3 = config.custom_config1 config.custom_config4 = "New custom config" if __name__ == "__main__": - nonebot.run(app="__mp_main__:app") + # nonebot.run(app="__mp_main__:app") + nonebot.run()