Commit f07d48a0 authored by Andrii Grynenko's avatar Andrii Grynenko Committed by Facebook Github Bot

Always manage lifetime of a SerialExecutor using KeepAlive tokens

Summary: SerialExecutor is always wrapping another Executor, so we shouldn't force users to coordinate their lifetimes.

Reviewed By: yfeldblum

Differential Revision: D7856146

fbshipit-source-id: ac7caaa0f181406dfb6b59d36ae4efe6d1343590
parent d233e724
...@@ -72,18 +72,54 @@ class SerialExecutor::TaskQueueImpl { ...@@ -72,18 +72,54 @@ class SerialExecutor::TaskQueueImpl {
std::queue<Func> queue_; std::queue<Func> queue_;
}; };
SerialExecutor::SerialExecutor(std::shared_ptr<folly::Executor> parent) SerialExecutor::SerialExecutor(KeepAlive<Executor> parent)
: parent_(std::move(parent)), : parent_(std::move(parent)),
taskQueueImpl_(std::make_shared<TaskQueueImpl>()) {} taskQueueImpl_(std::make_shared<TaskQueueImpl>()) {}
SerialExecutor::~SerialExecutor() {
DCHECK(!keepAliveCounter_);
}
Executor::KeepAlive<SerialExecutor> SerialExecutor::create(
KeepAlive<Executor> parent) {
return makeKeepAlive<SerialExecutor>(new SerialExecutor(std::move(parent)));
}
SerialExecutor::UniquePtr SerialExecutor::createUnique(
std::shared_ptr<Executor> parent) {
auto executor = new SerialExecutor(getKeepAliveToken(parent.get()));
return {executor, Deleter{std::move(parent)}};
}
bool SerialExecutor::keepAliveAcquire() {
auto keepAliveCounter =
keepAliveCounter_.fetch_add(1, std::memory_order_relaxed);
DCHECK(keepAliveCounter > 0);
return true;
}
void SerialExecutor::keepAliveRelease() {
auto keepAliveCounter = --keepAliveCounter_;
DCHECK(keepAliveCounter >= 0);
if (!keepAliveCounter) {
delete this;
}
}
void SerialExecutor::add(Func func) { void SerialExecutor::add(Func func) {
taskQueueImpl_->add(std::move(func)); taskQueueImpl_->add(std::move(func));
parent_->add([impl = taskQueueImpl_] { impl->run(); }); parent_->add([impl = taskQueueImpl_, keepAlive = getKeepAliveToken(this)] {
impl->run();
});
} }
void SerialExecutor::addWithPriority(Func func, int8_t priority) { void SerialExecutor::addWithPriority(Func func, int8_t priority) {
taskQueueImpl_->add(std::move(func)); taskQueueImpl_->add(std::move(func));
parent_->addWithPriority([impl = taskQueueImpl_] { impl->run(); }, priority); parent_->addWithPriority(
[impl = taskQueueImpl_, keepAlive = getKeepAliveToken(this)] {
impl->run();
},
priority);
} }
} // namespace folly } // namespace folly
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#pragma once #pragma once
#include <atomic>
#include <memory> #include <memory>
#include <folly/executors/GlobalExecutor.h> #include <folly/executors/GlobalExecutor.h>
...@@ -48,14 +49,35 @@ namespace folly { ...@@ -48,14 +49,35 @@ namespace folly {
class SerialExecutor : public SequencedExecutor { class SerialExecutor : public SequencedExecutor {
public: public:
~SerialExecutor() override = default; ~SerialExecutor() override;
SerialExecutor(SerialExecutor const&) = delete; SerialExecutor(SerialExecutor const&) = delete;
SerialExecutor& operator=(SerialExecutor const&) = delete; SerialExecutor& operator=(SerialExecutor const&) = delete;
SerialExecutor(SerialExecutor&&) = default; SerialExecutor(SerialExecutor&&) = default;
SerialExecutor& operator=(SerialExecutor&&) = default; SerialExecutor& operator=(SerialExecutor&&) = default;
explicit SerialExecutor( static KeepAlive<SerialExecutor> create(
std::shared_ptr<folly::Executor> parent = folly::getCPUExecutor()); KeepAlive<Executor> parent = getKeepAliveToken(getCPUExecutor().get()));
class Deleter {
public:
Deleter() {}
void operator()(SerialExecutor* executor) {
executor->keepAliveRelease();
}
private:
friend class SerialExecutor;
explicit Deleter(std::shared_ptr<Executor> parent)
: parent_(std::move(parent)) {}
std::shared_ptr<Executor> parent_;
};
using UniquePtr = std::unique_ptr<SerialExecutor, Deleter>;
[[deprecated("Replaced by create")]]
static UniquePtr createUnique(
std::shared_ptr<Executor> parent = getCPUExecutor());
/** /**
* Add one task for execution in the parent executor * Add one task for execution in the parent executor
...@@ -77,11 +99,20 @@ class SerialExecutor : public SequencedExecutor { ...@@ -77,11 +99,20 @@ class SerialExecutor : public SequencedExecutor {
return parent_->getNumPriorities(); return parent_->getNumPriorities();
} }
protected:
bool keepAliveAcquire() override;
void keepAliveRelease() override;
private: private:
explicit SerialExecutor(KeepAlive<Executor> parent);
class TaskQueueImpl; class TaskQueueImpl;
std::shared_ptr<folly::Executor> parent_; KeepAlive<Executor> parent_;
std::shared_ptr<TaskQueueImpl> taskQueueImpl_; std::shared_ptr<TaskQueueImpl> taskQueueImpl_;
std::atomic<ssize_t> keepAliveCounter_{1};
}; };
} // namespace folly } // namespace folly
...@@ -26,7 +26,9 @@ namespace folly { ...@@ -26,7 +26,9 @@ namespace folly {
bool isSequencedExecutor(folly::Executor& executor) { bool isSequencedExecutor(folly::Executor& executor) {
// Add can be called from different threads, but it should be sequenced. // Add can be called from different threads, but it should be sequenced.
SerialExecutor producer(std::make_shared<CPUThreadPoolExecutor>(4)); auto cpuExecutor = std::make_shared<CPUThreadPoolExecutor>(4);
auto producer =
SerialExecutor::create(Executor::getKeepAliveToken(cpuExecutor.get()));
std::atomic<size_t> nextCallIndex{0}; std::atomic<size_t> nextCallIndex{0};
std::atomic<bool> result{true}; std::atomic<bool> result{true};
...@@ -36,7 +38,7 @@ bool isSequencedExecutor(folly::Executor& executor) { ...@@ -36,7 +38,7 @@ bool isSequencedExecutor(folly::Executor& executor) {
constexpr size_t kNumCalls = 10000; constexpr size_t kNumCalls = 10000;
for (size_t callIndex = 0; callIndex < kNumCalls; ++callIndex) { for (size_t callIndex = 0; callIndex < kNumCalls; ++callIndex) {
producer.add([&result, &executor, &nextCallIndex, callIndex, joinPromise] { producer->add([&result, &executor, &nextCallIndex, callIndex, joinPromise] {
executor.add([&result, &nextCallIndex, callIndex, joinPromise] { executor.add([&result, &nextCallIndex, callIndex, joinPromise] {
if (nextCallIndex != callIndex) { if (nextCallIndex != callIndex) {
result = false; result = false;
...@@ -69,8 +71,10 @@ TEST(SequencedExecutor, CPUThreadPoolExecutor) { ...@@ -69,8 +71,10 @@ TEST(SequencedExecutor, CPUThreadPoolExecutor) {
} }
TEST(SequencedExecutor, SerialCPUThreadPoolExecutor) { TEST(SequencedExecutor, SerialCPUThreadPoolExecutor) {
SerialExecutor executor(std::make_shared<CPUThreadPoolExecutor>(4)); auto cpuExecutor = std::make_shared<CPUThreadPoolExecutor>(4);
testExecutor(executor); auto executor =
SerialExecutor::create(Executor::getKeepAliveToken(cpuExecutor.get()));
testExecutor(*executor);
} }
TEST(SequencedExecutor, EventBase) { TEST(SequencedExecutor, EventBase) {
......
...@@ -32,13 +32,14 @@ void burnMs(uint64_t ms) { ...@@ -32,13 +32,14 @@ void burnMs(uint64_t ms) {
} // namespace } // namespace
void SimpleTest(std::shared_ptr<folly::Executor> const& parent) { void SimpleTest(std::shared_ptr<folly::Executor> const& parent) {
SerialExecutor executor(parent); auto executor =
SerialExecutor::create(folly::getKeepAliveToken(parent.get()));
std::vector<int> values; std::vector<int> values;
std::vector<int> expected; std::vector<int> expected;
for (int i = 0; i < 20; ++i) { for (int i = 0; i < 20; ++i) {
executor.add([i, &values] { executor->add([i, &values] {
// make this extra vulnerable to concurrent execution // make this extra vulnerable to concurrent execution
values.push_back(0); values.push_back(0);
burnMs(10); burnMs(10);
...@@ -49,7 +50,7 @@ void SimpleTest(std::shared_ptr<folly::Executor> const& parent) { ...@@ -49,7 +50,7 @@ void SimpleTest(std::shared_ptr<folly::Executor> const& parent) {
// wait until last task has executed // wait until last task has executed
folly::Baton<> finished_baton; folly::Baton<> finished_baton;
executor.add([&finished_baton] { finished_baton.post(); }); executor->add([&finished_baton] { finished_baton.post(); });
finished_baton.wait(); finished_baton.wait();
EXPECT_EQ(expected, values); EXPECT_EQ(expected, values);
...@@ -67,7 +68,8 @@ TEST(SerialExecutor, SimpleInline) { ...@@ -67,7 +68,8 @@ TEST(SerialExecutor, SimpleInline) {
// destroy the SerialExecutor // destroy the SerialExecutor
TEST(SerialExecutor, Afterlife) { TEST(SerialExecutor, Afterlife) {
auto cpu_executor = std::make_shared<folly::CPUThreadPoolExecutor>(4); auto cpu_executor = std::make_shared<folly::CPUThreadPoolExecutor>(4);
auto executor = std::make_unique<SerialExecutor>(cpu_executor); auto executor =
SerialExecutor::create(folly::getKeepAliveToken(cpu_executor.get()));
// block executor until we call start_baton.post() // block executor until we call start_baton.post()
folly::Baton<> start_baton; folly::Baton<> start_baton;
...@@ -102,7 +104,8 @@ TEST(SerialExecutor, Afterlife) { ...@@ -102,7 +104,8 @@ TEST(SerialExecutor, Afterlife) {
} }
void RecursiveAddTest(std::shared_ptr<folly::Executor> const& parent) { void RecursiveAddTest(std::shared_ptr<folly::Executor> const& parent) {
SerialExecutor executor(parent); auto executor =
SerialExecutor::create(folly::getKeepAliveToken(parent.get()));
folly::Baton<> finished_baton; folly::Baton<> finished_baton;
...@@ -116,7 +119,7 @@ void RecursiveAddTest(std::shared_ptr<folly::Executor> const& parent) { ...@@ -116,7 +119,7 @@ void RecursiveAddTest(std::shared_ptr<folly::Executor> const& parent) {
values.push_back(0); values.push_back(0);
burnMs(10); burnMs(10);
values.back() = i; values.back() = i;
executor.add(lambda); executor->add(lambda);
} else if (i < 12) { } else if (i < 12) {
// Below we will post this lambda three times to the executor. When // Below we will post this lambda three times to the executor. When
// executed, the lambda will re-post itself during the first ten // executed, the lambda will re-post itself during the first ten
...@@ -128,9 +131,9 @@ void RecursiveAddTest(std::shared_ptr<folly::Executor> const& parent) { ...@@ -128,9 +131,9 @@ void RecursiveAddTest(std::shared_ptr<folly::Executor> const& parent) {
++i; ++i;
}; };
executor.add(lambda); executor->add(lambda);
executor.add(lambda); executor->add(lambda);
executor.add(lambda); executor->add(lambda);
// wait until last task has executed // wait until last task has executed
finished_baton.wait(); finished_baton.wait();
...@@ -146,9 +149,9 @@ TEST(SerialExecutor, RecursiveAddInline) { ...@@ -146,9 +149,9 @@ TEST(SerialExecutor, RecursiveAddInline) {
} }
TEST(SerialExecutor, ExecutionThrows) { TEST(SerialExecutor, ExecutionThrows) {
SerialExecutor executor(std::make_shared<folly::InlineExecutor>()); auto executor = SerialExecutor::create();
// an empty Func will throw std::bad_function_call when invoked, // an empty Func will throw std::bad_function_call when invoked,
// but SerialExecutor should catch that exception // but SerialExecutor should catch that exception
executor.add(folly::Func{}); executor->add(folly::Func{});
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment