Commit 189c99a6 authored by Kirk Shoop's avatar Kirk Shoop Committed by Facebook Github Bot

Add folly::coro::merge() operator for AsyncGenerator.

Summary: Adds a new `folly::coro::merge()` algorithm that accepts an async stream of async streams and returns an async stream that merges and interleaves the results.

Reviewed By: lewissbaker

Differential Revision: D15617383

fbshipit-source-id: cafb8e31725cdb21cfa00ce186901329828c2b68
parent 6e4965ce
......@@ -513,7 +513,7 @@ auto co_invoke(Func func, Args... args) -> std::enable_if_t<
auto asyncRange =
folly::invoke(static_cast<Func&&>(func), static_cast<Args&&>(args)...);
while (auto result = co_await asyncRange.next()) {
co_yield* result;
co_yield* std::move(result);
}
}
......
/*
* Copyright 2019-present Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/CancellationToken.h>
#include <folly/Executor.h>
#include <folly/ScopeGuard.h>
#include <folly/experimental/coro/Baton.h>
#include <folly/experimental/coro/Materialize.h>
#include <folly/experimental/coro/Mutex.h>
#include <folly/experimental/coro/Task.h>
#include <folly/experimental/coro/ViaIfAsync.h>
#include <folly/experimental/coro/WithCancellation.h>
#include <folly/experimental/coro/detail/Barrier.h>
#include <folly/experimental/coro/detail/BarrierTask.h>
#include <folly/experimental/coro/detail/Helpers.h>
#include <exception>
#include <memory>
namespace folly {
namespace coro {
template <typename Reference, typename Value>
AsyncGenerator<Reference, Value> merge(
folly::Executor::KeepAlive<> executor,
AsyncGenerator<AsyncGenerator<Reference, Value>> sources) {
struct SharedState {
explicit SharedState(folly::Executor::KeepAlive<> executor)
: executor(std::move(executor)) {}
const folly::Executor::KeepAlive<> executor;
const folly::CancellationSource cancelSource;
coro::Mutex mutex;
coro::Baton recordPublished;
coro::Baton recordConsumed;
CallbackRecord<Reference> record;
};
auto makeConsumerTask =
[](std::shared_ptr<SharedState> state,
AsyncGenerator<AsyncGenerator<Reference, Value>> sources)
-> Task<void> {
auto makeWorkerTask = [](std::shared_ptr<SharedState> state,
AsyncGenerator<Reference, Value> generator)
-> detail::DetachedBarrierTask {
exception_wrapper ex;
auto cancelToken = state->cancelSource.getToken();
try {
while (auto item = co_await co_viaIfAsync(
state->executor.get_alias(),
co_withCancellation(cancelToken, generator.next()))) {
// We have a new value to emit in the merged stream.
{
auto lock = co_await co_viaIfAsync(
state->executor.get_alias(), state->mutex.co_scoped_lock());
if (cancelToken.isCancellationRequested()) {
// Consumer has detached and doesn't want any more values.
// Discard this value.
break;
}
// Publish the value.
state->record = CallbackRecord<Reference>{callback_record_value,
*std::move(item)};
state->recordPublished.post();
// Wait until the consumer is finished with it.
co_await co_viaIfAsync(
state->executor.get_alias(), state->recordConsumed);
state->recordConsumed.reset();
// Clear the result before releasing the lock.
state->record = {};
}
if (cancelToken.isCancellationRequested()) {
break;
}
}
} catch (const std::exception& e) {
ex = exception_wrapper{std::current_exception(), e};
} catch (...) {
ex = exception_wrapper{std::current_exception()};
}
if (ex) {
state->cancelSource.requestCancellation();
auto lock = co_await co_viaIfAsync(
state->executor.get_alias(), state->mutex.co_scoped_lock());
if (!state->record.hasError()) {
state->record =
CallbackRecord<Reference>{callback_record_error, std::move(ex)};
state->recordPublished.post();
}
};
};
detail::Barrier barrier{1};
exception_wrapper ex;
try {
while (auto item = co_await sources.next()) {
if (state->cancelSource.isCancellationRequested()) {
break;
}
makeWorkerTask(state, *std::move(item)).start(&barrier);
}
} catch (const std::exception& e) {
ex = exception_wrapper{std::current_exception(), e};
} catch (...) {
ex = exception_wrapper{std::current_exception()};
}
if (ex) {
state->cancelSource.requestCancellation();
auto lock = co_await co_viaIfAsync(
state->executor.get_alias(), state->mutex.co_scoped_lock());
if (!state->record.hasError()) {
state->record =
CallbackRecord<Reference>{callback_record_error, std::move(ex)};
state->recordPublished.post();
}
};
// Wait for all worker tasks to finish consuming the entirety of their
// input streams.
co_await detail::UnsafeResumeInlineSemiAwaitable{barrier.arriveAndWait()};
// Guaranteed there are no more concurrent producers trying to acquire
// the mutex here.
if (!state->record.hasError()) {
// Stream not yet been terminated with an error.
// Terminate the stream with the 'end()' signal.
assert(!state->record.hasValue());
state->record = CallbackRecord<Reference>{callback_record_none};
state->recordPublished.post();
}
};
auto state = std::make_shared<SharedState>(executor);
SCOPE_EXIT {
state->cancelSource.requestCancellation();
};
// Start a task that consumes the stream of input streams.
makeConsumerTask(state, std::move(sources))
.scheduleOn(executor)
.start([](auto&&) {}, state->cancelSource.getToken());
// Consume values produced by the input streams.
while (true) {
if (!state->recordPublished.ready()) {
folly::CancellationCallback cb{
co_await co_current_cancellation_token,
[&] { state->cancelSource.requestCancellation(); }};
co_await state->recordPublished;
}
state->recordPublished.reset();
SCOPE_EXIT {
state->recordConsumed.post();
};
if (state->record.hasValue()) {
// next value
co_yield std::move(state->record).value();
} else if (state->record.hasError()) {
std::move(state->record).error().throw_exception();
} else {
// none
assert(state->record.hasNone());
break;
}
}
}
} // namespace coro
} // namespace folly
/*
* Copyright 2019-present Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <folly/experimental/coro/AsyncGenerator.h>
namespace folly {
namespace coro {
// Merge the results of a number of input streams.
//
// The 'executor' parameter represents specifies the execution context to
// be used for awaiting each value from the sources.
// The 'sources' parameter represents an async-stream of async-streams.
// The resulting generator merges the results from each of the streams
// produced by 'sources', interleaving them in the order that the values
// are produced.
//
// If any of the input streams completes with an error then the error
// is produced from the output stream and the remainder of the input streams
// are truncated, discarding any remaining values.
//
// The resulting stream will terminate only when the end of the 'sources'
// stream has been reached and the ends of all of the input streams it
// produced have been reached.
template <typename Reference, typename Value>
AsyncGenerator<Reference, Value> merge(
folly::Executor::KeepAlive<> executor,
AsyncGenerator<AsyncGenerator<Reference, Value>> sources);
} // namespace coro
} // namespace folly
#include <folly/experimental/coro/Merge-inl.h>
......@@ -45,7 +45,7 @@ class Barrier {
assert(SIZE_MAX - oldCount >= count);
}
std::experimental::coroutine_handle<> arrive() noexcept {
[[nodiscard]] std::experimental::coroutine_handle<> arrive() noexcept {
const std::size_t oldCount = count_.fetch_sub(1, std::memory_order_acq_rel);
// Invalid to call arrive() if you haven't previously incremented the
......
......@@ -25,10 +25,6 @@ namespace folly {
namespace coro {
namespace detail {
class BarrierTask;
class BarrierTaskPromise {};
class BarrierTask {
public:
class promise_type {
......@@ -134,6 +130,79 @@ class BarrierTask {
handle_t coro_;
};
class DetachedBarrierTask {
public:
class promise_type {
public:
DetachedBarrierTask get_return_object() noexcept {
return DetachedBarrierTask{
std::experimental::coroutine_handle<promise_type>::from_promise(
*this)};
}
std::experimental::suspend_always initial_suspend() noexcept {
return {};
}
auto final_suspend() noexcept {
struct awaiter {
bool await_ready() {
return false;
}
auto await_suspend(
std::experimental::coroutine_handle<promise_type> h) {
assert(h.promise().barrier_ != nullptr);
auto continuation = h.promise().barrier_->arrive();
h.destroy();
return continuation;
}
void await_resume() {}
};
return awaiter{};
}
[[noreturn]] void unhandled_exception() noexcept {
std::terminate();
}
void return_void() noexcept {}
void setBarrier(Barrier* barrier) noexcept {
barrier_ = barrier;
}
private:
Barrier* barrier_;
};
private:
using handle_t = std::experimental::coroutine_handle<promise_type>;
explicit DetachedBarrierTask(handle_t coro) : coro_(coro) {}
public:
DetachedBarrierTask(DetachedBarrierTask&& other) noexcept
: coro_(std::exchange(other.coro_, {})) {}
~DetachedBarrierTask() {
if (coro_) {
coro_.destroy();
}
}
void start(Barrier* barrier) && noexcept {
assert(coro_);
assert(barrier != nullptr);
barrier->add(1);
auto coro = std::exchange(coro_, {});
coro.promise().setBarrier(barrier);
coro.resume();
}
private:
handle_t coro_;
};
} // namespace detail
} // namespace coro
} // namespace folly
/*
* Copyright 2017-present Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/Portability.h>
#if FOLLY_HAS_COROUTINES
#include <folly/CancellationToken.h>
#include <folly/ScopeGuard.h>
#include <folly/experimental/coro/AsyncGenerator.h>
#include <folly/experimental/coro/BlockingWait.h>
#include <folly/experimental/coro/Collect.h>
#include <folly/experimental/coro/CurrentExecutor.h>
#include <folly/experimental/coro/Merge.h>
#include <folly/experimental/coro/Task.h>
#include <folly/portability/GTest.h>
using namespace folly::coro;
TEST(Merge, SimpleMerge) {
blockingWait([]() -> Task<void> {
auto generator = merge(
co_await co_current_executor,
[]() -> AsyncGenerator<AsyncGenerator<int>> {
auto makeGenerator = [](int start, int count) -> AsyncGenerator<int> {
for (int i = start; i < start + count; ++i) {
co_yield i;
co_await co_reschedule_on_current_executor;
}
};
co_yield makeGenerator(0, 3);
co_yield makeGenerator(3, 2);
}());
const std::array<int, 5> expectedValues = {0, 3, 1, 4, 2};
auto item = co_await generator.next();
for (int expectedValue : expectedValues) {
CHECK(!!item);
CHECK_EQ(expectedValue, *item);
item = co_await generator.next();
}
CHECK(!item);
}());
}
TEST(Merge, TruncateStream) {
blockingWait([]() -> Task<void> {
int started = 0;
int completed = 0;
{
auto generator = merge(
co_await co_current_executor,
co_invoke([&]() -> AsyncGenerator<AsyncGenerator<int>> {
auto makeGenerator = [&]() -> AsyncGenerator<int> {
++started;
SCOPE_EXIT {
++completed;
};
co_yield 1;
co_await co_reschedule_on_current_executor;
co_yield 2;
};
co_yield co_invoke(makeGenerator);
co_yield co_invoke(makeGenerator);
co_yield co_invoke(makeGenerator);
}));
auto item = co_await generator.next();
CHECK_EQ(1, *item);
item = co_await generator.next();
CHECK_EQ(1, *item);
CHECK_EQ(3, started);
// Truncate the stream after consuming only 2 of the 6 values it
// would have produced.
}
// Spin the executor until the generators finish responding to cancellation.
for (int i = 0; completed != started && i < 10; ++i) {
co_await co_reschedule_on_current_executor;
}
CHECK_EQ(3, completed);
}());
}
TEST(Merge, SequencesOfRValueReferences) {
blockingWait([]() -> Task<void> {
auto makeStreamOfStreams =
[]() -> AsyncGenerator<AsyncGenerator<std::vector<int>&&>> {
auto makeStreamOfVectors = []() -> AsyncGenerator<std::vector<int>&&> {
co_yield std::vector{1, 2, 3};
co_await co_reschedule_on_current_executor;
co_yield std::vector{2, 4, 6};
};
co_yield makeStreamOfVectors();
co_yield makeStreamOfVectors();
};
auto gen = merge(co_await co_current_executor, makeStreamOfStreams());
int resultCount = 0;
while (auto item = co_await gen.next()) {
++resultCount;
std::vector<int>&& v = *item;
CHECK_EQ(3, v.size());
}
CHECK_EQ(4, resultCount);
}());
}
TEST(Merge, SequencesOfLValueReferences) {
blockingWait([]() -> Task<void> {
auto makeStreamOfStreams =
[]() -> AsyncGenerator<AsyncGenerator<std::vector<int>&>> {
auto makeStreamOfVectors = []() -> AsyncGenerator<std::vector<int>&> {
std::vector<int> v{1, 2, 3};
co_yield v;
CHECK_EQ(4, v.size());
co_await co_reschedule_on_current_executor;
v.push_back(v.back());
co_yield v;
};
co_yield makeStreamOfVectors();
co_yield makeStreamOfVectors();
};
auto gen = merge(co_await co_current_executor, makeStreamOfStreams());
int resultCount = 0;
while (auto item = co_await gen.next()) {
++resultCount;
std::vector<int>& v = *item;
if (v.size() == 3) {
CHECK_EQ(1, v[0]);
CHECK_EQ(2, v[1]);
CHECK_EQ(3, v[2]);
v.push_back(7);
} else {
CHECK_EQ(5, v.size());
CHECK_EQ(1, v[0]);
CHECK_EQ(2, v[1]);
CHECK_EQ(3, v[2]);
CHECK_EQ(7, v[3]);
CHECK_EQ(7, v[4]);
}
}
CHECK_EQ(4, resultCount);
}());
}
template <typename Ref, typename Value = folly::remove_cvref_t<Ref>>
folly::coro::AsyncGenerator<Ref, Value> neverStream() {
folly::coro::Baton baton;
folly::CancellationCallback cb{
co_await folly::coro::co_current_cancellation_token,
[&] { baton.post(); }};
co_await baton;
}
TEST(Merge, CancellationTokenPropagatesToOuterFromConsumer) {
folly::coro::blockingWait([]() -> folly::coro::Task<void> {
folly::CancellationSource cancelSource;
bool suspended = false;
bool done = false;
co_await folly::coro::collectAll(
folly::coro::co_withCancellation(
cancelSource.getToken(),
[&]() -> folly::coro::Task<void> {
auto stream = merge(
co_await co_current_executor,
neverStream<AsyncGenerator<int>>());
suspended = true;
auto result = co_await stream.next();
CHECK(!result.has_value());
done = true;
}()),
[&]() -> folly::coro::Task<void> {
co_await folly::coro::co_reschedule_on_current_executor;
co_await folly::coro::co_reschedule_on_current_executor;
co_await folly::coro::co_reschedule_on_current_executor;
CHECK(suspended);
CHECK(!done);
cancelSource.requestCancellation();
}());
CHECK(done);
}());
}
TEST(Merge, CancellationTokenPropagatesToInnerFromConsumer) {
folly::coro::blockingWait([]() -> folly::coro::Task<void> {
folly::CancellationSource cancelSource;
bool suspended = false;
bool done = false;
auto makeStreamOfStreams = []() -> AsyncGenerator<AsyncGenerator<int>> {
co_yield neverStream<int>();
co_yield neverStream<int>();
};
co_await folly::coro::collectAll(
folly::coro::co_withCancellation(
cancelSource.getToken(),
[&]() -> folly::coro::Task<void> {
auto stream =
merge(co_await co_current_executor, makeStreamOfStreams());
suspended = true;
auto result = co_await stream.next();
CHECK(!result.has_value());
done = true;
}()),
[&]() -> folly::coro::Task<void> {
co_await folly::coro::co_reschedule_on_current_executor;
co_await folly::coro::co_reschedule_on_current_executor;
co_await folly::coro::co_reschedule_on_current_executor;
CHECK(suspended);
CHECK(!done);
cancelSource.requestCancellation();
}());
CHECK(done);
}());
}
#endif
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment