🐛 Fix: 内置规则和权限没有捕获错误 (#1291)

This commit is contained in:
Ju4tCode 2022-09-29 16:56:06 +08:00 committed by GitHub
parent ab85b8651e
commit 71aad502d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 113 additions and 56 deletions

View File

@ -132,9 +132,12 @@ class User:
)
async def __call__(self, bot: Bot, event: Event) -> bool:
try:
session = event.get_session_id()
except Exception:
return False
return bool(
event.get_session_id() in self.users
and (self.perm is None or await self.perm(bot, event))
session in self.users and (self.perm is None or await self.perm(bot, event))
)

View File

@ -94,10 +94,14 @@ class SuperUser:
return "Superuser()"
async def __call__(self, bot: Bot, event: Event) -> bool:
try:
user_id = event.get_user_id()
except Exception:
return False
return (
f"{bot.adapter.get_name().split(maxsplit=1)[0].lower()}:{event.get_user_id()}"
f"{bot.adapter.get_name().split(maxsplit=1)[0].lower()}:{user_id}"
in bot.config.superusers
or event.get_user_id() in bot.config.superusers # 兼容旧配置
or user_id in bot.config.superusers # 兼容旧配置
)

View File

@ -39,15 +39,8 @@ from nonebot.log import logger
from nonebot.typing import T_State
from nonebot.exception import ParserExit
from nonebot.internal.rule import Rule as Rule
from nonebot.params import Command, EventToMe, CommandArg
from nonebot.adapters import Bot, Event, Message, MessageSegment
from nonebot.params import (
Command,
EventToMe,
EventType,
CommandArg,
EventMessage,
EventPlainText,
)
from nonebot.consts import (
CMD_KEY,
PREFIX_KEY,
@ -143,10 +136,12 @@ class StartswithRule:
def __hash__(self) -> int:
return hash((frozenset(self.msg), self.ignorecase))
async def __call__(
self, type: str = EventType(), text: str = EventPlainText()
) -> Any:
if type != "message":
async def __call__(self, event: Event) -> bool:
if event.get_type() != "message":
return False
try:
text = event.get_plaintext()
except Exception:
return False
return bool(
re.match(
@ -197,10 +192,12 @@ class EndswithRule:
def __hash__(self) -> int:
return hash((frozenset(self.msg), self.ignorecase))
async def __call__(
self, type: str = EventType(), text: str = EventPlainText()
) -> Any:
if type != "message":
async def __call__(self, event: Event) -> bool:
if event.get_type() != "message":
return False
try:
text = event.get_plaintext()
except Exception:
return False
return bool(
re.search(
@ -251,13 +248,14 @@ class FullmatchRule:
def __hash__(self) -> int:
return hash((frozenset(self.msg), self.ignorecase))
async def __call__(
self, type_: str = EventType(), text: str = EventPlainText()
) -> bool:
return (
type_ == "message"
and (text.casefold() if self.ignorecase else text) in self.msg
)
async def __call__(self, event: Event) -> bool:
if event.get_type() != "message":
return False
try:
text = event.get_plaintext()
except Exception:
return False
return (text.casefold() if self.ignorecase else text) in self.msg
def fullmatch(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule:
@ -296,10 +294,12 @@ class KeywordsRule:
def __hash__(self) -> int:
return hash(frozenset(self.keywords))
async def __call__(
self, type: str = EventType(), text: str = EventPlainText()
) -> bool:
if type != "message":
async def __call__(self, event: Event) -> bool:
if event.get_type() != "message":
return False
try:
text = event.get_plaintext()
except Exception:
return False
return bool(text and any(keyword in text for keyword in self.keywords))
@ -583,16 +583,14 @@ class RegexRule:
def __hash__(self) -> int:
return hash((self.regex, self.flags))
async def __call__(
self,
state: T_State,
type: str = EventType(),
msg: Message = EventMessage(),
) -> bool:
if type != "message":
async def __call__(self, event: Event, state: T_State) -> bool:
if event.get_type() != "message":
return False
matched = re.search(self.regex, str(msg), self.flags)
if matched:
try:
msg = event.get_message()
except Exception:
return False
if matched := re.search(self.regex, str(msg), self.flags):
state[REGEX_MATCHED] = matched.group()
state[REGEX_GROUP] = matched.groups()
state[REGEX_DICT] = matched.groupdict()

View File

@ -1,4 +1,4 @@
from typing import Tuple
from typing import Tuple, Optional
import pytest
from nonebug import App
@ -144,6 +144,7 @@ async def test_metaevent(
("message", "test", True),
("message", "foo", False),
("message", "faketest", True),
("message", None, False),
("notice", "test", True),
],
)
@ -173,10 +174,11 @@ async def test_superuser(
[
(("user", "foo"), "user", True),
(("user", "foo"), "bar", False),
(("user", "foo"), None, False),
],
)
async def test_user(
app: App, session_ids: Tuple[str, ...], session_id: str, expected: bool
app: App, session_ids: Tuple[str, ...], session_id: Optional[str], expected: bool
):
from nonebot.permission import USER, User

View File

@ -1,5 +1,5 @@
import sys
from typing import Tuple, Union
from typing import Dict, Tuple, Union, Optional
import pytest
from nonebug import App
@ -52,6 +52,7 @@ async def test_rule(app: App):
("prefix", True, "message", "Prefix_", True),
("prefix", False, "message", "prefoo", False),
("prefix", False, "message", "fooprefix", False),
("prefix", False, "message", None, False),
(("prefix", "foo"), False, "message", "fooprefix", True),
("prefix", False, "notice", "foo", False),
],
@ -61,7 +62,7 @@ async def test_startswith(
msg: Union[str, Tuple[str, ...]],
ignorecase: bool,
type: str,
text: str,
text: Optional[str],
expected: bool,
):
from nonebot.rule import StartswithRule, startswith
@ -74,7 +75,7 @@ async def test_startswith(
assert checker.msg == (msg,) if isinstance(msg, str) else msg
assert checker.ignorecase == ignorecase
message = make_fake_message()(text)
message = text if text is None else make_fake_message()(text)
event = make_fake_event(_type=type, _message=message)()
assert await dependent(event=event) == expected
@ -89,6 +90,7 @@ async def test_startswith(
("suffix", True, "message", "_Suffix", True),
("suffix", False, "message", "suffoo", False),
("suffix", False, "message", "suffixfoo", False),
("suffix", False, "message", None, False),
(("suffix", "foo"), False, "message", "suffixfoo", True),
("suffix", False, "notice", "foo", False),
],
@ -98,7 +100,7 @@ async def test_endswith(
msg: Union[str, Tuple[str, ...]],
ignorecase: bool,
type: str,
text: str,
text: Optional[str],
expected: bool,
):
from nonebot.rule import EndswithRule, endswith
@ -111,7 +113,7 @@ async def test_endswith(
assert checker.msg == (msg,) if isinstance(msg, str) else msg
assert checker.ignorecase == ignorecase
message = make_fake_message()(text)
message = text if text is None else make_fake_message()(text)
event = make_fake_event(_type=type, _message=message)()
assert await dependent(event=event) == expected
@ -126,6 +128,7 @@ async def test_endswith(
("fullmatch", True, "message", "Fullmatch", True),
("fullmatch", False, "message", "fullfoo", False),
("fullmatch", False, "message", "_fullmatch_", False),
("fullmatch", False, "message", None, False),
(("fullmatch", "foo"), False, "message", "fullmatchfoo", False),
("fullmatch", False, "notice", "foo", False),
],
@ -135,7 +138,7 @@ async def test_fullmatch(
msg: Union[str, Tuple[str, ...]],
ignorecase: bool,
type: str,
text: str,
text: Optional[str],
expected: bool,
):
from nonebot.rule import FullmatchRule, fullmatch
@ -148,7 +151,7 @@ async def test_fullmatch(
assert checker.msg == ((msg,) if isinstance(msg, str) else msg)
assert checker.ignorecase == ignorecase
message = make_fake_message()(text)
message = text if text is None else make_fake_message()(text)
event = make_fake_event(_type=type, _message=message)()
assert await dependent(event=event) == expected
@ -159,6 +162,7 @@ async def test_fullmatch(
[
(("key",), "message", "_key_", True),
(("key", "foo"), "message", "_foo_", True),
(("key",), "message", None, False),
(("key",), "notice", "foo", False),
],
)
@ -166,7 +170,7 @@ async def test_keyword(
app: App,
kws: Tuple[str, ...],
type: str,
text: str,
text: Optional[str],
expected: bool,
):
from nonebot.rule import KeywordsRule, keyword
@ -178,7 +182,7 @@ async def test_keyword(
assert isinstance(checker, KeywordsRule)
assert checker.keywords == kws
message = make_fake_message()(text)
message = text if text is None else make_fake_message()(text)
event = make_fake_event(_type=type, _message=message)()
assert await dependent(event=event) == expected
@ -302,7 +306,51 @@ async def test_shell_command(app: App):
assert state[SHELL_ARGS].status != 0
# TODO: regex
@pytest.mark.asyncio
@pytest.mark.parametrize(
"pattern,type,text,expected,matched,group,dict",
[
(
r"(?P<key>key\d)",
"message",
"_key1_",
True,
"key1",
("key1",),
{"key": "key1"},
),
(r"foo", "message", None, False, None, None, None),
(r"foo", "notice", "foo", False, None, None, None),
],
)
async def test_regex(
app: App,
pattern: str,
type: str,
text: Optional[str],
expected: bool,
matched: Optional[str],
group: Optional[Tuple[str, ...]],
dict: Optional[Dict[str, str]],
):
from nonebot.typing import T_State
from nonebot.rule import RegexRule, regex
from nonebot.consts import REGEX_DICT, REGEX_GROUP, REGEX_MATCHED
test_regex = regex(pattern)
dependent = list(test_regex.checkers)[0]
checker = dependent.call
assert isinstance(checker, RegexRule)
assert checker.regex == pattern
message = text if text is None else make_fake_message()(text)
event = make_fake_event(_type=type, _message=message)()
state = {}
assert await dependent(event=event, state=state) == expected
assert state.get(REGEX_MATCHED) == matched
assert state.get(REGEX_GROUP) == group
assert state.get(REGEX_DICT) == dict
@pytest.mark.asyncio
@ -310,8 +358,8 @@ async def test_shell_command(app: App):
async def test_to_me(app: App, expected: bool):
from nonebot.rule import ToMeRule, to_me
test_keyword = to_me()
dependent = list(test_keyword.checkers)[0]
test_to_me = to_me()
dependent = list(test_to_me.checkers)[0]
checker = dependent.call
assert isinstance(checker, ToMeRule)

View File

@ -65,7 +65,7 @@ def make_fake_event(
_type: str = "message",
_name: str = "test",
_description: str = "test",
_user_id: str = "test",
_user_id: Optional[str] = "test",
_session_id: Optional[str] = "test",
_message: Optional["Message"] = None,
_to_me: bool = True,
@ -86,7 +86,9 @@ def make_fake_event(
return _description
def get_user_id(self) -> str:
return _user_id
if _user_id is not None:
return _user_id
raise NotImplementedError
def get_session_id(self) -> str:
if _session_id is not None: