Commit 4abb4669 authored by Dave Watson's avatar Dave Watson Committed by Noam Lerner

Udp Acceptor

Summary:
major changes:

1) ServerSocketFactory and AsyncSocketBase to abstract the differences between UDP and TCP async socket.  Could possibly push some of this to the sockets themselves eventually
2) pipeline() is a pipeline between accept/receive of a UDP message, and before sending it to workers.  Default impl for TCP is to fan out to worker threads.  This is the same as Netty.  Since we don't know if the data is a TCP socket or a UDP message, it's a void*, which sucks (netty uses Object msg, so it isn't any different).

Test Plan: Added lots of new tests.  Doesn't test any data passing yet though, just connects/simple receipt of UDP message.

Reviewed By: hans@fb.com

Subscribers: alandau, bmatheny, mshneer, jsedgwick, yfeldblum, trunkagent, doug, fugalh, folly-diffs@

FB internal diff: D1736670

Tasks: 5788116

Signature: t1:1736670:1424372992:e109450604ed905004bd40dfbb508b5808332c15
parent 3e3b5e35
...@@ -148,6 +148,7 @@ nobase_follyinclude_HEADERS = \ ...@@ -148,6 +148,7 @@ nobase_follyinclude_HEADERS = \
io/async/AsyncUDPSocket.h \ io/async/AsyncUDPSocket.h \
io/async/AsyncServerSocket.h \ io/async/AsyncServerSocket.h \
io/async/AsyncSocket.h \ io/async/AsyncSocket.h \
io/async/AsyncSocketBase.h \
io/async/AsyncSSLSocket.h \ io/async/AsyncSSLSocket.h \
io/async/AsyncSocketException.h \ io/async/AsyncSocketException.h \
io/async/DelayedDestruction.h \ io/async/DelayedDestruction.h \
...@@ -234,6 +235,7 @@ nobase_follyinclude_HEADERS = \ ...@@ -234,6 +235,7 @@ nobase_follyinclude_HEADERS = \
wangle/acceptor/TransportInfo.h \ wangle/acceptor/TransportInfo.h \
wangle/bootstrap/ServerBootstrap.h \ wangle/bootstrap/ServerBootstrap.h \
wangle/bootstrap/ServerBootstrap-inl.h \ wangle/bootstrap/ServerBootstrap-inl.h \
wangle/bootstrap/ServerSocketFactory.h \
wangle/bootstrap/ClientBootstrap.h \ wangle/bootstrap/ClientBootstrap.h \
wangle/channel/AsyncSocketHandler.h \ wangle/channel/AsyncSocketHandler.h \
wangle/channel/ChannelHandler.h \ wangle/channel/ChannelHandler.h \
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <folly/io/async/EventBase.h> #include <folly/io/async/EventBase.h>
#include <folly/io/async/NotificationQueue.h> #include <folly/io/async/NotificationQueue.h>
#include <folly/io/async/AsyncTimeout.h> #include <folly/io/async/AsyncTimeout.h>
#include <folly/io/async/AsyncSocketBase.h>
#include <folly/io/ShutdownSocketSet.h> #include <folly/io/ShutdownSocketSet.h>
#include <folly/SocketAddress.h> #include <folly/SocketAddress.h>
#include <memory> #include <memory>
...@@ -56,7 +57,8 @@ namespace folly { ...@@ -56,7 +57,8 @@ namespace folly {
* modify the AsyncServerSocket state may only be performed from the primary * modify the AsyncServerSocket state may only be performed from the primary
* EventBase thread. * EventBase thread.
*/ */
class AsyncServerSocket : public DelayedDestruction { class AsyncServerSocket : public DelayedDestruction
, public AsyncSocketBase {
public: public:
typedef std::unique_ptr<AsyncServerSocket, Destructor> UniquePtr; typedef std::unique_ptr<AsyncServerSocket, Destructor> UniquePtr;
// Disallow copy, move, and default construction. // Disallow copy, move, and default construction.
......
/*
* Copyright 2015 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <folly/SocketAddress.h>
#include <folly/io/async/EventBase.h>
namespace folly {
class AsyncSocketBase {
public:
virtual EventBase* getEventBase() const = 0;
virtual ~AsyncSocketBase() = default;
virtual void getAddress(SocketAddress*) const = 0;
};
} // namespace
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
#include <sys/uio.h> #include <sys/uio.h>
#include <folly/io/async/DelayedDestruction.h> #include <folly/io/async/DelayedDestruction.h>
#include <folly/io/async/EventBase.h>
#include <folly/io/async/AsyncSocketBase.h>
namespace folly { namespace folly {
...@@ -111,7 +113,7 @@ inline bool isSet(WriteFlags a, WriteFlags b) { ...@@ -111,7 +113,7 @@ inline bool isSet(WriteFlags a, WriteFlags b) {
* timeout, since most callers want to give up if the remote end stops * timeout, since most callers want to give up if the remote end stops
* responding and no further progress can be made sending the data. * responding and no further progress can be made sending the data.
*/ */
class AsyncTransport : public DelayedDestruction { class AsyncTransport : public DelayedDestruction, public AsyncSocketBase {
public: public:
typedef std::unique_ptr<AsyncTransport, Destructor> UniquePtr; typedef std::unique_ptr<AsyncTransport, Destructor> UniquePtr;
...@@ -256,14 +258,6 @@ class AsyncTransport : public DelayedDestruction { ...@@ -256,14 +258,6 @@ class AsyncTransport : public DelayedDestruction {
*/ */
virtual bool isDetachable() const = 0; virtual bool isDetachable() const = 0;
/**
* Get the EventBase used by this transport.
*
* Returns nullptr if this transport is not currently attached to a
* EventBase.
*/
virtual EventBase* getEventBase() const = 0;
/** /**
* Set the send timeout. * Set the send timeout.
* *
...@@ -296,6 +290,10 @@ class AsyncTransport : public DelayedDestruction { ...@@ -296,6 +290,10 @@ class AsyncTransport : public DelayedDestruction {
*/ */
virtual void getLocalAddress(SocketAddress* address) const = 0; virtual void getLocalAddress(SocketAddress* address) const = 0;
virtual void getAddress(SocketAddress* address) const {
getLocalAddress(address);
}
/** /**
* Get the address of the remote endpoint to which this transport is * Get the address of the remote endpoint to which this transport is
* connected. * connected.
......
...@@ -36,7 +36,8 @@ namespace folly { ...@@ -36,7 +36,8 @@ namespace folly {
* more than 1 packet will not work because they will end up with * more than 1 packet will not work because they will end up with
* different event base to process. * different event base to process.
*/ */
class AsyncUDPServerSocket : private AsyncUDPSocket::ReadCallback { class AsyncUDPServerSocket : private AsyncUDPSocket::ReadCallback
, public AsyncSocketBase {
public: public:
class Callback { class Callback {
public: public:
...@@ -93,6 +94,10 @@ class AsyncUDPServerSocket : private AsyncUDPSocket::ReadCallback { ...@@ -93,6 +94,10 @@ class AsyncUDPServerSocket : private AsyncUDPSocket::ReadCallback {
return socket_->address(); return socket_->address();
} }
void getAddress(SocketAddress* a) const {
*a = address();
}
/** /**
* Add a listener to the round robin list * Add a listener to the round robin list
*/ */
...@@ -124,6 +129,10 @@ class AsyncUDPServerSocket : private AsyncUDPSocket::ReadCallback { ...@@ -124,6 +129,10 @@ class AsyncUDPServerSocket : private AsyncUDPSocket::ReadCallback {
socket_.reset(); socket_.reset();
} }
EventBase* getEventBase() const {
return evb_;
}
private: private:
// AsyncUDPSocket::ReadCallback // AsyncUDPSocket::ReadCallback
void getReadBuffer(void** buf, size_t* len) noexcept { void getReadBuffer(void** buf, size_t* len) noexcept {
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <folly/io/IOBuf.h> #include <folly/io/IOBuf.h>
#include <folly/ScopeGuard.h> #include <folly/ScopeGuard.h>
#include <folly/io/async/AsyncSocketException.h> #include <folly/io/async/AsyncSocketException.h>
#include <folly/io/async/AsyncSocketBase.h>
#include <folly/io/async/EventHandler.h> #include <folly/io/async/EventHandler.h>
#include <folly/io/async/EventBase.h> #include <folly/io/async/EventBase.h>
#include <folly/SocketAddress.h> #include <folly/SocketAddress.h>
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <event.h> #include <event.h>
#include <folly/io/async/AsyncSSLSocket.h> #include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/AsyncServerSocket.h> #include <folly/io/async/AsyncServerSocket.h>
#include <folly/io/async/AsyncUDPServerSocket.h>
namespace folly { namespace wangle { namespace folly { namespace wangle {
class ManagedConnection; class ManagedConnection;
...@@ -46,7 +47,8 @@ class SSLContextManager; ...@@ -46,7 +47,8 @@ class SSLContextManager;
*/ */
class Acceptor : class Acceptor :
public folly::AsyncServerSocket::AcceptCallback, public folly::AsyncServerSocket::AcceptCallback,
public folly::wangle::ConnectionManager::Callback { public folly::wangle::ConnectionManager::Callback,
public AsyncUDPServerSocket::Callback {
public: public:
enum class State : uint32_t { enum class State : uint32_t {
...@@ -229,6 +231,10 @@ class Acceptor : ...@@ -229,6 +231,10 @@ class Acceptor :
const std::string& nextProtocolName, const std::string& nextProtocolName,
const TransportInfo& tinfo) = 0; const TransportInfo& tinfo) = 0;
void onListenStarted() noexcept {}
void onListenStopped() noexcept {}
void onDataAvailable(const SocketAddress&, std::unique_ptr<IOBuf>, bool) noexcept {}
virtual AsyncSocket::UniquePtr makeNewAsyncSocket(EventBase* base, int fd) { virtual AsyncSocket::UniquePtr makeNewAsyncSocket(EventBase* base, int fd) {
return AsyncSocket::UniquePtr(new AsyncSocket(base, fd)); return AsyncSocket::UniquePtr(new AsyncSocket(base, fd));
} }
......
...@@ -52,6 +52,27 @@ class TestPipelineFactory : public PipelineFactory<Pipeline> { ...@@ -52,6 +52,27 @@ class TestPipelineFactory : public PipelineFactory<Pipeline> {
std::atomic<int> pipelines{0}; std::atomic<int> pipelines{0};
}; };
class TestAcceptor : public Acceptor {
EventBase base_;
public:
TestAcceptor() : Acceptor(ServerSocketConfig()) {
Acceptor::init(nullptr, &base_);
}
void onNewConnection(
AsyncSocket::UniquePtr sock,
const folly::SocketAddress* address,
const std::string& nextProtocolName,
const TransportInfo& tinfo) {
}
};
class TestAcceptorFactory : public AcceptorFactory {
public:
std::shared_ptr<Acceptor> newAcceptor(EventBase* base) {
return std::make_shared<TestAcceptor>();
}
};
TEST(Bootstrap, Basic) { TEST(Bootstrap, Basic) {
TestServer server; TestServer server;
TestClient client; TestClient client;
...@@ -64,6 +85,13 @@ TEST(Bootstrap, ServerWithPipeline) { ...@@ -64,6 +85,13 @@ TEST(Bootstrap, ServerWithPipeline) {
server.stop(); server.stop();
} }
TEST(Bootstrap, ServerWithChildHandler) {
TestServer server;
server.childHandler(std::make_shared<TestAcceptorFactory>());
server.bind(0);
server.stop();
}
TEST(Bootstrap, ClientServerTest) { TEST(Bootstrap, ClientServerTest) {
TestServer server; TestServer server;
auto factory = std::make_shared<TestPipelineFactory>(); auto factory = std::make_shared<TestPipelineFactory>();
...@@ -236,3 +264,107 @@ TEST(Bootstrap, ExistingSocket) { ...@@ -236,3 +264,107 @@ TEST(Bootstrap, ExistingSocket) {
folly::AsyncServerSocket::UniquePtr socket(new AsyncServerSocket); folly::AsyncServerSocket::UniquePtr socket(new AsyncServerSocket);
server.bind(std::move(socket)); server.bind(std::move(socket));
} }
std::atomic<int> connections{0};
class TestHandlerPipeline
: public ChannelHandlerAdapter<void*,
std::exception> {
public:
void read(Context* ctx, void* conn) {
connections++;
return ctx->fireRead(conn);
}
Future<void> write(Context* ctx, std::exception e) {
return ctx->fireWrite(e);
}
};
template <typename HandlerPipeline>
class TestHandlerPipelineFactory
: public PipelineFactory<ServerBootstrap<Pipeline>::AcceptPipeline> {
public:
ServerBootstrap<Pipeline>::AcceptPipeline* newPipeline(std::shared_ptr<AsyncSocket>) {
auto pipeline = new ServerBootstrap<Pipeline>::AcceptPipeline;
auto handler = std::make_shared<HandlerPipeline>();
pipeline->addBack(ChannelHandlerPtr<HandlerPipeline>(handler));
return pipeline;
}
};
TEST(Bootstrap, LoadBalanceHandler) {
TestServer server;
auto factory = std::make_shared<TestPipelineFactory>();
server.childPipeline(factory);
auto pipelinefactory =
std::make_shared<TestHandlerPipelineFactory<TestHandlerPipeline>>();
server.pipeline(pipelinefactory);
server.bind(0);
auto base = EventBaseManager::get()->getEventBase();
SocketAddress address;
server.getSockets()[0]->getAddress(&address);
TestClient client;
client.pipelineFactory(std::make_shared<TestClientPipelineFactory>());
client.connect(address);
base->loop();
server.stop();
CHECK(factory->pipelines == 1);
CHECK(connections == 1);
}
class TestUDPPipeline
: public ChannelHandlerAdapter<void*,
std::exception> {
public:
void read(Context* ctx, void* conn) {
connections++;
}
Future<void> write(Context* ctx, std::exception e) {
return ctx->fireWrite(e);
}
};
TEST(Bootstrap, UDP) {
TestServer server;
auto factory = std::make_shared<TestPipelineFactory>();
auto pipelinefactory =
std::make_shared<TestHandlerPipelineFactory<TestUDPPipeline>>();
server.pipeline(pipelinefactory);
server.channelFactory(std::make_shared<AsyncUDPServerSocketFactory>());
server.bind(0);
}
TEST(Bootstrap, UDPClientServerTest) {
connections = 0;
TestServer server;
auto factory = std::make_shared<TestPipelineFactory>();
auto pipelinefactory =
std::make_shared<TestHandlerPipelineFactory<TestUDPPipeline>>();
server.pipeline(pipelinefactory);
server.channelFactory(std::make_shared<AsyncUDPServerSocketFactory>());
server.bind(0);
auto base = EventBaseManager::get()->getEventBase();
SocketAddress address;
server.getSockets()[0]->getAddress(&address);
SocketAddress localhost("::1", 0);
AsyncUDPSocket client(base);
client.bind(localhost);
auto data = IOBuf::create(1);
data->append(1);
*(data->writableData()) = 'a';
client.write(address, std::move(data));
base->loop();
server.stop();
CHECK(connections == 1);
}
...@@ -16,15 +16,19 @@ ...@@ -16,15 +16,19 @@
#pragma once #pragma once
#include <folly/wangle/acceptor/Acceptor.h> #include <folly/wangle/acceptor/Acceptor.h>
#include <folly/wangle/bootstrap/ServerSocketFactory.h>
#include <folly/io/async/EventBaseManager.h> #include <folly/io/async/EventBaseManager.h>
#include <folly/wangle/concurrent/IOThreadPoolExecutor.h> #include <folly/wangle/concurrent/IOThreadPoolExecutor.h>
#include <folly/wangle/acceptor/ManagedConnection.h> #include <folly/wangle/acceptor/ManagedConnection.h>
#include <folly/wangle/channel/ChannelPipeline.h> #include <folly/wangle/channel/ChannelPipeline.h>
#include <folly/wangle/channel/ChannelHandler.h>
namespace folly { namespace folly {
template <typename Pipeline> template <typename Pipeline>
class ServerAcceptor : public Acceptor { class ServerAcceptor
: public Acceptor
, public folly::wangle::ChannelHandlerAdapter<void*, std::exception> {
typedef std::unique_ptr<Pipeline, typedef std::unique_ptr<Pipeline,
folly::DelayedDestruction::Destructor> PipelinePtr; folly::DelayedDestruction::Destructor> PipelinePtr;
...@@ -56,20 +60,25 @@ class ServerAcceptor : public Acceptor { ...@@ -56,20 +60,25 @@ class ServerAcceptor : public Acceptor {
public: public:
explicit ServerAcceptor( explicit ServerAcceptor(
std::shared_ptr<PipelineFactory<Pipeline>> pipelineFactory, std::shared_ptr<PipelineFactory<Pipeline>> pipelineFactory,
std::shared_ptr<folly::wangle::ChannelPipeline<
void*, std::exception>> acceptorPipeline,
EventBase* base) EventBase* base)
: Acceptor(ServerSocketConfig()) : Acceptor(ServerSocketConfig())
, pipelineFactory_(pipelineFactory) { , base_(base)
Acceptor::init(nullptr, base); , childPipelineFactory_(pipelineFactory)
, acceptorPipeline_(acceptorPipeline) {
Acceptor::init(nullptr, base_);
CHECK(acceptorPipeline_);
acceptorPipeline_->addBack(folly::wangle::ChannelHandlerPtr<ServerAcceptor, false>(this));
acceptorPipeline_->finalize();
} }
/* See Acceptor::onNewConnection for details */ void read(Context* ctx, void* conn) {
void onNewConnection( AsyncSocket::UniquePtr transport((AsyncSocket*)conn);
AsyncSocket::UniquePtr transport, const SocketAddress* address,
const std::string& nextProtocolName, const TransportInfo& tinfo) {
std::unique_ptr<Pipeline, std::unique_ptr<Pipeline,
folly::DelayedDestruction::Destructor> folly::DelayedDestruction::Destructor>
pipeline(pipelineFactory_->newPipeline( pipeline(childPipelineFactory_->newPipeline(
std::shared_ptr<AsyncSocket>( std::shared_ptr<AsyncSocket>(
transport.release(), transport.release(),
folly::DelayedDestruction::Destructor()))); folly::DelayedDestruction::Destructor())));
...@@ -77,22 +86,53 @@ class ServerAcceptor : public Acceptor { ...@@ -77,22 +86,53 @@ class ServerAcceptor : public Acceptor {
Acceptor::addConnection(connection); Acceptor::addConnection(connection);
} }
folly::Future<void> write(Context* ctx, std::exception e) {
return ctx->fireWrite(e);
}
/* See Acceptor::onNewConnection for details */
void onNewConnection(
AsyncSocket::UniquePtr transport, const SocketAddress* address,
const std::string& nextProtocolName, const TransportInfo& tinfo) {
acceptorPipeline_->read(transport.release());
}
// UDP thunk
void onDataAvailable(const folly::SocketAddress& addr,
std::unique_ptr<folly::IOBuf> buf,
bool truncated) noexcept {
acceptorPipeline_->read(buf.release());
}
private: private:
std::shared_ptr<PipelineFactory<Pipeline>> pipelineFactory_; EventBase* base_;
std::shared_ptr<PipelineFactory<Pipeline>> childPipelineFactory_;
std::shared_ptr<folly::wangle::ChannelPipeline<
void*, std::exception>> acceptorPipeline_;
}; };
template <typename Pipeline> template <typename Pipeline>
class ServerAcceptorFactory : public AcceptorFactory { class ServerAcceptorFactory : public AcceptorFactory {
public: public:
explicit ServerAcceptorFactory( explicit ServerAcceptorFactory(
std::shared_ptr<PipelineFactory<Pipeline>> factory) std::shared_ptr<PipelineFactory<Pipeline>> factory,
: factory_(factory) {} std::shared_ptr<PipelineFactory<folly::wangle::ChannelPipeline<
void*, std::exception>>> pipeline)
std::shared_ptr<Acceptor> newAcceptor(folly::EventBase* base) { : factory_(factory)
return std::make_shared<ServerAcceptor<Pipeline>>(factory_, base); , pipeline_(pipeline) {}
std::shared_ptr<Acceptor> newAcceptor(EventBase* base) {
std::shared_ptr<folly::wangle::ChannelPipeline<
void*, std::exception>> pipeline(
pipeline_->newPipeline(nullptr));
return std::make_shared<ServerAcceptor<Pipeline>>(factory_, pipeline, base);
} }
private: private:
std::shared_ptr<PipelineFactory<Pipeline>> factory_; std::shared_ptr<PipelineFactory<Pipeline>> factory_;
std::shared_ptr<PipelineFactory<
folly::wangle::ChannelPipeline<
void*, std::exception>>> pipeline_;
}; };
class ServerWorkerPool : public folly::wangle::ThreadPoolExecutor::Observer { class ServerWorkerPool : public folly::wangle::ThreadPoolExecutor::Observer {
...@@ -100,10 +140,12 @@ class ServerWorkerPool : public folly::wangle::ThreadPoolExecutor::Observer { ...@@ -100,10 +140,12 @@ class ServerWorkerPool : public folly::wangle::ThreadPoolExecutor::Observer {
explicit ServerWorkerPool( explicit ServerWorkerPool(
std::shared_ptr<AcceptorFactory> acceptorFactory, std::shared_ptr<AcceptorFactory> acceptorFactory,
folly::wangle::IOThreadPoolExecutor* exec, folly::wangle::IOThreadPoolExecutor* exec,
std::vector<std::shared_ptr<folly::AsyncServerSocket>>* sockets) std::vector<std::shared_ptr<folly::AsyncSocketBase>>* sockets,
std::shared_ptr<ServerSocketFactory> socketFactory)
: acceptorFactory_(acceptorFactory) : acceptorFactory_(acceptorFactory)
, exec_(exec) , exec_(exec)
, sockets_(sockets) { , sockets_(sockets)
, socketFactory_(socketFactory) {
CHECK(exec); CHECK(exec);
} }
...@@ -128,7 +170,8 @@ class ServerWorkerPool : public folly::wangle::ThreadPoolExecutor::Observer { ...@@ -128,7 +170,8 @@ class ServerWorkerPool : public folly::wangle::ThreadPoolExecutor::Observer {
std::shared_ptr<Acceptor>> workers_; std::shared_ptr<Acceptor>> workers_;
std::shared_ptr<AcceptorFactory> acceptorFactory_; std::shared_ptr<AcceptorFactory> acceptorFactory_;
folly::wangle::IOThreadPoolExecutor* exec_{nullptr}; folly::wangle::IOThreadPoolExecutor* exec_{nullptr};
std::vector<std::shared_ptr<folly::AsyncServerSocket>>* sockets_; std::vector<std::shared_ptr<folly::AsyncSocketBase>>* sockets_;
std::shared_ptr<ServerSocketFactory> socketFactory_;
}; };
template <typename F> template <typename F>
...@@ -138,4 +181,16 @@ void ServerWorkerPool::forEachWorker(F&& f) const { ...@@ -138,4 +181,16 @@ void ServerWorkerPool::forEachWorker(F&& f) const {
} }
} }
class DefaultAcceptPipelineFactory
: public PipelineFactory<wangle::ChannelPipeline<void*, std::exception>> {
typedef wangle::ChannelPipeline<
void*,
std::exception> AcceptPipeline;
public:
AcceptPipeline* newPipeline(std::shared_ptr<AsyncSocket>) {
return new AcceptPipeline;
}
};
} // namespace } // namespace
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
*/ */
#include <folly/wangle/bootstrap/ServerBootstrap.h> #include <folly/wangle/bootstrap/ServerBootstrap.h>
#include <folly/wangle/concurrent/NamedThreadFactory.h> #include <folly/wangle/concurrent/NamedThreadFactory.h>
#include <folly/wangle/channel/ChannelHandler.h>
#include <folly/io/async/EventBaseManager.h> #include <folly/io/async/EventBaseManager.h>
namespace folly { namespace folly {
...@@ -25,8 +26,9 @@ void ServerWorkerPool::threadStarted( ...@@ -25,8 +26,9 @@ void ServerWorkerPool::threadStarted(
workers_.insert({h, worker}); workers_.insert({h, worker});
for(auto socket : *sockets_) { for(auto socket : *sockets_) {
socket->getEventBase()->runInEventBaseThread([this, worker, socket](){ socket->getEventBase()->runInEventBaseThreadAndWait([this, worker, socket](){
socket->addAcceptCallback(worker.get(), worker->getEventBase()); socketFactory_->addAcceptCB(
socket, worker.get(), worker->getEventBase());
}); });
} }
} }
...@@ -38,22 +40,22 @@ void ServerWorkerPool::threadStopped( ...@@ -38,22 +40,22 @@ void ServerWorkerPool::threadStopped(
for (auto& socket : *sockets_) { for (auto& socket : *sockets_) {
folly::Baton<> barrier; folly::Baton<> barrier;
socket->getEventBase()->runInEventBaseThread([&]() { socket->getEventBase()->runInEventBaseThreadAndWait([&]() {
socket->removeAcceptCallback(worker->second.get(), nullptr); socketFactory_->removeAcceptCB(
socket, worker->second.get(), nullptr);
barrier.post(); barrier.post();
}); });
barrier.wait(); barrier.wait();
} }
CHECK(worker->second->getEventBase() != nullptr); if (!worker->second->getEventBase()->isInEventBaseThread()) {
CHECK(!worker->second->getEventBase()->isInEventBaseThread()); worker->second->getEventBase()->runInEventBaseThreadAndWait([=]() {
folly::Baton<> barrier;
worker->second->getEventBase()->runInEventBaseThread([&]() {
worker->second->dropAllConnections(); worker->second->dropAllConnections();
barrier.post();
}); });
} else {
worker->second->dropAllConnections();
}
barrier.wait();
workers_.erase(worker); workers_.erase(worker);
} }
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <folly/wangle/bootstrap/ServerBootstrap-inl.h> #include <folly/wangle/bootstrap/ServerBootstrap-inl.h>
#include <folly/Baton.h> #include <folly/Baton.h>
#include <folly/wangle/channel/ChannelPipeline.h>
namespace folly { namespace folly {
...@@ -44,16 +45,24 @@ class ServerBootstrap { ...@@ -44,16 +45,24 @@ class ServerBootstrap {
~ServerBootstrap() { ~ServerBootstrap() {
stop(); stop();
} }
/* TODO(davejwatson)
* typedef wangle::ChannelPipeline<
* If there is any work to be done BEFORE handing the work to IO void*,
* threads, this handler is where the pipeline to do it would be std::exception> AcceptPipeline;
* set. /*
* * Pipeline used to add connections to event bases.
* This could be used for things like logging, load balancing, or * This is used for UDP or for load balancing
* advanced load balancing on IO threads. Netty also provides this. * TCP connections to IO threads explicitly
*/ */
ServerBootstrap* handler() { ServerBootstrap* pipeline(
std::shared_ptr<PipelineFactory<AcceptPipeline>> factory) {
pipeline_ = factory;
return this;
}
ServerBootstrap* channelFactory(
std::shared_ptr<ServerSocketFactory> factory) {
socketFactory_ = factory;
return this; return this;
} }
...@@ -75,7 +84,7 @@ class ServerBootstrap { ...@@ -75,7 +84,7 @@ class ServerBootstrap {
*/ */
ServerBootstrap* childPipeline( ServerBootstrap* childPipeline(
std::shared_ptr<PipelineFactory<Pipeline>> factory) { std::shared_ptr<PipelineFactory<Pipeline>> factory) {
pipelineFactory_ = factory; childPipelineFactory_ = factory;
return this; return this;
} }
...@@ -111,15 +120,19 @@ class ServerBootstrap { ...@@ -111,15 +120,19 @@ class ServerBootstrap {
32, std::make_shared<wangle::NamedThreadFactory>("IO Thread")); 32, std::make_shared<wangle::NamedThreadFactory>("IO Thread"));
} }
CHECK(acceptorFactory_ || pipelineFactory_); // TODO better config checking
// CHECK(acceptorFactory_ || childPipelineFactory_);
CHECK(!(acceptorFactory_ && childPipelineFactory_));
if (acceptorFactory_) { if (acceptorFactory_) {
workerFactory_ = std::make_shared<ServerWorkerPool>( workerFactory_ = std::make_shared<ServerWorkerPool>(
acceptorFactory_, io_group.get(), &sockets_); acceptorFactory_, io_group.get(), &sockets_, socketFactory_);
} else { } else {
workerFactory_ = std::make_shared<ServerWorkerPool>( workerFactory_ = std::make_shared<ServerWorkerPool>(
std::make_shared<ServerAcceptorFactory<Pipeline>>(pipelineFactory_), std::make_shared<ServerAcceptorFactory<Pipeline>>(
io_group.get(), &sockets_); childPipelineFactory_,
pipeline_),
io_group.get(), &sockets_, socketFactory_);
} }
io_group->addObserver(workerFactory_); io_group->addObserver(workerFactory_);
...@@ -143,13 +156,14 @@ class ServerBootstrap { ...@@ -143,13 +156,14 @@ class ServerBootstrap {
// Since only a single socket is given, // Since only a single socket is given,
// we can only accept on a single thread // we can only accept on a single thread
CHECK(acceptor_group_->numThreads() == 1); CHECK(acceptor_group_->numThreads() == 1);
std::shared_ptr<folly::AsyncServerSocket> socket( std::shared_ptr<folly::AsyncServerSocket> socket(
s.release(), DelayedDestruction::Destructor()); s.release(), DelayedDestruction::Destructor());
folly::Baton<> barrier; folly::Baton<> barrier;
acceptor_group_->add([&](){ acceptor_group_->add([&](){
socket->attachEventBase(EventBaseManager::get()->getEventBase()); socket->attachEventBase(EventBaseManager::get()->getEventBase());
socket->listen(1024); socket->listen(socketConfig.acceptBacklog);
socket->startAccepting(); socket->startAccepting();
barrier.post(); barrier.post();
}); });
...@@ -157,8 +171,9 @@ class ServerBootstrap { ...@@ -157,8 +171,9 @@ class ServerBootstrap {
// Startup all the threads // Startup all the threads
workerFactory_->forEachWorker([this, socket](Acceptor* worker){ workerFactory_->forEachWorker([this, socket](Acceptor* worker){
socket->getEventBase()->runInEventBaseThread([this, worker, socket](){ socket->getEventBase()->runInEventBaseThreadAndWait(
socket->addAcceptCallback(worker, worker->getEventBase()); [this, worker, socket](){
socketFactory_->addAcceptCB(socket, worker, worker->getEventBase());
}); });
}); });
...@@ -192,25 +207,27 @@ class ServerBootstrap { ...@@ -192,25 +207,27 @@ class ServerBootstrap {
} }
std::mutex sock_lock; std::mutex sock_lock;
std::vector<std::shared_ptr<folly::AsyncServerSocket>> new_sockets; std::vector<std::shared_ptr<folly::AsyncSocketBase>> new_sockets;
std::exception_ptr exn; std::exception_ptr exn;
auto startupFunc = [&](std::shared_ptr<folly::Baton<>> barrier){ auto startupFunc = [&](std::shared_ptr<folly::Baton<>> barrier){
auto socket = folly::AsyncServerSocket::newSocket();
socket->setReusePortEnabled(reusePort);
socket->attachEventBase(EventBaseManager::get()->getEventBase());
try { try {
if (port >= 0) { auto socket = socketFactory_->newSocket(
socket->bind(port); port, address, socketConfig.acceptBacklog, reusePort, socketConfig);
} else {
socket->bind(address); sock_lock.lock();
new_sockets.push_back(socket);
sock_lock.unlock();
if (port == 0) {
socket->getAddress(&address);
port = address.getPort(); port = address.getPort();
} }
socket->listen(socketConfig.acceptBacklog); barrier->post();
socket->startAccepting();
} catch (...) { } catch (...) {
exn = std::current_exception(); exn = std::current_exception();
barrier->post(); barrier->post();
...@@ -218,16 +235,8 @@ class ServerBootstrap { ...@@ -218,16 +235,8 @@ class ServerBootstrap {
return; return;
} }
sock_lock.lock();
new_sockets.push_back(socket);
sock_lock.unlock();
if (port == 0) {
socket->getAddress(&address);
port = address.getPort();
}
barrier->post();
}; };
auto wait0 = std::make_shared<folly::Baton<>>(); auto wait0 = std::make_shared<folly::Baton<>>();
...@@ -244,16 +253,14 @@ class ServerBootstrap { ...@@ -244,16 +253,14 @@ class ServerBootstrap {
std::rethrow_exception(exn); std::rethrow_exception(exn);
} }
for (auto& socket : new_sockets) {
// Startup all the threads // Startup all the threads
for(auto socket : new_sockets) {
workerFactory_->forEachWorker([this, socket](Acceptor* worker){ workerFactory_->forEachWorker([this, socket](Acceptor* worker){
socket->getEventBase()->runInEventBaseThread([this, worker, socket](){ socket->getEventBase()->runInEventBaseThreadAndWait([this, worker, socket](){
socket->addAcceptCallback(worker, worker->getEventBase()); socketFactory_->addAcceptCB(socket, worker, worker->getEventBase());
}); });
}); });
}
for (auto& socket : new_sockets) {
sockets_.push_back(socket); sockets_.push_back(socket);
} }
} }
...@@ -264,9 +271,8 @@ class ServerBootstrap { ...@@ -264,9 +271,8 @@ class ServerBootstrap {
void stop() { void stop() {
for (auto socket : sockets_) { for (auto socket : sockets_) {
folly::Baton<> barrier; folly::Baton<> barrier;
socket->getEventBase()->runInEventBaseThread([&barrier, socket]() { socket->getEventBase()->runInEventBaseThread([&]() mutable {
socket->stopAccepting(); socketFactory_->stopSocket(socket);
socket->detachEventBase();
barrier.post(); barrier.post();
}); });
barrier.wait(); barrier.wait();
...@@ -284,7 +290,7 @@ class ServerBootstrap { ...@@ -284,7 +290,7 @@ class ServerBootstrap {
/* /*
* Get the list of listening sockets * Get the list of listening sockets
*/ */
const std::vector<std::shared_ptr<folly::AsyncServerSocket>>& const std::vector<std::shared_ptr<folly::AsyncSocketBase>>&
getSockets() const { getSockets() const {
return sockets_; return sockets_;
} }
...@@ -305,10 +311,14 @@ class ServerBootstrap { ...@@ -305,10 +311,14 @@ class ServerBootstrap {
std::shared_ptr<wangle::IOThreadPoolExecutor> io_group_; std::shared_ptr<wangle::IOThreadPoolExecutor> io_group_;
std::shared_ptr<ServerWorkerPool> workerFactory_; std::shared_ptr<ServerWorkerPool> workerFactory_;
std::vector<std::shared_ptr<folly::AsyncServerSocket>> sockets_; std::vector<std::shared_ptr<folly::AsyncSocketBase>> sockets_;
std::shared_ptr<AcceptorFactory> acceptorFactory_; std::shared_ptr<AcceptorFactory> acceptorFactory_;
std::shared_ptr<PipelineFactory<Pipeline>> pipelineFactory_; std::shared_ptr<PipelineFactory<Pipeline>> childPipelineFactory_;
std::shared_ptr<PipelineFactory<AcceptPipeline>> pipeline_{
std::make_shared<DefaultAcceptPipelineFactory>()};
std::shared_ptr<ServerSocketFactory> socketFactory_{
std::make_shared<AsyncServerSocketFactory>()};
}; };
} // namespace } // namespace
/*
* Copyright 2015 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <folly/wangle/bootstrap/ServerBootstrap-inl.h>
#include <folly/io/async/AsyncServerSocket.h>
#include <folly/io/async/EventBaseManager.h>
#include <folly/io/async/AsyncUDPServerSocket.h>
namespace folly {
class ServerSocketFactory {
public:
virtual std::shared_ptr<AsyncSocketBase> newSocket(
int port, SocketAddress address, int backlog,
bool reuse, ServerSocketConfig& config) = 0;
virtual void stopSocket(
std::shared_ptr<AsyncSocketBase>& socket) = 0;
virtual void removeAcceptCB(std::shared_ptr<AsyncSocketBase> sock, Acceptor *callback, EventBase* base) = 0;
virtual void addAcceptCB(std::shared_ptr<AsyncSocketBase> sock, Acceptor* callback, EventBase* base) = 0 ;
virtual ~ServerSocketFactory() = default;
};
class AsyncServerSocketFactory : public ServerSocketFactory {
public:
std::shared_ptr<AsyncSocketBase> newSocket(
int port, SocketAddress address, int backlog, bool reuse,
ServerSocketConfig& config) {
auto socket = folly::AsyncServerSocket::newSocket();
socket->setReusePortEnabled(reuse);
socket->attachEventBase(EventBaseManager::get()->getEventBase());
if (port >= 0) {
socket->bind(port);
} else {
socket->bind(address);
}
socket->listen(config.acceptBacklog);
socket->startAccepting();
return socket;
}
virtual void stopSocket(
std::shared_ptr<AsyncSocketBase>& s) {
auto socket = std::dynamic_pointer_cast<AsyncServerSocket>(s);
DCHECK(socket);
socket->stopAccepting();
socket->detachEventBase();
}
virtual void removeAcceptCB(std::shared_ptr<AsyncSocketBase> s,
Acceptor *callback, EventBase* base) {
auto socket = std::dynamic_pointer_cast<AsyncServerSocket>(s);
CHECK(socket);
socket->removeAcceptCallback(callback, base);
}
virtual void addAcceptCB(std::shared_ptr<AsyncSocketBase> s,
Acceptor* callback, EventBase* base) {
auto socket = std::dynamic_pointer_cast<AsyncServerSocket>(s);
CHECK(socket);
socket->addAcceptCallback(callback, base);
}
};
class AsyncUDPServerSocketFactory : public ServerSocketFactory {
public:
std::shared_ptr<AsyncSocketBase> newSocket(
int port, SocketAddress address, int backlog, bool reuse,
ServerSocketConfig& config) {
auto socket = std::make_shared<AsyncUDPServerSocket>(
EventBaseManager::get()->getEventBase());
//socket->setReusePortEnabled(reuse);
SocketAddress addressr("::1", port);
socket->bind(addressr);
socket->listen();
return socket;
}
virtual void stopSocket(
std::shared_ptr<AsyncSocketBase>& s) {
auto socket = std::dynamic_pointer_cast<AsyncUDPServerSocket>(s);
DCHECK(socket);
socket->close();
}
virtual void removeAcceptCB(std::shared_ptr<AsyncSocketBase> s,
Acceptor *callback, EventBase* base) {
}
virtual void addAcceptCB(std::shared_ptr<AsyncSocketBase> s,
Acceptor* callback, EventBase* base) {
auto socket = std::dynamic_pointer_cast<AsyncUDPServerSocket>(s);
DCHECK(socket);
socket->addListener(base, callback);
}
};
} // namespace
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment