"""本模块是 {ref}`nonebot.matcher.Matcher.rule` 的类型定义。

每个{ref}`事件响应器 <nonebot.matcher.Matcher>`拥有一个
{ref}`nonebot.rule.Rule`,其中是 `RuleChecker` 的集合。
只有当所有 `RuleChecker` 检查结果为 `True` 时继续运行。

FrontMatter:
    mdx:
        format: md
    sidebar_position: 5
    description: nonebot.rule 模块
"""

import re
import shlex
from argparse import Action
from gettext import gettext
from argparse import ArgumentError
from contextvars import ContextVar
from collections.abc import Sequence
from itertools import chain, product
from argparse import Namespace as Namespace
from argparse import ArgumentParser as ArgParser
from typing import (
    IO,
    TYPE_CHECKING,
    Union,
    TypeVar,
    Optional,
    TypedDict,
    NamedTuple,
    cast,
    overload,
)

from pygtrie import CharTrie

from nonebot import get_driver
from nonebot.log import logger
from nonebot.typing import T_State
from nonebot.exception import ParserExit
from nonebot.internal.rule import Rule as Rule
from nonebot.adapters import Bot, Event, Message, MessageSegment
from nonebot.params import Command, EventToMe, CommandArg, CommandWhitespace
from nonebot.consts import (
    CMD_KEY,
    PREFIX_KEY,
    SHELL_ARGS,
    SHELL_ARGV,
    CMD_ARG_KEY,
    KEYWORD_KEY,
    RAW_CMD_KEY,
    ENDSWITH_KEY,
    CMD_START_KEY,
    FULLMATCH_KEY,
    REGEX_MATCHED,
    STARTSWITH_KEY,
    CMD_WHITESPACE_KEY,
)

T = TypeVar("T")


class CMD_RESULT(TypedDict):
    command: Optional[tuple[str, ...]]
    raw_command: Optional[str]
    command_arg: Optional[Message]
    command_start: Optional[str]
    command_whitespace: Optional[str]


class TRIE_VALUE(NamedTuple):
    command_start: str
    command: tuple[str, ...]


parser_message: ContextVar[str] = ContextVar("parser_message")


class TrieRule:
    prefix: CharTrie = CharTrie()

    @classmethod
    def add_prefix(cls, prefix: str, value: TRIE_VALUE) -> None:
        if prefix in cls.prefix:
            logger.warning(f'Duplicated prefix rule "{prefix}"')
            return
        cls.prefix[prefix] = value

    @classmethod
    def get_value(cls, bot: Bot, event: Event, state: T_State) -> CMD_RESULT:
        prefix = CMD_RESULT(
            command=None,
            raw_command=None,
            command_arg=None,
            command_start=None,
            command_whitespace=None,
        )
        state[PREFIX_KEY] = prefix
        if event.get_type() != "message":
            return prefix

        message = event.get_message()
        message_seg: MessageSegment = message[0]
        if message_seg.is_text():
            segment_text = str(message_seg).lstrip()
            if pf := cls.prefix.longest_prefix(segment_text):
                value: TRIE_VALUE = pf.value
                prefix[RAW_CMD_KEY] = pf.key
                prefix[CMD_START_KEY] = value.command_start
                prefix[CMD_KEY] = value.command

                msg = message.copy()
                msg.pop(0)

                # check whitespace
                arg_str = segment_text[len(pf.key) :]
                arg_str_stripped = arg_str.lstrip()
                # check next segment until arg detected or no text remain
                while not arg_str_stripped and msg and msg[0].is_text():
                    arg_str += str(msg.pop(0))
                    arg_str_stripped = arg_str.lstrip()

                has_arg = arg_str_stripped or msg
                if (
                    has_arg
                    and (stripped_len := len(arg_str) - len(arg_str_stripped)) > 0
                ):
                    prefix[CMD_WHITESPACE_KEY] = arg_str[:stripped_len]

                # construct command arg
                if arg_str_stripped:
                    new_message = msg.__class__(arg_str_stripped)
                    for new_segment in reversed(new_message):
                        msg.insert(0, new_segment)
                prefix[CMD_ARG_KEY] = msg

        return prefix


class StartswithRule:
    """检查消息纯文本是否以指定字符串开头。

    参数:
        msg: 指定消息开头字符串元组
        ignorecase: 是否忽略大小写
    """

    __slots__ = ("msg", "ignorecase")

    def __init__(self, msg: tuple[str, ...], ignorecase: bool = False):
        self.msg = msg
        self.ignorecase = ignorecase

    def __repr__(self) -> str:
        return f"Startswith(msg={self.msg}, ignorecase={self.ignorecase})"

    def __eq__(self, other: object) -> bool:
        return (
            isinstance(other, StartswithRule)
            and frozenset(self.msg) == frozenset(other.msg)
            and self.ignorecase == other.ignorecase
        )

    def __hash__(self) -> int:
        return hash((frozenset(self.msg), self.ignorecase))

    async def __call__(self, event: Event, state: T_State) -> bool:
        try:
            text = event.get_plaintext()
        except Exception:
            return False
        if match := re.match(
            f"^(?:{'|'.join(re.escape(prefix) for prefix in self.msg)})",
            text,
            re.IGNORECASE if self.ignorecase else 0,
        ):
            state[STARTSWITH_KEY] = match.group()
            return True
        return False


def startswith(msg: Union[str, tuple[str, ...]], ignorecase: bool = False) -> Rule:
    """匹配消息纯文本开头。

    参数:
        msg: 指定消息开头字符串元组
        ignorecase: 是否忽略大小写
    """
    if isinstance(msg, str):
        msg = (msg,)

    return Rule(StartswithRule(msg, ignorecase))


class EndswithRule:
    """检查消息纯文本是否以指定字符串结尾。

    参数:
        msg: 指定消息结尾字符串元组
        ignorecase: 是否忽略大小写
    """

    __slots__ = ("msg", "ignorecase")

    def __init__(self, msg: tuple[str, ...], ignorecase: bool = False):
        self.msg = msg
        self.ignorecase = ignorecase

    def __repr__(self) -> str:
        return f"Endswith(msg={self.msg}, ignorecase={self.ignorecase})"

    def __eq__(self, other: object) -> bool:
        return (
            isinstance(other, EndswithRule)
            and frozenset(self.msg) == frozenset(other.msg)
            and self.ignorecase == other.ignorecase
        )

    def __hash__(self) -> int:
        return hash((frozenset(self.msg), self.ignorecase))

    async def __call__(self, event: Event, state: T_State) -> bool:
        try:
            text = event.get_plaintext()
        except Exception:
            return False
        if match := re.search(
            f"(?:{'|'.join(re.escape(suffix) for suffix in self.msg)})$",
            text,
            re.IGNORECASE if self.ignorecase else 0,
        ):
            state[ENDSWITH_KEY] = match.group()
            return True
        return False


def endswith(msg: Union[str, tuple[str, ...]], ignorecase: bool = False) -> Rule:
    """匹配消息纯文本结尾。

    参数:
        msg: 指定消息开头字符串元组
        ignorecase: 是否忽略大小写
    """
    if isinstance(msg, str):
        msg = (msg,)

    return Rule(EndswithRule(msg, ignorecase))


class FullmatchRule:
    """检查消息纯文本是否与指定字符串全匹配。

    参数:
        msg: 指定消息全匹配字符串元组
        ignorecase: 是否忽略大小写
    """

    __slots__ = ("msg", "ignorecase")

    def __init__(self, msg: tuple[str, ...], ignorecase: bool = False):
        self.msg = tuple(map(str.casefold, msg) if ignorecase else msg)
        self.ignorecase = ignorecase

    def __repr__(self) -> str:
        return f"Fullmatch(msg={self.msg}, ignorecase={self.ignorecase})"

    def __eq__(self, other: object) -> bool:
        return (
            isinstance(other, FullmatchRule)
            and frozenset(self.msg) == frozenset(other.msg)
            and self.ignorecase == other.ignorecase
        )

    def __hash__(self) -> int:
        return hash((frozenset(self.msg), self.ignorecase))

    async def __call__(self, event: Event, state: T_State) -> bool:
        try:
            text = event.get_plaintext()
        except Exception:
            return False
        if not text:
            return False
        text = text.casefold() if self.ignorecase else text
        if text in self.msg:
            state[FULLMATCH_KEY] = text
            return True
        return False


def fullmatch(msg: Union[str, tuple[str, ...]], ignorecase: bool = False) -> Rule:
    """完全匹配消息。

    参数:
        msg: 指定消息全匹配字符串元组
        ignorecase: 是否忽略大小写
    """
    if isinstance(msg, str):
        msg = (msg,)

    return Rule(FullmatchRule(msg, ignorecase))


class KeywordsRule:
    """检查消息纯文本是否包含指定关键字。

    参数:
        keywords: 指定关键字元组
    """

    __slots__ = ("keywords",)

    def __init__(self, *keywords: str):
        self.keywords = keywords

    def __repr__(self) -> str:
        return f"Keywords(keywords={self.keywords})"

    def __eq__(self, other: object) -> bool:
        return isinstance(other, KeywordsRule) and frozenset(
            self.keywords
        ) == frozenset(other.keywords)

    def __hash__(self) -> int:
        return hash(frozenset(self.keywords))

    async def __call__(self, event: Event, state: T_State) -> bool:
        try:
            text = event.get_plaintext()
        except Exception:
            return False
        if not text:
            return False
        if key := next((k for k in self.keywords if k in text), None):
            state[KEYWORD_KEY] = key
            return True
        return False


def keyword(*keywords: str) -> Rule:
    """匹配消息纯文本关键词。

    参数:
        keywords: 指定关键字元组
    """

    return Rule(KeywordsRule(*keywords))


class CommandRule:
    """检查消息是否为指定命令。

    参数:
        cmds: 指定命令元组列表
        force_whitespace: 是否强制命令后必须有指定空白符
    """

    __slots__ = ("cmds", "force_whitespace")

    def __init__(
        self,
        cmds: list[tuple[str, ...]],
        force_whitespace: Optional[Union[str, bool]] = None,
    ):
        self.cmds = tuple(cmds)
        self.force_whitespace = force_whitespace

    def __repr__(self) -> str:
        return f"Command(cmds={self.cmds})"

    def __eq__(self, other: object) -> bool:
        return isinstance(other, CommandRule) and frozenset(self.cmds) == frozenset(
            other.cmds
        )

    def __hash__(self) -> int:
        return hash((frozenset(self.cmds),))

    async def __call__(
        self,
        cmd: Optional[tuple[str, ...]] = Command(),
        cmd_arg: Optional[Message] = CommandArg(),
        cmd_whitespace: Optional[str] = CommandWhitespace(),
    ) -> bool:
        if cmd not in self.cmds:
            return False
        if self.force_whitespace is None or not cmd_arg:
            return True
        if isinstance(self.force_whitespace, str):
            return self.force_whitespace == cmd_whitespace
        return self.force_whitespace == (cmd_whitespace is not None)


def command(
    *cmds: Union[str, tuple[str, ...]],
    force_whitespace: Optional[Union[str, bool]] = None,
) -> Rule:
    """匹配消息命令。

    根据配置里提供的 {ref}``command_start` <nonebot.config.Config.command_start>`,
    {ref}``command_sep` <nonebot.config.Config.command_sep>` 判断消息是否为命令。

    可以通过 {ref}`nonebot.params.Command` 获取匹配成功的命令(例: `("test",)`),
    通过 {ref}`nonebot.params.RawCommand` 获取匹配成功的原始命令文本(例: `"/test"`),
    通过 {ref}`nonebot.params.CommandArg` 获取匹配成功的命令参数。

    参数:
        cmds: 命令文本或命令元组
        force_whitespace: 是否强制命令后必须有指定空白符

    用法:
        使用默认 `command_start`, `command_sep` 配置情况下:

        命令 `("test",)` 可以匹配: `/test` 开头的消息
        命令 `("test", "sub")` 可以匹配: `/test.sub` 开头的消息

    :::tip 提示
    命令内容与后续消息间无需空格!
    :::
    """

    config = get_driver().config
    command_start = config.command_start
    command_sep = config.command_sep
    commands: list[tuple[str, ...]] = []
    for command in cmds:
        if isinstance(command, str):
            command = (command,)

        commands.append(command)

        if len(command) == 1:
            for start in command_start:
                TrieRule.add_prefix(f"{start}{command[0]}", TRIE_VALUE(start, command))
        else:
            for start, sep in product(command_start, command_sep):
                TrieRule.add_prefix(
                    f"{start}{sep.join(command)}", TRIE_VALUE(start, command)
                )

    return Rule(CommandRule(commands, force_whitespace))


