From f6b0809e5fd9fd4089501d63f7a8db0141ecf785 Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Sun, 11 Jun 2023 15:33:33 +0800 Subject: [PATCH] =?UTF-8?q?:sparkles:=20Feature:=20=E4=BE=9D=E8=B5=96?= =?UTF-8?q?=E6=B3=A8=E5=85=A5=E6=94=AF=E6=8C=81=20Generic=20TypeVar=20?= =?UTF-8?q?=E5=92=8C=20Matcher=20=E9=87=8D=E8=BD=BD=20(#2089)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot/internal/params.py | 16 +++++- nonebot/utils.py | 18 ++++++- tests/plugins/param/param_bot.py | 16 +++++- tests/plugins/param/param_event.py | 16 +++++- tests/plugins/param/param_matcher.py | 46 ++++++++++++++++++ tests/test_param.py | 73 +++++++++++++++++++++++++++- website/docs/advanced/dependency.mdx | 17 +++++-- website/docs/advanced/matcher.md | 4 ++ website/docs/appendices/overload.md | 2 +- 9 files changed, 198 insertions(+), 10 deletions(-) diff --git a/nonebot/internal/params.py b/nonebot/internal/params.py index ad27ce3c..6012eca5 100644 --- a/nonebot/internal/params.py +++ b/nonebot/internal/params.py @@ -317,7 +317,17 @@ class MatcherParam(Param): # param type is Matcher(s) or subclass(es) of Matcher or None if generic_check_issubclass(param.annotation, Matcher): - return cls(Required) + checker: Optional[ModelField] = None + if param.annotation is not Matcher: + checker = ModelField( + name=param.name, + type_=param.annotation, + class_validators=None, + model_config=CustomConfig, + default=None, + required=True, + ) + return cls(Required, checker=checker) # legacy: param is named "matcher" and has no type annotation elif param.annotation == param.empty and param.name == "matcher": return cls(Required) @@ -325,6 +335,10 @@ class MatcherParam(Param): async def _solve(self, matcher: "Matcher", **kwargs: Any) -> Any: return matcher + async def _check(self, matcher: "Matcher", **kwargs: Any) -> Any: + if checker := self.extra.get("checker", None): + check_field_type(checker, matcher) + class ArgInner: def __init__( diff --git a/nonebot/utils.py b/nonebot/utils.py index 5c770af4..8a13a8ab 100644 --- a/nonebot/utils.py +++ b/nonebot/utils.py @@ -58,8 +58,12 @@ def generic_check_issubclass( ) -> bool: """检查 cls 是否是 class_or_tuple 中的一个类型子类。 - 特别的,如果 cls 是 `typing.Union` 或 `types.UnionType` 类型, - 则会检查其中的所有类型是否是 class_or_tuple 中一个类型的子类或 None。 + 特别的: + + - 如果 cls 是 `typing.Union` 或 `types.UnionType` 类型, + 则会检查其中的所有类型是否是 class_or_tuple 中一个类型的子类或 None。 + - 如果 cls 是 `typing.TypeVar` 类型, + 则会检查其 `__bound__` 或 `__constraints__` 是否是 class_or_tuple 中一个类型的子类或 None。 """ try: return issubclass(cls, class_or_tuple) @@ -70,8 +74,18 @@ def generic_check_issubclass( is_none_type(type_) or generic_check_issubclass(type_, class_or_tuple) for type_ in get_args(cls) ) + # ensure generic List, Dict can be checked elif origin: return issubclass(origin, class_or_tuple) + elif isinstance(cls, TypeVar): + if cls.__constraints__: + return all( + is_none_type(type_) + or generic_check_issubclass(type_, class_or_tuple) + for type_ in cls.__constraints__ + ) + elif cls.__bound__: + return generic_check_issubclass(cls.__bound__, class_or_tuple) return False diff --git a/tests/plugins/param/param_bot.py b/tests/plugins/param/param_bot.py index 40c343f5..17943d26 100644 --- a/tests/plugins/param/param_bot.py +++ b/tests/plugins/param/param_bot.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, TypeVar from nonebot.adapters import Bot @@ -31,5 +31,19 @@ async def union_bot(b: Union[FooBot, BarBot]) -> Union[FooBot, BarBot]: return b +B = TypeVar("B", bound=Bot) + + +async def generic_bot(b: B) -> B: + return b + + +CB = TypeVar("CB", Bot, None) + + +async def generic_bot_none(b: CB) -> CB: + return b + + async def not_bot(b: Union[int, Bot]): ... diff --git a/tests/plugins/param/param_event.py b/tests/plugins/param/param_event.py index 3f4005fe..4def5e86 100644 --- a/tests/plugins/param/param_event.py +++ b/tests/plugins/param/param_event.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, TypeVar from nonebot.adapters import Event, Message from nonebot.params import EventToMe, EventType, EventMessage, EventPlainText @@ -32,6 +32,20 @@ async def union_event(e: Union[FooEvent, BarEvent]) -> Union[FooEvent, BarEvent] return e +E = TypeVar("E", bound=Event) + + +async def generic_event(e: E) -> E: + return e + + +CE = TypeVar("CE", Event, None) + + +async def generic_event_none(e: CE) -> CE: + return e + + async def not_event(e: Union[int, Event]): ... diff --git a/tests/plugins/param/param_matcher.py b/tests/plugins/param/param_matcher.py index ad8d5bd8..d1f9da6e 100644 --- a/tests/plugins/param/param_matcher.py +++ b/tests/plugins/param/param_matcher.py @@ -1,3 +1,5 @@ +from typing import Union, TypeVar + from nonebot.adapters import Event from nonebot.matcher import Matcher from nonebot.params import Received, LastReceived @@ -7,6 +9,50 @@ async def matcher(m: Matcher) -> Matcher: return m +async def legacy_matcher(matcher): + return matcher + + +async def not_legacy_matcher(matcher: int): + ... + + +class FooMatcher(Matcher): + ... + + +async def sub_matcher(m: FooMatcher) -> FooMatcher: + return m + + +class BarMatcher(Matcher): + ... + + +async def union_matcher( + m: Union[FooMatcher, BarMatcher] +) -> Union[FooMatcher, BarMatcher]: + return m + + +M = TypeVar("M", bound=Matcher) + + +async def generic_matcher(m: M) -> M: + return m + + +CM = TypeVar("CM", Matcher, None) + + +async def generic_matcher_none(m: CM) -> CM: + return m + + +async def not_matcher(m: Union[int, Matcher]): + ... + + async def receive(e: Event = Received("test")) -> Event: return e diff --git a/tests/test_param.py b/tests/test_param.py index 42dbd200..6fd930dc 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -90,7 +90,9 @@ async def test_bot(app: App): sub_bot, union_bot, legacy_bot, + generic_bot, not_legacy_bot, + generic_bot_none, ) async with app.test_dependent(get_bot, allow_types=[BotParam]) as ctx: @@ -122,6 +124,16 @@ async def test_bot(app: App): ctx.pass_params(bot=bot) ctx.should_return(bot) + async with app.test_dependent(generic_bot, allow_types=[BotParam]) as ctx: + bot = ctx.create_bot() + ctx.pass_params(bot=bot) + ctx.should_return(bot) + + async with app.test_dependent(generic_bot_none, allow_types=[BotParam]) as ctx: + bot = ctx.create_bot() + ctx.pass_params(bot=bot) + ctx.should_return(bot) + with pytest.raises(ValueError): async with app.test_dependent(not_bot, allow_types=[BotParam]) as ctx: ... @@ -139,8 +151,10 @@ async def test_event(app: App): union_event, legacy_event, event_message, + generic_event, event_plain_text, not_legacy_event, + generic_event_none, ) fake_message = make_fake_message()("text") @@ -173,6 +187,14 @@ async def test_event(app: App): ctx.pass_params(event=fake_fooevent) ctx.should_return(fake_event) + async with app.test_dependent(generic_event, allow_types=[EventParam]) as ctx: + ctx.pass_params(event=fake_event) + ctx.should_return(fake_event) + + async with app.test_dependent(generic_event_none, allow_types=[EventParam]) as ctx: + ctx.pass_params(event=fake_event) + ctx.should_return(fake_event) + with pytest.raises(ValueError): async with app.test_dependent(not_event, allow_types=[EventParam]) as ctx: ... @@ -351,14 +373,63 @@ async def test_state(app: App): @pytest.mark.asyncio async def test_matcher(app: App): - from plugins.param.param_matcher import matcher, receive, last_receive + from plugins.param.param_matcher import ( + FooMatcher, + matcher, + receive, + not_matcher, + sub_matcher, + last_receive, + union_matcher, + legacy_matcher, + generic_matcher, + not_legacy_matcher, + generic_matcher_none, + ) fake_matcher = Matcher() + foo_matcher = FooMatcher() async with app.test_dependent(matcher, allow_types=[MatcherParam]) as ctx: ctx.pass_params(matcher=fake_matcher) ctx.should_return(fake_matcher) + async with app.test_dependent(legacy_matcher, allow_types=[MatcherParam]) as ctx: + ctx.pass_params(matcher=fake_matcher) + ctx.should_return(fake_matcher) + + with pytest.raises(ValueError): + async with app.test_dependent( + not_legacy_matcher, allow_types=[MatcherParam] + ) as ctx: + ... + + async with app.test_dependent(sub_matcher, allow_types=[MatcherParam]) as ctx: + ctx.pass_params(matcher=foo_matcher) + ctx.should_return(foo_matcher) + + with pytest.raises(TypeMisMatch): + async with app.test_dependent(sub_matcher, allow_types=[MatcherParam]) as ctx: + ctx.pass_params(matcher=fake_matcher) + + async with app.test_dependent(union_matcher, allow_types=[MatcherParam]) as ctx: + ctx.pass_params(matcher=foo_matcher) + ctx.should_return(foo_matcher) + + async with app.test_dependent(generic_matcher, allow_types=[MatcherParam]) as ctx: + ctx.pass_params(matcher=fake_matcher) + ctx.should_return(fake_matcher) + + async with app.test_dependent( + generic_matcher_none, allow_types=[MatcherParam] + ) as ctx: + ctx.pass_params(matcher=fake_matcher) + ctx.should_return(fake_matcher) + + with pytest.raises(ValueError): + async with app.test_dependent(not_matcher, allow_types=[MatcherParam]) as ctx: + ... + event = make_fake_event()() fake_matcher.set_receive("test", event) event_next = make_fake_event()() diff --git a/website/docs/advanced/dependency.mdx b/website/docs/advanced/dependency.mdx index 9e5ef31b..7d33381b 100644 --- a/website/docs/advanced/dependency.mdx +++ b/website/docs/advanced/dependency.mdx @@ -71,7 +71,9 @@ async def _(foo: str = "bar"): ... 获取当前事件的 Bot 对象。 -通过标注参数为 `Bot` 类型,或者一系列 `Bot` 类型,即可获取到当前事件的 Bot 对象。为兼容性考虑,如果参数名为 `bot` 且无类型注解,也会视为 `Bot` 依赖注入。 +通过标注参数为 `Bot` 类型,或者一系列 `Bot` 类型,即可获取到当前事件的 Bot 对象。为兼容性考虑,如果参数名为 `bot` 且无类型注解,也会视为 Bot 依赖注入。 + +Bot 依赖注入支持重载(即:可以标注参数为子类型)且具有[重载优先检查权](../appendices/overload.md#重载)。 @@ -108,7 +110,9 @@ async def _(bot): ... # 兼容性处理 获取当前事件。 -通过标注参数为 `Event` 类型,或者一系列 `Event` 类型,即可获取到当前事件。为兼容性考虑,如果参数名为 `event` 且无类型注解,也会视为 `Event` 依赖注入。 +通过标注参数为 `Event` 类型,或者一系列 `Event` 类型,即可获取到当前事件。为兼容性考虑,如果参数名为 `event` 且无类型注解,也会视为 Event 依赖注入。 + +Event 依赖注入支持重载(即:可以标注参数为子类型)且具有[重载优先检查权](../appendices/overload.md#重载)。 @@ -143,6 +147,8 @@ async def _(event): ... # 兼容性处理 获取当前[会话状态](../appendices/session-state.md)。 +通过标注参数为 `T_State` 类型,即可获取到当前会话状态。为兼容性考虑,如果参数名为 `state` 且无类型注解,也会视为 State 依赖注入。 + ```python from nonebot.typing import T_State @@ -153,10 +159,15 @@ async def _(foo: T_State): ... 获取当前事件响应器实例。常用于使用[事件响应器操作](../appendices/session-control.mdx)。 +通过标注参数为 `Matcher` 类型,或者一系列 `Matcher` 类型,即可获取到当前事件。为兼容性考虑,如果参数名为 `matcher` 且无类型注解,也会视为 Matcher 依赖注入。 + +Matcher 依赖注入支持重载(即:可以标注参数为子类型)且具有[重载优先检查权](../appendices/overload.md#重载)。 + ```python from nonebot.matcher import Matcher -async def _(matcher: Matcher): ... +async def _(foo: Matcher): ... +async def _(matcher): ... # 兼容性处理 ``` ### Exception diff --git a/website/docs/advanced/matcher.md b/website/docs/advanced/matcher.md index f831dffe..e1f0b443 100644 --- a/website/docs/advanced/matcher.md +++ b/website/docs/advanced/matcher.md @@ -12,6 +12,10 @@ options: 在[指南](../tutorial/matcher.md)与[深入](../appendices/rule.md)中,我们已经介绍了事件响应器的基本用法以及响应规则、权限控制等功能。在这一节中,我们将介绍事件响应器的组成,内置的响应规则,与第三方响应规则拓展。 +:::tip 提示 +事件响应器允许继承,你可以通过直接继承 `Matcher` 类来创建一个新的事件响应器。 +::: + ## 事件响应器组成 ### 事件响应器类型 diff --git a/website/docs/appendices/overload.md b/website/docs/appendices/overload.md index 480de334..a466a4d5 100644 --- a/website/docs/appendices/overload.md +++ b/website/docs/appendices/overload.md @@ -66,7 +66,7 @@ async def handle_onebot(bot: OneBot): :::warning 注意 重载机制对所有的参数类型注解都有效,因此,依赖注入也可以使用这个特性来对不同的返回值进行处理。 -但 Bot 和 Event 二者的参数类型注解具有最高检查优先级,如果二者类型注解不匹配,那么其他依赖注入将不会执行(如:`Depends`)。 +但 Bot、Event 和 Matcher 三者的参数类型注解具有最高检查优先级,如果三者任一类型注解不匹配,那么其他依赖注入将不会执行(如:`Depends`)。 ::: :::tip 提示