Commit 0fc4facd authored by Petr Lapukhov's avatar Petr Lapukhov Committed by Facebook Github Bot

Properly disable TTLS with TFO

Summary: socketConnect() is not invoked if we have TFO enabled. Need to explicitly disable it once new socket is created, and noTransparentTls flag is set.

Reviewed By: djwatson

Differential Revision: D7945457

fbshipit-source-id: 739b7ae5bc146b50255254e644fba2618147c4f4
parent 46905dd2
...@@ -241,6 +241,23 @@ int AsyncSocket::SendMsgParamsCallback::getDefaultFlags( ...@@ -241,6 +241,23 @@ int AsyncSocket::SendMsgParamsCallback::getDefaultFlags(
namespace { namespace {
static AsyncSocket::SendMsgParamsCallback defaultSendMsgParamsCallback; static AsyncSocket::SendMsgParamsCallback defaultSendMsgParamsCallback;
// Based on flags, signal the transparent handler to disable certain functions
void disableTransparentFunctions(int fd, bool noTransparentTls, bool noTSocks) {
#if __linux__
if (noTransparentTls) {
// Ignore return value, errors are ok
VLOG(5) << "Disabling TTLS for fd " << fd;
::setsockopt(fd, SOL_SOCKET, SO_NO_TRANSPARENT_TLS, nullptr, 0);
}
if (noTSocks) {
VLOG(5) << "Disabling TSOCKS for fd " << fd;
// Ignore return value, errors are ok
::setsockopt(fd, SOL_SOCKET, SO_NO_TSOCKS, nullptr, 0);
}
#endif
}
} // namespace } // namespace
AsyncSocket::AsyncSocket() AsyncSocket::AsyncSocket()
...@@ -286,6 +303,7 @@ AsyncSocket::AsyncSocket(EventBase* evb, int fd, uint32_t zeroCopyBufId) ...@@ -286,6 +303,7 @@ AsyncSocket::AsyncSocket(EventBase* evb, int fd, uint32_t zeroCopyBufId)
<< ", zeroCopyBufId=" << zeroCopyBufId << ")"; << ", zeroCopyBufId=" << zeroCopyBufId << ")";
init(); init();
fd_ = fd; fd_ = fd;
disableTransparentFunctions(fd_, noTransparentTls_, noTSocks_);
setCloseOnExec(); setCloseOnExec();
state_ = StateEnum::ESTABLISHED; state_ = StateEnum::ESTABLISHED;
} }
...@@ -435,6 +453,7 @@ void AsyncSocket::connect(ConnectCallback* callback, ...@@ -435,6 +453,7 @@ void AsyncSocket::connect(ConnectCallback* callback,
withAddr("failed to create socket"), withAddr("failed to create socket"),
errnoCopy); errnoCopy);
} }
disableTransparentFunctions(fd_, noTransparentTls_, noTSocks_);
if (const auto shutdownSocketSet = wShutdownSocketSet_.lock()) { if (const auto shutdownSocketSet = wShutdownSocketSet_.lock()) {
shutdownSocketSet->add(fd_); shutdownSocketSet->add(fd_);
} }
...@@ -562,17 +581,6 @@ void AsyncSocket::connect(ConnectCallback* callback, ...@@ -562,17 +581,6 @@ void AsyncSocket::connect(ConnectCallback* callback,
} }
int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) { int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) {
#if __linux__
if (noTransparentTls_) {
// Ignore return value, errors are ok
setsockopt(fd_, SOL_SOCKET, SO_NO_TRANSPARENT_TLS, nullptr, 0);
}
if (noTSocks_) {
VLOG(4) << "Disabling TSOCKS for fd " << fd_;
// Ignore return value, errors are ok
setsockopt(fd_, SOL_SOCKET, SO_NO_TSOCKS, nullptr, 0);
}
#endif
int rv = fsp::connect(fd_, saddr, len); int rv = fsp::connect(fd_, saddr, len);
if (rv < 0) { if (rv < 0) {
auto errnoCopy = errno; auto errnoCopy = errno;
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <folly/io/async/test/BlockingSocket.h> #include <folly/io/async/test/BlockingSocket.h>
#include <dlfcn.h>
#include <fcntl.h> #include <fcntl.h>
#include <signal.h> #include <signal.h>
#include <sys/types.h> #include <sys/types.h>
...@@ -57,6 +58,47 @@ using std::list; ...@@ -57,6 +58,47 @@ using std::list;
using namespace testing; using namespace testing;
#if defined __linux__
namespace {
// to store libc's original setsockopt()
typedef int (*setsockopt_ptr)(int, int, int, const void*, socklen_t);
setsockopt_ptr real_setsockopt_ = nullptr;
// global struct to initialize before main runs. we can init within a test,
// or in main, but this method seems to be least intrsive and universal
struct GlobalStatic {
GlobalStatic() {
real_setsockopt_ = (setsockopt_ptr)dlsym(RTLD_NEXT, "setsockopt");
}
void reset() noexcept {
ttlsDisabledSet.clear();
}
// for each fd, tracks whether TTLS is disabled or not
std::set<int /* fd */> ttlsDisabledSet;
};
// the constructor will be called before main() which is all we care about
GlobalStatic globalStatic;
} // namespace
// we intercept setsoctopt to test setting NO_TRANSPARENT_TLS opt
// this name has to be global
int setsockopt(
int sockfd,
int level,
int optname,
const void* optval,
socklen_t optlen) {
if (optname == SO_NO_TRANSPARENT_TLS) {
globalStatic.ttlsDisabledSet.insert(sockfd);
return 0;
}
return real_setsockopt_(sockfd, level, optname, optval, optlen);
}
#endif
namespace folly { namespace folly {
uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0; uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0; uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
...@@ -2111,6 +2153,41 @@ TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) { ...@@ -2111,6 +2153,41 @@ TEST(AsyncSSLSocketTest, TestSSLCipherCodeToNameMap) {
EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), ""); EXPECT_EQ(OpenSSLUtils::getCipherName(0x00ff), "");
} }
#if defined __linux__
/**
* Ensure TransparentTLS flag is disabled with AsyncSSLSocket
*/
TEST(AsyncSSLSocketTest, TTLSDisabled) {
// clear all setsockopt tracking history
globalStatic.reset();
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback, false);
// Set up SSL context.
auto sslContext = std::make_shared<SSLContext>();
// connect
auto socket =
std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
socket->open();
EXPECT_EQ(1, globalStatic.ttlsDisabledSet.count(socket->getSocketFD()));
// write()
std::array<uint8_t, 128> buf;
memset(buf.data(), 'a', buf.size());
socket->write(buf.data(), buf.size());
// close()
socket->close();
}
#endif
#if FOLLY_ALLOW_TFO #if FOLLY_ALLOW_TFO
class MockAsyncTFOSSLSocket : public AsyncSSLSocket { class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
...@@ -2125,6 +2202,42 @@ class MockAsyncTFOSSLSocket : public AsyncSSLSocket { ...@@ -2125,6 +2202,42 @@ class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags)); MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
}; };
#if defined __linux__
/**
* Ensure TransparentTLS flag is disabled with AsyncSSLSocket + TFO
*/
TEST(AsyncSSLSocketTest, TTLSDisabledWithTFO) {
// clear all setsockopt tracking history
globalStatic.reset();
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback, true);
// Set up SSL context.
auto sslContext = std::make_shared<SSLContext>();
// connect
auto socket =
std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
socket->enableTFO();
socket->open();
EXPECT_EQ(1, globalStatic.ttlsDisabledSet.count(socket->getSocketFD()));
// write()
std::array<uint8_t, 128> buf;
memset(buf.data(), 'a', buf.size());
socket->write(buf.data(), buf.size());
// close()
socket->close();
}
#endif
/** /**
* Test connecting to, writing to, reading from, and closing the * Test connecting to, writing to, reading from, and closing the
* connection to the SSL server with TFO. * connection to the SSL server with TFO.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment