Commit 399c4fa5 authored by Andre Pinto's avatar Andre Pinto Committed by Facebook Github Bot

Add getTotalBufferedBytes() method fo AsyncTransport

Summary:
When BufferCallback::onEgressBuffered() is called, applications usually want to
know how many bytes were buffered so that can take actions (such as avoiding
that connection, failing fast, etc).
This diff adds getTotalPendingBytes() to AsyncTransport to allow applications
to easily query how many bytes were actually buffered.

Reviewed By: spalamarchuk

Differential Revision: D13252677

fbshipit-source-id: 8ec203f6764e00f52d471321afb549376397eb84
parent 2e5a8ccb
......@@ -1069,22 +1069,24 @@ std::unique_ptr<IOBuf> IOBuf::takeOwnershipIov(
return result;
}
size_t IOBuf::fillIov(struct iovec* iov, size_t len) const {
IOBuf::FillIovResult IOBuf::fillIov(struct iovec* iov, size_t len) const {
IOBuf const* p = this;
size_t i = 0;
size_t totalBytes = 0;
while (i < len) {
// some code can get confused by empty iovs, so skip them
if (p->length() > 0) {
iov[i].iov_base = const_cast<uint8_t*>(p->data());
iov[i].iov_len = p->length();
totalBytes += p->length();
i++;
}
p = p->next();
if (p == this) {
return i;
return {i, totalBytes};
}
}
return 0;
return {0, 0};
}
size_t IOBufHash::operator()(const IOBuf& buf) const noexcept {
......
......@@ -1244,17 +1244,25 @@ class IOBuf {
*/
void appendToIov(folly::fbvector<struct iovec>* iov) const;
struct FillIovResult {
// How many iovecs were filled (or 0 on error).
size_t numIovecs;
// The total length of filled iovecs (or 0 on error).
size_t totalLength;
};
/**
* Fill an iovec array with the IOBuf data.
*
* Returns the number of iovec filled. If there are more buffer than
* iovec, returns 0. This version is suitable to use with stack iovec
* arrays.
* Returns a struct with two fields: the number of iovec filled, and total
* size of the iovecs filled. If there are more buffer than iovec, returns 0
* in both fields.
* This version is suitable to use with stack iovec arrays.
*
* Naturally, the filled iovec data will be invalid if you modify the
* buffer chain.
*/
size_t fillIov(struct iovec* iov, size_t len) const;
FillIovResult fillIov(struct iovec* iov, size_t len) const;
/**
* A helper that wraps a number of iovecs into an IOBuf chain. If count == 0,
......
......@@ -358,6 +358,7 @@ void AsyncSocket::init() {
wShutdownSocketSet_.reset();
appBytesWritten_ = 0;
appBytesReceived_ = 0;
totalAppBytesScheduledForWrite_ = 0;
sendMsgParamCallback_ = &defaultSendMsgParamsCallback;
}
......@@ -1008,7 +1009,7 @@ void AsyncSocket::write(
iovec op;
op.iov_base = const_cast<void*>(buf);
op.iov_len = bytes;
writeImpl(callback, &op, 1, unique_ptr<IOBuf>(), flags);
writeImpl(callback, &op, 1, unique_ptr<IOBuf>(), bytes, flags);
}
void AsyncSocket::writev(
......@@ -1016,7 +1017,11 @@ void AsyncSocket::writev(
const iovec* vec,
size_t count,
WriteFlags flags) {
writeImpl(callback, vec, count, unique_ptr<IOBuf>(), flags);
size_t totalBytes = 0;
for (size_t i = 0; i < count; ++i) {
totalBytes += vec[i].iov_len;
}
writeImpl(callback, vec, count, unique_ptr<IOBuf>(), totalBytes, flags);
}
void AsyncSocket::writeChain(
......@@ -1048,8 +1053,9 @@ void AsyncSocket::writeChainImpl(
size_t count,
unique_ptr<IOBuf>&& buf,
WriteFlags flags) {
size_t veclen = buf->fillIov(vec, count);
writeImpl(callback, vec, veclen, std::move(buf), flags);
auto res = buf->fillIov(vec, count);
writeImpl(
callback, vec, res.numIovecs, std::move(buf), res.totalLength, flags);
}
void AsyncSocket::writeImpl(
......@@ -1057,6 +1063,7 @@ void AsyncSocket::writeImpl(
const iovec* vec,
size_t count,
unique_ptr<IOBuf>&& buf,
size_t totalBytes,
WriteFlags flags) {
VLOG(6) << "AsyncSocket::writev() this=" << this << ", fd=" << fd_
<< ", callback=" << callback << ", count=" << count
......@@ -1065,6 +1072,8 @@ void AsyncSocket::writeImpl(
unique_ptr<IOBuf> ioBuf(std::move(buf));
eventBase_->dcheckIsInEventBaseThread();
totalAppBytesScheduledForWrite_ += totalBytes;
if (shutdownFlags_ & (SHUT_WRITE | SHUT_WRITE_PENDING)) {
// No new writes may be performed after the write side of the socket has
// been shutdown.
......@@ -1118,9 +1127,6 @@ void AsyncSocket::writeImpl(
if (bytesWritten && isZeroCopyRequest(flags)) {
addZeroCopyBuf(ioBuf.get());
}
if (bufferCallback_) {
bufferCallback_->onEgressBuffered();
}
}
if (!connecting()) {
// Writes might put the socket back into connecting state
......@@ -1163,6 +1169,10 @@ void AsyncSocket::writeImpl(
writeReqTail_ = req;
}
if (bufferCallback_) {
bufferCallback_->onEgressBuffered();
}
// Register for write events if are established and not currently
// waiting on write events
if (mustRegister) {
......@@ -2117,10 +2127,10 @@ void AsyncSocket::handleWrite() noexcept {
// We'll continue around the loop, trying to write another request
} else {
// Partial write.
writeReqHead_->consume();
if (bufferCallback_) {
bufferCallback_->onEgressBuffered();
}
writeReqHead_->consume();
// Stop after a partial write; it's highly likely that a subsequent write
// attempt will just return EAGAIN.
//
......@@ -2684,6 +2694,9 @@ void AsyncSocket::failAllWrites(const AsyncSocketException& ex) {
}
req->destroy();
}
// All pending writes have failed - reset totalAppBytesScheduledForWrite_
totalAppBytesScheduledForWrite_ = appBytesWritten_;
}
void AsyncSocket::invalidState(ConnectCallback* callback) {
......
......@@ -606,6 +606,13 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
return getAppBytesReceived();
}
size_t getAppBytesBuffered() const override {
return totalAppBytesScheduledForWrite_ - appBytesWritten_;
}
size_t getRawBytesBuffered() const override {
return getAppBytesBuffered();
}
std::chrono::nanoseconds getConnectTime() const {
return connectEndTime_ - connectStartTime_;
}
......@@ -1100,23 +1107,25 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
* and queue up any leftover data to send when the socket can
* handle writes again.
*
* @param callback The callback to invoke when the write is completed.
* @param vec Array of buffers to write; this method will make a
* copy of the vector (but not the buffers themselves)
* if the write has to be completed asynchronously.
* @param count Number of elements in vec.
* @param buf The IOBuf that manages the buffers referenced by
* vec, or a pointer to nullptr if the buffers are not
* associated with an IOBuf. Note that ownership of
* the IOBuf is transferred here; upon completion of
* the write, the AsyncSocket deletes the IOBuf.
* @param flags Set of write flags.
* @param callback The callback to invoke when the write is completed.
* @param vec Array of buffers to write; this method will make a
* copy of the vector (but not the buffers themselves)
* if the write has to be completed asynchronously.
* @param count Number of elements in vec.
* @param buf The IOBuf that manages the buffers referenced by
* vec, or a pointer to nullptr if the buffers are not
* associated with an IOBuf. Note that ownership of
* the IOBuf is transferred here; upon completion of
* the write, the AsyncSocket deletes the IOBuf.
* @param totalBytes The total number of bytes to be written.
* @param flags Set of write flags.
*/
void writeImpl(
WriteCallback* callback,
const iovec* vec,
size_t count,
std::unique_ptr<folly::IOBuf>&& buf,
size_t totalBytes,
WriteFlags flags = WriteFlags::NONE);
/**
......@@ -1263,6 +1272,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
std::weak_ptr<ShutdownSocketSet> wShutdownSocketSet_;
size_t appBytesReceived_; ///< Num of bytes received from socket
size_t appBytesWritten_; ///< Num of bytes written to socket
// The total num of bytes passed to AsyncSocket's write functions. It doesn't
// include failed writes, but it does include buffered writes.
size_t totalAppBytesScheduledForWrite_;
// Pre-received data, to be returned to read callback before any data from the
// socket.
......
......@@ -434,10 +434,34 @@ class AsyncTransport : public DelayedDestruction, public AsyncSocketBase {
virtual size_t getAppBytesReceived() const = 0;
virtual size_t getRawBytesReceived() const = 0;
/**
* Calculates the total number of bytes that are currently buffered in the
* transport to be written later.
*/
virtual size_t getAppBytesBuffered() const {
return 0;
}
virtual size_t getRawBytesBuffered() const {
return 0;
}
/**
* Callback class to signal changes in the transport's internal buffers.
*/
class BufferCallback {
public:
virtual ~BufferCallback() {}
virtual ~BufferCallback() = default;
/**
* onEgressBuffered() will be invoked when there's a partial write and it
* is necessary to buffer the remaining data.
*/
virtual void onEgressBuffered() = 0;
/**
* onEgressBufferCleared() will be invoked when whatever was buffered is
* written, or when it errors out.
*/
virtual void onEgressBufferCleared() = 0;
};
......
......@@ -240,7 +240,7 @@ ssize_t AsyncUDPSocket::writeGSO(
// buffers less than 16, which is the highest I can think of
// for a real use case.
iovec vec[16];
size_t iovec_len = buf->fillIov(vec, sizeof(vec) / sizeof(vec[0]));
size_t iovec_len = buf->fillIov(vec, sizeof(vec) / sizeof(vec[0])).numIovecs;
if (UNLIKELY(iovec_len == 0)) {
buf->coalesce();
vec[0].iov_base = const_cast<uint8_t*>(buf->data());
......
......@@ -183,13 +183,25 @@ class ReadCallback : public folly::AsyncTransportWrapper::ReadCallback {
class BufferCallback : public folly::AsyncTransport::BufferCallback {
public:
BufferCallback() : buffered_(false), bufferCleared_(false) {}
BufferCallback(folly::AsyncSocket* socket, size_t expectedBytes)
: socket_(socket),
expectedBytes_(expectedBytes),
buffered_(false),
bufferCleared_(false) {}
void onEgressBuffered() override {
size_t bytesWritten = socket_->getAppBytesWritten();
size_t bytesBuffered = socket_->getAppBytesBuffered();
CHECK_GT(bytesBuffered, 0);
CHECK_EQ(expectedBytes_, bytesWritten + bytesBuffered);
buffered_ = true;
}
void onEgressBufferCleared() override {
size_t bytesWritten = socket_->getAppBytesWritten();
size_t bytesBuffered = socket_->getAppBytesBuffered();
CHECK_EQ(0, bytesBuffered);
CHECK_EQ(expectedBytes_, bytesWritten);
bufferCleared_ = true;
}
......@@ -202,6 +214,8 @@ class BufferCallback : public folly::AsyncTransport::BufferCallback {
}
private:
folly::AsyncSocket* socket_{nullptr};
size_t expectedBytes_{0};
bool buffered_{false};
bool bufferCleared_{false};
};
......
......@@ -2337,7 +2337,7 @@ TEST(AsyncSocketTest, BufferTest) {
char buf[100 * 1024];
memset(buf, 'c', sizeof(buf));
WriteCallback wcb;
BufferCallback bcb;
BufferCallback bcb(socket.get(), sizeof(buf));
socket->setBufferCallback(&bcb);
socket->write(&wcb, buf, sizeof(buf), WriteFlags::NONE);
......@@ -2355,6 +2355,46 @@ TEST(AsyncSocketTest, BufferTest) {
ASSERT_FALSE(socket->isClosedByPeer());
}
TEST(AsyncSocketTest, BufferTestChain) {
TestServer server;
EventBase evb;
AsyncSocket::OptionMap option{{{SOL_SOCKET, SO_SNDBUF}, 128}};
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30, option);
char buf1[100 * 1024];
memset(buf1, 'c', sizeof(buf1));
char buf2[100 * 1024];
memset(buf2, 'f', sizeof(buf2));
auto buf = folly::IOBuf::copyBuffer(buf1, sizeof(buf1));
buf->appendChain(folly::IOBuf::copyBuffer(buf2, sizeof(buf2)));
ASSERT_EQ(sizeof(buf1) + sizeof(buf2), buf->computeChainDataLength());
BufferCallback bcb(socket.get(), buf->computeChainDataLength());
socket->setBufferCallback(&bcb);
WriteCallback wcb;
socket->writeChain(&wcb, buf->clone(), WriteFlags::NONE);
evb.loop();
ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
ASSERT_TRUE(bcb.hasBuffered());
ASSERT_TRUE(bcb.hasBufferCleared());
socket->close();
buf->coalesce();
server.verifyConnection(
reinterpret_cast<const char*>(buf->data()), buf->length());
ASSERT_TRUE(socket->isClosedBySelf());
ASSERT_FALSE(socket->isClosedByPeer());
}
TEST(AsyncSocketTest, BufferCallbackKill) {
TestServer server;
EventBase evb;
......@@ -2366,7 +2406,7 @@ TEST(AsyncSocketTest, BufferCallbackKill) {
char buf[100 * 1024];
memset(buf, 'c', sizeof(buf));
BufferCallback bcb;
BufferCallback bcb(socket.get(), sizeof(buf));
socket->setBufferCallback(&bcb);
WriteCallback wcb;
wcb.successCallback = [&] {
......
......@@ -294,7 +294,8 @@ class ConnectedWriteUDPClient : public UDPClient {
// msg. This will test that connect worked.
void writePing(std::unique_ptr<folly::IOBuf> buf) override {
iovec vec[16];
size_t iovec_len = buf->fillIov(vec, sizeof(vec) / sizeof(vec[0]));
size_t iovec_len =
buf->fillIov(vec, sizeof(vec) / sizeof(vec[0])).numIovecs;
if (UNLIKELY(iovec_len == 0)) {
buf->coalesce();
vec[0].iov_base = const_cast<uint8_t*>(buf->data());
......
......@@ -1500,3 +1500,53 @@ TEST(IOBuf, CloneCoalescedSingle) {
EXPECT_EQ(b->data(), c->data());
EXPECT_EQ(b->length(), c->length());
}
TEST(IOBuf, fillIov) {
auto buf = IOBuf::create(4096);
append(buf, "hello");
auto buf2 = IOBuf::create(4096);
append(buf2, "goodbye");
auto buf3 = IOBuf::create(4096);
append(buf3, "hello again");
buf2->appendChain(std::move(buf3));
buf->appendChain(std::move(buf2));
constexpr size_t iovCount = 3;
struct iovec vec[iovCount];
auto res = buf->fillIov(vec, iovCount);
EXPECT_EQ(iovCount, res.numIovecs);
EXPECT_EQ(23, res.totalLength);
EXPECT_EQ(
"hello",
std::string(
reinterpret_cast<const char*>(vec[0].iov_base), vec[0].iov_len));
EXPECT_EQ(
"goodbye",
std::string(
reinterpret_cast<const char*>(vec[1].iov_base), vec[1].iov_len));
EXPECT_EQ(
"hello again",
std::string(
reinterpret_cast<const char*>(vec[2].iov_base), vec[2].iov_len));
}
TEST(IOBuf, fillIov2) {
auto buf = IOBuf::create(4096);
append(buf, "hello");
auto buf2 = IOBuf::create(4096);
append(buf2, "goodbye");
auto buf3 = IOBuf::create(4096);
append(buf2, "hello again");
buf2->appendChain(std::move(buf3));
buf->appendChain(std::move(buf2));
constexpr size_t iovCount = 2;
struct iovec vec[iovCount];
auto res = buf->fillIov(vec, iovCount);
EXPECT_EQ(0, res.numIovecs);
EXPECT_EQ(0, res.totalLength);
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment