Commit 0ec1d370 authored by Andrii Grynenko's avatar Andrii Grynenko Committed by Facebook Github Bot

Improve tail call optimization

Summary:
Existing tail call optimization implementation only works if the returned future already had callback set by the time the future returned by .then() callback is completed, which makes it's behavior racy.
This diff takes a slightly different approach by adding a new mode to Core, which makes it forward callback to another Core, once the callback is set.

Reviewed By: yfeldblum

Differential Revision: D12837044

fbshipit-source-id: 13ad822d8dfdcef8bb5a28108c24195ae911550b
parent 20f2b04c
...@@ -266,16 +266,8 @@ void FutureBase<T>::raise(exception_wrapper exception) { ...@@ -266,16 +266,8 @@ void FutureBase<T>::raise(exception_wrapper exception) {
template <class T> template <class T>
template <class F> template <class F>
void FutureBase<T>::setCallback_(F&& func) { void FutureBase<T>::setCallback_(F&& func) {
setCallback_(std::forward<F>(func), RequestContext::saveContext());
}
template <class T>
template <class F>
void FutureBase<T>::setCallback_(
F&& func,
std::shared_ptr<folly::RequestContext> context) {
throwIfContinued(); throwIfContinued();
getCore().setCallback(std::forward<F>(func), std::move(context)); getCore().setCallback(std::forward<F>(func), RequestContext::saveContext());
} }
template <class T> template <class T>
...@@ -423,17 +415,8 @@ FutureBase<T>::thenImplementation(F&& func, R) { ...@@ -423,17 +415,8 @@ FutureBase<T>::thenImplementation(F&& func, R) {
auto statePromise = state.stealPromise(); auto statePromise = state.stealPromise();
auto tf3 = auto tf3 =
chainExecutor(statePromise.core_->getExecutor(), *std::move(tf2)); chainExecutor(statePromise.core_->getExecutor(), *std::move(tf2));
if (statePromise.getCore().hasCallback()) { std::exchange(statePromise.core_, nullptr)
tf3.core_->setExecutor(statePromise.core_->getExecutor()); ->setProxy(std::exchange(tf3.core_, nullptr));
auto callbackAndContext = statePromise.getCore().stealCallback();
tf3.setCallback_(
std::move(callbackAndContext.first),
std::move(callbackAndContext.second));
} else {
tf3.setCallback_([p2 = std::move(statePromise)](Try<B>&& b) mutable {
p2.setTry(std::move(b));
});
}
} }
} }
}); });
......
...@@ -44,7 +44,9 @@ enum class State : uint8_t { ...@@ -44,7 +44,9 @@ enum class State : uint8_t {
Start = 1 << 0, Start = 1 << 0,
OnlyResult = 1 << 1, OnlyResult = 1 << 1,
OnlyCallback = 1 << 2, OnlyCallback = 1 << 2,
Done = 1 << 3, Proxy = 1 << 3,
Done = 1 << 4,
Empty = 1 << 5,
}; };
constexpr State operator&(State a, State b) { constexpr State operator&(State a, State b) {
return State(uint8_t(a) & uint8_t(b)); return State(uint8_t(a) & uint8_t(b));
...@@ -118,12 +120,18 @@ static_assert(sizeof(SpinLock) == 1, "missized"); ...@@ -118,12 +120,18 @@ static_assert(sizeof(SpinLock) == 1, "missized");
/// | (setResult()) (setCallback()) | /// | (setResult()) (setCallback()) |
/// | / \ | /// | / \ |
/// | Start ---------> ------> Done | /// | Start ---------> ------> Done |
/// | \ \ / | /// | \ \ / |
/// | \ (setCallback()) (setResult()) | /// | \ (setCallback()) (setResult()) |
/// | \ \ / | /// | \ \ / |
/// | \ ---> OnlyCallback --- | /// | \ ---> OnlyCallback --- |
/// | \ / | /// | \ \ |
/// | <- (stealCallback()) - | /// | (setProxy()) (setProxy()) |
/// | \ \ |
/// | \ ------> Empty |
/// | \ / |
/// | \ (setCallback()) |
/// | \ / |
/// | ---------> Proxy ---------- |
/// +----------------------------------------------------------------+ /// +----------------------------------------------------------------+
/// ///
/// States and the corresponding producer-to-consumer data status & ownership: /// States and the corresponding producer-to-consumer data status & ownership:
...@@ -143,14 +151,15 @@ static_assert(sizeof(SpinLock) == 1, "missized"); ...@@ -143,14 +151,15 @@ static_assert(sizeof(SpinLock) == 1, "missized");
/// point forward only the producer thread can safely access that callback /// point forward only the producer thread can safely access that callback
/// (see `setResult()` and `doCallback()` where the producer thread can both /// (see `setResult()` and `doCallback()` where the producer thread can both
/// read and modify the callback). /// read and modify the callback).
/// Alternatively producer thread can also steal the callback (see /// - Proxy: producer thread has set a proxy core which the callback should be
/// `stealCallback()`). /// proxied to.
/// - Done: callback can be safely accessed only within `doCallback()`, which /// - Done: callback can be safely accessed only within `doCallback()`, which
/// gets called on exactly one thread exactly once just after the transition /// gets called on exactly one thread exactly once just after the transition
/// to Done. The future object will have determined whether that callback /// to Done. The future object will have determined whether that callback
/// has/will move-out the result, but either way the result remains logically /// has/will move-out the result, but either way the result remains logically
/// owned exclusively by the consumer thread (the code of Future/SemiFuture, /// owned exclusively by the consumer thread (the code of Future/SemiFuture,
/// of the continuation, and/or of callers of `future.result()`, etc.). /// of the continuation, and/or of callers of `future.result()`, etc.).
/// - Empty: the core successfully proxied the callback and is now empty.
/// ///
/// Start state: /// Start state:
/// ///
...@@ -165,6 +174,8 @@ static_assert(sizeof(SpinLock) == 1, "missized"); ...@@ -165,6 +174,8 @@ static_assert(sizeof(SpinLock) == 1, "missized");
/// `future.wait()` and/or `future.get()` are used. /// `future.wait()` and/or `future.get()` are used.
/// - Done: a terminal state when `future.then()` is used, and sometimes also /// - Done: a terminal state when `future.then()` is used, and sometimes also
/// when `future.wait()` and/or `future.get()` are used. /// when `future.wait()` and/or `future.get()` are used.
/// - Proxy: a terminal state if proxy core was set, but callback was never set.
/// - Empty: a terminal state when proxying a callback was successful.
/// ///
/// Notes and caveats: /// Notes and caveats:
/// ///
...@@ -230,7 +241,12 @@ class Core final { ...@@ -230,7 +241,12 @@ class Core final {
/// Identical to `this->ready()` /// Identical to `this->ready()`
bool hasResult() const noexcept { bool hasResult() const noexcept {
constexpr auto allowed = State::OnlyResult | State::Done; constexpr auto allowed = State::OnlyResult | State::Done;
auto const state = state_.load(std::memory_order_acquire); auto core = this;
auto state = core->state_.load(std::memory_order_acquire);
while (state == State::Proxy) {
core = core->proxy_;
state = core->state_.load(std::memory_order_acquire);
}
return State() != (state & allowed); return State() != (state & allowed);
} }
...@@ -263,11 +279,19 @@ class Core final { ...@@ -263,11 +279,19 @@ class Core final {
/// all callbacks modify (possibly move-out) the result.) /// all callbacks modify (possibly move-out) the result.)
Try<T>& getTry() { Try<T>& getTry() {
DCHECK(hasResult()); DCHECK(hasResult());
return result_; auto core = this;
while (core->state_.load(std::memory_order_relaxed) == State::Proxy) {
core = core->proxy_;
}
return core->result_;
} }
Try<T> const& getTry() const { Try<T> const& getTry() const {
DCHECK(hasResult()); DCHECK(hasResult());
return result_; auto core = this;
while (core->state_.load(std::memory_order_relaxed) == State::Proxy) {
core = core->proxy_;
}
return core->result_;
} }
/// Call only from consumer thread. /// Call only from consumer thread.
...@@ -287,28 +311,26 @@ class Core final { ...@@ -287,28 +311,26 @@ class Core final {
::new (&context_) Context(std::move(context)); ::new (&context_) Context(std::move(context));
auto state = state_.load(std::memory_order_acquire); auto state = state_.load(std::memory_order_acquire);
while (true) {
switch (state) {
case State::Start:
if (state_.compare_exchange_strong(
state, State::OnlyCallback, std::memory_order_release)) {
return;
}
assume(state == State::OnlyResult);
FOLLY_FALLTHROUGH;
case State::OnlyResult:
if (state_.compare_exchange_strong(
state, State::Done, std::memory_order_release)) {
doCallback();
return;
}
FOLLY_FALLTHROUGH;
default: if (state == State::Start) {
terminate_with<std::logic_error>("setCallback unexpected state"); if (state_.compare_exchange_strong(
state, State::OnlyCallback, std::memory_order_release)) {
return;
} }
assume(state == State::OnlyResult || state == State::Proxy);
}
if (state == State::OnlyResult) {
state_.store(State::Done, std::memory_order_relaxed);
doCallback();
return;
} }
if (state == State::Proxy) {
return proxyCallback();
}
terminate_with<std::logic_error>("setCallback unexpected state");
} }
/// Call only from producer thread. /// Call only from producer thread.
...@@ -316,17 +338,31 @@ class Core final { ...@@ -316,17 +338,31 @@ class Core final {
/// ///
/// See FSM graph for allowed transitions. /// See FSM graph for allowed transitions.
/// ///
/// Call only if hasCallback() == true && hasResult() == false
/// This can not be called concurrently with setResult(). /// This can not be called concurrently with setResult().
/// void setProxy(Core* proxy) {
/// Extracts callback from the Core and transitions it back into Start state. DCHECK(!hasResult());
std::pair<Callback, std::shared_ptr<folly::RequestContext>> stealCallback() {
DCHECK(state_.load(std::memory_order_relaxed) == State::OnlyCallback); proxy_ = proxy;
auto ret = std::make_pair(std::move(callback_), std::move(context_));
context_.~Context(); auto state = state_.load(std::memory_order_acquire);
callback_.~Callback(); switch (state) {
state_.store(State::Start, std::memory_order_relaxed); case State::Start:
return ret; if (state_.compare_exchange_strong(
state, State::Proxy, std::memory_order_release)) {
break;
}
assume(state == State::OnlyCallback);
FOLLY_FALLTHROUGH;
case State::OnlyCallback:
proxyCallback();
break;
default:
terminate_with<std::logic_error>("setCallback unexpected state");
}
detachOne();
} }
/// Call only from producer thread. /// Call only from producer thread.
...@@ -476,8 +512,25 @@ class Core final { ...@@ -476,8 +512,25 @@ class Core final {
~Core() { ~Core() {
DCHECK(attached_ == 0); DCHECK(attached_ == 0);
DCHECK(hasResult()); auto state = state_.load(std::memory_order_relaxed);
result_.~Result(); switch (state) {
case State::OnlyResult:
FOLLY_FALLTHROUGH;
case State::Done:
result_.~Result();
break;
case State::Proxy:
proxy_->detachFuture();
break;
case State::Empty:
break;
default:
terminate_with<std::logic_error>("~Core unexpected state");
}
} }
// Helper class that stores a pointer to the `Core` object and calls // Helper class that stores a pointer to the `Core` object and calls
...@@ -577,6 +630,15 @@ class Core final { ...@@ -577,6 +630,15 @@ class Core final {
} }
} }
void proxyCallback() {
state_.store(State::Empty, std::memory_order_relaxed);
proxy_->setExecutor(std::move(executor_), priority_);
proxy_->setCallback(std::move(callback_), std::move(context_));
proxy_->detachFuture();
context_.~Context();
callback_.~Callback();
}
void detachOne() noexcept { void detachOne() noexcept {
auto a = attached_.fetch_sub(1, std::memory_order_acq_rel); auto a = attached_.fetch_sub(1, std::memory_order_acq_rel);
assert(a >= 1); assert(a >= 1);
...@@ -603,6 +665,7 @@ class Core final { ...@@ -603,6 +665,7 @@ class Core final {
// contained entirely in one cache line // contained entirely in one cache line
union { union {
Result result_; Result result_;
Core* proxy_;
}; };
std::atomic<State> state_; std::atomic<State> state_;
std::atomic<unsigned char> attached_; std::atomic<unsigned char> attached_;
......
...@@ -1608,12 +1608,14 @@ Future<bool> call(int depth, Executor* executor) { ...@@ -1608,12 +1608,14 @@ Future<bool> call(int depth, Executor* executor) {
} }
Future<int> recursion(Executor* executor, int depth) { Future<int> recursion(Executor* executor, int depth) {
return call(depth, executor).thenValue([=](auto result) { return makeFuture().thenValue([=](auto) {
if (result) { return call(depth, executor).thenValue([=](auto result) {
return folly::makeFuture(42); if (result) {
} return folly::makeFuture(42);
}
return recursion(executor, depth - 1); return recursion(executor, depth - 1);
});
}); });
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment