diff --git a/nonebot/internal/adapter/template.py b/nonebot/internal/adapter/template.py index f0edc32e..b720c257 100644 --- a/nonebot/internal/adapter/template.py +++ b/nonebot/internal/adapter/template.py @@ -20,9 +20,17 @@ from typing import ( overload, ) +from _string import formatter_field_name_split # type: ignore + if TYPE_CHECKING: from .message import Message, MessageSegment + def formatter_field_name_split( # noqa: F811 + field_name: str, + ) -> Tuple[str, List[Tuple[bool, str]]]: + ... + + TM = TypeVar("TM", bound="Message") TF = TypeVar("TF", str, "Message") @@ -36,26 +44,37 @@ class MessageTemplate(Formatter, Generic[TF]): 参数: template: 模板 factory: 消息类型工厂,默认为 `str` + private_getattr: 是否允许在模板中访问私有属性,默认为 `False` """ @overload def __init__( - self: "MessageTemplate[str]", template: str, factory: Type[str] = str + self: "MessageTemplate[str]", + template: str, + factory: Type[str] = str, + private_getattr: bool = False, ) -> None: ... @overload def __init__( - self: "MessageTemplate[TM]", template: Union[str, TM], factory: Type[TM] + self: "MessageTemplate[TM]", + template: Union[str, TM], + factory: Type[TM], + private_getattr: bool = False, ) -> None: ... def __init__( - self, template: Union[str, TM], factory: Union[Type[str], Type[TM]] = str + self, + template: Union[str, TM], + factory: Union[Type[str], Type[TM]] = str, + private_getattr: bool = False, ) -> None: self.template: TF = template # type: ignore self.factory: Type[TF] = factory # type: ignore self.format_specs: Dict[str, FormatSpecFunc] = {} + self.private_getattr = private_getattr def __repr__(self) -> str: return f"MessageTemplate({self.template!r}, factory={self.factory!r})" @@ -167,6 +186,19 @@ class MessageTemplate(Formatter, Generic[TF]): return functools.reduce(self._add, results), auto_arg_index + def get_field( + self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, Any] + ) -> Tuple[Any, Union[int, str]]: + first, rest = formatter_field_name_split(field_name) + obj = self.get_value(first, args, kwargs) + + for is_attr, value in rest: + if not self.private_getattr and value.startswith("_"): + raise ValueError("Cannot access private attribute") + obj = getattr(obj, value) if is_attr else obj[value] + + return obj, first + def format_field(self, value: Any, format_spec: str) -> Any: formatter: Optional[FormatSpecFunc] = self.format_specs.get(format_spec) if formatter is None and not issubclass(self.factory, str): diff --git a/tests/test_adapters/test_template.py b/tests/test_adapters/test_template.py index 710c556a..3ca840ea 100644 --- a/tests/test_adapters/test_template.py +++ b/tests/test_adapters/test_template.py @@ -1,3 +1,5 @@ +import pytest + from nonebot.adapters import MessageTemplate from utils import FakeMessage, FakeMessageSegment, escape_text @@ -15,12 +17,8 @@ def test_template_message(): def custom(input: str) -> str: return f"{input}-custom!" - try: + with pytest.raises(ValueError, match="already exists"): template.add_format_spec(custom) - except ValueError: - pass - else: - raise AssertionError("Should raise ValueError") format_args = { "a": "custom", @@ -57,3 +55,22 @@ def test_message_injection(): message = template.format(name="[fake:image]") assert message.extract_plain_text() == escape_text("[fake:image]Is Bad") + + +def test_malformed_template(): + positive_template = FakeMessage.template("{a}{b}") + message = positive_template.format(a="a", b="b") + assert message.extract_plain_text() == "ab" + + malformed_template = FakeMessage.template("{a.__init__}") + with pytest.raises(ValueError, match="private attribute"): + message = malformed_template.format(a="a") + + malformed_template = FakeMessage.template("{a[__builtins__]}") + with pytest.raises(ValueError, match="private attribute"): + message = malformed_template.format(a=globals()) + + malformed_template = MessageTemplate( + "{a[__builtins__][__import__]}{b.__init__}", private_getattr=True + ) + message = malformed_template.format(a=globals(), b="b")