mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 00:55:07 +08:00
4dae23d3bb
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ju4tCode <42488585+yanyongyu@users.noreply.github.com>
352 lines
12 KiB
Python
352 lines
12 KiB
Python
"""本模块为 Pydantic 版本兼容层模块
|
|
|
|
为兼容 Pydantic V1 与 V2 版本,定义了一系列兼容函数与类供使用。
|
|
|
|
FrontMatter:
|
|
sidebar_position: 16
|
|
description: nonebot.compat 模块
|
|
"""
|
|
|
|
from dataclasses import dataclass, is_dataclass
|
|
from typing_extensions import Self, Annotated, get_args, get_origin, is_typeddict
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Set,
|
|
Dict,
|
|
List,
|
|
Type,
|
|
TypeVar,
|
|
Callable,
|
|
Optional,
|
|
Protocol,
|
|
Generator,
|
|
)
|
|
|
|
from pydantic import VERSION, BaseModel
|
|
|
|
from nonebot.typing import origin_is_annotated
|
|
|
|
T = TypeVar("T")
|
|
|
|
PYDANTIC_V2 = int(VERSION.split(".", 1)[0]) == 2
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
class _CustomValidationClass(Protocol):
|
|
@classmethod
|
|
def __get_validators__(cls) -> Generator[Callable[..., Any], None, None]: ...
|
|
|
|
CVC = TypeVar("CVC", bound=_CustomValidationClass)
|
|
|
|
|
|
__all__ = (
|
|
"Required",
|
|
"PydanticUndefined",
|
|
"PydanticUndefinedType",
|
|
"ConfigDict",
|
|
"DEFAULT_CONFIG",
|
|
"FieldInfo",
|
|
"ModelField",
|
|
"extract_field_info",
|
|
"model_field_validate",
|
|
"model_fields",
|
|
"model_config",
|
|
"model_dump",
|
|
"type_validate_python",
|
|
"custom_validation",
|
|
)
|
|
|
|
__autodoc__ = {
|
|
"PydanticUndefined": "Pydantic Undefined object",
|
|
"PydanticUndefinedType": "Pydantic Undefined type",
|
|
}
|
|
|
|
|
|
if PYDANTIC_V2: # pragma: pydantic-v2
|
|
from pydantic_core import CoreSchema, core_schema
|
|
from pydantic._internal._repr import display_as_type
|
|
from pydantic import TypeAdapter, GetCoreSchemaHandler
|
|
from pydantic.fields import FieldInfo as BaseFieldInfo
|
|
|
|
Required = Ellipsis
|
|
"""Alias of Ellipsis for compatibility with pydantic v1"""
|
|
|
|
# Export undefined type
|
|
from pydantic_core import PydanticUndefined as PydanticUndefined
|
|
from pydantic_core import PydanticUndefinedType as PydanticUndefinedType
|
|
|
|
# isort: split
|
|
|
|
# Export model config dict
|
|
from pydantic import ConfigDict as ConfigDict
|
|
|
|
DEFAULT_CONFIG = ConfigDict(extra="allow", arbitrary_types_allowed=True)
|
|
"""Default config for validations"""
|
|
|
|
class FieldInfo(BaseFieldInfo):
|
|
"""FieldInfo class with extra property for compatibility with pydantic v1"""
|
|
|
|
# make default can be positional argument
|
|
def __init__(self, default: Any = PydanticUndefined, **kwargs: Any) -> None:
|
|
super().__init__(default=default, **kwargs)
|
|
|
|
@property
|
|
def extra(self) -> Dict[str, Any]:
|
|
"""Extra data that is not part of the standard pydantic fields.
|
|
|
|
For compatibility with pydantic v1.
|
|
"""
|
|
# extract extra data from attributes set except used slots
|
|
# we need to call super in advance due to
|
|
# comprehension not inlined in cpython < 3.12
|
|
# https://peps.python.org/pep-0709/
|
|
slots = super().__slots__
|
|
return {k: v for k, v in self._attributes_set.items() if k not in slots}
|
|
|
|
@dataclass
|
|
class ModelField:
|
|
"""ModelField class for compatibility with pydantic v1"""
|
|
|
|
name: str
|
|
"""The name of the field."""
|
|
annotation: Any
|
|
"""The annotation of the field."""
|
|
field_info: FieldInfo
|
|
"""The FieldInfo of the field."""
|
|
|
|
@classmethod
|
|
def _construct(cls, name: str, annotation: Any, field_info: FieldInfo) -> Self:
|
|
return cls(name, annotation, field_info)
|
|
|
|
@classmethod
|
|
def construct(
|
|
cls, name: str, annotation: Any, field_info: Optional[FieldInfo] = None
|
|
) -> Self:
|
|
"""Construct a ModelField from given infos."""
|
|
return cls._construct(name, annotation, field_info or FieldInfo())
|
|
|
|
def _annotation_has_config(self) -> bool:
|
|
"""Check if the annotation has config.
|
|
|
|
TypeAdapter raise error when annotation has config
|
|
and given config is not None.
|
|
"""
|
|
type_is_annotated = origin_is_annotated(get_origin(self.annotation))
|
|
inner_type = (
|
|
get_args(self.annotation)[0] if type_is_annotated else self.annotation
|
|
)
|
|
try:
|
|
return (
|
|
issubclass(inner_type, BaseModel)
|
|
or is_dataclass(inner_type)
|
|
or is_typeddict(inner_type)
|
|
)
|
|
except TypeError:
|
|
return False
|
|
|
|
def get_default(self) -> Any:
|
|
"""Get the default value of the field."""
|
|
return self.field_info.get_default(call_default_factory=True)
|
|
|
|
def _type_display(self):
|
|
"""Get the display of the type of the field."""
|
|
return display_as_type(self.annotation)
|
|
|
|
def __hash__(self) -> int:
|
|
# Each ModelField is unique for our purposes,
|
|
# to allow store them in a set.
|
|
return id(self)
|
|
|
|
def extract_field_info(field_info: BaseFieldInfo) -> Dict[str, Any]:
|
|
"""Get FieldInfo init kwargs from a FieldInfo instance."""
|
|
|
|
kwargs = field_info._attributes_set.copy()
|
|
kwargs["annotation"] = field_info.rebuild_annotation()
|
|
return kwargs
|
|
|
|
def model_field_validate(
|
|
model_field: ModelField, value: Any, config: Optional[ConfigDict] = None
|
|
) -> Any:
|
|
"""Validate the value pass to the field."""
|
|
type: Any = Annotated[model_field.annotation, model_field.field_info]
|
|
return TypeAdapter(
|
|
type, config=None if model_field._annotation_has_config() else config
|
|
).validate_python(value)
|
|
|
|
def model_fields(model: Type[BaseModel]) -> List[ModelField]:
|
|
"""Get field list of a model."""
|
|
|
|
return [
|
|
ModelField._construct(
|
|
name=name,
|
|
annotation=field_info.rebuild_annotation(),
|
|
field_info=FieldInfo(**extract_field_info(field_info)),
|
|
)
|
|
for name, field_info in model.model_fields.items()
|
|
]
|
|
|
|
def model_config(model: Type[BaseModel]) -> Any:
|
|
"""Get config of a model."""
|
|
return model.model_config
|
|
|
|
def model_dump(
|
|
model: BaseModel,
|
|
include: Optional[Set[str]] = None,
|
|
exclude: Optional[Set[str]] = None,
|
|
) -> Dict[str, Any]:
|
|
return model.model_dump(include=include, exclude=exclude)
|
|
|
|
def type_validate_python(type_: Type[T], data: Any) -> T:
|
|
"""Validate data with given type."""
|
|
return TypeAdapter(type_).validate_python(data)
|
|
|
|
def __get_pydantic_core_schema__(
|
|
cls: Type["_CustomValidationClass"],
|
|
source_type: Any,
|
|
handler: GetCoreSchemaHandler,
|
|
) -> CoreSchema:
|
|
validators = list(cls.__get_validators__())
|
|
if len(validators) == 1:
|
|
return core_schema.no_info_plain_validator_function(validators[0])
|
|
return core_schema.chain_schema(
|
|
[core_schema.no_info_plain_validator_function(func) for func in validators]
|
|
)
|
|
|
|
def custom_validation(class_: Type["CVC"]) -> Type["CVC"]:
|
|
"""Use pydantic v1 like validator generator in pydantic v2"""
|
|
|
|
setattr(
|
|
class_,
|
|
"__get_pydantic_core_schema__",
|
|
classmethod(__get_pydantic_core_schema__),
|
|
)
|
|
return class_
|
|
|
|
else: # pragma: pydantic-v1
|
|
from pydantic import Extra
|
|
from pydantic import parse_obj_as
|
|
from pydantic import BaseConfig as PydanticConfig
|
|
from pydantic.fields import FieldInfo as BaseFieldInfo
|
|
from pydantic.fields import ModelField as BaseModelField
|
|
from pydantic.schema import get_annotation_from_field_info
|
|
|
|
# isort: split
|
|
|
|
from pydantic.fields import Required as Required
|
|
|
|
# isort: split
|
|
|
|
from pydantic.fields import Undefined as PydanticUndefined
|
|
from pydantic.fields import UndefinedType as PydanticUndefinedType
|
|
|
|
class ConfigDict(PydanticConfig):
|
|
"""Config class that allow get value with default value."""
|
|
|
|
@classmethod
|
|
def get(cls, field: str, default: Any = None) -> Any:
|
|
"""Get a config value."""
|
|
return getattr(cls, field, default)
|
|
|
|
class DEFAULT_CONFIG(ConfigDict):
|
|
extra = Extra.allow
|
|
arbitrary_types_allowed = True
|
|
|
|
class FieldInfo(BaseFieldInfo):
|
|
def __init__(self, default: Any = PydanticUndefined, **kwargs: Any):
|
|
# preprocess default value to make it compatible with pydantic v2
|
|
# when default is Required, set it to PydanticUndefined
|
|
if default is Required:
|
|
default = PydanticUndefined
|
|
super().__init__(default, **kwargs)
|
|
|
|
class ModelField(BaseModelField):
|
|
@classmethod
|
|
def _construct(cls, name: str, annotation: Any, field_info: FieldInfo) -> Self:
|
|
return cls(
|
|
name=name,
|
|
type_=annotation,
|
|
class_validators=None,
|
|
model_config=DEFAULT_CONFIG,
|
|
default=field_info.default,
|
|
default_factory=field_info.default_factory,
|
|
required=(
|
|
field_info.default is PydanticUndefined
|
|
and field_info.default_factory is None
|
|
),
|
|
field_info=field_info,
|
|
)
|
|
|
|
@classmethod
|
|
def construct(
|
|
cls, name: str, annotation: Any, field_info: Optional[FieldInfo] = None
|
|
) -> Self:
|
|
"""Construct a ModelField from given infos.
|
|
|
|
Field annotation is preprocessed with field_info.
|
|
"""
|
|
if field_info is not None:
|
|
annotation = get_annotation_from_field_info(
|
|
annotation, field_info, name
|
|
)
|
|
return cls._construct(name, annotation, field_info or FieldInfo())
|
|
|
|
def extract_field_info(field_info: BaseFieldInfo) -> Dict[str, Any]:
|
|
"""Get FieldInfo init kwargs from a FieldInfo instance."""
|
|
|
|
kwargs = {
|
|
s: getattr(field_info, s) for s in field_info.__slots__ if s != "extra"
|
|
}
|
|
kwargs.update(field_info.extra)
|
|
return kwargs
|
|
|
|
def model_field_validate(
|
|
model_field: ModelField, value: Any, config: Optional[Type[ConfigDict]] = None
|
|
) -> Any:
|
|
"""Validate the value pass to the field.
|
|
|
|
Set config before validate to ensure validate correctly.
|
|
"""
|
|
|
|
if model_field.model_config is not config:
|
|
model_field.set_config(config or ConfigDict)
|
|
|
|
v, errs_ = model_field.validate(value, {}, loc=())
|
|
if errs_:
|
|
raise ValueError(value, model_field)
|
|
return v
|
|
|
|
def model_fields(model: Type[BaseModel]) -> List[ModelField]:
|
|
"""Get field list of a model."""
|
|
|
|
# construct the model field without preprocess to avoid error
|
|
return [
|
|
ModelField._construct(
|
|
name=model_field.name,
|
|
annotation=model_field.annotation,
|
|
field_info=FieldInfo(
|
|
**extract_field_info(model_field.field_info),
|
|
),
|
|
)
|
|
for model_field in model.__fields__.values()
|
|
]
|
|
|
|
def model_config(model: Type[BaseModel]) -> Any:
|
|
"""Get config of a model."""
|
|
return model.__config__
|
|
|
|
def model_dump(
|
|
model: BaseModel,
|
|
include: Optional[Set[str]] = None,
|
|
exclude: Optional[Set[str]] = None,
|
|
) -> Dict[str, Any]:
|
|
return model.dict(include=include, exclude=exclude)
|
|
|
|
def type_validate_python(type_: Type[T], data: Any) -> T:
|
|
"""Validate data with given type."""
|
|
return parse_obj_as(type_, data)
|
|
|
|
def custom_validation(class_: Type["CVC"]) -> Type["CVC"]:
|
|
"""Do nothing in pydantic v1"""
|
|
return class_
|