class ArgumentParser(ArgParser):
    """`shell_like` 命令参数解析器,解析出错时不会退出程序。

    支持 {ref}`nonebot.adapters.Message` 富文本解析。

    用法:
        用法与 `argparse.ArgumentParser` 相同,
        参考文档: [argparse](https://docs.python.org/3/library/argparse.html)
    """

    if TYPE_CHECKING:

        @overload
        def parse_known_args(
            self,
            args: Optional[Sequence[Union[str, MessageSegment]]] = None,
            namespace: None = None,
        ) -> tuple[Namespace, list[Union[str, MessageSegment]]]: ...

        @overload
        def parse_known_args(
            self, args: Optional[Sequence[Union[str, MessageSegment]]], namespace: T
        ) -> tuple[T, list[Union[str, MessageSegment]]]: ...

        @overload
        def parse_known_args(
            self, *, namespace: T
        ) -> tuple[T, list[Union[str, MessageSegment]]]: ...

        def parse_known_args(  # pyright: ignore[reportIncompatibleMethodOverride]
            self,
            args: Optional[Sequence[Union[str, MessageSegment]]] = None,
            namespace: Optional[T] = None,
        ) -> tuple[Union[Namespace, T], list[Union[str, MessageSegment]]]: ...

    @overload
    def parse_args(
        self,
        args: Optional[Sequence[Union[str, MessageSegment]]] = None,
        namespace: None = None,
    ) -> Namespace: ...

    @overload
    def parse_args(
        self, args: Optional[Sequence[Union[str, MessageSegment]]], namespace: T
    ) -> T: ...

    @overload
    def parse_args(self, *, namespace: T) -> T: ...

    def parse_args(
        self,
        args: Optional[Sequence[Union[str, MessageSegment]]] = None,
        namespace: Optional[T] = None,
    ) -> Union[Namespace, T]:
        result, argv = self.parse_known_args(args, namespace)
        if argv:
            msg = gettext("unrecognized arguments: %s")
            self.error(msg % " ".join(map(str, argv)))
        return cast(Union[Namespace, T], result)

    def _parse_optional(
        self, arg_string: Union[str, MessageSegment]
    ) -> Optional[tuple[Optional[Action], str, Optional[str]]]:
        return (
            super()._parse_optional(arg_string) if isinstance(arg_string, str) else None
        )

    def _print_message(self, message: str, file: Optional[IO[str]] = None):
        if (msg := parser_message.get(None)) is not None:
            parser_message.set(msg + message)
        else:
            super()._print_message(message, file)

    def exit(self, status: int = 0, message: Optional[str] = None):
        if message:
            self._print_message(message)
        raise ParserExit(status=status, message=parser_message.get(None))


class ShellCommandRule:
    """检查消息是否为指定 shell 命令。

    参数:
        cmds: 指定命令元组列表
        parser: 可选参数解析器
    """

    __slots__ = ("cmds", "parser")

    def __init__(self, cmds: list[tuple[str, ...]], parser: Optional[ArgumentParser]):
        self.cmds = tuple(cmds)
        self.parser = parser

    def __repr__(self) -> str:
        return f"ShellCommand(cmds={self.cmds}, parser={self.parser})"

    def __eq__(self, other: object) -> bool:
        return (
            isinstance(other, ShellCommandRule)
            and frozenset(self.cmds) == frozenset(other.cmds)
            and self.parser is other.parser
        )

    def __hash__(self) -> int:
        return hash((frozenset(self.cmds), self.parser))

    async def __call__(
        self,
        state: T_State,
        cmd: Optional[tuple[str, ...]] = Command(),
        msg: Optional[Message] = CommandArg(),
    ) -> bool:
        if cmd not in self.cmds or msg is None:
            return False

        state[SHELL_ARGV] = list(
            chain.from_iterable(
                shlex.split(str(seg)) if cast(MessageSegment, seg).is_text() else (seg,)
                for seg in msg
            )
        )

        if self.parser:
            t = parser_message.set("")
            try:
                args = self.parser.parse_args(state[SHELL_ARGV])
                state[SHELL_ARGS] = args
            except ArgumentError as e:
                state[SHELL_ARGS] = ParserExit(status=2, message=str(e))
            except ParserExit as e:
                state[SHELL_ARGS] = e
            finally:
                parser_message.reset(t)
        return True


def shell_command(
    *cmds: Union[str, tuple[str, ...]], parser: Optional[ArgumentParser] = None
) -> Rule:
    """匹配 `shell_like` 形式的消息命令。

    根据配置里提供的 {ref}``command_start` <nonebot.config.Config.command_start>`,
    {ref}``command_sep` <nonebot.config.Config.command_sep>` 判断消息是否为命令。

    可以通过 {ref}`nonebot.params.Command` 获取匹配成功的命令
    (例: `("test",)`),
    通过 {ref}`nonebot.params.RawCommand` 获取匹配成功的原始命令文本
    (例: `"/test"`),
    通过 {ref}`nonebot.params.ShellCommandArgv` 获取解析前的参数列表
    (例: `["arg", "-h"]`),
    通过 {ref}`nonebot.params.ShellCommandArgs` 获取解析后的参数字典
    (例: `{"arg": "arg", "h": True}`)。

    :::caution 警告
    如果参数解析失败,则通过 {ref}`nonebot.params.ShellCommandArgs`
    获取的将是 {ref}`nonebot.exception.ParserExit` 异常。
    :::

    参数:
        cmds: 命令文本或命令元组
        parser: {ref}`nonebot.rule.ArgumentParser` 对象

    用法:
        使用默认 `command_start`, `command_sep` 配置,更多示例参考
        [argparse](https://docs.python.org/3/library/argparse.html) 标准库文档。

        ```python
        from nonebot.rule import ArgumentParser

        parser = ArgumentParser()
        parser.add_argument("-a", action="store_true")

        rule = shell_command("ls", parser=parser)
        ```

    :::tip 提示
    命令内容与后续消息间无需空格!
    :::
    """
    if parser is not None and not isinstance(parser, ArgumentParser):
        raise TypeError("`parser` must be an instance of nonebot.rule.ArgumentParser")

    config = get_driver().config
    command_start = config.command_start
    command_sep = config.command_sep
    commands: list[tuple[str, ...]] = []
    for command in cmds:
        if isinstance(command, str):
            command = (command,)

        commands.append(command)

        if len(command) == 1:
            for start in command_start:
                TrieRule.add_prefix(f"{start}{command[0]}", TRIE_VALUE(start, command))
        else:
            for start, sep in product(command_start, command_sep):
                TrieRule.add_prefix(
                    f"{start}{sep.join(command)}", TRIE_VALUE(start, command)
                )

    return Rule(ShellCommandRule(commands, parser))


class RegexRule:
    """检查消息字符串是否符合指定正则表达式。

    参数:
        regex: 正则表达式
        flags: 正则表达式标记
    """

    __slots__ = ("regex", "flags")

    def __init__(self, regex: str, flags: int = 0):
        self.regex = regex
        self.flags = flags

    def __repr__(self) -> str:
        return f"Regex(regex={self.regex!r}, flags={self.flags})"

    def __eq__(self, other: object) -> bool:
        return (
            isinstance(other, RegexRule)
            and self.regex == other.regex
            and self.flags == other.flags
        )

    def __hash__(self) -> int:
        return hash((self.regex, self.flags))

    async def __call__(self, event: Event, state: T_State) -> bool:
        try:
            msg = event.get_message()
        except Exception:
            return False
        if matched := re.search(self.regex, str(msg), self.flags):
            state[REGEX_MATCHED] = matched
            return True
        else:
            return False


def regex(regex: str, flags: Union[int, re.RegexFlag] = 0) -> Rule:
    """匹配符合正则表达式的消息字符串。

    可以通过 {ref}`nonebot.params.RegexStr` 获取匹配成功的字符串,
    通过 {ref}`nonebot.params.RegexGroup` 获取匹配成功的 group 元组,
    通过 {ref}`nonebot.params.RegexDict` 获取匹配成功的 group 字典。

    参数:
        regex: 正则表达式
        flags: 正则表达式标记

    :::tip 提示
    正则表达式匹配使用 search 而非 match,如需从头匹配请使用 `r"^xxx"` 来确保匹配开头
    :::

    :::tip 提示
    正则表达式匹配使用 `EventMessage` 的 `str` 字符串,
    而非 `EventMessage` 的 `PlainText` 纯文本字符串
    :::
    """

    return Rule(RegexRule(regex, flags))


class ToMeRule:
    """检查事件是否与机器人有关。"""

    __slots__ = ()

    def __repr__(self) -> str:
        return "ToMe()"

    def __eq__(self, other: object) -> bool:
        return isinstance(other, ToMeRule)

    def __hash__(self) -> int:
        return hash((self.__class__,))

    async def __call__(self, to_me: bool = EventToMe()) -> bool:
        return to_me


def to_me() -> Rule:
    """匹配与机器人有关的事件。"""

    return Rule(ToMeRule())


class IsTypeRule:
    """检查事件类型是否为指定类型。"""

    __slots__ = ("types",)

    def __init__(self, *types: type[Event]):
        self.types = types

    def __repr__(self) -> str:
        return f"IsType(types={tuple(type.__name__ for type in self.types)})"

    def __eq__(self, other: object) -> bool:
        return isinstance(other, IsTypeRule) and self.types == other.types

    def __hash__(self) -> int:
        return hash((self.types,))

    async def __call__(self, event: Event) -> bool:
        return isinstance(event, self.types)


def is_type(*types: type[Event]) -> Rule:
    """匹配事件类型。

    参数:
        types: 事件类型
    """

    return Rule(IsTypeRule(*types))


__autodoc__ = {
    "Rule": True,
    "Rule.__call__": True,
    "TrieRule": False,
    "ArgumentParser.exit": False,
    "ArgumentParser.parse_args": False,
}