Commit 98ceba40 authored by Maxim Georgiev's avatar Maxim Georgiev Committed by Facebook Github Bot

In AsyncSocket::handleErrMessages() inside the loop check if the callback was not uninstalled

Summary: AsyncSocket::handleErrMessages() should check if the error message callback is still installing before calling it, since the callback could be uninstaled on the previous loop iteration.

Reviewed By: yfeldblum

Differential Revision: D5051001

fbshipit-source-id: fc01932c0d36bd8f72bf1905f12211fb83d28674
parent 7bdb20f8
...@@ -1552,7 +1552,9 @@ void AsyncSocket::handleErrMessages() noexcept { ...@@ -1552,7 +1552,9 @@ void AsyncSocket::handleErrMessages() noexcept {
} }
for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
cmsg != nullptr && cmsg->cmsg_len != 0; cmsg != nullptr &&
cmsg->cmsg_len != 0 &&
errMessageCallback_ != nullptr;
cmsg = CMSG_NXTHDR(&msg, cmsg)) { cmsg = CMSG_NXTHDR(&msg, cmsg)) {
errMessageCallback_->errMessage(*cmsg); errMessageCallback_->errMessage(*cmsg);
} }
......
...@@ -211,11 +211,13 @@ class TestErrMessageCallback : public folly::AsyncSocket::ErrMessageCallback { ...@@ -211,11 +211,13 @@ class TestErrMessageCallback : public folly::AsyncSocket::ErrMessageCallback {
void errMessage(const cmsghdr& cmsg) noexcept override { void errMessage(const cmsghdr& cmsg) noexcept override {
if (cmsg.cmsg_level == SOL_SOCKET && if (cmsg.cmsg_level == SOL_SOCKET &&
cmsg.cmsg_type == SCM_TIMESTAMPING) { cmsg.cmsg_type == SCM_TIMESTAMPING) {
gotTimestamp_ = true; gotTimestamp_++;
checkResetCallback();
} else if ( } else if (
(cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) || (cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
(cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR)) { (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR)) {
gotByteSeq_ = true; gotByteSeq_++;
checkResetCallback();
} }
} }
...@@ -224,9 +226,18 @@ class TestErrMessageCallback : public folly::AsyncSocket::ErrMessageCallback { ...@@ -224,9 +226,18 @@ class TestErrMessageCallback : public folly::AsyncSocket::ErrMessageCallback {
exception_ = ex; exception_ = ex;
} }
void checkResetCallback() noexcept {
if (socket_ != nullptr && resetAfter_ != -1 &&
gotTimestamp_ + gotByteSeq_ == resetAfter_) {
socket_->setErrMessageCB(nullptr);
}
}
folly::AsyncSocket* socket_{nullptr};
folly::AsyncSocketException exception_; folly::AsyncSocketException exception_;
bool gotTimestamp_{false}; int gotTimestamp_{0};
bool gotByteSeq_{false}; int gotByteSeq_{0};
int resetAfter_{-1};
}; };
class TestSendMsgParamsCallback : class TestSendMsgParamsCallback :
......
...@@ -2867,6 +2867,7 @@ enum SOF_TIMESTAMPING { ...@@ -2867,6 +2867,7 @@ enum SOF_TIMESTAMPING {
SOF_TIMESTAMPING_OPT_CMSG = (1 << 10), SOF_TIMESTAMPING_OPT_CMSG = (1 << 10),
SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11), SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
}; };
TEST(AsyncSocketTest, ErrMessageCallback) { TEST(AsyncSocketTest, ErrMessageCallback) {
TestServer server; TestServer server;
...@@ -2895,6 +2896,9 @@ TEST(AsyncSocketTest, ErrMessageCallback) { ...@@ -2895,6 +2896,9 @@ TEST(AsyncSocketTest, ErrMessageCallback) {
ASSERT_EQ(socket->getErrMessageCallback(), ASSERT_EQ(socket->getErrMessageCallback(),
static_cast<folly::AsyncSocket::ErrMessageCallback*>(&errMsgCB)); static_cast<folly::AsyncSocket::ErrMessageCallback*>(&errMsgCB));
errMsgCB.socket_ = socket.get();
errMsgCB.resetAfter_ = 3;
// Enable timestamp notifications // Enable timestamp notifications
ASSERT_GT(socket->getFd(), 0); ASSERT_GT(socket->getFd(), 0);
int flags = SOF_TIMESTAMPING_OPT_ID int flags = SOF_TIMESTAMPING_OPT_ID
...@@ -2908,7 +2912,9 @@ TEST(AsyncSocketTest, ErrMessageCallback) { ...@@ -2908,7 +2912,9 @@ TEST(AsyncSocketTest, ErrMessageCallback) {
// write() // write()
std::vector<uint8_t> wbuf(128, 'a'); std::vector<uint8_t> wbuf(128, 'a');
WriteCallback wcb; WriteCallback wcb;
socket->write(&wcb, wbuf.data(), wbuf.size()); // Send two packets to get two EOM notifications
socket->write(&wcb, wbuf.data(), wbuf.size() / 2);
socket->write(&wcb, wbuf.data() + wbuf.size() / 2, wbuf.size() / 2);
// Accept the connection. // Accept the connection.
std::shared_ptr<BlockingSocket> acceptedSocket = server.accept(); std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
...@@ -2933,8 +2939,10 @@ TEST(AsyncSocketTest, ErrMessageCallback) { ...@@ -2933,8 +2939,10 @@ TEST(AsyncSocketTest, ErrMessageCallback) {
// Check for the timestamp notifications. // Check for the timestamp notifications.
ASSERT_EQ(errMsgCB.exception_.type_, folly::AsyncSocketException::UNKNOWN); ASSERT_EQ(errMsgCB.exception_.type_, folly::AsyncSocketException::UNKNOWN);
ASSERT_TRUE(errMsgCB.gotByteSeq_); ASSERT_GT(errMsgCB.gotByteSeq_, 0);
ASSERT_TRUE(errMsgCB.gotTimestamp_); ASSERT_GT(errMsgCB.gotTimestamp_, 0);
ASSERT_EQ(
errMsgCB.gotByteSeq_ + errMsgCB.gotTimestamp_, errMsgCB.resetAfter_);
} }
#endif // MSG_ERRQUEUE #endif // MSG_ERRQUEUE
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment