diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py
index b1942f85..b8a9de20 100644
--- a/nonebot/drivers/__init__.py
+++ b/nonebot/drivers/__init__.py
@@ -140,7 +140,7 @@ class Driver(abc.ABC):
def on_bot_disconnect(
self,
- func: T_WebSocketDisConnectionHook) -> T_WebSocketDisConnectionHook:
+ func: T_WebSocketDisconnectionHook) -> T_WebSocketDisconnectionHook:
"""
:说明:
@@ -153,30 +153,38 @@ class Driver(abc.ABC):
self._ws_disconnection_hook.add(func)
return func
- async def bot_connect(self, bot: "Bot") -> None:
+ def bot_connect(self, bot: "Bot") -> None:
"""在 WebSocket 连接成功后,调用该函数来注册 bot 对象"""
self._clients[bot.self_id] = bot
- coros = list(map(lambda x: x(bot), self._ws_connection_hook))
- if coros:
- try:
- await asyncio.gather(*coros)
- except Exception as e:
- logger.opt(colors=True, exception=e).error(
- "Error when running WebSocketConnection hook. "
- "Running cancelled!")
- async def bot_disconnect(self, bot: "Bot") -> None:
+ async def _run_hook(bot: "Bot") -> None:
+ coros = list(map(lambda x: x(bot), self._ws_connection_hook))
+ if coros:
+ try:
+ await asyncio.gather(*coros)
+ except Exception as e:
+ logger.opt(colors=True, exception=e).error(
+ "Error when running WebSocketConnection hook. "
+ "Running cancelled!")
+
+ asyncio.create_task(_run_hook(bot))
+
+ def bot_disconnect(self, bot: "Bot") -> None:
"""在 WebSocket 连接断开后,调用该函数来注销 bot 对象"""
if bot.self_id in self._clients:
del self._clients[bot.self_id]
- coros = list(map(lambda x: x(bot), self._ws_disconnection_hook))
- if coros:
- try:
- await asyncio.gather(*coros)
- except Exception as e:
- logger.opt(colors=True, exception=e).error(
- "Error when running WebSocketDisConnection hook. "
- "Running cancelled!")
+
+ async def _run_hook(bot: "Bot") -> None:
+ coros = list(map(lambda x: x(bot), self._ws_disconnection_hook))
+ if coros:
+ try:
+ await asyncio.gather(*coros)
+ except Exception as e:
+ logger.opt(colors=True, exception=e).error(
+ "Error when running WebSocketDisConnection hook. "
+ "Running cancelled!")
+
+ asyncio.create_task(_run_hook(bot))
@abc.abstractmethod
def run(self,
diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py
index 53acccf0..0e4457d9 100644
--- a/nonebot/drivers/fastapi.py
+++ b/nonebot/drivers/fastapi.py
@@ -188,11 +188,12 @@ class Driver(BaseDriver):
bot = BotClass(self, "websocket", self.config, x_self_id, websocket=ws)
await ws.accept()
- await self.bot_connect(bot)
logger.opt(colors=True).info(
f"WebSocket Connection from {adapter.upper()} "
f"Bot {x_self_id} Accepted!")
+ self.bot_connect(bot)
+
try:
while not ws.closed:
data = await ws.receive()
@@ -202,7 +203,7 @@ class Driver(BaseDriver):
asyncio.create_task(bot.handle_message(data))
finally:
- await self.bot_disconnect(bot)
+ self.bot_disconnect(bot)
class WebSocket(BaseWebSocket):
diff --git a/tests/test_plugins/test_ws_hook.py b/tests/test_plugins/test_ws_hook.py
new file mode 100644
index 00000000..7b5f9b2b
--- /dev/null
+++ b/tests/test_plugins/test_ws_hook.py
@@ -0,0 +1,14 @@
+import nonebot
+from nonebot.adapters import Bot
+
+driver = nonebot.get_driver()
+
+
+@driver.on_bot_connect
+async def connect(bot: Bot) -> None:
+ print("Connect", bot)
+
+
+@driver.on_bot_disconnect
+async def disconnect(bot: Bot) -> None:
+ print("Disconnect", bot)