Commit 328c1a56 authored by Maged Michael's avatar Maged Michael Committed by Facebook Github Bot

Update hazard pointers prototype

Summary:
Interface:
- Got rid of std::function reclamation functions and added a Deleter template parameter.
- Got rid of the flush() member functions of hazptr_domain
- Added a lock-free non-bool member function to get a protected pointer.
Implementation:
- Implemented the interface changes.
- Changed the order of accesses in reading the shared list of objects vs reading the hazard pointers. I think the previous order would have allowed recently protected objects to be reclaimed incorrectly.
Updated the examples and tests accordingly.

Reviewed By: davidtgoldblatt

Differential Revision: D3981284

fbshipit-source-id: 35ff60da3aea1f67c58d82437dda58f6d8b07bf5
parent a62c9841
...@@ -55,13 +55,15 @@ class LockFreeLIFO { ...@@ -55,13 +55,15 @@ class LockFreeLIFO {
bool pop(T& val) { bool pop(T& val) {
DEBUG_PRINT(this); DEBUG_PRINT(this);
hazptr_owner<Node> hptr; hazptr_owner<Node> hptr;
Node* pnode; Node* pnode = head_.load();
while (true) { do {
if ((pnode = head_.load()) == nullptr) return false; if (pnode == nullptr)
if (!hptr.protect(pnode, head_)) continue; return false;
if (!hptr.try_protect(pnode, head_))
continue;
auto next = pnode->next_; auto next = pnode->next_;
if (head_.compare_exchange_weak(pnode, next)) break; if (head_.compare_exchange_weak(pnode, next)) break;
} } while (true);
hptr.clear(); hptr.clear();
val = pnode->value_; val = pnode->value_;
pnode->retire(); pnode->retire();
......
...@@ -29,7 +29,15 @@ namespace hazptr { ...@@ -29,7 +29,15 @@ namespace hazptr {
*/ */
template <typename T> template <typename T>
class SWMRListSet { class SWMRListSet {
class Node : public hazptr_obj_base<Node> { template <typename Node>
struct Reclaimer {
void operator()(Node* p) {
DEBUG_PRINT(p << " " << sizeof(Node));
delete p;
}
};
class Node : public hazptr_obj_base<Node, Reclaimer<Node>> {
friend SWMRListSet; friend SWMRListSet;
T elem_; T elem_;
std::atomic<Node*> next_; std::atomic<Node*> next_;
...@@ -45,16 +53,10 @@ class SWMRListSet { ...@@ -45,16 +53,10 @@ class SWMRListSet {
}; };
std::atomic<Node*> head_ = {nullptr}; std::atomic<Node*> head_ = {nullptr};
hazptr_domain* domain_; hazptr_domain& domain_;
hazptr_obj_reclaim<Node> reclaim_ = [](Node* p) { reclaim(p); };
static void reclaim(Node* p) {
DEBUG_PRINT(p << " " << sizeof(Node));
delete p;
};
/* Used by the single writer */ /* Used by the single writer */
void locate_lower_bound(T v, std::atomic<Node*>*& prev) { void locate_lower_bound(const T v, std::atomic<Node*>*& prev) const {
auto curr = prev->load(); auto curr = prev->load();
while (curr) { while (curr) {
if (curr->elem_ >= v) break; if (curr->elem_ >= v) break;
...@@ -65,7 +67,7 @@ class SWMRListSet { ...@@ -65,7 +67,7 @@ class SWMRListSet {
} }
public: public:
explicit SWMRListSet(hazptr_domain* domain = default_hazptr_domain()) explicit SWMRListSet(hazptr_domain& domain = default_hazptr_domain())
: domain_(domain) {} : domain_(domain) {}
~SWMRListSet() { ~SWMRListSet() {
...@@ -74,10 +76,9 @@ class SWMRListSet { ...@@ -74,10 +76,9 @@ class SWMRListSet {
next = p->next_.load(); next = p->next_.load();
delete p; delete p;
} }
domain_->flush(&reclaim_); /* avoid destruction order fiasco */
} }
bool add(T v) { bool add(const T v) {
auto prev = &head_; auto prev = &head_;
locate_lower_bound(v, prev); locate_lower_bound(v, prev);
auto curr = prev->load(); auto curr = prev->load();
...@@ -86,17 +87,17 @@ class SWMRListSet { ...@@ -86,17 +87,17 @@ class SWMRListSet {
return true; return true;
} }
bool remove(T v) { bool remove(const T v) {
auto prev = &head_; auto prev = &head_;
locate_lower_bound(v, prev); locate_lower_bound(v, prev);
auto curr = prev->load(); auto curr = prev->load();
if (!curr || curr->elem_ != v) return false; if (!curr || curr->elem_ != v) return false;
prev->store(curr->next_.load()); prev->store(curr->next_.load());
curr->retire(domain_, &reclaim_); curr->retire(domain_);
return true; return true;
} }
/* Used by readers */ /* Used by readers */
bool contains(T val) { bool contains(const T val) const {
/* Acquire two hazard pointers for hand-over-hand traversal. */ /* Acquire two hazard pointers for hand-over-hand traversal. */
hazptr_owner<Node> hptr_prev(domain_); hazptr_owner<Node> hptr_prev(domain_);
hazptr_owner<Node> hptr_curr(domain_); hazptr_owner<Node> hptr_curr(domain_);
...@@ -107,11 +108,10 @@ class SWMRListSet { ...@@ -107,11 +108,10 @@ class SWMRListSet {
auto curr = prev->load(); auto curr = prev->load();
while (true) { while (true) {
if (!curr) { done = true; break; } if (!curr) { done = true; break; }
if (!hptr_curr.protect(curr, *prev)) break; if (!hptr_curr.try_protect(curr, *prev))
break;
auto next = curr->next_.load(); auto next = curr->next_.load();
elem = curr->elem_; elem = curr->elem_;
// Load-load order
std::atomic_thread_fence(std::memory_order_acquire);
if (prev->load() != curr) break; if (prev->load() != curr) break;
if (elem >= val) { done = true; break; } if (elem >= val) { done = true; break; }
prev = &(curr->next_); prev = &(curr->next_);
......
...@@ -49,9 +49,9 @@ class WideCAS { ...@@ -49,9 +49,9 @@ class WideCAS {
DEBUG_PRINT(this << " " << u << " " << v); DEBUG_PRINT(this << " " << u << " " << v);
Node* n = new Node(v); Node* n = new Node(v);
hazptr_owner<Node> hptr; hazptr_owner<Node> hptr;
Node* p = p_.load(); Node* p;
do { do {
if (!hptr.protect(p, p_)) continue; p = hptr.get_protected(p_);
if (p->val_ != u) { delete n; return false; } if (p->val_ != u) { delete n; return false; }
if (p_.compare_exchange_weak(p, n)) break; if (p_.compare_exchange_weak(p, n)) break;
} while (true); } while (true);
......
...@@ -31,40 +31,21 @@ namespace hazptr { ...@@ -31,40 +31,21 @@ namespace hazptr {
constexpr hazptr_domain::hazptr_domain(memory_resource* mr) noexcept constexpr hazptr_domain::hazptr_domain(memory_resource* mr) noexcept
: mr_(mr) {} : mr_(mr) {}
template <typename T>
void hazptr_domain::flush(const hazptr_obj_reclaim<T>* reclaim) {
DEBUG_PRINT(this << " " << reclaim);
flush(reinterpret_cast<const hazptr_obj_reclaim<void>*>(reclaim));
}
template <typename T>
inline void hazptr_domain::objRetire(hazptr_obj_base<T>* p) {
DEBUG_PRINT(this << " " << p);
objRetire(reinterpret_cast<hazptr_obj_base<void>*>(p));
}
/** hazptr_obj_base */ /** hazptr_obj_base */
template <typename T> template <typename T, typename D>
inline void hazptr_obj_base<T>::retire( inline void hazptr_obj_base<T, D>::retire(
hazptr_domain* domain, hazptr_domain& domain,
const hazptr_obj_reclaim<T>* reclaim, D deleter,
const storage_policy /* policy */) { const storage_policy /* policy */) {
DEBUG_PRINT(this << " " << reclaim << " " << &domain); DEBUG_PRINT(this << " " << &domain);
reclaim_ = reclaim; deleter_ = std::move(deleter);
domain->objRetire<T>(this); reclaim_ = [](hazptr_obj* p) {
} auto hobp = static_cast<hazptr_obj_base*>(p);
auto obj = static_cast<T*>(hobp);
/* Definition of default_hazptr_obj_reclaim */ hobp->deleter_(obj);
template <typename T>
inline hazptr_obj_reclaim<T>* default_hazptr_obj_reclaim() {
static hazptr_obj_reclaim<T> fn = [](T* p) {
DEBUG_PRINT("default_hazptr_obj_reclaim " << p << " " << sizeof(T));
delete p;
}; };
DEBUG_PRINT(&fn); domain.objRetire(this);
return &fn;
} }
/** hazptr_rec */ /** hazptr_rec */
...@@ -87,37 +68,51 @@ class hazptr_rec { ...@@ -87,37 +68,51 @@ class hazptr_rec {
template <typename T> template <typename T>
inline hazptr_owner<T>::hazptr_owner( inline hazptr_owner<T>::hazptr_owner(
hazptr_domain* domain, hazptr_domain& domain,
const cache_policy /* policy */) { const cache_policy /* policy */) {
domain_ = domain; domain_ = &domain;
hazptr_ = domain_->hazptrAcquire(); hazptr_ = domain_->hazptrAcquire();
DEBUG_PRINT(this << " " << domain_ << " " << hazptr_); DEBUG_PRINT(this << " " << domain_ << " " << hazptr_);
if (hazptr_ == nullptr) { std::bad_alloc e; throw e; } if (hazptr_ == nullptr) { std::bad_alloc e; throw e; }
} }
template <typename T> template <typename T>
hazptr_owner<T>::~hazptr_owner() noexcept { hazptr_owner<T>::~hazptr_owner() {
DEBUG_PRINT(this); DEBUG_PRINT(this);
domain_->hazptrRelease(hazptr_); domain_->hazptrRelease(hazptr_);
} }
template <typename T> template <typename T>
inline bool hazptr_owner<T>::protect(const T* ptr, const std::atomic<T*>& src) inline bool hazptr_owner<T>::try_protect(
const noexcept { T*& ptr,
const std::atomic<T*>& src) noexcept {
DEBUG_PRINT(this << " " << ptr << " " << &src); DEBUG_PRINT(this << " " << ptr << " " << &src);
hazptr_->set(ptr); set(ptr);
// ORDER: store-load T* p = src.load();
return (src.load() == ptr); if (p != ptr) {
ptr = p;
clear();
return false;
}
return true;
} }
template <typename T> template <typename T>
inline void hazptr_owner<T>::set(const T* ptr) const noexcept { inline T* hazptr_owner<T>::get_protected(const std::atomic<T*>& src) noexcept {
T* p = src.load();
while (!try_protect(p, src)) {}
DEBUG_PRINT(this << " " << p << " " << &src);
return p;
}
template <typename T>
inline void hazptr_owner<T>::set(const T* ptr) noexcept {
DEBUG_PRINT(this << " " << ptr); DEBUG_PRINT(this << " " << ptr);
hazptr_->set(ptr); hazptr_->set(ptr);
} }
template <typename T> template <typename T>
inline void hazptr_owner<T>::clear() const noexcept { inline void hazptr_owner<T>::clear() noexcept {
DEBUG_PRINT(this); DEBUG_PRINT(this);
hazptr_->clear(); hazptr_->clear();
} }
...@@ -145,9 +140,9 @@ inline void swap(hazptr_owner<T>& lhs, hazptr_owner<T>& rhs) noexcept { ...@@ -145,9 +140,9 @@ inline void swap(hazptr_owner<T>& lhs, hazptr_owner<T>& rhs) noexcept {
// - Optimized memory order // - Optimized memory order
/** Definition of default_hazptr_domain() */ /** Definition of default_hazptr_domain() */
inline hazptr_domain* default_hazptr_domain() { inline hazptr_domain& default_hazptr_domain() {
static hazptr_domain d; static hazptr_domain d;
return &d; return d;
} }
/** hazptr_rec */ /** hazptr_rec */
...@@ -164,17 +159,21 @@ inline const void* hazptr_rec::get() const noexcept { ...@@ -164,17 +159,21 @@ inline const void* hazptr_rec::get() const noexcept {
inline void hazptr_rec::clear() noexcept { inline void hazptr_rec::clear() noexcept {
DEBUG_PRINT(this); DEBUG_PRINT(this);
// ORDER: release
hazptr_.store(nullptr); hazptr_.store(nullptr);
} }
inline void hazptr_rec::release() noexcept { inline void hazptr_rec::release() noexcept {
DEBUG_PRINT(this); DEBUG_PRINT(this);
clear(); clear();
// ORDER: release
active_.store(false); active_.store(false);
} }
/** hazptr_obj */
inline const void* hazptr_obj::getObjPtr() const {
return this;
}
/** hazptr_domain */ /** hazptr_domain */
inline hazptr_domain::~hazptr_domain() { inline hazptr_domain::~hazptr_domain() {
...@@ -195,17 +194,10 @@ inline hazptr_domain::~hazptr_domain() { ...@@ -195,17 +194,10 @@ inline hazptr_domain::~hazptr_domain() {
} }
} }
inline void hazptr_domain::flush() { inline void hazptr_domain::try_reclaim() {
DEBUG_PRINT(this); DEBUG_PRINT(this);
auto rcount = rcount_.exchange(0); rcount_.exchange(0);
auto p = retired_.exchange(nullptr); bulkReclaim();
hazptr_obj* next;
for (; p; p = next) {
next = p->next_;
(*(p->reclaim_))(p);
--rcount;
}
rcount_.fetch_add(rcount);
} }
inline hazptr_rec* hazptr_domain::hazptrAcquire() { inline hazptr_rec* hazptr_domain::hazptrAcquire() {
...@@ -237,7 +229,7 @@ inline hazptr_rec* hazptr_domain::hazptrAcquire() { ...@@ -237,7 +229,7 @@ inline hazptr_rec* hazptr_domain::hazptrAcquire() {
return p; return p;
} }
inline void hazptr_domain::hazptrRelease(hazptr_rec* p) const noexcept { inline void hazptr_domain::hazptrRelease(hazptr_rec* p) noexcept {
DEBUG_PRINT(this << " " << p); DEBUG_PRINT(this << " " << p);
p->release(); p->release();
} }
...@@ -245,7 +237,6 @@ inline void hazptr_domain::hazptrRelease(hazptr_rec* p) const noexcept { ...@@ -245,7 +237,6 @@ inline void hazptr_domain::hazptrRelease(hazptr_rec* p) const noexcept {
inline int inline int
hazptr_domain::pushRetired(hazptr_obj* head, hazptr_obj* tail, int count) { hazptr_domain::pushRetired(hazptr_obj* head, hazptr_obj* tail, int count) {
tail->next_ = retired_.load(); tail->next_ = retired_.load();
// ORDER: store-store order
while (!retired_.compare_exchange_weak(tail->next_, head)) {} while (!retired_.compare_exchange_weak(tail->next_, head)) {}
return rcount_.fetch_add(count); return rcount_.fetch_add(count);
} }
...@@ -253,16 +244,15 @@ hazptr_domain::pushRetired(hazptr_obj* head, hazptr_obj* tail, int count) { ...@@ -253,16 +244,15 @@ hazptr_domain::pushRetired(hazptr_obj* head, hazptr_obj* tail, int count) {
inline void hazptr_domain::objRetire(hazptr_obj* p) { inline void hazptr_domain::objRetire(hazptr_obj* p) {
auto rcount = pushRetired(p, p, 1) + 1; auto rcount = pushRetired(p, p, 1) + 1;
if (rcount >= kScanThreshold * hcount_.load()) { if (rcount >= kScanThreshold * hcount_.load()) {
bulkReclaim(); tryBulkReclaim();
} }
} }
inline void hazptr_domain::bulkReclaim() { inline void hazptr_domain::tryBulkReclaim() {
DEBUG_PRINT(this); DEBUG_PRINT(this);
auto h = hazptrs_.load(); do {
auto hcount = hcount_.load(); auto hcount = hcount_.load();
auto rcount = rcount_.load(); auto rcount = rcount_.load();
do {
if (rcount < kScanThreshold * hcount) { if (rcount < kScanThreshold * hcount) {
return; return;
} }
...@@ -270,20 +260,25 @@ inline void hazptr_domain::bulkReclaim() { ...@@ -270,20 +260,25 @@ inline void hazptr_domain::bulkReclaim() {
break; break;
} }
} while (true); } while (true);
/* ORDER: store-load order between removing each object and scanning bulkReclaim();
* the hazard pointers -- can be combined in one fence */ }
inline void hazptr_domain::bulkReclaim() {
DEBUG_PRINT(this);
auto p = retired_.exchange(nullptr);
auto h = hazptrs_.load();
std::unordered_set<const void*> hs; std::unordered_set<const void*> hs;
for (; h; h = h->next_) { for (; h; h = h->next_) {
hs.insert(h->hazptr_.load()); hs.insert(h->hazptr_.load());
} }
rcount = 0; int rcount = 0;
hazptr_obj* retired = nullptr; hazptr_obj* retired = nullptr;
hazptr_obj* tail = nullptr; hazptr_obj* tail = nullptr;
auto p = retired_.exchange(nullptr);
hazptr_obj* next; hazptr_obj* next;
for (; p; p = next) { for (; p; p = next) {
next = p->next_; next = p->next_;
if (hs.count(p) == 0) { if (hs.count(p->getObjPtr()) == 0) {
DEBUG_PRINT(this << " " << p << " " << p->reclaim_);
(*(p->reclaim_))(p); (*(p->reclaim_))(p);
} else { } else {
p->next_ = retired; p->next_ = retired;
...@@ -299,31 +294,6 @@ inline void hazptr_domain::bulkReclaim() { ...@@ -299,31 +294,6 @@ inline void hazptr_domain::bulkReclaim() {
} }
} }
inline void hazptr_domain::flush(const hazptr_obj_reclaim<void>* reclaim) {
DEBUG_PRINT(this << " " << reclaim);
auto rcount = rcount_.exchange(0);
auto p = retired_.exchange(nullptr);
hazptr_obj* retired = nullptr;
hazptr_obj* tail = nullptr;
hazptr_obj* next;
for (; p; p = next) {
next = p->next_;
if (p->reclaim_ == reclaim) {
(*reclaim)(p);
} else {
p->next_ = retired;
retired = p;
if (tail == nullptr) {
tail = p;
}
++rcount;
}
}
if (tail) {
pushRetired(retired, tail, rcount);
}
}
/** hazptr_user */ /** hazptr_user */
inline void hazptr_user::flush() { inline void hazptr_user::flush() {
......
...@@ -29,11 +29,12 @@ namespace hazptr { ...@@ -29,11 +29,12 @@ namespace hazptr {
/** hazptr_rec: Private class that contains hazard pointers. */ /** hazptr_rec: Private class that contains hazard pointers. */
class hazptr_rec; class hazptr_rec;
/** hazptr_obj_base: Base template for objects protected by hazard pointers. */ /** hazptr_obj: Private class for objects protected by hazard pointers. */
template <typename T> class hazptr_obj_base; class hazptr_obj;
/** Alias for object reclamation function template */ /** hazptr_obj_base: Base template for objects protected by hazard pointers. */
template <typename T> using hazptr_obj_reclaim = std::function<void(T*)>; template <typename T, typename Deleter>
class hazptr_obj_base;
/** hazptr_domain: Class of hazard pointer domains. Each domain manages a set /** hazptr_domain: Class of hazard pointer domains. Each domain manages a set
* of hazard pointers and a set of retired objects. */ * of hazard pointers and a set of retired objects. */
...@@ -48,18 +49,13 @@ class hazptr_domain { ...@@ -48,18 +49,13 @@ class hazptr_domain {
hazptr_domain& operator=(const hazptr_domain&) = delete; hazptr_domain& operator=(const hazptr_domain&) = delete;
hazptr_domain& operator=(hazptr_domain&&) = delete; hazptr_domain& operator=(hazptr_domain&&) = delete;
/* Reclaim all retired objects with a specific reclamation void try_reclaim();
* function currently stored by this domain */
template <typename T> void flush(const hazptr_obj_reclaim<T>* reclaim);
/* Reclaim all retired objects currently stored by this domain */
void flush();
private: private:
template <typename> friend class hazptr_obj_base; template <typename, typename>
friend class hazptr_obj_base;
template <typename> friend class hazptr_owner; template <typename> friend class hazptr_owner;
using hazptr_obj = hazptr_obj_base<void>;
/** Constant -- May be changed to parameter in the future */ /** Constant -- May be changed to parameter in the future */
enum { kScanThreshold = 3 }; enum { kScanThreshold = 3 };
...@@ -69,23 +65,31 @@ class hazptr_domain { ...@@ -69,23 +65,31 @@ class hazptr_domain {
std::atomic<int> hcount_ = {0}; std::atomic<int> hcount_ = {0};
std::atomic<int> rcount_ = {0}; std::atomic<int> rcount_ = {0};
template <typename T> void objRetire(hazptr_obj_base<T>*);
hazptr_rec* hazptrAcquire();
void hazptrRelease(hazptr_rec*) const noexcept;
void objRetire(hazptr_obj*); void objRetire(hazptr_obj*);
hazptr_rec* hazptrAcquire();
void hazptrRelease(hazptr_rec*) noexcept;
int pushRetired(hazptr_obj* head, hazptr_obj* tail, int count); int pushRetired(hazptr_obj* head, hazptr_obj* tail, int count);
void tryBulkReclaim();
void bulkReclaim(); void bulkReclaim();
void flush(const hazptr_obj_reclaim<void>* reclaim);
}; };
/** Get the default hazptr_domain */ /** Get the default hazptr_domain */
hazptr_domain* default_hazptr_domain(); hazptr_domain& default_hazptr_domain();
/** Definition of hazptr_obj */
class hazptr_obj {
friend class hazptr_domain;
template <typename, typename>
friend class hazptr_obj_base;
/** Declaration of default reclamation function template */ void (*reclaim_)(hazptr_obj*);
template <typename T> hazptr_obj_reclaim<T>* default_hazptr_obj_reclaim(); hazptr_obj* next_;
const void* getObjPtr() const;
};
/** Definition of hazptr_obj_base */ /** Definition of hazptr_obj_base */
template <typename T> class hazptr_obj_base { template <typename T, typename Deleter = std::default_delete<T>>
class hazptr_obj_base : private hazptr_obj {
public: public:
/* Policy for storing retired objects */ /* Policy for storing retired objects */
enum class storage_policy { priv, shared }; enum class storage_policy { priv, shared };
...@@ -93,29 +97,16 @@ template <typename T> class hazptr_obj_base { ...@@ -93,29 +97,16 @@ template <typename T> class hazptr_obj_base {
/* Retire a removed object and pass the responsibility for /* Retire a removed object and pass the responsibility for
* reclaiming it to the hazptr library */ * reclaiming it to the hazptr library */
void retire( void retire(
hazptr_domain* domain = default_hazptr_domain(), hazptr_domain& domain = default_hazptr_domain(),
const hazptr_obj_reclaim<T>* reclaim = default_hazptr_obj_reclaim<T>(), Deleter reclaim = {},
const storage_policy policy = storage_policy::shared); const storage_policy policy = storage_policy::shared);
private: private:
friend class hazptr_domain; Deleter deleter_;
template <typename> friend class hazptr_owner;
const hazptr_obj_reclaim<T>* reclaim_;
hazptr_obj_base* next_;
}; };
/** hazptr_owner: Template for automatic acquisition and release of /** hazptr_owner: Template for automatic acquisition and release of
* hazard pointers, and interface for hazard pointer operations. */ * hazard pointers, and interface for hazard pointer operations. */
template <typename T> class hazptr_owner;
/* Swap ownership of hazard ponters between hazptr_owner-s. */
/* Note: The owned hazard pointers remain unmodified during the swap
* and continue to protect the respective objects that they were
* protecting before the swap, if any. */
template <typename T>
void swap(hazptr_owner<T>&, hazptr_owner<T>&) noexcept;
template <typename T> class hazptr_owner { template <typename T> class hazptr_owner {
public: public:
/* Policy for caching hazard pointers */ /* Policy for caching hazard pointers */
...@@ -123,11 +114,11 @@ template <typename T> class hazptr_owner { ...@@ -123,11 +114,11 @@ template <typename T> class hazptr_owner {
/* Constructor automatically acquires a hazard pointer. */ /* Constructor automatically acquires a hazard pointer. */
explicit hazptr_owner( explicit hazptr_owner(
hazptr_domain* domain = default_hazptr_domain(), hazptr_domain& domain = default_hazptr_domain(),
const cache_policy policy = cache_policy::nocache); const cache_policy policy = cache_policy::nocache);
/* Destructor automatically clears and releases the owned hazard pointer. */ /* Destructor automatically clears and releases the owned hazard pointer. */
~hazptr_owner() noexcept; ~hazptr_owner();
/* Copy and move constructors and assignment operators are /* Copy and move constructors and assignment operators are
* disallowed because: * disallowed because:
...@@ -139,21 +130,30 @@ template <typename T> class hazptr_owner { ...@@ -139,21 +130,30 @@ template <typename T> class hazptr_owner {
hazptr_owner& operator=(hazptr_owner&&) = delete; hazptr_owner& operator=(hazptr_owner&&) = delete;
/** Hazard pointer operations */ /** Hazard pointer operations */
/* Return true if successful in protecting the object */ /* Returns a protected pointer from the source */
bool protect(const T* ptr, const std::atomic<T*>& src) const noexcept; T* get_protected(const std::atomic<T*>& src) noexcept;
/* Return true if successful in protecting ptr if src == ptr after
* setting the hazard pointer. Otherwise sets ptr to src. */
bool try_protect(T*& ptr, const std::atomic<T*>& src) noexcept;
/* Set the hazard pointer to ptr */ /* Set the hazard pointer to ptr */
void set(const T* ptr) const noexcept; void set(const T* ptr) noexcept;
/* Clear the hazard pointer */ /* Clear the hazard pointer */
void clear() const noexcept; void clear() noexcept;
/* Swap ownership of hazard ponters between hazptr_owner-s. */
/* Note: The owned hazard pointers remain unmodified during the swap
* and continue to protect the respective objects that they were
* protecting before the swap, if any. */
void swap(hazptr_owner&) noexcept; void swap(hazptr_owner&) noexcept;
private: private:
hazptr_domain* domain_; hazptr_domain* domain_;
hazptr_rec* hazptr_; hazptr_rec* hazptr_;
}; };
template <typename T>
void swap(hazptr_owner<T>&, hazptr_owner<T>&) noexcept;
/** hazptr_user: Thread-specific interface for users of hazard /** hazptr_user: Thread-specific interface for users of hazard
* pointers (i.e., threads that own hazard pointers by using * pointers (i.e., threads that own hazard pointers by using
* hazptr_owner. */ * hazptr_owner. */
......
...@@ -26,20 +26,16 @@ ...@@ -26,20 +26,16 @@
#include <thread> #include <thread>
using namespace folly::hazptr; DEFINE_int32(num_threads, 1, "Number of threads");
DEFINE_int64(num_reps, 1, "Number of test reps");
static hazptr_obj_reclaim<Node1> myReclaim_ = [](Node1* p) { DEFINE_int64(num_ops, 10, "Number of ops or pairs of ops per rep");
myReclaimFn(p);
};
static hazptr_obj_reclaim<Node2> mineReclaim_ = [](Node2* p) { using namespace folly::hazptr;
mineReclaimFn(p);
};
TEST(Hazptr, Test1) { TEST(Hazptr, Test1) {
DEBUG_PRINT("========== start of scope"); DEBUG_PRINT("========== start of scope");
DEBUG_PRINT(""); DEBUG_PRINT("");
Node1* node0 = new Node1; Node1* node0 = (Node1*)malloc(sizeof(Node1));
DEBUG_PRINT("=== new node0 " << node0 << " " << sizeof(*node0)); DEBUG_PRINT("=== new node0 " << node0 << " " << sizeof(*node0));
Node1* node1 = (Node1*)malloc(sizeof(Node1)); Node1* node1 = (Node1*)malloc(sizeof(Node1));
DEBUG_PRINT("=== malloc node1 " << node1 << " " << sizeof(*node1)); DEBUG_PRINT("=== malloc node1 " << node1 << " " << sizeof(*node1));
...@@ -67,9 +63,9 @@ TEST(Hazptr, Test1) { ...@@ -67,9 +63,9 @@ TEST(Hazptr, Test1) {
DEBUG_PRINT("=== hptr0"); DEBUG_PRINT("=== hptr0");
hazptr_owner<Node1> hptr0; hazptr_owner<Node1> hptr0;
DEBUG_PRINT("=== hptr1"); DEBUG_PRINT("=== hptr1");
hazptr_owner<Node1> hptr1(&myDomain0); hazptr_owner<Node1> hptr1(myDomain0);
DEBUG_PRINT("=== hptr2"); DEBUG_PRINT("=== hptr2");
hazptr_owner<Node1> hptr2(&myDomain1); hazptr_owner<Node1> hptr2(myDomain1);
DEBUG_PRINT("=== hptr3"); DEBUG_PRINT("=== hptr3");
hazptr_owner<Node1> hptr3; hazptr_owner<Node1> hptr3;
...@@ -80,11 +76,11 @@ TEST(Hazptr, Test1) { ...@@ -80,11 +76,11 @@ TEST(Hazptr, Test1) {
Node1* n2 = shared2.load(); Node1* n2 = shared2.load();
Node1* n3 = shared3.load(); Node1* n3 = shared3.load();
if (hptr0.protect(n0, shared0)) {} if (hptr0.try_protect(n0, shared0)) {}
if (hptr1.protect(n1, shared1)) {} if (hptr1.try_protect(n1, shared1)) {}
hptr1.clear(); hptr1.clear();
hptr1.set(n2); hptr1.set(n2);
if (hptr2.protect(n3, shared3)) {} if (hptr2.try_protect(n3, shared3)) {}
swap(hptr1, hptr2); swap(hptr1, hptr2);
hptr3.clear(); hptr3.clear();
...@@ -93,12 +89,11 @@ TEST(Hazptr, Test1) { ...@@ -93,12 +89,11 @@ TEST(Hazptr, Test1) {
DEBUG_PRINT("=== retire n0 " << n0); DEBUG_PRINT("=== retire n0 " << n0);
n0->retire(); n0->retire();
DEBUG_PRINT("=== retire n1 " << n1); DEBUG_PRINT("=== retire n1 " << n1);
n1->retire(default_hazptr_domain());
n1->retire(default_hazptr_domain(), &myReclaim_);
DEBUG_PRINT("=== retire n2 " << n2); DEBUG_PRINT("=== retire n2 " << n2);
n2->retire(&myDomain0, &myReclaim_); n2->retire(myDomain0);
DEBUG_PRINT("=== retire n3 " << n3); DEBUG_PRINT("=== retire n3 " << n3);
n3->retire(&myDomain1, &myReclaim_); n3->retire(myDomain1);
DEBUG_PRINT("========== end of scope"); DEBUG_PRINT("========== end of scope");
} }
...@@ -133,9 +128,9 @@ TEST(Hazptr, Test2) { ...@@ -133,9 +128,9 @@ TEST(Hazptr, Test2) {
DEBUG_PRINT("=== hptr0"); DEBUG_PRINT("=== hptr0");
hazptr_owner<Node2> hptr0; hazptr_owner<Node2> hptr0;
DEBUG_PRINT("=== hptr1"); DEBUG_PRINT("=== hptr1");
hazptr_owner<Node2> hptr1(&mineDomain0); hazptr_owner<Node2> hptr1(mineDomain0);
DEBUG_PRINT("=== hptr2"); DEBUG_PRINT("=== hptr2");
hazptr_owner<Node2> hptr2(&mineDomain1); hazptr_owner<Node2> hptr2(mineDomain1);
DEBUG_PRINT("=== hptr3"); DEBUG_PRINT("=== hptr3");
hazptr_owner<Node2> hptr3; hazptr_owner<Node2> hptr3;
...@@ -146,33 +141,28 @@ TEST(Hazptr, Test2) { ...@@ -146,33 +141,28 @@ TEST(Hazptr, Test2) {
Node2* n2 = shared2.load(); Node2* n2 = shared2.load();
Node2* n3 = shared3.load(); Node2* n3 = shared3.load();
if (hptr0.protect(n0, shared0)) {} if (hptr0.try_protect(n0, shared0)) {}
if (hptr1.protect(n1, shared1)) {} if (hptr1.try_protect(n1, shared1)) {}
hptr1.clear(); hptr1.clear();
hptr1.set(n2); hptr1.set(n2);
if (hptr2.protect(n3, shared3)) {} if (hptr2.try_protect(n3, shared3)) {}
swap(hptr1, hptr2); swap(hptr1, hptr2);
hptr3.clear(); hptr3.clear();
DEBUG_PRINT(""); DEBUG_PRINT("");
DEBUG_PRINT("=== retire n0 " << n0); DEBUG_PRINT("=== retire n0 " << n0);
n0->retire(); n0->retire(default_hazptr_domain(), &mineReclaimFnDelete);
DEBUG_PRINT("=== retire n1 " << n1); DEBUG_PRINT("=== retire n1 " << n1);
n1->retire(default_hazptr_domain(), &mineReclaimFnFree);
n1->retire(default_hazptr_domain(), &mineReclaim_);
DEBUG_PRINT("=== retire n2 " << n2); DEBUG_PRINT("=== retire n2 " << n2);
n2->retire(&mineDomain0, &mineReclaim_); n2->retire(mineDomain0, &mineReclaimFnFree);
DEBUG_PRINT("=== retire n3 " << n3); DEBUG_PRINT("=== retire n3 " << n3);
n3->retire(&mineDomain1, &mineReclaim_); n3->retire(mineDomain1, &mineReclaimFnFree);
DEBUG_PRINT("========== end of scope"); DEBUG_PRINT("========== end of scope");
} }
DEFINE_int32(num_threads, 1, "Number of threads");
DEFINE_int64(num_reps, 1, "Number of test reps");
DEFINE_int64(num_ops, 10, "Number of ops or pairs of ops per rep");
TEST(Hazptr, LIFO) { TEST(Hazptr, LIFO) {
using T = uint32_t; using T = uint32_t;
DEBUG_PRINT("========== start of test scope"); DEBUG_PRINT("========== start of test scope");
...@@ -206,7 +196,7 @@ TEST(Hazptr, SWMRLIST) { ...@@ -206,7 +196,7 @@ TEST(Hazptr, SWMRLIST) {
CHECK_GT(FLAGS_num_threads, 0); CHECK_GT(FLAGS_num_threads, 0);
for (int i = 0; i < FLAGS_num_reps; ++i) { for (int i = 0; i < FLAGS_num_reps; ++i) {
DEBUG_PRINT("========== start of rep scope"); DEBUG_PRINT("========== start of rep scope");
SWMRListSet<T> s(&custom_domain); SWMRListSet<T> s(custom_domain);
std::vector<std::thread> threads(FLAGS_num_threads); std::vector<std::thread> threads(FLAGS_num_threads);
for (int tid = 0; tid < FLAGS_num_threads; ++tid) { for (int tid = 0; tid < FLAGS_num_threads; ++tid) {
threads[tid] = std::thread([&s, tid]() { threads[tid] = std::thread([&s, tid]() {
...@@ -254,11 +244,12 @@ TEST(Hazptr, WIDECAS) { ...@@ -254,11 +244,12 @@ TEST(Hazptr, WIDECAS) {
} }
int main(int argc, char** argv) { int main(int argc, char** argv) {
DEBUG_PRINT("=================================================== start main"); DEBUG_PRINT("================================================= start main");
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
google::ParseCommandLineFlags(&argc, &argv, true); google::ParseCommandLineFlags(&argc, &argv, true);
auto ret = RUN_ALL_TESTS(); auto ret = RUN_ALL_TESTS();
default_hazptr_domain()->flush(); DEBUG_PRINT("================================================= after tests");
DEBUG_PRINT("===================================================== end main"); default_hazptr_domain().try_reclaim();
DEBUG_PRINT("================================================= end main");
return ret; return ret;
} }
...@@ -35,14 +35,17 @@ class MyMemoryResource : public memory_resource { ...@@ -35,14 +35,17 @@ class MyMemoryResource : public memory_resource {
} }
}; };
class Node1 : public hazptr_obj_base<Node1> { template <typename Node1>
char a[100]; struct MyReclaimerFree {
}; inline void operator()(Node1* p) {
inline void myReclaimFn(Node1* p) {
DEBUG_PRINT(p << " " << sizeof(Node1)); DEBUG_PRINT(p << " " << sizeof(Node1));
free(p); free(p);
} }
};
class Node1 : public hazptr_obj_base<Node1, MyReclaimerFree<Node1>> {
char a[100];
};
} // namespace folly { } // namespace folly {
} // namespace hazptr { } // namespace hazptr {
...@@ -35,14 +35,19 @@ class MineMemoryResource : public memory_resource { ...@@ -35,14 +35,19 @@ class MineMemoryResource : public memory_resource {
} }
}; };
class Node2 : public hazptr_obj_base<Node2> { class Node2 : public hazptr_obj_base<Node2, void (*)(Node2*)> {
char a[200]; char a[200];
}; };
inline void mineReclaimFn(Node2* p) { inline void mineReclaimFnFree(Node2* p) {
DEBUG_PRINT(p << " " << sizeof(Node2)); DEBUG_PRINT(p << " " << sizeof(Node2));
free(p); free(p);
} }
inline void mineReclaimFnDelete(Node2* p) {
DEBUG_PRINT(p << " " << sizeof(Node2));
delete p;
}
} // namespace folly { } // namespace folly {
} // namespace hazptr { } // namespace hazptr {
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment