Unverified Commit 53c08440 authored by Kip's avatar Kip Committed by GitHub

Implemented support for Authorization headers using basic method... (#725)

* {src/common/http_header.cc,include/pistache/http_header.h}: Implemented support for Authorization headers using basic method...
include/pistache/base64.{cc,h}: Added utility classes and helper functions for base 64 encoding and decoding...
tests/headers_test.cc: Added unit test authorization_basic_test...
version.txt: Bumped versioning metadata...

* Reformatted last commit with clang-format(1) per Dennis' request. Looks awful.
parent 7bb01321
/*
Copyright (C) 2019-2020, Kip Warner.
Released under the terms of Apache License 2.0.
*/
// Multiple include protection...
#ifndef _BASE_64_H_
#define _BASE_64_H_
// Includes...
// Build environment configuration...
#include <pistache/config.h>
// Standard C++ / POSIX system headers...
#include <cstddef>
#include <filesystem>
#include <string>
#include <vector>
// A class for performing decoding to raw bytes from base 64 encoding...
class Base64Decoder {
// Public methods...
public:
// Constructor...
explicit Base64Decoder(const std::string &Base64EncodedString)
: m_Base64EncodedString(Base64EncodedString) {}
// Calculate length of decoded raw bytes from that would be generated if
// the base 64 encoded input buffer was decoded. This is not a static
// method because we need to examine the string...
std::vector<std::byte>::size_type CalculateDecodedSize() const;
// Decode base 64 encoding into raw bytes...
const std::vector<std::byte> &Decode();
// Get raw decoded data...
const std::vector<std::byte> &GetRawDecodedData() const noexcept {
return m_DecodedData;
}
// Protected methods...
protected:
// Convert an octet character to corresponding sextet, provided it can
// safely be represented as such. Otherwise return 0xff...
std::byte DecodeCharacter(const unsigned char Character) const;
// Protected attributes...
protected:
// Base 64 encoded string to decode...
const std::string &m_Base64EncodedString;
// Decoded raw data...
std::vector<std::byte> m_DecodedData;
};
// A class for performing base 64 encoding from raw bytes...
class Base64Encoder {
// Public methods...
public:
// Construct encoder to encode from a raw input buffer...
explicit Base64Encoder(const std::vector<std::byte> &InputBuffer)
: m_InputBuffer(InputBuffer) {}
// Calculate length of base 64 string that would need to be generated
// for raw data of a given length...
static std::string::size_type CalculateEncodedSize(
const std::vector<std::byte>::size_type DecodedSize) noexcept;
// Encode raw data input buffer to base 64...
const std::string &Encode() noexcept;
// Encode a string into base 64 format...
static std::string EncodeString(const std::string &StringInput);
// Get the encoded data...
const std::string &GetBase64EncodedString() const noexcept {
return m_Base64EncodedString;
}
// Protected methods...
protected:
// Encode single binary byte to 6-bit base 64 character...
unsigned char EncodeByte(const std::byte Byte) const;
// Protected attributes...
protected:
// Raw bytes to encode to base 64 string...
const std::vector<std::byte> &m_InputBuffer;
// Base64 encoded string...
std::string m_Base64EncodedString;
};
// Multiple include protection...
#endif
...@@ -326,11 +326,27 @@ class Authorization : public Header { ...@@ -326,11 +326,27 @@ class Authorization : public Header {
public: public:
NAME("Authorization"); NAME("Authorization");
enum class Method { Basic, Bearer, Unknown };
Authorization() : value_("NONE") {} Authorization() : value_("NONE") {}
explicit Authorization(std::string &&val) : value_(std::move(val)) {} explicit Authorization(std::string &&val) : value_(std::move(val)) {}
explicit Authorization(const std::string &val) : value_(val) {} explicit Authorization(const std::string &val) : value_(val) {}
// What type of authorization method was used?
Method getMethod() const noexcept;
// Check if a particular authorization method was used...
template <Method M> bool hasMethod() const noexcept { return hasMethod<M>(); }
// Get decoded user ID and password if basic method was used...
std::string getBasicUser() const;
std::string getBasicPassword() const;
// Set encoded user ID and password for basic method...
void setBasicUserPassword(const std::string &User,
const std::string &Password);
void parse(const std::string &data) override; void parse(const std::string &data) override;
void write(std::ostream &os) const override; void write(std::ostream &os) const override;
...@@ -340,6 +356,12 @@ private: ...@@ -340,6 +356,12 @@ private:
std::string value_; std::string value_;
}; };
template <>
bool Authorization::hasMethod<Authorization::Method::Basic>() const noexcept;
template <>
bool Authorization::hasMethod<Authorization::Method::Bearer>() const noexcept;
class ContentType : public Header { class ContentType : public Header {
public: public:
NAME("Content-Type") NAME("Content-Type")
......
/*
Copyright (C) 2019-2020, Kip Warner.
Released under the terms of Apache License 2.0
*/
// Includes...
// Our headers...
#include <pistache/base64.h>
// Standard C++ / POSIX system headers...
#include <algorithm>
#include <cassert>
#include <cmath>
#include <fstream>
#include <stdexcept>
// Using the standard namespace and Pistache...
using namespace std;
// Calculate length of decoded raw bytes from that would be generated if the
// base 64 encoded input buffer was decoded. This is not a static method
// because we need to examine the string...
vector<byte>::size_type Base64Decoder::CalculateDecodedSize() const {
// If encoded size was zero, so is decoded size...
if (m_Base64EncodedString.size() == 0)
return 0;
// If non-zero, should always be at least four characters...
if (m_Base64EncodedString.size() < 4)
throw runtime_error(
"Base64 encoded stream should always be at least four bytes.");
// ...and always a multiple of four bytes because every three decoded bytes
// produce four encoded base 64 bytes, which may include padding...
if ((m_Base64EncodedString.size() % 4) != 0)
throw runtime_error("Base64 encoded stream length should always be evenly "
"divisible by four.");
// Iterator to walk the encoded string from the beginning...
auto EndIterator = m_Base64EncodedString.begin();
// Keep walking along the input buffer trying to decode characters, but
// without storing them, until we hit the first character we cannot decode.
// This should be the first padding character or end of string...
while (DecodeCharacter(*EndIterator) < static_cast<byte>(64))
++EndIterator;
// The length of the encoded string is the distance from the beginning to
// the first non-decodable character, such as padding...
const auto InputSize = distance(m_Base64EncodedString.begin(), EndIterator);
// Calculate decoded size before account for any more decoded bytes within
// the trailing padding block...
const auto DecodedSize = InputSize / 4 * 3;
// True decoded size depends on how much padding needed to be applied...
switch (InputSize % 4) {
case 2:
return DecodedSize + 1;
case 3:
return DecodedSize + 2;
default:
return DecodedSize;
}
}
// Decode base 64 encoding into raw bytes...
const vector<byte> &Base64Decoder::Decode() {
// Calculate required size of output buffer...
const auto DecodedSize = CalculateDecodedSize();
// Allocate sufficient storage...
m_DecodedData = vector<byte>(DecodedSize, byte(0x00));
m_DecodedData.shrink_to_fit();
// Initialize decode input and output iterators...
string::size_type InputOffset = 0;
string::size_type OutputOffset = 0;
// While there is at least one set of three octets remaining to decode...
for (string::size_type Index = 2; Index < DecodedSize; Index += 3) {
// Construct octets from sextets...
m_DecodedData.at(OutputOffset + 0) =
DecodeCharacter(m_Base64EncodedString.at(InputOffset + 0)) << 2 |
DecodeCharacter(m_Base64EncodedString.at(InputOffset + 1)) >> 4;
m_DecodedData.at(OutputOffset + 1) =
DecodeCharacter(m_Base64EncodedString.at(InputOffset + 1)) << 4 |
DecodeCharacter(m_Base64EncodedString.at(InputOffset + 2)) >> 2;
m_DecodedData.at(OutputOffset + 2) =
DecodeCharacter(m_Base64EncodedString.at(InputOffset + 2)) << 6 |
DecodeCharacter(m_Base64EncodedString.at(InputOffset + 3));
// Reseek i/o pointers...
InputOffset += 4;
OutputOffset += 3;
}
// There's less than three octets remaining...
switch (DecodedSize % 3) {
// One octet left to construct...
case 1:
m_DecodedData.at(OutputOffset + 0) =
DecodeCharacter(m_Base64EncodedString.at(InputOffset + 0)) << 2 |
DecodeCharacter(m_Base64EncodedString.at(InputOffset + 1)) >> 4;
break;
// Two octets left to construct...
case 2:
m_DecodedData.at(OutputOffset + 0) =
DecodeCharacter(m_Base64EncodedString.at(InputOffset + 0)) << 2 |
DecodeCharacter(m_Base64EncodedString.at(InputOffset + 1)) >> 4;
m_DecodedData.at(OutputOffset + 1) =
DecodeCharacter(m_Base64EncodedString.at(InputOffset + 1)) << 4 |
DecodeCharacter(m_Base64EncodedString.at(InputOffset + 2)) >> 2;
break;
}
// All done. Return constant reference to buffer containing decoded data...
return m_DecodedData;
}
// Convert an octet character to corresponding sextet, provided it can safely be
// represented as such. Otherwise return 0xff...
inline byte
Base64Decoder::DecodeCharacter(const unsigned char Character) const {
// Capital letter 'A' is ASCII 65 and zero in base 64...
if ('A' <= Character && Character <= 'Z')
return static_cast<byte>(Character - 'A');
// Lowercase letter 'a' is ASCII 97 and 26 in base 64...
if ('a' <= Character && Character <= 'z')
return static_cast<byte>(Character - (97 - 26));
// Numeric digit '0' is ASCII 48 and 52 in base 64...
if ('0' <= Character && Character <= '9')
return static_cast<byte>(Character - (48 - 52));
// '+' is ASCII 43 and 62 in base 64...
if (Character == '+')
return static_cast<byte>(62);
// '/' is ASCII 47 and 63 in base 64...
if (Character == '/')
return static_cast<byte>(63);
// Anything else that's not a 6-bit representation, signal to caller...
return static_cast<byte>(255);
}
// Calculate length of base 64 string that would need to be generated for raw
// data of a given length...
string::size_type Base64Encoder::CalculateEncodedSize(
const vector<byte>::size_type DecodedSize) noexcept {
// First term calcualtes the unpadded length. The bitwise and rounds up to
// the nearest multiple of four to add padding...
return ((4 * DecodedSize / 3) + 3) & ~3;
}
// Encode raw data input buffer to base 64...
const string &Base64Encoder::Encode() noexcept {
// Allocate precise storage for the output buffer...
m_Base64EncodedString =
string(CalculateEncodedSize(m_InputBuffer.size()), '!');
m_Base64EncodedString.shrink_to_fit();
// Number of complete octet triplets...
const auto OctetTriplets = m_InputBuffer.size() / 3;
// Initialize encode input and output offset registers...
string::size_type InputOffset = 0;
string::size_type OutputOffset = 0;
// While there are still complete octet triplets remaining...
for (string::size_type Index = 0; Index < OctetTriplets; ++Index) {
// Encode first sextet from first octet...
m_Base64EncodedString.at(OutputOffset + 0) =
EncodeByte(m_InputBuffer.at(InputOffset + 0) >> 2);
// Encode second sextet from first and second octet....
m_Base64EncodedString.at(OutputOffset + 1) = EncodeByte(
(m_InputBuffer.at(InputOffset + 0) & static_cast<byte>(0x03)) << 4 |
m_InputBuffer.at(InputOffset + 1) >> 4);
// Encode third sextet from second and third octet...
m_Base64EncodedString.at(OutputOffset + 2) = EncodeByte(
(m_InputBuffer.at(InputOffset + 1) & static_cast<byte>(0x0F)) << 2 |
m_InputBuffer.at(InputOffset + 2) >> 6);
// Encode fourth sextet from third octet...
m_Base64EncodedString.at(OutputOffset + 3) =
EncodeByte(m_InputBuffer.at(InputOffset + 2) & static_cast<byte>(0x3F));
// Stride i/o pointers...
InputOffset += 3;
OutputOffset += 4;
}
// Since the length of padded base 64 encoding must always be a multiple of
// four, after the last octet triplet, were there any additional octets in
// the input to encode that were less than three in number?
switch (m_InputBuffer.size() % 3) {
// Exactly one trailing octet followed...
case 1:
// Encode first sextet from remaining octet...
m_Base64EncodedString.at(OutputOffset + 0) =
EncodeByte(m_InputBuffer.at(InputOffset + 0) >> 2);
// Encode second sextet from remaining octet and empty second one...
m_Base64EncodedString.at(OutputOffset + 1) = EncodeByte(
(m_InputBuffer.at(InputOffset + 0) & static_cast<byte>(0x03)) << 4);
// Padd the two sextets with two padding characters to ensure the
// total length is a multiple of four...
m_Base64EncodedString.at(OutputOffset + 2) = '=';
m_Base64EncodedString.at(OutputOffset + 3) = '=';
break;
// Exactly two trailing octets followed...
case 2:
// Encode first sextet from first octet...
m_Base64EncodedString.at(OutputOffset + 0) =
EncodeByte(m_InputBuffer.at(InputOffset + 0) >> 2);
// Encode second sextet from first and second octet...
m_Base64EncodedString.at(OutputOffset + 1) = EncodeByte(
(m_InputBuffer.at(InputOffset + 0) & static_cast<byte>(0x03)) << 4 |
m_InputBuffer.at(InputOffset + 1) >> 4);
// Encode third sextet from second and dummy third octet...
m_Base64EncodedString.at(OutputOffset + 2) = EncodeByte(
(m_InputBuffer.at(InputOffset + 1) & static_cast<byte>(0x0F)) << 2);
// Padd three sextets with a single padding character to ensure the
// total length is a multiple of four...
m_Base64EncodedString.at(OutputOffset + 3) = '=';
break;
}
// Return constant reference to encoded data to caller...
return m_Base64EncodedString;
}
// Encode single binary byte to 6-bit base 64 character...
inline unsigned char Base64Encoder::EncodeByte(const byte Byte) const {
// Capital letter 'A' is ASCII 65 and zero in base 64...
if (static_cast<unsigned char>(Byte) < 26)
return static_cast<unsigned char>(Byte) + 'A';
// Lowercase letter 'a' is ASCII 97 and 26 in base 64...
if (static_cast<unsigned char>(Byte) < 52)
return static_cast<unsigned char>(Byte) + 71;
// Numeric digit '0' is ASCII 48 and 52 in base 64...
if (static_cast<unsigned char>(Byte) < 62)
return static_cast<unsigned char>(Byte) - 4;
// '+' is ASCII 43 and 62 in base 64...
if (static_cast<unsigned char>(Byte) == 62)
return '+';
// '/' is ASCII 47 and 63 in base 64...
if (static_cast<unsigned char>(Byte) == 63)
return '/';
// And lastly anything that can't be represented in 6-bits we return 64...
return 64;
}
// Encode a string into base 64 format...
string Base64Encoder::EncodeString(const string &StringInput) {
// Allocate storage for binary form of message...
vector<byte> BinaryInput(StringInput.size());
// Convert message to binary form...
transform(StringInput.begin(), StringInput.end(), BinaryInput.begin(),
[](const char Character) { return byte(Character); });
// Encode to base 64...
Base64Encoder Encoder(BinaryInput);
// Return encoded string to caller by value...
return Encoder.Encode();
}
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
Implementation of common HTTP headers described by the RFC Implementation of common HTTP headers described by the RFC
*/ */
#include <pistache/base64.h>
#include <pistache/common.h> #include <pistache/common.h>
#include <pistache/config.h> #include <pistache/config.h>
#include <pistache/http.h> #include <pistache/http.h>
...@@ -283,6 +284,127 @@ void ContentLength::parse(const std::string &data) { ...@@ -283,6 +284,127 @@ void ContentLength::parse(const std::string &data) {
void ContentLength::write(std::ostream &os) const { os << value_; } void ContentLength::write(std::ostream &os) const { os << value_; }
// What type of authorization method was used?
Authorization::Method Authorization::getMethod() const noexcept {
// Basic...
if (hasMethod<Method::Basic>())
return Method::Basic;
// Bearer...
else if (hasMethod<Method::Bearer>())
return Method::Bearer;
// Unknown...
else
return Method::Unknown;
}
// Authorization is basic method...
template <>
bool Authorization::hasMethod<Authorization::Method::Basic>() const noexcept {
// Method should begin with "Basic: "
if (value().rfind("Basic ", 0) == std::string::npos)
return false;
// Verify value is long enough to contain basic method's credentials...
if (value().length() <= std::string("Basic ").length())
return false;
// Looks good...
return true;
}
// Authorization is bearer method...
template <>
bool Authorization::hasMethod<Authorization::Method::Bearer>() const noexcept {
// Method should begin with "Bearer: "
if (value().rfind("Bearer ", 0) == std::string::npos)
return false;
// Verify value is long enough to contain basic method's credentials...
if (value().length() <= std::string("Bearer ").length())
return false;
// Looks good...
return true;
}
// Get decoded user ID if basic method was used...
std::string Authorization::getBasicUser() const {
// Verify basic authorization method was used...
if (!hasMethod<Authorization::Method::Basic>())
throw std::runtime_error("Authorization header does not use Basic method.");
// Extract encoded credentials...
const std::string EncodedCredentials(
value_.begin() + std::string("Basic ").length(), value_.end());
// Decode them...
Base64Decoder Decoder(EncodedCredentials);
const std::vector<std::byte> &BinaryDecodedCredentials = Decoder.Decode();
// Transform to string...
std::string DecodedCredentials;
for (std::byte CurrentByte : BinaryDecodedCredentials)
DecodedCredentials.push_back(static_cast<char>(CurrentByte));
// Find user ID and password delimiter...
const auto Delimiter = DecodedCredentials.find_first_of(':');
// None detected. Assume this is a malformed header...
if (Delimiter == std::string::npos)
return std::string();
// Extract and return just the user ID...
return std::string(DecodedCredentials.begin(),
DecodedCredentials.begin() + Delimiter);
}
// Get decoded password if basic method was used...
std::string Authorization::getBasicPassword() const {
// Verify basic authorization method was used...
if (!hasMethod<Authorization::Method::Basic>())
throw std::runtime_error("Authorization header does not use Basic method.");
// Extract encoded credentials...
const std::string EncodedCredentials(
value_.begin() + std::string("Basic ").length(), value_.end());
// Decode them...
Base64Decoder Decoder(EncodedCredentials);
const std::vector<std::byte> &BinaryDecodedCredentials = Decoder.Decode();
// Transform to string...
std::string DecodedCredentials;
for (std::byte CurrentByte : BinaryDecodedCredentials)
DecodedCredentials.push_back(static_cast<char>(CurrentByte));
// Find user ID and password delimiter...
const auto Delimiter = DecodedCredentials.find_first_of(':');
// None detected. Assume this is a malformed header...
if (Delimiter == std::string::npos)
return std::string();
// Extract and return just the password...
return std::string(DecodedCredentials.begin() + Delimiter + 1,
DecodedCredentials.end());
}
// Set encoded user ID and password for basic method...
void Authorization::setBasicUserPassword(const std::string &User,
const std::string &Password) {
// Verify user ID does not contain a colon...
if (User.find_first_of(':') != std::string::npos)
throw std::runtime_error("User ID cannot contain a colon.");
// Format credentials string...
const std::string Credentials = User + std::string(":") + Password;
// Encode credentials...
value_ = std::string("Basic ") + Base64Encoder::EncodeString(Credentials);
}
void Authorization::parse(const std::string &data) { void Authorization::parse(const std::string &data) {
try { try {
value_ = data; value_ = data;
......
...@@ -271,7 +271,43 @@ TEST(headers_test, content_length) { ...@@ -271,7 +271,43 @@ TEST(headers_test, content_length) {
ASSERT_TRUE(cl.value() == 3495U); ASSERT_TRUE(cl.value() == 3495U);
} }
TEST(headers_test, authorization_test) { // Verify authorization header with basic method works correctly...
TEST(headers_test, authorization_basic_test) {
Pistache::Http::Header::Authorization au;
std::ostringstream oss;
// Sample basic method authorization header for credentials
// Aladdin:OpenSesame base 64 encoded...
const std::string BasicEncodedValue = "Basic QWxhZGRpbjpPcGVuU2VzYW1l";
// Try parsing the raw basic authorization value...
au.parse(BasicEncodedValue);
// Verify what went in is what came out...
au.write(oss);
ASSERT_TRUE(BasicEncodedValue == oss.str());
oss = std::ostringstream();
// Verify authorization header recognizes it is basic method and no other...
ASSERT_TRUE(
au.hasMethod<Pistache::Http::Header::Authorization::Method::Basic>());
ASSERT_FALSE(
au.hasMethod<Pistache::Http::Header::Authorization::Method::Bearer>());
// Set credentials from decoded user and password...
au.setBasicUserPassword("Aladdin", "OpenSesame");
// Verify it encoded correctly...
au.write(oss);
ASSERT_TRUE(BasicEncodedValue == oss.str());
oss = std::ostringstream();
// Verify it decoded correctly...
ASSERT_TRUE(au.getBasicUser() == "Aladdin");
ASSERT_TRUE(au.getBasicPassword() == "OpenSesame");
}
TEST(headers_test, authorization_bearer_test) {
Pistache::Http::Header::Authorization au; Pistache::Http::Header::Authorization au;
std::ostringstream oss; std::ostringstream oss;
au.parse("Bearer " au.parse("Bearer "
...@@ -281,6 +317,11 @@ TEST(headers_test, authorization_test) { ...@@ -281,6 +317,11 @@ TEST(headers_test, authorization_test) {
"d0131JxqX4xSZLlO5xMRrCPBgn_00OxKJ9CQdnpjpuzblNQd2-A"); "d0131JxqX4xSZLlO5xMRrCPBgn_00OxKJ9CQdnpjpuzblNQd2-A");
au.write(oss); au.write(oss);
ASSERT_TRUE(
au.hasMethod<Pistache::Http::Header::Authorization::Method::Bearer>());
ASSERT_FALSE(
au.hasMethod<Pistache::Http::Header::Authorization::Method::Basic>());
ASSERT_TRUE( ASSERT_TRUE(
"Bearer " "Bearer "
"eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXUyJ9." "eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXUyJ9."
......
VERSION_MAJOR 0 VERSION_MAJOR 0
VERSION_MINOR 0 VERSION_MINOR 0
VERSION_PATCH 002 VERSION_PATCH 002
VERSION_GIT_DATE 20200117 VERSION_GIT_DATE 20200301
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment