Commit cbed8954 authored by Sarang Masti's avatar Sarang Masti Committed by Sara Golemon

Check readCallback before calling handleRead

Summary: Since readCallback_ could be uninstalled in any of callbacks,
we need to ensure that readCallback_ != nullptr before calling
handleRead.

Reviewed By: @djwatson

Differential Revision: D2140054
parent bf87ffac
......@@ -527,6 +527,12 @@ void AsyncSocket::setReadCB(ReadCallback *callback) {
return;
}
/* We are removing a read callback */
if (callback == nullptr &&
immediateReadHandler_.isLoopCallbackScheduled()) {
immediateReadHandler_.cancelLoopCallback();
}
if (shutdownFlags_ & SHUT_READ) {
// Reads have already been shut down on this socket.
//
......@@ -1330,9 +1336,11 @@ void AsyncSocket::handleRead() noexcept {
return;
}
if (maxReadsPerEvent_ && (++numReads >= maxReadsPerEvent_)) {
// We might still have data in the socket.
// (e.g. see comment in AsyncSSLSocket::checkForImmediateRead)
scheduleImmediateRead();
if (readCallback_ != nullptr) {
// We might still have data in the socket.
// (e.g. see comment in AsyncSSLSocket::checkForImmediateRead)
scheduleImmediateRead();
}
return;
}
}
......
......@@ -93,10 +93,11 @@ class WriteCallback : public AsyncTransportWrapper::WriteCallback {
class ReadCallback : public AsyncTransportWrapper::ReadCallback {
public:
ReadCallback()
explicit ReadCallback(size_t _maxBufferSz = 4096)
: state(STATE_WAITING)
, exception(AsyncSocketException::UNKNOWN, "none")
, buffers() {}
, buffers()
, maxBufferSz(_maxBufferSz) {}
~ReadCallback() {
for (std::vector<Buffer>::iterator it = buffers.begin();
......@@ -109,7 +110,7 @@ class ReadCallback : public AsyncTransportWrapper::ReadCallback {
void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
if (!currentBuffer.buffer) {
currentBuffer.allocate(4096);
currentBuffer.allocate(maxBufferSz);
}
*bufReturn = currentBuffer.buffer;
*lenReturn = currentBuffer.length;
......@@ -145,6 +146,14 @@ class ReadCallback : public AsyncTransportWrapper::ReadCallback {
CHECK_EQ(offset, expectedLen);
}
size_t dataRead() const {
size_t ret = 0;
for (const auto& buf : buffers) {
ret += buf.length;
}
return ret;
}
class Buffer {
public:
Buffer() : buffer(nullptr), length(0) {}
......@@ -173,6 +182,7 @@ class ReadCallback : public AsyncTransportWrapper::ReadCallback {
std::vector<Buffer> buffers;
Buffer currentBuffer;
VoidCallback dataAvailableCallback;
const size_t maxBufferSz;
};
class ReadVerifier {
......
......@@ -1253,6 +1253,115 @@ TEST(AsyncSocketTest, ClosePendingWritesWhileClosing) {
}
}
///////////////////////////////////////////////////////////////////////////
// ImmediateRead related tests
///////////////////////////////////////////////////////////////////////////
/* AsyncSocket use to verify immediate read works */
class AsyncSocketImmediateRead : public folly::AsyncSocket {
public:
bool immediateReadCalled = false;
explicit AsyncSocketImmediateRead(folly::EventBase* evb) : AsyncSocket(evb) {}
protected:
virtual void checkForImmediateRead() noexcept override {
immediateReadCalled = true;
AsyncSocket::handleRead();
}
};
TEST(AsyncSocket, ConnectReadImmediateRead) {
TestServer server;
const size_t maxBufferSz = 100;
const size_t maxReadsPerEvent = 1;
const size_t expectedDataSz = maxBufferSz * 3;
char expectedData[expectedDataSz];
memset(expectedData, 'j', expectedDataSz);
EventBase evb;
ReadCallback rcb(maxBufferSz);
AsyncSocketImmediateRead socket(&evb);
socket.connect(nullptr, server.getAddress(), 30);
evb.loop(); // loop until the socket is connected
socket.setReadCB(&rcb);
socket.setMaxReadsPerEvent(maxReadsPerEvent);
socket.immediateReadCalled = false;
auto acceptedSocket = server.acceptAsync(&evb);
ReadCallback rcbServer;
WriteCallback wcbServer;
rcbServer.dataAvailableCallback = [&]() {
if (rcbServer.dataRead() == expectedDataSz) {
// write back all data read
rcbServer.verifyData(expectedData, expectedDataSz);
acceptedSocket->write(&wcbServer, expectedData, expectedDataSz);
acceptedSocket->close();
}
};
acceptedSocket->setReadCB(&rcbServer);
// write data
WriteCallback wcb1;
socket.write(&wcb1, expectedData, expectedDataSz);
evb.loop();
CHECK_EQ(wcb1.state, STATE_SUCCEEDED);
rcb.verifyData(expectedData, expectedDataSz);
CHECK_EQ(socket.immediateReadCalled, true);
}
TEST(AsyncSocket, ConnectReadUninstallRead) {
TestServer server;
const size_t maxBufferSz = 100;
const size_t maxReadsPerEvent = 1;
const size_t expectedDataSz = maxBufferSz * 3;
char expectedData[expectedDataSz];
memset(expectedData, 'k', expectedDataSz);
EventBase evb;
ReadCallback rcb(maxBufferSz);
AsyncSocketImmediateRead socket(&evb);
socket.connect(nullptr, server.getAddress(), 30);
evb.loop(); // loop until the socket is connected
socket.setReadCB(&rcb);
socket.setMaxReadsPerEvent(maxReadsPerEvent);
socket.immediateReadCalled = false;
auto acceptedSocket = server.acceptAsync(&evb);
ReadCallback rcbServer;
WriteCallback wcbServer;
rcbServer.dataAvailableCallback = [&]() {
if (rcbServer.dataRead() == expectedDataSz) {
// write back all data read
rcbServer.verifyData(expectedData, expectedDataSz);
acceptedSocket->write(&wcbServer, expectedData, expectedDataSz);
acceptedSocket->close();
}
};
acceptedSocket->setReadCB(&rcbServer);
rcb.dataAvailableCallback = [&]() {
// we read data and reset readCB
socket.setReadCB(nullptr);
};
// write data
WriteCallback wcb;
socket.write(&wcb, expectedData, expectedDataSz);
evb.loop();
CHECK_EQ(wcb.state, STATE_SUCCEEDED);
/* we shoud've only read maxBufferSz data since readCallback_
* was reset in dataAvailableCallback */
CHECK_EQ(rcb.dataRead(), maxBufferSz);
CHECK_EQ(socket.immediateReadCalled, false);
}
// TODO:
// - Test connect() and have the connect callback set the read callback
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment