Commit c7b21fd1 authored by Orvid King's avatar Orvid King Committed by Facebook Github Bot

NetworkSocket support for AsyncSocket

Summary: This is a big one, but adds support for NetworkSocket to AsyncSocket itself.

Reviewed By: yfeldblum

Differential Revision: D12818493

fbshipit-source-id: d7b73f356414b006f4fc6b1c380a8fa1c32b9a46
parent 083c73bd
......@@ -326,7 +326,7 @@ void AsyncSSLSocket::init() {
void AsyncSSLSocket::closeNow() {
// Close the SSL connection.
if (ssl_ != nullptr && fd_ != -1) {
if (ssl_ != nullptr && fd_ != NetworkSocket()) {
int rc = SSL_shutdown(ssl_.get());
if (rc == 0) {
rc = SSL_shutdown(ssl_.get());
......
This diff is collapsed.
......@@ -262,7 +262,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
* @param fd File descriptor to take over (should be a connected socket).
* @param zeroCopyBufId Zerocopy buf id to start with.
*/
AsyncSocket(EventBase* evb, int fd, uint32_t zeroCopyBufId = 0);
AsyncSocket(EventBase* evb, int fd, uint32_t zeroCopyBufId = 0)
: AsyncSocket(evb, NetworkSocket::fromFd(fd), zeroCopyBufId) {}
AsyncSocket(EventBase* evb, NetworkSocket fd, uint32_t zeroCopyBufId = 0);
/**
* Create an AsyncSocket from a different, already connected AsyncSocket.
......@@ -309,6 +311,11 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
* Helper function to create a shared_ptr<AsyncSocket>.
*/
static std::shared_ptr<AsyncSocket> newSocket(EventBase* evb, int fd) {
return newSocket(evb, NetworkSocket::fromFd(fd));
}
static std::shared_ptr<AsyncSocket> newSocket(
EventBase* evb,
NetworkSocket fd) {
return std::shared_ptr<AsyncSocket>(new AsyncSocket(evb, fd), Destructor());
}
......@@ -333,7 +340,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
* Get the file descriptor used by the AsyncSocket.
*/
virtual int getFd() const {
return fd_;
return fd_.toFd();
}
/**
......@@ -365,8 +372,11 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
}
return level < other.level;
}
int apply(NetworkSocket fd, int val) const {
return netops::setsockopt(fd, level, optname, &val, sizeof(val));
}
int apply(int fd, int val) const {
return setsockopt(fd, level, optname, &val, sizeof(val));
return apply(NetworkSocket::fromFd(fd), val);
}
int level;
int optname;
......@@ -712,7 +722,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
*/
template <typename T>
int getSockOpt(int level, int optname, T* optval, socklen_t* optlen) {
return getsockopt(fd_, level, optname, (void*)optval, optlen);
return netops::getsockopt(fd_, level, optname, (void*)optval, optlen);
}
/**
......@@ -725,7 +735,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
*/
template <typename T>
int setSockOpt(int level, int optname, const T* optval) {
return setsockopt(fd_, level, optname, optval, sizeof(T));
return netops::setsockopt(fd_, level, optname, optval, sizeof(T));
}
/**
......@@ -741,7 +751,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
*/
virtual int
getSockOptVirtual(int level, int optname, void* optval, socklen_t* optlen) {
return getsockopt(fd_, level, optname, optval, optlen);
return netops::getsockopt(fd_, level, optname, optval, optlen);
}
/**
......@@ -760,7 +770,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
int optname,
void const* optval,
socklen_t optlen) {
return setsockopt(fd_, level, optname, optval, optlen);
return netops::setsockopt(fd_, level, optname, optval, optlen);
}
/**
......@@ -1000,6 +1010,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
IoHandler(AsyncSocket* socket, EventBase* eventBase)
: EventHandler(eventBase, -1), socket_(socket) {}
IoHandler(AsyncSocket* socket, EventBase* eventBase, int fd)
: IoHandler(socket, eventBase, NetworkSocket::fromFd(fd)) {}
IoHandler(AsyncSocket* socket, EventBase* eventBase, NetworkSocket fd)
: EventHandler(eventBase, fd), socket_(socket) {}
void handlerReady(uint16_t events) noexcept override {
......@@ -1138,9 +1150,17 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
* @param msg_flags Flags to pass to sendmsg
*/
AsyncSocket::WriteResult
sendSocketMessage(int fd, struct msghdr* msg, int msg_flags);
sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
return sendSocketMessage(NetworkSocket::fromFd(fd), msg, msg_flags);
}
AsyncSocket::WriteResult
sendSocketMessage(NetworkSocket fd, struct msghdr* msg, int msg_flags);
virtual ssize_t tfoSendMsg(int fd, struct msghdr* msg, int msg_flags);
ssize_t tfoSendMsg(int fd, struct msghdr* msg, int msg_flags) {
return tfoSendMsg(NetworkSocket::fromFd(fd), msg, msg_flags);
}
virtual ssize_t
tfoSendMsg(NetworkSocket fd, struct msghdr* msg, int msg_flags);
int socketConnect(const struct sockaddr* addr, socklen_t len);
......@@ -1226,7 +1246,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
StateEnum state_; ///< StateEnum describing current state
uint8_t shutdownFlags_; ///< Shutdown state (ShutdownFlags)
uint16_t eventFlags_; ///< EventBase::HandlerFlags settings
int fd_; ///< The socket file descriptor
NetworkSocket fd_; ///< The socket file descriptor
mutable folly::SocketAddress addr_; ///< The address we tried to connect to
mutable folly::SocketAddress localAddr_;
///< The address we are connecting from
......
......@@ -2229,7 +2229,9 @@ class MockAsyncTFOSSLSocket : public AsyncSSLSocket {
EventBase* evb)
: AsyncSocket(evb), AsyncSSLSocket(sslCtx, evb) {}
MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
MOCK_METHOD3(
tfoSendMsg,
ssize_t(NetworkSocket fd, struct msghdr* msg, int msg_flags));
};
#if defined __linux__
......@@ -2372,10 +2374,10 @@ MockAsyncTFOSSLSocket::UniquePtr setupSocketWithFallback(
EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
.Times(cardinality)
.WillOnce(Invoke([&](int fd, struct msghdr*, int) {
.WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
sockaddr_storage addr;
auto len = address.getAddress(&addr);
return connect(fd, (const struct sockaddr*)&addr, len);
return netops::connect(fd, (const struct sockaddr*)&addr, len);
}));
return socket;
}
......
......@@ -2585,7 +2585,9 @@ class MockAsyncTFOSocket : public AsyncSocket {
explicit MockAsyncTFOSocket(EventBase* evb) : AsyncSocket(evb) {}
MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
MOCK_METHOD3(
tfoSendMsg,
ssize_t(NetworkSocket fd, struct msghdr* msg, int msg_flags));
};
TEST(AsyncSocketTest, TestTFOUnsupported) {
......@@ -2646,10 +2648,10 @@ TEST(AsyncSocketTest, ConnectRefusedDelayedTFO) {
// Hopefully this fails
folly::SocketAddress fakeAddr("127.0.0.1", 65535);
EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
.WillOnce(Invoke([&](int fd, struct msghdr*, int) {
.WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
sockaddr_storage addr;
auto len = fakeAddr.getAddress(&addr);
int ret = connect(fd, (const struct sockaddr*)&addr, len);
auto ret = netops::connect(fd, (const struct sockaddr*)&addr, len);
LOG(INFO) << "connecting the socket " << fd << " : " << ret << " : "
<< errno;
return ret;
......@@ -2735,10 +2737,10 @@ TEST(AsyncSocketTest, TestTFOFallbackToConnect) {
socket->setReadCB(&rcb);
EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
.WillOnce(Invoke([&](int fd, struct msghdr*, int) {
.WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
sockaddr_storage addr;
auto len = server.getAddress().getAddress(&addr);
return connect(fd, (const struct sockaddr*)&addr, len);
return netops::connect(fd, (const struct sockaddr*)&addr, len);
}));
WriteCallback write;
auto sendBuf = IOBuf::copyBuffer("hey");
......@@ -2800,10 +2802,10 @@ TEST(AsyncSocketTest, TestTFOFallbackTimeout) {
socket->setReadCB(&rcb);
EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
.WillOnce(Invoke([&](int fd, struct msghdr*, int) {
.WillOnce(Invoke([&](NetworkSocket fd, struct msghdr*, int) {
sockaddr_storage addr2;
auto len = addr.getAddress(&addr2);
return connect(fd, (const struct sockaddr*)&addr2, len);
return netops::connect(fd, (const struct sockaddr*)&addr2, len);
}));
WriteCallback write;
socket->writeChain(&write, IOBuf::copyBuffer("hey"));
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment