cqhttp support forward websocket

This commit is contained in:
yanyongyu 2021-07-19 23:46:29 +08:00
parent 32787fdc1e
commit 04b3fda40c
5 changed files with 46 additions and 22 deletions

View File

@ -71,11 +71,11 @@ class Bot(abc.ABC):
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def register(cls, driver: Driver, config: Config): def register(cls, driver: Driver, config: Config, **kwargs):
""" """
:说明: :说明:
`register` 方法会在 `driver.register_adapter` 时被调用用于初始化相关配置 ``register`` 方法会在 ``driver.register_adapter`` 时被调用用于初始化相关配置
""" """
cls.driver = driver cls.driver = driver
cls.config = config cls.config = config

View File

@ -84,6 +84,7 @@ class Driver(abc.ABC):
* ``name: str``: 适配器名称用于在连接时进行识别 * ``name: str``: 适配器名称用于在连接时进行识别
* ``adapter: Type[Bot]``: 适配器 Class * ``adapter: Type[Bot]``: 适配器 Class
* ``**kwargs``: 其他传递给适配器的参数
""" """
if name in self._adapters: if name in self._adapters:
logger.opt( logger.opt(
@ -195,7 +196,8 @@ class Driver(abc.ABC):
class ForwardDriver(Driver): class ForwardDriver(Driver):
@abc.abstractmethod @abc.abstractmethod
def setup(self, adapter: str, request: "HTTPConnection") -> None: def setup(self, adapter: str, self_id: str,
request: "HTTPConnection") -> None:
raise NotImplementedError raise NotImplementedError

View File

@ -24,6 +24,7 @@ AVAILABLE_REQUEST = Union[HTTPRequest, WebSocket]
@dataclass @dataclass
class RequestSetup: class RequestSetup:
adapter: str adapter: str
self_id: str
request: AVAILABLE_REQUEST request: AVAILABLE_REQUEST
poll_interval: float poll_interval: float
reconnect_interval: float reconnect_interval: float
@ -61,13 +62,15 @@ class Driver(ForwardDriver):
@overrides(ForwardDriver) @overrides(ForwardDriver)
def setup(self, def setup(self,
adapter: str, adapter: str,
self_id: str,
request: HTTPConnection, request: HTTPConnection,
poll_interval: float = 3., poll_interval: float = 3.,
reconnect_interval: float = 3.) -> None: reconnect_interval: float = 3.) -> None:
if not isinstance(request, (HTTPRequest, WebSocket)): if not isinstance(request, (HTTPRequest, WebSocket)):
raise TypeError(f"Request Type {type(request)!r} is not supported!") raise TypeError(f"Request Type {type(request)!r} is not supported!")
self.requests.append( self.requests.append(
RequestSetup(adapter, request, poll_interval, reconnect_interval)) RequestSetup(adapter, self_id, request, poll_interval,
reconnect_interval))
@overrides(ForwardDriver) @overrides(ForwardDriver)
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
@ -90,11 +93,11 @@ class Driver(ForwardDriver):
for setup in self.requests: for setup in self.requests:
if isinstance(setup.request, HTTPRequest): if isinstance(setup.request, HTTPRequest):
setups.append( setups.append(
self._http_setup(setup.adapter, setup.request, self._http_setup(setup.adapter, setup.self_id,
setup.poll_interval)) setup.request, setup.poll_interval))
else: else:
setups.append( setups.append(
self._ws_setup(setup.adapter, setup.request, self._ws_setup(setup.adapter, setup.self_id, setup.request,
setup.reconnect_interval)) setup.reconnect_interval))
try: try:
@ -142,26 +145,17 @@ class Driver(ForwardDriver):
loop.stop() loop.stop()
async def _http_setup(self, adapter: str, request: HTTPRequest, async def _http_setup(self, adapter: str, self_id: str,
poll_interval: float): request: HTTPRequest, poll_interval: float):
BotClass = self._adapters[adapter] BotClass = self._adapters[adapter]
self_id, _ = await BotClass.check_permission(self, request)
if not self_id:
raise SetupFailed("Bot self_id get failed")
bot = BotClass(self_id, request) bot = BotClass(self_id, request)
self._bot_connect(bot) self._bot_connect(bot)
asyncio.create_task(self._http_loop(bot, request, poll_interval)) asyncio.create_task(self._http_loop(bot, request, poll_interval))
async def _ws_setup(self, adapter: str, request: WebSocket, async def _ws_setup(self, adapter: str, self_id: str, request: WebSocket,
reconnect_interval: float): reconnect_interval: float):
BotClass = self._adapters[adapter] BotClass = self._adapters[adapter]
self_id, _ = await BotClass.check_permission(self, request)
if not self_id:
raise SetupFailed("Bot self_id get failed")
bot = BotClass(self_id, request) bot = BotClass(self_id, request)
self._bot_connect(bot) self._bot_connect(bot)
asyncio.create_task(self._ws_loop(bot, request, reconnect_interval)) asyncio.create_task(self._ws_loop(bot, request, reconnect_interval))

View File

@ -3,6 +3,7 @@ import sys
import hmac import hmac
import json import json
import asyncio import asyncio
from urllib.parse import urlsplit
from typing import Any, Dict, Tuple, Union, Optional, TYPE_CHECKING from typing import Any, Dict, Tuple, Union, Optional, TYPE_CHECKING
import httpx import httpx
@ -11,7 +12,8 @@ from nonebot.typing import overrides
from nonebot.message import handle_event from nonebot.message import handle_event
from nonebot.adapters import Bot as BaseBot from nonebot.adapters import Bot as BaseBot
from nonebot.utils import escape_tag, DataclassEncoder from nonebot.utils import escape_tag, DataclassEncoder
from nonebot.drivers import Driver, HTTPConnection, HTTPRequest, HTTPResponse, WebSocket from nonebot.drivers import Driver, ForwardDriver, ReverseDriver
from nonebot.drivers import HTTPConnection, HTTPRequest, HTTPResponse, WebSocket
from .utils import log, escape from .utils import log, escape
from .config import Config as CQHTTPConfig from .config import Config as CQHTTPConfig
@ -237,6 +239,29 @@ class Bot(BaseBot):
def register(cls, driver: Driver, config: "Config"): def register(cls, driver: Driver, config: "Config"):
super().register(driver, config) super().register(driver, config)
cls.cqhttp_config = CQHTTPConfig(**config.dict()) cls.cqhttp_config = CQHTTPConfig(**config.dict())
if not isinstance(driver, ForwardDriver) and cls.cqhttp_config.ws_urls:
logger.warning(
f"Current driver {cls.config.driver} don't support forward connections"
)
elif isinstance(driver, ForwardDriver) and cls.cqhttp_config.ws_urls:
for self_id, url in cls.cqhttp_config.ws_urls.items():
try:
url_info = urlsplit(url)
headers = {
"authorization":
f"Bearer {cls.cqhttp_config.access_token}",
"host":
url_info.netloc if not url_info.port else
f"{url_info.netloc}:{url_info.port}",
}
driver.setup(
"cqhttp", self_id,
WebSocket("1.1", url_info.scheme, url_info.path,
url_info.query.encode("latin-1"), headers))
except Exception as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Bad url {url} for bot {self_id} "
"in cqhttp forward websocket</bg></r>")
@classmethod @classmethod
@overrides(BaseBot) @overrides(BaseBot)

View File

@ -1,6 +1,6 @@
from typing import Optional from typing import Dict, Optional
from pydantic import Field, BaseModel from pydantic import Field, BaseModel, AnyUrl
# priority: alias > origin # priority: alias > origin
@ -12,10 +12,13 @@ class Config(BaseModel):
- ``access_token`` / ``cqhttp_access_token``: CQHTTP 协议授权令牌 - ``access_token`` / ``cqhttp_access_token``: CQHTTP 协议授权令牌
- ``secret`` / ``cqhttp_secret``: CQHTTP HTTP 上报数据签名口令 - ``secret`` / ``cqhttp_secret``: CQHTTP HTTP 上报数据签名口令
- ``ws_urls`` / ``cqhttp_ws_urls``: CQHTTP 正向 Websocket 连接 Bot ID目标 URL 字典
""" """
access_token: Optional[str] = Field(default=None, access_token: Optional[str] = Field(default=None,
alias="cqhttp_access_token") alias="cqhttp_access_token")
secret: Optional[str] = Field(default=None, alias="cqhttp_secret") secret: Optional[str] = Field(default=None, alias="cqhttp_secret")
ws_urls: Dict[str, AnyUrl] = Field(default_factory=set,
alias="cqhttp_ws_urls")
class Config: class Config:
extra = "ignore" extra = "ignore"