Feature: 支持子依赖定义 Pydantic 类型校验 (#2310)

This commit is contained in:
Ju4tCode 2023-08-29 18:45:12 +08:00 committed by GitHub
parent 79f833b946
commit f59271bd47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 244 additions and 25 deletions

View File

@ -45,6 +45,10 @@ class Param(abc.ABC, FieldInfo):
继承自 `pydantic.fields.FieldInfo`用于描述参数信息不包括参数名 继承自 `pydantic.fields.FieldInfo`用于描述参数信息不包括参数名
""" """
def __init__(self, *args, validate: bool = False, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.validate = validate
@classmethod @classmethod
def _check_param( def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type["Param"], ...] cls, param: inspect.Parameter, allow_types: Tuple[Type["Param"], ...]
@ -206,10 +210,12 @@ class Dependent(Generic[R]):
raise raise
async def _solve_field(self, field: ModelField, params: Dict[str, Any]) -> Any: async def _solve_field(self, field: ModelField, params: Dict[str, Any]) -> Any:
value = await cast(Param, field.field_info)._solve(**params) param = cast(Param, field.field_info)
value = await param._solve(**params)
if value is Undefined: if value is Undefined:
value = field.get_default() value = field.get_default()
return check_field_type(field, value) v = check_field_type(field, value)
return v if param.validate else value
async def solve(self, **params: Any) -> Dict[str, Any]: async def solve(self, **params: Any) -> Dict[str, Any]:
# solve parameterless # solve parameterless

View File

@ -5,7 +5,7 @@ FrontMatter:
""" """
import inspect import inspect
from typing import Any, Dict, TypeVar, Callable, ForwardRef from typing import Any, Dict, Callable, ForwardRef
from loguru import logger from loguru import logger
from pydantic.fields import ModelField from pydantic.fields import ModelField
@ -13,8 +13,6 @@ from pydantic.typing import evaluate_forwardref
from nonebot.exception import TypeMisMatch from nonebot.exception import TypeMisMatch
V = TypeVar("V")
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
"""获取可调用对象签名""" """获取可调用对象签名"""
@ -49,10 +47,10 @@ def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) ->
return annotation return annotation
def check_field_type(field: ModelField, value: V) -> V: def check_field_type(field: ModelField, value: Any) -> Any:
"""检查字段类型是否匹配""" """检查字段类型是否匹配"""
_, errs_ = field.validate(value, {}, loc=()) v, errs_ = field.validate(value, {}, loc=())
if errs_: if errs_:
raise TypeMisMatch(field, value) raise TypeMisMatch(field, value)
return value return v

View File

@ -1,11 +1,21 @@
import asyncio import asyncio
import inspect import inspect
from typing_extensions import Annotated from typing_extensions import Self, Annotated, override
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from typing import TYPE_CHECKING, Any, Type, Tuple, Literal, Callable, Optional, cast from typing import (
TYPE_CHECKING,
Any,
Type,
Tuple,
Union,
Literal,
Callable,
Optional,
cast,
)
from pydantic.typing import get_args, get_origin from pydantic.typing import get_args, get_origin
from pydantic.fields import Required, Undefined, ModelField from pydantic.fields import Required, FieldInfo, Undefined, ModelField
from nonebot.dependencies.utils import check_field_type from nonebot.dependencies.utils import check_field_type
from nonebot.dependencies import Param, Dependent, CustomConfig from nonebot.dependencies import Param, Dependent, CustomConfig
@ -24,6 +34,23 @@ if TYPE_CHECKING:
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
EXTRA_FIELD_INFO = (
"gt",
"lt",
"ge",
"le",
"multiple_of",
"allow_inf_nan",
"max_digits",
"decimal_places",
"min_items",
"max_items",
"unique_items",
"min_length",
"max_length",
"regex",
)
class DependsInner: class DependsInner:
def __init__( def __init__(
@ -31,26 +58,31 @@ class DependsInner:
dependency: Optional[T_Handler] = None, dependency: Optional[T_Handler] = None,
*, *,
use_cache: bool = True, use_cache: bool = True,
validate: Union[bool, FieldInfo] = False,
) -> None: ) -> None:
self.dependency = dependency self.dependency = dependency
self.use_cache = use_cache self.use_cache = use_cache
self.validate = validate
def __repr__(self) -> str: def __repr__(self) -> str:
dep = get_name(self.dependency) dep = get_name(self.dependency)
cache = "" if self.use_cache else ", use_cache=False" cache = "" if self.use_cache else ", use_cache=False"
return f"DependsInner({dep}{cache})" validate = f", validate={self.validate}" if self.validate else ""
return f"DependsInner({dep}{cache}{validate})"
def Depends( def Depends(
dependency: Optional[T_Handler] = None, dependency: Optional[T_Handler] = None,
*, *,
use_cache: bool = True, use_cache: bool = True,
validate: Union[bool, FieldInfo] = False,
) -> Any: ) -> Any:
"""子依赖装饰器 """子依赖装饰器
参数: 参数:
dependency: 依赖函数默认为参数的类型注释 dependency: 依赖函数默认为参数的类型注释
use_cache: 是否使用缓存默认为 `True` use_cache: 是否使用缓存默认为 `True`
validate: 是否使用 Pydantic 类型校验默认为 `False`
用法: 用法:
```python ```python
@ -70,7 +102,7 @@ def Depends(
... ...
``` ```
""" """
return DependsInner(dependency, use_cache=use_cache) return DependsInner(dependency, use_cache=use_cache, validate=validate)
class DependParam(Param): class DependParam(Param):
@ -85,23 +117,44 @@ class DependParam(Param):
return f"Depends({self.extra['dependent']})" return f"Depends({self.extra['dependent']})"
@classmethod @classmethod
def _from_field(
cls, sub_dependent: Dependent, use_cache: bool, validate: Union[bool, FieldInfo]
) -> Self:
kwargs = {}
if isinstance(validate, FieldInfo):
kwargs.update((k, getattr(validate, k)) for k in EXTRA_FIELD_INFO)
return cls(
Required,
validate=bool(validate),
**kwargs,
dependent=sub_dependent,
use_cache=use_cache,
)
@classmethod
@override
def _check_param( def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["DependParam"]: ) -> Optional[Self]:
type_annotation, depends_inner = param.annotation, None type_annotation, depends_inner = param.annotation, None
# extract type annotation and dependency from Annotated
if get_origin(param.annotation) is Annotated: if get_origin(param.annotation) is Annotated:
type_annotation, *extra_args = get_args(param.annotation) type_annotation, *extra_args = get_args(param.annotation)
depends_inner = next( depends_inner = next(
(x for x in extra_args if isinstance(x, DependsInner)), None (x for x in extra_args if isinstance(x, DependsInner)), None
) )
# param default value takes higher priority
depends_inner = ( depends_inner = (
param.default if isinstance(param.default, DependsInner) else depends_inner param.default if isinstance(param.default, DependsInner) else depends_inner
) )
# not a dependent
if depends_inner is None: if depends_inner is None:
return return
dependency: T_Handler dependency: T_Handler
# sub dependency is not specified, use type annotation
if depends_inner.dependency is None: if depends_inner.dependency is None:
assert ( assert (
type_annotation is not inspect.Signature.empty type_annotation is not inspect.Signature.empty
@ -109,13 +162,18 @@ class DependParam(Param):
dependency = type_annotation dependency = type_annotation
else: else:
dependency = depends_inner.dependency dependency = depends_inner.dependency
# parse sub dependency
sub_dependent = Dependent[Any].parse( sub_dependent = Dependent[Any].parse(
call=dependency, call=dependency,
allow_types=allow_types, allow_types=allow_types,
) )
return cls(Required, use_cache=depends_inner.use_cache, dependent=sub_dependent)
return cls._from_field(
sub_dependent, depends_inner.use_cache, depends_inner.validate
)
@classmethod @classmethod
@override
def _check_parameterless( def _check_parameterless(
cls, value: Any, allow_types: Tuple[Type[Param], ...] cls, value: Any, allow_types: Tuple[Type[Param], ...]
) -> Optional["Param"]: ) -> Optional["Param"]:
@ -124,8 +182,9 @@ class DependParam(Param):
dependent = Dependent[Any].parse( dependent = Dependent[Any].parse(
call=value.dependency, allow_types=allow_types call=value.dependency, allow_types=allow_types
) )
return cls(Required, use_cache=value.use_cache, dependent=dependent) return cls._from_field(dependent, value.use_cache, value.validate)
@override
async def _solve( async def _solve(
self, self,
stack: Optional[AsyncExitStack] = None, stack: Optional[AsyncExitStack] = None,
@ -169,6 +228,7 @@ class DependParam(Param):
dependency_cache[call] = task dependency_cache[call] = task
return await task return await task
@override
async def _check(self, **kwargs: Any) -> None: async def _check(self, **kwargs: Any) -> None:
# run sub dependent pre-checkers # run sub dependent pre-checkers
sub_dependent: Dependent = self.extra["dependent"] sub_dependent: Dependent = self.extra["dependent"]
@ -195,9 +255,10 @@ class BotParam(Param):
) )
@classmethod @classmethod
@override
def _check_param( def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["BotParam"]: ) -> Optional[Self]:
from nonebot.adapters import Bot from nonebot.adapters import Bot
# param type is Bot(s) or subclass(es) of Bot or None # param type is Bot(s) or subclass(es) of Bot or None
@ -217,9 +278,11 @@ class BotParam(Param):
elif param.annotation == param.empty and param.name == "bot": elif param.annotation == param.empty and param.name == "bot":
return cls(Required) return cls(Required)
@override
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any: async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
return bot return bot
@override
async def _check(self, bot: "Bot", **kwargs: Any) -> None: async def _check(self, bot: "Bot", **kwargs: Any) -> None:
if checker := self.extra.get("checker"): if checker := self.extra.get("checker"):
check_field_type(checker, bot) check_field_type(checker, bot)
@ -245,9 +308,10 @@ class EventParam(Param):
) )
@classmethod @classmethod
@override
def _check_param( def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["EventParam"]: ) -> Optional[Self]:
from nonebot.adapters import Event from nonebot.adapters import Event
# param type is Event(s) or subclass(es) of Event or None # param type is Event(s) or subclass(es) of Event or None
@ -267,9 +331,11 @@ class EventParam(Param):
elif param.annotation == param.empty and param.name == "event": elif param.annotation == param.empty and param.name == "event":
return cls(Required) return cls(Required)
@override
async def _solve(self, event: "Event", **kwargs: Any) -> Any: async def _solve(self, event: "Event", **kwargs: Any) -> Any:
return event return event
@override
async def _check(self, event: "Event", **kwargs: Any) -> Any: async def _check(self, event: "Event", **kwargs: Any) -> Any:
if checker := self.extra.get("checker", None): if checker := self.extra.get("checker", None):
check_field_type(checker, event) check_field_type(checker, event)
@ -287,9 +353,10 @@ class StateParam(Param):
return "StateParam()" return "StateParam()"
@classmethod @classmethod
@override
def _check_param( def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["StateParam"]: ) -> Optional[Self]:
# param type is T_State # param type is T_State
if param.annotation is T_State: if param.annotation is T_State:
return cls(Required) return cls(Required)
@ -297,6 +364,7 @@ class StateParam(Param):
elif param.annotation == param.empty and param.name == "state": elif param.annotation == param.empty and param.name == "state":
return cls(Required) return cls(Required)
@override
async def _solve(self, state: T_State, **kwargs: Any) -> Any: async def _solve(self, state: T_State, **kwargs: Any) -> Any:
return state return state
@ -313,9 +381,10 @@ class MatcherParam(Param):
return "MatcherParam()" return "MatcherParam()"
@classmethod @classmethod
@override
def _check_param( def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["MatcherParam"]: ) -> Optional[Self]:
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
# param type is Matcher(s) or subclass(es) of Matcher or None # param type is Matcher(s) or subclass(es) of Matcher or None
@ -335,9 +404,11 @@ class MatcherParam(Param):
elif param.annotation == param.empty and param.name == "matcher": elif param.annotation == param.empty and param.name == "matcher":
return cls(Required) return cls(Required)
@override
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any: async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
return matcher return matcher
@override
async def _check(self, matcher: "Matcher", **kwargs: Any) -> Any: async def _check(self, matcher: "Matcher", **kwargs: Any) -> Any:
if checker := self.extra.get("checker", None): if checker := self.extra.get("checker", None):
check_field_type(checker, matcher) check_field_type(checker, matcher)
@ -382,9 +453,10 @@ class ArgParam(Param):
return f"ArgParam(key={self.extra['key']!r}, type={self.extra['type']!r})" return f"ArgParam(key={self.extra['key']!r}, type={self.extra['type']!r})"
@classmethod @classmethod
@override
def _check_param( def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["ArgParam"]: ) -> Optional[Self]:
if isinstance(param.default, ArgInner): if isinstance(param.default, ArgInner):
return cls( return cls(
Required, key=param.default.key or param.name, type=param.default.type Required, key=param.default.key or param.name, type=param.default.type
@ -419,9 +491,10 @@ class ExceptionParam(Param):
return "ExceptionParam()" return "ExceptionParam()"
@classmethod @classmethod
@override
def _check_param( def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["ExceptionParam"]: ) -> Optional[Self]:
# param type is Exception(s) or subclass(es) of Exception or None # param type is Exception(s) or subclass(es) of Exception or None
if generic_check_issubclass(param.annotation, Exception): if generic_check_issubclass(param.annotation, Exception):
return cls(Required) return cls(Required)
@ -429,6 +502,7 @@ class ExceptionParam(Param):
elif param.annotation == param.empty and param.name == "exception": elif param.annotation == param.empty and param.name == "exception":
return cls(Required) return cls(Required)
@override
async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any: async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any:
return exception return exception
@ -445,12 +519,14 @@ class DefaultParam(Param):
return f"DefaultParam(default={self.default!r})" return f"DefaultParam(default={self.default!r})"
@classmethod @classmethod
@override
def _check_param( def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
) -> Optional["DefaultParam"]: ) -> Optional[Self]:
if param.default != param.empty: if param.default != param.empty:
return cls(param.default) return cls(param.default)
@override
async def _solve(self, **kwargs: Any) -> Any: async def _solve(self, **kwargs: Any) -> Any:
return Undefined return Undefined

View File

@ -1,7 +1,10 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing_extensions import Annotated from typing_extensions import Annotated
from pydantic import Field
from nonebot import on_message from nonebot import on_message
from nonebot.adapters import Bot
from nonebot.params import Depends from nonebot.params import Depends
test_depends = on_message() test_depends = on_message()
@ -33,6 +36,14 @@ class ClassDependency:
y: int = Depends(gen_async) y: int = Depends(gen_async)
class FooBot(Bot):
...
async def sub_bot(b: FooBot) -> FooBot:
return b
# test parameterless # test parameterless
@test_depends.handle(parameterless=[Depends(parameterless)]) @test_depends.handle(parameterless=[Depends(parameterless)])
async def depends(x: int = Depends(dependency)): async def depends(x: int = Depends(dependency)):
@ -46,19 +57,46 @@ async def depends_cache(y: int = Depends(dependency, use_cache=True)):
return y return y
# test class dependency
async def class_depend(c: ClassDependency = Depends()): async def class_depend(c: ClassDependency = Depends()):
return c return c
# test annotated dependency
async def annotated_depend(x: Annotated[int, Depends(dependency)]): async def annotated_depend(x: Annotated[int, Depends(dependency)]):
return x return x
# test annotated class dependency
async def annotated_class_depend(c: Annotated[ClassDependency, Depends()]): async def annotated_class_depend(c: Annotated[ClassDependency, Depends()]):
return c return c
# test dependency priority
async def annotated_prior_depend( async def annotated_prior_depend(
x: Annotated[int, Depends(lambda: 2)] = Depends(dependency) x: Annotated[int, Depends(lambda: 2)] = Depends(dependency)
): ):
return x return x
# test sub dependency type mismatch
async def sub_type_mismatch(b: FooBot = Depends(sub_bot)):
return b
# test type validate
async def validate(x: int = Depends(lambda: "1", validate=True)):
return x
async def validate_fail(x: int = Depends(lambda: "not_number", validate=True)):
return x
# test FieldInfo validate
async def validate_field(x: int = Depends(lambda: "1", validate=Field(gt=0))):
return x
async def validate_field_fail(x: int = Depends(lambda: "0", validate=Field(gt=0))):
return x

View File

@ -42,9 +42,14 @@ async def test_depend(app: App):
ClassDependency, ClassDependency,
runned, runned,
depends, depends,
validate,
class_depend, class_depend,
test_depends, test_depends,
validate_fail,
validate_field,
annotated_depend, annotated_depend,
sub_type_mismatch,
validate_field_fail,
annotated_class_depend, annotated_class_depend,
annotated_prior_depend, annotated_prior_depend,
) )
@ -62,8 +67,7 @@ async def test_depend(app: App):
event_next = make_fake_event()() event_next = make_fake_event()()
ctx.receive_event(bot, event_next) ctx.receive_event(bot, event_next)
assert len(runned) == 2 assert runned == [1, 1]
assert runned[0] == runned[1] == 1
runned.clear() runned.clear()
@ -84,6 +88,29 @@ async def test_depend(app: App):
) as ctx: ) as ctx:
ctx.should_return(ClassDependency(x=1, y=2)) ctx.should_return(ClassDependency(x=1, y=2))
with pytest.raises(TypeMisMatch): # noqa: PT012
async with app.test_dependent(
sub_type_mismatch, allow_types=[DependParam, BotParam]
) as ctx:
bot = ctx.create_bot()
ctx.pass_params(bot=bot)
async with app.test_dependent(validate, allow_types=[DependParam]) as ctx:
ctx.should_return(1)
with pytest.raises(TypeMisMatch):
async with app.test_dependent(validate_fail, allow_types=[DependParam]) as ctx:
...
async with app.test_dependent(validate_field, allow_types=[DependParam]) as ctx:
ctx.should_return(1)
with pytest.raises(TypeMisMatch):
async with app.test_dependent(
validate_field_fail, allow_types=[DependParam]
) as ctx:
...
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bot(app: App): async def test_bot(app: App):

View File

@ -353,6 +353,80 @@ async def _(x: int = Depends(random_result, use_cache=False)):
缓存的生命周期与当前接收到的事件相同。接收到事件后,子依赖在首次执行时缓存,在该事件处理完成后,缓存就会被清除。 缓存的生命周期与当前接收到的事件相同。接收到事件后,子依赖在首次执行时缓存,在该事件处理完成后,缓存就会被清除。
::: :::
### 类型转换与校验
在依赖注入系统中,我们可以对子依赖的返回值进行自动类型转换与校验。这个功能由 Pydantic 支持,因此我们通过参数类型注解自动使用 Pydantic 支持的类型转换。例如:
<Tabs groupId="python">
<TabItem value="3.9" label="Python 3.9+" default>
```python {6,9}
from typing import Annotated
from nonebot.params import Depends
from nonebot.adapters import Event
def get_user_id(event: Event) -> str:
return event.get_user_id()
async def _(user_id: Annotated[int, Depends(get_user_id, validate=True)]):
print(user_id)
```
</TabItem>
<TabItem value="3.8" label="Python 3.8+">
```python {4,7}
from nonebot.params import Depends
from nonebot.adapters import Event
def get_user_id(event: Event) -> str:
return event.get_user_id()
async def _(user_id: int = Depends(get_user_id, validate=True)):
print(user_id)
```
</TabItem>
</Tabs>
在进行类型自动转换的同时Pydantic 还支持对数据进行更多的限制,如:大于、小于、长度等。使用方法如下:
<Tabs groupId="python">
<TabItem value="3.9" label="Python 3.9+" default>
```python {7,10}
from typing import Annotated
from pydantic import Field
from nonebot.params import Depends
from nonebot.adapters import Event
def get_user_id(event: Event) -> str:
return event.get_user_id()
async def _(user_id: Annotated[int, Depends(get_user_id, validate=Field(gt=100))]):
print(user_id)
```
</TabItem>
<TabItem value="3.8" label="Python 3.8+">
```python {5,8}
from pydantic import Field
from nonebot.params import Depends
from nonebot.adapters import Event
def get_user_id(event: Event) -> str:
return event.get_user_id()
async def _(user_id: int = Depends(get_user_id, validate=Field(gt=100))):
print(user_id)
```
</TabItem>
</Tabs>
### 类作为依赖 ### 类作为依赖
在前面的事例中,我们使用了函数作为子依赖。实际上,我们还可以使用类作为依赖。当我们在实例化一个类的时候,其实我们就在调用它,类本身也是一个可调用对象。例如: 在前面的事例中,我们使用了函数作为子依赖。实际上,我们还可以使用类作为依赖。当我们在实例化一个类的时候,其实我们就在调用它,类本身也是一个可调用对象。例如: