From 8377680fd7a1a98d2db465977ec2d91af810dac8 Mon Sep 17 00:00:00 2001 From: Akirami <66513481+A-kirami@users.noreply.github.com> Date: Wed, 12 Oct 2022 13:41:28 +0800 Subject: [PATCH] =?UTF-8?q?:sparkles:=20Feature:=20=E6=B7=BB=E5=8A=A0=20St?= =?UTF-8?q?ate=20=E5=93=8D=E5=BA=94=E5=99=A8=E8=A7=A6=E5=8F=91=E6=B6=88?= =?UTF-8?q?=E6=81=AF=E6=B3=A8=E5=85=A5=20(#1315)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot/consts.py | 8 +++ nonebot/params.py | 40 +++++++++++++ nonebot/rule.py | 58 ++++++++++++------- tests/plugins/param/param_state.py | 20 +++++++ tests/test_param.py | 36 ++++++++++++ tests/test_rule.py | 33 ++++++++--- .../docs/tutorial/plugin/create-handler.md | 56 ++++++++++++++++++ 7 files changed, 223 insertions(+), 28 deletions(-) diff --git a/nonebot/consts.py b/nonebot/consts.py index 31b781d6..17637149 100644 --- a/nonebot/consts.py +++ b/nonebot/consts.py @@ -42,3 +42,11 @@ REGEX_GROUP: Literal["_matched_groups"] = "_matched_groups" """正则匹配 group 元组存储 key""" REGEX_DICT: Literal["_matched_dict"] = "_matched_dict" """正则匹配 group 字典存储 key""" +STARTSWITH_KEY: Literal["_startswith"] = "_startswith" +"""响应触发前缀 key""" +ENDSWITH_KEY: Literal["_endswith"] = "_endswith" +"""响应触发后缀 key""" +FULLMATCH_KEY: Literal["_fullmatch"] = "_fullmatch" +"""响应触发完整消息 key""" +KEYWORD_KEY: Literal["_keyword"] = "_keyword" +"""响应触发关键字 key""" diff --git a/nonebot/params.py b/nonebot/params.py index c5b8c751..4628aae9 100644 --- a/nonebot/params.py +++ b/nonebot/params.py @@ -29,10 +29,14 @@ from nonebot.consts import ( SHELL_ARGS, SHELL_ARGV, CMD_ARG_KEY, + KEYWORD_KEY, RAW_CMD_KEY, REGEX_GROUP, + ENDSWITH_KEY, CMD_START_KEY, + FULLMATCH_KEY, REGEX_MATCHED, + STARTSWITH_KEY, ) @@ -153,6 +157,42 @@ def RegexDict() -> Dict[str, Any]: return Depends(_regex_dict, use_cache=False) +def _startswith(state: T_State) -> str: + return state[STARTSWITH_KEY] + + +def Startswith() -> str: + """响应触发前缀""" + return Depends(_startswith, use_cache=False) + + +def _endswith(state: T_State) -> str: + return state[ENDSWITH_KEY] + + +def Endswith() -> str: + """响应触发后缀""" + return Depends(_endswith, use_cache=False) + + +def _fullmatch(state: T_State) -> str: + return state[FULLMATCH_KEY] + + +def Fullmatch() -> str: + """响应触发完整消息""" + return Depends(_fullmatch, use_cache=False) + + +def _keyword(state: T_State) -> str: + return state[KEYWORD_KEY] + + +def Keyword() -> str: + """响应触发关键字""" + return Depends(_keyword, use_cache=False) + + def Received(id: Optional[str] = None, default: Any = None) -> Any: """`receive` 事件参数""" diff --git a/nonebot/rule.py b/nonebot/rule.py index 5d1d8689..a9b1cca3 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -18,7 +18,6 @@ from argparse import ArgumentParser as ArgParser from typing import ( IO, TYPE_CHECKING, - Any, List, Type, Tuple, @@ -48,10 +47,14 @@ from nonebot.consts import ( SHELL_ARGS, SHELL_ARGV, CMD_ARG_KEY, + KEYWORD_KEY, RAW_CMD_KEY, REGEX_GROUP, + ENDSWITH_KEY, CMD_START_KEY, + FULLMATCH_KEY, REGEX_MATCHED, + STARTSWITH_KEY, ) T = TypeVar("T") @@ -136,20 +139,21 @@ class StartswithRule: def __hash__(self) -> int: return hash((frozenset(self.msg), self.ignorecase)) - async def __call__(self, event: Event) -> bool: + async def __call__(self, event: Event, state: T_State) -> bool: if event.get_type() != "message": return False try: text = event.get_plaintext() except Exception: return False - return bool( - re.match( - f"^(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})", - text, - re.IGNORECASE if self.ignorecase else 0, - ) - ) + if match := re.match( + f"^(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})", + text, + re.IGNORECASE if self.ignorecase else 0, + ): + state[STARTSWITH_KEY] = match.group() + return True + return False def startswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule: @@ -192,20 +196,21 @@ class EndswithRule: def __hash__(self) -> int: return hash((frozenset(self.msg), self.ignorecase)) - async def __call__(self, event: Event) -> bool: + async def __call__(self, event: Event, state: T_State) -> bool: if event.get_type() != "message": return False try: text = event.get_plaintext() except Exception: return False - return bool( - re.search( - f"(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})$", - text, - re.IGNORECASE if self.ignorecase else 0, - ) - ) + if match := re.search( + f"(?:{'|'.join(re.escape(suffix) for suffix in self.msg)})$", + text, + re.IGNORECASE if self.ignorecase else 0, + ): + state[ENDSWITH_KEY] = match.group() + return True + return False def endswith(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule: @@ -248,14 +253,20 @@ class FullmatchRule: def __hash__(self) -> int: return hash((frozenset(self.msg), self.ignorecase)) - async def __call__(self, event: Event) -> bool: + async def __call__(self, event: Event, state: T_State) -> bool: if event.get_type() != "message": return False try: text = event.get_plaintext() except Exception: return False - return (text.casefold() if self.ignorecase else text) in self.msg + if not text: + return False + text = text.casefold() if self.ignorecase else text + if text in self.msg: + state[FULLMATCH_KEY] = text + return True + return False def fullmatch(msg: Union[str, Tuple[str, ...]], ignorecase: bool = False) -> Rule: @@ -294,14 +305,19 @@ class KeywordsRule: def __hash__(self) -> int: return hash(frozenset(self.keywords)) - async def __call__(self, event: Event) -> bool: + async def __call__(self, event: Event, state: T_State) -> bool: if event.get_type() != "message": return False try: text = event.get_plaintext() except Exception: return False - return bool(text and any(keyword in text for keyword in self.keywords)) + if not text: + return False + if key := next((k for k in self.keywords if k in text), None): + state[KEYWORD_KEY] = key + return True + return False def keyword(*keywords: str) -> Rule: diff --git a/tests/plugins/param/param_state.py b/tests/plugins/param/param_state.py index d9a2e21d..97f84bd4 100644 --- a/tests/plugins/param/param_state.py +++ b/tests/plugins/param/param_state.py @@ -4,10 +4,14 @@ from nonebot.typing import T_State from nonebot.adapters import Message from nonebot.params import ( Command, + Keyword, + Endswith, + Fullmatch, RegexDict, CommandArg, RawCommand, RegexGroup, + Startswith, CommandStart, RegexMatched, ShellCommandArgs, @@ -65,3 +69,19 @@ async def regex_group(regex_group: Tuple = RegexGroup()) -> Tuple: async def regex_matched(regex_matched: str = RegexMatched()) -> str: return regex_matched + + +async def startswith(startswith: str = Startswith()) -> str: + return startswith + + +async def endswith(endswith: str = Endswith()) -> str: + return endswith + + +async def fullmatch(fullmatch: str = Fullmatch()) -> str: + return fullmatch + + +async def keyword(keyword: str = Keyword()) -> str: + return keyword diff --git a/tests/test_param.py b/tests/test_param.py index e904ba14..05122371 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -168,15 +168,23 @@ async def test_state(app: App, load_plugin): SHELL_ARGS, SHELL_ARGV, CMD_ARG_KEY, + KEYWORD_KEY, RAW_CMD_KEY, REGEX_GROUP, + ENDSWITH_KEY, CMD_START_KEY, + FULLMATCH_KEY, REGEX_MATCHED, + STARTSWITH_KEY, ) from plugins.param.param_state import ( state, command, + keyword, + endswith, + fullmatch, regex_dict, + startswith, command_arg, raw_command, regex_group, @@ -201,6 +209,10 @@ async def test_state(app: App, load_plugin): REGEX_MATCHED: "[cq:test,arg=value]", REGEX_GROUP: ("test", "arg=value"), REGEX_DICT: {"type": "test", "arg": "value"}, + STARTSWITH_KEY: "startswith", + ENDSWITH_KEY: "endswith", + FULLMATCH_KEY: "fullmatch", + KEYWORD_KEY: "keyword", } async with app.test_dependent(state, allow_types=[StateParam]) as ctx: @@ -271,6 +283,30 @@ async def test_state(app: App, load_plugin): ctx.pass_params(state=fake_state) ctx.should_return(fake_state[REGEX_DICT]) + async with app.test_dependent( + startswith, allow_types=[StateParam, DependParam] + ) as ctx: + ctx.pass_params(state=fake_state) + ctx.should_return(fake_state[STARTSWITH_KEY]) + + async with app.test_dependent( + endswith, allow_types=[StateParam, DependParam] + ) as ctx: + ctx.pass_params(state=fake_state) + ctx.should_return(fake_state[ENDSWITH_KEY]) + + async with app.test_dependent( + fullmatch, allow_types=[StateParam, DependParam] + ) as ctx: + ctx.pass_params(state=fake_state) + ctx.should_return(fake_state[FULLMATCH_KEY]) + + async with app.test_dependent( + keyword, allow_types=[StateParam, DependParam] + ) as ctx: + ctx.pass_params(state=fake_state) + ctx.should_return(fake_state[KEYWORD_KEY]) + @pytest.mark.asyncio async def test_matcher(app: App, load_plugin): diff --git a/tests/test_rule.py b/tests/test_rule.py index eeefc9cb..d815b91c 100644 --- a/tests/test_rule.py +++ b/tests/test_rule.py @@ -65,19 +65,24 @@ async def test_startswith( text: Optional[str], expected: bool, ): + from nonebot.consts import STARTSWITH_KEY from nonebot.rule import StartswithRule, startswith test_startswith = startswith(msg, ignorecase) dependent = list(test_startswith.checkers)[0] checker = dependent.call + msg = (msg,) if isinstance(msg, str) else msg + assert isinstance(checker, StartswithRule) - assert checker.msg == (msg,) if isinstance(msg, str) else msg + assert checker.msg == msg assert checker.ignorecase == ignorecase message = text if text is None else make_fake_message()(text) event = make_fake_event(_type=type, _message=message)() - assert await dependent(event=event) == expected + for prefix in msg: + state = {STARTSWITH_KEY: prefix} + assert await dependent(event=event, state=state) == expected @pytest.mark.asyncio @@ -103,19 +108,24 @@ async def test_endswith( text: Optional[str], expected: bool, ): + from nonebot.consts import ENDSWITH_KEY from nonebot.rule import EndswithRule, endswith test_endswith = endswith(msg, ignorecase) dependent = list(test_endswith.checkers)[0] checker = dependent.call + msg = (msg,) if isinstance(msg, str) else msg + assert isinstance(checker, EndswithRule) - assert checker.msg == (msg,) if isinstance(msg, str) else msg + assert checker.msg == msg assert checker.ignorecase == ignorecase message = text if text is None else make_fake_message()(text) event = make_fake_event(_type=type, _message=message)() - assert await dependent(event=event) == expected + for suffix in msg: + state = {ENDSWITH_KEY: suffix} + assert await dependent(event=event, state=state) == expected @pytest.mark.asyncio @@ -141,19 +151,24 @@ async def test_fullmatch( text: Optional[str], expected: bool, ): + from nonebot.consts import FULLMATCH_KEY from nonebot.rule import FullmatchRule, fullmatch test_fullmatch = fullmatch(msg, ignorecase) dependent = list(test_fullmatch.checkers)[0] checker = dependent.call + msg = (msg,) if isinstance(msg, str) else msg + assert isinstance(checker, FullmatchRule) - assert checker.msg == ((msg,) if isinstance(msg, str) else msg) + assert checker.msg == msg assert checker.ignorecase == ignorecase message = text if text is None else make_fake_message()(text) event = make_fake_event(_type=type, _message=message)() - assert await dependent(event=event) == expected + for full in msg: + state = {FULLMATCH_KEY: full} + assert await dependent(event=event, state=state) == expected @pytest.mark.asyncio @@ -164,6 +179,7 @@ async def test_fullmatch( (("key", "foo"), "message", "_foo_", True), (("key",), "message", None, False), (("key",), "notice", "foo", False), + (("key",), "message", "foo", False), ], ) async def test_keyword( @@ -173,6 +189,7 @@ async def test_keyword( text: Optional[str], expected: bool, ): + from nonebot.consts import KEYWORD_KEY from nonebot.rule import KeywordsRule, keyword test_keyword = keyword(*kws) @@ -184,7 +201,9 @@ async def test_keyword( message = text if text is None else make_fake_message()(text) event = make_fake_event(_type=type, _message=message)() - assert await dependent(event=event) == expected + for kw in kws: + state = {KEYWORD_KEY: kw} + assert await dependent(event=event, state=state) == expected @pytest.mark.asyncio diff --git a/website/docs/tutorial/plugin/create-handler.md b/website/docs/tutorial/plugin/create-handler.md index c6aef475..382bfc0c 100644 --- a/website/docs/tutorial/plugin/create-handler.md +++ b/website/docs/tutorial/plugin/create-handler.md @@ -363,6 +363,62 @@ matcher = on_regex("regex") async def _(foo: Dict[str, Any] = RegexDict()): ... ``` +### Startswith + +获取触发响应器的消息前缀字符串。 + +```python {7} +from nonebot import on_startswith +from nonebot.params import Startswith + +matcher = on_startswith("prefix") + +@matcher.handle() +async def _(foo: str = Startswith()): ... +``` + +### Endswith + +获取触发响应器的消息后缀字符串。 + +```python {7} +from nonebot import on_endswith +from nonebot.params import Endswith + +matcher = on_endswith("suffix") + +@matcher.handle() +async def _(foo: str = Endswith()): ... +``` + +### Fullmatch + +获取触发响应器的消息字符串。 + +```python {7} +from nonebot import on_fullmatch +from nonebot.params import Fullmatch + +matcher = on_fullmatch("fullmatch") + +@matcher.handle() +async def _(foo: str = Fullmatch()): ... +``` + +### Keyword + +获取触发响应器的关键字字符串。 + +```python {7} +from nonebot import on_keyword +from nonebot.params import Keyword + +matcher = on_keyword({"keyword"}) + +@matcher.handle() +async def _(foo: str = Keyword()): ... +``` + ### Matcher 获取当前事件响应器实例。