add cqhttp forward support

This commit is contained in:
yanyongyu 2021-07-20 15:35:56 +08:00
parent 04b3fda40c
commit ecc613f6c5
10 changed files with 278 additions and 159 deletions

View File

@ -85,12 +85,12 @@ Config 配置对象
Adapter 类型 Adapter 类型
### _classmethod_ `register(driver, config)` ### _classmethod_ `register(driver, config, **kwargs)`
* **说明** * **说明**
register 方法会在 driver.register_adapter 时被调用,用于初始化相关配置 `register` 方法会在 `driver.register_adapter` 时被调用,用于初始化相关配置

View File

@ -26,6 +26,9 @@ CQHTTP 配置类
* `secret` / `cqhttp_secret`: CQHTTP HTTP 上报数据签名口令 * `secret` / `cqhttp_secret`: CQHTTP HTTP 上报数据签名口令
* `ws_urls` / `cqhttp_ws_urls`: CQHTTP 正向 Websocket 连接 Bot ID、目标 URL 字典
# NoneBot.adapters.cqhttp.utils 模块 # NoneBot.adapters.cqhttp.utils 模块

View File

@ -153,6 +153,9 @@ Driver 基类。
* `adapter: Type[Bot]`: 适配器 Class * `adapter: Type[Bot]`: 适配器 Class
* `**kwargs`: 其他传递给适配器的参数
### _abstract property_ `type` ### _abstract property_ `type`

View File

@ -61,28 +61,6 @@ sidebarDepth: 0
## _exception_ `DriverException`
基类:`nonebot.exception.NoneBotException`
* **说明**
代表 `Driver` 抛出的异常
## _exception_ `SetupFailed`
基类:`nonebot.exception.DriverException`
* **说明**
`ForwardDriver` 建立连接失败
## _exception_ `PausedException` ## _exception_ `PausedException`
基类:`nonebot.exception.NoneBotException` 基类:`nonebot.exception.NoneBotException`

View File

@ -196,8 +196,25 @@ class Driver(abc.ABC):
class ForwardDriver(Driver): class ForwardDriver(Driver):
@abc.abstractmethod @abc.abstractmethod
def setup(self, adapter: str, self_id: str, def setup_http_polling(self,
request: "HTTPConnection") -> None: adapter: str,
self_id: str,
url: str,
polling_interval: float = 3.,
method: str = "GET",
body: bytes = b"",
headers: Dict[str, str] = {},
http_version: str = "1.1") -> None:
raise NotImplementedError
@abc.abstractmethod
def setup_websocket(self,
adapter: str,
self_id: str,
url: str,
reconnect_interval: float = 3.,
headers: Dict[str, str] = {},
http_version: str = "1.1") -> None:
raise NotImplementedError raise NotImplementedError

View File

@ -3,8 +3,9 @@
import signal import signal
import asyncio import asyncio
import threading
from dataclasses import dataclass from dataclasses import dataclass
from typing import Set, List, Union, Callable, Awaitable from typing import Set, List, Dict, Optional, Callable, Awaitable
import aiohttp import aiohttp
from yarl import URL from yarl import URL
@ -13,20 +14,35 @@ from nonebot.log import logger
from nonebot.adapters import Bot from nonebot.adapters import Bot
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.config import Env, Config from nonebot.config import Env, Config
from nonebot.exception import SetupFailed from nonebot.drivers import ForwardDriver, HTTPRequest, WebSocket as BaseWebSocket
from nonebot.drivers import ForwardDriver, HTTPConnection, HTTPRequest, WebSocket
STARTUP_FUNC = Callable[[], Awaitable[None]] STARTUP_FUNC = Callable[[], Awaitable[None]]
SHUTDOWN_FUNC = Callable[[], Awaitable[None]] SHUTDOWN_FUNC = Callable[[], Awaitable[None]]
AVAILABLE_REQUEST = Union[HTTPRequest, WebSocket] HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
)
@dataclass @dataclass
class RequestSetup: class HTTPPollingSetup:
adapter: str adapter: str
self_id: str self_id: str
request: AVAILABLE_REQUEST url: str
method: str
body: bytes
headers: Dict[str, str]
http_version: str
poll_interval: float poll_interval: float
@dataclass
class WebSocketSetup:
adapter: str
self_id: str
url: str
headers: Dict[str, str]
http_version: str
reconnect_interval: float reconnect_interval: float
@ -36,7 +52,11 @@ class Driver(ForwardDriver):
super().__init__(env, config) super().__init__(env, config)
self.startup_funcs: Set[STARTUP_FUNC] = set() self.startup_funcs: Set[STARTUP_FUNC] = set()
self.shutdown_funcs: Set[SHUTDOWN_FUNC] = set() self.shutdown_funcs: Set[SHUTDOWN_FUNC] = set()
self.requests: List[RequestSetup] = [] self.http_pollings: List[HTTPPollingSetup] = []
self.websockets: List[WebSocketSetup] = []
self.connections: List[asyncio.Task] = []
self.should_exit: bool = False
self.force_exit: bool = False
@property @property
@overrides(ForwardDriver) @overrides(ForwardDriver)
@ -60,54 +80,52 @@ class Driver(ForwardDriver):
return func return func
@overrides(ForwardDriver) @overrides(ForwardDriver)
def setup(self, def setup_http_polling(self,
adapter: str, adapter: str,
self_id: str, self_id: str,
request: HTTPConnection, url: str,
poll_interval: float = 3., polling_interval: float = 3.,
reconnect_interval: float = 3.) -> None: method: str = "GET",
if not isinstance(request, (HTTPRequest, WebSocket)): body: bytes = b"",
raise TypeError(f"Request Type {type(request)!r} is not supported!") headers: Dict[str, str] = {},
self.requests.append( http_version: str = "1.1") -> None:
RequestSetup(adapter, self_id, request, poll_interval, self.http_pollings.append(
HTTPPollingSetup(adapter, self_id, url, method, body, headers,
http_version, polling_interval))
@overrides(ForwardDriver)
def setup_websocket(self,
adapter: str,
self_id: str,
url: str,
reconnect_interval: float = 3.,
headers: Dict[str, str] = {},
http_version: str = "1.1") -> None:
self.websockets.append(
WebSocketSetup(adapter, self_id, url, headers, http_version,
reconnect_interval)) reconnect_interval))
@overrides(ForwardDriver) @overrides(ForwardDriver)
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
super().run(*args, **kwargs)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) loop.run_until_complete(self.serve())
for s in signals:
loop.add_signal_handler(
s,
lambda s=s: asyncio.create_task(self.shutdown(loop, signal=s)))
try: async def serve(self):
asyncio.create_task(self.startup()) self.install_signal_handlers()
loop.run_forever() await self.startup()
finally: if self.should_exit:
loop.close() return
await self.main_loop()
await self.shutdown()
async def startup(self): async def startup(self):
setups = [] for setup in self.http_pollings:
loop = asyncio.get_event_loop() self.connections.append(asyncio.create_task(self._http_loop(setup)))
for setup in self.requests: for setup in self.websockets:
if isinstance(setup.request, HTTPRequest): self.connections.append(asyncio.create_task(self._ws_loop(setup)))
setups.append(
self._http_setup(setup.adapter, setup.self_id,
setup.request, setup.poll_interval))
else:
setups.append(
self._ws_setup(setup.adapter, setup.self_id, setup.request,
setup.reconnect_interval))
try: logger.info("Application startup completed.")
await asyncio.gather(*setups)
except Exception as e:
logger.opt(
colors=True,
exception=e).error("Application startup failed. Exiting.")
asyncio.create_task(self.shutdown(loop))
return
# run startup # run startup
cors = [startup() for startup in self.startup_funcs] cors = [startup() for startup in self.startup_funcs]
@ -119,11 +137,11 @@ class Driver(ForwardDriver):
"<r><bg #f8bbd0>Error when running startup function. " "<r><bg #f8bbd0>Error when running startup function. "
"Ignored!</bg #f8bbd0></r>") "Ignored!</bg #f8bbd0></r>")
async def shutdown(self, async def main_loop(self):
loop: asyncio.AbstractEventLoop, while not self.should_exit:
signal: signal.Signals = None): await asyncio.sleep(0.1)
# TODO: shutdown
async def shutdown(self):
# run shutdown # run shutdown
cors = [shutdown() for shutdown in self.shutdown_funcs] cors = [shutdown() for shutdown in self.shutdown_funcs]
if cors: if cors:
@ -134,44 +152,89 @@ class Driver(ForwardDriver):
"<r><bg #f8bbd0>Error when running shutdown function. " "<r><bg #f8bbd0>Error when running shutdown function. "
"Ignored!</bg #f8bbd0></r>") "Ignored!</bg #f8bbd0></r>")
for task in self.connections:
if not task.done():
task.cancel()
await asyncio.sleep(0.1)
tasks = [ tasks = [
t for t in asyncio.all_tasks() if t is not asyncio.current_task() t for t in asyncio.all_tasks() if t is not asyncio.current_task()
] ]
if tasks and not self.force_exit:
logger.info("Waiting for tasks to finish. (CTRL+C to force quit)")
while tasks and not self.force_exit:
await asyncio.sleep(0.1)
tasks = [
t for t in asyncio.all_tasks()
if t is not asyncio.current_task()
]
for task in tasks: for task in tasks:
task.cancel() task.cancel()
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
loop = asyncio.get_event_loop()
loop.stop() loop.stop()
async def _http_setup(self, adapter: str, self_id: str, def install_signal_handlers(self) -> None:
request: HTTPRequest, poll_interval: float): if threading.current_thread() is not threading.main_thread():
BotClass = self._adapters[adapter] # Signals can only be listened to from the main thread.
return
bot = BotClass(self_id, request) loop = asyncio.get_event_loop()
self._bot_connect(bot)
asyncio.create_task(self._http_loop(bot, request, poll_interval))
async def _ws_setup(self, adapter: str, self_id: str, request: WebSocket,
reconnect_interval: float):
BotClass = self._adapters[adapter]
bot = BotClass(self_id, request)
self._bot_connect(bot)
asyncio.create_task(self._ws_loop(bot, request, reconnect_interval))
async def _http_loop(self, bot: Bot, request: HTTPRequest,
poll_interval: float):
try: try:
for sig in HANDLED_SIGNALS:
loop.add_signal_handler(sig, self.handle_exit, sig, None)
except NotImplementedError:
# Windows
for sig in HANDLED_SIGNALS:
signal.signal(sig, self.handle_exit)
def handle_exit(self, sig, frame):
if self.should_exit:
self.force_exit = True
else:
self.should_exit = True
async def _http_loop(self, setup: HTTPPollingSetup):
url = URL(setup.url)
if not url.is_absolute() or not url.host:
logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>")
return
host = f"{url.host}:{url.port}" if url.port else url.host
request = HTTPRequest(setup.http_version, url.scheme, url.path,
url.raw_query_string.encode("latin-1"), {
**setup.headers, "host": host
}, setup.method, setup.body)
BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request)
self._bot_connect(bot)
headers = request.headers headers = request.headers
url = URL.build(scheme=request.scheme,
host=request.headers["host"],
path=request.path,
query_string=request.query_string.decode("latin-1"))
timeout = aiohttp.ClientTimeout(30) timeout = aiohttp.ClientTimeout(30)
version: aiohttp.HttpVersion
if request.http_version == "1.0":
version = aiohttp.HttpVersion10
elif request.http_version == "1.1":
version = aiohttp.HttpVersion11
else:
logger.opt(colors=True).error(
"<r><bg #f8bbd0>Unsupported HTTP Version "
f"{request.http_version}</bg #f8bbd0></r>")
return
try:
async with aiohttp.ClientSession(headers=headers, async with aiohttp.ClientSession(headers=headers,
timeout=timeout) as session: timeout=timeout,
while True: version=version) as session:
while not self.should_exit:
logger.debug(
f"Bot {setup.self_id} from adapter {setup.adapter} request {url}"
)
try: try:
async with session.request( async with session.request(
request.method, url, request.method, url,
@ -181,9 +244,10 @@ class Driver(ForwardDriver):
asyncio.create_task(bot.handle_message(data)) asyncio.create_task(bot.handle_message(data))
except aiohttp.ClientResponseError as e: except aiohttp.ClientResponseError as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
f"Error occurred while requesting {url}") f"<r><bg #f8bbd0>Error occurred while requesting {url}. "
"Try to reconnect...</bg #f8bbd0></r>")
await asyncio.sleep(poll_interval) await asyncio.sleep(setup.poll_interval)
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
@ -193,37 +257,111 @@ class Driver(ForwardDriver):
finally: finally:
self._bot_disconnect(bot) self._bot_disconnect(bot)
async def _ws_loop(self, bot: Bot, request: WebSocket, async def _ws_loop(self, setup: WebSocketSetup):
reconnect_interval: float): url = URL(setup.url)
try: if not url.is_absolute() or not url.host:
headers = request.headers logger.opt(colors=True).error(
url = URL.build(scheme=request.scheme, f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>")
host=request.headers["host"], return
path=request.path, host = f"{url.host}:{url.port}" if url.port else url.host
query_string=request.query_string.decode("latin-1"))
headers = {**setup.headers, "host": host}
timeout = aiohttp.ClientTimeout(30) timeout = aiohttp.ClientTimeout(30)
version: aiohttp.HttpVersion
if setup.http_version == "1.0":
version = aiohttp.HttpVersion10
elif setup.http_version == "1.1":
version = aiohttp.HttpVersion11
else:
logger.opt(colors=True).error(
"<r><bg #f8bbd0>Unsupported HTTP Version "
f"{setup.http_version}</bg #f8bbd0></r>")
return
bot: Optional[Bot] = None
try:
async with aiohttp.ClientSession(headers=headers, async with aiohttp.ClientSession(headers=headers,
timeout=timeout) as session: timeout=timeout,
version=version) as session:
while True: while True:
logger.debug(
f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}"
)
try:
async with session.ws_connect(url) as ws: async with session.ws_connect(url) as ws:
async for msg in ws: request = WebSocket(
setup.http_version, url.scheme, url.path,
url.raw_query_string.encode("latin-1"), {
**setup.headers, "host": host
}, ws)
BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request)
self._bot_connect(bot)
while not self.should_exit:
msg = await ws.receive()
if msg.type == aiohttp.WSMsgType.text: if msg.type == aiohttp.WSMsgType.text:
asyncio.create_task( asyncio.create_task(
bot.handle_message(msg.data.encode())) bot.handle_message(msg.data.encode()))
elif msg.type == aiohttp.WSMsgType.binary: elif msg.type == aiohttp.WSMsgType.binary:
asyncio.create_task(bot.handle_message( asyncio.create_task(
msg.data)) bot.handle_message(msg.data))
elif msg.type == aiohttp.WSMsgType.error: elif msg.type == aiohttp.WSMsgType.error:
logger.opt(colors=True).error( logger.opt(colors=True).error(
"<r><bg #f8bbd0>Error while handling websocket frame. " "<r><bg #f8bbd0>Error while handling websocket frame. "
"Try to reconnect...</bg></r>") "Try to reconnect...</bg #f8bbd0></r>")
break break
asyncio.sleep(reconnect_interval) else:
logger.opt(colors=True).error(
"<r><bg #f8bbd0>WebSocket connection closed by peer. "
"Try to reconnect...</bg #f8bbd0></r>")
break
except aiohttp.WSServerHandshakeError as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error while connecting to {url}"
"Try to reconnect...</bg #f8bbd0></r>")
finally:
if bot:
self._bot_disconnect(bot)
bot = None
await asyncio.sleep(setup.reconnect_interval)
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
"Unexpected exception occurred while websocket loop") "Unexpected exception occurred while websocket loop")
finally:
self._bot_disconnect(bot)
@dataclass
class WebSocket(BaseWebSocket):
websocket: aiohttp.ClientWebSocketResponse = None # type: ignore
@property
@overrides(BaseWebSocket)
def closed(self):
return self.websocket.closed
@overrides(BaseWebSocket)
async def accept(self):
raise NotImplementedError
@overrides(BaseWebSocket)
async def close(self, code: int = 1000):
await self.websocket.close(code=code)
@overrides(BaseWebSocket)
async def receive(self) -> str:
return await self.websocket.receive_str()
@overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes:
return await self.websocket.receive_bytes()
@overrides(BaseWebSocket)
async def send(self, data: str) -> None:
await self.websocket.send_str(data)
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None:
await self.websocket.send_bytes(data)

View File

@ -60,24 +60,6 @@ class ParserExit(NoneBotException):
return self.__repr__() return self.__repr__()
class DriverException(NoneBotException):
"""
:说明:
代表 ``Driver`` 抛出的异常
"""
pass
class SetupFailed(DriverException):
"""
:说明:
``ForwardDriver`` 建立连接失败
"""
pass
class PausedException(NoneBotException): class PausedException(NoneBotException):
""" """
:说明: :说明:

View File

@ -3,7 +3,6 @@ import sys
import hmac import hmac
import json import json
import asyncio import asyncio
from urllib.parse import urlsplit
from typing import Any, Dict, Tuple, Union, Optional, TYPE_CHECKING from typing import Any, Dict, Tuple, Union, Optional, TYPE_CHECKING
import httpx import httpx
@ -246,22 +245,18 @@ class Bot(BaseBot):
elif isinstance(driver, ForwardDriver) and cls.cqhttp_config.ws_urls: elif isinstance(driver, ForwardDriver) and cls.cqhttp_config.ws_urls:
for self_id, url in cls.cqhttp_config.ws_urls.items(): for self_id, url in cls.cqhttp_config.ws_urls.items():
try: try:
url_info = urlsplit(url)
headers = { headers = {
"authorization": "authorization":
f"Bearer {cls.cqhttp_config.access_token}", f"Bearer {cls.cqhttp_config.access_token}"
"host":
url_info.netloc if not url_info.port else
f"{url_info.netloc}:{url_info.port}",
} }
driver.setup( driver.setup_websocket("cqhttp",
"cqhttp", self_id, self_id,
WebSocket("1.1", url_info.scheme, url_info.path, url,
url_info.query.encode("latin-1"), headers)) headers=headers)
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Bad url {url} for bot {self_id} " f"<r><bg #f8bbd0>Bad url {url} for bot {self_id} "
"in cqhttp forward websocket</bg></r>") "in cqhttp forward websocket</bg #f8bbd0></r>")
@classmethod @classmethod
@overrides(BaseBot) @overrides(BaseBot)

View File

@ -1,4 +1,4 @@
DRIVER=nonebot.drivers.fastapi DRIVER=nonebot.drivers.aiohttp:Driver
HOST=0.0.0.0 HOST=0.0.0.0
PORT=2333 PORT=2333
DEBUG=true DEBUG=true
@ -13,6 +13,8 @@ COMMAND_SEP=["/", "."]
CUSTOM_CONFIG1=config in env CUSTOM_CONFIG1=config in env
CUSTOM_CONFIG3= CUSTOM_CONFIG3=
CQHTTP_WS_URLS={"123123123": "ws://127.0.0.1:6700/"}
MIRAI_AUTH_KEY=12345678 MIRAI_AUTH_KEY=12345678
MIRAI_HOST=127.0.0.1 MIRAI_HOST=127.0.0.1
MIRAI_PORT=8080 MIRAI_PORT=8080

View File

@ -18,7 +18,7 @@ logger.add("error.log",
format=default_format) format=default_format)
nonebot.init(custom_config2="config on init") nonebot.init(custom_config2="config on init")
app = nonebot.get_asgi() # app = nonebot.get_asgi()
driver = nonebot.get_driver() driver = nonebot.get_driver()
driver.register_adapter("cqhttp", Bot) driver.register_adapter("cqhttp", Bot)
driver.register_adapter("ding", DingBot) driver.register_adapter("ding", DingBot)
@ -37,4 +37,5 @@ config.custom_config3 = config.custom_config1
config.custom_config4 = "New custom config" config.custom_config4 = "New custom config"
if __name__ == "__main__": if __name__ == "__main__":
nonebot.run(app="__mp_main__:app") # nonebot.run(app="__mp_main__:app")
nonebot.run()