diff options
author | Connor Sughrue <55301806+cpsughrue@users.noreply.github.com> | 2024-04-09 23:41:18 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-09 23:41:18 -0400 |
commit | 87e6f87fe7e343eb656e9b49d30cbb065c086651 (patch) | |
tree | 8ee0b61590360583b9a8dd7f2526d856cfdaf2db | |
parent | 289a2c380e47d64a1e626259c53fc8c7d6c2be66 (diff) | |
download | llvm-87e6f87fe7e343eb656e9b49d30cbb065c086651.zip llvm-87e6f87fe7e343eb656e9b49d30cbb065c086651.tar.gz llvm-87e6f87fe7e343eb656e9b49d30cbb065c086651.tar.bz2 |
[llvm][Support] Improvements to ListeningSocket functionality and documentation (#84710)
Improvements include
* Enable `ListeningSocket::accept` to timeout after a specified amount
of time or block indefinitely
* Enable `ListeningSocket::createUnix` to handle instances where the
target socket address already exists and differentiate between
situations where the existing file does and does not already have a
bound socket
* Doxygen comments
Functionality added for the module build daemon
---------
Co-authored-by: Michael Spencer <bigcheesegs@gmail.com>
-rw-r--r-- | llvm/include/llvm/Support/raw_socket_stream.h | 86 | ||||
-rw-r--r-- | llvm/lib/Support/raw_socket_stream.cpp | 253 | ||||
-rw-r--r-- | llvm/unittests/Support/raw_socket_stream_test.cpp | 73 |
3 files changed, 339 insertions, 73 deletions
diff --git a/llvm/include/llvm/Support/raw_socket_stream.h b/llvm/include/llvm/Support/raw_socket_stream.h index c219792..bddd47e 100644 --- a/llvm/include/llvm/Support/raw_socket_stream.h +++ b/llvm/include/llvm/Support/raw_socket_stream.h @@ -17,12 +17,17 @@ #include "llvm/Support/Threading.h" #include "llvm/Support/raw_ostream.h" +#include <atomic> +#include <chrono> + namespace llvm { class raw_socket_stream; -// Make sure that calls to WSAStartup and WSACleanup are balanced. #ifdef _WIN32 +/// Ensures proper initialization and cleanup of winsock resources +/// +/// Make sure that calls to WSAStartup and WSACleanup are balanced. class WSABalancer { public: WSABalancer(); @@ -30,22 +35,87 @@ public: }; #endif // _WIN32 +/// Manages a passive (i.e., listening) UNIX domain socket +/// +/// The ListeningSocket class encapsulates a UNIX domain socket that can listen +/// and accept incoming connections. ListeningSocket is portable and supports +/// Windows builds begining with Insider Build 17063. ListeningSocket is +/// designed for server-side operations, working alongside \p raw_socket_streams +/// that function as client connections. +/// +/// Usage example: +/// \code{.cpp} +/// std::string Path = "/path/to/socket" +/// Expected<ListeningSocket> S = ListeningSocket::createUnix(Path); +/// +/// if (S) { +/// Expected<std::unique_ptr<raw_socket_stream>> connection = S->accept(); +/// if (connection) { +/// // Use the accepted raw_socket_stream for communication. +/// } +/// } +/// \endcode +/// class ListeningSocket { - int FD; - std::string SocketPath; - ListeningSocket(int SocketFD, StringRef SocketPath); + + std::atomic<int> FD; + std::string SocketPath; // Not modified after construction + + /// If a seperate thread calls ListeningSocket::shutdown, the ListeningSocket + /// file descriptor (FD) could be closed while ::poll is waiting for it to be + /// ready to perform a I/O operations. ::poll will continue to block even + /// after FD is closed so use a self-pipe mechanism to get ::poll to return + int PipeFD[2]; // Not modified after construction other then move constructor + + ListeningSocket(int SocketFD, StringRef SocketPath, int PipeFD[2]); + #ifdef _WIN32 WSABalancer _; #endif // _WIN32 public: + ~ListeningSocket(); + ListeningSocket(ListeningSocket &&LS); + ListeningSocket(const ListeningSocket &LS) = delete; + ListeningSocket &operator=(const ListeningSocket &) = delete; + + /// Closes the FD, unlinks the socket file, and writes to PipeFD. + /// + /// After the construction of the ListeningSocket, shutdown is signal safe if + /// it is called during the lifetime of the object. shutdown can be called + /// concurrently with ListeningSocket::accept as writing to PipeFD will cause + /// a blocking call to ::poll to return. + /// + /// Once shutdown is called there is no way to reinitialize ListeningSocket. + void shutdown(); + + /// Accepts an incoming connection on the listening socket. This method can + /// optionally either block until a connection is available or timeout after a + /// specified amount of time has passed. By default the method will block + /// until the socket has recieved a connection. + /// + /// \param Timeout An optional timeout duration in milliseconds. Setting + /// Timeout to -1 causes accept to block indefinitely + /// + Expected<std::unique_ptr<raw_socket_stream>> + accept(std::chrono::milliseconds Timeout = std::chrono::milliseconds(-1)); + + /// Creates a listening socket bound to the specified file system path. + /// Handles the socket creation, binding, and immediately starts listening for + /// incoming connections. + /// + /// \param SocketPath The file system path where the socket will be created + /// \param MaxBacklog The max number of connections in a socket's backlog + /// static Expected<ListeningSocket> createUnix( StringRef SocketPath, int MaxBacklog = llvm::hardware_concurrency().compute_thread_count()); - Expected<std::unique_ptr<raw_socket_stream>> accept(); - ListeningSocket(ListeningSocket &&LS); - ~ListeningSocket(); }; + +//===----------------------------------------------------------------------===// +// raw_socket_stream +//===----------------------------------------------------------------------===// + class raw_socket_stream : public raw_fd_stream { uint64_t current_pos() const override { return 0; } #ifdef _WIN32 @@ -54,7 +124,7 @@ class raw_socket_stream : public raw_fd_stream { public: raw_socket_stream(int SocketFD); - /// Create a \p raw_socket_stream connected to the Unix domain socket at \p + /// Create a \p raw_socket_stream connected to the UNIX domain socket at \p /// SocketPath. static Expected<std::unique_ptr<raw_socket_stream>> createConnectedUnix(StringRef SocketPath); diff --git a/llvm/lib/Support/raw_socket_stream.cpp b/llvm/lib/Support/raw_socket_stream.cpp index afb0ed1..1dcf635 100644 --- a/llvm/lib/Support/raw_socket_stream.cpp +++ b/llvm/lib/Support/raw_socket_stream.cpp @@ -14,8 +14,14 @@ #include "llvm/Support/raw_socket_stream.h" #include "llvm/Config/config.h" #include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem.h" + +#include <atomic> +#include <fcntl.h> +#include <thread> #ifndef _WIN32 +#include <poll.h> #include <sys/socket.h> #include <sys/un.h> #else @@ -45,7 +51,6 @@ WSABalancer::WSABalancer() { } WSABalancer::~WSABalancer() { WSACleanup(); } - #endif // _WIN32 static std::error_code getLastSocketErrorCode() { @@ -56,104 +61,231 @@ static std::error_code getLastSocketErrorCode() { #endif } -ListeningSocket::ListeningSocket(int SocketFD, StringRef SocketPath) - : FD(SocketFD), SocketPath(SocketPath) {} +static sockaddr_un setSocketAddr(StringRef SocketPath) { + struct sockaddr_un Addr; + memset(&Addr, 0, sizeof(Addr)); + Addr.sun_family = AF_UNIX; + strncpy(Addr.sun_path, SocketPath.str().c_str(), sizeof(Addr.sun_path) - 1); + return Addr; +} + +static Expected<int> getSocketFD(StringRef SocketPath) { +#ifdef _WIN32 + SOCKET Socket = socket(AF_UNIX, SOCK_STREAM, 0); + if (Socket == INVALID_SOCKET) { +#else + int Socket = socket(AF_UNIX, SOCK_STREAM, 0); + if (Socket == -1) { +#endif // _WIN32 + return llvm::make_error<StringError>(getLastSocketErrorCode(), + "Create socket failed"); + } + + struct sockaddr_un Addr = setSocketAddr(SocketPath); + if (::connect(Socket, (struct sockaddr *)&Addr, sizeof(Addr)) == -1) + return llvm::make_error<StringError>(getLastSocketErrorCode(), + "Connect socket failed"); + +#ifdef _WIN32 + return _open_osfhandle(Socket, 0); +#else + return Socket; +#endif // _WIN32 +} + +ListeningSocket::ListeningSocket(int SocketFD, StringRef SocketPath, + int PipeFD[2]) + : FD(SocketFD), SocketPath(SocketPath), PipeFD{PipeFD[0], PipeFD[1]} {} ListeningSocket::ListeningSocket(ListeningSocket &&LS) - : FD(LS.FD), SocketPath(LS.SocketPath) { + : FD(LS.FD.load()), SocketPath(LS.SocketPath), + PipeFD{LS.PipeFD[0], LS.PipeFD[1]} { + LS.FD = -1; + LS.SocketPath.clear(); + LS.PipeFD[0] = -1; + LS.PipeFD[1] = -1; } Expected<ListeningSocket> ListeningSocket::createUnix(StringRef SocketPath, int MaxBacklog) { + // Handle instances where the target socket address already exists and + // differentiate between a preexisting file with and without a bound socket + // + // ::bind will return std::errc:address_in_use if a file at the socket address + // already exists (e.g., the file was not properly unlinked due to a crash) + // even if another socket has not yet binded to that address + if (llvm::sys::fs::exists(SocketPath)) { + Expected<int> MaybeFD = getSocketFD(SocketPath); + if (!MaybeFD) { + + // Regardless of the error, notify the caller that a file already exists + // at the desired socket address and that there is no bound socket at that + // address. The file must be removed before ::bind can use the address + consumeError(MaybeFD.takeError()); + return llvm::make_error<StringError>( + std::make_error_code(std::errc::file_exists), + "Socket address unavailable"); + } + ::close(std::move(*MaybeFD)); + + // Notify caller that the provided socket address already has a bound socket + return llvm::make_error<StringError>( + std::make_error_code(std::errc::address_in_use), + "Socket address unavailable"); + } + #ifdef _WIN32 WSABalancer _; - SOCKET MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0); - if (MaybeWinsocket == INVALID_SOCKET) { + SOCKET Socket = socket(AF_UNIX, SOCK_STREAM, 0); + if (Socket == INVALID_SOCKET) #else - int MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0); - if (MaybeWinsocket == -1) { + int Socket = socket(AF_UNIX, SOCK_STREAM, 0); + if (Socket == -1) #endif return llvm::make_error<StringError>(getLastSocketErrorCode(), "socket create failed"); - } - struct sockaddr_un Addr; - memset(&Addr, 0, sizeof(Addr)); - Addr.sun_family = AF_UNIX; - strncpy(Addr.sun_path, SocketPath.str().c_str(), sizeof(Addr.sun_path) - 1); - - if (bind(MaybeWinsocket, (struct sockaddr *)&Addr, sizeof(Addr)) == -1) { - std::error_code Err = getLastSocketErrorCode(); - if (Err == std::errc::address_in_use) - ::close(MaybeWinsocket); - return llvm::make_error<StringError>(Err, "Bind error"); + struct sockaddr_un Addr = setSocketAddr(SocketPath); + if (::bind(Socket, (struct sockaddr *)&Addr, sizeof(Addr)) == -1) { + // Grab error code from call to ::bind before calling ::close + std::error_code EC = getLastSocketErrorCode(); + ::close(Socket); + return llvm::make_error<StringError>(EC, "Bind error"); } - if (listen(MaybeWinsocket, MaxBacklog) == -1) { + + // Mark socket as passive so incoming connections can be accepted + if (::listen(Socket, MaxBacklog) == -1) return llvm::make_error<StringError>(getLastSocketErrorCode(), "Listen error"); - } - int UnixSocket; + + int PipeFD[2]; #ifdef _WIN32 - UnixSocket = _open_osfhandle(MaybeWinsocket, 0); + // Reserve 1 byte for the pipe and use default textmode + if (::_pipe(PipeFD, 1, 0) == -1) #else - UnixSocket = MaybeWinsocket; + if (::pipe(PipeFD) == -1) +#endif // _WIN32 + return llvm::make_error<StringError>(getLastSocketErrorCode(), + "pipe failed"); + +#ifdef _WIN32 + return ListeningSocket{_open_osfhandle(Socket, 0), SocketPath, PipeFD}; +#else + return ListeningSocket{Socket, SocketPath, PipeFD}; #endif // _WIN32 - return ListeningSocket{UnixSocket, SocketPath}; } -Expected<std::unique_ptr<raw_socket_stream>> ListeningSocket::accept() { - int AcceptFD; +Expected<std::unique_ptr<raw_socket_stream>> +ListeningSocket::accept(std::chrono::milliseconds Timeout) { + + struct pollfd FDs[2]; + FDs[0].events = POLLIN; #ifdef _WIN32 SOCKET WinServerSock = _get_osfhandle(FD); + FDs[0].fd = WinServerSock; +#else + FDs[0].fd = FD; +#endif + FDs[1].events = POLLIN; + FDs[1].fd = PipeFD[0]; + + // Keep track of how much time has passed in case poll is interupted by a + // signal and needs to be recalled + int RemainingTime = Timeout.count(); + std::chrono::milliseconds ElapsedTime = std::chrono::milliseconds(0); + int PollStatus = -1; + + while (PollStatus == -1 && (Timeout.count() == -1 || ElapsedTime < Timeout)) { + if (Timeout.count() != -1) + RemainingTime -= ElapsedTime.count(); + + auto Start = std::chrono::steady_clock::now(); +#ifdef _WIN32 + PollStatus = WSAPoll(FDs, 2, RemainingTime); + if (PollStatus == SOCKET_ERROR) { +#else + PollStatus = ::poll(FDs, 2, RemainingTime); + if (PollStatus == -1) { +#endif + // Ignore error if caused by interupting signal + std::error_code PollErrCode = getLastSocketErrorCode(); + if (PollErrCode != std::errc::interrupted) + return llvm::make_error<StringError>(PollErrCode, "FD poll failed"); + } + + if (PollStatus == 0) + return llvm::make_error<StringError>( + std::make_error_code(std::errc::timed_out), + "No client requests within timeout window"); + + if (FDs[0].revents & POLLNVAL) + return llvm::make_error<StringError>( + std::make_error_code(std::errc::bad_file_descriptor), + "File descriptor closed by another thread"); + + if (FDs[1].revents & POLLIN) + return llvm::make_error<StringError>( + std::make_error_code(std::errc::operation_canceled), + "Accept canceled"); + + auto Stop = std::chrono::steady_clock::now(); + ElapsedTime += + std::chrono::duration_cast<std::chrono::milliseconds>(Stop - Start); + } + + int AcceptFD; +#ifdef _WIN32 SOCKET WinAcceptSock = ::accept(WinServerSock, NULL, NULL); AcceptFD = _open_osfhandle(WinAcceptSock, 0); #else AcceptFD = ::accept(FD, NULL, NULL); -#endif //_WIN32 +#endif + if (AcceptFD == -1) return llvm::make_error<StringError>(getLastSocketErrorCode(), - "Accept failed"); + "Socket accept failed"); return std::make_unique<raw_socket_stream>(AcceptFD); } -ListeningSocket::~ListeningSocket() { - if (FD == -1) +void ListeningSocket::shutdown() { + int ObservedFD = FD.load(); + + if (ObservedFD == -1) return; - ::close(FD); - unlink(SocketPath.c_str()); -} -static Expected<int> GetSocketFD(StringRef SocketPath) { -#ifdef _WIN32 - SOCKET MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0); - if (MaybeWinsocket == INVALID_SOCKET) { -#else - int MaybeWinsocket = socket(AF_UNIX, SOCK_STREAM, 0); - if (MaybeWinsocket == -1) { -#endif // _WIN32 - return llvm::make_error<StringError>(getLastSocketErrorCode(), - "Create socket failed"); - } + // If FD equals ObservedFD set FD to -1; If FD doesn't equal ObservedFD then + // another thread is responsible for shutdown so return + if (!FD.compare_exchange_strong(ObservedFD, -1)) + return; - struct sockaddr_un Addr; - memset(&Addr, 0, sizeof(Addr)); - Addr.sun_family = AF_UNIX; - strncpy(Addr.sun_path, SocketPath.str().c_str(), sizeof(Addr.sun_path) - 1); + ::close(ObservedFD); + ::unlink(SocketPath.c_str()); - int status = connect(MaybeWinsocket, (struct sockaddr *)&Addr, sizeof(Addr)); - if (status == -1) { - return llvm::make_error<StringError>(getLastSocketErrorCode(), - "Connect socket failed"); - } -#ifdef _WIN32 - return _open_osfhandle(MaybeWinsocket, 0); -#else - return MaybeWinsocket; -#endif // _WIN32 + // Ensure ::poll returns if shutdown is called by a seperate thread + char Byte = 'A'; + ::write(PipeFD[1], &Byte, 1); } +ListeningSocket::~ListeningSocket() { + shutdown(); + + // Close the pipe's FDs in the destructor instead of within + // ListeningSocket::shutdown to avoid unnecessary synchronization issues that + // would occur as PipeFD's values would have to be changed to -1 + // + // The move constructor sets PipeFD to -1 + if (PipeFD[0] != -1) + ::close(PipeFD[0]); + if (PipeFD[1] != -1) + ::close(PipeFD[1]); +} + +//===----------------------------------------------------------------------===// +// raw_socket_stream +//===----------------------------------------------------------------------===// + raw_socket_stream::raw_socket_stream(int SocketFD) : raw_fd_stream(SocketFD, true) {} @@ -162,11 +294,10 @@ raw_socket_stream::createConnectedUnix(StringRef SocketPath) { #ifdef _WIN32 WSABalancer _; #endif // _WIN32 - Expected<int> FD = GetSocketFD(SocketPath); + Expected<int> FD = getSocketFD(SocketPath); if (!FD) return FD.takeError(); return std::make_unique<raw_socket_stream>(*FD); } raw_socket_stream::~raw_socket_stream() {} - diff --git a/llvm/unittests/Support/raw_socket_stream_test.cpp b/llvm/unittests/Support/raw_socket_stream_test.cpp index 6903862..a853622 100644 --- a/llvm/unittests/Support/raw_socket_stream_test.cpp +++ b/llvm/unittests/Support/raw_socket_stream_test.cpp @@ -9,6 +9,7 @@ #include <future> #include <iostream> #include <stdlib.h> +#include <thread> #ifdef _WIN32 #include "llvm/Support/Windows/WindowsSupport.h" @@ -32,10 +33,10 @@ TEST(raw_socket_streamTest, CLIENT_TO_SERVER_AND_SERVER_TO_CLIENT) { GTEST_SKIP(); SmallString<100> SocketPath; - llvm::sys::fs::createUniquePath("test_raw_socket_stream.sock", SocketPath, - true); + llvm::sys::fs::createUniquePath("client_server_comms.sock", SocketPath, true); - char Bytes[8]; + // Make sure socket file does not exist. May still be there from the last test + std::remove(SocketPath.c_str()); Expected<ListeningSocket> MaybeServerListener = ListeningSocket::createUnix(SocketPath); @@ -58,6 +59,7 @@ TEST(raw_socket_streamTest, CLIENT_TO_SERVER_AND_SERVER_TO_CLIENT) { Client << "01234567"; Client.flush(); + char Bytes[8]; ssize_t BytesRead = Server.read(Bytes, 8); std::string string(Bytes, 8); @@ -65,4 +67,67 @@ TEST(raw_socket_streamTest, CLIENT_TO_SERVER_AND_SERVER_TO_CLIENT) { ASSERT_EQ(8, BytesRead); ASSERT_EQ("01234567", string); } -} // namespace
\ No newline at end of file + +TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) { + if (!hasUnixSocketSupport()) + GTEST_SKIP(); + + SmallString<100> SocketPath; + llvm::sys::fs::createUniquePath("timout_provided.sock", SocketPath, true); + + // Make sure socket file does not exist. May still be there from the last test + std::remove(SocketPath.c_str()); + + Expected<ListeningSocket> MaybeServerListener = + ListeningSocket::createUnix(SocketPath); + ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded()); + ListeningSocket ServerListener = std::move(*MaybeServerListener); + + std::chrono::milliseconds Timeout = std::chrono::milliseconds(100); + Expected<std::unique_ptr<raw_socket_stream>> MaybeServer = + ServerListener.accept(Timeout); + + ASSERT_THAT_EXPECTED(MaybeServer, Failed()); + llvm::Error Err = MaybeServer.takeError(); + llvm::handleAllErrors(std::move(Err), [&](const llvm::StringError &SE) { + std::error_code EC = SE.convertToErrorCode(); + ASSERT_EQ(EC, std::errc::timed_out); + }); +} + +TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) { + if (!hasUnixSocketSupport()) + GTEST_SKIP(); + + SmallString<100> SocketPath; + llvm::sys::fs::createUniquePath("fd_closed.sock", SocketPath, true); + + // Make sure socket file does not exist. May still be there from the last test + std::remove(SocketPath.c_str()); + + Expected<ListeningSocket> MaybeServerListener = + ListeningSocket::createUnix(SocketPath); + ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded()); + ListeningSocket ServerListener = std::move(*MaybeServerListener); + + // Create a separate thread to close the socket after a delay. Simulates a + // signal handler calling ServerListener::shutdown + std::thread CloseThread([&]() { + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + ServerListener.shutdown(); + }); + + Expected<std::unique_ptr<raw_socket_stream>> MaybeServer = + ServerListener.accept(); + + // Wait for the CloseThread to finish + CloseThread.join(); + + ASSERT_THAT_EXPECTED(MaybeServer, Failed()); + llvm::Error Err = MaybeServer.takeError(); + llvm::handleAllErrors(std::move(Err), [&](const llvm::StringError &SE) { + std::error_code EC = SE.convertToErrorCode(); + ASSERT_EQ(EC, std::errc::operation_canceled); + }); +} +} // namespace |