Commit f54c24f8 authored by Lewis Baker's avatar Lewis Baker Committed by Facebook Github Bot

Make folly::coro::collectAllRange/Windowed work with vector<TaskWithExecutor<T>>

Summary:
Generalises the specialisation of collectAllRange() and collectAllWindowed() on a std::vector<Task<T>> to now work for any std::vector<SemiAwaitable>.

This will allow them to work when passed a std::vector<TaskWithExecutor<T>>.

Reviewed By: andriigrynenko

Differential Revision: D18049190

fbshipit-source-id: 658213f93ff3eadda041ac2262cc6bcb759e331c
parent c2a285ae
...@@ -210,20 +210,20 @@ auto collectAllTryRange(InputRange awaitables) ...@@ -210,20 +210,20 @@ auto collectAllTryRange(InputRange awaitables)
detail::range_reference_t<InputRange>>>>; detail::range_reference_t<InputRange>>>>;
// collectAllRange()/collectAllTryRange() overloads that simplifies the // collectAllRange()/collectAllTryRange() overloads that simplifies the
// common-case where an rvalue std::vector<Task<T>> is passed. // common-case where an rvalue std::vector<SemiAwaitable> is passed.
// //
// This avoids the caller needing to pipe the input through ranges::views::move // This avoids the caller needing to pipe the input through ranges::views::move
// transform to force the Task<T> elements to be rvalue-references since the // transform to force the elements to be rvalue-references since the
// std::vector<T>::reference type is T& rather than T&& and Task<T>& is not // std::vector<T>::reference type is T& rather than T&& and some awaitables,
// awaitable. // such as Task<U>, are not lvalue awaitable.
template <typename T> template <typename SemiAwaitable>
auto collectAllRange(std::vector<Task<T>> awaitables) auto collectAllRange(std::vector<SemiAwaitable> awaitables)
-> decltype(collectAllRange(awaitables | ranges::views::move)) { -> decltype(collectAllRange(awaitables | ranges::views::move)) {
co_return co_await collectAllRange(awaitables | ranges::views::move); co_return co_await collectAllRange(awaitables | ranges::views::move);
} }
template <typename T> template <typename SemiAwaitable>
auto collectAllTryRange(std::vector<Task<T>> awaitables) auto collectAllTryRange(std::vector<SemiAwaitable> awaitables)
-> decltype(collectAllTryRange(awaitables | ranges::views::move)) { -> decltype(collectAllTryRange(awaitables | ranges::views::move)) {
co_return co_await collectAllTryRange(awaitables | ranges::views::move); co_return co_await collectAllTryRange(awaitables | ranges::views::move);
} }
...@@ -289,10 +289,10 @@ auto collectAllTryWindowed(InputRange awaitables, std::size_t maxConcurrency) ...@@ -289,10 +289,10 @@ auto collectAllTryWindowed(InputRange awaitables, std::size_t maxConcurrency)
detail::range_reference_t<InputRange>>>>; detail::range_reference_t<InputRange>>>>;
// collectAllWindowed()/collectAllTryWindowed() overloads that simplify the // collectAllWindowed()/collectAllTryWindowed() overloads that simplify the
// use of these functions with std::vector<Task<T>>. // use of these functions with std::vector<SemiAwaitable>.
template <typename T> template <typename SemiAwaitable>
auto collectAllWindowed( auto collectAllWindowed(
std::vector<Task<T>> awaitables, std::vector<SemiAwaitable> awaitables,
std::size_t maxConcurrency) std::size_t maxConcurrency)
-> decltype( -> decltype(
collectAllWindowed(awaitables | ranges::views::move, maxConcurrency)) { collectAllWindowed(awaitables | ranges::views::move, maxConcurrency)) {
...@@ -300,9 +300,9 @@ auto collectAllWindowed( ...@@ -300,9 +300,9 @@ auto collectAllWindowed(
awaitables | ranges::views::move, maxConcurrency); awaitables | ranges::views::move, maxConcurrency);
} }
template <typename T> template <typename SemiAwaitable>
auto collectAllTryWindowed( auto collectAllTryWindowed(
std::vector<Task<T>> awaitables, std::vector<SemiAwaitable> awaitables,
std::size_t maxConcurrency) std::size_t maxConcurrency)
-> decltype(collectAllTryWindowed( -> decltype(collectAllTryWindowed(
awaitables | ranges::views::move, awaitables | ranges::views::move,
......
...@@ -357,6 +357,21 @@ TEST(CollectAll, CollectAllKeepsRequestContextOfChildTasksIndependent) { ...@@ -357,6 +357,21 @@ TEST(CollectAll, CollectAllKeepsRequestContextOfChildTasksIndependent) {
}()); }());
} }
TEST(CollectAll, TaskWithExecutorUsage) {
folly::CPUThreadPoolExecutor threadPool{
4, std::make_shared<folly::NamedThreadFactory>("TestThreadPool")};
folly::coro::blockingWait([&]() -> folly::coro::Task<void> {
auto [a, b] = co_await folly::coro::collectAll(
[]() -> folly::coro::Task<int> { co_return 42; }().scheduleOn(
&threadPool),
[]() -> folly::coro::Task<std::string> { co_return "coroutine"; }()
.scheduleOn(&threadPool));
CHECK(a == 42);
CHECK(b == "coroutine");
}());
}
///////////////////////////////////////////////////////// /////////////////////////////////////////////////////////
// folly::coro::collectAllTry() tests // folly::coro::collectAllTry() tests
...@@ -762,6 +777,27 @@ TEST( ...@@ -762,6 +777,27 @@ TEST(
}()); }());
} }
TEST(CollectAllRange, VectorOfTaskWithExecutorUsage) {
folly::CPUThreadPoolExecutor threadPool{
4, std::make_shared<folly::NamedThreadFactory>("TestThreadPool")};
folly::coro::blockingWait([&]() -> folly::coro::Task<void> {
std::vector<folly::coro::TaskWithExecutor<int>> tasks;
for (int i = 0; i < 4; ++i) {
tasks.push_back(
[](int i) -> folly::coro::Task<int> { co_return i + 1; }(i)
.scheduleOn(&threadPool));
}
auto results = co_await folly::coro::collectAllRange(std::move(tasks));
CHECK(results.size() == 4);
CHECK(results[0] == 1);
CHECK(results[1] == 2);
CHECK(results[2] == 3);
CHECK(results[3] == 4);
}());
}
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
// folly::coro::collectAllTryRange() tests // folly::coro::collectAllTryRange() tests
...@@ -1239,6 +1275,28 @@ TEST( ...@@ -1239,6 +1275,28 @@ TEST(
}()); }());
} }
TEST(CollectAllWindowed, VectorOfTaskWithExecutorUsage) {
folly::CPUThreadPoolExecutor threadPool{
4, std::make_shared<folly::NamedThreadFactory>("TestThreadPool")};
folly::coro::blockingWait([&]() -> folly::coro::Task<void> {
std::vector<folly::coro::TaskWithExecutor<int>> tasks;
for (int i = 0; i < 4; ++i) {
tasks.push_back(
[](int i) -> folly::coro::Task<int> { co_return i + 1; }(i)
.scheduleOn(&threadPool));
}
auto results =
co_await folly::coro::collectAllWindowed(std::move(tasks), 2);
CHECK(results.size() == 4);
CHECK(results[0] == 1);
CHECK(results[1] == 2);
CHECK(results[2] == 3);
CHECK(results[3] == 4);
}());
}
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
// folly::coro::collectAllTryWindowed() tests // folly::coro::collectAllTryWindowed() tests
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment