🧪 Add a fail test to reproduce #781

This commit is contained in:
Mix 2022-02-10 13:15:59 +08:00
parent e9908bcbc4
commit 455c599b06
2 changed files with 23 additions and 6 deletions

View File

@ -1,4 +1,4 @@
from utils import make_fake_message from utils import make_fake_message, escape_text
def test_template_basis(): def test_template_basis():
@ -10,11 +10,8 @@ def test_template_basis():
def test_template_message(): def test_template_message():
from nonebot.adapters import MessageTemplate
Message = make_fake_message() Message = make_fake_message()
template = Message.template("{a:custom}{b:text}{c:image}")
template = MessageTemplate("{a:custom}{b:text}{c:image}", Message)
@template.add_format_spec @template.add_format_spec
def custom(input: str) -> str: def custom(input: str) -> str:
@ -33,3 +30,12 @@ def test_template_message():
assert template.format_map(format_args) == formatted assert template.format_map(format_args) == formatted
assert formatted.extract_plain_text() == "custom-custom!text" assert formatted.extract_plain_text() == "custom-custom!text"
assert str(formatted) == "custom-custom!text[fake:image]" assert str(formatted) == "custom-custom!text[fake:image]"
def test_message_injection():
Message = make_fake_message()
template = Message.template("{name}Is Bad")
message = template.format(name="[fake:image]")
assert message.extract_plain_text() == escape_text("[fake:image]Is Bad")

View File

@ -6,7 +6,14 @@ if TYPE_CHECKING:
from nonebot.adapters import Event, Message from nonebot.adapters import Event, Message
def make_fake_message() -> Type["Message"]: def escape_text(s: str, *, escape_comma: bool = True) -> str:
s = s.replace("&", "&").replace("[", "[").replace("]", "]")
if escape_comma:
s = s.replace(",", ",")
return s
def make_fake_message():
from nonebot.adapters import Message, MessageSegment from nonebot.adapters import Message, MessageSegment
class FakeMessageSegment(MessageSegment): class FakeMessageSegment(MessageSegment):
@ -42,6 +49,10 @@ def make_fake_message() -> Type["Message"]:
yield FakeMessageSegment(**seg) yield FakeMessageSegment(**seg)
return return
def __add__(self, other):
other = escape_text(other) if isinstance(other, str) else other
return super().__add__(other)
return FakeMessage return FakeMessage