Commit 571e7b97 authored by Maged Michael's avatar Maged Michael Committed by Facebook Github Bot

Add integrated reference counting

Summary:
Add support for reference counting integrated with the internal structures and operations of the hazard pointer library. The operations are wait-free.
The advantages of this approach over combining reference counting with hazard pointers externally are:
(1) A long list of linked objects that protected by one reference can all be reclaimed together instead of going through a potentially long series of alternating reclamation and calls to retire() for descendants.
(2) Support for iterative deletion as opposed to potential deep recursion of alternating calls to release reference count and object destructors.

Reviewed By: djwatson

Differential Revision: D6142066

fbshipit-source-id: 02bdfcbd5a2c2d5486d937bb2f9cfb6f192f5e1a
parent 3ace27b8
...@@ -118,9 +118,7 @@ static_assert( ...@@ -118,9 +118,7 @@ static_assert(
struct hazptr_tc { struct hazptr_tc {
hazptr_tc_entry entry_[HAZPTR_TC_SIZE]; hazptr_tc_entry entry_[HAZPTR_TC_SIZE];
size_t count_; size_t count_;
#ifndef NDEBUG bool local_; // for debug mode only
bool local_;
#endif
public: public:
hazptr_tc_entry& operator[](size_t i); hazptr_tc_entry& operator[](size_t i);
...@@ -206,6 +204,63 @@ inline void hazptr_obj_base<T, D>::retire(hazptr_domain& domain, D deleter) { ...@@ -206,6 +204,63 @@ inline void hazptr_obj_base<T, D>::retire(hazptr_domain& domain, D deleter) {
domain.objRetire(this); domain.objRetire(this);
} }
/**
* hazptr_obj_base_refcounted
*/
template <typename T, typename D>
inline void hazptr_obj_base_refcounted<T, D>::retire(
hazptr_domain& domain,
D deleter) {
DEBUG_PRINT(this << " " << &domain);
deleter_ = std::move(deleter);
reclaim_ = [](hazptr_obj* p) {
auto hrobp = static_cast<hazptr_obj_base_refcounted*>(p);
if (hrobp->release_ref()) {
auto obj = static_cast<T*>(hrobp);
hrobp->deleter_(obj);
}
};
if (HAZPTR_PRIV &&
(HAZPTR_ONE_DOMAIN || (&domain == &default_hazptr_domain()))) {
if (hazptr_priv_try_retire(this)) {
return;
}
}
domain.objRetire(this);
}
template <typename T, typename D>
inline void hazptr_obj_base_refcounted<T, D>::acquire_ref() {
DEBUG_PRINT(this);
auto oldval = refcount_.fetch_add(1);
DCHECK(oldval >= 0);
}
template <typename T, typename D>
inline void hazptr_obj_base_refcounted<T, D>::acquire_ref_safe() {
DEBUG_PRINT(this);
auto oldval = refcount_.load(std::memory_order_acquire);
DCHECK(oldval >= 0);
refcount_.store(oldval + 1, std::memory_order_release);
}
template <typename T, typename D>
inline bool hazptr_obj_base_refcounted<T, D>::release_ref() {
DEBUG_PRINT(this);
auto oldval = refcount_.load(std::memory_order_acquire);
if (oldval > 0) {
oldval = refcount_.fetch_sub(1);
} else {
if (kIsDebug) {
refcount_.store(-1);
}
}
DEBUG_PRINT(this << " " << oldval);
DCHECK(oldval >= 0);
return oldval == 0;
}
/** /**
* hazptr_rec * hazptr_rec
*/ */
...@@ -481,10 +536,10 @@ FOLLY_ALWAYS_INLINE hazptr_local<M>::hazptr_local() { ...@@ -481,10 +536,10 @@ FOLLY_ALWAYS_INLINE hazptr_local<M>::hazptr_local() {
auto& tc = *ptc; auto& tc = *ptc;
auto count = tc.count(); auto count = tc.count();
if (M <= count) { if (M <= count) {
#ifndef NDEBUG if (kIsDebug) {
DCHECK(!tc.local_); DCHECK(!tc.local_);
tc.local_ = true; tc.local_ = true;
#endif }
// Fast path // Fast path
for (size_t i = 0; i < M; ++i) { for (size_t i = 0; i < M; ++i) {
auto hprec = tc[i].hprec_; auto hprec = tc[i].hprec_;
...@@ -511,13 +566,13 @@ FOLLY_ALWAYS_INLINE hazptr_local<M>::hazptr_local() { ...@@ -511,13 +566,13 @@ FOLLY_ALWAYS_INLINE hazptr_local<M>::hazptr_local() {
template <size_t M> template <size_t M>
FOLLY_ALWAYS_INLINE hazptr_local<M>::~hazptr_local() { FOLLY_ALWAYS_INLINE hazptr_local<M>::~hazptr_local() {
if (LIKELY(!need_destruct_)) { if (LIKELY(!need_destruct_)) {
#ifndef NDEBUG if (kIsDebug) {
auto ptc = hazptr_tc_tls(); auto ptc = hazptr_tc_tls();
DCHECK(ptc != nullptr); DCHECK(ptc != nullptr);
auto& tc = *ptc; auto& tc = *ptc;
DCHECK(tc.local_); DCHECK(tc.local_);
tc.local_ = false; tc.local_ = false;
#endif }
return; return;
} }
// Slow path // Slow path
...@@ -602,6 +657,7 @@ inline hazptr_domain::~hazptr_domain() { ...@@ -602,6 +657,7 @@ inline hazptr_domain::~hazptr_domain() {
while (retired) { while (retired) {
for (auto p = retired; p; p = next) { for (auto p = retired; p; p = next) {
next = p->next_; next = p->next_;
DEBUG_PRINT(this << " " << p << " " << p->reclaim_);
(*(p->reclaim_))(p); (*(p->reclaim_))(p);
} }
retired = retired_.exchange(nullptr); retired = retired_.exchange(nullptr);
...@@ -866,9 +922,9 @@ inline void hazptr_tc_init() { ...@@ -866,9 +922,9 @@ inline void hazptr_tc_init() {
auto& tc = tls_tc_data_; auto& tc = tls_tc_data_;
DEBUG_PRINT(&tc); DEBUG_PRINT(&tc);
tc.count_ = 0; tc.count_ = 0;
#ifndef NDEBUG if (kIsDebug) {
tc.local_ = false; tc.local_ = false;
#endif }
} }
inline void hazptr_tc_shutdown() { inline void hazptr_tc_shutdown() {
......
...@@ -34,6 +34,12 @@ class hazptr_obj; ...@@ -34,6 +34,12 @@ class hazptr_obj;
template <typename T, typename Deleter> template <typename T, typename Deleter>
class hazptr_obj_base; class hazptr_obj_base;
/** hazptr_obj_base_refcounted:
* Base template for reference counted objects protected by hazard pointers.
*/
template <typename T, typename Deleter>
class hazptr_obj_base_refcounted;
/** hazptr_local: Optimized template for bulk construction and destruction of /** hazptr_local: Optimized template for bulk construction and destruction of
* hazard pointers */ * hazard pointers */
template <size_t M> template <size_t M>
...@@ -60,6 +66,8 @@ class hazptr_domain { ...@@ -60,6 +66,8 @@ class hazptr_domain {
friend class hazptr_holder; friend class hazptr_holder;
template <typename, typename> template <typename, typename>
friend class hazptr_obj_base; friend class hazptr_obj_base;
template <typename, typename>
friend class hazptr_obj_base_refcounted;
friend struct hazptr_priv; friend struct hazptr_priv;
memory_resource* mr_; memory_resource* mr_;
...@@ -87,10 +95,13 @@ class hazptr_obj { ...@@ -87,10 +95,13 @@ class hazptr_obj {
friend class hazptr_domain; friend class hazptr_domain;
template <typename, typename> template <typename, typename>
friend class hazptr_obj_base; friend class hazptr_obj_base;
template <typename, typename>
friend class hazptr_obj_base_refcounted;
friend struct hazptr_priv; friend struct hazptr_priv;
void (*reclaim_)(hazptr_obj*); void (*reclaim_)(hazptr_obj*);
hazptr_obj* next_; hazptr_obj* next_;
const void* getObjPtr() const; const void* getObjPtr() const;
}; };
...@@ -106,6 +117,33 @@ class hazptr_obj_base : public hazptr_obj { ...@@ -106,6 +117,33 @@ class hazptr_obj_base : public hazptr_obj {
D deleter_; D deleter_;
}; };
/** Definition of hazptr_recounted_obj_base */
template <typename T, typename D = std::default_delete<T>>
class hazptr_obj_base_refcounted : public hazptr_obj {
public:
/* Retire a removed object and pass the responsibility for
* reclaiming it to the hazptr library */
void retire(hazptr_domain& domain = default_hazptr_domain(), D reclaim = {});
/* aquire_ref() increments the reference count
*
* acquire_ref_safe() is the same as acquire_ref() except that in
* addition the caller guarantees that the call is made in a
* thread-safe context, e.g., the object is not yet shared. This is
* just an optimization to save an atomic operation.
*
* release_ref() decrements the reference count and returns true if
* the object is safe to reclaim.
*/
void acquire_ref();
void acquire_ref_safe();
bool release_ref();
private:
std::atomic<uint32_t> refcount_{0};
D deleter_;
};
/** hazptr_holder: Class for automatic acquisition and release of /** hazptr_holder: Class for automatic acquisition and release of
* hazard pointers, and interface for hazard pointer operations. */ * hazard pointers, and interface for hazard pointer operations. */
class hazptr_holder { class hazptr_holder {
......
...@@ -50,12 +50,16 @@ class HazptrTest : public testing::Test { ...@@ -50,12 +50,16 @@ class HazptrTest : public testing::Test {
TEST_F(HazptrTest, Test1) { TEST_F(HazptrTest, Test1) {
DEBUG_PRINT(""); DEBUG_PRINT("");
Node1* node0 = (Node1*)malloc(sizeof(Node1)); Node1* node0 = (Node1*)malloc(sizeof(Node1));
DEBUG_PRINT("=== new node0 " << node0 << " " << sizeof(*node0)); node0 = new (node0) Node1;
DEBUG_PRINT("=== malloc node0 " << node0 << " " << sizeof(*node0));
Node1* node1 = (Node1*)malloc(sizeof(Node1)); Node1* node1 = (Node1*)malloc(sizeof(Node1));
node1 = new (node1) Node1;
DEBUG_PRINT("=== malloc node1 " << node1 << " " << sizeof(*node1)); DEBUG_PRINT("=== malloc node1 " << node1 << " " << sizeof(*node1));
Node1* node2 = (Node1*)malloc(sizeof(Node1)); Node1* node2 = (Node1*)malloc(sizeof(Node1));
node2 = new (node2) Node1;
DEBUG_PRINT("=== malloc node2 " << node2 << " " << sizeof(*node2)); DEBUG_PRINT("=== malloc node2 " << node2 << " " << sizeof(*node2));
Node1* node3 = (Node1*)malloc(sizeof(Node1)); Node1* node3 = (Node1*)malloc(sizeof(Node1));
node3 = new (node3) Node1;
DEBUG_PRINT("=== malloc node3 " << node3 << " " << sizeof(*node3)); DEBUG_PRINT("=== malloc node3 " << node3 << " " << sizeof(*node3));
DEBUG_PRINT(""); DEBUG_PRINT("");
...@@ -90,12 +94,12 @@ TEST_F(HazptrTest, Test1) { ...@@ -90,12 +94,12 @@ TEST_F(HazptrTest, Test1) {
Node1* n2 = shared2.load(); Node1* n2 = shared2.load();
Node1* n3 = shared3.load(); Node1* n3 = shared3.load();
if (hptr0.try_protect(n0, shared0)) {} CHECK(hptr0.try_protect(n0, shared0));
if (hptr1.try_protect(n1, shared1)) {} CHECK(hptr1.try_protect(n1, shared1));
hptr1.reset(); hptr1.reset();
hptr1.reset(nullptr); hptr1.reset(nullptr);
hptr1.reset(n2); hptr1.reset(n2);
if (hptr2.try_protect(n3, shared3)) {} CHECK(hptr2.try_protect(n3, shared3));
swap(hptr1, hptr2); swap(hptr1, hptr2);
hptr3.reset(); hptr3.reset();
...@@ -115,10 +119,13 @@ TEST_F(HazptrTest, Test2) { ...@@ -115,10 +119,13 @@ TEST_F(HazptrTest, Test2) {
Node2* node0 = new Node2; Node2* node0 = new Node2;
DEBUG_PRINT("=== new node0 " << node0 << " " << sizeof(*node0)); DEBUG_PRINT("=== new node0 " << node0 << " " << sizeof(*node0));
Node2* node1 = (Node2*)malloc(sizeof(Node2)); Node2* node1 = (Node2*)malloc(sizeof(Node2));
node1 = new (node1) Node2;
DEBUG_PRINT("=== malloc node1 " << node1 << " " << sizeof(*node1)); DEBUG_PRINT("=== malloc node1 " << node1 << " " << sizeof(*node1));
Node2* node2 = (Node2*)malloc(sizeof(Node2)); Node2* node2 = (Node2*)malloc(sizeof(Node2));
node2 = new (node2) Node2;
DEBUG_PRINT("=== malloc node2 " << node2 << " " << sizeof(*node2)); DEBUG_PRINT("=== malloc node2 " << node2 << " " << sizeof(*node2));
Node2* node3 = (Node2*)malloc(sizeof(Node2)); Node2* node3 = (Node2*)malloc(sizeof(Node2));
node3 = new (node3) Node2;
DEBUG_PRINT("=== malloc node3 " << node3 << " " << sizeof(*node3)); DEBUG_PRINT("=== malloc node3 " << node3 << " " << sizeof(*node3));
DEBUG_PRINT(""); DEBUG_PRINT("");
...@@ -153,11 +160,11 @@ TEST_F(HazptrTest, Test2) { ...@@ -153,11 +160,11 @@ TEST_F(HazptrTest, Test2) {
Node2* n2 = shared2.load(); Node2* n2 = shared2.load();
Node2* n3 = shared3.load(); Node2* n3 = shared3.load();
if (hptr0.try_protect(n0, shared0)) {} CHECK(hptr0.try_protect(n0, shared0));
if (hptr1.try_protect(n1, shared1)) {} CHECK(hptr1.try_protect(n1, shared1));
hptr1.reset(); hptr1.reset();
hptr1.reset(n2); hptr1.reset(n2);
if (hptr2.try_protect(n3, shared3)) {} CHECK(hptr2.try_protect(n3, shared3));
swap(hptr1, hptr2); swap(hptr1, hptr2);
hptr3.reset(); hptr3.reset();
...@@ -185,7 +192,9 @@ TEST_F(HazptrTest, LIFO) { ...@@ -185,7 +192,9 @@ TEST_F(HazptrTest, LIFO) {
for (int j = tid; j < FLAGS_num_ops; j += FLAGS_num_threads) { for (int j = tid; j < FLAGS_num_ops; j += FLAGS_num_threads) {
s.push(j); s.push(j);
T res; T res;
while (!s.pop(res)) {} while (!s.pop(res)) {
/* keep trying */
}
} }
}); });
} }
...@@ -394,3 +403,151 @@ TEST_F(HazptrTest, Local) { ...@@ -394,3 +403,151 @@ TEST_F(HazptrTest, Local) {
hazptr_local<HAZPTR_TC_SIZE + 1> h; hazptr_local<HAZPTR_TC_SIZE + 1> h;
} }
} }
/* Test ref counting */
std::atomic<int> constructed;
std::atomic<int> destroyed;
struct Foo : hazptr_obj_base_refcounted<Foo> {
int val_;
bool marked_;
Foo* next_;
Foo(int v, Foo* n) : val_(v), marked_(false), next_(n) {
DEBUG_PRINT("");
++constructed;
}
~Foo() {
DEBUG_PRINT("");
++destroyed;
if (marked_) {
return;
}
auto next = next_;
while (next) {
if (!next->release_ref()) {
return;
}
auto p = next;
next = p->next_;
p->marked_ = true;
delete p;
}
}
};
struct Dummy : hazptr_obj_base<Dummy> {};
TEST_F(HazptrTest, basic_refcount) {
constructed.store(0);
destroyed.store(0);
Foo* p = nullptr;
int num = 20;
for (int i = 0; i < num; ++i) {
p = new Foo(i, p);
if (i & 1) {
p->acquire_ref_safe();
} else {
p->acquire_ref();
}
}
hazptr_holder hptr;
hptr.reset(p);
for (auto q = p->next_; q; q = q->next_) {
q->retire();
}
int v = num;
for (auto q = p; q; q = q->next_) {
CHECK_GT(v, 0);
--v;
CHECK_EQ(q->val_, v);
}
CHECK(!p->release_ref());
CHECK_EQ(constructed.load(), num);
CHECK_EQ(destroyed.load(), 0);
p->retire();
CHECK_EQ(constructed.load(), num);
CHECK_EQ(destroyed.load(), 0);
hptr.reset();
/* retire enough objects to guarantee reclamation of Foo objects */
for (int i = 0; i < 100; ++i) {
auto a = new Dummy;
a->retire();
}
CHECK_EQ(constructed.load(), num);
CHECK_EQ(destroyed.load(), num);
}
TEST_F(HazptrTest, mt_refcount) {
constructed.store(0);
destroyed.store(0);
std::atomic<bool> ready(false);
std::atomic<int> setHazptrs(0);
std::atomic<Foo*> head;
int num = 20;
int nthr = 10;
std::vector<std::thread> thr(nthr);
for (int i = 0; i < nthr; ++i) {
thr[i] = std::thread([&] {
while (!ready.load()) {
/* spin */
}
hazptr_holder hptr;
auto p = hptr.get_protected(head);
++setHazptrs;
/* Concurrent with removal */
int v = num;
for (auto q = p; q; q = q->next_) {
CHECK_GT(v, 0);
--v;
CHECK_EQ(q->val_, v);
}
CHECK_EQ(v, 0);
});
}
Foo* p = nullptr;
for (int i = 0; i < num; ++i) {
p = new Foo(i, p);
p->acquire_ref_safe();
}
head.store(p);
ready.store(true);
while (setHazptrs.load() < nthr) {
/* spin */
}
/* this is concurrent with traversal by reader */
head.store(nullptr);
for (auto q = p; q; q = q->next_) {
q->retire();
}
DEBUG_PRINT("Foo should not be destroyed");
CHECK_EQ(constructed.load(), num);
CHECK_EQ(destroyed.load(), 0);
DEBUG_PRINT("Foo may be destroyed after releasing the last reference");
if (p->release_ref()) {
delete p;
}
/* retire enough objects to guarantee reclamation of Foo objects */
for (int i = 0; i < 100; ++i) {
auto a = new Dummy;
a->retire();
}
for (int i = 0; i < nthr; ++i) {
thr[i].join();
}
CHECK_EQ(constructed.load(), num);
CHECK_EQ(destroyed.load(), num);
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment