diff options
author | David Benjamin <davidben@google.com> | 2018-05-05 00:42:23 -0400 |
---|---|---|
committer | CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org> | 2018-05-07 15:44:08 +0000 |
commit | e7ca8a5d78396388570df4d91058f6e170e8647f (patch) | |
tree | cb6d353ef64a0c7b83e154d61e72d51423c6b825 /tool | |
parent | e30fac6371c450833757021d6303d47f66e395f8 (diff) | |
download | boringssl-e7ca8a5d78396388570df4d91058f6e170e8647f.zip boringssl-e7ca8a5d78396388570df4d91058f6e170e8647f.tar.gz boringssl-e7ca8a5d78396388570df4d91058f6e170e8647f.tar.bz2 |
Fix bssl client/server's error-handling.
Rather than printing the SSL_ERROR_* constants, print the actual error.
This should be a bit more understandable. Debugging this also uncovered
some other issues on Windows:
- We were mixing up C runtime and Winsock errors, which are separate in
Windows.
- The thread local implementation interferes with WSAGetLastError due to
a quirk of TlsGetValue. This could affect other Windows consumers.
(Chromium uses a custom BIO, so it isn't affected.)
- SocketSetNonBlocking also interferes with WSAGetLastError.
- Listen for FD_CLOSE along with FD_READ. Connection close does not
signal FD_READ. (The select loop only barely works on Windows anyway
due to issues with stdin and line buffering, but if we take stdin out
of the equation, FD_CLOSE can be tested.)
Change-Id: If991259915acc96606a314fbe795fe6ea1e295e8
Reviewed-on: https://boringssl-review.googlesource.com/28125
Commit-Queue: Steven Valdez <svaldez@google.com>
Reviewed-by: Steven Valdez <svaldez@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
Diffstat (limited to 'tool')
-rw-r--r-- | tool/client.cc | 19 | ||||
-rw-r--r-- | tool/server.cc | 6 | ||||
-rw-r--r-- | tool/transport_common.cc | 80 | ||||
-rw-r--r-- | tool/transport_common.h | 5 |
4 files changed, 75 insertions, 35 deletions
diff --git a/tool/client.cc b/tool/client.cc index bdb5de7..037e10c 100644 --- a/tool/client.cc +++ b/tool/client.cc @@ -181,7 +181,7 @@ static int NewSessionCallback(SSL *ssl, SSL_SESSION *session) { if (!PEM_write_bio_SSL_SESSION(session_out.get(), session) || BIO_flush(session_out.get()) <= 0) { fprintf(stderr, "Error while saving session:\n"); - ERR_print_errors_cb(PrintErrorCallback, stderr); + ERR_print_errors_fp(stderr); return 0; } } @@ -221,8 +221,7 @@ static bool WaitForSession(SSL *ssl, int sock) { if (ssl_err == SSL_ERROR_WANT_READ) { continue; } - fprintf(stderr, "Error while reading: %d\n", ssl_err); - ERR_print_errors_cb(PrintErrorCallback, stderr); + PrintSSLError(stderr, "Error while reading", ssl_err, ssl_ret); return false; } } @@ -267,14 +266,14 @@ static bool DoConnection(SSL_CTX *ctx, "rb")); if (!in) { fprintf(stderr, "Error reading session\n"); - ERR_print_errors_cb(PrintErrorCallback, stderr); + ERR_print_errors_fp(stderr); return false; } bssl::UniquePtr<SSL_SESSION> session(PEM_read_bio_SSL_SESSION(in.get(), nullptr, nullptr, nullptr)); if (!session) { fprintf(stderr, "Error reading session\n"); - ERR_print_errors_cb(PrintErrorCallback, stderr); + ERR_print_errors_fp(stderr); return false; } SSL_set_session(ssl.get(), session.get()); @@ -294,8 +293,7 @@ static bool DoConnection(SSL_CTX *ctx, int ret = SSL_connect(ssl.get()); if (ret != 1) { int ssl_err = SSL_get_error(ssl.get(), ret); - fprintf(stderr, "Error while connecting: %d\n", ssl_err); - ERR_print_errors_cb(PrintErrorCallback, stderr); + PrintSSLError(stderr, "Error while connecting", ssl_err, ret); return false; } @@ -315,8 +313,7 @@ static bool DoConnection(SSL_CTX *ctx, int ssl_ret = SSL_write(ssl.get(), early_data.data(), ed_size); if (ssl_ret <= 0) { int ssl_err = SSL_get_error(ssl.get(), ssl_ret); - fprintf(stderr, "Error while writing: %d\n", ssl_err); - ERR_print_errors_cb(PrintErrorCallback, stderr); + PrintSSLError(stderr, "Error while writing", ssl_err, ssl_ret); return false; } else if (ssl_ret != ed_size) { fprintf(stderr, "Short write from SSL_write.\n"); @@ -500,7 +497,7 @@ bool Client(const std::vector<std::string> &args) { if (!session_out) { fprintf(stderr, "Error while opening %s:\n", args_map["-session-out"].c_str()); - ERR_print_errors_cb(PrintErrorCallback, stderr); + ERR_print_errors_fp(stderr); return false; } } @@ -513,7 +510,7 @@ bool Client(const std::vector<std::string> &args) { if (!SSL_CTX_load_verify_locations( ctx.get(), args_map["-root-certs"].c_str(), nullptr)) { fprintf(stderr, "Failed to load root certificates.\n"); - ERR_print_errors_cb(PrintErrorCallback, stderr); + ERR_print_errors_fp(stderr); return false; } SSL_CTX_set_verify(ctx.get(), SSL_VERIFY_PEER, nullptr); diff --git a/tool/server.cc b/tool/server.cc index 23a47e9..7a4e53b 100644 --- a/tool/server.cc +++ b/tool/server.cc @@ -185,8 +185,7 @@ static bool HandleWWW(SSL *ssl) { SSL_read(ssl, request + request_len, sizeof(request) - request_len); if (ssl_ret <= 0) { int ssl_err = SSL_get_error(ssl, ssl_ret); - fprintf(stderr, "Error while reading: %d\n", ssl_err); - ERR_print_errors_cb(PrintErrorCallback, stderr); + PrintSSLError(stderr, "Error while reading", ssl_err, ssl_ret); return false; } request_len += static_cast<size_t>(ssl_ret); @@ -342,8 +341,7 @@ bool Server(const std::vector<std::string> &args) { int ret = SSL_accept(ssl.get()); if (ret != 1) { int ssl_err = SSL_get_error(ssl.get(), ret); - fprintf(stderr, "Error while connecting: %d\n", ssl_err); - ERR_print_errors_cb(PrintErrorCallback, stderr); + PrintSSLError(stderr, "Error while connecting", ssl_err, ret); result = false; continue; } diff --git a/tool/transport_common.cc b/tool/transport_common.cc index 55f2059..dcb8e0d 100644 --- a/tool/transport_common.cc +++ b/tool/transport_common.cc @@ -91,6 +91,33 @@ static void SplitHostPort(std::string *out_hostname, std::string *out_port, } } +static std::string GetLastSocketErrorString() { +#if defined(OPENSSL_WINDOWS) + int error = WSAGetLastError(); + char *buffer; + DWORD len = FormatMessageA( + FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_ALLOCATE_BUFFER, 0, error, 0, + reinterpret_cast<char *>(&buffer), 0, nullptr); + if (len == 0) { + char buf[256]; + snprintf(buf, sizeof(buf), "unknown error (0x%x)", error); + return buf; + } + std::string ret(buffer, len); + LocalFree(buffer); + return ret; +#else + return strerror(errno); +#endif +} + +static void PrintSocketError(const char *function) { + // On Windows, |perror| and |errno| are part of the C runtime, while sockets + // are separate, so we must print errors manually. + std::string error = GetLastSocketErrorString(); + fprintf(stderr, "%s: %s\n", function, error.c_str()); +} + // Connect sets |*out_sock| to be a socket connected to the destination given // in |hostname_and_port|, which should be of the form "www.example.com:123". // It returns true on success and false otherwise. @@ -121,7 +148,7 @@ bool Connect(int *out_sock, const std::string &hostname_and_port) { *out_sock = socket(result->ai_family, result->ai_socktype, result->ai_protocol); if (*out_sock < 0) { - perror("socket"); + PrintSocketError("socket"); goto out; } @@ -145,7 +172,7 @@ bool Connect(int *out_sock, const std::string &hostname_and_port) { } if (connect(*out_sock, result->ai_addr, result->ai_addrlen) != 0) { - perror("connect"); + PrintSocketError("connect"); goto out; } ok = true; @@ -188,18 +215,18 @@ bool Listener::Init(const std::string &port) { server_sock_ = socket(addr.sin6_family, SOCK_STREAM, 0); if (server_sock_ < 0) { - perror("socket"); + PrintSocketError("socket"); return false; } if (setsockopt(server_sock_, SOL_SOCKET, SO_REUSEADDR, (const char *)&enable, sizeof(enable)) < 0) { - perror("setsockopt"); + PrintSocketError("setsockopt"); return false; } if (bind(server_sock_, (struct sockaddr *)&addr, sizeof(addr)) != 0) { - perror("connect"); + PrintSocketError("connect"); return false; } @@ -350,7 +377,7 @@ static bool SocketSelect(int sock, bool stdin_open, bool *socket_ready, #else WSAEVENT socket_handle = WSACreateEvent(); if (socket_handle == WSA_INVALID_EVENT || - WSAEventSelect(sock, socket_handle, FD_READ) != 0) { + WSAEventSelect(sock, socket_handle, FD_READ | FD_CLOSE) != 0) { WSACloseEvent(socket_handle); return false; } @@ -379,11 +406,26 @@ static bool SocketSelect(int sock, bool stdin_open, bool *socket_ready, #endif } -// PrintErrorCallback is a callback function from OpenSSL's -// |ERR_print_errors_cb| that writes errors to a given |FILE*|. -int PrintErrorCallback(const char *str, size_t len, void *ctx) { - fwrite(str, len, 1, reinterpret_cast<FILE*>(ctx)); - return 1; +void PrintSSLError(FILE *file, const char *msg, int ssl_err, int ret) { + switch (ssl_err) { + case SSL_ERROR_SSL: + fprintf(file, "%s: %s\n", msg, ERR_reason_error_string(ERR_peek_error())); + break; + case SSL_ERROR_SYSCALL: + if (ret == 0) { + fprintf(file, "%s: peer closed connection\n", msg); + } else { + std::string error = GetLastSocketErrorString(); + fprintf(file, "%s: %s\n", msg, error.c_str()); + } + break; + case SSL_ERROR_ZERO_RETURN: + fprintf(file, "%s: received close_notify\n", msg); + break; + default: + fprintf(file, "%s: unknown error type (%d)\n", msg, ssl_err); + } + ERR_print_errors_fp(file); } bool TransferData(SSL *ssl, int sock) { @@ -427,19 +469,20 @@ bool TransferData(SSL *ssl, int sock) { } #endif int ssl_ret = SSL_write(ssl, buffer, n); - if (!SocketSetNonBlocking(sock, true)) { - return false; - } - if (ssl_ret <= 0) { int ssl_err = SSL_get_error(ssl, ssl_ret); - fprintf(stderr, "Error while writing: %d\n", ssl_err); - ERR_print_errors_cb(PrintErrorCallback, stderr); + PrintSSLError(stderr, "Error while writing", ssl_err, ssl_ret); return false; } else if (ssl_ret != n) { fprintf(stderr, "Short write from SSL_write.\n"); return false; } + + // Note we handle errors before restoring the non-blocking state. On + // Windows, |SocketSetNonBlocking| internally clears the last error. + if (!SocketSetNonBlocking(sock, true)) { + return false; + } } if (socket_ready) { @@ -451,8 +494,7 @@ bool TransferData(SSL *ssl, int sock) { if (ssl_err == SSL_ERROR_WANT_READ) { continue; } - fprintf(stderr, "Error while reading: %d\n", ssl_err); - ERR_print_errors_cb(PrintErrorCallback, stderr); + PrintSSLError(stderr, "Error while reading", ssl_err, ssl_ret); return false; } else if (ssl_ret == 0) { return true; diff --git a/tool/transport_common.h b/tool/transport_common.h index 492416a..7d45d1c 100644 --- a/tool/transport_common.h +++ b/tool/transport_common.h @@ -53,7 +53,10 @@ void PrintConnectionInfo(BIO *bio, const SSL *ssl); bool SocketSetNonBlocking(int sock, bool is_non_blocking); -int PrintErrorCallback(const char *str, size_t len, void *ctx); +// PrintSSLError prints information about the most recent SSL error to stderr. +// |ssl_err| must be the output of |SSL_get_error| and the |SSL| object must be +// connected to socket from |Connect|. +void PrintSSLError(FILE *file, const char *msg, int ssl_err, int ret); bool TransferData(SSL *ssl, int sock); |