From f63877d7f42992ffbb8046cd9518a19fb1c84e03 Mon Sep 17 00:00:00 2001
From: Matthieu Martin <matthieu@fb.com>
Date: Wed, 25 Jul 2018 23:50:17 -0700
Subject: [PATCH] Introduce folly ShallowCopyRequestContextScopeGuard

Summary:
This guard maintains all the RequestData pointers of the parent (through shallow copy).
This allows to overwrite a specific RequestData pointer for the scope's duration, without breaking others.

We decided to keep the raw ptr interface, which required to implement a pseudo shared ptr to achieve the shallow copy behaviour.
Rest of the code is pretty straight forward. A few more lines than expected, due to introducing overrideContextData to avoid unecessary memory management (clearData) or warnings (setData).

The performance should be neutral for code not using the guard (std::atomic incr/decr).
The guard itself is pretty efficient at copying the values, though there is a slight worry about the keys (std::string). This might be a generic concern about current implementation, some form of cheap static would be better.
It also calls unecessarily onSet/onUnset. I will fix on top as it makes the change more complex.

Reviewed By: djwatson

Differential Revision: D8911351

fbshipit-source-id: 1692428382ace1d0b79bbc84a1db50efb4c7b489
---
 folly/io/async/Request.cpp                 | 67 ++++++++++++++---
 folly/io/async/Request.h                   | 74 ++++++++++++++++++-
 folly/io/async/test/RequestContextTest.cpp | 84 ++++++++++++++++++++--
 3 files changed, 206 insertions(+), 19 deletions(-)

diff --git a/folly/io/async/Request.cpp b/folly/io/async/Request.cpp
index fa14358e6..67708798c 100644
--- a/folly/io/async/Request.cpp
+++ b/folly/io/async/Request.cpp
@@ -24,22 +24,44 @@
 
 namespace folly {
 
+void RequestData::DestructPtr::operator()(RequestData* ptr) {
+  if (ptr) {
+    auto keepAliveCounter =
+        ptr->keepAliveCounter_.fetch_sub(1, std::memory_order_acq_rel);
+    // Note: this is the value before decrement, hence == 1 check
+    DCHECK(keepAliveCounter > 0);
+    if (keepAliveCounter == 1) {
+      delete ptr;
+    }
+  }
+}
+
+/* static */ RequestData::SharedPtr RequestData::constructPtr(
+    RequestData* ptr) {
+  if (ptr) {
+    auto keepAliveCounter =
+        ptr->keepAliveCounter_.fetch_add(1, std::memory_order_relaxed);
+    DCHECK(keepAliveCounter >= 0);
+  }
+  return SharedPtr(ptr);
+}
+
 bool RequestContext::doSetContextData(
     const std::string& val,
     std::unique_ptr<RequestData>& data,
-    bool strict) {
+    DoSetBehaviour behaviour) {
   auto ulock = state_.ulock();
 
   bool conflict = false;
   auto it = ulock->requestData_.find(val);
   if (it != ulock->requestData_.end()) {
-    if (strict) {
+    if (behaviour == DoSetBehaviour::SET_IF_ABSENT) {
       return false;
-    } else {
+    } else if (behaviour == DoSetBehaviour::SET) {
       LOG_FIRST_N(WARNING, 1) << "Calling RequestContext::setContextData for "
                               << val << " but it is already set";
-      conflict = true;
     }
+    conflict = true;
   }
 
   auto wlock = ulock.moveFromUpgradeToWrite();
@@ -51,14 +73,16 @@ bool RequestContext::doSetContextData(
       }
       it->second.reset(nullptr);
     }
-    return true;
+    if (behaviour == DoSetBehaviour::SET) {
+      return true;
+    }
   }
 
   if (data && data->hasCallback()) {
     wlock->callbackData_.insert(data.get());
     data->onSet();
   }
-  wlock->requestData_[val] = std::move(data);
+  wlock->requestData_[val] = RequestData::constructPtr(data.release());
 
   return true;
 }
@@ -66,13 +90,19 @@ bool RequestContext::doSetContextData(
 void RequestContext::setContextData(
     const std::string& val,
     std::unique_ptr<RequestData> data) {
-  doSetContextData(val, data, false /* strict */);
+  doSetContextData(val, data, DoSetBehaviour::SET);
 }
 
 bool RequestContext::setContextDataIfAbsent(
     const std::string& val,
     std::unique_ptr<RequestData> data) {
-  return doSetContextData(val, data, true /* strict */);
+  return doSetContextData(val, data, DoSetBehaviour::SET_IF_ABSENT);
+}
+
+void RequestContext::overwriteContextData(
+    const std::string& val,
+    std::unique_ptr<RequestData> data) {
+  doSetContextData(val, data, DoSetBehaviour::OVERWRITE);
 }
 
 bool RequestContext::hasContextData(const std::string& val) const {
@@ -80,13 +110,13 @@ bool RequestContext::hasContextData(const std::string& val) const {
 }
 
 RequestData* RequestContext::getContextData(const std::string& val) {
-  const std::unique_ptr<RequestData> dflt{nullptr};
+  const RequestData::SharedPtr dflt{nullptr};
   return get_ref_default(state_.rlock()->requestData_, val, dflt).get();
 }
 
 const RequestData* RequestContext::getContextData(
     const std::string& val) const {
-  const std::unique_ptr<RequestData> dflt{nullptr};
+  const RequestData::SharedPtr dflt{nullptr};
   return get_ref_default(state_.rlock()->requestData_, val, dflt).get();
 }
 
@@ -105,7 +135,7 @@ void RequestContext::onUnset() {
 }
 
 void RequestContext::clearContextData(const std::string& val) {
-  std::unique_ptr<RequestData> requestData;
+  RequestData::SharedPtr requestData;
   // Delete the RequestData after giving up the wlock just in case one of the
   // RequestData destructors will try to grab the lock again.
   {
@@ -147,6 +177,21 @@ std::shared_ptr<RequestContext>& RequestContext::getStaticContext() {
   return SingletonT::get();
 }
 
+/* static */ std::shared_ptr<RequestContext> RequestContext::shallowCopy() {
+  auto* parent = get();
+  auto child = std::make_shared<RequestContext>();
+  if (parent) {
+    auto rlock = parent->state_.rlock();
+    auto wlock = child->state_.wlock();
+    wlock->callbackData_ = rlock->callbackData_;
+    for (const auto& entry : rlock->requestData_) {
+      wlock->requestData_[entry.first] =
+          RequestData::constructPtr(entry.second.get());
+    }
+  }
+  return child;
+}
+
 RequestContext* RequestContext::get() {
   auto& context = getStaticContext();
   if (!context) {
diff --git a/folly/io/async/Request.h b/folly/io/async/Request.h
index a865eceaa..4edf5c9df 100644
--- a/folly/io/async/Request.h
+++ b/folly/io/async/Request.h
@@ -45,6 +45,28 @@ class RequestData {
   // instance overrides the hasCallback method to return true otherwise
   // the callback will not be executed
   virtual void onUnset() {}
+
+ private:
+  // Start shallow copy implementation details:
+  // For efficiency, RequestContext provides a raw ptr interface.
+  // To support shallow copy, we need a shared ptr.
+  // To keep it as safe as possible (even if a raw ptr is passed back),
+  // the counter lives directly in RequestData.
+
+  friend class RequestContext;
+
+  // Unique ptr with custom destructor, decrement the counter
+  // and only free if 0
+  struct DestructPtr {
+    void operator()(RequestData* ptr);
+  };
+  using SharedPtr = std::unique_ptr<RequestData, DestructPtr>;
+
+  // Initialize the pseudo-shared ptr, increment the counter
+  static SharedPtr constructPtr(RequestData* ptr);
+
+  std::atomic<int> keepAliveCounter_{0};
+  // End shallow copy
 };
 
 // If you do not call create() to create a unique request context,
@@ -118,18 +140,40 @@ class RequestContext {
  private:
   static std::shared_ptr<RequestContext>& getStaticContext();
 
+  friend struct ShallowCopyRequestContextScopeGuard;
+
+  // Private to encourage using shallow copy guard
+  static std::shared_ptr<RequestContext> shallowCopy();
+
+  // Similar to setContextData, except it overwrites the data
+  // if already set (instead of warn + reset ptr).
+  // Private to encourage using shallow copy guard
+  void overwriteContextData(
+      const std::string& val,
+      std::unique_ptr<RequestData> data);
+
+  enum class DoSetBehaviour {
+    SET,
+    SET_IF_ABSENT,
+    OVERWRITE,
+  };
+
   bool doSetContextData(
       const std::string& val,
       std::unique_ptr<RequestData>& data,
-      bool strict);
+      DoSetBehaviour behaviour);
 
   struct State {
-    std::map<std::string, std::unique_ptr<RequestData>> requestData_;
+    std::map<std::string, RequestData::SharedPtr> requestData_;
     std::set<RequestData*> callbackData_;
   };
   folly::Synchronized<State> state_;
 };
 
+/**
+ * Note: you probably want to use ShallowCopyRequestContextScopeGuard
+ * This resets all other RequestData for the duration of the scope!
+ */
 class RequestContextScopeGuard {
  private:
   std::shared_ptr<RequestContext> prev_;
@@ -155,4 +199,30 @@ class RequestContextScopeGuard {
     RequestContext::setContext(std::move(prev_));
   }
 };
+
+/**
+ * This guard maintains all the RequestData pointers of the parent.
+ * This allows to overwrite a specific RequestData pointer for the
+ * scope's duration, without breaking others.
+ *
+ * TODO: currently calls onSet and onUnset incorrectly, will fix on top
+ */
+struct ShallowCopyRequestContextScopeGuard : public RequestContextScopeGuard {
+  ShallowCopyRequestContextScopeGuard()
+      : RequestContextScopeGuard(RequestContext::shallowCopy()) {}
+
+  /**
+   * Shallow copy then overwrite one specific RequestData
+   *
+   * Helper constructor which is a more efficient equivalent to
+   * "clearRequestData" then "setRequestData" after the guard.
+   */
+  ShallowCopyRequestContextScopeGuard(
+      const std::string& val,
+      std::unique_ptr<RequestData> data)
+      : ShallowCopyRequestContextScopeGuard() {
+    RequestContext::get()->overwriteContextData(val, std::move(data));
+  }
+};
+
 } // namespace folly
diff --git a/folly/io/async/test/RequestContextTest.cpp b/folly/io/async/test/RequestContextTest.cpp
index 45a4daa62..e44f5b402 100644
--- a/folly/io/async/test/RequestContextTest.cpp
+++ b/folly/io/async/test/RequestContextTest.cpp
@@ -50,20 +50,24 @@ RequestContext& getContext() {
   return *ctx;
 }
 
-void setData(int data = 0) {
-  getContext().setContextData("test", std::make_unique<TestData>(data));
+void setData(int data = 0, std::string key = "test") {
+  getContext().setContextData(key, std::make_unique<TestData>(data));
 }
 
-bool hasData() {
-  return getContext().getContextData("test") != nullptr;
+bool hasData(std::string key = "test") {
+  return getContext().hasContextData(key);
 }
 
-const TestData& getData() {
-  auto* ptr = dynamic_cast<TestData*>(getContext().getContextData("test"));
+const TestData& getData(std::string key = "test") {
+  auto* ptr = dynamic_cast<TestData*>(getContext().getContextData(key));
   EXPECT_TRUE(ptr != nullptr);
   return *ptr;
 }
 
+void clearData(std::string key = "test") {
+  getContext().clearContextData(key);
+}
+
 TEST(RequestContext, SimpleTest) {
   EventBase base;
 
@@ -207,3 +211,71 @@ TEST(RequestContext, deadlockTest) {
       "test", std::make_unique<DeadlockTestData>("test2"));
   RequestContext::get()->clearContextData("test");
 }
+
+TEST(RequestContext, ShallowCopyBasic) {
+  ShallowCopyRequestContextScopeGuard g0;
+  setData(123, "immutable");
+  EXPECT_EQ(123, getData("immutable").data_);
+  EXPECT_FALSE(hasData());
+
+  {
+    ShallowCopyRequestContextScopeGuard g1;
+    EXPECT_EQ(123, getData("immutable").data_);
+    setData(789);
+    EXPECT_EQ(789, getData().data_);
+  }
+
+  EXPECT_FALSE(hasData());
+  EXPECT_EQ(123, getData("immutable").data_);
+  // TODO: Should be 1/0
+  EXPECT_EQ(3, getData("immutable").set_);
+  EXPECT_EQ(2, getData("immutable").unset_);
+}
+
+TEST(RequestContext, ShallowCopyOverwrite) {
+  RequestContextScopeGuard g0;
+  setData(123);
+  EXPECT_EQ(123, getData().data_);
+  {
+    ShallowCopyRequestContextScopeGuard g1(
+        "test", std::make_unique<TestData>(789));
+    EXPECT_EQ(789, getData().data_);
+    EXPECT_EQ(1, getData().set_);
+    EXPECT_EQ(0, getData().unset_);
+  }
+  EXPECT_EQ(123, getData().data_);
+  // TODO: Should be 2/1
+  EXPECT_EQ(3, getData().set_);
+  EXPECT_EQ(2, getData().unset_);
+}
+
+TEST(RequestContext, ShallowCopyDefaultContext) {
+  // Don't set global scope guard
+  setData(123);
+  EXPECT_EQ(123, getData().data_);
+  {
+    ShallowCopyRequestContextScopeGuard g1(
+        "test", std::make_unique<TestData>(789));
+    EXPECT_EQ(789, getData().data_);
+  }
+  EXPECT_EQ(123, getData().data_);
+  EXPECT_EQ(2, getData().set_);
+  EXPECT_EQ(1, getData().unset_);
+}
+
+TEST(RequestContext, ShallowCopyClear) {
+  RequestContextScopeGuard g0;
+  setData(123);
+  EXPECT_EQ(123, getData().data_);
+  {
+    ShallowCopyRequestContextScopeGuard g1;
+    EXPECT_EQ(123, getData().data_);
+    clearData();
+    setData(789);
+    EXPECT_EQ(789, getData().data_);
+  }
+  EXPECT_EQ(123, getData().data_);
+  // TODO: Should be 2/1
+  EXPECT_EQ(3, getData().set_);
+  EXPECT_EQ(2, getData().unset_);
+}
-- 
2.26.2