Commit bccbb102 authored by Andrii Grynenko's avatar Andrii Grynenko Committed by Facebook Github Bot 9

Fix races in TLRefCount

Summary:
This fixes 2 races in TLRefCount:
1. Thread-local constructor race, exposed by the stress test. It was possible for LocalRefCount to be created (grabbing collectGuard), but not be added to the thread-local list, so that accessAllThreads wasn't collecting it. collectAll() was then blocking waiting on baton to be posted, causing a dead-lock.
2. LocalRefCount::count_ has to be made atomic, because otherwise += operation may be not flushed (nbronson explained the race in D3133443).

Reviewed By: djwatson

Differential Revision: D3166956

fb-gh-sync-id: 17d58a215ebfc572f8316ed46bafaa5e6a9e2368
fbshipit-source-id: 17d58a215ebfc572f8316ed46bafaa5e6a9e2368
parent 68cf03b1
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
*/ */
#pragma once #pragma once
#include <folly/Baton.h>
#include <folly/ThreadLocal.h> #include <folly/ThreadLocal.h>
namespace folly { namespace folly {
...@@ -24,15 +23,9 @@ class TLRefCount { ...@@ -24,15 +23,9 @@ class TLRefCount {
public: public:
using Int = int64_t; using Int = int64_t;
TLRefCount() : TLRefCount()
localCount_([&]() { : localCount_([&]() { return new LocalRefCount(*this); }),
return new LocalRefCount(*this); collectGuard_(this, [](void*) {}) {}
}),
collectGuard_(&collectBaton_, [](void* p) {
auto baton = reinterpret_cast<folly::Baton<>*>(p);
baton->post();
}) {
}
~TLRefCount() noexcept { ~TLRefCount() noexcept {
assert(globalCount_.load() == 0); assert(globalCount_.load() == 0);
...@@ -91,13 +84,17 @@ class TLRefCount { ...@@ -91,13 +84,17 @@ class TLRefCount {
state_ = State::GLOBAL_TRANSITION; state_ = State::GLOBAL_TRANSITION;
auto accessor = localCount_.accessAllThreads(); std::weak_ptr<void> collectGuardWeak = collectGuard_;
for (auto& count : accessor) {
count.collect();
}
// Make sure we can't create new LocalRefCounts
collectGuard_.reset(); collectGuard_.reset();
collectBaton_.wait();
while (!collectGuardWeak.expired()) {
auto accessor = localCount_.accessAllThreads();
for (auto& count : accessor) {
count.collect();
}
}
state_ = State::GLOBAL; state_ = State::GLOBAL;
} }
...@@ -131,7 +128,7 @@ class TLRefCount { ...@@ -131,7 +128,7 @@ class TLRefCount {
return; return;
} }
collectCount_ = count_; collectCount_ = count_.load();
refCount_.globalCount_.fetch_add(collectCount_); refCount_.globalCount_.fetch_add(collectCount_);
collectGuard_.reset(); collectGuard_.reset();
} }
...@@ -166,7 +163,7 @@ class TLRefCount { ...@@ -166,7 +163,7 @@ class TLRefCount {
return true; return true;
} }
Int count_{0}; AtomicInt count_{0};
TLRefCount& refCount_; TLRefCount& refCount_;
std::mutex collectMutex_; std::mutex collectMutex_;
...@@ -178,7 +175,6 @@ class TLRefCount { ...@@ -178,7 +175,6 @@ class TLRefCount {
folly::ThreadLocal<LocalRefCount, TLRefCount> localCount_; folly::ThreadLocal<LocalRefCount, TLRefCount> localCount_;
std::atomic<int64_t> globalCount_{1}; std::atomic<int64_t> globalCount_{1};
std::mutex globalMutex_; std::mutex globalMutex_;
folly::Baton<> collectBaton_;
std::shared_ptr<void> collectGuard_; std::shared_ptr<void> collectGuard_;
}; };
......
...@@ -83,6 +83,40 @@ void basicTest() { ...@@ -83,6 +83,40 @@ void basicTest() {
EXPECT_EQ(0, ++count); EXPECT_EQ(0, ++count);
} }
template <typename RefCount>
void stressTest() {
constexpr size_t kItersCount = 10000;
for (size_t i = 0; i < kItersCount; ++i) {
RefCount count;
std::mutex mutex;
int a{1};
std::thread t1([&]() {
if (++count) {
{
std::lock_guard<std::mutex> lg(mutex);
EXPECT_EQ(1, a);
}
--count;
}
});
std::thread t2([&]() {
count.useGlobal();
if (--count == 0) {
std::lock_guard<std::mutex> lg(mutex);
a = 0;
}
});
t1.join();
t2.join();
EXPECT_EQ(0, ++count);
}
}
TEST(RCURefCount, Basic) { TEST(RCURefCount, Basic) {
basicTest<RCURefCount>(); basicTest<RCURefCount>();
} }
...@@ -91,4 +125,11 @@ TEST(TLRefCount, Basic) { ...@@ -91,4 +125,11 @@ TEST(TLRefCount, Basic) {
basicTest<TLRefCount>(); basicTest<TLRefCount>();
} }
TEST(RCURefCount, Stress) {
stressTest<TLRefCount>();
}
TEST(TLRefCount, Stress) {
stressTest<TLRefCount>();
}
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment