improve command rule types

This commit is contained in:
yanyongyu 2021-11-17 00:27:58 +08:00
parent 4cbdd726e5
commit dc31afbd18
2 changed files with 52 additions and 65 deletions

View File

@ -4,14 +4,14 @@ import inspect
from types import ModuleType from types import ModuleType
from typing import Any, Set, Dict, List, Type, Tuple, Union, Optional from typing import Any, Set, Dict, List, Type, Tuple, Union, Optional
from nonebot.adapters import Event
from nonebot.handler import Handler from nonebot.handler import Handler
from nonebot.matcher import Matcher from nonebot.matcher import Matcher
from .manager import _current_plugin from .manager import _current_plugin
from nonebot.adapters import Bot, Event
from nonebot.permission import Permission from nonebot.permission import Permission
from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory
from nonebot.rule import (Rule, ArgumentParser, regex, command, keyword, from nonebot.rule import (PREFIX_KEY, RAW_CMD_KEY, Rule, ArgumentParser, regex,
endswith, startswith, shell_command) command, keyword, endswith, startswith, shell_command)
def _store_matcher(matcher: Type[Matcher]) -> None: def _store_matcher(matcher: Type[Matcher]) -> None:
@ -373,16 +373,16 @@ def on_command(cmd: Union[str, Tuple[str, ...]],
- ``Type[Matcher]`` - ``Type[Matcher]``
""" """
async def _strip_cmd(bot: Bot, event: Event, state: T_State): async def _strip_cmd(event: Event, state: T_State):
message = event.get_message() message = event.get_message()
if len(message) < 1: if len(message) < 1:
return return
segment = message.pop(0) segment = message.pop(0)
segment_text = str(segment).lstrip() segment_text = str(segment).lstrip()
if not segment_text.startswith(state["_prefix"]["raw_command"]): if not segment_text.startswith(state[PREFIX_KEY][RAW_CMD_KEY]):
return return
new_message = message.__class__( new_message = message.__class__(
segment_text[len(state["_prefix"]["raw_command"]):].lstrip()) segment_text[len(state[PREFIX_KEY][RAW_CMD_KEY]):].lstrip())
for new_segment in reversed(new_message): for new_segment in reversed(new_message):
message.insert(0, new_segment) message.insert(0, new_segment)
@ -430,7 +430,7 @@ def on_shell_command(cmd: Union[str, Tuple[str, ...]],
- ``Type[Matcher]`` - ``Type[Matcher]``
""" """
async def _strip_cmd(bot: Bot, event: Event, state: T_State): async def _strip_cmd(event: Event, state: T_State):
message = event.get_message() message = event.get_message()
segment = message.pop(0) segment = message.pop(0)
new_message = message.__class__( new_message = message.__class__(

View File

@ -14,18 +14,35 @@ import shlex
import asyncio import asyncio
from itertools import product from itertools import product
from argparse import Namespace from argparse import Namespace
from typing_extensions import TypedDict
from argparse import ArgumentParser as ArgParser from argparse import ArgumentParser as ArgParser
from typing import (Any, Dict, Tuple, Union, Callable, NoReturn, Optional, from typing import (Any, Tuple, Union, Callable, NoReturn, Optional, Sequence,
Sequence, Awaitable) Awaitable)
from pygtrie import CharTrie from pygtrie import CharTrie
from nonebot import get_driver from nonebot import get_driver
from nonebot.log import logger from nonebot.log import logger
from nonebot.utils import run_sync from nonebot.utils import run_sync
from nonebot.adapters import Bot, Event
from nonebot.exception import ParserExit from nonebot.exception import ParserExit
from nonebot.typing import T_State, T_RuleChecker from nonebot.typing import T_State, T_RuleChecker
from nonebot.adapters import Bot, Event, MessageSegment
PREFIX_KEY = "_prefix"
SUFFIX_KEY = "_suffix"
CMD_KEY = "command"
RAW_CMD_KEY = "raw_command"
CMD_RESULT = TypedDict("CMD_RESULT", {
"command": Optional[Tuple[str, ...]],
"raw_command": Optional[str]
})
SHELL_ARGS = "_args"
SHELL_ARGV = "_argv"
REGEX_MATCHED = "_matched"
REGEX_GROUP = "_matched_groups"
REGEX_DICT = "_matched_dict"
class Rule: class Rule:
@ -121,57 +138,27 @@ class TrieRule:
@classmethod @classmethod
def get_value(cls, bot: Bot, event: Event, def get_value(cls, bot: Bot, event: Event,
state: T_State) -> Tuple[Dict[str, Any], Dict[str, Any]]: state: T_State) -> Tuple[CMD_RESULT, CMD_RESULT]:
prefix = CMD_RESULT(command=None, raw_command=None)
suffix = CMD_RESULT(command=None, raw_command=None)
state[PREFIX_KEY] = prefix
state[SUFFIX_KEY] = suffix
if event.get_type() != "message": if event.get_type() != "message":
state["_prefix"] = {"raw_command": None, "command": None} return prefix, suffix
state["_suffix"] = {"raw_command": None, "command": None}
return {
"raw_command": None,
"command": None
}, {
"raw_command": None,
"command": None
}
prefix = None
suffix = None
message = event.get_message() message = event.get_message()
message_seg = message[0] message_seg: MessageSegment = message[0]
if message_seg.is_text(): if message_seg.is_text():
prefix = cls.prefix.longest_prefix(str(message_seg).lstrip()) pf = cls.prefix.longest_prefix(str(message_seg).lstrip())
message_seg_r = message[-1] prefix[RAW_CMD_KEY] = pf.key
prefix[CMD_KEY] = pf.value
message_seg_r: MessageSegment = message[-1]
if message_seg_r.is_text(): if message_seg_r.is_text():
suffix = cls.suffix.longest_prefix( sf = cls.suffix.longest_prefix(str(message_seg_r).rstrip()[::-1])
str(message_seg_r).rstrip()[::-1]) suffix[RAW_CMD_KEY] = sf.key
suffix[CMD_KEY] = sf.value
state["_prefix"] = { return prefix, suffix
"raw_command": prefix.key,
"command": prefix.value
} if prefix else {
"raw_command": None,
"command": None
}
state["_suffix"] = {
"raw_command": suffix.key,
"command": suffix.value
} if suffix else {
"raw_command": None,
"command": None
}
return ({
"raw_command": prefix.key,
"command": prefix.value
} if prefix else {
"raw_command": None,
"command": None
}, {
"raw_command": suffix.key,
"command": suffix.value
} if suffix else {
"raw_command": None,
"command": None
})
def startswith(msg: Union[str, Tuple[str, ...]], def startswith(msg: Union[str, Tuple[str, ...]],
@ -288,7 +275,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
TrieRule.add_prefix(f"{start}{sep.join(command)}", command) TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
async def _command(bot: Bot, event: Event, state: T_State) -> bool: async def _command(bot: Bot, event: Event, state: T_State) -> bool:
return state["_prefix"]["command"] in commands return state[PREFIX_KEY][CMD_KEY] in commands
return Rule(_command) return Rule(_command)
@ -374,17 +361,17 @@ def shell_command(*cmds: Union[str, Tuple[str, ...]],
TrieRule.add_prefix(f"{start}{sep.join(command)}", command) TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
async def _shell_command(bot: Bot, event: Event, state: T_State) -> bool: async def _shell_command(bot: Bot, event: Event, state: T_State) -> bool:
if state["_prefix"]["command"] in commands: if state[PREFIX_KEY][CMD_KEY] in commands:
message = str(event.get_message()) message = str(event.get_message())
strip_message = message[len(state["_prefix"]["raw_command"] strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]
):].lstrip() ):].lstrip()
state["argv"] = shlex.split(strip_message) state[SHELL_ARGV] = shlex.split(strip_message)
if parser: if parser:
try: try:
args = parser.parse_args(state["argv"]) args = parser.parse_args(state[SHELL_ARGV])
state["args"] = args state[SHELL_ARGS] = args
except ParserExit as e: except ParserExit as e:
state["args"] = e state[SHELL_ARGS] = e
return True return True
else: else:
return False return False
@ -418,9 +405,9 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
return False return False
matched = pattern.search(str(event.get_message())) matched = pattern.search(str(event.get_message()))
if matched: if matched:
state["_matched"] = matched.group() state[REGEX_MATCHED] = matched.group()
state["_matched_groups"] = matched.groups() state[REGEX_GROUP] = matched.groups()
state["_matched_dict"] = matched.groupdict() state[REGEX_DICT] = matched.groupdict()
return True return True
else: else:
return False return False