Commit 68112fa0 authored by Christopher Dykes's avatar Christopher Dykes Committed by Facebook Github Bot

Fix some implicit truncations in the interaction with OpenSSL APIs

Summary: MSVC has the ability to warn about implicit truncations and places where implicit sign coercions are occuring, so do some cleanup to make it possible to compile with the warnings enabled.

Reviewed By: yfeldblum

Differential Revision: D4288028

fbshipit-source-id: f8330c62b2dcb76f696dfc47888f0e3e1eefc21a
parent c9fc2e32
...@@ -278,7 +278,7 @@ void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) { ...@@ -278,7 +278,7 @@ void SSLContext::loadCertificateFromBufferPEM(folly::StringPiece cert) {
throw std::runtime_error("BIO_new: " + getErrors()); throw std::runtime_error("BIO_new: " + getErrors());
} }
int written = BIO_write(bio.get(), cert.data(), cert.size()); int written = BIO_write(bio.get(), cert.data(), int(cert.size()));
if (written <= 0 || static_cast<unsigned>(written) != cert.size()) { if (written <= 0 || static_cast<unsigned>(written) != cert.size()) {
throw std::runtime_error("BIO_write: " + getErrors()); throw std::runtime_error("BIO_write: " + getErrors());
} }
...@@ -318,7 +318,7 @@ void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) { ...@@ -318,7 +318,7 @@ void SSLContext::loadPrivateKeyFromBufferPEM(folly::StringPiece pkey) {
throw std::runtime_error("BIO_new: " + getErrors()); throw std::runtime_error("BIO_new: " + getErrors());
} }
int written = BIO_write(bio.get(), pkey.data(), pkey.size()); int written = BIO_write(bio.get(), pkey.data(), int(pkey.size()));
if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) { if (written <= 0 || static_cast<unsigned>(written) != pkey.size()) {
throw std::runtime_error("BIO_write: " + getErrors()); throw std::runtime_error("BIO_write: " + getErrors());
} }
...@@ -517,12 +517,12 @@ bool SSLContext::setRandomizedAdvertisedNextProtocols( ...@@ -517,12 +517,12 @@ bool SSLContext::setRandomizedAdvertisedNextProtocols(
advertised_item.length = 0; advertised_item.length = 0;
for (const auto& proto : item.protocols) { for (const auto& proto : item.protocols) {
++advertised_item.length; ++advertised_item.length;
unsigned protoLength = proto.length(); auto protoLength = proto.length();
if (protoLength >= 256) { if (protoLength >= 256) {
deleteNextProtocolsStrings(); deleteNextProtocolsStrings();
return false; return false;
} }
advertised_item.length += protoLength; advertised_item.length += unsigned(protoLength);
} }
advertised_item.protocols = new unsigned char[advertised_item.length]; advertised_item.protocols = new unsigned char[advertised_item.length];
if (!advertised_item.protocols) { if (!advertised_item.protocols) {
...@@ -530,7 +530,7 @@ bool SSLContext::setRandomizedAdvertisedNextProtocols( ...@@ -530,7 +530,7 @@ bool SSLContext::setRandomizedAdvertisedNextProtocols(
} }
unsigned char* dst = advertised_item.protocols; unsigned char* dst = advertised_item.protocols;
for (auto& proto : item.protocols) { for (auto& proto : item.protocols) {
unsigned protoLength = proto.length(); uint8_t protoLength = uint8_t(proto.length());
*dst++ = (unsigned char)protoLength; *dst++ = (unsigned char)protoLength;
memcpy(dst, proto.data(), protoLength); memcpy(dst, proto.data(), protoLength);
dst += protoLength; dst += protoLength;
...@@ -715,7 +715,7 @@ int SSLContext::passwordCallback(char* password, ...@@ -715,7 +715,7 @@ int SSLContext::passwordCallback(char* password,
std::string userPassword; std::string userPassword;
// call user defined password collector to get password // call user defined password collector to get password
context->passwordCollector()->getPassword(userPassword, size); context->passwordCollector()->getPassword(userPassword, size);
int length = userPassword.size(); auto length = int(userPassword.size());
if (length > size) { if (length > size) {
length = size; length = size;
} }
......
...@@ -169,7 +169,7 @@ static std::unordered_map<uint16_t, std::string> getOpenSSLCipherNames() { ...@@ -169,7 +169,7 @@ static std::unordered_map<uint16_t, std::string> getOpenSSLCipherNames() {
}; };
STACK_OF(SSL_CIPHER)* sk = SSL_get_ciphers(ssl); STACK_OF(SSL_CIPHER)* sk = SSL_get_ciphers(ssl);
for (size_t i = 0; i < (size_t)sk_SSL_CIPHER_num(sk); i++) { for (int i = 0; i < sk_SSL_CIPHER_num(sk); i++) {
const SSL_CIPHER* c = sk_SSL_CIPHER_value(sk, i); const SSL_CIPHER* c = sk_SSL_CIPHER_value(sk, i);
unsigned long id = SSL_CIPHER_get_id(c); unsigned long id = SSL_CIPHER_get_id(c);
// OpenSSL 1.0.2 and prior does weird things such as stuff the SSL/TLS // OpenSSL 1.0.2 and prior does weird things such as stuff the SSL/TLS
......
...@@ -106,7 +106,7 @@ class OpenSSLHash { ...@@ -106,7 +106,7 @@ class OpenSSLHash {
void hash_init(const EVP_MD* md, ByteRange key) { void hash_init(const EVP_MD* md, ByteRange key) {
md_ = md; md_ = md;
check_libssl_result( check_libssl_result(
1, HMAC_Init_ex(&ctx_, key.data(), key.size(), md_, nullptr)); 1, HMAC_Init_ex(&ctx_, key.data(), int(key.size()), md_, nullptr));
} }
void hash_update(ByteRange data) { void hash_update(ByteRange data) {
check_libssl_result(1, HMAC_Update(&ctx_, data.data(), data.size())); check_libssl_result(1, HMAC_Update(&ctx_, data.data(), data.size()));
...@@ -121,7 +121,7 @@ class OpenSSLHash { ...@@ -121,7 +121,7 @@ class OpenSSLHash {
check_out_size(size, out); check_out_size(size, out);
unsigned int len = 0; unsigned int len = 0;
check_libssl_result(1, HMAC_Final(&ctx_, out.data(), &len)); check_libssl_result(1, HMAC_Final(&ctx_, out.data(), &len));
check_libssl_result(size, len); check_libssl_result(size, int(len));
md_ = nullptr; md_ = nullptr;
} }
private: private:
......
...@@ -40,8 +40,8 @@ SSLSessionImpl::SSLSessionImpl(SSL_SESSION* session, bool takeOwnership) ...@@ -40,8 +40,8 @@ SSLSessionImpl::SSLSessionImpl(SSL_SESSION* session, bool takeOwnership)
SSLSessionImpl::SSLSessionImpl(const std::string& serializedSession) { SSLSessionImpl::SSLSessionImpl(const std::string& serializedSession) {
auto sessionData = auto sessionData =
reinterpret_cast<const unsigned char*>(serializedSession.data()); reinterpret_cast<const unsigned char*>(serializedSession.data());
if ((session_ = d2i_SSL_SESSION( auto longLen = long(serializedSession.length());
nullptr, &sessionData, serializedSession.length())) == nullptr) { if ((session_ = d2i_SSL_SESSION(nullptr, &sessionData, longLen)) == nullptr) {
throw std::runtime_error("Cannot deserialize SSLSession string"); throw std::runtime_error("Cannot deserialize SSLSession string");
} }
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment