🐛 Fix: 内置规则和权限没有捕获错误 (#1291)

This commit is contained in:
Ju4tCode 2022-09-29 16:56:06 +08:00 committed by GitHub
parent ab85b8651e
commit 71aad502d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 113 additions and 56 deletions

View File

@ -132,9 +132,12 @@ class User:
) )
async def __call__(self, bot: Bot, event: Event) -> bool: async def __call__(self, bot: Bot, event: Event) -> bool:
try:
session = event.get_session_id()
except Exception:
return False
return bool( return bool(
event.get_session_id() in self.users session in self.users and (self.perm is None or await self.perm(bot, event))
and (self.perm is None or await self.perm(bot, event))
) )

View File

@ -94,10 +94,14 @@ class SuperUser:
return "Superuser()" return "Superuser()"
async def __call__(self, bot: Bot, event: Event) -> bool: async def __call__(self, bot: Bot, event: Event) -> bool:
try:
user_id = event.get_user_id()
except Exception:
return False
return ( return (
f"{bot.adapter.get_name().split(maxsplit=1)[0].lower()}:{event.get_user_id()}" f"{bot.adapter.get_name().split(maxsplit=1)[0].lower()}:{user_id}"
in bot.config.superusers in bot.config.superusers
or event.get_user_id() in bot.config.superusers # 兼容旧配置 or user_id in bot.config.superusers # 兼容旧配置
) )

View File

@ -39,15 +39,8 @@ from nonebot.log import logger
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot.exception import ParserExit from nonebot.exception import ParserExit
from nonebot.internal.rule import Rule as Rule from nonebot.internal.rule import Rule as Rule
from nonebot.params import Command, EventToMe, CommandArg
from nonebot.adapters import Bot, Event, Message, MessageSegment from nonebot.adapters import Bot, Event, Message, MessageSegment
from nonebot.params import (
Command,
EventToMe,
EventType,
CommandArg,
EventMessage,
EventPlainText,
)
from nonebot.consts import ( from nonebot.consts import (
CMD_KEY, CMD_KEY,
PREFIX_KEY, PREFIX_KEY,
@ -143,10 +136,12 @@ class StartswithRule:
def __hash__(self) -> int: def __hash__(self) -> int:
return hash((frozenset(self.msg), self.ignorecase)) return hash((frozenset(self.msg), self.ignorecase))
async def __call__( async def __call__(self, event: Event) -> bool:
self, type: str = EventType(), text: str = EventPlainText() if event.get_type() != "message":
) -> Any: return False
if type != "message": try:
text = event.get_plaintext()
except Exception:
return False return False
return bool( return bool(
re.match( re.match(
@ -197,10 +192,12 @@ class EndswithRule:
def __hash__(self) -> int: def __hash__(self) -> int:
return hash((frozenset(self.msg), self.ignorecase)) return hash((frozenset(self.msg), self.ignorecase))
async def __call__( async def __call__(self, event: Event) -> bool:
self, type: str = EventType(), text: str = EventPlainText() if event.get_type() != "message":
) -> Any: return False
if type != "message": try:
text = event.get_plaintext()
except Exception:
return False return False
return bool( return bool(
re.search( re.search(
@ -251,13 +248,14 @@ class FullmatchRule:
def __hash__(self) -> int: def __hash__(self) -> int:
return hash((frozenset(self.msg), self.ignorecase)) return hash((frozenset(self.msg), self.ignorecase))
async def __call__( async def __call__(self, event: Event) -> bool:
self, type_: str = EventType(), text: str = EventPlainText() if event.get_type() != "message":
) -> bool: return False
return ( try:
type_ == "message" text = event.get_plaintext()
and (text.casefold() if self.ignorecase else text) in self.msg except Exception:
) return False
return (text.casefold() if self.ignorecase else text) in self.msg
def fullmatch(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule: def fullmatch(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule:
@ -296,10 +294,12 @@ class KeywordsRule:
def __hash__(self) -> int: def __hash__(self) -> int:
return hash(frozenset(self.keywords)) return hash(frozenset(self.keywords))
async def __call__( async def __call__(self, event: Event) -> bool:
self, type: str = EventType(), text: str = EventPlainText() if event.get_type() != "message":
) -> bool: return False
if type != "message": try:
text = event.get_plaintext()
except Exception:
return False return False
return bool(text and any(keyword in text for keyword in self.keywords)) return bool(text and any(keyword in text for keyword in self.keywords))
@ -583,16 +583,14 @@ class RegexRule:
def __hash__(self) -> int: def __hash__(self) -> int:
return hash((self.regex, self.flags)) return hash((self.regex, self.flags))
async def __call__( async def __call__(self, event: Event, state: T_State) -> bool:
self, if event.get_type() != "message":
state: T_State,
type: str = EventType(),
msg: Message = EventMessage(),
) -> bool:
if type != "message":
return False return False
matched = re.search(self.regex, str(msg), self.flags) try:
if matched: msg = event.get_message()
except Exception:
return False
if matched := re.search(self.regex, str(msg), self.flags):
state[REGEX_MATCHED] = matched.group() state[REGEX_MATCHED] = matched.group()
state[REGEX_GROUP] = matched.groups() state[REGEX_GROUP] = matched.groups()
state[REGEX_DICT] = matched.groupdict() state[REGEX_DICT] = matched.groupdict()

View File

@ -1,4 +1,4 @@
from typing import Tuple from typing import Tuple, Optional
import pytest import pytest
from nonebug import App from nonebug import App
@ -144,6 +144,7 @@ async def test_metaevent(
("message", "test", True), ("message", "test", True),
("message", "foo", False), ("message", "foo", False),
("message", "faketest", True), ("message", "faketest", True),
("message", None, False),
("notice", "test", True), ("notice", "test", True),
], ],
) )
@ -173,10 +174,11 @@ async def test_superuser(
[ [
(("user", "foo"), "user", True), (("user", "foo"), "user", True),
(("user", "foo"), "bar", False), (("user", "foo"), "bar", False),
(("user", "foo"), None, False),
], ],
) )
async def test_user( async def test_user(
app: App, session_ids: Tuple[str, ...], session_id: str, expected: bool app: App, session_ids: Tuple[str, ...], session_id: Optional[str], expected: bool
): ):
from nonebot.permission import USER, User from nonebot.permission import USER, User

View File

@ -1,5 +1,5 @@
import sys import sys
from typing import Tuple, Union from typing import Dict, Tuple, Union, Optional
import pytest import pytest
from nonebug import App from nonebug import App
@ -52,6 +52,7 @@ async def test_rule(app: App):
("prefix", True, "message", "Prefix_", True), ("prefix", True, "message", "Prefix_", True),
("prefix", False, "message", "prefoo", False), ("prefix", False, "message", "prefoo", False),
("prefix", False, "message", "fooprefix", False), ("prefix", False, "message", "fooprefix", False),
("prefix", False, "message", None, False),
(("prefix", "foo"), False, "message", "fooprefix", True), (("prefix", "foo"), False, "message", "fooprefix", True),
("prefix", False, "notice", "foo", False), ("prefix", False, "notice", "foo", False),
], ],
@ -61,7 +62,7 @@ async def test_startswith(
msg: Union[str, Tuple[str, ...]], msg: Union[str, Tuple[str, ...]],
ignorecase: bool, ignorecase: bool,
type: str, type: str,
text: str, text: Optional[str],
expected: bool, expected: bool,
): ):
from nonebot.rule import StartswithRule, startswith from nonebot.rule import StartswithRule, startswith
@ -74,7 +75,7 @@ async def test_startswith(
assert checker.msg == (msg,) if isinstance(msg, str) else msg assert checker.msg == (msg,) if isinstance(msg, str) else msg
assert checker.ignorecase == ignorecase assert checker.ignorecase == ignorecase
message = make_fake_message()(text) message = text if text is None else make_fake_message()(text)
event = make_fake_event(_type=type, _message=message)() event = make_fake_event(_type=type, _message=message)()
assert await dependent(event=event) == expected assert await dependent(event=event) == expected
@ -89,6 +90,7 @@ async def test_startswith(
("suffix", True, "message", "_Suffix", True), ("suffix", True, "message", "_Suffix", True),
("suffix", False, "message", "suffoo", False), ("suffix", False, "message", "suffoo", False),
("suffix", False, "message", "suffixfoo", False), ("suffix", False, "message", "suffixfoo", False),
("suffix", False, "message", None, False),
(("suffix", "foo"), False, "message", "suffixfoo", True), (("suffix", "foo"), False, "message", "suffixfoo", True),
("suffix", False, "notice", "foo", False), ("suffix", False, "notice", "foo", False),
], ],
@ -98,7 +100,7 @@ async def test_endswith(
msg: Union[str, Tuple[str, ...]], msg: Union[str, Tuple[str, ...]],
ignorecase: bool, ignorecase: bool,
type: str, type: str,
text: str, text: Optional[str],
expected: bool, expected: bool,
): ):
from nonebot.rule import EndswithRule, endswith from nonebot.rule import EndswithRule, endswith
@ -111,7 +113,7 @@ async def test_endswith(
assert checker.msg == (msg,) if isinstance(msg, str) else msg assert checker.msg == (msg,) if isinstance(msg, str) else msg
assert checker.ignorecase == ignorecase assert checker.ignorecase == ignorecase
message = make_fake_message()(text) message = text if text is None else make_fake_message()(text)
event = make_fake_event(_type=type, _message=message)() event = make_fake_event(_type=type, _message=message)()
assert await dependent(event=event) == expected assert await dependent(event=event) == expected
@ -126,6 +128,7 @@ async def test_endswith(
("fullmatch", True, "message", "Fullmatch", True), ("fullmatch", True, "message", "Fullmatch", True),
("fullmatch", False, "message", "fullfoo", False), ("fullmatch", False, "message", "fullfoo", False),
("fullmatch", False, "message", "_fullmatch_", False), ("fullmatch", False, "message", "_fullmatch_", False),
("fullmatch", False, "message", None, False),
(("fullmatch", "foo"), False, "message", "fullmatchfoo", False), (("fullmatch", "foo"), False, "message", "fullmatchfoo", False),
("fullmatch", False, "notice", "foo", False), ("fullmatch", False, "notice", "foo", False),
], ],
@ -135,7 +138,7 @@ async def test_fullmatch(
msg: Union[str, Tuple[str, ...]], msg: Union[str, Tuple[str, ...]],
ignorecase: bool, ignorecase: bool,
type: str, type: str,
text: str, text: Optional[str],
expected: bool, expected: bool,
): ):
from nonebot.rule import FullmatchRule, fullmatch from nonebot.rule import FullmatchRule, fullmatch
@ -148,7 +151,7 @@ async def test_fullmatch(
assert checker.msg == ((msg,) if isinstance(msg, str) else msg) assert checker.msg == ((msg,) if isinstance(msg, str) else msg)
assert checker.ignorecase == ignorecase assert checker.ignorecase == ignorecase
message = make_fake_message()(text) message = text if text is None else make_fake_message()(text)
event = make_fake_event(_type=type, _message=message)() event = make_fake_event(_type=type, _message=message)()
assert await dependent(event=event) == expected assert await dependent(event=event) == expected
@ -159,6 +162,7 @@ async def test_fullmatch(
[ [
(("key",), "message", "_key_", True), (("key",), "message", "_key_", True),
(("key", "foo"), "message", "_foo_", True), (("key", "foo"), "message", "_foo_", True),
(("key",), "message", None, False),
(("key",), "notice", "foo", False), (("key",), "notice", "foo", False),
], ],
) )
@ -166,7 +170,7 @@ async def test_keyword(
app: App, app: App,
kws: Tuple[str, ...], kws: Tuple[str, ...],
type: str, type: str,
text: str, text: Optional[str],
expected: bool, expected: bool,
): ):
from nonebot.rule import KeywordsRule, keyword from nonebot.rule import KeywordsRule, keyword
@ -178,7 +182,7 @@ async def test_keyword(
assert isinstance(checker, KeywordsRule) assert isinstance(checker, KeywordsRule)
assert checker.keywords == kws assert checker.keywords == kws
message = make_fake_message()(text) message = text if text is None else make_fake_message()(text)
event = make_fake_event(_type=type, _message=message)() event = make_fake_event(_type=type, _message=message)()
assert await dependent(event=event) == expected assert await dependent(event=event) == expected
@ -302,7 +306,51 @@ async def test_shell_command(app: App):
assert state[SHELL_ARGS].status != 0 assert state[SHELL_ARGS].status != 0
# TODO: regex @pytest.mark.asyncio
@pytest.mark.parametrize(
"pattern,type,text,expected,matched,group,dict",
[
(
r"(?P<key>key\d)",
"message",
"_key1_",
True,
"key1",
("key1",),
{"key": "key1"},
),
(r"foo", "message", None, False, None, None, None),
(r"foo", "notice", "foo", False, None, None, None),
],
)
async def test_regex(
app: App,
pattern: str,
type: str,
text: Optional[str],
expected: bool,
matched: Optional[str],
group: Optional[Tuple[str, ...]],
dict: Optional[Dict[str, str]],
):
from nonebot.typing import T_State
from nonebot.rule import RegexRule, regex
from nonebot.consts import REGEX_DICT, REGEX_GROUP, REGEX_MATCHED
test_regex = regex(pattern)
dependent = list(test_regex.checkers)[0]
checker = dependent.call
assert isinstance(checker, RegexRule)
assert checker.regex == pattern
message = text if text is None else make_fake_message()(text)
event = make_fake_event(_type=type, _message=message)()
state = {}
assert await dependent(event=event, state=state) == expected
assert state.get(REGEX_MATCHED) == matched
assert state.get(REGEX_GROUP) == group
assert state.get(REGEX_DICT) == dict
@pytest.mark.asyncio @pytest.mark.asyncio
@ -310,8 +358,8 @@ async def test_shell_command(app: App):
async def test_to_me(app: App, expected: bool): async def test_to_me(app: App, expected: bool):
from nonebot.rule import ToMeRule, to_me from nonebot.rule import ToMeRule, to_me
test_keyword = to_me() test_to_me = to_me()
dependent = list(test_keyword.checkers)[0] dependent = list(test_to_me.checkers)[0]
checker = dependent.call checker = dependent.call
assert isinstance(checker, ToMeRule) assert isinstance(checker, ToMeRule)

View File

@ -65,7 +65,7 @@ def make_fake_event(
_type: str = "message", _type: str = "message",
_name: str = "test", _name: str = "test",
_description: str = "test", _description: str = "test",
_user_id: str = "test", _user_id: Optional[str] = "test",
_session_id: Optional[str] = "test", _session_id: Optional[str] = "test",
_message: Optional["Message"] = None, _message: Optional["Message"] = None,
_to_me: bool = True, _to_me: bool = True,
@ -86,7 +86,9 @@ def make_fake_event(
return _description return _description
def get_user_id(self) -> str: def get_user_id(self) -> str:
if _user_id is not None:
return _user_id return _user_id
raise NotImplementedError
def get_session_id(self) -> str: def get_session_id(self) -> str:
if _session_id is not None: if _session_id is not None: