Commit 16dc0043 authored by Subodh Iyengar's avatar Subodh Iyengar Committed by Facebook Github Bot

Fix TFO refused case

Summary:
When TFO falls back, it's possible that
the fallback can also error out.

We handle this correctly in AsyncSocket,
however because AsyncSSLSocket is so
inter-twined with AsyncSocket, we missed
the case of error as well.

This changes it so that a connect error on
fallback will cause a handshake error

Differential Revision: D4226477

fbshipit-source-id: c6e845e4a907bfef1e6ad1b4118db47184d047e0
parent 26c4e8d2
...@@ -1127,11 +1127,21 @@ AsyncSSLSocket::handleConnect() noexcept { ...@@ -1127,11 +1127,21 @@ AsyncSSLSocket::handleConnect() noexcept {
void AsyncSSLSocket::invokeConnectErr(const AsyncSocketException& ex) { void AsyncSSLSocket::invokeConnectErr(const AsyncSocketException& ex) {
connectionTimeout_.cancelTimeout(); connectionTimeout_.cancelTimeout();
AsyncSocket::invokeConnectErr(ex); AsyncSocket::invokeConnectErr(ex);
if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
assert(tfoAttempted_);
if (handshakeTimeout_.isScheduled()) {
handshakeTimeout_.cancelTimeout();
}
// If we fell back to connecting state during TFO and the connection
// failed, it would be an SSL failure as well.
invokeHandshakeErr(ex);
}
} }
void AsyncSSLSocket::invokeConnectSuccess() { void AsyncSSLSocket::invokeConnectSuccess() {
connectionTimeout_.cancelTimeout(); connectionTimeout_.cancelTimeout();
if (sslState_ == SSLStateEnum::STATE_CONNECTING) { if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
assert(tfoAttempted_);
// If we failed TFO, we'd fall back to trying to connect the socket, // If we failed TFO, we'd fall back to trying to connect the socket,
// to setup things like timeouts. // to setup things like timeouts.
startSSLConnect(); startSSLConnect();
......
...@@ -1798,8 +1798,8 @@ AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) { ...@@ -1798,8 +1798,8 @@ AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
errno = EAGAIN; errno = EAGAIN;
totalWritten = -1; totalWritten = -1;
} else if (errno == EOPNOTSUPP) { } else if (errno == EOPNOTSUPP) {
VLOG(4) << "TFO not supported";
// Try falling back to connecting. // Try falling back to connecting.
VLOG(4) << "TFO not supported";
state_ = StateEnum::CONNECTING; state_ = StateEnum::CONNECTING;
try { try {
int ret = socketConnect((const sockaddr*)&addr, len); int ret = socketConnect((const sockaddr*)&addr, len);
...@@ -1977,12 +1977,7 @@ void AsyncSocket::startFail() { ...@@ -1977,12 +1977,7 @@ void AsyncSocket::startFail() {
} }
} }
void AsyncSocket::finishFail() { void AsyncSocket::invokeAllErrors(const AsyncSocketException& ex) {
assert(state_ == StateEnum::ERROR);
assert(getDestructorGuardCount() > 0);
AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
withAddr("socket closing after error"));
invokeConnectErr(ex); invokeConnectErr(ex);
failAllWrites(ex); failAllWrites(ex);
...@@ -1993,6 +1988,22 @@ void AsyncSocket::finishFail() { ...@@ -1993,6 +1988,22 @@ void AsyncSocket::finishFail() {
} }
} }
void AsyncSocket::finishFail() {
assert(state_ == StateEnum::ERROR);
assert(getDestructorGuardCount() > 0);
AsyncSocketException ex(
AsyncSocketException::INTERNAL_ERROR,
withAddr("socket closing after error"));
invokeAllErrors(ex);
}
void AsyncSocket::finishFail(const AsyncSocketException& ex) {
assert(state_ == StateEnum::ERROR);
assert(getDestructorGuardCount() > 0);
invokeAllErrors(ex);
}
void AsyncSocket::fail(const char* fn, const AsyncSocketException& ex) { void AsyncSocket::fail(const char* fn, const AsyncSocketException& ex) {
VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state=" VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
<< state_ << " host=" << addr_.describe() << state_ << " host=" << addr_.describe()
...@@ -2010,7 +2021,7 @@ void AsyncSocket::failConnect(const char* fn, const AsyncSocketException& ex) { ...@@ -2010,7 +2021,7 @@ void AsyncSocket::failConnect(const char* fn, const AsyncSocketException& ex) {
startFail(); startFail();
invokeConnectErr(ex); invokeConnectErr(ex);
finishFail(); finishFail(ex);
} }
void AsyncSocket::failRead(const char* fn, const AsyncSocketException& ex) { void AsyncSocket::failRead(const char* fn, const AsyncSocketException& ex) {
......
...@@ -877,6 +877,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -877,6 +877,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
// error handling methods // error handling methods
void startFail(); void startFail();
void finishFail(); void finishFail();
void finishFail(const AsyncSocketException& ex);
void invokeAllErrors(const AsyncSocketException& ex);
void fail(const char* fn, const AsyncSocketException& ex); void fail(const char* fn, const AsyncSocketException& ex);
void failConnect(const char* fn, const AsyncSocketException& ex); void failConnect(const char* fn, const AsyncSocketException& ex);
void failRead(const char* fn, const AsyncSocketException& ex); void failRead(const char* fn, const AsyncSocketException& ex);
......
...@@ -1918,6 +1918,21 @@ TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) { ...@@ -1918,6 +1918,21 @@ TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out")); EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
} }
TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
// Start listening on a local port
EventBase evb;
// Hopefully nothing is listening on this address
SocketAddress addr("127.0.0.1", 65535);
auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
ConnCallback ccb;
socket->connect(&ccb, addr, 100);
evb.loop();
EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
}
#endif #endif
} // namespace } // namespace
......
...@@ -2524,7 +2524,7 @@ TEST(AsyncSocketTest, ConnectTFOSupplyEarlyReadCB) { ...@@ -2524,7 +2524,7 @@ TEST(AsyncSocketTest, ConnectTFOSupplyEarlyReadCB) {
/** /**
* Test connecting to a server that isn't listening * Test connecting to a server that isn't listening
*/ */
TEST(AsyncSocketTest, ConnectRefusedTFO) { TEST(AsyncSocketTest, ConnectRefusedImmediatelyTFO) {
EventBase evb; EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb); std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
...@@ -2541,7 +2541,6 @@ TEST(AsyncSocketTest, ConnectRefusedTFO) { ...@@ -2541,7 +2541,6 @@ TEST(AsyncSocketTest, ConnectRefusedTFO) {
WriteCallback write1; WriteCallback write1;
// Trigger the connect if TFO attempt is supported. // Trigger the connect if TFO attempt is supported.
socket->writeChain(&write1, IOBuf::copyBuffer("hey")); socket->writeChain(&write1, IOBuf::copyBuffer("hey"));
evb.loop();
WriteCallback write2; WriteCallback write2;
socket->writeChain(&write2, IOBuf::copyBuffer("hey")); socket->writeChain(&write2, IOBuf::copyBuffer("hey"));
evb.loop(); evb.loop();
...@@ -2675,6 +2674,51 @@ TEST(AsyncSocketTest, TestTFOUnsupported) { ...@@ -2675,6 +2674,51 @@ TEST(AsyncSocketTest, TestTFOUnsupported) {
EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded()); EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
} }
TEST(AsyncSocketTest, ConnectRefusedDelayedTFO) {
EventBase evb;
auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
socket->enableTFO();
// Hopefully this fails
folly::SocketAddress fakeAddr("127.0.0.1", 65535);
EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
.WillOnce(Invoke([&](int fd, struct msghdr*, int) {
sockaddr_storage addr;
auto len = fakeAddr.getAddress(&addr);
int ret = connect(fd, (const struct sockaddr*)&addr, len);
LOG(INFO) << "connecting the socket " << fd << " : " << ret << " : "
<< errno;
return ret;
}));
// Hopefully nothing is actually listening on this address
ConnCallback cb;
socket->connect(&cb, fakeAddr, 30);
WriteCallback write1;
// Trigger the connect if TFO attempt is supported.
socket->writeChain(&write1, IOBuf::copyBuffer("hey"));
if (socket->getTFOFinished()) {
// This test is useless now.
return;
}
WriteCallback write2;
// Trigger the connect if TFO attempt is supported.
socket->writeChain(&write2, IOBuf::copyBuffer("hey"));
evb.loop();
EXPECT_EQ(STATE_FAILED, write1.state);
EXPECT_EQ(STATE_FAILED, write2.state);
EXPECT_FALSE(socket->getTFOSucceded());
EXPECT_EQ(STATE_SUCCEEDED, cb.state);
EXPECT_LE(0, socket->getConnectTime().count());
EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
EXPECT_TRUE(socket->getTFOAttempted());
}
TEST(AsyncSocketTest, TestTFOUnsupportedTimeout) { TEST(AsyncSocketTest, TestTFOUnsupportedTimeout) {
// Try connecting to server that won't respond. // Try connecting to server that won't respond.
// //
......
...@@ -27,6 +27,9 @@ DEFINE_int32(port, 0, "port"); ...@@ -27,6 +27,9 @@ DEFINE_int32(port, 0, "port");
DEFINE_bool(tfo, false, "enable tfo"); DEFINE_bool(tfo, false, "enable tfo");
DEFINE_string(msg, "", "Message to send"); DEFINE_string(msg, "", "Message to send");
DEFINE_bool(ssl, false, "use ssl"); DEFINE_bool(ssl, false, "use ssl");
DEFINE_int32(timeout_ms, 0, "timeout");
DEFINE_int32(sendtimeout_ms, 0, "send timeout");
DEFINE_int32(num_writes, 1, "number of writes");
int main(int argc, char** argv) { int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true); gflags::ParseCommandLineFlags(&argc, &argv, true);
...@@ -53,6 +56,10 @@ int main(int argc, char** argv) { ...@@ -53,6 +56,10 @@ int main(int argc, char** argv) {
#endif #endif
} }
if (FLAGS_sendtimeout_ms != 0) {
socket->setSendTimeout(FLAGS_sendtimeout_ms);
}
// Keep this around // Keep this around
auto sockAddr = socket.get(); auto sockAddr = socket.get();
...@@ -60,10 +67,13 @@ int main(int argc, char** argv) { ...@@ -60,10 +67,13 @@ int main(int argc, char** argv) {
SocketAddress addr; SocketAddress addr;
addr.setFromHostPort(FLAGS_host, FLAGS_port); addr.setFromHostPort(FLAGS_host, FLAGS_port);
sock.setAddress(addr); sock.setAddress(addr);
sock.open(); std::chrono::milliseconds timeout(FLAGS_timeout_ms);
sock.open(timeout);
LOG(INFO) << "connected to " << addr.getAddressStr(); LOG(INFO) << "connected to " << addr.getAddressStr();
sock.write((const uint8_t*)FLAGS_msg.data(), FLAGS_msg.size()); for (int32_t i = 0; i < FLAGS_num_writes; ++i) {
sock.write((const uint8_t*)FLAGS_msg.data(), FLAGS_msg.size());
}
LOG(INFO) << "TFO attempted: " << sockAddr->getTFOAttempted(); LOG(INFO) << "TFO attempted: " << sockAddr->getTFOAttempted();
LOG(INFO) << "TFO finished: " << sockAddr->getTFOFinished(); LOG(INFO) << "TFO finished: " << sockAddr->getTFOFinished();
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment