diff options
author | Ian Lance Taylor <iant@golang.org> | 2020-10-12 09:46:38 -0700 |
---|---|---|
committer | Ian Lance Taylor <iant@golang.org> | 2020-10-12 09:46:38 -0700 |
commit | 9cd320ea6572c577cdf17ce1f9ea5230b166af6d (patch) | |
tree | d1c8e7c2e09a91ed75f0e5476c648c2e745aa2de /libgo/go/net | |
parent | 4854d721be78358e59367982bdd94461b4be3c5a (diff) | |
parent | 3175d40fc52fb8eb3c3b18cc343d773da24434fb (diff) | |
download | gcc-9cd320ea6572c577cdf17ce1f9ea5230b166af6d.zip gcc-9cd320ea6572c577cdf17ce1f9ea5230b166af6d.tar.gz gcc-9cd320ea6572c577cdf17ce1f9ea5230b166af6d.tar.bz2 |
Merge from trunk revision 3175d40fc52fb8eb3c3b18cc343d773da24434fb.
Diffstat (limited to 'libgo/go/net')
69 files changed, 2456 insertions, 963 deletions
diff --git a/libgo/go/net/dial.go b/libgo/go/net/dial.go index d8be1c2..13a312a 100644 --- a/libgo/go/net/dial.go +++ b/libgo/go/net/dial.go @@ -7,7 +7,6 @@ package net import ( "context" "internal/nettrace" - "internal/poll" "syscall" "time" ) @@ -141,7 +140,7 @@ func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, er } timeRemaining := deadline.Sub(now) if timeRemaining <= 0 { - return time.Time{}, poll.ErrTimeout + return time.Time{}, errTimeout } // Tentatively allocate equal time to each remaining address. timeout := timeRemaining / time.Duration(addrsRemaining) diff --git a/libgo/go/net/dial_test.go b/libgo/go/net/dial_test.go index 493cdfc..0158248 100644 --- a/libgo/go/net/dial_test.go +++ b/libgo/go/net/dial_test.go @@ -9,7 +9,6 @@ package net import ( "bufio" "context" - "internal/poll" "internal/testenv" "io" "os" @@ -441,6 +440,14 @@ func TestDialParallelSpuriousConnection(t *testing.T) { t.Skip("both IPv4 and IPv6 are required") } + var readDeadline time.Time + if td, ok := t.Deadline(); ok { + const arbitraryCleanupMargin = 1 * time.Second + readDeadline = td.Add(-arbitraryCleanupMargin) + } else { + readDeadline = time.Now().Add(5 * time.Second) + } + var wg sync.WaitGroup wg.Add(2) handler := func(dss *dualStackServer, ln Listener) { @@ -450,7 +457,7 @@ func TestDialParallelSpuriousConnection(t *testing.T) { t.Fatal(err) } // The client should close itself, without sending data. - c.SetReadDeadline(time.Now().Add(1 * time.Second)) + c.SetReadDeadline(readDeadline) var b [1]byte if _, err := c.Read(b[:]); err != io.EOF { t.Errorf("got %v; want %v", err, io.EOF) @@ -532,8 +539,8 @@ func TestDialerPartialDeadline(t *testing.T) { {now, noDeadline, 1, noDeadline, nil}, // Step the clock forward and cross the deadline. {now.Add(-1 * time.Millisecond), now, 1, now, nil}, - {now.Add(0 * time.Millisecond), now, 1, noDeadline, poll.ErrTimeout}, - {now.Add(1 * time.Millisecond), now, 1, noDeadline, poll.ErrTimeout}, + {now.Add(0 * time.Millisecond), now, 1, noDeadline, errTimeout}, + {now.Add(1 * time.Millisecond), now, 1, noDeadline, errTimeout}, } for i, tt := range testCases { deadline, err := partialDeadline(tt.now, tt.deadline, tt.addrs) @@ -983,7 +990,7 @@ func TestDialerControl(t *testing.T) { // except that it won't skip testing on non-mobile builders. func mustHaveExternalNetwork(t *testing.T) { t.Helper() - mobile := runtime.GOOS == "android" || runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") + mobile := runtime.GOOS == "android" || runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" if testenv.Builder() == "" || mobile { testenv.MustHaveExternalNetwork(t) } diff --git a/libgo/go/net/dnsclient_test.go b/libgo/go/net/dnsclient_test.go index 3ab2b83..f3ed62d 100644 --- a/libgo/go/net/dnsclient_test.go +++ b/libgo/go/net/dnsclient_test.go @@ -42,7 +42,7 @@ func testUniformity(t *testing.T, size int, margin float64) { rand.Seed(1) data := make([]*SRV, size) for i := 0; i < size; i++ { - data[i] = &SRV{Target: string('a' + i), Weight: 1} + data[i] = &SRV{Target: string('a' + rune(i)), Weight: 1} } checkDistribution(t, data, margin) } diff --git a/libgo/go/net/dnsclient_unix.go b/libgo/go/net/dnsclient_unix.go index da6baf3..5f6c870 100644 --- a/libgo/go/net/dnsclient_unix.go +++ b/libgo/go/net/dnsclient_unix.go @@ -40,10 +40,10 @@ var ( errInvalidDNSResponse = errors.New("invalid DNS response") errNoAnswerFromDNSServer = errors.New("no answer from DNS server") - // errServerTemporarlyMisbehaving is like errServerMisbehaving, except + // errServerTemporarilyMisbehaving is like errServerMisbehaving, except // that when it gets translated to a DNSError, the IsTemporary field // gets set to true. - errServerTemporarlyMisbehaving = errors.New("server misbehaving") + errServerTemporarilyMisbehaving = errors.New("server misbehaving") ) func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) { @@ -206,7 +206,7 @@ func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error { // the server is behaving incorrectly or // having temporary trouble. if h.RCode == dnsmessage.RCodeServerFailure { - return errServerTemporarlyMisbehaving + return errServerTemporarilyMisbehaving } return errServerMisbehaving } @@ -278,7 +278,7 @@ func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, Name: name, Server: server, } - if err == errServerTemporarlyMisbehaving { + if err == errServerTemporarilyMisbehaving { dnsErr.IsTemporary = true } if err == errNoSuchHost { diff --git a/libgo/go/net/dnsclient_unix_test.go b/libgo/go/net/dnsclient_unix_test.go index e8f81e8..a89dccc 100644 --- a/libgo/go/net/dnsclient_unix_test.go +++ b/libgo/go/net/dnsclient_unix_test.go @@ -10,7 +10,6 @@ import ( "context" "errors" "fmt" - "internal/poll" "io/ioutil" "os" "path" @@ -480,7 +479,7 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) { break default: time.Sleep(10 * time.Millisecond) - return dnsmessage.Message{}, poll.ErrTimeout + return dnsmessage.Message{}, os.ErrDeadlineExceeded } r := dnsmessage.Message{ Header: dnsmessage.Header{ @@ -993,7 +992,7 @@ func TestRetryTimeout(t *testing.T) { if s == "192.0.2.1:53" { deadline0 = deadline time.Sleep(10 * time.Millisecond) - return dnsmessage.Message{}, poll.ErrTimeout + return dnsmessage.Message{}, os.ErrDeadlineExceeded } if deadline.Equal(deadline0) { @@ -1131,7 +1130,7 @@ func TestStrictErrorsLookupIP(t *testing.T) { } makeTimeout := func() error { return &DNSError{ - Err: poll.ErrTimeout.Error(), + Err: os.ErrDeadlineExceeded.Error(), Name: name, Server: server, IsTimeout: true, @@ -1247,7 +1246,7 @@ func TestStrictErrorsLookupIP(t *testing.T) { Questions: q.Questions, }, nil case resolveTimeout: - return dnsmessage.Message{}, poll.ErrTimeout + return dnsmessage.Message{}, os.ErrDeadlineExceeded default: t.Fatal("Impossible resolveWhich") } @@ -1372,7 +1371,7 @@ func TestStrictErrorsLookupTXT(t *testing.T) { switch q.Questions[0].Name.String() { case searchX: - return dnsmessage.Message{}, poll.ErrTimeout + return dnsmessage.Message{}, os.ErrDeadlineExceeded case searchY: return mockTXTResponse(q), nil default: @@ -1387,7 +1386,7 @@ func TestStrictErrorsLookupTXT(t *testing.T) { var wantRRs int if strict { wantErr = &DNSError{ - Err: poll.ErrTimeout.Error(), + Err: os.ErrDeadlineExceeded.Error(), Name: name, Server: server, IsTimeout: true, @@ -1415,7 +1414,7 @@ func TestDNSGoroutineRace(t *testing.T) { fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) { time.Sleep(10 * time.Microsecond) - return dnsmessage.Message{}, poll.ErrTimeout + return dnsmessage.Message{}, os.ErrDeadlineExceeded }} r := Resolver{PreferGo: true, Dial: fake.DialContext} diff --git a/libgo/go/net/error_test.go b/libgo/go/net/error_test.go index 89dcc2e..8d4a7ff 100644 --- a/libgo/go/net/error_test.go +++ b/libgo/go/net/error_test.go @@ -91,7 +91,7 @@ second: return nil } switch err := nestedErr.(type) { - case *AddrError, addrinfoErrno, *DNSError, InvalidAddrError, *ParseError, *poll.TimeoutError, UnknownNetworkError: + case *AddrError, addrinfoErrno, *timeoutError, *DNSError, InvalidAddrError, *ParseError, *poll.DeadlineExceededError, UnknownNetworkError: return nil case *os.SyscallError: nestedErr = err.Err @@ -436,7 +436,7 @@ second: goto third } switch nestedErr { - case poll.ErrNetClosing, poll.ErrTimeout, poll.ErrNotPollable: + case poll.ErrNetClosing, errTimeout, poll.ErrNotPollable, os.ErrDeadlineExceeded: return nil } return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr) @@ -471,14 +471,14 @@ second: return nil } switch err := nestedErr.(type) { - case *AddrError, addrinfoErrno, *DNSError, InvalidAddrError, *ParseError, *poll.TimeoutError, UnknownNetworkError: + case *AddrError, addrinfoErrno, *timeoutError, *DNSError, InvalidAddrError, *ParseError, *poll.DeadlineExceededError, UnknownNetworkError: return nil case *os.SyscallError: nestedErr = err.Err goto third } switch nestedErr { - case errCanceled, poll.ErrNetClosing, errMissingAddress, poll.ErrTimeout, ErrWriteToConnected, io.ErrUnexpectedEOF: + case errCanceled, poll.ErrNetClosing, errMissingAddress, errTimeout, os.ErrDeadlineExceeded, ErrWriteToConnected, io.ErrUnexpectedEOF: return nil } return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr) @@ -627,7 +627,7 @@ second: goto third } switch nestedErr { - case poll.ErrNetClosing, poll.ErrTimeout, poll.ErrNotPollable: + case poll.ErrNetClosing, errTimeout, poll.ErrNotPollable, os.ErrDeadlineExceeded: return nil } return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr) diff --git a/libgo/go/net/fd_posix.go b/libgo/go/net/fd_posix.go new file mode 100644 index 0000000..b2f99bc --- /dev/null +++ b/libgo/go/net/fd_posix.go @@ -0,0 +1,100 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris windows + +package net + +import ( + "internal/poll" + "runtime" + "syscall" + "time" +) + +// Network file descriptor. +type netFD struct { + pfd poll.FD + + // immutable until Close + family int + sotype int + isConnected bool // handshake completed or use of association with peer + net string + laddr Addr + raddr Addr +} + +func (fd *netFD) setAddr(laddr, raddr Addr) { + fd.laddr = laddr + fd.raddr = raddr + runtime.SetFinalizer(fd, (*netFD).Close) +} + +func (fd *netFD) Close() error { + runtime.SetFinalizer(fd, nil) + return fd.pfd.Close() +} + +func (fd *netFD) shutdown(how int) error { + err := fd.pfd.Shutdown(how) + runtime.KeepAlive(fd) + return wrapSyscallError("shutdown", err) +} + +func (fd *netFD) closeRead() error { + return fd.shutdown(syscall.SHUT_RD) +} + +func (fd *netFD) closeWrite() error { + return fd.shutdown(syscall.SHUT_WR) +} + +func (fd *netFD) Read(p []byte) (n int, err error) { + n, err = fd.pfd.Read(p) + runtime.KeepAlive(fd) + return n, wrapSyscallError(readSyscallName, err) +} + +func (fd *netFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) { + n, sa, err = fd.pfd.ReadFrom(p) + runtime.KeepAlive(fd) + return n, sa, wrapSyscallError(readFromSyscallName, err) +} + +func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) { + n, oobn, flags, sa, err = fd.pfd.ReadMsg(p, oob) + runtime.KeepAlive(fd) + return n, oobn, flags, sa, wrapSyscallError(readMsgSyscallName, err) +} + +func (fd *netFD) Write(p []byte) (nn int, err error) { + nn, err = fd.pfd.Write(p) + runtime.KeepAlive(fd) + return nn, wrapSyscallError(writeSyscallName, err) +} + +func (fd *netFD) writeTo(p []byte, sa syscall.Sockaddr) (n int, err error) { + n, err = fd.pfd.WriteTo(p, sa) + runtime.KeepAlive(fd) + return n, wrapSyscallError(writeToSyscallName, err) +} + +func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) { + n, oobn, err = fd.pfd.WriteMsg(p, oob, sa) + runtime.KeepAlive(fd) + return n, oobn, wrapSyscallError(writeMsgSyscallName, err) +} + +func (fd *netFD) SetDeadline(t time.Time) error { + return fd.pfd.SetDeadline(t) +} + +func (fd *netFD) SetReadDeadline(t time.Time) error { + return fd.pfd.SetReadDeadline(t) +} + +func (fd *netFD) SetWriteDeadline(t time.Time) error { + return fd.pfd.SetWriteDeadline(t) +} diff --git a/libgo/go/net/fd_unix.go b/libgo/go/net/fd_unix.go index 117f5a9..ad79c06 100644 --- a/libgo/go/net/fd_unix.go +++ b/libgo/go/net/fd_unix.go @@ -12,21 +12,16 @@ import ( "os" "runtime" "syscall" - "time" ) -// Network file descriptor. -type netFD struct { - pfd poll.FD - - // immutable until Close - family int - sotype int - isConnected bool // handshake completed or use of association with peer - net string - laddr Addr - raddr Addr -} +const ( + readSyscallName = "read" + readFromSyscallName = "recvfrom" + readMsgSyscallName = "recvmsg" + writeSyscallName = "write" + writeToSyscallName = "sendto" + writeMsgSyscallName = "sendmsg" +) func newFD(sysfd, family, sotype int, net string) (*netFD, error) { ret := &netFD{ @@ -46,12 +41,6 @@ func (fd *netFD) init() error { return fd.pfd.Init(fd.net, true) } -func (fd *netFD) setAddr(laddr, raddr Addr) { - fd.laddr = laddr - fd.raddr = raddr - runtime.SetFinalizer(fd, (*netFD).Close) -} - func (fd *netFD) name() string { var ls, rs string if fd.laddr != nil { @@ -179,61 +168,6 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (rsa sysc } } -func (fd *netFD) Close() error { - runtime.SetFinalizer(fd, nil) - return fd.pfd.Close() -} - -func (fd *netFD) shutdown(how int) error { - err := fd.pfd.Shutdown(how) - runtime.KeepAlive(fd) - return wrapSyscallError("shutdown", err) -} - -func (fd *netFD) closeRead() error { - return fd.shutdown(syscall.SHUT_RD) -} - -func (fd *netFD) closeWrite() error { - return fd.shutdown(syscall.SHUT_WR) -} - -func (fd *netFD) Read(p []byte) (n int, err error) { - n, err = fd.pfd.Read(p) - runtime.KeepAlive(fd) - return n, wrapSyscallError("read", err) -} - -func (fd *netFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) { - n, sa, err = fd.pfd.ReadFrom(p) - runtime.KeepAlive(fd) - return n, sa, wrapSyscallError("recvfrom", err) -} - -func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) { - n, oobn, flags, sa, err = fd.pfd.ReadMsg(p, oob) - runtime.KeepAlive(fd) - return n, oobn, flags, sa, wrapSyscallError("recvmsg", err) -} - -func (fd *netFD) Write(p []byte) (nn int, err error) { - nn, err = fd.pfd.Write(p) - runtime.KeepAlive(fd) - return nn, wrapSyscallError("write", err) -} - -func (fd *netFD) writeTo(p []byte, sa syscall.Sockaddr) (n int, err error) { - n, err = fd.pfd.WriteTo(p, sa) - runtime.KeepAlive(fd) - return n, wrapSyscallError("sendto", err) -} - -func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) { - n, oobn, err = fd.pfd.WriteMsg(p, oob, sa) - runtime.KeepAlive(fd) - return n, oobn, wrapSyscallError("sendmsg", err) -} - func (fd *netFD) accept() (netfd *netFD, err error) { d, rsa, errcall, err := fd.pfd.Accept() if err != nil { @@ -267,15 +201,3 @@ func (fd *netFD) dup() (f *os.File, err error) { return os.NewFile(uintptr(ns), fd.name()), nil } - -func (fd *netFD) SetDeadline(t time.Time) error { - return fd.pfd.SetDeadline(t) -} - -func (fd *netFD) SetReadDeadline(t time.Time) error { - return fd.pfd.SetReadDeadline(t) -} - -func (fd *netFD) SetWriteDeadline(t time.Time) error { - return fd.pfd.SetWriteDeadline(t) -} diff --git a/libgo/go/net/fd_windows.go b/libgo/go/net/fd_windows.go index 3cc4c7a6..030b6a1 100644 --- a/libgo/go/net/fd_windows.go +++ b/libgo/go/net/fd_windows.go @@ -10,10 +10,18 @@ import ( "os" "runtime" "syscall" - "time" "unsafe" ) +const ( + readSyscallName = "wsarecv" + readFromSyscallName = "wsarecvfrom" + readMsgSyscallName = "wsarecvmsg" + writeSyscallName = "wsasend" + writeToSyscallName = "wsasendto" + writeMsgSyscallName = "wsasendmsg" +) + // canUseConnectEx reports whether we can use the ConnectEx Windows API call // for the given network type. func canUseConnectEx(net string) bool { @@ -25,19 +33,6 @@ func canUseConnectEx(net string) bool { return false } -// Network file descriptor. -type netFD struct { - pfd poll.FD - - // immutable until Close - family int - sotype int - isConnected bool // handshake completed or use of association with peer - net string - laddr Addr - raddr Addr -} - func newFD(sysfd syscall.Handle, family, sotype int, net string) (*netFD, error) { ret := &netFD{ pfd: poll.FD{ @@ -60,12 +55,6 @@ func (fd *netFD) init() error { return err } -func (fd *netFD) setAddr(laddr, raddr Addr) { - fd.laddr = laddr - fd.raddr = raddr - runtime.SetFinalizer(fd, (*netFD).Close) -} - // Always returns nil for connected peer address result. func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (syscall.Sockaddr, error) { // Do not need to call fd.writeLock here, @@ -129,43 +118,6 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (syscall. return nil, os.NewSyscallError("setsockopt", syscall.Setsockopt(fd.pfd.Sysfd, syscall.SOL_SOCKET, syscall.SO_UPDATE_CONNECT_CONTEXT, (*byte)(unsafe.Pointer(&fd.pfd.Sysfd)), int32(unsafe.Sizeof(fd.pfd.Sysfd)))) } -func (fd *netFD) Close() error { - runtime.SetFinalizer(fd, nil) - return fd.pfd.Close() -} - -func (fd *netFD) shutdown(how int) error { - err := fd.pfd.Shutdown(how) - runtime.KeepAlive(fd) - return err -} - -func (fd *netFD) closeRead() error { - return fd.shutdown(syscall.SHUT_RD) -} - -func (fd *netFD) closeWrite() error { - return fd.shutdown(syscall.SHUT_WR) -} - -func (fd *netFD) Read(buf []byte) (int, error) { - n, err := fd.pfd.Read(buf) - runtime.KeepAlive(fd) - return n, wrapSyscallError("wsarecv", err) -} - -func (fd *netFD) readFrom(buf []byte) (int, syscall.Sockaddr, error) { - n, sa, err := fd.pfd.ReadFrom(buf) - runtime.KeepAlive(fd) - return n, sa, wrapSyscallError("wsarecvfrom", err) -} - -func (fd *netFD) Write(buf []byte) (int, error) { - n, err := fd.pfd.Write(buf) - runtime.KeepAlive(fd) - return n, wrapSyscallError("wsasend", err) -} - func (c *conn) writeBuffers(v *Buffers) (int64, error) { if !c.ok() { return 0, syscall.EINVAL @@ -183,12 +135,6 @@ func (fd *netFD) writeBuffers(buf *Buffers) (int64, error) { return n, wrapSyscallError("wsasend", err) } -func (fd *netFD) writeTo(buf []byte, sa syscall.Sockaddr) (int, error) { - n, err := fd.pfd.WriteTo(buf, sa) - runtime.KeepAlive(fd) - return n, wrapSyscallError("wsasendto", err) -} - func (fd *netFD) accept() (*netFD, error) { s, rawsa, rsan, errcall, err := fd.pfd.Accept(func() (syscall.Handle, error) { return sysSocket(fd.family, fd.sotype, 0) @@ -224,33 +170,9 @@ func (fd *netFD) accept() (*netFD, error) { return netfd, nil } -func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) { - n, oobn, flags, sa, err = fd.pfd.ReadMsg(p, oob) - runtime.KeepAlive(fd) - return n, oobn, flags, sa, wrapSyscallError("wsarecvmsg", err) -} - -func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) { - n, oobn, err = fd.pfd.WriteMsg(p, oob, sa) - runtime.KeepAlive(fd) - return n, oobn, wrapSyscallError("wsasendmsg", err) -} - // Unimplemented functions. func (fd *netFD) dup() (*os.File, error) { // TODO: Implement this return nil, syscall.EWINDOWS } - -func (fd *netFD) SetDeadline(t time.Time) error { - return fd.pfd.SetDeadline(t) -} - -func (fd *netFD) SetReadDeadline(t time.Time) error { - return fd.pfd.SetReadDeadline(t) -} - -func (fd *netFD) SetWriteDeadline(t time.Time) error { - return fd.pfd.SetWriteDeadline(t) -} diff --git a/libgo/go/net/http/cgi/child.go b/libgo/go/net/http/cgi/child.go index cb140f8..61de616 100644 --- a/libgo/go/net/http/cgi/child.go +++ b/libgo/go/net/http/cgi/child.go @@ -89,8 +89,6 @@ func RequestFromMap(params map[string]string) (*http.Request, error) { r.Header.Add(strings.ReplaceAll(k[5:], "_", "-"), v) } - // TODO: cookies. parsing them isn't exported, though. - uriStr := params["REQUEST_URI"] if uriStr == "" { // Fallback to SCRIPT_NAME, PATH_INFO and QUERY_STRING. @@ -165,10 +163,12 @@ func Serve(handler http.Handler) error { } type response struct { - req *http.Request - header http.Header - bufw *bufio.Writer - headerSent bool + req *http.Request + header http.Header + code int + wroteHeader bool + wroteCGIHeader bool + bufw *bufio.Writer } func (r *response) Flush() { @@ -180,26 +180,38 @@ func (r *response) Header() http.Header { } func (r *response) Write(p []byte) (n int, err error) { - if !r.headerSent { + if !r.wroteHeader { r.WriteHeader(http.StatusOK) } + if !r.wroteCGIHeader { + r.writeCGIHeader(p) + } return r.bufw.Write(p) } func (r *response) WriteHeader(code int) { - if r.headerSent { + if r.wroteHeader { // Note: explicitly using Stderr, as Stdout is our HTTP output. fmt.Fprintf(os.Stderr, "CGI attempted to write header twice on request for %s", r.req.URL) return } - r.headerSent = true - fmt.Fprintf(r.bufw, "Status: %d %s\r\n", code, http.StatusText(code)) + r.wroteHeader = true + r.code = code +} - // Set a default Content-Type +// writeCGIHeader finalizes the header sent to the client and writes it to the output. +// p is not written by writeHeader, but is the first chunk of the body +// that will be written. It is sniffed for a Content-Type if none is +// set explicitly. +func (r *response) writeCGIHeader(p []byte) { + if r.wroteCGIHeader { + return + } + r.wroteCGIHeader = true + fmt.Fprintf(r.bufw, "Status: %d %s\r\n", r.code, http.StatusText(r.code)) if _, hasType := r.header["Content-Type"]; !hasType { - r.header.Add("Content-Type", "text/html; charset=utf-8") + r.header.Set("Content-Type", http.DetectContentType(p)) } - r.header.Write(r.bufw) r.bufw.WriteString("\r\n") r.bufw.Flush() diff --git a/libgo/go/net/http/cgi/child_test.go b/libgo/go/net/http/cgi/child_test.go index 14e0af4..f6ecb6e 100644 --- a/libgo/go/net/http/cgi/child_test.go +++ b/libgo/go/net/http/cgi/child_test.go @@ -7,6 +7,11 @@ package cgi import ( + "bufio" + "bytes" + "net/http" + "net/http/httptest" + "strings" "testing" ) @@ -148,3 +153,67 @@ func TestRequestWithoutRemotePort(t *testing.T) { t.Errorf("RemoteAddr: got %q; want %q", g, e) } } + +type countingWriter int + +func (c *countingWriter) Write(p []byte) (int, error) { + *c += countingWriter(len(p)) + return len(p), nil +} +func (c *countingWriter) WriteString(p string) (int, error) { + *c += countingWriter(len(p)) + return len(p), nil +} + +func TestResponse(t *testing.T) { + var tests = []struct { + name string + body string + wantCT string + }{ + { + name: "no body", + wantCT: "text/plain; charset=utf-8", + }, + { + name: "html", + body: "<html><head><title>test page</title></head><body>This is a body</body></html>", + wantCT: "text/html; charset=utf-8", + }, + { + name: "text", + body: strings.Repeat("gopher", 86), + wantCT: "text/plain; charset=utf-8", + }, + { + name: "jpg", + body: "\xFF\xD8\xFF" + strings.Repeat("B", 1024), + wantCT: "image/jpeg", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + resp := response{ + req: httptest.NewRequest("GET", "/", nil), + header: http.Header{}, + bufw: bufio.NewWriter(&buf), + } + n, err := resp.Write([]byte(tt.body)) + if err != nil { + t.Errorf("Write: unexpected %v", err) + } + if want := len(tt.body); n != want { + t.Errorf("reported short Write: got %v want %v", n, want) + } + resp.writeCGIHeader(nil) + resp.Flush() + if got := resp.Header().Get("Content-Type"); got != tt.wantCT { + t.Errorf("wrong content-type: got %q, want %q", got, tt.wantCT) + } + if !bytes.HasSuffix(buf.Bytes(), []byte(tt.body)) { + t.Errorf("body was not correctly written") + } + }) + } +} diff --git a/libgo/go/net/http/cgi/host.go b/libgo/go/net/http/cgi/host.go index 58e9f71..863f406 100644 --- a/libgo/go/net/http/cgi/host.go +++ b/libgo/go/net/http/cgi/host.go @@ -21,6 +21,7 @@ import ( "log" "net" "net/http" + "net/textproto" "os" "os/exec" "path/filepath" @@ -28,20 +29,29 @@ import ( "runtime" "strconv" "strings" + + "golang.org/x/net/http/httpguts" ) var trailingPort = regexp.MustCompile(`:([0-9]+)$`) -var osDefaultInheritEnv = map[string][]string{ - "darwin": {"DYLD_LIBRARY_PATH"}, - "freebsd": {"LD_LIBRARY_PATH"}, - "hpux": {"LD_LIBRARY_PATH", "SHLIB_PATH"}, - "irix": {"LD_LIBRARY_PATH", "LD_LIBRARYN32_PATH", "LD_LIBRARY64_PATH"}, - "linux": {"LD_LIBRARY_PATH"}, - "openbsd": {"LD_LIBRARY_PATH"}, - "solaris": {"LD_LIBRARY_PATH", "LD_LIBRARY_PATH_32", "LD_LIBRARY_PATH_64"}, - "windows": {"SystemRoot", "COMSPEC", "PATHEXT", "WINDIR"}, -} +var osDefaultInheritEnv = func() []string { + switch runtime.GOOS { + case "darwin": + return []string{"DYLD_LIBRARY_PATH"} + case "linux", "freebsd", "openbsd": + return []string{"LD_LIBRARY_PATH"} + case "hpux": + return []string{"LD_LIBRARY_PATH", "SHLIB_PATH"} + case "irix": + return []string{"LD_LIBRARY_PATH", "LD_LIBRARYN32_PATH", "LD_LIBRARY64_PATH"} + case "solaris": + return []string{"LD_LIBRARY_PATH", "LD_LIBRARY_PATH_32", "LD_LIBRARY_PATH_64"} + case "windows": + return []string{"SystemRoot", "COMSPEC", "PATHEXT", "WINDIR"} + } + return nil +}() // Handler runs an executable in a subprocess with a CGI environment. type Handler struct { @@ -183,7 +193,7 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } } - for _, e := range osDefaultInheritEnv[runtime.GOOS] { + for _, e := range osDefaultInheritEnv { if v := os.Getenv(e); v != "" { env = append(env, e+"="+v) } @@ -269,8 +279,11 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { continue } header, val := parts[0], parts[1] - header = strings.TrimSpace(header) - val = strings.TrimSpace(val) + if !httpguts.ValidHeaderFieldName(header) { + h.printf("cgi: invalid header name: %q", header) + continue + } + val = textproto.TrimString(val) switch { case header == "Status": if len(val) < 3 { diff --git a/libgo/go/net/http/cgi/integration_test.go b/libgo/go/net/http/cgi/integration_test.go index 32d59c0..295c3b8 100644 --- a/libgo/go/net/http/cgi/integration_test.go +++ b/libgo/go/net/http/cgi/integration_test.go @@ -16,7 +16,9 @@ import ( "io" "net/http" "net/http/httptest" + "net/url" "os" + "strings" "testing" "time" ) @@ -52,7 +54,7 @@ func TestHostingOurselves(t *testing.T) { } replay := runCgiTest(t, h, "GET /test.go?foo=bar&a=b HTTP/1.0\nHost: example.com\n\n", expectedMap) - if expected, got := "text/html; charset=utf-8", replay.Header().Get("Content-Type"); got != expected { + if expected, got := "text/plain; charset=utf-8", replay.Header().Get("Content-Type"); got != expected { t.Errorf("got a Content-Type of %q; expected %q", got, expected) } if expected, got := "X-Test-Value", replay.Header().Get("X-Test-Header"); got != expected { @@ -152,6 +154,51 @@ func TestChildOnlyHeaders(t *testing.T) { } } +func TestChildContentType(t *testing.T) { + testenv.MustHaveExec(t) + + h := &Handler{ + Path: os.Args[0], + Root: "/test.go", + Args: []string{"-test.run=TestBeChildCGIProcess"}, + } + var tests = []struct { + name string + body string + wantCT string + }{ + { + name: "no body", + wantCT: "text/plain; charset=utf-8", + }, + { + name: "html", + body: "<html><head><title>test page</title></head><body>This is a body</body></html>", + wantCT: "text/html; charset=utf-8", + }, + { + name: "text", + body: strings.Repeat("gopher", 86), + wantCT: "text/plain; charset=utf-8", + }, + { + name: "jpg", + body: "\xFF\xD8\xFF" + strings.Repeat("B", 1024), + wantCT: "image/jpeg", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expectedMap := map[string]string{"_body": tt.body} + req := fmt.Sprintf("GET /test.go?exact-body=%s HTTP/1.0\nHost: example.com\n\n", url.QueryEscape(tt.body)) + replay := runCgiTest(t, h, req, expectedMap) + if got := replay.Header().Get("Content-Type"); got != tt.wantCT { + t.Errorf("got a Content-Type of %q; expected it to start with %q", got, tt.wantCT) + } + }) + } +} + // golang.org/issue/7198 func Test500WithNoHeaders(t *testing.T) { want500Test(t, "/immediate-disconnect") } func Test500WithNoContentType(t *testing.T) { want500Test(t, "/no-content-type") } @@ -203,6 +250,10 @@ func TestBeChildCGIProcess(t *testing.T) { if req.FormValue("no-body") == "1" { return } + if eb, ok := req.Form["exact-body"]; ok { + io.WriteString(rw, eb[0]) + return + } if req.FormValue("write-forever") == "1" { io.Copy(rw, neverEnding('a')) for { diff --git a/libgo/go/net/http/client.go b/libgo/go/net/http/client.go index a496f1c..3860d97 100644 --- a/libgo/go/net/http/client.go +++ b/libgo/go/net/http/client.go @@ -216,7 +216,7 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, d if req.RequestURI != "" { req.closeBody() - return nil, alwaysFalse, errors.New("http: Request.RequestURI can't be set in client requests.") + return nil, alwaysFalse, errors.New("http: Request.RequestURI can't be set in client requests") } // forkReq forks req into a shallow clone of ireq the first @@ -265,6 +265,25 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, d } return nil, didTimeout, err } + if resp == nil { + return nil, didTimeout, fmt.Errorf("http: RoundTripper implementation (%T) returned a nil *Response with a nil error", rt) + } + if resp.Body == nil { + // The documentation on the Body field says “The http Client and Transport + // guarantee that Body is always non-nil, even on responses without a body + // or responses with a zero-length body.” Unfortunately, we didn't document + // that same constraint for arbitrary RoundTripper implementations, and + // RoundTripper implementations in the wild (mostly in tests) assume that + // they can use a nil Body to mean an empty one (similar to Request.Body). + // (See https://golang.org/issue/38095.) + // + // If the ContentLength allows the Body to be empty, fill in an empty one + // here to ensure that it is non-nil. + if resp.ContentLength > 0 && req.Method != "HEAD" { + return nil, didTimeout, fmt.Errorf("http: RoundTripper implementation (%T) returned a *Response with content length %d but a nil Body", rt, resp.ContentLength) + } + resp.Body = ioutil.NopCloser(strings.NewReader("")) + } if !deadline.IsZero() { resp.Body = &cancelTimerBody{ stop: stopTimer, diff --git a/libgo/go/net/http/client_test.go b/libgo/go/net/http/client_test.go index 2b4f53f..80807fa 100644 --- a/libgo/go/net/http/client_test.go +++ b/libgo/go/net/http/client_test.go @@ -1991,3 +1991,38 @@ func testClientDoCanceledVsTimeout(t *testing.T, h2 bool) { }) } } + +type nilBodyRoundTripper struct{} + +func (nilBodyRoundTripper) RoundTrip(req *Request) (*Response, error) { + return &Response{ + StatusCode: StatusOK, + Status: StatusText(StatusOK), + Body: nil, + Request: req, + }, nil +} + +func TestClientPopulatesNilResponseBody(t *testing.T) { + c := &Client{Transport: nilBodyRoundTripper{}} + + resp, err := c.Get("http://localhost/anything") + if err != nil { + t.Fatalf("Client.Get rejected Response with nil Body: %v", err) + } + + if resp.Body == nil { + t.Fatalf("Client failed to provide a non-nil Body as documented") + } + defer func() { + if err := resp.Body.Close(); err != nil { + t.Fatalf("error from Close on substitute Response.Body: %v", err) + } + }() + + if b, err := ioutil.ReadAll(resp.Body); err != nil { + t.Errorf("read error from substitute Response.Body: %v", err) + } else if len(b) != 0 { + t.Errorf("substitute Response.Body was unexpectedly non-empty: %q", b) + } +} diff --git a/libgo/go/net/http/cookie.go b/libgo/go/net/http/cookie.go index 5c572d6d..d7a8f5e 100644 --- a/libgo/go/net/http/cookie.go +++ b/libgo/go/net/http/cookie.go @@ -7,6 +7,7 @@ package http import ( "log" "net" + "net/textproto" "strconv" "strings" "time" @@ -60,11 +61,11 @@ func readSetCookies(h Header) []*Cookie { } cookies := make([]*Cookie, 0, cookieCount) for _, line := range h["Set-Cookie"] { - parts := strings.Split(strings.TrimSpace(line), ";") + parts := strings.Split(textproto.TrimString(line), ";") if len(parts) == 1 && parts[0] == "" { continue } - parts[0] = strings.TrimSpace(parts[0]) + parts[0] = textproto.TrimString(parts[0]) j := strings.Index(parts[0], "=") if j < 0 { continue @@ -83,7 +84,7 @@ func readSetCookies(h Header) []*Cookie { Raw: line, } for i := 1; i < len(parts); i++ { - parts[i] = strings.TrimSpace(parts[i]) + parts[i] = textproto.TrimString(parts[i]) if len(parts[i]) == 0 { continue } @@ -242,7 +243,7 @@ func readCookies(h Header, filter string) []*Cookie { cookies := make([]*Cookie, 0, len(lines)+strings.Count(lines[0], ";")) for _, line := range lines { - line = strings.TrimSpace(line) + line = textproto.TrimString(line) var part string for len(line) > 0 { // continue since we have rest @@ -251,7 +252,7 @@ func readCookies(h Header, filter string) []*Cookie { } else { part, line = line, "" } - part = strings.TrimSpace(part) + part = textproto.TrimString(part) if len(part) == 0 { continue } diff --git a/libgo/go/net/http/fcgi/child.go b/libgo/go/net/http/fcgi/child.go index 30a6b2c..a31273b 100644 --- a/libgo/go/net/http/fcgi/child.go +++ b/libgo/go/net/http/fcgi/child.go @@ -74,10 +74,12 @@ func (r *request) parseParams() { // response implements http.ResponseWriter. type response struct { - req *request - header http.Header - w *bufWriter - wroteHeader bool + req *request + header http.Header + code int + wroteHeader bool + wroteCGIHeader bool + w *bufWriter } func newResponse(c *child, req *request) *response { @@ -92,11 +94,14 @@ func (r *response) Header() http.Header { return r.header } -func (r *response) Write(data []byte) (int, error) { +func (r *response) Write(p []byte) (n int, err error) { if !r.wroteHeader { r.WriteHeader(http.StatusOK) } - return r.w.Write(data) + if !r.wroteCGIHeader { + r.writeCGIHeader(p) + } + return r.w.Write(p) } func (r *response) WriteHeader(code int) { @@ -104,22 +109,34 @@ func (r *response) WriteHeader(code int) { return } r.wroteHeader = true + r.code = code if code == http.StatusNotModified { // Must not have body. r.header.Del("Content-Type") r.header.Del("Content-Length") r.header.Del("Transfer-Encoding") - } else if r.header.Get("Content-Type") == "" { - r.header.Set("Content-Type", "text/html; charset=utf-8") } - if r.header.Get("Date") == "" { r.header.Set("Date", time.Now().UTC().Format(http.TimeFormat)) } +} - fmt.Fprintf(r.w, "Status: %d %s\r\n", code, http.StatusText(code)) +// writeCGIHeader finalizes the header sent to the client and writes it to the output. +// p is not written by writeHeader, but is the first chunk of the body +// that will be written. It is sniffed for a Content-Type if none is +// set explicitly. +func (r *response) writeCGIHeader(p []byte) { + if r.wroteCGIHeader { + return + } + r.wroteCGIHeader = true + fmt.Fprintf(r.w, "Status: %d %s\r\n", r.code, http.StatusText(r.code)) + if _, hasType := r.header["Content-Type"]; r.code != http.StatusNotModified && !hasType { + r.header.Set("Content-Type", http.DetectContentType(p)) + } r.header.Write(r.w) r.w.WriteString("\r\n") + r.w.Flush() } func (r *response) Flush() { @@ -290,6 +307,8 @@ func (c *child) serveRequest(req *request, body io.ReadCloser) { httpReq = httpReq.WithContext(envVarCtx) c.handler.ServeHTTP(r, httpReq) } + // Make sure we serve something even if nothing was written to r + r.Write(nil) r.Close() c.mu.Lock() delete(c.requests, req.reqId) diff --git a/libgo/go/net/http/fcgi/fcgi_test.go b/libgo/go/net/http/fcgi/fcgi_test.go index e9d2b34..59246c2 100644 --- a/libgo/go/net/http/fcgi/fcgi_test.go +++ b/libgo/go/net/http/fcgi/fcgi_test.go @@ -10,6 +10,7 @@ import ( "io" "io/ioutil" "net/http" + "strings" "testing" ) @@ -344,3 +345,55 @@ func TestChildServeReadsEnvVars(t *testing.T) { <-done } } + +func TestResponseWriterSniffsContentType(t *testing.T) { + t.Skip("this test is flaky, see Issue 41167") + var tests = []struct { + name string + body string + wantCT string + }{ + { + name: "no body", + wantCT: "text/plain; charset=utf-8", + }, + { + name: "html", + body: "<html><head><title>test page</title></head><body>This is a body</body></html>", + wantCT: "text/html; charset=utf-8", + }, + { + name: "text", + body: strings.Repeat("gopher", 86), + wantCT: "text/plain; charset=utf-8", + }, + { + name: "jpg", + body: "\xFF\xD8\xFF" + strings.Repeat("B", 1024), + wantCT: "image/jpeg", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + input := make([]byte, len(streamFullRequestStdin)) + copy(input, streamFullRequestStdin) + rc := nopWriteCloser{bytes.NewBuffer(input)} + done := make(chan bool) + var resp *response + c := newChild(rc, http.HandlerFunc(func( + w http.ResponseWriter, + r *http.Request, + ) { + io.WriteString(w, tt.body) + resp = w.(*response) + done <- true + })) + defer c.cleanUp() + go c.serve() + <-done + if got := resp.Header().Get("Content-Type"); got != tt.wantCT { + t.Errorf("got a Content-Type of %q; expected it to start with %q", got, tt.wantCT) + } + }) + } +} diff --git a/libgo/go/net/http/fs.go b/libgo/go/net/http/fs.go index d214485..922706a 100644 --- a/libgo/go/net/http/fs.go +++ b/libgo/go/net/http/fs.go @@ -30,11 +30,13 @@ import ( // value is a filename on the native file system, not a URL, so it is separated // by filepath.Separator, which isn't necessarily '/'. // -// Note that Dir will allow access to files and directories starting with a -// period, which could expose sensitive directories like a .git directory or -// sensitive files like .htpasswd. To exclude files with a leading period, -// remove the files/directories from the server or create a custom FileSystem -// implementation. +// Note that Dir could expose sensitive files and directories. Dir will follow +// symlinks pointing out of the directory tree, which can be especially dangerous +// if serving from a directory in which users are able to create arbitrary symlinks. +// Dir will also allow access to files and directories starting with a period, +// which could expose sensitive directories like .git or sensitive files like +// .htpasswd. To exclude files with a leading period, remove the files/directories +// from the server or create a custom FileSystem implementation. // // An empty Dir is treated as ".". type Dir string @@ -411,6 +413,7 @@ func checkIfNoneMatch(w ResponseWriter, r *Request) condResult { } if buf[0] == ',' { buf = buf[1:] + continue } if buf[0] == '*' { return condFalse @@ -756,7 +759,7 @@ func parseRange(s string, size int64) ([]httpRange, error) { var ranges []httpRange noOverlap := false for _, ra := range strings.Split(s[len(b):], ",") { - ra = strings.TrimSpace(ra) + ra = textproto.TrimString(ra) if ra == "" { continue } @@ -764,7 +767,7 @@ func parseRange(s string, size int64) ([]httpRange, error) { if i < 0 { return nil, errors.New("invalid range") } - start, end := strings.TrimSpace(ra[:i]), strings.TrimSpace(ra[i+1:]) + start, end := textproto.TrimString(ra[:i]), textproto.TrimString(ra[i+1:]) var r httpRange if start == "" { // If no start is specified, end specifies the diff --git a/libgo/go/net/http/fs_test.go b/libgo/go/net/http/fs_test.go index 435e34b..c082cee 100644 --- a/libgo/go/net/http/fs_test.go +++ b/libgo/go/net/http/fs_test.go @@ -849,6 +849,15 @@ func TestServeContent(t *testing.T) { wantStatus: 200, wantContentType: "text/css; charset=utf-8", }, + "if_none_match_malformed": { + file: "testdata/style.css", + serveETag: `"foo"`, + reqHeader: map[string]string{ + "If-None-Match": `,`, + }, + wantStatus: 200, + wantContentType: "text/css; charset=utf-8", + }, "range_good": { file: "testdata/style.css", serveETag: `"A"`, diff --git a/libgo/go/net/http/h2_bundle.go b/libgo/go/net/http/h2_bundle.go index f03dbba..779da4f 100644 --- a/libgo/go/net/http/h2_bundle.go +++ b/libgo/go/net/http/h2_bundle.go @@ -783,6 +783,7 @@ func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMis // dialCall is an in-flight Transport dial call to a host. type http2dialCall struct { + _ http2incomparable p *http2clientConnPool done chan struct{} // closed when done res *http2ClientConn // valid after done is closed @@ -856,6 +857,7 @@ func (p *http2clientConnPool) addConnIfNeeded(key string, t *http2Transport, c * } type http2addConnCall struct { + _ http2incomparable p *http2clientConnPool done chan struct{} // closed when done err error @@ -876,12 +878,6 @@ func (c *http2addConnCall) run(t *http2Transport, key string, tc *tls.Conn) { close(c.done) } -func (p *http2clientConnPool) addConn(key string, cc *http2ClientConn) { - p.mu.Lock() - p.addConnLocked(key, cc) - p.mu.Unlock() -} - // p.mu must be held func (p *http2clientConnPool) addConnLocked(key string, cc *http2ClientConn) { for _, v := range p.conns[key] { @@ -1219,6 +1215,8 @@ var ( // flow is the flow control window's size. type http2flow struct { + _ http2incomparable + // n is the number of DATA bytes we're allowed to send. // A flow is kept both on a conn and a per-stream. n int32 @@ -3245,11 +3243,6 @@ func (s http2SettingID) String() string { return fmt.Sprintf("UNKNOWN_SETTING_%d", uint16(s)) } -var ( - http2errInvalidHeaderFieldName = errors.New("http2: invalid header field name") - http2errInvalidHeaderFieldValue = errors.New("http2: invalid header field value") -) - // validWireHeaderFieldName reports whether v is a valid header field // name (key). See httpguts.ValidHeaderName for the base rules. // @@ -3320,6 +3313,7 @@ func (cw http2closeWaiter) Wait() { // Its buffered writer is lazily allocated as needed, to minimize // idle memory usage with many connections. type http2bufferedWriter struct { + _ http2incomparable w io.Writer // immutable bw *bufio.Writer // non-nil when data is buffered } @@ -3392,6 +3386,7 @@ func http2bodyAllowedForStatus(status int) bool { } type http2httpError struct { + _ http2incomparable msg string timeout bool } @@ -3460,6 +3455,11 @@ func http2validPseudoPath(v string) bool { return (len(v) > 0 && v[0] == '/') || v == "*" } +// incomparable is a zero-width, non-comparable type. Adding it to a struct +// makes that struct also non-comparable, and generally doesn't add +// any size (as long as it's first). +type http2incomparable [0]func() + // pipe is a goroutine-safe io.Reader/io.Writer pair. It's like // io.Pipe except there are no PipeReader/PipeWriter halves, and the // underlying buffer is an interface. (io.Pipe is always unbuffered) @@ -4147,13 +4147,10 @@ type http2stream struct { cancelCtx func() // owned by serverConn's serve loop: - bodyBytes int64 // body bytes seen so far - declBodyBytes int64 // or -1 if undeclared - flow http2flow // limits writing from Handler to client - inflow http2flow // what the client is allowed to POST/etc to us - parent *http2stream // or nil - numTrailerValues int64 - weight uint8 + bodyBytes int64 // body bytes seen so far + declBodyBytes int64 // or -1 if undeclared + flow http2flow // limits writing from Handler to client + inflow http2flow // what the client is allowed to POST/etc to us state http2streamState resetQueued bool // RST_STREAM queued for write; set by sc.resetStream gotTrailerHeader bool // HEADER frame for trailers was seen @@ -4333,6 +4330,7 @@ func (sc *http2serverConn) readFrames() { // frameWriteResult is the message passed from writeFrameAsync to the serve goroutine. type http2frameWriteResult struct { + _ http2incomparable wr http2FrameWriteRequest // what was written (or attempted) err error // result of the writeFrame call } @@ -4343,7 +4341,7 @@ type http2frameWriteResult struct { // serverConn. func (sc *http2serverConn) writeFrameAsync(wr http2FrameWriteRequest) { err := wr.write.writeFrame(sc) - sc.wroteFrameCh <- http2frameWriteResult{wr, err} + sc.wroteFrameCh <- http2frameWriteResult{wr: wr, err: err} } func (sc *http2serverConn) closeAllStreamsOnConnClose() { @@ -4735,7 +4733,7 @@ func (sc *http2serverConn) startFrameWrite(wr http2FrameWriteRequest) { if wr.write.staysWithinBuffer(sc.bw.Available()) { sc.writingFrameAsync = false err := wr.write.writeFrame(sc) - sc.wroteFrame(http2frameWriteResult{wr, err}) + sc.wroteFrame(http2frameWriteResult{wr: wr, err: err}) } else { sc.writingFrameAsync = true go sc.writeFrameAsync(wr) @@ -5849,6 +5847,7 @@ func (sc *http2serverConn) sendWindowUpdate32(st *http2stream, n int32) { // requestBody is the Handler's Request.Body type. // Read and Close may be called concurrently. type http2requestBody struct { + _ http2incomparable stream *http2stream conn *http2serverConn closed bool // for use by Close only @@ -6592,7 +6591,7 @@ type http2Transport struct { // send in the initial settings frame. It is how many bytes // of response headers are allowed. Unlike the http2 spec, zero here // means to use a default limit (currently 10MB). If you actually - // want to advertise an ulimited value to the peer, Transport + // want to advertise an unlimited value to the peer, Transport // interprets the highest possible value here (0xffffffff or 1<<32-1) // to mean no limit. MaxHeaderListSize uint32 @@ -7416,7 +7415,7 @@ func http2commaSeparatedTrailers(req *Request) (string, error) { k = CanonicalHeaderKey(k) switch k { case "Transfer-Encoding", "Trailer", "Content-Length": - return "", &http2badStringError{"invalid Trailer key", k} + return "", fmt.Errorf("invalid Trailer key %q", k) } keys = append(keys, k) } @@ -7909,13 +7908,6 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err er } } -type http2badStringError struct { - what string - str string -} - -func (e *http2badStringError) Error() string { return fmt.Sprintf("%s %q", e.what, e.str) } - // requires cc.mu be held. func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) { cc.hbuf.Reset() @@ -8131,6 +8123,7 @@ func (cc *http2ClientConn) writeHeader(name, value string) { } type http2resAndError struct { + _ http2incomparable res *Response err error } @@ -8178,6 +8171,7 @@ func (cc *http2ClientConn) streamByID(id uint32, andRemove bool) *http2clientStr // clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop. type http2clientConnReadLoop struct { + _ http2incomparable cc *http2ClientConn closeWhenIdle bool } @@ -8407,7 +8401,9 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http return nil, errors.New("malformed response from server: malformed non-numeric status pseudo header") } - header := make(Header) + regularFields := f.RegularFields() + strs := make([]string, len(regularFields)) + header := make(Header, len(regularFields)) res := &Response{ Proto: "HTTP/2.0", ProtoMajor: 2, @@ -8415,7 +8411,7 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http StatusCode: statusCode, Status: status + " " + StatusText(statusCode), } - for _, hf := range f.RegularFields() { + for _, hf := range regularFields { key := CanonicalHeaderKey(hf.Name) if key == "Trailer" { t := res.Trailer @@ -8427,7 +8423,18 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http t[CanonicalHeaderKey(v)] = nil }) } else { - header[key] = append(header[key], hf.Value) + vv := header[key] + if vv == nil && len(strs) > 0 { + // More than likely this will be a single-element key. + // Most headers aren't multi-valued. + // Set the capacity on strs[0] to 1, so any future append + // won't extend the slice into the other strings. + vv, strs = strs[:1:1], strs[1:] + vv[0] = hf.Value + header[key] = vv + } else { + header[key] = append(vv, hf.Value) + } } } @@ -8713,8 +8720,6 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { return nil } -var http2errInvalidTrailers = errors.New("http2: invalid trailers") - func (rl *http2clientConnReadLoop) endStream(cs *http2clientStream) { // TODO: check that any declared content-length matches, like // server.go's (*stream).endStream method. @@ -8945,7 +8950,6 @@ func (cc *http2ClientConn) writeStreamReset(streamID uint32, code http2ErrCode, var ( http2errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit") http2errRequestHeaderListSize = errors.New("http2: request header list larger than peer's advertised limit") - http2errPseudoTrailers = errors.New("http2: invalid pseudo header in trailers") ) func (cc *http2ClientConn) logf(format string, args ...interface{}) { @@ -8984,6 +8988,7 @@ func (rt http2erringRoundTripper) RoundTrip(*Request) (*Response, error) { retur // gzipReader wraps a response body so it can lazily // call gzip.NewReader on the first call to Read type http2gzipReader struct { + _ http2incomparable body io.ReadCloser // underlying Response.Body zr *gzip.Reader // lazily-initialized gzip reader zerr error // sticky error diff --git a/libgo/go/net/http/http.go b/libgo/go/net/http/http.go index 89e86d8..4c5054b 100644 --- a/libgo/go/net/http/http.go +++ b/libgo/go/net/http/http.go @@ -16,6 +16,11 @@ import ( "golang.org/x/net/http/httpguts" ) +// incomparable is a zero-width, non-comparable type. Adding it to a struct +// makes that struct also non-comparable, and generally doesn't add +// any size (as long as it's first). +type incomparable [0]func() + // maxInt64 is the effective "infinite" value for the Server and // Transport's byte-limiting readers. const maxInt64 = 1<<63 - 1 diff --git a/libgo/go/net/http/httptest/recorder.go b/libgo/go/net/http/httptest/recorder.go index d0bc0fa..66e67e7 100644 --- a/libgo/go/net/http/httptest/recorder.go +++ b/libgo/go/net/http/httptest/recorder.go @@ -9,6 +9,7 @@ import ( "fmt" "io/ioutil" "net/http" + "net/textproto" "strconv" "strings" @@ -221,13 +222,13 @@ func (rw *ResponseRecorder) Result() *http.Response { // This a modified version of same function found in net/http/transfer.go. This // one just ignores an invalid header. func parseContentLength(cl string) int64 { - cl = strings.TrimSpace(cl) + cl = textproto.TrimString(cl) if cl == "" { return -1 } - n, err := strconv.ParseInt(cl, 10, 64) + n, err := strconv.ParseUint(cl, 10, 63) if err != nil { return -1 } - return n + return int64(n) } diff --git a/libgo/go/net/http/httptest/recorder_test.go b/libgo/go/net/http/httptest/recorder_test.go index 0986554..e953489 100644 --- a/libgo/go/net/http/httptest/recorder_test.go +++ b/libgo/go/net/http/httptest/recorder_test.go @@ -310,3 +310,39 @@ func TestRecorder(t *testing.T) { }) } } + +// issue 39017 - disallow Content-Length values such as "+3" +func TestParseContentLength(t *testing.T) { + tests := []struct { + cl string + want int64 + }{ + { + cl: "3", + want: 3, + }, + { + cl: "+3", + want: -1, + }, + { + cl: "-3", + want: -1, + }, + { + // max int64, for safe conversion before returning + cl: "9223372036854775807", + want: 9223372036854775807, + }, + { + cl: "9223372036854775808", + want: -1, + }, + } + + for _, tt := range tests { + if got := parseContentLength(tt.cl); got != tt.want { + t.Errorf("%q:\n\tgot=%d\n\twant=%d", tt.cl, got, tt.want) + } + } +} diff --git a/libgo/go/net/http/httputil/reverseproxy.go b/libgo/go/net/http/httputil/reverseproxy.go index 4d6a085..3f48fab 100644 --- a/libgo/go/net/http/httputil/reverseproxy.go +++ b/libgo/go/net/http/httputil/reverseproxy.go @@ -13,6 +13,7 @@ import ( "log" "net" "net/http" + "net/textproto" "net/url" "strings" "sync" @@ -25,10 +26,15 @@ import ( // sends it to another server, proxying the response back to the // client. // -// ReverseProxy automatically sets the client IP as the value of the +// ReverseProxy by default sets the client IP as the value of the // X-Forwarded-For header. +// // If an X-Forwarded-For header already exists, the client IP is -// appended to the existing values. +// appended to the existing values. As a special case, if the header +// exists in the Request.Header map but has a nil value (such as when +// set by the Director func), the X-Forwarded-For header is +// not modified. +// // To prevent IP spoofing, be sure to delete any pre-existing // X-Forwarded-For header coming from the client or // an untrusted proxy. @@ -105,6 +111,27 @@ func singleJoiningSlash(a, b string) string { return a + b } +func joinURLPath(a, b *url.URL) (path, rawpath string) { + if a.RawPath == "" && b.RawPath == "" { + return singleJoiningSlash(a.Path, b.Path), "" + } + // Same as singleJoiningSlash, but uses EscapedPath to determine + // whether a slash should be added + apath := a.EscapedPath() + bpath := b.EscapedPath() + + aslash := strings.HasSuffix(apath, "/") + bslash := strings.HasPrefix(bpath, "/") + + switch { + case aslash && bslash: + return a.Path + b.Path[1:], apath + bpath[1:] + case !aslash && !bslash: + return a.Path + "/" + b.Path, apath + "/" + bpath + } + return a.Path + b.Path, apath + bpath +} + // NewSingleHostReverseProxy returns a new ReverseProxy that routes // URLs to the scheme, host, and base path provided in target. If the // target's path is "/base" and the incoming request was for "/dir", @@ -117,7 +144,7 @@ func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { director := func(req *http.Request) { req.URL.Scheme = target.Scheme req.URL.Host = target.Host - req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) + req.URL.Path, req.URL.RawPath = joinURLPath(target, req.URL) if targetQuery == "" || req.URL.RawQuery == "" { req.URL.RawQuery = targetQuery + req.URL.RawQuery } else { @@ -248,10 +275,14 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // If we aren't the first proxy retain prior // X-Forwarded-For information as a comma+space // separated list and fold multiple headers into one. - if prior, ok := outreq.Header["X-Forwarded-For"]; ok { + prior, ok := outreq.Header["X-Forwarded-For"] + omit := ok && prior == nil // Issue 38079: nil now means don't populate the header + if len(prior) > 0 { clientIP = strings.Join(prior, ", ") + ", " + clientIP } - outreq.Header.Set("X-Forwarded-For", clientIP) + if !omit { + outreq.Header.Set("X-Forwarded-For", clientIP) + } } res, err := transport.RoundTrip(outreq) @@ -357,7 +388,7 @@ func shouldPanicOnCopyError(req *http.Request) bool { func removeConnectionHeaders(h http.Header) { for _, f := range h["Connection"] { for _, sf := range strings.Split(f, ",") { - if sf = strings.TrimSpace(sf); sf != "" { + if sf = textproto.TrimString(sf); sf != "" { h.Del(sf) } } @@ -526,7 +557,20 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) return } - defer backConn.Close() + + backConnCloseCh := make(chan bool) + go func() { + // Ensure that the cancelation of a request closes the backend. + // See issue https://golang.org/issue/35559. + select { + case <-req.Context().Done(): + case <-backConnCloseCh: + } + backConn.Close() + }() + + defer close(backConnCloseCh) + conn, brw, err := hj.Hijack() if err != nil { p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err)) diff --git a/libgo/go/net/http/httputil/reverseproxy_test.go b/libgo/go/net/http/httputil/reverseproxy_test.go index f58e088..764939f 100644 --- a/libgo/go/net/http/httputil/reverseproxy_test.go +++ b/libgo/go/net/http/httputil/reverseproxy_test.go @@ -277,6 +277,39 @@ func TestXForwardedFor(t *testing.T) { } } +// Issue 38079: don't append to X-Forwarded-For if it's present but nil +func TestXForwardedFor_Omit(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if v := r.Header.Get("X-Forwarded-For"); v != "" { + t.Errorf("got X-Forwarded-For header: %q", v) + } + w.Write([]byte("hi")) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + oldDirector := proxyHandler.Director + proxyHandler.Director = func(r *http.Request) { + r.Header["X-Forwarded-For"] = nil + oldDirector(r) + } + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Host = "some-name" + getReq.Close = true + res, err := frontend.Client().Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + res.Body.Close() +} + var proxyQueryTests = []struct { baseSuffix string // suffix to add to backend URL reqSuffix string // suffix to add to frontend's request URL @@ -386,7 +419,7 @@ func TestReverseProxyFlushIntervalHeaders(t *testing.T) { } } -func TestReverseProxyCancelation(t *testing.T) { +func TestReverseProxyCancellation(t *testing.T) { const backendResponse = "I am the backend" reqInFlight := make(chan struct{}) @@ -1158,6 +1191,137 @@ func TestReverseProxyWebSocket(t *testing.T) { } } +func TestReverseProxyWebSocketCancelation(t *testing.T) { + n := 5 + triggerCancelCh := make(chan bool, n) + nthResponse := func(i int) string { + return fmt.Sprintf("backend response #%d\n", i) + } + terminalMsg := "final message" + + cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if g, ws := upgradeType(r.Header), "websocket"; g != ws { + t.Errorf("Unexpected upgrade type %q, want %q", g, ws) + http.Error(w, "Unexpected request", 400) + return + } + conn, bufrw, err := w.(http.Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + defer conn.Close() + + upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n" + if _, err := io.WriteString(conn, upgradeMsg); err != nil { + t.Error(err) + return + } + if _, _, err := bufrw.ReadLine(); err != nil { + t.Errorf("Failed to read line from client: %v", err) + return + } + + for i := 0; i < n; i++ { + if _, err := bufrw.WriteString(nthResponse(i)); err != nil { + select { + case <-triggerCancelCh: + default: + t.Errorf("Writing response #%d failed: %v", i, err) + } + return + } + bufrw.Flush() + time.Sleep(time.Second) + } + if _, err := bufrw.WriteString(terminalMsg); err != nil { + select { + case <-triggerCancelCh: + default: + t.Errorf("Failed to write terminal message: %v", err) + } + } + bufrw.Flush() + })) + defer cst.Close() + + backendURL, _ := url.Parse(cst.URL) + rproxy := NewSingleHostReverseProxy(backendURL) + rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + rproxy.ModifyResponse = func(res *http.Response) error { + res.Header.Add("X-Modified", "true") + return nil + } + + handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("X-Header", "X-Value") + ctx, cancel := context.WithCancel(req.Context()) + go func() { + <-triggerCancelCh + cancel() + }() + rproxy.ServeHTTP(rw, req.WithContext(ctx)) + }) + + frontendProxy := httptest.NewServer(handler) + defer frontendProxy.Close() + + req, _ := http.NewRequest("GET", frontendProxy.URL, nil) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + + res, err := frontendProxy.Client().Do(req) + if err != nil { + t.Fatalf("Dialing to frontend proxy: %v", err) + } + defer res.Body.Close() + if g, w := res.StatusCode, 101; g != w { + t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w) + } + + if g, w := res.Header.Get("X-Header"), "X-Value"; g != w { + t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w) + } + + if g, w := upgradeType(res.Header), "websocket"; g != w { + t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w) + } + + rwc, ok := res.Body.(io.ReadWriteCloser) + if !ok { + t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body) + } + + if got, want := res.Header.Get("X-Modified"), "true"; got != want { + t.Errorf("response X-Modified header = %q; want %q", got, want) + } + + if _, err := io.WriteString(rwc, "Hello\n"); err != nil { + t.Fatalf("Failed to write first message: %v", err) + } + + // Read loop. + + br := bufio.NewReader(rwc) + for { + line, err := br.ReadString('\n') + switch { + case line == terminalMsg: // this case before "err == io.EOF" + t.Fatalf("The websocket request was not canceled, unfortunately!") + + case err == io.EOF: + return + + case err != nil: + t.Fatalf("Unexpected error: %v", err) + + case line == nthResponse(0): // We've gotten the first response back + // Let's trigger a cancel. + close(triggerCancelCh) + } + } +} + func TestUnannouncedTrailer(t *testing.T) { backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -1202,7 +1366,7 @@ func TestSingleJoinSlash(t *testing.T) { } for _, tt := range tests { if got := singleJoiningSlash(tt.slasha, tt.slashb); got != tt.expected { - t.Errorf("singleJoiningSlash(%s,%s) want %s got %s", + t.Errorf("singleJoiningSlash(%q,%q) want %q got %q", tt.slasha, tt.slashb, tt.expected, @@ -1210,3 +1374,30 @@ func TestSingleJoinSlash(t *testing.T) { } } } + +func TestJoinURLPath(t *testing.T) { + tests := []struct { + a *url.URL + b *url.URL + wantPath string + wantRaw string + }{ + {&url.URL{Path: "/a/b"}, &url.URL{Path: "/c"}, "/a/b/c", ""}, + {&url.URL{Path: "/a/b", RawPath: "badpath"}, &url.URL{Path: "c"}, "/a/b/c", "/a/b/c"}, + {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, + {&url.URL{Path: "/a/b", RawPath: "/a%2Fb"}, &url.URL{Path: "/c"}, "/a/b/c", "/a%2Fb/c"}, + {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb%2F"}, &url.URL{Path: "c"}, "/a/b//c", "/a%2Fb%2F/c"}, + {&url.URL{Path: "/a/b/", RawPath: "/a%2Fb/"}, &url.URL{Path: "/c/d", RawPath: "/c%2Fd"}, "/a/b/c/d", "/a%2Fb/c%2Fd"}, + } + + for _, tt := range tests { + p, rp := joinURLPath(tt.a, tt.b) + if p != tt.wantPath || rp != tt.wantRaw { + t.Errorf("joinURLPath(URL(%q,%q),URL(%q,%q)) want (%q,%q) got (%q,%q)", + tt.a.Path, tt.a.RawPath, + tt.b.Path, tt.b.RawPath, + tt.wantPath, tt.wantRaw, + p, rp) + } + } +} diff --git a/libgo/go/net/http/omithttp2.go b/libgo/go/net/http/omithttp2.go index 307d93a..7e2f492 100644 --- a/libgo/go/net/http/omithttp2.go +++ b/libgo/go/net/http/omithttp2.go @@ -32,7 +32,7 @@ type http2Transport struct { func (*http2Transport) RoundTrip(*Request) (*Response, error) { panic(noHTTP2) } func (*http2Transport) CloseIdleConnections() {} -type http2erringRoundTripper struct{} +type http2erringRoundTripper struct{ err error } func (http2erringRoundTripper) RoundTrip(*Request) (*Response, error) { panic(noHTTP2) } diff --git a/libgo/go/net/http/pprof/pprof.go b/libgo/go/net/http/pprof/pprof.go index a237f58..81df044 100644 --- a/libgo/go/net/http/pprof/pprof.go +++ b/libgo/go/net/http/pprof/pprof.go @@ -36,15 +36,17 @@ // // go tool pprof http://localhost:6060/debug/pprof/block // -// Or to collect a 5-second execution trace: -// -// wget http://localhost:6060/debug/pprof/trace?seconds=5 -// // Or to look at the holders of contended mutexes, after calling // runtime.SetMutexProfileFraction in your program: // // go tool pprof http://localhost:6060/debug/pprof/mutex // +// The package also exports a handler that serves execution trace data +// for the "go tool trace" command. To collect a 5-second execution trace: +// +// wget -O trace.out http://localhost:6060/debug/pprof/trace?seconds=5 +// go tool trace trace.out +// // To view all available profiles, open http://localhost:6060/debug/pprof/ // in your browser. // @@ -57,8 +59,10 @@ package pprof import ( "bufio" "bytes" + "context" "fmt" "html/template" + "internal/profile" "io" "log" "net/http" @@ -234,6 +238,10 @@ func (name handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { serveError(w, http.StatusNotFound, "Unknown profile") return } + if sec := r.FormValue("seconds"); sec != "" { + name.serveDeltaProfile(w, r, p, sec) + return + } gc, _ := strconv.Atoi(r.FormValue("gc")) if name == "heap" && gc > 0 { runtime.GC() @@ -248,6 +256,94 @@ func (name handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { p.WriteTo(w, debug) } +func (name handler) serveDeltaProfile(w http.ResponseWriter, r *http.Request, p *pprof.Profile, secStr string) { + sec, err := strconv.ParseInt(secStr, 10, 64) + if err != nil || sec <= 0 { + serveError(w, http.StatusBadRequest, `invalid value for "seconds" - must be a positive integer`) + return + } + if !profileSupportsDelta[name] { + serveError(w, http.StatusBadRequest, `"seconds" parameter is not supported for this profile type`) + return + } + // 'name' should be a key in profileSupportsDelta. + if durationExceedsWriteTimeout(r, float64(sec)) { + serveError(w, http.StatusBadRequest, "profile duration exceeds server's WriteTimeout") + return + } + debug, _ := strconv.Atoi(r.FormValue("debug")) + if debug != 0 { + serveError(w, http.StatusBadRequest, "seconds and debug params are incompatible") + return + } + p0, err := collectProfile(p) + if err != nil { + serveError(w, http.StatusInternalServerError, "failed to collect profile") + return + } + + t := time.NewTimer(time.Duration(sec) * time.Second) + defer t.Stop() + + select { + case <-r.Context().Done(): + err := r.Context().Err() + if err == context.DeadlineExceeded { + serveError(w, http.StatusRequestTimeout, err.Error()) + } else { // TODO: what's a good status code for cancelled requests? 400? + serveError(w, http.StatusInternalServerError, err.Error()) + } + return + case <-t.C: + } + + p1, err := collectProfile(p) + if err != nil { + serveError(w, http.StatusInternalServerError, "failed to collect profile") + return + } + ts := p1.TimeNanos + dur := p1.TimeNanos - p0.TimeNanos + + p0.Scale(-1) + + p1, err = profile.Merge([]*profile.Profile{p0, p1}) + if err != nil { + serveError(w, http.StatusInternalServerError, "failed to compute delta") + return + } + + p1.TimeNanos = ts // set since we don't know what profile.Merge set for TimeNanos. + p1.DurationNanos = dur + + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s-delta"`, name)) + p1.Write(w) +} + +func collectProfile(p *pprof.Profile) (*profile.Profile, error) { + var buf bytes.Buffer + if err := p.WriteTo(&buf, 0); err != nil { + return nil, err + } + ts := time.Now().UnixNano() + p0, err := profile.Parse(&buf) + if err != nil { + return nil, err + } + p0.TimeNanos = ts + return p0, nil +} + +var profileSupportsDelta = map[handler]bool{ + "allocs": true, + "block": true, + "goroutine": true, + "heap": true, + "mutex": true, + "threadcreate": true, +} + var profileDescriptions = map[string]string{ "allocs": "A sampling of all past memory allocations", "block": "Stack traces that led to blocking on synchronization primitives", @@ -273,6 +369,9 @@ func Index(w http.ResponseWriter, r *http.Request) { } } + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("Content-Type", "text/html; charset=utf-8") + type profile struct { Name string Href string diff --git a/libgo/go/net/http/pprof/pprof_test.go b/libgo/go/net/http/pprof/pprof_test.go index dbb6fef..f6f9ef5 100644 --- a/libgo/go/net/http/pprof/pprof_test.go +++ b/libgo/go/net/http/pprof/pprof_test.go @@ -6,11 +6,18 @@ package pprof import ( "bytes" + "fmt" + "internal/profile" "io/ioutil" "net/http" "net/http/httptest" + "runtime" "runtime/pprof" + "strings" + "sync" + "sync/atomic" "testing" + "time" ) // TestDescriptions checks that the profile names under runtime/pprof package @@ -40,6 +47,10 @@ func TestHandlers(t *testing.T) { {"/debug/pprof/profile?seconds=1", Profile, http.StatusOK, "application/octet-stream", `attachment; filename="profile"`, nil}, {"/debug/pprof/symbol", Symbol, http.StatusOK, "text/plain; charset=utf-8", "", nil}, {"/debug/pprof/trace", Trace, http.StatusOK, "application/octet-stream", `attachment; filename="trace"`, nil}, + {"/debug/pprof/mutex", Index, http.StatusOK, "application/octet-stream", `attachment; filename="mutex"`, nil}, + {"/debug/pprof/block?seconds=1", Index, http.StatusOK, "application/octet-stream", `attachment; filename="block-delta"`, nil}, + {"/debug/pprof/goroutine?seconds=1", Index, http.StatusOK, "application/octet-stream", `attachment; filename="goroutine-delta"`, nil}, + {"/debug/pprof/", Index, http.StatusOK, "text/html; charset=utf-8", "", []byte("Types of profiles available:")}, } for _, tc := range testCases { t.Run(tc.path, func(t *testing.T) { @@ -77,5 +88,171 @@ func TestHandlers(t *testing.T) { } }) } +} + +var Sink uint32 + +func mutexHog1(mu1, mu2 *sync.Mutex, start time.Time, dt time.Duration) { + atomic.AddUint32(&Sink, 1) + for time.Since(start) < dt { + // When using gccgo the loop of mutex operations is + // not preemptible. This can cause the loop to block a GC, + // causing the time limits in TestDeltaContentionz to fail. + // Since this loop is not very realistic, when using + // gccgo add preemption points 100 times a second. + t1 := time.Now() + for time.Since(start) < dt && time.Since(t1) < 10*time.Millisecond { + mu1.Lock() + mu2.Lock() + mu1.Unlock() + mu2.Unlock() + } + if runtime.Compiler == "gccgo" { + runtime.Gosched() + } + } +} + +// mutexHog2 is almost identical to mutexHog but we keep them separate +// in order to distinguish them with function names in the stack trace. +// We make them slightly different, using Sink, because otherwise +// gccgo -c opt will merge them. +func mutexHog2(mu1, mu2 *sync.Mutex, start time.Time, dt time.Duration) { + atomic.AddUint32(&Sink, 2) + for time.Since(start) < dt { + // See comment in mutexHog. + t1 := time.Now() + for time.Since(start) < dt && time.Since(t1) < 10*time.Millisecond { + mu1.Lock() + mu2.Lock() + mu1.Unlock() + mu2.Unlock() + } + if runtime.Compiler == "gccgo" { + runtime.Gosched() + } + } +} +// mutexHog starts multiple goroutines that runs the given hogger function for the specified duration. +// The hogger function will be given two mutexes to lock & unlock. +func mutexHog(duration time.Duration, hogger func(mu1, mu2 *sync.Mutex, start time.Time, dt time.Duration)) { + start := time.Now() + mu1 := new(sync.Mutex) + mu2 := new(sync.Mutex) + var wg sync.WaitGroup + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + hogger(mu1, mu2, start, duration) + }() + } + wg.Wait() +} + +func TestDeltaProfile(t *testing.T) { + rate := runtime.SetMutexProfileFraction(1) + defer func() { + runtime.SetMutexProfileFraction(rate) + }() + + // mutexHog1 will appear in non-delta mutex profile + // if the mutex profile works. + mutexHog(20*time.Millisecond, mutexHog1) + + // If mutexHog1 does not appear in the mutex profile, + // skip this test. Mutex profile is likely not working, + // so is the delta profile. + + p, err := query("/debug/pprof/mutex") + if err != nil { + t.Skipf("mutex profile is unsupported: %v", err) + } + + if !seen(p, "mutexHog1") { + t.Skipf("mutex profile is not working: %v", p) + } + + // causes mutexHog2 call stacks to appear in the mutex profile. + done := make(chan bool) + go func() { + for { + mutexHog(20*time.Millisecond, mutexHog2) + select { + case <-done: + done <- true + return + default: + time.Sleep(10 * time.Millisecond) + } + } + }() + defer func() { // cleanup the above goroutine. + done <- true + <-done // wait for the goroutine to exit. + }() + + for _, d := range []int{1, 4, 16, 32} { + endpoint := fmt.Sprintf("/debug/pprof/mutex?seconds=%d", d) + p, err := query(endpoint) + if err != nil { + t.Fatalf("failed to query %q: %v", endpoint, err) + } + if !seen(p, "mutexHog1") && seen(p, "mutexHog2") && p.DurationNanos > 0 { + break // pass + } + if d == 32 { + t.Errorf("want mutexHog2 but no mutexHog1 in the profile, and non-zero p.DurationNanos, got %v", p) + } + } + p, err = query("/debug/pprof/mutex") + if err != nil { + t.Fatalf("failed to query mutex profile: %v", err) + } + if !seen(p, "mutexHog1") || !seen(p, "mutexHog2") { + t.Errorf("want both mutexHog1 and mutexHog2 in the profile, got %v", p) + } +} + +var srv = httptest.NewServer(nil) + +func query(endpoint string) (*profile.Profile, error) { + url := srv.URL + endpoint + r, err := http.Get(url) + if err != nil { + return nil, fmt.Errorf("failed to fetch %q: %v", url, err) + } + if r.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to fetch %q: %v", url, r.Status) + } + + b, err := ioutil.ReadAll(r.Body) + r.Body.Close() + if err != nil { + return nil, fmt.Errorf("failed to read and parse the result from %q: %v", url, err) + } + return profile.Parse(bytes.NewBuffer(b)) +} + +// seen returns true if the profile includes samples whose stacks include +// the specified function name (fname). +func seen(p *profile.Profile, fname string) bool { + locIDs := map[*profile.Location]bool{} + for _, loc := range p.Location { + for _, l := range loc.Line { + if strings.Contains(l.Function.Name, fname) { + locIDs[loc] = true + break + } + } + } + for _, sample := range p.Sample { + for _, loc := range sample.Location { + if locIDs[loc] { + return true + } + } + } + return false } diff --git a/libgo/go/net/http/proxy_test.go b/libgo/go/net/http/proxy_test.go index feb7047..0dd57b4 100644 --- a/libgo/go/net/http/proxy_test.go +++ b/libgo/go/net/http/proxy_test.go @@ -35,7 +35,7 @@ func TestCacheKeys(t *testing.T) { } proxy = u } - cm := connectMethod{proxy, tt.scheme, tt.addr, false} + cm := connectMethod{proxyURL: proxy, targetScheme: tt.scheme, targetAddr: tt.addr} if got := cm.key().String(); got != tt.key { t.Fatalf("{%q, %q, %q} cache key = %q; want %q", tt.proxy, tt.scheme, tt.addr, got, tt.key) } diff --git a/libgo/go/net/http/request.go b/libgo/go/net/http/request.go index 88fa093..fe6b6098 100644 --- a/libgo/go/net/http/request.go +++ b/libgo/go/net/http/request.go @@ -83,12 +83,7 @@ var ( ErrMissingContentLength = &ProtocolError{"missing ContentLength in HEAD response"} ) -type badStringError struct { - what string - str string -} - -func (e *badStringError) Error() string { return fmt.Sprintf("%s %q", e.what, e.str) } +func badStringError(what, val string) error { return fmt.Errorf("%s %q", what, val) } // Headers that Request.Write handles itself and should be skipped. var reqWriteExcludeHeader = map[string]bool{ @@ -430,6 +425,8 @@ func (r *Request) Cookie(name string) (*Cookie, error) { // AddCookie does not attach more than one Cookie header field. That // means all cookies, if any, are written into the same line, // separated by semicolon. +// AddCookie only sanitizes c's name and value, and does not sanitize +// a Cookie header already present in the request. func (r *Request) AddCookie(c *Cookie) { s := fmt.Sprintf("%s=%s", sanitizeCookieName(c.Name), sanitizeCookieValue(c.Value)) if c := r.Header.Get("Cookie"); c != "" { @@ -506,7 +503,7 @@ func valueOrDefault(value, def string) string { // NOTE: This is not intended to reflect the actual Go version being used. // It was changed at the time of Go 1.1 release because the former User-Agent -// had ended up on a blacklist for some intrusion detection systems. +// had ended up blocked by some intrusion detection systems. // See https://codereview.appspot.com/7532043. const defaultUserAgent = "Go-http-client/1.1" @@ -1025,14 +1022,14 @@ func readRequest(b *bufio.Reader, deleteHostHeader bool) (req *Request, err erro var ok bool req.Method, req.RequestURI, req.Proto, ok = parseRequestLine(s) if !ok { - return nil, &badStringError{"malformed HTTP request", s} + return nil, badStringError("malformed HTTP request", s) } if !validMethod(req.Method) { - return nil, &badStringError{"invalid method", req.Method} + return nil, badStringError("invalid method", req.Method) } rawurl := req.RequestURI if req.ProtoMajor, req.ProtoMinor, ok = ParseHTTPVersion(req.Proto); !ok { - return nil, &badStringError{"malformed HTTP version", req.Proto} + return nil, badStringError("malformed HTTP version", req.Proto) } // CONNECT requests are used two different ways, and neither uses a full URL: diff --git a/libgo/go/net/http/response.go b/libgo/go/net/http/response.go index cd9d796..72812f0 100644 --- a/libgo/go/net/http/response.go +++ b/libgo/go/net/http/response.go @@ -166,7 +166,7 @@ func ReadResponse(r *bufio.Reader, req *Request) (*Response, error) { return nil, err } if i := strings.IndexByte(line, ' '); i == -1 { - return nil, &badStringError{"malformed HTTP response", line} + return nil, badStringError("malformed HTTP response", line) } else { resp.Proto = line[:i] resp.Status = strings.TrimLeft(line[i+1:], " ") @@ -176,15 +176,15 @@ func ReadResponse(r *bufio.Reader, req *Request) (*Response, error) { statusCode = resp.Status[:i] } if len(statusCode) != 3 { - return nil, &badStringError{"malformed HTTP status code", statusCode} + return nil, badStringError("malformed HTTP status code", statusCode) } resp.StatusCode, err = strconv.Atoi(statusCode) if err != nil || resp.StatusCode < 0 { - return nil, &badStringError{"malformed HTTP status code", statusCode} + return nil, badStringError("malformed HTTP status code", statusCode) } var ok bool if resp.ProtoMajor, resp.ProtoMinor, ok = ParseHTTPVersion(resp.Proto); !ok { - return nil, &badStringError{"malformed HTTP version", resp.Proto} + return nil, badStringError("malformed HTTP version", resp.Proto) } // Parse the response headers. diff --git a/libgo/go/net/http/response_test.go b/libgo/go/net/http/response_test.go index 0c78df6..ce87260 100644 --- a/libgo/go/net/http/response_test.go +++ b/libgo/go/net/http/response_test.go @@ -734,6 +734,7 @@ func TestReadResponseCloseInMiddle(t *testing.T) { } func diff(t *testing.T, prefix string, have, want interface{}) { + t.Helper() hv := reflect.ValueOf(have).Elem() wv := reflect.ValueOf(want).Elem() if hv.Type() != wv.Type() { diff --git a/libgo/go/net/http/roundtrip_js.go b/libgo/go/net/http/roundtrip_js.go index 4dd9965..509d229 100644 --- a/libgo/go/net/http/roundtrip_js.go +++ b/libgo/go/net/http/roundtrip_js.go @@ -102,12 +102,17 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { js.CopyBytesToJS(buf, body) opt.Set("body", buf) } - respPromise := js.Global().Call("fetch", req.URL.String(), opt) + + fetchPromise := js.Global().Call("fetch", req.URL.String(), opt) var ( - respCh = make(chan *Response, 1) - errCh = make(chan error, 1) + respCh = make(chan *Response, 1) + errCh = make(chan error, 1) + success, failure js.Func ) - success := js.FuncOf(func(this js.Value, args []js.Value) interface{} { + success = js.FuncOf(func(this js.Value, args []js.Value) interface{} { + success.Release() + failure.Release() + result := args[0] header := Header{} // https://developer.mozilla.org/en-US/docs/Web/API/Headers/entries @@ -141,35 +146,29 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { } code := result.Get("status").Int() - select { - case respCh <- &Response{ + respCh <- &Response{ Status: fmt.Sprintf("%d %s", code, StatusText(code)), StatusCode: code, Header: header, ContentLength: contentLength, Body: body, Request: req, - }: - case <-req.Context().Done(): } return nil }) - defer success.Release() - failure := js.FuncOf(func(this js.Value, args []js.Value) interface{} { - err := fmt.Errorf("net/http: fetch() failed: %s", args[0].String()) - select { - case errCh <- err: - case <-req.Context().Done(): - } + failure = js.FuncOf(func(this js.Value, args []js.Value) interface{} { + success.Release() + failure.Release() + errCh <- fmt.Errorf("net/http: fetch() failed: %s", args[0].Get("message").String()) return nil }) - defer failure.Release() - respPromise.Call("then", success, failure) + + fetchPromise.Call("then", success, failure) select { case <-req.Context().Done(): if !ac.IsUndefined() { - // Abort the Fetch request + // Abort the Fetch request. ac.Call("abort") } return nil, req.Context().Err() diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go index 29b9379..5f56932 100644 --- a/libgo/go/net/http/serve_test.go +++ b/libgo/go/net/http/serve_test.go @@ -947,7 +947,7 @@ func TestOnlyWriteTimeout(t *testing.T) { c := ts.Client() - errc := make(chan error) + errc := make(chan error, 1) go func() { res, err := c.Get(ts.URL) if err != nil { @@ -1057,16 +1057,13 @@ func TestIdentityResponse(t *testing.T) { t.Fatalf("error writing: %v", err) } - // The ReadAll will hang for a failing test, so use a Timer to - // fail explicitly. - goTimeout(t, 2*time.Second, func() { - got, _ := ioutil.ReadAll(conn) - expectedSuffix := "\r\n\r\ntoo short" - if !strings.HasSuffix(string(got), expectedSuffix) { - t.Errorf("Expected output to end with %q; got response body %q", - expectedSuffix, string(got)) - } - }) + // The ReadAll will hang for a failing test. + got, _ := ioutil.ReadAll(conn) + expectedSuffix := "\r\n\r\ntoo short" + if !strings.HasSuffix(string(got), expectedSuffix) { + t.Errorf("Expected output to end with %q; got response body %q", + expectedSuffix, string(got)) + } } func testTCPConnectionCloses(t *testing.T, req string, h Handler) { @@ -1350,37 +1347,6 @@ func TestServerAllowsBlockingRemoteAddr(t *testing.T) { } } -func TestIdentityResponseHeaders(t *testing.T) { - // Not parallel; changes log output. - defer afterTest(t) - log.SetOutput(ioutil.Discard) // is noisy otherwise - defer log.SetOutput(os.Stderr) - - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - w.Header().Set("Transfer-Encoding", "identity") - w.(Flusher).Flush() - fmt.Fprintf(w, "I am an identity response.") - })) - defer ts.Close() - - c := ts.Client() - res, err := c.Get(ts.URL) - if err != nil { - t.Fatalf("Get error: %v", err) - } - defer res.Body.Close() - - if g, e := res.TransferEncoding, []string(nil); !reflect.DeepEqual(g, e) { - t.Errorf("expected TransferEncoding of %v; got %v", e, g) - } - if _, haveCL := res.Header["Content-Length"]; haveCL { - t.Errorf("Unexpected Content-Length") - } - if !res.Close { - t.Errorf("expected Connection: close; got %v", res.Close) - } -} - // TestHeadResponses verifies that all MIME type sniffing and Content-Length // counting of GET requests also happens on HEAD requests. func TestHeadResponses_h1(t *testing.T) { testHeadResponses(t, h1Mode) } @@ -1438,13 +1404,13 @@ func TestTLSHandshakeTimeout(t *testing.T) { t.Fatalf("Dial: %v", err) } defer conn.Close() - goTimeout(t, 10*time.Second, func() { - var buf [1]byte - n, err := conn.Read(buf[:]) - if err == nil || n != 0 { - t.Errorf("Read = %d, %v; want an error and no bytes", n, err) - } - }) + + var buf [1]byte + n, err := conn.Read(buf[:]) + if err == nil || n != 0 { + t.Errorf("Read = %d, %v; want an error and no bytes", n, err) + } + select { case v := <-errc: if !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") { @@ -1479,30 +1445,29 @@ func TestTLSServer(t *testing.T) { t.Fatalf("Dial: %v", err) } defer idleConn.Close() - goTimeout(t, 10*time.Second, func() { - if !strings.HasPrefix(ts.URL, "https://") { - t.Errorf("expected test TLS server to start with https://, got %q", ts.URL) - return - } - client := ts.Client() - res, err := client.Get(ts.URL) - if err != nil { - t.Error(err) - return - } - if res == nil { - t.Errorf("got nil Response") - return - } - defer res.Body.Close() - if res.Header.Get("X-TLS-Set") != "true" { - t.Errorf("expected X-TLS-Set response header") - return - } - if res.Header.Get("X-TLS-HandshakeComplete") != "true" { - t.Errorf("expected X-TLS-HandshakeComplete header") - } - }) + + if !strings.HasPrefix(ts.URL, "https://") { + t.Errorf("expected test TLS server to start with https://, got %q", ts.URL) + return + } + client := ts.Client() + res, err := client.Get(ts.URL) + if err != nil { + t.Error(err) + return + } + if res == nil { + t.Errorf("got nil Response") + return + } + defer res.Body.Close() + if res.Header.Get("X-TLS-Set") != "true" { + t.Errorf("expected X-TLS-Set response header") + return + } + if res.Header.Get("X-TLS-HandshakeComplete") != "true" { + t.Errorf("expected X-TLS-HandshakeComplete header") + } } func TestServeTLS(t *testing.T) { @@ -3629,21 +3594,6 @@ func TestHeaderToWire(t *testing.T) { } } -// goTimeout runs f, failing t if f takes more than ns to complete. -func goTimeout(t *testing.T, d time.Duration, f func()) { - ch := make(chan bool, 2) - timer := time.AfterFunc(d, func() { - t.Errorf("Timeout expired after %v", d) - ch <- true - }) - defer timer.Stop() - go func() { - defer func() { ch <- true }() - f() - }() - <-ch -} - type errorListener struct { errs []error } @@ -4135,10 +4085,19 @@ func TestServerConnState(t *testing.T) { doRequests() - timer := time.NewTimer(5 * time.Second) + stateDelay := 5 * time.Second + if deadline, ok := t.Deadline(); ok { + // Allow an arbitrarily long delay. + // This test was observed to be flaky on the darwin-arm64-corellium builder, + // so we're increasing the deadline to see if it starts passing. + // See https://golang.org/issue/37322. + const arbitraryCleanupMargin = 1 * time.Second + stateDelay = time.Until(deadline) - arbitraryCleanupMargin + } + timer := time.NewTimer(stateDelay) select { case <-timer.C: - t.Errorf("Timed out waiting for connection to change state.") + t.Errorf("Timed out after %v waiting for connection to change state.", stateDelay) case <-complete: timer.Stop() } @@ -5167,8 +5126,14 @@ func BenchmarkClient(b *testing.B) { } done := make(chan error) + stop := make(chan struct{}) + defer close(stop) go func() { - done <- cmd.Wait() + select { + case <-stop: + return + case done <- cmd.Wait(): + } }() // Do b.N requests to the server. @@ -5984,8 +5949,11 @@ type countCloseListener struct { } func (p *countCloseListener) Close() error { - atomic.AddInt32(&p.closes, 1) - return nil + var err error + if n := atomic.AddInt32(&p.closes, 1); n == 1 && p.Listener != nil { + err = p.Listener.Close() + } + return err } // Issue 24803: don't call Listener.Close on Server.Shutdown. diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go index 77329b2..6f7a259 100644 --- a/libgo/go/net/http/server.go +++ b/libgo/go/net/http/server.go @@ -425,6 +425,16 @@ type response struct { wants10KeepAlive bool // HTTP/1.0 w/ Connection "keep-alive" wantsClose bool // HTTP request has Connection "close" + // canWriteContinue is a boolean value accessed as an atomic int32 + // that says whether or not a 100 Continue header can be written + // to the connection. + // writeContinueMu must be held while writing the header. + // These two fields together synchronize the body reader + // (the expectContinueReader, which wants to write 100 Continue) + // against the main writer. + canWriteContinue atomicBool + writeContinueMu sync.Mutex + w *bufio.Writer // buffers output in chunks to chunkWriter cw chunkWriter @@ -515,6 +525,7 @@ type atomicBool int32 func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 } func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) } +func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) } // declareTrailer is called for each Trailer header when the // response header is written. It notes that a header will need to be @@ -629,6 +640,7 @@ func (srv *Server) newConn(rwc net.Conn) *conn { } type readResult struct { + _ incomparable n int err error b byte // byte read, if n == 1 @@ -877,21 +889,27 @@ type expectContinueReader struct { resp *response readCloser io.ReadCloser closed bool - sawEOF bool + sawEOF atomicBool } func (ecr *expectContinueReader) Read(p []byte) (n int, err error) { if ecr.closed { return 0, ErrBodyReadAfterClose } - if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked() { - ecr.resp.wroteContinue = true - ecr.resp.conn.bufw.WriteString("HTTP/1.1 100 Continue\r\n\r\n") - ecr.resp.conn.bufw.Flush() + w := ecr.resp + if !w.wroteContinue && w.canWriteContinue.isSet() && !w.conn.hijacked() { + w.wroteContinue = true + w.writeContinueMu.Lock() + if w.canWriteContinue.isSet() { + w.conn.bufw.WriteString("HTTP/1.1 100 Continue\r\n\r\n") + w.conn.bufw.Flush() + w.canWriteContinue.setFalse() + } + w.writeContinueMu.Unlock() } n, err = ecr.readCloser.Read(p) if err == io.EOF { - ecr.sawEOF = true + ecr.sawEOF.setTrue() } return } @@ -1315,7 +1333,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { // because we don't know if the next bytes on the wire will be // the body-following-the-timer or the subsequent request. // See Issue 11549. - if ecr, ok := w.req.Body.(*expectContinueReader); ok && !ecr.sawEOF { + if ecr, ok := w.req.Body.(*expectContinueReader); ok && !ecr.sawEOF.isSet() { w.closeAfterReply = true } @@ -1565,6 +1583,17 @@ func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err er } return 0, ErrHijacked } + + if w.canWriteContinue.isSet() { + // Body reader wants to write 100 Continue but hasn't yet. + // Tell it not to. The store must be done while holding the lock + // because the lock makes sure that there is not an active write + // this very moment. + w.writeContinueMu.Lock() + w.canWriteContinue.setFalse() + w.writeContinueMu.Unlock() + } + if !w.wroteHeader { w.WriteHeader(StatusOK) } @@ -1702,9 +1731,9 @@ func (c *conn) closeWriteAndWait() { time.Sleep(rstAvoidanceDelay) } -// validNextProto reports whether the proto is not a blacklisted ALPN -// protocol name. Empty and built-in protocol types are blacklisted -// and can't be overridden with alternate implementations. +// validNextProto reports whether the proto is a valid ALPN protocol name. +// Everything is valid except the empty string and built-in protocol types, +// so that those can't be overridden with alternate implementations. func validNextProto(proto string) bool { switch proto { case "", "http/1.1", "http/1.0": @@ -1876,6 +1905,7 @@ func (c *conn) serve(ctx context.Context) { if req.ProtoAtLeast(1, 1) && req.ContentLength != 0 { // Wrap the Body reader with one that replies on the connection req.Body = &expectContinueReader{readCloser: req.Body, resp: w} + w.canWriteContinue.setTrue() } } else if req.Header.get("Expect") != "" { w.sendExpectationFailed() @@ -2582,8 +2612,9 @@ type Server struct { // value. ConnContext func(ctx context.Context, c net.Conn) context.Context + inShutdown atomicBool // true when when server is in shutdown + disableKeepAlives int32 // accessed atomically. - inShutdown int32 // accessed atomically (non-zero means we're in Shutdown) nextProtoOnce sync.Once // guards setupHTTP2_* init nextProtoErr error // result of http2.ConfigureServer if used @@ -2629,7 +2660,7 @@ func (s *Server) closeDoneChanLocked() { // Close returns any error returned from closing the Server's // underlying Listener(s). func (srv *Server) Close() error { - atomic.StoreInt32(&srv.inShutdown, 1) + srv.inShutdown.setTrue() srv.mu.Lock() defer srv.mu.Unlock() srv.closeDoneChanLocked() @@ -2671,7 +2702,7 @@ var shutdownPollInterval = 500 * time.Millisecond // Once Shutdown has been called on a server, it may not be reused; // future calls to methods such as Serve will return ErrServerClosed. func (srv *Server) Shutdown(ctx context.Context) error { - atomic.StoreInt32(&srv.inShutdown, 1) + srv.inShutdown.setTrue() srv.mu.Lock() lnerr := srv.closeListenersLocked() @@ -2684,7 +2715,7 @@ func (srv *Server) Shutdown(ctx context.Context) error { ticker := time.NewTicker(shutdownPollInterval) defer ticker.Stop() for { - if srv.closeIdleConns() { + if srv.closeIdleConns() && srv.numListeners() == 0 { return lnerr } select { @@ -2706,6 +2737,12 @@ func (srv *Server) RegisterOnShutdown(f func()) { srv.mu.Unlock() } +func (s *Server) numListeners() int { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.listeners) +} + // closeIdleConns closes all idle connections and reports whether the // server is quiescent. func (s *Server) closeIdleConns() bool { @@ -2738,7 +2775,6 @@ func (s *Server) closeListenersLocked() error { if cerr := (*ln).Close(); cerr != nil && err == nil { err = cerr } - delete(s.listeners, ln) } return err } @@ -3037,9 +3073,7 @@ func (s *Server) doKeepAlives() bool { } func (s *Server) shuttingDown() bool { - // TODO: replace inShutdown with the existing atomicBool type; - // see https://github.com/golang/go/issues/20239#issuecomment-381434582 - return atomic.LoadInt32(&s.inShutdown) != 0 + return s.inShutdown.isSet() } // SetKeepAlivesEnabled controls whether HTTP keep-alives are enabled. diff --git a/libgo/go/net/http/transfer.go b/libgo/go/net/http/transfer.go index 2e01a07..50d434b 100644 --- a/libgo/go/net/http/transfer.go +++ b/libgo/go/net/http/transfer.go @@ -310,7 +310,7 @@ func (t *transferWriter) writeHeader(w io.Writer, trace *httptrace.ClientTrace) k = CanonicalHeaderKey(k) switch k { case "Transfer-Encoding", "Trailer", "Content-Length": - return &badStringError{"invalid Trailer key", k} + return badStringError("invalid Trailer key", k) } keys = append(keys, k) } @@ -335,7 +335,7 @@ func (t *transferWriter) writeBody(w io.Writer) error { var ncopy int64 // Write body. We "unwrap" the body first if it was wrapped in a - // nopCloser. This is to ensure that we can take advantage of + // nopCloser or readTrackingBody. This is to ensure that we can take advantage of // OS-level optimizations in the event that the body is an // *os.File. if t.Body != nil { @@ -413,7 +413,10 @@ func (t *transferWriter) unwrapBody() io.Reader { if reflect.TypeOf(t.Body) == nopCloserType { return reflect.ValueOf(t.Body).Field(0).Interface().(io.Reader) } - + if r, ok := t.Body.(*readTrackingBody); ok { + r.didRead = true + return r.ReadCloser + } return t.Body } @@ -425,11 +428,11 @@ type transferReader struct { ProtoMajor int ProtoMinor int // Output - Body io.ReadCloser - ContentLength int64 - TransferEncoding []string - Close bool - Trailer Header + Body io.ReadCloser + ContentLength int64 + Chunked bool + Close bool + Trailer Header } func (t *transferReader) protoAtLeast(m, n int) bool { @@ -501,13 +504,12 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { t.ProtoMajor, t.ProtoMinor = 1, 1 } - // Transfer encoding, content length - err = t.fixTransferEncoding() - if err != nil { + // Transfer-Encoding: chunked, and overriding Content-Length. + if err := t.parseTransferEncoding(); err != nil { return err } - realLength, err := fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.TransferEncoding) + realLength, err := fixLength(isResponse, t.StatusCode, t.RequestMethod, t.Header, t.Chunked) if err != nil { return err } @@ -522,7 +524,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { } // Trailer - t.Trailer, err = fixTrailer(t.Header, t.TransferEncoding) + t.Trailer, err = fixTrailer(t.Header, t.Chunked) if err != nil { return err } @@ -532,9 +534,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { // See RFC 7230, section 3.3. switch msg.(type) { case *Response: - if realLength == -1 && - !chunked(t.TransferEncoding) && - bodyAllowedForStatus(t.StatusCode) { + if realLength == -1 && !t.Chunked && bodyAllowedForStatus(t.StatusCode) { // Unbounded body. t.Close = true } @@ -543,7 +543,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { // Prepare body reader. ContentLength < 0 means chunked encoding // or close connection when finished, since multipart is not supported yet switch { - case chunked(t.TransferEncoding): + case t.Chunked: if noResponseBodyExpected(t.RequestMethod) || !bodyAllowedForStatus(t.StatusCode) { t.Body = NoBody } else { @@ -569,13 +569,17 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { case *Request: rr.Body = t.Body rr.ContentLength = t.ContentLength - rr.TransferEncoding = t.TransferEncoding + if t.Chunked { + rr.TransferEncoding = []string{"chunked"} + } rr.Close = t.Close rr.Trailer = t.Trailer case *Response: rr.Body = t.Body rr.ContentLength = t.ContentLength - rr.TransferEncoding = t.TransferEncoding + if t.Chunked { + rr.TransferEncoding = []string{"chunked"} + } rr.Close = t.Close rr.Trailer = t.Trailer } @@ -605,8 +609,8 @@ func isUnsupportedTEError(err error) bool { return ok } -// fixTransferEncoding sanitizes t.TransferEncoding, if needed. -func (t *transferReader) fixTransferEncoding() error { +// parseTransferEncoding sets t.Chunked based on the Transfer-Encoding header. +func (t *transferReader) parseTransferEncoding() error { raw, present := t.Header["Transfer-Encoding"] if !present { return nil @@ -618,56 +622,38 @@ func (t *transferReader) fixTransferEncoding() error { return nil } - encodings := strings.Split(raw[0], ",") - te := make([]string, 0, len(encodings)) - // TODO: Even though we only support "identity" and "chunked" - // encodings, the loop below is designed with foresight. One - // invariant that must be maintained is that, if present, - // chunked encoding must always come first. - for _, encoding := range encodings { - encoding = strings.ToLower(strings.TrimSpace(encoding)) - // "identity" encoding is not recorded - if encoding == "identity" { - break - } - if encoding != "chunked" { - return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", encoding)} - } - te = te[0 : len(te)+1] - te[len(te)-1] = encoding - } - if len(te) > 1 { - return &badStringError{"too many transfer encodings", strings.Join(te, ",")} - } - if len(te) > 0 { - // RFC 7230 3.3.2 says "A sender MUST NOT send a - // Content-Length header field in any message that - // contains a Transfer-Encoding header field." - // - // but also: - // "If a message is received with both a - // Transfer-Encoding and a Content-Length header - // field, the Transfer-Encoding overrides the - // Content-Length. Such a message might indicate an - // attempt to perform request smuggling (Section 9.5) - // or response splitting (Section 9.4) and ought to be - // handled as an error. A sender MUST remove the - // received Content-Length field prior to forwarding - // such a message downstream." - // - // Reportedly, these appear in the wild. - delete(t.Header, "Content-Length") - t.TransferEncoding = te - return nil + // Like nginx, we only support a single Transfer-Encoding header field, and + // only if set to "chunked". This is one of the most security sensitive + // surfaces in HTTP/1.1 due to the risk of request smuggling, so we keep it + // strict and simple. + if len(raw) != 1 { + return &unsupportedTEError{fmt.Sprintf("too many transfer encodings: %q", raw)} } + if strings.ToLower(textproto.TrimString(raw[0])) != "chunked" { + return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", raw[0])} + } + + // RFC 7230 3.3.2 says "A sender MUST NOT send a Content-Length header field + // in any message that contains a Transfer-Encoding header field." + // + // but also: "If a message is received with both a Transfer-Encoding and a + // Content-Length header field, the Transfer-Encoding overrides the + // Content-Length. Such a message might indicate an attempt to perform + // request smuggling (Section 9.5) or response splitting (Section 9.4) and + // ought to be handled as an error. A sender MUST remove the received + // Content-Length field prior to forwarding such a message downstream." + // + // Reportedly, these appear in the wild. + delete(t.Header, "Content-Length") + t.Chunked = true return nil } // Determine the expected body length, using RFC 7230 Section 3.3. This // function is not a method, because ultimately it should be shared by // ReadResponse and ReadRequest. -func fixLength(isResponse bool, status int, requestMethod string, header Header, te []string) (int64, error) { +func fixLength(isResponse bool, status int, requestMethod string, header Header, chunked bool) (int64, error) { isRequest := !isResponse contentLens := header["Content-Length"] @@ -677,9 +663,9 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header, // Content-Length headers if they differ in value. // If there are dups of the value, remove the dups. // See Issue 16490. - first := strings.TrimSpace(contentLens[0]) + first := textproto.TrimString(contentLens[0]) for _, ct := range contentLens[1:] { - if first != strings.TrimSpace(ct) { + if first != textproto.TrimString(ct) { return 0, fmt.Errorf("http: message cannot contain multiple Content-Length headers; got %q", contentLens) } } @@ -711,14 +697,14 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header, } // Logic based on Transfer-Encoding - if chunked(te) { + if chunked { return -1, nil } // Logic based on Content-Length var cl string if len(contentLens) == 1 { - cl = strings.TrimSpace(contentLens[0]) + cl = textproto.TrimString(contentLens[0]) } if cl != "" { n, err := parseContentLength(cl) @@ -766,12 +752,12 @@ func shouldClose(major, minor int, header Header, removeCloseHeader bool) bool { } // Parse the trailer header -func fixTrailer(header Header, te []string) (Header, error) { +func fixTrailer(header Header, chunked bool) (Header, error) { vv, ok := header["Trailer"] if !ok { return nil, nil } - if !chunked(te) { + if !chunked { // Trailer and no chunking: // this is an invalid use case for trailer header. // Nevertheless, no error will be returned and we @@ -791,7 +777,7 @@ func fixTrailer(header Header, te []string) (Header, error) { switch key { case "Transfer-Encoding", "Trailer", "Content-Length": if err == nil { - err = &badStringError{"bad trailer key", key} + err = badStringError("bad trailer key", key) return } } @@ -1049,15 +1035,15 @@ func (bl bodyLocked) Read(p []byte) (n int, err error) { // parseContentLength trims whitespace from s and returns -1 if no value // is set, or the value if it's >= 0. func parseContentLength(cl string) (int64, error) { - cl = strings.TrimSpace(cl) + cl = textproto.TrimString(cl) if cl == "" { return -1, nil } - n, err := strconv.ParseInt(cl, 10, 64) - if err != nil || n < 0 { - return 0, &badStringError{"bad Content-Length", cl} + n, err := strconv.ParseUint(cl, 10, 63) + if err != nil { + return 0, badStringError("bad Content-Length", cl) } - return n, nil + return int64(n), nil } @@ -1092,6 +1078,9 @@ func isKnownInMemoryReader(r io.Reader) bool { if reflect.TypeOf(r) == nopCloserType { return isKnownInMemoryReader(reflect.ValueOf(r).Field(0).Interface().(io.Reader)) } + if r, ok := r.(*readTrackingBody); ok { + return isKnownInMemoryReader(r.ReadCloser) + } return false } diff --git a/libgo/go/net/http/transfer_test.go b/libgo/go/net/http/transfer_test.go index 65009ee..185225f 100644 --- a/libgo/go/net/http/transfer_test.go +++ b/libgo/go/net/http/transfer_test.go @@ -279,7 +279,7 @@ func TestTransferWriterWriteBodyReaderTypes(t *testing.T) { } } -func TestFixTransferEncoding(t *testing.T) { +func TestParseTransferEncoding(t *testing.T) { tests := []struct { hdr Header wantErr error @@ -290,7 +290,23 @@ func TestFixTransferEncoding(t *testing.T) { }, { hdr: Header{"Transfer-Encoding": {"chunked, chunked", "identity", "chunked"}}, - wantErr: &badStringError{"too many transfer encodings", "chunked,chunked"}, + wantErr: &unsupportedTEError{`too many transfer encodings: ["chunked, chunked" "identity" "chunked"]`}, + }, + { + hdr: Header{"Transfer-Encoding": {""}}, + wantErr: &unsupportedTEError{`unsupported transfer encoding: ""`}, + }, + { + hdr: Header{"Transfer-Encoding": {"chunked, identity"}}, + wantErr: &unsupportedTEError{`unsupported transfer encoding: "chunked, identity"`}, + }, + { + hdr: Header{"Transfer-Encoding": {"chunked", "identity"}}, + wantErr: &unsupportedTEError{`too many transfer encodings: ["chunked" "identity"]`}, + }, + { + hdr: Header{"Transfer-Encoding": {"\x0bchunked"}}, + wantErr: &unsupportedTEError{`unsupported transfer encoding: "\vchunked"`}, }, { hdr: Header{"Transfer-Encoding": {"chunked"}}, @@ -304,9 +320,45 @@ func TestFixTransferEncoding(t *testing.T) { ProtoMajor: 1, ProtoMinor: 1, } - gotErr := tr.fixTransferEncoding() + gotErr := tr.parseTransferEncoding() if !reflect.DeepEqual(gotErr, tt.wantErr) { t.Errorf("%d.\ngot error:\n%v\nwant error:\n%v\n\n", i, gotErr, tt.wantErr) } } } + +// issue 39017 - disallow Content-Length values such as "+3" +func TestParseContentLength(t *testing.T) { + tests := []struct { + cl string + wantErr error + }{ + { + cl: "3", + wantErr: nil, + }, + { + cl: "+3", + wantErr: badStringError("bad Content-Length", "+3"), + }, + { + cl: "-3", + wantErr: badStringError("bad Content-Length", "-3"), + }, + { + // max int64, for safe conversion before returning + cl: "9223372036854775807", + wantErr: nil, + }, + { + cl: "9223372036854775808", + wantErr: badStringError("bad Content-Length", "9223372036854775808"), + }, + } + + for _, tt := range tests { + if _, gotErr := parseContentLength(tt.cl); !reflect.DeepEqual(gotErr, tt.wantErr) { + t.Errorf("%q:\n\tgot=%v\n\twant=%v", tt.cl, gotErr, tt.wantErr) + } + } +} diff --git a/libgo/go/net/http/transport.go b/libgo/go/net/http/transport.go index d0bfdb4..d37b52b 100644 --- a/libgo/go/net/http/transport.go +++ b/libgo/go/net/http/transport.go @@ -100,7 +100,7 @@ type Transport struct { idleLRU connLRU reqMu sync.Mutex - reqCanceler map[*Request]func(error) + reqCanceler map[cancelKey]func(error) altMu sync.Mutex // guards changing altProto only altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme @@ -273,6 +273,13 @@ type Transport struct { ForceAttemptHTTP2 bool } +// A cancelKey is the key of the reqCanceler map. +// We wrap the *Request in this type since we want to use the original request, +// not any transient one created by roundTrip. +type cancelKey struct { + req *Request +} + func (t *Transport) writeBufferSize() int { if t.WriteBufferSize > 0 { return t.WriteBufferSize @@ -433,9 +440,10 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) { // optional extra headers to write and stores any error to return // from roundTrip. type transportRequest struct { - *Request // original request, not to be mutated - extra Header // extra headers to write, or nil - trace *httptrace.ClientTrace // optional + *Request // original request, not to be mutated + extra Header // extra headers to write, or nil + trace *httptrace.ClientTrace // optional + cancelKey cancelKey mu sync.Mutex // guards err err error // first setError value for mapRoundTripError to consider @@ -511,14 +519,23 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { } } + origReq := req + cancelKey := cancelKey{origReq} + req = setupRewindBody(req) + if altRT := t.alternateRoundTripper(req); altRT != nil { if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol { return resp, err } + var err error + req, err = rewindBody(req) + if err != nil { + return nil, err + } } if !isHTTP { req.closeBody() - return nil, &badStringError{"unsupported protocol scheme", scheme} + return nil, badStringError("unsupported protocol scheme", scheme) } if req.Method != "" && !validMethod(req.Method) { req.closeBody() @@ -538,7 +555,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { } // treq gets modified by roundTrip, so we need to recreate for each retry. - treq := &transportRequest{Request: req, trace: trace} + treq := &transportRequest{Request: req, trace: trace, cancelKey: cancelKey} cm, err := t.connectMethodForRequest(treq) if err != nil { req.closeBody() @@ -551,7 +568,7 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { // to send it requests. pconn, err := t.getConn(treq, cm) if err != nil { - t.setReqCanceler(req, nil) + t.setReqCanceler(cancelKey, nil) req.closeBody() return nil, err } @@ -559,24 +576,22 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { var resp *Response if pconn.alt != nil { // HTTP/2 path. - t.setReqCanceler(req, nil) // not cancelable with CancelRequest + t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest resp, err = pconn.alt.RoundTrip(req) } else { resp, err = pconn.roundTrip(treq) } if err == nil { + resp.Request = origReq return resp, nil } // Failed. Clean up and determine whether to retry. - - _, isH2DialError := pconn.alt.(http2erringRoundTripper) - if http2isNoCachedConnError(err) || isH2DialError { + if http2isNoCachedConnError(err) { if t.removeIdleConn(pconn) { t.decConnsPerHost(pconn.cacheKey) } - } - if !pconn.shouldRetryRequest(req, err) { + } else if !pconn.shouldRetryRequest(req, err) { // Issue 16465: return underlying net.Conn.Read error from peek, // as we've historically done. if e, ok := err.(transportReadFromServerError); ok { @@ -587,18 +602,59 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { testHookRoundTripRetried() // Rewind the body if we're able to. - if req.GetBody != nil { - newReq := *req - var err error - newReq.Body, err = req.GetBody() - if err != nil { - return nil, err - } - req = &newReq + req, err = rewindBody(req) + if err != nil { + return nil, err } } } +var errCannotRewind = errors.New("net/http: cannot rewind body after connection loss") + +type readTrackingBody struct { + io.ReadCloser + didRead bool +} + +func (r *readTrackingBody) Read(data []byte) (int, error) { + r.didRead = true + return r.ReadCloser.Read(data) +} + +// setupRewindBody returns a new request with a custom body wrapper +// that can report whether the body needs rewinding. +// This lets rewindBody avoid an error result when the request +// does not have GetBody but the body hasn't been read at all yet. +func setupRewindBody(req *Request) *Request { + if req.Body == nil || req.Body == NoBody { + return req + } + newReq := *req + newReq.Body = &readTrackingBody{ReadCloser: req.Body} + return &newReq +} + +// rewindBody returns a new request with the body rewound. +// It returns req unmodified if the body does not need rewinding. +// rewindBody takes care of closing req.Body when appropriate +// (in all cases except when rewindBody returns req unmodified). +func rewindBody(req *Request) (rewound *Request, err error) { + if req.Body == nil || req.Body == NoBody || !req.Body.(*readTrackingBody).didRead { + return req, nil // nothing to rewind + } + req.closeBody() + if req.GetBody == nil { + return nil, errCannotRewind + } + body, err := req.GetBody() + if err != nil { + return nil, err + } + newReq := *req + newReq.Body = &readTrackingBody{ReadCloser: body} + return &newReq, nil +} + // shouldRetryRequest reports whether we should retry sending a failed // HTTP request on a new connection. The non-nil input error is the // error from roundTrip. @@ -706,14 +762,14 @@ func (t *Transport) CloseIdleConnections() { // cancelable context instead. CancelRequest cannot cancel HTTP/2 // requests. func (t *Transport) CancelRequest(req *Request) { - t.cancelRequest(req, errRequestCanceled) + t.cancelRequest(cancelKey{req}, errRequestCanceled) } // Cancel an in-flight request, recording the error value. -func (t *Transport) cancelRequest(req *Request, err error) { +func (t *Transport) cancelRequest(key cancelKey, err error) { t.reqMu.Lock() - cancel := t.reqCanceler[req] - delete(t.reqCanceler, req) + cancel := t.reqCanceler[key] + delete(t.reqCanceler, key) t.reqMu.Unlock() if cancel != nil { cancel(err) @@ -846,7 +902,7 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error { // Deliver pconn to goroutine waiting for idle connection, if any. // (They may be actively dialing, but this conn is ready first. // Chrome calls this socket late binding. - // See https://insouciant.org/tech/connection-management-in-chromium/.) + // See https://www.chromium.org/developers/design-documents/network-stack#TOC-Connection-Management.) key := pconn.cacheKey if q, ok := t.idleConnWait[key]; ok { done := false @@ -1046,16 +1102,16 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) bool { return removed } -func (t *Transport) setReqCanceler(r *Request, fn func(error)) { +func (t *Transport) setReqCanceler(key cancelKey, fn func(error)) { t.reqMu.Lock() defer t.reqMu.Unlock() if t.reqCanceler == nil { - t.reqCanceler = make(map[*Request]func(error)) + t.reqCanceler = make(map[cancelKey]func(error)) } if fn != nil { - t.reqCanceler[r] = fn + t.reqCanceler[key] = fn } else { - delete(t.reqCanceler, r) + delete(t.reqCanceler, key) } } @@ -1063,17 +1119,17 @@ func (t *Transport) setReqCanceler(r *Request, fn func(error)) { // for the request, we don't set the function and return false. // Since CancelRequest will clear the canceler, we can use the return value to detect if // the request was canceled since the last setReqCancel call. -func (t *Transport) replaceReqCanceler(r *Request, fn func(error)) bool { +func (t *Transport) replaceReqCanceler(key cancelKey, fn func(error)) bool { t.reqMu.Lock() defer t.reqMu.Unlock() - _, ok := t.reqCanceler[r] + _, ok := t.reqCanceler[key] if !ok { return false } if fn != nil { - t.reqCanceler[r] = fn + t.reqCanceler[key] = fn } else { - delete(t.reqCanceler, r) + delete(t.reqCanceler, key) } return true } @@ -1277,12 +1333,12 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (pc *persi // set request canceler to some non-nil function so we // can detect whether it was cleared between now and when // we enter roundTrip - t.setReqCanceler(req, func(error) {}) + t.setReqCanceler(treq.cancelKey, func(error) {}) return pc, nil } cancelc := make(chan error, 1) - t.setReqCanceler(req, func(err error) { cancelc <- err }) + t.setReqCanceler(treq.cancelKey, func(err error) { cancelc <- err }) // Queue for permission to dial. t.queueForDial(w) @@ -1637,7 +1693,12 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok { - return &persistConn{t: t, cacheKey: pconn.cacheKey, alt: next(cm.targetAddr, pconn.conn.(*tls.Conn))}, nil + alt := next(cm.targetAddr, pconn.conn.(*tls.Conn)) + if e, ok := alt.(http2erringRoundTripper); ok { + // pconn.conn was closed by next (http2configureTransport.upgradeFn). + return nil, e.err + } + return &persistConn{t: t, cacheKey: pconn.cacheKey, alt: alt}, nil } } @@ -1694,6 +1755,7 @@ var _ io.ReaderFrom = (*persistConnWriter)(nil) // https://proxy.com|http https to proxy, http to anywhere after that // type connectMethod struct { + _ incomparable proxyURL *url.URL // nil for no proxy, else full proxy URL targetScheme string // "http" or "https" // If proxyURL specifies an http or https proxy, and targetScheme is http (not https), @@ -2025,7 +2087,7 @@ func (pc *persistConn) readLoop() { } if !hasBody || bodyWritable { - pc.t.setReqCanceler(rc.req, nil) + pc.t.setReqCanceler(rc.cancelKey, nil) // Put the idle conn back into the pool before we send the response // so if they process it quickly and make another request, they'll @@ -2098,7 +2160,7 @@ func (pc *persistConn) readLoop() { // reading the response body. (or for cancellation or death) select { case bodyEOF := <-waitForBodyRead: - pc.t.setReqCanceler(rc.req, nil) // before pc might return to idle pool + pc.t.setReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool alive = alive && bodyEOF && !pc.sawEOF && @@ -2112,7 +2174,7 @@ func (pc *persistConn) readLoop() { pc.t.CancelRequest(rc.req) case <-rc.req.Context().Done(): alive = false - pc.t.cancelRequest(rc.req, rc.req.Context().Err()) + pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err()) case <-pc.closech: alive = false } @@ -2248,6 +2310,7 @@ func newReadWriteCloserBody(br *bufio.Reader, rwc io.ReadWriteCloser) io.ReadWri // the concrete type for a Response.Body on the 101 Switching // Protocols response, as used by WebSockets, h2c, etc. type readWriteCloserBody struct { + _ incomparable br *bufio.Reader // used until empty io.ReadWriteCloser } @@ -2348,13 +2411,16 @@ func (pc *persistConn) wroteRequest() bool { // responseAndError is how the goroutine reading from an HTTP/1 server // communicates with the goroutine doing the RoundTrip. type responseAndError struct { + _ incomparable res *Response // else use this response (see res method) err error } type requestAndChan struct { - req *Request - ch chan responseAndError // unbuffered; always send in select on callerGone + _ incomparable + req *Request + cancelKey cancelKey + ch chan responseAndError // unbuffered; always send in select on callerGone // whether the Transport (as opposed to the user client code) // added the Accept-Encoding gzip header. If the Transport @@ -2416,7 +2482,7 @@ var ( func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { testHookEnterRoundTrip() - if !pc.t.replaceReqCanceler(req.Request, pc.cancelRequest) { + if !pc.t.replaceReqCanceler(req.cancelKey, pc.cancelRequest) { pc.t.putOrCloseIdleConn(pc) return nil, errRequestCanceled } @@ -2468,7 +2534,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err defer func() { if err != nil { - pc.t.setReqCanceler(req.Request, nil) + pc.t.setReqCanceler(req.cancelKey, nil) } }() @@ -2484,6 +2550,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err resc := make(chan responseAndError) pc.reqch <- requestAndChan{ req: req.Request, + cancelKey: req.cancelKey, ch: resc, addedGzip: requestedGzip, continueCh: continueCh, @@ -2535,10 +2602,10 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err } return re.res, nil case <-cancelChan: - pc.t.CancelRequest(req.Request) + pc.t.cancelRequest(req.cancelKey, errRequestCanceled) cancelChan = nil case <-ctxDoneChan: - pc.t.cancelRequest(req.Request, req.Context().Err()) + pc.t.cancelRequest(req.cancelKey, req.Context().Err()) cancelChan = nil ctxDoneChan = nil } @@ -2685,6 +2752,7 @@ func (es *bodyEOFSignal) condfn(err error) error { // gzipReader wraps a response body so it can lazily // call gzip.NewReader on the first call to Read type gzipReader struct { + _ incomparable body *bodyEOFSignal // underlying HTTP/1 response body framing zr *gzip.Reader // lazily-initialized gzip reader zerr error // any error from gzip.NewReader; sticky diff --git a/libgo/go/net/http/transport_test.go b/libgo/go/net/http/transport_test.go index 1e0334d..5c5ae3f 100644 --- a/libgo/go/net/http/transport_test.go +++ b/libgo/go/net/http/transport_test.go @@ -451,14 +451,23 @@ func TestTransportReadToEndReusesConn(t *testing.T) { func TestTransportMaxPerHostIdleConns(t *testing.T) { defer afterTest(t) + stop := make(chan struct{}) // stop marks the exit of main Test goroutine + defer close(stop) + resch := make(chan string) gotReq := make(chan bool) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { gotReq <- true - msg := <-resch + var msg string + select { + case <-stop: + return + case msg = <-resch: + } _, err := w.Write([]byte(msg)) if err != nil { - t.Fatalf("Write: %v", err) + t.Errorf("Write: %v", err) + return } })) defer ts.Close() @@ -472,6 +481,13 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { // Their responses will hang until we write to resch, though. donech := make(chan bool) doReq := func() { + defer func() { + select { + case <-stop: + return + case donech <- t.Failed(): + } + }() resp, err := c.Get(ts.URL) if err != nil { t.Error(err) @@ -481,7 +497,6 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { t.Errorf("ReadAll: %v", err) return } - donech <- true } go doReq() <-gotReq @@ -842,7 +857,9 @@ func TestStressSurpriseServerCloses(t *testing.T) { // where we won the race. res.Body.Close() } - activityc <- true + if !<-activityc { // Receives false when close(activityc) is executed + return + } } }() } @@ -850,8 +867,9 @@ func TestStressSurpriseServerCloses(t *testing.T) { // Make sure all the request come back, one way or another. for i := 0; i < numClients*reqsPerClient; i++ { select { - case <-activityc: + case activityc <- true: case <-time.After(5 * time.Second): + close(activityc) t.Fatalf("presumed deadlock; no HTTP client activity seen in awhile") } } @@ -2350,6 +2368,50 @@ func TestTransportCancelRequest(t *testing.T) { } } +func testTransportCancelRequestInDo(t *testing.T, body io.Reader) { + setParallel(t) + defer afterTest(t) + if testing.Short() { + t.Skip("skipping test in -short mode") + } + unblockc := make(chan bool) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + <-unblockc + })) + defer ts.Close() + defer close(unblockc) + + c := ts.Client() + tr := c.Transport.(*Transport) + + donec := make(chan bool) + req, _ := NewRequest("GET", ts.URL, body) + go func() { + defer close(donec) + c.Do(req) + }() + start := time.Now() + timeout := 10 * time.Second + for time.Since(start) < timeout { + time.Sleep(100 * time.Millisecond) + tr.CancelRequest(req) + select { + case <-donec: + return + default: + } + } + t.Errorf("Do of canceled request has not returned after %v", timeout) +} + +func TestTransportCancelRequestInDo(t *testing.T) { + testTransportCancelRequestInDo(t, nil) +} + +func TestTransportCancelRequestWithBodyInDo(t *testing.T) { + testTransportCancelRequestInDo(t, bytes.NewBuffer([]byte{0})) +} + func TestTransportCancelRequestInDial(t *testing.T) { defer afterTest(t) if testing.Short() { @@ -2365,7 +2427,9 @@ func TestTransportCancelRequestInDial(t *testing.T) { tr := &Transport{ Dial: func(network, addr string) (net.Conn, error) { eventLog.Println("dial: blocking") - inDial <- true + if !<-inDial { + return nil, errors.New("main Test goroutine exited") + } <-unblockDial return nil, errors.New("nope") }, @@ -2380,8 +2444,9 @@ func TestTransportCancelRequestInDial(t *testing.T) { }() select { - case <-inDial: + case inDial <- true: case <-time.After(5 * time.Second): + close(inDial) t.Fatal("timeout; never saw blocking dial") } @@ -3494,7 +3559,8 @@ func TestRetryRequestsOnError(t *testing.T) { for i := 0; i < 3; i++ { t0 := time.Now() - res, err := c.Do(tc.req()) + req := tc.req() + res, err := c.Do(req) if err != nil { if time.Since(t0) < MaxWriteWaitBeforeConnReuse/2 { mu.Lock() @@ -3505,6 +3571,9 @@ func TestRetryRequestsOnError(t *testing.T) { t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", MaxWriteWaitBeforeConnReuse) } res.Body.Close() + if res.Request != req { + t.Errorf("Response.Request != original request; want identical Request") + } } mu.Lock() @@ -6179,3 +6248,48 @@ func (timeoutProto) RoundTrip(req *Request) (*Response, error) { return nil, errors.New("request was not canceled") } } + +type roundTripFunc func(r *Request) (*Response, error) + +func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) } + +// Issue 32441: body is not reset after ErrSkipAltProtocol +func TestIssue32441(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if n, _ := io.Copy(ioutil.Discard, r.Body); n == 0 { + t.Error("body length is zero") + } + })) + defer ts.Close() + c := ts.Client() + c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) { + // Draining body to trigger failure condition on actual request to server. + if n, _ := io.Copy(ioutil.Discard, r.Body); n == 0 { + t.Error("body length is zero during round trip") + } + return nil, ErrSkipAltProtocol + })) + if _, err := c.Post(ts.URL, "application/octet-stream", bytes.NewBufferString("data")); err != nil { + t.Error(err) + } +} + +// Issue 39017. Ensure that HTTP/1 transports reject Content-Length headers +// that contain a sign (eg. "+3"), per RFC 2616, Section 14.13. +func TestTransportRejectsSignInContentLength(t *testing.T) { + cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", "+3") + w.Write([]byte("abc")) + })) + defer cst.Close() + + c := cst.Client() + res, err := c.Get(cst.URL) + if err == nil || res != nil { + t.Fatal("Expected a non-nil error and a nil http.Response") + } + if got, want := err.Error(), `bad Content-Length "+3"`; !strings.Contains(got, want) { + t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want) + } +} diff --git a/libgo/go/net/interface_aix.go b/libgo/go/net/interface_aix.go index f57c5ff..bd55386 100644 --- a/libgo/go/net/interface_aix.go +++ b/libgo/go/net/interface_aix.go @@ -33,8 +33,6 @@ const _RTAX_NETMASK = 2 const _RTAX_IFA = 5 const _RTAX_MAX = 8 -const _SIOCGIFMTU = -0x3fd796aa - func getIfList() ([]byte, error) { needed, err := syscall.Getkerninfo(_KINFO_RT_IFLIST, 0, 0, 0) if err != nil { diff --git a/libgo/go/net/interface_plan9.go b/libgo/go/net/interface_plan9.go index 1295017..31bbaca 100644 --- a/libgo/go/net/interface_plan9.go +++ b/libgo/go/net/interface_plan9.go @@ -68,8 +68,8 @@ func readInterface(i int) (*Interface, error) { } ifc.MTU = mtu - // Not a loopback device - if device != "/dev/null" { + // Not a loopback device ("/dev/null") or packet interface (e.g. "pkt2") + if stringsHasPrefix(device, netdir+"/") { deviceaddrf, err := open(device + "/addr") if err != nil { return nil, err diff --git a/libgo/go/net/interface_windows.go b/libgo/go/net/interface_windows.go index 5449432..30e90b8 100644 --- a/libgo/go/net/interface_windows.go +++ b/libgo/go/net/interface_windows.go @@ -58,7 +58,7 @@ func interfaceTable(ifindex int) ([]Interface, error) { if ifindex == 0 || ifindex == int(index) { ifi := Interface{ Index: int(index), - Name: windows.UTF16PtrToString(aa.FriendlyName, 10000), + Name: windows.UTF16PtrToString(aa.FriendlyName), } if aa.OperStatus == windows.IfOperStatusUp { ifi.Flags |= FlagUp diff --git a/libgo/go/net/ip.go b/libgo/go/net/ip.go index 9d1223e..c00fe8e 100644 --- a/libgo/go/net/ip.go +++ b/libgo/go/net/ip.go @@ -671,8 +671,8 @@ func parseIPv6(s string) (ip IP) { } // ParseIP parses s as an IP address, returning the result. -// The string s can be in dotted decimal ("192.0.2.1") -// or IPv6 ("2001:db8::68") form. +// The string s can be in IPv4 dotted decimal ("192.0.2.1"), IPv6 +// ("2001:db8::68"), or IPv4-mapped IPv6 ("::ffff:192.0.2.1") form. // If s is not a valid textual representation of an IP address, // ParseIP returns nil. func ParseIP(s string) IP { diff --git a/libgo/go/net/ipsock_plan9.go b/libgo/go/net/ipsock_plan9.go index 93f0f4e..2308236 100644 --- a/libgo/go/net/ipsock_plan9.go +++ b/libgo/go/net/ipsock_plan9.go @@ -57,17 +57,17 @@ func parsePlan9Addr(s string) (ip IP, iport int, err error) { return nil, 0, &ParseError{Type: "IP address", Text: s} } } - p, _, ok := dtoi(s[i+1:]) + p, plen, ok := dtoi(s[i+1:]) if !ok { return nil, 0, &ParseError{Type: "port", Text: s} } if p < 0 || p > 0xFFFF { - return nil, 0, &AddrError{Err: "invalid port", Addr: string(p)} + return nil, 0, &AddrError{Err: "invalid port", Addr: s[i+1 : i+1+plen]} } return addr, p, nil } -func readPlan9Addr(proto, filename string) (addr Addr, err error) { +func readPlan9Addr(net, filename string) (addr Addr, err error) { var buf [128]byte f, err := os.Open(filename) @@ -83,13 +83,19 @@ func readPlan9Addr(proto, filename string) (addr Addr, err error) { if err != nil { return } - switch proto { - case "tcp": + switch net { + case "tcp4", "udp4": + if ip.Equal(IPv6zero) { + ip = ip[:IPv4len] + } + } + switch net { + case "tcp", "tcp4", "tcp6": addr = &TCPAddr{IP: ip, Port: port} - case "udp": + case "udp", "udp4", "udp6": addr = &UDPAddr{IP: ip, Port: port} default: - return nil, UnknownNetworkError(proto) + return nil, UnknownNetworkError(net) } return addr, nil } @@ -199,7 +205,11 @@ func dialPlan9Blocking(ctx context.Context, net string, laddr, raddr Addr) (fd * if err != nil { return nil, err } - _, err = f.WriteString("connect " + dest) + if la := plan9LocalAddr(laddr); la == "" { + err = hangupCtlWrite(ctx, proto, f, "connect "+dest) + } else { + err = hangupCtlWrite(ctx, proto, f, "connect "+dest+" "+la) + } if err != nil { f.Close() return nil, err @@ -209,7 +219,7 @@ func dialPlan9Blocking(ctx context.Context, net string, laddr, raddr Addr) (fd * f.Close() return nil, err } - laddr, err = readPlan9Addr(proto, netdir+"/"+proto+"/"+name+"/local") + laddr, err = readPlan9Addr(net, netdir+"/"+proto+"/"+name+"/local") if err != nil { data.Close() f.Close() @@ -229,7 +239,7 @@ func listenPlan9(ctx context.Context, net string, laddr Addr) (fd *netFD, err er f.Close() return nil, &OpError{Op: "announce", Net: net, Source: laddr, Addr: nil, Err: err} } - laddr, err = readPlan9Addr(proto, netdir+"/"+proto+"/"+name+"/local") + laddr, err = readPlan9Addr(net, netdir+"/"+proto+"/"+name+"/local") if err != nil { f.Close() return nil, err @@ -303,3 +313,53 @@ func toLocal(a Addr, net string) Addr { } return a } + +// plan9LocalAddr returns a Plan 9 local address string. +// See setladdrport at https://9p.io/sources/plan9/sys/src/9/ip/devip.c. +func plan9LocalAddr(addr Addr) string { + var ip IP + port := 0 + switch a := addr.(type) { + case *TCPAddr: + if a != nil { + ip = a.IP + port = a.Port + } + case *UDPAddr: + if a != nil { + ip = a.IP + port = a.Port + } + } + if len(ip) == 0 || ip.IsUnspecified() { + if port == 0 { + return "" + } + return itoa(port) + } + return ip.String() + "!" + itoa(port) +} + +func hangupCtlWrite(ctx context.Context, proto string, ctl *os.File, msg string) error { + if proto != "tcp" { + _, err := ctl.WriteString(msg) + return err + } + written := make(chan struct{}) + errc := make(chan error) + go func() { + select { + case <-ctx.Done(): + ctl.WriteString("hangup") + errc <- mapErr(ctx.Err()) + case <-written: + errc <- nil + } + }() + _, err := ctl.WriteString(msg) + close(written) + if e := <-errc; err == nil && e != nil { // we hung up + return e + } + return err +} diff --git a/libgo/go/net/ipsock_plan9_test.go b/libgo/go/net/ipsock_plan9_test.go new file mode 100644 index 0000000..e5fb9ff --- /dev/null +++ b/libgo/go/net/ipsock_plan9_test.go @@ -0,0 +1,29 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import "testing" + +func TestTCP4ListenZero(t *testing.T) { + l, err := Listen("tcp4", "0.0.0.0:0") + if err != nil { + t.Fatal(err) + } + defer l.Close() + if a := l.Addr(); isNotIPv4(a) { + t.Errorf("address does not contain IPv4: %v", a) + } +} + +func TestUDP4ListenZero(t *testing.T) { + c, err := ListenPacket("udp4", "0.0.0.0:0") + if err != nil { + t.Fatal(err) + } + defer c.Close() + if a := c.LocalAddr(); isNotIPv4(a) { + t.Errorf("address does not contain IPv4: %v", a) + } +} diff --git a/libgo/go/net/lookup.go b/libgo/go/net/lookup.go index 9cebd10..5f71198 100644 --- a/libgo/go/net/lookup.go +++ b/libgo/go/net/lookup.go @@ -204,6 +204,31 @@ func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, err return r.lookupIPAddr(ctx, "ip", host) } +// LookupIP looks up host for the given network using the local resolver. +// It returns a slice of that host's IP addresses of the type specified by +// network. +// network must be one of "ip", "ip4" or "ip6". +func (r *Resolver) LookupIP(ctx context.Context, network, host string) ([]IP, error) { + afnet, _, err := parseNetwork(ctx, network, false) + if err != nil { + return nil, err + } + switch afnet { + case "ip", "ip4", "ip6": + default: + return nil, UnknownNetworkError(network) + } + addrs, err := r.internetAddrList(ctx, afnet, host) + if err != nil { + return nil, err + } + ips := make([]IP, 0, len(addrs)) + for _, addr := range addrs { + ips = append(ips, addr.(*IPAddr).IP) + } + return ips, nil +} + // onlyValuesCtx is a context that uses an underlying context // for value lookup if the underlying context hasn't yet expired. type onlyValuesCtx struct { diff --git a/libgo/go/net/lookup_test.go b/libgo/go/net/lookup_test.go index 2bc5592..68bffca 100644 --- a/libgo/go/net/lookup_test.go +++ b/libgo/go/net/lookup_test.go @@ -74,7 +74,7 @@ func TestLookupGoogleSRV(t *testing.T) { t.Parallel() mustHaveExternalNetwork(t) - if runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") { + if iOS() { t.Skip("no resolv.conf on iOS") } @@ -123,7 +123,7 @@ func TestLookupGmailMX(t *testing.T) { t.Parallel() mustHaveExternalNetwork(t) - if runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") { + if iOS() { t.Skip("no resolv.conf on iOS") } @@ -169,7 +169,7 @@ func TestLookupGmailNS(t *testing.T) { t.Parallel() mustHaveExternalNetwork(t) - if runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") { + if iOS() { t.Skip("no resolv.conf on iOS") } @@ -218,7 +218,7 @@ func TestLookupGmailTXT(t *testing.T) { t.Parallel() mustHaveExternalNetwork(t) - if runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") { + if iOS() { t.Skip("no resolv.conf on iOS") } @@ -637,7 +637,7 @@ func TestLookupDotsWithRemoteSource(t *testing.T) { t.Skip("IPv4 is required") } - if runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") { + if iOS() { t.Skip("no resolv.conf on iOS") } @@ -913,6 +913,7 @@ func TestNilResolverLookup(t *testing.T) { r.LookupCNAME(ctx, "google.com") r.LookupHost(ctx, "google.com") r.LookupIPAddr(ctx, "google.com") + r.LookupIP(ctx, "ip", "google.com") r.LookupMX(ctx, "gmail.com") r.LookupNS(ctx, "google.com") r.LookupPort(ctx, "tcp", "smtp") @@ -1185,3 +1186,83 @@ func TestLookupNullByte(t *testing.T) { testenv.SkipFlakyNet(t) LookupHost("foo\x00bar") // check that it doesn't panic; it used to on Windows } + +func TestResolverLookupIP(t *testing.T) { + testenv.MustHaveExternalNetwork(t) + + v4Ok := supportsIPv4() && *testIPv4 + v6Ok := supportsIPv6() && *testIPv6 + + defer dnsWaitGroup.Wait() + + for _, impl := range []struct { + name string + fn func() func() + }{ + {"go", forceGoDNS}, + {"cgo", forceCgoDNS}, + } { + t.Run("implementation: "+impl.name, func(t *testing.T) { + fixup := impl.fn() + if fixup == nil { + t.Skip("not supported") + } + defer fixup() + + for _, network := range []string{"ip", "ip4", "ip6"} { + t.Run("network: "+network, func(t *testing.T) { + switch { + case network == "ip4" && !v4Ok: + t.Skip("IPv4 is not supported") + case network == "ip6" && !v6Ok: + t.Skip("IPv6 is not supported") + } + + // google.com has both A and AAAA records. + const host = "google.com" + ips, err := DefaultResolver.LookupIP(context.Background(), network, host) + if err != nil { + testenv.SkipFlakyNet(t) + t.Fatalf("DefaultResolver.LookupIP(%q, %q): failed with unexpected error: %v", network, host, err) + } + + var v4Addrs []IP + var v6Addrs []IP + for _, ip := range ips { + switch { + case ip.To4() != nil: + // We need to skip the test below because To16 will + // convent an IPv4 address to an IPv4-mapped IPv6 + // address. + v4Addrs = append(v4Addrs, ip) + case ip.To16() != nil: + v6Addrs = append(v6Addrs, ip) + default: + t.Fatalf("IP=%q is neither IPv4 nor IPv6", ip) + } + } + + // Check that we got the expected addresses. + if network == "ip4" || network == "ip" && v4Ok { + if len(v4Addrs) == 0 { + t.Errorf("DefaultResolver.LookupIP(%q, %q): no IPv4 addresses", network, host) + } + } + if network == "ip6" || network == "ip" && v6Ok { + if len(v6Addrs) == 0 { + t.Errorf("DefaultResolver.LookupIP(%q, %q): no IPv6 addresses", network, host) + } + } + + // Check that we didn't get any unexpected addresses. + if network == "ip6" && len(v4Addrs) > 0 { + t.Errorf("DefaultResolver.LookupIP(%q, %q): unexpected IPv4 addresses: %v", network, host, v4Addrs) + } + if network == "ip4" && len(v6Addrs) > 0 { + t.Errorf("DefaultResolver.LookupIP(%q, %q): unexpected IPv6 addresses: %v", network, host, v6Addrs) + } + }) + } + }) + } +} diff --git a/libgo/go/net/lookup_windows.go b/libgo/go/net/lookup_windows.go index 7d5c941..bb34a08 100644 --- a/libgo/go/net/lookup_windows.go +++ b/libgo/go/net/lookup_windows.go @@ -234,7 +234,7 @@ func (*Resolver) lookupCNAME(ctx context.Context, name string) (string, error) { defer syscall.DnsRecordListFree(r, 1) resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), r) - cname := windows.UTF16PtrToString(resolved, 256) + cname := windows.UTF16PtrToString(resolved) return absDomainName([]byte(cname)), nil } @@ -278,7 +278,7 @@ func (*Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) { mxs := make([]*MX, 0, 10) for _, p := range validRecs(r, syscall.DNS_TYPE_MX, name) { v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0])) - mxs = append(mxs, &MX{absDomainName([]byte(windows.UTF16PtrToString(v.NameExchange, 256))), v.Preference}) + mxs = append(mxs, &MX{absDomainName([]byte(windows.UTF16PtrToString(v.NameExchange))), v.Preference}) } byPref(mxs).sort() return mxs, nil @@ -319,7 +319,7 @@ func (*Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) { d := (*syscall.DNSTXTData)(unsafe.Pointer(&p.Data[0])) s := "" for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount:d.StringCount] { - s += windows.UTF16PtrToString(v, 1<<20) + s += windows.UTF16PtrToString(v) } txts = append(txts, s) } @@ -344,7 +344,7 @@ func (*Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) ptrs := make([]string, 0, 10) for _, p := range validRecs(r, syscall.DNS_TYPE_PTR, arpa) { v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0])) - ptrs = append(ptrs, absDomainName([]byte(windows.UTF16PtrToString(v.Host, 256)))) + ptrs = append(ptrs, absDomainName([]byte(windows.UTF16PtrToString(v.Host)))) } return ptrs, nil } diff --git a/libgo/go/net/mail/message.go b/libgo/go/net/mail/message.go index 0781310..09fb794 100644 --- a/libgo/go/net/mail/message.go +++ b/libgo/go/net/mail/message.go @@ -274,6 +274,12 @@ func (p *addrParser) parseAddressList() ([]*Address, error) { var list []*Address for { p.skipSpace() + + // allow skipping empty entries (RFC5322 obs-addr-list) + if p.consume(',') { + continue + } + addrs, err := p.parseAddress(true) if err != nil { return nil, err @@ -286,9 +292,17 @@ func (p *addrParser) parseAddressList() ([]*Address, error) { if p.empty() { break } - if !p.consume(',') { + if p.peek() != ',' { return nil, errors.New("mail: expected comma") } + + // Skip empty entries for obs-addr-list. + for p.consume(',') { + p.skipSpace() + } + if p.empty() { + break + } } return list, nil } diff --git a/libgo/go/net/mail/message_test.go b/libgo/go/net/mail/message_test.go index acab538..67e3643 100644 --- a/libgo/go/net/mail/message_test.go +++ b/libgo/go/net/mail/message_test.go @@ -431,6 +431,33 @@ func TestAddressParsing(t *testing.T) { }, }, }, + // RFC5322 4.4 obs-addr-list + { + ` , joe@where.test,,John <jdoe@one.test>,`, + []*Address{ + { + Name: "", + Address: "joe@where.test", + }, + { + Name: "John", + Address: "jdoe@one.test", + }, + }, + }, + { + ` , joe@where.test,,John <jdoe@one.test>,,`, + []*Address{ + { + Name: "", + Address: "joe@where.test", + }, + { + Name: "John", + Address: "jdoe@one.test", + }, + }, + }, { `Group1: <addr1@example.com>;, Group 2: addr2@example.com;, John <addr3@example.com>`, []*Address{ @@ -1053,3 +1080,22 @@ func TestAddressFormattingAndParsing(t *testing.T) { } } } + +func TestEmptyAddress(t *testing.T) { + parsed, err := ParseAddress("") + if parsed != nil || err == nil { + t.Errorf(`ParseAddress("") = %v, %v, want nil, error`, parsed, err) + } + list, err := ParseAddressList("") + if len(list) > 0 || err == nil { + t.Errorf(`ParseAddressList("") = %v, %v, want nil, error`, list, err) + } + list, err = ParseAddressList(",") + if len(list) > 0 || err == nil { + t.Errorf(`ParseAddressList("") = %v, %v, want nil, error`, list, err) + } + list, err = ParseAddressList("a@b c@d") + if len(list) > 0 || err == nil { + t.Errorf(`ParseAddressList("") = %v, %v, want nil, error`, list, err) + } +} diff --git a/libgo/go/net/net.go b/libgo/go/net/net.go index 1d7e5e7..2e61a7c 100644 --- a/libgo/go/net/net.go +++ b/libgo/go/net/net.go @@ -81,7 +81,6 @@ package net import ( "context" "errors" - "internal/poll" "io" "os" "sync" @@ -112,13 +111,13 @@ type Addr interface { // Multiple goroutines may invoke methods on a Conn simultaneously. type Conn interface { // Read reads data from the connection. - // Read can be made to time out and return an Error with Timeout() == true - // after a fixed time limit; see SetDeadline and SetReadDeadline. + // Read can be made to time out and return an error after a fixed + // time limit; see SetDeadline and SetReadDeadline. Read(b []byte) (n int, err error) // Write writes data to the connection. - // Write can be made to time out and return an Error with Timeout() == true - // after a fixed time limit; see SetDeadline and SetWriteDeadline. + // Write can be made to time out and return an error after a fixed + // time limit; see SetDeadline and SetWriteDeadline. Write(b []byte) (n int, err error) // Close closes the connection. @@ -136,23 +135,22 @@ type Conn interface { // SetReadDeadline and SetWriteDeadline. // // A deadline is an absolute time after which I/O operations - // fail with a timeout (see type Error) instead of - // blocking. The deadline applies to all future and pending - // I/O, not just the immediately following call to Read or - // Write. After a deadline has been exceeded, the connection - // can be refreshed by setting a deadline in the future. + // fail instead of blocking. The deadline applies to all future + // and pending I/O, not just the immediately following call to + // Read or Write. After a deadline has been exceeded, the + // connection can be refreshed by setting a deadline in the future. + // + // If the deadline is exceeded a call to Read or Write or to other + // I/O methods will return an error that wraps os.ErrDeadlineExceeded. + // This can be tested using errors.Is(err, os.ErrDeadlineExceeded). + // The error's Timeout method will return true, but note that there + // are other possible errors for which the Timeout method will + // return true even if the deadline has not been exceeded. // // An idle timeout can be implemented by repeatedly extending // the deadline after successful Read or Write calls. // // A zero value for t means I/O operations will not time out. - // - // Note that if a TCP connection has keep-alive turned on, - // which is the default unless overridden by Dialer.KeepAlive - // or ListenConfig.KeepAlive, then a keep-alive failure may - // also return a timeout error. On Unix systems a keep-alive - // failure on I/O can be detected using - // errors.Is(err, syscall.ETIMEDOUT). SetDeadline(t time.Time) error // SetReadDeadline sets the deadline for future Read calls @@ -315,15 +313,13 @@ type PacketConn interface { // It returns the number of bytes read (0 <= n <= len(p)) // and any error encountered. Callers should always process // the n > 0 bytes returned before considering the error err. - // ReadFrom can be made to time out and return - // an Error with Timeout() == true after a fixed time limit; - // see SetDeadline and SetReadDeadline. + // ReadFrom can be made to time out and return an error after a + // fixed time limit; see SetDeadline and SetReadDeadline. ReadFrom(p []byte) (n int, addr Addr, err error) // WriteTo writes a packet with payload p to addr. - // WriteTo can be made to time out and return - // an Error with Timeout() == true after a fixed time limit; - // see SetDeadline and SetWriteDeadline. + // WriteTo can be made to time out and return an Error after a + // fixed time limit; see SetDeadline and SetWriteDeadline. // On packet-oriented connections, write timeouts are rare. WriteTo(p []byte, addr Addr) (n int, err error) @@ -339,11 +335,17 @@ type PacketConn interface { // SetReadDeadline and SetWriteDeadline. // // A deadline is an absolute time after which I/O operations - // fail with a timeout (see type Error) instead of - // blocking. The deadline applies to all future and pending - // I/O, not just the immediately following call to ReadFrom or - // WriteTo. After a deadline has been exceeded, the connection - // can be refreshed by setting a deadline in the future. + // fail instead of blocking. The deadline applies to all future + // and pending I/O, not just the immediately following call to + // Read or Write. After a deadline has been exceeded, the + // connection can be refreshed by setting a deadline in the future. + // + // If the deadline is exceeded a call to Read or Write or to other + // I/O methods will return an error that wraps os.ErrDeadlineExceeded. + // This can be tested using errors.Is(err, os.ErrDeadlineExceeded). + // The error's Timeout method will return true, but note that there + // are other possible errors for which the Timeout method will + // return true even if the deadline has not been exceeded. // // An idle timeout can be implemented by repeatedly extending // the deadline after successful ReadFrom or WriteTo calls. @@ -420,7 +422,7 @@ func mapErr(err error) error { case context.Canceled: return errCanceled case context.DeadlineExceeded: - return poll.ErrTimeout + return errTimeout default: return err } @@ -567,6 +569,21 @@ func (e InvalidAddrError) Error() string { return string(e) } func (e InvalidAddrError) Timeout() bool { return false } func (e InvalidAddrError) Temporary() bool { return false } +// errTimeout exists to return the historical "i/o timeout" string +// for context.DeadlineExceeded. See mapErr. +// It is also used when Dialer.Deadline is exceeded. +// +// TODO(iant): We could consider changing this to os.ErrDeadlineExceeded +// in the future, but note that that would conflict with the TODO +// at mapErr that suggests changing it to context.DeadlineExceeded. +var errTimeout error = &timeoutError{} + +type timeoutError struct{} + +func (e *timeoutError) Error() string { return "i/o timeout" } +func (e *timeoutError) Timeout() bool { return true } +func (e *timeoutError) Temporary() bool { return true } + // DNSConfigError represents an error reading the machine's DNS configuration. // (No longer used; kept for compatibility.) type DNSConfigError struct { diff --git a/libgo/go/net/net_test.go b/libgo/go/net/net_test.go index a740674..409e140 100644 --- a/libgo/go/net/net_test.go +++ b/libgo/go/net/net_test.go @@ -23,50 +23,54 @@ func TestCloseRead(t *testing.T) { case "plan9": t.Skipf("not supported on %s", runtime.GOOS) } + t.Parallel() for _, network := range []string{"tcp", "unix", "unixpacket"} { - if !testableNetwork(network) { - t.Logf("skipping %s test", network) - continue - } + network := network + t.Run(network, func(t *testing.T) { + if !testableNetwork(network) { + t.Skipf("network %s is not testable on the current platform", network) + } + t.Parallel() - ln, err := newLocalListener(network) - if err != nil { - t.Fatal(err) - } - switch network { - case "unix", "unixpacket": - defer os.Remove(ln.Addr().String()) - } - defer ln.Close() + ln, err := newLocalListener(network) + if err != nil { + t.Fatal(err) + } + switch network { + case "unix", "unixpacket": + defer os.Remove(ln.Addr().String()) + } + defer ln.Close() - c, err := Dial(ln.Addr().Network(), ln.Addr().String()) - if err != nil { - t.Fatal(err) - } - switch network { - case "unix", "unixpacket": - defer os.Remove(c.LocalAddr().String()) - } - defer c.Close() + c, err := Dial(ln.Addr().Network(), ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + switch network { + case "unix", "unixpacket": + defer os.Remove(c.LocalAddr().String()) + } + defer c.Close() - switch c := c.(type) { - case *TCPConn: - err = c.CloseRead() - case *UnixConn: - err = c.CloseRead() - } - if err != nil { - if perr := parseCloseError(err, true); perr != nil { - t.Error(perr) + switch c := c.(type) { + case *TCPConn: + err = c.CloseRead() + case *UnixConn: + err = c.CloseRead() } - t.Fatal(err) - } - var b [1]byte - n, err := c.Read(b[:]) - if n != 0 || err == nil { - t.Fatalf("got (%d, %v); want (0, error)", n, err) - } + if err != nil { + if perr := parseCloseError(err, true); perr != nil { + t.Error(perr) + } + t.Fatal(err) + } + var b [1]byte + n, err := c.Read(b[:]) + if n != 0 || err == nil { + t.Fatalf("got (%d, %v); want (0, error)", n, err) + } + }) } } @@ -76,212 +80,240 @@ func TestCloseWrite(t *testing.T) { t.Skipf("not supported on %s", runtime.GOOS) } - handler := func(ls *localServer, ln Listener) { - c, err := ln.Accept() - if err != nil { - t.Error(err) - return - } - defer c.Close() - - var b [1]byte - n, err := c.Read(b[:]) - if n != 0 || err != io.EOF { - t.Errorf("got (%d, %v); want (0, io.EOF)", n, err) - return - } - switch c := c.(type) { - case *TCPConn: - err = c.CloseWrite() - case *UnixConn: - err = c.CloseWrite() - } - if err != nil { - if perr := parseCloseError(err, true); perr != nil { - t.Error(perr) - } - t.Error(err) - return - } - n, err = c.Write(b[:]) - if err == nil { - t.Errorf("got (%d, %v); want (any, error)", n, err) - return - } + t.Parallel() + deadline, _ := t.Deadline() + if !deadline.IsZero() { + // Leave 10% headroom on the deadline to report errors and clean up. + deadline = deadline.Add(-time.Until(deadline) / 10) } for _, network := range []string{"tcp", "unix", "unixpacket"} { - if !testableNetwork(network) { - t.Logf("skipping %s test", network) - continue - } + network := network + t.Run(network, func(t *testing.T) { + if !testableNetwork(network) { + t.Skipf("network %s is not testable on the current platform", network) + } + t.Parallel() + + handler := func(ls *localServer, ln Listener) { + c, err := ln.Accept() + if err != nil { + t.Error(err) + return + } + if !deadline.IsZero() { + c.SetDeadline(deadline) + } + defer c.Close() + + var b [1]byte + n, err := c.Read(b[:]) + if n != 0 || err != io.EOF { + t.Errorf("got (%d, %v); want (0, io.EOF)", n, err) + return + } + switch c := c.(type) { + case *TCPConn: + err = c.CloseWrite() + case *UnixConn: + err = c.CloseWrite() + } + if err != nil { + if perr := parseCloseError(err, true); perr != nil { + t.Error(perr) + } + t.Error(err) + return + } + n, err = c.Write(b[:]) + if err == nil { + t.Errorf("got (%d, %v); want (any, error)", n, err) + return + } + } - ls, err := newLocalServer(network) - if err != nil { - t.Fatal(err) - } - defer ls.teardown() - if err := ls.buildup(handler); err != nil { - t.Fatal(err) - } + ls, err := newLocalServer(network) + if err != nil { + t.Fatal(err) + } + defer ls.teardown() + if err := ls.buildup(handler); err != nil { + t.Fatal(err) + } - c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String()) - if err != nil { - t.Fatal(err) - } - switch network { - case "unix", "unixpacket": - defer os.Remove(c.LocalAddr().String()) - } - defer c.Close() + c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + if !deadline.IsZero() { + c.SetDeadline(deadline) + } + switch network { + case "unix", "unixpacket": + defer os.Remove(c.LocalAddr().String()) + } + defer c.Close() - switch c := c.(type) { - case *TCPConn: - err = c.CloseWrite() - case *UnixConn: - err = c.CloseWrite() - } - if err != nil { - if perr := parseCloseError(err, true); perr != nil { - t.Error(perr) + switch c := c.(type) { + case *TCPConn: + err = c.CloseWrite() + case *UnixConn: + err = c.CloseWrite() } - t.Fatal(err) - } - var b [1]byte - n, err := c.Read(b[:]) - if n != 0 || err != io.EOF { - t.Fatalf("got (%d, %v); want (0, io.EOF)", n, err) - } - n, err = c.Write(b[:]) - if err == nil { - t.Fatalf("got (%d, %v); want (any, error)", n, err) - } + if err != nil { + if perr := parseCloseError(err, true); perr != nil { + t.Error(perr) + } + t.Fatal(err) + } + var b [1]byte + n, err := c.Read(b[:]) + if n != 0 || err != io.EOF { + t.Fatalf("got (%d, %v); want (0, io.EOF)", n, err) + } + n, err = c.Write(b[:]) + if err == nil { + t.Fatalf("got (%d, %v); want (any, error)", n, err) + } + }) } } func TestConnClose(t *testing.T) { + t.Parallel() for _, network := range []string{"tcp", "unix", "unixpacket"} { - if !testableNetwork(network) { - t.Logf("skipping %s test", network) - continue - } + network := network + t.Run(network, func(t *testing.T) { + if !testableNetwork(network) { + t.Skipf("network %s is not testable on the current platform", network) + } + t.Parallel() - ln, err := newLocalListener(network) - if err != nil { - t.Fatal(err) - } - switch network { - case "unix", "unixpacket": - defer os.Remove(ln.Addr().String()) - } - defer ln.Close() + ln, err := newLocalListener(network) + if err != nil { + t.Fatal(err) + } + switch network { + case "unix", "unixpacket": + defer os.Remove(ln.Addr().String()) + } + defer ln.Close() - c, err := Dial(ln.Addr().Network(), ln.Addr().String()) - if err != nil { - t.Fatal(err) - } - switch network { - case "unix", "unixpacket": - defer os.Remove(c.LocalAddr().String()) - } - defer c.Close() + c, err := Dial(ln.Addr().Network(), ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + switch network { + case "unix", "unixpacket": + defer os.Remove(c.LocalAddr().String()) + } + defer c.Close() - if err := c.Close(); err != nil { - if perr := parseCloseError(err, false); perr != nil { - t.Error(perr) + if err := c.Close(); err != nil { + if perr := parseCloseError(err, false); perr != nil { + t.Error(perr) + } + t.Fatal(err) } - t.Fatal(err) - } - var b [1]byte - n, err := c.Read(b[:]) - if n != 0 || err == nil { - t.Fatalf("got (%d, %v); want (0, error)", n, err) - } + var b [1]byte + n, err := c.Read(b[:]) + if n != 0 || err == nil { + t.Fatalf("got (%d, %v); want (0, error)", n, err) + } + }) } } func TestListenerClose(t *testing.T) { + t.Parallel() for _, network := range []string{"tcp", "unix", "unixpacket"} { - if !testableNetwork(network) { - t.Logf("skipping %s test", network) - continue - } - - ln, err := newLocalListener(network) - if err != nil { - t.Fatal(err) - } - switch network { - case "unix", "unixpacket": - defer os.Remove(ln.Addr().String()) - } + network := network + t.Run(network, func(t *testing.T) { + if !testableNetwork(network) { + t.Skipf("network %s is not testable on the current platform", network) + } + t.Parallel() - dst := ln.Addr().String() - if err := ln.Close(); err != nil { - if perr := parseCloseError(err, false); perr != nil { - t.Error(perr) + ln, err := newLocalListener(network) + if err != nil { + t.Fatal(err) + } + switch network { + case "unix", "unixpacket": + defer os.Remove(ln.Addr().String()) } - t.Fatal(err) - } - c, err := ln.Accept() - if err == nil { - c.Close() - t.Fatal("should fail") - } - if network == "tcp" { - // We will have two TCP FSMs inside the - // kernel here. There's no guarantee that a - // signal comes from the far end FSM will be - // delivered immediately to the near end FSM, - // especially on the platforms that allow - // multiple consumer threads to pull pending - // established connections at the same time by - // enabling SO_REUSEPORT option such as Linux, - // DragonFly BSD. So we need to give some time - // quantum to the kernel. - // - // Note that net.inet.tcp.reuseport_ext=1 by - // default on DragonFly BSD. - time.Sleep(time.Millisecond) - - cc, err := Dial("tcp", dst) + dst := ln.Addr().String() + if err := ln.Close(); err != nil { + if perr := parseCloseError(err, false); perr != nil { + t.Error(perr) + } + t.Fatal(err) + } + c, err := ln.Accept() if err == nil { - t.Error("Dial to closed TCP listener succeeded.") - cc.Close() + c.Close() + t.Fatal("should fail") } - } + + if network == "tcp" { + // We will have two TCP FSMs inside the + // kernel here. There's no guarantee that a + // signal comes from the far end FSM will be + // delivered immediately to the near end FSM, + // especially on the platforms that allow + // multiple consumer threads to pull pending + // established connections at the same time by + // enabling SO_REUSEPORT option such as Linux, + // DragonFly BSD. So we need to give some time + // quantum to the kernel. + // + // Note that net.inet.tcp.reuseport_ext=1 by + // default on DragonFly BSD. + time.Sleep(time.Millisecond) + + cc, err := Dial("tcp", dst) + if err == nil { + t.Error("Dial to closed TCP listener succeeded.") + cc.Close() + } + } + }) } } func TestPacketConnClose(t *testing.T) { + t.Parallel() for _, network := range []string{"udp", "unixgram"} { - if !testableNetwork(network) { - t.Logf("skipping %s test", network) - continue - } + network := network + t.Run(network, func(t *testing.T) { + if !testableNetwork(network) { + t.Skipf("network %s is not testable on the current platform", network) + } + t.Parallel() - c, err := newLocalPacketListener(network) - if err != nil { - t.Fatal(err) - } - switch network { - case "unixgram": - defer os.Remove(c.LocalAddr().String()) - } - defer c.Close() + c, err := newLocalPacketListener(network) + if err != nil { + t.Fatal(err) + } + switch network { + case "unixgram": + defer os.Remove(c.LocalAddr().String()) + } + defer c.Close() - if err := c.Close(); err != nil { - if perr := parseCloseError(err, false); perr != nil { - t.Error(perr) + if err := c.Close(); err != nil { + if perr := parseCloseError(err, false); perr != nil { + t.Error(perr) + } + t.Fatal(err) } - t.Fatal(err) - } - var b [1]byte - n, _, err := c.ReadFrom(b[:]) - if n != 0 || err == nil { - t.Fatalf("got (%d, %v); want (0, error)", n, err) - } + var b [1]byte + n, _, err := c.ReadFrom(b[:]) + if n != 0 || err == nil { + t.Fatalf("got (%d, %v); want (0, error)", n, err) + } + }) } } @@ -366,56 +398,60 @@ func TestAcceptIgnoreAbortedConnRequest(t *testing.T) { } func TestZeroByteRead(t *testing.T) { + t.Parallel() for _, network := range []string{"tcp", "unix", "unixpacket"} { - if !testableNetwork(network) { - t.Logf("skipping %s test", network) - continue - } + network := network + t.Run(network, func(t *testing.T) { + if !testableNetwork(network) { + t.Skipf("network %s is not testable on the current platform", network) + } + t.Parallel() - ln, err := newLocalListener(network) - if err != nil { - t.Fatal(err) - } - connc := make(chan Conn, 1) - go func() { - defer ln.Close() - c, err := ln.Accept() + ln, err := newLocalListener(network) if err != nil { - t.Error(err) + t.Fatal(err) } - connc <- c // might be nil - }() - c, err := Dial(network, ln.Addr().String()) - if err != nil { - t.Fatal(err) - } - defer c.Close() - sc := <-connc - if sc == nil { - continue - } - defer sc.Close() + connc := make(chan Conn, 1) + go func() { + defer ln.Close() + c, err := ln.Accept() + if err != nil { + t.Error(err) + } + connc <- c // might be nil + }() + c, err := Dial(network, ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer c.Close() + sc := <-connc + if sc == nil { + return + } + defer sc.Close() - if runtime.GOOS == "windows" { - // A zero byte read on Windows caused a wait for readability first. - // Rather than change that behavior, satisfy it in this test. - // See Issue 15735. - go io.WriteString(sc, "a") - } + if runtime.GOOS == "windows" { + // A zero byte read on Windows caused a wait for readability first. + // Rather than change that behavior, satisfy it in this test. + // See Issue 15735. + go io.WriteString(sc, "a") + } - n, err := c.Read(nil) - if n != 0 || err != nil { - t.Errorf("%s: zero byte client read = %v, %v; want 0, nil", network, n, err) - } + n, err := c.Read(nil) + if n != 0 || err != nil { + t.Errorf("%s: zero byte client read = %v, %v; want 0, nil", network, n, err) + } - if runtime.GOOS == "windows" { - // Same as comment above. - go io.WriteString(c, "a") - } - n, err = sc.Read(nil) - if n != 0 || err != nil { - t.Errorf("%s: zero byte server read = %v, %v; want 0, nil", network, n, err) - } + if runtime.GOOS == "windows" { + // Same as comment above. + go io.WriteString(c, "a") + } + n, err = sc.Read(nil) + if n != 0 || err != nil { + t.Errorf("%s: zero byte server read = %v, %v; want 0, nil", network, n, err) + } + }) } } diff --git a/libgo/go/net/pipe.go b/libgo/go/net/pipe.go index 9177fc4..f174193 100644 --- a/libgo/go/net/pipe.go +++ b/libgo/go/net/pipe.go @@ -6,6 +6,7 @@ package net import ( "io" + "os" "sync" "time" ) @@ -78,12 +79,6 @@ func isClosedChan(c <-chan struct{}) bool { } } -type timeoutError struct{} - -func (timeoutError) Error() string { return "deadline exceeded" } -func (timeoutError) Timeout() bool { return true } -func (timeoutError) Temporary() bool { return true } - type pipeAddr struct{} func (pipeAddr) Network() string { return "pipe" } @@ -158,7 +153,7 @@ func (p *pipe) read(b []byte) (n int, err error) { case isClosedChan(p.remoteDone): return 0, io.EOF case isClosedChan(p.readDeadline.wait()): - return 0, timeoutError{} + return 0, os.ErrDeadlineExceeded } select { @@ -171,7 +166,7 @@ func (p *pipe) read(b []byte) (n int, err error) { case <-p.remoteDone: return 0, io.EOF case <-p.readDeadline.wait(): - return 0, timeoutError{} + return 0, os.ErrDeadlineExceeded } } @@ -190,7 +185,7 @@ func (p *pipe) write(b []byte) (n int, err error) { case isClosedChan(p.remoteDone): return 0, io.ErrClosedPipe case isClosedChan(p.writeDeadline.wait()): - return 0, timeoutError{} + return 0, os.ErrDeadlineExceeded } p.wrMu.Lock() // Ensure entirety of b is written together @@ -206,7 +201,7 @@ func (p *pipe) write(b []byte) (n int, err error) { case <-p.remoteDone: return n, io.ErrClosedPipe case <-p.writeDeadline.wait(): - return n, timeoutError{} + return n, os.ErrDeadlineExceeded } } return n, nil diff --git a/libgo/go/net/platform_test.go b/libgo/go/net/platform_test.go index d35dfaa..d3bb918 100644 --- a/libgo/go/net/platform_test.go +++ b/libgo/go/net/platform_test.go @@ -54,7 +54,7 @@ func testableNetwork(network string) bool { return unixEnabledOnAIX } // iOS does not support unix, unixgram. - if runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") { + if iOS() { return false } case "unixpacket": @@ -81,6 +81,10 @@ func testableNetwork(network string) bool { return true } +func iOS() bool { + return runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" +} + // testableAddress reports whether address of network is testable on // the current platform configuration. func testableAddress(network, address string) bool { diff --git a/libgo/go/net/rawconn_test.go b/libgo/go/net/rawconn_test.go index 9a82f8f..a08ff89 100644 --- a/libgo/go/net/rawconn_test.go +++ b/libgo/go/net/rawconn_test.go @@ -130,7 +130,7 @@ func TestRawConnReadWrite(t *testing.T) { if perr := parseWriteError(err); perr != nil { t.Error(perr) } - if nerr, ok := err.(Error); !ok || !nerr.Timeout() { + if !isDeadlineExceeded(err) { t.Errorf("got %v; want timeout", err) } if _, err = readRawConn(cc, b[:]); err == nil { @@ -139,7 +139,7 @@ func TestRawConnReadWrite(t *testing.T) { if perr := parseReadError(err); perr != nil { t.Error(perr) } - if nerr, ok := err.(Error); !ok || !nerr.Timeout() { + if !isDeadlineExceeded(err) { t.Errorf("got %v; want timeout", err) } @@ -153,7 +153,7 @@ func TestRawConnReadWrite(t *testing.T) { if perr := parseReadError(err); perr != nil { t.Error(perr) } - if nerr, ok := err.(Error); !ok || !nerr.Timeout() { + if !isDeadlineExceeded(err) { t.Errorf("got %v; want timeout", err) } @@ -167,7 +167,7 @@ func TestRawConnReadWrite(t *testing.T) { if perr := parseWriteError(err); perr != nil { t.Error(perr) } - if nerr, ok := err.(Error); !ok || !nerr.Timeout() { + if !isDeadlineExceeded(err) { t.Errorf("got %v; want timeout", err) } }) diff --git a/libgo/go/net/rpc/client.go b/libgo/go/net/rpc/client.go index cad2d45..25f2a00 100644 --- a/libgo/go/net/rpc/client.go +++ b/libgo/go/net/rpc/client.go @@ -31,7 +31,7 @@ type Call struct { Args interface{} // The argument to the function (*struct). Reply interface{} // The reply from the function (*struct). Error error // After completion, the error status. - Done chan *Call // Strobes when call is complete. + Done chan *Call // Receives *Call when Go is complete. } // Client represents an RPC Client. diff --git a/libgo/go/net/rpc/jsonrpc/all_test.go b/libgo/go/net/rpc/jsonrpc/all_test.go index bbb8eb0..4e73edc 100644 --- a/libgo/go/net/rpc/jsonrpc/all_test.go +++ b/libgo/go/net/rpc/jsonrpc/all_test.go @@ -127,8 +127,8 @@ func TestServer(t *testing.T) { if resp.Error != nil { t.Fatalf("resp.Error: %s", resp.Error) } - if resp.Id.(string) != string(i) { - t.Fatalf("resp: bad id %q want %q", resp.Id.(string), string(i)) + if resp.Id.(string) != string(rune(i)) { + t.Fatalf("resp: bad id %q want %q", resp.Id.(string), string(rune(i))) } if resp.Result.C != 2*i+1 { t.Fatalf("resp: bad result: %d+%d=%d", i, i+1, resp.Result.C) diff --git a/libgo/go/net/sockopt_aix.go b/libgo/go/net/sockopt_aix.go index b49c4d5..7729a44 100644 --- a/libgo/go/net/sockopt_aix.go +++ b/libgo/go/net/sockopt_aix.go @@ -16,8 +16,11 @@ func setDefaultSockopts(s, family, sotype int, ipv6only bool) error { // never admit this option. syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, boolint(ipv6only)) } - // Allow broadcast. - return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)) + if (sotype == syscall.SOCK_DGRAM || sotype == syscall.SOCK_RAW) && family != syscall.AF_UNIX { + // Allow broadcast. + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)) + } + return nil } func setDefaultListenerSockopts(s int) error { diff --git a/libgo/go/net/sockopt_bsd.go b/libgo/go/net/sockopt_bsd.go index 1aae88a..7b8b8d9 100644 --- a/libgo/go/net/sockopt_bsd.go +++ b/libgo/go/net/sockopt_bsd.go @@ -31,8 +31,11 @@ func setDefaultSockopts(s, family, sotype int, ipv6only bool) error { // never admit this option. syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, boolint(ipv6only)) } - // Allow broadcast. - return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)) + if (sotype == syscall.SOCK_DGRAM || sotype == syscall.SOCK_RAW) && family != syscall.AF_UNIX { + // Allow broadcast. + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)) + } + return nil } func setDefaultListenerSockopts(s int) error { diff --git a/libgo/go/net/sockopt_linux.go b/libgo/go/net/sockopt_linux.go index 0f70b12..3d54429 100644 --- a/libgo/go/net/sockopt_linux.go +++ b/libgo/go/net/sockopt_linux.go @@ -16,8 +16,11 @@ func setDefaultSockopts(s, family, sotype int, ipv6only bool) error { // never admit this option. syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, boolint(ipv6only)) } - // Allow broadcast. - return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)) + if (sotype == syscall.SOCK_DGRAM || sotype == syscall.SOCK_RAW) && family != syscall.AF_UNIX { + // Allow broadcast. + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)) + } + return nil } func setDefaultListenerSockopts(s int) error { diff --git a/libgo/go/net/sockopt_solaris.go b/libgo/go/net/sockopt_solaris.go index 0f70b12..3d54429 100644 --- a/libgo/go/net/sockopt_solaris.go +++ b/libgo/go/net/sockopt_solaris.go @@ -16,8 +16,11 @@ func setDefaultSockopts(s, family, sotype int, ipv6only bool) error { // never admit this option. syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, boolint(ipv6only)) } - // Allow broadcast. - return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)) + if (sotype == syscall.SOCK_DGRAM || sotype == syscall.SOCK_RAW) && family != syscall.AF_UNIX { + // Allow broadcast. + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)) + } + return nil } func setDefaultListenerSockopts(s int) error { diff --git a/libgo/go/net/sockopt_windows.go b/libgo/go/net/sockopt_windows.go index 8017426..8afaf34 100644 --- a/libgo/go/net/sockopt_windows.go +++ b/libgo/go/net/sockopt_windows.go @@ -16,8 +16,10 @@ func setDefaultSockopts(s syscall.Handle, family, sotype int, ipv6only bool) err // never admit this option. syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, boolint(ipv6only)) } - // Allow broadcast. - syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1) + if (sotype == syscall.SOCK_DGRAM || sotype == syscall.SOCK_RAW) && family != syscall.AF_UNIX && family != syscall.AF_INET6 { + // Allow broadcast. + return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1)) + } return nil } diff --git a/libgo/go/net/textproto/pipeline.go b/libgo/go/net/textproto/pipeline.go index 2e28321..1928a30 100644 --- a/libgo/go/net/textproto/pipeline.go +++ b/libgo/go/net/textproto/pipeline.go @@ -72,7 +72,7 @@ func (p *Pipeline) EndResponse(id uint) { type sequencer struct { mu sync.Mutex id uint - wait map[uint]chan uint + wait map[uint]chan struct{} } // Start waits until it is time for the event numbered id to begin. @@ -84,9 +84,9 @@ func (s *sequencer) Start(id uint) { s.mu.Unlock() return } - c := make(chan uint) + c := make(chan struct{}) if s.wait == nil { - s.wait = make(map[uint]chan uint) + s.wait = make(map[uint]chan struct{}) } s.wait[id] = c s.mu.Unlock() @@ -99,12 +99,13 @@ func (s *sequencer) Start(id uint) { func (s *sequencer) End(id uint) { s.mu.Lock() if s.id != id { + s.mu.Unlock() panic("out of sync") } id++ s.id = id if s.wait == nil { - s.wait = make(map[uint]chan uint) + s.wait = make(map[uint]chan struct{}) } c, ok := s.wait[id] if ok { @@ -112,6 +113,6 @@ func (s *sequencer) End(id uint) { } s.mu.Unlock() if ok { - c <- 1 + close(c) } } diff --git a/libgo/go/net/textproto/reader.go b/libgo/go/net/textproto/reader.go index a505da9..a00fd23 100644 --- a/libgo/go/net/textproto/reader.go +++ b/libgo/go/net/textproto/reader.go @@ -88,7 +88,7 @@ func (r *Reader) readLineSlice() ([]byte, error) { // The first call to ReadContinuedLine will return "Line 1 continued..." // and the second will return "Line 2". // -// A line consisting of only white space is never continued. +// Empty lines are never continued. // func (r *Reader) ReadContinuedLine() (string, error) { line, err := r.readContinuedLineSlice(noValidation) @@ -557,7 +557,7 @@ func noValidation(_ []byte) error { return nil } // contain a colon. func mustHaveFieldNameColon(line []byte) error { if bytes.IndexByte(line, ':') < 0 { - return ProtocolError(fmt.Sprintf("malformed MIME header: missing colon: %q" + string(line))) + return ProtocolError(fmt.Sprintf("malformed MIME header: missing colon: %q", line)) } return nil } diff --git a/libgo/go/net/timeout_test.go b/libgo/go/net/timeout_test.go index f54c956..ad14cd7 100644 --- a/libgo/go/net/timeout_test.go +++ b/libgo/go/net/timeout_test.go @@ -7,12 +7,13 @@ package net import ( + "errors" "fmt" - "internal/poll" "internal/testenv" "io" "io/ioutil" "net/internal/socktest" + "os" "runtime" "sync" "testing" @@ -148,9 +149,9 @@ var acceptTimeoutTests = []struct { }{ // Tests that accept deadlines in the past work, even if // there's incoming connections available. - {-5 * time.Second, [2]error{poll.ErrTimeout, poll.ErrTimeout}}, + {-5 * time.Second, [2]error{os.ErrDeadlineExceeded, os.ErrDeadlineExceeded}}, - {50 * time.Millisecond, [2]error{nil, poll.ErrTimeout}}, + {50 * time.Millisecond, [2]error{nil, os.ErrDeadlineExceeded}}, } func TestAcceptTimeout(t *testing.T) { @@ -194,7 +195,7 @@ func TestAcceptTimeout(t *testing.T) { if perr := parseAcceptError(err); perr != nil { t.Errorf("#%d/%d: %v", i, j, perr) } - if nerr, ok := err.(Error); !ok || !nerr.Timeout() { + if !isDeadlineExceeded(err) { t.Fatalf("#%d/%d: %v", i, j, err) } } @@ -250,7 +251,7 @@ func TestAcceptTimeoutMustReturn(t *testing.T) { if perr := parseAcceptError(err); perr != nil { t.Error(perr) } - if nerr, ok := err.(Error); !ok || !nerr.Timeout() { + if !isDeadlineExceeded(err) { t.Fatal(err) } } @@ -302,9 +303,9 @@ var readTimeoutTests = []struct { }{ // Tests that read deadlines work, even if there's data ready // to be read. - {-5 * time.Second, [2]error{poll.ErrTimeout, poll.ErrTimeout}}, + {-5 * time.Second, [2]error{os.ErrDeadlineExceeded, os.ErrDeadlineExceeded}}, - {50 * time.Millisecond, [2]error{nil, poll.ErrTimeout}}, + {50 * time.Millisecond, [2]error{nil, os.ErrDeadlineExceeded}}, } func TestReadTimeout(t *testing.T) { @@ -344,7 +345,7 @@ func TestReadTimeout(t *testing.T) { if perr := parseReadError(err); perr != nil { t.Errorf("#%d/%d: %v", i, j, perr) } - if nerr, ok := err.(Error); !ok || !nerr.Timeout() { + if !isDeadlineExceeded(err) { t.Fatalf("#%d/%d: %v", i, j, err) } } @@ -423,9 +424,9 @@ var readFromTimeoutTests = []struct { }{ // Tests that read deadlines work, even if there's data ready // to be read. - {-5 * time.Second, [2]error{poll.ErrTimeout, poll.ErrTimeout}}, + {-5 * time.Second, [2]error{os.ErrDeadlineExceeded, os.ErrDeadlineExceeded}}, - {50 * time.Millisecond, [2]error{nil, poll.ErrTimeout}}, + {50 * time.Millisecond, [2]error{nil, os.ErrDeadlineExceeded}}, } func TestReadFromTimeout(t *testing.T) { @@ -468,7 +469,7 @@ func TestReadFromTimeout(t *testing.T) { if perr := parseReadError(err); perr != nil { t.Errorf("#%d/%d: %v", i, j, perr) } - if nerr, ok := err.(Error); !ok || !nerr.Timeout() { + if !isDeadlineExceeded(err) { t.Fatalf("#%d/%d: %v", i, j, err) } } @@ -491,9 +492,9 @@ var writeTimeoutTests = []struct { }{ // Tests that write deadlines work, even if there's buffer // space available to write. - {-5 * time.Second, [2]error{poll.ErrTimeout, poll.ErrTimeout}}, + {-5 * time.Second, [2]error{os.ErrDeadlineExceeded, os.ErrDeadlineExceeded}}, - {10 * time.Millisecond, [2]error{nil, poll.ErrTimeout}}, + {10 * time.Millisecond, [2]error{nil, os.ErrDeadlineExceeded}}, } func TestWriteTimeout(t *testing.T) { @@ -522,7 +523,7 @@ func TestWriteTimeout(t *testing.T) { if perr := parseWriteError(err); perr != nil { t.Errorf("#%d/%d: %v", i, j, perr) } - if nerr, ok := err.(Error); !ok || !nerr.Timeout() { + if !isDeadlineExceeded(err) { t.Fatalf("#%d/%d: %v", i, j, err) } } @@ -605,9 +606,9 @@ var writeToTimeoutTests = []struct { }{ // Tests that write deadlines work, even if there's buffer // space available to write. - {-5 * time.Second, [2]error{poll.ErrTimeout, poll.ErrTimeout}}, + {-5 * time.Second, [2]error{os.ErrDeadlineExceeded, os.ErrDeadlineExceeded}}, - {10 * time.Millisecond, [2]error{nil, poll.ErrTimeout}}, + {10 * time.Millisecond, [2]error{nil, os.ErrDeadlineExceeded}}, } func TestWriteToTimeout(t *testing.T) { @@ -641,7 +642,7 @@ func TestWriteToTimeout(t *testing.T) { if perr := parseWriteError(err); perr != nil { t.Errorf("#%d/%d: %v", i, j, perr) } - if nerr, ok := err.(Error); !ok || !nerr.Timeout() { + if !isDeadlineExceeded(err) { t.Fatalf("#%d/%d: %v", i, j, err) } } @@ -685,7 +686,7 @@ func TestReadTimeoutFluctuation(t *testing.T) { if perr := parseReadError(err); perr != nil { t.Error(perr) } - if nerr, ok := err.(Error); !ok || !nerr.Timeout() { + if !isDeadlineExceeded(err) { t.Fatal(err) } } @@ -718,7 +719,7 @@ func TestReadFromTimeoutFluctuation(t *testing.T) { if perr := parseReadError(err); perr != nil { t.Error(perr) } - if nerr, ok := err.(Error); !ok || !nerr.Timeout() { + if !isDeadlineExceeded(err) { t.Fatal(err) } } @@ -745,7 +746,7 @@ func TestWriteTimeoutFluctuation(t *testing.T) { defer c.Close() d := time.Second - if runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") { + if iOS() { d = 3 * time.Second // see golang.org/issue/10775 } max := time.NewTimer(d) @@ -760,7 +761,7 @@ func TestWriteTimeoutFluctuation(t *testing.T) { if perr := parseWriteError(err); perr != nil { t.Error(perr) } - if nerr, ok := err.(Error); !ok || !nerr.Timeout() { + if !isDeadlineExceeded(err) { t.Fatal(err) } } @@ -1073,3 +1074,20 @@ func TestConcurrentSetDeadline(t *testing.T) { } wg.Wait() } + +// isDeadlineExceeded reports whether err is or wraps os.ErrDeadlineExceeded. +// We also check that the error implements net.Error, and that the +// Timeout method returns true. +func isDeadlineExceeded(err error) bool { + nerr, ok := err.(Error) + if !ok { + return false + } + if !nerr.Timeout() { + return false + } + if !errors.Is(err, os.ErrDeadlineExceeded) { + return false + } + return true +} diff --git a/libgo/go/net/unixsock_test.go b/libgo/go/net/unixsock_test.go index 80cccf2..4b2cfc4d 100644 --- a/libgo/go/net/unixsock_test.go +++ b/libgo/go/net/unixsock_test.go @@ -113,7 +113,7 @@ func TestUnixgramZeroBytePayload(t *testing.T) { t.Fatalf("unexpected peer address: %v", peer) } default: // Read may timeout, it depends on the platform - if nerr, ok := err.(Error); !ok || !nerr.Timeout() { + if !isDeadlineExceeded(err) { t.Fatal(err) } } @@ -163,7 +163,7 @@ func TestUnixgramZeroByteBuffer(t *testing.T) { t.Fatalf("unexpected peer address: %v", peer) } default: // Read may timeout, it depends on the platform - if nerr, ok := err.(Error); !ok || !nerr.Timeout() { + if !isDeadlineExceeded(err) { t.Fatal(err) } } diff --git a/libgo/go/net/url/url.go b/libgo/go/net/url/url.go index 2880e82..c93def0 100644 --- a/libgo/go/net/url/url.go +++ b/libgo/go/net/url/url.go @@ -356,15 +356,16 @@ func escape(s string, mode encoding) string { // URL's String method uses the EscapedPath method to obtain the path. See the // EscapedPath method for more details. type URL struct { - Scheme string - Opaque string // encoded opaque data - User *Userinfo // username and password information - Host string // host or host:port - Path string // path (relative paths may omit leading slash) - RawPath string // encoded path hint (see EscapedPath method) - ForceQuery bool // append a query ('?') even if RawQuery is empty - RawQuery string // encoded query values, without '?' - Fragment string // fragment for references, without '#' + Scheme string + Opaque string // encoded opaque data + User *Userinfo // username and password information + Host string // host or host:port + Path string // path (relative paths may omit leading slash) + RawPath string // encoded path hint (see EscapedPath method) + ForceQuery bool // append a query ('?') even if RawQuery is empty + RawQuery string // encoded query values, without '?' + Fragment string // fragment for references, without '#' + RawFragment string // encoded fragment hint (see EscapedFragment method) } // User returns a Userinfo containing the provided username @@ -481,7 +482,7 @@ func Parse(rawurl string) (*URL, error) { if frag == "" { return url, nil } - if url.Fragment, err = unescape(frag, encodeFragment); err != nil { + if err = url.setFragment(frag); err != nil { return nil, &Error{"parse", rawurl, err} } return url, nil @@ -697,7 +698,7 @@ func (u *URL) setPath(p string) error { // In general, code should call EscapedPath instead of // reading u.RawPath directly. func (u *URL) EscapedPath() string { - if u.RawPath != "" && validEncodedPath(u.RawPath) { + if u.RawPath != "" && validEncoded(u.RawPath, encodePath) { p, err := unescape(u.RawPath, encodePath) if err == nil && p == u.Path { return u.RawPath @@ -709,9 +710,10 @@ func (u *URL) EscapedPath() string { return escape(u.Path, encodePath) } -// validEncodedPath reports whether s is a valid encoded path. -// It must not contain any bytes that require escaping during path encoding. -func validEncodedPath(s string) bool { +// validEncoded reports whether s is a valid encoded path or fragment, +// according to mode. +// It must not contain any bytes that require escaping during encoding. +func validEncoded(s string, mode encoding) bool { for i := 0; i < len(s); i++ { // RFC 3986, Appendix A. // pchar = unreserved / pct-encoded / sub-delims / ":" / "@". @@ -726,7 +728,7 @@ func validEncodedPath(s string) bool { case '%': // ok - percent encoded, will decode default: - if shouldEscape(s[i], encodePath) { + if shouldEscape(s[i], mode) { return false } } @@ -734,6 +736,40 @@ func validEncodedPath(s string) bool { return true } +// setFragment is like setPath but for Fragment/RawFragment. +func (u *URL) setFragment(f string) error { + frag, err := unescape(f, encodeFragment) + if err != nil { + return err + } + u.Fragment = frag + if escf := escape(frag, encodeFragment); f == escf { + // Default encoding is fine. + u.RawFragment = "" + } else { + u.RawFragment = f + } + return nil +} + +// EscapedFragment returns the escaped form of u.Fragment. +// In general there are multiple possible escaped forms of any fragment. +// EscapedFragment returns u.RawFragment when it is a valid escaping of u.Fragment. +// Otherwise EscapedFragment ignores u.RawFragment and computes an escaped +// form on its own. +// The String method uses EscapedFragment to construct its result. +// In general, code should call EscapedFragment instead of +// reading u.RawFragment directly. +func (u *URL) EscapedFragment() string { + if u.RawFragment != "" && validEncoded(u.RawFragment, encodeFragment) { + f, err := unescape(u.RawFragment, encodeFragment) + if err == nil && f == u.Fragment { + return u.RawFragment + } + } + return escape(u.Fragment, encodeFragment) +} + // validOptionalPort reports whether port is either an empty string // or matches /^:\d*$/ func validOptionalPort(port string) bool { @@ -816,11 +852,25 @@ func (u *URL) String() string { } if u.Fragment != "" { buf.WriteByte('#') - buf.WriteString(escape(u.Fragment, encodeFragment)) + buf.WriteString(u.EscapedFragment()) } return buf.String() } +// Redacted is like String but replaces any password with "xxxxx". +// Only the password in u.URL is redacted. +func (u *URL) Redacted() string { + if u == nil { + return "" + } + + ru := *u + if _, has := ru.User.Password(); has { + ru.User = UserPassword(ru.User.Username(), "xxxxx") + } + return ru.String() +} + // Values maps a string key to a list of values. // It is typically used for query parameters and form values. // Unlike in the http.Header map, the keys in a Values map @@ -1016,6 +1066,7 @@ func (u *URL) ResolveReference(ref *URL) *URL { url.RawQuery = u.RawQuery if ref.Fragment == "" { url.Fragment = u.Fragment + url.RawFragment = u.RawFragment } } // The "abs_path" or "rel_path" cases. diff --git a/libgo/go/net/url/url_test.go b/libgo/go/net/url/url_test.go index 79fd3d5..92b15af 100644 --- a/libgo/go/net/url/url_test.go +++ b/libgo/go/net/url/url_test.go @@ -19,7 +19,7 @@ import ( type URLTest struct { in string - out *URL // expected parse; RawPath="" means same as Path + out *URL // expected parse roundtrip string // expected result of reserializing the URL; empty means same as "in". } @@ -54,6 +54,18 @@ var urltests = []URLTest{ }, "", }, + // fragment with hex escaping + { + "http://www.google.com/#file%20one%26two", + &URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/", + Fragment: "file one&two", + RawFragment: "file%20one%26two", + }, + "", + }, // user { "ftp://webmaster@www.google.com/", @@ -261,7 +273,7 @@ var urltests = []URLTest{ "", }, { - "http://www.google.com/?q=go+language#foo%26bar", + "http://www.google.com/?q=go+language#foo&bar", &URL{ Scheme: "http", Host: "www.google.com", @@ -272,6 +284,18 @@ var urltests = []URLTest{ "http://www.google.com/?q=go+language#foo&bar", }, { + "http://www.google.com/?q=go+language#foo%26bar", + &URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/", + RawQuery: "q=go+language", + Fragment: "foo&bar", + RawFragment: "foo%26bar", + }, + "http://www.google.com/?q=go+language#foo%26bar", + }, + { "file:///home/adg/rabbits", &URL{ Scheme: "file", @@ -601,8 +625,8 @@ func ufmt(u *URL) string { pass = p } } - return fmt.Sprintf("opaque=%q, scheme=%q, user=%#v, pass=%#v, host=%q, path=%q, rawpath=%q, rawq=%q, frag=%q, forcequery=%v", - u.Opaque, u.Scheme, user, pass, u.Host, u.Path, u.RawPath, u.RawQuery, u.Fragment, u.ForceQuery) + return fmt.Sprintf("opaque=%q, scheme=%q, user=%#v, pass=%#v, host=%q, path=%q, rawpath=%q, rawq=%q, frag=%q, rawfrag=%q, forcequery=%v", + u.Opaque, u.Scheme, user, pass, u.Host, u.Path, u.RawPath, u.RawQuery, u.Fragment, u.RawFragment, u.ForceQuery) } func BenchmarkString(b *testing.B) { @@ -765,6 +789,73 @@ func TestURLString(t *testing.T) { } } +func TestURLRedacted(t *testing.T) { + cases := []struct { + name string + url *URL + want string + }{ + { + name: "non-blank Password", + url: &URL{ + Scheme: "http", + Host: "host.tld", + Path: "this:that", + User: UserPassword("user", "password"), + }, + want: "http://user:xxxxx@host.tld/this:that", + }, + { + name: "blank Password", + url: &URL{ + Scheme: "http", + Host: "host.tld", + Path: "this:that", + User: User("user"), + }, + want: "http://user@host.tld/this:that", + }, + { + name: "nil User", + url: &URL{ + Scheme: "http", + Host: "host.tld", + Path: "this:that", + User: UserPassword("", "password"), + }, + want: "http://:xxxxx@host.tld/this:that", + }, + { + name: "blank Username, blank Password", + url: &URL{ + Scheme: "http", + Host: "host.tld", + Path: "this:that", + }, + want: "http://host.tld/this:that", + }, + { + name: "empty URL", + url: &URL{}, + want: "", + }, + { + name: "nil URL", + url: nil, + want: "", + }, + } + + for _, tt := range cases { + t := t + t.Run(tt.name, func(t *testing.T) { + if g, w := tt.url.Redacted(), tt.want; g != w { + t.Fatalf("got: %q\nwant: %q", g, w) + } + }) + } +} + type EscapeTest struct { in string out string |