reuse type check code for dependent

This commit is contained in:
yanyongyu 2022-01-28 14:49:04 +08:00
parent 1271a757c9
commit ad712c59b3
No known key found for this signature in database
GPG Key ID: 796D8A7FB73396EB
3 changed files with 29 additions and 14 deletions

View File

@ -17,10 +17,10 @@ from nonebot.log import logger
from nonebot.exception import TypeMisMatch
from nonebot.utils import run_sync, is_coroutine_callable
from .utils import get_typed_signature
from .utils import check_field_type, get_typed_signature
T = TypeVar("T", bound="Dependent")
R = TypeVar("R")
T = TypeVar("T", bound="Dependent")
class Param(abc.ABC, FieldInfo):
@ -196,16 +196,16 @@ class Dependent(Generic[R]):
value = await field_info._solve(**params)
if value == Undefined:
value = field.get_default()
_, errs_ = field.validate(value, values, loc=(str(field_info), field.alias))
if errs_:
try:
values[field.name] = check_field_type(field, value)
except TypeMisMatch:
logger.debug(
f"{field_info} "
f"type {type(value)} not match depends {self.call} "
f"annotation {field._type_display()}, ignored"
)
raise TypeMisMatch(field, value)
else:
values[field.name] = value
raise
return values

View File

@ -4,11 +4,16 @@ FrontMatter:
description: nonebot.dependencies.utils 模块
"""
import inspect
from typing import Any, Dict, Callable
from typing import Any, Dict, TypeVar, Callable
from loguru import logger
from pydantic.fields import ModelField
from pydantic.typing import ForwardRef, evaluate_forwardref
from nonebot.exception import TypeMisMatch
V = TypeVar("V")
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
"""获取可调用对象签名"""
@ -40,3 +45,10 @@ def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) ->
)
return inspect.Parameter.empty
return annotation
def check_field_type(field: ModelField, value: V) -> V:
_, errs_ = field.validate(value, {}, loc=())
if errs_:
raise TypeMisMatch(field, value)
return value

View File

@ -17,6 +17,7 @@ from pydantic.fields import Required, Undefined, ModelField
from nonebot.log import logger
from nonebot.exception import TypeMisMatch
from nonebot.adapters import Bot, Event, Message
from nonebot.dependencies.utils import check_field_type
from nonebot.dependencies import Param, Dependent, CustomConfig
from nonebot.typing import T_State, T_Handler, T_DependencyCache
from nonebot.consts import (
@ -174,13 +175,14 @@ class DependParam(Param):
class _BotChecker(Param):
async def _solve(self, bot: Bot, **kwargs: Any) -> Any:
field: ModelField = self.extra["field"]
_, errs_ = field.validate(bot, {}, loc=("bot",))
if errs_:
try:
return check_field_type(field, bot)
except TypeMisMatch:
logger.debug(
f"Bot type {type(bot)} not match "
f"annotation {field._type_display()}, ignored"
)
raise TypeMisMatch(field, bot)
raise
class BotParam(Param):
@ -217,13 +219,14 @@ class BotParam(Param):
class _EventChecker(Param):
async def _solve(self, event: Event, **kwargs: Any) -> Any:
field: ModelField = self.extra["field"]
_, errs_ = field.validate(event, {}, loc=("event",))
if errs_:
try:
return check_field_type(field, event)
except TypeMisMatch:
logger.debug(
f"Event type {type(event)} not match "
f"annotation {field._type_display()}, ignored"
)
raise TypeMisMatch(field, event)
raise
class EventParam(Param):