Commit d96d9550 authored by Alex Guzman's avatar Alex Guzman Committed by Facebook Github Bot

Let AsyncSSLSocket run accept via a runner.

Summary: Allows for a runner to execute the accept function and return the result via a callback. If no runner is specified, it simply does the accept as usual.

Reviewed By: yfeldblum

Differential Revision: D9849138

fbshipit-source-id: ef43ccc8668bcf1fe7f75b0b6fdcdba7adc891da
parent fb28875b
...@@ -1144,7 +1144,20 @@ void AsyncSSLSocket::handleAccept() noexcept { ...@@ -1144,7 +1144,20 @@ void AsyncSSLSocket::handleAccept() noexcept {
SSL_set_msg_callback_arg(ssl_, this); SSL_set_msg_callback_arg(ssl_, this);
} }
int ret = SSL_accept(ssl_); DCHECK(ctx_->sslAcceptRunner());
updateEventRegistration(
EventHandler::NONE, EventHandler::READ | EventHandler::WRITE);
DelayedDestruction::DestructorGuard dg(this);
ctx_->sslAcceptRunner()->run(
[this, dg]() { return SSL_accept(ssl_); },
[this, dg](int ret) { handleReturnFromSSLAccept(ret); });
}
void AsyncSSLSocket::handleReturnFromSSLAccept(int ret) {
if (sslState_ != STATE_ACCEPTING) {
return;
}
if (ret <= 0) { if (ret <= 0) {
VLOG(3) << "SSL_accept returned: " << ret; VLOG(3) << "SSL_accept returned: " << ret;
int sslError; int sslError;
......
...@@ -799,6 +799,11 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -799,6 +799,11 @@ class AsyncSSLSocket : public virtual AsyncSocket {
} }
private: private:
/**
* Handle the return from invoking SSL_accept
*/
void handleReturnFromSSLAccept(int ret);
void init(); void init();
protected: protected:
......
...@@ -66,6 +66,8 @@ SSLContext::SSLContext(SSLVersion version) { ...@@ -66,6 +66,8 @@ SSLContext::SSLContext(SSLVersion version) {
SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION); SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION);
sslAcceptRunner_ = std::make_unique<SSLAcceptRunner>();
#if FOLLY_OPENSSL_HAS_SNI #if FOLLY_OPENSSL_HAS_SNI
SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback); SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
SSL_CTX_set_tlsext_servername_arg(ctx_, this); SSL_CTX_set_tlsext_servername_arg(ctx_, this);
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <folly/folly-config.h> #include <folly/folly-config.h>
#endif #endif
#include <folly/Function.h>
#include <folly/Portability.h> #include <folly/Portability.h>
#include <folly/Range.h> #include <folly/Range.h>
#include <folly/String.h> #include <folly/String.h>
...@@ -64,6 +65,24 @@ class PasswordCollector { ...@@ -64,6 +65,24 @@ class PasswordCollector {
virtual std::string describe() const = 0; virtual std::string describe() const = 0;
}; };
/**
* Run SSL_accept via a runner
*/
class SSLAcceptRunner {
public:
virtual ~SSLAcceptRunner() = default;
/**
* This is expected to run the first function and provide its return
* value to the second function. This can be used to run the SSL_accept
* in different contexts.
*/
virtual void run(Function<int()> acceptFunc, Function<void(int)> finallyFunc)
const {
finallyFunc(acceptFunc());
}
};
/** /**
* Wrap OpenSSL SSL_CTX into a class. * Wrap OpenSSL SSL_CTX into a class.
*/ */
...@@ -509,6 +528,22 @@ class SSLContext { ...@@ -509,6 +528,22 @@ class SSLContext {
void enableFalseStart(); void enableFalseStart();
#endif #endif
/**
* Sets the runner used for SSL_accept. If none is given, the accept will be
* done directly.
*/
void sslAcceptRunner(std::unique_ptr<SSLAcceptRunner> runner) {
if (nullptr == runner) {
LOG(ERROR) << "Ignore invalid runner";
return;
}
sslAcceptRunner_ = std::move(runner);
}
const SSLAcceptRunner* sslAcceptRunner() {
return sslAcceptRunner_.get();
}
/** /**
* Helper to match a hostname versus a pattern. * Helper to match a hostname versus a pattern.
*/ */
...@@ -534,6 +569,8 @@ class SSLContext { ...@@ -534,6 +569,8 @@ class SSLContext {
static bool initialized_; static bool initialized_;
std::unique_ptr<SSLAcceptRunner> sslAcceptRunner_;
#if FOLLY_OPENSSL_HAS_ALPN #if FOLLY_OPENSSL_HAS_ALPN
struct AdvertisedNextProtocolsItem { struct AdvertisedNextProtocolsItem {
......
...@@ -1897,6 +1897,146 @@ TEST(AsyncSSLSocketTest, ConnectUnencryptedTest) { ...@@ -1897,6 +1897,146 @@ TEST(AsyncSSLSocketTest, ConnectUnencryptedTest) {
socket->close(); socket->close();
} }
/**
* Test acceptrunner in various situations
*/
TEST(AsyncSSLSocketTest, SSLAcceptRunnerBasic) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
serverCtx->loadPrivateKey(kTestKey);
serverCtx->loadCertificate(kTestCert);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
clientCtx->loadTrustedCertificates(kTestCA);
int fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
serverCtx->sslAcceptRunner(std::make_unique<SSLAcceptEvbRunner>(&eventBase));
SSLHandshakeClient client(std::move(clientSock), true, true);
SSLHandshakeServer server(std::move(serverSock), true, true);
eventBase.loop();
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_FALSE(client.handshakeError_);
EXPECT_LE(0, client.handshakeTime.count());
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_FALSE(server.handshakeError_);
EXPECT_LE(0, server.handshakeTime.count());
}
TEST(AsyncSSLSocketTest, SSLAcceptRunnerAcceptError) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
serverCtx->loadPrivateKey(kTestKey);
serverCtx->loadCertificate(kTestCert);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
clientCtx->loadTrustedCertificates(kTestCA);
int fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
serverCtx->sslAcceptRunner(
std::make_unique<SSLAcceptErrorRunner>(&eventBase));
SSLHandshakeClient client(std::move(clientSock), true, true);
SSLHandshakeServer server(std::move(serverSock), true, true);
eventBase.loop();
EXPECT_FALSE(client.handshakeSuccess_);
EXPECT_TRUE(client.handshakeError_);
EXPECT_FALSE(server.handshakeSuccess_);
EXPECT_TRUE(server.handshakeError_);
}
TEST(AsyncSSLSocketTest, SSLAcceptRunnerAcceptClose) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
serverCtx->loadPrivateKey(kTestKey);
serverCtx->loadCertificate(kTestCert);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
clientCtx->loadTrustedCertificates(kTestCA);
int fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
serverCtx->sslAcceptRunner(
std::make_unique<SSLAcceptCloseRunner>(&eventBase, serverSock.get()));
SSLHandshakeClient client(std::move(clientSock), true, true);
SSLHandshakeServer server(std::move(serverSock), true, true);
eventBase.loop();
EXPECT_FALSE(client.handshakeSuccess_);
EXPECT_TRUE(client.handshakeError_);
EXPECT_FALSE(server.handshakeSuccess_);
EXPECT_TRUE(server.handshakeError_);
}
TEST(AsyncSSLSocketTest, SSLAcceptRunnerAcceptDestroy) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
serverCtx->loadPrivateKey(kTestKey);
serverCtx->loadCertificate(kTestCert);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
clientCtx->loadTrustedCertificates(kTestCA);
int fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), true, true);
SSLHandshakeServer server(std::move(serverSock), true, true);
serverCtx->sslAcceptRunner(
std::make_unique<SSLAcceptDestroyRunner>(&eventBase, &server));
eventBase.loop();
EXPECT_FALSE(client.handshakeSuccess_);
EXPECT_TRUE(client.handshakeError_);
EXPECT_FALSE(server.handshakeSuccess_);
EXPECT_TRUE(server.handshakeError_);
}
TEST(AsyncSSLSocketTest, ConnResetErrorString) { TEST(AsyncSSLSocketTest, ConnResetErrorString) {
// Start listening on a local port // Start listening on a local port
WriteCallbackBase writeCallback; WriteCallbackBase writeCallback;
......
...@@ -1309,21 +1309,27 @@ class SSLHandshakeBase : public AsyncSSLSocket::HandshakeCB, ...@@ -1309,21 +1309,27 @@ class SSLHandshakeBase : public AsyncSSLSocket::HandshakeCB,
void handshakeSuc(AsyncSSLSocket*) noexcept override { void handshakeSuc(AsyncSSLSocket*) noexcept override {
LOG(INFO) << "Handshake success"; LOG(INFO) << "Handshake success";
handshakeSuccess_ = true; handshakeSuccess_ = true;
if (socket_) {
handshakeTime = socket_->getHandshakeTime(); handshakeTime = socket_->getHandshakeTime();
} }
}
void handshakeErr( void handshakeErr(
AsyncSSLSocket*, AsyncSSLSocket*,
const AsyncSocketException& ex) noexcept override { const AsyncSocketException& ex) noexcept override {
LOG(INFO) << "Handshake error " << ex.what(); LOG(INFO) << "Handshake error " << ex.what();
handshakeError_ = true; handshakeError_ = true;
if (socket_) {
handshakeTime = socket_->getHandshakeTime(); handshakeTime = socket_->getHandshakeTime();
} }
}
// WriteCallback // WriteCallback
void writeSuccess() noexcept override { void writeSuccess() noexcept override {
if (socket_) {
socket_->close(); socket_->close();
} }
}
void writeErr( void writeErr(
size_t bytesWritten, size_t bytesWritten,
...@@ -1451,4 +1457,75 @@ class EventBaseAborter : public AsyncTimeout { ...@@ -1451,4 +1457,75 @@ class EventBaseAborter : public AsyncTimeout {
EventBase* eventBase_; EventBase* eventBase_;
}; };
class SSLAcceptEvbRunner : public SSLAcceptRunner {
public:
explicit SSLAcceptEvbRunner(EventBase* evb) : evb_(evb) {}
~SSLAcceptEvbRunner() override = default;
void run(Function<int()> acceptFunc, Function<void(int)> finallyFunc)
const override {
evb_->runInLoop([acceptFunc = std::move(acceptFunc),
finallyFunc = std::move(finallyFunc)]() mutable {
finallyFunc(acceptFunc());
});
}
protected:
EventBase* evb_;
};
class SSLAcceptErrorRunner : public SSLAcceptEvbRunner {
public:
explicit SSLAcceptErrorRunner(EventBase* evb) : SSLAcceptEvbRunner(evb) {}
~SSLAcceptErrorRunner() override = default;
void run(Function<int()> /*acceptFunc*/, Function<void(int)> finallyFunc)
const override {
evb_->runInLoop(
[finallyFunc = std::move(finallyFunc)]() mutable { finallyFunc(-1); });
}
};
class SSLAcceptCloseRunner : public SSLAcceptEvbRunner {
public:
explicit SSLAcceptCloseRunner(EventBase* evb, folly::AsyncSSLSocket* sock)
: SSLAcceptEvbRunner(evb), socket_(sock) {}
~SSLAcceptCloseRunner() override = default;
void run(Function<int()> acceptFunc, Function<void(int)> finallyFunc)
const override {
evb_->runInLoop([acceptFunc = std::move(acceptFunc),
finallyFunc = std::move(finallyFunc),
sock = socket_]() mutable {
auto ret = acceptFunc();
sock->closeNow();
finallyFunc(ret);
});
}
private:
folly::AsyncSSLSocket* socket_;
};
class SSLAcceptDestroyRunner : public SSLAcceptEvbRunner {
public:
explicit SSLAcceptDestroyRunner(EventBase* evb, SSLHandshakeBase* base)
: SSLAcceptEvbRunner(evb), sslBase_(base) {}
~SSLAcceptDestroyRunner() override = default;
void run(Function<int()> acceptFunc, Function<void(int)> finallyFunc)
const override {
evb_->runInLoop([acceptFunc = std::move(acceptFunc),
finallyFunc = std::move(finallyFunc),
sslBase = sslBase_]() mutable {
auto ret = acceptFunc();
std::move(*sslBase).moveSocket();
finallyFunc(ret);
});
}
private:
SSLHandshakeBase* sslBase_;
};
} // namespace folly } // namespace folly
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment