Commit 3ae548f8 authored by Mathieu STEFANI's avatar Mathieu STEFANI

Async: implemented a whenAny primitive that returns a Promise that will be...

Async: implemented a whenAny primitive that returns a Promise that will be resolved when one of the Promise is resolved
parent 1d0330ed
......@@ -38,6 +38,11 @@ namespace Async {
TypeId id_;
};
class BadAnyCast : public std::bad_cast {
public:
virtual const char* what() const noexcept { return "Bad any cast"; }
};
/*
- Direct: The continuation will be directly called in the context
of the same thread
......@@ -456,8 +461,6 @@ namespace Async {
std::forward<RejectFunc>(rejectFunc));
}
};
}
class Resolver {
......@@ -676,16 +679,151 @@ namespace Async {
};
namespace Impl {
struct Any;
}
class Any {
public:
friend class Impl::Any;
Any(const Any& other) = default;
Any& operator=(const Any& other) = default;
Any(Any&& other) = default;
Any& operator=(Any&& other) = default;
template<typename T>
bool is() const {
return core_->id == TypeId::of<T>();
}
template<typename T>
T cast() const {
if (!is<T>()) throw BadAnyCast();
auto core = std::static_pointer_cast<Private::CoreT<T>>(core_);
return core->value();
}
private:
Any(const std::shared_ptr<Private::Core>& core)
: core_(core)
{ }
std::shared_ptr<Private::Core> core_;
};
namespace Impl {
/* Instead of duplicating the code between whenAll and whenAny functions, the main implementation
* is in the When class below and we configure the class with a policy instead, depending if we
* are executing an "all" or "any" operation, how cool is that ?
*/
struct All {
struct Data {
Data(const size_t total, Resolver resolver, Rejection rejection)
: total(total)
, resolved(0)
, rejected(false)
, resolve(std::move(resolver))
, reject(std::move(rejection))
{ }
const size_t total;
std::atomic<size_t> resolved;
std::atomic<bool> rejected;
Resolver resolve;
Rejection reject;
};
template<size_t Index, typename T, typename Data>
static void resolveT(const T& val, Data& data) {
if (data->rejected) return;
// @Check thread-safety of std::get ?
std::get<Index>(data->results) = val;
data->resolved.fetch_add(1, std::memory_order_relaxed);
if (data->resolved == data->total) {
data->resolve(data->results);
}
}
template<typename Data>
static void resolveVoid(Data& data) {
if (data->rejected) return;
data->resolved.fetch_add(1, std::memory_order_relaxed);
if (data->resolved == data->total) {
data->resolve(data->results);
}
}
struct WhenAll {
WhenAll(Resolver resolver, Rejection rejection)
template<typename Data>
static void reject(std::exception_ptr exc, Data& data) {
data->rejected.store(true);
data->reject(exc);
}
};
struct Any {
struct Data {
Data(size_t, Resolver resolver, Rejection rejection)
: done(false)
, resolve(std::move(resolver))
, reject(std::move(rejection))
{ }
std::atomic<bool> done;
Resolver resolve;
Rejection reject;
};
template<size_t Index, typename T, typename Data>
static void resolveT(const T& val, Data& data) {
if (data->done) return;
// Instead of allocating a new core, ideally we could share the same core as
// the relevant promise but we do not have access to the promise here is so meh
auto core = std::make_shared<Private::CoreT<T>>();
core->template construct<T>(val);
data->resolve(Async::Any(std::move(core)));
data->done = true;
}
template<typename Data>
static void resolveVoid(Data& data) {
if (data->done) return;
auto core = std::make_shared<Private::CoreT<void>>();
data->resolve(Async::Any(std::move(core)));
data->done = true;
}
template<typename Data>
static void reject(std::exception_ptr exc, Data& data) {
data->done.store(true);
data->reject(exc);
}
};
template<typename ContinuationPolicy>
struct When {
When(Resolver resolver, Rejection rejection)
: resolve(resolver)
, reject(rejection)
{ }
template<typename... Args>
void operator()(Args&&... args) {
whenAll(std::forward<Args>(args)...);
whenArgs(std::forward<Args>(args)...);
}
private:
......@@ -696,15 +834,7 @@ namespace Async {
{ }
void operator()(const T& val) const {
if (data->rejected) return;
// @Check thread-safety of std::get ?
std::get<Index>(data->results) = val;
data->resolved.fetch_add(1, std::memory_order_relaxed);
if (data->resolved == data->total) {
data->resolve(data->results);
}
ContinuationPolicy::template resolveT<Index>(val, data);
}
Data data;
......@@ -717,13 +847,7 @@ namespace Async {
{ }
void operator()() const {
if (data->rejected) return;
data->resolved.fetch_add(1, std::memory_order_relaxed);
if (data->resolved == data->total) {
data->resolve(data->results);
}
ContinuationPolicy::resolveVoid(data);
}
Data data;
......@@ -739,8 +863,7 @@ namespace Async {
template<size_t Index, typename Data, typename T>
void when(const Data& data, Promise<T>& promise) {
promise.then(makeContinuation<T, Index>(data), [=](std::exception_ptr ptr) {
data->rejected.store(true);
data->reject(ptr);
ContinuationPolicy::reject(std::move(ptr), data);
});
}
......@@ -752,7 +875,7 @@ namespace Async {
}
template<typename... Args>
void whenAll(Args&& ...args) {
void whenArgs(Args&& ...args) {
typedef std::tuple<
typename detail::RemovePromise<
typename std::remove_reference<Args>::type
......@@ -761,37 +884,34 @@ namespace Async {
/* We need to keep the results alive until the last promise
* finishes its execution
*/
struct Data {
/* See the trick here ? Basically, we only have access to the real type of the results
* in this function. The policy classes do not have access to the full type (std::tuple),
* but, instead, take a generic template data type as a parameter. They only need to know
* that results is a tuple, they do not need to know the real type of the results.
*
* This is some sort of compile-time template type-erasing, hue
*/
struct Data : public ContinuationPolicy::Data {
Data(size_t total, Resolver resolver, Rejection rejection)
: total(total)
, resolved(0)
, rejected(false)
, resolve(std::move(resolver))
, reject(std::move(rejection))
: ContinuationPolicy::Data(total, std::move(resolver), std::move(rejection))
{ }
const size_t total;
std::atomic<size_t> resolved;
std::atomic<bool> rejected;
Resolver resolve;
Rejection reject;
Results results;
};
auto data = std::make_shared<Data>(sizeof...(Args), std::move(resolve), std::move(reject));
whenAll<0>(data, std::forward<Args>(args)...);
whenArgs<0>(data, std::forward<Args>(args)...);
}
template<size_t Index, typename Data, typename Head, typename... Rest>
void whenAll(const Data& data, Head&& head, Rest&& ...rest) {
void whenArgs(const Data& data, Head&& head, Rest&& ...rest) {
when<Index>(data, std::forward<Head>(head));
whenAll<Index + 1>(data, std::forward<Rest>(rest)...);
whenArgs<Index + 1>(data, std::forward<Rest>(rest)...);
}
template<size_t Index, typename Data, typename Head>
void whenAll(const Data& data, Head&& head) {
void whenArgs(const Data& data, Head&& head) {
when<Index>(data, std::forward<Head>(head));
}
......@@ -799,7 +919,8 @@ namespace Async {
Rejection reject;
};
template<typename T,
template<
typename T,
typename Results
>
struct WhenAllRange {
......@@ -949,7 +1070,7 @@ namespace Async {
reject = &rejection;
});
Impl::WhenAll impl(*resolve, *reject);
Impl::When<Impl::All> impl(*resolve, *reject);
// So we capture everything we need inside the lambda and then call the
// implementation and expand the parameters pack here
impl(std::forward<Args>(args)...);
......@@ -957,6 +1078,22 @@ namespace Async {
return promise;
}
template<typename... Args>
Promise<Any> whenAny(Args&& ...args) {
// Same trick as above;
Resolver* resolve;
Rejection* reject;
Promise<Any> promise([&](Resolver& resolver, Rejection& rejection) {
resolve = &resolver;
reject = &rejection;
});
Impl::When<Impl::Any> impl(*resolve, *reject);
impl(std::forward<Args>(args)...);
return promise;
}
template<
typename Iterator,
typename ValueType
......
#include "gtest/gtest.h"
#include "async.h"
#include <thread>
#include <algorithm>
Async::Promise<int> doAsync(int N)
{
......@@ -17,6 +18,23 @@ Async::Promise<int> doAsync(int N)
return promise;
}
template<typename T, typename Func>
Async::Promise<T> doAsyncTimed(std::chrono::seconds time, T val, Func func)
{
Async::Promise<T> promise(
[=](Async::Resolver& resolve, Async::Rejection& reject) {
std::thread thr([=]() mutable {
std::this_thread::sleep_for(time);
resolve(func(val));
});
thr.detach();
});
return promise;
}
TEST(async_test, basic_test) {
Async::Promise<int> p1(
[](Async::Resolver& resolv, Async::Rejection& reject) {
......@@ -242,6 +260,27 @@ TEST(async_test, when_all) {
ASSERT_TRUE(resolved);
}
TEST(async_test, when_any) {
auto p1 = doAsyncTimed(std::chrono::seconds(2), 10.0,
[](double val) { return -val; });
auto p2 = doAsyncTimed(std::chrono::seconds(1), std::string("Hello"),
[](std::string val) { std::transform(std::begin(val), std::end(val), std::begin(val), ::toupper); return val; });
bool resolved = false;
Async::whenAny(p1, p2).then([&](const Async::Any& any) {
ASSERT_TRUE(any.is<std::string>());
auto val = any.cast<std::string>();
ASSERT_EQ(val, "HELLO");
ASSERT_THROW(any.cast<double>(), Async::BadAnyCast);
resolved = true;
}, Async::NoExcept);
std::this_thread::sleep_for(std::chrono::seconds(3));
ASSERT_TRUE(resolved);
}
TEST(async_test, rethrow_test) {
auto p1 = Async::Promise<void>([](Async::Resolver& resolve, Async::Rejection& reject) {
reject(std::runtime_error("Because"));
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment