Commit cbed8954 authored by Sarang Masti's avatar Sarang Masti Committed by Sara Golemon

Check readCallback before calling handleRead

Summary: Since readCallback_ could be uninstalled in any of callbacks,
we need to ensure that readCallback_ != nullptr before calling
handleRead.

Reviewed By: @djwatson

Differential Revision: D2140054
parent bf87ffac
...@@ -527,6 +527,12 @@ void AsyncSocket::setReadCB(ReadCallback *callback) { ...@@ -527,6 +527,12 @@ void AsyncSocket::setReadCB(ReadCallback *callback) {
return; return;
} }
/* We are removing a read callback */
if (callback == nullptr &&
immediateReadHandler_.isLoopCallbackScheduled()) {
immediateReadHandler_.cancelLoopCallback();
}
if (shutdownFlags_ & SHUT_READ) { if (shutdownFlags_ & SHUT_READ) {
// Reads have already been shut down on this socket. // Reads have already been shut down on this socket.
// //
...@@ -1330,9 +1336,11 @@ void AsyncSocket::handleRead() noexcept { ...@@ -1330,9 +1336,11 @@ void AsyncSocket::handleRead() noexcept {
return; return;
} }
if (maxReadsPerEvent_ && (++numReads >= maxReadsPerEvent_)) { if (maxReadsPerEvent_ && (++numReads >= maxReadsPerEvent_)) {
// We might still have data in the socket. if (readCallback_ != nullptr) {
// (e.g. see comment in AsyncSSLSocket::checkForImmediateRead) // We might still have data in the socket.
scheduleImmediateRead(); // (e.g. see comment in AsyncSSLSocket::checkForImmediateRead)
scheduleImmediateRead();
}
return; return;
} }
} }
......
...@@ -93,10 +93,11 @@ class WriteCallback : public AsyncTransportWrapper::WriteCallback { ...@@ -93,10 +93,11 @@ class WriteCallback : public AsyncTransportWrapper::WriteCallback {
class ReadCallback : public AsyncTransportWrapper::ReadCallback { class ReadCallback : public AsyncTransportWrapper::ReadCallback {
public: public:
ReadCallback() explicit ReadCallback(size_t _maxBufferSz = 4096)
: state(STATE_WAITING) : state(STATE_WAITING)
, exception(AsyncSocketException::UNKNOWN, "none") , exception(AsyncSocketException::UNKNOWN, "none")
, buffers() {} , buffers()
, maxBufferSz(_maxBufferSz) {}
~ReadCallback() { ~ReadCallback() {
for (std::vector<Buffer>::iterator it = buffers.begin(); for (std::vector<Buffer>::iterator it = buffers.begin();
...@@ -109,7 +110,7 @@ class ReadCallback : public AsyncTransportWrapper::ReadCallback { ...@@ -109,7 +110,7 @@ class ReadCallback : public AsyncTransportWrapper::ReadCallback {
void getReadBuffer(void** bufReturn, size_t* lenReturn) override { void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
if (!currentBuffer.buffer) { if (!currentBuffer.buffer) {
currentBuffer.allocate(4096); currentBuffer.allocate(maxBufferSz);
} }
*bufReturn = currentBuffer.buffer; *bufReturn = currentBuffer.buffer;
*lenReturn = currentBuffer.length; *lenReturn = currentBuffer.length;
...@@ -145,6 +146,14 @@ class ReadCallback : public AsyncTransportWrapper::ReadCallback { ...@@ -145,6 +146,14 @@ class ReadCallback : public AsyncTransportWrapper::ReadCallback {
CHECK_EQ(offset, expectedLen); CHECK_EQ(offset, expectedLen);
} }
size_t dataRead() const {
size_t ret = 0;
for (const auto& buf : buffers) {
ret += buf.length;
}
return ret;
}
class Buffer { class Buffer {
public: public:
Buffer() : buffer(nullptr), length(0) {} Buffer() : buffer(nullptr), length(0) {}
...@@ -173,6 +182,7 @@ class ReadCallback : public AsyncTransportWrapper::ReadCallback { ...@@ -173,6 +182,7 @@ class ReadCallback : public AsyncTransportWrapper::ReadCallback {
std::vector<Buffer> buffers; std::vector<Buffer> buffers;
Buffer currentBuffer; Buffer currentBuffer;
VoidCallback dataAvailableCallback; VoidCallback dataAvailableCallback;
const size_t maxBufferSz;
}; };
class ReadVerifier { class ReadVerifier {
......
...@@ -1253,6 +1253,115 @@ TEST(AsyncSocketTest, ClosePendingWritesWhileClosing) { ...@@ -1253,6 +1253,115 @@ TEST(AsyncSocketTest, ClosePendingWritesWhileClosing) {
} }
} }
///////////////////////////////////////////////////////////////////////////
// ImmediateRead related tests
///////////////////////////////////////////////////////////////////////////
/* AsyncSocket use to verify immediate read works */
class AsyncSocketImmediateRead : public folly::AsyncSocket {
public:
bool immediateReadCalled = false;
explicit AsyncSocketImmediateRead(folly::EventBase* evb) : AsyncSocket(evb) {}
protected:
virtual void checkForImmediateRead() noexcept override {
immediateReadCalled = true;
AsyncSocket::handleRead();
}
};
TEST(AsyncSocket, ConnectReadImmediateRead) {
TestServer server;
const size_t maxBufferSz = 100;
const size_t maxReadsPerEvent = 1;
const size_t expectedDataSz = maxBufferSz * 3;
char expectedData[expectedDataSz];
memset(expectedData, 'j', expectedDataSz);
EventBase evb;
ReadCallback rcb(maxBufferSz);
AsyncSocketImmediateRead socket(&evb);
socket.connect(nullptr, server.getAddress(), 30);
evb.loop(); // loop until the socket is connected
socket.setReadCB(&rcb);
socket.setMaxReadsPerEvent(maxReadsPerEvent);
socket.immediateReadCalled = false;
auto acceptedSocket = server.acceptAsync(&evb);
ReadCallback rcbServer;
WriteCallback wcbServer;
rcbServer.dataAvailableCallback = [&]() {
if (rcbServer.dataRead() == expectedDataSz) {
// write back all data read
rcbServer.verifyData(expectedData, expectedDataSz);
acceptedSocket->write(&wcbServer, expectedData, expectedDataSz);
acceptedSocket->close();
}
};
acceptedSocket->setReadCB(&rcbServer);
// write data
WriteCallback wcb1;
socket.write(&wcb1, expectedData, expectedDataSz);
evb.loop();
CHECK_EQ(wcb1.state, STATE_SUCCEEDED);
rcb.verifyData(expectedData, expectedDataSz);
CHECK_EQ(socket.immediateReadCalled, true);
}
TEST(AsyncSocket, ConnectReadUninstallRead) {
TestServer server;
const size_t maxBufferSz = 100;
const size_t maxReadsPerEvent = 1;
const size_t expectedDataSz = maxBufferSz * 3;
char expectedData[expectedDataSz];
memset(expectedData, 'k', expectedDataSz);
EventBase evb;
ReadCallback rcb(maxBufferSz);
AsyncSocketImmediateRead socket(&evb);
socket.connect(nullptr, server.getAddress(), 30);
evb.loop(); // loop until the socket is connected
socket.setReadCB(&rcb);
socket.setMaxReadsPerEvent(maxReadsPerEvent);
socket.immediateReadCalled = false;
auto acceptedSocket = server.acceptAsync(&evb);
ReadCallback rcbServer;
WriteCallback wcbServer;
rcbServer.dataAvailableCallback = [&]() {
if (rcbServer.dataRead() == expectedDataSz) {
// write back all data read
rcbServer.verifyData(expectedData, expectedDataSz);
acceptedSocket->write(&wcbServer, expectedData, expectedDataSz);
acceptedSocket->close();
}
};
acceptedSocket->setReadCB(&rcbServer);
rcb.dataAvailableCallback = [&]() {
// we read data and reset readCB
socket.setReadCB(nullptr);
};
// write data
WriteCallback wcb;
socket.write(&wcb, expectedData, expectedDataSz);
evb.loop();
CHECK_EQ(wcb.state, STATE_SUCCEEDED);
/* we shoud've only read maxBufferSz data since readCallback_
* was reset in dataAvailableCallback */
CHECK_EQ(rcb.dataRead(), maxBufferSz);
CHECK_EQ(socket.immediateReadCalled, false);
}
// TODO: // TODO:
// - Test connect() and have the connect callback set the read callback // - Test connect() and have the connect callback set the read callback
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment