From f59271bd478a72d5a03cc075e1c65d08586b8d71 Mon Sep 17 00:00:00 2001
From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com>
Date: Tue, 29 Aug 2023 18:45:12 +0800
Subject: [PATCH] =?UTF-8?q?:sparkles:=20Feature:=20=E6=94=AF=E6=8C=81?=
=?UTF-8?q?=E5=AD=90=E4=BE=9D=E8=B5=96=E5=AE=9A=E4=B9=89=20Pydantic=20?=
=?UTF-8?q?=E7=B1=BB=E5=9E=8B=E6=A0=A1=E9=AA=8C=20(#2310)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
nonebot/dependencies/__init__.py | 10 ++-
nonebot/dependencies/utils.py | 10 +--
nonebot/internal/params.py | 106 +++++++++++++++++++++++----
tests/plugins/param/param_depend.py | 38 ++++++++++
tests/test_param.py | 31 +++++++-
website/docs/advanced/dependency.mdx | 74 +++++++++++++++++++
6 files changed, 244 insertions(+), 25 deletions(-)
diff --git a/nonebot/dependencies/__init__.py b/nonebot/dependencies/__init__.py
index 044837aa..c2fbf00f 100644
--- a/nonebot/dependencies/__init__.py
+++ b/nonebot/dependencies/__init__.py
@@ -45,6 +45,10 @@ class Param(abc.ABC, FieldInfo):
继承自 `pydantic.fields.FieldInfo`,用于描述参数信息(不包括参数名)。
"""
+ def __init__(self, *args, validate: bool = False, **kwargs: Any) -> None:
+ super().__init__(*args, **kwargs)
+ self.validate = validate
+
@classmethod
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type["Param"], ...]
@@ -206,10 +210,12 @@ class Dependent(Generic[R]):
raise
async def _solve_field(self, field: ModelField, params: Dict[str, Any]) -> Any:
- value = await cast(Param, field.field_info)._solve(**params)
+ param = cast(Param, field.field_info)
+ value = await param._solve(**params)
if value is Undefined:
value = field.get_default()
- return check_field_type(field, value)
+ v = check_field_type(field, value)
+ return v if param.validate else value
async def solve(self, **params: Any) -> Dict[str, Any]:
# solve parameterless
diff --git a/nonebot/dependencies/utils.py b/nonebot/dependencies/utils.py
index 55029c34..dcc7385f 100644
--- a/nonebot/dependencies/utils.py
+++ b/nonebot/dependencies/utils.py
@@ -5,7 +5,7 @@ FrontMatter:
"""
import inspect
-from typing import Any, Dict, TypeVar, Callable, ForwardRef
+from typing import Any, Dict, Callable, ForwardRef
from loguru import logger
from pydantic.fields import ModelField
@@ -13,8 +13,6 @@ from pydantic.typing import evaluate_forwardref
from nonebot.exception import TypeMisMatch
-V = TypeVar("V")
-
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
"""获取可调用对象签名"""
@@ -49,10 +47,10 @@ def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) ->
return annotation
-def check_field_type(field: ModelField, value: V) -> V:
+def check_field_type(field: ModelField, value: Any) -> Any:
"""检查字段类型是否匹配"""
- _, errs_ = field.validate(value, {}, loc=())
+ v, errs_ = field.validate(value, {}, loc=())
if errs_:
raise TypeMisMatch(field, value)
- return value
+ return v
diff --git a/nonebot/internal/params.py b/nonebot/internal/params.py
index 9320742c..ca2dc8ad 100644
--- a/nonebot/internal/params.py
+++ b/nonebot/internal/params.py
@@ -1,11 +1,21 @@
import asyncio
import inspect
-from typing_extensions import Annotated
+from typing_extensions import Self, Annotated, override
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
-from typing import TYPE_CHECKING, Any, Type, Tuple, Literal, Callable, Optional, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Type,
+ Tuple,
+ Union,
+ Literal,
+ Callable,
+ Optional,
+ cast,
+)
from pydantic.typing import get_args, get_origin
-from pydantic.fields import Required, Undefined, ModelField
+from pydantic.fields import Required, FieldInfo, Undefined, ModelField
from nonebot.dependencies.utils import check_field_type
from nonebot.dependencies import Param, Dependent, CustomConfig
@@ -24,6 +34,23 @@ if TYPE_CHECKING:
from nonebot.matcher import Matcher
from nonebot.adapters import Bot, Event
+EXTRA_FIELD_INFO = (
+ "gt",
+ "lt",
+ "ge",
+ "le",
+ "multiple_of",
+ "allow_inf_nan",
+ "max_digits",
+ "decimal_places",
+ "min_items",
+ "max_items",
+ "unique_items",
+ "min_length",
+ "max_length",
+ "regex",
+)
+
class DependsInner:
def __init__(
@@ -31,26 +58,31 @@ class DependsInner:
dependency: Optional[T_Handler] = None,
*,
use_cache: bool = True,
+ validate: Union[bool, FieldInfo] = False,
) -> None:
self.dependency = dependency
self.use_cache = use_cache
+ self.validate = validate
def __repr__(self) -> str:
dep = get_name(self.dependency)
cache = "" if self.use_cache else ", use_cache=False"
- return f"DependsInner({dep}{cache})"
+ validate = f", validate={self.validate}" if self.validate else ""
+ return f"DependsInner({dep}{cache}{validate})"
def Depends(
dependency: Optional[T_Handler] = None,
*,
use_cache: bool = True,
+ validate: Union[bool, FieldInfo] = False,
) -> Any:
"""子依赖装饰器
参数:
dependency: 依赖函数。默认为参数的类型注释。
use_cache: 是否使用缓存。默认为 `True`。
+ validate: 是否使用 Pydantic 类型校验。默认为 `False`。
用法:
```python
@@ -70,7 +102,7 @@ def Depends(
...
```
"""
- return DependsInner(dependency, use_cache=use_cache)
+ return DependsInner(dependency, use_cache=use_cache, validate=validate)
class DependParam(Param):
@@ -85,23 +117,44 @@ class DependParam(Param):
return f"Depends({self.extra['dependent']})"
@classmethod
+ def _from_field(
+ cls, sub_dependent: Dependent, use_cache: bool, validate: Union[bool, FieldInfo]
+ ) -> Self:
+ kwargs = {}
+ if isinstance(validate, FieldInfo):
+ kwargs.update((k, getattr(validate, k)) for k in EXTRA_FIELD_INFO)
+
+ return cls(
+ Required,
+ validate=bool(validate),
+ **kwargs,
+ dependent=sub_dependent,
+ use_cache=use_cache,
+ )
+
+ @classmethod
+ @override
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
- ) -> Optional["DependParam"]:
+ ) -> Optional[Self]:
type_annotation, depends_inner = param.annotation, None
+ # extract type annotation and dependency from Annotated
if get_origin(param.annotation) is Annotated:
type_annotation, *extra_args = get_args(param.annotation)
depends_inner = next(
(x for x in extra_args if isinstance(x, DependsInner)), None
)
+ # param default value takes higher priority
depends_inner = (
param.default if isinstance(param.default, DependsInner) else depends_inner
)
+ # not a dependent
if depends_inner is None:
return
dependency: T_Handler
+ # sub dependency is not specified, use type annotation
if depends_inner.dependency is None:
assert (
type_annotation is not inspect.Signature.empty
@@ -109,13 +162,18 @@ class DependParam(Param):
dependency = type_annotation
else:
dependency = depends_inner.dependency
+ # parse sub dependency
sub_dependent = Dependent[Any].parse(
call=dependency,
allow_types=allow_types,
)
- return cls(Required, use_cache=depends_inner.use_cache, dependent=sub_dependent)
+
+ return cls._from_field(
+ sub_dependent, depends_inner.use_cache, depends_inner.validate
+ )
@classmethod
+ @override
def _check_parameterless(
cls, value: Any, allow_types: Tuple[Type[Param], ...]
) -> Optional["Param"]:
@@ -124,8 +182,9 @@ class DependParam(Param):
dependent = Dependent[Any].parse(
call=value.dependency, allow_types=allow_types
)
- return cls(Required, use_cache=value.use_cache, dependent=dependent)
+ return cls._from_field(dependent, value.use_cache, value.validate)
+ @override
async def _solve(
self,
stack: Optional[AsyncExitStack] = None,
@@ -169,6 +228,7 @@ class DependParam(Param):
dependency_cache[call] = task
return await task
+ @override
async def _check(self, **kwargs: Any) -> None:
# run sub dependent pre-checkers
sub_dependent: Dependent = self.extra["dependent"]
@@ -195,9 +255,10 @@ class BotParam(Param):
)
@classmethod
+ @override
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
- ) -> Optional["BotParam"]:
+ ) -> Optional[Self]:
from nonebot.adapters import Bot
# param type is Bot(s) or subclass(es) of Bot or None
@@ -217,9 +278,11 @@ class BotParam(Param):
elif param.annotation == param.empty and param.name == "bot":
return cls(Required)
+ @override
async def _solve(self, bot: "Bot", **kwargs: Any) -> Any:
return bot
+ @override
async def _check(self, bot: "Bot", **kwargs: Any) -> None:
if checker := self.extra.get("checker"):
check_field_type(checker, bot)
@@ -245,9 +308,10 @@ class EventParam(Param):
)
@classmethod
+ @override
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
- ) -> Optional["EventParam"]:
+ ) -> Optional[Self]:
from nonebot.adapters import Event
# param type is Event(s) or subclass(es) of Event or None
@@ -267,9 +331,11 @@ class EventParam(Param):
elif param.annotation == param.empty and param.name == "event":
return cls(Required)
+ @override
async def _solve(self, event: "Event", **kwargs: Any) -> Any:
return event
+ @override
async def _check(self, event: "Event", **kwargs: Any) -> Any:
if checker := self.extra.get("checker", None):
check_field_type(checker, event)
@@ -287,9 +353,10 @@ class StateParam(Param):
return "StateParam()"
@classmethod
+ @override
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
- ) -> Optional["StateParam"]:
+ ) -> Optional[Self]:
# param type is T_State
if param.annotation is T_State:
return cls(Required)
@@ -297,6 +364,7 @@ class StateParam(Param):
elif param.annotation == param.empty and param.name == "state":
return cls(Required)
+ @override
async def _solve(self, state: T_State, **kwargs: Any) -> Any:
return state
@@ -313,9 +381,10 @@ class MatcherParam(Param):
return "MatcherParam()"
@classmethod
+ @override
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
- ) -> Optional["MatcherParam"]:
+ ) -> Optional[Self]:
from nonebot.matcher import Matcher
# param type is Matcher(s) or subclass(es) of Matcher or None
@@ -335,9 +404,11 @@ class MatcherParam(Param):
elif param.annotation == param.empty and param.name == "matcher":
return cls(Required)
+ @override
async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any:
return matcher
+ @override
async def _check(self, matcher: "Matcher", **kwargs: Any) -> Any:
if checker := self.extra.get("checker", None):
check_field_type(checker, matcher)
@@ -382,9 +453,10 @@ class ArgParam(Param):
return f"ArgParam(key={self.extra['key']!r}, type={self.extra['type']!r})"
@classmethod
+ @override
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
- ) -> Optional["ArgParam"]:
+ ) -> Optional[Self]:
if isinstance(param.default, ArgInner):
return cls(
Required, key=param.default.key or param.name, type=param.default.type
@@ -419,9 +491,10 @@ class ExceptionParam(Param):
return "ExceptionParam()"
@classmethod
+ @override
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
- ) -> Optional["ExceptionParam"]:
+ ) -> Optional[Self]:
# param type is Exception(s) or subclass(es) of Exception or None
if generic_check_issubclass(param.annotation, Exception):
return cls(Required)
@@ -429,6 +502,7 @@ class ExceptionParam(Param):
elif param.annotation == param.empty and param.name == "exception":
return cls(Required)
+ @override
async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any:
return exception
@@ -445,12 +519,14 @@ class DefaultParam(Param):
return f"DefaultParam(default={self.default!r})"
@classmethod
+ @override
def _check_param(
cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...]
- ) -> Optional["DefaultParam"]:
+ ) -> Optional[Self]:
if param.default != param.empty:
return cls(param.default)
+ @override
async def _solve(self, **kwargs: Any) -> Any:
return Undefined
diff --git a/tests/plugins/param/param_depend.py b/tests/plugins/param/param_depend.py
index dd87199d..9a7f3fc2 100644
--- a/tests/plugins/param/param_depend.py
+++ b/tests/plugins/param/param_depend.py
@@ -1,7 +1,10 @@
from dataclasses import dataclass
from typing_extensions import Annotated
+from pydantic import Field
+
from nonebot import on_message
+from nonebot.adapters import Bot
from nonebot.params import Depends
test_depends = on_message()
@@ -33,6 +36,14 @@ class ClassDependency:
y: int = Depends(gen_async)
+class FooBot(Bot):
+ ...
+
+
+async def sub_bot(b: FooBot) -> FooBot:
+ return b
+
+
# test parameterless
@test_depends.handle(parameterless=[Depends(parameterless)])
async def depends(x: int = Depends(dependency)):
@@ -46,19 +57,46 @@ async def depends_cache(y: int = Depends(dependency, use_cache=True)):
return y
+# test class dependency
async def class_depend(c: ClassDependency = Depends()):
return c
+# test annotated dependency
async def annotated_depend(x: Annotated[int, Depends(dependency)]):
return x
+# test annotated class dependency
async def annotated_class_depend(c: Annotated[ClassDependency, Depends()]):
return c
+# test dependency priority
async def annotated_prior_depend(
x: Annotated[int, Depends(lambda: 2)] = Depends(dependency)
):
return x
+
+
+# test sub dependency type mismatch
+async def sub_type_mismatch(b: FooBot = Depends(sub_bot)):
+ return b
+
+
+# test type validate
+async def validate(x: int = Depends(lambda: "1", validate=True)):
+ return x
+
+
+async def validate_fail(x: int = Depends(lambda: "not_number", validate=True)):
+ return x
+
+
+# test FieldInfo validate
+async def validate_field(x: int = Depends(lambda: "1", validate=Field(gt=0))):
+ return x
+
+
+async def validate_field_fail(x: int = Depends(lambda: "0", validate=Field(gt=0))):
+ return x
diff --git a/tests/test_param.py b/tests/test_param.py
index 9ea7f4c3..4795e2a6 100644
--- a/tests/test_param.py
+++ b/tests/test_param.py
@@ -42,9 +42,14 @@ async def test_depend(app: App):
ClassDependency,
runned,
depends,
+ validate,
class_depend,
test_depends,
+ validate_fail,
+ validate_field,
annotated_depend,
+ sub_type_mismatch,
+ validate_field_fail,
annotated_class_depend,
annotated_prior_depend,
)
@@ -62,8 +67,7 @@ async def test_depend(app: App):
event_next = make_fake_event()()
ctx.receive_event(bot, event_next)
- assert len(runned) == 2
- assert runned[0] == runned[1] == 1
+ assert runned == [1, 1]
runned.clear()
@@ -84,6 +88,29 @@ async def test_depend(app: App):
) as ctx:
ctx.should_return(ClassDependency(x=1, y=2))
+ with pytest.raises(TypeMisMatch): # noqa: PT012
+ async with app.test_dependent(
+ sub_type_mismatch, allow_types=[DependParam, BotParam]
+ ) as ctx:
+ bot = ctx.create_bot()
+ ctx.pass_params(bot=bot)
+
+ async with app.test_dependent(validate, allow_types=[DependParam]) as ctx:
+ ctx.should_return(1)
+
+ with pytest.raises(TypeMisMatch):
+ async with app.test_dependent(validate_fail, allow_types=[DependParam]) as ctx:
+ ...
+
+ async with app.test_dependent(validate_field, allow_types=[DependParam]) as ctx:
+ ctx.should_return(1)
+
+ with pytest.raises(TypeMisMatch):
+ async with app.test_dependent(
+ validate_field_fail, allow_types=[DependParam]
+ ) as ctx:
+ ...
+
@pytest.mark.asyncio
async def test_bot(app: App):
diff --git a/website/docs/advanced/dependency.mdx b/website/docs/advanced/dependency.mdx
index d511403f..55b28f43 100644
--- a/website/docs/advanced/dependency.mdx
+++ b/website/docs/advanced/dependency.mdx
@@ -353,6 +353,80 @@ async def _(x: int = Depends(random_result, use_cache=False)):
缓存的生命周期与当前接收到的事件相同。接收到事件后,子依赖在首次执行时缓存,在该事件处理完成后,缓存就会被清除。
:::
+### 类型转换与校验
+
+在依赖注入系统中,我们可以对子依赖的返回值进行自动类型转换与校验。这个功能由 Pydantic 支持,因此我们通过参数类型注解自动使用 Pydantic 支持的类型转换。例如:
+
+
+
+
+```python {6,9}
+from typing import Annotated
+
+from nonebot.params import Depends
+from nonebot.adapters import Event
+
+def get_user_id(event: Event) -> str:
+ return event.get_user_id()
+
+async def _(user_id: Annotated[int, Depends(get_user_id, validate=True)]):
+ print(user_id)
+```
+
+
+
+
+```python {4,7}
+from nonebot.params import Depends
+from nonebot.adapters import Event
+
+def get_user_id(event: Event) -> str:
+ return event.get_user_id()
+
+async def _(user_id: int = Depends(get_user_id, validate=True)):
+ print(user_id)
+```
+
+
+
+
+在进行类型自动转换的同时,Pydantic 还支持对数据进行更多的限制,如:大于、小于、长度等。使用方法如下:
+
+
+
+
+```python {7,10}
+from typing import Annotated
+
+from pydantic import Field
+from nonebot.params import Depends
+from nonebot.adapters import Event
+
+def get_user_id(event: Event) -> str:
+ return event.get_user_id()
+
+async def _(user_id: Annotated[int, Depends(get_user_id, validate=Field(gt=100))]):
+ print(user_id)
+```
+
+
+
+
+```python {5,8}
+from pydantic import Field
+from nonebot.params import Depends
+from nonebot.adapters import Event
+
+def get_user_id(event: Event) -> str:
+ return event.get_user_id()
+
+async def _(user_id: int = Depends(get_user_id, validate=Field(gt=100))):
+ print(user_id)
+```
+
+
+
+
### 类作为依赖
在前面的事例中,我们使用了函数作为子依赖。实际上,我们还可以使用类作为依赖。当我们在实例化一个类的时候,其实我们就在调用它,类本身也是一个可调用对象。例如: