Commit dea4f330 authored by Lewis Baker's avatar Lewis Baker Committed by Facebook Github Bot

Replace AwaitWraper with new co_viaIfAsync operator

Summary:
Refactored the `AwaitWrapper` class to expose its functionality as a new `co_viaIfAsync()` operator that returns a `ViaIfAsyncAwaitable<Awaitable>` type.

The `folly::coro::Task` promise_type now delegates wrapping the awaited operations to the new `co_viaIfAsync()` operator. This allows SemiAwaitable types to customise how they are awaited from within a folly::coro::Task coroutine by providing an overload for `co_viaIfAsync()` rather than intrusively modifying the `folly::coro::Task` implementation.

Reviewed By: andriigrynenko

Differential Revision: D9983074

fbshipit-source-id: 87d80001879ac7aa3be473339ba259229dafc262
parent cce0bb0a
/*
* Copyright 2017-present Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <experimental/coroutine>
#include <glog/logging.h>
#include <folly/ExceptionString.h>
#include <folly/Executor.h>
#include <folly/Optional.h>
#include <folly/experimental/coro/Traits.h>
namespace folly {
namespace coro {
namespace detail {
template <typename T>
T&& getRef(T&& t) {
return std::forward<T>(t);
}
template <typename T>
T& getRef(std::reference_wrapper<T> t) {
return t.get();
}
template <typename Awaitable>
class AwaitWrapper {
public:
struct promise_type {
std::experimental::suspend_always initial_suspend() {
return {};
}
auto final_suspend() {
struct FinalAwaiter {
bool await_ready() noexcept {
return false;
}
void await_suspend(
std::experimental::coroutine_handle<promise_type> h) noexcept {
auto& p = h.promise();
p.executor_->add(p.awaiter_);
}
void await_resume() noexcept {}
};
return FinalAwaiter{};
}
void return_void() {}
void unhandled_exception() {
LOG(FATAL) << "Failed to schedule a task to awake a coroutine: "
<< exceptionStr(std::current_exception());
}
AwaitWrapper get_return_object() {
return {*this};
}
Executor* executor_;
std::experimental::coroutine_handle<> awaiter_;
};
AwaitWrapper(AwaitWrapper&& other)
: promise_(std::exchange(other.promise_, nullptr)),
awaitable_(std::move(other.awaitable_)) {}
AwaitWrapper& operator=(AwaitWrapper&&) = delete;
static AwaitWrapper create(Awaitable&& awaitable) {
return {std::move(awaitable)};
}
static AwaitWrapper create(Awaitable&& awaitable, Executor* executor) {
auto ret = awaitWrapper();
ret.awaitable_.emplace(std::move(awaitable));
ret.promise_->executor_ = executor;
return ret;
}
bool await_ready() {
return getRef(*awaitable_).await_ready();
}
decltype(auto) await_suspend(std::experimental::coroutine_handle<> awaiter) {
if (promise_) {
promise_->awaiter_ = std::move(awaiter);
return getRef(*awaitable_)
.await_suspend(
std::experimental::coroutine_handle<promise_type>::from_promise(
*promise_));
}
return getRef(*awaitable_).await_suspend(awaiter);
}
decltype(auto) await_resume() {
return getRef(*awaitable_).await_resume();
}
~AwaitWrapper() {
if (promise_) {
std::experimental::coroutine_handle<promise_type>::from_promise(*promise_)
.destroy();
}
}
private:
AwaitWrapper(Awaitable&& awaitable) {
awaitable_.emplace(std::move(awaitable));
}
AwaitWrapper(promise_type& promise) : promise_(&promise) {}
static AwaitWrapper awaitWrapper() {
co_return;
}
promise_type* promise_{nullptr};
Optional<Awaitable> awaitable_;
};
} // namespace detail
template <typename Awaitable>
auto createAwaitWrapper(Awaitable&& awaitable) {
using Awaiter = folly::coro::awaiter_type_t<Awaitable>;
using Wrapper = std::conditional_t<
std::is_reference<Awaiter>::value,
std::reference_wrapper<std::remove_reference_t<Awaiter>>,
Awaiter>;
return detail::AwaitWrapper<Wrapper>::create(
::folly::coro::get_awaiter(std::forward<Awaitable>(awaitable)));
}
template <typename Awaitable>
auto createAwaitWrapper(Awaitable&& awaitable, folly::Executor* executor) {
using Awaiter = folly::coro::awaiter_type_t<Awaitable>;
using Wrapper = std::conditional_t<
std::is_reference<Awaiter>::value,
std::reference_wrapper<std::remove_reference_t<Awaiter>>,
Awaiter>;
return detail::AwaitWrapper<Wrapper>::create(
::folly::coro::get_awaiter(std::forward<Awaitable>(awaitable)), executor);
}
} // namespace coro
} // namespace folly
...@@ -21,9 +21,9 @@ ...@@ -21,9 +21,9 @@
#include <folly/ExceptionWrapper.h> #include <folly/ExceptionWrapper.h>
#include <folly/Try.h> #include <folly/Try.h>
#include <folly/experimental/coro/AwaitWrapper.h>
#include <folly/experimental/coro/Task.h> #include <folly/experimental/coro/Task.h>
#include <folly/experimental/coro/Utils.h> #include <folly/experimental/coro/Utils.h>
#include <folly/experimental/coro/ViaIfAsync.h>
#include <folly/futures/Future.h> #include <folly/futures/Future.h>
namespace folly { namespace folly {
...@@ -85,43 +85,15 @@ class Promise : public PromiseBase<T> { ...@@ -85,43 +85,15 @@ class Promise : public PromiseBase<T> {
return {}; return {};
} }
// Don't allow awaiting lvalues of these types.
template <typename U>
void await_transform(folly::SemiFuture<U>& future) = delete;
template <typename U>
void await_transform(folly::Future<U>& future) = delete;
template <typename U>
void await_transform(Future<U>& future) = delete;
template <typename U>
void await_transform(Task<U>& task) = delete;
template <typename U> template <typename U>
auto await_transform(Task<U>&& task) { auto await_transform(Task<U>&& task) {
return std::move(task).viaInline(executor_); return std::move(task).viaInline(executor_);
} }
template <typename U> template <typename Awaitable>
decltype(auto) await_transform(folly::SemiFuture<U>&& future) { decltype(auto) await_transform(Awaitable&& awaitable) {
return future.via(executor_); using folly::coro::co_viaIfAsync;
} return co_viaIfAsync(executor_, std::forward<Awaitable>(awaitable));
template <typename U>
decltype(auto) await_transform(folly::Future<U>&& future) {
return future.via(executor_);
}
template <typename U>
auto await_transform(Future<U>&& future) {
if (future.promise_->executor_ == executor_) {
return createAwaitWrapper(std::move(future));
}
return createAwaitWrapper(std::move(future), executor_);
}
template <typename U>
auto await_transform(U&& awaitable) {
return createAwaitWrapper(std::forward<U>(awaitable), executor_);
} }
auto await_transform(getCurrentExecutor) { auto await_transform(getCurrentExecutor) {
......
/*
* Copyright 2017-present Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <experimental/coroutine>
#include <memory>
#include <folly/Executor.h>
#include <folly/experimental/coro/Traits.h>
#include <glog/logging.h>
namespace folly {
class InlineExecutor;
namespace coro {
namespace detail {
class ViaCoroutine {
public:
class promise_type {
public:
promise_type(folly::Executor* executor) noexcept : executor_(executor) {}
ViaCoroutine get_return_object() noexcept {
return ViaCoroutine{
std::experimental::coroutine_handle<promise_type>::from_promise(
*this)};
}
std::experimental::suspend_always initial_suspend() {
return {};
}
auto final_suspend() {
struct Awaiter {
bool await_ready() noexcept {
return false;
}
void await_suspend(
std::experimental::coroutine_handle<promise_type> coro) noexcept {
// Schedule resumption of the coroutine on the executor.
auto& promise = coro.promise();
promise.executor_->add(promise.continuation_);
}
void await_resume() noexcept {}
};
return Awaiter{};
}
[[noreturn]] void unhandled_exception() noexcept {
LOG(FATAL) << "ViaCoroutine threw an unhandled exception";
}
void return_void() noexcept {}
void setContinuation(
std::experimental::coroutine_handle<> continuation) noexcept {
DCHECK(!continuation_);
continuation_ = continuation;
}
private:
folly::Executor* executor_;
std::experimental::coroutine_handle<> continuation_;
};
ViaCoroutine(ViaCoroutine&& other) noexcept
: coro_(std::exchange(other.coro_, {})) {}
~ViaCoroutine() {
destroy();
}
ViaCoroutine& operator=(ViaCoroutine other) noexcept {
swap(other);
return *this;
}
void swap(ViaCoroutine& other) noexcept {
std::swap(coro_, other.coro_);
}
std::experimental::coroutine_handle<> getWrappedCoroutine(
std::experimental::coroutine_handle<> continuation) noexcept {
if (coro_) {
coro_.promise().setContinuation(continuation);
return coro_;
} else {
return continuation;
}
}
void destroy() {
if (coro_) {
std::exchange(coro_, {}).destroy();
}
}
static ViaCoroutine create(folly::Executor* executor) {
co_return;
}
static ViaCoroutine createInline() noexcept {
return ViaCoroutine{std::experimental::coroutine_handle<promise_type>{}};
}
private:
friend class promise_type;
explicit ViaCoroutine(
std::experimental::coroutine_handle<promise_type> coro) noexcept
: coro_(coro) {}
std::experimental::coroutine_handle<promise_type> coro_;
};
} // namespace detail
template <typename Awaiter>
class ViaIfAsyncAwaiter {
public:
static_assert(
folly::coro::is_awaiter_v<Awaiter>,
"Awaiter type does not implement the Awaiter interface.");
template <typename Awaitable>
explicit ViaIfAsyncAwaiter(folly::InlineExecutor*, Awaitable&& awaitable)
: viaCoroutine_(detail::ViaCoroutine::createInline()),
awaiter_(
folly::coro::get_awaiter(static_cast<Awaitable&&>(awaitable))) {}
template <typename Awaitable>
explicit ViaIfAsyncAwaiter(folly::Executor* executor, Awaitable&& awaitable)
: viaCoroutine_(detail::ViaCoroutine::create(executor)),
awaiter_(
folly::coro::get_awaiter(static_cast<Awaitable&&>(awaitable))) {}
bool await_ready() noexcept(
noexcept(std::declval<Awaiter&>().await_ready())) {
return awaiter_.await_ready();
}
auto
await_suspend(std::experimental::coroutine_handle<> continuation) noexcept(
noexcept(std::declval<Awaiter&>().await_suspend(continuation))) {
return awaiter_.await_suspend(
viaCoroutine_.getWrappedCoroutine(continuation));
}
decltype(auto) await_resume() noexcept(
noexcept(std::declval<Awaiter&>().await_resume())) {
viaCoroutine_.destroy();
return awaiter_.await_resume();
}
detail::ViaCoroutine viaCoroutine_;
Awaiter awaiter_;
};
template <typename Awaitable>
class ViaIfAsyncAwaitable {
public:
explicit ViaIfAsyncAwaitable(
folly::Executor* executor,
Awaitable&&
awaitable) noexcept(std::is_nothrow_move_constructible<Awaitable>::
value)
: executor_(executor), awaitable_(static_cast<Awaitable&&>(awaitable)) {}
template <typename Awaitable2>
friend auto operator co_await(ViaIfAsyncAwaitable<Awaitable2>&& awaitable)
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<Awaitable2>>;
template <typename Awaitable2>
friend auto operator co_await(ViaIfAsyncAwaitable<Awaitable2>& awaitable)
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<Awaitable2&>>;
template <typename Awaitable2>
friend auto operator co_await(
const ViaIfAsyncAwaitable<Awaitable2>&& awaitable)
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<const Awaitable2&&>>;
template <typename Awaitable2>
friend auto operator co_await(
const ViaIfAsyncAwaitable<Awaitable2>& awaitable)
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<const Awaitable2&>>;
private:
folly::Executor* executor_;
Awaitable awaitable_;
};
template <typename Awaitable>
auto operator co_await(ViaIfAsyncAwaitable<Awaitable>&& awaitable)
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<Awaitable>> {
return ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<Awaitable>>{
awaitable.executor_, static_cast<Awaitable&&>(awaitable.awaitable_)};
}
template <typename Awaitable>
auto operator co_await(ViaIfAsyncAwaitable<Awaitable>& awaitable)
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<Awaitable&>> {
return ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<Awaitable&>>{
awaitable.executor_, awaitable.awaitable_};
}
template <typename Awaitable>
auto operator co_await(const ViaIfAsyncAwaitable<Awaitable>&& awaitable)
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<const Awaitable&&>> {
return ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<const Awaitable&&>>{
awaitable.executor_,
static_cast<const Awaitable&&>(awaitable.awaitable_)};
}
template <typename Awaitable>
auto operator co_await(const ViaIfAsyncAwaitable<Awaitable>& awaitable)
-> ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<const Awaitable&>> {
return ViaIfAsyncAwaiter<folly::coro::awaiter_type_t<const Awaitable&>>{
awaitable.executor_, awaitable.awaitable_};
}
/// Returns a new awaitable that will resume execution of the awaiting coroutine
/// on a specified executor in the case that the operation does not complete
/// synchronously.
///
/// If the operation completes synchronously then the awaiting coroutine
/// will continue execution on the current thread without transitioning
/// execution to the specified executor.
template <typename Awaitable>
auto co_viaIfAsync(folly::Executor* executor, Awaitable&& awaitable)
-> ViaIfAsyncAwaitable<Awaitable> {
static_assert(
folly::coro::is_awaitable_v<Awaitable>,
"co_viaIfAsync() argument 2 is not awaitable.");
return ViaIfAsyncAwaitable<Awaitable>{executor,
static_cast<Awaitable&&>(awaitable)};
}
} // namespace coro
} // namespace folly
...@@ -870,6 +870,15 @@ class SemiFuture : private futures::detail::FutureBase<T> { ...@@ -870,6 +870,15 @@ class SemiFuture : private futures::detail::FutureBase<T> {
co_return co_await std::forward<Awaitable>(awaitable); co_return co_await std::forward<Awaitable>(awaitable);
}(std::forward<Awaitable>(awaitable)); }(std::forward<Awaitable>(awaitable));
} }
// Customise the co_viaIfAsync() operator so that SemiFuture<T> can be
// directly awaited within a folly::coro::Task coroutine.
friend Future<T> co_viaIfAsync(
folly::Executor* executor,
SemiFuture<T>&& future) noexcept {
return std::move(future).via(executor);
}
#endif #endif
private: private:
...@@ -1911,6 +1920,18 @@ class Future : private futures::detail::FutureBase<T> { ...@@ -1911,6 +1920,18 @@ class Future : private futures::detail::FutureBase<T> {
return SemiFuture<T>{std::move(*this)}; return SemiFuture<T>{std::move(*this)};
} }
#if FOLLY_HAS_COROUTINES
// Overload needed to customise behaviour of awaiting a Future<T>
// inside a folly::coro::Task coroutine.
friend Future<T> co_viaIfAsync(
folly::Executor* executor,
Future<T>&& future) noexcept {
return std::move(future).via(executor);
}
#endif
protected: protected:
friend class Promise<T>; friend class Promise<T>;
template <class> template <class>
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment