diff --git a/nonebot/dependencies/__init__.py b/nonebot/dependencies/__init__.py index 1b56089b..da02f531 100644 --- a/nonebot/dependencies/__init__.py +++ b/nonebot/dependencies/__init__.py @@ -21,7 +21,12 @@ from nonebot.log import logger from nonebot.typing import _DependentCallable from nonebot.exception import SkippedException from nonebot.compat import FieldInfo, ModelField, PydanticUndefined -from nonebot.utils import run_sync, is_coroutine_callable, flatten_exception_group +from nonebot.utils import ( + run_sync, + run_coro_with_shield, + is_coroutine_callable, + flatten_exception_group, +) from .utils import check_field_type, get_typed_signature @@ -207,7 +212,10 @@ class Dependent(Generic[R]): async with anyio.create_task_group() as tg: for field in self.params: - tg.start_soon(_solve_field, field, params) + # shield the task to prevent cancellation + # when one of the tasks raises an exception + # this will improve the dependency cache reusability + tg.start_soon(run_coro_with_shield, _solve_field(field, params)) return result diff --git a/nonebot/internal/params.py b/nonebot/internal/params.py index 87f7367b..9dbe0b40 100644 --- a/nonebot/internal/params.py +++ b/nonebot/internal/params.py @@ -115,6 +115,9 @@ class DependencyCache: self._exception: Optional[BaseException] = None self._waiter = anyio.Event() + def done(self) -> bool: + return self._state == CacheState.FINISHED + def result(self) -> Any: """获取子依赖结果""" @@ -304,11 +307,18 @@ class DependParam(Param): dependency_cache[call] = cache = DependencyCache() try: result = await target - cache.set_result(result) - return result - except BaseException as e: + except Exception as e: cache.set_exception(e) raise + except BaseException as e: + cache.set_exception(e) + # remove cache when base exception occurs + # e.g. CancelledError + dependency_cache.pop(call, None) + raise + else: + cache.set_result(result) + return result @override async def _check(self, **kwargs: Any) -> None: diff --git a/tests/plugins/param/param_depend.py b/tests/plugins/param/param_depend.py index 20c05892..6f28677f 100644 --- a/tests/plugins/param/param_depend.py +++ b/tests/plugins/param/param_depend.py @@ -1,6 +1,7 @@ from typing import Annotated from dataclasses import dataclass +import anyio from pydantic import Field from nonebot import on_message @@ -105,3 +106,26 @@ async def validate_field(x: int = Depends(lambda: "1", validate=Field(gt=0))): async def validate_field_fail(x: int = Depends(lambda: "0", validate=Field(gt=0))): return x + + +async def _dep(): + await anyio.sleep(1) + return 1 + + +def _dep_mismatch(): + return 1 + + +async def cache_exception_func1( + dep: int = Depends(_dep), + mismatch: dict = Depends(_dep_mismatch), +): + raise RuntimeError("Never reach here") + + +async def cache_exception_func2( + dep: int = Depends(_dep), + match: int = Depends(_dep_mismatch), +): + return dep diff --git a/tests/test_param.py b/tests/test_param.py index eb00996a..cdd9420b 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -51,6 +51,8 @@ async def test_depend(app: App): annotated_depend, sub_type_mismatch, validate_field_fail, + cache_exception_func1, + cache_exception_func2, annotated_class_depend, annotated_multi_depend, annotated_prior_depend, @@ -130,6 +132,26 @@ async def test_depend(app: App): if isinstance(exc_info.value, BaseExceptionGroup): assert exc_info.group_contains(TypeMisMatch) + # test cache reuse when exception raised + dependency_cache = {} + with pytest.raises((TypeMisMatch, BaseExceptionGroup)) as exc_info: + async with app.test_dependent( + cache_exception_func1, allow_types=[DependParam] + ) as ctx: + ctx.pass_params(dependency_cache=dependency_cache) + + if isinstance(exc_info.value, BaseExceptionGroup): + assert exc_info.group_contains(TypeMisMatch) + + # dependency solve tasks should be shielded even if one of them raises an exception + assert len(dependency_cache) == 2 + + async with app.test_dependent( + cache_exception_func2, allow_types=[DependParam] + ) as ctx: + ctx.pass_params(dependency_cache=dependency_cache) + ctx.should_return(1) + @pytest.mark.anyio async def test_bot(app: App):