mirror of
https://github.com/nonebot/nonebot2.git
synced 2024-12-01 01:25:07 +08:00
♿ reuse type check code for dependent
This commit is contained in:
parent
1271a757c9
commit
ad712c59b3
@ -17,10 +17,10 @@ from nonebot.log import logger
|
|||||||
from nonebot.exception import TypeMisMatch
|
from nonebot.exception import TypeMisMatch
|
||||||
from nonebot.utils import run_sync, is_coroutine_callable
|
from nonebot.utils import run_sync, is_coroutine_callable
|
||||||
|
|
||||||
from .utils import get_typed_signature
|
from .utils import check_field_type, get_typed_signature
|
||||||
|
|
||||||
T = TypeVar("T", bound="Dependent")
|
|
||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
T = TypeVar("T", bound="Dependent")
|
||||||
|
|
||||||
|
|
||||||
class Param(abc.ABC, FieldInfo):
|
class Param(abc.ABC, FieldInfo):
|
||||||
@ -196,16 +196,16 @@ class Dependent(Generic[R]):
|
|||||||
value = await field_info._solve(**params)
|
value = await field_info._solve(**params)
|
||||||
if value == Undefined:
|
if value == Undefined:
|
||||||
value = field.get_default()
|
value = field.get_default()
|
||||||
_, errs_ = field.validate(value, values, loc=(str(field_info), field.alias))
|
|
||||||
if errs_:
|
try:
|
||||||
|
values[field.name] = check_field_type(field, value)
|
||||||
|
except TypeMisMatch:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"{field_info} "
|
f"{field_info} "
|
||||||
f"type {type(value)} not match depends {self.call} "
|
f"type {type(value)} not match depends {self.call} "
|
||||||
f"annotation {field._type_display()}, ignored"
|
f"annotation {field._type_display()}, ignored"
|
||||||
)
|
)
|
||||||
raise TypeMisMatch(field, value)
|
raise
|
||||||
else:
|
|
||||||
values[field.name] = value
|
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@ -4,11 +4,16 @@ FrontMatter:
|
|||||||
description: nonebot.dependencies.utils 模块
|
description: nonebot.dependencies.utils 模块
|
||||||
"""
|
"""
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Dict, Callable
|
from typing import Any, Dict, TypeVar, Callable
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from pydantic.fields import ModelField
|
||||||
from pydantic.typing import ForwardRef, evaluate_forwardref
|
from pydantic.typing import ForwardRef, evaluate_forwardref
|
||||||
|
|
||||||
|
from nonebot.exception import TypeMisMatch
|
||||||
|
|
||||||
|
V = TypeVar("V")
|
||||||
|
|
||||||
|
|
||||||
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||||
"""获取可调用对象签名"""
|
"""获取可调用对象签名"""
|
||||||
@ -40,3 +45,10 @@ def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) ->
|
|||||||
)
|
)
|
||||||
return inspect.Parameter.empty
|
return inspect.Parameter.empty
|
||||||
return annotation
|
return annotation
|
||||||
|
|
||||||
|
|
||||||
|
def check_field_type(field: ModelField, value: V) -> V:
|
||||||
|
_, errs_ = field.validate(value, {}, loc=())
|
||||||
|
if errs_:
|
||||||
|
raise TypeMisMatch(field, value)
|
||||||
|
return value
|
||||||
|
@ -17,6 +17,7 @@ from pydantic.fields import Required, Undefined, ModelField
|
|||||||
from nonebot.log import logger
|
from nonebot.log import logger
|
||||||
from nonebot.exception import TypeMisMatch
|
from nonebot.exception import TypeMisMatch
|
||||||
from nonebot.adapters import Bot, Event, Message
|
from nonebot.adapters import Bot, Event, Message
|
||||||
|
from nonebot.dependencies.utils import check_field_type
|
||||||
from nonebot.dependencies import Param, Dependent, CustomConfig
|
from nonebot.dependencies import Param, Dependent, CustomConfig
|
||||||
from nonebot.typing import T_State, T_Handler, T_DependencyCache
|
from nonebot.typing import T_State, T_Handler, T_DependencyCache
|
||||||
from nonebot.consts import (
|
from nonebot.consts import (
|
||||||
@ -174,13 +175,14 @@ class DependParam(Param):
|
|||||||
class _BotChecker(Param):
|
class _BotChecker(Param):
|
||||||
async def _solve(self, bot: Bot, **kwargs: Any) -> Any:
|
async def _solve(self, bot: Bot, **kwargs: Any) -> Any:
|
||||||
field: ModelField = self.extra["field"]
|
field: ModelField = self.extra["field"]
|
||||||
_, errs_ = field.validate(bot, {}, loc=("bot",))
|
try:
|
||||||
if errs_:
|
return check_field_type(field, bot)
|
||||||
|
except TypeMisMatch:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Bot type {type(bot)} not match "
|
f"Bot type {type(bot)} not match "
|
||||||
f"annotation {field._type_display()}, ignored"
|
f"annotation {field._type_display()}, ignored"
|
||||||
)
|
)
|
||||||
raise TypeMisMatch(field, bot)
|
raise
|
||||||
|
|
||||||
|
|
||||||
class BotParam(Param):
|
class BotParam(Param):
|
||||||
@ -217,13 +219,14 @@ class BotParam(Param):
|
|||||||
class _EventChecker(Param):
|
class _EventChecker(Param):
|
||||||
async def _solve(self, event: Event, **kwargs: Any) -> Any:
|
async def _solve(self, event: Event, **kwargs: Any) -> Any:
|
||||||
field: ModelField = self.extra["field"]
|
field: ModelField = self.extra["field"]
|
||||||
_, errs_ = field.validate(event, {}, loc=("event",))
|
try:
|
||||||
if errs_:
|
return check_field_type(field, event)
|
||||||
|
except TypeMisMatch:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Event type {type(event)} not match "
|
f"Event type {type(event)} not match "
|
||||||
f"annotation {field._type_display()}, ignored"
|
f"annotation {field._type_display()}, ignored"
|
||||||
)
|
)
|
||||||
raise TypeMisMatch(field, event)
|
raise
|
||||||
|
|
||||||
|
|
||||||
class EventParam(Param):
|
class EventParam(Param):
|
||||||
|
Loading…
Reference in New Issue
Block a user