Commit c2b9a896 authored by Neel Goyal's avatar Neel Goyal Committed by Facebook Github Bot

Use AsyncTransportCertificate interfaces

Summary: Have AsyncSSLSocket use AsyncTransportCertificate interfaces properly

Reviewed By: mingtaoy

Differential Revision: D9752031

fbshipit-source-id: c65c0b808d82843bf1111bb650fe140ac98723b8
parent 44535e79
...@@ -70,6 +70,25 @@ inline bool zero_return(int error, int rc) { ...@@ -70,6 +70,25 @@ inline bool zero_return(int error, int rc) {
return (error == SSL_ERROR_ZERO_RETURN || (rc == 0 && errno == 0)); return (error == SSL_ERROR_ZERO_RETURN || (rc == 0 && errno == 0));
} }
class AsyncSSLCertificate : public folly::AsyncTransportCertificate {
public:
// assumed to be non null
explicit AsyncSSLCertificate(folly::ssl::X509UniquePtr x509)
: x509_(std::move(x509)) {}
folly::ssl::X509UniquePtr getX509() const override {
X509_up_ref(x509_.get());
return folly::ssl::X509UniquePtr(x509_.get());
}
std::string getIdentity() const override {
return OpenSSLUtils::getCommonName(x509_.get());
}
private:
folly::ssl::X509UniquePtr x509_;
};
class AsyncSSLSocketConnector : public AsyncSocket::ConnectCallback, class AsyncSSLSocketConnector : public AsyncSocket::ConnectCallback,
public AsyncSSLSocket::HandshakeCB { public AsyncSSLSocket::HandshakeCB {
private: private:
...@@ -931,6 +950,38 @@ int AsyncSSLSocket::getSSLCertSize() const { ...@@ -931,6 +950,38 @@ int AsyncSSLSocket::getSSLCertSize() const {
return certSize; return certSize;
} }
const AsyncTransportCertificate* AsyncSSLSocket::getPeerCertificate() const {
if (peerCertData_) {
return peerCertData_.get();
}
if (ssl_ != nullptr) {
auto peerX509 = SSL_get_peer_certificate(ssl_);
if (peerX509) {
// already up ref'd
folly::ssl::X509UniquePtr peer(peerX509);
peerCertData_ = std::make_unique<AsyncSSLCertificate>(std::move(peer));
}
}
return peerCertData_.get();
}
const AsyncTransportCertificate* AsyncSSLSocket::getSelfCertificate() const {
if (selfCertData_) {
return selfCertData_.get();
}
if (ssl_ != nullptr) {
auto selfX509 = SSL_get_certificate(ssl_);
if (selfX509) {
// need to upref
X509_up_ref(selfX509);
folly::ssl::X509UniquePtr peer(selfX509);
selfCertData_ = std::make_unique<AsyncSSLCertificate>(std::move(peer));
}
}
return selfCertData_.get();
}
// TODO: deprecate/remove in favor of getSelfCertificate.
const X509* AsyncSSLSocket::getSelfCert() const { const X509* AsyncSSLSocket::getSelfCert() const {
return (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr; return (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
} }
......
...@@ -741,16 +741,18 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -741,16 +741,18 @@ class AsyncSSLSocket : public virtual AsyncSocket {
*/ */
void setBufferMovableEnabled(bool enabled); void setBufferMovableEnabled(bool enabled);
const AsyncTransportCertificate* getPeerCertificate() const override;
const AsyncTransportCertificate* getSelfCertificate() const override;
/** /**
* Returns the peer certificate, or nullptr if no peer certificate received. * Returns the peer certificate, or nullptr if no peer certificate received.
*/ */
ssl::X509UniquePtr getPeerCert() const override { ssl::X509UniquePtr getPeerCert() const override {
if (!ssl_) { auto peerCert = getPeerCertificate();
if (!peerCert) {
return nullptr; return nullptr;
} }
return peerCert->getX509();
X509* cert = SSL_get_peer_certificate(ssl_);
return ssl::X509UniquePtr(cert);
} }
/** /**
......
...@@ -1269,8 +1269,11 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -1269,8 +1269,11 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
bool zeroCopyEnabled_{false}; bool zeroCopyEnabled_{false};
bool zeroCopyVal_{false}; bool zeroCopyVal_{false};
std::unique_ptr<const AsyncTransportCertificate> peerCertData_{nullptr}; // subclasses may cache these on first call to get
std::unique_ptr<const AsyncTransportCertificate> selfCertData_{nullptr}; mutable std::unique_ptr<const AsyncTransportCertificate> peerCertData_{
nullptr};
mutable std::unique_ptr<const AsyncTransportCertificate> selfCertData_{
nullptr};
}; };
#ifdef _MSC_VER #ifdef _MSC_VER
#pragma vtordisp(pop) #pragma vtordisp(pop)
......
...@@ -312,6 +312,18 @@ void OpenSSLUtils::setBioFd(BIO* b, int fd, int flags) { ...@@ -312,6 +312,18 @@ void OpenSSLUtils::setBioFd(BIO* b, int fd, int flags) {
BIO_set_fd(b, sock, flags); BIO_set_fd(b, sock, flags);
} }
std::string OpenSSLUtils::getCommonName(X509* x509) {
if (x509 == nullptr) {
return "";
}
X509_NAME* subject = X509_get_subject_name(x509);
std::string cn;
cn.resize(ub_common_name);
X509_NAME_get_text_by_NID(
subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
return cn;
}
} // namespace ssl } // namespace ssl
} // namespace folly } // namespace folly
......
...@@ -104,6 +104,11 @@ class OpenSSLUtils { ...@@ -104,6 +104,11 @@ class OpenSSLUtils {
static void setSSLInitialCtx(SSL* ssl, SSL_CTX* ctx); static void setSSLInitialCtx(SSL* ssl, SSL_CTX* ctx);
static SSL_CTX* getSSLInitialCtx(SSL* ssl); static SSL_CTX* getSSLInitialCtx(SSL* ssl);
/**
* Get the common name out of a cert. Return empty if x509 is null.
*/
static std::string getCommonName(X509* x509);
/** /**
* Wrappers for BIO operations that may be different across different * Wrappers for BIO operations that may be different across different
* versions/flavors of OpenSSL (including forks like BoringSSL) * versions/flavors of OpenSSL (including forks like BoringSSL)
......
...@@ -179,15 +179,6 @@ std::string getFileAsBuf(const char* fileName) { ...@@ -179,15 +179,6 @@ std::string getFileAsBuf(const char* fileName) {
return buffer; return buffer;
} }
std::string getCommonName(X509* cert) {
X509_NAME* subject = X509_get_subject_name(cert);
std::string cn;
cn.resize(ub_common_name);
X509_NAME_get_text_by_NID(
subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
return cn;
}
/** /**
* Test connecting to, writing to, reading from, and closing the * Test connecting to, writing to, reading from, and closing the
* connection to the SSL server. * connection to the SSL server.
...@@ -1586,6 +1577,23 @@ TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) { ...@@ -1586,6 +1577,23 @@ TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
EXPECT_TRUE(server.handshakeSuccess_); EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_FALSE(server.handshakeError_); EXPECT_FALSE(server.handshakeError_);
EXPECT_LE(0, server.handshakeTime.count()); EXPECT_LE(0, server.handshakeTime.count());
// check certificates
auto clientSsl = std::move(client).moveSocket();
auto serverSsl = std::move(server).moveSocket();
auto clientPeer = clientSsl->getPeerCertificate();
auto clientSelf = clientSsl->getSelfCertificate();
auto serverPeer = serverSsl->getPeerCertificate();
auto serverSelf = serverSsl->getSelfCertificate();
EXPECT_NE(clientPeer, nullptr);
EXPECT_NE(clientSelf, nullptr);
EXPECT_NE(serverPeer, nullptr);
EXPECT_NE(serverSelf, nullptr);
EXPECT_EQ(clientPeer->getIdentity(), serverSelf->getIdentity());
EXPECT_EQ(clientSelf->getIdentity(), serverPeer->getIdentity());
} }
/** /**
...@@ -1878,6 +1886,7 @@ TEST(AsyncSSLSocketTest, OpenSSL110AsyncTestFailure) { ...@@ -1878,6 +1886,7 @@ TEST(AsyncSSLSocketTest, OpenSSL110AsyncTestFailure) {
#endif // FOLLY_OPENSSL_IS_110 #endif // FOLLY_OPENSSL_IS_110
TEST(AsyncSSLSocketTest, LoadCertFromMemory) { TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
using folly::ssl::OpenSSLUtils;
auto cert = getFileAsBuf(kTestCert); auto cert = getFileAsBuf(kTestCert);
auto key = getFileAsBuf(kTestKey); auto key = getFileAsBuf(kTestKey);
...@@ -1894,7 +1903,7 @@ TEST(AsyncSSLSocketTest, LoadCertFromMemory) { ...@@ -1894,7 +1903,7 @@ TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
certBio = nullptr; certBio = nullptr;
keyBio = nullptr; keyBio = nullptr;
auto origCommonName = getCommonName(certStruct.get()); auto origCommonName = OpenSSLUtils::getCommonName(certStruct.get());
auto origKeySize = EVP_PKEY_bits(keyStruct.get()); auto origKeySize = EVP_PKEY_bits(keyStruct.get());
certStruct = nullptr; certStruct = nullptr;
keyStruct = nullptr; keyStruct = nullptr;
...@@ -1910,7 +1919,7 @@ TEST(AsyncSSLSocketTest, LoadCertFromMemory) { ...@@ -1910,7 +1919,7 @@ TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
auto newKey = SSL_get_privatekey(ssl.get()); auto newKey = SSL_get_privatekey(ssl.get());
// Get properties from SSL struct // Get properties from SSL struct
auto newCommonName = getCommonName(newCert); auto newCommonName = OpenSSLUtils::getCommonName(newCert);
auto newKeySize = EVP_PKEY_bits(newKey); auto newKeySize = EVP_PKEY_bits(newKey);
// Check that the key and cert have the expected properties // Check that the key and cert have the expected properties
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment