Commit af581bd8 authored by Andrew Smith's avatar Andrew Smith Committed by Facebook GitHub Bot

FanoutChannel: Add support for custom context

Summary:
One common pattern with FanoutChannel is to send to new subscribers an initial update indicating the current state of the stream. This is currently the pattern for all uses of FanoutChannel.

This pattern is currently accomplished by adding a transform to the input receiver before passing it to fanout channel. The transform updates the current state of the stream in some shared state object. That shared state object is then captured and used in the getInitialValues function passed to subscribe, in order to let new subscribers know the current state of the stream.

However, this approach can lead to a race condition. In this approach, the transform function is executed (on the transform executor) before the getInitialValues function is executed (on the FanoutChannel executor). If someone adds a new subscriber in between, getInitialValues will use the updated shared state, even though the corresponding update has not yet been sent from the output of the transform to the input of the fanout channel.

To solve this race condition, we need to ensure that new subscribers are not added between the shared state change and the fanning out of the update that led to the shared state change. To do this, this diff adds explicit support for a custom context object that can store state. The context object's update function is called on every update, allowing the shared state to be updated on each new value. The context object is also accessible to the getInitialValues function, which allows sending an update with the current state (based on the context) to new subscribers.

This enables the desired pattern without a race condition, and avoids the need for a transform.

Reviewed By: aary

Differential Revision: D30889893

fbshipit-source-id: 9a79fd5a823db1ae477b6b63170978925b791dda
parent c7e095e7
...@@ -25,16 +25,18 @@ ...@@ -25,16 +25,18 @@
namespace folly { namespace folly {
namespace channels { namespace channels {
template <typename ValueType> template <typename ValueType, typename ContextType>
FanoutChannel<ValueType>::FanoutChannel(TProcessor* processor) FanoutChannel<ValueType, ContextType>::FanoutChannel(TProcessor* processor)
: processor_(processor) {} : processor_(processor) {}
template <typename ValueType> template <typename ValueType, typename ContextType>
FanoutChannel<ValueType>::FanoutChannel(FanoutChannel&& other) noexcept FanoutChannel<ValueType, ContextType>::FanoutChannel(
FanoutChannel&& other) noexcept
: processor_(std::exchange(other.processor_, nullptr)) {} : processor_(std::exchange(other.processor_, nullptr)) {}
template <typename ValueType> template <typename ValueType, typename ContextType>
FanoutChannel<ValueType>& FanoutChannel<ValueType>::operator=( FanoutChannel<ValueType, ContextType>&
FanoutChannel<ValueType, ContextType>::operator=(
FanoutChannel&& other) noexcept { FanoutChannel&& other) noexcept {
if (&other == this) { if (&other == this) {
return *this; return *this;
...@@ -46,31 +48,33 @@ FanoutChannel<ValueType>& FanoutChannel<ValueType>::operator=( ...@@ -46,31 +48,33 @@ FanoutChannel<ValueType>& FanoutChannel<ValueType>::operator=(
return *this; return *this;
} }
template <typename ValueType> template <typename ValueType, typename ContextType>
FanoutChannel<ValueType>::~FanoutChannel() { FanoutChannel<ValueType, ContextType>::~FanoutChannel() {
if (processor_ != nullptr) { if (processor_ != nullptr) {
std::move(*this).close(folly::exception_wrapper()); std::move(*this).close(folly::exception_wrapper());
} }
} }
template <typename ValueType> template <typename ValueType, typename ContextType>
FanoutChannel<ValueType>::operator bool() const { FanoutChannel<ValueType, ContextType>::operator bool() const {
return processor_ != nullptr; return processor_ != nullptr;
} }
template <typename ValueType> template <typename ValueType, typename ContextType>
Receiver<ValueType> FanoutChannel<ValueType>::subscribe( Receiver<ValueType> FanoutChannel<ValueType, ContextType>::subscribe(
folly::Function<std::vector<ValueType>()> getInitialValues) { folly::Function<std::vector<ValueType>(const ContextType&)>
getInitialValues) {
return processor_->subscribe(std::move(getInitialValues)); return processor_->subscribe(std::move(getInitialValues));
} }
template <typename ValueType> template <typename ValueType, typename ContextType>
bool FanoutChannel<ValueType>::anySubscribers() { bool FanoutChannel<ValueType, ContextType>::anySubscribers() {
return processor_->anySubscribers(); return processor_->anySubscribers();
} }
template <typename ValueType> template <typename ValueType, typename ContextType>
void FanoutChannel<ValueType>::close(folly::exception_wrapper ex) && { void FanoutChannel<ValueType, ContextType>::close(
folly::exception_wrapper ex) && {
processor_->destroyHandle( processor_->destroyHandle(
ex ? detail::CloseResult(std::move(ex)) : detail::CloseResult()); ex ? detail::CloseResult(std::move(ex)) : detail::CloseResult());
processor_ = nullptr; processor_ = nullptr;
...@@ -78,11 +82,12 @@ void FanoutChannel<ValueType>::close(folly::exception_wrapper ex) && { ...@@ -78,11 +82,12 @@ void FanoutChannel<ValueType>::close(folly::exception_wrapper ex) && {
namespace detail { namespace detail {
template <typename ValueType> template <typename ValueType, typename ContextType>
class IFanoutChannelProcessor : public IChannelCallback { class IFanoutChannelProcessor : public IChannelCallback {
public: public:
virtual Receiver<ValueType> subscribe( virtual Receiver<ValueType> subscribe(
folly::Function<std::vector<ValueType>()> getInitialValues) = 0; folly::Function<std::vector<ValueType>(const ContextType&)>
getInitialValues) = 0;
virtual bool anySubscribers() = 0; virtual bool anySubscribers() = 0;
...@@ -108,16 +113,20 @@ class IFanoutChannelProcessor : public IChannelCallback { ...@@ -108,16 +113,20 @@ class IFanoutChannelProcessor : public IChannelCallback {
* then be deleted once the input receiver transitions to the * then be deleted once the input receiver transitions to the
* CancellationProcessed state. * CancellationProcessed state.
*/ */
template <typename ValueType> template <typename ValueType, typename ContextType>
class FanoutChannelProcessor : public IFanoutChannelProcessor<ValueType> { class FanoutChannelProcessor
: public IFanoutChannelProcessor<ValueType, ContextType> {
private: private:
struct State { struct State {
State(ContextType _context) : context(std::move(_context)) {}
ChannelState getReceiverState() { ChannelState getReceiverState() {
return detail::getReceiverState(receiver.get()); return detail::getReceiverState(receiver.get());
} }
ChannelBridgePtr<ValueType> receiver; ChannelBridgePtr<ValueType> receiver;
FanoutSender<ValueType> fanoutSender; FanoutSender<ValueType> fanoutSender;
ContextType context;
bool handleDeleted{false}; bool handleDeleted{false};
}; };
...@@ -125,8 +134,9 @@ class FanoutChannelProcessor : public IFanoutChannelProcessor<ValueType> { ...@@ -125,8 +134,9 @@ class FanoutChannelProcessor : public IFanoutChannelProcessor<ValueType> {
public: public:
explicit FanoutChannelProcessor( explicit FanoutChannelProcessor(
folly::Executor::KeepAlive<folly::SequencedExecutor> executor) folly::Executor::KeepAlive<folly::SequencedExecutor> executor,
: executor_(std::move(executor)) {} ContextType context)
: executor_(std::move(executor)), state_(std::move(context)) {}
/** /**
* Starts fanning out values from the input receiver to all output receivers. * Starts fanning out values from the input receiver to all output receivers.
...@@ -150,10 +160,11 @@ class FanoutChannelProcessor : public IFanoutChannelProcessor<ValueType> { ...@@ -150,10 +160,11 @@ class FanoutChannelProcessor : public IFanoutChannelProcessor<ValueType> {
* receiver. * receiver.
*/ */
Receiver<ValueType> subscribe( Receiver<ValueType> subscribe(
folly::Function<std::vector<ValueType>()> getInitialValues) override { folly::Function<std::vector<ValueType>(const ContextType&)>
getInitialValues) override {
auto state = state_.wlock(); auto state = state_.wlock();
auto initialValues = auto initialValues = getInitialValues ? getInitialValues(state->context)
getInitialValues ? getInitialValues() : std::vector<ValueType>(); : std::vector<ValueType>();
if (!state->receiver) { if (!state->receiver) {
auto [receiver, sender] = Channel<ValueType>::create(); auto [receiver, sender] = Channel<ValueType>::create();
for (auto&& value : initialValues) { for (auto&& value : initialValues) {
...@@ -251,6 +262,7 @@ class FanoutChannelProcessor : public IFanoutChannelProcessor<ValueType> { ...@@ -251,6 +262,7 @@ class FanoutChannelProcessor : public IFanoutChannelProcessor<ValueType> {
if (inputResult.hasValue()) { if (inputResult.hasValue()) {
// We have received a normal value from the input receiver. Write it to // We have received a normal value from the input receiver. Write it to
// all output senders. // all output senders.
state->context.update(inputResult.value());
state->fanoutSender.write(std::move(inputResult.value())); state->fanoutSender.write(std::move(inputResult.value()));
} else { } else {
// The input receiver was closed. // The input receiver was closed.
...@@ -312,14 +324,15 @@ class FanoutChannelProcessor : public IFanoutChannelProcessor<ValueType> { ...@@ -312,14 +324,15 @@ class FanoutChannelProcessor : public IFanoutChannelProcessor<ValueType> {
}; };
} // namespace detail } // namespace detail
template <typename TReceiver, typename ValueType> template <typename TReceiver, typename ValueType, typename ContextType>
FanoutChannel<ValueType> createFanoutChannel( FanoutChannel<ValueType, ContextType> createFanoutChannel(
TReceiver inputReceiver, TReceiver inputReceiver,
folly::Executor::KeepAlive<folly::SequencedExecutor> executor) { folly::Executor::KeepAlive<folly::SequencedExecutor> executor,
auto* processor = ContextType context) {
new detail::FanoutChannelProcessor<ValueType>(std::move(executor)); auto* processor = new detail::FanoutChannelProcessor<ValueType, ContextType>(
std::move(executor), std::move(context));
processor->start(std::move(inputReceiver)); processor->start(std::move(inputReceiver));
return FanoutChannel<ValueType>(processor); return FanoutChannel<ValueType, ContextType>(processor);
} }
} // namespace channels } // namespace channels
} // namespace folly } // namespace folly
...@@ -23,10 +23,15 @@ namespace folly { ...@@ -23,10 +23,15 @@ namespace folly {
namespace channels { namespace channels {
namespace detail { namespace detail {
template <typename ValueType> template <typename ValueType, typename ContextType>
class IFanoutChannelProcessor; class IFanoutChannelProcessor;
} }
template <typename TValue>
struct NoContext {
void update(const TValue&) {}
};
/** /**
* A fanout channel allows fanning out updates from a single input receiver * A fanout channel allows fanning out updates from a single input receiver
* to multiple output receivers. * to multiple output receivers.
...@@ -35,7 +40,17 @@ class IFanoutChannelProcessor; ...@@ -35,7 +40,17 @@ class IFanoutChannelProcessor;
* computes a set of initial values. These initial values will only be sent to * computes a set of initial values. These initial values will only be sent to
* the new receiver. * the new receiver.
* *
* Example: * FanoutChannel allows specifying an optional context object. If specified, the
* context object must have a void update function:
*
* void update(const ValueType&);
*
* This update function will be called on every value from the input receiver.
* The context will be passed to the getInitialUpdates argument to subscribe,
* allowing for initial updates to depend on the context. This facilitates the
* common pattern of letting new subscribers know where they are starting from.
*
* Example without context:
* *
* // Function that returns a receiver: * // Function that returns a receiver:
* Receiver<int> getInputReceiver(); * Receiver<int> getInputReceiver();
...@@ -47,10 +62,28 @@ class IFanoutChannelProcessor; ...@@ -47,10 +62,28 @@ class IFanoutChannelProcessor;
* auto receiver1 = fanoutChannel.subscribe(); * auto receiver1 = fanoutChannel.subscribe();
* auto receiver2 = fanoutChannel.subscribe(); * auto receiver2 = fanoutChannel.subscribe();
* auto receiver3 = fanoutChannel.subscribe([]{ return {1, 2, 3}; }); * auto receiver3 = fanoutChannel.subscribe([]{ return {1, 2, 3}; });
*
* Example with context:
*
* struct Context {
* int lastValue{-1};
*
* void update(const int& value) {
* lastValue = value;
* }
* };
*
* auto fanoutChannel =
* createFanoutChannel(getReceiver(), getExecutor(), Context());
* auto receiver1 = fanoutChannel.subscribe(
* [](const Context& context) { return {context.latestValue}; });
* auto receiver2 = fanoutChannel.subscribe(
* [](const Context& context) { return {context.latestValue}; });
* std::move(fanoutChannel).close();
*/ */
template <typename ValueType> template <typename ValueType, typename ContextType = NoContext<ValueType>>
class FanoutChannel { class FanoutChannel {
using TProcessor = detail::IFanoutChannelProcessor<ValueType>; using TProcessor = detail::IFanoutChannelProcessor<ValueType, ContextType>;
public: public:
explicit FanoutChannel(TProcessor* processor); explicit FanoutChannel(TProcessor* processor);
...@@ -73,7 +106,8 @@ class FanoutChannel { ...@@ -73,7 +106,8 @@ class FanoutChannel {
* getInitialValues, or a deadlock will occur. * getInitialValues, or a deadlock will occur.
*/ */
Receiver<ValueType> subscribe( Receiver<ValueType> subscribe(
folly::Function<std::vector<ValueType>()> getInitialValues = {}); folly::Function<std::vector<ValueType>(const ContextType&)>
getInitialValues = {});
/** /**
* Returns whether this fanout channel has any output receivers. * Returns whether this fanout channel has any output receivers.
...@@ -94,10 +128,12 @@ class FanoutChannel { ...@@ -94,10 +128,12 @@ class FanoutChannel {
*/ */
template < template <
typename ReceiverType, typename ReceiverType,
typename ValueType = typename ReceiverType::ValueType> typename ValueType = typename ReceiverType::ValueType,
FanoutChannel<ValueType> createFanoutChannel( typename ContextType = NoContext<typename ReceiverType::ValueType>>
FanoutChannel<ValueType, ContextType> createFanoutChannel(
ReceiverType inputReceiver, ReceiverType inputReceiver,
folly::Executor::KeepAlive<folly::SequencedExecutor> executor); folly::Executor::KeepAlive<folly::SequencedExecutor> executor,
ContextType context = ContextType());
} // namespace channels } // namespace channels
} // namespace folly } // namespace folly
......
...@@ -60,16 +60,22 @@ class FanoutChannelFixture : public Test { ...@@ -60,16 +60,22 @@ class FanoutChannelFixture : public Test {
}; };
TEST_F(FanoutChannelFixture, ReceiveValue_FanoutBroadcastsValues) { TEST_F(FanoutChannelFixture, ReceiveValue_FanoutBroadcastsValues) {
struct LatestVersion {
int version{-1};
void update(const int& newVersion) { version = newVersion; }
};
auto [inputReceiver, sender] = Channel<int>::create(); auto [inputReceiver, sender] = Channel<int>::create();
auto fanoutChannel = auto fanoutChannel = createFanoutChannel(
createFanoutChannel(std::move(inputReceiver), &executor_); std::move(inputReceiver), &executor_, LatestVersion());
EXPECT_FALSE(fanoutChannel.anySubscribers()); EXPECT_FALSE(fanoutChannel.anySubscribers());
auto [handle1, callback1] = processValues(fanoutChannel.subscribe( auto [handle1, callback1] = processValues(fanoutChannel.subscribe(
[]() { return toVector(100); } /* getInitialValues */)); [](const auto&) { return toVector(100); } /* getInitialValues */));
auto [handle2, callback2] = processValues(fanoutChannel.subscribe( auto [handle2, callback2] = processValues(fanoutChannel.subscribe(
[]() { return toVector(200); } /* getInitialValues */)); [](const auto&) { return toVector(200); } /* getInitialValues */));
EXPECT_TRUE(fanoutChannel.anySubscribers()); EXPECT_TRUE(fanoutChannel.anySubscribers());
EXPECT_CALL(*callback1, onValue(100)); EXPECT_CALL(*callback1, onValue(100));
...@@ -84,10 +90,12 @@ TEST_F(FanoutChannelFixture, ReceiveValue_FanoutBroadcastsValues) { ...@@ -84,10 +90,12 @@ TEST_F(FanoutChannelFixture, ReceiveValue_FanoutBroadcastsValues) {
sender.write(2); sender.write(2);
executor_.drain(); executor_.drain();
auto [handle3, callback3] = processValues(fanoutChannel.subscribe( auto [handle3, callback3] = processValues(
[]() { return toVector(300); } /* getInitialValues */)); fanoutChannel.subscribe([](const LatestVersion& latestVersion) {
return toVector(latestVersion.version);
} /* getInitialValues */));
EXPECT_CALL(*callback3, onValue(300)); EXPECT_CALL(*callback3, onValue(2));
executor_.drain(); executor_.drain();
sender.write(3); sender.write(3);
...@@ -203,9 +211,9 @@ TEST_F(FanoutChannelFixture, VectorBool) { ...@@ -203,9 +211,9 @@ TEST_F(FanoutChannelFixture, VectorBool) {
createFanoutChannel(std::move(inputReceiver), &executor_); createFanoutChannel(std::move(inputReceiver), &executor_);
auto [handle1, callback1] = processValues(fanoutChannel.subscribe( auto [handle1, callback1] = processValues(fanoutChannel.subscribe(
[] { return toVector(true); } /* getInitialValues */)); [](const auto&) { return toVector(true); } /* getInitialValues */));
auto [handle2, callback2] = processValues(fanoutChannel.subscribe( auto [handle2, callback2] = processValues(fanoutChannel.subscribe(
[] { return toVector(false); } /* getInitialValues */)); [](const auto&) { return toVector(false); } /* getInitialValues */));
EXPECT_CALL(*callback1, onValue(true)); EXPECT_CALL(*callback1, onValue(true));
EXPECT_CALL(*callback2, onValue(false)); EXPECT_CALL(*callback2, onValue(false));
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment