nonebot2/nonebot/dependencies/__init__.py

216 lines
6.9 KiB
Python
Raw Normal View History

2022-01-22 15:23:07 +08:00
"""本模块模块实现了依赖注入的定义与处理。
FrontMatter:
mdx:
format: md
2022-01-22 15:23:07 +08:00
sidebar_position: 0
description: nonebot.dependencies 模块
"""
import abc
2021-11-12 20:55:59 +08:00
import inspect
from functools import partial
from dataclasses import field, dataclass
from collections.abc import Iterable, Awaitable
from typing import Any, Generic, TypeVar, Callable, Optional, cast
2021-11-12 20:55:59 +08:00
import anyio
from exceptiongroup import BaseExceptionGroup, catch
2021-11-14 18:51:23 +08:00
from nonebot.log import logger
from nonebot.typing import _DependentCallable
from nonebot.exception import SkippedException
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined
from nonebot.utils import run_sync, is_coroutine_callable, flatten_exception_group
2021-11-21 15:46:48 +08:00
from .utils import check_field_type, get_typed_signature
2022-01-15 21:27:43 +08:00
R = TypeVar("R")
T = TypeVar("T", bound="Dependent")
2021-11-12 20:55:59 +08:00
class Param(abc.ABC, FieldInfo):
2022-01-21 21:04:17 +08:00
"""依赖注入的基本单元 —— 参数。
继承自 `pydantic.fields.FieldInfo`用于描述参数信息不包括参数名
"""
def __init__(self, *args, validate: bool = False, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.validate = validate
@classmethod
def _check_param(
cls, param: inspect.Parameter, allow_types: tuple[type["Param"], ...]
) -> Optional["Param"]:
return
@classmethod
def _check_parameterless(
cls, value: Any, allow_types: tuple[type["Param"], ...]
) -> Optional["Param"]:
return
@abc.abstractmethod
async def _solve(self, **kwargs: Any) -> Any:
raise NotImplementedError
2021-11-12 20:55:59 +08:00
async def _check(self, **kwargs: Any) -> None:
return
@dataclass(frozen=True)
class Dependent(Generic[R]):
2022-01-21 21:04:17 +08:00
"""依赖注入容器
参数:
call: 依赖注入的可调用对象可以是任何 Callable 对象
pre_checkers: 依赖注入解析前的参数检查
params: 具名参数列表
parameterless: 匿名参数列表
allow_types: 允许的参数类型
"""
call: _DependentCallable[R]
params: tuple[ModelField, ...] = field(default_factory=tuple)
parameterless: tuple[Param, ...] = field(default_factory=tuple)
def __repr__(self) -> str:
if inspect.isfunction(self.call) or inspect.isclass(self.call):
call_str = self.call.__name__
else:
call_str = repr(self.call)
return (
f"Dependent(call={call_str}"
+ (f", parameterless={self.parameterless}" if self.parameterless else "")
+ ")"
)
async def __call__(self, **kwargs: Any) -> R:
exception: Optional[BaseExceptionGroup[SkippedException]] = None
def _handle_skipped(exc_group: BaseExceptionGroup[SkippedException]):
nonlocal exception
exception = exc_group
# raise one of the exceptions instead
excs = list(flatten_exception_group(exc_group))
logger.trace(f"{self} skipped due to {excs}")
with catch({SkippedException: _handle_skipped}):
# do pre-check
await self.check(**kwargs)
# solve param values
values = await self.solve(**kwargs)
# call function
if is_coroutine_callable(self.call):
return await cast(Callable[..., Awaitable[R]], self.call)(**values)
else:
return await run_sync(cast(Callable[..., R], self.call))(**values)
raise exception
2021-11-13 19:38:01 +08:00
@staticmethod
def parse_params(
call: _DependentCallable[R], allow_types: tuple[type[Param], ...]
) -> tuple[ModelField, ...]:
fields: list[ModelField] = []
params = get_typed_signature(call).parameters.values()
2021-11-13 19:38:01 +08:00
for param in params:
if isinstance(param.default, Param):
field_info = param.default
else:
for allow_type in allow_types:
if field_info := allow_type._check_param(param, allow_types):
break
else:
raise ValueError(
f"Unknown parameter {param.name} "
f"for function {call} with type {param.annotation}"
)
annotation: Any = Any
if param.annotation is not param.empty:
annotation = param.annotation
fields.append(
ModelField.construct(
name=param.name, annotation=annotation, field_info=field_info
)
)
2021-11-21 17:09:31 +08:00
return tuple(fields)
@staticmethod
def parse_parameterless(
parameterless: tuple[Any, ...], allow_types: tuple[type[Param], ...]
) -> tuple[Param, ...]:
parameterless_params: list[Param] = []
for value in parameterless:
for allow_type in allow_types:
if param := allow_type._check_parameterless(value, allow_types):
break
else:
raise ValueError(f"Unknown parameterless {value}")
parameterless_params.append(param)
return tuple(parameterless_params)
@classmethod
def parse(
cls,
*,
call: _DependentCallable[R],
parameterless: Optional[Iterable[Any]] = None,
allow_types: Iterable[type[Param]],
) -> "Dependent[R]":
allow_types = tuple(allow_types)
params = cls.parse_params(call, allow_types)
parameterless_params = (
()
if parameterless is None
else cls.parse_parameterless(tuple(parameterless), allow_types)
)
return cls(call, params, parameterless_params)
async def check(self, **params: Any) -> None:
async with anyio.create_task_group() as tg:
for param in self.parameterless:
tg.start_soon(partial(param._check, **params))
async with anyio.create_task_group() as tg:
for param in self.params:
tg.start_soon(partial(cast(Param, param.field_info)._check, **params))
async def _solve_field(self, field: ModelField, params: dict[str, Any]) -> Any:
param = cast(Param, field.field_info)
value = await param._solve(**params)
if value is PydanticUndefined:
value = field.get_default()
v = check_field_type(field, value)
return v if param.validate else value
async def solve(self, **params: Any) -> dict[str, Any]:
# solve parameterless
2021-12-20 00:28:02 +08:00
for param in self.parameterless:
await param._solve(**params)
# solve param values
result: dict[str, Any] = {}
async def _solve_field(field: ModelField, params: dict[str, Any]) -> None:
value = await self._solve_field(field, params)
result[field.name] = value
async with anyio.create_task_group() as tg:
for field in self.params:
tg.start_soon(_solve_field, field, params)
return result
2022-01-21 21:04:17 +08:00
__autodoc__ = {"CustomConfig": False}