allow extra param with default value

This commit is contained in:
yanyongyu 2021-11-22 11:38:42 +08:00
parent 23c237cb2a
commit 3120abacb3
7 changed files with 45 additions and 26 deletions

View File

@ -11,8 +11,8 @@ from typing import Any, Dict, List, Type, Tuple, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from pydantic import BaseConfig from pydantic import BaseConfig
from pydantic.fields import Required, ModelField
from pydantic.schema import get_annotation_from_field_info from pydantic.schema import get_annotation_from_field_info
from pydantic.fields import Required, Undefined, ModelField
from nonebot.log import logger from nonebot.log import logger
from .models import Param as Param from .models import Param as Param
@ -90,16 +90,25 @@ def get_dependent(*,
dependent.dependencies.append(sub_dependent) dependent.dependencies.append(sub_dependent)
continue continue
for allow_type in dependent.allow_types: default_value = Required
if allow_type._check(param_name, param): if param.default != param.empty:
field_info = allow_type(param.default) default_value = param.default
break
if isinstance(default_value, Param):
field_info = default_value
default_value = field_info.default
else: else:
raise ValueError( for allow_type in dependent.allow_types:
f"Unknown parameter {param_name} for function {func} with type {param.annotation}" if allow_type._check(param_name, param):
) field_info = allow_type(default_value)
break
else:
raise ValueError(
f"Unknown parameter {param_name} for function {func} with type {param.annotation}"
)
annotation: Any = Any annotation: Any = Any
required = default_value == Required
if param.annotation != param.empty: if param.annotation != param.empty:
annotation = param.annotation annotation = param.annotation
annotation = get_annotation_from_field_info(annotation, field_info, annotation = get_annotation_from_field_info(annotation, field_info,
@ -109,8 +118,8 @@ def get_dependent(*,
type_=annotation, type_=annotation,
class_validators=None, class_validators=None,
model_config=CustomConfig, model_config=CustomConfig,
default=Required, default=None if required else default_value,
required=True, required=required,
field_info=field_info)) field_info=field_info))
return dependent return dependent
@ -176,6 +185,8 @@ async def solve_dependencies(
assert isinstance(field_info, assert isinstance(field_info,
Param), "Params must be subclasses of Param" Param), "Params must be subclasses of Param"
value = field_info._solve(**params) value = field_info._solve(**params)
if value == Undefined:
value = field.get_default()
_, errs_ = field.validate(value, _, errs_ = field.validate(value,
values, values,
loc=(str(field_info), field.alias)) loc=(str(field_info), field.alias))

View File

@ -10,12 +10,6 @@ from nonebot.typing import T_Handler
class Param(abc.ABC, FieldInfo): class Param(abc.ABC, FieldInfo):
def __repr__(self) -> str:
return f"{self.__class__.__name__}"
def __str__(self) -> str:
return repr(self)
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool: def _check(cls, name: str, param: inspect.Parameter) -> bool:

View File

@ -158,7 +158,7 @@ class Matcher(metaclass=MatcherMeta):
HANDLER_PARAM_TYPES = [ HANDLER_PARAM_TYPES = [
params.BotParam, params.EventParam, params.StateParam, params.BotParam, params.EventParam, params.StateParam,
params.MatcherParam params.MatcherParam, params.DefaultParam
] ]
def __init__(self): def __init__(self):

View File

@ -29,13 +29,16 @@ _event_postprocessors: Set[Handler] = set()
_run_preprocessors: Set[Handler] = set() _run_preprocessors: Set[Handler] = set()
_run_postprocessors: Set[Handler] = set() _run_postprocessors: Set[Handler] = set()
EVENT_PCS_PARAMS = [params.BotParam, params.EventParam, params.StateParam] EVENT_PCS_PARAMS = [
params.BotParam, params.EventParam, params.StateParam, params.DefaultParam
]
RUN_PREPCS_PARAMS = [ RUN_PREPCS_PARAMS = [
params.MatcherParam, params.BotParam, params.EventParam, params.StateParam params.MatcherParam, params.BotParam, params.EventParam, params.StateParam,
params.DefaultParam
] ]
RUN_POSTPCS_PARAMS = [ RUN_POSTPCS_PARAMS = [
params.MatcherParam, params.ExceptionParam, params.BotParam, params.MatcherParam, params.ExceptionParam, params.BotParam,
params.EventParam, params.StateParam params.EventParam, params.StateParam, params.DefaultParam
] ]

View File

@ -1,6 +1,8 @@
import inspect import inspect
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from pydantic.fields import Undefined
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot.dependencies import Param from nonebot.dependencies import Param
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
@ -69,4 +71,14 @@ class ExceptionParam(Param):
return exception return exception
class DefaultParam(Param):
@classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:
return param.default != param.empty
def _solve(self, **kwargs: Any) -> Any:
return Undefined
from nonebot.matcher import Matcher from nonebot.matcher import Matcher

View File

@ -15,7 +15,6 @@ from typing import Any, Dict, List, Type, Union, Callable, NoReturn, Optional
from nonebot import params from nonebot import params
from nonebot.handler import Handler from nonebot.handler import Handler
from nonebot.dependencies import Param
from nonebot.adapters import Bot, Event from nonebot.adapters import Bot, Event
from nonebot.typing import T_PermissionChecker from nonebot.typing import T_PermissionChecker
@ -37,8 +36,8 @@ class Permission:
""" """
__slots__ = ("checkers",) __slots__ = ("checkers",)
HANDLER_PARAM_TYPES: List[Type[Param]] = [ HANDLER_PARAM_TYPES = [
params.BotParam, params.EventParam params.BotParam, params.EventParam, params.DefaultParam
] ]
def __init__(self, *checkers: Union[T_PermissionChecker, Handler]) -> None: def __init__(self, *checkers: Union[T_PermissionChecker, Handler]) -> None:

View File

@ -25,7 +25,6 @@ from pygtrie import CharTrie
from nonebot.log import logger from nonebot.log import logger
from nonebot.handler import Handler from nonebot.handler import Handler
from nonebot import params, get_driver from nonebot import params, get_driver
from nonebot.dependencies import Param
from nonebot.exception import ParserExit from nonebot.exception import ParserExit
from nonebot.typing import T_State, T_RuleChecker from nonebot.typing import T_State, T_RuleChecker
from nonebot.adapters import Bot, Event, MessageSegment from nonebot.adapters import Bot, Event, MessageSegment
@ -64,8 +63,9 @@ class Rule:
""" """
__slots__ = ("checkers",) __slots__ = ("checkers",)
HANDLER_PARAM_TYPES: List[Type[Param]] = [ HANDLER_PARAM_TYPES = [
params.BotParam, params.EventParam, params.StateParam params.BotParam, params.EventParam, params.StateParam,
params.DefaultParam
] ]
def __init__(self, *checkers: Union[T_RuleChecker, Handler]) -> None: def __init__(self, *checkers: Union[T_RuleChecker, Handler]) -> None: