Commit 0ee04209 authored by Ilya Maykov's avatar Ilya Maykov Committed by Facebook GitHub Bot

make OpenSSLHash lazy-allocate the context, make move operations noexcept

Summary: See title.

Reviewed By: yfeldblum

Differential Revision: D30923907

fbshipit-source-id: 8cbe749fa4662a171a247c8c16f6b9bc7b587e30
parent 7b71f5e5
...@@ -31,17 +31,11 @@ class OpenSSLHash { ...@@ -31,17 +31,11 @@ class OpenSSLHash {
public: public:
class Digest { class Digest {
public: public:
Digest() { check_context_notnull(); } Digest() noexcept = default;
Digest(const Digest& that) { Digest(const Digest& that) { copy_impl(that); }
check_context_notnull();
copy_impl(that);
}
Digest(Digest&& that) noexcept(false) { Digest(Digest&& that) noexcept { move_impl(std::move(that)); }
check_context_notnull();
move_impl(std::move(that));
}
Digest& operator=(const Digest& that) { Digest& operator=(const Digest& that) {
if (this != &that) { if (this != &that) {
...@@ -50,7 +44,7 @@ class OpenSSLHash { ...@@ -50,7 +44,7 @@ class OpenSSLHash {
return *this; return *this;
} }
Digest& operator=(Digest&& that) noexcept(false) { Digest& operator=(Digest&& that) noexcept {
if (this != &that) { if (this != &that) {
move_impl(std::move(that)); move_impl(std::move(that));
that.hash_reset(); that.hash_reset();
...@@ -59,11 +53,22 @@ class OpenSSLHash { ...@@ -59,11 +53,22 @@ class OpenSSLHash {
} }
void hash_init(const EVP_MD* md) { void hash_init(const EVP_MD* md) {
md_ = md; if (nullptr == ctx_) {
ctx_.reset(EVP_MD_CTX_new());
if (nullptr == ctx_) {
throw_exception<std::runtime_error>(
"EVP_MD_CTX_new() returned nullptr");
}
}
check_libssl_result(1, EVP_DigestInit_ex(ctx_.get(), md, nullptr)); check_libssl_result(1, EVP_DigestInit_ex(ctx_.get(), md, nullptr));
md_ = md;
} }
void hash_update(ByteRange data) { void hash_update(ByteRange data) {
if (nullptr == ctx_) {
throw_exception<std::runtime_error>(
"hash_update() called without hash_init()");
}
check_libssl_result( check_libssl_result(
1, EVP_DigestUpdate(ctx_.get(), data.data(), data.size())); 1, EVP_DigestUpdate(ctx_.get(), data.data(), data.size()));
} }
...@@ -75,32 +80,29 @@ class OpenSSLHash { ...@@ -75,32 +80,29 @@ class OpenSSLHash {
} }
void hash_final(MutableByteRange out) { void hash_final(MutableByteRange out) {
if (nullptr == ctx_) {
throw_exception<std::runtime_error>(
"hash_final() called without hash_init()");
}
const auto size = EVP_MD_size(md_); const auto size = EVP_MD_size(md_);
check_out_size(size_t(size), out); check_out_size(size_t(size), out);
unsigned int len = 0; unsigned int len = 0;
check_libssl_result(1, EVP_DigestFinal_ex(ctx_.get(), out.data(), &len)); check_libssl_result(1, EVP_DigestFinal_ex(ctx_.get(), out.data(), &len));
check_libssl_result(size, int(len)); check_libssl_result(size, int(len));
md_ = nullptr; hash_reset();
} }
private: private:
const EVP_MD* md_{nullptr}; const EVP_MD* md_{nullptr};
EvpMdCtxUniquePtr ctx_{EVP_MD_CTX_new()}; EvpMdCtxUniquePtr ctx_{nullptr};
void hash_reset() { void hash_reset() noexcept {
ctx_.reset(nullptr);
md_ = nullptr; md_ = nullptr;
check_libssl_result(1, EVP_MD_CTX_reset(ctx_.get()));
}
void check_context_notnull() {
if (nullptr == ctx_) {
throw_exception<std::runtime_error>(
"EVP_MD_CTX_new() returned nullptr");
}
} }
void copy_impl(const Digest& that) { void copy_impl(const Digest& that) {
if (that.md_ != nullptr) { if (that.md_ != nullptr && that.ctx_ != nullptr) {
hash_init(that.md_); hash_init(that.md_);
check_libssl_result(1, EVP_MD_CTX_copy_ex(ctx_.get(), that.ctx_.get())); check_libssl_result(1, EVP_MD_CTX_copy_ex(ctx_.get(), that.ctx_.get()));
} else { } else {
......
...@@ -82,6 +82,27 @@ TEST_F(OpenSSLHashTest, sha256_hashcopy_self) { ...@@ -82,6 +82,27 @@ TEST_F(OpenSSLHashTest, sha256_hashcopy_self) {
EXPECT_EQ(expected, actual); EXPECT_EQ(expected, actual);
} }
TEST_F(OpenSSLHashTest, sha256_hashcopy_from_default_constructed) {
std::array<uint8_t, 32> expected, actual1, actual2;
constexpr StringPiece data{"foobar"};
OpenSSLHash::hash(range(expected), EVP_sha256(), data);
OpenSSLHash::Digest digest;
OpenSSLHash::Digest copy1(digest); // copy constructor
OpenSSLHash::Digest copy2 = digest; // copy assignment operator
copy1.hash_init(EVP_sha256());
copy1.hash_update(ByteRange(data));
copy1.hash_final(range(actual1));
EXPECT_EQ(expected, actual1);
copy2.hash_init(EVP_sha256());
copy2.hash_update(ByteRange(data));
copy2.hash_final(range(actual2));
EXPECT_EQ(expected, actual2);
}
TEST_F(OpenSSLHashTest, sha256_hashmove) { TEST_F(OpenSSLHashTest, sha256_hashmove) {
std::array<uint8_t, 32> expected, actual1, actual2; std::array<uint8_t, 32> expected, actual1, actual2;
constexpr StringPiece data{"foobar"}; constexpr StringPiece data{"foobar"};
...@@ -121,6 +142,28 @@ TEST_F(OpenSSLHashTest, sha256_hashmove_self) { ...@@ -121,6 +142,28 @@ TEST_F(OpenSSLHashTest, sha256_hashmove_self) {
EXPECT_EQ(expected, actual); EXPECT_EQ(expected, actual);
} }
TEST_F(OpenSSLHashTest, sha256_hashmove_from_default_constructed) {
std::array<uint8_t, 32> expected, actual1, actual2;
constexpr StringPiece data{"foobar"};
OpenSSLHash::hash(range(expected), EVP_sha256(), data);
OpenSSLHash::Digest digest1;
OpenSSLHash::Digest copy1(std::move(digest1)); // move constructor
OpenSSLHash::Digest digest2;
OpenSSLHash::Digest copy2 = std::move(digest2); // move assignment operator
copy1.hash_init(EVP_sha256());
copy1.hash_update(ByteRange(data));
copy1.hash_final(range(actual1));
EXPECT_EQ(expected, actual1);
copy2.hash_init(EVP_sha256());
copy2.hash_update(ByteRange(data));
copy2.hash_final(range(actual2));
EXPECT_EQ(expected, actual2);
}
TEST_F(OpenSSLHashTest, sha256_hashcopy_intermediate) { TEST_F(OpenSSLHashTest, sha256_hashcopy_intermediate) {
std::array<uint8_t, 32> expected, actual1, actual2; std::array<uint8_t, 32> expected, actual1, actual2;
constexpr StringPiece data1("foo"); constexpr StringPiece data1("foo");
...@@ -172,6 +215,17 @@ TEST_F(OpenSSLHashTest, sha256_hashmove_intermediate) { ...@@ -172,6 +215,17 @@ TEST_F(OpenSSLHashTest, sha256_hashmove_intermediate) {
digest.hash_init(EVP_sha256()); digest.hash_init(EVP_sha256());
} }
TEST_F(OpenSSLHashTest, digest_update_without_init_throws) {
OpenSSLHash::Digest digest;
EXPECT_THROW(digest.hash_update(ByteRange{}), std::runtime_error);
}
TEST_F(OpenSSLHashTest, digest_final_without_init_throws) {
OpenSSLHash::Digest digest;
std::array<uint8_t, 32> out;
EXPECT_THROW(digest.hash_final(range(out)), std::runtime_error);
}
TEST_F(OpenSSLHashTest, hmac_sha256) { TEST_F(OpenSSLHashTest, hmac_sha256) {
auto key = ByteRange(StringPiece("qwerty")); auto key = ByteRange(StringPiece("qwerty"));
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment