Commit 88ba2915 authored by Tom Jackson's avatar Tom Jackson Committed by Facebook Github Bot

Make pmap() handle exceptions instead of FATALing

Summary:
This switches the output queue from `Output` to `Expected<Output, std::exception_wrapper>`, allowing exceptions to be produced by the predicate. Prior to this change, any exceptions would FATAL, taking down the whole process.

Benchmarks appear unaffected. Memory use will be very similar: Small objects are already padded out to a cache line, so the increase in size will often be zero.

Reviewed By: luciang

Differential Revision: D14625210

fbshipit-source-id: 9bfa1197b7d2285bbc7e42e829beeb3525b3bd44
parent 848e0fa0
...@@ -20,11 +20,13 @@ ...@@ -20,11 +20,13 @@
#include <atomic> #include <atomic>
#include <cassert> #include <cassert>
#include <exception>
#include <thread> #include <thread>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <folly/Expected.h>
#include <folly/MPMCPipeline.h> #include <folly/MPMCPipeline.h>
#include <folly/experimental/EventCount.h> #include <folly/experimental/EventCount.h>
#include <folly/functional/Invoke.h> #include <folly/functional/Invoke.h>
...@@ -66,11 +68,13 @@ class PMap : public Operator<PMap<Predicate>> { ...@@ -66,11 +68,13 @@ class PMap : public Operator<PMap<Predicate>> {
Predicate pred_; Predicate pred_;
const size_t nThreads_; const size_t nThreads_;
using Result = folly::Expected<Output, std::exception_ptr>;
class ExecutionPipeline { class ExecutionPipeline {
std::vector<std::thread> workers_; std::vector<std::thread> workers_;
std::atomic<bool> done_{false}; std::atomic<bool> done_{false};
const Predicate& pred_; const Predicate& pred_;
MPMCPipeline<Input, Output> pipeline_; using Pipeline = MPMCPipeline<Input, Result>;
Pipeline pipeline_;
EventCount wake_; EventCount wake_;
public: public:
...@@ -109,12 +113,12 @@ class PMap : public Operator<PMap<Predicate>> { ...@@ -109,12 +113,12 @@ class PMap : public Operator<PMap<Predicate>> {
wake_.notify(); wake_.notify();
} }
bool read(Output& out) { bool read(Result& result) {
return pipeline_.read(out); return pipeline_.read(result);
} }
void blockingRead(Output& out) { void blockingRead(Result& result) {
pipeline_.blockingRead(out); pipeline_.blockingRead(result);
} }
private: private:
...@@ -128,11 +132,16 @@ class PMap : public Operator<PMap<Predicate>> { ...@@ -128,11 +132,16 @@ class PMap : public Operator<PMap<Predicate>> {
for (;;) { for (;;) {
auto key = wake_.prepareWait(); auto key = wake_.prepareWait();
typename MPMCPipeline<Input, Output>::template Ticket<0> ticket; typename Pipeline::template Ticket<0> ticket;
if (pipeline_.template readStage<0>(ticket, in)) { if (pipeline_.template readStage<0>(ticket, in)) {
wake_.cancelWait(); wake_.cancelWait();
Output out = pred_(std::move(in)); try {
pipeline_.template blockingWriteStage<0>(ticket, std::move(out)); Output out = pred_(std::move(in));
pipeline_.template blockingWriteStage<0>(ticket, std::move(out));
} catch (...) {
pipeline_.template blockingWriteStage<0>(
ticket, makeUnexpected(std::current_exception()));
}
continue; continue;
} }
...@@ -147,6 +156,13 @@ class PMap : public Operator<PMap<Predicate>> { ...@@ -147,6 +156,13 @@ class PMap : public Operator<PMap<Predicate>> {
} }
}; };
static Output&& getOutput(Result& result) {
if (result.hasError()) {
std::rethrow_exception(std::move(result).error());
}
return std::move(result).value();
}
public: public:
Generator(Source source, const Predicate& pred, size_t nThreads) Generator(Source source, const Predicate& pred, size_t nThreads)
: source_(std::move(source)), : source_(std::move(source)),
...@@ -168,10 +184,10 @@ class PMap : public Operator<PMap<Predicate>> { ...@@ -168,10 +184,10 @@ class PMap : public Operator<PMap<Predicate>> {
} }
// input queue full; drain ready items from the queue // input queue full; drain ready items from the queue
Output out; Result result;
while (pipeline.read(out)) { while (pipeline.read(result)) {
++read; ++read;
body(std::move(out)); body(getOutput(result));
} }
// write the value we were going to write before we made room. // write the value we were going to write before we made room.
...@@ -183,10 +199,10 @@ class PMap : public Operator<PMap<Predicate>> { ...@@ -183,10 +199,10 @@ class PMap : public Operator<PMap<Predicate>> {
// flush the output queue // flush the output queue
while (read < wrote) { while (read < wrote) {
Output out; Result result;
pipeline.blockingRead(out); pipeline.blockingRead(result);
++read; ++read;
body(std::move(out)); body(getOutput(result));
} }
} }
...@@ -206,10 +222,10 @@ class PMap : public Operator<PMap<Predicate>> { ...@@ -206,10 +222,10 @@ class PMap : public Operator<PMap<Predicate>> {
} }
// input queue full; drain ready items from the queue // input queue full; drain ready items from the queue
Output out; Result result;
while (pipeline.read(out)) { while (pipeline.read(result)) {
++read; ++read;
if (!handler(std::move(out))) { if (!handler(getOutput(result))) {
more = false; more = false;
return false; return false;
} }
...@@ -225,11 +241,11 @@ class PMap : public Operator<PMap<Predicate>> { ...@@ -225,11 +241,11 @@ class PMap : public Operator<PMap<Predicate>> {
// flush the output queue // flush the output queue
while (read < wrote) { while (read < wrote) {
Output out; Result result;
pipeline.blockingRead(out); pipeline.blockingRead(result);
++read; ++read;
if (more) { if (more && !handler(getOutput(result))) {
more = more && handler(std::move(out)); more = false;
} }
} }
return more; return more;
......
...@@ -33,7 +33,7 @@ DEFINE_int32( ...@@ -33,7 +33,7 @@ DEFINE_int32(
constexpr int kFib = 35; // unit of work constexpr int kFib = 35; // unit of work
size_t fib(int n) { size_t fib(int n) {
return n <= 1 ? 1 : fib(n - 1) * fib(n - 2); return n <= 1 ? 1 : fib(n - 1) + fib(n - 2);
} }
BENCHMARK(FibSumMap, n) { BENCHMARK(FibSumMap, n) {
......
...@@ -155,6 +155,10 @@ TEST(Pmap, Rvalues) { ...@@ -155,6 +155,10 @@ TEST(Pmap, Rvalues) {
} }
} }
TEST(Pmap, Exception) {
EXPECT_THROW(from({"a"}) | pmap(To<int>()) | count, std::runtime_error);
}
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
gflags::ParseCommandLineFlags(&argc, &argv, true); gflags::ParseCommandLineFlags(&argc, &argv, true);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment