Commit c7b21fd1 authored by Orvid King's avatar Orvid King Committed by Facebook Github Bot

NetworkSocket support for AsyncSocket

Summary: This is a big one, but adds support for NetworkSocket to AsyncSocket itself.

Reviewed By: yfeldblum

Differential Revision: D12818493

fbshipit-source-id: d7b73f356414b006f4fc6b1c380a8fa1c32b9a46
parent 083c73bd
......@@ -326,7 +326,7 @@ void AsyncSSLSocket::init() {
void AsyncSSLSocket::closeNow() {
// Close the SSL connection.
if (ssl_ != nullptr && fd_ != -1) {
if (ssl_ != nullptr && fd_ != NetworkSocket()) {
int rc = SSL_shutdown(ssl_.get());
if (rc == 0) {
rc = SSL_shutdown(ssl_.get());
......
......@@ -254,7 +254,10 @@ namespace {
static AsyncSocket::SendMsgParamsCallback defaultSendMsgParamsCallback;
// Based on flags, signal the transparent handler to disable certain functions
void disableTransparentFunctions(int fd, bool noTransparentTls, bool noTSocks) {
void disableTransparentFunctions(
NetworkSocket fd,
bool noTransparentTls,
bool noTSocks) {
(void)fd;
(void)noTransparentTls;
(void)noTSocks;
......@@ -262,12 +265,12 @@ void disableTransparentFunctions(int fd, bool noTransparentTls, bool noTSocks) {
if (noTransparentTls) {
// Ignore return value, errors are ok
VLOG(5) << "Disabling TTLS for fd " << fd;
::setsockopt(fd, SOL_SOCKET, SO_NO_TRANSPARENT_TLS, nullptr, 0);
netops::setsockopt(fd, SOL_SOCKET, SO_NO_TRANSPARENT_TLS, nullptr, 0);
}
if (noTSocks) {
VLOG(5) << "Disabling TSOCKS for fd " << fd;
// Ignore return value, errors are ok
::setsockopt(fd, SOL_SOCKET, SO_NO_TSOCKS, nullptr, 0);
netops::setsockopt(fd, SOL_SOCKET, SO_NO_TSOCKS, nullptr, 0);
}
#endif
}
......@@ -309,7 +312,10 @@ AsyncSocket::AsyncSocket(
connect(nullptr, ip, port, connectTimeout);
}
AsyncSocket::AsyncSocket(EventBase* evb, int fd, uint32_t zeroCopyBufId)
AsyncSocket::AsyncSocket(
EventBase* evb,
NetworkSocket fd,
uint32_t zeroCopyBufId)
: zeroCopyBufId_(zeroCopyBufId),
eventBase_(evb),
writeTimeout_(this, evb),
......@@ -341,7 +347,7 @@ void AsyncSocket::init() {
shutdownFlags_ = 0;
state_ = StateEnum::UNINIT;
eventFlags_ = EventHandler::NONE;
fd_ = -1;
fd_ = NetworkSocket();
sendTimeout_ = 0;
maxReadsPerEvent_ = 16;
connectCallback_ = nullptr;
......@@ -381,14 +387,14 @@ int AsyncSocket::detachFd() {
if (const auto socketSet = wShutdownSocketSet_.lock()) {
socketSet->remove(fd_);
}
int fd = fd_;
fd_ = -1;
auto fd = fd_;
fd_ = NetworkSocket();
// Call closeNow() to invoke all pending callbacks with an error.
closeNow();
// Update the EventHandler to stop using this fd.
// This can only be done after closeNow() unregisters the handler.
ioHandler_.changeHandlerFD(-1);
return fd;
return fd.toFd();
}
const folly::SocketAddress& AsyncSocket::anyAddress() {
......@@ -406,11 +412,11 @@ void AsyncSocket::setShutdownSocketSet(
return;
}
if (shutdownSocketSet && fd_ != -1) {
if (shutdownSocketSet && fd_ != NetworkSocket()) {
shutdownSocketSet->remove(fd_);
}
if (newSS && fd_ != -1) {
if (newSS && fd_ != NetworkSocket()) {
newSS->add(fd_);
}
......@@ -418,7 +424,7 @@ void AsyncSocket::setShutdownSocketSet(
}
void AsyncSocket::setCloseOnExec() {
int rv = fcntl(fd_, F_SETFD, FD_CLOEXEC);
int rv = netops::set_socket_close_on_exec(fd_);
if (rv != 0) {
auto errnoCopy = errno;
throw AsyncSocketException(
......@@ -449,7 +455,7 @@ void AsyncSocket::connect(
// Make connect end time at least >= connectStartTime.
connectEndTime_ = connectStartTime_;
assert(fd_ == -1);
assert(fd_ == NetworkSocket());
state_ = StateEnum::CONNECTING;
connectCallback_ = callback;
......@@ -462,8 +468,8 @@ void AsyncSocket::connect(
// constant (PF_xxx) rather than an address family (AF_xxx), but the
// distinction is mainly just historical. In pretty much all
// implementations the PF_foo and AF_foo constants are identical.
fd_ = fsp::socket(address.getFamily(), SOCK_STREAM, 0);
if (fd_ < 0) {
fd_ = netops::socket(address.getFamily(), SOCK_STREAM, 0);
if (fd_ == NetworkSocket()) {
auto errnoCopy = errno;
throw AsyncSocketException(
AsyncSocketException::INTERNAL_ERROR,
......@@ -479,15 +485,7 @@ void AsyncSocket::connect(
setCloseOnExec();
// Put the socket in non-blocking mode
int flags = fcntl(fd_, F_GETFL, 0);
if (flags == -1) {
auto errnoCopy = errno;
throw AsyncSocketException(
AsyncSocketException::INTERNAL_ERROR,
withAddr("failed to get socket flags"),
errnoCopy);
}
int rv = fcntl(fd_, F_SETFL, flags | O_NONBLOCK);
int rv = netops::set_socket_non_blocking(fd_);
if (rv == -1) {
auto errnoCopy = errno;
throw AsyncSocketException(
......@@ -498,7 +496,7 @@ void AsyncSocket::connect(
#if !defined(MSG_NOSIGNAL) && defined(F_SETNOSIGPIPE)
// iOS and OS X don't support MSG_NOSIGNAL; set F_SETNOSIGPIPE instead
rv = fcntl(fd_, F_SETNOSIGPIPE, 1);
rv = fcntl(fd_.toFd(), F_SETNOSIGPIPE, 1);
if (rv == -1) {
auto errnoCopy = errno;
throw AsyncSocketException(
......@@ -524,7 +522,8 @@ void AsyncSocket::connect(
// bind the socket
if (bindAddr != anyAddress()) {
int one = 1;
if (setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one))) {
if (netops::setsockopt(
fd_, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one))) {
auto errnoCopy = errno;
doClose();
throw AsyncSocketException(
......@@ -535,7 +534,7 @@ void AsyncSocket::connect(
bindAddr.getAddress(&addrStorage);
if (bind(fd_, saddr, bindAddr.getActualSize()) != 0) {
if (netops::bind(fd_, saddr, bindAddr.getActualSize()) != 0) {
auto errnoCopy = errno;
doClose();
throw AsyncSocketException(
......@@ -598,7 +597,7 @@ void AsyncSocket::connect(
}
int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) {
int rv = fsp::connect(fd_, saddr, len);
int rv = netops::connect(fd_, saddr, len);
if (rv < 0) {
auto errnoCopy = errno;
if (errnoCopy == EINPROGRESS) {
......@@ -861,12 +860,13 @@ bool AsyncSocket::setZeroCopy(bool enable) {
if (msgErrQueueSupported) {
zeroCopyVal_ = enable;
if (fd_ < 0) {
if (fd_ == NetworkSocket()) {
return false;
}
int val = enable ? 1 : 0;
int ret = setsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, sizeof(val));
int ret =
netops::setsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, sizeof(val));
// if enable == false, set zeroCopyEnabled_ = false regardless
// if SO_ZEROCOPY is set or not
......@@ -881,7 +881,7 @@ bool AsyncSocket::setZeroCopy(bool enable) {
if (ret) {
val = 0;
socklen_t optlen = sizeof(val);
ret = getsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, &optlen);
ret = netops::getsockopt(fd_, SOL_SOCKET, SO_ZEROCOPY, &val, &optlen);
if (!ret) {
enable = val ? true : false;
......@@ -1282,8 +1282,8 @@ void AsyncSocket::closeNow() {
immediateReadHandler_.cancelLoopCallback();
}
if (fd_ >= 0) {
ioHandler_.changeHandlerFD(-1);
if (fd_ != NetworkSocket()) {
ioHandler_.changeHandlerFD(NetworkSocket());
doClose();
}
......@@ -1324,7 +1324,7 @@ void AsyncSocket::closeNow() {
void AsyncSocket::closeWithReset() {
// Enable SO_LINGER, with the linger timeout set to 0.
// This will trigger a TCP reset when we close the socket.
if (fd_ >= 0) {
if (fd_ != NetworkSocket()) {
struct linger optLinger = {1, 0};
if (setSockOpt(SOL_SOCKET, SO_LINGER, &optLinger) != 0) {
VLOG(2) << "AsyncSocket::closeWithReset(): error setting SO_LINGER "
......@@ -1394,7 +1394,7 @@ void AsyncSocket::shutdownWriteNow() {
}
// Shutdown writes on the file descriptor
shutdown(fd_, SHUT_WR);
netops::shutdown(fd_, SHUT_WR);
// Immediately fail all write requests
failAllWrites(socketShutdownForWritesEx);
......@@ -1441,26 +1441,26 @@ void AsyncSocket::shutdownWriteNow() {
}
bool AsyncSocket::readable() const {
if (fd_ == -1) {
if (fd_ == NetworkSocket()) {
return false;
}
struct pollfd fds[1];
netops::PollDescriptor fds[1];
fds[0].fd = fd_;
fds[0].events = POLLIN;
fds[0].revents = 0;
int rc = poll(fds, 1, 0);
int rc = netops::poll(fds, 1, 0);
return rc == 1;
}
bool AsyncSocket::writable() const {
if (fd_ == -1) {
if (fd_ == NetworkSocket()) {
return false;
}
struct pollfd fds[1];
netops::PollDescriptor fds[1];
fds[0].fd = fd_;
fds[0].events = POLLOUT;
fds[0].revents = 0;
int rc = poll(fds, 1, 0);
int rc = netops::poll(fds, 1, 0);
return rc == 1;
}
......@@ -1469,17 +1469,17 @@ bool AsyncSocket::isPending() const {
}
bool AsyncSocket::hangup() const {
if (fd_ == -1) {
if (fd_ == NetworkSocket()) {
// sanity check, no one should ask for hangup if we are not connected.
assert(false);
return false;
}
#ifdef POLLRDHUP // Linux-only
struct pollfd fds[1];
netops::PollDescriptor fds[1];
fds[0].fd = fd_;
fds[0].events = POLLRDHUP | POLLHUP;
fds[0].revents = 0;
poll(fds, 1, 0);
netops::poll(fds, 1, 0);
return (fds[0].revents & (POLLRDHUP | POLLHUP)) != 0;
#else
return false;
......@@ -1542,7 +1542,7 @@ bool AsyncSocket::isDetachable() const {
}
void AsyncSocket::cacheAddresses() {
if (fd_ >= 0) {
if (fd_ != NetworkSocket()) {
try {
cacheLocalAddress();
cachePeerAddress();
......@@ -1587,14 +1587,15 @@ bool AsyncSocket::getTFOSucceded() const {
}
int AsyncSocket::setNoDelay(bool noDelay) {
if (fd_ < 0) {
if (fd_ == NetworkSocket()) {
VLOG(4) << "AsyncSocket::setNoDelay() called on non-open socket " << this
<< "(state=" << state_ << ")";
return EINVAL;
}
int value = noDelay ? 1 : 0;
if (setsockopt(fd_, IPPROTO_TCP, TCP_NODELAY, &value, sizeof(value)) != 0) {
if (netops::setsockopt(
fd_, IPPROTO_TCP, TCP_NODELAY, &value, sizeof(value)) != 0) {
int errnoCopy = errno;
VLOG(2) << "failed to update TCP_NODELAY option on AsyncSocket " << this
<< " (fd=" << fd_ << ", state=" << state_
......@@ -1610,13 +1611,13 @@ int AsyncSocket::setCongestionFlavor(const std::string& cname) {
#define TCP_CONGESTION 13
#endif
if (fd_ < 0) {
if (fd_ == NetworkSocket()) {
VLOG(4) << "AsyncSocket::setCongestionFlavor() called on non-open "
<< "socket " << this << "(state=" << state_ << ")";
return EINVAL;
}
if (setsockopt(
if (netops::setsockopt(
fd_,
IPPROTO_TCP,
TCP_CONGESTION,
......@@ -1634,7 +1635,7 @@ int AsyncSocket::setCongestionFlavor(const std::string& cname) {
int AsyncSocket::setQuickAck(bool quickack) {
(void)quickack;
if (fd_ < 0) {
if (fd_ == NetworkSocket()) {
VLOG(4) << "AsyncSocket::setQuickAck() called on non-open socket " << this
<< "(state=" << state_ << ")";
return EINVAL;
......@@ -1642,7 +1643,8 @@ int AsyncSocket::setQuickAck(bool quickack) {
#ifdef TCP_QUICKACK // Linux-only
int value = quickack ? 1 : 0;
if (setsockopt(fd_, IPPROTO_TCP, TCP_QUICKACK, &value, sizeof(value)) != 0) {
if (netops::setsockopt(
fd_, IPPROTO_TCP, TCP_QUICKACK, &value, sizeof(value)) != 0) {
int errnoCopy = errno;
VLOG(2) << "failed to update TCP_QUICKACK option on AsyncSocket" << this
<< "(fd=" << fd_ << ", state=" << state_
......@@ -1657,13 +1659,14 @@ int AsyncSocket::setQuickAck(bool quickack) {
}
int AsyncSocket::setSendBufSize(size_t bufsize) {
if (fd_ < 0) {
if (fd_ == NetworkSocket()) {
VLOG(4) << "AsyncSocket::setSendBufSize() called on non-open socket "
<< this << "(state=" << state_ << ")";
return EINVAL;
}
if (setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &bufsize, sizeof(bufsize)) != 0) {
if (netops::setsockopt(
fd_, SOL_SOCKET, SO_SNDBUF, &bufsize, sizeof(bufsize)) != 0) {
int errnoCopy = errno;
VLOG(2) << "failed to update SO_SNDBUF option on AsyncSocket" << this
<< "(fd=" << fd_ << ", state=" << state_
......@@ -1675,13 +1678,14 @@ int AsyncSocket::setSendBufSize(size_t bufsize) {
}
int AsyncSocket::setRecvBufSize(size_t bufsize) {
if (fd_ < 0) {
if (fd_ == NetworkSocket()) {
VLOG(4) << "AsyncSocket::setRecvBufSize() called on non-open socket "
<< this << "(state=" << state_ << ")";
return EINVAL;
}
if (setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &bufsize, sizeof(bufsize)) != 0) {
if (netops::setsockopt(
fd_, SOL_SOCKET, SO_RCVBUF, &bufsize, sizeof(bufsize)) != 0) {
int errnoCopy = errno;
VLOG(2) << "failed to update SO_RCVBUF option on AsyncSocket" << this
<< "(fd=" << fd_ << ", state=" << state_
......@@ -1693,13 +1697,14 @@ int AsyncSocket::setRecvBufSize(size_t bufsize) {
}
int AsyncSocket::setTCPProfile(int profd) {
if (fd_ < 0) {
if (fd_ == NetworkSocket()) {
VLOG(4) << "AsyncSocket::setTCPProfile() called on non-open socket " << this
<< "(state=" << state_ << ")";
return EINVAL;
}
if (setsockopt(fd_, SOL_SOCKET, SO_SET_NAMESPACE, &profd, sizeof(int)) != 0) {
if (netops::setsockopt(
fd_, SOL_SOCKET, SO_SET_NAMESPACE, &profd, sizeof(int)) != 0) {
int errnoCopy = errno;
VLOG(2) << "failed to set socket namespace option on AsyncSocket" << this
<< "(fd=" << fd_ << ", state=" << state_
......@@ -1781,7 +1786,7 @@ AsyncSocket::performRead(void** buf, size_t* buflen, size_t* /* offset */) {
return ReadResult(len);
}
ssize_t bytes = recv(fd_, *buf, *buflen, MSG_DONTWAIT);
ssize_t bytes = netops::recv(fd_, *buf, *buflen, MSG_DONTWAIT);
if (bytes < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// No more data to read right now.
......@@ -1831,7 +1836,7 @@ size_t AsyncSocket::handleErrMessages() noexcept {
int ret;
size_t num = 0;
while (true) {
ret = recvmsg(fd_, &msg, MSG_ERRQUEUE);
ret = netops::recvmsg(fd_, &msg, MSG_ERRQUEUE);
VLOG(5) << "AsyncSocket::handleErrMessages(): recvmsg returned " << ret;
if (ret < 0) {
......@@ -2092,13 +2097,13 @@ void AsyncSocket::handleWrite() noexcept {
// this point.
assert(readCallback_ == nullptr);
state_ = StateEnum::CLOSED;
if (fd_ >= 0) {
ioHandler_.changeHandlerFD(-1);
if (fd_ != NetworkSocket()) {
ioHandler_.changeHandlerFD(NetworkSocket());
doClose();
}
} else {
// Reads are still enabled, so we are only doing a half-shutdown
shutdown(fd_, SHUT_WR);
netops::shutdown(fd_, SHUT_WR);
}
}
}
......@@ -2223,7 +2228,7 @@ void AsyncSocket::handleConnect() noexcept {
// Call getsockopt() to check if the connect succeeded
int error;
socklen_t len = sizeof(error);
int rv = getsockopt(fd_, SOL_SOCKET, SO_ERROR, &error, &len);
int rv = netops::getsockopt(fd_, SOL_SOCKET, SO_ERROR, &error, &len);
if (rv != 0) {
auto errnoCopy = errno;
AsyncSocketException ex(
......@@ -2253,7 +2258,7 @@ void AsyncSocket::handleConnect() noexcept {
// are still connecting we just abort the connect rather than waiting for
// it to complete.
assert((shutdownFlags_ & SHUT_READ) == 0);
shutdown(fd_, SHUT_WR);
netops::shutdown(fd_, SHUT_WR);
shutdownFlags_ |= SHUT_WRITE;
}
......@@ -2313,12 +2318,15 @@ void AsyncSocket::timeoutExpired() noexcept {
}
}
ssize_t AsyncSocket::tfoSendMsg(int fd, struct msghdr* msg, int msg_flags) {
ssize_t
AsyncSocket::tfoSendMsg(NetworkSocket fd, struct msghdr* msg, int msg_flags) {
return detail::tfo_sendmsg(fd, msg, msg_flags);
}
AsyncSocket::WriteResult
AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
AsyncSocket::WriteResult AsyncSocket::sendSocketMessage(
NetworkSocket fd,
struct msghdr* msg,
int msg_flags) {
ssize_t totalWritten = 0;
if (state_ == StateEnum::FAST_OPEN) {
sockaddr_storage addr;
......@@ -2379,7 +2387,7 @@ AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
AsyncSocketException::UNKNOWN, "No more free local ports"));
}
} else {
totalWritten = ::sendmsg(fd, msg, msg_flags);
totalWritten = netops::sendmsg(fd, msg, msg_flags);
}
return WriteResult(totalWritten);
}
......@@ -2533,8 +2541,8 @@ void AsyncSocket::startFail() {
}
writeTimeout_.cancelTimeout();
if (fd_ >= 0) {
ioHandler_.changeHandlerFD(-1);
if (fd_ != NetworkSocket()) {
ioHandler_.changeHandlerFD(NetworkSocket());
doClose();
}
}
......@@ -2788,15 +2796,15 @@ void AsyncSocket::invalidState(WriteCallback* callback) {
}
void AsyncSocket::doClose() {
if (fd_ == -1) {
if (fd_ == NetworkSocket()) {
return;
}
if (const auto shutdownSocketSet = wShutdownSocketSet_.lock()) {
shutdownSocketSet->close(fd_);
} else {
::close(fd_);
netops::close(fd_);
}
fd_ = -1;
fd_ = NetworkSocket();
// we also want to clear the zerocopy maps
// if the fd has been closed
......
......@@ -262,7 +262,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
* @param fd File descriptor to take over (should be a connected socket).
* @param zeroCopyBufId Zerocopy buf id to start with.
*/
AsyncSocket(EventBase* evb, int fd, uint32_t zeroCopyBufId = 0);
AsyncSocket(EventBase* evb, int fd, uint32_t zeroCopyBufId = 0)
: AsyncSocket(evb, NetworkSocket::fromFd(fd), zeroCopyBufId) {}
AsyncSocket(EventBase* evb, NetworkSocket fd, uint32_t zeroCopyBufId = 0);
/**
* Create an AsyncSocket from a different, already connected AsyncSocket.
......@@ -309,6 +311,11 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
* Helper function to create a shared_ptr<AsyncSocket>.
*/
static std::shared_ptr<AsyncSocket> newSocket(EventBase* evb, int fd) {
return newSocket(evb, NetworkSocket::fromFd(fd));
}
static std::shared_ptr<AsyncSocket> newSocket(
EventBase* evb,
NetworkSocket fd) {
return std::shared_ptr<AsyncSocket>(new AsyncSocket(evb, fd), Destructor());
}
......@@ -333,7 +340,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
* Get the file descriptor used by the AsyncSocket.
*/
virtual int getFd() const {
return fd_;
return fd_.toFd();
}
/**
......@@ -365,8 +372,11 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
}
return level < other.level;
}
int apply(NetworkSocket fd, int val) const {
return netops::setsockopt(fd, level, optname, &val, sizeof(val));
}
int apply(int fd, int val) const {
return setsockopt(fd, level, optname, &val, sizeof(val));
return apply(NetworkSocket::fromFd(fd), val);
}
int level;
int optname;
......@@ -712,7 +722,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
*/
template <typename T>
int getSockOpt(int level, int optname, T* optval, socklen_t* optlen) {
return getsockopt(fd_, level, optname, (void*)optval, optlen);
return netops::getsockopt(fd_, level, optname, (void*)optval, optlen);
}
/**
......@@ -725,7 +735,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
*/
template <typename T>
int setSockOpt(int level, int optname, const T* optval) {
return setsockopt(fd_, level, optname, optval, sizeof(T));
return netops::setsockopt(fd_, level, optname, optval, sizeof(T));
}
/**
......@@ -741,7 +751,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
*/
virtual int
getSockOptVirtual(int level, int optname, void* optval, socklen_t* optlen) {
return getsockopt(fd_, level, optname, optval, optlen);
return netops::getsockopt(fd_, level, optname, optval, optlen);
}
/**
......@@ -760,7 +770,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
int optname,
void const* optval,
socklen_t optlen) {
return setsockopt(fd_, level, optname, optval, optlen);
return netops::setsockopt(fd_, level, optname, optval, optlen);
}
/**
......@@ -1000,6 +1010,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
IoHandler(AsyncSocket* socket, EventBase* eventBase)
: EventHandler(eventBase, -1), socket_(socket) {}
IoHandler(AsyncSocket* socket, EventBase* eventBase, int fd)
: IoHandler(socket, eventBase, NetworkSocket::fromFd(fd)) {}
IoHandler(AsyncSocket* socket, EventBase* eventBase, NetworkSocket fd)
: EventHandler(eventBase, fd), socket_(socket) {}
void handlerReady(uint16_t events) noexcept override {
......@@ -1138,9 +1150,17 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
* @param msg_flags Flags to pass to sendmsg
*/
AsyncSocket::WriteResult
sendSocketMessage(int fd, struct msghdr* msg, int msg_flags);
sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
return sendSocketMessage(NetworkSocket::fromFd(fd), msg, msg_flags);
}
AsyncSocket::WriteResult
sendSocketMessage(NetworkSocket fd, struct msghdr* msg, int msg_flags);
virtual ssize_t tfoSendMsg(int fd, struct msghdr* msg, int msg_flags);
ssize_t tfoSendMsg(int fd, struct msghdr* msg, int msg_flags) {
return tfoSendMsg(NetworkSocket::fromFd(fd), msg, msg_flags);
}
virtual ssize_t
tfoSendMsg(NetworkSocket fd, struct msghdr* msg, int msg_flags);
int socketConnect(const struct sockaddr* addr, socklen_t len);
......@@ -1226,7 +1246,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
StateEnum state_; ///< StateEnum describing current state
uint8_t shutdownFlags_; ///< Shutdown state (ShutdownFlags)
uint16_t eventFlags_; ///< EventBase::HandlerFlags settings
int fd_; ///< The socket file descriptor
NetworkSocket fd_; ///< The socket file descriptor
mutable folly::SocketAddress addr_; ///< The address we tried to connect to
mutable folly::SocketAddress localAddr_;
///< The address we are connecting from
......
......@@ -2229,7 +2229,9 @@ class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
EventBase* evb)
: AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
MOCK_METHOD3(
tfoSendMsg,
ssize_t(NetworkSocket fd, struct msghdr* msg, int msg_flags));
};
#if defined __linux__
......@@ -2372,10 +2374,10 @@ MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
.Times(cardinality)
.WillOnce(Invoke([&](int fd, struct msghdr*, int) {
.WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
sockaddr_storage addr;
auto len = address.getAddress(&addr);
return connect(fd, (const struct sockaddr*)&addr, len);
return netops::connect(fd, (const struct sockaddr*)&addr, len);
}));
return socket;
}
......
......@@ -2585,7 +2585,9 @@ class MockAsyncTFOSocket : public AsyncSocket {
explicit MockAsyncTFOSocket(EventBase* evb) : AsyncSocket(evb) {}
MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
MOCK_METHOD3(
tfoSendMsg,
ssize_t(NetworkSocket fd, struct msghdr* msg, int msg_flags));
};
TEST(AsyncSocketTest, TestTFOUnsupported) {
......@@ -2646,10 +2648,10 @@ TEST(AsyncSocketTest, ConnectRefusedDelayedTFO) {
// Hopefully this fails
folly::SocketAddress fakeAddr("127.0.0.1", 65535);
EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
.WillOnce(Invoke([&](int fd, struct msghdr*, int) {
.WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
sockaddr_storage addr;
auto len = fakeAddr.getAddress(&addr);
int ret = connect(fd, (const struct sockaddr*)&addr, len);
auto ret = netops::connect(fd, (const struct sockaddr*)&addr, len);
LOG(INFO) << "connecting the socket " << fd << " : " << ret << " : "
<< errno;
return ret;
......@@ -2735,10 +2737,10 @@ TEST(AsyncSocketTest, TestTFOFallbackToConnect) {
socket->setReadCB(&rcb);
EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
.WillOnce(Invoke([&](int fd, struct msghdr*, int) {
.WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
sockaddr_storage addr;
auto len = server.getAddress().getAddress(&addr);
return connect(fd, (const struct sockaddr*)&addr, len);
return netops::connect(fd, (const struct sockaddr*)&addr, len);
}));
WriteCallback write;
auto sendBuf = IOBuf::copyBuffer("hey");
......@@ -2800,10 +2802,10 @@ TEST(AsyncSocketTest, TestTFOFallbackTimeout) {
socket->setReadCB(&rcb);
EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
.WillOnce(Invoke([&](int fd, struct msghdr*, int) {
.WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
sockaddr_storage addr2;
auto len = addr.getAddress(&addr2);
return connect(fd, (const struct sockaddr*)&addr2, len);
return netops::connect(fd, (const struct sockaddr*)&addr2, len);
}));
WriteCallback write;
socket->writeChain(&write, IOBuf::copyBuffer("hey"));
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment