From 455c599b06c5f59305e828a08c82cb33d585ece5 Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Thu, 10 Feb 2022 13:15:59 +0800 Subject: [PATCH] :test_tube: Add a fail test to reproduce #781 --- tests/test_adapters/test_template.py | 16 +++++++++++----- tests/utils.py | 13 ++++++++++++- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/tests/test_adapters/test_template.py b/tests/test_adapters/test_template.py index 1814b04d..32974b63 100644 --- a/tests/test_adapters/test_template.py +++ b/tests/test_adapters/test_template.py @@ -1,4 +1,4 @@ -from utils import make_fake_message +from utils import make_fake_message, escape_text def test_template_basis(): @@ -10,11 +10,8 @@ def test_template_basis(): def test_template_message(): - from nonebot.adapters import MessageTemplate - Message = make_fake_message() - - template = MessageTemplate("{a:custom}{b:text}{c:image}", Message) + template = Message.template("{a:custom}{b:text}{c:image}") @template.add_format_spec def custom(input: str) -> str: @@ -33,3 +30,12 @@ def test_template_message(): assert template.format_map(format_args) == formatted assert formatted.extract_plain_text() == "custom-custom!text" assert str(formatted) == "custom-custom!text[fake:image]" + + +def test_message_injection(): + Message = make_fake_message() + + template = Message.template("{name}Is Bad") + message = template.format(name="[fake:image]") + + assert message.extract_plain_text() == escape_text("[fake:image]Is Bad") diff --git a/tests/utils.py b/tests/utils.py index ef54b69b..0cd94be4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,7 +6,14 @@ if TYPE_CHECKING: from nonebot.adapters import Event, Message -def make_fake_message() -> Type["Message"]: +def escape_text(s: str, *, escape_comma: bool = True) -> str: + s = s.replace("&", "&").replace("[", "[").replace("]", "]") + if escape_comma: + s = s.replace(",", ",") + return s + + +def make_fake_message(): from nonebot.adapters import Message, MessageSegment class FakeMessageSegment(MessageSegment): @@ -42,6 +49,10 @@ def make_fake_message() -> Type["Message"]: yield FakeMessageSegment(**seg) return + def __add__(self, other): + other = escape_text(other) if isinstance(other, str) else other + return super().__add__(other) + return FakeMessage