Feature: 优化依赖注入在 pydantic v2 下的性能 (#2870)

This commit is contained in:
Ju4tCode 2024-08-11 15:15:59 +08:00 committed by GitHub
parent aeb75a6ce3
commit b59b1be6ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 87 additions and 35 deletions

View File

@ -8,17 +8,20 @@ FrontMatter:
""" """
from collections.abc import Generator from collections.abc import Generator
from functools import cached_property
from dataclasses import dataclass, is_dataclass from dataclasses import dataclass, is_dataclass
from typing_extensions import Self, get_args, get_origin, is_typeddict from typing_extensions import Self, get_args, get_origin, is_typeddict
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Union, Union,
Generic,
TypeVar, TypeVar,
Callable, Callable,
Optional, Optional,
Protocol, Protocol,
Annotated, Annotated,
overload,
) )
from pydantic import VERSION, BaseModel from pydantic import VERSION, BaseModel
@ -46,8 +49,8 @@ __all__ = (
"DEFAULT_CONFIG", "DEFAULT_CONFIG",
"FieldInfo", "FieldInfo",
"ModelField", "ModelField",
"TypeAdapter",
"extract_field_info", "extract_field_info",
"model_field_validate",
"model_fields", "model_fields",
"model_config", "model_config",
"model_dump", "model_dump",
@ -63,9 +66,10 @@ __autodoc__ = {
if PYDANTIC_V2: # pragma: pydantic-v2 if PYDANTIC_V2: # pragma: pydantic-v2
from pydantic import GetCoreSchemaHandler
from pydantic import TypeAdapter as TypeAdapter
from pydantic_core import CoreSchema, core_schema from pydantic_core import CoreSchema, core_schema
from pydantic._internal._repr import display_as_type from pydantic._internal._repr import display_as_type
from pydantic import TypeAdapter, GetCoreSchemaHandler
from pydantic.fields import FieldInfo as BaseFieldInfo from pydantic.fields import FieldInfo as BaseFieldInfo
Required = Ellipsis Required = Ellipsis
@ -125,6 +129,25 @@ if PYDANTIC_V2: # pragma: pydantic-v2
"""Construct a ModelField from given infos.""" """Construct a ModelField from given infos."""
return cls._construct(name, annotation, field_info or FieldInfo()) return cls._construct(name, annotation, field_info or FieldInfo())
def __hash__(self) -> int:
# Each ModelField is unique for our purposes,
# to allow store them in a set.
return id(self)
@cached_property
def type_adapter(self) -> TypeAdapter:
"""TypeAdapter of the field.
Cache the TypeAdapter to avoid creating it multiple times.
Pydantic v2 uses too much cpu time to create TypeAdapter.
See: https://github.com/pydantic/pydantic/issues/9834
"""
return TypeAdapter(
Annotated[self.annotation, self.field_info],
config=None if self._annotation_has_config() else DEFAULT_CONFIG,
)
def _annotation_has_config(self) -> bool: def _annotation_has_config(self) -> bool:
"""Check if the annotation has config. """Check if the annotation has config.
@ -152,10 +175,9 @@ if PYDANTIC_V2: # pragma: pydantic-v2
"""Get the display of the type of the field.""" """Get the display of the type of the field."""
return display_as_type(self.annotation) return display_as_type(self.annotation)
def __hash__(self) -> int: def validate_value(self, value: Any) -> Any:
# Each ModelField is unique for our purposes, """Validate the value pass to the field."""
# to allow store them in a set. return self.type_adapter.validate_python(value)
return id(self)
def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]: def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]:
"""Get FieldInfo init kwargs from a FieldInfo instance.""" """Get FieldInfo init kwargs from a FieldInfo instance."""
@ -164,15 +186,6 @@ if PYDANTIC_V2: # pragma: pydantic-v2
kwargs["annotation"] = field_info.rebuild_annotation() kwargs["annotation"] = field_info.rebuild_annotation()
return kwargs return kwargs
def model_field_validate(
model_field: ModelField, value: Any, config: Optional[ConfigDict] = None
) -> Any:
"""Validate the value pass to the field."""
type: Any = Annotated[model_field.annotation, model_field.field_info]
return TypeAdapter(
type, config=None if model_field._annotation_has_config() else config
).validate_python(value)
def model_fields(model: type[BaseModel]) -> list[ModelField]: def model_fields(model: type[BaseModel]) -> list[ModelField]:
"""Get field list of a model.""" """Get field list of a model."""
@ -305,6 +318,45 @@ else: # pragma: pydantic-v1
) )
return cls._construct(name, annotation, field_info or FieldInfo()) return cls._construct(name, annotation, field_info or FieldInfo())
def validate_value(self, value: Any) -> Any:
"""Validate the value pass to the field."""
v, errs_ = self.validate(value, {}, loc=())
if errs_:
raise ValueError(value, self)
return v
class TypeAdapter(Generic[T]):
@overload
def __init__(
self,
type: type[T],
*,
config: Optional[ConfigDict] = ...,
) -> None: ...
@overload
def __init__(
self,
type: Any,
*,
config: Optional[ConfigDict] = ...,
) -> None: ...
def __init__(
self,
type: Any,
*,
config: Optional[ConfigDict] = None,
) -> None:
self.type = type
self.config = config
def validate_python(self, value: Any) -> T:
return type_validate_python(self.type, value)
def validate_json(self, value: Union[str, bytes]) -> T:
return type_validate_json(self.type, value)
def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]: def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]:
"""Get FieldInfo init kwargs from a FieldInfo instance.""" """Get FieldInfo init kwargs from a FieldInfo instance."""
@ -314,22 +366,6 @@ else: # pragma: pydantic-v1
kwargs.update(field_info.extra) kwargs.update(field_info.extra)
return kwargs return kwargs
def model_field_validate(
model_field: ModelField, value: Any, config: Optional[type[ConfigDict]] = None
) -> Any:
"""Validate the value pass to the field.
Set config before validate to ensure validate correctly.
"""
if model_field.model_config is not config:
model_field.set_config(config or ConfigDict)
v, errs_ = model_field.validate(value, {}, loc=())
if errs_:
raise ValueError(value, model_field)
return v
def model_fields(model: type[BaseModel]) -> list[ModelField]: def model_fields(model: type[BaseModel]) -> list[ModelField]:
"""Get field list of a model.""" """Get field list of a model."""

View File

@ -9,9 +9,9 @@ from typing import Any, Callable, ForwardRef
from loguru import logger from loguru import logger
from nonebot.compat import ModelField
from nonebot.exception import TypeMisMatch from nonebot.exception import TypeMisMatch
from nonebot.typing import evaluate_forwardref from nonebot.typing import evaluate_forwardref
from nonebot.compat import DEFAULT_CONFIG, ModelField, model_field_validate
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
@ -51,6 +51,6 @@ def check_field_type(field: ModelField, value: Any) -> Any:
"""检查字段类型是否匹配""" """检查字段类型是否匹配"""
try: try:
return model_field_validate(field, value, DEFAULT_CONFIG) return field.validate_value(value)
except ValueError: except ValueError:
raise TypeMisMatch(field, value) raise TypeMisMatch(field, value)

View File

@ -1,13 +1,14 @@
from typing import Any, Optional
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Annotated
import pytest import pytest
from pydantic import BaseModel from pydantic import BaseModel, ValidationError
from nonebot.compat import ( from nonebot.compat import (
DEFAULT_CONFIG, DEFAULT_CONFIG,
Required, Required,
FieldInfo, FieldInfo,
TypeAdapter,
PydanticUndefined, PydanticUndefined,
model_dump, model_dump,
custom_validation, custom_validation,
@ -31,6 +32,21 @@ async def test_field_info():
assert FieldInfo(test="test").extra["test"] == "test" assert FieldInfo(test="test").extra["test"] == "test"
@pytest.mark.asyncio
async def test_type_adapter():
t = TypeAdapter(Annotated[int, FieldInfo(ge=1)])
assert t.validate_python(2) == 2
with pytest.raises(ValidationError):
t.validate_python(0)
assert t.validate_json("2") == 2
with pytest.raises(ValidationError):
t.validate_json("0")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_model_dump(): async def test_model_dump():
class TestModel(BaseModel): class TestModel(BaseModel):