Commit c2b9a896 authored by Neel Goyal's avatar Neel Goyal Committed by Facebook Github Bot

Use AsyncTransportCertificate interfaces

Summary: Have AsyncSSLSocket use AsyncTransportCertificate interfaces properly

Reviewed By: mingtaoy

Differential Revision: D9752031

fbshipit-source-id: c65c0b808d82843bf1111bb650fe140ac98723b8
parent 44535e79
......@@ -70,6 +70,25 @@ inline bool zero_return(int error, int rc) {
return (error == SSL_ERROR_ZERO_RETURN || (rc == 0 && errno == 0));
}
class AsyncSSLCertificate : public folly::AsyncTransportCertificate {
public:
// assumed to be non null
explicit AsyncSSLCertificate(folly::ssl::X509UniquePtr x509)
: x509_(std::move(x509)) {}
folly::ssl::X509UniquePtr getX509() const override {
X509_up_ref(x509_.get());
return folly::ssl::X509UniquePtr(x509_.get());
}
std::string getIdentity() const override {
return OpenSSLUtils::getCommonName(x509_.get());
}
private:
folly::ssl::X509UniquePtr x509_;
};
class AsyncSSLSocketConnector : public AsyncSocket::ConnectCallback,
public AsyncSSLSocket::HandshakeCB {
private:
......@@ -931,6 +950,38 @@ int AsyncSSLSocket::getSSLCertSize() const {
return certSize;
}
const AsyncTransportCertificate* AsyncSSLSocket::getPeerCertificate() const {
if (peerCertData_) {
return peerCertData_.get();
}
if (ssl_ != nullptr) {
auto peerX509 = SSL_get_peer_certificate(ssl_);
if (peerX509) {
// already up ref'd
folly::ssl::X509UniquePtr peer(peerX509);
peerCertData_ = std::make_unique<AsyncSSLCertificate>(std::move(peer));
}
}
return peerCertData_.get();
}
const AsyncTransportCertificate* AsyncSSLSocket::getSelfCertificate() const {
if (selfCertData_) {
return selfCertData_.get();
}
if (ssl_ != nullptr) {
auto selfX509 = SSL_get_certificate(ssl_);
if (selfX509) {
// need to upref
X509_up_ref(selfX509);
folly::ssl::X509UniquePtr peer(selfX509);
selfCertData_ = std::make_unique<AsyncSSLCertificate>(std::move(peer));
}
}
return selfCertData_.get();
}
// TODO: deprecate/remove in favor of getSelfCertificate.
const X509* AsyncSSLSocket::getSelfCert() const {
return (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
}
......
......@@ -741,16 +741,18 @@ class AsyncSSLSocket : public virtual AsyncSocket {
*/
void setBufferMovableEnabled(bool enabled);
const AsyncTransportCertificate* getPeerCertificate() const override;
const AsyncTransportCertificate* getSelfCertificate() const override;
/**
* Returns the peer certificate, or nullptr if no peer certificate received.
*/
ssl::X509UniquePtr getPeerCert() const override {
if (!ssl_) {
auto peerCert = getPeerCertificate();
if (!peerCert) {
return nullptr;
}
X509* cert = SSL_get_peer_certificate(ssl_);
return ssl::X509UniquePtr(cert);
return peerCert->getX509();
}
/**
......
......@@ -1269,8 +1269,11 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
bool zeroCopyEnabled_{false};
bool zeroCopyVal_{false};
std::unique_ptr<const AsyncTransportCertificate> peerCertData_{nullptr};
std::unique_ptr<const AsyncTransportCertificate> selfCertData_{nullptr};
// subclasses may cache these on first call to get
mutable std::unique_ptr<const AsyncTransportCertificate> peerCertData_{
nullptr};
mutable std::unique_ptr<const AsyncTransportCertificate> selfCertData_{
nullptr};
};
#ifdef _MSC_VER
#pragma vtordisp(pop)
......
......@@ -312,6 +312,18 @@ void OpenSSLUtils::setBioFd(BIO* b, int fd, int flags) {
BIO_set_fd(b, sock, flags);
}
std::string OpenSSLUtils::getCommonName(X509* x509) {
if (x509 == nullptr) {
return "";
}
X509_NAME* subject = X509_get_subject_name(x509);
std::string cn;
cn.resize(ub_common_name);
X509_NAME_get_text_by_NID(
subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
return cn;
}
} // namespace ssl
} // namespace folly
......
......@@ -104,6 +104,11 @@ class OpenSSLUtils {
static void setSSLInitialCtx(SSL* ssl, SSL_CTX* ctx);
static SSL_CTX* getSSLInitialCtx(SSL* ssl);
/**
* Get the common name out of a cert. Return empty if x509 is null.
*/
static std::string getCommonName(X509* x509);
/**
* Wrappers for BIO operations that may be different across different
* versions/flavors of OpenSSL (including forks like BoringSSL)
......
......@@ -179,15 +179,6 @@ std::string getFileAsBuf(const char* fileName) {
return buffer;
}
std::string getCommonName(X509* cert) {
X509_NAME* subject = X509_get_subject_name(cert);
std::string cn;
cn.resize(ub_common_name);
X509_NAME_get_text_by_NID(
subject, NID_commonName, const_cast<char*>(cn.data()), ub_common_name);
return cn;
}
/**
* Test connecting to, writing to, reading from, and closing the
* connection to the SSL server.
......@@ -1586,6 +1577,23 @@ TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_FALSE(server.handshakeError_);
EXPECT_LE(0, server.handshakeTime.count());
// check certificates
auto clientSsl = std::move(client).moveSocket();
auto serverSsl = std::move(server).moveSocket();
auto clientPeer = clientSsl->getPeerCertificate();
auto clientSelf = clientSsl->getSelfCertificate();
auto serverPeer = serverSsl->getPeerCertificate();
auto serverSelf = serverSsl->getSelfCertificate();
EXPECT_NE(clientPeer, nullptr);
EXPECT_NE(clientSelf, nullptr);
EXPECT_NE(serverPeer, nullptr);
EXPECT_NE(serverSelf, nullptr);
EXPECT_EQ(clientPeer->getIdentity(), serverSelf->getIdentity());
EXPECT_EQ(clientSelf->getIdentity(), serverPeer->getIdentity());
}
/**
......@@ -1878,6 +1886,7 @@ TEST(AsyncSSLSocketTest, OpenSSL110AsyncTestFailure) {
#endif // FOLLY_OPENSSL_IS_110
TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
using folly::ssl::OpenSSLUtils;
auto cert = getFileAsBuf(kTestCert);
auto key = getFileAsBuf(kTestKey);
......@@ -1894,7 +1903,7 @@ TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
certBio = nullptr;
keyBio = nullptr;
auto origCommonName = getCommonName(certStruct.get());
auto origCommonName = OpenSSLUtils::getCommonName(certStruct.get());
auto origKeySize = EVP_PKEY_bits(keyStruct.get());
certStruct = nullptr;
keyStruct = nullptr;
......@@ -1910,7 +1919,7 @@ TEST(AsyncSSLSocketTest, LoadCertFromMemory) {
auto newKey = SSL_get_privatekey(ssl.get());
// Get properties from SSL struct
auto newCommonName = getCommonName(newCert);
auto newCommonName = OpenSSLUtils::getCommonName(newCert);
auto newKeySize = EVP_PKEY_bits(newKey);
// Check that the key and cert have the expected properties
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment