allow extra param with default value

This commit is contained in:
yanyongyu 2021-11-22 11:38:42 +08:00
parent 23c237cb2a
commit 3120abacb3
7 changed files with 45 additions and 26 deletions

View File

@ -11,8 +11,8 @@ from typing import Any, Dict, List, Type, Tuple, Callable, Optional, cast
from contextlib import AsyncExitStack, contextmanager, asynccontextmanager
from pydantic import BaseConfig
from pydantic.fields import Required, ModelField
from pydantic.schema import get_annotation_from_field_info
from pydantic.fields import Required, Undefined, ModelField
from nonebot.log import logger
from .models import Param as Param
@ -90,16 +90,25 @@ def get_dependent(*,
dependent.dependencies.append(sub_dependent)
continue
for allow_type in dependent.allow_types:
if allow_type._check(param_name, param):
field_info = allow_type(param.default)
break
default_value = Required
if param.default != param.empty:
default_value = param.default
if isinstance(default_value, Param):
field_info = default_value
default_value = field_info.default
else:
raise ValueError(
f"Unknown parameter {param_name} for function {func} with type {param.annotation}"
)
for allow_type in dependent.allow_types:
if allow_type._check(param_name, param):
field_info = allow_type(default_value)
break
else:
raise ValueError(
f"Unknown parameter {param_name} for function {func} with type {param.annotation}"
)
annotation: Any = Any
required = default_value == Required
if param.annotation != param.empty:
annotation = param.annotation
annotation = get_annotation_from_field_info(annotation, field_info,
@ -109,8 +118,8 @@ def get_dependent(*,
type_=annotation,
class_validators=None,
model_config=CustomConfig,
default=Required,
required=True,
default=None if required else default_value,
required=required,
field_info=field_info))
return dependent
@ -176,6 +185,8 @@ async def solve_dependencies(
assert isinstance(field_info,
Param), "Params must be subclasses of Param"
value = field_info._solve(**params)
if value == Undefined:
value = field.get_default()
_, errs_ = field.validate(value,
values,
loc=(str(field_info), field.alias))

View File

@ -10,12 +10,6 @@ from nonebot.typing import T_Handler
class Param(abc.ABC, FieldInfo):
def __repr__(self) -> str:
return f"{self.__class__.__name__}"
def __str__(self) -> str:
return repr(self)
@classmethod
@abc.abstractmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:

View File

@ -158,7 +158,7 @@ class Matcher(metaclass=MatcherMeta):
HANDLER_PARAM_TYPES = [
params.BotParam, params.EventParam, params.StateParam,
params.MatcherParam
params.MatcherParam, params.DefaultParam
]
def __init__(self):

View File

@ -29,13 +29,16 @@ _event_postprocessors: Set[Handler] = set()
_run_preprocessors: Set[Handler] = set()
_run_postprocessors: Set[Handler] = set()
EVENT_PCS_PARAMS = [params.BotParam, params.EventParam, params.StateParam]
EVENT_PCS_PARAMS = [
params.BotParam, params.EventParam, params.StateParam, params.DefaultParam
]
RUN_PREPCS_PARAMS = [
params.MatcherParam, params.BotParam, params.EventParam, params.StateParam
params.MatcherParam, params.BotParam, params.EventParam, params.StateParam,
params.DefaultParam
]
RUN_POSTPCS_PARAMS = [
params.MatcherParam, params.ExceptionParam, params.BotParam,
params.EventParam, params.StateParam
params.EventParam, params.StateParam, params.DefaultParam
]

View File

@ -1,6 +1,8 @@
import inspect
from typing import Any, Dict, Optional
from pydantic.fields import Undefined
from nonebot.typing import T_State
from nonebot.dependencies import Param
from nonebot.adapters import Bot, Event
@ -69,4 +71,14 @@ class ExceptionParam(Param):
return exception
class DefaultParam(Param):
@classmethod
def _check(cls, name: str, param: inspect.Parameter) -> bool:
return param.default != param.empty
def _solve(self, **kwargs: Any) -> Any:
return Undefined
from nonebot.matcher import Matcher

View File

@ -15,7 +15,6 @@ from typing import Any, Dict, List, Type, Union, Callable, NoReturn, Optional
from nonebot import params
from nonebot.handler import Handler
from nonebot.dependencies import Param
from nonebot.adapters import Bot, Event
from nonebot.typing import T_PermissionChecker
@ -37,8 +36,8 @@ class Permission:
"""
__slots__ = ("checkers",)
HANDLER_PARAM_TYPES: List[Type[Param]] = [
params.BotParam, params.EventParam
HANDLER_PARAM_TYPES = [
params.BotParam, params.EventParam, params.DefaultParam
]
def __init__(self, *checkers: Union[T_PermissionChecker, Handler]) -> None:

View File

@ -25,7 +25,6 @@ from pygtrie import CharTrie
from nonebot.log import logger
from nonebot.handler import Handler
from nonebot import params, get_driver
from nonebot.dependencies import Param
from nonebot.exception import ParserExit
from nonebot.typing import T_State, T_RuleChecker
from nonebot.adapters import Bot, Event, MessageSegment
@ -64,8 +63,9 @@ class Rule:
"""
__slots__ = ("checkers",)
HANDLER_PARAM_TYPES: List[Type[Param]] = [
params.BotParam, params.EventParam, params.StateParam
HANDLER_PARAM_TYPES = [
params.BotParam, params.EventParam, params.StateParam,
params.DefaultParam
]
def __init__(self, *checkers: Union[T_RuleChecker, Handler]) -> None: