Unverified Commit 9a65f409 authored by Ian Roddis's avatar Ian Roddis Committed by GitHub

Rest::Router add client disconnect handlers (#852)

* - Adding Router::addDisconnectHandler to handle client disconnections.
- Adding code to call handlers on client disconnection

* - Adding test case for router addDisconnectHandler and associated functionality.

* Changing to 2-space indentation and other formating changes to conform to LLVM style

* More format changes
Co-authored-by: default avatarIan Roddis <gitlab@ie2r.com>
parent ae146fa7
...@@ -67,6 +67,9 @@ struct Route { ...@@ -67,6 +67,9 @@ struct Route {
typedef std::function<bool(Http::Request& req, Http::ResponseWriter& resp)> Middleware; typedef std::function<bool(Http::Request& req, Http::ResponseWriter& resp)> Middleware;
typedef std::function<void(const std::shared_ptr<Tcp::Peer> &peer)>
DisconnectHandler;
explicit Route(Route::Handler handler) : handler_(std::move(handler)) {} explicit Route(Route::Handler handler) : handler_(std::move(handler)) {}
template <typename... Args> void invokeHandler(Args &&... args) const { template <typename... Args> void invokeHandler(Args &&... args) const {
...@@ -223,10 +226,13 @@ public: ...@@ -223,10 +226,13 @@ public:
void addMiddleware(Route::Middleware middleware); void addMiddleware(Route::Middleware middleware);
void addNotFoundHandler(Route::Handler handler); void addNotFoundHandler(Route::Handler handler);
void addDisconnectHandler(Route::DisconnectHandler handler);
inline bool hasNotFoundHandler() { return notFoundHandler != nullptr; } inline bool hasNotFoundHandler() { return notFoundHandler != nullptr; }
void invokeNotFoundHandler(const Http::Request &req, void invokeNotFoundHandler(const Http::Request &req,
Http::ResponseWriter resp) const; Http::ResponseWriter resp) const;
void disconnectPeer(const std::shared_ptr<Tcp::Peer> &peer);
Route::Status route(const Http::Request &request, Route::Status route(const Http::Request &request,
Http::ResponseWriter response); Http::ResponseWriter response);
...@@ -239,6 +245,8 @@ private: ...@@ -239,6 +245,8 @@ private:
std::vector<Route::Middleware> middlewares; std::vector<Route::Middleware> middlewares;
std::vector<Route::DisconnectHandler> disconnectHandlers;
Route::Handler notFoundHandler; Route::Handler notFoundHandler;
}; };
...@@ -266,6 +274,8 @@ public: ...@@ -266,6 +274,8 @@ public:
void onRequest(const Http::Request &req, void onRequest(const Http::Request &req,
Http::ResponseWriter response) override; Http::ResponseWriter response) override;
void onDisconnection(const std::shared_ptr<Tcp::Peer> &peer) override;
private: private:
std::shared_ptr<Rest::Router> router; std::shared_ptr<Rest::Router> router;
}; };
......
...@@ -306,6 +306,10 @@ void RouterHandler::onRequest(const Http::Request &req, ...@@ -306,6 +306,10 @@ void RouterHandler::onRequest(const Http::Request &req,
router->route(req, std::move(response)); router->route(req, std::move(response));
} }
void RouterHandler::onDisconnection(const std::shared_ptr<Tcp::Peer> &peer) {
router->disconnectPeer(peer);
}
} // namespace Private } // namespace Private
Router Router::fromDescription(const Rest::Description &desc) { Router Router::fromDescription(const Rest::Description &desc) {
...@@ -384,6 +388,10 @@ void Router::addMiddleware(Route::Middleware middleware) { ...@@ -384,6 +388,10 @@ void Router::addMiddleware(Route::Middleware middleware) {
middlewares.push_back(std::move(middleware)); middlewares.push_back(std::move(middleware));
} }
void Router::addDisconnectHandler(Route::DisconnectHandler handler) {
disconnectHandlers.push_back(std::move(handler));
}
void Router::addNotFoundHandler(Route::Handler handler) { void Router::addNotFoundHandler(Route::Handler handler) {
notFoundHandler = std::move(handler); notFoundHandler = std::move(handler);
} }
...@@ -480,6 +488,12 @@ void Router::addRoute(Http::Method method, const std::string &resource, ...@@ -480,6 +488,12 @@ void Router::addRoute(Http::Method method, const std::string &resource,
r.addRoute(path, handler, ptr); r.addRoute(path, handler, ptr);
} }
void Router::disconnectPeer(const std::shared_ptr<Tcp::Peer> &peer) {
for (const auto &handler : disconnectHandlers) {
handler(peer);
}
}
namespace Routes { namespace Routes {
void Get(Router &router, const std::string &resource, Route::Handler handler) { void Get(Router &router, const std::string &resource, Route::Handler handler) {
......
...@@ -379,3 +379,67 @@ TEST(segment_tree_node_test, test_resource_sanitize) { ...@@ -379,3 +379,67 @@ TEST(segment_tree_node_test, test_resource_sanitize) {
ASSERT_EQ(SegmentTreeNode::sanitizeResource("/path//to/bar"), "path/to/bar"); ASSERT_EQ(SegmentTreeNode::sanitizeResource("/path//to/bar"), "path/to/bar");
ASSERT_EQ(SegmentTreeNode::sanitizeResource("/path/to///////:place"), "path/to/:place"); ASSERT_EQ(SegmentTreeNode::sanitizeResource("/path/to///////:place"), "path/to/:place");
} }
namespace {
class WaitHelper {
public:
void increment() {
std::lock_guard<std::mutex> lock(counterLock_);
++counter_;
cv_.notify_one();
}
template <typename Duration>
bool wait(const size_t count, const Duration timeout) {
std::unique_lock<std::mutex> lock(counterLock_);
return cv_.wait_for(lock, timeout,
[this, count]() { return counter_ >= count; });
}
private:
size_t counter_ = 0;
std::mutex counterLock_;
std::condition_variable cv_;
};
TEST(router_test, test_client_disconnects) {
Address addr(Ipv4::any(), 0);
auto endpoint = std::make_shared<Http::Endpoint>(addr);
auto opts = Http::Endpoint::options().threads(1).maxRequestSize(4096);
endpoint->init(opts);
int count_found = 0;
WaitHelper count_disconnect;
Rest::Router router;
Routes::Head(router, "/moogle",
[&count_found](const Pistache::Rest::Request &,
Pistache::Http::ResponseWriter response) {
count_found++;
response.send(Pistache::Http::Code::Ok);
return Pistache::Rest::Route::Result::Ok;
});
router.addDisconnectHandler(
[&count_disconnect](const std::shared_ptr<Tcp::Peer> &) {
count_disconnect.increment();
});
endpoint->setHandler(router.handler());
endpoint->serveThreaded();
const auto bound_port = endpoint->getPort();
{
httplib::Client client("localhost", bound_port);
count_found = 0;
client.Head("/moogle");
ASSERT_EQ(count_found, 1);
}
const bool result = count_disconnect.wait(1, std::chrono::seconds(2));
endpoint->shutdown();
ASSERT_EQ(result, 1);
}
} // namespace
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment