Commit 48a8ecdb authored by James Sedgwick's avatar James Sedgwick Committed by Viswanath Sivakumar

make AsyncSocket::WriteRequest an interface

Summary: This will allow a subsequent diff to implement file transfers as another type of write request

Test Plan: unit

Reviewed By: davejwatson@fb.com

Subscribers: net-systems@, folly-diffs@, yfeldblum, chalfant, fugalh, bmatheny

FB internal diff: D2080257

Signature: t1:2080257:1432044566:bcc0724d349879f46e3e58ee672aff7bf37fa5f6
parent e1c97644
......@@ -53,44 +53,24 @@ const AsyncSocketException socketShutdownForWritesEx(
// the WriteRequest.
/**
* A WriteRequest object tracks information about a pending write() or writev()
* operation.
*
* A new WriteRequest operation is allocated on the heap for all write
* operations that cannot be completed immediately.
* A WriteRequest object tracks information about a pending write operation.
*/
class AsyncSocket::WriteRequest {
public:
static WriteRequest* newRequest(WriteCallback* callback,
const iovec* ops,
uint32_t opCount,
unique_ptr<IOBuf>&& ioBuf,
WriteFlags flags) {
assert(opCount > 0);
// Since we put a variable size iovec array at the end
// of each WriteRequest, we have to manually allocate the memory.
void* buf = malloc(sizeof(WriteRequest) +
(opCount * sizeof(struct iovec)));
if (buf == nullptr) {
throw std::bad_alloc();
}
WriteRequest(AsyncSocket* socket,
WriteRequest* next,
WriteCallback* callback,
uint32_t totalBytesWritten) :
socket_(socket), next_(next), callback_(callback),
totalBytesWritten_(totalBytesWritten) {}
return new(buf) WriteRequest(callback, ops, opCount, std::move(ioBuf),
flags);
}
virtual void destroy() = 0;
void destroy() {
this->~WriteRequest();
free(this);
}
virtual bool performWrite() = 0;
bool cork() const {
return isSet(flags_, WriteFlags::CORK);
}
virtual void consume() = 0;
WriteFlags flags() const {
return flags_;
}
virtual bool isComplete() = 0;
WriteRequest* getNext() const {
return next_;
......@@ -100,76 +80,141 @@ class AsyncSocket::WriteRequest {
return callback_;
}
uint32_t getBytesWritten() const {
return bytesWritten_;
uint32_t getTotalBytesWritten() const {
return totalBytesWritten_;
}
const struct iovec* getOps() const {
assert(opCount_ > opIndex_);
return writeOps_ + opIndex_;
void append(WriteRequest* next) {
assert(next_ == nullptr);
next_ = next;
}
uint32_t getOpCount() const {
assert(opCount_ > opIndex_);
return opCount_ - opIndex_;
protected:
// protected destructor, to ensure callers use destroy()
virtual ~WriteRequest() {}
AsyncSocket* socket_; ///< parent socket
WriteRequest* next_; ///< pointer to next WriteRequest
WriteCallback* callback_; ///< completion callback
uint32_t totalBytesWritten_; ///< total bytes written
};
/* The default WriteRequest implementation, used for write(), writev() and
* writeChain()
*
* A new BytesWriteRequest operation is allocated on the heap for all write
* operations that cannot be completed immediately.
*/
class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
public:
static BytesWriteRequest* newRequest(AsyncSocket* socket,
WriteCallback* callback,
const iovec* ops,
uint32_t opCount,
uint32_t partialWritten,
uint32_t bytesWritten,
unique_ptr<IOBuf>&& ioBuf,
WriteFlags flags) {
assert(opCount > 0);
// Since we put a variable size iovec array at the end
// of each BytesWriteRequest, we have to manually allocate the memory.
void* buf = malloc(sizeof(BytesWriteRequest) +
(opCount * sizeof(struct iovec)));
if (buf == nullptr) {
throw std::bad_alloc();
}
return new(buf) BytesWriteRequest(socket, callback, ops, opCount,
partialWritten, bytesWritten,
std::move(ioBuf), flags);
}
void consume(uint32_t wholeOps, uint32_t partialBytes,
uint32_t totalBytesWritten) {
// Advance opIndex_ forward by wholeOps
opIndex_ += wholeOps;
void destroy() override {
this->~BytesWriteRequest();
free(this);
}
bool performWrite() override {
WriteFlags writeFlags = flags_;
if (getNext() != nullptr) {
writeFlags = writeFlags | WriteFlags::CORK;
}
bytesWritten_ = socket_->performWrite(getOps(), getOpCount(), writeFlags,
&opsWritten_, &partialBytes_);
return bytesWritten_ >= 0;
}
bool isComplete() override {
return opsWritten_ == getOpCount();
}
void consume() override {
// Advance opIndex_ forward by opsWritten_
opIndex_ += opsWritten_;
assert(opIndex_ < opCount_);
// If we've finished writing any IOBufs, release them
if (ioBuf_) {
for (uint32_t i = wholeOps; i != 0; --i) {
for (uint32_t i = opsWritten_; i != 0; --i) {
assert(ioBuf_);
ioBuf_ = ioBuf_->pop();
}
}
// Move partialBytes forward into the current iovec buffer
// Move partialBytes_ forward into the current iovec buffer
struct iovec* currentOp = writeOps_ + opIndex_;
assert((partialBytes < currentOp->iov_len) || (currentOp->iov_len == 0));
assert((partialBytes_ < currentOp->iov_len) || (currentOp->iov_len == 0));
currentOp->iov_base =
reinterpret_cast<uint8_t*>(currentOp->iov_base) + partialBytes;
currentOp->iov_len -= partialBytes;
reinterpret_cast<uint8_t*>(currentOp->iov_base) + partialBytes_;
currentOp->iov_len -= partialBytes_;
// Increment the bytesWritten_ count by totalBytesWritten
bytesWritten_ += totalBytesWritten;
}
void append(WriteRequest* next) {
assert(next_ == nullptr);
next_ = next;
// Increment the totalBytesWritten_ count by bytesWritten_;
totalBytesWritten_ += bytesWritten_;
}
private:
WriteRequest(WriteCallback* callback,
const struct iovec* ops,
uint32_t opCount,
unique_ptr<IOBuf>&& ioBuf,
WriteFlags flags)
: next_(nullptr)
, callback_(callback)
, bytesWritten_(0)
BytesWriteRequest(AsyncSocket* socket,
WriteCallback* callback,
const struct iovec* ops,
uint32_t opCount,
uint32_t partialBytes,
uint32_t bytesWritten,
unique_ptr<IOBuf>&& ioBuf,
WriteFlags flags)
: AsyncSocket::WriteRequest(socket, nullptr, callback, 0)
, opCount_(opCount)
, opIndex_(0)
, flags_(flags)
, ioBuf_(std::move(ioBuf)) {
, ioBuf_(std::move(ioBuf))
, opsWritten_(0)
, partialBytes_(partialBytes)
, bytesWritten_(bytesWritten) {
memcpy(writeOps_, ops, sizeof(*ops) * opCount_);
}
// Private destructor, to ensure callers use destroy()
~WriteRequest() {}
// private destructor, to ensure callers use destroy()
virtual ~BytesWriteRequest() {}
const struct iovec* getOps() const {
assert(opCount_ > opIndex_);
return writeOps_ + opIndex_;
}
uint32_t getOpCount() const {
assert(opCount_ > opIndex_);
return opCount_ - opIndex_;
}
WriteRequest* next_; ///< pointer to next WriteRequest
WriteCallback* callback_; ///< completion callback
uint32_t bytesWritten_; ///< bytes written
uint32_t opCount_; ///< number of entries in writeOps_
uint32_t opIndex_; ///< current index into writeOps_
WriteFlags flags_; ///< set for WriteFlags
unique_ptr<IOBuf> ioBuf_; ///< underlying IOBuf, or nullptr if N/A
// for consume(), how much we wrote on the last write
uint32_t opsWritten_; ///< complete ops written
uint32_t partialBytes_; ///< partial bytes of incomplete op written
ssize_t bytesWritten_; ///< bytes written altogether
struct iovec writeOps_[]; ///< write operation(s) list
};
......@@ -687,16 +732,16 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
// Create a new WriteRequest to add to the queue
WriteRequest* req;
try {
req = WriteRequest::newRequest(callback, vec + countWritten,
count - countWritten, std::move(ioBuf),
flags);
req = BytesWriteRequest::newRequest(this, callback, vec + countWritten,
count - countWritten, partialWritten,
bytesWritten, std::move(ioBuf), flags);
} catch (const std::exception& ex) {
// we mainly expect to catch std::bad_alloc here
AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
withAddr(string("failed to append new WriteRequest: ") + ex.what()));
return failWrite(__func__, callback, bytesWritten, tex);
}
req->consume(0, partialWritten, bytesWritten);
req->consume();
if (writeReqTail_ == nullptr) {
assert(writeReqHead_ == nullptr);
writeReqHead_ = writeReqTail_ = req;
......@@ -1346,20 +1391,11 @@ void AsyncSocket::handleWrite() noexcept {
// (See the comment in handleRead() explaining how this can happen.)
EventBase* originalEventBase = eventBase_;
while (writeReqHead_ != nullptr && eventBase_ == originalEventBase) {
uint32_t countWritten;
uint32_t partialWritten;
WriteFlags writeFlags = writeReqHead_->flags();
if (writeReqHead_->getNext() != nullptr) {
writeFlags = writeFlags | WriteFlags::CORK;
}
int bytesWritten = performWrite(writeReqHead_->getOps(),
writeReqHead_->getOpCount(),
writeFlags, &countWritten, &partialWritten);
if (bytesWritten < 0) {
if (!writeReqHead_->performWrite()) {
AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
withAddr("writev() failed"), errno);
return failWrite(__func__, ex);
} else if (countWritten == writeReqHead_->getOpCount()) {
} else if (writeReqHead_->isComplete()) {
// We finished this request
WriteRequest* req = writeReqHead_;
writeReqHead_ = req->getNext();
......@@ -1424,7 +1460,7 @@ void AsyncSocket::handleWrite() noexcept {
// We'll continue around the loop, trying to write another request
} else {
// Partial write.
writeReqHead_->consume(countWritten, partialWritten, bytesWritten);
writeReqHead_->consume();
// Stop after a partial write; it's highly likely that a subsequent write
// attempt will just return EAGAIN.
//
......@@ -1822,7 +1858,7 @@ void AsyncSocket::failWrite(const char* fn, const AsyncSocketException& ex) {
WriteRequest* req = writeReqHead_;
writeReqHead_ = req->getNext();
WriteCallback* callback = req->getCallback();
uint32_t bytesWritten = req->getBytesWritten();
uint32_t bytesWritten = req->getTotalBytesWritten();
req->destroy();
if (callback) {
callback->writeErr(bytesWritten, ex);
......@@ -1859,7 +1895,7 @@ void AsyncSocket::failAllWrites(const AsyncSocketException& ex) {
writeReqHead_ = req->getNext();
WriteCallback* callback = req->getCallback();
if (callback) {
callback->writeErr(req->getBytesWritten(), ex);
callback->writeErr(req->getTotalBytesWritten(), ex);
}
req->destroy();
}
......
......@@ -517,6 +517,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
};
class WriteRequest;
class BytesWriteRequest;
class WriteTimeout : public AsyncTimeout {
public:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment