Commit 4708133f authored by Subodh Iyengar's avatar Subodh Iyengar Committed by Facebook Github Bot 0

Stop abusing errno

Summary:
We abuse errno to propagate exceptions from AsyncSSLSocket.
Stop doing this and propagate exceptions correctly.

This also formats the exception messages better.

Reviewed By: anirudhvr

Differential Revision: D3226808

fb-gh-sync-id: 15a5e67b0332136857e5fb85b1765757e548e040
fbshipit-source-id: 15a5e67b0332136857e5fb85b1765757e548e040
parent 38c0b1ab
...@@ -234,6 +234,7 @@ nobase_follyinclude_HEADERS = \ ...@@ -234,6 +234,7 @@ nobase_follyinclude_HEADERS = \
io/async/HHWheelTimer.h \ io/async/HHWheelTimer.h \
io/async/ssl/OpenSSLPtrTypes.h \ io/async/ssl/OpenSSLPtrTypes.h \
io/async/ssl/OpenSSLUtils.h \ io/async/ssl/OpenSSLUtils.h \
io/async/ssl/SSLErrors.h \
io/async/ssl/TLSDefinitions.h \ io/async/ssl/TLSDefinitions.h \
io/async/Request.h \ io/async/Request.h \
io/async/SSLContext.h \ io/async/SSLContext.h \
...@@ -417,6 +418,7 @@ libfolly_la_SOURCES = \ ...@@ -417,6 +418,7 @@ libfolly_la_SOURCES = \
io/async/test/SocketPair.cpp \ io/async/test/SocketPair.cpp \
io/async/test/TimeUtil.cpp \ io/async/test/TimeUtil.cpp \
io/async/ssl/OpenSSLUtils.cpp \ io/async/ssl/OpenSSLUtils.cpp \
io/async/ssl/SSLErrors.cpp \
json.cpp \ json.cpp \
detail/MemoryIdler.cpp \ detail/MemoryIdler.cpp \
MacAddress.cpp \ MacAddress.cpp \
......
This diff is collapsed.
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <folly/io/async/TimeoutManager.h> #include <folly/io/async/TimeoutManager.h>
#include <folly/io/async/ssl/OpenSSLPtrTypes.h> #include <folly/io/async/ssl/OpenSSLPtrTypes.h>
#include <folly/io/async/ssl/OpenSSLUtils.h> #include <folly/io/async/ssl/OpenSSLUtils.h>
#include <folly/io/async/ssl/SSLErrors.h>
#include <folly/io/async/ssl/TLSDefinitions.h> #include <folly/io/async/ssl/TLSDefinitions.h>
#include <folly/Bits.h> #include <folly/Bits.h>
...@@ -35,14 +36,6 @@ ...@@ -35,14 +36,6 @@
namespace folly { namespace folly {
class SSLException: public folly::AsyncSocketException {
public:
SSLException(int sslError,
unsigned long errError,
int sslOperationReturnValue,
int errno_copy);
};
/** /**
* A class for performing asynchronous I/O on an SSL connection. * A class for performing asynchronous I/O on an SSL connection.
* *
...@@ -143,18 +136,6 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -143,18 +136,6 @@ class AsyncSSLSocket : public virtual AsyncSocket {
AsyncSSLSocket* sslSocket_; AsyncSSLSocket* sslSocket_;
}; };
/**
* These are passed to the application via errno, packed in an SSL err which
* are outside the valid errno range. The values are chosen to be unique
* against values in ssl.h
*/
enum SSLError {
SSL_CLIENT_RENEGOTIATION_ATTEMPT = 900,
SSL_INVALID_RENEGOTIATION = 901,
SSL_EARLY_WRITE = 902
};
/** /**
* Create a client AsyncSSLSocket * Create a client AsyncSSLSocket
*/ */
...@@ -365,6 +346,11 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -365,6 +346,11 @@ class AsyncSSLSocket : public virtual AsyncSocket {
*/ */
SSL_SESSION *getSSLSession(); SSL_SESSION *getSSLSession();
/**
* Get a handle to the SSL struct.
*/
const SSL* getSSL() const;
/** /**
* Set the SSL session to be used during sslConn. AsyncSSLSocket will * Set the SSL session to be used during sslConn. AsyncSSLSocket will
* hold a reference to the session until it is destroyed or released by the * hold a reference to the session until it is destroyed or released by the
...@@ -760,11 +746,14 @@ class AsyncSSLSocket : public virtual AsyncSocket { ...@@ -760,11 +746,14 @@ class AsyncSSLSocket : public virtual AsyncSocket {
// AsyncSocket calls this at the wrong time for SSL // AsyncSocket calls this at the wrong time for SSL
void handleInitialReadWrite() noexcept override {} void handleInitialReadWrite() noexcept override {}
int interpretSSLError(int rc, int error); WriteResult interpretSSLError(int rc, int error);
ssize_t performRead(void** buf, size_t* buflen, size_t* offset) override; ReadResult performRead(void** buf, size_t* buflen, size_t* offset) override;
ssize_t performWrite(const iovec* vec, uint32_t count, WriteFlags flags, WriteResult performWrite(
uint32_t* countWritten, uint32_t* partialWritten) const iovec* vec,
override; uint32_t count,
WriteFlags flags,
uint32_t* countWritten,
uint32_t* partialWritten) override;
ssize_t performWriteIovec(const iovec* vec, uint32_t count, ssize_t performWriteIovec(const iovec* vec, uint32_t count,
WriteFlags flags, uint32_t* countWritten, WriteFlags flags, uint32_t* countWritten,
......
...@@ -91,14 +91,13 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest { ...@@ -91,14 +91,13 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
free(this); free(this);
} }
bool performWrite() override { WriteResult performWrite() override {
WriteFlags writeFlags = flags_; WriteFlags writeFlags = flags_;
if (getNext() != nullptr) { if (getNext() != nullptr) {
writeFlags = writeFlags | WriteFlags::CORK; writeFlags = writeFlags | WriteFlags::CORK;
} }
bytesWritten_ = socket_->performWrite(getOps(), getOpCount(), writeFlags, return socket_->performWrite(
&opsWritten_, &partialBytes_); getOps(), getOpCount(), writeFlags, &opsWritten_, &partialBytes_);
return bytesWritten_ >= 0;
} }
bool isComplete() override { bool isComplete() override {
...@@ -694,10 +693,14 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec, ...@@ -694,10 +693,14 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
assert(writeReqTail_ == nullptr); assert(writeReqTail_ == nullptr);
assert((eventFlags_ & EventHandler::WRITE) == 0); assert((eventFlags_ & EventHandler::WRITE) == 0);
bytesWritten = performWrite(vec, count, flags, auto writeResult =
&countWritten, &partialWritten); performWrite(vec, count, flags, &countWritten, &partialWritten);
bytesWritten = writeResult.writeReturn;
if (bytesWritten < 0) { if (bytesWritten < 0) {
auto errnoCopy = errno; auto errnoCopy = errno;
if (writeResult.exception) {
return failWrite(__func__, callback, 0, *writeResult.exception);
}
AsyncSocketException ex( AsyncSocketException ex(
AsyncSocketException::INTERNAL_ERROR, AsyncSocketException::INTERNAL_ERROR,
withAddr("writev failed"), withAddr("writev failed"),
...@@ -1259,11 +1262,10 @@ void AsyncSocket::ioReady(uint16_t events) noexcept { ...@@ -1259,11 +1262,10 @@ void AsyncSocket::ioReady(uint16_t events) noexcept {
} }
} }
ssize_t AsyncSocket::performRead(void** buf, AsyncSocket::ReadResult
size_t* buflen, AsyncSocket::performRead(void** buf, size_t* buflen, size_t* /* offset */) {
size_t* /* offset */) { VLOG(5) << "AsyncSocket::performRead() this=" << this << ", buf=" << *buf
VLOG(5) << "AsyncSocket::performRead() this=" << this << ", buflen=" << *buflen;
<< ", buf=" << *buf << ", buflen=" << *buflen;
int recvFlags = 0; int recvFlags = 0;
if (peek_) { if (peek_) {
...@@ -1274,13 +1276,13 @@ ssize_t AsyncSocket::performRead(void** buf, ...@@ -1274,13 +1276,13 @@ ssize_t AsyncSocket::performRead(void** buf,
if (bytes < 0) { if (bytes < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK) { if (errno == EAGAIN || errno == EWOULDBLOCK) {
// No more data to read right now. // No more data to read right now.
return READ_BLOCKING; return ReadResult(READ_BLOCKING);
} else { } else {
return READ_ERROR; return ReadResult(READ_ERROR);
} }
} else { } else {
appBytesReceived_ += bytes; appBytesReceived_ += bytes;
return bytes; return ReadResult(bytes);
} }
} }
...@@ -1347,7 +1349,8 @@ void AsyncSocket::handleRead() noexcept { ...@@ -1347,7 +1349,8 @@ void AsyncSocket::handleRead() noexcept {
} }
// Perform the read // Perform the read
ssize_t bytesRead = performRead(&buf, &buflen, &offset); auto readResult = performRead(&buf, &buflen, &offset);
auto bytesRead = readResult.readReturn;
VLOG(4) << "this=" << this << ", AsyncSocket::handleRead() got " VLOG(4) << "this=" << this << ", AsyncSocket::handleRead() got "
<< bytesRead << " bytes"; << bytesRead << " bytes";
if (bytesRead > 0) { if (bytesRead > 0) {
...@@ -1376,6 +1379,9 @@ void AsyncSocket::handleRead() noexcept { ...@@ -1376,6 +1379,9 @@ void AsyncSocket::handleRead() noexcept {
return; return;
} else if (bytesRead == READ_ERROR) { } else if (bytesRead == READ_ERROR) {
readErr_ = READ_ERROR; readErr_ = READ_ERROR;
if (readResult.exception) {
return failRead(__func__, *readResult.exception);
}
auto errnoCopy = errno; auto errnoCopy = errno;
AsyncSocketException ex( AsyncSocketException ex(
AsyncSocketException::INTERNAL_ERROR, AsyncSocketException::INTERNAL_ERROR,
...@@ -1439,7 +1445,11 @@ void AsyncSocket::handleWrite() noexcept { ...@@ -1439,7 +1445,11 @@ void AsyncSocket::handleWrite() noexcept {
// (See the comment in handleRead() explaining how this can happen.) // (See the comment in handleRead() explaining how this can happen.)
EventBase* originalEventBase = eventBase_; EventBase* originalEventBase = eventBase_;
while (writeReqHead_ != nullptr && eventBase_ == originalEventBase) { while (writeReqHead_ != nullptr && eventBase_ == originalEventBase) {
if (!writeReqHead_->performWrite()) { auto writeResult = writeReqHead_->performWrite();
if (writeResult.writeReturn < 0) {
if (writeResult.exception) {
return failWrite(__func__, *writeResult.exception);
}
auto errnoCopy = errno; auto errnoCopy = errno;
AsyncSocketException ex( AsyncSocketException ex(
AsyncSocketException::INTERNAL_ERROR, AsyncSocketException::INTERNAL_ERROR,
...@@ -1697,11 +1707,12 @@ void AsyncSocket::timeoutExpired() noexcept { ...@@ -1697,11 +1707,12 @@ void AsyncSocket::timeoutExpired() noexcept {
} }
} }
ssize_t AsyncSocket::performWrite(const iovec* vec, AsyncSocket::WriteResult AsyncSocket::performWrite(
uint32_t count, const iovec* vec,
WriteFlags flags, uint32_t count,
uint32_t* countWritten, WriteFlags flags,
uint32_t* partialWritten) { uint32_t* countWritten,
uint32_t* partialWritten) {
// We use sendmsg() instead of writev() so that we can pass in MSG_NOSIGNAL // We use sendmsg() instead of writev() so that we can pass in MSG_NOSIGNAL
// We correctly handle EPIPE errors, so we never want to receive SIGPIPE // We correctly handle EPIPE errors, so we never want to receive SIGPIPE
// (since it may terminate the program if the main program doesn't explicitly // (since it may terminate the program if the main program doesn't explicitly
...@@ -1736,12 +1747,12 @@ ssize_t AsyncSocket::performWrite(const iovec* vec, ...@@ -1736,12 +1747,12 @@ ssize_t AsyncSocket::performWrite(const iovec* vec,
// TCP buffer is full; we can't write any more data right now. // TCP buffer is full; we can't write any more data right now.
*countWritten = 0; *countWritten = 0;
*partialWritten = 0; *partialWritten = 0;
return 0; return WriteResult(0);
} }
// error // error
*countWritten = 0; *countWritten = 0;
*partialWritten = 0; *partialWritten = 0;
return -1; return WriteResult(WRITE_ERROR);
} }
appBytesWritten_ += totalWritten; appBytesWritten_ += totalWritten;
...@@ -1754,7 +1765,7 @@ ssize_t AsyncSocket::performWrite(const iovec* vec, ...@@ -1754,7 +1765,7 @@ ssize_t AsyncSocket::performWrite(const iovec* vec,
// Partial write finished in the middle of this iovec // Partial write finished in the middle of this iovec
*countWritten = n; *countWritten = n;
*partialWritten = bytesWritten; *partialWritten = bytesWritten;
return totalWritten; return WriteResult(totalWritten);
} }
bytesWritten -= v->iov_len; bytesWritten -= v->iov_len;
...@@ -1763,7 +1774,7 @@ ssize_t AsyncSocket::performWrite(const iovec* vec, ...@@ -1763,7 +1774,7 @@ ssize_t AsyncSocket::performWrite(const iovec* vec,
assert(bytesWritten == 0); assert(bytesWritten == 0);
*countWritten = n; *countWritten = n;
*partialWritten = 0; *partialWritten = 0;
return totalWritten; return WriteResult(totalWritten);
} }
/** /**
......
...@@ -16,16 +16,17 @@ ...@@ -16,16 +16,17 @@
#pragma once #pragma once
#include <sys/types.h> #include <folly/Optional.h>
#include <sys/socket.h>
#include <folly/SocketAddress.h> #include <folly/SocketAddress.h>
#include <folly/io/ShutdownSocketSet.h>
#include <folly/io/IOBuf.h> #include <folly/io/IOBuf.h>
#include <folly/io/async/AsyncTimeout.h> #include <folly/io/ShutdownSocketSet.h>
#include <folly/io/async/AsyncSocketException.h> #include <folly/io/async/AsyncSocketException.h>
#include <folly/io/async/AsyncTimeout.h>
#include <folly/io/async/AsyncTransport.h> #include <folly/io/async/AsyncTransport.h>
#include <folly/io/async/EventHandler.h>
#include <folly/io/async/DelayedDestruction.h> #include <folly/io/async/DelayedDestruction.h>
#include <folly/io/async/EventHandler.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <chrono> #include <chrono>
#include <memory> #include <memory>
...@@ -517,6 +518,41 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -517,6 +518,41 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
void setBufferCallback(BufferCallback* cb); void setBufferCallback(BufferCallback* cb);
/**
* writeReturn is the total number of bytes written, or WRITE_ERROR on error.
* If no data has been written, 0 is returned.
* exception is a more specific exception that cause a write error.
* Not all writes have exceptions associated with them thus writeReturn
* should be checked to determine whether the operation resulted in an error.
*/
struct WriteResult {
explicit WriteResult(ssize_t ret) : writeReturn(ret) {}
WriteResult(ssize_t ret, std::unique_ptr<const AsyncSocketException> e)
: writeReturn(ret), exception(std::move(e)) {}
ssize_t writeReturn;
std::unique_ptr<const AsyncSocketException> exception;
};
/**
* readReturn is the number of bytes read, or READ_EOF on EOF, or
* READ_ERROR on error, or READ_BLOCKING if the operation will
* block.
* exception is a more specific exception that may have caused a read error.
* Not all read errors have exceptions associated with them thus readReturn
* should be checked to determine whether the operation resulted in an error.
*/
struct ReadResult {
explicit ReadResult(ssize_t ret) : readReturn(ret) {}
ReadResult(ssize_t ret, std::unique_ptr<const AsyncSocketException> e)
: readReturn(ret), exception(std::move(e)) {}
ssize_t readReturn;
std::unique_ptr<const AsyncSocketException> exception;
};
/** /**
* A WriteRequest object tracks information about a pending write operation. * A WriteRequest object tracks information about a pending write operation.
*/ */
...@@ -529,7 +565,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -529,7 +565,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
virtual void destroy() = 0; virtual void destroy() = 0;
virtual bool performWrite() = 0; virtual WriteResult performWrite() = 0;
virtual void consume() = 0; virtual void consume() = 0;
...@@ -579,6 +615,10 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -579,6 +615,10 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
READ_NO_ERROR = -3, READ_NO_ERROR = -3,
}; };
enum WriteResultEnum {
WRITE_ERROR = -1,
};
/** /**
* Protected destructor. * Protected destructor.
* *
...@@ -683,11 +723,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -683,11 +723,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
* @param buf The buffer to read data into. * @param buf The buffer to read data into.
* @param buflen The length of the buffer. * @param buflen The length of the buffer.
* *
* @return Returns the number of bytes read, or READ_EOF on EOF, or * @return Returns a read result. See read result for details.
* READ_ERROR on error, or READ_BLOCKING if the operation will
* block.
*/ */
virtual ssize_t performRead(void** buf, size_t* buflen, size_t* offset); virtual ReadResult performRead(void** buf, size_t* buflen, size_t* offset);
/** /**
* Populate an iovec array from an IOBuf and attempt to write it. * Populate an iovec array from an IOBuf and attempt to write it.
...@@ -736,12 +774,14 @@ class AsyncSocket : virtual public AsyncTransportWrapper { ...@@ -736,12 +774,14 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
* will contain the number of bytes written in the * will contain the number of bytes written in the
* partially written iovec entry. * partially written iovec entry.
* *
* @return Returns the total number of bytes written, or -1 on error. If no * @return Returns a WriteResult. See WriteResult for more details.
* data can be written immediately, 0 is returned.
*/ */
virtual ssize_t performWrite(const iovec* vec, uint32_t count, virtual WriteResult performWrite(
WriteFlags flags, uint32_t* countWritten, const iovec* vec,
uint32_t* partialWritten); uint32_t count,
WriteFlags flags,
uint32_t* countWritten,
uint32_t* partialWritten);
bool updateEventRegistration(); bool updateEventRegistration();
......
/*
* Copyright 2016 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/io/async/ssl/SSLErrors.h>
#include <folly/Range.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
using namespace folly;
namespace {
std::string decodeOpenSSLError(
int sslError,
unsigned long errError,
int sslOperationReturnValue) {
if (sslError == SSL_ERROR_SYSCALL && errError == 0) {
if (sslOperationReturnValue == 0) {
return "SSL_ERROR_SYSCALL: EOF";
} else {
// In this case errno is set, AsyncSocketException will add it.
return "SSL_ERROR_SYSCALL";
}
} else if (sslError == SSL_ERROR_ZERO_RETURN) {
// This signifies a TLS closure alert.
return "SSL_ERROR_ZERO_RETURN";
} else {
std::array<char, 256> buf;
std::string msg(ERR_error_string(errError, buf.data()));
return msg;
}
}
const StringPiece getSSLErrorString(SSLError error) {
StringPiece ret;
switch (error) {
case SSLError::CLIENT_RENEGOTIATION:
ret = "Client tried to renegotiate with server";
break;
case SSLError::INVALID_RENEGOTIATION:
ret = "Attempt to start renegotiation, but unsupported";
break;
case SSLError::EARLY_WRITE:
ret = "Attempt to write before SSL connection established";
break;
case SSLError::OPENSSL_ERR:
// decodeOpenSSLError should be used for this type.
ret = "OPENSSL error";
break;
}
return ret;
}
}
namespace folly {
SSLException::SSLException(
int sslError,
unsigned long errError,
int sslOperationReturnValue,
int errno_copy)
: AsyncSocketException(
AsyncSocketException::SSL_ERROR,
decodeOpenSSLError(sslError, errError, sslOperationReturnValue),
sslError == SSL_ERROR_SYSCALL ? errno_copy : 0),
sslError(SSLError::OPENSSL_ERR),
opensslSSLError(sslError),
opensslErr(errError) {}
SSLException::SSLException(SSLError error)
: AsyncSocketException(
AsyncSocketException::SSL_ERROR,
getSSLErrorString(error).str(),
0),
sslError(error) {}
}
/*
* Copyright 2016 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <folly/Optional.h>
#include <folly/io/async/AsyncSocketException.h>
namespace folly {
enum class SSLError {
CLIENT_RENEGOTIATION, // A client tried to renegotiate with this server
INVALID_RENEGOTIATION, // We attempted to start a renegotiation.
EARLY_WRITE, // Wrote before SSL connection established.
// An openssl error type. The openssl specific methods should be used
// to find the real error type.
// This exists for compatibility until all error types can be move to proper
// errors.
OPENSSL_ERR,
};
class SSLException : public folly::AsyncSocketException {
public:
SSLException(
int sslError,
unsigned long errError,
int sslOperationReturnValue,
int errno_copy);
explicit SSLException(SSLError error);
SSLError getType() const {
return sslError;
}
// These methods exist for compatibility until there are proper exceptions
// for all ssl error types.
int getOpensslSSLError() const {
return opensslSSLError;
}
unsigned long getOpensslErr() const {
return opensslErr;
}
private:
SSLError sslError;
int opensslSSLError;
unsigned long opensslErr;
};
}
...@@ -201,13 +201,89 @@ TEST(AsyncSSLSocketTest, ConnectWriteReadClose) { ...@@ -201,13 +201,89 @@ TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
cerr << "ConnectWriteReadClose test completed" << endl; cerr << "ConnectWriteReadClose test completed" << endl;
} }
/**
* Test reading after server close.
*/
TEST(AsyncSSLSocketTest, ReadAfterClose) {
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadEOFCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallback acceptCallback(&handshakeCallback);
auto server = folly::make_unique<TestSSLServer>(&acceptCallback);
// Set up SSL context.
auto sslContext = std::make_shared<SSLContext>();
sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
auto socket =
std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
socket->open();
// This should trigger an EOF on the client.
auto evb = handshakeCallback.getSocket()->getEventBase();
evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
std::array<uint8_t, 128> readbuf;
auto bytesRead = socket->read(readbuf.data(), readbuf.size());
EXPECT_EQ(0, bytesRead);
}
/**
* Test bad renegotiation
*/
TEST(AsyncSSLSocketTest, Renegotiate) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto dfServerCtx = std::make_shared<SSLContext>();
std::array<int, 2> fds;
getfds(fds.data());
getctx(clientCtx, dfServerCtx);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), true, true);
RenegotiatingServer server(std::move(serverSock));
while (!client.handshakeSuccess_ && !client.handshakeError_) {
eventBase.loopOnce();
}
ASSERT_TRUE(client.handshakeSuccess_);
auto sslSock = std::move(client).moveSocket();
sslSock->detachEventBase();
// This is nasty, however we don't want to add support for
// renegotiation in AsyncSSLSocket.
SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
std::thread t([&]() { eventBase.loopForever(); });
// Trigger the renegotiation.
std::array<uint8_t, 128> buf;
memset(buf.data(), 'a', buf.size());
try {
socket->write(buf.data(), buf.size());
} catch (AsyncSocketException& e) {
LOG(INFO) << "client got error " << e.what();
}
eventBase.terminateLoopSoon();
t.join();
eventBase.loop();
ASSERT_TRUE(server.renegotiationError_);
}
/** /**
* Negative test for handshakeError(). * Negative test for handshakeError().
*/ */
TEST(AsyncSSLSocketTest, HandshakeError) { TEST(AsyncSSLSocketTest, HandshakeError) {
// Start listening on a local port // Start listening on a local port
WriteCallbackBase writeCallback; WriteCallbackBase writeCallback;
ReadCallback readCallback(&writeCallback); WriteErrorCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback); HandshakeCallback handshakeCallback(&readCallback);
HandshakeErrorCallback acceptCallback(&handshakeCallback); HandshakeErrorCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback); TestSSLServer server(&acceptCallback);
......
...@@ -18,13 +18,15 @@ ...@@ -18,13 +18,15 @@
#include <signal.h> #include <signal.h>
#include <pthread.h> #include <pthread.h>
#include <folly/io/async/AsyncServerSocket.h> #include <folly/ExceptionWrapper.h>
#include <folly/SocketAddress.h>
#include <folly/io/async/AsyncSSLSocket.h> #include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/AsyncServerSocket.h>
#include <folly/io/async/AsyncSocket.h> #include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/AsyncTimeout.h>
#include <folly/io/async/AsyncTransport.h> #include <folly/io/async/AsyncTransport.h>
#include <folly/io/async/EventBase.h> #include <folly/io/async/EventBase.h>
#include <folly/io/async/AsyncTimeout.h> #include <folly/io/async/ssl/SSLErrors.h>
#include <folly/SocketAddress.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
...@@ -58,7 +60,7 @@ public: ...@@ -58,7 +60,7 @@ public:
, exception(AsyncSocketException::UNKNOWN, "none") {} , exception(AsyncSocketException::UNKNOWN, "none") {}
~WriteCallbackBase() { ~WriteCallbackBase() {
EXPECT_EQ(state, STATE_SUCCEEDED); EXPECT_EQ(STATE_SUCCEEDED, state);
} }
void setSocket( void setSocket(
...@@ -92,10 +94,9 @@ public: ...@@ -92,10 +94,9 @@ public:
class ReadCallbackBase : class ReadCallbackBase :
public AsyncTransportWrapper::ReadCallback { public AsyncTransportWrapper::ReadCallback {
public: public:
explicit ReadCallbackBase(WriteCallbackBase *wcb) explicit ReadCallbackBase(WriteCallbackBase* wcb)
: wcb_(wcb) : wcb_(wcb), state(STATE_WAITING) {}
, state(STATE_WAITING) {}
~ReadCallbackBase() { ~ReadCallbackBase() {
EXPECT_EQ(state, STATE_SUCCEEDED); EXPECT_EQ(state, STATE_SUCCEEDED);
...@@ -222,6 +223,27 @@ public: ...@@ -222,6 +223,27 @@ public:
} }
}; };
class ReadEOFCallback : public ReadCallbackBase {
public:
explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
// Return nullptr buffer to trigger readError()
void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
*bufReturn = nullptr;
*lenReturn = 0;
}
void readDataAvailable(size_t /* len */) noexcept override {
// This should never to called.
FAIL();
}
void readEOF() noexcept override {
ReadCallbackBase::readEOF();
setState(STATE_SUCCEEDED);
}
};
class WriteErrorCallback : public ReadCallback { class WriteErrorCallback : public ReadCallback {
public: public:
explicit WriteErrorCallback(WriteCallbackBase *wcb) explicit WriteErrorCallback(WriteCallbackBase *wcb)
...@@ -340,6 +362,10 @@ public: ...@@ -340,6 +362,10 @@ public:
state = STATE_SUCCEEDED; state = STATE_SUCCEEDED;
} }
std::shared_ptr<AsyncSSLSocket> getSocket() {
return socket_;
}
StateEnum state; StateEnum state;
std::shared_ptr<AsyncSSLSocket> socket_; std::shared_ptr<AsyncSSLSocket> socket_;
ReadCallbackBase *rcb_; ReadCallbackBase *rcb_;
...@@ -879,6 +905,48 @@ class NpnServer : ...@@ -879,6 +905,48 @@ class NpnServer :
AsyncSSLSocket::UniquePtr socket_; AsyncSSLSocket::UniquePtr socket_;
}; };
class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
public AsyncTransportWrapper::ReadCallback {
public:
explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
: socket_(std::move(socket)) {
socket_->sslAccept(this);
}
~RenegotiatingServer() {
socket_->setReadCB(nullptr);
}
void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
LOG(INFO) << "Renegotiating server handshake success";
socket_->setReadCB(this);
}
void handshakeErr(
AsyncSSLSocket*,
const AsyncSocketException& ex) noexcept override {
ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
}
void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
*lenReturn = sizeof(buf);
*bufReturn = buf;
}
void readDataAvailable(size_t /* len */) noexcept override {}
void readEOF() noexcept override {}
void readErr(const AsyncSocketException& ex) noexcept override {
LOG(INFO) << "server got read error " << ex.what();
auto exPtr = dynamic_cast<const SSLException*>(&ex);
ASSERT_NE(nullptr, exPtr);
std::string exStr(ex.what());
SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
renegotiationError_ = true;
}
AsyncSSLSocket::UniquePtr socket_;
unsigned char buf[128];
bool renegotiationError_{false};
};
#ifndef OPENSSL_NO_TLSEXT #ifndef OPENSSL_NO_TLSEXT
class SNIClient : class SNIClient :
private AsyncSSLSocket::HandshakeCB, private AsyncSSLSocket::HandshakeCB,
...@@ -1139,6 +1207,10 @@ class SSLHandshakeBase : ...@@ -1139,6 +1207,10 @@ class SSLHandshakeBase :
verifyResult_(verifyResult) { verifyResult_(verifyResult) {
} }
AsyncSSLSocket::UniquePtr moveSocket() && {
return std::move(socket_);
}
bool handshakeVerify_; bool handshakeVerify_;
bool handshakeSuccess_; bool handshakeSuccess_;
bool handshakeError_; bool handshakeError_;
...@@ -1160,12 +1232,15 @@ class SSLHandshakeBase : ...@@ -1160,12 +1232,15 @@ class SSLHandshakeBase :
} }
void handshakeSuc(AsyncSSLSocket*) noexcept override { void handshakeSuc(AsyncSSLSocket*) noexcept override {
LOG(INFO) << "Handshake success";
handshakeSuccess_ = true; handshakeSuccess_ = true;
handshakeTime = socket_->getHandshakeTime(); handshakeTime = socket_->getHandshakeTime();
} }
void handshakeErr(AsyncSSLSocket*, void handshakeErr(
const AsyncSocketException& /* ex */) noexcept override { AsyncSSLSocket*,
const AsyncSocketException& ex) noexcept override {
LOG(INFO) << "Handshake error " << ex.what();
handshakeError_ = true; handshakeError_ = true;
handshakeTime = socket_->getHandshakeTime(); handshakeTime = socket_->getHandshakeTime();
} }
......
...@@ -58,8 +58,12 @@ class MockAsyncSSLSocket : public AsyncSSLSocket{ ...@@ -58,8 +58,12 @@ class MockAsyncSSLSocket : public AsyncSSLSocket{
MOCK_CONST_METHOD0(getRawBytesWritten, size_t()); MOCK_CONST_METHOD0(getRawBytesWritten, size_t());
// public wrapper for protected interface // public wrapper for protected interface
ssize_t testPerformWrite(const iovec* vec, uint32_t count, WriteFlags flags, WriteResult testPerformWrite(
uint32_t* countWritten, uint32_t* partialWritten) { const iovec* vec,
uint32_t count,
WriteFlags flags,
uint32_t* countWritten,
uint32_t* partialWritten) {
return performWrite(vec, count, flags, countWritten, partialWritten); return performWrite(vec, count, flags, countWritten, partialWritten);
} }
......
...@@ -35,6 +35,11 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback, ...@@ -35,6 +35,11 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
new folly::AsyncSocket(&eventBase_)), new folly::AsyncSocket(&eventBase_)),
address_(address) {} address_(address) {}
explicit BlockingSocket(folly::AsyncSocket::UniquePtr socket)
: sock_(std::move(socket)) {
sock_->attachEventBase(&eventBase_);
}
void open() { void open() {
sock_->connect(this, address_); sock_->connect(this, address_);
eventBase_.loop(); eventBase_.loop();
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment