From 2ccdc218e0568d6609e63e597fb40e2061851bbd Mon Sep 17 00:00:00 2001 From: yanyongyu Date: Mon, 10 Jan 2022 11:20:06 +0800 Subject: [PATCH] :wheelchair: improve state detect #677 --- nonebot/params.py | 31 +++++++++++++----------------- nonebot/rule.py | 5 ++--- tests/plugins/param/param_state.py | 3 +-- 3 files changed, 16 insertions(+), 23 deletions(-) diff --git a/nonebot/params.py b/nonebot/params.py index be6baaee..2563d9a9 100644 --- a/nonebot/params.py +++ b/nonebot/params.py @@ -278,27 +278,22 @@ def EventToMe() -> bool: return Depends(_event_to_me) -class StateInner: - ... - - -def State() -> T_State: - return StateInner() # type: ignore - - class StateParam(Param): @classmethod def _check_param( cls, dependent: Dependent, name: str, param: inspect.Parameter ) -> Optional["StateParam"]: - if isinstance(param.default, StateInner): - return cls(Required) + if param.default == param.empty: + if param.annotation is T_State: + return cls(Required) + elif param.annotation == param.empty and name == "state": + return cls(Required) async def _solve(self, state: T_State, **kwargs: Any) -> Any: return state -def _command(state=State()) -> Message: +def _command(state: T_State) -> Message: return state[PREFIX_KEY][CMD_KEY] @@ -306,7 +301,7 @@ def Command() -> Tuple[str, ...]: return Depends(_command, use_cache=False) -def _raw_command(state=State()) -> Message: +def _raw_command(state: T_State) -> Message: return state[PREFIX_KEY][RAW_CMD_KEY] @@ -314,7 +309,7 @@ def RawCommand() -> str: return Depends(_raw_command, use_cache=False) -def _command_arg(state=State()) -> Message: +def _command_arg(state: T_State) -> Message: return state[PREFIX_KEY][CMD_ARG_KEY] @@ -322,7 +317,7 @@ def CommandArg() -> Any: return Depends(_command_arg, use_cache=False) -def _shell_command_args(state=State()) -> Any: +def _shell_command_args(state: T_State) -> Any: return state[SHELL_ARGS] @@ -330,7 +325,7 @@ def ShellCommandArgs(): return Depends(_shell_command_args, use_cache=False) -def _shell_command_argv(state=State()) -> List[str]: +def _shell_command_argv(state: T_State) -> List[str]: return state[SHELL_ARGV] @@ -338,7 +333,7 @@ def ShellCommandArgv() -> Any: return Depends(_shell_command_argv, use_cache=False) -def _regex_matched(state=State()) -> str: +def _regex_matched(state: T_State) -> str: return state[REGEX_MATCHED] @@ -346,7 +341,7 @@ def RegexMatched() -> str: return Depends(_regex_matched, use_cache=False) -def _regex_group(state=State()): +def _regex_group(state: T_State): return state[REGEX_GROUP] @@ -354,7 +349,7 @@ def RegexGroup() -> Tuple[Any, ...]: return Depends(_regex_group, use_cache=False) -def _regex_dict(state=State()): +def _regex_dict(state: T_State): return state[REGEX_DICT] diff --git a/nonebot/rule.py b/nonebot/rule.py index b4cf827c..368dc490 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -39,7 +39,6 @@ from nonebot.consts import ( REGEX_MATCHED, ) from nonebot.params import ( - State, Command, BotParam, EventToMe, @@ -390,9 +389,9 @@ class ShellCommandRule: async def __call__( self, + state: T_State, cmd: Optional[Tuple[str, ...]] = Command(), msg: Message = EventMessage(), - state: T_State = State(), ) -> bool: if cmd in self.cmds: message = str(msg) @@ -475,9 +474,9 @@ class RegexRule: async def __call__( self, + state: T_State, type: str = EventType(), msg: Message = EventMessage(), - state: T_State = State(), ) -> bool: if type != "message": return False diff --git a/tests/plugins/param/param_state.py b/tests/plugins/param/param_state.py index beec94b8..636015fb 100644 --- a/tests/plugins/param/param_state.py +++ b/tests/plugins/param/param_state.py @@ -3,7 +3,6 @@ from typing import List, Tuple from nonebot.typing import T_State from nonebot.adapters import Message from nonebot.params import ( - State, Command, RegexDict, CommandArg, @@ -15,7 +14,7 @@ from nonebot.params import ( ) -async def state(x: T_State = State()) -> T_State: +async def state(x: T_State) -> T_State: return x