aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Support/raw_socket_stream.cpp
blob: 14e2308df4d7ed4a71c0ddb49461568d853c0fa8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
//===-- llvm/Support/raw_socket_stream.cpp - Socket streams --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains raw_ostream implementations for streams to communicate
// via UNIX sockets
//
//===----------------------------------------------------------------------===//

#include "llvm/Support/raw_socket_stream.h"
#include "llvm/Config/config.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/FileSystem.h"

#include <atomic>
#include <fcntl.h>
#include <thread>

#ifndef _WIN32
#include <poll.h>
#include <sys/socket.h>
#include <sys/un.h>
#else
#include "llvm/Support/Windows/WindowsSupport.h"
// winsock2.h must be included before afunix.h. Briefly turn off clang-format to
// avoid error.
// clang-format off
#include <winsock2.h>
#include <afunix.h>
// clang-format on
#include <io.h>
#endif // _WIN32

#if defined(HAVE_UNISTD_H)
#include <unistd.h>
#endif

using namespace llvm;

#ifdef _WIN32
WSABalancer::WSABalancer() {
  WSADATA WsaData;
  ::memset(&WsaData, 0, sizeof(WsaData));
  if (WSAStartup(MAKEWORD(2, 2), &WsaData) != 0) {
    llvm::report_fatal_error("WSAStartup failed");
  }
}

WSABalancer::~WSABalancer() { WSACleanup(); }
#endif // _WIN32

static std::error_code getLastSocketErrorCode() {
#ifdef _WIN32
  return std::error_code(::WSAGetLastError(), std::system_category());
#else
  return errnoAsErrorCode();
#endif
}

static sockaddr_un setSocketAddr(StringRef SocketPath) {
  struct sockaddr_un Addr;
  memset(&Addr, 0, sizeof(Addr));
  Addr.sun_family = AF_UNIX;
  strncpy(Addr.sun_path, SocketPath.str().c_str(), sizeof(Addr.sun_path) - 1);
  return Addr;
}

static Expected<int> getSocketFD(StringRef SocketPath) {
#ifdef _WIN32
  SOCKET Socket = socket(AF_UNIX, SOCK_STREAM, 0);
  if (Socket == INVALID_SOCKET) {
#else
  int Socket = socket(AF_UNIX, SOCK_STREAM, 0);
  if (Socket == -1) {
#endif // _WIN32
    return llvm::make_error<StringError>(getLastSocketErrorCode(),
                                         "Create socket failed");
  }

  struct sockaddr_un Addr = setSocketAddr(SocketPath);
  if (::connect(Socket, (struct sockaddr *)&Addr, sizeof(Addr)) == -1)
    return llvm::make_error<StringError>(getLastSocketErrorCode(),
                                         "Connect socket failed");

#ifdef _WIN32
  return _open_osfhandle(Socket, 0);
#else
  return Socket;
#endif // _WIN32
}

ListeningSocket::ListeningSocket(int SocketFD, StringRef SocketPath,
                                 int PipeFD[2])
    : FD(SocketFD), SocketPath(SocketPath), PipeFD{PipeFD[0], PipeFD[1]} {}

ListeningSocket::ListeningSocket(ListeningSocket &&LS)
    : FD(LS.FD.load()), SocketPath(LS.SocketPath),
      PipeFD{LS.PipeFD[0], LS.PipeFD[1]} {

  LS.FD = -1;
  LS.SocketPath.clear();
  LS.PipeFD[0] = -1;
  LS.PipeFD[1] = -1;
}

Expected<ListeningSocket> ListeningSocket::createUnix(StringRef SocketPath,
                                                      int MaxBacklog) {

  // Handle instances where the target socket address already exists and
  // differentiate between a preexisting file with and without a bound socket
  //
  // ::bind will return std::errc:address_in_use if a file at the socket address
  // already exists (e.g., the file was not properly unlinked due to a crash)
  // even if another socket has not yet binded to that address
  if (llvm::sys::fs::exists(SocketPath)) {
    Expected<int> MaybeFD = getSocketFD(SocketPath);
    if (!MaybeFD) {

      // Regardless of the error, notify the caller that a file already exists
      // at the desired socket address and that there is no bound socket at that
      // address. The file must be removed before ::bind can use the address
      consumeError(MaybeFD.takeError());
      return llvm::make_error<StringError>(
          std::make_error_code(std::errc::file_exists),
          "Socket address unavailable");
    }
    ::close(std::move(*MaybeFD));

    // Notify caller that the provided socket address already has a bound socket
    return llvm::make_error<StringError>(
        std::make_error_code(std::errc::address_in_use),
        "Socket address unavailable");
  }

#ifdef _WIN32
  WSABalancer _;
  SOCKET Socket = socket(AF_UNIX, SOCK_STREAM, 0);
  if (Socket == INVALID_SOCKET)
#else
  int Socket = socket(AF_UNIX, SOCK_STREAM, 0);
  if (Socket == -1)
#endif
    return llvm::make_error<StringError>(getLastSocketErrorCode(),
                                         "socket create failed");

  struct sockaddr_un Addr = setSocketAddr(SocketPath);
  if (::bind(Socket, (struct sockaddr *)&Addr, sizeof(Addr)) == -1) {
    // Grab error code from call to ::bind before calling ::close
    std::error_code EC = getLastSocketErrorCode();
    ::close(Socket);
    return llvm::make_error<StringError>(EC, "Bind error");
  }

  // Mark socket as passive so incoming connections can be accepted
  if (::listen(Socket, MaxBacklog) == -1)
    return llvm::make_error<StringError>(getLastSocketErrorCode(),
                                         "Listen error");

  int PipeFD[2];
#ifdef _WIN32
  // Reserve 1 byte for the pipe and use default textmode
  if (::_pipe(PipeFD, 1, 0) == -1)
#else
  if (::pipe(PipeFD) == -1)
#endif // _WIN32
    return llvm::make_error<StringError>(getLastSocketErrorCode(),
                                         "pipe failed");

#ifdef _WIN32
  return ListeningSocket{_open_osfhandle(Socket, 0), SocketPath, PipeFD};
#else
  return ListeningSocket{Socket, SocketPath, PipeFD};
#endif // _WIN32
}

Expected<std::unique_ptr<raw_socket_stream>>
ListeningSocket::accept(std::chrono::milliseconds Timeout) {

  struct pollfd FDs[2];
  FDs[0].events = POLLIN;
#ifdef _WIN32
  SOCKET WinServerSock = _get_osfhandle(FD);
  FDs[0].fd = WinServerSock;
#else
  FDs[0].fd = FD;
#endif
  FDs[1].events = POLLIN;
  FDs[1].fd = PipeFD[0];

  // Keep track of how much time has passed in case poll is interupted by a
  // signal and needs to be recalled
  int RemainingTime = Timeout.count();
  std::chrono::milliseconds ElapsedTime = std::chrono::milliseconds(0);
  int PollStatus = -1;

  while (PollStatus == -1 && (Timeout.count() == -1 || ElapsedTime < Timeout)) {
    if (Timeout.count() != -1)
      RemainingTime -= ElapsedTime.count();

    auto Start = std::chrono::steady_clock::now();
#ifdef _WIN32
    PollStatus = WSAPoll(FDs, 2, RemainingTime);
    if (PollStatus == SOCKET_ERROR) {
#else
    PollStatus = ::poll(FDs, 2, RemainingTime);
    if (PollStatus == -1) {
#endif
      // Ignore error if caused by interupting signal
      std::error_code PollErrCode = getLastSocketErrorCode();
      if (PollErrCode != std::errc::interrupted)
        return llvm::make_error<StringError>(PollErrCode, "FD poll failed");
    }

    if (PollStatus == 0)
      return llvm::make_error<StringError>(
          std::make_error_code(std::errc::timed_out),
          "No client requests within timeout window");

    if (FDs[0].revents & POLLNVAL)
      return llvm::make_error<StringError>(
          std::make_error_code(std::errc::bad_file_descriptor),
          "File descriptor closed by another thread");

    if (FDs[1].revents & POLLIN)
      return llvm::make_error<StringError>(
          std::make_error_code(std::errc::operation_canceled),
          "Accept canceled");

    auto Stop = std::chrono::steady_clock::now();
    ElapsedTime +=
        std::chrono::duration_cast<std::chrono::milliseconds>(Stop - Start);
  }

  int AcceptFD;
#ifdef _WIN32
  SOCKET WinAcceptSock = ::accept(WinServerSock, NULL, NULL);
  AcceptFD = _open_osfhandle(WinAcceptSock, 0);
#else
  AcceptFD = ::accept(FD, NULL, NULL);
#endif

  if (AcceptFD == -1)
    return llvm::make_error<StringError>(getLastSocketErrorCode(),
                                         "Socket accept failed");
  return std::make_unique<raw_socket_stream>(AcceptFD);
}

void ListeningSocket::shutdown() {
  int ObservedFD = FD.load();

  if (ObservedFD == -1)
    return;

  // If FD equals ObservedFD set FD to -1; If FD doesn't equal ObservedFD then
  // another thread is responsible for shutdown so return
  if (!FD.compare_exchange_strong(ObservedFD, -1))
    return;

  ::close(ObservedFD);
  ::unlink(SocketPath.c_str());

  // Ensure ::poll returns if shutdown is called by a seperate thread
  char Byte = 'A';
  ssize_t written = ::write(PipeFD[1], &Byte, 1);

  // Ignore any write() error
  (void)written;
}

ListeningSocket::~ListeningSocket() {
  shutdown();

  // Close the pipe's FDs in the destructor instead of within
  // ListeningSocket::shutdown to avoid unnecessary synchronization issues that
  // would occur as PipeFD's values would have to be changed to -1
  //
  // The move constructor sets PipeFD to -1
  if (PipeFD[0] != -1)
    ::close(PipeFD[0]);
  if (PipeFD[1] != -1)
    ::close(PipeFD[1]);
}

//===----------------------------------------------------------------------===//
//  raw_socket_stream
//===----------------------------------------------------------------------===//

raw_socket_stream::raw_socket_stream(int SocketFD)
    : raw_fd_stream(SocketFD, true) {}

Expected<std::unique_ptr<raw_socket_stream>>
raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
#ifdef _WIN32
  WSABalancer _;
#endif // _WIN32
  Expected<int> FD = getSocketFD(SocketPath);
  if (!FD)
    return FD.takeError();
  return std::make_unique<raw_socket_stream>(*FD);
}

raw_socket_stream::~raw_socket_stream() {}