diff --git a/nonebot/dependencies/__init__.py b/nonebot/dependencies/__init__.py index 044837aa..c2fbf00f 100644 --- a/nonebot/dependencies/__init__.py +++ b/nonebot/dependencies/__init__.py @@ -45,6 +45,10 @@ class Param(abc.ABC, FieldInfo): 继承自 `pydantic.fields.FieldInfo`,用于描述参数信息(不包括参数名)。 """ + def __init__(self, *args, validate: bool = False, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.validate = validate + @classmethod def _check_param( cls, param: inspect.Parameter, allow_types: Tuple[Type["Param"], ...] @@ -206,10 +210,12 @@ class Dependent(Generic[R]): raise async def _solve_field(self, field: ModelField, params: Dict[str, Any]) -> Any: - value = await cast(Param, field.field_info)._solve(**params) + param = cast(Param, field.field_info) + value = await param._solve(**params) if value is Undefined: value = field.get_default() - return check_field_type(field, value) + v = check_field_type(field, value) + return v if param.validate else value async def solve(self, **params: Any) -> Dict[str, Any]: # solve parameterless diff --git a/nonebot/dependencies/utils.py b/nonebot/dependencies/utils.py index 55029c34..dcc7385f 100644 --- a/nonebot/dependencies/utils.py +++ b/nonebot/dependencies/utils.py @@ -5,7 +5,7 @@ FrontMatter: """ import inspect -from typing import Any, Dict, TypeVar, Callable, ForwardRef +from typing import Any, Dict, Callable, ForwardRef from loguru import logger from pydantic.fields import ModelField @@ -13,8 +13,6 @@ from pydantic.typing import evaluate_forwardref from nonebot.exception import TypeMisMatch -V = TypeVar("V") - def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: """获取可调用对象签名""" @@ -49,10 +47,10 @@ def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> return annotation -def check_field_type(field: ModelField, value: V) -> V: +def check_field_type(field: ModelField, value: Any) -> Any: """检查字段类型是否匹配""" - _, errs_ = field.validate(value, {}, loc=()) + v, errs_ = field.validate(value, {}, loc=()) if errs_: raise TypeMisMatch(field, value) - return value + return v diff --git a/nonebot/internal/params.py b/nonebot/internal/params.py index 9320742c..ca2dc8ad 100644 --- a/nonebot/internal/params.py +++ b/nonebot/internal/params.py @@ -1,11 +1,21 @@ import asyncio import inspect -from typing_extensions import Annotated +from typing_extensions import Self, Annotated, override from contextlib import AsyncExitStack, contextmanager, asynccontextmanager -from typing import TYPE_CHECKING, Any, Type, Tuple, Literal, Callable, Optional, cast +from typing import ( + TYPE_CHECKING, + Any, + Type, + Tuple, + Union, + Literal, + Callable, + Optional, + cast, +) from pydantic.typing import get_args, get_origin -from pydantic.fields import Required, Undefined, ModelField +from pydantic.fields import Required, FieldInfo, Undefined, ModelField from nonebot.dependencies.utils import check_field_type from nonebot.dependencies import Param, Dependent, CustomConfig @@ -24,6 +34,23 @@ if TYPE_CHECKING: from nonebot.matcher import Matcher from nonebot.adapters import Bot, Event +EXTRA_FIELD_INFO = ( + "gt", + "lt", + "ge", + "le", + "multiple_of", + "allow_inf_nan", + "max_digits", + "decimal_places", + "min_items", + "max_items", + "unique_items", + "min_length", + "max_length", + "regex", +) + class DependsInner: def __init__( @@ -31,26 +58,31 @@ class DependsInner: dependency: Optional[T_Handler] = None, *, use_cache: bool = True, + validate: Union[bool, FieldInfo] = False, ) -> None: self.dependency = dependency self.use_cache = use_cache + self.validate = validate def __repr__(self) -> str: dep = get_name(self.dependency) cache = "" if self.use_cache else ", use_cache=False" - return f"DependsInner({dep}{cache})" + validate = f", validate={self.validate}" if self.validate else "" + return f"DependsInner({dep}{cache}{validate})" def Depends( dependency: Optional[T_Handler] = None, *, use_cache: bool = True, + validate: Union[bool, FieldInfo] = False, ) -> Any: """子依赖装饰器 参数: dependency: 依赖函数。默认为参数的类型注释。 use_cache: 是否使用缓存。默认为 `True`。 + validate: 是否使用 Pydantic 类型校验。默认为 `False`。 用法: ```python @@ -70,7 +102,7 @@ def Depends( ... ``` """ - return DependsInner(dependency, use_cache=use_cache) + return DependsInner(dependency, use_cache=use_cache, validate=validate) class DependParam(Param): @@ -85,23 +117,44 @@ class DependParam(Param): return f"Depends({self.extra['dependent']})" @classmethod + def _from_field( + cls, sub_dependent: Dependent, use_cache: bool, validate: Union[bool, FieldInfo] + ) -> Self: + kwargs = {} + if isinstance(validate, FieldInfo): + kwargs.update((k, getattr(validate, k)) for k in EXTRA_FIELD_INFO) + + return cls( + Required, + validate=bool(validate), + **kwargs, + dependent=sub_dependent, + use_cache=use_cache, + ) + + @classmethod + @override def _check_param( cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] - ) -> Optional["DependParam"]: + ) -> Optional[Self]: type_annotation, depends_inner = param.annotation, None + # extract type annotation and dependency from Annotated if get_origin(param.annotation) is Annotated: type_annotation, *extra_args = get_args(param.annotation) depends_inner = next( (x for x in extra_args if isinstance(x, DependsInner)), None ) + # param default value takes higher priority depends_inner = ( param.default if isinstance(param.default, DependsInner) else depends_inner ) + # not a dependent if depends_inner is None: return dependency: T_Handler + # sub dependency is not specified, use type annotation if depends_inner.dependency is None: assert ( type_annotation is not inspect.Signature.empty @@ -109,13 +162,18 @@ class DependParam(Param): dependency = type_annotation else: dependency = depends_inner.dependency + # parse sub dependency sub_dependent = Dependent[Any].parse( call=dependency, allow_types=allow_types, ) - return cls(Required, use_cache=depends_inner.use_cache, dependent=sub_dependent) + + return cls._from_field( + sub_dependent, depends_inner.use_cache, depends_inner.validate + ) @classmethod + @override def _check_parameterless( cls, value: Any, allow_types: Tuple[Type[Param], ...] ) -> Optional["Param"]: @@ -124,8 +182,9 @@ class DependParam(Param): dependent = Dependent[Any].parse( call=value.dependency, allow_types=allow_types ) - return cls(Required, use_cache=value.use_cache, dependent=dependent) + return cls._from_field(dependent, value.use_cache, value.validate) + @override async def _solve( self, stack: Optional[AsyncExitStack] = None, @@ -169,6 +228,7 @@ class DependParam(Param): dependency_cache[call] = task return await task + @override async def _check(self, **kwargs: Any) -> None: # run sub dependent pre-checkers sub_dependent: Dependent = self.extra["dependent"] @@ -195,9 +255,10 @@ class BotParam(Param): ) @classmethod + @override def _check_param( cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] - ) -> Optional["BotParam"]: + ) -> Optional[Self]: from nonebot.adapters import Bot # param type is Bot(s) or subclass(es) of Bot or None @@ -217,9 +278,11 @@ class BotParam(Param): elif param.annotation == param.empty and param.name == "bot": return cls(Required) + @override async def _solve(self, bot: "Bot", **kwargs: Any) -> Any: return bot + @override async def _check(self, bot: "Bot", **kwargs: Any) -> None: if checker := self.extra.get("checker"): check_field_type(checker, bot) @@ -245,9 +308,10 @@ class EventParam(Param): ) @classmethod + @override def _check_param( cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] - ) -> Optional["EventParam"]: + ) -> Optional[Self]: from nonebot.adapters import Event # param type is Event(s) or subclass(es) of Event or None @@ -267,9 +331,11 @@ class EventParam(Param): elif param.annotation == param.empty and param.name == "event": return cls(Required) + @override async def _solve(self, event: "Event", **kwargs: Any) -> Any: return event + @override async def _check(self, event: "Event", **kwargs: Any) -> Any: if checker := self.extra.get("checker", None): check_field_type(checker, event) @@ -287,9 +353,10 @@ class StateParam(Param): return "StateParam()" @classmethod + @override def _check_param( cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] - ) -> Optional["StateParam"]: + ) -> Optional[Self]: # param type is T_State if param.annotation is T_State: return cls(Required) @@ -297,6 +364,7 @@ class StateParam(Param): elif param.annotation == param.empty and param.name == "state": return cls(Required) + @override async def _solve(self, state: T_State, **kwargs: Any) -> Any: return state @@ -313,9 +381,10 @@ class MatcherParam(Param): return "MatcherParam()" @classmethod + @override def _check_param( cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] - ) -> Optional["MatcherParam"]: + ) -> Optional[Self]: from nonebot.matcher import Matcher # param type is Matcher(s) or subclass(es) of Matcher or None @@ -335,9 +404,11 @@ class MatcherParam(Param): elif param.annotation == param.empty and param.name == "matcher": return cls(Required) + @override async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any: return matcher + @override async def _check(self, matcher: "Matcher", **kwargs: Any) -> Any: if checker := self.extra.get("checker", None): check_field_type(checker, matcher) @@ -382,9 +453,10 @@ class ArgParam(Param): return f"ArgParam(key={self.extra['key']!r}, type={self.extra['type']!r})" @classmethod + @override def _check_param( cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] - ) -> Optional["ArgParam"]: + ) -> Optional[Self]: if isinstance(param.default, ArgInner): return cls( Required, key=param.default.key or param.name, type=param.default.type @@ -419,9 +491,10 @@ class ExceptionParam(Param): return "ExceptionParam()" @classmethod + @override def _check_param( cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] - ) -> Optional["ExceptionParam"]: + ) -> Optional[Self]: # param type is Exception(s) or subclass(es) of Exception or None if generic_check_issubclass(param.annotation, Exception): return cls(Required) @@ -429,6 +502,7 @@ class ExceptionParam(Param): elif param.annotation == param.empty and param.name == "exception": return cls(Required) + @override async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any: return exception @@ -445,12 +519,14 @@ class DefaultParam(Param): return f"DefaultParam(default={self.default!r})" @classmethod + @override def _check_param( cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] - ) -> Optional["DefaultParam"]: + ) -> Optional[Self]: if param.default != param.empty: return cls(param.default) + @override async def _solve(self, **kwargs: Any) -> Any: return Undefined diff --git a/tests/plugins/param/param_depend.py b/tests/plugins/param/param_depend.py index dd87199d..9a7f3fc2 100644 --- a/tests/plugins/param/param_depend.py +++ b/tests/plugins/param/param_depend.py @@ -1,7 +1,10 @@ from dataclasses import dataclass from typing_extensions import Annotated +from pydantic import Field + from nonebot import on_message +from nonebot.adapters import Bot from nonebot.params import Depends test_depends = on_message() @@ -33,6 +36,14 @@ class ClassDependency: y: int = Depends(gen_async) +class FooBot(Bot): + ... + + +async def sub_bot(b: FooBot) -> FooBot: + return b + + # test parameterless @test_depends.handle(parameterless=[Depends(parameterless)]) async def depends(x: int = Depends(dependency)): @@ -46,19 +57,46 @@ async def depends_cache(y: int = Depends(dependency, use_cache=True)): return y +# test class dependency async def class_depend(c: ClassDependency = Depends()): return c +# test annotated dependency async def annotated_depend(x: Annotated[int, Depends(dependency)]): return x +# test annotated class dependency async def annotated_class_depend(c: Annotated[ClassDependency, Depends()]): return c +# test dependency priority async def annotated_prior_depend( x: Annotated[int, Depends(lambda: 2)] = Depends(dependency) ): return x + + +# test sub dependency type mismatch +async def sub_type_mismatch(b: FooBot = Depends(sub_bot)): + return b + + +# test type validate +async def validate(x: int = Depends(lambda: "1", validate=True)): + return x + + +async def validate_fail(x: int = Depends(lambda: "not_number", validate=True)): + return x + + +# test FieldInfo validate +async def validate_field(x: int = Depends(lambda: "1", validate=Field(gt=0))): + return x + + +async def validate_field_fail(x: int = Depends(lambda: "0", validate=Field(gt=0))): + return x diff --git a/tests/test_param.py b/tests/test_param.py index 9ea7f4c3..4795e2a6 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -42,9 +42,14 @@ async def test_depend(app: App): ClassDependency, runned, depends, + validate, class_depend, test_depends, + validate_fail, + validate_field, annotated_depend, + sub_type_mismatch, + validate_field_fail, annotated_class_depend, annotated_prior_depend, ) @@ -62,8 +67,7 @@ async def test_depend(app: App): event_next = make_fake_event()() ctx.receive_event(bot, event_next) - assert len(runned) == 2 - assert runned[0] == runned[1] == 1 + assert runned == [1, 1] runned.clear() @@ -84,6 +88,29 @@ async def test_depend(app: App): ) as ctx: ctx.should_return(ClassDependency(x=1, y=2)) + with pytest.raises(TypeMisMatch): # noqa: PT012 + async with app.test_dependent( + sub_type_mismatch, allow_types=[DependParam, BotParam] + ) as ctx: + bot = ctx.create_bot() + ctx.pass_params(bot=bot) + + async with app.test_dependent(validate, allow_types=[DependParam]) as ctx: + ctx.should_return(1) + + with pytest.raises(TypeMisMatch): + async with app.test_dependent(validate_fail, allow_types=[DependParam]) as ctx: + ... + + async with app.test_dependent(validate_field, allow_types=[DependParam]) as ctx: + ctx.should_return(1) + + with pytest.raises(TypeMisMatch): + async with app.test_dependent( + validate_field_fail, allow_types=[DependParam] + ) as ctx: + ... + @pytest.mark.asyncio async def test_bot(app: App): diff --git a/website/docs/advanced/dependency.mdx b/website/docs/advanced/dependency.mdx index d511403f..55b28f43 100644 --- a/website/docs/advanced/dependency.mdx +++ b/website/docs/advanced/dependency.mdx @@ -353,6 +353,80 @@ async def _(x: int = Depends(random_result, use_cache=False)): 缓存的生命周期与当前接收到的事件相同。接收到事件后,子依赖在首次执行时缓存,在该事件处理完成后,缓存就会被清除。 ::: +### 类型转换与校验 + +在依赖注入系统中,我们可以对子依赖的返回值进行自动类型转换与校验。这个功能由 Pydantic 支持,因此我们通过参数类型注解自动使用 Pydantic 支持的类型转换。例如: + + + + +```python {6,9} +from typing import Annotated + +from nonebot.params import Depends +from nonebot.adapters import Event + +def get_user_id(event: Event) -> str: + return event.get_user_id() + +async def _(user_id: Annotated[int, Depends(get_user_id, validate=True)]): + print(user_id) +``` + + + + +```python {4,7} +from nonebot.params import Depends +from nonebot.adapters import Event + +def get_user_id(event: Event) -> str: + return event.get_user_id() + +async def _(user_id: int = Depends(get_user_id, validate=True)): + print(user_id) +``` + + + + +在进行类型自动转换的同时,Pydantic 还支持对数据进行更多的限制,如:大于、小于、长度等。使用方法如下: + + + + +```python {7,10} +from typing import Annotated + +from pydantic import Field +from nonebot.params import Depends +from nonebot.adapters import Event + +def get_user_id(event: Event) -> str: + return event.get_user_id() + +async def _(user_id: Annotated[int, Depends(get_user_id, validate=Field(gt=100))]): + print(user_id) +``` + + + + +```python {5,8} +from pydantic import Field +from nonebot.params import Depends +from nonebot.adapters import Event + +def get_user_id(event: Event) -> str: + return event.get_user_id() + +async def _(user_id: int = Depends(get_user_id, validate=Field(gt=100))): + print(user_id) +``` + + + + ### 类作为依赖 在前面的事例中,我们使用了函数作为子依赖。实际上,我们还可以使用类作为依赖。当我们在实例化一个类的时候,其实我们就在调用它,类本身也是一个可调用对象。例如: