Commit b31ebff2 authored by Orvid King's avatar Orvid King Committed by Facebook Github Bot

NetworkSocket support for AsyncServerSocket

Summary:
The end.

(Note: this ignores all push blocking failures!)

Reviewed By: yfeldblum

Differential Revision: D12878966

fbshipit-source-id: 23cf64b9f66560ddeb87631f5ab955334944ea55
parent c73706ee
No related merge requests found
...@@ -35,8 +35,6 @@ ...@@ -35,8 +35,6 @@
#include <string.h> #include <string.h>
#include <sys/types.h> #include <sys/types.h>
namespace fsp = folly::portability::sockets;
namespace folly { namespace folly {
#ifndef TCP_SAVE_SYN #ifndef TCP_SAVE_SYN
...@@ -58,27 +56,6 @@ const uint32_t AsyncServerSocket::kDefaultMaxAcceptAtOnce; ...@@ -58,27 +56,6 @@ const uint32_t AsyncServerSocket::kDefaultMaxAcceptAtOnce;
const uint32_t AsyncServerSocket::kDefaultCallbackAcceptAtOnce; const uint32_t AsyncServerSocket::kDefaultCallbackAcceptAtOnce;
const uint32_t AsyncServerSocket::kDefaultMaxMessagesInQueue; const uint32_t AsyncServerSocket::kDefaultMaxMessagesInQueue;
int setCloseOnExec(int fd, int value) {
// Read the current flags
int old_flags = fcntl(fd, F_GETFD, 0);
// If reading the flags failed, return error indication now
if (old_flags < 0) {
return -1;
}
// Set just the flag we want to set
int new_flags;
if (value != 0) {
new_flags = old_flags | FD_CLOEXEC;
} else {
new_flags = old_flags & ~FD_CLOEXEC;
}
// Store modified flag word in the descriptor
return fcntl(fd, F_SETFD, new_flags);
}
void AsyncServerSocket::RemoteAcceptor::start( void AsyncServerSocket::RemoteAcceptor::start(
EventBase* eventBase, EventBase* eventBase,
uint32_t maxAtOnce, uint32_t maxAtOnce,
...@@ -292,7 +269,8 @@ void AsyncServerSocket::detachEventBase() { ...@@ -292,7 +269,8 @@ void AsyncServerSocket::detachEventBase() {
} }
} }
void AsyncServerSocket::useExistingSockets(const std::vector<int>& fds) { void AsyncServerSocket::useExistingSockets(
const std::vector<NetworkSocket>& fds) {
if (eventBase_) { if (eventBase_) {
eventBase_->dcheckIsInEventBaseThread(); eventBase_->dcheckIsInEventBaseThread();
} }
...@@ -314,7 +292,7 @@ void AsyncServerSocket::useExistingSockets(const std::vector<int>& fds) { ...@@ -314,7 +292,7 @@ void AsyncServerSocket::useExistingSockets(const std::vector<int>& fds) {
#if __linux__ #if __linux__
if (noTransparentTls_) { if (noTransparentTls_) {
// Ignore return value, errors are ok // Ignore return value, errors are ok
setsockopt(fd, SOL_SOCKET, SO_NO_TRANSPARENT_TLS, nullptr, 0); netops::setsockopt(fd, SOL_SOCKET, SO_NO_TRANSPARENT_TLS, nullptr, 0);
} }
#endif #endif
...@@ -324,19 +302,19 @@ void AsyncServerSocket::useExistingSockets(const std::vector<int>& fds) { ...@@ -324,19 +302,19 @@ void AsyncServerSocket::useExistingSockets(const std::vector<int>& fds) {
} }
} }
void AsyncServerSocket::useExistingSocket(int fd) { void AsyncServerSocket::useExistingSocket(NetworkSocket fd) {
useExistingSockets({fd}); useExistingSockets({fd});
} }
void AsyncServerSocket::bindSocket( void AsyncServerSocket::bindSocket(
int fd, NetworkSocket fd,
const SocketAddress& address, const SocketAddress& address,
bool isExistingSocket) { bool isExistingSocket) {
sockaddr_storage addrStorage; sockaddr_storage addrStorage;
address.getAddress(&addrStorage); address.getAddress(&addrStorage);
sockaddr* saddr = reinterpret_cast<sockaddr*>(&addrStorage); sockaddr* saddr = reinterpret_cast<sockaddr*>(&addrStorage);
if (fsp::bind(fd, saddr, address.getActualSize()) != 0) { if (netops::bind(fd, saddr, address.getActualSize()) != 0) {
if (!isExistingSocket) { if (!isExistingSocket) {
closeNoInt(fd); closeNoInt(fd);
} }
...@@ -347,7 +325,7 @@ void AsyncServerSocket::bindSocket( ...@@ -347,7 +325,7 @@ void AsyncServerSocket::bindSocket(
#if __linux__ #if __linux__
if (noTransparentTls_) { if (noTransparentTls_) {
// Ignore return value, errors are ok // Ignore return value, errors are ok
setsockopt(fd, SOL_SOCKET, SO_NO_TRANSPARENT_TLS, nullptr, 0); netops::setsockopt(fd, SOL_SOCKET, SO_NO_TRANSPARENT_TLS, nullptr, 0);
} }
#endif #endif
...@@ -378,7 +356,7 @@ void AsyncServerSocket::bind(const SocketAddress& address) { ...@@ -378,7 +356,7 @@ void AsyncServerSocket::bind(const SocketAddress& address) {
// However, in the normal case we need to create a new socket now. // However, in the normal case we need to create a new socket now.
// Don't set socket_ yet, so that socket_ will remain uninitialized if an // Don't set socket_ yet, so that socket_ will remain uninitialized if an
// error occurs. // error occurs.
int fd; NetworkSocket fd;
if (sockets_.size() == 0) { if (sockets_.size() == 0) {
fd = createSocket(address.getFamily()); fd = createSocket(address.getFamily());
} else if (sockets_.size() == 1) { } else if (sockets_.size() == 1) {
...@@ -409,7 +387,7 @@ void AsyncServerSocket::bind( ...@@ -409,7 +387,7 @@ void AsyncServerSocket::bind(
for (const IPAddress& ipAddress : ipAddresses) { for (const IPAddress& ipAddress : ipAddresses) {
SocketAddress address(ipAddress.toFullyQualified(), port); SocketAddress address(ipAddress.toFullyQualified(), port);
int fd = createSocket(address.getFamily()); auto fd = createSocket(address.getFamily());
bindSocket(fd, address, false); bindSocket(fd, address, false);
} }
...@@ -443,12 +421,12 @@ void AsyncServerSocket::bind(uint16_t port) { ...@@ -443,12 +421,12 @@ void AsyncServerSocket::bind(uint16_t port) {
}; };
auto setupAddress = [&](struct addrinfo* res) { auto setupAddress = [&](struct addrinfo* res) {
int s = fsp::socket(res->ai_family, res->ai_socktype, res->ai_protocol); auto s = netops::socket(res->ai_family, res->ai_socktype, res->ai_protocol);
// IPv6/IPv4 may not be supported by the kernel // IPv6/IPv4 may not be supported by the kernel
if (s < 0 && errno == EAFNOSUPPORT) { if (s == NetworkSocket() && errno == EAFNOSUPPORT) {
return; return;
} }
CHECK_GE(s, 0); CHECK_NE(s, NetworkSocket());
try { try {
setupSocket(s, res->ai_family); setupSocket(s, res->ai_family);
...@@ -461,11 +439,12 @@ void AsyncServerSocket::bind(uint16_t port) { ...@@ -461,11 +439,12 @@ void AsyncServerSocket::bind(uint16_t port) {
int v6only = 1; int v6only = 1;
CHECK( CHECK(
0 == 0 ==
setsockopt(s, IPPROTO_IPV6, IPV6_V6ONLY, &v6only, sizeof(v6only))); netops::setsockopt(
s, IPPROTO_IPV6, IPV6_V6ONLY, &v6only, sizeof(v6only)));
} }
// Bind to the socket // Bind to the socket
if (fsp::bind(s, res->ai_addr, socklen_t(res->ai_addrlen)) != 0) { if (netops::bind(s, res->ai_addr, socklen_t(res->ai_addrlen)) != 0) {
folly::throwSystemError( folly::throwSystemError(
errno, errno,
"failed to bind to async server socket for port ", "failed to bind to async server socket for port ",
...@@ -477,7 +456,7 @@ void AsyncServerSocket::bind(uint16_t port) { ...@@ -477,7 +456,7 @@ void AsyncServerSocket::bind(uint16_t port) {
#if __linux__ #if __linux__
if (noTransparentTls_) { if (noTransparentTls_) {
// Ignore return value, errors are ok // Ignore return value, errors are ok
setsockopt(s, SOL_SOCKET, SO_NO_TRANSPARENT_TLS, nullptr, 0); netops::setsockopt(s, SOL_SOCKET, SO_NO_TRANSPARENT_TLS, nullptr, 0);
} }
#endif #endif
...@@ -524,7 +503,7 @@ void AsyncServerSocket::bind(uint16_t port) { ...@@ -524,7 +503,7 @@ void AsyncServerSocket::bind(uint16_t port) {
// were opened, then restarting from scratch. // were opened, then restarting from scratch.
if (port == 0 && !sockets_.empty() && tries != kNumTries) { if (port == 0 && !sockets_.empty() && tries != kNumTries) {
for (const auto& socket : sockets_) { for (const auto& socket : sockets_) {
if (socket.socket_ <= 0) { if (socket.socket_ == NetworkSocket()) {
continue; continue;
} else if ( } else if (
const auto shutdownSocketSet = wShutdownSocketSet_.lock()) { const auto shutdownSocketSet = wShutdownSocketSet_.lock()) {
...@@ -558,7 +537,7 @@ void AsyncServerSocket::listen(int backlog) { ...@@ -558,7 +537,7 @@ void AsyncServerSocket::listen(int backlog) {
// Start listening // Start listening
for (auto& handler : sockets_) { for (auto& handler : sockets_) {
if (fsp::listen(handler.socket_, backlog) == -1) { if (netops::listen(handler.socket_, backlog) == -1) {
folly::throwSystemError(errno, "failed to listen on async server socket"); folly::throwSystemError(errno, "failed to listen on async server socket");
} }
} }
...@@ -735,9 +714,9 @@ void AsyncServerSocket::pauseAccepting() { ...@@ -735,9 +714,9 @@ void AsyncServerSocket::pauseAccepting() {
} }
} }
int AsyncServerSocket::createSocket(int family) { NetworkSocket AsyncServerSocket::createSocket(int family) {
int fd = fsp::socket(family, SOCK_STREAM, 0); auto fd = netops::socket(family, SOCK_STREAM, 0);
if (fd == -1) { if (fd == NetworkSocket()) {
folly::throwSystemError(errno, "error creating async server socket"); folly::throwSystemError(errno, "error creating async server socket");
} }
...@@ -762,12 +741,12 @@ void AsyncServerSocket::setTosReflect(bool enable) { ...@@ -762,12 +741,12 @@ void AsyncServerSocket::setTosReflect(bool enable) {
} }
for (auto& handler : sockets_) { for (auto& handler : sockets_) {
if (handler.socket_ < 0) { if (handler.socket_ == NetworkSocket()) {
continue; continue;
} }
int val = (enable) ? 1 : 0; int val = (enable) ? 1 : 0;
int ret = setsockopt( int ret = netops::setsockopt(
handler.socket_, IPPROTO_TCP, TCP_SAVE_SYN, &val, sizeof(val)); handler.socket_, IPPROTO_TCP, TCP_SAVE_SYN, &val, sizeof(val));
if (ret == 0) { if (ret == 0) {
...@@ -779,15 +758,16 @@ void AsyncServerSocket::setTosReflect(bool enable) { ...@@ -779,15 +758,16 @@ void AsyncServerSocket::setTosReflect(bool enable) {
tosReflect_ = true; tosReflect_ = true;
} }
void AsyncServerSocket::setupSocket(int fd, int family) { void AsyncServerSocket::setupSocket(NetworkSocket fd, int family) {
// Put the socket in non-blocking mode // Put the socket in non-blocking mode
if (fcntl(fd, F_SETFL, O_NONBLOCK) != 0) { if (netops::set_socket_non_blocking(fd) != 0) {
folly::throwSystemError(errno, "failed to put socket in non-blocking mode"); folly::throwSystemError(errno, "failed to put socket in non-blocking mode");
} }
// Set reuseaddr to avoid 2MSL delay on server restart // Set reuseaddr to avoid 2MSL delay on server restart
int one = 1; int one = 1;
if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) != 0) { if (netops::setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) !=
0) {
// This isn't a fatal error; just log an error message and continue // This isn't a fatal error; just log an error message and continue
LOG(ERROR) << "failed to set SO_REUSEADDR on async server socket " << errno; LOG(ERROR) << "failed to set SO_REUSEADDR on async server socket " << errno;
} }
...@@ -795,7 +775,8 @@ void AsyncServerSocket::setupSocket(int fd, int family) { ...@@ -795,7 +775,8 @@ void AsyncServerSocket::setupSocket(int fd, int family) {
// Set reuseport to support multiple accept threads // Set reuseport to support multiple accept threads
int zero = 0; int zero = 0;
if (reusePortEnabled_ && if (reusePortEnabled_ &&
setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(int)) != 0) { netops::setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(int)) !=
0) {
LOG(ERROR) << "failed to set SO_REUSEPORT on async server socket " LOG(ERROR) << "failed to set SO_REUSEPORT on async server socket "
<< errnoStr(errno); << errnoStr(errno);
#ifdef WIN32 #ifdef WIN32
...@@ -809,7 +790,7 @@ void AsyncServerSocket::setupSocket(int fd, int family) { ...@@ -809,7 +790,7 @@ void AsyncServerSocket::setupSocket(int fd, int family) {
} }
// Set keepalive as desired // Set keepalive as desired
if (setsockopt( if (netops::setsockopt(
fd, fd,
SOL_SOCKET, SOL_SOCKET,
SO_KEEPALIVE, SO_KEEPALIVE,
...@@ -820,7 +801,7 @@ void AsyncServerSocket::setupSocket(int fd, int family) { ...@@ -820,7 +801,7 @@ void AsyncServerSocket::setupSocket(int fd, int family) {
} }
// Setup FD_CLOEXEC flag // Setup FD_CLOEXEC flag
if (closeOnExec_ && (-1 == folly::setCloseOnExec(fd, closeOnExec_))) { if (closeOnExec_ && (-1 == netops::set_socket_close_on_exec(fd))) {
LOG(ERROR) << "failed to set FD_CLOEXEC on async server socket: " LOG(ERROR) << "failed to set FD_CLOEXEC on async server socket: "
<< errnoStr(errno); << errnoStr(errno);
} }
...@@ -829,7 +810,8 @@ void AsyncServerSocket::setupSocket(int fd, int family) { ...@@ -829,7 +810,8 @@ void AsyncServerSocket::setupSocket(int fd, int family) {
// See http://lists.danga.com/pipermail/memcached/2005-March/001240.html // See http://lists.danga.com/pipermail/memcached/2005-March/001240.html
#ifndef TCP_NOPUSH #ifndef TCP_NOPUSH
if (family != AF_UNIX) { if (family != AF_UNIX) {
if (setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one)) != 0) { if (netops::setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one)) !=
0) {
// This isn't a fatal error; just log an error message and continue // This isn't a fatal error; just log an error message and continue
LOG(ERROR) << "failed to set TCP_NODELAY on async server socket: " LOG(ERROR) << "failed to set TCP_NODELAY on async server socket: "
<< errnoStr(errno); << errnoStr(errno);
...@@ -854,7 +836,7 @@ void AsyncServerSocket::setupSocket(int fd, int family) { ...@@ -854,7 +836,7 @@ void AsyncServerSocket::setupSocket(int fd, int family) {
void AsyncServerSocket::handlerReady( void AsyncServerSocket::handlerReady(
uint16_t /* events */, uint16_t /* events */,
int fd, NetworkSocket fd,
sa_family_t addressFamily) noexcept { sa_family_t addressFamily) noexcept {
assert(!callbacks_.empty()); assert(!callbacks_.empty());
DestructorGuard dg(this); DestructorGuard dg(this);
...@@ -876,30 +858,31 @@ void AsyncServerSocket::handlerReady( ...@@ -876,30 +858,31 @@ void AsyncServerSocket::handlerReady(
// Accept a new client socket // Accept a new client socket
#ifdef SOCK_NONBLOCK #ifdef SOCK_NONBLOCK
int clientSocket = accept4(fd, saddr, &addrLen, SOCK_NONBLOCK); auto clientSocket = NetworkSocket::fromFd(
accept4(fd.toFd(), saddr, &addrLen, SOCK_NONBLOCK));
#else #else
int clientSocket = accept(fd, saddr, &addrLen); auto clientSocket = netops::accept(fd, saddr, &addrLen);
#endif #endif
address.setFromSockaddr(saddr, addrLen); address.setFromSockaddr(saddr, addrLen);
if (clientSocket >= 0 && connectionEventCallback_) { if (clientSocket != NetworkSocket() && connectionEventCallback_) {
connectionEventCallback_->onConnectionAccepted(clientSocket, address); connectionEventCallback_->onConnectionAccepted(clientSocket, address);
} }
// Connection accepted, get the SYN packet from the client if // Connection accepted, get the SYN packet from the client if
// TOS reflect is enabled // TOS reflect is enabled
if (kIsLinux && clientSocket >= 0 && tosReflect_) { if (kIsLinux && clientSocket != NetworkSocket() && tosReflect_) {
std::array<uint32_t, 64> buffer; std::array<uint32_t, 64> buffer;
socklen_t len = sizeof(buffer); socklen_t len = sizeof(buffer);
int ret = int ret = netops::getsockopt(
getsockopt(clientSocket, IPPROTO_TCP, TCP_SAVED_SYN, &buffer, &len); clientSocket, IPPROTO_TCP, TCP_SAVED_SYN, &buffer, &len);
if (ret == 0) { if (ret == 0) {
uint32_t tosWord = folly::Endian::big(buffer[0]); uint32_t tosWord = folly::Endian::big(buffer[0]);
if (addressFamily == AF_INET6) { if (addressFamily == AF_INET6) {
tosWord = (tosWord & 0x0FC00000) >> 20; tosWord = (tosWord & 0x0FC00000) >> 20;
ret = setsockopt( ret = netops::setsockopt(
clientSocket, clientSocket,
IPPROTO_IPV6, IPPROTO_IPV6,
IPV6_TCLASS, IPV6_TCLASS,
...@@ -907,7 +890,7 @@ void AsyncServerSocket::handlerReady( ...@@ -907,7 +890,7 @@ void AsyncServerSocket::handlerReady(
sizeof(tosWord)); sizeof(tosWord));
} else if (addressFamily == AF_INET) { } else if (addressFamily == AF_INET) {
tosWord = (tosWord & 0x00FC0000) >> 16; tosWord = (tosWord & 0x00FC0000) >> 16;
ret = setsockopt( ret = netops::setsockopt(
clientSocket, IPPROTO_IP, IP_TOS, &tosWord, sizeof(tosWord)); clientSocket, IPPROTO_IP, IP_TOS, &tosWord, sizeof(tosWord));
} }
...@@ -934,7 +917,7 @@ void AsyncServerSocket::handlerReady( ...@@ -934,7 +917,7 @@ void AsyncServerSocket::handlerReady(
acceptRate_ = 1; acceptRate_ = 1;
} else if (rand() > acceptRate_ * RAND_MAX) { } else if (rand() > acceptRate_ * RAND_MAX) {
++numDroppedConnections_; ++numDroppedConnections_;
if (clientSocket >= 0) { if (clientSocket != NetworkSocket()) {
closeNoInt(clientSocket); closeNoInt(clientSocket);
if (connectionEventCallback_) { if (connectionEventCallback_) {
connectionEventCallback_->onConnectionDropped( connectionEventCallback_->onConnectionDropped(
...@@ -945,7 +928,7 @@ void AsyncServerSocket::handlerReady( ...@@ -945,7 +928,7 @@ void AsyncServerSocket::handlerReady(
} }
} }
if (clientSocket < 0) { if (clientSocket == NetworkSocket()) {
if (errno == EAGAIN) { if (errno == EAGAIN) {
// No more sockets to accept right now. // No more sockets to accept right now.
// Check for this code first, since it's the most common. // Check for this code first, since it's the most common.
...@@ -971,7 +954,7 @@ void AsyncServerSocket::handlerReady( ...@@ -971,7 +954,7 @@ void AsyncServerSocket::handlerReady(
#ifndef SOCK_NONBLOCK #ifndef SOCK_NONBLOCK
// Explicitly set the new connection to non-blocking mode // Explicitly set the new connection to non-blocking mode
if (fcntl(clientSocket, F_SETFL, O_NONBLOCK) != 0) { if (netops::set_socket_non_blocking(clientSocket) != 0) {
closeNoInt(clientSocket); closeNoInt(clientSocket);
dispatchError( dispatchError(
"failed to set accepted socket to non-blocking mode", errno); "failed to set accepted socket to non-blocking mode", errno);
...@@ -992,7 +975,9 @@ void AsyncServerSocket::handlerReady( ...@@ -992,7 +975,9 @@ void AsyncServerSocket::handlerReady(
} }
} }
void AsyncServerSocket::dispatchSocket(int socket, SocketAddress&& address) { void AsyncServerSocket::dispatchSocket(
NetworkSocket socket,
SocketAddress&& address) {
uint32_t startingIndex = callbackIndex_; uint32_t startingIndex = callbackIndex_;
// Short circuit if the callback is in the primary EventBase thread // Short circuit if the callback is in the primary EventBase thread
......
...@@ -25,6 +25,8 @@ ...@@ -25,6 +25,8 @@
#include <folly/io/async/EventBase.h> #include <folly/io/async/EventBase.h>
#include <folly/io/async/EventHandler.h> #include <folly/io/async/EventHandler.h>
#include <folly/io/async/NotificationQueue.h> #include <folly/io/async/NotificationQueue.h>
#include <folly/net/NetOps.h>
#include <folly/net/NetworkSocket.h>
#include <folly/portability/Sockets.h> #include <folly/portability/Sockets.h>
#include <limits.h> #include <limits.h>
...@@ -85,6 +87,11 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -85,6 +87,11 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
virtual void onConnectionAccepted( virtual void onConnectionAccepted(
const int socket, const int socket,
const SocketAddress& addr) noexcept = 0; const SocketAddress& addr) noexcept = 0;
void onConnectionAccepted(
const NetworkSocket socket,
const SocketAddress& addr) noexcept {
onConnectionAccepted(socket.toFd(), addr);
}
/** /**
* onConnectionAcceptError() is called when an error occurred accepting * onConnectionAcceptError() is called when an error occurred accepting
...@@ -99,6 +106,11 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -99,6 +106,11 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
virtual void onConnectionDropped( virtual void onConnectionDropped(
const int socket, const int socket,
const SocketAddress& addr) noexcept = 0; const SocketAddress& addr) noexcept = 0;
void onConnectionDropped(
const NetworkSocket socket,
const SocketAddress& addr) noexcept {
onConnectionDropped(socket.toFd(), addr);
}
/** /**
* onConnectionEnqueuedForAcceptorCallback() is called when the * onConnectionEnqueuedForAcceptorCallback() is called when the
...@@ -107,6 +119,11 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -107,6 +119,11 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
virtual void onConnectionEnqueuedForAcceptorCallback( virtual void onConnectionEnqueuedForAcceptorCallback(
const int socket, const int socket,
const SocketAddress& addr) noexcept = 0; const SocketAddress& addr) noexcept = 0;
void onConnectionEnqueuedForAcceptorCallback(
const NetworkSocket socket,
const SocketAddress& addr) noexcept {
onConnectionEnqueuedForAcceptorCallback(socket.toFd(), addr);
}
/** /**
* onConnectionDequeuedByAcceptorCallback() is called when the * onConnectionDequeuedByAcceptorCallback() is called when the
...@@ -115,6 +132,11 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -115,6 +132,11 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
virtual void onConnectionDequeuedByAcceptorCallback( virtual void onConnectionDequeuedByAcceptorCallback(
const int socket, const int socket,
const SocketAddress& addr) noexcept = 0; const SocketAddress& addr) noexcept = 0;
void onConnectionDequeuedByAcceptorCallback(
const NetworkSocket socket,
const SocketAddress& addr) noexcept {
onConnectionDequeuedByAcceptorCallback(socket.toFd(), addr);
}
/** /**
* onBackoffStarted is called when the socket has successfully started * onBackoffStarted is called when the socket has successfully started
...@@ -158,6 +180,11 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -158,6 +180,11 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
virtual void connectionAccepted( virtual void connectionAccepted(
int fd, int fd,
const SocketAddress& clientAddr) noexcept = 0; const SocketAddress& clientAddr) noexcept = 0;
void connectionAccepted(
NetworkSocket fd,
const SocketAddress& clientAddr) noexcept {
connectionAccepted(fd.toFd(), clientAddr);
}
/** /**
* acceptError() is called if an error occurs while accepting. * acceptError() is called if an error occurs while accepting.
...@@ -292,8 +319,21 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -292,8 +319,21 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
* On error a TTransportException will be thrown and the caller will retain * On error a TTransportException will be thrown and the caller will retain
* ownership of the file descriptor. * ownership of the file descriptor.
*/ */
void useExistingSocket(int fd); void useExistingSocket(int fd) {
void useExistingSockets(const std::vector<int>& fds); useExistingSocket(NetworkSocket::fromFd(fd));
}
void useExistingSocket(NetworkSocket fd);
void useExistingSockets(const std::vector<int>& fds) {
// This isn't a big enough perf impact to matter, as it's only really used
// for long-lived servers :)
std::vector<NetworkSocket> socks;
socks.reserve(fds.size());
for (size_t i = 0; i < fds.size(); ++i) {
socks.push_back(NetworkSocket::fromFd(fds[i]));
}
useExistingSockets(socks);
}
void useExistingSockets(const std::vector<NetworkSocket>& fds);
/** /**
* Return the underlying file descriptor * Return the underlying file descriptor
...@@ -301,7 +341,7 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -301,7 +341,7 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
std::vector<int> getSockets() const { std::vector<int> getSockets() const {
std::vector<int> sockets; std::vector<int> sockets;
for (auto& handler : sockets_) { for (auto& handler : sockets_) {
sockets.push_back(handler.socket_); sockets.push_back(handler.socket_.toFd());
} }
return sockets; return sockets;
} }
...@@ -317,7 +357,7 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -317,7 +357,7 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
if (sockets_.size() == 0) { if (sockets_.size() == 0) {
return -1; return -1;
} else { } else {
return sockets_[0].socket_; return sockets_[0].socket_.toFd();
} }
} }
...@@ -627,12 +667,12 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -627,12 +667,12 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
keepAliveEnabled_ = enabled; keepAliveEnabled_ = enabled;
for (auto& handler : sockets_) { for (auto& handler : sockets_) {
if (handler.socket_ < 0) { if (handler.socket_ == NetworkSocket()) {
continue; continue;
} }
int val = (enabled) ? 1 : 0; int val = (enabled) ? 1 : 0;
if (setsockopt( if (netops::setsockopt(
handler.socket_, SOL_SOCKET, SO_KEEPALIVE, &val, sizeof(val)) != handler.socket_, SOL_SOCKET, SO_KEEPALIVE, &val, sizeof(val)) !=
0) { 0) {
LOG(ERROR) << "failed to set SO_KEEPALIVE on async server socket: %s" LOG(ERROR) << "failed to set SO_KEEPALIVE on async server socket: %s"
...@@ -656,12 +696,12 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -656,12 +696,12 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
reusePortEnabled_ = enabled; reusePortEnabled_ = enabled;
for (auto& handler : sockets_) { for (auto& handler : sockets_) {
if (handler.socket_ < 0) { if (handler.socket_ == NetworkSocket()) {
continue; continue;
} }
int val = (enabled) ? 1 : 0; int val = (enabled) ? 1 : 0;
if (setsockopt( if (netops::setsockopt(
handler.socket_, SOL_SOCKET, SO_REUSEPORT, &val, sizeof(val)) != handler.socket_, SOL_SOCKET, SO_REUSEPORT, &val, sizeof(val)) !=
0) { 0) {
LOG(ERROR) << "failed to set SO_REUSEPORT on async server socket " LOG(ERROR) << "failed to set SO_REUSEPORT on async server socket "
...@@ -743,7 +783,7 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -743,7 +783,7 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
struct QueueMessage { struct QueueMessage {
MessageType type; MessageType type;
int fd; NetworkSocket fd;
int err; int err;
SocketAddress address; SocketAddress address;
std::string msg; std::string msg;
...@@ -800,13 +840,18 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -800,13 +840,18 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
class BackoffTimeout; class BackoffTimeout;
virtual void virtual void handlerReady(
handlerReady(uint16_t events, int socket, sa_family_t family) noexcept; uint16_t events,
NetworkSocket socket,
int createSocket(int family); sa_family_t family) noexcept;
void setupSocket(int fd, int family);
void bindSocket(int fd, const SocketAddress& address, bool isExistingSocket); NetworkSocket createSocket(int family);
void dispatchSocket(int socket, SocketAddress&& address); void setupSocket(NetworkSocket fd, int family);
void bindSocket(
NetworkSocket fd,
const SocketAddress& address,
bool isExistingSocket);
void dispatchSocket(NetworkSocket socket, SocketAddress&& address);
void dispatchError(const char* msg, int errnoValue); void dispatchError(const char* msg, int errnoValue);
void enterBackoff(); void enterBackoff();
void backoffTimeoutExpired(); void backoffTimeoutExpired();
...@@ -825,7 +870,7 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -825,7 +870,7 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
struct ServerEventHandler : public EventHandler { struct ServerEventHandler : public EventHandler {
ServerEventHandler( ServerEventHandler(
EventBase* eventBase, EventBase* eventBase,
int socket, NetworkSocket socket,
AsyncServerSocket* parent, AsyncServerSocket* parent,
sa_family_t addressFamily) sa_family_t addressFamily)
: EventHandler(eventBase, socket), : EventHandler(eventBase, socket),
...@@ -861,14 +906,14 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase { ...@@ -861,14 +906,14 @@ class AsyncServerSocket : public DelayedDestruction, public AsyncSocketBase {
} }
EventBase* eventBase_; EventBase* eventBase_;
int socket_; NetworkSocket socket_;
AsyncServerSocket* parent_; AsyncServerSocket* parent_;
sa_family_t addressFamily_; sa_family_t addressFamily_;
}; };
EventBase* eventBase_; EventBase* eventBase_;
std::vector<ServerEventHandler> sockets_; std::vector<ServerEventHandler> sockets_;
std::vector<int> pendingCloseSockets_; std::vector<NetworkSocket> pendingCloseSockets_;
bool accepting_; bool accepting_;
uint32_t maxAcceptAtOnce_; uint32_t maxAcceptAtOnce_;
uint32_t maxNumMsgsInQueue_; uint32_t maxNumMsgsInQueue_;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment