2024-01-26 11:12:57 +08:00
""" 本模块为 Pydantic 版本兼容层模块
为兼容 Pydantic V1 与 V2 版本 , 定义了一系列兼容函数与类供使用 。
FrontMatter :
2024-10-22 10:33:48 +08:00
mdx :
format : md
2024-01-26 11:12:57 +08:00
sidebar_position : 16
description : nonebot . compat 模块
"""
2024-04-16 00:33:48 +08:00
from collections . abc import Generator
2024-01-26 11:12:57 +08:00
from dataclasses import dataclass , is_dataclass
2024-12-01 12:31:11 +08:00
from functools import cached_property
2024-01-26 11:12:57 +08:00
from typing import (
TYPE_CHECKING ,
2024-12-01 12:31:11 +08:00
Annotated ,
2024-01-26 11:12:57 +08:00
Any ,
Callable ,
2024-12-01 12:31:11 +08:00
Generic ,
2025-01-31 23:48:22 +08:00
Literal ,
2024-01-26 11:12:57 +08:00
Optional ,
Protocol ,
2024-12-01 12:31:11 +08:00
TypeVar ,
Union ,
2024-08-11 15:15:59 +08:00
overload ,
2024-01-26 11:12:57 +08:00
)
2024-12-01 12:31:11 +08:00
from typing_extensions import Self , get_args , get_origin , is_typeddict
2024-01-26 11:12:57 +08:00
from pydantic import VERSION , BaseModel
2024-02-06 12:48:23 +08:00
from nonebot . typing import origin_is_annotated
2024-01-26 11:12:57 +08:00
T = TypeVar ( " T " )
PYDANTIC_V2 = int ( VERSION . split ( " . " , 1 ) [ 0 ] ) == 2
if TYPE_CHECKING :
class _CustomValidationClass ( Protocol ) :
@classmethod
2024-02-06 12:48:23 +08:00
def __get_validators__ ( cls ) - > Generator [ Callable [ . . . , Any ] , None , None ] : . . .
2024-01-26 11:12:57 +08:00
CVC = TypeVar ( " CVC " , bound = _CustomValidationClass )
__all__ = (
" DEFAULT_CONFIG " ,
2025-01-31 23:48:22 +08:00
" PYDANTIC_V2 " ,
2024-12-01 12:31:11 +08:00
" ConfigDict " ,
2024-01-26 11:12:57 +08:00
" FieldInfo " ,
" ModelField " ,
2024-12-01 12:31:11 +08:00
" PydanticUndefined " ,
" PydanticUndefinedType " ,
" Required " ,
2024-08-11 15:15:59 +08:00
" TypeAdapter " ,
2024-12-01 12:31:11 +08:00
" custom_validation " ,
2024-01-26 11:12:57 +08:00
" extract_field_info " ,
2025-01-31 23:48:22 +08:00
" field_validator " ,
2024-01-26 11:12:57 +08:00
" model_config " ,
2024-02-05 14:00:49 +08:00
" model_dump " ,
2024-12-01 12:31:11 +08:00
" model_fields " ,
2025-01-31 23:48:22 +08:00
" model_validator " ,
2024-02-17 23:18:00 +08:00
" type_validate_json " ,
2024-12-01 12:31:11 +08:00
" type_validate_python " ,
2024-01-26 11:12:57 +08:00
)
__autodoc__ = {
" PydanticUndefined " : " Pydantic Undefined object " ,
" PydanticUndefinedType " : " Pydantic Undefined type " ,
}
if PYDANTIC_V2 : # pragma: pydantic-v2
2024-08-11 15:15:59 +08:00
from pydantic import GetCoreSchemaHandler
from pydantic import TypeAdapter as TypeAdapter
2025-01-31 23:48:22 +08:00
from pydantic import field_validator as field_validator
from pydantic import model_validator as model_validator
2024-01-26 11:12:57 +08:00
from pydantic . _internal . _repr import display_as_type
from pydantic . fields import FieldInfo as BaseFieldInfo
2024-12-01 12:31:11 +08:00
from pydantic_core import CoreSchema , core_schema
2024-01-26 11:12:57 +08:00
Required = Ellipsis
""" Alias of Ellipsis for compatibility with pydantic v1 """
# Export undefined type
from pydantic_core import PydanticUndefined as PydanticUndefined
from pydantic_core import PydanticUndefinedType as PydanticUndefinedType
# isort: split
# Export model config dict
from pydantic import ConfigDict as ConfigDict
DEFAULT_CONFIG = ConfigDict ( extra = " allow " , arbitrary_types_allowed = True )
""" Default config for validations """
class FieldInfo ( BaseFieldInfo ) :
""" FieldInfo class with extra property for compatibility with pydantic v1 """
# make default can be positional argument
def __init__ ( self , default : Any = PydanticUndefined , * * kwargs : Any ) - > None :
super ( ) . __init__ ( default = default , * * kwargs )
@property
2024-04-16 00:33:48 +08:00
def extra ( self ) - > dict [ str , Any ] :
2024-01-26 11:12:57 +08:00
""" Extra data that is not part of the standard pydantic fields.
For compatibility with pydantic v1 .
"""
# extract extra data from attributes set except used slots
# we need to call super in advance due to
# comprehension not inlined in cpython < 3.12
# https://peps.python.org/pep-0709/
slots = super ( ) . __slots__
return { k : v for k , v in self . _attributes_set . items ( ) if k not in slots }
@dataclass
class ModelField :
""" ModelField class for compatibility with pydantic v1 """
name : str
""" The name of the field. """
annotation : Any
""" The annotation of the field. """
field_info : FieldInfo
""" The FieldInfo of the field. """
@classmethod
def _construct ( cls , name : str , annotation : Any , field_info : FieldInfo ) - > Self :
return cls ( name , annotation , field_info )
@classmethod
def construct (
cls , name : str , annotation : Any , field_info : Optional [ FieldInfo ] = None
) - > Self :
""" Construct a ModelField from given infos. """
return cls . _construct ( name , annotation , field_info or FieldInfo ( ) )
2024-08-11 15:15:59 +08:00
def __hash__ ( self ) - > int :
# Each ModelField is unique for our purposes,
# to allow store them in a set.
return id ( self )
@cached_property
def type_adapter ( self ) - > TypeAdapter :
""" TypeAdapter of the field.
Cache the TypeAdapter to avoid creating it multiple times .
Pydantic v2 uses too much cpu time to create TypeAdapter .
See : https : / / github . com / pydantic / pydantic / issues / 9834
"""
return TypeAdapter (
Annotated [ self . annotation , self . field_info ] ,
config = None if self . _annotation_has_config ( ) else DEFAULT_CONFIG ,
)
2024-01-26 11:12:57 +08:00
def _annotation_has_config ( self ) - > bool :
""" Check if the annotation has config.
TypeAdapter raise error when annotation has config
and given config is not None .
"""
2024-02-06 12:48:23 +08:00
type_is_annotated = origin_is_annotated ( get_origin ( self . annotation ) )
inner_type = (
get_args ( self . annotation ) [ 0 ] if type_is_annotated else self . annotation
)
2024-01-26 11:12:57 +08:00
try :
return (
2024-02-06 12:48:23 +08:00
issubclass ( inner_type , BaseModel )
or is_dataclass ( inner_type )
or is_typeddict ( inner_type )
2024-01-26 11:12:57 +08:00
)
except TypeError :
return False
def get_default ( self ) - > Any :
""" Get the default value of the field. """
return self . field_info . get_default ( call_default_factory = True )
def _type_display ( self ) :
""" Get the display of the type of the field. """
return display_as_type ( self . annotation )
2024-08-11 15:15:59 +08:00
def validate_value ( self , value : Any ) - > Any :
""" Validate the value pass to the field. """
return self . type_adapter . validate_python ( value )
2024-01-26 11:12:57 +08:00
2024-04-16 00:33:48 +08:00
def extract_field_info ( field_info : BaseFieldInfo ) - > dict [ str , Any ] :
2024-01-26 11:12:57 +08:00
""" Get FieldInfo init kwargs from a FieldInfo instance. """
kwargs = field_info . _attributes_set . copy ( )
kwargs [ " annotation " ] = field_info . rebuild_annotation ( )
return kwargs
2024-04-16 00:33:48 +08:00
def model_fields ( model : type [ BaseModel ] ) - > list [ ModelField ] :
2024-01-26 11:12:57 +08:00
""" Get field list of a model. """
return [
ModelField . _construct (
name = name ,
annotation = field_info . rebuild_annotation ( ) ,
field_info = FieldInfo ( * * extract_field_info ( field_info ) ) ,
)
for name , field_info in model . model_fields . items ( )
]
2024-04-16 00:33:48 +08:00
def model_config ( model : type [ BaseModel ] ) - > Any :
2024-01-26 11:12:57 +08:00
""" Get config of a model. """
return model . model_config
2024-02-05 14:00:49 +08:00
def model_dump (
model : BaseModel ,
2024-04-16 00:33:48 +08:00
include : Optional [ set [ str ] ] = None ,
exclude : Optional [ set [ str ] ] = None ,
2024-02-16 21:16:46 +08:00
by_alias : bool = False ,
2024-02-17 23:18:00 +08:00
exclude_unset : bool = False ,
exclude_defaults : bool = False ,
exclude_none : bool = False ,
2024-04-16 00:33:48 +08:00
) - > dict [ str , Any ] :
2024-02-17 23:18:00 +08:00
return model . model_dump (
include = include ,
exclude = exclude ,
by_alias = by_alias ,
exclude_unset = exclude_unset ,
exclude_defaults = exclude_defaults ,
exclude_none = exclude_none ,
)
2024-02-05 14:00:49 +08:00
2024-04-16 00:33:48 +08:00
def type_validate_python ( type_ : type [ T ] , data : Any ) - > T :
2024-01-26 11:12:57 +08:00
""" Validate data with given type. """
return TypeAdapter ( type_ ) . validate_python ( data )
2024-04-16 00:33:48 +08:00
def type_validate_json ( type_ : type [ T ] , data : Union [ str , bytes ] ) - > T :
2024-02-17 23:18:00 +08:00
""" Validate JSON with given type. """
return TypeAdapter ( type_ ) . validate_json ( data )
2024-01-26 11:12:57 +08:00
def __get_pydantic_core_schema__ (
2024-04-16 00:33:48 +08:00
cls : type [ " _CustomValidationClass " ] ,
2024-01-26 11:12:57 +08:00
source_type : Any ,
handler : GetCoreSchemaHandler ,
) - > CoreSchema :
validators = list ( cls . __get_validators__ ( ) )
if len ( validators ) == 1 :
return core_schema . no_info_plain_validator_function ( validators [ 0 ] )
return core_schema . chain_schema (
[ core_schema . no_info_plain_validator_function ( func ) for func in validators ]
)
2024-04-16 00:33:48 +08:00
def custom_validation ( class_ : type [ " CVC " ] ) - > type [ " CVC " ] :
2024-01-26 11:12:57 +08:00
""" Use pydantic v1 like validator generator in pydantic v2 """
setattr (
class_ ,
" __get_pydantic_core_schema__ " ,
classmethod ( __get_pydantic_core_schema__ ) ,
)
return class_
else : # pragma: pydantic-v1
from pydantic import BaseConfig as PydanticConfig
2025-01-31 23:48:22 +08:00
from pydantic import Extra , parse_obj_as , parse_raw_as , root_validator , validator
2024-01-26 11:12:57 +08:00
from pydantic . fields import FieldInfo as BaseFieldInfo
from pydantic . fields import ModelField as BaseModelField
from pydantic . schema import get_annotation_from_field_info
# isort: split
from pydantic . fields import Required as Required
# isort: split
from pydantic . fields import Undefined as PydanticUndefined
from pydantic . fields import UndefinedType as PydanticUndefinedType
class ConfigDict ( PydanticConfig ) :
""" Config class that allow get value with default value. """
@classmethod
def get ( cls , field : str , default : Any = None ) - > Any :
""" Get a config value. """
return getattr ( cls , field , default )
class DEFAULT_CONFIG ( ConfigDict ) :
extra = Extra . allow
arbitrary_types_allowed = True
class FieldInfo ( BaseFieldInfo ) :
def __init__ ( self , default : Any = PydanticUndefined , * * kwargs : Any ) :
# preprocess default value to make it compatible with pydantic v2
# when default is Required, set it to PydanticUndefined
if default is Required :
default = PydanticUndefined
super ( ) . __init__ ( default , * * kwargs )
class ModelField ( BaseModelField ) :
@classmethod
def _construct ( cls , name : str , annotation : Any , field_info : FieldInfo ) - > Self :
return cls (
name = name ,
type_ = annotation ,
class_validators = None ,
model_config = DEFAULT_CONFIG ,
default = field_info . default ,
default_factory = field_info . default_factory ,
required = (
field_info . default is PydanticUndefined
and field_info . default_factory is None
) ,
field_info = field_info ,
)
@classmethod
def construct (
cls , name : str , annotation : Any , field_info : Optional [ FieldInfo ] = None
) - > Self :
""" Construct a ModelField from given infos.
Field annotation is preprocessed with field_info .
"""
if field_info is not None :
annotation = get_annotation_from_field_info (
annotation , field_info , name
)
return cls . _construct ( name , annotation , field_info or FieldInfo ( ) )
2024-08-11 15:15:59 +08:00
def validate_value ( self , value : Any ) - > Any :
""" Validate the value pass to the field. """
v , errs_ = self . validate ( value , { } , loc = ( ) )
if errs_ :
raise ValueError ( value , self )
return v
class TypeAdapter ( Generic [ T ] ) :
@overload
def __init__ (
self ,
type : type [ T ] ,
* ,
config : Optional [ ConfigDict ] = . . . ,
) - > None : . . .
@overload
def __init__ (
self ,
type : Any ,
* ,
config : Optional [ ConfigDict ] = . . . ,
) - > None : . . .
def __init__ (
self ,
type : Any ,
* ,
config : Optional [ ConfigDict ] = None ,
) - > None :
self . type = type
self . config = config
def validate_python ( self , value : Any ) - > T :
return type_validate_python ( self . type , value )
def validate_json ( self , value : Union [ str , bytes ] ) - > T :
return type_validate_json ( self . type , value )
2024-04-16 00:33:48 +08:00
def extract_field_info ( field_info : BaseFieldInfo ) - > dict [ str , Any ] :
2024-01-26 11:12:57 +08:00
""" Get FieldInfo init kwargs from a FieldInfo instance. """
kwargs = {
s : getattr ( field_info , s ) for s in field_info . __slots__ if s != " extra "
}
kwargs . update ( field_info . extra )
return kwargs
2025-01-31 23:48:22 +08:00
@overload
def field_validator (
field : str ,
/ ,
* fields : str ,
mode : Literal [ " before " ] ,
check_fields : Optional [ bool ] = None ,
) : . . .
@overload
def field_validator (
field : str ,
/ ,
* fields : str ,
mode : Literal [ " after " ] = . . . ,
check_fields : Optional [ bool ] = None ,
) : . . .
def field_validator (
field : str ,
/ ,
* fields : str ,
mode : Literal [ " before " , " after " ] = " after " ,
check_fields : Optional [ bool ] = None ,
) :
if mode == " before " :
return validator (
field ,
* fields ,
pre = True ,
check_fields = check_fields or True ,
allow_reuse = True ,
)
else :
return validator (
field , * fields , check_fields = check_fields or True , allow_reuse = True
)
2024-04-16 00:33:48 +08:00
def model_fields ( model : type [ BaseModel ] ) - > list [ ModelField ] :
2024-01-26 11:12:57 +08:00
""" Get field list of a model. """
# construct the model field without preprocess to avoid error
return [
ModelField . _construct (
name = model_field . name ,
annotation = model_field . annotation ,
field_info = FieldInfo (
* * extract_field_info ( model_field . field_info ) ,
) ,
)
for model_field in model . __fields__ . values ( )
]
2024-04-16 00:33:48 +08:00
def model_config ( model : type [ BaseModel ] ) - > Any :
2024-01-26 11:12:57 +08:00
""" Get config of a model. """
return model . __config__
2024-02-05 14:00:49 +08:00
def model_dump (
model : BaseModel ,
2024-04-16 00:33:48 +08:00
include : Optional [ set [ str ] ] = None ,
exclude : Optional [ set [ str ] ] = None ,
2024-02-16 21:16:46 +08:00
by_alias : bool = False ,
2024-02-17 23:18:00 +08:00
exclude_unset : bool = False ,
exclude_defaults : bool = False ,
exclude_none : bool = False ,
2024-04-16 00:33:48 +08:00
) - > dict [ str , Any ] :
2024-02-17 23:18:00 +08:00
return model . dict (
include = include ,
exclude = exclude ,
by_alias = by_alias ,
exclude_unset = exclude_unset ,
exclude_defaults = exclude_defaults ,
exclude_none = exclude_none ,
)
2024-02-05 14:00:49 +08:00
2025-01-31 23:48:22 +08:00
@overload
def model_validator ( * , mode : Literal [ " before " ] ) : . . .
@overload
def model_validator ( * , mode : Literal [ " after " ] ) : . . .
def model_validator ( * , mode : Literal [ " before " , " after " ] ) :
if mode == " before " :
return root_validator ( pre = True , allow_reuse = True )
else :
return root_validator ( skip_on_failure = True , allow_reuse = True )
2024-04-16 00:33:48 +08:00
def type_validate_python ( type_ : type [ T ] , data : Any ) - > T :
2024-01-26 11:12:57 +08:00
""" Validate data with given type. """
return parse_obj_as ( type_ , data )
2024-04-16 00:33:48 +08:00
def type_validate_json ( type_ : type [ T ] , data : Union [ str , bytes ] ) - > T :
2024-02-17 23:18:00 +08:00
""" Validate JSON with given type. """
return parse_raw_as ( type_ , data )
2024-04-16 00:33:48 +08:00
def custom_validation ( class_ : type [ " CVC " ] ) - > type [ " CVC " ] :
2024-01-26 11:12:57 +08:00
""" Do nothing in pydantic v1 """
return class_