Feature: 添加 pydantic validator 兼容函数 (#3291)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Tarrailt 2025-01-31 23:48:22 +08:00 committed by GitHub
parent b16ddf380e
commit 3c616e758a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 114 additions and 1 deletions

View File

@ -18,6 +18,7 @@ from typing import (
Any,
Callable,
Generic,
Literal,
Optional,
Protocol,
TypeVar,
@ -45,6 +46,7 @@ if TYPE_CHECKING:
__all__ = (
"DEFAULT_CONFIG",
"PYDANTIC_V2",
"ConfigDict",
"FieldInfo",
"ModelField",
@ -54,9 +56,11 @@ __all__ = (
"TypeAdapter",
"custom_validation",
"extract_field_info",
"field_validator",
"model_config",
"model_dump",
"model_fields",
"model_validator",
"type_validate_json",
"type_validate_python",
)
@ -70,6 +74,8 @@ __autodoc__ = {
if PYDANTIC_V2: # pragma: pydantic-v2
from pydantic import GetCoreSchemaHandler
from pydantic import TypeAdapter as TypeAdapter
from pydantic import field_validator as field_validator
from pydantic import model_validator as model_validator
from pydantic._internal._repr import display_as_type
from pydantic.fields import FieldInfo as BaseFieldInfo
from pydantic_core import CoreSchema, core_schema
@ -254,7 +260,7 @@ if PYDANTIC_V2: # pragma: pydantic-v2
else: # pragma: pydantic-v1
from pydantic import BaseConfig as PydanticConfig
from pydantic import Extra, parse_obj_as, parse_raw_as
from pydantic import Extra, parse_obj_as, parse_raw_as, root_validator, validator
from pydantic.fields import FieldInfo as BaseFieldInfo
from pydantic.fields import ModelField as BaseModelField
from pydantic.schema import get_annotation_from_field_info
@ -367,6 +373,44 @@ else: # pragma: pydantic-v1
kwargs.update(field_info.extra)
return kwargs
@overload
def field_validator(
field: str,
/,
*fields: str,
mode: Literal["before"],
check_fields: Optional[bool] = None,
): ...
@overload
def field_validator(
field: str,
/,
*fields: str,
mode: Literal["after"] = ...,
check_fields: Optional[bool] = None,
): ...
def field_validator(
field: str,
/,
*fields: str,
mode: Literal["before", "after"] = "after",
check_fields: Optional[bool] = None,
):
if mode == "before":
return validator(
field,
*fields,
pre=True,
check_fields=check_fields or True,
allow_reuse=True,
)
else:
return validator(
field, *fields, check_fields=check_fields or True, allow_reuse=True
)
def model_fields(model: type[BaseModel]) -> list[ModelField]:
"""Get field list of a model."""
@ -404,6 +448,18 @@ else: # pragma: pydantic-v1
exclude_none=exclude_none,
)
@overload
def model_validator(*, mode: Literal["before"]): ...
@overload
def model_validator(*, mode: Literal["after"]): ...
def model_validator(*, mode: Literal["before", "after"]):
if mode == "before":
return root_validator(pre=True, allow_reuse=True)
else:
return root_validator(skip_on_failure=True, allow_reuse=True)
def type_validate_python(type_: type[T], data: Any) -> T:
"""Validate data with given type."""
return parse_obj_as(type_, data)

View File

@ -11,7 +11,9 @@ from nonebot.compat import (
Required,
TypeAdapter,
custom_validation,
field_validator,
model_dump,
model_validator,
type_validate_json,
type_validate_python,
)
@ -30,6 +32,32 @@ def test_field_info():
assert FieldInfo(test="test").extra["test"] == "test"
def test_field_validator():
class TestModel(BaseModel):
foo: int
bar: str
@field_validator("foo")
@classmethod
def test_validator(cls, v: Any) -> Any:
if v > 0:
return v
raise ValueError("test must be greater than 0")
@field_validator("bar", mode="before")
@classmethod
def test_validator_before(cls, v: Any) -> Any:
if not isinstance(v, str):
v = str(v)
return v
assert type_validate_python(TestModel, {"foo": 1, "bar": "test"}).foo == 1
assert type_validate_python(TestModel, {"foo": 1, "bar": 123}).bar == "123"
with pytest.raises(ValidationError):
TestModel(foo=0, bar="test")
def test_type_adapter():
t = TypeAdapter(Annotated[int, FieldInfo(ge=1)])
@ -53,6 +81,35 @@ def test_model_dump():
assert model_dump(TestModel(test1=1, test2=2), exclude={"test1"}) == {"test2": 2}
def test_model_validator():
class TestModel(BaseModel):
foo: int
bar: str
@model_validator(mode="before")
@classmethod
def test_validator_before(cls, data: Any) -> Any:
if isinstance(data, dict):
if "foo" not in data:
data["foo"] = 1
return data
@model_validator(mode="after")
@classmethod
def test_validator_after(cls, data: Any) -> Any:
if isinstance(data, dict):
if data["bar"] == "test":
raise ValueError("bar should not be test")
elif data.bar == "test":
raise ValueError("bar should not be test")
return data
assert type_validate_python(TestModel, {"bar": "aaa"}).foo == 1
with pytest.raises(ValidationError):
type_validate_python(TestModel, {"foo": 1, "bar": "test"})
def test_custom_validation():
called = []