Commit dbfa7c45 authored by Andrii Grynenko's avatar Andrii Grynenko Committed by Facebook GitHub Bot

Make sure all observer adaptors correctly capture dependencies

Reviewed By: praihan

Differential Revision: D27800735

fbshipit-source-id: 26f2c1be72bb66864173beb87244129dd6194c8d
parent 1366baf8
...@@ -142,10 +142,14 @@ AtomicObserver<T>& AtomicObserver<T>::operator=(Observer<T> observer) { ...@@ -142,10 +142,14 @@ AtomicObserver<T>& AtomicObserver<T>::operator=(Observer<T> observer) {
template <typename T> template <typename T>
T AtomicObserver<T>::get() const { T AtomicObserver<T>::get() const {
auto version = cachedVersion_.load(std::memory_order_acquire); auto version = cachedVersion_.load(std::memory_order_acquire);
if (UNLIKELY(observer_.needRefresh(version))) { if (UNLIKELY(
observer_.needRefresh(version) ||
observer_detail::ObserverManager::inManagerThread())) {
SharedMutex::WriteHolder guard{refreshLock_}; SharedMutex::WriteHolder guard{refreshLock_};
version = cachedVersion_.load(std::memory_order_acquire); version = cachedVersion_.load(std::memory_order_acquire);
if (LIKELY(observer_.needRefresh(version))) { if (LIKELY(
observer_.needRefresh(version) ||
observer_detail::ObserverManager::inManagerThread())) {
auto snapshot = *observer_; auto snapshot = *observer_;
cachedValue_.store(*snapshot, std::memory_order_relaxed); cachedValue_.store(*snapshot, std::memory_order_relaxed);
cachedVersion_.store(snapshot.getVersion(), std::memory_order_release); cachedVersion_.store(snapshot.getVersion(), std::memory_order_release);
...@@ -184,16 +188,17 @@ ReadMostlyAtomicObserver<T>::ReadMostlyAtomicObserver(Observer<T> observer) ...@@ -184,16 +188,17 @@ ReadMostlyAtomicObserver<T>::ReadMostlyAtomicObserver(Observer<T> observer)
template <typename T> template <typename T>
T ReadMostlyAtomicObserver<T>::get() const { T ReadMostlyAtomicObserver<T>::get() const {
if (UNLIKELY(observer_detail::ObserverManager::inManagerThread())) {
return **observer_;
}
return cachedValue_.load(std::memory_order_relaxed); return cachedValue_.load(std::memory_order_relaxed);
} }
template <typename T> template <typename T>
ReadMostlyTLObserver<T>::ReadMostlyTLObserver(Observer<T> observer) ReadMostlyTLObserver<T>::ReadMostlyTLObserver(Observer<T> observer)
: observer_(std::move(observer)), : observer_(std::move(observer)) {
callback_(observer_.addCallback([this](Snapshot<T> snapshot) { refresh();
globalData_.lock()->reset(snapshot.getShared()); }
globalVersion_ = snapshot.getVersion();
})) {}
template <typename T> template <typename T>
ReadMostlyTLObserver<T>::ReadMostlyTLObserver( ReadMostlyTLObserver<T>::ReadMostlyTLObserver(
...@@ -202,7 +207,8 @@ ReadMostlyTLObserver<T>::ReadMostlyTLObserver( ...@@ -202,7 +207,8 @@ ReadMostlyTLObserver<T>::ReadMostlyTLObserver(
template <typename T> template <typename T>
ReadMostlySharedPtr<const T> ReadMostlyTLObserver<T>::getShared() const { ReadMostlySharedPtr<const T> ReadMostlyTLObserver<T>::getShared() const {
if (localSnapshot_->version_ == globalVersion_.load()) { if (!observer_.needRefresh(localSnapshot_->version_) &&
!observer_detail::ObserverManager::inManagerThread()) {
if (auto data = localSnapshot_->data_.lock()) { if (auto data = localSnapshot_->data_.lock()) {
return data; return data;
} }
...@@ -212,9 +218,13 @@ ReadMostlySharedPtr<const T> ReadMostlyTLObserver<T>::getShared() const { ...@@ -212,9 +218,13 @@ ReadMostlySharedPtr<const T> ReadMostlyTLObserver<T>::getShared() const {
template <typename T> template <typename T>
ReadMostlySharedPtr<const T> ReadMostlyTLObserver<T>::refresh() const { ReadMostlySharedPtr<const T> ReadMostlyTLObserver<T>::refresh() const {
auto version = globalVersion_.load(); auto snapshot = observer_.getSnapshot();
auto globalData = globalData_.lock(); auto globalData = globalData_.lock();
*localSnapshot_ = LocalSnapshot(*globalData, version); if (globalVersion_.load() < snapshot.getVersion()) {
globalData->reset(snapshot.getShared());
globalVersion_ = snapshot.getVersion();
}
*localSnapshot_ = LocalSnapshot(*globalData, globalVersion_.load());
return globalData->getShared(); return globalData->getShared();
} }
...@@ -297,5 +307,25 @@ Observer<observer_detail::ResultOfUnwrapSharedPtr<F>> makeValueObserver( ...@@ -297,5 +307,25 @@ Observer<observer_detail::ResultOfUnwrapSharedPtr<F>> makeValueObserver(
return activeValue; return activeValue;
}); });
} }
template <typename T>
typename HazptrObserver<T>::DefaultSnapshot HazptrObserver<T>::getSnapshot()
const {
if (UNLIKELY(observer_detail::ObserverManager::inManagerThread())) {
// Wait for updates
observer_.getSnapshot();
}
return DefaultSnapshot(state_);
}
template <typename T>
typename HazptrObserver<T>::LocalSnapshot HazptrObserver<T>::getLocalSnapshot()
const {
if (UNLIKELY(observer_detail::ObserverManager::inManagerThread())) {
// Wait for updates
observer_.getSnapshot();
}
return LocalSnapshot(state_);
}
} // namespace observer } // namespace observer
} // namespace folly } // namespace folly
...@@ -152,7 +152,7 @@ class ReadMostlyTLObserver; ...@@ -152,7 +152,7 @@ class ReadMostlyTLObserver;
* observer changes. This implementation incurs an additional allocation * observer changes. This implementation incurs an additional allocation
* on updates making it less suitable for write-heavy workloads. * on updates making it less suitable for write-heavy workloads.
* *
* There are 3 main APIs: * There are 2 main APIs:
* 1) getSnapshot: Returns a Snapshot containing a const pointer to T and guards * 1) getSnapshot: Returns a Snapshot containing a const pointer to T and guards
* access to it using folly::hazptr_holder. The pointer is only safe to use * access to it using folly::hazptr_holder. The pointer is only safe to use
* while the returned Snapshot object is alive. * while the returned Snapshot object is alive.
...@@ -160,10 +160,6 @@ class ReadMostlyTLObserver; ...@@ -160,10 +160,6 @@ class ReadMostlyTLObserver;
* This API is ~3ns faster than getSnapshot but is unsafe for the current * This API is ~3ns faster than getSnapshot but is unsafe for the current
* thread to construct any other hazptr holder type objects (hazptr_holder, * thread to construct any other hazptr holder type objects (hazptr_holder,
* hazptr_array and other hazptr_local) while the returned snapshot exists. * hazptr_array and other hazptr_local) while the returned snapshot exists.
* 3) getUnderlyingObserver: This can be used to trigger dependent observer
* updates inside the lambda passed to folly::observer::makeObserver(...).
* Using getSnapshot or getLocalSnapshot is not sufficient since they don't
* access the underlying observer's snapshot.
* *
* See folly/synchronization/Hazptr.h for more details on hazptrs. * See folly/synchronization/Hazptr.h for more details on hazptrs.
*/ */
...@@ -397,14 +393,10 @@ class ReadMostlyTLObserver { ...@@ -397,14 +393,10 @@ class ReadMostlyTLObserver {
Observer<T> observer_; Observer<T> observer_;
Synchronized<ReadMostlyMainPtr<const T>, std::mutex> globalData_; mutable Synchronized<ReadMostlyMainPtr<const T>, std::mutex> globalData_;
std::atomic<int64_t> globalVersion_; mutable std::atomic<int64_t> globalVersion_{0};
ThreadLocal<LocalSnapshot> localSnapshot_; ThreadLocal<LocalSnapshot> localSnapshot_;
// Construct callback last so that it's joined before members it may
// be accessing are destructed
CallbackHandle callback_;
}; };
template <typename T> template <typename T>
...@@ -432,14 +424,20 @@ class HazptrObserver { ...@@ -432,14 +424,20 @@ class HazptrObserver {
using LocalSnapshot = HazptrSnapshot<hazptr_local<1>>; using LocalSnapshot = HazptrSnapshot<hazptr_local<1>>;
explicit HazptrObserver(Observer<T> observer) explicit HazptrObserver(Observer<T> observer)
: observer_(std::move(observer)), : observer_(
callback_(observer_.addCallback([this](Snapshot<T> snapshot) { makeObserver([o = std::move(observer), alive = alive_, this]() {
auto* newState = new State(std::move(snapshot)); auto snapshot = o.getSnapshot();
auto* oldState = state_.exchange(newState, std::memory_order_acq_rel); auto rAlive = alive->rlock();
if (oldState) { if (*rAlive) {
oldState->retire(); auto* newState = new State(snapshot);
} auto* oldState =
})) {} state_.exchange(newState, std::memory_order_acq_rel);
if (oldState) {
oldState->retire();
}
}
return snapshot.getShared();
})) {}
HazptrObserver(const HazptrObserver<T>& r) : HazptrObserver(r.observer_) {} HazptrObserver(const HazptrObserver<T>& r) : HazptrObserver(r.observer_) {}
HazptrObserver& operator=(const HazptrObserver<T>&) = delete; HazptrObserver& operator=(const HazptrObserver<T>&) = delete;
...@@ -448,17 +446,15 @@ class HazptrObserver { ...@@ -448,17 +446,15 @@ class HazptrObserver {
HazptrObserver& operator=(HazptrObserver<T>&&) = default; HazptrObserver& operator=(HazptrObserver<T>&&) = default;
~HazptrObserver() { ~HazptrObserver() {
*alive_->wlock() = false;
auto* state = state_.load(std::memory_order_acquire); auto* state = state_.load(std::memory_order_acquire);
if (state) { if (state) {
state->retire(); state->retire();
} }
} }
DefaultSnapshot getSnapshot() const { return DefaultSnapshot(state_); } DefaultSnapshot getSnapshot() const;
LocalSnapshot getLocalSnapshot() const;
LocalSnapshot getLocalSnapshot() const { return LocalSnapshot(state_); }
Observer<T> getUnderlyingObserver() const { return observer_; }
private: private:
struct State : public hazptr_obj_base<State> { struct State : public hazptr_obj_base<State> {
...@@ -468,8 +464,9 @@ class HazptrObserver { ...@@ -468,8 +464,9 @@ class HazptrObserver {
}; };
std::atomic<State*> state_{nullptr}; std::atomic<State*> state_{nullptr};
std::shared_ptr<Synchronized<bool>> alive_{
std::make_shared<Synchronized<bool>>(true)};
Observer<T> observer_; Observer<T> observer_;
CallbackHandle callback_;
}; };
/** /**
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <thread> #include <thread>
#include <folly/Singleton.h> #include <folly/Singleton.h>
#include <folly/experimental/observer/Observer.h>
#include <folly/experimental/observer/SimpleObservable.h> #include <folly/experimental/observer/SimpleObservable.h>
#include <folly/experimental/observer/WithJitter.h> #include <folly/experimental/observer/WithJitter.h>
#include <folly/portability/GTest.h> #include <folly/portability/GTest.h>
...@@ -659,8 +660,8 @@ void runHazptrObserverTest(bool useLocalSnapshot) { ...@@ -659,8 +660,8 @@ void runHazptrObserverTest(bool useLocalSnapshot) {
EXPECT_EQ(value(observer), 24); EXPECT_EQ(value(observer), 24);
EXPECT_EQ(value(observerCopy), 24); EXPECT_EQ(value(observerCopy), 24);
auto dependentObserver = makeHazptrObserver([=] { auto dependentObserver = makeHazptrObserver([o = observable.getObserver()] {
return IntHolder{observer.getUnderlyingObserver().getSnapshot()->val_ + 1}; return IntHolder{o.getSnapshot()->val_ + 1};
}); });
EXPECT_EQ(value(dependentObserver), 25); EXPECT_EQ(value(dependentObserver), 25);
...@@ -811,3 +812,51 @@ TEST(SimpleObservable, DefaultConstructible) { ...@@ -811,3 +812,51 @@ TEST(SimpleObservable, DefaultConstructible) {
SimpleObservable<Data> observable; SimpleObservable<Data> observable;
EXPECT_EQ((**observable.getObserver()).i, 42); EXPECT_EQ((**observable.getObserver()).i, 42);
} }
TEST(Observer, MakeObserverUpdatesTracking) {
SimpleObservable<int> observable(0);
auto slowObserver = makeObserver([o = observable.getObserver()] {
std::this_thread::sleep_for(std::chrono::milliseconds{10});
return **o;
});
auto tlObserver = makeTLObserver(slowObserver);
auto rmtlObserver = makeReadMostlyTLObserver(slowObserver);
auto atomicObserver = makeAtomicObserver(slowObserver);
auto rmatomicObserver = makeReadMostlyAtomicObserver(slowObserver);
auto hazptrObserver = makeHazptrObserver(slowObserver);
EXPECT_EQ(0, **tlObserver);
EXPECT_EQ(0, *(rmtlObserver.getShared()));
EXPECT_EQ(0, *atomicObserver);
EXPECT_EQ(0, *rmatomicObserver);
EXPECT_EQ(0, *(hazptrObserver.getSnapshot()));
EXPECT_EQ(0, *(hazptrObserver.getLocalSnapshot()));
auto tlObserverCheck = makeObserver([&]() mutable { return **tlObserver; });
auto rmtlObserverCheck =
makeObserver([&]() mutable { return *(rmtlObserver.getShared()); });
auto atomicObserverCheck =
makeObserver([&]() mutable { return *atomicObserver; });
auto rmatomicObserverCheck =
makeObserver([&]() mutable { return *rmatomicObserver; });
auto hazptrObserverGetSnapshotCheck =
makeObserver([&]() mutable { return *(hazptrObserver.getSnapshot()); });
auto hazptrObserverGetLocalSnapshotCheck = makeObserver(
[&]() mutable { return *(hazptrObserver.getLocalSnapshot()); });
for (size_t i = 1; i <= 10; ++i) {
observable.setValue(i);
folly::observer_detail::ObserverManager::waitForAllUpdates();
EXPECT_EQ(i, **tlObserverCheck);
EXPECT_EQ(i, **rmtlObserverCheck);
EXPECT_EQ(i, **atomicObserverCheck);
EXPECT_EQ(i, **rmatomicObserverCheck);
EXPECT_EQ(i, **hazptrObserverGetSnapshotCheck);
EXPECT_EQ(i, **hazptrObserverGetLocalSnapshotCheck);
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment