Commit 5b65633d authored by Mathieu Stefani's avatar Mathieu Stefani Committed by Mathieu STEFANI

Polished the code a little bit and introduced a new Poller class that wraps system calls to epoll.

Also started working on type-safe http headers parsing
parent 9fcc1835
#include "net.h" #include "net.h"
#include "peer.h" #include "peer.h"
#include "http.h" #include "http.h"
#include "http_headers.h"
#include <iostream> #include <iostream>
#include <cstring> #include <cstring>
...@@ -8,14 +9,15 @@ using namespace std; ...@@ -8,14 +9,15 @@ using namespace std;
class MyHandler : public Net::Http::Handler { class MyHandler : public Net::Http::Handler {
void onRequest(const Net::Http::Request& req, Net::Tcp::Peer& peer) { void onRequest(const Net::Http::Request& req, Net::Tcp::Peer& peer) {
cout << "Received " << methodString(req.method) << " request on " << req.resource << endl;
if (req.resource == "/ping") { if (req.resource == "/ping") {
if (req.method == Net::Http::Method::Get) { if (req.method == Net::Http::Method::Get) {
cout << "PONG" << endl;
auto host = req.headers.getHeader<Net::Http::Host>();
cout << "Host = " << host->host() << endl;
Net::Http::Response response(Net::Http::Code::Ok, "PONG"); Net::Http::Response response(Net::Http::Code::Ok, "PONG");
response.writeTo(peer); response.writeTo(peer);
} }
} }
} }
......
...@@ -5,6 +5,8 @@ set(SOURCE_FILES ...@@ -5,6 +5,8 @@ set(SOURCE_FILES
os.cc os.cc
peer.cc peer.cc
http.cc http.cc
http_header.cc
http_headers.cc
) )
add_library(net ${SOURCE_FILES}) add_library(net ${SOURCE_FILES})
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <sstream> #include <sstream>
#include <cstdio> #include <cstdio>
#include <cassert> #include <cassert>
#include <cstring>
#include <stdexcept>
#include <sys/types.h> #include <sys/types.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <netdb.h> #include <netdb.h>
...@@ -17,9 +19,15 @@ ...@@ -17,9 +19,15 @@
do { \ do { \
auto ret = __VA_ARGS__; \ auto ret = __VA_ARGS__; \
if (ret < 0) { \ if (ret < 0) { \
perror(#__VA_ARGS__); \ const char* str = #__VA_ARGS__; \
cerr << gai_strerror(ret) << endl; \ std::ostringstream oss; \
return false; \ oss << str << ": "; \
if (errno == 0) { \
oss << gai_strerror(ret); \
} else { \
oss << strerror(errno); \
} \
throw std::runtime_error(oss.str()); \
} \ } \
} while (0) } while (0)
...@@ -36,6 +44,8 @@ ...@@ -36,6 +44,8 @@
}(); \ }(); \
(void) 0 (void) 0
#define unreachable() __builtin_unreachable()
namespace Const { namespace Const {
static constexpr int MaxBacklog = 128; static constexpr int MaxBacklog = 128;
......
...@@ -27,7 +27,7 @@ namespace Private { ...@@ -27,7 +27,7 @@ namespace Private {
void void
Parser::advance(size_t count) { Parser::advance(size_t count) {
if (cursor + count >= len) { if (cursor + count >= len) {
throw std::runtime_error("Early EOF"); raise("Early EOF");
} }
cursor += count; cursor += count;
...@@ -41,7 +41,7 @@ namespace Private { ...@@ -41,7 +41,7 @@ namespace Private {
char char
Parser::next() const { Parser::next() const {
if (cursor + 1 >= len) { if (cursor + 1 >= len) {
throw std::runtime_error("Early EOF"); raise("Early EOF");
} }
return buffer[cursor + 1]; return buffer[cursor + 1];
...@@ -63,13 +63,13 @@ namespace Private { ...@@ -63,13 +63,13 @@ namespace Private {
auto tryMatch = [&](const char* const str) { auto tryMatch = [&](const char* const str) {
const size_t len = std::strlen(str); const size_t len = std::strlen(str);
if (strncmp(buffer, str, len) == 0) { if (strncmp(buffer, str, len) == 0) {
cursor += len; cursor += len - 1;
return true; return true;
} }
return false; return false;
}; };
// 5.1.1 Method // Method
if (tryMatch("OPTIONS")) { if (tryMatch("OPTIONS")) {
request.method = Method::Options; request.method = Method::Options;
...@@ -90,29 +90,49 @@ namespace Private { ...@@ -90,29 +90,49 @@ namespace Private {
request.method = Method::Delete; request.method = Method::Delete;
} }
advance(1); if (next() != ' ') {
raise("Malformed HTTP request after Method");
if (buffer[cursor] != ' ') {
// Exceptionnado
} }
// 5.1.2 Request-URI // SP
advance(2);
// Request-URI
size_t start = cursor; size_t start = cursor;
while (buffer[cursor] != ' ') { while (next() != ' ') {
advance(1); advance(1);
if (eol()) { if (eol()) {
// Exceptionnado raise("Malformed HTTP request after Request-URI");
} }
} }
request.resource = std::string(buffer + start, cursor - start); request.resource = std::string(buffer + start, cursor - start + 1);
if (next() != ' ') {
raise("Malformed HTTP request after Request-URI");
}
// SP
advance(2);
// HTTP-Version
start = cursor;
// Skip HTTP-Version for now
while (!eol()) while (!eol())
advance(1); advance(1);
const size_t diff = cursor - start;
if (strncmp(buffer + start, "HTTP/1.0", diff) == 0) {
request.version = Version::Http10;
}
else if (strncmp(buffer + start, "HTTP/1.1", diff) == 0) {
request.version = Version::Http11;
}
else {
raise("Encountered invalid HTTP version");
}
advance(2); advance(2);
} }
...@@ -128,7 +148,7 @@ namespace Private { ...@@ -128,7 +148,7 @@ namespace Private {
advance(1); advance(1);
std::string fieldName = std::string(buffer + start, cursor - start - 1); std::string name = std::string(buffer + start, cursor - start - 1);
// Skip the ':' // Skip the ':'
advance(1); advance(1);
...@@ -138,15 +158,15 @@ namespace Private { ...@@ -138,15 +158,15 @@ namespace Private {
while (!eol()) while (!eol())
advance(1); advance(1);
std::string fieldValue = std::string(buffer + start, cursor - start);
if (fieldName == "Content-Length") { if (HeaderRegistry::isRegistered(name)) {
size_t pos; std::shared_ptr<Header> header = HeaderRegistry::makeHeader(name);
contentLength = std::stol(fieldValue, &pos); header->parseRaw(buffer + start, cursor - start);
request.headers.add(header);
}
else {
std::string value = std::string(buffer + start, cursor - start);
} }
request.headers.push_back(make_pair(std::move(fieldName), std::move(fieldValue)));
// CRLF // CRLF
advance(2); advance(2);
...@@ -166,6 +186,35 @@ namespace Private { ...@@ -166,6 +186,35 @@ namespace Private {
} }
} }
void
Parser::raise(const char* msg) const {
throw ParsingError(msg);
}
ssize_t
Writer::writeRaw(const void* data, size_t len) {
buf = static_cast<char *>(memcpy(buf, data, len));
buf += len;
return 0;
}
ssize_t
Writer::writeString(const char* str) {
const size_t len = std::strlen(str);
return writeRaw(str, std::strlen(str));
}
ssize_t
Writer::writeHeader(const char* name, const char* value) {
writeString(name);
writeChar(':');
writeString(value);
writeRaw(CRLF, 2);
return 0;
}
} // namespace Private } // namespace Private
...@@ -179,7 +228,7 @@ const char* methodString(Method method) ...@@ -179,7 +228,7 @@ const char* methodString(Method method)
#undef METHOD #undef METHOD
} }
return nullptr; unreachable();
} }
const char* codeString(Code code) const char* codeString(Code code)
...@@ -192,21 +241,28 @@ const char* codeString(Code code) ...@@ -192,21 +241,28 @@ const char* codeString(Code code)
#undef CODE #undef CODE
} }
return nullptr; unreachable();
} }
Message::Message() Message::Message()
: version(Version::Http11)
{ } { }
Request::Request() Request::Request()
: Message() : Message()
{ } { }
Response::Response(int code, std::string body)
{
this->body = std::move(body);
code_ = code;
}
Response::Response(Code code, std::string body) Response::Response(Code code, std::string body)
: Message() : Message()
{ {
this->body = std::move(body); this->body = std::move(body);
code_ = code; code_ = static_cast<int>(code);
} }
void void
...@@ -216,52 +272,34 @@ Response::writeTo(Tcp::Peer& peer) ...@@ -216,52 +272,34 @@ Response::writeTo(Tcp::Peer& peer)
char buffer[Const::MaxBuffer]; char buffer[Const::MaxBuffer];
std::memset(buffer, 0, Const::MaxBuffer); std::memset(buffer, 0, Const::MaxBuffer);
char *p_buf = buffer; Private::Writer fmt(buffer, sizeof buffer);
auto writeRaw = [&](const void* data, size_t len) {
p_buf = static_cast<char *>(memcpy(p_buf, data, len));
p_buf += len;
};
auto writeString = [&](const char* str) {
const size_t len = std::strlen(str);
writeRaw(str, std::strlen(str));
};
auto writeInt = [&](uint64_t value) { fmt.writeString("HTTP/1.1 ");
auto str = std::to_string(value); fmt.writeInt(code_);
writeRaw(str.c_str(), str.size()); fmt.writeChar(' ');
}; fmt.writeString(codeString(static_cast<Code>(code_)));
fmt.writeRaw(CRLF, 2);
auto writeChar = [&](char c) { fmt.writeHeader("Content-Length", body.size());
*p_buf++ = c;
};
writeString("HTTP/1.1 "); fmt.writeRaw(CRLF, 2);
writeInt(static_cast<int>(code_)); fmt.writeString(body.c_str());
writeChar(' ');
writeString(codeString(code_));
writeRaw(CRLF, 2);
writeString("Content-Length:"); const size_t len = fmt.cursor() - buffer;
writeInt(body.size());
writeRaw(CRLF, 2);
writeRaw(CRLF, 2);
writeString(body.c_str());
const size_t len = p_buf - buffer;
ssize_t bytes = send(fd, buffer, len, 0); ssize_t bytes = send(fd, buffer, len, 0);
cout << bytes << " bytes sent" << endl; //cout << bytes << " bytes sent" << endl;
} }
void void
Handler::onInput(const char* buffer, size_t len, Tcp::Peer& peer) { Handler::onInput(const char* buffer, size_t len, Tcp::Peer& peer) {
Private::Parser parser(buffer, len); Private::Parser parser(buffer, len);
auto request = parser.expectRequest(); try {
auto request = parser.expectRequest();
onRequest(request, peer); onRequest(request, peer);
} catch (const Private::ParsingError &err) {
cerr << "Error when parsing HTTP request: " << err.what() << endl;
}
} }
void void
...@@ -286,7 +324,10 @@ Server::serve() ...@@ -286,7 +324,10 @@ Server::serve()
if (!handler_) if (!handler_)
throw std::runtime_error("Must call setHandler() prior to serve()"); throw std::runtime_error("Must call setHandler() prior to serve()");
listener.init(4); listener.init(8,
Tcp::Options::NoDelay |
Tcp::Options::InstallSignalHandler |
Tcp::Options::ReuseAddr);
listener.setHandler(handler_); listener.setHandler(handler_);
if (listener.bind()) { if (listener.bind()) {
......
...@@ -6,8 +6,11 @@ ...@@ -6,8 +6,11 @@
#pragma once #pragma once
#include <type_traits>
#include <stdexcept>
#include "listener.h" #include "listener.h"
#include "net.h" #include "net.h"
#include "http_headers.h"
namespace Net { namespace Net {
...@@ -78,6 +81,11 @@ enum class Code { ...@@ -78,6 +81,11 @@ enum class Code {
#undef CODE #undef CODE
}; };
enum class Version {
Http10, // HTTP/1.0
Http11 // HTTP/1.1
};
const char* methodString(Method method); const char* methodString(Method method);
const char* codeString(Code code); const char* codeString(Code code);
...@@ -85,7 +93,9 @@ const char* codeString(Code code); ...@@ -85,7 +93,9 @@ const char* codeString(Code code);
class Message { class Message {
public: public:
Message(); Message();
std::vector<std::pair<std::string, std::string>> headers; Version version;
Headers headers;
std::string body; std::string body;
}; };
...@@ -101,17 +111,24 @@ public: ...@@ -101,17 +111,24 @@ public:
// 6. Response // 6. Response
class Response : public Message { class Response : public Message {
public: public:
Response(int code, std::string body);
Response(Code code, std::string body); Response(Code code, std::string body);
void writeTo(Tcp::Peer& peer); void writeTo(Tcp::Peer& peer);
std::string mimeType;
private: private:
Code code_; int code_;
}; };
namespace Private { namespace Private {
struct ParsingError : public std::runtime_error {
ParsingError(const char* msg) : std::runtime_error(msg) { }
};
struct Parser { struct Parser {
Parser(const char* buffer, size_t len) Parser(const char* buffer, size_t len)
: buffer(buffer) : buffer(buffer)
, len(len) , len(len)
...@@ -133,10 +150,52 @@ namespace Private { ...@@ -133,10 +150,52 @@ namespace Private {
char next() const; char next() const;
private: private:
void raise(const char* msg) const;
ssize_t contentLength; ssize_t contentLength;
Request request; Request request;
}; };
struct Writer {
Writer(char* buffer, size_t len)
: buf(buffer)
, len(len)
{ }
ssize_t writeRaw(const void* data, size_t len);
ssize_t writeString(const char* str);
template<typename T>
typename std::enable_if<
std::is_integral<T>::value, ssize_t
>::type
writeInt(T value) {
auto str = std::to_string(value);
return writeRaw(str.c_str(), str.size());
}
ssize_t writeChar(char c) {
*buf++ = c;
return 0;
}
ssize_t writeHeader(const char* name, const char* value);
template<typename T>
typename std::enable_if<
std::is_arithmetic<T>::value, ssize_t
>::type
writeHeader(const char* name, T value) {
auto str = std::to_string(value);
return writeHeader(name, str.c_str());
}
char *cursor() const { return buf; }
private:
char* buf;
size_t len;
};
} }
class Handler : public Net::Tcp::Handler { class Handler : public Net::Tcp::Handler {
......
/* http_header.cc
Mathieu Stefani, 19 August 2015
Implementation of common HTTP headers described by the RFC
*/
#include "http_header.h"
#include <stdexcept>
using namespace std;
namespace Net {
namespace Http {
void
Header::parseRaw(const char *str, size_t len) {
parse(std::string(str, len));
}
void
ContentLength::parse(const std::string& data) {
try {
size_t pos;
uint64_t val = std::stoi(data, &pos);
if (pos != 0) {
}
value_ = val;
} catch (const std::invalid_argument& e) {
}
}
void
Host::parse(const std::string& data) {
host_ = data;
}
} // namespace Http
} // namespace Net
/* http_header.h
Mathieu Stefani, 19 August 2015
Declaration of common http headers
*/
#pragma once
#include <string>
#define NAME(header_name) \
static constexpr const char *Name = header_name; \
const char *name() const { return Name; }
namespace Net {
namespace Http {
class Header {
public:
virtual void parse(const std::string& data) = 0;
virtual void parseRaw(const char* str, size_t len);
virtual const char *name() const = 0;
//virtual void write(Net::Tcp::Stream& writer) = 0;
};
class ContentLength : public Header {
public:
NAME("Content-Length");
void parse(const std::string& data);
uint64_t value() const { return value_; }
private:
uint64_t value_;
};
class Host : public Header {
public:
NAME("Host");
void parse(const std::string& data);
std::string host() const { return host_; }
private:
std::string host_;
};
} // namespace Http
} // namespace Net
#undef NAME
/* http_headers.cc
Mathieu Stefani, 19 August 2015
Headers registry
*/
#include "http_headers.h"
#include <unordered_map>
#include <iterator>
#include <stdexcept>
#include <iostream>
namespace Net {
namespace Http {
namespace {
std::unordered_map<std::string, HeaderRegistry::RegistryFunc> registry;
}
void
HeaderRegistry::registerHeader(std::string name, HeaderRegistry::RegistryFunc func)
{
auto it = registry.find(name);
if (it != std::end(registry)) {
throw std::runtime_error("Header already registered");
}
registry.insert(std::make_pair(name, std::move(func)));
}
std::vector<std::string>
HeaderRegistry::headersList() {
std::vector<std::string> names;
names.reserve(registry.size());
for (const auto &header: registry) {
names.push_back(header.first);
}
return names;
}
std::unique_ptr<Header>
HeaderRegistry::makeHeader(const std::string& name) {
auto it = registry.find(name);
if (it == std::end(registry)) {
throw std::runtime_error("Unknown header");
}
return it->second();
}
bool
HeaderRegistry::isRegistered(const std::string& name) {
auto it = registry.find(name);
return it != std::end(registry);
}
void
Headers::add(const std::shared_ptr<Header>& header) {
headers.insert(std::make_pair(header->name(), header));
}
std::shared_ptr<Header>
Headers::getHeader(const std::string& name) const {
auto it = headers.find(name);
if (it == std::end(headers)) {
throw std::runtime_error("Could not find header");
}
return it->second;
}
namespace {
struct AtInit {
AtInit() {
HeaderRegistry::registerHeader<ContentLength>();
HeaderRegistry::registerHeader<Host>();
}
} atInit;
}
} // namespace Http
} // namespace Net
/* http_headers.h
Mathieu Stefani, 19 August 2015
A list of HTTP headers
*/
#pragma once
#include <unordered_map>
#include <vector>
#include <memory>
#include "http_header.h"
namespace Net {
namespace Http {
class Headers {
public:
template<typename H>
/*
typename std::enable_if<
std::is_base_of<H, Header>::value, std::shared_ptr<Header>
>::type
*/
std::shared_ptr<H>
getHeader() const {
return std::static_pointer_cast<H>(getHeader(H::Name));
}
void add(const std::shared_ptr<Header>& header);
std::shared_ptr<Header> getHeader(const std::string& name) const;
private:
std::unordered_map<std::string, std::shared_ptr<Header>> headers;
};
struct HeaderRegistry {
typedef std::function<std::unique_ptr<Header>()> RegistryFunc;
template<typename H>
static
/* typename std::enable_if<
std::is_base_of<H, Header>::value, void
>::type
*/
void
registerHeader() {
registerHeader(H::Name, []() -> std::unique_ptr<Header> {
return std::unique_ptr<Header>(new H());
});
}
static void registerHeader(std::string name, RegistryFunc func);
static std::vector<std::string> headersList();
static std::unique_ptr<Header> makeHeader(const std::string& name);
static bool isRegistered(const std::string& name);
};
} // namespace Http
} // namespace Net
...@@ -8,9 +8,11 @@ ...@@ -8,9 +8,11 @@
#include <sys/socket.h> #include <sys/socket.h>
#include <unistd.h> #include <unistd.h>
#include <netinet/in.h> #include <netinet/in.h>
#include <netinet/tcp.h>
#include <arpa/inet.h> #include <arpa/inet.h>
#include <netdb.h> #include <netdb.h>
#include <sys/epoll.h> #include <sys/epoll.h>
#include <signal.h>
#include <cassert> #include <cassert>
#include <cstring> #include <cstring>
#include "listener.h" #include "listener.h"
...@@ -24,24 +26,29 @@ namespace Net { ...@@ -24,24 +26,29 @@ namespace Net {
namespace Tcp { namespace Tcp {
namespace {
volatile sig_atomic_t g_listen_fd = -1;
void handle_sigint(int) {
if (g_listen_fd != -1) {
close(g_listen_fd);
g_listen_fd = -1;
}
}
}
using Polling::NotifyOn;
struct Message { struct Message {
virtual ~Message() { } virtual ~Message() { }
enum class Type { NewPeer, Shutdown }; enum class Type { Shutdown };
virtual Type type() const = 0; virtual Type type() const = 0;
}; };
struct PeerMessage : public Message { struct ShutdownMessage : public Message {
PeerMessage(const std::shared_ptr<Peer>& peer) Type type() const { return Type::Shutdown; }
: peer_(peer)
{ }
Type type() const { return Type::NewPeer; }
std::shared_ptr<Peer> peer() const { return peer_; }
private:
std::shared_ptr<Peer> peer_;
}; };
template<typename To> template<typename To>
...@@ -50,9 +57,31 @@ To *message_cast(const std::unique_ptr<Message>& from) ...@@ -50,9 +57,31 @@ To *message_cast(const std::unique_ptr<Message>& from)
return static_cast<To *>(from.get()); return static_cast<To *>(from.get());
} }
void setSocketOptions(Fd fd, Flags<Options> options) {
if (options.hasFlag(Options::ReuseAddr)) {
int one = 1;
TRY(::setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof (one)));
}
if (options.hasFlag(Options::Linger)) {
struct linger opt;
opt.l_onoff = 1;
opt.l_linger = 1;
TRY(::setsockopt(fd, SOL_SOCKET, SO_LINGER, &opt, sizeof (opt)));
}
if (options.hasFlag(Options::FastOpen)) {
int hint = 5;
TRY(::setsockopt(fd, SOL_TCP, TCP_FASTOPEN, &hint, sizeof (hint)));
}
if (options.hasFlag(Options::NoDelay)) {
int one = 1;
TRY(::setsockopt(fd, SOL_TCP, TCP_NODELAY, &one, sizeof (one)));
}
}
IoWorker::IoWorker() { IoWorker::IoWorker() {
epoll_fd = TRY_RET(epoll_create(128));
} }
IoWorker::~IoWorker() { IoWorker::~IoWorker() {
...@@ -62,8 +91,10 @@ IoWorker::~IoWorker() { ...@@ -62,8 +91,10 @@ IoWorker::~IoWorker() {
} }
void void
IoWorker::start(const std::shared_ptr<Handler>& handler) { IoWorker::start(const std::shared_ptr<Handler>& handler, Flags<Options> options) {
handler_ = handler; handler_ = handler;
options_ = options;
thread.reset(new std::thread([this]() { thread.reset(new std::thread([this]() {
this->run(); this->run();
})); }));
...@@ -72,6 +103,7 @@ IoWorker::start(const std::shared_ptr<Handler>& handler) { ...@@ -72,6 +103,7 @@ IoWorker::start(const std::shared_ptr<Handler>& handler) {
std::shared_ptr<Peer> std::shared_ptr<Peer>
IoWorker::getPeer(Fd fd) const IoWorker::getPeer(Fd fd) const
{ {
std::unique_lock<std::mutex> guard(peersMutex);
auto it = peers.find(fd); auto it = peers.find(fd);
if (it == std::end(peers)) if (it == std::end(peers))
{ {
...@@ -80,6 +112,11 @@ IoWorker::getPeer(Fd fd) const ...@@ -80,6 +112,11 @@ IoWorker::getPeer(Fd fd) const
return it->second; return it->second;
} }
std::shared_ptr<Peer>
IoWorker::getPeer(Polling::Tag tag) const
{
return getPeer(tag.value());
}
void void
IoWorker::handleIncoming(const std::shared_ptr<Peer>& peer) { IoWorker::handleIncoming(const std::shared_ptr<Peer>& peer) {
...@@ -97,14 +134,22 @@ IoWorker::handleIncoming(const std::shared_ptr<Peer>& peer) { ...@@ -97,14 +134,22 @@ IoWorker::handleIncoming(const std::shared_ptr<Peer>& peer) {
bytes = recv(fd, buffer + totalBytes, Const::MaxBuffer - totalBytes, 0); bytes = recv(fd, buffer + totalBytes, Const::MaxBuffer - totalBytes, 0);
if (bytes == -1) { if (bytes == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) { if (errno == EAGAIN || errno == EWOULDBLOCK) {
handler_->onInput(buffer, totalBytes, *peer); if (totalBytes > 0) {
handler_->onInput(buffer, totalBytes, *peer);
}
} else { } else {
throw std::runtime_error(strerror(errno)); if (errno == ECONNRESET) {
handler_->onDisconnection(peer);
close(fd);
}
else {
throw std::runtime_error(strerror(errno));
}
} }
break; break;
} }
else if (bytes == 0) { else if (bytes == 0) {
cout << "Peer " << *peer << " has disconnected" << endl; handler_->onDisconnection(peer);
close(fd); close(fd);
break; break;
} }
...@@ -123,49 +168,53 @@ IoWorker::handleIncoming(const std::shared_ptr<Peer>& peer) { ...@@ -123,49 +168,53 @@ IoWorker::handleIncoming(const std::shared_ptr<Peer>& peer) {
void void
IoWorker::handleNewPeer(const std::shared_ptr<Peer>& peer) IoWorker::handleNewPeer(const std::shared_ptr<Peer>& peer)
{ {
std::cout << "New peer: " << *peer << std::endl;
int fd = peer->fd(); int fd = peer->fd();
{
std::unique_lock<std::mutex> guard(peersMutex);
peers.insert(std::make_pair(fd, peer));
}
struct epoll_event event; handler_->onConnection(peer);
event.events = EPOLLIN; poller.addFd(fd, NotifyOn::Read, Polling::Tag(fd), Polling::Mode::Edge);
event.data.fd = fd;
epoll_ctl(epoll_fd, EPOLL_CTL_ADD, fd, &event);
peers.insert(std::make_pair(fd, peer));
} }
void void
IoWorker::run() { IoWorker::run() {
struct epoll_event events[Const::MaxEvents];
mailbox.bind(poller);
for (;;) { for (;;) {
std::vector<Polling::Event> events;
int ready_fds; int ready_fds;
switch(ready_fds = epoll_wait(epoll_fd, events, Const::MaxEvents, 100)) { switch(ready_fds = poller.poll(events, 32, std::chrono::milliseconds(0))) {
case -1: case -1:
break; break;
case 0: case 0:
if (!mailbox.isEmpty()) {
std::unique_ptr<Message> msg(mailbox.clear());
if (msg->type() == Message::Type::NewPeer) {
auto peer_msg = message_cast<PeerMessage>(msg);
handleNewPeer(peer_msg->peer());
}
}
break; break;
default: default:
for (int i = 0; i < ready_fds; ++i) { for (const auto& event: events) {
const struct epoll_event *event = events + i; if (event.tag == mailbox.tag()) {
handleIncoming(getPeer(event->data.fd)); std::unique_ptr<Message> msg(mailbox.clear());
if (msg->type() == Message::Type::Shutdown) {
return;
}
} else {
if (event.flags.hasFlag(NotifyOn::Read)) {
handleIncoming(getPeer(event.tag));
}
}
} }
break; break;
} }
} }
} }
void
IoWorker::handleMailbox() {
}
Handler::Handler() Handler::Handler()
{ } { }
...@@ -173,11 +222,11 @@ Handler::~Handler() ...@@ -173,11 +222,11 @@ Handler::~Handler()
{ } { }
void void
Handler::onConnection() { Handler::onConnection(const std::shared_ptr<Peer>& peer) {
} }
void void
Handler::onDisconnection() { Handler::onDisconnection(const std::shared_ptr<Peer>& peer) {
} }
Listener::Listener() Listener::Listener()
...@@ -192,12 +241,20 @@ Listener::Listener(const Address& address) ...@@ -192,12 +241,20 @@ Listener::Listener(const Address& address)
void void
Listener::init(size_t workers) Listener::init(size_t workers, Flags<Options> options)
{ {
if (workers > hardware_concurrency()) { if (workers > hardware_concurrency()) {
// Log::warning() << "More workers than available cores" // Log::warning() << "More workers than available cores"
} }
options_ = options;
if (options_.hasFlag(Options::InstallSignalHandler)) {
if (signal(SIGINT, handle_sigint) == SIG_ERR) {
throw std::runtime_error("Could not install signal handler");
}
}
for (size_t i = 0; i < workers; ++i) { for (size_t i = 0; i < workers; ++i) {
auto wrk = std::unique_ptr<IoWorker>(new IoWorker); auto wrk = std::unique_ptr<IoWorker>(new IoWorker);
ioGroup.push_back(std::move(wrk)); ioGroup.push_back(std::move(wrk));
...@@ -250,6 +307,8 @@ Listener::bind(const Address& address) { ...@@ -250,6 +307,8 @@ Listener::bind(const Address& address) {
fd = ::socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol); fd = ::socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol);
if (fd < 0) continue; if (fd < 0) continue;
setSocketOptions(fd, options_);
if (::bind(fd, addr->ai_addr, addr->ai_addrlen) < 0) { if (::bind(fd, addr->ai_addr, addr->ai_addrlen) < 0) {
close(fd); close(fd);
continue; continue;
...@@ -258,10 +317,12 @@ Listener::bind(const Address& address) { ...@@ -258,10 +317,12 @@ Listener::bind(const Address& address) {
TRY(::listen(fd, Const::MaxBacklog)); TRY(::listen(fd, Const::MaxBacklog));
} }
listen_fd = fd; listen_fd = fd;
g_listen_fd = fd;
for (auto& io: ioGroup) { for (auto& io: ioGroup) {
io->start(handler_); io->start(handler_, options_);
} }
return true; return true;
...@@ -272,20 +333,31 @@ Listener::run() { ...@@ -272,20 +333,31 @@ Listener::run() {
for (;;) { for (;;) {
struct sockaddr_in peer_addr; struct sockaddr_in peer_addr;
socklen_t peer_addr_len = sizeof(peer_addr); socklen_t peer_addr_len = sizeof(peer_addr);
int client_fd = TRY_RET(::accept(listen_fd, (struct sockaddr *)&peer_addr, &peer_addr_len)); int client_fd = ::accept(listen_fd, (struct sockaddr *)&peer_addr, &peer_addr_len);
if (client_fd < 0) {
if (g_listen_fd == -1) {
cout << "SIGINT Signal received, shutdowning !" << endl;
shutdown();
break;
make_non_blocking(client_fd); } else {
throw std::runtime_error(strerror(errno));
}
}
char peer_host[NI_MAXHOST]; make_non_blocking(client_fd);
if (getnameinfo((struct sockaddr *)&peer_addr, peer_addr_len, peer_host, NI_MAXHOST, nullptr, 0, 0) == 0) { auto peer = make_shared<Peer>(Address::fromUnix((struct sockaddr *)&peer_addr));
Address addr = Address::fromUnix((struct sockaddr *)&peer_addr); peer->associateFd(client_fd);
auto peer = make_shared<Peer>(addr, peer_host);
peer->associateFd(client_fd);
dispatchPeer(peer); dispatchPeer(peer);
} }
}
void
Listener::shutdown() {
for (auto &worker: ioGroup) {
worker->mailbox.post(new ShutdownMessage());
} }
} }
...@@ -294,33 +366,22 @@ Listener::address() const { ...@@ -294,33 +366,22 @@ Listener::address() const {
return addr_; return addr_;
} }
Options
Listener::options() const {
return options_;
}
void void
Listener::dispatchPeer(const std::shared_ptr<Peer>& peer) { Listener::dispatchPeer(const std::shared_ptr<Peer>& peer) {
const size_t workers = ioGroup.size(); const size_t workers = ioGroup.size();
size_t start = peer->fd() % workers; size_t worker = peer->fd() % workers;
/* Find the first available worker */
size_t current = start;
for (;;) {
auto& mailbox = ioGroup[current]->mailbox;
if (mailbox.isEmpty()) {
auto message = new PeerMessage(peer);
auto *old = mailbox.post(message);
assert(old == nullptr);
return;
}
current = (current + 1) % workers; ioGroup[worker]->handleNewPeer(peer);
if (current == start) {
break;
}
}
/* We did not find any available worker, what do we do ? */ }
void
Listener::handleSigint(int) {
} }
} // namespace Tcp } // namespace Tcp
......
...@@ -9,10 +9,12 @@ ...@@ -9,10 +9,12 @@
#include "net.h" #include "net.h"
#include "mailbox.h" #include "mailbox.h"
#include "os.h" #include "os.h"
#include "flags.h"
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <thread> #include <thread>
#include <unordered_map> #include <unordered_map>
#include <mutex>
namespace Net { namespace Net {
...@@ -20,54 +22,80 @@ namespace Tcp { ...@@ -20,54 +22,80 @@ namespace Tcp {
class Peer; class Peer;
class Message; class Message;
class Handler; class Handler;
enum class Options : uint64_t {
None = 0,
NoDelay = 1,
Linger = NoDelay << 1,
FastOpen = Linger << 1,
QuickAck = FastOpen << 1,
ReuseAddr = QuickAck << 1,
ReverseLookup = ReuseAddr << 1,
InstallSignalHandler = ReverseLookup << 1
};
DECLARE_FLAGS_OPERATORS(Options)
void setSocketOptions(Fd fd, Flags<Options> options);
class IoWorker { class IoWorker {
public: public:
Mailbox<Message> mailbox; PollableMailbox<Message> mailbox;
IoWorker(); IoWorker();
~IoWorker(); ~IoWorker();
void start(const std::shared_ptr<Handler> &handler); void start(const std::shared_ptr<Handler> &handler, Flags<Options> options);
void handleNewPeer(const std::shared_ptr<Peer>& peer);
private: private:
int epoll_fd; Polling::Epoll poller;
std::unique_ptr<std::thread> thread; std::unique_ptr<std::thread> thread;
mutable std::mutex peersMutex;
std::unordered_map<Fd, std::shared_ptr<Peer>> peers; std::unordered_map<Fd, std::shared_ptr<Peer>> peers;
std::shared_ptr<Handler> handler_; std::shared_ptr<Handler> handler_;
Flags<Options> options_;
std::shared_ptr<Peer> getPeer(Fd fd) const; std::shared_ptr<Peer> getPeer(Fd fd) const;
std::shared_ptr<Peer> getPeer(Polling::Tag tag) const;
void handleIncoming(const std::shared_ptr<Peer>& peer); void handleIncoming(const std::shared_ptr<Peer>& peer);
void handleNewPeer(const std::shared_ptr<Peer>& peer); void handleMailbox();
void run(); void run();
}; };
class Listener { class Listener {
public: public:
friend class IoWorker; friend class IoWorker;
friend class Peer;
Listener(); Listener();
Listener(const Address& address); Listener(const Address& address);
void init(size_t workers); void init(size_t workers, Flags<Options> options = Options::None);
void setHandler(const std::shared_ptr<Handler>& handler); void setHandler(const std::shared_ptr<Handler>& handler);
bool bind(); bool bind();
bool bind(const Address& adress); bool bind(const Address& adress);
void run(); void run();
void shutdown();
Options options() const;
Address address() const; Address address() const;
private: private:
Address addr_; Address addr_;
int listen_fd; int listen_fd;
std::vector<std::unique_ptr<IoWorker>> ioGroup; std::vector<std::unique_ptr<IoWorker>> ioGroup;
Flags<Options> options_;
std::shared_ptr<Handler> handler_; std::shared_ptr<Handler> handler_;
void dispatchPeer(const std::shared_ptr<Peer>& peer); void dispatchPeer(const std::shared_ptr<Peer>& peer);
void handleSigint(int);
}; };
class Handler { class Handler {
...@@ -78,8 +106,8 @@ public: ...@@ -78,8 +106,8 @@ public:
virtual void onInput(const char *buffer, size_t len, Tcp::Peer& peer) = 0; virtual void onInput(const char *buffer, size_t len, Tcp::Peer& peer) = 0;
virtual void onOutput() = 0; virtual void onOutput() = 0;
virtual void onConnection(); virtual void onConnection(const std::shared_ptr<Peer>& peer);
virtual void onDisconnection(); virtual void onDisconnection(const std::shared_ptr<Peer>& peer);
}; };
......
...@@ -6,9 +6,12 @@ ...@@ -6,9 +6,12 @@
*/ */
#pragma once #pragma once
#include "common.h"
#include "os.h"
#include <atomic> #include <atomic>
#include <stdexcept> #include <stdexcept>
#include <sys/eventfd.h>
template<typename T> template<typename T>
class Mailbox { class Mailbox {
...@@ -17,6 +20,8 @@ public: ...@@ -17,6 +20,8 @@ public:
data.store(nullptr); data.store(nullptr);
} }
virtual ~Mailbox() { }
const T *get() const { const T *get() const {
if (isEmpty()) { if (isEmpty()) {
throw std::runtime_error("Can not retrieve mail from empty mailbox"); throw std::runtime_error("Can not retrieve mail from empty mailbox");
...@@ -25,7 +30,7 @@ public: ...@@ -25,7 +30,7 @@ public:
return data.load(); return data.load();
} }
T *post(T *newData) { virtual T *post(T *newData) {
T *old = data.load(); T *old = data.load();
while (!data.compare_exchange_weak(old, newData)) while (!data.compare_exchange_weak(old, newData))
{ } { }
...@@ -33,7 +38,7 @@ public: ...@@ -33,7 +38,7 @@ public:
return old; return old;
} }
T *clear() { virtual T *clear() {
return data.exchange(nullptr); return data.exchange(nullptr);
} }
...@@ -44,3 +49,85 @@ public: ...@@ -44,3 +49,85 @@ public:
private: private:
std::atomic<T *> data; std::atomic<T *> data;
}; };
template<typename T>
class PollableMailbox : public Mailbox<T>
{
public:
PollableMailbox()
: event_fd(-1) {
}
~PollableMailbox() {
if (event_fd != -1) close(event_fd);
}
bool isBound() const {
return event_fd != -1;
}
Polling::Tag bind(Polling::Epoll& poller) {
using namespace Polling;
if (isBound()) {
throw std::runtime_error("The mailbox has already been bound");
}
event_fd = TRY_RET(eventfd(0, EFD_NONBLOCK));
Tag tag(event_fd);
poller.addFd(event_fd, NotifyOn::Read, tag);
return tag;
}
T *post(T *newData) {
auto *ret = Mailbox<T>::post(newData);
if (isBound()) {
uint64_t val = 1;
TRY_RET(write(event_fd, &val, sizeof val));
}
return ret;
}
T *clear() {
auto ret = Mailbox<T>::clear();
if (isBound()) {
uint64_t val;
for (;;) {
ssize_t bytes = read(event_fd, &val, sizeof val);
if (bytes == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK)
break;
else {
// TODO
}
}
}
}
return ret;
}
Polling::Tag tag() const {
if (!isBound())
throw std::runtime_error("Can not retrieve tag of an unbound mailbox");
return Polling::Tag(event_fd);
}
void unbind(Polling::Epoll& poller) {
if (event_fd == -1) {
throw std::runtime_error("The mailbox is not bound");
}
poller.removeFd(event_fd);
close(event_fd), event_fd = -1;
}
private:
int event_fd;
};
...@@ -8,6 +8,15 @@ ...@@ -8,6 +8,15 @@
#include <string> #include <string>
#include <sys/socket.h> #include <sys/socket.h>
#ifndef _KERNEL_FASTOPEN
#define _KERNEL_FASTOPEN
/* conditional define for TCP_FASTOPEN */
#ifndef TCP_FASTOPEN
#define TCP_FASTOPEN 23
#endif
#endif
namespace Net { namespace Net {
class Port { class Port {
......
...@@ -4,11 +4,13 @@ ...@@ -4,11 +4,13 @@
*/ */
#include "os.h" #include "os.h"
#include "common.h"
#include <unistd.h> #include <unistd.h>
#include <fcntl.h> #include <fcntl.h>
#include <fstream> #include <fstream>
#include <iterator> #include <iterator>
#include <algorithm> #include <algorithm>
#include <sys/epoll.h>
int hardware_concurrency() { int hardware_concurrency() {
std::ifstream cpuinfo("/proc/cpuinfo"); std::ifstream cpuinfo("/proc/cpuinfo");
...@@ -33,3 +35,99 @@ bool make_non_blocking(int sfd) ...@@ -33,3 +35,99 @@ bool make_non_blocking(int sfd)
return true; return true;
} }
namespace Polling {
Epoll::Epoll(size_t max) {
epoll_fd = TRY_RET(epoll_create(max));
}
void
Epoll::addFd(Fd fd, Flags<NotifyOn> interest, Tag tag, Mode mode) {
struct epoll_event ev;
ev.events = toEpollEvents(interest);
if (mode == Mode::Edge)
ev.events |= EPOLLET;
ev.data.u64 = tag.value_;
TRY(epoll_ctl(epoll_fd, EPOLL_CTL_ADD, fd, &ev));
}
void
Epoll::addFdOneShot(Fd fd, Flags<NotifyOn> interest, Tag tag, Mode mode) {
struct epoll_event ev;
ev.events = toEpollEvents(interest);
ev.events |= EPOLLONESHOT;
if (mode == Mode::Edge)
ev.events |= EPOLLET;
ev.data.u64 = tag.value_;
TRY(epoll_ctl(epoll_fd, EPOLL_CTL_ADD, fd, &ev));
}
void
Epoll::removeFd(Fd fd) {
struct epoll_event ev;
TRY(epoll_ctl(epoll_fd, EPOLL_CTL_DEL, fd, &ev));
}
void
Epoll::rearmFd(Fd fd, Flags<NotifyOn> interest, Tag tag, Mode mode) {
struct epoll_event ev;
ev.events = toEpollEvents(interest);
if (mode == Mode::Edge)
ev.events |= EPOLLET;
ev.data.u64 = tag.value_;
TRY(epoll_ctl(epoll_fd, EPOLL_CTL_MOD, fd, &ev));
}
int
Epoll::poll(std::vector<Event>& events, size_t maxEvents, std::chrono::milliseconds timeout) const {
struct epoll_event evs[Const::MaxEvents];
int ready_fds = epoll_wait(epoll_fd, evs, maxEvents, timeout.count());
if (ready_fds > 0) {
for (int i = 0; i < ready_fds; ++i) {
const struct epoll_event *ev = evs + i;
const Tag tag(ev->data.u64);
Event event(tag);
event.flags = toNotifyOn(ev->events);
events.push_back(tag);
}
}
return ready_fds;
}
int
Epoll::toEpollEvents(Flags<NotifyOn> interest) const {
int events;
if (interest.hasFlag(NotifyOn::Read))
events |= EPOLLIN;
if (interest.hasFlag(NotifyOn::Write))
events |= EPOLLOUT;
if (interest.hasFlag(NotifyOn::Hangup))
events |= EPOLLHUP;
return events;
}
Flags<NotifyOn>
Epoll::toNotifyOn(int events) const {
Flags<NotifyOn> flags;
if (events & EPOLLIN)
flags.setFlag(NotifyOn::Read);
if (events & EPOLLOUT)
flags.setFlag(NotifyOn::Write);
if (events & EPOLLHUP)
flags.setFlag(NotifyOn::Hangup);
return flags;
}
} // namespace Poller
...@@ -6,8 +6,78 @@ ...@@ -6,8 +6,78 @@
#pragma once #pragma once
#include <chrono>
#include <vector>
#include "flags.h"
#include "common.h"
typedef int Fd; typedef int Fd;
int hardware_concurrency(); int hardware_concurrency();
bool make_non_blocking(int fd); bool make_non_blocking(int fd);
namespace Polling {
enum class Mode {
Level,
Edge
};
enum class NotifyOn {
Read,
Write,
Hangup
};
DECLARE_FLAGS_OPERATORS(NotifyOn);
struct Tag {
friend class Epoll;
explicit constexpr Tag(uint64_t value)
: value_(value)
{ }
constexpr uint64_t value() const { return value_; }
friend constexpr bool operator==(Tag lhs, Tag rhs);
private:
uint64_t value_;
};
inline constexpr bool operator==(Tag lhs, Tag rhs) {
return lhs.value_ == rhs.value_;
}
struct Event {
Event(Tag tag) :
tag(tag)
{ }
Flags<NotifyOn> flags;
Fd fd;
Tag tag;
};
class Epoll {
public:
Epoll(size_t max = 128);
void addFd(Fd fd, Flags<NotifyOn> interest, Tag tag, Mode mode = Mode::Level);
void addFdOneShot(Fd fd, Flags<NotifyOn> interest, Tag tag, Mode mode = Mode::Level);
void removeFd(Fd fd);
void rearmFd(Fd fd, Flags<NotifyOn> interest, Tag tag, Mode mode = Mode::Level);
int poll(std::vector<Event>& events,
size_t maxEvents = Const::MaxEvents,
std::chrono::milliseconds timeout = std::chrono::milliseconds(0)) const;
private:
int toEpollEvents(Flags<NotifyOn> interest) const;
Flags<NotifyOn> toNotifyOn(int events) const;
int epoll_fd;
};
} // namespace Polling
...@@ -17,9 +17,8 @@ Peer::Peer() ...@@ -17,9 +17,8 @@ Peer::Peer()
: fd_(-1) : fd_(-1)
{ } { }
Peer::Peer(const Address& addr, const string& host) Peer::Peer(const Address& addr)
: addr(addr) : addr(addr)
, host(host)
, fd_(-1) , fd_(-1)
{ } { }
...@@ -30,7 +29,7 @@ Peer::address() const { ...@@ -30,7 +29,7 @@ Peer::address() const {
string string
Peer::hostname() const { Peer::hostname() const {
return host; return hostname_;
} }
void void
......
...@@ -18,7 +18,7 @@ namespace Tcp { ...@@ -18,7 +18,7 @@ namespace Tcp {
class Peer { class Peer {
public: public:
Peer(); Peer();
Peer(const Address& addr, const std::string& hostname); Peer(const Address& addr);
Address address() const; Address address() const;
std::string hostname() const; std::string hostname() const;
...@@ -28,7 +28,7 @@ public: ...@@ -28,7 +28,7 @@ public:
private: private:
Address addr; Address addr;
std::string host; std::string hostname_;
Fd fd_; Fd fd_;
}; };
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment