From 3fda97806425bfbea200b3b1c71ab7eca9fd50f3 Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Tue, 30 Aug 2022 09:54:09 +0800 Subject: [PATCH] =?UTF-8?q?:sparkles:=20Feature:=20=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=E4=BA=8B=E4=BB=B6=E7=B1=BB=E5=9E=8B=E8=BF=87=E6=BB=A4=20rule?= =?UTF-8?q?=20(#1183)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot/__init__.py | 2 + nonebot/plugin/__init__.py | 2 + nonebot/plugin/on.py | 51 +++++++ nonebot/plugin/on.pyi | 30 +++- nonebot/rule.py | 137 ++++++++++++++++- tests/plugins/plugin/__init__.py | 1 + tests/plugins/plugin/matchers.py | 243 +++++++++++++++++++++++++++++++ tests/test_plugin/test_on.py | 116 +++++++++++++++ tests/test_rule.py | 25 +++- 9 files changed, 597 insertions(+), 10 deletions(-) create mode 100644 tests/plugins/plugin/__init__.py create mode 100644 tests/plugins/plugin/matchers.py create mode 100644 tests/test_plugin/test_on.py diff --git a/nonebot/__init__.py b/nonebot/__init__.py index 50592b26..85d74aa7 100644 --- a/nonebot/__init__.py +++ b/nonebot/__init__.py @@ -16,6 +16,7 @@ - `on_command` => {ref}``on_command` ` - `on_shell_command` => {ref}``on_shell_command` ` - `on_regex` => {ref}``on_regex` ` +- `on_type` => {ref}``on_type` ` - `CommandGroup` => {ref}``CommandGroup` ` - `Matchergroup` => {ref}``MatcherGroup` ` - `load_plugin` => {ref}``load_plugin` ` @@ -260,6 +261,7 @@ def run(*args: Any, **kwargs: Any) -> None: from nonebot.plugin import on as on +from nonebot.plugin import on_type as on_type from nonebot.plugin import require as require from nonebot.plugin import on_regex as on_regex from nonebot.plugin import on_notice as on_notice diff --git a/nonebot/plugin/__init__.py b/nonebot/plugin/__init__.py index bc44ffb7..48d18bc8 100644 --- a/nonebot/plugin/__init__.py +++ b/nonebot/plugin/__init__.py @@ -16,6 +16,7 @@ - `on_command` => {ref}``on_command` ` - `on_shell_command` => {ref}``on_shell_command` ` - `on_regex` => {ref}``on_regex` ` +- `on_type` => {ref}``on_type` ` - `CommandGroup` => {ref}``CommandGroup` ` - `Matchergroup` => {ref}``MatcherGroup` ` - `load_plugin` => {ref}``load_plugin` ` @@ -105,6 +106,7 @@ def get_available_plugin_names() -> Set[str]: from .on import on as on from .manager import PluginManager +from .on import on_type as on_type from .load import require as require from .on import on_regex as on_regex from .plugin import Plugin as Plugin diff --git a/nonebot/plugin/on.py b/nonebot/plugin/on.py index 390d7a27..62b24c9e 100644 --- a/nonebot/plugin/on.py +++ b/nonebot/plugin/on.py @@ -10,6 +10,7 @@ from types import ModuleType from datetime import datetime, timedelta from typing import Any, Set, Dict, List, Type, Tuple, Union, Optional +from nonebot.adapters import Event from nonebot.matcher import Matcher from nonebot.permission import Permission from nonebot.dependencies import Dependent @@ -19,6 +20,7 @@ from nonebot.rule import ( ArgumentParser, regex, command, + is_type, keyword, endswith, fullmatch, @@ -437,6 +439,30 @@ def on_regex( return on_message(regex(pattern, flags) & rule, **kwargs, _depth=_depth + 1) +def on_type( + types: Union[Type[Event], Tuple[Type[Event]]], + rule: Optional[Union[Rule, T_RuleChecker]] = None, + *, + _depth: int = 0, + **kwargs, +) -> Type[Matcher]: + """注册一个事件响应器,并且当事件为指定类型时响应。 + + 参数: + types: 事件类型 + rule: 事件响应规则 + permission: 事件响应权限 + handlers: 事件处理函数列表 + temp: 是否为临时事件响应器(仅执行一次) + expire_time: 事件响应器最终有效时间点,过时即被删除 + priority: 事件响应器优先级 + block: 是否阻止事件向更低优先级传递 + state: 默认 state + """ + event_types = types if isinstance(types, tuple) else (types,) + return on(rule=is_type(*event_types) & rule, **kwargs, _depth=_depth + 1) + + class CommandGroup: """命令组,用于声明一组有相同名称前缀的命令。 @@ -593,6 +619,7 @@ class MatcherGroup: final_kwargs = self.base_kwargs.copy() final_kwargs.update(kwargs) final_kwargs.pop("type", None) + final_kwargs.pop("permission", None) matcher = on_notice(**final_kwargs, _depth=1) self.matchers.append(matcher) return matcher @@ -612,6 +639,7 @@ class MatcherGroup: final_kwargs = self.base_kwargs.copy() final_kwargs.update(kwargs) final_kwargs.pop("type", None) + final_kwargs.pop("permission", None) matcher = on_request(**final_kwargs, _depth=1) self.matchers.append(matcher) return matcher @@ -794,3 +822,26 @@ class MatcherGroup: matcher = on_regex(pattern, flags=flags, **final_kwargs, _depth=1) self.matchers.append(matcher) return matcher + + def on_type( + self, types: Union[Type[Event], Tuple[Type[Event]]], **kwargs + ) -> Type[Matcher]: + """注册一个事件响应器,并且当事件为指定类型时响应。 + + 参数: + types: 事件类型 + rule: 事件响应规则 + permission: 事件响应权限 + handlers: 事件处理函数列表 + temp: 是否为临时事件响应器(仅执行一次) + expire_time: 事件响应器最终有效时间点,过时即被删除 + priority: 事件响应器优先级 + block: 是否阻止事件向更低优先级传递 + state: 默认 state + """ + final_kwargs = self.base_kwargs.copy() + final_kwargs.update(kwargs) + final_kwargs.pop("type", None) + matcher = on_type(types, **final_kwargs, _depth=1) + self.matchers.append(matcher) + return matcher diff --git a/nonebot/plugin/on.pyi b/nonebot/plugin/on.pyi index 03f5c73d..ec45e05e 100644 --- a/nonebot/plugin/on.pyi +++ b/nonebot/plugin/on.pyi @@ -2,6 +2,7 @@ import re from datetime import datetime, timedelta from typing import Set, List, Type, Tuple, Union, Optional +from nonebot.adapters import Event from nonebot.matcher import Matcher from nonebot.permission import Permission from nonebot.dependencies import Dependent @@ -152,6 +153,18 @@ def on_regex( block: bool = ..., state: Optional[T_State] = ..., ) -> Type[Matcher]: ... +def on_type( + types: Union[Type[Event], Tuple[Type[Event]]], + rule: Optional[Union[Rule, T_RuleChecker]] = ..., + *, + permission: Optional[Union[Permission, T_PermissionChecker]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., + temp: bool = ..., + expire_time: Optional[Union[datetime, timedelta]] = ..., + priority: int = ..., + block: bool = ..., + state: Optional[T_State] = ..., +) -> Type[Matcher]: ... class CommandGroup: def __init__( @@ -171,8 +184,8 @@ class CommandGroup: self, cmd: Union[str, Tuple[str, ...]], *, - aliases: Optional[Set[Union[str, Tuple[str, ...]]]], rule: Optional[Union[Rule, T_RuleChecker]] = ..., + aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = ..., permission: Optional[Union[Permission, T_PermissionChecker]] = ..., handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., temp: bool = ..., @@ -186,7 +199,7 @@ class CommandGroup: cmd: Union[str, Tuple[str, ...]], *, rule: Optional[Union[Rule, T_RuleChecker]] = ..., - aliases: Optional[Set[Union[str, Tuple[str, ...]]]], + aliases: Optional[Set[Union[str, Tuple[str, ...]]]] = ..., parser: Optional[ArgumentParser] = ..., permission: Optional[Union[Permission, T_PermissionChecker]] = ..., handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., @@ -367,3 +380,16 @@ class MatcherGroup: block: bool = ..., state: Optional[T_State] = ..., ) -> Type[Matcher]: ... + def on_type( + self, + types: Union[Type[Event], Tuple[Type[Event]]], + *, + rule: Optional[Union[Rule, T_RuleChecker]] = ..., + permission: Optional[Union[Permission, T_PermissionChecker]] = ..., + handlers: Optional[List[Union[T_Handler, Dependent]]] = ..., + temp: bool = ..., + expire_time: Optional[Union[datetime, timedelta]] = ..., + priority: int = ..., + block: bool = ..., + state: Optional[T_State] = ..., + ) -> Type[Matcher]: ... diff --git a/nonebot/rule.py b/nonebot/rule.py index 03b2df56..f3e83a03 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -20,6 +20,7 @@ from typing import ( TYPE_CHECKING, Any, List, + Type, Tuple, Union, TypeVar, @@ -129,6 +130,19 @@ class StartswithRule: self.msg = msg self.ignorecase = ignorecase + def __repr__(self) -> str: + return f"StartswithRule(msg={self.msg}, ignorecase={self.ignorecase})" + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, StartswithRule) + and frozenset(self.msg) == frozenset(other.msg) + and self.ignorecase == other.ignorecase + ) + + def __hash__(self) -> int: + return hash((frozenset(self.msg), self.ignorecase)) + async def __call__( self, type: str = EventType(), text: str = EventPlainText() ) -> Any: @@ -170,6 +184,19 @@ class EndswithRule: self.msg = msg self.ignorecase = ignorecase + def __repr__(self) -> str: + return f"EndswithRule(msg={self.msg}, ignorecase={self.ignorecase})" + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, EndswithRule) + and frozenset(self.msg) == frozenset(other.msg) + and self.ignorecase == other.ignorecase + ) + + def __hash__(self) -> int: + return hash((frozenset(self.msg), self.ignorecase)) + async def __call__( self, type: str = EventType(), text: str = EventPlainText() ) -> Any: @@ -208,9 +235,22 @@ class FullmatchRule: __slots__ = ("msg", "ignorecase") def __init__(self, msg: Tuple[str, ...], ignorecase: bool = False): - self.msg = frozenset(map(str.casefold, msg) if ignorecase else msg) + self.msg = tuple(map(str.casefold, msg) if ignorecase else msg) self.ignorecase = ignorecase + def __repr__(self) -> str: + return f"FullmatchRule(msg={self.msg}, ignorecase={self.ignorecase})" + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, FullmatchRule) + and frozenset(self.msg) == frozenset(other.msg) + and self.ignorecase == other.ignorecase + ) + + def __hash__(self) -> int: + return hash((frozenset(self.msg), self.ignorecase)) + async def __call__( self, type_: str = EventType(), text: str = EventPlainText() ) -> bool: @@ -245,6 +285,17 @@ class KeywordsRule: def __init__(self, *keywords: str): self.keywords = keywords + def __repr__(self) -> str: + return f"KeywordsRule(keywords={self.keywords})" + + def __eq__(self, other: object) -> bool: + return isinstance(other, KeywordsRule) and frozenset( + self.keywords + ) == frozenset(other.keywords) + + def __hash__(self) -> int: + return hash(frozenset(self.keywords)) + async def __call__( self, type: str = EventType(), text: str = EventPlainText() ) -> bool: @@ -273,14 +324,22 @@ class CommandRule: __slots__ = ("cmds",) def __init__(self, cmds: List[Tuple[str, ...]]): - self.cmds = cmds + self.cmds = tuple(cmds) + + def __repr__(self) -> str: + return f"CommandRule(cmds={self.cmds})" + + def __eq__(self, other: object) -> bool: + return isinstance(other, CommandRule) and frozenset(self.cmds) == frozenset( + other.cmds + ) + + def __hash__(self) -> int: + return hash((frozenset(self.cmds),)) async def __call__(self, cmd: Optional[Tuple[str, ...]] = Command()) -> bool: return cmd in self.cmds - def __repr__(self): - return f"" - def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule: """匹配消息命令。 @@ -391,9 +450,22 @@ class ShellCommandRule: __slots__ = ("cmds", "parser") def __init__(self, cmds: List[Tuple[str, ...]], parser: Optional[ArgumentParser]): - self.cmds = cmds + self.cmds = tuple(cmds) self.parser = parser + def __repr__(self) -> str: + return f"ShellCommandRule(cmds={self.cmds}, parser={self.parser})" + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, ShellCommandRule) + and frozenset(self.cmds) == frozenset(other.cmds) + and self.parser is other.parser + ) + + def __hash__(self) -> int: + return hash((frozenset(self.cmds), self.parser)) + async def __call__( self, state: T_State, @@ -498,6 +570,19 @@ class RegexRule: self.regex = regex self.flags = flags + def __repr__(self) -> str: + return f"RegexRule(regex={self.regex!r}, flags={self.flags})" + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, RegexRule) + and self.regex == other.regex + and self.flags == other.flags + ) + + def __hash__(self) -> int: + return hash((self.regex, self.flags)) + async def __call__( self, state: T_State, @@ -544,6 +629,15 @@ class ToMeRule: __slots__ = () + def __repr__(self) -> str: + return "ToMeRule()" + + def __eq__(self, other: object) -> bool: + return isinstance(other, ToMeRule) + + def __hash__(self) -> int: + return hash((self.__class__,)) + async def __call__(self, to_me: bool = EventToMe()) -> bool: return to_me @@ -554,6 +648,37 @@ def to_me() -> Rule: return Rule(ToMeRule()) +class IsTypeRule: + """检查事件类型是否为指定类型。""" + + __slots__ = ("types",) + + def __init__(self, *types: Type[Event]): + self.types = types + + def __repr__(self) -> str: + return f"IsTypeRule(types={tuple(type.__name__ for type in self.types)})" + + def __eq__(self, other: object) -> bool: + return isinstance(other, IsTypeRule) and self.types == other.types + + def __hash__(self) -> int: + return hash((self.types,)) + + async def __call__(self, event: Event) -> bool: + return isinstance(event, self.types) + + +def is_type(*types: Type[Event]) -> Rule: + """匹配事件类型。 + + 参数: + types: 事件类型 + """ + + return Rule(IsTypeRule(*types)) + + __autodoc__ = { "Rule": True, "Rule.__call__": True, diff --git a/tests/plugins/plugin/__init__.py b/tests/plugins/plugin/__init__.py new file mode 100644 index 00000000..e8604447 --- /dev/null +++ b/tests/plugins/plugin/__init__.py @@ -0,0 +1 @@ +from . import matchers diff --git a/tests/plugins/plugin/matchers.py b/tests/plugins/plugin/matchers.py new file mode 100644 index 00000000..ba721d0e --- /dev/null +++ b/tests/plugins/plugin/matchers.py @@ -0,0 +1,243 @@ +from datetime import datetime, timezone + +from nonebot.adapters import Event +from nonebot import ( + CommandGroup, + MatcherGroup, + on, + on_type, + on_regex, + on_notice, + on_command, + on_keyword, + on_message, + on_request, + on_endswith, + on_fullmatch, + on_metaevent, + on_startswith, + on_shell_command, +) + + +async def rule() -> bool: + return True + + +async def permission() -> bool: + return True + + +async def handler(): + return + + +expire_time = datetime.now(timezone.utc) +priority = 100 +state = {"test": "test"} + + +matcher_on = on( + "test", + rule=rule, + permission=permission, + handlers=[handler], + temp=True, + expire_time=expire_time, + priority=priority, + block=True, + state=state, +) + + +matcher_on_metaevent = on_metaevent( + rule=rule, + handlers=[handler], + temp=True, + expire_time=expire_time, + priority=priority, + block=True, + state=state, +) + + +matcher_on_message = on_message( + rule=rule, + permission=permission, + handlers=[handler], + temp=True, + expire_time=expire_time, + priority=priority, + block=True, + state=state, +) + + +matcher_on_notice = on_notice( + rule=rule, + handlers=[handler], + temp=True, + expire_time=expire_time, + priority=priority, + block=True, + state=state, +) + + +matcher_on_request = on_request( + rule=rule, + handlers=[handler], + temp=True, + expire_time=expire_time, + priority=priority, + block=True, + state=state, +) + + +matcher_on_startswith = on_startswith( + "test", + rule=rule, + permission=permission, + handlers=[handler], + temp=True, + expire_time=expire_time, + priority=priority, + block=True, + state=state, +) + + +matcher_on_endswith = on_endswith( + "test", + rule=rule, + permission=permission, + handlers=[handler], + temp=True, + expire_time=expire_time, + priority=priority, + block=True, + state=state, +) + + +matcher_on_fullmatch = on_fullmatch( + "test", + rule=rule, + permission=permission, + handlers=[handler], + temp=True, + expire_time=expire_time, + priority=priority, + block=True, + state=state, +) + + +matcher_on_keyword = on_keyword( + {"test"}, + rule=rule, + permission=permission, + handlers=[handler], + temp=True, + expire_time=expire_time, + priority=priority, + block=True, + state=state, +) + + +matcher_on_command = on_command( + "test", + rule=rule, + permission=permission, + handlers=[handler], + temp=True, + expire_time=expire_time, + priority=priority, + block=True, + state=state, +) + + +matcher_on_shell_command = on_shell_command( + "test", + rule=rule, + permission=permission, + handlers=[handler], + temp=True, + expire_time=expire_time, + priority=priority, + block=True, + state=state, +) + + +matcher_on_regex = on_regex( + "test", + rule=rule, + permission=permission, + handlers=[handler], + temp=True, + expire_time=expire_time, + priority=priority, + block=True, + state=state, +) + + +class TestEvent(Event): + ... + + +matcher_on_type = on_type( + TestEvent, + rule=rule, + permission=permission, + handlers=[handler], + temp=True, + expire_time=expire_time, + priority=priority, + block=True, + state=state, +) + + +cmd_group = CommandGroup( + "test", + rule=rule, + permission=permission, + handlers=[handler], + temp=True, + expire_time=expire_time, + priority=priority, + block=True, + state=state, +) +matcher_sub_cmd = cmd_group.command("sub") +matcher_sub_shell_cmd = cmd_group.shell_command("sub") + + +matcher_group = MatcherGroup( + rule=rule, + permission=permission, + handlers=[handler], + temp=True, + expire_time=expire_time, + priority=priority, + block=True, + state=state, +) +matcher_group_on = matcher_group.on(type="test") +matcher_group_on_metaevent = matcher_group.on_metaevent() +matcher_group_on_message = matcher_group.on_message() +matcher_group_on_notice = matcher_group.on_notice() +matcher_group_on_request = matcher_group.on_request() +matcher_group_on_startswith = matcher_group.on_startswith("test") +matcher_group_on_endswith = matcher_group.on_endswith("test") +matcher_group_on_fullmatch = matcher_group.on_fullmatch("test") +matcher_group_on_keyword = matcher_group.on_keyword({"test"}) +matcher_group_on_command = matcher_group.on_command("test") +matcher_group_on_shell_command = matcher_group.on_shell_command("test") +matcher_group_on_regex = matcher_group.on_regex("test") +matcher_group_on_type = matcher_group.on_type(TestEvent) diff --git a/tests/test_plugin/test_on.py b/tests/test_plugin/test_on.py new file mode 100644 index 00000000..48efe4c0 --- /dev/null +++ b/tests/test_plugin/test_on.py @@ -0,0 +1,116 @@ +from typing import Type, Optional + +import pytest +from nonebug import App + + +@pytest.mark.asyncio +async def test_on(app: App, load_plugin): + import nonebot + import plugins.plugin.matchers as module + from nonebot.typing import T_RuleChecker + from nonebot.matcher import Matcher, matchers + from nonebot.rule import ( + RegexRule, + IsTypeRule, + CommandRule, + EndswithRule, + KeywordsRule, + FullmatchRule, + StartswithRule, + ShellCommandRule, + ) + from plugins.plugin.matchers import ( + TestEvent, + rule, + state, + handler, + priority, + matcher_on, + permission, + expire_time, + matcher_on_type, + matcher_sub_cmd, + matcher_group_on, + matcher_on_regex, + matcher_on_notice, + matcher_on_command, + matcher_on_keyword, + matcher_on_message, + matcher_on_request, + matcher_on_endswith, + matcher_on_fullmatch, + matcher_on_metaevent, + matcher_group_on_type, + matcher_on_startswith, + matcher_sub_shell_cmd, + matcher_group_on_regex, + matcher_group_on_notice, + matcher_group_on_command, + matcher_group_on_keyword, + matcher_group_on_message, + matcher_group_on_request, + matcher_on_shell_command, + matcher_group_on_endswith, + matcher_group_on_fullmatch, + matcher_group_on_metaevent, + matcher_group_on_startswith, + matcher_group_on_shell_command, + ) + + plugin = nonebot.get_plugin("plugin") + + def _check( + matcher: Type[Matcher], + pre_rule: Optional[T_RuleChecker], + has_permission: bool, + ): + assert {dependent.call for dependent in matcher.rule.checkers} == ( + {pre_rule, rule} if pre_rule else {rule} + ) + if has_permission: + assert {dependent.call for dependent in matcher.permission.checkers} == { + permission + } + else: + assert not matcher.permission.checkers + assert [dependent.call for dependent in matcher.handlers] == [handler] + assert matcher.temp is True + assert matcher.expire_time == expire_time + assert matcher in matchers[priority] + assert matcher.block is True + assert matcher._default_state == state + + assert matcher.plugin is plugin + assert matcher.module is module + assert matcher.plugin_name == "plugin" + assert matcher.module_name == "plugins.plugin.matchers" + + _check(matcher_on, None, True) + _check(matcher_on_metaevent, None, False) + _check(matcher_on_message, None, True) + _check(matcher_on_notice, None, False) + _check(matcher_on_request, None, False) + _check(matcher_on_startswith, StartswithRule(("test",)), True) + _check(matcher_on_endswith, EndswithRule(("test",)), True) + _check(matcher_on_fullmatch, FullmatchRule(("test",)), True) + _check(matcher_on_keyword, KeywordsRule("test"), True) + _check(matcher_on_command, CommandRule([("test",)]), True) + _check(matcher_on_shell_command, ShellCommandRule([("test",)], None), True) + _check(matcher_on_regex, RegexRule("test"), True) + _check(matcher_on_type, IsTypeRule(TestEvent), True) + _check(matcher_sub_cmd, CommandRule([("test", "sub")]), True) + _check(matcher_sub_shell_cmd, ShellCommandRule([("test", "sub")], None), True) + _check(matcher_group_on, None, True) + _check(matcher_group_on_metaevent, None, False) + _check(matcher_group_on_message, None, True) + _check(matcher_group_on_notice, None, False) + _check(matcher_group_on_request, None, False) + _check(matcher_group_on_startswith, StartswithRule(("test",)), True) + _check(matcher_group_on_endswith, EndswithRule(("test",)), True) + _check(matcher_group_on_fullmatch, FullmatchRule(("test",)), True) + _check(matcher_group_on_keyword, KeywordsRule("test"), True) + _check(matcher_group_on_command, CommandRule([("test",)]), True) + _check(matcher_group_on_shell_command, ShellCommandRule([("test",)], None), True) + _check(matcher_group_on_regex, RegexRule("test"), True) + _check(matcher_group_on_type, IsTypeRule(TestEvent), True) diff --git a/tests/test_rule.py b/tests/test_rule.py index 4aac218f..6d685592 100644 --- a/tests/test_rule.py +++ b/tests/test_rule.py @@ -145,7 +145,7 @@ async def test_fullmatch( checker = dependent.call assert isinstance(checker, FullmatchRule) - assert checker.msg == {msg} if isinstance(msg, str) else {*msg} + assert checker.msg == ((msg,) if isinstance(msg, str) else msg) assert checker.ignorecase == ignorecase message = make_fake_message()(text) @@ -196,7 +196,7 @@ async def test_command(app: App, cmds: Tuple[Tuple[str, ...]]): checker = dependent.call assert isinstance(checker, CommandRule) - assert checker.cmds == list(cmds) + assert checker.cmds == cmds for cmd in cmds: state = {PREFIX_KEY: {CMD_KEY: cmd}} @@ -318,3 +318,24 @@ async def test_to_me(app: App, expected: bool): event = make_fake_event(_to_me=expected)() assert await dependent(event=event) == expected + + +@pytest.mark.asyncio +async def test_is_type(app: App): + from nonebot.rule import IsTypeRule, is_type + + Event1 = make_fake_event() + Event2 = make_fake_event() + Event3 = make_fake_event() + + test_type = is_type(Event1, Event2) + dependent = list(test_type.checkers)[0] + checker = dependent.call + + assert isinstance(checker, IsTypeRule) + + event = Event1() + assert await dependent(event=event) + + event = Event3() + assert not await dependent(event=event)