Commit 93db3df4 authored by Andrii Grynenko's avatar Andrii Grynenko Committed by Facebook Github Bot 2

Allow adding tasks to TaskIterator dynamically

Reviewed By: yfeldblum

Differential Revision: D3244669

fb-gh-sync-id: 73fa4ecb0432a802e67ef922255a896d96f32374
fbshipit-source-id: 73fa4ecb0432a802e67ef922255a896d96f32374
parent 4598dd70
...@@ -16,20 +16,12 @@ ...@@ -16,20 +16,12 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <folly/experimental/fibers/FiberManager.h>
namespace folly { namespace folly {
namespace fibers { namespace fibers {
template <typename T> template <typename T>
TaskIterator<T>::TaskIterator(TaskIterator&& other) noexcept TaskIterator<T>::TaskIterator(TaskIterator&& other) noexcept
: context_(std::move(other.context_)), id_(other.id_) {} : context_(std::move(other.context_)), id_(other.id_), fm_(other.fm_) {}
template <typename T>
TaskIterator<T>::TaskIterator(std::shared_ptr<Context> context)
: context_(std::move(context)), id_(-1) {
assert(context_);
}
template <typename T> template <typename T>
inline bool TaskIterator<T>::hasCompleted() const { inline bool TaskIterator<T>::hasCompleted() const {
...@@ -92,6 +84,30 @@ inline size_t TaskIterator<T>::getTaskID() const { ...@@ -92,6 +84,30 @@ inline size_t TaskIterator<T>::getTaskID() const {
return id_; return id_;
} }
template <typename T>
template <typename F>
void TaskIterator<T>::addTask(F&& func) {
static_assert(
std::is_convertible<typename std::result_of<F()>::type, T>::value,
"TaskIterator<T>: T must be convertible from func()'s return type");
auto taskId = context_->totalTasks++;
fm_.addTask(
[ taskId, context = context_, func = std::forward<F>(func) ]() mutable {
context->results.emplace_back(
taskId, folly::makeTryWith(std::move(func)));
// Check for awaiting iterator.
if (context->promise.hasValue()) {
if (--context->tasksToFulfillPromise == 0) {
context->promise->setValue();
context->promise.clear();
}
}
});
}
template <class InputIterator> template <class InputIterator>
TaskIterator<typename std::result_of< TaskIterator<typename std::result_of<
typename std::iterator_traits<InputIterator>::value_type()>::type> typename std::iterator_traits<InputIterator>::value_type()>::type>
...@@ -101,32 +117,15 @@ addTasks(InputIterator first, InputIterator last) { ...@@ -101,32 +117,15 @@ addTasks(InputIterator first, InputIterator last) {
ResultType; ResultType;
typedef TaskIterator<ResultType> IteratorType; typedef TaskIterator<ResultType> IteratorType;
auto context = std::make_shared<typename IteratorType::Context>(); IteratorType iterator;
context->totalTasks = std::distance(first, last);
context->results.reserve(context->totalTasks); for (; first != last; ++first) {
iterator.addTask(std::move(*first));
for (size_t i = 0; first != last; ++i, ++first) {
#ifdef __clang__
#pragma clang diagnostic push // ignore generalized lambda capture warning
#pragma clang diagnostic ignored "-Wc++1y-extensions"
#endif
addTask([ i, context, f = std::move(*first) ]() {
context->results.emplace_back(i, folly::makeTryWith(std::move(f)));
// Check for awaiting iterator.
if (context->promise.hasValue()) {
if (--context->tasksToFulfillPromise == 0) {
context->promise->setValue();
context->promise.clear();
}
}
});
#ifdef __clang__
#pragma clang diagnostic pop
#endif
} }
return IteratorType(std::move(context)); iterator.context_->results.reserve(iterator.context_->totalTasks);
return std::move(iterator);
} }
} }
} }
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include <folly/Optional.h> #include <folly/Optional.h>
#include <folly/experimental/fibers/FiberManager.h>
#include <folly/experimental/fibers/Promise.h> #include <folly/experimental/fibers/Promise.h>
#include <folly/futures/Try.h> #include <folly/futures/Try.h>
...@@ -49,6 +50,8 @@ class TaskIterator { ...@@ -49,6 +50,8 @@ class TaskIterator {
public: public:
typedef T value_type; typedef T value_type;
TaskIterator() : fm_(FiberManager::getFiberManager()) {}
// not copyable // not copyable
TaskIterator(const TaskIterator& other) = delete; TaskIterator(const TaskIterator& other) = delete;
TaskIterator& operator=(const TaskIterator& other) = delete; TaskIterator& operator=(const TaskIterator& other) = delete;
...@@ -57,6 +60,14 @@ class TaskIterator { ...@@ -57,6 +60,14 @@ class TaskIterator {
TaskIterator(TaskIterator&& other) noexcept; TaskIterator(TaskIterator&& other) noexcept;
TaskIterator& operator=(TaskIterator&& other) = delete; TaskIterator& operator=(TaskIterator&& other) = delete;
/**
* Add one more task to the TaskIterator.
*
* @param func task to be added, will be scheduled on current FiberManager
*/
template <typename F>
void addTask(F&& func);
/** /**
* @return True if there are tasks immediately available to be consumed (no * @return True if there are tasks immediately available to be consumed (no
* need to await on them). * need to await on them).
...@@ -111,10 +122,9 @@ class TaskIterator { ...@@ -111,10 +122,9 @@ class TaskIterator {
size_t tasksToFulfillPromise{0}; size_t tasksToFulfillPromise{0};
}; };
std::shared_ptr<Context> context_; std::shared_ptr<Context> context_{std::make_shared<Context>()};
size_t id_; size_t id_{std::numeric_limits<size_t>::max()};
FiberManager& fm_;
explicit TaskIterator(std::shared_ptr<Context> context);
folly::Try<T> awaitNextResult(); folly::Try<T> awaitNextResult();
}; };
......
...@@ -463,7 +463,7 @@ TEST(FiberManager, addTasksVoidThrow) { ...@@ -463,7 +463,7 @@ TEST(FiberManager, addTasksVoidThrow) {
loopController.loop(std::move(loopFunc)); loopController.loop(std::move(loopFunc));
} }
TEST(FiberManager, reserve) { TEST(FiberManager, addTasksReserve) {
std::vector<Promise<int>> pendingFibers; std::vector<Promise<int>> pendingFibers;
bool taskAdded = false; bool taskAdded = false;
...@@ -517,6 +517,42 @@ TEST(FiberManager, reserve) { ...@@ -517,6 +517,42 @@ TEST(FiberManager, reserve) {
loopController.loop(std::move(loopFunc)); loopController.loop(std::move(loopFunc));
} }
TEST(FiberManager, addTaskDynamic) {
folly::EventBase evb;
Baton batons[3];
auto makeTask = [&](size_t taskId) {
return [&, taskId]() -> size_t {
batons[taskId].wait();
return taskId;
};
};
getFiberManager(evb)
.addTaskFuture([&]() {
TaskIterator<size_t> iterator;
iterator.addTask(makeTask(0));
iterator.addTask(makeTask(1));
batons[1].post();
EXPECT_EQ(1, iterator.awaitNext());
iterator.addTask(makeTask(2));
batons[2].post();
EXPECT_EQ(2, iterator.awaitNext());
batons[0].post();
EXPECT_EQ(0, iterator.awaitNext());
})
.waitVia(&evb);
}
TEST(FiberManager, forEach) { TEST(FiberManager, forEach) {
std::vector<Promise<int>> pendingFibers; std::vector<Promise<int>> pendingFibers;
bool taskAdded = false; bool taskAdded = false;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment