Commit 5b65633d authored by Mathieu Stefani's avatar Mathieu Stefani Committed by Mathieu STEFANI

Polished the code a little bit and introduced a new Poller class that wraps system calls to epoll.

Also started working on type-safe http headers parsing
parent 9fcc1835
#include "net.h"
#include "peer.h"
#include "http.h"
#include "http_headers.h"
#include <iostream>
#include <cstring>
......@@ -8,14 +9,15 @@ using namespace std;
class MyHandler : public Net::Http::Handler {
void onRequest(const Net::Http::Request& req, Net::Tcp::Peer& peer) {
cout << "Received " << methodString(req.method) << " request on " << req.resource << endl;
if (req.resource == "/ping") {
if (req.method == Net::Http::Method::Get) {
cout << "PONG" << endl;
auto host = req.headers.getHeader<Net::Http::Host>();
cout << "Host = " << host->host() << endl;
Net::Http::Response response(Net::Http::Code::Ok, "PONG");
response.writeTo(peer);
}
}
}
......
......@@ -5,6 +5,8 @@ set(SOURCE_FILES
os.cc
peer.cc
http.cc
http_header.cc
http_headers.cc
)
add_library(net ${SOURCE_FILES})
......
......@@ -9,6 +9,8 @@
#include <sstream>
#include <cstdio>
#include <cassert>
#include <cstring>
#include <stdexcept>
#include <sys/types.h>
#include <sys/socket.h>
#include <netdb.h>
......@@ -17,9 +19,15 @@
do { \
auto ret = __VA_ARGS__; \
if (ret < 0) { \
perror(#__VA_ARGS__); \
cerr << gai_strerror(ret) << endl; \
return false; \
const char* str = #__VA_ARGS__; \
std::ostringstream oss; \
oss << str << ": "; \
if (errno == 0) { \
oss << gai_strerror(ret); \
} else { \
oss << strerror(errno); \
} \
throw std::runtime_error(oss.str()); \
} \
} while (0)
......@@ -36,6 +44,8 @@
}(); \
(void) 0
#define unreachable() __builtin_unreachable()
namespace Const {
static constexpr int MaxBacklog = 128;
......
......@@ -27,7 +27,7 @@ namespace Private {
void
Parser::advance(size_t count) {
if (cursor + count >= len) {
throw std::runtime_error("Early EOF");
raise("Early EOF");
}
cursor += count;
......@@ -41,7 +41,7 @@ namespace Private {
char
Parser::next() const {
if (cursor + 1 >= len) {
throw std::runtime_error("Early EOF");
raise("Early EOF");
}
return buffer[cursor + 1];
......@@ -63,13 +63,13 @@ namespace Private {
auto tryMatch = [&](const char* const str) {
const size_t len = std::strlen(str);
if (strncmp(buffer, str, len) == 0) {
cursor += len;
cursor += len - 1;
return true;
}
return false;
};
// 5.1.1 Method
// Method
if (tryMatch("OPTIONS")) {
request.method = Method::Options;
......@@ -90,29 +90,49 @@ namespace Private {
request.method = Method::Delete;
}
advance(1);
if (buffer[cursor] != ' ') {
// Exceptionnado
if (next() != ' ') {
raise("Malformed HTTP request after Method");
}
// 5.1.2 Request-URI
// SP
advance(2);
// Request-URI
size_t start = cursor;
while (buffer[cursor] != ' ') {
while (next() != ' ') {
advance(1);
if (eol()) {
// Exceptionnado
raise("Malformed HTTP request after Request-URI");
}
}
request.resource = std::string(buffer + start, cursor - start);
request.resource = std::string(buffer + start, cursor - start + 1);
if (next() != ' ') {
raise("Malformed HTTP request after Request-URI");
}
// SP
advance(2);
// HTTP-Version
start = cursor;
// Skip HTTP-Version for now
while (!eol())
advance(1);
const size_t diff = cursor - start;
if (strncmp(buffer + start, "HTTP/1.0", diff) == 0) {
request.version = Version::Http10;
}
else if (strncmp(buffer + start, "HTTP/1.1", diff) == 0) {
request.version = Version::Http11;
}
else {
raise("Encountered invalid HTTP version");
}
advance(2);
}
......@@ -128,7 +148,7 @@ namespace Private {
advance(1);
std::string fieldName = std::string(buffer + start, cursor - start - 1);
std::string name = std::string(buffer + start, cursor - start - 1);
// Skip the ':'
advance(1);
......@@ -138,15 +158,15 @@ namespace Private {
while (!eol())
advance(1);
std::string fieldValue = std::string(buffer + start, cursor - start);
if (fieldName == "Content-Length") {
size_t pos;
contentLength = std::stol(fieldValue, &pos);
if (HeaderRegistry::isRegistered(name)) {
std::shared_ptr<Header> header = HeaderRegistry::makeHeader(name);
header->parseRaw(buffer + start, cursor - start);
request.headers.add(header);
}
else {
std::string value = std::string(buffer + start, cursor - start);
}
request.headers.push_back(make_pair(std::move(fieldName), std::move(fieldValue)));
// CRLF
advance(2);
......@@ -166,6 +186,35 @@ namespace Private {
}
}
void
Parser::raise(const char* msg) const {
throw ParsingError(msg);
}
ssize_t
Writer::writeRaw(const void* data, size_t len) {
buf = static_cast<char *>(memcpy(buf, data, len));
buf += len;
return 0;
}
ssize_t
Writer::writeString(const char* str) {
const size_t len = std::strlen(str);
return writeRaw(str, std::strlen(str));
}
ssize_t
Writer::writeHeader(const char* name, const char* value) {
writeString(name);
writeChar(':');
writeString(value);
writeRaw(CRLF, 2);
return 0;
}
} // namespace Private
......@@ -179,7 +228,7 @@ const char* methodString(Method method)
#undef METHOD
}
return nullptr;
unreachable();
}
const char* codeString(Code code)
......@@ -192,21 +241,28 @@ const char* codeString(Code code)
#undef CODE
}
return nullptr;
unreachable();
}
Message::Message()
: version(Version::Http11)
{ }
Request::Request()
: Message()
{ }
Response::Response(int code, std::string body)
{
this->body = std::move(body);
code_ = code;
}
Response::Response(Code code, std::string body)
: Message()
{
this->body = std::move(body);
code_ = code;
code_ = static_cast<int>(code);
}
void
......@@ -216,52 +272,34 @@ Response::writeTo(Tcp::Peer& peer)
char buffer[Const::MaxBuffer];
std::memset(buffer, 0, Const::MaxBuffer);
char *p_buf = buffer;
auto writeRaw = [&](const void* data, size_t len) {
p_buf = static_cast<char *>(memcpy(p_buf, data, len));
p_buf += len;
};
Private::Writer fmt(buffer, sizeof buffer);
auto writeString = [&](const char* str) {
const size_t len = std::strlen(str);
writeRaw(str, std::strlen(str));
};
auto writeInt = [&](uint64_t value) {
auto str = std::to_string(value);
writeRaw(str.c_str(), str.size());
};
auto writeChar = [&](char c) {
*p_buf++ = c;
};
fmt.writeString("HTTP/1.1 ");
fmt.writeInt(code_);
fmt.writeChar(' ');
fmt.writeString(codeString(static_cast<Code>(code_)));
fmt.writeRaw(CRLF, 2);
writeString("HTTP/1.1 ");
writeInt(static_cast<int>(code_));
writeChar(' ');
writeString(codeString(code_));
writeRaw(CRLF, 2);
writeString("Content-Length:");
writeInt(body.size());
writeRaw(CRLF, 2);
fmt.writeHeader("Content-Length", body.size());
writeRaw(CRLF, 2);
writeString(body.c_str());
fmt.writeRaw(CRLF, 2);
fmt.writeString(body.c_str());
const size_t len = p_buf - buffer;
const size_t len = fmt.cursor() - buffer;
ssize_t bytes = send(fd, buffer, len, 0);
cout << bytes << " bytes sent" << endl;
//cout << bytes << " bytes sent" << endl;
}
void
Handler::onInput(const char* buffer, size_t len, Tcp::Peer& peer) {
Private::Parser parser(buffer, len);
try {
auto request = parser.expectRequest();
onRequest(request, peer);
} catch (const Private::ParsingError &err) {
cerr << "Error when parsing HTTP request: " << err.what() << endl;
}
}
void
......@@ -286,7 +324,10 @@ Server::serve()
if (!handler_)
throw std::runtime_error("Must call setHandler() prior to serve()");
listener.init(4);
listener.init(8,
Tcp::Options::NoDelay |
Tcp::Options::InstallSignalHandler |
Tcp::Options::ReuseAddr);
listener.setHandler(handler_);
if (listener.bind()) {
......
......@@ -6,8 +6,11 @@
#pragma once
#include <type_traits>
#include <stdexcept>
#include "listener.h"
#include "net.h"
#include "http_headers.h"
namespace Net {
......@@ -78,6 +81,11 @@ enum class Code {
#undef CODE
};
enum class Version {
Http10, // HTTP/1.0
Http11 // HTTP/1.1
};
const char* methodString(Method method);
const char* codeString(Code code);
......@@ -85,7 +93,9 @@ const char* codeString(Code code);
class Message {
public:
Message();
std::vector<std::pair<std::string, std::string>> headers;
Version version;
Headers headers;
std::string body;
};
......@@ -101,17 +111,24 @@ public:
// 6. Response
class Response : public Message {
public:
Response(int code, std::string body);
Response(Code code, std::string body);
void writeTo(Tcp::Peer& peer);
std::string mimeType;
private:
Code code_;
int code_;
};
namespace Private {
struct ParsingError : public std::runtime_error {
ParsingError(const char* msg) : std::runtime_error(msg) { }
};
struct Parser {
Parser(const char* buffer, size_t len)
: buffer(buffer)
, len(len)
......@@ -133,10 +150,52 @@ namespace Private {
char next() const;
private:
void raise(const char* msg) const;
ssize_t contentLength;
Request request;
};
struct Writer {
Writer(char* buffer, size_t len)
: buf(buffer)
, len(len)
{ }
ssize_t writeRaw(const void* data, size_t len);
ssize_t writeString(const char* str);
template<typename T>
typename std::enable_if<
std::is_integral<T>::value, ssize_t
>::type
writeInt(T value) {
auto str = std::to_string(value);
return writeRaw(str.c_str(), str.size());
}
ssize_t writeChar(char c) {
*buf++ = c;
return 0;
}
ssize_t writeHeader(const char* name, const char* value);
template<typename T>
typename std::enable_if<
std::is_arithmetic<T>::value, ssize_t
>::type
writeHeader(const char* name, T value) {
auto str = std::to_string(value);
return writeHeader(name, str.c_str());
}
char *cursor() const { return buf; }
private:
char* buf;
size_t len;
};
}
class Handler : public Net::Tcp::Handler {
......
/* http_header.cc
Mathieu Stefani, 19 August 2015
Implementation of common HTTP headers described by the RFC
*/
#include "http_header.h"
#include <stdexcept>
using namespace std;
namespace Net {
namespace Http {
void
Header::parseRaw(const char *str, size_t len) {
parse(std::string(str, len));
}
void
ContentLength::parse(const std::string& data) {
try {
size_t pos;
uint64_t val = std::stoi(data, &pos);
if (pos != 0) {
}
value_ = val;
} catch (const std::invalid_argument& e) {
}
}
void
Host::parse(const std::string& data) {
host_ = data;
}
} // namespace Http
} // namespace Net
/* http_header.h
Mathieu Stefani, 19 August 2015
Declaration of common http headers
*/
#pragma once
#include <string>
#define NAME(header_name) \
static constexpr const char *Name = header_name; \
const char *name() const { return Name; }
namespace Net {
namespace Http {
class Header {
public:
virtual void parse(const std::string& data) = 0;
virtual void parseRaw(const char* str, size_t len);
virtual const char *name() const = 0;
//virtual void write(Net::Tcp::Stream& writer) = 0;
};
class ContentLength : public Header {
public:
NAME("Content-Length");
void parse(const std::string& data);
uint64_t value() const { return value_; }
private:
uint64_t value_;
};
class Host : public Header {
public:
NAME("Host");
void parse(const std::string& data);
std::string host() const { return host_; }
private:
std::string host_;
};
} // namespace Http
} // namespace Net
#undef NAME
/* http_headers.cc
Mathieu Stefani, 19 August 2015
Headers registry
*/
#include "http_headers.h"
#include <unordered_map>
#include <iterator>
#include <stdexcept>
#include <iostream>
namespace Net {
namespace Http {
namespace {
std::unordered_map<std::string, HeaderRegistry::RegistryFunc> registry;
}
void
HeaderRegistry::registerHeader(std::string name, HeaderRegistry::RegistryFunc func)
{
auto it = registry.find(name);
if (it != std::end(registry)) {
throw std::runtime_error("Header already registered");
}
registry.insert(std::make_pair(name, std::move(func)));
}
std::vector<std::string>
HeaderRegistry::headersList() {
std::vector<std::string> names;
names.reserve(registry.size());
for (const auto &header: registry) {
names.push_back(header.first);
}
return names;
}
std::unique_ptr<Header>
HeaderRegistry::makeHeader(const std::string& name) {
auto it = registry.find(name);
if (it == std::end(registry)) {
throw std::runtime_error("Unknown header");
}
return it->second();
}
bool
HeaderRegistry::isRegistered(const std::string& name) {
auto it = registry.find(name);
return it != std::end(registry);
}
void
Headers::add(const std::shared_ptr<Header>& header) {
headers.insert(std::make_pair(header->name(), header));
}
std::shared_ptr<Header>
Headers::getHeader(const std::string& name) const {
auto it = headers.find(name);
if (it == std::end(headers)) {
throw std::runtime_error("Could not find header");
}
return it->second;
}
namespace {
struct AtInit {
AtInit() {
HeaderRegistry::registerHeader<ContentLength>();
HeaderRegistry::registerHeader<Host>();
}
} atInit;
}
} // namespace Http
} // namespace Net
/* http_headers.h
Mathieu Stefani, 19 August 2015
A list of HTTP headers
*/
#pragma once
#include <unordered_map>
#include <vector>
#include <memory>
#include "http_header.h"
namespace Net {
namespace Http {
class Headers {
public:
template<typename H>
/*
typename std::enable_if<
std::is_base_of<H, Header>::value, std::shared_ptr<Header>
>::type
*/
std::shared_ptr<H>
getHeader() const {
return std::static_pointer_cast<H>(getHeader(H::Name));
}
void add(const std::shared_ptr<Header>& header);
std::shared_ptr<Header> getHeader(const std::string& name) const;
private:
std::unordered_map<std::string, std::shared_ptr<Header>> headers;
};
struct HeaderRegistry {
typedef std::function<std::unique_ptr<Header>()> RegistryFunc;
template<typename H>
static
/* typename std::enable_if<
std::is_base_of<H, Header>::value, void
>::type
*/
void
registerHeader() {
registerHeader(H::Name, []() -> std::unique_ptr<Header> {
return std::unique_ptr<Header>(new H());
});
}
static void registerHeader(std::string name, RegistryFunc func);
static std::vector<std::string> headersList();
static std::unique_ptr<Header> makeHeader(const std::string& name);
static bool isRegistered(const std::string& name);
};
} // namespace Http
} // namespace Net
......@@ -8,9 +8,11 @@
#include <sys/socket.h>
#include <unistd.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <sys/epoll.h>
#include <signal.h>
#include <cassert>
#include <cstring>
#include "listener.h"
......@@ -24,24 +26,29 @@ namespace Net {
namespace Tcp {
namespace {
volatile sig_atomic_t g_listen_fd = -1;
void handle_sigint(int) {
if (g_listen_fd != -1) {
close(g_listen_fd);
g_listen_fd = -1;
}
}
}
using Polling::NotifyOn;
struct Message {
virtual ~Message() { }
enum class Type { NewPeer, Shutdown };
enum class Type { Shutdown };
virtual Type type() const = 0;
};
struct PeerMessage : public Message {
PeerMessage(const std::shared_ptr<Peer>& peer)
: peer_(peer)
{ }
Type type() const { return Type::NewPeer; }
std::shared_ptr<Peer> peer() const { return peer_; }
private:
std::shared_ptr<Peer> peer_;
struct ShutdownMessage : public Message {
Type type() const { return Type::Shutdown; }
};
template<typename To>
......@@ -50,9 +57,31 @@ To *message_cast(const std::unique_ptr<Message>& from)
return static_cast<To *>(from.get());
}
void setSocketOptions(Fd fd, Flags<Options> options) {
if (options.hasFlag(Options::ReuseAddr)) {
int one = 1;
TRY(::setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof (one)));
}
if (options.hasFlag(Options::Linger)) {
struct linger opt;
opt.l_onoff = 1;
opt.l_linger = 1;
TRY(::setsockopt(fd, SOL_SOCKET, SO_LINGER, &opt, sizeof (opt)));
}
if (options.hasFlag(Options::FastOpen)) {
int hint = 5;
TRY(::setsockopt(fd, SOL_TCP, TCP_FASTOPEN, &hint, sizeof (hint)));
}
if (options.hasFlag(Options::NoDelay)) {
int one = 1;
TRY(::setsockopt(fd, SOL_TCP, TCP_NODELAY, &one, sizeof (one)));
}
}
IoWorker::IoWorker() {
epoll_fd = TRY_RET(epoll_create(128));
}
IoWorker::~IoWorker() {
......@@ -62,8 +91,10 @@ IoWorker::~IoWorker() {
}
void
IoWorker::start(const std::shared_ptr<Handler>& handler) {
IoWorker::start(const std::shared_ptr<Handler>& handler, Flags<Options> options) {
handler_ = handler;
options_ = options;
thread.reset(new std::thread([this]() {
this->run();
}));
......@@ -72,6 +103,7 @@ IoWorker::start(const std::shared_ptr<Handler>& handler) {
std::shared_ptr<Peer>
IoWorker::getPeer(Fd fd) const
{
std::unique_lock<std::mutex> guard(peersMutex);
auto it = peers.find(fd);
if (it == std::end(peers))
{
......@@ -80,6 +112,11 @@ IoWorker::getPeer(Fd fd) const
return it->second;
}
std::shared_ptr<Peer>
IoWorker::getPeer(Polling::Tag tag) const
{
return getPeer(tag.value());
}
void
IoWorker::handleIncoming(const std::shared_ptr<Peer>& peer) {
......@@ -97,14 +134,22 @@ IoWorker::handleIncoming(const std::shared_ptr<Peer>& peer) {
bytes = recv(fd, buffer + totalBytes, Const::MaxBuffer - totalBytes, 0);
if (bytes == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
if (totalBytes > 0) {
handler_->onInput(buffer, totalBytes, *peer);
}
} else {
if (errno == ECONNRESET) {
handler_->onDisconnection(peer);
close(fd);
}
else {
throw std::runtime_error(strerror(errno));
}
}
break;
}
else if (bytes == 0) {
cout << "Peer " << *peer << " has disconnected" << endl;
handler_->onDisconnection(peer);
close(fd);
break;
}
......@@ -123,49 +168,53 @@ IoWorker::handleIncoming(const std::shared_ptr<Peer>& peer) {
void
IoWorker::handleNewPeer(const std::shared_ptr<Peer>& peer)
{
std::cout << "New peer: " << *peer << std::endl;
int fd = peer->fd();
struct epoll_event event;
event.events = EPOLLIN;
event.data.fd = fd;
epoll_ctl(epoll_fd, EPOLL_CTL_ADD, fd, &event);
{
std::unique_lock<std::mutex> guard(peersMutex);
peers.insert(std::make_pair(fd, peer));
}
handler_->onConnection(peer);
poller.addFd(fd, NotifyOn::Read, Polling::Tag(fd), Polling::Mode::Edge);
}
void
IoWorker::run() {
struct epoll_event events[Const::MaxEvents];
mailbox.bind(poller);
for (;;) {
std::vector<Polling::Event> events;
int ready_fds;
switch(ready_fds = epoll_wait(epoll_fd, events, Const::MaxEvents, 100)) {
switch(ready_fds = poller.poll(events, 32, std::chrono::milliseconds(0))) {
case -1:
break;
case 0:
if (!mailbox.isEmpty()) {
break;
default:
for (const auto& event: events) {
if (event.tag == mailbox.tag()) {
std::unique_ptr<Message> msg(mailbox.clear());
if (msg->type() == Message::Type::NewPeer) {
auto peer_msg = message_cast<PeerMessage>(msg);
handleNewPeer(peer_msg->peer());
if (msg->type() == Message::Type::Shutdown) {
return;
}
} else {
if (event.flags.hasFlag(NotifyOn::Read)) {
handleIncoming(getPeer(event.tag));
}
}
break;
default:
for (int i = 0; i < ready_fds; ++i) {
const struct epoll_event *event = events + i;
handleIncoming(getPeer(event->data.fd));
}
break;
}
}
}
void
IoWorker::handleMailbox() {
}
Handler::Handler()
{ }
......@@ -173,11 +222,11 @@ Handler::~Handler()
{ }
void
Handler::onConnection() {
Handler::onConnection(const std::shared_ptr<Peer>& peer) {
}
void
Handler::onDisconnection() {
Handler::onDisconnection(const std::shared_ptr<Peer>& peer) {
}
Listener::Listener()
......@@ -192,12 +241,20 @@ Listener::Listener(const Address& address)
void
Listener::init(size_t workers)
Listener::init(size_t workers, Flags<Options> options)
{
if (workers > hardware_concurrency()) {
// Log::warning() << "More workers than available cores"
}
options_ = options;
if (options_.hasFlag(Options::InstallSignalHandler)) {
if (signal(SIGINT, handle_sigint) == SIG_ERR) {
throw std::runtime_error("Could not install signal handler");
}
}
for (size_t i = 0; i < workers; ++i) {
auto wrk = std::unique_ptr<IoWorker>(new IoWorker);
ioGroup.push_back(std::move(wrk));
......@@ -250,6 +307,8 @@ Listener::bind(const Address& address) {
fd = ::socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol);
if (fd < 0) continue;
setSocketOptions(fd, options_);
if (::bind(fd, addr->ai_addr, addr->ai_addrlen) < 0) {
close(fd);
continue;
......@@ -258,10 +317,12 @@ Listener::bind(const Address& address) {
TRY(::listen(fd, Const::MaxBacklog));
}
listen_fd = fd;
g_listen_fd = fd;
for (auto& io: ioGroup) {
io->start(handler_);
io->start(handler_, options_);
}
return true;
......@@ -272,20 +333,31 @@ Listener::run() {
for (;;) {
struct sockaddr_in peer_addr;
socklen_t peer_addr_len = sizeof(peer_addr);
int client_fd = TRY_RET(::accept(listen_fd, (struct sockaddr *)&peer_addr, &peer_addr_len));
int client_fd = ::accept(listen_fd, (struct sockaddr *)&peer_addr, &peer_addr_len);
if (client_fd < 0) {
if (g_listen_fd == -1) {
cout << "SIGINT Signal received, shutdowning !" << endl;
shutdown();
break;
make_non_blocking(client_fd);
} else {
throw std::runtime_error(strerror(errno));
}
}
char peer_host[NI_MAXHOST];
make_non_blocking(client_fd);
if (getnameinfo((struct sockaddr *)&peer_addr, peer_addr_len, peer_host, NI_MAXHOST, nullptr, 0, 0) == 0) {
Address addr = Address::fromUnix((struct sockaddr *)&peer_addr);
auto peer = make_shared<Peer>(addr, peer_host);
auto peer = make_shared<Peer>(Address::fromUnix((struct sockaddr *)&peer_addr));
peer->associateFd(client_fd);
dispatchPeer(peer);
}
}
void
Listener::shutdown() {
for (auto &worker: ioGroup) {
worker->mailbox.post(new ShutdownMessage());
}
}
......@@ -294,33 +366,22 @@ Listener::address() const {
return addr_;
}
Options
Listener::options() const {
return options_;
}
void
Listener::dispatchPeer(const std::shared_ptr<Peer>& peer) {
const size_t workers = ioGroup.size();
size_t start = peer->fd() % workers;
/* Find the first available worker */
size_t current = start;
for (;;) {
auto& mailbox = ioGroup[current]->mailbox;
size_t worker = peer->fd() % workers;
if (mailbox.isEmpty()) {
auto message = new PeerMessage(peer);
ioGroup[worker]->handleNewPeer(peer);
auto *old = mailbox.post(message);
assert(old == nullptr);
return;
}
current = (current + 1) % workers;
if (current == start) {
break;
}
}
/* We did not find any available worker, what do we do ? */
}
void
Listener::handleSigint(int) {
}
} // namespace Tcp
......
......@@ -9,10 +9,12 @@
#include "net.h"
#include "mailbox.h"
#include "os.h"
#include "flags.h"
#include <vector>
#include <memory>
#include <thread>
#include <unordered_map>
#include <mutex>
namespace Net {
......@@ -20,54 +22,80 @@ namespace Tcp {
class Peer;
class Message;
class Handler;
enum class Options : uint64_t {
None = 0,
NoDelay = 1,
Linger = NoDelay << 1,
FastOpen = Linger << 1,
QuickAck = FastOpen << 1,
ReuseAddr = QuickAck << 1,
ReverseLookup = ReuseAddr << 1,
InstallSignalHandler = ReverseLookup << 1
};
DECLARE_FLAGS_OPERATORS(Options)
void setSocketOptions(Fd fd, Flags<Options> options);
class IoWorker {
public:
Mailbox<Message> mailbox;
PollableMailbox<Message> mailbox;
IoWorker();
~IoWorker();
void start(const std::shared_ptr<Handler> &handler);
void start(const std::shared_ptr<Handler> &handler, Flags<Options> options);
void handleNewPeer(const std::shared_ptr<Peer>& peer);
private:
int epoll_fd;
Polling::Epoll poller;
std::unique_ptr<std::thread> thread;
mutable std::mutex peersMutex;
std::unordered_map<Fd, std::shared_ptr<Peer>> peers;
std::shared_ptr<Handler> handler_;
Flags<Options> options_;
std::shared_ptr<Peer> getPeer(Fd fd) const;
std::shared_ptr<Peer> getPeer(Polling::Tag tag) const;
void handleIncoming(const std::shared_ptr<Peer>& peer);
void handleNewPeer(const std::shared_ptr<Peer>& peer);
void handleMailbox();
void run();
};
class Listener {
public:
friend class IoWorker;
friend class Peer;
Listener();
Listener(const Address& address);
void init(size_t workers);
void init(size_t workers, Flags<Options> options = Options::None);
void setHandler(const std::shared_ptr<Handler>& handler);
bool bind();
bool bind(const Address& adress);
void run();
void shutdown();
Options options() const;
Address address() const;
private:
Address addr_;
int listen_fd;
std::vector<std::unique_ptr<IoWorker>> ioGroup;
Flags<Options> options_;
std::shared_ptr<Handler> handler_;
void dispatchPeer(const std::shared_ptr<Peer>& peer);
void handleSigint(int);
};
class Handler {
......@@ -78,8 +106,8 @@ public:
virtual void onInput(const char *buffer, size_t len, Tcp::Peer& peer) = 0;
virtual void onOutput() = 0;
virtual void onConnection();
virtual void onDisconnection();
virtual void onConnection(const std::shared_ptr<Peer>& peer);
virtual void onDisconnection(const std::shared_ptr<Peer>& peer);
};
......
......@@ -6,9 +6,12 @@
*/
#pragma once
#include "common.h"
#include "os.h"
#include <atomic>
#include <stdexcept>
#include <sys/eventfd.h>
template<typename T>
class Mailbox {
......@@ -17,6 +20,8 @@ public:
data.store(nullptr);
}
virtual ~Mailbox() { }
const T *get() const {
if (isEmpty()) {
throw std::runtime_error("Can not retrieve mail from empty mailbox");
......@@ -25,7 +30,7 @@ public:
return data.load();
}
T *post(T *newData) {
virtual T *post(T *newData) {
T *old = data.load();
while (!data.compare_exchange_weak(old, newData))
{ }
......@@ -33,7 +38,7 @@ public:
return old;
}
T *clear() {
virtual T *clear() {
return data.exchange(nullptr);
}
......@@ -44,3 +49,85 @@ public:
private:
std::atomic<T *> data;
};
template<typename T>
class PollableMailbox : public Mailbox<T>
{
public:
PollableMailbox()
: event_fd(-1) {
}
~PollableMailbox() {
if (event_fd != -1) close(event_fd);
}
bool isBound() const {
return event_fd != -1;
}
Polling::Tag bind(Polling::Epoll& poller) {
using namespace Polling;
if (isBound()) {
throw std::runtime_error("The mailbox has already been bound");
}
event_fd = TRY_RET(eventfd(0, EFD_NONBLOCK));
Tag tag(event_fd);
poller.addFd(event_fd, NotifyOn::Read, tag);
return tag;
}
T *post(T *newData) {
auto *ret = Mailbox<T>::post(newData);
if (isBound()) {
uint64_t val = 1;
TRY_RET(write(event_fd, &val, sizeof val));
}
return ret;
}
T *clear() {
auto ret = Mailbox<T>::clear();
if (isBound()) {
uint64_t val;
for (;;) {
ssize_t bytes = read(event_fd, &val, sizeof val);
if (bytes == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK)
break;
else {
// TODO
}
}
}
}
return ret;
}
Polling::Tag tag() const {
if (!isBound())
throw std::runtime_error("Can not retrieve tag of an unbound mailbox");
return Polling::Tag(event_fd);
}
void unbind(Polling::Epoll& poller) {
if (event_fd == -1) {
throw std::runtime_error("The mailbox is not bound");
}
poller.removeFd(event_fd);
close(event_fd), event_fd = -1;
}
private:
int event_fd;
};
......@@ -8,6 +8,15 @@
#include <string>
#include <sys/socket.h>
#ifndef _KERNEL_FASTOPEN
#define _KERNEL_FASTOPEN
/* conditional define for TCP_FASTOPEN */
#ifndef TCP_FASTOPEN
#define TCP_FASTOPEN 23
#endif
#endif
namespace Net {
class Port {
......
......@@ -4,11 +4,13 @@
*/
#include "os.h"
#include "common.h"
#include <unistd.h>
#include <fcntl.h>
#include <fstream>
#include <iterator>
#include <algorithm>
#include <sys/epoll.h>
int hardware_concurrency() {
std::ifstream cpuinfo("/proc/cpuinfo");
......@@ -33,3 +35,99 @@ bool make_non_blocking(int sfd)
return true;
}
namespace Polling {
Epoll::Epoll(size_t max) {
epoll_fd = TRY_RET(epoll_create(max));
}
void
Epoll::addFd(Fd fd, Flags<NotifyOn> interest, Tag tag, Mode mode) {
struct epoll_event ev;
ev.events = toEpollEvents(interest);
if (mode == Mode::Edge)
ev.events |= EPOLLET;
ev.data.u64 = tag.value_;
TRY(epoll_ctl(epoll_fd, EPOLL_CTL_ADD, fd, &ev));
}
void
Epoll::addFdOneShot(Fd fd, Flags<NotifyOn> interest, Tag tag, Mode mode) {
struct epoll_event ev;
ev.events = toEpollEvents(interest);
ev.events |= EPOLLONESHOT;
if (mode == Mode::Edge)
ev.events |= EPOLLET;
ev.data.u64 = tag.value_;
TRY(epoll_ctl(epoll_fd, EPOLL_CTL_ADD, fd, &ev));
}
void
Epoll::removeFd(Fd fd) {
struct epoll_event ev;
TRY(epoll_ctl(epoll_fd, EPOLL_CTL_DEL, fd, &ev));
}
void
Epoll::rearmFd(Fd fd, Flags<NotifyOn> interest, Tag tag, Mode mode) {
struct epoll_event ev;
ev.events = toEpollEvents(interest);
if (mode == Mode::Edge)
ev.events |= EPOLLET;
ev.data.u64 = tag.value_;
TRY(epoll_ctl(epoll_fd, EPOLL_CTL_MOD, fd, &ev));
}
int
Epoll::poll(std::vector<Event>& events, size_t maxEvents, std::chrono::milliseconds timeout) const {
struct epoll_event evs[Const::MaxEvents];
int ready_fds = epoll_wait(epoll_fd, evs, maxEvents, timeout.count());
if (ready_fds > 0) {
for (int i = 0; i < ready_fds; ++i) {
const struct epoll_event *ev = evs + i;
const Tag tag(ev->data.u64);
Event event(tag);
event.flags = toNotifyOn(ev->events);
events.push_back(tag);
}
}
return ready_fds;
}
int
Epoll::toEpollEvents(Flags<NotifyOn> interest) const {
int events;
if (interest.hasFlag(NotifyOn::Read))
events |= EPOLLIN;
if (interest.hasFlag(NotifyOn::Write))
events |= EPOLLOUT;
if (interest.hasFlag(NotifyOn::Hangup))
events |= EPOLLHUP;
return events;
}
Flags<NotifyOn>
Epoll::toNotifyOn(int events) const {
Flags<NotifyOn> flags;
if (events & EPOLLIN)
flags.setFlag(NotifyOn::Read);
if (events & EPOLLOUT)
flags.setFlag(NotifyOn::Write);
if (events & EPOLLHUP)
flags.setFlag(NotifyOn::Hangup);
return flags;
}
} // namespace Poller
......@@ -6,8 +6,78 @@
#pragma once
#include <chrono>
#include <vector>
#include "flags.h"
#include "common.h"
typedef int Fd;
int hardware_concurrency();
bool make_non_blocking(int fd);
namespace Polling {
enum class Mode {
Level,
Edge
};
enum class NotifyOn {
Read,
Write,
Hangup
};
DECLARE_FLAGS_OPERATORS(NotifyOn);
struct Tag {
friend class Epoll;
explicit constexpr Tag(uint64_t value)
: value_(value)
{ }
constexpr uint64_t value() const { return value_; }
friend constexpr bool operator==(Tag lhs, Tag rhs);
private:
uint64_t value_;
};
inline constexpr bool operator==(Tag lhs, Tag rhs) {
return lhs.value_ == rhs.value_;
}
struct Event {
Event(Tag tag) :
tag(tag)
{ }
Flags<NotifyOn> flags;
Fd fd;
Tag tag;
};
class Epoll {
public:
Epoll(size_t max = 128);
void addFd(Fd fd, Flags<NotifyOn> interest, Tag tag, Mode mode = Mode::Level);
void addFdOneShot(Fd fd, Flags<NotifyOn> interest, Tag tag, Mode mode = Mode::Level);
void removeFd(Fd fd);
void rearmFd(Fd fd, Flags<NotifyOn> interest, Tag tag, Mode mode = Mode::Level);
int poll(std::vector<Event>& events,
size_t maxEvents = Const::MaxEvents,
std::chrono::milliseconds timeout = std::chrono::milliseconds(0)) const;
private:
int toEpollEvents(Flags<NotifyOn> interest) const;
Flags<NotifyOn> toNotifyOn(int events) const;
int epoll_fd;
};
} // namespace Polling
......@@ -17,9 +17,8 @@ Peer::Peer()
: fd_(-1)
{ }
Peer::Peer(const Address& addr, const string& host)
Peer::Peer(const Address& addr)
: addr(addr)
, host(host)
, fd_(-1)
{ }
......@@ -30,7 +29,7 @@ Peer::address() const {
string
Peer::hostname() const {
return host;
return hostname_;
}
void
......
......@@ -18,7 +18,7 @@ namespace Tcp {
class Peer {
public:
Peer();
Peer(const Address& addr, const std::string& hostname);
Peer(const Address& addr);
Address address() const;
std::string hostname() const;
......@@ -28,7 +28,7 @@ public:
private:
Address addr;
std::string host;
std::string hostname_;
Fd fd_;
};
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment