nonebot2/nonebot/dependencies/__init__.py

230 lines
7.3 KiB
Python
Raw Normal View History

2022-01-22 15:23:07 +08:00
"""本模块模块实现了依赖注入的定义与处理。
FrontMatter:
mdx:
format: md
2022-01-22 15:23:07 +08:00
sidebar_position: 0
description: nonebot.dependencies 模块
"""
import abc
from collections.abc import Awaitable, Iterable
from dataclasses import dataclass, field
from functools import partial
import inspect
from typing import Any, Callable, Generic, Optional, TypeVar, cast
2021-11-12 20:55:59 +08:00
import anyio
from exceptiongroup import BaseExceptionGroup, catch
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined
from nonebot.exception import SkippedException
2021-11-14 18:51:23 +08:00
from nonebot.log import logger
from nonebot.typing import _DependentCallable
from nonebot.utils import (
flatten_exception_group,
is_coroutine_callable,
run_coro_with_shield,
run_sync,
)
2021-11-21 15:46:48 +08:00
from .utils import check_field_type, get_typed_signature
2022-01-15 21:27:43 +08:00
R = TypeVar("R")
T = TypeVar("T", bound="Dependent")
2021-11-12 20:55:59 +08:00
class Param(abc.ABC, FieldInfo):
2022-01-21 21:04:17 +08:00
"""依赖注入的基本单元 —— 参数。
继承自 `pydantic.fields.FieldInfo`用于描述参数信息不包括参数名
"""
def __init__(self, *args, validate: bool = False, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.validate = validate
@classmethod
def _check_param(
cls, param: inspect.Parameter, allow_types: tuple[type["Param"], ...]
) -> Optional["Param"]:
return
@classmethod
def _check_parameterless(
cls, value: Any, allow_types: tuple[type["Param"], ...]
) -> Optional["Param"]:
return
@abc.abstractmethod
async def _solve(self, **kwargs: Any) -> Any:
raise NotImplementedError
2021-11-12 20:55:59 +08:00
async def _check(self, **kwargs: Any) -> None:
return
@dataclass(frozen=True)
class Dependent(Generic[R]):
2022-01-21 21:04:17 +08:00
"""依赖注入容器
参数:
call: 依赖注入的可调用对象可以是任何 Callable 对象
pre_checkers: 依赖注入解析前的参数检查
params: 具名参数列表
parameterless: 匿名参数列表
allow_types: 允许的参数类型
"""
call: _DependentCallable[R]
params: tuple[ModelField, ...] = field(default_factory=tuple)
parameterless: tuple[Param, ...] = field(default_factory=tuple)
def __repr__(self) -> str:
if inspect.isfunction(self.call) or inspect.isclass(self.call):
call_str = self.call.__name__
else:
call_str = repr(self.call)
return (
f"Dependent(call={call_str}"
+ (f", parameterless={self.parameterless}" if self.parameterless else "")
+ ")"
)
async def __call__(self, **kwargs: Any) -> R:
exception: Optional[BaseExceptionGroup[SkippedException]] = None
def _handle_skipped(exc_group: BaseExceptionGroup[SkippedException]):
nonlocal exception
exception = exc_group
# raise one of the exceptions instead
excs = list(flatten_exception_group(exc_group))
logger.trace(f"{self} skipped due to {excs}")
with catch({SkippedException: _handle_skipped}):
# do pre-check
await self.check(**kwargs)
# solve param values
values = await self.solve(**kwargs)
# call function
if is_coroutine_callable(self.call):
return await cast(Callable[..., Awaitable[R]], self.call)(**values)
else:
return await run_sync(cast(Callable[..., R], self.call))(**values)
raise exception
2021-11-13 19:38:01 +08:00
@staticmethod
def parse_params(
call: _DependentCallable[R], allow_types: tuple[type[Param], ...]
) -> tuple[ModelField, ...]:
fields: list[ModelField] = []
params = get_typed_signature(call).parameters.values()
2021-11-13 19:38:01 +08:00
for param in params:
if isinstance(param.default, Param):
field_info = param.default
else:
for allow_type in allow_types:
if field_info := allow_type._check_param(param, allow_types):
break
else:
raise ValueError(
f"Unknown parameter {param.name} "
f"for function {call} with type {param.annotation}"
)
annotation: Any = Any
if param.annotation is not param.empty:
annotation = param.annotation
fields.append(
ModelField.construct(
name=param.name, annotation=annotation, field_info=field_info
)
)
2021-11-21 17:09:31 +08:00
return tuple(fields)
@staticmethod
def parse_parameterless(
parameterless: tuple[Any, ...], allow_types: tuple[type[Param], ...]
) -> tuple[Param, ...]:
parameterless_params: list[Param] = []
for value in parameterless:
for allow_type in allow_types:
if param := allow_type._check_parameterless(value, allow_types):
break
else:
raise ValueError(f"Unknown parameterless {value}")
parameterless_params.append(param)
return tuple(parameterless_params)
@classmethod
def parse(
cls,
*,
call: _DependentCallable[R],
parameterless: Optional[Iterable[Any]] = None,
allow_types: Iterable[type[Param]],
) -> "Dependent[R]":
allow_types = tuple(allow_types)
params = cls.parse_params(call, allow_types)
parameterless_params = (
()
if parameterless is None
else cls.parse_parameterless(tuple(parameterless), allow_types)
)
return cls(call, params, parameterless_params)
async def check(self, **params: Any) -> None:
if self.parameterless:
async with anyio.create_task_group() as tg:
for param in self.parameterless:
tg.start_soon(partial(param._check, **params))
if self.params:
async with anyio.create_task_group() as tg:
for param in self.params:
tg.start_soon(
partial(cast(Param, param.field_info)._check, **params)
)
async def _solve_field(self, field: ModelField, params: dict[str, Any]) -> Any:
param = cast(Param, field.field_info)
value = await param._solve(**params)
if value is PydanticUndefined:
value = field.get_default()
v = check_field_type(field, value)
return v if param.validate else value
async def solve(self, **params: Any) -> dict[str, Any]:
# solve parameterless
2021-12-20 00:28:02 +08:00
for param in self.parameterless:
await param._solve(**params)
# solve param values
result: dict[str, Any] = {}
if not self.params:
return result
async def _solve_field(field: ModelField, params: dict[str, Any]) -> None:
value = await self._solve_field(field, params)
result[field.name] = value
async with anyio.create_task_group() as tg:
for field in self.params:
# shield the task to prevent cancellation
# when one of the tasks raises an exception
# this will improve the dependency cache reusability
tg.start_soon(run_coro_with_shield, _solve_field(field, params))
return result
2022-01-21 21:04:17 +08:00
__autodoc__ = {"CustomConfig": False}