Commit c321eb58 authored by Kyle Nekritz's avatar Kyle Nekritz Committed by Facebook Github Bot

Handle close_notify as standard writeErr in AsyncSSLSocket.

Summary: Fixes CVE-2019-11934

Reviewed By: mingtaoy

Differential Revision: D18020613

fbshipit-source-id: db82bb250e53f0d225f1280bd67bc74abd417836
parent df7b6652
......@@ -1450,9 +1450,6 @@ AsyncSocket::WriteResult AsyncSSLSocket::interpretSSLError(int rc, int error) {
WRITE_ERROR,
std::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
} else {
if (zero_return(error, rc, errno)) {
return WriteResult(0);
}
auto errError = ERR_get_error();
VLOG(3) << "ERROR: AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
<< ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
......@@ -1589,10 +1586,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
*partialWritten = uint32_t(offset);
return WriteResult(totalWritten);
}
auto writeResult = interpretSSLError(int(bytes), error);
if (writeResult.writeReturn < 0) {
return writeResult;
} // else fall through to below to correctly record totalWritten
return interpretSSLError(int(bytes), error);
}
totalWritten += bytes;
......
......@@ -808,6 +808,114 @@ TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
cerr << "SSLClientTimeoutTest test completed" << endl;
}
class PerLoopReadCallback : public AsyncTransportWrapper::ReadCallback {
public:
void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
*bufReturn = buf_.data();
*lenReturn = buf_.size();
}
void readDataAvailable(size_t len) noexcept override {
VLOG(3) << "Read of size: " << len;
s_->setReadCB(nullptr);
s_->getEventBase()->runInLoop([this]() { s_->setReadCB(this); });
}
void readErr(const AsyncSocketException&) noexcept override {}
void readEOF() noexcept override {}
void setSocket(AsyncSocket* s) {
s_ = s;
}
private:
AsyncSocket* s_;
std::array<uint8_t, 1000> buf_;
};
class CloseNotifyConnector : public AsyncSocket::ConnectCallback {
public:
CloseNotifyConnector(EventBase* evb, const SocketAddress& addr) {
evb_ = evb;
ssl_ = AsyncSSLSocket::newSocket(std::make_shared<SSLContext>(), evb_);
ssl_->connect(this, addr);
}
void connectSuccess() noexcept override {
ssl_->writeChain(nullptr, IOBuf::copyBuffer("hi"));
auto ssl = const_cast<SSL*>(ssl_->getSSL());
SSL_shutdown(ssl);
auto fd = ssl_->detachNetworkSocket();
tcp_.reset(new AsyncSocket(evb_, fd), AsyncSocket::Destructor());
evb_->runAfterDelay(
[this]() {
perLoopReads_.setSocket(tcp_.get());
tcp_->setReadCB(&perLoopReads_);
evb_->runAfterDelay([this]() { tcp_->closeNow(); }, 10);
},
100);
}
void connectErr(const AsyncSocketException& ex) noexcept override {
FAIL() << ex.what();
}
private:
EventBase* evb_;
std::shared_ptr<AsyncSSLSocket> ssl_;
std::shared_ptr<AsyncSocket> tcp_;
PerLoopReadCallback perLoopReads_;
};
class ErrorCheckingWriteCallback : public AsyncSocket::WriteCallback {
public:
void writeSuccess() noexcept override {}
void writeErr(size_t, const AsyncSocketException& ex) noexcept override {
LOG(ERROR) << "write error: " << ex.what();
EXPECT_NE(
ex.getType(),
AsyncSocketException::AsyncSocketExceptionType::SSL_ERROR);
}
};
class WriteOnEofReadCallback : public ReadCallback {
public:
using ReadCallback::ReadCallback;
void readEOF() noexcept override {
LOG(INFO) << "Got EOF";
auto chain = IOBuf::create(0);
for (size_t i = 0; i < 1000 * 1000; i++) {
auto buf = IOBuf::create(10);
buf->append(10);
memset(buf->writableData(), 'x', 10);
chain->prependChain(std::move(buf));
}
socket_->writeChain(&writeCallback_, std::move(chain));
}
void readErr(const AsyncSocketException& ex) noexcept override {
LOG(ERROR) << ex.what();
}
private:
ErrorCheckingWriteCallback writeCallback_;
};
TEST(AsyncSSLSocketTest, EarlyCloseNotify) {
WriteOnEofReadCallback readCallback(nullptr);
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback);
EventBase eventBase;
CloseNotifyConnector cnc(&eventBase, server.getAddress());
eventBase.loop();
}
/**
* Verify Client Ciphers obtained using SSL MSG Callback.
*/
......
......@@ -389,7 +389,9 @@ class ReadCallback : public ReadCallbackBase {
currentBuffer.length = len;
wcb_->setSocket(socket_);
if (wcb_) {
wcb_->setSocket(socket_);
}
// Write back the same data.
socket_->write(wcb_, currentBuffer.buffer, len, writeFlags);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment