Commit c8aadaad authored by Philip Pronin's avatar Philip Pronin Committed by Facebook Github Bot

AsyncIO::cancel

Summary:
It should be implemented with `io_cancel`, but it is not
supported (D682836), so still have to drain events, ignoring only
op callbacks.

Reviewed By: luciang, ot

Differential Revision: D5044020

fbshipit-source-id: 0bcd04c91a437fccaf2189ccf771a1cb61c68942
parent e17fce32
...@@ -66,6 +66,11 @@ void AsyncIOOp::complete(ssize_t result) { ...@@ -66,6 +66,11 @@ void AsyncIOOp::complete(ssize_t result) {
} }
} }
void AsyncIOOp::cancel() {
DCHECK_EQ(state_, State::PENDING);
state_ = State::CANCELED;
}
ssize_t AsyncIOOp::result() const { ssize_t AsyncIOOp::result() const {
CHECK_EQ(state_, State::COMPLETED); CHECK_EQ(state_, State::COMPLETED);
return result_; return result_;
...@@ -104,13 +109,7 @@ void AsyncIOOp::init() { ...@@ -104,13 +109,7 @@ void AsyncIOOp::init() {
state_ = State::INITIALIZED; state_ = State::INITIALIZED;
} }
AsyncIO::AsyncIO(size_t capacity, PollMode pollMode) AsyncIO::AsyncIO(size_t capacity, PollMode pollMode) : capacity_(capacity) {
: ctx_(0),
ctxSet_(false),
pending_(0),
submitted_(0),
capacity_(capacity),
pollFd_(-1) {
CHECK_GT(capacity_, 0); CHECK_GT(capacity_, 0);
completed_.reserve(capacity_); completed_.reserve(capacity_);
if (pollMode == POLLABLE) { if (pollMode == POLLABLE) {
...@@ -194,7 +193,15 @@ Range<AsyncIO::Op**> AsyncIO::wait(size_t minRequests) { ...@@ -194,7 +193,15 @@ Range<AsyncIO::Op**> AsyncIO::wait(size_t minRequests) {
CHECK_EQ(pollFd_, -1) << "wait() only allowed on non-pollable object"; CHECK_EQ(pollFd_, -1) << "wait() only allowed on non-pollable object";
auto p = pending_.load(std::memory_order_acquire); auto p = pending_.load(std::memory_order_acquire);
CHECK_LE(minRequests, p); CHECK_LE(minRequests, p);
return doWait(minRequests, p); doWait(WaitType::COMPLETE, minRequests, p, &completed_);
return Range<Op**>(completed_.data(), completed_.size());
}
size_t AsyncIO::cancel() {
CHECK(ctx_);
auto p = pending_.load(std::memory_order_acquire);
doWait(WaitType::CANCEL, p, p, nullptr);
return p;
} }
Range<AsyncIO::Op**> AsyncIO::pollCompleted() { Range<AsyncIO::Op**> AsyncIO::pollCompleted() {
...@@ -217,12 +224,19 @@ Range<AsyncIO::Op**> AsyncIO::pollCompleted() { ...@@ -217,12 +224,19 @@ Range<AsyncIO::Op**> AsyncIO::pollCompleted() {
DCHECK_LE(numEvents, pending_); DCHECK_LE(numEvents, pending_);
// Don't reap more than numEvents, as we've just reset the counter to 0. // Don't reap more than numEvents, as we've just reset the counter to 0.
return doWait(numEvents, numEvents); doWait(WaitType::COMPLETE, numEvents, numEvents, &completed_);
return Range<Op**>(completed_.data(), completed_.size());
} }
Range<AsyncIO::Op**> AsyncIO::doWait(size_t minRequests, size_t maxRequests) { void AsyncIO::doWait(
WaitType type,
size_t minRequests,
size_t maxRequests,
std::vector<Op*>* result) {
io_event events[maxRequests]; io_event events[maxRequests];
// Unfortunately, Linux AIO doesn't implement io_cancel, so even for
// WaitType::CANCEL we have to wait for IO completion.
size_t count = 0; size_t count = 0;
do { do {
int ret; int ret;
...@@ -237,27 +251,32 @@ Range<AsyncIO::Op**> AsyncIO::doWait(size_t minRequests, size_t maxRequests) { ...@@ -237,27 +251,32 @@ Range<AsyncIO::Op**> AsyncIO::doWait(size_t minRequests, size_t maxRequests) {
/* timeout */ nullptr); // wait forever /* timeout */ nullptr); // wait forever
} while (ret == -EINTR); } while (ret == -EINTR);
// Check as may not be able to recover without leaking events. // Check as may not be able to recover without leaking events.
CHECK_GE(ret, 0) CHECK_GE(ret, 0) << "AsyncIO: io_getevents failed with error "
<< "AsyncIO: io_getevents failed with error " << errnoStr(-ret); << errnoStr(-ret);
count += ret; count += ret;
} while (count < minRequests); } while (count < minRequests);
DCHECK_LE(count, maxRequests); DCHECK_LE(count, maxRequests);
completed_.clear(); if (result != nullptr) {
if (count == 0) { result->clear();
return folly::Range<Op**>();
} }
for (size_t i = 0; i < count; ++i) { for (size_t i = 0; i < count; ++i) {
DCHECK(events[i].obj); DCHECK(events[i].obj);
Op* op = boost::intrusive::get_parent_from_member( Op* op = boost::intrusive::get_parent_from_member(
events[i].obj, &AsyncIOOp::iocb_); events[i].obj, &AsyncIOOp::iocb_);
decrementPending(); decrementPending();
op->complete(events[i].res); switch (type) {
completed_.push_back(op); case WaitType::COMPLETE:
op->complete(events[i].res);
break;
case WaitType::CANCEL:
op->cancel();
break;
}
if (result != nullptr) {
result->push_back(op);
}
} }
return folly::Range<Op**>(&completed_.front(), count);
} }
AsyncIOQueue::AsyncIOQueue(AsyncIO* asyncIO) AsyncIOQueue::AsyncIOQueue(AsyncIO* asyncIO)
...@@ -308,6 +327,7 @@ const char* asyncIoOpStateToString(AsyncIOOp::State state) { ...@@ -308,6 +327,7 @@ const char* asyncIoOpStateToString(AsyncIOOp::State state) {
X(AsyncIOOp::State::INITIALIZED); X(AsyncIOOp::State::INITIALIZED);
X(AsyncIOOp::State::PENDING); X(AsyncIOOp::State::PENDING);
X(AsyncIOOp::State::COMPLETED); X(AsyncIOOp::State::COMPLETED);
X(AsyncIOOp::State::CANCELED);
} }
return "<INVALID AsyncIOOp::State>"; return "<INVALID AsyncIOOp::State>";
} }
......
...@@ -40,25 +40,24 @@ namespace folly { ...@@ -40,25 +40,24 @@ namespace folly {
* An AsyncIOOp represents a pending operation. You may set a notification * An AsyncIOOp represents a pending operation. You may set a notification
* callback or you may use this class's methods directly. * callback or you may use this class's methods directly.
* *
* The op must remain allocated until completion. * The op must remain allocated until it is completed or canceled.
*/ */
class AsyncIOOp : private boost::noncopyable { class AsyncIOOp : private boost::noncopyable {
friend class AsyncIO; friend class AsyncIO;
friend std::ostream& operator<<(std::ostream& stream, const AsyncIOOp& o); friend std::ostream& operator<<(std::ostream& stream, const AsyncIOOp& o);
public: public:
typedef std::function<void(AsyncIOOp*)> NotificationCallback; typedef std::function<void(AsyncIOOp*)> NotificationCallback;
explicit AsyncIOOp(NotificationCallback cb = NotificationCallback()); explicit AsyncIOOp(NotificationCallback cb = NotificationCallback());
~AsyncIOOp(); ~AsyncIOOp();
// There would be a cancel() method here if Linux AIO actually implemented
// it. But let's not get your hopes up.
enum class State { enum class State {
UNINITIALIZED, UNINITIALIZED,
INITIALIZED, INITIALIZED,
PENDING, PENDING,
COMPLETED COMPLETED,
CANCELED,
}; };
/** /**
...@@ -95,8 +94,7 @@ class AsyncIOOp : private boost::noncopyable { ...@@ -95,8 +94,7 @@ class AsyncIOOp : private boost::noncopyable {
* conventions). Use checkKernelError (folly/Exception.h) on the result to * conventions). Use checkKernelError (folly/Exception.h) on the result to
* throw a std::system_error in case of error instead. * throw a std::system_error in case of error instead.
* *
* It is an error to call this if the Op hasn't yet started or is still * It is an error to call this if the Op hasn't completed.
* pending.
*/ */
ssize_t result() const; ssize_t result() const;
...@@ -104,6 +102,7 @@ class AsyncIOOp : private boost::noncopyable { ...@@ -104,6 +102,7 @@ class AsyncIOOp : private boost::noncopyable {
void init(); void init();
void start(); void start();
void complete(ssize_t result); void complete(ssize_t result);
void cancel();
NotificationCallback cb_; NotificationCallback cb_;
iocb iocb_; iocb iocb_;
...@@ -123,7 +122,7 @@ class AsyncIO : private boost::noncopyable { ...@@ -123,7 +122,7 @@ class AsyncIO : private boost::noncopyable {
enum PollMode { enum PollMode {
NOT_POLLABLE, NOT_POLLABLE,
POLLABLE POLLABLE,
}; };
/** /**
...@@ -141,12 +140,12 @@ class AsyncIO : private boost::noncopyable { ...@@ -141,12 +140,12 @@ class AsyncIO : private boost::noncopyable {
* file descriptor directly. * file descriptor directly.
* *
* You may use the same AsyncIO object from multiple threads, as long as * You may use the same AsyncIO object from multiple threads, as long as
* there is only one concurrent caller of wait() / pollCompleted() (perhaps * there is only one concurrent caller of wait() / pollCompleted() / cancel()
* by always calling it from the same thread, or by providing appropriate * (perhaps by always calling it from the same thread, or by providing
* mutual exclusion) In this case, pending() returns a snapshot * appropriate mutual exclusion). In this case, pending() returns a snapshot
* of the current number of pending requests. * of the current number of pending requests.
*/ */
explicit AsyncIO(size_t capacity, PollMode pollMode=NOT_POLLABLE); explicit AsyncIO(size_t capacity, PollMode pollMode = NOT_POLLABLE);
~AsyncIO(); ~AsyncIO();
/** /**
...@@ -156,6 +155,11 @@ class AsyncIO : private boost::noncopyable { ...@@ -156,6 +155,11 @@ class AsyncIO : private boost::noncopyable {
*/ */
Range<Op**> wait(size_t minRequests); Range<Op**> wait(size_t minRequests);
/**
* Cancel all pending requests and return their number.
*/
size_t cancel();
/** /**
* Return the number of pending requests. * Return the number of pending requests.
*/ */
...@@ -196,16 +200,21 @@ class AsyncIO : private boost::noncopyable { ...@@ -196,16 +200,21 @@ class AsyncIO : private boost::noncopyable {
void decrementPending(); void decrementPending();
void initializeContext(); void initializeContext();
Range<Op**> doWait(size_t minRequests, size_t maxRequests); enum class WaitType { COMPLETE, CANCEL };
void doWait(
WaitType type,
size_t minRequests,
size_t maxRequests,
std::vector<Op*>* result);
io_context_t ctx_; io_context_t ctx_{nullptr};
std::atomic<bool> ctxSet_; std::atomic<bool> ctxSet_{false};
std::mutex initMutex_; std::mutex initMutex_;
std::atomic<size_t> pending_; std::atomic<size_t> pending_{0};
std::atomic<size_t> submitted_; std::atomic<size_t> submitted_{0};
const size_t capacity_; const size_t capacity_;
int pollFd_; int pollFd_{-1};
std::vector<Op*> completed_; std::vector<Op*> completed_;
}; };
......
...@@ -36,7 +36,9 @@ ...@@ -36,7 +36,9 @@
#include <folly/portability/Sockets.h> #include <folly/portability/Sockets.h>
namespace fs = folly::fs; namespace fs = folly::fs;
using folly::AsyncIO; using folly::AsyncIO;
using folly::AsyncIOOp;
using folly::AsyncIOQueue; using folly::AsyncIOQueue;
namespace { namespace {
...@@ -85,7 +87,7 @@ class TemporaryFile { ...@@ -85,7 +87,7 @@ class TemporaryFile {
}; };
TemporaryFile::TemporaryFile(size_t size) TemporaryFile::TemporaryFile(size_t size)
: path_(fs::temp_directory_path() / fs::unique_path()) { : path_(fs::temp_directory_path() / fs::unique_path()) {
CHECK_EQ(size % sizeof(uint32_t), 0); CHECK_EQ(size % sizeof(uint32_t), 0);
size /= sizeof(uint32_t); size /= sizeof(uint32_t);
const uint32_t seed = 42; const uint32_t seed = 42;
...@@ -370,7 +372,7 @@ TEST(AsyncIO, NonBlockingWait) { ...@@ -370,7 +372,7 @@ TEST(AsyncIO, NonBlockingWait) {
SCOPE_EXIT { SCOPE_EXIT {
::close(fd); ::close(fd);
}; };
size_t size = 2*kAlign; size_t size = 2 * kAlign;
auto buf = allocateAligned(size); auto buf = allocateAligned(size);
op.pread(fd, buf.get(), size, 0); op.pread(fd, buf.get(), size, 0);
aioReader.submit(&op); aioReader.submit(&op);
...@@ -389,3 +391,50 @@ TEST(AsyncIO, NonBlockingWait) { ...@@ -389,3 +391,50 @@ TEST(AsyncIO, NonBlockingWait) {
EXPECT_EQ(size, res); EXPECT_EQ(size, res);
EXPECT_EQ(aioReader.pending(), 0); EXPECT_EQ(aioReader.pending(), 0);
} }
TEST(AsyncIO, Cancel) {
constexpr size_t kNumOps = 10;
AsyncIO aioReader(kNumOps, AsyncIO::NOT_POLLABLE);
int fd = ::open(tempFile.path().c_str(), O_DIRECT | O_RDONLY);
PCHECK(fd != -1);
SCOPE_EXIT {
::close(fd);
};
std::vector<AsyncIO::Op> ops(kNumOps);
std::vector<ManagedBuffer> bufs;
size_t completed = 0;
for (auto& op : ops) {
const size_t size = 2 * kAlign;
bufs.push_back(allocateAligned(size));
op.setNotificationCallback([&](AsyncIOOp*) { ++completed; });
op.pread(fd, bufs.back().get(), size, 0);
aioReader.submit(&op);
}
EXPECT_EQ(aioReader.pending(), kNumOps);
EXPECT_EQ(completed, 0);
{
auto result = aioReader.wait(1);
EXPECT_EQ(result.size(), 1);
}
EXPECT_EQ(completed, 1);
EXPECT_EQ(aioReader.pending(), kNumOps - 1);
EXPECT_EQ(aioReader.cancel(), kNumOps - 1);
EXPECT_EQ(aioReader.pending(), 0);
EXPECT_EQ(completed, 1);
completed = 0;
for (auto& op : ops) {
if (op.state() == AsyncIOOp::State::COMPLETED) {
++completed;
} else {
EXPECT_TRUE(op.state() == AsyncIOOp::State::CANCELED) << op;
}
}
EXPECT_EQ(completed, 1);
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment