Commit 2584b1f2 authored by Subodh Iyengar's avatar Subodh Iyengar Committed by Facebook Github Bot

Add errmsg callback to udp

Summary:
Adds the ability for UDP sockets to optionally
listen to errors that happen on the socket,
for example ICMP errors.

In this iteration, we explicitly do not register for
read events if an error callback is set. So even
if you set an error callback to get error events
you must register a read callback.

Reviewed By: yfeldblum

Differential Revision: D7632877

fbshipit-source-id: 43c922d0145e1da97b993f5bf4058c20d3469deb
parent 6ee198b0
...@@ -196,6 +196,27 @@ void AsyncUDPSocket::dontFragment(bool df) { ...@@ -196,6 +196,27 @@ void AsyncUDPSocket::dontFragment(bool df) {
#endif #endif
} }
void AsyncUDPSocket::setErrMessageCallback(
ErrMessageCallback* errMessageCallback) {
errMessageCallback_ = errMessageCallback;
int err = 1;
#if defined(IP_RECVERR)
if (address().getFamily() == AF_INET &&
fsp::setsockopt(fd_, IPPROTO_IP, IP_RECVERR, &err, sizeof(err))) {
throw AsyncSocketException(
AsyncSocketException::NOT_OPEN, "Failed to set IP_RECVERR", errno);
}
#endif
#if defined(IPV6_RECVERR)
if (address().getFamily() == AF_INET6 &&
fsp::setsockopt(fd_, IPPROTO_IPV6, IPV6_RECVERR, &err, sizeof(err))) {
throw AsyncSocketException(
AsyncSocketException::NOT_OPEN, "Failed to set IPV6_RECVERR", errno);
}
#endif
(void)err;
}
void AsyncUDPSocket::setFD(int fd, FDOwnership ownership) { void AsyncUDPSocket::setFD(int fd, FDOwnership ownership) {
CHECK_EQ(-1, fd_) << "Already bound to another FD"; CHECK_EQ(-1, fd_) << "Already bound to another FD";
...@@ -294,10 +315,84 @@ void AsyncUDPSocket::handlerReady(uint16_t events) noexcept { ...@@ -294,10 +315,84 @@ void AsyncUDPSocket::handlerReady(uint16_t events) noexcept {
} }
} }
size_t AsyncUDPSocket::handleErrMessages() noexcept {
#ifdef FOLLY_HAVE_MSG_ERRQUEUE
if (errMessageCallback_ == nullptr) {
return 0;
}
uint8_t ctrl[1024];
unsigned char data;
struct msghdr msg;
iovec entry;
entry.iov_base = &data;
entry.iov_len = sizeof(data);
msg.msg_iov = &entry;
msg.msg_iovlen = 1;
msg.msg_name = nullptr;
msg.msg_namelen = 0;
msg.msg_control = ctrl;
msg.msg_controllen = sizeof(ctrl);
msg.msg_flags = 0;
int ret;
size_t num = 0;
while (fd_ != -1) {
ret = recvmsg(fd_, &msg, MSG_ERRQUEUE);
VLOG(5) << "AsyncSocket::handleErrMessages(): recvmsg returned " << ret;
if (ret < 0) {
if (errno != EAGAIN) {
auto errnoCopy = errno;
LOG(ERROR) << "::recvmsg exited with code " << ret
<< ", errno: " << errnoCopy;
AsyncSocketException ex(
AsyncSocketException::INTERNAL_ERROR,
"recvmsg() failed",
errnoCopy);
failErrMessageRead(ex);
}
return num;
}
for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
cmsg != nullptr && cmsg->cmsg_len != 0;
cmsg = CMSG_NXTHDR(&msg, cmsg)) {
++num;
errMessageCallback_->errMessage(*cmsg);
if (fd_ == -1) {
// once the socket is closed there is no use for more read errors.
return num;
}
}
}
return num;
#else
return 0;
#endif
}
void AsyncUDPSocket::failErrMessageRead(const AsyncSocketException& ex) {
if (errMessageCallback_ != nullptr) {
ErrMessageCallback* callback = errMessageCallback_;
errMessageCallback_ = nullptr;
callback->errMessageError(ex);
}
}
void AsyncUDPSocket::handleRead() noexcept { void AsyncUDPSocket::handleRead() noexcept {
void* buf{nullptr}; void* buf{nullptr};
size_t len{0}; size_t len{0};
if (handleErrMessages()) {
return;
}
if (fd_ == -1) {
// The socket may have been closed by the error callbacks.
return;
}
readCallback_->getReadBuffer(&buf, &len); readCallback_->getReadBuffer(&buf, &len);
if (buf == nullptr || len == 0) { if (buf == nullptr || len == 0) {
AsyncSocketException ex( AsyncSocketException ex(
......
...@@ -74,6 +74,29 @@ class AsyncUDPSocket : public EventHandler { ...@@ -74,6 +74,29 @@ class AsyncUDPSocket : public EventHandler {
virtual ~ReadCallback() = default; virtual ~ReadCallback() = default;
}; };
class ErrMessageCallback {
public:
virtual ~ErrMessageCallback() = default;
/**
* errMessage() will be invoked when kernel puts a message to
* the error queue associated with the socket.
*
* @param cmsg Reference to cmsghdr structure describing
* a message read from error queue associated
* with the socket.
*/
virtual void errMessage(const cmsghdr& cmsg) noexcept = 0;
/**
* errMessageError() will be invoked if an error occurs reading a message
* from the socket error stream.
*
* @param ex An exception describing the error that occurred.
*/
virtual void errMessageError(const AsyncSocketException& ex) noexcept = 0;
};
/** /**
* Create a new UDP socket that will run in the * Create a new UDP socket that will run in the
* given eventbase * given eventbase
...@@ -196,11 +219,20 @@ class AsyncUDPSocket : public EventHandler { ...@@ -196,11 +219,20 @@ class AsyncUDPSocket : public EventHandler {
*/ */
virtual void dontFragment(bool df); virtual void dontFragment(bool df);
/**
* Callback for receiving errors on the UDP sockets
*/
void setErrMessageCallback(ErrMessageCallback* errMessageCallback);
protected: protected:
virtual ssize_t sendmsg(int socket, const struct msghdr* message, int flags) { virtual ssize_t sendmsg(int socket, const struct msghdr* message, int flags) {
return ::sendmsg(socket, message, flags); return ::sendmsg(socket, message, flags);
} }
size_t handleErrMessages() noexcept;
void failErrMessageRead(const AsyncSocketException& ex);
// Non-null only when we are reading // Non-null only when we are reading
ReadCallback* readCallback_; ReadCallback* readCallback_;
...@@ -228,6 +260,8 @@ class AsyncUDPSocket : public EventHandler { ...@@ -228,6 +260,8 @@ class AsyncUDPSocket : public EventHandler {
int rcvBuf_{0}; int rcvBuf_{0};
int sndBuf_{0}; int sndBuf_{0};
int busyPollUs_{0}; int busyPollUs_{0};
ErrMessageCallback* errMessageCallback_{nullptr};
}; };
} // namespace folly } // namespace folly
...@@ -284,3 +284,115 @@ class TestAsyncUDPSocket : public AsyncUDPSocket { ...@@ -284,3 +284,115 @@ class TestAsyncUDPSocket : public AsyncUDPSocket {
MOCK_METHOD3(sendmsg, ssize_t(int, const struct msghdr*, int)); MOCK_METHOD3(sendmsg, ssize_t(int, const struct msghdr*, int));
}; };
class MockErrMessageCallback : public AsyncUDPSocket::ErrMessageCallback {
public:
~MockErrMessageCallback() override = default;
GMOCK_METHOD1_(, noexcept, , errMessage, void(const cmsghdr& cmsg));
GMOCK_METHOD1_(
,
noexcept,
,
errMessageError,
void(const folly::AsyncSocketException& ex));
};
class MockUDPReadCallback : public AsyncUDPSocket::ReadCallback {
public:
~MockUDPReadCallback() override = default;
GMOCK_METHOD2_(, noexcept, , getReadBuffer, void(void** buf, size_t* len));
GMOCK_METHOD3_(
,
noexcept,
,
onDataAvailable,
void(const folly::SocketAddress& client, size_t len, bool truncated));
GMOCK_METHOD1_(
,
noexcept,
,
onReadError,
void(const folly::AsyncSocketException& ex));
GMOCK_METHOD0_(, noexcept, , onReadClosed, void());
};
class AsyncUDPSocketTest : public Test {
public:
void SetUp() override {
socket_ = std::make_shared<AsyncUDPSocket>(&evb_);
addr_ = folly::SocketAddress("127.0.0.1", 0);
socket_->bind(addr_);
}
EventBase evb_;
MockErrMessageCallback err;
MockUDPReadCallback readCb;
std::shared_ptr<AsyncUDPSocket> socket_;
folly::SocketAddress addr_;
};
TEST_F(AsyncUDPSocketTest, TestErrToNonExistentServer) {
socket_->resumeRead(&readCb);
socket_->setErrMessageCallback(&err);
folly::SocketAddress addr("127.0.0.1", 10000);
bool errRecvd = false;
EXPECT_CALL(err, errMessage(_))
.WillOnce(Invoke([this, &errRecvd](auto& cmsg) {
if ((cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
(cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR)) {
const struct sock_extended_err* serr =
reinterpret_cast<const struct sock_extended_err*>(
CMSG_DATA(&cmsg));
errRecvd =
(serr->ee_origin == SO_EE_ORIGIN_ICMP || SO_EE_ORIGIN_ICMP6);
LOG(ERROR) << "errno " << strerror(serr->ee_errno);
}
evb_.terminateLoopSoon();
}));
socket_->write(addr, folly::IOBuf::copyBuffer("hey"));
evb_.loopForever();
EXPECT_TRUE(errRecvd);
}
TEST_F(AsyncUDPSocketTest, CloseInErrorCallback) {
socket_->resumeRead(&readCb);
socket_->setErrMessageCallback(&err);
folly::SocketAddress addr("127.0.0.1", 10000);
bool errRecvd = false;
EXPECT_CALL(err, errMessage(_)).WillOnce(Invoke([this, &errRecvd](auto&) {
errRecvd = true;
socket_->close();
evb_.terminateLoopSoon();
}));
socket_->write(addr, folly::IOBuf::copyBuffer("hey"));
socket_->write(addr, folly::IOBuf::copyBuffer("hey"));
evb_.loopForever();
EXPECT_TRUE(errRecvd);
}
TEST_F(AsyncUDPSocketTest, TestNonExistentServerNoErrCb) {
socket_->resumeRead(&readCb);
folly::SocketAddress addr("127.0.0.1", 10000);
bool errRecvd = false;
folly::IOBufQueue readBuf;
EXPECT_CALL(readCb, getReadBuffer(_, _))
.WillRepeatedly(Invoke([&readBuf](void** buf, size_t* len) {
auto readSpace = readBuf.preallocate(2000, 10000);
*buf = readSpace.first;
*len = readSpace.second;
}));
ON_CALL(readCb, onReadError(_)).WillByDefault(Invoke([&errRecvd](auto& ex) {
LOG(ERROR) << ex.what();
errRecvd = true;
}));
socket_->write(addr, folly::IOBuf::copyBuffer("hey"));
evb_.timer().scheduleTimeoutFn(
[&] { evb_.terminateLoopSoon(); }, std::chrono::milliseconds(30));
evb_.loopForever();
EXPECT_FALSE(errRecvd);
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment