aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorConnor Sughrue <55301806+cpsughrue@users.noreply.github.com>2024-05-21 20:32:11 -0400
committerGitHub <noreply@github.com>2024-05-21 20:32:11 -0400
commit203232ffbd80e9f4631213a3876f14dde155a92d (patch)
tree5ff2674be3737497c9e6558cc9dd37a8845968f8
parent0170bd5d111f55f45f993a749727ce2815cc0b16 (diff)
downloadllvm-203232ffbd80e9f4631213a3876f14dde155a92d.zip
llvm-203232ffbd80e9f4631213a3876f14dde155a92d.tar.gz
llvm-203232ffbd80e9f4631213a3876f14dde155a92d.tar.bz2
[llvm][Support] ListeningSocket::accept returns operation_canceled if FD is set to -1 (#89479)
If `::poll` returns and `FD` equals -1, then `ListeningSocket::shutdown` has been called. So, regardless of any other information that could be gleaned from `FDs.revents` or `PollStatus`, it is appropriate to return `std::errc::operation_canceled`. `ListeningSocket::shutdown` copies `FD`'s value to `ObservedFD` then sets `FD` to -1 before canceling `::poll` by calling `::close(ObservedFD)` and writing to the pipe.
-rw-r--r--llvm/lib/Support/raw_socket_stream.cpp23
-rw-r--r--llvm/unittests/Support/raw_socket_stream_test.cpp19
2 files changed, 17 insertions, 25 deletions
diff --git a/llvm/lib/Support/raw_socket_stream.cpp b/llvm/lib/Support/raw_socket_stream.cpp
index 14e2308..549d537 100644
--- a/llvm/lib/Support/raw_socket_stream.cpp
+++ b/llvm/lib/Support/raw_socket_stream.cpp
@@ -204,17 +204,26 @@ ListeningSocket::accept(std::chrono::milliseconds Timeout) {
auto Start = std::chrono::steady_clock::now();
#ifdef _WIN32
PollStatus = WSAPoll(FDs, 2, RemainingTime);
- if (PollStatus == SOCKET_ERROR) {
#else
PollStatus = ::poll(FDs, 2, RemainingTime);
+#endif
+ // If FD equals -1 then ListeningSocket::shutdown has been called and it is
+ // appropriate to return operation_canceled
+ if (FD.load() == -1)
+ return llvm::make_error<StringError>(
+ std::make_error_code(std::errc::operation_canceled),
+ "Accept canceled");
+
+#if _WIN32
+ if (PollStatus == SOCKET_ERROR) {
+#else
if (PollStatus == -1) {
#endif
- // Ignore error if caused by interupting signal
std::error_code PollErrCode = getLastSocketErrorCode();
+ // Ignore EINTR (signal occured before any request event) and retry
if (PollErrCode != std::errc::interrupted)
return llvm::make_error<StringError>(PollErrCode, "FD poll failed");
}
-
if (PollStatus == 0)
return llvm::make_error<StringError>(
std::make_error_code(std::errc::timed_out),
@@ -222,13 +231,7 @@ ListeningSocket::accept(std::chrono::milliseconds Timeout) {
if (FDs[0].revents & POLLNVAL)
return llvm::make_error<StringError>(
- std::make_error_code(std::errc::bad_file_descriptor),
- "File descriptor closed by another thread");
-
- if (FDs[1].revents & POLLIN)
- return llvm::make_error<StringError>(
- std::make_error_code(std::errc::operation_canceled),
- "Accept canceled");
+ std::make_error_code(std::errc::bad_file_descriptor));
auto Stop = std::chrono::steady_clock::now();
ElapsedTime +=
diff --git a/llvm/unittests/Support/raw_socket_stream_test.cpp b/llvm/unittests/Support/raw_socket_stream_test.cpp
index a853622..c4e8cfb 100644
--- a/llvm/unittests/Support/raw_socket_stream_test.cpp
+++ b/llvm/unittests/Support/raw_socket_stream_test.cpp
@@ -7,7 +7,6 @@
#include "llvm/Testing/Support/Error.h"
#include "gtest/gtest.h"
#include <future>
-#include <iostream>
#include <stdlib.h>
#include <thread>
@@ -86,13 +85,8 @@ TEST(raw_socket_streamTest, TIMEOUT_PROVIDED) {
std::chrono::milliseconds Timeout = std::chrono::milliseconds(100);
Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
ServerListener.accept(Timeout);
-
- ASSERT_THAT_EXPECTED(MaybeServer, Failed());
- llvm::Error Err = MaybeServer.takeError();
- llvm::handleAllErrors(std::move(Err), [&](const llvm::StringError &SE) {
- std::error_code EC = SE.convertToErrorCode();
- ASSERT_EQ(EC, std::errc::timed_out);
- });
+ ASSERT_EQ(llvm::errorToErrorCode(MaybeServer.takeError()),
+ std::errc::timed_out);
}
TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) {
@@ -122,12 +116,7 @@ TEST(raw_socket_streamTest, FILE_DESCRIPTOR_CLOSED) {
// Wait for the CloseThread to finish
CloseThread.join();
-
- ASSERT_THAT_EXPECTED(MaybeServer, Failed());
- llvm::Error Err = MaybeServer.takeError();
- llvm::handleAllErrors(std::move(Err), [&](const llvm::StringError &SE) {
- std::error_code EC = SE.convertToErrorCode();
- ASSERT_EQ(EC, std::errc::operation_canceled);
- });
+ ASSERT_EQ(llvm::errorToErrorCode(MaybeServer.takeError()),
+ std::errc::operation_canceled);
}
} // namespace