diff --git a/nonebot/internal/driver/driver.py b/nonebot/internal/driver/driver.py
index 0040dd10..94f9798d 100644
--- a/nonebot/internal/driver/driver.py
+++ b/nonebot/internal/driver/driver.py
@@ -1,6 +1,6 @@
import abc
import asyncio
-from contextlib import asynccontextmanager
+from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, AsyncGenerator
from nonebot.log import logger
@@ -8,8 +8,12 @@ from nonebot.config import Env, Config
from nonebot.dependencies import Dependent
from nonebot.exception import SkippedException
from nonebot.utils import escape_tag, run_coro_with_catch
-from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook
from nonebot.internal.params import BotParam, DependParam, DefaultParam
+from nonebot.typing import (
+ T_DependencyCache,
+ T_BotConnectionHook,
+ T_BotDisconnectionHook,
+)
from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup
@@ -135,20 +139,22 @@ class Driver(abc.ABC):
self._bots[bot.self_id] = bot
async def _run_hook(bot: "Bot") -> None:
- coros = list(
- map(
- lambda x: run_coro_with_catch(x(bot=bot), (SkippedException,)),
- self._bot_connection_hook,
- )
- )
- if coros:
- try:
- await asyncio.gather(*coros)
- except Exception as e:
- logger.opt(colors=True, exception=e).error(
- "Error when running WebSocketConnection hook. "
- "Running cancelled!"
+ dependency_cache: T_DependencyCache = {}
+ async with AsyncExitStack() as stack:
+ if coros := [
+ run_coro_with_catch(
+ hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
+ (SkippedException,),
)
+ for hook in self._bot_connection_hook
+ ]:
+ try:
+ await asyncio.gather(*coros)
+ except Exception as e:
+ logger.opt(colors=True, exception=e).error(
+ "Error when running WebSocketConnection hook. "
+ "Running cancelled!"
+ )
asyncio.create_task(_run_hook(bot))
@@ -158,20 +164,22 @@ class Driver(abc.ABC):
del self._bots[bot.self_id]
async def _run_hook(bot: "Bot") -> None:
- coros = list(
- map(
- lambda x: run_coro_with_catch(x(bot=bot), (SkippedException,)),
- self._bot_disconnection_hook,
- )
- )
- if coros:
- try:
- await asyncio.gather(*coros)
- except Exception as e:
- logger.opt(colors=True, exception=e).error(
- "Error when running WebSocketDisConnection hook. "
- "Running cancelled!"
+ dependency_cache: T_DependencyCache = {}
+ async with AsyncExitStack() as stack:
+ if coros := [
+ run_coro_with_catch(
+ hook(bot=bot, stack=stack, dependency_cache=dependency_cache),
+ (SkippedException,),
)
+ for hook in self._bot_disconnection_hook
+ ]:
+ try:
+ await asyncio.gather(*coros)
+ except Exception as e:
+ logger.opt(colors=True, exception=e).error(
+ "Error when running WebSocketDisConnection hook. "
+ "Running cancelled!"
+ )
asyncio.create_task(_run_hook(bot))
diff --git a/tests/test_driver.py b/tests/test_driver.py
index 128d7f05..d0b00c88 100644
--- a/tests/test_driver.py
+++ b/tests/test_driver.py
@@ -1,13 +1,16 @@
import json
import asyncio
-from typing import cast
+from typing import Any, Set, cast
import pytest
from nonebug import App
import nonebot
from nonebot.config import Env
+from nonebot.adapters import Bot
+from nonebot.params import Depends
from nonebot import _resolve_combine_expr
+from nonebot.dependencies import Dependent
from nonebot.exception import WebSocketClosed
from nonebot.drivers import (
URL,
@@ -169,3 +172,69 @@ async def test_http_driver(driver: Driver):
)
async def test_combine_driver(driver: Driver, driver_type: str):
assert driver.type == driver_type
+
+
+@pytest.mark.asyncio
+async def test_bot_connect_hook(app: App, driver: Driver):
+ with pytest.MonkeyPatch.context() as m:
+ conn_hooks: Set[Dependent[Any]] = set()
+ disconn_hooks: Set[Dependent[Any]] = set()
+ m.setattr(Driver, "_bot_connection_hook", conn_hooks)
+ m.setattr(Driver, "_bot_disconnection_hook", disconn_hooks)
+
+ conn_should_be_called = False
+ disconn_should_be_called = False
+ dependency_should_be_run = False
+ dependency_should_be_cleaned = False
+
+ async def dependency():
+ nonlocal dependency_should_be_run, dependency_should_be_cleaned
+ dependency_should_be_run = True
+ try:
+ yield 1
+ finally:
+ dependency_should_be_cleaned = True
+
+ @driver.on_bot_connect
+ async def conn_hook(foo: Bot, dep: int = Depends(dependency), default: int = 1):
+ nonlocal conn_should_be_called
+ conn_should_be_called = True
+
+ if foo is not bot:
+ pytest.fail("on_bot_connect hook called with wrong bot")
+ if dep != 1:
+ pytest.fail("on_bot_connect hook called with wrong dependency")
+ if default != 1:
+ pytest.fail("on_bot_connect hook called with wrong default value")
+
+ @driver.on_bot_disconnect
+ async def disconn_hook(
+ foo: Bot, dep: int = Depends(dependency), default: int = 1
+ ):
+ nonlocal disconn_should_be_called
+ disconn_should_be_called = True
+
+ if foo is not bot:
+ pytest.fail("on_bot_disconnect hook called with wrong bot")
+ if dep != 1:
+ pytest.fail("on_bot_connect hook called with wrong dependency")
+ if default != 1:
+ pytest.fail("on_bot_connect hook called with wrong default value")
+
+ if conn_hook not in {hook.call for hook in conn_hooks}:
+ pytest.fail("on_bot_connect hook not registered")
+ if disconn_hook not in {hook.call for hook in disconn_hooks}:
+ pytest.fail("on_bot_disconnect hook not registered")
+
+ async with app.test_api() as ctx:
+ bot = ctx.create_bot()
+
+ await asyncio.sleep(1)
+ if not conn_should_be_called:
+ pytest.fail("on_bot_connect hook not called")
+ if not disconn_should_be_called:
+ pytest.fail("on_bot_disconnect hook not called")
+ if not dependency_should_be_run:
+ pytest.fail("dependency not run")
+ if not dependency_should_be_cleaned:
+ pytest.fail("dependency not cleaned")