Commit a73640d7 authored by Andrew Smith's avatar Andrew Smith Committed by Facebook GitHub Bot

Add optional parameter to consume() and cancel()

Summary:
This diff adds an optional parameter to consume() and cancel(). This parameter will allow callback objects to know the source bridge that is calling them back. With this parameter, we will not need to create a separate callback object for each bridge (which saves memory).

The parameter is optional. If the parameter type is set to void, no parameter will be passed. This preserves backward compatibility with existing users of AtomicQueue that don't need a parameter.

Reviewed By: iahs

Differential Revision: D28550867

fbshipit-source-id: e7d998538c880c2a5c7649d3262cb7f8913e1439
parent 187d8422
...@@ -100,7 +100,8 @@ class AtomicQueue { ...@@ -100,7 +100,8 @@ class AtomicQueue {
AtomicQueue(const AtomicQueue&) = delete; AtomicQueue(const AtomicQueue&) = delete;
AtomicQueue& operator=(const AtomicQueue&) = delete; AtomicQueue& operator=(const AtomicQueue&) = delete;
void push(Message&& value) { template <typename... ConsumerArgs>
void push(Message&& value, ConsumerArgs&&... consumerArgs) {
std::unique_ptr<typename MessageQueue::Node> node( std::unique_ptr<typename MessageQueue::Node> node(
new typename MessageQueue::Node(std::move(value))); new typename MessageQueue::Node(std::move(value)));
assert(!(reinterpret_cast<intptr_t>(node.get()) & kTypeMask)); assert(!(reinterpret_cast<intptr_t>(node.get()) & kTypeMask));
...@@ -135,7 +136,7 @@ class AtomicQueue { ...@@ -135,7 +136,7 @@ class AtomicQueue {
std::memory_order_relaxed)) { std::memory_order_relaxed)) {
node.release(); node.release();
auto consumer = reinterpret_cast<Consumer*>(ptr); auto consumer = reinterpret_cast<Consumer*>(ptr);
consumer->consume(); consumer->consume(std::forward<ConsumerArgs>(consumerArgs)...);
return; return;
} }
break; break;
...@@ -145,7 +146,8 @@ class AtomicQueue { ...@@ -145,7 +146,8 @@ class AtomicQueue {
} }
} }
bool wait(Consumer* consumer) { template <typename... ConsumerArgs>
bool wait(Consumer* consumer, ConsumerArgs&&... consumerArgs) {
assert(!(reinterpret_cast<intptr_t>(consumer) & kTypeMask)); assert(!(reinterpret_cast<intptr_t>(consumer) & kTypeMask));
auto storage = storage_.load(std::memory_order_relaxed); auto storage = storage_.load(std::memory_order_relaxed);
while (true) { while (true) {
...@@ -162,7 +164,7 @@ class AtomicQueue { ...@@ -162,7 +164,7 @@ class AtomicQueue {
} }
break; break;
case Type::CLOSED: case Type::CLOSED:
consumer->canceled(); consumer->canceled(std::forward<ConsumerArgs>(consumerArgs)...);
return true; return true;
case Type::TAIL: case Type::TAIL:
return false; return false;
...@@ -173,7 +175,8 @@ class AtomicQueue { ...@@ -173,7 +175,8 @@ class AtomicQueue {
} }
} }
void close() { template <typename... ConsumerArgs>
void close(ConsumerArgs&&... consumerArgs) {
auto storage = storage_.exchange( auto storage = storage_.exchange(
static_cast<intptr_t>(Type::CLOSED), std::memory_order_acquire); static_cast<intptr_t>(Type::CLOSED), std::memory_order_acquire);
auto type = static_cast<Type>(storage & kTypeMask); auto type = static_cast<Type>(storage & kTypeMask);
...@@ -186,7 +189,8 @@ class AtomicQueue { ...@@ -186,7 +189,8 @@ class AtomicQueue {
reinterpret_cast<typename MessageQueue::Node*>(ptr)); reinterpret_cast<typename MessageQueue::Node*>(ptr));
return; return;
case Type::CONSUMER: case Type::CONSUMER:
reinterpret_cast<Consumer*>(ptr)->canceled(); reinterpret_cast<Consumer*>(ptr)->canceled(
std::forward<ConsumerArgs>(consumerArgs)...);
return; return;
case Type::CLOSED: case Type::CLOSED:
default: default:
...@@ -199,7 +203,8 @@ class AtomicQueue { ...@@ -199,7 +203,8 @@ class AtomicQueue {
return type == Type::CLOSED; return type == Type::CLOSED;
} }
MessageQueue getMessages() { template <typename... ConsumerArgs>
MessageQueue getMessages(ConsumerArgs&&... consumerArgs) {
auto storage = storage_.exchange( auto storage = storage_.exchange(
static_cast<intptr_t>(Type::EMPTY), std::memory_order_acquire); static_cast<intptr_t>(Type::EMPTY), std::memory_order_acquire);
auto type = static_cast<Type>(storage & kTypeMask); auto type = static_cast<Type>(storage & kTypeMask);
...@@ -214,7 +219,7 @@ class AtomicQueue { ...@@ -214,7 +219,7 @@ class AtomicQueue {
// We accidentally re-opened the queue, so close it again. // We accidentally re-opened the queue, so close it again.
// This is only safe to do because isClosed() can't be called // This is only safe to do because isClosed() can't be called
// concurrently with getMessages(). // concurrently with getMessages().
close(); close(std::forward<ConsumerArgs>(consumerArgs)...);
return MessageQueue(); return MessageQueue();
case Type::CONSUMER: case Type::CONSUMER:
default: default:
......
...@@ -24,13 +24,20 @@ namespace folly { ...@@ -24,13 +24,20 @@ namespace folly {
namespace channels { namespace channels {
namespace detail { namespace detail {
static int* getConsumerParam() {
return reinterpret_cast<int*>(1);
}
TEST(AtomicQueueTest, Basic) { TEST(AtomicQueueTest, Basic) {
folly::Baton<> producerBaton; folly::Baton<> producerBaton;
folly::Baton<> consumerBaton; folly::Baton<> consumerBaton;
struct Consumer { struct Consumer {
void consume() { baton.post(); } void consume(int* consumerParam) {
void canceled() { ADD_FAILURE() << "canceled() shouldn't be called"; } EXPECT_EQ(consumerParam, getConsumerParam());
baton.post();
}
void canceled(int*) { ADD_FAILURE() << "canceled() shouldn't be called"; }
folly::Baton<> baton; folly::Baton<> baton;
}; };
AtomicQueue<Consumer, int> atomicQueue; AtomicQueue<Consumer, int> atomicQueue;
...@@ -40,23 +47,23 @@ TEST(AtomicQueueTest, Basic) { ...@@ -40,23 +47,23 @@ TEST(AtomicQueueTest, Basic) {
producerBaton.wait(); producerBaton.wait();
producerBaton.reset(); producerBaton.reset();
atomicQueue.push(1); atomicQueue.push(1, getConsumerParam());
producerBaton.wait(); producerBaton.wait();
producerBaton.reset(); producerBaton.reset();
atomicQueue.push(2); atomicQueue.push(2, getConsumerParam());
atomicQueue.push(3); atomicQueue.push(3, getConsumerParam());
consumerBaton.post(); consumerBaton.post();
}); });
EXPECT_TRUE(atomicQueue.wait(&consumer)); EXPECT_TRUE(atomicQueue.wait(&consumer, getConsumerParam()));
producerBaton.post(); producerBaton.post();
consumer.baton.wait(); consumer.baton.wait();
consumer.baton.reset(); consumer.baton.reset();
{ {
auto q = atomicQueue.getMessages(); auto q = atomicQueue.getMessages(getConsumerParam());
EXPECT_FALSE(q.empty()); EXPECT_FALSE(q.empty());
EXPECT_EQ(1, q.front()); EXPECT_EQ(1, q.front());
q.pop(); q.pop();
...@@ -67,9 +74,9 @@ TEST(AtomicQueueTest, Basic) { ...@@ -67,9 +74,9 @@ TEST(AtomicQueueTest, Basic) {
consumerBaton.wait(); consumerBaton.wait();
consumerBaton.reset(); consumerBaton.reset();
EXPECT_FALSE(atomicQueue.wait(&consumer)); EXPECT_FALSE(atomicQueue.wait(&consumer, getConsumerParam()));
{ {
auto q = atomicQueue.getMessages(); auto q = atomicQueue.getMessages(getConsumerParam());
EXPECT_FALSE(q.empty()); EXPECT_FALSE(q.empty());
EXPECT_EQ(2, q.front()); EXPECT_EQ(2, q.front());
q.pop(); q.pop();
...@@ -79,10 +86,10 @@ TEST(AtomicQueueTest, Basic) { ...@@ -79,10 +86,10 @@ TEST(AtomicQueueTest, Basic) {
EXPECT_TRUE(q.empty()); EXPECT_TRUE(q.empty());
} }
EXPECT_TRUE(atomicQueue.wait(&consumer)); EXPECT_TRUE(atomicQueue.wait(&consumer, getConsumerParam()));
EXPECT_EQ(atomicQueue.cancelCallback(), &consumer); EXPECT_EQ(atomicQueue.cancelCallback(), &consumer);
EXPECT_TRUE(atomicQueue.wait(&consumer)); EXPECT_TRUE(atomicQueue.wait(&consumer, getConsumerParam()));
EXPECT_EQ(atomicQueue.cancelCallback(), &consumer); EXPECT_EQ(atomicQueue.cancelCallback(), &consumer);
EXPECT_EQ(atomicQueue.cancelCallback(), nullptr); EXPECT_EQ(atomicQueue.cancelCallback(), nullptr);
...@@ -92,41 +99,47 @@ TEST(AtomicQueueTest, Basic) { ...@@ -92,41 +99,47 @@ TEST(AtomicQueueTest, Basic) {
TEST(AtomicQueueTest, Canceled) { TEST(AtomicQueueTest, Canceled) {
struct Consumer { struct Consumer {
void consume() { ADD_FAILURE() << "consume() shouldn't be called"; } void consume(int*) { ADD_FAILURE() << "consume() shouldn't be called"; }
void canceled() { canceledCalled = true; } void canceled(int* consumerParam) {
EXPECT_EQ(consumerParam, getConsumerParam());
canceledCalled = true;
}
bool canceledCalled{false}; bool canceledCalled{false};
}; };
AtomicQueue<Consumer, int> atomicQueue; AtomicQueue<Consumer, int> atomicQueue;
Consumer consumer; Consumer consumer;
EXPECT_TRUE(atomicQueue.wait(&consumer)); EXPECT_TRUE(atomicQueue.wait(&consumer, getConsumerParam()));
atomicQueue.close(); atomicQueue.close(getConsumerParam());
EXPECT_TRUE(consumer.canceledCalled); EXPECT_TRUE(consumer.canceledCalled);
EXPECT_TRUE(atomicQueue.isClosed()); EXPECT_TRUE(atomicQueue.isClosed());
EXPECT_TRUE(atomicQueue.getMessages().empty()); EXPECT_TRUE(atomicQueue.getMessages(getConsumerParam()).empty());
EXPECT_TRUE(atomicQueue.isClosed()); EXPECT_TRUE(atomicQueue.isClosed());
atomicQueue.push(42); atomicQueue.push(42, getConsumerParam());
EXPECT_TRUE(atomicQueue.getMessages().empty()); EXPECT_TRUE(atomicQueue.getMessages(getConsumerParam()).empty());
EXPECT_TRUE(atomicQueue.isClosed()); EXPECT_TRUE(atomicQueue.isClosed());
} }
TEST(AtomicQueueTest, Stress) { TEST(AtomicQueueTest, Stress) {
struct Consumer { struct Consumer {
void consume() { baton.post(); } void consume(int* consumerParam) {
void canceled() { ADD_FAILURE() << "canceled() shouldn't be called"; } EXPECT_EQ(consumerParam, getConsumerParam());
baton.post();
}
void canceled(int*) { ADD_FAILURE() << "canceled() shouldn't be called"; }
folly::Baton<> baton; folly::Baton<> baton;
}; };
AtomicQueue<Consumer, int> atomicQueue; AtomicQueue<Consumer, int> atomicQueue;
auto getNext = [&atomicQueue, queue = Queue<int>()]() mutable { auto getNext = [&atomicQueue, queue = Queue<int>()]() mutable {
Consumer consumer; Consumer consumer;
if (queue.empty()) { if (queue.empty()) {
if (atomicQueue.wait(&consumer)) { if (atomicQueue.wait(&consumer, getConsumerParam())) {
consumer.baton.wait(); consumer.baton.wait();
} }
queue = atomicQueue.getMessages(); queue = atomicQueue.getMessages(getConsumerParam());
EXPECT_FALSE(queue.empty()); EXPECT_FALSE(queue.empty());
} }
auto next = queue.front(); auto next = queue.front();
...@@ -142,7 +155,7 @@ TEST(AtomicQueueTest, Stress) { ...@@ -142,7 +155,7 @@ TEST(AtomicQueueTest, Stress) {
std::thread producerThread([&] { std::thread producerThread([&] {
for (producerIndex = 1; producerIndex <= kNumIters; ++producerIndex) { for (producerIndex = 1; producerIndex <= kNumIters; ++producerIndex) {
atomicQueue.push(producerIndex); atomicQueue.push(producerIndex, getConsumerParam());
if (producerIndex % kSynchronizeEvery == 0) { if (producerIndex % kSynchronizeEvery == 0) {
while (producerIndex > consumerIndex.load(std::memory_order_relaxed)) { while (producerIndex > consumerIndex.load(std::memory_order_relaxed)) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment