Commit 50c13594 authored by Subodh Iyengar's avatar Subodh Iyengar Committed by Facebook Github Bot

add connect method to AsyncUDPSocket

Summary:
Calling connect on the UDP socket can avoid route lookups
on the write path. This would be an optimization that would be
useful for clients.

There are several caveats of using connect() which we discovered
during testing. These are documented and unit tests have been written
for each of them.

Reviewed By: yfeldblum

Differential Revision: D7802389

fbshipit-source-id: 09b71373a3a95c5dab73ee9345db0dbbf66d4ec5
parent 7a500026
......@@ -380,6 +380,14 @@ void AsyncUDPSocket::failErrMessageRead(const AsyncSocketException& ex) {
}
}
int AsyncUDPSocket::connect(const folly::SocketAddress& address) {
CHECK_NE(-1, fd_) << "Socket not yet bound";
sockaddr_storage addrStorage;
address.getAddress(&addrStorage);
return fsp::connect(
fd_, reinterpret_cast<sockaddr*>(&addrStorage), address.getActualSize());
}
void AsyncUDPSocket::handleRead() noexcept {
void* buf{nullptr};
size_t len{0};
......
......@@ -224,6 +224,26 @@ class AsyncUDPSocket : public EventHandler {
*/
virtual void setErrMessageCallback(ErrMessageCallback* errMessageCallback);
/**
* Connects the UDP socket to a remote destination address provided in
* address. This can speed up UDP writes on linux because it will cache flow
* state on connects.
* Using connect has many quirks, and you should be aware of them before using
* this API:
* 1. This must only be called after binding the socket.
* 2. Normally UDP can use the 2 tuple (src ip, src port) to steer packets
* sent by the peer to the socket, however after connecting the socket, only
* packets destined to the destination address specified in connect() will be
* forwarded and others will be dropped. If the server can send a packet
* from a different destination port / IP then you probably do not want to use
* this API.
* 3. It can be called repeatedly on either the client or server however it's
* normally only useful on the client and not server.
*
* Returns the result of calling the connect syscall.
*/
virtual int connect(const folly::SocketAddress& address);
protected:
virtual ssize_t sendmsg(int socket, const struct msghdr* message, int flags) {
return ::sendmsg(socket, message, flags);
......
......@@ -36,33 +36,38 @@ using namespace testing;
class UDPAcceptor : public AsyncUDPServerSocket::Callback {
public:
UDPAcceptor(EventBase* evb, int n) : evb_(evb), n_(n) {}
UDPAcceptor(EventBase* evb, int n, bool changePortForWrites)
: evb_(evb), n_(n), changePortForWrites_(changePortForWrites) {}
void onListenStarted() noexcept override {}
void onListenStopped() noexcept override {}
void onDataAvailable(
std::shared_ptr<folly::AsyncUDPSocket> /* socket */,
std::shared_ptr<folly::AsyncUDPSocket> socket,
const folly::SocketAddress& client,
std::unique_ptr<folly::IOBuf> data,
bool truncated) noexcept override {
lastClient_ = client;
lastMsg_ = data->moveToFbString().toStdString();
lastMsg_ = data->clone()->moveToFbString().toStdString();
auto len = data->computeChainDataLength();
VLOG(4) << "Worker " << n_ << " read " << len << " bytes "
<< "(trun:" << truncated << ") from " << client.describe() << " - "
<< lastMsg_;
sendPong();
sendPong(socket);
}
void sendPong() noexcept {
void sendPong(std::shared_ptr<folly::AsyncUDPSocket> socket) noexcept {
try {
AsyncUDPSocket socket(evb_);
socket.bind(folly::SocketAddress("127.0.0.1", 0));
socket.write(lastClient_, folly::IOBuf::copyBuffer(lastMsg_));
auto writeSocket = socket;
if (changePortForWrites_) {
writeSocket = std::make_shared<folly::AsyncUDPSocket>(evb_);
writeSocket->setReuseAddr(false);
writeSocket->bind(folly::SocketAddress("127.0.0.1", 0));
}
writeSocket->write(lastClient_, folly::IOBuf::copyBuffer(lastMsg_));
} catch (const std::exception& ex) {
VLOG(4) << "Failed to send PONG " << ex.what();
}
......@@ -71,6 +76,8 @@ class UDPAcceptor : public AsyncUDPServerSocket::Callback {
private:
EventBase* const evb_{nullptr};
const int n_{-1};
// Whether to create a new port per write.
bool changePortForWrites_{true};
folly::SocketAddress lastClient_;
std::string lastMsg_;
......@@ -99,7 +106,7 @@ class UDPServer {
// Add numWorkers thread
int i = 0;
for (auto& evb : evbs_) {
acceptors_.emplace_back(&evb, i);
acceptors_.emplace_back(&evb, i, changePortForWrites_);
std::thread t([&]() { evb.loopForever(); });
......@@ -131,6 +138,11 @@ class UDPServer {
}
}
// Whether writes from the UDP server should change the port for each message.
void setChangePortForWrites(bool changePortForWrites) {
changePortForWrites_ = changePortForWrites;
}
private:
EventBase* const evb_{nullptr};
const folly::SocketAddress addr_;
......@@ -139,6 +151,7 @@ class UDPServer {
std::vector<std::thread> threads_;
std::vector<folly::EventBase> evbs_;
std::vector<UDPAcceptor> acceptors_;
bool changePortForWrites_{true};
};
class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
......@@ -147,13 +160,15 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
void start(const folly::SocketAddress& server, int n) {
CHECK(evb_->isInEventBaseThread());
server_ = server;
socket_ = std::make_unique<AsyncUDPSocket>(evb_);
try {
socket_->bind(folly::SocketAddress("127.0.0.1", 0));
VLOG(4) << "Client bound to " << socket_->address().describe();
if (connectAddr_) {
connect();
}
VLOG(2) << "Client bound to " << socket_->address().describe();
} catch (const std::exception& ex) {
LOG(FATAL) << ex.what();
}
......@@ -166,6 +181,15 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
sendPing();
}
void connect() {
int ret = socket_->connect(*connectAddr_);
if (ret != 0) {
throw folly::AsyncSocketException(
folly::AsyncSocketException::NOT_OPEN, "ConnectFail", errno);
}
VLOG(2) << "Client connected to address=" << *connectAddr_;
}
void shutdown() {
CHECK(evb_->isInEventBaseThread());
socket_->pauseRead();
......@@ -182,8 +206,11 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
--n_;
scheduleTimeout(5);
socket_->write(
server_, folly::IOBuf::copyBuffer(folly::to<std::string>("PING ", n_)));
writePing(folly::IOBuf::copyBuffer(folly::to<std::string>("PING ", n_)));
}
virtual void writePing(std::unique_ptr<folly::IOBuf> buf) {
socket_->write(server_, std::move(buf));
}
void getReadBuffer(void** buf, size_t* len) noexcept override {
......@@ -224,58 +251,182 @@ class UDPClient : private AsyncUDPSocket::ReadCallback, private AsyncTimeout {
return pongRecvd_;
}
private:
AsyncUDPSocket& getSocket() {
return *socket_;
}
void setShouldConnect(const folly::SocketAddress& connectAddr) {
connectAddr_ = connectAddr;
}
protected:
folly::Optional<folly::SocketAddress> connectAddr_;
EventBase* const evb_{nullptr};
folly::SocketAddress server_;
std::unique_ptr<AsyncUDPSocket> socket_;
private:
int pongRecvd_{0};
int n_{0};
char buf_[1024];
};
TEST(AsyncSocketTest, PingPong) {
folly::EventBase sevb;
UDPServer server(&sevb, folly::SocketAddress("127.0.0.1", 0), 4);
class ConnectedWriteUDPClient : public UDPClient {
public:
~ConnectedWriteUDPClient() override = default;
ConnectedWriteUDPClient(EventBase* evb) : UDPClient(evb) {}
// When the socket is connected you don't need to supply the address to send
// msg. This will test that connect worked.
void writePing(std::unique_ptr<folly::IOBuf> buf) override {
iovec vec[16];
size_t iovec_len = buf->fillIov(vec, sizeof(vec) / sizeof(vec[0]));
if (UNLIKELY(iovec_len == 0)) {
buf->coalesce();
vec[0].iov_base = const_cast<uint8_t*>(buf->data());
vec[0].iov_len = buf->length();
iovec_len = 1;
}
// Start event loop in a separate thread
auto serverThread = std::thread([&sevb]() { sevb.loopForever(); });
struct msghdr msg;
msg.msg_name = nullptr;
msg.msg_namelen = 0;
msg.msg_iov = const_cast<struct iovec*>(vec);
msg.msg_iovlen = iovec_len;
msg.msg_control = nullptr;
msg.msg_controllen = 0;
msg.msg_flags = 0;
ssize_t ret = ::sendmsg(socket_->getFD(), &msg, 0);
if (ret == -1) {
if (errno != EAGAIN || errno != EWOULDBLOCK) {
throw folly::AsyncSocketException(
folly::AsyncSocketException::NOT_OPEN, "WriteFail", errno);
}
}
connect();
}
};
// Wait for event loop to start
sevb.waitUntilRunning();
class AsyncSocketIntegrationTest : public Test {
public:
void SetUp() override {
server = std::make_unique<UDPServer>(
&sevb, folly::SocketAddress("127.0.0.1", 0), 4);
// Start the server
sevb.runInEventBaseThreadAndWait([&]() { server.start(); });
// Start event loop in a separate thread
serverThread =
std::make_unique<std::thread>([this]() { sevb.loopForever(); });
// Wait for event loop to start
sevb.waitUntilRunning();
}
void startServer() {
// Start the server
sevb.runInEventBaseThreadAndWait([&]() { server->start(); });
LOG(INFO) << "Server listening=" << server->address();
}
void TearDown() override {
// Shutdown server
sevb.runInEventBaseThread([&]() {
server->shutdown();
sevb.terminateLoopSoon();
});
// Wait for server thread to joib
serverThread->join();
}
std::unique_ptr<UDPClient> performPingPongTest(
folly::Optional<folly::SocketAddress> connectedAddress,
bool useConnectedWrite);
folly::EventBase sevb;
folly::EventBase cevb;
UDPClient client(&cevb);
std::unique_ptr<std::thread> serverThread;
std::unique_ptr<UDPServer> server;
std::unique_ptr<UDPClient> client;
};
std::unique_ptr<UDPClient> AsyncSocketIntegrationTest::performPingPongTest(
folly::Optional<folly::SocketAddress> connectedAddress,
bool useConnectedWrite) {
if (useConnectedWrite) {
CHECK(connectedAddress.hasValue());
client = std::make_unique<ConnectedWriteUDPClient>(&cevb);
client->setShouldConnect(*connectedAddress);
} else {
client = std::make_unique<UDPClient>(&cevb);
if (connectedAddress) {
client->setShouldConnect(*connectedAddress);
}
}
// Start event loop in a separate thread
auto clientThread = std::thread([&cevb]() { cevb.loopForever(); });
auto clientThread = std::thread([this]() { cevb.loopForever(); });
// Wait for event loop to start
cevb.waitUntilRunning();
// Send ping
cevb.runInEventBaseThread([&]() { client.start(server.address(), 1000); });
cevb.runInEventBaseThread([&]() { client->start(server->address(), 100); });
// Wait for client to finish
clientThread.join();
return std::move(client);
}
TEST_F(AsyncSocketIntegrationTest, PingPong) {
startServer();
auto client = performPingPongTest(folly::none, false);
// This should succeed.
ASSERT_GT(client->pongRecvd(), 0);
}
// Check that some PING/PONGS were exchanged. Out of 1000 transactions
// at least 1 should succeed
CHECK_GT(client.pongRecvd(), 0);
TEST_F(AsyncSocketIntegrationTest, ConnectedPingPong) {
server->setChangePortForWrites(false);
startServer();
auto client = performPingPongTest(server->address(), false);
// This should succeed
ASSERT_GT(client->pongRecvd(), 0);
}
// Shutdown server
sevb.runInEventBaseThread([&]() {
server.shutdown();
sevb.terminateLoopSoon();
});
TEST_F(AsyncSocketIntegrationTest, ConnectedPingPongServerWrongAddress) {
server->setChangePortForWrites(true);
startServer();
auto client = performPingPongTest(server->address(), false);
// This should fail.
ASSERT_EQ(client->pongRecvd(), 0);
}
// Wait for server thread to joib
serverThread.join();
TEST_F(AsyncSocketIntegrationTest, ConnectedPingPongClientWrongAddress) {
server->setChangePortForWrites(false);
startServer();
folly::SocketAddress connectAddr(
server->address().getIPAddress(), server->address().getPort() + 1);
auto client = performPingPongTest(connectAddr, false);
// This should fail.
ASSERT_EQ(client->pongRecvd(), 0);
}
TEST_F(AsyncSocketIntegrationTest, PingPongUseConnectedSendMsg) {
server->setChangePortForWrites(false);
startServer();
auto client = performPingPongTest(server->address(), true);
// This should succeed.
ASSERT_GT(client->pongRecvd(), 0);
}
TEST_F(AsyncSocketIntegrationTest, PingPongUseConnectedSendMsgServerWrongAddr) {
server->setChangePortForWrites(true);
startServer();
auto client = performPingPongTest(server->address(), true);
// This should fail.
ASSERT_EQ(client->pongRecvd(), 0);
}
class TestAsyncUDPSocket : public AsyncUDPSocket {
......@@ -346,6 +497,10 @@ class AsyncUDPSocketTest : public Test {
folly::SocketAddress addr_;
};
TEST_F(AsyncUDPSocketTest, TestConnect) {
EXPECT_EQ(socket_->connect(addr_), 0);
}
TEST_F(AsyncUDPSocketTest, TestErrToNonExistentServer) {
socket_->resumeRead(&readCb);
socket_->setErrMessageCallback(&err);
......
......@@ -42,6 +42,7 @@ struct MockAsyncUDPSocket : public AsyncUDPSocket {
MOCK_METHOD1(setReuseAddr, void(bool));
MOCK_METHOD1(dontFragment, void(bool));
MOCK_METHOD1(setErrMessageCallback, void(ErrMessageCallback*));
MOCK_METHOD1(connect, int(const SocketAddress&));
};
} // namespace test
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment