diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index b1942f85..b8a9de20 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -140,7 +140,7 @@ class Driver(abc.ABC): def on_bot_disconnect( self, - func: T_WebSocketDisConnectionHook) -> T_WebSocketDisConnectionHook: + func: T_WebSocketDisconnectionHook) -> T_WebSocketDisconnectionHook: """ :说明: @@ -153,30 +153,38 @@ class Driver(abc.ABC): self._ws_disconnection_hook.add(func) return func - async def bot_connect(self, bot: "Bot") -> None: + def bot_connect(self, bot: "Bot") -> None: """在 WebSocket 连接成功后,调用该函数来注册 bot 对象""" self._clients[bot.self_id] = bot - coros = list(map(lambda x: x(bot), self._ws_connection_hook)) - if coros: - try: - await asyncio.gather(*coros) - except Exception as e: - logger.opt(colors=True, exception=e).error( - "Error when running WebSocketConnection hook. " - "Running cancelled!") - async def bot_disconnect(self, bot: "Bot") -> None: + async def _run_hook(bot: "Bot") -> None: + coros = list(map(lambda x: x(bot), self._ws_connection_hook)) + if coros: + try: + await asyncio.gather(*coros) + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Error when running WebSocketConnection hook. " + "Running cancelled!") + + asyncio.create_task(_run_hook(bot)) + + def bot_disconnect(self, bot: "Bot") -> None: """在 WebSocket 连接断开后,调用该函数来注销 bot 对象""" if bot.self_id in self._clients: del self._clients[bot.self_id] - coros = list(map(lambda x: x(bot), self._ws_disconnection_hook)) - if coros: - try: - await asyncio.gather(*coros) - except Exception as e: - logger.opt(colors=True, exception=e).error( - "Error when running WebSocketDisConnection hook. " - "Running cancelled!") + + async def _run_hook(bot: "Bot") -> None: + coros = list(map(lambda x: x(bot), self._ws_disconnection_hook)) + if coros: + try: + await asyncio.gather(*coros) + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Error when running WebSocketDisConnection hook. " + "Running cancelled!") + + asyncio.create_task(_run_hook(bot)) @abc.abstractmethod def run(self, diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index 53acccf0..0e4457d9 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -188,11 +188,12 @@ class Driver(BaseDriver): bot = BotClass(self, "websocket", self.config, x_self_id, websocket=ws) await ws.accept() - await self.bot_connect(bot) logger.opt(colors=True).info( f"WebSocket Connection from {adapter.upper()} " f"Bot {x_self_id} Accepted!") + self.bot_connect(bot) + try: while not ws.closed: data = await ws.receive() @@ -202,7 +203,7 @@ class Driver(BaseDriver): asyncio.create_task(bot.handle_message(data)) finally: - await self.bot_disconnect(bot) + self.bot_disconnect(bot) class WebSocket(BaseWebSocket): diff --git a/tests/test_plugins/test_ws_hook.py b/tests/test_plugins/test_ws_hook.py new file mode 100644 index 00000000..7b5f9b2b --- /dev/null +++ b/tests/test_plugins/test_ws_hook.py @@ -0,0 +1,14 @@ +import nonebot +from nonebot.adapters import Bot + +driver = nonebot.get_driver() + + +@driver.on_bot_connect +async def connect(bot: Bot) -> None: + print("Connect", bot) + + +@driver.on_bot_disconnect +async def disconnect(bot: Bot) -> None: + print("Disconnect", bot)