Commit ff9b70f3 authored by Dave Watson's avatar Dave Watson Committed by Pavlo Kushnir

Move SSL socket to folly

Summary: One of the last thrift -> folly moves.  The only change was the exception types - there are small wrapper classes in thrift/lib/cpp/async left to convert from AsyncSocketException to TTransportException.

Test Plan: run unit tests

Reviewed By: dcsommer@fb.com

Subscribers: jdperlow, trunkagent, doug, bmatheny, ssl-diffs@, njormrod, mshneer, folly-diffs@, fugalh, jsedgwick, andrewcox, alandau

FB internal diff: D1632425

Signature: t1:1632425:1414526483:339ae107bacb073bdd8cf0942fd0f6b70990feb4
parent a91ed7e3
......@@ -132,7 +132,9 @@ nobase_follyinclude_HEADERS = \
io/async/AsyncTimeout.h \
io/async/AsyncTransport.h \
io/async/AsyncServerSocket.h \
io/async/AsyncSSLServerSocket.h \
io/async/AsyncSocket.h \
io/async/AsyncSSLSocket.h \
io/async/AsyncSocketException.h \
io/async/DelayedDestruction.h \
io/async/EventBase.h \
......@@ -265,7 +267,9 @@ libfolly_la_SOURCES = \
io/ShutdownSocketSet.cpp \
io/async/AsyncTimeout.cpp \
io/async/AsyncServerSocket.cpp \
io/async/AsyncSSLServerSocket.cpp \
io/async/AsyncSocket.cpp \
io/async/AsyncSSLSocket.cpp \
io/async/EventBase.cpp \
io/async/EventBaseManager.cpp \
io/async/EventHandler.cpp \
......
/*
* Copyright 2014 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/io/async/AsyncSSLServerSocket.h>
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/SocketAddress.h>
using std::shared_ptr;
namespace folly {
AsyncSSLServerSocket::AsyncSSLServerSocket(
const shared_ptr<SSLContext>& ctx,
EventBase* eventBase)
: eventBase_(eventBase)
, serverSocket_(new AsyncServerSocket(eventBase))
, ctx_(ctx)
, sslCallback_(nullptr) {
}
AsyncSSLServerSocket::~AsyncSSLServerSocket() {
}
void AsyncSSLServerSocket::destroy() {
// Stop accepting on the underlying socket as soon as destroy is called
if (sslCallback_ != nullptr) {
serverSocket_->pauseAccepting();
serverSocket_->removeAcceptCallback(this, nullptr);
}
serverSocket_->destroy();
serverSocket_ = nullptr;
sslCallback_ = nullptr;
DelayedDestruction::destroy();
}
void AsyncSSLServerSocket::setSSLAcceptCallback(SSLAcceptCallback* callback) {
SSLAcceptCallback *oldCallback = sslCallback_;
sslCallback_ = callback;
if (callback != nullptr && oldCallback == nullptr) {
serverSocket_->addAcceptCallback(this, nullptr);
serverSocket_->startAccepting();
} else if (callback == nullptr && oldCallback != nullptr) {
serverSocket_->removeAcceptCallback(this, nullptr);
serverSocket_->pauseAccepting();
}
}
void AsyncSSLServerSocket::attachEventBase(EventBase* eventBase) {
assert(sslCallback_ == nullptr);
eventBase_ = eventBase;
serverSocket_->attachEventBase(eventBase);
}
void AsyncSSLServerSocket::detachEventBase() {
serverSocket_->detachEventBase();
eventBase_ = nullptr;
}
void
AsyncSSLServerSocket::connectionAccepted(
int fd,
const folly::SocketAddress& clientAddr) noexcept {
shared_ptr<AsyncSSLSocket> sslSock;
try {
// Create a AsyncSSLSocket object with the fd. The socket should be
// added to the event base and in the state of accepting SSL connection.
sslSock = AsyncSSLSocket::newSocket(ctx_, eventBase_, fd);
} catch (const std::exception &e) {
LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
"object with socket " << e.what() << fd;
::close(fd);
sslCallback_->acceptError(e);
return;
}
// TODO: Perform the SSL handshake before invoking the callback
sslCallback_->connectionAccepted(sslSock);
}
void AsyncSSLServerSocket::acceptError(const std::exception& ex)
noexcept {
LOG(ERROR) << "AsyncSSLServerSocket accept error: " << ex.what();
sslCallback_->acceptError(ex);
}
} // namespace
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#pragma once
#include <folly/io/async/SSLContext.h>
#include <folly/io/async/AsyncServerSocket.h>
namespace folly {
class SocketAddress;
class AsyncSSLSocket;
class AsyncSSLServerSocket : public DelayedDestruction,
private AsyncServerSocket::AcceptCallback {
public:
class SSLAcceptCallback {
public:
virtual ~SSLAcceptCallback() {}
/**
* connectionAccepted() is called whenever a new client connection is
* received.
*
* The SSLAcceptCallback will remain installed after connectionAccepted()
* returns.
*
* @param sock The newly accepted client socket. The
* SSLAcceptCallback
* assumes ownership of this socket, and is responsible
* for closing it when done.
*/
virtual void connectionAccepted(
const std::shared_ptr<AsyncSSLSocket> &sock)
noexcept = 0;
/**
* acceptError() is called if an error occurs while accepting.
*
* The SSLAcceptCallback will remain installed even after an accept error.
* If the callback wants to uninstall itself and stop trying to accept new
* connections, it must explicit call setAcceptCallback(nullptr).
*
* @param ex An exception representing the error.
*/
virtual void acceptError(const std::exception& ex) noexcept = 0;
};
/**
* Create a new TAsyncSSLServerSocket with the specified EventBase.
*
* @param eventBase The EventBase to use for driving the asynchronous I/O.
* If this parameter is nullptr, attachEventBase() must be
* called before this socket can begin accepting
* connections. All TAsyncSSLSocket objects accepted by
* this server socket will be attached to this EventBase
* when they are created.
*/
explicit AsyncSSLServerSocket(
const std::shared_ptr<folly::SSLContext>& ctx,
EventBase* eventBase = nullptr);
/**
* Destroy the socket.
*
* destroy() must be called to destroy the socket. The normal destructor is
* private, and should not be invoked directly. This prevents callers from
* deleting a TAsyncSSLServerSocket while it is invoking a callback.
*/
virtual void destroy();
virtual void bind(const folly::SocketAddress& address) {
serverSocket_->bind(address);
}
virtual void bind(uint16_t port) {
serverSocket_->bind(port);
}
void getAddress(folly::SocketAddress* addressReturn) {
serverSocket_->getAddress(addressReturn);
}
virtual void listen(int backlog) {
serverSocket_->listen(backlog);
}
/**
* Helper function to create a shared_ptr<TAsyncSSLServerSocket>.
*
* This passes in the correct destructor object, since TAsyncSSLServerSocket's
* destructor is protected and cannot be invoked directly.
*/
static std::shared_ptr<AsyncSSLServerSocket> newSocket(
const std::shared_ptr<folly::SSLContext>& ctx,
EventBase* evb) {
return std::shared_ptr<AsyncSSLServerSocket>(
new AsyncSSLServerSocket(ctx, evb),
Destructor());
}
/**
* Set the accept callback.
*
* This method may only be invoked from the EventBase's loop thread.
*
* @param callback The callback to invoke when a new socket
* connection is accepted and a new TAsyncSSLSocket is
* created.
*
* Throws TTransportException on error.
*/
void setSSLAcceptCallback(SSLAcceptCallback* callback);
SSLAcceptCallback *getSSLAcceptCallback() const {
return sslCallback_;
}
void attachEventBase(EventBase* eventBase);
void detachEventBase();
/**
* Returns the EventBase that the handler is currently attached to.
*/
EventBase* getEventBase() const {
return eventBase_;
}
protected:
/**
* Protected destructor.
*
* Invoke destroy() instead to destroy the TAsyncSSLServerSocket.
*/
virtual ~AsyncSSLServerSocket();
protected:
virtual void connectionAccepted(int fd,
const folly::SocketAddress& clientAddr)
noexcept;
virtual void acceptError(const std::exception& ex) noexcept;
EventBase* eventBase_;
AsyncServerSocket* serverSocket_;
// SSL context
std::shared_ptr<folly::SSLContext> ctx_;
// The accept callback
SSLAcceptCallback* sslCallback_;
};
} // namespace
/*
* Copyright 2014 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/io//async/AsyncSSLSocket.h>
#include <folly/io/async/EventBase.h>
#include <boost/noncopyable.hpp>
#include <errno.h>
#include <fcntl.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <openssl/err.h>
#include <openssl/asn1.h>
#include <openssl/ssl.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <unistd.h>
#include <chrono>
#include <folly/Bits.h>
#include <folly/SocketAddress.h>
#include <folly/io/IOBuf.h>
#include <folly/io/Cursor.h>
#include <folly/io/PortableSpinLock.h>
using folly::SocketAddress;
using folly::SSLContext;
using std::string;
using std::shared_ptr;
using folly::Endian;
using folly::IOBuf;
using folly::io::Cursor;
using folly::io::PortableSpinLock;
using folly::io::PortableSpinLockGuard;
using std::unique_ptr;
using std::bind;
namespace {
using folly::AsyncSocket;
using folly::AsyncSocketException;
using folly::AsyncSSLSocket;
using folly::Optional;
/** Try to avoid calling SSL_write() for buffers smaller than this: */
size_t MIN_WRITE_SIZE = 1500;
// We have one single dummy SSL context so that we can implement attach
// and detach methods in a thread safe fashion without modifying opnessl.
static SSLContext *dummyCtx = nullptr;
static PortableSpinLock dummyCtxLock;
// Numbers chosen as to not collide with functions in ssl.h
const uint8_t TASYNCSSLSOCKET_F_PERFORM_READ = 90;
const uint8_t TASYNCSSLSOCKET_F_PERFORM_WRITE = 91;
// This converts "illegal" shutdowns into ZERO_RETURN
inline bool zero_return(int error, int rc) {
return (error == SSL_ERROR_ZERO_RETURN || (rc == 0 && errno == 0));
}
class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback,
public AsyncSSLSocket::HandshakeCB {
private:
AsyncSSLSocket *sslSocket_;
AsyncSSLSocket::ConnectCallback *callback_;
int timeout_;
int64_t startTime_;
protected:
virtual ~AsyncSSLSocketConnector() {
}
public:
AsyncSSLSocketConnector(AsyncSSLSocket *sslSocket,
AsyncSocket::ConnectCallback *callback,
int timeout) :
sslSocket_(sslSocket),
callback_(callback),
timeout_(timeout),
startTime_(std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now().time_since_epoch()).count()) {
}
virtual void connectSuccess() noexcept {
VLOG(7) << "client socket connected";
int64_t timeoutLeft = 0;
if (timeout_ > 0) {
auto curTime = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now().time_since_epoch()).count();
timeoutLeft = timeout_ - (curTime - startTime_);
if (timeoutLeft <= 0) {
AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
"SSL connect timed out");
fail(ex);
delete this;
return;
}
}
sslSocket_->sslConn(this, timeoutLeft);
}
virtual void connectErr(const AsyncSocketException& ex) noexcept {
LOG(ERROR) << "TCP connect failed: " << ex.what();
fail(ex);
delete this;
}
virtual void handshakeSuc(AsyncSSLSocket *sock) noexcept {
VLOG(7) << "client handshake success";
if (callback_) {
callback_->connectSuccess();
}
delete this;
}
virtual void handshakeErr(AsyncSSLSocket *socket,
const AsyncSocketException& ex) noexcept {
LOG(ERROR) << "client handshakeErr: " << ex.what();
fail(ex);
delete this;
}
void fail(const AsyncSocketException &ex) {
// fail is a noop if called twice
if (callback_) {
AsyncSSLSocket::ConnectCallback *cb = callback_;
callback_ = nullptr;
cb->connectErr(ex);
sslSocket_->closeNow();
// closeNow can call handshakeErr if it hasn't been called already.
// So this may have been deleted, no member variable access beyond this
// point
// Note that closeNow may invoke writeError callbacks if the socket had
// write data pending connection completion.
}
}
};
// XXX: implement an equivalent to corking for platforms with TCP_NOPUSH?
#ifdef TCP_CORK // Linux-only
/**
* Utility class that corks a TCP socket upon construction or uncorks
* the socket upon destruction
*/
class CorkGuard : private boost::noncopyable {
public:
CorkGuard(int fd, bool multipleWrites, bool haveMore, bool* corked):
fd_(fd), haveMore_(haveMore), corked_(corked) {
if (*corked_) {
// socket is already corked; nothing to do
return;
}
if (multipleWrites || haveMore) {
// We are performing multiple writes in this performWrite() call,
// and/or there are more calls to performWrite() that will be invoked
// later, so enable corking
int flag = 1;
setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag));
*corked_ = true;
}
}
~CorkGuard() {
if (haveMore_) {
// more data to come; don't uncork yet
return;
}
if (!*corked_) {
// socket isn't corked; nothing to do
return;
}
int flag = 0;
setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag));
*corked_ = false;
}
private:
int fd_;
bool haveMore_;
bool* corked_;
};
#else
class CorkGuard : private boost::noncopyable {
public:
CorkGuard(int, bool, bool, bool*) {}
};
#endif
void setup_SSL_CTX(SSL_CTX *ctx) {
#ifdef SSL_MODE_RELEASE_BUFFERS
SSL_CTX_set_mode(ctx,
SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
SSL_MODE_ENABLE_PARTIAL_WRITE
| SSL_MODE_RELEASE_BUFFERS
);
#else
SSL_CTX_set_mode(ctx,
SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
SSL_MODE_ENABLE_PARTIAL_WRITE
);
#endif
}
BIO_METHOD eorAwareBioMethod;
__attribute__((__constructor__))
void initEorBioMethod(void) {
memcpy(&eorAwareBioMethod, BIO_s_socket(), sizeof(eorAwareBioMethod));
// override the bwrite method for MSG_EOR support
eorAwareBioMethod.bwrite = AsyncSSLSocket::eorAwareBioWrite;
// Note that the eorAwareBioMethod.type and eorAwareBioMethod.name are not
// set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and
// then have specific handlings. The eorAwareBioWrite should be compatible
// with the one in openssl.
}
} // anonymous namespace
namespace folly {
SSLException::SSLException(int sslError, int errno_copy):
AsyncSocketException(
AsyncSocketException::SSL_ERROR,
ERR_error_string(sslError, msg_),
sslError == SSL_ERROR_SYSCALL ? errno_copy : 0), error_(sslError) {}
/**
* Create a client AsyncSSLSocket
*/
AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
EventBase* evb) :
AsyncSocket(evb),
ctx_(ctx),
handshakeTimeout_(this, evb) {
setup_SSL_CTX(ctx_->getSSLCtx());
}
/**
* Create a server/client AsyncSSLSocket
*/
AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
EventBase* evb, int fd, bool server) :
AsyncSocket(evb, fd),
server_(server),
ctx_(ctx),
handshakeTimeout_(this, evb) {
setup_SSL_CTX(ctx_->getSSLCtx());
if (server) {
SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
AsyncSSLSocket::sslInfoCallback);
}
}
#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
/**
* Create a client AsyncSSLSocket and allow tlsext_hostname
* to be sent in Client Hello.
*/
AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
EventBase* evb,
const std::string& serverName) :
AsyncSocket(evb),
ctx_(ctx),
handshakeTimeout_(this, evb),
tlsextHostname_(serverName) {
setup_SSL_CTX(ctx_->getSSLCtx());
}
/**
* Create a client AsyncSSLSocket from an already connected fd
* and allow tlsext_hostname to be sent in Client Hello.
*/
AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
EventBase* evb, int fd,
const std::string& serverName) :
AsyncSocket(evb, fd),
ctx_(ctx),
handshakeTimeout_(this, evb),
tlsextHostname_(serverName) {
setup_SSL_CTX(ctx_->getSSLCtx());
}
#endif
AsyncSSLSocket::~AsyncSSLSocket() {
VLOG(3) << "actual destruction of AsyncSSLSocket(this=" << this
<< ", evb=" << eventBase_ << ", fd=" << fd_
<< ", state=" << int(state_) << ", sslState="
<< sslState_ << ", events=" << eventFlags_ << ")";
}
void AsyncSSLSocket::closeNow() {
// Close the SSL connection.
if (ssl_ != nullptr && fd_ != -1) {
int rc = SSL_shutdown(ssl_);
if (rc == 0) {
rc = SSL_shutdown(ssl_);
}
if (rc < 0) {
ERR_clear_error();
}
}
if (sslSession_ != nullptr) {
SSL_SESSION_free(sslSession_);
sslSession_ = nullptr;
}
sslState_ = STATE_CLOSED;
if (handshakeTimeout_.isScheduled()) {
handshakeTimeout_.cancelTimeout();
}
DestructorGuard dg(this);
if (handshakeCallback_) {
AsyncSocketException ex(AsyncSocketException::END_OF_FILE,
"SSL connection closed locally");
HandshakeCB* callback = handshakeCallback_;
handshakeCallback_ = nullptr;
callback->handshakeErr(this, ex);
}
if (ssl_ != nullptr) {
SSL_free(ssl_);
ssl_ = nullptr;
}
// Close the socket.
AsyncSocket::closeNow();
}
void AsyncSSLSocket::shutdownWrite() {
// SSL sockets do not support half-shutdown, so just perform a full shutdown.
//
// (Performing a full shutdown here is more desirable than doing nothing at
// all. The purpose of shutdownWrite() is normally to notify the other end
// of the connection that no more data will be sent. If we do nothing, the
// other end will never know that no more data is coming, and this may result
// in protocol deadlock.)
close();
}
void AsyncSSLSocket::shutdownWriteNow() {
closeNow();
}
bool AsyncSSLSocket::good() const {
return (AsyncSocket::good() &&
(sslState_ == STATE_ACCEPTING || sslState_ == STATE_CONNECTING ||
sslState_ == STATE_ESTABLISHED));
}
// The TAsyncTransport definition of 'good' states that the transport is
// ready to perform reads and writes, so sslState_ == UNINIT must report !good.
// connecting can be true when the sslState_ == UNINIT because the AsyncSocket
// is connected but we haven't initiated the call to SSL_connect.
bool AsyncSSLSocket::connecting() const {
return (!server_ &&
(AsyncSocket::connecting() ||
(AsyncSocket::good() && (sslState_ == STATE_UNINIT ||
sslState_ == STATE_CONNECTING))));
}
bool AsyncSSLSocket::isEorTrackingEnabled() const {
const BIO *wb = SSL_get_wbio(ssl_);
return wb && wb->method == &eorAwareBioMethod;
}
void AsyncSSLSocket::setEorTracking(bool track) {
BIO *wb = SSL_get_wbio(ssl_);
if (!wb) {
throw AsyncSocketException(AsyncSocketException::INVALID_STATE,
"setting EOR tracking without an initialized "
"BIO");
}
if (track) {
if (wb->method != &eorAwareBioMethod) {
// only do this if we didn't
wb->method = &eorAwareBioMethod;
BIO_set_app_data(wb, this);
appEorByteNo_ = 0;
minEorRawByteNo_ = 0;
}
} else if (wb->method == &eorAwareBioMethod) {
wb->method = BIO_s_socket();
BIO_set_app_data(wb, nullptr);
appEorByteNo_ = 0;
minEorRawByteNo_ = 0;
} else {
CHECK(wb->method == BIO_s_socket());
}
}
size_t AsyncSSLSocket::getRawBytesWritten() const {
BIO *b;
if (!ssl_ || !(b = SSL_get_wbio(ssl_))) {
return 0;
}
return BIO_number_written(b);
}
size_t AsyncSSLSocket::getRawBytesReceived() const {
BIO *b;
if (!ssl_ || !(b = SSL_get_rbio(ssl_))) {
return 0;
}
return BIO_number_read(b);
}
void AsyncSSLSocket::invalidState(HandshakeCB* callback) {
LOG(ERROR) << "AsyncSSLSocket(this=" << this << ", fd=" << fd_
<< ", state=" << int(state_) << ", sslState=" << sslState_ << ", "
<< "events=" << eventFlags_ << ", server=" << short(server_) << "): "
<< "sslAccept/Connect() called in invalid "
<< "state, handshake callback " << handshakeCallback_ << ", new callback "
<< callback;
assert(!handshakeTimeout_.isScheduled());
sslState_ = STATE_ERROR;
AsyncSocketException ex(AsyncSocketException::INVALID_STATE,
"sslAccept() called with socket in invalid state");
if (callback) {
callback->handshakeErr(this, ex);
}
// Check the socket state not the ssl state here.
if (state_ != StateEnum::CLOSED || state_ != StateEnum::ERROR) {
failHandshake(__func__, ex);
}
}
void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout,
const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
DestructorGuard dg(this);
assert(eventBase_->isInEventBaseThread());
verifyPeer_ = verifyPeer;
// Make sure we're in the uninitialized state
if (!server_ || sslState_ != STATE_UNINIT || handshakeCallback_ != nullptr) {
return invalidState(callback);
}
sslState_ = STATE_ACCEPTING;
handshakeCallback_ = callback;
if (timeout > 0) {
handshakeTimeout_.scheduleTimeout(timeout);
}
/* register for a read operation (waiting for CLIENT HELLO) */
updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
}
#if OPENSSL_VERSION_NUMBER >= 0x009080bfL
void AsyncSSLSocket::attachSSLContext(
const std::shared_ptr<SSLContext>& ctx) {
// Check to ensure we are in client mode. Changing a server's ssl
// context doesn't make sense since clients of that server would likely
// become confused when the server's context changes.
DCHECK(!server_);
DCHECK(!ctx_);
DCHECK(ctx);
DCHECK(ctx->getSSLCtx());
ctx_ = ctx;
// In order to call attachSSLContext, detachSSLContext must have been
// previously called which sets the socket's context to the dummy
// context. Thus we must acquire this lock.
PortableSpinLockGuard guard(dummyCtxLock);
SSL_set_SSL_CTX(ssl_, ctx->getSSLCtx());
}
void AsyncSSLSocket::detachSSLContext() {
DCHECK(ctx_);
ctx_.reset();
// We aren't using the initial_ctx for now, and it can introduce race
// conditions in the destructor of the SSL object.
#ifndef OPENSSL_NO_TLSEXT
if (ssl_->initial_ctx) {
SSL_CTX_free(ssl_->initial_ctx);
ssl_->initial_ctx = nullptr;
}
#endif
PortableSpinLockGuard guard(dummyCtxLock);
if (nullptr == dummyCtx) {
// We need to lazily initialize the dummy context so we don't
// accidentally override any programmatic settings to openssl
dummyCtx = new SSLContext;
}
// We must remove this socket's references to its context right now
// since this socket could get passed to any thread. If the context has
// had its locking disabled, just doing a set in attachSSLContext()
// would not be thread safe.
SSL_set_SSL_CTX(ssl_, dummyCtx->getSSLCtx());
}
#endif
#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
void AsyncSSLSocket::switchServerSSLContext(
const std::shared_ptr<SSLContext>& handshakeCtx) {
CHECK(server_);
if (sslState_ != STATE_ACCEPTING) {
// We log it here and allow the switch.
// It should not affect our re-negotiation support (which
// is not supported now).
VLOG(6) << "fd=" << getFd()
<< " renegotation detected when switching SSL_CTX";
}
setup_SSL_CTX(handshakeCtx->getSSLCtx());
SSL_CTX_set_info_callback(handshakeCtx->getSSLCtx(),
AsyncSSLSocket::sslInfoCallback);
handshakeCtx_ = handshakeCtx;
SSL_set_SSL_CTX(ssl_, handshakeCtx->getSSLCtx());
}
bool AsyncSSLSocket::isServerNameMatch() const {
CHECK(!server_);
if (!ssl_) {
return false;
}
SSL_SESSION *ss = SSL_get_session(ssl_);
if (!ss) {
return false;
}
return (ss->tlsext_hostname ? true : false);
}
void AsyncSSLSocket::setServerName(std::string serverName) noexcept {
tlsextHostname_ = std::move(serverName);
}
#endif
void AsyncSSLSocket::timeoutExpired() noexcept {
if (state_ == StateEnum::ESTABLISHED &&
(sslState_ == STATE_CACHE_LOOKUP ||
sslState_ == STATE_RSA_ASYNC_PENDING)) {
sslState_ = STATE_ERROR;
// We are expecting a callback in restartSSLAccept. The cache lookup
// and rsa-call necessarily have pointers to this ssl socket, so delay
// the cleanup until he calls us back.
} else {
assert(state_ == StateEnum::ESTABLISHED &&
(sslState_ == STATE_CONNECTING || sslState_ == STATE_ACCEPTING));
DestructorGuard dg(this);
AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
(sslState_ == STATE_CONNECTING) ?
"SSL connect timed out" : "SSL accept timed out");
failHandshake(__func__, ex);
}
}
int AsyncSSLSocket::sslExDataIndex_ = -1;
std::mutex AsyncSSLSocket::mutex_;
int AsyncSSLSocket::getSSLExDataIndex() {
if (sslExDataIndex_ < 0) {
std::lock_guard<std::mutex> g(mutex_);
if (sslExDataIndex_ < 0) {
sslExDataIndex_ = SSL_get_ex_new_index(0,
(void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr);
}
}
return sslExDataIndex_;
}
AsyncSSLSocket* AsyncSSLSocket::getFromSSL(const SSL *ssl) {
return static_cast<AsyncSSLSocket *>(SSL_get_ex_data(ssl,
getSSLExDataIndex()));
}
void AsyncSSLSocket::failHandshake(const char* fn,
const AsyncSocketException& ex) {
startFail();
if (handshakeTimeout_.isScheduled()) {
handshakeTimeout_.cancelTimeout();
}
if (handshakeCallback_ != nullptr) {
HandshakeCB* callback = handshakeCallback_;
handshakeCallback_ = nullptr;
callback->handshakeErr(this, ex);
}
finishFail();
}
void AsyncSSLSocket::invokeHandshakeCB() {
if (handshakeTimeout_.isScheduled()) {
handshakeTimeout_.cancelTimeout();
}
if (handshakeCallback_) {
HandshakeCB* callback = handshakeCallback_;
handshakeCallback_ = nullptr;
callback->handshakeSuc(this);
}
}
void AsyncSSLSocket::connect(ConnectCallback* callback,
const folly::SocketAddress& address,
int timeout,
const OptionMap &options,
const folly::SocketAddress& bindAddr)
noexcept {
assert(!server_);
assert(state_ == StateEnum::UNINIT);
assert(sslState_ == STATE_UNINIT);
AsyncSSLSocketConnector *connector =
new AsyncSSLSocketConnector(this, callback, timeout);
AsyncSocket::connect(connector, address, timeout, options, bindAddr);
}
void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) {
// apply the settings specified in verifyPeer_
if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::USE_CTX) {
if(ctx_->needsPeerVerification()) {
SSL_set_verify(ssl, ctx_->getVerificationMode(),
AsyncSSLSocket::sslVerifyCallback);
}
} else {
if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY ||
verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT) {
SSL_set_verify(ssl, SSLContext::getVerificationMode(verifyPeer_),
AsyncSSLSocket::sslVerifyCallback);
}
}
}
void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
DestructorGuard dg(this);
assert(eventBase_->isInEventBaseThread());
verifyPeer_ = verifyPeer;
// Make sure we're in the uninitialized state
if (server_ || sslState_ != STATE_UNINIT || handshakeCallback_ != nullptr) {
return invalidState(callback);
}
sslState_ = STATE_CONNECTING;
handshakeCallback_ = callback;
try {
ssl_ = ctx_->createSSL();
} catch (std::exception &e) {
sslState_ = STATE_ERROR;
AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
"error calling SSLContext::createSSL()");
LOG(ERROR) << "AsyncSSLSocket::sslConn(this=" << this << ", fd="
<< fd_ << "): " << e.what();
return failHandshake(__func__, ex);
}
applyVerificationOptions(ssl_);
SSL_set_fd(ssl_, fd_);
if (sslSession_ != nullptr) {
SSL_set_session(ssl_, sslSession_);
SSL_SESSION_free(sslSession_);
sslSession_ = nullptr;
}
#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
if (tlsextHostname_.size()) {
SSL_set_tlsext_host_name(ssl_, tlsextHostname_.c_str());
}
#endif
SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
if (timeout > 0) {
handshakeTimeout_.scheduleTimeout(timeout);
}
handleConnect();
}
SSL_SESSION *AsyncSSLSocket::getSSLSession() {
if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
return SSL_get1_session(ssl_);
}
return sslSession_;
}
void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) {
sslSession_ = session;
if (!takeOwnership && session != nullptr) {
// Increment the reference count
CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION);
}
}
void AsyncSSLSocket::getSelectedNextProtocol(const unsigned char** protoName,
unsigned* protoLen) const {
if (!getSelectedNextProtocolNoThrow(protoName, protoLen)) {
throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
"NPN not supported");
}
}
bool AsyncSSLSocket::getSelectedNextProtocolNoThrow(
const unsigned char** protoName,
unsigned* protoLen) const {
*protoName = nullptr;
*protoLen = 0;
#ifdef OPENSSL_NPN_NEGOTIATED
SSL_get0_next_proto_negotiated(ssl_, protoName, protoLen);
return true;
#else
return false;
#endif
}
bool AsyncSSLSocket::getSSLSessionReused() const {
if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
return SSL_session_reused(ssl_);
}
return false;
}
const char *AsyncSSLSocket::getNegotiatedCipherName() const {
return (ssl_ != nullptr) ? SSL_get_cipher_name(ssl_) : nullptr;
}
const char *AsyncSSLSocket::getSSLServerName() const {
#ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
return (ssl_ != nullptr) ? SSL_get_servername(ssl_, TLSEXT_NAMETYPE_host_name)
: nullptr;
#else
throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
"SNI not supported");
#endif
}
const char *AsyncSSLSocket::getSSLServerNameNoThrow() const {
try {
return getSSLServerName();
} catch (AsyncSocketException& ex) {
return nullptr;
}
}
int AsyncSSLSocket::getSSLVersion() const {
return (ssl_ != nullptr) ? SSL_version(ssl_) : 0;
}
int AsyncSSLSocket::getSSLCertSize() const {
int certSize = 0;
X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
if (cert) {
EVP_PKEY *key = X509_get_pubkey(cert);
certSize = EVP_PKEY_bits(key);
EVP_PKEY_free(key);
}
return certSize;
}
bool AsyncSSLSocket::willBlock(int ret, int *errorOut) noexcept {
int error = *errorOut = SSL_get_error(ssl_, ret);
if (error == SSL_ERROR_WANT_READ) {
// Register for read event if not already.
updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
return true;
} else if (error == SSL_ERROR_WANT_WRITE) {
VLOG(3) << "AsyncSSLSocket(fd=" << fd_
<< ", state=" << int(state_) << ", sslState="
<< sslState_ << ", events=" << eventFlags_ << "): "
<< "SSL_ERROR_WANT_WRITE";
// Register for write event if not already.
updateEventRegistration(EventHandler::WRITE, EventHandler::READ);
return true;
#ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
} else if (error == SSL_ERROR_WANT_SESS_CACHE_LOOKUP) {
// We will block but we can't register our own socket. The callback that
// triggered this code will re-call handleAccept at the appropriate time.
// We can only get here if the linked libssl.so has support for this feature
// as well, otherwise SSL_get_error cannot return our error code.
sslState_ = STATE_CACHE_LOOKUP;
// Unregister for all events while blocked here
updateEventRegistration(EventHandler::NONE,
EventHandler::READ | EventHandler::WRITE);
// The timeout (if set) keeps running here
return true;
#endif
#ifdef SSL_ERROR_WANT_RSA_ASYNC_PENDING
} else if (error == SSL_ERROR_WANT_RSA_ASYNC_PENDING) {
// Our custom openssl function has kicked off an async request to do
// modular exponentiation. When that call returns, a callback will
// be invoked that will re-call handleAccept.
sslState_ = STATE_RSA_ASYNC_PENDING;
// Unregister for all events while blocked here
updateEventRegistration(
EventHandler::NONE,
EventHandler::READ | EventHandler::WRITE
);
// The timeout (if set) keeps running here
return true;
#endif
} else {
// SSL_ERROR_ZERO_RETURN is processed here so we can get some detail
// in the log
long lastError = ERR_get_error();
VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
<< "state=" << state_ << ", "
<< "sslState=" << sslState_ << ", "
<< "events=" << std::hex << eventFlags_ << "): "
<< "SSL error: " << error << ", "
<< "errno: " << errno << ", "
<< "ret: " << ret << ", "
<< "read: " << BIO_number_read(SSL_get_rbio(ssl_)) << ", "
<< "written: " << BIO_number_written(SSL_get_wbio(ssl_)) << ", "
<< "func: " << ERR_func_error_string(lastError) << ", "
<< "reason: " << ERR_reason_error_string(lastError);
if (error != SSL_ERROR_SYSCALL) {
if (error == SSL_ERROR_SSL) {
*errorOut = lastError;
}
if ((unsigned long)lastError < 0x8000) {
errno = ENOSYS;
} else {
errno = lastError;
}
}
ERR_clear_error();
return false;
}
}
void AsyncSSLSocket::checkForImmediateRead() noexcept {
// openssl may have buffered data that it read from the socket already.
// In this case we have to process it immediately, rather than waiting for
// the socket to become readable again.
if (ssl_ != nullptr && SSL_pending(ssl_) > 0) {
AsyncSocket::handleRead();
}
}
void
AsyncSSLSocket::restartSSLAccept()
{
VLOG(3) << "AsyncSSLSocket::restartSSLAccept() this=" << this << ", fd=" << fd_
<< ", state=" << int(state_) << ", "
<< "sslState=" << sslState_ << ", events=" << eventFlags_;
DestructorGuard dg(this);
assert(
sslState_ == STATE_CACHE_LOOKUP ||
sslState_ == STATE_RSA_ASYNC_PENDING ||
sslState_ == STATE_ERROR ||
sslState_ == STATE_CLOSED
);
if (sslState_ == STATE_CLOSED) {
// I sure hope whoever closed this socket didn't delete it already,
// but this is not strictly speaking an error
return;
}
if (sslState_ == STATE_ERROR) {
// go straight to fail if timeout expired during lookup
AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
"SSL accept timed out");
failHandshake(__func__, ex);
return;
}
sslState_ = STATE_ACCEPTING;
this->handleAccept();
}
void
AsyncSSLSocket::handleAccept() noexcept {
VLOG(3) << "AsyncSSLSocket::handleAccept() this=" << this
<< ", fd=" << fd_ << ", state=" << int(state_) << ", "
<< "sslState=" << sslState_ << ", events=" << eventFlags_;
assert(server_);
assert(state_ == StateEnum::ESTABLISHED &&
sslState_ == STATE_ACCEPTING);
if (!ssl_) {
/* lazily create the SSL structure */
try {
ssl_ = ctx_->createSSL();
} catch (std::exception &e) {
sslState_ = STATE_ERROR;
AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
"error calling SSLContext::createSSL()");
LOG(ERROR) << "AsyncSSLSocket::handleAccept(this=" << this
<< ", fd=" << fd_ << "): " << e.what();
return failHandshake(__func__, ex);
}
SSL_set_fd(ssl_, fd_);
SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
applyVerificationOptions(ssl_);
}
if (server_ && parseClientHello_) {
SSL_set_msg_callback_arg(ssl_, this);
SSL_set_msg_callback(ssl_, &AsyncSSLSocket::clientHelloParsingCallback);
}
errno = 0;
int ret = SSL_accept(ssl_);
if (ret <= 0) {
int error;
if (willBlock(ret, &error)) {
return;
} else {
sslState_ = STATE_ERROR;
SSLException ex(error, errno);
return failHandshake(__func__, ex);
}
}
handshakeComplete_ = true;
updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
// Move into STATE_ESTABLISHED in the normal case that we are in
// STATE_ACCEPTING.
sslState_ = STATE_ESTABLISHED;
VLOG(3) << "AsyncSSLSocket " << this << ": fd " << fd_
<< " successfully accepted; state=" << int(state_)
<< ", sslState=" << sslState_ << ", events=" << eventFlags_;
// Remember the EventBase we are attached to, before we start invoking any
// callbacks (since the callbacks may call detachEventBase()).
EventBase* originalEventBase = eventBase_;
// Call the accept callback.
invokeHandshakeCB();
// Note that the accept callback may have changed our state.
// (set or unset the read callback, called write(), closed the socket, etc.)
// The following code needs to handle these situations correctly.
//
// If the socket has been closed, readCallback_ and writeReqHead_ will
// always be nullptr, so that will prevent us from trying to read or write.
//
// The main thing to check for is if eventBase_ is still originalEventBase.
// If not, we have been detached from this event base, so we shouldn't
// perform any more operations.
if (eventBase_ != originalEventBase) {
return;
}
AsyncSocket::handleInitialReadWrite();
}
void
AsyncSSLSocket::handleConnect() noexcept {
VLOG(3) << "AsyncSSLSocket::handleConnect() this=" << this
<< ", fd=" << fd_ << ", state=" << int(state_) << ", "
<< "sslState=" << sslState_ << ", events=" << eventFlags_;
assert(!server_);
if (state_ < StateEnum::ESTABLISHED) {
return AsyncSocket::handleConnect();
}
assert(state_ == StateEnum::ESTABLISHED &&
sslState_ == STATE_CONNECTING);
assert(ssl_);
errno = 0;
int ret = SSL_connect(ssl_);
if (ret <= 0) {
int error;
if (willBlock(ret, &error)) {
return;
} else {
sslState_ = STATE_ERROR;
SSLException ex(error, errno);
return failHandshake(__func__, ex);
}
}
handshakeComplete_ = true;
updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
// Move into STATE_ESTABLISHED in the normal case that we are in
// STATE_CONNECTING.
sslState_ = STATE_ESTABLISHED;
VLOG(3) << "AsyncSSLSocket %p: fd %d successfully connected; "
<< "state=" << int(state_) << ", sslState=" << sslState_
<< ", events=" << eventFlags_;
// Remember the EventBase we are attached to, before we start invoking any
// callbacks (since the callbacks may call detachEventBase()).
EventBase* originalEventBase = eventBase_;
// Call the handshake callback.
invokeHandshakeCB();
// Note that the connect callback may have changed our state.
// (set or unset the read callback, called write(), closed the socket, etc.)
// The following code needs to handle these situations correctly.
//
// If the socket has been closed, readCallback_ and writeReqHead_ will
// always be nullptr, so that will prevent us from trying to read or write.
//
// The main thing to check for is if eventBase_ is still originalEventBase.
// If not, we have been detached from this event base, so we shouldn't
// perform any more operations.
if (eventBase_ != originalEventBase) {
return;
}
AsyncSocket::handleInitialReadWrite();
}
void
AsyncSSLSocket::handleRead() noexcept {
VLOG(5) << "AsyncSSLSocket::handleRead() this=" << this << ", fd=" << fd_
<< ", state=" << int(state_) << ", "
<< "sslState=" << sslState_ << ", events=" << eventFlags_;
if (state_ < StateEnum::ESTABLISHED) {
return AsyncSocket::handleRead();
}
if (sslState_ == STATE_ACCEPTING) {
assert(server_);
handleAccept();
return;
}
else if (sslState_ == STATE_CONNECTING) {
assert(!server_);
handleConnect();
return;
}
// Normal read
AsyncSocket::handleRead();
}
ssize_t
AsyncSSLSocket::performRead(void* buf, size_t buflen) {
errno = 0;
ssize_t bytes = SSL_read(ssl_, buf, buflen);
if (server_ && renegotiateAttempted_) {
LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
<< ", sslstate=" << sslState_ << ", events=" << eventFlags_ << "): "
<< "client intitiated SSL renegotiation not permitted";
// We pack our own SSLerr here with a dummy function
errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ,
SSL_CLIENT_RENEGOTIATION_ATTEMPT);
ERR_clear_error();
return READ_ERROR;
}
if (bytes <= 0) {
int error = SSL_get_error(ssl_, bytes);
if (error == SSL_ERROR_WANT_READ) {
// The caller will register for read event if not already.
return READ_BLOCKING;
} else if (error == SSL_ERROR_WANT_WRITE) {
// TODO: Even though we are attempting to read data, SSL_read() may
// need to write data if renegotiation is being performed. We currently
// don't support this and just fail the read.
LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
<< ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
<< "unsupported SSL renegotiation during read",
errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ,
SSL_INVALID_RENEGOTIATION);
ERR_clear_error();
return READ_ERROR;
} else {
// TODO: Fix this code so that it can return a proper error message
// to the callback, rather than relying on AsyncSocket code which
// can't handle SSL errors.
long lastError = ERR_get_error();
VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
<< "state=" << state_ << ", "
<< "sslState=" << sslState_ << ", "
<< "events=" << std::hex << eventFlags_ << "): "
<< "bytes: " << bytes << ", "
<< "error: " << error << ", "
<< "errno: " << errno << ", "
<< "func: " << ERR_func_error_string(lastError) << ", "
<< "reason: " << ERR_reason_error_string(lastError);
ERR_clear_error();
if (zero_return(error, bytes)) {
return bytes;
}
if (error != SSL_ERROR_SYSCALL) {
if ((unsigned long)lastError < 0x8000) {
errno = ENOSYS;
} else {
errno = lastError;
}
}
return READ_ERROR;
}
} else {
appBytesReceived_ += bytes;
return bytes;
}
}
void AsyncSSLSocket::handleWrite() noexcept {
VLOG(5) << "AsyncSSLSocket::handleWrite() this=" << this << ", fd=" << fd_
<< ", state=" << int(state_) << ", "
<< "sslState=" << sslState_ << ", events=" << eventFlags_;
if (state_ < StateEnum::ESTABLISHED) {
return AsyncSocket::handleWrite();
}
if (sslState_ == STATE_ACCEPTING) {
assert(server_);
handleAccept();
return;
}
if (sslState_ == STATE_CONNECTING) {
assert(!server_);
handleConnect();
return;
}
// Normal write
AsyncSocket::handleWrite();
}
ssize_t AsyncSSLSocket::performWrite(const iovec* vec,
uint32_t count,
WriteFlags flags,
uint32_t* countWritten,
uint32_t* partialWritten) {
if (sslState_ != STATE_ESTABLISHED) {
LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
<< ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
<< "TODO: AsyncSSLSocket currently does not support calling "
<< "write() before the handshake has fully completed";
errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE,
SSL_EARLY_WRITE);
return -1;
}
bool cork = isSet(flags, WriteFlags::CORK);
CorkGuard guard(fd_, count > 1, cork, &corked_);
*countWritten = 0;
*partialWritten = 0;
ssize_t totalWritten = 0;
size_t bytesStolenFromNextBuffer = 0;
for (uint32_t i = 0; i < count; i++) {
const iovec* v = vec + i;
size_t offset = bytesStolenFromNextBuffer;
bytesStolenFromNextBuffer = 0;
size_t len = v->iov_len - offset;
const void* buf;
if (len == 0) {
(*countWritten)++;
continue;
}
buf = ((const char*)v->iov_base) + offset;
ssize_t bytes;
errno = 0;
uint32_t buffersStolen = 0;
if ((len < MIN_WRITE_SIZE) && ((i + 1) < count)) {
// Combine this buffer with part or all of the next buffers in
// order to avoid really small-grained calls to SSL_write().
// Each call to SSL_write() produces a separate record in
// the egress SSL stream, and we've found that some low-end
// mobile clients can't handle receiving an HTTP response
// header and the first part of the response body in two
// separate SSL records (even if those two records are in
// the same TCP packet).
char combinedBuf[MIN_WRITE_SIZE];
memcpy(combinedBuf, buf, len);
do {
// INVARIANT: i + buffersStolen == complete chunks serialized
uint32_t nextIndex = i + buffersStolen + 1;
bytesStolenFromNextBuffer = std::min(vec[nextIndex].iov_len,
MIN_WRITE_SIZE - len);
memcpy(combinedBuf + len, vec[nextIndex].iov_base,
bytesStolenFromNextBuffer);
len += bytesStolenFromNextBuffer;
if (bytesStolenFromNextBuffer < vec[nextIndex].iov_len) {
// couldn't steal the whole buffer
break;
} else {
bytesStolenFromNextBuffer = 0;
buffersStolen++;
}
} while ((i + buffersStolen + 1) < count && (len < MIN_WRITE_SIZE));
bytes = eorAwareSSLWrite(
ssl_, combinedBuf, len,
(isSet(flags, WriteFlags::EOR) && i + buffersStolen + 1 == count));
} else {
bytes = eorAwareSSLWrite(ssl_, buf, len,
(isSet(flags, WriteFlags::EOR) && i + 1 == count));
}
if (bytes <= 0) {
int error = SSL_get_error(ssl_, bytes);
if (error == SSL_ERROR_WANT_WRITE) {
// The caller will register for write event if not already.
*partialWritten = offset;
return totalWritten;
} else if (error == SSL_ERROR_WANT_READ) {
// TODO: Even though we are attempting to write data, SSL_write() may
// need to read data if renegotiation is being performed. We currently
// don't support this and just fail the write.
LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
<< ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
<< "unsupported SSL renegotiation during write",
errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE,
SSL_INVALID_RENEGOTIATION);
ERR_clear_error();
return -1;
} else {
// TODO: Fix this code so that it can return a proper error message
// to the callback, rather than relying on AsyncSocket code which
// can't handle SSL errors.
long lastError = ERR_get_error();
VLOG(3) <<
"ERROR: AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
<< ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
<< "SSL error: " << error << ", errno: " << errno
<< ", func: " << ERR_func_error_string(lastError)
<< ", reason: " << ERR_reason_error_string(lastError);
if (error != SSL_ERROR_SYSCALL) {
if ((unsigned long)lastError < 0x8000) {
errno = ENOSYS;
} else {
errno = lastError;
}
}
ERR_clear_error();
if (!zero_return(error, bytes)) {
return -1;
} // else fall through to below to correctly record totalWritten
}
}
totalWritten += bytes;
if (bytes == (ssize_t)len) {
// The full iovec is written.
(*countWritten) += 1 + buffersStolen;
i += buffersStolen;
// continue
} else {
bytes += offset; // adjust bytes to account for all of v
while (bytes >= (ssize_t)v->iov_len) {
// We combined this buf with part or all of the next one, and
// we managed to write all of this buf but not all of the bytes
// from the next one that we'd hoped to write.
bytes -= v->iov_len;
(*countWritten)++;
v = &(vec[++i]);
}
*partialWritten = bytes;
return totalWritten;
}
}
return totalWritten;
}
int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n,
bool eor) {
if (eor && SSL_get_wbio(ssl)->method == &eorAwareBioMethod) {
if (appEorByteNo_) {
// cannot track for more than one app byte EOR
CHECK(appEorByteNo_ == appBytesWritten_ + n);
} else {
appEorByteNo_ = appBytesWritten_ + n;
}
// 1. It is fine to keep updating minEorRawByteNo_.
// 2. It is _min_ in the sense that SSL record will add some overhead.
minEorRawByteNo_ = getRawBytesWritten() + n;
}
n = sslWriteImpl(ssl, buf, n);
if (n > 0) {
appBytesWritten_ += n;
if (appEorByteNo_) {
if (getRawBytesWritten() >= minEorRawByteNo_) {
minEorRawByteNo_ = 0;
}
if(appBytesWritten_ == appEorByteNo_) {
appEorByteNo_ = 0;
} else {
CHECK(appBytesWritten_ < appEorByteNo_);
}
}
}
return n;
}
void
AsyncSSLSocket::sslInfoCallback(const SSL *ssl, int where, int ret) {
AsyncSSLSocket *sslSocket = AsyncSSLSocket::getFromSSL(ssl);
if (sslSocket->handshakeComplete_ && (where & SSL_CB_HANDSHAKE_START)) {
sslSocket->renegotiateAttempted_ = true;
}
}
int AsyncSSLSocket::eorAwareBioWrite(BIO *b, const char *in, int inl) {
int ret;
struct msghdr msg;
struct iovec iov;
int flags = 0;
AsyncSSLSocket *tsslSock;
iov.iov_base = const_cast<char *>(in);
iov.iov_len = inl;
memset(&msg, 0, sizeof(msg));
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
tsslSock =
reinterpret_cast<AsyncSSLSocket*>(BIO_get_app_data(b));
if (tsslSock &&
tsslSock->minEorRawByteNo_ &&
tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
flags = MSG_EOR;
}
errno = 0;
ret = sendmsg(b->num, &msg, flags);
BIO_clear_retry_flags(b);
if (ret <= 0) {
if (BIO_sock_should_retry(ret))
BIO_set_retry_write(b);
}
return(ret);
}
int AsyncSSLSocket::sslVerifyCallback(int preverifyOk,
X509_STORE_CTX* x509Ctx) {
SSL* ssl = (SSL*) X509_STORE_CTX_get_ex_data(
x509Ctx, SSL_get_ex_data_X509_STORE_CTX_idx());
AsyncSSLSocket* self = AsyncSSLSocket::getFromSSL(ssl);
VLOG(3) << "AsyncSSLSocket::sslVerifyCallback() this=" << self << ", "
<< "fd=" << self->fd_ << ", preverifyOk=" << preverifyOk;
return (self->handshakeCallback_) ?
self->handshakeCallback_->handshakeVer(self, preverifyOk, x509Ctx) :
preverifyOk;
}
void AsyncSSLSocket::enableClientHelloParsing() {
parseClientHello_ = true;
clientHelloInfo_.reset(new ClientHelloInfo());
}
void AsyncSSLSocket::resetClientHelloParsing(SSL *ssl) {
SSL_set_msg_callback(ssl, nullptr);
SSL_set_msg_callback_arg(ssl, nullptr);
clientHelloInfo_->clientHelloBuf_.clear();
}
void
AsyncSSLSocket::clientHelloParsingCallback(int written, int version,
int contentType, const void *buf, size_t len, SSL *ssl, void *arg)
{
AsyncSSLSocket *sock = static_cast<AsyncSSLSocket*>(arg);
if (written != 0) {
sock->resetClientHelloParsing(ssl);
return;
}
if (contentType != SSL3_RT_HANDSHAKE) {
sock->resetClientHelloParsing(ssl);
return;
}
if (len == 0) {
return;
}
auto& clientHelloBuf = sock->clientHelloInfo_->clientHelloBuf_;
clientHelloBuf.append(IOBuf::wrapBuffer(buf, len));
try {
Cursor cursor(clientHelloBuf.front());
if (cursor.read<uint8_t>() != SSL3_MT_CLIENT_HELLO) {
sock->resetClientHelloParsing(ssl);
return;
}
if (cursor.totalLength() < 3) {
clientHelloBuf.trimEnd(len);
clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
return;
}
uint32_t messageLength = cursor.read<uint8_t>();
messageLength <<= 8;
messageLength |= cursor.read<uint8_t>();
messageLength <<= 8;
messageLength |= cursor.read<uint8_t>();
if (cursor.totalLength() < messageLength) {
clientHelloBuf.trimEnd(len);
clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
return;
}
sock->clientHelloInfo_->clientHelloMajorVersion_ = cursor.read<uint8_t>();
sock->clientHelloInfo_->clientHelloMinorVersion_ = cursor.read<uint8_t>();
cursor.skip(4); // gmt_unix_time
cursor.skip(28); // random_bytes
cursor.skip(cursor.read<uint8_t>()); // session_id
uint16_t cipherSuitesLength = cursor.readBE<uint16_t>();
for (int i = 0; i < cipherSuitesLength; i += 2) {
sock->clientHelloInfo_->
clientHelloCipherSuites_.push_back(cursor.readBE<uint16_t>());
}
uint8_t compressionMethodsLength = cursor.read<uint8_t>();
for (int i = 0; i < compressionMethodsLength; ++i) {
sock->clientHelloInfo_->
clientHelloCompressionMethods_.push_back(cursor.readBE<uint8_t>());
}
if (cursor.totalLength() > 0) {
uint16_t extensionsLength = cursor.readBE<uint16_t>();
while (extensionsLength) {
sock->clientHelloInfo_->
clientHelloExtensions_.push_back(cursor.readBE<uint16_t>());
extensionsLength -= 2;
uint16_t extensionDataLength = cursor.readBE<uint16_t>();
extensionsLength -= 2;
cursor.skip(extensionDataLength);
extensionsLength -= extensionDataLength;
}
}
} catch (std::out_of_range& e) {
// we'll use what we found and cleanup below.
VLOG(4) << "AsyncSSLSocket::clientHelloParsingCallback(): "
<< "buffer finished unexpectedly." << " AsyncSSLSocket socket=" << sock;
}
sock->resetClientHelloParsing(ssl);
}
} // namespace
/*
* Copyright 2014 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <arpa/inet.h>
#include <iomanip>
#include <openssl/ssl.h>
#include <folly/Optional.h>
#include <folly/String.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/SSLContext.h>
#include <folly/io/async/AsyncTimeout.h>
#include <folly/io/async/TimeoutManager.h>
#include <folly/Bits.h>
#include <folly/io/IOBuf.h>
#include <folly/io/Cursor.h>
using folly::io::Cursor;
using std::unique_ptr;
namespace folly {
class SSLException: public folly::AsyncSocketException {
public:
SSLException(int sslError, int errno_copy);
int getSSLError() const { return error_; }
protected:
int error_;
char msg_[256];
};
/**
* A class for performing asynchronous I/O on an SSL connection.
*
* AsyncSSLSocket allows users to asynchronously wait for data on an
* SSL connection, and to asynchronously send data.
*
* The APIs for reading and writing are intentionally asymmetric.
* Waiting for data to read is a persistent API: a callback is
* installed, and is notified whenever new data is available. It
* continues to be notified of new events until it is uninstalled.
*
* AsyncSSLSocket does not provide read timeout functionality,
* because it typically cannot determine when the timeout should be
* active. Generally, a timeout should only be enabled when
* processing is blocked waiting on data from the remote endpoint.
* For server connections, the timeout should not be active if the
* server is currently processing one or more outstanding requests for
* this connection. For client connections, the timeout should not be
* active if there are no requests pending on the connection.
* Additionally, if a client has multiple pending requests, it will
* ususally want a separate timeout for each request, rather than a
* single read timeout.
*
* The write API is fairly intuitive: a user can request to send a
* block of data, and a callback will be informed once the entire
* block has been transferred to the kernel, or on error.
* AsyncSSLSocket does provide a send timeout, since most callers
* want to give up if the remote end stops responding and no further
* progress can be made sending the data.
*/
class AsyncSSLSocket : public virtual AsyncSocket {
public:
typedef std::unique_ptr<AsyncSSLSocket, Destructor> UniquePtr;
class HandshakeCB {
public:
virtual ~HandshakeCB() {}
/**
* handshakeVer() is invoked during handshaking to give the
* application chance to validate it's peer's certificate.
*
* Note that OpenSSL performs only rudimentary internal
* consistency verification checks by itself. Any other validation
* like whether or not the certificate was issued by a trusted CA.
* The default implementation of this callback mimics what what
* OpenSSL does internally if SSL_VERIFY_PEER is set with no
* verification callback.
*
* See the passages on verify_callback in SSL_CTX_set_verify(3)
* for more details.
*/
virtual bool handshakeVer(AsyncSSLSocket* sock,
bool preverifyOk,
X509_STORE_CTX* ctx) noexcept {
return preverifyOk;
}
/**
* handshakeSuc() is called when a new SSL connection is
* established, i.e., after SSL_accept/connect() returns successfully.
*
* The HandshakeCB will be uninstalled before handshakeSuc()
* is called.
*
* @param sock SSL socket on which the handshake was initiated
*/
virtual void handshakeSuc(AsyncSSLSocket *sock) noexcept = 0;
/**
* handshakeErr() is called if an error occurs while
* establishing the SSL connection.
*
* The HandshakeCB will be uninstalled before handshakeErr()
* is called.
*
* @param sock SSL socket on which the handshake was initiated
* @param ex An exception representing the error.
*/
virtual void handshakeErr(
AsyncSSLSocket *sock,
const AsyncSocketException& ex)
noexcept = 0;
};
class HandshakeTimeout : public AsyncTimeout {
public:
HandshakeTimeout(AsyncSSLSocket* sslSocket, EventBase* eventBase)
: AsyncTimeout(eventBase)
, sslSocket_(sslSocket) {}
virtual void timeoutExpired() noexcept {
sslSocket_->timeoutExpired();
}
private:
AsyncSSLSocket* sslSocket_;
};
/**
* These are passed to the application via errno, packed in an SSL err which
* are outside the valid errno range. The values are chosen to be unique
* against values in ssl.h
*/
enum SSLError {
SSL_CLIENT_RENEGOTIATION_ATTEMPT = 900,
SSL_INVALID_RENEGOTIATION = 901,
SSL_EARLY_WRITE = 902
};
/**
* Create a client AsyncSSLSocket
*/
AsyncSSLSocket(const std::shared_ptr<folly::SSLContext> &ctx,
EventBase* evb);
/**
* Create a server/client AsyncSSLSocket from an already connected
* socket file descriptor.
*
* Note that while AsyncSSLSocket enables TCP_NODELAY for sockets it creates
* when connecting, it does not change the socket options when given an
* existing file descriptor. If callers want TCP_NODELAY enabled when using
* this version of the constructor, they need to explicitly call
* setNoDelay(true) after the constructor returns.
*
* @param ctx SSL context for this connection.
* @param evb EventBase that will manage this socket.
* @param fd File descriptor to take over (should be a connected socket).
* @param server Is socket in server mode?
*/
AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx,
EventBase* evb, int fd, bool server = true);
/**
* Helper function to create a server/client shared_ptr<AsyncSSLSocket>.
*/
static std::shared_ptr<AsyncSSLSocket> newSocket(
const std::shared_ptr<folly::SSLContext>& ctx,
EventBase* evb, int fd, bool server=true) {
return std::shared_ptr<AsyncSSLSocket>(
new AsyncSSLSocket(ctx, evb, fd, server),
Destructor());
}
/**
* Helper function to create a client shared_ptr<AsyncSSLSocket>.
*/
static std::shared_ptr<AsyncSSLSocket> newSocket(
const std::shared_ptr<folly::SSLContext>& ctx,
EventBase* evb) {
return std::shared_ptr<AsyncSSLSocket>(
new AsyncSSLSocket(ctx, evb),
Destructor());
}
#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
/**
* Create a client AsyncSSLSocket with tlsext_servername in
* the Client Hello message.
*/
AsyncSSLSocket(const std::shared_ptr<folly::SSLContext> &ctx,
EventBase* evb,
const std::string& serverName);
/**
* Create a client AsyncSSLSocket from an already connected
* socket file descriptor.
*
* Note that while AsyncSSLSocket enables TCP_NODELAY for sockets it creates
* when connecting, it does not change the socket options when given an
* existing file descriptor. If callers want TCP_NODELAY enabled when using
* this version of the constructor, they need to explicitly call
* setNoDelay(true) after the constructor returns.
*
* @param ctx SSL context for this connection.
* @param evb EventBase that will manage this socket.
* @param fd File descriptor to take over (should be a connected socket).
* @param serverName tlsext_hostname that will be sent in ClientHello.
*/
AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx,
EventBase* evb,
int fd,
const std::string& serverName);
static std::shared_ptr<AsyncSSLSocket> newSocket(
const std::shared_ptr<folly::SSLContext>& ctx,
EventBase* evb,
const std::string& serverName) {
return std::shared_ptr<AsyncSSLSocket>(
new AsyncSSLSocket(ctx, evb, serverName),
Destructor());
}
#endif
/**
* TODO: implement support for SSL renegotiation.
*
* This involves proper handling of the SSL_ERROR_WANT_READ/WRITE
* code as a result of SSL_write/read(), instead of returning an
* error. In that case, the READ/WRITE event should be registered,
* and a flag (e.g., writeBlockedOnRead) should be set to indiciate
* the condition. In the next invocation of read/write callback, if
* the flag is on, performWrite()/performRead() should be called in
* addition to the normal call to performRead()/performWrite(), and
* the flag should be reset.
*/
// Inherit TAsyncTransport methods from AsyncSocket except the
// following.
// See the documentation in TAsyncTransport.h
// TODO: implement graceful shutdown in close()
// TODO: implement detachSSL() that returns the SSL connection
virtual void closeNow();
virtual void shutdownWrite();
virtual void shutdownWriteNow();
virtual bool good() const;
virtual bool connecting() const;
bool isEorTrackingEnabled() const override;
virtual void setEorTracking(bool track);
virtual size_t getRawBytesWritten() const;
virtual size_t getRawBytesReceived() const;
void enableClientHelloParsing();
/**
* Accept an SSL connection on the socket.
*
* The callback will be invoked and uninstalled when an SSL
* connection has been established on the underlying socket.
* The value of verifyPeer determines the client verification method.
* By default, its set to use the value in the underlying context
*
* @param callback callback object to invoke on success/failure
* @param timeout timeout for this function in milliseconds, or 0 for no
* timeout
* @param verifyPeer SSLVerifyPeerEnum uses the options specified in the
* context by default, can be set explcitly to override the
* method in the context
*/
virtual void sslAccept(HandshakeCB* callback, uint32_t timeout = 0,
const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer =
folly::SSLContext::SSLVerifyPeerEnum::USE_CTX);
/**
* Invoke SSL accept following an asynchronous session cache lookup
*/
void restartSSLAccept();
/**
* Connect to the given address, invoking callback when complete or on error
*
* Note timeout applies to TCP + SSL connection time
*/
void connect(ConnectCallback* callback,
const folly::SocketAddress& address,
int timeout = 0,
const OptionMap &options = emptyOptionMap,
const folly::SocketAddress& bindAddr = anyAddress)
noexcept;
using AsyncSocket::connect;
/**
* Initiate an SSL connection on the socket
* THe callback will be invoked and uninstalled when an SSL connection
* has been establshed on the underlying socket.
* The verification option verifyPeer is applied if its passed explicitly.
* If its not, the options in SSLContext set on the underying SSLContext
* are applied.
*
* @param callback callback object to invoke on success/failure
* @param timeout timeout for this function in milliseconds, or 0 for no
* timeout
* @param verifyPeer SSLVerifyPeerEnum uses the options specified in the
* context by default, can be set explcitly to override the
* method in the context. If verification is turned on sets
* SSL_VERIFY_PEER and invokes
* HandshakeCB::handshakeVer().
*/
virtual void sslConn(HandshakeCB *callback, uint64_t timeout = 0,
const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer =
folly::SSLContext::SSLVerifyPeerEnum::USE_CTX);
enum SSLStateEnum {
STATE_UNINIT,
STATE_ACCEPTING,
STATE_CACHE_LOOKUP,
STATE_RSA_ASYNC_PENDING,
STATE_CONNECTING,
STATE_ESTABLISHED,
STATE_REMOTE_CLOSED, /// remote end closed; we can still write
STATE_CLOSING, ///< close() called, but waiting on writes to complete
/// close() called with pending writes, before connect() has completed
STATE_CONNECTING_CLOSING,
STATE_CLOSED,
STATE_ERROR
};
SSLStateEnum getSSLState() const { return sslState_;}
/**
* Get a handle to the negotiated SSL session. This increments the session
* refcount and must be deallocated by the caller.
*/
SSL_SESSION *getSSLSession();
/**
* Set the SSL session to be used during sslConn. AsyncSSLSocket will
* hold a reference to the session until it is destroyed or released by the
* underlying SSL structure.
*
* @param takeOwnership if true, AsyncSSLSocket will assume the caller's
* reference count to session.
*/
void setSSLSession(SSL_SESSION *session, bool takeOwnership = false);
/**
* Get the name of the protocol selected by the client during
* Next Protocol Negotiation (NPN)
*
* Throw an exception if openssl does not support NPN
*
* @param protoName Name of the protocol (not guaranteed to be
* null terminated); will be set to nullptr if
* the client did not negotiate a protocol.
* Note: the AsyncSSLSocket retains ownership
* of this string.
* @param protoNameLen Length of the name.
*/
virtual void getSelectedNextProtocol(const unsigned char** protoName,
unsigned* protoLen) const;
/**
* Get the name of the protocol selected by the client during
* Next Protocol Negotiation (NPN)
*
* @param protoName Name of the protocol (not guaranteed to be
* null terminated); will be set to nullptr if
* the client did not negotiate a protocol.
* Note: the AsyncSSLSocket retains ownership
* of this string.
* @param protoNameLen Length of the name.
* @return false if openssl does not support NPN
*/
virtual bool getSelectedNextProtocolNoThrow(const unsigned char** protoName,
unsigned* protoLen) const;
/**
* Determine if the session specified during setSSLSession was reused
* or if the server rejected it and issued a new session.
*/
bool getSSLSessionReused() const;
/**
* true if the session was resumed using session ID
*/
bool sessionIDResumed() const { return sessionIDResumed_; }
void setSessionIDResumed(bool resumed) {
sessionIDResumed_ = resumed;
}
/**
* Get the negociated cipher name for this SSL connection.
* Returns the cipher used or the constant value "NONE" when no SSL session
* has been established.
*/
const char *getNegotiatedCipherName() const;
/**
* Get the server name for this SSL connection.
* Returns the server name used or the constant value "NONE" when no SSL
* session has been established.
* If openssl has no SNI support, throw TTransportException.
*/
const char *getSSLServerName() const;
/**
* Get the server name for this SSL connection.
* Returns the server name used or the constant value "NONE" when no SSL
* session has been established.
* If openssl has no SNI support, return "NONE"
*/
const char *getSSLServerNameNoThrow() const;
/**
* Get the SSL version for this connection.
* Possible return values are SSL2_VERSION, SSL3_VERSION, TLS1_VERSION,
* with hexa representations 0x200, 0x300, 0x301,
* or 0 if no SSL session has been established.
*/
int getSSLVersion() const;
/**
* Get the certificate size used for this SSL connection.
*/
int getSSLCertSize() const;
/* Get the number of bytes read from the wire (including protocol
* overhead). Returns 0 once the connection has been closed.
*/
unsigned long getBytesRead() const {
if (ssl_ != nullptr) {
return BIO_number_read(SSL_get_rbio(ssl_));
}
return 0;
}
/* Get the number of bytes written to the wire (including protocol
* overhead). Returns 0 once the connection has been closed.
*/
unsigned long getBytesWritten() const {
if (ssl_ != nullptr) {
return BIO_number_written(SSL_get_wbio(ssl_));
}
return 0;
}
virtual void attachEventBase(EventBase* eventBase) {
AsyncSocket::attachEventBase(eventBase);
handshakeTimeout_.attachEventBase(eventBase);
}
virtual void detachEventBase() {
AsyncSocket::detachEventBase();
handshakeTimeout_.detachEventBase();
}
virtual void attachTimeoutManager(TimeoutManager* manager) {
handshakeTimeout_.attachTimeoutManager(manager);
}
virtual void detachTimeoutManager() {
handshakeTimeout_.detachTimeoutManager();
}
#if OPENSSL_VERSION_NUMBER >= 0x009080bfL
/**
* This function will set the SSL context for this socket to the
* argument. This should only be used on client SSL Sockets that have
* already called detachSSLContext();
*/
void attachSSLContext(const std::shared_ptr<folly::SSLContext>& ctx);
/**
* Detaches the SSL context for this socket.
*/
void detachSSLContext();
#endif
#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
/**
* Switch the SSLContext to continue the SSL handshake.
* It can only be used in server mode.
*/
void switchServerSSLContext(
const std::shared_ptr<folly::SSLContext>& handshakeCtx);
/**
* Did server recognize/support the tlsext_hostname in Client Hello?
* It can only be used in client mode.
*
* @return true - tlsext_hostname is matched by the server
* false - tlsext_hostname is not matched or
* is not supported by server
*/
bool isServerNameMatch() const;
/**
* Set the SNI hostname that we'll advertise to the server in the
* ClientHello message.
*/
void setServerName(std::string serverName) noexcept;
#endif
void timeoutExpired() noexcept;
/**
* Get the list of supported ciphers sent by the client in the client's
* preference order.
*/
void getSSLClientCiphers(std::string& clientCiphers) {
std::stringstream ciphersStream;
std::string cipherName;
if (parseClientHello_ == false
|| clientHelloInfo_->clientHelloCipherSuites_.empty()) {
clientCiphers = "";
return;
}
for (auto originalCipherCode : clientHelloInfo_->clientHelloCipherSuites_)
{
// OpenSSL expects code as a big endian char array
auto cipherCode = htons(originalCipherCode);
#if defined(SSL_OP_NO_TLSv1_2)
const SSL_CIPHER* cipher =
TLSv1_2_method()->get_cipher_by_char((unsigned char*)&cipherCode);
#elif defined(SSL_OP_NO_TLSv1_1)
const SSL_CIPHER* cipher =
TLSv1_1_method()->get_cipher_by_char((unsigned char*)&cipherCode);
#elif defined(SSL_OP_NO_TLSv1)
const SSL_CIPHER* cipher =
TLSv1_method()->get_cipher_by_char((unsigned char*)&cipherCode);
#else
const SSL_CIPHER* cipher =
SSLv3_method()->get_cipher_by_char((unsigned char*)&cipherCode);
#endif
if (cipher == nullptr) {
ciphersStream << std::setfill('0') << std::setw(4) << std::hex
<< originalCipherCode << ":";
} else {
ciphersStream << SSL_CIPHER_get_name(cipher) << ":";
}
}
clientCiphers = ciphersStream.str();
clientCiphers.erase(clientCiphers.end() - 1);
}
/**
* Get the list of compression methods sent by the client in TLS Hello.
*/
std::string getSSLClientComprMethods() {
if (!parseClientHello_) {
return "";
}
return folly::join(":", clientHelloInfo_->clientHelloCompressionMethods_);
}
/**
* Get the list of TLS extensions sent by the client in the TLS Hello.
*/
std::string getSSLClientExts() {
if (!parseClientHello_) {
return "";
}
return folly::join(":", clientHelloInfo_->clientHelloExtensions_);
}
/**
* Get the list of shared ciphers between the server and the client.
* Works well for only SSLv2, not so good for SSLv3 or TLSv1.
*/
void getSSLSharedCiphers(std::string& sharedCiphers) {
char ciphersBuffer[1024];
ciphersBuffer[0] = '\0';
SSL_get_shared_ciphers(ssl_, ciphersBuffer, sizeof(ciphersBuffer) - 1);
sharedCiphers = ciphersBuffer;
}
/**
* Get the list of ciphers supported by the server in the server's
* preference order.
*/
void getSSLServerCiphers(std::string& serverCiphers) {
serverCiphers = SSL_get_cipher_list(ssl_, 0);
int i = 1;
const char *cipher;
while ((cipher = SSL_get_cipher_list(ssl_, i)) != nullptr) {
serverCiphers.append(":");
serverCiphers.append(cipher);
i++;
}
}
static int getSSLExDataIndex();
static AsyncSSLSocket* getFromSSL(const SSL *ssl);
static int eorAwareBioWrite(BIO *b, const char *in, int inl);
void resetClientHelloParsing(SSL *ssl);
static void clientHelloParsingCallback(int write_p, int version,
int content_type, const void *buf, size_t len, SSL *ssl, void *arg);
struct ClientHelloInfo {
folly::IOBufQueue clientHelloBuf_;
uint8_t clientHelloMajorVersion_;
uint8_t clientHelloMinorVersion_;
std::vector<uint16_t> clientHelloCipherSuites_;
std::vector<uint8_t> clientHelloCompressionMethods_;
std::vector<uint16_t> clientHelloExtensions_;
};
// For unit-tests
ClientHelloInfo* getClientHelloInfo() {
return clientHelloInfo_.get();
}
protected:
/**
* Protected destructor.
*
* Users of AsyncSSLSocket must never delete it directly. Instead, invoke
* destroy() instead. (See the documentation in TDelayedDestruction.h for
* more details.)
*/
~AsyncSSLSocket();
// Inherit event notification methods from AsyncSocket except
// the following.
void handleRead() noexcept;
void handleWrite() noexcept;
void handleAccept() noexcept;
void handleConnect() noexcept;
void invalidState(HandshakeCB* callback);
bool willBlock(int ret, int *errorOut) noexcept;
virtual void checkForImmediateRead() noexcept;
// AsyncSocket calls this at the wrong time for SSL
void handleInitialReadWrite() noexcept {}
ssize_t performRead(void* buf, size_t buflen);
ssize_t performWrite(const iovec* vec, uint32_t count, WriteFlags flags,
uint32_t* countWritten, uint32_t* partialWritten);
// This virtual wrapper around SSL_write exists solely for testing/mockability
virtual int sslWriteImpl(SSL *ssl, const void *buf, int n) {
return SSL_write(ssl, buf, n);
}
/**
* Apply verification options passed to sslConn/sslAccept or those set
* in the underlying SSLContext object.
*
* @param ssl pointer to the SSL object on which verification options will be
* applied. If verifyPeer_ was explicitly set either via sslConn/sslAccept,
* those options override the settings in the underlying SSLContext.
*/
void applyVerificationOptions(SSL * ssl);
/**
* A SSL_write wrapper that understand EOR
*
* @param ssl: SSL* object
* @param buf: Buffer to be written
* @param n: Number of bytes to be written
* @param eor: Does the last byte (buf[n-1]) have the app-last-byte?
* @return: The number of app bytes successfully written to the socket
*/
int eorAwareSSLWrite(SSL *ssl, const void *buf, int n, bool eor);
// Inherit error handling methods from AsyncSocket, plus the following.
void failHandshake(const char* fn, const AsyncSocketException& ex);
void invokeHandshakeCB();
static void sslInfoCallback(const SSL *ssl, int type, int val);
static std::mutex mutex_;
static int sslExDataIndex_;
// Whether we've applied the TCP_CORK option to the socket
bool corked_{false};
// SSL related members.
bool server_{false};
// Used to prevent client-initiated renegotiation. Note that AsyncSSLSocket
// doesn't fully support renegotiation, so we could just fail all attempts
// to enforce this. Once it is supported, we should make it an option
// to disable client-initiated renegotiation.
bool handshakeComplete_{false};
bool renegotiateAttempted_{false};
SSLStateEnum sslState_{STATE_UNINIT};
std::shared_ptr<folly::SSLContext> ctx_;
// Callback for SSL_accept() or SSL_connect()
HandshakeCB* handshakeCallback_{nullptr};
SSL* ssl_{nullptr};
SSL_SESSION *sslSession_{nullptr};
HandshakeTimeout handshakeTimeout_;
// whether the SSL session was resumed using session ID or not
bool sessionIDResumed_{false};
// The app byte num that we are tracking for the MSG_EOR
// Only one app EOR byte can be tracked.
size_t appEorByteNo_{0};
// When openssl is about to sendmsg() across the minEorRawBytesNo_,
// it will pass MSG_EOR to sendmsg().
size_t minEorRawByteNo_{0};
#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
std::shared_ptr<folly::SSLContext> handshakeCtx_;
std::string tlsextHostname_;
#endif
folly::SSLContext::SSLVerifyPeerEnum
verifyPeer_{folly::SSLContext::SSLVerifyPeerEnum::USE_CTX};
// Callback for SSL_CTX_set_verify()
static int sslVerifyCallback(int preverifyOk, X509_STORE_CTX* ctx);
bool parseClientHello_{false};
unique_ptr<ClientHelloInfo> clientHelloInfo_;
};
} // namespace
......@@ -175,6 +175,14 @@ class AsyncSocket::WriteRequest {
struct iovec writeOps_[]; ///< write operation(s) list
};
AsyncSocket::AsyncSocket()
: eventBase_(nullptr)
, writeTimeout_(this, nullptr)
, ioHandler_(this, nullptr) {
VLOG(5) << "new AsyncSocket(" << ")";
init();
}
AsyncSocket::AsyncSocket(EventBase* evb)
: eventBase_(evb)
, writeTimeout_(this, evb)
......
......@@ -184,6 +184,7 @@ class AsyncSocket : virtual public AsyncTransport {
noexcept = 0;
};
explicit AsyncSocket();
/**
* Create a new unconnected AsyncSocket.
*
......@@ -549,6 +550,14 @@ class AsyncSocket : virtual public AsyncTransport {
return setsockopt(fd_, level, optname, optval, sizeof(T));
}
enum class StateEnum : uint8_t {
UNINIT,
CONNECTING,
ESTABLISHED,
CLOSED,
ERROR
};
protected:
enum ReadResultEnum {
READ_EOF = 0,
......@@ -565,14 +574,6 @@ class AsyncSocket : virtual public AsyncTransport {
*/
~AsyncSocket();
enum class StateEnum : uint8_t {
UNINIT,
CONNECTING,
ESTABLISHED,
CLOSED,
ERROR
};
friend std::ostream& operator << (std::ostream& os, const StateEnum& state);
enum ShutdownFlags {
......
......@@ -44,9 +44,11 @@ AsyncTimeout::AsyncTimeout(EventBase* eventBase)
event_set(&event_, -1, EV_TIMEOUT, &AsyncTimeout::libeventCallback, this);
event_.ev_base = nullptr;
if (eventBase) {
timeoutManager_->attachTimeoutManager(
this,
TimeoutManager::InternalEnum::NORMAL);
}
RequestContext::getStaticContext();
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment