Commit eccdeae9 authored by Brandon Schlinker's avatar Brandon Schlinker Committed by Facebook Github Bot

Support TIMESTAMP_TX flag

Summary:
The `TIMESTAMP_TX` flag can be used to signal a request TX / NIC timestamping. This flag needs to be passed through to the application via the `SendMsgParamsCallback::getAncillaryData` callback so that the application can populate a socket control message with the ancillary data required to signal the timestamping request.

`folly::AsyncSSLSocket` has extra logic for tracking the end of record (EOR) byte, but this logic currently only passes through `WriteFlags::EOR`. This diff adds support for passing through any flag specified as an EOR-related write flag when the EOR byte is written. As part of this change, the relevant unit tests are extended / cleaned up (some of them still reference MSG_EOR, which is out of date).

Reviewed By: yfeldblum

Differential Revision: D15465432

fbshipit-source-id: 2ab5619607959dd829427a695aefd95a33b4abce
parent 1308f879
......@@ -386,6 +386,7 @@ void AsyncSSLSocket::setEorTracking(bool track) {
if (isEorTrackingEnabled() != track) {
AsyncSocket::setEorTracking(track);
appEorByteNo_ = 0;
appEorByteWriteFlags_ = {};
minEorRawByteNo_ = 0;
}
}
......@@ -1536,13 +1537,18 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
}
}
// cork the current write if the original flags included CORK or if there
// are remaining iovec to write
corkCurrentWrite_ =
isSet(flags, WriteFlags::CORK) || (i + buffersStolen + 1 < count);
bytes = eorAwareSSLWrite(
ssl_,
sslWriteBuf,
int(len),
(isSet(flags, WriteFlags::EOR) && i + buffersStolen + 1 == count));
// track the EoR if:
// (1) there are write flags that require EoR tracking (EOR / TIMESTAMP_TX)
// (2) if the buffer includes the EOR byte
appEorByteWriteFlags_ = flags & kEorRelevantWriteFlags;
bool trackEor = appEorByteWriteFlags_ != folly::WriteFlags::NONE &&
(i + buffersStolen + 1 == count);
bytes = eorAwareSSLWrite(ssl_, sslWriteBuf, int(len), trackEor);
if (bytes <= 0) {
int error = SSL_get_error(ssl_.get(), int(bytes));
......@@ -1609,6 +1615,7 @@ int AsyncSSLSocket::eorAwareSSLWrite(
}
if (appBytesWritten_ == appEorByteNo_) {
appEorByteNo_ = 0;
appEorByteWriteFlags_ = {};
} else {
CHECK(appBytesWritten_ < appEorByteNo_);
}
......@@ -1658,7 +1665,7 @@ int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
WriteFlags flags = WriteFlags::NONE;
if (tsslSock->isEorTrackingEnabled() && tsslSock->minEorRawByteNo_ &&
tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
flags |= WriteFlags::EOR;
flags |= tsslSock->appEorByteWriteFlags_;
}
if (tsslSock->corkCurrentWrite_) {
......
......@@ -894,16 +894,31 @@ class AsyncSSLSocket : public virtual AsyncSocket {
Timeout handshakeTimeout_;
Timeout connectionTimeout_;
// The app byte num that we are tracking for the MSG_EOR
// The app byte num that we are tracking for EOR.
//
// Only one app EOR byte can be tracked.
// See appEorByteWriteFlags_ for details.
size_t appEorByteNo_{0};
// The WriteFlags to pass for the app byte num that is tracked for EOR.
//
// When openssl is about to send appEorByteNo_, these flags will be passed to
// the application via the getAncillaryData callback. The application can then
// generate a control message containing socket timestamping flags or other
// commands that will be included when the corresponding buffer is passed to
// the kernel via sendmsg().
//
// See AsyncSSLSocket::bioWrite (which overrides OpenSSL biowrite).
WriteFlags appEorByteWriteFlags_{};
// Try to avoid calling SSL_write() for buffers smaller than this.
// It doesn't take effect when it is 0.
size_t minWriteSize_{1500};
// When openssl is about to sendmsg() across the minEorRawBytesNo_,
// it will pass MSG_EOR to sendmsg().
// it will trigger logic to include an application defined control message.
//
// See appEorByteWriteFlags_ for details.
size_t minEorRawByteNo_{0};
#if FOLLY_OPENSSL_HAS_SNI
std::shared_ptr<folly::SSLContext> handshakeCtx_;
......
......@@ -146,10 +146,15 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
/**
* getFlags() will be invoked to retrieve the desired flags to be passed
* to ::sendmsg() system call. This method was intentionally declared
* non-virtual, so there is no way to override it. Instead feel free to
* override getFlagsImpl(flags, defaultFlags) method instead, and enjoy
* the convenience of defaultFlags passed there.
* to ::sendmsg() system call. It is responsible for converting flags set in
* the passed folly::WriteFlags enum into a integer flag bitmask that can be
* passed to ::sendmsg. Some flags in folly::WriteFlags do not correspond to
* flags that can be passed to ::sendmsg and may instead be handled via
* getAncillaryData.
*
* This method was intentionally declared non-virtual, so there is no way to
* override it. Instead feel free to override getFlagsImpl(...) instead, and
* enjoy the convenience of defaultFlags passed there.
*
* @param flags Write flags requested for the given write operation
*/
......@@ -160,7 +165,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
/**
* getAncillaryData() will be invoked to initialize ancillary data
* buffer referred by "msg_control" field of msghdr structure passed to
* ::sendmsg() system call. The function assumes that the size of buffer
* ::sendmsg() system call based on the flags set in the passed
* folly::WriteFlags enum. Some flags in folly::WriteFlags are not relevant
* during this process. The function assumes that the size of buffer
* is not smaller than the value returned by getAncillaryDataSize() method
* for the same combination of flags.
*
......
......@@ -78,7 +78,7 @@ enum class WriteFlags : uint32_t {
/*
* union operator
*/
inline WriteFlags operator|(WriteFlags a, WriteFlags b) {
constexpr WriteFlags operator|(WriteFlags a, WriteFlags b) {
return static_cast<WriteFlags>(
static_cast<uint32_t>(a) | static_cast<uint32_t>(b));
}
......@@ -86,7 +86,7 @@ inline WriteFlags operator|(WriteFlags a, WriteFlags b) {
/*
* compound assignment union operator
*/
inline WriteFlags& operator|=(WriteFlags& a, WriteFlags b) {
constexpr WriteFlags& operator|=(WriteFlags& a, WriteFlags b) {
a = a | b;
return a;
}
......@@ -94,7 +94,7 @@ inline WriteFlags& operator|=(WriteFlags& a, WriteFlags b) {
/*
* intersection operator
*/
inline WriteFlags operator&(WriteFlags a, WriteFlags b) {
constexpr WriteFlags operator&(WriteFlags a, WriteFlags b) {
return static_cast<WriteFlags>(
static_cast<uint32_t>(a) & static_cast<uint32_t>(b));
}
......@@ -102,7 +102,7 @@ inline WriteFlags operator&(WriteFlags a, WriteFlags b) {
/*
* compound assignment intersection operator
*/
inline WriteFlags& operator&=(WriteFlags& a, WriteFlags b) {
constexpr WriteFlags& operator&=(WriteFlags& a, WriteFlags b) {
a = a & b;
return a;
}
......@@ -110,24 +110,37 @@ inline WriteFlags& operator&=(WriteFlags& a, WriteFlags b) {
/*
* exclusion parameter
*/
inline WriteFlags operator~(WriteFlags a) {
constexpr WriteFlags operator~(WriteFlags a) {
return static_cast<WriteFlags>(~static_cast<uint32_t>(a));
}
/*
* unset operator
*/
inline WriteFlags unSet(WriteFlags a, WriteFlags b) {
constexpr WriteFlags unSet(WriteFlags a, WriteFlags b) {
return a & ~b;
}
/*
* inclusion operator
*/
inline bool isSet(WriteFlags a, WriteFlags b) {
constexpr bool isSet(WriteFlags a, WriteFlags b) {
return (a & b) == b;
}
/**
* Write flags that are specifically for the final write call of a buffer.
*
* In some cases, buffers passed to send may be coalesced or split by the socket
* write handling logic. For instance, a buffer passed to AsyncSSLSocket may be
* split across multiple TLS records (and therefore multiple calls to write).
*
* When a buffer is split up, these flags will only be applied for the final
* call to write for that buffer.
*/
constexpr WriteFlags kEorRelevantWriteFlags =
WriteFlags::EOR | WriteFlags::TIMESTAMP_TX;
/**
* AsyncTransport defines an asynchronous API for streaming I/O.
*
......
......@@ -2561,7 +2561,7 @@ TEST(AsyncSSLSocketTest, SendMsgParamsCallback) {
#ifdef FOLLY_HAVE_MSG_ERRQUEUE
/**
* Test connecting to, writing to, reading from, and closing the
* connection to the SSL server.
* connection to the SSL server with ancillary data from the application.
*/
TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
// This test requires Linux kernel v4.6 or later
......@@ -2580,7 +2580,7 @@ TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
}
// Start listening on a local port
SendMsgDataCallback msgCallback;
SendMsgAncillaryDataCallback msgCallback;
WriteCheckTimestampCallback writeCallback(&msgCallback);
ReadCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
......@@ -2596,11 +2596,17 @@ TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
socket->open();
// Adding MSG_EOR flag to the message flags - it'll trigger
// timestamp generation for the last byte of the message.
msgCallback.resetFlags(MSG_DONTWAIT | MSG_NOSIGNAL | MSG_EOR);
// we'll pass the EOR and TIMESTAMP_TX flags with the write back
// EOR tracking must be enabled for WriteFlags be passed
const auto writeFlags =
folly::WriteFlags::EOR | folly::WriteFlags::TIMESTAMP_TX;
readCallback.setWriteFlags(writeFlags);
msgCallback.setEorTracking(true);
// Init ancillary data buffer to trigger timestamp notification
//
// We generate the same ancillary data regardless of the specific WriteFlags,
// we verify that the WriteFlags are observed as expected below.
union {
uint8_t ctrl_data[CMSG_LEN(sizeof(uint32_t))];
struct cmsghdr cmsg;
......@@ -2615,9 +2621,9 @@ TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
memcpy(ctrl.data(), u.ctrl_data, CMSG_LEN(sizeof(uint32_t)));
msgCallback.resetData(std::move(ctrl));
// write()
// write(), including flags
std::vector<uint8_t> buf(128, 'a');
socket->write(buf.data(), buf.size());
socket->write(buf.data(), buf.size(), writeFlags);
// read()
std::vector<uint8_t> readbuf(buf.size());
......@@ -2625,11 +2631,29 @@ TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
EXPECT_EQ(bytesRead, buf.size());
EXPECT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
writeCallback.checkForTimestampNotifications();
// should receive three timestamps (schedule, TX/SND, ACK)
// may take some time for all to arrive, so loop to wait
//
// socket error queue does not have the equivalent of an EOF, so we must
// loop on it unless we want to use libevent for this test...
const std::vector<int32_t> timestampsExpected = {
SCM_TSTAMP_SCHED, SCM_TSTAMP_SND, SCM_TSTAMP_ACK};
std::vector<int32_t> timestampsReceived;
while (timestampsExpected.size() != timestampsReceived.size()) {
const auto timestamps = writeCallback.getTimestampNotifications();
timestampsReceived.insert(
timestampsReceived.end(), timestamps.begin(), timestamps.end());
}
EXPECT_THAT(timestampsReceived, ElementsAreArray(timestampsExpected));
// check the observed write flags
EXPECT_EQ(
static_cast<std::underlying_type<folly::WriteFlags>::type>(
msgCallback.getObservedWriteFlags()),
static_cast<std::underlying_type<folly::WriteFlags>::type>(writeFlags));
// close()
socket->close();
cerr << "SendMsgDataCallback test completed" << endl;
}
#endif // FOLLY_HAVE_MSG_ERRQUEUE
......
......@@ -57,6 +57,12 @@ class SendMsgParamsCallbackBase
socket_ = socket;
oldCallback_ = socket_->getSendMsgParamsCB();
socket_->setSendMsgParamCB(this);
socket_->setEorTracking(trackEor_);
}
void setEorTracking(bool track) {
CHECK(!socket_); // should only be called during setup
trackEor_ = track;
}
int getFlagsImpl(
......@@ -74,6 +80,7 @@ class SendMsgParamsCallbackBase
}
std::shared_ptr<AsyncSSLSocket> socket_;
bool trackEor_{false};
folly::AsyncSocket::SendMsgParamsCallback* oldCallback_{nullptr};
};
......@@ -98,15 +105,29 @@ class SendMsgFlagsCallback : public SendMsgParamsCallbackBase {
int flags_{0};
};
class SendMsgDataCallback : public SendMsgFlagsCallback {
class SendMsgAncillaryDataCallback : public SendMsgParamsCallbackBase {
public:
SendMsgDataCallback() {}
SendMsgAncillaryDataCallback() {}
/**
* This data will be returned on calls to getAncillaryData.
*/
void resetData(std::vector<char>&& data) {
ancillaryData_.swap(data);
}
/**
* These flags were observed on the last call to getAncillaryData.
*/
folly::WriteFlags getObservedWriteFlags() {
return observedWriteFlags_;
}
void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
// getAncillaryData is called through a long chain of functions after send
// record the observed write flags so we can compare later
observedWriteFlags_ = flags;
if (ancillaryData_.size()) {
std::cerr << "getAncillaryData: copying data" << std::endl;
memcpy(data, ancillaryData_.data(), ancillaryData_.size());
......@@ -124,6 +145,7 @@ class SendMsgDataCallback : public SendMsgFlagsCallback {
}
}
folly::WriteFlags observedWriteFlags_{};
std::vector<char> ancillaryData_;
};
......@@ -205,8 +227,6 @@ class WriteCheckTimestampCallback : public WriteCallbackBase {
~WriteCheckTimestampCallback() override {
EXPECT_EQ(STATE_SUCCEEDED, state);
EXPECT_TRUE(gotTimestamp_);
EXPECT_TRUE(gotByteSeq_);
}
void setSocket(const std::shared_ptr<AsyncSSLSocket>& socket) override {
......@@ -220,7 +240,7 @@ class WriteCheckTimestampCallback : public WriteCallbackBase {
EXPECT_EQ(ret, 0);
}
void checkForTimestampNotifications() noexcept {
std::vector<int32_t> getTimestampNotifications() noexcept {
auto fd = socket_->getNetworkSocket();
std::vector<char> ctrl(1024, 0);
unsigned char data;
......@@ -235,6 +255,11 @@ class WriteCheckTimestampCallback : public WriteCallbackBase {
msg.msg_control = ctrl.data();
msg.msg_controllen = ctrl.size();
std::vector<int32_t> timestampsFound;
folly::Optional<int32_t> timestampType;
bool gotTimestamp = false;
bool gotByteSeq = false;
int ret;
while (true) {
ret = netops::recvmsg(fd, &msg, MSG_ERRQUEUE);
......@@ -249,7 +274,7 @@ class WriteCheckTimestampCallback : public WriteCallbackBase {
errnoCopy);
exception = ex;
}
return;
return timestampsFound;
}
for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
......@@ -257,21 +282,39 @@ class WriteCheckTimestampCallback : public WriteCallbackBase {
cmsg = CMSG_NXTHDR(&msg, cmsg)) {
if (cmsg->cmsg_level == SOL_SOCKET &&
cmsg->cmsg_type == SCM_TIMESTAMPING) {
gotTimestamp_ = true;
continue;
CHECK(!gotTimestamp); // shouldn't already be set
gotTimestamp = true;
}
if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) ||
(cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) {
gotByteSeq_ = true;
const struct cmsghdr& cmsgh = *cmsg;
const auto serr = reinterpret_cast<const struct sock_extended_err*>(
CMSG_DATA(&cmsgh));
if (serr->ee_errno != ENOMSG ||
serr->ee_origin != SO_EE_ORIGIN_TIMESTAMPING) {
// not a timestamp
continue;
}
CHECK(!timestampType); // shouldn't already be set
CHECK(!gotByteSeq); // shouldn't already be set
gotByteSeq = true;
timestampType = serr->ee_info;
}
// check if we have both a timestamp and byte sequence
if (gotTimestamp && gotByteSeq) {
timestampsFound.push_back(*timestampType);
timestampType = folly::none;
gotTimestamp = false;
gotByteSeq = false;
}
}
} // for(...)
} // while(true)
bool gotTimestamp_{false};
bool gotByteSeq_{false};
return timestampsFound;
}
};
#endif // FOLLY_HAVE_MSG_ERRQUEUE
......@@ -312,10 +355,16 @@ class ReadCallbackBase : public AsyncTransportWrapper::ReadCallback {
StateEnum state;
};
/**
* ReadCallback reads data from the socket and then writes it back.
*
* It includes any folly::WriteFlags set via setWriteFlags(...) in its write
* back operation.
*/
class ReadCallback : public ReadCallbackBase {
public:
explicit ReadCallback(WriteCallbackBase* wcb)
: ReadCallbackBase(wcb), buffers() {}
: ReadCallbackBase(wcb), buffers(), writeFlags(folly::WriteFlags::NONE) {}
~ReadCallback() override {
for (std::vector<Buffer>::iterator it = buffers.begin();
......@@ -342,13 +391,20 @@ class ReadCallback : public ReadCallbackBase {
wcb_->setSocket(socket_);
// Write back the same data.
socket_->write(wcb_, currentBuffer.buffer, len);
socket_->write(wcb_, currentBuffer.buffer, len, writeFlags);
buffers.push_back(currentBuffer);
currentBuffer.reset();
state = STATE_SUCCEEDED;
}
/**
* These flags will be used when writing the read data back to the socket.
*/
void setWriteFlags(folly::WriteFlags flags) {
writeFlags = flags;
}
class Buffer {
public:
Buffer() : buffer(nullptr), length(0) {}
......@@ -374,6 +430,7 @@ class ReadCallback : public ReadCallbackBase {
std::vector<Buffer> buffers;
Buffer currentBuffer;
folly::WriteFlags writeFlags;
};
class ReadErrorCallback : public ReadCallbackBase {
......
......@@ -45,6 +45,10 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
sock_->enableTFO();
}
void setEorTracking(bool track) {
sock_->setEorTracking(track);
}
void setAddress(folly::SocketAddress address) {
address_ = address;
}
......@@ -65,8 +69,11 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
sock_->closeWithReset();
}
int32_t write(uint8_t const* buf, size_t len) {
sock_->write(this, buf, len);
int32_t write(
uint8_t const* buf,
size_t len,
folly::WriteFlags flags = folly::WriteFlags::NONE) {
sock_->write(this, buf, len, flags);
eventBase_.loop();
if (err_.hasValue()) {
throw err_.value();
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment