Commit 28317f82 authored by octal's avatar octal

Now parsing Connection header. Also keeping the right http version in the response

parent 8ef8328b
......@@ -393,10 +393,20 @@ public:
return body_;
}
Version version() const {
return version_;
}
protected:
Response()
: Message()
{ }
Response(Version version)
: Message()
{
version_ = version;
}
};
class ResponseWriter : public Response {
......@@ -523,7 +533,7 @@ private:
{ }
ResponseWriter(Tcp::Transport* transport, Request request, Handler* handler)
: Response()
: Response(request.version())
, buf_(DefaultStreamSize)
, transport_(transport)
, timeout_(transport, handler, std::move(request))
......
......@@ -111,6 +111,12 @@ enum class Version {
Http11 // HTTP/1.1
};
enum class ConnectionControl {
Close,
KeepAlive,
Ext
};
enum class Expectation {
Continue,
Ext
......@@ -170,6 +176,7 @@ private:
const char* methodString(Method method);
const char* codeString(Code code);
std::ostream& operator<<(std::ostream& os, Version version);
std::ostream& operator<<(std::ostream& os, Method method);
std::ostream& operator<<(std::ostream& os, Code code);
......
......@@ -182,6 +182,27 @@ private:
std::vector<Http::CacheDirective> directives_;
};
class Connection : public Header {
public:
NAME("Connection")
Connection()
: control_(ConnectionControl::KeepAlive)
{ }
explicit Connection(ConnectionControl control)
: control_(control)
{ }
void parseRaw(const char* str, size_t len);
void write(std::ostream& os) const;
ConnectionControl control() const { return control_; }
private:
ConnectionControl control_;
};
class EncodingHeader : public Header {
public:
......
......@@ -39,7 +39,7 @@ writeHeader(Stream& stream, Args&& ...args) {
}
namespace {
bool writeStatusLine(Code code, DynamicStreamBuf& buf) {
bool writeStatusLine(Version version, Code code, DynamicStreamBuf& buf) {
#define OUT(...) \
do { \
__VA_ARGS__; \
......@@ -48,7 +48,7 @@ namespace {
std::ostream os(&buf);
OUT(os << "HTTP/1.1 ");
OUT(os << version << " ");
OUT(os << static_cast<int>(code));
OUT(os << ' ');
OUT(os << code);
......@@ -600,7 +600,7 @@ ResponseStream::ResponseStream(
, transport_(transport)
, timeout_(std::move(timeout))
{
if (!writeStatusLine(code_, buf_))
if (!writeStatusLine(version_, code_, buf_))
throw Error("Response exceeded buffer size");
if (!writeCookies(cookies_, buf_)) {
......@@ -609,6 +609,12 @@ ResponseStream::ResponseStream(
if (writeHeaders(headers_, buf_)) {
std::ostream os(&buf_);
/* @Todo @Major:
* Correctly handle non-keep alive requests
* Do not put Keep-Alive if version == Http::11 and request.keepAlive == true
*/
if (!writeHeader<Header::Connection>(os, ConnectionControl::KeepAlive));
throw Error("Response exceeded buffer size");
if (!writeHeader<Header::TransferEncoding>(os, Header::Encoding::Chunked))
throw Error("Response exceeded buffer size");
os << crlf;
......@@ -653,10 +659,15 @@ ResponseWriter::putOnWire(const char* data, size_t len)
} \
} while (0);
OUT(writeStatusLine(code_, buf_));
OUT(writeStatusLine(version_, code_, buf_));
OUT(writeHeaders(headers_, buf_));
OUT(writeCookies(cookies_, buf_));
/* @Todo @Major:
* Correctly handle non-keep alive requests
* Do not put Keep-Alive if version == Http::11 and request.keepAlive == true
*/
OUT(writeHeader<Header::Connection>(os, ConnectionControl::KeepAlive));
OUT(writeHeader<Header::ContentLength>(os, len));
OUT(os << crlf);
......@@ -718,7 +729,7 @@ serveFile(ResponseWriter& response, const char* fileName, const Mime::MediaType&
headers.add<Header::ContentType>(contentType);
};
OUT(writeStatusLine(Http::Code::Ok, *buf));
OUT(writeStatusLine(response.version(), Http::Code::Ok, *buf));
if (contentType.isValid()) {
setContentType(contentType);
} else {
......@@ -749,8 +760,8 @@ serveFile(ResponseWriter& response, const char* fileName, const Mime::MediaType&
void
Handler::onInput(const char* buffer, size_t len, const std::shared_ptr<Tcp::Peer>& peer) {
auto& parser = getParser(peer);
try {
auto& parser = getParser(peer);
if (!parser.feed(buffer, len)) {
parser.reset();
throw HttpError(Code::Request_Entity_Too_Large, "Request exceeded maximum buffer size");
......@@ -768,16 +779,16 @@ Handler::onInput(const char* buffer, size_t len, const std::shared_ptr<Tcp::Peer
parser.reset();
}
} catch (const HttpError &err) {
ResponseWriter response(transport());
ResponseWriter response(transport(), parser.request, this);
response.associatePeer(peer);
response.send(static_cast<Code>(err.code()), err.reason());
getParser(peer).reset();
parser.reset();
}
catch (const std::exception& e) {
ResponseWriter response(transport());
ResponseWriter response(transport(), parser.request, this);
response.associatePeer(peer);
response.send(Code::Internal_Server_Error, e.what());
getParser(peer).reset();
parser.reset();
}
}
......
......@@ -125,6 +125,16 @@ FullDate::write(std::ostream& os, Type type) const
}
}
const char *versionString(Version version) {
switch (version) {
case Version::Http10:
return "HTTP/1.0";
case Version::Http11:
return "HTTP/1.1";
}
unreachable();
}
const char* methodString(Method method)
{
......@@ -152,6 +162,11 @@ const char* codeString(Code code)
return "";
}
std::ostream& operator<<(std::ostream& os, Version version) {
os << versionString(version);
return os;
}
std::ostream& operator<<(std::ostream& os, Method method) {
os << methodString(method);
return os;
......
......@@ -254,6 +254,38 @@ CacheControl::addDirectives(const std::vector<Http::CacheDirective>& directives)
std::copy(std::begin(directives), std::end(directives), std::back_inserter(directives_));
}
void
Connection::parseRaw(const char* str, size_t len) {
char *p = const_cast<char *>(str);
RawStreamBuf<> buf(p, p + len);
StreamCursor cursor(&buf);
#define STR(str) \
str, sizeof(str) - 1
if (match_string(STR("close"), cursor)) {
control_ = ConnectionControl::Close;
}
else if (match_string(STR("keep-alive"), cursor)) {
control_ = ConnectionControl::KeepAlive;
}
else {
control_ = ConnectionControl::Ext;
}
}
void
Connection::write(std::ostream& os) const {
switch (control_) {
case ConnectionControl::Close:
os << "Close";
break;
case ConnectionControl::KeepAlive:
os << "Keep-Alive";
break;
}
}
void
ContentLength::parse(const std::string& data) {
try {
......
......@@ -23,6 +23,7 @@ namespace {
RegisterHeader(Accept);
RegisterHeader(Allow);
RegisterHeader(CacheControl);
RegisterHeader(Connection);
RegisterHeader(ContentEncoding);
RegisterHeader(TransferEncoding);
RegisterHeader(ContentLength);
......
......@@ -167,6 +167,32 @@ TEST(headers_test, content_length) {
ASSERT_EQ(cl.value(), 3495);
}
TEST(headers_test, connection) {
Header::Connection connection;
constexpr struct Test {
const char *data;
ConnectionControl expected;
} tests[] = {
{ "close", ConnectionControl::Close },
{ "clOse", ConnectionControl::Close },
{ "Close", ConnectionControl::Close },
{ "CLOSE", ConnectionControl::Close },
{ "keep-alive", ConnectionControl::KeepAlive },
{ "Keep-Alive", ConnectionControl::KeepAlive },
{ "kEEp-alIvE", ConnectionControl::KeepAlive },
{ "KEEP-ALIVE", ConnectionControl::KeepAlive }
};
for (auto test: tests) {
Header::Connection connection;
connection.parse(test.data);
ASSERT_EQ(connection.control(), test.expected);
}
}
TEST(headers_test, date_test) {
/* RFC-1123 */
Header::Date d1;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment