/*
 * Copyright 2016 Facebook, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include <folly/io/ShutdownSocketSet.h>

#include <atomic>
#include <chrono>
#include <thread>

#include <glog/logging.h>
#include <gtest/gtest.h>

#include <folly/portability/Sockets.h>

using folly::ShutdownSocketSet;

namespace folly { namespace test {

ShutdownSocketSet shutdownSocketSet;

class Server {
 public:
  Server();

  void stop(bool abortive);
  void join();
  int port() const { return port_; }
  int closeClients(bool abortive);

 private:
  int acceptSocket_;
  int port_;
  enum StopMode {
    NO_STOP,
    ORDERLY,
    ABORTIVE
  };
  std::atomic<StopMode> stop_;
  std::thread serverThread_;
  std::vector<int> fds_;
};

Server::Server()
  : acceptSocket_(-1),
    port_(0),
    stop_(NO_STOP) {
  acceptSocket_ = socket(PF_INET, SOCK_STREAM, 0);
  CHECK_ERR(acceptSocket_);
  shutdownSocketSet.add(acceptSocket_);

  sockaddr_in addr;
  addr.sin_family = AF_INET;
  addr.sin_port = 0;
  addr.sin_addr.s_addr = INADDR_ANY;
  CHECK_ERR(bind(acceptSocket_,
                 reinterpret_cast<const sockaddr*>(&addr),
                 sizeof(addr)));

  CHECK_ERR(listen(acceptSocket_, 10));

  socklen_t addrLen = sizeof(addr);
  CHECK_ERR(getsockname(acceptSocket_,
                        reinterpret_cast<sockaddr*>(&addr),
                        &addrLen));

  port_ = ntohs(addr.sin_port);

  serverThread_ = std::thread([this] {
    while (stop_ == NO_STOP) {
      sockaddr_in peer;
      socklen_t peerLen = sizeof(peer);
      int fd = accept(acceptSocket_,
                      reinterpret_cast<sockaddr*>(&peer),
                      &peerLen);
      if (fd == -1) {
        if (errno == EINTR) {
          continue;
        }
        if (errno == EINVAL || errno == ENOTSOCK) {  // socket broken
          break;
        }
      }
      CHECK_ERR(fd);
      shutdownSocketSet.add(fd);
      fds_.push_back(fd);
    }

    if (stop_ != NO_STOP) {
      closeClients(stop_ == ABORTIVE);
    }

    shutdownSocketSet.close(acceptSocket_);
    acceptSocket_ = -1;
    port_ = 0;
  });
}

int Server::closeClients(bool abortive) {
  for (int fd : fds_) {
    if (abortive) {
      struct linger l = {1, 0};
      CHECK_ERR(setsockopt(fd, SOL_SOCKET, SO_LINGER, &l, sizeof(l)));
    }
    shutdownSocketSet.close(fd);
  }
  int n = fds_.size();
  fds_.clear();
  return n;
}

void Server::stop(bool abortive) {
  stop_ = abortive ? ABORTIVE : ORDERLY;
  shutdown(acceptSocket_, SHUT_RDWR);
}

void Server::join() {
  serverThread_.join();
}

int createConnectedSocket(int port) {
  int sock = socket(PF_INET, SOCK_STREAM, 0);
  CHECK_ERR(sock);
  sockaddr_in addr;
  addr.sin_family = AF_INET;
  addr.sin_port = htons(port);
  addr.sin_addr.s_addr = htonl((127 << 24) | 1);  // XXX
  CHECK_ERR(connect(sock,
                    reinterpret_cast<const sockaddr*>(&addr),
                    sizeof(addr)));
  return sock;
}

void runCloseTest(bool abortive) {
  Server server;

  int sock = createConnectedSocket(server.port());

  std::thread stopper([&server, abortive] {
    std::this_thread::sleep_for(std::chrono::milliseconds(200));
    server.stop(abortive);
    server.join();
  });

  char c;
  int r = read(sock, &c, 1);
  if (abortive) {
    int e = errno;
    EXPECT_EQ(-1, r);
    EXPECT_EQ(ECONNRESET, e);
  } else {
    EXPECT_EQ(0, r);
  }

  close(sock);

  stopper.join();

  EXPECT_EQ(0, server.closeClients(false));  // closed by server when it exited
}

TEST(ShutdownSocketSetTest, OrderlyClose) {
  runCloseTest(false);
}

TEST(ShutdownSocketSetTest, AbortiveClose) {
  runCloseTest(true);
}

void runKillTest(bool abortive) {
  Server server;

  int sock = createConnectedSocket(server.port());

  std::thread killer([&server, abortive] {
    std::this_thread::sleep_for(std::chrono::milliseconds(200));
    shutdownSocketSet.shutdownAll(abortive);
    server.join();
  });

  char c;
  int r = read(sock, &c, 1);

  // "abortive" is just a hint for ShutdownSocketSet, so accept both
  // behaviors
  if (abortive) {
    if (r == -1) {
      EXPECT_EQ(ECONNRESET, errno);
    } else {
      EXPECT_EQ(r, 0);
    }
  } else {
    EXPECT_EQ(0, r);
  }

  close(sock);

  killer.join();

  // NOT closed by server when it exited
  EXPECT_EQ(1, server.closeClients(false));
}

TEST(ShutdownSocketSetTest, OrderlyKill) {
  runKillTest(false);
}

TEST(ShutdownSocketSetTest, AbortiveKill) {
  runKillTest(true);
}

}}  // namespaces
