Commit d1cadedd authored by Lewis Baker's avatar Lewis Baker Committed by Facebook Github Bot

Improve folly::coro::Task<T> support for move-only types

Summary:
The folly::coro::Task<T> and folly::coro::Future<T> types would previously return an lvalue-reference to the result when you co_await them.

This means that for move-only types that code like `auto x = co_await someTask;` would fail to compile. Instead you would need to write `auto x = std::move(co_await someTask);`.

Awaiting a Task<T> or Future<T> now returns type T instead of T&.

As part of this change we now only allow co_awaiting an rvalue folly::coro::Task<T>, folly::coro::Future<T> as well as folly::coro::Future<T> to indicate that the operation is destructive and a one-time operation.

Reviewed By: andriigrynenko

Differential Revision: D9731525

fbshipit-source-id: bee9e633b57b203a0d048cf3eb0e2fc48b899481
parent 9933f34f
...@@ -37,40 +37,56 @@ class Future { ...@@ -37,40 +37,56 @@ class Future {
other.promise_ = nullptr; other.promise_ = nullptr;
} }
bool await_ready() { class Awaiter {
return promise_->state_.load(std::memory_order_acquire) == public:
Promise<T>::State::HAS_RESULT; explicit Awaiter(Promise<T>* promise) noexcept : promise_(promise) {}
}
bool await_suspend(std::experimental::coroutine_handle<> awaiter) { bool await_ready() noexcept {
auto state = promise_->state_.load(std::memory_order_acquire); return promise_->isReady();
}
bool await_suspend(std::experimental::coroutine_handle<> awaiter) noexcept {
auto state = promise_->state_.load(std::memory_order_acquire);
if (state == Promise<T>::State::HAS_RESULT) {
return false;
}
DCHECK(state == Promise<T>::State::EMPTY);
if (state == Promise<T>::State::HAS_RESULT) { promise_->awaiter_ = std::move(awaiter);
if (promise_->state_.compare_exchange_strong(
state,
Promise<T>::State::HAS_AWAITER,
std::memory_order_release,
std::memory_order_acquire)) {
return true;
}
DCHECK(promise_->state_ == Promise<T>::State::HAS_RESULT);
return false; return false;
} }
DCHECK(state == Promise<T>::State::EMPTY);
promise_->awaiter_ = std::move(awaiter);
if (promise_->state_.compare_exchange_strong( T await_resume() {
state, DCHECK(promise_->state_ == Promise<T>::State::HAS_RESULT);
Promise<T>::State::HAS_AWAITER, return std::move(promise_->result_).value();
std::memory_order_release,
std::memory_order_acquire)) {
return true;
} }
DCHECK(promise_->state_ == Promise<T>::State::HAS_RESULT); private:
return false; Promise<T>* promise_;
} };
typename std::add_lvalue_reference<T>::type await_resume() { Awaiter operator co_await() && noexcept {
DCHECK(promise_->state_ == Promise<T>::State::HAS_RESULT); return Awaiter{promise_};
return *promise_->result_;
} }
auto toFuture() &&; auto toFuture() &&;
bool isReady() const noexcept {
DCHECK(promise_);
return promise_->isReady();
}
~Future() { ~Future() {
if (!promise_) { if (!promise_) {
return; return;
...@@ -109,7 +125,7 @@ class Future { ...@@ -109,7 +125,7 @@ class Future {
namespace detail { namespace detail {
inline SemiFuture<Unit> toSemiFuture(Future<void> future) { inline SemiFuture<Unit> toSemiFuture(Future<void> future) {
co_await future; co_await std::move(future);
co_return folly::unit; co_return folly::unit;
} }
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
*/ */
#pragma once #pragma once
#include <type_traits>
#include <glog/logging.h> #include <glog/logging.h>
#include <folly/ExceptionWrapper.h> #include <folly/ExceptionWrapper.h>
...@@ -47,7 +49,10 @@ class PromiseBase { ...@@ -47,7 +49,10 @@ class PromiseBase {
public: public:
template <typename U> template <typename U>
void return_value(U&& value) { void return_value(U&& value) {
result_ = Try<T>(std::forward<U>(value)); static_assert(
std::is_convertible<U&&, T>::value,
"Returned value is not convertible to task result type");
result_.emplace(static_cast<U&&>(value));
} }
protected: protected:
...@@ -80,23 +85,23 @@ class Promise : public PromiseBase<T> { ...@@ -80,23 +85,23 @@ class Promise : public PromiseBase<T> {
return {}; return {};
} }
// Don't allow awaiting lvalues of these types.
template <typename U> template <typename U>
auto await_transform(Task<U>&& task) { void await_transform(folly::SemiFuture<U>& future) = delete;
return std::move(task).viaInline(executor_);
}
template <typename U> template <typename U>
decltype(auto) await_transform(folly::SemiFuture<U>& future) { void await_transform(folly::Future<U>& future) = delete;
return future.via(executor_); template <typename U>
} void await_transform(Future<U>& future) = delete;
template <typename U>
void await_transform(Task<U>& task) = delete;
template <typename U> template <typename U>
decltype(auto) await_transform(folly::SemiFuture<U>&& future) { auto await_transform(Task<U>&& task) {
return future.via(executor_); return std::move(task).viaInline(executor_);
} }
template <typename U> template <typename U>
decltype(auto) await_transform(folly::Future<U>& future) { decltype(auto) await_transform(folly::SemiFuture<U>&& future) {
return future.via(executor_); return future.via(executor_);
} }
...@@ -105,22 +110,13 @@ class Promise : public PromiseBase<T> { ...@@ -105,22 +110,13 @@ class Promise : public PromiseBase<T> {
return future.via(executor_); return future.via(executor_);
} }
template <typename U>
auto await_transform(Future<U>& future) {
if (future.promise_->executor_ == executor_) {
return createAwaitWrapper(future);
}
return createAwaitWrapper(future, executor_);
}
template <typename U> template <typename U>
auto await_transform(Future<U>&& future) { auto await_transform(Future<U>&& future) {
if (future.promise_->executor_ == executor_) { if (future.promise_->executor_ == executor_) {
return createAwaitWrapper(future); return createAwaitWrapper(std::move(future));
} }
return createAwaitWrapper(future, executor_); return createAwaitWrapper(std::move(future), executor_);
} }
template <typename U> template <typename U>
...@@ -147,6 +143,11 @@ class Promise : public PromiseBase<T> { ...@@ -147,6 +143,11 @@ class Promise : public PromiseBase<T> {
std::experimental::coroutine_handle<Promise>::from_promise (*this)(); std::experimental::coroutine_handle<Promise>::from_promise (*this)();
} }
bool isReady() const noexcept {
return state_.load(std::memory_order_acquire) ==
Promise<T>::State::HAS_RESULT;
}
private: private:
friend class Future<T>; friend class Future<T>;
friend class Task<T>; friend class Task<T>;
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <type_traits> #include <type_traits>
#include <folly/Optional.h> #include <folly/Optional.h>
#include <folly/experimental/coro/Traits.h>
#include <folly/experimental/coro/Wait.h> #include <folly/experimental/coro/Wait.h>
#include <folly/futures/Future.h> #include <folly/futures/Future.h>
...@@ -78,8 +79,7 @@ class TimedWaitAwaitable { ...@@ -78,8 +79,7 @@ class TimedWaitAwaitable {
static_assert( static_assert(
std::is_same<Awaitable, std::decay_t<Awaitable>>::value, std::is_same<Awaitable, std::decay_t<Awaitable>>::value,
"Awaitable should be decayed."); "Awaitable should be decayed.");
using await_resume_return_type = using await_resume_return_type = await_result_t<Awaitable>;
decltype((operator co_await(std::declval<Awaitable>())).await_resume());
TimedWaitAwaitable(Awaitable&& awaitable, std::chrono::milliseconds duration) TimedWaitAwaitable(Awaitable&& awaitable, std::chrono::milliseconds duration)
: awaitable_(std::move(awaitable)), duration_(duration) {} : awaitable_(std::move(awaitable)), duration_(duration) {}
...@@ -116,7 +116,7 @@ class TimedWaitAwaitable { ...@@ -116,7 +116,7 @@ class TimedWaitAwaitable {
return; return;
} }
assume(!storage_.hasValue() && !storage_.hasException()); assume(!storage_.hasValue() && !storage_.hasException());
storage_ = Try<await_resume_return_type>(std::move(value)); tryEmplace(storage_, static_cast<await_resume_return_type&&>(value));
ch_(); ch_();
} }
...@@ -125,7 +125,7 @@ class TimedWaitAwaitable { ...@@ -125,7 +125,7 @@ class TimedWaitAwaitable {
return; return;
} }
assume(!storage_.hasValue() && !storage_.hasException()); assume(!storage_.hasValue() && !storage_.hasException());
storage_ = Try<await_resume_return_type>(std::move(e)); storage_.emplaceException(std::move(e));
ch_(); ch_();
} }
...@@ -146,7 +146,7 @@ class TimedWaitAwaitable { ...@@ -146,7 +146,7 @@ class TimedWaitAwaitable {
Awaitable awaitable, Awaitable awaitable,
std::shared_ptr<SharedState> sharedState) { std::shared_ptr<SharedState> sharedState) {
try { try {
sharedState->setValue(co_await awaitable); sharedState->setValue(co_await std::forward<Awaitable>(awaitable));
} catch (const std::exception& e) { } catch (const std::exception& e) {
sharedState->setException(exception_wrapper(std::current_exception(), e)); sharedState->setException(exception_wrapper(std::current_exception(), e));
} catch (...) { } catch (...) {
......
...@@ -35,12 +35,12 @@ TEST(Coro, Basic) { ...@@ -35,12 +35,12 @@ TEST(Coro, Basic) {
ManualExecutor executor; ManualExecutor executor;
auto future = via(&executor, task42()); auto future = via(&executor, task42());
EXPECT_FALSE(future.await_ready()); EXPECT_FALSE(future.isReady());
executor.drive(); executor.drive();
EXPECT_TRUE(future.await_ready()); EXPECT_TRUE(future.isReady());
EXPECT_EQ(42, folly::coro::blockingWait(future)); EXPECT_EQ(42, folly::coro::blockingWait(std::move(future)));
} }
TEST(Coro, BasicFuture) { TEST(Coro, BasicFuture) {
...@@ -62,11 +62,11 @@ TEST(Coro, Basic2) { ...@@ -62,11 +62,11 @@ TEST(Coro, Basic2) {
ManualExecutor executor; ManualExecutor executor;
auto future = via(&executor, taskVoid()); auto future = via(&executor, taskVoid());
EXPECT_FALSE(future.await_ready()); EXPECT_FALSE(future.isReady());
executor.drive(); executor.drive();
EXPECT_TRUE(future.await_ready()); EXPECT_TRUE(future.isReady());
} }
coro::Task<void> taskSleep() { coro::Task<void> taskSleep() {
...@@ -74,15 +74,25 @@ coro::Task<void> taskSleep() { ...@@ -74,15 +74,25 @@ coro::Task<void> taskSleep() {
co_return; co_return;
} }
TEST(Coro, TaskOfMoveOnly) {
auto f = []() -> coro::Task<std::unique_ptr<int>> {
co_return std::make_unique<int>(123);
};
auto p = coro::blockingWait(f().scheduleVia(&InlineExecutor::instance()));
EXPECT_TRUE(p);
EXPECT_EQ(123, *p);
}
TEST(Coro, Sleep) { TEST(Coro, Sleep) {
ScopedEventBaseThread evbThread; ScopedEventBaseThread evbThread;
auto startTime = std::chrono::steady_clock::now(); auto startTime = std::chrono::steady_clock::now();
auto future = via(evbThread.getEventBase(), taskSleep()); auto future = via(evbThread.getEventBase(), taskSleep());
EXPECT_FALSE(future.await_ready()); EXPECT_FALSE(future.isReady());
coro::blockingWait(future); coro::blockingWait(std::move(future));
// The total time should be roughly 1 second. Some builds, especially // The total time should be roughly 1 second. Some builds, especially
// optimized ones, may result in slightly less than 1 second, so we perform // optimized ones, may result in slightly less than 1 second, so we perform
...@@ -90,8 +100,6 @@ TEST(Coro, Sleep) { ...@@ -90,8 +100,6 @@ TEST(Coro, Sleep) {
auto totalTime = std::chrono::steady_clock::now() - startTime; auto totalTime = std::chrono::steady_clock::now() - startTime;
EXPECT_GE( EXPECT_GE(
chrono::round<std::chrono::seconds>(totalTime), std::chrono::seconds{1}); chrono::round<std::chrono::seconds>(totalTime), std::chrono::seconds{1});
EXPECT_TRUE(future.await_ready());
} }
coro::Task<int> taskException() { coro::Task<int> taskException() {
...@@ -103,12 +111,12 @@ TEST(Coro, Throw) { ...@@ -103,12 +111,12 @@ TEST(Coro, Throw) {
ManualExecutor executor; ManualExecutor executor;
auto future = via(&executor, taskException()); auto future = via(&executor, taskException());
EXPECT_FALSE(future.await_ready()); EXPECT_FALSE(future.isReady());
executor.drive(); executor.drive();
EXPECT_TRUE(future.await_ready()); EXPECT_TRUE(future.isReady());
EXPECT_THROW(coro::blockingWait(future), std::runtime_error); EXPECT_THROW(coro::blockingWait(std::move(future)), std::runtime_error);
} }
TEST(Coro, FutureThrow) { TEST(Coro, FutureThrow) {
...@@ -137,7 +145,7 @@ TEST(Coro, LargeStack) { ...@@ -137,7 +145,7 @@ TEST(Coro, LargeStack) {
ScopedEventBaseThread evbThread; ScopedEventBaseThread evbThread;
auto future = via(evbThread.getEventBase(), taskRecursion(5000)); auto future = via(evbThread.getEventBase(), taskRecursion(5000));
EXPECT_EQ(5000, coro::blockingWait(future)); EXPECT_EQ(5000, coro::blockingWait(std::move(future)));
} }
coro::Task<void> taskThreadNested(std::thread::id threadId) { coro::Task<void> taskThreadNested(std::thread::id threadId) {
...@@ -163,7 +171,7 @@ TEST(Coro, NestedThreads) { ...@@ -163,7 +171,7 @@ TEST(Coro, NestedThreads) {
ScopedEventBaseThread evbThread; ScopedEventBaseThread evbThread;
auto future = via(evbThread.getEventBase(), taskThread()); auto future = via(evbThread.getEventBase(), taskThread());
EXPECT_EQ(42, coro::blockingWait(future)); EXPECT_EQ(42, coro::blockingWait(std::move(future)));
} }
coro::Task<int> taskYield(Executor* executor) { coro::Task<int> taskYield(Executor* executor) {
...@@ -171,12 +179,12 @@ coro::Task<int> taskYield(Executor* executor) { ...@@ -171,12 +179,12 @@ coro::Task<int> taskYield(Executor* executor) {
EXPECT_EQ(executor, currentExecutor); EXPECT_EQ(executor, currentExecutor);
auto future = via(currentExecutor, task42()); auto future = via(currentExecutor, task42());
EXPECT_FALSE(future.await_ready()); EXPECT_FALSE(future.isReady());
co_await coro::yield(); co_await coro::yield();
EXPECT_TRUE(future.await_ready()); EXPECT_TRUE(future.isReady());
co_return co_await future; co_return co_await std::move(future);
} }
TEST(Coro, CurrentExecutor) { TEST(Coro, CurrentExecutor) {
...@@ -184,7 +192,7 @@ TEST(Coro, CurrentExecutor) { ...@@ -184,7 +192,7 @@ TEST(Coro, CurrentExecutor) {
auto future = auto future =
via(evbThread.getEventBase(), taskYield(evbThread.getEventBase())); via(evbThread.getEventBase(), taskYield(evbThread.getEventBase()));
EXPECT_EQ(42, coro::blockingWait(future)); EXPECT_EQ(42, coro::blockingWait(std::move(future)));
} }
coro::Task<void> taskTimedWait() { coro::Task<void> taskTimedWait() {
...@@ -286,17 +294,17 @@ TEST(Coro, Baton) { ...@@ -286,17 +294,17 @@ TEST(Coro, Baton) {
fibers::Baton baton; fibers::Baton baton;
auto future = via(&executor, taskBaton(baton)); auto future = via(&executor, taskBaton(baton));
EXPECT_FALSE(future.await_ready()); EXPECT_FALSE(future.isReady());
executor.run(); executor.run();
EXPECT_FALSE(future.await_ready()); EXPECT_FALSE(future.isReady());
baton.post(); baton.post();
executor.run(); executor.run();
EXPECT_TRUE(future.await_ready()); EXPECT_TRUE(future.isReady());
EXPECT_EQ(42, coro::blockingWait(future)); EXPECT_EQ(42, coro::blockingWait(std::move(future)));
} }
#endif #endif
...@@ -150,9 +150,9 @@ TEST(Mutex, ThreadSafety) { ...@@ -150,9 +150,9 @@ TEST(Mutex, ThreadSafety) {
auto f2 = makeTask().scheduleVia(&threadPool); auto f2 = makeTask().scheduleVia(&threadPool);
auto f3 = makeTask().scheduleVia(&threadPool); auto f3 = makeTask().scheduleVia(&threadPool);
coro::blockingWait(f1); coro::blockingWait(std::move(f1));
coro::blockingWait(f2); coro::blockingWait(std::move(f2));
coro::blockingWait(f3); coro::blockingWait(std::move(f3));
CHECK_EQ(30'000, value); CHECK_EQ(30'000, value);
} }
......
...@@ -863,7 +863,7 @@ class SemiFuture : private futures::detail::FutureBase<T> { ...@@ -863,7 +863,7 @@ class SemiFuture : private futures::detail::FutureBase<T> {
template <typename Awaitable> template <typename Awaitable>
static SemiFuture fromAwaitable(Awaitable&& awaitable) { static SemiFuture fromAwaitable(Awaitable&& awaitable) {
return [](Awaitable awaitable) -> SemiFuture { return [](Awaitable awaitable) -> SemiFuture {
co_return co_await awaitable; co_return co_await std::forward<Awaitable>(awaitable);
}(std::forward<Awaitable>(awaitable)); }(std::forward<Awaitable>(awaitable));
} }
#endif #endif
...@@ -1996,10 +1996,11 @@ std::pair<Promise<T>, Future<T>> makePromiseContract(Executor* e) { ...@@ -1996,10 +1996,11 @@ std::pair<Promise<T>, Future<T>> makePromiseContract(Executor* e) {
namespace folly { namespace folly {
namespace detail { namespace detail {
template <typename T> template <typename T>
class FutureAwaitable { class FutureAwaitable {
public: public:
explicit FutureAwaitable(folly::Future<T>&& future) explicit FutureAwaitable(folly::Future<T>&& future) noexcept
: future_(std::move(future)) {} : future_(std::move(future)) {}
bool await_ready() const { bool await_ready() const {
...@@ -2007,49 +2008,23 @@ class FutureAwaitable { ...@@ -2007,49 +2008,23 @@ class FutureAwaitable {
} }
T await_resume() { T await_resume() {
return std::move(future_.value()); return std::move(future_).value();
} }
void await_suspend(std::experimental::coroutine_handle<> h) { void await_suspend(std::experimental::coroutine_handle<> h) {
future_.setCallback_([h](Try<T>&&) mutable { h(); }); future_.setCallback_([h](Try<T>&&) mutable { h.resume(); });
} }
private: private:
folly::Future<T> future_; folly::Future<T> future_;
}; };
template <typename T>
class FutureRefAwaitable {
public:
explicit FutureRefAwaitable(folly::Future<T>& future) : future_(future) {}
bool await_ready() const {
return future_.isReady();
}
T await_resume() {
return std::move(future_.value());
}
void await_suspend(std::experimental::coroutine_handle<> h) {
future_.setCallback_([h](Try<T>&&) mutable { h(); });
}
private:
folly::Future<T>& future_;
};
} // namespace detail } // namespace detail
template <typename T> template <typename T>
inline detail::FutureRefAwaitable<T> inline detail::FutureAwaitable<T>
/* implicit */ operator co_await(Future<T>& future) { /* implicit */ operator co_await(Future<T>&& future) noexcept {
return detail::FutureRefAwaitable<T>(future); return detail::FutureAwaitable<T>(std::move(future));
}
template <typename T>
inline detail::FutureRefAwaitable<T>
/* implicit */ operator co_await(Future<T>&& future) {
return detail::FutureRefAwaitable<T>(future);
} }
} // namespace folly } // namespace folly
#endif #endif
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment