🔀 Merge pull request #645

Support user-defined format spec for message template
This commit is contained in:
Ju4tCode 2022-01-02 16:25:43 +08:00 committed by GitHub
commit 0541008fb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 50 additions and 22 deletions

View File

@ -4,6 +4,9 @@ from string import Formatter
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Callable,
Dict,
Optional,
Set, Set,
List, List,
Type, Type,
@ -23,6 +26,9 @@ if TYPE_CHECKING:
TM = TypeVar("TM", bound="Message") TM = TypeVar("TM", bound="Message")
TF = TypeVar("TF", str, "Message") TF = TypeVar("TF", str, "Message")
FormatSpecFunc = Callable[[Any], str]
FormatSpecFunc_T = TypeVar("FormatSpecFunc_T", bound=FormatSpecFunc)
class MessageTemplate(Formatter, Generic[TF]): class MessageTemplate(Formatter, Generic[TF]):
"""消息模板格式化实现类""" """消息模板格式化实现类"""
@ -50,8 +56,18 @@ class MessageTemplate(Formatter, Generic[TF]):
* ``template: Union[str, Message]``: 模板 * ``template: Union[str, Message]``: 模板
* ``factory: Union[str, Message]``: 消息构造类型默认为 `str` * ``factory: Union[str, Message]``: 消息构造类型默认为 `str`
""" """
self.template: Union[str, TF] = template self.template: TF = template
self.factory: Type[TF] = factory self.factory: Type[TF] = factory
self.format_specs: Dict[str, FormatSpecFunc] = {}
def add_format_spec(
self, spec: FormatSpecFunc_T, name: Optional[str] = None
) -> FormatSpecFunc_T:
name = name or spec.__name__
if name in self.format_specs:
raise ValueError(f"Format spec {name} already exists!")
self.format_specs[name] = spec
return spec
def format(self, *args: Any, **kwargs: Any) -> TF: def format(self, *args: Any, **kwargs: Any) -> TF:
""" """
@ -69,7 +85,7 @@ class MessageTemplate(Formatter, Generic[TF]):
else: else:
raise TypeError("template must be a string or instance of Message!") raise TypeError("template must be a string or instance of Message!")
return msg return msg # type:ignore
def vformat( def vformat(
self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any] self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]
@ -155,25 +171,16 @@ class MessageTemplate(Formatter, Generic[TF]):
) )
def format_field(self, value: Any, format_spec: str) -> Any: def format_field(self, value: Any, format_spec: str) -> Any:
if issubclass(self.factory, str): formatter: Optional[FormatSpecFunc] = self.format_specs.get(format_spec)
return super().format_field(value, format_spec) if (formatter is None) and (not issubclass(self.factory, str)):
segment_class: Type["MessageSegment"] = self.factory.get_segment_class()
segment_class: Type[MessageSegment] = self.factory.get_segment_class()
method = getattr(segment_class, format_spec, None) method = getattr(segment_class, format_spec, None)
method_type = inspect.getattr_static(segment_class, format_spec, None) if inspect.ismethod(method):
formatter = getattr(segment_class, format_spec)
return ( return (
(
super().format_field(value, format_spec) super().format_field(value, format_spec)
if ( if formatter is None
(method is None) else formatter(value)
or (
not isinstance(method_type, (classmethod, staticmethod))
) # Only Call staticmethod or classmethod
)
else method(value)
)
if format_spec
else value
) )
def _add(self, a: Any, b: Any) -> Any: def _add(self, a: Any, b: Any) -> Any:

View File

@ -0,0 +1,17 @@
from utils import make_fake_message
def test_message_template():
from nonebot.adapters import MessageTemplate
Message = make_fake_message()
template = MessageTemplate("{a:custom}{b:text}{c:image}", Message)
@template.add_format_spec
def custom(input: str) -> str:
return input + "-custom!"
formatted = template.format(a="test", b="test", c="https://example.com/test")
assert formatted.extract_plain_text() == "test-custom!test"
assert str(formatted) == "test-custom!test[fake:image]"

View File

@ -18,14 +18,18 @@ def make_fake_message() -> Type["Message"]:
return FakeMessage return FakeMessage
def __str__(self) -> str: def __str__(self) -> str:
return self.data["text"] return self.data["text"] if self.type == "text" else f"[fake:{self.type}]"
@classmethod @classmethod
def text(cls, text: str): def text(cls, text: str):
return cls("text", {"text": text}) return cls("text", {"text": text})
@classmethod
def image(cls, url: str):
return cls("image", {"url": url})
def is_text(self) -> bool: def is_text(self) -> bool:
return True return self.type == "text"
class FakeMessage(Message): class FakeMessage(Message):
@classmethod @classmethod