mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-27 18:45:05 +08:00
✨ cqhttp support forward websocket
This commit is contained in:
parent
32787fdc1e
commit
04b3fda40c
@ -71,11 +71,11 @@ class Bot(abc.ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register(cls, driver: Driver, config: Config):
|
def register(cls, driver: Driver, config: Config, **kwargs):
|
||||||
"""
|
"""
|
||||||
:说明:
|
:说明:
|
||||||
|
|
||||||
`register` 方法会在 `driver.register_adapter` 时被调用,用于初始化相关配置
|
``register`` 方法会在 ``driver.register_adapter`` 时被调用,用于初始化相关配置
|
||||||
"""
|
"""
|
||||||
cls.driver = driver
|
cls.driver = driver
|
||||||
cls.config = config
|
cls.config = config
|
||||||
|
@ -84,6 +84,7 @@ class Driver(abc.ABC):
|
|||||||
|
|
||||||
* ``name: str``: 适配器名称,用于在连接时进行识别
|
* ``name: str``: 适配器名称,用于在连接时进行识别
|
||||||
* ``adapter: Type[Bot]``: 适配器 Class
|
* ``adapter: Type[Bot]``: 适配器 Class
|
||||||
|
* ``**kwargs``: 其他传递给适配器的参数
|
||||||
"""
|
"""
|
||||||
if name in self._adapters:
|
if name in self._adapters:
|
||||||
logger.opt(
|
logger.opt(
|
||||||
@ -195,7 +196,8 @@ class Driver(abc.ABC):
|
|||||||
class ForwardDriver(Driver):
|
class ForwardDriver(Driver):
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def setup(self, adapter: str, request: "HTTPConnection") -> None:
|
def setup(self, adapter: str, self_id: str,
|
||||||
|
request: "HTTPConnection") -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,6 +24,7 @@ AVAILABLE_REQUEST = Union[HTTPRequest, WebSocket]
|
|||||||
@dataclass
|
@dataclass
|
||||||
class RequestSetup:
|
class RequestSetup:
|
||||||
adapter: str
|
adapter: str
|
||||||
|
self_id: str
|
||||||
request: AVAILABLE_REQUEST
|
request: AVAILABLE_REQUEST
|
||||||
poll_interval: float
|
poll_interval: float
|
||||||
reconnect_interval: float
|
reconnect_interval: float
|
||||||
@ -61,13 +62,15 @@ class Driver(ForwardDriver):
|
|||||||
@overrides(ForwardDriver)
|
@overrides(ForwardDriver)
|
||||||
def setup(self,
|
def setup(self,
|
||||||
adapter: str,
|
adapter: str,
|
||||||
|
self_id: str,
|
||||||
request: HTTPConnection,
|
request: HTTPConnection,
|
||||||
poll_interval: float = 3.,
|
poll_interval: float = 3.,
|
||||||
reconnect_interval: float = 3.) -> None:
|
reconnect_interval: float = 3.) -> None:
|
||||||
if not isinstance(request, (HTTPRequest, WebSocket)):
|
if not isinstance(request, (HTTPRequest, WebSocket)):
|
||||||
raise TypeError(f"Request Type {type(request)!r} is not supported!")
|
raise TypeError(f"Request Type {type(request)!r} is not supported!")
|
||||||
self.requests.append(
|
self.requests.append(
|
||||||
RequestSetup(adapter, request, poll_interval, reconnect_interval))
|
RequestSetup(adapter, self_id, request, poll_interval,
|
||||||
|
reconnect_interval))
|
||||||
|
|
||||||
@overrides(ForwardDriver)
|
@overrides(ForwardDriver)
|
||||||
def run(self, *args, **kwargs):
|
def run(self, *args, **kwargs):
|
||||||
@ -90,11 +93,11 @@ class Driver(ForwardDriver):
|
|||||||
for setup in self.requests:
|
for setup in self.requests:
|
||||||
if isinstance(setup.request, HTTPRequest):
|
if isinstance(setup.request, HTTPRequest):
|
||||||
setups.append(
|
setups.append(
|
||||||
self._http_setup(setup.adapter, setup.request,
|
self._http_setup(setup.adapter, setup.self_id,
|
||||||
setup.poll_interval))
|
setup.request, setup.poll_interval))
|
||||||
else:
|
else:
|
||||||
setups.append(
|
setups.append(
|
||||||
self._ws_setup(setup.adapter, setup.request,
|
self._ws_setup(setup.adapter, setup.self_id, setup.request,
|
||||||
setup.reconnect_interval))
|
setup.reconnect_interval))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -142,26 +145,17 @@ class Driver(ForwardDriver):
|
|||||||
|
|
||||||
loop.stop()
|
loop.stop()
|
||||||
|
|
||||||
async def _http_setup(self, adapter: str, request: HTTPRequest,
|
async def _http_setup(self, adapter: str, self_id: str,
|
||||||
poll_interval: float):
|
request: HTTPRequest, poll_interval: float):
|
||||||
BotClass = self._adapters[adapter]
|
BotClass = self._adapters[adapter]
|
||||||
self_id, _ = await BotClass.check_permission(self, request)
|
|
||||||
|
|
||||||
if not self_id:
|
|
||||||
raise SetupFailed("Bot self_id get failed")
|
|
||||||
|
|
||||||
bot = BotClass(self_id, request)
|
bot = BotClass(self_id, request)
|
||||||
self._bot_connect(bot)
|
self._bot_connect(bot)
|
||||||
asyncio.create_task(self._http_loop(bot, request, poll_interval))
|
asyncio.create_task(self._http_loop(bot, request, poll_interval))
|
||||||
|
|
||||||
async def _ws_setup(self, adapter: str, request: WebSocket,
|
async def _ws_setup(self, adapter: str, self_id: str, request: WebSocket,
|
||||||
reconnect_interval: float):
|
reconnect_interval: float):
|
||||||
BotClass = self._adapters[adapter]
|
BotClass = self._adapters[adapter]
|
||||||
self_id, _ = await BotClass.check_permission(self, request)
|
|
||||||
|
|
||||||
if not self_id:
|
|
||||||
raise SetupFailed("Bot self_id get failed")
|
|
||||||
|
|
||||||
bot = BotClass(self_id, request)
|
bot = BotClass(self_id, request)
|
||||||
self._bot_connect(bot)
|
self._bot_connect(bot)
|
||||||
asyncio.create_task(self._ws_loop(bot, request, reconnect_interval))
|
asyncio.create_task(self._ws_loop(bot, request, reconnect_interval))
|
||||||
|
@ -3,6 +3,7 @@ import sys
|
|||||||
import hmac
|
import hmac
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from urllib.parse import urlsplit
|
||||||
from typing import Any, Dict, Tuple, Union, Optional, TYPE_CHECKING
|
from typing import Any, Dict, Tuple, Union, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@ -11,7 +12,8 @@ from nonebot.typing import overrides
|
|||||||
from nonebot.message import handle_event
|
from nonebot.message import handle_event
|
||||||
from nonebot.adapters import Bot as BaseBot
|
from nonebot.adapters import Bot as BaseBot
|
||||||
from nonebot.utils import escape_tag, DataclassEncoder
|
from nonebot.utils import escape_tag, DataclassEncoder
|
||||||
from nonebot.drivers import Driver, HTTPConnection, HTTPRequest, HTTPResponse, WebSocket
|
from nonebot.drivers import Driver, ForwardDriver, ReverseDriver
|
||||||
|
from nonebot.drivers import HTTPConnection, HTTPRequest, HTTPResponse, WebSocket
|
||||||
|
|
||||||
from .utils import log, escape
|
from .utils import log, escape
|
||||||
from .config import Config as CQHTTPConfig
|
from .config import Config as CQHTTPConfig
|
||||||
@ -237,6 +239,29 @@ class Bot(BaseBot):
|
|||||||
def register(cls, driver: Driver, config: "Config"):
|
def register(cls, driver: Driver, config: "Config"):
|
||||||
super().register(driver, config)
|
super().register(driver, config)
|
||||||
cls.cqhttp_config = CQHTTPConfig(**config.dict())
|
cls.cqhttp_config = CQHTTPConfig(**config.dict())
|
||||||
|
if not isinstance(driver, ForwardDriver) and cls.cqhttp_config.ws_urls:
|
||||||
|
logger.warning(
|
||||||
|
f"Current driver {cls.config.driver} don't support forward connections"
|
||||||
|
)
|
||||||
|
elif isinstance(driver, ForwardDriver) and cls.cqhttp_config.ws_urls:
|
||||||
|
for self_id, url in cls.cqhttp_config.ws_urls.items():
|
||||||
|
try:
|
||||||
|
url_info = urlsplit(url)
|
||||||
|
headers = {
|
||||||
|
"authorization":
|
||||||
|
f"Bearer {cls.cqhttp_config.access_token}",
|
||||||
|
"host":
|
||||||
|
url_info.netloc if not url_info.port else
|
||||||
|
f"{url_info.netloc}:{url_info.port}",
|
||||||
|
}
|
||||||
|
driver.setup(
|
||||||
|
"cqhttp", self_id,
|
||||||
|
WebSocket("1.1", url_info.scheme, url_info.path,
|
||||||
|
url_info.query.encode("latin-1"), headers))
|
||||||
|
except Exception as e:
|
||||||
|
logger.opt(colors=True, exception=e).error(
|
||||||
|
f"<r><bg #f8bbd0>Bad url {url} for bot {self_id} "
|
||||||
|
"in cqhttp forward websocket</bg></r>")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@overrides(BaseBot)
|
@overrides(BaseBot)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from pydantic import Field, BaseModel
|
from pydantic import Field, BaseModel, AnyUrl
|
||||||
|
|
||||||
|
|
||||||
# priority: alias > origin
|
# priority: alias > origin
|
||||||
@ -12,10 +12,13 @@ class Config(BaseModel):
|
|||||||
|
|
||||||
- ``access_token`` / ``cqhttp_access_token``: CQHTTP 协议授权令牌
|
- ``access_token`` / ``cqhttp_access_token``: CQHTTP 协议授权令牌
|
||||||
- ``secret`` / ``cqhttp_secret``: CQHTTP HTTP 上报数据签名口令
|
- ``secret`` / ``cqhttp_secret``: CQHTTP HTTP 上报数据签名口令
|
||||||
|
- ``ws_urls`` / ``cqhttp_ws_urls``: CQHTTP 正向 Websocket 连接 Bot ID、目标 URL 字典
|
||||||
"""
|
"""
|
||||||
access_token: Optional[str] = Field(default=None,
|
access_token: Optional[str] = Field(default=None,
|
||||||
alias="cqhttp_access_token")
|
alias="cqhttp_access_token")
|
||||||
secret: Optional[str] = Field(default=None, alias="cqhttp_secret")
|
secret: Optional[str] = Field(default=None, alias="cqhttp_secret")
|
||||||
|
ws_urls: Dict[str, AnyUrl] = Field(default_factory=set,
|
||||||
|
alias="cqhttp_ws_urls")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
extra = "ignore"
|
extra = "ignore"
|
||||||
|
Loading…
Reference in New Issue
Block a user