Commit c7b21fd1 authored by Orvid King's avatar Orvid King Committed by Facebook Github Bot

NetworkSocket support for AsyncSocket

Summary: This is a big one, but adds support for NetworkSocket to AsyncSocket itself.

Reviewed By: yfeldblum

Differential Revision: D12818493

fbshipit-source-id: d7b73f356414b006f4fc6b1c380a8fa1c32b9a46
parent 083c73bd
No related merge requests found
...@@ -326,7 +326,7 @@ void AsyncSSLSocket::init() { ...@@ -326,7 +326,7 @@ void AsyncSSLSocket::init() {
void AsyncSSLSocket::closeNow() { void AsyncSSLSocket::closeNow() {
// Close the SSL connection. // Close the SSL connection.
if (ssl_ != nullptr && fd_ != -1) { if (ssl_ != nullptr && fd_ != NetworkSocket()) {
int rc = SSL_shutdown(ssl_.get()); int rc = SSL_shutdown(ssl_.get());
if (rc == 0) { if (rc == 0) {
rc = SSL_shutdown(ssl_.get()); rc = SSL_shutdown(ssl_.get());
......
...@@ -254,7 +254,10 @@ namespace { ...@@ -254,7 +254,10 @@ namespace {
static AsyncSocket::SendMsgParamsCallback defaultSendMsgParamsCallback; static AsyncSocket::SendMsgParamsCallback defaultSendMsgParamsCallback;
// Based on flags, signal the transparent handler to disable certain functions // Based on flags, signal the transparent handler to disable certain functions
void disableTransparentFunctions(int fd, bool noTransparentTls, bool noTSocks) { void disableTransparentFunctions(
NetworkSocket fd,
bool noTransparentTls,
bool noTSocks) {
(void)fd; (void)fd;
(void)noTransparentTls; (void)noTransparentTls;
(void)noTSocks; (void)noTSocks;
...@@ -262,12 +265,12 @@ void disableTransparentFunctions(int fd, bool noTransparentTls, bool noTSocks) { ...@@ -262,12 +265,12 @@ void disableTransparentFunctions(int fd, bool noTransparentTls, bool noTSocks) {
if (noTransparentTls) { if (noTransparentTls) {
// Ignore return value, errors are ok // Ignore return value, errors are ok
VLOG(5) << "Disabling TTLS for fd " << fd; VLOG(5) << "Disabling TTLS for fd " << fd;
::setsockopt(fd, SOL_SOCKET, SO_NO_TRANSPARENT_TLS, nullptr, 0); netops::setsockopt(fd, SOL_SOCKET, SO_NO_TRANSPARENT_TLS, nullptr, 0);
} }
if (noTSocks) { if (noTSocks) {
VLOG(5) << "Disabling TSOCKS for fd " << fd; VLOG(5) << "Disabling TSOCKS for fd " << fd;
// Ignore return value, errors are ok // Ignore return value, errors are ok
::setsockopt(fd, SOL_SOCKET, SO_NO_TSOCKS, nullptr, 0); netops::setsockopt(fd, SOL_SOCKET, SO_NO_TSOCKS, nullptr, 0);
} }
#endif #endif
} }
...@@ -309,7 +312,10 @@ AsyncSocket::AsyncSocket( ...@@ -309,7 +312,10 @@ AsyncSocket::AsyncSocket(
connect(nullptr, ip, port, connectTimeout); connect(nullptr, ip, port, connectTimeout);
} }
AsyncSocket::AsyncSocket(EventBase* evb, int fd, uint32_t zeroCopyBufId) AsyncSocket::AsyncSocket(
EventBase* evb,
NetworkSocket fd,
uint32_t zeroCopyBufId)
: zeroCopyBufId_(zeroCopyBufId), : zeroCopyBufId_(zeroCopyBufId),
eventBase_(evb), eventBase_(evb),
writeTimeout_(this, evb), writeTimeout_(this, evb),
...@@ -341,7 +347,7 @@ void AsyncSocket::init() { ...@@ -341,7 +347,7 @@ void AsyncSocket::init() {
shutdownFlags_ = 0; shutdownFlags_ = 0;
state_ = StateEnum::UNINIT; state_ = StateEnum::UNINIT;
eventFlags_ = EventHandler::NONE; eventFlags_ = EventHandler::NONE;
fd_ = -1; fd_ = NetworkSocket();
sendTimeout_ = 0; sendTimeout_ = 0;
maxReadsPerEvent_ = 16; maxReadsPerEvent_ = 16;
connectCallback_ = nullptr; connectCallback_ = nullptr;
...@@ -381,14 +387,14 @@ int AsyncSocket::detachFd() { ...@@ -381,14 +387,14 @@ int AsyncSocket::detachFd() {
if (const auto socketSet = wShutdownSocketSet_.lock()) { if (const auto socketSet = wShutdownSocketSet_.lock()) {
socketSet->remove(fd_); socketSet->remove(fd_);
} }
int fd = fd_; auto fd = fd_;
fd_ = -1; fd_ = NetworkSocket();
// Call closeNow() to invoke all pending callbacks with an error. // Call closeNow() to invoke all pending callbacks with an error.
closeNow(); closeNow();
// Update the EventHandler to stop using this fd. // Update the EventHandler to stop using this fd.
// This can only be done after closeNow() unregisters the handler. // This can only be done after closeNow() unregisters the handler.
ioHandler_.changeHandlerFD(-1); ioHandler_.changeHandlerFD(-1);
return fd; return fd.toFd();
} }
const folly::SocketAddress& AsyncSocket::anyAddress() { const folly::SocketAddress& AsyncSocket::anyAddress() {
...@@ -406,11 +412,11 @@ void AsyncSocket::setShutdownSocketSet( ...@@ -406,11 +412,11 @@ void AsyncSocket::setShutdownSocketSet(
return; return;
} }
if (shutdownSocketSet && fd_ != -1) { if (shutdownSocketSet && fd_ != NetworkSocket()) {
shutdownSocketSet->remove(fd_); shutdownSocketSet->remove(fd_);
} }
if (newSS && fd_ != -1) { if (newSS && fd_ != NetworkSocket()) {
newSS->add(fd_); newSS->add(fd_);
} }
...@@ -418,7 +424,7 @@ void AsyncSocket::setShutdownSocketSet( ...@@ -418,7 +424,7 @@ void AsyncSocket::setShutdownSocketSet(
} }
void AsyncSocket::setCloseOnExec() { void AsyncSocket::setCloseOnExec() {
int rv = fcntl(fd_, F_SETFD, FD_CLOEXEC); int rv = netops::set_socket_close_on_exec(fd_);
if (rv != 0) { if (rv != 0) {
auto errnoCopy = errno; auto errnoCopy = errno;
throw AsyncSocketException( throw AsyncSocketException(
...@@ -449,7 +455,7 @@ void AsyncSocket::connect( ...@@ -449,7 +455,7 @@ void AsyncSocket::connect(
// Make connect end time at least >= connectStartTime. // Make connect end time at least >= connectStartTime.
connectEndTime_ = connectStartTime_; connectEndTime_ = connectStartTime_;
assert(fd_ == -1); assert(fd_ == NetworkSocket());
state_ = StateEnum::CONNECTING; state_ = StateEnum::CONNECTING;
connectCallback_ = callback; connectCallback_ = callback;
...@@ -462,8 +468,8 @@ void AsyncSocket::connect( ...@@ -462,8 +468,8 @@ void AsyncSocket::connect(
// constant (PF_xxx) rather than an address family (AF_xxx), but the // constant (PF_xxx) rather than an address family (AF_xxx), but the
// distinction is mainly just historical. In pretty much all // distinction is mainly just historical. In pretty much all
// implementations the PF_foo and AF_foo constants are identical. // implementations the PF_foo and AF_foo constants are identical.
fd_ = fsp::socket(address.getFamily(), SOCK_STREAM, 0); fd_ = netops::socket(address.getFamily(), SOCK_STREAM, 0);
if (fd_ < 0) { if (fd_ == NetworkSocket()) {
auto errnoCopy = errno; auto errnoCopy = errno;
throw AsyncSocketException( throw AsyncSocketException(
AsyncSocketException::INTERNAL_ERROR, AsyncSocketException::INTERNAL_ERROR,
...@@ -479,15 +485,7 @@ void AsyncSocket::connect( ...@@ -479,15 +485,7 @@ void AsyncSocket::connect(
setCloseOnExec(); setCloseOnExec();
// Put the socket in non-blocking mode // Put the socket in non-blocking mode
int flags = fcntl(fd_, F_GETFL, 0); int rv = netops::set_socket_non_blocking(fd_);
if (flags == -1) {
auto errnoCopy = errno;
throw AsyncSocketException(
AsyncSocketException::INTERNAL_ERROR,
withAddr("failed to get socket flags"),
errnoCopy);
}
int rv = fcntl(fd_, F_SETFL, flags | O_NONBLOCK);
if (rv == -1) { if (rv == -1) {
auto errnoCopy = errno; auto errnoCopy = errno;
throw AsyncSocketException( throw AsyncSocketException(
...@@ -498,7 +496,7 @@ void AsyncSocket::connect( ...@@ -498,7 +496,7 @@ void AsyncSocket::connect(
#if !defined(MSG_NOSIGNAL) && defined(F_SETNOSIGPIPE) #if !defined(MSG_NOSIGNAL) && defined(F_SETNOSIGPIPE)
// iOS and OS X don't support MSG_NOSIGNAL; set F_SETNOSIGPIPE instead // iOS and OS X don't support MSG_NOSIGNAL; set F_SETNOSIGPIPE instead
rv = fcntl(fd_, F_SETNOSIGPIPE, 1); rv = fcntl(fd_.toFd(), F_SETNOSIGPIPE, 1);
if (rv == -1) { if (rv == -1) {
auto errnoCopy = errno; auto errnoCopy = errno;
throw AsyncSocketException( throw AsyncSocketException(
...@@ -524,7 +522,8 @@ void AsyncSocket::connect( ...@@ -524,7 +522,8 @@ void AsyncSocket::connect(
// bind the socket // bind the socket
if (bindAddr != anyAddress()) { if (bindAddr != anyAddress()) {
int one = 1; int one = 1;
if (setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one))) { if (netops::setsockopt(
fd_, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one))) {
auto errnoCopy = errno; auto errnoCopy = errno;
doClose(); doClose();
throw AsyncSocketException( throw AsyncSocketException(
...@@ -535,7 +534,7 @@ void AsyncSocket::connect( ...@@ -535,7 +534,7 @@ void AsyncSocket::connect(
bindAddr.getAddress(&addrStorage); bindAddr.getAddress(&addrStorage);
if (bind(fd_, saddr, bindAddr.getActualSize()) != 0) { if (netops::bind(fd_, saddr, bindAddr.getActualSize()) != 0) {
auto errnoCopy = errno; auto errnoCopy = errno;
doClose(); doClose();
throw AsyncSocketException( throw AsyncSocketException(
...@@ -598,7 +597,7 @@ void AsyncSocket::connect( ...@@ -598,7 +597,7 @@ void AsyncSocket::connect(
} }
int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) { int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) {
int rv = fsp::connect(fd_, saddr, len); int rv = netops::connect(fd_, saddr, len);
if (rv < 0) { if (rv < 0) {
auto errnoCopy = errno; auto errnoCopy = errno;
if (errnoCopy == EINPROGRESS) { if (errnoCopy == EINPROGRESS) {
...@@ -861,12 +860,13 @@ bool AsyncSocket::setZeroCopy(bool enable) { ...@@ -861,12 +860,13 @@ bool AsyncSocket::setZeroCopy(bool enable) {
if (msgErrQueueSupported) { if (msgErrQueueSupported) {
zeroCopyVal_ = enable; zeroCopyVal_ = enable;
if (fd_ < 0) { if (fd_ == NetworkSocket()) {
return false; return false;
} }
int val = enable ? 1 : 0; int val = enable ? 1 : 0;
int ret = setsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, sizeof(val)); int ret =
netops::setsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, sizeof(val));
// if enable == false, set zeroCopyEnabled_ = false regardless // if enable == false, set zeroCopyEnabled_ = false regardless
// if SO_ZEROCOPY is set or not // if SO_ZEROCOPY is set or not
...@@ -881,7 +881,7 @@ bool AsyncSocket::setZeroCopy(bool enable) { ...@@ -881,7 +881,7 @@ bool AsyncSocket::setZeroCopy(bool enable) {
if (ret) { if (ret) {
val = 0; val = 0;
socklen_t optlen = sizeof(val); socklen_t optlen = sizeof(val);
ret = getsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, &optlen); ret = netops::getsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, &optlen);
if (!ret) { if (!ret) {
enable = val ? true : false; enable = val ? true : false;
...@@ -1282,8 +1282,8 @@ void AsyncSocket::closeNow() { ...@@ -1282,8 +1282,8 @@ void AsyncSocket::closeNow() {
immediateReadHandler_.cancelLoopCallback(); immediateReadHandler_.cancelLoopCallback();
} }
if (fd_ >= 0) { if (fd_ != NetworkSocket()) {
ioHandler_.changeHandlerFD(-1); ioHandler_.changeHandlerFD(NetworkSocket());
doClose(); doClose();
} }
...@@ -1324,7 +1324,7 @@ void AsyncSocket::closeNow() { ...@@ -1324,7 +1324,7 @@ void AsyncSocket::closeNow() {
void AsyncSocket::closeWithReset() { void AsyncSocket::closeWithReset() {
// Enable SO_LINGER, with the linger timeout set to 0. // Enable SO_LINGER, with the linger timeout set to 0.
// This will trigger a TCP reset when we close the socket. // This will trigger a TCP reset when we close the socket.
if (fd_ >= 0) { if (fd_ != NetworkSocket()) {
struct linger optLinger = {1, 0}; struct linger optLinger = {1, 0};
if (setSockOpt(SOL_SOCKET, SO_LINGER, &optLinger) != 0) { if (setSockOpt(SOL_SOCKET, SO_LINGER, &optLinger) != 0) {
VLOG(2) << "AsyncSocket::closeWithReset(): error setting SO_LINGER " VLOG(2) << "AsyncSocket::closeWithReset(): error setting SO_LINGER "
...@@ -1394,7 +1394,7 @@ void AsyncSocket::shutdownWriteNow() { ...@@ -1394,7 +1394,7 @@ void AsyncSocket::shutdownWriteNow() {
} }
// Shutdown writes on the file descriptor // Shutdown writes on the file descriptor
shutdown(fd_, SHUT_WR); netops::shutdown(fd_, SHUT_WR);
// Immediately fail all write requests // Immediately fail all write requests
failAllWrites(socketShutdownForWritesEx); failAllWrites(socketShutdownForWritesEx);
...@@ -1441,26 +1441,26 @@ void AsyncSocket::shutdownWriteNow() { ...@@ -1441,26 +1441,26 @@ void AsyncSocket::shutdownWriteNow() {
} }
bool AsyncSocket::readable() const { bool AsyncSocket::readable() const {
if (fd_ == -1) { if (fd_ == NetworkSocket()) {
return false; return false;
} }
struct pollfd fds[1]; netops::PollDescriptor fds[1];
fds[0].fd = fd_; fds[0].fd = fd_;
fds[0].events = POLLIN; fds[0].events = POLLIN;
fds[0].revents = 0; fds[0].revents = 0;
int rc = poll(fds, 1, 0); int rc = netops::poll(fds, 1, 0);
return rc == 1; return rc == 1;
} }
bool AsyncSocket::writable() const { bool AsyncSocket::writable() const {
if (fd_ == -1) { if (fd_ == NetworkSocket()) {
return false; return false;
} }
struct pollfd fds[1]; netops::PollDescriptor fds[1];
fds[0].fd = fd_; fds[0].fd = fd_;
fds[0].events = POLLOUT; fds[0].events = POLLOUT;
fds[0].revents = 0; fds[0].revents = 0;
int rc = poll(fds, 1, 0); int rc = netops::poll(fds, 1, 0);
return rc == 1; return rc == 1;
} }
...@@ -1469,17 +1469,17 @@ bool AsyncSocket::isPending() const { ...@@ -1469,17 +1469,17 @@ bool AsyncSocket::isPending() const {
} }
bool AsyncSocket::hangup() const { bool AsyncSocket::hangup() const {
if (fd_ == -1) { if (fd_ == NetworkSocket()) {
// sanity check, no one should ask for hangup if we are not connected. // sanity check, no one should ask for hangup if we are not connected.
assert(false); assert(false);
return false; return false;
} }
#ifdef POLLRDHUP // Linux-only #ifdef POLLRDHUP // Linux-only
struct pollfd fds[1]; netops::PollDescriptor fds[1];
fds[0].fd = fd_; fds[0].fd = fd_;
fds[0].events = POLLRDHUP | POLLHUP; fds[0].events = POLLRDHUP | POLLHUP;
fds[0].revents = 0; fds[0].revents = 0;
poll(fds, 1, 0); netops::poll(fds, 1, 0);
return (fds[0].revents & (POLLRDHUP | POLLHUP)) != 0; return (fds[0].revents & (POLLRDHUP | POLLHUP)) != 0;
#else #else
return false; return false;
...@@ -1542,7 +1542,7 @@ bool AsyncSocket::isDetachable() const { ...@@ -1542,7 +1542,7 @@ bool AsyncSocket::isDetachable() const {
} }
void AsyncSocket::cacheAddresses() { void AsyncSocket::cacheAddresses() {
if (fd_ >= 0) { if (fd_ != NetworkSocket()) {
try { try {
cacheLocalAddress(); cacheLocalAddress();
cachePeerAddress(); cachePeerAddress();
...@@ -1587,14 +1587,15 @@ bool AsyncSocket::getTFOSucceded() const { ...@@ -1587,14 +1587,15 @@ bool AsyncSocket::getTFOSucceded() const {
} }
int AsyncSocket::setNoDelay(bool noDelay) { int AsyncSocket::setNoDelay(bool noDelay) {
if (fd_ < 0) { if (fd_ == NetworkSocket()) {
VLOG(4) << "AsyncSocket::setNoDelay() called on non-open socket " << this VLOG(4) << "AsyncSocket::setNoDelay() called on non-open socket " << this
<< "(state=" << state_ << ")"; << "(state=" << state_ << ")";
return EINVAL; return EINVAL;
} }
int value = noDelay ? 1 : 0; int value = noDelay ? 1 : 0;
if (setsockopt(fd_, IPPROTO_TCP, TCP_NODELAY, &value, sizeof(value)) != 0) { if (netops::setsockopt(
fd_, IPPROTO_TCP, TCP_NODELAY, &value, sizeof(value)) != 0) {
int errnoCopy = errno; int errnoCopy = errno;
VLOG(2) << "failed to update TCP_NODELAY option on AsyncSocket " << this VLOG(2) << "failed to update TCP_NODELAY option on AsyncSocket " << this
<< " (fd=" << fd_ << ", state=" << state_ << " (fd=" << fd_ << ", state=" << state_
...@@ -1610,13 +1611,13 @@ int AsyncSocket::setCongestionFlavor(const std::string& cname) { ...@@ -1610,13 +1611,13 @@ int AsyncSocket::setCongestionFlavor(const std::string& cname) {
#define TCP_CONGESTION 13 #define TCP_CONGESTION 13
#endif #endif
if (fd_ < 0) { if (fd_ == NetworkSocket()) {
VLOG(4) << "AsyncSocket::setCongestionFlavor() called on non-open " VLOG(4) << "AsyncSocket::setCongestionFlavor() called on non-open "
<< "socket " << this << "(state=" << state_ << ")"; << "socket " << this << "(state=" << state_ << ")";
return EINVAL; return EINVAL;
} }
if (setsockopt( if (netops::setsockopt(
fd_, fd_,
IPPROTO_TCP, IPPROTO_TCP,
TCP_CONGESTION, TCP_CONGESTION,
...@@ -1634,7 +1635,7 @@ int AsyncSocket::setCongestionFlavor(const std::string& cname) { ...@@ -1634,7 +1635,7 @@ int AsyncSocket::setCongestionFlavor(const std::string& cname) {
int AsyncSocket::setQuickAck(bool quickack) { int AsyncSocket::setQuickAck(bool quickack) {
(void)quickack; (void)quickack;
if (fd_ < 0) { if (fd_ == NetworkSocket()) {
VLOG(4) << "AsyncSocket::setQuickAck() called on non-open socket " << this VLOG(4) << "AsyncSocket::setQuickAck() called on non-open socket " << this
<< "(state=" << state_ << ")"; << "(state=" << state_ << ")";
return EINVAL; return EINVAL;
...@@ -1642,7 +1643,8 @@ int AsyncSocket::setQuickAck(bool quickack) { ...@@ -1642,7 +1643,8 @@ int AsyncSocket::setQuickAck(bool quickack) {
#ifdef TCP_QUICKACK // Linux-only #ifdef TCP_QUICKACK // Linux-only
int value = quickack ? 1 : 0; int value = quickack ? 1 : 0;
if (setsockopt(fd_, IPPROTO_TCP, TCP_QUICKACK, &value, sizeof(value)) != 0) { if (netops::setsockopt(
fd_, IPPROTO_TCP, TCP_QUICKACK, &value, sizeof(value)) != 0) {
int errnoCopy = errno; int errnoCopy = errno;
VLOG(2) << "failed to update TCP_QUICKACK option on AsyncSocket" << this VLOG(2) << "failed to update TCP_QUICKACK option on AsyncSocket" << this
<< "(fd=" << fd_ << ", state=" << state_ << "(fd=" << fd_ << ", state=" << state_
...@@ -1657,13 +1659,14 @@ int AsyncSocket::setQuickAck(bool quickack) { ...@@ -1657,13 +1659,14 @@ int AsyncSocket::setQuickAck(bool quickack) {
} }
int AsyncSocket::setSendBufSize(size_t bufsize) { int AsyncSocket::setSendBufSize(size_t bufsize) {
if (fd_ < 0) { if (fd_ == NetworkSocket()) {
VLOG(4) << "AsyncSocket::setSendBufSize() called on non-open socket " VLOG(4) << "AsyncSocket::setSendBufSize() called on non-open socket "
<< this << "(state=" << state_ << ")"; << this << "(state=" << state_ << ")";
return EINVAL; return EINVAL;
} }
if (setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &bufsize, sizeof(bufsize)) != 0) { if (netops::setsockopt(
fd_, SOL_SOCKET, SO_SNDBUF, &bufsize, sizeof(bufsize)) != 0) {
int errnoCopy = errno; int errnoCopy = errno;
VLOG(2) << "failed to update SO_SNDBUF option on AsyncSocket" << this VLOG(2) << "failed to update SO_SNDBUF option on AsyncSocket" << this
<< "(fd=" << fd_ << ", state=" << state_ << "(fd=" << fd_ << ", state=" << state_
...@@ -1675,13 +1678,14 @@ int AsyncSocket::setSendBufSize(size_t bufsize) { ...@@ -1675,13 +1678,14 @@ int AsyncSocket::setSendBufSize(size_t bufsize) {
} }
int AsyncSocket::setRecvBufSize(size_t bufsize) { int AsyncSocket::setRecvBufSize(size_t bufsize) {
if (fd_ < 0) { if (fd_ == NetworkSocket()) {
VLOG(4) << "AsyncSocket::setRecvBufSize() called on non-open socket " VLOG(4) << "AsyncSocket::setRecvBufSize() called on non-open socket "
<< this << "(state=" << state_ << ")"; << this << "(state=" << state_ << ")";
return EINVAL; return EINVAL;
} }
if (setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &bufsize, sizeof(bufsize)) != 0) { if (netops::setsockopt(
fd_, SOL_SOCKET, SO_RCVBUF, &bufsize, sizeof(bufsize)) != 0) {
int errnoCopy = errno; int errnoCopy = errno;
VLOG(2) << "failed to update SO_RCVBUF option on AsyncSocket" << this VLOG(2) << "failed to update SO_RCVBUF option on AsyncSocket" << this
<< "(fd=" << fd_ << ", state=" << state_ << "(fd=" << fd_ << ", state=" << state_
...@@ -1693,13 +1697,14 @@ int AsyncSocket::setRecvBufSize(size_t bufsize) { ...@@ -1693,13 +1697,14 @@ int AsyncSocket::setRecvBufSize(size_t bufsize) {
} }
int AsyncSocket::setTCPProfile(int profd) { int AsyncSocket::setTCPProfile(int profd) {
if (fd_ < 0) { if (fd_ == NetworkSocket()) {
VLOG(4) << "AsyncSocket::setTCPProfile() called on non-open socket " << this VLOG(4) << "AsyncSocket::setTCPProfile() called on non-open socket " << this
<< "(state=" << state_ << ")"; << "(state=" << state_ << ")";
return EINVAL; return EINVAL;
} }
if (setsockopt(fd_, SOL_SOCKET, SO_SET_NAMESPACE, &profd, sizeof(int)) != 0) { if (netops::setsockopt(
fd_, SOL_SOCKET, SO_SET_NAMESPACE, &profd, sizeof(int)) != 0) {
int errnoCopy = errno; int errnoCopy = errno;
VLOG(2) << "failed to set socket namespace option on AsyncSocket" << this VLOG(2) << "failed to set socket namespace option on AsyncSocket" << this
<< "(fd=" << fd_ << ", state=" << state_ << "(fd=" << fd_ << ", state=" << state_
...@@ -1781,7 +1786,7 @@ AsyncSocket::performRead(void** buf, size_t* buflen, size_t* /* offset */) { ...@@ -1781,7 +1786,7 @@ AsyncSocket::performRead(void** buf, size_t* buflen, size_t* /* offset */) {
return ReadResult(len); return ReadResult(len);
} }
ssize_t bytes = recv(fd_, *buf, *buflen, MSG_DONTWAIT); ssize_t bytes = netops::recv(fd_, *buf, *buflen, MSG_DONTWAIT);
if (bytes < 0) { if (bytes < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK) { if (errno == EAGAIN || errno == EWOULDBLOCK) {
// No more data to read right now. // No more data to read right now.
...@@ -1831,7 +1836,7 @@ size_t AsyncSocket::handleErrMessages() noexcept { ...@@ -1831,7 +1836,7 @@ size_t AsyncSocket::handleErrMessages() noexcept {
int ret; int ret;
size_t num = 0; size_t num = 0;
while (true) { while (true) {
ret = recvmsg(fd_, &msg, MSG_ERRQUEUE); ret = netops::recvmsg(fd_, &msg, MSG_ERRQUEUE);
VLOG(5) << "AsyncSocket::handleErrMessages(): recvmsg returned " << ret; VLOG(5) << "AsyncSocket::handleErrMessages(): recvmsg returned " << ret;
if (ret < 0) { if (ret < 0) {
...@@ -2092,13 +2097,13 @@ void AsyncSocket::handleWrite() noexcept { ...@@ -2092,13 +2097,13 @@ void AsyncSocket::handleWrite() noexcept {
// this point. // this point.
assert(readCallback_ == nullptr); assert(readCallback_ == nullptr);
state_ = StateEnum::CLOSED; state_ = StateEnum::CLOSED;
if (fd_ >= 0) { if (fd_ != NetworkSocket()) {
ioHandler_.changeHandlerFD(-1); ioHandler_.changeHandlerFD(NetworkSocket());
doClose(); doClose();
} }
} else { } else {
// Reads are still enabled, so we are only doing a half-shutdown // Reads are still enabled, so we are only doing a half-shutdown
shutdown(fd_, SHUT_WR); netops::shutdown(fd_, SHUT_WR);
} }
} }
} }
...@@ -2223,7 +2228,7 @@ void AsyncSocket::handleConnect() noexcept { ...@@ -2223,7 +2228,7 @@ void AsyncSocket::handleConnect() noexcept {
// Call getsockopt() to check if the connect succeeded // Call getsockopt() to check if the connect succeeded
int error; int error;
socklen_t len = sizeof(error); socklen_t len = sizeof(error);
int rv = getsockopt(fd_, SOL_SOCKET, SO_ERROR, &error, &len); int rv = netops::getsockopt(fd_, SOL_SOCKET, SO_ERROR, &error, &len);
if (rv != 0) { if (rv != 0) {
auto errnoCopy = errno; auto errnoCopy = errno;
AsyncSocketException ex( AsyncSocketException ex(
...@@ -2253,7 +2258,7 @@ void AsyncSocket::handleConnect() noexcept { ...@@ -2253,7 +2258,7 @@ void AsyncSocket::handleConnect() noexcept {
// are still connecting we just abort the connect rather than waiting for // are still connecting we just abort the connect rather than waiting for
// it to complete. // it to complete.
assert((shutdownFlags_ & SHUT_READ) == 0); assert((shutdownFlags_ & SHUT_READ) == 0);
shutdown(fd_, SHUT_WR); netops::shutdown(fd_, SHUT_WR);
shutdownFlags_ |= SHUT_WRITE; shutdownFlags_ |= SHUT_WRITE;
} }
...@@ -2313,12 +2318,15 @@ void AsyncSocket::timeoutExpired() noexcept { ...@@ -2313,12 +2318,15 @@ void AsyncSocket::timeoutExpired() noexcept {
} }
} }
ssize_t AsyncSocket::tfoSendMsg(int fd, struct msghdr* msg, int msg_flags) { ssize_t
AsyncSocket::tfoSendMsg(NetworkSocket fd, struct msghdr* msg, int msg_flags) {
return detail::tfo_sendmsg(fd, msg, msg_flags); return detail::tfo_sendmsg(fd, msg, msg_flags);
} }
AsyncSocket::WriteResult AsyncSocket::WriteResult AsyncSocket::sendSocketMessage(
AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) { NetworkSocket fd,
struct msghdr* msg,
int msg_flags) {
ssize_t totalWritten = 0; ssize_t totalWritten = 0;
if (state_ == StateEnum::FAST_OPEN) { if (state_ == StateEnum::FAST_OPEN) {
sockaddr_storage addr; sockaddr_storage addr;
...@@ -2379,7 +2387,7 @@ AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) { ...@@ -2379,7 +2387,7 @@ AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
AsyncSocketException::UNKNOWN, "No more free local ports")); AsyncSocketException::UNKNOWN, "No more free local ports"));
} }
} else { } else {
totalWritten = ::sendmsg(fd, msg, msg_flags); totalWritten = netops::sendmsg(fd, msg, msg_flags);
} }
return WriteResult(totalWritten); return WriteResult(totalWritten);
} }
...@@ -2533,8 +2541,8 @@ void AsyncSocket::startFail() { ...@@ -2533,8 +2541,8 @@ void AsyncSocket::startFail() {
} }
writeTimeout_.cancelTimeout(); writeTimeout_.cancelTimeout();
if (fd_ >= 0) { if (fd_ != NetworkSocket()) {
ioHandler_.changeHandlerFD(-1); ioHandler_.changeHandlerFD(NetworkSocket());
doClose(); doClose();
} }
} }
...@@ -2788,15 +2796,15 @@ void AsyncSocket::invalidState(WriteCallback* callback) { ...@@ -2788,15 +2796,15 @@ void AsyncSocket::invalidState(WriteCallback* callback) {
} }
void AsyncSocket::doClose() { void AsyncSocket::doClose() {
if (fd_ == -1) { if (fd_ == NetworkSocket()) {
return; return;
} }
if (const auto shutdownSocketSet = wShutdownSocketSet_.lock()) { if (const auto shutdownSocketSet = wShutdownSocketSet_.lock()) {
shutdownSocketSet->close(fd_); shutdownSocketSet->close(fd_);
} else { } else {
::close(fd_); netops::close(fd_);
} }
fd_ = -1; fd_ = NetworkSocket();
// we also want to clear the zerocopy maps // we also want to clear the zerocopy maps
// if the fd has been closed // if the fd has been closed
......
...@@ -262,7 +262,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -262,7 +262,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
* @param fd File descriptor to take over (should be a connected socket). * @param fd File descriptor to take over (should be a connected socket).
* @param zeroCopyBufId Zerocopy buf id to start with. * @param zeroCopyBufId Zerocopy buf id to start with.
*/ */
AsyncSocket(EventBase* evb, int fd, uint32_t zeroCopyBufId = 0); AsyncSocket(EventBase* evb, int fd, uint32_t zeroCopyBufId = 0)
: AsyncSocket(evb, NetworkSocket::fromFd(fd), zeroCopyBufId) {}
AsyncSocket(EventBase* evb, NetworkSocket fd, uint32_t zeroCopyBufId = 0);
/** /**
* Create an AsyncSocket from a different, already connected AsyncSocket. * Create an AsyncSocket from a different, already connected AsyncSocket.
...@@ -309,6 +311,11 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -309,6 +311,11 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
* Helper function to create a shared_ptr<AsyncSocket>. * Helper function to create a shared_ptr<AsyncSocket>.
*/ */
static std::shared_ptr<AsyncSocket> newSocket(EventBase* evb, int fd) { static std::shared_ptr<AsyncSocket> newSocket(EventBase* evb, int fd) {
return newSocket(evb, NetworkSocket::fromFd(fd));
}
static std::shared_ptr<AsyncSocket> newSocket(
EventBase* evb,
NetworkSocket fd) {
return std::shared_ptr<AsyncSocket>(new AsyncSocket(evb, fd), Destructor()); return std::shared_ptr<AsyncSocket>(new AsyncSocket(evb, fd), Destructor());
} }
...@@ -333,7 +340,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -333,7 +340,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
* Get the file descriptor used by the AsyncSocket. * Get the file descriptor used by the AsyncSocket.
*/ */
virtual int getFd() const { virtual int getFd() const {
return fd_; return fd_.toFd();
} }
/** /**
...@@ -365,8 +372,11 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -365,8 +372,11 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
} }
return level < other.level; return level < other.level;
} }
int apply(NetworkSocket fd, int val) const {
return netops::setsockopt(fd, level, optname, &val, sizeof(val));
}
int apply(int fd, int val) const { int apply(int fd, int val) const {
return setsockopt(fd, level, optname, &val, sizeof(val)); return apply(NetworkSocket::fromFd(fd), val);
} }
int level; int level;
int optname; int optname;
...@@ -712,7 +722,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -712,7 +722,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
*/ */
template <typename T> template <typename T>
int getSockOpt(int level, int optname, T* optval, socklen_t* optlen) { int getSockOpt(int level, int optname, T* optval, socklen_t* optlen) {
return getsockopt(fd_, level, optname, (void*)optval, optlen); return netops::getsockopt(fd_, level, optname, (void*)optval, optlen);
} }
/** /**
...@@ -725,7 +735,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -725,7 +735,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
*/ */
template <typename T> template <typename T>
int setSockOpt(int level, int optname, const T* optval) { int setSockOpt(int level, int optname, const T* optval) {
return setsockopt(fd_, level, optname, optval, sizeof(T)); return netops::setsockopt(fd_, level, optname, optval, sizeof(T));
} }
/** /**
...@@ -741,7 +751,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -741,7 +751,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
*/ */
virtual int virtual int
getSockOptVirtual(int level, int optname, void* optval, socklen_t* optlen) { getSockOptVirtual(int level, int optname, void* optval, socklen_t* optlen) {
return getsockopt(fd_, level, optname, optval, optlen); return netops::getsockopt(fd_, level, optname, optval, optlen);
} }
/** /**
...@@ -760,7 +770,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -760,7 +770,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
int optname, int optname,
void const* optval, void const* optval,
socklen_t optlen) { socklen_t optlen) {
return setsockopt(fd_, level, optname, optval, optlen); return netops::setsockopt(fd_, level, optname, optval, optlen);
} }
/** /**
...@@ -1000,6 +1010,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -1000,6 +1010,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
IoHandler(AsyncSocket* socket, EventBase* eventBase) IoHandler(AsyncSocket* socket, EventBase* eventBase)
: EventHandler(eventBase, -1), socket_(socket) {} : EventHandler(eventBase, -1), socket_(socket) {}
IoHandler(AsyncSocket* socket, EventBase* eventBase, int fd) IoHandler(AsyncSocket* socket, EventBase* eventBase, int fd)
: IoHandler(socket, eventBase, NetworkSocket::fromFd(fd)) {}
IoHandler(AsyncSocket* socket, EventBase* eventBase, NetworkSocket fd)
: EventHandler(eventBase, fd), socket_(socket) {} : EventHandler(eventBase, fd), socket_(socket) {}
void handlerReady(uint16_t events) noexcept override { void handlerReady(uint16_t events) noexcept override {
...@@ -1138,9 +1150,17 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -1138,9 +1150,17 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
* @param msg_flags Flags to pass to sendmsg * @param msg_flags Flags to pass to sendmsg
*/ */
AsyncSocket::WriteResult AsyncSocket::WriteResult
sendSocketMessage(int fd, struct msghdr* msg, int msg_flags); sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
return sendSocketMessage(NetworkSocket::fromFd(fd), msg, msg_flags);
}
AsyncSocket::WriteResult
sendSocketMessage(NetworkSocket fd, struct msghdr* msg, int msg_flags);
virtual ssize_t tfoSendMsg(int fd, struct msghdr* msg, int msg_flags); ssize_t tfoSendMsg(int fd, struct msghdr* msg, int msg_flags) {
return tfoSendMsg(NetworkSocket::fromFd(fd), msg, msg_flags);
}
virtual ssize_t
tfoSendMsg(NetworkSocket fd, struct msghdr* msg, int msg_flags);
int socketConnect(const struct sockaddr* addr, socklen_t len); int socketConnect(const struct sockaddr* addr, socklen_t len);
...@@ -1226,7 +1246,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -1226,7 +1246,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
StateEnum state_; ///< StateEnum describing current state StateEnum state_; ///< StateEnum describing current state
uint8_t shutdownFlags_; ///< Shutdown state (ShutdownFlags) uint8_t shutdownFlags_; ///< Shutdown state (ShutdownFlags)
uint16_t eventFlags_; ///< EventBase::HandlerFlags settings uint16_t eventFlags_; ///< EventBase::HandlerFlags settings
int fd_; ///< The socket file descriptor NetworkSocket fd_; ///< The socket file descriptor
mutable folly::SocketAddress addr_; ///< The address we tried to connect to mutable folly::SocketAddress addr_; ///< The address we tried to connect to
mutable folly::SocketAddress localAddr_; mutable folly::SocketAddress localAddr_;
///< The address we are connecting from ///< The address we are connecting from
......
...@@ -2229,7 +2229,9 @@ class MockAsyncTFOSSLSocket : public AsyncSSLSocket { ...@@ -2229,7 +2229,9 @@ class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
EventBase* evb) EventBase* evb)
: AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {} : AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags)); MOCK_METHOD3(
tfoSendMsg,
ssize_t(NetworkSocket fd, struct msghdr* msg, int msg_flags));
}; };
#if defined __linux__ #if defined __linux__
...@@ -2372,10 +2374,10 @@ MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback( ...@@ -2372,10 +2374,10 @@ MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
EXPECT_CALL(*socket, tfoSendMsg(_, _, _)) EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
.Times(cardinality) .Times(cardinality)
.WillOnce(Invoke([&](int fd, struct msghdr*, int) { .WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
sockaddr_storage addr; sockaddr_storage addr;
auto len = address.getAddress(&addr); auto len = address.getAddress(&addr);
return connect(fd, (const struct sockaddr*)&addr, len); return netops::connect(fd, (const struct sockaddr*)&addr, len);
})); }));
return socket; return socket;
} }
......
...@@ -2585,7 +2585,9 @@ class MockAsyncTFOSocket : public AsyncSocket { ...@@ -2585,7 +2585,9 @@ class MockAsyncTFOSocket : public AsyncSocket {
explicit MockAsyncTFOSocket(EventBase* evb) : AsyncSocket(evb) {} explicit MockAsyncTFOSocket(EventBase* evb) : AsyncSocket(evb) {}
MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags)); MOCK_METHOD3(
tfoSendMsg,
ssize_t(NetworkSocket fd, struct msghdr* msg, int msg_flags));
}; };
TEST(AsyncSocketTest, TestTFOUnsupported) { TEST(AsyncSocketTest, TestTFOUnsupported) {
...@@ -2646,10 +2648,10 @@ TEST(AsyncSocketTest, ConnectRefusedDelayedTFO) { ...@@ -2646,10 +2648,10 @@ TEST(AsyncSocketTest, ConnectRefusedDelayedTFO) {
// Hopefully this fails // Hopefully this fails
folly::SocketAddress fakeAddr("127.0.0.1", 65535); folly::SocketAddress fakeAddr("127.0.0.1", 65535);
EXPECT_CALL(*socket, tfoSendMsg(_, _, _)) EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
.WillOnce(Invoke([&](int fd, struct msghdr*, int) { .WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
sockaddr_storage addr; sockaddr_storage addr;
auto len = fakeAddr.getAddress(&addr); auto len = fakeAddr.getAddress(&addr);
int ret = connect(fd, (const struct sockaddr*)&addr, len); auto ret = netops::connect(fd, (const struct sockaddr*)&addr, len);
LOG(INFO) << "connecting the socket " << fd << " : " << ret << " : " LOG(INFO) << "connecting the socket " << fd << " : " << ret << " : "
<< errno; << errno;
return ret; return ret;
...@@ -2735,10 +2737,10 @@ TEST(AsyncSocketTest, TestTFOFallbackToConnect) { ...@@ -2735,10 +2737,10 @@ TEST(AsyncSocketTest, TestTFOFallbackToConnect) {
socket->setReadCB(&rcb); socket->setReadCB(&rcb);
EXPECT_CALL(*socket, tfoSendMsg(_, _, _)) EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
.WillOnce(Invoke([&](int fd, struct msghdr*, int) { .WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
sockaddr_storage addr; sockaddr_storage addr;
auto len = server.getAddress().getAddress(&addr); auto len = server.getAddress().getAddress(&addr);
return connect(fd, (const struct sockaddr*)&addr, len); return netops::connect(fd, (const struct sockaddr*)&addr, len);
})); }));
WriteCallback write; WriteCallback write;
auto sendBuf = IOBuf::copyBuffer("hey"); auto sendBuf = IOBuf::copyBuffer("hey");
...@@ -2800,10 +2802,10 @@ TEST(AsyncSocketTest, TestTFOFallbackTimeout) { ...@@ -2800,10 +2802,10 @@ TEST(AsyncSocketTest, TestTFOFallbackTimeout) {
socket->setReadCB(&rcb); socket->setReadCB(&rcb);
EXPECT_CALL(*socket, tfoSendMsg(_, _, _)) EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
.WillOnce(Invoke([&](int fd, struct msghdr*, int) { .WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
sockaddr_storage addr2; sockaddr_storage addr2;
auto len = addr.getAddress(&addr2); auto len = addr.getAddress(&addr2);
return connect(fd, (const struct sockaddr*)&addr2, len); return netops::connect(fd, (const struct sockaddr*)&addr2, len);
})); }));
WriteCallback write; WriteCallback write;
socket->writeChain(&write, IOBuf::copyBuffer("hey")); socket->writeChain(&write, IOBuf::copyBuffer("hey"));
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment