Feature: 移除内置响应规则事件类型限制 (#1824)

This commit is contained in:
Ju4tCode 2023-03-19 15:45:32 +08:00 committed by GitHub
parent f65127e655
commit 36e99bc3ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 26 additions and 19 deletions

View File

@ -21,7 +21,7 @@ from typing import (
from nonebot.log import logger from nonebot.log import logger
from nonebot.internal.rule import Rule from nonebot.internal.rule import Rule
from nonebot.dependencies import Dependent from nonebot.dependencies import Dependent
from nonebot.internal.permission import USER, User, Permission from nonebot.internal.permission import User, Permission
from nonebot.internal.adapter import ( from nonebot.internal.adapter import (
Bot, Bot,
Event, Event,
@ -682,12 +682,7 @@ class Matcher(metaclass=MatcherMeta):
stack=stack, stack=stack,
dependency_cache=dependency_cache, dependency_cache=dependency_cache,
) )
permission = self.permission return Permission(User.from_event(event, perm=self.permission))
if len(permission.checkers) == 1 and isinstance(
user_perm := tuple(permission.checkers)[0].call, User
):
permission = user_perm.perm
return USER(event.get_session_id(), perm=permission)
async def resolve_reject(self): async def resolve_reject(self):
handler = current_handler.get() handler = current_handler.get()

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
from typing_extensions import Self
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from typing import Set, Tuple, Union, NoReturn, Optional from typing import Set, Tuple, Union, NoReturn, Optional
@ -140,6 +141,22 @@ class User:
session in self.users and (self.perm is None or await self.perm(bot, event)) session in self.users and (self.perm is None or await self.perm(bot, event))
) )
@classmethod
def from_event(cls, event: Event, perm: Optional[Permission] = None) -> Self:
"""从事件中获取会话 ID
参数:
event: Event 对象
perm: 需同时满足的权限
"""
if (
perm
and len(perm.checkers) == 1
and isinstance(user_perm := tuple(perm.checkers)[0].call, cls)
):
perm = user_perm.perm
return cls((event.get_session_id(),), perm)
def USER(*users: str, perm: Optional[Permission] = None): def USER(*users: str, perm: Optional[Permission] = None):
"""匹配当前事件属于指定会话 """匹配当前事件属于指定会话

View File

@ -164,8 +164,6 @@ class StartswithRule:
return hash((frozenset(self.msg), self.ignorecase)) return hash((frozenset(self.msg), self.ignorecase))
async def __call__(self, event: Event, state: T_State) -> bool: async def __call__(self, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
try: try:
text = event.get_plaintext() text = event.get_plaintext()
except Exception: except Exception:
@ -221,8 +219,6 @@ class EndswithRule:
return hash((frozenset(self.msg), self.ignorecase)) return hash((frozenset(self.msg), self.ignorecase))
async def __call__(self, event: Event, state: T_State) -> bool: async def __call__(self, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
try: try:
text = event.get_plaintext() text = event.get_plaintext()
except Exception: except Exception:
@ -278,8 +274,6 @@ class FullmatchRule:
return hash((frozenset(self.msg), self.ignorecase)) return hash((frozenset(self.msg), self.ignorecase))
async def __call__(self, event: Event, state: T_State) -> bool: async def __call__(self, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
try: try:
text = event.get_plaintext() text = event.get_plaintext()
except Exception: except Exception:
@ -330,8 +324,6 @@ class KeywordsRule:
return hash(frozenset(self.keywords)) return hash(frozenset(self.keywords))
async def __call__(self, event: Event, state: T_State) -> bool: async def __call__(self, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
try: try:
text = event.get_plaintext() text = event.get_plaintext()
except Exception: except Exception:
@ -649,8 +641,6 @@ class RegexRule:
return hash((self.regex, self.flags)) return hash((self.regex, self.flags))
async def __call__(self, event: Event, state: T_State) -> bool: async def __call__(self, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
try: try:
msg = event.get_message() msg = event.get_message()
except Exception: except Exception:

View File

@ -133,6 +133,7 @@ async def test_trie(app: App):
("prefix", False, "message", "fooprefix", False), ("prefix", False, "message", "fooprefix", False),
("prefix", False, "message", None, False), ("prefix", False, "message", None, False),
(("prefix", "foo"), False, "message", "fooprefix", True), (("prefix", "foo"), False, "message", "fooprefix", True),
("prefix", False, "notice", "prefix", True),
("prefix", False, "notice", "foo", False), ("prefix", False, "notice", "foo", False),
], ],
) )
@ -172,6 +173,7 @@ async def test_startswith(
("suffix", False, "message", "suffixfoo", False), ("suffix", False, "message", "suffixfoo", False),
("suffix", False, "message", None, False), ("suffix", False, "message", None, False),
(("suffix", "foo"), False, "message", "suffixfoo", True), (("suffix", "foo"), False, "message", "suffixfoo", True),
("suffix", False, "notice", "suffix", True),
("suffix", False, "notice", "foo", False), ("suffix", False, "notice", "foo", False),
], ],
) )
@ -211,6 +213,7 @@ async def test_endswith(
("fullmatch", False, "message", "_fullmatch_", False), ("fullmatch", False, "message", "_fullmatch_", False),
("fullmatch", False, "message", None, False), ("fullmatch", False, "message", None, False),
(("fullmatch", "foo"), False, "message", "fullmatchfoo", False), (("fullmatch", "foo"), False, "message", "fullmatchfoo", False),
("fullmatch", False, "notice", "fullmatch", True),
("fullmatch", False, "notice", "foo", False), ("fullmatch", False, "notice", "foo", False),
], ],
) )
@ -245,8 +248,9 @@ async def test_fullmatch(
(("key",), "message", "_key_", True), (("key",), "message", "_key_", True),
(("key", "foo"), "message", "_foo_", True), (("key", "foo"), "message", "_foo_", True),
(("key",), "message", None, False), (("key",), "message", None, False),
(("key",), "notice", "foo", False),
(("key",), "message", "foo", False), (("key",), "message", "foo", False),
(("key",), "notice", "_key_", True),
(("key",), "notice", "foo", False),
], ],
) )
async def test_keyword( async def test_keyword(
@ -410,7 +414,8 @@ async def test_shell_command():
{"key": "key1"}, {"key": "key1"},
), ),
(r"foo", "message", None, False, None, None, None, None), (r"foo", "message", None, False, None, None, None, None),
(r"foo", "notice", "foo", False, None, None, None, None), (r"foo", "notice", "foo", True, "foo", "foo", tuple(), {}),
(r"foo", "notice", "bar", False, None, None, None, None),
], ],
) )
async def test_regex( async def test_regex(