Commit 118a3923 authored by Cameron Pickett's avatar Cameron Pickett Committed by Facebook GitHub Bot

Do not swallow child exceptions on cancellation

Summary:
Aligns `collectAll` and `collectAny` behaviour and simplifies the logic for propagating child exceptions on cancellation.

Prior to this change, we were conditionally setting `firstException` dependent on whether cancellation was requested or not. Additionally, `collectAny` was unconditionally returning `co_cancelled` even if all the child tasks completed successfully.

Now, the behaviour is consistent between general errors and cancellation: If any child task completes in error (including folly::OperationCancelled), then the parent `collectAll`/`collectAny` will propagate that failure. Otherwise, if no child fails, then `collectAll`/`collectAny` will propagate that success.

(Note: this ignores all push blocking failures!)

Reviewed By: iahs

Differential Revision: D31266586

fbshipit-source-id: b6eba6ab2a0a3634b112318b1810819d7916acdb
parent ff7c3177
...@@ -135,7 +135,6 @@ auto collectAllImpl( ...@@ -135,7 +135,6 @@ auto collectAllImpl(
CancellationToken::merge(parentCancelToken, cancelSource.getToken()); CancellationToken::merge(parentCancelToken, cancelSource.getToken());
exception_wrapper firstException; exception_wrapper firstException;
std::atomic<bool> anyFailures{false};
auto makeTask = [&](auto&& awaitable, auto& result) -> BarrierTask { auto makeTask = [&](auto&& awaitable, auto& result) -> BarrierTask {
using await_result = semi_await_result_t<decltype(awaitable)>; using await_result = semi_await_result_t<decltype(awaitable)>;
...@@ -153,9 +152,7 @@ auto collectAllImpl( ...@@ -153,9 +152,7 @@ auto collectAllImpl(
cancelToken, static_cast<decltype(awaitable)>(awaitable)))); cancelToken, static_cast<decltype(awaitable)>(awaitable))));
} }
} catch (...) { } catch (...) {
anyFailures.store(true, std::memory_order_relaxed); if (!cancelSource.requestCancellation()) {
if (!cancelSource.requestCancellation() &&
!parentCancelToken.isCancellationRequested()) {
// This was the first failure, remember its error. // This was the first failure, remember its error.
firstException = exception_wrapper{std::current_exception()}; firstException = exception_wrapper{std::current_exception()};
} }
...@@ -194,15 +191,8 @@ auto collectAllImpl( ...@@ -194,15 +191,8 @@ auto collectAllImpl(
// the use of co_viaIfAsync() within makeBarrierTask(). // the use of co_viaIfAsync() within makeBarrierTask().
co_await UnsafeResumeInlineSemiAwaitable{barrier.arriveAndWait()}; co_await UnsafeResumeInlineSemiAwaitable{barrier.arriveAndWait()};
if (anyFailures.load(std::memory_order_relaxed)) { if (firstException) {
if (firstException) { co_yield co_error(std::move(firstException));
co_yield co_error(std::move(firstException));
}
// Parent task was cancelled before any child tasks failed.
// Complete with the OperationCancelled error instead of the
// child task's errors.
co_yield co_cancelled;
} }
co_return std::tuple<collect_all_component_t<SemiAwaitables>...>{ co_return std::tuple<collect_all_component_t<SemiAwaitables>...>{
...@@ -311,10 +301,6 @@ auto collectAnyImpl( ...@@ -311,10 +301,6 @@ auto collectAnyImpl(
} }
}))...); }))...);
if (parentCancelToken.isCancellationRequested()) {
co_yield co_cancelled;
}
co_return firstCompletion; co_return firstCompletion;
} }
...@@ -377,7 +363,6 @@ auto collectAllRange(InputRange awaitables) ...@@ -377,7 +363,6 @@ auto collectAllRange(InputRange awaitables)
tryResults; tryResults;
exception_wrapper firstException; exception_wrapper firstException;
std::atomic<bool> anyFailures = false;
using awaitable_type = remove_cvref_t<detail::range_reference_t<InputRange>>; using awaitable_type = remove_cvref_t<detail::range_reference_t<InputRange>>;
auto makeTask = [&](awaitable_type semiAwaitable, auto makeTask = [&](awaitable_type semiAwaitable,
...@@ -389,7 +374,6 @@ auto collectAllRange(InputRange awaitables) ...@@ -389,7 +374,6 @@ auto collectAllRange(InputRange awaitables)
executor.get_alias(), executor.get_alias(),
co_withCancellation(cancelToken, std::move(semiAwaitable)))); co_withCancellation(cancelToken, std::move(semiAwaitable))));
} catch (...) { } catch (...) {
anyFailures.store(true, std::memory_order_relaxed);
if (!cancelSource.requestCancellation()) { if (!cancelSource.requestCancellation()) {
firstException = exception_wrapper{std::current_exception()}; firstException = exception_wrapper{std::current_exception()};
} }
...@@ -429,14 +413,8 @@ auto collectAllRange(InputRange awaitables) ...@@ -429,14 +413,8 @@ auto collectAllRange(InputRange awaitables)
} }
// Check if there were any exceptions and rethrow the first one. // Check if there were any exceptions and rethrow the first one.
if (anyFailures.load(std::memory_order_relaxed)) { if (firstException) {
if (firstException) { co_yield co_error(std::move(firstException));
co_yield co_error(std::move(firstException));
}
// Cancellation was requested of the parent Task before any of the
// child tasks failed.
co_yield co_cancelled;
} }
std::vector<detail::collect_all_range_component_t< std::vector<detail::collect_all_range_component_t<
...@@ -463,7 +441,6 @@ auto collectAllRange(InputRange awaitables) -> folly::coro::Task<void> { ...@@ -463,7 +441,6 @@ auto collectAllRange(InputRange awaitables) -> folly::coro::Task<void> {
co_await co_current_cancellation_token, cancelSource.getToken()); co_await co_current_cancellation_token, cancelSource.getToken());
exception_wrapper firstException; exception_wrapper firstException;
std::atomic<bool> anyFailures = false;
using awaitable_type = remove_cvref_t<detail::range_reference_t<InputRange>>; using awaitable_type = remove_cvref_t<detail::range_reference_t<InputRange>>;
auto makeTask = [&](awaitable_type semiAwaitable) -> detail::BarrierTask { auto makeTask = [&](awaitable_type semiAwaitable) -> detail::BarrierTask {
...@@ -472,7 +449,6 @@ auto collectAllRange(InputRange awaitables) -> folly::coro::Task<void> { ...@@ -472,7 +449,6 @@ auto collectAllRange(InputRange awaitables) -> folly::coro::Task<void> {
executor.get_alias(), executor.get_alias(),
co_withCancellation(cancelToken, std::move(semiAwaitable))); co_withCancellation(cancelToken, std::move(semiAwaitable)));
} catch (...) { } catch (...) {
anyFailures.store(true, std::memory_order_relaxed);
if (!cancelSource.requestCancellation()) { if (!cancelSource.requestCancellation()) {
firstException = exception_wrapper{std::current_exception()}; firstException = exception_wrapper{std::current_exception()};
} }
...@@ -509,10 +485,8 @@ auto collectAllRange(InputRange awaitables) -> folly::coro::Task<void> { ...@@ -509,10 +485,8 @@ auto collectAllRange(InputRange awaitables) -> folly::coro::Task<void> {
} }
// Check if there were any exceptions and rethrow the first one. // Check if there were any exceptions and rethrow the first one.
if (anyFailures.load(std::memory_order_relaxed)) { if (firstException) {
if (firstException) { co_yield co_error(std::move(firstException));
co_yield co_error(std::move(firstException));
}
} }
} }
...@@ -607,10 +581,8 @@ auto collectAllWindowed(InputRange awaitables, std::size_t maxConcurrency) ...@@ -607,10 +581,8 @@ auto collectAllWindowed(InputRange awaitables, std::size_t maxConcurrency)
co_await co_current_cancellation_token, cancelSource.getToken()); co_await co_current_cancellation_token, cancelSource.getToken());
exception_wrapper firstException; exception_wrapper firstException;
std::atomic<bool> anyFailures = false;
const auto trySetFirstException = [&](exception_wrapper&& e) noexcept { const auto trySetFirstException = [&](exception_wrapper&& e) noexcept {
anyFailures.store(true, std::memory_order_relaxed);
if (!cancelSource.requestCancellation()) { if (!cancelSource.requestCancellation()) {
// This is first entity to request cancellation. // This is first entity to request cancellation.
firstException = std::move(e); firstException = std::move(e);
...@@ -700,14 +672,8 @@ auto collectAllWindowed(InputRange awaitables, std::size_t maxConcurrency) ...@@ -700,14 +672,8 @@ auto collectAllWindowed(InputRange awaitables, std::size_t maxConcurrency)
co_await detail::UnsafeResumeInlineSemiAwaitable{barrier.arriveAndWait()}; co_await detail::UnsafeResumeInlineSemiAwaitable{barrier.arriveAndWait()};
if (iterationException) { if (auto& ex = iterationException ? iterationException : firstException) {
co_yield co_error(std::move(iterationException)); co_yield co_error(std::move(ex));
} else if (anyFailures.load(std::memory_order_relaxed)) {
if (firstException) {
co_yield co_error(std::move(firstException));
}
co_yield co_cancelled;
} }
} }
...@@ -731,12 +697,9 @@ auto collectAllWindowed(InputRange awaitables, std::size_t maxConcurrency) ...@@ -731,12 +697,9 @@ auto collectAllWindowed(InputRange awaitables, std::size_t maxConcurrency)
CancellationToken::merge(parentCancelToken, cancelSource.getToken()); CancellationToken::merge(parentCancelToken, cancelSource.getToken());
exception_wrapper firstException; exception_wrapper firstException;
std::atomic<bool> anyFailures = false;
auto trySetFirstException = [&](exception_wrapper&& e) noexcept { auto trySetFirstException = [&](exception_wrapper&& e) noexcept {
anyFailures.store(true, std::memory_order_relaxed); if (!cancelSource.requestCancellation()) {
if (!cancelSource.requestCancellation() &&
!parentCancelToken.isCancellationRequested()) {
// This is first entity to request cancellation. // This is first entity to request cancellation.
firstException = std::move(e); firstException = std::move(e);
} }
...@@ -846,16 +809,8 @@ auto collectAllWindowed(InputRange awaitables, std::size_t maxConcurrency) ...@@ -846,16 +809,8 @@ auto collectAllWindowed(InputRange awaitables, std::size_t maxConcurrency)
co_await detail::UnsafeResumeInlineSemiAwaitable{barrier.arriveAndWait()}; co_await detail::UnsafeResumeInlineSemiAwaitable{barrier.arriveAndWait()};
if (iterationException) { if (auto& ex = iterationException ? iterationException : firstException) {
co_yield co_error(std::move(iterationException)); co_yield co_error(std::move(ex));
} else if (anyFailures.load(std::memory_order_relaxed)) {
if (firstException) {
co_yield co_error(std::move(firstException));
}
// Otherwise, cancellation was requested before any of the child tasks
// failed so complete with the OperationCancelled error.
co_yield co_cancelled;
} }
std::vector<detail::collect_all_range_component_t< std::vector<detail::collect_all_range_component_t<
...@@ -1100,10 +1055,6 @@ auto collectAnyRange(InputRange awaitables) -> folly::coro::Task<std::pair< ...@@ -1100,10 +1055,6 @@ auto collectAnyRange(InputRange awaitables) -> folly::coro::Task<std::pair<
co_await folly::coro::co_withCancellation( co_await folly::coro::co_withCancellation(
cancelToken, folly::coro::collectAllRange(tasks | ranges::views::move)); cancelToken, folly::coro::collectAllRange(tasks | ranges::views::move));
if (parentCancelToken.isCancellationRequested()) {
co_yield co_cancelled;
}
co_return firstCompletion; co_return firstCompletion;
} }
......
...@@ -2039,31 +2039,27 @@ TEST_F(CollectAnyTest, CollectAnyCancelsSubtasksWhenParentTaskCancelled) { ...@@ -2039,31 +2039,27 @@ TEST_F(CollectAnyTest, CollectAnyCancelsSubtasksWhenParentTaskCancelled) {
folly::coro::blockingWait([]() -> folly::coro::Task<void> { folly::coro::blockingWait([]() -> folly::coro::Task<void> {
auto start = std::chrono::steady_clock::now(); auto start = std::chrono::steady_clock::now();
folly::CancellationSource cancelSource; folly::CancellationSource cancelSource;
try { auto [index, result] = co_await folly::coro::co_withCancellation(
auto [index, result] = co_await folly::coro::co_withCancellation( cancelSource.getToken(),
cancelSource.getToken(), folly::coro::collectAny(
folly::coro::collectAny( [&]() -> folly::coro::Task<int> {
[&]() -> folly::coro::Task<int> { co_await sleepThatShouldBeCancelled(10s);
co_await sleepThatShouldBeCancelled(10s); co_return 42;
co_return 42; }(),
}(), [&]() -> folly::coro::Task<int> {
[&]() -> folly::coro::Task<int> { co_await sleepThatShouldBeCancelled(5s);
co_await sleepThatShouldBeCancelled(5s); co_return 314;
co_return 314; }(),
}(), [&]() -> folly::coro::Task<int> {
[&]() -> folly::coro::Task<int> { co_await folly::coro::co_reschedule_on_current_executor;
co_await folly::coro::co_reschedule_on_current_executor; co_await folly::coro::co_reschedule_on_current_executor;
co_await folly::coro::co_reschedule_on_current_executor; co_await folly::coro::co_reschedule_on_current_executor;
co_await folly::coro::co_reschedule_on_current_executor; cancelSource.requestCancellation();
cancelSource.requestCancellation(); co_await sleepThatShouldBeCancelled(15s);
co_await sleepThatShouldBeCancelled(15s); co_return 123;
co_return 123; }()));
}())); auto end = std::chrono::steady_clock::now();
ADD_FAILURE() << "Hit unexpected codepath"; EXPECT_LT(end - start, 1s);
} catch (const folly::OperationCancelled&) {
auto end = std::chrono::steady_clock::now();
EXPECT_LT(end - start, 1s);
}
}()); }());
} }
...@@ -2530,34 +2526,29 @@ TEST_F(CollectAnyRangeTest, CollectAnyCancelsSubtasksWhenParentTaskCancelled) { ...@@ -2530,34 +2526,29 @@ TEST_F(CollectAnyRangeTest, CollectAnyCancelsSubtasksWhenParentTaskCancelled) {
folly::coro::blockingWait([]() -> folly::coro::Task<void> { folly::coro::blockingWait([]() -> folly::coro::Task<void> {
auto start = std::chrono::steady_clock::now(); auto start = std::chrono::steady_clock::now();
folly::CancellationSource cancelSource; folly::CancellationSource cancelSource;
try { auto generateTasks =
auto generateTasks = [&]() -> folly::coro::Generator<folly::coro::Task<int>&&> {
[&]() -> folly::coro::Generator<folly::coro::Task<int>&&> { co_yield [&]() -> folly::coro::Task<int> {
co_yield [&]() -> folly::coro::Task<int> { co_await sleepThatShouldBeCancelled(10s);
co_await sleepThatShouldBeCancelled(10s); co_return 42;
co_return 42; }();
}(); co_yield [&]() -> folly::coro::Task<int> {
co_yield [&]() -> folly::coro::Task<int> { co_await sleepThatShouldBeCancelled(5s);
co_await sleepThatShouldBeCancelled(5s); co_return 314;
co_return 314; }();
}(); co_yield [&]() -> folly::coro::Task<int> {
co_yield [&]() -> folly::coro::Task<int> { co_await folly::coro::co_reschedule_on_current_executor;
co_await folly::coro::co_reschedule_on_current_executor; co_await folly::coro::co_reschedule_on_current_executor;
co_await folly::coro::co_reschedule_on_current_executor; co_await folly::coro::co_reschedule_on_current_executor;
co_await folly::coro::co_reschedule_on_current_executor; cancelSource.requestCancellation();
cancelSource.requestCancellation(); co_await sleepThatShouldBeCancelled(15s);
co_await sleepThatShouldBeCancelled(15s); co_return 123;
co_return 123; }();
}(); };
}; auto [index, result] = co_await folly::coro::co_withCancellation(
auto [index, result] = co_await folly::coro::co_withCancellation( cancelSource.getToken(), folly::coro::collectAnyRange(generateTasks()));
cancelSource.getToken(), auto end = std::chrono::steady_clock::now();
folly::coro::collectAnyRange(generateTasks())); EXPECT_LT(end - start, 1s);
ADD_FAILURE() << "Hit unexpected codepath";
} catch (const folly::OperationCancelled&) {
auto end = std::chrono::steady_clock::now();
EXPECT_LT(end - start, 1s);
}
}()); }());
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment