mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-27 18:45:05 +08:00
✨ Feature: 支持子依赖定义 Pydantic 类型校验 (#2310)
This commit is contained in:
parent
79f833b946
commit
f59271bd47
@ -45,6 +45,10 @@ class Param(abc.ABC, FieldInfo):
|
|||||||
继承自 `pydantic.fields.FieldInfo`,用于描述参数信息(不包括参数名)。
|
继承自 `pydantic.fields.FieldInfo`,用于描述参数信息(不包括参数名)。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, validate: bool = False, **kwargs: Any) -> None:
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.validate = validate
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _check_param(
|
def _check_param(
|
||||||
cls, param: inspect.Parameter, allow_types: Tuple[Type["Param"], ...]
|
cls, param: inspect.Parameter, allow_types: Tuple[Type["Param"], ...]
|
||||||
@ -206,10 +210,12 @@ class Dependent(Generic[R]):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
async def _solve_field(self, field: ModelField, params: Dict[str, Any]) -> Any:
|
async def _solve_field(self, field: ModelField, params: Dict[str, Any]) -> Any:
|
||||||
value = await cast(Param, field.field_info)._solve(**params)
|
param = cast(Param, field.field_info)
|
||||||
|
value = await param._solve(**params)
|
||||||
if value is Undefined:
|
if value is Undefined:
|
||||||
value = field.get_default()
|
value = field.get_default()
|
||||||
return check_field_type(field, value)
|
v = check_field_type(field, value)
|
||||||
|
return v if param.validate else value
|
||||||
|
|
||||||
async def solve(self, **params: Any) -> Dict[str, Any]:
|
async def solve(self, **params: Any) -> Dict[str, Any]:
|
||||||
# solve parameterless
|
# solve parameterless
|
||||||
|
@ -5,7 +5,7 @@ FrontMatter:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Dict, TypeVar, Callable, ForwardRef
|
from typing import Any, Dict, Callable, ForwardRef
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic.fields import ModelField
|
from pydantic.fields import ModelField
|
||||||
@ -13,8 +13,6 @@ from pydantic.typing import evaluate_forwardref
|
|||||||
|
|
||||||
from nonebot.exception import TypeMisMatch
|
from nonebot.exception import TypeMisMatch
|
||||||
|
|
||||||
V = TypeVar("V")
|
|
||||||
|
|
||||||
|
|
||||||
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||||
"""获取可调用对象签名"""
|
"""获取可调用对象签名"""
|
||||||
@ -49,10 +47,10 @@ def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) ->
|
|||||||
return annotation
|
return annotation
|
||||||
|
|
||||||
|
|
||||||
def check_field_type(field: ModelField, value: V) -> V:
|
def check_field_type(field: ModelField, value: Any) -> Any:
|
||||||
"""检查字段类型是否匹配"""
|
"""检查字段类型是否匹配"""
|
||||||
|
|
||||||
_, errs_ = field.validate(value, {}, loc=())
|
v, errs_ = field.validate(value, {}, loc=())
|
||||||
if errs_:
|
if errs_:
|
||||||
raise TypeMisMatch(field, value)
|
raise TypeMisMatch(field, value)
|
||||||
return value
|
return v
|
||||||
|
@ -1,11 +1,21 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Self, Annotated, override
|
||||||
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
|
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
|
||||||
from typing import TYPE_CHECKING, Any, Type, Tuple, Literal, Callable, Optional, cast
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Type,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
Literal,
|
||||||
|
Callable,
|
||||||
|
Optional,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
from pydantic.typing import get_args, get_origin
|
from pydantic.typing import get_args, get_origin
|
||||||
from pydantic.fields import Required, Undefined, ModelField
|
from pydantic.fields import Required, FieldInfo, Undefined, ModelField
|
||||||
|
|
||||||
from nonebot.dependencies.utils import check_field_type
|
from nonebot.dependencies.utils import check_field_type
|
||||||
from nonebot.dependencies import Param, Dependent, CustomConfig
|
from nonebot.dependencies import Param, Dependent, CustomConfig
|
||||||
@ -24,6 +34,23 @@ if TYPE_CHECKING:
|
|||||||
from nonebot.matcher import Matcher
|
from nonebot.matcher import Matcher
|
||||||
from nonebot.adapters import Bot, Event
|
from nonebot.adapters import Bot, Event
|
||||||
|
|
||||||
|
EXTRA_FIELD_INFO = (
|
||||||
|
"gt",
|
||||||
|
"lt",
|
||||||
|
"ge",
|
||||||
|
"le",
|
||||||
|
"multiple_of",
|
||||||
|
"allow_inf_nan",
|
||||||
|
"max_digits",
|
||||||
|
"decimal_places",
|
||||||
|
"min_items",
|
||||||
|
"max_items",
|
||||||
|
"unique_items",
|
||||||
|
"min_length",
|
||||||
|
"max_length",
|
||||||
|
"regex",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DependsInner:
|
class DependsInner:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -31,26 +58,31 @@ class DependsInner:
|
|||||||
dependency: Optional[T_Handler] = None,
|
dependency: Optional[T_Handler] = None,
|
||||||
*,
|
*,
|
||||||
use_cache: bool = True,
|
use_cache: bool = True,
|
||||||
|
validate: Union[bool, FieldInfo] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.dependency = dependency
|
self.dependency = dependency
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
|
self.validate = validate
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
dep = get_name(self.dependency)
|
dep = get_name(self.dependency)
|
||||||
cache = "" if self.use_cache else ", use_cache=False"
|
cache = "" if self.use_cache else ", use_cache=False"
|
||||||
return f"DependsInner({dep}{cache})"
|
validate = f", validate={self.validate}" if self.validate else ""
|
||||||
|
return f"DependsInner({dep}{cache}{validate})"
|
||||||
|
|
||||||
|
|
||||||
def Depends(
|
def Depends(
|
||||||
dependency: Optional[T_Handler] = None,
|
dependency: Optional[T_Handler] = None,
|
||||||
*,
|
*,
|
||||||
use_cache: bool = True,
|
use_cache: bool = True,
|
||||||
|
validate: Union[bool, FieldInfo] = False,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""子依赖装饰器
|
"""子依赖装饰器
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
dependency: 依赖函数。默认为参数的类型注释。
|
dependency: 依赖函数。默认为参数的类型注释。
|
||||||
use_cache: 是否使用缓存。默认为 `True`。
|
use_cache: 是否使用缓存。默认为 `True`。
|
||||||
|
validate: 是否使用 Pydantic 类型校验。默认为 `False`。
|
||||||
|
|
||||||
用法:
|
用法:
|
||||||
```python
|
```python
|
||||||
@ -70,7 +102,7 @@ def Depends(
|
|||||||
...
|
...
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
return DependsInner(dependency, use_cache=use_cache)
|
return DependsInner(dependency, use_cache=use_cache, validate=validate)
|
||||||
|
|
||||||
|
|
||||||
class DependParam(Param):
|
class DependParam(Param):
|
||||||
@ -85,23 +117,44 @@ class DependParam(Param):
|
|||||||
return f"Depends({self.extra['dependent']})"
|
return f"Depends({self.extra['dependent']})"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
def _from_field(
|
||||||
|
cls, sub_dependent: Dependent, use_cache: bool, validate: Union[bool, FieldInfo]
|
||||||
|
) -> Self:
|
||||||
|
kwargs = {}
|
||||||
|
if isinstance(validate, FieldInfo):
|
||||||
|
kwargs.update((k, getattr(validate, k)) for k in EXTRA_FIELD_INFO)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
Required,
|
||||||
|
validate=bool(validate),
|
||||||
|
**kwargs,
|
||||||
|
dependent=sub_dependent,
|
||||||
|
use_cache=use_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@override
|
||||||
def _check_param(
|
def _check_param(
|
||||||
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||||
) -> Optional["DependParam"]:
|
) -> Optional[Self]:
|
||||||
type_annotation, depends_inner = param.annotation, None
|
type_annotation, depends_inner = param.annotation, None
|
||||||
|
# extract type annotation and dependency from Annotated
|
||||||
if get_origin(param.annotation) is Annotated:
|
if get_origin(param.annotation) is Annotated:
|
||||||
type_annotation, *extra_args = get_args(param.annotation)
|
type_annotation, *extra_args = get_args(param.annotation)
|
||||||
depends_inner = next(
|
depends_inner = next(
|
||||||
(x for x in extra_args if isinstance(x, DependsInner)), None
|
(x for x in extra_args if isinstance(x, DependsInner)), None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# param default value takes higher priority
|
||||||
depends_inner = (
|
depends_inner = (
|
||||||
param.default if isinstance(param.default, DependsInner) else depends_inner
|
param.default if isinstance(param.default, DependsInner) else depends_inner
|
||||||
)
|
)
|
||||||
|
# not a dependent
|
||||||
if depends_inner is None:
|
if depends_inner is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
dependency: T_Handler
|
dependency: T_Handler
|
||||||
|
# sub dependency is not specified, use type annotation
|
||||||
if depends_inner.dependency is None:
|
if depends_inner.dependency is None:
|
||||||
assert (
|
assert (
|
||||||
type_annotation is not inspect.Signature.empty
|
type_annotation is not inspect.Signature.empty
|
||||||
@ -109,13 +162,18 @@ class DependParam(Param):
|
|||||||
dependency = type_annotation
|
dependency = type_annotation
|
||||||
else:
|
else:
|
||||||
dependency = depends_inner.dependency
|
dependency = depends_inner.dependency
|
||||||
|
# parse sub dependency
|
||||||
sub_dependent = Dependent[Any].parse(
|
sub_dependent = Dependent[Any].parse(
|
||||||
call=dependency,
|
call=dependency,
|
||||||
allow_types=allow_types,
|
allow_types=allow_types,
|
||||||
)
|
)
|
||||||
return cls(Required, use_cache=depends_inner.use_cache, dependent=sub_dependent)
|
|
||||||
|
return cls._from_field(
|
||||||
|
sub_dependent, depends_inner.use_cache, depends_inner.validate
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@override
|
||||||
def _check_parameterless(
|
def _check_parameterless(
|
||||||
cls, value: Any, allow_types: Tuple[Type[Param], ...]
|
cls, value: Any, allow_types: Tuple[Type[Param], ...]
|
||||||
) -> Optional["Param"]:
|
) -> Optional["Param"]:
|
||||||
@ -124,8 +182,9 @@ class DependParam(Param):
|
|||||||
dependent = Dependent[Any].parse(
|
dependent = Dependent[Any].parse(
|
||||||
call=value.dependency, allow_types=allow_types
|
call=value.dependency, allow_types=allow_types
|
||||||
)
|
)
|
||||||
return cls(Required, use_cache=value.use_cache, dependent=dependent)
|
return cls._from_field(dependent, value.use_cache, value.validate)
|
||||||
|
|
||||||
|
@override
|
||||||
async def _solve(
|
async def _solve(
|
||||||
self,
|
self,
|
||||||
stack: Optional[AsyncExitStack] = None,
|
stack: Optional[AsyncExitStack] = None,
|
||||||
@ -169,6 +228,7 @@ class DependParam(Param):
|
|||||||
dependency_cache[call] = task
|
dependency_cache[call] = task
|
||||||
return await task
|
return await task
|
||||||
|
|
||||||
|
@override
|
||||||
async def _check(self, **kwargs: Any) -> None:
|
async def _check(self, **kwargs: Any) -> None:
|
||||||
# run sub dependent pre-checkers
|
# run sub dependent pre-checkers
|
||||||
sub_dependent: Dependent = self.extra["dependent"]
|
sub_dependent: Dependent = self.extra["dependent"]
|
||||||
@ -195,9 +255,10 @@ class BotParam(Param):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@override
|
||||||
def _check_param(
|
def _check_param(
|
||||||
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||||
) -> Optional["BotParam"]:
|
) -> Optional[Self]:
|
||||||
from nonebot.adapters import Bot
|
from nonebot.adapters import Bot
|
||||||
|
|
||||||
# param type is Bot(s) or subclass(es) of Bot or None
|
# param type is Bot(s) or subclass(es) of Bot or None
|
||||||
@ -217,9 +278,11 @@ class BotParam(Param):
|
|||||||
elif param.annotation == param.empty and param.name == "bot":
|
elif param.annotation == param.empty and param.name == "bot":
|
||||||
return cls(Required)
|
return cls(Required)
|
||||||
|
|
||||||
|
@override
|
||||||
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
|
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
|
||||||
return bot
|
return bot
|
||||||
|
|
||||||
|
@override
|
||||||
async def _check(self, bot: "Bot", **kwargs: Any) -> None:
|
async def _check(self, bot: "Bot", **kwargs: Any) -> None:
|
||||||
if checker := self.extra.get("checker"):
|
if checker := self.extra.get("checker"):
|
||||||
check_field_type(checker, bot)
|
check_field_type(checker, bot)
|
||||||
@ -245,9 +308,10 @@ class EventParam(Param):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@override
|
||||||
def _check_param(
|
def _check_param(
|
||||||
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||||
) -> Optional["EventParam"]:
|
) -> Optional[Self]:
|
||||||
from nonebot.adapters import Event
|
from nonebot.adapters import Event
|
||||||
|
|
||||||
# param type is Event(s) or subclass(es) of Event or None
|
# param type is Event(s) or subclass(es) of Event or None
|
||||||
@ -267,9 +331,11 @@ class EventParam(Param):
|
|||||||
elif param.annotation == param.empty and param.name == "event":
|
elif param.annotation == param.empty and param.name == "event":
|
||||||
return cls(Required)
|
return cls(Required)
|
||||||
|
|
||||||
|
@override
|
||||||
async def _solve(self, event: "Event", **kwargs: Any) -> Any:
|
async def _solve(self, event: "Event", **kwargs: Any) -> Any:
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
@override
|
||||||
async def _check(self, event: "Event", **kwargs: Any) -> Any:
|
async def _check(self, event: "Event", **kwargs: Any) -> Any:
|
||||||
if checker := self.extra.get("checker", None):
|
if checker := self.extra.get("checker", None):
|
||||||
check_field_type(checker, event)
|
check_field_type(checker, event)
|
||||||
@ -287,9 +353,10 @@ class StateParam(Param):
|
|||||||
return "StateParam()"
|
return "StateParam()"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@override
|
||||||
def _check_param(
|
def _check_param(
|
||||||
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||||
) -> Optional["StateParam"]:
|
) -> Optional[Self]:
|
||||||
# param type is T_State
|
# param type is T_State
|
||||||
if param.annotation is T_State:
|
if param.annotation is T_State:
|
||||||
return cls(Required)
|
return cls(Required)
|
||||||
@ -297,6 +364,7 @@ class StateParam(Param):
|
|||||||
elif param.annotation == param.empty and param.name == "state":
|
elif param.annotation == param.empty and param.name == "state":
|
||||||
return cls(Required)
|
return cls(Required)
|
||||||
|
|
||||||
|
@override
|
||||||
async def _solve(self, state: T_State, **kwargs: Any) -> Any:
|
async def _solve(self, state: T_State, **kwargs: Any) -> Any:
|
||||||
return state
|
return state
|
||||||
|
|
||||||
@ -313,9 +381,10 @@ class MatcherParam(Param):
|
|||||||
return "MatcherParam()"
|
return "MatcherParam()"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@override
|
||||||
def _check_param(
|
def _check_param(
|
||||||
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||||
) -> Optional["MatcherParam"]:
|
) -> Optional[Self]:
|
||||||
from nonebot.matcher import Matcher
|
from nonebot.matcher import Matcher
|
||||||
|
|
||||||
# param type is Matcher(s) or subclass(es) of Matcher or None
|
# param type is Matcher(s) or subclass(es) of Matcher or None
|
||||||
@ -335,9 +404,11 @@ class MatcherParam(Param):
|
|||||||
elif param.annotation == param.empty and param.name == "matcher":
|
elif param.annotation == param.empty and param.name == "matcher":
|
||||||
return cls(Required)
|
return cls(Required)
|
||||||
|
|
||||||
|
@override
|
||||||
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
|
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
|
||||||
return matcher
|
return matcher
|
||||||
|
|
||||||
|
@override
|
||||||
async def _check(self, matcher: "Matcher", **kwargs: Any) -> Any:
|
async def _check(self, matcher: "Matcher", **kwargs: Any) -> Any:
|
||||||
if checker := self.extra.get("checker", None):
|
if checker := self.extra.get("checker", None):
|
||||||
check_field_type(checker, matcher)
|
check_field_type(checker, matcher)
|
||||||
@ -382,9 +453,10 @@ class ArgParam(Param):
|
|||||||
return f"ArgParam(key={self.extra['key']!r}, type={self.extra['type']!r})"
|
return f"ArgParam(key={self.extra['key']!r}, type={self.extra['type']!r})"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@override
|
||||||
def _check_param(
|
def _check_param(
|
||||||
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||||
) -> Optional["ArgParam"]:
|
) -> Optional[Self]:
|
||||||
if isinstance(param.default, ArgInner):
|
if isinstance(param.default, ArgInner):
|
||||||
return cls(
|
return cls(
|
||||||
Required, key=param.default.key or param.name, type=param.default.type
|
Required, key=param.default.key or param.name, type=param.default.type
|
||||||
@ -419,9 +491,10 @@ class ExceptionParam(Param):
|
|||||||
return "ExceptionParam()"
|
return "ExceptionParam()"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@override
|
||||||
def _check_param(
|
def _check_param(
|
||||||
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||||
) -> Optional["ExceptionParam"]:
|
) -> Optional[Self]:
|
||||||
# param type is Exception(s) or subclass(es) of Exception or None
|
# param type is Exception(s) or subclass(es) of Exception or None
|
||||||
if generic_check_issubclass(param.annotation, Exception):
|
if generic_check_issubclass(param.annotation, Exception):
|
||||||
return cls(Required)
|
return cls(Required)
|
||||||
@ -429,6 +502,7 @@ class ExceptionParam(Param):
|
|||||||
elif param.annotation == param.empty and param.name == "exception":
|
elif param.annotation == param.empty and param.name == "exception":
|
||||||
return cls(Required)
|
return cls(Required)
|
||||||
|
|
||||||
|
@override
|
||||||
async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any:
|
async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any:
|
||||||
return exception
|
return exception
|
||||||
|
|
||||||
@ -445,12 +519,14 @@ class DefaultParam(Param):
|
|||||||
return f"DefaultParam(default={self.default!r})"
|
return f"DefaultParam(default={self.default!r})"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@override
|
||||||
def _check_param(
|
def _check_param(
|
||||||
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
|
||||||
) -> Optional["DefaultParam"]:
|
) -> Optional[Self]:
|
||||||
if param.default != param.empty:
|
if param.default != param.empty:
|
||||||
return cls(param.default)
|
return cls(param.default)
|
||||||
|
|
||||||
|
@override
|
||||||
async def _solve(self, **kwargs: Any) -> Any:
|
async def _solve(self, **kwargs: Any) -> Any:
|
||||||
return Undefined
|
return Undefined
|
||||||
|
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
from nonebot import on_message
|
from nonebot import on_message
|
||||||
|
from nonebot.adapters import Bot
|
||||||
from nonebot.params import Depends
|
from nonebot.params import Depends
|
||||||
|
|
||||||
test_depends = on_message()
|
test_depends = on_message()
|
||||||
@ -33,6 +36,14 @@ class ClassDependency:
|
|||||||
y: int = Depends(gen_async)
|
y: int = Depends(gen_async)
|
||||||
|
|
||||||
|
|
||||||
|
class FooBot(Bot):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
async def sub_bot(b: FooBot) -> FooBot:
|
||||||
|
return b
|
||||||
|
|
||||||
|
|
||||||
# test parameterless
|
# test parameterless
|
||||||
@test_depends.handle(parameterless=[Depends(parameterless)])
|
@test_depends.handle(parameterless=[Depends(parameterless)])
|
||||||
async def depends(x: int = Depends(dependency)):
|
async def depends(x: int = Depends(dependency)):
|
||||||
@ -46,19 +57,46 @@ async def depends_cache(y: int = Depends(dependency, use_cache=True)):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
# test class dependency
|
||||||
async def class_depend(c: ClassDependency = Depends()):
|
async def class_depend(c: ClassDependency = Depends()):
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
# test annotated dependency
|
||||||
async def annotated_depend(x: Annotated[int, Depends(dependency)]):
|
async def annotated_depend(x: Annotated[int, Depends(dependency)]):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# test annotated class dependency
|
||||||
async def annotated_class_depend(c: Annotated[ClassDependency, Depends()]):
|
async def annotated_class_depend(c: Annotated[ClassDependency, Depends()]):
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
# test dependency priority
|
||||||
async def annotated_prior_depend(
|
async def annotated_prior_depend(
|
||||||
x: Annotated[int, Depends(lambda: 2)] = Depends(dependency)
|
x: Annotated[int, Depends(lambda: 2)] = Depends(dependency)
|
||||||
):
|
):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# test sub dependency type mismatch
|
||||||
|
async def sub_type_mismatch(b: FooBot = Depends(sub_bot)):
|
||||||
|
return b
|
||||||
|
|
||||||
|
|
||||||
|
# test type validate
|
||||||
|
async def validate(x: int = Depends(lambda: "1", validate=True)):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_fail(x: int = Depends(lambda: "not_number", validate=True)):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# test FieldInfo validate
|
||||||
|
async def validate_field(x: int = Depends(lambda: "1", validate=Field(gt=0))):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_field_fail(x: int = Depends(lambda: "0", validate=Field(gt=0))):
|
||||||
|
return x
|
||||||
|
@ -42,9 +42,14 @@ async def test_depend(app: App):
|
|||||||
ClassDependency,
|
ClassDependency,
|
||||||
runned,
|
runned,
|
||||||
depends,
|
depends,
|
||||||
|
validate,
|
||||||
class_depend,
|
class_depend,
|
||||||
test_depends,
|
test_depends,
|
||||||
|
validate_fail,
|
||||||
|
validate_field,
|
||||||
annotated_depend,
|
annotated_depend,
|
||||||
|
sub_type_mismatch,
|
||||||
|
validate_field_fail,
|
||||||
annotated_class_depend,
|
annotated_class_depend,
|
||||||
annotated_prior_depend,
|
annotated_prior_depend,
|
||||||
)
|
)
|
||||||
@ -62,8 +67,7 @@ async def test_depend(app: App):
|
|||||||
event_next = make_fake_event()()
|
event_next = make_fake_event()()
|
||||||
ctx.receive_event(bot, event_next)
|
ctx.receive_event(bot, event_next)
|
||||||
|
|
||||||
assert len(runned) == 2
|
assert runned == [1, 1]
|
||||||
assert runned[0] == runned[1] == 1
|
|
||||||
|
|
||||||
runned.clear()
|
runned.clear()
|
||||||
|
|
||||||
@ -84,6 +88,29 @@ async def test_depend(app: App):
|
|||||||
) as ctx:
|
) as ctx:
|
||||||
ctx.should_return(ClassDependency(x=1, y=2))
|
ctx.should_return(ClassDependency(x=1, y=2))
|
||||||
|
|
||||||
|
with pytest.raises(TypeMisMatch): # noqa: PT012
|
||||||
|
async with app.test_dependent(
|
||||||
|
sub_type_mismatch, allow_types=[DependParam, BotParam]
|
||||||
|
) as ctx:
|
||||||
|
bot = ctx.create_bot()
|
||||||
|
ctx.pass_params(bot=bot)
|
||||||
|
|
||||||
|
async with app.test_dependent(validate, allow_types=[DependParam]) as ctx:
|
||||||
|
ctx.should_return(1)
|
||||||
|
|
||||||
|
with pytest.raises(TypeMisMatch):
|
||||||
|
async with app.test_dependent(validate_fail, allow_types=[DependParam]) as ctx:
|
||||||
|
...
|
||||||
|
|
||||||
|
async with app.test_dependent(validate_field, allow_types=[DependParam]) as ctx:
|
||||||
|
ctx.should_return(1)
|
||||||
|
|
||||||
|
with pytest.raises(TypeMisMatch):
|
||||||
|
async with app.test_dependent(
|
||||||
|
validate_field_fail, allow_types=[DependParam]
|
||||||
|
) as ctx:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bot(app: App):
|
async def test_bot(app: App):
|
||||||
|
@ -353,6 +353,80 @@ async def _(x: int = Depends(random_result, use_cache=False)):
|
|||||||
缓存的生命周期与当前接收到的事件相同。接收到事件后,子依赖在首次执行时缓存,在该事件处理完成后,缓存就会被清除。
|
缓存的生命周期与当前接收到的事件相同。接收到事件后,子依赖在首次执行时缓存,在该事件处理完成后,缓存就会被清除。
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
### 类型转换与校验
|
||||||
|
|
||||||
|
在依赖注入系统中,我们可以对子依赖的返回值进行自动类型转换与校验。这个功能由 Pydantic 支持,因此我们通过参数类型注解自动使用 Pydantic 支持的类型转换。例如:
|
||||||
|
|
||||||
|
<Tabs groupId="python">
|
||||||
|
<TabItem value="3.9" label="Python 3.9+" default>
|
||||||
|
|
||||||
|
```python {6,9}
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from nonebot.params import Depends
|
||||||
|
from nonebot.adapters import Event
|
||||||
|
|
||||||
|
def get_user_id(event: Event) -> str:
|
||||||
|
return event.get_user_id()
|
||||||
|
|
||||||
|
async def _(user_id: Annotated[int, Depends(get_user_id, validate=True)]):
|
||||||
|
print(user_id)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="3.8" label="Python 3.8+">
|
||||||
|
|
||||||
|
```python {4,7}
|
||||||
|
from nonebot.params import Depends
|
||||||
|
from nonebot.adapters import Event
|
||||||
|
|
||||||
|
def get_user_id(event: Event) -> str:
|
||||||
|
return event.get_user_id()
|
||||||
|
|
||||||
|
async def _(user_id: int = Depends(get_user_id, validate=True)):
|
||||||
|
print(user_id)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
在进行类型自动转换的同时,Pydantic 还支持对数据进行更多的限制,如:大于、小于、长度等。使用方法如下:
|
||||||
|
|
||||||
|
<Tabs groupId="python">
|
||||||
|
<TabItem value="3.9" label="Python 3.9+" default>
|
||||||
|
|
||||||
|
```python {7,10}
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from nonebot.params import Depends
|
||||||
|
from nonebot.adapters import Event
|
||||||
|
|
||||||
|
def get_user_id(event: Event) -> str:
|
||||||
|
return event.get_user_id()
|
||||||
|
|
||||||
|
async def _(user_id: Annotated[int, Depends(get_user_id, validate=Field(gt=100))]):
|
||||||
|
print(user_id)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="3.8" label="Python 3.8+">
|
||||||
|
|
||||||
|
```python {5,8}
|
||||||
|
from pydantic import Field
|
||||||
|
from nonebot.params import Depends
|
||||||
|
from nonebot.adapters import Event
|
||||||
|
|
||||||
|
def get_user_id(event: Event) -> str:
|
||||||
|
return event.get_user_id()
|
||||||
|
|
||||||
|
async def _(user_id: int = Depends(get_user_id, validate=Field(gt=100))):
|
||||||
|
print(user_id)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
### 类作为依赖
|
### 类作为依赖
|
||||||
|
|
||||||
在前面的事例中,我们使用了函数作为子依赖。实际上,我们还可以使用类作为依赖。当我们在实例化一个类的时候,其实我们就在调用它,类本身也是一个可调用对象。例如:
|
在前面的事例中,我们使用了函数作为子依赖。实际上,我们还可以使用类作为依赖。当我们在实例化一个类的时候,其实我们就在调用它,类本身也是一个可调用对象。例如:
|
||||||
|
Loading…
Reference in New Issue
Block a user