From 81cb3565033f81303efd8367d78735b75eb9956d Mon Sep 17 00:00:00 2001 From: Akirami <66513481+A-kirami@users.noreply.github.com> Date: Tue, 5 Sep 2023 00:17:55 +0800 Subject: [PATCH] =?UTF-8?q?:memo:=20Feature:=20=E8=A1=A5=E5=85=85=E4=BE=9D?= =?UTF-8?q?=E8=B5=96=E6=B3=A8=E5=85=A5=E9=83=A8=E5=88=86=E6=83=85=E5=86=B5?= =?UTF-8?q?=E4=B8=8B=E7=B1=BB=E5=9E=8B=E9=94=99=E8=AF=AF=E6=97=B6=E7=9A=84?= =?UTF-8?q?=E6=97=A5=E5=BF=97=E6=8F=90=E7=A4=BA=20(#2343)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot/dependencies/__init__.py | 39 ++++++++++++++------------------ 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/nonebot/dependencies/__init__.py b/nonebot/dependencies/__init__.py index dc787f77..4b73e0b2 100644 --- a/nonebot/dependencies/__init__.py +++ b/nonebot/dependencies/__init__.py @@ -101,17 +101,21 @@ class Dependent(Generic[R]): ) async def __call__(self, **kwargs: Any) -> R: - # do pre-check - await self.check(**kwargs) + try: + # do pre-check + await self.check(**kwargs) - # solve param values - values = await self.solve(**kwargs) + # solve param values + values = await self.solve(**kwargs) - # call function - if is_coroutine_callable(self.call): - return await cast(Callable[..., Awaitable[R]], self.call)(**values) - else: - return await run_sync(cast(Callable[..., R], self.call))(**values) + # call function + if is_coroutine_callable(self.call): + return await cast(Callable[..., Awaitable[R]], self.call)(**values) + else: + return await run_sync(cast(Callable[..., R], self.call))(**values) + except SkippedException as e: + logger.trace(f"{self} skipped due to {e}") + raise @staticmethod def parse_params( @@ -195,19 +199,10 @@ class Dependent(Generic[R]): return cls(call, params, parameterless_params) async def check(self, **params: Any) -> None: - try: - await asyncio.gather( - *(param._check(**params) for param in self.parameterless) - ) - await asyncio.gather( - *( - cast(Param, param.field_info)._check(**params) - for param in self.params - ) - ) - except SkippedException as e: - logger.trace(f"{self} skipped due to {e}") - raise + await asyncio.gather(*(param._check(**params) for param in self.parameterless)) + await asyncio.gather( + *(cast(Param, param.field_info)._check(**params) for param in self.params) + ) async def _solve_field(self, field: ModelField, params: Dict[str, Any]) -> Any: param = cast(Param, field.field_info)