Commit d39ab557 authored by octal's avatar octal

More work on Promises

parent 955ae9ac
......@@ -84,8 +84,13 @@ namespace Async {
virtual void* memory() = 0;
virtual bool isVoid() const = 0;
template<typename T, typename... Args>
void construct(Args&&... args) {
if (isVoid())
throw Error("Can not construct a void core");
void *mem = memory();
new (mem) T(std::forward<Args>(args)...);
state = State::Fulfilled;
......@@ -94,26 +99,6 @@ namespace Async {
}
namespace detail {
template<typename T>
struct RemovePromise {
typedef T Type;
};
template<typename T>
struct RemovePromise<Promise<T>> {
typedef T Type;
};
template<typename Func>
struct result_of : public result_of<decltype(&Func::operator())> { };
template<typename R, typename Class, typename... Args>
struct result_of<R (Class::*) (Args...) const> {
typedef R Type;
};
}
class Resolver {
public:
Resolver(const std::shared_ptr<Private::Core> &core)
......@@ -127,6 +112,13 @@ namespace Async {
if (core_->state != State::Pending)
throw Error("Attempt to resolve a fulfilled promise");
/* In a ideal world, this should be checked at compile-time rather
* than runtime. However, since types are erased, this looks like
* a difficult task
*/
if (core_->isVoid())
throw Error("Attempt to resolve a void promise with arguments");
core_->construct<Type>(std::forward<Arg>(arg));
for (const auto& req: core_->requests) {
req->resolve(core_);
......@@ -135,6 +127,21 @@ namespace Async {
return true;
}
bool operator()() {
if (core_->state != State::Pending)
throw Error("Attempt to resolve a fulfilled promise");
if (!core_->isVoid())
throw Error("Attempt ro resolve a non-void promise with no argument");
core_->state = State::Fulfilled;
for (const auto& req: core_->requests) {
req->resolve(core_);
}
return true;
}
private:
std::shared_ptr<Private::Core> core_;
};
......@@ -170,13 +177,34 @@ namespace Async {
struct IsCallable {
template<typename U>
static auto test(U *) -> decltype(std::declval<Func>()(std::declval<T>()), std::true_type());
static auto test(U *) -> decltype(std::declval<Func>()(std::declval<U>()), std::true_type());
template<typename U>
static auto test(...) -> std::false_type;
static constexpr bool value = std::is_same<decltype(test<T>(0)), std::true_type>::value;
};
template<typename Func>
struct FunctionTrait : public FunctionTrait<decltype(&Func::operator())> { };
template<typename R, typename Class, typename... Args>
struct FunctionTrait<R (Class::*)(Args...) const> {
typedef R ReturnType;
static constexpr size_t ArgsCount = sizeof...(Args);
};
template<typename T>
struct RemovePromise {
typedef T Type;
};
template<typename T>
struct RemovePromise<Promise<T>> {
typedef T Type;
};
}
template<typename T>
......@@ -214,14 +242,14 @@ namespace Async {
then(ResolveFunc&& resolveFunc, RejectFunc&& rejectFunc, Continuation type = Continuation::Direct)
-> Promise<
typename detail::RemovePromise<
typename detail::result_of<ResolveFunc>::Type
typename detail::FunctionTrait<ResolveFunc>::ReturnType
>::Type
>
{
static_assert(detail::IsCallable<ResolveFunc, T>::value, "Function is not compatible with underlying promise type");
typedef typename detail::RemovePromise<
typename detail::result_of<ResolveFunc>::Type
typename detail::FunctionTrait<ResolveFunc>::ReturnType
>::Type RetType;
Promise<RetType> promise;
......@@ -271,6 +299,8 @@ namespace Async {
return *reinterpret_cast<const U*>(&storage);
}
bool isVoid() const { return false; }
void *memory() {
return &storage;
}
......@@ -335,7 +365,6 @@ namespace Async {
rejectFunc_(core->exc);
}
std::shared_ptr<Private::Core> chain_;
ResolveFunc resolveFunc_;
RejectFunc rejectFunc_;
};
......@@ -378,6 +407,11 @@ namespace Async {
void doResolve(const std::shared_ptr<CoreT>& core) const {
auto promise = resolveFunc_(core->value());
auto chainer = makeChainer(promise);
promise.then(std::move(chainer), [=](std::exception_ptr exc) {
chain_->exc = std::move(exc);
chain_->state = State::Rejected;
});
}
void doReject(const std::shared_ptr<CoreT>& core) const {
......@@ -388,6 +422,27 @@ namespace Async {
ResolveFunc resolveFunc_;
RejectFunc rejectFunc_;
template<typename PromiseType>
struct Chainer {
Chainer(const std::shared_ptr<Private::Core>& core)
: chainCore(core)
{ }
void operator()(const PromiseType& val) const {
chainCore->construct<PromiseType>(val);
}
std::shared_ptr<Private::Core> chainCore;
};
template<
typename Promise,
typename Type = typename detail::RemovePromise<Promise>::Type>
Chainer<Type>
makeChainer(const Promise&) const {
return Chainer<Type>(chain_);
}
};
template<typename ResolveFunc>
......@@ -437,21 +492,242 @@ namespace Async {
};
template<>
class Promise<void>
class Promise<void> : public PromiseBase
{
public:
Promise() :
core_(std::make_shared<Core>())
template<typename U> friend class Promise;
typedef std::function<void (Resolver&, Rejection&)> ResolveFunc;
Promise(ResolveFunc func)
: core_(std::make_shared<VoidCore>())
, resolver_(core_)
, rejection_(core_)
{
func(resolver_, rejection_);
}
bool isPending() const {
return core_->state == State::Pending;
}
bool isFulfilled() const {
return core_->state == State::Fulfilled;
}
bool isRejected() const {
return core_->state == State::Rejected;
}
template<typename ResolveFunc, typename RejectFunc>
auto then(ResolveFunc&& resolveFunc, RejectFunc&& rejectFunc)
-> Promise<
typename detail::RemovePromise<
typename detail::FunctionTrait<ResolveFunc>::ReturnType
>::Type
>
{
static_assert(detail::FunctionTrait<ResolveFunc>::ArgsCount == 0,
"Continuation function of a void promise should not take any argument");
typedef typename detail::RemovePromise<
typename detail::FunctionTrait<ResolveFunc>::ReturnType
>::Type RetType;
Promise<RetType> promise;
typedef typename std::remove_reference<ResolveFunc>::type ResolveFuncType;
typedef typename std::remove_reference<RejectFunc>::type RejectFuncType;
std::shared_ptr<Private::Request> req;
req.reset(ContinuationFactory<ResolveFuncType>::create(
promise.core_,
std::forward<ResolveFunc>(resolveFunc),
std::forward<RejectFunc>(rejectFunc)));
if (isFulfilled()) {
req->resolve(core_);
}
else if (isRejected()) {
req->reject(core_);
}
core_->requests.push_back(req);
return promise;
}
private:
Promise()
: core_(std::make_shared<VoidCore>())
, resolver_(core_)
, rejection_(core_)
{ }
public:
struct Core : public Private::Core {
Core() :
struct VoidCore : public Private::Core {
VoidCore() :
Private::Core(State::Pending)
{ }
bool isVoid() const { return true; }
void *memory() { return nullptr; }
};
std::shared_ptr<Core> core_;
/* TODO: Crazy code duplication between Promise<T> and its void specialization.
*
* Find a way to get rid of the duplication but still presevering encapsulation
*/
struct Continuable : public Private::Request {
Continuable()
: resolveCount_(0)
, rejectCount_(0)
{ }
void resolve(const std::shared_ptr<Private::Core>& core) {
if (resolveCount_ >= 1)
throw Error("Resolve must not be called more than once");
doResolve(coreCast(core));
++resolveCount_;
}
void reject(const std::shared_ptr<Private::Core>& core) {
if (rejectCount_ >= 1)
throw Error("Reject must not be called more than once");
doReject(coreCast(core));
++rejectCount_;
}
std::shared_ptr<VoidCore> coreCast(const std::shared_ptr<Private::Core>& core) const {
return std::static_pointer_cast<VoidCore>(core);
}
virtual void doResolve(const std::shared_ptr<VoidCore>& core) const = 0;
virtual void doReject(const std::shared_ptr<VoidCore>& core) const = 0;
size_t resolveCount_;
size_t rejectCount_;
};
template<typename ResolveFunc, typename RejectFunc>
struct ThenContinuation : public Continuable {
ThenContinuation(
ResolveFunc&& resolveFunc, RejectFunc&& rejectFunc)
: resolveFunc_(std::forward<ResolveFunc>(resolveFunc))
, rejectFunc_(std::forward<RejectFunc>(rejectFunc))
{
}
void doResolve(const std::shared_ptr<VoidCore>& core) const {
resolveFunc_();
}
void doReject(const std::shared_ptr<VoidCore>& core) const {
rejectFunc_(core->exc);
}
ResolveFunc resolveFunc_;
RejectFunc rejectFunc_;
};
template<typename ResolveFunc, typename RejectFunc, typename Return>
struct ThenReturnContinuation : public Continuable {
ThenReturnContinuation(
const std::shared_ptr<Private::Core>& chain,
ResolveFunc&& resolveFunc, RejectFunc&& rejectFunc)
: chain_(chain)
, resolveFunc_(std::forward<ResolveFunc>(resolveFunc))
, rejectFunc_(std::forward<RejectFunc>(rejectFunc))
{
}
void doResolve(const std::shared_ptr<VoidCore>& core) const {
auto ret = resolveFunc_();
chain_->construct<decltype(ret)>(std::move(ret));
}
void doReject(const std::shared_ptr<VoidCore>& core) const {
rejectFunc_(core->exc);
}
std::shared_ptr<Private::Core> chain_;
ResolveFunc resolveFunc_;
RejectFunc rejectFunc_;
};
template<typename ResolveFunc, typename RejectFunc>
struct ThenChainContinuation : public Continuable {
ThenChainContinuation(
const std::shared_ptr<Private::Core>& chain,
ResolveFunc&& resolveFunc, RejectFunc&& rejectFunc)
: chain_(chain)
, resolveFunc_(std::forward<ResolveFunc>(resolveFunc))
, rejectFunc_(std::forward<RejectFunc>(rejectFunc))
{
}
void doResolve(const std::shared_ptr<VoidCore>& core) const {
auto promise = resolveFunc_();
}
void doReject(const std::shared_ptr<VoidCore>& core) const {
rejectFunc_(core->exc);
}
std::shared_ptr<Private::Core> chain_;
ResolveFunc resolveFunc_;
RejectFunc rejectFunc_;
};
template<typename ResolveFunc>
struct ContinuationFactory : public ContinuationFactory<decltype(&ResolveFunc::operator())> { };
template<typename R, typename Class, typename... Args>
struct ContinuationFactory<R (Class::*)(Args...) const> {
template<typename ResolveFunc, typename RejectFunc>
static Continuable *create(
const std::shared_ptr<Private::Core>& chain,
ResolveFunc&& resolveFunc, RejectFunc&& rejectFunc) {
return new ThenReturnContinuation<ResolveFunc, RejectFunc, R>(
chain,
std::forward<ResolveFunc>(resolveFunc),
std::forward<RejectFunc>(rejectFunc));
}
};
template<typename Class, typename... Args>
struct ContinuationFactory<void (Class::*)(Args ...) const> {
template<typename ResolveFunc, typename RejectFunc>
static Continuable *create(
const std::shared_ptr<Private::Core>& chain,
ResolveFunc&& resolveFunc, RejectFunc&& rejectFunc) {
return new ThenContinuation<ResolveFunc, RejectFunc>(
std::forward<ResolveFunc>(resolveFunc),
std::forward<RejectFunc>(rejectFunc));
}
};
template<typename U, typename Class, typename... Args>
struct ContinuationFactory<Promise<U> (Class::*)(Args...) const> {
template<typename ResolveFunc, typename RejectFunc>
static Continuable* create(
const std::shared_ptr<Private::Core>& chain,
ResolveFunc&& resolveFunc, RejectFunc&& rejectFunc) {
return new ThenChainContinuation<ResolveFunc, RejectFunc>(
chain,
std::forward<ResolveFunc>(resolveFunc),
std::forward<RejectFunc>(rejectFunc));
}
};
std::shared_ptr<VoidCore> core_;
Resolver resolver_;
Rejection rejection_;
};
} // namespace Async
......@@ -51,6 +51,32 @@ TEST(async_test, basic_test) {
}
TEST(async_test, void_promise) {
Async::Promise<void> p1(
[](Async::Resolver& resolve, Async::Rejection& reject) {
resolve();
});
ASSERT_TRUE(p1.isFulfilled());
bool thenCalled { false };
p1.then([&]() {
thenCalled = true;
}, Async::NoExcept);
ASSERT_TRUE(thenCalled);
Async::Promise<int> p2(
[](Async::Resolver& resolve, Async::Rejection& reject) {
ASSERT_THROW(resolve(), Async::Error);
});
Async::Promise<void> p3(
[](Async::Resolver& resolve, Async::Rejection& reject) {
ASSERT_THROW(resolve(10), Async::Error);
});
}
TEST(async_test, chain_test) {
Async::Promise<int> p1(
[](Async::Resolver& resolve, Async::Rejection& reject) {
......@@ -72,5 +98,56 @@ TEST(async_test, chain_test) {
.then([](double result) { std::cout << "Result = " << result << std::endl; },
Async::IgnoreException);
enum class Test { Foo, Bar };
Async::Promise<Test> p3(
[](Async::Resolver& resolve, Async::Rejection& reject) {
resolve(Test::Foo);
});
p3
.then([](Test result) {
return Async::Promise<std::string>(
[=](Async::Resolver& resolve, Async::Rejection&) {
switch (result) {
case Test::Foo:
resolve(std::string("Foo"));
break;
case Test::Bar:
resolve(std::string("Bar"));
}
}); }, Async::NoExcept)
.then([](std::string str) {
ASSERT_EQ(str, "Foo");
}, Async::NoExcept);
Async::Promise<Test> p4(
[](Async::Resolver& resolve, Async::Rejection& reject) {
resolve(Test::Bar);
});
p4
.then(
[](Test result) {
return Async::Promise<std::string>(
[=](Async::Resolver& resolve, Async::Rejection& reject) {
switch (result) {
case Test::Foo:
resolve(std::string("Foo"));
break;
case Test::Bar:
reject(std::runtime_error("Invalid"));
}
});
},
Async::NoExcept)
.then(
[](std::string str) {
ASSERT_TRUE(false);
},
[](std::exception_ptr exc) {
ASSERT_THROW(std::rethrow_exception(exc), std::runtime_error);
});
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment