diff --git a/nonebot/adapters/_base.py b/nonebot/adapters/_base.py index 603f795a..8223aa0e 100644 --- a/nonebot/adapters/_base.py +++ b/nonebot/adapters/_base.py @@ -71,11 +71,11 @@ class Bot(abc.ABC): raise NotImplementedError @classmethod - def register(cls, driver: Driver, config: Config): + def register(cls, driver: Driver, config: Config, **kwargs): """ :说明: - `register` 方法会在 `driver.register_adapter` 时被调用,用于初始化相关配置 + ``register`` 方法会在 ``driver.register_adapter`` 时被调用,用于初始化相关配置 """ cls.driver = driver cls.config = config diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index 04372028..54b6b3e4 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -84,6 +84,7 @@ class Driver(abc.ABC): * ``name: str``: 适配器名称,用于在连接时进行识别 * ``adapter: Type[Bot]``: 适配器 Class + * ``**kwargs``: 其他传递给适配器的参数 """ if name in self._adapters: logger.opt( @@ -195,7 +196,8 @@ class Driver(abc.ABC): class ForwardDriver(Driver): @abc.abstractmethod - def setup(self, adapter: str, request: "HTTPConnection") -> None: + def setup(self, adapter: str, self_id: str, + request: "HTTPConnection") -> None: raise NotImplementedError diff --git a/nonebot/drivers/aiohttp.py b/nonebot/drivers/aiohttp.py index 035f4c84..ba9d9285 100644 --- a/nonebot/drivers/aiohttp.py +++ b/nonebot/drivers/aiohttp.py @@ -24,6 +24,7 @@ AVAILABLE_REQUEST = Union[HTTPRequest, WebSocket] @dataclass class RequestSetup: adapter: str + self_id: str request: AVAILABLE_REQUEST poll_interval: float reconnect_interval: float @@ -61,13 +62,15 @@ class Driver(ForwardDriver): @overrides(ForwardDriver) def setup(self, adapter: str, + self_id: str, request: HTTPConnection, poll_interval: float = 3., reconnect_interval: float = 3.) -> None: if not isinstance(request, (HTTPRequest, WebSocket)): raise TypeError(f"Request Type {type(request)!r} is not supported!") self.requests.append( - RequestSetup(adapter, request, poll_interval, reconnect_interval)) + RequestSetup(adapter, self_id, request, poll_interval, + reconnect_interval)) @overrides(ForwardDriver) def run(self, *args, **kwargs): @@ -90,11 +93,11 @@ class Driver(ForwardDriver): for setup in self.requests: if isinstance(setup.request, HTTPRequest): setups.append( - self._http_setup(setup.adapter, setup.request, - setup.poll_interval)) + self._http_setup(setup.adapter, setup.self_id, + setup.request, setup.poll_interval)) else: setups.append( - self._ws_setup(setup.adapter, setup.request, + self._ws_setup(setup.adapter, setup.self_id, setup.request, setup.reconnect_interval)) try: @@ -142,26 +145,17 @@ class Driver(ForwardDriver): loop.stop() - async def _http_setup(self, adapter: str, request: HTTPRequest, - poll_interval: float): + async def _http_setup(self, adapter: str, self_id: str, + request: HTTPRequest, poll_interval: float): BotClass = self._adapters[adapter] - self_id, _ = await BotClass.check_permission(self, request) - - if not self_id: - raise SetupFailed("Bot self_id get failed") bot = BotClass(self_id, request) self._bot_connect(bot) asyncio.create_task(self._http_loop(bot, request, poll_interval)) - async def _ws_setup(self, adapter: str, request: WebSocket, + async def _ws_setup(self, adapter: str, self_id: str, request: WebSocket, reconnect_interval: float): BotClass = self._adapters[adapter] - self_id, _ = await BotClass.check_permission(self, request) - - if not self_id: - raise SetupFailed("Bot self_id get failed") - bot = BotClass(self_id, request) self._bot_connect(bot) asyncio.create_task(self._ws_loop(bot, request, reconnect_interval)) diff --git a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py index e20950b0..842ef0ea 100644 --- a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py +++ b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py @@ -3,6 +3,7 @@ import sys import hmac import json import asyncio +from urllib.parse import urlsplit from typing import Any, Dict, Tuple, Union, Optional, TYPE_CHECKING import httpx @@ -11,7 +12,8 @@ from nonebot.typing import overrides from nonebot.message import handle_event from nonebot.adapters import Bot as BaseBot from nonebot.utils import escape_tag, DataclassEncoder -from nonebot.drivers import Driver, HTTPConnection, HTTPRequest, HTTPResponse, WebSocket +from nonebot.drivers import Driver, ForwardDriver, ReverseDriver +from nonebot.drivers import HTTPConnection, HTTPRequest, HTTPResponse, WebSocket from .utils import log, escape from .config import Config as CQHTTPConfig @@ -237,6 +239,29 @@ class Bot(BaseBot): def register(cls, driver: Driver, config: "Config"): super().register(driver, config) cls.cqhttp_config = CQHTTPConfig(**config.dict()) + if not isinstance(driver, ForwardDriver) and cls.cqhttp_config.ws_urls: + logger.warning( + f"Current driver {cls.config.driver} don't support forward connections" + ) + elif isinstance(driver, ForwardDriver) and cls.cqhttp_config.ws_urls: + for self_id, url in cls.cqhttp_config.ws_urls.items(): + try: + url_info = urlsplit(url) + headers = { + "authorization": + f"Bearer {cls.cqhttp_config.access_token}", + "host": + url_info.netloc if not url_info.port else + f"{url_info.netloc}:{url_info.port}", + } + driver.setup( + "cqhttp", self_id, + WebSocket("1.1", url_info.scheme, url_info.path, + url_info.query.encode("latin-1"), headers)) + except Exception as e: + logger.opt(colors=True, exception=e).error( + f"Bad url {url} for bot {self_id} " + "in cqhttp forward websocket") @classmethod @overrides(BaseBot) diff --git a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/config.py b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/config.py index 1a17f853..ee894893 100644 --- a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/config.py +++ b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/config.py @@ -1,6 +1,6 @@ -from typing import Optional +from typing import Dict, Optional -from pydantic import Field, BaseModel +from pydantic import Field, BaseModel, AnyUrl # priority: alias > origin @@ -12,10 +12,13 @@ class Config(BaseModel): - ``access_token`` / ``cqhttp_access_token``: CQHTTP 协议授权令牌 - ``secret`` / ``cqhttp_secret``: CQHTTP HTTP 上报数据签名口令 + - ``ws_urls`` / ``cqhttp_ws_urls``: CQHTTP 正向 Websocket 连接 Bot ID、目标 URL 字典 """ access_token: Optional[str] = Field(default=None, alias="cqhttp_access_token") secret: Optional[str] = Field(default=None, alias="cqhttp_secret") + ws_urls: Dict[str, AnyUrl] = Field(default_factory=set, + alias="cqhttp_ws_urls") class Config: extra = "ignore"