From 203232ffbd80e9f4631213a3876f14dde155a92d Mon Sep 17 00:00:00 2001 From: Connor Sughrue <55301806+cpsughrue@users.noreply.github.com> Date: Tue, 21 May 2024 20:32:11 -0400 Subject: [llvm][Support] ListeningSocket::accept returns operation_canceled if FD is set to -1 (#89479) If `::poll` returns and `FD` equals -1, then `ListeningSocket::shutdown` has been called. So, regardless of any other information that could be gleaned from `FDs.revents` or `PollStatus`, it is appropriate to return `std::errc::operation_canceled`. `ListeningSocket::shutdown` copies `FD`'s value to `ObservedFD` then sets `FD` to -1 before canceling `::poll` by calling `::close(ObservedFD)` and writing to the pipe. --- llvm/lib/Support/raw_socket_stream.cpp | 23 +++++++++++++---------- llvm/unittests/Support/raw_socket_stream_test.cpp | 19 ++++--------------- 2 files changed, 17 insertions(+), 25 deletions(-) diff --git a/llvm/lib/Support/raw_socket_stream.cpp b/llvm/lib/Support/raw_socket_stream.cpp index 14e2308..549d537 100644 --- a/llvm/lib/Support/raw_socket_stream.cpp +++ b/llvm/lib/Support/raw_socket_stream.cpp @@ -204,17 +204,26 @@ ListeningSocket::accept(std::chrono::milliseconds Timeout) { auto Start = std::chrono::steady_clock::now(); #ifdef _WIN32 PollStatus = WSAPoll(FDs, 2, RemainingTime); - if (PollStatus == SOCKET_ERROR) { #else PollStatus = ::poll(FDs, 2, RemainingTime); +#endif + // If FD equals -1 then ListeningSocket::shutdown has been called and it is + // appropriate to return operation_canceled + if (FD.load() == -1) + return llvm::make_error( + std::make_error_code(std::errc::operation_canceled), + "Accept canceled"); + +#if _WIN32 + if (PollStatus == SOCKET_ERROR) { +#else if (PollStatus == -1) { #endif - // Ignore error if caused by interupting signal std::error_code PollErrCode = getLastSocketErrorCode(); + // Ignore EINTR (signal occured before any request event) and retry if (PollErrCode != std::errc::interrupted) return llvm::make_error(PollErrCode, "FD poll failed"); } - if (PollStatus == 0) return llvm::make_error( std::make_error_code(std::errc::timed_out), @@ -222,13 +231,7 @@ ListeningSocket::accept(std::chrono::milliseconds Timeout) { if (FDs[0].revents & POLLNVAL) return llvm::make_error( - std::make_error_code(std::errc::bad_file_descriptor), - "File descriptor closed by another thread"); - - if (FDs[1].revents & POLLIN) - return llvm::make_error( - std::make_error_code(std::errc::operation_canceled), - "Accept canceled"); + std::make_error_code(std::errc::bad_file_descriptor)); auto Stop = std::chrono::steady_clock::now(); ElapsedTime += diff --git a/llvm/unittests/Support/raw_socket_stream_test.cpp b/llvm/unittests/Support/raw_socket_stream_test.cpp index a853622..c4e8cfb 100644 --- a/llvm/unittests/Support/raw_socket_stream_test.cpp +++ b/llvm/unittests/Support/raw_socket_stream_test.cpp @@ -7,7 +7,6 @@ #include "llvm/Testing/Support/Error.h" #include "gtest/gtest.h" #include -#include #include #include @@ -86,13 +85,8 @@ TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) { std::chrono::milliseconds Timeout = std::chrono::milliseconds(100); Expected> MaybeServer = ServerListener.accept(Timeout); - - ASSERT_THAT_EXPECTED(MaybeServer, Failed()); - llvm::Error Err = MaybeServer.takeError(); - llvm::handleAllErrors(std::move(Err), [&](const llvm::StringError &SE) { - std::error_code EC = SE.convertToErrorCode(); - ASSERT_EQ(EC, std::errc::timed_out); - }); + ASSERT_EQ(llvm::errorToErrorCode(MaybeServer.takeError()), + std::errc::timed_out); } TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) { @@ -122,12 +116,7 @@ TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) { // Wait for the CloseThread to finish CloseThread.join(); - - ASSERT_THAT_EXPECTED(MaybeServer, Failed()); - llvm::Error Err = MaybeServer.takeError(); - llvm::handleAllErrors(std::move(Err), [&](const llvm::StringError &SE) { - std::error_code EC = SE.convertToErrorCode(); - ASSERT_EQ(EC, std::errc::operation_canceled); - }); + ASSERT_EQ(llvm::errorToErrorCode(MaybeServer.takeError()), + std::errc::operation_canceled); } } // namespace -- cgit v1.1