Commit fb9776fc authored by Kyle Nekritz's avatar Kyle Nekritz Committed by Facebook Github Bot

Replace MSG_PEEK with a pre-received data interface.

Summary: MSG_PEEK was difficult if not impossible to use well since we do not provide a way wait for more data to arrive. If you are using setPeek on AsyncSocket, and you do not receive the amount of data you want, you must either abandon your peek attempt, or spin around the event base waiting for more data. This diff replaces the peek interface on AsyncSocket with a pre-received data interface, allowing users to insert data back onto the front of connections after reading some data in another layer.

Reviewed By: djwatson

Differential Revision: D4626315

fbshipit-source-id: c552e64f5b3ac9e40ea3358d65b4b9db848f5d74
parent ad9b56c1
...@@ -218,14 +218,38 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx, ...@@ -218,14 +218,38 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
/** /**
* Create a server/client AsyncSSLSocket * Create a server/client AsyncSSLSocket
*/ */
AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx, AsyncSSLSocket::AsyncSSLSocket(
EventBase* evb, int fd, bool server, const shared_ptr<SSLContext>& ctx,
bool deferSecurityNegotiation) : EventBase* evb,
AsyncSocket(evb, fd), int fd,
server_(server), bool server,
ctx_(ctx), bool deferSecurityNegotiation)
handshakeTimeout_(this, evb), : AsyncSocket(evb, fd),
connectionTimeout_(this, evb) { server_(server),
ctx_(ctx),
handshakeTimeout_(this, evb),
connectionTimeout_(this, evb) {
noTransparentTls_ = true;
init();
if (server) {
SSL_CTX_set_info_callback(
ctx_->getSSLCtx(), AsyncSSLSocket::sslInfoCallback);
}
if (deferSecurityNegotiation) {
sslState_ = STATE_UNENCRYPTED;
}
}
AsyncSSLSocket::AsyncSSLSocket(
const shared_ptr<SSLContext>& ctx,
AsyncSocket::UniquePtr oldAsyncSocket,
bool server,
bool deferSecurityNegotiation)
: AsyncSocket(std::move(oldAsyncSocket)),
server_(server),
ctx_(ctx),
handshakeTimeout_(this, oldAsyncSocket->getEventBase()),
connectionTimeout_(this, oldAsyncSocket->getEventBase()) {
noTransparentTls_ = true; noTransparentTls_ = true;
init(); init();
if (server) { if (server) {
...@@ -254,11 +278,13 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx, ...@@ -254,11 +278,13 @@ AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
* Create a client AsyncSSLSocket from an already connected fd * Create a client AsyncSSLSocket from an already connected fd
* and allow tlsext_hostname to be sent in Client Hello. * and allow tlsext_hostname to be sent in Client Hello.
*/ */
AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx, AsyncSSLSocket::AsyncSSLSocket(
EventBase* evb, int fd, const shared_ptr<SSLContext>& ctx,
const std::string& serverName, EventBase* evb,
bool deferSecurityNegotiation) : int fd,
AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) { const std::string& serverName,
bool deferSecurityNegotiation)
: AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
tlsextHostname_ = serverName; tlsextHostname_ = serverName;
} }
#endif // FOLLY_OPENSSL_HAS_SNI #endif // FOLLY_OPENSSL_HAS_SNI
...@@ -451,9 +477,7 @@ void AsyncSSLSocket::sslAccept( ...@@ -451,9 +477,7 @@ void AsyncSSLSocket::sslAccept(
/* register for a read operation (waiting for CLIENT HELLO) */ /* register for a read operation (waiting for CLIENT HELLO) */
updateEventRegistration(EventHandler::READ, EventHandler::WRITE); updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
if (preReceivedData_) { checkForImmediateRead();
handleRead();
}
} }
#if OPENSSL_VERSION_NUMBER >= 0x009080bfL #if OPENSSL_VERSION_NUMBER >= 0x009080bfL
...@@ -985,6 +1009,8 @@ void AsyncSSLSocket::checkForImmediateRead() noexcept { ...@@ -985,6 +1009,8 @@ void AsyncSSLSocket::checkForImmediateRead() noexcept {
// the socket to become readable again. // the socket to become readable again.
if (ssl_ != nullptr && SSL_pending(ssl_) > 0) { if (ssl_ != nullptr && SSL_pending(ssl_) > 0) {
AsyncSocket::handleRead(); AsyncSocket::handleRead();
} else {
AsyncSocket::checkForImmediateRead();
} }
} }
...@@ -1684,12 +1710,6 @@ int AsyncSSLSocket::sslVerifyCallback( ...@@ -1684,12 +1710,6 @@ int AsyncSSLSocket::sslVerifyCallback(
preverifyOk; preverifyOk;
} }
void AsyncSSLSocket::setPreReceivedData(std::unique_ptr<IOBuf> data) {
CHECK(sslState_ == STATE_UNINIT || sslState_ == STATE_UNENCRYPTED);
CHECK(!preReceivedData_);
preReceivedData_ = std::move(data);
}
void AsyncSSLSocket::enableClientHelloParsing() { void AsyncSSLSocket::enableClientHelloParsing() {
parseClientHello_ = true; parseClientHello_ = true;
clientHelloInfo_.reset(new ssl::ClientHelloInfo()); clientHelloInfo_.reset(new ssl::ClientHelloInfo());
......
...@@ -173,10 +173,22 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -173,10 +173,22 @@ class AsyncSSLSocket : public virtual AsyncSocket {
* @param deferSecurityNegotiation * @param deferSecurityNegotiation
* unencrypted data can be sent before sslConn/Accept * unencrypted data can be sent before sslConn/Accept
*/ */
AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx, AsyncSSLSocket(
EventBase* evb, int fd, const std::shared_ptr<folly::SSLContext>& ctx,
bool server = true, bool deferSecurityNegotiation = false); EventBase* evb,
int fd,
bool server = true,
bool deferSecurityNegotiation = false);
/**
* Create a server/client AsyncSSLSocket from an already connected
* AsyncSocket.
*/
AsyncSSLSocket(
const std::shared_ptr<folly::SSLContext>& ctx,
AsyncSocket::UniquePtr oldAsyncSocket,
bool server = true,
bool deferSecurityNegotiation = false);
/** /**
* Helper function to create a server/client shared_ptr<AsyncSSLSocket>. * Helper function to create a server/client shared_ptr<AsyncSSLSocket>.
...@@ -227,11 +239,12 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -227,11 +239,12 @@ class AsyncSSLSocket : public virtual AsyncSocket {
* @param fd File descriptor to take over (should be a connected socket). * @param fd File descriptor to take over (should be a connected socket).
* @param serverName tlsext_hostname that will be sent in ClientHello. * @param serverName tlsext_hostname that will be sent in ClientHello.
*/ */
AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx, AsyncSSLSocket(
EventBase* evb, const std::shared_ptr<folly::SSLContext>& ctx,
int fd, EventBase* evb,
const std::string& serverName, int fd,
bool deferSecurityNegotiation = false); const std::string& serverName,
bool deferSecurityNegotiation = false);
static std::shared_ptr<AsyncSSLSocket> newSocket( static std::shared_ptr<AsyncSSLSocket> newSocket(
const std::shared_ptr<folly::SSLContext>& ctx, const std::shared_ptr<folly::SSLContext>& ctx,
...@@ -276,8 +289,6 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -276,8 +289,6 @@ class AsyncSSLSocket : public virtual AsyncSocket {
virtual size_t getRawBytesReceived() const override; virtual size_t getRawBytesReceived() const override;
void enableClientHelloParsing(); void enableClientHelloParsing();
void setPreReceivedData(std::unique_ptr<IOBuf> data);
/** /**
* Accept an SSL connection on the socket. * Accept an SSL connection on the socket.
* *
...@@ -864,7 +875,6 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -864,7 +875,6 @@ class AsyncSSLSocket : public virtual AsyncSocket {
bool sessionResumptionAttempted_{false}; bool sessionResumptionAttempted_{false};
std::chrono::milliseconds totalConnectTimeout_{0}; std::chrono::milliseconds totalConnectTimeout_{0};
std::unique_ptr<IOBuf> preReceivedData_;
std::string sslVerificationAlert_; std::string sslVerificationAlert_;
}; };
......
...@@ -17,9 +17,11 @@ ...@@ -17,9 +17,11 @@
#include <folly/io/async/AsyncSocket.h> #include <folly/io/async/AsyncSocket.h>
#include <folly/ExceptionWrapper.h> #include <folly/ExceptionWrapper.h>
#include <folly/Portability.h>
#include <folly/SocketAddress.h> #include <folly/SocketAddress.h>
#include <folly/io/Cursor.h>
#include <folly/io/IOBuf.h> #include <folly/io/IOBuf.h>
#include <folly/Portability.h> #include <folly/io/IOBufQueue.h>
#include <folly/portability/Fcntl.h> #include <folly/portability/Fcntl.h>
#include <folly/portability/Sockets.h> #include <folly/portability/Sockets.h>
#include <folly/portability/SysUio.h> #include <folly/portability/SysUio.h>
...@@ -229,6 +231,11 @@ AsyncSocket::AsyncSocket(EventBase* evb, int fd) ...@@ -229,6 +231,11 @@ AsyncSocket::AsyncSocket(EventBase* evb, int fd)
state_ = StateEnum::ESTABLISHED; state_ = StateEnum::ESTABLISHED;
} }
AsyncSocket::AsyncSocket(AsyncSocket::UniquePtr oldAsyncSocket)
: AsyncSocket(oldAsyncSocket->getEventBase(), oldAsyncSocket->detachFd()) {
preReceivedData_ = std::move(oldAsyncSocket->preReceivedData_);
}
// init() method, since constructor forwarding isn't supported in most // init() method, since constructor forwarding isn't supported in most
// compilers yet. // compilers yet.
void AsyncSocket::init() { void AsyncSocket::init() {
...@@ -1406,12 +1413,23 @@ AsyncSocket::performRead(void** buf, size_t* buflen, size_t* /* offset */) { ...@@ -1406,12 +1413,23 @@ AsyncSocket::performRead(void** buf, size_t* buflen, size_t* /* offset */) {
VLOG(5) << "AsyncSocket::performRead() this=" << this << ", buf=" << *buf VLOG(5) << "AsyncSocket::performRead() this=" << this << ", buf=" << *buf
<< ", buflen=" << *buflen; << ", buflen=" << *buflen;
int recvFlags = 0; if (preReceivedData_ && !preReceivedData_->empty()) {
if (peek_) { VLOG(5) << "AsyncSocket::performRead() this=" << this
recvFlags |= MSG_PEEK; << ", reading pre-received data";
io::Cursor cursor(preReceivedData_.get());
auto len = cursor.pullAtMost(*buf, *buflen);
IOBufQueue queue;
queue.append(std::move(preReceivedData_));
queue.trimStart(len);
preReceivedData_ = queue.move();
appBytesReceived_ += len;
return ReadResult(len);
} }
ssize_t bytes = recv(fd_, *buf, *buflen, MSG_DONTWAIT | recvFlags); ssize_t bytes = recv(fd_, *buf, *buflen, MSG_DONTWAIT);
if (bytes < 0) { if (bytes < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK) { if (errno == EAGAIN || errno == EWOULDBLOCK) {
// No more data to read right now. // No more data to read right now.
...@@ -1762,6 +1780,12 @@ void AsyncSocket::checkForImmediateRead() noexcept { ...@@ -1762,6 +1780,12 @@ void AsyncSocket::checkForImmediateRead() noexcept {
// be a pessimism. In most cases it probably wouldn't be readable, and we // be a pessimism. In most cases it probably wouldn't be readable, and we
// would just waste an extra system call. Even if it is readable, waiting to // would just waste an extra system call. Even if it is readable, waiting to
// find out from libevent on the next event loop doesn't seem that bad. // find out from libevent on the next event loop doesn't seem that bad.
//
// The exception to this is if we have pre-received data. In that case there
// is definitely data available immediately.
if (preReceivedData_ && !preReceivedData_->empty()) {
handleRead();
}
} }
void AsyncSocket::handleInitialReadWrite() noexcept { void AsyncSocket::handleInitialReadWrite() noexcept {
......
...@@ -189,6 +189,14 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -189,6 +189,14 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
*/ */
AsyncSocket(EventBase* evb, int fd); AsyncSocket(EventBase* evb, int fd);
/**
* Create an AsyncSocket from a different, already connected AsyncSocket.
*
* Similar to AsyncSocket(evb, fd) when fd was previously owned by an
* AsyncSocket.
*/
explicit AsyncSocket(AsyncSocket::UniquePtr);
/** /**
* Helper function to create a shared_ptr<AsyncSocket>. * Helper function to create a shared_ptr<AsyncSocket>.
* *
...@@ -264,6 +272,10 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -264,6 +272,10 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
* error. The AsyncSocket may no longer be used after the file descriptor * error. The AsyncSocket may no longer be used after the file descriptor
* has been extracted. * has been extracted.
* *
* This method should be used with care as the resulting fd is not guaranteed
* to perfectly reflect the state of the AsyncSocket (security state,
* pre-received data, etc.).
*
* Returns the file descriptor. The caller assumes ownership of the * Returns the file descriptor. The caller assumes ownership of the
* descriptor, and it will not be closed when the AsyncSocket is destroyed. * descriptor, and it will not be closed when the AsyncSocket is destroyed.
*/ */
...@@ -601,8 +613,16 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -601,8 +613,16 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
return setsockopt(fd_, level, optname, optval, sizeof(T)); return setsockopt(fd_, level, optname, optval, sizeof(T));
} }
virtual void setPeek(bool peek) { /**
peek_ = peek; * Set pre-received data, to be returned to read callback before any data
* from the socket.
*/
virtual void setPreReceivedData(std::unique_ptr<IOBuf> data) {
if (preReceivedData_) {
preReceivedData_->prependChain(std::move(data));
} else {
preReceivedData_ = std::move(data);
}
} }
/** /**
...@@ -998,7 +1018,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -998,7 +1018,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
size_t appBytesWritten_; ///< Num of bytes written to socket size_t appBytesWritten_; ///< Num of bytes written to socket
bool isBufferMovable_{false}; bool isBufferMovable_{false};
bool peek_{false}; // Peek bytes. // Pre-received data, to be returned to read callback before any data from the
// socket.
std::unique_ptr<IOBuf> preReceivedData_;
int8_t readErr_{READ_NO_ERROR}; ///< The read error encountered, if any. int8_t readErr_{READ_NO_ERROR}; ///< The read error encountered, if any.
......
...@@ -2909,3 +2909,133 @@ TEST(AsyncSocketTest, ErrMessageCallback) { ...@@ -2909,3 +2909,133 @@ TEST(AsyncSocketTest, ErrMessageCallback) {
ASSERT_TRUE(errMsgCB.gotTimestamp_); ASSERT_TRUE(errMsgCB.gotTimestamp_);
} }
#endif // MSG_ERRQUEUE #endif // MSG_ERRQUEUE
TEST(AsyncSocket, PreReceivedData) {
TestServer server;
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
socket->connect(nullptr, server.getAddress(), 30);
evb.loop();
socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
auto acceptedSocket = server.acceptAsync(&evb);
ReadCallback peekCallback(2);
ReadCallback readCallback;
peekCallback.dataAvailableCallback = [&]() {
peekCallback.verifyData("he", 2);
acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("h"));
acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("e"));
acceptedSocket->setReadCB(nullptr);
acceptedSocket->setReadCB(&readCallback);
};
readCallback.dataAvailableCallback = [&]() {
if (readCallback.dataRead() == 5) {
readCallback.verifyData("hello", 5);
acceptedSocket->setReadCB(nullptr);
}
};
acceptedSocket->setReadCB(&peekCallback);
evb.loop();
}
TEST(AsyncSocket, PreReceivedDataOnly) {
TestServer server;
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
socket->connect(nullptr, server.getAddress(), 30);
evb.loop();
socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
auto acceptedSocket = server.acceptAsync(&evb);
ReadCallback peekCallback;
ReadCallback readCallback;
peekCallback.dataAvailableCallback = [&]() {
peekCallback.verifyData("hello", 5);
acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
acceptedSocket->setReadCB(&readCallback);
};
readCallback.dataAvailableCallback = [&]() {
readCallback.verifyData("hello", 5);
acceptedSocket->setReadCB(nullptr);
};
acceptedSocket->setReadCB(&peekCallback);
evb.loop();
}
TEST(AsyncSocket, PreReceivedDataPartial) {
TestServer server;
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
socket->connect(nullptr, server.getAddress(), 30);
evb.loop();
socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
auto acceptedSocket = server.acceptAsync(&evb);
ReadCallback peekCallback;
ReadCallback smallReadCallback(3);
ReadCallback normalReadCallback;
peekCallback.dataAvailableCallback = [&]() {
peekCallback.verifyData("hello", 5);
acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
acceptedSocket->setReadCB(&smallReadCallback);
};
smallReadCallback.dataAvailableCallback = [&]() {
smallReadCallback.verifyData("hel", 3);
acceptedSocket->setReadCB(&normalReadCallback);
};
normalReadCallback.dataAvailableCallback = [&]() {
normalReadCallback.verifyData("lo", 2);
acceptedSocket->setReadCB(nullptr);
};
acceptedSocket->setReadCB(&peekCallback);
evb.loop();
}
TEST(AsyncSocket, PreReceivedDataTakeover) {
TestServer server;
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
socket->connect(nullptr, server.getAddress(), 30);
evb.loop();
socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
auto acceptedSocket =
AsyncSocket::UniquePtr(new AsyncSocket(&evb, server.acceptFD()));
AsyncSocket::UniquePtr takeoverSocket;
ReadCallback peekCallback(3);
ReadCallback readCallback;
peekCallback.dataAvailableCallback = [&]() {
peekCallback.verifyData("hel", 3);
acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
acceptedSocket->setReadCB(nullptr);
takeoverSocket =
AsyncSocket::UniquePtr(new AsyncSocket(std::move(acceptedSocket)));
takeoverSocket->setReadCB(&readCallback);
};
readCallback.dataAvailableCallback = [&]() {
readCallback.verifyData("hello", 5);
takeoverSocket->setReadCB(nullptr);
};
acceptedSocket->setReadCB(&peekCallback);
evb.loop();
}
...@@ -50,7 +50,6 @@ class MockAsyncSSLSocket : public AsyncSSLSocket { ...@@ -50,7 +50,6 @@ class MockAsyncSSLSocket : public AsyncSSLSocket {
bool(const unsigned char**, bool(const unsigned char**,
unsigned*, unsigned*,
SSLContext::NextProtocolType*)); SSLContext::NextProtocolType*));
MOCK_METHOD1(setPeek, void(bool));
MOCK_METHOD1(setReadCB, void(ReadCallback*)); MOCK_METHOD1(setReadCB, void(ReadCallback*));
void sslConn( void sslConn(
......
...@@ -45,8 +45,11 @@ class MockAsyncSocket : public AsyncSocket { ...@@ -45,8 +45,11 @@ class MockAsyncSocket : public AsyncSocket {
MOCK_CONST_METHOD0(good, bool()); MOCK_CONST_METHOD0(good, bool());
MOCK_CONST_METHOD0(readable, bool()); MOCK_CONST_METHOD0(readable, bool());
MOCK_CONST_METHOD0(hangup, bool()); MOCK_CONST_METHOD0(hangup, bool());
MOCK_METHOD1(setPeek, void(bool));
MOCK_METHOD1(setReadCB, void(ReadCallback*)); MOCK_METHOD1(setReadCB, void(ReadCallback*));
MOCK_METHOD1(_setPreReceivedData, void(std::unique_ptr<IOBuf>&));
void setPreReceivedData(std::unique_ptr<IOBuf> data) override {
return _setPreReceivedData(data);
}
}; };
}} }}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment