Commit 3e6addba authored by Mathieu Stefani's avatar Mathieu Stefani Committed by Mathieu STEFANI

Follow-up of work on type safe headers and improved http parser

parent 5b65633d
...@@ -12,10 +12,11 @@ class MyHandler : public Net::Http::Handler { ...@@ -12,10 +12,11 @@ class MyHandler : public Net::Http::Handler {
if (req.resource == "/ping") { if (req.resource == "/ping") {
if (req.method == Net::Http::Method::Get) { if (req.method == Net::Http::Method::Get) {
auto host = req.headers.getHeader<Net::Http::Host>(); using namespace Net::Http;
cout << "Host = " << host->host() << endl;
Net::Http::Response response(Net::Http::Code::Ok, "PONG"); Net::Http::Response response(Net::Http::Code::Ok, "PONG");
// response.headers
// .add(std::make_shared<Server>("lys"));
response.writeTo(peer); response.writeTo(peer);
} }
...@@ -35,7 +36,7 @@ int main(int argc, char *argv[]) { ...@@ -35,7 +36,7 @@ int main(int argc, char *argv[]) {
cout << "Cores = " << hardware_concurrency() << endl; cout << "Cores = " << hardware_concurrency() << endl;
Net::Http::Server server(addr); Net::Http::Endpoint server(addr);
server.setHandler(std::make_shared<MyHandler>()); server.setHandler(std::make_shared<MyHandler>());
server.serve(); server.serve();
......
/* flags.h
Mathieu Stefani, 18 August 2015
Make it easy to have bitwise operators for scoped or unscoped enumerations
*/
#pragma once
#include <type_traits>
#include <iostream>
// Looks like gcc 4.6 does not implement std::underlying_type
namespace detail {
template<size_t N> struct TypeStorage;
template<> struct TypeStorage<sizeof(uint8_t)> {
typedef uint8_t Type;
};
template<> struct TypeStorage<sizeof(uint16_t)> {
typedef uint16_t Type;
};
template<> struct TypeStorage<sizeof(uint32_t)> {
typedef uint32_t Type;
};
template<> struct TypeStorage<sizeof(uint64_t)> {
typedef uint64_t Type;
};
template<typename T> struct UnderlyingType {
typedef typename TypeStorage<sizeof(T)>::Type Type;
};
}
template<typename T>
class Flags {
public:
static_assert(std::is_enum<T>::value, "Flags only works with enumerations");
typedef typename detail::UnderlyingType<T>::Type Type;
Flags() { }
Flags(T val) : val(val)
{
}
#define DEFINE_BITWISE_OP_CONST(Op) \
Flags<T> operator Op (T rhs) const { \
return Flags<T>( \
static_cast<T>(static_cast<Type>(val) Op static_cast<Type>(rhs)) \
); \
} \
\
Flags<T> operator Op (Flags<T> rhs) const { \
return Flags<T>( \
static_cast<T>(static_cast<Type>(val) Op static_cast<Type>(rhs.val)) \
); \
}
DEFINE_BITWISE_OP_CONST(|)
DEFINE_BITWISE_OP_CONST(&)
DEFINE_BITWISE_OP_CONST(^)
#undef DEFINE_BITWISE_OP_CONST
#define DEFINE_BITWISE_OP(Op) \
Flags<T>& operator Op##=(T rhs) { \
val = static_cast<T>( \
static_cast<Type>(val) Op static_cast<Type>(rhs) \
); \
return *this; \
} \
\
Flags<T>& operator Op##=(Flags<T> rhs) { \
val = static_cast<T>( \
static_cast<Type>(val) Op static_cast<Type>(rhs.val) \
); \
return *this; \
}
DEFINE_BITWISE_OP(|)
DEFINE_BITWISE_OP(&)
DEFINE_BITWISE_OP(^)
#undef DEFINE_BITWISE_OP
bool hasFlag(T flag) const {
return static_cast<T>(
static_cast<Type>(val) & static_cast<Type>(flag)
) == flag;
}
Flags<T>& setFlag(T flag) {
return *this &= flag;
}
Flags<T>& toggleFlag(T flag) {
return *this ^= flag;
}
operator T() const {
return val;
}
private:
T val;
};
#define DEFINE_BITWISE_OP(Op, T) \
inline T operator Op (T lhs, T rhs) { \
typedef detail::UnderlyingType<T>::Type UnderlyingType; \
return static_cast<T>( \
static_cast<UnderlyingType>(lhs) Op static_cast<UnderlyingType>(rhs) \
); \
}
#define DECLARE_FLAGS_OPERATORS(T) \
DEFINE_BITWISE_OP(&, T) \
DEFINE_BITWISE_OP(|, T)
...@@ -24,46 +24,71 @@ static constexpr char CRLF[] = {CR, LF}; ...@@ -24,46 +24,71 @@ static constexpr char CRLF[] = {CR, LF};
namespace Private { namespace Private {
void bool
Parser::advance(size_t count) { Parser::Cursor::advance(size_t count)
if (cursor + count >= len) { {
raise("Early EOF"); if (value + count >= sizeof (buff.data)) {
//parser->raise("Early EOF");
}
else if (value + count >= buff.len) {
return false;
} }
cursor += count; value += count;
return true;
} }
bool bool
Parser::eol() const { Parser::Cursor::eol() const {
return buffer[cursor] == CR && next() == LF; return buff.data[value] == CR && next() == LF;
}
int
Parser::Cursor::next() const {
if (value + 1 >= sizeof (buff.data)) {
//parser->raise("Early EOF");
}
else if (value + 1 >= buff.len) {
return Eof;
}
return buff.data[value + 1];
} }
char char
Parser::next() const { Parser::Cursor::current() const {
if (cursor + 1 >= len) { return buff.data[value];
raise("Early EOF");
} }
return buffer[cursor + 1]; const char *
Parser::Cursor::offset() const {
return buff.data + value;
} }
Request const char *
Parser::expectRequest() { Parser::Cursor::offset(size_t off) const {
expectRequestLine(); return buff.data + off;
expectHeaders(); }
expectBody();
return request; size_t
Parser::Cursor::diff(size_t previous) const {
return value - previous;
} }
// 5.1 Request-Line
void void
Parser::expectRequestLine() { Parser::Step::raise(const char* msg) {
throw ParsingError(msg);
}
Parser::State
Parser::RequestLineStep::apply(Cursor& cursor) {
Reverter reverter(cursor);
auto tryMatch = [&](const char* const str) { auto tryMatch = [&](const char* const str) {
const size_t len = std::strlen(str); const size_t len = std::strlen(str);
if (strncmp(buffer, str, len) == 0) { if (strncmp(cursor.offset(), str, len) == 0) {
cursor += len - 1; cursor.advance(len - 1);
return true; return true;
} }
return false; return false;
...@@ -72,123 +97,152 @@ namespace Private { ...@@ -72,123 +97,152 @@ namespace Private {
// Method // Method
if (tryMatch("OPTIONS")) { if (tryMatch("OPTIONS")) {
request.method = Method::Options; request->method = Method::Options;
} }
else if (tryMatch("GET")) { else if (tryMatch("GET")) {
request.method = Method::Get; request->method = Method::Get;
} }
else if (tryMatch("POST")) { else if (tryMatch("POST")) {
request.method = Method::Post; request->method = Method::Post;
} }
else if (tryMatch("HEAD")) { else if (tryMatch("HEAD")) {
request.method = Method::Head; request->method = Method::Head;
} }
else if (tryMatch("PUT")) { else if (tryMatch("PUT")) {
request.method = Method::Put; request->method = Method::Put;
} }
else if (tryMatch("DELETE")) { else if (tryMatch("DELETE")) {
request.method = Method::Delete; request->method = Method::Delete;
} }
if (next() != ' ') { auto n = cursor.next();
raise("Malformed HTTP request after Method"); if (n == Cursor::Eof) return State::Again;
}
// SP else if (n != ' ') raise("Malformed HTTP Request");
advance(2); if (!cursor.advance(2)) return State::Again;
// Request-URI
size_t start = cursor; size_t start = cursor;
while (next() != ' ') { while ((n = cursor.next()) != Cursor::Eof && n != ' ') {
advance(1); if (!cursor.advance(1)) return State::Again;
if (eol()) {
raise("Malformed HTTP request after Request-URI");
}
} }
request.resource = std::string(buffer + start, cursor - start + 1); request->resource = std::string(cursor.offset(start), cursor.diff(start) + 1);
if (next() != ' ') { if ((n = cursor.next()) == Cursor::Eof) return State::Again;
if (n != ' ')
raise("Malformed HTTP request after Request-URI"); raise("Malformed HTTP request after Request-URI");
}
// SP // SP
advance(2); if (!cursor.advance(2)) return State::Again;
// HTTP-Version // HTTP-Version
start = cursor; start = cursor;
while (!eol()) while (!cursor.eol())
advance(1); if (!cursor.advance(1)) return State::Again;
const size_t diff = cursor - start; const size_t diff = cursor.diff(start);
if (strncmp(buffer + start, "HTTP/1.0", diff) == 0) { if (strncmp(cursor.offset(start), "HTTP/1.0", diff) == 0) {
request.version = Version::Http10; request->version = Version::Http10;
} }
else if (strncmp(buffer + start, "HTTP/1.1", diff) == 0) { else if (strncmp(cursor.offset(start), "HTTP/1.1", diff) == 0) {
request.version = Version::Http11; request->version = Version::Http11;
} }
else { else {
raise("Encountered invalid HTTP version"); raise("Encountered invalid HTTP version");
} }
advance(2); if (!cursor.advance(2)) return State::Again;
reverter.clear();
return State::Next;
} }
void Parser::State
Parser::expectHeaders() { Parser::HeadersStep::apply(Cursor& cursor) {
while (!eol()) { Reverter reverter(cursor);
while (!cursor.eol()) {
Reverter headerReverter(cursor);
// Read the header name // Read the header name
size_t start = cursor; size_t start = cursor;
while (buffer[cursor] != ':') while (cursor.current() != ':')
advance(1); if (!cursor.advance(1)) return State::Again;
advance(1); if (!cursor.advance(1)) return State::Again;
std::string name = std::string(buffer + start, cursor - start - 1); std::string name = std::string(cursor.offset(start), cursor.diff(start) - 1);
// Skip the ':' // Skip the ':'
advance(1); if (!cursor.advance(1)) return State::Again;
// Read the header value // Read the header value
start = cursor; start = cursor;
while (!eol()) while (!cursor.eol()) {
advance(1); if (!cursor.advance(1)) return State::Again;
}
if (HeaderRegistry::isRegistered(name)) { if (HeaderRegistry::isRegistered(name)) {
std::shared_ptr<Header> header = HeaderRegistry::makeHeader(name); std::shared_ptr<Header> header = HeaderRegistry::makeHeader(name);
header->parseRaw(buffer + start, cursor - start); header->parseRaw(cursor.offset(start), cursor.diff(start));
request.headers.add(header); request->headers.add(header);
}
else {
std::string value = std::string(buffer + start, cursor - start);
} }
// CRLF // CRLF
advance(2); if (!cursor.advance(2)) return State::Again;
headerReverter.clear();
} }
return Parser::State::Next;
} }
void Parser::State
Parser::expectBody() { Parser::BodyStep::apply(Cursor& cursor) {
if (contentLength > 0) { auto cl = request->headers.tryGet<ContentLength>();
advance(2); // CRLF
if (cursor + contentLength > len) { if (cl) {
throw std::runtime_error("Corrupted HTTP Body"); // CRLF
if (!cursor.advance(2)) return State::Again;
auto len = cl->value();
auto start = cursor;
if (!cursor.advance(len)) return State::Again;
request->body = std::string(cursor.offset(start), cursor.diff(start));
} }
request.body = std::string(buffer + cursor, contentLength); return Parser::State::Done;
} }
Parser::State
Parser::parse() {
State state = State::Again;
do {
Step *step = allSteps[currentStep].get();
state = step->apply(cursor);
if (state == State::Next) {
++currentStep;
} }
} while (state == State::Next);
void // Should be either Again or Done
Parser::raise(const char* msg) const { return state;
throw ParsingError(msg); }
bool
Parser::feed(const char* data, size_t len) {
if (len + buffer.len >= sizeof (buffer.data)) {
return false;
}
memcpy(buffer.data + buffer.len, data, len);
buffer.len += len;
} }
ssize_t ssize_t
...@@ -280,6 +334,15 @@ Response::writeTo(Tcp::Peer& peer) ...@@ -280,6 +334,15 @@ Response::writeTo(Tcp::Peer& peer)
fmt.writeString(codeString(static_cast<Code>(code_))); fmt.writeString(codeString(static_cast<Code>(code_)));
fmt.writeRaw(CRLF, 2); fmt.writeRaw(CRLF, 2);
for (const auto& header: headers.list()) {
std::ostringstream oss;
header->write(oss);
std::string str = oss.str();
fmt.writeRaw(str.c_str(), str.size());
fmt.writeRaw(CRLF, 2);
}
fmt.writeHeader("Content-Length", body.size()); fmt.writeHeader("Content-Length", body.size());
fmt.writeRaw(CRLF, 2); fmt.writeRaw(CRLF, 2);
...@@ -288,15 +351,16 @@ Response::writeTo(Tcp::Peer& peer) ...@@ -288,15 +351,16 @@ Response::writeTo(Tcp::Peer& peer)
const size_t len = fmt.cursor() - buffer; const size_t len = fmt.cursor() - buffer;
ssize_t bytes = send(fd, buffer, len, 0); ssize_t bytes = send(fd, buffer, len, 0);
//cout << bytes << " bytes sent" << endl;
} }
void void
Handler::onInput(const char* buffer, size_t len, Tcp::Peer& peer) { Handler::onInput(const char* buffer, size_t len, Tcp::Peer& peer) {
Private::Parser parser(buffer, len); Private::Parser parser(buffer, len);
try { try {
auto request = parser.expectRequest(); Private::Parser::State state = parser.parse();
onRequest(request, peer); if (state == Private::Parser::State::Done) {
onRequest(parser.request, peer);
}
} catch (const Private::ParsingError &err) { } catch (const Private::ParsingError &err) {
cerr << "Error when parsing HTTP request: " << err.what() << endl; cerr << "Error when parsing HTTP request: " << err.what() << endl;
} }
...@@ -306,28 +370,25 @@ void ...@@ -306,28 +370,25 @@ void
Handler::onOutput() { Handler::onOutput() {
} }
Server::Server() Endpoint::Endpoint()
{ } { }
Server::Server(const Net::Address& addr) Endpoint::Endpoint(const Net::Address& addr)
: listener(addr) : listener(addr)
{ } { }
void void
Server::setHandler(const std::shared_ptr<Handler>& handler) { Endpoint::setHandler(const std::shared_ptr<Handler>& handler) {
handler_ = handler; handler_ = handler;
} }
void void
Server::serve() Endpoint::serve()
{ {
if (!handler_) if (!handler_)
throw std::runtime_error("Must call setHandler() prior to serve()"); throw std::runtime_error("Must call setHandler() prior to serve()");
listener.init(8, listener.init(8, Tcp::Options::InstallSignalHandler);
Tcp::Options::NoDelay |
Tcp::Options::InstallSignalHandler |
Tcp::Options::ReuseAddr);
listener.setHandler(handler_); listener.setHandler(handler_);
if (listener.bind()) { if (listener.bind()) {
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <type_traits> #include <type_traits>
#include <stdexcept> #include <stdexcept>
#include <array>
#include "listener.h" #include "listener.h"
#include "net.h" #include "net.h"
#include "http_headers.h" #include "http_headers.h"
...@@ -129,30 +130,127 @@ namespace Private { ...@@ -129,30 +130,127 @@ namespace Private {
struct Parser { struct Parser {
Parser(const char* buffer, size_t len) struct Buffer {
: buffer(buffer) char data[Const::MaxBuffer];
, len(len) size_t len;
, cursor(0) };
, contentLength(-1)
struct Cursor {
static constexpr int Eof = -1;
Cursor(const Buffer &buffer, size_t initialPos = 0)
: buff(buffer)
, value(initialPos)
{ } { }
Request expectRequest(); bool advance(size_t count);
void expectRequestLine();
void expectHeaders(); operator size_t() const { return value; }
void expectBody();
void advance(size_t count);
bool eol() const; bool eol() const;
int next() const;
char current() const;
const char* buffer; const char *offset() const;
size_t len; const char *offset(size_t off) const;
size_t cursor;
size_t diff(size_t before) const;
const Buffer& buff;
size_t value;
};
struct Reverter {
Reverter(Cursor& cursor)
: cursor(cursor)
, pos(cursor.value)
, active(true)
{ }
void revert() {
cursor.value = pos;
}
void clear() {
active = false;
}
~Reverter() {
if (active) cursor.value = pos;
}
Cursor& cursor;
size_t pos;
bool active;
};
enum class State { Again, Next, Done };
struct Step {
Step(Request* request)
: request(request)
{ }
virtual State apply(Cursor& cursor) = 0;
void raise(const char* msg);
Request *request;
};
struct RequestLineStep : public Step {
RequestLineStep(Request* request)
: Step(request)
{ }
State apply(Cursor& cursor);
};
struct HeadersStep : public Step {
HeadersStep(Request* request)
: Step(request)
{ }
State apply(Cursor& cursor);
};
struct BodyStep : public Step {
BodyStep(Request* request)
: Step(request)
{ }
State apply(Cursor& cursor);
};
Parser(const char* data, size_t len)
: contentLength(-1)
, currentStep(0)
, cursor(buffer)
{
allSteps[0].reset(new RequestLineStep(&request));
allSteps[1].reset(new HeadersStep(&request));
allSteps[2].reset(new BodyStep(&request));
feed(data, len);
}
bool feed(const char* data, size_t len);
State parse();
Buffer buffer;
Cursor cursor;
Request request;
char next() const;
private: private:
void raise(const char* msg) const; static constexpr size_t StepsCount = 3;
std::unique_ptr<Step> allSteps[StepsCount];
size_t currentStep;
ssize_t contentLength; ssize_t contentLength;
Request request;
}; };
struct Writer { struct Writer {
...@@ -206,10 +304,10 @@ public: ...@@ -206,10 +304,10 @@ public:
virtual void onRequest(const Request& request, Tcp::Peer& peer) = 0; virtual void onRequest(const Request& request, Tcp::Peer& peer) = 0;
}; };
class Server { class Endpoint {
public: public:
Server(); Endpoint();
Server(const Net::Address& addr); Endpoint(const Net::Address& addr);
void setHandler(const std::shared_ptr<Handler>& handler); void setHandler(const std::shared_ptr<Handler>& handler);
void serve(); void serve();
......
...@@ -5,7 +5,9 @@ ...@@ -5,7 +5,9 @@
*/ */
#include "http_header.h" #include "http_header.h"
#include "common.h"
#include <stdexcept> #include <stdexcept>
#include <iterator>
using namespace std; using namespace std;
...@@ -13,6 +15,28 @@ namespace Net { ...@@ -13,6 +15,28 @@ namespace Net {
namespace Http { namespace Http {
const char* encodingString(Encoding encoding) {
switch (encoding) {
case Encoding::Gzip:
return "gzip";
case Encoding::Compress:
return "compress";
case Encoding::Deflate:
return "deflate";
case Encoding::Identity:
return "identity";
case Encoding::Unknown:
return "unknown";
}
unreachable();
}
void
Header::parse(const std::string& data) {
parseRaw(data.c_str(), data.size());
}
void void
Header::parseRaw(const char *str, size_t len) { Header::parseRaw(const char *str, size_t len) {
parse(std::string(str, len)); parse(std::string(str, len));
...@@ -31,9 +55,103 @@ ContentLength::parse(const std::string& data) { ...@@ -31,9 +55,103 @@ ContentLength::parse(const std::string& data) {
} }
} }
void
ContentLength::write(std::ostream& os) const {
os << "Content-Length: " << value_;
}
void void
Host::parse(const std::string& data) { Host::parse(const std::string& data) {
auto pos = data.find(':');
if (pos != std::string::npos) {
std::string h = data.substr(0, pos);
int16_t p = std::stoi(data.substr(pos + 1));
host_ = h;
port_ = p;
} else {
host_ = data; host_ = data;
port_ = -1;
}
}
void
Host::write(std::ostream& os) const {
os << host_;
if (port_ != -1) {
os << ":" << port_;
}
}
void
UserAgent::parse(const std::string& data) {
ua_ = data;
}
void
UserAgent::write(std::ostream& os) const {
os << "User-Agent: " << ua_;
}
void
Accept::parseRaw(const char *str, size_t len) {
}
void
Accept::write(std::ostream& os) const {
}
void
ContentEncoding::parseRaw(const char* str, size_t len) {
// TODO: case-insensitive
//
if (!strncmp(str, "gzip", len)) {
encoding_ = Encoding::Gzip;
}
else if (!strncmp(str, "deflate", len)) {
encoding_ = Encoding::Deflate;
}
else if (!strncmp(str, "compress", len)) {
encoding_ = Encoding::Compress;
}
else if (!strncmp(str, "identity", len)) {
encoding_ = Encoding::Identity;
}
else {
encoding_ = Encoding::Unknown;
}
}
void
ContentEncoding::write(std::ostream& os) const {
os << "Content-Encoding: " << encodingString(encoding_);
}
Server::Server(const std::vector<std::string>& tokens)
: tokens_(tokens)
{ }
Server::Server(const std::string& token)
{
tokens_.push_back(token);
}
Server::Server(const char* token)
{
tokens_.emplace_back(token);
}
void
Server::parse(const std::string& data)
{
}
void
Server::write(std::ostream& os) const
{
os << "Server: ";
std::copy(std::begin(tokens_), std::end(tokens_),
std::ostream_iterator<std::string>(os, " "));
} }
} // namespace Http } // namespace Http
......
...@@ -7,30 +7,121 @@ ...@@ -7,30 +7,121 @@
#pragma once #pragma once
#include <string> #include <string>
#include <type_traits>
#include <memory>
#include <ostream>
#include <vector>
#define NAME(header_name) \ #define SAFE_HEADER_CAST
static constexpr const char *Name = header_name; \
const char *name() const { return Name; }
namespace Net { namespace Net {
namespace Http { namespace Http {
#ifdef SAFE_HEADER_CAST
namespace detail {
// compile-time FNV-1a hashing algorithm
static constexpr uint64_t basis = 14695981039346656037ULL;
static constexpr uint64_t prime = 1099511628211ULL;
constexpr uint64_t hash_one(char c, const char* remain, unsigned long long value)
{
return c == 0 ? value : hash_one(remain[0], remain + 1, (value ^ c) * prime);
}
constexpr uint64_t hash(const char* str)
{
return hash_one(str[0], str + 1, basis);
}
} // namespace detail
#endif
#ifdef SAFE_HEADER_CAST
#define NAME(header_name) \
static constexpr uint64_t Hash = detail::hash(header_name); \
uint64_t hash() const { return Hash; } \
static constexpr const char *Name = header_name; \
const char *name() const { return Name; }
#else
#define NAME(header_name) \
static constexpr const char *Name = header_name; \
const char *name() const { return Name; }
#endif
// 3.5 Content Codings
enum class Encoding {
Gzip,
Compress,
Deflate,
Identity,
Unknown
};
const char* encodingString(Encoding encoding);
class Header { class Header {
public: public:
virtual void parse(const std::string& data) = 0; virtual const char *name() const = 0;
virtual void parse(const std::string& data);
virtual void parseRaw(const char* str, size_t len); virtual void parseRaw(const char* str, size_t len);
virtual const char *name() const = 0; virtual void write(std::ostream& stream) const = 0;
#ifdef SAFE_HEADER_CAST
virtual uint64_t hash() const = 0;
#endif
};
//virtual void write(Net::Tcp::Stream& writer) = 0; template<typename H> struct IsHeader {
template<typename T>
static std::true_type test(decltype(T::Name) *);
template<typename T>
static std::false_type test(...);
static constexpr bool value
= std::is_base_of<Header, H>::value
&& std::is_same<decltype(test<H>(nullptr)), std::true_type>::value;
}; };
#ifdef SAFE_HEADER_CAST
template<typename To>
typename std::enable_if<IsHeader<To>::value, std::shared_ptr<To>>::type
header_cast(const std::shared_ptr<Header>& from)
{
return static_cast<To *>(0)->Hash == from->hash() ?
std::static_pointer_cast<To>(from) : nullptr;
}
template<typename To>
typename std::enable_if<IsHeader<To>::value, std::shared_ptr<const To>>::type
header_cast(const std::shared_ptr<const Header>& from)
{
return static_cast<To *>(0)->Hash == from->hash() ?
std::static_pointer_cast<const To>(from) : nullptr;
}
#endif
class ContentLength : public Header { class ContentLength : public Header {
public: public:
NAME("Content-Length"); NAME("Content-Length");
ContentLength()
: value_(0)
{ }
ContentLength(uint64_t val)
: value_(val)
{ }
void parse(const std::string& data); void parse(const std::string& data);
void write(std::ostream& os) const;
uint64_t value() const { return value_; } uint64_t value() const { return value_; }
private: private:
...@@ -41,15 +132,98 @@ class Host : public Header { ...@@ -41,15 +132,98 @@ class Host : public Header {
public: public:
NAME("Host"); NAME("Host");
Host()
: host_()
, port_(-1)
{ }
Host(const std::string& host, int16_t port = -1)
: host_(host)
, port_(port)
{ }
void parse(const std::string& data); void parse(const std::string& data);
void write(std::ostream& os) const;
std::string host() const { return host_; } std::string host() const { return host_; }
int16_t port() const { return port_; }
private: private:
std::string host_; std::string host_;
int16_t port_;
};
class UserAgent : public Header {
public:
NAME("User-Agent")
UserAgent() { }
UserAgent(const std::string& ua) :
ua_(ua)
{ }
void parse(const std::string& data);
void write(std::ostream& os) const;
std::string ua() const { return ua_; }
private:
std::string ua_;
};
class Accept : public Header {
public:
NAME("Accept")
Accept() { }
void parseRaw(const char *str, size_t len);
void write(std::ostream& os) const;
private:
std::string data;
};
class ContentEncoding : public Header {
public:
NAME("Content-Encoding")
ContentEncoding()
: encoding_(Encoding::Identity)
{ }
ContentEncoding(Encoding encoding)
: encoding_(encoding)
{ }
void parseRaw(const char* str, size_t len);
void write(std::ostream& os) const;
Encoding encoding() const { return encoding_; }
private:
Encoding encoding_;
};
class Server : public Header {
public:
NAME("Server")
Server() { }
Server(const std::vector<std::string>& tokens);
Server(const std::string& token);
Server(const char* token);
void parse(const std::string& data);
void write(std::ostream& os) const;
std::vector<std::string> tokens() const { return tokens_; }
private:
std::vector<std::string> tokens_;
}; };
} // namespace Http } // namespace Http
} // namespace Net } // namespace Net
#undef NAME
...@@ -57,19 +57,73 @@ HeaderRegistry::isRegistered(const std::string& name) { ...@@ -57,19 +57,73 @@ HeaderRegistry::isRegistered(const std::string& name) {
return it != std::end(registry); return it != std::end(registry);
} }
void Headers&
Headers::add(const std::shared_ptr<Header>& header) { Headers::add(const std::shared_ptr<Header>& header) {
headers.insert(std::make_pair(header->name(), header)); headers.insert(std::make_pair(header->name(), header));
return *this;
}
std::shared_ptr<const Header>
Headers::get(const std::string& name) const {
auto header = getImpl(name);
if (!header.first) {
throw std::runtime_error("Could not find header");
}
return header.second;
} }
std::shared_ptr<Header> std::shared_ptr<Header>
Headers::getHeader(const std::string& name) const { Headers::get(const std::string& name) {
auto header = getImpl(name);
if (!header.first) {
throw std::runtime_error("Could not find header");
}
return header.second;
}
std::shared_ptr<const Header>
Headers::tryGet(const std::string& name) const {
auto header = getImpl(name);
if (!header.first) return nullptr;
return header.second;
}
std::shared_ptr<Header>
Headers::tryGet(const std::string& name) {
auto header = getImpl(name);
if (!header.first) return nullptr;
return header.second;
}
bool
Headers::has(const std::string& name) const {
return getImpl(name).first;
}
std::vector<std::shared_ptr<Header>>
Headers::list() const {
std::vector<std::shared_ptr<Header>> ret;
ret.reserve(headers.size());
for (const auto& h: headers) {
ret.push_back(h.second);
}
return ret;
}
std::pair<bool, std::shared_ptr<Header>>
Headers::getImpl(const std::string& name) const {
auto it = headers.find(name); auto it = headers.find(name);
if (it == std::end(headers)) { if (it == std::end(headers)) {
throw std::runtime_error("Could not find header"); return std::make_pair(false, nullptr);
} }
return it->second; return std::make_pair(true, it->second);
} }
namespace { namespace {
...@@ -77,6 +131,9 @@ namespace { ...@@ -77,6 +131,9 @@ namespace {
AtInit() { AtInit() {
HeaderRegistry::registerHeader<ContentLength>(); HeaderRegistry::registerHeader<ContentLength>();
HeaderRegistry::registerHeader<Host>(); HeaderRegistry::registerHeader<Host>();
HeaderRegistry::registerHeader<Accept>();
HeaderRegistry::registerHeader<UserAgent>();
HeaderRegistry::registerHeader<ContentEncoding>();
} }
} atInit; } atInit;
} }
......
...@@ -19,21 +19,57 @@ class Headers { ...@@ -19,21 +19,57 @@ class Headers {
public: public:
template<typename H> template<typename H>
/*
typename std::enable_if< typename std::enable_if<
std::is_base_of<H, Header>::value, std::shared_ptr<Header> IsHeader<H>::value, std::shared_ptr<const H>
>::type >::type
*/ get() const {
std::shared_ptr<H> return std::static_pointer_cast<const H>(get(H::Name));
getHeader() const {
return std::static_pointer_cast<H>(getHeader(H::Name));
} }
template<typename H>
typename std::enable_if<
IsHeader<H>::value, std::shared_ptr<H>
>::type
get() {
return std::static_pointer_cast<H>(get(H::Name));
}
template<typename H>
typename std::enable_if<
IsHeader<H>::value, std::shared_ptr<const H>
>::type
tryGet() const {
return std::static_pointer_cast<const H>(tryGet(H::Name));
}
template<typename H>
typename std::enable_if<
IsHeader<H>::value, std::shared_ptr<H>
>::type
tryGet() {
return std::static_pointer_cast<H>(tryGet(H::Name));
}
Headers& add(const std::shared_ptr<Header>& header);
void add(const std::shared_ptr<Header>& header); std::shared_ptr<const Header> get(const std::string& name) const;
std::shared_ptr<Header> get(const std::string& name);
std::shared_ptr<Header> getHeader(const std::string& name) const; std::shared_ptr<const Header> tryGet(const std::string& name) const;
std::shared_ptr<Header> tryGet(const std::string& name);
template<typename H>
typename std::enable_if<
IsHeader<H>::value, bool
>::type
has() const {
return has(H::Name);
}
bool has(const std::string& name) const;
std::vector<std::shared_ptr<Header>> list() const;
private: private:
std::pair<bool, std::shared_ptr<Header>> getImpl(const std::string& name) const;
std::unordered_map<std::string, std::shared_ptr<Header>> headers; std::unordered_map<std::string, std::shared_ptr<Header>> headers;
}; };
...@@ -43,11 +79,9 @@ struct HeaderRegistry { ...@@ -43,11 +79,9 @@ struct HeaderRegistry {
template<typename H> template<typename H>
static static
/* typename std::enable_if< typename std::enable_if<
std::is_base_of<H, Header>::value, void IsHeader<H>::value, void
>::type >::type
*/
void
registerHeader() { registerHeader() {
registerHeader(H::Name, []() -> std::unique_ptr<Header> { registerHeader(H::Name, []() -> std::unique_ptr<Header> {
return std::unique_ptr<Header>(new H()); return std::unique_ptr<Header>(new H());
......
...@@ -184,14 +184,17 @@ IoWorker::run() { ...@@ -184,14 +184,17 @@ IoWorker::run() {
mailbox.bind(poller); mailbox.bind(poller);
std::chrono::milliseconds timeout(-1);
for (;;) { for (;;) {
std::vector<Polling::Event> events; std::vector<Polling::Event> events;
int ready_fds; int ready_fds;
switch(ready_fds = poller.poll(events, 32, std::chrono::milliseconds(0))) { switch(ready_fds = poller.poll(events, 1024, timeout)) {
case -1: case -1:
break; break;
case 0: case 0:
timeout = std::chrono::milliseconds(-1);
break; break;
default: default:
for (const auto& event: events) { for (const auto& event: events) {
...@@ -206,6 +209,7 @@ IoWorker::run() { ...@@ -206,6 +209,7 @@ IoWorker::run() {
} }
} }
} }
timeout = std::chrono::milliseconds(0);
break; break;
} }
} }
...@@ -380,10 +384,6 @@ Listener::dispatchPeer(const std::shared_ptr<Peer>& peer) { ...@@ -380,10 +384,6 @@ Listener::dispatchPeer(const std::shared_ptr<Peer>& peer) {
} }
void
Listener::handleSigint(int) {
}
} // namespace Tcp } // namespace Tcp
} // namespace Net } // namespace Net
...@@ -95,7 +95,6 @@ private: ...@@ -95,7 +95,6 @@ private:
std::shared_ptr<Handler> handler_; std::shared_ptr<Handler> handler_;
void dispatchPeer(const std::shared_ptr<Peer>& peer); void dispatchPeer(const std::shared_ptr<Peer>& peer);
void handleSigint(int);
}; };
class Handler { class Handler {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment