Commit cb274c8e authored by octal's avatar octal

Improved the client architecture and fixed a data-race in async

parent e703fc45
...@@ -15,28 +15,32 @@ int main() { ...@@ -15,28 +15,32 @@ int main() {
Http::Experimental::Client client("http://supnetwork.org:9080"); Http::Experimental::Client client("http://supnetwork.org:9080");
auto opts = Http::Experimental::Client::options() auto opts = Http::Experimental::Client::options()
.threads(1) .threads(1)
.maxConnections(20); .maxConnections(64);
using namespace Net::Http; using namespace Net::Http;
constexpr size_t Requests = 5000; constexpr size_t Requests = 10000;
std::atomic<int> responsesReceived(0); std::atomic<int> responsesReceived(0);
client.init(opts); client.init(opts);
for (int i = 0; i < Requests; ++i) { for (int i = 0; i < Requests; ++i) {
client.get(client client.get(client
.request("/ping") .request("/ping")
.cookie(Cookie("FOO", "bar")), std::chrono::milliseconds(1000)) .cookie(Cookie("FOO", "bar")))
.then([&](const Http::Response& response) { .then([&](const Http::Response& response) {
responsesReceived.fetch_add(1); responsesReceived.fetch_add(1);
//std::cout << "code = " << response.code() << std::endl; //std::cout << "code = " << response.code() << std::endl;
// std::cout << "body = " << response.body() << std::endl; //std::cout << "body = " << response.body() << std::endl;
}, Async::NoExcept); }, Async::NoExcept);
const auto count = i + 1;
if (count % 10 == 0)
std::cout << "Sent " << count << " requests" << std::endl;
} }
std::cout << "Sent " << Requests << " requests" << std::endl;
for (;;) { for (;;) {
std::this_thread::sleep_for(std::chrono::seconds(1)); std::this_thread::sleep_for(std::chrono::seconds(1));
std::cout << "Received " << responsesReceived.load() << " responses" << std::endl; auto count = responsesReceived.load();
std::cout << "Received " << count << " responses" << std::endl;
if (count == Requests) break;
} }
client.shutdown(); client.shutdown();
} }
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <memory> #include <memory>
#include <atomic> #include <atomic>
#include <vector> #include <vector>
#include <mutex>
#include "optional.h" #include "optional.h"
#include "typeid.h" #include "typeid.h"
...@@ -170,6 +171,19 @@ namespace Async { ...@@ -170,6 +171,19 @@ namespace Async {
State state; State state;
std::exception_ptr exc; std::exception_ptr exc;
/*
* We need this lock because a Promise might be resolved or rejected from a thread A
* while a continuation to the same Promise (Core) might be attached at the same from
* a thread B. If that's the case, then we need to serialize operations so that we
* avoid a race-condition.
*
* Since we have a lock, we have a blocking progress guarantee but I don't expect this
* to be a major bottleneck as I don't expect major contention on the lock
* If it ends up being a bottlenick, try @improving it by experimenting with a lock-free
* scheme
*/
std::mutex mtx;
std::vector<std::shared_ptr<Request>> requests; std::vector<std::shared_ptr<Request>> requests;
TypeId id; TypeId id;
...@@ -536,10 +550,13 @@ namespace Async { ...@@ -536,10 +550,13 @@ namespace Async {
* than runtime. However, since types are erased, this looks like * than runtime. However, since types are erased, this looks like
* a difficult task * a difficult task
*/ */
if (core_->isVoid()) if (core_->isVoid()) {
throw Error("Attempt to resolve a void promise with arguments"); throw Error("Attempt to resolve a void promise with arguments");
}
std::unique_lock<std::mutex> guard(core_->mtx);
core_->construct<Type>(std::forward<Arg>(arg)); core_->construct<Type>(std::forward<Arg>(arg));
for (const auto& req: core_->requests) { for (const auto& req: core_->requests) {
req->resolve(core_); req->resolve(core_);
} }
...@@ -554,6 +571,7 @@ namespace Async { ...@@ -554,6 +571,7 @@ namespace Async {
if (!core_->isVoid()) if (!core_->isVoid())
throw Error("Attempt ro resolve a non-void promise with no argument"); throw Error("Attempt ro resolve a non-void promise with no argument");
std::unique_lock<std::mutex> guard(core_->mtx);
core_->state = State::Fulfilled; core_->state = State::Fulfilled;
for (const auto& req: core_->requests) { for (const auto& req: core_->requests) {
req->resolve(core_); req->resolve(core_);
...@@ -578,6 +596,7 @@ namespace Async { ...@@ -578,6 +596,7 @@ namespace Async {
if (core_->state != State::Pending) if (core_->state != State::Pending)
throw Error("Attempt to reject a fulfilled promise"); throw Error("Attempt to reject a fulfilled promise");
std::unique_lock<std::mutex> guard(core_->mtx);
core_->exc = std::make_exception_ptr(exc); core_->exc = std::make_exception_ptr(exc);
core_->state = State::Rejected; core_->state = State::Rejected;
for (const auto& req: core_->requests) { for (const auto& req: core_->requests) {
...@@ -716,6 +735,7 @@ namespace Async { ...@@ -716,6 +735,7 @@ namespace Async {
std::forward<ResolveFunc>(resolveFunc), std::forward<ResolveFunc>(resolveFunc),
std::forward<RejectFunc>(rejectFunc))); std::forward<RejectFunc>(rejectFunc)));
std::unique_lock<std::mutex> guard(core_->mtx);
if (isFulfilled()) { if (isFulfilled()) {
req->resolve(core_); req->resolve(core_);
} }
...@@ -723,9 +743,9 @@ namespace Async { ...@@ -723,9 +743,9 @@ namespace Async {
req->reject(core_); req->reject(core_);
} }
core_->requests.push_back(req); core_->requests.push_back(req);
return promise; return promise;
} }
private: private:
......
...@@ -24,12 +24,21 @@ namespace Experimental { ...@@ -24,12 +24,21 @@ namespace Experimental {
class ConnectionPool; class ConnectionPool;
class Transport; class Transport;
struct Connection { struct Connection : public std::enable_shared_from_this<Connection> {
friend class ConnectionPool; friend class ConnectionPool;
typedef std::function<void()> OnDone; typedef std::function<void()> OnDone;
Connection()
: fd(-1)
, connectionState_(NotConnected)
, inflightCount(0)
, responsesReceived(0)
{
state_.store(static_cast<uint32_t>(State::Idle));
}
struct RequestData { struct RequestData {
RequestData( RequestData(
...@@ -48,9 +57,9 @@ struct Connection { ...@@ -48,9 +57,9 @@ struct Connection {
Async::Resolver resolve; Async::Resolver resolve;
Async::Rejection reject; Async::Rejection reject;
Http::Request request;
std::string host; std::string host;
std::chrono::milliseconds timeout; std::chrono::milliseconds timeout;
Http::Request request;
OnDone onDone; OnDone onDone;
}; };
...@@ -65,13 +74,6 @@ struct Connection { ...@@ -65,13 +74,6 @@ struct Connection {
Connected Connected
}; };
Connection()
: fd(-1)
, connectionState_(NotConnected)
{
state_.store(static_cast<uint32_t>(State::Idle));
}
void connect(Net::Address addr); void connect(Net::Address addr);
void close(); void close();
bool isConnected() const; bool isConnected() const;
...@@ -94,15 +96,45 @@ struct Connection { ...@@ -94,15 +96,45 @@ struct Connection {
Fd fd; Fd fd;
void handleResponsePacket(const char* buffer, size_t totalBytes);
void handleTimeout();
std::string dump() const;
private: private:
std::atomic<int> inflightCount;
std::atomic<int> responsesReceived;
struct sockaddr_in saddr;
void processRequestQueue(); void processRequestQueue();
struct RequestEntry {
RequestEntry(
Async::Resolver resolve, Async::Rejection reject,
std::shared_ptr<TimerPool::Entry> timer,
OnDone onDone)
: resolve(std::move(resolve))
, reject(std::move(reject))
, timer(std::move(timer))
, onDone(std::move(onDone))
{ }
Async::Resolver resolve;
Async::Rejection reject;
std::shared_ptr<TimerPool::Entry> timer;
OnDone onDone;
};
std::atomic<uint32_t> state_; std::atomic<uint32_t> state_;
ConnectionState connectionState_; ConnectionState connectionState_;
std::shared_ptr<Transport> transport_; std::shared_ptr<Transport> transport_;
Queue<RequestData> requestsQueue; Queue<RequestData> requestsQueue;
std::deque<RequestEntry> inflightRequests;
Net::TimerPool timerPool_; Net::TimerPool timerPool_;
Private::Parser<Http::Response> parser_;
}; };
struct ConnectionPool { struct ConnectionPool {
...@@ -129,15 +161,12 @@ public: ...@@ -129,15 +161,12 @@ public:
void registerPoller(Polling::Epoll& poller); void registerPoller(Polling::Epoll& poller);
Async::Promise<void> Async::Promise<void>
asyncConnect(Fd fd, const struct sockaddr* address, socklen_t addr_len); asyncConnect(const std::shared_ptr<Connection>& connection, const struct sockaddr* address, socklen_t addr_len);
void asyncSendRequest( Async::Promise<ssize_t> asyncSendRequest(
Fd fd, const std::shared_ptr<Connection>& connection,
std::shared_ptr<TimerPool::Entry> timer, std::shared_ptr<TimerPool::Entry> timer,
const Buffer& buffer, const Buffer& buffer);
Async::Resolver resolve,
Async::Rejection reject,
OnResponseParsed onParsed);
private: private:
...@@ -146,79 +175,61 @@ private: ...@@ -146,79 +175,61 @@ private:
Retry Retry
}; };
struct PendingConnection { struct ConnectionEntry {
PendingConnection( ConnectionEntry(
Async::Resolver resolve, Async::Rejection reject, Async::Resolver resolve, Async::Rejection reject,
Fd fd, const struct sockaddr* addr, socklen_t addr_len) std::shared_ptr<Connection> connection, const struct sockaddr* addr, socklen_t addr_len)
: resolve(std::move(resolve)) : resolve(std::move(resolve))
, reject(std::move(reject)) , reject(std::move(reject))
, fd(fd) , connection(std::move(connection))
, addr(addr) , addr(addr)
, addr_len(addr_len) , addr_len(addr_len)
{ } { }
Async::Resolver resolve; Async::Resolver resolve;
Async::Rejection reject; Async::Rejection reject;
Fd fd; std::shared_ptr<Connection> connection;
const struct sockaddr* addr; const struct sockaddr* addr;
socklen_t addr_len; socklen_t addr_len;
}; };
struct InflightRequest { struct RequestEntry {
InflightRequest( RequestEntry(
Async::Resolver resolve, Async::Rejection reject, Async::Resolver resolve, Async::Rejection reject,
Fd fd, std::shared_ptr<Connection> connection,
std::shared_ptr<TimerPool::Entry> timer, std::shared_ptr<TimerPool::Entry> timer,
const Buffer& buffer, const Buffer& buffer)
OnResponseParsed onParsed = nullptr) : resolve(std::move(resolve))
: resolve_(std::move(resolve))
, reject(std::move(reject)) , reject(std::move(reject))
, fd(fd) , connection(std::move(connection))
, timer(std::move(timer)) , timer(std::move(timer))
, buffer(buffer) , buffer(buffer)
, onParsed(onParsed)
{ {
} }
void feed(const char* buffer, size_t totalBytes) { Async::Resolver resolve;
if (!parser)
parser.reset(new Private::Parser<Http::Response>());
parser->feed(buffer, totalBytes);
}
void resolve(Http::Response response) {
if (onParsed)
onParsed();
resolve_(std::move(response));
}
Async::Resolver resolve_;
Async::Rejection reject; Async::Rejection reject;
Fd fd; std::shared_ptr<Connection> connection;
std::shared_ptr<TimerPool::Entry> timer; std::shared_ptr<TimerPool::Entry> timer;
Buffer buffer; Buffer buffer;
OnResponseParsed onParsed;
std::shared_ptr<Private::Parser<Http::Response>> parser;
}; };
PollableQueue<InflightRequest> requestsQueue; PollableQueue<RequestEntry> requestsQueue;
PollableQueue<PendingConnection> connectionsQueue; PollableQueue<ConnectionEntry> connectionsQueue;
std::unordered_map<Fd, PendingConnection> pendingConnections; std::unordered_map<Fd, ConnectionEntry> connections;
std::unordered_map<Fd, std::deque<InflightRequest>> inflightRequests; std::unordered_map<Fd, RequestEntry> requests;
std::unordered_map<Fd, Fd> timeouts; std::unordered_map<Fd, std::shared_ptr<Connection>> timeouts;
void asyncSendRequestImpl(InflightRequest& req, WriteStatus status = FirstTry); void asyncSendRequestImpl(const RequestEntry& req, WriteStatus status = FirstTry);
void handleRequestsQueue(); void handleRequestsQueue();
void handleConnectionQueue(); void handleConnectionQueue();
void handleIncoming(Fd fd); void handleIncoming(const std::shared_ptr<Connection>& connection);
void handleResponsePacket(Fd fd, const char* buffer, size_t totalBytes); void handleResponsePacket(const std::shared_ptr<Connection>& connection, const char* buffer, size_t totalBytes);
void handleTimeout(Fd fd); void handleTimeout(const std::shared_ptr<Connection>& connection);
}; };
...@@ -274,6 +285,7 @@ public: ...@@ -274,6 +285,7 @@ public:
void shutdown(); void shutdown();
private: private:
Io::ServiceGroup io_; Io::ServiceGroup io_;
std::string url_; std::string url_;
std::string host_; std::string host_;
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <atomic> #include <atomic>
#include <unistd.h> #include <unistd.h>
#include "os.h" #include "os.h"
#include "io.h"
namespace Net { namespace Net {
...@@ -35,6 +36,7 @@ public: ...@@ -35,6 +36,7 @@ public:
Entry() Entry()
: fd(-1) : fd(-1)
, registered(false)
{ {
state.store(static_cast<uint32_t>(State::Idle)); state.store(static_cast<uint32_t>(State::Idle));
} }
...@@ -53,11 +55,19 @@ public: ...@@ -53,11 +55,19 @@ public:
} }
void disarm(); void disarm();
void registerIo(Io::Service* io) {
if (!registered) {
io->registerFd(fd, Polling::NotifyOn::Read);
}
registered = true;
}
private: private:
void armMs(std::chrono::milliseconds value); void armMs(std::chrono::milliseconds value);
enum class State : uint32_t { Idle, Used }; enum class State : uint32_t { Idle, Used };
std::atomic<uint32_t> state; std::atomic<uint32_t> state;
bool registered;
}; };
std::shared_ptr<Entry> pickTimer(); std::shared_ptr<Entry> pickTimer();
......
...@@ -19,6 +19,16 @@ namespace Experimental { ...@@ -19,6 +19,16 @@ namespace Experimental {
static constexpr const char* UA = "pistache/0.1"; static constexpr const char* UA = "pistache/0.1";
struct ExceptionPrinter {
void operator()(std::exception_ptr exc) {
try {
std::rethrow_exception(exc);
} catch (const std::exception& e) {
std::cout << "Got exception: " << e.what() << std::endl;
}
}
};
namespace { namespace {
#define OUT(...) \ #define OUT(...) \
do { \ do { \
...@@ -98,13 +108,13 @@ Transport::onReady(const Io::FdSet& fds) { ...@@ -98,13 +108,13 @@ Transport::onReady(const Io::FdSet& fds) {
else if (entry.isReadable()) { else if (entry.isReadable()) {
auto tag = entry.getTag(); auto tag = entry.getTag();
auto fd = tag.value(); auto fd = tag.value();
auto reqIt = inflightRequests.find(fd); auto reqIt = connections.find(fd);
if (reqIt != std::end(inflightRequests)) if (reqIt != std::end(connections))
handleIncoming(fd); handleIncoming(reqIt->second.connection);
else { else {
auto timerIt = timeouts.find(fd); auto timerIt = timeouts.find(fd);
if (timerIt != std::end(timeouts)) if (timerIt != std::end(timeouts))
handleTimeout(fd); handleTimeout(timerIt->second);
else { else {
throw std::runtime_error("Unknown fd"); throw std::runtime_error("Unknown fd");
} }
...@@ -114,28 +124,15 @@ Transport::onReady(const Io::FdSet& fds) { ...@@ -114,28 +124,15 @@ Transport::onReady(const Io::FdSet& fds) {
auto tag = entry.getTag(); auto tag = entry.getTag();
auto fd = tag.value(); auto fd = tag.value();
auto connIt = pendingConnections.find(fd); auto connIt = connections.find(fd);
if (connIt != std::end(pendingConnections)) { if (connIt != std::end(connections)) {
auto& conn = connIt->second; auto& conn = connIt->second;
conn.resolve(); conn.resolve();
pendingConnections.erase(fd); // We are connected, we can start reading data now
io()->modifyFd(conn.connection->fd, NotifyOn::Read);
continue; continue;
} }
#if 0
auto writeIt = toWrite.find(fd);
if (writeIt != std::end(toWrite)) {
/* @Bug: should not need modifyFd, investigate why I can't use
* registerFd
*/
io()->modifyFd(fd, NotifyOn::Read, Polling::Mode::Edge);
auto& write = writeIt->second;
asyncWriteImpl(fd, write, Retry);
continue;
}
#endif
throw std::runtime_error("Unknown fd"); throw std::runtime_error("Unknown fd");
} }
} }
...@@ -148,40 +145,38 @@ Transport::registerPoller(Polling::Epoll& poller) { ...@@ -148,40 +145,38 @@ Transport::registerPoller(Polling::Epoll& poller) {
} }
Async::Promise<void> Async::Promise<void>
Transport::asyncConnect(Fd fd, const struct sockaddr* address, socklen_t addr_len) Transport::asyncConnect(const std::shared_ptr<Connection>& connection, const struct sockaddr* address, socklen_t addr_len)
{ {
return Async::Promise<void>([=](Async::Resolver& resolve, Async::Rejection& reject) { return Async::Promise<void>([=](Async::Resolver& resolve, Async::Rejection& reject) {
PendingConnection conn(std::move(resolve), std::move(reject), fd, address, addr_len); ConnectionEntry entry(std::move(resolve), std::move(reject), connection, address, addr_len);
auto *entry = connectionsQueue.allocEntry(std::move(conn)); auto *e = connectionsQueue.allocEntry(std::move(entry));
connectionsQueue.push(entry); connectionsQueue.push(e);
}); });
} }
void Async::Promise<ssize_t>
Transport::asyncSendRequest( Transport::asyncSendRequest(
Fd fd, const std::shared_ptr<Connection>& connection,
std::shared_ptr<TimerPool::Entry> timer, std::shared_ptr<TimerPool::Entry> timer,
const Buffer& buffer, const Buffer& buffer) {
Async::Resolver resolve,
Async::Rejection reject,
OnResponseParsed onParsed) {
if (std::this_thread::get_id() != io()->thread()) { return Async::Promise<ssize_t>([&](Async::Resolver& resolve, Async::Rejection& reject) {
InflightRequest req(std::move(resolve), std::move(reject), fd, std::move(timer), buffer.detach(), std::move(onParsed)); if (std::this_thread::get_id() != io()->thread()) {
auto detached = buffer.detach(); RequestEntry req(std::move(resolve), std::move(reject), connection, std::move(timer), buffer.detach());
auto *e = requestsQueue.allocEntry(std::move(req)); auto *e = requestsQueue.allocEntry(std::move(req));
requestsQueue.push(e); requestsQueue.push(e);
} else { } else {
InflightRequest req(std::move(resolve), std::move(reject), fd, std::move(timer), buffer, std::move(onParsed)); RequestEntry req(std::move(resolve), std::move(reject), connection, std::move(timer), buffer);
asyncSendRequestImpl(req); asyncSendRequestImpl(req);
} }
});
} }
void void
Transport::asyncSendRequestImpl( Transport::asyncSendRequestImpl(
InflightRequest& req, WriteStatus status) const RequestEntry& req, WriteStatus status)
{ {
auto buffer = req.buffer; auto buffer = req.buffer;
...@@ -189,7 +184,9 @@ Transport::asyncSendRequestImpl( ...@@ -189,7 +184,9 @@ Transport::asyncSendRequestImpl(
if (buffer.isOwned) delete[] buffer.data; if (buffer.isOwned) delete[] buffer.data;
}; };
auto fd = req.fd; auto conn = req.connection;
auto fd = conn->fd;
ssize_t totalWritten = 0; ssize_t totalWritten = 0;
for (;;) { for (;;) {
...@@ -202,7 +199,7 @@ Transport::asyncSendRequestImpl( ...@@ -202,7 +199,7 @@ Transport::asyncSendRequestImpl(
if (status == FirstTry) { if (status == FirstTry) {
throw std::runtime_error("Unimplemented, fix me!"); throw std::runtime_error("Unimplemented, fix me!");
} }
io()->modifyFd(fd, NotifyOn::Read | NotifyOn::Write, Polling::Mode::Edge); io()->modifyFd(fd, NotifyOn::Write, Polling::Mode::Edge);
} }
else { else {
cleanUp(); cleanUp();
...@@ -214,14 +211,12 @@ Transport::asyncSendRequestImpl( ...@@ -214,14 +211,12 @@ Transport::asyncSendRequestImpl(
totalWritten += bytesWritten; totalWritten += bytesWritten;
if (totalWritten == len) { if (totalWritten == len) {
cleanUp(); cleanUp();
auto& queue = inflightRequests[fd]; if (req.timer) {
auto timer = req.timer; timeouts.insert(
if (timer) { std::make_pair(req.timer->fd, conn));
auto timerFd = timer->fd; req.timer->registerIo(io());
timeouts[timerFd] = fd;
io()->registerFd(timerFd, NotifyOn::Read, Polling::Mode::Edge);
} }
queue.push_back(std::move(req)); req.resolve(totalWritten);
break; break;
} }
} }
...@@ -246,23 +241,25 @@ Transport::handleConnectionQueue() { ...@@ -246,23 +241,25 @@ Transport::handleConnectionQueue() {
auto entry = connectionsQueue.popSafe(); auto entry = connectionsQueue.popSafe();
if (!entry) break; if (!entry) break;
auto &conn = entry->data();
int res = ::connect(conn.fd, conn.addr, conn.addr_len); auto &data = entry->data();
const auto& conn = data.connection;
int res = ::connect(conn->fd, data.addr, data.addr_len);
if (res == -1) { if (res == -1) {
if (errno == EINPROGRESS) { if (errno == EINPROGRESS) {
io()->registerFdOneShot(conn.fd, NotifyOn::Write); io()->registerFdOneShot(conn->fd, NotifyOn::Write);
pendingConnections.insert(
std::make_pair(conn.fd, std::move(conn)));
} }
else { else {
conn.reject(Error::system("Failed to connect")); data.reject(Error::system("Failed to connect"));
continue;
} }
} }
connections.insert(std::make_pair(conn->fd, std::move(data)));
} }
} }
void void
Transport::handleIncoming(Fd fd) { Transport::handleIncoming(const std::shared_ptr<Connection>& connection) {
char buffer[Const::MaxBuffer]; char buffer[Const::MaxBuffer];
memset(buffer, 0, sizeof buffer); memset(buffer, 0, sizeof buffer);
...@@ -271,11 +268,11 @@ Transport::handleIncoming(Fd fd) { ...@@ -271,11 +268,11 @@ Transport::handleIncoming(Fd fd) {
ssize_t bytes; ssize_t bytes;
bytes = recv(fd, buffer + totalBytes, Const::MaxBuffer - totalBytes, 0); bytes = recv(connection->fd, buffer + totalBytes, Const::MaxBuffer - totalBytes, 0);
if (bytes == -1) { if (bytes == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) { if (errno == EAGAIN || errno == EWOULDBLOCK) {
if (totalBytes > 0) { if (totalBytes > 0) {
handleResponsePacket(fd, buffer, totalBytes); handleResponsePacket(connection, buffer, totalBytes);
} }
} else { } else {
if (errno == ECONNRESET) { if (errno == ECONNRESET) {
...@@ -301,35 +298,13 @@ Transport::handleIncoming(Fd fd) { ...@@ -301,35 +298,13 @@ Transport::handleIncoming(Fd fd) {
} }
void void
Transport::handleResponsePacket(Fd fd, const char* buffer, size_t totalBytes) { Transport::handleResponsePacket(const std::shared_ptr<Connection>& connection, const char* buffer, size_t totalBytes) {
auto it = inflightRequests.find(fd); connection->handleResponsePacket(buffer, totalBytes);
if (it == std::end(inflightRequests))
throw std::runtime_error("Received response for a non-inflight request");
auto &queue = it->second;
auto &req = queue.front();
req.feed(buffer, totalBytes);
if (req.parser->parse() == Private::State::Done) {
req.timer->disarm();
req.resolve(std::move(req.parser->response));
queue.pop_front();
}
} }
void void
Transport::handleTimeout(Fd fd) { Transport::handleTimeout(const std::shared_ptr<Connection>& connection) {
auto timerIt = timeouts.find(fd); connection->handleTimeout();
auto reqFd = timerIt->second;
auto reqIt = inflightRequests.find(reqFd);
if (reqIt == std::end(inflightRequests))
throw std::runtime_error("Internal condition violation, received timeout for a non-inflight request");
auto& queue = reqIt->second;
auto& req = queue.front();
req.reject(std::runtime_error("Timeout"));
queue.pop_front();
timeouts.erase(fd);
} }
void void
...@@ -363,14 +338,15 @@ Connection::connect(Net::Address addr) ...@@ -363,14 +338,15 @@ Connection::connect(Net::Address addr)
make_non_blocking(sfd); make_non_blocking(sfd);
connectionState_ = Connecting; connectionState_ = Connecting;
fd = sfd;
transport_->asyncConnect(sfd, addr->ai_addr, addr->ai_addrlen) transport_->asyncConnect(shared_from_this(), addr->ai_addr, addr->ai_addrlen)
.then([=]() { .then([=]() {
socklen_t len = sizeof(saddr);
getsockname(sfd, (struct sockaddr *)&saddr, &len);
connectionState_ = Connected; connectionState_ = Connected;
fd = sfd;
transport_->io()->modifyFd(fd, NotifyOn::Read);
processRequestQueue(); processRequestQueue();
}, Async::Throw); }, ExceptionPrinter());
break; break;
} }
...@@ -379,6 +355,14 @@ Connection::connect(Net::Address addr) ...@@ -379,6 +355,14 @@ Connection::connect(Net::Address addr)
throw std::runtime_error("Failed to connect"); throw std::runtime_error("Failed to connect");
} }
std::string
Connection::dump() const {
std::ostringstream oss;
oss << "Connection(fd = " << fd << ", src_port = ";
oss << ntohs(saddr.sin_port) << ")";
return oss.str();
}
bool bool
Connection::isConnected() const { Connection::isConnected() const {
return connectionState_ == Connected; return connectionState_ == Connected;
...@@ -402,23 +386,94 @@ Connection::hasTransport() const { ...@@ -402,23 +386,94 @@ Connection::hasTransport() const {
return transport_ != nullptr; return transport_ != nullptr;
} }
void
Connection::handleResponsePacket(const char* buffer, size_t bytes) {
parser_.feed(buffer, bytes);
if (parser_.parse() == Private::State::Done) {
auto req = std::move(inflightRequests.front());
inflightRequests.pop_back();
if (req.timer) {
req.timer->disarm();
timerPool_.releaseTimer(req.timer);
}
req.resolve(std::move(parser_.response));
req.onDone();
}
}
void
Connection::handleTimeout() {
auto req = std::move(inflightRequests.front());
inflightRequests.pop_back();
timerPool_.releaseTimer(req.timer);
req.onDone();
/* @API: create a TimeoutException */
req.reject(std::runtime_error("Timeout"));
}
Async::Promise<Response> Async::Promise<Response>
Connection::perform( Connection::perform(
const Http::Request& request, const Http::Request& request,
std::string host, std::string host,
std::chrono::milliseconds timeout, std::chrono::milliseconds timeout,
OnDone onDone) { Connection::OnDone onDone) {
return Async::Promise<Response>([=](Async::Resolver& resolve, Async::Rejection& reject) { return Async::Promise<Response>([=](Async::Resolver& resolve, Async::Rejection& reject) {
if (!isConnected()) { if (!isConnected()) {
auto* entry = requestsQueue.allocEntry( auto* entry = requestsQueue.allocEntry(
RequestData(std::move(resolve), std::move(reject), request, std::move(host), timeout, std::move(onDone))); RequestData(
requestsQueue.push(entry); std::move(resolve),
std::move(reject),
request,
std::move(host),
timeout,
std::move(onDone)));
requestsQueue.push(entry);
} else { } else {
performImpl(request, std::move(host), timeout, std::move(resolve), std::move(reject), std::move(onDone)); performImpl(request, std::move(host), timeout, std::move(resolve), std::move(reject), std::move(onDone));
} }
}); });
} }
/**
* This class is used to emulate the generalized lambda capture feature from C++14
* whereby a given object can be moved inside a lambda, directly from the capture-list
*
* So instead, it will use the exact same semantic than auto_ptr (don't beat me for that),
* meaning that it will move the value on copy
*/
template<typename T>
struct MoveOnCopy {
MoveOnCopy(T val)
: val(std::move(val))
{ }
MoveOnCopy(MoveOnCopy& other)
: val(std::move(other.val))
{ }
MoveOnCopy& operator=(MoveOnCopy& other) {
val = std::move(other.val);
}
MoveOnCopy(MoveOnCopy&& other) = default;
MoveOnCopy& operator=(MoveOnCopy&& other) = default;
operator T&&() {
return std::move(val);
}
T val;
};
template<typename T>
MoveOnCopy<T> make_copy_mover(T arg) {
return MoveOnCopy<T>(std::move(arg));
}
void void
Connection::performImpl( Connection::performImpl(
const Http::Request& request, const Http::Request& request,
...@@ -426,7 +481,7 @@ Connection::performImpl( ...@@ -426,7 +481,7 @@ Connection::performImpl(
std::chrono::milliseconds timeout, std::chrono::milliseconds timeout,
Async::Resolver resolve, Async::Resolver resolve,
Async::Rejection reject, Async::Rejection reject,
OnDone onDone) { Connection::OnDone onDone) {
DynamicStreamBuf buf(128); DynamicStreamBuf buf(128);
...@@ -440,7 +495,30 @@ Connection::performImpl( ...@@ -440,7 +495,30 @@ Connection::performImpl(
timer->arm(timeout); timer->arm(timeout);
} }
transport_->asyncSendRequest(fd, timer, buffer, std::move(resolve), std::move(reject), std::move(onDone)); // Move the resolver and rejecter inside the lambda
auto resolveMover = make_copy_mover(std::move(resolve));
auto rejectMover = make_copy_mover(std::move(reject));
/*
* @Incomplete: currently, if the promise is rejected in asyncSendRequest,
* it will abort the current execution (NoExcept). Instead, it should reject
* the original promise from the request. The thing is that we currently can not
* do that since we transfered the ownership of the original rejecter to the
* mover so that it will be moved inside the continuation lambda.
*
* Since Resolver and Rejection objects are not copyable by default, we could
* implement a clone() member function so that it's explicit that a copy operation
* has been originated from the user
*
* The reason why Resolver and Rejection copy constructors is disabled is to avoid
* double-resolve and double-reject of a promise
*/
transport_->asyncSendRequest(shared_from_this(), timer, buffer).then(
[=](ssize_t bytes) mutable {
inflightRequests.push_back(RequestEntry(std::move(resolveMover), std::move(rejectMover), std::move(timer), std::move(onDone)));
}
, Async::NoExcept);
} }
void void
......
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
#include "async.h" #include "async.h"
#include <thread> #include <thread>
#include <algorithm> #include <algorithm>
#include <deque>
#include <mutex>
#include <condition_variable>
Async::Promise<int> doAsync(int N) Async::Promise<int> doAsync(int N)
{ {
...@@ -291,3 +294,139 @@ TEST(async_test, rethrow_test) { ...@@ -291,3 +294,139 @@ TEST(async_test, rethrow_test) {
ASSERT_TRUE(p2.isRejected()); ASSERT_TRUE(p2.isRejected());
} }
template<typename T>
struct MessageQueue {
public:
template<typename U>
void push(U&& arg) {
std::unique_lock<std::mutex> guard(mtx);
q.push_back(std::forward<U>(arg));
cv.notify_one();
}
T pop() {
std::unique_lock<std::mutex> lock(mtx);
cv.wait(lock, [=]() { return !q.empty(); });
T out = std::move(q.front());
q.pop_front();
return out;
}
bool tryPop(T& out, std::chrono::milliseconds timeout) {
std::unique_lock<std::mutex> lock(mtx);
if (!cv.wait_for(lock, timeout, [=]() { return !q.empty(); }))
return false;
out = std::move(q.front());
q.pop_front();
return true;
}
private:
std::deque<T> q;
std::mutex mtx;
std::condition_variable cv;
};
struct Worker {
public:
~Worker() {
thread->join();
}
void start() {
shutdown.store(false);
thread.reset(new std::thread([=]() { run(); }));
}
void stop() {
shutdown.store(true);
}
Async::Promise<int> doWork(int seq) {
return Async::Promise<int>([=](Async::Resolver& resolve, Async::Rejection& reject) {
queue.push(new WorkRequest(std::move(resolve), std::move(reject), seq));
});
}
private:
void run() {
while (!shutdown) {
WorkRequest *request;
if (queue.tryPop(request, std::chrono::milliseconds(200))) {
request->resolve(request->seq);
delete request;
}
}
}
struct WorkRequest {
WorkRequest(Async::Resolver resolve, Async::Rejection reject, int seq)
: resolve(std::move(resolve))
, reject(std::move(reject))
, seq(seq)
{
}
int seq;
Async::Resolver resolve;
Async::Rejection reject;
};
std::atomic<bool> shutdown;
MessageQueue<WorkRequest*> queue;
std::random_device rd;
std::unique_ptr<std::thread> thread;
};
TEST(async_test, stress_multithreaded_test) {
static constexpr size_t OpsPerThread = 100000;
static constexpr size_t Workers = 6;
static constexpr size_t Ops = OpsPerThread * Workers;
std::cout << "Starting stress testing promises, hang on, this test might take some time to complete" << std::endl;
std::cout << "=================================================" << std::endl;
std::cout << "Parameters for the test: " << std::endl;
std::cout << "Workers -> " << Workers << std::endl;
std::cout << "OpsPerThread -> " << OpsPerThread << std::endl;
std::cout << "Total Ops -> " << Ops << std::endl;
std::cout << "=================================================" << std::endl;
std::cout << std::endl << std::endl;
std::vector<std::unique_ptr<Worker>> workers;
for (size_t i = 0; i < Workers; ++i) {
std::unique_ptr<Worker> wrk(new Worker);
wrk->start();
workers.push_back(std::move(wrk));
}
std::vector<Async::Promise<int>> promises;
std::atomic<int> resolved(0);
size_t wrkIndex = 0;
for (size_t i = 0; i < Ops; ++i) {
auto &wrk = workers[wrkIndex];
wrk->doWork(i).then([&](int seq) {
++resolved;
}, Async::NoExcept);
wrkIndex = (wrkIndex + 1) % Workers;
}
for (;;) {
auto r = resolved.load();
std::cout << r << " promises resolved" << std::endl;
if (r == Ops) break;
std::this_thread::sleep_for(std::chrono::milliseconds(500));
}
std::cout << "Stopping worker" << std::endl;
for (auto& wrk: workers) {
wrk->stop();
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment