Commit 17d04308 authored by Neel Goyal's avatar Neel Goyal Committed by facebook-github-bot-4

D2741855 broke my wangle. Reverting

Summary: Revert D2741855

Reviewed By: mzlee

Differential Revision: D2744015

fb-gh-sync-id: b1e9b0a5ab95cb988d2b5c08c86139452b092465
parent 97c7b417
...@@ -84,10 +84,6 @@ SSLContext::SSLContext(SSLVersion version) { ...@@ -84,10 +84,6 @@ SSLContext::SSLContext(SSLVersion version) {
SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback); SSL_CTX_set_tlsext_servername_callback(ctx_, baseServerNameOpenSSLCallback);
SSL_CTX_set_tlsext_servername_arg(ctx_, this); SSL_CTX_set_tlsext_servername_arg(ctx_, this);
#endif #endif
#ifdef OPENSSL_NPN_NEGOTIATED
Random::seed(nextProtocolPicker_);
#endif
} }
SSLContext::~SSLContext() { SSLContext::~SSLContext() {
...@@ -378,16 +374,16 @@ bool SSLContext::setRandomizedAdvertisedNextProtocols( ...@@ -378,16 +374,16 @@ bool SSLContext::setRandomizedAdvertisedNextProtocols(
dst += protoLength; dst += protoLength;
} }
total_weight += item.weight; total_weight += item.weight;
advertised_item.probability = item.weight;
advertisedNextProtocols_.push_back(advertised_item); advertisedNextProtocols_.push_back(advertised_item);
advertisedNextProtocolWeights_.push_back(item.weight);
} }
if (total_weight == 0) { if (total_weight == 0) {
deleteNextProtocolsStrings(); deleteNextProtocolsStrings();
return false; return false;
} }
nextProtocolDistribution_ = for (auto& advertised_item : advertisedNextProtocols_) {
std::discrete_distribution<>(advertisedNextProtocolWeights_.begin(), advertised_item.probability /= total_weight;
advertisedNextProtocolWeights_.end()); }
if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) { if ((uint8_t)protocolType & (uint8_t)NextProtocolType::NPN) {
SSL_CTX_set_next_protos_advertised_cb( SSL_CTX_set_next_protos_advertised_cb(
ctx_, advertisedNextProtocolCallback, this); ctx_, advertisedNextProtocolCallback, this);
...@@ -410,7 +406,6 @@ void SSLContext::deleteNextProtocolsStrings() { ...@@ -410,7 +406,6 @@ void SSLContext::deleteNextProtocolsStrings() {
delete[] protocols.protocols; delete[] protocols.protocols;
} }
advertisedNextProtocols_.clear(); advertisedNextProtocols_.clear();
advertisedNextProtocolWeights_.clear();
} }
void SSLContext::unsetNextProtocols() { void SSLContext::unsetNextProtocols() {
...@@ -424,8 +419,18 @@ void SSLContext::unsetNextProtocols() { ...@@ -424,8 +419,18 @@ void SSLContext::unsetNextProtocols() {
} }
size_t SSLContext::pickNextProtocols() { size_t SSLContext::pickNextProtocols() {
CHECK(!advertisedNextProtocols_.empty()) << "Failed to pickNextProtocols"; unsigned char random_byte;
return nextProtocolDistribution_(nextProtocolPicker_); RAND_bytes(&random_byte, 1);
double random_value = random_byte / 255.0;
double sum = 0;
for (size_t i = 0; i < advertisedNextProtocols_.size(); ++i) {
sum += advertisedNextProtocols_[i].probability;
if (sum < random_value && i + 1 < advertisedNextProtocols_.size()) {
continue;
}
return i;
}
CHECK(false) << "Failed to pickNextProtocols";
} }
int SSLContext::advertisedNextProtocolCallback(SSL* ssl, int SSLContext::advertisedNextProtocolCallback(SSL* ssl,
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <string> #include <string>
#include <random>
#include <openssl/ssl.h> #include <openssl/ssl.h>
#include <openssl/tls1.h> #include <openssl/tls1.h>
...@@ -36,8 +35,6 @@ ...@@ -36,8 +35,6 @@
#include <folly/folly-config.h> #include <folly/folly-config.h>
#endif #endif
#include <folly/Random.h>
namespace folly { namespace folly {
/** /**
...@@ -90,6 +87,12 @@ class SSLContext { ...@@ -90,6 +87,12 @@ class SSLContext {
std::list<std::string> protocols; std::list<std::string> protocols;
}; };
struct AdvertisedNextProtocolsItem {
unsigned char* protocols;
unsigned length;
double probability;
};
// Function that selects a client protocol given the server's list // Function that selects a client protocol given the server's list
using ClientProtocolFilterCallback = bool (*)(unsigned char**, unsigned int*, using ClientProtocolFilterCallback = bool (*)(unsigned char**, unsigned int*,
const unsigned char*, unsigned int); const unsigned char*, unsigned int);
...@@ -455,20 +458,10 @@ class SSLContext { ...@@ -455,20 +458,10 @@ class SSLContext {
static bool initialized_; static bool initialized_;
#ifdef OPENSSL_NPN_NEGOTIATED #ifdef OPENSSL_NPN_NEGOTIATED
struct AdvertisedNextProtocolsItem {
unsigned char* protocols;
unsigned length;
};
/** /**
* Wire-format list of advertised protocols for use in NPN. * Wire-format list of advertised protocols for use in NPN.
*/ */
std::vector<AdvertisedNextProtocolsItem> advertisedNextProtocols_; std::vector<AdvertisedNextProtocolsItem> advertisedNextProtocols_;
std::vector<int> advertisedNextProtocolWeights_;
std::discrete_distribution<int> nextProtocolDistribution_;
Random::DefaultGenerator nextProtocolPicker_;
static int sNextProtocolsExDataIndex_; static int sNextProtocolsExDataIndex_;
static int advertisedNextProtocolCallback(SSL* ssl, static int advertisedNextProtocolCallback(SSL* ssl,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment