Commit 50c13594 authored by Subodh Iyengar's avatar Subodh Iyengar Committed by Facebook Github Bot

add connect method to AsyncUDPSocket

Summary:
Calling connect on the UDP socket can avoid route lookups
on the write path. This would be an optimization that would be
useful for clients.

There are several caveats of using connect() which we discovered
during testing. These are documented and unit tests have been written
for each of them.

Reviewed By: yfeldblum

Differential Revision: D7802389

fbshipit-source-id: 09b71373a3a95c5dab73ee9345db0dbbf66d4ec5
parent 7a500026
...@@ -380,6 +380,14 @@ void AsyncUDPSocket::failErrMessageRead(const AsyncSocketException& ex) { ...@@ -380,6 +380,14 @@ void AsyncUDPSocket::failErrMessageRead(const AsyncSocketException& ex) {
} }
} }
int AsyncUDPSocket::connect(const folly::SocketAddress& address) {
CHECK_NE(-1, fd_) << "Socket not yet bound";
sockaddr_storage addrStorage;
address.getAddress(&addrStorage);
return fsp::connect(
fd_, reinterpret_cast<sockaddr*>(&addrStorage), address.getActualSize());
}
void AsyncUDPSocket::handleRead() noexcept { void AsyncUDPSocket::handleRead() noexcept {
void* buf{nullptr}; void* buf{nullptr};
size_t len{0}; size_t len{0};
......
...@@ -224,6 +224,26 @@ class AsyncUDPSocket : public EventHandler { ...@@ -224,6 +224,26 @@ class AsyncUDPSocket : public EventHandler {
*/ */
virtual void setErrMessageCallback(ErrMessageCallback* errMessageCallback); virtual void setErrMessageCallback(ErrMessageCallback* errMessageCallback);
/**
* Connects the UDP socket to a remote destination address provided in
* address. This can speed up UDP writes on linux because it will cache flow
* state on connects.
* Using connect has many quirks, and you should be aware of them before using
* this API:
* 1. This must only be called after binding the socket.
* 2. Normally UDP can use the 2 tuple (src ip, src port) to steer packets
* sent by the peer to the socket, however after connecting the socket, only
* packets destined to the destination address specified in connect() will be
* forwarded and others will be dropped. If the server can send a packet
* from a different destination port / IP then you probably do not want to use
* this API.
* 3. It can be called repeatedly on either the client or server however it's
* normally only useful on the client and not server.
*
* Returns the result of calling the connect syscall.
*/
virtual int connect(const folly::SocketAddress& address);
protected: protected:
virtual ssize_t sendmsg(int socket, const struct msghdr* message, int flags) { virtual ssize_t sendmsg(int socket, const struct msghdr* message, int flags) {
return ::sendmsg(socket, message, flags); return ::sendmsg(socket, message, flags);
......
...@@ -36,33 +36,38 @@ using namespace testing; ...@@ -36,33 +36,38 @@ using namespace testing;
class UDPAcceptor : public AsyncUDPServerSocket::Callback { class UDPAcceptor : public AsyncUDPServerSocket::Callback {
public: public:
UDPAcceptor(EventBase* evb, int n) : evb_(evb), n_(n) {} UDPAcceptor(EventBase* evb, int n, bool changePortForWrites)
: evb_(evb), n_(n), changePortForWrites_(changePortForWrites) {}
void onListenStarted() noexcept override {} void onListenStarted() noexcept override {}
void onListenStopped() noexcept override {} void onListenStopped() noexcept override {}
void onDataAvailable( void onDataAvailable(
std::shared_ptr<folly::AsyncUDPSocket> /* socket */, std::shared_ptr<folly::AsyncUDPSocket> socket,
const folly::SocketAddress& client, const folly::SocketAddress& client,
std::unique_ptr<folly::IOBuf> data, std::unique_ptr<folly::IOBuf> data,
bool truncated) noexcept override { bool truncated) noexcept override {
lastClient_ = client; lastClient_ = client;
lastMsg_ = data->moveToFbString().toStdString(); lastMsg_ = data->clone()->moveToFbString().toStdString();
auto len = data->computeChainDataLength(); auto len = data->computeChainDataLength();
VLOG(4) << "Worker " << n_ << " read " << len << " bytes " VLOG(4) << "Worker " << n_ << " read " << len << " bytes "
<< "(trun:" << truncated << ") from " << client.describe() << " - " << "(trun:" << truncated << ") from " << client.describe() << " - "
<< lastMsg_; << lastMsg_;
sendPong(); sendPong(socket);
} }
void sendPong() noexcept { void sendPong(std::shared_ptr<folly::AsyncUDPSocket> socket) noexcept {
try { try {
AsyncUDPSocket socket(evb_); auto writeSocket = socket;
socket.bind(folly::SocketAddress("127.0.0.1", 0)); if (changePortForWrites_) {
socket.write(lastClient_, folly::IOBuf::copyBuffer(lastMsg_)); writeSocket = std::make_shared<folly::AsyncUDPSocket>(evb_);
writeSocket->setReuseAddr(false);
writeSocket->bind(folly::SocketAddress("127.0.0.1", 0));
}
writeSocket->write(lastClient_, folly::IOBuf::copyBuffer(lastMsg_));
} catch (const std::exception& ex) { } catch (const std::exception& ex) {
VLOG(4) << "Failed to send PONG " << ex.what(); VLOG(4) << "Failed to send PONG " << ex.what();
} }
...@@ -71,6 +76,8 @@ class UDPAcceptor : public AsyncUDPServerSocket::Callback { ...@@ -71,6 +76,8 @@ class UDPAcceptor : public AsyncUDPServerSocket::Callback {
private: private:
EventBase* const evb_{nullptr}; EventBase* const evb_{nullptr};
const int n_{-1}; const int n_{-1};
// Whether to create a new port per write.
bool changePortForWrites_{true};
folly::SocketAddress lastClient_; folly::SocketAddress lastClient_;
std::string lastMsg_; std::string lastMsg_;
...@@ -99,7 +106,7 @@ class UDPServer { ...@@ -99,7 +106,7 @@ class UDPServer {
// Add numWorkers thread // Add numWorkers thread
int i = 0; int i = 0;
for (auto& evb : evbs_) { for (auto& evb : evbs_) {
acceptors_.emplace_back(&evb, i); acceptors_.emplace_back(&evb, i, changePortForWrites_);
std::thread t([&]() { evb.loopForever(); }); std::thread t([&]() { evb.loopForever(); });
...@@ -131,6 +138,11 @@ class UDPServer { ...@@ -131,6 +138,11 @@ class UDPServer {
} }
} }
// Whether writes from the UDP server should change the port for each message.
void setChangePortForWrites(bool changePortForWrites) {
changePortForWrites_ = changePortForWrites;
}
private: private:
EventBase* const evb_{nullptr}; EventBase* const evb_{nullptr};
const folly::SocketAddress addr_; const folly::SocketAddress addr_;
...@@ -139,6 +151,7 @@ class UDPServer { ...@@ -139,6 +151,7 @@ class UDPServer {
std::vector<std::thread> threads_; std::vector<std::thread> threads_;
std::vector<folly::EventBase> evbs_; std::vector<folly::EventBase> evbs_;
std::vector<UDPAcceptor> acceptors_; std::vector<UDPAcceptor> acceptors_;
bool changePortForWrites_{true};
}; };
class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout { class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
...@@ -147,13 +160,15 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout { ...@@ -147,13 +160,15 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
void start(const folly::SocketAddress& server, int n) { void start(const folly::SocketAddress& server, int n) {
CHECK(evb_->isInEventBaseThread()); CHECK(evb_->isInEventBaseThread());
server_ = server; server_ = server;
socket_ = std::make_unique<AsyncUDPSocket>(evb_); socket_ = std::make_unique<AsyncUDPSocket>(evb_);
try { try {
socket_->bind(folly::SocketAddress("127.0.0.1", 0)); socket_->bind(folly::SocketAddress("127.0.0.1", 0));
VLOG(4) << "Client bound to " << socket_->address().describe(); if (connectAddr_) {
connect();
}
VLOG(2) << "Client bound to " << socket_->address().describe();
} catch (const std::exception& ex) { } catch (const std::exception& ex) {
LOG(FATAL) << ex.what(); LOG(FATAL) << ex.what();
} }
...@@ -166,6 +181,15 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout { ...@@ -166,6 +181,15 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
sendPing(); sendPing();
} }
void connect() {
int ret = socket_->connect(*connectAddr_);
if (ret != 0) {
throw folly::AsyncSocketException(
folly::AsyncSocketException::NOT_OPEN, "ConnectFail", errno);
}
VLOG(2) << "Client connected to address=" << *connectAddr_;
}
void shutdown() { void shutdown() {
CHECK(evb_->isInEventBaseThread()); CHECK(evb_->isInEventBaseThread());
socket_->pauseRead(); socket_->pauseRead();
...@@ -182,8 +206,11 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout { ...@@ -182,8 +206,11 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
--n_; --n_;
scheduleTimeout(5); scheduleTimeout(5);
socket_->write( writePing(folly::IOBuf::copyBuffer(folly::to<std::string>("PING ", n_)));
server_, folly::IOBuf::copyBuffer(folly::to<std::string>("PING ", n_))); }
virtual void writePing(std::unique_ptr<folly::IOBuf> buf) {
socket_->write(server_, std::move(buf));
} }
void getReadBuffer(void** buf, size_t* len) noexcept override { void getReadBuffer(void** buf, size_t* len) noexcept override {
...@@ -224,58 +251,182 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout { ...@@ -224,58 +251,182 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
return pongRecvd_; return pongRecvd_;
} }
private: AsyncUDPSocket& getSocket() {
return *socket_;
}
void setShouldConnect(const folly::SocketAddress& connectAddr) {
connectAddr_ = connectAddr;
}
protected:
folly::Optional<folly::SocketAddress> connectAddr_;
EventBase* const evb_{nullptr}; EventBase* const evb_{nullptr};
folly::SocketAddress server_; folly::SocketAddress server_;
std::unique_ptr<AsyncUDPSocket> socket_; std::unique_ptr<AsyncUDPSocket> socket_;
private:
int pongRecvd_{0}; int pongRecvd_{0};
int n_{0}; int n_{0};
char buf_[1024]; char buf_[1024];
}; };
TEST(AsyncSocketTest, PingPong) { class ConnectedWriteUDPClient : public UDPClient {
folly::EventBase sevb; public:
UDPServer server(&sevb, folly::SocketAddress("127.0.0.1", 0), 4); ~ConnectedWriteUDPClient() override = default;
ConnectedWriteUDPClient(EventBase* evb) : UDPClient(evb) {}
// When the socket is connected you don't need to supply the address to send
// msg. This will test that connect worked.
void writePing(std::unique_ptr<folly::IOBuf> buf) override {
iovec vec[16];
size_t iovec_len = buf->fillIov(vec, sizeof(vec) / sizeof(vec[0]));
if (UNLIKELY(iovec_len == 0)) {
buf->coalesce();
vec[0].iov_base = const_cast<uint8_t*>(buf->data());
vec[0].iov_len = buf->length();
iovec_len = 1;
}
// Start event loop in a separate thread struct msghdr msg;
auto serverThread = std::thread([&sevb]() { sevb.loopForever(); }); msg.msg_name = nullptr;
msg.msg_namelen = 0;
msg.msg_iov = const_cast<struct iovec*>(vec);
msg.msg_iovlen = iovec_len;
msg.msg_control = nullptr;
msg.msg_controllen = 0;
msg.msg_flags = 0;
ssize_t ret = ::sendmsg(socket_->getFD(), &msg, 0);
if (ret == -1) {
if (errno != EAGAIN || errno != EWOULDBLOCK) {
throw folly::AsyncSocketException(
folly::AsyncSocketException::NOT_OPEN, "WriteFail", errno);
}
}
connect();
}
};
// Wait for event loop to start class AsyncSocketIntegrationTest : public Test {
sevb.waitUntilRunning(); public:
void SetUp() override {
server = std::make_unique<UDPServer>(
&sevb, folly::SocketAddress("127.0.0.1", 0), 4);
// Start the server // Start event loop in a separate thread
sevb.runInEventBaseThreadAndWait([&]() { server.start(); }); serverThread =
std::make_unique<std::thread>([this]() { sevb.loopForever(); });
// Wait for event loop to start
sevb.waitUntilRunning();
}
void startServer() {
// Start the server
sevb.runInEventBaseThreadAndWait([&]() { server->start(); });
LOG(INFO) << "Server listening=" << server->address();
}
void TearDown() override {
// Shutdown server
sevb.runInEventBaseThread([&]() {
server->shutdown();
sevb.terminateLoopSoon();
});
// Wait for server thread to joib
serverThread->join();
}
std::unique_ptr<UDPClient> performPingPongTest(
folly::Optional<folly::SocketAddress> connectedAddress,
bool useConnectedWrite);
folly::EventBase sevb;
folly::EventBase cevb; folly::EventBase cevb;
UDPClient client(&cevb); std::unique_ptr<std::thread> serverThread;
std::unique_ptr<UDPServer> server;
std::unique_ptr<UDPClient> client;
};
std::unique_ptr<UDPClient> AsyncSocketIntegrationTest::performPingPongTest(
folly::Optional<folly::SocketAddress> connectedAddress,
bool useConnectedWrite) {
if (useConnectedWrite) {
CHECK(connectedAddress.hasValue());
client = std::make_unique<ConnectedWriteUDPClient>(&cevb);
client->setShouldConnect(*connectedAddress);
} else {
client = std::make_unique<UDPClient>(&cevb);
if (connectedAddress) {
client->setShouldConnect(*connectedAddress);
}
}
// Start event loop in a separate thread // Start event loop in a separate thread
auto clientThread = std::thread([&cevb]() { cevb.loopForever(); }); auto clientThread = std::thread([this]() { cevb.loopForever(); });
// Wait for event loop to start // Wait for event loop to start
cevb.waitUntilRunning(); cevb.waitUntilRunning();
// Send ping // Send ping
cevb.runInEventBaseThread([&]() { client.start(server.address(), 1000); }); cevb.runInEventBaseThread([&]() { client->start(server->address(), 100); });
// Wait for client to finish // Wait for client to finish
clientThread.join(); clientThread.join();
return std::move(client);
}
TEST_F(AsyncSocketIntegrationTest, PingPong) {
startServer();
auto client = performPingPongTest(folly::none, false);
// This should succeed.
ASSERT_GT(client->pongRecvd(), 0);
}
// Check that some PING/PONGS were exchanged. Out of 1000 transactions TEST_F(AsyncSocketIntegrationTest, ConnectedPingPong) {
// at least 1 should succeed server->setChangePortForWrites(false);
CHECK_GT(client.pongRecvd(), 0); startServer();
auto client = performPingPongTest(server->address(), false);
// This should succeed
ASSERT_GT(client->pongRecvd(), 0);
}
// Shutdown server TEST_F(AsyncSocketIntegrationTest, ConnectedPingPongServerWrongAddress) {
sevb.runInEventBaseThread([&]() { server->setChangePortForWrites(true);
server.shutdown(); startServer();
sevb.terminateLoopSoon(); auto client = performPingPongTest(server->address(), false);
}); // This should fail.
ASSERT_EQ(client->pongRecvd(), 0);
}
// Wait for server thread to joib TEST_F(AsyncSocketIntegrationTest, ConnectedPingPongClientWrongAddress) {
serverThread.join(); server->setChangePortForWrites(false);
startServer();
folly::SocketAddress connectAddr(
server->address().getIPAddress(), server->address().getPort() + 1);
auto client = performPingPongTest(connectAddr, false);
// This should fail.
ASSERT_EQ(client->pongRecvd(), 0);
}
TEST_F(AsyncSocketIntegrationTest, PingPongUseConnectedSendMsg) {
server->setChangePortForWrites(false);
startServer();
auto client = performPingPongTest(server->address(), true);
// This should succeed.
ASSERT_GT(client->pongRecvd(), 0);
}
TEST_F(AsyncSocketIntegrationTest, PingPongUseConnectedSendMsgServerWrongAddr) {
server->setChangePortForWrites(true);
startServer();
auto client = performPingPongTest(server->address(), true);
// This should fail.
ASSERT_EQ(client->pongRecvd(), 0);
} }
class TestAsyncUDPSocket : public AsyncUDPSocket { class TestAsyncUDPSocket : public AsyncUDPSocket {
...@@ -346,6 +497,10 @@ class AsyncUDPSocketTest : public Test { ...@@ -346,6 +497,10 @@ class AsyncUDPSocketTest : public Test {
folly::SocketAddress addr_; folly::SocketAddress addr_;
}; };
TEST_F(AsyncUDPSocketTest, TestConnect) {
EXPECT_EQ(socket_->connect(addr_), 0);
}
TEST_F(AsyncUDPSocketTest, TestErrToNonExistentServer) { TEST_F(AsyncUDPSocketTest, TestErrToNonExistentServer) {
socket_->resumeRead(&readCb); socket_->resumeRead(&readCb);
socket_->setErrMessageCallback(&err); socket_->setErrMessageCallback(&err);
......
...@@ -42,6 +42,7 @@ struct MockAsyncUDPSocket : public AsyncUDPSocket { ...@@ -42,6 +42,7 @@ struct MockAsyncUDPSocket : public AsyncUDPSocket {
MOCK_METHOD1(setReuseAddr, void(bool)); MOCK_METHOD1(setReuseAddr, void(bool));
MOCK_METHOD1(dontFragment, void(bool)); MOCK_METHOD1(dontFragment, void(bool));
MOCK_METHOD1(setErrMessageCallback, void(ErrMessageCallback*)); MOCK_METHOD1(setErrMessageCallback, void(ErrMessageCallback*));
MOCK_METHOD1(connect, int(const SocketAddress&));
}; };
} // namespace test } // namespace test
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment