Commit 73a507a2 authored by Andrii Grynenko's avatar Andrii Grynenko Committed by Facebook Github Bot

Add support for types that implement operator co_await()

Summary: AwaitWrapper was previously only working with types which are Awaitable (i.e. have await_ready()/await_suspend()/await_resume()). This extends it to support types which have operator co_await() implemented.

Reviewed By: lewissbaker

Differential Revision: D8580759

fbshipit-source-id: 980f9c9f34c5a2302badb2ab7c4644883b14aa2c
parent 48a252cf
......@@ -19,9 +19,21 @@
#include <folly/ExceptionString.h>
#include <folly/Executor.h>
#include <folly/Optional.h>
namespace folly {
namespace coro {
namespace detail {
template <typename T>
T&& getRef(T&& t) {
return std::forward<T>(t);
}
template <typename T>
T& getRef(std::reference_wrapper<T> t) {
return t.get();
}
template <typename Awaitable>
class AwaitWrapper {
......@@ -61,39 +73,40 @@ class AwaitWrapper {
std::experimental::coroutine_handle<> awaiter_;
};
static AwaitWrapper create(Awaitable* awaitable) {
return {awaitable};
AwaitWrapper(AwaitWrapper&& other)
: promise_(std::exchange(other.promise_, nullptr)),
awaitable_(std::move(other.awaitable_)) {}
AwaitWrapper& operator=(AwaitWrapper&&) = delete;
static AwaitWrapper create(Awaitable&& awaitable) {
return {std::move(awaitable)};
}
static AwaitWrapper create(Awaitable* awaitable, Executor* executor) {
static AwaitWrapper create(Awaitable&& awaitable, Executor* executor) {
auto ret = awaitWrapper();
ret.awaitable_ = awaitable;
ret.awaitable_.emplace(std::move(awaitable));
ret.promise_->executor_ = executor;
return ret;
}
bool await_ready() {
return awaitable_->await_ready();
return getRef(*awaitable_).await_ready();
}
using await_suspend_return_type =
decltype((*static_cast<Awaitable*>(nullptr))
.await_suspend(std::experimental::coroutine_handle<>()));
await_suspend_return_type await_suspend(
std::experimental::coroutine_handle<> awaiter) {
decltype(auto) await_suspend(std::experimental::coroutine_handle<> awaiter) {
if (promise_) {
promise_->awaiter_ = std::move(awaiter);
return awaitable_->await_suspend(
std::experimental::coroutine_handle<promise_type>::from_promise(
*promise_));
return getRef(*awaitable_)
.await_suspend(
std::experimental::coroutine_handle<promise_type>::from_promise(
*promise_));
}
return awaitable_->await_suspend(awaiter);
return getRef(*awaitable_).await_suspend(awaiter);
}
decltype((*static_cast<Awaitable*>(nullptr)).await_resume()) await_resume() {
return awaitable_->await_resume();
decltype(auto) await_resume() {
return getRef(*awaitable_).await_resume();
}
~AwaitWrapper() {
......@@ -104,7 +117,9 @@ class AwaitWrapper {
}
private:
AwaitWrapper(Awaitable* awaitable) : awaitable_(awaitable) {}
AwaitWrapper(Awaitable&& awaitable) {
awaitable_.emplace(std::move(awaitable));
}
AwaitWrapper(promise_type& promise) : promise_(&promise) {}
static AwaitWrapper awaitWrapper() {
......@@ -112,7 +127,68 @@ class AwaitWrapper {
}
promise_type* promise_{nullptr};
Awaitable* awaitable_{nullptr};
Optional<Awaitable> awaitable_;
};
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wc++17-extensions"
FOLLY_CREATE_MEMBER_INVOKE_TRAITS(
member_operator_co_await_traits,
operator co_await);
template <typename Awaitable>
inline constexpr bool has_member_operator_co_await_v =
member_operator_co_await_traits::is_invocable<Awaitable>::value;
FOLLY_CREATE_FREE_INVOKE_TRAITS(
non_member_operator_co_await_traits,
operator co_await);
template <typename Awaitable>
inline constexpr bool has_non_member_operator_co_await_v =
non_member_operator_co_await_traits::is_invocable<Awaitable>::value;
} // namespace detail
template <typename Awaitable>
decltype(auto) get_awaiter(Awaitable&& awaitable) {
if constexpr (detail::has_member_operator_co_await_v<Awaitable&&>) {
return std::forward<Awaitable>(awaitable).operator co_await();
} else if constexpr (detail::has_non_member_operator_co_await_v<
Awaitable&&>) {
return operator co_await(std::forward<Awaitable>(awaitable));
} else {
// This is necessary for it to work with std::reference_wrapper
return static_cast<Awaitable&>(awaitable);
}
}
template <typename Awaitable>
auto createAwaitWrapper(Awaitable&& awaitable) {
using Awaiter =
decltype(::folly::coro::get_awaiter(std::declval<Awaitable&&>()));
using Wrapper = std::conditional_t<
std::is_reference<Awaiter>::value,
std::reference_wrapper<std::remove_reference_t<Awaiter>>,
Awaiter>;
return detail::AwaitWrapper<Wrapper>::create(
::folly::coro::get_awaiter(std::forward<Awaitable>(awaitable)));
}
template <typename Awaitable>
auto createAwaitWrapper(Awaitable&& awaitable, folly::Executor* executor) {
using Awaiter =
decltype(::folly::coro::get_awaiter(std::declval<Awaitable&&>()));
using Wrapper = std::conditional_t<
std::is_reference<Awaiter>::value,
std::reference_wrapper<std::remove_reference_t<Awaiter>>,
Awaiter>;
return detail::AwaitWrapper<Wrapper>::create(
::folly::coro::get_awaiter(std::forward<Awaitable>(awaitable)), executor);
}
#pragma clang diagnostic pop
} // namespace coro
} // namespace folly
......@@ -86,48 +86,46 @@ class Promise : public PromiseBase<T> {
}
template <typename U>
auto await_transform(folly::SemiFuture<U>& future) {
return folly::detail::FutureAwaitable<U>(future.via(executor_));
decltype(auto) await_transform(folly::SemiFuture<U>& future) {
return future.via(executor_);
}
template <typename U>
auto await_transform(folly::SemiFuture<U>&& future) {
return folly::detail::FutureAwaitable<U>(future.via(executor_));
decltype(auto) await_transform(folly::SemiFuture<U>&& future) {
return future.via(executor_);
}
template <typename U>
auto await_transform(folly::Future<U>& future) {
future = future.via(executor_);
return folly::detail::FutureRefAwaitable<U>(future);
decltype(auto) await_transform(folly::Future<U>& future) {
return future.via(executor_);
}
template <typename U>
auto await_transform(folly::Future<U>&& future) {
future = future.via(executor_);
return folly::detail::FutureRefAwaitable<U>(future);
decltype(auto) await_transform(folly::Future<U>&& future) {
return future.via(executor_);
}
template <typename U>
AwaitWrapper<Future<U>> await_transform(Future<U>& future) {
auto await_transform(Future<U>& future) {
if (future.promise_->executor_ == executor_) {
return AwaitWrapper<Future<U>>::create(future);
return createAwaitWrapper(future);
}
return AwaitWrapper<Future<U>>::create(future, executor_);
return createAwaitWrapper(future, executor_);
}
template <typename U>
AwaitWrapper<Future<U>> await_transform(Future<U>&& future) {
auto await_transform(Future<U>&& future) {
if (future.promise_->executor_ == executor_) {
return AwaitWrapper<Future<U>>::create(&future);
return createAwaitWrapper(future);
}
return AwaitWrapper<Future<U>>::create(&future, executor_);
return createAwaitWrapper(future, executor_);
}
template <typename U>
AwaitWrapper<U> await_transform(U&& awaitable) {
return AwaitWrapper<U>::create(&awaitable, executor_);
auto await_transform(U&& awaitable) {
return createAwaitWrapper(std::forward<U>(awaitable), executor_);
}
auto await_transform(getCurrentExecutor) {
......
......@@ -223,4 +223,59 @@ TEST(Coro, TimedWait) {
via(&executor, taskTimedWait()).toFuture().getVia(&executor);
}
template <int value>
struct AwaitableInt {
bool await_ready() const {
return true;
}
bool await_suspend(std::experimental::coroutine_handle<>) {
LOG(FATAL) << "Should never be called.";
}
int await_resume() {
return value;
}
};
struct AwaitableWithOperator {};
AwaitableInt<42> operator co_await(const AwaitableWithOperator&) {
return {};
}
coro::Task<int> taskAwaitableWithOperator() {
co_return co_await AwaitableWithOperator();
}
TEST(Coro, AwaitableWithOperator) {
ManualExecutor executor;
EXPECT_EQ(
42,
via(&executor, taskAwaitableWithOperator()).toFuture().getVia(&executor));
}
struct AwaitableWithMemberOperator {
AwaitableInt<42> operator co_await() {
return {};
}
};
AwaitableInt<24> operator co_await(const AwaitableWithMemberOperator&) {
return {};
}
coro::Task<int> taskAwaitableWithMemberOperator() {
co_return co_await AwaitableWithMemberOperator();
}
TEST(Coro, AwaitableWithMemberOperator) {
ManualExecutor executor;
EXPECT_EQ(
42,
via(&executor, taskAwaitableWithMemberOperator())
.toFuture()
.getVia(&executor));
}
#endif
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment