Commit c3156ea8 authored by Alan Frindell's avatar Alan Frindell Committed by Facebook GitHub Bot

coro::Socket wraps AsyncTransport

Summary: This allows coro::Socket to be used with non-socket AsyncTransports, notably Fizz.  A subsequent diff renames the class and file, and creates a shim for callers.

Reviewed By: yairgott

Differential Revision: D26610473

fbshipit-source-id: d64597ef0de3c90ab084249b20e85d97b34857a6
parent 2f7a1264
...@@ -114,8 +114,10 @@ Task<std::unique_ptr<Socket>> ServerSocket::accept() { ...@@ -114,8 +114,10 @@ Task<std::unique_ptr<Socket>> ServerSocket::accept() {
if (cb.error) { if (cb.error) {
co_yield co_error(std::move(cb.error)); co_yield co_error(std::move(cb.error));
} }
co_return std::make_unique<Socket>(AsyncSocket::newSocket( co_return std::make_unique<Socket>(
socket_->getEventBase(), NetworkSocket::fromFd(cb.acceptFd))); socket_->getEventBase(),
AsyncSocket::newSocket(
socket_->getEventBase(), NetworkSocket::fromFd(cb.acceptFd)));
} }
} // namespace coro } // namespace coro
......
...@@ -33,8 +33,8 @@ namespace { ...@@ -33,8 +33,8 @@ namespace {
class CallbackBase { class CallbackBase {
public: public:
explicit CallbackBase(std::shared_ptr<folly::AsyncSocket> socket) explicit CallbackBase(folly::AsyncTransport& transport)
: socket_{std::move(socket)} {} : transport_{transport} {}
virtual ~CallbackBase() noexcept = default; virtual ~CallbackBase() noexcept = default;
...@@ -65,10 +65,10 @@ class CallbackBase { ...@@ -65,10 +65,10 @@ class CallbackBase {
protected: protected:
// we use this to notify the other side of completion // we use this to notify the other side of completion
Baton baton_; Baton baton_;
// needed to modify AsyncSocket state, e.g. cacncel callbacks // needed to modify AsyncTransport state, e.g. cacncel callbacks
const std::shared_ptr<folly::AsyncSocket> socket_; folly::AsyncTransport& transport_;
// to wrap AsyncSocket errors // to wrap AsyncTransport errors
folly::exception_wrapper error_; folly::exception_wrapper error_;
private: private:
...@@ -82,11 +82,11 @@ class CallbackBase { ...@@ -82,11 +82,11 @@ class CallbackBase {
class ConnectCallback : public CallbackBase, class ConnectCallback : public CallbackBase,
public folly::AsyncSocket::ConnectCallback { public folly::AsyncSocket::ConnectCallback {
public: public:
explicit ConnectCallback(std::shared_ptr<folly::AsyncSocket> socket) explicit ConnectCallback(folly::AsyncSocket& socket)
: CallbackBase(std::move(socket)) {} : CallbackBase(socket), socket_(socket) {}
private: private:
void cancel() noexcept override { socket_->cancelConnect(); } void cancel() noexcept override { socket_.cancelConnect(); }
void connectSuccess() noexcept override { post(); } void connectSuccess() noexcept override { post(); }
...@@ -94,14 +94,15 @@ class ConnectCallback : public CallbackBase, ...@@ -94,14 +94,15 @@ class ConnectCallback : public CallbackBase,
error_ = folly::exception_wrapper(ex); error_ = folly::exception_wrapper(ex);
post(); post();
} }
folly::AsyncSocket& socket_;
}; };
// //
// Handle data read for AsyncSocket // Handle data read for AsyncTransport
// //
class ReadCallback : public CallbackBase, class ReadCallback : public CallbackBase,
public folly::AsyncSocket::ReadCallback, public folly::AsyncTransport::ReadCallback,
public folly::HHWheelTimer::Callback { public folly::HHWheelTimer::Callback {
public: public:
// we need to pass the socket into ReadCallback so we can clear the callback // we need to pass the socket into ReadCallback so we can clear the callback
...@@ -111,27 +112,29 @@ class ReadCallback : public CallbackBase, ...@@ -111,27 +112,29 @@ class ReadCallback : public CallbackBase,
// socket to call readDataAvailable and readEOF in sequence, causing the // socket to call readDataAvailable and readEOF in sequence, causing the
// promise to be fulfilled twice (oops!) // promise to be fulfilled twice (oops!)
ReadCallback( ReadCallback(
std::shared_ptr<folly::AsyncSocket> socket, folly::HHWheelTimer& timer,
folly::AsyncTransport& transport,
folly::MutableByteRange buf, folly::MutableByteRange buf,
std::chrono::milliseconds timeout) std::chrono::milliseconds timeout)
: CallbackBase(socket), buf_{buf} { : CallbackBase(transport), buf_{buf} {
if (timeout.count() > 0) { if (timeout.count() > 0) {
socket->getEventBase()->timer().scheduleTimeout(this, timeout); timer.scheduleTimeout(this, timeout);
} }
} }
ReadCallback( ReadCallback(
std::shared_ptr<folly::AsyncSocket> socket, folly::HHWheelTimer& timer,
folly::AsyncTransport& transport,
folly::IOBufQueue* readBuf, folly::IOBufQueue* readBuf,
size_t minReadSize, size_t minReadSize,
size_t newAllocationSize, size_t newAllocationSize,
std::chrono::milliseconds timeout) std::chrono::milliseconds timeout)
: CallbackBase(socket), : CallbackBase(transport),
readBuf_(readBuf), readBuf_(readBuf),
minReadSize_(minReadSize), minReadSize_(minReadSize),
newAllocationSize_(newAllocationSize) { newAllocationSize_(newAllocationSize) {
if (timeout.count() > 0) { if (timeout.count() > 0) {
socket->getEventBase()->timer().scheduleTimeout(this, timeout); timer.scheduleTimeout(this, timeout);
} }
} }
...@@ -147,7 +150,7 @@ class ReadCallback : public CallbackBase, ...@@ -147,7 +150,7 @@ class ReadCallback : public CallbackBase,
size_t newAllocationSize_{0}; size_t newAllocationSize_{0};
void cancel() noexcept override { void cancel() noexcept override {
socket_->setReadCB(nullptr); transport_.setReadCB(nullptr);
cancelTimeout(); cancelTimeout();
} }
...@@ -176,7 +179,7 @@ class ReadCallback : public CallbackBase, ...@@ -176,7 +179,7 @@ class ReadCallback : public CallbackBase,
if (readBuf_) { if (readBuf_) {
readBuf_->postallocate(len); readBuf_->postallocate(len);
} else if (length == buf_.size()) { } else if (length == buf_.size()) {
socket_->setReadCB(nullptr); transport_.setReadCB(nullptr);
cancelTimeout(); cancelTimeout();
} }
post(); post();
...@@ -185,7 +188,7 @@ class ReadCallback : public CallbackBase, ...@@ -185,7 +188,7 @@ class ReadCallback : public CallbackBase,
void readEOF() noexcept override { void readEOF() noexcept override {
VLOG(5) << "readEOF()"; VLOG(5) << "readEOF()";
// disable callbacks // disable callbacks
socket_->setReadCB(nullptr); transport_.setReadCB(nullptr);
cancelTimeout(); cancelTimeout();
eof = true; eof = true;
post(); post();
...@@ -194,7 +197,7 @@ class ReadCallback : public CallbackBase, ...@@ -194,7 +197,7 @@ class ReadCallback : public CallbackBase,
void readErr(const folly::AsyncSocketException& ex) noexcept override { void readErr(const folly::AsyncSocketException& ex) noexcept override {
VLOG(5) << "readErr()"; VLOG(5) << "readErr()";
// disable callbacks // disable callbacks
socket_->setReadCB(nullptr); transport_.setReadCB(nullptr);
cancelTimeout(); cancelTimeout();
error_ = folly::exception_wrapper(ex); error_ = folly::exception_wrapper(ex);
post(); post();
...@@ -210,7 +213,7 @@ class ReadCallback : public CallbackBase, ...@@ -210,7 +213,7 @@ class ReadCallback : public CallbackBase,
using Error = folly::AsyncSocketException::AsyncSocketExceptionType; using Error = folly::AsyncSocketException::AsyncSocketExceptionType;
// uninstall read callback. it takes another read to bring it back. // uninstall read callback. it takes another read to bring it back.
socket_->setReadCB(nullptr); transport_.setReadCB(nullptr);
// If the timeout fires but this ReadCallback did get some data, ignore it. // If the timeout fires but this ReadCallback did get some data, ignore it.
// post() has already happend from readDataAvailable. // post() has already happend from readDataAvailable.
if (length == 0) { if (length == 0) {
...@@ -222,21 +225,21 @@ class ReadCallback : public CallbackBase, ...@@ -222,21 +225,21 @@ class ReadCallback : public CallbackBase,
}; };
// //
// Handle data write for AsyncSocket // Handle data write for AsyncTransport
// //
class WriteCallback : public CallbackBase, class WriteCallback : public CallbackBase,
public folly::AsyncSocket::WriteCallback { public folly::AsyncTransport::WriteCallback {
public: public:
explicit WriteCallback(std::shared_ptr<folly::AsyncSocket> socket) explicit WriteCallback(folly::AsyncTransport& transport)
: CallbackBase(socket) {} : CallbackBase(transport) {}
~WriteCallback() override = default; ~WriteCallback() override = default;
size_t bytesWritten{0}; size_t bytesWritten{0};
std::optional<folly::AsyncSocketException> error; std::optional<folly::AsyncSocketException> error;
private: private:
void cancel() noexcept override { socket_->closeWithReset(); } void cancel() noexcept override { transport_.closeWithReset(); }
// //
// Methods of WriteCallback // Methods of WriteCallback
// //
...@@ -264,10 +267,10 @@ Task<Socket> Socket::connect( ...@@ -264,10 +267,10 @@ Task<Socket> Socket::connect(
folly::EventBase* evb, folly::EventBase* evb,
const folly::SocketAddress& destAddr, const folly::SocketAddress& destAddr,
std::chrono::milliseconds connectTimeout) { std::chrono::milliseconds connectTimeout) {
std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(evb); auto socket = AsyncSocket::newSocket(evb);
socket->setReadCB(nullptr); socket->setReadCB(nullptr);
ConnectCallback cb{socket}; ConnectCallback cb{*socket};
socket->connect(&cb, destAddr, connectTimeout.count()); socket->connect(&cb, destAddr, connectTimeout.count());
auto waitRet = auto waitRet =
co_await co_awaitTry(cb.wait(co_await co_current_cancellation_token)); co_await co_awaitTry(cb.wait(co_await co_current_cancellation_token));
...@@ -277,7 +280,7 @@ Task<Socket> Socket::connect( ...@@ -277,7 +280,7 @@ Task<Socket> Socket::connect(
if (cb.error()) { if (cb.error()) {
co_yield co_error(std::move(cb.error())); co_yield co_error(std::move(cb.error()));
} }
co_return Socket(socket); co_return Socket(evb, std::move(socket));
} }
Task<size_t> Socket::read( Task<size_t> Socket::read(
...@@ -288,8 +291,8 @@ Task<size_t> Socket::read( ...@@ -288,8 +291,8 @@ Task<size_t> Socket::read(
} }
VLOG(5) << "Socket::read(), expecting max len " << buf.size(); VLOG(5) << "Socket::read(), expecting max len " << buf.size();
ReadCallback cb{socket_, buf, timeout}; ReadCallback cb{eventBase_->timer(), *transport_, buf, timeout};
socket_->setReadCB(&cb); transport_->setReadCB(&cb);
auto waitRet = auto waitRet =
co_await co_awaitTry(cb.wait(co_await co_current_cancellation_token)); co_await co_awaitTry(cb.wait(co_await co_current_cancellation_token));
...@@ -299,7 +302,7 @@ Task<size_t> Socket::read( ...@@ -299,7 +302,7 @@ Task<size_t> Socket::read(
if (cb.error()) { if (cb.error()) {
co_yield co_error(std::move(cb.error())); co_yield co_error(std::move(cb.error()));
} }
socket_->setReadCB(nullptr); transport_->setReadCB(nullptr);
deferredReadEOF_ = (cb.eof && cb.length > 0); deferredReadEOF_ = (cb.eof && cb.length > 0);
co_return cb.length; co_return cb.length;
} }
...@@ -315,8 +318,14 @@ Task<size_t> Socket::read( ...@@ -315,8 +318,14 @@ Task<size_t> Socket::read(
} }
VLOG(5) << "Socket::read(), expecting minReadSize=" << minReadSize; VLOG(5) << "Socket::read(), expecting minReadSize=" << minReadSize;
ReadCallback cb{socket_, &readBuf, minReadSize, newAllocationSize, timeout}; ReadCallback cb{
socket_->setReadCB(&cb); eventBase_->timer(),
*transport_,
&readBuf,
minReadSize,
newAllocationSize,
timeout};
transport_->setReadCB(&cb);
auto waitRet = auto waitRet =
co_await co_awaitTry(cb.wait(co_await co_current_cancellation_token)); co_await co_awaitTry(cb.wait(co_await co_current_cancellation_token));
if (waitRet.hasException()) { if (waitRet.hasException()) {
...@@ -325,7 +334,7 @@ Task<size_t> Socket::read( ...@@ -325,7 +334,7 @@ Task<size_t> Socket::read(
if (cb.error()) { if (cb.error()) {
co_yield co_error(std::move(cb.error())); co_yield co_error(std::move(cb.error()));
} }
socket_->setReadCB(nullptr); transport_->setReadCB(nullptr);
deferredReadEOF_ = (cb.eof && cb.length > 0); deferredReadEOF_ = (cb.eof && cb.length > 0);
co_return cb.length; co_return cb.length;
} }
...@@ -334,9 +343,9 @@ Task<folly::Unit> Socket::write( ...@@ -334,9 +343,9 @@ Task<folly::Unit> Socket::write(
folly::ByteRange buf, folly::ByteRange buf,
std::chrono::milliseconds timeout, std::chrono::milliseconds timeout,
WriteInfo* writeInfo) { WriteInfo* writeInfo) {
socket_->setSendTimeout(timeout.count()); transport_->setSendTimeout(timeout.count());
WriteCallback cb{socket_}; WriteCallback cb{*transport_};
socket_->write(&cb, buf.begin(), buf.size()); transport_->write(&cb, buf.begin(), buf.size());
auto waitRet = auto waitRet =
co_await co_awaitTry(cb.wait(co_await co_current_cancellation_token)); co_await co_awaitTry(cb.wait(co_await co_current_cancellation_token));
if (waitRet.hasException()) { if (waitRet.hasException()) {
...@@ -359,10 +368,10 @@ Task<folly::Unit> Socket::write( ...@@ -359,10 +368,10 @@ Task<folly::Unit> Socket::write(
folly::IOBufQueue& ioBufQueue, folly::IOBufQueue& ioBufQueue,
std::chrono::milliseconds timeout, std::chrono::milliseconds timeout,
WriteInfo* writeInfo) { WriteInfo* writeInfo) {
socket_->setSendTimeout(timeout.count()); transport_->setSendTimeout(timeout.count());
WriteCallback cb{socket_}; WriteCallback cb{*transport_};
auto iovec = ioBufQueue.front()->getIov(); auto iovec = ioBufQueue.front()->getIov();
socket_->writev(&cb, iovec.data(), iovec.size()); transport_->writev(&cb, iovec.data(), iovec.size());
auto waitRet = auto waitRet =
co_await co_awaitTry(cb.wait(co_await co_current_cancellation_token)); co_await co_awaitTry(cb.wait(co_await co_current_cancellation_token));
if (waitRet.hasException()) { if (waitRet.hasException()) {
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <folly/experimental/coro/Task.h> #include <folly/experimental/coro/Task.h>
#include <folly/io/IOBufQueue.h> #include <folly/io/IOBufQueue.h>
#include <folly/io/async/AsyncSocket.h> #include <folly/io/async/AsyncSocket.h>
#include <folly/io/async/AsyncTimeout.h> #include <folly/io/async/AsyncSocketException.h>
#if FOLLY_HAS_COROUTINES #if FOLLY_HAS_COROUTINES
...@@ -71,11 +71,12 @@ class Transport { ...@@ -71,11 +71,12 @@ class Transport {
class Socket : public Transport { class Socket : public Transport {
public: public:
explicit Socket(std::shared_ptr<AsyncSocket> socket)
: socket_(std::move(socket)) {}
explicit Socket(AsyncSocket::UniquePtr socket) explicit Socket(AsyncSocket::UniquePtr socket)
: socket_(socket.release(), AsyncSocket::Destructor()) {} : eventBase_(socket->getEventBase()), transport_(std::move(socket)) {}
Socket(
folly::EventBase* eventBase, folly::AsyncTransport::UniquePtr transport)
: eventBase_(eventBase), transport_(std::move(transport)) {}
Socket(Socket&&) = default; Socket(Socket&&) = default;
Socket& operator=(Socket&&) = default; Socket& operator=(Socket&&) = default;
...@@ -84,9 +85,7 @@ class Socket : public Transport { ...@@ -84,9 +85,7 @@ class Socket : public Transport {
EventBase* evb, EventBase* evb,
const SocketAddress& destAddr, const SocketAddress& destAddr,
std::chrono::milliseconds connectTimeout); std::chrono::milliseconds connectTimeout);
virtual EventBase* getEventBase() noexcept override { virtual EventBase* getEventBase() noexcept override { return eventBase_; }
return socket_->getEventBase();
}
Task<size_t> read( Task<size_t> read(
MutableByteRange buf, std::chrono::milliseconds timeout) override; MutableByteRange buf, std::chrono::milliseconds timeout) override;
...@@ -105,42 +104,40 @@ class Socket : public Transport { ...@@ -105,42 +104,40 @@ class Socket : public Transport {
std::chrono::milliseconds timeout = std::chrono::milliseconds(0), std::chrono::milliseconds timeout = std::chrono::milliseconds(0),
WriteInfo* writeInfo = nullptr) override; WriteInfo* writeInfo = nullptr) override;
AsyncTransport* getTransport() const override { return transport_.get(); }
SocketAddress getLocalAddress() const noexcept override { SocketAddress getLocalAddress() const noexcept override {
SocketAddress addr; SocketAddress addr;
socket_->getLocalAddress(&addr); transport_->getLocalAddress(&addr);
return addr; return addr;
} }
folly::AsyncTransport* getTransport() const override { return socket_.get(); }
SocketAddress getPeerAddress() const noexcept override { SocketAddress getPeerAddress() const noexcept override {
SocketAddress addr; SocketAddress addr;
socket_->getPeerAddress(&addr); transport_->getPeerAddress(&addr);
return addr; return addr;
} }
void shutdownWrite() noexcept override { void shutdownWrite() noexcept override {
if (socket_) { if (transport_) {
socket_->shutdownWrite(); transport_->shutdownWrite();
} }
} }
void close() noexcept override { void close() noexcept override {
if (socket_) { if (transport_) {
socket_->close(); transport_->close();
} }
} }
void closeWithReset() noexcept override { void closeWithReset() noexcept override {
if (socket_) { if (transport_) {
socket_->closeWithReset(); transport_->closeWithReset();
} }
} }
std::shared_ptr<AsyncSocket> getAsyncSocket() { return socket_; }
const AsyncTransportCertificate* getPeerCertificate() const override { const AsyncTransportCertificate* getPeerCertificate() const override {
return socket_->getPeerCertificate(); return transport_->getPeerCertificate();
} }
private: private:
...@@ -148,7 +145,8 @@ class Socket : public Transport { ...@@ -148,7 +145,8 @@ class Socket : public Transport {
Socket(const Socket&) = delete; Socket(const Socket&) = delete;
Socket& operator=(const Socket&) = delete; Socket& operator=(const Socket&) = delete;
std::shared_ptr<AsyncSocket> socket_; EventBase* eventBase_;
AsyncTransport::UniquePtr transport_;
bool deferredReadEOF_{false}; bool deferredReadEOF_{false};
}; };
......
...@@ -252,7 +252,9 @@ TEST_F(ServerSocketTest, WriteCancelled) { ...@@ -252,7 +252,9 @@ TEST_F(ServerSocketTest, WriteCancelled) {
run([&]() -> Task<> { run([&]() -> Task<> {
auto cs = co_await connect(); auto cs = co_await connect();
// reduce the send buffer size so the write wouldn't complete immediately // reduce the send buffer size so the write wouldn't complete immediately
EXPECT_EQ(cs.getAsyncSocket()->setSendBufSize(4096), 0); auto asyncSocket = dynamic_cast<folly::AsyncSocket*>(cs.getTransport());
CHECK(asyncSocket);
EXPECT_EQ(asyncSocket->setSendBufSize(4096), 0);
// produces blocking socket // produces blocking socket
auto ss = srv.accept(-1); auto ss = srv.accept(-1);
constexpr auto kBufSize = 65536; constexpr auto kBufSize = 65536;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment