From 3120abacb37d4b59d69c605f8c11872a11dda823 Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Mon, 22 Nov 2021 11:38:42 +0800 Subject: [PATCH] :wheelchair: allow extra param with default value --- nonebot/dependencies/__init__.py | 31 +++++++++++++++++++++---------- nonebot/dependencies/models.py | 6 ------ nonebot/matcher.py | 2 +- nonebot/message.py | 9 ++++++--- nonebot/params.py | 12 ++++++++++++ nonebot/permission.py | 5 ++--- nonebot/rule.py | 6 +++--- 7 files changed, 45 insertions(+), 26 deletions(-) diff --git a/nonebot/dependencies/__init__.py b/nonebot/dependencies/__init__.py index eb502e95..44f568f7 100644 --- a/nonebot/dependencies/__init__.py +++ b/nonebot/dependencies/__init__.py @@ -11,8 +11,8 @@ from typing import Any, Dict, List, Type, Tuple, Callable, Optional, cast from contextlib import AsyncExitStack, contextmanager, asynccontextmanager from pydantic import BaseConfig -from pydantic.fields import Required, ModelField from pydantic.schema import get_annotation_from_field_info +from pydantic.fields import Required, Undefined, ModelField from nonebot.log import logger from .models import Param as Param @@ -90,16 +90,25 @@ def get_dependent(*, dependent.dependencies.append(sub_dependent) continue - for allow_type in dependent.allow_types: - if allow_type._check(param_name, param): - field_info = allow_type(param.default) - break + default_value = Required + if param.default != param.empty: + default_value = param.default + + if isinstance(default_value, Param): + field_info = default_value + default_value = field_info.default else: - raise ValueError( - f"Unknown parameter {param_name} for function {func} with type {param.annotation}" - ) + for allow_type in dependent.allow_types: + if allow_type._check(param_name, param): + field_info = allow_type(default_value) + break + else: + raise ValueError( + f"Unknown parameter {param_name} for function {func} with type {param.annotation}" + ) annotation: Any = Any + required = default_value == Required if param.annotation != param.empty: annotation = param.annotation annotation = get_annotation_from_field_info(annotation, field_info, @@ -109,8 +118,8 @@ def get_dependent(*, type_=annotation, class_validators=None, model_config=CustomConfig, - default=Required, - required=True, + default=None if required else default_value, + required=required, field_info=field_info)) return dependent @@ -176,6 +185,8 @@ async def solve_dependencies( assert isinstance(field_info, Param), "Params must be subclasses of Param" value = field_info._solve(**params) + if value == Undefined: + value = field.get_default() _, errs_ = field.validate(value, values, loc=(str(field_info), field.alias)) diff --git a/nonebot/dependencies/models.py b/nonebot/dependencies/models.py index 36ee5af0..9acba4db 100644 --- a/nonebot/dependencies/models.py +++ b/nonebot/dependencies/models.py @@ -10,12 +10,6 @@ from nonebot.typing import T_Handler class Param(abc.ABC, FieldInfo): - def __repr__(self) -> str: - return f"{self.__class__.__name__}" - - def __str__(self) -> str: - return repr(self) - @classmethod @abc.abstractmethod def _check(cls, name: str, param: inspect.Parameter) -> bool: diff --git a/nonebot/matcher.py b/nonebot/matcher.py index 7eeda5ab..cbaab224 100644 --- a/nonebot/matcher.py +++ b/nonebot/matcher.py @@ -158,7 +158,7 @@ class Matcher(metaclass=MatcherMeta): HANDLER_PARAM_TYPES = [ params.BotParam, params.EventParam, params.StateParam, - params.MatcherParam + params.MatcherParam, params.DefaultParam ] def __init__(self): diff --git a/nonebot/message.py b/nonebot/message.py index bedf778c..87184362 100644 --- a/nonebot/message.py +++ b/nonebot/message.py @@ -29,13 +29,16 @@ _event_postprocessors: Set[Handler] = set() _run_preprocessors: Set[Handler] = set() _run_postprocessors: Set[Handler] = set() -EVENT_PCS_PARAMS = [params.BotParam, params.EventParam, params.StateParam] +EVENT_PCS_PARAMS = [ + params.BotParam, params.EventParam, params.StateParam, params.DefaultParam +] RUN_PREPCS_PARAMS = [ - params.MatcherParam, params.BotParam, params.EventParam, params.StateParam + params.MatcherParam, params.BotParam, params.EventParam, params.StateParam, + params.DefaultParam ] RUN_POSTPCS_PARAMS = [ params.MatcherParam, params.ExceptionParam, params.BotParam, - params.EventParam, params.StateParam + params.EventParam, params.StateParam, params.DefaultParam ] diff --git a/nonebot/params.py b/nonebot/params.py index 8b644f91..06fb6987 100644 --- a/nonebot/params.py +++ b/nonebot/params.py @@ -1,6 +1,8 @@ import inspect from typing import Any, Dict, Optional +from pydantic.fields import Undefined + from nonebot.typing import T_State from nonebot.dependencies import Param from nonebot.adapters import Bot, Event @@ -69,4 +71,14 @@ class ExceptionParam(Param): return exception +class DefaultParam(Param): + + @classmethod + def _check(cls, name: str, param: inspect.Parameter) -> bool: + return param.default != param.empty + + def _solve(self, **kwargs: Any) -> Any: + return Undefined + + from nonebot.matcher import Matcher diff --git a/nonebot/permission.py b/nonebot/permission.py index e52058cb..4e7e89da 100644 --- a/nonebot/permission.py +++ b/nonebot/permission.py @@ -15,7 +15,6 @@ from typing import Any, Dict, List, Type, Union, Callable, NoReturn, Optional from nonebot import params from nonebot.handler import Handler -from nonebot.dependencies import Param from nonebot.adapters import Bot, Event from nonebot.typing import T_PermissionChecker @@ -37,8 +36,8 @@ class Permission: """ __slots__ = ("checkers",) - HANDLER_PARAM_TYPES: List[Type[Param]] = [ - params.BotParam, params.EventParam + HANDLER_PARAM_TYPES = [ + params.BotParam, params.EventParam, params.DefaultParam ] def __init__(self, *checkers: Union[T_PermissionChecker, Handler]) -> None: diff --git a/nonebot/rule.py b/nonebot/rule.py index 1621aef6..2059ceac 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -25,7 +25,6 @@ from pygtrie import CharTrie from nonebot.log import logger from nonebot.handler import Handler from nonebot import params, get_driver -from nonebot.dependencies import Param from nonebot.exception import ParserExit from nonebot.typing import T_State, T_RuleChecker from nonebot.adapters import Bot, Event, MessageSegment @@ -64,8 +63,9 @@ class Rule: """ __slots__ = ("checkers",) - HANDLER_PARAM_TYPES: List[Type[Param]] = [ - params.BotParam, params.EventParam, params.StateParam + HANDLER_PARAM_TYPES = [ + params.BotParam, params.EventParam, params.StateParam, + params.DefaultParam ] def __init__(self, *checkers: Union[T_RuleChecker, Handler]) -> None: