import functools from string import Formatter from typing import ( TYPE_CHECKING, Any, Set, Dict, List, Type, Tuple, Union, Generic, Mapping, TypeVar, Callable, Optional, Sequence, cast, overload, ) if TYPE_CHECKING: from .message import Message, MessageSegment TM = TypeVar("TM", bound="Message") TF = TypeVar("TF", str, "Message") FormatSpecFunc = Callable[[Any], str] FormatSpecFunc_T = TypeVar("FormatSpecFunc_T", bound=FormatSpecFunc) class MessageTemplate(Formatter, Generic[TF]): """消息模板格式化实现类。 参数: template: 模板 factory: 消息类型工厂,默认为 `str` """ @overload def __init__( self: "MessageTemplate[str]", template: str, factory: Type[str] = str ) -> None: ... @overload def __init__( self: "MessageTemplate[TM]", template: Union[str, TM], factory: Type[TM] ) -> None: ... def __init__(self, template, factory=str) -> None: self.template: TF = template self.factory: Type[TF] = factory self.format_specs: Dict[str, FormatSpecFunc] = {} def add_format_spec( self, spec: FormatSpecFunc_T, name: Optional[str] = None ) -> FormatSpecFunc_T: name = name or spec.__name__ if name in self.format_specs: raise ValueError(f"Format spec {name} already exists!") self.format_specs[name] = spec return spec def format(self, *args, **kwargs): """根据传入参数和模板生成消息对象""" return self._format(args, kwargs) def format_map(self, mapping: Mapping[str, Any]) -> TF: """根据传入字典和模板生成消息对象, 在传入字段名不是有效标识符时有用""" return self._format([], mapping) def _format(self, args: Sequence[Any], kwargs: Mapping[str, Any]) -> TF: msg = self.factory() if isinstance(self.template, str): msg += self.vformat(self.template, args, kwargs) elif isinstance(self.template, self.factory): template = cast("Message[MessageSegment]", self.template) for seg in template: msg += self.vformat(str(seg), args, kwargs) if seg.is_text() else seg else: raise TypeError("template must be a string or instance of Message!") return msg # type:ignore def vformat( self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any] ) -> TF: used_args = set() result, _ = self._vformat(format_string, args, kwargs, used_args, 2) self.check_unused_args(list(used_args), args, kwargs) return result def _vformat( self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any], used_args: Set[Union[int, str]], recursion_depth: int, auto_arg_index: int = 0, ) -> Tuple[TF, int]: if recursion_depth < 0: raise ValueError("Max string recursion exceeded") results: List[Any] = [] for (literal_text, field_name, format_spec, conversion) in self.parse( format_string ): # output the literal text if literal_text: results.append(literal_text) # if there's a field, output it if field_name is not None: # this is some markup, find the object and do # the formatting # handle arg indexing when empty field_names are given. if field_name == "": if auto_arg_index is False: raise ValueError( "cannot switch from manual field specification to " "automatic field numbering" ) field_name = str(auto_arg_index) auto_arg_index += 1 elif field_name.isdigit(): if auto_arg_index: raise ValueError( "cannot switch from manual field specification to " "automatic field numbering" ) # disable auto arg incrementing, if it gets # used later on, then an exception will be raised auto_arg_index = False # given the field_name, find the object it references # and the argument it came from obj, arg_used = self.get_field(field_name, args, kwargs) used_args.add(arg_used) assert format_spec is not None # do any conversion on the resulting object obj = self.convert_field(obj, conversion) if conversion else obj # expand the format spec, if needed format_control, auto_arg_index = self._vformat( format_spec, args, kwargs, used_args, recursion_depth - 1, auto_arg_index, ) # format the object and append to the result formatted_text = self.format_field(obj, str(format_control)) results.append(formatted_text) return ( self.factory(functools.reduce(self._add, results or [""])), auto_arg_index, ) def format_field(self, value: Any, format_spec: str) -> Any: formatter: Optional[FormatSpecFunc] = self.format_specs.get(format_spec) if formatter is None and not issubclass(self.factory, str): segment_class: Type["MessageSegment"] = self.factory.get_segment_class() method = getattr(segment_class, format_spec, None) if callable(method) and not cast(str, method.__name__).startswith("_"): formatter = getattr(segment_class, format_spec) return ( super().format_field(value, format_spec) if formatter is None else formatter(value) ) def _add(self, a: Any, b: Any) -> Any: try: return a + b except TypeError: return a + str(b)