From 71aad502d151ab22f94ff99a29e7b7fedac5fbc7 Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Thu, 29 Sep 2022 16:56:06 +0800 Subject: [PATCH] =?UTF-8?q?:bug:=20Fix:=20=E5=86=85=E7=BD=AE=E8=A7=84?= =?UTF-8?q?=E5=88=99=E5=92=8C=E6=9D=83=E9=99=90=E6=B2=A1=E6=9C=89=E6=8D=95?= =?UTF-8?q?=E8=8E=B7=E9=94=99=E8=AF=AF=20(#1291)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot/internal/permission.py | 7 +++- nonebot/permission.py | 8 +++- nonebot/rule.py | 70 ++++++++++++++++----------------- tests/test_permission.py | 6 ++- tests/test_rule.py | 72 ++++++++++++++++++++++++++++------ tests/utils.py | 6 ++- 6 files changed, 113 insertions(+), 56 deletions(-) diff --git a/nonebot/internal/permission.py b/nonebot/internal/permission.py index d844ffdd..f1be4b37 100644 --- a/nonebot/internal/permission.py +++ b/nonebot/internal/permission.py @@ -132,9 +132,12 @@ class User: ) async def __call__(self, bot: Bot, event: Event) -> bool: + try: + session = event.get_session_id() + except Exception: + return False return bool( - event.get_session_id() in self.users - and (self.perm is None or await self.perm(bot, event)) + session in self.users and (self.perm is None or await self.perm(bot, event)) ) diff --git a/nonebot/permission.py b/nonebot/permission.py index 49226f6f..9c4da085 100644 --- a/nonebot/permission.py +++ b/nonebot/permission.py @@ -94,10 +94,14 @@ class SuperUser: return "Superuser()" async def __call__(self, bot: Bot, event: Event) -> bool: + try: + user_id = event.get_user_id() + except Exception: + return False return ( - f"{bot.adapter.get_name().split(maxsplit=1)[0].lower()}:{event.get_user_id()}" + f"{bot.adapter.get_name().split(maxsplit=1)[0].lower()}:{user_id}" in bot.config.superusers - or event.get_user_id() in bot.config.superusers # 兼容旧配置 + or user_id in bot.config.superusers # 兼容旧配置 ) diff --git a/nonebot/rule.py b/nonebot/rule.py index febb455a..5d1d8689 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -39,15 +39,8 @@ from nonebot.log import logger from nonebot.typing import T_State from nonebot.exception import ParserExit from nonebot.internal.rule import Rule as Rule +from nonebot.params import Command, EventToMe, CommandArg from nonebot.adapters import Bot, Event, Message, MessageSegment -from nonebot.params import ( - Command, - EventToMe, - EventType, - CommandArg, - EventMessage, - EventPlainText, -) from nonebot.consts import ( CMD_KEY, PREFIX_KEY, @@ -143,10 +136,12 @@ class StartswithRule: def __hash__(self) -> int: return hash((frozenset(self.msg), self.ignorecase)) - async def __call__( - self, type: str = EventType(), text: str = EventPlainText() - ) -> Any: - if type != "message": + async def __call__(self, event: Event) -> bool: + if event.get_type() != "message": + return False + try: + text = event.get_plaintext() + except Exception: return False return bool( re.match( @@ -197,10 +192,12 @@ class EndswithRule: def __hash__(self) -> int: return hash((frozenset(self.msg), self.ignorecase)) - async def __call__( - self, type: str = EventType(), text: str = EventPlainText() - ) -> Any: - if type != "message": + async def __call__(self, event: Event) -> bool: + if event.get_type() != "message": + return False + try: + text = event.get_plaintext() + except Exception: return False return bool( re.search( @@ -251,13 +248,14 @@ class FullmatchRule: def __hash__(self) -> int: return hash((frozenset(self.msg), self.ignorecase)) - async def __call__( - self, type_: str = EventType(), text: str = EventPlainText() - ) -> bool: - return ( - type_ == "message" - and (text.casefold() if self.ignorecase else text) in self.msg - ) + async def __call__(self, event: Event) -> bool: + if event.get_type() != "message": + return False + try: + text = event.get_plaintext() + except Exception: + return False + return (text.casefold() if self.ignorecase else text) in self.msg def fullmatch(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule: @@ -296,10 +294,12 @@ class KeywordsRule: def __hash__(self) -> int: return hash(frozenset(self.keywords)) - async def __call__( - self, type: str = EventType(), text: str = EventPlainText() - ) -> bool: - if type != "message": + async def __call__(self, event: Event) -> bool: + if event.get_type() != "message": + return False + try: + text = event.get_plaintext() + except Exception: return False return bool(text and any(keyword in text for keyword in self.keywords)) @@ -583,16 +583,14 @@ class RegexRule: def __hash__(self) -> int: return hash((self.regex, self.flags)) - async def __call__( - self, - state: T_State, - type: str = EventType(), - msg: Message = EventMessage(), - ) -> bool: - if type != "message": + async def __call__(self, event: Event, state: T_State) -> bool: + if event.get_type() != "message": return False - matched = re.search(self.regex, str(msg), self.flags) - if matched: + try: + msg = event.get_message() + except Exception: + return False + if matched := re.search(self.regex, str(msg), self.flags): state[REGEX_MATCHED] = matched.group() state[REGEX_GROUP] = matched.groups() state[REGEX_DICT] = matched.groupdict() diff --git a/tests/test_permission.py b/tests/test_permission.py index bea16e33..5d988977 100644 --- a/tests/test_permission.py +++ b/tests/test_permission.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Tuple, Optional import pytest from nonebug import App @@ -144,6 +144,7 @@ async def test_metaevent( ("message", "test", True), ("message", "foo", False), ("message", "faketest", True), + ("message", None, False), ("notice", "test", True), ], ) @@ -173,10 +174,11 @@ async def test_superuser( [ (("user", "foo"), "user", True), (("user", "foo"), "bar", False), + (("user", "foo"), None, False), ], ) async def test_user( - app: App, session_ids: Tuple[str, ...], session_id: str, expected: bool + app: App, session_ids: Tuple[str, ...], session_id: Optional[str], expected: bool ): from nonebot.permission import USER, User diff --git a/tests/test_rule.py b/tests/test_rule.py index 6d685592..eeefc9cb 100644 --- a/tests/test_rule.py +++ b/tests/test_rule.py @@ -1,5 +1,5 @@ import sys -from typing import Tuple, Union +from typing import Dict, Tuple, Union, Optional import pytest from nonebug import App @@ -52,6 +52,7 @@ async def test_rule(app: App): ("prefix", True, "message", "Prefix_", True), ("prefix", False, "message", "prefoo", False), ("prefix", False, "message", "fooprefix", False), + ("prefix", False, "message", None, False), (("prefix", "foo"), False, "message", "fooprefix", True), ("prefix", False, "notice", "foo", False), ], @@ -61,7 +62,7 @@ async def test_startswith( msg: Union[str, Tuple[str, ...]], ignorecase: bool, type: str, - text: str, + text: Optional[str], expected: bool, ): from nonebot.rule import StartswithRule, startswith @@ -74,7 +75,7 @@ async def test_startswith( assert checker.msg == (msg,) if isinstance(msg, str) else msg assert checker.ignorecase == ignorecase - message = make_fake_message()(text) + message = text if text is None else make_fake_message()(text) event = make_fake_event(_type=type, _message=message)() assert await dependent(event=event) == expected @@ -89,6 +90,7 @@ async def test_startswith( ("suffix", True, "message", "_Suffix", True), ("suffix", False, "message", "suffoo", False), ("suffix", False, "message", "suffixfoo", False), + ("suffix", False, "message", None, False), (("suffix", "foo"), False, "message", "suffixfoo", True), ("suffix", False, "notice", "foo", False), ], @@ -98,7 +100,7 @@ async def test_endswith( msg: Union[str, Tuple[str, ...]], ignorecase: bool, type: str, - text: str, + text: Optional[str], expected: bool, ): from nonebot.rule import EndswithRule, endswith @@ -111,7 +113,7 @@ async def test_endswith( assert checker.msg == (msg,) if isinstance(msg, str) else msg assert checker.ignorecase == ignorecase - message = make_fake_message()(text) + message = text if text is None else make_fake_message()(text) event = make_fake_event(_type=type, _message=message)() assert await dependent(event=event) == expected @@ -126,6 +128,7 @@ async def test_endswith( ("fullmatch", True, "message", "Fullmatch", True), ("fullmatch", False, "message", "fullfoo", False), ("fullmatch", False, "message", "_fullmatch_", False), + ("fullmatch", False, "message", None, False), (("fullmatch", "foo"), False, "message", "fullmatchfoo", False), ("fullmatch", False, "notice", "foo", False), ], @@ -135,7 +138,7 @@ async def test_fullmatch( msg: Union[str, Tuple[str, ...]], ignorecase: bool, type: str, - text: str, + text: Optional[str], expected: bool, ): from nonebot.rule import FullmatchRule, fullmatch @@ -148,7 +151,7 @@ async def test_fullmatch( assert checker.msg == ((msg,) if isinstance(msg, str) else msg) assert checker.ignorecase == ignorecase - message = make_fake_message()(text) + message = text if text is None else make_fake_message()(text) event = make_fake_event(_type=type, _message=message)() assert await dependent(event=event) == expected @@ -159,6 +162,7 @@ async def test_fullmatch( [ (("key",), "message", "_key_", True), (("key", "foo"), "message", "_foo_", True), + (("key",), "message", None, False), (("key",), "notice", "foo", False), ], ) @@ -166,7 +170,7 @@ async def test_keyword( app: App, kws: Tuple[str, ...], type: str, - text: str, + text: Optional[str], expected: bool, ): from nonebot.rule import KeywordsRule, keyword @@ -178,7 +182,7 @@ async def test_keyword( assert isinstance(checker, KeywordsRule) assert checker.keywords == kws - message = make_fake_message()(text) + message = text if text is None else make_fake_message()(text) event = make_fake_event(_type=type, _message=message)() assert await dependent(event=event) == expected @@ -302,7 +306,51 @@ async def test_shell_command(app: App): assert state[SHELL_ARGS].status != 0 -# TODO: regex +@pytest.mark.asyncio +@pytest.mark.parametrize( + "pattern,type,text,expected,matched,group,dict", + [ + ( + r"(?Pkey\d)", + "message", + "_key1_", + True, + "key1", + ("key1",), + {"key": "key1"}, + ), + (r"foo", "message", None, False, None, None, None), + (r"foo", "notice", "foo", False, None, None, None), + ], +) +async def test_regex( + app: App, + pattern: str, + type: str, + text: Optional[str], + expected: bool, + matched: Optional[str], + group: Optional[Tuple[str, ...]], + dict: Optional[Dict[str, str]], +): + from nonebot.typing import T_State + from nonebot.rule import RegexRule, regex + from nonebot.consts import REGEX_DICT, REGEX_GROUP, REGEX_MATCHED + + test_regex = regex(pattern) + dependent = list(test_regex.checkers)[0] + checker = dependent.call + + assert isinstance(checker, RegexRule) + assert checker.regex == pattern + + message = text if text is None else make_fake_message()(text) + event = make_fake_event(_type=type, _message=message)() + state = {} + assert await dependent(event=event, state=state) == expected + assert state.get(REGEX_MATCHED) == matched + assert state.get(REGEX_GROUP) == group + assert state.get(REGEX_DICT) == dict @pytest.mark.asyncio @@ -310,8 +358,8 @@ async def test_shell_command(app: App): async def test_to_me(app: App, expected: bool): from nonebot.rule import ToMeRule, to_me - test_keyword = to_me() - dependent = list(test_keyword.checkers)[0] + test_to_me = to_me() + dependent = list(test_to_me.checkers)[0] checker = dependent.call assert isinstance(checker, ToMeRule) diff --git a/tests/utils.py b/tests/utils.py index d0715e0e..46f0ade0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -65,7 +65,7 @@ def make_fake_event( _type: str = "message", _name: str = "test", _description: str = "test", - _user_id: str = "test", + _user_id: Optional[str] = "test", _session_id: Optional[str] = "test", _message: Optional["Message"] = None, _to_me: bool = True, @@ -86,7 +86,9 @@ def make_fake_event( return _description def get_user_id(self) -> str: - return _user_id + if _user_id is not None: + return _user_id + raise NotImplementedError def get_session_id(self) -> str: if _session_id is not None: