Commit 73a507a2 authored by Andrii Grynenko's avatar Andrii Grynenko Committed by Facebook Github Bot

Add support for types that implement operator co_await()

Summary: AwaitWrapper was previously only working with types which are Awaitable (i.e. have await_ready()/await_suspend()/await_resume()). This extends it to support types which have operator co_await() implemented.

Reviewed By: lewissbaker

Differential Revision: D8580759

fbshipit-source-id: 980f9c9f34c5a2302badb2ab7c4644883b14aa2c
parent 48a252cf
...@@ -19,9 +19,21 @@ ...@@ -19,9 +19,21 @@
#include <folly/ExceptionString.h> #include <folly/ExceptionString.h>
#include <folly/Executor.h> #include <folly/Executor.h>
#include <folly/Optional.h>
namespace folly { namespace folly {
namespace coro { namespace coro {
namespace detail {
template <typename T>
T&& getRef(T&& t) {
return std::forward<T>(t);
}
template <typename T>
T& getRef(std::reference_wrapper<T> t) {
return t.get();
}
template <typename Awaitable> template <typename Awaitable>
class AwaitWrapper { class AwaitWrapper {
...@@ -61,39 +73,40 @@ class AwaitWrapper { ...@@ -61,39 +73,40 @@ class AwaitWrapper {
std::experimental::coroutine_handle<> awaiter_; std::experimental::coroutine_handle<> awaiter_;
}; };
static AwaitWrapper create(Awaitable* awaitable) { AwaitWrapper(AwaitWrapper&& other)
return {awaitable}; : promise_(std::exchange(other.promise_, nullptr)),
awaitable_(std::move(other.awaitable_)) {}
AwaitWrapper& operator=(AwaitWrapper&&) = delete;
static AwaitWrapper create(Awaitable&& awaitable) {
return {std::move(awaitable)};
} }
static AwaitWrapper create(Awaitable* awaitable, Executor* executor) { static AwaitWrapper create(Awaitable&& awaitable, Executor* executor) {
auto ret = awaitWrapper(); auto ret = awaitWrapper();
ret.awaitable_ = awaitable; ret.awaitable_.emplace(std::move(awaitable));
ret.promise_->executor_ = executor; ret.promise_->executor_ = executor;
return ret; return ret;
} }
bool await_ready() { bool await_ready() {
return awaitable_->await_ready(); return getRef(*awaitable_).await_ready();
} }
using await_suspend_return_type = decltype(auto) await_suspend(std::experimental::coroutine_handle<> awaiter) {
decltype((*static_cast<Awaitable*>(nullptr))
.await_suspend(std::experimental::coroutine_handle<>()));
await_suspend_return_type await_suspend(
std::experimental::coroutine_handle<> awaiter) {
if (promise_) { if (promise_) {
promise_->awaiter_ = std::move(awaiter); promise_->awaiter_ = std::move(awaiter);
return awaitable_->await_suspend( return getRef(*awaitable_)
std::experimental::coroutine_handle<promise_type>::from_promise( .await_suspend(
*promise_)); std::experimental::coroutine_handle<promise_type>::from_promise(
*promise_));
} }
return awaitable_->await_suspend(awaiter); return getRef(*awaitable_).await_suspend(awaiter);
} }
decltype((*static_cast<Awaitable*>(nullptr)).await_resume()) await_resume() { decltype(auto) await_resume() {
return awaitable_->await_resume(); return getRef(*awaitable_).await_resume();
} }
~AwaitWrapper() { ~AwaitWrapper() {
...@@ -104,7 +117,9 @@ class AwaitWrapper { ...@@ -104,7 +117,9 @@ class AwaitWrapper {
} }
private: private:
AwaitWrapper(Awaitable* awaitable) : awaitable_(awaitable) {} AwaitWrapper(Awaitable&& awaitable) {
awaitable_.emplace(std::move(awaitable));
}
AwaitWrapper(promise_type& promise) : promise_(&promise) {} AwaitWrapper(promise_type& promise) : promise_(&promise) {}
static AwaitWrapper awaitWrapper() { static AwaitWrapper awaitWrapper() {
...@@ -112,7 +127,68 @@ class AwaitWrapper { ...@@ -112,7 +127,68 @@ class AwaitWrapper {
} }
promise_type* promise_{nullptr}; promise_type* promise_{nullptr};
Awaitable* awaitable_{nullptr};
Optional<Awaitable> awaitable_;
}; };
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wc++17-extensions"
FOLLY_CREATE_MEMBER_INVOKE_TRAITS(
member_operator_co_await_traits,
operator co_await);
template <typename Awaitable>
inline constexpr bool has_member_operator_co_await_v =
member_operator_co_await_traits::is_invocable<Awaitable>::value;
FOLLY_CREATE_FREE_INVOKE_TRAITS(
non_member_operator_co_await_traits,
operator co_await);
template <typename Awaitable>
inline constexpr bool has_non_member_operator_co_await_v =
non_member_operator_co_await_traits::is_invocable<Awaitable>::value;
} // namespace detail
template <typename Awaitable>
decltype(auto) get_awaiter(Awaitable&& awaitable) {
if constexpr (detail::has_member_operator_co_await_v<Awaitable&&>) {
return std::forward<Awaitable>(awaitable).operator co_await();
} else if constexpr (detail::has_non_member_operator_co_await_v<
Awaitable&&>) {
return operator co_await(std::forward<Awaitable>(awaitable));
} else {
// This is necessary for it to work with std::reference_wrapper
return static_cast<Awaitable&>(awaitable);
}
}
template <typename Awaitable>
auto createAwaitWrapper(Awaitable&& awaitable) {
using Awaiter =
decltype(::folly::coro::get_awaiter(std::declval<Awaitable&&>()));
using Wrapper = std::conditional_t<
std::is_reference<Awaiter>::value,
std::reference_wrapper<std::remove_reference_t<Awaiter>>,
Awaiter>;
return detail::AwaitWrapper<Wrapper>::create(
::folly::coro::get_awaiter(std::forward<Awaitable>(awaitable)));
}
template <typename Awaitable>
auto createAwaitWrapper(Awaitable&& awaitable, folly::Executor* executor) {
using Awaiter =
decltype(::folly::coro::get_awaiter(std::declval<Awaitable&&>()));
using Wrapper = std::conditional_t<
std::is_reference<Awaiter>::value,
std::reference_wrapper<std::remove_reference_t<Awaiter>>,
Awaiter>;
return detail::AwaitWrapper<Wrapper>::create(
::folly::coro::get_awaiter(std::forward<Awaitable>(awaitable)), executor);
}
#pragma clang diagnostic pop
} // namespace coro } // namespace coro
} // namespace folly } // namespace folly
...@@ -86,48 +86,46 @@ class Promise : public PromiseBase<T> { ...@@ -86,48 +86,46 @@ class Promise : public PromiseBase<T> {
} }
template <typename U> template <typename U>
auto await_transform(folly::SemiFuture<U>& future) { decltype(auto) await_transform(folly::SemiFuture<U>& future) {
return folly::detail::FutureAwaitable<U>(future.via(executor_)); return future.via(executor_);
} }
template <typename U> template <typename U>
auto await_transform(folly::SemiFuture<U>&& future) { decltype(auto) await_transform(folly::SemiFuture<U>&& future) {
return folly::detail::FutureAwaitable<U>(future.via(executor_)); return future.via(executor_);
} }
template <typename U> template <typename U>
auto await_transform(folly::Future<U>& future) { decltype(auto) await_transform(folly::Future<U>& future) {
future = future.via(executor_); return future.via(executor_);
return folly::detail::FutureRefAwaitable<U>(future);
} }
template <typename U> template <typename U>
auto await_transform(folly::Future<U>&& future) { decltype(auto) await_transform(folly::Future<U>&& future) {
future = future.via(executor_); return future.via(executor_);
return folly::detail::FutureRefAwaitable<U>(future);
} }
template <typename U> template <typename U>
AwaitWrapper<Future<U>> await_transform(Future<U>& future) { auto await_transform(Future<U>& future) {
if (future.promise_->executor_ == executor_) { if (future.promise_->executor_ == executor_) {
return AwaitWrapper<Future<U>>::create(future); return createAwaitWrapper(future);
} }
return AwaitWrapper<Future<U>>::create(future, executor_); return createAwaitWrapper(future, executor_);
} }
template <typename U> template <typename U>
AwaitWrapper<Future<U>> await_transform(Future<U>&& future) { auto await_transform(Future<U>&& future) {
if (future.promise_->executor_ == executor_) { if (future.promise_->executor_ == executor_) {
return AwaitWrapper<Future<U>>::create(&future); return createAwaitWrapper(future);
} }
return AwaitWrapper<Future<U>>::create(&future, executor_); return createAwaitWrapper(future, executor_);
} }
template <typename U> template <typename U>
AwaitWrapper<U> await_transform(U&& awaitable) { auto await_transform(U&& awaitable) {
return AwaitWrapper<U>::create(&awaitable, executor_); return createAwaitWrapper(std::forward<U>(awaitable), executor_);
} }
auto await_transform(getCurrentExecutor) { auto await_transform(getCurrentExecutor) {
......
...@@ -223,4 +223,59 @@ TEST(Coro, TimedWait) { ...@@ -223,4 +223,59 @@ TEST(Coro, TimedWait) {
via(&executor, taskTimedWait()).toFuture().getVia(&executor); via(&executor, taskTimedWait()).toFuture().getVia(&executor);
} }
template <int value>
struct AwaitableInt {
bool await_ready() const {
return true;
}
bool await_suspend(std::experimental::coroutine_handle<>) {
LOG(FATAL) << "Should never be called.";
}
int await_resume() {
return value;
}
};
struct AwaitableWithOperator {};
AwaitableInt<42> operator co_await(const AwaitableWithOperator&) {
return {};
}
coro::Task<int> taskAwaitableWithOperator() {
co_return co_await AwaitableWithOperator();
}
TEST(Coro, AwaitableWithOperator) {
ManualExecutor executor;
EXPECT_EQ(
42,
via(&executor, taskAwaitableWithOperator()).toFuture().getVia(&executor));
}
struct AwaitableWithMemberOperator {
AwaitableInt<42> operator co_await() {
return {};
}
};
AwaitableInt<24> operator co_await(const AwaitableWithMemberOperator&) {
return {};
}
coro::Task<int> taskAwaitableWithMemberOperator() {
co_return co_await AwaitableWithMemberOperator();
}
TEST(Coro, AwaitableWithMemberOperator) {
ManualExecutor executor;
EXPECT_EQ(
42,
via(&executor, taskAwaitableWithMemberOperator())
.toFuture()
.getVia(&executor));
}
#endif #endif
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment