Commit f63877d7 authored by Matthieu Martin's avatar Matthieu Martin Committed by Facebook Github Bot

Introduce folly ShallowCopyRequestContextScopeGuard

Summary:
This guard maintains all the RequestData pointers of the parent (through shallow copy).
This allows to overwrite a specific RequestData pointer for the scope's duration, without breaking others.

We decided to keep the raw ptr interface, which required to implement a pseudo shared ptr to achieve the shallow copy behaviour.
Rest of the code is pretty straight forward. A few more lines than expected, due to introducing overrideContextData to avoid unecessary memory management (clearData) or warnings (setData).

The performance should be neutral for code not using the guard (std::atomic incr/decr).
The guard itself is pretty efficient at copying the values, though there is a slight worry about the keys (std::string). This might be a generic concern about current implementation, some form of cheap static would be better.
It also calls unecessarily onSet/onUnset. I will fix on top as it makes the change more complex.

Reviewed By: djwatson

Differential Revision: D8911351

fbshipit-source-id: 1692428382ace1d0b79bbc84a1db50efb4c7b489
parent 4941d624
...@@ -24,22 +24,44 @@ ...@@ -24,22 +24,44 @@
namespace folly { namespace folly {
void RequestData::DestructPtr::operator()(RequestData* ptr) {
if (ptr) {
auto keepAliveCounter =
ptr->keepAliveCounter_.fetch_sub(1, std::memory_order_acq_rel);
// Note: this is the value before decrement, hence == 1 check
DCHECK(keepAliveCounter > 0);
if (keepAliveCounter == 1) {
delete ptr;
}
}
}
/* static */ RequestData::SharedPtr RequestData::constructPtr(
RequestData* ptr) {
if (ptr) {
auto keepAliveCounter =
ptr->keepAliveCounter_.fetch_add(1, std::memory_order_relaxed);
DCHECK(keepAliveCounter >= 0);
}
return SharedPtr(ptr);
}
bool RequestContext::doSetContextData( bool RequestContext::doSetContextData(
const std::string& val, const std::string& val,
std::unique_ptr<RequestData>& data, std::unique_ptr<RequestData>& data,
bool strict) { DoSetBehaviour behaviour) {
auto ulock = state_.ulock(); auto ulock = state_.ulock();
bool conflict = false; bool conflict = false;
auto it = ulock->requestData_.find(val); auto it = ulock->requestData_.find(val);
if (it != ulock->requestData_.end()) { if (it != ulock->requestData_.end()) {
if (strict) { if (behaviour == DoSetBehaviour::SET_IF_ABSENT) {
return false; return false;
} else { } else if (behaviour == DoSetBehaviour::SET) {
LOG_FIRST_N(WARNING, 1) << "Calling RequestContext::setContextData for " LOG_FIRST_N(WARNING, 1) << "Calling RequestContext::setContextData for "
<< val << " but it is already set"; << val << " but it is already set";
conflict = true;
} }
conflict = true;
} }
auto wlock = ulock.moveFromUpgradeToWrite(); auto wlock = ulock.moveFromUpgradeToWrite();
...@@ -51,14 +73,16 @@ bool RequestContext::doSetContextData( ...@@ -51,14 +73,16 @@ bool RequestContext::doSetContextData(
} }
it->second.reset(nullptr); it->second.reset(nullptr);
} }
return true; if (behaviour == DoSetBehaviour::SET) {
return true;
}
} }
if (data && data->hasCallback()) { if (data && data->hasCallback()) {
wlock->callbackData_.insert(data.get()); wlock->callbackData_.insert(data.get());
data->onSet(); data->onSet();
} }
wlock->requestData_[val] = std::move(data); wlock->requestData_[val] = RequestData::constructPtr(data.release());
return true; return true;
} }
...@@ -66,13 +90,19 @@ bool RequestContext::doSetContextData( ...@@ -66,13 +90,19 @@ bool RequestContext::doSetContextData(
void RequestContext::setContextData( void RequestContext::setContextData(
const std::string& val, const std::string& val,
std::unique_ptr<RequestData> data) { std::unique_ptr<RequestData> data) {
doSetContextData(val, data, false /* strict */); doSetContextData(val, data, DoSetBehaviour::SET);
} }
bool RequestContext::setContextDataIfAbsent( bool RequestContext::setContextDataIfAbsent(
const std::string& val, const std::string& val,
std::unique_ptr<RequestData> data) { std::unique_ptr<RequestData> data) {
return doSetContextData(val, data, true /* strict */); return doSetContextData(val, data, DoSetBehaviour::SET_IF_ABSENT);
}
void RequestContext::overwriteContextData(
const std::string& val,
std::unique_ptr<RequestData> data) {
doSetContextData(val, data, DoSetBehaviour::OVERWRITE);
} }
bool RequestContext::hasContextData(const std::string& val) const { bool RequestContext::hasContextData(const std::string& val) const {
...@@ -80,13 +110,13 @@ bool RequestContext::hasContextData(const std::string& val) const { ...@@ -80,13 +110,13 @@ bool RequestContext::hasContextData(const std::string& val) const {
} }
RequestData* RequestContext::getContextData(const std::string& val) { RequestData* RequestContext::getContextData(const std::string& val) {
const std::unique_ptr<RequestData> dflt{nullptr}; const RequestData::SharedPtr dflt{nullptr};
return get_ref_default(state_.rlock()->requestData_, val, dflt).get(); return get_ref_default(state_.rlock()->requestData_, val, dflt).get();
} }
const RequestData* RequestContext::getContextData( const RequestData* RequestContext::getContextData(
const std::string& val) const { const std::string& val) const {
const std::unique_ptr<RequestData> dflt{nullptr}; const RequestData::SharedPtr dflt{nullptr};
return get_ref_default(state_.rlock()->requestData_, val, dflt).get(); return get_ref_default(state_.rlock()->requestData_, val, dflt).get();
} }
...@@ -105,7 +135,7 @@ void RequestContext::onUnset() { ...@@ -105,7 +135,7 @@ void RequestContext::onUnset() {
} }
void RequestContext::clearContextData(const std::string& val) { void RequestContext::clearContextData(const std::string& val) {
std::unique_ptr<RequestData> requestData; RequestData::SharedPtr requestData;
// Delete the RequestData after giving up the wlock just in case one of the // Delete the RequestData after giving up the wlock just in case one of the
// RequestData destructors will try to grab the lock again. // RequestData destructors will try to grab the lock again.
{ {
...@@ -147,6 +177,21 @@ std::shared_ptr<RequestContext>& RequestContext::getStaticContext() { ...@@ -147,6 +177,21 @@ std::shared_ptr<RequestContext>& RequestContext::getStaticContext() {
return SingletonT::get(); return SingletonT::get();
} }
/* static */ std::shared_ptr<RequestContext> RequestContext::shallowCopy() {
auto* parent = get();
auto child = std::make_shared<RequestContext>();
if (parent) {
auto rlock = parent->state_.rlock();
auto wlock = child->state_.wlock();
wlock->callbackData_ = rlock->callbackData_;
for (const auto& entry : rlock->requestData_) {
wlock->requestData_[entry.first] =
RequestData::constructPtr(entry.second.get());
}
}
return child;
}
RequestContext* RequestContext::get() { RequestContext* RequestContext::get() {
auto& context = getStaticContext(); auto& context = getStaticContext();
if (!context) { if (!context) {
......
...@@ -45,6 +45,28 @@ class RequestData { ...@@ -45,6 +45,28 @@ class RequestData {
// instance overrides the hasCallback method to return true otherwise // instance overrides the hasCallback method to return true otherwise
// the callback will not be executed // the callback will not be executed
virtual void onUnset() {} virtual void onUnset() {}
private:
// Start shallow copy implementation details:
// For efficiency, RequestContext provides a raw ptr interface.
// To support shallow copy, we need a shared ptr.
// To keep it as safe as possible (even if a raw ptr is passed back),
// the counter lives directly in RequestData.
friend class RequestContext;
// Unique ptr with custom destructor, decrement the counter
// and only free if 0
struct DestructPtr {
void operator()(RequestData* ptr);
};
using SharedPtr = std::unique_ptr<RequestData, DestructPtr>;
// Initialize the pseudo-shared ptr, increment the counter
static SharedPtr constructPtr(RequestData* ptr);
std::atomic<int> keepAliveCounter_{0};
// End shallow copy
}; };
// If you do not call create() to create a unique request context, // If you do not call create() to create a unique request context,
...@@ -118,18 +140,40 @@ class RequestContext { ...@@ -118,18 +140,40 @@ class RequestContext {
private: private:
static std::shared_ptr<RequestContext>& getStaticContext(); static std::shared_ptr<RequestContext>& getStaticContext();
friend struct ShallowCopyRequestContextScopeGuard;
// Private to encourage using shallow copy guard
static std::shared_ptr<RequestContext> shallowCopy();
// Similar to setContextData, except it overwrites the data
// if already set (instead of warn + reset ptr).
// Private to encourage using shallow copy guard
void overwriteContextData(
const std::string& val,
std::unique_ptr<RequestData> data);
enum class DoSetBehaviour {
SET,
SET_IF_ABSENT,
OVERWRITE,
};
bool doSetContextData( bool doSetContextData(
const std::string& val, const std::string& val,
std::unique_ptr<RequestData>& data, std::unique_ptr<RequestData>& data,
bool strict); DoSetBehaviour behaviour);
struct State { struct State {
std::map<std::string, std::unique_ptr<RequestData>> requestData_; std::map<std::string, RequestData::SharedPtr> requestData_;
std::set<RequestData*> callbackData_; std::set<RequestData*> callbackData_;
}; };
folly::Synchronized<State> state_; folly::Synchronized<State> state_;
}; };
/**
* Note: you probably want to use ShallowCopyRequestContextScopeGuard
* This resets all other RequestData for the duration of the scope!
*/
class RequestContextScopeGuard { class RequestContextScopeGuard {
private: private:
std::shared_ptr<RequestContext> prev_; std::shared_ptr<RequestContext> prev_;
...@@ -155,4 +199,30 @@ class RequestContextScopeGuard { ...@@ -155,4 +199,30 @@ class RequestContextScopeGuard {
RequestContext::setContext(std::move(prev_)); RequestContext::setContext(std::move(prev_));
} }
}; };
/**
* This guard maintains all the RequestData pointers of the parent.
* This allows to overwrite a specific RequestData pointer for the
* scope's duration, without breaking others.
*
* TODO: currently calls onSet and onUnset incorrectly, will fix on top
*/
struct ShallowCopyRequestContextScopeGuard : public RequestContextScopeGuard {
ShallowCopyRequestContextScopeGuard()
: RequestContextScopeGuard(RequestContext::shallowCopy()) {}
/**
* Shallow copy then overwrite one specific RequestData
*
* Helper constructor which is a more efficient equivalent to
* "clearRequestData" then "setRequestData" after the guard.
*/
ShallowCopyRequestContextScopeGuard(
const std::string& val,
std::unique_ptr<RequestData> data)
: ShallowCopyRequestContextScopeGuard() {
RequestContext::get()->overwriteContextData(val, std::move(data));
}
};
} // namespace folly } // namespace folly
...@@ -50,20 +50,24 @@ RequestContext& getContext() { ...@@ -50,20 +50,24 @@ RequestContext& getContext() {
return *ctx; return *ctx;
} }
void setData(int data = 0) { void setData(int data = 0, std::string key = "test") {
getContext().setContextData("test", std::make_unique<TestData>(data)); getContext().setContextData(key, std::make_unique<TestData>(data));
} }
bool hasData() { bool hasData(std::string key = "test") {
return getContext().getContextData("test") != nullptr; return getContext().hasContextData(key);
} }
const TestData& getData() { const TestData& getData(std::string key = "test") {
auto* ptr = dynamic_cast<TestData*>(getContext().getContextData("test")); auto* ptr = dynamic_cast<TestData*>(getContext().getContextData(key));
EXPECT_TRUE(ptr != nullptr); EXPECT_TRUE(ptr != nullptr);
return *ptr; return *ptr;
} }
void clearData(std::string key = "test") {
getContext().clearContextData(key);
}
TEST(RequestContext, SimpleTest) { TEST(RequestContext, SimpleTest) {
EventBase base; EventBase base;
...@@ -207,3 +211,71 @@ TEST(RequestContext, deadlockTest) { ...@@ -207,3 +211,71 @@ TEST(RequestContext, deadlockTest) {
"test", std::make_unique<DeadlockTestData>("test2")); "test", std::make_unique<DeadlockTestData>("test2"));
RequestContext::get()->clearContextData("test"); RequestContext::get()->clearContextData("test");
} }
TEST(RequestContext, ShallowCopyBasic) {
ShallowCopyRequestContextScopeGuard g0;
setData(123, "immutable");
EXPECT_EQ(123, getData("immutable").data_);
EXPECT_FALSE(hasData());
{
ShallowCopyRequestContextScopeGuard g1;
EXPECT_EQ(123, getData("immutable").data_);
setData(789);
EXPECT_EQ(789, getData().data_);
}
EXPECT_FALSE(hasData());
EXPECT_EQ(123, getData("immutable").data_);
// TODO: Should be 1/0
EXPECT_EQ(3, getData("immutable").set_);
EXPECT_EQ(2, getData("immutable").unset_);
}
TEST(RequestContext, ShallowCopyOverwrite) {
RequestContextScopeGuard g0;
setData(123);
EXPECT_EQ(123, getData().data_);
{
ShallowCopyRequestContextScopeGuard g1(
"test", std::make_unique<TestData>(789));
EXPECT_EQ(789, getData().data_);
EXPECT_EQ(1, getData().set_);
EXPECT_EQ(0, getData().unset_);
}
EXPECT_EQ(123, getData().data_);
// TODO: Should be 2/1
EXPECT_EQ(3, getData().set_);
EXPECT_EQ(2, getData().unset_);
}
TEST(RequestContext, ShallowCopyDefaultContext) {
// Don't set global scope guard
setData(123);
EXPECT_EQ(123, getData().data_);
{
ShallowCopyRequestContextScopeGuard g1(
"test", std::make_unique<TestData>(789));
EXPECT_EQ(789, getData().data_);
}
EXPECT_EQ(123, getData().data_);
EXPECT_EQ(2, getData().set_);
EXPECT_EQ(1, getData().unset_);
}
TEST(RequestContext, ShallowCopyClear) {
RequestContextScopeGuard g0;
setData(123);
EXPECT_EQ(123, getData().data_);
{
ShallowCopyRequestContextScopeGuard g1;
EXPECT_EQ(123, getData().data_);
clearData();
setData(789);
EXPECT_EQ(789, getData().data_);
}
EXPECT_EQ(123, getData().data_);
// TODO: Should be 2/1
EXPECT_EQ(3, getData().set_);
EXPECT_EQ(2, getData().unset_);
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment