diff --git a/nonebot/internal/adapter/template.py b/nonebot/internal/adapter/template.py index d6c934ee..22d2aebb 100644 --- a/nonebot/internal/adapter/template.py +++ b/nonebot/internal/adapter/template.py @@ -1,4 +1,3 @@ -import inspect import functools from string import Formatter from typing import ( @@ -35,7 +34,7 @@ class MessageTemplate(Formatter, Generic[TF]): 参数: template: 模板 - factory: 消息构造类型,默认为 `str` + factory: 消息类型工厂,默认为 `str` """ @overload @@ -64,8 +63,15 @@ class MessageTemplate(Formatter, Generic[TF]): self.format_specs[name] = spec return spec - def format(self, *args: Any, **kwargs: Any) -> TF: - """根据模板和参数生成消息对象""" + def format(self, *args, **kwargs): + """根据传入参数和模板生成消息对象""" + return self._format(args, kwargs) + + def format_map(self, mapping: Mapping[str, Any]) -> TF: + """根据传入字典和模板生成消息对象, 在传入字段名不是有效标识符时有用""" + return self._format([], mapping) + + def _format(self, args: Sequence[Any], kwargs: Mapping[str, Any]) -> TF: msg = self.factory() if isinstance(self.template, str): msg += self.vformat(self.template, args, kwargs) @@ -166,7 +172,7 @@ class MessageTemplate(Formatter, Generic[TF]): if formatter is None and not issubclass(self.factory, str): segment_class: Type["MessageSegment"] = self.factory.get_segment_class() method = getattr(segment_class, format_spec, None) - if inspect.ismethod(method): + if callable(method) and not cast(str, method.__name__).startswith("_"): formatter = getattr(segment_class, format_spec) return ( super().format_field(value, format_spec) diff --git a/tests/test_adapters/test_template.py b/tests/test_adapters/test_template.py index 3dbef541..1814b04d 100644 --- a/tests/test_adapters/test_template.py +++ b/tests/test_adapters/test_template.py @@ -1,7 +1,15 @@ from utils import make_fake_message -def test_message_template(): +def test_template_basis(): + from nonebot.adapters import MessageTemplate + + template = MessageTemplate("{key:.3%}") + formatted = template.format(key=0.123456789) + assert formatted == "12.346%" + + +def test_template_message(): from nonebot.adapters import MessageTemplate Message = make_fake_message() @@ -12,6 +20,16 @@ def test_message_template(): def custom(input: str) -> str: return input + "-custom!" - formatted = template.format(a="test", b="test", c="https://example.com/test") - assert formatted.extract_plain_text() == "test-custom!test" - assert str(formatted) == "test-custom!test[fake:image]" + try: + template.add_format_spec(custom) + except ValueError: + pass + else: + raise AssertionError("Should raise ValueError") + + format_args = {"a": "custom", "b": "text", "c": "https://example.com/test"} + formatted = template.format(**format_args) + + assert template.format_map(format_args) == formatted + assert formatted.extract_plain_text() == "custom-custom!text" + assert str(formatted) == "custom-custom!text[fake:image]" diff --git a/tests/utils.py b/tests/utils.py index e82ad7d9..ef54b69b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -21,9 +21,9 @@ def make_fake_message() -> Type["Message"]: def text(cls, text: str): return cls("text", {"text": text}) - @classmethod - def image(cls, url: str): - return cls("image", {"url": url}) + @staticmethod + def image(url: str): + return FakeMessageSegment("image", {"url": url}) def is_text(self) -> bool: return self.type == "text"