Commit 284c5f6f authored by octal's avatar octal

async: Implemented a new Async::whenAll primitive.

The whenAll function returns a Promise<T> that will resolve
when all of the promises have resolved or will reject when of
the promise is rejected.
parent aa845e40
......@@ -10,6 +10,7 @@
#include <type_traits>
#include <functional>
#include <memory>
#include <atomic>
#include "optional.h"
namespace Async {
......@@ -62,6 +63,21 @@ namespace Async {
static constexpr bool value = std::is_same<decltype(test<T>(0)), std::true_type>::value;
};
template<typename Func>
struct IsMoveCallable : public IsMoveCallable<decltype(&Func::operator())> { };
template<typename R, typename Class, typename Arg>
struct IsMoveCallable<R (Class::*)(Arg) const> : public std::is_rvalue_reference<Arg> { };
template<typename Func, typename Arg>
typename std::conditional<
IsMoveCallable<Func>::value,
Arg&&,
const Arg&
>::type tryMove(Arg& arg) {
return std::move(arg);
}
template<typename Func>
struct FunctionTrait : public FunctionTrait<decltype(&Func::operator())> { };
......@@ -140,11 +156,11 @@ namespace Async {
typedef typename std::aligned_storage<sizeof(T), alignof(T)>::type Storage;
Storage storage;
const T& value() const {
T& value() {
if (state != State::Fulfilled)
throw Error("Attempted to take the value of a not fulfilled promise");
return *reinterpret_cast<const T*>(&storage);
return *reinterpret_cast<T*>(&storage);
}
bool isVoid() const { return false; }
......@@ -247,16 +263,26 @@ namespace Async {
void doReject(const std::shared_ptr<CoreT<T>>& core) const {
rejectFunc_(core->exc);
for (const auto& req: chain_->requests) {
req->reject(chain_);
}
}
void doResolveImpl(const std::shared_ptr<CoreT<T>>& core, std::true_type /* is_void */) const {
auto ret = resolveFunc_();
chain_->construct<decltype(ret)>(std::move(ret));
finishResolve(resolveFunc_());
}
void doResolveImpl(const std::shared_ptr<CoreT<T>>& core, std::false_type /* is_void */) const {
auto ret = resolveFunc_(core->value());
chain_->construct<decltype(ret)>(std::move(ret));
finishResolve(resolveFunc_(detail::tryMove<ResolveFunc>(core->value())));
}
template<typename Ret>
void finishResolve(Ret&& ret) const {
typedef typename std::remove_reference<Ret>::type CleanRet;
chain_->construct<CleanRet>(std::forward<Ret>(ret));
for (const auto& req: chain_->requests) {
req->resolve(chain_);
}
}
std::shared_ptr<Core> chain_;
......@@ -281,6 +307,9 @@ namespace Async {
void doReject(const std::shared_ptr<CoreT<T>>& core) const {
rejectFunc_(core->exc);
for (const auto& req: core->requests) {
req->reject(core);
}
}
void doResolveImpl(const std::shared_ptr<CoreT<T>>& core, std::true_type /* is_void */) const {
......@@ -289,7 +318,7 @@ namespace Async {
}
void doResolveImpl(const std::shared_ptr<CoreT<T>>& core, std::false_type /* is_void */) const {
auto promise = resolveFunc_(core->value());
auto promise = resolveFunc_(detail::tryMove<ResolveFunc>(core->value()));
finishResolve(promise);
}
......@@ -299,6 +328,10 @@ namespace Async {
promise.then(std::move(chainer), [=](std::exception_ptr exc) {
chain_->exc = std::move(exc);
chain_->state = State::Rejected;
for (const auto& req: chain_->requests) {
req->reject(chain_);
}
});
}
......@@ -314,6 +347,9 @@ namespace Async {
void operator()(const PromiseType& val) const {
chainCore->construct<PromiseType>(val);
for (const auto& req: chainCore->requests) {
req->resolve(chainCore);
}
}
std::shared_ptr<Core> chainCore;
......@@ -586,4 +622,208 @@ namespace Async {
Rejection rejection_;
};
namespace Impl {
struct WhenAll {
WhenAll(Resolver resolver, Rejection rejection)
: resolve(resolver)
, reject(rejection)
{ }
template<typename... Args>
void operator()(Args&&... args) {
whenAll(std::forward<Args>(args)...);
}
private:
template<typename T, size_t Index, typename Data>
struct WhenContinuation {
WhenContinuation(Data data)
: data(std::move(data))
{ }
void operator()(const T& val) const {
if (data->rejected) return;
// @Check thread-safety of std::get ?
std::get<Index>(data->results) = val;
data->resolved.fetch_add(1, std::memory_order_relaxed);
if (data->resolved == data->total) {
data->resolve(data->results);
}
}
Data data;
};
template<size_t Index, typename Data>
struct WhenContinuation<void, Index, Data> {
WhenContinuation(Data data)
: data(std::move(data))
{ }
void operator()() const {
if (data->rejected) return;
data->resolved.fetch_add(1, std::memory_order_relaxed);
if (data->resolved == data->total) {
data->resolve(data->results);
}
}
Data data;
};
template<typename T, size_t Index, typename Data>
WhenContinuation<T, Index, Data>
makeContinuation(const Data& data) {
return WhenContinuation<T, Index, Data>(data);
}
template<size_t Index, typename Data, typename T>
void when(const Data& data, Promise<T>& promise) {
promise.then(makeContinuation<T, Index>(data), [=](std::exception_ptr ptr) {
data->rejected.store(true);
data->reject(ptr);
});
}
template<size_t Index, typename Data, typename T>
void when(const Data& data, T&& arg) {
typedef typename std::remove_reference<T>::type CleanT;
auto promise = Promise<CleanT>::resolved(std::forward<T>(arg));
when<Index>(data, promise);
}
template<typename... Args>
void whenAll(Args&& ...args) {
typedef std::tuple<
typename detail::RemovePromise<
typename std::remove_reference<Args>::type
>::Type...
> Results;
/* We need to keep the results alive until the last promise
* finishes its execution
*/
struct Data {
Data(size_t total, Resolver resolver, Rejection rejection)
: total(total)
, resolved(0)
, rejected(false)
, resolve(std::move(resolver))
, reject(std::move(rejection))
{ }
const size_t total;
std::atomic<size_t> resolved;
std::atomic<bool> rejected;
Resolver resolve;
Rejection reject;
Results results;
};
auto data = std::make_shared<Data>(sizeof...(Args), std::move(resolve), std::move(reject));
whenAll<0>(data, std::forward<Args>(args)...);
}
template<size_t Index, typename Data, typename Head, typename... Rest>
void whenAll(const Data& data, Head&& head, Rest&& ...rest) {
when<Index>(data, std::forward<Head>(head));
whenAll<Index + 1>(data, std::forward<Rest>(rest)...);
}
template<size_t Index, typename Data, typename Head>
void whenAll(const Data& data, Head&& head) {
when<Index>(data, std::forward<Head>(head));
}
Resolver resolve;
Rejection reject;
};
}
template<
typename... Args,
typename Results =
std::tuple<
typename detail::RemovePromise<
typename std::remove_reference<Args>::type
>::Type...
>
>
Promise<Results> whenAll(Args&& ...args) {
return Promise<Results>([&](Resolver& resolver, Rejection& rejection) {
Impl::WhenAll impl(resolver, rejection);
impl(std::forward<Args>(args)...);
});
}
template<
typename Iterator,
typename ValueType
= typename detail::RemovePromise<
typename std::iterator_traits<Iterator>::value_type
>::Type,
typename Results = std::vector<ValueType>
>
Promise<Results> whenAll(Iterator first, Iterator last) {
return Promise<Results>([=](Resolver& resolve, Rejection& rejection) {
struct Data {
Data(const size_t total,
Resolver resolver,
Rejection rejection)
: total(total)
, resolved(0)
, rejected(false)
, resolve(std::move(resolver))
, reject(std::move(rejection))
{
results.resize(total);
}
const size_t total;
std::atomic<size_t> resolved;
std::atomic<bool> rejected;
Results results;
Resolver resolve;
Rejection reject;
};
auto data = std::make_shared<Data>(
std::distance(first, last),
resolve,
rejection
);
size_t index = 0;
for (auto it = first; it != last; ++it) {
it->then([=](const ValueType& val) {
if (data->rejected) return;
data->results[index] = val;
data->resolved.fetch_add(1);
if (data->resolved == data->total) {
data->resolve(data->results);
}
},
[=](std::exception_ptr ptr) {
data->rejected.store(true);
data->reject(std::move(ptr));
});
++index;
}
});
}
} // namespace Async
......@@ -17,7 +17,6 @@ Async::Promise<int> doAsync(int N)
return promise;
}
TEST(async_test, basic_test) {
Async::Promise<int> p1(
[](Async::Resolver& resolv, Async::Rejection& reject) {
......@@ -157,5 +156,71 @@ TEST(async_test, chain_test) {
ASSERT_THROW(std::rethrow_exception(exc), std::runtime_error);
});
auto p5 = doAsync(10);
p5
.then([](int result) { return result * 3.51; }, Async::NoExcept)
.then([](double result) { ASSERT_EQ(result, 20 * 3.51); }, Async::NoExcept);
auto p6 = doAsync(20);
p6
.then([](int result) { return doAsync(result - 5); }, Async::NoExcept)
.then([](int result) { ASSERT_EQ(result, 70); }, Async::NoExcept);
std::this_thread::sleep_for(std::chrono::seconds(2));
}
TEST(async_test, when_all) {
auto p1 = Async::Promise<int>::resolved(10);
int p2 = 123;
auto p3 = Async::Promise<std::string>::resolved("Hello");
auto p4 = Async::Promise<void>::resolved();
bool resolved { false };
Async::whenAll(p1, p2, p3).then([&](const std::tuple<int, int, std::string>& results) {
resolved = true;
ASSERT_EQ(std::get<0>(results), 10);
ASSERT_EQ(std::get<1>(results), 123);
ASSERT_EQ(std::get<2>(results), "Hello");
}, Async::NoExcept);
ASSERT_TRUE(resolved);
std::vector<Async::Promise<int>> vec;
vec.push_back(std::move(p1));
vec.push_back(Async::Promise<int>::resolved(p2));
resolved = false;
Async::whenAll(std::begin(vec), std::end(vec)).then([&](const std::vector<int>& results) {
resolved = true;
ASSERT_EQ(results.size(), 2);
ASSERT_EQ(results[0], 10);
ASSERT_EQ(results[1], 123);
},
Async::NoExcept);
ASSERT_TRUE(resolved);
auto p5 = doAsync(10);
auto p6 = p5.then([](int result) { return result * 3.1415; }, Async::NoExcept);
resolved = false;
Async::whenAll(p5, p6).then([&](std::tuple<int, double> results) {
ASSERT_EQ(std::get<0>(results), 20);
ASSERT_EQ(std::get<1>(results), 20 * 3.1415);
resolved = true;
}, Async::NoExcept);
std::this_thread::sleep_for(std::chrono::seconds(3));
ASSERT_TRUE(resolved);
// @Todo: does not compile yet. Figure out why it does not compile with void
// promises
#if 0
Async::whenAll(p3, p4).then([](const std::tuple<std::string, void>& results) {
}, Async::NoExcept);
#endif
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment