From c993f15bca3d7c97887e174be2affffa3db9e6f9 Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Mon, 28 Dec 2020 13:36:00 +0800 Subject: [PATCH 1/5] :alembic: add ws connection hook --- nonebot/drivers/__init__.py | 76 +++++++++++++++++++++++++++++++++++-- nonebot/drivers/fastapi.py | 4 +- nonebot/typing.py | 17 +++++++++ 3 files changed, 92 insertions(+), 5 deletions(-) diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index b20f535b..b1942f85 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -6,10 +6,12 @@ """ import abc -from typing import Dict, Type, Optional, Callable, TYPE_CHECKING +import asyncio +from typing import Set, Dict, Type, Optional, Callable, TYPE_CHECKING from nonebot.log import logger from nonebot.config import Env, Config +from nonebot.typing import T_WebSocketConnectionHook, T_WebSocketDisconnectionHook if TYPE_CHECKING: from nonebot.adapters import Bot @@ -25,6 +27,16 @@ class Driver(abc.ABC): :类型: ``Dict[str, Type[Bot]]`` :说明: 已注册的适配器列表 """ + _ws_connection_hook: Set[T_WebSocketConnectionHook] = set() + """ + :类型: ``Set[T_WebSocketConnectionHook]`` + :说明: WebSocket 连接建立时执行的函数 + """ + _ws_disconnection_hook: Set[T_WebSocketDisconnectionHook] = set() + """ + :类型: ``Set[T_WebSocketDisconnectionHook]`` + :说明: WebSocket 连接断开时执行的函数 + """ @abc.abstractmethod def __init__(self, env: Env, config: Config): @@ -93,8 +105,12 @@ class Driver(abc.ABC): @property def bots(self) -> Dict[str, "Bot"]: """ - :类型: ``Dict[str, Bot]`` - :说明: 获取当前所有已连接的 Bot + :类型: + + ``Dict[str, Bot]`` + :说明: + + 获取当前所有已连接的 Bot """ return self._clients @@ -108,6 +124,60 @@ class Driver(abc.ABC): """注册一个在驱动停止时运行的函数""" raise NotImplementedError + def on_bot_connect( + self, func: T_WebSocketConnectionHook) -> T_WebSocketConnectionHook: + """ + :说明: + + 装饰一个函数使他在 bot 通过 WebSocket 连接成功时执行。 + + :函数参数: + + * ``bot: Bot``: 当前连接上的 Bot 对象 + """ + self._ws_connection_hook.add(func) + return func + + def on_bot_disconnect( + self, + func: T_WebSocketDisConnectionHook) -> T_WebSocketDisConnectionHook: + """ + :说明: + + 装饰一个函数使他在 bot 通过 WebSocket 连接断开时执行。 + + :函数参数: + + * ``bot: Bot``: 当前连接上的 Bot 对象 + """ + self._ws_disconnection_hook.add(func) + return func + + async def bot_connect(self, bot: "Bot") -> None: + """在 WebSocket 连接成功后,调用该函数来注册 bot 对象""" + self._clients[bot.self_id] = bot + coros = list(map(lambda x: x(bot), self._ws_connection_hook)) + if coros: + try: + await asyncio.gather(*coros) + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Error when running WebSocketConnection hook. " + "Running cancelled!") + + async def bot_disconnect(self, bot: "Bot") -> None: + """在 WebSocket 连接断开后,调用该函数来注销 bot 对象""" + if bot.self_id in self._clients: + del self._clients[bot.self_id] + coros = list(map(lambda x: x(bot), self._ws_disconnection_hook)) + if coros: + try: + await asyncio.gather(*coros) + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Error when running WebSocketDisConnection hook. " + "Running cancelled!") + @abc.abstractmethod def run(self, host: Optional[str] = None, diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index fa0a27eb..53acccf0 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -188,7 +188,7 @@ class Driver(BaseDriver): bot = BotClass(self, "websocket", self.config, x_self_id, websocket=ws) await ws.accept() - self._clients[x_self_id] = bot + await self.bot_connect(bot) logger.opt(colors=True).info( f"WebSocket Connection from {adapter.upper()} " f"Bot {x_self_id} Accepted!") @@ -202,7 +202,7 @@ class Driver(BaseDriver): asyncio.create_task(bot.handle_message(data)) finally: - del self._clients[x_self_id] + await self.bot_disconnect(bot) class WebSocket(BaseWebSocket): diff --git a/nonebot/typing.py b/nonebot/typing.py index 5d5799b4..f92d1379 100644 --- a/nonebot/typing.py +++ b/nonebot/typing.py @@ -54,6 +54,23 @@ T_StateFactory = Callable[["Bot", "Event"], Awaitable[T_State]] 事件处理状态 State 类工厂函数 """ +T_WebSocketConnectionHook = Callable[["Bot"], Awaitable[None]] +""" +:类型: ``Callable[[Bot], Awaitable[None]]`` + +:说明: + + WebSocket 连接建立时执行的函数 +""" +T_WebSocketDisconnectionHook = Callable[["Bot"], Awaitable[None]] +""" +:类型: ``Callable[[Bot], Awaitable[None]]`` + +:说明: + + WebSocket 连接断开时执行的函数 +""" + T_EventPreProcessor = Callable[["Bot", "Event", T_State], Awaitable[None]] """ :类型: ``Callable[[Bot, Event, T_State], Awaitable[None]]`` From 299c259d50adb8f6dda9772d803d927da0ae7ad1 Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Mon, 28 Dec 2020 13:53:24 +0800 Subject: [PATCH 2/5] :bug: fix hook block event receiving bug --- nonebot/drivers/__init__.py | 46 ++++++++++++++++++------------ nonebot/drivers/fastapi.py | 5 ++-- tests/test_plugins/test_ws_hook.py | 14 +++++++++ 3 files changed, 44 insertions(+), 21 deletions(-) create mode 100644 tests/test_plugins/test_ws_hook.py diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index b1942f85..b8a9de20 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -140,7 +140,7 @@ class Driver(abc.ABC): def on_bot_disconnect( self, - func: T_WebSocketDisConnectionHook) -> T_WebSocketDisConnectionHook: + func: T_WebSocketDisconnectionHook) -> T_WebSocketDisconnectionHook: """ :说明: @@ -153,30 +153,38 @@ class Driver(abc.ABC): self._ws_disconnection_hook.add(func) return func - async def bot_connect(self, bot: "Bot") -> None: + def bot_connect(self, bot: "Bot") -> None: """在 WebSocket 连接成功后,调用该函数来注册 bot 对象""" self._clients[bot.self_id] = bot - coros = list(map(lambda x: x(bot), self._ws_connection_hook)) - if coros: - try: - await asyncio.gather(*coros) - except Exception as e: - logger.opt(colors=True, exception=e).error( - "Error when running WebSocketConnection hook. " - "Running cancelled!") - async def bot_disconnect(self, bot: "Bot") -> None: + async def _run_hook(bot: "Bot") -> None: + coros = list(map(lambda x: x(bot), self._ws_connection_hook)) + if coros: + try: + await asyncio.gather(*coros) + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Error when running WebSocketConnection hook. " + "Running cancelled!") + + asyncio.create_task(_run_hook(bot)) + + def bot_disconnect(self, bot: "Bot") -> None: """在 WebSocket 连接断开后,调用该函数来注销 bot 对象""" if bot.self_id in self._clients: del self._clients[bot.self_id] - coros = list(map(lambda x: x(bot), self._ws_disconnection_hook)) - if coros: - try: - await asyncio.gather(*coros) - except Exception as e: - logger.opt(colors=True, exception=e).error( - "Error when running WebSocketDisConnection hook. " - "Running cancelled!") + + async def _run_hook(bot: "Bot") -> None: + coros = list(map(lambda x: x(bot), self._ws_disconnection_hook)) + if coros: + try: + await asyncio.gather(*coros) + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Error when running WebSocketDisConnection hook. " + "Running cancelled!") + + asyncio.create_task(_run_hook(bot)) @abc.abstractmethod def run(self, diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index 53acccf0..0e4457d9 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -188,11 +188,12 @@ class Driver(BaseDriver): bot = BotClass(self, "websocket", self.config, x_self_id, websocket=ws) await ws.accept() - await self.bot_connect(bot) logger.opt(colors=True).info( f"WebSocket Connection from {adapter.upper()} " f"Bot {x_self_id} Accepted!") + self.bot_connect(bot) + try: while not ws.closed: data = await ws.receive() @@ -202,7 +203,7 @@ class Driver(BaseDriver): asyncio.create_task(bot.handle_message(data)) finally: - await self.bot_disconnect(bot) + self.bot_disconnect(bot) class WebSocket(BaseWebSocket): diff --git a/tests/test_plugins/test_ws_hook.py b/tests/test_plugins/test_ws_hook.py new file mode 100644 index 00000000..7b5f9b2b --- /dev/null +++ b/tests/test_plugins/test_ws_hook.py @@ -0,0 +1,14 @@ +import nonebot +from nonebot.adapters import Bot + +driver = nonebot.get_driver() + + +@driver.on_bot_connect +async def connect(bot: Bot) -> None: + print("Connect", bot) + + +@driver.on_bot_disconnect +async def disconnect(bot: Bot) -> None: + print("Disconnect", bot) From 0c43d83494d17661ddaf7162b6f16eb3201c1652 Mon Sep 17 00:00:00 2001 From: nonebot Date: Mon, 28 Dec 2020 05:57:28 +0000 Subject: [PATCH 3/5] :memo: update api docs --- docs/api/drivers/README.md | 72 ++++++++++++++++++++++++++++++++++++++ docs/api/typing.md | 32 +++++++++++++++++ 2 files changed, 104 insertions(+) diff --git a/docs/api/drivers/README.md b/docs/api/drivers/README.md index e8045d62..dc5a6d61 100644 --- a/docs/api/drivers/README.md +++ b/docs/api/drivers/README.md @@ -32,6 +32,36 @@ Driver 基类。将后端框架封装,以满足适配器使用。 +### `_ws_connection_hook` + + +* **类型** + + `Set[T_WebSocketConnectionHook]` + + + +* **说明** + + WebSocket 连接建立时执行的函数 + + + +### `_ws_disconnection_hook` + + +* **类型** + + `Set[T_WebSocketDisconnectionHook]` + + + +* **说明** + + WebSocket 连接断开时执行的函数 + + + ### _abstract_ `__init__(env, config)` @@ -154,6 +184,48 @@ Driver 基类。将后端框架封装,以满足适配器使用。 注册一个在驱动停止时运行的函数 +### `on_bot_connect(func)` + + +* **说明** + + 装饰一个函数使他在 bot 通过 WebSocket 连接成功时执行。 + + + +* **函数参数** + + + * `bot: Bot`: 当前连接上的 Bot 对象 + + + +### `on_bot_disconnect(func)` + + +* **说明** + + 装饰一个函数使他在 bot 通过 WebSocket 连接断开时执行。 + + + +* **函数参数** + + + * `bot: Bot`: 当前连接上的 Bot 对象 + + + +### `bot_connect(bot)` + +在 WebSocket 连接成功后,调用该函数来注册 bot 对象 + + +### `bot_disconnect(bot)` + +在 WebSocket 连接断开后,调用该函数来注销 bot 对象 + + ### _abstract_ `run(host=None, port=None, *args, **kwargs)` diff --git a/docs/api/typing.md b/docs/api/typing.md index 3bf49897..5d1b3d7b 100644 --- a/docs/api/typing.md +++ b/docs/api/typing.md @@ -46,6 +46,38 @@ sidebarDepth: 0 +## `T_WebSocketConnectionHook` + + +* **类型** + + `Callable[[Bot], Awaitable[None]]` + + + +* **说明** + + WebSocket 连接建立时执行的函数 + + + + +## `T_WebSocketDisconnectionHook` + + +* **类型** + + `Callable[[Bot], Awaitable[None]]` + + + +* **说明** + + WebSocket 连接断开时执行的函数 + + + + ## `T_EventPreProcessor` From 1c9df5ac021c0fdea4d385b30caedd72c6a9ca8d Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Mon, 28 Dec 2020 13:59:54 +0800 Subject: [PATCH 4/5] :art: change method into private --- nonebot/drivers/__init__.py | 4 ++-- nonebot/drivers/fastapi.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/nonebot/drivers/__init__.py b/nonebot/drivers/__init__.py index b8a9de20..7e95ee91 100644 --- a/nonebot/drivers/__init__.py +++ b/nonebot/drivers/__init__.py @@ -153,7 +153,7 @@ class Driver(abc.ABC): self._ws_disconnection_hook.add(func) return func - def bot_connect(self, bot: "Bot") -> None: + def _bot_connect(self, bot: "Bot") -> None: """在 WebSocket 连接成功后,调用该函数来注册 bot 对象""" self._clients[bot.self_id] = bot @@ -169,7 +169,7 @@ class Driver(abc.ABC): asyncio.create_task(_run_hook(bot)) - def bot_disconnect(self, bot: "Bot") -> None: + def _bot_disconnect(self, bot: "Bot") -> None: """在 WebSocket 连接断开后,调用该函数来注销 bot 对象""" if bot.self_id in self._clients: del self._clients[bot.self_id] diff --git a/nonebot/drivers/fastapi.py b/nonebot/drivers/fastapi.py index 0e4457d9..42e21490 100644 --- a/nonebot/drivers/fastapi.py +++ b/nonebot/drivers/fastapi.py @@ -192,7 +192,7 @@ class Driver(BaseDriver): f"WebSocket Connection from {adapter.upper()} " f"Bot {x_self_id} Accepted!") - self.bot_connect(bot) + self._bot_connect(bot) try: while not ws.closed: @@ -203,7 +203,7 @@ class Driver(BaseDriver): asyncio.create_task(bot.handle_message(data)) finally: - self.bot_disconnect(bot) + self._bot_disconnect(bot) class WebSocket(BaseWebSocket): From c23a777a9f66d55867e6eccac677b28031ee8520 Mon Sep 17 00:00:00 2001 From: nonebot Date: Mon, 28 Dec 2020 06:01:32 +0000 Subject: [PATCH 5/5] :memo: update api docs --- docs/api/drivers/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/api/drivers/README.md b/docs/api/drivers/README.md index dc5a6d61..313717cb 100644 --- a/docs/api/drivers/README.md +++ b/docs/api/drivers/README.md @@ -216,12 +216,12 @@ Driver 基类。将后端框架封装,以满足适配器使用。 -### `bot_connect(bot)` +### `_bot_connect(bot)` 在 WebSocket 连接成功后,调用该函数来注册 bot 对象 -### `bot_disconnect(bot)` +### `_bot_disconnect(bot)` 在 WebSocket 连接断开后,调用该函数来注销 bot 对象