Commit d1cadedd authored by Lewis Baker's avatar Lewis Baker Committed by Facebook Github Bot

Improve folly::coro::Task<T> support for move-only types

Summary:
The folly::coro::Task<T> and folly::coro::Future<T> types would previously return an lvalue-reference to the result when you co_await them.

This means that for move-only types that code like `auto x = co_await someTask;` would fail to compile. Instead you would need to write `auto x = std::move(co_await someTask);`.

Awaiting a Task<T> or Future<T> now returns type T instead of T&.

As part of this change we now only allow co_awaiting an rvalue folly::coro::Task<T>, folly::coro::Future<T> as well as folly::coro::Future<T> to indicate that the operation is destructive and a one-time operation.

Reviewed By: andriigrynenko

Differential Revision: D9731525

fbshipit-source-id: bee9e633b57b203a0d048cf3eb0e2fc48b899481
parent 9933f34f
......@@ -37,40 +37,56 @@ class Future {
other.promise_ = nullptr;
}
bool await_ready() {
return promise_->state_.load(std::memory_order_acquire) ==
Promise<T>::State::HAS_RESULT;
}
class Awaiter {
public:
explicit Awaiter(Promise<T>* promise) noexcept : promise_(promise) {}
bool await_suspend(std::experimental::coroutine_handle<> awaiter) {
auto state = promise_->state_.load(std::memory_order_acquire);
bool await_ready() noexcept {
return promise_->isReady();
}
bool await_suspend(std::experimental::coroutine_handle<> awaiter) noexcept {
auto state = promise_->state_.load(std::memory_order_acquire);
if (state == Promise<T>::State::HAS_RESULT) {
return false;
}
DCHECK(state == Promise<T>::State::EMPTY);
if (state == Promise<T>::State::HAS_RESULT) {
promise_->awaiter_ = std::move(awaiter);
if (promise_->state_.compare_exchange_strong(
state,
Promise<T>::State::HAS_AWAITER,
std::memory_order_release,
std::memory_order_acquire)) {
return true;
}
DCHECK(promise_->state_ == Promise<T>::State::HAS_RESULT);
return false;
}
DCHECK(state == Promise<T>::State::EMPTY);
promise_->awaiter_ = std::move(awaiter);
if (promise_->state_.compare_exchange_strong(
state,
Promise<T>::State::HAS_AWAITER,
std::memory_order_release,
std::memory_order_acquire)) {
return true;
T await_resume() {
DCHECK(promise_->state_ == Promise<T>::State::HAS_RESULT);
return std::move(promise_->result_).value();
}
DCHECK(promise_->state_ == Promise<T>::State::HAS_RESULT);
return false;
}
private:
Promise<T>* promise_;
};
typename std::add_lvalue_reference<T>::type await_resume() {
DCHECK(promise_->state_ == Promise<T>::State::HAS_RESULT);
return *promise_->result_;
Awaiter operator co_await() && noexcept {
return Awaiter{promise_};
}
auto toFuture() &&;
bool isReady() const noexcept {
DCHECK(promise_);
return promise_->isReady();
}
~Future() {
if (!promise_) {
return;
......@@ -109,7 +125,7 @@ class Future {
namespace detail {
inline SemiFuture<Unit> toSemiFuture(Future<void> future) {
co_await future;
co_await std::move(future);
co_return folly::unit;
}
......
......@@ -15,6 +15,8 @@
*/
#pragma once
#include <type_traits>
#include <glog/logging.h>
#include <folly/ExceptionWrapper.h>
......@@ -47,7 +49,10 @@ class PromiseBase {
public:
template <typename U>
void return_value(U&& value) {
result_ = Try<T>(std::forward<U>(value));
static_assert(
std::is_convertible<U&&, T>::value,
"Returned value is not convertible to task result type");
result_.emplace(static_cast<U&&>(value));
}
protected:
......@@ -80,23 +85,23 @@ class Promise : public PromiseBase<T> {
return {};
}
// Don't allow awaiting lvalues of these types.
template <typename U>
auto await_transform(Task<U>&& task) {
return std::move(task).viaInline(executor_);
}
void await_transform(folly::SemiFuture<U>& future) = delete;
template <typename U>
decltype(auto) await_transform(folly::SemiFuture<U>& future) {
return future.via(executor_);
}
void await_transform(folly::Future<U>& future) = delete;
template <typename U>
void await_transform(Future<U>& future) = delete;
template <typename U>
void await_transform(Task<U>& task) = delete;
template <typename U>
decltype(auto) await_transform(folly::SemiFuture<U>&& future) {
return future.via(executor_);
auto await_transform(Task<U>&& task) {
return std::move(task).viaInline(executor_);
}
template <typename U>
decltype(auto) await_transform(folly::Future<U>& future) {
decltype(auto) await_transform(folly::SemiFuture<U>&& future) {
return future.via(executor_);
}
......@@ -105,22 +110,13 @@ class Promise : public PromiseBase<T> {
return future.via(executor_);
}
template <typename U>
auto await_transform(Future<U>& future) {
if (future.promise_->executor_ == executor_) {
return createAwaitWrapper(future);
}
return createAwaitWrapper(future, executor_);
}
template <typename U>
auto await_transform(Future<U>&& future) {
if (future.promise_->executor_ == executor_) {
return createAwaitWrapper(future);
return createAwaitWrapper(std::move(future));
}
return createAwaitWrapper(future, executor_);
return createAwaitWrapper(std::move(future), executor_);
}
template <typename U>
......@@ -147,6 +143,11 @@ class Promise : public PromiseBase<T> {
std::experimental::coroutine_handle<Promise>::from_promise (*this)();
}
bool isReady() const noexcept {
return state_.load(std::memory_order_acquire) ==
Promise<T>::State::HAS_RESULT;
}
private:
friend class Future<T>;
friend class Task<T>;
......
......@@ -20,6 +20,7 @@
#include <type_traits>
#include <folly/Optional.h>
#include <folly/experimental/coro/Traits.h>
#include <folly/experimental/coro/Wait.h>
#include <folly/futures/Future.h>
......@@ -78,8 +79,7 @@ class TimedWaitAwaitable {
static_assert(
std::is_same<Awaitable, std::decay_t<Awaitable>>::value,
"Awaitable should be decayed.");
using await_resume_return_type =
decltype((operator co_await(std::declval<Awaitable>())).await_resume());
using await_resume_return_type = await_result_t<Awaitable>;
TimedWaitAwaitable(Awaitable&& awaitable, std::chrono::milliseconds duration)
: awaitable_(std::move(awaitable)), duration_(duration) {}
......@@ -116,7 +116,7 @@ class TimedWaitAwaitable {
return;
}
assume(!storage_.hasValue() && !storage_.hasException());
storage_ = Try<await_resume_return_type>(std::move(value));
tryEmplace(storage_, static_cast<await_resume_return_type&&>(value));
ch_();
}
......@@ -125,7 +125,7 @@ class TimedWaitAwaitable {
return;
}
assume(!storage_.hasValue() && !storage_.hasException());
storage_ = Try<await_resume_return_type>(std::move(e));
storage_.emplaceException(std::move(e));
ch_();
}
......@@ -146,7 +146,7 @@ class TimedWaitAwaitable {
Awaitable awaitable,
std::shared_ptr<SharedState> sharedState) {
try {
sharedState->setValue(co_await awaitable);
sharedState->setValue(co_await std::forward<Awaitable>(awaitable));
} catch (const std::exception& e) {
sharedState->setException(exception_wrapper(std::current_exception(), e));
} catch (...) {
......
......@@ -35,12 +35,12 @@ TEST(Coro, Basic) {
ManualExecutor executor;
auto future = via(&executor, task42());
EXPECT_FALSE(future.await_ready());
EXPECT_FALSE(future.isReady());
executor.drive();
EXPECT_TRUE(future.await_ready());
EXPECT_EQ(42, folly::coro::blockingWait(future));
EXPECT_TRUE(future.isReady());
EXPECT_EQ(42, folly::coro::blockingWait(std::move(future)));
}
TEST(Coro, BasicFuture) {
......@@ -62,11 +62,11 @@ TEST(Coro, Basic2) {
ManualExecutor executor;
auto future = via(&executor, taskVoid());
EXPECT_FALSE(future.await_ready());
EXPECT_FALSE(future.isReady());
executor.drive();
EXPECT_TRUE(future.await_ready());
EXPECT_TRUE(future.isReady());
}
coro::Task<void> taskSleep() {
......@@ -74,15 +74,25 @@ coro::Task<void> taskSleep() {
co_return;
}
TEST(Coro, TaskOfMoveOnly) {
auto f = []() -> coro::Task<std::unique_ptr<int>> {
co_return std::make_unique<int>(123);
};
auto p = coro::blockingWait(f().scheduleVia(&InlineExecutor::instance()));
EXPECT_TRUE(p);
EXPECT_EQ(123, *p);
}
TEST(Coro, Sleep) {
ScopedEventBaseThread evbThread;
auto startTime = std::chrono::steady_clock::now();
auto future = via(evbThread.getEventBase(), taskSleep());
EXPECT_FALSE(future.await_ready());
EXPECT_FALSE(future.isReady());
coro::blockingWait(future);
coro::blockingWait(std::move(future));
// The total time should be roughly 1 second. Some builds, especially
// optimized ones, may result in slightly less than 1 second, so we perform
......@@ -90,8 +100,6 @@ TEST(Coro, Sleep) {
auto totalTime = std::chrono::steady_clock::now() - startTime;
EXPECT_GE(
chrono::round<std::chrono::seconds>(totalTime), std::chrono::seconds{1});
EXPECT_TRUE(future.await_ready());
}
coro::Task<int> taskException() {
......@@ -103,12 +111,12 @@ TEST(Coro, Throw) {
ManualExecutor executor;
auto future = via(&executor, taskException());
EXPECT_FALSE(future.await_ready());
EXPECT_FALSE(future.isReady());
executor.drive();
EXPECT_TRUE(future.await_ready());
EXPECT_THROW(coro::blockingWait(future), std::runtime_error);
EXPECT_TRUE(future.isReady());
EXPECT_THROW(coro::blockingWait(std::move(future)), std::runtime_error);
}
TEST(Coro, FutureThrow) {
......@@ -137,7 +145,7 @@ TEST(Coro, LargeStack) {
ScopedEventBaseThread evbThread;
auto future = via(evbThread.getEventBase(), taskRecursion(5000));
EXPECT_EQ(5000, coro::blockingWait(future));
EXPECT_EQ(5000, coro::blockingWait(std::move(future)));
}
coro::Task<void> taskThreadNested(std::thread::id threadId) {
......@@ -163,7 +171,7 @@ TEST(Coro, NestedThreads) {
ScopedEventBaseThread evbThread;
auto future = via(evbThread.getEventBase(), taskThread());
EXPECT_EQ(42, coro::blockingWait(future));
EXPECT_EQ(42, coro::blockingWait(std::move(future)));
}
coro::Task<int> taskYield(Executor* executor) {
......@@ -171,12 +179,12 @@ coro::Task<int> taskYield(Executor* executor) {
EXPECT_EQ(executor, currentExecutor);
auto future = via(currentExecutor, task42());
EXPECT_FALSE(future.await_ready());
EXPECT_FALSE(future.isReady());
co_await coro::yield();
EXPECT_TRUE(future.await_ready());
co_return co_await future;
EXPECT_TRUE(future.isReady());
co_return co_await std::move(future);
}
TEST(Coro, CurrentExecutor) {
......@@ -184,7 +192,7 @@ TEST(Coro, CurrentExecutor) {
auto future =
via(evbThread.getEventBase(), taskYield(evbThread.getEventBase()));
EXPECT_EQ(42, coro::blockingWait(future));
EXPECT_EQ(42, coro::blockingWait(std::move(future)));
}
coro::Task<void> taskTimedWait() {
......@@ -286,17 +294,17 @@ TEST(Coro, Baton) {
fibers::Baton baton;
auto future = via(&executor, taskBaton(baton));
EXPECT_FALSE(future.await_ready());
EXPECT_FALSE(future.isReady());
executor.run();
EXPECT_FALSE(future.await_ready());
EXPECT_FALSE(future.isReady());
baton.post();
executor.run();
EXPECT_TRUE(future.await_ready());
EXPECT_EQ(42, coro::blockingWait(future));
EXPECT_TRUE(future.isReady());
EXPECT_EQ(42, coro::blockingWait(std::move(future)));
}
#endif
......@@ -150,9 +150,9 @@ TEST(Mutex, ThreadSafety) {
auto f2 = makeTask().scheduleVia(&threadPool);
auto f3 = makeTask().scheduleVia(&threadPool);
coro::blockingWait(f1);
coro::blockingWait(f2);
coro::blockingWait(f3);
coro::blockingWait(std::move(f1));
coro::blockingWait(std::move(f2));
coro::blockingWait(std::move(f3));
CHECK_EQ(30'000, value);
}
......
......@@ -863,7 +863,7 @@ class SemiFuture : private futures::detail::FutureBase<T> {
template <typename Awaitable>
static SemiFuture fromAwaitable(Awaitable&& awaitable) {
return [](Awaitable awaitable) -> SemiFuture {
co_return co_await awaitable;
co_return co_await std::forward<Awaitable>(awaitable);
}(std::forward<Awaitable>(awaitable));
}
#endif
......@@ -1996,10 +1996,11 @@ std::pair<Promise<T>, Future<T>> makePromiseContract(Executor* e) {
namespace folly {
namespace detail {
template <typename T>
class FutureAwaitable {
public:
explicit FutureAwaitable(folly::Future<T>&& future)
explicit FutureAwaitable(folly::Future<T>&& future) noexcept
: future_(std::move(future)) {}
bool await_ready() const {
......@@ -2007,49 +2008,23 @@ class FutureAwaitable {
}
T await_resume() {
return std::move(future_.value());
return std::move(future_).value();
}
void await_suspend(std::experimental::coroutine_handle<> h) {
future_.setCallback_([h](Try<T>&&) mutable { h(); });
future_.setCallback_([h](Try<T>&&) mutable { h.resume(); });
}
private:
folly::Future<T> future_;
};
template <typename T>
class FutureRefAwaitable {
public:
explicit FutureRefAwaitable(folly::Future<T>& future) : future_(future) {}
bool await_ready() const {
return future_.isReady();
}
T await_resume() {
return std::move(future_.value());
}
void await_suspend(std::experimental::coroutine_handle<> h) {
future_.setCallback_([h](Try<T>&&) mutable { h(); });
}
private:
folly::Future<T>& future_;
};
} // namespace detail
template <typename T>
inline detail::FutureRefAwaitable<T>
/* implicit */ operator co_await(Future<T>& future) {
return detail::FutureRefAwaitable<T>(future);
}
template <typename T>
inline detail::FutureRefAwaitable<T>
/* implicit */ operator co_await(Future<T>&& future) {
return detail::FutureRefAwaitable<T>(future);
inline detail::FutureAwaitable<T>
/* implicit */ operator co_await(Future<T>&& future) noexcept {
return detail::FutureAwaitable<T>(std::move(future));
}
} // namespace folly
#endif
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment