Commit 9009c2b4 authored by Dave Watson's avatar Dave Watson Committed by Viswanath Sivakumar

Thread Observer

Summary: Observer methods, so users of IOThreadPoolExecutor can do stuff when threads are added/removed.  As a use case, previously the thrift server only used the threads already started when it started up, and assumed iothreadpool was never resized.

Test Plan: Added several unittests

Reviewed By: jsedgwick@fb.com

Subscribers: trunkagent, doug, fugalh, alandau, bmatheny, mshneer, folly-diffs@

FB internal diff: D1753861

Signature: t1:1753861:1420236825:54cbdfee0efb3b97dea35faba29c134f2b10a480
parent c9a5ee23
...@@ -105,6 +105,10 @@ void CPUThreadPoolExecutor::threadRun(std::shared_ptr<Thread> thread) { ...@@ -105,6 +105,10 @@ void CPUThreadPoolExecutor::threadRun(std::shared_ptr<Thread> thread) {
auto task = taskQueue_->take(); auto task = taskQueue_->take();
if (UNLIKELY(task.poison)) { if (UNLIKELY(task.poison)) {
CHECK(threadsToStop_-- > 0); CHECK(threadsToStop_-- > 0);
for (auto& o : observers_) {
o->threadStopped(thread.get());
}
stoppedThreads_.add(thread); stoppedThreads_.add(thread);
return; return;
} else { } else {
......
...@@ -117,6 +117,17 @@ EventBase* IOThreadPoolExecutor::getEventBase() { ...@@ -117,6 +117,17 @@ EventBase* IOThreadPoolExecutor::getEventBase() {
return pickThread()->eventBase; return pickThread()->eventBase;
} }
EventBase* IOThreadPoolExecutor::getEventBase(
ThreadPoolExecutor::ThreadHandle* h) {
auto thread = dynamic_cast<IOThread*>(h);
if (thread) {
return thread->eventBase;
}
return nullptr;
}
std::shared_ptr<ThreadPoolExecutor::Thread> std::shared_ptr<ThreadPoolExecutor::Thread>
IOThreadPoolExecutor::makeThread() { IOThreadPoolExecutor::makeThread() {
return std::make_shared<IOThread>(this); return std::make_shared<IOThread>(this);
...@@ -148,21 +159,14 @@ void IOThreadPoolExecutor::stopThreads(size_t n) { ...@@ -148,21 +159,14 @@ void IOThreadPoolExecutor::stopThreads(size_t n) {
for (size_t i = 0; i < n; i++) { for (size_t i = 0; i < n; i++) {
const auto ioThread = std::static_pointer_cast<IOThread>( const auto ioThread = std::static_pointer_cast<IOThread>(
threadList_.get()[i]); threadList_.get()[i]);
for (auto& o : observers_) {
o->threadStopped(ioThread.get());
}
ioThread->shouldRun = false; ioThread->shouldRun = false;
ioThread->eventBase->terminateLoopSoon(); ioThread->eventBase->terminateLoopSoon();
} }
} }
std::vector<EventBase*> IOThreadPoolExecutor::getEventBases() {
std::vector<EventBase*> bases;
RWSpinLock::ReadHolder{&threadListLock_};
for (const auto& thread : threadList_.get()) {
auto ioThread = std::static_pointer_cast<IOThread>(thread);
bases.push_back(ioThread->eventBase);
}
return bases;
}
// threadListLock_ is readlocked // threadListLock_ is readlocked
uint64_t IOThreadPoolExecutor::getPendingTaskCount() { uint64_t IOThreadPoolExecutor::getPendingTaskCount() {
uint64_t count = 0; uint64_t count = 0;
......
...@@ -41,7 +41,7 @@ class IOThreadPoolExecutor : public ThreadPoolExecutor, public IOExecutor { ...@@ -41,7 +41,7 @@ class IOThreadPoolExecutor : public ThreadPoolExecutor, public IOExecutor {
EventBase* getEventBase() override; EventBase* getEventBase() override;
std::vector<EventBase*> getEventBases(); EventBase* getEventBase(ThreadPoolExecutor::ThreadHandle*);
private: private:
struct FOLLY_ALIGN_TO_AVOID_FALSE_SHARING IOThread : public Thread { struct FOLLY_ALIGN_TO_AVOID_FALSE_SHARING IOThread : public Thread {
......
...@@ -99,6 +99,11 @@ void ThreadPoolExecutor::addThreads(size_t n) { ...@@ -99,6 +99,11 @@ void ThreadPoolExecutor::addThreads(size_t n) {
for (auto& thread : newThreads) { for (auto& thread : newThreads) {
thread->startupBaton.wait(); thread->startupBaton.wait();
} }
for (auto& o : observers_) {
for (auto& thread : newThreads) {
o->threadStarted(thread.get());
}
}
} }
// threadListLock_ is writelocked // threadListLock_ is writelocked
...@@ -171,4 +176,27 @@ size_t ThreadPoolExecutor::StoppedThreadQueue::size() { ...@@ -171,4 +176,27 @@ size_t ThreadPoolExecutor::StoppedThreadQueue::size() {
return queue_.size(); return queue_.size();
} }
void ThreadPoolExecutor::addObserver(std::shared_ptr<Observer> o) {
RWSpinLock::ReadHolder{&threadListLock_};
observers_.push_back(o);
for (auto& thread : threadList_.get()) {
o->threadPreviouslyStarted(thread.get());
}
}
void ThreadPoolExecutor::removeObserver(std::shared_ptr<Observer> o) {
RWSpinLock::ReadHolder{&threadListLock_};
for (auto& thread : threadList_.get()) {
o->threadNotYetStopped(thread.get());
}
for (auto it = observers_.begin(); it != observers_.end(); it++) {
if (*it == o) {
observers_.erase(it);
return;
}
}
DCHECK(false);
}
}} // folly::wangle }} // folly::wangle
...@@ -85,13 +85,40 @@ class ThreadPoolExecutor : public virtual Executor { ...@@ -85,13 +85,40 @@ class ThreadPoolExecutor : public virtual Executor {
return taskStatsSubject_->subscribe(observer); return taskStatsSubject_->subscribe(observer);
} }
/**
* Base class for threads created with ThreadPoolExecutor.
* Some subclasses have methods that operate on these
* handles.
*/
class ThreadHandle {
public:
virtual ~ThreadHandle() = default;
};
/**
* Observer interface for thread start/stop.
* Provides hooks so actions can be taken when
* threads are created
*/
class Observer {
public:
virtual void threadStarted(ThreadHandle*) = 0;
virtual void threadStopped(ThreadHandle*) = 0;
virtual void threadPreviouslyStarted(ThreadHandle*) = 0;
virtual void threadNotYetStopped(ThreadHandle*) = 0;
virtual ~Observer() = default;
};
void addObserver(std::shared_ptr<Observer>);
void removeObserver(std::shared_ptr<Observer>);
protected: protected:
// Prerequisite: threadListLock_ writelocked // Prerequisite: threadListLock_ writelocked
void addThreads(size_t n); void addThreads(size_t n);
// Prerequisite: threadListLock_ writelocked // Prerequisite: threadListLock_ writelocked
void removeThreads(size_t n, bool isJoin); void removeThreads(size_t n, bool isJoin);
struct FOLLY_ALIGN_TO_AVOID_FALSE_SHARING Thread { struct FOLLY_ALIGN_TO_AVOID_FALSE_SHARING Thread : public ThreadHandle {
explicit Thread(ThreadPoolExecutor* pool) explicit Thread(ThreadPoolExecutor* pool)
: id(nextId++), : id(nextId++),
handle(), handle(),
...@@ -185,6 +212,7 @@ class ThreadPoolExecutor : public virtual Executor { ...@@ -185,6 +212,7 @@ class ThreadPoolExecutor : public virtual Executor {
std::atomic<bool> isJoin_; // whether the current downsizing is a join std::atomic<bool> isJoin_; // whether the current downsizing is a join
std::shared_ptr<Subject<TaskStats>> taskStatsSubject_; std::shared_ptr<Subject<TaskStats>> taskStatsSubject_;
std::vector<std::shared_ptr<Observer>> observers_;
}; };
}} // folly::wangle }} // folly::wangle
...@@ -318,3 +318,56 @@ TEST(ThreadPoolExecutorTest, PriorityPreemptionTest) { ...@@ -318,3 +318,56 @@ TEST(ThreadPoolExecutorTest, PriorityPreemptionTest) {
pool.join(); pool.join();
EXPECT_EQ(100, completed); EXPECT_EQ(100, completed);
} }
class TestObserver : public ThreadPoolExecutor::Observer {
public:
void threadStarted(ThreadPoolExecutor::ThreadHandle*) {
threads_++;
}
void threadStopped(ThreadPoolExecutor::ThreadHandle*) {
threads_--;
}
void threadPreviouslyStarted(ThreadPoolExecutor::ThreadHandle*) {
threads_++;
}
void threadNotYetStopped(ThreadPoolExecutor::ThreadHandle*) {
threads_--;
}
void checkCalls() {
ASSERT_EQ(threads_, 0);
}
private:
int threads_{0};
};
TEST(ThreadPoolExecutorTest, IOObserver) {
auto observer = std::make_shared<TestObserver>();
{
IOThreadPoolExecutor exe(10);
exe.addObserver(observer);
exe.setNumThreads(3);
exe.setNumThreads(0);
exe.setNumThreads(7);
exe.removeObserver(observer);
exe.setNumThreads(10);
}
observer->checkCalls();
}
TEST(ThreadPoolExecutorTest, CPUObserver) {
auto observer = std::make_shared<TestObserver>();
{
CPUThreadPoolExecutor exe(10);
exe.addObserver(observer);
exe.setNumThreads(3);
exe.setNumThreads(0);
exe.setNumThreads(7);
exe.removeObserver(observer);
exe.setNumThreads(10);
}
observer->checkCalls();
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment