diff --git a/nonebot/internal/driver/driver.py b/nonebot/internal/driver/driver.py index 0040dd10..94f9798d 100644 --- a/nonebot/internal/driver/driver.py +++ b/nonebot/internal/driver/driver.py @@ -1,6 +1,6 @@ import abc import asyncio -from contextlib import asynccontextmanager +from contextlib import AsyncExitStack, asynccontextmanager from typing import TYPE_CHECKING, Any, Set, Dict, Type, Callable, AsyncGenerator from nonebot.log import logger @@ -8,8 +8,12 @@ from nonebot.config import Env, Config from nonebot.dependencies import Dependent from nonebot.exception import SkippedException from nonebot.utils import escape_tag, run_coro_with_catch -from nonebot.typing import T_BotConnectionHook, T_BotDisconnectionHook from nonebot.internal.params import BotParam, DependParam, DefaultParam +from nonebot.typing import ( + T_DependencyCache, + T_BotConnectionHook, + T_BotDisconnectionHook, +) from .model import Request, Response, WebSocket, HTTPServerSetup, WebSocketServerSetup @@ -135,20 +139,22 @@ class Driver(abc.ABC): self._bots[bot.self_id] = bot async def _run_hook(bot: "Bot") -> None: - coros = list( - map( - lambda x: run_coro_with_catch(x(bot=bot), (SkippedException,)), - self._bot_connection_hook, - ) - ) - if coros: - try: - await asyncio.gather(*coros) - except Exception as e: - logger.opt(colors=True, exception=e).error( - "Error when running WebSocketConnection hook. " - "Running cancelled!" + dependency_cache: T_DependencyCache = {} + async with AsyncExitStack() as stack: + if coros := [ + run_coro_with_catch( + hook(bot=bot, stack=stack, dependency_cache=dependency_cache), + (SkippedException,), ) + for hook in self._bot_connection_hook + ]: + try: + await asyncio.gather(*coros) + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Error when running WebSocketConnection hook. " + "Running cancelled!" + ) asyncio.create_task(_run_hook(bot)) @@ -158,20 +164,22 @@ class Driver(abc.ABC): del self._bots[bot.self_id] async def _run_hook(bot: "Bot") -> None: - coros = list( - map( - lambda x: run_coro_with_catch(x(bot=bot), (SkippedException,)), - self._bot_disconnection_hook, - ) - ) - if coros: - try: - await asyncio.gather(*coros) - except Exception as e: - logger.opt(colors=True, exception=e).error( - "Error when running WebSocketDisConnection hook. " - "Running cancelled!" + dependency_cache: T_DependencyCache = {} + async with AsyncExitStack() as stack: + if coros := [ + run_coro_with_catch( + hook(bot=bot, stack=stack, dependency_cache=dependency_cache), + (SkippedException,), ) + for hook in self._bot_disconnection_hook + ]: + try: + await asyncio.gather(*coros) + except Exception as e: + logger.opt(colors=True, exception=e).error( + "Error when running WebSocketDisConnection hook. " + "Running cancelled!" + ) asyncio.create_task(_run_hook(bot)) diff --git a/tests/test_driver.py b/tests/test_driver.py index 128d7f05..d0b00c88 100644 --- a/tests/test_driver.py +++ b/tests/test_driver.py @@ -1,13 +1,16 @@ import json import asyncio -from typing import cast +from typing import Any, Set, cast import pytest from nonebug import App import nonebot from nonebot.config import Env +from nonebot.adapters import Bot +from nonebot.params import Depends from nonebot import _resolve_combine_expr +from nonebot.dependencies import Dependent from nonebot.exception import WebSocketClosed from nonebot.drivers import ( URL, @@ -169,3 +172,69 @@ async def test_http_driver(driver: Driver): ) async def test_combine_driver(driver: Driver, driver_type: str): assert driver.type == driver_type + + +@pytest.mark.asyncio +async def test_bot_connect_hook(app: App, driver: Driver): + with pytest.MonkeyPatch.context() as m: + conn_hooks: Set[Dependent[Any]] = set() + disconn_hooks: Set[Dependent[Any]] = set() + m.setattr(Driver, "_bot_connection_hook", conn_hooks) + m.setattr(Driver, "_bot_disconnection_hook", disconn_hooks) + + conn_should_be_called = False + disconn_should_be_called = False + dependency_should_be_run = False + dependency_should_be_cleaned = False + + async def dependency(): + nonlocal dependency_should_be_run, dependency_should_be_cleaned + dependency_should_be_run = True + try: + yield 1 + finally: + dependency_should_be_cleaned = True + + @driver.on_bot_connect + async def conn_hook(foo: Bot, dep: int = Depends(dependency), default: int = 1): + nonlocal conn_should_be_called + conn_should_be_called = True + + if foo is not bot: + pytest.fail("on_bot_connect hook called with wrong bot") + if dep != 1: + pytest.fail("on_bot_connect hook called with wrong dependency") + if default != 1: + pytest.fail("on_bot_connect hook called with wrong default value") + + @driver.on_bot_disconnect + async def disconn_hook( + foo: Bot, dep: int = Depends(dependency), default: int = 1 + ): + nonlocal disconn_should_be_called + disconn_should_be_called = True + + if foo is not bot: + pytest.fail("on_bot_disconnect hook called with wrong bot") + if dep != 1: + pytest.fail("on_bot_connect hook called with wrong dependency") + if default != 1: + pytest.fail("on_bot_connect hook called with wrong default value") + + if conn_hook not in {hook.call for hook in conn_hooks}: + pytest.fail("on_bot_connect hook not registered") + if disconn_hook not in {hook.call for hook in disconn_hooks}: + pytest.fail("on_bot_disconnect hook not registered") + + async with app.test_api() as ctx: + bot = ctx.create_bot() + + await asyncio.sleep(1) + if not conn_should_be_called: + pytest.fail("on_bot_connect hook not called") + if not disconn_should_be_called: + pytest.fail("on_bot_disconnect hook not called") + if not dependency_should_be_run: + pytest.fail("dependency not run") + if not dependency_should_be_cleaned: + pytest.fail("dependency not cleaned")