From 983e5aefdb8a2b75ab469adf645d1f54ffd19e13 Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Sat, 1 Jan 2022 21:52:54 +0800 Subject: [PATCH 1/3] :sparkles: support user-defined format spec for message template --- nonebot/adapters/_template.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/nonebot/adapters/_template.py b/nonebot/adapters/_template.py index fad17d84..a4a762f9 100644 --- a/nonebot/adapters/_template.py +++ b/nonebot/adapters/_template.py @@ -4,6 +4,9 @@ from string import Formatter from typing import ( TYPE_CHECKING, Any, + Callable, + Dict, + Optional, Set, List, Type, @@ -23,6 +26,9 @@ if TYPE_CHECKING: TM = TypeVar("TM", bound="Message") TF = TypeVar("TF", str, "Message") +FormatSpecFunc = Callable[[Any], str] +FormatSpecFunc_T = TypeVar("FormatSpecFunc_T", bound=FormatSpecFunc) + class MessageTemplate(Formatter, Generic[TF]): """消息模板格式化实现类""" @@ -50,8 +56,18 @@ class MessageTemplate(Formatter, Generic[TF]): * ``template: Union[str, Message]``: 模板 * ``factory: Union[str, Message]``: 消息构造类型,默认为 `str` """ - self.template: Union[str, TF] = template + self.template: TF = template self.factory: Type[TF] = factory + self.format_specs: Dict[str, FormatSpecFunc] = {} + + def add_format_spec( + self, spec: FormatSpecFunc_T, name: Optional[str] = None + ) -> FormatSpecFunc_T: + name = name or spec.__name__ + if name in self.format_specs: + raise ValueError(f"Format spec {name} already exists!") + self.format_specs[name] = spec + return spec def format(self, *args: Any, **kwargs: Any) -> TF: """ @@ -69,7 +85,7 @@ class MessageTemplate(Formatter, Generic[TF]): else: raise TypeError("template must be a string or instance of Message!") - return msg + return msg # type:ignore def vformat( self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any] @@ -165,12 +181,14 @@ class MessageTemplate(Formatter, Generic[TF]): ( super().format_field(value, format_spec) if ( - (method is None) - or ( - not isinstance(method_type, (classmethod, staticmethod)) - ) # Only Call staticmethod or classmethod + method is None + or not isinstance(method_type, (classmethod, staticmethod)) + ) + else ( + self.format_specs[format_spec](value) + if format_spec in self.format_specs + else method(value) ) - else method(value) ) if format_spec else value From 43938a004e10731eff4a20ce1c80afb15d248b89 Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Sun, 2 Jan 2022 13:13:43 +0800 Subject: [PATCH 2/3] :recycle: refactor template `format_field` to improve readability --- nonebot/adapters/_template.py | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/nonebot/adapters/_template.py b/nonebot/adapters/_template.py index a4a762f9..82903dc5 100644 --- a/nonebot/adapters/_template.py +++ b/nonebot/adapters/_template.py @@ -171,27 +171,16 @@ class MessageTemplate(Formatter, Generic[TF]): ) def format_field(self, value: Any, format_spec: str) -> Any: - if issubclass(self.factory, str): - return super().format_field(value, format_spec) - - segment_class: Type[MessageSegment] = self.factory.get_segment_class() - method = getattr(segment_class, format_spec, None) - method_type = inspect.getattr_static(segment_class, format_spec, None) + formatter: Optional[FormatSpecFunc] = self.format_specs.get(format_spec) + if (formatter is None) and (not issubclass(self.factory, str)): + segment_class: Type["MessageSegment"] = self.factory.get_segment_class() + method = getattr(segment_class, format_spec, None) + if inspect.ismethod(method): + formatter = getattr(segment_class, format_spec) return ( - ( - super().format_field(value, format_spec) - if ( - method is None - or not isinstance(method_type, (classmethod, staticmethod)) - ) - else ( - self.format_specs[format_spec](value) - if format_spec in self.format_specs - else method(value) - ) - ) - if format_spec - else value + super().format_field(value, format_spec) + if formatter is None + else formatter(value) ) def _add(self, a: Any, b: Any) -> Any: From be1915381eff6d4855287b6b88fdc9910266662d Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Sun, 2 Jan 2022 14:29:04 +0800 Subject: [PATCH 3/3] :white_check_mark: add tests for message template --- tests/adapters/test_template.py | 17 +++++++++++++++++ tests/utils.py | 8 ++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) create mode 100644 tests/adapters/test_template.py diff --git a/tests/adapters/test_template.py b/tests/adapters/test_template.py new file mode 100644 index 00000000..3dbef541 --- /dev/null +++ b/tests/adapters/test_template.py @@ -0,0 +1,17 @@ +from utils import make_fake_message + + +def test_message_template(): + from nonebot.adapters import MessageTemplate + + Message = make_fake_message() + + template = MessageTemplate("{a:custom}{b:text}{c:image}", Message) + + @template.add_format_spec + def custom(input: str) -> str: + return input + "-custom!" + + formatted = template.format(a="test", b="test", c="https://example.com/test") + assert formatted.extract_plain_text() == "test-custom!test" + assert str(formatted) == "test-custom!test[fake:image]" diff --git a/tests/utils.py b/tests/utils.py index d516665f..a8d0c0ed 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -18,14 +18,18 @@ def make_fake_message() -> Type["Message"]: return FakeMessage def __str__(self) -> str: - return self.data["text"] + return self.data["text"] if self.type == "text" else f"[fake:{self.type}]" @classmethod def text(cls, text: str): return cls("text", {"text": text}) + @classmethod + def image(cls, url: str): + return cls("image", {"url": url}) + def is_text(self) -> bool: - return True + return self.type == "text" class FakeMessage(Message): @classmethod