reuse type check code for dependent

This commit is contained in:
yanyongyu 2022-01-28 14:49:04 +08:00
parent 1271a757c9
commit ad712c59b3
No known key found for this signature in database
GPG Key ID: 796D8A7FB73396EB
3 changed files with 29 additions and 14 deletions

View File

@ -17,10 +17,10 @@ from nonebot.log import logger
from nonebot.exception import TypeMisMatch from nonebot.exception import TypeMisMatch
from nonebot.utils import run_sync, is_coroutine_callable from nonebot.utils import run_sync, is_coroutine_callable
from .utils import get_typed_signature from .utils import check_field_type, get_typed_signature
T = TypeVar("T", bound="Dependent")
R = TypeVar("R") R = TypeVar("R")
T = TypeVar("T", bound="Dependent")
class Param(abc.ABC, FieldInfo): class Param(abc.ABC, FieldInfo):
@ -196,16 +196,16 @@ class Dependent(Generic[R]):
value = await field_info._solve(**params) value = await field_info._solve(**params)
if value == Undefined: if value == Undefined:
value = field.get_default() value = field.get_default()
_, errs_ = field.validate(value, values, loc=(str(field_info), field.alias))
if errs_: try:
values[field.name] = check_field_type(field, value)
except TypeMisMatch:
logger.debug( logger.debug(
f"{field_info} " f"{field_info} "
f"type {type(value)} not match depends {self.call} " f"type {type(value)} not match depends {self.call} "
f"annotation {field._type_display()}, ignored" f"annotation {field._type_display()}, ignored"
) )
raise TypeMisMatch(field, value) raise
else:
values[field.name] = value
return values return values

View File

@ -4,11 +4,16 @@ FrontMatter:
description: nonebot.dependencies.utils 模块 description: nonebot.dependencies.utils 模块
""" """
import inspect import inspect
from typing import Any, Dict, Callable from typing import Any, Dict, TypeVar, Callable
from loguru import logger from loguru import logger
from pydantic.fields import ModelField
from pydantic.typing import ForwardRef, evaluate_forwardref from pydantic.typing import ForwardRef, evaluate_forwardref
from nonebot.exception import TypeMisMatch
V = TypeVar("V")
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
"""获取可调用对象签名""" """获取可调用对象签名"""
@ -40,3 +45,10 @@ def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) ->
) )
return inspect.Parameter.empty return inspect.Parameter.empty
return annotation return annotation
def check_field_type(field: ModelField, value: V) -> V:
_, errs_ = field.validate(value, {}, loc=())
if errs_:
raise TypeMisMatch(field, value)
return value

View File

@ -17,6 +17,7 @@ from pydantic.fields import Required, Undefined, ModelField
from nonebot.log import logger from nonebot.log import logger
from nonebot.exception import TypeMisMatch from nonebot.exception import TypeMisMatch
from nonebot.adapters import Bot, Event, Message from nonebot.adapters import Bot, Event, Message
from nonebot.dependencies.utils import check_field_type
from nonebot.dependencies import Param, Dependent, CustomConfig from nonebot.dependencies import Param, Dependent, CustomConfig
from nonebot.typing import T_State, T_Handler, T_DependencyCache from nonebot.typing import T_State, T_Handler, T_DependencyCache
from nonebot.consts import ( from nonebot.consts import (
@ -174,13 +175,14 @@ class DependParam(Param):
class _BotChecker(Param): class _BotChecker(Param):
async def _solve(self, bot: Bot, **kwargs: Any) -> Any: async def _solve(self, bot: Bot, **kwargs: Any) -> Any:
field: ModelField = self.extra["field"] field: ModelField = self.extra["field"]
_, errs_ = field.validate(bot, {}, loc=("bot",)) try:
if errs_: return check_field_type(field, bot)
except TypeMisMatch:
logger.debug( logger.debug(
f"Bot type {type(bot)} not match " f"Bot type {type(bot)} not match "
f"annotation {field._type_display()}, ignored" f"annotation {field._type_display()}, ignored"
) )
raise TypeMisMatch(field, bot) raise
class BotParam(Param): class BotParam(Param):
@ -217,13 +219,14 @@ class BotParam(Param):
class _EventChecker(Param): class _EventChecker(Param):
async def _solve(self, event: Event, **kwargs: Any) -> Any: async def _solve(self, event: Event, **kwargs: Any) -> Any:
field: ModelField = self.extra["field"] field: ModelField = self.extra["field"]
_, errs_ = field.validate(event, {}, loc=("event",)) try:
if errs_: return check_field_type(field, event)
except TypeMisMatch:
logger.debug( logger.debug(
f"Event type {type(event)} not match " f"Event type {type(event)} not match "
f"annotation {field._type_display()}, ignored" f"annotation {field._type_display()}, ignored"
) )
raise TypeMisMatch(field, event) raise
class EventParam(Param): class EventParam(Param):