mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-12-12 22:45:47 +08:00
230 lines
7.3 KiB
Python
230 lines
7.3 KiB
Python
"""本模块模块实现了依赖注入的定义与处理。
|
|
|
|
FrontMatter:
|
|
mdx:
|
|
format: md
|
|
sidebar_position: 0
|
|
description: nonebot.dependencies 模块
|
|
"""
|
|
|
|
import abc
|
|
from collections.abc import Awaitable, Iterable
|
|
from dataclasses import dataclass, field
|
|
from functools import partial
|
|
import inspect
|
|
from typing import Any, Callable, Generic, Optional, TypeVar, cast
|
|
|
|
import anyio
|
|
from exceptiongroup import BaseExceptionGroup, catch
|
|
|
|
from nonebot.compat import FieldInfo, ModelField, PydanticUndefined
|
|
from nonebot.exception import SkippedException
|
|
from nonebot.log import logger
|
|
from nonebot.typing import _DependentCallable
|
|
from nonebot.utils import (
|
|
flatten_exception_group,
|
|
is_coroutine_callable,
|
|
run_coro_with_shield,
|
|
run_sync,
|
|
)
|
|
|
|
from .utils import check_field_type, get_typed_signature
|
|
|
|
R = TypeVar("R")
|
|
T = TypeVar("T", bound="Dependent")
|
|
|
|
|
|
class Param(abc.ABC, FieldInfo):
|
|
"""依赖注入的基本单元 —— 参数。
|
|
|
|
继承自 `pydantic.fields.FieldInfo`,用于描述参数信息(不包括参数名)。
|
|
"""
|
|
|
|
def __init__(self, *args, validate: bool = False, **kwargs: Any) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self.validate = validate
|
|
|
|
@classmethod
|
|
def _check_param(
|
|
cls, param: inspect.Parameter, allow_types: tuple[type["Param"], ...]
|
|
) -> Optional["Param"]:
|
|
return
|
|
|
|
@classmethod
|
|
def _check_parameterless(
|
|
cls, value: Any, allow_types: tuple[type["Param"], ...]
|
|
) -> Optional["Param"]:
|
|
return
|
|
|
|
@abc.abstractmethod
|
|
async def _solve(self, **kwargs: Any) -> Any:
|
|
raise NotImplementedError
|
|
|
|
async def _check(self, **kwargs: Any) -> None:
|
|
return
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Dependent(Generic[R]):
|
|
"""依赖注入容器
|
|
|
|
参数:
|
|
call: 依赖注入的可调用对象,可以是任何 Callable 对象
|
|
pre_checkers: 依赖注入解析前的参数检查
|
|
params: 具名参数列表
|
|
parameterless: 匿名参数列表
|
|
allow_types: 允许的参数类型
|
|
"""
|
|
|
|
call: _DependentCallable[R]
|
|
params: tuple[ModelField, ...] = field(default_factory=tuple)
|
|
parameterless: tuple[Param, ...] = field(default_factory=tuple)
|
|
|
|
def __repr__(self) -> str:
|
|
if inspect.isfunction(self.call) or inspect.isclass(self.call):
|
|
call_str = self.call.__name__
|
|
else:
|
|
call_str = repr(self.call)
|
|
return (
|
|
f"Dependent(call={call_str}"
|
|
+ (f", parameterless={self.parameterless}" if self.parameterless else "")
|
|
+ ")"
|
|
)
|
|
|
|
async def __call__(self, **kwargs: Any) -> R:
|
|
exception: Optional[BaseExceptionGroup[SkippedException]] = None
|
|
|
|
def _handle_skipped(exc_group: BaseExceptionGroup[SkippedException]):
|
|
nonlocal exception
|
|
exception = exc_group
|
|
# raise one of the exceptions instead
|
|
excs = list(flatten_exception_group(exc_group))
|
|
logger.trace(f"{self} skipped due to {excs}")
|
|
|
|
with catch({SkippedException: _handle_skipped}):
|
|
# do pre-check
|
|
await self.check(**kwargs)
|
|
|
|
# solve param values
|
|
values = await self.solve(**kwargs)
|
|
|
|
# call function
|
|
if is_coroutine_callable(self.call):
|
|
return await cast(Callable[..., Awaitable[R]], self.call)(**values)
|
|
else:
|
|
return await run_sync(cast(Callable[..., R], self.call))(**values)
|
|
|
|
raise exception
|
|
|
|
@staticmethod
|
|
def parse_params(
|
|
call: _DependentCallable[R], allow_types: tuple[type[Param], ...]
|
|
) -> tuple[ModelField, ...]:
|
|
fields: list[ModelField] = []
|
|
params = get_typed_signature(call).parameters.values()
|
|
|
|
for param in params:
|
|
if isinstance(param.default, Param):
|
|
field_info = param.default
|
|
else:
|
|
for allow_type in allow_types:
|
|
if field_info := allow_type._check_param(param, allow_types):
|
|
break
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown parameter {param.name} "
|
|
f"for function {call} with type {param.annotation}"
|
|
)
|
|
|
|
annotation: Any = Any
|
|
if param.annotation is not param.empty:
|
|
annotation = param.annotation
|
|
|
|
fields.append(
|
|
ModelField.construct(
|
|
name=param.name, annotation=annotation, field_info=field_info
|
|
)
|
|
)
|
|
|
|
return tuple(fields)
|
|
|
|
@staticmethod
|
|
def parse_parameterless(
|
|
parameterless: tuple[Any, ...], allow_types: tuple[type[Param], ...]
|
|
) -> tuple[Param, ...]:
|
|
parameterless_params: list[Param] = []
|
|
for value in parameterless:
|
|
for allow_type in allow_types:
|
|
if param := allow_type._check_parameterless(value, allow_types):
|
|
break
|
|
else:
|
|
raise ValueError(f"Unknown parameterless {value}")
|
|
parameterless_params.append(param)
|
|
return tuple(parameterless_params)
|
|
|
|
@classmethod
|
|
def parse(
|
|
cls,
|
|
*,
|
|
call: _DependentCallable[R],
|
|
parameterless: Optional[Iterable[Any]] = None,
|
|
allow_types: Iterable[type[Param]],
|
|
) -> "Dependent[R]":
|
|
allow_types = tuple(allow_types)
|
|
|
|
params = cls.parse_params(call, allow_types)
|
|
parameterless_params = (
|
|
()
|
|
if parameterless is None
|
|
else cls.parse_parameterless(tuple(parameterless), allow_types)
|
|
)
|
|
|
|
return cls(call, params, parameterless_params)
|
|
|
|
async def check(self, **params: Any) -> None:
|
|
if self.parameterless:
|
|
async with anyio.create_task_group() as tg:
|
|
for param in self.parameterless:
|
|
tg.start_soon(partial(param._check, **params))
|
|
|
|
if self.params:
|
|
async with anyio.create_task_group() as tg:
|
|
for param in self.params:
|
|
tg.start_soon(
|
|
partial(cast(Param, param.field_info)._check, **params)
|
|
)
|
|
|
|
async def _solve_field(self, field: ModelField, params: dict[str, Any]) -> Any:
|
|
param = cast(Param, field.field_info)
|
|
value = await param._solve(**params)
|
|
if value is PydanticUndefined:
|
|
value = field.get_default()
|
|
v = check_field_type(field, value)
|
|
return v if param.validate else value
|
|
|
|
async def solve(self, **params: Any) -> dict[str, Any]:
|
|
# solve parameterless
|
|
for param in self.parameterless:
|
|
await param._solve(**params)
|
|
|
|
# solve param values
|
|
result: dict[str, Any] = {}
|
|
if not self.params:
|
|
return result
|
|
|
|
async def _solve_field(field: ModelField, params: dict[str, Any]) -> None:
|
|
value = await self._solve_field(field, params)
|
|
result[field.name] = value
|
|
|
|
async with anyio.create_task_group() as tg:
|
|
for field in self.params:
|
|
# shield the task to prevent cancellation
|
|
# when one of the tasks raises an exception
|
|
# this will improve the dependency cache reusability
|
|
tg.start_soon(run_coro_with_shield, _solve_field(field, params))
|
|
|
|
return result
|
|
|
|
|
|
__autodoc__ = {"CustomConfig": False}
|