diff --git a/docs/api/drivers/README.md b/docs/api/drivers/README.md
index e8045d62..313717cb 100644
--- a/docs/api/drivers/README.md
+++ b/docs/api/drivers/README.md
@@ -32,6 +32,36 @@ Driver 基类。将后端框架封装,以满足适配器使用。
+### `_ws_connection_hook`
+
+
+* **类型**
+
+ `Set[T_WebSocketConnectionHook]`
+
+
+
+* **说明**
+
+ WebSocket 连接建立时执行的函数
+
+
+
+### `_ws_disconnection_hook`
+
+
+* **类型**
+
+ `Set[T_WebSocketDisconnectionHook]`
+
+
+
+* **说明**
+
+ WebSocket 连接断开时执行的函数
+
+
+
### _abstract_ `__init__(env, config)`
@@ -154,6 +184,48 @@ Driver 基类。将后端框架封装,以满足适配器使用。
注册一个在驱动停止时运行的函数
+### `on_bot_connect(func)`
+
+
+* **说明**
+
+ 装饰一个函数使他在 bot 通过 WebSocket 连接成功时执行。
+
+
+
+* **函数参数**
+
+
+ * `bot: Bot`: 当前连接上的 Bot 对象
+
+
+
+### `on_bot_disconnect(func)`
+
+
+* **说明**
+
+ 装饰一个函数使他在 bot 通过 WebSocket 连接断开时执行。
+
+
+
+* **函数参数**
+
+
+ * `bot: Bot`: 当前连接上的 Bot 对象
+
+
+
+### `_bot_connect(bot)`
+
+在 WebSocket 连接成功后,调用该函数来注册 bot 对象
+
+
+### `_bot_disconnect(bot)`
+
+在 WebSocket 连接断开后,调用该函数来注销 bot 对象
+
+
### _abstract_ `run(host=None, port=None, *args, **kwargs)`
diff --git a/docs/api/typing.md b/docs/api/typing.md
index 3bf49897..5d1b3d7b 100644
--- a/docs/api/typing.md
+++ b/docs/api/typing.md
@@ -46,6 +46,38 @@ sidebarDepth: 0
+## `T_WebSocketConnectionHook`
+
+
+* **类型**
+
+ `Callable[[Bot], Awaitable[None]]`
+
+
+
+* **说明**
+
+ WebSocket 连接建立时执行的函数
+
+
+
+
+## `T_WebSocketDisconnectionHook`
+
+
+* **类型**
+
+ `Callable[[Bot], Awaitable[None]]`
+
+
+
+* **说明**
+
+ WebSocket 连接断开时执行的函数
+
+
+
+
## `T_EventPreProcessor`
diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py
index b20f535b..7e95ee91 100644
--- a/nonebot/drivers/__init__.py
+++ b/nonebot/drivers/__init__.py
@@ -6,10 +6,12 @@
"""
import abc
-from typing import Dict, Type, Optional, Callable, TYPE_CHECKING
+import asyncio
+from typing import Set, Dict, Type, Optional, Callable, TYPE_CHECKING
from nonebot.log import logger
from nonebot.config import Env, Config
+from nonebot.typing import T_WebSocketConnectionHook, T_WebSocketDisconnectionHook
if TYPE_CHECKING:
from nonebot.adapters import Bot
@@ -25,6 +27,16 @@ class Driver(abc.ABC):
:类型: ``Dict[str, Type[Bot]]``
:说明: 已注册的适配器列表
"""
+ _ws_connection_hook: Set[T_WebSocketConnectionHook] = set()
+ """
+ :类型: ``Set[T_WebSocketConnectionHook]``
+ :说明: WebSocket 连接建立时执行的函数
+ """
+ _ws_disconnection_hook: Set[T_WebSocketDisconnectionHook] = set()
+ """
+ :类型: ``Set[T_WebSocketDisconnectionHook]``
+ :说明: WebSocket 连接断开时执行的函数
+ """
@abc.abstractmethod
def __init__(self, env: Env, config: Config):
@@ -93,8 +105,12 @@ class Driver(abc.ABC):
@property
def bots(self) -> Dict[str, "Bot"]:
"""
- :类型: ``Dict[str, Bot]``
- :说明: 获取当前所有已连接的 Bot
+ :类型:
+
+ ``Dict[str, Bot]``
+ :说明:
+
+ 获取当前所有已连接的 Bot
"""
return self._clients
@@ -108,6 +124,68 @@ class Driver(abc.ABC):
"""注册一个在驱动停止时运行的函数"""
raise NotImplementedError
+ def on_bot_connect(
+ self, func: T_WebSocketConnectionHook) -> T_WebSocketConnectionHook:
+ """
+ :说明:
+
+ 装饰一个函数使他在 bot 通过 WebSocket 连接成功时执行。
+
+ :函数参数:
+
+ * ``bot: Bot``: 当前连接上的 Bot 对象
+ """
+ self._ws_connection_hook.add(func)
+ return func
+
+ def on_bot_disconnect(
+ self,
+ func: T_WebSocketDisconnectionHook) -> T_WebSocketDisconnectionHook:
+ """
+ :说明:
+
+ 装饰一个函数使他在 bot 通过 WebSocket 连接断开时执行。
+
+ :函数参数:
+
+ * ``bot: Bot``: 当前连接上的 Bot 对象
+ """
+ self._ws_disconnection_hook.add(func)
+ return func
+
+ def _bot_connect(self, bot: "Bot") -> None:
+ """在 WebSocket 连接成功后,调用该函数来注册 bot 对象"""
+ self._clients[bot.self_id] = bot
+
+ async def _run_hook(bot: "Bot") -> None:
+ coros = list(map(lambda x: x(bot), self._ws_connection_hook))
+ if coros:
+ try:
+ await asyncio.gather(*coros)
+ except Exception as e:
+ logger.opt(colors=True, exception=e).error(
+ "Error when running WebSocketConnection hook. "
+ "Running cancelled!")
+
+ asyncio.create_task(_run_hook(bot))
+
+ def _bot_disconnect(self, bot: "Bot") -> None:
+ """在 WebSocket 连接断开后,调用该函数来注销 bot 对象"""
+ if bot.self_id in self._clients:
+ del self._clients[bot.self_id]
+
+ async def _run_hook(bot: "Bot") -> None:
+ coros = list(map(lambda x: x(bot), self._ws_disconnection_hook))
+ if coros:
+ try:
+ await asyncio.gather(*coros)
+ except Exception as e:
+ logger.opt(colors=True, exception=e).error(
+ "Error when running WebSocketDisConnection hook. "
+ "Running cancelled!")
+
+ asyncio.create_task(_run_hook(bot))
+
@abc.abstractmethod
def run(self,
host: Optional[str] = None,
diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py
index fa0a27eb..42e21490 100644
--- a/nonebot/drivers/fastapi.py
+++ b/nonebot/drivers/fastapi.py
@@ -188,11 +188,12 @@ class Driver(BaseDriver):
bot = BotClass(self, "websocket", self.config, x_self_id, websocket=ws)
await ws.accept()
- self._clients[x_self_id] = bot
logger.opt(colors=True).info(
f"WebSocket Connection from {adapter.upper()} "
f"Bot {x_self_id} Accepted!")
+ self._bot_connect(bot)
+
try:
while not ws.closed:
data = await ws.receive()
@@ -202,7 +203,7 @@ class Driver(BaseDriver):
asyncio.create_task(bot.handle_message(data))
finally:
- del self._clients[x_self_id]
+ self._bot_disconnect(bot)
class WebSocket(BaseWebSocket):
diff --git a/nonebot/typing.py b/nonebot/typing.py
index 5d5799b4..f92d1379 100644
--- a/nonebot/typing.py
+++ b/nonebot/typing.py
@@ -54,6 +54,23 @@ T_StateFactory = Callable[["Bot", "Event"], Awaitable[T_State]]
事件处理状态 State 类工厂函数
"""
+T_WebSocketConnectionHook = Callable[["Bot"], Awaitable[None]]
+"""
+:类型: ``Callable[[Bot], Awaitable[None]]``
+
+:说明:
+
+ WebSocket 连接建立时执行的函数
+"""
+T_WebSocketDisconnectionHook = Callable[["Bot"], Awaitable[None]]
+"""
+:类型: ``Callable[[Bot], Awaitable[None]]``
+
+:说明:
+
+ WebSocket 连接断开时执行的函数
+"""
+
T_EventPreProcessor = Callable[["Bot", "Event", T_State], Awaitable[None]]
"""
:类型: ``Callable[[Bot, Event, T_State], Awaitable[None]]``
diff --git a/tests/test_plugins/test_ws_hook.py b/tests/test_plugins/test_ws_hook.py
new file mode 100644
index 00000000..7b5f9b2b
--- /dev/null
+++ b/tests/test_plugins/test_ws_hook.py
@@ -0,0 +1,14 @@
+import nonebot
+from nonebot.adapters import Bot
+
+driver = nonebot.get_driver()
+
+
+@driver.on_bot_connect
+async def connect(bot: Bot) -> None:
+ print("Connect", bot)
+
+
+@driver.on_bot_disconnect
+async def disconnect(bot: Bot) -> None:
+ print("Disconnect", bot)