Commit 653703a3 authored by Xintong Hu's avatar Xintong Hu Committed by Facebook GitHub Bot

Add API to set cmsg for write

Summary: allow users to set/append a list of cmsgs to be sent for each write

Reviewed By: bschlinker

Differential Revision: D29313594

fbshipit-source-id: 8f78c59ecfe56ddb2c8c016d6105a676cd501c18
parent ca7ce442
...@@ -32,11 +32,17 @@ class SocketOptionKey { ...@@ -32,11 +32,17 @@ class SocketOptionKey {
public: public:
enum class ApplyPos { POST_BIND = 0, PRE_BIND = 1 }; enum class ApplyPos { POST_BIND = 0, PRE_BIND = 1 };
bool operator<(const SocketOptionKey& other) const { friend bool operator<(
if (level == other.level) { const SocketOptionKey& lhs, const SocketOptionKey& rhs) {
return optname < other.optname; if (lhs.level == rhs.level) {
return lhs.optname < rhs.optname;
} }
return level < other.level; return lhs.level < rhs.level;
}
friend bool operator==(
const SocketOptionKey& lhs, const SocketOptionKey& rhs) {
return lhs.level == rhs.level && lhs.optname == rhs.optname;
} }
int apply(NetworkSocket fd, int val) const; int apply(NetworkSocket fd, int val) const;
......
...@@ -591,19 +591,29 @@ ssize_t AsyncUDPSocket::writev( ...@@ -591,19 +591,29 @@ ssize_t AsyncUDPSocket::writev(
msg.msg_flags = 0; msg.msg_flags = 0;
#ifdef FOLLY_HAVE_MSG_ERRQUEUE #ifdef FOLLY_HAVE_MSG_ERRQUEUE
if (gso > 0) { constexpr size_t kSmallSizeMax = 5;
char control[CMSG_SPACE(sizeof(uint16_t))]; size_t controlBufSize = gso > 0 ? 1 : 0;
msg.msg_control = control; controlBufSize +=
msg.msg_controllen = sizeof(control); cmsgs_.size() * (CMSG_SPACE(sizeof(int)) / CMSG_SPACE(sizeof(uint16_t)));
struct cmsghdr* cm = CMSG_FIRSTHDR(&msg); if (controlBufSize <= kSmallSizeMax) {
cm->cmsg_level = SOL_UDP; // suppress "warning: variable length array 'control' is used [-Wvla]"
cm->cmsg_type = UDP_SEGMENT; FOLLY_PUSH_WARNING
cm->cmsg_len = CMSG_LEN(sizeof(uint16_t)); FOLLY_GNU_DISABLE_WARNING("-Wvla")
auto gso_len = static_cast<uint16_t>(gso); // we will allocate this on the stack anyway even if we do not use it
memcpy(CMSG_DATA(cm), &gso_len, sizeof(gso_len)); char control
[(BOOST_PP_IF(FOLLY_HAVE_VLA_01, controlBufSize, kSmallSizeMax)) *
return sendmsg(fd_, &msg, 0); (CMSG_SPACE(sizeof(uint16_t)))];
memset(control, 0, sizeof(control));
msg.msg_control = control;
FOLLY_POP_WARNING
return writevImpl(&msg, gso);
} else {
std::unique_ptr<char[]> control(
new char[controlBufSize * (CMSG_SPACE(sizeof(uint16_t)))]);
memset(control.get(), 0, controlBufSize * (CMSG_SPACE(sizeof(uint16_t))));
msg.msg_control = control.get();
return writevImpl(&msg, gso);
} }
#else #else
CHECK_LT(gso, 1) << "GSO not supported"; CHECK_LT(gso, 1) << "GSO not supported";
...@@ -612,6 +622,46 @@ ssize_t AsyncUDPSocket::writev( ...@@ -612,6 +622,46 @@ ssize_t AsyncUDPSocket::writev(
return sendmsg(fd_, &msg, 0); return sendmsg(fd_, &msg, 0);
} }
ssize_t AsyncUDPSocket::writevImpl(
struct msghdr* msg, FOLLY_MAYBE_UNUSED int gso) {
#ifdef FOLLY_HAVE_MSG_ERRQUEUE
struct cmsghdr* cm = nullptr;
for (auto itr = cmsgs_.begin(); itr != cmsgs_.end(); ++itr) {
const auto key = itr->first;
const auto val = itr->second;
msg->msg_controllen += CMSG_SPACE(sizeof(val));
if (cm) {
cm = CMSG_NXTHDR(msg, cm);
} else {
cm = CMSG_FIRSTHDR(msg);
}
if (cm) {
cm->cmsg_level = key.level;
cm->cmsg_type = key.optname;
cm->cmsg_len = CMSG_LEN(sizeof(val));
memcpy(CMSG_DATA(cm), &val, sizeof(val));
}
}
if (gso > 0) {
msg->msg_controllen += CMSG_SPACE(sizeof(uint16_t));
if (cm) {
cm = CMSG_NXTHDR(msg, cm);
} else {
cm = CMSG_FIRSTHDR(msg);
}
if (cm) {
cm->cmsg_level = SOL_UDP;
cm->cmsg_type = UDP_SEGMENT;
cm->cmsg_len = CMSG_LEN(sizeof(uint16_t));
auto gso_len = static_cast<uint16_t>(gso);
memcpy(CMSG_DATA(cm), &gso_len, sizeof(gso_len));
}
}
#endif // FOLLY_HAVE_MSG_ERRQUEUE
return sendmsg(fd_, msg, 0);
}
ssize_t AsyncUDPSocket::writev( ssize_t AsyncUDPSocket::writev(
const folly::SocketAddress& address, const folly::SocketAddress& address,
const struct iovec* vec, const struct iovec* vec,
...@@ -636,12 +686,16 @@ int AsyncUDPSocket::writemGSO( ...@@ -636,12 +686,16 @@ int AsyncUDPSocket::writemGSO(
size_t count, size_t count,
const int* gso) { const int* gso) {
int ret; int ret;
constexpr size_t kSmallSizeMax = 8; constexpr size_t kSmallSizeMax = 40;
char* gsoControl = nullptr; char* controlPtr = nullptr;
#ifndef FOLLY_HAVE_MSG_ERRQUEUE #ifndef FOLLY_HAVE_MSG_ERRQUEUE
CHECK(!gso) << "GSO not supported"; CHECK(!gso) << "GSO not supported";
#endif #endif
if (count <= kSmallSizeMax) { size_t singleControlBufSize = 1;
singleControlBufSize +=
cmsgs_.size() * (CMSG_SPACE(sizeof(int)) / CMSG_SPACE(sizeof(uint16_t)));
size_t controlBufSize = count * singleControlBufSize;
if (controlBufSize <= kSmallSizeMax) {
// suppress "warning: variable length array 'vec' is used [-Wvla]" // suppress "warning: variable length array 'vec' is used [-Wvla]"
FOLLY_PUSH_WARNING FOLLY_PUSH_WARNING
FOLLY_GNU_DISABLE_WARNING("-Wvla") FOLLY_GNU_DISABLE_WARNING("-Wvla")
...@@ -649,25 +703,22 @@ int AsyncUDPSocket::writemGSO( ...@@ -649,25 +703,22 @@ int AsyncUDPSocket::writemGSO(
#ifdef FOLLY_HAVE_MSG_ERRQUEUE #ifdef FOLLY_HAVE_MSG_ERRQUEUE
// we will allocate this on the stack anyway even if we do not use it // we will allocate this on the stack anyway even if we do not use it
char control char control
[(BOOST_PP_IF(FOLLY_HAVE_VLA_01, count, kSmallSizeMax)) * [(BOOST_PP_IF(FOLLY_HAVE_VLA_01, controlBufSize, kSmallSizeMax)) *
(CMSG_SPACE(sizeof(uint16_t)))]; (CMSG_SPACE(sizeof(uint16_t)))];
memset(control, 0, sizeof(control));
if (gso) { controlPtr = control;
gsoControl = control;
}
#endif #endif
FOLLY_POP_WARNING FOLLY_POP_WARNING
ret = writeImpl(addrs, bufs, count, vec, gso, gsoControl); ret = writeImpl(addrs, bufs, count, vec, gso, controlPtr);
} else { } else {
std::unique_ptr<mmsghdr[]> vec(new mmsghdr[count]); std::unique_ptr<mmsghdr[]> vec(new mmsghdr[count]);
#ifdef FOLLY_HAVE_MSG_ERRQUEUE #ifdef FOLLY_HAVE_MSG_ERRQUEUE
std::unique_ptr<char[]> control( std::unique_ptr<char[]> control(
gso ? (new char[count * (CMSG_SPACE(sizeof(uint16_t)))]) : nullptr); new char[controlBufSize * (CMSG_SPACE(sizeof(uint16_t)))]);
if (gso) { memset(control.get(), 0, controlBufSize * (CMSG_SPACE(sizeof(uint16_t))));
gsoControl = control.get(); controlPtr = control.get();
}
#endif #endif
ret = writeImpl(addrs, bufs, count, vec.get(), gso, gsoControl); ret = writeImpl(addrs, bufs, count, vec.get(), gso, controlPtr);
} }
return ret; return ret;
...@@ -681,7 +732,7 @@ void AsyncUDPSocket::fillMsgVec( ...@@ -681,7 +732,7 @@ void AsyncUDPSocket::fillMsgVec(
struct iovec* iov, struct iovec* iov,
size_t iov_count, size_t iov_count,
const int* gso, const int* gso,
char* gsoControl) { char* control) {
auto addr_count = addrs.size(); auto addr_count = addrs.size();
DCHECK(addr_count); DCHECK(addr_count);
size_t remaining = iov_count; size_t remaining = iov_count;
...@@ -705,23 +756,57 @@ void AsyncUDPSocket::fillMsgVec( ...@@ -705,23 +756,57 @@ void AsyncUDPSocket::fillMsgVec(
msg.msg_iov = &iov[iov_pos]; msg.msg_iov = &iov[iov_pos];
msg.msg_iovlen = iovec_len; msg.msg_iovlen = iovec_len;
#ifdef FOLLY_HAVE_MSG_ERRQUEUE #ifdef FOLLY_HAVE_MSG_ERRQUEUE
size_t controlBufSize = 1 +
cmsgs_.size() *
(CMSG_SPACE(sizeof(int)) / CMSG_SPACE(sizeof(uint16_t)));
// get the offset in the control buf allocated for this msg
msg.msg_control =
&control[i * controlBufSize * CMSG_SPACE(sizeof(uint16_t))];
msg.msg_controllen = 0;
struct cmsghdr* cm = nullptr;
// handle socket options
for (auto itr = cmsgs_.begin(); itr != cmsgs_.end(); ++itr) {
const auto key = itr->first;
const auto val = itr->second;
msg.msg_controllen += CMSG_SPACE(sizeof(val));
if (cm) {
cm = CMSG_NXTHDR(&msg, cm);
} else {
cm = CMSG_FIRSTHDR(&msg);
}
if (cm) {
cm->cmsg_level = key.level;
cm->cmsg_type = key.optname;
cm->cmsg_len = CMSG_LEN(sizeof(val));
memcpy(CMSG_DATA(cm), &val, sizeof(val));
}
}
// handle GSO
if (gso && gso[i] > 0) { if (gso && gso[i] > 0) {
msg.msg_control = &gsoControl[i * CMSG_SPACE(sizeof(uint16_t))]; msg.msg_controllen += CMSG_SPACE(sizeof(uint16_t));
msg.msg_controllen = CMSG_SPACE(sizeof(uint16_t)); if (cm) {
cm = CMSG_NXTHDR(&msg, cm);
struct cmsghdr* cm = CMSG_FIRSTHDR(&msg); } else {
cm->cmsg_level = SOL_UDP; cm = CMSG_FIRSTHDR(&msg);
cm->cmsg_type = UDP_SEGMENT; }
cm->cmsg_len = CMSG_LEN(sizeof(uint16_t)); if (cm) {
auto gso_len = static_cast<uint16_t>(gso[i]); cm->cmsg_level = SOL_UDP;
memcpy(CMSG_DATA(cm), &gso_len, sizeof(gso_len)); cm->cmsg_type = UDP_SEGMENT;
} else { cm->cmsg_len = CMSG_LEN(sizeof(uint16_t));
auto gso_len = static_cast<uint16_t>(gso[i]);
memcpy(CMSG_DATA(cm), &gso_len, sizeof(gso_len));
}
}
// there may be control buffer allocated, but nothing to put into it
// in this case, we null out the control fields
if (!cm) {
// no GSO, no socket options, null out control fields
msg.msg_control = nullptr; msg.msg_control = nullptr;
msg.msg_controllen = 0; msg.msg_controllen = 0;
} }
#else #else
(void)gso; (void)gso;
(void)gsoControl; (void)control;
msg.msg_control = nullptr; msg.msg_control = nullptr;
msg.msg_controllen = 0; msg.msg_controllen = 0;
#endif #endif
...@@ -739,7 +824,7 @@ int AsyncUDPSocket::writeImpl( ...@@ -739,7 +824,7 @@ int AsyncUDPSocket::writeImpl(
size_t count, size_t count,
struct mmsghdr* msgvec, struct mmsghdr* msgvec,
const int* gso, const int* gso,
char* gsoControl) { char* control) {
// most times we have a single destination addr // most times we have a single destination addr
auto addr_count = addrs.size(); auto addr_count = addrs.size();
constexpr size_t kAddrCountMax = 1; constexpr size_t kAddrCountMax = 1;
...@@ -764,14 +849,7 @@ int AsyncUDPSocket::writeImpl( ...@@ -764,14 +849,7 @@ int AsyncUDPSocket::writeImpl(
iovec iov[BOOST_PP_IF(FOLLY_HAVE_VLA_01, iov_count, kSmallSizeMax)]; iovec iov[BOOST_PP_IF(FOLLY_HAVE_VLA_01, iov_count, kSmallSizeMax)];
FOLLY_POP_WARNING FOLLY_POP_WARNING
fillMsgVec( fillMsgVec(
range(addrStorage), range(addrStorage), bufs, count, msgvec, iov, iov_count, gso, control);
bufs,
count,
msgvec,
iov,
iov_count,
gso,
gsoControl);
ret = sendmmsg(fd_, msgvec, count, 0); ret = sendmmsg(fd_, msgvec, count, 0);
} else { } else {
std::unique_ptr<iovec[]> iov(new iovec[iov_count]); std::unique_ptr<iovec[]> iov(new iovec[iov_count]);
...@@ -783,7 +861,7 @@ int AsyncUDPSocket::writeImpl( ...@@ -783,7 +861,7 @@ int AsyncUDPSocket::writeImpl(
iov.get(), iov.get(),
iov_count, iov_count,
gso, gso,
gsoControl); control);
ret = sendmmsg(fd_, msgvec, count, 0); ret = sendmmsg(fd_, msgvec, count, 0);
} }
...@@ -1288,4 +1366,14 @@ void AsyncUDPSocket::attachEventBase(folly::EventBase* evb) { ...@@ -1288,4 +1366,14 @@ void AsyncUDPSocket::attachEventBase(folly::EventBase* evb) {
updateRegistration(); updateRegistration();
} }
void AsyncUDPSocket::setCmsgs(const SocketOptionMap& cmsgs) {
cmsgs_ = cmsgs;
}
void AsyncUDPSocket::appendCmsgs(const SocketOptionMap& cmsgs) {
for (auto itr = cmsgs.begin(); itr != cmsgs.end(); ++itr) {
cmsgs_[itr->first] = itr->second;
}
}
} // namespace folly } // namespace folly
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <folly/io/async/EventBase.h> #include <folly/io/async/EventBase.h>
#include <folly/io/async/EventHandler.h> #include <folly/io/async/EventHandler.h>
#include <folly/net/NetOps.h> #include <folly/net/NetOps.h>
#include <folly/net/NetOpsDispatcher.h>
#include <folly/net/NetworkSocket.h> #include <folly/net/NetworkSocket.h>
namespace folly { namespace folly {
...@@ -230,6 +231,13 @@ class AsyncUDPSocket : public EventHandler { ...@@ -230,6 +231,13 @@ class AsyncUDPSocket : public EventHandler {
zeroCopyReenableThreshold_ = threshold; zeroCopyReenableThreshold_ = threshold;
} }
/**
* Set extra control messages to send
*/
virtual void setCmsgs(const SocketOptionMap& cmsgs);
virtual void appendCmsgs(const SocketOptionMap& cmsgs);
/** /**
* Send the data in buffer to destination. Returns the return code from * Send the data in buffer to destination. Returns the return code from
* ::sendmsg. * ::sendmsg.
...@@ -447,6 +455,28 @@ class AsyncUDPSocket : public EventHandler { ...@@ -447,6 +455,28 @@ class AsyncUDPSocket : public EventHandler {
void applyOptions( void applyOptions(
const SocketOptionMap& options, SocketOptionKey::ApplyPos pos); const SocketOptionMap& options, SocketOptionKey::ApplyPos pos);
/**
* Override netops::Dispatcher to be used for netops:: calls.
*
* Pass empty shared_ptr to reset to default.
* Override can be used by unit tests to intercept and mock netops:: calls.
*/
virtual void setOverrideNetOpsDispatcher(
std::shared_ptr<netops::Dispatcher> dispatcher) {
netops_.setOverride(std::move(dispatcher));
}
/**
* Returns override netops::Dispatcher being used for netops:: calls.
*
* Returns empty shared_ptr if no override set.
* Override can be used by unit tests to intercept and mock netops:: calls.
*/
virtual std::shared_ptr<netops::Dispatcher> getOverrideNetOpsDispatcher()
const {
return netops_.getOverride();
}
protected: protected:
struct full_sockaddr_storage { struct full_sockaddr_storage {
sockaddr_storage storage; sockaddr_storage storage;
...@@ -455,7 +485,7 @@ class AsyncUDPSocket : public EventHandler { ...@@ -455,7 +485,7 @@ class AsyncUDPSocket : public EventHandler {
virtual ssize_t sendmsg( virtual ssize_t sendmsg(
NetworkSocket socket, const struct msghdr* message, int flags) { NetworkSocket socket, const struct msghdr* message, int flags) {
return netops::sendmsg(socket, message, flags); return netops_->sendmsg(socket, message, flags);
} }
virtual int sendmmsg( virtual int sendmmsg(
...@@ -463,7 +493,7 @@ class AsyncUDPSocket : public EventHandler { ...@@ -463,7 +493,7 @@ class AsyncUDPSocket : public EventHandler {
struct mmsghdr* msgvec, struct mmsghdr* msgvec,
unsigned int vlen, unsigned int vlen,
int flags) { int flags) {
return netops::sendmmsg(socket, msgvec, vlen, flags); return netops_->sendmmsg(socket, msgvec, vlen, flags);
} }
void fillMsgVec( void fillMsgVec(
...@@ -474,7 +504,7 @@ class AsyncUDPSocket : public EventHandler { ...@@ -474,7 +504,7 @@ class AsyncUDPSocket : public EventHandler {
struct iovec* iov, struct iovec* iov,
size_t iov_count, size_t iov_count,
const int* gso, const int* gso,
char* gsoControl); char* control);
virtual int writeImpl( virtual int writeImpl(
Range<SocketAddress const*> addrs, Range<SocketAddress const*> addrs,
...@@ -482,7 +512,9 @@ class AsyncUDPSocket : public EventHandler { ...@@ -482,7 +512,9 @@ class AsyncUDPSocket : public EventHandler {
size_t count, size_t count,
struct mmsghdr* msgvec, struct mmsghdr* msgvec,
const int* gso, const int* gso,
char* gsoControl); char* control);
virtual ssize_t writevImpl(struct msghdr* msg, FOLLY_MAYBE_UNUSED int gso);
size_t handleErrMessages() noexcept; size_t handleErrMessages() noexcept;
...@@ -559,6 +591,10 @@ class AsyncUDPSocket : public EventHandler { ...@@ -559,6 +591,10 @@ class AsyncUDPSocket : public EventHandler {
std::unordered_map<uint32_t, std::unique_ptr<folly::IOBuf>> idZeroCopyBufMap_; std::unordered_map<uint32_t, std::unique_ptr<folly::IOBuf>> idZeroCopyBufMap_;
IOBufFreeFunc ioBufFreeFunc_; IOBufFreeFunc ioBufFreeFunc_;
SocketOptionMap cmsgs_;
netops::DispatcherContainer netops_;
}; };
} // namespace folly } // namespace folly
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <folly/io/async/AsyncTimeout.h> #include <folly/io/async/AsyncTimeout.h>
#include <folly/io/async/AsyncUDPServerSocket.h> #include <folly/io/async/AsyncUDPServerSocket.h>
#include <folly/io/async/EventBase.h> #include <folly/io/async/EventBase.h>
#include <folly/net/test/MockNetOpsDispatcher.h>
#include <folly/portability/GMock.h> #include <folly/portability/GMock.h>
#include <folly/portability/GTest.h> #include <folly/portability/GTest.h>
#include <folly/portability/Sockets.h> #include <folly/portability/Sockets.h>
...@@ -883,3 +884,227 @@ TEST_F(AsyncUDPSocketTest, TestDetachAttach) { ...@@ -883,3 +884,227 @@ TEST_F(AsyncUDPSocketTest, TestDetachAttach) {
t.join(); t.join();
EXPECT_EQ(packetsRecvd, 2); EXPECT_EQ(packetsRecvd, 2);
} }
MATCHER_P(HasCmsgs, cmsgs, "") {
struct msghdr* msg = const_cast<struct msghdr*>(arg);
if (msg == nullptr) {
return false;
}
#ifdef FOLLY_HAVE_MSG_ERRQUEUE
folly::SocketOptionMap sentCmsgs;
struct cmsghdr* cmsg;
for (cmsg = CMSG_FIRSTHDR(msg); cmsg != nullptr;
cmsg = CMSG_NXTHDR(msg, cmsg)) {
if (cmsg->cmsg_level == SOL_UDP) {
if (cmsg->cmsg_type == UDP_SEGMENT) {
uint16_t gso;
memcpy(
&gso,
reinterpret_cast<struct timespec*>(CMSG_DATA(cmsg)),
sizeof(gso));
sentCmsgs[{SOL_UDP, UDP_SEGMENT}] = gso;
}
}
if (cmsg->cmsg_level == SOL_SOCKET) {
if (cmsg->cmsg_type == SO_MARK) {
uint32_t somark;
memcpy(
&somark,
reinterpret_cast<struct timespec*>(CMSG_DATA(cmsg)),
sizeof(somark));
sentCmsgs[{SOL_SOCKET, SO_MARK}] = somark;
}
}
if (cmsg->cmsg_level == IPPROTO_IP) {
if (cmsg->cmsg_type == IP_TOS) {
uint32_t tos;
memcpy(
&tos,
reinterpret_cast<struct timespec*>(CMSG_DATA(cmsg)),
sizeof(tos));
sentCmsgs[{IPPROTO_IP, IP_TOS}] = tos;
}
if (cmsg->cmsg_type == IP_TTL) {
uint32_t ttl;
memcpy(
&ttl,
reinterpret_cast<struct timespec*>(CMSG_DATA(cmsg)),
sizeof(ttl));
sentCmsgs[{IPPROTO_IP, IP_TTL}] = ttl;
}
}
}
return sentCmsgs == cmsgs;
#else
return false;
#endif
}
TEST_F(AsyncUDPSocketTest, TestWriteCmsg) {
folly::SocketAddress addr("127.0.0.1", 10000);
auto netOpsDispatcher =
std::make_shared<NiceMock<folly::netops::test::MockDispatcher>>();
socket_->setOverrideNetOpsDispatcher(netOpsDispatcher);
#ifdef FOLLY_HAVE_MSG_ERRQUEUE
// empty
{
folly::SocketOptionMap cmsgs;
EXPECT_CALL(*netOpsDispatcher, sendmsg(_, HasCmsgs(cmsgs), _));
socket_->write(addr, folly::IOBuf::copyBuffer("hey"));
}
// writeGSO
{
folly::SocketOptionMap cmsgs;
cmsgs[{SOL_UDP, UDP_SEGMENT}] = 1;
EXPECT_CALL(*netOpsDispatcher, sendmsg(_, HasCmsgs(cmsgs), _));
socket_->writeGSO(addr, folly::IOBuf::copyBuffer("hey"), 1);
}
// SO_MARK
{
folly::SocketOptionMap cmsgs;
cmsgs[{SOL_SOCKET, SO_MARK}] = 123;
socket_->setCmsgs(cmsgs);
EXPECT_CALL(*netOpsDispatcher, sendmsg(_, HasCmsgs(cmsgs), _));
socket_->write(addr, folly::IOBuf::copyBuffer("hey"));
}
// append IP_TOS
{
folly::SocketOptionMap cmsgs;
cmsgs[{IPPROTO_IP, IP_TOS}] = 456;
socket_->appendCmsgs(cmsgs);
folly::SocketOptionMap expectedCmsgs;
expectedCmsgs[{IPPROTO_IP, IP_TOS}] = 456;
expectedCmsgs[{SOL_SOCKET, SO_MARK}] = 123;
EXPECT_CALL(*netOpsDispatcher, sendmsg(_, HasCmsgs(expectedCmsgs), _));
socket_->write(addr, folly::IOBuf::copyBuffer("hey"));
}
// append IP_TOS with a different value
{
folly::SocketOptionMap cmsgs;
cmsgs[{IPPROTO_IP, IP_TOS}] = 789;
socket_->appendCmsgs(cmsgs);
folly::SocketOptionMap expectedCmsgs;
expectedCmsgs[{IPPROTO_IP, IP_TOS}] = 789;
expectedCmsgs[{SOL_SOCKET, SO_MARK}] = 123;
socket_->setCmsgs(expectedCmsgs);
EXPECT_CALL(*netOpsDispatcher, sendmsg(_, HasCmsgs(expectedCmsgs), _));
socket_->write(addr, folly::IOBuf::copyBuffer("hey"));
}
// writeGSO with IP_TOS and SO_MARK
{
folly::SocketOptionMap expectedCmsgs;
expectedCmsgs[{IPPROTO_IP, IP_TOS}] = 789;
expectedCmsgs[{SOL_SOCKET, SO_MARK}] = 123;
expectedCmsgs[{SOL_UDP, UDP_SEGMENT}] = 1;
EXPECT_CALL(*netOpsDispatcher, sendmsg(_, HasCmsgs(expectedCmsgs), _));
socket_->writeGSO(addr, folly::IOBuf::copyBuffer("hey"), 1);
}
#endif // FOLLY_HAVE_MSG_ERRQUEUE
socket_->close();
}
MATCHER_P2(AllHaveCmsgs, cmsgs, count, "") {
if (arg == nullptr) {
return false;
}
for (size_t i = 0; i < count; ++i) {
auto msg = arg[i].msg_hdr;
if (!Matches(HasCmsgs(cmsgs))(&msg)) {
return false;
}
}
return true;
}
TEST(MatcherTest, AllHaveCmsgsTest) {
#ifdef FOLLY_HAVE_MSG_ERRQUEUE
size_t count = 2;
size_t controlSize = 2;
mmsghdr msgvec[count];
char control[count * controlSize * CMSG_SPACE(sizeof(uint16_t))];
memset(control, 0, sizeof(control));
// two messages, one with all cmsgs, the other with only one
{
auto& msg = msgvec[0].msg_hdr;
msg.msg_control = &control[0];
msg.msg_controllen = controlSize * CMSG_SPACE(sizeof(int));
struct cmsghdr* cm = nullptr;
cm = CMSG_FIRSTHDR(&msg);
int val1 = 20;
cm->cmsg_level = SOL_SOCKET;
cm->cmsg_type = SO_MARK;
cm->cmsg_len = CMSG_LEN(sizeof(val1));
memcpy(CMSG_DATA(cm), &val1, sizeof(val1));
cm = CMSG_NXTHDR(&msg, cm);
int val2 = 30;
cm->cmsg_level = IPPROTO_IP;
cm->cmsg_type = IP_TOS;
cm->cmsg_len = CMSG_LEN(sizeof(val2));
memcpy(CMSG_DATA(cm), &val2, sizeof(val2));
}
{
auto& msg = msgvec[1].msg_hdr;
msg.msg_control = &control[controlSize * CMSG_SPACE(sizeof(uint16_t))];
msg.msg_controllen = CMSG_SPACE(sizeof(int));
struct cmsghdr* cm = nullptr;
cm = CMSG_FIRSTHDR(&msg);
int val = 20;
cm->cmsg_level = SOL_SOCKET;
cm->cmsg_type = SO_MARK;
cm->cmsg_len = CMSG_LEN(sizeof(val));
memcpy(CMSG_DATA(cm), &val, sizeof(val));
}
folly::SocketOptionMap cmsgs;
cmsgs[{SOL_SOCKET, SO_MARK}] = 20;
cmsgs[{IPPROTO_IP, IP_TOS}] = 30;
struct mmsghdr* msgvecPtr = nullptr;
EXPECT_THAT(msgvecPtr, Not(AllHaveCmsgs(cmsgs, count)));
msgvecPtr = msgvec;
EXPECT_THAT(msgvecPtr, Not(AllHaveCmsgs(cmsgs, count)));
// true cases are tested in TestWritemCmsg
#endif // FOLLY_HAVE_MSG_ERRQUEUE
}
TEST_F(AsyncUDPSocketTest, TestWritemCmsg) {
folly::SocketAddress addr("127.0.0.1", 10000);
auto netOpsDispatcher =
std::make_shared<NiceMock<folly::netops::test::MockDispatcher>>();
socket_->setOverrideNetOpsDispatcher(netOpsDispatcher);
std::vector<std::unique_ptr<folly::IOBuf>> bufs;
bufs.emplace_back(folly::IOBuf::copyBuffer("hey1"));
bufs.emplace_back(folly::IOBuf::copyBuffer("hey2"));
#ifdef FOLLY_HAVE_MSG_ERRQUEUE
// empty
{
folly::SocketOptionMap cmsgs;
EXPECT_CALL(
*netOpsDispatcher, sendmmsg(_, AllHaveCmsgs(cmsgs, bufs.size()), _, _));
socket_->writem(folly::range(&addr, &addr + 1), bufs.data(), bufs.size());
}
// set IP_TOS & SO_MARK
{
folly::SocketOptionMap cmsgs;
cmsgs[{IPPROTO_IP, IP_TOS}] = 456;
cmsgs[{SOL_SOCKET, SO_MARK}] = 123;
socket_->setCmsgs(cmsgs);
EXPECT_CALL(
*netOpsDispatcher, sendmmsg(_, AllHaveCmsgs(cmsgs, bufs.size()), _, _));
socket_->writem(folly::range(&addr, &addr + 1), bufs.data(), bufs.size());
}
// writemGSO
{
folly::SocketOptionMap expectedCmsgs;
expectedCmsgs[{IPPROTO_IP, IP_TOS}] = 456;
expectedCmsgs[{SOL_SOCKET, SO_MARK}] = 123;
expectedCmsgs[{SOL_UDP, UDP_SEGMENT}] = 1;
EXPECT_CALL(
*netOpsDispatcher,
sendmmsg(_, AllHaveCmsgs(expectedCmsgs, bufs.size()), _, _));
std::vector<int> gso{1, 1};
socket_->writemGSO(
folly::range(&addr, &addr + 1), bufs.data(), bufs.size(), gso.data());
}
#endif // FOLLY_HAVE_MSG_ERRQUEUE
socket_->close();
}
...@@ -58,6 +58,8 @@ struct MockAsyncUDPSocket : public AsyncUDPSocket { ...@@ -58,6 +58,8 @@ struct MockAsyncUDPSocket : public AsyncUDPSocket {
MOCK_METHOD4( MOCK_METHOD4(
recvmmsg, recvmmsg,
int(struct mmsghdr*, unsigned int, unsigned int, struct timespec*)); int(struct mmsghdr*, unsigned int, unsigned int, struct timespec*));
MOCK_METHOD1(setCmsgs, void(const SocketOptionMap&));
MOCK_METHOD1(appendCmsgs, void(const SocketOptionMap&));
}; };
} // namespace test } // namespace test
......
...@@ -74,6 +74,10 @@ class MockDispatcher : public Dispatcher { ...@@ -74,6 +74,10 @@ class MockDispatcher : public Dispatcher {
MOCK_METHOD3( MOCK_METHOD3(
sendmsg, ssize_t(NetworkSocket s, const msghdr* message, int flags)); sendmsg, ssize_t(NetworkSocket s, const msghdr* message, int flags));
MOCK_METHOD4(
sendmmsg,
int(NetworkSocket s, mmsghdr* msgvec, unsigned int vlen, int flags));
MOCK_METHOD5( MOCK_METHOD5(
setsockopt, setsockopt,
int(NetworkSocket s, int(NetworkSocket s,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment