Commit d2e690a6 authored by Yair Gottdenker's avatar Yair Gottdenker Committed by Facebook GitHub Bot

moving CoroSocket to folly/experimental/coro

Summary: This is the second attempt, the first one was D22958650. Decided to do a different diff as some affected files were moved from experimental/afrind/coro/h2proxy to proxygen/facebook/lib/experimental/coro/ which created some confusion while arc pulling

Reviewed By: yfeldblum

Differential Revision: D25432869

fbshipit-source-id: a183898302a79084d890548b9b7ecc4409f501d2
parent be76ab69
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/Portability.h>
#if FOLLY_HAS_COROUTINES
#include <folly/experimental/coro/Baton.h>
#include <folly/io/coro/ServerSocket.h>
using namespace folly::coro;
namespace {
class AcceptCallback : public folly::AsyncServerSocket::AcceptCallback {
public:
explicit AcceptCallback(
Baton& baton, std::shared_ptr<folly::AsyncServerSocket> socket)
: baton_{baton}, socket_(std::move(socket)) {}
~AcceptCallback() override = default;
int acceptFd{-1};
folly::exception_wrapper error;
private:
// to notify the caller of the result
Baton& baton_;
// the server socket
std::shared_ptr<folly::AsyncServerSocket> socket_;
//
// AcceptCallback methods
//
void connectionAccepted(
folly::NetworkSocket fdNetworkSocket,
const folly::SocketAddress& clientAddr) noexcept override {
VLOG(5) << "Connection accepted from: " << clientAddr.describe();
// unregister handlers while in the callback
socket_->pauseAccepting();
socket_->removeAcceptCallback(this, nullptr);
acceptFd = fdNetworkSocket.toFd();
baton_.post();
}
void acceptError(folly::exception_wrapper ex) noexcept override {
VLOG(5) << "acceptError";
// unregister handlers while in the callback
socket_->pauseAccepting();
socket_->removeAcceptCallback(this, nullptr);
error = std::move(ex);
acceptFd = -1;
baton_.post();
}
void acceptStarted() noexcept override { VLOG(5) << "acceptStarted"; }
void acceptStopped() noexcept override { VLOG(5) << "acceptStopped"; }
};
} // namespace
namespace folly {
namespace coro {
ServerSocket::ServerSocket(
std::shared_ptr<AsyncServerSocket> socket,
std::optional<SocketAddress> bindAddr,
uint32_t listenQueueDepth)
: socket_{socket} {
socket_->setReusePortEnabled(true);
if (bindAddr.has_value()) {
VLOG(1) << "ServerSocket binds on IP: " << bindAddr->describe();
socket_->bind(*bindAddr);
} else {
VLOG(1) << "ServerSocket binds on any addr, random port";
socket_->bind(0);
}
socket_->listen(listenQueueDepth);
}
Task<std::unique_ptr<Socket>> ServerSocket::accept() {
VLOG(5) << "accept() called";
co_await folly::coro::co_safe_point;
Baton baton;
AcceptCallback cb(baton, socket_);
socket_->addAcceptCallback(&cb, nullptr);
socket_->startAccepting();
auto cancelToken = co_await folly::coro::co_current_cancellation_token;
CancellationCallback cancellationCallback{cancelToken, [&baton, this] {
this->socket_->stopAccepting();
baton.post();
}};
co_await baton;
co_await folly::coro::co_safe_point;
if (cb.error) {
co_yield co_error(std::move(cb.error));
}
co_return std::make_unique<Socket>(AsyncSocket::newSocket(
socket_->getEventBase(), NetworkSocket::fromFd(cb.acceptFd)));
}
} // namespace coro
} // namespace folly
#endif // FOLLY_HAS_COROUTINES
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <folly/ExceptionWrapper.h>
#include <folly/Expected.h>
#include <folly/SocketAddress.h>
#include <folly/io/async/AsyncServerSocket.h>
#include <folly/io/coro/Socket.h>
#include <optional>
namespace folly {
namespace coro {
//
// This server socket will accept connections on the
// same event base as the socket itself
//
class ServerSocket {
public:
ServerSocket(
std::shared_ptr<AsyncServerSocket> socket,
std::optional<SocketAddress> bindAddr,
uint32_t listenQueueDepth);
ServerSocket(ServerSocket&&) = default;
ServerSocket& operator=(ServerSocket&&) = default;
Task<std::unique_ptr<Socket>> accept();
void close() noexcept {
if (socket_) {
socket_->stopAccepting();
}
}
const AsyncServerSocket* getAsyncServerSocket() const {
return socket_.get();
}
private:
// non-copyable
ServerSocket(const ServerSocket&) = delete;
ServerSocket& operator=(const ServerSocket&) = delete;
std::shared_ptr<AsyncServerSocket> socket_;
};
} // namespace coro
} // namespace folly
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/Portability.h>
#if FOLLY_HAS_COROUTINES
#include <folly/experimental/coro/Baton.h>
#include <folly/io/coro/Socket.h>
#include <functional>
using namespace folly::coro;
namespace {
//
// Common base for all callbcaks
//
class CallbackBase {
public:
explicit CallbackBase(std::shared_ptr<folly::AsyncSocket> socket)
: socket_{std::move(socket)} {}
virtual ~CallbackBase() noexcept = default;
folly::exception_wrapper& error() noexcept { return error_; }
void post() noexcept { baton_.post(); }
Task<void> wait() { co_await baton_; }
Task<folly::Unit> wait(folly::CancellationToken cancelToken) {
if (cancelToken.isCancellationRequested()) {
cancel();
co_yield folly::coro::co_cancelled;
}
folly::CancellationCallback cancellationCallback{
cancelToken, [this] {
this->post();
VLOG(5) << "Cancellation was called";
}};
co_await wait();
VLOG(5) << "After baton await";
if (cancelToken.isCancellationRequested()) {
cancel();
co_yield folly::coro::co_cancelled;
}
co_return folly::unit;
}
protected:
// we use this to notify the other side of completion
Baton baton_;
// needed to modify AsyncSocket state, e.g. cacncel callbacks
const std::shared_ptr<folly::AsyncSocket> socket_;
// to wrap AsyncSocket errors
folly::exception_wrapper error_;
private:
virtual void cancel() noexcept = 0;
};
//
// Handle connect for AsyncSocket
//
class ConnectCallback : public CallbackBase,
public folly::AsyncSocket::ConnectCallback {
public:
explicit ConnectCallback(std::shared_ptr<folly::AsyncSocket> socket)
: CallbackBase(std::move(socket)) {}
private:
void cancel() noexcept override { socket_->cancelConnect(); }
void connectSuccess() noexcept override { post(); }
void connectErr(const folly::AsyncSocketException& ex) noexcept override {
error_ = folly::exception_wrapper(ex);
post();
}
};
//
// Handle data read for AsyncSocket
//
class ReadCallback : public CallbackBase,
public folly::AsyncSocket::ReadCallback,
public folly::HHWheelTimer::Callback {
public:
// we need to pass the socket into ReadCallback so we can clear the callback
// pointer in the socket, thus preventing multiple callbacks from happening
// in one run of event loop. This may happen, for example, when one fiber
// writes and immediately closes the socket - this would cause the async
// socket to call readDataAvailable and readEOF in sequence, causing the
// promise to be fulfilled twice (oops!)
ReadCallback(
std::shared_ptr<folly::AsyncSocket> socket,
folly::MutableByteRange buf,
std::chrono::milliseconds timeout)
: CallbackBase(socket), buf_{buf} {
if (timeout.count() > 0) {
socket->getEventBase()->timer().scheduleTimeout(this, timeout);
}
}
ReadCallback(
std::shared_ptr<folly::AsyncSocket> socket,
folly::IOBufQueue* readBuf,
size_t minReadSize,
size_t newAllocationSize,
std::chrono::milliseconds timeout)
: CallbackBase(socket),
readBuf_(readBuf),
minReadSize_(minReadSize),
newAllocationSize_(newAllocationSize) {
if (timeout.count() > 0) {
socket->getEventBase()->timer().scheduleTimeout(this, timeout);
}
}
// how much was read during operation
size_t length{0};
bool eof{false};
private:
// the read buffer we store to hand off to callback - obtained from user
folly::MutableByteRange buf_;
folly::IOBufQueue* readBuf_{nullptr};
size_t minReadSize_{0};
size_t newAllocationSize_{0};
void cancel() noexcept override {
socket_->setReadCB(nullptr);
cancelTimeout();
}
//
// ReadCallback methods
//
// this is called right before readDataAvailable(), always
// in the same sequence
void getReadBuffer(void** buf, size_t* len) override {
if (readBuf_) {
auto rbuf = readBuf_->preallocate(minReadSize_, newAllocationSize_);
*buf = rbuf.first;
*len = rbuf.second;
} else {
VLOG(5) << "getReadBuffer, size: " << buf_.size();
*buf = buf_.begin() + length;
*len = buf_.size() - length;
}
}
// once we get actual data, uninstall callback and clear timeout
void readDataAvailable(size_t len) noexcept override {
VLOG(5) << "readDataAvailable: " << len << " bytes";
length += len;
if (readBuf_) {
readBuf_->postallocate(len);
} else if (length == buf_.size()) {
socket_->setReadCB(nullptr);
cancelTimeout();
}
post();
}
void readEOF() noexcept override {
VLOG(5) << "readEOF()";
// disable callbacks
socket_->setReadCB(nullptr);
cancelTimeout();
eof = true;
post();
}
void readErr(const folly::AsyncSocketException& ex) noexcept override {
VLOG(5) << "readErr()";
// disable callbacks
socket_->setReadCB(nullptr);
cancelTimeout();
error_ = folly::exception_wrapper(ex);
post();
}
//
// AsyncTimeout method
//
void timeoutExpired() noexcept override {
VLOG(5) << "timeoutExpired()";
using Error = folly::AsyncSocketException::AsyncSocketExceptionType;
// uninstall read callback. it takes another read to bring it back.
socket_->setReadCB(nullptr);
// If the timeout fires but this ReadCallback did get some data, ignore it.
// post() has already happend from readDataAvailable.
if (length == 0) {
error_ = folly::exception_wrapper(folly::AsyncSocketException(
Error::TIMED_OUT, "Timed out waiting for data", errno));
post();
}
}
};
//
// Handle data write for AsyncSocket
//
class WriteCallback : public CallbackBase,
public folly::AsyncSocket::WriteCallback {
public:
explicit WriteCallback(std::shared_ptr<folly::AsyncSocket> socket)
: CallbackBase(socket) {}
~WriteCallback() override = default;
size_t bytesWritten{0};
std::optional<folly::AsyncSocketException> error;
private:
void cancel() noexcept override { socket_->closeWithReset(); }
//
// Methods of WriteCallback
//
void writeSuccess() noexcept override {
VLOG(5) << "writeSuccess";
post();
}
void writeErr(
size_t bytes, const folly::AsyncSocketException& ex) noexcept override {
VLOG(5) << "writeErr, wrote " << bytesWritten << " bytes";
bytesWritten = bytes;
error = ex;
post();
}
};
} // namespace
namespace folly {
namespace coro {
Task<Socket> Socket::connect(
folly::EventBase* evb,
const folly::SocketAddress& destAddr,
std::chrono::milliseconds connectTimeout) {
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(evb);
socket->setReadCB(nullptr);
ConnectCallback cb{socket};
socket->connect(&cb, destAddr, connectTimeout.count());
auto waitRet =
co_await co_awaitTry(cb.wait(co_await co_current_cancellation_token));
if (waitRet.hasException()) {
co_yield co_error(std::move(waitRet.exception()));
}
if (cb.error()) {
co_yield co_error(std::move(cb.error()));
}
co_return Socket(socket);
}
Task<size_t> Socket::read(
folly::MutableByteRange buf, std::chrono::milliseconds timeout) {
if (deferredReadEOF_) {
deferredReadEOF_ = false;
co_return 0;
}
VLOG(5) << "Socket::read(), expecting max len " << buf.size();
ReadCallback cb{socket_, buf, timeout};
socket_->setReadCB(&cb);
auto waitRet =
co_await co_awaitTry(cb.wait(co_await co_current_cancellation_token));
if (waitRet.hasException()) {
co_yield co_error(std::move(waitRet.exception()));
}
if (cb.error()) {
co_yield co_error(std::move(cb.error()));
}
socket_->setReadCB(nullptr);
deferredReadEOF_ = (cb.eof && cb.length > 0);
co_return cb.length;
}
Task<size_t> Socket::read(
folly::IOBufQueue& readBuf,
std::size_t minReadSize,
std::size_t newAllocationSize,
std::chrono::milliseconds timeout) {
if (deferredReadEOF_) {
deferredReadEOF_ = false;
co_return 0;
}
VLOG(5) << "Socket::read(), expecting minReadSize=" << minReadSize;
ReadCallback cb{socket_, &readBuf, minReadSize, newAllocationSize, timeout};
socket_->setReadCB(&cb);
auto waitRet =
co_await co_awaitTry(cb.wait(co_await co_current_cancellation_token));
if (waitRet.hasException()) {
co_yield co_error(std::move(waitRet.exception()));
}
if (cb.error()) {
co_yield co_error(std::move(cb.error()));
}
socket_->setReadCB(nullptr);
deferredReadEOF_ = (cb.eof && cb.length > 0);
co_return cb.length;
}
Task<folly::Unit> Socket::write(
folly::ByteRange buf,
std::chrono::milliseconds timeout,
WriteInfo* writeInfo) {
socket_->setSendTimeout(timeout.count());
WriteCallback cb{socket_};
socket_->write(&cb, buf.begin(), buf.size());
auto waitRet =
co_await co_awaitTry(cb.wait(co_await co_current_cancellation_token));
if (waitRet.hasException()) {
if (writeInfo) {
writeInfo->bytesWritten = cb.bytesWritten;
}
co_yield co_error(std::move(waitRet.exception()));
}
if (cb.error) {
if (writeInfo) {
writeInfo->bytesWritten = cb.bytesWritten;
}
co_yield co_error(std::move(*cb.error));
}
co_return unit;
}
Task<folly::Unit> Socket::write(
folly::IOBufQueue& ioBufQueue,
std::chrono::milliseconds timeout,
WriteInfo* writeInfo) {
socket_->setSendTimeout(timeout.count());
WriteCallback cb{socket_};
auto iovec = ioBufQueue.front()->getIov();
socket_->writev(&cb, iovec.data(), iovec.size());
auto waitRet =
co_await co_awaitTry(cb.wait(co_await co_current_cancellation_token));
if (waitRet.hasException()) {
if (writeInfo) {
writeInfo->bytesWritten = cb.bytesWritten;
}
co_yield co_error(std::move(waitRet.exception()));
}
if (cb.error) {
if (writeInfo) {
writeInfo->bytesWritten = cb.bytesWritten;
}
co_yield co_error(std::move(*cb.error));
}
co_return unit;
}
} // namespace coro
} // namespace folly
#endif // FOLLY_HAS_COROUTINES
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <folly/Range.h>
#include <folly/SocketAddress.h>
#include <folly/experimental/coro/Task.h>
#include <folly/experimental/coro/Utils.h>
#include <folly/io/IOBufQueue.h>
#include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/AsyncTimeout.h>
namespace folly {
namespace coro {
class Transport {
public:
using ErrorCode = AsyncSocketException::AsyncSocketExceptionType;
// on write error, report the issue and how many bytes were written
virtual ~Transport() = default;
virtual EventBase* getEventBase() noexcept = 0;
virtual Task<size_t> read(
MutableByteRange buf, std::chrono::milliseconds timeout) = 0;
Task<size_t> read(
void* buf, size_t buflen, std::chrono::milliseconds timeout) {
return read(MutableByteRange((unsigned char*)buf, buflen), timeout);
}
virtual Task<size_t> read(
IOBufQueue& buf,
size_t minReadSize,
size_t newAllocationSize,
std::chrono::milliseconds timeout) = 0;
struct WriteInfo {
size_t bytesWritten{0};
};
virtual Task<Unit> write(
ByteRange buf,
std::chrono::milliseconds timeout = std::chrono::milliseconds(0),
WriteInfo* writeInfo = nullptr) = 0;
virtual Task<Unit> write(
IOBufQueue& ioBufQueue,
std::chrono::milliseconds timeout = std::chrono::milliseconds(0),
WriteInfo* writeInfo = nullptr) = 0;
virtual SocketAddress getLocalAddress() const noexcept = 0;
virtual SocketAddress getPeerAddress() const noexcept = 0;
virtual void close() = 0;
virtual void shutdownWrite() = 0;
virtual void closeWithReset() = 0;
virtual folly::AsyncTransport* getTransport() const = 0;
virtual const AsyncTransportCertificate* getPeerCertificate() const = 0;
};
class Socket : public Transport {
public:
explicit Socket(std::shared_ptr<AsyncSocket> socket)
: socket_(std::move(socket)) {}
Socket(Socket&&) = default;
Socket& operator=(Socket&&) = default;
static Task<Socket> connect(
EventBase* evb,
const SocketAddress& destAddr,
std::chrono::milliseconds connectTimeout);
virtual EventBase* getEventBase() noexcept override {
return socket_->getEventBase();
}
Task<size_t> read(
MutableByteRange buf, std::chrono::milliseconds timeout) override;
Task<size_t> read(
IOBufQueue& buf,
size_t minReadSize,
size_t newAllocationSize,
std::chrono::milliseconds timeout) override;
Task<Unit> write(
ByteRange buf,
std::chrono::milliseconds timeout = std::chrono::milliseconds(0),
WriteInfo* writeInfo = nullptr) override;
Task<folly::Unit> write(
IOBufQueue& ioBufQueue,
std::chrono::milliseconds timeout = std::chrono::milliseconds(0),
WriteInfo* writeInfo = nullptr) override;
SocketAddress getLocalAddress() const noexcept override {
SocketAddress addr;
socket_->getLocalAddress(&addr);
return addr;
}
folly::AsyncTransport* getTransport() const override { return socket_.get(); }
SocketAddress getPeerAddress() const noexcept override {
SocketAddress addr;
socket_->getPeerAddress(&addr);
return addr;
}
void shutdownWrite() noexcept override {
if (socket_) {
socket_->shutdownWrite();
}
}
void close() noexcept override {
if (socket_) {
socket_->close();
}
}
void closeWithReset() noexcept override {
if (socket_) {
socket_->closeWithReset();
}
}
std::shared_ptr<AsyncSocket> getAsyncSocket() { return socket_; }
const AsyncTransportCertificate* getPeerCertificate() const override {
return socket_->getPeerCertificate();
}
private:
// non-copyable
Socket(const Socket&) = delete;
Socket& operator=(const Socket&) = delete;
std::shared_ptr<AsyncSocket> socket_;
bool deferredReadEOF_{false};
};
} // namespace coro
} // namespace folly
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <folly/Portability.h>
#if FOLLY_HAS_COROUTINES
#include <folly/experimental/coro/BlockingWait.h>
#include <folly/experimental/coro/Collect.h>
#include <folly/io/async/test/AsyncSocketTest.h>
#include <folly/io/async/test/ScopedBoundPort.h>
#include <folly/io/coro/ServerSocket.h>
#include <folly/io/coro/Socket.h>
#include <folly/portability/GTest.h>
using namespace std::chrono_literals;
using namespace folly;
using namespace folly::coro;
class SocketTest : public testing::Test {
public:
template <typename F>
void run(F f) {
blockingWait(co_invoke(std::move(f)), &evb);
}
folly::coro::Task<> requestCancellation() {
cancelSource.requestCancellation();
co_return;
}
EventBase evb;
CancellationSource cancelSource;
};
class ServerSocketTest : public SocketTest {
public:
folly::coro::Task<Socket> connect() {
co_return co_await Socket::connect(&evb, srv.getAddress(), 0ms);
}
TestServer srv;
};
TEST_F(SocketTest, ConnectFailure) {
run([&]() -> Task<> {
ScopedBoundPort ph;
auto serverAddr = ph.getAddress();
EXPECT_THROW(
co_await Socket::connect(&evb, serverAddr, 0ms), AsyncSocketException);
});
}
TEST_F(ServerSocketTest, ConnectSuccess) {
run([&]() -> Task<> {
auto cs = co_await connect();
EXPECT_EQ(srv.getAddress(), cs.getPeerAddress());
});
}
TEST_F(ServerSocketTest, ConnectCancelled) {
run([&]() -> Task<> {
co_await folly::coro::collectAll(
// token would be cancelled while waiting on connect
[&]() -> Task<> {
EXPECT_THROW(
co_await co_withCancellation(cancelSource.getToken(), connect()),
OperationCancelled);
}(),
requestCancellation());
// token was cancelled before read was called
EXPECT_THROW(
co_await co_withCancellation(
cancelSource.getToken(),
Socket::connect(&evb, srv.getAddress(), 0ms)),
OperationCancelled);
});
}
TEST_F(ServerSocketTest, SimpleRead) {
run([&]() -> Task<> {
constexpr auto kBufSize = 65536;
auto cs = co_await connect();
// produces blocking socket
auto ss = srv.accept(-1);
std::array<uint8_t, kBufSize> sndBuf;
std::memset(sndBuf.data(), 'a', sndBuf.size());
ss->write(sndBuf.data(), sndBuf.size());
// read using coroutines
std::array<uint8_t, kBufSize> rcvBuf;
auto reader = [&rcvBuf, &cs]() -> Task<Unit> {
int totalBytes{0};
while (totalBytes < kBufSize) {
auto bytesRead = co_await cs.read(
MutableByteRange(
rcvBuf.data() + totalBytes,
(rcvBuf.data() + rcvBuf.size() - totalBytes)),
0ms);
totalBytes += bytesRead;
}
co_return unit;
};
co_await reader();
EXPECT_EQ(0, memcmp(sndBuf.data(), rcvBuf.data(), rcvBuf.size()));
});
}
TEST_F(ServerSocketTest, SimpleIOBufRead) {
run([&]() -> Task<> {
// Exactly fills a buffer mid-loop and triggers deferredReadEOF handling
constexpr auto kBufSize = 55 * 1184;
auto cs = co_await connect();
// produces blocking socket
auto ss = srv.accept(-1);
std::array<uint8_t, kBufSize> sndBuf;
std::memset(sndBuf.data(), 'a', sndBuf.size());
ss->write(sndBuf.data(), sndBuf.size());
ss->close();
// read using coroutines
IOBufQueue rcvBuf(IOBufQueue::cacheChainLength());
int totalBytes{0};
while (totalBytes < kBufSize) {
auto bytesRead = co_await cs.read(rcvBuf, 1000, 1000, 0ms);
totalBytes += bytesRead;
}
auto bytesRead = co_await cs.read(rcvBuf, 1000, 1000, 50ms);
EXPECT_EQ(bytesRead, 0); // closed
auto data = rcvBuf.move();
data->coalesce();
EXPECT_EQ(0, memcmp(sndBuf.data(), data->data(), data->length()));
});
}
TEST_F(ServerSocketTest, ReadCancelled) {
run([&]() -> Task<> {
auto cs = co_await connect();
auto reader = [&cs]() -> Task<Unit> {
std::array<uint8_t, 1024> rcvBuf;
EXPECT_THROW(
co_await cs.read(
MutableByteRange(rcvBuf.data(), (rcvBuf.data() + rcvBuf.size())),
0ms),
OperationCancelled);
co_return unit;
};
co_await co_withCancellation(
cancelSource.getToken(),
folly::coro::collectAll(requestCancellation(), reader()));
// token was cancelled before read was called
co_await co_withCancellation(cancelSource.getToken(), reader());
});
}
TEST_F(ServerSocketTest, ReadTimeout) {
run([&]() -> Task<> {
auto cs = co_await connect();
std::array<uint8_t, 1024> rcvBuf;
EXPECT_THROW(
co_await cs.read(
MutableByteRange(rcvBuf.data(), (rcvBuf.data() + rcvBuf.size())),
50ms),
AsyncSocketException);
});
}
TEST_F(ServerSocketTest, ReadError) {
run([&]() -> Task<> {
auto cs = co_await connect();
// produces blocking socket
auto ss = srv.accept(-1);
ss->closeWithReset();
std::array<uint8_t, 1024> rcvBuf;
EXPECT_THROW(
co_await cs.read(
MutableByteRange(rcvBuf.data(), (rcvBuf.data() + rcvBuf.size())),
50ms),
AsyncSocketException);
});
}
TEST_F(ServerSocketTest, SimpleWrite) {
run([&]() -> Task<> {
auto cs = co_await connect();
// produces blocking socket
auto ss = srv.accept(-1);
constexpr auto kBufSize = 65536;
std::array<uint8_t, kBufSize> sndBuf;
std::memset(sndBuf.data(), 'a', sndBuf.size());
// write use co-routine
co_await cs.write(ByteRange(sndBuf.data(), sndBuf.data() + sndBuf.size()));
// read on server side
std::array<uint8_t, kBufSize> rcvBuf;
ss->readAll(rcvBuf.data(), rcvBuf.size());
EXPECT_EQ(0, memcmp(sndBuf.data(), rcvBuf.data(), rcvBuf.size()));
});
}
TEST_F(ServerSocketTest, SimpleWritev) {
run([&]() -> Task<> {
auto cs = co_await connect();
// produces blocking socket
auto ss = srv.accept(-1);
IOBufQueue sndBuf;
constexpr auto kBufSize = 65536;
std::array<uint8_t, kBufSize> bufA;
std::memset(bufA.data(), 'a', bufA.size());
std::array<uint8_t, kBufSize> bufB;
std::memset(bufB.data(), 'b', bufB.size());
sndBuf.append(bufA.data(), bufA.size());
sndBuf.append(bufB.data(), bufB.size());
// write use co-routine
co_await cs.write(sndBuf);
// read on server side
std::array<uint8_t, kBufSize> rcvBufA;
ss->readAll(rcvBufA.data(), rcvBufA.size());
EXPECT_EQ(0, memcmp(bufA.data(), rcvBufA.data(), rcvBufA.size()));
std::array<uint8_t, kBufSize> rcvBufB;
ss->readAll(rcvBufB.data(), rcvBufB.size());
EXPECT_EQ(0, memcmp(bufB.data(), rcvBufB.data(), rcvBufB.size()));
});
}
TEST_F(ServerSocketTest, WriteCancelled) {
run([&]() -> Task<> {
auto cs = co_await connect();
// reduce the send buffer size so the write wouldn't complete immediately
EXPECT_EQ(cs.getAsyncSocket()->setSendBufSize(4096), 0);
// produces blocking socket
auto ss = srv.accept(-1);
constexpr auto kBufSize = 65536;
std::array<uint8_t, kBufSize> sndBuf;
std::memset(sndBuf.data(), 'a', sndBuf.size());
// write use co-routine
auto writter = [&]() -> Task<> {
EXPECT_THROW(
co_await co_withCancellation(
cancelSource.getToken(),
cs.write(
ByteRange(sndBuf.data(), sndBuf.data() + sndBuf.size()))),
OperationCancelled);
};
co_await folly::coro::collectAll(requestCancellation(), writter());
co_await co_withCancellation(cancelSource.getToken(), writter());
});
}
TEST_F(SocketTest, SimpleAccept) {
run([&]() -> Task<> {
ServerSocket css(AsyncServerSocket::newSocket(&evb), std::nullopt, 16);
auto serverAddr = css.getAsyncServerSocket()->getAddress();
co_await folly::coro::collectAll(
css.accept(), Socket::connect(&evb, serverAddr, 0ms));
});
}
TEST_F(SocketTest, AcceptCancelled) {
run([&]() -> Task<> {
co_await folly::coro::collectAll(requestCancellation(), [&]() -> Task<> {
ServerSocket css(AsyncServerSocket::newSocket(&evb), std::nullopt, 16);
EXPECT_THROW(
co_await co_withCancellation(cancelSource.getToken(), css.accept()),
OperationCancelled);
}());
});
}
TEST_F(SocketTest, AsyncClientAndServer) {
run([&]() -> Task<> {
constexpr int kSize = 128;
ServerSocket css(AsyncServerSocket::newSocket(&evb), std::nullopt, 16);
auto serverAddr = css.getAsyncServerSocket()->getAddress();
auto cs = co_await Socket::connect(&evb, serverAddr, 0ms);
co_await folly::coro::collectAll(
[&css]() -> Task<> {
auto sock = co_await css.accept();
std::array<uint8_t, kSize> buf;
memset(buf.data(), 'a', kSize);
co_await sock->write(ByteRange(buf.begin(), buf.end()));
css.close();
}(),
[&cs]() -> Task<> {
std::array<uint8_t, kSize> buf;
// For fun, shutdown the write half -- we don't need it
cs.shutdownWrite();
auto len =
co_await cs.read(MutableByteRange(buf.begin(), buf.end()), 0ms);
cs.close();
EXPECT_TRUE(len == buf.size());
}());
});
}
#endif // FOLLY_HAS_COROUTINES
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment