Commit 3a033f27 authored by Kyle Nekritz's avatar Kyle Nekritz Committed by Facebook Github Bot

Add pre received data API to AsyncSSLSocket.

Summary: This allows something else (ie fizz) to read data from a socket, and then later decide to to accept an SSL connection with OpenSSL by inserting the data it read in front of future reads on the socket.

Reviewed By: anirudhvr

Differential Revision: D4325634

fbshipit-source-id: 05076d2d911fda681b9c4e5d9d3375559293ea35
parent f3c3434c
...@@ -452,6 +452,10 @@ void AsyncSSLSocket::sslAccept( ...@@ -452,6 +452,10 @@ void AsyncSSLSocket::sslAccept(
/* register for a read operation (waiting for CLIENT HELLO) */ /* register for a read operation (waiting for CLIENT HELLO) */
updateEventRegistration(EventHandler::READ, EventHandler::WRITE); updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
if (preReceivedData_) {
handleRead();
}
} }
#if OPENSSL_VERSION_NUMBER >= 0x009080bfL #if OPENSSL_VERSION_NUMBER >= 0x009080bfL
...@@ -1610,12 +1614,31 @@ int AsyncSSLSocket::bioRead(BIO* b, char* out, int outl) { ...@@ -1610,12 +1614,31 @@ int AsyncSSLSocket::bioRead(BIO* b, char* out, int outl) {
if (!out) { if (!out) {
return 0; return 0;
} }
auto result = recv(OpenSSLUtils::getBioFd(b, nullptr), out, outl, 0);
BIO_clear_retry_flags(b); BIO_clear_retry_flags(b);
if (result <= 0 && OpenSSLUtils::getBioShouldRetryWrite(result)) {
BIO_set_retry_read(b); auto appData = OpenSSLUtils::getBioAppData(b);
CHECK(appData);
auto sslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
if (sslSock->preReceivedData_ && !sslSock->preReceivedData_->empty()) {
VLOG(5) << "AsyncSSLSocket::bioRead() this=" << sslSock
<< ", reading pre-received data";
Cursor cursor(sslSock->preReceivedData_.get());
auto len = cursor.pullAtMost(out, outl);
IOBufQueue queue;
queue.append(std::move(sslSock->preReceivedData_));
queue.trimStart(len);
sslSock->preReceivedData_ = queue.move();
return len;
} else {
auto result = recv(OpenSSLUtils::getBioFd(b, nullptr), out, outl, 0);
if (result <= 0 && OpenSSLUtils::getBioShouldRetryWrite(result)) {
BIO_set_retry_read(b);
}
return result;
} }
return result;
} }
int AsyncSSLSocket::sslVerifyCallback( int AsyncSSLSocket::sslVerifyCallback(
...@@ -1632,6 +1655,12 @@ int AsyncSSLSocket::sslVerifyCallback( ...@@ -1632,6 +1655,12 @@ int AsyncSSLSocket::sslVerifyCallback(
preverifyOk; preverifyOk;
} }
void AsyncSSLSocket::setPreReceivedData(std::unique_ptr<IOBuf> data) {
CHECK(sslState_ == STATE_UNINIT || sslState_ == STATE_UNENCRYPTED);
CHECK(!preReceivedData_);
preReceivedData_ = std::move(data);
}
void AsyncSSLSocket::enableClientHelloParsing() { void AsyncSSLSocket::enableClientHelloParsing() {
parseClientHello_ = true; parseClientHello_ = true;
clientHelloInfo_.reset(new ssl::ClientHelloInfo()); clientHelloInfo_.reset(new ssl::ClientHelloInfo());
......
...@@ -278,6 +278,8 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -278,6 +278,8 @@ class AsyncSSLSocket : public virtual AsyncSocket {
virtual size_t getRawBytesReceived() const override; virtual size_t getRawBytesReceived() const override;
void enableClientHelloParsing(); void enableClientHelloParsing();
void setPreReceivedData(std::unique_ptr<IOBuf> data);
/** /**
* Accept an SSL connection on the socket. * Accept an SSL connection on the socket.
* *
...@@ -818,6 +820,8 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -818,6 +820,8 @@ class AsyncSSLSocket : public virtual AsyncSocket {
std::chrono::steady_clock::time_point handshakeEndTime_; std::chrono::steady_clock::time_point handshakeEndTime_;
std::chrono::milliseconds handshakeConnectTimeout_{0}; std::chrono::milliseconds handshakeConnectTimeout_{0};
bool sessionResumptionAttempted_{false}; bool sessionResumptionAttempted_{false};
std::unique_ptr<IOBuf> preReceivedData_;
}; };
} // namespace } // namespace
...@@ -1961,6 +1961,41 @@ TEST(AsyncSSLSocketTest, HandshakeTFORefused) { ...@@ -1961,6 +1961,41 @@ TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
EXPECT_THAT(ccb.error, testing::HasSubstr("refused")); EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
} }
TEST(AsyncSSLSocketTest, TestPreReceivedData) {
EventBase clientEventBase;
EventBase serverEventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto dfServerCtx = std::make_shared<SSLContext>();
std::array<int, 2> fds;
getfds(fds.data());
getctx(clientCtx, dfServerCtx);
AsyncSSLSocket::UniquePtr clientSockPtr(
new AsyncSSLSocket(clientCtx, &clientEventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSockPtr(
new AsyncSSLSocket(dfServerCtx, &serverEventBase, fds[1], true));
auto clientSock = clientSockPtr.get();
auto serverSock = serverSockPtr.get();
SSLHandshakeClient client(std::move(clientSockPtr), true, true);
// Steal some data from the server.
clientEventBase.loopOnce();
std::array<uint8_t, 10> buf;
recv(fds[1], buf.data(), buf.size(), 0);
serverSock->setPreReceivedData(IOBuf::wrapBuffer(range(buf)));
SSLHandshakeServer server(std::move(serverSockPtr), true, true);
while (!client.handshakeSuccess_ && !client.handshakeError_) {
serverEventBase.loopOnce();
clientEventBase.loopOnce();
}
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_EQ(
serverSock->getRawBytesReceived(), clientSock->getRawBytesWritten());
}
#endif #endif
} // namespace } // namespace
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment