Commit 168487dc authored by Neel Goyal's avatar Neel Goyal Committed by Facebook Github Bot

Add evb change callback to SSL Socket

Summary: Allow observers to be notified when AsyncSocket attaches and detaches from EVB

Reviewed By: siyengar

Differential Revision: D4256175

fbshipit-source-id: a3ff96811f885e508f20cf11ce52e0f00e5ee461
parent 4824fb83
...@@ -464,18 +464,36 @@ void AsyncSSLSocket::attachSSLContext( ...@@ -464,18 +464,36 @@ void AsyncSSLSocket::attachSSLContext(
DCHECK(ctx->getSSLCtx()); DCHECK(ctx->getSSLCtx());
ctx_ = ctx; ctx_ = ctx;
// It's possible this could be attached before ssl_ is set up
if (!ssl_) {
return;
}
// In order to call attachSSLContext, detachSSLContext must have been // In order to call attachSSLContext, detachSSLContext must have been
// previously called which sets the socket's context to the dummy // previously called.
// context. Thus we must acquire this lock. // We need to update the initial_ctx if necessary
auto sslCtx = ctx->getSSLCtx();
#ifndef OPENSSL_NO_TLSEXT
CRYPTO_add(&sslCtx->references, 1, CRYPTO_LOCK_SSL_CTX);
// note that detachSSLContext has already freed ssl_->initial_ctx
ssl_->initial_ctx = sslCtx;
#endif
// Detach sets the socket's context to the dummy context. Thus we must acquire
// this lock.
SpinLockGuard guard(dummyCtxLock); SpinLockGuard guard(dummyCtxLock);
SSL_set_SSL_CTX(ssl_, ctx->getSSLCtx()); SSL_set_SSL_CTX(ssl_, sslCtx);
} }
void AsyncSSLSocket::detachSSLContext() { void AsyncSSLSocket::detachSSLContext() {
DCHECK(ctx_); DCHECK(ctx_);
ctx_.reset(); ctx_.reset();
// We aren't using the initial_ctx for now, and it can introduce race // It's possible for this to be called before ssl_ has been
// conditions in the destructor of the SSL object. // set up
if (!ssl_) {
return;
}
// Detach the initial_ctx as well. Internally w/ OPENSSL_NO_TLSEXT
// it is used for session info. It will be reattached in attachSSLContext
#ifndef OPENSSL_NO_TLSEXT #ifndef OPENSSL_NO_TLSEXT
if (ssl_->initial_ctx) { if (ssl_->initial_ctx) {
SSL_CTX_free(ssl_->initial_ctx); SSL_CTX_free(ssl_->initial_ctx);
......
...@@ -1121,6 +1121,9 @@ void AsyncSocket::attachEventBase(EventBase* eventBase) { ...@@ -1121,6 +1121,9 @@ void AsyncSocket::attachEventBase(EventBase* eventBase) {
eventBase_ = eventBase; eventBase_ = eventBase;
ioHandler_.attachEventBase(eventBase); ioHandler_.attachEventBase(eventBase);
writeTimeout_.attachEventBase(eventBase); writeTimeout_.attachEventBase(eventBase);
if (evbChangeCb_) {
evbChangeCb_->evbAttached(this);
}
} }
void AsyncSocket::detachEventBase() { void AsyncSocket::detachEventBase() {
...@@ -1133,6 +1136,9 @@ void AsyncSocket::detachEventBase() { ...@@ -1133,6 +1136,9 @@ void AsyncSocket::detachEventBase() {
eventBase_ = nullptr; eventBase_ = nullptr;
ioHandler_.detachEventBase(); ioHandler_.detachEventBase();
writeTimeout_.detachEventBase(); writeTimeout_.detachEventBase();
if (evbChangeCb_) {
evbChangeCb_->evbDetached(this);
}
} }
bool AsyncSocket::isDetachable() const { bool AsyncSocket::isDetachable() const {
......
...@@ -94,6 +94,19 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -94,6 +94,19 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
noexcept = 0; noexcept = 0;
}; };
class EvbChangeCallback {
public:
virtual ~EvbChangeCallback() = default;
// Called when the socket has been attached to a new EVB
// and is called from within that EVB thread
virtual void evbAttached(AsyncSocket* socket) = 0;
// Called when the socket is detached from an EVB and
// is called from the EVB thread being detached
virtual void evbDetached(AsyncSocket* socket) = 0;
};
explicit AsyncSocket(); explicit AsyncSocket();
/** /**
* Create a new unconnected AsyncSocket. * Create a new unconnected AsyncSocket.
...@@ -560,6 +573,12 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -560,6 +573,12 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
void setBufferCallback(BufferCallback* cb); void setBufferCallback(BufferCallback* cb);
// Callers should set this prior to connecting the socket for the safest
// behavior.
void setEvbChangedCallback(std::unique_ptr<EvbChangeCallback> cb) {
evbChangeCb_ = std::move(cb);
}
/** /**
* writeReturn is the total number of bytes written, or WRITE_ERROR on error. * writeReturn is the total number of bytes written, or WRITE_ERROR on error.
* If no data has been written, 0 is returned. * If no data has been written, 0 is returned.
...@@ -930,6 +949,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -930,6 +949,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
bool tfoEnabled_{false}; bool tfoEnabled_{false};
bool tfoAttempted_{false}; bool tfoAttempted_{false};
bool tfoFinished_{false}; bool tfoFinished_{false};
std::unique_ptr<EvbChangeCallback> evbChangeCb_{nullptr};
}; };
#ifdef _MSC_VER #ifdef _MSC_VER
#pragma vtordisp(pop) #pragma vtordisp(pop)
......
...@@ -17,9 +17,11 @@ ...@@ -17,9 +17,11 @@
#include <pthread.h> #include <pthread.h>
#include <folly/futures/Promise.h>
#include <folly/io/async/AsyncSSLSocket.h> #include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/EventBase.h> #include <folly/io/async/EventBase.h>
#include <folly/io/async/SSLContext.h> #include <folly/io/async/SSLContext.h>
#include <folly/io/async/ScopedEventBaseThread.h>
#include <folly/portability/GTest.h> #include <folly/portability/GTest.h>
using std::string; using std::string;
...@@ -31,44 +33,100 @@ using std::list; ...@@ -31,44 +33,100 @@ using std::list;
namespace folly { namespace folly {
struct EvbAndContext {
EvbAndContext() {
ctx_.reset(new SSLContext());
ctx_->setOptions(SSL_OP_NO_TICKET);
ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
}
std::shared_ptr<AsyncSSLSocket> createSocket() {
return AsyncSSLSocket::newSocket(ctx_, getEventBase());
}
EventBase* getEventBase() {
return evb_.getEventBase();
}
void attach(AsyncSSLSocket& socket) {
socket.attachEventBase(getEventBase());
socket.attachSSLContext(ctx_);
}
folly::ScopedEventBaseThread evb_;
std::shared_ptr<SSLContext> ctx_;
};
class AttachDetachClient : public AsyncSocket::ConnectCallback, class AttachDetachClient : public AsyncSocket::ConnectCallback,
public AsyncTransportWrapper::WriteCallback, public AsyncTransportWrapper::WriteCallback,
public AsyncTransportWrapper::ReadCallback { public AsyncTransportWrapper::ReadCallback {
private: private:
EventBase *eventBase_; // two threads here - we'll create the socket in one, connect
// in the other, and then read/write in the initial one
EvbAndContext t1_;
EvbAndContext t2_;
std::shared_ptr<AsyncSSLSocket> sslSocket_; std::shared_ptr<AsyncSSLSocket> sslSocket_;
std::shared_ptr<SSLContext> ctx_;
folly::SocketAddress address_; folly::SocketAddress address_;
char buf_[128]; char buf_[128];
char readbuf_[128]; char readbuf_[128];
uint32_t bytesRead_; uint32_t bytesRead_;
// promise to fulfill when done
folly::Promise<bool> promise_;
void detach() {
sslSocket_->detachEventBase();
sslSocket_->detachSSLContext();
}
public: public:
AttachDetachClient(EventBase *eventBase, const folly::SocketAddress& address) explicit AttachDetachClient(const folly::SocketAddress& address)
: eventBase_(eventBase), address_(address), bytesRead_(0) { : address_(address), bytesRead_(0) {}
ctx_.reset(new SSLContext());
ctx_->setOptions(SSL_OP_NO_TICKET); Future<bool> getFuture() {
ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); return promise_.getFuture();
} }
void connect() { void connect() {
sslSocket_ = AsyncSSLSocket::newSocket(ctx_, eventBase_); // create in one and then move to another
auto t1Evb = t1_.getEventBase();
t1Evb->runInEventBaseThread([this] {
sslSocket_ = t1_.createSocket();
// ensure we can detach and reattach the context before connecting
for (int i = 0; i < 1000; ++i) {
sslSocket_->detachSSLContext();
sslSocket_->attachSSLContext(t1_.ctx_);
}
// detach from t1 and connect in t2
detach();
auto t2Evb = t2_.getEventBase();
t2Evb->runInEventBaseThread([this] {
t2_.attach(*sslSocket_);
sslSocket_->connect(this, address_); sslSocket_->connect(this, address_);
});
});
} }
void connectSuccess() noexcept override { void connectSuccess() noexcept override {
auto t2Evb = t2_.getEventBase();
EXPECT_TRUE(t2Evb->isInEventBaseThread());
cerr << "client SSL socket connected" << endl; cerr << "client SSL socket connected" << endl;
for (int i = 0; i < 1000; ++i) { for (int i = 0; i < 1000; ++i) {
sslSocket_->detachSSLContext(); sslSocket_->detachSSLContext();
sslSocket_->attachSSLContext(ctx_); sslSocket_->attachSSLContext(t2_.ctx_);
} }
EXPECT_EQ(ctx_->getSSLCtx()->references, 2); // detach from t2 and then read/write in t1
t2Evb->runInEventBaseThread([this] {
detach();
auto t1Evb = t1_.getEventBase();
t1Evb->runInEventBaseThread([this] {
t1_.attach(*sslSocket_);
sslSocket_->write(this, buf_, sizeof(buf_)); sslSocket_->write(this, buf_, sizeof(buf_));
sslSocket_->setReadCB(this); sslSocket_->setReadCB(this);
memset(readbuf_, 'b', sizeof(readbuf_)); memset(readbuf_, 'b', sizeof(readbuf_));
bytesRead_ = 0; bytesRead_ = 0;
});
});
} }
void connectErr(const AsyncSocketException& ex) noexcept override void connectErr(const AsyncSocketException& ex) noexcept override
...@@ -96,14 +154,19 @@ class AttachDetachClient : public AsyncSocket::ConnectCallback, ...@@ -96,14 +154,19 @@ class AttachDetachClient : public AsyncSocket::ConnectCallback,
void readErr(const AsyncSocketException& ex) noexcept override { void readErr(const AsyncSocketException& ex) noexcept override {
cerr << "client readError: " << ex.what() << endl; cerr << "client readError: " << ex.what() << endl;
promise_.setException(ex);
} }
void readDataAvailable(size_t len) noexcept override { void readDataAvailable(size_t len) noexcept override {
EXPECT_TRUE(t1_.getEventBase()->isInEventBaseThread());
EXPECT_EQ(sslSocket_->getEventBase(), t1_.getEventBase());
cerr << "client read data: " << len << endl; cerr << "client read data: " << len << endl;
bytesRead_ += len; bytesRead_ += len;
if (len == sizeof(buf_)) { if (len == sizeof(buf_)) {
EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0); EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
sslSocket_->closeNow(); sslSocket_->closeNow();
sslSocket_.reset();
promise_.setValue(true);
} }
} }
}; };
...@@ -119,13 +182,12 @@ TEST(AsyncSSLSocketTest2, AttachDetachSSLContext) { ...@@ -119,13 +182,12 @@ TEST(AsyncSSLSocketTest2, AttachDetachSSLContext) {
SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback); SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback); TestSSLServer server(&acceptCallback);
EventBase eventBase;
EventBaseAborter eba(&eventBase, 3000);
std::shared_ptr<AttachDetachClient> client( std::shared_ptr<AttachDetachClient> client(
new AttachDetachClient(&eventBase, server.getAddress())); new AttachDetachClient(server.getAddress()));
auto f = client->getFuture();
client->connect(); client->connect();
eventBase.loop(); EXPECT_TRUE(f.within(std::chrono::seconds(3)).get());
} }
} // folly } // folly
......
...@@ -2996,4 +2996,24 @@ TEST(AsyncSocketTest, ConnectTFOWithBigData) { ...@@ -2996,4 +2996,24 @@ TEST(AsyncSocketTest, ConnectTFOWithBigData) {
EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded()); EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
} }
class MockEvbChangeCallback : public AsyncSocket::EvbChangeCallback {
public:
MOCK_METHOD1(evbAttached, void(AsyncSocket*));
MOCK_METHOD1(evbDetached, void(AsyncSocket*));
};
TEST(AsyncSocketTest, EvbCallbacks) {
auto cb = folly::make_unique<MockEvbChangeCallback>();
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
InSequence seq;
EXPECT_CALL(*cb, evbDetached(socket.get())).Times(1);
EXPECT_CALL(*cb, evbAttached(socket.get())).Times(1);
socket->setEvbChangedCallback(std::move(cb));
socket->detachEventBase();
socket->attachEventBase(&evb);
}
#endif #endif
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment