Commit 0e29494c authored by Maged Michael's avatar Maged Michael Committed by Facebook GitHub Bot

RequestContext: Fix overwrite of null data

Summary:
Fix a bug in overwrite of null data.

Null old data should be erased so that the insert of the new data can succeed.

Reviewed By: davidtgoldblatt

Differential Revision: D20908574

fbshipit-source-id: 6873d9c76daf1289f1d5b97b001c21473c41ea3f
parent e1cdd680
...@@ -301,23 +301,20 @@ RequestContext::StateHazptr::doSetContextDataHelper( ...@@ -301,23 +301,20 @@ RequestContext::StateHazptr::doSetContextDataHelper(
nullptr /* combined not replaced */}; nullptr /* combined not replaced */};
} }
RequestData* oldData = it.value(); RequestData* oldData = it.value();
if (oldData) { // Always erase old data (and run onUnset callback, if any).
// Always erase non-null old data (and run its onUnset callback, // Old data will always be overwritten either by the new data
// if any). Non-null old data will always be overwritten either // (if behavior is OVERWRITE) or by nullptr (if behavior is SET).
// by the new data (if behavior is OVERWRITE) or by nullptr (if Combined* newCombined = eraseOldData(cur, token, oldData, safe);
// behavior is SET). DCHECK(oldData != nullptr || newCombined == nullptr);
Combined* newCombined = eraseOldData(cur, token, oldData, safe); if (newCombined) {
if (newCombined) { replaced = cur;
replaced = cur; cur = newCombined;
cur = newCombined;
}
} }
if (behaviour == DoSetBehaviour::SET) { if (behaviour == DoSetBehaviour::SET) {
// The expected behavior for SET when found is to reset the // The expected behavior for SET when found is to reset the
// pointer and warn, without updating to the new data. // pointer and warn, without updating to the new data.
if (oldData) { bool inserted = cur->requestData_.insert(token, nullptr);
cur->requestData_.insert(token, nullptr); DCHECK(inserted);
}
unexpected = true; unexpected = true;
} else { } else {
DCHECK(behaviour == DoSetBehaviour::OVERWRITE); DCHECK(behaviour == DoSetBehaviour::OVERWRITE);
...@@ -349,22 +346,26 @@ RequestContext::StateHazptr::eraseOldData( ...@@ -349,22 +346,26 @@ RequestContext::StateHazptr::eraseOldData(
bool safe) { bool safe) {
Combined* newCombined = nullptr; Combined* newCombined = nullptr;
// Call onUnset, if any. // Call onUnset, if any.
if (olddata->hasCallback()) { if (olddata && olddata->hasCallback()) {
olddata->onUnset(); olddata->onUnset();
bool erased = cur->callbackData_.erase(olddata); bool erased = cur->callbackData_.erase(olddata);
DCHECK(erased); DCHECK(erased);
} }
if (safe) { if (safe || olddata == nullptr) {
// If the caller guarantees thread-safety, then erase the // If the caller guarantees thread-safety or the old data is null,
// entry in the current version. // then erase the entry in the current version.
cur->requestData_.erase(token); bool erased = cur->requestData_.erase(token);
olddata->releaseRefClearDelete(); DCHECK(erased);
if (olddata) {
olddata->releaseRefClearDelete();
}
} else { } else {
// If there may be concurrent readers, then copy-on-erase. // If there may be concurrent readers, then copy-on-erase.
// Update the data reference counts to account for the // Update the data reference counts to account for the
// existence of the new copy. // existence of the new copy.
newCombined = new Combined(*cur); newCombined = new Combined(*cur);
newCombined->requestData_.erase(token); bool erased = newCombined->requestData_.erase(token);
DCHECK(erased);
newCombined->acquireDataRefs(); newCombined->acquireDataRefs();
} }
return newCombined; return newCombined;
...@@ -387,13 +388,15 @@ RequestContext::StateHazptr::insertNewData( ...@@ -387,13 +388,15 @@ RequestContext::StateHazptr::insertNewData(
} }
if (data && data->hasCallback()) { if (data && data->hasCallback()) {
// If data has callback, insert in callback structure, call onSet // If data has callback, insert in callback structure, call onSet
cur->callbackData_.insert(data.get(), true); bool inserted = cur->callbackData_.insert(data.get(), true);
DCHECK(inserted);
data->onSet(); data->onSet();
} }
if (data) { if (data) {
data->acquireRef(); data->acquireRef();
} }
cur->requestData_.insert(token, data.release()); bool inserted = cur->requestData_.insert(token, data.release());
DCHECK(inserted);
return newCombined; return newCombined;
} }
...@@ -474,16 +477,19 @@ void RequestContext::StateHazptr::clearContextData(const RequestToken& token) { ...@@ -474,16 +477,19 @@ void RequestContext::StateHazptr::clearContextData(const RequestToken& token) {
} }
data = it.value(); data = it.value();
if (!data) { if (!data) {
cur->requestData_.erase(token); bool erased = cur->requestData_.erase(token);
DCHECK(erased);
return; return;
} }
if (data->hasCallback()) { if (data->hasCallback()) {
data->onUnset(); data->onUnset();
cur->callbackData_.erase(data); bool erased = cur->callbackData_.erase(data);
DCHECK(erased);
} }
replaced = cur; replaced = cur;
cur = new Combined(*replaced); cur = new Combined(*replaced);
cur->requestData_.erase(token); bool erased = cur->requestData_.erase(token);
DCHECK(erased);
cur->acquireDataRefs(); cur->acquireDataRefs();
setCombined(cur); setCombined(cur);
} // Unlock mutex_ } // Unlock mutex_
......
...@@ -446,3 +446,12 @@ TEST_F(RequestContextTest, Clear) { ...@@ -446,3 +446,12 @@ TEST_F(RequestContextTest, Clear) {
EXPECT_TRUE(deleted); EXPECT_TRUE(deleted);
} }
} }
TEST_F(RequestContextTest, OverwriteNullData) {
folly::ShallowCopyRequestContextScopeGuard g0("token", nullptr);
{
folly::ShallowCopyRequestContextScopeGuard g1(
"token", std::make_unique<TestData>(0));
EXPECT_NE(folly::RequestContext::get()->getContextData("token"), nullptr);
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment