diff --git a/nonebot/plugin/on.py b/nonebot/plugin/on.py index 626e4ed1..1b2cf16e 100644 --- a/nonebot/plugin/on.py +++ b/nonebot/plugin/on.py @@ -4,14 +4,14 @@ import inspect from types import ModuleType from typing import Any, Set, Dict, List, Type, Tuple, Union, Optional +from nonebot.adapters import Event from nonebot.handler import Handler from nonebot.matcher import Matcher from .manager import _current_plugin -from nonebot.adapters import Bot, Event from nonebot.permission import Permission from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory -from nonebot.rule import (Rule, ArgumentParser, regex, command, keyword, - endswith, startswith, shell_command) +from nonebot.rule import (PREFIX_KEY, RAW_CMD_KEY, Rule, ArgumentParser, regex, + command, keyword, endswith, startswith, shell_command) def _store_matcher(matcher: Type[Matcher]) -> None: @@ -373,16 +373,16 @@ def on_command(cmd: Union[str, Tuple[str, ...]], - ``Type[Matcher]`` """ - async def _strip_cmd(bot: Bot, event: Event, state: T_State): + async def _strip_cmd(event: Event, state: T_State): message = event.get_message() if len(message) < 1: return segment = message.pop(0) segment_text = str(segment).lstrip() - if not segment_text.startswith(state["_prefix"]["raw_command"]): + if not segment_text.startswith(state[PREFIX_KEY][RAW_CMD_KEY]): return new_message = message.__class__( - segment_text[len(state["_prefix"]["raw_command"]):].lstrip()) + segment_text[len(state[PREFIX_KEY][RAW_CMD_KEY]):].lstrip()) for new_segment in reversed(new_message): message.insert(0, new_segment) @@ -430,7 +430,7 @@ def on_shell_command(cmd: Union[str, Tuple[str, ...]], - ``Type[Matcher]`` """ - async def _strip_cmd(bot: Bot, event: Event, state: T_State): + async def _strip_cmd(event: Event, state: T_State): message = event.get_message() segment = message.pop(0) new_message = message.__class__( diff --git a/nonebot/rule.py b/nonebot/rule.py index 40fc1b43..7ac0fe8d 100644 --- a/nonebot/rule.py +++ b/nonebot/rule.py @@ -14,18 +14,35 @@ import shlex import asyncio from itertools import product from argparse import Namespace +from typing_extensions import TypedDict from argparse import ArgumentParser as ArgParser -from typing import (Any, Dict, Tuple, Union, Callable, NoReturn, Optional, - Sequence, Awaitable) +from typing import (Any, Tuple, Union, Callable, NoReturn, Optional, Sequence, + Awaitable) from pygtrie import CharTrie from nonebot import get_driver from nonebot.log import logger from nonebot.utils import run_sync -from nonebot.adapters import Bot, Event from nonebot.exception import ParserExit from nonebot.typing import T_State, T_RuleChecker +from nonebot.adapters import Bot, Event, MessageSegment + +PREFIX_KEY = "_prefix" +SUFFIX_KEY = "_suffix" +CMD_KEY = "command" +RAW_CMD_KEY = "raw_command" +CMD_RESULT = TypedDict("CMD_RESULT", { + "command": Optional[Tuple[str, ...]], + "raw_command": Optional[str] +}) + +SHELL_ARGS = "_args" +SHELL_ARGV = "_argv" + +REGEX_MATCHED = "_matched" +REGEX_GROUP = "_matched_groups" +REGEX_DICT = "_matched_dict" class Rule: @@ -121,57 +138,27 @@ class TrieRule: @classmethod def get_value(cls, bot: Bot, event: Event, - state: T_State) -> Tuple[Dict[str, Any], Dict[str, Any]]: + state: T_State) -> Tuple[CMD_RESULT, CMD_RESULT]: + prefix = CMD_RESULT(command=None, raw_command=None) + suffix = CMD_RESULT(command=None, raw_command=None) + state[PREFIX_KEY] = prefix + state[SUFFIX_KEY] = suffix if event.get_type() != "message": - state["_prefix"] = {"raw_command": None, "command": None} - state["_suffix"] = {"raw_command": None, "command": None} - return { - "raw_command": None, - "command": None - }, { - "raw_command": None, - "command": None - } + return prefix, suffix - prefix = None - suffix = None message = event.get_message() - message_seg = message[0] + message_seg: MessageSegment = message[0] if message_seg.is_text(): - prefix = cls.prefix.longest_prefix(str(message_seg).lstrip()) - message_seg_r = message[-1] + pf = cls.prefix.longest_prefix(str(message_seg).lstrip()) + prefix[RAW_CMD_KEY] = pf.key + prefix[CMD_KEY] = pf.value + message_seg_r: MessageSegment = message[-1] if message_seg_r.is_text(): - suffix = cls.suffix.longest_prefix( - str(message_seg_r).rstrip()[::-1]) + sf = cls.suffix.longest_prefix(str(message_seg_r).rstrip()[::-1]) + suffix[RAW_CMD_KEY] = sf.key + suffix[CMD_KEY] = sf.value - state["_prefix"] = { - "raw_command": prefix.key, - "command": prefix.value - } if prefix else { - "raw_command": None, - "command": None - } - state["_suffix"] = { - "raw_command": suffix.key, - "command": suffix.value - } if suffix else { - "raw_command": None, - "command": None - } - - return ({ - "raw_command": prefix.key, - "command": prefix.value - } if prefix else { - "raw_command": None, - "command": None - }, { - "raw_command": suffix.key, - "command": suffix.value - } if suffix else { - "raw_command": None, - "command": None - }) + return prefix, suffix def startswith(msg: Union[str, Tuple[str, ...]], @@ -288,7 +275,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule: TrieRule.add_prefix(f"{start}{sep.join(command)}", command) async def _command(bot: Bot, event: Event, state: T_State) -> bool: - return state["_prefix"]["command"] in commands + return state[PREFIX_KEY][CMD_KEY] in commands return Rule(_command) @@ -374,17 +361,17 @@ def shell_command(*cmds: Union[str, Tuple[str, ...]], TrieRule.add_prefix(f"{start}{sep.join(command)}", command) async def _shell_command(bot: Bot, event: Event, state: T_State) -> bool: - if state["_prefix"]["command"] in commands: + if state[PREFIX_KEY][CMD_KEY] in commands: message = str(event.get_message()) - strip_message = message[len(state["_prefix"]["raw_command"] + strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY] ):].lstrip() - state["argv"] = shlex.split(strip_message) + state[SHELL_ARGV] = shlex.split(strip_message) if parser: try: - args = parser.parse_args(state["argv"]) - state["args"] = args + args = parser.parse_args(state[SHELL_ARGV]) + state[SHELL_ARGS] = args except ParserExit as e: - state["args"] = e + state[SHELL_ARGS] = e return True else: return False @@ -418,9 +405,9 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule: return False matched = pattern.search(str(event.get_message())) if matched: - state["_matched"] = matched.group() - state["_matched_groups"] = matched.groups() - state["_matched_dict"] = matched.groupdict() + state[REGEX_MATCHED] = matched.group() + state[REGEX_GROUP] = matched.groups() + state[REGEX_DICT] = matched.groupdict() return True else: return False