improve state detect #677

This commit is contained in:
yanyongyu 2022-01-10 11:20:06 +08:00
parent a14cfc8d77
commit 2ccdc218e0
3 changed files with 16 additions and 23 deletions

View File

@ -278,27 +278,22 @@ def EventToMe() -> bool:
return Depends(_event_to_me) return Depends(_event_to_me)
class StateInner:
...
def State() -> T_State:
return StateInner() # type: ignore
class StateParam(Param): class StateParam(Param):
@classmethod @classmethod
def _check_param( def _check_param(
cls, dependent: Dependent, name: str, param: inspect.Parameter cls, dependent: Dependent, name: str, param: inspect.Parameter
) -> Optional["StateParam"]: ) -> Optional["StateParam"]:
if isinstance(param.default, StateInner): if param.default == param.empty:
if param.annotation is T_State:
return cls(Required)
elif param.annotation == param.empty and name == "state":
return cls(Required) return cls(Required)
async def _solve(self, state: T_State, **kwargs: Any) -> Any: async def _solve(self, state: T_State, **kwargs: Any) -> Any:
return state return state
def _command(state=State()) -> Message: def _command(state: T_State) -> Message:
return state[PREFIX_KEY][CMD_KEY] return state[PREFIX_KEY][CMD_KEY]
@ -306,7 +301,7 @@ def Command() -> Tuple[str, ...]:
return Depends(_command, use_cache=False) return Depends(_command, use_cache=False)
def _raw_command(state=State()) -> Message: def _raw_command(state: T_State) -> Message:
return state[PREFIX_KEY][RAW_CMD_KEY] return state[PREFIX_KEY][RAW_CMD_KEY]
@ -314,7 +309,7 @@ def RawCommand() -> str:
return Depends(_raw_command, use_cache=False) return Depends(_raw_command, use_cache=False)
def _command_arg(state=State()) -> Message: def _command_arg(state: T_State) -> Message:
return state[PREFIX_KEY][CMD_ARG_KEY] return state[PREFIX_KEY][CMD_ARG_KEY]
@ -322,7 +317,7 @@ def CommandArg() -> Any:
return Depends(_command_arg, use_cache=False) return Depends(_command_arg, use_cache=False)
def _shell_command_args(state=State()) -> Any: def _shell_command_args(state: T_State) -> Any:
return state[SHELL_ARGS] return state[SHELL_ARGS]
@ -330,7 +325,7 @@ def ShellCommandArgs():
return Depends(_shell_command_args, use_cache=False) return Depends(_shell_command_args, use_cache=False)
def _shell_command_argv(state=State()) -> List[str]: def _shell_command_argv(state: T_State) -> List[str]:
return state[SHELL_ARGV] return state[SHELL_ARGV]
@ -338,7 +333,7 @@ def ShellCommandArgv() -> Any:
return Depends(_shell_command_argv, use_cache=False) return Depends(_shell_command_argv, use_cache=False)
def _regex_matched(state=State()) -> str: def _regex_matched(state: T_State) -> str:
return state[REGEX_MATCHED] return state[REGEX_MATCHED]
@ -346,7 +341,7 @@ def RegexMatched() -> str:
return Depends(_regex_matched, use_cache=False) return Depends(_regex_matched, use_cache=False)
def _regex_group(state=State()): def _regex_group(state: T_State):
return state[REGEX_GROUP] return state[REGEX_GROUP]
@ -354,7 +349,7 @@ def RegexGroup() -> Tuple[Any, ...]:
return Depends(_regex_group, use_cache=False) return Depends(_regex_group, use_cache=False)
def _regex_dict(state=State()): def _regex_dict(state: T_State):
return state[REGEX_DICT] return state[REGEX_DICT]

View File

@ -39,7 +39,6 @@ from nonebot.consts import (
REGEX_MATCHED, REGEX_MATCHED,
) )
from nonebot.params import ( from nonebot.params import (
State,
Command, Command,
BotParam, BotParam,
EventToMe, EventToMe,
@ -390,9 +389,9 @@ class ShellCommandRule:
async def __call__( async def __call__(
self, self,
state: T_State,
cmd: Optional[Tuple[str, ...]] = Command(), cmd: Optional[Tuple[str, ...]] = Command(),
msg: Message = EventMessage(), msg: Message = EventMessage(),
state: T_State = State(),
) -> bool: ) -> bool:
if cmd in self.cmds: if cmd in self.cmds:
message = str(msg) message = str(msg)
@ -475,9 +474,9 @@ class RegexRule:
async def __call__( async def __call__(
self, self,
state: T_State,
type: str = EventType(), type: str = EventType(),
msg: Message = EventMessage(), msg: Message = EventMessage(),
state: T_State = State(),
) -> bool: ) -> bool:
if type != "message": if type != "message":
return False return False

View File

@ -3,7 +3,6 @@ from typing import List, Tuple
from nonebot.typing import T_State from nonebot.typing import T_State
from nonebot.adapters import Message from nonebot.adapters import Message
from nonebot.params import ( from nonebot.params import (
State,
Command, Command,
RegexDict, RegexDict,
CommandArg, CommandArg,
@ -15,7 +14,7 @@ from nonebot.params import (
) )
async def state(x: T_State = State()) -> T_State: async def state(x: T_State) -> T_State:
return x return x