Commit 48a8ecdb authored by James Sedgwick's avatar James Sedgwick Committed by Viswanath Sivakumar

make AsyncSocket::WriteRequest an interface

Summary: This will allow a subsequent diff to implement file transfers as another type of write request

Test Plan: unit

Reviewed By: davejwatson@fb.com

Subscribers: net-systems@, folly-diffs@, yfeldblum, chalfant, fugalh, bmatheny

FB internal diff: D2080257

Signature: t1:2080257:1432044566:bcc0724d349879f46e3e58ee672aff7bf37fa5f6
parent e1c97644
...@@ -53,123 +53,168 @@ const AsyncSocketException socketShutdownForWritesEx( ...@@ -53,123 +53,168 @@ const AsyncSocketException socketShutdownForWritesEx(
// the WriteRequest. // the WriteRequest.
/** /**
* A WriteRequest object tracks information about a pending write() or writev() * A WriteRequest object tracks information about a pending write operation.
* operation. */
class AsyncSocket::WriteRequest {
public:
WriteRequest(AsyncSocket* socket,
WriteRequest* next,
WriteCallback* callback,
uint32_t totalBytesWritten) :
socket_(socket), next_(next), callback_(callback),
totalBytesWritten_(totalBytesWritten) {}
virtual void destroy() = 0;
virtual bool performWrite() = 0;
virtual void consume() = 0;
virtual bool isComplete() = 0;
WriteRequest* getNext() const {
return next_;
}
WriteCallback* getCallback() const {
return callback_;
}
uint32_t getTotalBytesWritten() const {
return totalBytesWritten_;
}
void append(WriteRequest* next) {
assert(next_ == nullptr);
next_ = next;
}
protected:
// protected destructor, to ensure callers use destroy()
virtual ~WriteRequest() {}
AsyncSocket* socket_; ///< parent socket
WriteRequest* next_; ///< pointer to next WriteRequest
WriteCallback* callback_; ///< completion callback
uint32_t totalBytesWritten_; ///< total bytes written
};
/* The default WriteRequest implementation, used for write(), writev() and
* writeChain()
* *
* A new WriteRequest operation is allocated on the heap for all write * A new BytesWriteRequest operation is allocated on the heap for all write
* operations that cannot be completed immediately. * operations that cannot be completed immediately.
*/ */
class AsyncSocket::WriteRequest { class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
public: public:
static WriteRequest* newRequest(WriteCallback* callback, static BytesWriteRequest* newRequest(AsyncSocket* socket,
WriteCallback* callback,
const iovec* ops, const iovec* ops,
uint32_t opCount, uint32_t opCount,
uint32_t partialWritten,
uint32_t bytesWritten,
unique_ptr<IOBuf>&& ioBuf, unique_ptr<IOBuf>&& ioBuf,
WriteFlags flags) { WriteFlags flags) {
assert(opCount > 0); assert(opCount > 0);
// Since we put a variable size iovec array at the end // Since we put a variable size iovec array at the end
// of each WriteRequest, we have to manually allocate the memory. // of each BytesWriteRequest, we have to manually allocate the memory.
void* buf = malloc(sizeof(WriteRequest) + void* buf = malloc(sizeof(BytesWriteRequest) +
(opCount * sizeof(struct iovec))); (opCount * sizeof(struct iovec)));
if (buf == nullptr) { if (buf == nullptr) {
throw std::bad_alloc(); throw std::bad_alloc();
} }
return new(buf) WriteRequest(callback, ops, opCount, std::move(ioBuf), return new(buf) BytesWriteRequest(socket, callback, ops, opCount,
flags); partialWritten, bytesWritten,
std::move(ioBuf), flags);
} }
void destroy() { void destroy() override {
this->~WriteRequest(); this->~BytesWriteRequest();
free(this); free(this);
} }
bool cork() const { bool performWrite() override {
return isSet(flags_, WriteFlags::CORK); WriteFlags writeFlags = flags_;
} if (getNext() != nullptr) {
writeFlags = writeFlags | WriteFlags::CORK;
WriteFlags flags() const {
return flags_;
}
WriteRequest* getNext() const {
return next_;
}
WriteCallback* getCallback() const {
return callback_;
}
uint32_t getBytesWritten() const {
return bytesWritten_;
} }
bytesWritten_ = socket_->performWrite(getOps(), getOpCount(), writeFlags,
const struct iovec* getOps() const { &opsWritten_, &partialBytes_);
assert(opCount_ > opIndex_); return bytesWritten_ >= 0;
return writeOps_ + opIndex_;
} }
uint32_t getOpCount() const { bool isComplete() override {
assert(opCount_ > opIndex_); return opsWritten_ == getOpCount();
return opCount_ - opIndex_;
} }
void consume(uint32_t wholeOps, uint32_t partialBytes, void consume() override {
uint32_t totalBytesWritten) { // Advance opIndex_ forward by opsWritten_
// Advance opIndex_ forward by wholeOps opIndex_ += opsWritten_;
opIndex_ += wholeOps;
assert(opIndex_ < opCount_); assert(opIndex_ < opCount_);
// If we've finished writing any IOBufs, release them // If we've finished writing any IOBufs, release them
if (ioBuf_) { if (ioBuf_) {
for (uint32_t i = wholeOps; i != 0; --i) { for (uint32_t i = opsWritten_; i != 0; --i) {
assert(ioBuf_); assert(ioBuf_);
ioBuf_ = ioBuf_->pop(); ioBuf_ = ioBuf_->pop();
} }
} }
// Move partialBytes forward into the current iovec buffer // Move partialBytes_ forward into the current iovec buffer
struct iovec* currentOp = writeOps_ + opIndex_; struct iovec* currentOp = writeOps_ + opIndex_;
assert((partialBytes < currentOp->iov_len) || (currentOp->iov_len == 0)); assert((partialBytes_ < currentOp->iov_len) || (currentOp->iov_len == 0));
currentOp->iov_base = currentOp->iov_base =
reinterpret_cast<uint8_t*>(currentOp->iov_base) + partialBytes; reinterpret_cast<uint8_t*>(currentOp->iov_base) + partialBytes_;
currentOp->iov_len -= partialBytes; currentOp->iov_len -= partialBytes_;
// Increment the bytesWritten_ count by totalBytesWritten // Increment the totalBytesWritten_ count by bytesWritten_;
bytesWritten_ += totalBytesWritten; totalBytesWritten_ += bytesWritten_;
}
void append(WriteRequest* next) {
assert(next_ == nullptr);
next_ = next;
} }
private: private:
WriteRequest(WriteCallback* callback, BytesWriteRequest(AsyncSocket* socket,
WriteCallback* callback,
const struct iovec* ops, const struct iovec* ops,
uint32_t opCount, uint32_t opCount,
uint32_t partialBytes,
uint32_t bytesWritten,
unique_ptr<IOBuf>&& ioBuf, unique_ptr<IOBuf>&& ioBuf,
WriteFlags flags) WriteFlags flags)
: next_(nullptr) : AsyncSocket::WriteRequest(socket, nullptr, callback, 0)
, callback_(callback)
, bytesWritten_(0)
, opCount_(opCount) , opCount_(opCount)
, opIndex_(0) , opIndex_(0)
, flags_(flags) , flags_(flags)
, ioBuf_(std::move(ioBuf)) { , ioBuf_(std::move(ioBuf))
, opsWritten_(0)
, partialBytes_(partialBytes)
, bytesWritten_(bytesWritten) {
memcpy(writeOps_, ops, sizeof(*ops) * opCount_); memcpy(writeOps_, ops, sizeof(*ops) * opCount_);
} }
// Private destructor, to ensure callers use destroy() // private destructor, to ensure callers use destroy()
~WriteRequest() {} virtual ~BytesWriteRequest() {}
const struct iovec* getOps() const {
assert(opCount_ > opIndex_);
return writeOps_ + opIndex_;
}
uint32_t getOpCount() const {
assert(opCount_ > opIndex_);
return opCount_ - opIndex_;
}
WriteRequest* next_; ///< pointer to next WriteRequest
WriteCallback* callback_; ///< completion callback
uint32_t bytesWritten_; ///< bytes written
uint32_t opCount_; ///< number of entries in writeOps_ uint32_t opCount_; ///< number of entries in writeOps_
uint32_t opIndex_; ///< current index into writeOps_ uint32_t opIndex_; ///< current index into writeOps_
WriteFlags flags_; ///< set for WriteFlags WriteFlags flags_; ///< set for WriteFlags
unique_ptr<IOBuf> ioBuf_; ///< underlying IOBuf, or nullptr if N/A unique_ptr<IOBuf> ioBuf_; ///< underlying IOBuf, or nullptr if N/A
// for consume(), how much we wrote on the last write
uint32_t opsWritten_; ///< complete ops written
uint32_t partialBytes_; ///< partial bytes of incomplete op written
ssize_t bytesWritten_; ///< bytes written altogether
struct iovec writeOps_[]; ///< write operation(s) list struct iovec writeOps_[]; ///< write operation(s) list
}; };
...@@ -687,16 +732,16 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec, ...@@ -687,16 +732,16 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
// Create a new WriteRequest to add to the queue // Create a new WriteRequest to add to the queue
WriteRequest* req; WriteRequest* req;
try { try {
req = WriteRequest::newRequest(callback, vec + countWritten, req = BytesWriteRequest::newRequest(this, callback, vec + countWritten,
count - countWritten, std::move(ioBuf), count - countWritten, partialWritten,
flags); bytesWritten, std::move(ioBuf), flags);
} catch (const std::exception& ex) { } catch (const std::exception& ex) {
// we mainly expect to catch std::bad_alloc here // we mainly expect to catch std::bad_alloc here
AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR, AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
withAddr(string("failed to append new WriteRequest: ") + ex.what())); withAddr(string("failed to append new WriteRequest: ") + ex.what()));
return failWrite(__func__, callback, bytesWritten, tex); return failWrite(__func__, callback, bytesWritten, tex);
} }
req->consume(0, partialWritten, bytesWritten); req->consume();
if (writeReqTail_ == nullptr) { if (writeReqTail_ == nullptr) {
assert(writeReqHead_ == nullptr); assert(writeReqHead_ == nullptr);
writeReqHead_ = writeReqTail_ = req; writeReqHead_ = writeReqTail_ = req;
...@@ -1346,20 +1391,11 @@ void AsyncSocket::handleWrite() noexcept { ...@@ -1346,20 +1391,11 @@ void AsyncSocket::handleWrite() noexcept {
// (See the comment in handleRead() explaining how this can happen.) // (See the comment in handleRead() explaining how this can happen.)
EventBase* originalEventBase = eventBase_; EventBase* originalEventBase = eventBase_;
while (writeReqHead_ != nullptr && eventBase_ == originalEventBase) { while (writeReqHead_ != nullptr && eventBase_ == originalEventBase) {
uint32_t countWritten; if (!writeReqHead_->performWrite()) {
uint32_t partialWritten;
WriteFlags writeFlags = writeReqHead_->flags();
if (writeReqHead_->getNext() != nullptr) {
writeFlags = writeFlags | WriteFlags::CORK;
}
int bytesWritten = performWrite(writeReqHead_->getOps(),
writeReqHead_->getOpCount(),
writeFlags, &countWritten, &partialWritten);
if (bytesWritten < 0) {
AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR, AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
withAddr("writev() failed"), errno); withAddr("writev() failed"), errno);
return failWrite(__func__, ex); return failWrite(__func__, ex);
} else if (countWritten == writeReqHead_->getOpCount()) { } else if (writeReqHead_->isComplete()) {
// We finished this request // We finished this request
WriteRequest* req = writeReqHead_; WriteRequest* req = writeReqHead_;
writeReqHead_ = req->getNext(); writeReqHead_ = req->getNext();
...@@ -1424,7 +1460,7 @@ void AsyncSocket::handleWrite() noexcept { ...@@ -1424,7 +1460,7 @@ void AsyncSocket::handleWrite() noexcept {
// We'll continue around the loop, trying to write another request // We'll continue around the loop, trying to write another request
} else { } else {
// Partial write. // Partial write.
writeReqHead_->consume(countWritten, partialWritten, bytesWritten); writeReqHead_->consume();
// Stop after a partial write; it's highly likely that a subsequent write // Stop after a partial write; it's highly likely that a subsequent write
// attempt will just return EAGAIN. // attempt will just return EAGAIN.
// //
...@@ -1822,7 +1858,7 @@ void AsyncSocket::failWrite(const char* fn, const AsyncSocketException& ex) { ...@@ -1822,7 +1858,7 @@ void AsyncSocket::failWrite(const char* fn, const AsyncSocketException& ex) {
WriteRequest* req = writeReqHead_; WriteRequest* req = writeReqHead_;
writeReqHead_ = req->getNext(); writeReqHead_ = req->getNext();
WriteCallback* callback = req->getCallback(); WriteCallback* callback = req->getCallback();
uint32_t bytesWritten = req->getBytesWritten(); uint32_t bytesWritten = req->getTotalBytesWritten();
req->destroy(); req->destroy();
if (callback) { if (callback) {
callback->writeErr(bytesWritten, ex); callback->writeErr(bytesWritten, ex);
...@@ -1859,7 +1895,7 @@ void AsyncSocket::failAllWrites(const AsyncSocketException& ex) { ...@@ -1859,7 +1895,7 @@ void AsyncSocket::failAllWrites(const AsyncSocketException& ex) {
writeReqHead_ = req->getNext(); writeReqHead_ = req->getNext();
WriteCallback* callback = req->getCallback(); WriteCallback* callback = req->getCallback();
if (callback) { if (callback) {
callback->writeErr(req->getBytesWritten(), ex); callback->writeErr(req->getTotalBytesWritten(), ex);
} }
req->destroy(); req->destroy();
} }
......
...@@ -517,6 +517,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -517,6 +517,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
}; };
class WriteRequest; class WriteRequest;
class BytesWriteRequest;
class WriteTimeout : public AsyncTimeout { class WriteTimeout : public AsyncTimeout {
public: public:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment