Commit 6faab0f1 authored by octal's avatar octal

Introduced a new Timeout class that can be used to trigger a timeout.

Also fixed a bunch of issues:
    - Flags<T> now works correctly and correctly zero-initializes the
      value in the default constructor
    - Fixed a nasty bug in Poller that went unnoticed because the Tag
      constructor was not marked explicit
parent ec6ee3b9
...@@ -18,7 +18,11 @@ struct ExceptionPrinter { ...@@ -18,7 +18,11 @@ struct ExceptionPrinter {
}; };
class MyHandler : public Net::Http::Handler { class MyHandler : public Net::Http::Handler {
void onRequest(const Net::Http::Request& req, Net::Http::Response response) { void onRequest(
const Net::Http::Request& req,
Net::Http::Response response,
Net::Http::Timeout timeout) {
if (req.resource() == "/ping") { if (req.resource() == "/ping") {
if (req.method() == Net::Http::Method::Get) { if (req.method() == Net::Http::Method::Get) {
...@@ -28,12 +32,10 @@ class MyHandler : public Net::Http::Handler { ...@@ -28,12 +32,10 @@ class MyHandler : public Net::Http::Handler {
.add<Header::Server>("lys") .add<Header::Server>("lys")
.add<Header::ContentType>(MIME(Text, Plain)); .add<Header::ContentType>(MIME(Text, Plain));
auto w = response.beginWrite(Net::Http::Code::Ok); std::ostream os(response.rdbuf());
std::ostream os(w);
os << "PONG"; os << "PONG";
w.send().then([](ssize_t bytes) { response.send(Net::Http::Code::Ok).then([](ssize_t bytes) {
std::cout << "Sent total of " << bytes << " bytes" << std::endl; std::cout << "Sent total of " << bytes << " bytes" << std::endl;
}, Async::IgnoreException); }, Async::IgnoreException);
...@@ -48,13 +50,15 @@ class MyHandler : public Net::Http::Handler { ...@@ -48,13 +50,15 @@ class MyHandler : public Net::Http::Handler {
throw std::runtime_error("Exception thrown in the handler"); throw std::runtime_error("Exception thrown in the handler");
} }
else if (req.resource() == "/timeout") { else if (req.resource() == "/timeout") {
response timeout.arm(std::chrono::seconds(5));
.timeoutAfter(std::chrono::seconds(1))
.then([=](Net::Http::Response *response) {
response->send(Net::Http::Code::Bad_Request, "Timeout occured");
}, Async::NoExcept);
} }
} }
void onTimeout(const Net::Http::Request& req, Net::Http::Response response) {
response
.send(Net::Http::Code::Request_Timeout, "Timeout")
.then([=](ssize_t) { }, ExceptionPrinter());
}
}; };
int main(int argc, char *argv[]) { int main(int argc, char *argv[]) {
......
...@@ -440,7 +440,7 @@ namespace Async { ...@@ -440,7 +440,7 @@ namespace Async {
{ } { }
template<typename Arg> template<typename Arg>
bool operator()(Arg&& arg) { bool operator()(Arg&& arg) const {
typedef typename std::remove_reference<Arg>::type Type; typedef typename std::remove_reference<Arg>::type Type;
if (core_->state != State::Pending) if (core_->state != State::Pending)
...@@ -461,7 +461,7 @@ namespace Async { ...@@ -461,7 +461,7 @@ namespace Async {
return true; return true;
} }
bool operator()() { bool operator()() const {
if (core_->state != State::Pending) if (core_->state != State::Pending)
throw Error("Attempt to resolve a fulfilled promise"); throw Error("Attempt to resolve a fulfilled promise");
...@@ -488,7 +488,7 @@ namespace Async { ...@@ -488,7 +488,7 @@ namespace Async {
template<typename Exc> template<typename Exc>
bool operator()(Exc exc) { bool operator()(Exc exc) const {
if (core_->state != State::Pending) if (core_->state != State::Pending)
throw Error("Attempt to reject a fulfilled promise"); throw Error("Attempt to reject a fulfilled promise");
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#pragma once #pragma once
#include <type_traits> #include <type_traits>
#include <climits>
#include <iostream> #include <iostream>
// Looks like gcc 4.6 does not implement std::underlying_type // Looks like gcc 4.6 does not implement std::underlying_type
...@@ -28,15 +29,31 @@ namespace detail { ...@@ -28,15 +29,31 @@ namespace detail {
template<typename T> struct UnderlyingType { template<typename T> struct UnderlyingType {
typedef typename TypeStorage<sizeof(T)>::Type Type; typedef typename TypeStorage<sizeof(T)>::Type Type;
}; };
template<typename Enum>
struct HasNone {
template<typename U>
static auto test(U *) -> decltype(U::None, std::true_type());
template<typename U>
static auto test(...) -> std::false_type;
static constexpr bool value =
std::is_same<decltype(test<Enum>(0)), std::true_type>::value;
};
} }
template<typename T> template<typename T>
class Flags { class Flags {
public: public:
static_assert(std::is_enum<T>::value, "Flags only works with enumerations");
typedef typename detail::UnderlyingType<T>::Type Type; typedef typename detail::UnderlyingType<T>::Type Type;
Flags() { } static_assert(std::is_enum<T>::value, "Flags only works with enumerations");
static_assert(detail::HasNone<T>::value, "The enumartion needs a None value");
static_assert(static_cast<Type>(T::None) == 0, "None should be 0");
Flags() : val(T::None) {
}
Flags(T val) : val(val) Flags(T val) : val(val)
{ {
...@@ -83,13 +100,12 @@ public: ...@@ -83,13 +100,12 @@ public:
#undef DEFINE_BITWISE_OP #undef DEFINE_BITWISE_OP
bool hasFlag(T flag) const { bool hasFlag(T flag) const {
return static_cast<T>( return static_cast<Type>(val) & static_cast<Type>(flag);
static_cast<Type>(val) & static_cast<Type>(flag)
) == flag;
} }
Flags<T>& setFlag(T flag) { Flags<T>& setFlag(T flag) {
return *this &= flag; *this |= flag;
return *this;
} }
Flags<T>& toggleFlag(T flag) { Flags<T>& toggleFlag(T flag) {
...@@ -115,3 +131,15 @@ private: ...@@ -115,3 +131,15 @@ private:
#define DECLARE_FLAGS_OPERATORS(T) \ #define DECLARE_FLAGS_OPERATORS(T) \
DEFINE_BITWISE_OP(&, T) \ DEFINE_BITWISE_OP(&, T) \
DEFINE_BITWISE_OP(|, T) DEFINE_BITWISE_OP(|, T)
template<typename T>
std::ostream& operator<<(std::ostream& os, Flags<T> flags) {
typedef typename detail::UnderlyingType<T>::Type UnderlyingType;
auto val = static_cast<UnderlyingType>(static_cast<T>(flags));
for (ssize_t i = (sizeof(UnderlyingType) * CHAR_BIT) - 1; i >= 0; --i) {
os << ((val >> i) & 0x1);
}
return os;
}
...@@ -348,67 +348,60 @@ Request::query() const { ...@@ -348,67 +348,60 @@ Request::query() const {
return query_; return query_;
} }
Async::Promise<Response *> std::shared_ptr<Tcp::Peer>
Response::timeoutAfter(std::chrono::milliseconds timeout) Request::peer() const {
{ auto p = peer_.lock();
#if 0
Async::Promise<uint64_t> promise([=](Async::Resolver& resolve, Async::Rejection& reject) { if (!p) throw std::runtime_error("Failed to retrieve peer: Broken pipe");
peer()->io()->setTimeout(timeout, resolve, reject);
}); return p;
promise.then([=](uint64_t) {
send(Code::Bad_Request, "A timeout occured");
}, Async::NoExcept);
return promise;
#endif
Async::Promise<uint64_t> promise([=](Async::Resolver& resolve, Async::Rejection& reject) {
peer()->io()->setTimeout(timeout, resolve, reject);
});
return promise.then([=](uint64_t) {
auto p = Async::Promise<Response *>::resolved(this);
return p;
}, Async::NoExcept);
} }
Async::Promise<ssize_t> Async::Promise<ssize_t>
ResponseWriter::send() const Response::putOnWire() const
{ {
auto body = stream_.buffer(); try {
auto body = stream_.buffer();
NetworkStream stream(512 + body.len); NetworkStream stream(512 + body.len);
std::ostream os(&stream); std::ostream os(&stream);
#define OUT(...) \ #define OUT(...) \
do { \ do { \
__VA_ARGS__; \ __VA_ARGS__; \
if (!os) { \ if (!os) { \
return Async::Promise<ssize_t>::rejected(Error("Response exceeded buffer size")); \ return Async::Promise<ssize_t>::rejected(Error("Response exceeded buffer size")); \
} \ } \
} while (0); } while (0);
OUT(os << "HTTP/1.1 "); OUT(os << "HTTP/1.1 ");
OUT(os << static_cast<int>(code_)); OUT(os << static_cast<int>(code_));
OUT(os << ' '); OUT(os << ' ');
OUT(os << code_); OUT(os << code_);
OUT(os << crlf);
for (const auto& header: headers_.list()) {
OUT(os << header->name() << ": ");
OUT(header->write(os));
OUT(os << crlf); OUT(os << crlf);
}
OUT(writeHeader<Header::ContentLength>(os, body.len)); for (const auto& header: headers_.list()) {
OUT(os << crlf); OUT(os << header->name() << ": ");
OUT(header->write(os));
OUT(os << crlf);
}
OUT(os.write(static_cast<const char *>(body.data), body.len)); OUT(writeHeader<Header::ContentLength>(os, body.len));
OUT(os << crlf);
OUT(os.write(static_cast<const char *>(body.data), body.len));
auto buf = stream.buffer();
if (io_) {
io_->disarmTimer();
}
auto buf = stream.buffer(); return peer()->send(buf.data, buf.len);
return peer()->send(buf.data, buf.len); } catch (const std::runtime_error& e) {
return Async::Promise<ssize_t>::rejected(e);
}
} }
void void
...@@ -422,19 +415,23 @@ Handler::onInput(const char* buffer, size_t len, const std::shared_ptr<Tcp::Peer ...@@ -422,19 +415,23 @@ Handler::onInput(const char* buffer, size_t len, const std::shared_ptr<Tcp::Peer
auto state = parser.parse(); auto state = parser.parse();
if (state == Private::State::Done) { if (state == Private::State::Done) {
Response response; Response response(io());
response.associatePeer(peer); response.associatePeer(peer);
onRequest(parser.request, std::move(response));
Timeout timeout(io(), this, peer, parser.request);
parser.request.associatePeer(peer);
onRequest(parser.request, std::move(response), std::move(timeout));
parser.reset(); parser.reset();
} }
} catch (const HttpError &err) { } catch (const HttpError &err) {
Response response; Response response(io());
response.associatePeer(peer); response.associatePeer(peer);
response.send(static_cast<Code>(err.code()), err.reason()); response.send(static_cast<Code>(err.code()), err.reason());
getParser(peer).reset(); getParser(peer).reset();
} }
catch (const std::exception& e) { catch (const std::exception& e) {
Response response; Response response(io());
response.associatePeer(peer); response.associatePeer(peer);
response.send(Code::Internal_Server_Error, e.what()); response.send(Code::Internal_Server_Error, e.what());
getParser(peer).reset(); getParser(peer).reset();
...@@ -450,6 +447,20 @@ void ...@@ -450,6 +447,20 @@ void
Handler::onDisconnection(const shared_ptr<Tcp::Peer>& peer) { Handler::onDisconnection(const shared_ptr<Tcp::Peer>& peer) {
} }
void
Handler::onTimeout(const Request& request, Response response) {
}
void
Timeout::onTimeout(uint64_t numWakeup) {
if (!peer.lock()) return;
Response response(io);
response.associatePeer(peer);
handler->onTimeout(request, std::move(response));
}
Private::Parser& Private::Parser&
Handler::getParser(const std::shared_ptr<Tcp::Peer>& peer) const { Handler::getParser(const std::shared_ptr<Tcp::Peer>& peer) const {
return *peer->getData<Private::Parser>(ParserData); return *peer->getData<Private::Parser>(ParserData);
......
...@@ -71,7 +71,7 @@ public: ...@@ -71,7 +71,7 @@ public:
friend class Private::BodyStep; friend class Private::BodyStep;
friend class Private::Parser; friend class Private::Parser;
Request(); friend class Handler;
Version version() const; Version version() const;
Method method() const; Method method() const;
...@@ -82,88 +82,33 @@ public: ...@@ -82,88 +82,33 @@ public:
const Header::Collection& headers() const; const Header::Collection& headers() const;
const Uri::Query& query() const; const Uri::Query& query() const;
private: std::shared_ptr<Tcp::Peer> peer() const;
Method method_;
std::string resource_;
Uri::Query query_;
};
class Handler;
class ResponseWriter : private Message {
public:
ResponseWriter(const ResponseWriter& other) = delete;
ResponseWriter& operator=(const ResponseWriter& other) = delete;
ResponseWriter(ResponseWriter&& other)
: Message(std::move(other))
, peer_(std::move(other.peer_))
, stream_(std::move(other.stream_))
{ }
ResponseWriter& operator=(ResponseWriter&& other) {
Message::operator=(std::move(other));
peer_ = std::move(other.peer_);
stream_ = std::move(other.stream_);
return *this;
}
friend class Response;
const Header::Collection& headers() const {
return headers_;
}
Code code() const {
return code_;
}
void setCode(Code code) {
code_ = code;
}
std::streambuf* rdbuf() {
return &stream_;
}
operator std::streambuf*() {
return &stream_;
}
Async::Promise<ssize_t> send() const;
private: private:
ResponseWriter(Message&& other, size_t size, std::weak_ptr<Tcp::Peer> peer) Request();
: Message(std::move(other))
, stream_(size)
, peer_(peer)
{
}
ResponseWriter(const Message& other, size_t size, std::weak_ptr<Tcp::Peer> peer)
: Message(other)
, stream_(size)
, peer_(peer)
{
}
std::shared_ptr<Tcp::Peer> peer() const { void associatePeer(const std::shared_ptr<Tcp::Peer>& peer) {
if (peer_.expired()) { if (peer_.use_count() > 0)
throw std::runtime_error("Broken pipe"); throw std::runtime_error("A peer was already associated to the response");
}
return peer_.lock(); peer_ = peer;
} }
NetworkStream stream_; Method method_;
std::string resource_;
Uri::Query query_;
std::weak_ptr<Tcp::Peer> peer_; std::weak_ptr<Tcp::Peer> peer_;
}; };
class Handler;
class Timeout;
// 6. Response // 6. Response
class Response : private Message { class Response : private Message {
public: public:
friend class Handler; friend class Handler;
friend class Timeout;
static constexpr size_t DefaultStreamSize = 512; static constexpr size_t DefaultStreamSize = 512;
...@@ -173,9 +118,13 @@ public: ...@@ -173,9 +118,13 @@ public:
Response(Response&& other) Response(Response&& other)
: Message(std::move(other)) : Message(std::move(other))
, peer_(other.peer_) , peer_(other.peer_)
, stream_(512)
, io_(other.io_)
{ } { }
Response& operator=(Response&& other) { Response& operator=(Response&& other) {
Message::operator=(std::move(other));
peer_ = other.peer_; peer_ = other.peer_;
io_ = other.io_;
return *this; return *this;
} }
...@@ -202,7 +151,8 @@ public: ...@@ -202,7 +151,8 @@ public:
} }
Async::Promise<ssize_t> send(Code code) { Async::Promise<ssize_t> send(Code code) {
return send(code, ""); code_ = code;
return putOnWire();
} }
Async::Promise<ssize_t> send( Async::Promise<ssize_t> send(
Code code, Code code,
...@@ -219,44 +169,92 @@ public: ...@@ -219,44 +169,92 @@ public:
headers_.add(std::make_shared<Header::ContentType>(mime)); headers_.add(std::make_shared<Header::ContentType>(mime));
} }
ResponseWriter w(*this, body.size(), peer_); std::ostream os(&stream_);
std::ostream os(w);
os << body; os << body;
if (!os) if (!os)
return Async::Promise<ssize_t>::rejected(Error("Response exceeded buffer size")); return Async::Promise<ssize_t>::rejected(Error("Response exceeded buffer size"));
return w.send();
}
/* @Revisit: not sure about the name yet */ return putOnWire();
ResponseWriter
beginWrite(Code code = Code::Ok, size_t size = DefaultStreamSize) {
code_ = code;
return ResponseWriter(std::move(*this), size, peer_);
} }
Async::Promise<Response *> timeoutAfter(std::chrono::milliseconds timeout);
std::streambuf *rdbuf() {
return &stream_;
}
private: private:
Response() Response(Tcp::IoWorker* io)
: Message() : Message()
, stream_(512)
, io_(io)
{ } { }
std::shared_ptr<Tcp::Peer> peer() const { std::shared_ptr<Tcp::Peer> peer() const {
if (peer_.expired()) { if (peer_.expired())
throw std::runtime_error("Broken pipe"); throw std::runtime_error("Write failed: Broken pipe");
}
return peer_.lock(); return peer_.lock();
} }
void associatePeer(const std::shared_ptr<Tcp::Peer>& peer) { template<typename Ptr>
void associatePeer(const Ptr& peer) {
if (peer_.use_count() > 0) if (peer_.use_count() > 0)
throw std::runtime_error("A peer was already associated to the response"); throw std::runtime_error("A peer was already associated to the response");
peer_ = peer; peer_ = peer;
} }
Async::Promise<ssize_t> putOnWire() const;
std::weak_ptr<Tcp::Peer> peer_; std::weak_ptr<Tcp::Peer> peer_;
NetworkStream stream_;
Tcp::IoWorker *io_;
};
class Timeout {
public:
friend class Handler;
template<typename Duration>
void arm(Duration duration) {
Async::Promise<uint64_t> p([=](Async::Resolver& resolve, Async::Rejection& reject) {
io->armTimer(duration, resolve, reject);
});
p.then(
[=](uint64_t numWakeup) {
this->onTimeout(numWakeup);
},
[=](std::exception_ptr exc) {
std::rethrow_exception(exc);
});
}
void disarm() {
io->disarmTimer();
}
private:
Timeout(Tcp::IoWorker* io,
Handler* handler,
const std::shared_ptr<Tcp::Peer>& peer,
const Request& request)
: io(io)
, handler(handler)
, peer(peer)
, request(request)
{ }
void onTimeout(uint64_t numWakeup);
Handler* handler;
std::weak_ptr<Tcp::Peer> peer;
Request request;
Tcp::IoWorker *io;
}; };
namespace Private { namespace Private {
...@@ -358,7 +356,9 @@ public: ...@@ -358,7 +356,9 @@ public:
void onConnection(const std::shared_ptr<Tcp::Peer>& peer); void onConnection(const std::shared_ptr<Tcp::Peer>& peer);
void onDisconnection(const std::shared_ptr<Tcp::Peer>& peer); void onDisconnection(const std::shared_ptr<Tcp::Peer>& peer);
virtual void onRequest(const Request& request, Response response) = 0; virtual void onRequest(const Request& request, Response response, Timeout timeout) = 0;
virtual void onTimeout(const Request& request, Response response);
private: private:
Private::Parser& getParser(const std::shared_ptr<Tcp::Peer>& peer) const; Private::Parser& getParser(const std::shared_ptr<Tcp::Peer>& peer) const;
......
...@@ -49,6 +49,8 @@ IoWorker::~IoWorker() { ...@@ -49,6 +49,8 @@ IoWorker::~IoWorker() {
void void
IoWorker::start(const std::shared_ptr<Handler>& handler, Flags<Options> options) { IoWorker::start(const std::shared_ptr<Handler>& handler, Flags<Options> options) {
handler_ = handler; handler_ = handler;
handler_->io_ = this;
options_ = options; options_ = options;
thread.reset(new std::thread([this]() { thread.reset(new std::thread([this]() {
...@@ -74,7 +76,7 @@ IoWorker::pin(const CpuSet& set) { ...@@ -74,7 +76,7 @@ IoWorker::pin(const CpuSet& set) {
} }
void void
IoWorker::setTimeoutMs( IoWorker::armTimer(
std::chrono::milliseconds value, std::chrono::milliseconds value,
Async::Resolver resolve, Async::Rejection reject) Async::Resolver resolve, Async::Rejection reject)
{ {
...@@ -99,8 +101,23 @@ IoWorker::setTimeoutMs( ...@@ -99,8 +101,23 @@ IoWorker::setTimeoutMs(
return; return;
} }
timeout = Some(Timeout(value, std::move(resolve), std::move(reject))); timer = Some(Timer(value, std::move(resolve), std::move(reject)));
}
void
IoWorker::disarmTimer()
{
itimerspec spec;
spec.it_value.tv_sec = spec.it_value.tv_nsec = 0;
spec.it_interval.tv_sec = spec.it_interval.tv_nsec = 0;
int res = timerfd_settime(timerFd, 0, &spec, 0);
if (res == -1)
throw Error::system("Could not set timer time");
timer = None();
} }
void void
...@@ -147,8 +164,7 @@ IoWorker::handleIncoming(const std::shared_ptr<Peer>& peer) { ...@@ -147,8 +164,7 @@ IoWorker::handleIncoming(const std::shared_ptr<Peer>& peer) {
} }
} else { } else {
if (errno == ECONNRESET) { if (errno == ECONNRESET) {
handler_->onDisconnection(peer); handlePeerDisconnection(peer);
close(fd);
} }
else { else {
throw std::runtime_error(strerror(errno)); throw std::runtime_error(strerror(errno));
...@@ -157,8 +173,7 @@ IoWorker::handleIncoming(const std::shared_ptr<Peer>& peer) { ...@@ -157,8 +173,7 @@ IoWorker::handleIncoming(const std::shared_ptr<Peer>& peer) {
break; break;
} }
else if (bytes == 0) { else if (bytes == 0) {
handler_->onDisconnection(peer); handlePeerDisconnection(peer);
close(fd);
break; break;
} }
...@@ -185,9 +200,26 @@ IoWorker::handleNewPeer(const std::shared_ptr<Peer>& peer) ...@@ -185,9 +200,26 @@ IoWorker::handleNewPeer(const std::shared_ptr<Peer>& peer)
peer->io_ = this; peer->io_ = this;
handler_->onConnection(peer); handler_->onConnection(peer);
poller.addFd(fd, NotifyOn::Read, Polling::Tag(fd), Polling::Mode::Edge); poller.addFd(fd, NotifyOn::Read | NotifyOn::Shutdown, Polling::Tag(fd), Polling::Mode::Edge);
} }
void
IoWorker::handlePeerDisconnection(const std::shared_ptr<Peer>& peer)
{
handler_->onDisconnection(peer);
int fd = peer->fd();
{
std::unique_lock<std::mutex> guard(peersMutex);
auto it = peers.find(fd);
if (it == std::end(peers))
throw std::runtime_error("Could not find peer to erase");
peers.erase(it);
}
close(fd);
}
void void
IoWorker::run() { IoWorker::run() {
...@@ -226,6 +258,9 @@ IoWorker::run() { ...@@ -226,6 +258,9 @@ IoWorker::run() {
handleIncoming(peer); handleIncoming(peer);
} }
} }
else if (event.flags.hasFlag(NotifyOn::Shutdown)) {
handlePeerDisconnection(getPeer(event.tag));
}
else if (event.flags.hasFlag(NotifyOn::Write)) { else if (event.flags.hasFlag(NotifyOn::Write)) {
auto fd = event.tag.value(); auto fd = event.tag.value();
auto it = toWrite.find(fd); auto it = toWrite.find(fd);
...@@ -257,24 +292,24 @@ IoWorker::run() { ...@@ -257,24 +292,24 @@ IoWorker::run() {
void void
IoWorker::handleTimeout() { IoWorker::handleTimeout() {
auto& entry = timeout.unsafeGet(); optionally_do(timer, [=](const Timer& entry) {
uint64_t numWakeups;
uint64_t numWakeups; int res = ::read(timerFd, &numWakeups, sizeof numWakeups);
int res = ::read(timerFd, &numWakeups, sizeof numWakeups); if (res == -1) {
if (res == -1) { if (errno == EAGAIN || errno == EWOULDBLOCK)
if (errno == EAGAIN || errno == EWOULDBLOCK) return;
return; else
else entry.reject(Error::system("Could not read timerfd"));
entry.reject(Error::system("Could not read timerfd")); } else {
} else { if (res != sizeof(numWakeups)) {
if (res != sizeof(numWakeups)) { entry.reject(Error("Read invalid number of bytes for timer fd: "
entry.reject(Error("Read invalid number of bytes for timer fd: " + std::to_string(timerFd)));
+ std::to_string(timerFd))); }
} else {
else { entry.resolve(numWakeups);
entry.resolve(numWakeups); }
} }
} });
} }
Async::Promise<ssize_t> Async::Promise<ssize_t>
......
...@@ -40,12 +40,14 @@ public: ...@@ -40,12 +40,14 @@ public:
void shutdown(); void shutdown();
template<typename Duration> template<typename Duration>
void setTimeout(Duration timeout, Async::Resolver resolve, Async::Rejection reject) {
setTimeoutMs(std::chrono::duration_cast<std::chrono::milliseconds>(timeout), void armTimer(Duration timeout, Async::Resolver resolve, Async::Rejection reject) {
armTimer(std::chrono::duration_cast<std::chrono::milliseconds>(timeout),
std::move(resolve), std::move(resolve),
std::move(reject)); std::move(reject));
} }
void disarmTimer();
private: private:
struct OnHoldWrite { struct OnHoldWrite {
...@@ -65,10 +67,10 @@ private: ...@@ -65,10 +67,10 @@ private:
}; };
void void
setTimeoutMs(std::chrono::milliseconds value, Async::Resolver, Async::Rejection reject); armTimer(std::chrono::milliseconds value, Async::Resolver, Async::Rejection reject);
struct Timeout { struct Timer {
Timeout(std::chrono::milliseconds value, Timer(std::chrono::milliseconds value,
Async::Resolver resolve, Async::Resolver resolve,
Async::Rejection reject) Async::Rejection reject)
: value(value) : value(value)
...@@ -88,7 +90,7 @@ private: ...@@ -88,7 +90,7 @@ private:
std::unordered_map<Fd, std::shared_ptr<Peer>> peers; std::unordered_map<Fd, std::shared_ptr<Peer>> peers;
std::unordered_map<Fd, OnHoldWrite> toWrite; std::unordered_map<Fd, OnHoldWrite> toWrite;
Optional<Timeout> timeout; Optional<Timer> timer;
Fd timerFd; Fd timerFd;
std::shared_ptr<Handler> handler_; std::shared_ptr<Handler> handler_;
...@@ -101,6 +103,8 @@ private: ...@@ -101,6 +103,8 @@ private:
Async::Promise<ssize_t> asyncWrite(Fd fd, const void *buf, size_t len); Async::Promise<ssize_t> asyncWrite(Fd fd, const void *buf, size_t len);
void handlePeerDisconnection(const std::shared_ptr<Peer>& peer);
void handleIncoming(const std::shared_ptr<Peer>& peer); void handleIncoming(const std::shared_ptr<Peer>& peer);
void handleTimeout(); void handleTimeout();
void run(); void run();
......
...@@ -195,7 +195,7 @@ namespace Polling { ...@@ -195,7 +195,7 @@ namespace Polling {
Event event(tag); Event event(tag);
event.flags = toNotifyOn(ev->events); event.flags = toNotifyOn(ev->events);
events.push_back(tag); events.push_back(event);
} }
} }
...@@ -212,6 +212,8 @@ namespace Polling { ...@@ -212,6 +212,8 @@ namespace Polling {
events |= EPOLLOUT; events |= EPOLLOUT;
if (interest.hasFlag(NotifyOn::Hangup)) if (interest.hasFlag(NotifyOn::Hangup))
events |= EPOLLHUP; events |= EPOLLHUP;
if (interest.hasFlag(NotifyOn::Shutdown))
events |= EPOLLRDHUP;
return events; return events;
} }
...@@ -226,6 +228,9 @@ namespace Polling { ...@@ -226,6 +228,9 @@ namespace Polling {
flags.setFlag(NotifyOn::Write); flags.setFlag(NotifyOn::Write);
if (events & EPOLLHUP) if (events & EPOLLHUP)
flags.setFlag(NotifyOn::Hangup); flags.setFlag(NotifyOn::Hangup);
if (events & EPOLLRDHUP) {
flags.setFlag(NotifyOn::Shutdown);
}
return flags; return flags;
} }
......
...@@ -52,9 +52,12 @@ enum class Mode { ...@@ -52,9 +52,12 @@ enum class Mode {
}; };
enum class NotifyOn { enum class NotifyOn {
Read, None = 0,
Write,
Hangup Read = 1,
Write = Read << 1,
Hangup = Write << 1,
Shutdown = Hangup << 1
}; };
DECLARE_FLAGS_OPERATORS(NotifyOn); DECLARE_FLAGS_OPERATORS(NotifyOn);
...@@ -79,7 +82,7 @@ inline constexpr bool operator==(Tag lhs, Tag rhs) { ...@@ -79,7 +82,7 @@ inline constexpr bool operator==(Tag lhs, Tag rhs) {
} }
struct Event { struct Event {
Event(Tag tag) : explicit Event(Tag tag) :
tag(tag) tag(tag)
{ } { }
......
...@@ -52,10 +52,6 @@ public: ...@@ -52,10 +52,6 @@ public:
Async::Promise<ssize_t> send(const void* buf, size_t len); Async::Promise<ssize_t> send(const void* buf, size_t len);
IoWorker *io() {
return io_;
}
private: private:
IoWorker* io_; IoWorker* io_;
......
...@@ -28,8 +28,12 @@ enum class Options : uint64_t { ...@@ -28,8 +28,12 @@ enum class Options : uint64_t {
DECLARE_FLAGS_OPERATORS(Options) DECLARE_FLAGS_OPERATORS(Options)
class IoWorker;
class Handler { class Handler {
public: public:
friend class IoWorker;
Handler(); Handler();
~Handler(); ~Handler();
...@@ -38,6 +42,11 @@ public: ...@@ -38,6 +42,11 @@ public:
virtual void onConnection(const std::shared_ptr<Tcp::Peer>& peer); virtual void onConnection(const std::shared_ptr<Tcp::Peer>& peer);
virtual void onDisconnection(const std::shared_ptr<Tcp::Peer>& peer); virtual void onDisconnection(const std::shared_ptr<Tcp::Peer>& peer);
protected:
IoWorker *io() { return io_; }
private:
IoWorker *io_;
}; };
} // namespace Tcp } // namespace Tcp
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment