Commit cb274c8e authored by octal's avatar octal

Improved the client architecture and fixed a data-race in async

parent e703fc45
......@@ -15,28 +15,32 @@ int main() {
Http::Experimental::Client client("http://supnetwork.org:9080");
auto opts = Http::Experimental::Client::options()
.threads(1)
.maxConnections(20);
.maxConnections(64);
using namespace Net::Http;
constexpr size_t Requests = 5000;
constexpr size_t Requests = 10000;
std::atomic<int> responsesReceived(0);
client.init(opts);
for (int i = 0; i < Requests; ++i) {
client.get(client
.request("/ping")
.cookie(Cookie("FOO", "bar")), std::chrono::milliseconds(1000))
.cookie(Cookie("FOO", "bar")))
.then([&](const Http::Response& response) {
responsesReceived.fetch_add(1);
//std::cout << "code = " << response.code() << std::endl;
// std::cout << "body = " << response.body() << std::endl;
//std::cout << "body = " << response.body() << std::endl;
}, Async::NoExcept);
const auto count = i + 1;
if (count % 10 == 0)
std::cout << "Sent " << count << " requests" << std::endl;
}
std::cout << "Sent " << Requests << " requests" << std::endl;
for (;;) {
std::this_thread::sleep_for(std::chrono::seconds(1));
std::cout << "Received " << responsesReceived.load() << " responses" << std::endl;
auto count = responsesReceived.load();
std::cout << "Received " << count << " responses" << std::endl;
if (count == Requests) break;
}
client.shutdown();
}
......@@ -12,6 +12,7 @@
#include <memory>
#include <atomic>
#include <vector>
#include <mutex>
#include "optional.h"
#include "typeid.h"
......@@ -170,6 +171,19 @@ namespace Async {
State state;
std::exception_ptr exc;
/*
* We need this lock because a Promise might be resolved or rejected from a thread A
* while a continuation to the same Promise (Core) might be attached at the same from
* a thread B. If that's the case, then we need to serialize operations so that we
* avoid a race-condition.
*
* Since we have a lock, we have a blocking progress guarantee but I don't expect this
* to be a major bottleneck as I don't expect major contention on the lock
* If it ends up being a bottlenick, try @improving it by experimenting with a lock-free
* scheme
*/
std::mutex mtx;
std::vector<std::shared_ptr<Request>> requests;
TypeId id;
......@@ -536,10 +550,13 @@ namespace Async {
* than runtime. However, since types are erased, this looks like
* a difficult task
*/
if (core_->isVoid())
if (core_->isVoid()) {
throw Error("Attempt to resolve a void promise with arguments");
}
std::unique_lock<std::mutex> guard(core_->mtx);
core_->construct<Type>(std::forward<Arg>(arg));
for (const auto& req: core_->requests) {
req->resolve(core_);
}
......@@ -554,6 +571,7 @@ namespace Async {
if (!core_->isVoid())
throw Error("Attempt ro resolve a non-void promise with no argument");
std::unique_lock<std::mutex> guard(core_->mtx);
core_->state = State::Fulfilled;
for (const auto& req: core_->requests) {
req->resolve(core_);
......@@ -578,6 +596,7 @@ namespace Async {
if (core_->state != State::Pending)
throw Error("Attempt to reject a fulfilled promise");
std::unique_lock<std::mutex> guard(core_->mtx);
core_->exc = std::make_exception_ptr(exc);
core_->state = State::Rejected;
for (const auto& req: core_->requests) {
......@@ -716,6 +735,7 @@ namespace Async {
std::forward<ResolveFunc>(resolveFunc),
std::forward<RejectFunc>(rejectFunc)));
std::unique_lock<std::mutex> guard(core_->mtx);
if (isFulfilled()) {
req->resolve(core_);
}
......@@ -723,9 +743,9 @@ namespace Async {
req->reject(core_);
}
core_->requests.push_back(req);
core_->requests.push_back(req);
return promise;
return promise;
}
private:
......
......@@ -24,12 +24,21 @@ namespace Experimental {
class ConnectionPool;
class Transport;
struct Connection {
struct Connection : public std::enable_shared_from_this<Connection> {
friend class ConnectionPool;
typedef std::function<void()> OnDone;
Connection()
: fd(-1)
, connectionState_(NotConnected)
, inflightCount(0)
, responsesReceived(0)
{
state_.store(static_cast<uint32_t>(State::Idle));
}
struct RequestData {
RequestData(
......@@ -48,9 +57,9 @@ struct Connection {
Async::Resolver resolve;
Async::Rejection reject;
Http::Request request;
std::string host;
std::chrono::milliseconds timeout;
Http::Request request;
OnDone onDone;
};
......@@ -65,13 +74,6 @@ struct Connection {
Connected
};
Connection()
: fd(-1)
, connectionState_(NotConnected)
{
state_.store(static_cast<uint32_t>(State::Idle));
}
void connect(Net::Address addr);
void close();
bool isConnected() const;
......@@ -94,15 +96,45 @@ struct Connection {
Fd fd;
void handleResponsePacket(const char* buffer, size_t totalBytes);
void handleTimeout();
std::string dump() const;
private:
std::atomic<int> inflightCount;
std::atomic<int> responsesReceived;
struct sockaddr_in saddr;
void processRequestQueue();
struct RequestEntry {
RequestEntry(
Async::Resolver resolve, Async::Rejection reject,
std::shared_ptr<TimerPool::Entry> timer,
OnDone onDone)
: resolve(std::move(resolve))
, reject(std::move(reject))
, timer(std::move(timer))
, onDone(std::move(onDone))
{ }
Async::Resolver resolve;
Async::Rejection reject;
std::shared_ptr<TimerPool::Entry> timer;
OnDone onDone;
};
std::atomic<uint32_t> state_;
ConnectionState connectionState_;
std::shared_ptr<Transport> transport_;
Queue<RequestData> requestsQueue;
std::deque<RequestEntry> inflightRequests;
Net::TimerPool timerPool_;
Private::Parser<Http::Response> parser_;
};
struct ConnectionPool {
......@@ -129,15 +161,12 @@ public:
void registerPoller(Polling::Epoll& poller);
Async::Promise<void>
asyncConnect(Fd fd, const struct sockaddr* address, socklen_t addr_len);
asyncConnect(const std::shared_ptr<Connection>& connection, const struct sockaddr* address, socklen_t addr_len);
void asyncSendRequest(
Fd fd,
Async::Promise<ssize_t> asyncSendRequest(
const std::shared_ptr<Connection>& connection,
std::shared_ptr<TimerPool::Entry> timer,
const Buffer& buffer,
Async::Resolver resolve,
Async::Rejection reject,
OnResponseParsed onParsed);
const Buffer& buffer);
private:
......@@ -146,79 +175,61 @@ private:
Retry
};
struct PendingConnection {
PendingConnection(
struct ConnectionEntry {
ConnectionEntry(
Async::Resolver resolve, Async::Rejection reject,
Fd fd, const struct sockaddr* addr, socklen_t addr_len)
std::shared_ptr<Connection> connection, const struct sockaddr* addr, socklen_t addr_len)
: resolve(std::move(resolve))
, reject(std::move(reject))
, fd(fd)
, connection(std::move(connection))
, addr(addr)
, addr_len(addr_len)
{ }
Async::Resolver resolve;
Async::Rejection reject;
Fd fd;
std::shared_ptr<Connection> connection;
const struct sockaddr* addr;
socklen_t addr_len;
};
struct InflightRequest {
InflightRequest(
struct RequestEntry {
RequestEntry(
Async::Resolver resolve, Async::Rejection reject,
Fd fd,
std::shared_ptr<Connection> connection,
std::shared_ptr<TimerPool::Entry> timer,
const Buffer& buffer,
OnResponseParsed onParsed = nullptr)
: resolve_(std::move(resolve))
const Buffer& buffer)
: resolve(std::move(resolve))
, reject(std::move(reject))
, fd(fd)
, connection(std::move(connection))
, timer(std::move(timer))
, buffer(buffer)
, onParsed(onParsed)
{
}
void feed(const char* buffer, size_t totalBytes) {
if (!parser)
parser.reset(new Private::Parser<Http::Response>());
parser->feed(buffer, totalBytes);
}
void resolve(Http::Response response) {
if (onParsed)
onParsed();
resolve_(std::move(response));
}
Async::Resolver resolve_;
Async::Resolver resolve;
Async::Rejection reject;
Fd fd;
std::shared_ptr<Connection> connection;
std::shared_ptr<TimerPool::Entry> timer;
Buffer buffer;
OnResponseParsed onParsed;
std::shared_ptr<Private::Parser<Http::Response>> parser;
};
PollableQueue<InflightRequest> requestsQueue;
PollableQueue<PendingConnection> connectionsQueue;
PollableQueue<RequestEntry> requestsQueue;
PollableQueue<ConnectionEntry> connectionsQueue;
std::unordered_map<Fd, PendingConnection> pendingConnections;
std::unordered_map<Fd, std::deque<InflightRequest>> inflightRequests;
std::unordered_map<Fd, Fd> timeouts;
std::unordered_map<Fd, ConnectionEntry> connections;
std::unordered_map<Fd, RequestEntry> requests;
std::unordered_map<Fd, std::shared_ptr<Connection>> timeouts;
void asyncSendRequestImpl(InflightRequest& req, WriteStatus status = FirstTry);
void asyncSendRequestImpl(const RequestEntry& req, WriteStatus status = FirstTry);
void handleRequestsQueue();
void handleConnectionQueue();
void handleIncoming(Fd fd);
void handleResponsePacket(Fd fd, const char* buffer, size_t totalBytes);
void handleTimeout(Fd fd);
void handleIncoming(const std::shared_ptr<Connection>& connection);
void handleResponsePacket(const std::shared_ptr<Connection>& connection, const char* buffer, size_t totalBytes);
void handleTimeout(const std::shared_ptr<Connection>& connection);
};
......@@ -274,6 +285,7 @@ public:
void shutdown();
private:
Io::ServiceGroup io_;
std::string url_;
std::string host_;
......
......@@ -16,6 +16,7 @@
#include <atomic>
#include <unistd.h>
#include "os.h"
#include "io.h"
namespace Net {
......@@ -35,6 +36,7 @@ public:
Entry()
: fd(-1)
, registered(false)
{
state.store(static_cast<uint32_t>(State::Idle));
}
......@@ -53,11 +55,19 @@ public:
}
void disarm();
void registerIo(Io::Service* io) {
if (!registered) {
io->registerFd(fd, Polling::NotifyOn::Read);
}
registered = true;
}
private:
void armMs(std::chrono::milliseconds value);
enum class State : uint32_t { Idle, Used };
std::atomic<uint32_t> state;
bool registered;
};
std::shared_ptr<Entry> pickTimer();
......
......@@ -19,6 +19,16 @@ namespace Experimental {
static constexpr const char* UA = "pistache/0.1";
struct ExceptionPrinter {
void operator()(std::exception_ptr exc) {
try {
std::rethrow_exception(exc);
} catch (const std::exception& e) {
std::cout << "Got exception: " << e.what() << std::endl;
}
}
};
namespace {
#define OUT(...) \
do { \
......@@ -98,13 +108,13 @@ Transport::onReady(const Io::FdSet& fds) {
else if (entry.isReadable()) {
auto tag = entry.getTag();
auto fd = tag.value();
auto reqIt = inflightRequests.find(fd);
if (reqIt != std::end(inflightRequests))
handleIncoming(fd);
auto reqIt = connections.find(fd);
if (reqIt != std::end(connections))
handleIncoming(reqIt->second.connection);
else {
auto timerIt = timeouts.find(fd);
if (timerIt != std::end(timeouts))
handleTimeout(fd);
handleTimeout(timerIt->second);
else {
throw std::runtime_error("Unknown fd");
}
......@@ -114,28 +124,15 @@ Transport::onReady(const Io::FdSet& fds) {
auto tag = entry.getTag();
auto fd = tag.value();
auto connIt = pendingConnections.find(fd);
if (connIt != std::end(pendingConnections)) {
auto connIt = connections.find(fd);
if (connIt != std::end(connections)) {
auto& conn = connIt->second;
conn.resolve();
pendingConnections.erase(fd);
// We are connected, we can start reading data now
io()->modifyFd(conn.connection->fd, NotifyOn::Read);
continue;
}
#if 0
auto writeIt = toWrite.find(fd);
if (writeIt != std::end(toWrite)) {
/* @Bug: should not need modifyFd, investigate why I can't use
* registerFd
*/
io()->modifyFd(fd, NotifyOn::Read, Polling::Mode::Edge);
auto& write = writeIt->second;
asyncWriteImpl(fd, write, Retry);
continue;
}
#endif
throw std::runtime_error("Unknown fd");
}
}
......@@ -148,40 +145,38 @@ Transport::registerPoller(Polling::Epoll& poller) {
}
Async::Promise<void>
Transport::asyncConnect(Fd fd, const struct sockaddr* address, socklen_t addr_len)
Transport::asyncConnect(const std::shared_ptr<Connection>& connection, const struct sockaddr* address, socklen_t addr_len)
{
return Async::Promise<void>([=](Async::Resolver& resolve, Async::Rejection& reject) {
PendingConnection conn(std::move(resolve), std::move(reject), fd, address, addr_len);
auto *entry = connectionsQueue.allocEntry(std::move(conn));
connectionsQueue.push(entry);
ConnectionEntry entry(std::move(resolve), std::move(reject), connection, address, addr_len);
auto *e = connectionsQueue.allocEntry(std::move(entry));
connectionsQueue.push(e);
});
}
void
Async::Promise<ssize_t>
Transport::asyncSendRequest(
Fd fd,
const std::shared_ptr<Connection>& connection,
std::shared_ptr<TimerPool::Entry> timer,
const Buffer& buffer,
Async::Resolver resolve,
Async::Rejection reject,
OnResponseParsed onParsed) {
const Buffer& buffer) {
if (std::this_thread::get_id() != io()->thread()) {
InflightRequest req(std::move(resolve), std::move(reject), fd, std::move(timer), buffer.detach(), std::move(onParsed));
auto detached = buffer.detach();
auto *e = requestsQueue.allocEntry(std::move(req));
requestsQueue.push(e);
} else {
InflightRequest req(std::move(resolve), std::move(reject), fd, std::move(timer), buffer, std::move(onParsed));
return Async::Promise<ssize_t>([&](Async::Resolver& resolve, Async::Rejection& reject) {
if (std::this_thread::get_id() != io()->thread()) {
RequestEntry req(std::move(resolve), std::move(reject), connection, std::move(timer), buffer.detach());
auto *e = requestsQueue.allocEntry(std::move(req));
requestsQueue.push(e);
} else {
RequestEntry req(std::move(resolve), std::move(reject), connection, std::move(timer), buffer);
asyncSendRequestImpl(req);
}
asyncSendRequestImpl(req);
}
});
}
void
Transport::asyncSendRequestImpl(
InflightRequest& req, WriteStatus status)
const RequestEntry& req, WriteStatus status)
{
auto buffer = req.buffer;
......@@ -189,7 +184,9 @@ Transport::asyncSendRequestImpl(
if (buffer.isOwned) delete[] buffer.data;
};
auto fd = req.fd;
auto conn = req.connection;
auto fd = conn->fd;
ssize_t totalWritten = 0;
for (;;) {
......@@ -202,7 +199,7 @@ Transport::asyncSendRequestImpl(
if (status == FirstTry) {
throw std::runtime_error("Unimplemented, fix me!");
}
io()->modifyFd(fd, NotifyOn::Read | NotifyOn::Write, Polling::Mode::Edge);
io()->modifyFd(fd, NotifyOn::Write, Polling::Mode::Edge);
}
else {
cleanUp();
......@@ -214,14 +211,12 @@ Transport::asyncSendRequestImpl(
totalWritten += bytesWritten;
if (totalWritten == len) {
cleanUp();
auto& queue = inflightRequests[fd];
auto timer = req.timer;
if (timer) {
auto timerFd = timer->fd;
timeouts[timerFd] = fd;
io()->registerFd(timerFd, NotifyOn::Read, Polling::Mode::Edge);
if (req.timer) {
timeouts.insert(
std::make_pair(req.timer->fd, conn));
req.timer->registerIo(io());
}
queue.push_back(std::move(req));
req.resolve(totalWritten);
break;
}
}
......@@ -246,23 +241,25 @@ Transport::handleConnectionQueue() {
auto entry = connectionsQueue.popSafe();
if (!entry) break;
auto &conn = entry->data();
int res = ::connect(conn.fd, conn.addr, conn.addr_len);
auto &data = entry->data();
const auto& conn = data.connection;
int res = ::connect(conn->fd, data.addr, data.addr_len);
if (res == -1) {
if (errno == EINPROGRESS) {
io()->registerFdOneShot(conn.fd, NotifyOn::Write);
pendingConnections.insert(
std::make_pair(conn.fd, std::move(conn)));
io()->registerFdOneShot(conn->fd, NotifyOn::Write);
}
else {
conn.reject(Error::system("Failed to connect"));
data.reject(Error::system("Failed to connect"));
continue;
}
}
connections.insert(std::make_pair(conn->fd, std::move(data)));
}
}
void
Transport::handleIncoming(Fd fd) {
Transport::handleIncoming(const std::shared_ptr<Connection>& connection) {
char buffer[Const::MaxBuffer];
memset(buffer, 0, sizeof buffer);
......@@ -271,11 +268,11 @@ Transport::handleIncoming(Fd fd) {
ssize_t bytes;
bytes = recv(fd, buffer + totalBytes, Const::MaxBuffer - totalBytes, 0);
bytes = recv(connection->fd, buffer + totalBytes, Const::MaxBuffer - totalBytes, 0);
if (bytes == -1) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
if (totalBytes > 0) {
handleResponsePacket(fd, buffer, totalBytes);
handleResponsePacket(connection, buffer, totalBytes);
}
} else {
if (errno == ECONNRESET) {
......@@ -301,35 +298,13 @@ Transport::handleIncoming(Fd fd) {
}
void
Transport::handleResponsePacket(Fd fd, const char* buffer, size_t totalBytes) {
auto it = inflightRequests.find(fd);
if (it == std::end(inflightRequests))
throw std::runtime_error("Received response for a non-inflight request");
auto &queue = it->second;
auto &req = queue.front();
req.feed(buffer, totalBytes);
if (req.parser->parse() == Private::State::Done) {
req.timer->disarm();
req.resolve(std::move(req.parser->response));
queue.pop_front();
}
Transport::handleResponsePacket(const std::shared_ptr<Connection>& connection, const char* buffer, size_t totalBytes) {
connection->handleResponsePacket(buffer, totalBytes);
}
void
Transport::handleTimeout(Fd fd) {
auto timerIt = timeouts.find(fd);
auto reqFd = timerIt->second;
auto reqIt = inflightRequests.find(reqFd);
if (reqIt == std::end(inflightRequests))
throw std::runtime_error("Internal condition violation, received timeout for a non-inflight request");
auto& queue = reqIt->second;
auto& req = queue.front();
req.reject(std::runtime_error("Timeout"));
queue.pop_front();
timeouts.erase(fd);
Transport::handleTimeout(const std::shared_ptr<Connection>& connection) {
connection->handleTimeout();
}
void
......@@ -363,14 +338,15 @@ Connection::connect(Net::Address addr)
make_non_blocking(sfd);
connectionState_ = Connecting;
fd = sfd;
transport_->asyncConnect(sfd, addr->ai_addr, addr->ai_addrlen)
transport_->asyncConnect(shared_from_this(), addr->ai_addr, addr->ai_addrlen)
.then([=]() {
socklen_t len = sizeof(saddr);
getsockname(sfd, (struct sockaddr *)&saddr, &len);
connectionState_ = Connected;
fd = sfd;
transport_->io()->modifyFd(fd, NotifyOn::Read);
processRequestQueue();
}, Async::Throw);
}, ExceptionPrinter());
break;
}
......@@ -379,6 +355,14 @@ Connection::connect(Net::Address addr)
throw std::runtime_error("Failed to connect");
}
std::string
Connection::dump() const {
std::ostringstream oss;
oss << "Connection(fd = " << fd << ", src_port = ";
oss << ntohs(saddr.sin_port) << ")";
return oss.str();
}
bool
Connection::isConnected() const {
return connectionState_ == Connected;
......@@ -402,23 +386,94 @@ Connection::hasTransport() const {
return transport_ != nullptr;
}
void
Connection::handleResponsePacket(const char* buffer, size_t bytes) {
parser_.feed(buffer, bytes);
if (parser_.parse() == Private::State::Done) {
auto req = std::move(inflightRequests.front());
inflightRequests.pop_back();
if (req.timer) {
req.timer->disarm();
timerPool_.releaseTimer(req.timer);
}
req.resolve(std::move(parser_.response));
req.onDone();
}
}
void
Connection::handleTimeout() {
auto req = std::move(inflightRequests.front());
inflightRequests.pop_back();
timerPool_.releaseTimer(req.timer);
req.onDone();
/* @API: create a TimeoutException */
req.reject(std::runtime_error("Timeout"));
}
Async::Promise<Response>
Connection::perform(
const Http::Request& request,
std::string host,
std::chrono::milliseconds timeout,
OnDone onDone) {
Connection::OnDone onDone) {
return Async::Promise<Response>([=](Async::Resolver& resolve, Async::Rejection& reject) {
if (!isConnected()) {
auto* entry = requestsQueue.allocEntry(
RequestData(std::move(resolve), std::move(reject), request, std::move(host), timeout, std::move(onDone)));
requestsQueue.push(entry);
auto* entry = requestsQueue.allocEntry(
RequestData(
std::move(resolve),
std::move(reject),
request,
std::move(host),
timeout,
std::move(onDone)));
requestsQueue.push(entry);
} else {
performImpl(request, std::move(host), timeout, std::move(resolve), std::move(reject), std::move(onDone));
}
});
}
/**
* This class is used to emulate the generalized lambda capture feature from C++14
* whereby a given object can be moved inside a lambda, directly from the capture-list
*
* So instead, it will use the exact same semantic than auto_ptr (don't beat me for that),
* meaning that it will move the value on copy
*/
template<typename T>
struct MoveOnCopy {
MoveOnCopy(T val)
: val(std::move(val))
{ }
MoveOnCopy(MoveOnCopy& other)
: val(std::move(other.val))
{ }
MoveOnCopy& operator=(MoveOnCopy& other) {
val = std::move(other.val);
}
MoveOnCopy(MoveOnCopy&& other) = default;
MoveOnCopy& operator=(MoveOnCopy&& other) = default;
operator T&&() {
return std::move(val);
}
T val;
};
template<typename T>
MoveOnCopy<T> make_copy_mover(T arg) {
return MoveOnCopy<T>(std::move(arg));
}
void
Connection::performImpl(
const Http::Request& request,
......@@ -426,7 +481,7 @@ Connection::performImpl(
std::chrono::milliseconds timeout,
Async::Resolver resolve,
Async::Rejection reject,
OnDone onDone) {
Connection::OnDone onDone) {
DynamicStreamBuf buf(128);
......@@ -440,7 +495,30 @@ Connection::performImpl(
timer->arm(timeout);
}
transport_->asyncSendRequest(fd, timer, buffer, std::move(resolve), std::move(reject), std::move(onDone));
// Move the resolver and rejecter inside the lambda
auto resolveMover = make_copy_mover(std::move(resolve));
auto rejectMover = make_copy_mover(std::move(reject));
/*
* @Incomplete: currently, if the promise is rejected in asyncSendRequest,
* it will abort the current execution (NoExcept). Instead, it should reject
* the original promise from the request. The thing is that we currently can not
* do that since we transfered the ownership of the original rejecter to the
* mover so that it will be moved inside the continuation lambda.
*
* Since Resolver and Rejection objects are not copyable by default, we could
* implement a clone() member function so that it's explicit that a copy operation
* has been originated from the user
*
* The reason why Resolver and Rejection copy constructors is disabled is to avoid
* double-resolve and double-reject of a promise
*/
transport_->asyncSendRequest(shared_from_this(), timer, buffer).then(
[=](ssize_t bytes) mutable {
inflightRequests.push_back(RequestEntry(std::move(resolveMover), std::move(rejectMover), std::move(timer), std::move(onDone)));
}
, Async::NoExcept);
}
void
......
......@@ -2,6 +2,9 @@
#include "async.h"
#include <thread>
#include <algorithm>
#include <deque>
#include <mutex>
#include <condition_variable>
Async::Promise<int> doAsync(int N)
{
......@@ -291,3 +294,139 @@ TEST(async_test, rethrow_test) {
ASSERT_TRUE(p2.isRejected());
}
template<typename T>
struct MessageQueue {
public:
template<typename U>
void push(U&& arg) {
std::unique_lock<std::mutex> guard(mtx);
q.push_back(std::forward<U>(arg));
cv.notify_one();
}
T pop() {
std::unique_lock<std::mutex> lock(mtx);
cv.wait(lock, [=]() { return !q.empty(); });
T out = std::move(q.front());
q.pop_front();
return out;
}
bool tryPop(T& out, std::chrono::milliseconds timeout) {
std::unique_lock<std::mutex> lock(mtx);
if (!cv.wait_for(lock, timeout, [=]() { return !q.empty(); }))
return false;
out = std::move(q.front());
q.pop_front();
return true;
}
private:
std::deque<T> q;
std::mutex mtx;
std::condition_variable cv;
};
struct Worker {
public:
~Worker() {
thread->join();
}
void start() {
shutdown.store(false);
thread.reset(new std::thread([=]() { run(); }));
}
void stop() {
shutdown.store(true);
}
Async::Promise<int> doWork(int seq) {
return Async::Promise<int>([=](Async::Resolver& resolve, Async::Rejection& reject) {
queue.push(new WorkRequest(std::move(resolve), std::move(reject), seq));
});
}
private:
void run() {
while (!shutdown) {
WorkRequest *request;
if (queue.tryPop(request, std::chrono::milliseconds(200))) {
request->resolve(request->seq);
delete request;
}
}
}
struct WorkRequest {
WorkRequest(Async::Resolver resolve, Async::Rejection reject, int seq)
: resolve(std::move(resolve))
, reject(std::move(reject))
, seq(seq)
{
}
int seq;
Async::Resolver resolve;
Async::Rejection reject;
};
std::atomic<bool> shutdown;
MessageQueue<WorkRequest*> queue;
std::random_device rd;
std::unique_ptr<std::thread> thread;
};
TEST(async_test, stress_multithreaded_test) {
static constexpr size_t OpsPerThread = 100000;
static constexpr size_t Workers = 6;
static constexpr size_t Ops = OpsPerThread * Workers;
std::cout << "Starting stress testing promises, hang on, this test might take some time to complete" << std::endl;
std::cout << "=================================================" << std::endl;
std::cout << "Parameters for the test: " << std::endl;
std::cout << "Workers -> " << Workers << std::endl;
std::cout << "OpsPerThread -> " << OpsPerThread << std::endl;
std::cout << "Total Ops -> " << Ops << std::endl;
std::cout << "=================================================" << std::endl;
std::cout << std::endl << std::endl;
std::vector<std::unique_ptr<Worker>> workers;
for (size_t i = 0; i < Workers; ++i) {
std::unique_ptr<Worker> wrk(new Worker);
wrk->start();
workers.push_back(std::move(wrk));
}
std::vector<Async::Promise<int>> promises;
std::atomic<int> resolved(0);
size_t wrkIndex = 0;
for (size_t i = 0; i < Ops; ++i) {
auto &wrk = workers[wrkIndex];
wrk->doWork(i).then([&](int seq) {
++resolved;
}, Async::NoExcept);
wrkIndex = (wrkIndex + 1) % Workers;
}
for (;;) {
auto r = resolved.load();
std::cout << r << " promises resolved" << std::endl;
if (r == Ops) break;
std::this_thread::sleep_for(std::chrono::milliseconds(500));
}
std::cout << "Stopping worker" << std::endl;
for (auto& wrk: workers) {
wrk->stop();
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment