Commit fe6985dd authored by Alan Frindell's avatar Alan Frindell Committed by afrind

Move AsyncSocket tests from thrift to folly

Summary: These tests belong with the code that they test.  The old tests had a couple dependencies on TSocket/TSSLSocket, so I wrote a BlockingSocket wrapper for AsyncSocket/AsyncSSLSocket

Test Plan: Ran the tests

Reviewed By: alandau@fb.com

Subscribers: doug, net-systems@, alandau, bmatheny, mshneer, folly-diffs@, yfeldblum, chalfant

FB internal diff: D1959955

Signature: t1:1959955:1427917833:73d334846cf248f8bb215f3eb5b596df7f7cee4f
parent 773ee3cb
......@@ -162,6 +162,8 @@ nobase_follyinclude_HEADERS = \
io/async/Request.h \
io/async/SSLContext.h \
io/async/TimeoutManager.h \
io/async/test/AsyncSSLSocketTest.h \
io/async/test/BlockingSocket.h \
io/async/test/TimeUtil.h \
io/async/test/UndelayedDestruction.h \
io/async/test/Util.h \
......
/*
* Copyright 2015 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/io/async/test/AsyncSSLSocketTest.h>
#include <signal.h>
#include <pthread.h>
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/EventBase.h>
#include <folly/SocketAddress.h>
#include <folly/io/async/test/BlockingSocket.h>
#include <gtest/gtest.h>
#include <iostream>
#include <list>
#include <set>
#include <unistd.h>
#include <fcntl.h>
#include <poll.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/tcp.h>
#include <folly/io/Cursor.h>
using std::string;
using std::vector;
using std::min;
using std::cerr;
using std::endl;
using std::list;
namespace folly {
uint32_t TestSSLAsyncCacheServer::asyncCallbacks_ = 0;
uint32_t TestSSLAsyncCacheServer::asyncLookups_ = 0;
uint32_t TestSSLAsyncCacheServer::lookupDelay_ = 0;
const char* testCert = "folly/io/async/test/certs/tests-cert.pem";
const char* testKey = "folly/io/async/test/certs/tests-key.pem";
const char* testCA = "folly/io/async/test/certs/ca-cert.pem";
TestSSLServer::TestSSLServer(SSLServerAcceptCallbackBase *acb) :
ctx_(new folly::SSLContext),
acb_(acb),
socket_(new folly::AsyncServerSocket(&evb_)) {
// Set up the SSL context
ctx_->loadCertificate(testCert);
ctx_->loadPrivateKey(testKey);
ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
acb_->ctx_ = ctx_;
acb_->base_ = &evb_;
//set up the listening socket
socket_->bind(0);
socket_->getAddress(&address_);
socket_->listen(100);
socket_->addAcceptCallback(acb_, &evb_);
socket_->startAccepting();
int ret = pthread_create(&thread_, nullptr, Main, this);
assert(ret == 0);
std::cerr << "Accepting connections on " << address_ << std::endl;
}
void getfds(int fds[2]) {
if (socketpair(PF_LOCAL, SOCK_STREAM, 0, fds) != 0) {
FAIL() << "failed to create socketpair: " << strerror(errno);
}
for (int idx = 0; idx < 2; ++idx) {
int flags = fcntl(fds[idx], F_GETFL, 0);
if (flags == -1) {
FAIL() << "failed to get flags for socket " << idx << ": "
<< strerror(errno);
}
if (fcntl(fds[idx], F_SETFL, flags | O_NONBLOCK) != 0) {
FAIL() << "failed to put socket " << idx << " in non-blocking mode: "
<< strerror(errno);
}
}
}
void getctx(
std::shared_ptr<folly::SSLContext> clientCtx,
std::shared_ptr<folly::SSLContext> serverCtx) {
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
serverCtx->loadCertificate(
testCert);
serverCtx->loadPrivateKey(
testKey);
}
void sslsocketpair(
EventBase* eventBase,
AsyncSSLSocket::UniquePtr* clientSock,
AsyncSSLSocket::UniquePtr* serverSock) {
auto clientCtx = std::make_shared<folly::SSLContext>();
auto serverCtx = std::make_shared<folly::SSLContext>();
int fds[2];
getfds(fds);
getctx(clientCtx, serverCtx);
clientSock->reset(new AsyncSSLSocket(
clientCtx, eventBase, fds[0], false));
serverSock->reset(new AsyncSSLSocket(
serverCtx, eventBase, fds[1], true));
// (*clientSock)->setSendTimeout(100);
// (*serverSock)->setSendTimeout(100);
}
/**
* Test connecting to, writing to, reading from, and closing the
* connection to the SSL server.
*/
TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback);
// Set up SSL context.
std::shared_ptr<SSLContext> sslContext(new SSLContext());
sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
//sslContext->loadTrustedCertificates("./trusted-ca-certificate.pem");
//sslContext->authenticate(true, false);
// connect
auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
sslContext);
socket->open();
// write()
uint8_t buf[128];
memset(buf, 'a', sizeof(buf));
socket->write(buf, sizeof(buf));
// read()
uint8_t readbuf[128];
uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
EXPECT_EQ(bytesRead, 128);
EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
// close()
socket->close();
cerr << "ConnectWriteReadClose test completed" << endl;
}
/**
* Negative test for handshakeError().
*/
TEST(AsyncSSLSocketTest, HandshakeError) {
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
HandshakeErrorCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback);
// Set up SSL context.
std::shared_ptr<SSLContext> sslContext(new SSLContext());
sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
// connect
auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
sslContext);
// read()
bool ex = false;
try {
socket->open();
uint8_t readbuf[128];
uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
} catch (AsyncSocketException &e) {
ex = true;
}
EXPECT_TRUE(ex);
// close()
socket->close();
cerr << "HandshakeError test completed" << endl;
}
/**
* Negative test for readError().
*/
TEST(AsyncSSLSocketTest, ReadError) {
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadErrorCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback);
// Set up SSL context.
std::shared_ptr<SSLContext> sslContext(new SSLContext());
sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
// connect
auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
sslContext);
socket->open();
// write something to trigger ssl handshake
uint8_t buf[128];
memset(buf, 'a', sizeof(buf));
socket->write(buf, sizeof(buf));
socket->close();
cerr << "ReadError test completed" << endl;
}
/**
* Negative test for writeError().
*/
TEST(AsyncSSLSocketTest, WriteError) {
// Start listening on a local port
WriteCallbackBase writeCallback;
WriteErrorCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback);
// Set up SSL context.
std::shared_ptr<SSLContext> sslContext(new SSLContext());
sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
// connect
auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
sslContext);
socket->open();
// write something to trigger ssl handshake
uint8_t buf[128];
memset(buf, 'a', sizeof(buf));
socket->write(buf, sizeof(buf));
socket->close();
cerr << "WriteError test completed" << endl;
}
/**
* Test a socket with TCP_NODELAY unset.
*/
TEST(AsyncSSLSocketTest, SocketWithDelay) {
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback);
// Set up SSL context.
std::shared_ptr<SSLContext> sslContext(new SSLContext());
sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
// connect
auto socket = std::make_shared<BlockingSocket>(server.getAddress(),
sslContext);
socket->open();
// write()
uint8_t buf[128];
memset(buf, 'a', sizeof(buf));
socket->write(buf, sizeof(buf));
// read()
uint8_t readbuf[128];
uint32_t bytesRead = socket->readAll(readbuf, sizeof(readbuf));
EXPECT_EQ(bytesRead, 128);
EXPECT_EQ(memcmp(buf, readbuf, bytesRead), 0);
// close()
socket->close();
cerr << "SocketWithDelay test completed" << endl;
}
TEST(AsyncSSLSocketTest, NpnTestOverlap) {
EventBase eventBase;
std::shared_ptr<SSLContext> clientCtx(new SSLContext);
std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
int fds[2];
getfds(fds);
getctx(clientCtx, serverCtx);
clientCtx->setAdvertisedNextProtocols({"blub","baz"});
serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
NpnClient client(std::move(clientSock));
NpnServer server(std::move(serverSock));
eventBase.loop();
EXPECT_TRUE(client.nextProtoLength != 0);
EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
server.nextProtoLength), 0);
string selected((const char*)client.nextProto, client.nextProtoLength);
EXPECT_EQ(selected.compare("baz"), 0);
}
TEST(AsyncSSLSocketTest, NpnTestUnset) {
// Identical to above test, except that we want unset NPN before
// looping.
EventBase eventBase;
std::shared_ptr<SSLContext> clientCtx(new SSLContext);
std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
int fds[2];
getfds(fds);
getctx(clientCtx, serverCtx);
clientCtx->setAdvertisedNextProtocols({"blub","baz"});
serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
// unsetting NPN for any of [client, server] is enought to make NPN not
// work
clientCtx->unsetNextProtocols();
NpnClient client(std::move(clientSock));
NpnServer server(std::move(serverSock));
eventBase.loop();
EXPECT_TRUE(client.nextProtoLength == 0);
EXPECT_TRUE(server.nextProtoLength == 0);
EXPECT_TRUE(client.nextProto == nullptr);
EXPECT_TRUE(server.nextProto == nullptr);
}
TEST(AsyncSSLSocketTest, NpnTestNoOverlap) {
EventBase eventBase;
std::shared_ptr<SSLContext> clientCtx(new SSLContext);
std::shared_ptr<SSLContext> serverCtx(new SSLContext);;
int fds[2];
getfds(fds);
getctx(clientCtx, serverCtx);
clientCtx->setAdvertisedNextProtocols({"blub"});
serverCtx->setAdvertisedNextProtocols({"foo","bar","baz"});
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
NpnClient client(std::move(clientSock));
NpnServer server(std::move(serverSock));
eventBase.loop();
EXPECT_TRUE(client.nextProtoLength != 0);
EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
server.nextProtoLength), 0);
string selected((const char*)client.nextProto, client.nextProtoLength);
EXPECT_EQ(selected.compare("blub"), 0);
}
TEST(AsyncSSLSocketTest, RandomizedNpnTest) {
// Probability that this test will fail is 2^-64, which could be considered
// as negligible.
const int kTries = 64;
std::set<string> selectedProtocols;
for (int i = 0; i < kTries; ++i) {
EventBase eventBase;
std::shared_ptr<SSLContext> clientCtx = std::make_shared<SSLContext>();
std::shared_ptr<SSLContext> serverCtx = std::make_shared<SSLContext>();
int fds[2];
getfds(fds);
getctx(clientCtx, serverCtx);
clientCtx->setAdvertisedNextProtocols({"foo", "bar", "baz"});
serverCtx->setRandomizedAdvertisedNextProtocols({{1, {"foo"}},
{1, {"bar"}}});
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
NpnClient client(std::move(clientSock));
NpnServer server(std::move(serverSock));
eventBase.loop();
EXPECT_TRUE(client.nextProtoLength != 0);
EXPECT_EQ(client.nextProtoLength, server.nextProtoLength);
EXPECT_EQ(memcmp(client.nextProto, server.nextProto,
server.nextProtoLength), 0);
string selected((const char*)client.nextProto, client.nextProtoLength);
selectedProtocols.insert(selected);
}
EXPECT_EQ(selectedProtocols.size(), 2);
}
#ifndef OPENSSL_NO_TLSEXT
/**
* 1. Client sends TLSEXT_HOSTNAME in client hello.
* 2. Server found a match SSL_CTX and use this SSL_CTX to
* continue the SSL handshake.
* 3. Server sends back TLSEXT_HOSTNAME in server hello.
*/
TEST(AsyncSSLSocketTest, SNITestMatch) {
EventBase eventBase;
std::shared_ptr<SSLContext> clientCtx(new SSLContext);
std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
// Use the same SSLContext to continue the handshake after
// tlsext_hostname match.
std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
const std::string serverName("xyz.newdev.facebook.com");
int fds[2];
getfds(fds);
getctx(clientCtx, dfServerCtx);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], serverName));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SNIClient client(std::move(clientSock));
SNIServer server(std::move(serverSock),
dfServerCtx,
hskServerCtx,
serverName);
eventBase.loop();
EXPECT_TRUE(client.serverNameMatch);
EXPECT_TRUE(server.serverNameMatch);
}
/**
* 1. Client sends TLSEXT_HOSTNAME in client hello.
* 2. Server cannot find a matching SSL_CTX and continue to use
* the current SSL_CTX to do the handshake.
* 3. Server does not send back TLSEXT_HOSTNAME in server hello.
*/
TEST(AsyncSSLSocketTest, SNITestNotMatch) {
EventBase eventBase;
std::shared_ptr<SSLContext> clientCtx(new SSLContext);
std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
// Use the same SSLContext to continue the handshake after
// tlsext_hostname match.
std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
const std::string clientRequestingServerName("foo.com");
const std::string serverExpectedServerName("xyz.newdev.facebook.com");
int fds[2];
getfds(fds);
getctx(clientCtx, dfServerCtx);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx,
&eventBase,
fds[0],
clientRequestingServerName));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SNIClient client(std::move(clientSock));
SNIServer server(std::move(serverSock),
dfServerCtx,
hskServerCtx,
serverExpectedServerName);
eventBase.loop();
EXPECT_TRUE(!client.serverNameMatch);
EXPECT_TRUE(!server.serverNameMatch);
}
/**
* 1. Client does not send TLSEXT_HOSTNAME in client hello.
* 2. Server does not send back TLSEXT_HOSTNAME in server hello.
*/
TEST(AsyncSSLSocketTest, SNITestClientHelloNoHostname) {
EventBase eventBase;
std::shared_ptr<SSLContext> clientCtx(new SSLContext);
std::shared_ptr<SSLContext> dfServerCtx(new SSLContext);
// Use the same SSLContext to continue the handshake after
// tlsext_hostname match.
std::shared_ptr<SSLContext> hskServerCtx(dfServerCtx);
const std::string serverExpectedServerName("xyz.newdev.facebook.com");
int fds[2];
getfds(fds);
getctx(clientCtx, dfServerCtx);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SNIClient client(std::move(clientSock));
SNIServer server(std::move(serverSock),
dfServerCtx,
hskServerCtx,
serverExpectedServerName);
eventBase.loop();
EXPECT_TRUE(!client.serverNameMatch);
EXPECT_TRUE(!server.serverNameMatch);
}
#endif
/**
* Test SSL client socket
*/
TEST(AsyncSSLSocketTest, SSLClientTest) {
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback);
// Set up SSL client
EventBase eventBase;
std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
1));
client->connect();
EventBaseAborter eba(&eventBase, 3000);
eventBase.loop();
EXPECT_EQ(client->getMiss(), 1);
EXPECT_EQ(client->getHit(), 0);
cerr << "SSLClientTest test completed" << endl;
}
/**
* Test SSL client socket session re-use
*/
TEST(AsyncSSLSocketTest, SSLClientTestReuse) {
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback);
// Set up SSL client
EventBase eventBase;
std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
10));
client->connect();
EventBaseAborter eba(&eventBase, 3000);
eventBase.loop();
EXPECT_EQ(client->getMiss(), 1);
EXPECT_EQ(client->getHit(), 9);
cerr << "SSLClientTestReuse test completed" << endl;
}
/**
* Test SSL client socket timeout
*/
TEST(AsyncSSLSocketTest, SSLClientTimeoutTest) {
// Start listening on a local port
EmptyReadCallback readCallback;
HandshakeCallback handshakeCallback(&readCallback,
HandshakeCallback::EXPECT_ERROR);
HandshakeTimeoutCallback acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback);
// Set up SSL client
EventBase eventBase;
std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
1, 10));
client->connect(true /* write before connect completes */);
EventBaseAborter eba(&eventBase, 3000);
eventBase.loop();
usleep(100000);
// This is checking that the connectError callback precedes any queued
// writeError callbacks. This matches AsyncSocket's behavior
EXPECT_EQ(client->getWriteAfterConnectErrors(), 1);
EXPECT_EQ(client->getErrors(), 1);
EXPECT_EQ(client->getMiss(), 0);
EXPECT_EQ(client->getHit(), 0);
cerr << "SSLClientTimeoutTest test completed" << endl;
}
/**
* Test SSL server async cache
*/
TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTest) {
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
TestSSLAsyncCacheServer server(&acceptCallback);
// Set up SSL client
EventBase eventBase;
std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
10, 500));
client->connect();
EventBaseAborter eba(&eventBase, 3000);
eventBase.loop();
EXPECT_EQ(server.getAsyncCallbacks(), 18);
EXPECT_EQ(server.getAsyncLookups(), 9);
EXPECT_EQ(client->getMiss(), 10);
EXPECT_EQ(client->getHit(), 0);
cerr << "SSLServerAsyncCacheTest test completed" << endl;
}
/**
* Test SSL server accept timeout with cache path
*/
TEST(AsyncSSLSocketTest, SSLServerTimeoutTest) {
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadCallback readCallback(&writeCallback);
EmptyReadCallback clientReadCallback;
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallback acceptCallback(&handshakeCallback, 50);
TestSSLAsyncCacheServer server(&acceptCallback);
// Set up SSL client
EventBase eventBase;
// only do a TCP connect
std::shared_ptr<AsyncSocket> sock = AsyncSocket::newSocket(&eventBase);
sock->connect(nullptr, server.getAddress());
clientReadCallback.tcpSocket_ = sock;
sock->setReadCB(&clientReadCallback);
EventBaseAborter eba(&eventBase, 3000);
eventBase.loop();
EXPECT_EQ(readCallback.state, STATE_WAITING);
cerr << "SSLServerTimeoutTest test completed" << endl;
}
/**
* Test SSL server accept timeout with cache path
*/
TEST(AsyncSSLSocketTest, SSLServerAsyncCacheTimeoutTest) {
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback, 50);
TestSSLAsyncCacheServer server(&acceptCallback);
// Set up SSL client
EventBase eventBase;
std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
2));
client->connect();
EventBaseAborter eba(&eventBase, 3000);
eventBase.loop();
EXPECT_EQ(server.getAsyncCallbacks(), 1);
EXPECT_EQ(server.getAsyncLookups(), 1);
EXPECT_EQ(client->getErrors(), 1);
EXPECT_EQ(client->getMiss(), 1);
EXPECT_EQ(client->getHit(), 0);
cerr << "SSLServerAsyncCacheTimeoutTest test completed" << endl;
}
/**
* Test SSL server accept timeout with cache path
*/
TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback,
HandshakeCallback::EXPECT_ERROR);
SSLServerAsyncCacheAcceptCallback acceptCallback(&handshakeCallback);
TestSSLAsyncCacheServer server(&acceptCallback, 500);
// Set up SSL client
EventBase eventBase;
std::shared_ptr<SSLClient> client(new SSLClient(&eventBase, server.getAddress(),
2, 100));
client->connect();
EventBaseAborter eba(&eventBase, 3000);
eventBase.loop();
server.getEventBase().runInEventBaseThread([&handshakeCallback]{
handshakeCallback.closeSocket();});
// give time for the cache lookup to come back and find it closed
usleep(500000);
EXPECT_EQ(server.getAsyncCallbacks(), 1);
EXPECT_EQ(server.getAsyncLookups(), 1);
EXPECT_EQ(client->getErrors(), 1);
EXPECT_EQ(client->getMiss(), 1);
EXPECT_EQ(client->getHit(), 0);
cerr << "SSLServerCacheCloseTest test completed" << endl;
}
/**
* Verify Client Ciphers obtained using SSL MSG Callback.
*/
TEST(AsyncSSLSocketTest, SSLParseClientHelloSuccess) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
serverCtx->ciphers("RSA:!SHA:!NULL:!SHA256@STRENGTH");
serverCtx->loadPrivateKey(testKey);
serverCtx->loadCertificate(testCert);
serverCtx->loadTrustedCertificates(testCA);
serverCtx->loadClientCAList(testCA);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
clientCtx->ciphers("RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5");
clientCtx->loadPrivateKey(testKey);
clientCtx->loadCertificate(testCert);
clientCtx->loadTrustedCertificates(testCA);
int fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), true, true);
SSLHandshakeServerParseClientHello server(std::move(serverSock), true, true);
eventBase.loop();
EXPECT_EQ(server.clientCiphers_,
"RC4-SHA:AES128-SHA:AES256-SHA:RC4-MD5:00ff");
EXPECT_TRUE(client.handshakeVerify_);
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_TRUE(!client.handshakeError_);
EXPECT_TRUE(server.handshakeVerify_);
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_TRUE(!server.handshakeError_);
}
TEST(AsyncSSLSocketTest, SSLParseClientHelloOnePacket) {
EventBase eventBase;
auto ctx = std::make_shared<SSLContext>();
int fds[2];
getfds(fds);
int bufLen = 42;
uint8_t majorVersion = 18;
uint8_t minorVersion = 25;
// Create callback buf
auto buf = IOBuf::create(bufLen);
buf->append(bufLen);
folly::io::RWPrivateCursor cursor(buf.get());
cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
cursor.write<uint16_t>(0);
cursor.write<uint8_t>(38);
cursor.write<uint8_t>(majorVersion);
cursor.write<uint8_t>(minorVersion);
cursor.skip(32);
cursor.write<uint32_t>(0);
SSL* ssl = ctx->createSSL();
AsyncSSLSocket::UniquePtr sock(
new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
sock->enableClientHelloParsing();
// Test client hello parsing in one packet
AsyncSSLSocket::clientHelloParsingCallback(
0, 0, SSL3_RT_HANDSHAKE, buf->data(), buf->length(), ssl, sock.get());
buf.reset();
auto parsedClientHello = sock->getClientHelloInfo();
EXPECT_TRUE(parsedClientHello != nullptr);
EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
}
TEST(AsyncSSLSocketTest, SSLParseClientHelloTwoPackets) {
EventBase eventBase;
auto ctx = std::make_shared<SSLContext>();
int fds[2];
getfds(fds);
int bufLen = 42;
uint8_t majorVersion = 18;
uint8_t minorVersion = 25;
// Create callback buf
auto buf = IOBuf::create(bufLen);
buf->append(bufLen);
folly::io::RWPrivateCursor cursor(buf.get());
cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
cursor.write<uint16_t>(0);
cursor.write<uint8_t>(38);
cursor.write<uint8_t>(majorVersion);
cursor.write<uint8_t>(minorVersion);
cursor.skip(32);
cursor.write<uint32_t>(0);
SSL* ssl = ctx->createSSL();
AsyncSSLSocket::UniquePtr sock(
new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
sock->enableClientHelloParsing();
// Test parsing with two packets with first packet size < 3
auto bufCopy = folly::IOBuf::copyBuffer(buf->data(), 2);
AsyncSSLSocket::clientHelloParsingCallback(
0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
ssl, sock.get());
bufCopy.reset();
bufCopy = folly::IOBuf::copyBuffer(buf->data() + 2, buf->length() - 2);
AsyncSSLSocket::clientHelloParsingCallback(
0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
ssl, sock.get());
bufCopy.reset();
auto parsedClientHello = sock->getClientHelloInfo();
EXPECT_TRUE(parsedClientHello != nullptr);
EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
}
TEST(AsyncSSLSocketTest, SSLParseClientHelloMultiplePackets) {
EventBase eventBase;
auto ctx = std::make_shared<SSLContext>();
int fds[2];
getfds(fds);
int bufLen = 42;
uint8_t majorVersion = 18;
uint8_t minorVersion = 25;
// Create callback buf
auto buf = IOBuf::create(bufLen);
buf->append(bufLen);
folly::io::RWPrivateCursor cursor(buf.get());
cursor.write<uint8_t>(SSL3_MT_CLIENT_HELLO);
cursor.write<uint16_t>(0);
cursor.write<uint8_t>(38);
cursor.write<uint8_t>(majorVersion);
cursor.write<uint8_t>(minorVersion);
cursor.skip(32);
cursor.write<uint32_t>(0);
SSL* ssl = ctx->createSSL();
AsyncSSLSocket::UniquePtr sock(
new AsyncSSLSocket(ctx, &eventBase, fds[0], true));
sock->enableClientHelloParsing();
// Test parsing with multiple small packets
for (uint64_t i = 0; i < buf->length(); i += 3) {
auto bufCopy = folly::IOBuf::copyBuffer(
buf->data() + i, std::min((uint64_t)3, buf->length() - i));
AsyncSSLSocket::clientHelloParsingCallback(
0, 0, SSL3_RT_HANDSHAKE, bufCopy->data(), bufCopy->length(),
ssl, sock.get());
bufCopy.reset();
}
auto parsedClientHello = sock->getClientHelloInfo();
EXPECT_TRUE(parsedClientHello != nullptr);
EXPECT_EQ(parsedClientHello->clientHelloMajorVersion_, majorVersion);
EXPECT_EQ(parsedClientHello->clientHelloMinorVersion_, minorVersion);
}
/**
* Verify sucessful behavior of SSL certificate validation.
*/
TEST(AsyncSSLSocketTest, SSLHandshakeValidationSuccess) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto dfServerCtx = std::make_shared<SSLContext>();
int fds[2];
getfds(fds);
getctx(clientCtx, dfServerCtx);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), true, true);
clientCtx->loadTrustedCertificates(testCA);
SSLHandshakeServer server(std::move(serverSock), true, true);
eventBase.loop();
EXPECT_TRUE(client.handshakeVerify_);
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_TRUE(!client.handshakeError_);
EXPECT_TRUE(!server.handshakeVerify_);
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_TRUE(!server.handshakeError_);
}
/**
* Verify that the client's verification callback is able to fail SSL
* connection establishment.
*/
TEST(AsyncSSLSocketTest, SSLHandshakeValidationFailure) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto dfServerCtx = std::make_shared<SSLContext>();
int fds[2];
getfds(fds);
getctx(clientCtx, dfServerCtx);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), true, false);
clientCtx->loadTrustedCertificates(testCA);
SSLHandshakeServer server(std::move(serverSock), true, true);
eventBase.loop();
EXPECT_TRUE(client.handshakeVerify_);
EXPECT_TRUE(!client.handshakeSuccess_);
EXPECT_TRUE(client.handshakeError_);
EXPECT_TRUE(!server.handshakeVerify_);
EXPECT_TRUE(!server.handshakeSuccess_);
EXPECT_TRUE(server.handshakeError_);
}
/**
* Verify that the options in SSLContext can be overridden in
* sslConnect/Accept.i.e specifying that no validation should be performed
* allows an otherwise-invalid certificate to be accepted and doesn't fire
* the validation callback.
*/
TEST(AsyncSSLSocketTest, OverrideSSLCtxDisableVerify) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto dfServerCtx = std::make_shared<SSLContext>();
int fds[2];
getfds(fds);
getctx(clientCtx, dfServerCtx);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClientNoVerify client(std::move(clientSock), false, false);
clientCtx->loadTrustedCertificates(testCA);
SSLHandshakeServerNoVerify server(std::move(serverSock), false, false);
eventBase.loop();
EXPECT_TRUE(!client.handshakeVerify_);
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_TRUE(!client.handshakeError_);
EXPECT_TRUE(!server.handshakeVerify_);
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_TRUE(!server.handshakeError_);
}
/**
* Verify that the options in SSLContext can be overridden in
* sslConnect/Accept. Enable verification even if context says otherwise.
* Test requireClientCert with client cert
*/
TEST(AsyncSSLSocketTest, OverrideSSLCtxEnableVerify) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
serverCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
serverCtx->loadPrivateKey(testKey);
serverCtx->loadCertificate(testCert);
serverCtx->loadTrustedCertificates(testCA);
serverCtx->loadClientCAList(testCA);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
clientCtx->loadPrivateKey(testKey);
clientCtx->loadCertificate(testCert);
clientCtx->loadTrustedCertificates(testCA);
int fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
SSLHandshakeClientDoVerify client(std::move(clientSock), true, true);
SSLHandshakeServerDoVerify server(std::move(serverSock), true, true);
eventBase.loop();
EXPECT_TRUE(client.handshakeVerify_);
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_FALSE(client.handshakeError_);
EXPECT_TRUE(server.handshakeVerify_);
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_FALSE(server.handshakeError_);
}
/**
* Verify that the client's verification callback is able to override
* the preverification failure and allow a successful connection.
*/
TEST(AsyncSSLSocketTest, SSLHandshakeValidationOverride) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto dfServerCtx = std::make_shared<SSLContext>();
int fds[2];
getfds(fds);
getctx(clientCtx, dfServerCtx);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), false, true);
SSLHandshakeServer server(std::move(serverSock), true, true);
eventBase.loop();
EXPECT_TRUE(client.handshakeVerify_);
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_TRUE(!client.handshakeError_);
EXPECT_TRUE(!server.handshakeVerify_);
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_TRUE(!server.handshakeError_);
}
/**
* Verify that specifying that no validation should be performed allows an
* otherwise-invalid certificate to be accepted and doesn't fire the validation
* callback.
*/
TEST(AsyncSSLSocketTest, SSLHandshakeValidationSkip) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto dfServerCtx = std::make_shared<SSLContext>();
int fds[2];
getfds(fds);
getctx(clientCtx, dfServerCtx);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
dfServerCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), false, false);
SSLHandshakeServer server(std::move(serverSock), false, false);
eventBase.loop();
EXPECT_TRUE(!client.handshakeVerify_);
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_TRUE(!client.handshakeError_);
EXPECT_TRUE(!server.handshakeVerify_);
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_TRUE(!server.handshakeError_);
}
/**
* Test requireClientCert with client cert
*/
TEST(AsyncSSLSocketTest, ClientCertHandshakeSuccess) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
serverCtx->setVerificationOption(
SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
serverCtx->loadPrivateKey(testKey);
serverCtx->loadCertificate(testCert);
serverCtx->loadTrustedCertificates(testCA);
serverCtx->loadClientCAList(testCA);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::VERIFY);
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
clientCtx->loadPrivateKey(testKey);
clientCtx->loadCertificate(testCert);
clientCtx->loadTrustedCertificates(testCA);
int fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), true, true);
SSLHandshakeServer server(std::move(serverSock), true, true);
eventBase.loop();
EXPECT_TRUE(client.handshakeVerify_);
EXPECT_TRUE(client.handshakeSuccess_);
EXPECT_FALSE(client.handshakeError_);
EXPECT_TRUE(server.handshakeVerify_);
EXPECT_TRUE(server.handshakeSuccess_);
EXPECT_FALSE(server.handshakeError_);
}
/**
* Test requireClientCert with no client cert
*/
TEST(AsyncSSLSocketTest, NoClientCertHandshakeError) {
EventBase eventBase;
auto clientCtx = std::make_shared<SSLContext>();
auto serverCtx = std::make_shared<SSLContext>();
serverCtx->setVerificationOption(
SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
serverCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
serverCtx->loadPrivateKey(testKey);
serverCtx->loadCertificate(testCert);
serverCtx->loadTrustedCertificates(testCA);
serverCtx->loadClientCAList(testCA);
clientCtx->setVerificationOption(SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
clientCtx->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
int fds[2];
getfds(fds);
AsyncSSLSocket::UniquePtr clientSock(
new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
AsyncSSLSocket::UniquePtr serverSock(
new AsyncSSLSocket(serverCtx, &eventBase, fds[1], true));
SSLHandshakeClient client(std::move(clientSock), false, false);
SSLHandshakeServer server(std::move(serverSock), false, false);
eventBase.loop();
EXPECT_FALSE(server.handshakeVerify_);
EXPECT_FALSE(server.handshakeSuccess_);
EXPECT_TRUE(server.handshakeError_);
}
}
///////////////////////////////////////////////////////////////////////////
// init_unit_test_suite
///////////////////////////////////////////////////////////////////////////
namespace {
struct Initializer {
Initializer() {
signal(SIGPIPE, SIG_IGN);
}
};
Initializer initializer;
} // anonymous
/*
* Copyright 2015 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <signal.h>
#include <pthread.h>
#include <folly/io/async/AsyncServerSocket.h>
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/AsyncTransport.h>
#include <folly/io/async/EventBase.h>
#include <folly/io/async/AsyncTimeout.h>
#include <folly/SocketAddress.h>
#include <gtest/gtest.h>
#include <iostream>
#include <list>
#include <unistd.h>
#include <fcntl.h>
#include <poll.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/tcp.h>
namespace folly {
enum StateEnum {
STATE_WAITING,
STATE_SUCCEEDED,
STATE_FAILED
};
// The destructors of all callback classes assert that the state is
// STATE_SUCCEEDED, for both possitive and negative tests. The tests
// are responsible for setting the succeeded state properly before the
// destructors are called.
class WriteCallbackBase :
public AsyncTransportWrapper::WriteCallback {
public:
WriteCallbackBase()
: state(STATE_WAITING)
, bytesWritten(0)
, exception(AsyncSocketException::UNKNOWN, "none") {}
~WriteCallbackBase() {
EXPECT_EQ(state, STATE_SUCCEEDED);
}
void setSocket(
const std::shared_ptr<AsyncSSLSocket> &socket) {
socket_ = socket;
}
void writeSuccess() noexcept override {
std::cerr << "writeSuccess" << std::endl;
state = STATE_SUCCEEDED;
}
void writeErr(
size_t bytesWritten,
const AsyncSocketException& ex) noexcept override {
std::cerr << "writeError: bytesWritten " << bytesWritten
<< ", exception " << ex.what() << std::endl;
state = STATE_FAILED;
this->bytesWritten = bytesWritten;
exception = ex;
socket_->close();
socket_->detachEventBase();
}
std::shared_ptr<AsyncSSLSocket> socket_;
StateEnum state;
size_t bytesWritten;
AsyncSocketException exception;
};
class ReadCallbackBase :
public AsyncTransportWrapper::ReadCallback {
public:
explicit ReadCallbackBase(WriteCallbackBase *wcb)
: wcb_(wcb)
, state(STATE_WAITING) {}
~ReadCallbackBase() {
EXPECT_EQ(state, STATE_SUCCEEDED);
}
void setSocket(
const std::shared_ptr<AsyncSSLSocket> &socket) {
socket_ = socket;
}
void setState(StateEnum s) {
state = s;
if (wcb_) {
wcb_->state = s;
}
}
void readErr(
const AsyncSocketException& ex) noexcept override {
std::cerr << "readError " << ex.what() << std::endl;
state = STATE_FAILED;
socket_->close();
socket_->detachEventBase();
}
void readEOF() noexcept override {
std::cerr << "readEOF" << std::endl;
socket_->close();
socket_->detachEventBase();
}
std::shared_ptr<AsyncSSLSocket> socket_;
WriteCallbackBase *wcb_;
StateEnum state;
};
class ReadCallback : public ReadCallbackBase {
public:
explicit ReadCallback(WriteCallbackBase *wcb)
: ReadCallbackBase(wcb)
, buffers() {}
~ReadCallback() {
for (std::vector<Buffer>::iterator it = buffers.begin();
it != buffers.end();
++it) {
it->free();
}
currentBuffer.free();
}
void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
if (!currentBuffer.buffer) {
currentBuffer.allocate(4096);
}
*bufReturn = currentBuffer.buffer;
*lenReturn = currentBuffer.length;
}
void readDataAvailable(size_t len) noexcept override {
std::cerr << "readDataAvailable, len " << len << std::endl;
currentBuffer.length = len;
wcb_->setSocket(socket_);
// Write back the same data.
socket_->write(wcb_, currentBuffer.buffer, len);
buffers.push_back(currentBuffer);
currentBuffer.reset();
state = STATE_SUCCEEDED;
}
class Buffer {
public:
Buffer() : buffer(nullptr), length(0) {}
Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
void reset() {
buffer = nullptr;
length = 0;
}
void allocate(size_t length) {
assert(buffer == nullptr);
this->buffer = static_cast<char*>(malloc(length));
this->length = length;
}
void free() {
::free(buffer);
reset();
}
char* buffer;
size_t length;
};
std::vector<Buffer> buffers;
Buffer currentBuffer;
};
class ReadErrorCallback : public ReadCallbackBase {
public:
explicit ReadErrorCallback(WriteCallbackBase *wcb)
: ReadCallbackBase(wcb) {}
// Return nullptr buffer to trigger readError()
void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
*bufReturn = nullptr;
*lenReturn = 0;
}
void readDataAvailable(size_t len) noexcept override {
// This should never to called.
FAIL();
}
void readErr(
const AsyncSocketException& ex) noexcept override {
ReadCallbackBase::readErr(ex);
std::cerr << "ReadErrorCallback::readError" << std::endl;
setState(STATE_SUCCEEDED);
}
};
class WriteErrorCallback : public ReadCallback {
public:
explicit WriteErrorCallback(WriteCallbackBase *wcb)
: ReadCallback(wcb) {}
void readDataAvailable(size_t len) noexcept override {
std::cerr << "readDataAvailable, len " << len << std::endl;
currentBuffer.length = len;
// close the socket before writing to trigger writeError().
::close(socket_->getFd());
wcb_->setSocket(socket_);
// Write back the same data.
socket_->write(wcb_, currentBuffer.buffer, len);
if (wcb_->state == STATE_FAILED) {
setState(STATE_SUCCEEDED);
} else {
state = STATE_FAILED;
}
buffers.push_back(currentBuffer);
currentBuffer.reset();
}
void readErr(const AsyncSocketException& ex) noexcept override {
std::cerr << "readError " << ex.what() << std::endl;
// do nothing since this is expected
}
};
class EmptyReadCallback : public ReadCallback {
public:
explicit EmptyReadCallback()
: ReadCallback(nullptr) {}
void readErr(const AsyncSocketException& ex) noexcept override {
std::cerr << "readError " << ex.what() << std::endl;
state = STATE_FAILED;
tcpSocket_->close();
tcpSocket_->detachEventBase();
}
void readEOF() noexcept override {
std::cerr << "readEOF" << std::endl;
tcpSocket_->close();
tcpSocket_->detachEventBase();
state = STATE_SUCCEEDED;
}
std::shared_ptr<AsyncSocket> tcpSocket_;
};
class HandshakeCallback :
public AsyncSSLSocket::HandshakeCB {
public:
enum ExpectType {
EXPECT_SUCCESS,
EXPECT_ERROR
};
explicit HandshakeCallback(ReadCallbackBase *rcb,
ExpectType expect = EXPECT_SUCCESS):
state(STATE_WAITING),
rcb_(rcb),
expect_(expect) {}
void setSocket(
const std::shared_ptr<AsyncSSLSocket> &socket) {
socket_ = socket;
}
void setState(StateEnum s) {
state = s;
rcb_->setState(s);
}
// Functions inherited from AsyncSSLSocketHandshakeCallback
void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
EXPECT_EQ(sock, socket_.get());
std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
rcb_->setSocket(socket_);
sock->setReadCB(rcb_);
state = (expect_ == EXPECT_SUCCESS) ? STATE_SUCCEEDED : STATE_FAILED;
}
void handshakeErr(
AsyncSSLSocket *sock,
const AsyncSocketException& ex) noexcept override {
std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
if (expect_ == EXPECT_ERROR) {
// rcb will never be invoked
rcb_->setState(STATE_SUCCEEDED);
}
}
~HandshakeCallback() {
EXPECT_EQ(state, STATE_SUCCEEDED);
}
void closeSocket() {
socket_->close();
state = STATE_SUCCEEDED;
}
StateEnum state;
std::shared_ptr<AsyncSSLSocket> socket_;
ReadCallbackBase *rcb_;
ExpectType expect_;
};
class SSLServerAcceptCallbackBase:
public folly::AsyncServerSocket::AcceptCallback {
public:
explicit SSLServerAcceptCallbackBase(HandshakeCallback *hcb):
state(STATE_WAITING), hcb_(hcb) {}
~SSLServerAcceptCallbackBase() {
EXPECT_EQ(state, STATE_SUCCEEDED);
}
void acceptError(const std::exception& ex) noexcept override {
std::cerr << "SSLServerAcceptCallbackBase::acceptError "
<< ex.what() << std::endl;
state = STATE_FAILED;
}
void connectionAccepted(int fd, const folly::SocketAddress& clientAddr)
noexcept override{
printf("Connection accepted\n");
std::shared_ptr<AsyncSSLSocket> sslSock;
try {
// Create a AsyncSSLSocket object with the fd. The socket should be
// added to the event base and in the state of accepting SSL connection.
sslSock = AsyncSSLSocket::newSocket(ctx_, base_, fd);
} catch (const std::exception &e) {
LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
"object with socket " << e.what() << fd;
::close(fd);
acceptError(e);
return;
}
connAccepted(sslSock);
}
virtual void connAccepted(
const std::shared_ptr<folly::AsyncSSLSocket> &s) = 0;
StateEnum state;
HandshakeCallback *hcb_;
std::shared_ptr<folly::SSLContext> ctx_;
folly::EventBase* base_;
};
class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
public:
uint32_t timeout_;
explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
uint32_t timeout = 0):
SSLServerAcceptCallbackBase(hcb),
timeout_(timeout) {}
virtual ~SSLServerAcceptCallback() {
if (timeout_ > 0) {
// if we set a timeout, we expect failure
EXPECT_EQ(hcb_->state, STATE_FAILED);
hcb_->setState(STATE_SUCCEEDED);
}
}
// Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
void connAccepted(
const std::shared_ptr<folly::AsyncSSLSocket> &s)
noexcept override {
auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
hcb_->setSocket(sock);
sock->sslAccept(hcb_, timeout_);
EXPECT_EQ(sock->getSSLState(),
AsyncSSLSocket::STATE_ACCEPTING);
state = STATE_SUCCEEDED;
}
};
class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
public:
explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
SSLServerAcceptCallback(hcb) {}
// Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
void connAccepted(
const std::shared_ptr<folly::AsyncSSLSocket> &s)
noexcept override {
auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
std::cerr << "SSLServerAcceptCallbackDelay::connAccepted"
<< std::endl;
int fd = sock->getFd();
#ifndef TCP_NOPUSH
{
// The accepted connection should already have TCP_NODELAY set
int value;
socklen_t valueLength = sizeof(value);
int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
EXPECT_EQ(rc, 0);
EXPECT_EQ(value, 1);
}
#endif
// Unset the TCP_NODELAY option.
int value = 0;
socklen_t valueLength = sizeof(value);
int rc = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, valueLength);
EXPECT_EQ(rc, 0);
rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
EXPECT_EQ(rc, 0);
EXPECT_EQ(value, 0);
SSLServerAcceptCallback::connAccepted(sock);
}
};
class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
public:
explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
uint32_t timeout = 0):
SSLServerAcceptCallback(hcb, timeout) {}
// Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
void connAccepted(
const std::shared_ptr<folly::AsyncSSLSocket> &s)
noexcept override {
auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
hcb_->setSocket(sock);
sock->sslAccept(hcb_, timeout_);
ASSERT_TRUE((sock->getSSLState() ==
AsyncSSLSocket::STATE_ACCEPTING) ||
(sock->getSSLState() ==
AsyncSSLSocket::STATE_CACHE_LOOKUP));
state = STATE_SUCCEEDED;
}
};
class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
public:
explicit HandshakeErrorCallback(HandshakeCallback *hcb):
SSLServerAcceptCallbackBase(hcb) {}
// Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
void connAccepted(
const std::shared_ptr<folly::AsyncSSLSocket> &s)
noexcept override {
auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
// The first call to sslAccept() should succeed.
hcb_->setSocket(sock);
sock->sslAccept(hcb_);
EXPECT_EQ(sock->getSSLState(),
AsyncSSLSocket::STATE_ACCEPTING);
// The second call to sslAccept() should fail.
HandshakeCallback callback2(hcb_->rcb_);
callback2.setSocket(sock);
sock->sslAccept(&callback2);
EXPECT_EQ(sock->getSSLState(),
AsyncSSLSocket::STATE_ERROR);
// Both callbacks should be in the error state.
EXPECT_EQ(hcb_->state, STATE_FAILED);
EXPECT_EQ(callback2.state, STATE_FAILED);
sock->detachEventBase();
state = STATE_SUCCEEDED;
hcb_->setState(STATE_SUCCEEDED);
callback2.setState(STATE_SUCCEEDED);
}
};
class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
public:
explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
SSLServerAcceptCallbackBase(hcb) {}
// Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
void connAccepted(
const std::shared_ptr<folly::AsyncSSLSocket> &s)
noexcept override {
std::cerr << "HandshakeErrorCallback::connAccepted" << std::endl;
auto sock = std::static_pointer_cast<AsyncSSLSocket>(s);
hcb_->setSocket(sock);
sock->getEventBase()->tryRunAfterDelay([=] {
std::cerr << "Delayed SSL accept, client will have close by now"
<< std::endl;
// SSL accept will fail
EXPECT_EQ(
sock->getSSLState(),
AsyncSSLSocket::STATE_UNINIT);
hcb_->socket_->sslAccept(hcb_);
// This registers for an event
EXPECT_EQ(
sock->getSSLState(),
AsyncSSLSocket::STATE_ACCEPTING);
state = STATE_SUCCEEDED;
}, 100);
}
};
class TestSSLServer {
protected:
EventBase evb_;
std::shared_ptr<folly::SSLContext> ctx_;
SSLServerAcceptCallbackBase *acb_;
folly::AsyncServerSocket *socket_;
folly::SocketAddress address_;
pthread_t thread_;
static void *Main(void *ctx) {
TestSSLServer *self = static_cast<TestSSLServer*>(ctx);
self->evb_.loop();
std::cerr << "Server thread exited event loop" << std::endl;
return nullptr;
}
public:
// Create a TestSSLServer.
// This immediately starts listening on the given port.
explicit TestSSLServer(SSLServerAcceptCallbackBase *acb);
// Kill the thread.
~TestSSLServer() {
evb_.runInEventBaseThread([&](){
socket_->stopAccepting();
});
std::cerr << "Waiting for server thread to exit" << std::endl;
pthread_join(thread_, nullptr);
}
EventBase &getEventBase() { return evb_; }
const folly::SocketAddress& getAddress() const {
return address_;
}
};
class TestSSLAsyncCacheServer : public TestSSLServer {
public:
explicit TestSSLAsyncCacheServer(SSLServerAcceptCallbackBase *acb,
int lookupDelay = 100) :
TestSSLServer(acb) {
SSL_CTX *sslCtx = ctx_->getSSLCtx();
SSL_CTX_sess_set_get_cb(sslCtx,
TestSSLAsyncCacheServer::getSessionCallback);
SSL_CTX_set_session_cache_mode(
sslCtx, SSL_SESS_CACHE_NO_INTERNAL | SSL_SESS_CACHE_SERVER);
asyncCallbacks_ = 0;
asyncLookups_ = 0;
lookupDelay_ = lookupDelay;
}
uint32_t getAsyncCallbacks() const { return asyncCallbacks_; }
uint32_t getAsyncLookups() const { return asyncLookups_; }
private:
static uint32_t asyncCallbacks_;
static uint32_t asyncLookups_;
static uint32_t lookupDelay_;
static SSL_SESSION *getSessionCallback(SSL *ssl,
unsigned char *sess_id,
int id_len,
int *copyflag) {
*copyflag = 0;
asyncCallbacks_++;
#ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
if (!SSL_want_sess_cache_lookup(ssl)) {
// libssl.so mismatch
std::cerr << "no async support" << std::endl;
return nullptr;
}
AsyncSSLSocket *sslSocket =
AsyncSSLSocket::getFromSSL(ssl);
assert(sslSocket != nullptr);
// Going to simulate an async cache by just running delaying the miss 100ms
if (asyncCallbacks_ % 2 == 0) {
// This socket is already blocked on lookup, return miss
std::cerr << "returning miss" << std::endl;
} else {
// fresh meat - block it
std::cerr << "async lookup" << std::endl;
sslSocket->getEventBase()->tryRunAfterDelay(
std::bind(&AsyncSSLSocket::restartSSLAccept,
sslSocket), lookupDelay_);
*copyflag = SSL_SESSION_CB_WOULD_BLOCK;
asyncLookups_++;
}
#endif
return nullptr;
}
};
void getfds(int fds[2]);
void getctx(
std::shared_ptr<folly::SSLContext> clientCtx,
std::shared_ptr<folly::SSLContext> serverCtx);
void sslsocketpair(
EventBase* eventBase,
AsyncSSLSocket::UniquePtr* clientSock,
AsyncSSLSocket::UniquePtr* serverSock);
class BlockingWriteClient :
private AsyncSSLSocket::HandshakeCB,
private AsyncTransportWrapper::WriteCallback {
public:
explicit BlockingWriteClient(
AsyncSSLSocket::UniquePtr socket)
: socket_(std::move(socket)),
bufLen_(2500),
iovCount_(2000) {
// Fill buf_
buf_.reset(new uint8_t[bufLen_]);
for (uint32_t n = 0; n < sizeof(buf_); ++n) {
buf_[n] = n % 0xff;
}
// Initialize iov_
iov_.reset(new struct iovec[iovCount_]);
for (uint32_t n = 0; n < iovCount_; ++n) {
iov_[n].iov_base = buf_.get() + n;
if (n & 0x1) {
iov_[n].iov_len = n % bufLen_;
} else {
iov_[n].iov_len = bufLen_ - (n % bufLen_);
}
}
socket_->sslConn(this, 100);
}
struct iovec* getIovec() const {
return iov_.get();
}
uint32_t getIovecCount() const {
return iovCount_;
}
private:
void handshakeSuc(AsyncSSLSocket*) noexcept override {
socket_->writev(this, iov_.get(), iovCount_);
}
void handshakeErr(
AsyncSSLSocket*,
const AsyncSocketException& ex) noexcept override {
ADD_FAILURE() << "client handshake error: " << ex.what();
}
void writeSuccess() noexcept override {
socket_->close();
}
void writeErr(
size_t bytesWritten,
const AsyncSocketException& ex) noexcept override {
ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
<< ex.what();
}
AsyncSSLSocket::UniquePtr socket_;
uint32_t bufLen_;
uint32_t iovCount_;
std::unique_ptr<uint8_t[]> buf_;
std::unique_ptr<struct iovec[]> iov_;
};
class BlockingWriteServer :
private AsyncSSLSocket::HandshakeCB,
private AsyncTransportWrapper::ReadCallback {
public:
explicit BlockingWriteServer(
AsyncSSLSocket::UniquePtr socket)
: socket_(std::move(socket)),
bufSize_(2500 * 2000),
bytesRead_(0) {
buf_.reset(new uint8_t[bufSize_]);
socket_->sslAccept(this, 100);
}
void checkBuffer(struct iovec* iov, uint32_t count) const {
uint32_t idx = 0;
for (uint32_t n = 0; n < count; ++n) {
size_t bytesLeft = bytesRead_ - idx;
int rc = memcmp(buf_.get() + idx, iov[n].iov_base,
std::min(iov[n].iov_len, bytesLeft));
if (rc != 0) {
FAIL() << "buffer mismatch at iovec " << n << "/" << count
<< ": rc=" << rc;
}
if (iov[n].iov_len > bytesLeft) {
FAIL() << "server did not read enough data: "
<< "ended at byte " << bytesLeft << "/" << iov[n].iov_len
<< " in iovec " << n << "/" << count;
}
idx += iov[n].iov_len;
}
if (idx != bytesRead_) {
ADD_FAILURE() << "server read extra data: " << bytesRead_
<< " bytes read; expected " << idx;
}
}
private:
void handshakeSuc(AsyncSSLSocket*) noexcept override {
// Wait 10ms before reading, so the client's writes will initially block.
socket_->getEventBase()->tryRunAfterDelay(
[this] { socket_->setReadCB(this); }, 10);
}
void handshakeErr(
AsyncSSLSocket*,
const AsyncSocketException& ex) noexcept override {
ADD_FAILURE() << "server handshake error: " << ex.what();
}
void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
*bufReturn = buf_.get() + bytesRead_;
*lenReturn = bufSize_ - bytesRead_;
}
void readDataAvailable(size_t len) noexcept override {
bytesRead_ += len;
socket_->setReadCB(nullptr);
socket_->getEventBase()->tryRunAfterDelay(
[this] { socket_->setReadCB(this); }, 2);
}
void readEOF() noexcept override {
socket_->close();
}
void readErr(
const AsyncSocketException& ex) noexcept override {
ADD_FAILURE() << "server read error: " << ex.what();
}
AsyncSSLSocket::UniquePtr socket_;
uint32_t bufSize_;
uint32_t bytesRead_;
std::unique_ptr<uint8_t[]> buf_;
};
class NpnClient :
private AsyncSSLSocket::HandshakeCB,
private AsyncTransportWrapper::WriteCallback {
public:
explicit NpnClient(
AsyncSSLSocket::UniquePtr socket)
: nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
socket_->sslConn(this);
}
const unsigned char* nextProto;
unsigned nextProtoLength;
private:
void handshakeSuc(AsyncSSLSocket*) noexcept override {
socket_->getSelectedNextProtocol(&nextProto,
&nextProtoLength);
}
void handshakeErr(
AsyncSSLSocket*,
const AsyncSocketException& ex) noexcept override {
ADD_FAILURE() << "client handshake error: " << ex.what();
}
void writeSuccess() noexcept override {
socket_->close();
}
void writeErr(
size_t bytesWritten,
const AsyncSocketException& ex) noexcept override {
ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
<< ex.what();
}
AsyncSSLSocket::UniquePtr socket_;
};
class NpnServer :
private AsyncSSLSocket::HandshakeCB,
private AsyncTransportWrapper::ReadCallback {
public:
explicit NpnServer(AsyncSSLSocket::UniquePtr socket)
: nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
socket_->sslAccept(this);
}
const unsigned char* nextProto;
unsigned nextProtoLength;
private:
void handshakeSuc(AsyncSSLSocket*) noexcept override {
socket_->getSelectedNextProtocol(&nextProto,
&nextProtoLength);
}
void handshakeErr(
AsyncSSLSocket*,
const AsyncSocketException& ex) noexcept override {
ADD_FAILURE() << "server handshake error: " << ex.what();
}
void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
*lenReturn = 0;
}
void readDataAvailable(size_t len) noexcept override {
}
void readEOF() noexcept override {
socket_->close();
}
void readErr(
const AsyncSocketException& ex) noexcept override {
ADD_FAILURE() << "server read error: " << ex.what();
}
AsyncSSLSocket::UniquePtr socket_;
};
#ifndef OPENSSL_NO_TLSEXT
class SNIClient :
private AsyncSSLSocket::HandshakeCB,
private AsyncTransportWrapper::WriteCallback {
public:
explicit SNIClient(
AsyncSSLSocket::UniquePtr socket)
: serverNameMatch(false), socket_(std::move(socket)) {
socket_->sslConn(this);
}
bool serverNameMatch;
private:
void handshakeSuc(AsyncSSLSocket*) noexcept override {
serverNameMatch = socket_->isServerNameMatch();
}
void handshakeErr(
AsyncSSLSocket*,
const AsyncSocketException& ex) noexcept override {
ADD_FAILURE() << "client handshake error: " << ex.what();
}
void writeSuccess() noexcept override {
socket_->close();
}
void writeErr(
size_t bytesWritten,
const AsyncSocketException& ex) noexcept override {
ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
<< ex.what();
}
AsyncSSLSocket::UniquePtr socket_;
};
class SNIServer :
private AsyncSSLSocket::HandshakeCB,
private AsyncTransportWrapper::ReadCallback {
public:
explicit SNIServer(
AsyncSSLSocket::UniquePtr socket,
const std::shared_ptr<folly::SSLContext>& ctx,
const std::shared_ptr<folly::SSLContext>& sniCtx,
const std::string& expectedServerName)
: serverNameMatch(false), socket_(std::move(socket)), sniCtx_(sniCtx),
expectedServerName_(expectedServerName) {
ctx->setServerNameCallback(std::bind(&SNIServer::serverNameCallback, this,
std::placeholders::_1));
socket_->sslAccept(this);
}
bool serverNameMatch;
private:
void handshakeSuc(AsyncSSLSocket* ssl) noexcept override {}
void handshakeErr(
AsyncSSLSocket*,
const AsyncSocketException& ex) noexcept override {
ADD_FAILURE() << "server handshake error: " << ex.what();
}
void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
*lenReturn = 0;
}
void readDataAvailable(size_t len) noexcept override {
}
void readEOF() noexcept override {
socket_->close();
}
void readErr(
const AsyncSocketException& ex) noexcept override {
ADD_FAILURE() << "server read error: " << ex.what();
}
folly::SSLContext::ServerNameCallbackResult
serverNameCallback(SSL *ssl) {
const char *sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
if (sniCtx_ &&
sn &&
!strcasecmp(expectedServerName_.c_str(), sn)) {
AsyncSSLSocket *sslSocket =
AsyncSSLSocket::getFromSSL(ssl);
sslSocket->switchServerSSLContext(sniCtx_);
serverNameMatch = true;
return folly::SSLContext::SERVER_NAME_FOUND;
} else {
serverNameMatch = false;
return folly::SSLContext::SERVER_NAME_NOT_FOUND;
}
}
AsyncSSLSocket::UniquePtr socket_;
std::shared_ptr<folly::SSLContext> sniCtx_;
std::string expectedServerName_;
};
#endif
class SSLClient : public AsyncSocket::ConnectCallback,
public AsyncTransportWrapper::WriteCallback,
public AsyncTransportWrapper::ReadCallback
{
private:
EventBase *eventBase_;
std::shared_ptr<AsyncSSLSocket> sslSocket_;
SSL_SESSION *session_;
std::shared_ptr<folly::SSLContext> ctx_;
uint32_t requests_;
folly::SocketAddress address_;
uint32_t timeout_;
char buf_[128];
char readbuf_[128];
uint32_t bytesRead_;
uint32_t hit_;
uint32_t miss_;
uint32_t errors_;
uint32_t writeAfterConnectErrors_;
public:
SSLClient(EventBase *eventBase,
const folly::SocketAddress& address,
uint32_t requests, uint32_t timeout = 0)
: eventBase_(eventBase),
session_(nullptr),
requests_(requests),
address_(address),
timeout_(timeout),
bytesRead_(0),
hit_(0),
miss_(0),
errors_(0),
writeAfterConnectErrors_(0) {
ctx_.reset(new folly::SSLContext());
ctx_->setOptions(SSL_OP_NO_TICKET);
ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
memset(buf_, 'a', sizeof(buf_));
}
~SSLClient() {
if (session_) {
SSL_SESSION_free(session_);
}
if (errors_ == 0) {
EXPECT_EQ(bytesRead_, sizeof(buf_));
}
}
uint32_t getHit() const { return hit_; }
uint32_t getMiss() const { return miss_; }
uint32_t getErrors() const { return errors_; }
uint32_t getWriteAfterConnectErrors() const {
return writeAfterConnectErrors_;
}
void connect(bool writeNow = false) {
sslSocket_ = AsyncSSLSocket::newSocket(
ctx_, eventBase_);
if (session_ != nullptr) {
sslSocket_->setSSLSession(session_);
}
requests_--;
sslSocket_->connect(this, address_, timeout_);
if (sslSocket_ && writeNow) {
// write some junk, used in an error test
sslSocket_->write(this, buf_, sizeof(buf_));
}
}
void connectSuccess() noexcept override {
std::cerr << "client SSL socket connected" << std::endl;
if (sslSocket_->getSSLSessionReused()) {
hit_++;
} else {
miss_++;
if (session_ != nullptr) {
SSL_SESSION_free(session_);
}
session_ = sslSocket_->getSSLSession();
}
// write()
sslSocket_->write(this, buf_, sizeof(buf_));
sslSocket_->setReadCB(this);
memset(readbuf_, 'b', sizeof(readbuf_));
bytesRead_ = 0;
}
void connectErr(
const AsyncSocketException& ex) noexcept override {
std::cerr << "SSLClient::connectError: " << ex.what() << std::endl;
errors_++;
sslSocket_.reset();
}
void writeSuccess() noexcept override {
std::cerr << "client write success" << std::endl;
}
void writeErr(
size_t bytesWritten,
const AsyncSocketException& ex)
noexcept override {
std::cerr << "client writeError: " << ex.what() << std::endl;
if (!sslSocket_) {
writeAfterConnectErrors_++;
}
}
void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
*bufReturn = readbuf_ + bytesRead_;
*lenReturn = sizeof(readbuf_) - bytesRead_;
}
void readEOF() noexcept override {
std::cerr << "client readEOF" << std::endl;
}
void readErr(
const AsyncSocketException& ex) noexcept override {
std::cerr << "client readError: " << ex.what() << std::endl;
}
void readDataAvailable(size_t len) noexcept override {
std::cerr << "client read data: " << len << std::endl;
bytesRead_ += len;
if (len == sizeof(buf_)) {
EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
sslSocket_->closeNow();
sslSocket_.reset();
if (requests_ != 0) {
connect();
}
}
}
};
class SSLHandshakeBase :
public AsyncSSLSocket::HandshakeCB,
private AsyncTransportWrapper::WriteCallback {
public:
explicit SSLHandshakeBase(
AsyncSSLSocket::UniquePtr socket,
bool preverifyResult,
bool verifyResult) :
handshakeVerify_(false),
handshakeSuccess_(false),
handshakeError_(false),
socket_(std::move(socket)),
preverifyResult_(preverifyResult),
verifyResult_(verifyResult) {
}
bool handshakeVerify_;
bool handshakeSuccess_;
bool handshakeError_;
protected:
AsyncSSLSocket::UniquePtr socket_;
bool preverifyResult_;
bool verifyResult_;
// HandshakeCallback
bool handshakeVer(
AsyncSSLSocket* sock,
bool preverifyOk,
X509_STORE_CTX* ctx) noexcept override {
handshakeVerify_ = true;
EXPECT_EQ(preverifyResult_, preverifyOk);
return verifyResult_;
}
void handshakeSuc(AsyncSSLSocket*) noexcept override {
handshakeSuccess_ = true;
}
void handshakeErr(
AsyncSSLSocket*,
const AsyncSocketException& ex) noexcept override {
handshakeError_ = true;
}
// WriteCallback
void writeSuccess() noexcept override {
socket_->close();
}
void writeErr(
size_t bytesWritten,
const AsyncSocketException& ex) noexcept override {
ADD_FAILURE() << "client write error after " << bytesWritten << " bytes: "
<< ex.what();
}
};
class SSLHandshakeClient : public SSLHandshakeBase {
public:
SSLHandshakeClient(
AsyncSSLSocket::UniquePtr socket,
bool preverifyResult,
bool verifyResult) :
SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
socket_->sslConn(this, 0);
}
};
class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
public:
SSLHandshakeClientNoVerify(
AsyncSSLSocket::UniquePtr socket,
bool preverifyResult,
bool verifyResult) :
SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
socket_->sslConn(this, 0,
folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
}
};
class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
public:
SSLHandshakeClientDoVerify(
AsyncSSLSocket::UniquePtr socket,
bool preverifyResult,
bool verifyResult) :
SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
socket_->sslConn(this, 0,
folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
}
};
class SSLHandshakeServer : public SSLHandshakeBase {
public:
SSLHandshakeServer(
AsyncSSLSocket::UniquePtr socket,
bool preverifyResult,
bool verifyResult)
: SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
socket_->sslAccept(this, 0);
}
};
class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
public:
SSLHandshakeServerParseClientHello(
AsyncSSLSocket::UniquePtr socket,
bool preverifyResult,
bool verifyResult)
: SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
socket_->enableClientHelloParsing();
socket_->sslAccept(this, 0);
}
std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
protected:
void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
handshakeSuccess_ = true;
sock->getSSLSharedCiphers(sharedCiphers_);
sock->getSSLServerCiphers(serverCiphers_);
sock->getSSLClientCiphers(clientCiphers_);
chosenCipher_ = sock->getNegotiatedCipherName();
}
};
class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
public:
SSLHandshakeServerNoVerify(
AsyncSSLSocket::UniquePtr socket,
bool preverifyResult,
bool verifyResult)
: SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
socket_->sslAccept(this, 0,
folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
}
};
class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
public:
SSLHandshakeServerDoVerify(
AsyncSSLSocket::UniquePtr socket,
bool preverifyResult,
bool verifyResult)
: SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
socket_->sslAccept(this, 0,
folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
}
};
class EventBaseAborter : public AsyncTimeout {
public:
EventBaseAborter(EventBase* eventBase,
uint32_t timeoutMS)
: AsyncTimeout(
eventBase, AsyncTimeout::InternalEnum::INTERNAL)
, eventBase_(eventBase) {
scheduleTimeout(timeoutMS);
}
void timeoutExpired() noexcept override {
FAIL() << "test timed out";
eventBase_->terminateLoopSoon();
}
private:
EventBase* eventBase_;
};
}
/*
* Copyright 2015 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/io/async/test/AsyncSSLSocketTest.h>
#include <gtest/gtest.h>
#include <pthread.h>
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/EventBase.h>
using std::string;
using std::vector;
using std::min;
using std::cerr;
using std::endl;
using std::list;
namespace folly {
class AttachDetachClient : public AsyncSocket::ConnectCallback,
public AsyncTransportWrapper::WriteCallback,
public AsyncTransportWrapper::ReadCallback {
private:
EventBase *eventBase_;
std::shared_ptr<AsyncSSLSocket> sslSocket_;
std::shared_ptr<SSLContext> ctx_;
folly::SocketAddress address_;
char buf_[128];
char readbuf_[128];
uint32_t bytesRead_;
public:
AttachDetachClient(EventBase *eventBase, const folly::SocketAddress& address)
: eventBase_(eventBase), address_(address), bytesRead_(0) {
ctx_.reset(new SSLContext());
ctx_->setOptions(SSL_OP_NO_TICKET);
ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
}
void connect() {
sslSocket_ = AsyncSSLSocket::newSocket(ctx_, eventBase_);
sslSocket_->connect(this, address_);
}
void connectSuccess() noexcept override {
cerr << "client SSL socket connected" << endl;
for (int i = 0; i < 1000; ++i) {
sslSocket_->detachSSLContext();
sslSocket_->attachSSLContext(ctx_);
}
EXPECT_EQ(ctx_->getSSLCtx()->references, 2);
sslSocket_->write(this, buf_, sizeof(buf_));
sslSocket_->setReadCB(this);
memset(readbuf_, 'b', sizeof(readbuf_));
bytesRead_ = 0;
}
void connectErr(const AsyncSocketException& ex) noexcept override
{
cerr << "AttachDetachClient::connectError: " << ex.what() << endl;
sslSocket_.reset();
}
void writeSuccess() noexcept override {
cerr << "client write success" << endl;
}
void writeErr(size_t bytesWritten, const AsyncSocketException& ex)
noexcept override {
cerr << "client writeError: " << ex.what() << endl;
}
void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
*bufReturn = readbuf_ + bytesRead_;
*lenReturn = sizeof(readbuf_) - bytesRead_;
}
void readEOF() noexcept override {
cerr << "client readEOF" << endl;
}
void readErr(const AsyncSocketException& ex) noexcept override {
cerr << "client readError: " << ex.what() << endl;
}
void readDataAvailable(size_t len) noexcept override {
cerr << "client read data: " << len << endl;
bytesRead_ += len;
if (len == sizeof(buf_)) {
EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
sslSocket_->closeNow();
}
}
};
/**
* Test passing contexts between threads
*/
TEST(AsyncSSLSocketTest2, AttachDetachSSLContext) {
// Start listening on a local port
WriteCallbackBase writeCallback;
ReadCallback readCallback(&writeCallback);
HandshakeCallback handshakeCallback(&readCallback);
SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
TestSSLServer server(&acceptCallback);
EventBase eventBase;
EventBaseAborter eba(&eventBase, 3000);
std::shared_ptr<AttachDetachClient> client(
new AttachDetachClient(&eventBase, server.getAddress()));
client->connect();
eventBase.loop();
}
}
///////////////////////////////////////////////////////////////////////////
// init_unit_test_suite
///////////////////////////////////////////////////////////////////////////
namespace {
using folly::SSLContext;
struct Initializer {
Initializer() {
signal(SIGPIPE, SIG_IGN);
SSLContext::setSSLLockTypes({
{CRYPTO_LOCK_EVP_PKEY, SSLContext::LOCK_NONE},
{CRYPTO_LOCK_SSL_SESSION, SSLContext::LOCK_SPINLOCK},
{CRYPTO_LOCK_SSL_CTX, SSLContext::LOCK_NONE}});
}
};
Initializer initializer;
} // anonymous
/*
* Copyright 2015 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/Foreach.h>
#include <folly/io/Cursor.h>
#include <folly/io/async/AsyncSSLSocket.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/EventBase.h>
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <string>
#include <vector>
using std::string;
using namespace testing;
namespace folly {
class MockAsyncSSLSocket : public AsyncSSLSocket{
public:
static std::shared_ptr<MockAsyncSSLSocket> newSocket(
const std::shared_ptr<SSLContext>& ctx,
EventBase* evb) {
auto sock = std::shared_ptr<MockAsyncSSLSocket>(
new MockAsyncSSLSocket(ctx, evb),
Destructor());
sock->ssl_ = SSL_new(ctx->getSSLCtx());
SSL_set_fd(sock->ssl_, -1);
return sock;
}
// Fake constructor sets the state to established without call to connect
// or accept
MockAsyncSSLSocket(const std::shared_ptr<SSLContext>& ctx,
EventBase* evb)
: AsyncSocket(evb), AsyncSSLSocket(ctx, evb) {
state_ = AsyncSocket::StateEnum::ESTABLISHED;
sslState_ = AsyncSSLSocket::SSLStateEnum::STATE_ESTABLISHED;
}
// mock the calls to SSL_write to see the buffer length and contents
MOCK_METHOD3(sslWriteImpl, int(SSL *ssl, const void *buf, int n));
// mock the calls to getRawBytesWritten()
MOCK_CONST_METHOD0(getRawBytesWritten, size_t());
// public wrapper for protected interface
ssize_t testPerformWrite(const iovec* vec, uint32_t count, WriteFlags flags,
uint32_t* countWritten, uint32_t* partialWritten) {
return performWrite(vec, count, flags, countWritten, partialWritten);
}
void checkEor(size_t appEor, size_t rawEor) {
EXPECT_EQ(appEor, appEorByteNo_);
EXPECT_EQ(rawEor, minEorRawByteNo_);
}
void setAppBytesWritten(size_t n) {
appBytesWritten_ = n;
}
};
class AsyncSSLSocketWriteTest : public testing::Test {
public:
AsyncSSLSocketWriteTest() :
sslContext_(new SSLContext()),
sock_(MockAsyncSSLSocket::newSocket(sslContext_, &eventBase_)) {
for (int i = 0; i < 500; i++) {
memcpy(source_ + i * 26, "abcdefghijklmnopqrstuvwxyz", 26);
}
}
// Make an iovec containing chunks of the reference text with requested sizes
// for each chunk
iovec *makeVec(std::vector<uint32_t> sizes) {
iovec *vec = new iovec[sizes.size()];
int i = 0;
int pos = 0;
for (auto size: sizes) {
vec[i].iov_base = (void *)(source_ + pos);
vec[i++].iov_len = size;
pos += size;
}
return vec;
}
// Verify that the given buf/pos matches the reference text
void verifyVec(const void *buf, int n, int pos) {
ASSERT_EQ(memcmp(source_ + pos, buf, n), 0);
}
// Update a vec on partial write
void consumeVec(iovec *vec, uint32_t countWritten, uint32_t partialWritten) {
vec[countWritten].iov_base =
((char *)vec[countWritten].iov_base) + partialWritten;
vec[countWritten].iov_len -= partialWritten;
}
EventBase eventBase_;
std::shared_ptr<SSLContext> sslContext_;
std::shared_ptr<MockAsyncSSLSocket> sock_;
char source_[26 * 500];
};
// The entire vec fits in one packet
TEST_F(AsyncSSLSocketWriteTest, write_coalescing1) {
int n = 3;
iovec *vec = makeVec({3, 3, 3});
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 9))
.WillOnce(Invoke([this] (SSL *, const void *buf, int n) {
verifyVec(buf, n, 0);
return 9; }));
uint32_t countWritten = 0;
uint32_t partialWritten = 0;
sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
&partialWritten);
EXPECT_EQ(countWritten, n);
EXPECT_EQ(partialWritten, 0);
}
// First packet is full, second two go in one packet
TEST_F(AsyncSSLSocketWriteTest, write_coalescing2) {
int n = 3;
iovec *vec = makeVec({1500, 3, 3});
int pos = 0;
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
.WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
verifyVec(buf, n, pos);
pos += n;
return n; }));
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6))
.WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
verifyVec(buf, n, pos);
pos += n;
return n; }));
uint32_t countWritten = 0;
uint32_t partialWritten = 0;
sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
&partialWritten);
EXPECT_EQ(countWritten, n);
EXPECT_EQ(partialWritten, 0);
}
// Two exactly full packets (coalesce ends midway through second chunk)
TEST_F(AsyncSSLSocketWriteTest, write_coalescing3) {
int n = 3;
iovec *vec = makeVec({1000, 1000, 1000});
int pos = 0;
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
.Times(2)
.WillRepeatedly(Invoke([this, &pos] (SSL *, const void *buf, int n) {
verifyVec(buf, n, pos);
pos += n;
return n; }));
uint32_t countWritten = 0;
uint32_t partialWritten = 0;
sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
&partialWritten);
EXPECT_EQ(countWritten, n);
EXPECT_EQ(partialWritten, 0);
}
// Partial write success midway through a coalesced vec
TEST_F(AsyncSSLSocketWriteTest, write_coalescing4) {
int n = 5;
iovec *vec = makeVec({300, 300, 300, 300, 300});
int pos = 0;
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
.WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
verifyVec(buf, n, pos);
pos += 1000;
return 1000; /* 500 bytes "pending" */ }));
uint32_t countWritten = 0;
uint32_t partialWritten = 0;
sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
&partialWritten);
EXPECT_EQ(countWritten, 3);
EXPECT_EQ(partialWritten, 100);
consumeVec(vec, countWritten, partialWritten);
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500))
.WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
verifyVec(buf, n, pos);
pos += n;
return 500; }));
sock_->testPerformWrite(vec + countWritten, n - countWritten,
WriteFlags::NONE,
&countWritten, &partialWritten);
EXPECT_EQ(countWritten, 2);
EXPECT_EQ(partialWritten, 0);
}
// coalesce ends exactly on a buffer boundary
TEST_F(AsyncSSLSocketWriteTest, write_coalescing5) {
int n = 3;
iovec *vec = makeVec({1000, 500, 500});
int pos = 0;
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
.WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
verifyVec(buf, n, pos);
pos += n;
return n; }));
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 500))
.WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
verifyVec(buf, n, pos);
pos += n;
return n; }));
uint32_t countWritten = 0;
uint32_t partialWritten = 0;
sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
&partialWritten);
EXPECT_EQ(countWritten, 3);
EXPECT_EQ(partialWritten, 0);
}
// partial write midway through first chunk
TEST_F(AsyncSSLSocketWriteTest, write_coalescing6) {
int n = 2;
iovec *vec = makeVec({1000, 500});
int pos = 0;
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
.WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
verifyVec(buf, n, pos);
pos += 700;
return 700; }));
uint32_t countWritten = 0;
uint32_t partialWritten = 0;
sock_->testPerformWrite(vec, n, WriteFlags::NONE, &countWritten,
&partialWritten);
EXPECT_EQ(countWritten, 0);
EXPECT_EQ(partialWritten, 700);
consumeVec(vec, countWritten, partialWritten);
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 800))
.WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
verifyVec(buf, n, pos);
pos += n;
return n; }));
sock_->testPerformWrite(vec + countWritten, n - countWritten,
WriteFlags::NONE,
&countWritten, &partialWritten);
EXPECT_EQ(countWritten, 2);
EXPECT_EQ(partialWritten, 0);
}
// Repeat coalescing2 with WriteFlags::EOR
TEST_F(AsyncSSLSocketWriteTest, write_with_eor1) {
int n = 3;
iovec *vec = makeVec({1500, 3, 3});
int pos = 0;
const size_t initAppBytesWritten = 500;
const size_t appEor = initAppBytesWritten + 1506;
sock_->setAppBytesWritten(initAppBytesWritten);
EXPECT_FALSE(sock_->isEorTrackingEnabled());
sock_->setEorTracking(true);
EXPECT_TRUE(sock_->isEorTrackingEnabled());
EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
// rawBytesWritten after writting initAppBytesWritten + 1500
// + some random SSL overhead
.WillOnce(Return(3600))
// rawBytesWritten after writting last 6 bytes
// + some random SSL overhead
.WillOnce(Return(3728));
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
.WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
// the first 1500 does not have the EOR byte
sock_->checkEor(0, 0);
verifyVec(buf, n, pos);
pos += n;
return n; }));
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 6))
.WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
sock_->checkEor(appEor, 3600 + n);
verifyVec(buf, n, pos);
pos += n;
return n; }));
uint32_t countWritten = 0;
uint32_t partialWritten = 0;
sock_->testPerformWrite(vec, n , WriteFlags::EOR,
&countWritten, &partialWritten);
EXPECT_EQ(countWritten, n);
EXPECT_EQ(partialWritten, 0);
sock_->checkEor(0, 0);
}
// coalescing with left over at the last chunk
// WriteFlags::EOR turned on
TEST_F(AsyncSSLSocketWriteTest, write_with_eor2) {
int n = 3;
iovec *vec = makeVec({600, 600, 600});
int pos = 0;
const size_t initAppBytesWritten = 500;
const size_t appEor = initAppBytesWritten + 1800;
sock_->setAppBytesWritten(initAppBytesWritten);
sock_->setEorTracking(true);
EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
// rawBytesWritten after writting initAppBytesWritten + 1500 bytes
// + some random SSL overhead
.WillOnce(Return(3600))
// rawBytesWritten after writting last 300 bytes
// + some random SSL overhead
.WillOnce(Return(4100));
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1500))
.WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
// the first 1500 does not have the EOR byte
sock_->checkEor(0, 0);
verifyVec(buf, n, pos);
pos += n;
return n; }));
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 300))
.WillOnce(Invoke([=, &pos] (SSL *, const void *buf, int n) {
sock_->checkEor(appEor, 3600 + n);
verifyVec(buf, n, pos);
pos += n;
return n; }));
uint32_t countWritten = 0;
uint32_t partialWritten = 0;
sock_->testPerformWrite(vec, n, WriteFlags::EOR,
&countWritten, &partialWritten);
EXPECT_EQ(countWritten, n);
EXPECT_EQ(partialWritten, 0);
sock_->checkEor(0, 0);
}
// WriteFlags::EOR set
// One buf in iovec
// Partial write at 1000-th byte
TEST_F(AsyncSSLSocketWriteTest, write_with_eor3) {
int n = 1;
iovec *vec = makeVec({1600});
int pos = 0;
const size_t initAppBytesWritten = 500;
const size_t appEor = initAppBytesWritten + 1600;
sock_->setAppBytesWritten(initAppBytesWritten);
sock_->setEorTracking(true);
EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
// rawBytesWritten after the initAppBytesWritten
// + some random SSL overhead
.WillOnce(Return(2000))
// rawBytesWritten after the initAppBytesWritten + 1000 (with 100 overhead)
// + some random SSL overhead
.WillOnce(Return(3100));
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 1600))
.WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
sock_->checkEor(appEor, 2000 + n);
verifyVec(buf, n, pos);
pos += 1000;
return 1000; }));
uint32_t countWritten = 0;
uint32_t partialWritten = 0;
sock_->testPerformWrite(vec, n, WriteFlags::EOR,
&countWritten, &partialWritten);
EXPECT_EQ(countWritten, 0);
EXPECT_EQ(partialWritten, 1000);
sock_->checkEor(appEor, 2000 + 1600);
consumeVec(vec, countWritten, partialWritten);
EXPECT_CALL(*(sock_.get()), getRawBytesWritten())
.WillOnce(Return(3100))
.WillOnce(Return(3800));
EXPECT_CALL(*(sock_.get()), sslWriteImpl(_, _, 600))
.WillOnce(Invoke([this, &pos] (SSL *, const void *buf, int n) {
sock_->checkEor(appEor, 3100 + n);
verifyVec(buf, n, pos);
pos += n;
return n; }));
sock_->testPerformWrite(vec + countWritten, n - countWritten,
WriteFlags::EOR,
&countWritten, &partialWritten);
EXPECT_EQ(countWritten, n);
EXPECT_EQ(partialWritten, 0);
sock_->checkEor(0, 0);
}
}
/*
* Copyright 2015 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/io/async/AsyncServerSocket.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/AsyncTimeout.h>
#include <folly/io/async/EventBase.h>
#include <folly/SocketAddress.h>
#include <folly/io/IOBuf.h>
#include <folly/io/async/test/BlockingSocket.h>
#include <folly/io/async/test/Util.h>
#include <gtest/gtest.h>
#include <boost/scoped_array.hpp>
#include <iostream>
#include <unistd.h>
#include <fcntl.h>
#include <poll.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/tcp.h>
#include <thread>
using namespace boost;
using std::string;
using std::vector;
using std::min;
using std::cerr;
using std::endl;
using std::unique_ptr;
using std::chrono::milliseconds;
using boost::scoped_array;
using namespace folly;
enum StateEnum {
STATE_WAITING,
STATE_SUCCEEDED,
STATE_FAILED
};
typedef std::function<void()> VoidCallback;
class ConnCallback : public AsyncSocket::ConnectCallback {
public:
ConnCallback()
: state(STATE_WAITING)
, exception(AsyncSocketException::UNKNOWN, "none") {}
void connectSuccess() noexcept override {
state = STATE_SUCCEEDED;
if (successCallback) {
successCallback();
}
}
void connectErr(const AsyncSocketException& ex) noexcept override {
state = STATE_FAILED;
exception = ex;
if (errorCallback) {
errorCallback();
}
}
StateEnum state;
AsyncSocketException exception;
VoidCallback successCallback;
VoidCallback errorCallback;
};
class WriteCallback : public AsyncTransportWrapper::WriteCallback {
public:
WriteCallback()
: state(STATE_WAITING)
, bytesWritten(0)
, exception(AsyncSocketException::UNKNOWN, "none") {}
void writeSuccess() noexcept override {
state = STATE_SUCCEEDED;
if (successCallback) {
successCallback();
}
}
void writeErr(size_t bytesWritten,
const AsyncSocketException& ex) noexcept override {
state = STATE_FAILED;
this->bytesWritten = bytesWritten;
exception = ex;
if (errorCallback) {
errorCallback();
}
}
StateEnum state;
size_t bytesWritten;
AsyncSocketException exception;
VoidCallback successCallback;
VoidCallback errorCallback;
};
class ReadCallback : public AsyncTransportWrapper::ReadCallback {
public:
ReadCallback()
: state(STATE_WAITING)
, exception(AsyncSocketException::UNKNOWN, "none")
, buffers() {}
~ReadCallback() {
for (vector<Buffer>::iterator it = buffers.begin();
it != buffers.end();
++it) {
it->free();
}
currentBuffer.free();
}
void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
if (!currentBuffer.buffer) {
currentBuffer.allocate(4096);
}
*bufReturn = currentBuffer.buffer;
*lenReturn = currentBuffer.length;
}
void readDataAvailable(size_t len) noexcept override {
currentBuffer.length = len;
buffers.push_back(currentBuffer);
currentBuffer.reset();
if (dataAvailableCallback) {
dataAvailableCallback();
}
}
void readEOF() noexcept override {
state = STATE_SUCCEEDED;
}
void readErr(const AsyncSocketException& ex) noexcept override {
state = STATE_FAILED;
exception = ex;
}
void verifyData(const char* expected, size_t expectedLen) const {
size_t offset = 0;
for (size_t idx = 0; idx < buffers.size(); ++idx) {
const auto& buf = buffers[idx];
size_t cmpLen = std::min(buf.length, expectedLen - offset);
CHECK_EQ(memcmp(buf.buffer, expected + offset, cmpLen), 0);
CHECK_EQ(cmpLen, buf.length);
offset += cmpLen;
}
CHECK_EQ(offset, expectedLen);
}
class Buffer {
public:
Buffer() : buffer(nullptr), length(0) {}
Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
void reset() {
buffer = nullptr;
length = 0;
}
void allocate(size_t length) {
assert(buffer == nullptr);
this->buffer = static_cast<char*>(malloc(length));
this->length = length;
}
void free() {
::free(buffer);
reset();
}
char* buffer;
size_t length;
};
StateEnum state;
AsyncSocketException exception;
vector<Buffer> buffers;
Buffer currentBuffer;
VoidCallback dataAvailableCallback;
};
class ReadVerifier {
};
class TestServer {
public:
// Create a TestServer.
// This immediately starts listening on an ephemeral port.
TestServer()
: fd_(-1) {
fd_ = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
if (fd_ < 0) {
throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
"failed to create test server socket", errno);
}
if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
"failed to put test server socket in "
"non-blocking mode", errno);
}
if (listen(fd_, 10) != 0) {
throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
"failed to listen on test server socket",
errno);
}
address_.setFromLocalAddress(fd_);
// The local address will contain 0.0.0.0.
// Change it to 127.0.0.1, so it can be used to connect to the server
address_.setFromIpPort("127.0.0.1", address_.getPort());
}
// Get the address for connecting to the server
const folly::SocketAddress& getAddress() const {
return address_;
}
int acceptFD(int timeout=50) {
struct pollfd pfd;
pfd.fd = fd_;
pfd.events = POLLIN;
int ret = poll(&pfd, 1, timeout);
if (ret == 0) {
throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
"test server accept() timed out");
} else if (ret < 0) {
throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
"test server accept() poll failed", errno);
}
int acceptedFd = ::accept(fd_, nullptr, nullptr);
if (acceptedFd < 0) {
throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
"test server accept() failed", errno);
}
return acceptedFd;
}
std::shared_ptr<BlockingSocket> accept(int timeout=50) {
int fd = acceptFD(timeout);
return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
}
std::shared_ptr<AsyncSocket> acceptAsync(EventBase* evb, int timeout=50) {
int fd = acceptFD(timeout);
return AsyncSocket::newSocket(evb, fd);
}
/**
* Accept a connection, read data from it, and verify that it matches the
* data in the specified buffer.
*/
void verifyConnection(const char* buf, size_t len) {
// accept a connection
std::shared_ptr<BlockingSocket> acceptedSocket = accept();
// read the data and compare it to the specified buffer
scoped_array<uint8_t> readbuf(new uint8_t[len]);
acceptedSocket->readAll(readbuf.get(), len);
CHECK_EQ(memcmp(buf, readbuf.get(), len), 0);
// make sure we get EOF next
uint32_t bytesRead = acceptedSocket->read(readbuf.get(), len);
CHECK_EQ(bytesRead, 0);
}
private:
int fd_;
folly::SocketAddress address_;
};
class DelayedWrite: public AsyncTimeout {
public:
DelayedWrite(const std::shared_ptr<AsyncSocket>& socket,
unique_ptr<IOBuf>&& bufs, AsyncTransportWrapper::WriteCallback* wcb,
bool cork, bool lastWrite = false):
AsyncTimeout(socket->getEventBase()),
socket_(socket),
bufs_(std::move(bufs)),
wcb_(wcb),
cork_(cork),
lastWrite_(lastWrite) {}
private:
void timeoutExpired() noexcept override {
WriteFlags flags = cork_ ? WriteFlags::CORK : WriteFlags::NONE;
socket_->writeChain(wcb_, std::move(bufs_), flags);
if (lastWrite_) {
socket_->shutdownWrite();
}
}
std::shared_ptr<AsyncSocket> socket_;
unique_ptr<IOBuf> bufs_;
AsyncTransportWrapper::WriteCallback* wcb_;
bool cork_;
bool lastWrite_;
};
///////////////////////////////////////////////////////////////////////////
// connect() tests
///////////////////////////////////////////////////////////////////////////
/**
* Test connecting to a server
*/
TEST(AsyncSocketTest, Connect) {
// Start listening on a local port
TestServer server;
// Connect using a AsyncSocket
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
ConnCallback cb;
socket->connect(&cb, server.getAddress(), 30);
evb.loop();
CHECK_EQ(cb.state, STATE_SUCCEEDED);
}
/**
* Test connecting to a server that isn't listening
*/
TEST(AsyncSocketTest, ConnectRefused) {
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
// Hopefully nothing is actually listening on this address
folly::SocketAddress addr("127.0.0.1", 65535);
ConnCallback cb;
socket->connect(&cb, addr, 30);
evb.loop();
CHECK_EQ(cb.state, STATE_FAILED);
CHECK_EQ(cb.exception.getType(), AsyncSocketException::NOT_OPEN);
}
/**
* Test connection timeout
*/
TEST(AsyncSocketTest, ConnectTimeout) {
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
// Try connecting to server that won't respond.
//
// This depends somewhat on the network where this test is run.
// Hopefully this IP will be routable but unresponsive.
// (Alternatively, we could try listening on a local raw socket, but that
// normally requires root privileges.)
folly::SocketAddress addr("8.8.8.8", 65535);
ConnCallback cb;
socket->connect(&cb, addr, 1); // also set a ridiculously small timeout
evb.loop();
CHECK_EQ(cb.state, STATE_FAILED);
CHECK_EQ(cb.exception.getType(), AsyncSocketException::TIMED_OUT);
// Verify that we can still get the peer address after a timeout.
// Use case is if the client was created from a client pool, and we want
// to log which peer failed.
folly::SocketAddress peer;
socket->getPeerAddress(&peer);
CHECK_EQ(peer, addr);
}
/**
* Test writing immediately after connecting, without waiting for connect
* to finish.
*/
TEST(AsyncSocketTest, ConnectAndWrite) {
TestServer server;
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30);
// write()
char buf[128];
memset(buf, 'a', sizeof(buf));
WriteCallback wcb;
socket->write(&wcb, buf, sizeof(buf));
// Loop. We don't bother accepting on the server socket yet.
// The kernel should be able to buffer the write request so it can succeed.
evb.loop();
CHECK_EQ(ccb.state, STATE_SUCCEEDED);
CHECK_EQ(wcb.state, STATE_SUCCEEDED);
// Make sure the server got a connection and received the data
socket->close();
server.verifyConnection(buf, sizeof(buf));
}
/**
* Test connecting using a nullptr connect callback.
*/
TEST(AsyncSocketTest, ConnectNullCallback) {
TestServer server;
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
socket->connect(nullptr, server.getAddress(), 30);
// write some data, just so we have some way of verifing
// that the socket works correctly after connecting
char buf[128];
memset(buf, 'a', sizeof(buf));
WriteCallback wcb;
socket->write(&wcb, buf, sizeof(buf));
evb.loop();
CHECK_EQ(wcb.state, STATE_SUCCEEDED);
// Make sure the server got a connection and received the data
socket->close();
server.verifyConnection(buf, sizeof(buf));
}
/**
* Test calling both write() and close() immediately after connecting, without
* waiting for connect to finish.
*
* This exercises the STATE_CONNECTING_CLOSING code.
*/
TEST(AsyncSocketTest, ConnectWriteAndClose) {
TestServer server;
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30);
// write()
char buf[128];
memset(buf, 'a', sizeof(buf));
WriteCallback wcb;
socket->write(&wcb, buf, sizeof(buf));
// close()
socket->close();
// Loop. We don't bother accepting on the server socket yet.
// The kernel should be able to buffer the write request so it can succeed.
evb.loop();
CHECK_EQ(ccb.state, STATE_SUCCEEDED);
CHECK_EQ(wcb.state, STATE_SUCCEEDED);
// Make sure the server got a connection and received the data
server.verifyConnection(buf, sizeof(buf));
}
/**
* Test calling close() immediately after connect()
*/
TEST(AsyncSocketTest, ConnectAndClose) {
TestServer server;
// Connect using a AsyncSocket
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30);
// Hopefully the connect didn't succeed immediately.
// If it did, we can't exercise the close-while-connecting code path.
if (ccb.state == STATE_SUCCEEDED) {
LOG(INFO) << "connect() succeeded immediately; aborting test "
"of close-during-connect behavior";
return;
}
socket->close();
// Loop, although there shouldn't be anything to do.
evb.loop();
// Make sure the connection was aborted
CHECK_EQ(ccb.state, STATE_FAILED);
}
/**
* Test calling closeNow() immediately after connect()
*
* This should be identical to the normal close behavior.
*/
TEST(AsyncSocketTest, ConnectAndCloseNow) {
TestServer server;
// Connect using a AsyncSocket
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30);
// Hopefully the connect didn't succeed immediately.
// If it did, we can't exercise the close-while-connecting code path.
if (ccb.state == STATE_SUCCEEDED) {
LOG(INFO) << "connect() succeeded immediately; aborting test "
"of closeNow()-during-connect behavior";
return;
}
socket->closeNow();
// Loop, although there shouldn't be anything to do.
evb.loop();
// Make sure the connection was aborted
CHECK_EQ(ccb.state, STATE_FAILED);
}
/**
* Test calling both write() and closeNow() immediately after connecting,
* without waiting for connect to finish.
*
* This should abort the pending write.
*/
TEST(AsyncSocketTest, ConnectWriteAndCloseNow) {
TestServer server;
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30);
// Hopefully the connect didn't succeed immediately.
// If it did, we can't exercise the close-while-connecting code path.
if (ccb.state == STATE_SUCCEEDED) {
LOG(INFO) << "connect() succeeded immediately; aborting test "
"of write-during-connect behavior";
return;
}
// write()
char buf[128];
memset(buf, 'a', sizeof(buf));
WriteCallback wcb;
socket->write(&wcb, buf, sizeof(buf));
// close()
socket->closeNow();
// Loop, although there shouldn't be anything to do.
evb.loop();
CHECK_EQ(ccb.state, STATE_FAILED);
CHECK_EQ(wcb.state, STATE_FAILED);
}
/**
* Test installing a read callback immediately, before connect() finishes.
*/
TEST(AsyncSocketTest, ConnectAndRead) {
TestServer server;
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30);
ReadCallback rcb;
socket->setReadCB(&rcb);
// Even though we haven't looped yet, we should be able to accept
// the connection and send data to it.
std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
uint8_t buf[128];
memset(buf, 'a', sizeof(buf));
acceptedSocket->write(buf, sizeof(buf));
acceptedSocket->flush();
acceptedSocket->close();
// Loop, although there shouldn't be anything to do.
evb.loop();
CHECK_EQ(ccb.state, STATE_SUCCEEDED);
CHECK_EQ(rcb.state, STATE_SUCCEEDED);
CHECK_EQ(rcb.buffers.size(), 1);
CHECK_EQ(rcb.buffers[0].length, sizeof(buf));
CHECK_EQ(memcmp(rcb.buffers[0].buffer, buf, sizeof(buf)), 0);
}
/**
* Test installing a read callback and then closing immediately before the
* connect attempt finishes.
*/
TEST(AsyncSocketTest, ConnectReadAndClose) {
TestServer server;
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30);
// Hopefully the connect didn't succeed immediately.
// If it did, we can't exercise the close-while-connecting code path.
if (ccb.state == STATE_SUCCEEDED) {
LOG(INFO) << "connect() succeeded immediately; aborting test "
"of read-during-connect behavior";
return;
}
ReadCallback rcb;
socket->setReadCB(&rcb);
// close()
socket->close();
// Loop, although there shouldn't be anything to do.
evb.loop();
CHECK_EQ(ccb.state, STATE_FAILED); // we aborted the close attempt
CHECK_EQ(rcb.buffers.size(), 0);
CHECK_EQ(rcb.state, STATE_SUCCEEDED); // this indicates EOF
}
/**
* Test both writing and installing a read callback immediately,
* before connect() finishes.
*/
TEST(AsyncSocketTest, ConnectWriteAndRead) {
TestServer server;
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30);
// write()
char buf1[128];
memset(buf1, 'a', sizeof(buf1));
WriteCallback wcb;
socket->write(&wcb, buf1, sizeof(buf1));
// set a read callback
ReadCallback rcb;
socket->setReadCB(&rcb);
// Even though we haven't looped yet, we should be able to accept
// the connection and send data to it.
std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
uint8_t buf2[128];
memset(buf2, 'b', sizeof(buf2));
acceptedSocket->write(buf2, sizeof(buf2));
acceptedSocket->flush();
// shut down the write half of acceptedSocket, so that the AsyncSocket
// will stop reading and we can break out of the event loop.
shutdown(acceptedSocket->getSocketFD(), SHUT_WR);
// Loop
evb.loop();
// Make sure the connect succeeded
CHECK_EQ(ccb.state, STATE_SUCCEEDED);
// Make sure the AsyncSocket read the data written by the accepted socket
CHECK_EQ(rcb.state, STATE_SUCCEEDED);
CHECK_EQ(rcb.buffers.size(), 1);
CHECK_EQ(rcb.buffers[0].length, sizeof(buf2));
CHECK_EQ(memcmp(rcb.buffers[0].buffer, buf2, sizeof(buf2)), 0);
// Close the AsyncSocket so we'll see EOF on acceptedSocket
socket->close();
// Make sure the accepted socket saw the data written by the AsyncSocket
uint8_t readbuf[sizeof(buf1)];
acceptedSocket->readAll(readbuf, sizeof(readbuf));
CHECK_EQ(memcmp(buf1, readbuf, sizeof(buf1)), 0);
uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
CHECK_EQ(bytesRead, 0);
}
/**
* Test writing to the socket then shutting down writes before the connect
* attempt finishes.
*/
TEST(AsyncSocketTest, ConnectWriteAndShutdownWrite) {
TestServer server;
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30);
// Hopefully the connect didn't succeed immediately.
// If it did, we can't exercise the write-while-connecting code path.
if (ccb.state == STATE_SUCCEEDED) {
LOG(INFO) << "connect() succeeded immediately; skipping test";
return;
}
// Ask to write some data
char wbuf[128];
memset(wbuf, 'a', sizeof(wbuf));
WriteCallback wcb;
socket->write(&wcb, wbuf, sizeof(wbuf));
socket->shutdownWrite();
// Shutdown writes
socket->shutdownWrite();
// Even though we haven't looped yet, we should be able to accept
// the connection.
std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
// Since the connection is still in progress, there should be no data to
// read yet. Verify that the accepted socket is not readable.
struct pollfd fds[1];
fds[0].fd = acceptedSocket->getSocketFD();
fds[0].events = POLLIN;
fds[0].revents = 0;
int rc = poll(fds, 1, 0);
CHECK_EQ(rc, 0);
// Write data to the accepted socket
uint8_t acceptedWbuf[192];
memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
acceptedSocket->flush();
// Loop
evb.loop();
// The loop should have completed the connection, written the queued data,
// and shutdown writes on the socket.
//
// Check that the connection was completed successfully and that the write
// callback succeeded.
CHECK_EQ(ccb.state, STATE_SUCCEEDED);
CHECK_EQ(wcb.state, STATE_SUCCEEDED);
// Check that we can read the data that was written to the socket, and that
// we see an EOF, since its socket was half-shutdown.
uint8_t readbuf[sizeof(wbuf)];
acceptedSocket->readAll(readbuf, sizeof(readbuf));
CHECK_EQ(memcmp(wbuf, readbuf, sizeof(wbuf)), 0);
uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
CHECK_EQ(bytesRead, 0);
// Close the accepted socket. This will cause it to see EOF
// and uninstall the read callback when we loop next.
acceptedSocket->close();
// Install a read callback, then loop again.
ReadCallback rcb;
socket->setReadCB(&rcb);
evb.loop();
// This loop should have read the data and seen the EOF
CHECK_EQ(rcb.state, STATE_SUCCEEDED);
CHECK_EQ(rcb.buffers.size(), 1);
CHECK_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
CHECK_EQ(memcmp(rcb.buffers[0].buffer,
acceptedWbuf, sizeof(acceptedWbuf)), 0);
}
/**
* Test reading, writing, and shutting down writes before the connect attempt
* finishes.
*/
TEST(AsyncSocketTest, ConnectReadWriteAndShutdownWrite) {
TestServer server;
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30);
// Hopefully the connect didn't succeed immediately.
// If it did, we can't exercise the write-while-connecting code path.
if (ccb.state == STATE_SUCCEEDED) {
LOG(INFO) << "connect() succeeded immediately; skipping test";
return;
}
// Install a read callback
ReadCallback rcb;
socket->setReadCB(&rcb);
// Ask to write some data
char wbuf[128];
memset(wbuf, 'a', sizeof(wbuf));
WriteCallback wcb;
socket->write(&wcb, wbuf, sizeof(wbuf));
// Shutdown writes
socket->shutdownWrite();
// Even though we haven't looped yet, we should be able to accept
// the connection.
std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
// Since the connection is still in progress, there should be no data to
// read yet. Verify that the accepted socket is not readable.
struct pollfd fds[1];
fds[0].fd = acceptedSocket->getSocketFD();
fds[0].events = POLLIN;
fds[0].revents = 0;
int rc = poll(fds, 1, 0);
CHECK_EQ(rc, 0);
// Write data to the accepted socket
uint8_t acceptedWbuf[192];
memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
acceptedSocket->flush();
// Shutdown writes to the accepted socket. This will cause it to see EOF
// and uninstall the read callback.
::shutdown(acceptedSocket->getSocketFD(), SHUT_WR);
// Loop
evb.loop();
// The loop should have completed the connection, written the queued data,
// shutdown writes on the socket, read the data we wrote to it, and see the
// EOF.
//
// Check that the connection was completed successfully and that the read
// and write callbacks were invoked as expected.
CHECK_EQ(ccb.state, STATE_SUCCEEDED);
CHECK_EQ(rcb.state, STATE_SUCCEEDED);
CHECK_EQ(rcb.buffers.size(), 1);
CHECK_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
CHECK_EQ(memcmp(rcb.buffers[0].buffer,
acceptedWbuf, sizeof(acceptedWbuf)), 0);
CHECK_EQ(wcb.state, STATE_SUCCEEDED);
// Check that we can read the data that was written to the socket, and that
// we see an EOF, since its socket was half-shutdown.
uint8_t readbuf[sizeof(wbuf)];
acceptedSocket->readAll(readbuf, sizeof(readbuf));
CHECK_EQ(memcmp(wbuf, readbuf, sizeof(wbuf)), 0);
uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
CHECK_EQ(bytesRead, 0);
// Fully close both sockets
acceptedSocket->close();
socket->close();
}
/**
* Test reading, writing, and calling shutdownWriteNow() before the
* connect attempt finishes.
*/
TEST(AsyncSocketTest, ConnectReadWriteAndShutdownWriteNow) {
TestServer server;
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30);
// Hopefully the connect didn't succeed immediately.
// If it did, we can't exercise the write-while-connecting code path.
if (ccb.state == STATE_SUCCEEDED) {
LOG(INFO) << "connect() succeeded immediately; skipping test";
return;
}
// Install a read callback
ReadCallback rcb;
socket->setReadCB(&rcb);
// Ask to write some data
char wbuf[128];
memset(wbuf, 'a', sizeof(wbuf));
WriteCallback wcb;
socket->write(&wcb, wbuf, sizeof(wbuf));
// Shutdown writes immediately.
// This should immediately discard the data that we just tried to write.
socket->shutdownWriteNow();
// Verify that writeError() was invoked on the write callback.
CHECK_EQ(wcb.state, STATE_FAILED);
CHECK_EQ(wcb.bytesWritten, 0);
// Even though we haven't looped yet, we should be able to accept
// the connection.
std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
// Since the connection is still in progress, there should be no data to
// read yet. Verify that the accepted socket is not readable.
struct pollfd fds[1];
fds[0].fd = acceptedSocket->getSocketFD();
fds[0].events = POLLIN;
fds[0].revents = 0;
int rc = poll(fds, 1, 0);
CHECK_EQ(rc, 0);
// Write data to the accepted socket
uint8_t acceptedWbuf[192];
memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
acceptedSocket->flush();
// Shutdown writes to the accepted socket. This will cause it to see EOF
// and uninstall the read callback.
::shutdown(acceptedSocket->getSocketFD(), SHUT_WR);
// Loop
evb.loop();
// The loop should have completed the connection, written the queued data,
// shutdown writes on the socket, read the data we wrote to it, and see the
// EOF.
//
// Check that the connection was completed successfully and that the read
// callback was invoked as expected.
CHECK_EQ(ccb.state, STATE_SUCCEEDED);
CHECK_EQ(rcb.state, STATE_SUCCEEDED);
CHECK_EQ(rcb.buffers.size(), 1);
CHECK_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
CHECK_EQ(memcmp(rcb.buffers[0].buffer,
acceptedWbuf, sizeof(acceptedWbuf)), 0);
// Since we used shutdownWriteNow(), it should have discarded all pending
// write data. Verify we see an immediate EOF when reading from the accepted
// socket.
uint8_t readbuf[sizeof(wbuf)];
uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
CHECK_EQ(bytesRead, 0);
// Fully close both sockets
acceptedSocket->close();
socket->close();
}
// Helper function for use in testConnectOptWrite()
// Temporarily disable the read callback
void tmpDisableReads(AsyncSocket* socket, ReadCallback* rcb) {
// Uninstall the read callback
socket->setReadCB(nullptr);
// Schedule the read callback to be reinstalled after 1ms
socket->getEventBase()->runInLoop(
std::bind(&AsyncSocket::setReadCB, socket, rcb));
}
/**
* Test connect+write, then have the connect callback perform another write.
*
* This tests interaction of the optimistic writing after connect with
* additional write attempts that occur in the connect callback.
*/
void testConnectOptWrite(size_t size1, size_t size2, bool close = false) {
TestServer server;
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
// connect()
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30);
// Hopefully the connect didn't succeed immediately.
// If it did, we can't exercise the optimistic write code path.
if (ccb.state == STATE_SUCCEEDED) {
LOG(INFO) << "connect() succeeded immediately; aborting test "
"of optimistic write behavior";
return;
}
// Tell the connect callback to perform a write when the connect succeeds
WriteCallback wcb2;
scoped_array<char> buf2(new char[size2]);
memset(buf2.get(), 'b', size2);
if (size2 > 0) {
ccb.successCallback = [&] { socket->write(&wcb2, buf2.get(), size2); };
// Tell the second write callback to close the connection when it is done
wcb2.successCallback = [&] { socket->closeNow(); };
}
// Schedule one write() immediately, before the connect finishes
scoped_array<char> buf1(new char[size1]);
memset(buf1.get(), 'a', size1);
WriteCallback wcb1;
if (size1 > 0) {
socket->write(&wcb1, buf1.get(), size1);
}
if (close) {
// immediately perform a close, before connect() completes
socket->close();
}
// Start reading from the other endpoint after 10ms.
// If we're using large buffers, we have to read so that the writes don't
// block forever.
std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
ReadCallback rcb;
rcb.dataAvailableCallback = std::bind(tmpDisableReads,
acceptedSocket.get(), &rcb);
socket->getEventBase()->tryRunAfterDelay(
std::bind(&AsyncSocket::setReadCB, acceptedSocket.get(), &rcb),
10);
// Loop. We don't bother accepting on the server socket yet.
// The kernel should be able to buffer the write request so it can succeed.
evb.loop();
CHECK_EQ(ccb.state, STATE_SUCCEEDED);
if (size1 > 0) {
CHECK_EQ(wcb1.state, STATE_SUCCEEDED);
}
if (size2 > 0) {
CHECK_EQ(wcb2.state, STATE_SUCCEEDED);
}
socket->close();
// Make sure the read callback received all of the data
size_t bytesRead = 0;
for (vector<ReadCallback::Buffer>::const_iterator it = rcb.buffers.begin();
it != rcb.buffers.end();
++it) {
size_t start = bytesRead;
bytesRead += it->length;
size_t end = bytesRead;
if (start < size1) {
size_t cmpLen = min(size1, end) - start;
CHECK_EQ(memcmp(it->buffer, buf1.get() + start, cmpLen), 0);
}
if (end > size1 && end <= size1 + size2) {
size_t itOffset;
size_t buf2Offset;
size_t cmpLen;
if (start >= size1) {
itOffset = 0;
buf2Offset = start - size1;
cmpLen = end - start;
} else {
itOffset = size1 - start;
buf2Offset = 0;
cmpLen = end - size1;
}
CHECK_EQ(memcmp(it->buffer + itOffset, buf2.get() + buf2Offset,
cmpLen),
0);
}
}
CHECK_EQ(bytesRead, size1 + size2);
}
TEST(AsyncSocketTest, ConnectCallbackWrite) {
// Test using small writes that should both succeed immediately
testConnectOptWrite(100, 200);
// Test using a large buffer in the connect callback, that should block
const size_t largeSize = 8*1024*1024;
testConnectOptWrite(100, largeSize);
// Test using a large initial write
testConnectOptWrite(largeSize, 100);
// Test using two large buffers
testConnectOptWrite(largeSize, largeSize);
// Test a small write in the connect callback,
// but no immediate write before connect completes
testConnectOptWrite(0, 64);
// Test a large write in the connect callback,
// but no immediate write before connect completes
testConnectOptWrite(0, largeSize);
// Test connect, a small write, then immediately call close() before connect
// completes
testConnectOptWrite(211, 0, true);
// Test connect, a large immediate write (that will block), then immediately
// call close() before connect completes
testConnectOptWrite(largeSize, 0, true);
}
///////////////////////////////////////////////////////////////////////////
// write() related tests
///////////////////////////////////////////////////////////////////////////
/**
* Test writing using a nullptr callback
*/
TEST(AsyncSocketTest, WriteNullCallback) {
TestServer server;
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket =
AsyncSocket::newSocket(&evb, server.getAddress(), 30);
evb.loop(); // loop until the socket is connected
// write() with a nullptr callback
char buf[128];
memset(buf, 'a', sizeof(buf));
socket->write(nullptr, buf, sizeof(buf));
evb.loop(); // loop until the data is sent
// Make sure the server got a connection and received the data
socket->close();
server.verifyConnection(buf, sizeof(buf));
}
/**
* Test writing with a send timeout
*/
TEST(AsyncSocketTest, WriteTimeout) {
TestServer server;
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket =
AsyncSocket::newSocket(&evb, server.getAddress(), 30);
evb.loop(); // loop until the socket is connected
// write() a large chunk of data, with no-one on the other end reading
size_t writeLength = 8*1024*1024;
uint32_t timeout = 200;
socket->setSendTimeout(timeout);
scoped_array<char> buf(new char[writeLength]);
memset(buf.get(), 'a', writeLength);
WriteCallback wcb;
socket->write(&wcb, buf.get(), writeLength);
TimePoint start;
evb.loop();
TimePoint end;
// Make sure the write attempt timed out as requested
CHECK_EQ(wcb.state, STATE_FAILED);
CHECK_EQ(wcb.exception.getType(), AsyncSocketException::TIMED_OUT);
// Check that the write timed out within a reasonable period of time.
// We don't check for exactly the specified timeout, since AsyncSocket only
// times out when it hasn't made progress for that period of time.
//
// On linux, the first write sends a few hundred kb of data, then blocks for
// writability, and then unblocks again after 40ms and is able to write
// another smaller of data before blocking permanently. Therefore it doesn't
// time out until 40ms + timeout.
//
// I haven't fully verified the cause of this, but I believe it probably
// occurs because the receiving end delays sending an ack for up to 40ms.
// (This is the default value for TCP_DELACK_MIN.) Once the sender receives
// the ack, it can send some more data. However, after that point the
// receiver's kernel buffer is full. This 40ms delay happens even with
// TCP_NODELAY and TCP_QUICKACK enabled on both endpoints. However, the
// kernel may be automatically disabling TCP_QUICKACK after receiving some
// data.
//
// For now, we simply check that the timeout occurred within 160ms of
// the requested value.
T_CHECK_TIMEOUT(start, end, milliseconds(timeout), milliseconds(160));
}
/**
* Test writing to a socket that the remote endpoint has closed
*/
TEST(AsyncSocketTest, WritePipeError) {
TestServer server;
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket =
AsyncSocket::newSocket(&evb, server.getAddress(), 30);
socket->setSendTimeout(1000);
evb.loop(); // loop until the socket is connected
// accept and immediately close the socket
std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
acceptedSocket.reset();
// write() a large chunk of data
size_t writeLength = 8*1024*1024;
scoped_array<char> buf(new char[writeLength]);
memset(buf.get(), 'a', writeLength);
WriteCallback wcb;
socket->write(&wcb, buf.get(), writeLength);
evb.loop();
// Make sure the write failed.
// It would be nice if AsyncSocketException could convey the errno value,
// so that we could check for EPIPE
CHECK_EQ(wcb.state, STATE_FAILED);
CHECK_EQ(wcb.exception.getType(),
AsyncSocketException::INTERNAL_ERROR);
}
/**
* Test writing a mix of simple buffers and IOBufs
*/
TEST(AsyncSocketTest, WriteIOBuf) {
TestServer server;
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30);
// Accept the connection
std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
ReadCallback rcb;
acceptedSocket->setReadCB(&rcb);
// Write a simple buffer to the socket
size_t simpleBufLength = 5;
char simpleBuf[simpleBufLength];
memset(simpleBuf, 'a', simpleBufLength);
WriteCallback wcb;
socket->write(&wcb, simpleBuf, simpleBufLength);
// Write a single-element IOBuf chain
size_t buf1Length = 7;
unique_ptr<IOBuf> buf1(IOBuf::create(buf1Length));
memset(buf1->writableData(), 'b', buf1Length);
buf1->append(buf1Length);
unique_ptr<IOBuf> buf1Copy(buf1->clone());
WriteCallback wcb2;
socket->writeChain(&wcb2, std::move(buf1));
// Write a multiple-element IOBuf chain
size_t buf2Length = 11;
unique_ptr<IOBuf> buf2(IOBuf::create(buf2Length));
memset(buf2->writableData(), 'c', buf2Length);
buf2->append(buf2Length);
size_t buf3Length = 13;
unique_ptr<IOBuf> buf3(IOBuf::create(buf3Length));
memset(buf3->writableData(), 'd', buf3Length);
buf3->append(buf3Length);
buf2->appendChain(std::move(buf3));
unique_ptr<IOBuf> buf2Copy(buf2->clone());
buf2Copy->coalesce();
WriteCallback wcb3;
socket->writeChain(&wcb3, std::move(buf2));
socket->shutdownWrite();
// Let the reads and writes run to completion
evb.loop();
CHECK_EQ(wcb.state, STATE_SUCCEEDED);
CHECK_EQ(wcb2.state, STATE_SUCCEEDED);
CHECK_EQ(wcb3.state, STATE_SUCCEEDED);
// Make sure the reader got the right data in the right order
CHECK_EQ(rcb.state, STATE_SUCCEEDED);
CHECK_EQ(rcb.buffers.size(), 1);
CHECK_EQ(rcb.buffers[0].length,
simpleBufLength + buf1Length + buf2Length + buf3Length);
CHECK_EQ(
memcmp(rcb.buffers[0].buffer, simpleBuf, simpleBufLength), 0);
CHECK_EQ(
memcmp(rcb.buffers[0].buffer + simpleBufLength,
buf1Copy->data(), buf1Copy->length()), 0);
CHECK_EQ(
memcmp(rcb.buffers[0].buffer + simpleBufLength + buf1Length,
buf2Copy->data(), buf2Copy->length()), 0);
acceptedSocket->close();
socket->close();
}
TEST(AsyncSocketTest, WriteIOBufCorked) {
TestServer server;
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30);
// Accept the connection
std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
ReadCallback rcb;
acceptedSocket->setReadCB(&rcb);
// Do three writes, 100ms apart, with the "cork" flag set
// on the second write. The reader should see the first write
// arrive by itself, followed by the second and third writes
// arriving together.
size_t buf1Length = 5;
unique_ptr<IOBuf> buf1(IOBuf::create(buf1Length));
memset(buf1->writableData(), 'a', buf1Length);
buf1->append(buf1Length);
size_t buf2Length = 7;
unique_ptr<IOBuf> buf2(IOBuf::create(buf2Length));
memset(buf2->writableData(), 'b', buf2Length);
buf2->append(buf2Length);
size_t buf3Length = 11;
unique_ptr<IOBuf> buf3(IOBuf::create(buf3Length));
memset(buf3->writableData(), 'c', buf3Length);
buf3->append(buf3Length);
WriteCallback wcb1;
socket->writeChain(&wcb1, std::move(buf1));
WriteCallback wcb2;
DelayedWrite write2(socket, std::move(buf2), &wcb2, true);
write2.scheduleTimeout(100);
WriteCallback wcb3;
DelayedWrite write3(socket, std::move(buf3), &wcb3, false, true);
write3.scheduleTimeout(200);
evb.loop();
CHECK_EQ(ccb.state, STATE_SUCCEEDED);
CHECK_EQ(wcb1.state, STATE_SUCCEEDED);
CHECK_EQ(wcb2.state, STATE_SUCCEEDED);
if (wcb3.state != STATE_SUCCEEDED) {
throw(wcb3.exception);
}
CHECK_EQ(wcb3.state, STATE_SUCCEEDED);
// Make sure the reader got the data with the right grouping
CHECK_EQ(rcb.state, STATE_SUCCEEDED);
CHECK_EQ(rcb.buffers.size(), 2);
CHECK_EQ(rcb.buffers[0].length, buf1Length);
CHECK_EQ(rcb.buffers[1].length, buf2Length + buf3Length);
acceptedSocket->close();
socket->close();
}
/**
* Test performing a zero-length write
*/
TEST(AsyncSocketTest, ZeroLengthWrite) {
TestServer server;
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket =
AsyncSocket::newSocket(&evb, server.getAddress(), 30);
evb.loop(); // loop until the socket is connected
auto acceptedSocket = server.acceptAsync(&evb);
ReadCallback rcb;
acceptedSocket->setReadCB(&rcb);
size_t len1 = 1024*1024;
size_t len2 = 1024*1024;
std::unique_ptr<char[]> buf(new char[len1 + len2]);
memset(buf.get(), 'a', len1);
memset(buf.get(), 'b', len2);
WriteCallback wcb1;
WriteCallback wcb2;
WriteCallback wcb3;
WriteCallback wcb4;
socket->write(&wcb1, buf.get(), 0);
socket->write(&wcb2, buf.get(), len1);
socket->write(&wcb3, buf.get() + len1, 0);
socket->write(&wcb4, buf.get() + len1, len2);
socket->close();
evb.loop(); // loop until the data is sent
CHECK_EQ(wcb1.state, STATE_SUCCEEDED);
CHECK_EQ(wcb2.state, STATE_SUCCEEDED);
CHECK_EQ(wcb3.state, STATE_SUCCEEDED);
CHECK_EQ(wcb4.state, STATE_SUCCEEDED);
rcb.verifyData(buf.get(), len1 + len2);
}
TEST(AsyncSocketTest, ZeroLengthWritev) {
TestServer server;
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket =
AsyncSocket::newSocket(&evb, server.getAddress(), 30);
evb.loop(); // loop until the socket is connected
auto acceptedSocket = server.acceptAsync(&evb);
ReadCallback rcb;
acceptedSocket->setReadCB(&rcb);
size_t len1 = 1024*1024;
size_t len2 = 1024*1024;
std::unique_ptr<char[]> buf(new char[len1 + len2]);
memset(buf.get(), 'a', len1);
memset(buf.get(), 'b', len2);
WriteCallback wcb;
size_t iovCount = 4;
struct iovec iov[iovCount];
iov[0].iov_base = buf.get();
iov[0].iov_len = len1;
iov[1].iov_base = buf.get() + len1;
iov[1].iov_len = 0;
iov[2].iov_base = buf.get() + len1;
iov[2].iov_len = len2;
iov[3].iov_base = buf.get() + len1 + len2;
iov[3].iov_len = 0;
socket->writev(&wcb, iov, iovCount);
socket->close();
evb.loop(); // loop until the data is sent
CHECK_EQ(wcb.state, STATE_SUCCEEDED);
rcb.verifyData(buf.get(), len1 + len2);
}
///////////////////////////////////////////////////////////////////////////
// close() related tests
///////////////////////////////////////////////////////////////////////////
/**
* Test calling close() with pending writes when the socket is already closing.
*/
TEST(AsyncSocketTest, ClosePendingWritesWhileClosing) {
TestServer server;
// connect()
EventBase evb;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30);
// accept the socket on the server side
std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
// Loop to ensure the connect has completed
evb.loop();
// Make sure we are connected
CHECK_EQ(ccb.state, STATE_SUCCEEDED);
// Schedule pending writes, until several write attempts have blocked
char buf[128];
memset(buf, 'a', sizeof(buf));
typedef vector< std::shared_ptr<WriteCallback> > WriteCallbackVector;
WriteCallbackVector writeCallbacks;
writeCallbacks.reserve(5);
while (writeCallbacks.size() < 5) {
std::shared_ptr<WriteCallback> wcb(new WriteCallback);
socket->write(wcb.get(), buf, sizeof(buf));
if (wcb->state == STATE_SUCCEEDED) {
// Succeeded immediately. Keep performing more writes
continue;
}
// This write is blocked.
// Have the write callback call close() when writeError() is invoked
wcb->errorCallback = std::bind(&AsyncSocket::close, socket.get());
writeCallbacks.push_back(wcb);
}
// Call closeNow() to immediately fail the pending writes
socket->closeNow();
// Make sure writeError() was invoked on all of the pending write callbacks
for (WriteCallbackVector::const_iterator it = writeCallbacks.begin();
it != writeCallbacks.end();
++it) {
CHECK_EQ((*it)->state, STATE_FAILED);
}
}
// TODO:
// - Test connect() and have the connect callback set the read callback
// - Test connect() and have the connect callback unset the read callback
// - Test reading/writing/closing/destroying the socket in the connect callback
// - Test reading/writing/closing/destroying the socket in the read callback
// - Test reading/writing/closing/destroying the socket in the write callback
// - Test one-way shutdown behavior
// - Test changing the EventBase
//
// - TODO: test multiple threads sharing a AsyncSocket, and detaching from it
// in connectSuccess(), readDataAvailable(), writeSuccess()
///////////////////////////////////////////////////////////////////////////
// AsyncServerSocket tests
///////////////////////////////////////////////////////////////////////////
/**
* Helper AcceptCallback class for the test code
* It records the callbacks that were invoked, and also supports calling
* generic std::function objects in each callback.
*/
class TestAcceptCallback : public AsyncServerSocket::AcceptCallback {
public:
enum EventType {
TYPE_START,
TYPE_ACCEPT,
TYPE_ERROR,
TYPE_STOP
};
struct EventInfo {
EventInfo(int fd, const folly::SocketAddress& addr)
: type(TYPE_ACCEPT),
fd(fd),
address(addr),
errorMsg() {}
explicit EventInfo(const std::string& msg)
: type(TYPE_ERROR),
fd(-1),
address(),
errorMsg(msg) {}
explicit EventInfo(EventType et)
: type(et),
fd(-1),
address(),
errorMsg() {}
EventType type;
int fd; // valid for TYPE_ACCEPT
folly::SocketAddress address; // valid for TYPE_ACCEPT
string errorMsg; // valid for TYPE_ERROR
};
typedef std::deque<EventInfo> EventList;
TestAcceptCallback()
: connectionAcceptedFn_(),
acceptErrorFn_(),
acceptStoppedFn_(),
events_() {}
std::deque<EventInfo>* getEvents() {
return &events_;
}
void setConnectionAcceptedFn(
const std::function<void(int, const folly::SocketAddress&)>& fn) {
connectionAcceptedFn_ = fn;
}
void setAcceptErrorFn(const std::function<void(const std::exception&)>& fn) {
acceptErrorFn_ = fn;
}
void setAcceptStartedFn(const std::function<void()>& fn) {
acceptStartedFn_ = fn;
}
void setAcceptStoppedFn(const std::function<void()>& fn) {
acceptStoppedFn_ = fn;
}
void connectionAccepted(int fd, const folly::SocketAddress& clientAddr)
noexcept {
events_.push_back(EventInfo(fd, clientAddr));
if (connectionAcceptedFn_) {
connectionAcceptedFn_(fd, clientAddr);
}
}
void acceptError(const std::exception& ex) noexcept {
events_.push_back(EventInfo(ex.what()));
if (acceptErrorFn_) {
acceptErrorFn_(ex);
}
}
void acceptStarted() noexcept {
events_.push_back(EventInfo(TYPE_START));
if (acceptStartedFn_) {
acceptStartedFn_();
}
}
void acceptStopped() noexcept {
events_.push_back(EventInfo(TYPE_STOP));
if (acceptStoppedFn_) {
acceptStoppedFn_();
}
}
private:
std::function<void(int, const folly::SocketAddress&)> connectionAcceptedFn_;
std::function<void(const std::exception&)> acceptErrorFn_;
std::function<void()> acceptStartedFn_;
std::function<void()> acceptStoppedFn_;
std::deque<EventInfo> events_;
};
/**
* Make sure accepted sockets have O_NONBLOCK and TCP_NODELAY set
*/
TEST(AsyncSocketTest, ServerAcceptOptions) {
EventBase eventBase;
// Create a server socket
std::shared_ptr<AsyncServerSocket> serverSocket(
AsyncServerSocket::newSocket(&eventBase));
serverSocket->bind(0);
serverSocket->listen(16);
folly::SocketAddress serverAddress;
serverSocket->getAddress(&serverAddress);
// Add a callback to accept one connection then stop the loop
TestAcceptCallback acceptCallback;
acceptCallback.setConnectionAcceptedFn(
[&](int fd, const folly::SocketAddress& addr) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
});
acceptCallback.setAcceptErrorFn([&](const std::exception& ex) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
});
serverSocket->addAcceptCallback(&acceptCallback, nullptr);
serverSocket->startAccepting();
// Connect to the server socket
std::shared_ptr<AsyncSocket> socket(
AsyncSocket::newSocket(&eventBase, serverAddress));
eventBase.loop();
// Verify that the server accepted a connection
CHECK_EQ(acceptCallback.getEvents()->size(), 3);
CHECK_EQ(acceptCallback.getEvents()->at(0).type,
TestAcceptCallback::TYPE_START);
CHECK_EQ(acceptCallback.getEvents()->at(1).type,
TestAcceptCallback::TYPE_ACCEPT);
CHECK_EQ(acceptCallback.getEvents()->at(2).type,
TestAcceptCallback::TYPE_STOP);
int fd = acceptCallback.getEvents()->at(1).fd;
// The accepted connection should already be in non-blocking mode
int flags = fcntl(fd, F_GETFL, 0);
CHECK_EQ(flags & O_NONBLOCK, O_NONBLOCK);
#ifndef TCP_NOPUSH
// The accepted connection should already have TCP_NODELAY set
int value;
socklen_t valueLength = sizeof(value);
int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
CHECK_EQ(rc, 0);
CHECK_EQ(value, 1);
#endif
}
/**
* Test AsyncServerSocket::removeAcceptCallback()
*/
TEST(AsyncSocketTest, RemoveAcceptCallback) {
// Create a new AsyncServerSocket
EventBase eventBase;
std::shared_ptr<AsyncServerSocket> serverSocket(
AsyncServerSocket::newSocket(&eventBase));
serverSocket->bind(0);
serverSocket->listen(16);
folly::SocketAddress serverAddress;
serverSocket->getAddress(&serverAddress);
// Add several accept callbacks
TestAcceptCallback cb1;
TestAcceptCallback cb2;
TestAcceptCallback cb3;
TestAcceptCallback cb4;
TestAcceptCallback cb5;
TestAcceptCallback cb6;
TestAcceptCallback cb7;
// Test having callbacks remove other callbacks before them on the list,
// after them on the list, or removing themselves.
//
// Have callback 2 remove callback 3 and callback 5 the first time it is
// called.
int cb2Count = 0;
cb1.setConnectionAcceptedFn([&](int fd, const folly::SocketAddress& addr){
std::shared_ptr<AsyncSocket> sock2(
AsyncSocket::newSocket(&eventBase, serverAddress)); // cb2: -cb3 -cb5
});
cb3.setConnectionAcceptedFn([&](int fd, const folly::SocketAddress& addr){
});
cb4.setConnectionAcceptedFn([&](int fd, const folly::SocketAddress& addr){
std::shared_ptr<AsyncSocket> sock3(
AsyncSocket::newSocket(&eventBase, serverAddress)); // cb4
});
cb5.setConnectionAcceptedFn([&](int fd, const folly::SocketAddress& addr){
std::shared_ptr<AsyncSocket> sock5(
AsyncSocket::newSocket(&eventBase, serverAddress)); // cb7: -cb7
});
cb2.setConnectionAcceptedFn(
[&](int fd, const folly::SocketAddress& addr) {
if (cb2Count == 0) {
serverSocket->removeAcceptCallback(&cb3, nullptr);
serverSocket->removeAcceptCallback(&cb5, nullptr);
}
++cb2Count;
});
// Have callback 6 remove callback 4 the first time it is called,
// and destroy the server socket the second time it is called
int cb6Count = 0;
cb6.setConnectionAcceptedFn(
[&](int fd, const folly::SocketAddress& addr) {
if (cb6Count == 0) {
serverSocket->removeAcceptCallback(&cb4, nullptr);
std::shared_ptr<AsyncSocket> sock6(
AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
std::shared_ptr<AsyncSocket> sock7(
AsyncSocket::newSocket(&eventBase, serverAddress)); // cb2
std::shared_ptr<AsyncSocket> sock8(
AsyncSocket::newSocket(&eventBase, serverAddress)); // cb6: stop
} else {
serverSocket.reset();
}
++cb6Count;
});
// Have callback 7 remove itself
cb7.setConnectionAcceptedFn(
[&](int fd, const folly::SocketAddress& addr) {
serverSocket->removeAcceptCallback(&cb7, nullptr);
});
serverSocket->addAcceptCallback(&cb1, nullptr);
serverSocket->addAcceptCallback(&cb2, nullptr);
serverSocket->addAcceptCallback(&cb3, nullptr);
serverSocket->addAcceptCallback(&cb4, nullptr);
serverSocket->addAcceptCallback(&cb5, nullptr);
serverSocket->addAcceptCallback(&cb6, nullptr);
serverSocket->addAcceptCallback(&cb7, nullptr);
serverSocket->startAccepting();
// Make several connections to the socket
std::shared_ptr<AsyncSocket> sock1(
AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
std::shared_ptr<AsyncSocket> sock4(
AsyncSocket::newSocket(&eventBase, serverAddress)); // cb6: -cb4
// Loop until we are stopped
eventBase.loop();
// Check to make sure that the expected callbacks were invoked.
//
// NOTE: This code depends on the AsyncServerSocket operating calling all of
// the AcceptCallbacks in round-robin fashion, in the order that they were
// added. The code is implemented this way right now, but the API doesn't
// explicitly require it be done this way. If we change the code not to be
// exactly round robin in the future, we can simplify the test checks here.
// (We'll also need to update the termination code, since we expect cb6 to
// get called twice to terminate the loop.)
CHECK_EQ(cb1.getEvents()->size(), 4);
CHECK_EQ(cb1.getEvents()->at(0).type,
TestAcceptCallback::TYPE_START);
CHECK_EQ(cb1.getEvents()->at(1).type,
TestAcceptCallback::TYPE_ACCEPT);
CHECK_EQ(cb1.getEvents()->at(2).type,
TestAcceptCallback::TYPE_ACCEPT);
CHECK_EQ(cb1.getEvents()->at(3).type,
TestAcceptCallback::TYPE_STOP);
CHECK_EQ(cb2.getEvents()->size(), 4);
CHECK_EQ(cb2.getEvents()->at(0).type,
TestAcceptCallback::TYPE_START);
CHECK_EQ(cb2.getEvents()->at(1).type,
TestAcceptCallback::TYPE_ACCEPT);
CHECK_EQ(cb2.getEvents()->at(2).type,
TestAcceptCallback::TYPE_ACCEPT);
CHECK_EQ(cb2.getEvents()->at(3).type,
TestAcceptCallback::TYPE_STOP);
CHECK_EQ(cb3.getEvents()->size(), 2);
CHECK_EQ(cb3.getEvents()->at(0).type,
TestAcceptCallback::TYPE_START);
CHECK_EQ(cb3.getEvents()->at(1).type,
TestAcceptCallback::TYPE_STOP);
CHECK_EQ(cb4.getEvents()->size(), 3);
CHECK_EQ(cb4.getEvents()->at(0).type,
TestAcceptCallback::TYPE_START);
CHECK_EQ(cb4.getEvents()->at(1).type,
TestAcceptCallback::TYPE_ACCEPT);
CHECK_EQ(cb4.getEvents()->at(2).type,
TestAcceptCallback::TYPE_STOP);
CHECK_EQ(cb5.getEvents()->size(), 2);
CHECK_EQ(cb5.getEvents()->at(0).type,
TestAcceptCallback::TYPE_START);
CHECK_EQ(cb5.getEvents()->at(1).type,
TestAcceptCallback::TYPE_STOP);
CHECK_EQ(cb6.getEvents()->size(), 4);
CHECK_EQ(cb6.getEvents()->at(0).type,
TestAcceptCallback::TYPE_START);
CHECK_EQ(cb6.getEvents()->at(1).type,
TestAcceptCallback::TYPE_ACCEPT);
CHECK_EQ(cb6.getEvents()->at(2).type,
TestAcceptCallback::TYPE_ACCEPT);
CHECK_EQ(cb6.getEvents()->at(3).type,
TestAcceptCallback::TYPE_STOP);
CHECK_EQ(cb7.getEvents()->size(), 3);
CHECK_EQ(cb7.getEvents()->at(0).type,
TestAcceptCallback::TYPE_START);
CHECK_EQ(cb7.getEvents()->at(1).type,
TestAcceptCallback::TYPE_ACCEPT);
CHECK_EQ(cb7.getEvents()->at(2).type,
TestAcceptCallback::TYPE_STOP);
}
/**
* Test AsyncServerSocket::removeAcceptCallback()
*/
TEST(AsyncSocketTest, OtherThreadAcceptCallback) {
// Create a new AsyncServerSocket
EventBase eventBase;
std::shared_ptr<AsyncServerSocket> serverSocket(
AsyncServerSocket::newSocket(&eventBase));
serverSocket->bind(0);
serverSocket->listen(16);
folly::SocketAddress serverAddress;
serverSocket->getAddress(&serverAddress);
// Add several accept callbacks
TestAcceptCallback cb1;
auto thread_id = pthread_self();
cb1.setAcceptStartedFn([&](){
CHECK_NE(thread_id, pthread_self());
thread_id = pthread_self();
});
cb1.setConnectionAcceptedFn([&](int fd, const folly::SocketAddress& addr){
CHECK_EQ(thread_id, pthread_self());
serverSocket->removeAcceptCallback(&cb1, nullptr);
});
cb1.setAcceptStoppedFn([&](){
CHECK_EQ(thread_id, pthread_self());
});
// Test having callbacks remove other callbacks before them on the list,
serverSocket->addAcceptCallback(&cb1, nullptr);
serverSocket->startAccepting();
// Make several connections to the socket
std::shared_ptr<AsyncSocket> sock1(
AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
// Loop in another thread
auto other = std::thread([&](){
eventBase.loop();
});
other.join();
// Check to make sure that the expected callbacks were invoked.
//
// NOTE: This code depends on the AsyncServerSocket operating calling all of
// the AcceptCallbacks in round-robin fashion, in the order that they were
// added. The code is implemented this way right now, but the API doesn't
// explicitly require it be done this way. If we change the code not to be
// exactly round robin in the future, we can simplify the test checks here.
// (We'll also need to update the termination code, since we expect cb6 to
// get called twice to terminate the loop.)
CHECK_EQ(cb1.getEvents()->size(), 3);
CHECK_EQ(cb1.getEvents()->at(0).type,
TestAcceptCallback::TYPE_START);
CHECK_EQ(cb1.getEvents()->at(1).type,
TestAcceptCallback::TYPE_ACCEPT);
CHECK_EQ(cb1.getEvents()->at(2).type,
TestAcceptCallback::TYPE_STOP);
}
void serverSocketSanityTest(AsyncServerSocket* serverSocket) {
// Add a callback to accept one connection then stop accepting
TestAcceptCallback acceptCallback;
acceptCallback.setConnectionAcceptedFn(
[&](int fd, const folly::SocketAddress& addr) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
});
acceptCallback.setAcceptErrorFn([&](const std::exception& ex) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
});
serverSocket->addAcceptCallback(&acceptCallback, nullptr);
serverSocket->startAccepting();
// Connect to the server socket
EventBase* eventBase = serverSocket->getEventBase();
folly::SocketAddress serverAddress;
serverSocket->getAddress(&serverAddress);
AsyncSocket::UniquePtr socket(new AsyncSocket(eventBase, serverAddress));
// Loop to process all events
eventBase->loop();
// Verify that the server accepted a connection
CHECK_EQ(acceptCallback.getEvents()->size(), 3);
CHECK_EQ(acceptCallback.getEvents()->at(0).type,
TestAcceptCallback::TYPE_START);
CHECK_EQ(acceptCallback.getEvents()->at(1).type,
TestAcceptCallback::TYPE_ACCEPT);
CHECK_EQ(acceptCallback.getEvents()->at(2).type,
TestAcceptCallback::TYPE_STOP);
}
/* Verify that we don't leak sockets if we are destroyed()
* and there are still writes pending
*
* If destroy() only calls close() instead of closeNow(),
* it would shutdown(writes) on the socket, but it would
* never be close()'d, and the socket would leak
*/
TEST(AsyncSocketTest, DestroyCloseTest) {
TestServer server;
// connect()
EventBase clientEB;
EventBase serverEB;
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&clientEB);
ConnCallback ccb;
socket->connect(&ccb, server.getAddress(), 30);
// Accept the connection
std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&serverEB);
ReadCallback rcb;
acceptedSocket->setReadCB(&rcb);
// Write a large buffer to the socket that is larger than kernel buffer
size_t simpleBufLength = 5000000;
char* simpleBuf = new char[simpleBufLength];
memset(simpleBuf, 'a', simpleBufLength);
WriteCallback wcb;
// Let the reads and writes run to completion
int fd = acceptedSocket->getFd();
acceptedSocket->write(&wcb, simpleBuf, simpleBufLength);
socket.reset();
acceptedSocket.reset();
// Test that server socket was closed
ssize_t sz = read(fd, simpleBuf, simpleBufLength);
CHECK_EQ(sz, -1);
CHECK_EQ(errno, 9);
delete[] simpleBuf;
}
/**
* Test AsyncServerSocket::useExistingSocket()
*/
TEST(AsyncSocketTest, ServerExistingSocket) {
EventBase eventBase;
// Test creating a socket, and letting AsyncServerSocket bind and listen
{
// Manually create a socket
int fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
ASSERT_GE(fd, 0);
// Create a server socket
AsyncServerSocket::UniquePtr serverSocket(
new AsyncServerSocket(&eventBase));
serverSocket->useExistingSocket(fd);
folly::SocketAddress address;
serverSocket->getAddress(&address);
address.setPort(0);
serverSocket->bind(address);
serverSocket->listen(16);
// Make sure the socket works
serverSocketSanityTest(serverSocket.get());
}
// Test creating a socket and binding manually,
// then letting AsyncServerSocket listen
{
// Manually create a socket
int fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
ASSERT_GE(fd, 0);
// bind
struct sockaddr_in addr;
addr.sin_family = AF_INET;
addr.sin_port = 0;
addr.sin_addr.s_addr = INADDR_ANY;
CHECK_EQ(bind(fd, reinterpret_cast<struct sockaddr*>(&addr),
sizeof(addr)), 0);
// Look up the address that we bound to
folly::SocketAddress boundAddress;
boundAddress.setFromLocalAddress(fd);
// Create a server socket
AsyncServerSocket::UniquePtr serverSocket(
new AsyncServerSocket(&eventBase));
serverSocket->useExistingSocket(fd);
serverSocket->listen(16);
// Make sure AsyncServerSocket reports the same address that we bound to
folly::SocketAddress serverSocketAddress;
serverSocket->getAddress(&serverSocketAddress);
CHECK_EQ(boundAddress, serverSocketAddress);
// Make sure the socket works
serverSocketSanityTest(serverSocket.get());
}
// Test creating a socket, binding and listening manually,
// then giving it to AsyncServerSocket
{
// Manually create a socket
int fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
ASSERT_GE(fd, 0);
// bind
struct sockaddr_in addr;
addr.sin_family = AF_INET;
addr.sin_port = 0;
addr.sin_addr.s_addr = INADDR_ANY;
CHECK_EQ(bind(fd, reinterpret_cast<struct sockaddr*>(&addr),
sizeof(addr)), 0);
// Look up the address that we bound to
folly::SocketAddress boundAddress;
boundAddress.setFromLocalAddress(fd);
// listen
CHECK_EQ(listen(fd, 16), 0);
// Create a server socket
AsyncServerSocket::UniquePtr serverSocket(
new AsyncServerSocket(&eventBase));
serverSocket->useExistingSocket(fd);
// Make sure AsyncServerSocket reports the same address that we bound to
folly::SocketAddress serverSocketAddress;
serverSocket->getAddress(&serverSocketAddress);
CHECK_EQ(boundAddress, serverSocketAddress);
// Make sure the socket works
serverSocketSanityTest(serverSocket.get());
}
}
TEST(AsyncSocketTest, UnixDomainSocketTest) {
EventBase eventBase;
// Create a server socket
std::shared_ptr<AsyncServerSocket> serverSocket(
AsyncServerSocket::newSocket(&eventBase));
string path(1, 0);
path.append("/anonymous");
folly::SocketAddress serverAddress;
serverAddress.setFromPath(path);
serverSocket->bind(serverAddress);
serverSocket->listen(16);
// Add a callback to accept one connection then stop the loop
TestAcceptCallback acceptCallback;
acceptCallback.setConnectionAcceptedFn(
[&](int fd, const folly::SocketAddress& addr) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
});
acceptCallback.setAcceptErrorFn([&](const std::exception& ex) {
serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
});
serverSocket->addAcceptCallback(&acceptCallback, nullptr);
serverSocket->startAccepting();
// Connect to the server socket
std::shared_ptr<AsyncSocket> socket(
AsyncSocket::newSocket(&eventBase, serverAddress));
eventBase.loop();
// Verify that the server accepted a connection
CHECK_EQ(acceptCallback.getEvents()->size(), 3);
CHECK_EQ(acceptCallback.getEvents()->at(0).type,
TestAcceptCallback::TYPE_START);
CHECK_EQ(acceptCallback.getEvents()->at(1).type,
TestAcceptCallback::TYPE_ACCEPT);
CHECK_EQ(acceptCallback.getEvents()->at(2).type,
TestAcceptCallback::TYPE_STOP);
int fd = acceptCallback.getEvents()->at(1).fd;
// The accepted connection should already be in non-blocking mode
int flags = fcntl(fd, F_GETFL, 0);
CHECK_EQ(flags & O_NONBLOCK, O_NONBLOCK);
}
/*
* Copyright 2015 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <folly/Optional.h>
#include <folly/io/async/SSLContext.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/AsyncSSLSocket.h>
class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
public folly::AsyncTransportWrapper::ReadCallback,
public folly::AsyncTransportWrapper::WriteCallback
{
public:
explicit BlockingSocket(int fd)
: sock_(new folly::AsyncSocket(&eventBase_, fd)) {
}
BlockingSocket(folly::SocketAddress address,
std::shared_ptr<folly::SSLContext> sslContext)
: sock_(sslContext ? new folly::AsyncSSLSocket(sslContext, &eventBase_) :
new folly::AsyncSocket(&eventBase_)),
address_(address) {}
void open() {
sock_->connect(this, address_);
eventBase_.loop();
if (err_.hasValue()) {
throw err_.value();
}
}
void close() {
sock_->close();
}
int32_t write(uint8_t const* buf, size_t len) {
sock_->write(this, buf, len);
eventBase_.loop();
if (err_.hasValue()) {
throw err_.value();
}
return len;
}
void flush() {}
int32_t readAll(uint8_t *buf, size_t len) {
return readHelper(buf, len, true);
}
int32_t read(uint8_t *buf, size_t len) {
return readHelper(buf, len, false);
}
int getSocketFD() const {
return sock_->getFd();
}
private:
folly::EventBase eventBase_;
folly::AsyncSocket::UniquePtr sock_;
folly::Optional<folly::AsyncSocketException> err_;
uint8_t *readBuf_{nullptr};
size_t readLen_{0};
folly::SocketAddress address_;
void connectSuccess() noexcept override {}
void connectErr(const folly::AsyncSocketException& ex) noexcept override {
err_ = ex;
}
void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
*bufReturn = readBuf_;
*lenReturn = readLen_;
}
void readDataAvailable(size_t len) noexcept override {
readBuf_ += len;
readLen_ -= len;
if (readLen_ == 0) {
sock_->setReadCB(nullptr);
}
}
void readEOF() noexcept override {
}
void readErr(const folly::AsyncSocketException& ex) noexcept override {
err_ = ex;
}
void writeSuccess() noexcept override {}
void writeErr(size_t bytesWritten,
const folly::AsyncSocketException& ex) noexcept override {
err_ = ex;
}
int32_t readHelper(uint8_t *buf, size_t len, bool all) {
readBuf_ = buf;
readLen_ = len;
sock_->setReadCB(this);
while (!err_ && sock_->good() && readLen_ > 0) {
eventBase_.loop();
if (!all) {
break;
}
}
sock_->setReadCB(nullptr);
if (err_.hasValue()) {
throw err_.value();
}
if (all && readLen_ > 0) {
throw folly::AsyncSocketException(folly::AsyncSocketException::UNKNOWN,
"eof");
}
return len - readLen_;
}
};
-----BEGIN CERTIFICATE-----
MIIDXTCCAkWgAwIBAgIJAKMZICGWUzawMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV
BAYTAlVTMQ8wDQYDVQQKDAZUaHJpZnQxJTAjBgNVBAMMHFRocmlmdCBDZXJ0aWZp
Y2F0ZSBBdXRob3JpdHkwHhcNMTQwNTE2MjAyODUyWhcNNDExMDAxMjAyODUyWjBF
MQswCQYDVQQGEwJVUzEPMA0GA1UECgwGVGhyaWZ0MSUwIwYDVQQDDBxUaHJpZnQg
Q2VydGlmaWNhdGUgQXV0aG9yaXR5MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB
CgKCAQEA1Bx2vUvXZ8PrvEBxwdH5qM1F2Xo7UkeC1jzQ+OLUBEcCiEduyStitSvB
NOAzAGdjt7NmHTP/7OJngp2vzQGjSQzm20XacyTieFUuPBuikUc0Ge3Tf+uQXtiU
zZPh+xn6arHH+zBWtmUCt3cBrpgRqdnWUsbl8eqo5HsczY781FxQbDoT9VP6A+9R
KGTsEhxxKbWJ1C7OngwLKc7Zv4DtTC1JFlFyKd8ryDtxP4s/GgsXJkoK0Hkpputr
cMxMm6OGt77mFvzR2qRY1CpEK/9rjBB6Gqd8GakXsvoOsqL/37k2wVhN/JoS/Pde
12Mp6TZ2rA8NW8vRujfWU0u55gnQnwIDAQABo1AwTjAdBgNVHQ4EFgQUQ00NGVmY
NZ6LJg8UQUOVLZX1Gh8wHwYDVR0jBBgwFoAUQ00NGVmYNZ6LJg8UQUOVLZX1Gh8w
DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQUFAAOCAQEAdlxt5+z9uXCBr1Wt6r49
4MmOYw9lOnEOG1JPMRo108TLpmwXEWReCAtjQuR7BitRJW0kJtlO1M6t3qoIh6GA
sBkgsjQM1xNY3YEpx71MLt1V+JD+2WtSBKMyysj1TiOmIH66kkvXO3ptXzhjhZyX
G6B+kxLtxrqkn9SJULyN55X8T+dkW28UIBZVLavoREDU+UPrYU9JgZeIVObtGSWi
DvS4RIJZNjgG3vTrT00rfUGEfTlI54Vbcmv0cYvswP/nMsLtDStCdgI7c/ipyJve
dfuI4CedjE240AxK5OFxFg/k/IfnB4a5oojbdIR9hKrTU57TPaUVD50Na9WA1aqX
5Q==
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIDKzCCAhOgAwIBAgIBCjANBgkqhkiG9w0BAQUFADBFMQswCQYDVQQGEwJVUzEP
MA0GA1UECgwGVGhyaWZ0MSUwIwYDVQQDDBxUaHJpZnQgQ2VydGlmaWNhdGUgQXV0
aG9yaXR5MB4XDTE0MDUxNjIwMjg1MloXDTQxMTAwMTIwMjg1MlowRjELMAkGA1UE
BhMCVVMxDTALBgNVBAgTBE9oaW8xETAPBgNVBAcTCEhpbGxpYXJkMRUwEwYDVQQD
EwxBc294IENvbXBhbnkwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCz
ZGrJ5XQHAuMYHlBgn32OOc9l0n3RjXccio2ceeWctXkSxDP3vFyZ4kBILF1lsY1g
o8UTjMkqSDYcytCLK0qavrv9BZRLB9FcpqJ9o4V9feaI/HsHa8DYHyEs8qyNTTNG
YQ3i4j+AA9iDSpezIYy/tyAOAjrSquUW1jI4tzKTBh8hk8MAMvR2/NPHPkrp4gI+
EMH6u4vWdr4F9bbriLFWoU04T9mWOMk7G+h8BS9sgINg2+v5cWvl3BC4kLk5L1yJ
FEyuofSSCEEe6dDf7uVh+RPKa4hEkIYo31AEOPFrN56d+pCj/5l67HTWXoQx3rjy
dNXMvgU75urm6TQe8dB5AgMBAAGjJTAjMCEGA1UdEQQaMBiHBH8AAAGHEAAAAAAA
AAAAAAAAAAAAAAEwDQYJKoZIhvcNAQEFBQADggEBAD26XYInaEvlWZJYgtl3yQyC
3NRQc3LG7XxWg4aFdXCxYLPRAL2HLoarKYH8GPFso57t5xnhA8WfP7iJxmgsKdCS
0pNIicOWsMmXvYLib0j9tMCFR+a8rn3f4n+clwnqas4w/vWBJUoMgyxtkP8NNNZO
kIl02JKRhuyiFyPLilVp5tu0e+lmyUER+ak53WjLq2yoytYAlHkzkOpc4MZ/TNt5
UTEtx/WVlZvlrPi3dsi7QikkjQgo1wCnm7owtuAHlPDMAB8wKk4+vvIOjsGM33T/
8ffq/4X1HeYM0w0fM+SVlX1rwkXA1RW/jn48VWFHpWbE10+m196OdiToGfm2OJI=
-----END CERTIFICATE-----
-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAs2RqyeV0BwLjGB5QYJ99jjnPZdJ90Y13HIqNnHnlnLV5EsQz
97xcmeJASCxdZbGNYKPFE4zJKkg2HMrQiytKmr67/QWUSwfRXKaifaOFfX3miPx7
B2vA2B8hLPKsjU0zRmEN4uI/gAPYg0qXsyGMv7cgDgI60qrlFtYyOLcykwYfIZPD
ADL0dvzTxz5K6eICPhDB+ruL1na+BfW264ixVqFNOE/ZljjJOxvofAUvbICDYNvr
+XFr5dwQuJC5OS9ciRRMrqH0kghBHunQ3+7lYfkTymuIRJCGKN9QBDjxazeenfqQ
o/+Zeux01l6EMd648nTVzL4FO+bq5uk0HvHQeQIDAQABAoIBAQCSPcBYindF5/Kd
jMjVm+9M7I/IYAo1tG9vkvvSngSy9bWXuN7sjF+pCyqAK7qP1mh8acWVJGYx0+BZ
JHVRnp8Y+3hg0hWL/PmN4EICzjVakjJHZhwddpglF2uCKurD3jV4oFIjrXE6uOfe
UAbO/wCwoWa+RM8TQkGzljYmyiGufCcXlgEKMNA7TIvbJ9TVx3VTCOQy6EjZ13jd
M6X7byV/ZOFpZ2H0QV46LvZraw04riXQ/59gVmzizYdI+BwnxxapsCmalTJoV/Y0
LMI2ylat4PTMVTxPF+ti7Nt+rUkkEx6kuiAgfc+bzE4BSD5X4wy3fdLVLccoxXYw
4N3fOuQhAoGBAOLrMhiSCrzXGjDWTbPrwzxXDO0qm+wURELi3N5SXIkKUdG2/In6
wNdpXdvqblOm7SASgPf9KCwUSADrNw6R6nbfrrir5EHg66YydI/OW42QzJKcBUFh
5Q5na3fvoL/zRhsmh0gEymBg+OIfNel2LY69bl8aAko2y0R1kj7zb8X1AoGBAMph
9hlnkIBSw60+pKHaOqo2t/kihNNMFyfOgJvh8960eFeMDhMIXgxPUR8yaPX0bBMb
bCdEJJ2pmq7zUBPvxVJLedwkGMhywElA8yYVh+S6x4Cg+lYo4spIjrHQ/WTvJkHB
GrDskxdq80lbXjwRd0dPJZkxhKJec1o0n8S03Mn1AoGAGarK5taWGlgmYUHMVj6j
vc6G6si4DFMaiYpJu2gLiYC+Un9lP2I6r+L+N+LjidjG16rgJazf/2Rn5Jq2hpJg
uAODKuZekkkTvp/UaXPJDVFEooy9V3DwTNnL4SwcvbmRw35vLOlFzvMJE+K94WN5
sbyhoGY7vhNGmL7HxREaIoUCgYEAwpteVWFz3yE2ziF9l7FMVh7V23go9zGk1n9I
xhyJL26khbLEWeLi5L1kiTYlHdUSE3F8F2n8N6s+ddq79t/KA29WV6xSNHW7lvUg
mk975CMC8hpZfn5ETjVlGXGYJ/Wa+QGiE9z5ODx8gt6cB/DXnLdrtRqbqrJeA7C0
rScpY/0CgYBCC1QeuAiwWHOqQn3BwsZo9JQBTyT0QvDqLH/F+h9QbXep+4HvyAxG
nTMNDtGyfyKGDaDUn5hyeU7Oxvzq0K9P+eZD3MjQeaMEg/++GPGUPmDUTqyb2UT8
5s0NIUobxfKnTD6IpgOIq7ffvVY6cKBMyuLmu/gSvscsbONHjKti3Q==
-----END RSA PRIVATE KEY-----
/*
* Copyright 2015 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/io/ShutdownSocketSet.h>
#include <atomic>
#include <chrono>
#include <thread>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/socket.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
using folly::ShutdownSocketSet;
namespace folly { namespace test {
ShutdownSocketSet shutdownSocketSet;
class Server {
public:
Server();
void stop(bool abortive);
void join();
int port() const { return port_; }
int closeClients(bool abortive);
private:
int acceptSocket_;
int port_;
enum StopMode {
NO_STOP,
ORDERLY,
ABORTIVE
};
std::atomic<StopMode> stop_;
std::thread serverThread_;
std::vector<int> fds_;
};
Server::Server()
: acceptSocket_(-1),
port_(0),
stop_(NO_STOP) {
acceptSocket_ = socket(PF_INET, SOCK_STREAM, 0);
CHECK_ERR(acceptSocket_);
shutdownSocketSet.add(acceptSocket_);
sockaddr_in addr;
addr.sin_family = AF_INET;
addr.sin_port = 0;
addr.sin_addr.s_addr = INADDR_ANY;
CHECK_ERR(bind(acceptSocket_,
reinterpret_cast<const sockaddr*>(&addr),
sizeof(addr)));
CHECK_ERR(listen(acceptSocket_, 10));
socklen_t addrLen = sizeof(addr);
CHECK_ERR(getsockname(acceptSocket_,
reinterpret_cast<sockaddr*>(&addr),
&addrLen));
port_ = ntohs(addr.sin_port);
serverThread_ = std::thread([this] {
while (stop_ == NO_STOP) {
sockaddr_in peer;
socklen_t peerLen = sizeof(peer);
int fd = accept(acceptSocket_,
reinterpret_cast<sockaddr*>(&peer),
&peerLen);
if (fd == -1) {
if (errno == EINTR) {
continue;
}
if (errno == EINVAL || errno == ENOTSOCK) { // socket broken
break;
}
}
CHECK_ERR(fd);
shutdownSocketSet.add(fd);
fds_.push_back(fd);
}
if (stop_ != NO_STOP) {
closeClients(stop_ == ABORTIVE);
}
shutdownSocketSet.close(acceptSocket_);
acceptSocket_ = -1;
port_ = 0;
});
}
int Server::closeClients(bool abortive) {
for (int fd : fds_) {
if (abortive) {
struct linger l = {1, 0};
CHECK_ERR(setsockopt(fd, SOL_SOCKET, SO_LINGER, &l, sizeof(l)));
}
shutdownSocketSet.close(fd);
}
int n = fds_.size();
fds_.clear();
return n;
}
void Server::stop(bool abortive) {
stop_ = abortive ? ABORTIVE : ORDERLY;
shutdown(acceptSocket_, SHUT_RDWR);
}
void Server::join() {
serverThread_.join();
}
int createConnectedSocket(int port) {
int sock = socket(PF_INET, SOCK_STREAM, 0);
CHECK_ERR(sock);
sockaddr_in addr;
addr.sin_family = AF_INET;
addr.sin_port = htons(port);
addr.sin_addr.s_addr = htonl((127 << 24) | 1); // XXX
CHECK_ERR(connect(sock,
reinterpret_cast<const sockaddr*>(&addr),
sizeof(addr)));
return sock;
}
void runCloseTest(bool abortive) {
Server server;
int sock = createConnectedSocket(server.port());
std::thread stopper([&server, abortive] {
std::this_thread::sleep_for(std::chrono::milliseconds(200));
server.stop(abortive);
server.join();
});
char c;
int r = read(sock, &c, 1);
if (abortive) {
int e = errno;
EXPECT_EQ(-1, r);
EXPECT_EQ(ECONNRESET, e);
} else {
EXPECT_EQ(0, r);
}
close(sock);
stopper.join();
EXPECT_EQ(0, server.closeClients(false)); // closed by server when it exited
}
TEST(ShutdownSocketSetTest, OrderlyClose) {
runCloseTest(false);
}
TEST(ShutdownSocketSetTest, AbortiveClose) {
runCloseTest(true);
}
void runKillTest(bool abortive) {
Server server;
int sock = createConnectedSocket(server.port());
std::thread killer([&server, abortive] {
std::this_thread::sleep_for(std::chrono::milliseconds(200));
shutdownSocketSet.shutdownAll(abortive);
server.join();
});
char c;
int r = read(sock, &c, 1);
// "abortive" is just a hint for ShutdownSocketSet, so accept both
// behaviors
if (abortive) {
if (r == -1) {
EXPECT_EQ(ECONNRESET, errno);
} else {
EXPECT_EQ(r, 0);
}
} else {
EXPECT_EQ(0, r);
}
close(sock);
killer.join();
// NOT closed by server when it exited
EXPECT_EQ(1, server.closeClients(false));
}
TEST(ShutdownSocketSetTest, OrderlyKill) {
runKillTest(false);
}
TEST(ShutdownSocketSetTest, AbortiveKill) {
runKillTest(true);
}
}} // namespaces
int main(int argc, char *argv[]) {
testing::InitGoogleTest(&argc, argv);
google::ParseCommandLineFlags(&argc, &argv, true);
return RUN_ALL_TESTS();
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment