nonebot2/nonebot/dependencies/__init__.py

219 lines
8.1 KiB
Python
Raw Normal View History

"""
依赖注入处理模块
===============
该模块实现了依赖注入的定义与处理
"""
2021-11-12 20:55:59 +08:00
import inspect
2021-11-13 19:38:01 +08:00
from itertools import chain
2021-11-15 21:44:24 +08:00
from typing import Any, Dict, List, Type, Tuple, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
2021-11-12 20:55:59 +08:00
2021-11-15 21:44:24 +08:00
from pydantic import BaseConfig
from pydantic.fields import Required, ModelField
from pydantic.schema import get_annotation_from_field_info
2021-11-14 18:51:23 +08:00
from nonebot.log import logger
from .models import Param as Param
2021-11-15 21:44:24 +08:00
from .utils import get_typed_signature
from .models import Dependent as Dependent
2021-11-19 18:18:53 +08:00
from nonebot.exception import SkippedException
from .models import DependsWrapper as DependsWrapper
from nonebot.typing import T_Handler, T_DependencyCache
from nonebot.utils import (run_sync, is_gen_callable, run_sync_ctx_manager,
is_async_gen_callable, is_coroutine_callable)
2021-11-12 20:55:59 +08:00
class CustomConfig(BaseConfig):
arbitrary_types_allowed = True
def get_param_sub_dependent(
*,
param: inspect.Parameter,
allow_types: Optional[List[Type[Param]]] = None) -> Dependent:
depends: DependsWrapper = param.default
2021-11-12 20:55:59 +08:00
if depends.dependency:
dependency = depends.dependency
else:
dependency = param.annotation
return get_sub_dependant(depends=depends,
dependency=dependency,
name=param.name,
allow_types=allow_types)
2021-11-12 20:55:59 +08:00
2021-11-15 21:44:24 +08:00
def get_parameterless_sub_dependant(
*,
depends: DependsWrapper,
allow_types: Optional[List[Type[Param]]] = None) -> Dependent:
2021-11-12 20:55:59 +08:00
assert callable(
depends.dependency
), "A parameter-less dependency must have a callable dependency"
2021-11-15 21:44:24 +08:00
return get_sub_dependant(depends=depends,
dependency=depends.dependency,
allow_types=allow_types)
2021-11-12 20:55:59 +08:00
def get_sub_dependant(
2021-11-15 21:44:24 +08:00
*,
depends: DependsWrapper,
dependency: T_Handler,
2021-11-15 21:44:24 +08:00
name: Optional[str] = None,
allow_types: Optional[List[Type[Param]]] = None) -> Dependent:
2021-11-15 21:44:24 +08:00
sub_dependant = get_dependent(func=dependency,
name=name,
use_cache=depends.use_cache,
allow_types=allow_types)
2021-11-12 20:55:59 +08:00
return sub_dependant
def get_dependent(*,
func: T_Handler,
name: Optional[str] = None,
use_cache: bool = True,
allow_types: Optional[List[Type[Param]]] = None) -> Dependent:
2021-11-12 20:55:59 +08:00
signature = get_typed_signature(func)
params = signature.parameters
dependent = Dependent(func=func,
name=name,
allow_types=allow_types,
use_cache=use_cache)
2021-11-12 20:55:59 +08:00
for param_name, param in params.items():
if isinstance(param.default, DependsWrapper):
sub_dependent = get_param_sub_dependent(param=param,
allow_types=allow_types)
2021-11-12 20:55:59 +08:00
dependent.dependencies.append(sub_dependent)
continue
for allow_type in dependent.allow_types:
if allow_type._check(param_name, param):
field_info = allow_type(param.default)
2021-11-15 21:44:24 +08:00
break
2021-11-13 19:38:01 +08:00
else:
raise ValueError(
f"Unknown parameter {param_name} for funcction {func} with type {param.annotation}"
)
2021-11-12 20:55:59 +08:00
2021-11-15 21:44:24 +08:00
annotation: Any = Any
if param.annotation != param.empty:
annotation = param.annotation
annotation = get_annotation_from_field_info(annotation, field_info,
param_name)
dependent.params.append(
ModelField(name=param_name,
type_=annotation,
class_validators=None,
model_config=CustomConfig,
2021-11-15 21:44:24 +08:00
default=Required,
required=True,
field_info=field_info))
2021-11-12 20:55:59 +08:00
return dependent
2021-11-13 19:38:01 +08:00
async def solve_dependencies(
2021-11-19 18:18:53 +08:00
*,
_dependent: Dependent,
_stack: Optional[AsyncExitStack] = None,
_sub_dependents: Optional[List[Dependent]] = None,
_dependency_overrides_provider: Optional[Any] = None,
_dependency_cache: Optional[T_DependencyCache] = None,
**params: Any) -> Tuple[Dict[str, Any], T_DependencyCache]:
2021-11-13 19:38:01 +08:00
values: Dict[str, Any] = {}
2021-11-19 18:18:53 +08:00
dependency_cache = _dependency_cache or {}
2021-11-13 19:38:01 +08:00
# solve sub dependencies
2021-11-14 01:34:25 +08:00
sub_dependent: Dependent
2021-11-19 18:18:53 +08:00
for sub_dependent in chain(_sub_dependents or tuple(),
_dependent.dependencies):
2021-11-14 01:34:25 +08:00
sub_dependent.func = cast(Callable[..., Any], sub_dependent.func)
sub_dependent.cache_key = cast(Callable[..., Any],
2021-11-14 01:34:25 +08:00
sub_dependent.cache_key)
func = sub_dependent.func
2021-11-13 19:38:01 +08:00
# dependency overrides
2021-11-14 01:34:25 +08:00
use_sub_dependant = sub_dependent
2021-11-19 18:18:53 +08:00
if (_dependency_overrides_provider and hasattr(
_dependency_overrides_provider, "dependency_overrides")):
2021-11-14 01:34:25 +08:00
original_call = sub_dependent.func
2021-11-19 18:18:53 +08:00
func = getattr(_dependency_overrides_provider,
2021-11-13 19:38:01 +08:00
"dependency_overrides",
{}).get(original_call, original_call)
use_sub_dependant = get_dependent(
func=func,
2021-11-14 01:34:25 +08:00
name=sub_dependent.name,
allow_types=sub_dependent.allow_types,
2021-11-13 19:38:01 +08:00
)
# solve sub dependency with current cache
solved_result = await solve_dependencies(
2021-11-19 18:18:53 +08:00
_dependent=use_sub_dependant,
_dependency_overrides_provider=_dependency_overrides_provider,
2021-11-13 19:38:01 +08:00
dependency_cache=dependency_cache,
**params)
2021-11-19 18:18:53 +08:00
sub_values, sub_dependency_cache = solved_result
2021-11-13 19:38:01 +08:00
# update cache?
dependency_cache.update(sub_dependency_cache)
# run dependency function
2021-11-14 01:34:25 +08:00
if sub_dependent.use_cache and sub_dependent.cache_key in dependency_cache:
solved = dependency_cache[sub_dependent.cache_key]
elif is_gen_callable(func) or is_async_gen_callable(func):
assert isinstance(
2021-11-19 18:18:53 +08:00
_stack, AsyncExitStack
), "Generator dependency should be called in context"
if is_gen_callable(func):
cm = run_sync_ctx_manager(contextmanager(func)(**sub_values))
else:
cm = asynccontextmanager(func)(**sub_values)
2021-11-19 18:18:53 +08:00
solved = await _stack.enter_async_context(cm)
2021-11-13 19:38:01 +08:00
elif is_coroutine_callable(func):
solved = await func(**sub_values)
else:
solved = await run_sync(func)(**sub_values)
# parameter dependency
2021-11-14 01:34:25 +08:00
if sub_dependent.name is not None:
values[sub_dependent.name] = solved
2021-11-13 19:38:01 +08:00
# save current dependency to cache
2021-11-14 01:34:25 +08:00
if sub_dependent.cache_key not in dependency_cache:
dependency_cache[sub_dependent.cache_key] = solved
2021-11-13 19:38:01 +08:00
# usual dependency
2021-11-19 18:18:53 +08:00
for field in _dependent.params:
2021-11-15 21:44:24 +08:00
field_info = field.field_info
assert isinstance(field_info,
Param), "Params must be subclasses of Param"
value = field_info._solve(**params)
2021-11-15 21:44:24 +08:00
_, errs_ = field.validate(value,
values,
loc=(str(field_info), field.alias))
2021-11-15 21:44:24 +08:00
if errs_:
logger.debug(
f"{field_info} "
2021-11-19 18:18:53 +08:00
f"type {type(value)} not match depends {_dependent.func} "
2021-11-15 21:44:24 +08:00
f"annotation {field._type_display()}, ignored")
2021-11-19 18:18:53 +08:00
raise SkippedException
2021-11-15 21:44:24 +08:00
else:
values[field.name] = value
2021-11-19 18:18:53 +08:00
return values, dependency_cache
2021-11-13 19:38:01 +08:00
def Depends(dependency: Optional[T_Handler] = None,
2021-11-13 19:38:01 +08:00
*,
use_cache: bool = True) -> Any:
"""
:说明:
参数依赖注入装饰器
2021-11-13 19:38:01 +08:00
:参数:
2021-11-13 19:38:01 +08:00
* ``dependency: Optional[Callable[..., Any]] = None``: 依赖函数默认为参数的类型注释
* ``use_cache: bool = True``: 是否使用缓存默认为 ``True``
"""
return DependsWrapper(dependency=dependency, use_cache=use_cache)