mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-11-24 00:55:07 +08:00
⚡ Feature: 优化依赖注入在 pydantic v2 下的性能 (#2870)
This commit is contained in:
parent
aeb75a6ce3
commit
b59b1be6ff
@ -8,17 +8,20 @@ FrontMatter:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
from functools import cached_property
|
||||||
from dataclasses import dataclass, is_dataclass
|
from dataclasses import dataclass, is_dataclass
|
||||||
from typing_extensions import Self, get_args, get_origin, is_typeddict
|
from typing_extensions import Self, get_args, get_origin, is_typeddict
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Union,
|
Union,
|
||||||
|
Generic,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Callable,
|
Callable,
|
||||||
Optional,
|
Optional,
|
||||||
Protocol,
|
Protocol,
|
||||||
Annotated,
|
Annotated,
|
||||||
|
overload,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import VERSION, BaseModel
|
from pydantic import VERSION, BaseModel
|
||||||
@ -46,8 +49,8 @@ __all__ = (
|
|||||||
"DEFAULT_CONFIG",
|
"DEFAULT_CONFIG",
|
||||||
"FieldInfo",
|
"FieldInfo",
|
||||||
"ModelField",
|
"ModelField",
|
||||||
|
"TypeAdapter",
|
||||||
"extract_field_info",
|
"extract_field_info",
|
||||||
"model_field_validate",
|
|
||||||
"model_fields",
|
"model_fields",
|
||||||
"model_config",
|
"model_config",
|
||||||
"model_dump",
|
"model_dump",
|
||||||
@ -63,9 +66,10 @@ __autodoc__ = {
|
|||||||
|
|
||||||
|
|
||||||
if PYDANTIC_V2: # pragma: pydantic-v2
|
if PYDANTIC_V2: # pragma: pydantic-v2
|
||||||
|
from pydantic import GetCoreSchemaHandler
|
||||||
|
from pydantic import TypeAdapter as TypeAdapter
|
||||||
from pydantic_core import CoreSchema, core_schema
|
from pydantic_core import CoreSchema, core_schema
|
||||||
from pydantic._internal._repr import display_as_type
|
from pydantic._internal._repr import display_as_type
|
||||||
from pydantic import TypeAdapter, GetCoreSchemaHandler
|
|
||||||
from pydantic.fields import FieldInfo as BaseFieldInfo
|
from pydantic.fields import FieldInfo as BaseFieldInfo
|
||||||
|
|
||||||
Required = Ellipsis
|
Required = Ellipsis
|
||||||
@ -125,6 +129,25 @@ if PYDANTIC_V2: # pragma: pydantic-v2
|
|||||||
"""Construct a ModelField from given infos."""
|
"""Construct a ModelField from given infos."""
|
||||||
return cls._construct(name, annotation, field_info or FieldInfo())
|
return cls._construct(name, annotation, field_info or FieldInfo())
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
# Each ModelField is unique for our purposes,
|
||||||
|
# to allow store them in a set.
|
||||||
|
return id(self)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def type_adapter(self) -> TypeAdapter:
|
||||||
|
"""TypeAdapter of the field.
|
||||||
|
|
||||||
|
Cache the TypeAdapter to avoid creating it multiple times.
|
||||||
|
Pydantic v2 uses too much cpu time to create TypeAdapter.
|
||||||
|
|
||||||
|
See: https://github.com/pydantic/pydantic/issues/9834
|
||||||
|
"""
|
||||||
|
return TypeAdapter(
|
||||||
|
Annotated[self.annotation, self.field_info],
|
||||||
|
config=None if self._annotation_has_config() else DEFAULT_CONFIG,
|
||||||
|
)
|
||||||
|
|
||||||
def _annotation_has_config(self) -> bool:
|
def _annotation_has_config(self) -> bool:
|
||||||
"""Check if the annotation has config.
|
"""Check if the annotation has config.
|
||||||
|
|
||||||
@ -152,10 +175,9 @@ if PYDANTIC_V2: # pragma: pydantic-v2
|
|||||||
"""Get the display of the type of the field."""
|
"""Get the display of the type of the field."""
|
||||||
return display_as_type(self.annotation)
|
return display_as_type(self.annotation)
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def validate_value(self, value: Any) -> Any:
|
||||||
# Each ModelField is unique for our purposes,
|
"""Validate the value pass to the field."""
|
||||||
# to allow store them in a set.
|
return self.type_adapter.validate_python(value)
|
||||||
return id(self)
|
|
||||||
|
|
||||||
def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]:
|
def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]:
|
||||||
"""Get FieldInfo init kwargs from a FieldInfo instance."""
|
"""Get FieldInfo init kwargs from a FieldInfo instance."""
|
||||||
@ -164,15 +186,6 @@ if PYDANTIC_V2: # pragma: pydantic-v2
|
|||||||
kwargs["annotation"] = field_info.rebuild_annotation()
|
kwargs["annotation"] = field_info.rebuild_annotation()
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
def model_field_validate(
|
|
||||||
model_field: ModelField, value: Any, config: Optional[ConfigDict] = None
|
|
||||||
) -> Any:
|
|
||||||
"""Validate the value pass to the field."""
|
|
||||||
type: Any = Annotated[model_field.annotation, model_field.field_info]
|
|
||||||
return TypeAdapter(
|
|
||||||
type, config=None if model_field._annotation_has_config() else config
|
|
||||||
).validate_python(value)
|
|
||||||
|
|
||||||
def model_fields(model: type[BaseModel]) -> list[ModelField]:
|
def model_fields(model: type[BaseModel]) -> list[ModelField]:
|
||||||
"""Get field list of a model."""
|
"""Get field list of a model."""
|
||||||
|
|
||||||
@ -305,6 +318,45 @@ else: # pragma: pydantic-v1
|
|||||||
)
|
)
|
||||||
return cls._construct(name, annotation, field_info or FieldInfo())
|
return cls._construct(name, annotation, field_info or FieldInfo())
|
||||||
|
|
||||||
|
def validate_value(self, value: Any) -> Any:
|
||||||
|
"""Validate the value pass to the field."""
|
||||||
|
v, errs_ = self.validate(value, {}, loc=())
|
||||||
|
if errs_:
|
||||||
|
raise ValueError(value, self)
|
||||||
|
return v
|
||||||
|
|
||||||
|
class TypeAdapter(Generic[T]):
|
||||||
|
@overload
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
type: type[T],
|
||||||
|
*,
|
||||||
|
config: Optional[ConfigDict] = ...,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
type: Any,
|
||||||
|
*,
|
||||||
|
config: Optional[ConfigDict] = ...,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
type: Any,
|
||||||
|
*,
|
||||||
|
config: Optional[ConfigDict] = None,
|
||||||
|
) -> None:
|
||||||
|
self.type = type
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def validate_python(self, value: Any) -> T:
|
||||||
|
return type_validate_python(self.type, value)
|
||||||
|
|
||||||
|
def validate_json(self, value: Union[str, bytes]) -> T:
|
||||||
|
return type_validate_json(self.type, value)
|
||||||
|
|
||||||
def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]:
|
def extract_field_info(field_info: BaseFieldInfo) -> dict[str, Any]:
|
||||||
"""Get FieldInfo init kwargs from a FieldInfo instance."""
|
"""Get FieldInfo init kwargs from a FieldInfo instance."""
|
||||||
|
|
||||||
@ -314,22 +366,6 @@ else: # pragma: pydantic-v1
|
|||||||
kwargs.update(field_info.extra)
|
kwargs.update(field_info.extra)
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
def model_field_validate(
|
|
||||||
model_field: ModelField, value: Any, config: Optional[type[ConfigDict]] = None
|
|
||||||
) -> Any:
|
|
||||||
"""Validate the value pass to the field.
|
|
||||||
|
|
||||||
Set config before validate to ensure validate correctly.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if model_field.model_config is not config:
|
|
||||||
model_field.set_config(config or ConfigDict)
|
|
||||||
|
|
||||||
v, errs_ = model_field.validate(value, {}, loc=())
|
|
||||||
if errs_:
|
|
||||||
raise ValueError(value, model_field)
|
|
||||||
return v
|
|
||||||
|
|
||||||
def model_fields(model: type[BaseModel]) -> list[ModelField]:
|
def model_fields(model: type[BaseModel]) -> list[ModelField]:
|
||||||
"""Get field list of a model."""
|
"""Get field list of a model."""
|
||||||
|
|
||||||
|
@ -9,9 +9,9 @@ from typing import Any, Callable, ForwardRef
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from nonebot.compat import ModelField
|
||||||
from nonebot.exception import TypeMisMatch
|
from nonebot.exception import TypeMisMatch
|
||||||
from nonebot.typing import evaluate_forwardref
|
from nonebot.typing import evaluate_forwardref
|
||||||
from nonebot.compat import DEFAULT_CONFIG, ModelField, model_field_validate
|
|
||||||
|
|
||||||
|
|
||||||
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||||
@ -51,6 +51,6 @@ def check_field_type(field: ModelField, value: Any) -> Any:
|
|||||||
"""检查字段类型是否匹配"""
|
"""检查字段类型是否匹配"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return model_field_validate(field, value, DEFAULT_CONFIG)
|
return field.validate_value(value)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise TypeMisMatch(field, value)
|
raise TypeMisMatch(field, value)
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
from typing import Any, Optional
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional, Annotated
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
from nonebot.compat import (
|
from nonebot.compat import (
|
||||||
DEFAULT_CONFIG,
|
DEFAULT_CONFIG,
|
||||||
Required,
|
Required,
|
||||||
FieldInfo,
|
FieldInfo,
|
||||||
|
TypeAdapter,
|
||||||
PydanticUndefined,
|
PydanticUndefined,
|
||||||
model_dump,
|
model_dump,
|
||||||
custom_validation,
|
custom_validation,
|
||||||
@ -31,6 +32,21 @@ async def test_field_info():
|
|||||||
assert FieldInfo(test="test").extra["test"] == "test"
|
assert FieldInfo(test="test").extra["test"] == "test"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_type_adapter():
|
||||||
|
t = TypeAdapter(Annotated[int, FieldInfo(ge=1)])
|
||||||
|
|
||||||
|
assert t.validate_python(2) == 2
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
t.validate_python(0)
|
||||||
|
|
||||||
|
assert t.validate_json("2") == 2
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
t.validate_json("0")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_model_dump():
|
async def test_model_dump():
|
||||||
class TestModel(BaseModel):
|
class TestModel(BaseModel):
|
||||||
|
Loading…
Reference in New Issue
Block a user