Commit 399c4fa5 authored by Andre Pinto's avatar Andre Pinto Committed by Facebook Github Bot

Add getTotalBufferedBytes() method fo AsyncTransport

Summary:
When BufferCallback::onEgressBuffered() is called, applications usually want to
know how many bytes were buffered so that can take actions (such as avoiding
that connection, failing fast, etc).
This diff adds getTotalPendingBytes() to AsyncTransport to allow applications
to easily query how many bytes were actually buffered.

Reviewed By: spalamarchuk

Differential Revision: D13252677

fbshipit-source-id: 8ec203f6764e00f52d471321afb549376397eb84
parent 2e5a8ccb
...@@ -1069,22 +1069,24 @@ std::unique_ptr<IOBuf> IOBuf::takeOwnershipIov( ...@@ -1069,22 +1069,24 @@ std::unique_ptr<IOBuf> IOBuf::takeOwnershipIov(
return result; return result;
} }
size_t IOBuf::fillIov(struct iovec* iov, size_t len) const { IOBuf::FillIovResult IOBuf::fillIov(struct iovec* iov, size_t len) const {
IOBuf const* p = this; IOBuf const* p = this;
size_t i = 0; size_t i = 0;
size_t totalBytes = 0;
while (i < len) { while (i < len) {
// some code can get confused by empty iovs, so skip them // some code can get confused by empty iovs, so skip them
if (p->length() > 0) { if (p->length() > 0) {
iov[i].iov_base = const_cast<uint8_t*>(p->data()); iov[i].iov_base = const_cast<uint8_t*>(p->data());
iov[i].iov_len = p->length(); iov[i].iov_len = p->length();
totalBytes += p->length();
i++; i++;
} }
p = p->next(); p = p->next();
if (p == this) { if (p == this) {
return i; return {i, totalBytes};
} }
} }
return 0; return {0, 0};
} }
size_t IOBufHash::operator()(const IOBuf& buf) const noexcept { size_t IOBufHash::operator()(const IOBuf& buf) const noexcept {
......
...@@ -1244,17 +1244,25 @@ class IOBuf { ...@@ -1244,17 +1244,25 @@ class IOBuf {
*/ */
void appendToIov(folly::fbvector<struct iovec>* iov) const; void appendToIov(folly::fbvector<struct iovec>* iov) const;
struct FillIovResult {
// How many iovecs were filled (or 0 on error).
size_t numIovecs;
// The total length of filled iovecs (or 0 on error).
size_t totalLength;
};
/** /**
* Fill an iovec array with the IOBuf data. * Fill an iovec array with the IOBuf data.
* *
* Returns the number of iovec filled. If there are more buffer than * Returns a struct with two fields: the number of iovec filled, and total
* iovec, returns 0. This version is suitable to use with stack iovec * size of the iovecs filled. If there are more buffer than iovec, returns 0
* arrays. * in both fields.
* This version is suitable to use with stack iovec arrays.
* *
* Naturally, the filled iovec data will be invalid if you modify the * Naturally, the filled iovec data will be invalid if you modify the
* buffer chain. * buffer chain.
*/ */
size_t fillIov(struct iovec* iov, size_t len) const; FillIovResult fillIov(struct iovec* iov, size_t len) const;
/** /**
* A helper that wraps a number of iovecs into an IOBuf chain. If count == 0, * A helper that wraps a number of iovecs into an IOBuf chain. If count == 0,
......
...@@ -358,6 +358,7 @@ void AsyncSocket::init() { ...@@ -358,6 +358,7 @@ void AsyncSocket::init() {
wShutdownSocketSet_.reset(); wShutdownSocketSet_.reset();
appBytesWritten_ = 0; appBytesWritten_ = 0;
appBytesReceived_ = 0; appBytesReceived_ = 0;
totalAppBytesScheduledForWrite_ = 0;
sendMsgParamCallback_ = &defaultSendMsgParamsCallback; sendMsgParamCallback_ = &defaultSendMsgParamsCallback;
} }
...@@ -1008,7 +1009,7 @@ void AsyncSocket::write( ...@@ -1008,7 +1009,7 @@ void AsyncSocket::write(
iovec op; iovec op;
op.iov_base = const_cast<void*>(buf); op.iov_base = const_cast<void*>(buf);
op.iov_len = bytes; op.iov_len = bytes;
writeImpl(callback, &op, 1, unique_ptr<IOBuf>(), flags); writeImpl(callback, &op, 1, unique_ptr<IOBuf>(), bytes, flags);
} }
void AsyncSocket::writev( void AsyncSocket::writev(
...@@ -1016,7 +1017,11 @@ void AsyncSocket::writev( ...@@ -1016,7 +1017,11 @@ void AsyncSocket::writev(
const iovec* vec, const iovec* vec,
size_t count, size_t count,
WriteFlags flags) { WriteFlags flags) {
writeImpl(callback, vec, count, unique_ptr<IOBuf>(), flags); size_t totalBytes = 0;
for (size_t i = 0; i < count; ++i) {
totalBytes += vec[i].iov_len;
}
writeImpl(callback, vec, count, unique_ptr<IOBuf>(), totalBytes, flags);
} }
void AsyncSocket::writeChain( void AsyncSocket::writeChain(
...@@ -1048,8 +1053,9 @@ void AsyncSocket::writeChainImpl( ...@@ -1048,8 +1053,9 @@ void AsyncSocket::writeChainImpl(
size_t count, size_t count,
unique_ptr<IOBuf>&& buf, unique_ptr<IOBuf>&& buf,
WriteFlags flags) { WriteFlags flags) {
size_t veclen = buf->fillIov(vec, count); auto res = buf->fillIov(vec, count);
writeImpl(callback, vec, veclen, std::move(buf), flags); writeImpl(
callback, vec, res.numIovecs, std::move(buf), res.totalLength, flags);
} }
void AsyncSocket::writeImpl( void AsyncSocket::writeImpl(
...@@ -1057,6 +1063,7 @@ void AsyncSocket::writeImpl( ...@@ -1057,6 +1063,7 @@ void AsyncSocket::writeImpl(
const iovec* vec, const iovec* vec,
size_t count, size_t count,
unique_ptr<IOBuf>&& buf, unique_ptr<IOBuf>&& buf,
size_t totalBytes,
WriteFlags flags) { WriteFlags flags) {
VLOG(6) << "AsyncSocket::writev() this=" << this << ", fd=" << fd_ VLOG(6) << "AsyncSocket::writev() this=" << this << ", fd=" << fd_
<< ", callback=" << callback << ", count=" << count << ", callback=" << callback << ", count=" << count
...@@ -1065,6 +1072,8 @@ void AsyncSocket::writeImpl( ...@@ -1065,6 +1072,8 @@ void AsyncSocket::writeImpl(
unique_ptr<IOBuf> ioBuf(std::move(buf)); unique_ptr<IOBuf> ioBuf(std::move(buf));
eventBase_->dcheckIsInEventBaseThread(); eventBase_->dcheckIsInEventBaseThread();
totalAppBytesScheduledForWrite_ += totalBytes;
if (shutdownFlags_ & (SHUT_WRITE | SHUT_WRITE_PENDING)) { if (shutdownFlags_ & (SHUT_WRITE | SHUT_WRITE_PENDING)) {
// No new writes may be performed after the write side of the socket has // No new writes may be performed after the write side of the socket has
// been shutdown. // been shutdown.
...@@ -1118,9 +1127,6 @@ void AsyncSocket::writeImpl( ...@@ -1118,9 +1127,6 @@ void AsyncSocket::writeImpl(
if (bytesWritten && isZeroCopyRequest(flags)) { if (bytesWritten && isZeroCopyRequest(flags)) {
addZeroCopyBuf(ioBuf.get()); addZeroCopyBuf(ioBuf.get());
} }
if (bufferCallback_) {
bufferCallback_->onEgressBuffered();
}
} }
if (!connecting()) { if (!connecting()) {
// Writes might put the socket back into connecting state // Writes might put the socket back into connecting state
...@@ -1163,6 +1169,10 @@ void AsyncSocket::writeImpl( ...@@ -1163,6 +1169,10 @@ void AsyncSocket::writeImpl(
writeReqTail_ = req; writeReqTail_ = req;
} }
if (bufferCallback_) {
bufferCallback_->onEgressBuffered();
}
// Register for write events if are established and not currently // Register for write events if are established and not currently
// waiting on write events // waiting on write events
if (mustRegister) { if (mustRegister) {
...@@ -2117,10 +2127,10 @@ void AsyncSocket::handleWrite() noexcept { ...@@ -2117,10 +2127,10 @@ void AsyncSocket::handleWrite() noexcept {
// We'll continue around the loop, trying to write another request // We'll continue around the loop, trying to write another request
} else { } else {
// Partial write. // Partial write.
writeReqHead_->consume();
if (bufferCallback_) { if (bufferCallback_) {
bufferCallback_->onEgressBuffered(); bufferCallback_->onEgressBuffered();
} }
writeReqHead_->consume();
// Stop after a partial write; it's highly likely that a subsequent write // Stop after a partial write; it's highly likely that a subsequent write
// attempt will just return EAGAIN. // attempt will just return EAGAIN.
// //
...@@ -2684,6 +2694,9 @@ void AsyncSocket::failAllWrites(const AsyncSocketException& ex) { ...@@ -2684,6 +2694,9 @@ void AsyncSocket::failAllWrites(const AsyncSocketException& ex) {
} }
req->destroy(); req->destroy();
} }
// All pending writes have failed - reset totalAppBytesScheduledForWrite_
totalAppBytesScheduledForWrite_ = appBytesWritten_;
} }
void AsyncSocket::invalidState(ConnectCallback* callback) { void AsyncSocket::invalidState(ConnectCallback* callback) {
......
...@@ -606,6 +606,13 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -606,6 +606,13 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
return getAppBytesReceived(); return getAppBytesReceived();
} }
size_t getAppBytesBuffered() const override {
return totalAppBytesScheduledForWrite_ - appBytesWritten_;
}
size_t getRawBytesBuffered() const override {
return getAppBytesBuffered();
}
std::chrono::nanoseconds getConnectTime() const { std::chrono::nanoseconds getConnectTime() const {
return connectEndTime_ - connectStartTime_; return connectEndTime_ - connectStartTime_;
} }
...@@ -1110,6 +1117,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -1110,6 +1117,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
* associated with an IOBuf. Note that ownership of * associated with an IOBuf. Note that ownership of
* the IOBuf is transferred here; upon completion of * the IOBuf is transferred here; upon completion of
* the write, the AsyncSocket deletes the IOBuf. * the write, the AsyncSocket deletes the IOBuf.
* @param totalBytes The total number of bytes to be written.
* @param flags Set of write flags. * @param flags Set of write flags.
*/ */
void writeImpl( void writeImpl(
...@@ -1117,6 +1125,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -1117,6 +1125,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
const iovec* vec, const iovec* vec,
size_t count, size_t count,
std::unique_ptr<folly::IOBuf>&& buf, std::unique_ptr<folly::IOBuf>&& buf,
size_t totalBytes,
WriteFlags flags = WriteFlags::NONE); WriteFlags flags = WriteFlags::NONE);
/** /**
...@@ -1263,6 +1272,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -1263,6 +1272,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
std::weak_ptr<ShutdownSocketSet> wShutdownSocketSet_; std::weak_ptr<ShutdownSocketSet> wShutdownSocketSet_;
size_t appBytesReceived_; ///< Num of bytes received from socket size_t appBytesReceived_; ///< Num of bytes received from socket
size_t appBytesWritten_; ///< Num of bytes written to socket size_t appBytesWritten_; ///< Num of bytes written to socket
// The total num of bytes passed to AsyncSocket's write functions. It doesn't
// include failed writes, but it does include buffered writes.
size_t totalAppBytesScheduledForWrite_;
// Pre-received data, to be returned to read callback before any data from the // Pre-received data, to be returned to read callback before any data from the
// socket. // socket.
......
...@@ -434,10 +434,34 @@ class AsyncTransport : public DelayedDestruction, public AsyncSocketBase { ...@@ -434,10 +434,34 @@ class AsyncTransport : public DelayedDestruction, public AsyncSocketBase {
virtual size_t getAppBytesReceived() const = 0; virtual size_t getAppBytesReceived() const = 0;
virtual size_t getRawBytesReceived() const = 0; virtual size_t getRawBytesReceived() const = 0;
/**
* Calculates the total number of bytes that are currently buffered in the
* transport to be written later.
*/
virtual size_t getAppBytesBuffered() const {
return 0;
}
virtual size_t getRawBytesBuffered() const {
return 0;
}
/**
* Callback class to signal changes in the transport's internal buffers.
*/
class BufferCallback { class BufferCallback {
public: public:
virtual ~BufferCallback() {} virtual ~BufferCallback() = default;
/**
* onEgressBuffered() will be invoked when there's a partial write and it
* is necessary to buffer the remaining data.
*/
virtual void onEgressBuffered() = 0; virtual void onEgressBuffered() = 0;
/**
* onEgressBufferCleared() will be invoked when whatever was buffered is
* written, or when it errors out.
*/
virtual void onEgressBufferCleared() = 0; virtual void onEgressBufferCleared() = 0;
}; };
......
...@@ -240,7 +240,7 @@ ssize_t AsyncUDPSocket::writeGSO( ...@@ -240,7 +240,7 @@ ssize_t AsyncUDPSocket::writeGSO(
// buffers less than 16, which is the highest I can think of // buffers less than 16, which is the highest I can think of
// for a real use case. // for a real use case.
iovec vec[16]; iovec vec[16];
size_t iovec_len = buf->fillIov(vec, sizeof(vec) / sizeof(vec[0])); size_t iovec_len = buf->fillIov(vec, sizeof(vec) / sizeof(vec[0])).numIovecs;
if (UNLIKELY(iovec_len == 0)) { if (UNLIKELY(iovec_len == 0)) {
buf->coalesce(); buf->coalesce();
vec[0].iov_base = const_cast<uint8_t*>(buf->data()); vec[0].iov_base = const_cast<uint8_t*>(buf->data());
......
...@@ -183,13 +183,25 @@ class ReadCallback : public folly::AsyncTransportWrapper::ReadCallback { ...@@ -183,13 +183,25 @@ class ReadCallback : public folly::AsyncTransportWrapper::ReadCallback {
class BufferCallback : public folly::AsyncTransport::BufferCallback { class BufferCallback : public folly::AsyncTransport::BufferCallback {
public: public:
BufferCallback() : buffered_(false), bufferCleared_(false) {} BufferCallback(folly::AsyncSocket* socket, size_t expectedBytes)
: socket_(socket),
expectedBytes_(expectedBytes),
buffered_(false),
bufferCleared_(false) {}
void onEgressBuffered() override { void onEgressBuffered() override {
size_t bytesWritten = socket_->getAppBytesWritten();
size_t bytesBuffered = socket_->getAppBytesBuffered();
CHECK_GT(bytesBuffered, 0);
CHECK_EQ(expectedBytes_, bytesWritten + bytesBuffered);
buffered_ = true; buffered_ = true;
} }
void onEgressBufferCleared() override { void onEgressBufferCleared() override {
size_t bytesWritten = socket_->getAppBytesWritten();
size_t bytesBuffered = socket_->getAppBytesBuffered();
CHECK_EQ(0, bytesBuffered);
CHECK_EQ(expectedBytes_, bytesWritten);
bufferCleared_ = true; bufferCleared_ = true;
} }
...@@ -202,6 +214,8 @@ class BufferCallback : public folly::AsyncTransport::BufferCallback { ...@@ -202,6 +214,8 @@ class BufferCallback : public folly::AsyncTransport::BufferCallback {
} }
private: private:
folly::AsyncSocket* socket_{nullptr};
size_t expectedBytes_{0};
bool buffered_{false}; bool buffered_{false};
bool bufferCleared_{false}; bool bufferCleared_{false};
}; };
......
...@@ -2337,7 +2337,7 @@ TEST(AsyncSocketTest, BufferTest) { ...@@ -2337,7 +2337,7 @@ TEST(AsyncSocketTest, BufferTest) {
char buf[100 * 1024]; char buf[100 * 1024];
memset(buf, 'c', sizeof(buf)); memset(buf, 'c', sizeof(buf));
WriteCallback wcb; WriteCallback wcb;
BufferCallback bcb; BufferCallback bcb(socket.get(), sizeof(buf));
socket->setBufferCallback(&bcb); socket->setBufferCallback(&bcb);
socket->write(&wcb, buf, sizeof(buf), WriteFlags::NONE); socket->write(&wcb, buf, sizeof(buf), WriteFlags::NONE);
...@@ -2355,6 +2355,46 @@ TEST(AsyncSocketTest, BufferTest) { ...@@ -2355,6 +2355,46 @@ TEST(AsyncSocketTest, BufferTest) {
ASSERT_FALSE(socket->isClosedByPeer()); ASSERT_FALSE(socket->isClosedByPeer());
} }
TEST(AsyncSocketTest, BufferTestChain) {
TestServer server;
EventBase evb;
AsyncSocket::OptionMap option{{{SOL_SOCKET, SO_SNDBUF}, 128}};
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30, option);
char buf1[100 * 1024];
memset(buf1, 'c', sizeof(buf1));
char buf2[100 * 1024];
memset(buf2, 'f', sizeof(buf2));
auto buf = folly::IOBuf::copyBuffer(buf1, sizeof(buf1));
buf->appendChain(folly::IOBuf::copyBuffer(buf2, sizeof(buf2)));
ASSERT_EQ(sizeof(buf1) + sizeof(buf2), buf->computeChainDataLength());
BufferCallback bcb(socket.get(), buf->computeChainDataLength());
socket->setBufferCallback(&bcb);
WriteCallback wcb;
socket->writeChain(&wcb, buf->clone(), WriteFlags::NONE);
evb.loop();
ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
ASSERT_TRUE(bcb.hasBuffered());
ASSERT_TRUE(bcb.hasBufferCleared());
socket->close();
buf->coalesce();
server.verifyConnection(
reinterpret_cast<const char*>(buf->data()), buf->length());
ASSERT_TRUE(socket->isClosedBySelf());
ASSERT_FALSE(socket->isClosedByPeer());
}
TEST(AsyncSocketTest, BufferCallbackKill) { TEST(AsyncSocketTest, BufferCallbackKill) {
TestServer server; TestServer server;
EventBase evb; EventBase evb;
...@@ -2366,7 +2406,7 @@ TEST(AsyncSocketTest, BufferCallbackKill) { ...@@ -2366,7 +2406,7 @@ TEST(AsyncSocketTest, BufferCallbackKill) {
char buf[100 * 1024]; char buf[100 * 1024];
memset(buf, 'c', sizeof(buf)); memset(buf, 'c', sizeof(buf));
BufferCallback bcb; BufferCallback bcb(socket.get(), sizeof(buf));
socket->setBufferCallback(&bcb); socket->setBufferCallback(&bcb);
WriteCallback wcb; WriteCallback wcb;
wcb.successCallback = [&] { wcb.successCallback = [&] {
......
...@@ -294,7 +294,8 @@ class ConnectedWriteUDPClient : public UDPClient { ...@@ -294,7 +294,8 @@ class ConnectedWriteUDPClient : public UDPClient {
// msg. This will test that connect worked. // msg. This will test that connect worked.
void writePing(std::unique_ptr<folly::IOBuf> buf) override { void writePing(std::unique_ptr<folly::IOBuf> buf) override {
iovec vec[16]; iovec vec[16];
size_t iovec_len = buf->fillIov(vec, sizeof(vec) / sizeof(vec[0])); size_t iovec_len =
buf->fillIov(vec, sizeof(vec) / sizeof(vec[0])).numIovecs;
if (UNLIKELY(iovec_len == 0)) { if (UNLIKELY(iovec_len == 0)) {
buf->coalesce(); buf->coalesce();
vec[0].iov_base = const_cast<uint8_t*>(buf->data()); vec[0].iov_base = const_cast<uint8_t*>(buf->data());
......
...@@ -1500,3 +1500,53 @@ TEST(IOBuf, CloneCoalescedSingle) { ...@@ -1500,3 +1500,53 @@ TEST(IOBuf, CloneCoalescedSingle) {
EXPECT_EQ(b->data(), c->data()); EXPECT_EQ(b->data(), c->data());
EXPECT_EQ(b->length(), c->length()); EXPECT_EQ(b->length(), c->length());
} }
TEST(IOBuf, fillIov) {
auto buf = IOBuf::create(4096);
append(buf, "hello");
auto buf2 = IOBuf::create(4096);
append(buf2, "goodbye");
auto buf3 = IOBuf::create(4096);
append(buf3, "hello again");
buf2->appendChain(std::move(buf3));
buf->appendChain(std::move(buf2));
constexpr size_t iovCount = 3;
struct iovec vec[iovCount];
auto res = buf->fillIov(vec, iovCount);
EXPECT_EQ(iovCount, res.numIovecs);
EXPECT_EQ(23, res.totalLength);
EXPECT_EQ(
"hello",
std::string(
reinterpret_cast<const char*>(vec[0].iov_base), vec[0].iov_len));
EXPECT_EQ(
"goodbye",
std::string(
reinterpret_cast<const char*>(vec[1].iov_base), vec[1].iov_len));
EXPECT_EQ(
"hello again",
std::string(
reinterpret_cast<const char*>(vec[2].iov_base), vec[2].iov_len));
}
TEST(IOBuf, fillIov2) {
auto buf = IOBuf::create(4096);
append(buf, "hello");
auto buf2 = IOBuf::create(4096);
append(buf2, "goodbye");
auto buf3 = IOBuf::create(4096);
append(buf2, "hello again");
buf2->appendChain(std::move(buf3));
buf->appendChain(std::move(buf2));
constexpr size_t iovCount = 2;
struct iovec vec[iovCount];
auto res = buf->fillIov(vec, iovCount);
EXPECT_EQ(0, res.numIovecs);
EXPECT_EQ(0, res.totalLength);
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment