Commit d3489f9e authored by Dan Melnic's avatar Dan Melnic Committed by Facebook GitHub Bot

Allow recycling of std::unique_ptr<IOBuf>

Summary: Allow recycling of std::unique_ptr<IOBuf>

Reviewed By: yfeldblum

Differential Revision: D24650191

fbshipit-source-id: 85d219052cc62c35098085abb3eed6cfe00beefc
parent 4087512f
...@@ -120,6 +120,9 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest { ...@@ -120,6 +120,9 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
} }
void destroy() override { void destroy() override {
if (ioBuf_ && releaseIOBufCallback_) {
releaseIOBufCallback_->releaseIOBuf(std::move(ioBuf_));
}
this->~BytesWriteRequest(); this->~BytesWriteRequest();
free(this); free(this);
} }
...@@ -138,7 +141,7 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest { ...@@ -138,7 +141,7 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
if (bytesWritten_) { if (bytesWritten_) {
if (socket_->isZeroCopyRequest(writeFlags)) { if (socket_->isZeroCopyRequest(writeFlags)) {
if (isComplete()) { if (isComplete()) {
socket_->addZeroCopyBuf(std::move(ioBuf_)); socket_->addZeroCopyBuf(std::move(ioBuf_), releaseIOBufCallback_);
} else { } else {
socket_->addZeroCopyBuf(ioBuf_.get()); socket_->addZeroCopyBuf(ioBuf_.get());
} }
...@@ -147,7 +150,7 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest { ...@@ -147,7 +150,7 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
// with zero copy but not the last one // with zero copy but not the last one
if (isComplete() && zeroCopyRequest_ && if (isComplete() && zeroCopyRequest_ &&
socket_->containsZeroCopyBuf(ioBuf_.get())) { socket_->containsZeroCopyBuf(ioBuf_.get())) {
socket_->setZeroCopyBuf(std::move(ioBuf_)); socket_->setZeroCopyBuf(std::move(ioBuf_), releaseIOBufCallback_);
} }
} }
} }
...@@ -172,7 +175,11 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest { ...@@ -172,7 +175,11 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
if (ioBuf_) { if (ioBuf_) {
for (uint32_t i = opsWritten_; i != 0; --i) { for (uint32_t i = opsWritten_; i != 0; --i) {
assert(ioBuf_); assert(ioBuf_);
ioBuf_ = ioBuf_->pop(); auto next = ioBuf_->pop();
if (releaseIOBufCallback_) {
releaseIOBufCallback_->releaseIOBuf(std::move(ioBuf_));
}
ioBuf_ = std::move(next);
} }
} }
} }
...@@ -968,7 +975,9 @@ void AsyncSocket::adjustZeroCopyFlags(folly::WriteFlags& flags) { ...@@ -968,7 +975,9 @@ void AsyncSocket::adjustZeroCopyFlags(folly::WriteFlags& flags) {
} }
} }
void AsyncSocket::addZeroCopyBuf(std::unique_ptr<folly::IOBuf>&& buf) { void AsyncSocket::addZeroCopyBuf(
std::unique_ptr<folly::IOBuf>&& buf,
ReleaseIOBufCallback* cb) {
uint32_t id = getNextZeroCopyBufId(); uint32_t id = getNextZeroCopyBufId();
folly::IOBuf* ptr = buf.get(); folly::IOBuf* ptr = buf.get();
...@@ -977,6 +986,7 @@ void AsyncSocket::addZeroCopyBuf(std::unique_ptr<folly::IOBuf>&& buf) { ...@@ -977,6 +986,7 @@ void AsyncSocket::addZeroCopyBuf(std::unique_ptr<folly::IOBuf>&& buf) {
p.count_++; p.count_++;
CHECK(p.buf_.get() == nullptr); CHECK(p.buf_.get() == nullptr);
p.buf_ = std::move(buf); p.buf_ = std::move(buf);
p.cb_ = cb;
} }
void AsyncSocket::addZeroCopyBuf(folly::IOBuf* ptr) { void AsyncSocket::addZeroCopyBuf(folly::IOBuf* ptr) {
...@@ -993,18 +1003,24 @@ void AsyncSocket::releaseZeroCopyBuf(uint32_t id) { ...@@ -993,18 +1003,24 @@ void AsyncSocket::releaseZeroCopyBuf(uint32_t id) {
auto iter1 = idZeroCopyBufInfoMap_.find(ptr); auto iter1 = idZeroCopyBufInfoMap_.find(ptr);
CHECK(iter1 != idZeroCopyBufInfoMap_.end()); CHECK(iter1 != idZeroCopyBufInfoMap_.end());
if (0 == --iter1->second.count_) { if (0 == --iter1->second.count_) {
if (iter1->second.cb_) {
iter1->second.cb_->releaseIOBuf(std::move(iter1->second.buf_));
}
idZeroCopyBufInfoMap_.erase(iter1); idZeroCopyBufInfoMap_.erase(iter1);
} }
idZeroCopyBufPtrMap_.erase(iter); idZeroCopyBufPtrMap_.erase(iter);
} }
void AsyncSocket::setZeroCopyBuf(std::unique_ptr<folly::IOBuf>&& buf) { void AsyncSocket::setZeroCopyBuf(
std::unique_ptr<folly::IOBuf>&& buf,
ReleaseIOBufCallback* cb) {
folly::IOBuf* ptr = buf.get(); folly::IOBuf* ptr = buf.get();
auto& p = idZeroCopyBufInfoMap_[ptr]; auto& p = idZeroCopyBufInfoMap_[ptr];
CHECK(p.buf_.get() == nullptr); CHECK(p.buf_.get() == nullptr);
p.buf_ = std::move(buf); p.buf_ = std::move(buf);
p.cb_ = cb;
} }
bool AsyncSocket::containsZeroCopyBuf(folly::IOBuf* ptr) { bool AsyncSocket::containsZeroCopyBuf(folly::IOBuf* ptr) {
...@@ -1123,6 +1139,15 @@ void AsyncSocket::writeImpl( ...@@ -1123,6 +1139,15 @@ void AsyncSocket::writeImpl(
unique_ptr<IOBuf> ioBuf(std::move(buf)); unique_ptr<IOBuf> ioBuf(std::move(buf));
eventBase_->dcheckIsInEventBaseThread(); eventBase_->dcheckIsInEventBaseThread();
auto* releaseIOBufCallback =
callback ? callback->getReleaseIOBufCallback() : nullptr;
SCOPE_EXIT {
if (ioBuf && releaseIOBufCallback) {
releaseIOBufCallback->releaseIOBuf(std::move(ioBuf));
}
};
totalAppBytesScheduledForWrite_ += totalBytes; totalAppBytesScheduledForWrite_ += totalBytes;
if (shutdownFlags_ & (SHUT_WRITE | SHUT_WRITE_PENDING)) { if (shutdownFlags_ & (SHUT_WRITE | SHUT_WRITE_PENDING)) {
...@@ -1165,10 +1190,14 @@ void AsyncSocket::writeImpl( ...@@ -1165,10 +1190,14 @@ void AsyncSocket::writeImpl(
} else if (countWritten == count) { } else if (countWritten == count) {
// done, add the whole buffer // done, add the whole buffer
if (countWritten && isZeroCopyRequest(flags)) { if (countWritten && isZeroCopyRequest(flags)) {
addZeroCopyBuf(std::move(ioBuf)); addZeroCopyBuf(std::move(ioBuf), releaseIOBufCallback);
} else {
if (releaseIOBufCallback) {
releaseIOBufCallback->releaseIOBuf(std::move(ioBuf));
} else { } else {
ioBuf.reset(); ioBuf.reset();
} }
}
// We successfully wrote everything. // We successfully wrote everything.
// Invoke the callback and return. // Invoke the callback and return.
......
...@@ -895,7 +895,10 @@ class AsyncSocket : public AsyncTransport { ...@@ -895,7 +895,10 @@ class AsyncSocket : public AsyncTransport {
class WriteRequest { class WriteRequest {
public: public:
WriteRequest(AsyncSocket* socket, WriteCallback* callback) WriteRequest(AsyncSocket* socket, WriteCallback* callback)
: socket_(socket), callback_(callback) {} : socket_(socket),
callback_(callback),
releaseIOBufCallback_(
callback ? callback->getReleaseIOBufCallback() : nullptr) {}
virtual void start() {} virtual void start() {}
...@@ -934,6 +937,7 @@ class AsyncSocket : public AsyncTransport { ...@@ -934,6 +937,7 @@ class AsyncSocket : public AsyncTransport {
AsyncSocket* socket_; ///< parent socket AsyncSocket* socket_; ///< parent socket
WriteRequest* next_{nullptr}; ///< pointer to next WriteRequest WriteRequest* next_{nullptr}; ///< pointer to next WriteRequest
WriteCallback* callback_; ///< completion callback WriteCallback* callback_; ///< completion callback
ReleaseIOBufCallback* releaseIOBufCallback_; ///< release IOBuf callback
uint32_t totalBytesWritten_{0}; ///< total bytes written uint32_t totalBytesWritten_{0}; ///< total bytes written
}; };
...@@ -1291,9 +1295,13 @@ class AsyncSocket : public AsyncTransport { ...@@ -1291,9 +1295,13 @@ class AsyncSocket : public AsyncTransport {
uint32_t getNextZeroCopyBufId() { return zeroCopyBufId_++; } uint32_t getNextZeroCopyBufId() { return zeroCopyBufId_++; }
void adjustZeroCopyFlags(folly::WriteFlags& flags); void adjustZeroCopyFlags(folly::WriteFlags& flags);
void addZeroCopyBuf(std::unique_ptr<folly::IOBuf>&& buf); void addZeroCopyBuf(
std::unique_ptr<folly::IOBuf>&& buf,
ReleaseIOBufCallback* cb);
void addZeroCopyBuf(folly::IOBuf* ptr); void addZeroCopyBuf(folly::IOBuf* ptr);
void setZeroCopyBuf(std::unique_ptr<folly::IOBuf>&& buf); void setZeroCopyBuf(
std::unique_ptr<folly::IOBuf>&& buf,
ReleaseIOBufCallback* cb);
bool containsZeroCopyBuf(folly::IOBuf* ptr); bool containsZeroCopyBuf(folly::IOBuf* ptr);
void releaseZeroCopyBuf(uint32_t id); void releaseZeroCopyBuf(uint32_t id);
...@@ -1307,6 +1315,7 @@ class AsyncSocket : public AsyncTransport { ...@@ -1307,6 +1315,7 @@ class AsyncSocket : public AsyncTransport {
struct IOBufInfo { struct IOBufInfo {
uint32_t count_{0}; uint32_t count_{0};
ReleaseIOBufCallback* cb_{nullptr};
std::unique_ptr<folly::IOBuf> buf_; std::unique_ptr<folly::IOBuf> buf_;
}; };
......
...@@ -272,6 +272,13 @@ class AsyncReader { ...@@ -272,6 +272,13 @@ class AsyncReader {
class AsyncWriter { class AsyncWriter {
public: public:
class ReleaseIOBufCallback {
public:
virtual ~ReleaseIOBufCallback() = default;
virtual void releaseIOBuf(std::unique_ptr<folly::IOBuf>) noexcept = 0;
};
class WriteCallback { class WriteCallback {
public: public:
virtual ~WriteCallback() = default; virtual ~WriteCallback() = default;
...@@ -298,6 +305,10 @@ class AsyncWriter { ...@@ -298,6 +305,10 @@ class AsyncWriter {
virtual void writeErr( virtual void writeErr(
size_t bytesWritten, size_t bytesWritten,
const AsyncSocketException& ex) noexcept = 0; const AsyncSocketException& ex) noexcept = 0;
virtual ReleaseIOBufCallback* getReleaseIOBufCallback() noexcept {
return nullptr;
}
}; };
/** /**
......
...@@ -55,12 +55,16 @@ class ConnCallback : public folly::AsyncSocket::ConnectCallback { ...@@ -55,12 +55,16 @@ class ConnCallback : public folly::AsyncSocket::ConnectCallback {
VoidCallback errorCallback; VoidCallback errorCallback;
}; };
class WriteCallback : public folly::AsyncTransport::WriteCallback { class WriteCallback : public folly::AsyncTransport::WriteCallback,
public folly::AsyncWriter::ReleaseIOBufCallback {
public: public:
WriteCallback() explicit WriteCallback(bool enableReleaseIOBufCallback = false)
: state(STATE_WAITING), : state(STATE_WAITING),
bytesWritten(0), bytesWritten(0),
exception(folly::AsyncSocketException::UNKNOWN, "none") {} numIoBufCount(0),
numIoBufBytes(0),
exception(folly::AsyncSocketException::UNKNOWN, "none"),
releaseIOBufCallback(enableReleaseIOBufCallback ? this : nullptr) {}
void writeSuccess() noexcept override { void writeSuccess() noexcept override {
state = STATE_SUCCEEDED; state = STATE_SUCCEEDED;
...@@ -81,11 +85,24 @@ class WriteCallback : public folly::AsyncTransport::WriteCallback { ...@@ -81,11 +85,24 @@ class WriteCallback : public folly::AsyncTransport::WriteCallback {
} }
} }
folly::AsyncWriter::ReleaseIOBufCallback*
getReleaseIOBufCallback() noexcept override {
return releaseIOBufCallback;
}
void releaseIOBuf(std::unique_ptr<folly::IOBuf> ioBuf) noexcept override {
numIoBufCount += ioBuf->countChainElements();
numIoBufBytes += ioBuf->computeChainDataLength();
}
StateEnum state; StateEnum state;
std::atomic<size_t> bytesWritten; std::atomic<size_t> bytesWritten;
std::atomic<size_t> numIoBufCount;
std::atomic<size_t> numIoBufBytes;
folly::AsyncSocketException exception; folly::AsyncSocketException exception;
VoidCallback successCallback; VoidCallback successCallback;
VoidCallback errorCallback; VoidCallback errorCallback;
ReleaseIOBufCallback* releaseIOBufCallback;
}; };
class ReadCallback : public folly::AsyncTransport::ReadCallback { class ReadCallback : public folly::AsyncTransport::ReadCallback {
......
...@@ -219,8 +219,9 @@ TEST_P(AsyncSocketConnectTest, ConnectAndWrite) { ...@@ -219,8 +219,9 @@ TEST_P(AsyncSocketConnectTest, ConnectAndWrite) {
// write() // write()
char buf[128]; char buf[128];
memset(buf, 'a', sizeof(buf)); memset(buf, 'a', sizeof(buf));
WriteCallback wcb; WriteCallback wcb(true /*enableReleaseIOBufCallback*/);
socket->write(&wcb, buf, sizeof(buf)); // use writeChain so we can pass an IOBuf
socket->writeChain(&wcb, IOBuf::copyBuffer(buf, sizeof(buf)));
// Loop. We don't bother accepting on the server socket yet. // Loop. We don't bother accepting on the server socket yet.
// The kernel should be able to buffer the write request so it can succeed. // The kernel should be able to buffer the write request so it can succeed.
...@@ -228,6 +229,8 @@ TEST_P(AsyncSocketConnectTest, ConnectAndWrite) { ...@@ -228,6 +229,8 @@ TEST_P(AsyncSocketConnectTest, ConnectAndWrite) {
ASSERT_EQ(ccb.state, STATE_SUCCEEDED); ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
ASSERT_EQ(wcb.state, STATE_SUCCEEDED); ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
ASSERT_EQ(wcb.numIoBufCount, 1);
ASSERT_EQ(wcb.numIoBufBytes, sizeof(buf));
// Make sure the server got a connection and received the data // Make sure the server got a connection and received the data
socket->close(); socket->close();
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment