Commit ffc3227f authored by Andrew Huang's avatar Andrew Huang Committed by Facebook GitHub Bot

Replace old AsyncSSLSocket session API with V2

Summary:
This new API has a few benefits:
1. It allows `getSSLSession` to support TLS 1.3 session resumption by returning a mutable session wrapper as opposed to the immutable `SSL_SESSION*` object.
2. OpenSSL `SSL_SESSION*` objects require the caller to keep accurate reference counts. Failure to do so can result in memory leaks or use-after-free errors.
3. This design abstracts away OpenSSL internals, which are unnecessary for the caller to perform session resumption.

Reviewed By: mingtaoy

Differential Revision: D24239802

fbshipit-source-id: cd3e90217717394f32dc6a2281e7a40c805990b2
parent 0cb5aa0f
...@@ -898,15 +898,7 @@ void AsyncSSLSocket::startSSLConnect() { ...@@ -898,15 +898,7 @@ void AsyncSSLSocket::startSSLConnect() {
handleConnect(); handleConnect();
} }
SSL_SESSION* AsyncSSLSocket::getSSLSession() { shared_ptr<ssl::SSLSession> AsyncSSLSocket::getSSLSession() {
if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
return SSL_get1_session(ssl_.get());
}
return sslSessionManager_.getRawSession().release();
}
shared_ptr<ssl::SSLSession> AsyncSSLSocket::getSSLSessionV2() {
return sslSessionManager_.getSession(); return sslSessionManager_.getSession();
} }
...@@ -914,17 +906,8 @@ const SSL* AsyncSSLSocket::getSSL() const { ...@@ -914,17 +906,8 @@ const SSL* AsyncSSLSocket::getSSL() const {
return ssl_.get(); return ssl_.get();
} }
void AsyncSSLSocket::setSSLSession(SSL_SESSION* session, bool takeOwnership) { void AsyncSSLSocket::setSSLSession(shared_ptr<ssl::SSLSession> session) {
if (!takeOwnership && session != nullptr) { sslSessionManager_.setSession(std::move(session));
// Increment the reference count
// This API exists in BoringSSL and OpenSSL 1.1.0
SSL_SESSION_up_ref(session);
}
sslSessionManager_.setRawSession(SSLSessionUniquePtr(session));
}
void AsyncSSLSocket::setSSLSessionV2(shared_ptr<ssl::SSLSession> session) {
sslSessionManager_.setSession(session);
} }
void AsyncSSLSocket::setRawSSLSession(SSLSessionUniquePtr session) { void AsyncSSLSocket::setRawSSLSession(SSLSessionUniquePtr session) {
......
...@@ -498,17 +498,13 @@ class AsyncSSLSocket : public AsyncSocket { ...@@ -498,17 +498,13 @@ class AsyncSSLSocket : public AsyncSocket {
SSLStateEnum getSSLState() const { return sslState_; } SSLStateEnum getSSLState() const { return sslState_; }
/** /**
* Get a handle to the negotiated SSL session. This increments the session * Retrieve the SSL session associated with this established connection.
* refcount and must be deallocated by the caller. *
*/ * The SSL Session object is a copyable, opaque token that can be set on other
SSL_SESSION* getSSLSession(); * unconnected AsyncSSLSockets. If AsyncSSLSocket::connect() is called with a
* previous session set, TLS resumption will be attempted.
/**
* Currently unsupported. Eventually intended to replace getSSLSession()
* once TLS 1.3 is enabled by default.
* Get an abstracted SSL Session.
*/ */
std::shared_ptr<ssl::SSLSession> getSSLSessionV2(); std::shared_ptr<ssl::SSLSession> getSSLSession();
/** /**
* Get a handle to the SSL struct. * Get a handle to the SSL struct.
...@@ -516,25 +512,13 @@ class AsyncSSLSocket : public AsyncSocket { ...@@ -516,25 +512,13 @@ class AsyncSSLSocket : public AsyncSocket {
const SSL* getSSL() const; const SSL* getSSL() const;
/** /**
* DEPRECATED. Will eventually be removed. Please use setSSLSessionV2. * Sets the SSL session that will be attempted for TLS resumption.
*
* Set the SSL session to be used during sslConn. AsyncSSLSocket will
* hold a reference to the session until it is destroyed or released by the
* underlying SSL structure.
*
* @param takeOwnership if true, AsyncSSLSocket will assume the caller's
* reference count to session.
*/
void setSSLSession(SSL_SESSION* session, bool takeOwnership = false);
/**
* Set the SSL session to be used during sslConn.
*/ */
void setSSLSessionV2(std::shared_ptr<ssl::SSLSession> session); void setSSLSession(std::shared_ptr<ssl::SSLSession> session);
/** /**
* Note: This function exists for compatibility reasons. It is strongly * Note: This function exists for compatibility reasons. It is strongly
* recommended to use setSSLSessionV2 instead. After setRawSSLSession is * recommended to use setSSLSession instead. After setRawSSLSession is
* called, subsequent calls to getSSLSession on the socket will return null. * called, subsequent calls to getSSLSession on the socket will return null.
* *
* Set the SSL session to be used during sslConn. * Set the SSL session to be used during sslConn.
......
...@@ -3204,7 +3204,7 @@ TEST(AsyncSSLSocketTest, TestSNIClientHelloBehavior) { ...@@ -3204,7 +3204,7 @@ TEST(AsyncSSLSocketTest, TestSNIClientHelloBehavior) {
// create another client, resuming with the prior session, but under a // create another client, resuming with the prior session, but under a
// different common name. // different common name.
clientSock = std::move(client).moveSocket(); clientSock = std::move(client).moveSocket();
resumptionSession = clientSock->getSSLSessionV2(); resumptionSession = clientSock->getSSLSession();
} }
{ {
...@@ -3216,7 +3216,7 @@ TEST(AsyncSSLSocketTest, TestSNIClientHelloBehavior) { ...@@ -3216,7 +3216,7 @@ TEST(AsyncSSLSocketTest, TestSNIClientHelloBehavior) {
AsyncSSLSocket::UniquePtr serverSock( AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true)); new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
clientSock->setSSLSessionV2(resumptionSession); clientSock->setSSLSession(resumptionSession);
clientSock->setServerName("Baz"); clientSock->setServerName("Baz");
SSLHandshakeServerParseClientHello server( SSLHandshakeServerParseClientHello server(
std::move(serverSock), true, true); std::move(serverSock), true, true);
......
...@@ -1168,7 +1168,7 @@ class SSLClient : public AsyncSocket::ConnectCallback, ...@@ -1168,7 +1168,7 @@ class SSLClient : public AsyncSocket::ConnectCallback,
void connect(bool writeNow = false) { void connect(bool writeNow = false) {
sslSocket_ = AsyncSSLSocket::newSocket(ctx_, eventBase_); sslSocket_ = AsyncSSLSocket::newSocket(ctx_, eventBase_);
if (session_ != nullptr) { if (session_ != nullptr) {
sslSocket_->setSSLSessionV2(session_); sslSocket_->setSSLSession(session_);
} }
requests_--; requests_--;
sslSocket_->connect(this, address_, timeout_); sslSocket_->connect(this, address_, timeout_);
...@@ -1184,7 +1184,7 @@ class SSLClient : public AsyncSocket::ConnectCallback, ...@@ -1184,7 +1184,7 @@ class SSLClient : public AsyncSocket::ConnectCallback,
hit_++; hit_++;
} else { } else {
miss_++; miss_++;
session_ = sslSocket_->getSSLSessionV2(); session_ = sslSocket_->getSSLSession();
} }
// write() // write()
......
...@@ -85,7 +85,7 @@ TEST_F(SSLSessionTest, BasicTest) { ...@@ -85,7 +85,7 @@ TEST_F(SSLSessionTest, BasicTest) {
AsyncSSLSocket::UniquePtr clientSock( AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName)); new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
auto clientPtr = clientSock.get(); auto clientPtr = clientSock.get();
sslSession = clientPtr->getSSLSessionV2(); sslSession = clientPtr->getSSLSession();
ASSERT_NE(sslSession, nullptr); ASSERT_NE(sslSession, nullptr);
{ {
auto opensslSession = auto opensslSession =
...@@ -111,57 +111,6 @@ TEST_F(SSLSessionTest, BasicTest) { ...@@ -111,57 +111,6 @@ TEST_F(SSLSessionTest, BasicTest) {
} }
} }
// Session resumption
{
NetworkSocket fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
auto clientPtr = clientSock.get();
clientPtr->setSSLSessionV2(sslSession);
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), false, false);
SSLHandshakeServerParseClientHello server(
std::move(serverSock), false, false);
eventBase.loop();
ASSERT_TRUE(client.handshakeSuccess_);
ASSERT_TRUE(clientPtr->getSSLSessionReused());
}
}
/**
* To be removed when getSSLSessionV2() and setSSLSessionV2()
* replace getSSLSession() and setSSLSession(),
* respectively.
*/
TEST_F(SSLSessionTest, BasicRegressionTest) {
SSL_SESSION* sslSession;
// Full handshake
{
NetworkSocket fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
auto clientPtr = clientSock.get();
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), false, false);
SSLHandshakeServerParseClientHello server(
std::move(serverSock), false, false);
eventBase.loop();
ASSERT_TRUE(client.handshakeSuccess_);
ASSERT_FALSE(clientPtr->getSSLSessionReused());
sslSession = clientPtr->getSSLSession();
ASSERT_NE(sslSession, nullptr);
}
// Session resumption // Session resumption
{ {
NetworkSocket fds[2]; NetworkSocket fds[2];
...@@ -181,7 +130,6 @@ TEST_F(SSLSessionTest, BasicRegressionTest) { ...@@ -181,7 +130,6 @@ TEST_F(SSLSessionTest, BasicRegressionTest) {
eventBase.loop(); eventBase.loop();
ASSERT_TRUE(client.handshakeSuccess_); ASSERT_TRUE(client.handshakeSuccess_);
ASSERT_TRUE(clientPtr->getSSLSessionReused()); ASSERT_TRUE(clientPtr->getSSLSessionReused());
SSL_SESSION_free(sslSession);
} }
} }
...@@ -194,7 +142,7 @@ TEST_F(SSLSessionTest, NullSessionResumptionTest) { ...@@ -194,7 +142,7 @@ TEST_F(SSLSessionTest, NullSessionResumptionTest) {
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName)); new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
auto clientPtr = clientSock.get(); auto clientPtr = clientSock.get();
clientPtr->setSSLSessionV2(nullptr); clientPtr->setSSLSession(nullptr);
AsyncSSLSocket::UniquePtr serverSock( AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment