Commit 6d3bf646 authored by Mohammad Husain's avatar Mohammad Husain Committed by facebook-github-bot-4

Add connection event callback to AsyncServerSocket

Summary: Adding a callback to AsyncServerSocket to get notified of client connection events. This can be used for example to record stats about these events.

Reviewed By: @afrind

Differential Revision: D2544776

fb-gh-sync-id: 20d22cfc939c5b937abec2b600c10b7228923ff3
parent dd631eb7
...@@ -91,6 +91,10 @@ void AsyncServerSocket::RemoteAcceptor::messageAvailable( ...@@ -91,6 +91,10 @@ void AsyncServerSocket::RemoteAcceptor::messageAvailable(
switch (msg.type) { switch (msg.type) {
case MessageType::MSG_NEW_CONN: case MessageType::MSG_NEW_CONN:
{ {
if (connectionEventCallback_) {
connectionEventCallback_->onConnectionDequeuedByAcceptCallback(
msg.fd, msg.address);
}
callback_->connectionAccepted(msg.fd, msg.address); callback_->connectionAccepted(msg.fd, msg.address);
break; break;
} }
...@@ -515,7 +519,7 @@ void AsyncServerSocket::addAcceptCallback(AcceptCallback *callback, ...@@ -515,7 +519,7 @@ void AsyncServerSocket::addAcceptCallback(AcceptCallback *callback,
// callback more efficiently without having to use a notification queue. // callback more efficiently without having to use a notification queue.
RemoteAcceptor* acceptor = nullptr; RemoteAcceptor* acceptor = nullptr;
try { try {
acceptor = new RemoteAcceptor(callback); acceptor = new RemoteAcceptor(callback, connectionEventCallback_);
acceptor->start(eventBase, maxAtOnce, maxNumMsgsInQueue_); acceptor->start(eventBase, maxAtOnce, maxNumMsgsInQueue_);
} catch (...) { } catch (...) {
callbacks_.pop_back(); callbacks_.pop_back();
...@@ -722,6 +726,10 @@ void AsyncServerSocket::handlerReady( ...@@ -722,6 +726,10 @@ void AsyncServerSocket::handlerReady(
address.setFromSockaddr(saddr, addrLen); address.setFromSockaddr(saddr, addrLen);
if (clientSocket >= 0 && connectionEventCallback_) {
connectionEventCallback_->onConnectionAccepted(clientSocket, address);
}
std::chrono::time_point<std::chrono::steady_clock> nowMs = std::chrono::time_point<std::chrono::steady_clock> nowMs =
std::chrono::steady_clock::now(); std::chrono::steady_clock::now();
auto timeSinceLastAccept = std::max<int64_t>( auto timeSinceLastAccept = std::max<int64_t>(
...@@ -737,6 +745,10 @@ void AsyncServerSocket::handlerReady( ...@@ -737,6 +745,10 @@ void AsyncServerSocket::handlerReady(
++numDroppedConnections_; ++numDroppedConnections_;
if (clientSocket >= 0) { if (clientSocket >= 0) {
closeNoInt(clientSocket); closeNoInt(clientSocket);
if (connectionEventCallback_) {
connectionEventCallback_->onConnectionDropped(clientSocket,
address);
}
} }
continue; continue;
} }
...@@ -760,6 +772,9 @@ void AsyncServerSocket::handlerReady( ...@@ -760,6 +772,9 @@ void AsyncServerSocket::handlerReady(
} else { } else {
dispatchError("accept() failed", errno); dispatchError("accept() failed", errno);
} }
if (connectionEventCallback_) {
connectionEventCallback_->onConnectionAcceptError(errno);
}
return; return;
} }
...@@ -769,6 +784,9 @@ void AsyncServerSocket::handlerReady( ...@@ -769,6 +784,9 @@ void AsyncServerSocket::handlerReady(
closeNoInt(clientSocket); closeNoInt(clientSocket);
dispatchError("failed to set accepted socket to non-blocking mode", dispatchError("failed to set accepted socket to non-blocking mode",
errno); errno);
if (connectionEventCallback_) {
connectionEventCallback_->onConnectionDropped(clientSocket, address);
}
return; return;
} }
#endif #endif
...@@ -795,6 +813,7 @@ void AsyncServerSocket::dispatchSocket(int socket, ...@@ -795,6 +813,7 @@ void AsyncServerSocket::dispatchSocket(int socket,
return; return;
} }
const SocketAddress addr(address);
// Create a message to send over the notification queue // Create a message to send over the notification queue
QueueMessage msg; QueueMessage msg;
msg.type = MessageType::MSG_NEW_CONN; msg.type = MessageType::MSG_NEW_CONN;
...@@ -804,6 +823,10 @@ void AsyncServerSocket::dispatchSocket(int socket, ...@@ -804,6 +823,10 @@ void AsyncServerSocket::dispatchSocket(int socket,
// Loop until we find a free queue to write to // Loop until we find a free queue to write to
while (true) { while (true) {
if (info->consumer->getQueue()->tryPutMessageNoThrow(std::move(msg))) { if (info->consumer->getQueue()->tryPutMessageNoThrow(std::move(msg))) {
if (connectionEventCallback_) {
connectionEventCallback_->onConnectionEnqueuedForAcceptCallback(socket,
addr);
}
// Success! return. // Success! return.
return; return;
} }
...@@ -831,6 +854,9 @@ void AsyncServerSocket::dispatchSocket(int socket, ...@@ -831,6 +854,9 @@ void AsyncServerSocket::dispatchSocket(int socket,
LOG(ERROR) << "failed to dispatch newly accepted socket:" LOG(ERROR) << "failed to dispatch newly accepted socket:"
<< " all accept callback queues are full"; << " all accept callback queues are full";
closeNoInt(socket); closeNoInt(socket);
if (connectionEventCallback_) {
connectionEventCallback_->onConnectionDropped(socket, addr);
}
return; return;
} }
...@@ -886,6 +912,9 @@ void AsyncServerSocket::enterBackoff() { ...@@ -886,6 +912,9 @@ void AsyncServerSocket::enterBackoff() {
// since we won't be able to re-enable ourselves later. // since we won't be able to re-enable ourselves later.
LOG(ERROR) << "failed to allocate AsyncServerSocket backoff" LOG(ERROR) << "failed to allocate AsyncServerSocket backoff"
<< " timer; unable to temporarly pause accepting"; << " timer; unable to temporarly pause accepting";
if (connectionEventCallback_) {
connectionEventCallback_->onBackoffError();
}
return; return;
} }
} }
...@@ -903,6 +932,9 @@ void AsyncServerSocket::enterBackoff() { ...@@ -903,6 +932,9 @@ void AsyncServerSocket::enterBackoff() {
if (!backoffTimeout_->scheduleTimeout(timeoutMS)) { if (!backoffTimeout_->scheduleTimeout(timeoutMS)) {
LOG(ERROR) << "failed to schedule AsyncServerSocket backoff timer;" LOG(ERROR) << "failed to schedule AsyncServerSocket backoff timer;"
<< "unable to temporarly pause accepting"; << "unable to temporarly pause accepting";
if (connectionEventCallback_) {
connectionEventCallback_->onBackoffError();
}
return; return;
} }
...@@ -912,6 +944,9 @@ void AsyncServerSocket::enterBackoff() { ...@@ -912,6 +944,9 @@ void AsyncServerSocket::enterBackoff() {
for (auto& handler : sockets_) { for (auto& handler : sockets_) {
handler.unregisterHandler(); handler.unregisterHandler();
} }
if (connectionEventCallback_) {
connectionEventCallback_->onBackoffStarted();
}
} }
void AsyncServerSocket::backoffTimeoutExpired() { void AsyncServerSocket::backoffTimeoutExpired() {
...@@ -924,6 +959,9 @@ void AsyncServerSocket::backoffTimeoutExpired() { ...@@ -924,6 +959,9 @@ void AsyncServerSocket::backoffTimeoutExpired() {
// If all of the callbacks were removed, we shouldn't re-enable accepts // If all of the callbacks were removed, we shouldn't re-enable accepts
if (callbacks_.empty()) { if (callbacks_.empty()) {
if (connectionEventCallback_) {
connectionEventCallback_->onBackoffEnded();
}
return; return;
} }
...@@ -942,6 +980,9 @@ void AsyncServerSocket::backoffTimeoutExpired() { ...@@ -942,6 +980,9 @@ void AsyncServerSocket::backoffTimeoutExpired() {
abort(); abort();
} }
} }
if (connectionEventCallback_) {
connectionEventCallback_->onBackoffEnded();
}
} }
......
...@@ -64,6 +64,71 @@ class AsyncServerSocket : public DelayedDestruction ...@@ -64,6 +64,71 @@ class AsyncServerSocket : public DelayedDestruction
// Disallow copy, move, and default construction. // Disallow copy, move, and default construction.
AsyncServerSocket(AsyncServerSocket&&) = delete; AsyncServerSocket(AsyncServerSocket&&) = delete;
/**
* A callback interface to get notified of client socket events.
*
* The ConnectionEventCallback implementations need to be thread-safe as the
* callbacks may be called from different threads.
*/
class ConnectionEventCallback {
public:
virtual ~ConnectionEventCallback() = default;
/**
* onConnectionAccepted() is called right after a client connection
* is accepted using the system accept()/accept4() APIs.
*/
virtual void onConnectionAccepted(const int socket,
const SocketAddress& addr) noexcept = 0;
/**
* onConnectionAcceptError() is called when an error occurred accepting
* a connection.
*/
virtual void onConnectionAcceptError(const int err) noexcept = 0;
/**
* onConnectionDropped() is called when a connection is dropped,
* probably because of some error encountered.
*/
virtual void onConnectionDropped(const int socket,
const SocketAddress& addr) noexcept = 0;
/**
* onConnectionEnqueuedForAcceptCallback() is called when the
* connection is successfully enqueued for an AcceptCallback to pick up.
*/
virtual void onConnectionEnqueuedForAcceptCallback(
const int socket,
const SocketAddress& addr) noexcept = 0;
/**
* onConnectionDequeuedByAcceptCallback() is called when the
* connection is successfully dequeued by an AcceptCallback.
*/
virtual void onConnectionDequeuedByAcceptCallback(
const int socket,
const SocketAddress& addr) noexcept = 0;
/**
* onBackoffStarted is called when the socket has successfully started
* backing off accepting new client sockets.
*/
virtual void onBackoffStarted() noexcept = 0;
/**
* onBackoffEnded is called when the backoff period has ended and the socket
* has successfully resumed accepting new connections if there is any
* AcceptCallback registered.
*/
virtual void onBackoffEnded() noexcept = 0;
/**
* onBackoffError is called when there is an error entering backoff
*/
virtual void onBackoffError() noexcept = 0;
};
class AcceptCallback { class AcceptCallback {
public: public:
virtual ~AcceptCallback() = default; virtual ~AcceptCallback() = default;
...@@ -320,8 +385,8 @@ class AsyncServerSocket : public DelayedDestruction ...@@ -320,8 +385,8 @@ class AsyncServerSocket : public DelayedDestruction
* *
* When a new socket is accepted, one of the AcceptCallbacks will be invoked * When a new socket is accepted, one of the AcceptCallbacks will be invoked
* with the new socket. The AcceptCallbacks are invoked in a round-robin * with the new socket. The AcceptCallbacks are invoked in a round-robin
* fashion. This allows the accepted sockets to distributed among a pool of * fashion. This allows the accepted sockets to be distributed among a pool
* threads, each running its own EventBase object. This is a common model, * of threads, each running its own EventBase object. This is a common model,
* since most asynchronous-style servers typically run one EventBase thread * since most asynchronous-style servers typically run one EventBase thread
* per CPU. * per CPU.
* *
...@@ -584,6 +649,21 @@ class AsyncServerSocket : public DelayedDestruction ...@@ -584,6 +649,21 @@ class AsyncServerSocket : public DelayedDestruction
return accepting_; return accepting_;
} }
/**
* Set the ConnectionEventCallback
*/
void setConnectionEventCallback(
ConnectionEventCallback* const connectionEventCallback) {
connectionEventCallback_ = connectionEventCallback;
}
/**
* Get the ConnectionEventCallback
*/
ConnectionEventCallback* getConnectionEventCallback() const {
return connectionEventCallback_;
}
protected: protected:
/** /**
* Protected destructor. * Protected destructor.
...@@ -618,8 +698,10 @@ class AsyncServerSocket : public DelayedDestruction ...@@ -618,8 +698,10 @@ class AsyncServerSocket : public DelayedDestruction
class RemoteAcceptor class RemoteAcceptor
: private NotificationQueue<QueueMessage>::Consumer { : private NotificationQueue<QueueMessage>::Consumer {
public: public:
explicit RemoteAcceptor(AcceptCallback *callback) explicit RemoteAcceptor(AcceptCallback *callback,
: callback_(callback) {} ConnectionEventCallback *connectionEventCallback)
: callback_(callback),
connectionEventCallback_(connectionEventCallback) {}
~RemoteAcceptor() = default; ~RemoteAcceptor() = default;
...@@ -634,6 +716,7 @@ class AsyncServerSocket : public DelayedDestruction ...@@ -634,6 +716,7 @@ class AsyncServerSocket : public DelayedDestruction
private: private:
AcceptCallback *callback_; AcceptCallback *callback_;
ConnectionEventCallback* connectionEventCallback_;
NotificationQueue<QueueMessage> queue_; NotificationQueue<QueueMessage> queue_;
}; };
...@@ -738,6 +821,7 @@ class AsyncServerSocket : public DelayedDestruction ...@@ -738,6 +821,7 @@ class AsyncServerSocket : public DelayedDestruction
bool reusePortEnabled_{false}; bool reusePortEnabled_{false};
bool closeOnExec_; bool closeOnExec_;
ShutdownSocketSet* shutdownSocketSet_; ShutdownSocketSet* shutdownSocketSet_;
ConnectionEventCallback* connectionEventCallback_{nullptr};
}; };
} // folly } // folly
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <folly/io/async/AsyncSocket.h> #include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/AsyncTimeout.h> #include <folly/io/async/AsyncTimeout.h>
#include <folly/io/async/EventBase.h> #include <folly/io/async/EventBase.h>
#include <folly/RWSpinLock.h>
#include <folly/SocketAddress.h> #include <folly/SocketAddress.h>
#include <folly/io/IOBuf.h> #include <folly/io/IOBuf.h>
...@@ -1452,6 +1453,113 @@ TEST(AsyncSocket, ConnectReadUninstallRead) { ...@@ -1452,6 +1453,113 @@ TEST(AsyncSocket, ConnectReadUninstallRead) {
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
// AsyncServerSocket tests // AsyncServerSocket tests
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
namespace {
/**
* Helper ConnectionEventCallback class for the test code.
* It maintains counters protected by a spin lock.
*/
class TestConnectionEventCallback :
public AsyncServerSocket::ConnectionEventCallback {
public:
virtual void onConnectionAccepted(
const int socket,
const SocketAddress& addr) noexcept override {
folly::RWSpinLock::WriteHolder holder(spinLock_);
connectionAccepted_++;
}
virtual void onConnectionAcceptError(const int err) noexcept override {
folly::RWSpinLock::WriteHolder holder(spinLock_);
connectionAcceptedError_++;
}
virtual void onConnectionDropped(
const int socket,
const SocketAddress& addr) noexcept override {
folly::RWSpinLock::WriteHolder holder(spinLock_);
connectionDropped_++;
}
virtual void onConnectionEnqueuedForAcceptCallback(
const int socket,
const SocketAddress& addr) noexcept override {
folly::RWSpinLock::WriteHolder holder(spinLock_);
connectionEnqueuedForAcceptCallback_++;
}
virtual void onConnectionDequeuedByAcceptCallback(
const int socket,
const SocketAddress& addr) noexcept override {
folly::RWSpinLock::WriteHolder holder(spinLock_);
connectionDequeuedByAcceptCallback_++;
}
virtual void onBackoffStarted() noexcept override {
folly::RWSpinLock::WriteHolder holder(spinLock_);
backoffStarted_++;
}
virtual void onBackoffEnded() noexcept override {
folly::RWSpinLock::WriteHolder holder(spinLock_);
backoffEnded_++;
}
virtual void onBackoffError() noexcept override {
folly::RWSpinLock::WriteHolder holder(spinLock_);
backoffError_++;
}
unsigned int getConnectionAccepted() const {
folly::RWSpinLock::ReadHolder holder(spinLock_);
return connectionAccepted_;
}
unsigned int getConnectionAcceptedError() const {
folly::RWSpinLock::ReadHolder holder(spinLock_);
return connectionAcceptedError_;
}
unsigned int getConnectionDropped() const {
folly::RWSpinLock::ReadHolder holder(spinLock_);
return connectionDropped_;
}
unsigned int getConnectionEnqueuedForAcceptCallback() const {
folly::RWSpinLock::ReadHolder holder(spinLock_);
return connectionEnqueuedForAcceptCallback_;
}
unsigned int getConnectionDequeuedByAcceptCallback() const {
folly::RWSpinLock::ReadHolder holder(spinLock_);
return connectionDequeuedByAcceptCallback_;
}
unsigned int getBackoffStarted() const {
folly::RWSpinLock::ReadHolder holder(spinLock_);
return backoffStarted_;
}
unsigned int getBackoffEnded() const {
folly::RWSpinLock::ReadHolder holder(spinLock_);
return backoffEnded_;
}
unsigned int getBackoffError() const {
folly::RWSpinLock::ReadHolder holder(spinLock_);
return backoffError_;
}
private:
mutable folly::RWSpinLock spinLock_;
unsigned int connectionAccepted_{0};
unsigned int connectionAcceptedError_{0};
unsigned int connectionDropped_{0};
unsigned int connectionEnqueuedForAcceptCallback_{0};
unsigned int connectionDequeuedByAcceptCallback_{0};
unsigned int backoffStarted_{0};
unsigned int backoffEnded_{0};
unsigned int backoffError_{0};
};
/** /**
* Helper AcceptCallback class for the test code * Helper AcceptCallback class for the test code
...@@ -1552,6 +1660,7 @@ class TestAcceptCallback : public AsyncServerSocket::AcceptCallback { ...@@ -1552,6 +1660,7 @@ class TestAcceptCallback : public AsyncServerSocket::AcceptCallback {
std::deque<EventInfo> events_; std::deque<EventInfo> events_;
}; };
}
/** /**
* Make sure accepted sockets have O_NONBLOCK and TCP_NODELAY set * Make sure accepted sockets have O_NONBLOCK and TCP_NODELAY set
...@@ -2043,3 +2152,46 @@ TEST(AsyncSocketTest, UnixDomainSocketTest) { ...@@ -2043,3 +2152,46 @@ TEST(AsyncSocketTest, UnixDomainSocketTest) {
int flags = fcntl(fd, F_GETFL, 0); int flags = fcntl(fd, F_GETFL, 0);
CHECK_EQ(flags & O_NONBLOCK, O_NONBLOCK); CHECK_EQ(flags & O_NONBLOCK, O_NONBLOCK);
} }
TEST(AsyncSocketTest, ConnectionEventCallbackDefault) {
EventBase eventBase;
TestConnectionEventCallback connectionEventCallback;
// Create a server socket
std::shared_ptr<AsyncServerSocket> serverSocket(
AsyncServerSocket::newSocket(&eventBase));
serverSocket->setConnectionEventCallback(&connectionEventCallback);
serverSocket->bind(0);
serverSocket->listen(16);
folly::SocketAddress serverAddress;
serverSocket->getAddress(&serverAddress);
// Add a callback to accept one connection then stop the loop
TestAcceptCallback acceptCallback;
acceptCallback.setConnectionAcceptedFn(
[&](int fd, const folly::SocketAddress& addr) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
});
acceptCallback.setAcceptErrorFn([&](const std::exception& ex) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
});
serverSocket->addAcceptCallback(&acceptCallback, nullptr);
serverSocket->startAccepting();
// Connect to the server socket
std::shared_ptr<AsyncSocket> socket(
AsyncSocket::newSocket(&eventBase, serverAddress));
eventBase.loop();
// Validate the connection event counters
ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
ASSERT_EQ(
connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 1);
ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 1);
ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment