🔀 Merge pull request #782

Bugfix: Potential message body injection vulnerability in MessageTemplate
This commit is contained in:
Ju4tCode 2022-02-10 15:44:41 +08:00 committed by GitHub
commit 4f91e63759
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 35 additions and 13 deletions

View File

@ -104,7 +104,7 @@ class MessageTemplate(Formatter, Generic[TF]):
if recursion_depth < 0: if recursion_depth < 0:
raise ValueError("Max string recursion exceeded") raise ValueError("Max string recursion exceeded")
results: List[Any] = [] results: List[Any] = [self.factory()]
for (literal_text, field_name, format_spec, conversion) in self.parse( for (literal_text, field_name, format_spec, conversion) in self.parse(
format_string format_string
@ -162,10 +162,7 @@ class MessageTemplate(Formatter, Generic[TF]):
formatted_text = self.format_field(obj, str(format_control)) formatted_text = self.format_field(obj, str(format_control))
results.append(formatted_text) results.append(formatted_text)
return ( return functools.reduce(self._add, results), auto_arg_index
self.factory(functools.reduce(self._add, results or [""])),
auto_arg_index,
)
def format_field(self, value: Any, format_spec: str) -> Any: def format_field(self, value: Any, format_spec: str) -> Any:
formatter: Optional[FormatSpecFunc] = self.format_specs.get(format_spec) formatter: Optional[FormatSpecFunc] = self.format_specs.get(format_spec)

View File

@ -1,4 +1,4 @@
from utils import make_fake_message from utils import escape_text, make_fake_message
def test_template_basis(): def test_template_basis():
@ -10,11 +10,8 @@ def test_template_basis():
def test_template_message(): def test_template_message():
from nonebot.adapters import MessageTemplate
Message = make_fake_message() Message = make_fake_message()
template = Message.template("{a:custom}{b:text}{c:image}")
template = MessageTemplate("{a:custom}{b:text}{c:image}", Message)
@template.add_format_spec @template.add_format_spec
def custom(input: str) -> str: def custom(input: str) -> str:
@ -33,3 +30,12 @@ def test_template_message():
assert template.format_map(format_args) == formatted assert template.format_map(format_args) == formatted
assert formatted.extract_plain_text() == "custom-custom!text" assert formatted.extract_plain_text() == "custom-custom!text"
assert str(formatted) == "custom-custom!text[fake:image]" assert str(formatted) == "custom-custom!text[fake:image]"
def test_message_injection():
Message = make_fake_message()
template = Message.template("{name}Is Bad")
message = template.format(name="[fake:image]")
assert message.extract_plain_text() == escape_text("[fake:image]Is Bad")

View File

@ -29,7 +29,11 @@ async def test_weather(app: App):
event = make_fake_event(_message=msg, _to_me=True)() event = make_fake_event(_message=msg, _to_me=True)()
ctx.receive_event(bot, event) ctx.receive_event(bot, event)
ctx.should_call_send(event, Message("你想查询的城市 南京 暂不支持,请重新输入!"), True) ctx.should_call_send(
event,
Message.template("你想查询的城市 {} 暂不支持,请重新输入!").format("南京"),
True,
)
ctx.should_rejected() ctx.should_rejected()
msg = Message("北京") msg = Message("北京")
@ -53,7 +57,11 @@ async def test_weather(app: App):
event = make_fake_event(_message=msg)() event = make_fake_event(_message=msg)()
ctx.receive_event(bot, event) ctx.receive_event(bot, event)
ctx.should_call_send(event, Message("你想查询的城市 杭州 暂不支持,请重新输入!"), True) ctx.should_call_send(
event,
Message.template("你想查询的城市 {} 暂不支持,请重新输入!").format("杭州"),
True,
)
ctx.should_rejected() ctx.should_rejected()
msg = Message("北京") msg = Message("北京")

View File

@ -6,7 +6,14 @@ if TYPE_CHECKING:
from nonebot.adapters import Event, Message from nonebot.adapters import Event, Message
def make_fake_message() -> Type["Message"]: def escape_text(s: str, *, escape_comma: bool = True) -> str:
s = s.replace("&", "&amp;").replace("[", "&#91;").replace("]", "&#93;")
if escape_comma:
s = s.replace(",", "&#44;")
return s
def make_fake_message():
from nonebot.adapters import Message, MessageSegment from nonebot.adapters import Message, MessageSegment
class FakeMessageSegment(MessageSegment): class FakeMessageSegment(MessageSegment):
@ -42,6 +49,10 @@ def make_fake_message() -> Type["Message"]:
yield FakeMessageSegment(**seg) yield FakeMessageSegment(**seg)
return return
def __add__(self, other):
other = escape_text(other) if isinstance(other, str) else other
return super().__add__(other)
return FakeMessage return FakeMessage