Commit 61ff656c authored by Mathieu Stefani's avatar Mathieu Stefani

Merge branch 'server-client-request-timeout'

parents baf8b65d 32548c88
......@@ -4,6 +4,8 @@
#include <cstdint>
#include <limits>
#include <chrono>
// Allow compile-time overload
namespace Pistache {
namespace Const {
......@@ -18,6 +20,8 @@ static constexpr size_t DefaultTimerPoolSize = 128;
static constexpr size_t DefaultMaxRequestSize = 4096;
static constexpr size_t DefaultMaxResponseSize =
std::numeric_limits<uint32_t>::max();
static constexpr auto DefaultHeaderTimeout = std::chrono::seconds(60);
static constexpr auto DefaultBodyTimeout = std::chrono::seconds(60);
static constexpr size_t ChunkSize = 1024;
static constexpr uint16_t HTTP_STANDARD_PORT = 80;
......
......@@ -8,8 +8,11 @@
#include <pistache/http.h>
#include <pistache/listener.h>
#include <pistache/transport.h>
#include <pistache/net.h>
#include <chrono>
namespace Pistache {
namespace Http {
......@@ -20,26 +23,55 @@ public:
Options &threads(int val);
Options &threadsName(const std::string &val);
Options &flags(Flags<Tcp::Options> flags);
Options &flags(Tcp::Options tcp_opts) {
flags(Flags<Tcp::Options>(tcp_opts));
return *this;
}
Options &backlog(int val);
Options &maxRequestSize(size_t val);
Options &maxResponseSize(size_t val);
template<typename Duration>
Options &headerTimeout(Duration timeout)
{
headerTimeout_ = std::chrono::duration_cast<std::chrono::milliseconds>(timeout);
return *this;
}
template<typename Duration>
Options& bodyTimeout(Duration timeout)
{
bodyTimeout_ = std::chrono::duration_cast<std::chrono::milliseconds>(timeout);
return *this;
}
Options &logger(PISTACHE_STRING_LOGGER_T logger);
[[deprecated("Replaced by maxRequestSize(val)")]] Options &
maxPayload(size_t val);
private:
// Thread options
int threads_;
std::string threadsName_;
// TCP flags
Flags<Tcp::Options> flags_;
// Backlog size
int backlog_;
// Size options
size_t maxRequestSize_;
size_t maxResponseSize_;
// Timeout options
std::chrono::milliseconds headerTimeout_;
std::chrono::milliseconds bodyTimeout_;
PISTACHE_STRING_LOGGER_T logger_;
Options();
};
......@@ -151,8 +183,8 @@ private:
std::shared_ptr<Handler> handler_;
Tcp::Listener listener;
size_t maxRequestSize_ = Const::DefaultMaxRequestSize;
size_t maxResponseSize_ = Const::DefaultMaxResponseSize;
Options options_;
PISTACHE_STRING_LOGGER_T logger_ = PISTACHE_NULL_STRING_LOGGER;
};
......
......@@ -8,6 +8,7 @@
#include <algorithm>
#include <array>
#include <chrono>
#include <memory>
#include <sstream>
#include <stdexcept>
......@@ -21,6 +22,7 @@
#include <pistache/cookie.h>
#include <pistache/http_defs.h>
#include <pistache/http_headers.h>
#include <pistache/meta.h>
#include <pistache/mime.h>
#include <pistache/net.h>
#include <pistache/stream.h>
......@@ -212,7 +214,7 @@ public:
friend class ResponseWriter;
explicit Timeout(Timeout &&other)
: handler(other.handler), request(std::move(other.request)),
: handler(other.handler),
transport(other.transport), armed(other.armed), timerFd(other.timerFd),
peer(std::move(other.peer)) {
// cppcheck-suppress useInitializationList
......@@ -222,7 +224,7 @@ public:
Timeout &operator=(Timeout &&other) {
handler = other.handler;
transport = other.transport;
request = std::move(other.request);
version = other.version;
armed = other.armed;
timerFd = other.timerFd;
other.timerFd = -1;
......@@ -256,13 +258,13 @@ public:
private:
Timeout(const Timeout &other) = default;
Timeout(Tcp::Transport *transport_, Handler *handler_,
Timeout(Tcp::Transport *transport_, Http::Version version, Handler *handler_,
std::weak_ptr<Tcp::Peer> peer_);
void onTimeout(uint64_t numWakeup);
Handler *handler;
Request request;
Http::Version version;
Tcp::Transport *transport;
bool armed;
Fd timerFd;
......@@ -354,6 +356,9 @@ public:
friend class Private::ResponseLineStep;
ResponseWriter(Http::Version version, Tcp::Transport *transport,
Handler *handler, std::weak_ptr<Tcp::Peer> peer);
//
// C++11: std::weak_ptr move constructor is C++14 only so the default
// version of move constructor / assignement operator does not work and we
......@@ -423,9 +428,6 @@ public:
}
private:
ResponseWriter(Http::Version version, Tcp::Transport *transport,
Handler *handler, std::weak_ptr<Tcp::Peer> peer);
ResponseWriter(const ResponseWriter &other);
Async::Promise<ssize_t> sendImpl(Code code, const char *data,
......@@ -449,12 +451,14 @@ serveFile(ResponseWriter &writer, const std::string &fileName,
namespace Private {
enum class State { Again, Next, Done };
using StepId = uint64_t;
struct Step {
explicit Step(Message *request);
virtual ~Step() = default;
virtual StepId id() const = 0;
virtual State apply(StreamCursor &cursor) = 0;
static void raise(const char *msg, Code code = Code::Bad_Request);
......@@ -465,30 +469,42 @@ protected:
class RequestLineStep : public Step {
public:
static constexpr StepId Id = Meta::Hash::fnv1a("RequestLine");
explicit RequestLineStep(Request *request) : Step(request) {}
StepId id() const override { return Id; }
State apply(StreamCursor &cursor) override;
};
class ResponseLineStep : public Step {
public:
static constexpr StepId Id = Meta::Hash::fnv1a("ResponseLine");
explicit ResponseLineStep(Response *response) : Step(response) {}
StepId id() const override { return Id; }
State apply(StreamCursor &cursor) override;
};
class HeadersStep : public Step {
public:
static constexpr StepId Id = Meta::Hash::fnv1a("Headers");
explicit HeadersStep(Message *request) : Step(request) {}
StepId id() const override { return Id; }
State apply(StreamCursor &cursor) override;
};
class BodyStep : public Step {
public:
static constexpr auto Id = Meta::Hash::fnv1a("Headers");
explicit BodyStep(Message *message_)
: Step(message_), chunk(message_), bytesRead(0) {}
StepId id() const override { return Id; }
State apply(StreamCursor &cursor) override;
private:
......@@ -524,6 +540,8 @@ private:
class ParserBase {
public:
static constexpr size_t StepsCount = 3;
explicit ParserBase(size_t maxDataSize);
ParserBase(const ParserBase &) = delete;
......@@ -537,12 +555,18 @@ public:
virtual void reset();
State parse();
protected:
static constexpr size_t StepsCount = 3;
Step* step();
std::chrono::steady_clock::time_point time() const
{
return time_;
}
protected:
std::array<std::unique_ptr<Step>, StepsCount> allSteps;
size_t currentStep = 0;
std::chrono::steady_clock::time_point time_;
private:
ArrayStreamBuf<char> buffer;
StreamCursor cursor;
......@@ -551,7 +575,6 @@ private:
template <typename Message> class ParserImpl;
template <> class ParserImpl<Http::Request> : public ParserBase {
public:
explicit ParserImpl(size_t maxDataSize);
......@@ -575,6 +598,8 @@ using ResponseParser = Private::ParserImpl<Http::Response>;
class Handler : public Tcp::Handler {
public:
static constexpr const char* ParserData = "__Parser";
virtual void onRequest(const Request &request, ResponseWriter response) = 0;
virtual void onTimeout(const Request &request, ResponseWriter response);
......@@ -584,16 +609,42 @@ public:
void setMaxResponseSize(size_t value);
size_t getMaxResponseSize() const;
template<typename Duration>
void setHeaderTimeout(Duration timeout)
{
headerTimeout_ = std::chrono::duration_cast<std::chrono::milliseconds>(timeout);
}
template<typename Duration>
void setBodyTimeout(Duration timeout)
{
bodyTimeout_ = std::chrono::duration_cast<std::chrono::milliseconds>(timeout);
}
std::chrono::milliseconds getHeaderTimeout() const
{
return headerTimeout_;
}
std::chrono::milliseconds getBodyTimeout() const
{
return bodyTimeout_;
}
static std::shared_ptr<RequestParser> getParser(const std::shared_ptr<Tcp::Peer> &peer);
virtual ~Handler() override {}
private:
void onConnection(const std::shared_ptr<Tcp::Peer> &peer) override;
void onInput(const char *buffer, size_t len,
const std::shared_ptr<Tcp::Peer> &peer) override;
private:
size_t maxRequestSize_ = Const::DefaultMaxRequestSize;
size_t maxResponseSize_ = Const::DefaultMaxResponseSize;
std::chrono::milliseconds headerTimeout_ = Const::DefaultHeaderTimeout;
std::chrono::milliseconds bodyTimeout_ = Const::DefaultBodyTimeout;
};
template <typename H, typename... Args>
......
......@@ -45,7 +45,9 @@ public:
TimePoint tick;
};
Listener() = default;
using TransportFactory = std::function<std::shared_ptr<Transport> ()>;
Listener();
~Listener();
explicit Listener(const Address &address);
......@@ -54,6 +56,8 @@ public:
const std::string &workersName = "",
int backlog = Const::MaxBacklog,
PISTACHE_STRING_LOGGER_T logger = PISTACHE_NULL_STRING_LOGGER);
void setTransportFactory(TransportFactory factory);
void setHandler(const std::shared_ptr<Handler> &handler);
void bind();
......@@ -96,6 +100,10 @@ private:
Aio::Reactor reactor_;
Aio::Reactor::Key transportKey;
TransportFactory transportFactory_;
TransportFactory defaultTransportFactory() const;
void handleNewConnection();
int acceptConnection(struct sockaddr_in &peer_addr) const;
void dispatchPeer(const std::shared_ptr<Peer> &peer);
......
#pragma once
#include <cstdint>
namespace Pistache
{
namespace Meta
{
namespace Hash
{
static constexpr uint64_t val64 = 0xcbf29ce484222325;
static constexpr uint64_t prime64 = 0x100000001b3;
inline constexpr uint64_t fnv1a(const char* const str, const uint64_t value = val64) noexcept {
return (str[0] == '\0') ? value : fnv1a(&str[1], (value ^ uint64_t(str[0])) * prime64);
}
}
}
}
\ No newline at end of file
......@@ -44,6 +44,10 @@ public:
void *ssl() const;
void putData(std::string name, std::shared_ptr<void> data);
std::shared_ptr<void> getData(std::string name) const;
std::shared_ptr<void> tryGetData(std::string name) const;
Async::Promise<ssize_t> send(const RawBuffer &buffer, int flags = 0);
size_t getID() const;
......@@ -51,11 +55,6 @@ protected:
Peer(Fd fd, const Address &addr, void *ssl);
private:
void setParser(std::shared_ptr<Http::RequestParser> parser);
std::shared_ptr<Http::RequestParser> getParser() const;
Http::Request &request();
void associateTransport(Transport *transport);
Transport *transport() const;
......@@ -64,7 +63,7 @@ private:
Address addr;
std::string hostname_;
std::shared_ptr<Http::RequestParser> parser_;
std::unordered_map<std::string, std::shared_ptr<void>> data_;
void *ssl_ = nullptr;
const size_t id_;
......
......@@ -13,16 +13,15 @@ namespace Pistache {
/* In a sense, a Prototype is just a class that provides a clone() method */
template <typename Class> struct Prototype {
public:
virtual ~Prototype() {}
virtual std::shared_ptr<Class> clone() const = 0;
};
} // namespace Pistache
#define PROTOTYPE_OF(Base, Class) \
private: \
std::shared_ptr<Base> clone() const override { \
return std::make_shared<Class>(*this); \
} \
\
public:
#define PROTOTYPE_OF(Base, Class) \
public: \
std::shared_ptr<Base> clone() const override { \
return std::make_shared<Class>(*this); \
} \
......@@ -32,7 +32,7 @@ enum class Options : uint64_t {
DECLARE_FLAGS_OPERATORS(Options)
class Handler : private Prototype<Handler> {
class Handler : public Prototype<Handler> {
public:
friend class Transport;
......
......@@ -27,15 +27,16 @@ class Handler;
class Transport : public Aio::Handler {
public:
explicit Transport(const std::shared_ptr<Tcp::Handler> &handler);
Transport(const Transport &) = delete;
Transport &operator=(const Transport &) = delete;
void init(const std::shared_ptr<Tcp::Handler> &handler);
void registerPoller(Polling::Epoll &poller) override;
virtual void registerPoller(Polling::Epoll &poller) override;
void handleNewPeer(const std::shared_ptr<Peer> &peer);
void onReady(const Aio::FdSet &fds) override;
virtual void onReady(const Aio::FdSet &fds) override;
template <typename Buf>
Async::Promise<ssize_t> asyncWrite(Fd fd, const Buf &buffer, int flags = 0) {
......@@ -167,13 +168,17 @@ private:
std::unordered_map<Fd, TimerEntry> timers;
PollableQueue<PeerEntry> peersQueue;
std::unordered_map<Fd, std::shared_ptr<Peer>> peers;
Async::Deferred<rusage> loadRequest_;
NotifyFd notifier;
std::shared_ptr<Tcp::Handler> handler_;
protected:
void removePeer(const std::shared_ptr<Peer>& peer);
std::unordered_map<Fd, std::shared_ptr<Peer>> peers;
private:
bool isPeerFd(Fd fd) const;
bool isTimerFd(Fd fd) const;
bool isPeerFd(Polling::Tag tag) const;
......
......@@ -245,7 +245,7 @@ State ResponseLineStep::apply(StreamCursor &cursor) {
char *end;
auto code = strtol(codeToken.rawText(), &end, 10);
if (*end != ' ')
raise("Failed to parsed return code");
raise("Failed to parse return code");
response->code_ = static_cast<Http::Code>(code);
if (!cursor.advance(1))
......@@ -482,6 +482,7 @@ State ParserBase::parse() {
}
bool ParserBase::feed(const char *data, size_t len) {
time_ = std::chrono::steady_clock::now();
return buffer.feed(data, len);
}
......@@ -490,6 +491,12 @@ void ParserBase::reset() {
cursor.reset();
currentStep = 0;
time_ = std::chrono::steady_clock::time_point(std::chrono::steady_clock::duration(0));
}
Step* ParserBase::step()
{
return allSteps[currentStep].get();
}
} // namespace Private
......@@ -656,20 +663,29 @@ void ResponseStream::ends() {
}
ResponseWriter::ResponseWriter(ResponseWriter &&other)
: response_(std::move(other.response_)), peer_(other.peer_),
buf_(std::move(other.buf_)), transport_(other.transport_),
timeout_(std::move(other.timeout_)) {}
: response_(std::move(other.response_))
, peer_(other.peer_)
, buf_(std::move(other.buf_))
, transport_(other.transport_)
, timeout_(std::move(other.timeout_))
{}
ResponseWriter::ResponseWriter(Http::Version version, Tcp::Transport *transport,
Handler *handler, std::weak_ptr<Tcp::Peer> peer)
: response_(version), peer_(peer),
buf_(DefaultStreamSize, handler->getMaxResponseSize()),
transport_(transport), timeout_(transport, handler, peer) {}
: response_(version)
, peer_(peer)
, buf_(DefaultStreamSize, handler->getMaxResponseSize())
, transport_(transport)
, timeout_(transport, version, handler, peer)
{}
ResponseWriter::ResponseWriter(const ResponseWriter &other)
: response_(other.response_), peer_(other.peer_),
buf_(DefaultStreamSize, other.buf_.maxSize()),
transport_(other.transport_), timeout_(other.timeout_) {}
: response_(other.response_)
, peer_(other.peer_)
, buf_(DefaultStreamSize, other.buf_.maxSize())
, transport_(other.transport_)
, timeout_(other.timeout_)
{}
void ResponseWriter::setMime(const Mime::MediaType &mime) {
auto ct = response_.headers().tryGet<Header::ContentType>();
......@@ -916,8 +932,8 @@ Private::ParserImpl<Http::Response>::ParserImpl(size_t maxDataSize)
void Handler::onInput(const char *buffer, size_t len,
const std::shared_ptr<Tcp::Peer> &peer) {
auto parser = peer->getParser();
auto &request = peer->request();
auto parser = getParser(peer);
auto &request = parser->request;
try {
if (!parser->feed(buffer, len)) {
parser->reset();
......@@ -962,11 +978,14 @@ void Handler::onInput(const char *buffer, size_t len,
}
void Handler::onConnection(const std::shared_ptr<Tcp::Peer> &peer) {
peer->setParser(std::make_shared<RequestParser>(maxRequestSize_));
peer->putData(ParserData, std::make_shared<RequestParser>(maxRequestSize_));
}
void Handler::onTimeout(const Request & /*request*/,
ResponseWriter /*response*/) {}
void Handler::onTimeout(const Request& /*request*/,
ResponseWriter response)
{
response.send(Code::Request_Timeout);
}
Timeout::~Timeout() { disarm(); }
......@@ -978,10 +997,15 @@ void Timeout::disarm() {
bool Timeout::isArmed() const { return armed; }
Timeout::Timeout(Tcp::Transport *transport_, Handler *handler_,
Timeout::Timeout(Tcp::Transport *transport_, Http::Version version, Handler *handler_,
std::weak_ptr<Tcp::Peer> peer_)
: handler(handler_), transport(transport_), armed(false), timerFd(-1),
peer(peer_) {}
: handler(handler_)
, transport(transport_)
, version(version)
, armed(false)
, timerFd(-1)
, peer(peer_)
{}
void Timeout::onTimeout(uint64_t numWakeup) {
UNUSED(numWakeup)
......@@ -989,9 +1013,10 @@ void Timeout::onTimeout(uint64_t numWakeup) {
if (!sp)
return;
ResponseWriter response(sp->request().version(), transport, handler, peer);
handler->onTimeout(sp->request(), std::move(response));
ResponseWriter response(version, transport, handler, peer);
auto parser = Handler::getParser(sp);
const auto& request = parser->request;
handler->onTimeout(request, std::move(response));
}
void Handler::setMaxRequestSize(size_t value) { maxRequestSize_ = value; }
......@@ -1002,5 +1027,10 @@ void Handler::setMaxResponseSize(size_t value) { maxResponseSize_ = value; }
size_t Handler::getMaxResponseSize() const { return maxResponseSize_; }
std::shared_ptr<RequestParser>
Handler::getParser(const std::shared_ptr<Tcp::Peer> &peer) {
return std::static_pointer_cast<RequestParser>(peer->getData(ParserData));
}
} // namespace Http
} // namespace Pistache
......@@ -78,18 +78,30 @@ int Peer::fd() const {
return fd_;
}
void Peer::setParser(std::shared_ptr<Http::RequestParser> parser) {
parser_ = parser;
}
void Peer::putData(std::string name, std::shared_ptr<void> data) {
auto it = data_.find(name);
if (it != std::end(data_)) {
throw std::runtime_error("The data already exists");
}
std::shared_ptr<Http::RequestParser> Peer::getParser() const { return parser_; }
data_.insert(std::make_pair(std::move(name), std::move(data)));
}
Http::Request &Peer::request() {
if (!parser_) {
throw std::runtime_error("The peer has no associated parser");
std::shared_ptr<void> Peer::getData(std::string name) const {
auto data = tryGetData(std::move(name));
if (data == nullptr) {
throw std::runtime_error("The data does not exist");
}
return parser_->request;
return data;
}
std::shared_ptr<void> Peer::tryGetData(std::string(name)) const {
auto it = data_.find(name);
if (it == std::end(data_))
return nullptr;
return it->second;
}
Async::Promise<ssize_t> Peer::send(const RawBuffer &buffer, int flags) {
......
......@@ -66,7 +66,6 @@ public:
Reactor::Key addHandler(const std::shared_ptr<Handler> &handler,
bool setKey = true) override {
handler->registerPoller(poller);
handler->reactor_ = reactor_;
......
......@@ -160,6 +160,11 @@ void Transport::handleIncoming(const std::shared_ptr<Peer> &peer) {
void Transport::handlePeerDisconnection(const std::shared_ptr<Peer> &peer) {
handler_->onDisconnection(peer);
removePeer(peer);
}
void Transport::removePeer(const std::shared_ptr<Peer>& peer)
{
int fd = peer->fd();
auto it = peers.find(fd);
if (it == std::end(peers))
......
......@@ -9,14 +9,149 @@
#include <pistache/peer.h>
#include <pistache/tcp.h>
#include <array>
#include <chrono>
namespace Pistache {
namespace Http {
class TransportImpl : public Tcp::Transport
{
public:
using Base = Tcp::Transport;
explicit TransportImpl(const std::shared_ptr<Tcp::Handler>& handler);
void registerPoller(Polling::Epoll& poller) override;
void onReady(const Aio::FdSet& fds) override;
void setHeaderTimeout(std::chrono::milliseconds timeout);
void setBodyTimeout(std::chrono::milliseconds timeout);
std::shared_ptr<Aio::Handler> clone() const override;
private:
std::shared_ptr<Tcp::Handler> handler_;
std::chrono::milliseconds headerTimeout_;
std::chrono::milliseconds bodyTimeout_;
int timerFd;
void checkIdlePeers();
};
TransportImpl::TransportImpl(const std::shared_ptr<Tcp::Handler>& handler)
: Tcp::Transport(handler)
, handler_(handler)
{ }
void TransportImpl::registerPoller(Polling::Epoll& poller)
{
Base::registerPoller(poller);
timerFd = TRY_RET(timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK));
static constexpr auto TimerInterval = std::chrono::milliseconds(500);
static constexpr auto TimerIntervalNs = std::chrono::duration_cast<std::chrono::nanoseconds>(TimerInterval);
static_assert(
TimerInterval < std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::seconds(1)),
"Timer frequency should be less than 1 second"
);
itimerspec spec;
spec.it_value.tv_sec = 0;
spec.it_value.tv_nsec = TimerIntervalNs.count();
spec.it_interval.tv_sec = 0;
spec.it_interval.tv_nsec = TimerIntervalNs.count();
TRY(timerfd_settime(timerFd, 0, &spec, 0));
Polling::Tag tag(timerFd);
poller.addFd(timerFd, Flags<Polling::NotifyOn>(Polling::NotifyOn::Read), Polling::Tag(timerFd));
}
void TransportImpl::onReady(const Aio::FdSet& fds)
{
bool handled = false;
for (const auto& entry: fds)
{
if (entry.getTag() == Polling::Tag(timerFd))
{
uint64_t wakeups;
::read(timerFd, &wakeups, sizeof wakeups);
checkIdlePeers();
handled = true;
}
}
if (!handled)
Base::onReady(fds);
}
void TransportImpl::setHeaderTimeout(std::chrono::milliseconds timeout)
{
headerTimeout_ = timeout;
}
void TransportImpl::setBodyTimeout(std::chrono::milliseconds timeout)
{
bodyTimeout_ = timeout;
}
void TransportImpl::checkIdlePeers()
{
std::vector<std::shared_ptr<Tcp::Peer>> idlePeers;
for (const auto& peerPair: peers)
{
const auto& peer = peerPair.second;
auto parser = Http::Handler::getParser(peer);
auto time = parser->time();
auto now = std::chrono::steady_clock::now();
auto elapsed = now - time;
auto* step = parser->step();
if (step->id() == Private::RequestLineStep::Id)
{
if (elapsed > headerTimeout_ || elapsed > bodyTimeout_)
idlePeers.push_back(peer);
}
else if (step->id() == Private::HeadersStep::Id)
{
if (elapsed > bodyTimeout_)
idlePeers.push_back(peer);
}
}
for (const auto& idlePeer: idlePeers)
{
ResponseWriter response(Http::Version::Http11, this, static_cast<Http::Handler *>(handler_.get()), idlePeer);
response.send(Http::Code::Request_Timeout).then([=](ssize_t) {
removePeer(idlePeer);
}, [=](std::exception_ptr) {
removePeer(idlePeer);
});
}
}
std::shared_ptr<Aio::Handler> TransportImpl::clone() const
{
auto transport = std::make_shared<TransportImpl>(handler_->clone());
transport->setHeaderTimeout(headerTimeout_);
transport->setBodyTimeout(bodyTimeout_);
return transport;
}
Endpoint::Options::Options()
: threads_(1), flags_(), backlog_(Const::MaxBacklog),
maxRequestSize_(Const::DefaultMaxRequestSize),
maxResponseSize_(Const::DefaultMaxResponseSize),
logger_(PISTACHE_NULL_STRING_LOGGER) {}
: threads_(1), flags_(), backlog_(Const::MaxBacklog)
, maxRequestSize_(Const::DefaultMaxRequestSize)
, maxResponseSize_(Const::DefaultMaxResponseSize)
, headerTimeout_(Const::DefaultHeaderTimeout)
, bodyTimeout_(Const::DefaultBodyTimeout)
, logger_(PISTACHE_NULL_STRING_LOGGER)
{}
Endpoint::Options &Endpoint::Options::threads(int val) {
threads_ = val;
......@@ -63,15 +198,25 @@ Endpoint::Endpoint(const Address &addr) : listener(addr) {}
void Endpoint::init(const Endpoint::Options &options) {
listener.init(options.threads_, options.flags_, options.threadsName_);
maxRequestSize_ = options.maxRequestSize_;
maxResponseSize_ = options.maxResponseSize_;
listener.setTransportFactory([&] {
if (!handler_)
throw std::runtime_error("Must call setHandler()");
auto transport = std::make_shared<TransportImpl>(handler_);
transport->setHeaderTimeout(options.headerTimeout_);
transport->setBodyTimeout(options.bodyTimeout_);
return transport;
});
options_ = options;
logger_ = options.logger_;
}
void Endpoint::setHandler(const std::shared_ptr<Handler> &handler) {
handler_ = handler;
handler_->setMaxRequestSize(maxRequestSize_);
handler_->setMaxResponseSize(maxResponseSize_);
handler_->setMaxRequestSize(options_.maxRequestSize_);
handler_->setMaxResponseSize(options_.maxResponseSize_);
}
void Endpoint::bind() { listener.bind(); }
......
......@@ -148,7 +148,14 @@ void setSocketOptions(Fd fd, Flags<Options> options) {
}
}
Listener::Listener(const Address &address) : addr_(address) {}
Listener::Listener()
: transportFactory_(defaultTransportFactory())
{}
Listener::Listener(const Address &address)
: addr_(address)
, transportFactory_(defaultTransportFactory())
{}
Listener::~Listener() {
if (isBound())
......@@ -165,16 +172,21 @@ Listener::~Listener() {
void Listener::init(size_t workers, Flags<Options> options,
const std::string &workersName, int backlog,
PISTACHE_STRING_LOGGER_T logger) {
if (workers > hardware_concurrency()) {
if (workers > hardware_concurrency()) {
// Log::warning() << "More workers than available cores"
}
}
options_ = options;
backlog_ = backlog;
useSSL_ = false;
workers_ = workers;
workersName_ = workersName;
logger_ = logger;
}
options_ = options;
backlog_ = backlog;
useSSL_ = false;
workers_ = workers;
workersName_ = workersName;
logger_ = logger;
void Listener::setTransportFactory(TransportFactory factory)
{
transportFactory_ = std::move(factory);
}
void Listener::setHandler(const std::shared_ptr<Handler> &handler) {
......@@ -200,8 +212,6 @@ void Listener::pinWorker(size_t worker, const CpuSet &set) {
void Listener::bind() { bind(addr_); }
void Listener::bind(const Address &address) {
if (!handler_)
throw std::runtime_error("Call setHandler before calling bind()");
addr_ = address;
struct addrinfo hints;
......@@ -252,7 +262,7 @@ void Listener::bind(const Address &address) {
Polling::Tag(fd));
listen_fd = fd;
auto transport = std::make_shared<Transport>(handler_);
auto transport = transportFactory_();
reactor_.init(Aio::AsyncContext(workers_, workersName_));
transportKey = reactor_.addHandler(transport);
......@@ -451,6 +461,16 @@ void Listener::dispatchPeer(const std::shared_ptr<Peer> &peer) {
transport->handleNewPeer(peer);
}
Listener::TransportFactory Listener::defaultTransportFactory() const
{
return [&] {
if (!handler_)
throw std::runtime_error("setHandler() has not been called");
return std::make_shared<Transport>(handler_);
};
}
#ifdef PISTACHE_USE_SSL
void Listener::setupSSLAuth(const std::string &ca_file,
......
......@@ -15,6 +15,8 @@
#include <string>
#include <thread>
#include "tcp_client.h"
using namespace Pistache;
#define THREAD_INFO \
......@@ -96,6 +98,22 @@ struct AddressEchoHandler : public Http::Handler {
}
};
struct PingHandler : public Http::Handler {
HTTP_PROTOTYPE(PingHandler)
PingHandler() = default;
void onRequest(const Http::Request& request,
Http::ResponseWriter writer) override {
if (request.resource() == "/ping") {
writer.send(Http::Code::Ok, "PONG");
}
else {
writer.send(Http::Code::Not_Found);
}
}
};
int clientLogicFunc(int response_size, const std::string &server_page,
int timeout_seconds, int wait_seconds) {
Http::Client client;
......@@ -429,6 +447,79 @@ TEST(http_server_test, response_size_captured) {
ASSERT_EQ(rcode, Http::Code::Ok);
}
TEST(http_server_test, client_request_header_timeout_raises_http_408) {
Pistache::Address address("localhost", Pistache::Port(0));
auto timeout = std::chrono::seconds(2);
Http::Endpoint server(address);
auto flags = Tcp::Options::ReuseAddr;
auto opts = Http::Endpoint::options()
.flags(flags)
.headerTimeout(timeout);
server.init(opts);
server.setHandler(Http::make_handler<PingHandler>());
server.serveThreaded();
auto port = server.getPort();
auto addr = "localhost:" + port.toString();
std::cout << "Server address: " << addr << "\n";
char recvBuf[1024];
std::memset(recvBuf, 0, sizeof(recvBuf));
size_t bytes;
TcpClient client;
ASSERT_TRUE(client.connect(Pistache::Address("localhost", port))) << client.lastError();
ASSERT_TRUE(client.receive(recvBuf, sizeof(recvBuf), &bytes, std::chrono::seconds(5))) << client.lastError();
server.shutdown();
}
TEST(http_server_test, client_request_body_timeout_raises_http_408) {
Pistache::Address address("localhost", Pistache::Port(0));
auto headerTimeout = std::chrono::seconds(1);
auto bodyTimeout = std::chrono::seconds(1);
Http::Endpoint server(address);
auto flags = Tcp::Options::ReuseAddr;
auto opts = Http::Endpoint::options()
.flags(flags)
.headerTimeout(headerTimeout)
.bodyTimeout(bodyTimeout);
server.init(opts);
server.setHandler(Http::make_handler<PingHandler>());
server.serveThreaded();
auto port = server.getPort();
auto addr = "localhost:" + port.toString();
std::cout << "Server address: " << addr << "\n";
std::string reqStr = "GET /ping HTTP/1.1\r\n";
std::string headerStr = "Host: localhost\r\nUser-Agent: test\r\n";
char recvBuf[1024];
std::memset(recvBuf, 0, sizeof(recvBuf));
size_t bytes;
TcpClient client;
ASSERT_TRUE(client.connect(Pistache::Address("localhost", port))) << client.lastError();
ASSERT_TRUE(client.send(reqStr)) << client.lastError();
std::this_thread::sleep_for(headerTimeout / 2);
ASSERT_TRUE(client.send(headerStr)) << client.lastError();
static constexpr const char* ExpectedResponseLine = "HTTP/1.1 408 Request Timeout";
ASSERT_TRUE(client.receive(recvBuf, sizeof(recvBuf), &bytes, std::chrono::seconds(5))) << client.lastError();
ASSERT_TRUE(!strncmp(recvBuf, ExpectedResponseLine, strlen(ExpectedResponseLine)));
server.shutdown();
}
namespace {
class WaitHelper {
......
#pragma once
#include <pistache/net.h>
#include <pistache/os.h>
#include <netdb.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <poll.h>
namespace Pistache
{
#define CLIENT_TRY(...) \
do { \
auto ret = __VA_ARGS__; \
if (ret < 0) { \
lastError_ = strerror(errno); \
return false; \
} \
} while (0) \
class TcpClient
{
public:
bool connect(const Pistache::Address& address)
{
struct addrinfo hints;
std::memset(&hints, 0, sizeof(hints));
hints.ai_family = address.family();
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = 0;
hints.ai_protocol = 0;
auto host = address.host();
auto port = address.port().toString();
AddrInfo addrInfo;
CLIENT_TRY(addrInfo.invoke(host.c_str(), port.c_str(), &hints));
const auto* addrs = addrInfo.get_info_ptr();
int sfd = -1;
auto* addr = addrs;
for (; addr; addr = addr->ai_next) {
sfd = ::socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol);
if (sfd < 0)
continue;
break;
}
CLIENT_TRY(sfd);
CLIENT_TRY(::connect(sfd, addr->ai_addr, addr->ai_addrlen));
make_non_blocking(sfd);
fd_ = sfd;
return true;
}
bool send(const std::string& data)
{
return send(data.c_str(), data.size());
}
bool send(const void* data, size_t size)
{
CLIENT_TRY(::send(fd_, data, size, 0));
return true;
}
template<typename Duration>
bool receive(void* buffer, size_t size, size_t *bytes, Duration timeout)
{
struct pollfd fds[1];
fds[0].fd = fd_;
fds[0].events = POLLIN;
auto timeoutMs = std::chrono::duration_cast<std::chrono::milliseconds>(timeout);
auto ret = ::poll(fds, 1, static_cast<int>(timeoutMs.count()));
if (ret < 0)
{
lastError_ = strerror(errno);
return false;
}
if (ret == 0)
{
lastError_ = "Poll timeout";
return false;
}
if (fds[0].revents & POLLERR)
{
lastError_ = "An error has occured on the stream";
return false;
}
auto res = ::recv(fd_, buffer, size, 0);
if (res < 0)
{
lastError_ = strerror(errno);
return false;
}
*bytes = static_cast<size_t>(res);
return true;
}
std::string lastError() const
{
return lastError_;
}
private:
int fd_;
std::string lastError_;
};
#undef CLIENT_TRY
} // namespace Pistache
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment