Commit 91022bb0 authored by Stella Lau's avatar Stella Lau Committed by Facebook Github Bot

Enforce forward progress with StreamCodec

Summary:
- Throw exception if no forward progress was made with `StreamCodec.compress()` and `StreamCodec.uncompress()`
- Prevents infinite looping behavior when no forward progress was made
- Update tests

Reviewed By: terrelln

Differential Revision: D5685690

fbshipit-source-id: 969393896b74f51250f0e0ce3af0cd4fedcab49a
parent 82ee3be4
...@@ -234,6 +234,7 @@ void StreamCodec::assertStateIs(State expected) const { ...@@ -234,6 +234,7 @@ void StreamCodec::assertStateIs(State expected) const {
void StreamCodec::resetStream(Optional<uint64_t> uncompressedLength) { void StreamCodec::resetStream(Optional<uint64_t> uncompressedLength) {
state_ = State::RESET; state_ = State::RESET;
uncompressedLength_ = uncompressedLength; uncompressedLength_ = uncompressedLength;
progressMade_ = true;
doResetStream(); doResetStream();
} }
...@@ -279,7 +280,18 @@ bool StreamCodec::compressStream( ...@@ -279,7 +280,18 @@ bool StreamCodec::compressStream(
assertStateIs(State::COMPRESS_END); assertStateIs(State::COMPRESS_END);
break; break;
} }
size_t const inputSize = input.size();
size_t const outputSize = output.size();
bool const done = doCompressStream(input, output, flushOp); bool const done = doCompressStream(input, output, flushOp);
if (!done && inputSize == input.size() && outputSize == output.size()) {
if (!progressMade_) {
throw std::runtime_error("Codec: No forward progress made");
}
// Throw an exception if there is no progress again next time
progressMade_ = false;
} else {
progressMade_ = true;
}
// Handle output state transitions // Handle output state transitions
if (done) { if (done) {
if (state_ == State::COMPRESS_FLUSH) { if (state_ == State::COMPRESS_FLUSH) {
...@@ -309,7 +321,18 @@ bool StreamCodec::uncompressStream( ...@@ -309,7 +321,18 @@ bool StreamCodec::uncompressStream(
state_ = State::UNCOMPRESS; state_ = State::UNCOMPRESS;
} }
assertStateIs(State::UNCOMPRESS); assertStateIs(State::UNCOMPRESS);
size_t const inputSize = input.size();
size_t const outputSize = output.size();
bool const done = doUncompressStream(input, output, flushOp); bool const done = doUncompressStream(input, output, flushOp);
if (!done && inputSize == input.size() && outputSize == output.size()) {
if (!progressMade_) {
throw std::runtime_error("Codec: no forward progress made");
}
// Throw an exception if there is no progress again next time
progressMade_ = false;
} else {
progressMade_ = true;
}
// Handle output state transitions // Handle output state transitions
if (done) { if (done) {
state_ = State::END; state_ = State::END;
...@@ -345,7 +368,8 @@ std::unique_ptr<IOBuf> StreamCodec::doCompress(IOBuf const* data) { ...@@ -345,7 +368,8 @@ std::unique_ptr<IOBuf> StreamCodec::doCompress(IOBuf const* data) {
IOBuf const* current = data; IOBuf const* current = data;
ByteRange input{current->data(), current->length()}; ByteRange input{current->data(), current->length()};
StreamCodec::FlushOp flushOp = StreamCodec::FlushOp::NONE; StreamCodec::FlushOp flushOp = StreamCodec::FlushOp::NONE;
for (;;) { bool done = false;
while (!done) {
while (input.empty() && current->next() != data) { while (input.empty() && current->next() != data) {
current = current->next(); current = current->next();
input = {current->data(), current->length()}; input = {current->data(), current->length()};
...@@ -357,17 +381,11 @@ std::unique_ptr<IOBuf> StreamCodec::doCompress(IOBuf const* data) { ...@@ -357,17 +381,11 @@ std::unique_ptr<IOBuf> StreamCodec::doCompress(IOBuf const* data) {
if (output.empty()) { if (output.empty()) {
buffer->prependChain(addOutputBuffer(output, kDefaultBufferLength)); buffer->prependChain(addOutputBuffer(output, kDefaultBufferLength));
} }
size_t const inputSize = input.size(); done = compressStream(input, output, flushOp);
size_t const outputSize = output.size();
bool const done = compressStream(input, output, flushOp);
if (done) { if (done) {
DCHECK(input.empty()); DCHECK(input.empty());
DCHECK(flushOp == StreamCodec::FlushOp::END); DCHECK(flushOp == StreamCodec::FlushOp::END);
DCHECK_EQ(current->next(), data); DCHECK_EQ(current->next(), data);
break;
}
if (inputSize == input.size() && outputSize == output.size()) {
throw std::runtime_error("Codec: No forward progress made");
} }
} }
buffer->prev()->trimEnd(output.size()); buffer->prev()->trimEnd(output.size());
...@@ -404,7 +422,8 @@ std::unique_ptr<IOBuf> StreamCodec::doUncompress( ...@@ -404,7 +422,8 @@ std::unique_ptr<IOBuf> StreamCodec::doUncompress(
IOBuf const* current = data; IOBuf const* current = data;
ByteRange input{current->data(), current->length()}; ByteRange input{current->data(), current->length()};
StreamCodec::FlushOp flushOp = StreamCodec::FlushOp::NONE; StreamCodec::FlushOp flushOp = StreamCodec::FlushOp::NONE;
for (;;) { bool done = false;
while (!done) {
while (input.empty() && current->next() != data) { while (input.empty() && current->next() != data) {
current = current->next(); current = current->next();
input = {current->data(), current->length()}; input = {current->data(), current->length()};
...@@ -416,15 +435,7 @@ std::unique_ptr<IOBuf> StreamCodec::doUncompress( ...@@ -416,15 +435,7 @@ std::unique_ptr<IOBuf> StreamCodec::doUncompress(
if (output.empty()) { if (output.empty()) {
buffer->prependChain(addOutputBuffer(output, defaultBufferLength)); buffer->prependChain(addOutputBuffer(output, defaultBufferLength));
} }
size_t const inputSize = input.size(); done = uncompressStream(input, output, flushOp);
size_t const outputSize = output.size();
bool const done = uncompressStream(input, output, flushOp);
if (done) {
break;
}
if (inputSize == input.size() && outputSize == output.size()) {
throw std::runtime_error("Codec: Truncated data");
}
} }
if (!input.empty()) { if (!input.empty()) {
throw std::runtime_error("Codec: Junk after end of data"); throw std::runtime_error("Codec: Junk after end of data");
......
...@@ -304,7 +304,8 @@ class StreamCodec : public Codec { ...@@ -304,7 +304,8 @@ class StreamCodec : public Codec {
* flushOp. * flushOp.
* *
* A std::logic_error is thrown on incorrect usage of the API. * A std::logic_error is thrown on incorrect usage of the API.
* A std::runtime_error is thrown upon error conditions. * A std::runtime_error is thrown upon error conditions or if no forward
* progress could be made twice in a row.
*/ */
bool compressStream( bool compressStream(
folly::ByteRange& input, folly::ByteRange& input,
...@@ -340,6 +341,10 @@ class StreamCodec : public Codec { ...@@ -340,6 +341,10 @@ class StreamCodec : public Codec {
* compressStream() with flushOp FLUSH. Most users don't need to use this * compressStream() with flushOp FLUSH. Most users don't need to use this
* flushOp. * flushOp.
* *
* A std::runtime_error is thrown upon error conditions or if no forward
* progress could be made upon two consecutive calls to the function (only the
* second call will throw an exception).
*
* Returns true at the end of a frame. At this point resetStream() must be * Returns true at the end of a frame. At this point resetStream() must be
* called to reuse the codec. * called to reuse the codec.
*/ */
...@@ -390,6 +395,7 @@ class StreamCodec : public Codec { ...@@ -390,6 +395,7 @@ class StreamCodec : public Codec {
State state_{State::RESET}; State state_{State::RESET};
ByteRange previousInput_{}; ByteRange previousInput_{};
folly::Optional<uint64_t> uncompressedLength_{}; folly::Optional<uint64_t> uncompressedLength_{};
bool progressMade_{true};
}; };
constexpr int COMPRESSION_LEVEL_FASTEST = -1; constexpr int COMPRESSION_LEVEL_FASTEST = -1;
......
...@@ -562,66 +562,68 @@ TEST_P(StreamingUnitTest, emptyData) { ...@@ -562,66 +562,68 @@ TEST_P(StreamingUnitTest, emptyData) {
codec_->uncompressStream(input, output, StreamCodec::FlushOp::END)); codec_->uncompressStream(input, output, StreamCodec::FlushOp::END));
} }
TEST_P(StreamingUnitTest, noForwardProgressOkay) { TEST_P(StreamingUnitTest, noForwardProgress) {
auto inBuffer = IOBuf::create(2); auto inBuffer = IOBuf::create(2);
inBuffer->writableData()[0] = 'a'; inBuffer->writableData()[0] = 'a';
inBuffer->writableData()[0] = 'a'; inBuffer->writableData()[1] = 'a';
inBuffer->append(2); inBuffer->append(2);
auto input = inBuffer->coalesce(); const auto compressed = codec_->compress(inBuffer.get());
auto compressed = codec_->compress(inBuffer.get());
auto outBuffer = IOBuf::create(codec_->maxCompressedLength(2)); auto outBuffer = IOBuf::create(codec_->maxCompressedLength(2));
MutableByteRange output{outBuffer->writableTail(), outBuffer->tailroom()};
ByteRange emptyInput; ByteRange emptyInput;
MutableByteRange emptyOutput; MutableByteRange emptyOutput;
// Compress some data to avoid empty data special casing const std::array<StreamCodec::FlushOp, 3> flushOps = {{
if (codec_->needsDataLength()) { StreamCodec::FlushOp::NONE,
codec_->resetStream(inBuffer->computeChainDataLength()); StreamCodec::FlushOp::FLUSH,
} else { StreamCodec::FlushOp::END,
codec_->resetStream(); }};
}
while (!input.empty()) {
codec_->compressStream(input, output);
}
// empty input and output is okay for flush NONE and FLUSH.
codec_->compressStream(emptyInput, emptyOutput);
codec_->compressStream(emptyInput, emptyOutput, StreamCodec::FlushOp::FLUSH);
// No progress is not okay twice in a row for all flush operations when
// compressing
for (const auto flushOp : flushOps) {
if (codec_->needsDataLength()) { if (codec_->needsDataLength()) {
codec_->resetStream(inBuffer->computeChainDataLength()); codec_->resetStream(inBuffer->computeChainDataLength());
} else { } else {
codec_->resetStream(); codec_->resetStream();
} }
input = inBuffer->coalesce(); auto input = inBuffer->coalesce();
output = {outBuffer->writableTail(), outBuffer->tailroom()}; MutableByteRange output = {outBuffer->writableTail(),
outBuffer->tailroom()};
// Compress some data to avoid empty data special casing
while (!input.empty()) { while (!input.empty()) {
codec_->compressStream(input, output); codec_->compressStream(input, output);
} }
// empty input and output is okay for flush END. EXPECT_FALSE(codec_->compressStream(emptyInput, emptyOutput, flushOp));
codec_->compressStream(emptyInput, emptyOutput, StreamCodec::FlushOp::END); EXPECT_THROW(
codec_->compressStream(emptyInput, emptyOutput, flushOp),
std::runtime_error);
}
// No progress is not okay twice in a row for all flush operations when
// uncompressing
for (const auto flushOp : flushOps) {
codec_->resetStream(); codec_->resetStream();
input = compressed->coalesce(); auto input = compressed->coalesce();
input.uncheckedSubtract(1); // Remove last byte so the operation is incomplete // Remove the last byte so the operation is incomplete
output = {inBuffer->writableData(), inBuffer->length()}; input.uncheckedSubtract(1);
MutableByteRange output = {inBuffer->writableData(), inBuffer->length()};
// Uncompress some data to avoid empty data special casing // Uncompress some data to avoid empty data special casing
while (!input.empty()) { while (!input.empty()) {
EXPECT_FALSE(codec_->uncompressStream(input, output)); EXPECT_FALSE(codec_->uncompressStream(input, output));
} }
// empty input and output is okay for all flush values. EXPECT_FALSE(codec_->uncompressStream(emptyInput, emptyOutput, flushOp));
EXPECT_FALSE(codec_->uncompressStream(emptyInput, emptyOutput)); EXPECT_THROW(
EXPECT_FALSE(codec_->uncompressStream( codec_->uncompressStream(emptyInput, emptyOutput, flushOp),
emptyInput, emptyOutput, StreamCodec::FlushOp::FLUSH)); std::runtime_error);
EXPECT_FALSE(codec_->uncompressStream( }
emptyInput, emptyOutput, StreamCodec::FlushOp::END));
} }
TEST_P(StreamingUnitTest, stateTransitions) { TEST_P(StreamingUnitTest, stateTransitions) {
auto inBuffer = IOBuf::create(1); auto inBuffer = IOBuf::create(2);
inBuffer->writableData()[0] = 'a'; inBuffer->writableData()[0] = 'a';
inBuffer->append(1); inBuffer->writableData()[1] = 'a';
inBuffer->append(2);
auto compressed = codec_->compress(inBuffer.get()); auto compressed = codec_->compress(inBuffer.get());
ByteRange const in = compressed->coalesce(); ByteRange const in = compressed->coalesce();
auto outBuffer = IOBuf::create(codec_->maxCompressedLength(in.size())); auto outBuffer = IOBuf::create(codec_->maxCompressedLength(in.size()));
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment