From 455c599b06c5f59305e828a08c82cb33d585ece5 Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Thu, 10 Feb 2022 13:15:59 +0800 Subject: [PATCH 1/3] :test_tube: Add a fail test to reproduce #781 --- tests/test_adapters/test_template.py | 16 +++++++++++----- tests/utils.py | 13 ++++++++++++- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/tests/test_adapters/test_template.py b/tests/test_adapters/test_template.py index 1814b04d..32974b63 100644 --- a/tests/test_adapters/test_template.py +++ b/tests/test_adapters/test_template.py @@ -1,4 +1,4 @@ -from utils import make_fake_message +from utils import make_fake_message, escape_text def test_template_basis(): @@ -10,11 +10,8 @@ def test_template_basis(): def test_template_message(): - from nonebot.adapters import MessageTemplate - Message = make_fake_message() - - template = MessageTemplate("{a:custom}{b:text}{c:image}", Message) + template = Message.template("{a:custom}{b:text}{c:image}") @template.add_format_spec def custom(input: str) -> str: @@ -33,3 +30,12 @@ def test_template_message(): assert template.format_map(format_args) == formatted assert formatted.extract_plain_text() == "custom-custom!text" assert str(formatted) == "custom-custom!text[fake:image]" + + +def test_message_injection(): + Message = make_fake_message() + + template = Message.template("{name}Is Bad") + message = template.format(name="[fake:image]") + + assert message.extract_plain_text() == escape_text("[fake:image]Is Bad") diff --git a/tests/utils.py b/tests/utils.py index ef54b69b..0cd94be4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,7 +6,14 @@ if TYPE_CHECKING: from nonebot.adapters import Event, Message -def make_fake_message() -> Type["Message"]: +def escape_text(s: str, *, escape_comma: bool = True) -> str: + s = s.replace("&", "&").replace("[", "[").replace("]", "]") + if escape_comma: + s = s.replace(",", ",") + return s + + +def make_fake_message(): from nonebot.adapters import Message, MessageSegment class FakeMessageSegment(MessageSegment): @@ -42,6 +49,10 @@ def make_fake_message() -> Type["Message"]: yield FakeMessageSegment(**seg) return + def __add__(self, other): + other = escape_text(other) if isinstance(other, str) else other + return super().__add__(other) + return FakeMessage From b7762b91765419a98850819327738d949f8811d5 Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Thu, 10 Feb 2022 13:17:11 +0800 Subject: [PATCH 2/3] :lock: :bug: Add initial value to vformat results list, fix #781 --- nonebot/internal/adapter/template.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/nonebot/internal/adapter/template.py b/nonebot/internal/adapter/template.py index 22d2aebb..e04830a8 100644 --- a/nonebot/internal/adapter/template.py +++ b/nonebot/internal/adapter/template.py @@ -104,7 +104,7 @@ class MessageTemplate(Formatter, Generic[TF]): if recursion_depth < 0: raise ValueError("Max string recursion exceeded") - results: List[Any] = [] + results: List[Any] = [self.factory()] for (literal_text, field_name, format_spec, conversion) in self.parse( format_string @@ -162,10 +162,7 @@ class MessageTemplate(Formatter, Generic[TF]): formatted_text = self.format_field(obj, str(format_control)) results.append(formatted_text) - return ( - self.factory(functools.reduce(self._add, results or [""])), - auto_arg_index, - ) + return functools.reduce(self._add, results), auto_arg_index def format_field(self, value: Any, format_spec: str) -> Any: formatter: Optional[FormatSpecFunc] = self.format_specs.get(format_spec) From dc982fe5eb06a8a8b79de9959e9048e5b23daa7a Mon Sep 17 00:00:00 2001 From: Mix <32300164+mnixry@users.noreply.github.com> Date: Thu, 10 Feb 2022 15:12:27 +0800 Subject: [PATCH 3/3] :white_check_mark: Fix update cause failed test --- tests/test_adapters/test_template.py | 2 +- tests/test_examples/test_weather.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/test_adapters/test_template.py b/tests/test_adapters/test_template.py index 32974b63..84856625 100644 --- a/tests/test_adapters/test_template.py +++ b/tests/test_adapters/test_template.py @@ -1,4 +1,4 @@ -from utils import make_fake_message, escape_text +from utils import escape_text, make_fake_message def test_template_basis(): diff --git a/tests/test_examples/test_weather.py b/tests/test_examples/test_weather.py index 08bab064..086e2c1c 100644 --- a/tests/test_examples/test_weather.py +++ b/tests/test_examples/test_weather.py @@ -29,7 +29,11 @@ async def test_weather(app: App): event = make_fake_event(_message=msg, _to_me=True)() ctx.receive_event(bot, event) - ctx.should_call_send(event, Message("你想查询的城市 南京 暂不支持,请重新输入!"), True) + ctx.should_call_send( + event, + Message.template("你想查询的城市 {} 暂不支持,请重新输入!").format("南京"), + True, + ) ctx.should_rejected() msg = Message("北京") @@ -53,7 +57,11 @@ async def test_weather(app: App): event = make_fake_event(_message=msg)() ctx.receive_event(bot, event) - ctx.should_call_send(event, Message("你想查询的城市 杭州 暂不支持,请重新输入!"), True) + ctx.should_call_send( + event, + Message.template("你想查询的城市 {} 暂不支持,请重新输入!").format("杭州"), + True, + ) ctx.should_rejected() msg = Message("北京")