From aa48299d5d66ca8c4687337b4f8a68b063affb48 Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Sun, 21 May 2023 16:01:55 +0800 Subject: [PATCH] :sparkles: improve dependency injection params (#2034) --- nonebot/internal/params.py | 155 +++++++++++++++++++++----------- nonebot/utils.py | 2 +- tests/plugins/param/priority.py | 23 +++++ tests/test_param.py | 39 ++++++++ 4 files changed, 165 insertions(+), 54 deletions(-) create mode 100644 tests/plugins/param/priority.py diff --git a/nonebot/internal/params.py b/nonebot/internal/params.py index 0fd5bbc3..ad27ce3c 100644 --- a/nonebot/internal/params.py +++ b/nonebot/internal/params.py @@ -71,7 +71,12 @@ def Depends( class DependParam(Param): - """子依赖参数""" + """子依赖注入参数。 + + 本注入解析所有子依赖注入,然后将它们的返回值作为参数值传递给父依赖。 + + 本注入应该具有最高优先级,因此应该在其他参数之前检查。 + """ def __repr__(self) -> str: return f"Depends({self.extra['dependent']})" @@ -168,7 +173,12 @@ class DependParam(Param): class BotParam(Param): - """{ref}`nonebot.adapters.Bot` 参数""" + """{ref}`nonebot.adapters.Bot` 注入参数。 + + 本注入解析所有类型为且仅为 {ref}`nonebot.adapters.Bot` 及其子类或 `None` 的参数。 + + 为保证兼容性,本注入还会解析名为 `bot` 且没有类型注解的参数。 + """ def __repr__(self) -> str: return ( @@ -187,21 +197,22 @@ class BotParam(Param): ) -> Optional["BotParam"]: from nonebot.adapters import Bot - if param.default == param.empty: - if generic_check_issubclass(param.annotation, Bot): - checker: Optional[ModelField] = None - if param.annotation is not Bot: - checker = ModelField( - name=param.name, - type_=param.annotation, - class_validators=None, - model_config=CustomConfig, - default=None, - required=True, - ) - return cls(Required, checker=checker) - elif param.annotation == param.empty and param.name == "bot": - return cls(Required) + # param type is Bot(s) or subclass(es) of Bot or None + if generic_check_issubclass(param.annotation, Bot): + checker: Optional[ModelField] = None + if param.annotation is not Bot: + checker = ModelField( + name=param.name, + type_=param.annotation, + class_validators=None, + model_config=CustomConfig, + default=None, + required=True, + ) + return cls(Required, checker=checker) + # legacy: param is named "bot" and has no type annotation + elif param.annotation == param.empty and param.name == "bot": + return cls(Required) async def _solve(self, bot: "Bot", **kwargs: Any) -> Any: return bot @@ -212,7 +223,12 @@ class BotParam(Param): class EventParam(Param): - """{ref}`nonebot.adapters.Event` 参数""" + """{ref}`nonebot.adapters.Event` 注入参数 + + 本注入解析所有类型为且仅为 {ref}`nonebot.adapters.Event` 及其子类或 `None` 的参数。 + + 为保证兼容性,本注入还会解析名为 `event` 且没有类型注解的参数。 + """ def __repr__(self) -> str: return ( @@ -231,21 +247,22 @@ class EventParam(Param): ) -> Optional["EventParam"]: from nonebot.adapters import Event - if param.default == param.empty: - if generic_check_issubclass(param.annotation, Event): - checker: Optional[ModelField] = None - if param.annotation is not Event: - checker = ModelField( - name=param.name, - type_=param.annotation, - class_validators=None, - model_config=CustomConfig, - default=None, - required=True, - ) - return cls(Required, checker=checker) - elif param.annotation == param.empty and param.name == "event": - return cls(Required) + # param type is Event(s) or subclass(es) of Event or None + if generic_check_issubclass(param.annotation, Event): + checker: Optional[ModelField] = None + if param.annotation is not Event: + checker = ModelField( + name=param.name, + type_=param.annotation, + class_validators=None, + model_config=CustomConfig, + default=None, + required=True, + ) + return cls(Required, checker=checker) + # legacy: param is named "event" and has no type annotation + elif param.annotation == param.empty and param.name == "event": + return cls(Required) async def _solve(self, event: "Event", **kwargs: Any) -> Any: return event @@ -256,7 +273,12 @@ class EventParam(Param): class StateParam(Param): - """事件处理状态参数""" + """事件处理状态注入参数 + + 本注入解析所有类型为 `T_State` 的参数。 + + 为保证兼容性,本注入还会解析名为 `state` 且没有类型注解的参数。 + """ def __repr__(self) -> str: return "StateParam()" @@ -265,18 +287,24 @@ class StateParam(Param): def _check_param( cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] ) -> Optional["StateParam"]: - if param.default == param.empty: - if param.annotation is T_State: - return cls(Required) - elif param.annotation == param.empty and param.name == "state": - return cls(Required) + # param type is T_State + if param.annotation is T_State: + return cls(Required) + # legacy: param is named "state" and has no type annotation + elif param.annotation == param.empty and param.name == "state": + return cls(Required) async def _solve(self, state: T_State, **kwargs: Any) -> Any: return state class MatcherParam(Param): - """事件响应器实例参数""" + """事件响应器实例注入参数 + + 本注入解析所有类型为且仅为 {ref}`nonebot.matcher.Matcher` 及其子类或 `None` 的参数。 + + 为保证兼容性,本注入还会解析名为 `matcher` 且没有类型注解的参数。 + """ def __repr__(self) -> str: return "MatcherParam()" @@ -287,9 +315,11 @@ class MatcherParam(Param): ) -> Optional["MatcherParam"]: from nonebot.matcher import Matcher - if generic_check_issubclass(param.annotation, Matcher) or ( - param.annotation == param.empty and param.name == "matcher" - ): + # param type is Matcher(s) or subclass(es) of Matcher or None + if generic_check_issubclass(param.annotation, Matcher): + return cls(Required) + # legacy: param is named "matcher" and has no type annotation + elif param.annotation == param.empty and param.name == "matcher": return cls(Required) async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any: @@ -308,22 +338,28 @@ class ArgInner: def Arg(key: Optional[str] = None) -> Any: - """`got` 的 Arg 参数消息""" + """Arg 参数消息""" return ArgInner(key, "message") def ArgStr(key: Optional[str] = None) -> str: - """`got` 的 Arg 参数消息文本""" + """Arg 参数消息文本""" return ArgInner(key, "str") # type: ignore def ArgPlainText(key: Optional[str] = None) -> str: - """`got` 的 Arg 参数消息纯文本""" + """Arg 参数消息纯文本""" return ArgInner(key, "plaintext") # type: ignore class ArgParam(Param): - """`got` 的 Arg 参数""" + """Arg 注入参数 + + 本注入解析事件响应器操作 `got` 所获取的参数。 + + 可以通过 `Arg`、`ArgStr`、`ArgPlainText` 等函数参数 `key` 指定获取的参数, + 留空则会根据参数名称获取。 + """ def __repr__(self) -> str: return f"ArgParam(key={self.extra['key']!r}, type={self.extra['type']!r})" @@ -338,7 +374,8 @@ class ArgParam(Param): ) async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any: - message = matcher.get_arg(self.extra["key"]) + key: str = self.extra["key"] + message = matcher.get_arg(key) if message is None: return message if self.extra["type"] == "message": @@ -350,7 +387,12 @@ class ArgParam(Param): class ExceptionParam(Param): - """`run_postprocessor` 的异常参数""" + """{ref}`nonebot.message.run_postprocessor` 的异常注入参数 + + 本注入解析所有类型为 `Exception` 或 `None` 的参数。 + + 为保证兼容性,本注入还会解析名为 `exception` 且没有类型注解的参数。 + """ def __repr__(self) -> str: return "ExceptionParam()" @@ -359,9 +401,11 @@ class ExceptionParam(Param): def _check_param( cls, param: inspect.Parameter, allow_types: Tuple[Type[Param], ...] ) -> Optional["ExceptionParam"]: - if generic_check_issubclass(param.annotation, Exception) or ( - param.annotation == param.empty and param.name == "exception" - ): + # param type is Exception(s) or subclass(es) of Exception or None + if generic_check_issubclass(param.annotation, Exception): + return cls(Required) + # legacy: param is named "exception" and has no type annotation + elif param.annotation == param.empty and param.name == "exception": return cls(Required) async def _solve(self, exception: Optional[Exception] = None, **kwargs: Any) -> Any: @@ -369,7 +413,12 @@ class ExceptionParam(Param): class DefaultParam(Param): - """默认值参数""" + """默认值注入参数 + + 本注入解析所有剩余未能解析且具有默认值的参数。 + + 本注入参数应该具有最低优先级,因此应该在所有其他注入参数之后使用。 + """ def __repr__(self) -> str: return f"DefaultParam(default={self.default!r})" diff --git a/nonebot/utils.py b/nonebot/utils.py index 93cffb70..9391a507 100644 --- a/nonebot/utils.py +++ b/nonebot/utils.py @@ -59,7 +59,7 @@ def generic_check_issubclass( """检查 cls 是否是 class_or_tuple 中的一个类型子类。 特别的,如果 cls 是 `typing.Union` 或 `types.UnionType` 类型, - 则会检查其中的类型是否是 class_or_tuple 中的一个类型子类。(None 会被忽略) + 则会检查其中的所有类型是否是 class_or_tuple 中一个类型的子类或 None。 """ try: return issubclass(cls, class_or_tuple) diff --git a/tests/plugins/param/priority.py b/tests/plugins/param/priority.py new file mode 100644 index 00000000..311855eb --- /dev/null +++ b/tests/plugins/param/priority.py @@ -0,0 +1,23 @@ +from typing import Optional + +from nonebot.typing import T_State +from nonebot.matcher import Matcher +from nonebot.params import Arg, Depends +from nonebot.adapters import Bot, Event, Message + + +def dependency(): + return 1 + + +async def complex_priority( + sub: int = Depends(dependency), + bot: Optional[Bot] = None, + event: Optional[Event] = None, + state: T_State = {}, + matcher: Optional[Matcher] = None, + arg: Message = Arg(), + exception: Optional[Exception] = None, + default: int = 1, +): + ... diff --git a/tests/test_param.py b/tests/test_param.py index 1878d320..42dbd200 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -4,6 +4,7 @@ import pytest from nonebug import App from nonebot.matcher import Matcher +from nonebot.dependencies import Dependent from nonebot.exception import TypeMisMatch from utils import make_fake_event, make_fake_message from nonebot.params import ( @@ -413,3 +414,41 @@ async def test_default(app: App): async with app.test_dependent(default, allow_types=[DefaultParam]) as ctx: ctx.should_return(1) + + +@pytest.mark.asyncio +async def test_priority(): + from plugins.param.priority import complex_priority + + dependent = Dependent.parse( + call=complex_priority, + allow_types=[ + DependParam, + BotParam, + EventParam, + StateParam, + MatcherParam, + ArgParam, + ExceptionParam, + DefaultParam, + ], + ) + for param in dependent.params: + if param.name == "sub": + assert isinstance(param.field_info, DependParam) + elif param.name == "bot": + assert isinstance(param.field_info, BotParam) + elif param.name == "event": + assert isinstance(param.field_info, EventParam) + elif param.name == "state": + assert isinstance(param.field_info, StateParam) + elif param.name == "matcher": + assert isinstance(param.field_info, MatcherParam) + elif param.name == "arg": + assert isinstance(param.field_info, ArgParam) + elif param.name == "exception": + assert isinstance(param.field_info, ExceptionParam) + elif param.name == "default": + assert isinstance(param.field_info, DefaultParam) + else: + raise ValueError(f"unknown param {param.name}")