Commit c47d0c77 authored by Fred Qiu's avatar Fred Qiu Committed by Facebook GitHub Bot

Parse and capture client alpns from client hello info

Summary:
Add code to capture client alpns from client hello packet and store them in
ssl socket object for later logging.

Reviewed By: AjanthanAsogamoorthy

Differential Revision: D31176714

fbshipit-source-id: 888fd9949ede5209234bb3ab1959a6f9c14043b2
parent dfe13560
...@@ -1927,6 +1927,21 @@ void AsyncSSLSocket::resetClientHelloParsing(SSL* ssl) { ...@@ -1927,6 +1927,21 @@ void AsyncSSLSocket::resetClientHelloParsing(SSL* ssl) {
clientHelloInfo_->clientHelloBuf_.clear(); clientHelloInfo_->clientHelloBuf_.clear();
} }
void AsyncSSLSocket::parseClientAlpns(
AsyncSSLSocket* sock,
folly::io::Cursor& cursor,
uint16_t& extensionDataLength) {
cursor.skip(2);
extensionDataLength -= 2;
while (extensionDataLength) {
auto protoLength = cursor.readBE<uint8_t>();
extensionDataLength--;
auto proto = cursor.readFixedString(protoLength);
sock->clientHelloInfo_->clientAlpns_.push_back(proto);
extensionDataLength -= protoLength;
}
}
void AsyncSSLSocket::clientHelloParsingCallback( void AsyncSSLSocket::clientHelloParsingCallback(
int written, int written,
int /* version */, int /* version */,
...@@ -2051,6 +2066,10 @@ void AsyncSSLSocket::clientHelloParsingCallback( ...@@ -2051,6 +2066,10 @@ void AsyncSSLSocket::clientHelloParsingCallback(
extensionDataLength -= extensionDataLength -=
sizeof(typ) + sizeof(nameLength) + nameLength; sizeof(typ) + sizeof(nameLength) + nameLength;
} }
} else if (
extensionType ==
ssl::TLSExtension::APPLICATION_LAYER_PROTOCOL_NEGOTIATION) {
parseClientAlpns(sock, cursor, extensionDataLength);
} else { } else {
cursor.skip(extensionDataLength); cursor.skip(extensionDataLength);
} }
...@@ -2189,4 +2208,13 @@ void AsyncSSLSocket::getSSLServerCiphers(std::string& serverCiphers) const { ...@@ -2189,4 +2208,13 @@ void AsyncSSLSocket::getSSLServerCiphers(std::string& serverCiphers) const {
} }
} }
const std::vector<std::string>& AsyncSSLSocket::getClientAlpns() const {
if (!parseClientHello_) {
static std::vector<std::string> emptyAlpns{};
return emptyAlpns;
} else {
return clientHelloInfo_->clientAlpns_;
}
}
} // namespace folly } // namespace folly
...@@ -768,6 +768,13 @@ class AsyncSSLSocket : public AsyncSocket { ...@@ -768,6 +768,13 @@ class AsyncSSLSocket : public AsyncSocket {
*/ */
void getSSLServerCiphers(std::string& serverCiphers) const; void getSSLServerCiphers(std::string& serverCiphers) const;
/**
* Get the list of next protocols sent from the client. The protocols are
* directly as the client passed them and may be arbitrary byte sequences
* of arbitrary length.
*/
const std::vector<std::string>& getClientAlpns() const;
/** /**
* Method to check if peer verfication is set. * Method to check if peer verfication is set.
* *
...@@ -780,6 +787,10 @@ class AsyncSSLSocket : public AsyncSocket { ...@@ -780,6 +787,10 @@ class AsyncSSLSocket : public AsyncSocket {
static int bioWrite(BIO* b, const char* in, int inl); static int bioWrite(BIO* b, const char* in, int inl);
static int bioRead(BIO* b, char* out, int outl); static int bioRead(BIO* b, char* out, int outl);
void resetClientHelloParsing(SSL* ssl); void resetClientHelloParsing(SSL* ssl);
static void parseClientAlpns(
AsyncSSLSocket* sock,
folly::io::Cursor& cursor,
uint16_t& extensionDataLength);
static void clientHelloParsingCallback( static void clientHelloParsingCallback(
int written, int written,
int version, int version,
......
...@@ -600,6 +600,22 @@ int SSLContext::alpnSelectCallback( ...@@ -600,6 +600,22 @@ int SSLContext::alpnSelectCallback(
return SSL_TLSEXT_ERR_OK; return SSL_TLSEXT_ERR_OK;
} }
std::string SSLContext::getAdvertisedNextProtocols() {
if (advertisedNextProtocols_.empty()) {
return "";
}
std::string alpns(
(const char*)advertisedNextProtocols_[0].protocols + 1,
advertisedNextProtocols_[0].length - 1);
auto len = advertisedNextProtocols_[0].protocols[0];
for (size_t i = len; i < alpns.length();) {
len = alpns[i];
alpns[i] = ',';
i += len + 1;
}
return alpns;
}
bool SSLContext::setAdvertisedNextProtocols( bool SSLContext::setAdvertisedNextProtocols(
const std::list<std::string>& protocols) { const std::list<std::string>& protocols) {
return setRandomizedAdvertisedNextProtocols({{1, protocols}}); return setRandomizedAdvertisedNextProtocols({{1, protocols}});
......
...@@ -487,6 +487,8 @@ class SSLContext { ...@@ -487,6 +487,8 @@ class SSLContext {
void setOptions(long options); void setOptions(long options);
#if FOLLY_OPENSSL_HAS_ALPN #if FOLLY_OPENSSL_HAS_ALPN
std::string getAdvertisedNextProtocols();
/** /**
* Set the list of protocols that this SSL context supports. In client * Set the list of protocols that this SSL context supports. In client
* mode, this is the list of protocols that will be advertised for Application * mode, this is the list of protocols that will be advertised for Application
......
...@@ -96,6 +96,7 @@ struct ClientHelloInfo { ...@@ -96,6 +96,7 @@ struct ClientHelloInfo {
// long as each ServerName has a distinct type). In practice, the only one // long as each ServerName has a distinct type). In practice, the only one
// we really care about is HOST_NAME. // we really care about is HOST_NAME.
std::string clientHelloSNIHostname_; std::string clientHelloSNIHostname_;
std::vector<std::string> clientAlpns_;
}; };
} // namespace ssl } // namespace ssl
......
...@@ -597,6 +597,7 @@ TEST_F(NextProtocolTest, AlpnNotAllowMismatchNoClientProtocol) { ...@@ -597,6 +597,7 @@ TEST_F(NextProtocolTest, AlpnNotAllowMismatchNoClientProtocol) {
expectHandshakeSuccess(); expectHandshakeSuccess();
expectNoProtocol(); expectNoProtocol();
EXPECT_EQ(server->getClientAlpns(), std::vector<std::string>({}));
} }
TEST_F(NextProtocolTest, AlpnNotAllowMismatchWithOverlap) { TEST_F(NextProtocolTest, AlpnNotAllowMismatchWithOverlap) {
...@@ -607,6 +608,8 @@ TEST_F(NextProtocolTest, AlpnNotAllowMismatchWithOverlap) { ...@@ -607,6 +608,8 @@ TEST_F(NextProtocolTest, AlpnNotAllowMismatchWithOverlap) {
connect(); connect();
expectProtocol("baz"); expectProtocol("baz");
EXPECT_EQ(
server->getClientAlpns(), std::vector<std::string>({"blub", "baz"}));
} }
TEST_F(NextProtocolTest, AlpnNotAllowMismatchWithoutOverlap) { TEST_F(NextProtocolTest, AlpnNotAllowMismatchWithoutOverlap) {
...@@ -617,6 +620,7 @@ TEST_F(NextProtocolTest, AlpnNotAllowMismatchWithoutOverlap) { ...@@ -617,6 +620,7 @@ TEST_F(NextProtocolTest, AlpnNotAllowMismatchWithoutOverlap) {
connect(); connect();
expectHandshakeError(); expectHandshakeError();
EXPECT_EQ(server->getClientAlpns(), std::vector<std::string>({"blub"}));
} }
#endif #endif
......
...@@ -875,11 +875,15 @@ class AlpnServer : private AsyncSSLSocket::HandshakeCB, ...@@ -875,11 +875,15 @@ class AlpnServer : private AsyncSSLSocket::HandshakeCB,
explicit AlpnServer(AsyncSSLSocket::UniquePtr socket) explicit AlpnServer(AsyncSSLSocket::UniquePtr socket)
: nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) { : nextProto(nullptr), nextProtoLength(0), socket_(std::move(socket)) {
socket_->sslAccept(this); socket_->sslAccept(this);
socket_->enableClientHelloParsing();
} }
const unsigned char* nextProto; const unsigned char* nextProto;
unsigned nextProtoLength; unsigned nextProtoLength;
folly::Optional<AsyncSocketException> except; folly::Optional<AsyncSocketException> except;
const std::vector<std::string>& getClientAlpns() const {
return socket_->getClientAlpns();
}
private: private:
void handshakeSuc(AsyncSSLSocket*) noexcept override { void handshakeSuc(AsyncSSLSocket*) noexcept override {
......
...@@ -235,4 +235,15 @@ TEST_F(SSLContextTest, TestSetInvalidCiphersuite) { ...@@ -235,4 +235,15 @@ TEST_F(SSLContextTest, TestSetInvalidCiphersuite) {
TEST_F(SSLContextTest, TestTLS13MinVersionThrow) { TEST_F(SSLContextTest, TestTLS13MinVersionThrow) {
EXPECT_THROW(SSLContext{SSLContext::SSLVersion::TLSv1_3}, std::runtime_error); EXPECT_THROW(SSLContext{SSLContext::SSLVersion::TLSv1_3}, std::runtime_error);
} }
TEST_F(SSLContextTest, AdvertisedNextProtocols) {
EXPECT_EQ(ctx.getAdvertisedNextProtocols(), "");
ctx.setAdvertisedNextProtocols({"blub"});
EXPECT_EQ(ctx.getAdvertisedNextProtocols(), "blub");
ctx.setAdvertisedNextProtocols({"foo", "bar", "baz"});
EXPECT_EQ(ctx.getAdvertisedNextProtocols(), "foo,bar,baz");
}
} // namespace folly } // namespace folly
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment