Commit 15906ea2 authored by Andrii Grynenko's avatar Andrii Grynenko Committed by Facebook GitHub Bot

Add upcasting support

Summary: Store a raw pointer to the managed object outside of the ReadMostlySharedPtrCore to make casting easier.

Reviewed By: yfeldblum

Differential Revision: D23920580

fbshipit-source-id: 4d3d3423f4ef3a78ebbd5efd45bb365ca0f72a18
parent a52c2b20
...@@ -36,14 +36,10 @@ using DefaultRefCount = TLRefCount; ...@@ -36,14 +36,10 @@ using DefaultRefCount = TLRefCount;
namespace detail { namespace detail {
template <typename T, typename RefCount = DefaultRefCount> template <typename RefCount = DefaultRefCount>
class ReadMostlySharedPtrCore { class ReadMostlySharedPtrCore {
public: public:
T* get() { std::shared_ptr<const void> getShared() {
return ptrRaw_;
}
std::shared_ptr<T> getShared() {
return ptr_; return ptr_;
} }
...@@ -53,7 +49,6 @@ class ReadMostlySharedPtrCore { ...@@ -53,7 +49,6 @@ class ReadMostlySharedPtrCore {
void decref() { void decref() {
if (--count_ == 0) { if (--count_ == 0) {
ptrRaw_ = nullptr;
ptr_.reset(); ptr_.reset();
decrefWeak(); decrefWeak();
...@@ -81,16 +76,16 @@ class ReadMostlySharedPtrCore { ...@@ -81,16 +76,16 @@ class ReadMostlySharedPtrCore {
} }
private: private:
friend class ReadMostlyMainPtr<T, RefCount>; template <typename T, typename RefCount2>
friend class folly::ReadMostlyMainPtr;
friend class ReadMostlyMainPtrDeleter<RefCount>; friend class ReadMostlyMainPtrDeleter<RefCount>;
explicit ReadMostlySharedPtrCore(std::shared_ptr<T> ptr) explicit ReadMostlySharedPtrCore(std::shared_ptr<const void> ptr)
: ptrRaw_(ptr.get()), ptr_(std::move(ptr)) {} : ptr_(std::move(ptr)) {}
T* ptrRaw_;
RefCount count_; RefCount count_;
RefCount weakCount_; RefCount weakCount_;
std::shared_ptr<T> ptr_; std::shared_ptr<const void> ptr_;
}; };
} // namespace detail } // namespace detail
...@@ -113,7 +108,7 @@ class ReadMostlyMainPtr { ...@@ -113,7 +108,7 @@ class ReadMostlyMainPtr {
ReadMostlyMainPtr& operator=(ReadMostlyMainPtr&& other) noexcept { ReadMostlyMainPtr& operator=(ReadMostlyMainPtr&& other) noexcept {
std::swap(impl_, other.impl_); std::swap(impl_, other.impl_);
std::swap(ptrRaw_, other.ptrRaw_);
return *this; return *this;
} }
...@@ -135,6 +130,7 @@ class ReadMostlyMainPtr { ...@@ -135,6 +130,7 @@ class ReadMostlyMainPtr {
void reset() noexcept { void reset() noexcept {
if (impl_) { if (impl_) {
ptrRaw_ = nullptr;
impl_->count_.useGlobal(); impl_->count_.useGlobal();
impl_->weakCount_.useGlobal(); impl_->weakCount_.useGlobal();
impl_->decref(); impl_->decref();
...@@ -145,21 +141,18 @@ class ReadMostlyMainPtr { ...@@ -145,21 +141,18 @@ class ReadMostlyMainPtr {
void reset(std::shared_ptr<T> ptr) { void reset(std::shared_ptr<T> ptr) {
reset(); reset();
if (ptr) { if (ptr) {
impl_ = new detail::ReadMostlySharedPtrCore<T, RefCount>(std::move(ptr)); ptrRaw_ = ptr.get();
impl_ = new detail::ReadMostlySharedPtrCore<RefCount>(std::move(ptr));
} }
} }
T* get() const { T* get() const {
if (impl_) { return ptrRaw_;
return impl_->ptrRaw_;
} else {
return nullptr;
}
} }
std::shared_ptr<T> getStdShared() const { std::shared_ptr<T> getStdShared() const {
if (impl_) { if (impl_) {
return impl_->getShared(); return {impl_->getShared(), ptrRaw_};
} else { } else {
return {}; return {};
} }
...@@ -182,11 +175,14 @@ class ReadMostlyMainPtr { ...@@ -182,11 +175,14 @@ class ReadMostlyMainPtr {
} }
private: private:
friend class ReadMostlyWeakPtr<T, RefCount>; template <typename U, typename RefCount2>
friend class ReadMostlySharedPtr<T, RefCount>; friend class ReadMostlyWeakPtr;
template <typename U, typename RefCount2>
friend class ReadMostlySharedPtr;
friend class ReadMostlyMainPtrDeleter<RefCount>; friend class ReadMostlyMainPtrDeleter<RefCount>;
detail::ReadMostlySharedPtrCore<T, RefCount>* impl_{nullptr}; detail::ReadMostlySharedPtrCore<RefCount>* impl_{nullptr};
T* ptrRaw_{nullptr};
}; };
template <typename T, typename RefCount = DefaultRefCount> template <typename T, typename RefCount = DefaultRefCount>
...@@ -194,39 +190,91 @@ class ReadMostlyWeakPtr { ...@@ -194,39 +190,91 @@ class ReadMostlyWeakPtr {
public: public:
ReadMostlyWeakPtr() {} ReadMostlyWeakPtr() {}
explicit ReadMostlyWeakPtr(const ReadMostlyMainPtr<T, RefCount>& mainPtr) { ReadMostlyWeakPtr(const ReadMostlyWeakPtr& other) {
reset(mainPtr.impl_); *this = other;
} }
explicit ReadMostlyWeakPtr(const ReadMostlySharedPtr<T, RefCount>& ptr) { ReadMostlyWeakPtr(ReadMostlyWeakPtr&& other) noexcept {
reset(ptr.impl_); *this = std::move(other);
} }
ReadMostlyWeakPtr(const ReadMostlyWeakPtr& other) { template <
typename T2,
typename = std::enable_if_t<std::is_convertible<T2*, T*>::value>>
ReadMostlyWeakPtr(const ReadMostlyWeakPtr<T2, RefCount>& other) {
*this = other; *this = other;
} }
ReadMostlyWeakPtr& operator=(const ReadMostlyWeakPtr& other) { template <
reset(other.impl_); typename T2,
return *this; typename = std::enable_if_t<std::is_convertible<T2*, T*>::value>>
ReadMostlyWeakPtr(ReadMostlyWeakPtr<T2, RefCount>&& other) noexcept {
*this = std::move(other);
} }
ReadMostlyWeakPtr& operator=(const ReadMostlyMainPtr<T, RefCount>& mainPtr) { template <
reset(mainPtr.impl_); typename T2,
return *this; typename = std::enable_if_t<std::is_convertible<T2*, T*>::value>>
explicit ReadMostlyWeakPtr(const ReadMostlyMainPtr<T2, RefCount>& other) {
*this = other;
} }
ReadMostlyWeakPtr(ReadMostlyWeakPtr&& other) noexcept { template <
typename T2,
typename = std::enable_if_t<std::is_convertible<T2*, T*>::value>>
explicit ReadMostlyWeakPtr(const ReadMostlySharedPtr<T2, RefCount>& other) {
*this = other; *this = other;
} }
ReadMostlyWeakPtr& operator=(const ReadMostlyWeakPtr& other) {
reset(other.impl_, other.ptrRaw_);
return *this;
}
ReadMostlyWeakPtr& operator=(ReadMostlyWeakPtr&& other) noexcept { ReadMostlyWeakPtr& operator=(ReadMostlyWeakPtr&& other) noexcept {
std::swap(impl_, other.impl_); std::swap(impl_, other.impl_);
std::swap(ptrRaw_, other.ptrRaw_);
return *this;
}
template <
typename T2,
typename = std::enable_if_t<std::is_convertible<T2*, T*>::value>>
ReadMostlyWeakPtr& operator=(const ReadMostlyWeakPtr<T2, RefCount>& other) {
reset(other.impl_, other.ptrRaw_);
return *this;
}
template <
typename T2,
typename = std::enable_if_t<std::is_convertible<T2*, T*>::value>>
ReadMostlyWeakPtr& operator=(
ReadMostlyWeakPtr<T2, RefCount>&& other) noexcept {
reset();
impl_ = std::exchange(other.impl_, nullptr);
ptrRaw_ = std::exchange(other.ptrRaw_, nullptr);
return *this;
}
template <
typename T2,
typename = std::enable_if_t<std::is_convertible<T2*, T*>::value>>
ReadMostlyWeakPtr& operator=(const ReadMostlyMainPtr<T2, RefCount>& mainPtr) {
reset(mainPtr.impl_, mainPtr.ptrRaw_);
return *this;
}
template <
typename T2,
typename = std::enable_if_t<std::is_convertible<T2*, T*>::value>>
ReadMostlyWeakPtr& operator=(
const ReadMostlySharedPtr<T2, RefCount>& mainPtr) {
reset(mainPtr.impl_, mainPtr.ptrRaw_);
return *this; return *this;
} }
~ReadMostlyWeakPtr() noexcept { ~ReadMostlyWeakPtr() noexcept {
reset(nullptr); reset(nullptr, nullptr);
} }
ReadMostlySharedPtr<T, RefCount> lock() { ReadMostlySharedPtr<T, RefCount> lock() {
...@@ -234,19 +282,28 @@ class ReadMostlyWeakPtr { ...@@ -234,19 +282,28 @@ class ReadMostlyWeakPtr {
} }
private: private:
friend class ReadMostlySharedPtr<T, RefCount>; template <typename U, typename RefCount2>
friend class ReadMostlyWeakPtr;
template <typename U, typename RefCount2>
friend class ReadMostlySharedPtr;
void reset(detail::ReadMostlySharedPtrCore<RefCount>* impl, T* ptrRaw) {
if (impl_ == impl) {
return;
}
void reset(detail::ReadMostlySharedPtrCore<T, RefCount>* impl) {
if (impl_) { if (impl_) {
impl_->decrefWeak(); impl_->decrefWeak();
} }
impl_ = impl; impl_ = impl;
ptrRaw_ = ptrRaw;
if (impl_) { if (impl_) {
impl_->increfWeak(); impl_->increfWeak();
} }
} }
detail::ReadMostlySharedPtrCore<T, RefCount>* impl_{nullptr}; detail::ReadMostlySharedPtrCore<RefCount>* impl_{nullptr};
T* ptrRaw_{nullptr};
}; };
template <typename T, typename RefCount = DefaultRefCount> template <typename T, typename RefCount = DefaultRefCount>
...@@ -254,48 +311,94 @@ class ReadMostlySharedPtr { ...@@ -254,48 +311,94 @@ class ReadMostlySharedPtr {
public: public:
ReadMostlySharedPtr() {} ReadMostlySharedPtr() {}
explicit ReadMostlySharedPtr(const ReadMostlyWeakPtr<T, RefCount>& weakPtr) { ReadMostlySharedPtr(const ReadMostlySharedPtr& other) {
reset(weakPtr.impl_); *this = other;
} }
// Generally, this shouldn't be used. ReadMostlySharedPtr(ReadMostlySharedPtr&& other) noexcept {
explicit ReadMostlySharedPtr(const ReadMostlyMainPtr<T, RefCount>& mainPtr) { *this = std::move(other);
reset(mainPtr.impl_);
} }
ReadMostlySharedPtr(const ReadMostlySharedPtr& other) { template <
typename T2,
typename = std::enable_if_t<std::is_convertible<T2*, T*>::value>>
ReadMostlySharedPtr(const ReadMostlySharedPtr<T2, RefCount>& other) {
*this = other;
}
template <
typename T2,
typename = std::enable_if_t<std::is_convertible<T2*, T*>::value>>
ReadMostlySharedPtr(ReadMostlySharedPtr<T2, RefCount>&& other) noexcept {
*this = std::move(other);
}
template <
typename T2,
typename = std::enable_if_t<std::is_convertible<T2*, T*>::value>>
explicit ReadMostlySharedPtr(const ReadMostlyWeakPtr<T2, RefCount>& other) {
*this = other;
}
// Generally, this shouldn't be used.
template <
typename T2,
typename = std::enable_if_t<std::is_convertible<T2*, T*>::value>>
explicit ReadMostlySharedPtr(const ReadMostlyMainPtr<T2, RefCount>& other) {
*this = other; *this = other;
} }
ReadMostlySharedPtr& operator=(const ReadMostlySharedPtr& other) { ReadMostlySharedPtr& operator=(const ReadMostlySharedPtr& other) {
reset(other.impl_); reset(other.impl_, other.ptrRaw_);
return *this; return *this;
} }
ReadMostlySharedPtr& operator=(const ReadMostlyWeakPtr<T, RefCount>& other) { ReadMostlySharedPtr& operator=(ReadMostlySharedPtr&& other) noexcept {
reset(other.impl_); std::swap(impl_, other.impl_);
std::swap(ptrRaw_, other.ptrRaw_);
return *this; return *this;
} }
ReadMostlySharedPtr& operator=(const ReadMostlyMainPtr<T, RefCount>& other) { template <
reset(other.impl_); typename T2,
typename = std::enable_if_t<std::is_convertible<T2*, T*>::value>>
ReadMostlySharedPtr& operator=(
const ReadMostlySharedPtr<T2, RefCount>& other) {
reset(other.impl_, other.ptrRaw_);
return *this; return *this;
} }
ReadMostlySharedPtr(ReadMostlySharedPtr&& other) noexcept { template <
*this = std::move(other); typename T2,
typename = std::enable_if_t<std::is_convertible<T2*, T*>::value>>
ReadMostlySharedPtr& operator=(
ReadMostlySharedPtr<T2, RefCount>&& other) noexcept {
reset();
impl_ = std::exchange(other.impl_, nullptr);
ptrRaw_ = std::exchange(other.ptrRaw_, nullptr);
return *this;
} }
~ReadMostlySharedPtr() noexcept { template <
reset(nullptr); typename T2,
typename = std::enable_if_t<std::is_convertible<T2*, T*>::value>>
ReadMostlySharedPtr& operator=(const ReadMostlyWeakPtr<T2, RefCount>& other) {
reset(other.impl_, other.ptrRaw_);
return *this;
} }
ReadMostlySharedPtr& operator=(ReadMostlySharedPtr&& other) noexcept { template <
std::swap(ptr_, other.ptr_); typename T2,
std::swap(impl_, other.impl_); typename = std::enable_if_t<std::is_convertible<T2*, T*>::value>>
ReadMostlySharedPtr& operator=(const ReadMostlyMainPtr<T2, RefCount>& other) {
reset(other.impl_, other.ptrRaw_);
return *this; return *this;
} }
~ReadMostlySharedPtr() noexcept {
reset(nullptr, nullptr);
}
bool operator==(const ReadMostlyMainPtr<T, RefCount>& other) const { bool operator==(const ReadMostlyMainPtr<T, RefCount>& other) const {
return get() == other.get(); return get() == other.get();
} }
...@@ -309,16 +412,16 @@ class ReadMostlySharedPtr { ...@@ -309,16 +412,16 @@ class ReadMostlySharedPtr {
} }
void reset() { void reset() {
reset(nullptr); reset(nullptr, nullptr);
} }
T* get() const { T* get() const {
return ptr_; return ptrRaw_;
} }
std::shared_ptr<T> getStdShared() const { std::shared_ptr<T> getStdShared() const {
if (impl_) { if (impl_) {
return impl_->getShared(); return {impl_->getShared(), ptrRaw_};
} else { } else {
return {}; return {};
} }
...@@ -345,23 +448,30 @@ class ReadMostlySharedPtr { ...@@ -345,23 +448,30 @@ class ReadMostlySharedPtr {
} }
private: private:
friend class ReadMostlyWeakPtr<T, RefCount>; template <typename U, typename RefCount2>
friend class ReadMostlyWeakPtr;
template <typename U, typename RefCount2>
friend class ReadMostlySharedPtr;
void reset(detail::ReadMostlySharedPtrCore<RefCount>* impl, T* ptrRaw) {
if (impl_ == impl) {
return;
}
void reset(detail::ReadMostlySharedPtrCore<T, RefCount>* impl) {
if (impl_) { if (impl_) {
impl_->decref(); impl_->decref();
impl_ = nullptr; impl_ = nullptr;
ptr_ = nullptr; ptrRaw_ = nullptr;
} }
if (impl && impl->incref()) { if (impl && impl->incref()) {
impl_ = impl; impl_ = impl;
ptr_ = impl->get(); ptrRaw_ = ptrRaw;
} }
} }
T* ptr_{nullptr}; T* ptrRaw_{nullptr};
detail::ReadMostlySharedPtrCore<T, RefCount>* impl_{nullptr}; detail::ReadMostlySharedPtrCore<RefCount>* impl_{nullptr};
}; };
/** /**
...@@ -387,6 +497,7 @@ class ReadMostlyMainPtrDeleter { ...@@ -387,6 +497,7 @@ class ReadMostlyMainPtrDeleter {
refCounts_.push_back(&ptr.impl_->weakCount_); refCounts_.push_back(&ptr.impl_->weakCount_);
decrefs_.push_back([impl = ptr.impl_] { impl->decref(); }); decrefs_.push_back([impl = ptr.impl_] { impl->decref(); });
ptr.impl_ = nullptr; ptr.impl_ = nullptr;
ptr.ptrRaw_ = nullptr;
} }
private: private:
......
...@@ -346,3 +346,51 @@ TEST_F(ReadMostlySharedPtrTest, getStdShared) { ...@@ -346,3 +346,51 @@ TEST_F(ReadMostlySharedPtrTest, getStdShared) {
// No conditions to check; we just wanted to ensure this compiles. // No conditions to check; we just wanted to ensure this compiles.
SUCCEED(); SUCCEED();
} }
struct Base {
virtual ~Base() = default;
virtual std::string getName() const {
return "Base";
}
};
struct Derived : public Base {
std::string getName() const override {
return "Derived";
}
};
TEST_F(ReadMostlySharedPtrTest, casts) {
ReadMostlyMainPtr<Derived> rmmp(std::make_shared<Derived>());
ReadMostlySharedPtr<Derived> rmsp(rmmp);
{
ReadMostlySharedPtr<Base> rmspbase(rmmp);
EXPECT_EQ("Derived", rmspbase->getName());
EXPECT_EQ("Derived", rmspbase.getStdShared()->getName());
}
{
ReadMostlySharedPtr<Base> rmspbase(rmsp);
EXPECT_EQ("Derived", rmspbase->getName());
EXPECT_EQ("Derived", rmspbase.getStdShared()->getName());
}
{
ReadMostlySharedPtr<Base> rmspbase;
rmspbase = rmsp;
EXPECT_EQ("Derived", rmspbase->getName());
EXPECT_EQ("Derived", rmspbase.getStdShared()->getName());
}
{
auto rmspcopy = rmsp;
ReadMostlySharedPtr<Base> rmspbase(std::move(rmspcopy));
EXPECT_EQ("Derived", rmspbase->getName());
EXPECT_EQ("Derived", rmspbase.getStdShared()->getName());
}
{
auto rmspcopy = rmsp;
ReadMostlySharedPtr<Base> rmspbase;
rmspbase = std::move(rmspcopy);
EXPECT_EQ("Derived", rmspbase->getName());
EXPECT_EQ("Derived", rmspbase.getStdShared()->getName());
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment