Commit 4a92da5b authored by Hannes Roth's avatar Hannes Roth Committed by Praveen Kumar Ramakrishnan

(Wangle) Implement collect* using mapSetCallback and shared_ptrs

Summary:
I figured it would make sense to implement all the collect* functions using a shared_ptr<Context>, instead of doing our manual reference counting and all that. Fulfilling the promise in the destructor seemed like the icing on the cake. Also saves some line of code.

Test Plan: Run all the tests.

Reviewed By: hans@fb.com

Subscribers: folly-diffs@, jsedgwick, yfeldblum, chalfant

FB internal diff: D2015320

Signature: t1:2015320:1431106133:ac3001b3696fc75230afe70908ed349102b02a45
parent 5cc2f994
...@@ -531,22 +531,31 @@ inline Future<void> via(Executor* executor) { ...@@ -531,22 +531,31 @@ inline Future<void> via(Executor* executor) {
return makeFuture().via(executor); return makeFuture().via(executor);
} }
// when (variadic) // mapSetCallback calls func(i, Try<T>) when every future completes
template <class T, class InputIterator, class F>
void mapSetCallback(InputIterator first, InputIterator last, F func) {
for (size_t i = 0; first != last; ++first, ++i) {
first->setCallback_([func, i](Try<T>&& t) {
func(i, std::move(t));
});
}
}
// collectAll (variadic)
template <typename... Fs> template <typename... Fs>
typename detail::VariadicContext< typename detail::VariadicContext<
typename std::decay<Fs>::type::value_type...>::type typename std::decay<Fs>::type::value_type...>::type
collectAll(Fs&&... fs) { collectAll(Fs&&... fs) {
auto ctx = auto ctx = std::make_shared<detail::VariadicContext<
new detail::VariadicContext<typename std::decay<Fs>::type::value_type...>(); typename std::decay<Fs>::type::value_type...>>();
ctx->total = sizeof...(fs);
auto f_saved = ctx->p.getFuture();
detail::collectAllVariadicHelper(ctx, detail::collectAllVariadicHelper(ctx,
std::forward<typename std::decay<Fs>::type>(fs)...); std::forward<typename std::decay<Fs>::type>(fs)...);
return f_saved; return ctx->p.getFuture();
} }
// when (iterator) // collectAll (iterator)
template <class InputIterator> template <class InputIterator>
Future< Future<
...@@ -556,155 +565,87 @@ collectAll(InputIterator first, InputIterator last) { ...@@ -556,155 +565,87 @@ collectAll(InputIterator first, InputIterator last) {
typedef typedef
typename std::iterator_traits<InputIterator>::value_type::value_type T; typename std::iterator_traits<InputIterator>::value_type::value_type T;
if (first >= last) { struct CollectAllContext {
return makeFuture(std::vector<Try<T>>()); CollectAllContext(int n) : results(n) {}
} ~CollectAllContext() {
size_t n = std::distance(first, last); p.setValue(std::move(results));
}
auto ctx = new detail::WhenAllContext<T>(); Promise<std::vector<Try<T>>> p;
std::vector<Try<T>> results;
ctx->results.resize(n); };
auto f_saved = ctx->p.getFuture();
for (size_t i = 0; first != last; ++first, ++i) {
assert(i < n);
auto& f = *first;
f.setCallback_([ctx, i, n](Try<T> t) {
ctx->results[i] = std::move(t);
if (++ctx->count == n) {
ctx->p.setValue(std::move(ctx->results));
delete ctx;
}
});
}
return f_saved; auto ctx = std::make_shared<CollectAllContext>(std::distance(first, last));
mapSetCallback<T>(first, last, [ctx](size_t i, Try<T>&& t) {
ctx->results[i] = std::move(t);
});
return ctx->p.getFuture();
} }
namespace detail { namespace detail {
template <class, class, typename = void> struct CollectContextHelper;
template <class T, class VecT>
struct CollectContextHelper<T, VecT,
typename std::enable_if<std::is_same<T, VecT>::value>::type> {
static inline std::vector<T>&& getResults(std::vector<VecT>& results) {
return std::move(results);
}
};
template <class T, class VecT>
struct CollectContextHelper<T, VecT,
typename std::enable_if<!std::is_same<T, VecT>::value>::type> {
static inline std::vector<T> getResults(std::vector<VecT>& results) {
std::vector<T> finalResults;
finalResults.reserve(results.size());
for (auto& opt : results) {
finalResults.push_back(std::move(opt.value()));
}
return finalResults;
}
};
template <typename T> template <typename T>
struct CollectContext { struct CollectContext {
struct Nothing { explicit Nothing(int n) {} };
typedef typename std::conditional<
std::is_default_constructible<T>::value, using Result = typename std::conditional<
T, std::is_void<T>::value,
Optional<T> void,
>::type VecT; std::vector<T>>::type;
explicit CollectContext(int n) : count(0), success_count(0), threw(false) { using InternalResult = typename std::conditional<
results.resize(n); std::is_void<T>::value,
} Nothing,
std::vector<Optional<T>>>::type;
Promise<std::vector<T>> p;
std::vector<VecT> results; explicit CollectContext(int n) : result(n) {}
std::atomic<size_t> count, success_count; ~CollectContext() {
std::atomic_bool threw; if (!threw.exchange(true)) {
// map Optional<T> -> T
typedef std::vector<T> result_type; std::vector<T> finalResult;
finalResult.reserve(result.size());
static inline Future<std::vector<T>> makeEmptyFuture() { std::transform(result.begin(), result.end(),
return makeFuture(std::vector<T>()); std::back_inserter(finalResult),
} [](Optional<T>& o) { return std::move(o.value()); });
p.setValue(std::move(finalResult));
inline void setValue() { }
p.setValue(CollectContextHelper<T, VecT>::getResults(results));
} }
inline void setPartialResult(size_t i, Try<T>& t) {
inline void addResult(int i, Try<T>& t) { result[i] = std::move(t.value());
results[i] = std::move(t.value());
} }
Promise<Result> p;
InternalResult result;
std::atomic<bool> threw;
}; };
template <> // Specialize for void (implementations in Future.cpp)
struct CollectContext<void> {
explicit CollectContext(int n) : count(0), success_count(0), threw(false) {}
Promise<void> p; template <>
std::atomic<size_t> count, success_count; CollectContext<void>::~CollectContext();
std::atomic_bool threw;
typedef void result_type;
static inline Future<void> makeEmptyFuture() {
return makeFuture();
}
inline void setValue() {
p.setValue();
}
inline void addResult(int i, Try<void>& t) { template <>
// do nothing void CollectContext<void>::setPartialResult(size_t i, Try<void>& t);
}
};
} // detail }
template <class InputIterator> template <class InputIterator>
Future<typename detail::CollectContext< Future<typename detail::CollectContext<
typename std::iterator_traits<InputIterator>::value_type::value_type typename std::iterator_traits<InputIterator>::value_type::value_type>::Result>
>::result_type>
collect(InputIterator first, InputIterator last) { collect(InputIterator first, InputIterator last) {
typedef typedef
typename std::iterator_traits<InputIterator>::value_type::value_type T; typename std::iterator_traits<InputIterator>::value_type::value_type T;
if (first >= last) { auto ctx = std::make_shared<detail::CollectContext<T>>(
return detail::CollectContext<T>::makeEmptyFuture(); std::distance(first, last));
} mapSetCallback<T>(first, last, [ctx](size_t i, Try<T>&& t) {
if (t.hasException()) {
size_t n = std::distance(first, last); if (!ctx->threw.exchange(true)) {
auto ctx = new detail::CollectContext<T>(n); ctx->p.setException(std::move(t.exception()));
auto f_saved = ctx->p.getFuture();
for (size_t i = 0; first != last; ++first, ++i) {
assert(i < n);
auto& f = *first;
f.setCallback_([ctx, i, n](Try<T> t) {
if (t.hasException()) {
if (!ctx->threw.exchange(true)) {
ctx->p.setException(std::move(t.exception()));
}
} else if (!ctx->threw) {
ctx->addResult(i, t);
if (++ctx->success_count == n) {
ctx->setValue();
}
}
if (++ctx->count == n) {
delete ctx;
} }
}); } else if (!ctx->threw) {
} ctx->setPartialResult(i, t);
}
return f_saved; });
return ctx->p.getFuture();
} }
template <class InputIterator> template <class InputIterator>
...@@ -712,25 +653,24 @@ Future< ...@@ -712,25 +653,24 @@ Future<
std::pair<size_t, std::pair<size_t,
Try< Try<
typename typename
std::iterator_traits<InputIterator>::value_type::value_type> > > std::iterator_traits<InputIterator>::value_type::value_type>>>
collectAny(InputIterator first, InputIterator last) { collectAny(InputIterator first, InputIterator last) {
typedef typedef
typename std::iterator_traits<InputIterator>::value_type::value_type T; typename std::iterator_traits<InputIterator>::value_type::value_type T;
auto ctx = new detail::WhenAnyContext<T>(std::distance(first, last)); struct CollectAnyContext {
auto f_saved = ctx->p.getFuture(); CollectAnyContext(size_t n) : done(false) {};
Promise<std::pair<size_t, Try<T>>> p;
for (size_t i = 0; first != last; first++, i++) { std::atomic<bool> done;
auto& f = *first; };
f.setCallback_([i, ctx](Try<T>&& t) {
if (!ctx->done.exchange(true)) {
ctx->p.setValue(std::make_pair(i, std::move(t)));
}
ctx->decref();
});
}
return f_saved; auto ctx = std::make_shared<CollectAnyContext>(std::distance(first, last));
mapSetCallback<T>(first, last, [ctx](size_t i, Try<T>&& t) {
if (!ctx->done.exchange(true)) {
ctx->p.setValue(std::make_pair(i, std::move(t)));
}
});
return ctx->p.getFuture();
} }
template <class InputIterator> template <class InputIterator>
...@@ -741,38 +681,29 @@ collectN(InputIterator first, InputIterator last, size_t n) { ...@@ -741,38 +681,29 @@ collectN(InputIterator first, InputIterator last, size_t n) {
std::iterator_traits<InputIterator>::value_type::value_type T; std::iterator_traits<InputIterator>::value_type::value_type T;
typedef std::vector<std::pair<size_t, Try<T>>> V; typedef std::vector<std::pair<size_t, Try<T>>> V;
struct ctx_t { struct CollectNContext {
V v; V v;
size_t completed; std::atomic<size_t> completed = {0};
Promise<V> p; Promise<V> p;
}; };
auto ctx = std::make_shared<ctx_t>(); auto ctx = std::make_shared<CollectNContext>();
ctx->completed = 0;
if (std::distance(first, last) < n) {
// for each completed Future, increase count and add to vector, until we ctx->p.setException(std::runtime_error("Not enough futures"));
// have n completed futures at which point we fulfill our Promise with the } else {
// vector // for each completed Future, increase count and add to vector, until we
auto it = first; // have n completed futures at which point we fulfil our Promise with the
size_t i = 0; // vector
while (it != last) { mapSetCallback<T>(first, last, [ctx, n](size_t i, Try<T>&& t) {
it->then([ctx, n, i](Try<T>&& t) {
auto& v = ctx->v;
auto c = ++ctx->completed; auto c = ++ctx->completed;
if (c <= n) { if (c <= n) {
assert(ctx->v.size() < n); assert(ctx->v.size() < n);
v.push_back(std::make_pair(i, std::move(t))); ctx->v.push_back(std::make_pair(i, std::move(t)));
if (c == n) { if (c == n) {
ctx->p.setTry(Try<V>(std::move(v))); ctx->p.setTry(Try<V>(std::move(ctx->v)));
} }
} }
}); });
it++;
i++;
}
if (i < n) {
ctx->p.setException(std::runtime_error("Not enough futures"));
} }
return ctx->p.getFuture(); return ctx->p.getFuture();
......
...@@ -39,3 +39,19 @@ Future<void> sleep(Duration dur, Timekeeper* tk) { ...@@ -39,3 +39,19 @@ Future<void> sleep(Duration dur, Timekeeper* tk) {
} }
}} }}
namespace folly { namespace detail {
template <>
CollectContext<void>::~CollectContext() {
if (!threw.exchange(true)) {
p.setValue();
}
}
template <>
void CollectContext<void>::setPartialResult(size_t i, Try<void>& t) {
// Nothing to do for void
}
}}
...@@ -319,59 +319,33 @@ class Core { ...@@ -319,59 +319,33 @@ class Core {
template <typename... Ts> template <typename... Ts>
struct VariadicContext { struct VariadicContext {
VariadicContext() : total(0), count(0) {} VariadicContext() {}
Promise<std::tuple<Try<Ts>... > > p; ~VariadicContext() {
p.setValue(std::move(results));
}
Promise<std::tuple<Try<Ts>... >> p;
std::tuple<Try<Ts>... > results; std::tuple<Try<Ts>... > results;
size_t total;
std::atomic<size_t> count;
typedef Future<std::tuple<Try<Ts>...>> type; typedef Future<std::tuple<Try<Ts>...>> type;
}; };
template <typename... Ts, typename THead, typename... Fs> template <typename... Ts, typename THead, typename... Fs>
typename std::enable_if<sizeof...(Fs) == 0, void>::type typename std::enable_if<sizeof...(Fs) == 0, void>::type
collectAllVariadicHelper(VariadicContext<Ts...> *ctx, THead&& head, Fs&&... tail) { collectAllVariadicHelper(std::shared_ptr<VariadicContext<Ts...>> ctx,
THead&& head, Fs&&... tail) {
head.setCallback_([ctx](Try<typename THead::value_type>&& t) { head.setCallback_([ctx](Try<typename THead::value_type>&& t) {
std::get<sizeof...(Ts) - sizeof...(Fs) - 1>(ctx->results) = std::move(t); std::get<sizeof...(Ts) - sizeof...(Fs) - 1>(ctx->results) = std::move(t);
if (++ctx->count == ctx->total) {
ctx->p.setValue(std::move(ctx->results));
delete ctx;
}
}); });
} }
template <typename... Ts, typename THead, typename... Fs> template <typename... Ts, typename THead, typename... Fs>
typename std::enable_if<sizeof...(Fs) != 0, void>::type typename std::enable_if<sizeof...(Fs) != 0, void>::type
collectAllVariadicHelper(VariadicContext<Ts...> *ctx, THead&& head, Fs&&... tail) { collectAllVariadicHelper(std::shared_ptr<VariadicContext<Ts...>> ctx,
THead&& head, Fs&&... tail) {
head.setCallback_([ctx](Try<typename THead::value_type>&& t) { head.setCallback_([ctx](Try<typename THead::value_type>&& t) {
std::get<sizeof...(Ts) - sizeof...(Fs) - 1>(ctx->results) = std::move(t); std::get<sizeof...(Ts) - sizeof...(Fs) - 1>(ctx->results) = std::move(t);
if (++ctx->count == ctx->total) {
ctx->p.setValue(std::move(ctx->results));
delete ctx;
}
}); });
// template tail-recursion // template tail-recursion
collectAllVariadicHelper(ctx, std::forward<Fs>(tail)...); collectAllVariadicHelper(ctx, std::forward<Fs>(tail)...);
} }
template <typename T>
struct WhenAllContext {
WhenAllContext() : count(0) {}
Promise<std::vector<Try<T> > > p;
std::vector<Try<T> > results;
std::atomic<size_t> count;
};
template <typename T>
struct WhenAnyContext {
explicit WhenAnyContext(size_t n) : done(false), ref_count(n) {};
Promise<std::pair<size_t, Try<T>>> p;
std::atomic<bool> done;
std::atomic<size_t> ref_count;
void decref() {
if (--ref_count == 0) {
delete this;
}
}
};
}} // folly::detail }} // folly::detail
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment