cqhttp support forward websocket

This commit is contained in:
yanyongyu 2021-07-19 23:46:29 +08:00
parent 32787fdc1e
commit 04b3fda40c
5 changed files with 46 additions and 22 deletions

View File

@ -71,11 +71,11 @@ class Bot(abc.ABC):
raise NotImplementedError
@classmethod
def register(cls, driver: Driver, config: Config):
def register(cls, driver: Driver, config: Config, **kwargs):
"""
:说明:
`register` 方法会在 `driver.register_adapter` 时被调用用于初始化相关配置
``register`` 方法会在 ``driver.register_adapter`` 时被调用用于初始化相关配置
"""
cls.driver = driver
cls.config = config

View File

@ -84,6 +84,7 @@ class Driver(abc.ABC):
* ``name: str``: 适配器名称用于在连接时进行识别
* ``adapter: Type[Bot]``: 适配器 Class
* ``**kwargs``: 其他传递给适配器的参数
"""
if name in self._adapters:
logger.opt(
@ -195,7 +196,8 @@ class Driver(abc.ABC):
class ForwardDriver(Driver):
@abc.abstractmethod
def setup(self, adapter: str, request: "HTTPConnection") -> None:
def setup(self, adapter: str, self_id: str,
request: "HTTPConnection") -> None:
raise NotImplementedError

View File

@ -24,6 +24,7 @@ AVAILABLE_REQUEST = Union[HTTPRequest, WebSocket]
@dataclass
class RequestSetup:
adapter: str
self_id: str
request: AVAILABLE_REQUEST
poll_interval: float
reconnect_interval: float
@ -61,13 +62,15 @@ class Driver(ForwardDriver):
@overrides(ForwardDriver)
def setup(self,
adapter: str,
self_id: str,
request: HTTPConnection,
poll_interval: float = 3.,
reconnect_interval: float = 3.) -> None:
if not isinstance(request, (HTTPRequest, WebSocket)):
raise TypeError(f"Request Type {type(request)!r} is not supported!")
self.requests.append(
RequestSetup(adapter, request, poll_interval, reconnect_interval))
RequestSetup(adapter, self_id, request, poll_interval,
reconnect_interval))
@overrides(ForwardDriver)
def run(self, *args, **kwargs):
@ -90,11 +93,11 @@ class Driver(ForwardDriver):
for setup in self.requests:
if isinstance(setup.request, HTTPRequest):
setups.append(
self._http_setup(setup.adapter, setup.request,
setup.poll_interval))
self._http_setup(setup.adapter, setup.self_id,
setup.request, setup.poll_interval))
else:
setups.append(
self._ws_setup(setup.adapter, setup.request,
self._ws_setup(setup.adapter, setup.self_id, setup.request,
setup.reconnect_interval))
try:
@ -142,26 +145,17 @@ class Driver(ForwardDriver):
loop.stop()
async def _http_setup(self, adapter: str, request: HTTPRequest,
poll_interval: float):
async def _http_setup(self, adapter: str, self_id: str,
request: HTTPRequest, poll_interval: float):
BotClass = self._adapters[adapter]
self_id, _ = await BotClass.check_permission(self, request)
if not self_id:
raise SetupFailed("Bot self_id get failed")
bot = BotClass(self_id, request)
self._bot_connect(bot)
asyncio.create_task(self._http_loop(bot, request, poll_interval))
async def _ws_setup(self, adapter: str, request: WebSocket,
async def _ws_setup(self, adapter: str, self_id: str, request: WebSocket,
reconnect_interval: float):
BotClass = self._adapters[adapter]
self_id, _ = await BotClass.check_permission(self, request)
if not self_id:
raise SetupFailed("Bot self_id get failed")
bot = BotClass(self_id, request)
self._bot_connect(bot)
asyncio.create_task(self._ws_loop(bot, request, reconnect_interval))

View File

@ -3,6 +3,7 @@ import sys
import hmac
import json
import asyncio
from urllib.parse import urlsplit
from typing import Any, Dict, Tuple, Union, Optional, TYPE_CHECKING
import httpx
@ -11,7 +12,8 @@ from nonebot.typing import overrides
from nonebot.message import handle_event
from nonebot.adapters import Bot as BaseBot
from nonebot.utils import escape_tag, DataclassEncoder
from nonebot.drivers import Driver, HTTPConnection, HTTPRequest, HTTPResponse, WebSocket
from nonebot.drivers import Driver, ForwardDriver, ReverseDriver
from nonebot.drivers import HTTPConnection, HTTPRequest, HTTPResponse, WebSocket
from .utils import log, escape
from .config import Config as CQHTTPConfig
@ -237,6 +239,29 @@ class Bot(BaseBot):
def register(cls, driver: Driver, config: "Config"):
super().register(driver, config)
cls.cqhttp_config = CQHTTPConfig(**config.dict())
if not isinstance(driver, ForwardDriver) and cls.cqhttp_config.ws_urls:
logger.warning(
f"Current driver {cls.config.driver} don't support forward connections"
)
elif isinstance(driver, ForwardDriver) and cls.cqhttp_config.ws_urls:
for self_id, url in cls.cqhttp_config.ws_urls.items():
try:
url_info = urlsplit(url)
headers = {
"authorization":
f"Bearer {cls.cqhttp_config.access_token}",
"host":
url_info.netloc if not url_info.port else
f"{url_info.netloc}:{url_info.port}",
}
driver.setup(
"cqhttp", self_id,
WebSocket("1.1", url_info.scheme, url_info.path,
url_info.query.encode("latin-1"), headers))
except Exception as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Bad url {url} for bot {self_id} "
"in cqhttp forward websocket</bg></r>")
@classmethod
@overrides(BaseBot)

View File

@ -1,6 +1,6 @@
from typing import Optional
from typing import Dict, Optional
from pydantic import Field, BaseModel
from pydantic import Field, BaseModel, AnyUrl
# priority: alias > origin
@ -12,10 +12,13 @@ class Config(BaseModel):
- ``access_token`` / ``cqhttp_access_token``: CQHTTP 协议授权令牌
- ``secret`` / ``cqhttp_secret``: CQHTTP HTTP 上报数据签名口令
- ``ws_urls`` / ``cqhttp_ws_urls``: CQHTTP 正向 Websocket 连接 Bot ID目标 URL 字典
"""
access_token: Optional[str] = Field(default=None,
alias="cqhttp_access_token")
secret: Optional[str] = Field(default=None, alias="cqhttp_secret")
ws_urls: Dict[str, AnyUrl] = Field(default_factory=set,
alias="cqhttp_ws_urls")
class Config:
extra = "ignore"