🔀 Merge pull request #772

Fix Message.template format spec does not support static method
This commit is contained in:
Ju4tCode 2022-02-07 10:59:43 +08:00 committed by GitHub
commit 5ce72655e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 12 deletions

View File

@ -1,4 +1,3 @@
import inspect
import functools import functools
from string import Formatter from string import Formatter
from typing import ( from typing import (
@ -35,7 +34,7 @@ class MessageTemplate(Formatter, Generic[TF]):
参数: 参数:
template: 模板 template: 模板
factory: 消息构造类型默认为 `str` factory: 消息类型工厂默认为 `str`
""" """
@overload @overload
@ -64,8 +63,15 @@ class MessageTemplate(Formatter, Generic[TF]):
self.format_specs[name] = spec self.format_specs[name] = spec
return spec return spec
def format(self, *args: Any, **kwargs: Any) -> TF: def format(self, *args, **kwargs):
"""根据模板和参数生成消息对象""" """根据传入参数和模板生成消息对象"""
return self._format(args, kwargs)
def format_map(self, mapping: Mapping[str, Any]) -> TF:
"""根据传入字典和模板生成消息对象, 在传入字段名不是有效标识符时有用"""
return self._format([], mapping)
def _format(self, args: Sequence[Any], kwargs: Mapping[str, Any]) -> TF:
msg = self.factory() msg = self.factory()
if isinstance(self.template, str): if isinstance(self.template, str):
msg += self.vformat(self.template, args, kwargs) msg += self.vformat(self.template, args, kwargs)
@ -166,7 +172,7 @@ class MessageTemplate(Formatter, Generic[TF]):
if formatter is None and not issubclass(self.factory, str): if formatter is None and not issubclass(self.factory, str):
segment_class: Type["MessageSegment"] = self.factory.get_segment_class() segment_class: Type["MessageSegment"] = self.factory.get_segment_class()
method = getattr(segment_class, format_spec, None) method = getattr(segment_class, format_spec, None)
if inspect.ismethod(method): if callable(method) and not cast(str, method.__name__).startswith("_"):
formatter = getattr(segment_class, format_spec) formatter = getattr(segment_class, format_spec)
return ( return (
super().format_field(value, format_spec) super().format_field(value, format_spec)

View File

@ -1,7 +1,15 @@
from utils import make_fake_message from utils import make_fake_message
def test_message_template(): def test_template_basis():
from nonebot.adapters import MessageTemplate
template = MessageTemplate("{key:.3%}")
formatted = template.format(key=0.123456789)
assert formatted == "12.346%"
def test_template_message():
from nonebot.adapters import MessageTemplate from nonebot.adapters import MessageTemplate
Message = make_fake_message() Message = make_fake_message()
@ -12,6 +20,16 @@ def test_message_template():
def custom(input: str) -> str: def custom(input: str) -> str:
return input + "-custom!" return input + "-custom!"
formatted = template.format(a="test", b="test", c="https://example.com/test") try:
assert formatted.extract_plain_text() == "test-custom!test" template.add_format_spec(custom)
assert str(formatted) == "test-custom!test[fake:image]" except ValueError:
pass
else:
raise AssertionError("Should raise ValueError")
format_args = {"a": "custom", "b": "text", "c": "https://example.com/test"}
formatted = template.format(**format_args)
assert template.format_map(format_args) == formatted
assert formatted.extract_plain_text() == "custom-custom!text"
assert str(formatted) == "custom-custom!text[fake:image]"

View File

@ -21,9 +21,9 @@ def make_fake_message() -> Type["Message"]:
def text(cls, text: str): def text(cls, text: str):
return cls("text", {"text": text}) return cls("text", {"text": text})
@classmethod @staticmethod
def image(cls, url: str): def image(url: str):
return cls("image", {"url": url}) return FakeMessageSegment("image", {"url": url})
def is_text(self) -> bool: def is_text(self) -> bool:
return self.type == "text" return self.type == "text"