2024-12-01 12:31:11 +08:00
|
|
|
|
from _string import formatter_field_name_split # type: ignore
|
|
|
|
|
from collections.abc import Mapping, Sequence
|
2021-09-27 00:19:30 +08:00
|
|
|
|
import functools
|
2021-08-27 02:52:24 +08:00
|
|
|
|
from string import Formatter
|
2021-11-22 23:21:26 +08:00
|
|
|
|
from typing import (
|
|
|
|
|
TYPE_CHECKING,
|
|
|
|
|
Any,
|
2022-01-02 18:18:37 +08:00
|
|
|
|
Callable,
|
2024-12-01 12:31:11 +08:00
|
|
|
|
Generic,
|
2022-01-02 18:18:37 +08:00
|
|
|
|
Optional,
|
2024-12-01 12:31:11 +08:00
|
|
|
|
TypeVar,
|
|
|
|
|
Union,
|
2021-11-22 23:21:26 +08:00
|
|
|
|
cast,
|
|
|
|
|
overload,
|
|
|
|
|
)
|
2024-12-01 12:31:11 +08:00
|
|
|
|
from typing_extensions import TypeAlias
|
2024-01-04 11:11:37 +08:00
|
|
|
|
|
2021-08-27 02:52:24 +08:00
|
|
|
|
if TYPE_CHECKING:
|
2022-02-06 14:52:50 +08:00
|
|
|
|
from .message import Message, MessageSegment
|
2021-08-27 02:52:24 +08:00
|
|
|
|
|
2024-03-07 14:57:26 +08:00
|
|
|
|
def formatter_field_name_split(
|
2024-01-04 11:11:37 +08:00
|
|
|
|
field_name: str,
|
2024-04-16 00:33:48 +08:00
|
|
|
|
) -> tuple[str, list[tuple[bool, str]]]: ...
|
2024-01-04 11:11:37 +08:00
|
|
|
|
|
|
|
|
|
|
2021-08-27 14:46:15 +08:00
|
|
|
|
TM = TypeVar("TM", bound="Message")
|
2021-10-04 22:00:32 +08:00
|
|
|
|
TF = TypeVar("TF", str, "Message")
|
2021-08-27 02:52:24 +08:00
|
|
|
|
|
2023-06-24 14:47:35 +08:00
|
|
|
|
FormatSpecFunc: TypeAlias = Callable[[Any], str]
|
2022-01-01 21:52:54 +08:00
|
|
|
|
FormatSpecFunc_T = TypeVar("FormatSpecFunc_T", bound=FormatSpecFunc)
|
|
|
|
|
|
2021-08-27 02:52:24 +08:00
|
|
|
|
|
2021-10-04 22:00:32 +08:00
|
|
|
|
class MessageTemplate(Formatter, Generic[TF]):
|
2022-01-20 14:49:46 +08:00
|
|
|
|
"""消息模板格式化实现类。
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
template: 模板
|
2022-02-06 18:40:30 +08:00
|
|
|
|
factory: 消息类型工厂,默认为 `str`
|
2024-01-04 11:11:37 +08:00
|
|
|
|
private_getattr: 是否允许在模板中访问私有属性,默认为 `False`
|
2022-01-20 14:49:46 +08:00
|
|
|
|
"""
|
2021-08-27 02:52:24 +08:00
|
|
|
|
|
2021-10-04 22:00:32 +08:00
|
|
|
|
@overload
|
2021-11-22 23:21:26 +08:00
|
|
|
|
def __init__(
|
2024-01-04 11:11:37 +08:00
|
|
|
|
self: "MessageTemplate[str]",
|
|
|
|
|
template: str,
|
2024-04-16 00:33:48 +08:00
|
|
|
|
factory: type[str] = str,
|
2024-01-04 11:11:37 +08:00
|
|
|
|
private_getattr: bool = False,
|
2024-02-06 12:48:23 +08:00
|
|
|
|
) -> None: ...
|
2021-08-27 02:52:24 +08:00
|
|
|
|
|
2021-10-04 22:00:32 +08:00
|
|
|
|
@overload
|
2021-11-22 23:21:26 +08:00
|
|
|
|
def __init__(
|
2024-01-04 11:11:37 +08:00
|
|
|
|
self: "MessageTemplate[TM]",
|
|
|
|
|
template: Union[str, TM],
|
2024-04-16 00:33:48 +08:00
|
|
|
|
factory: type[TM],
|
2024-01-04 11:11:37 +08:00
|
|
|
|
private_getattr: bool = False,
|
2024-02-06 12:48:23 +08:00
|
|
|
|
) -> None: ...
|
2021-10-04 22:00:32 +08:00
|
|
|
|
|
2022-12-20 18:13:45 +08:00
|
|
|
|
def __init__(
|
2024-01-04 11:11:37 +08:00
|
|
|
|
self,
|
|
|
|
|
template: Union[str, TM],
|
2024-04-16 00:33:48 +08:00
|
|
|
|
factory: Union[type[str], type[TM]] = str,
|
2024-01-04 11:11:37 +08:00
|
|
|
|
private_getattr: bool = False,
|
2022-12-20 18:13:45 +08:00
|
|
|
|
) -> None:
|
|
|
|
|
self.template: TF = template # type: ignore
|
2024-04-16 00:33:48 +08:00
|
|
|
|
self.factory: type[TF] = factory # type: ignore
|
|
|
|
|
self.format_specs: dict[str, FormatSpecFunc] = {}
|
2024-01-04 11:11:37 +08:00
|
|
|
|
self.private_getattr = private_getattr
|
2022-01-01 21:52:54 +08:00
|
|
|
|
|
2022-09-09 11:52:57 +08:00
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
|
return f"MessageTemplate({self.template!r}, factory={self.factory!r})"
|
|
|
|
|
|
2022-01-01 21:52:54 +08:00
|
|
|
|
def add_format_spec(
|
|
|
|
|
self, spec: FormatSpecFunc_T, name: Optional[str] = None
|
|
|
|
|
) -> FormatSpecFunc_T:
|
|
|
|
|
name = name or spec.__name__
|
|
|
|
|
if name in self.format_specs:
|
|
|
|
|
raise ValueError(f"Format spec {name} already exists!")
|
|
|
|
|
self.format_specs[name] = spec
|
|
|
|
|
return spec
|
2021-10-04 22:00:32 +08:00
|
|
|
|
|
2024-04-16 00:33:48 +08:00
|
|
|
|
def format( # pyright: ignore[reportIncompatibleMethodOverride]
|
|
|
|
|
self, *args, **kwargs
|
|
|
|
|
) -> TF:
|
2022-02-06 18:40:30 +08:00
|
|
|
|
"""根据传入参数和模板生成消息对象"""
|
|
|
|
|
return self._format(args, kwargs)
|
|
|
|
|
|
|
|
|
|
def format_map(self, mapping: Mapping[str, Any]) -> TF:
|
|
|
|
|
"""根据传入字典和模板生成消息对象, 在传入字段名不是有效标识符时有用"""
|
|
|
|
|
return self._format([], mapping)
|
|
|
|
|
|
|
|
|
|
def _format(self, args: Sequence[Any], kwargs: Mapping[str, Any]) -> TF:
|
2022-04-30 09:59:23 +08:00
|
|
|
|
full_message = self.factory()
|
|
|
|
|
used_args, arg_index = set(), 0
|
|
|
|
|
|
2021-10-02 13:56:14 +08:00
|
|
|
|
if isinstance(self.template, str):
|
2022-04-30 09:59:23 +08:00
|
|
|
|
msg, arg_index = self._vformat(
|
|
|
|
|
self.template, args, kwargs, used_args, arg_index
|
|
|
|
|
)
|
|
|
|
|
full_message += msg
|
2021-10-02 13:56:14 +08:00
|
|
|
|
elif isinstance(self.template, self.factory):
|
2021-10-04 22:00:32 +08:00
|
|
|
|
template = cast("Message[MessageSegment]", self.template)
|
|
|
|
|
for seg in template:
|
2022-04-30 09:59:23 +08:00
|
|
|
|
if not seg.is_text():
|
|
|
|
|
full_message += seg
|
|
|
|
|
else:
|
|
|
|
|
msg, arg_index = self._vformat(
|
|
|
|
|
str(seg), args, kwargs, used_args, arg_index
|
|
|
|
|
)
|
|
|
|
|
full_message += msg
|
2021-10-02 13:56:14 +08:00
|
|
|
|
else:
|
2021-11-22 23:21:26 +08:00
|
|
|
|
raise TypeError("template must be a string or instance of Message!")
|
2021-10-02 13:56:14 +08:00
|
|
|
|
|
2022-12-20 18:13:45 +08:00
|
|
|
|
self.check_unused_args(used_args, args, kwargs)
|
2022-04-30 09:59:23 +08:00
|
|
|
|
return cast(TF, full_message)
|
2021-08-27 02:52:24 +08:00
|
|
|
|
|
2024-04-16 00:33:48 +08:00
|
|
|
|
def vformat( # pyright: ignore[reportIncompatibleMethodOverride]
|
2022-04-30 09:59:23 +08:00
|
|
|
|
self,
|
|
|
|
|
format_string: str,
|
|
|
|
|
args: Sequence[Any],
|
|
|
|
|
kwargs: Mapping[str, Any],
|
2021-11-22 23:21:26 +08:00
|
|
|
|
) -> TF:
|
2022-04-30 09:59:23 +08:00
|
|
|
|
raise NotImplementedError("`vformat` has merged into `_format`")
|
2021-08-27 02:52:24 +08:00
|
|
|
|
|
2024-04-16 00:33:48 +08:00
|
|
|
|
def _vformat( # pyright: ignore[reportIncompatibleMethodOverride]
|
2021-08-27 02:52:24 +08:00
|
|
|
|
self,
|
|
|
|
|
format_string: str,
|
2021-08-27 15:08:26 +08:00
|
|
|
|
args: Sequence[Any],
|
|
|
|
|
kwargs: Mapping[str, Any],
|
2024-04-16 00:33:48 +08:00
|
|
|
|
used_args: set[Union[int, str]],
|
2021-08-27 02:52:24 +08:00
|
|
|
|
auto_arg_index: int = 0,
|
2024-04-16 00:33:48 +08:00
|
|
|
|
) -> tuple[TF, int]:
|
|
|
|
|
results: list[Any] = [self.factory()]
|
2021-08-27 02:52:24 +08:00
|
|
|
|
|
2023-02-05 12:57:31 +08:00
|
|
|
|
for literal_text, field_name, format_spec, conversion in self.parse(
|
2021-11-22 23:21:26 +08:00
|
|
|
|
format_string
|
|
|
|
|
):
|
2021-08-27 02:52:24 +08:00
|
|
|
|
# output the literal text
|
|
|
|
|
if literal_text:
|
|
|
|
|
results.append(literal_text)
|
|
|
|
|
|
|
|
|
|
# if there's a field, output it
|
|
|
|
|
if field_name is not None:
|
|
|
|
|
# this is some markup, find the object and do
|
|
|
|
|
# the formatting
|
|
|
|
|
|
|
|
|
|
# handle arg indexing when empty field_names are given.
|
|
|
|
|
if field_name == "":
|
|
|
|
|
if auto_arg_index is False:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"cannot switch from manual field specification to "
|
2021-11-22 23:21:26 +08:00
|
|
|
|
"automatic field numbering"
|
|
|
|
|
)
|
2021-08-27 02:52:24 +08:00
|
|
|
|
field_name = str(auto_arg_index)
|
|
|
|
|
auto_arg_index += 1
|
|
|
|
|
elif field_name.isdigit():
|
|
|
|
|
if auto_arg_index:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"cannot switch from manual field specification to "
|
2021-11-22 23:21:26 +08:00
|
|
|
|
"automatic field numbering"
|
|
|
|
|
)
|
2021-08-27 02:52:24 +08:00
|
|
|
|
# disable auto arg incrementing, if it gets
|
|
|
|
|
# used later on, then an exception will be raised
|
|
|
|
|
auto_arg_index = False
|
|
|
|
|
|
|
|
|
|
# given the field_name, find the object it references
|
|
|
|
|
# and the argument it came from
|
|
|
|
|
obj, arg_used = self.get_field(field_name, args, kwargs)
|
|
|
|
|
used_args.add(arg_used)
|
|
|
|
|
|
|
|
|
|
# do any conversion on the resulting object
|
|
|
|
|
obj = self.convert_field(obj, conversion) if conversion else obj
|
|
|
|
|
|
|
|
|
|
# format the object and append to the result
|
2022-04-30 09:59:23 +08:00
|
|
|
|
formatted_text = (
|
|
|
|
|
self.format_field(obj, format_spec) if format_spec else obj
|
|
|
|
|
)
|
2021-08-27 02:52:24 +08:00
|
|
|
|
results.append(formatted_text)
|
|
|
|
|
|
2022-02-10 13:17:11 +08:00
|
|
|
|
return functools.reduce(self._add, results), auto_arg_index
|
2021-08-28 19:39:54 +08:00
|
|
|
|
|
2024-01-04 11:11:37 +08:00
|
|
|
|
def get_field(
|
|
|
|
|
self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, Any]
|
2024-04-16 00:33:48 +08:00
|
|
|
|
) -> tuple[Any, Union[int, str]]:
|
2024-01-04 11:11:37 +08:00
|
|
|
|
first, rest = formatter_field_name_split(field_name)
|
|
|
|
|
obj = self.get_value(first, args, kwargs)
|
|
|
|
|
|
|
|
|
|
for is_attr, value in rest:
|
|
|
|
|
if not self.private_getattr and value.startswith("_"):
|
|
|
|
|
raise ValueError("Cannot access private attribute")
|
|
|
|
|
obj = getattr(obj, value) if is_attr else obj[value]
|
|
|
|
|
|
|
|
|
|
return obj, first
|
|
|
|
|
|
2021-08-28 19:39:54 +08:00
|
|
|
|
def format_field(self, value: Any, format_spec: str) -> Any:
|
2022-01-02 13:13:43 +08:00
|
|
|
|
formatter: Optional[FormatSpecFunc] = self.format_specs.get(format_spec)
|
2022-01-02 18:18:37 +08:00
|
|
|
|
if formatter is None and not issubclass(self.factory, str):
|
2024-04-16 00:33:48 +08:00
|
|
|
|
segment_class: type["MessageSegment"] = self.factory.get_segment_class()
|
2022-01-02 13:13:43 +08:00
|
|
|
|
method = getattr(segment_class, format_spec, None)
|
2022-02-06 18:40:30 +08:00
|
|
|
|
if callable(method) and not cast(str, method.__name__).startswith("_"):
|
2022-01-02 13:13:43 +08:00
|
|
|
|
formatter = getattr(segment_class, format_spec)
|
2021-11-22 23:21:26 +08:00
|
|
|
|
return (
|
2022-01-02 13:13:43 +08:00
|
|
|
|
super().format_field(value, format_spec)
|
|
|
|
|
if formatter is None
|
|
|
|
|
else formatter(value)
|
2021-11-22 23:21:26 +08:00
|
|
|
|
)
|
2021-10-02 13:56:14 +08:00
|
|
|
|
|
|
|
|
|
def _add(self, a: Any, b: Any) -> Any:
|
|
|
|
|
try:
|
|
|
|
|
return a + b
|
|
|
|
|
except TypeError:
|
|
|
|
|
return a + str(b)
|