Commit 3e6addba authored by Mathieu Stefani's avatar Mathieu Stefani Committed by Mathieu STEFANI

Follow-up of work on type safe headers and improved http parser

parent 5b65633d
......@@ -12,10 +12,11 @@ class MyHandler : public Net::Http::Handler {
if (req.resource == "/ping") {
if (req.method == Net::Http::Method::Get) {
auto host = req.headers.getHeader<Net::Http::Host>();
cout << "Host = " << host->host() << endl;
using namespace Net::Http;
Net::Http::Response response(Net::Http::Code::Ok, "PONG");
// response.headers
// .add(std::make_shared<Server>("lys"));
response.writeTo(peer);
}
......@@ -35,7 +36,7 @@ int main(int argc, char *argv[]) {
cout << "Cores = " << hardware_concurrency() << endl;
Net::Http::Server server(addr);
Net::Http::Endpoint server(addr);
server.setHandler(std::make_shared<MyHandler>());
server.serve();
......
/* flags.h
Mathieu Stefani, 18 August 2015
Make it easy to have bitwise operators for scoped or unscoped enumerations
*/
#pragma once
#include <type_traits>
#include <iostream>
// Looks like gcc 4.6 does not implement std::underlying_type
namespace detail {
template<size_t N> struct TypeStorage;
template<> struct TypeStorage<sizeof(uint8_t)> {
typedef uint8_t Type;
};
template<> struct TypeStorage<sizeof(uint16_t)> {
typedef uint16_t Type;
};
template<> struct TypeStorage<sizeof(uint32_t)> {
typedef uint32_t Type;
};
template<> struct TypeStorage<sizeof(uint64_t)> {
typedef uint64_t Type;
};
template<typename T> struct UnderlyingType {
typedef typename TypeStorage<sizeof(T)>::Type Type;
};
}
template<typename T>
class Flags {
public:
static_assert(std::is_enum<T>::value, "Flags only works with enumerations");
typedef typename detail::UnderlyingType<T>::Type Type;
Flags() { }
Flags(T val) : val(val)
{
}
#define DEFINE_BITWISE_OP_CONST(Op) \
Flags<T> operator Op (T rhs) const { \
return Flags<T>( \
static_cast<T>(static_cast<Type>(val) Op static_cast<Type>(rhs)) \
); \
} \
\
Flags<T> operator Op (Flags<T> rhs) const { \
return Flags<T>( \
static_cast<T>(static_cast<Type>(val) Op static_cast<Type>(rhs.val)) \
); \
}
DEFINE_BITWISE_OP_CONST(|)
DEFINE_BITWISE_OP_CONST(&)
DEFINE_BITWISE_OP_CONST(^)
#undef DEFINE_BITWISE_OP_CONST
#define DEFINE_BITWISE_OP(Op) \
Flags<T>& operator Op##=(T rhs) { \
val = static_cast<T>( \
static_cast<Type>(val) Op static_cast<Type>(rhs) \
); \
return *this; \
} \
\
Flags<T>& operator Op##=(Flags<T> rhs) { \
val = static_cast<T>( \
static_cast<Type>(val) Op static_cast<Type>(rhs.val) \
); \
return *this; \
}
DEFINE_BITWISE_OP(|)
DEFINE_BITWISE_OP(&)
DEFINE_BITWISE_OP(^)
#undef DEFINE_BITWISE_OP
bool hasFlag(T flag) const {
return static_cast<T>(
static_cast<Type>(val) & static_cast<Type>(flag)
) == flag;
}
Flags<T>& setFlag(T flag) {
return *this &= flag;
}
Flags<T>& toggleFlag(T flag) {
return *this ^= flag;
}
operator T() const {
return val;
}
private:
T val;
};
#define DEFINE_BITWISE_OP(Op, T) \
inline T operator Op (T lhs, T rhs) { \
typedef detail::UnderlyingType<T>::Type UnderlyingType; \
return static_cast<T>( \
static_cast<UnderlyingType>(lhs) Op static_cast<UnderlyingType>(rhs) \
); \
}
#define DECLARE_FLAGS_OPERATORS(T) \
DEFINE_BITWISE_OP(&, T) \
DEFINE_BITWISE_OP(|, T)
......@@ -24,46 +24,71 @@ static constexpr char CRLF[] = {CR, LF};
namespace Private {
void
Parser::advance(size_t count) {
if (cursor + count >= len) {
raise("Early EOF");
bool
Parser::Cursor::advance(size_t count)
{
if (value + count >= sizeof (buff.data)) {
//parser->raise("Early EOF");
}
else if (value + count >= buff.len) {
return false;
}
cursor += count;
value += count;
return true;
}
bool
Parser::eol() const {
return buffer[cursor] == CR && next() == LF;
Parser::Cursor::eol() const {
return buff.data[value] == CR && next() == LF;
}
char
Parser::next() const {
if (cursor + 1 >= len) {
raise("Early EOF");
int
Parser::Cursor::next() const {
if (value + 1 >= sizeof (buff.data)) {
//parser->raise("Early EOF");
}
else if (value + 1 >= buff.len) {
return Eof;
}
return buffer[cursor + 1];
return buff.data[value + 1];
}
Request
Parser::expectRequest() {
expectRequestLine();
expectHeaders();
expectBody();
char
Parser::Cursor::current() const {
return buff.data[value];
}
const char *
Parser::Cursor::offset() const {
return buff.data + value;
}
const char *
Parser::Cursor::offset(size_t off) const {
return buff.data + off;
}
return request;
size_t
Parser::Cursor::diff(size_t previous) const {
return value - previous;
}
// 5.1 Request-Line
void
Parser::expectRequestLine() {
Parser::Step::raise(const char* msg) {
throw ParsingError(msg);
}
Parser::State
Parser::RequestLineStep::apply(Cursor& cursor) {
Reverter reverter(cursor);
auto tryMatch = [&](const char* const str) {
const size_t len = std::strlen(str);
if (strncmp(buffer, str, len) == 0) {
cursor += len - 1;
if (strncmp(cursor.offset(), str, len) == 0) {
cursor.advance(len - 1);
return true;
}
return false;
......@@ -72,123 +97,152 @@ namespace Private {
// Method
if (tryMatch("OPTIONS")) {
request.method = Method::Options;
request->method = Method::Options;
}
else if (tryMatch("GET")) {
request.method = Method::Get;
request->method = Method::Get;
}
else if (tryMatch("POST")) {
request.method = Method::Post;
request->method = Method::Post;
}
else if (tryMatch("HEAD")) {
request.method = Method::Head;
request->method = Method::Head;
}
else if (tryMatch("PUT")) {
request.method = Method::Put;
request->method = Method::Put;
}
else if (tryMatch("DELETE")) {
request.method = Method::Delete;
request->method = Method::Delete;
}
if (next() != ' ') {
raise("Malformed HTTP request after Method");
}
auto n = cursor.next();
if (n == Cursor::Eof) return State::Again;
// SP
advance(2);
// Request-URI
else if (n != ' ') raise("Malformed HTTP Request");
if (!cursor.advance(2)) return State::Again;
size_t start = cursor;
while (next() != ' ') {
advance(1);
if (eol()) {
raise("Malformed HTTP request after Request-URI");
}
while ((n = cursor.next()) != Cursor::Eof && n != ' ') {
if (!cursor.advance(1)) return State::Again;
}
request.resource = std::string(buffer + start, cursor - start + 1);
if (next() != ' ') {
request->resource = std::string(cursor.offset(start), cursor.diff(start) + 1);
if ((n = cursor.next()) == Cursor::Eof) return State::Again;
if (n != ' ')
raise("Malformed HTTP request after Request-URI");
}
// SP
advance(2);
if (!cursor.advance(2)) return State::Again;
// HTTP-Version
start = cursor;
while (!eol())
advance(1);
while (!cursor.eol())
if (!cursor.advance(1)) return State::Again;
const size_t diff = cursor - start;
if (strncmp(buffer + start, "HTTP/1.0", diff) == 0) {
request.version = Version::Http10;
const size_t diff = cursor.diff(start);
if (strncmp(cursor.offset(start), "HTTP/1.0", diff) == 0) {
request->version = Version::Http10;
}
else if (strncmp(buffer + start, "HTTP/1.1", diff) == 0) {
request.version = Version::Http11;
else if (strncmp(cursor.offset(start), "HTTP/1.1", diff) == 0) {
request->version = Version::Http11;
}
else {
raise("Encountered invalid HTTP version");
}
advance(2);
if (!cursor.advance(2)) return State::Again;
reverter.clear();
return State::Next;
}
void
Parser::expectHeaders() {
while (!eol()) {
Parser::State
Parser::HeadersStep::apply(Cursor& cursor) {
Reverter reverter(cursor);
while (!cursor.eol()) {
Reverter headerReverter(cursor);
// Read the header name
size_t start = cursor;
while (buffer[cursor] != ':')
advance(1);
while (cursor.current() != ':')
if (!cursor.advance(1)) return State::Again;
advance(1);
if (!cursor.advance(1)) return State::Again;
std::string name = std::string(buffer + start, cursor - start - 1);
std::string name = std::string(cursor.offset(start), cursor.diff(start) - 1);
// Skip the ':'
advance(1);
if (!cursor.advance(1)) return State::Again;
// Read the header value
start = cursor;
while (!eol())
advance(1);
while (!cursor.eol()) {
if (!cursor.advance(1)) return State::Again;
}
if (HeaderRegistry::isRegistered(name)) {
std::shared_ptr<Header> header = HeaderRegistry::makeHeader(name);
header->parseRaw(buffer + start, cursor - start);
request.headers.add(header);
}
else {
std::string value = std::string(buffer + start, cursor - start);
header->parseRaw(cursor.offset(start), cursor.diff(start));
request->headers.add(header);
}
// CRLF
advance(2);
if (!cursor.advance(2)) return State::Again;
headerReverter.clear();
}
return Parser::State::Next;
}
void
Parser::expectBody() {
if (contentLength > 0) {
advance(2); // CRLF
Parser::State
Parser::BodyStep::apply(Cursor& cursor) {
auto cl = request->headers.tryGet<ContentLength>();
if (cursor + contentLength > len) {
throw std::runtime_error("Corrupted HTTP Body");
}
if (cl) {
// CRLF
if (!cursor.advance(2)) return State::Again;
auto len = cl->value();
auto start = cursor;
if (!cursor.advance(len)) return State::Again;
request.body = std::string(buffer + cursor, contentLength);
request->body = std::string(cursor.offset(start), cursor.diff(start));
}
return Parser::State::Done;
}
void
Parser::raise(const char* msg) const {
throw ParsingError(msg);
Parser::State
Parser::parse() {
State state = State::Again;
do {
Step *step = allSteps[currentStep].get();
state = step->apply(cursor);
if (state == State::Next) {
++currentStep;
}
} while (state == State::Next);
// Should be either Again or Done
return state;
}
bool
Parser::feed(const char* data, size_t len) {
if (len + buffer.len >= sizeof (buffer.data)) {
return false;
}
memcpy(buffer.data + buffer.len, data, len);
buffer.len += len;
}
ssize_t
......@@ -280,6 +334,15 @@ Response::writeTo(Tcp::Peer& peer)
fmt.writeString(codeString(static_cast<Code>(code_)));
fmt.writeRaw(CRLF, 2);
for (const auto& header: headers.list()) {
std::ostringstream oss;
header->write(oss);
std::string str = oss.str();
fmt.writeRaw(str.c_str(), str.size());
fmt.writeRaw(CRLF, 2);
}
fmt.writeHeader("Content-Length", body.size());
fmt.writeRaw(CRLF, 2);
......@@ -288,15 +351,16 @@ Response::writeTo(Tcp::Peer& peer)
const size_t len = fmt.cursor() - buffer;
ssize_t bytes = send(fd, buffer, len, 0);
//cout << bytes << " bytes sent" << endl;
}
void
Handler::onInput(const char* buffer, size_t len, Tcp::Peer& peer) {
Private::Parser parser(buffer, len);
try {
auto request = parser.expectRequest();
onRequest(request, peer);
Private::Parser::State state = parser.parse();
if (state == Private::Parser::State::Done) {
onRequest(parser.request, peer);
}
} catch (const Private::ParsingError &err) {
cerr << "Error when parsing HTTP request: " << err.what() << endl;
}
......@@ -306,28 +370,25 @@ void
Handler::onOutput() {
}
Server::Server()
Endpoint::Endpoint()
{ }
Server::Server(const Net::Address& addr)
Endpoint::Endpoint(const Net::Address& addr)
: listener(addr)
{ }
void
Server::setHandler(const std::shared_ptr<Handler>& handler) {
Endpoint::setHandler(const std::shared_ptr<Handler>& handler) {
handler_ = handler;
}
void
Server::serve()
Endpoint::serve()
{
if (!handler_)
throw std::runtime_error("Must call setHandler() prior to serve()");
listener.init(8,
Tcp::Options::NoDelay |
Tcp::Options::InstallSignalHandler |
Tcp::Options::ReuseAddr);
listener.init(8, Tcp::Options::InstallSignalHandler);
listener.setHandler(handler_);
if (listener.bind()) {
......
......@@ -8,6 +8,7 @@
#include <type_traits>
#include <stdexcept>
#include <array>
#include "listener.h"
#include "net.h"
#include "http_headers.h"
......@@ -129,30 +130,127 @@ namespace Private {
struct Parser {
Parser(const char* buffer, size_t len)
: buffer(buffer)
, len(len)
, cursor(0)
, contentLength(-1)
{ }
struct Buffer {
char data[Const::MaxBuffer];
size_t len;
};
Request expectRequest();
void expectRequestLine();
void expectHeaders();
void expectBody();
struct Cursor {
static constexpr int Eof = -1;
void advance(size_t count);
bool eol() const;
Cursor(const Buffer &buffer, size_t initialPos = 0)
: buff(buffer)
, value(initialPos)
{ }
const char* buffer;
size_t len;
size_t cursor;
bool advance(size_t count);
operator size_t() const { return value; }
bool eol() const;
int next() const;
char current() const;
const char *offset() const;
const char *offset(size_t off) const;
size_t diff(size_t before) const;
const Buffer& buff;
size_t value;
};
struct Reverter {
Reverter(Cursor& cursor)
: cursor(cursor)
, pos(cursor.value)
, active(true)
{ }
void revert() {
cursor.value = pos;
}
void clear() {
active = false;
}
~Reverter() {
if (active) cursor.value = pos;
}
Cursor& cursor;
size_t pos;
bool active;
};
enum class State { Again, Next, Done };
struct Step {
Step(Request* request)
: request(request)
{ }
virtual State apply(Cursor& cursor) = 0;
void raise(const char* msg);
Request *request;
};
struct RequestLineStep : public Step {
RequestLineStep(Request* request)
: Step(request)
{ }
State apply(Cursor& cursor);
};
struct HeadersStep : public Step {
HeadersStep(Request* request)
: Step(request)
{ }
State apply(Cursor& cursor);
};
struct BodyStep : public Step {
BodyStep(Request* request)
: Step(request)
{ }
State apply(Cursor& cursor);
};
Parser(const char* data, size_t len)
: contentLength(-1)
, currentStep(0)
, cursor(buffer)
{
allSteps[0].reset(new RequestLineStep(&request));
allSteps[1].reset(new HeadersStep(&request));
allSteps[2].reset(new BodyStep(&request));
feed(data, len);
}
bool feed(const char* data, size_t len);
State parse();
Buffer buffer;
Cursor cursor;
Request request;
char next() const;
private:
void raise(const char* msg) const;
static constexpr size_t StepsCount = 3;
std::unique_ptr<Step> allSteps[StepsCount];
size_t currentStep;
ssize_t contentLength;
Request request;
};
struct Writer {
......@@ -206,10 +304,10 @@ public:
virtual void onRequest(const Request& request, Tcp::Peer& peer) = 0;
};
class Server {
class Endpoint {
public:
Server();
Server(const Net::Address& addr);
Endpoint();
Endpoint(const Net::Address& addr);
void setHandler(const std::shared_ptr<Handler>& handler);
void serve();
......
......@@ -5,7 +5,9 @@
*/
#include "http_header.h"
#include "common.h"
#include <stdexcept>
#include <iterator>
using namespace std;
......@@ -13,6 +15,28 @@ namespace Net {
namespace Http {
const char* encodingString(Encoding encoding) {
switch (encoding) {
case Encoding::Gzip:
return "gzip";
case Encoding::Compress:
return "compress";
case Encoding::Deflate:
return "deflate";
case Encoding::Identity:
return "identity";
case Encoding::Unknown:
return "unknown";
}
unreachable();
}
void
Header::parse(const std::string& data) {
parseRaw(data.c_str(), data.size());
}
void
Header::parseRaw(const char *str, size_t len) {
parse(std::string(str, len));
......@@ -31,9 +55,103 @@ ContentLength::parse(const std::string& data) {
}
}
void
ContentLength::write(std::ostream& os) const {
os << "Content-Length: " << value_;
}
void
Host::parse(const std::string& data) {
host_ = data;
auto pos = data.find(':');
if (pos != std::string::npos) {
std::string h = data.substr(0, pos);
int16_t p = std::stoi(data.substr(pos + 1));
host_ = h;
port_ = p;
} else {
host_ = data;
port_ = -1;
}
}
void
Host::write(std::ostream& os) const {
os << host_;
if (port_ != -1) {
os << ":" << port_;
}
}
void
UserAgent::parse(const std::string& data) {
ua_ = data;
}
void
UserAgent::write(std::ostream& os) const {
os << "User-Agent: " << ua_;
}
void
Accept::parseRaw(const char *str, size_t len) {
}
void
Accept::write(std::ostream& os) const {
}
void
ContentEncoding::parseRaw(const char* str, size_t len) {
// TODO: case-insensitive
//
if (!strncmp(str, "gzip", len)) {
encoding_ = Encoding::Gzip;
}
else if (!strncmp(str, "deflate", len)) {
encoding_ = Encoding::Deflate;
}
else if (!strncmp(str, "compress", len)) {
encoding_ = Encoding::Compress;
}
else if (!strncmp(str, "identity", len)) {
encoding_ = Encoding::Identity;
}
else {
encoding_ = Encoding::Unknown;
}
}
void
ContentEncoding::write(std::ostream& os) const {
os << "Content-Encoding: " << encodingString(encoding_);
}
Server::Server(const std::vector<std::string>& tokens)
: tokens_(tokens)
{ }
Server::Server(const std::string& token)
{
tokens_.push_back(token);
}
Server::Server(const char* token)
{
tokens_.emplace_back(token);
}
void
Server::parse(const std::string& data)
{
}
void
Server::write(std::ostream& os) const
{
os << "Server: ";
std::copy(std::begin(tokens_), std::end(tokens_),
std::ostream_iterator<std::string>(os, " "));
}
} // namespace Http
......
......@@ -7,30 +7,121 @@
#pragma once
#include <string>
#include <type_traits>
#include <memory>
#include <ostream>
#include <vector>
#define NAME(header_name) \
static constexpr const char *Name = header_name; \
const char *name() const { return Name; }
#define SAFE_HEADER_CAST
namespace Net {
namespace Http {
#ifdef SAFE_HEADER_CAST
namespace detail {
// compile-time FNV-1a hashing algorithm
static constexpr uint64_t basis = 14695981039346656037ULL;
static constexpr uint64_t prime = 1099511628211ULL;
constexpr uint64_t hash_one(char c, const char* remain, unsigned long long value)
{
return c == 0 ? value : hash_one(remain[0], remain + 1, (value ^ c) * prime);
}
constexpr uint64_t hash(const char* str)
{
return hash_one(str[0], str + 1, basis);
}
} // namespace detail
#endif
#ifdef SAFE_HEADER_CAST
#define NAME(header_name) \
static constexpr uint64_t Hash = detail::hash(header_name); \
uint64_t hash() const { return Hash; } \
static constexpr const char *Name = header_name; \
const char *name() const { return Name; }
#else
#define NAME(header_name) \
static constexpr const char *Name = header_name; \
const char *name() const { return Name; }
#endif
// 3.5 Content Codings
enum class Encoding {
Gzip,
Compress,
Deflate,
Identity,
Unknown
};
const char* encodingString(Encoding encoding);
class Header {
public:
virtual void parse(const std::string& data) = 0;
virtual const char *name() const = 0;
virtual void parse(const std::string& data);
virtual void parseRaw(const char* str, size_t len);
virtual const char *name() const = 0;
virtual void write(std::ostream& stream) const = 0;
#ifdef SAFE_HEADER_CAST
virtual uint64_t hash() const = 0;
#endif
//virtual void write(Net::Tcp::Stream& writer) = 0;
};
template<typename H> struct IsHeader {
template<typename T>
static std::true_type test(decltype(T::Name) *);
template<typename T>
static std::false_type test(...);
static constexpr bool value
= std::is_base_of<Header, H>::value
&& std::is_same<decltype(test<H>(nullptr)), std::true_type>::value;
};
#ifdef SAFE_HEADER_CAST
template<typename To>
typename std::enable_if<IsHeader<To>::value, std::shared_ptr<To>>::type
header_cast(const std::shared_ptr<Header>& from)
{
return static_cast<To *>(0)->Hash == from->hash() ?
std::static_pointer_cast<To>(from) : nullptr;
}
template<typename To>
typename std::enable_if<IsHeader<To>::value, std::shared_ptr<const To>>::type
header_cast(const std::shared_ptr<const Header>& from)
{
return static_cast<To *>(0)->Hash == from->hash() ?
std::static_pointer_cast<const To>(from) : nullptr;
}
#endif
class ContentLength : public Header {
public:
NAME("Content-Length");
ContentLength()
: value_(0)
{ }
ContentLength(uint64_t val)
: value_(val)
{ }
void parse(const std::string& data);
void write(std::ostream& os) const;
uint64_t value() const { return value_; }
private:
......@@ -41,15 +132,98 @@ class Host : public Header {
public:
NAME("Host");
Host()
: host_()
, port_(-1)
{ }
Host(const std::string& host, int16_t port = -1)
: host_(host)
, port_(port)
{ }
void parse(const std::string& data);
void write(std::ostream& os) const;
std::string host() const { return host_; }
int16_t port() const { return port_; }
private:
std::string host_;
int16_t port_;
};
class UserAgent : public Header {
public:
NAME("User-Agent")
UserAgent() { }
UserAgent(const std::string& ua) :
ua_(ua)
{ }
void parse(const std::string& data);
void write(std::ostream& os) const;
std::string ua() const { return ua_; }
private:
std::string ua_;
};
class Accept : public Header {
public:
NAME("Accept")
Accept() { }
void parseRaw(const char *str, size_t len);
void write(std::ostream& os) const;
private:
std::string data;
};
class ContentEncoding : public Header {
public:
NAME("Content-Encoding")
ContentEncoding()
: encoding_(Encoding::Identity)
{ }
ContentEncoding(Encoding encoding)
: encoding_(encoding)
{ }
void parseRaw(const char* str, size_t len);
void write(std::ostream& os) const;
Encoding encoding() const { return encoding_; }
private:
Encoding encoding_;
};
class Server : public Header {
public:
NAME("Server")
Server() { }
Server(const std::vector<std::string>& tokens);
Server(const std::string& token);
Server(const char* token);
void parse(const std::string& data);
void write(std::ostream& os) const;
std::vector<std::string> tokens() const { return tokens_; }
private:
std::vector<std::string> tokens_;
};
} // namespace Http
} // namespace Net
#undef NAME
......@@ -57,19 +57,73 @@ HeaderRegistry::isRegistered(const std::string& name) {
return it != std::end(registry);
}
void
Headers&
Headers::add(const std::shared_ptr<Header>& header) {
headers.insert(std::make_pair(header->name(), header));
return *this;
}
std::shared_ptr<const Header>
Headers::get(const std::string& name) const {
auto header = getImpl(name);
if (!header.first) {
throw std::runtime_error("Could not find header");
}
return header.second;
}
std::shared_ptr<Header>
Headers::get(const std::string& name) {
auto header = getImpl(name);
if (!header.first) {
throw std::runtime_error("Could not find header");
}
return header.second;
}
std::shared_ptr<const Header>
Headers::tryGet(const std::string& name) const {
auto header = getImpl(name);
if (!header.first) return nullptr;
return header.second;
}
std::shared_ptr<Header>
Headers::getHeader(const std::string& name) const {
Headers::tryGet(const std::string& name) {
auto header = getImpl(name);
if (!header.first) return nullptr;
return header.second;
}
bool
Headers::has(const std::string& name) const {
return getImpl(name).first;
}
std::vector<std::shared_ptr<Header>>
Headers::list() const {
std::vector<std::shared_ptr<Header>> ret;
ret.reserve(headers.size());
for (const auto& h: headers) {
ret.push_back(h.second);
}
return ret;
}
std::pair<bool, std::shared_ptr<Header>>
Headers::getImpl(const std::string& name) const {
auto it = headers.find(name);
if (it == std::end(headers)) {
throw std::runtime_error("Could not find header");
return std::make_pair(false, nullptr);
}
return it->second;
return std::make_pair(true, it->second);
}
namespace {
......@@ -77,6 +131,9 @@ namespace {
AtInit() {
HeaderRegistry::registerHeader<ContentLength>();
HeaderRegistry::registerHeader<Host>();
HeaderRegistry::registerHeader<Accept>();
HeaderRegistry::registerHeader<UserAgent>();
HeaderRegistry::registerHeader<ContentEncoding>();
}
} atInit;
}
......
......@@ -19,21 +19,57 @@ class Headers {
public:
template<typename H>
/*
typename std::enable_if<
std::is_base_of<H, Header>::value, std::shared_ptr<Header>
IsHeader<H>::value, std::shared_ptr<const H>
>::type
*/
std::shared_ptr<H>
getHeader() const {
return std::static_pointer_cast<H>(getHeader(H::Name));
get() const {
return std::static_pointer_cast<const H>(get(H::Name));
}
template<typename H>
typename std::enable_if<
IsHeader<H>::value, std::shared_ptr<H>
>::type
get() {
return std::static_pointer_cast<H>(get(H::Name));
}
template<typename H>
typename std::enable_if<
IsHeader<H>::value, std::shared_ptr<const H>
>::type
tryGet() const {
return std::static_pointer_cast<const H>(tryGet(H::Name));
}
template<typename H>
typename std::enable_if<
IsHeader<H>::value, std::shared_ptr<H>
>::type
tryGet() {
return std::static_pointer_cast<H>(tryGet(H::Name));
}
Headers& add(const std::shared_ptr<Header>& header);
void add(const std::shared_ptr<Header>& header);
std::shared_ptr<const Header> get(const std::string& name) const;
std::shared_ptr<Header> get(const std::string& name);
std::shared_ptr<Header> getHeader(const std::string& name) const;
std::shared_ptr<const Header> tryGet(const std::string& name) const;
std::shared_ptr<Header> tryGet(const std::string& name);
template<typename H>
typename std::enable_if<
IsHeader<H>::value, bool
>::type
has() const {
return has(H::Name);
}
bool has(const std::string& name) const;
std::vector<std::shared_ptr<Header>> list() const;
private:
std::pair<bool, std::shared_ptr<Header>> getImpl(const std::string& name) const;
std::unordered_map<std::string, std::shared_ptr<Header>> headers;
};
......@@ -43,14 +79,12 @@ struct HeaderRegistry {
template<typename H>
static
/* typename std::enable_if<
std::is_base_of<H, Header>::value, void
typename std::enable_if<
IsHeader<H>::value, void
>::type
*/
void
registerHeader() {
registerHeader(H::Name, []() -> std::unique_ptr<Header> {
return std::unique_ptr<Header>(new H());
return std::unique_ptr<Header>(new H());
});
}
......
......@@ -184,14 +184,17 @@ IoWorker::run() {
mailbox.bind(poller);
std::chrono::milliseconds timeout(-1);
for (;;) {
std::vector<Polling::Event> events;
int ready_fds;
switch(ready_fds = poller.poll(events, 32, std::chrono::milliseconds(0))) {
switch(ready_fds = poller.poll(events, 1024, timeout)) {
case -1:
break;
case 0:
timeout = std::chrono::milliseconds(-1);
break;
default:
for (const auto& event: events) {
......@@ -206,6 +209,7 @@ IoWorker::run() {
}
}
}
timeout = std::chrono::milliseconds(0);
break;
}
}
......@@ -380,10 +384,6 @@ Listener::dispatchPeer(const std::shared_ptr<Peer>& peer) {
}
void
Listener::handleSigint(int) {
}
} // namespace Tcp
} // namespace Net
......@@ -95,7 +95,6 @@ private:
std::shared_ptr<Handler> handler_;
void dispatchPeer(const std::shared_ptr<Peer>& peer);
void handleSigint(int);
};
class Handler {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment