diff --git a/nonebot/internal/adapter/template.py b/nonebot/internal/adapter/template.py index e04830a8..39855903 100644 --- a/nonebot/internal/adapter/template.py +++ b/nonebot/internal/adapter/template.py @@ -49,7 +49,9 @@ class MessageTemplate(Formatter, Generic[TF]): ) -> None: ... - def __init__(self, template, factory=str) -> None: + def __init__( # type:ignore + self, template, factory=str + ) -> None: # TODO: fix type hint here self.template: TF = template self.factory: Type[TF] = factory self.format_specs: Dict[str, FormatSpecFunc] = {} @@ -72,25 +74,37 @@ class MessageTemplate(Formatter, Generic[TF]): return self._format([], mapping) def _format(self, args: Sequence[Any], kwargs: Mapping[str, Any]) -> TF: - msg = self.factory() + full_message = self.factory() + used_args, arg_index = set(), 0 + if isinstance(self.template, str): - msg += self.vformat(self.template, args, kwargs) + msg, arg_index = self._vformat( + self.template, args, kwargs, used_args, arg_index + ) + full_message += msg elif isinstance(self.template, self.factory): template = cast("Message[MessageSegment]", self.template) for seg in template: - msg += self.vformat(str(seg), args, kwargs) if seg.is_text() else seg + if not seg.is_text(): + full_message += seg + else: + msg, arg_index = self._vformat( + str(seg), args, kwargs, used_args, arg_index + ) + full_message += msg else: raise TypeError("template must be a string or instance of Message!") - return msg # type:ignore + self.check_unused_args(list(used_args), args, kwargs) + return cast(TF, full_message) def vformat( - self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any] + self, + format_string: str, + args: Sequence[Any], + kwargs: Mapping[str, Any], ) -> TF: - used_args = set() - result, _ = self._vformat(format_string, args, kwargs, used_args, 2) - self.check_unused_args(list(used_args), args, kwargs) - return result + raise NotImplementedError("`vformat` has merged into `_format`") def _vformat( self, @@ -98,12 +112,8 @@ class MessageTemplate(Formatter, Generic[TF]): args: Sequence[Any], kwargs: Mapping[str, Any], used_args: Set[Union[int, str]], - recursion_depth: int, auto_arg_index: int = 0, ) -> Tuple[TF, int]: - if recursion_depth < 0: - raise ValueError("Max string recursion exceeded") - results: List[Any] = [self.factory()] for (literal_text, field_name, format_spec, conversion) in self.parse( @@ -143,23 +153,13 @@ class MessageTemplate(Formatter, Generic[TF]): obj, arg_used = self.get_field(field_name, args, kwargs) used_args.add(arg_used) - assert format_spec is not None - # do any conversion on the resulting object obj = self.convert_field(obj, conversion) if conversion else obj - # expand the format spec, if needed - format_control, auto_arg_index = self._vformat( - format_spec, - args, - kwargs, - used_args, - recursion_depth - 1, - auto_arg_index, - ) - # format the object and append to the result - formatted_text = self.format_field(obj, str(format_control)) + formatted_text = ( + self.format_field(obj, format_spec) if format_spec else obj + ) results.append(formatted_text) return functools.reduce(self._add, results), auto_arg_index diff --git a/tests/test_adapters/test_template.py b/tests/test_adapters/test_template.py index 84856625..50a87493 100644 --- a/tests/test_adapters/test_template.py +++ b/tests/test_adapters/test_template.py @@ -32,6 +32,26 @@ def test_template_message(): assert str(formatted) == "custom-custom!text[fake:image]" +def test_rich_template_message(): + Message = make_fake_message() + MS = Message.get_segment_class() + + pic1, pic2, pic3 = ( + MS.image("file:///pic1.jpg"), + MS.image("file:///pic2.jpg"), + MS.image("file:///pic3.jpg"), + ) + + template = Message.template("{}{}" + pic2 + "{}") + + result = template.format(pic1, "[fake:image]", pic3) + + assert result["image"] == Message([pic1, pic2, pic3]) + assert str(result) == ( + "[fake:image]" + escape_text("[fake:image]") + "[fake:image]" + "[fake:image]" + ) + + def test_message_injection(): Message = make_fake_message()