Commit eccdeae9 authored by Brandon Schlinker's avatar Brandon Schlinker Committed by Facebook Github Bot

Support TIMESTAMP_TX flag

Summary:
The `TIMESTAMP_TX` flag can be used to signal a request TX / NIC timestamping. This flag needs to be passed through to the application via the `SendMsgParamsCallback::getAncillaryData` callback so that the application can populate a socket control message with the ancillary data required to signal the timestamping request.

`folly::AsyncSSLSocket` has extra logic for tracking the end of record (EOR) byte, but this logic currently only passes through `WriteFlags::EOR`. This diff adds support for passing through any flag specified as an EOR-related write flag when the EOR byte is written. As part of this change, the relevant unit tests are extended / cleaned up (some of them still reference MSG_EOR, which is out of date).

Reviewed By: yfeldblum

Differential Revision: D15465432

fbshipit-source-id: 2ab5619607959dd829427a695aefd95a33b4abce
parent 1308f879
...@@ -386,6 +386,7 @@ void AsyncSSLSocket::setEorTracking(bool track) { ...@@ -386,6 +386,7 @@ void AsyncSSLSocket::setEorTracking(bool track) {
if (isEorTrackingEnabled() != track) { if (isEorTrackingEnabled() != track) {
AsyncSocket::setEorTracking(track); AsyncSocket::setEorTracking(track);
appEorByteNo_ = 0; appEorByteNo_ = 0;
appEorByteWriteFlags_ = {};
minEorRawByteNo_ = 0; minEorRawByteNo_ = 0;
} }
} }
...@@ -1536,13 +1537,18 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite( ...@@ -1536,13 +1537,18 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
} }
} }
// cork the current write if the original flags included CORK or if there
// are remaining iovec to write
corkCurrentWrite_ = corkCurrentWrite_ =
isSet(flags, WriteFlags::CORK) || (i + buffersStolen + 1 < count); isSet(flags, WriteFlags::CORK) || (i + buffersStolen + 1 < count);
bytes = eorAwareSSLWrite(
ssl_, // track the EoR if:
sslWriteBuf, // (1) there are write flags that require EoR tracking (EOR / TIMESTAMP_TX)
int(len), // (2) if the buffer includes the EOR byte
(isSet(flags, WriteFlags::EOR) && i + buffersStolen + 1 == count)); appEorByteWriteFlags_ = flags & kEorRelevantWriteFlags;
bool trackEor = appEorByteWriteFlags_ != folly::WriteFlags::NONE &&
(i + buffersStolen + 1 == count);
bytes = eorAwareSSLWrite(ssl_, sslWriteBuf, int(len), trackEor);
if (bytes <= 0) { if (bytes <= 0) {
int error = SSL_get_error(ssl_.get(), int(bytes)); int error = SSL_get_error(ssl_.get(), int(bytes));
...@@ -1609,6 +1615,7 @@ int AsyncSSLSocket::eorAwareSSLWrite( ...@@ -1609,6 +1615,7 @@ int AsyncSSLSocket::eorAwareSSLWrite(
} }
if (appBytesWritten_ == appEorByteNo_) { if (appBytesWritten_ == appEorByteNo_) {
appEorByteNo_ = 0; appEorByteNo_ = 0;
appEorByteWriteFlags_ = {};
} else { } else {
CHECK(appBytesWritten_ < appEorByteNo_); CHECK(appBytesWritten_ < appEorByteNo_);
} }
...@@ -1658,7 +1665,7 @@ int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) { ...@@ -1658,7 +1665,7 @@ int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
WriteFlags flags = WriteFlags::NONE; WriteFlags flags = WriteFlags::NONE;
if (tsslSock->isEorTrackingEnabled() && tsslSock->minEorRawByteNo_ && if (tsslSock->isEorTrackingEnabled() && tsslSock->minEorRawByteNo_ &&
tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) { tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
flags |= WriteFlags::EOR; flags |= tsslSock->appEorByteWriteFlags_;
} }
if (tsslSock->corkCurrentWrite_) { if (tsslSock->corkCurrentWrite_) {
......
...@@ -894,16 +894,31 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -894,16 +894,31 @@ class AsyncSSLSocket : public virtual AsyncSocket {
Timeout handshakeTimeout_; Timeout handshakeTimeout_;
Timeout connectionTimeout_; Timeout connectionTimeout_;
// The app byte num that we are tracking for the MSG_EOR // The app byte num that we are tracking for EOR.
//
// Only one app EOR byte can be tracked. // Only one app EOR byte can be tracked.
// See appEorByteWriteFlags_ for details.
size_t appEorByteNo_{0}; size_t appEorByteNo_{0};
// The WriteFlags to pass for the app byte num that is tracked for EOR.
//
// When openssl is about to send appEorByteNo_, these flags will be passed to
// the application via the getAncillaryData callback. The application can then
// generate a control message containing socket timestamping flags or other
// commands that will be included when the corresponding buffer is passed to
// the kernel via sendmsg().
//
// See AsyncSSLSocket::bioWrite (which overrides OpenSSL biowrite).
WriteFlags appEorByteWriteFlags_{};
// Try to avoid calling SSL_write() for buffers smaller than this. // Try to avoid calling SSL_write() for buffers smaller than this.
// It doesn't take effect when it is 0. // It doesn't take effect when it is 0.
size_t minWriteSize_{1500}; size_t minWriteSize_{1500};
// When openssl is about to sendmsg() across the minEorRawBytesNo_, // When openssl is about to sendmsg() across the minEorRawBytesNo_,
// it will pass MSG_EOR to sendmsg(). // it will trigger logic to include an application defined control message.
//
// See appEorByteWriteFlags_ for details.
size_t minEorRawByteNo_{0}; size_t minEorRawByteNo_{0};
#if FOLLY_OPENSSL_HAS_SNI #if FOLLY_OPENSSL_HAS_SNI
std::shared_ptr<folly::SSLContext> handshakeCtx_; std::shared_ptr<folly::SSLContext> handshakeCtx_;
......
...@@ -146,10 +146,15 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -146,10 +146,15 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
/** /**
* getFlags() will be invoked to retrieve the desired flags to be passed * getFlags() will be invoked to retrieve the desired flags to be passed
* to ::sendmsg() system call. This method was intentionally declared * to ::sendmsg() system call. It is responsible for converting flags set in
* non-virtual, so there is no way to override it. Instead feel free to * the passed folly::WriteFlags enum into a integer flag bitmask that can be
* override getFlagsImpl(flags, defaultFlags) method instead, and enjoy * passed to ::sendmsg. Some flags in folly::WriteFlags do not correspond to
* the convenience of defaultFlags passed there. * flags that can be passed to ::sendmsg and may instead be handled via
* getAncillaryData.
*
* This method was intentionally declared non-virtual, so there is no way to
* override it. Instead feel free to override getFlagsImpl(...) instead, and
* enjoy the convenience of defaultFlags passed there.
* *
* @param flags Write flags requested for the given write operation * @param flags Write flags requested for the given write operation
*/ */
...@@ -160,7 +165,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -160,7 +165,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
/** /**
* getAncillaryData() will be invoked to initialize ancillary data * getAncillaryData() will be invoked to initialize ancillary data
* buffer referred by "msg_control" field of msghdr structure passed to * buffer referred by "msg_control" field of msghdr structure passed to
* ::sendmsg() system call. The function assumes that the size of buffer * ::sendmsg() system call based on the flags set in the passed
* folly::WriteFlags enum. Some flags in folly::WriteFlags are not relevant
* during this process. The function assumes that the size of buffer
* is not smaller than the value returned by getAncillaryDataSize() method * is not smaller than the value returned by getAncillaryDataSize() method
* for the same combination of flags. * for the same combination of flags.
* *
......
...@@ -78,7 +78,7 @@ enum class WriteFlags : uint32_t { ...@@ -78,7 +78,7 @@ enum class WriteFlags : uint32_t {
/* /*
* union operator * union operator
*/ */
inline WriteFlags operator|(WriteFlags a, WriteFlags b) { constexpr WriteFlags operator|(WriteFlags a, WriteFlags b) {
return static_cast<WriteFlags>( return static_cast<WriteFlags>(
static_cast<uint32_t>(a) | static_cast<uint32_t>(b)); static_cast<uint32_t>(a) | static_cast<uint32_t>(b));
} }
...@@ -86,7 +86,7 @@ inline WriteFlags operator|(WriteFlags a, WriteFlags b) { ...@@ -86,7 +86,7 @@ inline WriteFlags operator|(WriteFlags a, WriteFlags b) {
/* /*
* compound assignment union operator * compound assignment union operator
*/ */
inline WriteFlags& operator|=(WriteFlags& a, WriteFlags b) { constexpr WriteFlags& operator|=(WriteFlags& a, WriteFlags b) {
a = a | b; a = a | b;
return a; return a;
} }
...@@ -94,7 +94,7 @@ inline WriteFlags& operator|=(WriteFlags& a, WriteFlags b) { ...@@ -94,7 +94,7 @@ inline WriteFlags& operator|=(WriteFlags& a, WriteFlags b) {
/* /*
* intersection operator * intersection operator
*/ */
inline WriteFlags operator&(WriteFlags a, WriteFlags b) { constexpr WriteFlags operator&(WriteFlags a, WriteFlags b) {
return static_cast<WriteFlags>( return static_cast<WriteFlags>(
static_cast<uint32_t>(a) & static_cast<uint32_t>(b)); static_cast<uint32_t>(a) & static_cast<uint32_t>(b));
} }
...@@ -102,7 +102,7 @@ inline WriteFlags operator&(WriteFlags a, WriteFlags b) { ...@@ -102,7 +102,7 @@ inline WriteFlags operator&(WriteFlags a, WriteFlags b) {
/* /*
* compound assignment intersection operator * compound assignment intersection operator
*/ */
inline WriteFlags& operator&=(WriteFlags& a, WriteFlags b) { constexpr WriteFlags& operator&=(WriteFlags& a, WriteFlags b) {
a = a & b; a = a & b;
return a; return a;
} }
...@@ -110,24 +110,37 @@ inline WriteFlags& operator&=(WriteFlags& a, WriteFlags b) { ...@@ -110,24 +110,37 @@ inline WriteFlags& operator&=(WriteFlags& a, WriteFlags b) {
/* /*
* exclusion parameter * exclusion parameter
*/ */
inline WriteFlags operator~(WriteFlags a) { constexpr WriteFlags operator~(WriteFlags a) {
return static_cast<WriteFlags>(~static_cast<uint32_t>(a)); return static_cast<WriteFlags>(~static_cast<uint32_t>(a));
} }
/* /*
* unset operator * unset operator
*/ */
inline WriteFlags unSet(WriteFlags a, WriteFlags b) { constexpr WriteFlags unSet(WriteFlags a, WriteFlags b) {
return a & ~b; return a & ~b;
} }
/* /*
* inclusion operator * inclusion operator
*/ */
inline bool isSet(WriteFlags a, WriteFlags b) { constexpr bool isSet(WriteFlags a, WriteFlags b) {
return (a & b) == b; return (a & b) == b;
} }
/**
* Write flags that are specifically for the final write call of a buffer.
*
* In some cases, buffers passed to send may be coalesced or split by the socket
* write handling logic. For instance, a buffer passed to AsyncSSLSocket may be
* split across multiple TLS records (and therefore multiple calls to write).
*
* When a buffer is split up, these flags will only be applied for the final
* call to write for that buffer.
*/
constexpr WriteFlags kEorRelevantWriteFlags =
WriteFlags::EOR | WriteFlags::TIMESTAMP_TX;
/** /**
* AsyncTransport defines an asynchronous API for streaming I/O. * AsyncTransport defines an asynchronous API for streaming I/O.
* *
......
...@@ -2561,7 +2561,7 @@ TEST(AsyncSSLSocketTest, SendMsgParamsCallback) { ...@@ -2561,7 +2561,7 @@ TEST(AsyncSSLSocketTest, SendMsgParamsCallback) {
#ifdef FOLLY_HAVE_MSG_ERRQUEUE #ifdef FOLLY_HAVE_MSG_ERRQUEUE
/** /**
* Test connecting to, writing to, reading from, and closing the * Test connecting to, writing to, reading from, and closing the
* connection to the SSL server. * connection to the SSL server with ancillary data from the application.
*/ */
TEST(AsyncSSLSocketTest, SendMsgDataCallback) { TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
// This test requires Linux kernel v4.6 or later // This test requires Linux kernel v4.6 or later
...@@ -2580,7 +2580,7 @@ TEST(AsyncSSLSocketTest, SendMsgDataCallback) { ...@@ -2580,7 +2580,7 @@ TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
} }
// Start listening on a local port // Start listening on a local port
SendMsgDataCallback msgCallback; SendMsgAncillaryDataCallback msgCallback;
WriteCheckTimestampCallback writeCallback(&msgCallback); WriteCheckTimestampCallback writeCallback(&msgCallback);
ReadCallback readCallback(&writeCallback); ReadCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback); HandshakeCallback handshakeCallback(&readCallback);
...@@ -2596,11 +2596,17 @@ TEST(AsyncSSLSocketTest, SendMsgDataCallback) { ...@@ -2596,11 +2596,17 @@ TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
std::make_shared<BlockingSocket>(server.getAddress(), sslContext); std::make_shared<BlockingSocket>(server.getAddress(), sslContext);
socket->open(); socket->open();
// Adding MSG_EOR flag to the message flags - it'll trigger // we'll pass the EOR and TIMESTAMP_TX flags with the write back
// timestamp generation for the last byte of the message. // EOR tracking must be enabled for WriteFlags be passed
msgCallback.resetFlags(MSG_DONTWAIT | MSG_NOSIGNAL | MSG_EOR); const auto writeFlags =
folly::WriteFlags::EOR | folly::WriteFlags::TIMESTAMP_TX;
readCallback.setWriteFlags(writeFlags);
msgCallback.setEorTracking(true);
// Init ancillary data buffer to trigger timestamp notification // Init ancillary data buffer to trigger timestamp notification
//
// We generate the same ancillary data regardless of the specific WriteFlags,
// we verify that the WriteFlags are observed as expected below.
union { union {
uint8_t ctrl_data[CMSG_LEN(sizeof(uint32_t))]; uint8_t ctrl_data[CMSG_LEN(sizeof(uint32_t))];
struct cmsghdr cmsg; struct cmsghdr cmsg;
...@@ -2615,9 +2621,9 @@ TEST(AsyncSSLSocketTest, SendMsgDataCallback) { ...@@ -2615,9 +2621,9 @@ TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
memcpy(ctrl.data(), u.ctrl_data, CMSG_LEN(sizeof(uint32_t))); memcpy(ctrl.data(), u.ctrl_data, CMSG_LEN(sizeof(uint32_t)));
msgCallback.resetData(std::move(ctrl)); msgCallback.resetData(std::move(ctrl));
// write() // write(), including flags
std::vector<uint8_t> buf(128, 'a'); std::vector<uint8_t> buf(128, 'a');
socket->write(buf.data(), buf.size()); socket->write(buf.data(), buf.size(), writeFlags);
// read() // read()
std::vector<uint8_t> readbuf(buf.size()); std::vector<uint8_t> readbuf(buf.size());
...@@ -2625,11 +2631,29 @@ TEST(AsyncSSLSocketTest, SendMsgDataCallback) { ...@@ -2625,11 +2631,29 @@ TEST(AsyncSSLSocketTest, SendMsgDataCallback) {
EXPECT_EQ(bytesRead, buf.size()); EXPECT_EQ(bytesRead, buf.size());
EXPECT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin())); EXPECT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
writeCallback.checkForTimestampNotifications(); // should receive three timestamps (schedule, TX/SND, ACK)
// may take some time for all to arrive, so loop to wait
//
// socket error queue does not have the equivalent of an EOF, so we must
// loop on it unless we want to use libevent for this test...
const std::vector<int32_t> timestampsExpected = {
SCM_TSTAMP_SCHED, SCM_TSTAMP_SND, SCM_TSTAMP_ACK};
std::vector<int32_t> timestampsReceived;
while (timestampsExpected.size() != timestampsReceived.size()) {
const auto timestamps = writeCallback.getTimestampNotifications();
timestampsReceived.insert(
timestampsReceived.end(), timestamps.begin(), timestamps.end());
}
EXPECT_THAT(timestampsReceived, ElementsAreArray(timestampsExpected));
// check the observed write flags
EXPECT_EQ(
static_cast<std::underlying_type<folly::WriteFlags>::type>(
msgCallback.getObservedWriteFlags()),
static_cast<std::underlying_type<folly::WriteFlags>::type>(writeFlags));
// close() // close()
socket->close(); socket->close();
cerr << "SendMsgDataCallback test completed" << endl; cerr << "SendMsgDataCallback test completed" << endl;
} }
#endif // FOLLY_HAVE_MSG_ERRQUEUE #endif // FOLLY_HAVE_MSG_ERRQUEUE
......
...@@ -57,6 +57,12 @@ class SendMsgParamsCallbackBase ...@@ -57,6 +57,12 @@ class SendMsgParamsCallbackBase
socket_ = socket; socket_ = socket;
oldCallback_ = socket_->getSendMsgParamsCB(); oldCallback_ = socket_->getSendMsgParamsCB();
socket_->setSendMsgParamCB(this); socket_->setSendMsgParamCB(this);
socket_->setEorTracking(trackEor_);
}
void setEorTracking(bool track) {
CHECK(!socket_); // should only be called during setup
trackEor_ = track;
} }
int getFlagsImpl( int getFlagsImpl(
...@@ -74,6 +80,7 @@ class SendMsgParamsCallbackBase ...@@ -74,6 +80,7 @@ class SendMsgParamsCallbackBase
} }
std::shared_ptr<AsyncSSLSocket> socket_; std::shared_ptr<AsyncSSLSocket> socket_;
bool trackEor_{false};
folly::AsyncSocket::SendMsgParamsCallback* oldCallback_{nullptr}; folly::AsyncSocket::SendMsgParamsCallback* oldCallback_{nullptr};
}; };
...@@ -98,15 +105,29 @@ class SendMsgFlagsCallback : public SendMsgParamsCallbackBase { ...@@ -98,15 +105,29 @@ class SendMsgFlagsCallback : public SendMsgParamsCallbackBase {
int flags_{0}; int flags_{0};
}; };
class SendMsgDataCallback : public SendMsgFlagsCallback { class SendMsgAncillaryDataCallback : public SendMsgParamsCallbackBase {
public: public:
SendMsgDataCallback() {} SendMsgAncillaryDataCallback() {}
/**
* This data will be returned on calls to getAncillaryData.
*/
void resetData(std::vector<char>&& data) { void resetData(std::vector<char>&& data) {
ancillaryData_.swap(data); ancillaryData_.swap(data);
} }
/**
* These flags were observed on the last call to getAncillaryData.
*/
folly::WriteFlags getObservedWriteFlags() {
return observedWriteFlags_;
}
void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override { void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
// getAncillaryData is called through a long chain of functions after send
// record the observed write flags so we can compare later
observedWriteFlags_ = flags;
if (ancillaryData_.size()) { if (ancillaryData_.size()) {
std::cerr << "getAncillaryData: copying data" << std::endl; std::cerr << "getAncillaryData: copying data" << std::endl;
memcpy(data, ancillaryData_.data(), ancillaryData_.size()); memcpy(data, ancillaryData_.data(), ancillaryData_.size());
...@@ -124,6 +145,7 @@ class SendMsgDataCallback : public SendMsgFlagsCallback { ...@@ -124,6 +145,7 @@ class SendMsgDataCallback : public SendMsgFlagsCallback {
} }
} }
folly::WriteFlags observedWriteFlags_{};
std::vector<char> ancillaryData_; std::vector<char> ancillaryData_;
}; };
...@@ -205,8 +227,6 @@ class WriteCheckTimestampCallback : public WriteCallbackBase { ...@@ -205,8 +227,6 @@ class WriteCheckTimestampCallback : public WriteCallbackBase {
~WriteCheckTimestampCallback() override { ~WriteCheckTimestampCallback() override {
EXPECT_EQ(STATE_SUCCEEDED, state); EXPECT_EQ(STATE_SUCCEEDED, state);
EXPECT_TRUE(gotTimestamp_);
EXPECT_TRUE(gotByteSeq_);
} }
void setSocket(const std::shared_ptr<AsyncSSLSocket>& socket) override { void setSocket(const std::shared_ptr<AsyncSSLSocket>& socket) override {
...@@ -220,7 +240,7 @@ class WriteCheckTimestampCallback : public WriteCallbackBase { ...@@ -220,7 +240,7 @@ class WriteCheckTimestampCallback : public WriteCallbackBase {
EXPECT_EQ(ret, 0); EXPECT_EQ(ret, 0);
} }
void checkForTimestampNotifications() noexcept { std::vector<int32_t> getTimestampNotifications() noexcept {
auto fd = socket_->getNetworkSocket(); auto fd = socket_->getNetworkSocket();
std::vector<char> ctrl(1024, 0); std::vector<char> ctrl(1024, 0);
unsigned char data; unsigned char data;
...@@ -235,6 +255,11 @@ class WriteCheckTimestampCallback : public WriteCallbackBase { ...@@ -235,6 +255,11 @@ class WriteCheckTimestampCallback : public WriteCallbackBase {
msg.msg_control = ctrl.data(); msg.msg_control = ctrl.data();
msg.msg_controllen = ctrl.size(); msg.msg_controllen = ctrl.size();
std::vector<int32_t> timestampsFound;
folly::Optional<int32_t> timestampType;
bool gotTimestamp = false;
bool gotByteSeq = false;
int ret; int ret;
while (true) { while (true) {
ret = netops::recvmsg(fd, &msg, MSG_ERRQUEUE); ret = netops::recvmsg(fd, &msg, MSG_ERRQUEUE);
...@@ -249,7 +274,7 @@ class WriteCheckTimestampCallback : public WriteCallbackBase { ...@@ -249,7 +274,7 @@ class WriteCheckTimestampCallback : public WriteCallbackBase {
errnoCopy); errnoCopy);
exception = ex; exception = ex;
} }
return; return timestampsFound;
} }
for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
...@@ -257,21 +282,39 @@ class WriteCheckTimestampCallback : public WriteCallbackBase { ...@@ -257,21 +282,39 @@ class WriteCheckTimestampCallback : public WriteCallbackBase {
cmsg = CMSG_NXTHDR(&msg, cmsg)) { cmsg = CMSG_NXTHDR(&msg, cmsg)) {
if (cmsg->cmsg_level == SOL_SOCKET && if (cmsg->cmsg_level == SOL_SOCKET &&
cmsg->cmsg_type == SCM_TIMESTAMPING) { cmsg->cmsg_type == SCM_TIMESTAMPING) {
gotTimestamp_ = true; CHECK(!gotTimestamp); // shouldn't already be set
continue; gotTimestamp = true;
} }
if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) || if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) ||
(cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) { (cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) {
gotByteSeq_ = true; const struct cmsghdr& cmsgh = *cmsg;
continue; const auto serr = reinterpret_cast<const struct sock_extended_err*>(
CMSG_DATA(&cmsgh));
if (serr->ee_errno != ENOMSG ||
serr->ee_origin != SO_EE_ORIGIN_TIMESTAMPING) {
// not a timestamp
continue;
}
CHECK(!timestampType); // shouldn't already be set
CHECK(!gotByteSeq); // shouldn't already be set
gotByteSeq = true;
timestampType = serr->ee_info;
} }
}
}
}
bool gotTimestamp_{false}; // check if we have both a timestamp and byte sequence
bool gotByteSeq_{false}; if (gotTimestamp && gotByteSeq) {
timestampsFound.push_back(*timestampType);
timestampType = folly::none;
gotTimestamp = false;
gotByteSeq = false;
}
} // for(...)
} // while(true)
return timestampsFound;
}
}; };
#endif // FOLLY_HAVE_MSG_ERRQUEUE #endif // FOLLY_HAVE_MSG_ERRQUEUE
...@@ -312,10 +355,16 @@ class ReadCallbackBase : public AsyncTransportWrapper::ReadCallback { ...@@ -312,10 +355,16 @@ class ReadCallbackBase : public AsyncTransportWrapper::ReadCallback {
StateEnum state; StateEnum state;
}; };
/**
* ReadCallback reads data from the socket and then writes it back.
*
* It includes any folly::WriteFlags set via setWriteFlags(...) in its write
* back operation.
*/
class ReadCallback : public ReadCallbackBase { class ReadCallback : public ReadCallbackBase {
public: public:
explicit ReadCallback(WriteCallbackBase* wcb) explicit ReadCallback(WriteCallbackBase* wcb)
: ReadCallbackBase(wcb), buffers() {} : ReadCallbackBase(wcb), buffers(), writeFlags(folly::WriteFlags::NONE) {}
~ReadCallback() override { ~ReadCallback() override {
for (std::vector<Buffer>::iterator it = buffers.begin(); for (std::vector<Buffer>::iterator it = buffers.begin();
...@@ -342,13 +391,20 @@ class ReadCallback : public ReadCallbackBase { ...@@ -342,13 +391,20 @@ class ReadCallback : public ReadCallbackBase {
wcb_->setSocket(socket_); wcb_->setSocket(socket_);
// Write back the same data. // Write back the same data.
socket_->write(wcb_, currentBuffer.buffer, len); socket_->write(wcb_, currentBuffer.buffer, len, writeFlags);
buffers.push_back(currentBuffer); buffers.push_back(currentBuffer);
currentBuffer.reset(); currentBuffer.reset();
state = STATE_SUCCEEDED; state = STATE_SUCCEEDED;
} }
/**
* These flags will be used when writing the read data back to the socket.
*/
void setWriteFlags(folly::WriteFlags flags) {
writeFlags = flags;
}
class Buffer { class Buffer {
public: public:
Buffer() : buffer(nullptr), length(0) {} Buffer() : buffer(nullptr), length(0) {}
...@@ -374,6 +430,7 @@ class ReadCallback : public ReadCallbackBase { ...@@ -374,6 +430,7 @@ class ReadCallback : public ReadCallbackBase {
std::vector<Buffer> buffers; std::vector<Buffer> buffers;
Buffer currentBuffer; Buffer currentBuffer;
folly::WriteFlags writeFlags;
}; };
class ReadErrorCallback : public ReadCallbackBase { class ReadErrorCallback : public ReadCallbackBase {
......
...@@ -45,6 +45,10 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback, ...@@ -45,6 +45,10 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
sock_->enableTFO(); sock_->enableTFO();
} }
void setEorTracking(bool track) {
sock_->setEorTracking(track);
}
void setAddress(folly::SocketAddress address) { void setAddress(folly::SocketAddress address) {
address_ = address; address_ = address;
} }
...@@ -65,8 +69,11 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback, ...@@ -65,8 +69,11 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
sock_->closeWithReset(); sock_->closeWithReset();
} }
int32_t write(uint8_t const* buf, size_t len) { int32_t write(
sock_->write(this, buf, len); uint8_t const* buf,
size_t len,
folly::WriteFlags flags = folly::WriteFlags::NONE) {
sock_->write(this, buf, len, flags);
eventBase_.loop(); eventBase_.loop();
if (err_.hasValue()) { if (err_.hasValue()) {
throw err_.value(); throw err_.value();
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment