diff --git a/folly/io/async/AsyncSSLSocket.cpp b/folly/io/async/AsyncSSLSocket.cpp index 91a2d962429ccf8e8fdf1ad0c5ab47285c75cca1..7ceefbe3f946f51db7f2b4885207d663eef26ce5 100644 --- a/folly/io/async/AsyncSSLSocket.cpp +++ b/folly/io/async/AsyncSSLSocket.cpp @@ -833,6 +833,7 @@ void AsyncSSLSocket::sslConn( #endif SSL_set_ex_data(ssl_.get(), getSSLExDataIndex(), this); + sslSessionManager_.attachToSSL(ssl_.get()); handshakeConnectTimeout_ = timeout; startSSLConnect(); diff --git a/folly/io/async/SSLContext.cpp b/folly/io/async/SSLContext.cpp index c509263a7a6bd4a853de686f40fff99225a4f37e..89b72e90c361c9503f9ca7176f4fae6dfdd22382 100644 --- a/folly/io/async/SSLContext.cpp +++ b/folly/io/async/SSLContext.cpp @@ -22,6 +22,7 @@ #include <folly/SharedMutex.h> #include <folly/SpinLock.h> #include <folly/ssl/Init.h> +#include <folly/ssl/SSLSessionManager.h> #include <folly/system/ThreadId.h> // --------------------------------------------------------------------- @@ -38,6 +39,7 @@ int getExDataIndex() { } // namespace namespace folly { + // // For OpenSSL portability API @@ -709,7 +711,14 @@ int SSLContext::newSessionCallback(SSL* ssl, SSL_SESSION* session) { cb->onNewSession(ssl, std::move(sessionPtr)); } - SSL_SESSION_free(session); + // Session will either be moved to session manager or + // freed when the unique_ptr goes out of scope + auto sessionPtr = folly::ssl::SSLSessionUniquePtr(session); + auto sessionManager = folly::ssl::SSLSessionManager::getFromSSL(ssl); + if (sessionManager) { + sessionManager->onNewSession(std::move(sessionPtr)); + } + return 1; } diff --git a/folly/io/async/test/SSLSessionTest.cpp b/folly/io/async/test/SSLSessionTest.cpp index f9a34c950311bf5e4ce5405469c77f1e94422ef4..cb61d21ebe82509794e60f2fbf37b18afa1250f5 100644 --- a/folly/io/async/test/SSLSessionTest.cpp +++ b/folly/io/async/test/SSLSessionTest.cpp @@ -30,21 +30,6 @@ using folly::ssl::detail::OpenSSLSession; namespace folly { -class SimpleCallbackManager - : public folly::SSLContext::SessionLifecycleCallbacks { - public: - void onNewSession(SSL* ssl, folly::ssl::SSLSessionUniquePtr session) - override { - auto socket = folly::AsyncSSLSocket::getFromSSL(ssl); - auto sslSession = - std::dynamic_pointer_cast<folly::ssl::detail::OpenSSLSession>( - socket->getSSLSessionV2()); - if (sslSession) { - sslSession->setActiveSession(std::move(session)); - } - } -}; - void getfds(NetworkSocket fds[2]) { if (netops::socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) { FAIL() << "failed to create socketpair: " << errnoStr(errno); @@ -62,12 +47,12 @@ void getctx( std::shared_ptr<folly::SSLContext> serverCtx) { clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); clientCtx->loadTrustedCertificates(kTestCA); - clientCtx->setSessionLifecycleCallbacks( - std::make_unique<SimpleCallbackManager>()); + clientCtx->enableTLS13(); serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); serverCtx->loadCertificate(kTestCert); serverCtx->loadPrivateKey(kTestKey); + serverCtx->enableTLS13(); } class SSLSessionTest : public testing::Test { @@ -98,9 +83,19 @@ TEST_F(SSLSessionTest, BasicTest) { { NetworkSocket fds[2]; getfds(fds); + AsyncSSLSocket::UniquePtr clientSock( new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName)); auto clientPtr = clientSock.get(); + sslSession = clientPtr->getSSLSessionV2(); + ASSERT_NE(sslSession, nullptr); + { + auto opensslSession = + std::dynamic_pointer_cast<OpenSSLSession>(sslSession); + auto sessionPtr = opensslSession->getActiveSession(); + ASSERT_EQ(sessionPtr.get(), nullptr); + } + AsyncSSLSocket::UniquePtr serverSock( new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true)); SSLHandshakeClient client(std::move(clientSock), false, false); @@ -110,16 +105,12 @@ TEST_F(SSLSessionTest, BasicTest) { eventBase.loop(); ASSERT_TRUE(client.handshakeSuccess_); ASSERT_FALSE(clientPtr->getSSLSessionReused()); - - sslSession = clientPtr->getSSLSessionV2(); - ASSERT_NE(sslSession, nullptr); - - // The underlying SSL_SESSION is set in the session callback - // that is attached to the SSL_CTX. The session is guaranteed to - // be resumable here in TLS 1.2, but not in TLS 1.3 - auto opensslSession = std::dynamic_pointer_cast<OpenSSLSession>(sslSession); - auto sessionPtr = opensslSession->getActiveSession(); - ASSERT_NE(sessionPtr.get(), nullptr); + { + auto opensslSession = + std::dynamic_pointer_cast<OpenSSLSession>(sslSession); + auto sessionPtr = opensslSession->getActiveSession(); + ASSERT_NE(sessionPtr.get(), nullptr); + } } // Session resumption diff --git a/folly/ssl/SSLSessionManager.cpp b/folly/ssl/SSLSessionManager.cpp index 88a48f09386c83d2626c3dbfb2382baaa89ad752..de6bdac9bfa3efa12f365ff2e4945b4b3ca93f6e 100644 --- a/folly/ssl/SSLSessionManager.cpp +++ b/folly/ssl/SSLSessionManager.cpp @@ -15,6 +15,7 @@ */ #include <folly/ssl/SSLSessionManager.h> +#include <folly/portability/OpenSSL.h> #include <folly/ssl/OpenSSLPtrTypes.h> #include <folly/ssl/detail/OpenSSLSession.h> @@ -61,6 +62,29 @@ class SSLSessionRetrievalVisitor } }; +class SessionForwarderVisitor : boost::static_visitor<> { + public: + explicit SessionForwarderVisitor(SSLSessionUniquePtr sessionArg) + : sessionArg_{std::move(sessionArg)} {} + + void operator()(const SSLSessionUniquePtr&) {} + + void operator()(const std::shared_ptr<OpenSSLSession>& session) { + if (session) { + session->setActiveSession(std::move(sessionArg_)); + } + } + + private: + SSLSessionUniquePtr sessionArg_{nullptr}; +}; + +int getSSLExDataIndex() { + static auto index = + SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); + return index; +} + } // namespace namespace folly { @@ -96,5 +120,19 @@ shared_ptr<SSLSession> SSLSessionManager::getSession() const { return boost::apply_visitor(visitor, session_); } +void SSLSessionManager::attachToSSL(SSL* ssl) { + SSL_set_ex_data(ssl, getSSLExDataIndex(), this); +} + +SSLSessionManager* SSLSessionManager::getFromSSL(const SSL* ssl) { + return static_cast<SSLSessionManager*>( + SSL_get_ex_data(ssl, getSSLExDataIndex())); +} + +void SSLSessionManager::onNewSession(SSLSessionUniquePtr session) { + auto visitor = SessionForwarderVisitor(std::move(session)); + boost::apply_visitor(visitor, session_); +} + } // namespace ssl } // namespace folly diff --git a/folly/ssl/SSLSessionManager.h b/folly/ssl/SSLSessionManager.h index 0ba972bdf7fc8d794bcf7e87be3fd0a8e91e56fe..7e30d9bb1a0fe2ac164177ac388fa9d3ee8836ce 100644 --- a/folly/ssl/SSLSessionManager.h +++ b/folly/ssl/SSLSessionManager.h @@ -17,10 +17,14 @@ #pragma once #include <boost/variant.hpp> +#include <folly/portability/OpenSSL.h> #include <folly/ssl/OpenSSLPtrTypes.h> #include <folly/ssl/SSLSession.h> namespace folly { + +class SSLContext; + namespace ssl { namespace detail { @@ -51,9 +55,32 @@ class SSLSessionManager { folly::ssl::SSLSessionUniquePtr getRawSession() const; + /** + * Add SSLSessionManager instance to the ex data of ssl. + * Needs to be called for SSLSessionManager::getFromSSL to return + * a non-null pointer. + */ + void attachToSSL(SSL* ssl); + + /** + * Get pointer to a SSLSessionManager instance that was added to + * the ex data of ssl through attachToSSL() + */ + static SSLSessionManager* getFromSSL(const SSL* ssl); + private: - // The SSL session. Which type the variant contains depends on the - // session API that is used. + friend class folly::SSLContext; + + /** + * Called by SSLContext when a new session is negotiated for the + * SSL connection that SSLSessionManager is attached to. + */ + void onNewSession(folly::ssl::SSLSessionUniquePtr session); + + /** + * The SSL session. Which type the variant contains depends on the + * session API that is used. + */ boost::variant< folly::ssl::SSLSessionUniquePtr, std::shared_ptr<folly::ssl::detail::OpenSSLSession>> diff --git a/folly/ssl/test/SSLSessionManagerTest.cpp b/folly/ssl/test/SSLSessionManagerTest.cpp index d06542bdc2ae85228364292a9914c15d6cf3cbc7..5dab4a8b5ffeb197a8a1e9e11010ce828af87e64 100644 --- a/folly/ssl/test/SSLSessionManagerTest.cpp +++ b/folly/ssl/test/SSLSessionManagerTest.cpp @@ -64,4 +64,20 @@ TEST(SSLSessionManagerTest, SetRawSesionTest) { EXPECT_EQ(nullptr, manager.getRawSession().get()); } +TEST(SSLSessionManagerTest, GetFromSSLTest) { + SSLSessionManager manager; + SSL_CTX* ctx = SSL_CTX_new(SSLv23_method()); + + SSL* ssl1 = SSL_new(ctx); + EXPECT_EQ(nullptr, SSLSessionManager::getFromSSL(ssl1)); + SSL_free(ssl1); + + SSL* ssl2 = SSL_new(ctx); + manager.attachToSSL(ssl2); + EXPECT_EQ(&manager, SSLSessionManager::getFromSSL(ssl2)); + SSL_free(ssl2); + + SSL_CTX_free(ctx); +} + } // namespace folly