Commit 97e9c15e authored by Lewis Baker's avatar Lewis Baker Committed by Facebook Github Bot

Make folly::coro::Baton safe to await concurrently by multiple coroutines

Summary:
This modifies the Baton data-structure to store a linked list of awaiters rather than storing a single awaiter.

When a coroutine awaits the Baton it now does a lock-free push onto a list of awaiters.

When the Baton is posted it atomically dequeues all awaiters from the list and then resumes each of them in turn.

Reviewed By: andriigrynenko

Differential Revision: D15310137

fbshipit-source-id: 895ebcf2b113fb270ad7abfedbaab68ea51de84c
parent d8bd1ecf
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
using namespace folly::coro; using namespace folly::coro;
Baton::~Baton() { Baton::~Baton() {
// Should not be any waiting coroutines when the baton is destruced. // Should not be any waiting coroutines when the baton is destructed.
// Caller should ensure the baton is posted before destructing. // Caller should ensure the baton is posted before destructing.
assert( assert(
state_.load(std::memory_order_relaxed) == static_cast<void*>(this) || state_.load(std::memory_order_relaxed) == static_cast<void*>(this) ||
...@@ -33,29 +33,30 @@ Baton::~Baton() { ...@@ -33,29 +33,30 @@ Baton::~Baton() {
} }
void Baton::post() noexcept { void Baton::post() noexcept {
void* signalledState = static_cast<void*>(this); void* const signalledState = static_cast<void*>(this);
void* oldValue = state_.exchange(signalledState, std::memory_order_acq_rel); void* oldValue = state_.exchange(signalledState, std::memory_order_acq_rel);
if (oldValue != signalledState && oldValue != nullptr) { if (oldValue != signalledState) {
// We are the first thread to set the state to signalled and there is // We are the first thread to set the state to signalled and there is
// a waiting coroutine. We are responsible for resuming it. // a waiting coroutine. We are responsible for resuming it.
WaitOperation* awaiter = static_cast<WaitOperation*>(oldValue); WaitOperation* awaiter = static_cast<WaitOperation*>(oldValue);
awaiter->awaitingCoroutine_.resume(); while (awaiter != nullptr) {
std::exchange(awaiter, awaiter->next_)->awaitingCoroutine_.resume();
}
} }
} }
bool Baton::waitImpl(WaitOperation* awaiter) const noexcept { bool Baton::waitImpl(WaitOperation* awaiter) const noexcept {
void* oldValue = nullptr; // Try to push the awaiter onto the front of the queue of waiters.
if (!state_.compare_exchange_strong( const auto signalledState = static_cast<const void*>(this);
oldValue, void* oldValue = state_.load(std::memory_order_acquire);
static_cast<void*>(awaiter), do {
std::memory_order_release, if (oldValue == signalledState) {
std::memory_order_acquire)) { // Already in the signalled state, don't enqueue it.
// If the compare-exchange fails it should be because the baton was return false;
// set to the signalled state. If this not the case then this could }
// indicate that there are two awaiting coroutines. awaiter->next_ = static_cast<WaitOperation*>(oldValue);
assert(oldValue == static_cast<const void*>(this)); } while (!state_.compare_exchange_weak(
return false; oldValue, awaiter, std::memory_order_release, std::memory_order_acquire));
}
return true; return true;
} }
......
...@@ -21,13 +21,14 @@ ...@@ -21,13 +21,14 @@
namespace folly { namespace folly {
namespace coro { namespace coro {
/// A baton is a synchronisation primitive for coroutines that allows one /// A baton is a synchronisation primitive for coroutines that allows a
/// coroutine to co_await the baton and suspend until the baton is posted /// coroutine to co_await the baton and suspend until the baton is posted by
/// by some other thread via a call to .post(). /// some thread via a call to .post().
/// ///
/// The Baton supports being awaited by a single coroutine at a time. If the /// The Baton supports being awaited by multiple coroutines at a time. If the
/// baton is not ready at the time it is awaited then the awaiting coroutine /// baton is not ready at the time it is awaited then an awaiting coroutine
/// suspends and is later resumed when some thread calls .post(). /// suspends. All suspended coroutines waiting for the baton to be posted will
/// be resumed when some thread next calls .post().
/// ///
/// Example usage: /// Example usage:
/// ///
...@@ -71,19 +72,12 @@ class Baton { ...@@ -71,19 +72,12 @@ class Baton {
/// suspending. Otherwise, if the Baton is not yet signalled then the /// suspending. Otherwise, if the Baton is not yet signalled then the
/// awaiting coroutine will suspend execution and will be resumed when some /// awaiting coroutine will suspend execution and will be resumed when some
/// thread later calls post(). /// thread later calls post().
///
/// You may optionally specify an executor on which to resume executing the
/// awaiting coroutine if the baton was not already in the signalled state
/// by chaining a .via(executor) call. If you do not specify an executor then
/// the behaviour is as if an inline executor was specified.
/// i.e. the coroutine will be resumed inside the call to .post() on the
/// thread that next calls .post().
[[nodiscard]] WaitOperation operator co_await() const noexcept; [[nodiscard]] WaitOperation operator co_await() const noexcept;
/// Set the Baton to the signalled state if it is not already signalled. /// Set the Baton to the signalled state if it is not already signalled.
/// ///
/// This will resume any coroutines that are currently suspended waiting /// This will resume any coroutines that are currently suspended waiting
/// for the Baton inside 'co_await baton.waitAsync()'. /// for the Baton inside 'co_await baton'.
void post() noexcept; void post() noexcept;
/// Atomically reset the baton back to the non-signalled state. /// Atomically reset the baton back to the non-signalled state.
......
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
#include <folly/experimental/coro/Task.h> #include <folly/experimental/coro/Task.h>
#include <folly/portability/GTest.h> #include <folly/portability/GTest.h>
#include <stdio.h>
using namespace folly; using namespace folly;
TEST(Baton, Ready) { TEST(Baton, Ready) {
...@@ -66,4 +68,44 @@ TEST(Baton, AwaitBaton) { ...@@ -66,4 +68,44 @@ TEST(Baton, AwaitBaton) {
CHECK(reachedAfterAwait); CHECK(reachedAfterAwait);
} }
TEST(Baton, MultiAwaitBaton) {
coro::Baton baton;
bool reachedBeforeAwait1 = false;
bool reachedBeforeAwait2 = false;
bool reachedAfterAwait1 = false;
bool reachedAfterAwait2 = false;
auto makeTask1 = [&]() -> coro::Task<void> {
reachedBeforeAwait1 = true;
co_await baton;
reachedAfterAwait1 = true;
};
auto makeTask2 = [&]() -> coro::Task<void> {
reachedBeforeAwait2 = true;
co_await baton;
reachedAfterAwait2 = true;
};
coro::Task<void> t1 = makeTask1();
coro::Task<void> t2 = makeTask2();
auto f1 = std::move(t1).scheduleOn(&InlineExecutor::instance()).start();
auto f2 = std::move(t2).scheduleOn(&InlineExecutor::instance()).start();
CHECK(reachedBeforeAwait1);
CHECK(reachedBeforeAwait2);
CHECK(!reachedAfterAwait1);
CHECK(!reachedAfterAwait2);
baton.post();
CHECK(f1.isReady());
CHECK(f2.isReady());
CHECK(reachedAfterAwait1);
CHECK(reachedAfterAwait2);
}
#endif #endif
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment