improve command rule types

This commit is contained in:
yanyongyu 2021-11-17 00:27:58 +08:00
parent 4cbdd726e5
commit dc31afbd18
2 changed files with 52 additions and 65 deletions

View File

@ -4,14 +4,14 @@ import inspect
from types import ModuleType
from typing import Any, Set, Dict, List, Type, Tuple, Union, Optional
from nonebot.adapters import Event
from nonebot.handler import Handler
from nonebot.matcher import Matcher
from .manager import _current_plugin
from nonebot.adapters import Bot, Event
from nonebot.permission import Permission
from nonebot.typing import T_State, T_Handler, T_RuleChecker, T_StateFactory
from nonebot.rule import (Rule, ArgumentParser, regex, command, keyword,
endswith, startswith, shell_command)
from nonebot.rule import (PREFIX_KEY, RAW_CMD_KEY, Rule, ArgumentParser, regex,
command, keyword, endswith, startswith, shell_command)
def _store_matcher(matcher: Type[Matcher]) -> None:
@ -373,16 +373,16 @@ def on_command(cmd: Union[str, Tuple[str, ...]],
- ``Type[Matcher]``
"""
async def _strip_cmd(bot: Bot, event: Event, state: T_State):
async def _strip_cmd(event: Event, state: T_State):
message = event.get_message()
if len(message) < 1:
return
segment = message.pop(0)
segment_text = str(segment).lstrip()
if not segment_text.startswith(state["_prefix"]["raw_command"]):
if not segment_text.startswith(state[PREFIX_KEY][RAW_CMD_KEY]):
return
new_message = message.__class__(
segment_text[len(state["_prefix"]["raw_command"]):].lstrip())
segment_text[len(state[PREFIX_KEY][RAW_CMD_KEY]):].lstrip())
for new_segment in reversed(new_message):
message.insert(0, new_segment)
@ -430,7 +430,7 @@ def on_shell_command(cmd: Union[str, Tuple[str, ...]],
- ``Type[Matcher]``
"""
async def _strip_cmd(bot: Bot, event: Event, state: T_State):
async def _strip_cmd(event: Event, state: T_State):
message = event.get_message()
segment = message.pop(0)
new_message = message.__class__(

View File

@ -14,18 +14,35 @@ import shlex
import asyncio
from itertools import product
from argparse import Namespace
from typing_extensions import TypedDict
from argparse import ArgumentParser as ArgParser
from typing import (Any, Dict, Tuple, Union, Callable, NoReturn, Optional,
Sequence, Awaitable)
from typing import (Any, Tuple, Union, Callable, NoReturn, Optional, Sequence,
Awaitable)
from pygtrie import CharTrie
from nonebot import get_driver
from nonebot.log import logger
from nonebot.utils import run_sync
from nonebot.adapters import Bot, Event
from nonebot.exception import ParserExit
from nonebot.typing import T_State, T_RuleChecker
from nonebot.adapters import Bot, Event, MessageSegment
PREFIX_KEY = "_prefix"
SUFFIX_KEY = "_suffix"
CMD_KEY = "command"
RAW_CMD_KEY = "raw_command"
CMD_RESULT = TypedDict("CMD_RESULT", {
"command": Optional[Tuple[str, ...]],
"raw_command": Optional[str]
})
SHELL_ARGS = "_args"
SHELL_ARGV = "_argv"
REGEX_MATCHED = "_matched"
REGEX_GROUP = "_matched_groups"
REGEX_DICT = "_matched_dict"
class Rule:
@ -121,57 +138,27 @@ class TrieRule:
@classmethod
def get_value(cls, bot: Bot, event: Event,
state: T_State) -> Tuple[Dict[str, Any], Dict[str, Any]]:
state: T_State) -> Tuple[CMD_RESULT, CMD_RESULT]:
prefix = CMD_RESULT(command=None, raw_command=None)
suffix = CMD_RESULT(command=None, raw_command=None)
state[PREFIX_KEY] = prefix
state[SUFFIX_KEY] = suffix
if event.get_type() != "message":
state["_prefix"] = {"raw_command": None, "command": None}
state["_suffix"] = {"raw_command": None, "command": None}
return {
"raw_command": None,
"command": None
}, {
"raw_command": None,
"command": None
}
return prefix, suffix
prefix = None
suffix = None
message = event.get_message()
message_seg = message[0]
message_seg: MessageSegment = message[0]
if message_seg.is_text():
prefix = cls.prefix.longest_prefix(str(message_seg).lstrip())
message_seg_r = message[-1]
pf = cls.prefix.longest_prefix(str(message_seg).lstrip())
prefix[RAW_CMD_KEY] = pf.key
prefix[CMD_KEY] = pf.value
message_seg_r: MessageSegment = message[-1]
if message_seg_r.is_text():
suffix = cls.suffix.longest_prefix(
str(message_seg_r).rstrip()[::-1])
sf = cls.suffix.longest_prefix(str(message_seg_r).rstrip()[::-1])
suffix[RAW_CMD_KEY] = sf.key
suffix[CMD_KEY] = sf.value
state["_prefix"] = {
"raw_command": prefix.key,
"command": prefix.value
} if prefix else {
"raw_command": None,
"command": None
}
state["_suffix"] = {
"raw_command": suffix.key,
"command": suffix.value
} if suffix else {
"raw_command": None,
"command": None
}
return ({
"raw_command": prefix.key,
"command": prefix.value
} if prefix else {
"raw_command": None,
"command": None
}, {
"raw_command": suffix.key,
"command": suffix.value
} if suffix else {
"raw_command": None,
"command": None
})
return prefix, suffix
def startswith(msg: Union[str, Tuple[str, ...]],
@ -288,7 +275,7 @@ def command(*cmds: Union[str, Tuple[str, ...]]) -> Rule:
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
async def _command(bot: Bot, event: Event, state: T_State) -> bool:
return state["_prefix"]["command"] in commands
return state[PREFIX_KEY][CMD_KEY] in commands
return Rule(_command)
@ -374,17 +361,17 @@ def shell_command(*cmds: Union[str, Tuple[str, ...]],
TrieRule.add_prefix(f"{start}{sep.join(command)}", command)
async def _shell_command(bot: Bot, event: Event, state: T_State) -> bool:
if state["_prefix"]["command"] in commands:
if state[PREFIX_KEY][CMD_KEY] in commands:
message = str(event.get_message())
strip_message = message[len(state["_prefix"]["raw_command"]
strip_message = message[len(state[PREFIX_KEY][RAW_CMD_KEY]
):].lstrip()
state["argv"] = shlex.split(strip_message)
state[SHELL_ARGV] = shlex.split(strip_message)
if parser:
try:
args = parser.parse_args(state["argv"])
state["args"] = args
args = parser.parse_args(state[SHELL_ARGV])
state[SHELL_ARGS] = args
except ParserExit as e:
state["args"] = e
state[SHELL_ARGS] = e
return True
else:
return False
@ -418,9 +405,9 @@ def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
return False
matched = pattern.search(str(event.get_message()))
if matched:
state["_matched"] = matched.group()
state["_matched_groups"] = matched.groups()
state["_matched_dict"] = matched.groupdict()
state[REGEX_MATCHED] = matched.group()
state[REGEX_GROUP] = matched.groups()
state[REGEX_DICT] = matched.groupdict()
return True
else:
return False