
119 lines
4.3 KiB
Raw Normal View History

import functools
import operator
from string import Formatter
from typing import (Any, Generic, List, Mapping, Protocol, Sequence, Set, Tuple,
Type, TypeVar, Union, TYPE_CHECKING)
from nonebot.adapters import Message
class AddAble(Protocol):
def __add__(self, __s: Any) -> "AddAble":
def __str__(self) -> str:
AddAble_T = TypeVar("AddAble_T", bound=AddAble)
MessageResult_T = TypeVar("MessageResult_T", bound="Message", covariant=True)
class MessageFormatter(Formatter, Generic[MessageResult_T]):
def __init__(self, factory: Type[MessageResult_T], template: str) -> None:
self.template = template
self.factory = factory
def format(self, *args: AddAble, **kwargs: AddAble) -> MessageResult_T:
msg: AddAble = super().format(self.template, *args, **kwargs)
return msg if isinstance(msg, self.factory) else self.factory(
msg) # type: ignore
def vformat(self, format_string: str, args: Sequence[AddAble],
kwargs: Mapping[str, AddAble]):
result, arg_index, used_args = self._vformat(format_string, args,
kwargs, set(), 2)
self.check_unused_args(list(used_args), args, kwargs)
return result
def _vformat(
format_string: str,
args: Sequence[Any],
kwargs: Mapping[str, Any],
used_args: Set[Union[int, str]],
recursion_depth: int,
auto_arg_index: int = 0,
) -> Tuple[AddAble, int, Set[Union[int, str]]]:
if recursion_depth < 0:
raise ValueError("Max string recursion exceeded")
results: List[AddAble] = []
for (literal_text, field_name, format_spec,
conversion) in self.parse(format_string):
# output the literal text
if literal_text:
# if there's a field, output it
if field_name is not None:
# this is some markup, find the object and do
# the formatting
# handle arg indexing when empty field_names are given.
if field_name == "":
if auto_arg_index is False:
raise ValueError(
"cannot switch from manual field specification to "
"automatic field numbering")
field_name = str(auto_arg_index)
auto_arg_index += 1
elif field_name.isdigit():
if auto_arg_index:
raise ValueError(
"cannot switch from manual field specification to "
"automatic field numbering")
# disable auto arg incrementing, if it gets
# used later on, then an exception will be raised
auto_arg_index = False
# given the field_name, find the object it references
# and the argument it came from
obj, arg_used = self.get_field(field_name, args, kwargs)
assert format_spec is not None
# do any conversion on the resulting object
obj = self.convert_field(obj, conversion) if conversion else obj
# expand the format spec, if needed
format_control, auto_arg_index, formatted_args = self._vformat(
recursion_depth - 1,
used_args |= formatted_args
# format the object and append to the result
formatted_text = self.format_field(obj, str(format_control))
return functools.reduce(operator.add, results or
[""]), auto_arg_index, used_args
def format_field(self, value: AddAble_T,
format_spec: str) -> Union[AddAble_T, str]:
return super().format_field(value,
format_spec) if format_spec else value