Commit 168487dc authored by Neel Goyal's avatar Neel Goyal Committed by Facebook Github Bot

Add evb change callback to SSL Socket

Summary: Allow observers to be notified when AsyncSocket attaches and detaches from EVB

Reviewed By: siyengar

Differential Revision: D4256175

fbshipit-source-id: a3ff96811f885e508f20cf11ce52e0f00e5ee461
parent 4824fb83
......@@ -464,18 +464,36 @@ void AsyncSSLSocket::attachSSLContext(
DCHECK(ctx->getSSLCtx());
ctx_ = ctx;
// It's possible this could be attached before ssl_ is set up
if (!ssl_) {
return;
}
// In order to call attachSSLContext, detachSSLContext must have been
// previously called which sets the socket's context to the dummy
// context. Thus we must acquire this lock.
// previously called.
// We need to update the initial_ctx if necessary
auto sslCtx = ctx->getSSLCtx();
#ifndef OPENSSL_NO_TLSEXT
CRYPTO_add(&sslCtx->references, 1, CRYPTO_LOCK_SSL_CTX);
// note that detachSSLContext has already freed ssl_->initial_ctx
ssl_->initial_ctx = sslCtx;
#endif
// Detach sets the socket's context to the dummy context. Thus we must acquire
// this lock.
SpinLockGuard guard(dummyCtxLock);
SSL_set_SSL_CTX(ssl_, ctx->getSSLCtx());
SSL_set_SSL_CTX(ssl_, sslCtx);
}
void AsyncSSLSocket::detachSSLContext() {
DCHECK(ctx_);
ctx_.reset();
// We aren't using the initial_ctx for now, and it can introduce race
// conditions in the destructor of the SSL object.
// It's possible for this to be called before ssl_ has been
// set up
if (!ssl_) {
return;
}
// Detach the initial_ctx as well. Internally w/ OPENSSL_NO_TLSEXT
// it is used for session info. It will be reattached in attachSSLContext
#ifndef OPENSSL_NO_TLSEXT
if (ssl_->initial_ctx) {
SSL_CTX_free(ssl_->initial_ctx);
......
......@@ -1121,6 +1121,9 @@ void AsyncSocket::attachEventBase(EventBase* eventBase) {
eventBase_ = eventBase;
ioHandler_.attachEventBase(eventBase);
writeTimeout_.attachEventBase(eventBase);
if (evbChangeCb_) {
evbChangeCb_->evbAttached(this);
}
}
void AsyncSocket::detachEventBase() {
......@@ -1133,6 +1136,9 @@ void AsyncSocket::detachEventBase() {
eventBase_ = nullptr;
ioHandler_.detachEventBase();
writeTimeout_.detachEventBase();
if (evbChangeCb_) {
evbChangeCb_->evbDetached(this);
}
}
bool AsyncSocket::isDetachable() const {
......
......@@ -94,6 +94,19 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
noexcept = 0;
};
class EvbChangeCallback {
public:
virtual ~EvbChangeCallback() = default;
// Called when the socket has been attached to a new EVB
// and is called from within that EVB thread
virtual void evbAttached(AsyncSocket* socket) = 0;
// Called when the socket is detached from an EVB and
// is called from the EVB thread being detached
virtual void evbDetached(AsyncSocket* socket) = 0;
};
explicit AsyncSocket();
/**
* Create a new unconnected AsyncSocket.
......@@ -560,6 +573,12 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
void setBufferCallback(BufferCallback* cb);
// Callers should set this prior to connecting the socket for the safest
// behavior.
void setEvbChangedCallback(std::unique_ptr<EvbChangeCallback> cb) {
evbChangeCb_ = std::move(cb);
}
/**
* writeReturn is the total number of bytes written, or WRITE_ERROR on error.
* If no data has been written, 0 is returned.
......@@ -930,6 +949,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
bool tfoEnabled_{false};
bool tfoAttempted_{false};
bool tfoFinished_{false};
std::unique_ptr<EvbChangeCallback> evbChangeCb_{nullptr};
};
#ifdef _MSC_VER
#pragma vtordisp(pop)
......
......@@ -17,9 +17,11 @@
#include <pthread.h>
#include <folly/futures/Promise.h>
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/EventBase.h>
#include <folly/io/async/SSLContext.h>
#include <folly/io/async/ScopedEventBaseThread.h>
#include <folly/portability/GTest.h>
using std::string;
......@@ -31,44 +33,100 @@ using std::list;
namespace folly {
struct EvbAndContext {
EvbAndContext() {
ctx_.reset(new SSLContext());
ctx_->setOptions(SSL_OP_NO_TICKET);
ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
}
std::shared_ptr<AsyncSSLSocket> createSocket() {
return AsyncSSLSocket::newSocket(ctx_, getEventBase());
}
EventBase* getEventBase() {
return evb_.getEventBase();
}
void attach(AsyncSSLSocket& socket) {
socket.attachEventBase(getEventBase());
socket.attachSSLContext(ctx_);
}
folly::ScopedEventBaseThread evb_;
std::shared_ptr<SSLContext> ctx_;
};
class AttachDetachClient : public AsyncSocket::ConnectCallback,
public AsyncTransportWrapper::WriteCallback,
public AsyncTransportWrapper::ReadCallback {
private:
EventBase *eventBase_;
// two threads here - we'll create the socket in one, connect
// in the other, and then read/write in the initial one
EvbAndContext t1_;
EvbAndContext t2_;
std::shared_ptr<AsyncSSLSocket> sslSocket_;
std::shared_ptr<SSLContext> ctx_;
folly::SocketAddress address_;
char buf_[128];
char readbuf_[128];
uint32_t bytesRead_;
// promise to fulfill when done
folly::Promise<bool> promise_;
void detach() {
sslSocket_->detachEventBase();
sslSocket_->detachSSLContext();
}
public:
AttachDetachClient(EventBase *eventBase, const folly::SocketAddress& address)
: eventBase_(eventBase), address_(address), bytesRead_(0) {
ctx_.reset(new SSLContext());
ctx_->setOptions(SSL_OP_NO_TICKET);
ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
explicit AttachDetachClient(const folly::SocketAddress& address)
: address_(address), bytesRead_(0) {}
Future<bool> getFuture() {
return promise_.getFuture();
}
void connect() {
sslSocket_ = AsyncSSLSocket::newSocket(ctx_, eventBase_);
sslSocket_->connect(this, address_);
// create in one and then move to another
auto t1Evb = t1_.getEventBase();
t1Evb->runInEventBaseThread([this] {
sslSocket_ = t1_.createSocket();
// ensure we can detach and reattach the context before connecting
for (int i = 0; i < 1000; ++i) {
sslSocket_->detachSSLContext();
sslSocket_->attachSSLContext(t1_.ctx_);
}
// detach from t1 and connect in t2
detach();
auto t2Evb = t2_.getEventBase();
t2Evb->runInEventBaseThread([this] {
t2_.attach(*sslSocket_);
sslSocket_->connect(this, address_);
});
});
}
void connectSuccess() noexcept override {
auto t2Evb = t2_.getEventBase();
EXPECT_TRUE(t2Evb->isInEventBaseThread());
cerr << "client SSL socket connected" << endl;
for (int i = 0; i < 1000; ++i) {
sslSocket_->detachSSLContext();
sslSocket_->attachSSLContext(ctx_);
sslSocket_->attachSSLContext(t2_.ctx_);
}
EXPECT_EQ(ctx_->getSSLCtx()->references, 2);
sslSocket_->write(this, buf_, sizeof(buf_));
sslSocket_->setReadCB(this);
memset(readbuf_, 'b', sizeof(readbuf_));
bytesRead_ = 0;
// detach from t2 and then read/write in t1
t2Evb->runInEventBaseThread([this] {
detach();
auto t1Evb = t1_.getEventBase();
t1Evb->runInEventBaseThread([this] {
t1_.attach(*sslSocket_);
sslSocket_->write(this, buf_, sizeof(buf_));
sslSocket_->setReadCB(this);
memset(readbuf_, 'b', sizeof(readbuf_));
bytesRead_ = 0;
});
});
}
void connectErr(const AsyncSocketException& ex) noexcept override
......@@ -96,14 +154,19 @@ class AttachDetachClient : public AsyncSocket::ConnectCallback,
void readErr(const AsyncSocketException& ex) noexcept override {
cerr << "client readError: " << ex.what() << endl;
promise_.setException(ex);
}
void readDataAvailable(size_t len) noexcept override {
EXPECT_TRUE(t1_.getEventBase()->isInEventBaseThread());
EXPECT_EQ(sslSocket_->getEventBase(), t1_.getEventBase());
cerr << "client read data: " << len << endl;
bytesRead_ += len;
if (len == sizeof(buf_)) {
EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
sslSocket_->closeNow();
sslSocket_.reset();
promise_.setValue(true);
}
}
};
......@@ -119,13 +182,12 @@ TEST(AsyncSSLSocketTest2, AttachDetachSSLContext) {
SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback);
EventBase eventBase;
EventBaseAborter eba(&eventBase, 3000);
std::shared_ptr<AttachDetachClient> client(
new AttachDetachClient(&eventBase, server.getAddress()));
new AttachDetachClient(server.getAddress()));
auto f = client->getFuture();
client->connect();
eventBase.loop();
EXPECT_TRUE(f.within(std::chrono::seconds(3)).get());
}
} // folly
......
......@@ -2996,4 +2996,24 @@ TEST(AsyncSocketTest, ConnectTFOWithBigData) {
EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
}
class MockEvbChangeCallback : public AsyncSocket::EvbChangeCallback {
public:
MOCK_METHOD1(evbAttached, void(AsyncSocket*));
MOCK_METHOD1(evbDetached, void(AsyncSocket*));
};
TEST(AsyncSocketTest, EvbCallbacks) {
auto cb = folly::make_unique<MockEvbChangeCallback>();
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
InSequence seq;
EXPECT_CALL(*cb, evbDetached(socket.get())).Times(1);
EXPECT_CALL(*cb, evbAttached(socket.get())).Times(1);
socket->setEvbChangedCallback(std::move(cb));
socket->detachEventBase();
socket->attachEventBase(&evb);
}
#endif
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment