💥 change forward setup api

This commit is contained in:
yanyongyu 2021-07-31 12:24:11 +08:00
parent f48c61c2e0
commit cda1ad093f
11 changed files with 599 additions and 245 deletions

View File

@ -219,6 +219,10 @@ module.exports = (context) => ({
title: "nonebot.drivers.quart 模块", title: "nonebot.drivers.quart 模块",
path: "drivers/quart", path: "drivers/quart",
}, },
{
title: "nonebot.drivers.aiohttp 模块",
path: "drivers/aiohttp",
},
{ {
title: "nonebot.adapters 模块", title: "nonebot.adapters 模块",
path: "adapters/", path: "adapters/",

View File

@ -49,6 +49,9 @@
* [nonebot.drivers.quart](drivers/quart.html) * [nonebot.drivers.quart](drivers/quart.html)
* [nonebot.drivers.aiohttp](drivers/aiohttp.html)
* [nonebot.adapters](adapters/) * [nonebot.adapters](adapters/)

View File

@ -238,6 +238,45 @@ Driver 基类。
在 WebSocket 连接断开后,调用该函数来注销 bot 对象 在 WebSocket 连接断开后,调用该函数来注销 bot 对象
## _class_ `ForwardDriver`
基类:`nonebot.drivers.Driver`
Forward Driver 基类。将客户端框架封装,以满足适配器使用。
### _abstract_ `setup_http_polling(setup)`
* **说明**
注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
* **参数**
* `setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]`
### _abstract_ `setup_websocket(setup)`
* **说明**
注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
* **参数**
* `setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]`
## _class_ `ReverseDriver` ## _class_ `ReverseDriver`
基类:`nonebot.drivers.Driver` 基类:`nonebot.drivers.Driver`
@ -413,3 +452,78 @@ Always `websocket`
### _abstract async_ `send_bytes(data)` ### _abstract async_ `send_bytes(data)`
发送一条 WebSocket binary 信息 发送一条 WebSocket binary 信息
## _class_ `HTTPPollingSetup`
基类:`object`
### `adapter`
协议适配器名称
### `self_id`
机器人 ID
### `url`
URL
### `method`
HTTP method
### `body`
HTTP body
### `headers`
HTTP headers
### `http_version`
HTTP version
### `poll_interval`
HTTP 轮询间隔
## _class_ `WebSocketSetup`
基类:`object`
### `adapter`
协议适配器名称
### `self_id`
机器人 ID
### `url`
URL
### `headers`
HTTP headers
### `reconnect_interval`
WebSocket 重连间隔

101
docs/api/drivers/aiohttp.md Normal file
View File

@ -0,0 +1,101 @@
---
contentSidebar: true
sidebarDepth: 0
---
# NoneBot.drivers.aiohttp 模块
## AIOHTTP 驱动适配
本驱动仅支持客户端连接
## _class_ `Driver`
基类:[`nonebot.drivers.ForwardDriver`](README.md#nonebot.drivers.ForwardDriver)
AIOHTTP 驱动框架
### _property_ `type`
驱动名称: `aiohttp`
### _property_ `logger`
aiohttp driver 使用的 logger
### `on_startup(func)`
* **说明**
注册一个启动时执行的函数
* **参数**
* `func: Callable[[], Awaitable[None]]`
### `on_shutdown(func)`
* **说明**
注册一个停止时执行的函数
* **参数**
* `func: Callable[[], Awaitable[None]]`
### `setup_http_polling(setup)`
* **说明**
注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
* **参数**
* `setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]`
### `setup_websocket(setup)`
* **说明**
注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
* **参数**
* `setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]`
### `run(*args, **kwargs)`
启动 aiohttp driver
## _class_ `WebSocket`
基类:[`nonebot.drivers.WebSocket`](README.md#nonebot.drivers.WebSocket)

View File

@ -7,19 +7,11 @@ sidebarDepth: 0
## FastAPI 驱动适配 ## FastAPI 驱动适配
本驱动同时支持服务端以及客户端连接
后端使用方法请参考: [FastAPI 文档](https://fastapi.tiangolo.com/) 后端使用方法请参考: [FastAPI 文档](https://fastapi.tiangolo.com/)
## _class_ `HTTPPollingSetup`
基类:`object`
## _class_ `WebSocketSetup`
基类:`object`
## _class_ `Config` ## _class_ `Config`
基类:`pydantic.env_settings.BaseSettings` 基类:`pydantic.env_settings.BaseSettings`
@ -89,7 +81,7 @@ FastAPI 驱动框架设置,详情参考 FastAPI 文档
## _class_ `Driver` ## _class_ `Driver`
基类:[`nonebot.drivers.ReverseDriver`](README.md#nonebot.drivers.ReverseDriver), `nonebot.drivers.ForwardDriver` 基类:[`nonebot.drivers.ReverseDriver`](README.md#nonebot.drivers.ReverseDriver), [`nonebot.drivers.ForwardDriver`](README.md#nonebot.drivers.ForwardDriver)
FastAPI 驱动框架 FastAPI 驱动框架
@ -140,6 +132,38 @@ fastapi 使用的 logger
参考文档: [Events](https://fastapi.tiangolo.com/advanced/events/#startup-event) 参考文档: [Events](https://fastapi.tiangolo.com/advanced/events/#startup-event)
### `setup_http_polling(setup)`
* **说明**
注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
* **参数**
* `setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]`
### `setup_websocket(setup)`
* **说明**
注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
* **参数**
* `setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]`
### `run(host=None, port=None, *, app=None, **kwargs)` ### `run(host=None, port=None, *, app=None, **kwargs)`
使用 `uvicorn` 启动 FastAPI 使用 `uvicorn` 启动 FastAPI

View File

@ -17,6 +17,7 @@ NoneBot Api Reference
- `nonebot.drivers <drivers/>`_ - `nonebot.drivers <drivers/>`_
- `nonebot.drivers.fastapi <drivers/fastapi.html>`_ - `nonebot.drivers.fastapi <drivers/fastapi.html>`_
- `nonebot.drivers.quart <drivers/quart.html>`_ - `nonebot.drivers.quart <drivers/quart.html>`_
- `nonebot.drivers.aiohttp <drivers/aiohttp.html>`_
- `nonebot.adapters <adapters/>`_ - `nonebot.adapters <adapters/>`_
- `nonebot.adapters.cqhttp <adapters/cqhttp.html>`_ - `nonebot.adapters.cqhttp <adapters/cqhttp.html>`_
- `nonebot.adapters.ding <adapters/ding.html>`_ - `nonebot.adapters.ding <adapters/ding.html>`_

View File

@ -0,0 +1,12 @@
\-\-\-
contentSidebar: true
sidebarDepth: 0
\-\-\-
NoneBot.drivers.aiohttp 模块
=============================
.. automodule:: nonebot.drivers.aiohttp
:members:
:private-members:
:show-inheritance:

View File

@ -8,7 +8,7 @@
import abc import abc
import asyncio import asyncio
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Set, Dict, Type, Optional, Callable, TYPE_CHECKING from typing import Any, Set, Dict, Type, Union, Optional, Callable, Awaitable, TYPE_CHECKING
from nonebot.log import logger from nonebot.log import logger
from nonebot.config import Env, Config from nonebot.config import Env, Config
@ -193,27 +193,40 @@ class Driver(abc.ABC):
class ForwardDriver(Driver): class ForwardDriver(Driver):
"""
Forward Driver 基类将客户端框架封装以满足适配器使用
"""
@abc.abstractmethod @abc.abstractmethod
def setup_http_polling(self, def setup_http_polling(
adapter: str, self, setup: Union["HTTPPollingSetup",
self_id: str, Callable[[], Awaitable["HTTPPollingSetup"]]]
url: str, ) -> None:
polling_interval: float = 3., """
method: str = "GET", :说明:
body: bytes = b"",
headers: Dict[str, str] = {}, 注册一个 HTTP 轮询连接如果传入一个函数则该函数会在每次连接时被调用
http_version: str = "1.1") -> None:
:参数:
* ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
"""
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def setup_websocket(self, def setup_websocket(
adapter: str, self, setup: Union["WebSocketSetup",
self_id: str, Callable[[], Awaitable["WebSocketSetup"]]]
url: str, ) -> None:
reconnect_interval: float = 3., """
headers: Dict[str, str] = {}, :说明:
http_version: str = "1.1") -> None:
注册一个 WebSocket 连接如果传入一个函数则该函数会在每次重连时被调用
:参数:
* ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
"""
raise NotImplementedError raise NotImplementedError
@ -369,3 +382,37 @@ class WebSocket(HTTPConnection, abc.ABC):
async def send_bytes(self, data: bytes): async def send_bytes(self, data: bytes):
"""发送一条 WebSocket binary 信息""" """发送一条 WebSocket binary 信息"""
raise NotImplementedError raise NotImplementedError
@dataclass
class HTTPPollingSetup:
adapter: str
"""协议适配器名称"""
self_id: str
"""机器人 ID"""
url: str
"""URL"""
method: str
"""HTTP method"""
body: bytes
"""HTTP body"""
headers: Dict[str, str]
"""HTTP headers"""
http_version: str
"""HTTP version"""
poll_interval: float
"""HTTP 轮询间隔"""
@dataclass
class WebSocketSetup:
adapter: str
"""协议适配器名称"""
self_id: str
"""机器人 ID"""
url: str
"""URL"""
headers: Dict[str, str] = field(default_factory=dict)
"""HTTP headers"""
reconnect_interval: float = 3.
"""WebSocket 重连间隔"""

View File

@ -1,11 +1,15 @@
""" """
AIOHTTP 驱动适配
================
本驱动仅支持客户端连接
""" """
import signal import signal
import asyncio import asyncio
import threading import threading
from dataclasses import dataclass from dataclasses import dataclass
from typing import Set, List, Dict, Optional, Callable, Awaitable from typing import Set, List, cast, Union, Optional, Callable, Awaitable
import aiohttp import aiohttp
from yarl import URL from yarl import URL
@ -14,46 +18,31 @@ from nonebot.log import logger
from nonebot.adapters import Bot from nonebot.adapters import Bot
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.config import Env, Config from nonebot.config import Env, Config
from nonebot.drivers import ForwardDriver, HTTPRequest, WebSocket as BaseWebSocket from nonebot.drivers import (ForwardDriver, HTTPPollingSetup, WebSocketSetup,
HTTPRequest, WebSocket as BaseWebSocket)
STARTUP_FUNC = Callable[[], Awaitable[None]] STARTUP_FUNC = Callable[[], Awaitable[None]]
SHUTDOWN_FUNC = Callable[[], Awaitable[None]] SHUTDOWN_FUNC = Callable[[], Awaitable[None]]
HTTPPOLLING_SETUP = Union[HTTPPollingSetup,
Callable[[], Awaitable[HTTPPollingSetup]]]
WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
HANDLED_SIGNALS = ( HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`. signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
) )
@dataclass
class HTTPPollingSetup:
adapter: str
self_id: str
url: str
method: str
body: bytes
headers: Dict[str, str]
http_version: str
poll_interval: float
@dataclass
class WebSocketSetup:
adapter: str
self_id: str
url: str
headers: Dict[str, str]
http_version: str
reconnect_interval: float
class Driver(ForwardDriver): class Driver(ForwardDriver):
"""
AIOHTTP 驱动框架
"""
def __init__(self, env: Env, config: Config): def __init__(self, env: Env, config: Config):
super().__init__(env, config) super().__init__(env, config)
self.startup_funcs: Set[STARTUP_FUNC] = set() self.startup_funcs: Set[STARTUP_FUNC] = set()
self.shutdown_funcs: Set[SHUTDOWN_FUNC] = set() self.shutdown_funcs: Set[SHUTDOWN_FUNC] = set()
self.http_pollings: List[HTTPPollingSetup] = [] self.http_pollings: List[HTTPPOLLING_SETUP] = []
self.websockets: List[WebSocketSetup] = [] self.websockets: List[WEBSOCKET_SETUP] = []
self.connections: List[asyncio.Task] = [] self.connections: List[asyncio.Task] = []
self.should_exit: asyncio.Event = asyncio.Event() self.should_exit: asyncio.Event = asyncio.Event()
self.force_exit: bool = False self.force_exit: bool = False
@ -67,46 +56,66 @@ class Driver(ForwardDriver):
@property @property
@overrides(ForwardDriver) @overrides(ForwardDriver)
def logger(self): def logger(self):
"""aiohttp driver 使用的 logger"""
return logger return logger
@overrides(ForwardDriver) @overrides(ForwardDriver)
def on_startup(self, func: Callable) -> Callable: def on_startup(self, func: STARTUP_FUNC) -> STARTUP_FUNC:
"""
:说明:
注册一个启动时执行的函数
:参数:
* ``func: Callable[[], Awaitable[None]]``
"""
self.startup_funcs.add(func) self.startup_funcs.add(func)
return func return func
@overrides(ForwardDriver) @overrides(ForwardDriver)
def on_shutdown(self, func: Callable) -> Callable: def on_shutdown(self, func: SHUTDOWN_FUNC) -> SHUTDOWN_FUNC:
"""
:说明:
注册一个停止时执行的函数
:参数:
* ``func: Callable[[], Awaitable[None]]``
"""
self.shutdown_funcs.add(func) self.shutdown_funcs.add(func)
return func return func
@overrides(ForwardDriver) @overrides(ForwardDriver)
def setup_http_polling(self, def setup_http_polling(self, setup: HTTPPOLLING_SETUP) -> None:
adapter: str, """
self_id: str, :说明:
url: str,
polling_interval: float = 3., 注册一个 HTTP 轮询连接如果传入一个函数则该函数会在每次连接时被调用
method: str = "GET",
body: bytes = b"", :参数:
headers: Dict[str, str] = {},
http_version: str = "1.1") -> None: * ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
self.http_pollings.append( """
HTTPPollingSetup(adapter, self_id, url, method, body, headers, self.http_pollings.append(setup)
http_version, polling_interval))
@overrides(ForwardDriver) @overrides(ForwardDriver)
def setup_websocket(self, def setup_websocket(self, setup: WEBSOCKET_SETUP) -> None:
adapter: str, """
self_id: str, :说明:
url: str,
reconnect_interval: float = 3., 注册一个 WebSocket 连接如果传入一个函数则该函数会在每次重连时被调用
headers: Dict[str, str] = {},
http_version: str = "1.1") -> None: :参数:
self.websockets.append(
WebSocketSetup(adapter, self_id, url, headers, http_version, * ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
reconnect_interval)) """
self.websockets.append(setup)
@overrides(ForwardDriver) @overrides(ForwardDriver)
def run(self, *args, **kwargs): def run(self, *args, **kwargs):
"""启动 aiohttp driver"""
super().run(*args, **kwargs) super().run(*args, **kwargs)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.run_until_complete(self.serve()) loop.run_until_complete(self.serve())
@ -197,25 +206,55 @@ class Driver(ForwardDriver):
else: else:
self.should_exit.set() self.should_exit.set()
async def _http_loop(self, setup: HTTPPollingSetup): async def _http_loop(self, setup: HTTPPOLLING_SETUP):
async def _build_request(
setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
url = URL(setup.url) url = URL(setup.url)
if not url.is_absolute() or not url.host: if not url.is_absolute() or not url.host:
logger.opt(colors=True).error( logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>") f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>")
return return
host = f"{url.host}:{url.port}" if url.port else url.host host = f"{url.host}:{url.port}" if url.port else url.host
request = HTTPRequest(setup.http_version, url.scheme, url.path, return HTTPRequest(setup.http_version, url.scheme, url.path,
url.raw_query_string.encode("latin-1"), { url.raw_query_string.encode("latin-1"), {
**setup.headers, "host": host **setup.headers, "host": host
}, setup.method, setup.body) }, setup.method, setup.body)
BotClass = self._adapters[setup.adapter] bot: Optional[Bot] = None
bot = BotClass(setup.self_id, request) request: Optional[HTTPRequest] = None
self._bot_connect(bot) setup_: Optional[HTTPPollingSetup] = None
logger.opt(colors=True).info( logger.opt(colors=True).info(
f"Start http polling for <y>{setup.adapter.upper()} " f"Start http polling for <y>{setup.adapter.upper()} "
f"Bot {setup.self_id}</y>") f"Bot {setup.self_id}</y>")
try:
async with aiohttp.ClientSession() as session:
while not self.should_exit.is_set():
if not bot:
if callable(setup):
setup_ = await setup()
else:
setup_ = setup
request = await _build_request(setup_)
if not request:
return
BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request)
self._bot_connect(bot)
elif callable(setup):
setup_ = await setup()
request = await _build_request(setup_)
if not request:
await asyncio.sleep(setup_.poll_interval)
continue
bot.request = request
request = cast(HTTPRequest, request)
setup_ = cast(HTTPPollingSetup, setup_)
headers = request.headers headers = request.headers
timeout = aiohttp.ClientTimeout(30) timeout = aiohttp.ClientTimeout(30)
version: aiohttp.HttpVersion version: aiohttp.HttpVersion
@ -229,27 +268,26 @@ class Driver(ForwardDriver):
f"{request.http_version}</bg #f8bbd0></r>") f"{request.http_version}</bg #f8bbd0></r>")
return return
try:
async with aiohttp.ClientSession(headers=headers,
timeout=timeout,
version=version) as session:
while not self.should_exit.is_set():
logger.debug( logger.debug(
f"Bot {setup.self_id} from adapter {setup.adapter} request {url}" f"Bot {setup_.self_id} from adapter {setup_.adapter} request {setup_.url}"
) )
try: try:
async with session.request( async with session.request(request.method,
request.method, url, setup_.url,
data=request.body) as response: data=request.body,
headers=headers,
timeout=timeout,
version=version) as response:
response.raise_for_status() response.raise_for_status()
data = await response.read() data = await response.read()
asyncio.create_task(bot.handle_message(data)) asyncio.create_task(bot.handle_message(data))
except aiohttp.ClientResponseError as e: except aiohttp.ClientResponseError as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error occurred while requesting {url}. " f"<r><bg #f8bbd0>Error occurred while requesting {setup_.url}. "
"Try to reconnect...</bg #f8bbd0></r>") "Try to reconnect...</bg #f8bbd0></r>")
await asyncio.sleep(setup.poll_interval) await asyncio.sleep(setup_.poll_interval)
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
@ -258,50 +296,48 @@ class Driver(ForwardDriver):
"<r><bg #f8bbd0>Unexpected exception occurred " "<r><bg #f8bbd0>Unexpected exception occurred "
"while http polling</bg #f8bbd0></r>") "while http polling</bg #f8bbd0></r>")
finally: finally:
if bot:
self._bot_disconnect(bot) self._bot_disconnect(bot)
async def _ws_loop(self, setup: WebSocketSetup): async def _ws_loop(self, setup: WEBSOCKET_SETUP):
url = URL(setup.url) bot: Optional[Bot] = None
try:
async with aiohttp.ClientSession() as session:
while True:
if callable(setup):
setup_ = await setup()
else:
setup_ = setup
url = URL(setup_.url)
if not url.is_absolute() or not url.host: if not url.is_absolute() or not url.host:
logger.opt(colors=True).error( logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>") f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>"
return )
await asyncio.sleep(setup_.reconnect_interval)
continue
host = f"{url.host}:{url.port}" if url.port else url.host host = f"{url.host}:{url.port}" if url.port else url.host
headers = {**setup_.headers, "host": host}
headers = {**setup.headers, "host": host}
timeout = aiohttp.ClientTimeout(30)
version: aiohttp.HttpVersion
if setup.http_version == "1.0":
version = aiohttp.HttpVersion10
elif setup.http_version == "1.1":
version = aiohttp.HttpVersion11
else:
logger.opt(colors=True).error(
"<r><bg #f8bbd0>Unsupported HTTP Version "
f"{setup.http_version}</bg #f8bbd0></r>")
return
bot: Optional[Bot] = None
try:
async with aiohttp.ClientSession(headers=headers,
timeout=timeout,
version=version) as session:
while True:
logger.debug( logger.debug(
f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}" f"Bot {setup_.self_id} from adapter {setup_.adapter} connecting to {url}"
) )
try: try:
async with session.ws_connect(url) as ws: async with session.ws_connect(url,
headers=headers,
timeout=30.) as ws:
logger.opt(colors=True).info( logger.opt(colors=True).info(
f"WebSocket Connection to <y>{setup.adapter.upper()} " f"WebSocket Connection to <y>{setup_.adapter.upper()} "
f"Bot {setup.self_id}</y> succeeded!") f"Bot {setup_.self_id}</y> succeeded!")
request = WebSocket( request = WebSocket(
setup.http_version, url.scheme, url.path, "1.1", url.scheme, url.path,
url.raw_query_string.encode("latin-1"), headers, url.raw_query_string.encode("latin-1"), headers,
ws) ws)
BotClass = self._adapters[setup.adapter] BotClass = self._adapters[setup_.adapter]
bot = BotClass(setup.self_id, request) bot = BotClass(setup_.self_id, request)
self._bot_connect(bot) self._bot_connect(bot)
while not self.should_exit.is_set(): while not self.should_exit.is_set():
msg = await ws.receive() msg = await ws.receive()
@ -330,7 +366,7 @@ class Driver(ForwardDriver):
if bot: if bot:
self._bot_disconnect(bot) self._bot_disconnect(bot)
bot = None bot = None
await asyncio.sleep(setup.reconnect_interval) await asyncio.sleep(setup_.reconnect_interval)
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass

View File

@ -2,6 +2,8 @@
FastAPI 驱动适配 FastAPI 驱动适配
================ ================
本驱动同时支持服务端以及客户端连接
后端使用方法请参考: `FastAPI 文档`_ 后端使用方法请参考: `FastAPI 文档`_
.. _FastAPI 文档: .. _FastAPI 文档:
@ -11,7 +13,7 @@ FastAPI 驱动适配
import asyncio import asyncio
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Dict, Union, Optional, Callable from typing import List, cast, Union, Optional, Callable, Awaitable
import httpx import httpx
import uvicorn import uvicorn
@ -27,30 +29,13 @@ from nonebot.log import logger
from nonebot.adapters import Bot from nonebot.adapters import Bot
from nonebot.typing import overrides from nonebot.typing import overrides
from nonebot.config import Env, Config as NoneBotConfig from nonebot.config import Env, Config as NoneBotConfig
from nonebot.drivers import ReverseDriver, ForwardDriver from nonebot.drivers import (ReverseDriver, ForwardDriver, HTTPPollingSetup,
from nonebot.drivers import HTTPRequest, WebSocket as BaseWebSocket WebSocketSetup, HTTPRequest, WebSocket as
BaseWebSocket)
HTTPPOLLING_SETUP = Union[HTTPPollingSetup,
@dataclass Callable[[], Awaitable[HTTPPollingSetup]]]
class HTTPPollingSetup: WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
adapter: str
self_id: str
url: str
method: str
body: bytes
headers: Dict[str, str]
http_version: str
poll_interval: float
@dataclass
class WebSocketSetup:
adapter: str
self_id: str
url: str
headers: Dict[str, str]
http_version: str
reconnect_interval: float
class Config(BaseSettings): class Config(BaseSettings):
@ -118,8 +103,8 @@ class Driver(ReverseDriver, ForwardDriver):
super().__init__(env, config) super().__init__(env, config)
self.fastapi_config: Config = Config(**config.dict()) self.fastapi_config: Config = Config(**config.dict())
self.http_pollings: List[HTTPPollingSetup] = [] self.http_pollings: List[HTTPPOLLING_SETUP] = []
self.websockets: List[WebSocketSetup] = [] self.websockets: List[WEBSOCKET_SETUP] = []
self.shutdown: asyncio.Event = asyncio.Event() self.shutdown: asyncio.Event = asyncio.Event()
self.connections: List[asyncio.Task] = [] self.connections: List[asyncio.Task] = []
@ -173,30 +158,30 @@ class Driver(ReverseDriver, ForwardDriver):
return self.server_app.on_event("shutdown")(func) return self.server_app.on_event("shutdown")(func)
@overrides(ForwardDriver) @overrides(ForwardDriver)
def setup_http_polling(self, def setup_http_polling(self, setup: HTTPPOLLING_SETUP) -> None:
adapter: str, """
self_id: str, :说明:
url: str,
polling_interval: float = 3., 注册一个 HTTP 轮询连接如果传入一个函数则该函数会在每次连接时被调用
method: str = "GET",
body: bytes = b"", :参数:
headers: Dict[str, str] = {},
http_version: str = "1.1") -> None: * ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
self.http_pollings.append( """
HTTPPollingSetup(adapter, self_id, url, method, body, headers, self.http_pollings.append(setup)
http_version, polling_interval))
@overrides(ForwardDriver) @overrides(ForwardDriver)
def setup_websocket(self, def setup_websocket(self, setup: WEBSOCKET_SETUP) -> None:
adapter: str, """
self_id: str, :说明:
url: str,
reconnect_interval: float = 3., 注册一个 WebSocket 连接如果传入一个函数则该函数会在每次重连时被调用
headers: Dict[str, str] = {},
http_version: str = "1.1") -> None: :参数:
self.websockets.append(
WebSocketSetup(adapter, self_id, url, headers, http_version, * ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
reconnect_interval)) """
self.websockets.append(setup)
@overrides(ReverseDriver) @overrides(ReverseDriver)
def run(self, def run(self,
@ -336,50 +321,72 @@ class Driver(ReverseDriver, ForwardDriver):
finally: finally:
self._bot_disconnect(bot) self._bot_disconnect(bot)
async def _http_loop(self, setup: HTTPPollingSetup): async def _http_loop(self, setup: HTTPPOLLING_SETUP):
async def _build_request(
setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
url = httpx.URL(setup.url) url = httpx.URL(setup.url)
if not url.netloc: if not url.netloc:
logger.opt(colors=True).error( logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>") f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>")
return return
request = HTTPRequest( return HTTPRequest(
setup.http_version, url.scheme, url.path, url.query, { setup.http_version, url.scheme, url.path, url.query, {
**setup.headers, "host": url.netloc.decode("ascii") **setup.headers, "host": url.netloc.decode("ascii")
}, setup.method, setup.body) }, setup.method, setup.body)
BotClass = self._adapters[setup.adapter] bot: Optional[Bot] = None
bot = BotClass(setup.self_id, request) request: Optional[HTTPRequest] = None
self._bot_connect(bot) setup_: Optional[HTTPPollingSetup] = None
logger.opt(colors=True).info( logger.opt(colors=True).info(
f"Start http polling for <y>{setup.adapter.upper()} " f"Start http polling for <y>{setup.adapter.upper()} "
f"Bot {setup.self_id}</y>") f"Bot {setup.self_id}</y>")
headers = request.headers
http2: bool = False
if request.http_version == "2":
http2 = True
try: try:
async with httpx.AsyncClient(headers=headers, async with httpx.AsyncClient(http2=True) as session:
timeout=30.,
http2=http2) as session:
while not self.shutdown.is_set(): while not self.shutdown.is_set():
if not bot:
if callable(setup):
setup_ = await setup()
else:
setup_ = setup
request = await _build_request(setup_)
if not request:
return
BotClass = self._adapters[setup.adapter]
bot = BotClass(setup.self_id, request)
self._bot_connect(bot)
elif callable(setup):
setup_ = await setup()
request = await _build_request(setup_)
if not request:
await asyncio.sleep(setup_.poll_interval)
continue
bot.request = request
setup_ = cast(HTTPPollingSetup, setup_)
request = cast(HTTPRequest, request)
headers = request.headers
logger.debug( logger.debug(
f"Bot {setup.self_id} from adapter {setup.adapter} request {url}" f"Bot {setup_.self_id} from adapter {setup_.adapter} request {setup_.url}"
) )
try: try:
response = await session.request(request.method, response = await session.request(request.method,
url, setup_.url,
content=request.body) content=request.body,
headers=headers,
timeout=30.)
response.raise_for_status() response.raise_for_status()
data = response.read() data = response.read()
asyncio.create_task(bot.handle_message(data)) asyncio.create_task(bot.handle_message(data))
except httpx.HTTPError as e: except httpx.HTTPError as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Error occurred while requesting {url}. " f"<r><bg #f8bbd0>Error occurred while requesting {setup_.url}. "
"Try to reconnect...</bg #f8bbd0></r>") "Try to reconnect...</bg #f8bbd0></r>")
await asyncio.sleep(setup.poll_interval) await asyncio.sleep(setup_.poll_interval)
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
@ -388,34 +395,41 @@ class Driver(ReverseDriver, ForwardDriver):
"<r><bg #f8bbd0>Unexpected exception occurred " "<r><bg #f8bbd0>Unexpected exception occurred "
"while http polling</bg #f8bbd0></r>") "while http polling</bg #f8bbd0></r>")
finally: finally:
if bot:
self._bot_disconnect(bot) self._bot_disconnect(bot)
async def _ws_loop(self, setup: WebSocketSetup): async def _ws_loop(self, setup: WEBSOCKET_SETUP):
url = httpx.URL(setup.url)
if not url.netloc:
logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>")
return
headers = {**setup.headers, "host": url.netloc.decode("ascii")}
bot: Optional[Bot] = None bot: Optional[Bot] = None
try: try:
while True: while True:
if callable(setup):
setup_ = await setup()
else:
setup_ = setup
url = httpx.URL(setup_.url)
if not url.netloc:
logger.opt(colors=True).error(
f"<r><bg #f8bbd0>Error parsing url {url}</bg #f8bbd0></r>"
)
return
headers = {**setup_.headers, "host": url.netloc.decode("ascii")}
logger.debug( logger.debug(
f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}" f"Bot {setup_.self_id} from adapter {setup_.adapter} connecting to {url}"
) )
try: try:
connection = Connect(setup.url) connection = Connect(setup_.url)
async with connection as ws: async with connection as ws:
logger.opt(colors=True).info( logger.opt(colors=True).info(
f"WebSocket Connection to <y>{setup.adapter.upper()} " f"WebSocket Connection to <y>{setup_.adapter.upper()} "
f"Bot {setup.self_id}</y> succeeded!") f"Bot {setup_.self_id}</y> succeeded!")
request = WebSocket(setup.http_version, url.scheme, request = WebSocket("1.1", url.scheme, url.path,
url.path, url.query, headers, ws) url.query, headers, ws)
BotClass = self._adapters[setup.adapter] BotClass = self._adapters[setup_.adapter]
bot = BotClass(setup.self_id, request) bot = BotClass(setup_.self_id, request)
self._bot_connect(bot) self._bot_connect(bot)
while not self.shutdown.is_set(): while not self.shutdown.is_set():
# use try except instead of "request.closed" because of queued message # use try except instead of "request.closed" because of queued message
@ -434,7 +448,7 @@ class Driver(ReverseDriver, ForwardDriver):
if bot: if bot:
self._bot_disconnect(bot) self._bot_disconnect(bot)
bot = None bot = None
await asyncio.sleep(setup.reconnect_interval) await asyncio.sleep(setup_.reconnect_interval)
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass

View File

@ -11,7 +11,7 @@ from nonebot.typing import overrides
from nonebot.message import handle_event from nonebot.message import handle_event
from nonebot.adapters import Bot as BaseBot from nonebot.adapters import Bot as BaseBot
from nonebot.utils import escape_tag, DataclassEncoder from nonebot.utils import escape_tag, DataclassEncoder
from nonebot.drivers import Driver, ForwardDriver, ReverseDriver from nonebot.drivers import Driver, ForwardDriver, WebSocketSetup
from nonebot.drivers import HTTPConnection, HTTPRequest, HTTPResponse, WebSocket from nonebot.drivers import HTTPConnection, HTTPRequest, HTTPResponse, WebSocket
from .utils import log, escape from .utils import log, escape
@ -249,10 +249,8 @@ class Bot(BaseBot):
"authorization": "authorization":
f"Bearer {cls.cqhttp_config.access_token}" f"Bearer {cls.cqhttp_config.access_token}"
} if cls.cqhttp_config.access_token else {} } if cls.cqhttp_config.access_token else {}
driver.setup_websocket("cqhttp", driver.setup_websocket(
self_id, WebSocketSetup("cqhttp", self_id, url, headers=headers))
url,
headers=headers)
except Exception as e: except Exception as e:
logger.opt(colors=True, exception=e).error( logger.opt(colors=True, exception=e).error(
f"<r><bg #f8bbd0>Bad url {url} for bot {self_id} " f"<r><bg #f8bbd0>Bad url {url} for bot {self_id} "