Unverified Commit b605712c authored by Igor [hyperxor]'s avatar Igor [hyperxor] Committed by GitHub

Fix issue #610 and refactoring (#711)

* Fix inssue #610 and refactoring

* Fix ResponseWriter constructors
parent e3ac500d
...@@ -24,13 +24,20 @@ ...@@ -24,13 +24,20 @@
namespace Pistache { namespace Pistache {
namespace Http { namespace Http {
namespace Default {
constexpr int Threads = 1;
constexpr int MaxConnectionsPerHost = 8;
constexpr bool KeepAlive = true;
constexpr size_t MaxResponseSize = std::numeric_limits<uint32_t>::max();
} // namespace Default
class Transport; class Transport;
struct Connection : public std::enable_shared_from_this<Connection> { struct Connection : public std::enable_shared_from_this<Connection> {
using OnDone = std::function<void()>; using OnDone = std::function<void()>;
Connection(); explicit Connection(size_t maxResponseSize);
struct RequestData { struct RequestData {
...@@ -103,9 +110,9 @@ private: ...@@ -103,9 +110,9 @@ private:
class ConnectionPool { class ConnectionPool {
public: public:
ConnectionPool() : connsLock(), conns(), maxConnectionsPerHost() {} ConnectionPool() = default;
void init(size_t maxConnsPerHost); void init(size_t maxConnsPerHost, size_t maxResponseSize);
std::shared_ptr<Connection> pickConnection(const std::string &domain); std::shared_ptr<Connection> pickConnection(const std::string &domain);
static void releaseConnection(const std::shared_ptr<Connection> &connection); static void releaseConnection(const std::shared_ptr<Connection> &connection);
...@@ -125,14 +132,9 @@ private: ...@@ -125,14 +132,9 @@ private:
mutable Lock connsLock; mutable Lock connsLock;
std::unordered_map<std::string, Connections> conns; std::unordered_map<std::string, Connections> conns;
size_t maxConnectionsPerHost; size_t maxConnectionsPerHost;
size_t maxResponseSize;
}; };
namespace Default {
constexpr int Threads = 1;
constexpr int MaxConnectionsPerHost = 8;
constexpr bool KeepAlive = true;
} // namespace Default
class Client; class Client;
class RequestBuilder { class RequestBuilder {
...@@ -175,16 +177,19 @@ public: ...@@ -175,16 +177,19 @@ public:
Options() Options()
: threads_(Default::Threads), : threads_(Default::Threads),
maxConnectionsPerHost_(Default::MaxConnectionsPerHost), maxConnectionsPerHost_(Default::MaxConnectionsPerHost),
keepAlive_(Default::KeepAlive) {} keepAlive_(Default::KeepAlive),
maxResponseSize_(Default::MaxResponseSize) {}
Options &threads(int val); Options &threads(int val);
Options &keepAlive(bool val); Options &keepAlive(bool val);
Options &maxConnectionsPerHost(int val); Options &maxConnectionsPerHost(int val);
Options &maxResponseSize(size_t val);
private: private:
int threads_; int threads_;
int maxConnectionsPerHost_; int maxConnectionsPerHost_;
bool keepAlive_; bool keepAlive_;
size_t maxResponseSize_;
}; };
Client(); Client();
......
...@@ -149,6 +149,8 @@ private: ...@@ -149,6 +149,8 @@ private:
std::shared_ptr<Handler> handler_; std::shared_ptr<Handler> handler_;
Tcp::Listener listener; Tcp::Listener listener;
size_t maxRequestSize_ = Const::DefaultMaxRequestSize;
size_t maxResponseSize_ = Const::DefaultMaxResponseSize;
}; };
template <typename Handler> template <typename Handler>
......
...@@ -49,7 +49,6 @@ template <typename P> struct IsHttpPrototype { ...@@ -49,7 +49,6 @@ template <typename P> struct IsHttpPrototype {
typedef Pistache::Http::details::prototype_tag tag; typedef Pistache::Http::details::prototype_tag tag;
namespace Private { namespace Private {
class ParserBase;
template <typename T> class Parser; template <typename T> class Parser;
class RequestLineStep; class RequestLineStep;
class ResponseLineStep; class ResponseLineStep;
...@@ -288,7 +287,8 @@ public: ...@@ -288,7 +287,8 @@ public:
private: private:
ResponseStream(Message &&other, std::weak_ptr<Tcp::Peer> peer, ResponseStream(Message &&other, std::weak_ptr<Tcp::Peer> peer,
Tcp::Transport *transport, Timeout timeout, size_t streamSize); Tcp::Transport *transport, Timeout timeout, size_t streamSize,
size_t maxResponseSize);
std::shared_ptr<Tcp::Peer> peer() const; std::shared_ptr<Tcp::Peer> peer() const;
...@@ -518,16 +518,17 @@ private: ...@@ -518,16 +518,17 @@ private:
class ParserBase { class ParserBase {
public: public:
ParserBase(); explicit ParserBase(size_t maxDataSize);
ParserBase(const ParserBase &other) = delete; ParserBase(const ParserBase &) = delete;
ParserBase(ParserBase &&other) = default; ParserBase &operator=(const ParserBase &) = delete;
ParserBase(ParserBase &&) = default;
bool feed(const char *data, size_t len); ParserBase &operator=(ParserBase &&) = default;
virtual void reset();
virtual ~ParserBase() = default; virtual ~ParserBase() = default;
bool feed(const char *data, size_t len);
virtual void reset();
State parse(); State parse();
protected: protected:
...@@ -546,47 +547,16 @@ template <typename Message> class Parser; ...@@ -546,47 +547,16 @@ template <typename Message> class Parser;
template <> class Parser<Http::Request> : public ParserBase { template <> class Parser<Http::Request> : public ParserBase {
public: public:
Parser() : ParserBase(), request() { explicit Parser(size_t maxDataSize);
allSteps[0].reset(new RequestLineStep(&request));
allSteps[1].reset(new HeadersStep(&request));
allSteps[2].reset(new BodyStep(&request));
}
Parser(const char *data, size_t len) : ParserBase(), request() {
allSteps[0].reset(new RequestLineStep(&request));
allSteps[1].reset(new HeadersStep(&request));
allSteps[2].reset(new BodyStep(&request));
feed(data, len);
}
void reset() override { void reset() override;
ParserBase::reset();
request.headers_.clear();
request.body_.clear();
request.resource_.clear();
request.query_.clear();
}
Request request; Request request;
}; };
template <> class Parser<Http::Response> : public ParserBase { template <> class Parser<Http::Response> : public ParserBase {
public: public:
Parser() : ParserBase(), response() { explicit Parser(size_t maxDataSize);
allSteps[0].reset(new ResponseLineStep(&response));
allSteps[1].reset(new HeadersStep(&response));
allSteps[2].reset(new BodyStep(&response));
}
Parser(const char *data, size_t len) : ParserBase(), response() {
allSteps[0].reset(new ResponseLineStep(&response));
allSteps[1].reset(new HeadersStep(&response));
allSteps[2].reset(new BodyStep(&response));
feed(data, len);
}
Response response; Response response;
}; };
...@@ -605,11 +575,20 @@ public: ...@@ -605,11 +575,20 @@ public:
virtual void onTimeout(const Request &request, ResponseWriter response); virtual void onTimeout(const Request &request, ResponseWriter response);
void setMaxRequestSize(size_t value);
size_t getMaxRequestSize() const;
void setMaxResponseSize(size_t value);
size_t getMaxResponseSize() const;
virtual ~Handler() override {} virtual ~Handler() override {}
private: private:
Private::Parser<Http::Request> & Private::Parser<Http::Request> &
getParser(const std::shared_ptr<Tcp::Peer> &peer) const; getParser(const std::shared_ptr<Tcp::Peer> &peer) const;
private:
size_t maxRequestSize_ = Const::DefaultMaxRequestSize;
size_t maxResponseSize_ = Const::DefaultMaxResponseSize;
}; };
template <typename H, typename... Args> template <typename H, typename... Args>
......
...@@ -66,10 +66,10 @@ public: ...@@ -66,10 +66,10 @@ public:
template <typename CharT = char> template <typename CharT = char>
class ArrayStreamBuf : public StreamBuf<CharT> { class ArrayStreamBuf : public StreamBuf<CharT> {
public: public:
typedef StreamBuf<CharT> Base; using Base = StreamBuf<CharT>;
static size_t maxSize;
ArrayStreamBuf() : StreamBuf<CharT>(), bytes() { explicit ArrayStreamBuf(size_t maxSize)
: StreamBuf<CharT>(), bytes(), maxSize(maxSize) {
bytes.clear(); bytes.clear();
Base::setg(bytes.data(), bytes.data(), bytes.data() + bytes.size()); Base::setg(bytes.data(), bytes.data(), bytes.data() + bytes.size());
} }
...@@ -100,11 +100,9 @@ public: ...@@ -100,11 +100,9 @@ public:
private: private:
std::vector<CharT> bytes; std::vector<CharT> bytes;
size_t maxSize = Const::MaxBuffer;
}; };
template <typename CharT>
size_t ArrayStreamBuf<CharT>::maxSize = Const::DefaultMaxRequestSize;
struct RawBuffer { struct RawBuffer {
RawBuffer(); RawBuffer();
RawBuffer(std::string data, size_t length, bool isDetached = false); RawBuffer(std::string data, size_t length, bool isDetached = false);
...@@ -135,37 +133,23 @@ private: ...@@ -135,37 +133,23 @@ private:
class DynamicStreamBuf : public StreamBuf<char> { class DynamicStreamBuf : public StreamBuf<char> {
public: public:
typedef StreamBuf<char> Base; using Base = StreamBuf<char>;
typedef typename Base::traits_type traits_type; using traits_type = typename Base::traits_type;
typedef typename Base::int_type int_type; using int_type = typename Base::int_type;
static size_t maxSize;
explicit DynamicStreamBuf(size_t size) : data_() { reserve(size); } DynamicStreamBuf(size_t size, size_t maxSize);
DynamicStreamBuf(const DynamicStreamBuf &other) = delete; DynamicStreamBuf(const DynamicStreamBuf &other) = delete;
DynamicStreamBuf &operator=(const DynamicStreamBuf &other) = delete; DynamicStreamBuf &operator=(const DynamicStreamBuf &other) = delete;
DynamicStreamBuf(DynamicStreamBuf &&other) : data_(std::move(other.data_)) { DynamicStreamBuf(DynamicStreamBuf &&other);
setp(other.pptr(), other.epptr()); DynamicStreamBuf &operator=(DynamicStreamBuf &&other);
other.setp(nullptr, nullptr);
}
DynamicStreamBuf &operator=(DynamicStreamBuf &&other) { RawBuffer buffer() const;
data_ = std::move(other.data_);
setp(other.pptr(), other.epptr());
other.setp(nullptr, nullptr);
return *this;
}
RawBuffer buffer() const { void clear();
return RawBuffer(data_.data(), pptr() - data_.data());
}
void clear() { size_t maxSize() const;
// reset stream buffer to the whole backing storage.
this->setp(data_.data(), data_.data() + data_.size());
}
protected: protected:
int_type overflow(int_type ch) override; int_type overflow(int_type ch) override;
...@@ -174,6 +158,7 @@ private: ...@@ -174,6 +158,7 @@ private:
void reserve(size_t size); void reserve(size_t size);
std::vector<char> data_; std::vector<char> data_;
size_t maxSize_ = Const::MaxBuffer;
}; };
class StreamCursor { class StreamCursor {
......
...@@ -396,44 +396,34 @@ void Transport::handleHangupEntry(const Aio::FdSet::Entry &entry) { ...@@ -396,44 +396,34 @@ void Transport::handleHangupEntry(const Aio::FdSet::Entry &entry) {
} }
void Transport::handleIncoming(std::shared_ptr<Connection> connection) { void Transport::handleIncoming(std::shared_ptr<Connection> connection) {
char buffer[Const::MaxBuffer] = {0};
ssize_t totalBytes = 0; ssize_t totalBytes = 0;
for (;;) { for (;;) {
ssize_t bytes = recv(connection->fd(), buffer + totalBytes, char buffer[Const::MaxBuffer] = {
Const::MaxBuffer - totalBytes, 0); 0,
};
const ssize_t bytes = recv(connection->fd(), buffer, Const::MaxBuffer, 0);
if (bytes == -1) { if (bytes == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) { if (errno != EAGAIN && errno != EWOULDBLOCK) {
if (totalBytes > 0) {
connection->handleResponsePacket(buffer, totalBytes);
}
} else {
connection->handleError(strerror(errno)); connection->handleError(strerror(errno));
} }
break; break;
} else if (bytes == 0) { } else if (bytes == 0) {
if (totalBytes > 0) { if (totalBytes == 0) {
connection->handleResponsePacket(buffer, totalBytes);
} else {
connection->handleError("Remote closed connection"); connection->handleError("Remote closed connection");
} }
connections.erase(connection->fd()); connections.erase(connection->fd());
connection->close(); connection->close();
break; break;
} } else {
else {
totalBytes += bytes; totalBytes += bytes;
if (static_cast<size_t>(totalBytes) > Const::MaxBuffer) { connection->handleResponsePacket(buffer, bytes);
std::cerr << "Client: Too long packet" << std::endl;
break;
}
} }
} }
} }
Connection::Connection() : fd_(-1), requestEntry(nullptr) { Connection::Connection(size_t maxResponseSize)
: fd_(-1), requestEntry(nullptr), parser(maxResponseSize) {
state_.store(static_cast<uint32_t>(State::Idle)); state_.store(static_cast<uint32_t>(State::Idle));
connectionState_.store(NotConnected); connectionState_.store(NotConnected);
} }
...@@ -532,7 +522,11 @@ Fd Connection::fd() const { ...@@ -532,7 +522,11 @@ Fd Connection::fd() const {
void Connection::handleResponsePacket(const char *buffer, size_t totalBytes) { void Connection::handleResponsePacket(const char *buffer, size_t totalBytes) {
try { try {
parser.feed(buffer, totalBytes); const bool result = parser.feed(buffer, totalBytes);
if (!result) {
handleError("Client: Too long packet");
return;
}
if (parser.parse() == Private::State::Done) { if (parser.parse() == Private::State::Done) {
if (requestEntry) { if (requestEntry) {
if (requestEntry->timer) { if (requestEntry->timer) {
...@@ -642,8 +636,10 @@ void Connection::processRequestQueue() { ...@@ -642,8 +636,10 @@ void Connection::processRequestQueue() {
} }
} }
void ConnectionPool::init(size_t maxConnsPerHost) { void ConnectionPool::init(size_t maxConnectionsPerHost,
maxConnectionsPerHost = maxConnsPerHost; size_t maxResponseSize) {
this->maxConnectionsPerHost = maxConnectionsPerHost;
this->maxResponseSize = maxResponseSize;
} }
std::shared_ptr<Connection> std::shared_ptr<Connection>
...@@ -656,7 +652,7 @@ ConnectionPool::pickConnection(const std::string &domain) { ...@@ -656,7 +652,7 @@ ConnectionPool::pickConnection(const std::string &domain) {
if (poolIt == std::end(conns)) { if (poolIt == std::end(conns)) {
Connections connections; Connections connections;
for (size_t i = 0; i < maxConnectionsPerHost; ++i) { for (size_t i = 0; i < maxConnectionsPerHost; ++i) {
connections.push_back(std::make_shared<Connection>()); connections.push_back(std::make_shared<Connection>(maxResponseSize));
} }
poolIt = poolIt =
...@@ -780,6 +776,11 @@ Client::Options &Client::Options::maxConnectionsPerHost(int val) { ...@@ -780,6 +776,11 @@ Client::Options &Client::Options::maxConnectionsPerHost(int val) {
return *this; return *this;
} }
Client::Options &Client::Options::maxResponseSize(size_t val) {
maxResponseSize_ = val;
return *this;
}
Client::Client() Client::Client()
: reactor_(Aio::Reactor::create()), pool(), transportKey(), ioIndex(0), : reactor_(Aio::Reactor::create()), pool(), transportKey(), ioIndex(0),
queuesLock(), requestsQueues(), stopProcessPequestsQueues(false) {} queuesLock(), requestsQueues(), stopProcessPequestsQueues(false) {}
...@@ -792,7 +793,7 @@ Client::~Client() { ...@@ -792,7 +793,7 @@ Client::~Client() {
Client::Options Client::options() { return Client::Options(); } Client::Options Client::options() { return Client::Options(); }
void Client::init(const Client::Options &options) { void Client::init(const Client::Options &options) {
pool.init(options.maxConnectionsPerHost_); pool.init(options.maxConnectionsPerHost_, options.maxResponseSize_);
reactor_->init(Aio::AsyncContext(options.threads_)); reactor_->init(Aio::AsyncContext(options.threads_));
transportKey = reactor_->addHandler(std::make_shared<Transport>()); transportKey = reactor_->addHandler(std::make_shared<Transport>());
reactor_->run(); reactor_->run();
......
...@@ -464,7 +464,8 @@ State BodyStep::parseTransferEncoding( ...@@ -464,7 +464,8 @@ State BodyStep::parseTransferEncoding(
return State::Done; return State::Done;
} }
ParserBase::ParserBase() : cursor(&buffer) {} ParserBase::ParserBase(size_t maxDataSize)
: buffer(maxDataSize), cursor(&buffer) {}
State ParserBase::parse() { State ParserBase::parse() {
State state; State state;
...@@ -576,9 +577,10 @@ ResponseStream::ResponseStream(ResponseStream &&other) ...@@ -576,9 +577,10 @@ ResponseStream::ResponseStream(ResponseStream &&other)
ResponseStream::ResponseStream(Message &&other, std::weak_ptr<Tcp::Peer> peer, ResponseStream::ResponseStream(Message &&other, std::weak_ptr<Tcp::Peer> peer,
Tcp::Transport *transport, Timeout timeout, Tcp::Transport *transport, Timeout timeout,
size_t streamSize) size_t streamSize, size_t maxResponseSize)
: response_(std::move(other)), peer_(std::move(peer)), buf_(streamSize), : response_(std::move(other)), peer_(std::move(peer)),
transport_(transport), timeout_(std::move(timeout)) { buf_(streamSize, maxResponseSize), transport_(transport),
timeout_(std::move(timeout)) {
if (!writeStatusLine(response_.version(), response_.code(), buf_)) if (!writeStatusLine(response_.version(), response_.code(), buf_))
throw Error("Response exceeded buffer size"); throw Error("Response exceeded buffer size");
...@@ -657,12 +659,14 @@ ResponseWriter::ResponseWriter(ResponseWriter &&other) ...@@ -657,12 +659,14 @@ ResponseWriter::ResponseWriter(ResponseWriter &&other)
ResponseWriter::ResponseWriter(Tcp::Transport *transport, Request request, ResponseWriter::ResponseWriter(Tcp::Transport *transport, Request request,
Handler *handler, std::weak_ptr<Tcp::Peer> peer) Handler *handler, std::weak_ptr<Tcp::Peer> peer)
: response_(request.version()), peer_(peer), buf_(DefaultStreamSize), : response_(request.version()), peer_(peer),
buf_(DefaultStreamSize, handler->getMaxResponseSize()),
transport_(transport), transport_(transport),
timeout_(transport, handler, std::move(request), peer), sent_bytes_(0) {} timeout_(transport, handler, std::move(request), peer), sent_bytes_(0) {}
ResponseWriter::ResponseWriter(const ResponseWriter &other) ResponseWriter::ResponseWriter(const ResponseWriter &other)
: response_(other.response_), peer_(other.peer_), buf_(DefaultStreamSize), : response_(other.response_), peer_(other.peer_),
buf_(DefaultStreamSize, other.buf_.maxSize()),
transport_(other.transport_), timeout_(other.timeout_), sent_bytes_(0) {} transport_(other.transport_), timeout_(other.timeout_), sent_bytes_(0) {}
void ResponseWriter::setMime(const Mime::MediaType &mime) { void ResponseWriter::setMime(const Mime::MediaType &mime) {
...@@ -716,7 +720,7 @@ ResponseStream ResponseWriter::stream(Code code, size_t streamSize) { ...@@ -716,7 +720,7 @@ ResponseStream ResponseWriter::stream(Code code, size_t streamSize) {
response_.code_ = code; response_.code_ = code;
return ResponseStream(std::move(response_), peer_, transport_, return ResponseStream(std::move(response_), peer_, transport_,
std::move(timeout_), streamSize); std::move(timeout_), streamSize, buf_.maxSize());
} }
const CookieJar &ResponseWriter::cookies() const { return response_.cookies(); } const CookieJar &ResponseWriter::cookies() const { return response_.cookies(); }
...@@ -888,6 +892,29 @@ Async::Promise<ssize_t> serveFile(ResponseWriter &writer, ...@@ -888,6 +892,29 @@ Async::Promise<ssize_t> serveFile(ResponseWriter &writer,
#undef OUT #undef OUT
} }
Private::Parser<Http::Request>::Parser(size_t maxDataSize)
: ParserBase(maxDataSize), request() {
allSteps[0].reset(new RequestLineStep(&request));
allSteps[1].reset(new HeadersStep(&request));
allSteps[2].reset(new BodyStep(&request));
}
void Private::Parser<Http::Request>::reset() {
ParserBase::reset();
request.headers_.clear();
request.body_.clear();
request.resource_.clear();
request.query_.clear();
}
Private::Parser<Http::Response>::Parser(size_t maxDataSize)
: ParserBase(maxDataSize), response() {
allSteps[0].reset(new ResponseLineStep(&response));
allSteps[1].reset(new HeadersStep(&response));
allSteps[2].reset(new BodyStep(&response));
}
void Handler::onInput(const char *buffer, size_t len, void Handler::onInput(const char *buffer, size_t len,
const std::shared_ptr<Tcp::Peer> &peer) { const std::shared_ptr<Tcp::Peer> &peer) {
auto &parser = getParser(peer); auto &parser = getParser(peer);
...@@ -936,7 +963,8 @@ void Handler::onInput(const char *buffer, size_t len, ...@@ -936,7 +963,8 @@ void Handler::onInput(const char *buffer, size_t len,
} }
void Handler::onConnection(const std::shared_ptr<Tcp::Peer> &peer) { void Handler::onConnection(const std::shared_ptr<Tcp::Peer> &peer) {
peer->putData(ParserData, std::make_shared<Private::Parser<Http::Request>>()); peer->putData(ParserData, std::make_shared<Private::Parser<Http::Request>>(
maxRequestSize_));
} }
void Handler::onDisconnection(const std::shared_ptr<Tcp::Peer> & /*peer*/) {} void Handler::onDisconnection(const std::shared_ptr<Tcp::Peer> & /*peer*/) {}
...@@ -969,6 +997,14 @@ void Timeout::onTimeout(uint64_t numWakeup) { ...@@ -969,6 +997,14 @@ void Timeout::onTimeout(uint64_t numWakeup) {
handler->onTimeout(request, std::move(response)); handler->onTimeout(request, std::move(response));
} }
void Handler::setMaxRequestSize(size_t value) { maxRequestSize_ = value; }
size_t Handler::getMaxRequestSize() const { return maxRequestSize_; }
void Handler::setMaxResponseSize(size_t value) { maxResponseSize_ = value; }
size_t Handler::getMaxResponseSize() const { return maxResponseSize_; }
Private::Parser<Http::Request> & Private::Parser<Http::Request> &
Handler::getParser(const std::shared_ptr<Tcp::Peer> &peer) const { Handler::getParser(const std::shared_ptr<Tcp::Peer> &peer) const {
return static_cast<Private::Parser<Http::Request> &>( return static_cast<Private::Parser<Http::Request> &>(
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <pistache/stream.h> #include <pistache/stream.h>
#include <algorithm> #include <algorithm>
#include <cassert>
#include <iostream> #include <iostream>
#include <string> #include <string>
...@@ -73,13 +74,46 @@ Fd FileBuffer::fd() const { return fd_; } ...@@ -73,13 +74,46 @@ Fd FileBuffer::fd() const { return fd_; }
size_t FileBuffer::size() const { return size_; } size_t FileBuffer::size() const { return size_; }
size_t DynamicStreamBuf::maxSize = Const::DefaultMaxResponseSize; DynamicStreamBuf::DynamicStreamBuf(size_t size, size_t maxSize)
: data_(), maxSize_(maxSize) {
assert(size <= maxSize);
reserve(size);
}
DynamicStreamBuf::DynamicStreamBuf(DynamicStreamBuf &&other)
: data_(std::move(other.data_)), maxSize_(std::move(other.maxSize_)) {
setp(other.pptr(), other.epptr());
other.setp(nullptr, nullptr);
}
DynamicStreamBuf &DynamicStreamBuf::operator=(DynamicStreamBuf &&other) {
if (&other != this) {
data_ = std::move(other.data_);
maxSize_ = std::move(other.maxSize_);
setp(other.pptr(), other.epptr());
other.setp(nullptr, nullptr);
}
return *this;
}
RawBuffer DynamicStreamBuf::buffer() const {
return RawBuffer(data_.data(), pptr() - data_.data());
}
size_t DynamicStreamBuf::maxSize() const { return maxSize_; }
void DynamicStreamBuf::clear() {
// reset stream buffer to the whole backing storage.
this->setp(data_.data(), data_.data() + data_.size());
}
DynamicStreamBuf::int_type DynamicStreamBuf::int_type
DynamicStreamBuf::overflow(DynamicStreamBuf::int_type ch) { DynamicStreamBuf::overflow(DynamicStreamBuf::int_type ch) {
if (!traits_type::eq_int_type(ch, traits_type::eof())) { if (!traits_type::eq_int_type(ch, traits_type::eof())) {
const auto size = data_.size(); const auto size = data_.size();
if (size < maxSize) { if (size < maxSize_) {
reserve((size ? size : 1u) * 2); reserve((size ? size : 1u) * 2);
*pptr() = ch; *pptr() = ch;
pbump(1); pbump(1);
...@@ -91,8 +125,10 @@ DynamicStreamBuf::overflow(DynamicStreamBuf::int_type ch) { ...@@ -91,8 +125,10 @@ DynamicStreamBuf::overflow(DynamicStreamBuf::int_type ch) {
} }
void DynamicStreamBuf::reserve(size_t size) { void DynamicStreamBuf::reserve(size_t size) {
if (size > maxSize) if (size > maxSize_) {
size = maxSize; size = maxSize_;
}
const size_t oldSize = data_.size(); const size_t oldSize = data_.size();
data_.resize(size); data_.resize(size);
this->setp(data_.data() + oldSize, data_.data() + size); this->setp(data_.data() + oldSize, data_.data() + size);
......
...@@ -57,12 +57,14 @@ Endpoint::Endpoint(const Address &addr) : listener(addr) {} ...@@ -57,12 +57,14 @@ Endpoint::Endpoint(const Address &addr) : listener(addr) {}
void Endpoint::init(const Endpoint::Options &options) { void Endpoint::init(const Endpoint::Options &options) {
listener.init(options.threads_, options.flags_, options.threadsName_); listener.init(options.threads_, options.flags_, options.threadsName_);
ArrayStreamBuf<char>::maxSize = options.maxRequestSize_; maxRequestSize_ = options.maxRequestSize_;
DynamicStreamBuf::maxSize = options.maxResponseSize_; maxResponseSize_ = options.maxResponseSize_;
} }
void Endpoint::setHandler(const std::shared_ptr<Handler> &handler) { void Endpoint::setHandler(const std::shared_ptr<Handler> &handler) {
handler_ = handler; handler_ = handler;
handler_->setMaxRequestSize(maxRequestSize_);
handler_->setMaxResponseSize(maxResponseSize_);
} }
void Endpoint::bind() { listener.bind(); } void Endpoint::bind() { listener.bind(); }
......
...@@ -53,6 +53,19 @@ struct QueryBounceHandler : public Http::Handler { ...@@ -53,6 +53,19 @@ struct QueryBounceHandler : public Http::Handler {
} }
}; };
namespace {
std::string largeContent(4097, 'a');
}
struct LargeContentHandler : public Http::Handler {
HTTP_PROTOTYPE(LargeContentHandler)
void onRequest(const Http::Request & /*request*/,
Http::ResponseWriter writer) override {
writer.send(Http::Code::Ok, largeContent);
}
};
TEST(http_client_test, one_client_with_one_request) { TEST(http_client_test, one_client_with_one_request) {
const Pistache::Address address("localhost", Pistache::Port(0)); const Pistache::Address address("localhost", Pistache::Port(0));
...@@ -418,12 +431,14 @@ TEST(http_client_test, client_sends_query) { ...@@ -418,12 +431,14 @@ TEST(http_client_test, client_sends_query) {
Async::Barrier<Http::Response> barrier(response); Async::Barrier<Http::Response> barrier(response);
barrier.wait_for(std::chrono::seconds(5)); barrier.wait_for(std::chrono::seconds(5));
server.shutdown();
client.shutdown();
EXPECT_EQ(queryStr[0], '?'); EXPECT_EQ(queryStr[0], '?');
std::unordered_map<std::string, std::string> results; std::unordered_map<std::string, std::string> results;
bool key = true; bool key = true;
std::string keyStr, valueStr; std::string keyStr, valueStr;
;
for (auto it = std::next(queryStr.begin()); it != queryStr.end(); it++) { for (auto it = std::next(queryStr.begin()); it != queryStr.end(); it++) {
if (*it == '&' || std::next(it) == queryStr.end()) { if (*it == '&' || std::next(it) == queryStr.end()) {
...@@ -448,7 +463,77 @@ TEST(http_client_test, client_sends_query) { ...@@ -448,7 +463,77 @@ TEST(http_client_test, client_sends_query) {
ASSERT_TRUE(query.has(entry.first)); ASSERT_TRUE(query.has(entry.first));
EXPECT_EQ(entry.second, query.get(entry.first).get()); EXPECT_EQ(entry.second, query.get(entry.first).get());
} }
}
TEST(http_client_test, client_get_large_content) {
const Pistache::Address address("localhost", Pistache::Port(0));
Http::Endpoint server(address);
auto flags = Tcp::Options::ReuseAddr;
auto server_opts = Http::Endpoint::options().flags(flags);
server.init(server_opts);
server.setHandler(Http::make_handler<LargeContentHandler>());
server.serveThreaded();
const std::string server_address = "localhost:" + server.getPort().toString();
std::cout << "Server address: " << server_address << "\n";
Http::Client client;
auto opts = Http::Client::options().maxResponseSize(8192);
client.init(opts);
auto response = client.get(server_address).send();
bool done = false;
std::string rcvContent;
response.then(
[&done, &rcvContent](Http::Response rsp) {
if (rsp.code() == Http::Code::Ok) {
done = true;
rcvContent = rsp.body();
}
},
Async::IgnoreException);
Async::Barrier<Http::Response> barrier(response);
barrier.wait_for(std::chrono::seconds(5));
server.shutdown();
client.shutdown();
ASSERT_TRUE(done);
ASSERT_EQ(largeContent, rcvContent);
}
TEST(http_client_test, client_do_not_get_large_content) {
const Pistache::Address address("localhost", Pistache::Port(0));
Http::Endpoint server(address);
auto flags = Tcp::Options::ReuseAddr;
auto server_opts = Http::Endpoint::options().flags(flags);
server.init(server_opts);
server.setHandler(Http::make_handler<LargeContentHandler>());
server.serveThreaded();
const std::string server_address = "localhost:" + server.getPort().toString();
std::cout << "Server address: " << server_address << "\n";
Http::Client client;
auto opts = Http::Client::options().maxResponseSize(4096);
client.init(opts);
auto response = client.get(server_address).send();
bool ok_flag = false;
bool exception_flag = false;
response.then(
[&ok_flag](Http::Response /*rsp*/) { ok_flag = true; },
[&exception_flag](std::exception_ptr /*ptr*/) { exception_flag = true; });
Async::Barrier<Http::Response> barrier(response);
barrier.wait_for(std::chrono::seconds(5));
server.shutdown(); server.shutdown();
client.shutdown(); client.shutdown();
ASSERT_FALSE(ok_flag);
ASSERT_TRUE(exception_flag);
} }
...@@ -11,7 +11,7 @@ using namespace Pistache; ...@@ -11,7 +11,7 @@ using namespace Pistache;
// @Todo: Add an easy to use fixture to inject data for parsing tests. // @Todo: Add an easy to use fixture to inject data for parsing tests.
TEST(http_parsing_test, should_parse_http_request_in_two_packets_issue_160) { TEST(http_parsing_test, should_parse_http_request_in_two_packets_issue_160) {
Http::Private::Parser<Http::Request> parser; Http::Private::Parser<Http::Request> parser(Const::DefaultMaxRequestSize);
auto feed = [&parser](const char *data) { auto feed = [&parser](const char *data) {
parser.feed(data, std::strlen(data)); parser.feed(data, std::strlen(data));
......
...@@ -54,7 +54,8 @@ TEST(stream, test_file_buffer) { ...@@ -54,7 +54,8 @@ TEST(stream, test_file_buffer) {
} }
TEST(stream, test_dyn_buffer) { TEST(stream, test_dyn_buffer) {
DynamicStreamBuf buf(128); DynamicStreamBuf buf(128, Const::MaxBuffer);
ASSERT_EQ(buf.maxSize(), Const::MaxBuffer);
{ {
std::ostream os(&buf); std::ostream os(&buf);
...@@ -72,11 +73,21 @@ TEST(stream, test_dyn_buffer) { ...@@ -72,11 +73,21 @@ TEST(stream, test_dyn_buffer) {
ASSERT_EQ(strlen(rawbuf.data().c_str()), 128u); ASSERT_EQ(strlen(rawbuf.data().c_str()), 128u);
} }
TEST(stream, test_array_buffer) {
ArrayStreamBuf<char> buffer(4);
const char *part1 = "abcd";
ASSERT_TRUE(buffer.feed(part1, strlen(part1)));
const char *part2 = "efgh";
ASSERT_FALSE(buffer.feed(part2, strlen(part2)));
}
TEST(stream, test_cursor_advance_for_array) { TEST(stream, test_cursor_advance_for_array) {
ArrayStreamBuf<char> buffer; ArrayStreamBuf<char> buffer(Const::MaxBuffer);
StreamCursor cursor{&buffer}; StreamCursor cursor{&buffer};
const char* part1 = "abcd"; const char *part1 = "abcd";
buffer.feed(part1, strlen(part1)); buffer.feed(part1, strlen(part1));
ASSERT_EQ(cursor.current(), 'a'); ASSERT_EQ(cursor.current(), 'a');
...@@ -90,7 +101,7 @@ TEST(stream, test_cursor_advance_for_array) { ...@@ -90,7 +101,7 @@ TEST(stream, test_cursor_advance_for_array) {
ASSERT_TRUE(cursor.advance(1)); ASSERT_TRUE(cursor.advance(1));
ASSERT_EQ(cursor.current(), 'c'); ASSERT_EQ(cursor.current(), 'c');
const char* part2 = "efgh"; const char *part2 = "efgh";
buffer.feed(part2, strlen(part2)); buffer.feed(part2, strlen(part2));
ASSERT_TRUE(cursor.advance(2)); ASSERT_TRUE(cursor.advance(2));
...@@ -100,11 +111,11 @@ TEST(stream, test_cursor_advance_for_array) { ...@@ -100,11 +111,11 @@ TEST(stream, test_cursor_advance_for_array) {
} }
TEST(stream, test_cursor_remaining_for_array) { TEST(stream, test_cursor_remaining_for_array) {
ArrayStreamBuf<char> buffer; ArrayStreamBuf<char> buffer(Const::MaxBuffer);
StreamCursor cursor{&buffer}; StreamCursor cursor{&buffer};
const char* data = "abcd"; const char *data = "abcd";
buffer.feed(data, strlen(data)); ASSERT_TRUE(buffer.feed(data, strlen(data)));
ASSERT_EQ(cursor.remaining(), 4u); ASSERT_EQ(cursor.remaining(), 4u);
cursor.advance(2); cursor.advance(2);
...@@ -118,11 +129,11 @@ TEST(stream, test_cursor_remaining_for_array) { ...@@ -118,11 +129,11 @@ TEST(stream, test_cursor_remaining_for_array) {
} }
TEST(stream, test_cursor_eol_eof_for_array) { TEST(stream, test_cursor_eol_eof_for_array) {
ArrayStreamBuf<char> buffer; ArrayStreamBuf<char> buffer(Const::MaxBuffer);
StreamCursor cursor{&buffer}; StreamCursor cursor{&buffer};
const char* data = "abcd\r\nefgh"; const char *data = "abcd\r\nefgh";
buffer.feed(data, strlen(data)); ASSERT_TRUE(buffer.feed(data, strlen(data)));
cursor.advance(4); cursor.advance(4);
ASSERT_TRUE(cursor.eol()); ASSERT_TRUE(cursor.eol());
...@@ -138,13 +149,13 @@ TEST(stream, test_cursor_eol_eof_for_array) { ...@@ -138,13 +149,13 @@ TEST(stream, test_cursor_eol_eof_for_array) {
} }
TEST(stream, test_cursor_offset_for_array) { TEST(stream, test_cursor_offset_for_array) {
ArrayStreamBuf<char> buffer; ArrayStreamBuf<char> buffer(Const::MaxBuffer);
StreamCursor cursor{&buffer}; StreamCursor cursor{&buffer};
const char* data = "abcdefgh"; const char *data = "abcdefgh";
buffer.feed(data, strlen(data)); ASSERT_TRUE(buffer.feed(data, strlen(data)));
size_t shift = 4u; const size_t shift = 4u;
cursor.advance(shift); cursor.advance(shift);
std::string result{cursor.offset(), strlen(data) - shift}; std::string result{cursor.offset(), strlen(data) - shift};
...@@ -152,14 +163,14 @@ TEST(stream, test_cursor_offset_for_array) { ...@@ -152,14 +163,14 @@ TEST(stream, test_cursor_offset_for_array) {
} }
TEST(stream, test_cursor_diff_for_array) { TEST(stream, test_cursor_diff_for_array) {
ArrayStreamBuf<char> buffer1; ArrayStreamBuf<char> buffer1(Const::MaxBuffer);
StreamCursor first_cursor{&buffer1}; StreamCursor first_cursor{&buffer1};
ArrayStreamBuf<char> buffer2; ArrayStreamBuf<char> buffer2(Const::MaxBuffer);
StreamCursor second_cursor{&buffer2}; StreamCursor second_cursor{&buffer2};
const char* data = "abcdefgh"; const char *data = "abcdefgh";
buffer1.feed(data, strlen(data)); ASSERT_TRUE(buffer1.feed(data, strlen(data)));
buffer2.feed(data, strlen(data)); ASSERT_TRUE(buffer2.feed(data, strlen(data)));
ASSERT_EQ(first_cursor.diff(second_cursor), 0u); ASSERT_EQ(first_cursor.diff(second_cursor), 0u);
ASSERT_EQ(second_cursor.diff(first_cursor), 0u); ASSERT_EQ(second_cursor.diff(first_cursor), 0u);
......
...@@ -130,6 +130,7 @@ TEST(streaming, from_description) { ...@@ -130,6 +130,7 @@ TEST(streaming, from_description) {
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION,
static_cast<CURL_WRITEFUNCTION_PTR>(curl_callback)); static_cast<CURL_WRITEFUNCTION_PTR>(curl_callback));
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &ss); curl_easy_setopt(curl, CURLOPT_WRITEDATA, &ss);
curl_easy_setopt(curl, CURLOPT_VERBOSE, 1L);
res = curl_easy_perform(curl); res = curl_easy_perform(curl);
curl_easy_cleanup(curl); curl_easy_cleanup(curl);
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment