add cqhttp forward support

This commit is contained in:
yanyongyu 2021-07-20 15:35:56 +08:00
parent 04b3fda40c
commit ecc613f6c5
10 changed files with 278 additions and 159 deletions

View File

@ -85,12 +85,12 @@ Config 配置对象
Adapter 类型
### _classmethod_ `register(driver, config)`
### _classmethod_ `register(driver, config, **kwargs)`
* **说明**
register 方法会在 driver.register_adapter 时被调用,用于初始化相关配置
`register` 方法会在 `driver.register_adapter` 时被调用,用于初始化相关配置

View File

@ -26,6 +26,9 @@ CQHTTP 配置类
* `secret` / `cqhttp_secret`: CQHTTP HTTP 上报数据签名口令
* `ws_urls` / `cqhttp_ws_urls`: CQHTTP 正向 Websocket 连接 Bot ID、目标 URL 字典
# NoneBot.adapters.cqhttp.utils 模块

View File

@ -153,6 +153,9 @@ Driver 基类。
* `adapter: Type[Bot]`: 适配器 Class
* `**kwargs`: 其他传递给适配器的参数
### _abstract property_ `type`

View File

@ -61,28 +61,6 @@ sidebarDepth: 0
## _exception_ `DriverException`
基类:`nonebot.exception.NoneBotException`
* **说明**
代表 `Driver` 抛出的异常
## _exception_ `SetupFailed`
基类:`nonebot.exception.DriverException`
* **说明**
`ForwardDriver` 建立连接失败
## _exception_ `PausedException`
基类:`nonebot.exception.NoneBotException`

View File

@ -196,8 +196,25 @@ class Driver(abc.ABC):
class ForwardDriver(Driver):
@abc.abstractmethod
def setup(self, adapter: str, self_id: str,
request: "HTTPConnection") -> None:
def setup_http_polling(self,
adapter: str,
self_id: str,
url: str,
polling_interval: float = 3.,
method: str = "GET",
body: bytes = b"",
headers: Dict[str, str] = {},
http_version: str = "1.1") -> None:
raise NotImplementedError
@abc.abstractmethod
def setup_websocket(self,
adapter: str,
self_id: str,
url: str,
reconnect_interval: float = 3.,
headers: Dict[str, str] = {},
http_version: str = "1.1") -> None:
raise NotImplementedError

View File

@ -3,8 +3,9 @@
import signal
import asyncio
import threading
from dataclasses import dataclass
from typing import Set, List, Union, Callable, Awaitable
from typing import Set, List, Dict, Optional, Callable, Awaitable
import aiohttp
from yarl import URL
@ -13,20 +14,35 @@ from nonebot.log import logger
from nonebot.adapters import Bot
from nonebot.typing import overrides
from nonebot.config import Env, Config
from nonebot.exception import SetupFailed
from nonebot.drivers import ForwardDriver, HTTPConnection, HTTPRequest, WebSocket
from nonebot.drivers import ForwardDriver, HTTPRequest, WebSocket as BaseWebSocket
STARTUP_FUNC = Callable[[], Awaitable[None]]
SHUTDOWN_FUNC = Callable[[], Awaitable[None]]
AVAILABLE_REQUEST = Union[HTTPRequest, WebSocket]
HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
)
@dataclass
class RequestSetup:
class HTTPPollingSetup:
adapter: str
self_id: str
request: AVAILABLE_REQUEST
url: str
method: str
body: bytes
headers: Dict[str, str]
http_version: str
poll_interval: float
@dataclass
class WebSocketSetup:
adapter: str
self_id: str
url: str
headers: Dict[str, str]
http_version: str
reconnect_interval: float
@ -36,7 +52,11 @@ class Driver(ForwardDriver):
super().__init__(env, config)
self.startup_funcs: Set[STARTUP_FUNC] = set()
self.shutdown_funcs: Set[SHUTDOWN_FUNC] = set()
self.requests: List[RequestSetup] = []
self.http_pollings: List[HTTPPollingSetup] = []
self.websockets: List[WebSocketSetup] = []
self.connections: List[asyncio.Task] = []
self.should_exit: bool = False
self.force_exit: bool = False
@property
@overrides(ForwardDriver)
@ -60,54 +80,52 @@ class Driver(ForwardDriver):
return func
@overrides(ForwardDriver)
def setup(self,
def setup_http_polling(self,
adapter: str,
self_id: str,
request: HTTPConnection,
poll_interval: float = 3.,
reconnect_interval: float = 3.) -> None:
if not isinstance(request, (HTTPRequest, WebSocket)):
raise TypeError(f"Request Type {type(request)!r} is not supported!")
self.requests.append(
RequestSetup(adapter, self_id, request, poll_interval,
url: str,
polling_interval: float = 3.,
method: str = "GET",
body: bytes = b"",
headers: Dict[str, str] = {},
http_version: str = "1.1") -> None:
self.http_pollings.append(
HTTPPollingSetup(adapter, self_id, url, method, body, headers,
http_version, polling_interval))
@overrides(ForwardDriver)
def setup_websocket(self,
adapter: str,
self_id: str,
url: str,
reconnect_interval: float = 3.,
headers: Dict[str, str] = {},
http_version: str = "1.1") -> None:
self.websockets.append(
WebSocketSetup(adapter, self_id, url, headers, http_version,
reconnect_interval))
@overrides(ForwardDriver)
def run(self, *args, **kwargs):
super().run(*args, **kwargs)
loop = asyncio.get_event_loop()
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
for s in signals:
loop.add_signal_handler(
s,
lambda s=s: asyncio.create_task(self.shutdown(loop, signal=s)))
loop.run_until_complete(self.serve())
try:
asyncio.create_task(self.startup())
loop.run_forever()
finally:
loop.close()
async def serve(self):
self.install_signal_handlers()
await self.startup()
if self.should_exit:
return
await self.main_loop()
await self.shutdown()
async def startup(self):
setups = []
loop = asyncio.get_event_loop()
for setup in self.requests:
if isinstance(setup.request, HTTPRequest):
setups.append(
self._http_setup(setup.adapter, setup.self_id,
setup.request, setup.poll_interval))
else:
setups.append(
self._ws_setup(setup.adapter, setup.self_id, setup.request,
setup.reconnect_interval))
for setup in self.http_pollings:
self.connections.append(asyncio.create_task(self._http_loop(setup)))
for setup in self.websockets:
self.connections.append(asyncio.create_task(self._ws_loop(setup)))
try:
await asyncio.gather(*setups)
except Exception as e:
logger.opt(
colors=True,
exception=e).error("Application startup failed. Exiting.")
asyncio.create_task(self.shutdown(loop))
return
logger.info("Application startup completed.")
# run startup
cors = [startup() for startup in self.startup_funcs]
@ -119,11 +137,11 @@ class Driver(ForwardDriver):
"<r><bg #f8bbd0>Error when running startup function. "
"Ignored!</bg #f8bbd0></r>")
async def shutdown(self,
loop: asyncio.AbstractEventLoop,
signal: signal.Signals = None):
# TODO: shutdown
async def main_loop(self):
while not self.should_exit:
await asyncio.sleep(0.1)
async def shutdown(self):
# run shutdown
cors = [shutdown() for shutdown in self.shutdown_funcs]
if cors:
@ -134,44 +152,89 @@ class Driver(ForwardDriver):
"<r><bg #f8bbd0>Error when running shutdown function. "
"Ignored!</bg #f8bbd0></r>")
for task in self.connections:
if not task.done():
task.cancel()
await asyncio.sleep(0.1)
tasks = [
t for t in asyncio.all_tasks() if t is not asyncio.current_task()
]
if tasks and not self.force_exit:
logger.info("Waiting for tasks to finish. (CTRL+C to force quit)")
while tasks and not self.force_exit:
await asyncio.sleep(0.1)
tasks = [
t for t in asyncio.all_tasks()
if t is not asyncio.current_task()
]
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
loop = asyncio.get_event_loop()
loop.stop()
async def _http_setup(self, adapter: str, self_id: str,
request: HTTPRequest, poll_interval: float):
BotClass = self._adapters[adapter]
def install_signal_handlers(self) -> None:
if threading.current_thread() is not threading.main_thread():
# Signals can only be listened to from the main thread.
return
bot = BotClass(self_id, request)
self._bot_connect(bot)
asyncio.create_task(self._http_loop(bot, request, poll_interval))
loop = asyncio.get_event_loop()
async def _ws_setup(self, adapter: str, self_id: str, request: WebSocket,
reconnect_interval: float):
BotClass = self._adapters[adapter]
bot = BotClass(self_id, request)
self._bot_connect(bot)
asyncio.create_task(self._ws_loop(bot, request, reconnect_interval))
async def _http_loop(self, bot: Bot, request: HTTPRequest,
poll_interval: float):
try:
for sig in HANDLED_SIGNALS:
loop.add_signal_handler(sig, self.handle_exit, sig, None)
except NotImplementedError:
# Windows
for sig in HANDLED_SIGNALS:
signal.signal(sig, self.handle_exit)
def handle_exit(self, sig, frame):
if self.should_exit:
self.force_exit = True
else:
self.should_exit = True
async def _http_loop(self, setup: HTTPPollingSetup):
url = URL(setup.url)
if not url.is_absolute() or not url.host:
logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>")
return
host = f"{url.host}:{url.port}" if url.port else url.host
request = HTTPRequest(setup.http_version, url.scheme, url.path,
url.raw_query_string.encode("latin-1"), {
**setup.headers, "host": host
}, setup.method, setup.body)
BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request)
self._bot_connect(bot)
headers = request.headers
url = URL.build(scheme=request.scheme,
host=request.headers["host"],
path=request.path,
query_string=request.query_string.decode("latin-1"))
timeout = aiohttp.ClientTimeout(30)
version: aiohttp.HttpVersion
if request.http_version == "1.0":
version = aiohttp.HttpVersion10
elif request.http_version == "1.1":
version = aiohttp.HttpVersion11
else:
logger.opt(colors=True).error(
"<r><bg #f8bbd0>Unsupported HTTP Version "
f"{request.http_version}</bg #f8bbd0></r>")
return
try:
async with aiohttp.ClientSession(headers=headers,
timeout=timeout) as session:
while True:
timeout=timeout,
version=version) as session:
while not self.should_exit:
logger.debug(
f"Bot {setup.self_id} from adapter {setup.adapter} request {url}"
)
try:
async with session.request(
request.method, url,
@ -181,9 +244,10 @@ class Driver(ForwardDriver):
asyncio.create_task(bot.handle_message(data))
except aiohttp.ClientResponseError as e:
logger.opt(colors=True, exception=e).error(
f"Error occurred while requesting {url}")
f"<r><bg #f8bbd0>Error occurred while requesting {url}. "
"Try to reconnect...</bg #f8bbd0></r>")
await asyncio.sleep(poll_interval)
await asyncio.sleep(setup.poll_interval)
except asyncio.CancelledError:
pass
@ -193,37 +257,111 @@ class Driver(ForwardDriver):
finally:
self._bot_disconnect(bot)
async def _ws_loop(self, bot: Bot, request: WebSocket,
reconnect_interval: float):
try:
headers = request.headers
url = URL.build(scheme=request.scheme,
host=request.headers["host"],
path=request.path,
query_string=request.query_string.decode("latin-1"))
async def _ws_loop(self, setup: WebSocketSetup):
url = URL(setup.url)
if not url.is_absolute() or not url.host:
logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>")
return
host = f"{url.host}:{url.port}" if url.port else url.host
headers = {**setup.headers, "host": host}
timeout = aiohttp.ClientTimeout(30)
version: aiohttp.HttpVersion
if setup.http_version == "1.0":
version = aiohttp.HttpVersion10
elif setup.http_version == "1.1":
version = aiohttp.HttpVersion11
else:
logger.opt(colors=True).error(
"<r><bg #f8bbd0>Unsupported HTTP Version "
f"{setup.http_version}</bg #f8bbd0></r>")
return
bot: Optional[Bot] = None
try:
async with aiohttp.ClientSession(headers=headers,
timeout=timeout) as session:
timeout=timeout,
version=version) as session:
while True:
logger.debug(
f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}"
)
try:
async with session.ws_connect(url) as ws:
async for msg in ws:
request = WebSocket(
setup.http_version, url.scheme, url.path,
url.raw_query_string.encode("latin-1"), {
**setup.headers, "host": host
}, ws)
BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request)
self._bot_connect(bot)
while not self.should_exit:
msg = await ws.receive()
if msg.type == aiohttp.WSMsgType.text:
asyncio.create_task(
bot.handle_message(msg.data.encode()))
elif msg.type == aiohttp.WSMsgType.binary:
asyncio.create_task(bot.handle_message(
msg.data))
asyncio.create_task(
bot.handle_message(msg.data))
elif msg.type == aiohttp.WSMsgType.error:
logger.opt(colors=True).error(
"<r><bg #f8bbd0>Error while handling websocket frame. "
"Try to reconnect...</bg></r>")
"Try to reconnect...</bg #f8bbd0></r>")
break
asyncio.sleep(reconnect_interval)
else:
logger.opt(colors=True).error(
"<r><bg #f8bbd0>WebSocket connection closed by peer. "
"Try to reconnect...</bg #f8bbd0></r>")
break
except aiohttp.WSServerHandshakeError as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error while connecting to {url}"
"Try to reconnect...</bg #f8bbd0></r>")
finally:
if bot:
self._bot_disconnect(bot)
bot = None
await asyncio.sleep(setup.reconnect_interval)
except asyncio.CancelledError:
pass
except Exception as e:
logger.opt(colors=True, exception=e).error(
"Unexpected exception occurred while websocket loop")
finally:
self._bot_disconnect(bot)
@dataclass
class WebSocket(BaseWebSocket):
websocket: aiohttp.ClientWebSocketResponse = None # type: ignore
@property
@overrides(BaseWebSocket)
def closed(self):
return self.websocket.closed
@overrides(BaseWebSocket)
async def accept(self):
raise NotImplementedError
@overrides(BaseWebSocket)
async def close(self, code: int = 1000):
await self.websocket.close(code=code)
@overrides(BaseWebSocket)
async def receive(self) -> str:
return await self.websocket.receive_str()
@overrides(BaseWebSocket)
async def receive_bytes(self) -> bytes:
return await self.websocket.receive_bytes()
@overrides(BaseWebSocket)
async def send(self, data: str) -> None:
await self.websocket.send_str(data)
@overrides(BaseWebSocket)
async def send_bytes(self, data: bytes) -> None:
await self.websocket.send_bytes(data)

View File

@ -60,24 +60,6 @@ class ParserExit(NoneBotException):
return self.__repr__()
class DriverException(NoneBotException):
"""
:说明:
代表 ``Driver`` 抛出的异常
"""
pass
class SetupFailed(DriverException):
"""
:说明:
``ForwardDriver`` 建立连接失败
"""
pass
class PausedException(NoneBotException):
"""
:说明:

View File

@ -3,7 +3,6 @@ import sys
import hmac
import json
import asyncio
from urllib.parse import urlsplit
from typing import Any, Dict, Tuple, Union, Optional, TYPE_CHECKING
import httpx
@ -246,22 +245,18 @@ class Bot(BaseBot):
elif isinstance(driver, ForwardDriver) and cls.cqhttp_config.ws_urls:
for self_id, url in cls.cqhttp_config.ws_urls.items():
try:
url_info = urlsplit(url)
headers = {
"authorization":
f"Bearer {cls.cqhttp_config.access_token}",
"host":
url_info.netloc if not url_info.port else
f"{url_info.netloc}:{url_info.port}",
f"Bearer {cls.cqhttp_config.access_token}"
}
driver.setup(
"cqhttp", self_id,
WebSocket("1.1", url_info.scheme, url_info.path,
url_info.query.encode("latin-1"), headers))
driver.setup_websocket("cqhttp",
self_id,
url,
headers=headers)
except Exception as e:
logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Bad url {url} for bot {self_id} "
"in cqhttp forward websocket</bg></r>")
"in cqhttp forward websocket</bg #f8bbd0></r>")
@classmethod
@overrides(BaseBot)

View File

@ -1,4 +1,4 @@
DRIVER=nonebot.drivers.fastapi
DRIVER=nonebot.drivers.aiohttp:Driver
HOST=0.0.0.0
PORT=2333
DEBUG=true
@ -13,6 +13,8 @@ COMMAND_SEP=["/", "."]
CUSTOM_CONFIG1=config in env
CUSTOM_CONFIG3=
CQHTTP_WS_URLS={"123123123": "ws://127.0.0.1:6700/"}
MIRAI_AUTH_KEY=12345678
MIRAI_HOST=127.0.0.1
MIRAI_PORT=8080

View File

@ -18,7 +18,7 @@ logger.add("error.log",
format=default_format)
nonebot.init(custom_config2="config on init")
app = nonebot.get_asgi()
# app = nonebot.get_asgi()
driver = nonebot.get_driver()
driver.register_adapter("cqhttp", Bot)
driver.register_adapter("ding", DingBot)
@ -37,4 +37,5 @@ config.custom_config3 = config.custom_config1
config.custom_config4 = "New custom config"
if __name__ == "__main__":
nonebot.run(app="__mp_main__:app")
# nonebot.run(app="__mp_main__:app")
nonebot.run()