Commit 955ae9ac authored by octal's avatar octal

Promise::then now returns a new Promise so that chaining now becomes easier

parent 5455a471
...@@ -67,10 +67,10 @@ namespace Async { ...@@ -67,10 +67,10 @@ namespace Async {
class Core; class Core;
class Request { class Request {
public: public:
virtual void resolve(const std::shared_ptr<Private::Core>& core) = 0; virtual void resolve(const std::shared_ptr<Private::Core>& core) = 0;
virtual void reject(const std::shared_ptr<Private::Core>& core) = 0; virtual void reject(const std::shared_ptr<Private::Core>& core) = 0;
virtual void disconnect() = 0;
}; };
struct Core { struct Core {
...@@ -83,10 +83,37 @@ namespace Async { ...@@ -83,10 +83,37 @@ namespace Async {
std::vector<std::shared_ptr<Request>> requests; std::vector<std::shared_ptr<Request>> requests;
virtual void* memory() = 0; virtual void* memory() = 0;
template<typename T, typename... Args>
void construct(Args&&... args) {
void *mem = memory();
new (mem) T(std::forward<Args>(args)...);
state = State::Fulfilled;
}
}; };
} }
namespace detail {
template<typename T>
struct RemovePromise {
typedef T Type;
};
template<typename T>
struct RemovePromise<Promise<T>> {
typedef T Type;
};
template<typename Func>
struct result_of : public result_of<decltype(&Func::operator())> { };
template<typename R, typename Class, typename... Args>
struct result_of<R (Class::*) (Args...) const> {
typedef R Type;
};
}
class Resolver { class Resolver {
public: public:
Resolver(const std::shared_ptr<Private::Core> &core) Resolver(const std::shared_ptr<Private::Core> &core)
...@@ -100,9 +127,7 @@ namespace Async { ...@@ -100,9 +127,7 @@ namespace Async {
if (core_->state != State::Pending) if (core_->state != State::Pending)
throw Error("Attempt to resolve a fulfilled promise"); throw Error("Attempt to resolve a fulfilled promise");
void *mem = core_->memory(); core_->construct<Type>(std::forward<Arg>(arg));
new (mem) Type(std::forward<Arg>(arg));
core_->state = State::Fulfilled;
for (const auto& req: core_->requests) { for (const auto& req: core_->requests) {
req->resolve(core_); req->resolve(core_);
} }
...@@ -158,15 +183,9 @@ namespace Async { ...@@ -158,15 +183,9 @@ namespace Async {
class Promise : public PromiseBase class Promise : public PromiseBase
{ {
public: public:
typedef std::function<void (Resolver&, Rejection&)> ResolveFunc; template<typename U> friend class Promise;
struct Request : public Private::Request {
virtual Promise<T> chain() = 0;
void disconnect() {
}
}; typedef std::function<void (Resolver&, Rejection&)> ResolveFunc;
Promise(ResolveFunc func) Promise(ResolveFunc func)
: core_(std::make_shared<CoreT>()) : core_(std::make_shared<CoreT>())
...@@ -191,17 +210,32 @@ namespace Async { ...@@ -191,17 +210,32 @@ namespace Async {
bool isRejected() const { return core_->state == State::Rejected; } bool isRejected() const { return core_->state == State::Rejected; }
template<typename ResolveFunc, typename RejectFunc> template<typename ResolveFunc, typename RejectFunc>
std::shared_ptr<Request> then(ResolveFunc&& resolveFunc, RejectFunc&& rejectFunc, Continuation type = Continuation::Direct) { auto
then(ResolveFunc&& resolveFunc, RejectFunc&& rejectFunc, Continuation type = Continuation::Direct)
-> Promise<
typename detail::RemovePromise<
typename detail::result_of<ResolveFunc>::Type
>::Type
>
{
static_assert(detail::IsCallable<ResolveFunc, T>::value, "Function is not compatible with underlying promise type"); static_assert(detail::IsCallable<ResolveFunc, T>::value, "Function is not compatible with underlying promise type");
typedef typename detail::RemovePromise<
typename detail::result_of<ResolveFunc>::Type
>::Type RetType;
Promise<RetType> promise;
// Due to how template argument deduction works on universal references, we need to remove any reference from // Due to how template argument deduction works on universal references, we need to remove any reference from
// the deduced function type, fun fun fun // the deduced function type, fun fun fun
typedef typename std::remove_reference<ResolveFunc>::type ResolveFuncType; typedef typename std::remove_reference<ResolveFunc>::type ResolveFuncType;
typedef typename std::remove_reference<RejectFunc>::type RejectFuncType; typedef typename std::remove_reference<RejectFunc>::type RejectFuncType;
std::shared_ptr<Request> req; std::shared_ptr<Private::Request> req;
req.reset(ContinuationFactory<ResolveFuncType>::create( req.reset(ContinuationFactory<ResolveFuncType>::create(
std::forward<ResolveFunc>(resolveFunc), std::forward<RejectFunc>(rejectFunc))); promise.core_,
std::forward<ResolveFunc>(resolveFunc),
std::forward<RejectFunc>(rejectFunc)));
if (isFulfilled()) { if (isFulfilled()) {
req->resolve(core_); req->resolve(core_);
...@@ -210,8 +244,9 @@ namespace Async { ...@@ -210,8 +244,9 @@ namespace Async {
req->reject(core_); req->reject(core_);
} }
core_->requests.push_back(req); core_->requests.push_back(req);
return req;
return promise;
} }
private: private:
...@@ -243,15 +278,14 @@ namespace Async { ...@@ -243,15 +278,14 @@ namespace Async {
typedef Core<T> CoreT; typedef Core<T> CoreT;
Promise(const std::shared_ptr<CoreT>& core) Promise()
: core_(core) : core_(std::make_shared<Core<T>>())
, resolver_(core) , resolver_(core_)
, rejection_(core) , rejection_(core_)
{ {
} }
struct Continuable : public Private::Request {
struct Continuable : public Request {
Continuable() Continuable()
: resolveCount_(0) : resolveCount_(0)
, rejectCount_(0) , rejectCount_(0)
...@@ -286,7 +320,8 @@ namespace Async { ...@@ -286,7 +320,8 @@ namespace Async {
template<typename ResolveFunc, typename RejectFunc> template<typename ResolveFunc, typename RejectFunc>
struct ThenContinuation : public Continuable { struct ThenContinuation : public Continuable {
ThenContinuation(ResolveFunc&& resolveFunc, RejectFunc&& rejectFunc) ThenContinuation(
ResolveFunc&& resolveFunc, RejectFunc&& rejectFunc)
: resolveFunc_(std::forward<ResolveFunc>(resolveFunc)) : resolveFunc_(std::forward<ResolveFunc>(resolveFunc))
, rejectFunc_(std::forward<RejectFunc>(rejectFunc)) , rejectFunc_(std::forward<RejectFunc>(rejectFunc))
{ {
...@@ -300,53 +335,43 @@ namespace Async { ...@@ -300,53 +335,43 @@ namespace Async {
rejectFunc_(core->exc); rejectFunc_(core->exc);
} }
Promise<T> chain() { std::shared_ptr<Private::Core> chain_;
throw Error("The request is not chainable");
}
ResolveFunc resolveFunc_; ResolveFunc resolveFunc_;
RejectFunc rejectFunc_; RejectFunc rejectFunc_;
}; };
template<typename ResolveFunc, typename RejectFunc, typename Return> template<typename ResolveFunc, typename RejectFunc, typename Return>
struct ThenReturnContinuation : public Continuable { struct ThenReturnContinuation : public Continuable {
ThenReturnContinuation(ResolveFunc&& resolveFunc, RejectFunc&& rejectFunc) ThenReturnContinuation(
: resolveFunc_(std::forward<ResolveFunc>(resolveFunc)) const std::shared_ptr<Private::Core>& chain,
ResolveFunc&& resolveFunc, RejectFunc&& rejectFunc)
: chain_(chain)
, resolveFunc_(std::forward<ResolveFunc>(resolveFunc))
, rejectFunc_(std::forward<RejectFunc>(rejectFunc)) , rejectFunc_(std::forward<RejectFunc>(rejectFunc))
{ {
} }
void doResolve(const std::shared_ptr<CoreT>& core) const { void doResolve(const std::shared_ptr<CoreT>& core) const {
auto ret = resolveFunc_(core->value()); auto ret = resolveFunc_(core->value());
result = Some(std::move(ret)); chain_->construct<decltype(ret)>(std::move(ret));
} }
void doReject(const std::shared_ptr<CoreT>& core) const { void doReject(const std::shared_ptr<CoreT>& core) const {
rejectFunc_(core->exc); rejectFunc_(core->exc);
} }
Promise<Return> chain() { std::shared_ptr<Private::Core> chain_;
typedef typename CoreT::template Rebind<Return>::Type CoreType;
auto core = std::make_shared<CoreType>();
optionally_do(result, [&core](Return&& result) {
void *mem = core->memory();
new (mem) Return(std::move(result));
core->state = State::Fulfilled;
});
return Promise<Return>(core);
}
ResolveFunc resolveFunc_; ResolveFunc resolveFunc_;
RejectFunc rejectFunc_; RejectFunc rejectFunc_;
mutable Optional<Return> result;
}; };
template<typename ResolveFunc, typename RejectFunc> template<typename ResolveFunc, typename RejectFunc>
struct ThenChainContinuation : public Continuable { struct ThenChainContinuation : public Continuable {
ThenChainContinuation(ResolveFunc&& resolveFunc, RejectFunc&& rejectFunc) ThenChainContinuation(
: resolveFunc_(std::forward<ResolveFunc>(resolveFunc)) const std::shared_ptr<Private::Core>& chain,
ResolveFunc&& resolveFunc, RejectFunc&& rejectFunc)
: chain_(chain)
, resolveFunc_(std::forward<ResolveFunc>(resolveFunc))
, rejectFunc_(std::forward<RejectFunc>(rejectFunc)) , rejectFunc_(std::forward<RejectFunc>(rejectFunc))
{ {
} }
...@@ -359,11 +384,7 @@ namespace Async { ...@@ -359,11 +384,7 @@ namespace Async {
rejectFunc_(core->exc); rejectFunc_(core->exc);
} }
Promise<T> chain() { std::shared_ptr<Private::Core> chain_;
auto core = std::make_shared<CoreT>();
return Promise<T>(core);
}
ResolveFunc resolveFunc_; ResolveFunc resolveFunc_;
RejectFunc rejectFunc_; RejectFunc rejectFunc_;
...@@ -375,8 +396,11 @@ namespace Async { ...@@ -375,8 +396,11 @@ namespace Async {
template<typename R, typename Class, typename... Args> template<typename R, typename Class, typename... Args>
struct ContinuationFactory<R (Class::*)(Args...) const> { struct ContinuationFactory<R (Class::*)(Args...) const> {
template<typename ResolveFunc, typename RejectFunc> template<typename ResolveFunc, typename RejectFunc>
static Continuable *create(ResolveFunc&& resolveFunc, RejectFunc&& rejectFunc) { static Continuable *create(
const std::shared_ptr<Private::Core>& chain,
ResolveFunc&& resolveFunc, RejectFunc&& rejectFunc) {
return new ThenReturnContinuation<ResolveFunc, RejectFunc, R>( return new ThenReturnContinuation<ResolveFunc, RejectFunc, R>(
chain,
std::forward<ResolveFunc>(resolveFunc), std::forward<ResolveFunc>(resolveFunc),
std::forward<RejectFunc>(rejectFunc)); std::forward<RejectFunc>(rejectFunc));
} }
...@@ -385,7 +409,9 @@ namespace Async { ...@@ -385,7 +409,9 @@ namespace Async {
template<typename Class, typename... Args> template<typename Class, typename... Args>
struct ContinuationFactory<void (Class::*)(Args ...) const> { struct ContinuationFactory<void (Class::*)(Args ...) const> {
template<typename ResolveFunc, typename RejectFunc> template<typename ResolveFunc, typename RejectFunc>
static Continuable *create(ResolveFunc&& resolveFunc, RejectFunc&& rejectFunc) { static Continuable *create(
const std::shared_ptr<Private::Core>& chain,
ResolveFunc&& resolveFunc, RejectFunc&& rejectFunc) {
return new ThenContinuation<ResolveFunc, RejectFunc>( return new ThenContinuation<ResolveFunc, RejectFunc>(
std::forward<ResolveFunc>(resolveFunc), std::forward<ResolveFunc>(resolveFunc),
std::forward<RejectFunc>(rejectFunc)); std::forward<RejectFunc>(rejectFunc));
...@@ -395,8 +421,11 @@ namespace Async { ...@@ -395,8 +421,11 @@ namespace Async {
template<typename U, typename Class, typename... Args> template<typename U, typename Class, typename... Args>
struct ContinuationFactory<Promise<U> (Class::*)(Args...) const> { struct ContinuationFactory<Promise<U> (Class::*)(Args...) const> {
template<typename ResolveFunc, typename RejectFunc> template<typename ResolveFunc, typename RejectFunc>
static Continuable* create(ResolveFunc&& resolveFunc, RejectFunc&& rejectFunc) { static Continuable* create(
const std::shared_ptr<Private::Core>& chain,
ResolveFunc&& resolveFunc, RejectFunc&& rejectFunc) {
return new ThenChainContinuation<ResolveFunc, RejectFunc>( return new ThenChainContinuation<ResolveFunc, RejectFunc>(
chain,
std::forward<ResolveFunc>(resolveFunc), std::forward<ResolveFunc>(resolveFunc),
std::forward<RejectFunc>(rejectFunc)); std::forward<RejectFunc>(rejectFunc));
} }
...@@ -407,4 +436,22 @@ namespace Async { ...@@ -407,4 +436,22 @@ namespace Async {
Rejection rejection_; Rejection rejection_;
}; };
template<>
class Promise<void>
{
public:
Promise() :
core_(std::make_shared<Core>())
{ }
public:
struct Core : public Private::Core {
Core() :
Private::Core(State::Pending)
{ }
void *memory() { return nullptr; }
};
std::shared_ptr<Core> core_;
};
} // namespace Async } // namespace Async
...@@ -30,6 +30,7 @@ TEST(async_test, basic_test) { ...@@ -30,6 +30,7 @@ TEST(async_test, basic_test) {
p1.then([&](int v) { val = v; }, Async::NoExcept); p1.then([&](int v) { val = v; }, Async::NoExcept);
ASSERT_EQ(val, 10); ASSERT_EQ(val, 10);
{ {
Async::Promise<int> p2 = doAsync(10); Async::Promise<int> p2 = doAsync(10);
p2.then([](int result) { ASSERT_EQ(result, 20); }, p2.then([](int result) { ASSERT_EQ(result, 20); },
...@@ -58,7 +59,6 @@ TEST(async_test, chain_test) { ...@@ -58,7 +59,6 @@ TEST(async_test, chain_test) {
p1 p1
.then([](int result) { return result * 2; }, Async::NoExcept) .then([](int result) { return result * 2; }, Async::NoExcept)
->chain()
.then([](int result) { std::cout << "Result = " << result << std::endl; }, .then([](int result) { std::cout << "Result = " << result << std::endl; },
Async::NoExcept); Async::NoExcept);
...@@ -67,13 +67,10 @@ TEST(async_test, chain_test) { ...@@ -67,13 +67,10 @@ TEST(async_test, chain_test) {
resolve(10); resolve(10);
}); });
#if 0
p2 p2
.then([](int result) { return result * 2.0; }, Async::IgnoreException) .then([](int result) { return result * 2.2901; }, Async::IgnoreException)
->chain()
.then([](double result) { std::cout << "Result = " << result << std::endl; }, .then([](double result) { std::cout << "Result = " << result << std::endl; },
Async::IgnoreException); Async::IgnoreException);
#endif
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment