mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 00:55:07 +08:00
🐛 Fix: 修复结构化并发子依赖取消缓存问题 (#3084)
This commit is contained in:
parent
be732cf9d8
commit
e3cb4c7907
@ -21,7 +21,12 @@ from nonebot.log import logger
|
|||||||
from nonebot.typing import _DependentCallable
|
from nonebot.typing import _DependentCallable
|
||||||
from nonebot.exception import SkippedException
|
from nonebot.exception import SkippedException
|
||||||
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined
|
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined
|
||||||
from nonebot.utils import run_sync, is_coroutine_callable, flatten_exception_group
|
from nonebot.utils import (
|
||||||
|
run_sync,
|
||||||
|
run_coro_with_shield,
|
||||||
|
is_coroutine_callable,
|
||||||
|
flatten_exception_group,
|
||||||
|
)
|
||||||
|
|
||||||
from .utils import check_field_type, get_typed_signature
|
from .utils import check_field_type, get_typed_signature
|
||||||
|
|
||||||
@ -207,7 +212,10 @@ class Dependent(Generic[R]):
|
|||||||
|
|
||||||
async with anyio.create_task_group() as tg:
|
async with anyio.create_task_group() as tg:
|
||||||
for field in self.params:
|
for field in self.params:
|
||||||
tg.start_soon(_solve_field, field, params)
|
# shield the task to prevent cancellation
|
||||||
|
# when one of the tasks raises an exception
|
||||||
|
# this will improve the dependency cache reusability
|
||||||
|
tg.start_soon(run_coro_with_shield, _solve_field(field, params))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -115,6 +115,9 @@ class DependencyCache:
|
|||||||
self._exception: Optional[BaseException] = None
|
self._exception: Optional[BaseException] = None
|
||||||
self._waiter = anyio.Event()
|
self._waiter = anyio.Event()
|
||||||
|
|
||||||
|
def done(self) -> bool:
|
||||||
|
return self._state == CacheState.FINISHED
|
||||||
|
|
||||||
def result(self) -> Any:
|
def result(self) -> Any:
|
||||||
"""获取子依赖结果"""
|
"""获取子依赖结果"""
|
||||||
|
|
||||||
@ -304,11 +307,18 @@ class DependParam(Param):
|
|||||||
dependency_cache[call] = cache = DependencyCache()
|
dependency_cache[call] = cache = DependencyCache()
|
||||||
try:
|
try:
|
||||||
result = await target
|
result = await target
|
||||||
cache.set_result(result)
|
except Exception as e:
|
||||||
return result
|
|
||||||
except BaseException as e:
|
|
||||||
cache.set_exception(e)
|
cache.set_exception(e)
|
||||||
raise
|
raise
|
||||||
|
except BaseException as e:
|
||||||
|
cache.set_exception(e)
|
||||||
|
# remove cache when base exception occurs
|
||||||
|
# e.g. CancelledError
|
||||||
|
dependency_cache.pop(call, None)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
cache.set_result(result)
|
||||||
|
return result
|
||||||
|
|
||||||
@override
|
@override
|
||||||
async def _check(self, **kwargs: Any) -> None:
|
async def _check(self, **kwargs: Any) -> None:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import anyio
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from nonebot import on_message
|
from nonebot import on_message
|
||||||
@ -105,3 +106,26 @@ async def validate_field(x: int = Depends(lambda: "1", validate=Field(gt=0))):
|
|||||||
|
|
||||||
async def validate_field_fail(x: int = Depends(lambda: "0", validate=Field(gt=0))):
|
async def validate_field_fail(x: int = Depends(lambda: "0", validate=Field(gt=0))):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
async def _dep():
|
||||||
|
await anyio.sleep(1)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def _dep_mismatch():
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
async def cache_exception_func1(
|
||||||
|
dep: int = Depends(_dep),
|
||||||
|
mismatch: dict = Depends(_dep_mismatch),
|
||||||
|
):
|
||||||
|
raise RuntimeError("Never reach here")
|
||||||
|
|
||||||
|
|
||||||
|
async def cache_exception_func2(
|
||||||
|
dep: int = Depends(_dep),
|
||||||
|
match: int = Depends(_dep_mismatch),
|
||||||
|
):
|
||||||
|
return dep
|
||||||
|
@ -51,6 +51,8 @@ async def test_depend(app: App):
|
|||||||
annotated_depend,
|
annotated_depend,
|
||||||
sub_type_mismatch,
|
sub_type_mismatch,
|
||||||
validate_field_fail,
|
validate_field_fail,
|
||||||
|
cache_exception_func1,
|
||||||
|
cache_exception_func2,
|
||||||
annotated_class_depend,
|
annotated_class_depend,
|
||||||
annotated_multi_depend,
|
annotated_multi_depend,
|
||||||
annotated_prior_depend,
|
annotated_prior_depend,
|
||||||
@ -130,6 +132,26 @@ async def test_depend(app: App):
|
|||||||
if isinstance(exc_info.value, BaseExceptionGroup):
|
if isinstance(exc_info.value, BaseExceptionGroup):
|
||||||
assert exc_info.group_contains(TypeMisMatch)
|
assert exc_info.group_contains(TypeMisMatch)
|
||||||
|
|
||||||
|
# test cache reuse when exception raised
|
||||||
|
dependency_cache = {}
|
||||||
|
with pytest.raises((TypeMisMatch, BaseExceptionGroup)) as exc_info:
|
||||||
|
async with app.test_dependent(
|
||||||
|
cache_exception_func1, allow_types=[DependParam]
|
||||||
|
) as ctx:
|
||||||
|
ctx.pass_params(dependency_cache=dependency_cache)
|
||||||
|
|
||||||
|
if isinstance(exc_info.value, BaseExceptionGroup):
|
||||||
|
assert exc_info.group_contains(TypeMisMatch)
|
||||||
|
|
||||||
|
# dependency solve tasks should be shielded even if one of them raises an exception
|
||||||
|
assert len(dependency_cache) == 2
|
||||||
|
|
||||||
|
async with app.test_dependent(
|
||||||
|
cache_exception_func2, allow_types=[DependParam]
|
||||||
|
) as ctx:
|
||||||
|
ctx.pass_params(dependency_cache=dependency_cache)
|
||||||
|
ctx.should_return(1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_bot(app: App):
|
async def test_bot(app: App):
|
||||||
|
Loading…
Reference in New Issue
Block a user