Commit 151e22b2 authored by Philip Pronin's avatar Philip Pronin Committed by Facebook Github Bot 4

fix ZSTD support

Summary:
Existing logic is broken (unable to correctly handle chained `IOBuf`
in case of both `compress` and `uncompress`) and has unnecessarly strict
`needsUncompressedLength() == true` requirement.

This diff switches `ZSTDCodec` to use streaming to handle chained `IOBuf`,
drops `needsUncompressedLength() == true`.

Reviewed By: luciang

Differential Revision: D3827579

fbshipit-source-id: 0ef6a9ea664ef585d0e181bff6ca17166b28efc2
parent 8d329050
...@@ -983,55 +983,137 @@ ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) { ...@@ -983,55 +983,137 @@ ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) {
} }
bool ZSTDCodec::doNeedsUncompressedLength() const { bool ZSTDCodec::doNeedsUncompressedLength() const {
return true; return false;
}
void zstdThrowIfError(size_t rc) {
if (!ZSTD_isError(rc)) {
return;
}
throw std::runtime_error(
to<std::string>("ZSTD returned an error: ", ZSTD_getErrorName(rc)));
} }
std::unique_ptr<IOBuf> ZSTDCodec::doCompress(const IOBuf* data) { std::unique_ptr<IOBuf> ZSTDCodec::doCompress(const IOBuf* data) {
size_t rc; // Support earlier versions of the codec (working with a single IOBuf,
size_t maxCompressedLength = ZSTD_compressBound(data->length()); // and using ZSTD_decompress which requires ZSTD frame to contain size,
auto out = IOBuf::createCombined(maxCompressedLength); // which isn't populated by streaming API).
if (!data->isChained()) {
auto out = IOBuf::createCombined(ZSTD_compressBound(data->length()));
const auto rc = ZSTD_compress(
out->writableData(),
out->capacity(),
data->data(),
data->length(),
level_);
zstdThrowIfError(rc);
out->append(rc);
return out;
}
CHECK_EQ(out->length(), 0); auto zcs = ZSTD_createCStream();
SCOPE_EXIT {
ZSTD_freeCStream(zcs);
};
rc = ZSTD_compress(out->writableTail(), auto rc = ZSTD_initCStream(zcs, level_);
out->capacity(), zstdThrowIfError(rc);
data->data(),
data->length(),
level_);
if (ZSTD_isError(rc)) { Cursor cursor(data);
throw std::runtime_error(to<std::string>( auto result = IOBuf::createCombined(ZSTD_compressBound(cursor.totalLength()));
"ZSTD compression returned an error: ",
ZSTD_getErrorName(rc))); ZSTD_outBuffer out;
out.dst = result->writableTail();
out.size = result->capacity();
out.pos = 0;
for (auto buffer = cursor.peekBytes(); !buffer.empty();) {
ZSTD_inBuffer in;
in.src = buffer.data();
in.size = buffer.size();
for (in.pos = 0; in.pos != in.size;) {
rc = ZSTD_compressStream(zcs, &out, &in);
zstdThrowIfError(rc);
}
cursor.skip(in.size);
buffer = cursor.peekBytes();
} }
out->append(rc); rc = ZSTD_endStream(zcs, &out);
CHECK_EQ(out->length(), rc); zstdThrowIfError(rc);
CHECK_EQ(rc, 0);
return out; result->append(out.pos);
return result;
} }
std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(const IOBuf* data, std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(
uint64_t uncompressedLength) { const IOBuf* data,
size_t rc; uint64_t uncompressedLength) {
auto out = IOBuf::createCombined(uncompressedLength); auto zds = ZSTD_createDStream();
SCOPE_EXIT {
ZSTD_freeDStream(zds);
};
CHECK_GE(out->capacity(), uncompressedLength); auto rc = ZSTD_initDStream(zds);
CHECK_EQ(out->length(), 0); zstdThrowIfError(rc);
rc = ZSTD_decompress( ZSTD_outBuffer out{};
out->writableTail(), out->capacity(), data->data(), data->length()); ZSTD_inBuffer in{};
if (ZSTD_isError(rc)) { auto outputSize = ZSTD_DStreamOutSize();
throw std::runtime_error(to<std::string>( if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH) {
"ZSTD decompression returned an error: ", outputSize = uncompressedLength;
ZSTD_getErrorName(rc))); } else {
auto decompressedSize =
ZSTD_getDecompressedSize(data->data(), data->length());
if (decompressedSize != 0 && decompressedSize < outputSize) {
outputSize = decompressedSize;
}
} }
out->append(rc); IOBufQueue queue(IOBufQueue::cacheChainLength());
CHECK_EQ(out->length(), rc);
Cursor cursor(data);
for (rc = 0;;) {
if (in.pos == in.size) {
auto buffer = cursor.peekBytes();
in.src = buffer.data();
in.size = buffer.size();
in.pos = 0;
cursor.skip(in.size);
if (rc > 1 && in.size == 0) {
throw std::runtime_error(to<std::string>("ZSTD: incomplete input"));
}
}
if (out.pos == out.size) {
if (out.pos != 0) {
queue.postallocate(out.pos);
}
auto buffer = queue.preallocate(outputSize, outputSize);
out.dst = buffer.first;
out.size = buffer.second;
out.pos = 0;
outputSize = ZSTD_DStreamOutSize();
}
rc = ZSTD_decompressStream(zds, &out, &in);
zstdThrowIfError(rc);
if (rc == 0) {
break;
}
}
if (out.pos != 0) {
queue.postallocate(out.pos);
}
if (in.pos != in.size || !cursor.isAtEnd()) {
throw std::runtime_error("ZSTD: junk after end of data");
}
if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
queue.chainLength() != uncompressedLength) {
throw std::runtime_error("ZSTD: invalid uncompressed length");
}
return out; return queue.move();
} }
#endif // FOLLY_HAVE_LIBZSTD #endif // FOLLY_HAVE_LIBZSTD
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <folly/io/Compression.h> #include <folly/io/Compression.h>
#include <random> #include <random>
#include <set>
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
...@@ -128,31 +129,35 @@ TEST(CompressionTestNeedsUncompressedLength, Simple) { ...@@ -128,31 +129,35 @@ TEST(CompressionTestNeedsUncompressedLength, Simple) {
EXPECT_TRUE(getCodec(CodecType::LZMA2)->needsUncompressedLength()); EXPECT_TRUE(getCodec(CodecType::LZMA2)->needsUncompressedLength());
EXPECT_FALSE(getCodec(CodecType::LZMA2_VARINT_SIZE) EXPECT_FALSE(getCodec(CodecType::LZMA2_VARINT_SIZE)
->needsUncompressedLength()); ->needsUncompressedLength());
EXPECT_TRUE(getCodec(CodecType::ZSTD)->needsUncompressedLength()); EXPECT_FALSE(getCodec(CodecType::ZSTD)->needsUncompressedLength());
EXPECT_FALSE(getCodec(CodecType::GZIP)->needsUncompressedLength()); EXPECT_FALSE(getCodec(CodecType::GZIP)->needsUncompressedLength());
} }
class CompressionTest class CompressionTest
: public testing::TestWithParam<std::tr1::tuple<int, CodecType>> { : public testing::TestWithParam<std::tr1::tuple<int, int, CodecType>> {
protected: protected:
void SetUp() override { void SetUp() override {
auto tup = GetParam(); auto tup = GetParam();
uncompressedLength_ = uint64_t(1) << std::tr1::get<0>(tup); uncompressedLength_ = uint64_t(1) << std::tr1::get<0>(tup);
codec_ = getCodec(std::tr1::get<1>(tup)); chunks_ = std::tr1::get<1>(tup);
} codec_ = getCodec(std::tr1::get<2>(tup));
}
void runSimpleTest(const DataHolder& dh); void runSimpleTest(const DataHolder& dh);
uint64_t uncompressedLength_; private:
std::unique_ptr<Codec> codec_; std::unique_ptr<IOBuf> split(std::unique_ptr<IOBuf> data) const;
uint64_t uncompressedLength_;
size_t chunks_;
std::unique_ptr<Codec> codec_;
}; };
void CompressionTest::runSimpleTest(const DataHolder& dh) { void CompressionTest::runSimpleTest(const DataHolder& dh) {
auto original = IOBuf::wrapBuffer(dh.data(uncompressedLength_)); const auto original = split(IOBuf::wrapBuffer(dh.data(uncompressedLength_)));
auto compressed = codec_->compress(original.get()); const auto compressed = split(codec_->compress(original.get()));
if (!codec_->needsUncompressedLength()) { if (!codec_->needsUncompressedLength()) {
auto uncompressed = codec_->uncompress(compressed.get()); auto uncompressed = codec_->uncompress(compressed.get());
EXPECT_EQ(uncompressedLength_, uncompressed->computeChainDataLength()); EXPECT_EQ(uncompressedLength_, uncompressed->computeChainDataLength());
EXPECT_EQ(dh.hash(uncompressedLength_), hashIOBuf(uncompressed.get())); EXPECT_EQ(dh.hash(uncompressedLength_), hashIOBuf(uncompressed.get()));
} }
...@@ -164,6 +169,32 @@ void CompressionTest::runSimpleTest(const DataHolder& dh) { ...@@ -164,6 +169,32 @@ void CompressionTest::runSimpleTest(const DataHolder& dh) {
} }
} }
// Uniformly split data into (potentially empty) chunks.
std::unique_ptr<IOBuf> CompressionTest::split(
std::unique_ptr<IOBuf> data) const {
if (data->isChained()) {
data->coalesce();
}
const size_t size = data->computeChainDataLength();
std::multiset<size_t> splits;
for (size_t i = 1; i < chunks_; ++i) {
splits.insert(Random::rand64(size));
}
folly::IOBufQueue result;
size_t offset = 0;
for (size_t split : splits) {
result.append(IOBuf::copyBuffer(data->data() + offset, split - offset));
offset = split;
}
result.append(IOBuf::copyBuffer(data->data() + offset, size - offset));
return result.move();
}
TEST_P(CompressionTest, RandomData) { TEST_P(CompressionTest, RandomData) {
runSimpleTest(randomDataHolder); runSimpleTest(randomDataHolder);
} }
...@@ -175,16 +206,19 @@ TEST_P(CompressionTest, ConstantData) { ...@@ -175,16 +206,19 @@ TEST_P(CompressionTest, ConstantData) {
INSTANTIATE_TEST_CASE_P( INSTANTIATE_TEST_CASE_P(
CompressionTest, CompressionTest,
CompressionTest, CompressionTest,
testing::Combine(testing::Values(0, 1, 12, 22, 25, 27), testing::Combine(
testing::Values(CodecType::NO_COMPRESSION, testing::Values(0, 1, 12, 22, 25, 27),
CodecType::LZ4, testing::Values(1, 2, 3, 8, 65),
CodecType::SNAPPY, testing::Values(
CodecType::ZLIB, CodecType::NO_COMPRESSION,
CodecType::LZ4_VARINT_SIZE, CodecType::LZ4,
CodecType::LZMA2, CodecType::SNAPPY,
CodecType::LZMA2_VARINT_SIZE, CodecType::ZLIB,
CodecType::ZSTD, CodecType::LZ4_VARINT_SIZE,
CodecType::GZIP))); CodecType::LZMA2,
CodecType::LZMA2_VARINT_SIZE,
CodecType::ZSTD,
CodecType::GZIP)));
class CompressionVarintTest class CompressionVarintTest
: public testing::TestWithParam<std::tr1::tuple<int, CodecType>> { : public testing::TestWithParam<std::tr1::tuple<int, CodecType>> {
...@@ -227,7 +261,9 @@ void CompressionVarintTest::runSimpleTest(const DataHolder& dh) { ...@@ -227,7 +261,9 @@ void CompressionVarintTest::runSimpleTest(const DataHolder& dh) {
EXPECT_EQ(dh.hash(uncompressedLength_), hashIOBuf(uncompressed.get())); EXPECT_EQ(dh.hash(uncompressedLength_), hashIOBuf(uncompressed.get()));
} }
TEST_P(CompressionVarintTest, RandomData) { runSimpleTest(randomDataHolder); } TEST_P(CompressionVarintTest, RandomData) {
runSimpleTest(randomDataHolder);
}
TEST_P(CompressionVarintTest, ConstantData) { TEST_P(CompressionVarintTest, ConstantData) {
runSimpleTest(constantDataHolder); runSimpleTest(constantDataHolder);
...@@ -236,9 +272,11 @@ TEST_P(CompressionVarintTest, ConstantData) { ...@@ -236,9 +272,11 @@ TEST_P(CompressionVarintTest, ConstantData) {
INSTANTIATE_TEST_CASE_P( INSTANTIATE_TEST_CASE_P(
CompressionVarintTest, CompressionVarintTest,
CompressionVarintTest, CompressionVarintTest,
testing::Combine(testing::Values(0, 1, 12, 22, 25, 27), testing::Combine(
testing::Values(CodecType::LZ4_VARINT_SIZE, testing::Values(0, 1, 12, 22, 25, 27),
CodecType::LZMA2_VARINT_SIZE))); testing::Values(
CodecType::LZ4_VARINT_SIZE,
CodecType::LZMA2_VARINT_SIZE)));
class CompressionCorruptionTest : public testing::TestWithParam<CodecType> { class CompressionCorruptionTest : public testing::TestWithParam<CodecType> {
protected: protected:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment