Commit 2ea73e0b authored by Alex Guzman's avatar Alex Guzman Committed by Facebook Github Bot

Don't call SSL_shutdown when SSL_accept is pending

Summary: As it says on tin.

Reviewed By: knekritz

Differential Revision: D13144407

fbshipit-source-id: 8fc69f9005ca54c2fb82b501547de2aaa892c1fa
parent 2bacf890
...@@ -326,7 +326,7 @@ void AsyncSSLSocket::init() { ...@@ -326,7 +326,7 @@ void AsyncSSLSocket::init() {
void AsyncSSLSocket::closeNow() { void AsyncSSLSocket::closeNow() {
// Close the SSL connection. // Close the SSL connection.
if (ssl_ != nullptr && fd_ != NetworkSocket()) { if (ssl_ != nullptr && fd_ != NetworkSocket() && !waitingOnAccept_) {
int rc = SSL_shutdown(ssl_.get()); int rc = SSL_shutdown(ssl_.get());
if (rc == 0) { if (rc == 0) {
rc = SSL_shutdown(ssl_.get()); rc = SSL_shutdown(ssl_.get());
...@@ -1148,8 +1148,14 @@ void AsyncSSLSocket::handleAccept() noexcept { ...@@ -1148,8 +1148,14 @@ void AsyncSSLSocket::handleAccept() noexcept {
EventHandler::NONE, EventHandler::READ | EventHandler::WRITE); EventHandler::NONE, EventHandler::READ | EventHandler::WRITE);
DelayedDestruction::DestructorGuard dg(this); DelayedDestruction::DestructorGuard dg(this);
ctx_->sslAcceptRunner()->run( ctx_->sslAcceptRunner()->run(
[this, dg]() { return SSL_accept(ssl_.get()); }, [this, dg]() {
[this, dg](int ret) { handleReturnFromSSLAccept(ret); }); waitingOnAccept_ = true;
return SSL_accept(ssl_.get());
},
[this, dg](int ret) {
waitingOnAccept_ = false;
handleReturnFromSSLAccept(ret);
});
} }
void AsyncSSLSocket::handleReturnFromSSLAccept(int ret) { void AsyncSSLSocket::handleReturnFromSSLAccept(int ret) {
......
...@@ -1000,6 +1000,8 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -1000,6 +1000,8 @@ class AsyncSSLSocket : public virtual AsyncSocket {
bool sessionIDResumed_{false}; bool sessionIDResumed_{false};
// This can be called for OpenSSL 1.1.0 async operation finishes // This can be called for OpenSSL 1.1.0 async operation finishes
std::unique_ptr<ReadCallback> asyncOperationFinishCallback_; std::unique_ptr<ReadCallback> asyncOperationFinishCallback_;
// Whether this socket is currently waiting on SSL_accept
bool waitingOnAccept_{false};
}; };
} // namespace folly } // namespace folly
...@@ -2089,6 +2089,95 @@ TEST(AsyncSSLSocketTest, SSLAcceptRunnerAcceptDestroy) { ...@@ -2089,6 +2089,95 @@ TEST(AsyncSSLSocketTest, SSLAcceptRunnerAcceptDestroy) {
EXPECT_TRUE(server.handshakeError_); EXPECT_TRUE(server.handshakeError_);
} }
TEST(AsyncSSLSocketTest, SSLAcceptRunnerFiber) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
serverCtx->loadPrivateKey(kTestKey);
serverCtx->loadCertificate(kTestCert);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
clientCtx->loadTrustedCertificates(kTestCA);
NetworkSocket fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), true, true);
SSLHandshakeServer server(std::move(serverSock), true, true);
serverCtx->sslAcceptRunner(
std::make_unique<SSLAcceptFiberRunner>(&eventBase));
eventBase.loop();
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_FALSE(client.handshakeError_);
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_FALSE(server.handshakeError_);
}
static int newCloseCb(SSL* ssl, SSL_SESSION*) {
AsyncSSLSocket::getFromSSL(ssl)->closeNow();
return 1;
}
#if FOLLY_OPENSSL_IS_110
static SSL_SESSION* getCloseCb(SSL* ssl, const unsigned char*, int, int*) {
#else
static SSL_SESSION* getCloseCb(SSL* ssl, unsigned char*, int, int*) {
#endif
AsyncSSLSocket::getFromSSL(ssl)->closeNow();
return nullptr;
}
TEST(AsyncSSLSocketTest, SSLAcceptRunnerFiberCloseSessionCb) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
serverCtx->loadPrivateKey(kTestKey);
serverCtx->loadCertificate(kTestCert);
SSL_CTX_set_session_cache_mode(
serverCtx->getSSLCtx(),
SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
SSL_CTX_sess_set_new_cb(serverCtx->getSSLCtx(), &newCloseCb);
SSL_CTX_sess_set_get_cb(serverCtx->getSSLCtx(), &getCloseCb);
serverCtx->sslAcceptRunner(
std::make_unique<SSLAcceptFiberRunner>(&eventBase));
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
clientCtx->ciphers("AES128-SHA256");
clientCtx->loadTrustedCertificates(kTestCA);
clientCtx->setOptions(SSL_OP_NO_TICKET);
NetworkSocket fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), true, true);
SSLHandshakeServer server(std::move(serverSock), true, true);
eventBase.loop();
// As close() is called during session callbacks, client sees it as a
// successful connection
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_FALSE(client.handshakeError_);
EXPECT_FALSE(server.handshakeSuccess_);
EXPECT_TRUE(server.handshakeError_);
}
TEST(AsyncSSLSocketTest, ConnResetErrorString) { TEST(AsyncSSLSocketTest, ConnResetErrorString) {
// Start listening on a local port // Start listening on a local port
WriteCallbackBase writeCallback; WriteCallbackBase writeCallback;
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <folly/ExceptionWrapper.h> #include <folly/ExceptionWrapper.h>
#include <folly/SocketAddress.h> #include <folly/SocketAddress.h>
#include <folly/experimental/TestUtil.h> #include <folly/experimental/TestUtil.h>
#include <folly/fibers/FiberManagerMap.h>
#include <folly/io/async/AsyncSSLSocket.h> #include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/AsyncServerSocket.h> #include <folly/io/async/AsyncServerSocket.h>
#include <folly/io/async/AsyncSocket.h> #include <folly/io/async/AsyncSocket.h>
...@@ -1530,4 +1531,19 @@ class SSLAcceptDestroyRunner : public SSLAcceptEvbRunner { ...@@ -1530,4 +1531,19 @@ class SSLAcceptDestroyRunner : public SSLAcceptEvbRunner {
SSLHandshakeBase* sslBase_; SSLHandshakeBase* sslBase_;
}; };
class SSLAcceptFiberRunner : public SSLAcceptEvbRunner {
public:
explicit SSLAcceptFiberRunner(EventBase* evb) : SSLAcceptEvbRunner(evb) {}
~SSLAcceptFiberRunner() override = default;
void run(Function<int()> acceptFunc, Function<void(int)> finallyFunc)
const override {
auto& fiberManager = folly::fibers::getFiberManager(*evb_);
fiberManager.addTaskFinally(
std::move(acceptFunc),
[finally = std::move(finallyFunc)](folly::Try<int>&& res) mutable {
finally(res.value());
});
}
};
} // namespace folly } // namespace folly
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment