diff --git a/nonebot/params.py b/nonebot/params.py index fc5990b9..9f3c60cd 100644 --- a/nonebot/params.py +++ b/nonebot/params.py @@ -5,7 +5,18 @@ FrontMatter: description: nonebot.params 模块 """ -from typing import Any, Dict, List, Match, Tuple, Union, Optional +from typing import ( + Any, + Dict, + List, + Match, + Tuple, + Union, + Literal, + Callable, + Optional, + overload, +) from nonebot.typing import T_State from nonebot.matcher import Matcher @@ -147,13 +158,37 @@ def RegexMatched() -> Match[str]: return Depends(_regex_matched, use_cache=False) -def _regex_str(state: T_State) -> str: - return _regex_matched(state).group() +def _regex_str( + groups: Tuple[Union[str, int], ...] +) -> Callable[[T_State], Union[str, Tuple[Union[str, Any], ...], Any]]: + def _regex_str_dependency( + state: T_State, + ) -> Union[str, Tuple[Union[str, Any], ...], Any]: + return _regex_matched(state).group(*groups) + + return _regex_str_dependency -def RegexStr() -> str: +@overload +def RegexStr(__group: Literal[0] = 0) -> str: + ... + + +@overload +def RegexStr(__group: Union[str, int]) -> Union[str, Any]: + ... + + +@overload +def RegexStr( + __group1: Union[str, int], __group2: Union[str, int], *groups: Union[str, int] +) -> Tuple[Union[str, Any], ...]: + ... + + +def RegexStr(*groups: Union[str, int]) -> Union[str, Tuple[Union[str, Any], ...], Any]: """正则匹配结果文本""" - return Depends(_regex_str, use_cache=False) + return Depends(_regex_str(groups), use_cache=False) def _regex_group(state: T_State) -> Tuple[Any, ...]: diff --git a/tests/plugins/param/param_state.py b/tests/plugins/param/param_state.py index 06731ada..f513ecd6 100644 --- a/tests/plugins/param/param_state.py +++ b/tests/plugins/param/param_state.py @@ -77,8 +77,13 @@ async def regex_matched(regex_matched: Match[str] = RegexMatched()) -> Match[str return regex_matched -async def regex_str(regex_str: str = RegexStr()) -> str: - return regex_str +async def regex_str( + entire: str = RegexStr(), + type_: str = RegexStr("type"), + second: str = RegexStr(2), + groups: Tuple[str, ...] = RegexStr(1, "arg"), +) -> Tuple[str, str, str, Tuple[str, ...]]: + return entire, type_, second, groups async def startswith(startswith: str = Startswith()) -> str: diff --git a/tests/test_param.py b/tests/test_param.py index 3bbf70ae..8a5323d2 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -361,7 +361,9 @@ async def test_state(app: App): regex_str, allow_types=[StateParam, DependParam] ) as ctx: ctx.pass_params(state=fake_state) - ctx.should_return("[cq:test,arg=value]") + ctx.should_return( + ("[cq:test,arg=value]", "test", "arg=value", ("test", "arg=value")) + ) async with app.test_dependent( regex_group, allow_types=[StateParam, DependParam]