Commit e1d2ddd5 authored by Nick Terrell's avatar Nick Terrell Committed by Facebook Github Bot

Add zstd streaming interface

Summary:
* Add streaming interface to the `ZstdCodec`
* Implement `ZstdCodec::doCompress()` and `ZstdCodec::doUncompress()` using the streaming interface.
  [fbgs CodecType::ZSTD](https://fburl.com/pr8chg64) and check that no caller requires thread-safety.

Reviewed By: yfeldblum

Differential Revision: D5026558

fbshipit-source-id: 61faa25c71f5aef06ca2d7e0700f43214353c650
parent 74560278
...@@ -40,6 +40,7 @@ ...@@ -40,6 +40,7 @@
#endif #endif
#if FOLLY_HAVE_LIBZSTD #if FOLLY_HAVE_LIBZSTD
#define ZSTD_STATIC_LINKING_ONLY
#include <zstd.h> #include <zstd.h>
#endif #endif
...@@ -1584,13 +1585,24 @@ std::unique_ptr<IOBuf> LZMA2Codec::doUncompress( ...@@ -1584,13 +1585,24 @@ std::unique_ptr<IOBuf> LZMA2Codec::doUncompress(
#ifdef FOLLY_HAVE_LIBZSTD #ifdef FOLLY_HAVE_LIBZSTD
namespace {
void zstdFreeCStream(ZSTD_CStream* zcs) {
ZSTD_freeCStream(zcs);
}
void zstdFreeDStream(ZSTD_DStream* zds) {
ZSTD_freeDStream(zds);
}
}
/** /**
* ZSTD compression * ZSTD compression
*/ */
class ZSTDCodec final : public Codec { class ZSTDStreamCodec final : public StreamCodec {
public: public:
static std::unique_ptr<Codec> create(int level, CodecType); static std::unique_ptr<Codec> createCodec(int level, CodecType);
explicit ZSTDCodec(int level, CodecType type); static std::unique_ptr<StreamCodec> createStream(int level, CodecType);
explicit ZSTDStreamCodec(int level, CodecType type);
std::vector<std::string> validPrefixes() const override; std::vector<std::string> validPrefixes() const override;
bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength) bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
...@@ -1599,29 +1611,61 @@ class ZSTDCodec final : public Codec { ...@@ -1599,29 +1611,61 @@ class ZSTDCodec final : public Codec {
private: private:
bool doNeedsUncompressedLength() const override; bool doNeedsUncompressedLength() const override;
uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override; uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override; Optional<uint64_t> doGetUncompressedLength(
std::unique_ptr<IOBuf> doUncompress( IOBuf const* data,
const IOBuf* data, Optional<uint64_t> uncompressedLength) const override;
Optional<uint64_t> uncompressedLength) override;
void doResetStream() override;
bool doCompressStream(
ByteRange& input,
MutableByteRange& output,
StreamCodec::FlushOp flushOp) override;
bool doUncompressStream(
ByteRange& input,
MutableByteRange& output,
StreamCodec::FlushOp flushOp) override;
void resetCStream();
void resetDStream();
bool tryBlockCompress(ByteRange& input, MutableByteRange& output) const;
bool tryBlockUncompress(ByteRange& input, MutableByteRange& output) const;
int level_; int level_;
bool needReset_{true};
std::unique_ptr<
ZSTD_CStream,
folly::static_function_deleter<ZSTD_CStream, &zstdFreeCStream>>
cstream_{nullptr};
std::unique_ptr<
ZSTD_DStream,
folly::static_function_deleter<ZSTD_DStream, &zstdFreeDStream>>
dstream_{nullptr};
}; };
static constexpr uint32_t kZSTDMagicLE = 0xFD2FB528; static constexpr uint32_t kZSTDMagicLE = 0xFD2FB528;
std::vector<std::string> ZSTDCodec::validPrefixes() const { std::vector<std::string> ZSTDStreamCodec::validPrefixes() const {
return {prefixToStringLE(kZSTDMagicLE)}; return {prefixToStringLE(kZSTDMagicLE)};
} }
bool ZSTDCodec::canUncompress(const IOBuf* data, Optional<uint64_t>) const { bool ZSTDStreamCodec::canUncompress(const IOBuf* data, Optional<uint64_t>)
const {
return dataStartsWithLE(data, kZSTDMagicLE); return dataStartsWithLE(data, kZSTDMagicLE);
} }
std::unique_ptr<Codec> ZSTDCodec::create(int level, CodecType type) { std::unique_ptr<Codec> ZSTDStreamCodec::createCodec(int level, CodecType type) {
return std::make_unique<ZSTDCodec>(level, type); return make_unique<ZSTDStreamCodec>(level, type);
}
std::unique_ptr<StreamCodec> ZSTDStreamCodec::createStream(
int level,
CodecType type) {
return make_unique<ZSTDStreamCodec>(level, type);
} }
ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) { ZSTDStreamCodec::ZSTDStreamCodec(int level, CodecType type)
: StreamCodec(type) {
DCHECK(type == CodecType::ZSTD); DCHECK(type == CodecType::ZSTD);
switch (level) { switch (level) {
case COMPRESSION_LEVEL_FASTEST: case COMPRESSION_LEVEL_FASTEST:
...@@ -1641,11 +1685,12 @@ ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) { ...@@ -1641,11 +1685,12 @@ ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) {
level_ = level; level_ = level;
} }
bool ZSTDCodec::doNeedsUncompressedLength() const { bool ZSTDStreamCodec::doNeedsUncompressedLength() const {
return false; return false;
} }
uint64_t ZSTDCodec::doMaxCompressedLength(uint64_t uncompressedLength) const { uint64_t ZSTDStreamCodec::doMaxCompressedLength(
uint64_t uncompressedLength) const {
return ZSTD_compressBound(uncompressedLength); return ZSTD_compressBound(uncompressedLength);
} }
...@@ -1657,160 +1702,155 @@ void zstdThrowIfError(size_t rc) { ...@@ -1657,160 +1702,155 @@ void zstdThrowIfError(size_t rc) {
to<std::string>("ZSTD returned an error: ", ZSTD_getErrorName(rc))); to<std::string>("ZSTD returned an error: ", ZSTD_getErrorName(rc)));
} }
std::unique_ptr<IOBuf> ZSTDCodec::doCompress(const IOBuf* data) { Optional<uint64_t> ZSTDStreamCodec::doGetUncompressedLength(
// Support earlier versions of the codec (working with a single IOBuf, IOBuf const* data,
// and using ZSTD_decompress which requires ZSTD frame to contain size, Optional<uint64_t> uncompressedLength) const {
// which isn't populated by streaming API). // Read decompressed size from frame if available in first IOBuf.
if (!data->isChained()) { auto const decompressedSize =
auto out = IOBuf::createCombined(ZSTD_compressBound(data->length())); ZSTD_getDecompressedSize(data->data(), data->length());
const auto rc = ZSTD_compress( if (decompressedSize != 0) {
out->writableData(), if (uncompressedLength && *uncompressedLength != decompressedSize) {
out->capacity(), throw std::runtime_error("ZSTD: invalid uncompressed length");
data->data(),
data->length(),
level_);
zstdThrowIfError(rc);
out->append(rc);
return out;
}
auto zcs = ZSTD_createCStream();
SCOPE_EXIT {
ZSTD_freeCStream(zcs);
};
auto rc = ZSTD_initCStream(zcs, level_);
zstdThrowIfError(rc);
Cursor cursor(data);
auto result =
IOBuf::createCombined(maxCompressedLength(cursor.totalLength()));
ZSTD_outBuffer out;
out.dst = result->writableTail();
out.size = result->capacity();
out.pos = 0;
for (auto buffer = cursor.peekBytes(); !buffer.empty();) {
ZSTD_inBuffer in;
in.src = buffer.data();
in.size = buffer.size();
for (in.pos = 0; in.pos != in.size;) {
rc = ZSTD_compressStream(zcs, &out, &in);
zstdThrowIfError(rc);
} }
cursor.skip(in.size); uncompressedLength = decompressedSize;
buffer = cursor.peekBytes();
} }
return uncompressedLength;
}
rc = ZSTD_endStream(zcs, &out); void ZSTDStreamCodec::doResetStream() {
zstdThrowIfError(rc); needReset_ = true;
CHECK_EQ(rc, 0);
result->append(out.pos);
return result;
} }
static std::unique_ptr<IOBuf> zstdUncompressBuffer( bool ZSTDStreamCodec::tryBlockCompress(
const IOBuf* data, ByteRange& input,
Optional<uint64_t> uncompressedLength) { MutableByteRange& output) const {
// Check preconditions DCHECK(needReset_);
DCHECK(!data->isChained()); // We need to know that we have enough output space to use block compression
DCHECK(uncompressedLength.hasValue()); if (output.size() < ZSTD_compressBound(input.size())) {
return false;
}
size_t const length = ZSTD_compress(
output.data(), output.size(), input.data(), input.size(), level_);
zstdThrowIfError(length);
input.uncheckedAdvance(input.size());
output.uncheckedAdvance(length);
return true;
}
auto uncompressed = IOBuf::create(*uncompressedLength); void ZSTDStreamCodec::resetCStream() {
const auto decompressedSize = ZSTD_decompress( if (!cstream_) {
uncompressed->writableTail(), cstream_.reset(ZSTD_createCStream());
uncompressed->tailroom(), if (!cstream_) {
data->data(), throw std::bad_alloc{};
data->length()); }
zstdThrowIfError(decompressedSize);
if (decompressedSize != uncompressedLength) {
throw std::runtime_error("ZSTD: invalid uncompressed length");
} }
uncompressed->append(decompressedSize); // Advanced API usage works for all supported versions of zstd.
return uncompressed; // Required to set contentSizeFlag.
auto params = ZSTD_getParams(level_, uncompressedLength().value_or(0), 0);
params.fParams.contentSizeFlag = uncompressedLength().hasValue();
zstdThrowIfError(ZSTD_initCStream_advanced(
cstream_.get(), nullptr, 0, params, uncompressedLength().value_or(0)));
} }
static std::unique_ptr<IOBuf> zstdUncompressStream( bool ZSTDStreamCodec::doCompressStream(
const IOBuf* data, ByteRange& input,
Optional<uint64_t> uncompressedLength) { MutableByteRange& output,
auto zds = ZSTD_createDStream(); StreamCodec::FlushOp flushOp) {
if (needReset_) {
// If we are given all the input in one chunk try to use block compression
if (flushOp == StreamCodec::FlushOp::END &&
tryBlockCompress(input, output)) {
return true;
}
resetCStream();
needReset_ = false;
}
ZSTD_inBuffer in = {input.data(), input.size(), 0};
ZSTD_outBuffer out = {output.data(), output.size(), 0};
SCOPE_EXIT { SCOPE_EXIT {
ZSTD_freeDStream(zds); input.uncheckedAdvance(in.pos);
output.uncheckedAdvance(out.pos);
}; };
if (flushOp == StreamCodec::FlushOp::NONE || !input.empty()) {
auto rc = ZSTD_initDStream(zds); zstdThrowIfError(ZSTD_compressStream(cstream_.get(), &out, &in));
zstdThrowIfError(rc);
ZSTD_outBuffer out{};
ZSTD_inBuffer in{};
auto outputSize = uncompressedLength.value_or(ZSTD_DStreamOutSize());
IOBufQueue queue(IOBufQueue::cacheChainLength());
Cursor cursor(data);
for (rc = 0;;) {
if (in.pos == in.size) {
auto buffer = cursor.peekBytes();
in.src = buffer.data();
in.size = buffer.size();
in.pos = 0;
cursor.skip(in.size);
if (rc > 1 && in.size == 0) {
throw std::runtime_error(to<std::string>("ZSTD: incomplete input"));
}
}
if (out.pos == out.size) {
if (out.pos != 0) {
queue.postallocate(out.pos);
}
auto buffer = queue.preallocate(outputSize, outputSize);
out.dst = buffer.first;
out.size = buffer.second;
out.pos = 0;
outputSize = ZSTD_DStreamOutSize();
}
rc = ZSTD_decompressStream(zds, &out, &in);
zstdThrowIfError(rc);
if (rc == 0) {
break;
} }
if (in.pos == in.size && flushOp != StreamCodec::FlushOp::NONE) {
size_t rc;
switch (flushOp) {
case StreamCodec::FlushOp::FLUSH:
rc = ZSTD_flushStream(cstream_.get(), &out);
break;
case StreamCodec::FlushOp::END:
rc = ZSTD_endStream(cstream_.get(), &out);
break;
default:
throw std::invalid_argument("ZSTD: invalid FlushOp");
} }
if (out.pos != 0) { zstdThrowIfError(rc);
queue.postallocate(out.pos); if (rc == 0) {
return true;
} }
if (in.pos != in.size || !cursor.isAtEnd()) {
throw std::runtime_error("ZSTD: junk after end of data");
}
if (uncompressedLength && queue.chainLength() != *uncompressedLength) {
throw std::runtime_error("ZSTD: invalid uncompressed length");
} }
return false;
}
return queue.move(); bool ZSTDStreamCodec::tryBlockUncompress(
ByteRange& input,
MutableByteRange& output) const {
DCHECK(needReset_);
#if ZSTD_VERSION_NUMBER < 10104
// We require ZSTD_findFrameCompressedSize() to perform this optimization.
return false;
#else
// We need to know the uncompressed length and have enough output space.
if (!uncompressedLength() || output.size() < *uncompressedLength()) {
return false;
}
size_t const compressedLength =
ZSTD_findFrameCompressedSize(input.data(), input.size());
zstdThrowIfError(compressedLength);
size_t const length = ZSTD_decompress(
output.data(), *uncompressedLength(), input.data(), compressedLength);
zstdThrowIfError(length);
DCHECK_EQ(length, *uncompressedLength());
input.uncheckedAdvance(compressedLength);
output.uncheckedAdvance(length);
return true;
#endif
} }
std::unique_ptr<IOBuf> ZSTDCodec::doUncompress( void ZSTDStreamCodec::resetDStream() {
const IOBuf* data, if (!dstream_) {
Optional<uint64_t> uncompressedLength) { dstream_.reset(ZSTD_createDStream());
{ if (!dstream_) {
// Read decompressed size from frame if available in first IOBuf. throw std::bad_alloc{};
const auto decompressedSize =
ZSTD_getDecompressedSize(data->data(), data->length());
if (decompressedSize != 0) {
if (uncompressedLength && *uncompressedLength != decompressedSize) {
throw std::runtime_error("ZSTD: invalid uncompressed length");
} }
uncompressedLength = decompressedSize;
} }
zstdThrowIfError(ZSTD_initDStream(dstream_.get()));
}
bool ZSTDStreamCodec::doUncompressStream(
ByteRange& input,
MutableByteRange& output,
StreamCodec::FlushOp flushOp) {
if (needReset_) {
// If we are given all the input in one chunk try to use block uncompression
if (flushOp == StreamCodec::FlushOp::END &&
tryBlockUncompress(input, output)) {
return true;
} }
// Faster to decompress using ZSTD_decompress() if we can. resetDStream();
if (uncompressedLength && !data->isChained()) { needReset_ = false;
return zstdUncompressBuffer(data, uncompressedLength);
} }
// Fall back to slower streaming decompression. ZSTD_inBuffer in = {input.data(), input.size(), 0};
return zstdUncompressStream(data, uncompressedLength); ZSTD_outBuffer out = {output.data(), output.size(), 0};
SCOPE_EXIT {
input.uncheckedAdvance(in.pos);
output.uncheckedAdvance(out.pos);
};
size_t const rc = ZSTD_decompressStream(dstream_.get(), &out, &in);
zstdThrowIfError(rc);
return rc == 0;
} }
#endif // FOLLY_HAVE_LIBZSTD #endif // FOLLY_HAVE_LIBZSTD
...@@ -2229,7 +2269,7 @@ constexpr Factory ...@@ -2229,7 +2269,7 @@ constexpr Factory
#endif #endif
#if FOLLY_HAVE_LIBZSTD #if FOLLY_HAVE_LIBZSTD
{ZSTDCodec::create, nullptr}, {ZSTDStreamCodec::createCodec, ZSTDStreamCodec::createStream},
#else #else
{}, {},
#endif #endif
......
...@@ -34,6 +34,10 @@ ...@@ -34,6 +34,10 @@
#include <folly/io/IOBufQueue.h> #include <folly/io/IOBufQueue.h>
#include <folly/portability/GTest.h> #include <folly/portability/GTest.h>
#if FOLLY_HAVE_LIBZSTD
#include <zstd.h>
#endif
namespace folly { namespace io { namespace test { namespace folly { namespace io { namespace test {
class DataHolder : private boost::noncopyable { class DataHolder : private boost::noncopyable {
...@@ -1084,6 +1088,31 @@ TEST(CheckCompatibleTest, ZlibIsPrefix) { ...@@ -1084,6 +1088,31 @@ TEST(CheckCompatibleTest, ZlibIsPrefix) {
EXPECT_THROW_IF_DEBUG( EXPECT_THROW_IF_DEBUG(
getAutoUncompressionCodec(std::move(codecs)), std::invalid_argument); getAutoUncompressionCodec(std::move(codecs)), std::invalid_argument);
} }
#if FOLLY_HAVE_LIBZSTD
TEST(ZstdTest, BackwardCompatible) {
auto codec = getCodec(CodecType::ZSTD);
{
auto const data = IOBuf::wrapBuffer(randomDataHolder.data(size_t(1) << 20));
auto compressed = codec->compress(data.get());
compressed->coalesce();
EXPECT_EQ(
data->length(),
ZSTD_getDecompressedSize(compressed->data(), compressed->length()));
}
{
auto const data =
IOBuf::wrapBuffer(randomDataHolder.data(size_t(100) << 20));
auto compressed = codec->compress(data.get());
compressed->coalesce();
EXPECT_EQ(
data->length(),
ZSTD_getDecompressedSize(compressed->data(), compressed->length()));
}
}
#endif
}}} // namespaces }}} // namespaces
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment