diff --git a/docs/.vuepress/config.js b/docs/.vuepress/config.js
index 69efc7ba..4aef42b7 100644
--- a/docs/.vuepress/config.js
+++ b/docs/.vuepress/config.js
@@ -219,6 +219,10 @@ module.exports = (context) => ({
title: "nonebot.drivers.quart 模块",
path: "drivers/quart",
},
+ {
+ title: "nonebot.drivers.aiohttp 模块",
+ path: "drivers/aiohttp",
+ },
{
title: "nonebot.adapters 模块",
path: "adapters/",
diff --git a/docs/api/README.md b/docs/api/README.md
index 26d4a0f5..38fac915 100644
--- a/docs/api/README.md
+++ b/docs/api/README.md
@@ -49,6 +49,9 @@
* [nonebot.drivers.quart](drivers/quart.html)
+ * [nonebot.drivers.aiohttp](drivers/aiohttp.html)
+
+
* [nonebot.adapters](adapters/)
diff --git a/docs/api/drivers/README.md b/docs/api/drivers/README.md
index 57c4d135..9f8ee3ee 100644
--- a/docs/api/drivers/README.md
+++ b/docs/api/drivers/README.md
@@ -238,6 +238,45 @@ Driver 基类。
在 WebSocket 连接断开后,调用该函数来注销 bot 对象
+## _class_ `ForwardDriver`
+
+基类:`nonebot.drivers.Driver`
+
+Forward Driver 基类。将客户端框架封装,以满足适配器使用。
+
+
+### _abstract_ `setup_http_polling(setup)`
+
+
+* **说明**
+
+ 注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
+
+
+
+* **参数**
+
+
+ * `setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]`
+
+
+
+### _abstract_ `setup_websocket(setup)`
+
+
+* **说明**
+
+ 注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
+
+
+
+* **参数**
+
+
+ * `setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]`
+
+
+
## _class_ `ReverseDriver`
基类:`nonebot.drivers.Driver`
@@ -413,3 +452,78 @@ Always `websocket`
### _abstract async_ `send_bytes(data)`
发送一条 WebSocket binary 信息
+
+
+## _class_ `HTTPPollingSetup`
+
+基类:`object`
+
+
+### `adapter`
+
+协议适配器名称
+
+
+### `self_id`
+
+机器人 ID
+
+
+### `url`
+
+URL
+
+
+### `method`
+
+HTTP method
+
+
+### `body`
+
+HTTP body
+
+
+### `headers`
+
+HTTP headers
+
+
+### `http_version`
+
+HTTP version
+
+
+### `poll_interval`
+
+HTTP 轮询间隔
+
+
+## _class_ `WebSocketSetup`
+
+基类:`object`
+
+
+### `adapter`
+
+协议适配器名称
+
+
+### `self_id`
+
+机器人 ID
+
+
+### `url`
+
+URL
+
+
+### `headers`
+
+HTTP headers
+
+
+### `reconnect_interval`
+
+WebSocket 重连间隔
diff --git a/docs/api/drivers/aiohttp.md b/docs/api/drivers/aiohttp.md
new file mode 100644
index 00000000..4159d44e
--- /dev/null
+++ b/docs/api/drivers/aiohttp.md
@@ -0,0 +1,101 @@
+---
+contentSidebar: true
+sidebarDepth: 0
+---
+
+# NoneBot.drivers.aiohttp 模块
+
+## AIOHTTP 驱动适配
+
+本驱动仅支持客户端连接
+
+
+## _class_ `Driver`
+
+基类:[`nonebot.drivers.ForwardDriver`](README.md#nonebot.drivers.ForwardDriver)
+
+AIOHTTP 驱动框架
+
+
+### _property_ `type`
+
+驱动名称: `aiohttp`
+
+
+### _property_ `logger`
+
+aiohttp driver 使用的 logger
+
+
+### `on_startup(func)`
+
+
+* **说明**
+
+ 注册一个启动时执行的函数
+
+
+
+* **参数**
+
+
+ * `func: Callable[[], Awaitable[None]]`
+
+
+
+### `on_shutdown(func)`
+
+
+* **说明**
+
+ 注册一个停止时执行的函数
+
+
+
+* **参数**
+
+
+ * `func: Callable[[], Awaitable[None]]`
+
+
+
+### `setup_http_polling(setup)`
+
+
+* **说明**
+
+ 注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
+
+
+
+* **参数**
+
+
+ * `setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]`
+
+
+
+### `setup_websocket(setup)`
+
+
+* **说明**
+
+ 注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
+
+
+
+* **参数**
+
+
+ * `setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]`
+
+
+
+### `run(*args, **kwargs)`
+
+启动 aiohttp driver
+
+
+## _class_ `WebSocket`
+
+基类:[`nonebot.drivers.WebSocket`](README.md#nonebot.drivers.WebSocket)
diff --git a/docs/api/drivers/fastapi.md b/docs/api/drivers/fastapi.md
index 3a79f42c..2c02c8d6 100644
--- a/docs/api/drivers/fastapi.md
+++ b/docs/api/drivers/fastapi.md
@@ -7,19 +7,11 @@ sidebarDepth: 0
## FastAPI 驱动适配
+本驱动同时支持服务端以及客户端连接
+
后端使用方法请参考: [FastAPI 文档](https://fastapi.tiangolo.com/)
-## _class_ `HTTPPollingSetup`
-
-基类:`object`
-
-
-## _class_ `WebSocketSetup`
-
-基类:`object`
-
-
## _class_ `Config`
基类:`pydantic.env_settings.BaseSettings`
@@ -89,7 +81,7 @@ FastAPI 驱动框架设置,详情参考 FastAPI 文档
## _class_ `Driver`
-基类:[`nonebot.drivers.ReverseDriver`](README.md#nonebot.drivers.ReverseDriver), `nonebot.drivers.ForwardDriver`
+基类:[`nonebot.drivers.ReverseDriver`](README.md#nonebot.drivers.ReverseDriver), [`nonebot.drivers.ForwardDriver`](README.md#nonebot.drivers.ForwardDriver)
FastAPI 驱动框架
@@ -140,6 +132,38 @@ fastapi 使用的 logger
参考文档: [Events](https://fastapi.tiangolo.com/advanced/events/#startup-event)
+### `setup_http_polling(setup)`
+
+
+* **说明**
+
+ 注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
+
+
+
+* **参数**
+
+
+ * `setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]`
+
+
+
+### `setup_websocket(setup)`
+
+
+* **说明**
+
+ 注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
+
+
+
+* **参数**
+
+
+ * `setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]`
+
+
+
### `run(host=None, port=None, *, app=None, **kwargs)`
使用 `uvicorn` 启动 FastAPI
diff --git a/docs_build/README.rst b/docs_build/README.rst
index 35d9368c..5d5fc86d 100644
--- a/docs_build/README.rst
+++ b/docs_build/README.rst
@@ -17,6 +17,7 @@ NoneBot Api Reference
- `nonebot.drivers `_
- `nonebot.drivers.fastapi `_
- `nonebot.drivers.quart `_
+ - `nonebot.drivers.aiohttp `_
- `nonebot.adapters `_
- `nonebot.adapters.cqhttp `_
- `nonebot.adapters.ding `_
diff --git a/docs_build/drivers/aiohttp.rst b/docs_build/drivers/aiohttp.rst
new file mode 100644
index 00000000..077da6c5
--- /dev/null
+++ b/docs_build/drivers/aiohttp.rst
@@ -0,0 +1,12 @@
+\-\-\-
+contentSidebar: true
+sidebarDepth: 0
+\-\-\-
+
+NoneBot.drivers.aiohttp 模块
+=============================
+
+.. automodule:: nonebot.drivers.aiohttp
+ :members:
+ :private-members:
+ :show-inheritance:
diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py
index d7a7a8df..7607e4b3 100644
--- a/nonebot/drivers/__init__.py
+++ b/nonebot/drivers/__init__.py
@@ -8,7 +8,7 @@
import abc
import asyncio
from dataclasses import dataclass, field
-from typing import Any, Set, Dict, Type, Optional, Callable, TYPE_CHECKING
+from typing import Any, Set, Dict, Type, Union, Optional, Callable, Awaitable, TYPE_CHECKING
from nonebot.log import logger
from nonebot.config import Env, Config
@@ -193,27 +193,40 @@ class Driver(abc.ABC):
class ForwardDriver(Driver):
+ """
+ Forward Driver 基类。将客户端框架封装,以满足适配器使用。
+ """
@abc.abstractmethod
- def setup_http_polling(self,
- adapter: str,
- self_id: str,
- url: str,
- polling_interval: float = 3.,
- method: str = "GET",
- body: bytes = b"",
- headers: Dict[str, str] = {},
- http_version: str = "1.1") -> None:
+ def setup_http_polling(
+ self, setup: Union["HTTPPollingSetup",
+ Callable[[], Awaitable["HTTPPollingSetup"]]]
+ ) -> None:
+ """
+ :说明:
+
+ 注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
+
+ :参数:
+
+ * ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
+ """
raise NotImplementedError
@abc.abstractmethod
- def setup_websocket(self,
- adapter: str,
- self_id: str,
- url: str,
- reconnect_interval: float = 3.,
- headers: Dict[str, str] = {},
- http_version: str = "1.1") -> None:
+ def setup_websocket(
+ self, setup: Union["WebSocketSetup",
+ Callable[[], Awaitable["WebSocketSetup"]]]
+ ) -> None:
+ """
+ :说明:
+
+ 注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
+
+ :参数:
+
+ * ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
+ """
raise NotImplementedError
@@ -369,3 +382,37 @@ class WebSocket(HTTPConnection, abc.ABC):
async def send_bytes(self, data: bytes):
"""发送一条 WebSocket binary 信息"""
raise NotImplementedError
+
+
+@dataclass
+class HTTPPollingSetup:
+ adapter: str
+ """协议适配器名称"""
+ self_id: str
+ """机器人 ID"""
+ url: str
+ """URL"""
+ method: str
+ """HTTP method"""
+ body: bytes
+ """HTTP body"""
+ headers: Dict[str, str]
+ """HTTP headers"""
+ http_version: str
+ """HTTP version"""
+ poll_interval: float
+ """HTTP 轮询间隔"""
+
+
+@dataclass
+class WebSocketSetup:
+ adapter: str
+ """协议适配器名称"""
+ self_id: str
+ """机器人 ID"""
+ url: str
+ """URL"""
+ headers: Dict[str, str] = field(default_factory=dict)
+ """HTTP headers"""
+ reconnect_interval: float = 3.
+ """WebSocket 重连间隔"""
diff --git a/nonebot/drivers/aiohttp.py b/nonebot/drivers/aiohttp.py
index 7977acfb..10d464c0 100644
--- a/nonebot/drivers/aiohttp.py
+++ b/nonebot/drivers/aiohttp.py
@@ -1,11 +1,15 @@
"""
+AIOHTTP 驱动适配
+================
+
+本驱动仅支持客户端连接
"""
import signal
import asyncio
import threading
from dataclasses import dataclass
-from typing import Set, List, Dict, Optional, Callable, Awaitable
+from typing import Set, List, cast, Union, Optional, Callable, Awaitable
import aiohttp
from yarl import URL
@@ -14,46 +18,31 @@ from nonebot.log import logger
from nonebot.adapters import Bot
from nonebot.typing import overrides
from nonebot.config import Env, Config
-from nonebot.drivers import ForwardDriver, HTTPRequest, WebSocket as BaseWebSocket
+from nonebot.drivers import (ForwardDriver, HTTPPollingSetup, WebSocketSetup,
+ HTTPRequest, WebSocket as BaseWebSocket)
STARTUP_FUNC = Callable[[], Awaitable[None]]
SHUTDOWN_FUNC = Callable[[], Awaitable[None]]
+HTTPPOLLING_SETUP = Union[HTTPPollingSetup,
+ Callable[[], Awaitable[HTTPPollingSetup]]]
+WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill `.
)
-@dataclass
-class HTTPPollingSetup:
- adapter: str
- self_id: str
- url: str
- method: str
- body: bytes
- headers: Dict[str, str]
- http_version: str
- poll_interval: float
-
-
-@dataclass
-class WebSocketSetup:
- adapter: str
- self_id: str
- url: str
- headers: Dict[str, str]
- http_version: str
- reconnect_interval: float
-
-
class Driver(ForwardDriver):
+ """
+ AIOHTTP 驱动框架
+ """
def __init__(self, env: Env, config: Config):
super().__init__(env, config)
self.startup_funcs: Set[STARTUP_FUNC] = set()
self.shutdown_funcs: Set[SHUTDOWN_FUNC] = set()
- self.http_pollings: List[HTTPPollingSetup] = []
- self.websockets: List[WebSocketSetup] = []
+ self.http_pollings: List[HTTPPOLLING_SETUP] = []
+ self.websockets: List[WEBSOCKET_SETUP] = []
self.connections: List[asyncio.Task] = []
self.should_exit: asyncio.Event = asyncio.Event()
self.force_exit: bool = False
@@ -67,46 +56,66 @@ class Driver(ForwardDriver):
@property
@overrides(ForwardDriver)
def logger(self):
+ """aiohttp driver 使用的 logger"""
return logger
@overrides(ForwardDriver)
- def on_startup(self, func: Callable) -> Callable:
+ def on_startup(self, func: STARTUP_FUNC) -> STARTUP_FUNC:
+ """
+ :说明:
+
+ 注册一个启动时执行的函数
+
+ :参数:
+
+ * ``func: Callable[[], Awaitable[None]]``
+ """
self.startup_funcs.add(func)
return func
@overrides(ForwardDriver)
- def on_shutdown(self, func: Callable) -> Callable:
+ def on_shutdown(self, func: SHUTDOWN_FUNC) -> SHUTDOWN_FUNC:
+ """
+ :说明:
+
+ 注册一个停止时执行的函数
+
+ :参数:
+
+ * ``func: Callable[[], Awaitable[None]]``
+ """
self.shutdown_funcs.add(func)
return func
@overrides(ForwardDriver)
- def setup_http_polling(self,
- adapter: str,
- self_id: str,
- url: str,
- polling_interval: float = 3.,
- method: str = "GET",
- body: bytes = b"",
- headers: Dict[str, str] = {},
- http_version: str = "1.1") -> None:
- self.http_pollings.append(
- HTTPPollingSetup(adapter, self_id, url, method, body, headers,
- http_version, polling_interval))
+ def setup_http_polling(self, setup: HTTPPOLLING_SETUP) -> None:
+ """
+ :说明:
+
+ 注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
+
+ :参数:
+
+ * ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
+ """
+ self.http_pollings.append(setup)
@overrides(ForwardDriver)
- def setup_websocket(self,
- adapter: str,
- self_id: str,
- url: str,
- reconnect_interval: float = 3.,
- headers: Dict[str, str] = {},
- http_version: str = "1.1") -> None:
- self.websockets.append(
- WebSocketSetup(adapter, self_id, url, headers, http_version,
- reconnect_interval))
+ def setup_websocket(self, setup: WEBSOCKET_SETUP) -> None:
+ """
+ :说明:
+
+ 注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
+
+ :参数:
+
+ * ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
+ """
+ self.websockets.append(setup)
@overrides(ForwardDriver)
def run(self, *args, **kwargs):
+ """启动 aiohttp driver"""
super().run(*args, **kwargs)
loop = asyncio.get_event_loop()
loop.run_until_complete(self.serve())
@@ -197,59 +206,88 @@ class Driver(ForwardDriver):
else:
self.should_exit.set()
- async def _http_loop(self, setup: HTTPPollingSetup):
- url = URL(setup.url)
- if not url.is_absolute() or not url.host:
- logger.opt(colors=True).error(
- f"Error parsing url {url}")
- return
- host = f"{url.host}:{url.port}" if url.port else url.host
- request = HTTPRequest(setup.http_version, url.scheme, url.path,
- url.raw_query_string.encode("latin-1"), {
- **setup.headers, "host": host
- }, setup.method, setup.body)
+ async def _http_loop(self, setup: HTTPPOLLING_SETUP):
+
+ async def _build_request(
+ setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
+ url = URL(setup.url)
+ if not url.is_absolute() or not url.host:
+ logger.opt(colors=True).error(
+ f"Error parsing url {url}")
+ return
+ host = f"{url.host}:{url.port}" if url.port else url.host
+ return HTTPRequest(setup.http_version, url.scheme, url.path,
+ url.raw_query_string.encode("latin-1"), {
+ **setup.headers, "host": host
+ }, setup.method, setup.body)
+
+ bot: Optional[Bot] = None
+ request: Optional[HTTPRequest] = None
+ setup_: Optional[HTTPPollingSetup] = None
- BotClass = self._adapters[setup.adapter]
- bot = BotClass(setup.self_id, request)
- self._bot_connect(bot)
logger.opt(colors=True).info(
f"Start http polling for {setup.adapter.upper()} "
f"Bot {setup.self_id}")
- headers = request.headers
- timeout = aiohttp.ClientTimeout(30)
- version: aiohttp.HttpVersion
- if request.http_version == "1.0":
- version = aiohttp.HttpVersion10
- elif request.http_version == "1.1":
- version = aiohttp.HttpVersion11
- else:
- logger.opt(colors=True).error(
- "Unsupported HTTP Version "
- f"{request.http_version}")
- return
-
try:
- async with aiohttp.ClientSession(headers=headers,
- timeout=timeout,
- version=version) as session:
+ async with aiohttp.ClientSession() as session:
while not self.should_exit.is_set():
+ if not bot:
+ if callable(setup):
+ setup_ = await setup()
+ else:
+ setup_ = setup
+ request = await _build_request(setup_)
+ if not request:
+ return
+
+ BotClass = self._adapters[setup.adapter]
+ bot = BotClass(setup.self_id, request)
+ self._bot_connect(bot)
+ elif callable(setup):
+ setup_ = await setup()
+ request = await _build_request(setup_)
+ if not request:
+ await asyncio.sleep(setup_.poll_interval)
+ continue
+ bot.request = request
+
+ request = cast(HTTPRequest, request)
+ setup_ = cast(HTTPPollingSetup, setup_)
+
+ headers = request.headers
+ timeout = aiohttp.ClientTimeout(30)
+ version: aiohttp.HttpVersion
+ if request.http_version == "1.0":
+ version = aiohttp.HttpVersion10
+ elif request.http_version == "1.1":
+ version = aiohttp.HttpVersion11
+ else:
+ logger.opt(colors=True).error(
+ "Unsupported HTTP Version "
+ f"{request.http_version}")
+ return
+
logger.debug(
- f"Bot {setup.self_id} from adapter {setup.adapter} request {url}"
+ f"Bot {setup_.self_id} from adapter {setup_.adapter} request {setup_.url}"
)
+
try:
- async with session.request(
- request.method, url,
- data=request.body) as response:
+ async with session.request(request.method,
+ setup_.url,
+ data=request.body,
+ headers=headers,
+ timeout=timeout,
+ version=version) as response:
response.raise_for_status()
data = await response.read()
asyncio.create_task(bot.handle_message(data))
except aiohttp.ClientResponseError as e:
logger.opt(colors=True, exception=e).error(
- f"Error occurred while requesting {url}. "
+ f"Error occurred while requesting {setup_.url}. "
"Try to reconnect...")
- await asyncio.sleep(setup.poll_interval)
+ await asyncio.sleep(setup_.poll_interval)
except asyncio.CancelledError:
pass
@@ -258,50 +296,48 @@ class Driver(ForwardDriver):
"Unexpected exception occurred "
"while http polling")
finally:
- self._bot_disconnect(bot)
-
- async def _ws_loop(self, setup: WebSocketSetup):
- url = URL(setup.url)
- if not url.is_absolute() or not url.host:
- logger.opt(colors=True).error(
- f"Error parsing url {url}")
- return
- host = f"{url.host}:{url.port}" if url.port else url.host
-
- headers = {**setup.headers, "host": host}
- timeout = aiohttp.ClientTimeout(30)
- version: aiohttp.HttpVersion
- if setup.http_version == "1.0":
- version = aiohttp.HttpVersion10
- elif setup.http_version == "1.1":
- version = aiohttp.HttpVersion11
- else:
- logger.opt(colors=True).error(
- "Unsupported HTTP Version "
- f"{setup.http_version}")
- return
+ if bot:
+ self._bot_disconnect(bot)
+ async def _ws_loop(self, setup: WEBSOCKET_SETUP):
bot: Optional[Bot] = None
+
try:
- async with aiohttp.ClientSession(headers=headers,
- timeout=timeout,
- version=version) as session:
+ async with aiohttp.ClientSession() as session:
while True:
+ if callable(setup):
+ setup_ = await setup()
+ else:
+ setup_ = setup
+
+ url = URL(setup_.url)
+ if not url.is_absolute() or not url.host:
+ logger.opt(colors=True).error(
+ f"Error parsing url {url}"
+ )
+ await asyncio.sleep(setup_.reconnect_interval)
+ continue
+
+ host = f"{url.host}:{url.port}" if url.port else url.host
+ headers = {**setup_.headers, "host": host}
+
logger.debug(
- f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}"
+ f"Bot {setup_.self_id} from adapter {setup_.adapter} connecting to {url}"
)
try:
- async with session.ws_connect(url) as ws:
+ async with session.ws_connect(url,
+ headers=headers,
+ timeout=30.) as ws:
logger.opt(colors=True).info(
- f"WebSocket Connection to {setup.adapter.upper()} "
- f"Bot {setup.self_id} succeeded!")
+ f"WebSocket Connection to {setup_.adapter.upper()} "
+ f"Bot {setup_.self_id} succeeded!")
request = WebSocket(
- setup.http_version, url.scheme, url.path,
+ "1.1", url.scheme, url.path,
url.raw_query_string.encode("latin-1"), headers,
ws)
- BotClass = self._adapters[setup.adapter]
- bot = BotClass(setup.self_id, request)
+ BotClass = self._adapters[setup_.adapter]
+ bot = BotClass(setup_.self_id, request)
self._bot_connect(bot)
while not self.should_exit.is_set():
msg = await ws.receive()
@@ -330,7 +366,7 @@ class Driver(ForwardDriver):
if bot:
self._bot_disconnect(bot)
bot = None
- await asyncio.sleep(setup.reconnect_interval)
+ await asyncio.sleep(setup_.reconnect_interval)
except asyncio.CancelledError:
pass
diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py
index caf9e2e9..b263ba7a 100644
--- a/nonebot/drivers/fastapi.py
+++ b/nonebot/drivers/fastapi.py
@@ -2,6 +2,8 @@
FastAPI 驱动适配
================
+本驱动同时支持服务端以及客户端连接
+
后端使用方法请参考: `FastAPI 文档`_
.. _FastAPI 文档:
@@ -11,7 +13,7 @@ FastAPI 驱动适配
import asyncio
import logging
from dataclasses import dataclass
-from typing import List, Dict, Union, Optional, Callable
+from typing import List, cast, Union, Optional, Callable, Awaitable
import httpx
import uvicorn
@@ -27,30 +29,13 @@ from nonebot.log import logger
from nonebot.adapters import Bot
from nonebot.typing import overrides
from nonebot.config import Env, Config as NoneBotConfig
-from nonebot.drivers import ReverseDriver, ForwardDriver
-from nonebot.drivers import HTTPRequest, WebSocket as BaseWebSocket
+from nonebot.drivers import (ReverseDriver, ForwardDriver, HTTPPollingSetup,
+ WebSocketSetup, HTTPRequest, WebSocket as
+ BaseWebSocket)
-
-@dataclass
-class HTTPPollingSetup:
- adapter: str
- self_id: str
- url: str
- method: str
- body: bytes
- headers: Dict[str, str]
- http_version: str
- poll_interval: float
-
-
-@dataclass
-class WebSocketSetup:
- adapter: str
- self_id: str
- url: str
- headers: Dict[str, str]
- http_version: str
- reconnect_interval: float
+HTTPPOLLING_SETUP = Union[HTTPPollingSetup,
+ Callable[[], Awaitable[HTTPPollingSetup]]]
+WEBSOCKET_SETUP = Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]
class Config(BaseSettings):
@@ -118,8 +103,8 @@ class Driver(ReverseDriver, ForwardDriver):
super().__init__(env, config)
self.fastapi_config: Config = Config(**config.dict())
- self.http_pollings: List[HTTPPollingSetup] = []
- self.websockets: List[WebSocketSetup] = []
+ self.http_pollings: List[HTTPPOLLING_SETUP] = []
+ self.websockets: List[WEBSOCKET_SETUP] = []
self.shutdown: asyncio.Event = asyncio.Event()
self.connections: List[asyncio.Task] = []
@@ -173,30 +158,30 @@ class Driver(ReverseDriver, ForwardDriver):
return self.server_app.on_event("shutdown")(func)
@overrides(ForwardDriver)
- def setup_http_polling(self,
- adapter: str,
- self_id: str,
- url: str,
- polling_interval: float = 3.,
- method: str = "GET",
- body: bytes = b"",
- headers: Dict[str, str] = {},
- http_version: str = "1.1") -> None:
- self.http_pollings.append(
- HTTPPollingSetup(adapter, self_id, url, method, body, headers,
- http_version, polling_interval))
+ def setup_http_polling(self, setup: HTTPPOLLING_SETUP) -> None:
+ """
+ :说明:
+
+ 注册一个 HTTP 轮询连接,如果传入一个函数,则该函数会在每次连接时被调用
+
+ :参数:
+
+ * ``setup: Union[HTTPPollingSetup, Callable[[], Awaitable[HTTPPollingSetup]]]``
+ """
+ self.http_pollings.append(setup)
@overrides(ForwardDriver)
- def setup_websocket(self,
- adapter: str,
- self_id: str,
- url: str,
- reconnect_interval: float = 3.,
- headers: Dict[str, str] = {},
- http_version: str = "1.1") -> None:
- self.websockets.append(
- WebSocketSetup(adapter, self_id, url, headers, http_version,
- reconnect_interval))
+ def setup_websocket(self, setup: WEBSOCKET_SETUP) -> None:
+ """
+ :说明:
+
+ 注册一个 WebSocket 连接,如果传入一个函数,则该函数会在每次重连时被调用
+
+ :参数:
+
+ * ``setup: Union[WebSocketSetup, Callable[[], Awaitable[WebSocketSetup]]]``
+ """
+ self.websockets.append(setup)
@overrides(ReverseDriver)
def run(self,
@@ -336,50 +321,72 @@ class Driver(ReverseDriver, ForwardDriver):
finally:
self._bot_disconnect(bot)
- async def _http_loop(self, setup: HTTPPollingSetup):
- url = httpx.URL(setup.url)
- if not url.netloc:
- logger.opt(colors=True).error(
- f"Error parsing url {url}")
- return
- request = HTTPRequest(
- setup.http_version, url.scheme, url.path, url.query, {
- **setup.headers, "host": url.netloc.decode("ascii")
- }, setup.method, setup.body)
+ async def _http_loop(self, setup: HTTPPOLLING_SETUP):
+
+ async def _build_request(
+ setup: HTTPPollingSetup) -> Optional[HTTPRequest]:
+ url = httpx.URL(setup.url)
+ if not url.netloc:
+ logger.opt(colors=True).error(
+ f"Error parsing url {url}")
+ return
+ return HTTPRequest(
+ setup.http_version, url.scheme, url.path, url.query, {
+ **setup.headers, "host": url.netloc.decode("ascii")
+ }, setup.method, setup.body)
+
+ bot: Optional[Bot] = None
+ request: Optional[HTTPRequest] = None
+ setup_: Optional[HTTPPollingSetup] = None
- BotClass = self._adapters[setup.adapter]
- bot = BotClass(setup.self_id, request)
- self._bot_connect(bot)
logger.opt(colors=True).info(
f"Start http polling for {setup.adapter.upper()} "
f"Bot {setup.self_id}")
- headers = request.headers
- http2: bool = False
- if request.http_version == "2":
- http2 = True
-
try:
- async with httpx.AsyncClient(headers=headers,
- timeout=30.,
- http2=http2) as session:
+ async with httpx.AsyncClient(http2=True) as session:
while not self.shutdown.is_set():
+ if not bot:
+ if callable(setup):
+ setup_ = await setup()
+ else:
+ setup_ = setup
+ request = await _build_request(setup_)
+ if not request:
+ return
+ BotClass = self._adapters[setup.adapter]
+ bot = BotClass(setup.self_id, request)
+ self._bot_connect(bot)
+ elif callable(setup):
+ setup_ = await setup()
+ request = await _build_request(setup_)
+ if not request:
+ await asyncio.sleep(setup_.poll_interval)
+ continue
+ bot.request = request
+
+ setup_ = cast(HTTPPollingSetup, setup_)
+ request = cast(HTTPRequest, request)
+ headers = request.headers
+
logger.debug(
- f"Bot {setup.self_id} from adapter {setup.adapter} request {url}"
+ f"Bot {setup_.self_id} from adapter {setup_.adapter} request {setup_.url}"
)
try:
response = await session.request(request.method,
- url,
- content=request.body)
+ setup_.url,
+ content=request.body,
+ headers=headers,
+ timeout=30.)
response.raise_for_status()
data = response.read()
asyncio.create_task(bot.handle_message(data))
except httpx.HTTPError as e:
logger.opt(colors=True, exception=e).error(
- f"Error occurred while requesting {url}. "
+ f"Error occurred while requesting {setup_.url}. "
"Try to reconnect...")
- await asyncio.sleep(setup.poll_interval)
+ await asyncio.sleep(setup_.poll_interval)
except asyncio.CancelledError:
pass
@@ -388,34 +395,41 @@ class Driver(ReverseDriver, ForwardDriver):
"Unexpected exception occurred "
"while http polling")
finally:
- self._bot_disconnect(bot)
-
- async def _ws_loop(self, setup: WebSocketSetup):
- url = httpx.URL(setup.url)
- if not url.netloc:
- logger.opt(colors=True).error(
- f"Error parsing url {url}")
- return
-
- headers = {**setup.headers, "host": url.netloc.decode("ascii")}
+ if bot:
+ self._bot_disconnect(bot)
+ async def _ws_loop(self, setup: WEBSOCKET_SETUP):
bot: Optional[Bot] = None
+
try:
while True:
+ if callable(setup):
+ setup_ = await setup()
+ else:
+ setup_ = setup
+
+ url = httpx.URL(setup_.url)
+ if not url.netloc:
+ logger.opt(colors=True).error(
+ f"Error parsing url {url}"
+ )
+ return
+
+ headers = {**setup_.headers, "host": url.netloc.decode("ascii")}
logger.debug(
- f"Bot {setup.self_id} from adapter {setup.adapter} connecting to {url}"
+ f"Bot {setup_.self_id} from adapter {setup_.adapter} connecting to {url}"
)
try:
- connection = Connect(setup.url)
+ connection = Connect(setup_.url)
async with connection as ws:
logger.opt(colors=True).info(
- f"WebSocket Connection to {setup.adapter.upper()} "
- f"Bot {setup.self_id} succeeded!")
- request = WebSocket(setup.http_version, url.scheme,
- url.path, url.query, headers, ws)
+ f"WebSocket Connection to {setup_.adapter.upper()} "
+ f"Bot {setup_.self_id} succeeded!")
+ request = WebSocket("1.1", url.scheme, url.path,
+ url.query, headers, ws)
- BotClass = self._adapters[setup.adapter]
- bot = BotClass(setup.self_id, request)
+ BotClass = self._adapters[setup_.adapter]
+ bot = BotClass(setup_.self_id, request)
self._bot_connect(bot)
while not self.shutdown.is_set():
# use try except instead of "request.closed" because of queued message
@@ -434,7 +448,7 @@ class Driver(ReverseDriver, ForwardDriver):
if bot:
self._bot_disconnect(bot)
bot = None
- await asyncio.sleep(setup.reconnect_interval)
+ await asyncio.sleep(setup_.reconnect_interval)
except asyncio.CancelledError:
pass
diff --git a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py
index 6f6233bc..a6a922df 100644
--- a/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py
+++ b/packages/nonebot-adapter-cqhttp/nonebot/adapters/cqhttp/bot.py
@@ -11,7 +11,7 @@ from nonebot.typing import overrides
from nonebot.message import handle_event
from nonebot.adapters import Bot as BaseBot
from nonebot.utils import escape_tag, DataclassEncoder
-from nonebot.drivers import Driver, ForwardDriver, ReverseDriver
+from nonebot.drivers import Driver, ForwardDriver, WebSocketSetup
from nonebot.drivers import HTTPConnection, HTTPRequest, HTTPResponse, WebSocket
from .utils import log, escape
@@ -249,10 +249,8 @@ class Bot(BaseBot):
"authorization":
f"Bearer {cls.cqhttp_config.access_token}"
} if cls.cqhttp_config.access_token else {}
- driver.setup_websocket("cqhttp",
- self_id,
- url,
- headers=headers)
+ driver.setup_websocket(
+ WebSocketSetup("cqhttp", self_id, url, headers=headers))
except Exception as e:
logger.opt(colors=True, exception=e).error(
f"Bad url {url} for bot {self_id} "