2021-11-16 18:30:16 +08:00
|
|
|
"""
|
|
|
|
依赖注入处理模块
|
2021-11-27 12:16:31 +08:00
|
|
|
================
|
2021-11-16 18:30:16 +08:00
|
|
|
|
|
|
|
该模块实现了依赖注入的定义与处理。
|
|
|
|
"""
|
|
|
|
|
2021-12-12 18:19:08 +08:00
|
|
|
import abc
|
2021-11-12 20:55:59 +08:00
|
|
|
import inspect
|
2021-12-12 18:19:08 +08:00
|
|
|
from typing import Any, Dict, List, Type, Generic, TypeVar, Callable, Optional
|
2021-11-12 20:55:59 +08:00
|
|
|
|
2021-11-15 21:44:24 +08:00
|
|
|
from pydantic import BaseConfig
|
|
|
|
from pydantic.schema import get_annotation_from_field_info
|
2021-12-12 18:19:08 +08:00
|
|
|
from pydantic.fields import Required, FieldInfo, Undefined, ModelField
|
2021-11-15 21:44:24 +08:00
|
|
|
|
2021-11-14 18:51:23 +08:00
|
|
|
from nonebot.log import logger
|
2021-11-15 21:44:24 +08:00
|
|
|
from .utils import get_typed_signature
|
2021-11-19 18:18:53 +08:00
|
|
|
from nonebot.exception import SkippedException
|
2021-12-12 18:19:08 +08:00
|
|
|
from nonebot.utils import run_sync, is_coroutine_callable
|
2021-11-21 15:46:48 +08:00
|
|
|
|
2021-12-12 18:19:08 +08:00
|
|
|
T = TypeVar("T", bound="Dependent")
|
|
|
|
R = TypeVar("R")
|
2021-11-12 20:55:59 +08:00
|
|
|
|
|
|
|
|
2021-12-12 18:19:08 +08:00
|
|
|
class Param(abc.ABC, FieldInfo):
|
|
|
|
@classmethod
|
|
|
|
def _check_param(
|
|
|
|
cls, dependent: "Dependent", name: str, param: inspect.Parameter
|
|
|
|
) -> Optional["Param"]:
|
|
|
|
return None
|
2021-11-16 18:30:16 +08:00
|
|
|
|
2021-12-12 18:19:08 +08:00
|
|
|
@classmethod
|
|
|
|
def _check_parameterless(
|
|
|
|
cls, dependent: "Dependent", value: Any
|
|
|
|
) -> Optional["Param"]:
|
|
|
|
return None
|
2021-11-16 18:30:16 +08:00
|
|
|
|
2021-12-12 18:19:08 +08:00
|
|
|
@abc.abstractmethod
|
|
|
|
async def _solve(self, **kwargs: Any) -> Any:
|
|
|
|
raise NotImplementedError
|
2021-11-12 20:55:59 +08:00
|
|
|
|
2021-11-22 11:38:42 +08:00
|
|
|
|
2021-12-12 18:19:08 +08:00
|
|
|
class CustomConfig(BaseConfig):
|
|
|
|
arbitrary_types_allowed = True
|
2021-11-12 20:55:59 +08:00
|
|
|
|
2021-11-15 21:44:24 +08:00
|
|
|
|
2021-12-12 18:19:08 +08:00
|
|
|
class Dependent(Generic[R]):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
*,
|
|
|
|
call: Callable[..., Any],
|
|
|
|
params: Optional[List[ModelField]] = None,
|
|
|
|
parameterless: Optional[List[Param]] = None,
|
|
|
|
allow_types: Optional[List[Type[Param]]] = None,
|
|
|
|
) -> None:
|
|
|
|
self.call = call
|
|
|
|
self.params = params or []
|
|
|
|
self.parameterless = parameterless or []
|
|
|
|
self.allow_types = allow_types or []
|
|
|
|
|
|
|
|
async def __call__(self, **kwargs: Any) -> R:
|
|
|
|
values = await self.solve(**kwargs)
|
|
|
|
|
|
|
|
if is_coroutine_callable(self.call):
|
|
|
|
return await self.call(**values)
|
2021-12-05 17:29:38 +08:00
|
|
|
else:
|
2021-12-12 18:19:08 +08:00
|
|
|
return await run_sync(self.call)(**values)
|
2021-11-13 19:38:01 +08:00
|
|
|
|
2021-12-12 18:19:08 +08:00
|
|
|
def parse_param(self, name: str, param: inspect.Parameter) -> Param:
|
|
|
|
for allow_type in self.allow_types:
|
|
|
|
field_info = allow_type._check_param(self, name, param)
|
|
|
|
if field_info:
|
|
|
|
return field_info
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
f"Unknown parameter {name} for function {self.call} with type {param.annotation}"
|
|
|
|
)
|
2021-11-13 19:38:01 +08:00
|
|
|
|
2021-12-12 18:19:08 +08:00
|
|
|
def parse_parameterless(self, value: Any) -> Param:
|
|
|
|
for allow_type in self.allow_types:
|
|
|
|
field_info = allow_type._check_parameterless(self, value)
|
|
|
|
if field_info:
|
|
|
|
return field_info
|
|
|
|
else:
|
|
|
|
raise ValueError(
|
|
|
|
f"Unknown parameterless {value} for function {self.call} with type {type(value)}"
|
|
|
|
)
|
2021-11-16 18:30:16 +08:00
|
|
|
|
2021-12-12 18:19:08 +08:00
|
|
|
def prepend_parameterless(self, value: Any) -> None:
|
|
|
|
self.parameterless.insert(0, self.parse_parameterless(value))
|
|
|
|
|
|
|
|
def append_parameterless(self, value: Any) -> None:
|
|
|
|
self.parameterless.append(self.parse_parameterless(value))
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def parse(
|
|
|
|
cls: Type[T],
|
|
|
|
*,
|
|
|
|
call: Callable[..., Any],
|
|
|
|
parameterless: Optional[List[Any]] = None,
|
|
|
|
allow_types: Optional[List[Type[Param]]] = None,
|
|
|
|
) -> T:
|
|
|
|
signature = get_typed_signature(call)
|
|
|
|
params = signature.parameters
|
|
|
|
dependent = cls(
|
|
|
|
call=call,
|
|
|
|
allow_types=allow_types,
|
|
|
|
)
|
2021-11-13 19:38:01 +08:00
|
|
|
|
2021-12-12 18:19:08 +08:00
|
|
|
parameterless_params = [
|
|
|
|
dependent.parse_parameterless(param) for param in (parameterless or [])
|
|
|
|
]
|
|
|
|
dependent.parameterless.extend(parameterless_params)
|
2021-11-13 19:38:01 +08:00
|
|
|
|
2021-12-12 18:19:08 +08:00
|
|
|
for param_name, param in params.items():
|
|
|
|
default_value = Required
|
|
|
|
if param.default != param.empty:
|
|
|
|
default_value = param.default
|
2021-11-21 17:09:31 +08:00
|
|
|
|
2021-12-12 18:19:08 +08:00
|
|
|
if isinstance(default_value, Param):
|
|
|
|
field_info = default_value
|
|
|
|
default_value = field_info.default
|
|
|
|
else:
|
|
|
|
field_info = dependent.parse_param(param_name, param)
|
|
|
|
default_value = field_info.default
|
|
|
|
|
|
|
|
annotation: Any = Any
|
|
|
|
required = default_value == Required
|
|
|
|
if param.annotation != param.empty:
|
|
|
|
annotation = param.annotation
|
|
|
|
annotation = get_annotation_from_field_info(
|
|
|
|
annotation, field_info, param_name
|
|
|
|
)
|
|
|
|
dependent.params.append(
|
|
|
|
ModelField(
|
|
|
|
name=param_name,
|
|
|
|
type_=annotation,
|
|
|
|
class_validators=None,
|
|
|
|
model_config=CustomConfig,
|
|
|
|
default=None if required else default_value,
|
|
|
|
required=required,
|
|
|
|
field_info=field_info,
|
|
|
|
)
|
|
|
|
)
|
2021-11-21 17:09:31 +08:00
|
|
|
|
2021-12-12 18:19:08 +08:00
|
|
|
return dependent
|
|
|
|
|
|
|
|
async def solve(
|
|
|
|
self,
|
|
|
|
**params: Any,
|
|
|
|
) -> Dict[str, Any]:
|
|
|
|
values: Dict[str, Any] = {}
|
|
|
|
|
2021-12-20 00:28:02 +08:00
|
|
|
for param in self.parameterless:
|
|
|
|
await param._solve(**params)
|
|
|
|
|
2021-12-12 18:19:08 +08:00
|
|
|
for field in self.params:
|
|
|
|
field_info = field.field_info
|
|
|
|
assert isinstance(field_info, Param), "Params must be subclasses of Param"
|
|
|
|
value = await field_info._solve(**params)
|
|
|
|
if value == Undefined:
|
|
|
|
value = field.get_default()
|
|
|
|
_, errs_ = field.validate(value, values, loc=(str(field_info), field.alias))
|
|
|
|
if errs_:
|
|
|
|
logger.debug(
|
|
|
|
f"{field_info} "
|
|
|
|
f"type {type(value)} not match depends {self.call} "
|
|
|
|
f"annotation {field._type_display()}, ignored"
|
|
|
|
)
|
|
|
|
raise SkippedException(field, value)
|
|
|
|
else:
|
|
|
|
values[field.name] = value
|
2021-11-21 17:09:31 +08:00
|
|
|
|
2021-12-12 18:19:08 +08:00
|
|
|
return values
|