🔀 Merge pull request #123

New: WebSocket Connection Hook
This commit is contained in:
Ju4tCode 2020-12-28 14:12:04 +08:00 committed by GitHub
commit e14d3d8d73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 219 additions and 5 deletions

View File

@ -32,6 +32,36 @@ Driver 基类。将后端框架封装,以满足适配器使用。
### `_ws_connection_hook`
* **类型**
`Set[T_WebSocketConnectionHook]`
* **说明**
WebSocket 连接建立时执行的函数
### `_ws_disconnection_hook`
* **类型**
`Set[T_WebSocketDisconnectionHook]`
* **说明**
WebSocket 连接断开时执行的函数
### _abstract_ `__init__(env, config)` ### _abstract_ `__init__(env, config)`
@ -154,6 +184,48 @@ Driver 基类。将后端框架封装,以满足适配器使用。
注册一个在驱动停止时运行的函数 注册一个在驱动停止时运行的函数
### `on_bot_connect(func)`
* **说明**
装饰一个函数使他在 bot 通过 WebSocket 连接成功时执行。
* **函数参数**
* `bot: Bot`: 当前连接上的 Bot 对象
### `on_bot_disconnect(func)`
* **说明**
装饰一个函数使他在 bot 通过 WebSocket 连接断开时执行。
* **函数参数**
* `bot: Bot`: 当前连接上的 Bot 对象
### `_bot_connect(bot)`
在 WebSocket 连接成功后,调用该函数来注册 bot 对象
### `_bot_disconnect(bot)`
在 WebSocket 连接断开后,调用该函数来注销 bot 对象
### _abstract_ `run(host=None, port=None, *args, **kwargs)` ### _abstract_ `run(host=None, port=None, *args, **kwargs)`

View File

@ -46,6 +46,38 @@ sidebarDepth: 0
## `T_WebSocketConnectionHook`
* **类型**
`Callable[[Bot], Awaitable[None]]`
* **说明**
WebSocket 连接建立时执行的函数
## `T_WebSocketDisconnectionHook`
* **类型**
`Callable[[Bot], Awaitable[None]]`
* **说明**
WebSocket 连接断开时执行的函数
## `T_EventPreProcessor` ## `T_EventPreProcessor`

View File

@ -6,10 +6,12 @@
""" """
import abc import abc
from typing import Dict, Type, Optional, Callable, TYPE_CHECKING import asyncio
from typing import Set, Dict, Type, Optional, Callable, TYPE_CHECKING
from nonebot.log import logger from nonebot.log import logger
from nonebot.config import Env, Config from nonebot.config import Env, Config
from nonebot.typing import T_WebSocketConnectionHook, T_WebSocketDisconnectionHook
if TYPE_CHECKING: if TYPE_CHECKING:
from nonebot.adapters import Bot from nonebot.adapters import Bot
@ -25,6 +27,16 @@ class Driver(abc.ABC):
:类型: ``Dict[str, Type[Bot]]`` :类型: ``Dict[str, Type[Bot]]``
:说明: 已注册的适配器列表 :说明: 已注册的适配器列表
""" """
_ws_connection_hook: Set[T_WebSocketConnectionHook] = set()
"""
:类型: ``Set[T_WebSocketConnectionHook]``
:说明: WebSocket 连接建立时执行的函数
"""
_ws_disconnection_hook: Set[T_WebSocketDisconnectionHook] = set()
"""
:类型: ``Set[T_WebSocketDisconnectionHook]``
:说明: WebSocket 连接断开时执行的函数
"""
@abc.abstractmethod @abc.abstractmethod
def __init__(self, env: Env, config: Config): def __init__(self, env: Env, config: Config):
@ -93,8 +105,12 @@ class Driver(abc.ABC):
@property @property
def bots(self) -> Dict[str, "Bot"]: def bots(self) -> Dict[str, "Bot"]:
""" """
:类型: ``Dict[str, Bot]`` :类型:
:说明: 获取当前所有已连接的 Bot
``Dict[str, Bot]``
:说明:
获取当前所有已连接的 Bot
""" """
return self._clients return self._clients
@ -108,6 +124,68 @@ class Driver(abc.ABC):
"""注册一个在驱动停止时运行的函数""" """注册一个在驱动停止时运行的函数"""
raise NotImplementedError raise NotImplementedError
def on_bot_connect(
self, func: T_WebSocketConnectionHook) -> T_WebSocketConnectionHook:
"""
:说明:
装饰一个函数使他在 bot 通过 WebSocket 连接成功时执行
:函数参数:
* ``bot: Bot``: 当前连接上的 Bot 对象
"""
self._ws_connection_hook.add(func)
return func
def on_bot_disconnect(
self,
func: T_WebSocketDisconnectionHook) -> T_WebSocketDisconnectionHook:
"""
:说明:
装饰一个函数使他在 bot 通过 WebSocket 连接断开时执行
:函数参数:
* ``bot: Bot``: 当前连接上的 Bot 对象
"""
self._ws_disconnection_hook.add(func)
return func
def _bot_connect(self, bot: "Bot") -> None:
"""在 WebSocket 连接成功后,调用该函数来注册 bot 对象"""
self._clients[bot.self_id] = bot
async def _run_hook(bot: "Bot") -> None:
coros = list(map(lambda x: x(bot), self._ws_connection_hook))
if coros:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running WebSocketConnection hook. "
"Running cancelled!</bg #f8bbd0></r>")
asyncio.create_task(_run_hook(bot))
def _bot_disconnect(self, bot: "Bot") -> None:
"""在 WebSocket 连接断开后,调用该函数来注销 bot 对象"""
if bot.self_id in self._clients:
del self._clients[bot.self_id]
async def _run_hook(bot: "Bot") -> None:
coros = list(map(lambda x: x(bot), self._ws_disconnection_hook))
if coros:
try:
await asyncio.gather(*coros)
except Exception as e:
logger.opt(colors=True, exception=e).error(
"<r><bg #f8bbd0>Error when running WebSocketDisConnection hook. "
"Running cancelled!</bg #f8bbd0></r>")
asyncio.create_task(_run_hook(bot))
@abc.abstractmethod @abc.abstractmethod
def run(self, def run(self,
host: Optional[str] = None, host: Optional[str] = None,

View File

@ -188,11 +188,12 @@ class Driver(BaseDriver):
bot = BotClass(self, "websocket", self.config, x_self_id, websocket=ws) bot = BotClass(self, "websocket", self.config, x_self_id, websocket=ws)
await ws.accept() await ws.accept()
self._clients[x_self_id] = bot
logger.opt(colors=True).info( logger.opt(colors=True).info(
f"WebSocket Connection from <y>{adapter.upper()} " f"WebSocket Connection from <y>{adapter.upper()} "
f"Bot {x_self_id}</y> Accepted!") f"Bot {x_self_id}</y> Accepted!")
self._bot_connect(bot)
try: try:
while not ws.closed: while not ws.closed:
data = await ws.receive() data = await ws.receive()
@ -202,7 +203,7 @@ class Driver(BaseDriver):
asyncio.create_task(bot.handle_message(data)) asyncio.create_task(bot.handle_message(data))
finally: finally:
del self._clients[x_self_id] self._bot_disconnect(bot)
class WebSocket(BaseWebSocket): class WebSocket(BaseWebSocket):

View File

@ -54,6 +54,23 @@ T_StateFactory = Callable[["Bot", "Event"], Awaitable[T_State]]
事件处理状态 State 类工厂函数 事件处理状态 State 类工厂函数
""" """
T_WebSocketConnectionHook = Callable[["Bot"], Awaitable[None]]
"""
:类型: ``Callable[[Bot], Awaitable[None]]``
:说明:
WebSocket 连接建立时执行的函数
"""
T_WebSocketDisconnectionHook = Callable[["Bot"], Awaitable[None]]
"""
:类型: ``Callable[[Bot], Awaitable[None]]``
:说明:
WebSocket 连接断开时执行的函数
"""
T_EventPreProcessor = Callable[["Bot", "Event", T_State], Awaitable[None]] T_EventPreProcessor = Callable[["Bot", "Event", T_State], Awaitable[None]]
""" """
:类型: ``Callable[[Bot, Event, T_State], Awaitable[None]]`` :类型: ``Callable[[Bot, Event, T_State], Awaitable[None]]``

View File

@ -0,0 +1,14 @@
import nonebot
from nonebot.adapters import Bot
driver = nonebot.get_driver()
@driver.on_bot_connect
async def connect(bot: Bot) -> None:
print("Connect", bot)
@driver.on_bot_disconnect
async def disconnect(bot: Bot) -> None:
print("Disconnect", bot)