Commit 16dc0043 authored by Subodh Iyengar's avatar Subodh Iyengar Committed by Facebook Github Bot

Fix TFO refused case

Summary:
When TFO falls back, it's possible that
the fallback can also error out.

We handle this correctly in AsyncSocket,
however because AsyncSSLSocket is so
inter-twined with AsyncSocket, we missed
the case of error as well.

This changes it so that a connect error on
fallback will cause a handshake error

Differential Revision: D4226477

fbshipit-source-id: c6e845e4a907bfef1e6ad1b4118db47184d047e0
parent 26c4e8d2
......@@ -1127,11 +1127,21 @@ AsyncSSLSocket::handleConnect() noexcept {
void AsyncSSLSocket::invokeConnectErr(const AsyncSocketException& ex) {
connectionTimeout_.cancelTimeout();
AsyncSocket::invokeConnectErr(ex);
if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
assert(tfoAttempted_);
if (handshakeTimeout_.isScheduled()) {
handshakeTimeout_.cancelTimeout();
}
// If we fell back to connecting state during TFO and the connection
// failed, it would be an SSL failure as well.
invokeHandshakeErr(ex);
}
}
void AsyncSSLSocket::invokeConnectSuccess() {
connectionTimeout_.cancelTimeout();
if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
assert(tfoAttempted_);
// If we failed TFO, we'd fall back to trying to connect the socket,
// to setup things like timeouts.
startSSLConnect();
......
......@@ -1798,8 +1798,8 @@ AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
errno = EAGAIN;
totalWritten = -1;
} else if (errno == EOPNOTSUPP) {
VLOG(4) << "TFO not supported";
// Try falling back to connecting.
VLOG(4) << "TFO not supported";
state_ = StateEnum::CONNECTING;
try {
int ret = socketConnect((const sockaddr*)&addr, len);
......@@ -1977,12 +1977,7 @@ void AsyncSocket::startFail() {
}
}
void AsyncSocket::finishFail() {
assert(state_ == StateEnum::ERROR);
assert(getDestructorGuardCount() > 0);
AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
withAddr("socket closing after error"));
void AsyncSocket::invokeAllErrors(const AsyncSocketException& ex) {
invokeConnectErr(ex);
failAllWrites(ex);
......@@ -1993,6 +1988,22 @@ void AsyncSocket::finishFail() {
}
}
void AsyncSocket::finishFail() {
assert(state_ == StateEnum::ERROR);
assert(getDestructorGuardCount() > 0);
AsyncSocketException ex(
AsyncSocketException::INTERNAL_ERROR,
withAddr("socket closing after error"));
invokeAllErrors(ex);
}
void AsyncSocket::finishFail(const AsyncSocketException& ex) {
assert(state_ == StateEnum::ERROR);
assert(getDestructorGuardCount() > 0);
invokeAllErrors(ex);
}
void AsyncSocket::fail(const char* fn, const AsyncSocketException& ex) {
VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
<< state_ << " host=" << addr_.describe()
......@@ -2010,7 +2021,7 @@ void AsyncSocket::failConnect(const char* fn, const AsyncSocketException& ex) {
startFail();
invokeConnectErr(ex);
finishFail();
finishFail(ex);
}
void AsyncSocket::failRead(const char* fn, const AsyncSocketException& ex) {
......
......@@ -877,6 +877,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
// error handling methods
void startFail();
void finishFail();
void finishFail(const AsyncSocketException& ex);
void invokeAllErrors(const AsyncSocketException& ex);
void fail(const char* fn, const AsyncSocketException& ex);
void failConnect(const char* fn, const AsyncSocketException& ex);
void failRead(const char* fn, const AsyncSocketException& ex);
......
......@@ -1918,6 +1918,21 @@ TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
}
TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
// Start listening on a local port
EventBase evb;
// Hopefully nothing is listening on this address
SocketAddress addr("127.0.0.1", 65535);
auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
ConnCallback ccb;
socket->connect(&ccb, addr, 100);
evb.loop();
EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
}
#endif
} // namespace
......
......@@ -2524,7 +2524,7 @@ TEST(AsyncSocketTest, ConnectTFOSupplyEarlyReadCB) {
/**
* Test connecting to a server that isn't listening
*/
TEST(AsyncSocketTest, ConnectRefusedTFO) {
TEST(AsyncSocketTest, ConnectRefusedImmediatelyTFO) {
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
......@@ -2541,7 +2541,6 @@ TEST(AsyncSocketTest, ConnectRefusedTFO) {
WriteCallback write1;
// Trigger the connect if TFO attempt is supported.
socket->writeChain(&write1, IOBuf::copyBuffer("hey"));
evb.loop();
WriteCallback write2;
socket->writeChain(&write2, IOBuf::copyBuffer("hey"));
evb.loop();
......@@ -2675,6 +2674,51 @@ TEST(AsyncSocketTest, TestTFOUnsupported) {
EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
}
TEST(AsyncSocketTest, ConnectRefusedDelayedTFO) {
EventBase evb;
auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
socket->enableTFO();
// Hopefully this fails
folly::SocketAddress fakeAddr("127.0.0.1", 65535);
EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
.WillOnce(Invoke([&](int fd, struct msghdr*, int) {
sockaddr_storage addr;
auto len = fakeAddr.getAddress(&addr);
int ret = connect(fd, (const struct sockaddr*)&addr, len);
LOG(INFO) << "connecting the socket " << fd << " : " << ret << " : "
<< errno;
return ret;
}));
// Hopefully nothing is actually listening on this address
ConnCallback cb;
socket->connect(&cb, fakeAddr, 30);
WriteCallback write1;
// Trigger the connect if TFO attempt is supported.
socket->writeChain(&write1, IOBuf::copyBuffer("hey"));
if (socket->getTFOFinished()) {
// This test is useless now.
return;
}
WriteCallback write2;
// Trigger the connect if TFO attempt is supported.
socket->writeChain(&write2, IOBuf::copyBuffer("hey"));
evb.loop();
EXPECT_EQ(STATE_FAILED, write1.state);
EXPECT_EQ(STATE_FAILED, write2.state);
EXPECT_FALSE(socket->getTFOSucceded());
EXPECT_EQ(STATE_SUCCEEDED, cb.state);
EXPECT_LE(0, socket->getConnectTime().count());
EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
EXPECT_TRUE(socket->getTFOAttempted());
}
TEST(AsyncSocketTest, TestTFOUnsupportedTimeout) {
// Try connecting to server that won't respond.
//
......
......@@ -27,6 +27,9 @@ DEFINE_int32(port, 0, "port");
DEFINE_bool(tfo, false, "enable tfo");
DEFINE_string(msg, "", "Message to send");
DEFINE_bool(ssl, false, "use ssl");
DEFINE_int32(timeout_ms, 0, "timeout");
DEFINE_int32(sendtimeout_ms, 0, "send timeout");
DEFINE_int32(num_writes, 1, "number of writes");
int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
......@@ -53,6 +56,10 @@ int main(int argc, char** argv) {
#endif
}
if (FLAGS_sendtimeout_ms != 0) {
socket->setSendTimeout(FLAGS_sendtimeout_ms);
}
// Keep this around
auto sockAddr = socket.get();
......@@ -60,10 +67,13 @@ int main(int argc, char** argv) {
SocketAddress addr;
addr.setFromHostPort(FLAGS_host, FLAGS_port);
sock.setAddress(addr);
sock.open();
std::chrono::milliseconds timeout(FLAGS_timeout_ms);
sock.open(timeout);
LOG(INFO) << "connected to " << addr.getAddressStr();
sock.write((const uint8_t*)FLAGS_msg.data(), FLAGS_msg.size());
for (int32_t i = 0; i < FLAGS_num_writes; ++i) {
sock.write((const uint8_t*)FLAGS_msg.data(), FLAGS_msg.size());
}
LOG(INFO) << "TFO attempted: " << sockAddr->getTFOAttempted();
LOG(INFO) << "TFO finished: " << sockAddr->getTFOFinished();
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment