Commit 187d8422 authored by Andrew Smith's avatar Andrew Smith Committed by Facebook GitHub Bot

Move AtomicQueue to folly

Summary: This diff moves thrift's AtomicQueue to folly/experimental/channels/detail. This file will be shared between thrift and a new channels framework that will be added to folly/experimental/channels.

Reviewed By: iahs

Differential Revision: D28549810

fbshipit-source-id: de9d66c0f9fd73e89917df997526539b6b92f172
parent 17a3ed12
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <atomic>
#include <cassert>
#include <memory>
#include <utility>
#include <glog/logging.h>
#include <folly/lang/Assume.h>
namespace folly {
namespace channels {
namespace detail {
template <typename T>
class Queue {
public:
Queue() {}
Queue(Queue&& other) : head_(std::exchange(other.head_, nullptr)) {}
Queue& operator=(Queue&& other) {
clear();
std::swap(head_, other.head_);
return *this;
}
~Queue() { clear(); }
bool empty() const { return !head_; }
T& front() { return head_->value; }
void pop() { std::unique_ptr<Node>(std::exchange(head_, head_->next)); }
void clear() {
while (!empty()) {
pop();
}
}
explicit operator bool() const { return !empty(); }
struct Node {
explicit Node(T&& t) : value(std::move(t)) {}
T value;
Node* next{nullptr};
};
explicit Queue(Node* head) : head_(head) {}
static Queue fromReversed(Node* tail) {
// Reverse a linked list.
Node* head{nullptr};
while (tail) {
head = std::exchange(tail, std::exchange(tail->next, head));
}
return Queue(head);
}
Node* head_{nullptr};
};
template <typename Consumer, typename Message>
class AtomicQueue {
public:
using MessageQueue = Queue<Message>;
AtomicQueue() {}
~AtomicQueue() {
auto storage = storage_.load(std::memory_order_acquire);
auto type = static_cast<Type>(storage & kTypeMask);
auto ptr = storage & kPointerMask;
switch (type) {
case Type::EMPTY:
case Type::CLOSED:
return;
case Type::TAIL:
MessageQueue::fromReversed(
reinterpret_cast<typename MessageQueue::Node*>(ptr));
return;
case Type::CONSUMER:
default:
folly::assume_unreachable();
};
}
AtomicQueue(const AtomicQueue&) = delete;
AtomicQueue& operator=(const AtomicQueue&) = delete;
void push(Message&& value) {
std::unique_ptr<typename MessageQueue::Node> node(
new typename MessageQueue::Node(std::move(value)));
assert(!(reinterpret_cast<intptr_t>(node.get()) & kTypeMask));
auto storage = storage_.load(std::memory_order_relaxed);
while (true) {
auto type = static_cast<Type>(storage & kTypeMask);
auto ptr = storage & kPointerMask;
switch (type) {
case Type::EMPTY:
case Type::TAIL:
node->next = reinterpret_cast<typename MessageQueue::Node*>(ptr);
if (storage_.compare_exchange_weak(
storage,
reinterpret_cast<intptr_t>(node.get()) |
static_cast<intptr_t>(Type::TAIL),
std::memory_order_release,
std::memory_order_relaxed)) {
node.release();
return;
}
break;
case Type::CLOSED:
return;
case Type::CONSUMER:
node->next = nullptr;
if (storage_.compare_exchange_weak(
storage,
reinterpret_cast<intptr_t>(node.get()) |
static_cast<intptr_t>(Type::TAIL),
std::memory_order_acq_rel,
std::memory_order_relaxed)) {
node.release();
auto consumer = reinterpret_cast<Consumer*>(ptr);
consumer->consume();
return;
}
break;
default:
folly::assume_unreachable();
}
}
}
bool wait(Consumer* consumer) {
assert(!(reinterpret_cast<intptr_t>(consumer) & kTypeMask));
auto storage = storage_.load(std::memory_order_relaxed);
while (true) {
auto type = static_cast<Type>(storage & kTypeMask);
switch (type) {
case Type::EMPTY:
if (storage_.compare_exchange_weak(
storage,
reinterpret_cast<intptr_t>(consumer) |
static_cast<intptr_t>(Type::CONSUMER),
std::memory_order_release,
std::memory_order_relaxed)) {
return true;
}
break;
case Type::CLOSED:
consumer->canceled();
return true;
case Type::TAIL:
return false;
case Type::CONSUMER:
default:
folly::assume_unreachable();
}
}
}
void close() {
auto storage = storage_.exchange(
static_cast<intptr_t>(Type::CLOSED), std::memory_order_acquire);
auto type = static_cast<Type>(storage & kTypeMask);
auto ptr = storage & kPointerMask;
switch (type) {
case Type::EMPTY:
return;
case Type::TAIL:
MessageQueue::fromReversed(
reinterpret_cast<typename MessageQueue::Node*>(ptr));
return;
case Type::CONSUMER:
reinterpret_cast<Consumer*>(ptr)->canceled();
return;
case Type::CLOSED:
default:
folly::assume_unreachable();
};
}
bool isClosed() {
auto type = static_cast<Type>(storage_ & kTypeMask);
return type == Type::CLOSED;
}
MessageQueue getMessages() {
auto storage = storage_.exchange(
static_cast<intptr_t>(Type::EMPTY), std::memory_order_acquire);
auto type = static_cast<Type>(storage & kTypeMask);
auto ptr = storage & kPointerMask;
switch (type) {
case Type::TAIL:
return MessageQueue::fromReversed(
reinterpret_cast<typename MessageQueue::Node*>(ptr));
case Type::EMPTY:
return MessageQueue();
case Type::CLOSED:
// We accidentally re-opened the queue, so close it again.
// This is only safe to do because isClosed() can't be called
// concurrently with getMessages().
close();
return MessageQueue();
case Type::CONSUMER:
default:
folly::assume_unreachable();
};
}
Consumer* cancelCallback() {
auto storage = storage_.load(std::memory_order_acquire);
while (true) {
auto type = static_cast<Type>(storage & kTypeMask);
auto ptr = storage & kPointerMask;
switch (type) {
case Type::CONSUMER:
if (storage_.compare_exchange_weak(
storage,
static_cast<intptr_t>(Type::EMPTY),
std::memory_order_relaxed,
std::memory_order_relaxed)) {
return reinterpret_cast<Consumer*>(ptr);
}
break;
case Type::TAIL:
case Type::EMPTY:
case Type::CLOSED:
default:
return nullptr;
}
}
}
private:
enum class Type : intptr_t { EMPTY = 0, CONSUMER = 1, TAIL = 2, CLOSED = 3 };
static constexpr intptr_t kTypeMask = 3;
static constexpr intptr_t kPointerMask = ~kTypeMask;
std::atomic<intptr_t> storage_{0};
};
} // namespace detail
} // namespace channels
} // namespace folly
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/experimental/channels/detail/AtomicQueue.h>
#include <folly/portability/GTest.h>
#include <folly/synchronization/Baton.h>
namespace folly {
namespace channels {
namespace detail {
TEST(AtomicQueueTest, Basic) {
folly::Baton<> producerBaton;
folly::Baton<> consumerBaton;
struct Consumer {
void consume() { baton.post(); }
void canceled() { ADD_FAILURE() << "canceled() shouldn't be called"; }
folly::Baton<> baton;
};
AtomicQueue<Consumer, int> atomicQueue;
Consumer consumer;
std::thread producerThread([&] {
producerBaton.wait();
producerBaton.reset();
atomicQueue.push(1);
producerBaton.wait();
producerBaton.reset();
atomicQueue.push(2);
atomicQueue.push(3);
consumerBaton.post();
});
EXPECT_TRUE(atomicQueue.wait(&consumer));
producerBaton.post();
consumer.baton.wait();
consumer.baton.reset();
{
auto q = atomicQueue.getMessages();
EXPECT_FALSE(q.empty());
EXPECT_EQ(1, q.front());
q.pop();
EXPECT_TRUE(q.empty());
}
producerBaton.post();
consumerBaton.wait();
consumerBaton.reset();
EXPECT_FALSE(atomicQueue.wait(&consumer));
{
auto q = atomicQueue.getMessages();
EXPECT_FALSE(q.empty());
EXPECT_EQ(2, q.front());
q.pop();
EXPECT_FALSE(q.empty());
EXPECT_EQ(3, q.front());
q.pop();
EXPECT_TRUE(q.empty());
}
EXPECT_TRUE(atomicQueue.wait(&consumer));
EXPECT_EQ(atomicQueue.cancelCallback(), &consumer);
EXPECT_TRUE(atomicQueue.wait(&consumer));
EXPECT_EQ(atomicQueue.cancelCallback(), &consumer);
EXPECT_EQ(atomicQueue.cancelCallback(), nullptr);
producerThread.join();
}
TEST(AtomicQueueTest, Canceled) {
struct Consumer {
void consume() { ADD_FAILURE() << "consume() shouldn't be called"; }
void canceled() { canceledCalled = true; }
bool canceledCalled{false};
};
AtomicQueue<Consumer, int> atomicQueue;
Consumer consumer;
EXPECT_TRUE(atomicQueue.wait(&consumer));
atomicQueue.close();
EXPECT_TRUE(consumer.canceledCalled);
EXPECT_TRUE(atomicQueue.isClosed());
EXPECT_TRUE(atomicQueue.getMessages().empty());
EXPECT_TRUE(atomicQueue.isClosed());
atomicQueue.push(42);
EXPECT_TRUE(atomicQueue.getMessages().empty());
EXPECT_TRUE(atomicQueue.isClosed());
}
TEST(AtomicQueueTest, Stress) {
struct Consumer {
void consume() { baton.post(); }
void canceled() { ADD_FAILURE() << "canceled() shouldn't be called"; }
folly::Baton<> baton;
};
AtomicQueue<Consumer, int> atomicQueue;
auto getNext = [&atomicQueue, queue = Queue<int>()]() mutable {
Consumer consumer;
if (queue.empty()) {
if (atomicQueue.wait(&consumer)) {
consumer.baton.wait();
}
queue = atomicQueue.getMessages();
EXPECT_FALSE(queue.empty());
}
auto next = queue.front();
queue.pop();
return next;
};
constexpr ssize_t kNumIters = 100000;
constexpr ssize_t kSynchronizeEvery = 1000;
std::atomic<ssize_t> producerIndex{0};
std::atomic<ssize_t> consumerIndex{0};
std::thread producerThread([&] {
for (producerIndex = 1; producerIndex <= kNumIters; ++producerIndex) {
atomicQueue.push(producerIndex);
if (producerIndex % kSynchronizeEvery == 0) {
while (producerIndex > consumerIndex.load(std::memory_order_relaxed)) {
std::this_thread::yield();
}
}
}
});
for (consumerIndex = 1; consumerIndex <= kNumIters; ++consumerIndex) {
EXPECT_EQ(consumerIndex, getNext());
if (consumerIndex % kSynchronizeEvery == 0) {
while (consumerIndex > producerIndex.load(std::memory_order_relaxed)) {
std::this_thread::yield();
}
}
}
producerThread.join();
}
} // namespace detail
} // namespace channels
} // namespace folly
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment