Commit 61ff656c authored by Mathieu Stefani's avatar Mathieu Stefani

Merge branch 'server-client-request-timeout'

parents baf8b65d 32548c88
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
#include <cstdint> #include <cstdint>
#include <limits> #include <limits>
#include <chrono>
// Allow compile-time overload // Allow compile-time overload
namespace Pistache { namespace Pistache {
namespace Const { namespace Const {
...@@ -18,6 +20,8 @@ static constexpr size_t DefaultTimerPoolSize = 128; ...@@ -18,6 +20,8 @@ static constexpr size_t DefaultTimerPoolSize = 128;
static constexpr size_t DefaultMaxRequestSize = 4096; static constexpr size_t DefaultMaxRequestSize = 4096;
static constexpr size_t DefaultMaxResponseSize = static constexpr size_t DefaultMaxResponseSize =
std::numeric_limits<uint32_t>::max(); std::numeric_limits<uint32_t>::max();
static constexpr auto DefaultHeaderTimeout = std::chrono::seconds(60);
static constexpr auto DefaultBodyTimeout = std::chrono::seconds(60);
static constexpr size_t ChunkSize = 1024; static constexpr size_t ChunkSize = 1024;
static constexpr uint16_t HTTP_STANDARD_PORT = 80; static constexpr uint16_t HTTP_STANDARD_PORT = 80;
......
...@@ -8,8 +8,11 @@ ...@@ -8,8 +8,11 @@
#include <pistache/http.h> #include <pistache/http.h>
#include <pistache/listener.h> #include <pistache/listener.h>
#include <pistache/transport.h>
#include <pistache/net.h> #include <pistache/net.h>
#include <chrono>
namespace Pistache { namespace Pistache {
namespace Http { namespace Http {
...@@ -20,26 +23,55 @@ public: ...@@ -20,26 +23,55 @@ public:
Options &threads(int val); Options &threads(int val);
Options &threadsName(const std::string &val); Options &threadsName(const std::string &val);
Options &flags(Flags<Tcp::Options> flags); Options &flags(Flags<Tcp::Options> flags);
Options &flags(Tcp::Options tcp_opts) { Options &flags(Tcp::Options tcp_opts) {
flags(Flags<Tcp::Options>(tcp_opts)); flags(Flags<Tcp::Options>(tcp_opts));
return *this; return *this;
} }
Options &backlog(int val); Options &backlog(int val);
Options &maxRequestSize(size_t val); Options &maxRequestSize(size_t val);
Options &maxResponseSize(size_t val); Options &maxResponseSize(size_t val);
template<typename Duration>
Options &headerTimeout(Duration timeout)
{
headerTimeout_ = std::chrono::duration_cast<std::chrono::milliseconds>(timeout);
return *this;
}
template<typename Duration>
Options& bodyTimeout(Duration timeout)
{
bodyTimeout_ = std::chrono::duration_cast<std::chrono::milliseconds>(timeout);
return *this;
}
Options &logger(PISTACHE_STRING_LOGGER_T logger); Options &logger(PISTACHE_STRING_LOGGER_T logger);
[[deprecated("Replaced by maxRequestSize(val)")]] Options & [[deprecated("Replaced by maxRequestSize(val)")]] Options &
maxPayload(size_t val); maxPayload(size_t val);
private: private:
// Thread options
int threads_; int threads_;
std::string threadsName_; std::string threadsName_;
// TCP flags
Flags<Tcp::Options> flags_; Flags<Tcp::Options> flags_;
// Backlog size
int backlog_; int backlog_;
// Size options
size_t maxRequestSize_; size_t maxRequestSize_;
size_t maxResponseSize_; size_t maxResponseSize_;
// Timeout options
std::chrono::milliseconds headerTimeout_;
std::chrono::milliseconds bodyTimeout_;
PISTACHE_STRING_LOGGER_T logger_; PISTACHE_STRING_LOGGER_T logger_;
Options(); Options();
}; };
...@@ -151,8 +183,8 @@ private: ...@@ -151,8 +183,8 @@ private:
std::shared_ptr<Handler> handler_; std::shared_ptr<Handler> handler_;
Tcp::Listener listener; Tcp::Listener listener;
size_t maxRequestSize_ = Const::DefaultMaxRequestSize;
size_t maxResponseSize_ = Const::DefaultMaxResponseSize; Options options_;
PISTACHE_STRING_LOGGER_T logger_ = PISTACHE_NULL_STRING_LOGGER; PISTACHE_STRING_LOGGER_T logger_ = PISTACHE_NULL_STRING_LOGGER;
}; };
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <algorithm> #include <algorithm>
#include <array> #include <array>
#include <chrono>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
...@@ -21,6 +22,7 @@ ...@@ -21,6 +22,7 @@
#include <pistache/cookie.h> #include <pistache/cookie.h>
#include <pistache/http_defs.h> #include <pistache/http_defs.h>
#include <pistache/http_headers.h> #include <pistache/http_headers.h>
#include <pistache/meta.h>
#include <pistache/mime.h> #include <pistache/mime.h>
#include <pistache/net.h> #include <pistache/net.h>
#include <pistache/stream.h> #include <pistache/stream.h>
...@@ -212,7 +214,7 @@ public: ...@@ -212,7 +214,7 @@ public:
friend class ResponseWriter; friend class ResponseWriter;
explicit Timeout(Timeout &&other) explicit Timeout(Timeout &&other)
: handler(other.handler), request(std::move(other.request)), : handler(other.handler),
transport(other.transport), armed(other.armed), timerFd(other.timerFd), transport(other.transport), armed(other.armed), timerFd(other.timerFd),
peer(std::move(other.peer)) { peer(std::move(other.peer)) {
// cppcheck-suppress useInitializationList // cppcheck-suppress useInitializationList
...@@ -222,7 +224,7 @@ public: ...@@ -222,7 +224,7 @@ public:
Timeout &operator=(Timeout &&other) { Timeout &operator=(Timeout &&other) {
handler = other.handler; handler = other.handler;
transport = other.transport; transport = other.transport;
request = std::move(other.request); version = other.version;
armed = other.armed; armed = other.armed;
timerFd = other.timerFd; timerFd = other.timerFd;
other.timerFd = -1; other.timerFd = -1;
...@@ -256,13 +258,13 @@ public: ...@@ -256,13 +258,13 @@ public:
private: private:
Timeout(const Timeout &other) = default; Timeout(const Timeout &other) = default;
Timeout(Tcp::Transport *transport_, Handler *handler_, Timeout(Tcp::Transport *transport_, Http::Version version, Handler *handler_,
std::weak_ptr<Tcp::Peer> peer_); std::weak_ptr<Tcp::Peer> peer_);
void onTimeout(uint64_t numWakeup); void onTimeout(uint64_t numWakeup);
Handler *handler; Handler *handler;
Request request; Http::Version version;
Tcp::Transport *transport; Tcp::Transport *transport;
bool armed; bool armed;
Fd timerFd; Fd timerFd;
...@@ -354,6 +356,9 @@ public: ...@@ -354,6 +356,9 @@ public:
friend class Private::ResponseLineStep; friend class Private::ResponseLineStep;
ResponseWriter(Http::Version version, Tcp::Transport *transport,
Handler *handler, std::weak_ptr<Tcp::Peer> peer);
// //
// C++11: std::weak_ptr move constructor is C++14 only so the default // C++11: std::weak_ptr move constructor is C++14 only so the default
// version of move constructor / assignement operator does not work and we // version of move constructor / assignement operator does not work and we
...@@ -423,9 +428,6 @@ public: ...@@ -423,9 +428,6 @@ public:
} }
private: private:
ResponseWriter(Http::Version version, Tcp::Transport *transport,
Handler *handler, std::weak_ptr<Tcp::Peer> peer);
ResponseWriter(const ResponseWriter &other); ResponseWriter(const ResponseWriter &other);
Async::Promise<ssize_t> sendImpl(Code code, const char *data, Async::Promise<ssize_t> sendImpl(Code code, const char *data,
...@@ -449,12 +451,14 @@ serveFile(ResponseWriter &writer, const std::string &fileName, ...@@ -449,12 +451,14 @@ serveFile(ResponseWriter &writer, const std::string &fileName,
namespace Private { namespace Private {
enum class State { Again, Next, Done }; enum class State { Again, Next, Done };
using StepId = uint64_t;
struct Step { struct Step {
explicit Step(Message *request); explicit Step(Message *request);
virtual ~Step() = default; virtual ~Step() = default;
virtual StepId id() const = 0;
virtual State apply(StreamCursor &cursor) = 0; virtual State apply(StreamCursor &cursor) = 0;
static void raise(const char *msg, Code code = Code::Bad_Request); static void raise(const char *msg, Code code = Code::Bad_Request);
...@@ -465,30 +469,42 @@ protected: ...@@ -465,30 +469,42 @@ protected:
class RequestLineStep : public Step { class RequestLineStep : public Step {
public: public:
static constexpr StepId Id = Meta::Hash::fnv1a("RequestLine");
explicit RequestLineStep(Request *request) : Step(request) {} explicit RequestLineStep(Request *request) : Step(request) {}
StepId id() const override { return Id; }
State apply(StreamCursor &cursor) override; State apply(StreamCursor &cursor) override;
}; };
class ResponseLineStep : public Step { class ResponseLineStep : public Step {
public: public:
static constexpr StepId Id = Meta::Hash::fnv1a("ResponseLine");
explicit ResponseLineStep(Response *response) : Step(response) {} explicit ResponseLineStep(Response *response) : Step(response) {}
StepId id() const override { return Id; }
State apply(StreamCursor &cursor) override; State apply(StreamCursor &cursor) override;
}; };
class HeadersStep : public Step { class HeadersStep : public Step {
public: public:
static constexpr StepId Id = Meta::Hash::fnv1a("Headers");
explicit HeadersStep(Message *request) : Step(request) {} explicit HeadersStep(Message *request) : Step(request) {}
StepId id() const override { return Id; }
State apply(StreamCursor &cursor) override; State apply(StreamCursor &cursor) override;
}; };
class BodyStep : public Step { class BodyStep : public Step {
public: public:
static constexpr auto Id = Meta::Hash::fnv1a("Headers");
explicit BodyStep(Message *message_) explicit BodyStep(Message *message_)
: Step(message_), chunk(message_), bytesRead(0) {} : Step(message_), chunk(message_), bytesRead(0) {}
StepId id() const override { return Id; }
State apply(StreamCursor &cursor) override; State apply(StreamCursor &cursor) override;
private: private:
...@@ -524,6 +540,8 @@ private: ...@@ -524,6 +540,8 @@ private:
class ParserBase { class ParserBase {
public: public:
static constexpr size_t StepsCount = 3;
explicit ParserBase(size_t maxDataSize); explicit ParserBase(size_t maxDataSize);
ParserBase(const ParserBase &) = delete; ParserBase(const ParserBase &) = delete;
...@@ -537,12 +555,18 @@ public: ...@@ -537,12 +555,18 @@ public:
virtual void reset(); virtual void reset();
State parse(); State parse();
protected: Step* step();
static constexpr size_t StepsCount = 3; std::chrono::steady_clock::time_point time() const
{
return time_;
}
protected:
std::array<std::unique_ptr<Step>, StepsCount> allSteps; std::array<std::unique_ptr<Step>, StepsCount> allSteps;
size_t currentStep = 0; size_t currentStep = 0;
std::chrono::steady_clock::time_point time_;
private: private:
ArrayStreamBuf<char> buffer; ArrayStreamBuf<char> buffer;
StreamCursor cursor; StreamCursor cursor;
...@@ -551,7 +575,6 @@ private: ...@@ -551,7 +575,6 @@ private:
template <typename Message> class ParserImpl; template <typename Message> class ParserImpl;
template <> class ParserImpl<Http::Request> : public ParserBase { template <> class ParserImpl<Http::Request> : public ParserBase {
public: public:
explicit ParserImpl(size_t maxDataSize); explicit ParserImpl(size_t maxDataSize);
...@@ -575,6 +598,8 @@ using ResponseParser = Private::ParserImpl<Http::Response>; ...@@ -575,6 +598,8 @@ using ResponseParser = Private::ParserImpl<Http::Response>;
class Handler : public Tcp::Handler { class Handler : public Tcp::Handler {
public: public:
static constexpr const char* ParserData = "__Parser";
virtual void onRequest(const Request &request, ResponseWriter response) = 0; virtual void onRequest(const Request &request, ResponseWriter response) = 0;
virtual void onTimeout(const Request &request, ResponseWriter response); virtual void onTimeout(const Request &request, ResponseWriter response);
...@@ -584,16 +609,42 @@ public: ...@@ -584,16 +609,42 @@ public:
void setMaxResponseSize(size_t value); void setMaxResponseSize(size_t value);
size_t getMaxResponseSize() const; size_t getMaxResponseSize() const;
template<typename Duration>
void setHeaderTimeout(Duration timeout)
{
headerTimeout_ = std::chrono::duration_cast<std::chrono::milliseconds>(timeout);
}
template<typename Duration>
void setBodyTimeout(Duration timeout)
{
bodyTimeout_ = std::chrono::duration_cast<std::chrono::milliseconds>(timeout);
}
std::chrono::milliseconds getHeaderTimeout() const
{
return headerTimeout_;
}
std::chrono::milliseconds getBodyTimeout() const
{
return bodyTimeout_;
}
static std::shared_ptr<RequestParser> getParser(const std::shared_ptr<Tcp::Peer> &peer);
virtual ~Handler() override {} virtual ~Handler() override {}
private: private:
void onConnection(const std::shared_ptr<Tcp::Peer> &peer) override; void onConnection(const std::shared_ptr<Tcp::Peer> &peer) override;
void onInput(const char *buffer, size_t len, void onInput(const char *buffer, size_t len,
const std::shared_ptr<Tcp::Peer> &peer) override; const std::shared_ptr<Tcp::Peer> &peer) override;
private: private:
size_t maxRequestSize_ = Const::DefaultMaxRequestSize; size_t maxRequestSize_ = Const::DefaultMaxRequestSize;
size_t maxResponseSize_ = Const::DefaultMaxResponseSize; size_t maxResponseSize_ = Const::DefaultMaxResponseSize;
std::chrono::milliseconds headerTimeout_ = Const::DefaultHeaderTimeout;
std::chrono::milliseconds bodyTimeout_ = Const::DefaultBodyTimeout;
}; };
template <typename H, typename... Args> template <typename H, typename... Args>
......
...@@ -45,7 +45,9 @@ public: ...@@ -45,7 +45,9 @@ public:
TimePoint tick; TimePoint tick;
}; };
Listener() = default; using TransportFactory = std::function<std::shared_ptr<Transport> ()>;
Listener();
~Listener(); ~Listener();
explicit Listener(const Address &address); explicit Listener(const Address &address);
...@@ -54,6 +56,8 @@ public: ...@@ -54,6 +56,8 @@ public:
const std::string &workersName = "", const std::string &workersName = "",
int backlog = Const::MaxBacklog, int backlog = Const::MaxBacklog,
PISTACHE_STRING_LOGGER_T logger = PISTACHE_NULL_STRING_LOGGER); PISTACHE_STRING_LOGGER_T logger = PISTACHE_NULL_STRING_LOGGER);
void setTransportFactory(TransportFactory factory);
void setHandler(const std::shared_ptr<Handler> &handler); void setHandler(const std::shared_ptr<Handler> &handler);
void bind(); void bind();
...@@ -96,6 +100,10 @@ private: ...@@ -96,6 +100,10 @@ private:
Aio::Reactor reactor_; Aio::Reactor reactor_;
Aio::Reactor::Key transportKey; Aio::Reactor::Key transportKey;
TransportFactory transportFactory_;
TransportFactory defaultTransportFactory() const;
void handleNewConnection(); void handleNewConnection();
int acceptConnection(struct sockaddr_in &peer_addr) const; int acceptConnection(struct sockaddr_in &peer_addr) const;
void dispatchPeer(const std::shared_ptr<Peer> &peer); void dispatchPeer(const std::shared_ptr<Peer> &peer);
......
#pragma once
#include <cstdint>
namespace Pistache
{
namespace Meta
{
namespace Hash
{
static constexpr uint64_t val64 = 0xcbf29ce484222325;
static constexpr uint64_t prime64 = 0x100000001b3;
inline constexpr uint64_t fnv1a(const char* const str, const uint64_t value = val64) noexcept {
return (str[0] == '\0') ? value : fnv1a(&str[1], (value ^ uint64_t(str[0])) * prime64);
}
}
}
}
\ No newline at end of file
...@@ -44,6 +44,10 @@ public: ...@@ -44,6 +44,10 @@ public:
void *ssl() const; void *ssl() const;
void putData(std::string name, std::shared_ptr<void> data);
std::shared_ptr<void> getData(std::string name) const;
std::shared_ptr<void> tryGetData(std::string name) const;
Async::Promise<ssize_t> send(const RawBuffer &buffer, int flags = 0); Async::Promise<ssize_t> send(const RawBuffer &buffer, int flags = 0);
size_t getID() const; size_t getID() const;
...@@ -51,11 +55,6 @@ protected: ...@@ -51,11 +55,6 @@ protected:
Peer(Fd fd, const Address &addr, void *ssl); Peer(Fd fd, const Address &addr, void *ssl);
private: private:
void setParser(std::shared_ptr<Http::RequestParser> parser);
std::shared_ptr<Http::RequestParser> getParser() const;
Http::Request &request();
void associateTransport(Transport *transport); void associateTransport(Transport *transport);
Transport *transport() const; Transport *transport() const;
...@@ -64,7 +63,7 @@ private: ...@@ -64,7 +63,7 @@ private:
Address addr; Address addr;
std::string hostname_; std::string hostname_;
std::shared_ptr<Http::RequestParser> parser_; std::unordered_map<std::string, std::shared_ptr<void>> data_;
void *ssl_ = nullptr; void *ssl_ = nullptr;
const size_t id_; const size_t id_;
......
...@@ -13,6 +13,7 @@ namespace Pistache { ...@@ -13,6 +13,7 @@ namespace Pistache {
/* In a sense, a Prototype is just a class that provides a clone() method */ /* In a sense, a Prototype is just a class that provides a clone() method */
template <typename Class> struct Prototype { template <typename Class> struct Prototype {
public:
virtual ~Prototype() {} virtual ~Prototype() {}
virtual std::shared_ptr<Class> clone() const = 0; virtual std::shared_ptr<Class> clone() const = 0;
}; };
...@@ -20,9 +21,7 @@ template <typename Class> struct Prototype { ...@@ -20,9 +21,7 @@ template <typename Class> struct Prototype {
} // namespace Pistache } // namespace Pistache
#define PROTOTYPE_OF(Base, Class) \ #define PROTOTYPE_OF(Base, Class) \
private: \ public: \
std::shared_ptr<Base> clone() const override { \ std::shared_ptr<Base> clone() const override { \
return std::make_shared<Class>(*this); \ return std::make_shared<Class>(*this); \
} \ } \
\
public:
...@@ -32,7 +32,7 @@ enum class Options : uint64_t { ...@@ -32,7 +32,7 @@ enum class Options : uint64_t {
DECLARE_FLAGS_OPERATORS(Options) DECLARE_FLAGS_OPERATORS(Options)
class Handler : private Prototype<Handler> { class Handler : public Prototype<Handler> {
public: public:
friend class Transport; friend class Transport;
......
...@@ -27,15 +27,16 @@ class Handler; ...@@ -27,15 +27,16 @@ class Handler;
class Transport : public Aio::Handler { class Transport : public Aio::Handler {
public: public:
explicit Transport(const std::shared_ptr<Tcp::Handler> &handler); explicit Transport(const std::shared_ptr<Tcp::Handler> &handler);
Transport(const Transport &) = delete; Transport(const Transport &) = delete;
Transport &operator=(const Transport &) = delete; Transport &operator=(const Transport &) = delete;
void init(const std::shared_ptr<Tcp::Handler> &handler); void init(const std::shared_ptr<Tcp::Handler> &handler);
void registerPoller(Polling::Epoll &poller) override; virtual void registerPoller(Polling::Epoll &poller) override;
void handleNewPeer(const std::shared_ptr<Peer> &peer); void handleNewPeer(const std::shared_ptr<Peer> &peer);
void onReady(const Aio::FdSet &fds) override; virtual void onReady(const Aio::FdSet &fds) override;
template <typename Buf> template <typename Buf>
Async::Promise<ssize_t> asyncWrite(Fd fd, const Buf &buffer, int flags = 0) { Async::Promise<ssize_t> asyncWrite(Fd fd, const Buf &buffer, int flags = 0) {
...@@ -167,13 +168,17 @@ private: ...@@ -167,13 +168,17 @@ private:
std::unordered_map<Fd, TimerEntry> timers; std::unordered_map<Fd, TimerEntry> timers;
PollableQueue<PeerEntry> peersQueue; PollableQueue<PeerEntry> peersQueue;
std::unordered_map<Fd, std::shared_ptr<Peer>> peers;
Async::Deferred<rusage> loadRequest_; Async::Deferred<rusage> loadRequest_;
NotifyFd notifier; NotifyFd notifier;
std::shared_ptr<Tcp::Handler> handler_; std::shared_ptr<Tcp::Handler> handler_;
protected:
void removePeer(const std::shared_ptr<Peer>& peer);
std::unordered_map<Fd, std::shared_ptr<Peer>> peers;
private:
bool isPeerFd(Fd fd) const; bool isPeerFd(Fd fd) const;
bool isTimerFd(Fd fd) const; bool isTimerFd(Fd fd) const;
bool isPeerFd(Polling::Tag tag) const; bool isPeerFd(Polling::Tag tag) const;
......
...@@ -245,7 +245,7 @@ State ResponseLineStep::apply(StreamCursor &cursor) { ...@@ -245,7 +245,7 @@ State ResponseLineStep::apply(StreamCursor &cursor) {
char *end; char *end;
auto code = strtol(codeToken.rawText(), &end, 10); auto code = strtol(codeToken.rawText(), &end, 10);
if (*end != ' ') if (*end != ' ')
raise("Failed to parsed return code"); raise("Failed to parse return code");
response->code_ = static_cast<Http::Code>(code); response->code_ = static_cast<Http::Code>(code);
if (!cursor.advance(1)) if (!cursor.advance(1))
...@@ -482,6 +482,7 @@ State ParserBase::parse() { ...@@ -482,6 +482,7 @@ State ParserBase::parse() {
} }
bool ParserBase::feed(const char *data, size_t len) { bool ParserBase::feed(const char *data, size_t len) {
time_ = std::chrono::steady_clock::now();
return buffer.feed(data, len); return buffer.feed(data, len);
} }
...@@ -490,6 +491,12 @@ void ParserBase::reset() { ...@@ -490,6 +491,12 @@ void ParserBase::reset() {
cursor.reset(); cursor.reset();
currentStep = 0; currentStep = 0;
time_ = std::chrono::steady_clock::time_point(std::chrono::steady_clock::duration(0));
}
Step* ParserBase::step()
{
return allSteps[currentStep].get();
} }
} // namespace Private } // namespace Private
...@@ -656,20 +663,29 @@ void ResponseStream::ends() { ...@@ -656,20 +663,29 @@ void ResponseStream::ends() {
} }
ResponseWriter::ResponseWriter(ResponseWriter &&other) ResponseWriter::ResponseWriter(ResponseWriter &&other)
: response_(std::move(other.response_)), peer_(other.peer_), : response_(std::move(other.response_))
buf_(std::move(other.buf_)), transport_(other.transport_), , peer_(other.peer_)
timeout_(std::move(other.timeout_)) {} , buf_(std::move(other.buf_))
, transport_(other.transport_)
, timeout_(std::move(other.timeout_))
{}
ResponseWriter::ResponseWriter(Http::Version version, Tcp::Transport *transport, ResponseWriter::ResponseWriter(Http::Version version, Tcp::Transport *transport,
Handler *handler, std::weak_ptr<Tcp::Peer> peer) Handler *handler, std::weak_ptr<Tcp::Peer> peer)
: response_(version), peer_(peer), : response_(version)
buf_(DefaultStreamSize, handler->getMaxResponseSize()), , peer_(peer)
transport_(transport), timeout_(transport, handler, peer) {} , buf_(DefaultStreamSize, handler->getMaxResponseSize())
, transport_(transport)
, timeout_(transport, version, handler, peer)
{}
ResponseWriter::ResponseWriter(const ResponseWriter &other) ResponseWriter::ResponseWriter(const ResponseWriter &other)
: response_(other.response_), peer_(other.peer_), : response_(other.response_)
buf_(DefaultStreamSize, other.buf_.maxSize()), , peer_(other.peer_)
transport_(other.transport_), timeout_(other.timeout_) {} , buf_(DefaultStreamSize, other.buf_.maxSize())
, transport_(other.transport_)
, timeout_(other.timeout_)
{}
void ResponseWriter::setMime(const Mime::MediaType &mime) { void ResponseWriter::setMime(const Mime::MediaType &mime) {
auto ct = response_.headers().tryGet<Header::ContentType>(); auto ct = response_.headers().tryGet<Header::ContentType>();
...@@ -916,8 +932,8 @@ Private::ParserImpl<Http::Response>::ParserImpl(size_t maxDataSize) ...@@ -916,8 +932,8 @@ Private::ParserImpl<Http::Response>::ParserImpl(size_t maxDataSize)
void Handler::onInput(const char *buffer, size_t len, void Handler::onInput(const char *buffer, size_t len,
const std::shared_ptr<Tcp::Peer> &peer) { const std::shared_ptr<Tcp::Peer> &peer) {
auto parser = peer->getParser(); auto parser = getParser(peer);
auto &request = peer->request(); auto &request = parser->request;
try { try {
if (!parser->feed(buffer, len)) { if (!parser->feed(buffer, len)) {
parser->reset(); parser->reset();
...@@ -962,11 +978,14 @@ void Handler::onInput(const char *buffer, size_t len, ...@@ -962,11 +978,14 @@ void Handler::onInput(const char *buffer, size_t len,
} }
void Handler::onConnection(const std::shared_ptr<Tcp::Peer> &peer) { void Handler::onConnection(const std::shared_ptr<Tcp::Peer> &peer) {
peer->setParser(std::make_shared<RequestParser>(maxRequestSize_)); peer->putData(ParserData, std::make_shared<RequestParser>(maxRequestSize_));
} }
void Handler::onTimeout(const Request & /*request*/, void Handler::onTimeout(const Request& /*request*/,
ResponseWriter /*response*/) {} ResponseWriter response)
{
response.send(Code::Request_Timeout);
}
Timeout::~Timeout() { disarm(); } Timeout::~Timeout() { disarm(); }
...@@ -978,10 +997,15 @@ void Timeout::disarm() { ...@@ -978,10 +997,15 @@ void Timeout::disarm() {
bool Timeout::isArmed() const { return armed; } bool Timeout::isArmed() const { return armed; }
Timeout::Timeout(Tcp::Transport *transport_, Handler *handler_, Timeout::Timeout(Tcp::Transport *transport_, Http::Version version, Handler *handler_,
std::weak_ptr<Tcp::Peer> peer_) std::weak_ptr<Tcp::Peer> peer_)
: handler(handler_), transport(transport_), armed(false), timerFd(-1), : handler(handler_)
peer(peer_) {} , transport(transport_)
, version(version)
, armed(false)
, timerFd(-1)
, peer(peer_)
{}
void Timeout::onTimeout(uint64_t numWakeup) { void Timeout::onTimeout(uint64_t numWakeup) {
UNUSED(numWakeup) UNUSED(numWakeup)
...@@ -989,9 +1013,10 @@ void Timeout::onTimeout(uint64_t numWakeup) { ...@@ -989,9 +1013,10 @@ void Timeout::onTimeout(uint64_t numWakeup) {
if (!sp) if (!sp)
return; return;
ResponseWriter response(sp->request().version(), transport, handler, peer); ResponseWriter response(version, transport, handler, peer);
auto parser = Handler::getParser(sp);
handler->onTimeout(sp->request(), std::move(response)); const auto& request = parser->request;
handler->onTimeout(request, std::move(response));
} }
void Handler::setMaxRequestSize(size_t value) { maxRequestSize_ = value; } void Handler::setMaxRequestSize(size_t value) { maxRequestSize_ = value; }
...@@ -1002,5 +1027,10 @@ void Handler::setMaxResponseSize(size_t value) { maxResponseSize_ = value; } ...@@ -1002,5 +1027,10 @@ void Handler::setMaxResponseSize(size_t value) { maxResponseSize_ = value; }
size_t Handler::getMaxResponseSize() const { return maxResponseSize_; } size_t Handler::getMaxResponseSize() const { return maxResponseSize_; }
std::shared_ptr<RequestParser>
Handler::getParser(const std::shared_ptr<Tcp::Peer> &peer) {
return std::static_pointer_cast<RequestParser>(peer->getData(ParserData));
}
} // namespace Http } // namespace Http
} // namespace Pistache } // namespace Pistache
...@@ -78,18 +78,30 @@ int Peer::fd() const { ...@@ -78,18 +78,30 @@ int Peer::fd() const {
return fd_; return fd_;
} }
void Peer::setParser(std::shared_ptr<Http::RequestParser> parser) { void Peer::putData(std::string name, std::shared_ptr<void> data) {
parser_ = parser; auto it = data_.find(name);
} if (it != std::end(data_)) {
throw std::runtime_error("The data already exists");
}
std::shared_ptr<Http::RequestParser> Peer::getParser() const { return parser_; } data_.insert(std::make_pair(std::move(name), std::move(data)));
}
Http::Request &Peer::request() { std::shared_ptr<void> Peer::getData(std::string name) const {
if (!parser_) { auto data = tryGetData(std::move(name));
throw std::runtime_error("The peer has no associated parser"); if (data == nullptr) {
throw std::runtime_error("The data does not exist");
} }
return parser_->request; return data;
}
std::shared_ptr<void> Peer::tryGetData(std::string(name)) const {
auto it = data_.find(name);
if (it == std::end(data_))
return nullptr;
return it->second;
} }
Async::Promise<ssize_t> Peer::send(const RawBuffer &buffer, int flags) { Async::Promise<ssize_t> Peer::send(const RawBuffer &buffer, int flags) {
......
...@@ -66,7 +66,6 @@ public: ...@@ -66,7 +66,6 @@ public:
Reactor::Key addHandler(const std::shared_ptr<Handler> &handler, Reactor::Key addHandler(const std::shared_ptr<Handler> &handler,
bool setKey = true) override { bool setKey = true) override {
handler->registerPoller(poller); handler->registerPoller(poller);
handler->reactor_ = reactor_; handler->reactor_ = reactor_;
......
...@@ -160,6 +160,11 @@ void Transport::handleIncoming(const std::shared_ptr<Peer> &peer) { ...@@ -160,6 +160,11 @@ void Transport::handleIncoming(const std::shared_ptr<Peer> &peer) {
void Transport::handlePeerDisconnection(const std::shared_ptr<Peer> &peer) { void Transport::handlePeerDisconnection(const std::shared_ptr<Peer> &peer) {
handler_->onDisconnection(peer); handler_->onDisconnection(peer);
removePeer(peer);
}
void Transport::removePeer(const std::shared_ptr<Peer>& peer)
{
int fd = peer->fd(); int fd = peer->fd();
auto it = peers.find(fd); auto it = peers.find(fd);
if (it == std::end(peers)) if (it == std::end(peers))
......
...@@ -9,14 +9,149 @@ ...@@ -9,14 +9,149 @@
#include <pistache/peer.h> #include <pistache/peer.h>
#include <pistache/tcp.h> #include <pistache/tcp.h>
#include <array>
#include <chrono>
namespace Pistache { namespace Pistache {
namespace Http { namespace Http {
class TransportImpl : public Tcp::Transport
{
public:
using Base = Tcp::Transport;
explicit TransportImpl(const std::shared_ptr<Tcp::Handler>& handler);
void registerPoller(Polling::Epoll& poller) override;
void onReady(const Aio::FdSet& fds) override;
void setHeaderTimeout(std::chrono::milliseconds timeout);
void setBodyTimeout(std::chrono::milliseconds timeout);
std::shared_ptr<Aio::Handler> clone() const override;
private:
std::shared_ptr<Tcp::Handler> handler_;
std::chrono::milliseconds headerTimeout_;
std::chrono::milliseconds bodyTimeout_;
int timerFd;
void checkIdlePeers();
};
TransportImpl::TransportImpl(const std::shared_ptr<Tcp::Handler>& handler)
: Tcp::Transport(handler)
, handler_(handler)
{ }
void TransportImpl::registerPoller(Polling::Epoll& poller)
{
Base::registerPoller(poller);
timerFd = TRY_RET(timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK));
static constexpr auto TimerInterval = std::chrono::milliseconds(500);
static constexpr auto TimerIntervalNs = std::chrono::duration_cast<std::chrono::nanoseconds>(TimerInterval);
static_assert(
TimerInterval < std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::seconds(1)),
"Timer frequency should be less than 1 second"
);
itimerspec spec;
spec.it_value.tv_sec = 0;
spec.it_value.tv_nsec = TimerIntervalNs.count();
spec.it_interval.tv_sec = 0;
spec.it_interval.tv_nsec = TimerIntervalNs.count();
TRY(timerfd_settime(timerFd, 0, &spec, 0));
Polling::Tag tag(timerFd);
poller.addFd(timerFd, Flags<Polling::NotifyOn>(Polling::NotifyOn::Read), Polling::Tag(timerFd));
}
void TransportImpl::onReady(const Aio::FdSet& fds)
{
bool handled = false;
for (const auto& entry: fds)
{
if (entry.getTag() == Polling::Tag(timerFd))
{
uint64_t wakeups;
::read(timerFd, &wakeups, sizeof wakeups);
checkIdlePeers();
handled = true;
}
}
if (!handled)
Base::onReady(fds);
}
void TransportImpl::setHeaderTimeout(std::chrono::milliseconds timeout)
{
headerTimeout_ = timeout;
}
void TransportImpl::setBodyTimeout(std::chrono::milliseconds timeout)
{
bodyTimeout_ = timeout;
}
void TransportImpl::checkIdlePeers()
{
std::vector<std::shared_ptr<Tcp::Peer>> idlePeers;
for (const auto& peerPair: peers)
{
const auto& peer = peerPair.second;
auto parser = Http::Handler::getParser(peer);
auto time = parser->time();
auto now = std::chrono::steady_clock::now();
auto elapsed = now - time;
auto* step = parser->step();
if (step->id() == Private::RequestLineStep::Id)
{
if (elapsed > headerTimeout_ || elapsed > bodyTimeout_)
idlePeers.push_back(peer);
}
else if (step->id() == Private::HeadersStep::Id)
{
if (elapsed > bodyTimeout_)
idlePeers.push_back(peer);
}
}
for (const auto& idlePeer: idlePeers)
{
ResponseWriter response(Http::Version::Http11, this, static_cast<Http::Handler *>(handler_.get()), idlePeer);
response.send(Http::Code::Request_Timeout).then([=](ssize_t) {
removePeer(idlePeer);
}, [=](std::exception_ptr) {
removePeer(idlePeer);
});
}
}
std::shared_ptr<Aio::Handler> TransportImpl::clone() const
{
auto transport = std::make_shared<TransportImpl>(handler_->clone());
transport->setHeaderTimeout(headerTimeout_);
transport->setBodyTimeout(bodyTimeout_);
return transport;
}
Endpoint::Options::Options() Endpoint::Options::Options()
: threads_(1), flags_(), backlog_(Const::MaxBacklog), : threads_(1), flags_(), backlog_(Const::MaxBacklog)
maxRequestSize_(Const::DefaultMaxRequestSize), , maxRequestSize_(Const::DefaultMaxRequestSize)
maxResponseSize_(Const::DefaultMaxResponseSize), , maxResponseSize_(Const::DefaultMaxResponseSize)
logger_(PISTACHE_NULL_STRING_LOGGER) {} , headerTimeout_(Const::DefaultHeaderTimeout)
, bodyTimeout_(Const::DefaultBodyTimeout)
, logger_(PISTACHE_NULL_STRING_LOGGER)
{}
Endpoint::Options &Endpoint::Options::threads(int val) { Endpoint::Options &Endpoint::Options::threads(int val) {
threads_ = val; threads_ = val;
...@@ -63,15 +198,25 @@ Endpoint::Endpoint(const Address &addr) : listener(addr) {} ...@@ -63,15 +198,25 @@ Endpoint::Endpoint(const Address &addr) : listener(addr) {}
void Endpoint::init(const Endpoint::Options &options) { void Endpoint::init(const Endpoint::Options &options) {
listener.init(options.threads_, options.flags_, options.threadsName_); listener.init(options.threads_, options.flags_, options.threadsName_);
maxRequestSize_ = options.maxRequestSize_; listener.setTransportFactory([&] {
maxResponseSize_ = options.maxResponseSize_; if (!handler_)
throw std::runtime_error("Must call setHandler()");
auto transport = std::make_shared<TransportImpl>(handler_);
transport->setHeaderTimeout(options.headerTimeout_);
transport->setBodyTimeout(options.bodyTimeout_);
return transport;
});
options_ = options;
logger_ = options.logger_; logger_ = options.logger_;
} }
void Endpoint::setHandler(const std::shared_ptr<Handler> &handler) { void Endpoint::setHandler(const std::shared_ptr<Handler> &handler) {
handler_ = handler; handler_ = handler;
handler_->setMaxRequestSize(maxRequestSize_); handler_->setMaxRequestSize(options_.maxRequestSize_);
handler_->setMaxResponseSize(maxResponseSize_); handler_->setMaxResponseSize(options_.maxResponseSize_);
} }
void Endpoint::bind() { listener.bind(); } void Endpoint::bind() { listener.bind(); }
......
...@@ -148,7 +148,14 @@ void setSocketOptions(Fd fd, Flags<Options> options) { ...@@ -148,7 +148,14 @@ void setSocketOptions(Fd fd, Flags<Options> options) {
} }
} }
Listener::Listener(const Address &address) : addr_(address) {} Listener::Listener()
: transportFactory_(defaultTransportFactory())
{}
Listener::Listener(const Address &address)
: addr_(address)
, transportFactory_(defaultTransportFactory())
{}
Listener::~Listener() { Listener::~Listener() {
if (isBound()) if (isBound())
...@@ -177,6 +184,11 @@ void Listener::init(size_t workers, Flags<Options> options, ...@@ -177,6 +184,11 @@ void Listener::init(size_t workers, Flags<Options> options,
logger_ = logger; logger_ = logger;
} }
void Listener::setTransportFactory(TransportFactory factory)
{
transportFactory_ = std::move(factory);
}
void Listener::setHandler(const std::shared_ptr<Handler> &handler) { void Listener::setHandler(const std::shared_ptr<Handler> &handler) {
handler_ = handler; handler_ = handler;
} }
...@@ -200,8 +212,6 @@ void Listener::pinWorker(size_t worker, const CpuSet &set) { ...@@ -200,8 +212,6 @@ void Listener::pinWorker(size_t worker, const CpuSet &set) {
void Listener::bind() { bind(addr_); } void Listener::bind() { bind(addr_); }
void Listener::bind(const Address &address) { void Listener::bind(const Address &address) {
if (!handler_)
throw std::runtime_error("Call setHandler before calling bind()");
addr_ = address; addr_ = address;
struct addrinfo hints; struct addrinfo hints;
...@@ -252,7 +262,7 @@ void Listener::bind(const Address &address) { ...@@ -252,7 +262,7 @@ void Listener::bind(const Address &address) {
Polling::Tag(fd)); Polling::Tag(fd));
listen_fd = fd; listen_fd = fd;
auto transport = std::make_shared<Transport>(handler_); auto transport = transportFactory_();
reactor_.init(Aio::AsyncContext(workers_, workersName_)); reactor_.init(Aio::AsyncContext(workers_, workersName_));
transportKey = reactor_.addHandler(transport); transportKey = reactor_.addHandler(transport);
...@@ -451,6 +461,16 @@ void Listener::dispatchPeer(const std::shared_ptr<Peer> &peer) { ...@@ -451,6 +461,16 @@ void Listener::dispatchPeer(const std::shared_ptr<Peer> &peer) {
transport->handleNewPeer(peer); transport->handleNewPeer(peer);
} }
Listener::TransportFactory Listener::defaultTransportFactory() const
{
return [&] {
if (!handler_)
throw std::runtime_error("setHandler() has not been called");
return std::make_shared<Transport>(handler_);
};
}
#ifdef PISTACHE_USE_SSL #ifdef PISTACHE_USE_SSL
void Listener::setupSSLAuth(const std::string &ca_file, void Listener::setupSSLAuth(const std::string &ca_file,
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include <string> #include <string>
#include <thread> #include <thread>
#include "tcp_client.h"
using namespace Pistache; using namespace Pistache;
#define THREAD_INFO \ #define THREAD_INFO \
...@@ -96,6 +98,22 @@ struct AddressEchoHandler : public Http::Handler { ...@@ -96,6 +98,22 @@ struct AddressEchoHandler : public Http::Handler {
} }
}; };
struct PingHandler : public Http::Handler {
HTTP_PROTOTYPE(PingHandler)
PingHandler() = default;
void onRequest(const Http::Request& request,
Http::ResponseWriter writer) override {
if (request.resource() == "/ping") {
writer.send(Http::Code::Ok, "PONG");
}
else {
writer.send(Http::Code::Not_Found);
}
}
};
int clientLogicFunc(int response_size, const std::string &server_page, int clientLogicFunc(int response_size, const std::string &server_page,
int timeout_seconds, int wait_seconds) { int timeout_seconds, int wait_seconds) {
Http::Client client; Http::Client client;
...@@ -429,6 +447,79 @@ TEST(http_server_test, response_size_captured) { ...@@ -429,6 +447,79 @@ TEST(http_server_test, response_size_captured) {
ASSERT_EQ(rcode, Http::Code::Ok); ASSERT_EQ(rcode, Http::Code::Ok);
} }
TEST(http_server_test, client_request_header_timeout_raises_http_408) {
Pistache::Address address("localhost", Pistache::Port(0));
auto timeout = std::chrono::seconds(2);
Http::Endpoint server(address);
auto flags = Tcp::Options::ReuseAddr;
auto opts = Http::Endpoint::options()
.flags(flags)
.headerTimeout(timeout);
server.init(opts);
server.setHandler(Http::make_handler<PingHandler>());
server.serveThreaded();
auto port = server.getPort();
auto addr = "localhost:" + port.toString();
std::cout << "Server address: " << addr << "\n";
char recvBuf[1024];
std::memset(recvBuf, 0, sizeof(recvBuf));
size_t bytes;
TcpClient client;
ASSERT_TRUE(client.connect(Pistache::Address("localhost", port))) << client.lastError();
ASSERT_TRUE(client.receive(recvBuf, sizeof(recvBuf), &bytes, std::chrono::seconds(5))) << client.lastError();
server.shutdown();
}
TEST(http_server_test, client_request_body_timeout_raises_http_408) {
Pistache::Address address("localhost", Pistache::Port(0));
auto headerTimeout = std::chrono::seconds(1);
auto bodyTimeout = std::chrono::seconds(1);
Http::Endpoint server(address);
auto flags = Tcp::Options::ReuseAddr;
auto opts = Http::Endpoint::options()
.flags(flags)
.headerTimeout(headerTimeout)
.bodyTimeout(bodyTimeout);
server.init(opts);
server.setHandler(Http::make_handler<PingHandler>());
server.serveThreaded();
auto port = server.getPort();
auto addr = "localhost:" + port.toString();
std::cout << "Server address: " << addr << "\n";
std::string reqStr = "GET /ping HTTP/1.1\r\n";
std::string headerStr = "Host: localhost\r\nUser-Agent: test\r\n";
char recvBuf[1024];
std::memset(recvBuf, 0, sizeof(recvBuf));
size_t bytes;
TcpClient client;
ASSERT_TRUE(client.connect(Pistache::Address("localhost", port))) << client.lastError();
ASSERT_TRUE(client.send(reqStr)) << client.lastError();
std::this_thread::sleep_for(headerTimeout / 2);
ASSERT_TRUE(client.send(headerStr)) << client.lastError();
static constexpr const char* ExpectedResponseLine = "HTTP/1.1 408 Request Timeout";
ASSERT_TRUE(client.receive(recvBuf, sizeof(recvBuf), &bytes, std::chrono::seconds(5))) << client.lastError();
ASSERT_TRUE(!strncmp(recvBuf, ExpectedResponseLine, strlen(ExpectedResponseLine)));
server.shutdown();
}
namespace { namespace {
class WaitHelper { class WaitHelper {
......
#pragma once
#include <pistache/net.h>
#include <pistache/os.h>
#include <netdb.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <poll.h>
namespace Pistache
{
#define CLIENT_TRY(...) \
do { \
auto ret = __VA_ARGS__; \
if (ret < 0) { \
lastError_ = strerror(errno); \
return false; \
} \
} while (0) \
class TcpClient
{
public:
bool connect(const Pistache::Address& address)
{
struct addrinfo hints;
std::memset(&hints, 0, sizeof(hints));
hints.ai_family = address.family();
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = 0;
hints.ai_protocol = 0;
auto host = address.host();
auto port = address.port().toString();
AddrInfo addrInfo;
CLIENT_TRY(addrInfo.invoke(host.c_str(), port.c_str(), &hints));
const auto* addrs = addrInfo.get_info_ptr();
int sfd = -1;
auto* addr = addrs;
for (; addr; addr = addr->ai_next) {
sfd = ::socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol);
if (sfd < 0)
continue;
break;
}
CLIENT_TRY(sfd);
CLIENT_TRY(::connect(sfd, addr->ai_addr, addr->ai_addrlen));
make_non_blocking(sfd);
fd_ = sfd;
return true;
}
bool send(const std::string& data)
{
return send(data.c_str(), data.size());
}
bool send(const void* data, size_t size)
{
CLIENT_TRY(::send(fd_, data, size, 0));
return true;
}
template<typename Duration>
bool receive(void* buffer, size_t size, size_t *bytes, Duration timeout)
{
struct pollfd fds[1];
fds[0].fd = fd_;
fds[0].events = POLLIN;
auto timeoutMs = std::chrono::duration_cast<std::chrono::milliseconds>(timeout);
auto ret = ::poll(fds, 1, static_cast<int>(timeoutMs.count()));
if (ret < 0)
{
lastError_ = strerror(errno);
return false;
}
if (ret == 0)
{
lastError_ = "Poll timeout";
return false;
}
if (fds[0].revents & POLLERR)
{
lastError_ = "An error has occured on the stream";
return false;
}
auto res = ::recv(fd_, buffer, size, 0);
if (res < 0)
{
lastError_ = strerror(errno);
return false;
}
*bytes = static_cast<size_t>(res);
return true;
}
std::string lastError() const
{
return lastError_;
}
private:
int fd_;
std::string lastError_;
};
#undef CLIENT_TRY
} // namespace Pistache
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment