Commit e31eb32a authored by Junqi Wang's avatar Junqi Wang Committed by Facebook GitHub Bot

Allow connect before bind

Summary: connect will automatically bind the socket if the socket is not bound yet

Reviewed By: yangchi

Differential Revision: D21845740

fbshipit-source-id: 27a5b44476dfc0b2ae5ff2f0a6c1bd4e976eadc9
parent d11cbbc9
...@@ -415,7 +415,7 @@ void testAsyncUDPRecvmsg(bool useRegisteredFds) { ...@@ -415,7 +415,7 @@ void testAsyncUDPRecvmsg(bool useRegisteredFds) {
serverSocketVec.emplace_back(std::move(serverSock)); serverSocketVec.emplace_back(std::move(serverSock));
// connect the client // connect the client
CHECK_EQ(clientSock->connect(addr), 0); clientSock->connect(addr);
for (size_t j = 0; j < kNumPackets; j++) { for (size_t j = 0; j < kNumPackets; j++) {
auto buf = folly::IOBuf::copyBuffer(data.c_str(), data.size()); auto buf = folly::IOBuf::copyBuffer(data.c_str(), data.size());
CHECK_EQ(clientSock->write(addr, std::move(buf)), data.size()); CHECK_EQ(clientSock->write(addr, std::move(buf)), data.size());
......
...@@ -58,11 +58,9 @@ AsyncUDPSocket::~AsyncUDPSocket() { ...@@ -58,11 +58,9 @@ AsyncUDPSocket::~AsyncUDPSocket() {
} }
} }
void AsyncUDPSocket::bind(const folly::SocketAddress& address) { void AsyncUDPSocket::init(sa_family_t family) {
NetworkSocket socket = netops::socket( NetworkSocket socket =
address.getFamily(), netops::socket(family, SOCK_DGRAM, family != AF_UNIX ? IPPROTO_UDP : 0);
SOCK_DGRAM,
address.getFamily() != AF_UNIX ? IPPROTO_UDP : 0);
if (socket == NetworkSocket()) { if (socket == NetworkSocket()) {
throw AsyncSocketException( throw AsyncSocketException(
AsyncSocketException::NOT_OPEN, AsyncSocketException::NOT_OPEN,
...@@ -148,7 +146,7 @@ void AsyncUDPSocket::bind(const folly::SocketAddress& address) { ...@@ -148,7 +146,7 @@ void AsyncUDPSocket::bind(const folly::SocketAddress& address) {
} }
// If we're using IPv6, make sure we don't accept V4-mapped connections // If we're using IPv6, make sure we don't accept V4-mapped connections
if (address.getFamily() == AF_INET6) { if (family == AF_INET6) {
int flag = 1; int flag = 1;
if (netops::setsockopt( if (netops::setsockopt(
socket, IPPROTO_IPV6, IPV6_V6ONLY, &flag, sizeof(flag))) { socket, IPPROTO_IPV6, IPV6_V6ONLY, &flag, sizeof(flag))) {
...@@ -157,25 +155,29 @@ void AsyncUDPSocket::bind(const folly::SocketAddress& address) { ...@@ -157,25 +155,29 @@ void AsyncUDPSocket::bind(const folly::SocketAddress& address) {
} }
} }
// success
g.dismiss();
fd_ = socket;
ownership_ = FDOwnership::OWNS;
// attach to EventHandler
EventHandler::changeHandlerFD(fd_);
}
void AsyncUDPSocket::bind(const folly::SocketAddress& address) {
init(address.getFamily());
// bind to the address // bind to the address
sockaddr_storage addrStorage; sockaddr_storage addrStorage;
address.getAddress(&addrStorage); address.getAddress(&addrStorage);
auto& saddr = reinterpret_cast<sockaddr&>(addrStorage); auto& saddr = reinterpret_cast<sockaddr&>(addrStorage);
if (netops::bind(socket, &saddr, address.getActualSize()) != 0) { if (netops::bind(fd_, &saddr, address.getActualSize()) != 0) {
throw AsyncSocketException( throw AsyncSocketException(
AsyncSocketException::NOT_OPEN, AsyncSocketException::NOT_OPEN,
"failed to bind the async udp socket for:" + address.describe(), "failed to bind the async udp socket for:" + address.describe(),
errno); errno);
} }
// success
g.dismiss();
fd_ = socket;
ownership_ = FDOwnership::OWNS;
// attach to EventHandler
EventHandler::changeHandlerFD(fd_);
if (address.getFamily() == AF_UNIX || address.getPort() != 0) { if (address.getFamily() == AF_UNIX || address.getPort() != 0) {
localAddress_ = address; localAddress_ = address;
} else { } else {
...@@ -183,17 +185,29 @@ void AsyncUDPSocket::bind(const folly::SocketAddress& address) { ...@@ -183,17 +185,29 @@ void AsyncUDPSocket::bind(const folly::SocketAddress& address) {
} }
} }
int AsyncUDPSocket::connect(const folly::SocketAddress& address) { void AsyncUDPSocket::connect(const folly::SocketAddress& address) {
CHECK_NE(NetworkSocket(), fd_) << "Socket not yet bound"; // not bound yet
if (fd_ == NetworkSocket()) {
init(address.getFamily());
}
sockaddr_storage addrStorage; sockaddr_storage addrStorage;
address.getAddress(&addrStorage); address.getAddress(&addrStorage);
int ret = netops::connect( if (netops::connect(
fd_, reinterpret_cast<sockaddr*>(&addrStorage), address.getActualSize()); fd_,
if (ret == 0) { reinterpret_cast<sockaddr*>(&addrStorage),
connected_ = true; address.getActualSize()) != 0) {
connectedAddress_ = address; throw AsyncSocketException(
AsyncSocketException::NOT_OPEN,
"Failed to connect the udp socket to:" + address.describe(),
errno);
}
connected_ = true;
connectedAddress_ = address;
if (!localAddress_.isInitialized()) {
localAddress_.setFromLocalAddress(fd_);
} }
return ret;
} }
void AsyncUDPSocket::dontFragment(bool df) { void AsyncUDPSocket::dontFragment(bool df) {
......
...@@ -152,7 +152,8 @@ class AsyncUDPSocket : public EventHandler { ...@@ -152,7 +152,8 @@ class AsyncUDPSocket : public EventHandler {
* state on connects. * state on connects.
* Using connect has many quirks, and you should be aware of them before using * Using connect has many quirks, and you should be aware of them before using
* this API: * this API:
* 1. This must only be called after binding the socket. * 1. If this is called before bind, the socket will be automatically bound to
* the IP address of the current default network interface.
* 2. Normally UDP can use the 2 tuple (src ip, src port) to steer packets * 2. Normally UDP can use the 2 tuple (src ip, src port) to steer packets
* sent by the peer to the socket, however after connecting the socket, only * sent by the peer to the socket, however after connecting the socket, only
* packets destined to the destination address specified in connect() will be * packets destined to the destination address specified in connect() will be
...@@ -164,7 +165,7 @@ class AsyncUDPSocket : public EventHandler { ...@@ -164,7 +165,7 @@ class AsyncUDPSocket : public EventHandler {
* *
* Returns the result of calling the connect syscall. * Returns the result of calling the connect syscall.
*/ */
virtual int connect(const folly::SocketAddress& address); virtual void connect(const folly::SocketAddress& address);
/** /**
* Use an already bound file descriptor. You can either transfer ownership * Use an already bound file descriptor. You can either transfer ownership
...@@ -440,6 +441,8 @@ class AsyncUDPSocket : public EventHandler { ...@@ -440,6 +441,8 @@ class AsyncUDPSocket : public EventHandler {
AsyncUDPSocket(const AsyncUDPSocket&) = delete; AsyncUDPSocket(const AsyncUDPSocket&) = delete;
AsyncUDPSocket& operator=(const AsyncUDPSocket&) = delete; AsyncUDPSocket& operator=(const AsyncUDPSocket&) = delete;
void init(sa_family_t family);
// EventHandler // EventHandler
void handlerReady(uint16_t events) noexcept override; void handlerReady(uint16_t events) noexcept override;
......
...@@ -303,11 +303,7 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout { ...@@ -303,11 +303,7 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
} }
void connect() { void connect() {
int ret = socket_->connect(*connectAddr_); socket_->connect(*connectAddr_);
if (ret != 0) {
throw folly::AsyncSocketException(
folly::AsyncSocketException::NOT_OPEN, "ConnectFail", errno);
}
VLOG(2) << "Client connected to address=" << *connectAddr_; VLOG(2) << "Client connected to address=" << *connectAddr_;
} }
......
...@@ -212,11 +212,7 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout { ...@@ -212,11 +212,7 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
} }
void connect() { void connect() {
int ret = socket_->connect(*connectAddr_); socket_->connect(*connectAddr_);
if (ret != 0) {
throw folly::AsyncSocketException(
folly::AsyncSocketException::NOT_OPEN, "ConnectFail", errno);
}
VLOG(2) << "Client connected to address=" << *connectAddr_; VLOG(2) << "Client connected to address=" << *connectAddr_;
} }
......
...@@ -173,6 +173,8 @@ class UDPServer { ...@@ -173,6 +173,8 @@ class UDPServer {
bool changePortForWrites_{true}; bool changePortForWrites_{true};
}; };
enum class BindSocket { YES, NO };
class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout { class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
public: public:
using AsyncUDPSocket::ReadCallback::OnDataAvailableParams; using AsyncUDPSocket::ReadCallback::OnDataAvailableParams;
...@@ -188,9 +190,12 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout { ...@@ -188,9 +190,12 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
socket_ = std::make_unique<AsyncUDPSocket>(evb_); socket_ = std::make_unique<AsyncUDPSocket>(evb_);
try { try {
socket_->bind(folly::SocketAddress("127.0.0.1", 0)); if (bindSocket_ == BindSocket::YES) {
socket_->bind(folly::SocketAddress("127.0.0.1", 0));
}
if (connectAddr_) { if (connectAddr_) {
connect(); socket_->connect(*connectAddr_);
VLOG(2) << "Client connected to address=" << *connectAddr_;
} }
VLOG(2) << "Client bound to " << socket_->address().describe(); VLOG(2) << "Client bound to " << socket_->address().describe();
} catch (const std::exception& ex) { } catch (const std::exception& ex) {
...@@ -209,15 +214,6 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout { ...@@ -209,15 +214,6 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
} }
} }
void connect() {
int ret = socket_->connect(*connectAddr_);
if (ret != 0) {
throw folly::AsyncSocketException(
folly::AsyncSocketException::NOT_OPEN, "ConnectFail", errno);
}
VLOG(2) << "Client connected to address=" << *connectAddr_;
}
void shutdown() { void shutdown() {
CHECK(evb_->isInEventBaseThread()); CHECK(evb_->isInEventBaseThread());
socket_->pauseRead(); socket_->pauseRead();
...@@ -295,8 +291,11 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout { ...@@ -295,8 +291,11 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
return *socket_; return *socket_;
} }
void setShouldConnect(const folly::SocketAddress& connectAddr) { void setShouldConnect(
const folly::SocketAddress& connectAddr,
BindSocket bindSocket) {
connectAddr_ = connectAddr; connectAddr_ = connectAddr;
bindSocket_ = bindSocket;
} }
bool error() const { bool error() const {
...@@ -309,6 +308,7 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout { ...@@ -309,6 +308,7 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
protected: protected:
folly::Optional<folly::SocketAddress> connectAddr_; folly::Optional<folly::SocketAddress> connectAddr_;
BindSocket bindSocket_{BindSocket::YES};
EventBase* const evb_{nullptr}; EventBase* const evb_{nullptr};
folly::SocketAddress server_; folly::SocketAddress server_;
...@@ -473,16 +473,19 @@ class AsyncSocketIntegrationTest : public Test { ...@@ -473,16 +473,19 @@ class AsyncSocketIntegrationTest : public Test {
std::unique_ptr<UDPClient> performPingPongTest( std::unique_ptr<UDPClient> performPingPongTest(
folly::SocketAddress writeAddress, folly::SocketAddress writeAddress,
folly::Optional<folly::SocketAddress> connectedAddress); folly::Optional<folly::SocketAddress> connectedAddress,
BindSocket bindSocket = BindSocket::YES);
std::unique_ptr<UDPNotifyClient> performPingPongNotifyTest( std::unique_ptr<UDPNotifyClient> performPingPongNotifyTest(
folly::SocketAddress writeAddress, folly::SocketAddress writeAddress,
folly::Optional<folly::SocketAddress> connectedAddress); folly::Optional<folly::SocketAddress> connectedAddress,
BindSocket bindSocket = BindSocket::YES);
std::unique_ptr<UDPNotifyClient> performPingPongNotifyMmsgTest( std::unique_ptr<UDPNotifyClient> performPingPongNotifyMmsgTest(
folly::SocketAddress writeAddress, folly::SocketAddress writeAddress,
unsigned int numMsgs, unsigned int numMsgs,
folly::Optional<folly::SocketAddress> connectedAddress); folly::Optional<folly::SocketAddress> connectedAddress,
BindSocket bindSocket = BindSocket::YES);
folly::EventBase sevb; folly::EventBase sevb;
folly::EventBase cevb; folly::EventBase cevb;
...@@ -492,10 +495,11 @@ class AsyncSocketIntegrationTest : public Test { ...@@ -492,10 +495,11 @@ class AsyncSocketIntegrationTest : public Test {
std::unique_ptr<UDPClient> AsyncSocketIntegrationTest::performPingPongTest( std::unique_ptr<UDPClient> AsyncSocketIntegrationTest::performPingPongTest(
folly::SocketAddress writeAddress, folly::SocketAddress writeAddress,
folly::Optional<folly::SocketAddress> connectedAddress) { folly::Optional<folly::SocketAddress> connectedAddress,
BindSocket bindSocket) {
auto client = std::make_unique<UDPClient>(&cevb); auto client = std::make_unique<UDPClient>(&cevb);
if (connectedAddress) { if (connectedAddress) {
client->setShouldConnect(*connectedAddress); client->setShouldConnect(*connectedAddress, bindSocket);
} }
// Start event loop in a separate thread // Start event loop in a separate thread
auto clientThread = std::thread([this]() { cevb.loopForever(); }); auto clientThread = std::thread([this]() { cevb.loopForever(); });
...@@ -514,10 +518,11 @@ std::unique_ptr<UDPClient> AsyncSocketIntegrationTest::performPingPongTest( ...@@ -514,10 +518,11 @@ std::unique_ptr<UDPClient> AsyncSocketIntegrationTest::performPingPongTest(
std::unique_ptr<UDPNotifyClient> std::unique_ptr<UDPNotifyClient>
AsyncSocketIntegrationTest::performPingPongNotifyTest( AsyncSocketIntegrationTest::performPingPongNotifyTest(
folly::SocketAddress writeAddress, folly::SocketAddress writeAddress,
folly::Optional<folly::SocketAddress> connectedAddress) { folly::Optional<folly::SocketAddress> connectedAddress,
BindSocket bindSocket) {
auto client = std::make_unique<UDPNotifyClient>(&cevb); auto client = std::make_unique<UDPNotifyClient>(&cevb);
if (connectedAddress) { if (connectedAddress) {
client->setShouldConnect(*connectedAddress); client->setShouldConnect(*connectedAddress, bindSocket);
} }
// Start event loop in a separate thread // Start event loop in a separate thread
auto clientThread = std::thread([this]() { cevb.loopForever(); }); auto clientThread = std::thread([this]() { cevb.loopForever(); });
...@@ -537,10 +542,11 @@ std::unique_ptr<UDPNotifyClient> ...@@ -537,10 +542,11 @@ std::unique_ptr<UDPNotifyClient>
AsyncSocketIntegrationTest::performPingPongNotifyMmsgTest( AsyncSocketIntegrationTest::performPingPongNotifyMmsgTest(
folly::SocketAddress writeAddress, folly::SocketAddress writeAddress,
unsigned int numMsgs, unsigned int numMsgs,
folly::Optional<folly::SocketAddress> connectedAddress) { folly::Optional<folly::SocketAddress> connectedAddress,
BindSocket bindSocket) {
auto client = std::make_unique<UDPNotifyClient>(&cevb, true, numMsgs); auto client = std::make_unique<UDPNotifyClient>(&cevb, true, numMsgs);
if (connectedAddress) { if (connectedAddress) {
client->setShouldConnect(*connectedAddress); client->setShouldConnect(*connectedAddress, bindSocket);
} }
// Start event loop in a separate thread // Start event loop in a separate thread
auto clientThread = std::thread([this]() { cevb.loopForever(); }); auto clientThread = std::thread([this]() { cevb.loopForever(); });
...@@ -581,44 +587,63 @@ TEST_F(AsyncSocketIntegrationTest, PingPongNotifyMmsg) { ...@@ -581,44 +587,63 @@ TEST_F(AsyncSocketIntegrationTest, PingPongNotifyMmsg) {
ASSERT_TRUE(pingClient->notifyInvoked); ASSERT_TRUE(pingClient->notifyInvoked);
} }
TEST_F(AsyncSocketIntegrationTest, ConnectedPingPong) { class ConnectedAsyncSocketIntegrationTest
: public AsyncSocketIntegrationTest,
public WithParamInterface<BindSocket> {};
TEST_P(ConnectedAsyncSocketIntegrationTest, ConnectedPingPong) {
server->setChangePortForWrites(false); server->setChangePortForWrites(false);
startServer(); startServer();
auto pingClient = performPingPongTest(server->address(), server->address()); auto pingClient =
performPingPongTest(server->address(), server->address(), GetParam());
// This should succeed // This should succeed
ASSERT_GT(pingClient->pongRecvd(), 0); ASSERT_GT(pingClient->pongRecvd(), 0);
} }
TEST_F(AsyncSocketIntegrationTest, ConnectedPingPongServerWrongAddress) { TEST_P(
ConnectedAsyncSocketIntegrationTest,
ConnectedPingPongServerWrongAddress) {
server->setChangePortForWrites(true); server->setChangePortForWrites(true);
startServer(); startServer();
auto pingClient = performPingPongTest(server->address(), server->address()); auto pingClient =
performPingPongTest(server->address(), server->address(), GetParam());
// This should fail. // This should fail.
ASSERT_EQ(pingClient->pongRecvd(), 0); ASSERT_EQ(pingClient->pongRecvd(), 0);
} }
TEST_F(AsyncSocketIntegrationTest, ConnectedPingPongClientWrongAddress) { TEST_P(
ConnectedAsyncSocketIntegrationTest,
ConnectedPingPongClientWrongAddress) {
server->setChangePortForWrites(false); server->setChangePortForWrites(false);
startServer(); startServer();
folly::SocketAddress connectAddr( folly::SocketAddress connectAddr(
server->address().getIPAddress(), server->address().getPort() + 1); server->address().getIPAddress(), server->address().getPort() + 1);
auto pingClient = performPingPongTest(server->address(), connectAddr); auto pingClient =
performPingPongTest(server->address(), connectAddr, GetParam());
// This should fail. // This should fail.
ASSERT_EQ(pingClient->pongRecvd(), 0); ASSERT_EQ(pingClient->pongRecvd(), 0);
EXPECT_TRUE(pingClient->error()); EXPECT_TRUE(pingClient->error());
} }
TEST_F(AsyncSocketIntegrationTest, ConnectedPingPongDifferentWriteAddress) { TEST_P(
ConnectedAsyncSocketIntegrationTest,
ConnectedPingPongDifferentWriteAddress) {
server->setChangePortForWrites(false); server->setChangePortForWrites(false);
startServer(); startServer();
folly::SocketAddress connectAddr( folly::SocketAddress connectAddr(
server->address().getIPAddress(), server->address().getPort() + 1); server->address().getIPAddress(), server->address().getPort() + 1);
auto pingClient = performPingPongTest(connectAddr, server->address()); auto pingClient =
performPingPongTest(connectAddr, server->address(), GetParam());
// This should fail. // This should fail.
ASSERT_EQ(pingClient->pongRecvd(), 0); ASSERT_EQ(pingClient->pongRecvd(), 0);
EXPECT_TRUE(pingClient->error()); EXPECT_TRUE(pingClient->error());
} }
INSTANTIATE_TEST_CASE_P(
ConnectedAsyncSocketIntegrationTests,
ConnectedAsyncSocketIntegrationTest,
Values(BindSocket::YES, BindSocket::NO));
TEST_F(AsyncSocketIntegrationTest, PingPongPauseResumeListening) { TEST_F(AsyncSocketIntegrationTest, PingPongPauseResumeListening) {
startServer(); startServer();
...@@ -703,8 +728,20 @@ class AsyncUDPSocketTest : public Test { ...@@ -703,8 +728,20 @@ class AsyncUDPSocketTest : public Test {
folly::SocketAddress addr_; folly::SocketAddress addr_;
}; };
TEST_F(AsyncUDPSocketTest, TestConnectAfterBind) {
socket_->connect(addr_);
}
TEST_F(AsyncUDPSocketTest, TestConnect) { TEST_F(AsyncUDPSocketTest, TestConnect) {
EXPECT_EQ(socket_->connect(addr_), 0); AsyncUDPSocket socket(&evb_);
EXPECT_FALSE(socket.isBound());
folly::SocketAddress address("127.0.0.1", 443);
socket.connect(address);
EXPECT_TRUE(socket.isBound());
const auto& localAddr = socket.address();
EXPECT_TRUE(localAddr.isInitialized());
EXPECT_GT(localAddr.getPort(), 0);
} }
TEST_F(AsyncUDPSocketTest, TestErrToNonExistentServer) { TEST_F(AsyncUDPSocketTest, TestErrToNonExistentServer) {
......
...@@ -50,7 +50,7 @@ struct MockAsyncUDPSocket : public AsyncUDPSocket { ...@@ -50,7 +50,7 @@ struct MockAsyncUDPSocket : public AsyncUDPSocket {
MOCK_METHOD1(setReuseAddr, void(bool)); MOCK_METHOD1(setReuseAddr, void(bool));
MOCK_METHOD1(dontFragment, void(bool)); MOCK_METHOD1(dontFragment, void(bool));
MOCK_METHOD1(setErrMessageCallback, void(ErrMessageCallback*)); MOCK_METHOD1(setErrMessageCallback, void(ErrMessageCallback*));
MOCK_METHOD1(connect, int(const SocketAddress&)); MOCK_METHOD1(connect, void(const SocketAddress&));
MOCK_CONST_METHOD0(isBound, bool()); MOCK_CONST_METHOD0(isBound, bool());
MOCK_METHOD0(getGSO, int()); MOCK_METHOD0(getGSO, int());
MOCK_METHOD1(setGSO, bool(int)); MOCK_METHOD1(setGSO, bool(int));
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment