Commit db6d0a10 authored by Sergey Korytnikov's avatar Sergey Korytnikov Committed by Facebook GitHub Bot

Fix waiters signaling in BatchSemaphore

Summary: Fix BatchSemaphore to post baton for a multiple waiter in a list as long as there are token available.

Reviewed By: yfeldblum

Differential Revision: D29430799

fbshipit-source-id: 98b0a616d0ce863108dcf331e491fd2cc12429d1
parent b1fa3c6f
...@@ -20,17 +20,7 @@ namespace folly { ...@@ -20,17 +20,7 @@ namespace folly {
namespace fibers { namespace fibers {
void BatchSemaphore::signal(int64_t tokens) { void BatchSemaphore::signal(int64_t tokens) {
auto oldVal = tokens_.load(std::memory_order_acquire); signalSlow(tokens);
do {
if (signalSlow(tokens, oldVal)) {
return;
}
oldVal = tokens_.load(std::memory_order_acquire);
} while (!tokens_.compare_exchange_weak(
oldVal,
oldVal + tokens,
std::memory_order_release,
std::memory_order_acquire));
} }
void BatchSemaphore::wait(int64_t tokens) { void BatchSemaphore::wait(int64_t tokens) {
......
...@@ -19,38 +19,49 @@ ...@@ -19,38 +19,49 @@
namespace folly { namespace folly {
namespace fibers { namespace fibers {
bool SemaphoreBase::signalSlow(int64_t tokens, int64_t oldVal) { bool SemaphoreBase::signalSlow(int64_t tokens) {
Waiter* waiter = nullptr; do {
{
// If we signalled a release, notify the waitlist
auto waitListLock = waitList_.wlock();
auto& waitList = *waitListLock;
auto testVal = tokens_.load(std::memory_order_acquire); auto testVal = tokens_.load(std::memory_order_acquire);
if (oldVal != testVal) { Waiter* waiter = nullptr;
return false; {
} // If we signalled a release, notify the waitlist
auto waitListLock = waitList_.wlock();
auto& waitList = *waitListLock;
if (waitList.empty() || waitList.front().tokens_ > testVal + tokens) {
// If the waitlist is now empty or not enough tokens to resume next in a
// waitlist, ensure the token count increments. No need for CAS here as
// we will always be under the mutex
if (tokens_.compare_exchange_strong(
testVal, testVal + tokens, std::memory_order_relaxed)) {
return true;
}
continue;
}
if (waitList.empty()) { waiter = &waitList.front();
// If the waitlist is now empty and tokens is 0, ensure the token count
// increments No need for CAS here as we will always be under the mutex // Check for tokens shortage and keep the waiter until tokens acquired
return tokens_.compare_exchange_strong( int64_t release =
testVal, testVal + tokens, std::memory_order_relaxed); (testVal > waiter->tokens_) ? 0 : waiter->tokens_ - testVal;
} if (!tokens_.compare_exchange_strong(
testVal,
testVal + release - waiter->tokens_,
std::memory_order_relaxed)) {
continue;
}
waiter = &waitList.front(); tokens -= release;
if (waiter->tokens_ > testVal + tokens) { waitList.pop_front();
// Not enough tokens to resume next in waitlist
return tokens_.compare_exchange_strong(
testVal, testVal + tokens, std::memory_order_relaxed);
} }
waitList.pop_front();
} // Trigger waiter if there is one
// Trigger waiter if there is one // Do it after releasing the waitList mutex, in case the waiter
// Do it after releasing the waitList mutex, in case the waiter // eagerly calls signal
// eagerly calls signal waiter->baton.post();
waiter->baton.post(); } while (tokens > 0);
return true; return true;
} }
...@@ -77,8 +88,9 @@ void SemaphoreBase::wait_common(int64_t tokens) { ...@@ -77,8 +88,9 @@ void SemaphoreBase::wait_common(int64_t tokens) {
do { do {
while (oldVal < tokens) { while (oldVal < tokens) {
Waiter waiter{tokens}; Waiter waiter{tokens};
// If waitSlow fails it is because the capacity is greater than requested // If waitSlow fails it is because the capacity is greater than
// by the time the lock is taken, so we can just continue round the loop // requested by the time the lock is taken, so we can just continue
// round the loop
if (waitSlow(waiter, tokens)) { if (waitSlow(waiter, tokens)) {
waiter.baton.wait(); waiter.baton.wait();
return; return;
...@@ -116,8 +128,9 @@ coro::Task<void> SemaphoreBase::co_wait_common(int64_t tokens) { ...@@ -116,8 +128,9 @@ coro::Task<void> SemaphoreBase::co_wait_common(int64_t tokens) {
do { do {
while (oldVal < tokens) { while (oldVal < tokens) {
Waiter waiter{tokens}; Waiter waiter{tokens};
// If waitSlow fails it is because the capacity is greater than requested // If waitSlow fails it is because the capacity is greater than
// by the time the lock is taken, so we can just continue round the loop // requested by the time the lock is taken, so we can just continue
// round the loop
if (waitSlow(waiter, tokens)) { if (waitSlow(waiter, tokens)) {
bool cancelled = false; bool cancelled = false;
{ {
...@@ -143,9 +156,9 @@ coro::Task<void> SemaphoreBase::co_wait_common(int64_t tokens) { ...@@ -143,9 +156,9 @@ coro::Task<void> SemaphoreBase::co_wait_common(int64_t tokens) {
co_await waiter.baton; co_await waiter.baton;
} }
// Check 'cancelled' flag only after deregistering the callback so we're // Check 'cancelled' flag only after deregistering the callback so
// sure that we aren't reading it concurrently with a potential write // we're sure that we aren't reading it concurrently with a potential
// from a thread requesting cancellation. // write from a thread requesting cancellation.
if (cancelled) { if (cancelled) {
co_yield folly::coro::co_cancelled; co_yield folly::coro::co_cancelled;
} }
...@@ -187,8 +200,9 @@ SemiFuture<Unit> SemaphoreBase::future_wait_common(int64_t tokens) { ...@@ -187,8 +200,9 @@ SemiFuture<Unit> SemaphoreBase::future_wait_common(int64_t tokens) {
do { do {
while (oldVal < tokens) { while (oldVal < tokens) {
auto batonWaiterPtr = std::make_unique<FutureWaiter>(tokens); auto batonWaiterPtr = std::make_unique<FutureWaiter>(tokens);
// If waitSlow fails it is because the capacity is greater than requested // If waitSlow fails it is because the capacity is greater than
// by the time the lock is taken, so we can just continue round the loop // requested by the time the lock is taken, so we can just continue
// round the loop
auto future = batonWaiterPtr->promise.getSemiFuture(); auto future = batonWaiterPtr->promise.getSemiFuture();
if (waitSlow(batonWaiterPtr->semaphoreWaiter, tokens)) { if (waitSlow(batonWaiterPtr->semaphoreWaiter, tokens)) {
(void)batonWaiterPtr.release(); (void)batonWaiterPtr.release();
......
...@@ -99,7 +99,7 @@ class SemaphoreBase { ...@@ -99,7 +99,7 @@ class SemaphoreBase {
SemiFuture<Unit> future_wait_common(int64_t tokens); SemiFuture<Unit> future_wait_common(int64_t tokens);
bool waitSlow(Waiter& waiter, int64_t tokens); bool waitSlow(Waiter& waiter, int64_t tokens);
bool signalSlow(int64_t tokens, int64_t oldVal); bool signalSlow(int64_t tokens);
size_t capacity_; size_t capacity_;
// Atomic counter // Atomic counter
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include <folly/Memory.h> #include <folly/Memory.h>
#include <folly/Random.h> #include <folly/Random.h>
#include <folly/executors/CPUThreadPoolExecutor.h> #include <folly/executors/CPUThreadPoolExecutor.h>
#include <folly/experimental/coro/BlockingWait.h>
#include <folly/fibers/AddTasks.h> #include <folly/fibers/AddTasks.h>
#include <folly/fibers/AtomicBatchDispatcher.h> #include <folly/fibers/AtomicBatchDispatcher.h>
#include <folly/fibers/BatchDispatcher.h> #include <folly/fibers/BatchDispatcher.h>
...@@ -1831,8 +1832,8 @@ TEST(FiberManager, batchSemaphore) { ...@@ -1831,8 +1832,8 @@ TEST(FiberManager, batchSemaphore) {
for (size_t i = 0; i < kTasks; ++i) { for (size_t i = 0; i < kTasks; ++i) {
manager.addTask([&, completionCounter]() { manager.addTask([&, completionCounter]() {
for (size_t j = 0; j < kIterations; ++j) { for (size_t j = 0; j < kIterations; ++j) {
int tokens = j % 3 + 1; int tokens = j % 4 + 1;
switch (j % 3) { switch (j % 4) {
case 0: case 0:
sem.wait(tokens); sem.wait(tokens);
break; break;
...@@ -1847,6 +1848,9 @@ TEST(FiberManager, batchSemaphore) { ...@@ -1847,6 +1848,9 @@ TEST(FiberManager, batchSemaphore) {
} }
break; break;
} }
case 3:
folly::coro::blockingWait(sem.co_wait(tokens));
break;
} }
counter += tokens; counter += tokens;
sem.signal(tokens); sem.signal(tokens);
...@@ -1881,6 +1885,54 @@ TEST(FiberManager, batchSemaphore) { ...@@ -1881,6 +1885,54 @@ TEST(FiberManager, batchSemaphore) {
} }
} }
/**
* Verify that BatchSemaphore signals all waiters or fail by timeout.
* Overall idea is to linearize waiters in the semaphore's list,
* requesting incremental number of token. For example, [1, 2, 3, 4, 5] for a
* total semaphore capacity of 5 tokens. When releasing all 5 tokens at once an
* expected behavior is:
* - Return 5 tokens: notify [1, 2] and block [3, 4, 5] - 2 tokens left
* - Return 1 token: notify [3] and block [4, 5] - 0 tokens left
* - Return 2 tokens: and block [4, 5] - 2 tokens left
* - Return 3 tokens: notify [4] and block [5] - 1 token left
* - Return 4 tokens: notify [5] - 0 tokens left
* - Return 5 tokens: done - 5 tokens left
*/
TEST(FiberManager, batchSemaphoreSignalAll) {
static constexpr size_t kNumWaiters = 5;
BatchSemaphore sem(kNumWaiters);
sem.wait(kNumWaiters);
folly::EventBase evb;
auto& fm = getFiberManager(evb);
for (size_t task = 0; task < kNumWaiters; task++) {
fm.addTask([&, tokens = int64_t(task) + 1] {
// Wait for semaphore and fail if not notified
BatchSemaphore::Waiter waiter{tokens};
bool acquired = sem.try_wait(waiter, tokens);
if (!acquired &&
!waiter.baton.try_wait_for(std::chrono::milliseconds(1000))) {
FAIL() << "BatchSemaphore::Waiter has never been notified";
}
// Rotate to the next task
Baton b;
b.try_wait_for(std::chrono::milliseconds(1));
sem.signal(tokens);
});
}
fm.addTask([&] {
// Release all tokens and notify waiters
sem.signal(kNumWaiters);
});
evb.loop();
EXPECT_FALSE(fm.hasTasks());
}
template <typename ExecutorT> template <typename ExecutorT>
void singleBatchDispatch(ExecutorT& executor, int batchSize, int index) { void singleBatchDispatch(ExecutorT& executor, int batchSize, int index) {
thread_local BatchDispatcher<int, std::string, ExecutorT> batchDispatcher( thread_local BatchDispatcher<int, std::string, ExecutorT> batchDispatcher(
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment