diff options
author | Connor Sughrue <55301806+cpsughrue@users.noreply.github.com> | 2024-05-21 20:32:11 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-21 20:32:11 -0400 |
commit | 203232ffbd80e9f4631213a3876f14dde155a92d (patch) | |
tree | 5ff2674be3737497c9e6558cc9dd37a8845968f8 | |
parent | 0170bd5d111f55f45f993a749727ce2815cc0b16 (diff) | |
download | llvm-203232ffbd80e9f4631213a3876f14dde155a92d.zip llvm-203232ffbd80e9f4631213a3876f14dde155a92d.tar.gz llvm-203232ffbd80e9f4631213a3876f14dde155a92d.tar.bz2 |
[llvm][Support] ListeningSocket::accept returns operation_canceled if FD is set to -1 (#89479)
If `::poll` returns and `FD` equals -1, then `ListeningSocket::shutdown`
has been called. So, regardless of any other information that could be
gleaned from `FDs.revents` or `PollStatus`, it is appropriate to return
`std::errc::operation_canceled`. `ListeningSocket::shutdown` copies
`FD`'s value to `ObservedFD` then sets `FD` to -1 before canceling
`::poll` by calling `::close(ObservedFD)` and writing to the pipe.
-rw-r--r-- | llvm/lib/Support/raw_socket_stream.cpp | 23 | ||||
-rw-r--r-- | llvm/unittests/Support/raw_socket_stream_test.cpp | 19 |
2 files changed, 17 insertions, 25 deletions
diff --git a/llvm/lib/Support/raw_socket_stream.cpp b/llvm/lib/Support/raw_socket_stream.cpp index 14e2308..549d537 100644 --- a/llvm/lib/Support/raw_socket_stream.cpp +++ b/llvm/lib/Support/raw_socket_stream.cpp @@ -204,17 +204,26 @@ ListeningSocket::accept(std::chrono::milliseconds Timeout) { auto Start = std::chrono::steady_clock::now(); #ifdef _WIN32 PollStatus = WSAPoll(FDs, 2, RemainingTime); - if (PollStatus == SOCKET_ERROR) { #else PollStatus = ::poll(FDs, 2, RemainingTime); +#endif + // If FD equals -1 then ListeningSocket::shutdown has been called and it is + // appropriate to return operation_canceled + if (FD.load() == -1) + return llvm::make_error<StringError>( + std::make_error_code(std::errc::operation_canceled), + "Accept canceled"); + +#if _WIN32 + if (PollStatus == SOCKET_ERROR) { +#else if (PollStatus == -1) { #endif - // Ignore error if caused by interupting signal std::error_code PollErrCode = getLastSocketErrorCode(); + // Ignore EINTR (signal occured before any request event) and retry if (PollErrCode != std::errc::interrupted) return llvm::make_error<StringError>(PollErrCode, "FD poll failed"); } - if (PollStatus == 0) return llvm::make_error<StringError>( std::make_error_code(std::errc::timed_out), @@ -222,13 +231,7 @@ ListeningSocket::accept(std::chrono::milliseconds Timeout) { if (FDs[0].revents & POLLNVAL) return llvm::make_error<StringError>( - std::make_error_code(std::errc::bad_file_descriptor), - "File descriptor closed by another thread"); - - if (FDs[1].revents & POLLIN) - return llvm::make_error<StringError>( - std::make_error_code(std::errc::operation_canceled), - "Accept canceled"); + std::make_error_code(std::errc::bad_file_descriptor)); auto Stop = std::chrono::steady_clock::now(); ElapsedTime += diff --git a/llvm/unittests/Support/raw_socket_stream_test.cpp b/llvm/unittests/Support/raw_socket_stream_test.cpp index a853622..c4e8cfb 100644 --- a/llvm/unittests/Support/raw_socket_stream_test.cpp +++ b/llvm/unittests/Support/raw_socket_stream_test.cpp @@ -7,7 +7,6 @@ #include "llvm/Testing/Support/Error.h" #include "gtest/gtest.h" #include <future> -#include <iostream> #include <stdlib.h> #include <thread> @@ -86,13 +85,8 @@ TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) { std::chrono::milliseconds Timeout = std::chrono::milliseconds(100); Expected<std::unique_ptr<raw_socket_stream>> MaybeServer = ServerListener.accept(Timeout); - - ASSERT_THAT_EXPECTED(MaybeServer, Failed()); - llvm::Error Err = MaybeServer.takeError(); - llvm::handleAllErrors(std::move(Err), [&](const llvm::StringError &SE) { - std::error_code EC = SE.convertToErrorCode(); - ASSERT_EQ(EC, std::errc::timed_out); - }); + ASSERT_EQ(llvm::errorToErrorCode(MaybeServer.takeError()), + std::errc::timed_out); } TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) { @@ -122,12 +116,7 @@ TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) { // Wait for the CloseThread to finish CloseThread.join(); - - ASSERT_THAT_EXPECTED(MaybeServer, Failed()); - llvm::Error Err = MaybeServer.takeError(); - llvm::handleAllErrors(std::move(Err), [&](const llvm::StringError &SE) { - std::error_code EC = SE.convertToErrorCode(); - ASSERT_EQ(EC, std::errc::operation_canceled); - }); + ASSERT_EQ(llvm::errorToErrorCode(MaybeServer.takeError()), + std::errc::operation_canceled); } } // namespace |