Feature: 移除内置响应规则事件类型限制 (#1824)

This commit is contained in:
Ju4tCode 2023-03-19 15:45:32 +08:00 committed by GitHub
parent f65127e655
commit 36e99bc3ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 26 additions and 19 deletions

View File

@ -21,7 +21,7 @@ from typing import (
from nonebot.log import logger
from nonebot.internal.rule import Rule
from nonebot.dependencies import Dependent
from nonebot.internal.permission import USER, User, Permission
from nonebot.internal.permission import User, Permission
from nonebot.internal.adapter import (
Bot,
Event,
@ -682,12 +682,7 @@ class Matcher(metaclass=MatcherMeta):
stack=stack,
dependency_cache=dependency_cache,
)
permission = self.permission
if len(permission.checkers) == 1 and isinstance(
user_perm := tuple(permission.checkers)[0].call, User
):
permission = user_perm.perm
return USER(event.get_session_id(), perm=permission)
return Permission(User.from_event(event, perm=self.permission))
async def resolve_reject(self):
handler = current_handler.get()

View File

@ -1,4 +1,5 @@
import asyncio
from typing_extensions import Self
from contextlib import AsyncExitStack
from typing import Set, Tuple, Union, NoReturn, Optional
@ -140,6 +141,22 @@ class User:
session in self.users and (self.perm is None or await self.perm(bot, event))
)
@classmethod
def from_event(cls, event: Event, perm: Optional[Permission] = None) -> Self:
"""从事件中获取会话 ID
参数:
event: Event 对象
perm: 需同时满足的权限
"""
if (
perm
and len(perm.checkers) == 1
and isinstance(user_perm := tuple(perm.checkers)[0].call, cls)
):
perm = user_perm.perm
return cls((event.get_session_id(),), perm)
def USER(*users: str, perm: Optional[Permission] = None):
"""匹配当前事件属于指定会话

View File

@ -164,8 +164,6 @@ class StartswithRule:
return hash((frozenset(self.msg), self.ignorecase))
async def __call__(self, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
try:
text = event.get_plaintext()
except Exception:
@ -221,8 +219,6 @@ class EndswithRule:
return hash((frozenset(self.msg), self.ignorecase))
async def __call__(self, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
try:
text = event.get_plaintext()
except Exception:
@ -278,8 +274,6 @@ class FullmatchRule:
return hash((frozenset(self.msg), self.ignorecase))
async def __call__(self, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
try:
text = event.get_plaintext()
except Exception:
@ -330,8 +324,6 @@ class KeywordsRule:
return hash(frozenset(self.keywords))
async def __call__(self, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
try:
text = event.get_plaintext()
except Exception:
@ -649,8 +641,6 @@ class RegexRule:
return hash((self.regex, self.flags))
async def __call__(self, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
try:
msg = event.get_message()
except Exception:

View File

@ -133,6 +133,7 @@ async def test_trie(app: App):
("prefix", False, "message", "fooprefix", False),
("prefix", False, "message", None, False),
(("prefix", "foo"), False, "message", "fooprefix", True),
("prefix", False, "notice", "prefix", True),
("prefix", False, "notice", "foo", False),
],
)
@ -172,6 +173,7 @@ async def test_startswith(
("suffix", False, "message", "suffixfoo", False),
("suffix", False, "message", None, False),
(("suffix", "foo"), False, "message", "suffixfoo", True),
("suffix", False, "notice", "suffix", True),
("suffix", False, "notice", "foo", False),
],
)
@ -211,6 +213,7 @@ async def test_endswith(
("fullmatch", False, "message", "_fullmatch_", False),
("fullmatch", False, "message", None, False),
(("fullmatch", "foo"), False, "message", "fullmatchfoo", False),
("fullmatch", False, "notice", "fullmatch", True),
("fullmatch", False, "notice", "foo", False),
],
)
@ -245,8 +248,9 @@ async def test_fullmatch(
(("key",), "message", "_key_", True),
(("key", "foo"), "message", "_foo_", True),
(("key",), "message", None, False),
(("key",), "notice", "foo", False),
(("key",), "message", "foo", False),
(("key",), "notice", "_key_", True),
(("key",), "notice", "foo", False),
],
)
async def test_keyword(
@ -410,7 +414,8 @@ async def test_shell_command():
{"key": "key1"},
),
(r"foo", "message", None, False, None, None, None, None),
(r"foo", "notice", "foo", False, None, None, None, None),
(r"foo", "notice", "foo", True, "foo", "foo", tuple(), {}),
(r"foo", "notice", "bar", False, None, None, None, None),
],
)
async def test_regex(