diff options
Diffstat (limited to 'libgo/go/net')
206 files changed, 8428 insertions, 2215 deletions
diff --git a/libgo/go/net/addrselect.go b/libgo/go/net/addrselect.go index 4603c55..e910181 100644 --- a/libgo/go/net/addrselect.go +++ b/libgo/go/net/addrselect.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris -// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris // Minimal RFC 6724 address selection. diff --git a/libgo/go/net/addrselect_test.go b/libgo/go/net/addrselect_test.go index 18784fe..a958e2e 100644 --- a/libgo/go/net/addrselect_test.go +++ b/libgo/go/net/addrselect_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris -// +build darwin dragonfly freebsd hurd linux netbsd openbsd solaris package net diff --git a/libgo/go/net/cgo_aix.go b/libgo/go/net/cgo_aix.go index 577649f..6ee0f09 100644 --- a/libgo/go/net/cgo_aix.go +++ b/libgo/go/net/cgo_aix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build cgo && !netgo -// +build cgo,!netgo package net diff --git a/libgo/go/net/cgo_android.go b/libgo/go/net/cgo_android.go index 4b1a2e3..5ab8b5f 100644 --- a/libgo/go/net/cgo_android.go +++ b/libgo/go/net/cgo_android.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build cgo && !netgo -// +build cgo,!netgo package net diff --git a/libgo/go/net/cgo_bsd.go b/libgo/go/net/cgo_bsd.go index 1268c89..830e589 100644 --- a/libgo/go/net/cgo_bsd.go +++ b/libgo/go/net/cgo_bsd.go @@ -3,9 +3,6 @@ // license that can be found in the LICENSE file. //go:build cgo && !netgo && (darwin || dragonfly || freebsd) -// +build cgo -// +build !netgo -// +build darwin dragonfly freebsd package net diff --git a/libgo/go/net/cgo_linux.go b/libgo/go/net/cgo_linux.go index 4b45dad..5d67699 100644 --- a/libgo/go/net/cgo_linux.go +++ b/libgo/go/net/cgo_linux.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !android && cgo && !netgo -// +build !android,cgo,!netgo package net diff --git a/libgo/go/net/cgo_netbsd.go b/libgo/go/net/cgo_netbsd.go index e23899d..4778811 100644 --- a/libgo/go/net/cgo_netbsd.go +++ b/libgo/go/net/cgo_netbsd.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build cgo && !netgo -// +build cgo,!netgo package net diff --git a/libgo/go/net/cgo_openbsd.go b/libgo/go/net/cgo_openbsd.go index 3714793..03392e8 100644 --- a/libgo/go/net/cgo_openbsd.go +++ b/libgo/go/net/cgo_openbsd.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build cgo && !netgo -// +build cgo,!netgo package net diff --git a/libgo/go/net/cgo_resnew.go b/libgo/go/net/cgo_resnew.go index 6611fd7..985faee 100644 --- a/libgo/go/net/cgo_resnew.go +++ b/libgo/go/net/cgo_resnew.go @@ -3,9 +3,6 @@ // license that can be found in the LICENSE file. //go:build cgo && !netgo && (aix || darwin || hurd || (linux && !android) || netbsd || solaris) -// +build cgo -// +build !netgo -// +build aix darwin hurd linux,!android netbsd solaris package net diff --git a/libgo/go/net/cgo_resold.go b/libgo/go/net/cgo_resold.go index 33f664c..b65e020 100644 --- a/libgo/go/net/cgo_resold.go +++ b/libgo/go/net/cgo_resold.go @@ -3,9 +3,6 @@ // license that can be found in the LICENSE file. //go:build cgo && !netgo && (android || freebsd || dragonfly || openbsd) -// +build cgo -// +build !netgo -// +build android freebsd dragonfly openbsd package net diff --git a/libgo/go/net/cgo_socknew.go b/libgo/go/net/cgo_socknew.go index 84b40c9..2c3ab63 100644 --- a/libgo/go/net/cgo_socknew.go +++ b/libgo/go/net/cgo_socknew.go @@ -3,9 +3,6 @@ // license that can be found in the LICENSE file. //go:build cgo && !netgo && (android || linux || solaris) -// +build cgo -// +build !netgo -// +build android linux solaris package net diff --git a/libgo/go/net/cgo_sockold.go b/libgo/go/net/cgo_sockold.go index 703b41b..461ecb4 100644 --- a/libgo/go/net/cgo_sockold.go +++ b/libgo/go/net/cgo_sockold.go @@ -3,9 +3,6 @@ // license that can be found in the LICENSE file. //go:build cgo && !netgo && (aix || darwin || dragonfly || freebsd || hurd || netbsd || openbsd) -// +build cgo -// +build !netgo -// +build aix darwin dragonfly freebsd hurd netbsd openbsd package net diff --git a/libgo/go/net/cgo_solaris.go b/libgo/go/net/cgo_solaris.go index 95d5db5..95a23cf 100644 --- a/libgo/go/net/cgo_solaris.go +++ b/libgo/go/net/cgo_solaris.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build cgo && !netgo -// +build cgo,!netgo package net diff --git a/libgo/go/net/cgo_stub.go b/libgo/go/net/cgo_stub.go index 039e4be..cc84ca4 100644 --- a/libgo/go/net/cgo_stub.go +++ b/libgo/go/net/cgo_stub.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !cgo || netgo -// +build !cgo netgo package net diff --git a/libgo/go/net/cgo_unix.go b/libgo/go/net/cgo_unix.go index 462bf12..26b3da3 100644 --- a/libgo/go/net/cgo_unix.go +++ b/libgo/go/net/cgo_unix.go @@ -3,9 +3,6 @@ // license that can be found in the LICENSE file. //go:build cgo && !netgo && (aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris) -// +build cgo -// +build !netgo -// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris package net @@ -352,7 +349,7 @@ func cgoLookupAddrPTR(addr string, sa *syscall.RawSockaddr, salen syscall.Sockle break } } - return []string{absDomainName(b)}, nil + return []string{absDomainName(string(b))}, nil } func cgoReverseLookup(result chan<- reverseLookupResult, addr string, sa *syscall.RawSockaddr, salen syscall.Socklen_t) { diff --git a/libgo/go/net/cgo_unix_test.go b/libgo/go/net/cgo_unix_test.go index 98b3b4a..5264fcd 100644 --- a/libgo/go/net/cgo_unix_test.go +++ b/libgo/go/net/cgo_unix_test.go @@ -3,9 +3,6 @@ // license that can be found in the LICENSE file. //go:build cgo && !netgo && (aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris) -// +build cgo -// +build !netgo -// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris package net diff --git a/libgo/go/net/cgo_windows.go b/libgo/go/net/cgo_windows.go index 1fd1f297..6bb6cbb 100644 --- a/libgo/go/net/cgo_windows.go +++ b/libgo/go/net/cgo_windows.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build cgo && !netgo -// +build cgo,!netgo package net diff --git a/libgo/go/net/conf.go b/libgo/go/net/conf.go index fe7ebf1..6edecaf 100644 --- a/libgo/go/net/conf.go +++ b/libgo/go/net/conf.go @@ -3,12 +3,12 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris -// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris package net import ( "internal/bytealg" + "internal/godebug" "os" "runtime" "sync" @@ -287,7 +287,7 @@ func (c *conf) hostLookupOrder(r *Resolver, hostname string) (ret hostLookupOrde // cgo+2 // same, but debug level 2 // etc. func goDebugNetDNS() (dnsMode string, debugLevel int) { - goDebug := goDebugString("netdns") + goDebug := godebug.Get("netdns") parsePart := func(s string) { if s == "" { return diff --git a/libgo/go/net/conf_netcgo.go b/libgo/go/net/conf_netcgo.go index c705152..3447a87 100644 --- a/libgo/go/net/conf_netcgo.go +++ b/libgo/go/net/conf_netcgo.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build netcgo -// +build netcgo package net diff --git a/libgo/go/net/conf_test.go b/libgo/go/net/conf_test.go index f5e4d86..8c2d3ce 100644 --- a/libgo/go/net/conf_test.go +++ b/libgo/go/net/conf_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris -// +build darwin dragonfly freebsd hurd linux netbsd openbsd solaris package net diff --git a/libgo/go/net/conn_test.go b/libgo/go/net/conn_test.go index 45e271c..d168dda 100644 --- a/libgo/go/net/conn_test.go +++ b/libgo/go/net/conn_test.go @@ -6,7 +6,6 @@ // tag. //go:build !js -// +build !js package net @@ -18,7 +17,7 @@ import ( // someTimeout is used just to test that net.Conn implementations // don't explode when their SetFooDeadline methods are called. // It isn't actually used for testing timeouts. -const someTimeout = 10 * time.Second +const someTimeout = 1 * time.Hour func TestConnAndListener(t *testing.T) { for i, network := range []string{"tcp", "unix", "unixpacket"} { @@ -27,10 +26,7 @@ func TestConnAndListener(t *testing.T) { continue } - ls, err := newLocalServer(network) - if err != nil { - t.Fatal(err) - } + ls := newLocalServer(t, network) defer ls.teardown() ch := make(chan error, 1) handler := func(ls *localServer, ln Listener) { ls.transponder(ln, ch) } diff --git a/libgo/go/net/dial_test.go b/libgo/go/net/dial_test.go index 723038c..b9aead0 100644 --- a/libgo/go/net/dial_test.go +++ b/libgo/go/net/dial_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js -// +build !js package net @@ -60,10 +59,7 @@ func TestProhibitionaryDialArg(t *testing.T) { } func TestDialLocal(t *testing.T) { - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") defer ln.Close() _, port, err := SplitHostPort(ln.Addr().String()) if err != nil { @@ -433,14 +429,15 @@ func TestDialParallelSpuriousConnection(t *testing.T) { readDeadline = time.Now().Add(5 * time.Second) } - var wg sync.WaitGroup - wg.Add(2) + var closed sync.WaitGroup + closed.Add(2) handler := func(dss *dualStackServer, ln Listener) { // Accept one connection per address. c, err := ln.Accept() if err != nil { t.Fatal(err) } + // The client should close itself, without sending data. c.SetReadDeadline(readDeadline) var b [1]byte @@ -448,7 +445,7 @@ func TestDialParallelSpuriousConnection(t *testing.T) { t.Errorf("got %v; want %v", err, io.EOF) } c.Close() - wg.Done() + closed.Done() } dss, err := newDualStackServer() if err != nil { @@ -461,12 +458,16 @@ func TestDialParallelSpuriousConnection(t *testing.T) { const fallbackDelay = 100 * time.Millisecond + var dialing sync.WaitGroup + dialing.Add(2) origTestHookDialTCP := testHookDialTCP defer func() { testHookDialTCP = origTestHookDialTCP }() testHookDialTCP = func(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) { - // Sleep long enough for Happy Eyeballs to kick in, and inhibit cancellation. + // Wait until Happy Eyeballs kicks in and both connections are dialing, + // and inhibit cancellation. // This forces dialParallel to juggle two successful connections. - time.Sleep(fallbackDelay * 2) + dialing.Done() + dialing.Wait() // Now ignore the provided context (which will be canceled) and use a // different one to make sure this completes with a valid connection, @@ -500,7 +501,7 @@ func TestDialParallelSpuriousConnection(t *testing.T) { c.Close() // The server should've seen both connections. - wg.Wait() + closed.Wait() } func TestDialerPartialDeadline(t *testing.T) { @@ -538,6 +539,9 @@ func TestDialerPartialDeadline(t *testing.T) { } } +// isEADDRINUSE reports whether err is syscall.EADDRINUSE. +var isEADDRINUSE = func(err error) bool { return false } + func TestDialerLocalAddr(t *testing.T) { if !supportsIPv4() || !supportsIPv6() { t.Skip("both IPv4 and IPv6 are required") @@ -593,7 +597,9 @@ func TestDialerLocalAddr(t *testing.T) { {"tcp", "::1", &UnixAddr{}, &AddrError{Err: "some error"}}, } + issue34264Index := -1 if supportsIPv4map() { + issue34264Index = len(tests) tests = append(tests, test{ "tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("::")}, nil, }) @@ -615,20 +621,16 @@ func TestDialerLocalAddr(t *testing.T) { c.Close() } } - var err error var lss [2]*localServer for i, network := range []string{"tcp4", "tcp6"} { - lss[i], err = newLocalServer(network) - if err != nil { - t.Fatal(err) - } + lss[i] = newLocalServer(t, network) defer lss[i].teardown() if err := lss[i].buildup(handler); err != nil { t.Fatal(err) } } - for _, tt := range tests { + for i, tt := range tests { d := &Dialer{LocalAddr: tt.laddr} var addr string ip := ParseIP(tt.raddr) @@ -640,7 +642,15 @@ func TestDialerLocalAddr(t *testing.T) { } c, err := d.Dial(tt.network, addr) if err == nil && tt.error != nil || err != nil && tt.error == nil { - t.Errorf("%s %v->%s: got %v; want %v", tt.network, tt.laddr, tt.raddr, err, tt.error) + if i == issue34264Index && runtime.GOOS == "freebsd" && isEADDRINUSE(err) { + // https://golang.org/issue/34264: FreeBSD through at least version 12.2 + // has been observed to fail with EADDRINUSE when dialing from an IPv6 + // local address to an IPv4 remote address. + t.Logf("%s %v->%s: got %v; want %v", tt.network, tt.laddr, tt.raddr, err, tt.error) + t.Logf("(spurious EADDRINUSE ignored on freebsd: see https://golang.org/issue/34264)") + } else { + t.Errorf("%s %v->%s: got %v; want %v", tt.network, tt.laddr, tt.raddr, err, tt.error) + } } if err != nil { if perr := parseDialError(err); perr != nil { @@ -713,10 +723,7 @@ func TestDialerKeepAlive(t *testing.T) { c.Close() } } - ls, err := newLocalServer("tcp") - if err != nil { - t.Fatal(err) - } + ls := newLocalServer(t, "tcp") defer ls.teardown() if err := ls.buildup(handler); err != nil { t.Fatal(err) @@ -814,10 +821,7 @@ func TestCancelAfterDial(t *testing.T) { t.Skip("avoiding time.Sleep") } - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") var wg sync.WaitGroup wg.Add(1) @@ -920,11 +924,7 @@ func TestDialerControl(t *testing.T) { if !testableNetwork(network) { continue } - ln, err := newLocalListener(network) - if err != nil { - t.Error(err) - continue - } + ln := newLocalListener(t, network) defer ln.Close() d := Dialer{Control: controlOnConnSetup} c, err := d.Dial(network, ln.Addr().String()) @@ -940,11 +940,7 @@ func TestDialerControl(t *testing.T) { if !testableNetwork(network) { continue } - c1, err := newLocalPacketListener(network) - if err != nil { - t.Error(err) - continue - } + c1 := newLocalPacketListener(t, network) if network == "unixgram" { defer os.Remove(c1.LocalAddr().String()) } @@ -980,10 +976,7 @@ func (contextWithNonZeroDeadline) Deadline() (time.Time, bool) { } func TestDialWithNonZeroDeadline(t *testing.T) { - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") defer ln.Close() _, port, err := SplitHostPort(ln.Addr().String()) if err != nil { diff --git a/libgo/go/net/dial_unix_test.go b/libgo/go/net/dial_unix_test.go index 4b9bc27..45d032c 100644 --- a/libgo/go/net/dial_unix_test.go +++ b/libgo/go/net/dial_unix_test.go @@ -3,17 +3,23 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris -// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris package net import ( "context" + "errors" "syscall" "testing" "time" ) +func init() { + isEADDRINUSE = func(err error) bool { + return errors.Is(err, syscall.EADDRINUSE) + } +} + // Issue 16523 func TestDialContextCancelRace(t *testing.T) { oldConnectFunc := connectFunc @@ -25,10 +31,7 @@ func TestDialContextCancelRace(t *testing.T) { testHookCanceledDial = oldTestHookCanceledDial }() - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") listenerDone := make(chan struct{}) go func() { defer close(listenerDone) diff --git a/libgo/go/net/dnsclient.go b/libgo/go/net/dnsclient.go index 1bbe396..a779c37 100644 --- a/libgo/go/net/dnsclient.go +++ b/libgo/go/net/dnsclient.go @@ -5,6 +5,7 @@ package net import ( + "internal/bytealg" "internal/itoa" "sort" @@ -75,6 +76,11 @@ func equalASCIIName(x, y dnsmessage.Name) bool { // (currently restricted to hostname-compatible "preferred name" LDH labels and // SRV-like "underscore labels"; see golang.org/issue/12421). func isDomainName(s string) bool { + // The root domain name is valid. See golang.org/issue/45715. + if s == "." { + return true + } + // See RFC 1035, RFC 3696. // Presentation format has dots before every label except the first, and the // terminal empty label is optional here because we assume fully-qualified @@ -136,18 +142,11 @@ func isDomainName(s string) bool { // It's hard to tell so we settle on the heuristic that names without dots // (like "localhost" or "myhost") do not get trailing dots, but any other // names do. -func absDomainName(b []byte) string { - hasDots := false - for _, x := range b { - if x == '.' { - hasDots = true - break - } - } - if hasDots && b[len(b)-1] != '.' { - b = append(b, '.') +func absDomainName(s string) string { + if bytealg.IndexByteString(s, '.') != -1 && s[len(s)-1] != '.' { + s += "." } - return string(b) + return s } // An SRV represents a single DNS SRV record. diff --git a/libgo/go/net/dnsclient_unix.go b/libgo/go/net/dnsclient_unix.go index a326319..3278791e 100644 --- a/libgo/go/net/dnsclient_unix.go +++ b/libgo/go/net/dnsclient_unix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris -// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris // DNS client: see RFC 1035. // Has to be linked into package net for Dial. diff --git a/libgo/go/net/dnsclient_unix_test.go b/libgo/go/net/dnsclient_unix_test.go index ce1a4f3..e34c0a5 100644 --- a/libgo/go/net/dnsclient_unix_test.go +++ b/libgo/go/net/dnsclient_unix_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris -// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris package net @@ -2121,3 +2120,44 @@ func TestNullMX(t *testing.T) { t.Errorf("records = [%v]; want [%v]", strings.Join(records, " "), want[0]) } } + +func TestRootNS(t *testing.T) { + // See https://golang.org/issue/45715. + fake := fakeDNSServer{ + rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { + r := dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: q.Header.ID, + Response: true, + RCode: dnsmessage.RCodeSuccess, + }, + Questions: q.Questions, + Answers: []dnsmessage.Resource{ + { + Header: dnsmessage.ResourceHeader{ + Name: q.Questions[0].Name, + Type: dnsmessage.TypeNS, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.NSResource{ + NS: dnsmessage.MustNewName("i.root-servers.net."), + }, + }, + }, + } + return r, nil + }, + } + r := Resolver{PreferGo: true, Dial: fake.DialContext} + rrset, err := r.LookupNS(context.Background(), ".") + if err != nil { + t.Fatalf("LookupNS: %v", err) + } + if want := []*NS{&NS{Host: "i.root-servers.net."}}; !reflect.DeepEqual(rrset, want) { + records := []string{} + for _, rr := range rrset { + records = append(records, fmt.Sprintf("%v", rr)) + } + t.Errorf("records = [%v]; want [%v]", strings.Join(records, " "), want[0]) + } +} diff --git a/libgo/go/net/dnsconfig_unix.go b/libgo/go/net/dnsconfig_unix.go index 4b11602..37f3cce 100644 --- a/libgo/go/net/dnsconfig_unix.go +++ b/libgo/go/net/dnsconfig_unix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris -// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris // Read system DNS config from /etc/resolv.conf diff --git a/libgo/go/net/dnsconfig_unix_test.go b/libgo/go/net/dnsconfig_unix_test.go index 59e21d6..652a68f 100644 --- a/libgo/go/net/dnsconfig_unix_test.go +++ b/libgo/go/net/dnsconfig_unix_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris -// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris package net diff --git a/libgo/go/net/dnsname_test.go b/libgo/go/net/dnsname_test.go index d851bf7..28b7c68 100644 --- a/libgo/go/net/dnsname_test.go +++ b/libgo/go/net/dnsname_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js -// +build !js package net diff --git a/libgo/go/net/error_plan9_test.go b/libgo/go/net/error_plan9_test.go index d7c7f14..1270af1 100644 --- a/libgo/go/net/error_plan9_test.go +++ b/libgo/go/net/error_plan9_test.go @@ -17,3 +17,7 @@ func isPlatformError(err error) bool { _, ok := err.(syscall.ErrorString) return ok } + +func isENOBUFS(err error) bool { + return false // ENOBUFS is Unix-specific +} diff --git a/libgo/go/net/error_posix.go b/libgo/go/net/error_posix.go index 017f2cb..94c73cc 100644 --- a/libgo/go/net/error_posix.go +++ b/libgo/go/net/error_posix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || (js && wasm) || linux || netbsd || openbsd || solaris || windows -// +build aix darwin dragonfly freebsd hurd js,wasm linux netbsd openbsd solaris windows package net diff --git a/libgo/go/net/error_posix_test.go b/libgo/go/net/error_posix_test.go index ea52a45..081176f 100644 --- a/libgo/go/net/error_posix_test.go +++ b/libgo/go/net/error_posix_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !plan9 -// +build !plan9 package net diff --git a/libgo/go/net/error_test.go b/libgo/go/net/error_test.go index c304390..4a191673 100644 --- a/libgo/go/net/error_test.go +++ b/libgo/go/net/error_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js -// +build !js package net @@ -554,10 +553,7 @@ third: } func TestCloseError(t *testing.T) { - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") defer ln.Close() c, err := Dial(ln.Addr().Network(), ln.Addr().String()) if err != nil { @@ -665,10 +661,7 @@ func TestAcceptError(t *testing.T) { c.Close() } } - ls, err := newLocalServer("tcp") - if err != nil { - t.Fatal(err) - } + ls := newLocalServer(t, "tcp") if err := ls.buildup(handler); err != nil { ls.teardown() t.Fatal(err) @@ -774,10 +767,7 @@ func TestFileError(t *testing.T) { t.Error("should fail") } - ln, err = newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln = newLocalListener(t, "tcp") for i := 0; i < 3; i++ { f, err := ln.(*TCPListener).File() diff --git a/libgo/go/net/error_unix.go b/libgo/go/net/error_unix.go index 3de4e76..775e4a0 100644 --- a/libgo/go/net/error_unix.go +++ b/libgo/go/net/error_unix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || js || linux || netbsd || openbsd || solaris -// +build aix darwin dragonfly freebsd hurd js linux netbsd openbsd solaris package net diff --git a/libgo/go/net/error_unix_test.go b/libgo/go/net/error_unix_test.go index 533a45e..291a723 100644 --- a/libgo/go/net/error_unix_test.go +++ b/libgo/go/net/error_unix_test.go @@ -3,11 +3,11 @@ // license that can be found in the LICENSE file. //go:build !plan9 && !windows -// +build !plan9,!windows package net import ( + "errors" "os" "syscall" ) @@ -33,3 +33,7 @@ func samePlatformError(err, want error) bool { } return err == want } + +func isENOBUFS(err error) bool { + return errors.Is(err, syscall.ENOBUFS) +} diff --git a/libgo/go/net/error_windows_test.go b/libgo/go/net/error_windows_test.go index 834a9de..25825f9 100644 --- a/libgo/go/net/error_windows_test.go +++ b/libgo/go/net/error_windows_test.go @@ -4,7 +4,10 @@ package net -import "syscall" +import ( + "errors" + "syscall" +) var ( errTimedout = syscall.ETIMEDOUT @@ -17,3 +20,10 @@ func isPlatformError(err error) bool { _, ok := err.(syscall.Errno) return ok } + +func isENOBUFS(err error) bool { + // syscall.ENOBUFS is a completely made-up value on Windows: we don't expect + // a real system call to ever actually return it. However, since it is already + // defined in the syscall package we may as well check for it. + return errors.Is(err, syscall.ENOBUFS) +} diff --git a/libgo/go/net/example_test.go b/libgo/go/net/example_test.go index 72c7183..2c045d7 100644 --- a/libgo/go/net/example_test.go +++ b/libgo/go/net/example_test.go @@ -124,6 +124,176 @@ func ExampleIP_DefaultMask() { // ffffff00 } +func ExampleIP_Equal() { + ipv4DNS := net.ParseIP("8.8.8.8") + ipv4Lo := net.ParseIP("127.0.0.1") + ipv6DNS := net.ParseIP("0:0:0:0:0:FFFF:0808:0808") + + fmt.Println(ipv4DNS.Equal(ipv4DNS)) + fmt.Println(ipv4DNS.Equal(ipv4Lo)) + fmt.Println(ipv4DNS.Equal(ipv6DNS)) + + // Output: + // true + // false + // true +} + +func ExampleIP_IsGlobalUnicast() { + ipv6Global := net.ParseIP("2000::") + ipv6UniqLocal := net.ParseIP("2000::") + ipv6Multi := net.ParseIP("FF00::") + + ipv4Private := net.ParseIP("10.255.0.0") + ipv4Public := net.ParseIP("8.8.8.8") + ipv4Broadcast := net.ParseIP("255.255.255.255") + + fmt.Println(ipv6Global.IsGlobalUnicast()) + fmt.Println(ipv6UniqLocal.IsGlobalUnicast()) + fmt.Println(ipv6Multi.IsGlobalUnicast()) + + fmt.Println(ipv4Private.IsGlobalUnicast()) + fmt.Println(ipv4Public.IsGlobalUnicast()) + fmt.Println(ipv4Broadcast.IsGlobalUnicast()) + + // Output: + // true + // true + // false + // true + // true + // false +} + +func ExampleIP_IsInterfaceLocalMulticast() { + ipv6InterfaceLocalMulti := net.ParseIP("ff01::1") + ipv6Global := net.ParseIP("2000::") + ipv4 := net.ParseIP("255.0.0.0") + + fmt.Println(ipv6InterfaceLocalMulti.IsInterfaceLocalMulticast()) + fmt.Println(ipv6Global.IsInterfaceLocalMulticast()) + fmt.Println(ipv4.IsInterfaceLocalMulticast()) + + // Output: + // true + // false + // false +} + +func ExampleIP_IsLinkLocalMulticast() { + ipv6LinkLocalMulti := net.ParseIP("ff02::2") + ipv6LinkLocalUni := net.ParseIP("fe80::") + ipv4LinkLocalMulti := net.ParseIP("224.0.0.0") + ipv4LinkLocalUni := net.ParseIP("169.254.0.0") + + fmt.Println(ipv6LinkLocalMulti.IsLinkLocalMulticast()) + fmt.Println(ipv6LinkLocalUni.IsLinkLocalMulticast()) + fmt.Println(ipv4LinkLocalMulti.IsLinkLocalMulticast()) + fmt.Println(ipv4LinkLocalUni.IsLinkLocalMulticast()) + + // Output: + // true + // false + // true + // false +} + +func ExampleIP_IsLinkLocalUnicast() { + ipv6LinkLocalUni := net.ParseIP("fe80::") + ipv6Global := net.ParseIP("2000::") + ipv4LinkLocalUni := net.ParseIP("169.254.0.0") + ipv4LinkLocalMulti := net.ParseIP("224.0.0.0") + + fmt.Println(ipv6LinkLocalUni.IsLinkLocalUnicast()) + fmt.Println(ipv6Global.IsLinkLocalUnicast()) + fmt.Println(ipv4LinkLocalUni.IsLinkLocalUnicast()) + fmt.Println(ipv4LinkLocalMulti.IsLinkLocalUnicast()) + + // Output: + // true + // false + // true + // false +} + +func ExampleIP_IsLoopback() { + ipv6Lo := net.ParseIP("::1") + ipv6 := net.ParseIP("ff02::1") + ipv4Lo := net.ParseIP("127.0.0.0") + ipv4 := net.ParseIP("128.0.0.0") + + fmt.Println(ipv6Lo.IsLoopback()) + fmt.Println(ipv6.IsLoopback()) + fmt.Println(ipv4Lo.IsLoopback()) + fmt.Println(ipv4.IsLoopback()) + + // Output: + // true + // false + // true + // false +} + +func ExampleIP_IsMulticast() { + ipv6Multi := net.ParseIP("FF00::") + ipv6LinkLocalMulti := net.ParseIP("ff02::1") + ipv6Lo := net.ParseIP("::1") + ipv4Multi := net.ParseIP("239.0.0.0") + ipv4LinkLocalMulti := net.ParseIP("224.0.0.0") + ipv4Lo := net.ParseIP("127.0.0.0") + + fmt.Println(ipv6Multi.IsMulticast()) + fmt.Println(ipv6LinkLocalMulti.IsMulticast()) + fmt.Println(ipv6Lo.IsMulticast()) + fmt.Println(ipv4Multi.IsMulticast()) + fmt.Println(ipv4LinkLocalMulti.IsMulticast()) + fmt.Println(ipv4Lo.IsMulticast()) + + // Output: + // true + // true + // false + // true + // true + // false +} + +func ExampleIP_IsPrivate() { + ipv6Private := net.ParseIP("fc00::") + ipv6Public := net.ParseIP("fe00::") + ipv4Private := net.ParseIP("10.255.0.0") + ipv4Public := net.ParseIP("11.0.0.0") + + fmt.Println(ipv6Private.IsPrivate()) + fmt.Println(ipv6Public.IsPrivate()) + fmt.Println(ipv4Private.IsPrivate()) + fmt.Println(ipv4Public.IsPrivate()) + + // Output: + // true + // false + // true + // false +} + +func ExampleIP_IsUnspecified() { + ipv6Unspecified := net.ParseIP("::") + ipv6Specified := net.ParseIP("fe00::") + ipv4Unspecified := net.ParseIP("0.0.0.0") + ipv4Specified := net.ParseIP("8.8.8.8") + + fmt.Println(ipv6Unspecified.IsUnspecified()) + fmt.Println(ipv6Specified.IsUnspecified()) + fmt.Println(ipv4Unspecified.IsUnspecified()) + fmt.Println(ipv4Specified.IsUnspecified()) + + // Output: + // true + // false + // true + // false +} + func ExampleIP_Mask() { ipv4Addr := net.ParseIP("192.0.2.1") // This mask corresponds to a /24 subnet for IPv4. @@ -140,6 +310,42 @@ func ExampleIP_Mask() { // 2001:db8:: } +func ExampleIP_String() { + ipv6 := net.IP{0xfc, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + ipv4 := net.IPv4(10, 255, 0, 0) + + fmt.Println(ipv6.String()) + fmt.Println(ipv4.String()) + + // Output: + // fc00:: + // 10.255.0.0 +} + +func ExampleIP_To16() { + ipv6 := net.IP{0xfc, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + ipv4 := net.IPv4(10, 255, 0, 0) + + fmt.Println(ipv6.To16()) + fmt.Println(ipv4.To16()) + + // Output: + // fc00:: + // 10.255.0.0 +} + +func ExampleIP_to4() { + ipv6 := net.IP{0xfc, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + ipv4 := net.IPv4(10, 255, 0, 0) + + fmt.Println(ipv6.To4()) + fmt.Println(ipv4.To4()) + + // Output: + // <nil> + // 10.255.0.0 +} + func ExampleCIDRMask() { // This mask corresponds to a /31 subnet for IPv4. fmt.Println(net.CIDRMask(31, 32)) diff --git a/libgo/go/net/external_test.go b/libgo/go/net/external_test.go index b8753cc..3a97011 100644 --- a/libgo/go/net/external_test.go +++ b/libgo/go/net/external_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js -// +build !js package net diff --git a/libgo/go/net/fcntl_libc_test.go b/libgo/go/net/fcntl_libc_test.go index 02511c5..f59a1aa 100644 --- a/libgo/go/net/fcntl_libc_test.go +++ b/libgo/go/net/fcntl_libc_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || solaris -// +build aix darwin solaris package net diff --git a/libgo/go/net/fcntl_syscall_test.go b/libgo/go/net/fcntl_syscall_test.go index 59ba1a1..58cacc4 100644 --- a/libgo/go/net/fcntl_syscall_test.go +++ b/libgo/go/net/fcntl_syscall_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build dragonfly || freebsd || linux || netbsd || openbsd -// +build dragonfly freebsd linux netbsd openbsd package net diff --git a/libgo/go/net/fd_posix.go b/libgo/go/net/fd_posix.go index a0f1f5a..466ccce 100644 --- a/libgo/go/net/fd_posix.go +++ b/libgo/go/net/fd_posix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris || windows -// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris windows package net @@ -63,6 +62,17 @@ func (fd *netFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) { runtime.KeepAlive(fd) return n, sa, wrapSyscallError(readFromSyscallName, err) } +func (fd *netFD) readFromInet4(p []byte, from *syscall.SockaddrInet4) (n int, err error) { + n, err = fd.pfd.ReadFromInet4(p, from) + runtime.KeepAlive(fd) + return n, wrapSyscallError(readFromSyscallName, err) +} + +func (fd *netFD) readFromInet6(p []byte, from *syscall.SockaddrInet6) (n int, err error) { + n, err = fd.pfd.ReadFromInet6(p, from) + runtime.KeepAlive(fd) + return n, wrapSyscallError(readFromSyscallName, err) +} func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) { n, oobn, retflags, sa, err = fd.pfd.ReadMsg(p, oob, flags) @@ -70,6 +80,18 @@ func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int return n, oobn, retflags, sa, wrapSyscallError(readMsgSyscallName, err) } +func (fd *netFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet4) (n, oobn, retflags int, err error) { + n, oobn, retflags, err = fd.pfd.ReadMsgInet4(p, oob, flags, sa) + runtime.KeepAlive(fd) + return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err) +} + +func (fd *netFD) readMsgInet6(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet6) (n, oobn, retflags int, err error) { + n, oobn, retflags, err = fd.pfd.ReadMsgInet6(p, oob, flags, sa) + runtime.KeepAlive(fd) + return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err) +} + func (fd *netFD) Write(p []byte) (nn int, err error) { nn, err = fd.pfd.Write(p) runtime.KeepAlive(fd) @@ -82,12 +104,36 @@ func (fd *netFD) writeTo(p []byte, sa syscall.Sockaddr) (n int, err error) { return n, wrapSyscallError(writeToSyscallName, err) } +func (fd *netFD) writeToInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) { + n, err = fd.pfd.WriteToInet4(p, sa) + runtime.KeepAlive(fd) + return n, wrapSyscallError(writeToSyscallName, err) +} + +func (fd *netFD) writeToInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) { + n, err = fd.pfd.WriteToInet6(p, sa) + runtime.KeepAlive(fd) + return n, wrapSyscallError(writeToSyscallName, err) +} + func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) { n, oobn, err = fd.pfd.WriteMsg(p, oob, sa) runtime.KeepAlive(fd) return n, oobn, wrapSyscallError(writeMsgSyscallName, err) } +func (fd *netFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (n int, oobn int, err error) { + n, oobn, err = fd.pfd.WriteMsgInet4(p, oob, sa) + runtime.KeepAlive(fd) + return n, oobn, wrapSyscallError(writeMsgSyscallName, err) +} + +func (fd *netFD) writeMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (n int, oobn int, err error) { + n, oobn, err = fd.pfd.WriteMsgInet6(p, oob, sa) + runtime.KeepAlive(fd) + return n, oobn, wrapSyscallError(writeMsgSyscallName, err) +} + func (fd *netFD) SetDeadline(t time.Time) error { return fd.pfd.SetDeadline(t) } diff --git a/libgo/go/net/fd_unix.go b/libgo/go/net/fd_unix.go index e2db165..394e1c7 100644 --- a/libgo/go/net/fd_unix.go +++ b/libgo/go/net/fd_unix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris -// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris package net @@ -92,12 +91,12 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (rsa sysc } // Start the "interrupter" goroutine, if this context might be canceled. - // (The background context cannot) // // The interrupter goroutine waits for the context to be done and // interrupts the dial (by altering the fd's write deadline, which // wakes up waitWrite). - if ctx != context.Background() { + ctxDone := ctx.Done() + if ctxDone != nil { // Wait for the interrupter goroutine to exit before returning // from connect. done := make(chan struct{}) @@ -117,7 +116,7 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (rsa sysc }() go func() { select { - case <-ctx.Done(): + case <-ctxDone: // Force the runtime's poller to immediately give up // waiting for writability, unblocking waitWrite // below. @@ -141,7 +140,7 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (rsa sysc // details. if err := fd.pfd.WaitWrite(); err != nil { select { - case <-ctx.Done(): + case <-ctxDone: return nil, mapErr(ctx.Err()) default: } diff --git a/libgo/go/net/file_stub.go b/libgo/go/net/file_stub.go index 9f988fe..91df926 100644 --- a/libgo/go/net/file_stub.go +++ b/libgo/go/net/file_stub.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build js && wasm -// +build js,wasm package net diff --git a/libgo/go/net/file_test.go b/libgo/go/net/file_test.go index a70ef1b..ea2a218 100644 --- a/libgo/go/net/file_test.go +++ b/libgo/go/net/file_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js -// +build !js package net @@ -45,10 +44,7 @@ func TestFileConn(t *testing.T) { var network, address string switch tt.network { case "udp": - c, err := newLocalPacketListener(tt.network) - if err != nil { - t.Fatal(err) - } + c := newLocalPacketListener(t, tt.network) defer c.Close() network = c.LocalAddr().Network() address = c.LocalAddr().String() @@ -62,10 +58,7 @@ func TestFileConn(t *testing.T) { var b [1]byte c.Read(b[:]) } - ls, err := newLocalServer(tt.network) - if err != nil { - t.Fatal(err) - } + ls := newLocalServer(t, tt.network) defer ls.teardown() if err := ls.buildup(handler); err != nil { t.Fatal(err) @@ -149,17 +142,17 @@ func TestFileListener(t *testing.T) { continue } - ln1, err := newLocalListener(tt.network) - if err != nil { - t.Fatal(err) - } + ln1 := newLocalListener(t, tt.network) switch tt.network { case "unix", "unixpacket": defer os.Remove(ln1.Addr().String()) } addr := ln1.Addr() - var f *os.File + var ( + f *os.File + err error + ) switch ln1 := ln1.(type) { case *TCPListener: f, err = ln1.File() @@ -241,17 +234,17 @@ func TestFilePacketConn(t *testing.T) { continue } - c1, err := newLocalPacketListener(tt.network) - if err != nil { - t.Fatal(err) - } + c1 := newLocalPacketListener(t, tt.network) switch tt.network { case "unixgram": defer os.Remove(c1.LocalAddr().String()) } addr := c1.LocalAddr() - var f *os.File + var ( + f *os.File + err error + ) switch c1 := c1.(type) { case *UDPConn: f, err = c1.File() @@ -315,10 +308,7 @@ func TestFileCloseRace(t *testing.T) { c.Read(b[:]) } - ls, err := newLocalServer("tcp") - if err != nil { - t.Fatal(err) - } + ls := newLocalServer(t, "tcp") defer ls.teardown() if err := ls.buildup(handler); err != nil { t.Fatal(err) diff --git a/libgo/go/net/file_unix.go b/libgo/go/net/file_unix.go index d36a881..afb1d98 100644 --- a/libgo/go/net/file_unix.go +++ b/libgo/go/net/file_unix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris -// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris package net diff --git a/libgo/go/net/hook_unix.go b/libgo/go/net/hook_unix.go index 618c6c2..5629476 100644 --- a/libgo/go/net/hook_unix.go +++ b/libgo/go/net/hook_unix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || (js && wasm) || linux || netbsd || openbsd || solaris -// +build aix darwin dragonfly freebsd hurd js,wasm linux netbsd openbsd solaris package net diff --git a/libgo/go/net/hosts.go b/libgo/go/net/hosts.go index 5c560f3..e604031 100644 --- a/libgo/go/net/hosts.go +++ b/libgo/go/net/hosts.go @@ -82,10 +82,10 @@ func readHosts() { continue } for i := 1; i < len(f); i++ { - name := absDomainName([]byte(f[i])) + name := absDomainName(f[i]) h := []byte(f[i]) lowerASCIIBytes(h) - key := absDomainName(h) + key := absDomainName(string(h)) hs[key] = append(hs[key], addr) is[addr] = append(is[addr], name) } @@ -106,11 +106,12 @@ func lookupStaticHost(host string) []string { defer hosts.Unlock() readHosts() if len(hosts.byName) != 0 { - // TODO(jbd,bradfitz): avoid this alloc if host is already all lowercase? - // or linear scan the byName map if it's small enough? - lowerHost := []byte(host) - lowerASCIIBytes(lowerHost) - if ips, ok := hosts.byName[absDomainName(lowerHost)]; ok { + if hasUpperCase(host) { + lowerHost := []byte(host) + lowerASCIIBytes(lowerHost) + host = string(lowerHost) + } + if ips, ok := hosts.byName[absDomainName(host)]; ok { ipsCp := make([]string, len(ips)) copy(ipsCp, ips) return ipsCp diff --git a/libgo/go/net/hosts_test.go b/libgo/go/net/hosts_test.go index 19c4399..7291914 100644 --- a/libgo/go/net/hosts_test.go +++ b/libgo/go/net/hosts_test.go @@ -70,7 +70,7 @@ func TestLookupStaticHost(t *testing.T) { } func testStaticHost(t *testing.T, hostsPath string, ent staticHostEntry) { - ins := []string{ent.in, absDomainName([]byte(ent.in)), strings.ToLower(ent.in), strings.ToUpper(ent.in)} + ins := []string{ent.in, absDomainName(ent.in), strings.ToLower(ent.in), strings.ToUpper(ent.in)} for _, in := range ins { addrs := lookupStaticHost(in) if !reflect.DeepEqual(addrs, ent.out) { @@ -141,7 +141,7 @@ func TestLookupStaticAddr(t *testing.T) { func testStaticAddr(t *testing.T, hostsPath string, ent staticHostEntry) { hosts := lookupStaticAddr(ent.in) for i := range ent.out { - ent.out[i] = absDomainName([]byte(ent.out[i])) + ent.out[i] = absDomainName(ent.out[i]) } if !reflect.DeepEqual(hosts, ent.out) { t.Errorf("%s, lookupStaticAddr(%s) = %v; want %v", hostsPath, ent.in, hosts, ent.out) diff --git a/libgo/go/net/http/cgi/child.go b/libgo/go/net/http/cgi/child.go index 0114da3..bdb35a6 100644 --- a/libgo/go/net/http/cgi/child.go +++ b/libgo/go/net/http/cgi/child.go @@ -39,8 +39,8 @@ func Request() (*http.Request, error) { func envMap(env []string) map[string]string { m := make(map[string]string) for _, kv := range env { - if idx := strings.Index(kv, "="); idx != -1 { - m[kv[:idx]] = kv[idx+1:] + if k, v, ok := strings.Cut(kv, "="); ok { + m[k] = v } } return m diff --git a/libgo/go/net/http/cgi/host.go b/libgo/go/net/http/cgi/host.go index eff67ca..95b2e13 100644 --- a/libgo/go/net/http/cgi/host.go +++ b/libgo/go/net/http/cgi/host.go @@ -273,12 +273,11 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { break } headerLines++ - parts := strings.SplitN(string(line), ":", 2) - if len(parts) < 2 { + header, val, ok := strings.Cut(string(line), ":") + if !ok { h.printf("cgi: bogus header line: %s", string(line)) continue } - header, val := parts[0], parts[1] if !httpguts.ValidHeaderFieldName(header) { h.printf("cgi: invalid header name: %q", header) continue @@ -351,7 +350,7 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } } -func (h *Handler) printf(format string, v ...interface{}) { +func (h *Handler) printf(format string, v ...any) { if h.Logger != nil { h.Logger.Printf(format, v...) } else { diff --git a/libgo/go/net/http/cgi/host_test.go b/libgo/go/net/http/cgi/host_test.go index 9f1716b..1b72f7e 100644 --- a/libgo/go/net/http/cgi/host_test.go +++ b/libgo/go/net/http/cgi/host_test.go @@ -62,12 +62,12 @@ readlines: } linesRead++ trimmedLine := strings.TrimRight(line, "\r\n") - split := strings.SplitN(trimmedLine, "=", 2) - if len(split) != 2 { - t.Fatalf("Unexpected %d parts from invalid line number %v: %q; existing map=%v", - len(split), linesRead, line, m) + k, v, ok := strings.Cut(trimmedLine, "=") + if !ok { + t.Fatalf("Unexpected response from invalid line number %v: %q; existing map=%v", + linesRead, line, m) } - m[split[0]] = split[1] + m[k] = v } for key, expected := range expectedMap { diff --git a/libgo/go/net/http/cgi/posix_test.go b/libgo/go/net/http/cgi/posix_test.go index bc58ea9..49b9470 100644 --- a/libgo/go/net/http/cgi/posix_test.go +++ b/libgo/go/net/http/cgi/posix_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !plan9 -// +build !plan9 package cgi diff --git a/libgo/go/net/http/client.go b/libgo/go/net/http/client.go index 4d380c6..22db96b 100644 --- a/libgo/go/net/http/client.go +++ b/libgo/go/net/http/client.go @@ -965,7 +965,6 @@ func (b *cancelTimerBody) Read(p []byte) (n int, err error) { if err == nil { return n, nil } - b.stop() if err == io.EOF { return n, err } diff --git a/libgo/go/net/http/client_test.go b/libgo/go/net/http/client_test.go index 01d605c..e91d526 100644 --- a/libgo/go/net/http/client_test.go +++ b/libgo/go/net/http/client_test.go @@ -13,6 +13,7 @@ import ( "encoding/base64" "errors" "fmt" + "internal/testenv" "io" "log" "net" @@ -21,6 +22,7 @@ import ( "net/http/httptest" "net/url" "reflect" + "runtime" "strconv" "strings" "sync" @@ -431,11 +433,10 @@ func testRedirectsByMethod(t *testing.T, method string, table []redirectTest, wa if v := urlQuery.Get("code"); v != "" { location := ts.URL if final := urlQuery.Get("next"); final != "" { - splits := strings.Split(final, ",") - first, rest := splits[0], splits[1:] + first, rest, _ := strings.Cut(final, ",") location = fmt.Sprintf("%s?code=%s", location, first) - if len(rest) > 0 { - location = fmt.Sprintf("%s&next=%s", location, strings.Join(rest, ",")) + if rest != "" { + location = fmt.Sprintf("%s&next=%s", location, rest) } } code, _ := strconv.Atoi(v) @@ -746,7 +747,7 @@ func (j *RecordingJar) Cookies(u *url.URL) []*Cookie { return nil } -func (j *RecordingJar) logf(format string, args ...interface{}) { +func (j *RecordingJar) logf(format string, args ...any) { j.mu.Lock() defer j.mu.Unlock() fmt.Fprintf(&j.log, format, args...) @@ -1206,64 +1207,80 @@ func TestClientTimeout_h2(t *testing.T) { testClientTimeout(t, h2Mode) } func testClientTimeout(t *testing.T, h2 bool) { setParallel(t) defer afterTest(t) - testDone := make(chan struct{}) // closed in defer below - sawRoot := make(chan bool, 1) - sawSlow := make(chan bool, 1) + var ( + mu sync.Mutex + nonce string // a unique per-request string + sawSlowNonce bool // true if the handler saw /slow?nonce=<nonce> + ) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + _ = r.ParseForm() if r.URL.Path == "/" { - sawRoot <- true - Redirect(w, r, "/slow", StatusFound) + Redirect(w, r, "/slow?nonce="+r.Form.Get("nonce"), StatusFound) return } if r.URL.Path == "/slow" { - sawSlow <- true + mu.Lock() + if r.Form.Get("nonce") == nonce { + sawSlowNonce = true + } else { + t.Logf("mismatched nonce: received %s, want %s", r.Form.Get("nonce"), nonce) + } + mu.Unlock() + w.Write([]byte("Hello")) w.(Flusher).Flush() - <-testDone + <-r.Context().Done() return } })) defer cst.close() - defer close(testDone) // before cst.close, to unblock /slow handler - // 200ms should be long enough to get a normal request (the / - // handler), but not so long that it makes the test slow. - const timeout = 200 * time.Millisecond - cst.c.Timeout = timeout - - res, err := cst.c.Get(cst.ts.URL) - if err != nil { - if strings.Contains(err.Error(), "Client.Timeout") { - t.Skipf("host too slow to get fast resource in %v", timeout) + // Try to trigger a timeout after reading part of the response body. + // The initial timeout is emprically usually long enough on a decently fast + // machine, but if we undershoot we'll retry with exponentially longer + // timeouts until the test either passes or times out completely. + // This keeps the test reasonably fast in the typical case but allows it to + // also eventually succeed on arbitrarily slow machines. + timeout := 10 * time.Millisecond + nextNonce := 0 + for ; ; timeout *= 2 { + if timeout <= 0 { + // The only way we can feasibly hit this while the test is running is if + // the request fails without actually waiting for the timeout to occur. + t.Fatalf("timeout overflow") + } + if deadline, ok := t.Deadline(); ok && !time.Now().Add(timeout).Before(deadline) { + t.Fatalf("failed to produce expected timeout before test deadline") + } + t.Logf("attempting test with timeout %v", timeout) + cst.c.Timeout = timeout + + mu.Lock() + nonce = fmt.Sprint(nextNonce) + nextNonce++ + sawSlowNonce = false + mu.Unlock() + res, err := cst.c.Get(cst.ts.URL + "/?nonce=" + nonce) + if err != nil { + if strings.Contains(err.Error(), "Client.Timeout") { + // Timed out before handler could respond. + t.Logf("timeout before response received") + continue + } + t.Fatal(err) } - t.Fatal(err) - } - - select { - case <-sawRoot: - // good. - default: - t.Fatal("handler never got / request") - } - select { - case <-sawSlow: - // good. - default: - t.Fatal("handler never got /slow request") - } + mu.Lock() + ok := sawSlowNonce + mu.Unlock() + if !ok { + t.Fatal("handler never got /slow request, but client returned response") + } - errc := make(chan error, 1) - go func() { - _, err := io.ReadAll(res.Body) - errc <- err + _, err = io.ReadAll(res.Body) res.Body.Close() - }() - const failTime = 5 * time.Second - select { - case err := <-errc: if err == nil { t.Fatal("expected error from ReadAll") } @@ -1274,10 +1291,13 @@ func testClientTimeout(t *testing.T, h2 bool) { t.Errorf("net.Error.Timeout = false; want true") } if got := ne.Error(); !strings.Contains(got, "(Client.Timeout") { + if runtime.GOOS == "windows" && strings.HasPrefix(runtime.GOARCH, "arm") { + testenv.SkipFlaky(t, 43120) + } t.Errorf("error string = %q; missing timeout substring", got) } - case <-time.After(failTime): - t.Errorf("timeout after %v waiting for timeout of %v", failTime, timeout) + + break } } @@ -1319,6 +1339,9 @@ func testClientTimeout_Headers(t *testing.T, h2 bool) { t.Error("net.Error.Timeout = false; want true") } if got := ne.Error(); !strings.Contains(got, "Client.Timeout exceeded") { + if runtime.GOOS == "windows" && strings.HasPrefix(runtime.GOARCH, "arm") { + testenv.SkipFlaky(t, 43120) + } t.Errorf("error string = %q; missing timeout substring", got) } } @@ -1353,6 +1376,33 @@ func TestClientTimeoutCancel(t *testing.T) { } } +func TestClientTimeoutDoesNotExpire_h1(t *testing.T) { testClientTimeoutDoesNotExpire(t, h1Mode) } +func TestClientTimeoutDoesNotExpire_h2(t *testing.T) { testClientTimeoutDoesNotExpire(t, h2Mode) } + +// Issue 49366: if Client.Timeout is set but not hit, no error should be returned. +func testClientTimeoutDoesNotExpire(t *testing.T, h2 bool) { + setParallel(t) + defer afterTest(t) + + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Write([]byte("body")) + })) + defer cst.close() + + cst.c.Timeout = 1 * time.Hour + req, _ := NewRequest("GET", cst.ts.URL, nil) + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + if _, err = io.Copy(io.Discard, res.Body); err != nil { + t.Fatalf("io.Copy(io.Discard, res.Body) = %v, want nil", err) + } + if err = res.Body.Close(); err != nil { + t.Fatalf("res.Body.Close() = %v, want nil", err) + } +} + func TestClientRedirectEatsBody_h1(t *testing.T) { testClientRedirectEatsBody(t, h1Mode) } func TestClientRedirectEatsBody_h2(t *testing.T) { testClientRedirectEatsBody(t, h2Mode) } func testClientRedirectEatsBody(t *testing.T, h2 bool) { @@ -2082,3 +2132,47 @@ func (b *issue40382Body) Close() error { } return nil } + +func TestProbeZeroLengthBody(t *testing.T) { + setParallel(t) + defer afterTest(t) + reqc := make(chan struct{}) + cst := newClientServerTest(t, false, HandlerFunc(func(w ResponseWriter, r *Request) { + close(reqc) + if _, err := io.Copy(w, r.Body); err != nil { + t.Errorf("error copying request body: %v", err) + } + })) + defer cst.close() + + bodyr, bodyw := io.Pipe() + var gotBody string + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + req, _ := NewRequest("GET", cst.ts.URL, bodyr) + res, err := cst.c.Do(req) + b, err := io.ReadAll(res.Body) + if err != nil { + t.Error(err) + } + gotBody = string(b) + }() + + select { + case <-reqc: + // Request should be sent after trying to probe the request body for 200ms. + case <-time.After(60 * time.Second): + t.Errorf("request not sent after 60s") + } + + // Write the request body and wait for the request to complete. + const content = "body" + bodyw.Write([]byte(content)) + bodyw.Close() + wg.Wait() + if gotBody != content { + t.Fatalf("server got body %q, want %q", gotBody, content) + } +} diff --git a/libgo/go/net/http/clientserver_test.go b/libgo/go/net/http/clientserver_test.go index 42207ac..44d70f0 100644 --- a/libgo/go/net/http/clientserver_test.go +++ b/libgo/go/net/http/clientserver_test.go @@ -81,7 +81,7 @@ func optWithServerLog(lg *log.Logger) func(*httptest.Server) { } } -func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...interface{}) *clientServerTest { +func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...any) *clientServerTest { if h2 { CondSkipHTTP2(t) } @@ -189,7 +189,7 @@ type h12Compare struct { ReqFunc reqFunc // optional CheckResponse func(proto string, res *Response) // optional EarlyCheckResponse func(proto string, res *Response) // optional; pre-normalize - Opts []interface{} + Opts []any } func (tt h12Compare) reqFunc() reqFunc { @@ -441,7 +441,7 @@ func TestH12_AutoGzip(t *testing.T) { func TestH12_AutoGzip_Disabled(t *testing.T) { h12Compare{ - Opts: []interface{}{ + Opts: []any{ func(tr *Transport) { tr.DisableCompression = true }, }, Handler: func(w ResponseWriter, r *Request) { @@ -1172,7 +1172,7 @@ func TestInterruptWithPanic_ErrAbortHandler_h1(t *testing.T) { func TestInterruptWithPanic_ErrAbortHandler_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, ErrAbortHandler) } -func testInterruptWithPanic(t *testing.T, h2 bool, panicValue interface{}) { +func testInterruptWithPanic(t *testing.T, h2 bool, panicValue any) { setParallel(t) const msg = "hello" defer afterTest(t) @@ -1522,7 +1522,7 @@ func TestBidiStreamReverseProxy(t *testing.T) { })) defer proxy.close() - bodyRes := make(chan interface{}, 1) // error or hash.Hash + bodyRes := make(chan any, 1) // error or hash.Hash pr, pw := io.Pipe() req, _ := NewRequest("PUT", proxy.ts.URL, pr) const size = 4 << 20 @@ -1586,3 +1586,37 @@ func TestH12_WebSocketUpgrade(t *testing.T) { }, }.run(t) } + +func TestIdentityTransferEncoding_h1(t *testing.T) { testIdentityTransferEncoding(t, h1Mode) } +func TestIdentityTransferEncoding_h2(t *testing.T) { testIdentityTransferEncoding(t, h2Mode) } + +func testIdentityTransferEncoding(t *testing.T, h2 bool) { + setParallel(t) + defer afterTest(t) + + const body = "body" + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + gotBody, _ := io.ReadAll(r.Body) + if got, want := string(gotBody), body; got != want { + t.Errorf("got request body = %q; want %q", got, want) + } + w.Header().Set("Transfer-Encoding", "identity") + w.WriteHeader(StatusOK) + w.(Flusher).Flush() + io.WriteString(w, body) + })) + defer cst.close() + req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body)) + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + gotBody, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if got, want := string(gotBody), body; got != want { + t.Errorf("got response body = %q; want %q", got, want) + } +} diff --git a/libgo/go/net/http/cookie.go b/libgo/go/net/http/cookie.go index ca2c1c2..cb37f23 100644 --- a/libgo/go/net/http/cookie.go +++ b/libgo/go/net/http/cookie.go @@ -5,6 +5,8 @@ package http import ( + "errors" + "fmt" "log" "net" "net/http/internal/ascii" @@ -67,15 +69,14 @@ func readSetCookies(h Header) []*Cookie { continue } parts[0] = textproto.TrimString(parts[0]) - j := strings.Index(parts[0], "=") - if j < 0 { + name, value, ok := strings.Cut(parts[0], "=") + if !ok { continue } - name, value := parts[0][:j], parts[0][j+1:] if !isCookieNameValid(name) { continue } - value, ok := parseCookieValue(value, true) + value, ok = parseCookieValue(value, true) if !ok { continue } @@ -90,10 +91,7 @@ func readSetCookies(h Header) []*Cookie { continue } - attr, val := parts[i], "" - if j := strings.Index(attr, "="); j >= 0 { - attr, val = attr[:j], attr[j+1:] - } + attr, val, _ := strings.Cut(parts[i], "=") lowerAttr, isASCII := ascii.ToLower(attr) if !isASCII { continue @@ -240,6 +238,37 @@ func (c *Cookie) String() string { return b.String() } +// Valid reports whether the cookie is valid. +func (c *Cookie) Valid() error { + if c == nil { + return errors.New("http: nil Cookie") + } + if !isCookieNameValid(c.Name) { + return errors.New("http: invalid Cookie.Name") + } + if !validCookieExpires(c.Expires) { + return errors.New("http: invalid Cookie.Expires") + } + for i := 0; i < len(c.Value); i++ { + if !validCookieValueByte(c.Value[i]) { + return fmt.Errorf("http: invalid byte %q in Cookie.Value", c.Value[i]) + } + } + if len(c.Path) > 0 { + for i := 0; i < len(c.Path); i++ { + if !validCookiePathByte(c.Path[i]) { + return fmt.Errorf("http: invalid byte %q in Cookie.Path", c.Path[i]) + } + } + } + if len(c.Domain) > 0 { + if !validCookieDomain(c.Domain) { + return errors.New("http: invalid Cookie.Domain") + } + } + return nil +} + // readCookies parses all "Cookie" values from the header h and // returns the successfully parsed Cookies. // @@ -256,19 +285,12 @@ func readCookies(h Header, filter string) []*Cookie { var part string for len(line) > 0 { // continue since we have rest - if splitIndex := strings.Index(line, ";"); splitIndex > 0 { - part, line = line[:splitIndex], line[splitIndex+1:] - } else { - part, line = line, "" - } + part, line, _ = strings.Cut(line, ";") part = textproto.TrimString(part) - if len(part) == 0 { + if part == "" { continue } - name, val := part, "" - if j := strings.Index(part, "="); j >= 0 { - name, val = name[:j], name[j+1:] - } + name, val, _ := strings.Cut(part, "=") if !isCookieNameValid(name) { continue } @@ -379,7 +401,7 @@ func sanitizeCookieValue(v string) string { if len(v) == 0 { return v } - if strings.IndexByte(v, ' ') >= 0 || strings.IndexByte(v, ',') >= 0 { + if strings.ContainsAny(v, " ,") { return `"` + v + `"` } return v diff --git a/libgo/go/net/http/cookie_test.go b/libgo/go/net/http/cookie_test.go index 959713a..ccc5f98 100644 --- a/libgo/go/net/http/cookie_test.go +++ b/libgo/go/net/http/cookie_test.go @@ -360,7 +360,7 @@ var readSetCookiesTests = []struct { // Header{"Set-Cookie": {"ASP.NET_SessionId=foo; path=/; HttpOnly, .ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly"}}, } -func toJSON(v interface{}) string { +func toJSON(v any) string { b, err := json.Marshal(v) if err != nil { return fmt.Sprintf("%#v", v) @@ -529,6 +529,31 @@ func TestCookieSanitizePath(t *testing.T) { } } +func TestCookieValid(t *testing.T) { + tests := []struct { + cookie *Cookie + valid bool + }{ + {nil, false}, + {&Cookie{Name: ""}, false}, + {&Cookie{Name: "invalid-expires"}, false}, + {&Cookie{Name: "invalid-value", Value: "foo\"bar"}, false}, + {&Cookie{Name: "invalid-path", Path: "/foo;bar/"}, false}, + {&Cookie{Name: "invalid-domain", Domain: "example.com:80"}, false}, + {&Cookie{Name: "valid", Value: "foo", Path: "/bar", Domain: "example.com", Expires: time.Unix(0, 0)}, true}, + } + + for _, tt := range tests { + err := tt.cookie.Valid() + if err != nil && tt.valid { + t.Errorf("%#v.Valid() returned error %v; want nil", tt.cookie, err) + } + if err == nil && !tt.valid { + t.Errorf("%#v.Valid() returned nil; want error", tt.cookie) + } + } +} + func BenchmarkCookieString(b *testing.B) { const wantCookieString = `cookie-9=i3e01nf61b6t23bvfmplnanol3; Path=/restricted/; Domain=example.com; Expires=Tue, 10 Nov 2009 23:00:00 GMT; Max-Age=3600` c := &Cookie{ diff --git a/libgo/go/net/http/export_test.go b/libgo/go/net/http/export_test.go index 096a6d3..a849327 100644 --- a/libgo/go/net/http/export_test.go +++ b/libgo/go/net/http/export_test.go @@ -88,12 +88,7 @@ func SetPendingDialHooks(before, after func()) { func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn } -func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler { - ctx, cancel := context.WithCancel(context.Background()) - go func() { - <-ch - cancel() - }() +func NewTestTimeoutHandler(handler Handler, ctx context.Context) Handler { return &timeoutHandler{ handler: handler, testContext: ctx, diff --git a/libgo/go/net/http/fs.go b/libgo/go/net/http/fs.go index 57e731e..6caee9e 100644 --- a/libgo/go/net/http/fs.go +++ b/libgo/go/net/http/fs.go @@ -42,20 +42,20 @@ import ( // An empty Dir is treated as ".". type Dir string -// mapDirOpenError maps the provided non-nil error from opening name +// mapOpenError maps the provided non-nil error from opening name // to a possibly better non-nil error. In particular, it turns OS-specific errors -// about opening files in non-directories into fs.ErrNotExist. See Issue 18984. -func mapDirOpenError(originalErr error, name string) error { +// about opening files in non-directories into fs.ErrNotExist. See Issues 18984 and 49552. +func mapOpenError(originalErr error, name string, sep rune, stat func(string) (fs.FileInfo, error)) error { if errors.Is(originalErr, fs.ErrNotExist) || errors.Is(originalErr, fs.ErrPermission) { return originalErr } - parts := strings.Split(name, string(filepath.Separator)) + parts := strings.Split(name, string(sep)) for i := range parts { if parts[i] == "" { continue } - fi, err := os.Stat(strings.Join(parts[:i+1], string(filepath.Separator))) + fi, err := stat(strings.Join(parts[:i+1], string(sep))) if err != nil { return originalErr } @@ -79,7 +79,7 @@ func (d Dir) Open(name string) (File, error) { fullName := filepath.Join(dir, filepath.FromSlash(path.Clean("/"+name))) f, err := os.Open(fullName) if err != nil { - return nil, mapDirOpenError(err, fullName) + return nil, mapOpenError(err, fullName, filepath.Separator, os.Stat) } return f, nil } @@ -759,7 +759,9 @@ func (f ioFS) Open(name string) (File, error) { } file, err := f.fsys.Open(name) if err != nil { - return nil, err + return nil, mapOpenError(err, name, '/', func(path string) (fs.FileInfo, error) { + return fs.Stat(f.fsys, path) + }) } return ioFile{file}, nil } @@ -881,11 +883,11 @@ func parseRange(s string, size int64) ([]httpRange, error) { if ra == "" { continue } - i := strings.Index(ra, "-") - if i < 0 { + start, end, ok := strings.Cut(ra, "-") + if !ok { return nil, errors.New("invalid range") } - start, end := textproto.TrimString(ra[:i]), textproto.TrimString(ra[i+1:]) + start, end = textproto.TrimString(start), textproto.TrimString(end) var r httpRange if start == "" { // If no start is specified, end specifies the diff --git a/libgo/go/net/http/fs_test.go b/libgo/go/net/http/fs_test.go index b42ade1..d627dfd 100644 --- a/libgo/go/net/http/fs_test.go +++ b/libgo/go/net/http/fs_test.go @@ -658,7 +658,7 @@ type fakeFileInfo struct { } func (f *fakeFileInfo) Name() string { return f.basename } -func (f *fakeFileInfo) Sys() interface{} { return nil } +func (f *fakeFileInfo) Sys() any { return nil } func (f *fakeFileInfo) ModTime() time.Time { return f.modtime } func (f *fakeFileInfo) IsDir() bool { return f.dir } func (f *fakeFileInfo) Size() int64 { return int64(len(f.contents)) } @@ -1244,10 +1244,19 @@ func TestLinuxSendfileChild(*testing.T) { } } -// Issue 18984: tests that requests for paths beyond files return not-found errors +// Issues 18984, 49552: tests that requests for paths beyond files return not-found errors func TestFileServerNotDirError(t *testing.T) { defer afterTest(t) - ts := httptest.NewServer(FileServer(Dir("testdata"))) + t.Run("Dir", func(t *testing.T) { + testFileServerNotDirError(t, func(path string) FileSystem { return Dir(path) }) + }) + t.Run("FS", func(t *testing.T) { + testFileServerNotDirError(t, func(path string) FileSystem { return FS(os.DirFS(path)) }) + }) +} + +func testFileServerNotDirError(t *testing.T, newfs func(string) FileSystem) { + ts := httptest.NewServer(FileServer(newfs("testdata"))) defer ts.Close() res, err := Get(ts.URL + "/index.html/not-a-file") @@ -1259,9 +1268,9 @@ func TestFileServerNotDirError(t *testing.T) { t.Errorf("StatusCode = %v; want 404", res.StatusCode) } - test := func(name string, dir Dir) { + test := func(name string, fsys FileSystem) { t.Run(name, func(t *testing.T) { - _, err = dir.Open("/index.html/not-a-file") + _, err = fsys.Open("/index.html/not-a-file") if err == nil { t.Fatal("err == nil; want != nil") } @@ -1270,7 +1279,7 @@ func TestFileServerNotDirError(t *testing.T) { errors.Is(err, fs.ErrNotExist)) } - _, err = dir.Open("/index.html/not-a-dir/not-a-file") + _, err = fsys.Open("/index.html/not-a-dir/not-a-file") if err == nil { t.Fatal("err == nil; want != nil") } @@ -1286,8 +1295,8 @@ func TestFileServerNotDirError(t *testing.T) { t.Fatal("get abs path:", err) } - test("RelativePath", Dir("testdata")) - test("AbsolutePath", Dir(absPath)) + test("RelativePath", newfs("testdata")) + test("AbsolutePath", newfs(absPath)) } func TestFileServerCleanPath(t *testing.T) { diff --git a/libgo/go/net/http/h2_bundle.go b/libgo/go/net/http/h2_bundle.go index 8958a9e..bb82f24 100644 --- a/libgo/go/net/http/h2_bundle.go +++ b/libgo/go/net/http/h2_bundle.go @@ -53,6 +53,10 @@ import ( "golang.org/x/net/idna" ) +// The HTTP protocols are defined in terms of ASCII, not Unicode. This file +// contains helper functions which may use Unicode-aware functions which would +// otherwise be unsafe and could introduce vulnerabilities if used improperly. + // asciiEqualFold is strings.EqualFold, ASCII only. It reports whether s and t // are equal, ASCII-case-insensitively. func http2asciiEqualFold(s, t string) bool { @@ -733,6 +737,12 @@ func http2isBadCipher(cipher uint16) bool { // ClientConnPool manages a pool of HTTP/2 client connections. type http2ClientConnPool interface { + // GetClientConn returns a specific HTTP/2 connection (usually + // a TLS-TCP connection) to an HTTP/2 server. On success, the + // returned ClientConn accounts for the upcoming RoundTrip + // call, so the caller should not omit it. If the caller needs + // to, ClientConn.RoundTrip can be called with a bogus + // new(http.Request) to release the stream reservation. GetClientConn(req *Request, addr string) (*http2ClientConn, error) MarkDead(*http2ClientConn) } @@ -759,7 +769,7 @@ type http2clientConnPool struct { conns map[string][]*http2ClientConn // key is host:port dialing map[string]*http2dialCall // currently in-flight dials keys map[*http2ClientConn][]string - addConnCalls map[string]*http2addConnCall // in-flight addConnIfNeede calls + addConnCalls map[string]*http2addConnCall // in-flight addConnIfNeeded calls } func (p *http2clientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) { @@ -771,28 +781,8 @@ const ( http2noDialOnMiss = false ) -// shouldTraceGetConn reports whether getClientConn should call any -// ClientTrace.GetConn hook associated with the http.Request. -// -// This complexity is needed to avoid double calls of the GetConn hook -// during the back-and-forth between net/http and x/net/http2 (when the -// net/http.Transport is upgraded to also speak http2), as well as support -// the case where x/net/http2 is being used directly. -func (p *http2clientConnPool) shouldTraceGetConn(st http2clientConnIdleState) bool { - // If our Transport wasn't made via ConfigureTransport, always - // trace the GetConn hook if provided, because that means the - // http2 package is being used directly and it's the one - // dialing, as opposed to net/http. - if _, ok := p.t.ConnPool.(http2noDialClientConnPool); !ok { - return true - } - // Otherwise, only use the GetConn hook if this connection has - // been used previously for other requests. For fresh - // connections, the net/http package does the dialing. - return !st.freshConn -} - func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) { + // TODO(dneil): Dial a new connection when t.DisableKeepAlives is set? if http2isConnectionCloseRequest(req) && dialOnMiss { // It gets its own connection. http2traceGetConn(req, addr) @@ -806,10 +796,14 @@ func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMis for { p.mu.Lock() for _, cc := range p.conns[addr] { - if st := cc.idleState(); st.canTakeNewRequest { - if p.shouldTraceGetConn(st) { + if cc.ReserveNewRequest() { + // When a connection is presented to us by the net/http package, + // the GetConn hook has already been called. + // Don't call it a second time here. + if !cc.getConnCalled { http2traceGetConn(req, addr) } + cc.getConnCalled = false p.mu.Unlock() return cc, nil } @@ -825,7 +819,13 @@ func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMis if http2shouldRetryDial(call, req) { continue } - return call.res, call.err + cc, err := call.res, call.err + if err != nil { + return nil, err + } + if cc.ReserveNewRequest() { + return cc, nil + } } } @@ -922,6 +922,7 @@ func (c *http2addConnCall) run(t *http2Transport, key string, tc *tls.Conn) { if err != nil { c.err = err } else { + cc.getConnCalled = true // already called by the net/http package p.addConnLocked(key, cc) } delete(p.addConnCalls, key) @@ -1208,6 +1209,13 @@ func (e http2ErrCode) String() string { return fmt.Sprintf("unknown error code 0x%x", uint32(e)) } +func (e http2ErrCode) stringToken() string { + if s, ok := http2errCodeName[e]; ok { + return s + } + return fmt.Sprintf("ERR_UNKNOWN_%d", uint32(e)) +} + // ConnectionError is an error that results in the termination of the // entire connection. type http2ConnectionError http2ErrCode @@ -1224,6 +1232,11 @@ type http2StreamError struct { Cause error // optional additional detail } +// errFromPeer is a sentinel error value for StreamError.Cause to +// indicate that the StreamError was sent from the peer over the wire +// and wasn't locally generated in the Transport. +var http2errFromPeer = errors.New("received from peer") + func http2streamError(id uint32, code http2ErrCode) http2StreamError { return http2StreamError{StreamID: id, Code: code} } @@ -1438,7 +1451,7 @@ var http2flagName = map[http2FrameType]map[http2Flags]string{ // a frameParser parses a frame given its FrameHeader and payload // bytes. The length of payload will always equal fh.Length (which // might be 0). -type http2frameParser func(fc *http2frameCache, fh http2FrameHeader, payload []byte) (http2Frame, error) +type http2frameParser func(fc *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) var http2frameParsers = map[http2FrameType]http2frameParser{ http2FrameData: http2parseDataFrame, @@ -1583,6 +1596,11 @@ type http2Framer struct { lastFrame http2Frame errDetail error + // countError is a non-nil func that's called on a frame parse + // error with some unique error path token. It's initialized + // from Transport.CountError or Server.CountError. + countError func(errToken string) + // lastHeaderStream is non-zero if the last frame was an // unfinished HEADERS/CONTINUATION. lastHeaderStream uint32 @@ -1745,6 +1763,7 @@ func http2NewFramer(w io.Writer, r io.Reader) *http2Framer { fr := &http2Framer{ w: w, r: r, + countError: func(string) {}, logReads: http2logFrameReads, logWrites: http2logFrameWrites, debugReadLoggerf: log.Printf, @@ -1819,7 +1838,7 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) { if _, err := io.ReadFull(fr.r, payload); err != nil { return nil, err } - f, err := http2typeFrameParser(fh.Type)(fr.frameCache, fh, payload) + f, err := http2typeFrameParser(fh.Type)(fr.frameCache, fh, fr.countError, payload) if err != nil { if ce, ok := err.(http2connError); ok { return nil, fr.connError(ce.Code, ce.Reason) @@ -1907,13 +1926,14 @@ func (f *http2DataFrame) Data() []byte { return f.data } -func http2parseDataFrame(fc *http2frameCache, fh http2FrameHeader, payload []byte) (http2Frame, error) { +func http2parseDataFrame(fc *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) { if fh.StreamID == 0 { // DATA frames MUST be associated with a stream. If a // DATA frame is received whose stream identifier // field is 0x0, the recipient MUST respond with a // connection error (Section 5.4.1) of type // PROTOCOL_ERROR. + countError("frame_data_stream_0") return nil, http2connError{http2ErrCodeProtocol, "DATA frame with stream ID 0"} } f := fc.getDataFrame() @@ -1924,6 +1944,7 @@ func http2parseDataFrame(fc *http2frameCache, fh http2FrameHeader, payload []byt var err error payload, padSize, err = http2readByte(payload) if err != nil { + countError("frame_data_pad_byte_short") return nil, err } } @@ -1932,6 +1953,7 @@ func http2parseDataFrame(fc *http2frameCache, fh http2FrameHeader, payload []byt // length of the frame payload, the recipient MUST // treat this as a connection error. // Filed: https://github.com/http2/http2-spec/issues/610 + countError("frame_data_pad_too_big") return nil, http2connError{http2ErrCodeProtocol, "pad size larger than data payload"} } f.data = payload[:len(payload)-int(padSize)] @@ -2014,7 +2036,7 @@ type http2SettingsFrame struct { p []byte } -func http2parseSettingsFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { +func http2parseSettingsFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { if fh.Flags.Has(http2FlagSettingsAck) && fh.Length > 0 { // When this (ACK 0x1) bit is set, the payload of the // SETTINGS frame MUST be empty. Receipt of a @@ -2022,6 +2044,7 @@ func http2parseSettingsFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) // field value other than 0 MUST be treated as a // connection error (Section 5.4.1) of type // FRAME_SIZE_ERROR. + countError("frame_settings_ack_with_length") return nil, http2ConnectionError(http2ErrCodeFrameSize) } if fh.StreamID != 0 { @@ -2032,14 +2055,17 @@ func http2parseSettingsFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) // field is anything other than 0x0, the endpoint MUST // respond with a connection error (Section 5.4.1) of // type PROTOCOL_ERROR. + countError("frame_settings_has_stream") return nil, http2ConnectionError(http2ErrCodeProtocol) } if len(p)%6 != 0 { + countError("frame_settings_mod_6") // Expecting even number of 6 byte settings. return nil, http2ConnectionError(http2ErrCodeFrameSize) } f := &http2SettingsFrame{http2FrameHeader: fh, p: p} if v, ok := f.Value(http2SettingInitialWindowSize); ok && v > (1<<31)-1 { + countError("frame_settings_window_size_too_big") // Values above the maximum flow control window size of 2^31 - 1 MUST // be treated as a connection error (Section 5.4.1) of type // FLOW_CONTROL_ERROR. @@ -2151,11 +2177,13 @@ type http2PingFrame struct { func (f *http2PingFrame) IsAck() bool { return f.Flags.Has(http2FlagPingAck) } -func http2parsePingFrame(_ *http2frameCache, fh http2FrameHeader, payload []byte) (http2Frame, error) { +func http2parsePingFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) { if len(payload) != 8 { + countError("frame_ping_length") return nil, http2ConnectionError(http2ErrCodeFrameSize) } if fh.StreamID != 0 { + countError("frame_ping_has_stream") return nil, http2ConnectionError(http2ErrCodeProtocol) } f := &http2PingFrame{http2FrameHeader: fh} @@ -2191,11 +2219,13 @@ func (f *http2GoAwayFrame) DebugData() []byte { return f.debugData } -func http2parseGoAwayFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { +func http2parseGoAwayFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { if fh.StreamID != 0 { + countError("frame_goaway_has_stream") return nil, http2ConnectionError(http2ErrCodeProtocol) } if len(p) < 8 { + countError("frame_goaway_short") return nil, http2ConnectionError(http2ErrCodeFrameSize) } return &http2GoAwayFrame{ @@ -2231,7 +2261,7 @@ func (f *http2UnknownFrame) Payload() []byte { return f.p } -func http2parseUnknownFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { +func http2parseUnknownFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { return &http2UnknownFrame{fh, p}, nil } @@ -2242,8 +2272,9 @@ type http2WindowUpdateFrame struct { Increment uint32 // never read with high bit set } -func http2parseWindowUpdateFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { +func http2parseWindowUpdateFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { if len(p) != 4 { + countError("frame_windowupdate_bad_len") return nil, http2ConnectionError(http2ErrCodeFrameSize) } inc := binary.BigEndian.Uint32(p[:4]) & 0x7fffffff // mask off high reserved bit @@ -2255,8 +2286,10 @@ func http2parseWindowUpdateFrame(_ *http2frameCache, fh http2FrameHeader, p []by // control window MUST be treated as a connection // error (Section 5.4.1). if fh.StreamID == 0 { + countError("frame_windowupdate_zero_inc_conn") return nil, http2ConnectionError(http2ErrCodeProtocol) } + countError("frame_windowupdate_zero_inc_stream") return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol) } return &http2WindowUpdateFrame{ @@ -2307,7 +2340,7 @@ func (f *http2HeadersFrame) HasPriority() bool { return f.http2FrameHeader.Flags.Has(http2FlagHeadersPriority) } -func http2parseHeadersFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (_ http2Frame, err error) { +func http2parseHeadersFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (_ http2Frame, err error) { hf := &http2HeadersFrame{ http2FrameHeader: fh, } @@ -2316,11 +2349,13 @@ func http2parseHeadersFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) ( // is received whose stream identifier field is 0x0, the recipient MUST // respond with a connection error (Section 5.4.1) of type // PROTOCOL_ERROR. + countError("frame_headers_zero_stream") return nil, http2connError{http2ErrCodeProtocol, "HEADERS frame with stream ID 0"} } var padLength uint8 if fh.Flags.Has(http2FlagHeadersPadded) { if p, padLength, err = http2readByte(p); err != nil { + countError("frame_headers_pad_short") return } } @@ -2328,16 +2363,19 @@ func http2parseHeadersFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) ( var v uint32 p, v, err = http2readUint32(p) if err != nil { + countError("frame_headers_prio_short") return nil, err } hf.Priority.StreamDep = v & 0x7fffffff hf.Priority.Exclusive = (v != hf.Priority.StreamDep) // high bit was set p, hf.Priority.Weight, err = http2readByte(p) if err != nil { + countError("frame_headers_prio_weight_short") return nil, err } } - if len(p)-int(padLength) <= 0 { + if len(p)-int(padLength) < 0 { + countError("frame_headers_pad_too_big") return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol) } hf.headerFragBuf = p[:len(p)-int(padLength)] @@ -2444,11 +2482,13 @@ func (p http2PriorityParam) IsZero() bool { return p == http2PriorityParam{} } -func http2parsePriorityFrame(_ *http2frameCache, fh http2FrameHeader, payload []byte) (http2Frame, error) { +func http2parsePriorityFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) { if fh.StreamID == 0 { + countError("frame_priority_zero_stream") return nil, http2connError{http2ErrCodeProtocol, "PRIORITY frame with stream ID 0"} } if len(payload) != 5 { + countError("frame_priority_bad_length") return nil, http2connError{http2ErrCodeFrameSize, fmt.Sprintf("PRIORITY frame payload size was %d; want 5", len(payload))} } v := binary.BigEndian.Uint32(payload[:4]) @@ -2491,11 +2531,13 @@ type http2RSTStreamFrame struct { ErrCode http2ErrCode } -func http2parseRSTStreamFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { +func http2parseRSTStreamFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { if len(p) != 4 { + countError("frame_rststream_bad_len") return nil, http2ConnectionError(http2ErrCodeFrameSize) } if fh.StreamID == 0 { + countError("frame_rststream_zero_stream") return nil, http2ConnectionError(http2ErrCodeProtocol) } return &http2RSTStreamFrame{fh, http2ErrCode(binary.BigEndian.Uint32(p[:4]))}, nil @@ -2521,8 +2563,9 @@ type http2ContinuationFrame struct { headerFragBuf []byte } -func http2parseContinuationFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { +func http2parseContinuationFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { if fh.StreamID == 0 { + countError("frame_continuation_zero_stream") return nil, http2connError{http2ErrCodeProtocol, "CONTINUATION frame with stream ID 0"} } return &http2ContinuationFrame{fh, p}, nil @@ -2571,7 +2614,7 @@ func (f *http2PushPromiseFrame) HeadersEnded() bool { return f.http2FrameHeader.Flags.Has(http2FlagPushPromiseEndHeaders) } -func http2parsePushPromise(_ *http2frameCache, fh http2FrameHeader, p []byte) (_ http2Frame, err error) { +func http2parsePushPromise(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (_ http2Frame, err error) { pp := &http2PushPromiseFrame{ http2FrameHeader: fh, } @@ -2582,6 +2625,7 @@ func http2parsePushPromise(_ *http2frameCache, fh http2FrameHeader, p []byte) (_ // with. If the stream identifier field specifies the value // 0x0, a recipient MUST respond with a connection error // (Section 5.4.1) of type PROTOCOL_ERROR. + countError("frame_pushpromise_zero_stream") return nil, http2ConnectionError(http2ErrCodeProtocol) } // The PUSH_PROMISE frame includes optional padding. @@ -2589,18 +2633,21 @@ func http2parsePushPromise(_ *http2frameCache, fh http2FrameHeader, p []byte) (_ var padLength uint8 if fh.Flags.Has(http2FlagPushPromisePadded) { if p, padLength, err = http2readByte(p); err != nil { + countError("frame_pushpromise_pad_short") return } } p, pp.PromiseID, err = http2readUint32(p) if err != nil { + countError("frame_pushpromise_promiseid_short") return } pp.PromiseID = pp.PromiseID & (1<<31 - 1) if int(padLength) > len(p) { // like the DATA frame, error out if padding is longer than the body. + countError("frame_pushpromise_pad_too_big") return nil, http2ConnectionError(http2ErrCodeProtocol) } pp.headerFragBuf = p[:len(p)-int(padLength)] @@ -3570,6 +3617,17 @@ type http2pipeBuffer interface { io.Reader } +// setBuffer initializes the pipe buffer. +// It has no effect if the pipe is already closed. +func (p *http2pipe) setBuffer(b http2pipeBuffer) { + p.mu.Lock() + defer p.mu.Unlock() + if p.err != nil || p.breakErr != nil { + return + } + p.b = b +} + func (p *http2pipe) Len() int { p.mu.Lock() defer p.mu.Unlock() @@ -3786,6 +3844,12 @@ type http2Server struct { // If nil, a default scheduler is chosen. NewWriteScheduler func() http2WriteScheduler + // CountError, if non-nil, is called on HTTP/2 server errors. + // It's intended to increment a metric for monitoring, such + // as an expvar or Prometheus metric. + // The errType consists of only ASCII word characters. + CountError func(errType string) + // Internal state. This is a pointer (rather than embedded directly) // so that we don't embed a Mutex in this struct, which will make the // struct non-copyable, which might break some callers. @@ -3915,16 +3979,12 @@ func http2ConfigureServer(s *Server, conf *http2Server) error { s.TLSConfig.PreferServerCipherSuites = true - haveNPN := false - for _, p := range s.TLSConfig.NextProtos { - if p == http2NextProtoTLS { - haveNPN = true - break - } - } - if !haveNPN { + if !http2strSliceContains(s.TLSConfig.NextProtos, http2NextProtoTLS) { s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, http2NextProtoTLS) } + if !http2strSliceContains(s.TLSConfig.NextProtos, "http/1.1") { + s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "http/1.1") + } if s.TLSNextProto == nil { s.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){} @@ -4065,6 +4125,9 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) fr := http2NewFramer(sc.bw, c) + if s.CountError != nil { + fr.countError = s.CountError + } fr.ReadMetaHeaders = hpack.NewDecoder(http2initialHeaderTableSize, nil) fr.MaxHeaderListSize = sc.maxHeaderListSize() fr.SetMaxReadFrameSize(s.maxReadFrameSize()) @@ -4373,7 +4436,15 @@ func (sc *http2serverConn) canonicalHeader(v string) string { sc.canonHeader = make(map[string]string) } cv = CanonicalHeaderKey(v) - sc.canonHeader[v] = cv + // maxCachedCanonicalHeaders is an arbitrarily-chosen limit on the number of + // entries in the canonHeader cache. This should be larger than the number + // of unique, uncommon header keys likely to be sent by the peer, while not + // so high as to permit unreaasonable memory usage if the peer sends an unbounded + // number of unique header keys. + const maxCachedCanonicalHeaders = 32 + if len(sc.canonHeader) < maxCachedCanonicalHeaders { + sc.canonHeader[v] = cv + } return cv } @@ -4479,7 +4550,7 @@ func (sc *http2serverConn) serve() { }) sc.unackedSettings++ - // Each connection starts with intialWindowSize inflow tokens. + // Each connection starts with initialWindowSize inflow tokens. // If a higher value is configured, we add more tokens. if diff := sc.srv.initialConnRecvWindowSize() - http2initialWindowSize; diff > 0 { sc.sendWindowUpdate(nil, int(diff)) @@ -5064,7 +5135,7 @@ func (sc *http2serverConn) processFrame(f http2Frame) error { // First frame received must be SETTINGS. if !sc.sawFirstSettings { if _, ok := f.(*http2SettingsFrame); !ok { - return http2ConnectionError(http2ErrCodeProtocol) + return sc.countError("first_settings", http2ConnectionError(http2ErrCodeProtocol)) } sc.sawFirstSettings = true } @@ -5089,7 +5160,7 @@ func (sc *http2serverConn) processFrame(f http2Frame) error { case *http2PushPromiseFrame: // A client cannot push. Thus, servers MUST treat the receipt of a PUSH_PROMISE // frame as a connection error (Section 5.4.1) of type PROTOCOL_ERROR. - return http2ConnectionError(http2ErrCodeProtocol) + return sc.countError("push_promise", http2ConnectionError(http2ErrCodeProtocol)) default: sc.vlogf("http2: server ignoring frame: %v", f.Header()) return nil @@ -5109,7 +5180,7 @@ func (sc *http2serverConn) processPing(f *http2PingFrame) error { // identifier field value other than 0x0, the recipient MUST // respond with a connection error (Section 5.4.1) of type // PROTOCOL_ERROR." - return http2ConnectionError(http2ErrCodeProtocol) + return sc.countError("ping_on_stream", http2ConnectionError(http2ErrCodeProtocol)) } if sc.inGoAway && sc.goAwayCode != http2ErrCodeNo { return nil @@ -5128,7 +5199,7 @@ func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error // or PRIORITY on a stream in this state MUST be // treated as a connection error (Section 5.4.1) of // type PROTOCOL_ERROR." - return http2ConnectionError(http2ErrCodeProtocol) + return sc.countError("stream_idle", http2ConnectionError(http2ErrCodeProtocol)) } if st == nil { // "WINDOW_UPDATE can be sent by a peer that has sent a @@ -5139,7 +5210,7 @@ func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error return nil } if !st.flow.add(int32(f.Increment)) { - return http2streamError(f.StreamID, http2ErrCodeFlowControl) + return sc.countError("bad_flow", http2streamError(f.StreamID, http2ErrCodeFlowControl)) } default: // connection-level flow control if !sc.flow.add(int32(f.Increment)) { @@ -5160,7 +5231,7 @@ func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error { // identifying an idle stream is received, the // recipient MUST treat this as a connection error // (Section 5.4.1) of type PROTOCOL_ERROR. - return http2ConnectionError(http2ErrCodeProtocol) + return sc.countError("reset_idle_stream", http2ConnectionError(http2ErrCodeProtocol)) } if st != nil { st.cancelCtx() @@ -5212,7 +5283,7 @@ func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error { // Why is the peer ACKing settings we never sent? // The spec doesn't mention this case, but // hang up on them anyway. - return http2ConnectionError(http2ErrCodeProtocol) + return sc.countError("ack_mystery", http2ConnectionError(http2ErrCodeProtocol)) } return nil } @@ -5220,7 +5291,7 @@ func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error { // This isn't actually in the spec, but hang up on // suspiciously large settings frames or those with // duplicate entries. - return http2ConnectionError(http2ErrCodeProtocol) + return sc.countError("settings_big_or_dups", http2ConnectionError(http2ErrCodeProtocol)) } if err := f.ForeachSetting(sc.processSetting); err != nil { return err @@ -5287,7 +5358,7 @@ func (sc *http2serverConn) processSettingInitialWindowSize(val uint32) error { // control window to exceed the maximum size as a // connection error (Section 5.4.1) of type // FLOW_CONTROL_ERROR." - return http2ConnectionError(http2ErrCodeFlowControl) + return sc.countError("setting_win_size", http2ConnectionError(http2ErrCodeFlowControl)) } } return nil @@ -5320,7 +5391,7 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { // or PRIORITY on a stream in this state MUST be // treated as a connection error (Section 5.4.1) of // type PROTOCOL_ERROR." - return http2ConnectionError(http2ErrCodeProtocol) + return sc.countError("data_on_idle", http2ConnectionError(http2ErrCodeProtocol)) } // "If a DATA frame is received whose stream is not in "open" @@ -5337,7 +5408,7 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { // and return any flow control bytes since we're not going // to consume them. if sc.inflow.available() < int32(f.Length) { - return http2streamError(id, http2ErrCodeFlowControl) + return sc.countError("data_flow", http2streamError(id, http2ErrCodeFlowControl)) } // Deduct the flow control from inflow, since we're // going to immediately add it back in @@ -5350,7 +5421,7 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { // Already have a stream error in flight. Don't send another. return nil } - return http2streamError(id, http2ErrCodeStreamClosed) + return sc.countError("closed", http2streamError(id, http2ErrCodeStreamClosed)) } if st.body == nil { panic("internal error: should have a body in this state") @@ -5362,12 +5433,12 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { // RFC 7540, sec 8.1.2.6: A request or response is also malformed if the // value of a content-length header field does not equal the sum of the // DATA frame payload lengths that form the body. - return http2streamError(id, http2ErrCodeProtocol) + return sc.countError("send_too_much", http2streamError(id, http2ErrCodeProtocol)) } if f.Length > 0 { // Check whether the client has flow control quota. if st.inflow.available() < int32(f.Length) { - return http2streamError(id, http2ErrCodeFlowControl) + return sc.countError("flow_on_data_length", http2streamError(id, http2ErrCodeFlowControl)) } st.inflow.take(int32(f.Length)) @@ -5375,7 +5446,7 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { wrote, err := st.body.Write(data) if err != nil { sc.sendWindowUpdate(nil, int(f.Length)-wrote) - return http2streamError(id, http2ErrCodeStreamClosed) + return sc.countError("body_write_err", http2streamError(id, http2ErrCodeStreamClosed)) } if wrote != len(data) { panic("internal error: bad Writer") @@ -5461,7 +5532,7 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { // stream identifier MUST respond with a connection error // (Section 5.4.1) of type PROTOCOL_ERROR. if id%2 != 1 { - return http2ConnectionError(http2ErrCodeProtocol) + return sc.countError("headers_even", http2ConnectionError(http2ErrCodeProtocol)) } // A HEADERS frame can be used to create a new stream or // send a trailer for an open one. If we already have a stream @@ -5478,7 +5549,7 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { // this state, it MUST respond with a stream error (Section 5.4.2) of // type STREAM_CLOSED. if st.state == http2stateHalfClosedRemote { - return http2streamError(id, http2ErrCodeStreamClosed) + return sc.countError("headers_half_closed", http2streamError(id, http2ErrCodeStreamClosed)) } return st.processTrailerHeaders(f) } @@ -5489,7 +5560,7 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { // receives an unexpected stream identifier MUST respond with // a connection error (Section 5.4.1) of type PROTOCOL_ERROR. if id <= sc.maxClientStreamID { - return http2ConnectionError(http2ErrCodeProtocol) + return sc.countError("stream_went_down", http2ConnectionError(http2ErrCodeProtocol)) } sc.maxClientStreamID = id @@ -5506,14 +5577,14 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { if sc.curClientStreams+1 > sc.advMaxStreams { if sc.unackedSettings == 0 { // They should know better. - return http2streamError(id, http2ErrCodeProtocol) + return sc.countError("over_max_streams", http2streamError(id, http2ErrCodeProtocol)) } // Assume it's a network race, where they just haven't // received our last SETTINGS update. But actually // this can't happen yet, because we don't yet provide // a way for users to adjust server parameters at // runtime. - return http2streamError(id, http2ErrCodeRefusedStream) + return sc.countError("over_max_streams_race", http2streamError(id, http2ErrCodeRefusedStream)) } initialState := http2stateOpen @@ -5523,7 +5594,7 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { st := sc.newStream(id, 0, initialState) if f.HasPriority() { - if err := http2checkPriority(f.StreamID, f.Priority); err != nil { + if err := sc.checkPriority(f.StreamID, f.Priority); err != nil { return err } sc.writeSched.AdjustStream(st.id, f.Priority) @@ -5567,15 +5638,15 @@ func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error { sc := st.sc sc.serveG.check() if st.gotTrailerHeader { - return http2ConnectionError(http2ErrCodeProtocol) + return sc.countError("dup_trailers", http2ConnectionError(http2ErrCodeProtocol)) } st.gotTrailerHeader = true if !f.StreamEnded() { - return http2streamError(st.id, http2ErrCodeProtocol) + return sc.countError("trailers_not_ended", http2streamError(st.id, http2ErrCodeProtocol)) } if len(f.PseudoFields()) > 0 { - return http2streamError(st.id, http2ErrCodeProtocol) + return sc.countError("trailers_pseudo", http2streamError(st.id, http2ErrCodeProtocol)) } if st.trailer != nil { for _, hf := range f.RegularFields() { @@ -5584,7 +5655,7 @@ func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error { // TODO: send more details to the peer somehow. But http2 has // no way to send debug data at a stream level. Discuss with // HTTP folk. - return http2streamError(st.id, http2ErrCodeProtocol) + return sc.countError("trailers_bogus", http2streamError(st.id, http2ErrCodeProtocol)) } st.trailer[key] = append(st.trailer[key], hf.Value) } @@ -5593,13 +5664,13 @@ func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error { return nil } -func http2checkPriority(streamID uint32, p http2PriorityParam) error { +func (sc *http2serverConn) checkPriority(streamID uint32, p http2PriorityParam) error { if streamID == p.StreamDep { // Section 5.3.1: "A stream cannot depend on itself. An endpoint MUST treat // this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR." // Section 5.3.3 says that a stream can depend on one of its dependencies, // so it's only self-dependencies that are forbidden. - return http2streamError(streamID, http2ErrCodeProtocol) + return sc.countError("priority", http2streamError(streamID, http2ErrCodeProtocol)) } return nil } @@ -5608,7 +5679,7 @@ func (sc *http2serverConn) processPriority(f *http2PriorityFrame) error { if sc.inGoAway { return nil } - if err := http2checkPriority(f.StreamID, f.http2PriorityParam); err != nil { + if err := sc.checkPriority(f.StreamID, f.http2PriorityParam); err != nil { return err } sc.writeSched.AdjustStream(f.StreamID, f.http2PriorityParam) @@ -5665,7 +5736,7 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead isConnect := rp.method == "CONNECT" if isConnect { if rp.path != "" || rp.scheme != "" || rp.authority == "" { - return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) + return nil, nil, sc.countError("bad_connect", http2streamError(f.StreamID, http2ErrCodeProtocol)) } } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") { // See 8.1.2.6 Malformed Requests and Responses: @@ -5678,13 +5749,13 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead // "All HTTP/2 requests MUST include exactly one valid // value for the :method, :scheme, and :path // pseudo-header fields" - return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) + return nil, nil, sc.countError("bad_path_method", http2streamError(f.StreamID, http2ErrCodeProtocol)) } bodyOpen := !f.StreamEnded() if rp.method == "HEAD" && bodyOpen { // HEAD requests can't have bodies - return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) + return nil, nil, sc.countError("head_body", http2streamError(f.StreamID, http2ErrCodeProtocol)) } rp.header = make(Header) @@ -5767,7 +5838,7 @@ func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2re var err error url_, err = url.ParseRequestURI(rp.path) if err != nil { - return nil, nil, http2streamError(st.id, http2ErrCodeProtocol) + return nil, nil, sc.countError("bad_path", http2streamError(st.id, http2ErrCodeProtocol)) } requestURI = rp.path } @@ -6651,6 +6722,34 @@ func http2h1ServerKeepAlivesDisabled(hs *Server) bool { return false } +func (sc *http2serverConn) countError(name string, err error) error { + if sc == nil || sc.srv == nil { + return err + } + f := sc.srv.CountError + if f == nil { + return err + } + var typ string + var code http2ErrCode + switch e := err.(type) { + case http2ConnectionError: + typ = "conn" + code = http2ErrCode(e) + case http2StreamError: + typ = "stream" + code = http2ErrCode(e.Code) + default: + return err + } + codeStr := http2errCodeName[code] + if codeStr == "" { + codeStr = strconv.Itoa(int(code)) + } + f(fmt.Sprintf("%s_%s_%s", typ, codeStr, name)) + return err +} + const ( // transportDefaultConnFlow is how many connection-level flow control // tokens we give the server at start-up, past the default 64k. @@ -6666,6 +6765,15 @@ const ( http2transportDefaultStreamMinRefresh = 4 << 10 http2defaultUserAgent = "Go-http-client/2.0" + + // initialMaxConcurrentStreams is a connections maxConcurrentStreams until + // it's received servers initial SETTINGS frame, which corresponds with the + // spec's minimum recommended value. + http2initialMaxConcurrentStreams = 100 + + // defaultMaxConcurrentStreams is a connections default maxConcurrentStreams + // if the server doesn't include one in its initial SETTINGS frame. + http2defaultMaxConcurrentStreams = 1000 ) // Transport is an HTTP/2 Transport. @@ -6736,6 +6844,17 @@ type http2Transport struct { // Defaults to 15s. PingTimeout time.Duration + // WriteByteTimeout is the timeout after which the connection will be + // closed no data can be written to it. The timeout begins when data is + // available to write, and is extended whenever any bytes are written. + WriteByteTimeout time.Duration + + // CountError, if non-nil, is called on HTTP/2 transport errors. + // It's intended to increment a metric for monitoring, such + // as an expvar or Prometheus metric. + // The errType consists of only ASCII word characters. + CountError func(errType string) + // t1, if non-nil, is the standard library Transport using // this transport. Its settings are used (but not its // RoundTrip method, etc). @@ -6842,11 +6961,12 @@ func (t *http2Transport) initConnPool() { // ClientConn is the state of a single HTTP/2 client connection to an // HTTP/2 server. type http2ClientConn struct { - t *http2Transport - tconn net.Conn // usually *tls.Conn, except specialized impls - tlsState *tls.ConnectionState // nil only for specialized impls - reused uint32 // whether conn is being reused; atomic - singleUse bool // whether being used for a single http.Request + t *http2Transport + tconn net.Conn // usually *tls.Conn, except specialized impls + tlsState *tls.ConnectionState // nil only for specialized impls + reused uint32 // whether conn is being reused; atomic + singleUse bool // whether being used for a single http.Request + getConnCalled bool // used by clientConnPool // readLoop goroutine fields: readerDone chan struct{} // closed on error @@ -6859,87 +6979,94 @@ type http2ClientConn struct { cond *sync.Cond // hold mu; broadcast on flow/closed changes flow http2flow // our conn-level flow control quota (cs.flow is per stream) inflow http2flow // peer's conn-level flow control + doNotReuse bool // whether conn is marked to not be reused for any future requests closing bool closed bool + seenSettings bool // true if we've seen a settings frame, false otherwise wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back goAway *http2GoAwayFrame // if non-nil, the GoAwayFrame we received goAwayDebug string // goAway frame's debug data, retained as a string streams map[uint32]*http2clientStream // client-initiated + streamsReserved int // incr by ReserveNewRequest; decr on RoundTrip nextStreamID uint32 pendingRequests int // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams pings map[[8]byte]chan struct{} // in flight ping data to notification channel - bw *bufio.Writer br *bufio.Reader - fr *http2Framer lastActive time.Time lastIdle time.Time // time last idle - // Settings from peer: (also guarded by mu) + // Settings from peer: (also guarded by wmu) maxFrameSize uint32 maxConcurrentStreams uint32 peerMaxHeaderListSize uint64 initialWindowSize uint32 - hbuf bytes.Buffer // HPACK encoder writes into this - henc *hpack.Encoder - freeBuf [][]byte + // reqHeaderMu is a 1-element semaphore channel controlling access to sending new requests. + // Write to reqHeaderMu to lock it, read from it to unlock. + // Lock reqmu BEFORE mu or wmu. + reqHeaderMu chan struct{} - wmu sync.Mutex // held while writing; acquire AFTER mu if holding both - werr error // first write error that has occurred + // wmu is held while writing. + // Acquire BEFORE mu when holding both, to avoid blocking mu on network writes. + // Only acquire both at the same time when changing peer settings. + wmu sync.Mutex + bw *bufio.Writer + fr *http2Framer + werr error // first write error that has occurred + hbuf bytes.Buffer // HPACK encoder writes into this + henc *hpack.Encoder } // clientStream is the state for a single HTTP/2 stream. One of these // is created for each Transport.RoundTrip call. type http2clientStream struct { - cc *http2ClientConn - req *Request + cc *http2ClientConn + + // Fields of Request that we may access even after the response body is closed. + ctx context.Context + reqCancel <-chan struct{} + trace *httptrace.ClientTrace // or nil ID uint32 - resc chan http2resAndError bufPipe http2pipe // buffered pipe with the flow-controlled response payload - startedWrite bool // started request body write; guarded by cc.mu requestedGzip bool - on100 func() // optional code to run if get a 100 continue response + isHead bool + + abortOnce sync.Once + abort chan struct{} // closed to signal stream should end immediately + abortErr error // set if abort is closed + + peerClosed chan struct{} // closed when the peer sends an END_STREAM flag + donec chan struct{} // closed after the stream is in the closed state + on100 chan struct{} // buffered; written to if a 100 is received + + respHeaderRecv chan struct{} // closed when headers are received + res *Response // set if respHeaderRecv is closed flow http2flow // guarded by cc.mu inflow http2flow // guarded by cc.mu bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read readErr error // sticky read error; owned by transportResponseBody.Read - stopReqBody error // if non-nil, stop writing req body; guarded by cc.mu - didReset bool // whether we sent a RST_STREAM to the server; guarded by cc.mu - peerReset chan struct{} // closed on peer reset - resetErr error // populated before peerReset is closed + reqBody io.ReadCloser + reqBodyContentLength int64 // -1 means unknown + reqBodyClosed bool // body has been closed; guarded by cc.mu - done chan struct{} // closed when stream remove from cc.streams map; close calls guarded by cc.mu + // owned by writeRequest: + sentEndStream bool // sent an END_STREAM flag to the peer + sentHeaders bool // owned by clientConnReadLoop: firstByte bool // got the first response byte pastHeaders bool // got first MetaHeadersFrame (actual headers) pastTrailers bool // got optional second MetaHeadersFrame (trailers) num1xx uint8 // number of 1xx responses seen + readClosed bool // peer sent an END_STREAM flag + readAborted bool // read loop reset the stream trailer Header // accumulated trailers resTrailer *Header // client's Response.Trailer } -// awaitRequestCancel waits for the user to cancel a request or for the done -// channel to be signaled. A non-nil error is returned only if the request was -// canceled. -func http2awaitRequestCancel(req *Request, done <-chan struct{}) error { - ctx := req.Context() - if req.Cancel == nil && ctx.Done() == nil { - return nil - } - select { - case <-req.Cancel: - return http2errRequestCanceled - case <-ctx.Done(): - return ctx.Err() - case <-done: - return nil - } -} - var http2got1xxFuncForTests func(int, textproto.MIMEHeader) error // get1xxTraceFunc returns the value of request's httptrace.ClientTrace.Got1xxResponse func, @@ -6951,73 +7078,65 @@ func (cs *http2clientStream) get1xxTraceFunc() func(int, textproto.MIMEHeader) e return http2traceGot1xxResponseFunc(cs.trace) } -// awaitRequestCancel waits for the user to cancel a request, its context to -// expire, or for the request to be done (any way it might be removed from the -// cc.streams map: peer reset, successful completion, TCP connection breakage, -// etc). If the request is canceled, then cs will be canceled and closed. -func (cs *http2clientStream) awaitRequestCancel(req *Request) { - if err := http2awaitRequestCancel(req, cs.done); err != nil { - cs.cancelStream() - cs.bufPipe.CloseWithError(err) - } +func (cs *http2clientStream) abortStream(err error) { + cs.cc.mu.Lock() + defer cs.cc.mu.Unlock() + cs.abortStreamLocked(err) } -func (cs *http2clientStream) cancelStream() { - cc := cs.cc - cc.mu.Lock() - didReset := cs.didReset - cs.didReset = true - cc.mu.Unlock() - - if !didReset { - cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) - cc.forgetStreamID(cs.ID) +func (cs *http2clientStream) abortStreamLocked(err error) { + cs.abortOnce.Do(func() { + cs.abortErr = err + close(cs.abort) + }) + if cs.reqBody != nil && !cs.reqBodyClosed { + cs.reqBody.Close() + cs.reqBodyClosed = true } -} - -// checkResetOrDone reports any error sent in a RST_STREAM frame by the -// server, or errStreamClosed if the stream is complete. -func (cs *http2clientStream) checkResetOrDone() error { - select { - case <-cs.peerReset: - return cs.resetErr - case <-cs.done: - return http2errStreamClosed - default: - return nil + // TODO(dneil): Clean up tests where cs.cc.cond is nil. + if cs.cc.cond != nil { + // Wake up writeRequestBody if it is waiting on flow control. + cs.cc.cond.Broadcast() } } -func (cs *http2clientStream) getStartedWrite() bool { +func (cs *http2clientStream) abortRequestBodyWrite() { cc := cs.cc cc.mu.Lock() defer cc.mu.Unlock() - return cs.startedWrite -} - -func (cs *http2clientStream) abortRequestBodyWrite(err error) { - if err == nil { - panic("nil error") + if cs.reqBody != nil && !cs.reqBodyClosed { + cs.reqBody.Close() + cs.reqBodyClosed = true + cc.cond.Broadcast() } - cc := cs.cc - cc.mu.Lock() - cs.stopReqBody = err - cc.cond.Broadcast() - cc.mu.Unlock() } type http2stickyErrWriter struct { - w io.Writer - err *error + conn net.Conn + timeout time.Duration + err *error } func (sew http2stickyErrWriter) Write(p []byte) (n int, err error) { if *sew.err != nil { return 0, *sew.err } - n, err = sew.w.Write(p) - *sew.err = err - return + for { + if sew.timeout != 0 { + sew.conn.SetWriteDeadline(time.Now().Add(sew.timeout)) + } + nn, err := sew.conn.Write(p[n:]) + n += nn + if n < len(p) && nn > 0 && errors.Is(err, os.ErrDeadlineExceeded) { + // Keep extending the deadline so long as we're making progress. + continue + } + if sew.timeout != 0 { + sew.conn.SetWriteDeadline(time.Time{}) + } + *sew.err = err + return n, err + } } // noCachedConnError is the concrete type of ErrNoCachedConn, which @@ -7091,9 +7210,9 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res } reused := !atomic.CompareAndSwapUint32(&cc.reused, 0, 1) http2traceGotConn(req, cc, reused) - res, gotErrAfterReqBodyWrite, err := cc.roundTrip(req) + res, err := cc.RoundTrip(req) if err != nil && retry <= 6 { - if req, err = http2shouldRetryRequest(req, err, gotErrAfterReqBodyWrite); err == nil { + if req, err = http2shouldRetryRequest(req, err); err == nil { // After the first retry, do exponential backoff with 10% jitter. if retry == 0 { continue @@ -7104,7 +7223,7 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res case <-time.After(time.Second * time.Duration(backoff)): continue case <-req.Context().Done(): - return nil, req.Context().Err() + err = req.Context().Err() } } } @@ -7135,7 +7254,7 @@ var ( // response headers. It is always called with a non-nil error. // It returns either a request to retry (either the same request, or a // modified clone), or an error if the request can't be replayed. -func http2shouldRetryRequest(req *Request, err error, afterBodyWrite bool) (*Request, error) { +func http2shouldRetryRequest(req *Request, err error) (*Request, error) { if !http2canRetryError(err) { return nil, err } @@ -7148,7 +7267,6 @@ func http2shouldRetryRequest(req *Request, err error, afterBodyWrite bool) (*Req // If the request body can be reset back to its original // state via the optional req.GetBody, do that. if req.GetBody != nil { - // TODO: consider a req.Body.Close here? or audit that all caller paths do? body, err := req.GetBody() if err != nil { return nil, err @@ -7160,10 +7278,8 @@ func http2shouldRetryRequest(req *Request, err error, afterBodyWrite bool) (*Req // The Request.Body can't reset back to the beginning, but we // don't seem to have started to read from it yet, so reuse - // the request directly. The "afterBodyWrite" means the - // bodyWrite process has started, which becomes true before - // the first Read. - if !afterBodyWrite { + // the request directly. + if err == http2errClientConnUnusable { return req, nil } @@ -7175,6 +7291,10 @@ func http2canRetryError(err error) bool { return true } if se, ok := err.(http2StreamError); ok { + if se.Code == http2ErrCodeProtocol && se.Cause == http2errFromPeer { + // See golang/go#47635, golang/go#42777 + return true + } return se.Code == http2ErrCodeRefusedStream } return false @@ -7249,14 +7369,15 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client tconn: c, readerDone: make(chan struct{}), nextStreamID: 1, - maxFrameSize: 16 << 10, // spec default - initialWindowSize: 65535, // spec default - maxConcurrentStreams: 1000, // "infinite", per spec. 1000 seems good enough. - peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead. + maxFrameSize: 16 << 10, // spec default + initialWindowSize: 65535, // spec default + maxConcurrentStreams: http2initialMaxConcurrentStreams, // "infinite", per spec. Use a smaller value until we have received server settings. + peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead. streams: make(map[uint32]*http2clientStream), singleUse: singleUse, wantSettingsAck: true, pings: make(map[[8]byte]chan struct{}), + reqHeaderMu: make(chan struct{}, 1), } if d := t.idleConnTimeout(); d != 0 { cc.idleTimeout = d @@ -7271,9 +7392,16 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client // TODO: adjust this writer size to account for frame size + // MTU + crypto/tls record padding. - cc.bw = bufio.NewWriter(http2stickyErrWriter{c, &cc.werr}) + cc.bw = bufio.NewWriter(http2stickyErrWriter{ + conn: c, + timeout: t.WriteByteTimeout, + err: &cc.werr, + }) cc.br = bufio.NewReader(c) cc.fr = http2NewFramer(cc.bw, cc.br) + if t.CountError != nil { + cc.fr.countError = t.CountError + } cc.fr.ReadMetaHeaders = hpack.NewDecoder(http2initialHeaderTableSize, nil) cc.fr.MaxHeaderListSize = t.maxHeaderListSize() @@ -7326,6 +7454,13 @@ func (cc *http2ClientConn) healthCheck() { } } +// SetDoNotReuse marks cc as not reusable for future HTTP requests. +func (cc *http2ClientConn) SetDoNotReuse() { + cc.mu.Lock() + defer cc.mu.Unlock() + cc.doNotReuse = true +} + func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) { cc.mu.Lock() defer cc.mu.Unlock() @@ -7343,27 +7478,94 @@ func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) { last := f.LastStreamID for streamID, cs := range cc.streams { if streamID > last { - select { - case cs.resc <- http2resAndError{err: http2errClientConnGotGoAway}: - default: - } + cs.abortStreamLocked(http2errClientConnGotGoAway) } } } // CanTakeNewRequest reports whether the connection can take a new request, // meaning it has not been closed or received or sent a GOAWAY. +// +// If the caller is going to immediately make a new request on this +// connection, use ReserveNewRequest instead. func (cc *http2ClientConn) CanTakeNewRequest() bool { cc.mu.Lock() defer cc.mu.Unlock() return cc.canTakeNewRequestLocked() } +// ReserveNewRequest is like CanTakeNewRequest but also reserves a +// concurrent stream in cc. The reservation is decremented on the +// next call to RoundTrip. +func (cc *http2ClientConn) ReserveNewRequest() bool { + cc.mu.Lock() + defer cc.mu.Unlock() + if st := cc.idleStateLocked(); !st.canTakeNewRequest { + return false + } + cc.streamsReserved++ + return true +} + +// ClientConnState describes the state of a ClientConn. +type http2ClientConnState struct { + // Closed is whether the connection is closed. + Closed bool + + // Closing is whether the connection is in the process of + // closing. It may be closing due to shutdown, being a + // single-use connection, being marked as DoNotReuse, or + // having received a GOAWAY frame. + Closing bool + + // StreamsActive is how many streams are active. + StreamsActive int + + // StreamsReserved is how many streams have been reserved via + // ClientConn.ReserveNewRequest. + StreamsReserved int + + // StreamsPending is how many requests have been sent in excess + // of the peer's advertised MaxConcurrentStreams setting and + // are waiting for other streams to complete. + StreamsPending int + + // MaxConcurrentStreams is how many concurrent streams the + // peer advertised as acceptable. Zero means no SETTINGS + // frame has been received yet. + MaxConcurrentStreams uint32 + + // LastIdle, if non-zero, is when the connection last + // transitioned to idle state. + LastIdle time.Time +} + +// State returns a snapshot of cc's state. +func (cc *http2ClientConn) State() http2ClientConnState { + cc.wmu.Lock() + maxConcurrent := cc.maxConcurrentStreams + if !cc.seenSettings { + maxConcurrent = 0 + } + cc.wmu.Unlock() + + cc.mu.Lock() + defer cc.mu.Unlock() + return http2ClientConnState{ + Closed: cc.closed, + Closing: cc.closing || cc.singleUse || cc.doNotReuse || cc.goAway != nil, + StreamsActive: len(cc.streams), + StreamsReserved: cc.streamsReserved, + StreamsPending: cc.pendingRequests, + LastIdle: cc.lastIdle, + MaxConcurrentStreams: maxConcurrent, + } +} + // clientConnIdleState describes the suitability of a client // connection to initiate a new RoundTrip request. type http2clientConnIdleState struct { canTakeNewRequest bool - freshConn bool // whether it's unused by any previous request } func (cc *http2ClientConn) idleState() http2clientConnIdleState { @@ -7384,13 +7586,13 @@ func (cc *http2ClientConn) idleStateLocked() (st http2clientConnIdleState) { // writing it. maxConcurrentOkay = true } else { - maxConcurrentOkay = int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) + maxConcurrentOkay = int64(len(cc.streams)+cc.streamsReserved+1) <= int64(cc.maxConcurrentStreams) } st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing && maxConcurrentOkay && + !cc.doNotReuse && int64(cc.nextStreamID)+2*int64(cc.pendingRequests) < math.MaxInt32 && !cc.tooIdleLocked() - st.freshConn = cc.nextStreamID == 1 && st.canTakeNewRequest return } @@ -7421,7 +7623,7 @@ func (cc *http2ClientConn) onIdleTimeout() { func (cc *http2ClientConn) closeIfIdle() { cc.mu.Lock() - if len(cc.streams) > 0 { + if len(cc.streams) > 0 || cc.streamsReserved > 0 { cc.mu.Unlock() return } @@ -7436,9 +7638,15 @@ func (cc *http2ClientConn) closeIfIdle() { cc.tconn.Close() } +func (cc *http2ClientConn) isDoNotReuseAndIdle() bool { + cc.mu.Lock() + defer cc.mu.Unlock() + return cc.doNotReuse && len(cc.streams) == 0 +} + var http2shutdownEnterWaitStateHook = func() {} -// Shutdown gracefully close the client connection, waiting for running streams to complete. +// Shutdown gracefully closes the client connection, waiting for running streams to complete. func (cc *http2ClientConn) Shutdown(ctx context.Context) error { if err := cc.sendGoAway(); err != nil { return err @@ -7477,15 +7685,18 @@ func (cc *http2ClientConn) Shutdown(ctx context.Context) error { func (cc *http2ClientConn) sendGoAway() error { cc.mu.Lock() - defer cc.mu.Unlock() - cc.wmu.Lock() - defer cc.wmu.Unlock() - if cc.closing { + closing := cc.closing + cc.closing = true + maxStreamID := cc.nextStreamID + cc.mu.Unlock() + if closing { // GOAWAY sent already return nil } + + cc.wmu.Lock() + defer cc.wmu.Unlock() // Send a graceful shutdown frame to server - maxStreamID := cc.nextStreamID if err := cc.fr.WriteGoAway(maxStreamID, http2ErrCodeNo, nil); err != nil { return err } @@ -7493,7 +7704,6 @@ func (cc *http2ClientConn) sendGoAway() error { return err } // Prevent new requests - cc.closing = true return nil } @@ -7501,17 +7711,12 @@ func (cc *http2ClientConn) sendGoAway() error { // err is sent to streams. func (cc *http2ClientConn) closeForError(err error) error { cc.mu.Lock() + cc.closed = true + for _, cs := range cc.streams { + cs.abortStreamLocked(err) + } defer cc.cond.Broadcast() defer cc.mu.Unlock() - for id, cs := range cc.streams { - select { - case cs.resc <- http2resAndError{err: err}: - default: - } - cs.bufPipe.CloseWithError(err) - delete(cc.streams, id) - } - cc.closed = true return cc.tconn.Close() } @@ -7526,47 +7731,10 @@ func (cc *http2ClientConn) Close() error { // closes the client connection immediately. In-flight requests are interrupted. func (cc *http2ClientConn) closeForLostPing() error { err := errors.New("http2: client connection lost") - return cc.closeForError(err) -} - -const http2maxAllocFrameSize = 512 << 10 - -// frameBuffer returns a scratch buffer suitable for writing DATA frames. -// They're capped at the min of the peer's max frame size or 512KB -// (kinda arbitrarily), but definitely capped so we don't allocate 4GB -// bufers. -func (cc *http2ClientConn) frameScratchBuffer() []byte { - cc.mu.Lock() - size := cc.maxFrameSize - if size > http2maxAllocFrameSize { - size = http2maxAllocFrameSize - } - for i, buf := range cc.freeBuf { - if len(buf) >= int(size) { - cc.freeBuf[i] = nil - cc.mu.Unlock() - return buf[:size] - } - } - cc.mu.Unlock() - return make([]byte, size) -} - -func (cc *http2ClientConn) putFrameScratchBuffer(buf []byte) { - cc.mu.Lock() - defer cc.mu.Unlock() - const maxBufs = 4 // arbitrary; 4 concurrent requests per conn? investigate. - if len(cc.freeBuf) < maxBufs { - cc.freeBuf = append(cc.freeBuf, buf) - return - } - for i, old := range cc.freeBuf { - if old == nil { - cc.freeBuf[i] = buf - return - } + if f := cc.t.CountError; f != nil { + f("conn_close_lost_ping") } - // forget about it. + return cc.closeForError(err) } // errRequestCanceled is a copy of net/http's errRequestCanceled because it's not @@ -7630,41 +7798,158 @@ func http2actualContentLength(req *Request) int64 { return -1 } +func (cc *http2ClientConn) decrStreamReservations() { + cc.mu.Lock() + defer cc.mu.Unlock() + cc.decrStreamReservationsLocked() +} + +func (cc *http2ClientConn) decrStreamReservationsLocked() { + if cc.streamsReserved > 0 { + cc.streamsReserved-- + } +} + func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { - resp, _, err := cc.roundTrip(req) - return resp, err + ctx := req.Context() + cs := &http2clientStream{ + cc: cc, + ctx: ctx, + reqCancel: req.Cancel, + isHead: req.Method == "HEAD", + reqBody: req.Body, + reqBodyContentLength: http2actualContentLength(req), + trace: httptrace.ContextClientTrace(ctx), + peerClosed: make(chan struct{}), + abort: make(chan struct{}), + respHeaderRecv: make(chan struct{}), + donec: make(chan struct{}), + } + go cs.doRequest(req) + + waitDone := func() error { + select { + case <-cs.donec: + return nil + case <-ctx.Done(): + return ctx.Err() + case <-cs.reqCancel: + return http2errRequestCanceled + } + } + + handleResponseHeaders := func() (*Response, error) { + res := cs.res + if res.StatusCode > 299 { + // On error or status code 3xx, 4xx, 5xx, etc abort any + // ongoing write, assuming that the server doesn't care + // about our request body. If the server replied with 1xx or + // 2xx, however, then assume the server DOES potentially + // want our body (e.g. full-duplex streaming: + // golang.org/issue/13444). If it turns out the server + // doesn't, they'll RST_STREAM us soon enough. This is a + // heuristic to avoid adding knobs to Transport. Hopefully + // we can keep it. + cs.abortRequestBodyWrite() + } + res.Request = req + res.TLS = cc.tlsState + if res.Body == http2noBody && http2actualContentLength(req) == 0 { + // If there isn't a request or response body still being + // written, then wait for the stream to be closed before + // RoundTrip returns. + if err := waitDone(); err != nil { + return nil, err + } + } + return res, nil + } + + for { + select { + case <-cs.respHeaderRecv: + return handleResponseHeaders() + case <-cs.abort: + select { + case <-cs.respHeaderRecv: + // If both cs.respHeaderRecv and cs.abort are signaling, + // pick respHeaderRecv. The server probably wrote the + // response and immediately reset the stream. + // golang.org/issue/49645 + return handleResponseHeaders() + default: + waitDone() + return nil, cs.abortErr + } + case <-ctx.Done(): + err := ctx.Err() + cs.abortStream(err) + return nil, err + case <-cs.reqCancel: + cs.abortStream(http2errRequestCanceled) + return nil, http2errRequestCanceled + } + } } -func (cc *http2ClientConn) roundTrip(req *Request) (res *Response, gotErrAfterReqBodyWrite bool, err error) { +// doRequest runs for the duration of the request lifetime. +// +// It sends the request and performs post-request cleanup (closing Request.Body, etc.). +func (cs *http2clientStream) doRequest(req *Request) { + err := cs.writeRequest(req) + cs.cleanupWriteRequest(err) +} + +// writeRequest sends a request. +// +// It returns nil after the request is written, the response read, +// and the request stream is half-closed by the peer. +// +// It returns non-nil if the request ends otherwise. +// If the returned error is StreamError, the error Code may be used in resetting the stream. +func (cs *http2clientStream) writeRequest(req *Request) (err error) { + cc := cs.cc + ctx := cs.ctx + if err := http2checkConnHeaders(req); err != nil { - return nil, false, err - } - if cc.idleTimer != nil { - cc.idleTimer.Stop() + return err } - trailers, err := http2commaSeparatedTrailers(req) - if err != nil { - return nil, false, err + // Acquire the new-request lock by writing to reqHeaderMu. + // This lock guards the critical section covering allocating a new stream ID + // (requires mu) and creating the stream (requires wmu). + if cc.reqHeaderMu == nil { + panic("RoundTrip on uninitialized ClientConn") // for tests + } + select { + case cc.reqHeaderMu <- struct{}{}: + case <-cs.reqCancel: + return http2errRequestCanceled + case <-ctx.Done(): + return ctx.Err() } - hasTrailers := trailers != "" cc.mu.Lock() - if err := cc.awaitOpenSlotForRequest(req); err != nil { + if cc.idleTimer != nil { + cc.idleTimer.Stop() + } + cc.decrStreamReservationsLocked() + if err := cc.awaitOpenSlotForStreamLocked(cs); err != nil { cc.mu.Unlock() - return nil, false, err + <-cc.reqHeaderMu + return err } - - body := req.Body - contentLen := http2actualContentLength(req) - hasBody := contentLen != 0 + cc.addStreamLocked(cs) // assigns stream ID + if http2isConnectionCloseRequest(req) { + cc.doNotReuse = true + } + cc.mu.Unlock() // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? - var requestedGzip bool if !cc.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && - req.Method != "HEAD" { + !cs.isHead { // Request gzip only, not deflate. Deflate is ambiguous and // not as universally supported anyway. // See: https://zlib.net/zlib_faq.html#faq39 @@ -7677,210 +7962,224 @@ func (cc *http2ClientConn) roundTrip(req *Request) (res *Response, gotErrAfterRe // We don't request gzip if the request is for a range, since // auto-decoding a portion of a gzipped document will just fail // anyway. See https://golang.org/issue/8923 - requestedGzip = true + cs.requestedGzip = true } - // we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is - // sent by writeRequestBody below, along with any Trailers, - // again in form HEADERS{1}, CONTINUATION{0,}) - hdrs, err := cc.encodeHeaders(req, requestedGzip, trailers, contentLen) - if err != nil { - cc.mu.Unlock() - return nil, false, err + continueTimeout := cc.t.expectContinueTimeout() + if continueTimeout != 0 { + if !httpguts.HeaderValuesContainsToken(req.Header["Expect"], "100-continue") { + continueTimeout = 0 + } else { + cs.on100 = make(chan struct{}, 1) + } } - cs := cc.newStream() - cs.req = req - cs.trace = httptrace.ContextClientTrace(req.Context()) - cs.requestedGzip = requestedGzip - bodyWriter := cc.t.getBodyWriterState(cs, body) - cs.on100 = bodyWriter.on100 + // Past this point (where we send request headers), it is possible for + // RoundTrip to return successfully. Since the RoundTrip contract permits + // the caller to "mutate or reuse" the Request after closing the Response's Body, + // we must take care when referencing the Request from here on. + err = cs.encodeAndWriteHeaders(req) + <-cc.reqHeaderMu + if err != nil { + return err + } - defer func() { - cc.wmu.Lock() - werr := cc.werr - cc.wmu.Unlock() - if werr != nil { - cc.Close() + hasBody := cs.reqBodyContentLength != 0 + if !hasBody { + cs.sentEndStream = true + } else { + if continueTimeout != 0 { + http2traceWait100Continue(cs.trace) + timer := time.NewTimer(continueTimeout) + select { + case <-timer.C: + err = nil + case <-cs.on100: + err = nil + case <-cs.abort: + err = cs.abortErr + case <-ctx.Done(): + err = ctx.Err() + case <-cs.reqCancel: + err = http2errRequestCanceled + } + timer.Stop() + if err != nil { + http2traceWroteRequest(cs.trace, err) + return err + } } - }() - - cc.wmu.Lock() - endStream := !hasBody && !hasTrailers - werr := cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs) - cc.wmu.Unlock() - http2traceWroteHeaders(cs.trace) - cc.mu.Unlock() - if werr != nil { - if hasBody { - req.Body.Close() // per RoundTripper contract - bodyWriter.cancel() + if err = cs.writeRequestBody(req); err != nil { + if err != http2errStopReqBodyWrite { + http2traceWroteRequest(cs.trace, err) + return err + } + } else { + cs.sentEndStream = true } - cc.forgetStreamID(cs.ID) - // Don't bother sending a RST_STREAM (our write already failed; - // no need to keep writing) - http2traceWroteRequest(cs.trace, werr) - return nil, false, werr } + http2traceWroteRequest(cs.trace, err) + var respHeaderTimer <-chan time.Time - if hasBody { - bodyWriter.scheduleBodyWrite() - } else { - http2traceWroteRequest(cs.trace, nil) - if d := cc.responseHeaderTimeout(); d != 0 { - timer := time.NewTimer(d) - defer timer.Stop() - respHeaderTimer = timer.C + var respHeaderRecv chan struct{} + if d := cc.responseHeaderTimeout(); d != 0 { + timer := time.NewTimer(d) + defer timer.Stop() + respHeaderTimer = timer.C + respHeaderRecv = cs.respHeaderRecv + } + // Wait until the peer half-closes its end of the stream, + // or until the request is aborted (via context, error, or otherwise), + // whichever comes first. + for { + select { + case <-cs.peerClosed: + return nil + case <-respHeaderTimer: + return http2errTimeout + case <-respHeaderRecv: + respHeaderRecv = nil + respHeaderTimer = nil // keep waiting for END_STREAM + case <-cs.abort: + return cs.abortErr + case <-ctx.Done(): + return ctx.Err() + case <-cs.reqCancel: + return http2errRequestCanceled } } +} - readLoopResCh := cs.resc - bodyWritten := false - ctx := req.Context() +func (cs *http2clientStream) encodeAndWriteHeaders(req *Request) error { + cc := cs.cc + ctx := cs.ctx - handleReadLoopResponse := func(re http2resAndError) (*Response, bool, error) { - res := re.res - if re.err != nil || res.StatusCode > 299 { - // On error or status code 3xx, 4xx, 5xx, etc abort any - // ongoing write, assuming that the server doesn't care - // about our request body. If the server replied with 1xx or - // 2xx, however, then assume the server DOES potentially - // want our body (e.g. full-duplex streaming: - // golang.org/issue/13444). If it turns out the server - // doesn't, they'll RST_STREAM us soon enough. This is a - // heuristic to avoid adding knobs to Transport. Hopefully - // we can keep it. - bodyWriter.cancel() - cs.abortRequestBodyWrite(http2errStopReqBodyWrite) - if hasBody && !bodyWritten { - <-bodyWriter.resc - } - } - if re.err != nil { - cc.forgetStreamID(cs.ID) - return nil, cs.getStartedWrite(), re.err - } - res.Request = req - res.TLS = cc.tlsState - return res, false, nil + cc.wmu.Lock() + defer cc.wmu.Unlock() + + // If the request was canceled while waiting for cc.mu, just quit. + select { + case <-cs.abort: + return cs.abortErr + case <-ctx.Done(): + return ctx.Err() + case <-cs.reqCancel: + return http2errRequestCanceled + default: } - for { + // Encode headers. + // + // we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is + // sent by writeRequestBody below, along with any Trailers, + // again in form HEADERS{1}, CONTINUATION{0,}) + trailers, err := http2commaSeparatedTrailers(req) + if err != nil { + return err + } + hasTrailers := trailers != "" + contentLen := http2actualContentLength(req) + hasBody := contentLen != 0 + hdrs, err := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen) + if err != nil { + return err + } + + // Write the request. + endStream := !hasBody && !hasTrailers + cs.sentHeaders = true + err = cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs) + http2traceWroteHeaders(cs.trace) + return err +} + +// cleanupWriteRequest performs post-request tasks. +// +// If err (the result of writeRequest) is non-nil and the stream is not closed, +// cleanupWriteRequest will send a reset to the peer. +func (cs *http2clientStream) cleanupWriteRequest(err error) { + cc := cs.cc + + if cs.ID == 0 { + // We were canceled before creating the stream, so return our reservation. + cc.decrStreamReservations() + } + + // TODO: write h12Compare test showing whether + // Request.Body is closed by the Transport, + // and in multiple cases: server replies <=299 and >299 + // while still writing request body + cc.mu.Lock() + bodyClosed := cs.reqBodyClosed + cs.reqBodyClosed = true + cc.mu.Unlock() + if !bodyClosed && cs.reqBody != nil { + cs.reqBody.Close() + } + + if err != nil && cs.sentEndStream { + // If the connection is closed immediately after the response is read, + // we may be aborted before finishing up here. If the stream was closed + // cleanly on both sides, there is no error. select { - case re := <-readLoopResCh: - return handleReadLoopResponse(re) - case <-respHeaderTimer: - if !hasBody || bodyWritten { - cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) - } else { - bodyWriter.cancel() - cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) - <-bodyWriter.resc - } - cc.forgetStreamID(cs.ID) - return nil, cs.getStartedWrite(), http2errTimeout - case <-ctx.Done(): - select { - case re := <-readLoopResCh: - return handleReadLoopResponse(re) - default: - } - if !hasBody || bodyWritten { - cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) - } else { - bodyWriter.cancel() - cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) - <-bodyWriter.resc - } - cc.forgetStreamID(cs.ID) - return nil, cs.getStartedWrite(), ctx.Err() - case <-req.Cancel: - select { - case re := <-readLoopResCh: - return handleReadLoopResponse(re) - default: - } - if !hasBody || bodyWritten { - cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + case <-cs.peerClosed: + err = nil + default: + } + } + if err != nil { + cs.abortStream(err) // possibly redundant, but harmless + if cs.sentHeaders { + if se, ok := err.(http2StreamError); ok { + if se.Cause != http2errFromPeer { + cc.writeStreamReset(cs.ID, se.Code, err) + } } else { - bodyWriter.cancel() - cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) - <-bodyWriter.resc - } - cc.forgetStreamID(cs.ID) - return nil, cs.getStartedWrite(), http2errRequestCanceled - case <-cs.peerReset: - select { - case re := <-readLoopResCh: - return handleReadLoopResponse(re) - default: - } - // processResetStream already removed the - // stream from the streams map; no need for - // forgetStreamID. - return nil, cs.getStartedWrite(), cs.resetErr - case err := <-bodyWriter.resc: - bodyWritten = true - // Prefer the read loop's response, if available. Issue 16102. - select { - case re := <-readLoopResCh: - return handleReadLoopResponse(re) - default: - } - if err != nil { - cc.forgetStreamID(cs.ID) - return nil, cs.getStartedWrite(), err - } - if d := cc.responseHeaderTimeout(); d != 0 { - timer := time.NewTimer(d) - defer timer.Stop() - respHeaderTimer = timer.C + cc.writeStreamReset(cs.ID, http2ErrCodeCancel, err) } } + cs.bufPipe.CloseWithError(err) // no-op if already closed + } else { + if cs.sentHeaders && !cs.sentEndStream { + cc.writeStreamReset(cs.ID, http2ErrCodeNo, nil) + } + cs.bufPipe.CloseWithError(http2errRequestCanceled) + } + if cs.ID != 0 { + cc.forgetStreamID(cs.ID) + } + + cc.wmu.Lock() + werr := cc.werr + cc.wmu.Unlock() + if werr != nil { + cc.Close() } + + close(cs.donec) } -// awaitOpenSlotForRequest waits until len(streams) < maxConcurrentStreams. +// awaitOpenSlotForStream waits until len(streams) < maxConcurrentStreams. // Must hold cc.mu. -func (cc *http2ClientConn) awaitOpenSlotForRequest(req *Request) error { - var waitingForConn chan struct{} - var waitingForConnErr error // guarded by cc.mu +func (cc *http2ClientConn) awaitOpenSlotForStreamLocked(cs *http2clientStream) error { for { cc.lastActive = time.Now() if cc.closed || !cc.canTakeNewRequestLocked() { - if waitingForConn != nil { - close(waitingForConn) - } return http2errClientConnUnusable } cc.lastIdle = time.Time{} - if int64(len(cc.streams))+1 <= int64(cc.maxConcurrentStreams) { - if waitingForConn != nil { - close(waitingForConn) - } + if int64(len(cc.streams)) < int64(cc.maxConcurrentStreams) { return nil } - // Unfortunately, we cannot wait on a condition variable and channel at - // the same time, so instead, we spin up a goroutine to check if the - // request is canceled while we wait for a slot to open in the connection. - if waitingForConn == nil { - waitingForConn = make(chan struct{}) - go func() { - if err := http2awaitRequestCancel(req, waitingForConn); err != nil { - cc.mu.Lock() - waitingForConnErr = err - cc.cond.Broadcast() - cc.mu.Unlock() - } - }() - } cc.pendingRequests++ cc.cond.Wait() cc.pendingRequests-- - if waitingForConnErr != nil { - return waitingForConnErr + select { + case <-cs.abort: + return cs.abortErr + default: } } } @@ -7907,10 +8206,6 @@ func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, maxFram cc.fr.WriteContinuation(streamID, endHeaders, chunk) } } - // TODO(bradfitz): this Flush could potentially block (as - // could the WriteHeaders call(s) above), which means they - // wouldn't respond to Request.Cancel being readable. That's - // rare, but this should probably be in a goroutine. cc.bw.Flush() return cc.werr } @@ -7926,32 +8221,59 @@ var ( http2errReqBodyTooLong = errors.New("http2: request body larger than specified content length") ) -func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (err error) { +// frameScratchBufferLen returns the length of a buffer to use for +// outgoing request bodies to read/write to/from. +// +// It returns max(1, min(peer's advertised max frame size, +// Request.ContentLength+1, 512KB)). +func (cs *http2clientStream) frameScratchBufferLen(maxFrameSize int) int { + const max = 512 << 10 + n := int64(maxFrameSize) + if n > max { + n = max + } + if cl := cs.reqBodyContentLength; cl != -1 && cl+1 < n { + // Add an extra byte past the declared content-length to + // give the caller's Request.Body io.Reader a chance to + // give us more bytes than they declared, so we can catch it + // early. + n = cl + 1 + } + if n < 1 { + return 1 + } + return int(n) // doesn't truncate; max is 512K +} + +var http2bufPool sync.Pool // of *[]byte + +func (cs *http2clientStream) writeRequestBody(req *Request) (err error) { cc := cs.cc + body := cs.reqBody sentEnd := false // whether we sent the final DATA frame w/ END_STREAM - buf := cc.frameScratchBuffer() - defer cc.putFrameScratchBuffer(buf) - - defer func() { - http2traceWroteRequest(cs.trace, err) - // TODO: write h12Compare test showing whether - // Request.Body is closed by the Transport, - // and in multiple cases: server replies <=299 and >299 - // while still writing request body - cerr := bodyCloser.Close() - if err == nil { - err = cerr - } - }() - req := cs.req hasTrailers := req.Trailer != nil - remainLen := http2actualContentLength(req) + remainLen := cs.reqBodyContentLength hasContentLen := remainLen != -1 + cc.mu.Lock() + maxFrameSize := int(cc.maxFrameSize) + cc.mu.Unlock() + + // Scratch buffer for reading into & writing from. + scratchLen := cs.frameScratchBufferLen(maxFrameSize) + var buf []byte + if bp, ok := http2bufPool.Get().(*[]byte); ok && len(*bp) >= scratchLen { + defer http2bufPool.Put(bp) + buf = *bp + } else { + buf = make([]byte, scratchLen) + defer http2bufPool.Put(&buf) + } + var sawEOF bool for !sawEOF { - n, err := body.Read(buf[:len(buf)-1]) + n, err := body.Read(buf[:len(buf)]) if hasContentLen { remainLen -= int64(n) if remainLen == 0 && err == nil { @@ -7962,35 +8284,36 @@ func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Clos // to send the END_STREAM bit early, double-check that we're actually // at EOF. Subsequent reads should return (0, EOF) at this point. // If either value is different, we return an error in one of two ways below. + var scratch [1]byte var n1 int - n1, err = body.Read(buf[n:]) + n1, err = body.Read(scratch[:]) remainLen -= int64(n1) } if remainLen < 0 { err = http2errReqBodyTooLong - cc.writeStreamReset(cs.ID, http2ErrCodeCancel, err) return err } } - if err == io.EOF { - sawEOF = true - err = nil - } else if err != nil { - cc.writeStreamReset(cs.ID, http2ErrCodeCancel, err) - return err + if err != nil { + cc.mu.Lock() + bodyClosed := cs.reqBodyClosed + cc.mu.Unlock() + switch { + case bodyClosed: + return http2errStopReqBodyWrite + case err == io.EOF: + sawEOF = true + err = nil + default: + return err + } } remain := buf[:n] for len(remain) > 0 && err == nil { var allowed int32 allowed, err = cs.awaitFlowControl(len(remain)) - switch { - case err == http2errStopReqBodyWrite: - return err - case err == http2errStopReqBodyWriteAndCancel: - cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) - return err - case err != nil: + if err != nil { return err } cc.wmu.Lock() @@ -8021,24 +8344,26 @@ func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Clos return nil } - var trls []byte - if hasTrailers { - cc.mu.Lock() - trls, err = cc.encodeTrailers(req) - cc.mu.Unlock() - if err != nil { - cc.writeStreamReset(cs.ID, http2ErrCodeInternal, err) - cc.forgetStreamID(cs.ID) - return err - } - } - + // Since the RoundTrip contract permits the caller to "mutate or reuse" + // a request after the Response's Body is closed, verify that this hasn't + // happened before accessing the trailers. cc.mu.Lock() - maxFrameSize := int(cc.maxFrameSize) + trailer := req.Trailer + err = cs.abortErr cc.mu.Unlock() + if err != nil { + return err + } cc.wmu.Lock() defer cc.wmu.Unlock() + var trls []byte + if len(trailer) > 0 { + trls, err = cc.encodeTrailers(trailer) + if err != nil { + return err + } + } // Two ways to send END_STREAM: either with trailers, or // with an empty DATA frame. @@ -8059,17 +8384,24 @@ func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Clos // if the stream is dead. func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) { cc := cs.cc + ctx := cs.ctx cc.mu.Lock() defer cc.mu.Unlock() for { if cc.closed { return 0, http2errClientConnClosed } - if cs.stopReqBody != nil { - return 0, cs.stopReqBody + if cs.reqBodyClosed { + return 0, http2errStopReqBodyWrite } - if err := cs.checkResetOrDone(); err != nil { - return 0, err + select { + case <-cs.abort: + return 0, cs.abortErr + case <-ctx.Done(): + return 0, ctx.Err() + case <-cs.reqCancel: + return 0, http2errRequestCanceled + default: } if a := cs.flow.available(); a > 0 { take := a @@ -8087,9 +8419,14 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err er } } -// requires cc.mu be held. +var http2errNilRequestURL = errors.New("http2: Request.URI is nil") + +// requires cc.wmu be held. func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) { cc.hbuf.Reset() + if req.URL == nil { + return nil, http2errNilRequestURL + } host := req.Host if host == "" { @@ -8275,12 +8612,12 @@ func http2shouldSendReqContentLength(method string, contentLength int64) bool { } } -// requires cc.mu be held. -func (cc *http2ClientConn) encodeTrailers(req *Request) ([]byte, error) { +// requires cc.wmu be held. +func (cc *http2ClientConn) encodeTrailers(trailer Header) ([]byte, error) { cc.hbuf.Reset() hlSize := uint64(0) - for k, vv := range req.Trailer { + for k, vv := range trailer { for _, v := range vv { hf := hpack.HeaderField{Name: k, Value: v} hlSize += uint64(hf.Size()) @@ -8290,7 +8627,7 @@ func (cc *http2ClientConn) encodeTrailers(req *Request) ([]byte, error) { return nil, http2errRequestHeaderListSize } - for k, vv := range req.Trailer { + for k, vv := range trailer { lowKey, ascii := http2asciiToLower(k) if !ascii { // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header @@ -8320,51 +8657,51 @@ type http2resAndError struct { } // requires cc.mu be held. -func (cc *http2ClientConn) newStream() *http2clientStream { - cs := &http2clientStream{ - cc: cc, - ID: cc.nextStreamID, - resc: make(chan http2resAndError, 1), - peerReset: make(chan struct{}), - done: make(chan struct{}), - } +func (cc *http2ClientConn) addStreamLocked(cs *http2clientStream) { cs.flow.add(int32(cc.initialWindowSize)) cs.flow.setConnFlow(&cc.flow) cs.inflow.add(http2transportDefaultStreamFlow) cs.inflow.setConnFlow(&cc.inflow) + cs.ID = cc.nextStreamID cc.nextStreamID += 2 cc.streams[cs.ID] = cs - return cs + if cs.ID == 0 { + panic("assigned stream ID 0") + } } func (cc *http2ClientConn) forgetStreamID(id uint32) { - cc.streamByID(id, true) -} - -func (cc *http2ClientConn) streamByID(id uint32, andRemove bool) *http2clientStream { cc.mu.Lock() - defer cc.mu.Unlock() - cs := cc.streams[id] - if andRemove && cs != nil && !cc.closed { - cc.lastActive = time.Now() - delete(cc.streams, id) - if len(cc.streams) == 0 && cc.idleTimer != nil { - cc.idleTimer.Reset(cc.idleTimeout) - cc.lastIdle = time.Now() - } - close(cs.done) - // Wake up checkResetOrDone via clientStream.awaitFlowControl and - // wake up RoundTrip if there is a pending request. - cc.cond.Broadcast() + slen := len(cc.streams) + delete(cc.streams, id) + if len(cc.streams) != slen-1 { + panic("forgetting unknown stream id") + } + cc.lastActive = time.Now() + if len(cc.streams) == 0 && cc.idleTimer != nil { + cc.idleTimer.Reset(cc.idleTimeout) + cc.lastIdle = time.Now() + } + // Wake up writeRequestBody via clientStream.awaitFlowControl and + // wake up RoundTrip if there is a pending request. + cc.cond.Broadcast() + + closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() + if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 { + if http2VerboseLogs { + cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, cc.nextStreamID-2) + } + cc.closed = true + defer cc.tconn.Close() } - return cs + + cc.mu.Unlock() } // clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop. type http2clientConnReadLoop struct { - _ http2incomparable - cc *http2ClientConn - closeWhenIdle bool + _ http2incomparable + cc *http2ClientConn } // readLoop runs in its own goroutine and reads and dispatches frames. @@ -8424,23 +8761,49 @@ func (rl *http2clientConnReadLoop) cleanup() { } else if err == io.EOF { err = io.ErrUnexpectedEOF } + cc.closed = true for _, cs := range cc.streams { - cs.bufPipe.CloseWithError(err) // no-op if already closed select { - case cs.resc <- http2resAndError{err: err}: + case <-cs.peerClosed: + // The server closed the stream before closing the conn, + // so no need to interrupt it. default: + cs.abortStreamLocked(err) } - close(cs.done) } - cc.closed = true cc.cond.Broadcast() cc.mu.Unlock() } +// countReadFrameError calls Transport.CountError with a string +// representing err. +func (cc *http2ClientConn) countReadFrameError(err error) { + f := cc.t.CountError + if f == nil || err == nil { + return + } + if ce, ok := err.(http2ConnectionError); ok { + errCode := http2ErrCode(ce) + f(fmt.Sprintf("read_frame_conn_error_%s", errCode.stringToken())) + return + } + if errors.Is(err, io.EOF) { + f("read_frame_eof") + return + } + if errors.Is(err, io.ErrUnexpectedEOF) { + f("read_frame_unexpected_eof") + return + } + if errors.Is(err, http2ErrFrameTooLarge) { + f("read_frame_too_large") + return + } + f("read_frame_other") +} + func (rl *http2clientConnReadLoop) run() error { cc := rl.cc - rl.closeWhenIdle = cc.t.disableKeepAlives() || cc.singleUse - gotReply := false // ever saw a HEADERS reply gotSettings := false readIdleTimeout := cc.t.ReadIdleTimeout var t *time.Timer @@ -8457,9 +8820,7 @@ func (rl *http2clientConnReadLoop) run() error { cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err) } if se, ok := err.(http2StreamError); ok { - if cs := cc.streamByID(se.StreamID, false); cs != nil { - cs.cc.writeStreamReset(cs.ID, se.Code, err) - cs.cc.forgetStreamID(cs.ID) + if cs := rl.streamByID(se.StreamID); cs != nil { if se.Cause == nil { se.Cause = cc.fr.errDetail } @@ -8467,6 +8828,7 @@ func (rl *http2clientConnReadLoop) run() error { } continue } else if err != nil { + cc.countReadFrameError(err) return err } if http2VerboseLogs { @@ -8479,22 +8841,16 @@ func (rl *http2clientConnReadLoop) run() error { } gotSettings = true } - maybeIdle := false // whether frame might transition us to idle switch f := f.(type) { case *http2MetaHeadersFrame: err = rl.processHeaders(f) - maybeIdle = true - gotReply = true case *http2DataFrame: err = rl.processData(f) - maybeIdle = true case *http2GoAwayFrame: err = rl.processGoAway(f) - maybeIdle = true case *http2RSTStreamFrame: err = rl.processResetStream(f) - maybeIdle = true case *http2SettingsFrame: err = rl.processSettings(f) case *http2PushPromiseFrame: @@ -8512,38 +8868,24 @@ func (rl *http2clientConnReadLoop) run() error { } return err } - if rl.closeWhenIdle && gotReply && maybeIdle { - cc.closeIfIdle() - } } } func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) error { - cc := rl.cc - cs := cc.streamByID(f.StreamID, false) + cs := rl.streamByID(f.StreamID) if cs == nil { // We'd get here if we canceled a request while the // server had its response still in flight. So if this // was just something we canceled, ignore it. return nil } - if f.StreamEnded() { - // Issue 20521: If the stream has ended, streamByID() causes - // clientStream.done to be closed, which causes the request's bodyWriter - // to be closed with an errStreamClosed, which may be received by - // clientConn.RoundTrip before the result of processing these headers. - // Deferring stream closure allows the header processing to occur first. - // clientConn.RoundTrip may still receive the bodyWriter error first, but - // the fix for issue 16102 prioritises any response. - // - // Issue 22413: If there is no request body, we should close the - // stream before writing to cs.resc so that the stream is closed - // immediately once RoundTrip returns. - if cs.req.Body != nil { - defer cc.forgetStreamID(f.StreamID) - } else { - cc.forgetStreamID(f.StreamID) - } + if cs.readClosed { + rl.endStreamError(cs, http2StreamError{ + StreamID: f.StreamID, + Code: http2ErrCodeProtocol, + Cause: errors.New("protocol error: headers after END_STREAM"), + }) + return nil } if !cs.firstByte { if cs.trace != nil { @@ -8567,9 +8909,11 @@ func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) erro return err } // Any other error type is a stream error. - cs.cc.writeStreamReset(f.StreamID, http2ErrCodeProtocol, err) - cc.forgetStreamID(cs.ID) - cs.resc <- http2resAndError{err: err} + rl.endStreamError(cs, http2StreamError{ + StreamID: f.StreamID, + Code: http2ErrCodeProtocol, + Cause: err, + }) return nil // return nil from process* funcs to keep conn alive } if res == nil { @@ -8577,7 +8921,11 @@ func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) erro return nil } cs.resTrailer = &res.Trailer - cs.resc <- http2resAndError{res: res} + cs.res = res + close(cs.respHeaderRecv) + if f.StreamEnded() { + rl.endStream(cs) + } return nil } @@ -8639,6 +8987,9 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http } if statusCode >= 100 && statusCode <= 199 { + if f.StreamEnded() { + return nil, errors.New("1xx informational response with END_STREAM flag") + } cs.num1xx++ const max1xxResponses = 5 // arbitrary bound on number of informational responses, same as net/http if cs.num1xx > max1xxResponses { @@ -8651,42 +9002,49 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http } if statusCode == 100 { http2traceGot100Continue(cs.trace) - if cs.on100 != nil { - cs.on100() // forces any write delay timer to fire + select { + case cs.on100 <- struct{}{}: + default: } } cs.pastHeaders = false // do it all again return nil, nil } - streamEnded := f.StreamEnded() - isHead := cs.req.Method == "HEAD" - if !streamEnded || isHead { - res.ContentLength = -1 - if clens := res.Header["Content-Length"]; len(clens) == 1 { - if cl, err := strconv.ParseUint(clens[0], 10, 63); err == nil { - res.ContentLength = int64(cl) - } else { - // TODO: care? unlike http/1, it won't mess up our framing, so it's - // more safe smuggling-wise to ignore. - } - } else if len(clens) > 1 { + res.ContentLength = -1 + if clens := res.Header["Content-Length"]; len(clens) == 1 { + if cl, err := strconv.ParseUint(clens[0], 10, 63); err == nil { + res.ContentLength = int64(cl) + } else { // TODO: care? unlike http/1, it won't mess up our framing, so it's // more safe smuggling-wise to ignore. } + } else if len(clens) > 1 { + // TODO: care? unlike http/1, it won't mess up our framing, so it's + // more safe smuggling-wise to ignore. + } else if f.StreamEnded() && !cs.isHead { + res.ContentLength = 0 } - if streamEnded || isHead { + if cs.isHead { res.Body = http2noBody return res, nil } - cs.bufPipe = http2pipe{b: &http2dataBuffer{expected: res.ContentLength}} + if f.StreamEnded() { + if res.ContentLength > 0 { + res.Body = http2missingBody{} + } else { + res.Body = http2noBody + } + return res, nil + } + + cs.bufPipe.setBuffer(&http2dataBuffer{expected: res.ContentLength}) cs.bytesRemain = res.ContentLength res.Body = http2transportResponseBody{cs} - go cs.awaitRequestCancel(cs.req) - if cs.requestedGzip && res.Header.Get("Content-Encoding") == "gzip" { + if cs.requestedGzip && http2asciiEqualFold(res.Header.Get("Content-Encoding"), "gzip") { res.Header.Del("Content-Encoding") res.Header.Del("Content-Length") res.ContentLength = -1 @@ -8725,8 +9083,7 @@ func (rl *http2clientConnReadLoop) processTrailers(cs *http2clientStream, f *htt } // transportResponseBody is the concrete type of Transport.RoundTrip's -// Response.Body. It is an io.ReadCloser. On Read, it reads from cs.body. -// On Close it sends RST_STREAM if EOF wasn't already seen. +// Response.Body. It is an io.ReadCloser. type http2transportResponseBody struct { cs *http2clientStream } @@ -8744,7 +9101,7 @@ func (b http2transportResponseBody) Read(p []byte) (n int, err error) { n = int(cs.bytesRemain) if err == nil { err = errors.New("net/http: server replied with more than declared Content-Length; truncated") - cc.writeStreamReset(cs.ID, http2ErrCodeProtocol, err) + cs.abortStream(err) } cs.readErr = err return int(cs.bytesRemain), err @@ -8762,8 +9119,6 @@ func (b http2transportResponseBody) Read(p []byte) (n int, err error) { } cc.mu.Lock() - defer cc.mu.Unlock() - var connAdd, streamAdd int32 // Check the conn-level first, before the stream-level. if v := cc.inflow.available(); v < http2transportDefaultConnFlow/2 { @@ -8780,6 +9135,8 @@ func (b http2transportResponseBody) Read(p []byte) (n int, err error) { cs.inflow.add(streamAdd) } } + cc.mu.Unlock() + if connAdd != 0 || streamAdd != 0 { cc.wmu.Lock() defer cc.wmu.Unlock() @@ -8800,34 +9157,45 @@ func (b http2transportResponseBody) Close() error { cs := b.cs cc := cs.cc - serverSentStreamEnd := cs.bufPipe.Err() == io.EOF unread := cs.bufPipe.Len() - - if unread > 0 || !serverSentStreamEnd { + if unread > 0 { cc.mu.Lock() - cc.wmu.Lock() - if !serverSentStreamEnd { - cc.fr.WriteRSTStream(cs.ID, http2ErrCodeCancel) - cs.didReset = true - } // Return connection-level flow control. if unread > 0 { cc.inflow.add(int32(unread)) + } + cc.mu.Unlock() + + // TODO(dneil): Acquiring this mutex can block indefinitely. + // Move flow control return to a goroutine? + cc.wmu.Lock() + // Return connection-level flow control. + if unread > 0 { cc.fr.WriteWindowUpdate(0, uint32(unread)) } cc.bw.Flush() cc.wmu.Unlock() - cc.mu.Unlock() } cs.bufPipe.BreakWithError(http2errClosedResponseBody) - cc.forgetStreamID(cs.ID) + cs.abortStream(http2errClosedResponseBody) + + select { + case <-cs.donec: + case <-cs.ctx.Done(): + // See golang/go#49366: The net/http package can cancel the + // request context after the response body is fully read. + // Don't treat this as an error. + return nil + case <-cs.reqCancel: + return http2errRequestCanceled + } return nil } func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { cc := rl.cc - cs := cc.streamByID(f.StreamID, f.StreamEnded()) + cs := rl.streamByID(f.StreamID) data := f.Data() if cs == nil { cc.mu.Lock() @@ -8856,6 +9224,14 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { } return nil } + if cs.readClosed { + cc.logf("protocol error: received DATA after END_STREAM") + rl.endStreamError(cs, http2StreamError{ + StreamID: f.StreamID, + Code: http2ErrCodeProtocol, + }) + return nil + } if !cs.firstByte { cc.logf("protocol error: received DATA before a HEADERS frame") rl.endStreamError(cs, http2StreamError{ @@ -8865,7 +9241,7 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { return nil } if f.Length > 0 { - if cs.req.Method == "HEAD" && len(data) > 0 { + if cs.isHead && len(data) > 0 { cc.logf("protocol error: received DATA on a HEAD request") rl.endStreamError(cs, http2StreamError{ StreamID: f.StreamID, @@ -8887,30 +9263,39 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { if pad := int(f.Length) - len(data); pad > 0 { refund += pad } - // Return len(data) now if the stream is already closed, - // since data will never be read. - didReset := cs.didReset - if didReset { - refund += len(data) + + didReset := false + var err error + if len(data) > 0 { + if _, err = cs.bufPipe.Write(data); err != nil { + // Return len(data) now if the stream is already closed, + // since data will never be read. + didReset = true + refund += len(data) + } } + if refund > 0 { cc.inflow.add(int32(refund)) + if !didReset { + cs.inflow.add(int32(refund)) + } + } + cc.mu.Unlock() + + if refund > 0 { cc.wmu.Lock() cc.fr.WriteWindowUpdate(0, uint32(refund)) if !didReset { - cs.inflow.add(int32(refund)) cc.fr.WriteWindowUpdate(cs.ID, uint32(refund)) } cc.bw.Flush() cc.wmu.Unlock() } - cc.mu.Unlock() - if len(data) > 0 && !didReset { - if _, err := cs.bufPipe.Write(data); err != nil { - rl.endStreamError(cs, err) - return err - } + if err != nil { + rl.endStreamError(cs, err) + return nil } } @@ -8923,24 +9308,32 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { func (rl *http2clientConnReadLoop) endStream(cs *http2clientStream) { // TODO: check that any declared content-length matches, like // server.go's (*stream).endStream method. - rl.endStreamError(cs, nil) + if !cs.readClosed { + cs.readClosed = true + // Close cs.bufPipe and cs.peerClosed with cc.mu held to avoid a + // race condition: The caller can read io.EOF from Response.Body + // and close the body before we close cs.peerClosed, causing + // cleanupWriteRequest to send a RST_STREAM. + rl.cc.mu.Lock() + defer rl.cc.mu.Unlock() + cs.bufPipe.closeWithErrorAndCode(io.EOF, cs.copyTrailers) + close(cs.peerClosed) + } } func (rl *http2clientConnReadLoop) endStreamError(cs *http2clientStream, err error) { - var code func() - if err == nil { - err = io.EOF - code = cs.copyTrailers - } - if http2isConnectionCloseRequest(cs.req) { - rl.closeWhenIdle = true - } - cs.bufPipe.closeWithErrorAndCode(err, code) + cs.readAborted = true + cs.abortStream(err) +} - select { - case cs.resc <- http2resAndError{err: err}: - default: +func (rl *http2clientConnReadLoop) streamByID(id uint32) *http2clientStream { + rl.cc.mu.Lock() + defer rl.cc.mu.Unlock() + cs := rl.cc.streams[id] + if cs != nil && !cs.readAborted { + return cs } + return nil } func (cs *http2clientStream) copyTrailers() { @@ -8959,6 +9352,10 @@ func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error { if f.ErrCode != 0 { // TODO: deal with GOAWAY more. particularly the error code cc.vlogf("transport got GOAWAY with error code = %v", f.ErrCode) + if fn := cc.t.CountError; fn != nil { + fn("recv_goaway_" + f.ErrCode.stringToken()) + } + } cc.setGoAway(f) return nil @@ -8966,6 +9363,23 @@ func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error { func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error { cc := rl.cc + // Locking both mu and wmu here allows frame encoding to read settings with only wmu held. + // Acquiring wmu when f.IsAck() is unnecessary, but convenient and mostly harmless. + cc.wmu.Lock() + defer cc.wmu.Unlock() + + if err := rl.processSettingsNoWrite(f); err != nil { + return err + } + if !f.IsAck() { + cc.fr.WriteSettingsAck() + cc.bw.Flush() + } + return nil +} + +func (rl *http2clientConnReadLoop) processSettingsNoWrite(f *http2SettingsFrame) error { + cc := rl.cc cc.mu.Lock() defer cc.mu.Unlock() @@ -8977,12 +9391,14 @@ func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error return http2ConnectionError(http2ErrCodeProtocol) } + var seenMaxConcurrentStreams bool err := f.ForeachSetting(func(s http2Setting) error { switch s.ID { case http2SettingMaxFrameSize: cc.maxFrameSize = s.Val case http2SettingMaxConcurrentStreams: cc.maxConcurrentStreams = s.Val + seenMaxConcurrentStreams = true case http2SettingMaxHeaderListSize: cc.peerMaxHeaderListSize = uint64(s.Val) case http2SettingInitialWindowSize: @@ -9014,17 +9430,23 @@ func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error return err } - cc.wmu.Lock() - defer cc.wmu.Unlock() + if !cc.seenSettings { + if !seenMaxConcurrentStreams { + // This was the servers initial SETTINGS frame and it + // didn't contain a MAX_CONCURRENT_STREAMS field so + // increase the number of concurrent streams this + // connection can establish to our default. + cc.maxConcurrentStreams = http2defaultMaxConcurrentStreams + } + cc.seenSettings = true + } - cc.fr.WriteSettingsAck() - cc.bw.Flush() - return cc.werr + return nil } func (rl *http2clientConnReadLoop) processWindowUpdate(f *http2WindowUpdateFrame) error { cc := rl.cc - cs := cc.streamByID(f.StreamID, false) + cs := rl.streamByID(f.StreamID) if f.StreamID != 0 && cs == nil { return nil } @@ -9044,24 +9466,22 @@ func (rl *http2clientConnReadLoop) processWindowUpdate(f *http2WindowUpdateFrame } func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) error { - cs := rl.cc.streamByID(f.StreamID, true) + cs := rl.streamByID(f.StreamID) if cs == nil { - // TODO: return error if server tries to RST_STEAM an idle stream + // TODO: return error if server tries to RST_STREAM an idle stream return nil } - select { - case <-cs.peerReset: - // Already reset. - // This is the only goroutine - // which closes this, so there - // isn't a race. - default: - err := http2streamError(cs.ID, f.ErrCode) - cs.resetErr = err - close(cs.peerReset) - cs.bufPipe.CloseWithError(err) - cs.cc.cond.Broadcast() // wake up checkResetOrDone via clientStream.awaitFlowControl + serr := http2streamError(cs.ID, f.ErrCode) + serr.Cause = http2errFromPeer + if f.ErrCode == http2ErrCodeProtocol { + rl.cc.SetDoNotReuse() } + if fn := cs.cc.t.CountError; fn != nil { + fn("recv_rststream_" + f.ErrCode.stringToken()) + } + cs.abortStream(serr) + + cs.bufPipe.CloseWithError(serr) return nil } @@ -9083,19 +9503,24 @@ func (cc *http2ClientConn) Ping(ctx context.Context) error { } cc.mu.Unlock() } - cc.wmu.Lock() - if err := cc.fr.WritePing(false, p); err != nil { - cc.wmu.Unlock() - return err - } - if err := cc.bw.Flush(); err != nil { - cc.wmu.Unlock() - return err - } - cc.wmu.Unlock() + errc := make(chan error, 1) + go func() { + cc.wmu.Lock() + defer cc.wmu.Unlock() + if err := cc.fr.WritePing(false, p); err != nil { + errc <- err + return + } + if err := cc.bw.Flush(); err != nil { + errc <- err + return + } + }() select { case <-c: return nil + case err := <-errc: + return err case <-ctx.Done(): return ctx.Err() case <-cc.readerDone: @@ -9172,6 +9597,12 @@ func (t *http2Transport) logf(format string, args ...interface{}) { var http2noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil)) +type http2missingBody struct{} + +func (http2missingBody) Close() error { return nil } + +func (http2missingBody) Read([]byte) (int, error) { return 0, io.ErrUnexpectedEOF } + func http2strSliceContains(ss []string, s string) bool { for _, v := range ss { if v == s { @@ -9218,87 +9649,6 @@ type http2errorReader struct{ err error } func (r http2errorReader) Read(p []byte) (int, error) { return 0, r.err } -// bodyWriterState encapsulates various state around the Transport's writing -// of the request body, particularly regarding doing delayed writes of the body -// when the request contains "Expect: 100-continue". -type http2bodyWriterState struct { - cs *http2clientStream - timer *time.Timer // if non-nil, we're doing a delayed write - fnonce *sync.Once // to call fn with - fn func() // the code to run in the goroutine, writing the body - resc chan error // result of fn's execution - delay time.Duration // how long we should delay a delayed write for -} - -func (t *http2Transport) getBodyWriterState(cs *http2clientStream, body io.Reader) (s http2bodyWriterState) { - s.cs = cs - if body == nil { - return - } - resc := make(chan error, 1) - s.resc = resc - s.fn = func() { - cs.cc.mu.Lock() - cs.startedWrite = true - cs.cc.mu.Unlock() - resc <- cs.writeRequestBody(body, cs.req.Body) - } - s.delay = t.expectContinueTimeout() - if s.delay == 0 || - !httpguts.HeaderValuesContainsToken( - cs.req.Header["Expect"], - "100-continue") { - return - } - s.fnonce = new(sync.Once) - - // Arm the timer with a very large duration, which we'll - // intentionally lower later. It has to be large now because - // we need a handle to it before writing the headers, but the - // s.delay value is defined to not start until after the - // request headers were written. - const hugeDuration = 365 * 24 * time.Hour - s.timer = time.AfterFunc(hugeDuration, func() { - s.fnonce.Do(s.fn) - }) - return -} - -func (s http2bodyWriterState) cancel() { - if s.timer != nil { - if s.timer.Stop() { - s.resc <- nil - } - } -} - -func (s http2bodyWriterState) on100() { - if s.timer == nil { - // If we didn't do a delayed write, ignore the server's - // bogus 100 continue response. - return - } - s.timer.Stop() - go func() { s.fnonce.Do(s.fn) }() -} - -// scheduleBodyWrite starts writing the body, either immediately (in -// the common case) or after the delay timeout. It should not be -// called until after the headers have been written. -func (s http2bodyWriterState) scheduleBodyWrite() { - if s.timer == nil { - // We're not doing a delayed write (see - // getBodyWriterState), so just start the writing - // goroutine immediately. - go s.fn() - return - } - http2traceWait100Continue(s.cs.trace) - if s.timer.Stop() { - s.timer.Reset(s.delay) - } -} - // isConnectionCloseRequest reports whether req should use its own // connection for a single request and then close the connection. func http2isConnectionCloseRequest(req *Request) bool { @@ -9775,7 +10125,8 @@ type http2WriteScheduler interface { // Pop dequeues the next frame to write. Returns false if no frames can // be written. Frames with a given wr.StreamID() are Pop'd in the same - // order they are Push'd. No frames should be discarded except by CloseStream. + // order they are Push'd, except RST_STREAM frames. No frames should be + // discarded except by CloseStream. Pop() (wr http2FrameWriteRequest, ok bool) } @@ -9795,6 +10146,7 @@ type http2FrameWriteRequest struct { // stream is the stream on which this frame will be written. // nil for non-stream frames like PING and SETTINGS. + // nil for RST_STREAM streams, which use the StreamError.StreamID field instead. stream *http2stream // done, if non-nil, must be a buffered channel with space for @@ -10474,11 +10826,11 @@ func (ws *http2randomWriteScheduler) AdjustStream(streamID uint32, priority http } func (ws *http2randomWriteScheduler) Push(wr http2FrameWriteRequest) { - id := wr.StreamID() - if id == 0 { + if wr.isControl() { ws.zero.push(wr) return } + id := wr.StreamID() q, ok := ws.sq[id] if !ok { q = ws.queuePool.get() @@ -10488,7 +10840,7 @@ func (ws *http2randomWriteScheduler) Push(wr http2FrameWriteRequest) { } func (ws *http2randomWriteScheduler) Pop() (http2FrameWriteRequest, bool) { - // Control frames first. + // Control and RST_STREAM frames first. if !ws.zero.empty() { return ws.zero.shift(), true } diff --git a/libgo/go/net/http/header.go b/libgo/go/net/http/header.go index 4c72dcb..6487e50 100644 --- a/libgo/go/net/http/header.go +++ b/libgo/go/net/http/header.go @@ -13,6 +13,8 @@ import ( "strings" "sync" "time" + + "golang.org/x/net/http/httpguts" ) // A Header represents the key-value pairs in an HTTP header. @@ -155,7 +157,7 @@ func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kv func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } var headerSorterPool = sync.Pool{ - New: func() interface{} { return new(headerSorter) }, + New: func() any { return new(headerSorter) }, } // sortedKeyValues returns h's keys sorted in the returned kvs @@ -192,6 +194,13 @@ func (h Header) writeSubset(w io.Writer, exclude map[string]bool, trace *httptra kvs, sorter := h.sortedKeyValues(exclude) var formattedVals []string for _, kv := range kvs { + if !httpguts.ValidHeaderFieldName(kv.key) { + // This could be an error. In the common case of + // writing response headers, however, we have no good + // way to provide the error back to the server + // handler, so just drop invalid headers instead. + continue + } for _, v := range kv.values { v = headerNewlineToSpace.Replace(v) v = textproto.TrimString(v) diff --git a/libgo/go/net/http/header_test.go b/libgo/go/net/http/header_test.go index ad8ab9b..575493b 100644 --- a/libgo/go/net/http/header_test.go +++ b/libgo/go/net/http/header_test.go @@ -89,6 +89,19 @@ var headerWriteTests = []struct { "k4: 4a\r\nk4: 4b\r\nk6: 6a\r\nk6: 6b\r\n" + "k7: 7a\r\nk7: 7b\r\nk8: 8a\r\nk8: 8b\r\nk9: 9a\r\nk9: 9b\r\n", }, + // Tests invalid characters in headers. + { + Header{ + "Content-Type": {"text/html; charset=UTF-8"}, + "NewlineInValue": {"1\r\nBar: 2"}, + "NewlineInKey\r\n": {"1"}, + "Colon:InKey": {"1"}, + "Evil: 1\r\nSmuggledValue": {"1"}, + }, + nil, + "Content-Type: text/html; charset=UTF-8\r\n" + + "NewlineInValue: 1 Bar: 2\r\n", + }, } func TestHeaderWrite(t *testing.T) { diff --git a/libgo/go/net/http/httptrace/trace.go b/libgo/go/net/http/httptrace/trace.go index 5777c91..6af30f7 100644 --- a/libgo/go/net/http/httptrace/trace.go +++ b/libgo/go/net/http/httptrace/trace.go @@ -50,7 +50,7 @@ func WithClientTrace(ctx context.Context, trace *ClientTrace) context.Context { } } if trace.DNSDone != nil { - nt.DNSDone = func(netIPs []interface{}, coalesced bool, err error) { + nt.DNSDone = func(netIPs []any, coalesced bool, err error) { addrs := make([]net.IPAddr, len(netIPs)) for i, ip := range netIPs { addrs[i] = ip.(net.IPAddr) diff --git a/libgo/go/net/http/httputil/dump.go b/libgo/go/net/http/httputil/dump.go index 2948f27..d7baecd 100644 --- a/libgo/go/net/http/httputil/dump.go +++ b/libgo/go/net/http/httputil/dump.go @@ -292,7 +292,7 @@ func DumpRequest(req *http.Request, body bool) ([]byte, error) { // can detect that the lack of body was intentional. var errNoBody = errors.New("sentinel error value") -// failureToReadBody is a io.ReadCloser that just returns errNoBody on +// failureToReadBody is an io.ReadCloser that just returns errNoBody on // Read. It's swapped in when we don't actually want to consume // the body, but need a non-nil one, and want to distinguish the // error from reading the dummy body. diff --git a/libgo/go/net/http/httputil/dump_test.go b/libgo/go/net/http/httputil/dump_test.go index 366cc82..5df2ee8 100644 --- a/libgo/go/net/http/httputil/dump_test.go +++ b/libgo/go/net/http/httputil/dump_test.go @@ -31,7 +31,7 @@ type dumpTest struct { Req *http.Request GetReq func() *http.Request - Body interface{} // optional []byte or func() io.ReadCloser to populate Req.Body + Body any // optional []byte or func() io.ReadCloser to populate Req.Body WantDump string WantDumpOut string diff --git a/libgo/go/net/http/httputil/reverseproxy.go b/libgo/go/net/http/httputil/reverseproxy.go index 8b63368..319e2a3 100644 --- a/libgo/go/net/http/httputil/reverseproxy.go +++ b/libgo/go/net/http/httputil/reverseproxy.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "log" + "mime" "net" "net/http" "net/http/internal/ascii" @@ -412,7 +413,7 @@ func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration { // For Server-Sent Events responses, flush immediately. // The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream - if resCT == "text/event-stream" { + if baseCT, _, _ := mime.ParseMediaType(resCT); baseCT == "text/event-stream" { return -1 // negative means immediately } @@ -483,7 +484,7 @@ func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int } } -func (p *ReverseProxy) logf(format string, args ...interface{}) { +func (p *ReverseProxy) logf(format string, args ...any) { if p.ErrorLog != nil { p.ErrorLog.Printf(format, args...) } else { diff --git a/libgo/go/net/http/httputil/reverseproxy_test.go b/libgo/go/net/http/httputil/reverseproxy_test.go index 4b6ad77..90e8903 100644 --- a/libgo/go/net/http/httputil/reverseproxy_test.go +++ b/libgo/go/net/http/httputil/reverseproxy_test.go @@ -1195,6 +1195,26 @@ func TestSelectFlushInterval(t *testing.T) { want: -1, }, { + name: "server-sent events with media-type parameters overrides non-zero", + res: &http.Response{ + Header: http.Header{ + "Content-Type": {"text/event-stream;charset=utf-8"}, + }, + }, + p: &ReverseProxy{FlushInterval: 123}, + want: -1, + }, + { + name: "server-sent events with media-type parameters overrides zero", + res: &http.Response{ + Header: http.Header{ + "Content-Type": {"text/event-stream;charset=utf-8"}, + }, + }, + p: &ReverseProxy{FlushInterval: 0}, + want: -1, + }, + { name: "Content-Length: -1, overrides non-zero", res: &http.Response{ ContentLength: -1, diff --git a/libgo/go/net/http/internal/chunked.go b/libgo/go/net/http/internal/chunked.go index f06e572..37a72e9 100644 --- a/libgo/go/net/http/internal/chunked.go +++ b/libgo/go/net/http/internal/chunked.go @@ -81,6 +81,11 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) { cr.err = errors.New("malformed chunked encoding") break } + } else { + if cr.err == io.EOF { + cr.err = io.ErrUnexpectedEOF + } + break } cr.checkEnd = false } @@ -109,6 +114,8 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) { // bytes to verify they are "\r\n". if cr.n == 0 && cr.err == nil { cr.checkEnd = true + } else if cr.err == io.EOF { + cr.err = io.ErrUnexpectedEOF } } return n, cr.err @@ -152,6 +159,8 @@ func isASCIISpace(b byte) bool { return b == ' ' || b == '\t' || b == '\n' || b == '\r' } +var semi = []byte(";") + // removeChunkExtension removes any chunk-extension from p. // For example, // "0" => "0" @@ -159,14 +168,11 @@ func isASCIISpace(b byte) bool { // "0;token=val" => "0" // `0;token="quoted string"` => "0" func removeChunkExtension(p []byte) ([]byte, error) { - semi := bytes.IndexByte(p, ';') - if semi == -1 { - return p, nil - } + p, _, _ = bytes.Cut(p, semi) // TODO: care about exact syntax of chunk extensions? We're // ignoring and stripping them anyway. For now just never // return an error. - return p[:semi], nil + return p, nil } // NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP diff --git a/libgo/go/net/http/internal/chunked_test.go b/libgo/go/net/http/internal/chunked_test.go index 08152ed..5e29a78 100644 --- a/libgo/go/net/http/internal/chunked_test.go +++ b/libgo/go/net/http/internal/chunked_test.go @@ -11,6 +11,7 @@ import ( "io" "strings" "testing" + "testing/iotest" ) func TestChunk(t *testing.T) { @@ -211,3 +212,30 @@ func TestChunkReadPartial(t *testing.T) { } } + +// Issue 48861: ChunkedReader should report incomplete chunks +func TestIncompleteChunk(t *testing.T) { + const valid = "4\r\nabcd\r\n" + "5\r\nabc\r\n\r\n" + "0\r\n" + + for i := 0; i < len(valid); i++ { + incomplete := valid[:i] + r := NewChunkedReader(strings.NewReader(incomplete)) + if _, err := io.ReadAll(r); err != io.ErrUnexpectedEOF { + t.Errorf("expected io.ErrUnexpectedEOF for %q, got %v", incomplete, err) + } + } + + r := NewChunkedReader(strings.NewReader(valid)) + if _, err := io.ReadAll(r); err != nil { + t.Errorf("unexpected error for %q: %v", valid, err) + } +} + +func TestChunkEndReadError(t *testing.T) { + readErr := fmt.Errorf("chunk end read error") + + r := NewChunkedReader(io.MultiReader(strings.NewReader("4\r\nabcd"), iotest.ErrReader(readErr))) + if _, err := io.ReadAll(r); err != readErr { + t.Errorf("expected %v, got %v", readErr, err) + } +} diff --git a/libgo/go/net/http/internal/testcert/testcert.go b/libgo/go/net/http/internal/testcert/testcert.go index 5f94704..d510e79 100644 --- a/libgo/go/net/http/internal/testcert/testcert.go +++ b/libgo/go/net/http/internal/testcert/testcert.go @@ -10,37 +10,56 @@ import "strings" // LocalhostCert is a PEM-encoded TLS cert with SAN IPs // "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT. // generated from src/crypto/tls: -// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h +// go run generate_cert.go --rsa-bits 2048 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h var LocalhostCert = []byte(`-----BEGIN CERTIFICATE----- -MIICEzCCAXygAwIBAgIQMIMChMLGrR+QvmQvpwAU6zANBgkqhkiG9w0BAQsFADAS +MIIDOTCCAiGgAwIBAgIQSRJrEpBGFc7tNb1fb5pKFzANBgkqhkiG9w0BAQsFADAS MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw -MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB -iQKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9SjY1bIw4 -iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZBl2+XsDul -rKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQABo2gwZjAO -BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw -AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA -AAAAATANBgkqhkiG9w0BAQsFAAOBgQCEcetwO59EWk7WiJsG4x8SY+UIAA+flUI9 -tyC4lNhbcF2Idq9greZwbYCqTTTr2XiRNSMLCOjKyI7ukPoPjo16ocHj+P3vZGfs -h1fIw3cSS2OolhloGw/XM6RWPWtPAlGykKLciQrBru5NAPvCMsb/I1DAceTiotQM -fblo6RBxUQ== +MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEA6Gba5tHV1dAKouAaXO3/ebDUU4rvwCUg/CNaJ2PT5xLD4N1Vcb8r +bFSW2HXKq+MPfVdwIKR/1DczEoAGf/JWQTW7EgzlXrCd3rlajEX2D73faWJekD0U +aUgz5vtrTXZ90BQL7WvRICd7FlEZ6FPOcPlumiyNmzUqtwGhO+9ad1W5BqJaRI6P +YfouNkwR6Na4TzSj5BrqUfP0FwDizKSJ0XXmh8g8G9mtwxOSN3Ru1QFc61Xyeluk +POGKBV/q6RBNklTNe0gI8usUMlYyoC7ytppNMW7X2vodAelSu25jgx2anj9fDVZu +h7AXF5+4nJS4AAt0n1lNY7nGSsdZas8PbQIDAQABo4GIMIGFMA4GA1UdDwEB/wQE +AwICpDATBgNVHSUEDDAKBggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MB0GA1Ud +DgQWBBStsdjh3/JCXXYlQryOrL4Sh7BW5TAuBgNVHREEJzAlggtleGFtcGxlLmNv +bYcEfwAAAYcQAAAAAAAAAAAAAAAAAAAAATANBgkqhkiG9w0BAQsFAAOCAQEAxWGI +5NhpF3nwwy/4yB4i/CwwSpLrWUa70NyhvprUBC50PxiXav1TeDzwzLx/o5HyNwsv +cxv3HdkLW59i/0SlJSrNnWdfZ19oTcS+6PtLoVyISgtyN6DpkKpdG1cOkW3Cy2P2 ++tK/tKHRP1Y/Ra0RiDpOAmqn0gCOFGz8+lqDIor/T7MTpibL3IxqWfPrvfVRHL3B +grw/ZQTTIVjjh4JBSW3WyWgNo/ikC1lrVxzl4iPUGptxT36Cr7Zk2Bsg0XqwbOvK +5d+NTDREkSnUbie4GeutujmX3Dsx88UiV6UY/4lHJa6I5leHUNOHahRbpbWeOfs/ +WkBKOclmOV2xlTVuPw== -----END CERTIFICATE-----`) // LocalhostKey is the private key for LocalhostCert. var LocalhostKey = []byte(testingKey(`-----BEGIN RSA TESTING KEY----- -MIICXgIBAAKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9 -SjY1bIw4iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZB -l2+XsDulrKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQAB -AoGAGRzwwir7XvBOAy5tM/uV6e+Zf6anZzus1s1Y1ClbjbE6HXbnWWF/wbZGOpet -3Zm4vD6MXc7jpTLryzTQIvVdfQbRc6+MUVeLKwZatTXtdZrhu+Jk7hx0nTPy8Jcb -uJqFk541aEw+mMogY/xEcfbWd6IOkp+4xqjlFLBEDytgbIECQQDvH/E6nk+hgN4H -qzzVtxxr397vWrjrIgPbJpQvBsafG7b0dA4AFjwVbFLmQcj2PprIMmPcQrooz8vp -jy4SHEg1AkEA/v13/5M47K9vCxmb8QeD/asydfsgS5TeuNi8DoUBEmiSJwma7FXY -fFUtxuvL7XvjwjN5B30pNEbc6Iuyt7y4MQJBAIt21su4b3sjXNueLKH85Q+phy2U -fQtuUE9txblTu14q3N7gHRZB4ZMhFYyDy8CKrN2cPg/Fvyt0Xlp/DoCzjA0CQQDU -y2ptGsuSmgUtWj3NM9xuwYPm+Z/F84K6+ARYiZ6PYj013sovGKUFfYAqVXVlxtIX -qyUBnu3X9ps8ZfjLZO7BAkEAlT4R5Yl6cGhaJQYZHOde3JEMhNRcVFMO8dJDaFeo -f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA== +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDoZtrm0dXV0Aqi +4Bpc7f95sNRTiu/AJSD8I1onY9PnEsPg3VVxvytsVJbYdcqr4w99V3AgpH/UNzMS +gAZ/8lZBNbsSDOVesJ3euVqMRfYPvd9pYl6QPRRpSDPm+2tNdn3QFAvta9EgJ3sW +URnoU85w+W6aLI2bNSq3AaE771p3VbkGolpEjo9h+i42TBHo1rhPNKPkGupR8/QX +AOLMpInRdeaHyDwb2a3DE5I3dG7VAVzrVfJ6W6Q84YoFX+rpEE2SVM17SAjy6xQy +VjKgLvK2mk0xbtfa+h0B6VK7bmODHZqeP18NVm6HsBcXn7iclLgAC3SfWU1jucZK +x1lqzw9tAgMBAAECggEABWzxS1Y2wckblnXY57Z+sl6YdmLV+gxj2r8Qib7g4ZIk +lIlWR1OJNfw7kU4eryib4fc6nOh6O4AWZyYqAK6tqNQSS/eVG0LQTLTTEldHyVJL +dvBe+MsUQOj4nTndZW+QvFzbcm2D8lY5n2nBSxU5ypVoKZ1EqQzytFcLZpTN7d89 +EPj0qDyrV4NZlWAwL1AygCwnlwhMQjXEalVF1ylXwU3QzyZ/6MgvF6d3SSUlh+sq +XefuyigXw484cQQgbzopv6niMOmGP3of+yV4JQqUSb3IDmmT68XjGd2Dkxl4iPki +6ZwXf3CCi+c+i/zVEcufgZ3SLf8D99kUGE7v7fZ6AQKBgQD1ZX3RAla9hIhxCf+O +3D+I1j2LMrdjAh0ZKKqwMR4JnHX3mjQI6LwqIctPWTU8wYFECSh9klEclSdCa64s +uI/GNpcqPXejd0cAAdqHEEeG5sHMDt0oFSurL4lyud0GtZvwlzLuwEweuDtvT9cJ +Wfvl86uyO36IW8JdvUprYDctrQKBgQDycZ697qutBieZlGkHpnYWUAeImVA878sJ +w44NuXHvMxBPz+lbJGAg8Cn8fcxNAPqHIraK+kx3po8cZGQywKHUWsxi23ozHoxo ++bGqeQb9U661TnfdDspIXia+xilZt3mm5BPzOUuRqlh4Y9SOBpSWRmEhyw76w4ZP +OPxjWYAgwQKBgA/FehSYxeJgRjSdo+MWnK66tjHgDJE8bYpUZsP0JC4R9DL5oiaA +brd2fI6Y+SbyeNBallObt8LSgzdtnEAbjIH8uDJqyOmknNePRvAvR6mP4xyuR+Bv +m+Lgp0DMWTw5J9CKpydZDItc49T/mJ5tPhdFVd+am0NAQnmr1MCZ6nHxAoGABS3Y +LkaC9FdFUUqSU8+Chkd/YbOkuyiENdkvl6t2e52jo5DVc1T7mLiIrRQi4SI8N9bN +/3oJWCT+uaSLX2ouCtNFunblzWHBrhxnZzTeqVq4SLc8aESAnbslKL4i8/+vYZlN +s8xtiNcSvL+lMsOBORSXzpj/4Ot8WwTkn1qyGgECgYBKNTypzAHeLE6yVadFp3nQ +Ckq9yzvP/ib05rvgbvrne00YeOxqJ9gtTrzgh7koqJyX1L4NwdkEza4ilDWpucn0 +xiUZS4SoaJq6ZvcBYS62Yr1t8n09iG47YL8ibgtmH3L+svaotvpVxVK+d7BLevA/ +ZboOWVe3icTy64BT3OQhmg== -----END RSA TESTING KEY-----`)) func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } diff --git a/libgo/go/net/http/main_test.go b/libgo/go/net/http/main_test.go index 6564627..27872b4 100644 --- a/libgo/go/net/http/main_test.go +++ b/libgo/go/net/http/main_test.go @@ -31,11 +31,8 @@ func interestingGoroutines() (gs []string) { buf := make([]byte, 2<<20) buf = buf[:runtime.Stack(buf, true)] for _, g := range strings.Split(string(buf), "\n\n") { - sl := strings.SplitN(g, "\n", 2) - if len(sl) != 2 { - continue - } - stack := strings.TrimSpace(sl[1]) + _, stack, _ := strings.Cut(g, "\n") + stack = strings.TrimSpace(stack) if stack == "" || strings.Contains(stack, "testing.(*M).before.func1") || strings.Contains(stack, "os/signal.signal_recv") || @@ -46,7 +43,7 @@ func interestingGoroutines() (gs []string) { // These only show up with GOTRACEBACK=2; Issue 5005 (comment 28) strings.Contains(stack, "runtime.goexit") || strings.Contains(stack, "created by runtime.gc") || - strings.Contains(stack, "net/http_test.interestingGoroutines") || + strings.Contains(stack, "interestingGoroutines") || strings.Contains(stack, "runtime.MHeap_Scavenger") { continue } diff --git a/libgo/go/net/http/omithttp2.go b/libgo/go/net/http/omithttp2.go index 79599d0..3316f55 100644 --- a/libgo/go/net/http/omithttp2.go +++ b/libgo/go/net/http/omithttp2.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build nethttpomithttp2 -// +build nethttpomithttp2 package http @@ -27,7 +26,7 @@ const http2NextProtoTLS = "h2" type http2Transport struct { MaxHeaderListSize uint32 - ConnPool interface{} + ConnPool any } func (*http2Transport) RoundTrip(*Request) (*Response, error) { panic(noHTTP2) } @@ -57,9 +56,9 @@ type http2Server struct { NewWriteScheduler func() http2WriteScheduler } -type http2WriteScheduler interface{} +type http2WriteScheduler any -func http2NewPriorityWriteScheduler(interface{}) http2WriteScheduler { panic(noHTTP2) } +func http2NewPriorityWriteScheduler(any) http2WriteScheduler { panic(noHTTP2) } func http2ConfigureServer(s *Server, conf *http2Server) error { panic(noHTTP2) } diff --git a/libgo/go/net/http/pprof/pprof.go b/libgo/go/net/http/pprof/pprof.go index 888ea35..dc855c8 100644 --- a/libgo/go/net/http/pprof/pprof.go +++ b/libgo/go/net/http/pprof/pprof.go @@ -44,7 +44,7 @@ // The package also exports a handler that serves execution trace data // for the "go tool trace" command. To collect a 5-second execution trace: // -// wget -O trace.out http://localhost:6060/debug/pprof/trace?seconds=5 +// curl -o trace.out http://localhost:6060/debug/pprof/trace?seconds=5 // go tool trace trace.out // // To view all available profiles, open http://localhost:6060/debug/pprof/ diff --git a/libgo/go/net/http/pprof/pprof_test.go b/libgo/go/net/http/pprof/pprof_test.go index 84757e4..1a4d653 100644 --- a/libgo/go/net/http/pprof/pprof_test.go +++ b/libgo/go/net/http/pprof/pprof_test.go @@ -8,6 +8,7 @@ import ( "bytes" "fmt" "internal/profile" + "internal/testenv" "io" "net/http" "net/http/httptest" @@ -152,6 +153,10 @@ func mutexHog(duration time.Duration, hogger func(mu1, mu2 *sync.Mutex, start ti } func TestDeltaProfile(t *testing.T) { + if runtime.GOOS == "openbsd" && runtime.GOARCH == "arm" { + testenv.SkipFlaky(t, 50218) + } + rate := runtime.SetMutexProfileFraction(1) defer func() { runtime.SetMutexProfileFraction(rate) diff --git a/libgo/go/net/http/request.go b/libgo/go/net/http/request.go index 09cb0c7..76c2317 100644 --- a/libgo/go/net/http/request.go +++ b/libgo/go/net/http/request.go @@ -779,11 +779,10 @@ func removeZone(host string) string { return host[:j] + host[i:] } -// ParseHTTPVersion parses an HTTP version string. +// ParseHTTPVersion parses an HTTP version string according to RFC 7230, section 2.6. // "HTTP/1.0" returns (1, 0, true). Note that strings without // a minor version, such as "HTTP/2", are not valid. func ParseHTTPVersion(vers string) (major, minor int, ok bool) { - const Big = 1000000 // arbitrary upper bound switch vers { case "HTTP/1.1": return 1, 1, true @@ -793,19 +792,21 @@ func ParseHTTPVersion(vers string) (major, minor int, ok bool) { if !strings.HasPrefix(vers, "HTTP/") { return 0, 0, false } - dot := strings.Index(vers, ".") - if dot < 0 { + if len(vers) != len("HTTP/X.Y") { return 0, 0, false } - major, err := strconv.Atoi(vers[5:dot]) - if err != nil || major < 0 || major > Big { + if vers[6] != '.' { return 0, 0, false } - minor, err = strconv.Atoi(vers[dot+1:]) - if err != nil || minor < 0 || minor > Big { + maj, err := strconv.ParseUint(vers[5:6], 10, 0) + if err != nil { return 0, 0, false } - return major, minor, true + min, err := strconv.ParseUint(vers[7:8], 10, 0) + if err != nil { + return 0, 0, false + } + return int(maj), int(min), true } func validMethod(method string) bool { @@ -939,7 +940,7 @@ func NewRequestWithContext(ctx context.Context, method, url string, body io.Read func (r *Request) BasicAuth() (username, password string, ok bool) { auth := r.Header.Get("Authorization") if auth == "" { - return + return "", "", false } return parseBasicAuth(auth) } @@ -950,18 +951,18 @@ func parseBasicAuth(auth string) (username, password string, ok bool) { const prefix = "Basic " // Case insensitive prefix match. See Issue 22736. if len(auth) < len(prefix) || !ascii.EqualFold(auth[:len(prefix)], prefix) { - return + return "", "", false } c, err := base64.StdEncoding.DecodeString(auth[len(prefix):]) if err != nil { - return + return "", "", false } cs := string(c) - s := strings.IndexByte(cs, ':') - if s < 0 { - return + username, password, ok = strings.Cut(cs, ":") + if !ok { + return "", "", false } - return cs[:s], cs[s+1:], true + return username, password, true } // SetBasicAuth sets the request's Authorization header to use HTTP @@ -979,13 +980,12 @@ func (r *Request) SetBasicAuth(username, password string) { // parseRequestLine parses "GET /foo HTTP/1.1" into its three parts. func parseRequestLine(line string) (method, requestURI, proto string, ok bool) { - s1 := strings.Index(line, " ") - s2 := strings.Index(line[s1+1:], " ") - if s1 < 0 || s2 < 0 { - return + method, rest, ok1 := strings.Cut(line, " ") + requestURI, proto, ok2 := strings.Cut(rest, " ") + if !ok1 || !ok2 { + return "", "", "", false } - s2 += s1 + 1 - return line[:s1], line[s1+1 : s2], line[s2+1:], true + return method, requestURI, proto, true } var textprotoReaderPool sync.Pool diff --git a/libgo/go/net/http/request_test.go b/libgo/go/net/http/request_test.go index 4e0c4ba..4363e11 100644 --- a/libgo/go/net/http/request_test.go +++ b/libgo/go/net/http/request_test.go @@ -639,10 +639,10 @@ var parseHTTPVersionTests = []struct { major, minor int ok bool }{ + {"HTTP/0.0", 0, 0, true}, {"HTTP/0.9", 0, 9, true}, {"HTTP/1.0", 1, 0, true}, {"HTTP/1.1", 1, 1, true}, - {"HTTP/3.14", 3, 14, true}, {"HTTP", 0, 0, false}, {"HTTP/one.one", 0, 0, false}, @@ -651,6 +651,12 @@ var parseHTTPVersionTests = []struct { {"HTTP/0,-1", 0, 0, false}, {"HTTP/", 0, 0, false}, {"HTTP/1,1", 0, 0, false}, + {"HTTP/+1.1", 0, 0, false}, + {"HTTP/1.+1", 0, 0, false}, + {"HTTP/0000000001.1", 0, 0, false}, + {"HTTP/1.0000000001", 0, 0, false}, + {"HTTP/3.14", 0, 0, false}, + {"HTTP/12.3", 0, 0, false}, } func TestParseHTTPVersion(t *testing.T) { diff --git a/libgo/go/net/http/requestwrite_test.go b/libgo/go/net/http/requestwrite_test.go index 1157bdf..bdc1e3c 100644 --- a/libgo/go/net/http/requestwrite_test.go +++ b/libgo/go/net/http/requestwrite_test.go @@ -20,7 +20,7 @@ import ( type reqWriteTest struct { Req Request - Body interface{} // optional []byte or func() io.ReadCloser to populate Req.Body + Body any // optional []byte or func() io.ReadCloser to populate Req.Body // Any of these three may be empty to skip that test. WantWrite string // Request.Write diff --git a/libgo/go/net/http/response.go b/libgo/go/net/http/response.go index b8985da..297394e 100644 --- a/libgo/go/net/http/response.go +++ b/libgo/go/net/http/response.go @@ -165,16 +165,14 @@ func ReadResponse(r *bufio.Reader, req *Request) (*Response, error) { } return nil, err } - if i := strings.IndexByte(line, ' '); i == -1 { + proto, status, ok := strings.Cut(line, " ") + if !ok { return nil, badStringError("malformed HTTP response", line) - } else { - resp.Proto = line[:i] - resp.Status = strings.TrimLeft(line[i+1:], " ") - } - statusCode := resp.Status - if i := strings.IndexByte(resp.Status, ' '); i != -1 { - statusCode = resp.Status[:i] } + resp.Proto = proto + resp.Status = strings.TrimLeft(status, " ") + + statusCode, _, _ := strings.Cut(resp.Status, " ") if len(statusCode) != 3 { return nil, badStringError("malformed HTTP status code", statusCode) } @@ -182,7 +180,6 @@ func ReadResponse(r *bufio.Reader, req *Request) (*Response, error) { if err != nil || resp.StatusCode < 0 { return nil, badStringError("malformed HTTP status code", statusCode) } - var ok bool if resp.ProtoMajor, resp.ProtoMinor, ok = ParseHTTPVersion(resp.Proto); !ok { return nil, badStringError("malformed HTTP version", resp.Proto) } diff --git a/libgo/go/net/http/response_test.go b/libgo/go/net/http/response_test.go index 8eef654..5a735b0 100644 --- a/libgo/go/net/http/response_test.go +++ b/libgo/go/net/http/response_test.go @@ -646,8 +646,8 @@ type readerAndCloser struct { func TestReadResponseCloseInMiddle(t *testing.T) { t.Parallel() for _, test := range readResponseCloseInMiddleTests { - fatalf := func(format string, args ...interface{}) { - args = append([]interface{}{test.chunked, test.compressed}, args...) + fatalf := func(format string, args ...any) { + args = append([]any{test.chunked, test.compressed}, args...) t.Fatalf("on test chunked=%v, compressed=%v: "+format, args...) } checkErr := func(err error, msg string) { @@ -732,7 +732,7 @@ func TestReadResponseCloseInMiddle(t *testing.T) { } } -func diff(t *testing.T, prefix string, have, want interface{}) { +func diff(t *testing.T, prefix string, have, want any) { t.Helper() hv := reflect.ValueOf(have).Elem() wv := reflect.ValueOf(want).Elem() @@ -849,10 +849,10 @@ func TestReadResponseErrors(t *testing.T) { type testCase struct { name string // optional, defaults to in in string - wantErr interface{} // nil, err value, or string substring + wantErr any // nil, err value, or string substring } - status := func(s string, wantErr interface{}) testCase { + status := func(s string, wantErr any) testCase { if wantErr == true { wantErr = "malformed HTTP status code" } @@ -863,7 +863,7 @@ func TestReadResponseErrors(t *testing.T) { } } - version := func(s string, wantErr interface{}) testCase { + version := func(s string, wantErr any) testCase { if wantErr == true { wantErr = "malformed HTTP version" } @@ -874,7 +874,7 @@ func TestReadResponseErrors(t *testing.T) { } } - contentLength := func(status, body string, wantErr interface{}) testCase { + contentLength := func(status, body string, wantErr any) testCase { return testCase{ name: fmt.Sprintf("status %q %q", status, body), in: fmt.Sprintf("HTTP/1.1 %s\r\n%s", status, body), @@ -947,7 +947,7 @@ func TestReadResponseErrors(t *testing.T) { // wantErr can be nil, an error value to match exactly, or type string to // match a substring. -func matchErr(err error, wantErr interface{}) error { +func matchErr(err error, wantErr any) error { if err == nil { if wantErr == nil { return nil diff --git a/libgo/go/net/http/roundtrip.go b/libgo/go/net/http/roundtrip.go index eef7c79..c4c5d3b 100644 --- a/libgo/go/net/http/roundtrip.go +++ b/libgo/go/net/http/roundtrip.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js || !wasm -// +build !js !wasm package http diff --git a/libgo/go/net/http/roundtrip_js.go b/libgo/go/net/http/roundtrip_js.go index 74c83a9..01c0600 100644 --- a/libgo/go/net/http/roundtrip_js.go +++ b/libgo/go/net/http/roundtrip_js.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build js && wasm -// +build js,wasm package http @@ -41,11 +40,19 @@ const jsFetchCreds = "js.fetch:credentials" // Reference: https://developer.mozilla.org/en-US/docs/Web/API/WindowOrWorkerGlobalScope/fetch#Parameters const jsFetchRedirect = "js.fetch:redirect" -var useFakeNetwork = js.Global().Get("fetch").IsUndefined() +// jsFetchMissing will be true if the Fetch API is not present in +// the browser globals. +var jsFetchMissing = js.Global().Get("fetch").IsUndefined() // RoundTrip implements the RoundTripper interface using the WHATWG Fetch API. func (t *Transport) RoundTrip(req *Request) (*Response, error) { - if useFakeNetwork { + // The Transport has a documented contract that states that if the DialContext or + // DialTLSContext functions are set, they will be used to set up the connections. + // If they aren't set then the documented contract is to use Dial or DialTLS, even + // though they are deprecated. Therefore, if any of these are set, we should obey + // the contract and dial using the regular round-trip instead. Otherwise, we'll try + // to fall back on the Fetch API, unless it's not available. + if t.Dial != nil || t.DialContext != nil || t.DialTLS != nil || t.DialTLSContext != nil || jsFetchMissing { return t.roundTrip(req) } @@ -111,7 +118,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { errCh = make(chan error, 1) success, failure js.Func ) - success = js.FuncOf(func(this js.Value, args []js.Value) interface{} { + success = js.FuncOf(func(this js.Value, args []js.Value) any { success.Release() failure.Release() @@ -131,8 +138,24 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { } contentLength := int64(0) - if cl, err := strconv.ParseInt(header.Get("Content-Length"), 10, 64); err == nil { + clHeader := header.Get("Content-Length") + switch { + case clHeader != "": + cl, err := strconv.ParseInt(clHeader, 10, 64) + if err != nil { + errCh <- fmt.Errorf("net/http: ill-formed Content-Length header: %v", err) + return nil + } + if cl < 0 { + // Content-Length values less than 0 are invalid. + // See: https://datatracker.ietf.org/doc/html/rfc2616/#section-14.13 + errCh <- fmt.Errorf("net/http: invalid Content-Length header: %q", clHeader) + return nil + } contentLength = cl + default: + // If the response length is not declared, set it to -1. + contentLength = -1 } b := result.Get("body") @@ -159,7 +182,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { return nil }) - failure = js.FuncOf(func(this js.Value, args []js.Value) interface{} { + failure = js.FuncOf(func(this js.Value, args []js.Value) any { success.Release() failure.Release() errCh <- fmt.Errorf("net/http: fetch() failed: %s", args[0].Get("message").String()) @@ -200,7 +223,7 @@ func (r *streamReader) Read(p []byte) (n int, err error) { bCh = make(chan []byte, 1) errCh = make(chan error, 1) ) - success := js.FuncOf(func(this js.Value, args []js.Value) interface{} { + success := js.FuncOf(func(this js.Value, args []js.Value) any { result := args[0] if result.Get("done").Bool() { errCh <- io.EOF @@ -212,7 +235,7 @@ func (r *streamReader) Read(p []byte) (n int, err error) { return nil }) defer success.Release() - failure := js.FuncOf(func(this js.Value, args []js.Value) interface{} { + failure := js.FuncOf(func(this js.Value, args []js.Value) any { // Assumes it's a TypeError. See // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/TypeError // for more information on this type. See @@ -266,7 +289,7 @@ func (r *arrayReader) Read(p []byte) (n int, err error) { bCh = make(chan []byte, 1) errCh = make(chan error, 1) ) - success := js.FuncOf(func(this js.Value, args []js.Value) interface{} { + success := js.FuncOf(func(this js.Value, args []js.Value) any { // Wrap the input ArrayBuffer with a Uint8Array uint8arrayWrapper := uint8Array.New(args[0]) value := make([]byte, uint8arrayWrapper.Get("byteLength").Int()) @@ -275,7 +298,7 @@ func (r *arrayReader) Read(p []byte) (n int, err error) { return nil }) defer success.Release() - failure := js.FuncOf(func(this js.Value, args []js.Value) interface{} { + failure := js.FuncOf(func(this js.Value, args []js.Value) any { // Assumes it's a TypeError. See // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/TypeError // for more information on this type. diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go index 6394da3..fb18cb2 100644 --- a/libgo/go/net/http/serve_test.go +++ b/libgo/go/net/http/serve_test.go @@ -23,6 +23,7 @@ import ( "net" . "net/http" "net/http/httptest" + "net/http/httptrace" "net/http/httputil" "net/http/internal" "net/http/internal/testcert" @@ -2146,7 +2147,7 @@ func TestInvalidTrailerClosesConnection(t *testing.T) { // Read and Write. type slowTestConn struct { // over multiple calls to Read, time.Durations are slept, strings are read. - script []interface{} + script []any closec chan bool mu sync.Mutex // guards rd/wd @@ -2238,7 +2239,7 @@ func TestRequestBodyTimeoutClosesConnection(t *testing.T) { defer afterTest(t) for _, handler := range testHandlerBodyConsumers { conn := &slowTestConn{ - script: []interface{}{ + script: []any{ "POST /public HTTP/1.1\r\n" + "Host: test\r\n" + "Content-Length: 10000\r\n" + @@ -2273,6 +2274,18 @@ func TestRequestBodyTimeoutClosesConnection(t *testing.T) { } } +// cancelableTimeoutContext overwrites the error message to DeadlineExceeded +type cancelableTimeoutContext struct { + context.Context +} + +func (c cancelableTimeoutContext) Err() error { + if c.Context.Err() != nil { + return context.DeadlineExceeded + } + return nil +} + func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) } func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) } func testTimeoutHandler(t *testing.T, h2 bool) { @@ -2285,8 +2298,9 @@ func testTimeoutHandler(t *testing.T, h2 bool) { _, werr := w.Write([]byte("hi")) writeErrors <- werr }) - timeout := make(chan time.Time, 1) // write to this to force timeouts - cst := newClientServerTest(t, h2, NewTestTimeoutHandler(sayHi, timeout)) + ctx, cancel := context.WithCancel(context.Background()) + h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx}) + cst := newClientServerTest(t, h2, h) defer cst.close() // Succeed without timing out: @@ -2307,7 +2321,8 @@ func testTimeoutHandler(t *testing.T, h2 bool) { } // Times out: - timeout <- time.Time{} + cancel() + res, err = cst.c.Get(cst.ts.URL) if err != nil { t.Error(err) @@ -2428,8 +2443,9 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { _, werr := w.Write([]byte("hi")) writeErrors <- werr }) - timeout := make(chan time.Time, 1) // write to this to force timeouts - cst := newClientServerTest(t, h1Mode, NewTestTimeoutHandler(sayHi, timeout)) + ctx, cancel := context.WithCancel(context.Background()) + h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx}) + cst := newClientServerTest(t, h1Mode, h) defer cst.close() // Succeed without timing out: @@ -2450,7 +2466,8 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { } // Times out: - timeout <- time.Time{} + cancel() + res, err = cst.c.Get(cst.ts.URL) if err != nil { t.Error(err) @@ -2500,6 +2517,47 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) { } } +func TestTimeoutHandlerContextCanceled(t *testing.T) { + setParallel(t) + defer afterTest(t) + writeErrors := make(chan error, 1) + sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Type", "text/plain") + var err error + // The request context has already been canceled, but + // retry the write for a while to give the timeout handler + // a chance to notice. + for i := 0; i < 100; i++ { + _, err = w.Write([]byte("a")) + if err != nil { + break + } + time.Sleep(1 * time.Millisecond) + } + writeErrors <- err + }) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + h := NewTestTimeoutHandler(sayHi, ctx) + cst := newClientServerTest(t, h1Mode, h) + defer cst.close() + + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Error(err) + } + if g, e := res.StatusCode, StatusServiceUnavailable; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + body, _ := io.ReadAll(res.Body) + if g, e := string(body), ""; g != e { + t.Errorf("got body %q; expected %q", g, e) + } + if g, e := <-writeErrors, context.Canceled; g != e { + t.Errorf("got unexpected Write in handler: %v, want %g", g, e) + } +} + // https://golang.org/issue/15948 func TestTimeoutHandlerEmptyResponse(t *testing.T) { setParallel(t) @@ -2708,7 +2766,7 @@ func TestHandlerPanicWithHijack(t *testing.T) { testHandlerPanic(t, true, h1Mode, nil, "intentional death for testing") } -func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) Handler, panicValue interface{}) { +func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) Handler, panicValue any) { defer afterTest(t) // Unlike the other tests that set the log output to io.Discard // to quiet the output, this test uses a pipe. The pipe serves three @@ -3017,22 +3075,14 @@ func TestClientWriteShutdown(t *testing.T) { if err != nil { t.Fatalf("CloseWrite: %v", err) } - donec := make(chan bool) - go func() { - defer close(donec) - bs, err := io.ReadAll(conn) - if err != nil { - t.Errorf("ReadAll: %v", err) - } - got := string(bs) - if got != "" { - t.Errorf("read %q from server; want nothing", got) - } - }() - select { - case <-donec: - case <-time.After(10 * time.Second): - t.Fatalf("timeout") + + bs, err := io.ReadAll(conn) + if err != nil { + t.Errorf("ReadAll: %v", err) + } + got := string(bs) + if got != "" { + t.Errorf("read %q from server; want nothing", got) } } @@ -3884,7 +3934,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { // this test fails, it hangs. This helps debugging and I've // added this enough times "temporarily". It now gets added // full time. - errorf := func(format string, args ...interface{}) { + errorf := func(format string, args ...any) { v := fmt.Sprintf(format, args...) println(v) t.Error(v) @@ -3893,10 +3943,10 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { unblockBackend := make(chan bool) backend := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { gone := rw.(CloseNotifier).CloseNotify() - didCopy := make(chan interface{}) + didCopy := make(chan any) go func() { n, err := io.CopyN(rw, req.Body, bodySize) - didCopy <- []interface{}{n, err} + didCopy <- []any{n, err} }() isGone := false Loop: @@ -4888,7 +4938,7 @@ func TestServerContext_LocalAddrContextKey_h2(t *testing.T) { func testServerContext_LocalAddrContextKey(t *testing.T, h2 bool) { setParallel(t) defer afterTest(t) - ch := make(chan interface{}, 1) + ch := make(chan any, 1) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { ch <- r.Context().Value(LocalAddrContextKey) })) @@ -5689,22 +5739,37 @@ func testServerKeepAlivesEnabled(t *testing.T, h2 bool) { } // Not parallel: messes with global variable. (http2goAwayTimeout) defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { - fmt.Fprintf(w, "%v", r.RemoteAddr) - })) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {})) defer cst.close() srv := cst.ts.Config srv.SetKeepAlivesEnabled(false) - a := cst.getURL(cst.ts.URL) - if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) { - t.Fatalf("test server has active conns") - } - b := cst.getURL(cst.ts.URL) - if a == b { - t.Errorf("got same connection between first and second requests") - } - if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) { - t.Fatalf("test server has active conns") + for try := 0; try < 2; try++ { + if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) { + t.Fatalf("request %v: test server has active conns", try) + } + conns := 0 + var info httptrace.GotConnInfo + ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ + GotConn: func(v httptrace.GotConnInfo) { + conns++ + info = v + }, + }) + req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if conns != 1 { + t.Fatalf("request %v: got %v conns, want 1", try, conns) + } + if info.Reused || info.WasIdle { + t.Fatalf("request %v: Reused=%v (want false), WasIdle=%v (want false)", try, info.Reused, info.WasIdle) + } } } @@ -5933,11 +5998,7 @@ func TestServerHijackGetsBackgroundByte_big(t *testing.T) { t.Fatal(err) } - select { - case <-done: - case <-time.After(2 * time.Second): - t.Error("timeout") - } + <-done } // Issue 18319: test that the Server validates the request method. @@ -6232,7 +6293,7 @@ func testContentEncodingNoSniffing(t *testing.T, h2 bool) { // setting contentEncoding as an interface instead of a string // directly, so as to differentiate between 3 states: // unset, empty string "" and set string "foo/bar". - contentEncoding interface{} + contentEncoding any wantContentType string } @@ -6490,7 +6551,7 @@ func TestDisableKeepAliveUpgrade(t *testing.T) { rwc, ok := resp.Body.(io.ReadWriteCloser) if !ok { - t.Fatalf("Response.Body is not a io.ReadWriteCloser: %T", resp.Body) + t.Fatalf("Response.Body is not an io.ReadWriteCloser: %T", resp.Body) } _, err = rwc.Write([]byte("hello")) @@ -6609,3 +6670,63 @@ func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolon } } } + +func TestMaxBytesHandler(t *testing.T) { + setParallel(t) + defer afterTest(t) + + for _, maxSize := range []int64{100, 1_000, 1_000_000} { + for _, requestSize := range []int64{100, 1_000, 1_000_000} { + t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize), + func(t *testing.T) { + testMaxBytesHandler(t, maxSize, requestSize) + }) + } + } +} + +func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) { + var ( + handlerN int64 + handlerErr error + ) + echo := HandlerFunc(func(w ResponseWriter, r *Request) { + var buf bytes.Buffer + handlerN, handlerErr = io.Copy(&buf, r.Body) + io.Copy(w, &buf) + }) + + ts := httptest.NewServer(MaxBytesHandler(echo, maxSize)) + defer ts.Close() + + c := ts.Client() + var buf strings.Builder + body := strings.NewReader(strings.Repeat("a", int(requestSize))) + res, err := c.Post(ts.URL, "text/plain", body) + if err != nil { + t.Errorf("unexpected connection error: %v", err) + } else { + _, err = io.Copy(&buf, res.Body) + res.Body.Close() + if err != nil { + t.Errorf("unexpected read error: %v", err) + } + } + if handlerN > maxSize { + t.Errorf("expected max request body %d; got %d", maxSize, handlerN) + } + if requestSize > maxSize && handlerErr == nil { + t.Error("expected error on handler side; got nil") + } + if requestSize <= maxSize { + if handlerErr != nil { + t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr) + } + if handlerN != requestSize { + t.Errorf("expected request of size %d; got %d", requestSize, handlerN) + } + } + if buf.Len() != int(handlerN) { + t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len()) + } +} diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go index ce39933..f5cdc3a 100644 --- a/libgo/go/net/http/server.go +++ b/libgo/go/net/http/server.go @@ -13,6 +13,7 @@ import ( "crypto/tls" "errors" "fmt" + "internal/godebug" "io" "log" "math/rand" @@ -20,7 +21,6 @@ import ( "net/textproto" "net/url" urlpkg "net/url" - "os" "path" "runtime" "sort" @@ -494,8 +494,8 @@ type response struct { // prior to the headers being written. If the set of trailers is fixed // or known before the header is written, the normal Go trailers mechanism // is preferred: -// https://golang.org/pkg/net/http/#ResponseWriter -// https://golang.org/pkg/net/http/#example_ResponseWriter_trailers +// https://pkg.go.dev/net/http#ResponseWriter +// https://pkg.go.dev/net/http#example-ResponseWriter-Trailers const TrailerPrefix = "Trailer:" // finalTrailers is called after the Handler exits and returns a non-nil @@ -798,7 +798,7 @@ var ( ) var copyBufPool = sync.Pool{ - New: func() interface{} { + New: func() any { b := make([]byte, 32*1024) return &b }, @@ -865,6 +865,28 @@ func (srv *Server) initialReadLimitSize() int64 { return int64(srv.maxHeaderBytes()) + 4096 // bufio slop } +// tlsHandshakeTimeout returns the time limit permitted for the TLS +// handshake, or zero for unlimited. +// +// It returns the minimum of any positive ReadHeaderTimeout, +// ReadTimeout, or WriteTimeout. +func (srv *Server) tlsHandshakeTimeout() time.Duration { + var ret time.Duration + for _, v := range [...]time.Duration{ + srv.ReadHeaderTimeout, + srv.ReadTimeout, + srv.WriteTimeout, + } { + if v <= 0 { + continue + } + if ret == 0 || v < ret { + ret = v + } + } + return ret +} + // wrapper around io.ReadCloser which on first read, sends an // HTTP/1.1 100 Continue header type expectContinueReader struct { @@ -1409,11 +1431,11 @@ func (cw *chunkWriter) writeHeader(p []byte) { hasCL = false } - if w.req.Method == "HEAD" || !bodyAllowedForStatus(code) { - // do nothing - } else if code == StatusNoContent { + if w.req.Method == "HEAD" || !bodyAllowedForStatus(code) || code == StatusNoContent { + // Response has no body. delHeader("Transfer-Encoding") } else if hasCL { + // Content-Length has been provided, so no chunking is to be done. delHeader("Transfer-Encoding") } else if w.req.ProtoAtLeast(1, 1) { // HTTP/1.1 or greater: Transfer-Encoding has been set to identity, and no @@ -1424,6 +1446,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { if hasTE && te == "identity" { cw.chunking = false w.closeAfterReply = true + delHeader("Transfer-Encoding") } else { // HTTP/1.1 or greater: use chunked transfer encoding // to avoid closing the connection at EOF. @@ -1799,6 +1822,7 @@ func isCommonNetReadError(err error) bool { func (c *conn) serve(ctx context.Context) { c.remoteAddr = c.rwc.RemoteAddr().String() ctx = context.WithValue(ctx, LocalAddrContextKey, c.rwc.LocalAddr()) + var inFlightResponse *response defer func() { if err := recover(); err != nil && err != ErrAbortHandler { const size = 64 << 10 @@ -1806,18 +1830,25 @@ func (c *conn) serve(ctx context.Context) { buf = buf[:runtime.Stack(buf, false)] c.server.logf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf) } + if inFlightResponse != nil { + inFlightResponse.cancelCtx() + } if !c.hijacked() { + if inFlightResponse != nil { + inFlightResponse.conn.r.abortPendingRead() + inFlightResponse.reqBody.Close() + } c.close() c.setState(c.rwc, StateClosed, runHooks) } }() if tlsConn, ok := c.rwc.(*tls.Conn); ok { - if d := c.server.ReadTimeout; d > 0 { - c.rwc.SetReadDeadline(time.Now().Add(d)) - } - if d := c.server.WriteTimeout; d > 0 { - c.rwc.SetWriteDeadline(time.Now().Add(d)) + tlsTO := c.server.tlsHandshakeTimeout() + if tlsTO > 0 { + dl := time.Now().Add(tlsTO) + c.rwc.SetReadDeadline(dl) + c.rwc.SetWriteDeadline(dl) } if err := tlsConn.HandshakeContext(ctx); err != nil { // If the handshake failed due to the client not speaking @@ -1831,6 +1862,11 @@ func (c *conn) serve(ctx context.Context) { c.server.logf("http: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err) return } + // Restore Conn-level deadlines. + if tlsTO > 0 { + c.rwc.SetReadDeadline(time.Time{}) + c.rwc.SetWriteDeadline(time.Time{}) + } c.tlsState = new(tls.ConnectionState) *c.tlsState = tlsConn.ConnectionState() if proto := c.tlsState.NegotiatedProtocol; validNextProto(proto) { @@ -1931,7 +1967,9 @@ func (c *conn) serve(ctx context.Context) { // in parallel even if their responses need to be serialized. // But we're not going to implement HTTP pipelining because it // was never deployed in the wild and the answer is HTTP/2. + inFlightResponse = w serverHandler{c.server}.ServeHTTP(w, w.req) + inFlightResponse = nil w.cancelCtx() if c.hijacked() { return @@ -2277,7 +2315,7 @@ func cleanPath(p string) string { // stripHostPort returns h without any trailing ":<port>". func stripHostPort(h string) string { // If no port on host, return unchanged - if strings.IndexByte(h, ':') == -1 { + if !strings.Contains(h, ":") { return h } host, _, err := net.SplitHostPort(h) @@ -3157,7 +3195,7 @@ func (srv *Server) SetKeepAlivesEnabled(v bool) { // TODO: Issue 26303: close HTTP/2 conns as soon as they become idle. } -func (s *Server) logf(format string, args ...interface{}) { +func (s *Server) logf(format string, args ...any) { if s.ErrorLog != nil { s.ErrorLog.Printf(format, args...) } else { @@ -3168,7 +3206,7 @@ func (s *Server) logf(format string, args ...interface{}) { // logf prints to the ErrorLog of the *Server associated with request r // via ServerContextKey. If there's no associated server, or if ErrorLog // is nil, logging is done via the log package's standard logger. -func logf(r *Request, format string, args ...interface{}) { +func logf(r *Request, format string, args ...any) { s, _ := r.Context().Value(ServerContextKey).(*Server) if s != nil && s.ErrorLog != nil { s.ErrorLog.Printf(format, args...) @@ -3264,7 +3302,7 @@ func (srv *Server) onceSetNextProtoDefaults_Serve() { // configured otherwise. (by setting srv.TLSNextProto non-nil) // It must only be called via srv.nextProtoOnce (use srv.setupHTTP2_*). func (srv *Server) onceSetNextProtoDefaults() { - if omitBundledHTTP2 || strings.Contains(os.Getenv("GODEBUG"), "http2server=0") { + if omitBundledHTTP2 || godebug.Get("http2server") == "0" { return } // Enable HTTP/2 by default if the user hasn't otherwise @@ -3331,7 +3369,7 @@ func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { h: make(Header), req: r, } - panicChan := make(chan interface{}, 1) + panicChan := make(chan any, 1) go func() { defer func() { if p := recover(); p != nil { @@ -3359,9 +3397,15 @@ func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { case <-ctx.Done(): tw.mu.Lock() defer tw.mu.Unlock() - w.WriteHeader(StatusServiceUnavailable) - io.WriteString(w, h.errorBody()) - tw.timedOut = true + switch err := ctx.Err(); err { + case context.DeadlineExceeded: + w.WriteHeader(StatusServiceUnavailable) + io.WriteString(w, h.errorBody()) + tw.err = ErrHandlerTimeout + default: + w.WriteHeader(StatusServiceUnavailable) + tw.err = err + } } } @@ -3372,7 +3416,7 @@ type timeoutWriter struct { req *Request mu sync.Mutex - timedOut bool + err error wroteHeader bool code int } @@ -3392,8 +3436,8 @@ func (tw *timeoutWriter) Header() Header { return tw.h } func (tw *timeoutWriter) Write(p []byte) (int, error) { tw.mu.Lock() defer tw.mu.Unlock() - if tw.timedOut { - return 0, ErrHandlerTimeout + if tw.err != nil { + return 0, tw.err } if !tw.wroteHeader { tw.writeHeaderLocked(StatusOK) @@ -3405,7 +3449,7 @@ func (tw *timeoutWriter) writeHeaderLocked(code int) { checkWriteHeaderCode(code) switch { - case tw.timedOut: + case tw.err != nil: return case tw.wroteHeader: if tw.req != nil { @@ -3572,3 +3616,12 @@ func tlsRecordHeaderLooksLikeHTTP(hdr [5]byte) bool { } return false } + +// MaxBytesHandler returns a Handler that runs h with its ResponseWriter and Request.Body wrapped by a MaxBytesReader. +func MaxBytesHandler(h Handler, n int64) Handler { + return HandlerFunc(func(w ResponseWriter, r *Request) { + r2 := *r + r2.Body = MaxBytesReader(w, r.Body, n) + h.ServeHTTP(w, &r2) + }) +} diff --git a/libgo/go/net/http/server_test.go b/libgo/go/net/http/server_test.go index 0132f3b..d17c5c1 100644 --- a/libgo/go/net/http/server_test.go +++ b/libgo/go/net/http/server_test.go @@ -9,8 +9,61 @@ package http import ( "fmt" "testing" + "time" ) +func TestServerTLSHandshakeTimeout(t *testing.T) { + tests := []struct { + s *Server + want time.Duration + }{ + { + s: &Server{}, + want: 0, + }, + { + s: &Server{ + ReadTimeout: -1, + }, + want: 0, + }, + { + s: &Server{ + ReadTimeout: 5 * time.Second, + }, + want: 5 * time.Second, + }, + { + s: &Server{ + ReadTimeout: 5 * time.Second, + WriteTimeout: -1, + }, + want: 5 * time.Second, + }, + { + s: &Server{ + ReadTimeout: 5 * time.Second, + WriteTimeout: 4 * time.Second, + }, + want: 4 * time.Second, + }, + { + s: &Server{ + ReadTimeout: 5 * time.Second, + ReadHeaderTimeout: 2 * time.Second, + WriteTimeout: 4 * time.Second, + }, + want: 2 * time.Second, + }, + } + for i, tt := range tests { + got := tt.s.tlsHandshakeTimeout() + if got != tt.want { + t.Errorf("%d. got %v; want %v", i, got, tt.want) + } + } +} + func BenchmarkServerMatch(b *testing.B) { fn := func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "OK") diff --git a/libgo/go/net/http/transfer.go b/libgo/go/net/http/transfer.go index 85c2e5a..6d51178 100644 --- a/libgo/go/net/http/transfer.go +++ b/libgo/go/net/http/transfer.go @@ -73,7 +73,7 @@ type transferWriter struct { ByteReadCh chan readResult // non-nil if probeRequestBody called } -func newTransferWriter(r interface{}) (t *transferWriter, err error) { +func newTransferWriter(r any) (t *transferWriter, err error) { t = &transferWriter{} // Extract relevant fields @@ -212,6 +212,7 @@ func (t *transferWriter) probeRequestBody() { rres.b = buf[0] } t.ByteReadCh <- rres + close(t.ByteReadCh) }(t.Body) timer := time.NewTimer(200 * time.Millisecond) select { @@ -480,7 +481,7 @@ func suppressedHeaders(status int) []string { } // msg is *Request or *Response. -func readTransfer(msg interface{}, r *bufio.Reader) (err error) { +func readTransfer(msg any, r *bufio.Reader) (err error) { t := &transferReader{RequestMethod: "GET"} // Unify input @@ -808,7 +809,7 @@ func fixTrailer(header Header, chunked bool) (Header, error) { // and then reads the trailer if necessary. type body struct { src io.Reader - hdr interface{} // non-nil (Response or Request) value means read trailer + hdr any // non-nil (Response or Request) value means read trailer r *bufio.Reader // underlying wire-format reader for the trailer closing bool // is the connection to be closed after reading body? doEarlyClose bool // whether Close should stop early @@ -1029,7 +1030,7 @@ func (b *body) registerOnHitEOF(fn func()) { b.onHitEOF = fn } -// bodyLocked is a io.Reader reading from a *body when its mutex is +// bodyLocked is an io.Reader reading from a *body when its mutex is // already held. type bodyLocked struct { b *body @@ -1072,6 +1073,9 @@ func (fr finishAsyncByteRead) Read(p []byte) (n int, err error) { if n == 1 { p[0] = rres.b } + if err == nil { + err = io.EOF + } return } diff --git a/libgo/go/net/http/transport.go b/libgo/go/net/http/transport.go index 309194e..5fe3e6e 100644 --- a/libgo/go/net/http/transport.go +++ b/libgo/go/net/http/transport.go @@ -17,6 +17,7 @@ import ( "crypto/tls" "errors" "fmt" + "internal/godebug" "io" "log" "net" @@ -24,7 +25,6 @@ import ( "net/http/internal/ascii" "net/textproto" "net/url" - "os" "reflect" "strings" "sync" @@ -42,10 +42,10 @@ import ( // $no_proxy) environment variables. var DefaultTransport RoundTripper = &Transport{ Proxy: ProxyFromEnvironment, - DialContext: (&net.Dialer{ + DialContext: defaultTransportDialContext(&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, - }).DialContext, + }), ForceAttemptHTTP2: true, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, @@ -360,7 +360,7 @@ func (t *Transport) hasCustomTLSDialer() bool { // It must be called via t.nextProtoOnce.Do. func (t *Transport) onceSetNextProtoDefaults() { t.tlsNextProtoWasNil = (t.TLSNextProto == nil) - if strings.Contains(os.Getenv("GODEBUG"), "http2client=0") { + if godebug.Get("http2client") == "0" { return } @@ -1715,12 +1715,12 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers return nil, err } if resp.StatusCode != 200 { - f := strings.SplitN(resp.Status, " ", 2) + _, text, ok := strings.Cut(resp.Status, " ") conn.Close() - if len(f) < 2 { + if !ok { return nil, errors.New("unknown status code") } - return nil, errors.New(f[1]) + return nil, errors.New(text) } } @@ -2481,7 +2481,7 @@ type requestAndChan struct { callerGone <-chan struct{} // closed when roundTrip caller has returned } -// A writeRequest is sent by the readLoop's goroutine to the +// A writeRequest is sent by the caller's goroutine to the // writeLoop's goroutine to write a request while the read loop // concurrently waits on both the write response and the server's // reply. @@ -2668,8 +2668,8 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // a t.Logf func. See export_test.go's Request.WithT method. type tLogKey struct{} -func (tr *transportRequest) logf(format string, args ...interface{}) { - if logf, ok := tr.Request.Context().Value(tLogKey{}).(func(string, ...interface{})); ok { +func (tr *transportRequest) logf(format string, args ...any) { + if logf, ok := tr.Request.Context().Value(tLogKey{}).(func(string, ...any)); ok { logf(time.Now().Format(time.RFC3339Nano)+": "+format, args...) } } diff --git a/libgo/go/net/http/transport_default_js.go b/libgo/go/net/http/transport_default_js.go new file mode 100644 index 0000000..c07d35e --- /dev/null +++ b/libgo/go/net/http/transport_default_js.go @@ -0,0 +1,17 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build js && wasm +// +build js,wasm + +package http + +import ( + "context" + "net" +) + +func defaultTransportDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) { + return nil +} diff --git a/libgo/go/net/http/transport_default_other.go b/libgo/go/net/http/transport_default_other.go new file mode 100644 index 0000000..8a2f1cc --- /dev/null +++ b/libgo/go/net/http/transport_default_other.go @@ -0,0 +1,17 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !(js && wasm) +// +build !js !wasm + +package http + +import ( + "context" + "net" +) + +func defaultTransportDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) { + return dialer.DialContext +} diff --git a/libgo/go/net/http/transport_test.go b/libgo/go/net/http/transport_test.go index 7e14749..fed092b 100644 --- a/libgo/go/net/http/transport_test.go +++ b/libgo/go/net/http/transport_test.go @@ -776,7 +776,7 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) { c := ts.Client() fetch := func(n, retries int) string { - condFatalf := func(format string, arg ...interface{}) { + condFatalf := func(format string, arg ...any) { if retries <= 0 { t.Fatalf(format, arg...) } @@ -3518,7 +3518,7 @@ func TestRetryRequestsOnError(t *testing.T) { mu sync.Mutex logbuf bytes.Buffer ) - logf := func(format string, args ...interface{}) { + logf := func(format string, args ...any) { mu.Lock() defer mu.Unlock() fmt.Fprintf(&logbuf, format, args...) @@ -4495,7 +4495,7 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { var mu sync.Mutex // guards buf var buf bytes.Buffer - logf := func(format string, args ...interface{}) { + logf := func(format string, args ...any) { mu.Lock() defer mu.Unlock() fmt.Fprintf(&buf, format, args...) @@ -4654,7 +4654,7 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { func TestTransportEventTraceTLSVerify(t *testing.T) { var mu sync.Mutex var buf bytes.Buffer - logf := func(format string, args ...interface{}) { + logf := func(format string, args ...any) { mu.Lock() defer mu.Unlock() fmt.Fprintf(&buf, format, args...) @@ -4740,7 +4740,7 @@ func TestTransportEventTraceRealDNS(t *testing.T) { var mu sync.Mutex // guards buf var buf bytes.Buffer - logf := func(format string, args ...interface{}) { + logf := func(format string, args ...any) { mu.Lock() defer mu.Unlock() fmt.Fprintf(&buf, format, args...) @@ -6516,3 +6516,32 @@ func TestCancelRequestWhenSharingConnection(t *testing.T) { close(r2c) wg.Wait() } + +func TestHandlerAbortRacesBodyRead(t *testing.T) { + setParallel(t) + defer afterTest(t) + + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + go io.Copy(io.Discard, req.Body) + panic(ErrAbortHandler) + })) + defer ts.Close() + + var wg sync.WaitGroup + for i := 0; i < 2; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 10; j++ { + const reqLen = 6 * 1024 * 1024 + req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen}) + req.ContentLength = reqLen + resp, _ := ts.Client().Transport.RoundTrip(req) + if resp != nil { + resp.Body.Close() + } + } + }() + } + wg.Wait() +} diff --git a/libgo/go/net/http/triv.go b/libgo/go/net/http/triv.go index 4dc6240..11b19ab 100644 --- a/libgo/go/net/http/triv.go +++ b/libgo/go/net/http/triv.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build ignore -// +build ignore package main diff --git a/libgo/go/net/interface_aix.go b/libgo/go/net/interface_aix.go index bd55386..ba15a67 100644 --- a/libgo/go/net/interface_aix.go +++ b/libgo/go/net/interface_aix.go @@ -78,7 +78,7 @@ func interfaceTable(ifindex int) ([]Interface, error) { // Retrieve MTU ifr := &ifreq{} copy(ifr.Name[:], ifi.Name) - err = unix.Ioctl(sock, syscall.SIOCGIFMTU, uintptr(unsafe.Pointer(ifr))) + err = unix.Ioctl(sock, syscall.SIOCGIFMTU, unsafe.Pointer(ifr)) if err != nil { return nil, err } diff --git a/libgo/go/net/interface_bsd.go b/libgo/go/net/interface_bsd.go index 7578b1a..db7bc75 100644 --- a/libgo/go/net/interface_bsd.go +++ b/libgo/go/net/interface_bsd.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build darwin || dragonfly || freebsd || netbsd || openbsd -// +build darwin dragonfly freebsd netbsd openbsd package net diff --git a/libgo/go/net/interface_bsd_test.go b/libgo/go/net/interface_bsd_test.go index 8d0d9c3..ce59962 100644 --- a/libgo/go/net/interface_bsd_test.go +++ b/libgo/go/net/interface_bsd_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build darwin || dragonfly || freebsd || netbsd || openbsd -// +build darwin dragonfly freebsd netbsd openbsd package net diff --git a/libgo/go/net/interface_bsdvar.go b/libgo/go/net/interface_bsdvar.go index 6230e0b..e9bea3d 100644 --- a/libgo/go/net/interface_bsdvar.go +++ b/libgo/go/net/interface_bsdvar.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build dragonfly || netbsd || openbsd -// +build dragonfly netbsd openbsd package net diff --git a/libgo/go/net/interface_freebsd.go b/libgo/go/net/interface_freebsd.go index 2b51fcb..8536bd3 100644 --- a/libgo/go/net/interface_freebsd.go +++ b/libgo/go/net/interface_freebsd.go @@ -11,16 +11,11 @@ import ( ) func interfaceMessages(ifindex int) ([]route.Message, error) { - typ := route.RIBType(syscall.NET_RT_IFLISTL) - rib, err := route.FetchRIB(syscall.AF_UNSPEC, typ, ifindex) + rib, err := route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeInterface, ifindex) if err != nil { - typ = route.RIBType(syscall.NET_RT_IFLIST) - rib, err = route.FetchRIB(syscall.AF_UNSPEC, typ, ifindex) - if err != nil { - return nil, err - } + return nil, err } - return route.ParseRIB(typ, rib) + return route.ParseRIB(route.RIBTypeInterface, rib) } // interfaceMulticastAddrTable returns addresses for a specific diff --git a/libgo/go/net/interface_stub.go b/libgo/go/net/interface_stub.go index 1075e36..fadd8b2 100644 --- a/libgo/go/net/interface_stub.go +++ b/libgo/go/net/interface_stub.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build hurd || (js && wasm) -// +build hurd js,wasm package net diff --git a/libgo/go/net/interface_test.go b/libgo/go/net/interface_test.go index 754db36..f6c9868 100644 --- a/libgo/go/net/interface_test.go +++ b/libgo/go/net/interface_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js -// +build !js package net diff --git a/libgo/go/net/interface_unix_test.go b/libgo/go/net/interface_unix_test.go index 0d69fa5..92ec13a 100644 --- a/libgo/go/net/interface_unix_test.go +++ b/libgo/go/net/interface_unix_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd -// +build darwin dragonfly freebsd linux netbsd openbsd package net diff --git a/libgo/go/net/internal/socktest/main_test.go b/libgo/go/net/internal/socktest/main_test.go index 8af85d3..c7c8d16 100644 --- a/libgo/go/net/internal/socktest/main_test.go +++ b/libgo/go/net/internal/socktest/main_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js && !plan9 -// +build !js,!plan9 package socktest_test diff --git a/libgo/go/net/internal/socktest/main_unix_test.go b/libgo/go/net/internal/socktest/main_unix_test.go index 6aa8875..7d21f6f 100644 --- a/libgo/go/net/internal/socktest/main_unix_test.go +++ b/libgo/go/net/internal/socktest/main_unix_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js && !plan9 && !windows -// +build !js,!plan9,!windows package socktest_test diff --git a/libgo/go/net/internal/socktest/switch_posix.go b/libgo/go/net/internal/socktest/switch_posix.go index cda74e8..fcad4ce 100644 --- a/libgo/go/net/internal/socktest/switch_posix.go +++ b/libgo/go/net/internal/socktest/switch_posix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !plan9 -// +build !plan9 package socktest diff --git a/libgo/go/net/internal/socktest/switch_stub.go b/libgo/go/net/internal/socktest/switch_stub.go index 5aa2ece..8a2fc35 100644 --- a/libgo/go/net/internal/socktest/switch_stub.go +++ b/libgo/go/net/internal/socktest/switch_stub.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build plan9 -// +build plan9 package socktest diff --git a/libgo/go/net/internal/socktest/switch_unix.go b/libgo/go/net/internal/socktest/switch_unix.go index be9ef6d..83df596 100644 --- a/libgo/go/net/internal/socktest/switch_unix.go +++ b/libgo/go/net/internal/socktest/switch_unix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || (js && wasm) || linux || netbsd || openbsd || solaris -// +build aix darwin dragonfly freebsd hurd js,wasm linux netbsd openbsd solaris package socktest diff --git a/libgo/go/net/internal/socktest/sys_cloexec.go b/libgo/go/net/internal/socktest/sys_cloexec.go index 5e95896..c2d9d4b 100644 --- a/libgo/go/net/internal/socktest/sys_cloexec.go +++ b/libgo/go/net/internal/socktest/sys_cloexec.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build dragonfly || freebsd || hurd || illumos || linux || netbsd || openbsd -// +build dragonfly freebsd hurd illumos linux netbsd openbsd package socktest diff --git a/libgo/go/net/internal/socktest/sys_unix.go b/libgo/go/net/internal/socktest/sys_unix.go index 39f3dbc..0cb4693 100644 --- a/libgo/go/net/internal/socktest/sys_unix.go +++ b/libgo/go/net/internal/socktest/sys_unix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || (js && wasm) || linux || netbsd || openbsd || solaris -// +build aix darwin dragonfly freebsd hurd js,wasm linux netbsd openbsd solaris package socktest diff --git a/libgo/go/net/ip.go b/libgo/go/net/ip.go index 38e1aa2..54c5288 100644 --- a/libgo/go/net/ip.go +++ b/libgo/go/net/ip.go @@ -308,7 +308,7 @@ func ubtoa(dst []byte, start int, v byte) int { // It returns one of 4 forms: // - "<nil>", if ip has length 0 // - dotted decimal ("192.0.2.1"), if ip is an IPv4 or IP4-mapped IPv6 address -// - IPv6 ("2001:db8::1"), if ip is a valid IPv6 address +// - IPv6 conforming to RFC 5952 ("2001:db8::1"), if ip is a valid IPv6 address // - the hexadecimal form of ip, without punctuation, if no other cases apply func (ip IP) String() string { p := ip @@ -545,6 +545,9 @@ func (n *IPNet) Network() string { return "ip+net" } // character and a mask expressed as hexadecimal form with no // punctuation like "198.51.100.0/c000ff00". func (n *IPNet) String() string { + if n == nil { + return "<nil>" + } nn, m := networkNumberAndMask(n) if nn == nil || m == nil { return "<nil>" diff --git a/libgo/go/net/ip_test.go b/libgo/go/net/ip_test.go index 5bbda60..8f1590c 100644 --- a/libgo/go/net/ip_test.go +++ b/libgo/go/net/ip_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js -// +build !js package net @@ -408,6 +407,7 @@ var ipNetStringTests = []struct { {&IPNet{IP: IPv4(192, 168, 1, 0), Mask: IPv4Mask(255, 0, 255, 0)}, "192.168.1.0/ff00ff00"}, {&IPNet{IP: ParseIP("2001:db8::"), Mask: CIDRMask(55, 128)}, "2001:db8::/55"}, {&IPNet{IP: ParseIP("2001:db8::"), Mask: IPMask(ParseIP("8000:f123:0:cafe::"))}, "2001:db8::/8000f1230000cafe0000000000000000"}, + {nil, "<nil>"}, } func TestIPNetString(t *testing.T) { @@ -719,7 +719,7 @@ var ipAddrScopeTests = []struct { {IP.IsPrivate, IP{0xfe, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, false}, } -func name(f interface{}) string { +func name(f any) string { return runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name() } diff --git a/libgo/go/net/iprawsock_posix.go b/libgo/go/net/iprawsock_posix.go index ffc437c..04f8e10 100644 --- a/libgo/go/net/iprawsock_posix.go +++ b/libgo/go/net/iprawsock_posix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || (js && wasm) || linux || netbsd || openbsd || solaris || windows -// +build aix darwin dragonfly freebsd hurd js,wasm linux netbsd openbsd solaris windows package net diff --git a/libgo/go/net/iprawsock_test.go b/libgo/go/net/iprawsock_test.go index a96448e..ca5ab48 100644 --- a/libgo/go/net/iprawsock_test.go +++ b/libgo/go/net/iprawsock_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js -// +build !js package net diff --git a/libgo/go/net/ipsock_posix.go b/libgo/go/net/ipsock_posix.go index cdd191a..cec7eb7 100644 --- a/libgo/go/net/ipsock_posix.go +++ b/libgo/go/net/ipsock_posix.go @@ -3,13 +3,13 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || (js && wasm) || linux || netbsd || openbsd || solaris || windows -// +build aix darwin dragonfly freebsd hurd js,wasm linux netbsd openbsd solaris windows package net import ( "context" "internal/poll" + "net/netip" "runtime" "syscall" ) @@ -142,42 +142,87 @@ func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, soty return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr, ctrlFn) } +func ipToSockaddrInet4(ip IP, port int) (syscall.SockaddrInet4, error) { + if len(ip) == 0 { + ip = IPv4zero + } + ip4 := ip.To4() + if ip4 == nil { + return syscall.SockaddrInet4{}, &AddrError{Err: "non-IPv4 address", Addr: ip.String()} + } + sa := syscall.SockaddrInet4{Port: port} + copy(sa.Addr[:], ip4) + return sa, nil +} + +func ipToSockaddrInet6(ip IP, port int, zone string) (syscall.SockaddrInet6, error) { + // In general, an IP wildcard address, which is either + // "0.0.0.0" or "::", means the entire IP addressing + // space. For some historical reason, it is used to + // specify "any available address" on some operations + // of IP node. + // + // When the IP node supports IPv4-mapped IPv6 address, + // we allow a listener to listen to the wildcard + // address of both IP addressing spaces by specifying + // IPv6 wildcard address. + if len(ip) == 0 || ip.Equal(IPv4zero) { + ip = IPv6zero + } + // We accept any IPv6 address including IPv4-mapped + // IPv6 address. + ip6 := ip.To16() + if ip6 == nil { + return syscall.SockaddrInet6{}, &AddrError{Err: "non-IPv6 address", Addr: ip.String()} + } + sa := syscall.SockaddrInet6{Port: port, ZoneId: uint32(zoneCache.index(zone))} + copy(sa.Addr[:], ip6) + return sa, nil +} + func ipToSockaddr(family int, ip IP, port int, zone string) (syscall.Sockaddr, error) { switch family { case syscall.AF_INET: - if len(ip) == 0 { - ip = IPv4zero - } - ip4 := ip.To4() - if ip4 == nil { - return nil, &AddrError{Err: "non-IPv4 address", Addr: ip.String()} + sa, err := ipToSockaddrInet4(ip, port) + if err != nil { + return nil, err } - sa := &syscall.SockaddrInet4{Port: port} - copy(sa.Addr[:], ip4) - return sa, nil + return &sa, nil case syscall.AF_INET6: - // In general, an IP wildcard address, which is either - // "0.0.0.0" or "::", means the entire IP addressing - // space. For some historical reason, it is used to - // specify "any available address" on some operations - // of IP node. - // - // When the IP node supports IPv4-mapped IPv6 address, - // we allow a listener to listen to the wildcard - // address of both IP addressing spaces by specifying - // IPv6 wildcard address. - if len(ip) == 0 || ip.Equal(IPv4zero) { - ip = IPv6zero - } - // We accept any IPv6 address including IPv4-mapped - // IPv6 address. - ip6 := ip.To16() - if ip6 == nil { - return nil, &AddrError{Err: "non-IPv6 address", Addr: ip.String()} + sa, err := ipToSockaddrInet6(ip, port, zone) + if err != nil { + return nil, err } - sa := &syscall.SockaddrInet6{Port: port, ZoneId: uint32(zoneCache.index(zone))} - copy(sa.Addr[:], ip6) - return sa, nil + return &sa, nil } return nil, &AddrError{Err: "invalid address family", Addr: ip.String()} } + +func addrPortToSockaddrInet4(ap netip.AddrPort) (syscall.SockaddrInet4, error) { + // ipToSockaddrInet4 has special handling here for zero length slices. + // We do not, because netip has no concept of a generic zero IP address. + addr := ap.Addr() + if !addr.Is4() { + return syscall.SockaddrInet4{}, &AddrError{Err: "non-IPv4 address", Addr: addr.String()} + } + sa := syscall.SockaddrInet4{ + Addr: addr.As4(), + Port: int(ap.Port()), + } + return sa, nil +} + +func addrPortToSockaddrInet6(ap netip.AddrPort) (syscall.SockaddrInet6, error) { + // ipToSockaddrInet6 has special handling here for zero length slices. + // We do not, because netip has no concept of a generic zero IP address. + addr := ap.Addr() + if !addr.Is6() { + return syscall.SockaddrInet6{}, &AddrError{Err: "non-IPv6 address", Addr: addr.String()} + } + sa := syscall.SockaddrInet6{ + Addr: addr.As16(), + Port: int(ap.Port()), + ZoneId: uint32(zoneCache.index(addr.Zone())), + } + return sa, nil +} diff --git a/libgo/go/net/listen_test.go b/libgo/go/net/listen_test.go index b1dce29..59c0112 100644 --- a/libgo/go/net/listen_test.go +++ b/libgo/go/net/listen_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js && !plan9 -// +build !js,!plan9 package net @@ -380,7 +379,7 @@ func differentWildcardAddr(i, j string) bool { return true } -func checkFirstListener(network string, ln interface{}) error { +func checkFirstListener(network string, ln any) error { switch network { case "tcp": fd := ln.(*TCPListener).fd @@ -535,8 +534,6 @@ func TestIPv4MulticastListener(t *testing.T) { switch runtime.GOOS { case "android", "plan9": t.Skipf("not supported on %s", runtime.GOOS) - case "solaris", "illumos": - t.Skipf("not supported on solaris or illumos, see golang.org/issue/7399") } if !supportsIPv4() { t.Skip("IPv4 is not supported") @@ -610,8 +607,6 @@ func TestIPv6MulticastListener(t *testing.T) { switch runtime.GOOS { case "plan9": t.Skipf("not supported on %s", runtime.GOOS) - case "solaris", "illumos": - t.Skipf("not supported on solaris or illumos, see issue 7399") } if !supportsIPv6() { t.Skip("IPv6 is not supported") @@ -702,10 +697,7 @@ func multicastRIBContains(ip IP) (bool, error) { // Issue 21856. func TestClosingListener(t *testing.T) { - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") addr := ln.Addr() go func() { @@ -743,15 +735,13 @@ func TestListenConfigControl(t *testing.T) { if !testableNetwork(network) { continue } - ln, err := newLocalListener(network) - if err != nil { - t.Error(err) - continue - } + ln := newLocalListener(t, network) address := ln.Addr().String() + // TODO: This is racy. The selected address could be reused in between + // this Close and the subsequent Listen. ln.Close() lc := ListenConfig{Control: controlOnConnSetup} - ln, err = lc.Listen(context.Background(), network, address) + ln, err := lc.Listen(context.Background(), network, address) if err != nil { t.Error(err) continue @@ -764,18 +754,16 @@ func TestListenConfigControl(t *testing.T) { if !testableNetwork(network) { continue } - c, err := newLocalPacketListener(network) - if err != nil { - t.Error(err) - continue - } + c := newLocalPacketListener(t, network) address := c.LocalAddr().String() + // TODO: This is racy. The selected address could be reused in between + // this Close and the subsequent ListenPacket. c.Close() if network == "unixgram" { os.Remove(address) } lc := ListenConfig{Control: controlOnConnSetup} - c, err = lc.ListenPacket(context.Background(), network, address) + c, err := lc.ListenPacket(context.Background(), network, address) if err != nil { t.Error(err) continue diff --git a/libgo/go/net/lookup.go b/libgo/go/net/lookup.go index d350ef7..c7b8dc6 100644 --- a/libgo/go/net/lookup.go +++ b/libgo/go/net/lookup.go @@ -8,6 +8,7 @@ import ( "context" "internal/nettrace" "internal/singleflight" + "net/netip" "sync" ) @@ -232,6 +233,28 @@ func (r *Resolver) LookupIP(ctx context.Context, network, host string) ([]IP, er return ips, nil } +// LookupNetIP looks up host using the local resolver. +// It returns a slice of that host's IP addresses of the type specified by +// network. +// The network must be one of "ip", "ip4" or "ip6". +func (r *Resolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) { + // TODO(bradfitz): make this efficient, making the internal net package + // type throughout be netip.Addr and only converting to the net.IP slice + // version at the edge. But for now (2021-10-20), this is a wrapper around + // the old way. + ips, err := r.LookupIP(ctx, network, host) + if err != nil { + return nil, err + } + ret := make([]netip.Addr, 0, len(ips)) + for _, ip := range ips { + if a, ok := netip.AddrFromSlice(ip); ok { + ret = append(ret, a) + } + } + return ret, nil +} + // onlyValuesCtx is a context that uses an underlying context // for value lookup if the underlying context hasn't yet expired. type onlyValuesCtx struct { @@ -242,7 +265,7 @@ type onlyValuesCtx struct { var _ context.Context = (*onlyValuesCtx)(nil) // Value performs a lookup if the original context hasn't expired. -func (ovc *onlyValuesCtx) Value(key interface{}) interface{} { +func (ovc *onlyValuesCtx) Value(key any) any { select { case <-ovc.lookupValues.Done(): return nil @@ -291,7 +314,7 @@ func (r *Resolver) lookupIPAddr(ctx context.Context, network, host string) ([]IP lookupKey := network + "\000" + host dnsWaitGroup.Add(1) - ch, called := r.getLookupGroup().DoChan(lookupKey, func() (interface{}, error) { + ch, called := r.getLookupGroup().DoChan(lookupKey, func() (any, error) { defer dnsWaitGroup.Done() return testHookLookupIP(lookupGroupCtx, resolverFunc, network, host) }) @@ -316,24 +339,45 @@ func (r *Resolver) lookupIPAddr(ctx context.Context, network, host string) ([]IP lookupGroupCancel() }() } - err := mapErr(ctx.Err()) + ctxErr := ctx.Err() + err := &DNSError{ + Err: mapErr(ctxErr).Error(), + Name: host, + IsTimeout: ctxErr == context.DeadlineExceeded, + } if trace != nil && trace.DNSDone != nil { trace.DNSDone(nil, false, err) } return nil, err case r := <-ch: lookupGroupCancel() + err := r.Err + if err != nil { + if _, ok := err.(*DNSError); !ok { + isTimeout := false + if err == context.DeadlineExceeded { + isTimeout = true + } else if terr, ok := err.(timeout); ok { + isTimeout = terr.Timeout() + } + err = &DNSError{ + Err: err.Error(), + Name: host, + IsTimeout: isTimeout, + } + } + } if trace != nil && trace.DNSDone != nil { addrs, _ := r.Val.([]IPAddr) - trace.DNSDone(ipAddrsEface(addrs), r.Shared, r.Err) + trace.DNSDone(ipAddrsEface(addrs), r.Shared, err) } - return lookupIPReturn(r.Val, r.Err, r.Shared) + return lookupIPReturn(r.Val, err, r.Shared) } } // lookupIPReturn turns the return values from singleflight.Do into // the return values from LookupIP. -func lookupIPReturn(addrsi interface{}, err error, shared bool) ([]IPAddr, error) { +func lookupIPReturn(addrsi any, err error, shared bool) ([]IPAddr, error) { if err != nil { return nil, err } @@ -347,8 +391,8 @@ func lookupIPReturn(addrsi interface{}, err error, shared bool) ([]IPAddr, error } // ipAddrsEface returns an empty interface slice of addrs. -func ipAddrsEface(addrs []IPAddr) []interface{} { - s := make([]interface{}, len(addrs)) +func ipAddrsEface(addrs []IPAddr) []any { + s := make([]any, len(addrs)) for i, v := range addrs { s[i] = v } @@ -442,7 +486,7 @@ func (r *Resolver) LookupCNAME(ctx context.Context, host string) (string, error) // The returned service names are validated to be properly // formatted presentation-format domain names. If the response contains // invalid names, those records are filtered out and an error -// will be returned alongside the the remaining results, if any. +// will be returned alongside the remaining results, if any. func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) { return DefaultResolver.LookupSRV(context.Background(), service, proto, name) } @@ -460,7 +504,7 @@ func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err err // The returned service names are validated to be properly // formatted presentation-format domain names. If the response contains // invalid names, those records are filtered out and an error -// will be returned alongside the the remaining results, if any. +// will be returned alongside the remaining results, if any. func (r *Resolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) { cname, addrs, err := r.lookupSRV(ctx, service, proto, name) if err != nil { @@ -490,7 +534,7 @@ func (r *Resolver) LookupSRV(ctx context.Context, service, proto, name string) ( // The returned mail server names are validated to be properly // formatted presentation-format domain names. If the response contains // invalid names, those records are filtered out and an error -// will be returned alongside the the remaining results, if any. +// will be returned alongside the remaining results, if any. // // LookupMX uses context.Background internally; to specify the context, use // Resolver.LookupMX. @@ -503,7 +547,7 @@ func LookupMX(name string) ([]*MX, error) { // The returned mail server names are validated to be properly // formatted presentation-format domain names. If the response contains // invalid names, those records are filtered out and an error -// will be returned alongside the the remaining results, if any. +// will be returned alongside the remaining results, if any. func (r *Resolver) LookupMX(ctx context.Context, name string) ([]*MX, error) { records, err := r.lookupMX(ctx, name) if err != nil { @@ -514,9 +558,7 @@ func (r *Resolver) LookupMX(ctx context.Context, name string) ([]*MX, error) { if mx == nil { continue } - // Bypass the hostname validity check for targets which contain only a dot, - // as this is used to represent a 'Null' MX record. - if mx.Host != "." && !isDomainName(mx.Host) { + if !isDomainName(mx.Host) { continue } filteredMX = append(filteredMX, mx) @@ -532,7 +574,7 @@ func (r *Resolver) LookupMX(ctx context.Context, name string) ([]*MX, error) { // The returned name server names are validated to be properly // formatted presentation-format domain names. If the response contains // invalid names, those records are filtered out and an error -// will be returned alongside the the remaining results, if any. +// will be returned alongside the remaining results, if any. // // LookupNS uses context.Background internally; to specify the context, use // Resolver.LookupNS. @@ -545,7 +587,7 @@ func LookupNS(name string) ([]*NS, error) { // The returned name server names are validated to be properly // formatted presentation-format domain names. If the response contains // invalid names, those records are filtered out and an error -// will be returned alongside the the remaining results, if any. +// will be returned alongside the remaining results, if any. func (r *Resolver) LookupNS(ctx context.Context, name string) ([]*NS, error) { records, err := r.lookupNS(ctx, name) if err != nil { @@ -585,7 +627,7 @@ func (r *Resolver) LookupTXT(ctx context.Context, name string) ([]string, error) // // The returned names are validated to be properly formatted presentation-format // domain names. If the response contains invalid names, those records are filtered -// out and an error will be returned alongside the the remaining results, if any. +// out and an error will be returned alongside the remaining results, if any. // // When using the host C library resolver, at most one result will be // returned. To bypass the host resolver, use a custom Resolver. @@ -601,7 +643,7 @@ func LookupAddr(addr string) (names []string, err error) { // // The returned names are validated to be properly formatted presentation-format // domain names. If the response contains invalid names, those records are filtered -// out and an error will be returned alongside the the remaining results, if any. +// out and an error will be returned alongside the remaining results, if any. func (r *Resolver) LookupAddr(ctx context.Context, addr string) ([]string, error) { names, err := r.lookupAddr(ctx, addr) if err != nil { @@ -620,6 +662,6 @@ func (r *Resolver) LookupAddr(ctx context.Context, addr string) ([]string, error } // errMalformedDNSRecordsDetail is the DNSError detail which is returned when a Resolver.Lookup... -// method recieves DNS records which contain invalid DNS names. This may be returned alongside +// method receives DNS records which contain invalid DNS names. This may be returned alongside // results which have had the malformed records filtered out. var errMalformedDNSRecordsDetail = "DNS response contained records which contain invalid names" diff --git a/libgo/go/net/lookup_fake.go b/libgo/go/net/lookup_fake.go index f4fcaed..c27eae4 100644 --- a/libgo/go/net/lookup_fake.go +++ b/libgo/go/net/lookup_fake.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build js && wasm -// +build js,wasm package net diff --git a/libgo/go/net/lookup_plan9.go b/libgo/go/net/lookup_plan9.go index 75c18b3..d43a03b 100644 --- a/libgo/go/net/lookup_plan9.go +++ b/libgo/go/net/lookup_plan9.go @@ -262,8 +262,8 @@ func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (cn if !(portOk && priorityOk && weightOk) { continue } - addrs = append(addrs, &SRV{absDomainName([]byte(f[5])), uint16(port), uint16(priority), uint16(weight)}) - cname = absDomainName([]byte(f[0])) + addrs = append(addrs, &SRV{absDomainName(f[5]), uint16(port), uint16(priority), uint16(weight)}) + cname = absDomainName(f[0]) } byPriorityWeight(addrs).sort() return @@ -280,7 +280,7 @@ func (*Resolver) lookupMX(ctx context.Context, name string) (mx []*MX, err error continue } if pref, _, ok := dtoi(f[2]); ok { - mx = append(mx, &MX{absDomainName([]byte(f[3])), uint16(pref)}) + mx = append(mx, &MX{absDomainName(f[3]), uint16(pref)}) } } byPref(mx).sort() @@ -297,7 +297,7 @@ func (*Resolver) lookupNS(ctx context.Context, name string) (ns []*NS, err error if len(f) < 3 { continue } - ns = append(ns, &NS{absDomainName([]byte(f[2]))}) + ns = append(ns, &NS{absDomainName(f[2])}) } return } @@ -329,7 +329,7 @@ func (*Resolver) lookupAddr(ctx context.Context, addr string) (name []string, er if len(f) < 3 { continue } - name = append(name, absDomainName([]byte(f[2]))) + name = append(name, absDomainName(f[2])) } return } diff --git a/libgo/go/net/lookup_test.go b/libgo/go/net/lookup_test.go index 3faaf00..063d650 100644 --- a/libgo/go/net/lookup_test.go +++ b/libgo/go/net/lookup_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js -// +build !js package net @@ -353,6 +352,7 @@ var lookupCNAMETests = []struct { func TestLookupCNAME(t *testing.T) { mustHaveExternalNetwork(t) + testenv.SkipFlakyNet(t) if !supportsIPv4() || !*testIPv4 { t.Skip("IPv4 is required") @@ -391,6 +391,7 @@ var lookupGoogleHostTests = []struct { func TestLookupGoogleHost(t *testing.T) { mustHaveExternalNetwork(t) + testenv.SkipFlakyNet(t) if !supportsIPv4() || !*testIPv4 { t.Skip("IPv4 is required") @@ -443,6 +444,7 @@ var lookupGoogleIPTests = []struct { func TestLookupGoogleIP(t *testing.T) { mustHaveExternalNetwork(t) + testenv.SkipFlakyNet(t) if !supportsIPv4() || !*testIPv4 { t.Skip("IPv4 is required") @@ -633,6 +635,7 @@ func TestLookupDotsWithRemoteSource(t *testing.T) { testenv.SkipFlaky(t, 27992) } mustHaveExternalNetwork(t) + testenv.SkipFlakyNet(t) if !supportsIPv4() || !*testIPv4 { t.Skip("IPv4 is required") @@ -657,7 +660,6 @@ func TestLookupDotsWithRemoteSource(t *testing.T) { func testDots(t *testing.T, mode string) { names, err := LookupAddr("8.8.8.8") // Google dns server if err != nil { - testenv.SkipFlakyNet(t) t.Errorf("LookupAddr(8.8.8.8): %v (mode=%v)", err, mode) } else { for _, name := range names { @@ -670,7 +672,6 @@ func testDots(t *testing.T, mode string) { cname, err := LookupCNAME("www.mit.edu") if err != nil { - testenv.SkipFlakyNet(t) t.Errorf("LookupCNAME(www.mit.edu, mode=%v): %v", mode, err) } else if !strings.HasSuffix(cname, ".") { t.Errorf("LookupCNAME(www.mit.edu) = %v, want cname ending in . with trailing dot (mode=%v)", cname, mode) @@ -678,7 +679,6 @@ func testDots(t *testing.T, mode string) { mxs, err := LookupMX("google.com") if err != nil { - testenv.SkipFlakyNet(t) t.Errorf("LookupMX(google.com): %v (mode=%v)", err, mode) } else { for _, mx := range mxs { @@ -691,7 +691,6 @@ func testDots(t *testing.T, mode string) { nss, err := LookupNS("google.com") if err != nil { - testenv.SkipFlakyNet(t) t.Errorf("LookupNS(google.com): %v (mode=%v)", err, mode) } else { for _, ns := range nss { @@ -704,7 +703,6 @@ func testDots(t *testing.T, mode string) { cname, srvs, err := LookupSRV("xmpp-server", "tcp", "google.com") if err != nil { - testenv.SkipFlakyNet(t) t.Errorf("LookupSRV(xmpp-server, tcp, google.com): %v (mode=%v)", err, mode) } else { if !hasSuffixFold(cname, ".google.com.") { @@ -890,7 +888,7 @@ func TestLookupContextCancel(t *testing.T) { ctx, ctxCancel := context.WithCancel(context.Background()) ctxCancel() _, err := DefaultResolver.LookupIPAddr(ctx, "google.com") - if err != errCanceled { + if err.(*DNSError).Err != errCanceled.Error() { testenv.SkipFlakyNet(t) t.Fatal(err) } @@ -926,6 +924,9 @@ func TestNilResolverLookup(t *testing.T) { // canceled lookups (see golang.org/issue/24178 for details). func TestLookupHostCancel(t *testing.T) { mustHaveExternalNetwork(t) + testenv.SkipFlakyNet(t) + t.Parallel() // Executes 600ms worth of sequential sleeps. + const ( google = "www.google.com" invalidDomain = "invalid.invalid" // RFC 2606 reserves .invalid @@ -944,9 +945,15 @@ func TestLookupHostCancel(t *testing.T) { if err == nil { t.Fatalf("LookupHost(%q): returns %v, but should fail", invalidDomain, addr) } - if !strings.Contains(err.Error(), "canceled") { - t.Fatalf("LookupHost(%q): failed with unexpected error: %v", invalidDomain, err) - } + + // Don't verify what the actual error is. + // We know that it must be non-nil because the domain is invalid, + // but we don't have any guarantee that LookupHost actually bothers + // to check for cancellation on the fast path. + // (For example, it could use a local cache to avoid blocking entirely.) + + // The lookup may deduplicate in-flight requests, so give it time to settle + // in between. time.Sleep(time.Millisecond * 1) } @@ -1050,7 +1057,7 @@ func TestLookupIPAddrPreservesContextValues(t *testing.T) { defer func() { testHookLookupIP = origTestHookLookupIP }() keyValues := []struct { - key, value interface{} + key, value any }{ {"key-1", 12}, {384, "value2"}, @@ -1267,3 +1274,71 @@ func TestResolverLookupIP(t *testing.T) { }) } } + +// A context timeout should still return a DNSError. +func TestDNSTimeout(t *testing.T) { + origTestHookLookupIP := testHookLookupIP + defer func() { testHookLookupIP = origTestHookLookupIP }() + defer dnsWaitGroup.Wait() + + timeoutHookGo := make(chan bool, 1) + timeoutHook := func(ctx context.Context, fn func(context.Context, string, string) ([]IPAddr, error), network, host string) ([]IPAddr, error) { + <-timeoutHookGo + return nil, context.DeadlineExceeded + } + testHookLookupIP = timeoutHook + + checkErr := func(err error) { + t.Helper() + if err == nil { + t.Error("expected an error") + } else if dnserr, ok := err.(*DNSError); !ok { + t.Errorf("got error type %T, want %T", err, (*DNSError)(nil)) + } else if !dnserr.IsTimeout { + t.Errorf("got error %#v, want IsTimeout == true", dnserr) + } else if isTimeout := dnserr.Timeout(); !isTimeout { + t.Errorf("got err.Timeout() == %t, want true", isTimeout) + } + } + + // Single lookup. + timeoutHookGo <- true + _, err := LookupIP("golang.org") + checkErr(err) + + // Double lookup. + var err1, err2 error + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + _, err1 = LookupIP("golang1.org") + }() + go func() { + defer wg.Done() + _, err2 = LookupIP("golang1.org") + }() + close(timeoutHookGo) + wg.Wait() + checkErr(err1) + checkErr(err2) + + // Double lookup with context. + timeoutHookGo = make(chan bool) + ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond) + wg.Add(2) + go func() { + defer wg.Done() + _, err1 = DefaultResolver.LookupIPAddr(ctx, "golang2.org") + }() + go func() { + defer wg.Done() + _, err2 = DefaultResolver.LookupIPAddr(ctx, "golang2.org") + }() + time.Sleep(10 * time.Nanosecond) + close(timeoutHookGo) + wg.Wait() + checkErr(err1) + checkErr(err2) + cancel() +} diff --git a/libgo/go/net/lookup_unix.go b/libgo/go/net/lookup_unix.go index 05f49b0..0d25f22 100644 --- a/libgo/go/net/lookup_unix.go +++ b/libgo/go/net/lookup_unix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris -// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris package net diff --git a/libgo/go/net/lookup_windows.go b/libgo/go/net/lookup_windows.go index bb34a08..27e5f86 100644 --- a/libgo/go/net/lookup_windows.go +++ b/libgo/go/net/lookup_windows.go @@ -226,7 +226,7 @@ func (*Resolver) lookupCNAME(ctx context.Context, name string) (string, error) { // windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS { // if there are no aliases, the canonical name is the input name - return absDomainName([]byte(name)), nil + return absDomainName(name), nil } if e != nil { return "", &DNSError{Err: winError("dnsquery", e).Error(), Name: name} @@ -235,7 +235,7 @@ func (*Resolver) lookupCNAME(ctx context.Context, name string) (string, error) { resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), r) cname := windows.UTF16PtrToString(resolved) - return absDomainName([]byte(cname)), nil + return absDomainName(cname), nil } func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) { @@ -258,10 +258,10 @@ func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (st srvs := make([]*SRV, 0, 10) for _, p := range validRecs(r, syscall.DNS_TYPE_SRV, target) { v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0])) - srvs = append(srvs, &SRV{absDomainName([]byte(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:]))), v.Port, v.Priority, v.Weight}) + srvs = append(srvs, &SRV{absDomainName(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:])), v.Port, v.Priority, v.Weight}) } byPriorityWeight(srvs).sort() - return absDomainName([]byte(target)), srvs, nil + return absDomainName(target), srvs, nil } func (*Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) { @@ -278,7 +278,7 @@ func (*Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) { mxs := make([]*MX, 0, 10) for _, p := range validRecs(r, syscall.DNS_TYPE_MX, name) { v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0])) - mxs = append(mxs, &MX{absDomainName([]byte(windows.UTF16PtrToString(v.NameExchange))), v.Preference}) + mxs = append(mxs, &MX{absDomainName(windows.UTF16PtrToString(v.NameExchange)), v.Preference}) } byPref(mxs).sort() return mxs, nil @@ -298,7 +298,7 @@ func (*Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) { nss := make([]*NS, 0, 10) for _, p := range validRecs(r, syscall.DNS_TYPE_NS, name) { v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0])) - nss = append(nss, &NS{absDomainName([]byte(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:])))}) + nss = append(nss, &NS{absDomainName(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]))}) } return nss, nil } @@ -344,7 +344,7 @@ func (*Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) ptrs := make([]string, 0, 10) for _, p := range validRecs(r, syscall.DNS_TYPE_PTR, arpa) { v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0])) - ptrs = append(ptrs, absDomainName([]byte(windows.UTF16PtrToString(v.Host)))) + ptrs = append(ptrs, absDomainName(windows.UTF16PtrToString(v.Host))) } return ptrs, nil } diff --git a/libgo/go/net/lookup_windows_test.go b/libgo/go/net/lookup_windows_test.go index aa95501..9254733 100644 --- a/libgo/go/net/lookup_windows_test.go +++ b/libgo/go/net/lookup_windows_test.go @@ -21,7 +21,7 @@ import ( var nslookupTestServers = []string{"mail.golang.com", "gmail.com"} var lookupTestIPs = []string{"8.8.8.8", "1.1.1.1"} -func toJson(v interface{}) string { +func toJson(v any) string { data, _ := json.Marshal(v) return string(data) } @@ -220,14 +220,14 @@ func nslookupMX(name string) (mx []*MX, err error) { rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+mail exchanger\s*=\s*([0-9]+)\s*([a-z0-9.\-]+)$`) for _, ans := range rx.FindAllStringSubmatch(r, -1) { pref, _, _ := dtoi(ans[2]) - mx = append(mx, &MX{absDomainName([]byte(ans[3])), uint16(pref)}) + mx = append(mx, &MX{absDomainName(ans[3]), uint16(pref)}) } // windows nslookup syntax // gmail.com MX preference = 30, mail exchanger = alt3.gmail-smtp-in.l.google.com rx = regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+MX preference\s*=\s*([0-9]+)\s*,\s*mail exchanger\s*=\s*([a-z0-9.\-]+)$`) for _, ans := range rx.FindAllStringSubmatch(r, -1) { pref, _, _ := dtoi(ans[2]) - mx = append(mx, &MX{absDomainName([]byte(ans[3])), uint16(pref)}) + mx = append(mx, &MX{absDomainName(ans[3]), uint16(pref)}) } return } @@ -241,7 +241,7 @@ func nslookupNS(name string) (ns []*NS, err error) { // golang.org nameserver = ns1.google.com. rx := regexp.MustCompile(`(?m)^([a-z0-9.\-]+)\s+nameserver\s*=\s*([a-z0-9.\-]+)$`) for _, ans := range rx.FindAllStringSubmatch(r, -1) { - ns = append(ns, &NS{absDomainName([]byte(ans[2]))}) + ns = append(ns, &NS{absDomainName(ans[2])}) } return } @@ -258,7 +258,7 @@ func nslookupCNAME(name string) (cname string, err error) { for _, ans := range rx.FindAllStringSubmatch(r, -1) { last = ans[2] } - return absDomainName([]byte(last)), nil + return absDomainName(last), nil } func nslookupTXT(name string) (txt []string, err error) { @@ -299,7 +299,7 @@ func lookupPTR(name string) (ptr []string, err error) { ptr = make([]string, 0, 10) rx := regexp.MustCompile(`(?m)^Pinging\s+([a-zA-Z0-9.\-]+)\s+\[.*$`) for _, ans := range rx.FindAllStringSubmatch(r, -1) { - ptr = append(ptr, absDomainName([]byte(ans[1]))) + ptr = append(ptr, absDomainName(ans[1])) } return } diff --git a/libgo/go/net/mail/message.go b/libgo/go/net/mail/message.go index 47bbf6c..985b6fc 100644 --- a/libgo/go/net/mail/message.go +++ b/libgo/go/net/mail/message.go @@ -35,7 +35,7 @@ var debug = debugT(false) type debugT bool -func (d debugT) Printf(format string, args ...interface{}) { +func (d debugT) Printf(format string, args ...any) { if d { log.Printf(format, args...) } @@ -100,7 +100,7 @@ func ParseDate(date string) (time.Time, error) { dateLayoutsBuildOnce.Do(buildDateLayouts) // CR and LF must match and are tolerated anywhere in the date field. date = strings.ReplaceAll(date, "\r\n", "") - if strings.Index(date, "\r") != -1 { + if strings.Contains(date, "\r") { return time.Time{}, errors.New("mail: header has a CR without LF") } // Re-using some addrParser methods which support obsolete text, i.e. non-printable ASCII diff --git a/libgo/go/net/main_cloexec_test.go b/libgo/go/net/main_cloexec_test.go index 03f7d63..06f0671 100644 --- a/libgo/go/net/main_cloexec_test.go +++ b/libgo/go/net/main_cloexec_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build dragonfly || freebsd || hurd || illumos || linux || netbsd || openbsd -// +build dragonfly freebsd hurd illumos linux netbsd openbsd package net diff --git a/libgo/go/net/main_conf_test.go b/libgo/go/net/main_conf_test.go index 645b267..41b78ed 100644 --- a/libgo/go/net/main_conf_test.go +++ b/libgo/go/net/main_conf_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js && !plan9 && !windows -// +build !js,!plan9,!windows package net diff --git a/libgo/go/net/main_noconf_test.go b/libgo/go/net/main_noconf_test.go index bcea630..ab050fa 100644 --- a/libgo/go/net/main_noconf_test.go +++ b/libgo/go/net/main_noconf_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build (js && wasm) || plan9 || windows -// +build js,wasm plan9 windows package net diff --git a/libgo/go/net/main_posix_test.go b/libgo/go/net/main_posix_test.go index c9ab25a..8899aa9 100644 --- a/libgo/go/net/main_posix_test.go +++ b/libgo/go/net/main_posix_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js && !plan9 -// +build !js,!plan9 package net @@ -18,9 +17,9 @@ func enableSocketConnect() { } func disableSocketConnect(network string) { - ss := strings.Split(network, ":") + net, _, _ := strings.Cut(network, ":") sw.Set(socktest.FilterConnect, func(so *socktest.Status) (socktest.AfterFilter, error) { - switch ss[0] { + switch net { case "tcp4": if so.Cookie.Family() == syscall.AF_INET && so.Cookie.Type() == syscall.SOCK_STREAM { return nil, syscall.EHOSTUNREACH diff --git a/libgo/go/net/main_test.go b/libgo/go/net/main_test.go index dc17d3f..1ee8c2e 100644 --- a/libgo/go/net/main_test.go +++ b/libgo/go/net/main_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js -// +build !js package net @@ -174,11 +173,8 @@ func runningGoroutines() []string { b := make([]byte, 2<<20) b = b[:runtime.Stack(b, true)] for _, s := range strings.Split(string(b), "\n\n") { - ss := strings.SplitN(s, "\n", 2) - if len(ss) != 2 { - continue - } - stack := strings.TrimSpace(ss[1]) + _, stack, _ := strings.Cut(s, "\n") + stack = strings.TrimSpace(stack) if !strings.Contains(stack, "created by net") { continue } diff --git a/libgo/go/net/main_unix_test.go b/libgo/go/net/main_unix_test.go index 367cefc..402da4d 100644 --- a/libgo/go/net/main_unix_test.go +++ b/libgo/go/net/main_unix_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris -// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris package net diff --git a/libgo/go/net/mockserver_test.go b/libgo/go/net/mockserver_test.go index b50a1e5..186bd33 100644 --- a/libgo/go/net/mockserver_test.go +++ b/libgo/go/net/mockserver_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js -// +build !js package net @@ -11,46 +10,67 @@ import ( "errors" "fmt" "os" + "path/filepath" "sync" "testing" "time" ) -// testUnixAddr uses os.CreateTemp to get a name that is unique. -func testUnixAddr() string { - f, err := os.CreateTemp("", "go-nettest") +// testUnixAddr uses os.MkdirTemp to get a name that is unique. +func testUnixAddr(t testing.TB) string { + // Pass an empty pattern to get a directory name that is as short as possible. + // If we end up with a name longer than the sun_path field in the sockaddr_un + // struct, we won't be able to make the syscall to open the socket. + d, err := os.MkdirTemp("", "") if err != nil { - panic(err) + t.Fatal(err) } - addr := f.Name() - f.Close() - os.Remove(addr) - return addr + t.Cleanup(func() { + if err := os.RemoveAll(d); err != nil { + t.Error(err) + } + }) + return filepath.Join(d, "sock") } -func newLocalListener(network string) (Listener, error) { +func newLocalListener(t testing.TB, network string) Listener { + listen := func(net, addr string) Listener { + ln, err := Listen(net, addr) + if err != nil { + t.Helper() + t.Fatal(err) + } + return ln + } + switch network { case "tcp": if supportsIPv4() { + if !supportsIPv6() { + return listen("tcp4", "127.0.0.1:0") + } if ln, err := Listen("tcp4", "127.0.0.1:0"); err == nil { - return ln, nil + return ln } } if supportsIPv6() { - return Listen("tcp6", "[::1]:0") + return listen("tcp6", "[::1]:0") } case "tcp4": if supportsIPv4() { - return Listen("tcp4", "127.0.0.1:0") + return listen("tcp4", "127.0.0.1:0") } case "tcp6": if supportsIPv6() { - return Listen("tcp6", "[::1]:0") + return listen("tcp6", "[::1]:0") } case "unix", "unixpacket": - return Listen(network, testUnixAddr()) + return listen(network, testUnixAddr(t)) } - return nil, fmt.Errorf("%s is not supported", network) + + t.Helper() + t.Fatalf("%s is not supported", network) + return nil } func newDualStackListener() (lns []*TCPListener, err error) { @@ -121,12 +141,10 @@ func (ls *localServer) teardown() error { return nil } -func newLocalServer(network string) (*localServer, error) { - ln, err := newLocalListener(network) - if err != nil { - return nil, err - } - return &localServer{Listener: ln, done: make(chan bool)}, nil +func newLocalServer(t testing.TB, network string) *localServer { + t.Helper() + ln := newLocalListener(t, network) + return &localServer{Listener: ln, done: make(chan bool)} } type streamListener struct { @@ -135,8 +153,8 @@ type streamListener struct { done chan bool // signal that indicates server stopped } -func (sl *streamListener) newLocalServer() (*localServer, error) { - return &localServer{Listener: sl.Listener, done: make(chan bool)}, nil +func (sl *streamListener) newLocalServer() *localServer { + return &localServer{Listener: sl.Listener, done: make(chan bool)} } type dualStackServer struct { @@ -288,75 +306,39 @@ func transceiver(c Conn, wb []byte, ch chan<- error) { } } -func timeoutReceiver(c Conn, d, min, max time.Duration, ch chan<- error) { - var err error - defer func() { ch <- err }() - - t0 := time.Now() - if err = c.SetReadDeadline(time.Now().Add(d)); err != nil { - return - } - b := make([]byte, 256) - var n int - n, err = c.Read(b) - t1 := time.Now() - if n != 0 || err == nil || !err.(Error).Timeout() { - err = fmt.Errorf("Read did not return (0, timeout): (%d, %v)", n, err) - return - } - if dt := t1.Sub(t0); min > dt || dt > max && !testing.Short() { - err = fmt.Errorf("Read took %s; expected %s", dt, d) - return - } -} - -func timeoutTransmitter(c Conn, d, min, max time.Duration, ch chan<- error) { - var err error - defer func() { ch <- err }() - - t0 := time.Now() - if err = c.SetWriteDeadline(time.Now().Add(d)); err != nil { - return - } - var n int - for { - n, err = c.Write([]byte("TIMEOUT TRANSMITTER")) +func newLocalPacketListener(t testing.TB, network string) PacketConn { + listenPacket := func(net, addr string) PacketConn { + c, err := ListenPacket(net, addr) if err != nil { - break + t.Helper() + t.Fatal(err) } + return c } - t1 := time.Now() - if err == nil || !err.(Error).Timeout() { - err = fmt.Errorf("Write did not return (any, timeout): (%d, %v)", n, err) - return - } - if dt := t1.Sub(t0); min > dt || dt > max && !testing.Short() { - err = fmt.Errorf("Write took %s; expected %s", dt, d) - return - } -} -func newLocalPacketListener(network string) (PacketConn, error) { switch network { case "udp": if supportsIPv4() { - return ListenPacket("udp4", "127.0.0.1:0") + return listenPacket("udp4", "127.0.0.1:0") } if supportsIPv6() { - return ListenPacket("udp6", "[::1]:0") + return listenPacket("udp6", "[::1]:0") } case "udp4": if supportsIPv4() { - return ListenPacket("udp4", "127.0.0.1:0") + return listenPacket("udp4", "127.0.0.1:0") } case "udp6": if supportsIPv6() { - return ListenPacket("udp6", "[::1]:0") + return listenPacket("udp6", "[::1]:0") } case "unixgram": - return ListenPacket(network, testUnixAddr()) + return listenPacket(network, testUnixAddr(t)) } - return nil, fmt.Errorf("%s is not supported", network) + + t.Helper() + t.Fatalf("%s is not supported", network) + return nil } func newDualStackPacketListener() (cs []*UDPConn, err error) { @@ -421,20 +403,18 @@ func (ls *localPacketServer) teardown() error { return nil } -func newLocalPacketServer(network string) (*localPacketServer, error) { - c, err := newLocalPacketListener(network) - if err != nil { - return nil, err - } - return &localPacketServer{PacketConn: c, done: make(chan bool)}, nil +func newLocalPacketServer(t testing.TB, network string) *localPacketServer { + t.Helper() + c := newLocalPacketListener(t, network) + return &localPacketServer{PacketConn: c, done: make(chan bool)} } type packetListener struct { PacketConn } -func (pl *packetListener) newLocalServer() (*localPacketServer, error) { - return &localPacketServer{PacketConn: pl.PacketConn, done: make(chan bool)}, nil +func (pl *packetListener) newLocalServer() *localPacketServer { + return &localPacketServer{PacketConn: pl.PacketConn, done: make(chan bool)} } func packetTransponder(c PacketConn, ch chan<- error) { @@ -505,25 +485,3 @@ func packetTransceiver(c PacketConn, wb []byte, dst Addr, ch chan<- error) { ch <- fmt.Errorf("read %d; want %d", n, len(wb)) } } - -func timeoutPacketReceiver(c PacketConn, d, min, max time.Duration, ch chan<- error) { - var err error - defer func() { ch <- err }() - - t0 := time.Now() - if err = c.SetReadDeadline(time.Now().Add(d)); err != nil { - return - } - b := make([]byte, 256) - var n int - n, _, err = c.ReadFrom(b) - t1 := time.Now() - if n != 0 || err == nil || !err.(Error).Timeout() { - err = fmt.Errorf("ReadFrom did not return (0, timeout): (%d, %v)", n, err) - return - } - if dt := t1.Sub(t0); min > dt || dt > max && !testing.Short() { - err = fmt.Errorf("ReadFrom took %s; expected %s", dt, d) - return - } -} diff --git a/libgo/go/net/net.go b/libgo/go/net/net.go index a7c65ff..77e54a9 100644 --- a/libgo/go/net/net.go +++ b/libgo/go/net/net.go @@ -125,10 +125,10 @@ type Conn interface { // Any blocked Read or Write operations will be unblocked and return errors. Close() error - // LocalAddr returns the local network address. + // LocalAddr returns the local network address, if known. LocalAddr() Addr - // RemoteAddr returns the remote network address. + // RemoteAddr returns the remote network address, if known. RemoteAddr() Addr // SetDeadline sets the read and write deadlines associated @@ -328,7 +328,7 @@ type PacketConn interface { // Any blocked ReadFrom or WriteTo operations will be unblocked and return errors. Close() error - // LocalAddr returns the local network address. + // LocalAddr returns the local network address, if known. LocalAddr() Addr // SetDeadline sets the read and write deadlines associated @@ -396,8 +396,12 @@ type Listener interface { // An Error represents a network error. type Error interface { error - Timeout() bool // Is the error a timeout? - Temporary() bool // Is the error temporary? + Timeout() bool // Is the error a timeout? + + // Deprecated: Temporary errors are not well-defined. + // Most "temporary" errors are timeouts, and the few exceptions are surprising. + // Do not use this method. + Temporary() bool } // Various errors contained in OpError. diff --git a/libgo/go/net/net_fake.go b/libgo/go/net/net_fake.go index 74fc1da..ee5644c 100644 --- a/libgo/go/net/net_fake.go +++ b/libgo/go/net/net_fake.go @@ -5,7 +5,6 @@ // Fake networking for js/wasm. It is intended to allow tests of other package to pass. //go:build js && wasm -// +build js,wasm package net @@ -266,16 +265,48 @@ func sysSocket(family, sotype, proto int) (int, error) { func (fd *netFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) { return 0, nil, syscall.ENOSYS + +} +func (fd *netFD) readFromInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) { + return 0, syscall.ENOSYS +} + +func (fd *netFD) readFromInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) { + return 0, syscall.ENOSYS } func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) { return 0, 0, 0, nil, syscall.ENOSYS } +func (fd *netFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet4) (n, oobn, retflags int, err error) { + return 0, 0, 0, syscall.ENOSYS +} + +func (fd *netFD) readMsgInet6(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet6) (n, oobn, retflags int, err error) { + return 0, 0, 0, syscall.ENOSYS +} + +func (fd *netFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (n int, oobn int, err error) { + return 0, 0, syscall.ENOSYS +} + +func (fd *netFD) writeMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (n int, oobn int, err error) { + return 0, 0, syscall.ENOSYS +} + func (fd *netFD) writeTo(p []byte, sa syscall.Sockaddr) (n int, err error) { return 0, syscall.ENOSYS } +func (fd *netFD) writeToInet4(p []byte, sa *syscall.SockaddrInet4) (n int, err error) { + return 0, syscall.ENOSYS +} + +func (fd *netFD) writeToInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err error) { + return 0, syscall.ENOSYS +} + func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) { return 0, 0, syscall.ENOSYS } diff --git a/libgo/go/net/net_test.go b/libgo/go/net/net_test.go index 6e7be4d..7b16991 100644 --- a/libgo/go/net/net_test.go +++ b/libgo/go/net/net_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js -// +build !js package net @@ -34,10 +33,7 @@ func TestCloseRead(t *testing.T) { } t.Parallel() - ln, err := newLocalListener(network) - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, network) switch network { case "unix", "unixpacket": defer os.Remove(ln.Addr().String()) @@ -133,10 +129,7 @@ func TestCloseWrite(t *testing.T) { } } - ls, err := newLocalServer(network) - if err != nil { - t.Fatal(err) - } + ls := newLocalServer(t, network) defer ls.teardown() if err := ls.buildup(handler); err != nil { t.Fatal(err) @@ -190,10 +183,7 @@ func TestConnClose(t *testing.T) { } t.Parallel() - ln, err := newLocalListener(network) - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, network) switch network { case "unix", "unixpacket": defer os.Remove(ln.Addr().String()) @@ -235,16 +225,12 @@ func TestListenerClose(t *testing.T) { } t.Parallel() - ln, err := newLocalListener(network) - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, network) switch network { case "unix", "unixpacket": defer os.Remove(ln.Addr().String()) } - dst := ln.Addr().String() if err := ln.Close(); err != nil { if perr := parseCloseError(err, false); perr != nil { t.Error(perr) @@ -257,28 +243,12 @@ func TestListenerClose(t *testing.T) { t.Fatal("should fail") } - if network == "tcp" { - // We will have two TCP FSMs inside the - // kernel here. There's no guarantee that a - // signal comes from the far end FSM will be - // delivered immediately to the near end FSM, - // especially on the platforms that allow - // multiple consumer threads to pull pending - // established connections at the same time by - // enabling SO_REUSEPORT option such as Linux, - // DragonFly BSD. So we need to give some time - // quantum to the kernel. - // - // Note that net.inet.tcp.reuseport_ext=1 by - // default on DragonFly BSD. - time.Sleep(time.Millisecond) - - cc, err := Dial("tcp", dst) - if err == nil { - t.Error("Dial to closed TCP listener succeeded.") - cc.Close() - } - } + // Note: we cannot ensure that a subsequent Dial does not succeed, because + // we do not in general have any guarantee that ln.Addr is not immediately + // reused. (TCP sockets enter a TIME_WAIT state when closed, but that only + // applies to existing connections for the port โ it does not prevent the + // port itself from being used for entirely new connections in the + // meantime.) }) } } @@ -293,10 +263,7 @@ func TestPacketConnClose(t *testing.T) { } t.Parallel() - c, err := newLocalPacketListener(network) - if err != nil { - t.Fatal(err) - } + c := newLocalPacketListener(t, network) switch network { case "unixgram": defer os.Remove(c.LocalAddr().String()) @@ -321,18 +288,17 @@ func TestPacketConnClose(t *testing.T) { func TestListenCloseListen(t *testing.T) { const maxTries = 10 for tries := 0; tries < maxTries; tries++ { - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") addr := ln.Addr().String() + // TODO: This is racy. The selected address could be reused in between this + // Close and the subsequent Listen. if err := ln.Close(); err != nil { if perr := parseCloseError(err, false); perr != nil { t.Error(perr) } t.Fatal(err) } - ln, err = Listen("tcp", addr) + ln, err := Listen("tcp", addr) if err == nil { // Success. (This test didn't always make it here earlier.) ln.Close() @@ -378,10 +344,7 @@ func TestAcceptIgnoreAbortedConnRequest(t *testing.T) { } c.Close() } - ls, err := newLocalServer("tcp") - if err != nil { - t.Fatal(err) - } + ls := newLocalServer(t, "tcp") defer ls.teardown() if err := ls.buildup(handler); err != nil { t.Fatal(err) @@ -408,10 +371,7 @@ func TestZeroByteRead(t *testing.T) { } t.Parallel() - ln, err := newLocalListener(network) - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, network) connc := make(chan Conn, 1) go func() { defer ln.Close() @@ -460,10 +420,7 @@ func TestZeroByteRead(t *testing.T) { // runs peer1 and peer2 concurrently. withTCPConnPair returns when // both have completed. func withTCPConnPair(t *testing.T, peer1, peer2 func(c *TCPConn) error) { - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") defer ln.Close() errc := make(chan error, 2) go func() { diff --git a/libgo/go/net/netip/export_test.go b/libgo/go/net/netip/export_test.go new file mode 100644 index 0000000..59971fa --- /dev/null +++ b/libgo/go/net/netip/export_test.go @@ -0,0 +1,30 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package netip + +import "internal/intern" + +var ( + Z0 = z0 + Z4 = z4 + Z6noz = z6noz +) + +type Uint128 = uint128 + +func Mk128(hi, lo uint64) Uint128 { + return uint128{hi, lo} +} + +func MkAddr(u Uint128, z *intern.Value) Addr { + return Addr{u, z} +} + +func IPv4(a, b, c, d uint8) Addr { return AddrFrom4([4]byte{a, b, c, d}) } + +var TestAppendToMarshal = testAppendToMarshal + +func (a Addr) IsZero() bool { return a.isZero() } +func (p Prefix) IsZero() bool { return p.isZero() } diff --git a/libgo/go/net/netip/fuzz_test.go b/libgo/go/net/netip/fuzz_test.go new file mode 100644 index 0000000..4edbcf6 --- /dev/null +++ b/libgo/go/net/netip/fuzz_test.go @@ -0,0 +1,353 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build ignore_due_to_generics + +package netip_test + +import ( + "bytes" + "encoding" + "fmt" + "net" + . "net/netip" + "reflect" + "strings" + "testing" +) + +var corpus = []string{ + // Basic zero IPv4 address. + "0.0.0.0", + // Basic non-zero IPv4 address. + "192.168.140.255", + // IPv4 address in windows-style "print all the digits" form. + "010.000.015.001", + // IPv4 address with a silly amount of leading zeros. + "000001.00000002.00000003.000000004", + // 4-in-6 with octet with leading zero + "::ffff:1.2.03.4", + // Basic zero IPv6 address. + "::", + // Localhost IPv6. + "::1", + // Fully expanded IPv6 address. + "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b", + // IPv6 with elided fields in the middle. + "fd7a:115c::626b:430b", + // IPv6 with elided fields at the end. + "fd7a:115c:a1e0:ab12:4843:cd96::", + // IPv6 with single elided field at the end. + "fd7a:115c:a1e0:ab12:4843:cd96:626b::", + "fd7a:115c:a1e0:ab12:4843:cd96:626b:0", + // IPv6 with single elided field in the middle. + "fd7a:115c:a1e0::4843:cd96:626b:430b", + "fd7a:115c:a1e0:0:4843:cd96:626b:430b", + // IPv6 with the trailing 32 bits written as IPv4 dotted decimal. (4in6) + "::ffff:192.168.140.255", + "::ffff:192.168.140.255", + // IPv6 with a zone specifier. + "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b%eth0", + // IPv6 with dotted decimal and zone specifier. + "1:2::ffff:192.168.140.255%eth1", + "1:2::ffff:c0a8:8cff%eth1", + // IPv6 with capital letters. + "FD9E:1A04:F01D::1", + "fd9e:1a04:f01d::1", + // Empty string. + "", + // Garbage non-IP. + "bad", + // Single number. Some parsers accept this as an IPv4 address in + // big-endian uint32 form, but we don't. + "1234", + // IPv4 with a zone specifier. + "1.2.3.4%eth0", + // IPv4 field must have at least one digit. + ".1.2.3", + "1.2.3.", + "1..2.3", + // IPv4 address too long. + "1.2.3.4.5", + // IPv4 in dotted octal form. + "0300.0250.0214.0377", + // IPv4 in dotted hex form. + "0xc0.0xa8.0x8c.0xff", + // IPv4 in class B form. + "192.168.12345", + // IPv4 in class B form, with a small enough number to be + // parseable as a regular dotted decimal field. + "127.0.1", + // IPv4 in class A form. + "192.1234567", + // IPv4 in class A form, with a small enough number to be + // parseable as a regular dotted decimal field. + "127.1", + // IPv4 field has value >255. + "192.168.300.1", + // IPv4 with too many fields. + "192.168.0.1.5.6", + // IPv6 with not enough fields. + "1:2:3:4:5:6:7", + // IPv6 with too many fields. + "1:2:3:4:5:6:7:8:9", + // IPv6 with 8 fields and a :: expander. + "1:2:3:4::5:6:7:8", + // IPv6 with a field bigger than 2b. + "fe801::1", + // IPv6 with non-hex values in field. + "fe80:tail:scal:e::", + // IPv6 with a zone delimiter but no zone. + "fe80::1%", + // IPv6 with a zone specifier of zero. + "::ffff:0:0%0", + // IPv6 (without ellipsis) with too many fields for trailing embedded IPv4. + "ffff:ffff:ffff:ffff:ffff:ffff:ffff:192.168.140.255", + // IPv6 (with ellipsis) with too many fields for trailing embedded IPv4. + "ffff::ffff:ffff:ffff:ffff:ffff:ffff:192.168.140.255", + // IPv6 with invalid embedded IPv4. + "::ffff:192.168.140.bad", + // IPv6 with multiple ellipsis ::. + "fe80::1::1", + // IPv6 with invalid non hex/colon character. + "fe80:1?:1", + // IPv6 with truncated bytes after single colon. + "fe80:", + // AddrPort strings. + "1.2.3.4:51820", + "[fd7a:115c:a1e0:ab12:4843:cd96:626b:430b]:80", + "[::ffff:c000:0280]:65535", + "[::ffff:c000:0280%eth0]:1", + // Prefix strings. + "1.2.3.4/24", + "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b/118", + "::ffff:c000:0280/96", + "::ffff:c000:0280%eth0/37", +} + +func FuzzParse(f *testing.F) { + for _, seed := range corpus { + f.Add(seed) + } + + f.Fuzz(func(t *testing.T, s string) { + ip, _ := ParseAddr(s) + checkStringParseRoundTrip(t, ip, ParseAddr) + checkEncoding(t, ip) + + // Check that we match the net's IP parser, modulo zones. + if !strings.Contains(s, "%") { + stdip := net.ParseIP(s) + if !ip.IsValid() != (stdip == nil) { + t.Errorf("ParseAddr zero != net.ParseIP nil: ip=%q stdip=%q", ip, stdip) + } + + if ip.IsValid() && !ip.Is4In6() { + buf, err := ip.MarshalText() + if err != nil { + t.Fatal(err) + } + buf2, err := stdip.MarshalText() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf, buf2) { + t.Errorf("Addr.MarshalText() != net.IP.MarshalText(): ip=%q stdip=%q", ip, stdip) + } + if ip.String() != stdip.String() { + t.Errorf("Addr.String() != net.IP.String(): ip=%q stdip=%q", ip, stdip) + } + if ip.IsGlobalUnicast() != stdip.IsGlobalUnicast() { + t.Errorf("Addr.IsGlobalUnicast() != net.IP.IsGlobalUnicast(): ip=%q stdip=%q", ip, stdip) + } + if ip.IsInterfaceLocalMulticast() != stdip.IsInterfaceLocalMulticast() { + t.Errorf("Addr.IsInterfaceLocalMulticast() != net.IP.IsInterfaceLocalMulticast(): ip=%q stdip=%q", ip, stdip) + } + if ip.IsLinkLocalMulticast() != stdip.IsLinkLocalMulticast() { + t.Errorf("Addr.IsLinkLocalMulticast() != net.IP.IsLinkLocalMulticast(): ip=%q stdip=%q", ip, stdip) + } + if ip.IsLinkLocalUnicast() != stdip.IsLinkLocalUnicast() { + t.Errorf("Addr.IsLinkLocalUnicast() != net.IP.IsLinkLocalUnicast(): ip=%q stdip=%q", ip, stdip) + } + if ip.IsLoopback() != stdip.IsLoopback() { + t.Errorf("Addr.IsLoopback() != net.IP.IsLoopback(): ip=%q stdip=%q", ip, stdip) + } + if ip.IsMulticast() != stdip.IsMulticast() { + t.Errorf("Addr.IsMulticast() != net.IP.IsMulticast(): ip=%q stdip=%q", ip, stdip) + } + if ip.IsPrivate() != stdip.IsPrivate() { + t.Errorf("Addr.IsPrivate() != net.IP.IsPrivate(): ip=%q stdip=%q", ip, stdip) + } + if ip.IsUnspecified() != stdip.IsUnspecified() { + t.Errorf("Addr.IsUnspecified() != net.IP.IsUnspecified(): ip=%q stdip=%q", ip, stdip) + } + } + } + + // Check that .Next().Prev() and .Prev().Next() preserve the IP. + if ip.IsValid() && ip.Next().IsValid() && ip.Next().Prev() != ip { + t.Errorf(".Next.Prev did not round trip: ip=%q .next=%q .next.prev=%q", ip, ip.Next(), ip.Next().Prev()) + } + if ip.IsValid() && ip.Prev().IsValid() && ip.Prev().Next() != ip { + t.Errorf(".Prev.Next did not round trip: ip=%q .prev=%q .prev.next=%q", ip, ip.Prev(), ip.Prev().Next()) + } + + port, err := ParseAddrPort(s) + if err == nil { + checkStringParseRoundTrip(t, port, ParseAddrPort) + checkEncoding(t, port) + } + port = AddrPortFrom(ip, 80) + checkStringParseRoundTrip(t, port, ParseAddrPort) + checkEncoding(t, port) + + ipp, err := ParsePrefix(s) + if err == nil { + checkStringParseRoundTrip(t, ipp, ParsePrefix) + checkEncoding(t, ipp) + } + ipp = PrefixFrom(ip, 8) + checkStringParseRoundTrip(t, ipp, ParsePrefix) + checkEncoding(t, ipp) + }) +} + +// checkTextMarshaler checks that x's MarshalText and UnmarshalText functions round trip correctly. +func checkTextMarshaler(t *testing.T, x encoding.TextMarshaler) { + buf, err := x.MarshalText() + if err != nil { + t.Fatal(err) + } + y := reflect.New(reflect.TypeOf(x)).Interface().(encoding.TextUnmarshaler) + err = y.UnmarshalText(buf) + if err != nil { + t.Logf("(%v).MarshalText() = %q", x, buf) + t.Fatalf("(%T).UnmarshalText(%q) = %v", y, buf, err) + } + e := reflect.ValueOf(y).Elem().Interface() + if !reflect.DeepEqual(x, e) { + t.Logf("(%v).MarshalText() = %q", x, buf) + t.Logf("(%T).UnmarshalText(%q) = %v", y, buf, y) + t.Fatalf("MarshalText/UnmarshalText failed to round trip: %#v != %#v", x, e) + } + buf2, err := y.(encoding.TextMarshaler).MarshalText() + if err != nil { + t.Logf("(%v).MarshalText() = %q", x, buf) + t.Logf("(%T).UnmarshalText(%q) = %v", y, buf, y) + t.Fatalf("failed to MarshalText a second time: %v", err) + } + if !bytes.Equal(buf, buf2) { + t.Logf("(%v).MarshalText() = %q", x, buf) + t.Logf("(%T).UnmarshalText(%q) = %v", y, buf, y) + t.Logf("(%v).MarshalText() = %q", y, buf2) + t.Fatalf("second MarshalText differs from first: %q != %q", buf, buf2) + } +} + +// checkBinaryMarshaler checks that x's MarshalText and UnmarshalText functions round trip correctly. +func checkBinaryMarshaler(t *testing.T, x encoding.BinaryMarshaler) { + buf, err := x.MarshalBinary() + if err != nil { + t.Fatal(err) + } + y := reflect.New(reflect.TypeOf(x)).Interface().(encoding.BinaryUnmarshaler) + err = y.UnmarshalBinary(buf) + if err != nil { + t.Logf("(%v).MarshalBinary() = %q", x, buf) + t.Fatalf("(%T).UnmarshalBinary(%q) = %v", y, buf, err) + } + e := reflect.ValueOf(y).Elem().Interface() + if !reflect.DeepEqual(x, e) { + t.Logf("(%v).MarshalBinary() = %q", x, buf) + t.Logf("(%T).UnmarshalBinary(%q) = %v", y, buf, y) + t.Fatalf("MarshalBinary/UnmarshalBinary failed to round trip: %#v != %#v", x, e) + } + buf2, err := y.(encoding.BinaryMarshaler).MarshalBinary() + if err != nil { + t.Logf("(%v).MarshalBinary() = %q", x, buf) + t.Logf("(%T).UnmarshalBinary(%q) = %v", y, buf, y) + t.Fatalf("failed to MarshalBinary a second time: %v", err) + } + if !bytes.Equal(buf, buf2) { + t.Logf("(%v).MarshalBinary() = %q", x, buf) + t.Logf("(%T).UnmarshalBinary(%q) = %v", y, buf, y) + t.Logf("(%v).MarshalBinary() = %q", y, buf2) + t.Fatalf("second MarshalBinary differs from first: %q != %q", buf, buf2) + } +} + +func checkTextMarshalMatchesString(t *testing.T, x netipType) { + buf, err := x.MarshalText() + if err != nil { + t.Fatal(err) + } + str := x.String() + if string(buf) != str { + t.Fatalf("%v: MarshalText = %q, String = %q", x, buf, str) + } +} + +type appendMarshaler interface { + encoding.TextMarshaler + AppendTo([]byte) []byte +} + +// checkTextMarshalMatchesAppendTo checks that x's MarshalText matches x's AppendTo. +func checkTextMarshalMatchesAppendTo(t *testing.T, x appendMarshaler) { + buf, err := x.MarshalText() + if err != nil { + t.Fatal(err) + } + + buf2 := make([]byte, 0, len(buf)) + buf2 = x.AppendTo(buf2) + if !bytes.Equal(buf, buf2) { + t.Fatalf("%v: MarshalText = %q, AppendTo = %q", x, buf, buf2) + } +} + +type netipType interface { + encoding.BinaryMarshaler + encoding.TextMarshaler + fmt.Stringer + IsValid() bool +} + +type netipTypeCmp interface { + comparable + netipType +} + +// checkStringParseRoundTrip checks that x's String method and the provided parse function can round trip correctly. +func checkStringParseRoundTrip[P netipTypeCmp](t *testing.T, x P, parse func(string) (P, error)) { + if !x.IsValid() { + // Ignore invalid values. + return + } + + s := x.String() + y, err := parse(s) + if err != nil { + t.Fatalf("s=%q err=%v", s, err) + } + if x != y { + t.Fatalf("%T round trip identity failure: s=%q x=%#v y=%#v", x, s, x, y) + } + s2 := y.String() + if s != s2 { + t.Fatalf("%T String round trip identity failure: s=%#v s2=%#v", x, s, s2) + } +} + +func checkEncoding(t *testing.T, x netipType) { + if x.IsValid() { + checkTextMarshaler(t, x) + checkBinaryMarshaler(t, x) + checkTextMarshalMatchesString(t, x) + } + + if am, ok := x.(appendMarshaler); ok { + checkTextMarshalMatchesAppendTo(t, am) + } +} diff --git a/libgo/go/net/netip/inlining_test.go b/libgo/go/net/netip/inlining_test.go new file mode 100644 index 0000000..107fe1f --- /dev/null +++ b/libgo/go/net/netip/inlining_test.go @@ -0,0 +1,110 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package netip + +import ( + "internal/testenv" + "os/exec" + "path/filepath" + "regexp" + "runtime" + "strings" + "testing" +) + +func TestInlining(t *testing.T) { + testenv.MustHaveGoBuild(t) + t.Parallel() + var exe string + if runtime.GOOS == "windows" { + exe = ".exe" + } + out, err := exec.Command( + filepath.Join(runtime.GOROOT(), "bin", "go"+exe), + "build", + "--gcflags=-m", + "net/netip").CombinedOutput() + if err != nil { + t.Fatalf("go build: %v, %s", err, out) + } + got := map[string]bool{} + regexp.MustCompile(` can inline (\S+)`).ReplaceAllFunc(out, func(match []byte) []byte { + got[strings.TrimPrefix(string(match), " can inline ")] = true + return nil + }) + wantInlinable := []string{ + "(*uint128).halves", + "Addr.BitLen", + "Addr.hasZone", + "Addr.Is4", + "Addr.Is4In6", + "Addr.Is6", + "Addr.IsLoopback", + "Addr.IsMulticast", + "Addr.IsInterfaceLocalMulticast", + "Addr.IsValid", + "Addr.IsUnspecified", + "Addr.Less", + "Addr.lessOrEq", + "Addr.Unmap", + "Addr.Zone", + "Addr.v4", + "Addr.v6", + "Addr.v6u16", + "Addr.withoutZone", + "AddrPortFrom", + "AddrPort.Addr", + "AddrPort.Port", + "AddrPort.IsValid", + "Prefix.IsSingleIP", + "Prefix.Masked", + "Prefix.IsValid", + "PrefixFrom", + "Prefix.Addr", + "Prefix.Bits", + "AddrFrom4", + "IPv6LinkLocalAllNodes", + "IPv6Unspecified", + "MustParseAddr", + "MustParseAddrPort", + "MustParsePrefix", + "appendDecimal", + "appendHex", + "uint128.addOne", + "uint128.and", + "uint128.bitsClearedFrom", + "uint128.bitsSetFrom", + "uint128.isZero", + "uint128.not", + "uint128.or", + "uint128.subOne", + "uint128.xor", + } + switch runtime.GOARCH { + case "amd64", "arm64": + // These don't inline on 32-bit. + wantInlinable = append(wantInlinable, + "u64CommonPrefixLen", + "uint128.commonPrefixLen", + "Addr.Next", + "Addr.Prev", + ) + } + + for _, want := range wantInlinable { + if !got[want] { + t.Errorf("%q is no longer inlinable", want) + continue + } + delete(got, want) + } + for sym := range got { + if strings.Contains(sym, ".func") { + continue + } + t.Logf("not in expected set, but also inlinable: %q", sym) + + } +} diff --git a/libgo/go/net/netip/leaf_alts.go b/libgo/go/net/netip/leaf_alts.go new file mode 100644 index 0000000..70513ab --- /dev/null +++ b/libgo/go/net/netip/leaf_alts.go @@ -0,0 +1,54 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Stuff that exists in std, but we can't use due to being a dependency +// of net, for go/build deps_test policy reasons. + +package netip + +func stringsLastIndexByte(s string, b byte) int { + for i := len(s) - 1; i >= 0; i-- { + if s[i] == b { + return i + } + } + return -1 +} + +func beUint64(b []byte) uint64 { + _ = b[7] // bounds check hint to compiler; see golang.org/issue/14808 + return uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 | + uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56 +} + +func bePutUint64(b []byte, v uint64) { + _ = b[7] // early bounds check to guarantee safety of writes below + b[0] = byte(v >> 56) + b[1] = byte(v >> 48) + b[2] = byte(v >> 40) + b[3] = byte(v >> 32) + b[4] = byte(v >> 24) + b[5] = byte(v >> 16) + b[6] = byte(v >> 8) + b[7] = byte(v) +} + +func bePutUint32(b []byte, v uint32) { + _ = b[3] // early bounds check to guarantee safety of writes below + b[0] = byte(v >> 24) + b[1] = byte(v >> 16) + b[2] = byte(v >> 8) + b[3] = byte(v) +} + +func leUint16(b []byte) uint16 { + _ = b[1] // bounds check hint to compiler; see golang.org/issue/14808 + return uint16(b[0]) | uint16(b[1])<<8 +} + +func lePutUint16(b []byte, v uint16) { + _ = b[1] // early bounds check to guarantee safety of writes below + b[0] = byte(v) + b[1] = byte(v >> 8) +} diff --git a/libgo/go/net/netip/netip.go b/libgo/go/net/netip/netip.go new file mode 100644 index 0000000..591d38a --- /dev/null +++ b/libgo/go/net/netip/netip.go @@ -0,0 +1,1498 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package netip defines an IP address type that's a small value type. +// Building on that Addr type, the package also defines AddrPort (an +// IP address and a port), and Prefix (an IP address and a bit length +// prefix). +// +// Compared to the net.IP type, this package's Addr type takes less +// memory, is immutable, and is comparable (supports == and being a +// map key). +package netip + +import ( + "errors" + "math" + "strconv" + + "internal/bytealg" + "internal/intern" + "internal/itoa" +) + +// Sizes: (64-bit) +// net.IP: 24 byte slice header + {4, 16} = 28 to 40 bytes +// net.IPAddr: 40 byte slice header + {4, 16} = 44 to 56 bytes + zone length +// netip.Addr: 24 bytes (zone is per-name singleton, shared across all users) + +// Addr represents an IPv4 or IPv6 address (with or without a scoped +// addressing zone), similar to net.IP or net.IPAddr. +// +// Unlike net.IP or net.IPAddr, Addr is a comparable value +// type (it supports == and can be a map key) and is immutable. +// +// The zero Addr is not a valid IP address. +// Addr{} is distinct from both 0.0.0.0 and ::. +type Addr struct { + // addr is the hi and lo bits of an IPv6 address. If z==z4, + // hi and lo contain the IPv4-mapped IPv6 address. + // + // hi and lo are constructed by interpreting a 16-byte IPv6 + // address as a big-endian 128-bit number. The most significant + // bits of that number go into hi, the rest into lo. + // + // For example, 0011:2233:4455:6677:8899:aabb:ccdd:eeff is stored as: + // addr.hi = 0x0011223344556677 + // addr.lo = 0x8899aabbccddeeff + // + // We store IPs like this, rather than as [16]byte, because it + // turns most operations on IPs into arithmetic and bit-twiddling + // operations on 64-bit registers, which is much faster than + // bytewise processing. + addr uint128 + + // z is a combination of the address family and the IPv6 zone. + // + // nil means invalid IP address (for a zero Addr). + // z4 means an IPv4 address. + // z6noz means an IPv6 address without a zone. + // + // Otherwise it's the interned zone name string. + z *intern.Value +} + +// z0, z4, and z6noz are sentinel IP.z values. +// See the IP type's field docs. +var ( + z0 = (*intern.Value)(nil) + z4 = new(intern.Value) + z6noz = new(intern.Value) +) + +// IPv6LinkLocalAllNodes returns the IPv6 link-local all nodes multicast +// address ff02::1. +func IPv6LinkLocalAllNodes() Addr { return AddrFrom16([16]byte{0: 0xff, 1: 0x02, 15: 0x01}) } + +// IPv6Unspecified returns the IPv6 unspecified address "::". +func IPv6Unspecified() Addr { return Addr{z: z6noz} } + +// IPv4Unspecified returns the IPv4 unspecified address "0.0.0.0". +func IPv4Unspecified() Addr { return AddrFrom4([4]byte{}) } + +// AddrFrom4 returns the address of the IPv4 address given by the bytes in addr. +func AddrFrom4(addr [4]byte) Addr { + return Addr{ + addr: uint128{0, 0xffff00000000 | uint64(addr[0])<<24 | uint64(addr[1])<<16 | uint64(addr[2])<<8 | uint64(addr[3])}, + z: z4, + } +} + +// AddrFrom16 returns the IPv6 address given by the bytes in addr. +// An IPv6-mapped IPv4 address is left as an IPv6 address. +// (Use Unmap to convert them if needed.) +func AddrFrom16(addr [16]byte) Addr { + return Addr{ + addr: uint128{ + beUint64(addr[:8]), + beUint64(addr[8:]), + }, + z: z6noz, + } +} + +// ipv6Slice is like IPv6Raw, but operates on a 16-byte slice. Assumes +// slice is 16 bytes, caller must enforce this. +func ipv6Slice(addr []byte) Addr { + return Addr{ + addr: uint128{ + beUint64(addr[:8]), + beUint64(addr[8:]), + }, + z: z6noz, + } +} + +// ParseAddr parses s as an IP address, returning the result. The string +// s can be in dotted decimal ("192.0.2.1"), IPv6 ("2001:db8::68"), +// or IPv6 with a scoped addressing zone ("fe80::1cc0:3e8c:119f:c2e1%ens18"). +func ParseAddr(s string) (Addr, error) { + for i := 0; i < len(s); i++ { + switch s[i] { + case '.': + return parseIPv4(s) + case ':': + return parseIPv6(s) + case '%': + // Assume that this was trying to be an IPv6 address with + // a zone specifier, but the address is missing. + return Addr{}, parseAddrError{in: s, msg: "missing IPv6 address"} + } + } + return Addr{}, parseAddrError{in: s, msg: "unable to parse IP"} +} + +// MustParseAddr calls ParseAddr(s) and panics on error. +// It is intended for use in tests with hard-coded strings. +func MustParseAddr(s string) Addr { + ip, err := ParseAddr(s) + if err != nil { + panic(err) + } + return ip +} + +type parseAddrError struct { + in string // the string given to ParseAddr + msg string // an explanation of the parse failure + at string // optionally, the unparsed portion of in at which the error occurred. +} + +func (err parseAddrError) Error() string { + q := strconv.Quote + if err.at != "" { + return "ParseAddr(" + q(err.in) + "): " + err.msg + " (at " + q(err.at) + ")" + } + return "ParseAddr(" + q(err.in) + "): " + err.msg +} + +// parseIPv4 parses s as an IPv4 address (in form "192.168.0.1"). +func parseIPv4(s string) (ip Addr, err error) { + var fields [4]uint8 + var val, pos int + var digLen int // number of digits in current octet + for i := 0; i < len(s); i++ { + if s[i] >= '0' && s[i] <= '9' { + if digLen == 1 && val == 0 { + return Addr{}, parseAddrError{in: s, msg: "IPv4 field has octet with leading zero"} + } + val = val*10 + int(s[i]) - '0' + digLen++ + if val > 255 { + return Addr{}, parseAddrError{in: s, msg: "IPv4 field has value >255"} + } + } else if s[i] == '.' { + // .1.2.3 + // 1.2.3. + // 1..2.3 + if i == 0 || i == len(s)-1 || s[i-1] == '.' { + return Addr{}, parseAddrError{in: s, msg: "IPv4 field must have at least one digit", at: s[i:]} + } + // 1.2.3.4.5 + if pos == 3 { + return Addr{}, parseAddrError{in: s, msg: "IPv4 address too long"} + } + fields[pos] = uint8(val) + pos++ + val = 0 + digLen = 0 + } else { + return Addr{}, parseAddrError{in: s, msg: "unexpected character", at: s[i:]} + } + } + if pos < 3 { + return Addr{}, parseAddrError{in: s, msg: "IPv4 address too short"} + } + fields[3] = uint8(val) + return AddrFrom4(fields), nil +} + +// parseIPv6 parses s as an IPv6 address (in form "2001:db8::68"). +func parseIPv6(in string) (Addr, error) { + s := in + + // Split off the zone right from the start. Yes it's a second scan + // of the string, but trying to handle it inline makes a bunch of + // other inner loop conditionals more expensive, and it ends up + // being slower. + zone := "" + i := bytealg.IndexByteString(s, '%') + if i != -1 { + s, zone = s[:i], s[i+1:] + if zone == "" { + // Not allowed to have an empty zone if explicitly specified. + return Addr{}, parseAddrError{in: in, msg: "zone must be a non-empty string"} + } + } + + var ip [16]byte + ellipsis := -1 // position of ellipsis in ip + + // Might have leading ellipsis + if len(s) >= 2 && s[0] == ':' && s[1] == ':' { + ellipsis = 0 + s = s[2:] + // Might be only ellipsis + if len(s) == 0 { + return IPv6Unspecified().WithZone(zone), nil + } + } + + // Loop, parsing hex numbers followed by colon. + i = 0 + for i < 16 { + // Hex number. Similar to parseIPv4, inlining the hex number + // parsing yields a significant performance increase. + off := 0 + acc := uint32(0) + for ; off < len(s); off++ { + c := s[off] + if c >= '0' && c <= '9' { + acc = (acc << 4) + uint32(c-'0') + } else if c >= 'a' && c <= 'f' { + acc = (acc << 4) + uint32(c-'a'+10) + } else if c >= 'A' && c <= 'F' { + acc = (acc << 4) + uint32(c-'A'+10) + } else { + break + } + if acc > math.MaxUint16 { + // Overflow, fail. + return Addr{}, parseAddrError{in: in, msg: "IPv6 field has value >=2^16", at: s} + } + } + if off == 0 { + // No digits found, fail. + return Addr{}, parseAddrError{in: in, msg: "each colon-separated field must have at least one digit", at: s} + } + + // If followed by dot, might be in trailing IPv4. + if off < len(s) && s[off] == '.' { + if ellipsis < 0 && i != 12 { + // Not the right place. + return Addr{}, parseAddrError{in: in, msg: "embedded IPv4 address must replace the final 2 fields of the address", at: s} + } + if i+4 > 16 { + // Not enough room. + return Addr{}, parseAddrError{in: in, msg: "too many hex fields to fit an embedded IPv4 at the end of the address", at: s} + } + // TODO: could make this a bit faster by having a helper + // that parses to a [4]byte, and have both parseIPv4 and + // parseIPv6 use it. + ip4, err := parseIPv4(s) + if err != nil { + return Addr{}, parseAddrError{in: in, msg: err.Error(), at: s} + } + ip[i] = ip4.v4(0) + ip[i+1] = ip4.v4(1) + ip[i+2] = ip4.v4(2) + ip[i+3] = ip4.v4(3) + s = "" + i += 4 + break + } + + // Save this 16-bit chunk. + ip[i] = byte(acc >> 8) + ip[i+1] = byte(acc) + i += 2 + + // Stop at end of string. + s = s[off:] + if len(s) == 0 { + break + } + + // Otherwise must be followed by colon and more. + if s[0] != ':' { + return Addr{}, parseAddrError{in: in, msg: "unexpected character, want colon", at: s} + } else if len(s) == 1 { + return Addr{}, parseAddrError{in: in, msg: "colon must be followed by more characters", at: s} + } + s = s[1:] + + // Look for ellipsis. + if s[0] == ':' { + if ellipsis >= 0 { // already have one + return Addr{}, parseAddrError{in: in, msg: "multiple :: in address", at: s} + } + ellipsis = i + s = s[1:] + if len(s) == 0 { // can be at end + break + } + } + } + + // Must have used entire string. + if len(s) != 0 { + return Addr{}, parseAddrError{in: in, msg: "trailing garbage after address", at: s} + } + + // If didn't parse enough, expand ellipsis. + if i < 16 { + if ellipsis < 0 { + return Addr{}, parseAddrError{in: in, msg: "address string too short"} + } + n := 16 - i + for j := i - 1; j >= ellipsis; j-- { + ip[j+n] = ip[j] + } + for j := ellipsis + n - 1; j >= ellipsis; j-- { + ip[j] = 0 + } + } else if ellipsis >= 0 { + // Ellipsis must represent at least one 0 group. + return Addr{}, parseAddrError{in: in, msg: "the :: must expand to at least one field of zeros"} + } + return AddrFrom16(ip).WithZone(zone), nil +} + +// AddrFromSlice parses the 4- or 16-byte byte slice as an IPv4 or IPv6 address. +// Note that a net.IP can be passed directly as the []byte argument. +// If slice's length is not 4 or 16, AddrFromSlice returns Addr{}, false. +func AddrFromSlice(slice []byte) (ip Addr, ok bool) { + switch len(slice) { + case 4: + return AddrFrom4(*(*[4]byte)(slice)), true + case 16: + return ipv6Slice(slice), true + } + return Addr{}, false +} + +// v4 returns the i'th byte of ip. If ip is not an IPv4, v4 returns +// unspecified garbage. +func (ip Addr) v4(i uint8) uint8 { + return uint8(ip.addr.lo >> ((3 - i) * 8)) +} + +// v6 returns the i'th byte of ip. If ip is an IPv4 address, this +// accesses the IPv4-mapped IPv6 address form of the IP. +func (ip Addr) v6(i uint8) uint8 { + return uint8(*(ip.addr.halves()[(i/8)%2]) >> ((7 - i%8) * 8)) +} + +// v6u16 returns the i'th 16-bit word of ip. If ip is an IPv4 address, +// this accesses the IPv4-mapped IPv6 address form of the IP. +func (ip Addr) v6u16(i uint8) uint16 { + return uint16(*(ip.addr.halves()[(i/4)%2]) >> ((3 - i%4) * 16)) +} + +// isZero reports whether ip is the zero value of the IP type. +// The zero value is not a valid IP address of any type. +// +// Note that "0.0.0.0" and "::" are not the zero value. Use IsUnspecified to +// check for these values instead. +func (ip Addr) isZero() bool { + // Faster than comparing ip == Addr{}, but effectively equivalent, + // as there's no way to make an IP with a nil z from this package. + return ip.z == z0 +} + +// IsValid reports whether the Addr is an initialized address (not the zero Addr). +// +// Note that "0.0.0.0" and "::" are both valid values. +func (ip Addr) IsValid() bool { return ip.z != z0 } + +// BitLen returns the number of bits in the IP address: +// 128 for IPv6, 32 for IPv4, and 0 for the zero Addr. +// +// Note that IPv4-mapped IPv6 addresses are considered IPv6 addresses +// and therefore have bit length 128. +func (ip Addr) BitLen() int { + switch ip.z { + case z0: + return 0 + case z4: + return 32 + } + return 128 +} + +// Zone returns ip's IPv6 scoped addressing zone, if any. +func (ip Addr) Zone() string { + if ip.z == nil { + return "" + } + zone, _ := ip.z.Get().(string) + return zone +} + +// Compare returns an integer comparing two IPs. +// The result will be 0 if ip == ip2, -1 if ip < ip2, and +1 if ip > ip2. +// The definition of "less than" is the same as the Less method. +func (ip Addr) Compare(ip2 Addr) int { + f1, f2 := ip.BitLen(), ip2.BitLen() + if f1 < f2 { + return -1 + } + if f1 > f2 { + return 1 + } + hi1, hi2 := ip.addr.hi, ip2.addr.hi + if hi1 < hi2 { + return -1 + } + if hi1 > hi2 { + return 1 + } + lo1, lo2 := ip.addr.lo, ip2.addr.lo + if lo1 < lo2 { + return -1 + } + if lo1 > lo2 { + return 1 + } + if ip.Is6() { + za, zb := ip.Zone(), ip2.Zone() + if za < zb { + return -1 + } + if za > zb { + return 1 + } + } + return 0 +} + +// Less reports whether ip sorts before ip2. +// IP addresses sort first by length, then their address. +// IPv6 addresses with zones sort just after the same address without a zone. +func (ip Addr) Less(ip2 Addr) bool { return ip.Compare(ip2) == -1 } + +func (ip Addr) lessOrEq(ip2 Addr) bool { return ip.Compare(ip2) <= 0 } + +// Is4 reports whether ip is an IPv4 address. +// +// It returns false for IP4-mapped IPv6 addresses. See IP.Unmap. +func (ip Addr) Is4() bool { + return ip.z == z4 +} + +// Is4In6 reports whether ip is an IPv4-mapped IPv6 address. +func (ip Addr) Is4In6() bool { + return ip.Is6() && ip.addr.hi == 0 && ip.addr.lo>>32 == 0xffff +} + +// Is6 reports whether ip is an IPv6 address, including IPv4-mapped +// IPv6 addresses. +func (ip Addr) Is6() bool { + return ip.z != z0 && ip.z != z4 +} + +// Unmap returns ip with any IPv4-mapped IPv6 address prefix removed. +// +// That is, if ip is an IPv6 address wrapping an IPv4 adddress, it +// returns the wrapped IPv4 address. Otherwise it returns ip unmodified. +func (ip Addr) Unmap() Addr { + if ip.Is4In6() { + ip.z = z4 + } + return ip +} + +// WithZone returns an IP that's the same as ip but with the provided +// zone. If zone is empty, the zone is removed. If ip is an IPv4 +// address, WithZone is a no-op and returns ip unchanged. +func (ip Addr) WithZone(zone string) Addr { + if !ip.Is6() { + return ip + } + if zone == "" { + ip.z = z6noz + return ip + } + ip.z = intern.GetByString(zone) + return ip +} + +// withoutZone unconditionally strips the zone from IP. +// It's similar to WithZone, but small enough to be inlinable. +func (ip Addr) withoutZone() Addr { + if !ip.Is6() { + return ip + } + ip.z = z6noz + return ip +} + +// hasZone reports whether IP has an IPv6 zone. +func (ip Addr) hasZone() bool { + return ip.z != z0 && ip.z != z4 && ip.z != z6noz +} + +// IsLinkLocalUnicast reports whether ip is a link-local unicast address. +func (ip Addr) IsLinkLocalUnicast() bool { + // Dynamic Configuration of IPv4 Link-Local Addresses + // https://datatracker.ietf.org/doc/html/rfc3927#section-2.1 + if ip.Is4() { + return ip.v4(0) == 169 && ip.v4(1) == 254 + } + // IP Version 6 Addressing Architecture (2.4 Address Type Identification) + // https://datatracker.ietf.org/doc/html/rfc4291#section-2.4 + if ip.Is6() { + return ip.v6u16(0)&0xffc0 == 0xfe80 + } + return false // zero value +} + +// IsLoopback reports whether ip is a loopback address. +func (ip Addr) IsLoopback() bool { + // Requirements for Internet Hosts -- Communication Layers (3.2.1.3 Addressing) + // https://datatracker.ietf.org/doc/html/rfc1122#section-3.2.1.3 + if ip.Is4() { + return ip.v4(0) == 127 + } + // IP Version 6 Addressing Architecture (2.4 Address Type Identification) + // https://datatracker.ietf.org/doc/html/rfc4291#section-2.4 + if ip.Is6() { + return ip.addr.hi == 0 && ip.addr.lo == 1 + } + return false // zero value +} + +// IsMulticast reports whether ip is a multicast address. +func (ip Addr) IsMulticast() bool { + // Host Extensions for IP Multicasting (4. HOST GROUP ADDRESSES) + // https://datatracker.ietf.org/doc/html/rfc1112#section-4 + if ip.Is4() { + return ip.v4(0)&0xf0 == 0xe0 + } + // IP Version 6 Addressing Architecture (2.4 Address Type Identification) + // https://datatracker.ietf.org/doc/html/rfc4291#section-2.4 + if ip.Is6() { + return ip.addr.hi>>(64-8) == 0xff // ip.v6(0) == 0xff + } + return false // zero value +} + +// IsInterfaceLocalMulticast reports whether ip is an IPv6 interface-local +// multicast address. +func (ip Addr) IsInterfaceLocalMulticast() bool { + // IPv6 Addressing Architecture (2.7.1. Pre-Defined Multicast Addresses) + // https://datatracker.ietf.org/doc/html/rfc4291#section-2.7.1 + if ip.Is6() { + return ip.v6u16(0)&0xff0f == 0xff01 + } + return false // zero value +} + +// IsLinkLocalMulticast reports whether ip is a link-local multicast address. +func (ip Addr) IsLinkLocalMulticast() bool { + // IPv4 Multicast Guidelines (4. Local Network Control Block (224.0.0/24)) + // https://datatracker.ietf.org/doc/html/rfc5771#section-4 + if ip.Is4() { + return ip.v4(0) == 224 && ip.v4(1) == 0 && ip.v4(2) == 0 + } + // IPv6 Addressing Architecture (2.7.1. Pre-Defined Multicast Addresses) + // https://datatracker.ietf.org/doc/html/rfc4291#section-2.7.1 + if ip.Is6() { + return ip.v6u16(0)&0xff0f == 0xff02 + } + return false // zero value +} + +// IsGlobalUnicast reports whether ip is a global unicast address. +// +// It returns true for IPv6 addresses which fall outside of the current +// IANA-allocated 2000::/3 global unicast space, with the exception of the +// link-local address space. It also returns true even if ip is in the IPv4 +// private address space or IPv6 unique local address space. +// It returns false for the zero Addr. +// +// For reference, see RFC 1122, RFC 4291, and RFC 4632. +func (ip Addr) IsGlobalUnicast() bool { + if ip.z == z0 { + // Invalid or zero-value. + return false + } + + // Match package net's IsGlobalUnicast logic. Notably private IPv4 addresses + // and ULA IPv6 addresses are still considered "global unicast". + if ip.Is4() && (ip == IPv4Unspecified() || ip == AddrFrom4([4]byte{255, 255, 255, 255})) { + return false + } + + return ip != IPv6Unspecified() && + !ip.IsLoopback() && + !ip.IsMulticast() && + !ip.IsLinkLocalUnicast() +} + +// IsPrivate reports whether ip is a private address, according to RFC 1918 +// (IPv4 addresses) and RFC 4193 (IPv6 addresses). That is, it reports whether +// ip is in 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, or fc00::/7. This is the +// same as net.IP.IsPrivate. +func (ip Addr) IsPrivate() bool { + // Match the stdlib's IsPrivate logic. + if ip.Is4() { + // RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as + // private IPv4 address subnets. + return ip.v4(0) == 10 || + (ip.v4(0) == 172 && ip.v4(1)&0xf0 == 16) || + (ip.v4(0) == 192 && ip.v4(1) == 168) + } + + if ip.Is6() { + // RFC 4193 allocates fc00::/7 as the unique local unicast IPv6 address + // subnet. + return ip.v6(0)&0xfe == 0xfc + } + + return false // zero value +} + +// IsUnspecified reports whether ip is an unspecified address, either the IPv4 +// address "0.0.0.0" or the IPv6 address "::". +// +// Note that the zero Addr is not an unspecified address. +func (ip Addr) IsUnspecified() bool { + return ip == IPv4Unspecified() || ip == IPv6Unspecified() +} + +// Prefix keeps only the top b bits of IP, producing a Prefix +// of the specified length. +// If ip is a zero Addr, Prefix always returns a zero Prefix and a nil error. +// Otherwise, if bits is less than zero or greater than ip.BitLen(), +// Prefix returns an error. +func (ip Addr) Prefix(b int) (Prefix, error) { + if b < 0 { + return Prefix{}, errors.New("negative Prefix bits") + } + effectiveBits := b + switch ip.z { + case z0: + return Prefix{}, nil + case z4: + if b > 32 { + return Prefix{}, errors.New("prefix length " + itoa.Itoa(b) + " too large for IPv4") + } + effectiveBits += 96 + default: + if b > 128 { + return Prefix{}, errors.New("prefix length " + itoa.Itoa(b) + " too large for IPv6") + } + } + ip.addr = ip.addr.and(mask6(effectiveBits)) + return PrefixFrom(ip, b), nil +} + +const ( + netIPv4len = 4 + netIPv6len = 16 +) + +// As16 returns the IP address in its 16-byte representation. +// IPv4 addresses are returned in their v6-mapped form. +// IPv6 addresses with zones are returned without their zone (use the +// Zone method to get it). +// The ip zero value returns all zeroes. +func (ip Addr) As16() (a16 [16]byte) { + bePutUint64(a16[:8], ip.addr.hi) + bePutUint64(a16[8:], ip.addr.lo) + return a16 +} + +// As4 returns an IPv4 or IPv4-in-IPv6 address in its 4-byte representation. +// If ip is the zero Addr or an IPv6 address, As4 panics. +// Note that 0.0.0.0 is not the zero Addr. +func (ip Addr) As4() (a4 [4]byte) { + if ip.z == z4 || ip.Is4In6() { + bePutUint32(a4[:], uint32(ip.addr.lo)) + return a4 + } + if ip.z == z0 { + panic("As4 called on IP zero value") + } + panic("As4 called on IPv6 address") +} + +// AsSlice returns an IPv4 or IPv6 address in its respective 4-byte or 16-byte representation. +func (ip Addr) AsSlice() []byte { + switch ip.z { + case z0: + return nil + case z4: + var ret [4]byte + bePutUint32(ret[:], uint32(ip.addr.lo)) + return ret[:] + default: + var ret [16]byte + bePutUint64(ret[:8], ip.addr.hi) + bePutUint64(ret[8:], ip.addr.lo) + return ret[:] + } +} + +// Next returns the address following ip. +// If there is none, it returns the zero Addr. +func (ip Addr) Next() Addr { + ip.addr = ip.addr.addOne() + if ip.Is4() { + if uint32(ip.addr.lo) == 0 { + // Overflowed. + return Addr{} + } + } else { + if ip.addr.isZero() { + // Overflowed + return Addr{} + } + } + return ip +} + +// Prev returns the IP before ip. +// If there is none, it returns the IP zero value. +func (ip Addr) Prev() Addr { + if ip.Is4() { + if uint32(ip.addr.lo) == 0 { + return Addr{} + } + } else if ip.addr.isZero() { + return Addr{} + } + ip.addr = ip.addr.subOne() + return ip +} + +// String returns the string form of the IP address ip. +// It returns one of 5 forms: +// +// - "invalid IP", if ip is the zero Addr +// - IPv4 dotted decimal ("192.0.2.1") +// - IPv6 ("2001:db8::1") +// - "::ffff:1.2.3.4" (if Is4In6) +// - IPv6 with zone ("fe80:db8::1%eth0") +// +// Note that unlike package net's IP.String method, +// IP4-mapped IPv6 addresses format with a "::ffff:" +// prefix before the dotted quad. +func (ip Addr) String() string { + switch ip.z { + case z0: + return "invalid IP" + case z4: + return ip.string4() + default: + if ip.Is4In6() { + // TODO(bradfitz): this could alloc less. + if z := ip.Zone(); z != "" { + return "::ffff:" + ip.Unmap().String() + "%" + z + } else { + return "::ffff:" + ip.Unmap().String() + } + } + return ip.string6() + } +} + +// AppendTo appends a text encoding of ip, +// as generated by MarshalText, +// to b and returns the extended buffer. +func (ip Addr) AppendTo(b []byte) []byte { + switch ip.z { + case z0: + return b + case z4: + return ip.appendTo4(b) + default: + if ip.Is4In6() { + b = append(b, "::ffff:"...) + b = ip.Unmap().appendTo4(b) + if z := ip.Zone(); z != "" { + b = append(b, '%') + b = append(b, z...) + } + return b + } + return ip.appendTo6(b) + } +} + +// digits is a string of the hex digits from 0 to f. It's used in +// appendDecimal and appendHex to format IP addresses. +const digits = "0123456789abcdef" + +// appendDecimal appends the decimal string representation of x to b. +func appendDecimal(b []byte, x uint8) []byte { + // Using this function rather than strconv.AppendUint makes IPv4 + // string building 2x faster. + + if x >= 100 { + b = append(b, digits[x/100]) + } + if x >= 10 { + b = append(b, digits[x/10%10]) + } + return append(b, digits[x%10]) +} + +// appendHex appends the hex string representation of x to b. +func appendHex(b []byte, x uint16) []byte { + // Using this function rather than strconv.AppendUint makes IPv6 + // string building 2x faster. + + if x >= 0x1000 { + b = append(b, digits[x>>12]) + } + if x >= 0x100 { + b = append(b, digits[x>>8&0xf]) + } + if x >= 0x10 { + b = append(b, digits[x>>4&0xf]) + } + return append(b, digits[x&0xf]) +} + +// appendHexPad appends the fully padded hex string representation of x to b. +func appendHexPad(b []byte, x uint16) []byte { + return append(b, digits[x>>12], digits[x>>8&0xf], digits[x>>4&0xf], digits[x&0xf]) +} + +func (ip Addr) string4() string { + const max = len("255.255.255.255") + ret := make([]byte, 0, max) + ret = ip.appendTo4(ret) + return string(ret) +} + +func (ip Addr) appendTo4(ret []byte) []byte { + ret = appendDecimal(ret, ip.v4(0)) + ret = append(ret, '.') + ret = appendDecimal(ret, ip.v4(1)) + ret = append(ret, '.') + ret = appendDecimal(ret, ip.v4(2)) + ret = append(ret, '.') + ret = appendDecimal(ret, ip.v4(3)) + return ret +} + +// string6 formats ip in IPv6 textual representation. It follows the +// guidelines in section 4 of RFC 5952 +// (https://tools.ietf.org/html/rfc5952#section-4): no unnecessary +// zeros, use :: to elide the longest run of zeros, and don't use :: +// to compact a single zero field. +func (ip Addr) string6() string { + // Use a zone with a "plausibly long" name, so that most zone-ful + // IP addresses won't require additional allocation. + // + // The compiler does a cool optimization here, where ret ends up + // stack-allocated and so the only allocation this function does + // is to construct the returned string. As such, it's okay to be a + // bit greedy here, size-wise. + const max = len("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff%enp5s0") + ret := make([]byte, 0, max) + ret = ip.appendTo6(ret) + return string(ret) +} + +func (ip Addr) appendTo6(ret []byte) []byte { + zeroStart, zeroEnd := uint8(255), uint8(255) + for i := uint8(0); i < 8; i++ { + j := i + for j < 8 && ip.v6u16(j) == 0 { + j++ + } + if l := j - i; l >= 2 && l > zeroEnd-zeroStart { + zeroStart, zeroEnd = i, j + } + } + + for i := uint8(0); i < 8; i++ { + if i == zeroStart { + ret = append(ret, ':', ':') + i = zeroEnd + if i >= 8 { + break + } + } else if i > 0 { + ret = append(ret, ':') + } + + ret = appendHex(ret, ip.v6u16(i)) + } + + if ip.z != z6noz { + ret = append(ret, '%') + ret = append(ret, ip.Zone()...) + } + return ret +} + +// StringExpanded is like String but IPv6 addresses are expanded with leading +// zeroes and no "::" compression. For example, "2001:db8::1" becomes +// "2001:0db8:0000:0000:0000:0000:0000:0001". +func (ip Addr) StringExpanded() string { + switch ip.z { + case z0, z4: + return ip.String() + } + + const size = len("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff") + ret := make([]byte, 0, size) + for i := uint8(0); i < 8; i++ { + if i > 0 { + ret = append(ret, ':') + } + + ret = appendHexPad(ret, ip.v6u16(i)) + } + + if ip.z != z6noz { + // The addition of a zone will cause a second allocation, but when there + // is no zone the ret slice will be stack allocated. + ret = append(ret, '%') + ret = append(ret, ip.Zone()...) + } + return string(ret) +} + +// MarshalText implements the encoding.TextMarshaler interface, +// The encoding is the same as returned by String, with one exception: +// If ip is the zero Addr, the encoding is the empty string. +func (ip Addr) MarshalText() ([]byte, error) { + switch ip.z { + case z0: + return []byte(""), nil + case z4: + max := len("255.255.255.255") + b := make([]byte, 0, max) + return ip.appendTo4(b), nil + default: + max := len("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff%enp5s0") + b := make([]byte, 0, max) + if ip.Is4In6() { + b = append(b, "::ffff:"...) + b = ip.Unmap().appendTo4(b) + if z := ip.Zone(); z != "" { + b = append(b, '%') + b = append(b, z...) + } + return b, nil + } + return ip.appendTo6(b), nil + } + +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// The IP address is expected in a form accepted by ParseAddr. +// +// If text is empty, UnmarshalText sets *ip to the zero Addr and +// returns no error. +func (ip *Addr) UnmarshalText(text []byte) error { + if len(text) == 0 { + *ip = Addr{} + return nil + } + var err error + *ip, err = ParseAddr(string(text)) + return err +} + +func (ip Addr) marshalBinaryWithTrailingBytes(trailingBytes int) []byte { + var b []byte + switch ip.z { + case z0: + b = make([]byte, trailingBytes) + case z4: + b = make([]byte, 4+trailingBytes) + bePutUint32(b, uint32(ip.addr.lo)) + default: + z := ip.Zone() + b = make([]byte, 16+len(z)+trailingBytes) + bePutUint64(b[:8], ip.addr.hi) + bePutUint64(b[8:], ip.addr.lo) + copy(b[16:], z) + } + return b +} + +// MarshalBinary implements the encoding.BinaryMarshaler interface. +// It returns a zero-length slice for the zero Addr, +// the 4-byte form for an IPv4 address, +// and the 16-byte form with zone appended for an IPv6 address. +func (ip Addr) MarshalBinary() ([]byte, error) { + return ip.marshalBinaryWithTrailingBytes(0), nil +} + +// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. +// It expects data in the form generated by MarshalBinary. +func (ip *Addr) UnmarshalBinary(b []byte) error { + n := len(b) + switch { + case n == 0: + *ip = Addr{} + return nil + case n == 4: + *ip = AddrFrom4(*(*[4]byte)(b)) + return nil + case n == 16: + *ip = ipv6Slice(b) + return nil + case n > 16: + *ip = ipv6Slice(b[:16]).WithZone(string(b[16:])) + return nil + } + return errors.New("unexpected slice size") +} + +// AddrPort is an IP and a port number. +type AddrPort struct { + ip Addr + port uint16 +} + +// AddrPortFrom returns an AddrPort with the provided IP and port. +// It does not allocate. +func AddrPortFrom(ip Addr, port uint16) AddrPort { return AddrPort{ip: ip, port: port} } + +// Addr returns p's IP address. +func (p AddrPort) Addr() Addr { return p.ip } + +// Port returns p's port. +func (p AddrPort) Port() uint16 { return p.port } + +// splitAddrPort splits s into an IP address string and a port +// string. It splits strings shaped like "foo:bar" or "[foo]:bar", +// without further validating the substrings. v6 indicates whether the +// ip string should parse as an IPv6 address or an IPv4 address, in +// order for s to be a valid ip:port string. +func splitAddrPort(s string) (ip, port string, v6 bool, err error) { + i := stringsLastIndexByte(s, ':') + if i == -1 { + return "", "", false, errors.New("not an ip:port") + } + + ip, port = s[:i], s[i+1:] + if len(ip) == 0 { + return "", "", false, errors.New("no IP") + } + if len(port) == 0 { + return "", "", false, errors.New("no port") + } + if ip[0] == '[' { + if len(ip) < 2 || ip[len(ip)-1] != ']' { + return "", "", false, errors.New("missing ]") + } + ip = ip[1 : len(ip)-1] + v6 = true + } + + return ip, port, v6, nil +} + +// ParseAddrPort parses s as an AddrPort. +// +// It doesn't do any name resolution: both the address and the port +// must be numeric. +func ParseAddrPort(s string) (AddrPort, error) { + var ipp AddrPort + ip, port, v6, err := splitAddrPort(s) + if err != nil { + return ipp, err + } + port16, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return ipp, errors.New("invalid port " + strconv.Quote(port) + " parsing " + strconv.Quote(s)) + } + ipp.port = uint16(port16) + ipp.ip, err = ParseAddr(ip) + if err != nil { + return AddrPort{}, err + } + if v6 && ipp.ip.Is4() { + return AddrPort{}, errors.New("invalid ip:port " + strconv.Quote(s) + ", square brackets can only be used with IPv6 addresses") + } else if !v6 && ipp.ip.Is6() { + return AddrPort{}, errors.New("invalid ip:port " + strconv.Quote(s) + ", IPv6 addresses must be surrounded by square brackets") + } + return ipp, nil +} + +// MustParseAddrPort calls ParseAddrPort(s) and panics on error. +// It is intended for use in tests with hard-coded strings. +func MustParseAddrPort(s string) AddrPort { + ip, err := ParseAddrPort(s) + if err != nil { + panic(err) + } + return ip +} + +// isZero reports whether p is the zero AddrPort. +func (p AddrPort) isZero() bool { return p == AddrPort{} } + +// IsValid reports whether p.IP() is valid. +// All ports are valid, including zero. +func (p AddrPort) IsValid() bool { return p.ip.IsValid() } + +func (p AddrPort) String() string { + switch p.ip.z { + case z0: + return "invalid AddrPort" + case z4: + a := p.ip.As4() + buf := make([]byte, 0, 21) + for i := range a { + buf = strconv.AppendUint(buf, uint64(a[i]), 10) + buf = append(buf, "...:"[i]) + } + buf = strconv.AppendUint(buf, uint64(p.port), 10) + return string(buf) + default: + // TODO: this could be more efficient allocation-wise: + return joinHostPort(p.ip.String(), itoa.Itoa(int(p.port))) + } +} + +func joinHostPort(host, port string) string { + // We assume that host is a literal IPv6 address if host has + // colons. + if bytealg.IndexByteString(host, ':') >= 0 { + return "[" + host + "]:" + port + } + return host + ":" + port +} + +// AppendTo appends a text encoding of p, +// as generated by MarshalText, +// to b and returns the extended buffer. +func (p AddrPort) AppendTo(b []byte) []byte { + switch p.ip.z { + case z0: + return b + case z4: + b = p.ip.appendTo4(b) + default: + if p.ip.Is4In6() { + b = append(b, "[::ffff:"...) + b = p.ip.Unmap().appendTo4(b) + if z := p.ip.Zone(); z != "" { + b = append(b, '%') + b = append(b, z...) + } + } else { + b = append(b, '[') + b = p.ip.appendTo6(b) + } + b = append(b, ']') + } + b = append(b, ':') + b = strconv.AppendInt(b, int64(p.port), 10) + return b +} + +// MarshalText implements the encoding.TextMarshaler interface. The +// encoding is the same as returned by String, with one exception: if +// p.Addr() is the zero Addr, the encoding is the empty string. +func (p AddrPort) MarshalText() ([]byte, error) { + var max int + switch p.ip.z { + case z0: + case z4: + max = len("255.255.255.255:65535") + default: + max = len("[ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff%enp5s0]:65535") + } + b := make([]byte, 0, max) + b = p.AppendTo(b) + return b, nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler +// interface. The AddrPort is expected in a form +// generated by MarshalText or accepted by ParseAddrPort. +func (p *AddrPort) UnmarshalText(text []byte) error { + if len(text) == 0 { + *p = AddrPort{} + return nil + } + var err error + *p, err = ParseAddrPort(string(text)) + return err +} + +// MarshalBinary implements the encoding.BinaryMarshaler interface. +// It returns Addr.MarshalBinary with an additional two bytes appended +// containing the port in little-endian. +func (p AddrPort) MarshalBinary() ([]byte, error) { + b := p.Addr().marshalBinaryWithTrailingBytes(2) + lePutUint16(b[len(b)-2:], p.Port()) + return b, nil +} + +// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. +// It expects data in the form generated by MarshalBinary. +func (p *AddrPort) UnmarshalBinary(b []byte) error { + if len(b) < 2 { + return errors.New("unexpected slice size") + } + var addr Addr + err := addr.UnmarshalBinary(b[:len(b)-2]) + if err != nil { + return err + } + *p = AddrPortFrom(addr, leUint16(b[len(b)-2:])) + return nil +} + +// Prefix is an IP address prefix (CIDR) representing an IP network. +// +// The first Bits() of Addr() are specified. The remaining bits match any address. +// The range of Bits() is [0,32] for IPv4 or [0,128] for IPv6. +type Prefix struct { + ip Addr + + // bits is logically a uint8 (storing [0,128]) but also + // encodes an "invalid" bit, currently represented by the + // invalidPrefixBits sentinel value. It could be packed into + // the uint8 more with more complicated expressions in the + // accessors, but the extra byte (in padding anyway) doesn't + // hurt and simplifies code below. + bits int16 +} + +// invalidPrefixBits is the Prefix.bits value used when PrefixFrom is +// outside the range of a uint8. It's returned as the int -1 in the +// public API. +const invalidPrefixBits = -1 + +// PrefixFrom returns a Prefix with the provided IP address and bit +// prefix length. +// +// It does not allocate. Unlike Addr.Prefix, PrefixFrom does not mask +// off the host bits of ip. +// +// If bits is less than zero or greater than ip.BitLen, Prefix.Bits +// will return an invalid value -1. +func PrefixFrom(ip Addr, bits int) Prefix { + if bits < 0 || bits > ip.BitLen() { + bits = invalidPrefixBits + } + b16 := int16(bits) + return Prefix{ + ip: ip.withoutZone(), + bits: b16, + } +} + +// Addr returns p's IP address. +func (p Prefix) Addr() Addr { return p.ip } + +// Bits returns p's prefix length. +// +// It reports -1 if invalid. +func (p Prefix) Bits() int { return int(p.bits) } + +// IsValid reports whether p.Bits() has a valid range for p.IP(). +// If p.Addr() is the zero Addr, IsValid returns false. +// Note that if p is the zero Prefix, then p.IsValid() == false. +func (p Prefix) IsValid() bool { return !p.ip.isZero() && p.bits >= 0 && int(p.bits) <= p.ip.BitLen() } + +func (p Prefix) isZero() bool { return p == Prefix{} } + +// IsSingleIP reports whether p contains exactly one IP. +func (p Prefix) IsSingleIP() bool { return p.bits != 0 && int(p.bits) == p.ip.BitLen() } + +// ParsePrefix parses s as an IP address prefix. +// The string can be in the form "192.168.1.0/24" or "2001::db8::/32", +// the CIDR notation defined in RFC 4632 and RFC 4291. +// +// Note that masked address bits are not zeroed. Use Masked for that. +func ParsePrefix(s string) (Prefix, error) { + i := stringsLastIndexByte(s, '/') + if i < 0 { + return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + "): no '/'") + } + ip, err := ParseAddr(s[:i]) + if err != nil { + return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + "): " + err.Error()) + } + bitsStr := s[i+1:] + bits, err := strconv.Atoi(bitsStr) + if err != nil { + return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + ": bad bits after slash: " + strconv.Quote(bitsStr)) + } + maxBits := 32 + if ip.Is6() { + maxBits = 128 + } + if bits < 0 || bits > maxBits { + return Prefix{}, errors.New("netip.ParsePrefix(" + strconv.Quote(s) + ": prefix length out of range") + } + return PrefixFrom(ip, bits), nil +} + +// MustParsePrefix calls ParsePrefix(s) and panics on error. +// It is intended for use in tests with hard-coded strings. +func MustParsePrefix(s string) Prefix { + ip, err := ParsePrefix(s) + if err != nil { + panic(err) + } + return ip +} + +// Masked returns p in its canonical form, with all but the high +// p.Bits() bits of p.Addr() masked off. +// +// If p is zero or otherwise invalid, Masked returns the zero Prefix. +func (p Prefix) Masked() Prefix { + if m, err := p.ip.Prefix(int(p.bits)); err == nil { + return m + } + return Prefix{} +} + +// Contains reports whether the network p includes ip. +// +// An IPv4 address will not match an IPv6 prefix. +// A v6-mapped IPv6 address will not match an IPv4 prefix. +// A zero-value IP will not match any prefix. +// If ip has an IPv6 zone, Contains returns false, +// because Prefixes strip zones. +func (p Prefix) Contains(ip Addr) bool { + if !p.IsValid() || ip.hasZone() { + return false + } + if f1, f2 := p.ip.BitLen(), ip.BitLen(); f1 == 0 || f2 == 0 || f1 != f2 { + return false + } + if ip.Is4() { + // xor the IP addresses together; mismatched bits are now ones. + // Shift away the number of bits we don't care about. + // Shifts in Go are more efficient if the compiler can prove + // that the shift amount is smaller than the width of the shifted type (64 here). + // We know that p.bits is in the range 0..32 because p is Valid; + // the compiler doesn't know that, so mask with 63 to help it. + // Now truncate to 32 bits, because this is IPv4. + // If all the bits we care about are equal, the result will be zero. + return uint32((ip.addr.lo^p.ip.addr.lo)>>((32-p.bits)&63)) == 0 + } else { + // xor the IP addresses together. + // Mask away the bits we don't care about. + // If all the bits we care about are equal, the result will be zero. + return ip.addr.xor(p.ip.addr).and(mask6(int(p.bits))).isZero() + } +} + +// Overlaps reports whether p and o contain any IP addresses in common. +// +// If p and o are of different address families or either have a zero +// IP, it reports false. Like the Contains method, a prefix with a +// v6-mapped IPv4 IP is still treated as an IPv6 mask. +func (p Prefix) Overlaps(o Prefix) bool { + if !p.IsValid() || !o.IsValid() { + return false + } + if p == o { + return true + } + if p.ip.Is4() != o.ip.Is4() { + return false + } + var minBits int16 + if p.bits < o.bits { + minBits = p.bits + } else { + minBits = o.bits + } + if minBits == 0 { + return true + } + // One of these Prefix calls might look redundant, but we don't require + // that p and o values are normalized (via Prefix.Masked) first, + // so the Prefix call on the one that's already minBits serves to zero + // out any remaining bits in IP. + var err error + if p, err = p.ip.Prefix(int(minBits)); err != nil { + return false + } + if o, err = o.ip.Prefix(int(minBits)); err != nil { + return false + } + return p.ip == o.ip +} + +// AppendTo appends a text encoding of p, +// as generated by MarshalText, +// to b and returns the extended buffer. +func (p Prefix) AppendTo(b []byte) []byte { + if p.isZero() { + return b + } + if !p.IsValid() { + return append(b, "invalid Prefix"...) + } + + // p.ip is non-nil, because p is valid. + if p.ip.z == z4 { + b = p.ip.appendTo4(b) + } else { + if p.ip.Is4In6() { + b = append(b, "::ffff:"...) + b = p.ip.Unmap().appendTo4(b) + } else { + b = p.ip.appendTo6(b) + } + } + + b = append(b, '/') + b = appendDecimal(b, uint8(p.bits)) + return b +} + +// MarshalText implements the encoding.TextMarshaler interface, +// The encoding is the same as returned by String, with one exception: +// If p is the zero value, the encoding is the empty string. +func (p Prefix) MarshalText() ([]byte, error) { + var max int + switch p.ip.z { + case z0: + case z4: + max = len("255.255.255.255/32") + default: + max = len("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff%enp5s0/128") + } + b := make([]byte, 0, max) + b = p.AppendTo(b) + return b, nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// The IP address is expected in a form accepted by ParsePrefix +// or generated by MarshalText. +func (p *Prefix) UnmarshalText(text []byte) error { + if len(text) == 0 { + *p = Prefix{} + return nil + } + var err error + *p, err = ParsePrefix(string(text)) + return err +} + +// MarshalBinary implements the encoding.BinaryMarshaler interface. +// It returns Addr.MarshalBinary with an additional byte appended +// containing the prefix bits. +func (p Prefix) MarshalBinary() ([]byte, error) { + b := p.Addr().withoutZone().marshalBinaryWithTrailingBytes(1) + b[len(b)-1] = uint8(p.Bits()) + return b, nil +} + +// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. +// It expects data in the form generated by MarshalBinary. +func (p *Prefix) UnmarshalBinary(b []byte) error { + if len(b) < 1 { + return errors.New("unexpected slice size") + } + var addr Addr + err := addr.UnmarshalBinary(b[:len(b)-1]) + if err != nil { + return err + } + *p = PrefixFrom(addr, int(b[len(b)-1])) + return nil +} + +// String returns the CIDR notation of p: "<ip>/<bits>". +func (p Prefix) String() string { + if !p.IsValid() { + return "invalid Prefix" + } + return p.ip.String() + "/" + itoa.Itoa(int(p.bits)) +} diff --git a/libgo/go/net/netip/netip_pkg_test.go b/libgo/go/net/netip/netip_pkg_test.go new file mode 100644 index 0000000..f5cd9ee --- /dev/null +++ b/libgo/go/net/netip/netip_pkg_test.go @@ -0,0 +1,359 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package netip + +import ( + "bytes" + "encoding" + "encoding/json" + "strings" + "testing" +) + +var ( + mustPrefix = MustParsePrefix + mustIP = MustParseAddr +) + +func TestPrefixValid(t *testing.T) { + v4 := MustParseAddr("1.2.3.4") + v6 := MustParseAddr("::1") + tests := []struct { + ipp Prefix + want bool + }{ + {Prefix{v4, -2}, false}, + {Prefix{v4, -1}, false}, + {Prefix{v4, 0}, true}, + {Prefix{v4, 32}, true}, + {Prefix{v4, 33}, false}, + + {Prefix{v6, -2}, false}, + {Prefix{v6, -1}, false}, + {Prefix{v6, 0}, true}, + {Prefix{v6, 32}, true}, + {Prefix{v6, 128}, true}, + {Prefix{v6, 129}, false}, + + {Prefix{Addr{}, -2}, false}, + {Prefix{Addr{}, -1}, false}, + {Prefix{Addr{}, 0}, false}, + {Prefix{Addr{}, 32}, false}, + {Prefix{Addr{}, 128}, false}, + } + for _, tt := range tests { + got := tt.ipp.IsValid() + if got != tt.want { + t.Errorf("(%v).IsValid() = %v want %v", tt.ipp, got, tt.want) + } + } +} + +var nextPrevTests = []struct { + ip Addr + next Addr + prev Addr +}{ + {mustIP("10.0.0.1"), mustIP("10.0.0.2"), mustIP("10.0.0.0")}, + {mustIP("10.0.0.255"), mustIP("10.0.1.0"), mustIP("10.0.0.254")}, + {mustIP("127.0.0.1"), mustIP("127.0.0.2"), mustIP("127.0.0.0")}, + {mustIP("254.255.255.255"), mustIP("255.0.0.0"), mustIP("254.255.255.254")}, + {mustIP("255.255.255.255"), Addr{}, mustIP("255.255.255.254")}, + {mustIP("0.0.0.0"), mustIP("0.0.0.1"), Addr{}}, + {mustIP("::"), mustIP("::1"), Addr{}}, + {mustIP("::%x"), mustIP("::1%x"), Addr{}}, + {mustIP("::1"), mustIP("::2"), mustIP("::")}, + {mustIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"), Addr{}, mustIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffe")}, +} + +func TestIPNextPrev(t *testing.T) { + doNextPrev(t) + + for _, ip := range []Addr{ + mustIP("0.0.0.0"), + mustIP("::"), + } { + got := ip.Prev() + if !got.isZero() { + t.Errorf("IP(%v).Prev = %v; want zero", ip, got) + } + } + + var allFF [16]byte + for i := range allFF { + allFF[i] = 0xff + } + + for _, ip := range []Addr{ + mustIP("255.255.255.255"), + AddrFrom16(allFF), + } { + got := ip.Next() + if !got.isZero() { + t.Errorf("IP(%v).Next = %v; want zero", ip, got) + } + } +} + +func BenchmarkIPNextPrev(b *testing.B) { + for i := 0; i < b.N; i++ { + doNextPrev(b) + } +} + +func doNextPrev(t testing.TB) { + for _, tt := range nextPrevTests { + gnext, gprev := tt.ip.Next(), tt.ip.Prev() + if gnext != tt.next { + t.Errorf("IP(%v).Next = %v; want %v", tt.ip, gnext, tt.next) + } + if gprev != tt.prev { + t.Errorf("IP(%v).Prev = %v; want %v", tt.ip, gprev, tt.prev) + } + if !tt.ip.Next().isZero() && tt.ip.Next().Prev() != tt.ip { + t.Errorf("IP(%v).Next.Prev = %v; want %v", tt.ip, tt.ip.Next().Prev(), tt.ip) + } + if !tt.ip.Prev().isZero() && tt.ip.Prev().Next() != tt.ip { + t.Errorf("IP(%v).Prev.Next = %v; want %v", tt.ip, tt.ip.Prev().Next(), tt.ip) + } + } +} + +func TestIPBitLen(t *testing.T) { + tests := []struct { + ip Addr + want int + }{ + {Addr{}, 0}, + {mustIP("0.0.0.0"), 32}, + {mustIP("10.0.0.1"), 32}, + {mustIP("::"), 128}, + {mustIP("fed0::1"), 128}, + {mustIP("::ffff:10.0.0.1"), 128}, + } + for _, tt := range tests { + got := tt.ip.BitLen() + if got != tt.want { + t.Errorf("BitLen(%v) = %d; want %d", tt.ip, got, tt.want) + } + } +} + +func TestPrefixContains(t *testing.T) { + tests := []struct { + ipp Prefix + ip Addr + want bool + }{ + {mustPrefix("9.8.7.6/0"), mustIP("9.8.7.6"), true}, + {mustPrefix("9.8.7.6/16"), mustIP("9.8.7.6"), true}, + {mustPrefix("9.8.7.6/16"), mustIP("9.8.6.4"), true}, + {mustPrefix("9.8.7.6/16"), mustIP("9.9.7.6"), false}, + {mustPrefix("9.8.7.6/32"), mustIP("9.8.7.6"), true}, + {mustPrefix("9.8.7.6/32"), mustIP("9.8.7.7"), false}, + {mustPrefix("9.8.7.6/32"), mustIP("9.8.7.7"), false}, + {mustPrefix("::1/0"), mustIP("::1"), true}, + {mustPrefix("::1/0"), mustIP("::2"), true}, + {mustPrefix("::1/127"), mustIP("::1"), true}, + {mustPrefix("::1/127"), mustIP("::2"), false}, + {mustPrefix("::1/128"), mustIP("::1"), true}, + {mustPrefix("::1/127"), mustIP("::2"), false}, + // zones support + {mustPrefix("::1%a/128"), mustIP("::1"), true}, // prefix zones are stripped... + {mustPrefix("::1%a/128"), mustIP("::1%a"), false}, // but ip zones are not + // invalid IP + {mustPrefix("::1/0"), Addr{}, false}, + {mustPrefix("1.2.3.4/0"), Addr{}, false}, + // invalid Prefix + {Prefix{mustIP("::1"), 129}, mustIP("::1"), false}, + {Prefix{mustIP("1.2.3.4"), 33}, mustIP("1.2.3.4"), false}, + {Prefix{Addr{}, 0}, mustIP("1.2.3.4"), false}, + {Prefix{Addr{}, 32}, mustIP("1.2.3.4"), false}, + {Prefix{Addr{}, 128}, mustIP("::1"), false}, + // wrong IP family + {mustPrefix("::1/0"), mustIP("1.2.3.4"), false}, + {mustPrefix("1.2.3.4/0"), mustIP("::1"), false}, + } + for _, tt := range tests { + got := tt.ipp.Contains(tt.ip) + if got != tt.want { + t.Errorf("(%v).Contains(%v) = %v want %v", tt.ipp, tt.ip, got, tt.want) + } + } +} + +func TestParseIPError(t *testing.T) { + tests := []struct { + ip string + errstr string + }{ + { + ip: "localhost", + }, + { + ip: "500.0.0.1", + errstr: "field has value >255", + }, + { + ip: "::gggg%eth0", + errstr: "must have at least one digit", + }, + { + ip: "fe80::1cc0:3e8c:119f:c2e1%", + errstr: "zone must be a non-empty string", + }, + { + ip: "%eth0", + errstr: "missing IPv6 address", + }, + } + for _, test := range tests { + t.Run(test.ip, func(t *testing.T) { + _, err := ParseAddr(test.ip) + if err == nil { + t.Fatal("no error") + } + if _, ok := err.(parseAddrError); !ok { + t.Errorf("error type is %T, want parseIPError", err) + } + if test.errstr == "" { + test.errstr = "unable to parse IP" + } + if got := err.Error(); !strings.Contains(got, test.errstr) { + t.Errorf("error is missing substring %q: %s", test.errstr, got) + } + }) + } +} + +func TestParseAddrPort(t *testing.T) { + tests := []struct { + in string + want AddrPort + wantErr bool + }{ + {in: "1.2.3.4:1234", want: AddrPort{mustIP("1.2.3.4"), 1234}}, + {in: "1.1.1.1:123456", wantErr: true}, + {in: "1.1.1.1:-123", wantErr: true}, + {in: "[::1]:1234", want: AddrPort{mustIP("::1"), 1234}}, + {in: "[1.2.3.4]:1234", wantErr: true}, + {in: "fe80::1:1234", wantErr: true}, + {in: ":0", wantErr: true}, // if we need to parse this form, there should be a separate function that explicitly allows it + } + for _, test := range tests { + t.Run(test.in, func(t *testing.T) { + got, err := ParseAddrPort(test.in) + if err != nil { + if test.wantErr { + return + } + t.Fatal(err) + } + if got != test.want { + t.Errorf("got %v; want %v", got, test.want) + } + if got.String() != test.in { + t.Errorf("String = %q; want %q", got.String(), test.in) + } + }) + + t.Run(test.in+"/AppendTo", func(t *testing.T) { + got, err := ParseAddrPort(test.in) + if err == nil { + testAppendToMarshal(t, got) + } + }) + + // TextMarshal and TextUnmarshal mostly behave like + // ParseAddrPort and String. Divergent behavior are handled in + // TestAddrPortMarshalUnmarshal. + t.Run(test.in+"/Marshal", func(t *testing.T) { + var got AddrPort + jsin := `"` + test.in + `"` + err := json.Unmarshal([]byte(jsin), &got) + if err != nil { + if test.wantErr { + return + } + t.Fatal(err) + } + if got != test.want { + t.Errorf("got %v; want %v", got, test.want) + } + gotb, err := json.Marshal(got) + if err != nil { + t.Fatal(err) + } + if string(gotb) != jsin { + t.Errorf("Marshal = %q; want %q", string(gotb), jsin) + } + }) + } +} + +func TestAddrPortMarshalUnmarshal(t *testing.T) { + tests := []struct { + in string + want AddrPort + }{ + {"", AddrPort{}}, + } + + for _, test := range tests { + t.Run(test.in, func(t *testing.T) { + orig := `"` + test.in + `"` + + var ipp AddrPort + if err := json.Unmarshal([]byte(orig), &ipp); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + ippb, err := json.Marshal(ipp) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + back := string(ippb) + if orig != back { + t.Errorf("Marshal = %q; want %q", back, orig) + } + + testAppendToMarshal(t, ipp) + }) + } +} + +type appendMarshaler interface { + encoding.TextMarshaler + AppendTo([]byte) []byte +} + +// testAppendToMarshal tests that x's AppendTo and MarshalText methods yield the same results. +// x's MarshalText method must not return an error. +func testAppendToMarshal(t *testing.T, x appendMarshaler) { + t.Helper() + m, err := x.MarshalText() + if err != nil { + t.Fatalf("(%v).MarshalText: %v", x, err) + } + a := make([]byte, 0, len(m)) + a = x.AppendTo(a) + if !bytes.Equal(m, a) { + t.Errorf("(%v).MarshalText = %q, (%v).AppendTo = %q", x, m, x, a) + } +} + +func TestIPv6Accessor(t *testing.T) { + var a [16]byte + for i := range a { + a[i] = uint8(i) + 1 + } + ip := AddrFrom16(a) + for i := range a { + if got, want := ip.v6(uint8(i)), uint8(i)+1; got != want { + t.Errorf("v6(%v) = %v; want %v", i, got, want) + } + } +} diff --git a/libgo/go/net/netip/netip_test.go b/libgo/go/net/netip/netip_test.go new file mode 100644 index 0000000..d988864 --- /dev/null +++ b/libgo/go/net/netip/netip_test.go @@ -0,0 +1,1974 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package netip_test + +import ( + "bytes" + "encoding/json" + "flag" + "fmt" + "internal/intern" + "net" + . "net/netip" + "reflect" + "sort" + "strings" + "testing" +) + +var long = flag.Bool("long", false, "run long tests") + +type uint128 = Uint128 + +var ( + mustPrefix = MustParsePrefix + mustIP = MustParseAddr + mustIPPort = MustParseAddrPort +) + +func TestParseAddr(t *testing.T) { + var validIPs = []struct { + in string + ip Addr // output of ParseAddr() + str string // output of String(). If "", use in. + wantErr string + }{ + // Basic zero IPv4 address. + { + in: "0.0.0.0", + ip: MkAddr(Mk128(0, 0xffff00000000), Z4), + }, + // Basic non-zero IPv4 address. + { + in: "192.168.140.255", + ip: MkAddr(Mk128(0, 0xffffc0a88cff), Z4), + }, + // IPv4 address in windows-style "print all the digits" form. + { + in: "010.000.015.001", + wantErr: `ParseAddr("010.000.015.001"): IPv4 field has octet with leading zero`, + }, + // IPv4 address with a silly amount of leading zeros. + { + in: "000001.00000002.00000003.000000004", + wantErr: `ParseAddr("000001.00000002.00000003.000000004"): IPv4 field has octet with leading zero`, + }, + // 4-in-6 with octet with leading zero + { + in: "::ffff:1.2.03.4", + wantErr: `ParseAddr("::ffff:1.2.03.4"): ParseAddr("1.2.03.4"): IPv4 field has octet with leading zero (at "1.2.03.4")`, + }, + // Basic zero IPv6 address. + { + in: "::", + ip: MkAddr(Mk128(0, 0), Z6noz), + }, + // Localhost IPv6. + { + in: "::1", + ip: MkAddr(Mk128(0, 1), Z6noz), + }, + // Fully expanded IPv6 address. + { + in: "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b", + ip: MkAddr(Mk128(0xfd7a115ca1e0ab12, 0x4843cd96626b430b), Z6noz), + }, + // IPv6 with elided fields in the middle. + { + in: "fd7a:115c::626b:430b", + ip: MkAddr(Mk128(0xfd7a115c00000000, 0x00000000626b430b), Z6noz), + }, + // IPv6 with elided fields at the end. + { + in: "fd7a:115c:a1e0:ab12:4843:cd96::", + ip: MkAddr(Mk128(0xfd7a115ca1e0ab12, 0x4843cd9600000000), Z6noz), + }, + // IPv6 with single elided field at the end. + { + in: "fd7a:115c:a1e0:ab12:4843:cd96:626b::", + ip: MkAddr(Mk128(0xfd7a115ca1e0ab12, 0x4843cd96626b0000), Z6noz), + str: "fd7a:115c:a1e0:ab12:4843:cd96:626b:0", + }, + // IPv6 with single elided field in the middle. + { + in: "fd7a:115c:a1e0::4843:cd96:626b:430b", + ip: MkAddr(Mk128(0xfd7a115ca1e00000, 0x4843cd96626b430b), Z6noz), + str: "fd7a:115c:a1e0:0:4843:cd96:626b:430b", + }, + // IPv6 with the trailing 32 bits written as IPv4 dotted decimal. (4in6) + { + in: "::ffff:192.168.140.255", + ip: MkAddr(Mk128(0, 0x0000ffffc0a88cff), Z6noz), + str: "::ffff:192.168.140.255", + }, + // IPv6 with a zone specifier. + { + in: "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b%eth0", + ip: MkAddr(Mk128(0xfd7a115ca1e0ab12, 0x4843cd96626b430b), intern.Get("eth0")), + }, + // IPv6 with dotted decimal and zone specifier. + { + in: "1:2::ffff:192.168.140.255%eth1", + ip: MkAddr(Mk128(0x0001000200000000, 0x0000ffffc0a88cff), intern.Get("eth1")), + str: "1:2::ffff:c0a8:8cff%eth1", + }, + // 4-in-6 with zone + { + in: "::ffff:192.168.140.255%eth1", + ip: MkAddr(Mk128(0, 0x0000ffffc0a88cff), intern.Get("eth1")), + str: "::ffff:192.168.140.255%eth1", + }, + // IPv6 with capital letters. + { + in: "FD9E:1A04:F01D::1", + ip: MkAddr(Mk128(0xfd9e1a04f01d0000, 0x1), Z6noz), + str: "fd9e:1a04:f01d::1", + }, + } + + for _, test := range validIPs { + t.Run(test.in, func(t *testing.T) { + got, err := ParseAddr(test.in) + if err != nil { + if err.Error() == test.wantErr { + return + } + t.Fatal(err) + } + if test.wantErr != "" { + t.Fatalf("wanted error %q; got none", test.wantErr) + } + if got != test.ip { + t.Errorf("got %#v, want %#v", got, test.ip) + } + + // Check that ParseAddr is a pure function. + got2, err := ParseAddr(test.in) + if err != nil { + t.Fatal(err) + } + if got != got2 { + t.Errorf("ParseAddr(%q) got 2 different results: %#v, %#v", test.in, got, got2) + } + + // Check that ParseAddr(ip.String()) is the identity function. + s := got.String() + got3, err := ParseAddr(s) + if err != nil { + t.Fatal(err) + } + if got != got3 { + t.Errorf("ParseAddr(%q) != ParseAddr(ParseIP(%q).String()). Got %#v, want %#v", test.in, test.in, got3, got) + } + + // Check that the slow-but-readable parser produces the same result. + slow, err := parseIPSlow(test.in) + if err != nil { + t.Fatal(err) + } + if got != slow { + t.Errorf("ParseAddr(%q) = %#v, parseIPSlow(%q) = %#v", test.in, got, test.in, slow) + } + + // Check that the parsed IP formats as expected. + s = got.String() + wants := test.str + if wants == "" { + wants = test.in + } + if s != wants { + t.Errorf("ParseAddr(%q).String() got %q, want %q", test.in, s, wants) + } + + // Check that AppendTo matches MarshalText. + TestAppendToMarshal(t, got) + + // Check that MarshalText/UnmarshalText work similarly to + // ParseAddr/String (see TestIPMarshalUnmarshal for + // marshal-specific behavior that's not common with + // ParseAddr/String). + js := `"` + test.in + `"` + var jsgot Addr + if err := json.Unmarshal([]byte(js), &jsgot); err != nil { + t.Fatal(err) + } + if jsgot != got { + t.Errorf("json.Unmarshal(%q) = %#v, want %#v", test.in, jsgot, got) + } + jsb, err := json.Marshal(jsgot) + if err != nil { + t.Fatal(err) + } + jswant := `"` + wants + `"` + jsback := string(jsb) + if jsback != jswant { + t.Errorf("Marshal(Unmarshal(%q)) = %s, want %s", test.in, jsback, jswant) + } + }) + } + + var invalidIPs = []string{ + // Empty string + "", + // Garbage non-IP + "bad", + // Single number. Some parsers accept this as an IPv4 address in + // big-endian uint32 form, but we don't. + "1234", + // IPv4 with a zone specifier + "1.2.3.4%eth0", + // IPv4 field must have at least one digit + ".1.2.3", + "1.2.3.", + "1..2.3", + // IPv4 address too long + "1.2.3.4.5", + // IPv4 in dotted octal form + "0300.0250.0214.0377", + // IPv4 in dotted hex form + "0xc0.0xa8.0x8c.0xff", + // IPv4 in class B form + "192.168.12345", + // IPv4 in class B form, with a small enough number to be + // parseable as a regular dotted decimal field. + "127.0.1", + // IPv4 in class A form + "192.1234567", + // IPv4 in class A form, with a small enough number to be + // parseable as a regular dotted decimal field. + "127.1", + // IPv4 field has value >255 + "192.168.300.1", + // IPv4 with too many fields + "192.168.0.1.5.6", + // IPv6 with not enough fields + "1:2:3:4:5:6:7", + // IPv6 with too many fields + "1:2:3:4:5:6:7:8:9", + // IPv6 with 8 fields and a :: expander + "1:2:3:4::5:6:7:8", + // IPv6 with a field bigger than 2b + "fe801::1", + // IPv6 with non-hex values in field + "fe80:tail:scal:e::", + // IPv6 with a zone delimiter but no zone. + "fe80::1%", + // IPv6 (without ellipsis) with too many fields for trailing embedded IPv4. + "ffff:ffff:ffff:ffff:ffff:ffff:ffff:192.168.140.255", + // IPv6 (with ellipsis) with too many fields for trailing embedded IPv4. + "ffff::ffff:ffff:ffff:ffff:ffff:ffff:192.168.140.255", + // IPv6 with invalid embedded IPv4. + "::ffff:192.168.140.bad", + // IPv6 with multiple ellipsis ::. + "fe80::1::1", + // IPv6 with invalid non hex/colon character. + "fe80:1?:1", + // IPv6 with truncated bytes after single colon. + "fe80:", + } + + for _, s := range invalidIPs { + t.Run(s, func(t *testing.T) { + got, err := ParseAddr(s) + if err == nil { + t.Errorf("ParseAddr(%q) = %#v, want error", s, got) + } + + slow, err := parseIPSlow(s) + if err == nil { + t.Errorf("parseIPSlow(%q) = %#v, want error", s, slow) + } + + std := net.ParseIP(s) + if std != nil { + t.Errorf("net.ParseIP(%q) = %#v, want error", s, std) + } + + if s == "" { + // Don't test unmarshaling of "" here, do it in + // IPMarshalUnmarshal. + return + } + var jsgot Addr + js := []byte(`"` + s + `"`) + if err := json.Unmarshal(js, &jsgot); err == nil { + t.Errorf("json.Unmarshal(%q) = %#v, want error", s, jsgot) + } + }) + } +} + +func TestIPv4Constructors(t *testing.T) { + if AddrFrom4([4]byte{1, 2, 3, 4}) != MustParseAddr("1.2.3.4") { + t.Errorf("don't match") + } +} + +func TestAddrMarshalUnmarshalBinary(t *testing.T) { + tests := []struct { + ip string + wantSize int + }{ + {"", 0}, // zero IP + {"1.2.3.4", 4}, + {"fd7a:115c:a1e0:ab12:4843:cd96:626b:430b", 16}, + {"::ffff:c000:0280", 16}, + {"::ffff:c000:0280%eth0", 20}, + } + for _, tc := range tests { + var ip Addr + if len(tc.ip) > 0 { + ip = mustIP(tc.ip) + } + b, err := ip.MarshalBinary() + if err != nil { + t.Fatal(err) + } + if len(b) != tc.wantSize { + t.Fatalf("%q encoded to size %d; want %d", tc.ip, len(b), tc.wantSize) + } + var ip2 Addr + if err := ip2.UnmarshalBinary(b); err != nil { + t.Fatal(err) + } + if ip != ip2 { + t.Fatalf("got %v; want %v", ip2, ip) + } + } + + // Cannot unmarshal from unexpected IP length. + for _, n := range []int{3, 5} { + var ip2 Addr + if err := ip2.UnmarshalBinary(bytes.Repeat([]byte{1}, n)); err == nil { + t.Fatalf("unmarshaled from unexpected IP length %d", n) + } + } +} + +func TestAddrPortMarshalTextString(t *testing.T) { + tests := []struct { + in AddrPort + want string + }{ + {mustIPPort("1.2.3.4:80"), "1.2.3.4:80"}, + {mustIPPort("[1::CAFE]:80"), "[1::cafe]:80"}, + {mustIPPort("[1::CAFE%en0]:80"), "[1::cafe%en0]:80"}, + {mustIPPort("[::FFFF:192.168.140.255]:80"), "[::ffff:192.168.140.255]:80"}, + {mustIPPort("[::FFFF:192.168.140.255%en0]:80"), "[::ffff:192.168.140.255%en0]:80"}, + } + for i, tt := range tests { + if got := tt.in.String(); got != tt.want { + t.Errorf("%d. for (%v, %v) String = %q; want %q", i, tt.in.Addr(), tt.in.Port(), got, tt.want) + } + mt, err := tt.in.MarshalText() + if err != nil { + t.Errorf("%d. for (%v, %v) MarshalText error: %v", i, tt.in.Addr(), tt.in.Port(), err) + continue + } + if string(mt) != tt.want { + t.Errorf("%d. for (%v, %v) MarshalText = %q; want %q", i, tt.in.Addr(), tt.in.Port(), mt, tt.want) + } + } +} + +func TestAddrPortMarshalUnmarshalBinary(t *testing.T) { + tests := []struct { + ipport string + wantSize int + }{ + {"1.2.3.4:51820", 4 + 2}, + {"[fd7a:115c:a1e0:ab12:4843:cd96:626b:430b]:80", 16 + 2}, + {"[::ffff:c000:0280]:65535", 16 + 2}, + {"[::ffff:c000:0280%eth0]:1", 20 + 2}, + } + for _, tc := range tests { + var ipport AddrPort + if len(tc.ipport) > 0 { + ipport = mustIPPort(tc.ipport) + } + b, err := ipport.MarshalBinary() + if err != nil { + t.Fatal(err) + } + if len(b) != tc.wantSize { + t.Fatalf("%q encoded to size %d; want %d", tc.ipport, len(b), tc.wantSize) + } + var ipport2 AddrPort + if err := ipport2.UnmarshalBinary(b); err != nil { + t.Fatal(err) + } + if ipport != ipport2 { + t.Fatalf("got %v; want %v", ipport2, ipport) + } + } + + // Cannot unmarshal from unexpected lengths. + for _, n := range []int{3, 7} { + var ipport2 AddrPort + if err := ipport2.UnmarshalBinary(bytes.Repeat([]byte{1}, n)); err == nil { + t.Fatalf("unmarshaled from unexpected length %d", n) + } + } +} + +func TestPrefixMarshalTextString(t *testing.T) { + tests := []struct { + in Prefix + want string + }{ + {mustPrefix("1.2.3.4/24"), "1.2.3.4/24"}, + {mustPrefix("fd7a:115c:a1e0:ab12:4843:cd96:626b:430b/118"), "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b/118"}, + {mustPrefix("::ffff:c000:0280/96"), "::ffff:192.0.2.128/96"}, + {mustPrefix("::ffff:c000:0280%eth0/37"), "::ffff:192.0.2.128/37"}, // Zone should be stripped + {mustPrefix("::ffff:192.168.140.255/8"), "::ffff:192.168.140.255/8"}, + } + for i, tt := range tests { + if got := tt.in.String(); got != tt.want { + t.Errorf("%d. for %v String = %q; want %q", i, tt.in, got, tt.want) + } + mt, err := tt.in.MarshalText() + if err != nil { + t.Errorf("%d. for %v MarshalText error: %v", i, tt.in, err) + continue + } + if string(mt) != tt.want { + t.Errorf("%d. for %v MarshalText = %q; want %q", i, tt.in, mt, tt.want) + } + } +} + +func TestPrefixMarshalUnmarshalBinary(t *testing.T) { + type testCase struct { + prefix Prefix + wantSize int + } + tests := []testCase{ + {mustPrefix("1.2.3.4/24"), 4 + 1}, + {mustPrefix("fd7a:115c:a1e0:ab12:4843:cd96:626b:430b/118"), 16 + 1}, + {mustPrefix("::ffff:c000:0280/96"), 16 + 1}, + {mustPrefix("::ffff:c000:0280%eth0/37"), 16 + 1}, // Zone should be stripped + } + tests = append(tests, + testCase{PrefixFrom(tests[0].prefix.Addr(), 33), tests[0].wantSize}, + testCase{PrefixFrom(tests[1].prefix.Addr(), 129), tests[1].wantSize}) + for _, tc := range tests { + prefix := tc.prefix + b, err := prefix.MarshalBinary() + if err != nil { + t.Fatal(err) + } + if len(b) != tc.wantSize { + t.Fatalf("%q encoded to size %d; want %d", tc.prefix, len(b), tc.wantSize) + } + var prefix2 Prefix + if err := prefix2.UnmarshalBinary(b); err != nil { + t.Fatal(err) + } + if prefix != prefix2 { + t.Fatalf("got %v; want %v", prefix2, prefix) + } + } + + // Cannot unmarshal from unexpected lengths. + for _, n := range []int{3, 6} { + var prefix2 Prefix + if err := prefix2.UnmarshalBinary(bytes.Repeat([]byte{1}, n)); err == nil { + t.Fatalf("unmarshaled from unexpected length %d", n) + } + } +} + +func TestAddrMarshalUnmarshal(t *testing.T) { + // This only tests the cases where Marshal/Unmarshal diverges from + // the behavior of ParseAddr/String. For the rest of the test cases, + // see TestParseAddr above. + orig := `""` + var ip Addr + if err := json.Unmarshal([]byte(orig), &ip); err != nil { + t.Fatalf("Unmarshal(%q) got error %v", orig, err) + } + if ip != (Addr{}) { + t.Errorf("Unmarshal(%q) is not the zero Addr", orig) + } + + jsb, err := json.Marshal(ip) + if err != nil { + t.Fatalf("Marshal(%v) got error %v", ip, err) + } + back := string(jsb) + if back != orig { + t.Errorf("Marshal(Unmarshal(%q)) got %q, want %q", orig, back, orig) + } +} + +func TestAddrFrom16(t *testing.T) { + tests := []struct { + name string + in [16]byte + want Addr + }{ + { + name: "v6-raw", + in: [...]byte{15: 1}, + want: MkAddr(Mk128(0, 1), Z6noz), + }, + { + name: "v4-raw", + in: [...]byte{10: 0xff, 11: 0xff, 12: 1, 13: 2, 14: 3, 15: 4}, + want: MkAddr(Mk128(0, 0xffff01020304), Z6noz), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := AddrFrom16(tt.in) + if got != tt.want { + t.Errorf("got %#v; want %#v", got, tt.want) + } + }) + } +} + +func TestIPProperties(t *testing.T) { + var ( + nilIP Addr + + unicast4 = mustIP("192.0.2.1") + unicast6 = mustIP("2001:db8::1") + unicastZone6 = mustIP("2001:db8::1%eth0") + unicast6Unassigned = mustIP("4000::1") // not in 2000::/3. + + multicast4 = mustIP("224.0.0.1") + multicast6 = mustIP("ff02::1") + multicastZone6 = mustIP("ff02::1%eth0") + + llu4 = mustIP("169.254.0.1") + llu6 = mustIP("fe80::1") + llu6Last = mustIP("febf:ffff:ffff:ffff:ffff:ffff:ffff:ffff") + lluZone6 = mustIP("fe80::1%eth0") + + loopback4 = mustIP("127.0.0.1") + loopback6 = mustIP("::1") + + ilm6 = mustIP("ff01::1") + ilmZone6 = mustIP("ff01::1%eth0") + + private4a = mustIP("10.0.0.1") + private4b = mustIP("172.16.0.1") + private4c = mustIP("192.168.1.1") + private6 = mustIP("fd00::1") + + unspecified4 = AddrFrom4([4]byte{}) + unspecified6 = IPv6Unspecified() + ) + + tests := []struct { + name string + ip Addr + globalUnicast bool + interfaceLocalMulticast bool + linkLocalMulticast bool + linkLocalUnicast bool + loopback bool + multicast bool + private bool + unspecified bool + }{ + { + name: "nil", + ip: nilIP, + }, + { + name: "unicast v4Addr", + ip: unicast4, + globalUnicast: true, + }, + { + name: "unicast v6Addr", + ip: unicast6, + globalUnicast: true, + }, + { + name: "unicast v6AddrZone", + ip: unicastZone6, + globalUnicast: true, + }, + { + name: "unicast v6Addr unassigned", + ip: unicast6Unassigned, + globalUnicast: true, + }, + { + name: "multicast v4Addr", + ip: multicast4, + linkLocalMulticast: true, + multicast: true, + }, + { + name: "multicast v6Addr", + ip: multicast6, + linkLocalMulticast: true, + multicast: true, + }, + { + name: "multicast v6AddrZone", + ip: multicastZone6, + linkLocalMulticast: true, + multicast: true, + }, + { + name: "link-local unicast v4Addr", + ip: llu4, + linkLocalUnicast: true, + }, + { + name: "link-local unicast v6Addr", + ip: llu6, + linkLocalUnicast: true, + }, + { + name: "link-local unicast v6Addr upper bound", + ip: llu6Last, + linkLocalUnicast: true, + }, + { + name: "link-local unicast v6AddrZone", + ip: lluZone6, + linkLocalUnicast: true, + }, + { + name: "loopback v4Addr", + ip: loopback4, + loopback: true, + }, + { + name: "loopback v6Addr", + ip: loopback6, + loopback: true, + }, + { + name: "interface-local multicast v6Addr", + ip: ilm6, + interfaceLocalMulticast: true, + multicast: true, + }, + { + name: "interface-local multicast v6AddrZone", + ip: ilmZone6, + interfaceLocalMulticast: true, + multicast: true, + }, + { + name: "private v4Addr 10/8", + ip: private4a, + globalUnicast: true, + private: true, + }, + { + name: "private v4Addr 172.16/12", + ip: private4b, + globalUnicast: true, + private: true, + }, + { + name: "private v4Addr 192.168/16", + ip: private4c, + globalUnicast: true, + private: true, + }, + { + name: "private v6Addr", + ip: private6, + globalUnicast: true, + private: true, + }, + { + name: "unspecified v4Addr", + ip: unspecified4, + unspecified: true, + }, + { + name: "unspecified v6Addr", + ip: unspecified6, + unspecified: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gu := tt.ip.IsGlobalUnicast() + if gu != tt.globalUnicast { + t.Errorf("IsGlobalUnicast(%v) = %v; want %v", tt.ip, gu, tt.globalUnicast) + } + + ilm := tt.ip.IsInterfaceLocalMulticast() + if ilm != tt.interfaceLocalMulticast { + t.Errorf("IsInterfaceLocalMulticast(%v) = %v; want %v", tt.ip, ilm, tt.interfaceLocalMulticast) + } + + llu := tt.ip.IsLinkLocalUnicast() + if llu != tt.linkLocalUnicast { + t.Errorf("IsLinkLocalUnicast(%v) = %v; want %v", tt.ip, llu, tt.linkLocalUnicast) + } + + llm := tt.ip.IsLinkLocalMulticast() + if llm != tt.linkLocalMulticast { + t.Errorf("IsLinkLocalMulticast(%v) = %v; want %v", tt.ip, llm, tt.linkLocalMulticast) + } + + lo := tt.ip.IsLoopback() + if lo != tt.loopback { + t.Errorf("IsLoopback(%v) = %v; want %v", tt.ip, lo, tt.loopback) + } + + multicast := tt.ip.IsMulticast() + if multicast != tt.multicast { + t.Errorf("IsMulticast(%v) = %v; want %v", tt.ip, multicast, tt.multicast) + } + + private := tt.ip.IsPrivate() + if private != tt.private { + t.Errorf("IsPrivate(%v) = %v; want %v", tt.ip, private, tt.private) + } + + unspecified := tt.ip.IsUnspecified() + if unspecified != tt.unspecified { + t.Errorf("IsUnspecified(%v) = %v; want %v", tt.ip, unspecified, tt.unspecified) + } + }) + } +} + +func TestAddrWellKnown(t *testing.T) { + tests := []struct { + name string + ip Addr + std net.IP + }{ + { + name: "IPv6 link-local all nodes", + ip: IPv6LinkLocalAllNodes(), + std: net.IPv6linklocalallnodes, + }, + { + name: "IPv6 unspecified", + ip: IPv6Unspecified(), + std: net.IPv6unspecified, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + want := tt.std.String() + got := tt.ip.String() + + if got != want { + t.Fatalf("got %s, want %s", got, want) + } + }) + } +} + +func TestLessCompare(t *testing.T) { + tests := []struct { + a, b Addr + want bool + }{ + {Addr{}, Addr{}, false}, + {Addr{}, mustIP("1.2.3.4"), true}, + {mustIP("1.2.3.4"), Addr{}, false}, + + {mustIP("1.2.3.4"), mustIP("0102:0304::0"), true}, + {mustIP("0102:0304::0"), mustIP("1.2.3.4"), false}, + {mustIP("1.2.3.4"), mustIP("1.2.3.4"), false}, + + {mustIP("::1"), mustIP("::2"), true}, + {mustIP("::1"), mustIP("::1%foo"), true}, + {mustIP("::1%foo"), mustIP("::2"), true}, + {mustIP("::2"), mustIP("::3"), true}, + + {mustIP("::"), mustIP("0.0.0.0"), false}, + {mustIP("0.0.0.0"), mustIP("::"), true}, + + {mustIP("::1%a"), mustIP("::1%b"), true}, + {mustIP("::1%a"), mustIP("::1%a"), false}, + {mustIP("::1%b"), mustIP("::1%a"), false}, + } + for _, tt := range tests { + got := tt.a.Less(tt.b) + if got != tt.want { + t.Errorf("Less(%q, %q) = %v; want %v", tt.a, tt.b, got, tt.want) + } + cmp := tt.a.Compare(tt.b) + if got && cmp != -1 { + t.Errorf("Less(%q, %q) = true, but Compare = %v (not -1)", tt.a, tt.b, cmp) + } + if cmp < -1 || cmp > 1 { + t.Errorf("bogus Compare return value %v", cmp) + } + if cmp == 0 && tt.a != tt.b { + t.Errorf("Compare(%q, %q) = 0; but not equal", tt.a, tt.b) + } + if cmp == 1 && !tt.b.Less(tt.a) { + t.Errorf("Compare(%q, %q) = 1; but b.Less(a) isn't true", tt.a, tt.b) + } + + // Also check inverse. + if got == tt.want && got { + got2 := tt.b.Less(tt.a) + if got2 { + t.Errorf("Less(%q, %q) was correctly %v, but so was Less(%q, %q)", tt.a, tt.b, got, tt.b, tt.a) + } + } + } + + // And just sort. + values := []Addr{ + mustIP("::1"), + mustIP("::2"), + Addr{}, + mustIP("1.2.3.4"), + mustIP("8.8.8.8"), + mustIP("::1%foo"), + } + sort.Slice(values, func(i, j int) bool { return values[i].Less(values[j]) }) + got := fmt.Sprintf("%s", values) + want := `[invalid IP 1.2.3.4 8.8.8.8 ::1 ::1%foo ::2]` + if got != want { + t.Errorf("unexpected sort\n got: %s\nwant: %s\n", got, want) + } +} + +func TestIPStringExpanded(t *testing.T) { + tests := []struct { + ip Addr + s string + }{ + { + ip: Addr{}, + s: "invalid IP", + }, + { + ip: mustIP("192.0.2.1"), + s: "192.0.2.1", + }, + { + ip: mustIP("::ffff:192.0.2.1"), + s: "0000:0000:0000:0000:0000:ffff:c000:0201", + }, + { + ip: mustIP("2001:db8::1"), + s: "2001:0db8:0000:0000:0000:0000:0000:0001", + }, + { + ip: mustIP("2001:db8::1%eth0"), + s: "2001:0db8:0000:0000:0000:0000:0000:0001%eth0", + }, + } + + for _, tt := range tests { + t.Run(tt.ip.String(), func(t *testing.T) { + want := tt.s + got := tt.ip.StringExpanded() + + if got != want { + t.Fatalf("got %s, want %s", got, want) + } + }) + } +} + +func TestPrefixMasking(t *testing.T) { + type subtest struct { + ip Addr + bits uint8 + p Prefix + ok bool + } + + // makeIPv6 produces a set of IPv6 subtests with an optional zone identifier. + makeIPv6 := func(zone string) []subtest { + if zone != "" { + zone = "%" + zone + } + + return []subtest{ + { + ip: mustIP(fmt.Sprintf("2001:db8::1%s", zone)), + bits: 255, + }, + { + ip: mustIP(fmt.Sprintf("2001:db8::1%s", zone)), + bits: 32, + p: mustPrefix(fmt.Sprintf("2001:db8::%s/32", zone)), + ok: true, + }, + { + ip: mustIP(fmt.Sprintf("fe80::dead:beef:dead:beef%s", zone)), + bits: 96, + p: mustPrefix(fmt.Sprintf("fe80::dead:beef:0:0%s/96", zone)), + ok: true, + }, + { + ip: mustIP(fmt.Sprintf("aaaa::%s", zone)), + bits: 4, + p: mustPrefix(fmt.Sprintf("a000::%s/4", zone)), + ok: true, + }, + { + ip: mustIP(fmt.Sprintf("::%s", zone)), + bits: 63, + p: mustPrefix(fmt.Sprintf("::%s/63", zone)), + ok: true, + }, + } + } + + tests := []struct { + family string + subtests []subtest + }{ + { + family: "nil", + subtests: []subtest{ + { + bits: 255, + ok: true, + }, + { + bits: 16, + ok: true, + }, + }, + }, + { + family: "IPv4", + subtests: []subtest{ + { + ip: mustIP("192.0.2.0"), + bits: 255, + }, + { + ip: mustIP("192.0.2.0"), + bits: 16, + p: mustPrefix("192.0.0.0/16"), + ok: true, + }, + { + ip: mustIP("255.255.255.255"), + bits: 20, + p: mustPrefix("255.255.240.0/20"), + ok: true, + }, + { + // Partially masking one byte that contains both + // 1s and 0s on either side of the mask limit. + ip: mustIP("100.98.156.66"), + bits: 10, + p: mustPrefix("100.64.0.0/10"), + ok: true, + }, + }, + }, + { + family: "IPv6", + subtests: makeIPv6(""), + }, + { + family: "IPv6 zone", + subtests: makeIPv6("eth0"), + }, + } + + for _, tt := range tests { + t.Run(tt.family, func(t *testing.T) { + for _, st := range tt.subtests { + t.Run(st.p.String(), func(t *testing.T) { + // Ensure st.ip is not mutated. + orig := st.ip.String() + + p, err := st.ip.Prefix(int(st.bits)) + if st.ok && err != nil { + t.Fatalf("failed to produce prefix: %v", err) + } + if !st.ok && err == nil { + t.Fatal("expected an error, but none occurred") + } + if err != nil { + t.Logf("err: %v", err) + return + } + + if !reflect.DeepEqual(p, st.p) { + t.Errorf("prefix = %q, want %q", p, st.p) + } + + if got := st.ip.String(); got != orig { + t.Errorf("IP was mutated: %q, want %q", got, orig) + } + }) + } + }) + } +} + +func TestPrefixMarshalUnmarshal(t *testing.T) { + tests := []string{ + "", + "1.2.3.4/32", + "0.0.0.0/0", + "::/0", + "::1/128", + "2001:db8::/32", + } + + for _, s := range tests { + t.Run(s, func(t *testing.T) { + // Ensure that JSON (and by extension, text) marshaling is + // sane by entering quoted input. + orig := `"` + s + `"` + + var p Prefix + if err := json.Unmarshal([]byte(orig), &p); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + pb, err := json.Marshal(p) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + back := string(pb) + if orig != back { + t.Errorf("Marshal = %q; want %q", back, orig) + } + }) + } +} + +func TestPrefixMarshalUnmarshalZone(t *testing.T) { + orig := `"fe80::1cc0:3e8c:119f:c2e1%ens18/128"` + unzoned := `"fe80::1cc0:3e8c:119f:c2e1/128"` + + var p Prefix + if err := json.Unmarshal([]byte(orig), &p); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + pb, err := json.Marshal(p) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + back := string(pb) + if back != unzoned { + t.Errorf("Marshal = %q; want %q", back, unzoned) + } +} + +func TestPrefixUnmarshalTextNonZero(t *testing.T) { + ip := mustPrefix("fe80::/64") + if err := ip.UnmarshalText([]byte("xxx")); err == nil { + t.Fatal("unmarshaled into non-empty Prefix") + } +} + +func TestIs4AndIs6(t *testing.T) { + tests := []struct { + ip Addr + is4 bool + is6 bool + }{ + {Addr{}, false, false}, + {mustIP("1.2.3.4"), true, false}, + {mustIP("127.0.0.2"), true, false}, + {mustIP("::1"), false, true}, + {mustIP("::ffff:192.0.2.128"), false, true}, + {mustIP("::fffe:c000:0280"), false, true}, + {mustIP("::1%eth0"), false, true}, + } + for _, tt := range tests { + got4 := tt.ip.Is4() + if got4 != tt.is4 { + t.Errorf("Is4(%q) = %v; want %v", tt.ip, got4, tt.is4) + } + + got6 := tt.ip.Is6() + if got6 != tt.is6 { + t.Errorf("Is6(%q) = %v; want %v", tt.ip, got6, tt.is6) + } + } +} + +func TestIs4In6(t *testing.T) { + tests := []struct { + ip Addr + want bool + wantUnmap Addr + }{ + {Addr{}, false, Addr{}}, + {mustIP("::ffff:c000:0280"), true, mustIP("192.0.2.128")}, + {mustIP("::ffff:192.0.2.128"), true, mustIP("192.0.2.128")}, + {mustIP("::ffff:192.0.2.128%eth0"), true, mustIP("192.0.2.128")}, + {mustIP("::fffe:c000:0280"), false, mustIP("::fffe:c000:0280")}, + {mustIP("::ffff:127.1.2.3"), true, mustIP("127.1.2.3")}, + {mustIP("::ffff:7f01:0203"), true, mustIP("127.1.2.3")}, + {mustIP("0:0:0:0:0000:ffff:127.1.2.3"), true, mustIP("127.1.2.3")}, + {mustIP("0:0:0:0:000000:ffff:127.1.2.3"), true, mustIP("127.1.2.3")}, + {mustIP("0:0:0:0::ffff:127.1.2.3"), true, mustIP("127.1.2.3")}, + {mustIP("::1"), false, mustIP("::1")}, + {mustIP("1.2.3.4"), false, mustIP("1.2.3.4")}, + } + for _, tt := range tests { + got := tt.ip.Is4In6() + if got != tt.want { + t.Errorf("Is4In6(%q) = %v; want %v", tt.ip, got, tt.want) + } + u := tt.ip.Unmap() + if u != tt.wantUnmap { + t.Errorf("Unmap(%q) = %v; want %v", tt.ip, u, tt.wantUnmap) + } + } +} + +func TestPrefixMasked(t *testing.T) { + tests := []struct { + prefix Prefix + masked Prefix + }{ + { + prefix: mustPrefix("192.168.0.255/24"), + masked: mustPrefix("192.168.0.0/24"), + }, + { + prefix: mustPrefix("2100::/3"), + masked: mustPrefix("2000::/3"), + }, + { + prefix: PrefixFrom(mustIP("2000::"), 129), + masked: Prefix{}, + }, + { + prefix: PrefixFrom(mustIP("1.2.3.4"), 33), + masked: Prefix{}, + }, + } + for _, test := range tests { + t.Run(test.prefix.String(), func(t *testing.T) { + got := test.prefix.Masked() + if got != test.masked { + t.Errorf("Masked=%s, want %s", got, test.masked) + } + }) + } +} + +func TestPrefix(t *testing.T) { + tests := []struct { + prefix string + ip Addr + bits int + str string + contains []Addr + notContains []Addr + }{ + { + prefix: "192.168.0.0/24", + ip: mustIP("192.168.0.0"), + bits: 24, + contains: mustIPs("192.168.0.1", "192.168.0.55"), + notContains: mustIPs("192.168.1.1", "1.1.1.1"), + }, + { + prefix: "192.168.1.1/32", + ip: mustIP("192.168.1.1"), + bits: 32, + contains: mustIPs("192.168.1.1"), + notContains: mustIPs("192.168.1.2"), + }, + { + prefix: "100.64.0.0/10", // CGNAT range; prefix not multiple of 8 + ip: mustIP("100.64.0.0"), + bits: 10, + contains: mustIPs("100.64.0.0", "100.64.0.1", "100.81.251.94", "100.100.100.100", "100.127.255.254", "100.127.255.255"), + notContains: mustIPs("100.63.255.255", "100.128.0.0"), + }, + { + prefix: "2001:db8::/96", + ip: mustIP("2001:db8::"), + bits: 96, + contains: mustIPs("2001:db8::aaaa:bbbb", "2001:db8::1"), + notContains: mustIPs("2001:db8::1:aaaa:bbbb", "2001:db9::"), + }, + { + prefix: "0.0.0.0/0", + ip: mustIP("0.0.0.0"), + bits: 0, + contains: mustIPs("192.168.0.1", "1.1.1.1"), + notContains: append(mustIPs("2001:db8::1"), Addr{}), + }, + { + prefix: "::/0", + ip: mustIP("::"), + bits: 0, + contains: mustIPs("::1", "2001:db8::1"), + notContains: mustIPs("192.0.2.1"), + }, + { + prefix: "2000::/3", + ip: mustIP("2000::"), + bits: 3, + contains: mustIPs("2001:db8::1"), + notContains: mustIPs("fe80::1"), + }, + { + prefix: "::%0/00/80", + ip: mustIP("::"), + bits: 80, + str: "::/80", + contains: mustIPs("::"), + notContains: mustIPs("ff::%0/00", "ff::%1/23", "::%0/00", "::%1/23"), + }, + } + for _, test := range tests { + t.Run(test.prefix, func(t *testing.T) { + prefix, err := ParsePrefix(test.prefix) + if err != nil { + t.Fatal(err) + } + if prefix.Addr() != test.ip { + t.Errorf("IP=%s, want %s", prefix.Addr(), test.ip) + } + if prefix.Bits() != test.bits { + t.Errorf("bits=%d, want %d", prefix.Bits(), test.bits) + } + for _, ip := range test.contains { + if !prefix.Contains(ip) { + t.Errorf("does not contain %s", ip) + } + } + for _, ip := range test.notContains { + if prefix.Contains(ip) { + t.Errorf("contains %s", ip) + } + } + want := test.str + if want == "" { + want = test.prefix + } + if got := prefix.String(); got != want { + t.Errorf("prefix.String()=%q, want %q", got, want) + } + + TestAppendToMarshal(t, prefix) + }) + } +} + +func TestPrefixFromInvalidBits(t *testing.T) { + v4 := MustParseAddr("1.2.3.4") + v6 := MustParseAddr("66::66") + tests := []struct { + ip Addr + in, want int + }{ + {v4, 0, 0}, + {v6, 0, 0}, + {v4, 1, 1}, + {v4, 33, -1}, + {v6, 33, 33}, + {v6, 127, 127}, + {v6, 128, 128}, + {v4, 254, -1}, + {v4, 255, -1}, + {v4, -1, -1}, + {v6, -1, -1}, + {v4, -5, -1}, + {v6, -5, -1}, + } + for _, tt := range tests { + p := PrefixFrom(tt.ip, tt.in) + if got := p.Bits(); got != tt.want { + t.Errorf("for (%v, %v), Bits out = %v; want %v", tt.ip, tt.in, got, tt.want) + } + } +} + +func TestParsePrefixAllocs(t *testing.T) { + tests := []struct { + ip string + slash string + }{ + {"192.168.1.0", "/24"}, + {"aaaa:bbbb:cccc::", "/24"}, + } + for _, test := range tests { + prefix := test.ip + test.slash + t.Run(prefix, func(t *testing.T) { + ipAllocs := int(testing.AllocsPerRun(5, func() { + ParseAddr(test.ip) + })) + prefixAllocs := int(testing.AllocsPerRun(5, func() { + ParsePrefix(prefix) + })) + if got := prefixAllocs - ipAllocs; got != 0 { + t.Errorf("allocs=%d, want 0", got) + } + }) + } +} + +func TestParsePrefixError(t *testing.T) { + tests := []struct { + prefix string + errstr string + }{ + { + prefix: "192.168.0.0", + errstr: "no '/'", + }, + { + prefix: "1.257.1.1/24", + errstr: "value >255", + }, + { + prefix: "1.1.1.0/q", + errstr: "bad bits", + }, + { + prefix: "1.1.1.0/-1", + errstr: "out of range", + }, + { + prefix: "1.1.1.0/33", + errstr: "out of range", + }, + { + prefix: "2001::/129", + errstr: "out of range", + }, + } + for _, test := range tests { + t.Run(test.prefix, func(t *testing.T) { + _, err := ParsePrefix(test.prefix) + if err == nil { + t.Fatal("no error") + } + if got := err.Error(); !strings.Contains(got, test.errstr) { + t.Errorf("error is missing substring %q: %s", test.errstr, got) + } + }) + } +} + +func TestPrefixIsSingleIP(t *testing.T) { + tests := []struct { + ipp Prefix + want bool + }{ + {ipp: mustPrefix("127.0.0.1/32"), want: true}, + {ipp: mustPrefix("127.0.0.1/31"), want: false}, + {ipp: mustPrefix("127.0.0.1/0"), want: false}, + {ipp: mustPrefix("::1/128"), want: true}, + {ipp: mustPrefix("::1/127"), want: false}, + {ipp: mustPrefix("::1/0"), want: false}, + {ipp: Prefix{}, want: false}, + } + for _, tt := range tests { + got := tt.ipp.IsSingleIP() + if got != tt.want { + t.Errorf("IsSingleIP(%v) = %v want %v", tt.ipp, got, tt.want) + } + } +} + +func mustIPs(strs ...string) []Addr { + var res []Addr + for _, s := range strs { + res = append(res, mustIP(s)) + } + return res +} + +func BenchmarkBinaryMarshalRoundTrip(b *testing.B) { + b.ReportAllocs() + tests := []struct { + name string + ip string + }{ + {"ipv4", "1.2.3.4"}, + {"ipv6", "2001:db8::1"}, + {"ipv6+zone", "2001:db8::1%eth0"}, + } + for _, tc := range tests { + b.Run(tc.name, func(b *testing.B) { + ip := mustIP(tc.ip) + for i := 0; i < b.N; i++ { + bt, err := ip.MarshalBinary() + if err != nil { + b.Fatal(err) + } + var ip2 Addr + if err := ip2.UnmarshalBinary(bt); err != nil { + b.Fatal(err) + } + } + }) + } +} + +func BenchmarkStdIPv4(b *testing.B) { + b.ReportAllocs() + ips := []net.IP{} + for i := 0; i < b.N; i++ { + ip := net.IPv4(8, 8, 8, 8) + ips = ips[:0] + for i := 0; i < 100; i++ { + ips = append(ips, ip) + } + } +} + +func BenchmarkIPv4(b *testing.B) { + b.ReportAllocs() + ips := []Addr{} + for i := 0; i < b.N; i++ { + ip := IPv4(8, 8, 8, 8) + ips = ips[:0] + for i := 0; i < 100; i++ { + ips = append(ips, ip) + } + } +} + +// ip4i was one of the possible representations of IP that came up in +// discussions, inlining IPv4 addresses, but having an "overflow" +// interface for IPv6 or IPv6 + zone. This is here for benchmarking. +type ip4i struct { + ip4 [4]byte + flags1 byte + flags2 byte + flags3 byte + flags4 byte + ipv6 any +} + +func newip4i_v4(a, b, c, d byte) ip4i { + return ip4i{ip4: [4]byte{a, b, c, d}} +} + +// BenchmarkIPv4_inline benchmarks the candidate representation, ip4i. +func BenchmarkIPv4_inline(b *testing.B) { + b.ReportAllocs() + ips := []ip4i{} + for i := 0; i < b.N; i++ { + ip := newip4i_v4(8, 8, 8, 8) + ips = ips[:0] + for i := 0; i < 100; i++ { + ips = append(ips, ip) + } + } +} + +func BenchmarkStdIPv6(b *testing.B) { + b.ReportAllocs() + ips := []net.IP{} + for i := 0; i < b.N; i++ { + ip := net.ParseIP("2001:db8::1") + ips = ips[:0] + for i := 0; i < 100; i++ { + ips = append(ips, ip) + } + } +} + +func BenchmarkIPv6(b *testing.B) { + b.ReportAllocs() + ips := []Addr{} + for i := 0; i < b.N; i++ { + ip := mustIP("2001:db8::1") + ips = ips[:0] + for i := 0; i < 100; i++ { + ips = append(ips, ip) + } + } +} + +func BenchmarkIPv4Contains(b *testing.B) { + b.ReportAllocs() + prefix := PrefixFrom(IPv4(192, 168, 1, 0), 24) + ip := IPv4(192, 168, 1, 1) + for i := 0; i < b.N; i++ { + prefix.Contains(ip) + } +} + +func BenchmarkIPv6Contains(b *testing.B) { + b.ReportAllocs() + prefix := MustParsePrefix("::1/128") + ip := MustParseAddr("::1") + for i := 0; i < b.N; i++ { + prefix.Contains(ip) + } +} + +var parseBenchInputs = []struct { + name string + ip string +}{ + {"v4", "192.168.1.1"}, + {"v6", "fd7a:115c:a1e0:ab12:4843:cd96:626b:430b"}, + {"v6_ellipsis", "fd7a:115c::626b:430b"}, + {"v6_v4", "::ffff:192.168.140.255"}, + {"v6_zone", "1:2::ffff:192.168.140.255%eth1"}, +} + +func BenchmarkParseAddr(b *testing.B) { + sinkInternValue = intern.Get("eth1") // Pin to not benchmark the intern package + for _, test := range parseBenchInputs { + b.Run(test.name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + sinkIP, _ = ParseAddr(test.ip) + } + }) + } +} + +func BenchmarkStdParseIP(b *testing.B) { + for _, test := range parseBenchInputs { + b.Run(test.name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + sinkStdIP = net.ParseIP(test.ip) + } + }) + } +} + +func BenchmarkIPString(b *testing.B) { + for _, test := range parseBenchInputs { + ip := MustParseAddr(test.ip) + b.Run(test.name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + sinkString = ip.String() + } + }) + } +} + +func BenchmarkIPStringExpanded(b *testing.B) { + for _, test := range parseBenchInputs { + ip := MustParseAddr(test.ip) + b.Run(test.name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + sinkString = ip.StringExpanded() + } + }) + } +} + +func BenchmarkIPMarshalText(b *testing.B) { + b.ReportAllocs() + ip := MustParseAddr("66.55.44.33") + for i := 0; i < b.N; i++ { + sinkBytes, _ = ip.MarshalText() + } +} + +func BenchmarkAddrPortString(b *testing.B) { + for _, test := range parseBenchInputs { + ip := MustParseAddr(test.ip) + ipp := AddrPortFrom(ip, 60000) + b.Run(test.name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + sinkString = ipp.String() + } + }) + } +} + +func BenchmarkAddrPortMarshalText(b *testing.B) { + for _, test := range parseBenchInputs { + ip := MustParseAddr(test.ip) + ipp := AddrPortFrom(ip, 60000) + b.Run(test.name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + sinkBytes, _ = ipp.MarshalText() + } + }) + } +} + +func BenchmarkPrefixMasking(b *testing.B) { + tests := []struct { + name string + ip Addr + bits int + }{ + { + name: "IPv4 /32", + ip: IPv4(192, 0, 2, 0), + bits: 32, + }, + { + name: "IPv4 /17", + ip: IPv4(192, 0, 2, 0), + bits: 17, + }, + { + name: "IPv4 /0", + ip: IPv4(192, 0, 2, 0), + bits: 0, + }, + { + name: "IPv6 /128", + ip: mustIP("2001:db8::1"), + bits: 128, + }, + { + name: "IPv6 /65", + ip: mustIP("2001:db8::1"), + bits: 65, + }, + { + name: "IPv6 /0", + ip: mustIP("2001:db8::1"), + bits: 0, + }, + { + name: "IPv6 zone /128", + ip: mustIP("2001:db8::1%eth0"), + bits: 128, + }, + { + name: "IPv6 zone /65", + ip: mustIP("2001:db8::1%eth0"), + bits: 65, + }, + { + name: "IPv6 zone /0", + ip: mustIP("2001:db8::1%eth0"), + bits: 0, + }, + } + + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + sinkPrefix, _ = tt.ip.Prefix(tt.bits) + } + }) + } +} + +func BenchmarkPrefixMarshalText(b *testing.B) { + b.ReportAllocs() + ipp := MustParsePrefix("66.55.44.33/22") + for i := 0; i < b.N; i++ { + sinkBytes, _ = ipp.MarshalText() + } +} + +func BenchmarkParseAddrPort(b *testing.B) { + for _, test := range parseBenchInputs { + var ipp string + if strings.HasPrefix(test.name, "v6") { + ipp = fmt.Sprintf("[%s]:1234", test.ip) + } else { + ipp = fmt.Sprintf("%s:1234", test.ip) + } + b.Run(test.name, func(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + sinkAddrPort, _ = ParseAddrPort(ipp) + } + }) + } +} + +func TestAs4(t *testing.T) { + tests := []struct { + ip Addr + want [4]byte + wantPanic bool + }{ + { + ip: mustIP("1.2.3.4"), + want: [4]byte{1, 2, 3, 4}, + }, + { + ip: AddrFrom16(mustIP("1.2.3.4").As16()), // IPv4-in-IPv6 + want: [4]byte{1, 2, 3, 4}, + }, + { + ip: mustIP("0.0.0.0"), + want: [4]byte{0, 0, 0, 0}, + }, + { + ip: Addr{}, + wantPanic: true, + }, + { + ip: mustIP("::1"), + wantPanic: true, + }, + } + as4 := func(ip Addr) (v [4]byte, gotPanic bool) { + defer func() { + if recover() != nil { + gotPanic = true + return + } + }() + v = ip.As4() + return + } + for i, tt := range tests { + got, gotPanic := as4(tt.ip) + if gotPanic != tt.wantPanic { + t.Errorf("%d. panic on %v = %v; want %v", i, tt.ip, gotPanic, tt.wantPanic) + continue + } + if got != tt.want { + t.Errorf("%d. %v = %v; want %v", i, tt.ip, got, tt.want) + } + } +} + +func TestPrefixOverlaps(t *testing.T) { + pfx := mustPrefix + tests := []struct { + a, b Prefix + want bool + }{ + {Prefix{}, pfx("1.2.0.0/16"), false}, // first zero + {pfx("1.2.0.0/16"), Prefix{}, false}, // second zero + {pfx("::0/3"), pfx("0.0.0.0/3"), false}, // different families + + {pfx("1.2.0.0/16"), pfx("1.2.0.0/16"), true}, // equal + + {pfx("1.2.0.0/16"), pfx("1.2.3.0/24"), true}, + {pfx("1.2.3.0/24"), pfx("1.2.0.0/16"), true}, + + {pfx("1.2.0.0/16"), pfx("1.2.3.0/32"), true}, + {pfx("1.2.3.0/32"), pfx("1.2.0.0/16"), true}, + + // Match /0 either order + {pfx("1.2.3.0/32"), pfx("0.0.0.0/0"), true}, + {pfx("0.0.0.0/0"), pfx("1.2.3.0/32"), true}, + + {pfx("1.2.3.0/32"), pfx("5.5.5.5/0"), true}, // normalization not required; /0 means true + + // IPv6 overlapping + {pfx("5::1/128"), pfx("5::0/8"), true}, + {pfx("5::0/8"), pfx("5::1/128"), true}, + + // IPv6 not overlapping + {pfx("1::1/128"), pfx("2::2/128"), false}, + {pfx("0100::0/8"), pfx("::1/128"), false}, + + // v6-mapped v4 should not overlap with IPv4. + {PrefixFrom(AddrFrom16(mustIP("1.2.0.0").As16()), 16), pfx("1.2.3.0/24"), false}, + + // Invalid prefixes + {PrefixFrom(mustIP("1.2.3.4"), 33), pfx("1.2.3.0/24"), false}, + {PrefixFrom(mustIP("2000::"), 129), pfx("2000::/64"), false}, + } + for i, tt := range tests { + if got := tt.a.Overlaps(tt.b); got != tt.want { + t.Errorf("%d. (%v).Overlaps(%v) = %v; want %v", i, tt.a, tt.b, got, tt.want) + } + // Overlaps is commutative + if got := tt.b.Overlaps(tt.a); got != tt.want { + t.Errorf("%d. (%v).Overlaps(%v) = %v; want %v", i, tt.b, tt.a, got, tt.want) + } + } +} + +// Sink variables are here to force the compiler to not elide +// seemingly useless work in benchmarks and allocation tests. If you +// were to just `_ = foo()` within a test function, the compiler could +// correctly deduce that foo() does nothing and doesn't need to be +// called. By writing results to a global variable, we hide that fact +// from the compiler and force it to keep the code under test. +var ( + sinkIP Addr + sinkStdIP net.IP + sinkAddrPort AddrPort + sinkPrefix Prefix + sinkPrefixSlice []Prefix + sinkInternValue *intern.Value + sinkIP16 [16]byte + sinkIP4 [4]byte + sinkBool bool + sinkString string + sinkBytes []byte + sinkUDPAddr = &net.UDPAddr{IP: make(net.IP, 0, 16)} +) + +func TestNoAllocs(t *testing.T) { + // Wrappers that panic on error, to prove that our alloc-free + // methods are returning successfully. + panicIP := func(ip Addr, err error) Addr { + if err != nil { + panic(err) + } + return ip + } + panicPfx := func(pfx Prefix, err error) Prefix { + if err != nil { + panic(err) + } + return pfx + } + panicIPP := func(ipp AddrPort, err error) AddrPort { + if err != nil { + panic(err) + } + return ipp + } + test := func(name string, f func()) { + t.Run(name, func(t *testing.T) { + n := testing.AllocsPerRun(1000, f) + if n != 0 { + t.Fatalf("allocs = %d; want 0", int(n)) + } + }) + } + + // IP constructors + test("IPv4", func() { sinkIP = IPv4(1, 2, 3, 4) }) + test("AddrFrom4", func() { sinkIP = AddrFrom4([4]byte{1, 2, 3, 4}) }) + test("AddrFrom16", func() { sinkIP = AddrFrom16([16]byte{}) }) + test("ParseAddr/4", func() { sinkIP = panicIP(ParseAddr("1.2.3.4")) }) + test("ParseAddr/6", func() { sinkIP = panicIP(ParseAddr("::1")) }) + test("MustParseAddr", func() { sinkIP = MustParseAddr("1.2.3.4") }) + test("IPv6LinkLocalAllNodes", func() { sinkIP = IPv6LinkLocalAllNodes() }) + test("IPv6Unspecified", func() { sinkIP = IPv6Unspecified() }) + + // IP methods + test("IP.IsZero", func() { sinkBool = MustParseAddr("1.2.3.4").IsZero() }) + test("IP.BitLen", func() { sinkBool = MustParseAddr("1.2.3.4").BitLen() == 8 }) + test("IP.Zone/4", func() { sinkBool = MustParseAddr("1.2.3.4").Zone() == "" }) + test("IP.Zone/6", func() { sinkBool = MustParseAddr("fe80::1").Zone() == "" }) + test("IP.Zone/6zone", func() { sinkBool = MustParseAddr("fe80::1%zone").Zone() == "" }) + test("IP.Compare", func() { + a := MustParseAddr("1.2.3.4") + b := MustParseAddr("2.3.4.5") + sinkBool = a.Compare(b) == 0 + }) + test("IP.Less", func() { + a := MustParseAddr("1.2.3.4") + b := MustParseAddr("2.3.4.5") + sinkBool = a.Less(b) + }) + test("IP.Is4", func() { sinkBool = MustParseAddr("1.2.3.4").Is4() }) + test("IP.Is6", func() { sinkBool = MustParseAddr("fe80::1").Is6() }) + test("IP.Is4In6", func() { sinkBool = MustParseAddr("fe80::1").Is4In6() }) + test("IP.Unmap", func() { sinkIP = MustParseAddr("ffff::2.3.4.5").Unmap() }) + test("IP.WithZone", func() { sinkIP = MustParseAddr("fe80::1").WithZone("") }) + test("IP.IsGlobalUnicast", func() { sinkBool = MustParseAddr("2001:db8::1").IsGlobalUnicast() }) + test("IP.IsInterfaceLocalMulticast", func() { sinkBool = MustParseAddr("fe80::1").IsInterfaceLocalMulticast() }) + test("IP.IsLinkLocalMulticast", func() { sinkBool = MustParseAddr("fe80::1").IsLinkLocalMulticast() }) + test("IP.IsLinkLocalUnicast", func() { sinkBool = MustParseAddr("fe80::1").IsLinkLocalUnicast() }) + test("IP.IsLoopback", func() { sinkBool = MustParseAddr("fe80::1").IsLoopback() }) + test("IP.IsMulticast", func() { sinkBool = MustParseAddr("fe80::1").IsMulticast() }) + test("IP.IsPrivate", func() { sinkBool = MustParseAddr("fd00::1").IsPrivate() }) + test("IP.IsUnspecified", func() { sinkBool = IPv6Unspecified().IsUnspecified() }) + test("IP.Prefix/4", func() { sinkPrefix = panicPfx(MustParseAddr("1.2.3.4").Prefix(20)) }) + test("IP.Prefix/6", func() { sinkPrefix = panicPfx(MustParseAddr("fe80::1").Prefix(64)) }) + test("IP.As16", func() { sinkIP16 = MustParseAddr("1.2.3.4").As16() }) + test("IP.As4", func() { sinkIP4 = MustParseAddr("1.2.3.4").As4() }) + test("IP.Next", func() { sinkIP = MustParseAddr("1.2.3.4").Next() }) + test("IP.Prev", func() { sinkIP = MustParseAddr("1.2.3.4").Prev() }) + + // AddrPort constructors + test("AddrPortFrom", func() { sinkAddrPort = AddrPortFrom(IPv4(1, 2, 3, 4), 22) }) + test("ParseAddrPort", func() { sinkAddrPort = panicIPP(ParseAddrPort("[::1]:1234")) }) + test("MustParseAddrPort", func() { sinkAddrPort = MustParseAddrPort("[::1]:1234") }) + + // Prefix constructors + test("PrefixFrom", func() { sinkPrefix = PrefixFrom(IPv4(1, 2, 3, 4), 32) }) + test("ParsePrefix/4", func() { sinkPrefix = panicPfx(ParsePrefix("1.2.3.4/20")) }) + test("ParsePrefix/6", func() { sinkPrefix = panicPfx(ParsePrefix("fe80::1/64")) }) + test("MustParsePrefix", func() { sinkPrefix = MustParsePrefix("1.2.3.4/20") }) + + // Prefix methods + test("Prefix.Contains", func() { sinkBool = MustParsePrefix("1.2.3.0/24").Contains(MustParseAddr("1.2.3.4")) }) + test("Prefix.Overlaps", func() { + a, b := MustParsePrefix("1.2.3.0/24"), MustParsePrefix("1.2.0.0/16") + sinkBool = a.Overlaps(b) + }) + test("Prefix.IsZero", func() { sinkBool = MustParsePrefix("1.2.0.0/16").IsZero() }) + test("Prefix.IsSingleIP", func() { sinkBool = MustParsePrefix("1.2.3.4/32").IsSingleIP() }) + test("IPPRefix.Masked", func() { sinkPrefix = MustParsePrefix("1.2.3.4/16").Masked() }) +} + +func TestPrefixString(t *testing.T) { + tests := []struct { + ipp Prefix + want string + }{ + {Prefix{}, "invalid Prefix"}, + {PrefixFrom(Addr{}, 8), "invalid Prefix"}, + {PrefixFrom(MustParseAddr("1.2.3.4"), 88), "invalid Prefix"}, + } + + for _, tt := range tests { + if got := tt.ipp.String(); got != tt.want { + t.Errorf("(%#v).String() = %q want %q", tt.ipp, got, tt.want) + } + } +} + +func TestInvalidAddrPortString(t *testing.T) { + tests := []struct { + ipp AddrPort + want string + }{ + {AddrPort{}, "invalid AddrPort"}, + {AddrPortFrom(Addr{}, 80), "invalid AddrPort"}, + } + + for _, tt := range tests { + if got := tt.ipp.String(); got != tt.want { + t.Errorf("(%#v).String() = %q want %q", tt.ipp, got, tt.want) + } + } +} + +func TestAsSlice(t *testing.T) { + tests := []struct { + in Addr + want []byte + }{ + {in: Addr{}, want: nil}, + {in: mustIP("1.2.3.4"), want: []byte{1, 2, 3, 4}}, + {in: mustIP("ffff::1"), want: []byte{0xff, 0xff, 15: 1}}, + } + + for _, test := range tests { + got := test.in.AsSlice() + if !bytes.Equal(got, test.want) { + t.Errorf("%v.AsSlice() = %v want %v", test.in, got, test.want) + } + } +} + +var sink16 [16]byte + +func BenchmarkAs16(b *testing.B) { + addr := MustParseAddr("1::10") + for i := 0; i < b.N; i++ { + sink16 = addr.As16() + } +} diff --git a/libgo/go/net/netip/slow_test.go b/libgo/go/net/netip/slow_test.go new file mode 100644 index 0000000..5b46a39 --- /dev/null +++ b/libgo/go/net/netip/slow_test.go @@ -0,0 +1,190 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package netip_test + +import ( + "fmt" + . "net/netip" + "strconv" + "strings" +) + +// zeros is a slice of eight stringified zeros. It's used in +// parseIPSlow to construct slices of specific amounts of zero fields, +// from 1 to 8. +var zeros = []string{"0", "0", "0", "0", "0", "0", "0", "0"} + +// parseIPSlow is like ParseIP, but aims for readability above +// speed. It's the reference implementation for correctness checking +// and against which we measure optimized parsers. +// +// parseIPSlow understands the following forms of IP addresses: +// - Regular IPv4: 1.2.3.4 +// - IPv4 with many leading zeros: 0000001.0000002.0000003.0000004 +// - Regular IPv6: 1111:2222:3333:4444:5555:6666:7777:8888 +// - IPv6 with many leading zeros: 00000001:0000002:0000003:0000004:0000005:0000006:0000007:0000008 +// - IPv6 with zero blocks elided: 1111:2222::7777:8888 +// - IPv6 with trailing 32 bits expressed as IPv4: 1111:2222:3333:4444:5555:6666:77.77.88.88 +// +// It does not process the following IP address forms, which have been +// varyingly accepted by some programs due to an under-specification +// of the shapes of IPv4 addresses: +// +// - IPv4 as a single 32-bit uint: 4660 (same as "1.2.3.4") +// - IPv4 with octal numbers: 0300.0250.0.01 (same as "192.168.0.1") +// - IPv4 with hex numbers: 0xc0.0xa8.0x0.0x1 (same as "192.168.0.1") +// - IPv4 in "class-B style": 1.2.52 (same as "1.2.3.4") +// - IPv4 in "class-A style": 1.564 (same as "1.2.3.4") +func parseIPSlow(s string) (Addr, error) { + // Identify and strip out the zone, if any. There should be 0 or 1 + // '%' in the string. + var zone string + fs := strings.Split(s, "%") + switch len(fs) { + case 1: + // No zone, that's fine. + case 2: + s, zone = fs[0], fs[1] + if zone == "" { + return Addr{}, fmt.Errorf("netaddr.ParseIP(%q): no zone after zone specifier", s) + } + default: + return Addr{}, fmt.Errorf("netaddr.ParseIP(%q): too many zone specifiers", s) // TODO: less specific? + } + + // IPv4 by itself is easy to do in a helper. + if strings.Count(s, ":") == 0 { + if zone != "" { + return Addr{}, fmt.Errorf("netaddr.ParseIP(%q): IPv4 addresses cannot have a zone", s) + } + return parseIPv4Slow(s) + } + + normal, err := normalizeIPv6Slow(s) + if err != nil { + return Addr{}, err + } + + // At this point, we've normalized the address back into 8 hex + // fields of 16 bits each. Parse that. + fs = strings.Split(normal, ":") + if len(fs) != 8 { + return Addr{}, fmt.Errorf("netaddr.ParseIP(%q): wrong size address", s) + } + var ret [16]byte + for i, f := range fs { + a, b, err := parseWord(f) + if err != nil { + return Addr{}, err + } + ret[i*2] = a + ret[i*2+1] = b + } + + return AddrFrom16(ret).WithZone(zone), nil +} + +// normalizeIPv6Slow expands s, which is assumed to be an IPv6 +// address, to its canonical text form. +// +// The canonical form of an IPv6 address is 8 colon-separated fields, +// where each field should be a hex value from 0 to ffff. This +// function does not verify the contents of each field. +// +// This function performs two transformations: +// - The last 32 bits of an IPv6 address may be represented in +// IPv4-style dotted quad form, as in 1:2:3:4:5:6:7.8.9.10. That +// address is transformed to its hex equivalent, +// e.g. 1:2:3:4:5:6:708:90a. +// - An address may contain one "::", which expands into as many +// 16-bit blocks of zeros as needed to make the address its correct +// full size. For example, fe80::1:2 expands to fe80:0:0:0:0:0:1:2. +// +// Both short forms may be present in a single address, +// e.g. fe80::1.2.3.4. +func normalizeIPv6Slow(orig string) (string, error) { + s := orig + + // Find and convert an IPv4 address in the final field, if any. + i := strings.LastIndex(s, ":") + if i == -1 { + return "", fmt.Errorf("netaddr.ParseIP(%q): invalid IP address", orig) + } + if strings.Contains(s[i+1:], ".") { + ip, err := parseIPv4Slow(s[i+1:]) + if err != nil { + return "", err + } + a4 := ip.As4() + s = fmt.Sprintf("%s:%02x%02x:%02x%02x", s[:i], a4[0], a4[1], a4[2], a4[3]) + } + + // Find and expand a ::, if any. + fs := strings.Split(s, "::") + switch len(fs) { + case 1: + // No ::, nothing to do. + case 2: + lhs, rhs := fs[0], fs[1] + // Found a ::, figure out how many zero blocks need to be + // inserted. + nblocks := strings.Count(lhs, ":") + strings.Count(rhs, ":") + if lhs != "" { + nblocks++ + } + if rhs != "" { + nblocks++ + } + if nblocks > 7 { + return "", fmt.Errorf("netaddr.ParseIP(%q): address too long", orig) + } + fs = nil + // Either side of the :: can be empty. We don't want empty + // fields to feature in the final normalized address. + if lhs != "" { + fs = append(fs, lhs) + } + fs = append(fs, zeros[:8-nblocks]...) + if rhs != "" { + fs = append(fs, rhs) + } + s = strings.Join(fs, ":") + default: + // Too many :: + return "", fmt.Errorf("netaddr.ParseIP(%q): invalid IP address", orig) + } + + return s, nil +} + +// parseIPv4Slow parses and returns an IPv4 address in dotted quad +// form, e.g. "192.168.0.1". It is slow but easy to read, and the +// reference implementation against which we compare faster +// implementations for correctness. +func parseIPv4Slow(s string) (Addr, error) { + fs := strings.Split(s, ".") + if len(fs) != 4 { + return Addr{}, fmt.Errorf("netaddr.ParseIP(%q): invalid IP address", s) + } + var ret [4]byte + for i := range ret { + val, err := strconv.ParseUint(fs[i], 10, 8) + if err != nil { + return Addr{}, err + } + ret[i] = uint8(val) + } + return AddrFrom4([4]byte{ret[0], ret[1], ret[2], ret[3]}), nil +} + +// parseWord converts a 16-bit hex string into its corresponding +// two-byte value. +func parseWord(s string) (byte, byte, error) { + ret, err := strconv.ParseUint(s, 16, 16) + if err != nil { + return 0, 0, err + } + return uint8(ret >> 8), uint8(ret), nil +} diff --git a/libgo/go/net/netip/uint128.go b/libgo/go/net/netip/uint128.go new file mode 100644 index 0000000..738939d --- /dev/null +++ b/libgo/go/net/netip/uint128.go @@ -0,0 +1,92 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package netip + +import "math/bits" + +// uint128 represents a uint128 using two uint64s. +// +// When the methods below mention a bit number, bit 0 is the most +// significant bit (in hi) and bit 127 is the lowest (lo&1). +type uint128 struct { + hi uint64 + lo uint64 +} + +// mask6 returns a uint128 bitmask with the topmost n bits of a +// 128-bit number. +func mask6(n int) uint128 { + return uint128{^(^uint64(0) >> n), ^uint64(0) << (128 - n)} +} + +// isZero reports whether u == 0. +// +// It's faster than u == (uint128{}) because the compiler (as of Go +// 1.15/1.16b1) doesn't do this trick and instead inserts a branch in +// its eq alg's generated code. +func (u uint128) isZero() bool { return u.hi|u.lo == 0 } + +// and returns the bitwise AND of u and m (u&m). +func (u uint128) and(m uint128) uint128 { + return uint128{u.hi & m.hi, u.lo & m.lo} +} + +// xor returns the bitwise XOR of u and m (u^m). +func (u uint128) xor(m uint128) uint128 { + return uint128{u.hi ^ m.hi, u.lo ^ m.lo} +} + +// or returns the bitwise OR of u and m (u|m). +func (u uint128) or(m uint128) uint128 { + return uint128{u.hi | m.hi, u.lo | m.lo} +} + +// not returns the bitwise NOT of u. +func (u uint128) not() uint128 { + return uint128{^u.hi, ^u.lo} +} + +// subOne returns u - 1. +func (u uint128) subOne() uint128 { + lo, borrow := bits.Sub64(u.lo, 1, 0) + return uint128{u.hi - borrow, lo} +} + +// addOne returns u + 1. +func (u uint128) addOne() uint128 { + lo, carry := bits.Add64(u.lo, 1, 0) + return uint128{u.hi + carry, lo} +} + +func u64CommonPrefixLen(a, b uint64) uint8 { + return uint8(bits.LeadingZeros64(a ^ b)) +} + +func (u uint128) commonPrefixLen(v uint128) (n uint8) { + if n = u64CommonPrefixLen(u.hi, v.hi); n == 64 { + n += u64CommonPrefixLen(u.lo, v.lo) + } + return +} + +// halves returns the two uint64 halves of the uint128. +// +// Logically, think of it as returning two uint64s. +// It only returns pointers for inlining reasons on 32-bit platforms. +func (u *uint128) halves() [2]*uint64 { + return [2]*uint64{&u.hi, &u.lo} +} + +// bitsSetFrom returns a copy of u with the given bit +// and all subsequent ones set. +func (u uint128) bitsSetFrom(bit uint8) uint128 { + return u.or(mask6(int(bit)).not()) +} + +// bitsClearedFrom returns a copy of u with the given bit +// and all subsequent ones cleared. +func (u uint128) bitsClearedFrom(bit uint8) uint128 { + return u.and(mask6(int(bit))) +} diff --git a/libgo/go/net/netip/uint128_test.go b/libgo/go/net/netip/uint128_test.go new file mode 100644 index 0000000..dd1ae0e --- /dev/null +++ b/libgo/go/net/netip/uint128_test.go @@ -0,0 +1,89 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package netip + +import ( + "testing" +) + +func TestUint128AddSub(t *testing.T) { + const add1 = 1 + const sub1 = -1 + tests := []struct { + in uint128 + op int // +1 or -1 to add vs subtract + want uint128 + }{ + {uint128{0, 0}, add1, uint128{0, 1}}, + {uint128{0, 1}, add1, uint128{0, 2}}, + {uint128{1, 0}, add1, uint128{1, 1}}, + {uint128{0, ^uint64(0)}, add1, uint128{1, 0}}, + {uint128{^uint64(0), ^uint64(0)}, add1, uint128{0, 0}}, + + {uint128{0, 0}, sub1, uint128{^uint64(0), ^uint64(0)}}, + {uint128{0, 1}, sub1, uint128{0, 0}}, + {uint128{0, 2}, sub1, uint128{0, 1}}, + {uint128{1, 0}, sub1, uint128{0, ^uint64(0)}}, + {uint128{1, 1}, sub1, uint128{1, 0}}, + } + for _, tt := range tests { + var got uint128 + switch tt.op { + case add1: + got = tt.in.addOne() + case sub1: + got = tt.in.subOne() + default: + panic("bogus op") + } + if got != tt.want { + t.Errorf("%v add %d = %v; want %v", tt.in, tt.op, got, tt.want) + } + } +} + +func TestBitsSetFrom(t *testing.T) { + tests := []struct { + bit uint8 + want uint128 + }{ + {0, uint128{^uint64(0), ^uint64(0)}}, + {1, uint128{^uint64(0) >> 1, ^uint64(0)}}, + {63, uint128{1, ^uint64(0)}}, + {64, uint128{0, ^uint64(0)}}, + {65, uint128{0, ^uint64(0) >> 1}}, + {127, uint128{0, 1}}, + {128, uint128{0, 0}}, + } + for _, tt := range tests { + var zero uint128 + got := zero.bitsSetFrom(tt.bit) + if got != tt.want { + t.Errorf("0.bitsSetFrom(%d) = %064b want %064b", tt.bit, got, tt.want) + } + } +} + +func TestBitsClearedFrom(t *testing.T) { + tests := []struct { + bit uint8 + want uint128 + }{ + {0, uint128{0, 0}}, + {1, uint128{1 << 63, 0}}, + {63, uint128{^uint64(0) &^ 1, 0}}, + {64, uint128{^uint64(0), 0}}, + {65, uint128{^uint64(0), 1 << 63}}, + {127, uint128{^uint64(0), ^uint64(0) &^ 1}}, + {128, uint128{^uint64(0), ^uint64(0)}}, + } + for _, tt := range tests { + ones := uint128{^uint64(0), ^uint64(0)} + got := ones.bitsClearedFrom(tt.bit) + if got != tt.want { + t.Errorf("ones.bitsClearedFrom(%d) = %064b want %064b", tt.bit, got, tt.want) + } + } +} diff --git a/libgo/go/net/nss.go b/libgo/go/net/nss.go index c12ee75..3e5274d 100644 --- a/libgo/go/net/nss.go +++ b/libgo/go/net/nss.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris -// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris package net diff --git a/libgo/go/net/nss_test.go b/libgo/go/net/nss_test.go index 948b8d3..b9a23ab 100644 --- a/libgo/go/net/nss_test.go +++ b/libgo/go/net/nss_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris -// +build darwin dragonfly freebsd hurd linux netbsd openbsd solaris package net diff --git a/libgo/go/net/packetconn_test.go b/libgo/go/net/packetconn_test.go index aeb9845..fa160df 100644 --- a/libgo/go/net/packetconn_test.go +++ b/libgo/go/net/packetconn_test.go @@ -6,14 +6,12 @@ // tag. //go:build !js -// +build !js package net import ( "os" "testing" - "time" ) // The full stack test cases for IPConn have been moved to the @@ -29,16 +27,16 @@ func packetConnTestData(t *testing.T, network string) ([]byte, func()) { return []byte("PACKETCONN TEST"), nil } -var packetConnTests = []struct { - net string - addr1 string - addr2 string -}{ - {"udp", "127.0.0.1:0", "127.0.0.1:0"}, - {"unixgram", testUnixAddr(), testUnixAddr()}, -} - func TestPacketConn(t *testing.T) { + var packetConnTests = []struct { + net string + addr1 string + addr2 string + }{ + {"udp", "127.0.0.1:0", "127.0.0.1:0"}, + {"unixgram", testUnixAddr(t), testUnixAddr(t)}, + } + closer := func(c PacketConn, net, addr1, addr2 string) { c.Close() switch net { @@ -61,9 +59,6 @@ func TestPacketConn(t *testing.T) { } defer closer(c1, tt.net, tt.addr1, tt.addr2) c1.LocalAddr() - c1.SetDeadline(time.Now().Add(500 * time.Millisecond)) - c1.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) - c1.SetWriteDeadline(time.Now().Add(500 * time.Millisecond)) c2, err := ListenPacket(tt.net, tt.addr2) if err != nil { @@ -71,9 +66,6 @@ func TestPacketConn(t *testing.T) { } defer closer(c2, tt.net, tt.addr1, tt.addr2) c2.LocalAddr() - c2.SetDeadline(time.Now().Add(500 * time.Millisecond)) - c2.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) - c2.SetWriteDeadline(time.Now().Add(500 * time.Millisecond)) rb2 := make([]byte, 128) if _, err := c1.WriteTo(wb, c2.LocalAddr()); err != nil { @@ -93,6 +85,15 @@ func TestPacketConn(t *testing.T) { } func TestConnAndPacketConn(t *testing.T) { + var packetConnTests = []struct { + net string + addr1 string + addr2 string + }{ + {"udp", "127.0.0.1:0", "127.0.0.1:0"}, + {"unixgram", testUnixAddr(t), testUnixAddr(t)}, + } + closer := func(c PacketConn, net, addr1, addr2 string) { c.Close() switch net { @@ -116,9 +117,6 @@ func TestConnAndPacketConn(t *testing.T) { } defer closer(c1, tt.net, tt.addr1, tt.addr2) c1.LocalAddr() - c1.SetDeadline(time.Now().Add(500 * time.Millisecond)) - c1.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) - c1.SetWriteDeadline(time.Now().Add(500 * time.Millisecond)) c2, err := Dial(tt.net, c1.LocalAddr().String()) if err != nil { @@ -127,9 +125,6 @@ func TestConnAndPacketConn(t *testing.T) { defer c2.Close() c2.LocalAddr() c2.RemoteAddr() - c2.SetDeadline(time.Now().Add(500 * time.Millisecond)) - c2.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) - c2.SetWriteDeadline(time.Now().Add(500 * time.Millisecond)) if _, err := c2.Write(wb); err != nil { t.Fatal(err) diff --git a/libgo/go/net/parse.go b/libgo/go/net/parse.go index 6c230ab..ee2890f 100644 --- a/libgo/go/net/parse.go +++ b/libgo/go/net/parse.go @@ -208,6 +208,16 @@ func last(s string, b byte) int { return i } +// hasUpperCase tells whether the given string contains at least one upper-case. +func hasUpperCase(s string) bool { + for i := range s { + if 'A' <= s[i] && s[i] <= 'Z' { + return true + } + } + return false +} + // lowerASCIIBytes makes x ASCII lowercase in-place. func lowerASCIIBytes(x []byte) { for i, b := range x { @@ -331,26 +341,3 @@ func readFull(r io.Reader) (all []byte, err error) { } } } - -// goDebugString returns the value of the named GODEBUG key. -// GODEBUG is of the form "key=val,key2=val2" -func goDebugString(key string) string { - s := os.Getenv("GODEBUG") - for i := 0; i < len(s)-len(key)-1; i++ { - if i > 0 && s[i-1] != ',' { - continue - } - afterKey := s[i+len(key):] - if afterKey[0] != '=' || s[i:i+len(key)] != key { - continue - } - val := afterKey[1:] - for i, b := range val { - if b == ',' { - return val[:i] - } - } - return val - } - return "" -} diff --git a/libgo/go/net/parse_test.go b/libgo/go/net/parse_test.go index c5f8bfd..97716d7 100644 --- a/libgo/go/net/parse_test.go +++ b/libgo/go/net/parse_test.go @@ -51,33 +51,6 @@ func TestReadLine(t *testing.T) { } } -func TestGoDebugString(t *testing.T) { - defer os.Setenv("GODEBUG", os.Getenv("GODEBUG")) - tests := []struct { - godebug string - key string - want string - }{ - {"", "foo", ""}, - {"foo=", "foo", ""}, - {"foo=bar", "foo", "bar"}, - {"foo=bar,", "foo", "bar"}, - {"foo,foo=bar,", "foo", "bar"}, - {"foo1=bar,foo=bar,", "foo", "bar"}, - {"foo=bar,foo=bar,", "foo", "bar"}, - {"foo=", "foo", ""}, - {"foo", "foo", ""}, - {",foo", "foo", ""}, - {"foo=bar,baz", "loooooooong", ""}, - } - for _, tt := range tests { - os.Setenv("GODEBUG", tt.godebug) - if got := goDebugString(tt.key); got != tt.want { - t.Errorf("for %q, goDebugString(%q) = %q; want %q", tt.godebug, tt.key, got, tt.want) - } - } -} - func TestDtoi(t *testing.T) { for _, tt := range []struct { in string diff --git a/libgo/go/net/platform_test.go b/libgo/go/net/platform_test.go index 2da23de..c522ba2 100644 --- a/libgo/go/net/platform_test.go +++ b/libgo/go/net/platform_test.go @@ -34,8 +34,8 @@ func init() { // testableNetwork reports whether network is testable on the current // platform configuration. func testableNetwork(network string) bool { - ss := strings.Split(network, ":") - switch ss[0] { + net, _, _ := strings.Cut(network, ":") + switch net { case "ip+nopriv": case "ip", "ip4", "ip6": switch runtime.GOOS { @@ -68,7 +68,7 @@ func testableNetwork(network string) bool { } } } - switch ss[0] { + switch net { case "tcp4", "udp4", "ip4": if !supportsIPv4() { return false @@ -88,7 +88,7 @@ func iOS() bool { // testableAddress reports whether address of network is testable on // the current platform configuration. func testableAddress(network, address string) bool { - switch ss := strings.Split(network, ":"); ss[0] { + switch net, _, _ := strings.Cut(network, ":"); net { case "unix", "unixgram", "unixpacket": // Abstract unix domain sockets, a Linux-ism. if address[0] == '@' && runtime.GOOS != "linux" { @@ -107,7 +107,7 @@ func testableListenArgs(network, address, client string) bool { var err error var addr Addr - switch ss := strings.Split(network, ":"); ss[0] { + switch net, _, _ := strings.Cut(network, ":"); net { case "tcp", "tcp4", "tcp6": addr, err = ResolveTCPAddr("tcp", address) case "udp", "udp4", "udp6": @@ -173,7 +173,7 @@ func testableListenArgs(network, address, client string) bool { return true } -func condFatalf(t *testing.T, network string, format string, args ...interface{}) { +func condFatalf(t *testing.T, network string, format string, args ...any) { t.Helper() // A few APIs like File and Read/WriteMsg{UDP,IP} are not // fully implemented yet on Plan 9 and Windows. diff --git a/libgo/go/net/port_unix.go b/libgo/go/net/port_unix.go index 07b4cbb..3527f1f 100644 --- a/libgo/go/net/port_unix.go +++ b/libgo/go/net/port_unix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || (js && wasm) || linux || netbsd || openbsd || solaris -// +build aix darwin dragonfly freebsd hurd js,wasm linux netbsd openbsd solaris // Read system port mappings from /etc/services diff --git a/libgo/go/net/protoconn_test.go b/libgo/go/net/protoconn_test.go index fc9b386..e4198a3 100644 --- a/libgo/go/net/protoconn_test.go +++ b/libgo/go/net/protoconn_test.go @@ -6,7 +6,6 @@ // tag. //go:build !js -// +build !js package net @@ -74,10 +73,7 @@ func TestTCPConnSpecificMethods(t *testing.T) { } ch := make(chan error, 1) handler := func(ls *localServer, ln Listener) { ls.transponder(ls.Listener, ch) } - ls, err := (&streamListener{Listener: ln}).newLocalServer() - if err != nil { - t.Fatal(err) - } + ls := (&streamListener{Listener: ln}).newLocalServer() defer ls.teardown() if err := ls.buildup(handler); err != nil { t.Fatal(err) @@ -208,7 +204,7 @@ func TestUnixListenerSpecificMethods(t *testing.T) { t.Skip("unix test") } - addr := testUnixAddr() + addr := testUnixAddr(t) la, err := ResolveUnixAddr("unix", addr) if err != nil { t.Fatal(err) @@ -249,7 +245,7 @@ func TestUnixConnSpecificMethods(t *testing.T) { t.Skip("unixgram test") } - addr1, addr2, addr3 := testUnixAddr(), testUnixAddr(), testUnixAddr() + addr1, addr2, addr3 := testUnixAddr(t), testUnixAddr(t), testUnixAddr(t) a1, err := ResolveUnixAddr("unixgram", addr1) if err != nil { diff --git a/libgo/go/net/rawconn_stub_test.go b/libgo/go/net/rawconn_stub_test.go index 975aa8d..ff3d829 100644 --- a/libgo/go/net/rawconn_stub_test.go +++ b/libgo/go/net/rawconn_stub_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build (js && wasm) || plan9 -// +build js,wasm plan9 package net diff --git a/libgo/go/net/rawconn_test.go b/libgo/go/net/rawconn_test.go index 3ef7af3..d1ef79d 100644 --- a/libgo/go/net/rawconn_test.go +++ b/libgo/go/net/rawconn_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js -// +build !js package net @@ -65,10 +64,7 @@ func TestRawConnReadWrite(t *testing.T) { return } } - ls, err := newLocalServer("tcp") - if err != nil { - t.Fatal(err) - } + ls := newLocalServer(t, "tcp") defer ls.teardown() if err := ls.buildup(handler); err != nil { t.Fatal(err) @@ -103,10 +99,7 @@ func TestRawConnReadWrite(t *testing.T) { t.Skipf("not supported on %s", runtime.GOOS) } - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") defer ln.Close() c, err := Dial(ln.Addr().Network(), ln.Addr().String()) @@ -181,10 +174,7 @@ func TestRawConnControl(t *testing.T) { } t.Run("TCP", func(t *testing.T) { - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") defer ln.Close() cc1, err := ln.(*TCPListener).SyscallConn() diff --git a/libgo/go/net/rawconn_unix_test.go b/libgo/go/net/rawconn_unix_test.go index 77df4f8..7069d01 100644 --- a/libgo/go/net/rawconn_unix_test.go +++ b/libgo/go/net/rawconn_unix_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris -// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris package net diff --git a/libgo/go/net/rpc/client.go b/libgo/go/net/rpc/client.go index 60bb2cc..42d1351 100644 --- a/libgo/go/net/rpc/client.go +++ b/libgo/go/net/rpc/client.go @@ -27,11 +27,11 @@ var ErrShutdown = errors.New("connection is shut down") // Call represents an active RPC. type Call struct { - ServiceMethod string // The name of the service and method to call. - Args interface{} // The argument to the function (*struct). - Reply interface{} // The reply from the function (*struct). - Error error // After completion, the error status. - Done chan *Call // Receives *Call when Go is complete. + ServiceMethod string // The name of the service and method to call. + Args any // The argument to the function (*struct). + Reply any // The reply from the function (*struct). + Error error // After completion, the error status. + Done chan *Call // Receives *Call when Go is complete. } // Client represents an RPC Client. @@ -61,9 +61,9 @@ type Client struct { // discarded. // See NewClient's comment for information about concurrent access. type ClientCodec interface { - WriteRequest(*Request, interface{}) error + WriteRequest(*Request, any) error ReadResponseHeader(*Response) error - ReadResponseBody(interface{}) error + ReadResponseBody(any) error Close() error } @@ -214,7 +214,7 @@ type gobClientCodec struct { encBuf *bufio.Writer } -func (c *gobClientCodec) WriteRequest(r *Request, body interface{}) (err error) { +func (c *gobClientCodec) WriteRequest(r *Request, body any) (err error) { if err = c.enc.Encode(r); err != nil { return } @@ -228,7 +228,7 @@ func (c *gobClientCodec) ReadResponseHeader(r *Response) error { return c.dec.Decode(r) } -func (c *gobClientCodec) ReadResponseBody(body interface{}) error { +func (c *gobClientCodec) ReadResponseBody(body any) error { return c.dec.Decode(body) } @@ -295,7 +295,7 @@ func (client *Client) Close() error { // the invocation. The done channel will signal when the call is complete by returning // the same Call object. If done is nil, Go will allocate a new channel. // If non-nil, done must be buffered or Go will deliberately crash. -func (client *Client) Go(serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call { +func (client *Client) Go(serviceMethod string, args any, reply any, done chan *Call) *Call { call := new(Call) call.ServiceMethod = serviceMethod call.Args = args @@ -317,7 +317,7 @@ func (client *Client) Go(serviceMethod string, args interface{}, reply interface } // Call invokes the named function, waits for it to complete, and returns its error status. -func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) error { +func (client *Client) Call(serviceMethod string, args any, reply any) error { call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done return call.Error } diff --git a/libgo/go/net/rpc/client_test.go b/libgo/go/net/rpc/client_test.go index 03225e3..ffc12fa 100644 --- a/libgo/go/net/rpc/client_test.go +++ b/libgo/go/net/rpc/client_test.go @@ -17,8 +17,8 @@ type shutdownCodec struct { closed bool } -func (c *shutdownCodec) WriteRequest(*Request, interface{}) error { return nil } -func (c *shutdownCodec) ReadResponseBody(interface{}) error { return nil } +func (c *shutdownCodec) WriteRequest(*Request, any) error { return nil } +func (c *shutdownCodec) ReadResponseBody(any) error { return nil } func (c *shutdownCodec) ReadResponseHeader(*Response) error { c.responded <- 1 return errors.New("shutdownCodec ReadResponseHeader") @@ -57,8 +57,8 @@ func TestGobError(t *testing.T) { if err == nil { t.Fatal("no error") } - if !strings.Contains(err.(error).Error(), "reading body EOF") { - t.Fatal("expected `reading body EOF', got", err) + if !strings.Contains(err.(error).Error(), "reading body unexpected EOF") { + t.Fatal("expected `reading body unexpected EOF', got", err) } }() Register(new(S)) diff --git a/libgo/go/net/rpc/debug.go b/libgo/go/net/rpc/debug.go index a1d799f..9e499fd 100644 --- a/libgo/go/net/rpc/debug.go +++ b/libgo/go/net/rpc/debug.go @@ -72,7 +72,7 @@ type debugHTTP struct { func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) { // Build a sorted version of the data. var services serviceArray - server.serviceMap.Range(func(snamei, svci interface{}) bool { + server.serviceMap.Range(func(snamei, svci any) bool { svc := svci.(*service) ds := debugService{svc, snamei.(string), make(methodArray, 0, len(svc.method))} for mname, method := range svc.method { diff --git a/libgo/go/net/rpc/jsonrpc/all_test.go b/libgo/go/net/rpc/jsonrpc/all_test.go index 667f839..f4e1278 100644 --- a/libgo/go/net/rpc/jsonrpc/all_test.go +++ b/libgo/go/net/rpc/jsonrpc/all_test.go @@ -28,9 +28,9 @@ type Reply struct { type Arith int type ArithAddResp struct { - Id interface{} `json:"id"` - Result Reply `json:"result"` - Error interface{} `json:"error"` + Id any `json:"id"` + Result Reply `json:"result"` + Error any `json:"error"` } func (t *Arith) Add(args *Args, reply *Reply) error { diff --git a/libgo/go/net/rpc/jsonrpc/client.go b/libgo/go/net/rpc/jsonrpc/client.go index e6359be..c473017 100644 --- a/libgo/go/net/rpc/jsonrpc/client.go +++ b/libgo/go/net/rpc/jsonrpc/client.go @@ -44,12 +44,12 @@ func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec { } type clientRequest struct { - Method string `json:"method"` - Params [1]interface{} `json:"params"` - Id uint64 `json:"id"` + Method string `json:"method"` + Params [1]any `json:"params"` + Id uint64 `json:"id"` } -func (c *clientCodec) WriteRequest(r *rpc.Request, param interface{}) error { +func (c *clientCodec) WriteRequest(r *rpc.Request, param any) error { c.mutex.Lock() c.pending[r.Seq] = r.ServiceMethod c.mutex.Unlock() @@ -62,7 +62,7 @@ func (c *clientCodec) WriteRequest(r *rpc.Request, param interface{}) error { type clientResponse struct { Id uint64 `json:"id"` Result *json.RawMessage `json:"result"` - Error interface{} `json:"error"` + Error any `json:"error"` } func (r *clientResponse) reset() { @@ -97,7 +97,7 @@ func (c *clientCodec) ReadResponseHeader(r *rpc.Response) error { return nil } -func (c *clientCodec) ReadResponseBody(x interface{}) error { +func (c *clientCodec) ReadResponseBody(x any) error { if x == nil { return nil } diff --git a/libgo/go/net/rpc/jsonrpc/server.go b/libgo/go/net/rpc/jsonrpc/server.go index 40e4e6f..3ee4ddf 100644 --- a/libgo/go/net/rpc/jsonrpc/server.go +++ b/libgo/go/net/rpc/jsonrpc/server.go @@ -57,8 +57,8 @@ func (r *serverRequest) reset() { type serverResponse struct { Id *json.RawMessage `json:"id"` - Result interface{} `json:"result"` - Error interface{} `json:"error"` + Result any `json:"result"` + Error any `json:"error"` } func (c *serverCodec) ReadRequestHeader(r *rpc.Request) error { @@ -81,7 +81,7 @@ func (c *serverCodec) ReadRequestHeader(r *rpc.Request) error { return nil } -func (c *serverCodec) ReadRequestBody(x interface{}) error { +func (c *serverCodec) ReadRequestBody(x any) error { if x == nil { return nil } @@ -92,14 +92,14 @@ func (c *serverCodec) ReadRequestBody(x interface{}) error { // RPC params is struct. // Unmarshal into array containing struct for now. // Should think about making RPC more general. - var params [1]interface{} + var params [1]any params[0] = x return json.Unmarshal(*c.req.Params, ¶ms) } var null = json.RawMessage([]byte("null")) -func (c *serverCodec) WriteResponse(r *rpc.Response, x interface{}) error { +func (c *serverCodec) WriteResponse(r *rpc.Response, x any) error { c.mutex.Lock() b, ok := c.pending[r.Seq] if !ok { diff --git a/libgo/go/net/rpc/server.go b/libgo/go/net/rpc/server.go index 074c5b9..d5207a4 100644 --- a/libgo/go/net/rpc/server.go +++ b/libgo/go/net/rpc/server.go @@ -203,7 +203,7 @@ var DefaultServer = NewServer() // Is this type exported or a builtin? func isExportedOrBuiltinType(t reflect.Type) bool { - for t.Kind() == reflect.Ptr { + for t.Kind() == reflect.Pointer { t = t.Elem() } // PkgPath will be non-empty even for an exported type, @@ -221,17 +221,21 @@ func isExportedOrBuiltinType(t reflect.Type) bool { // no suitable methods. It also logs the error using package log. // The client accesses each method using a string of the form "Type.Method", // where Type is the receiver's concrete type. -func (server *Server) Register(rcvr interface{}) error { +func (server *Server) Register(rcvr any) error { return server.register(rcvr, "", false) } // RegisterName is like Register but uses the provided name for the type // instead of the receiver's concrete type. -func (server *Server) RegisterName(name string, rcvr interface{}) error { +func (server *Server) RegisterName(name string, rcvr any) error { return server.register(rcvr, name, true) } -func (server *Server) register(rcvr interface{}, name string, useName bool) error { +// logRegisterError specifies whether to log problems during method registration. +// To debug registration, recompile the package with this set to true. +const logRegisterError = false + +func (server *Server) register(rcvr any, name string, useName bool) error { s := new(service) s.typ = reflect.TypeOf(rcvr) s.rcvr = reflect.ValueOf(rcvr) @@ -252,13 +256,13 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) erro s.name = sname // Install the methods - s.method = suitableMethods(s.typ, true) + s.method = suitableMethods(s.typ, logRegisterError) if len(s.method) == 0 { str := "" // To help the user, see if a pointer receiver would work. - method := suitableMethods(reflect.PtrTo(s.typ), false) + method := suitableMethods(reflect.PointerTo(s.typ), false) if len(method) != 0 { str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)" } else { @@ -274,9 +278,9 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) erro return nil } -// suitableMethods returns suitable Rpc methods of typ, it will report -// error using log if reportErr is true. -func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType { +// suitableMethods returns suitable Rpc methods of typ. It will log +// errors if logErr is true. +func suitableMethods(typ reflect.Type, logErr bool) map[string]*methodType { methods := make(map[string]*methodType) for m := 0; m < typ.NumMethod(); m++ { method := typ.Method(m) @@ -288,7 +292,7 @@ func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType { } // Method needs three ins: receiver, *args, *reply. if mtype.NumIn() != 3 { - if reportErr { + if logErr { log.Printf("rpc.Register: method %q has %d input parameters; needs exactly three\n", mname, mtype.NumIn()) } continue @@ -296,36 +300,36 @@ func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType { // First arg need not be a pointer. argType := mtype.In(1) if !isExportedOrBuiltinType(argType) { - if reportErr { + if logErr { log.Printf("rpc.Register: argument type of method %q is not exported: %q\n", mname, argType) } continue } // Second arg must be a pointer. replyType := mtype.In(2) - if replyType.Kind() != reflect.Ptr { - if reportErr { + if replyType.Kind() != reflect.Pointer { + if logErr { log.Printf("rpc.Register: reply type of method %q is not a pointer: %q\n", mname, replyType) } continue } // Reply type must be exported. if !isExportedOrBuiltinType(replyType) { - if reportErr { + if logErr { log.Printf("rpc.Register: reply type of method %q is not exported: %q\n", mname, replyType) } continue } // Method needs one out. if mtype.NumOut() != 1 { - if reportErr { + if logErr { log.Printf("rpc.Register: method %q has %d output parameters; needs exactly one\n", mname, mtype.NumOut()) } continue } // The return type of the method must be error. if returnType := mtype.Out(0); returnType != typeOfError { - if reportErr { + if logErr { log.Printf("rpc.Register: return type of method %q is %q, must be error\n", mname, returnType) } continue @@ -340,7 +344,7 @@ func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType { // contains an error when it is used. var invalidRequest = struct{}{} -func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) { +func (server *Server) sendResponse(sending *sync.Mutex, req *Request, reply any, codec ServerCodec, errmsg string) { resp := server.getResponse() // Encode the response header resp.ServiceMethod = req.ServiceMethod @@ -397,11 +401,11 @@ func (c *gobServerCodec) ReadRequestHeader(r *Request) error { return c.dec.Decode(r) } -func (c *gobServerCodec) ReadRequestBody(body interface{}) error { +func (c *gobServerCodec) ReadRequestBody(body any) error { return c.dec.Decode(body) } -func (c *gobServerCodec) WriteResponse(r *Response, body interface{}) (err error) { +func (c *gobServerCodec) WriteResponse(r *Response, body any) (err error) { if err = c.enc.Encode(r); err != nil { if c.encBuf.Flush() == nil { // Gob couldn't encode the header. Should not happen, so if it does, @@ -552,7 +556,7 @@ func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *m // Decode the argument value. argIsValue := false // if true, need to indirect before calling. - if mtype.ArgType.Kind() == reflect.Ptr { + if mtype.ArgType.Kind() == reflect.Pointer { argv = reflect.New(mtype.ArgType.Elem()) } else { argv = reflect.New(mtype.ArgType) @@ -632,11 +636,11 @@ func (server *Server) Accept(lis net.Listener) { } // Register publishes the receiver's methods in the DefaultServer. -func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) } +func Register(rcvr any) error { return DefaultServer.Register(rcvr) } // RegisterName is like Register but uses the provided name for the type // instead of the receiver's concrete type. -func RegisterName(name string, rcvr interface{}) error { +func RegisterName(name string, rcvr any) error { return DefaultServer.RegisterName(name, rcvr) } @@ -650,8 +654,8 @@ func RegisterName(name string, rcvr interface{}) error { // See NewClient's comment for information about concurrent access. type ServerCodec interface { ReadRequestHeader(*Request) error - ReadRequestBody(interface{}) error - WriteResponse(*Response, interface{}) error + ReadRequestBody(any) error + WriteResponse(*Response, any) error // Close can be called multiple times and must be idempotent. Close() error diff --git a/libgo/go/net/rpc/server_test.go b/libgo/go/net/rpc/server_test.go index e5d7fe0..dc5f5de 100644 --- a/libgo/go/net/rpc/server_test.go +++ b/libgo/go/net/rpc/server_test.go @@ -427,7 +427,7 @@ func (codec *CodecEmulator) ReadRequestHeader(req *Request) error { return nil } -func (codec *CodecEmulator) ReadRequestBody(argv interface{}) error { +func (codec *CodecEmulator) ReadRequestBody(argv any) error { if codec.args == nil { return io.ErrUnexpectedEOF } @@ -435,7 +435,7 @@ func (codec *CodecEmulator) ReadRequestBody(argv interface{}) error { return nil } -func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) error { +func (codec *CodecEmulator) WriteResponse(resp *Response, reply any) error { if resp.Error != "" { codec.err = errors.New(resp.Error) } else { @@ -521,7 +521,7 @@ func TestRegistrationError(t *testing.T) { type WriteFailCodec int -func (WriteFailCodec) WriteRequest(*Request, interface{}) error { +func (WriteFailCodec) WriteRequest(*Request, any) error { // the panic caused by this error used to not unlock a lock. return errors.New("fail") } @@ -530,7 +530,7 @@ func (WriteFailCodec) ReadResponseHeader(*Response) error { select {} } -func (WriteFailCodec) ReadResponseBody(interface{}) error { +func (WriteFailCodec) ReadResponseBody(any) error { select {} } diff --git a/libgo/go/net/sendfile_stub.go b/libgo/go/net/sendfile_stub.go index 5753bc0..7428da3 100644 --- a/libgo/go/net/sendfile_stub.go +++ b/libgo/go/net/sendfile_stub.go @@ -2,8 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build aix || darwin || (js && wasm) || netbsd || openbsd -// +build aix darwin js,wasm netbsd openbsd +//go:build aix || (js && wasm) || netbsd || openbsd || ios package net diff --git a/libgo/go/net/sendfile_test.go b/libgo/go/net/sendfile_test.go index 54e51fa..6edfb67 100644 --- a/libgo/go/net/sendfile_test.go +++ b/libgo/go/net/sendfile_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js -// +build !js package net @@ -28,10 +27,7 @@ const ( ) func TestSendfile(t *testing.T) { - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") defer ln.Close() errc := make(chan error, 1) @@ -98,10 +94,7 @@ func TestSendfile(t *testing.T) { } func TestSendfileParts(t *testing.T) { - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") defer ln.Close() errc := make(chan error, 1) @@ -156,10 +149,7 @@ func TestSendfileParts(t *testing.T) { } func TestSendfileSeeked(t *testing.T) { - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") defer ln.Close() const seekTo = 65 << 10 @@ -226,10 +216,7 @@ func TestSendfilePipe(t *testing.T) { t.Parallel() - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") defer ln.Close() r, w, err := os.Pipe() @@ -318,10 +305,7 @@ func TestSendfilePipe(t *testing.T) { // Issue 43822: tests that returns EOF when conn write timeout. func TestSendfileOnWriteTimeoutExceeded(t *testing.T) { - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") defer ln.Close() errc := make(chan error, 1) diff --git a/libgo/go/net/sendfile_unix_alt.go b/libgo/go/net/sendfile_unix_alt.go index 54667d6..f99af92 100644 --- a/libgo/go/net/sendfile_unix_alt.go +++ b/libgo/go/net/sendfile_unix_alt.go @@ -2,8 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build dragonfly || freebsd || solaris -// +build dragonfly freebsd solaris +//go:build (darwin && !ios) || dragonfly || freebsd || solaris package net diff --git a/libgo/go/net/server_test.go b/libgo/go/net/server_test.go index 7cbf152..6796d79 100644 --- a/libgo/go/net/server_test.go +++ b/libgo/go/net/server_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js -// +build !js package net @@ -78,10 +77,7 @@ func TestTCPServer(t *testing.T) { } }() for i := 0; i < N; i++ { - ls, err := (&streamListener{Listener: ln}).newLocalServer() - if err != nil { - t.Fatal(err) - } + ls := (&streamListener{Listener: ln}).newLocalServer() lss = append(lss, ls) tpchs = append(tpchs, make(chan error, 1)) } @@ -126,19 +122,19 @@ func TestTCPServer(t *testing.T) { } } -var unixAndUnixpacketServerTests = []struct { - network, address string -}{ - {"unix", testUnixAddr()}, - {"unix", "@nettest/go/unix"}, - - {"unixpacket", testUnixAddr()}, - {"unixpacket", "@nettest/go/unixpacket"}, -} - // TestUnixAndUnixpacketServer tests concurrent accept-read-write // servers func TestUnixAndUnixpacketServer(t *testing.T) { + var unixAndUnixpacketServerTests = []struct { + network, address string + }{ + {"unix", testUnixAddr(t)}, + {"unix", "@nettest/go/unix"}, + + {"unixpacket", testUnixAddr(t)}, + {"unixpacket", "@nettest/go/unixpacket"}, + } + const N = 3 for i, tt := range unixAndUnixpacketServerTests { @@ -163,10 +159,7 @@ func TestUnixAndUnixpacketServer(t *testing.T) { } }() for i := 0; i < N; i++ { - ls, err := (&streamListener{Listener: ln}).newLocalServer() - if err != nil { - t.Fatal(err) - } + ls := (&streamListener{Listener: ln}).newLocalServer() lss = append(lss, ls) tpchs = append(tpchs, make(chan error, 1)) } @@ -188,7 +181,11 @@ func TestUnixAndUnixpacketServer(t *testing.T) { } t.Fatal(err) } - defer os.Remove(c.LocalAddr().String()) + + if addr := c.LocalAddr(); addr != nil { + t.Logf("connected %s->%s", addr, lss[i].Listener.Addr()) + } + defer c.Close() trchs = append(trchs, make(chan error, 1)) go transceiver(c, []byte("UNIX AND UNIXPACKET SERVER TEST"), trchs[i]) @@ -267,10 +264,7 @@ func TestUDPServer(t *testing.T) { t.Fatal(err) } - ls, err := (&packetListener{PacketConn: c1}).newLocalServer() - if err != nil { - t.Fatal(err) - } + ls := (&packetListener{PacketConn: c1}).newLocalServer() defer ls.teardown() tpch := make(chan error, 1) handler := func(ls *localPacketServer, c PacketConn) { packetTransponder(c, tpch) } @@ -319,18 +313,18 @@ func TestUDPServer(t *testing.T) { } } -var unixgramServerTests = []struct { - saddr string // server endpoint - caddr string // client endpoint - dial bool // test with Dial -}{ - {saddr: testUnixAddr(), caddr: testUnixAddr()}, - {saddr: testUnixAddr(), caddr: testUnixAddr(), dial: true}, - - {saddr: "@nettest/go/unixgram/server", caddr: "@nettest/go/unixgram/client"}, -} - func TestUnixgramServer(t *testing.T) { + var unixgramServerTests = []struct { + saddr string // server endpoint + caddr string // client endpoint + dial bool // test with Dial + }{ + {saddr: testUnixAddr(t), caddr: testUnixAddr(t)}, + {saddr: testUnixAddr(t), caddr: testUnixAddr(t), dial: true}, + + {saddr: "@nettest/go/unixgram/server", caddr: "@nettest/go/unixgram/client"}, + } + for i, tt := range unixgramServerTests { if !testableListenArgs("unixgram", tt.saddr, "") { t.Logf("skipping %s test", "unixgram "+tt.saddr+"<-"+tt.caddr) @@ -345,10 +339,7 @@ func TestUnixgramServer(t *testing.T) { t.Fatal(err) } - ls, err := (&packetListener{PacketConn: c1}).newLocalServer() - if err != nil { - t.Fatal(err) - } + ls := (&packetListener{PacketConn: c1}).newLocalServer() defer ls.teardown() tpch := make(chan error, 1) handler := func(ls *localPacketServer, c PacketConn) { packetTransponder(c, tpch) } diff --git a/libgo/go/net/smtp/smtp.go b/libgo/go/net/smtp/smtp.go index 1a6864a..c1f00a0 100644 --- a/libgo/go/net/smtp/smtp.go +++ b/libgo/go/net/smtp/smtp.go @@ -105,7 +105,7 @@ func (c *Client) Hello(localName string) error { } // cmd is a convenience function that sends a command and returns the response -func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, string, error) { +func (c *Client) cmd(expectCode int, format string, args ...any) (int, string, error) { id, err := c.Text.Cmd(format, args...) if err != nil { return 0, "", err @@ -136,12 +136,8 @@ func (c *Client) ehlo() error { if len(extList) > 1 { extList = extList[1:] for _, line := range extList { - args := strings.SplitN(line, " ", 2) - if len(args) > 1 { - ext[args[0]] = args[1] - } else { - ext[args[0]] = "" - } + k, v, _ := strings.Cut(line, " ") + ext[k] = v } } if mechs, ok := ext["AUTH"]; ok { diff --git a/libgo/go/net/smtp/smtp_test.go b/libgo/go/net/smtp/smtp_test.go index 5521937..0f758f4 100644 --- a/libgo/go/net/smtp/smtp_test.go +++ b/libgo/go/net/smtp/smtp_test.go @@ -948,7 +948,7 @@ QUIT ` func TestTLSClient(t *testing.T) { - if (runtime.GOOS == "freebsd" && runtime.GOARCH == "amd64") || runtime.GOOS == "js" { + if runtime.GOOS == "freebsd" || runtime.GOOS == "js" { testenv.SkipFlaky(t, 19229) } ln := newLocalListener(t) diff --git a/libgo/go/net/sock_bsd.go b/libgo/go/net/sock_bsd.go index 4c883ad..27daf72 100644 --- a/libgo/go/net/sock_bsd.go +++ b/libgo/go/net/sock_bsd.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build darwin || dragonfly || freebsd || netbsd || openbsd -// +build darwin dragonfly freebsd netbsd openbsd package net diff --git a/libgo/go/net/sock_cloexec.go b/libgo/go/net/sock_cloexec.go index cb57bb4..6321dbc 100644 --- a/libgo/go/net/sock_cloexec.go +++ b/libgo/go/net/sock_cloexec.go @@ -6,7 +6,6 @@ // setting SetNonblock and CloseOnExec. //go:build dragonfly || freebsd || hurd || illumos || linux || netbsd || openbsd -// +build dragonfly freebsd hurd illumos linux netbsd openbsd package net diff --git a/libgo/go/net/sock_posix.go b/libgo/go/net/sock_posix.go index 8c09b0b..fbdec81 100644 --- a/libgo/go/net/sock_posix.go +++ b/libgo/go/net/sock_posix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris || windows -// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris windows package net diff --git a/libgo/go/net/sock_stub.go b/libgo/go/net/sock_stub.go index 1e5032e..e5883d02 100644 --- a/libgo/go/net/sock_stub.go +++ b/libgo/go/net/sock_stub.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || hurd || (js && wasm) || solaris -// +build aix hurd js,wasm solaris package net diff --git a/libgo/go/net/sockaddr_posix.go b/libgo/go/net/sockaddr_posix.go index 618d85f..050eac7 100644 --- a/libgo/go/net/sockaddr_posix.go +++ b/libgo/go/net/sockaddr_posix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || (js && wasm) || linux || netbsd || openbsd || solaris || windows -// +build aix darwin dragonfly freebsd hurd js,wasm linux netbsd openbsd solaris windows package net diff --git a/libgo/go/net/sockopt_bsd.go b/libgo/go/net/sockopt_bsd.go index e52fa88..8934e4c 100644 --- a/libgo/go/net/sockopt_bsd.go +++ b/libgo/go/net/sockopt_bsd.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build darwin || dragonfly || freebsd || netbsd || openbsd -// +build darwin dragonfly freebsd netbsd openbsd package net diff --git a/libgo/go/net/sockopt_posix.go b/libgo/go/net/sockopt_posix.go index 3478872..1d92668 100644 --- a/libgo/go/net/sockopt_posix.go +++ b/libgo/go/net/sockopt_posix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris || windows -// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris windows package net diff --git a/libgo/go/net/sockopt_stub.go b/libgo/go/net/sockopt_stub.go index 99b5277..98e2371 100644 --- a/libgo/go/net/sockopt_stub.go +++ b/libgo/go/net/sockopt_stub.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build js && wasm -// +build js,wasm package net diff --git a/libgo/go/net/sockoptip_bsdvar.go b/libgo/go/net/sockoptip_bsdvar.go index 8b0b5d2..696fa30 100644 --- a/libgo/go/net/sockoptip_bsdvar.go +++ b/libgo/go/net/sockoptip_bsdvar.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || netbsd || openbsd || solaris -// +build aix darwin dragonfly freebsd hurd netbsd openbsd solaris package net diff --git a/libgo/go/net/sockoptip_posix.go b/libgo/go/net/sockoptip_posix.go index a063e79..3d47afd 100644 --- a/libgo/go/net/sockoptip_posix.go +++ b/libgo/go/net/sockoptip_posix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris || windows -// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris windows package net diff --git a/libgo/go/net/sockoptip_stub.go b/libgo/go/net/sockoptip_stub.go index 4175922..2c993eb 100644 --- a/libgo/go/net/sockoptip_stub.go +++ b/libgo/go/net/sockoptip_stub.go @@ -3,38 +3,31 @@ // license that can be found in the LICENSE file. //go:build js && wasm -// +build js,wasm package net import "syscall" func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error { - // See golang.org/issue/7399. return syscall.ENOPROTOOPT } func setIPv4MulticastLoopback(fd *netFD, v bool) error { - // See golang.org/issue/7399. return syscall.ENOPROTOOPT } func joinIPv4Group(fd *netFD, ifi *Interface, ip IP) error { - // See golang.org/issue/7399. return syscall.ENOPROTOOPT } func setIPv6MulticastInterface(fd *netFD, ifi *Interface) error { - // See golang.org/issue/7399. return syscall.ENOPROTOOPT } func setIPv6MulticastLoopback(fd *netFD, v bool) error { - // See golang.org/issue/7399. return syscall.ENOPROTOOPT } func joinIPv6Group(fd *netFD, ifi *Interface, ip IP) error { - // See golang.org/issue/7399. return syscall.ENOPROTOOPT } diff --git a/libgo/go/net/splice_stub.go b/libgo/go/net/splice_stub.go index ce2e904..3cdadb1 100644 --- a/libgo/go/net/splice_stub.go +++ b/libgo/go/net/splice_stub.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !linux -// +build !linux package net diff --git a/libgo/go/net/splice_test.go b/libgo/go/net/splice_test.go index d5f6367..5ad9fcd 100644 --- a/libgo/go/net/splice_test.go +++ b/libgo/go/net/splice_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build linux -// +build linux package net @@ -47,20 +46,14 @@ type spliceTestCase struct { } func (tc spliceTestCase) test(t *testing.T) { - clientUp, serverUp, err := spliceTestSocketPair(tc.upNet) - if err != nil { - t.Fatal(err) - } + clientUp, serverUp := spliceTestSocketPair(t, tc.upNet) defer serverUp.Close() cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.totalSize) if err != nil { t.Fatal(err) } defer cleanup() - clientDown, serverDown, err := spliceTestSocketPair(tc.downNet) - if err != nil { - t.Fatal(err) - } + clientDown, serverDown := spliceTestSocketPair(t, tc.downNet) defer serverDown.Close() cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.totalSize) if err != nil { @@ -104,15 +97,9 @@ func (tc spliceTestCase) test(t *testing.T) { } func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) { - clientUp, serverUp, err := spliceTestSocketPair(upNet) - if err != nil { - t.Fatal(err) - } + clientUp, serverUp := spliceTestSocketPair(t, upNet) defer clientUp.Close() - clientDown, serverDown, err := spliceTestSocketPair(downNet) - if err != nil { - t.Fatal(err) - } + clientDown, serverDown := spliceTestSocketPair(t, downNet) defer clientDown.Close() serverUp.Close() @@ -141,7 +128,7 @@ func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) { }() buf := make([]byte, 3) - _, err = io.ReadFull(clientDown, buf) + _, err := io.ReadFull(clientDown, buf) if err != nil { t.Errorf("clientDown: %v", err) } @@ -151,15 +138,9 @@ func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) { } func testSpliceIssue25985(t *testing.T, upNet, downNet string) { - front, err := newLocalListener(upNet) - if err != nil { - t.Fatal(err) - } + front := newLocalListener(t, upNet) defer front.Close() - back, err := newLocalListener(downNet) - if err != nil { - t.Fatal(err) - } + back := newLocalListener(t, downNet) defer back.Close() var wg sync.WaitGroup @@ -211,16 +192,10 @@ func testSpliceIssue25985(t *testing.T, upNet, downNet string) { } func testSpliceNoUnixpacket(t *testing.T) { - clientUp, serverUp, err := spliceTestSocketPair("unixpacket") - if err != nil { - t.Fatal(err) - } + clientUp, serverUp := spliceTestSocketPair(t, "unixpacket") defer clientUp.Close() defer serverUp.Close() - clientDown, serverDown, err := spliceTestSocketPair("tcp") - if err != nil { - t.Fatal(err) - } + clientDown, serverDown := spliceTestSocketPair(t, "tcp") defer clientDown.Close() defer serverDown.Close() // If splice called poll.Splice here, we'd get err == syscall.EINVAL @@ -238,7 +213,7 @@ func testSpliceNoUnixpacket(t *testing.T) { } func testSpliceNoUnixgram(t *testing.T) { - addr, err := ResolveUnixAddr("unixgram", testUnixAddr()) + addr, err := ResolveUnixAddr("unixgram", testUnixAddr(t)) if err != nil { t.Fatal(err) } @@ -248,10 +223,7 @@ func testSpliceNoUnixgram(t *testing.T) { t.Fatal(err) } defer up.Close() - clientDown, serverDown, err := spliceTestSocketPair("tcp") - if err != nil { - t.Fatal(err) - } + clientDown, serverDown := spliceTestSocketPair(t, "tcp") defer clientDown.Close() defer serverDown.Close() // Analogous to testSpliceNoUnixpacket. @@ -285,10 +257,7 @@ func (tc spliceTestCase) bench(b *testing.B) { // To benchmark the genericReadFrom code path, set this to false. useSplice := true - clientUp, serverUp, err := spliceTestSocketPair(tc.upNet) - if err != nil { - b.Fatal(err) - } + clientUp, serverUp := spliceTestSocketPair(b, tc.upNet) defer serverUp.Close() cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.chunkSize*b.N) @@ -297,10 +266,7 @@ func (tc spliceTestCase) bench(b *testing.B) { } defer cleanup() - clientDown, serverDown, err := spliceTestSocketPair(tc.downNet) - if err != nil { - b.Fatal(err) - } + clientDown, serverDown := spliceTestSocketPair(b, tc.downNet) defer serverDown.Close() cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.chunkSize*b.N) @@ -328,11 +294,9 @@ func (tc spliceTestCase) bench(b *testing.B) { } } -func spliceTestSocketPair(net string) (client, server Conn, err error) { - ln, err := newLocalListener(net) - if err != nil { - return nil, nil, err - } +func spliceTestSocketPair(t testing.TB, net string) (client, server Conn) { + t.Helper() + ln := newLocalListener(t, net) defer ln.Close() var cerr, serr error acceptDone := make(chan struct{}) @@ -346,15 +310,15 @@ func spliceTestSocketPair(net string) (client, server Conn, err error) { if server != nil { server.Close() } - return nil, nil, cerr + t.Fatal(cerr) } if serr != nil { if client != nil { client.Close() } - return nil, nil, serr + t.Fatal(serr) } - return client, server, nil + return client, server } func startSpliceClient(conn Conn, op string, chunkSize, totalSize int) (func(), error) { diff --git a/libgo/go/net/sys_cloexec.go b/libgo/go/net/sys_cloexec.go index a32483e..26eac55 100644 --- a/libgo/go/net/sys_cloexec.go +++ b/libgo/go/net/sys_cloexec.go @@ -6,7 +6,6 @@ // for setting SetNonblock and CloseOnExec. //go:build aix || darwin || (solaris && !illumos) -// +build aix darwin solaris,!illumos package net diff --git a/libgo/go/net/tcpsock.go b/libgo/go/net/tcpsock.go index 19a90143..6bad0e8 100644 --- a/libgo/go/net/tcpsock.go +++ b/libgo/go/net/tcpsock.go @@ -8,6 +8,7 @@ import ( "context" "internal/itoa" "io" + "net/netip" "os" "syscall" "time" @@ -23,6 +24,20 @@ type TCPAddr struct { Zone string // IPv6 scoped addressing zone } +// AddrPort returns the TCPAddr a as a netip.AddrPort. +// +// If a.Port does not fit in a uint16, it's silently truncated. +// +// If a is nil, a zero value is returned. +func (a *TCPAddr) AddrPort() netip.AddrPort { + if a == nil { + return netip.AddrPort{} + } + na, _ := netip.AddrFromSlice(a.IP) + na = na.WithZone(a.Zone) + return netip.AddrPortFrom(na, uint16(a.Port)) +} + // Network returns the address's network name, "tcp". func (a *TCPAddr) Network() string { return "tcp" } @@ -81,6 +96,17 @@ func ResolveTCPAddr(network, address string) (*TCPAddr, error) { return addrs.forResolve(network, address).(*TCPAddr), nil } +// TCPAddrFromAddrPort returns addr as a TCPAddr. If addr.IsValid() is false, +// then the returned TCPAddr will contain a nil IP field, indicating an +// address family-agnostic unspecified address. +func TCPAddrFromAddrPort(addr netip.AddrPort) *TCPAddr { + return &TCPAddr{ + IP: addr.Addr().AsSlice(), + Zone: addr.Addr().Zone(), + Port: int(addr.Port()), + } +} + // TCPConn is an implementation of the Conn interface for TCP network // connections. type TCPConn struct { diff --git a/libgo/go/net/tcpsock_posix.go b/libgo/go/net/tcpsock_posix.go index 9fd7822..8237909 100644 --- a/libgo/go/net/tcpsock_posix.go +++ b/libgo/go/net/tcpsock_posix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || (js && wasm) || linux || netbsd || openbsd || solaris || windows -// +build aix darwin dragonfly freebsd hurd js,wasm linux netbsd openbsd solaris windows package net diff --git a/libgo/go/net/tcpsock_test.go b/libgo/go/net/tcpsock_test.go index 884c5cb..5cff961 100644 --- a/libgo/go/net/tcpsock_test.go +++ b/libgo/go/net/tcpsock_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js -// +build !js package net @@ -388,10 +387,7 @@ func TestIPv6LinkLocalUnicastTCP(t *testing.T) { t.Log(err) continue } - ls, err := (&streamListener{Listener: ln}).newLocalServer() - if err != nil { - t.Fatal(err) - } + ls := (&streamListener{Listener: ln}).newLocalServer() defer ls.teardown() ch := make(chan error, 1) handler := func(ls *localServer, ln Listener) { ls.transponder(ln, ch) } @@ -632,10 +628,7 @@ func TestTCPSelfConnect(t *testing.T) { t.Skip("known-broken test on windows") } - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") var d Dialer c, err := d.Dial(ln.Addr().Network(), ln.Addr().String()) if err != nil { @@ -682,10 +675,7 @@ func TestTCPBig(t *testing.T) { for _, writev := range []bool{false, true} { t.Run(fmt.Sprintf("writev=%v", writev), func(t *testing.T) { - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") defer ln.Close() x := int(1 << 30) @@ -729,10 +719,7 @@ func TestTCPBig(t *testing.T) { } func TestCopyPipeIntoTCP(t *testing.T) { - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") defer ln.Close() errc := make(chan error, 1) @@ -800,10 +787,7 @@ func TestCopyPipeIntoTCP(t *testing.T) { } func BenchmarkSetReadDeadline(b *testing.B) { - ln, err := newLocalListener("tcp") - if err != nil { - b.Fatal(err) - } + ln := newLocalListener(b, "tcp") defer ln.Close() var serv Conn done := make(chan error) diff --git a/libgo/go/net/tcpsock_unix_test.go b/libgo/go/net/tcpsock_unix_test.go index 41bd229..b14670b 100644 --- a/libgo/go/net/tcpsock_unix_test.go +++ b/libgo/go/net/tcpsock_unix_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js && !plan9 && !windows -// +build !js,!plan9,!windows package net @@ -23,10 +22,7 @@ func TestTCPSpuriousConnSetupCompletion(t *testing.T) { t.Skip("skipping in short mode") } - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") var wg sync.WaitGroup wg.Add(1) go func(ln Listener) { diff --git a/libgo/go/net/tcpsockopt_posix.go b/libgo/go/net/tcpsockopt_posix.go index 4c99ab8..ad54d1b 100644 --- a/libgo/go/net/tcpsockopt_posix.go +++ b/libgo/go/net/tcpsockopt_posix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || linux || netbsd || openbsd || solaris || windows -// +build aix darwin dragonfly freebsd hurd linux netbsd openbsd solaris windows package net diff --git a/libgo/go/net/tcpsockopt_stub.go b/libgo/go/net/tcpsockopt_stub.go index 028d5fd..0fe9182 100644 --- a/libgo/go/net/tcpsockopt_stub.go +++ b/libgo/go/net/tcpsockopt_stub.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build js && wasm -// +build js,wasm package net diff --git a/libgo/go/net/tcpsockopt_unix.go b/libgo/go/net/tcpsockopt_unix.go index cc0662a..edcab44 100644 --- a/libgo/go/net/tcpsockopt_unix.go +++ b/libgo/go/net/tcpsockopt_unix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || freebsd || hurd || linux || netbsd -// +build aix freebsd hurd linux netbsd package net diff --git a/libgo/go/net/textproto/reader.go b/libgo/go/net/textproto/reader.go index 5c3084f..157c59b 100644 --- a/libgo/go/net/textproto/reader.go +++ b/libgo/go/net/textproto/reader.go @@ -460,6 +460,8 @@ func (r *Reader) ReadDotLines() ([]string, error) { return v, err } +var colon = []byte(":") + // ReadMIMEHeader reads a MIME-style header from r. // The header is a sequence of possibly continued Key: Value lines // ending in a blank line. @@ -508,11 +510,11 @@ func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) { } // Key ends at first colon. - i := bytes.IndexByte(kv, ':') - if i < 0 { + k, v, ok := bytes.Cut(kv, colon) + if !ok { return m, ProtocolError("malformed MIME header line: " + string(kv)) } - key := canonicalMIMEHeaderKey(kv[:i]) + key := canonicalMIMEHeaderKey(k) // As per RFC 7230 field-name is a token, tokens consist of one or more chars. // We could return a ProtocolError here, but better to be liberal in what we @@ -522,11 +524,7 @@ func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) { } // Skip initial spaces in value. - i++ // skip colon - for i < len(kv) && (kv[i] == ' ' || kv[i] == '\t') { - i++ - } - value := string(kv[i:]) + value := strings.TrimLeft(string(v), " \t") vv := m[key] if vv == nil && len(strs) > 0 { @@ -561,6 +559,8 @@ func mustHaveFieldNameColon(line []byte) error { return nil } +var nl = []byte("\n") + // upcomingHeaderNewlines returns an approximation of the number of newlines // that will be in this header. If it gets confused, it returns 0. func (r *Reader) upcomingHeaderNewlines() (n int) { @@ -571,17 +571,7 @@ func (r *Reader) upcomingHeaderNewlines() (n int) { return } peek, _ := r.R.Peek(s) - for len(peek) > 0 { - i := bytes.IndexByte(peek, '\n') - if i < 3 { - // Not present (-1) or found within the next few bytes, - // implying we're at the end ("\r\n\r\n" or "\n\n") - return - } - n++ - peek = peek[i+1:] - } - return + return bytes.Count(peek, nl) } // CanonicalMIMEHeaderKey returns the canonical format of the diff --git a/libgo/go/net/textproto/textproto.go b/libgo/go/net/textproto/textproto.go index 8fd781e..cc1a847 100644 --- a/libgo/go/net/textproto/textproto.go +++ b/libgo/go/net/textproto/textproto.go @@ -111,7 +111,7 @@ func Dial(network, addr string) (*Conn, error) { // } // return c.ReadCodeLine(250) // -func (c *Conn) Cmd(format string, args ...interface{}) (id uint, err error) { +func (c *Conn) Cmd(format string, args ...any) (id uint, err error) { id = c.Next() c.StartRequest(id) err = c.PrintfLine(format, args...) diff --git a/libgo/go/net/textproto/writer.go b/libgo/go/net/textproto/writer.go index 33c146c..2ece3f5 100644 --- a/libgo/go/net/textproto/writer.go +++ b/libgo/go/net/textproto/writer.go @@ -26,7 +26,7 @@ var crnl = []byte{'\r', '\n'} var dotcrnl = []byte{'.', '\r', '\n'} // PrintfLine writes the formatted output followed by \r\n. -func (w *Writer) PrintfLine(format string, args ...interface{}) error { +func (w *Writer) PrintfLine(format string, args ...any) error { w.closeDot() fmt.Fprintf(w.W, format, args...) w.W.Write(crnl) diff --git a/libgo/go/net/timeout_test.go b/libgo/go/net/timeout_test.go index e1cf146..d1cfbf8 100644 --- a/libgo/go/net/timeout_test.go +++ b/libgo/go/net/timeout_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js -// +build !js package net @@ -93,53 +92,35 @@ func TestDialTimeout(t *testing.T) { } } -var dialTimeoutMaxDurationTests = []struct { - timeout time.Duration - delta time.Duration // for deadline -}{ - // Large timeouts that will overflow an int64 unix nanos. - {1<<63 - 1, 0}, - {0, 1<<63 - 1}, -} - func TestDialTimeoutMaxDuration(t *testing.T) { - if runtime.GOOS == "openbsd" { - testenv.SkipFlaky(t, 15157) - } - - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } - defer ln.Close() + ln := newLocalListener(t, "tcp") + defer func() { + if err := ln.Close(); err != nil { + t.Error(err) + } + }() - for i, tt := range dialTimeoutMaxDurationTests { - ch := make(chan error) - max := time.NewTimer(250 * time.Millisecond) - defer max.Stop() - go func() { + for _, tt := range []struct { + timeout time.Duration + delta time.Duration // for deadline + }{ + // Large timeouts that will overflow an int64 unix nanos. + {1<<63 - 1, 0}, + {0, 1<<63 - 1}, + } { + t.Run(fmt.Sprintf("timeout=%s/delta=%s", tt.timeout, tt.delta), func(t *testing.T) { d := Dialer{Timeout: tt.timeout} if tt.delta != 0 { d.Deadline = time.Now().Add(tt.delta) } c, err := d.Dial(ln.Addr().Network(), ln.Addr().String()) - if err == nil { - c.Close() - } - ch <- err - }() - - select { - case <-max.C: - t.Fatalf("#%d: Dial didn't return in an expected time", i) - case err := <-ch: - if perr := parseDialError(err); perr != nil { - t.Error(perr) - } if err != nil { - t.Errorf("#%d: %v", i, err) + t.Fatal(err) } - } + if err := c.Close(); err != nil { + t.Error(err) + } + }) } } @@ -163,10 +144,7 @@ func TestAcceptTimeout(t *testing.T) { t.Skipf("not supported on %s", runtime.GOOS) } - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") defer ln.Close() var wg sync.WaitGroup @@ -219,10 +197,7 @@ func TestAcceptTimeoutMustReturn(t *testing.T) { t.Skipf("not supported on %s", runtime.GOOS) } - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") defer ln.Close() max := time.NewTimer(time.Second) @@ -265,10 +240,7 @@ func TestAcceptTimeoutMustNotReturn(t *testing.T) { t.Skipf("not supported on %s", runtime.GOOS) } - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") defer ln.Close() max := time.NewTimer(100 * time.Millisecond) @@ -318,10 +290,7 @@ func TestReadTimeout(t *testing.T) { c.Write([]byte("READ TIMEOUT TEST")) defer c.Close() } - ls, err := newLocalServer("tcp") - if err != nil { - t.Fatal(err) - } + ls := newLocalServer(t, "tcp") defer ls.teardown() if err := ls.buildup(handler); err != nil { t.Fatal(err) @@ -370,10 +339,7 @@ func TestReadTimeoutMustNotReturn(t *testing.T) { t.Skipf("not supported on %s", runtime.GOOS) } - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") defer ln.Close() c, err := Dial(ln.Addr().Network(), ln.Addr().String()) @@ -437,10 +403,7 @@ func TestReadFromTimeout(t *testing.T) { c.WriteTo([]byte("READFROM TIMEOUT TEST"), dst) } } - ls, err := newLocalPacketServer("udp") - if err != nil { - t.Fatal(err) - } + ls := newLocalPacketServer(t, "udp") defer ls.teardown() if err := ls.buildup(handler); err != nil { t.Fatal(err) @@ -500,10 +463,7 @@ var writeTimeoutTests = []struct { func TestWriteTimeout(t *testing.T) { t.Parallel() - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") defer ln.Close() for i, tt := range writeTimeoutTests { @@ -548,10 +508,7 @@ func TestWriteTimeoutMustNotReturn(t *testing.T) { t.Skipf("not supported on %s", runtime.GOOS) } - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") defer ln.Close() c, err := Dial(ln.Addr().Network(), ln.Addr().String()) @@ -600,24 +557,10 @@ func TestWriteTimeoutMustNotReturn(t *testing.T) { } } -var writeToTimeoutTests = []struct { - timeout time.Duration - xerrs [2]error // expected errors in transition -}{ - // Tests that write deadlines work, even if there's buffer - // space available to write. - {-5 * time.Second, [2]error{os.ErrDeadlineExceeded, os.ErrDeadlineExceeded}}, - - {10 * time.Millisecond, [2]error{nil, os.ErrDeadlineExceeded}}, -} - func TestWriteToTimeout(t *testing.T) { t.Parallel() - c1, err := newLocalPacketListener("udp") - if err != nil { - t.Fatal(err) - } + c1 := newLocalPacketListener(t, "udp") defer c1.Close() host, _, err := SplitHostPort(c1.LocalAddr().String()) @@ -625,47 +568,116 @@ func TestWriteToTimeout(t *testing.T) { t.Fatal(err) } - for i, tt := range writeToTimeoutTests { - c2, err := ListenPacket(c1.LocalAddr().Network(), JoinHostPort(host, "0")) - if err != nil { - t.Fatal(err) - } - defer c2.Close() + timeouts := []time.Duration{ + -5 * time.Second, + 10 * time.Millisecond, + } - if err := c2.SetWriteDeadline(time.Now().Add(tt.timeout)); err != nil { - t.Fatalf("#%d: %v", i, err) - } - for j, xerr := range tt.xerrs { - for { + for _, timeout := range timeouts { + t.Run(fmt.Sprint(timeout), func(t *testing.T) { + c2, err := ListenPacket(c1.LocalAddr().Network(), JoinHostPort(host, "0")) + if err != nil { + t.Fatal(err) + } + defer c2.Close() + + if err := c2.SetWriteDeadline(time.Now().Add(timeout)); err != nil { + t.Fatalf("SetWriteDeadline: %v", err) + } + backoff := 1 * time.Millisecond + nDeadlineExceeded := 0 + for j := 0; nDeadlineExceeded < 2; j++ { n, err := c2.WriteTo([]byte("WRITETO TIMEOUT TEST"), c1.LocalAddr()) - if xerr != nil { - if perr := parseWriteError(err); perr != nil { - t.Errorf("#%d/%d: %v", i, j, perr) - } - if !isDeadlineExceeded(err) { - t.Fatalf("#%d/%d: %v", i, j, err) - } + t.Logf("#%d: WriteTo: %d, %v", j, n, err) + if err == nil && timeout >= 0 && nDeadlineExceeded == 0 { + // If the timeout is nonnegative, some number of WriteTo calls may + // succeed before the timeout takes effect. + t.Logf("WriteTo succeeded; sleeping %v", timeout/3) + time.Sleep(timeout / 3) + continue } - if err == nil { - time.Sleep(tt.timeout / 3) + if isENOBUFS(err) { + t.Logf("WriteTo: %v", err) + // We're looking for a deadline exceeded error, but if the kernel's + // network buffers are saturated we may see ENOBUFS instead (see + // https://go.dev/issue/49930). Give it some time to unsaturate. + time.Sleep(backoff) + backoff *= 2 continue } + if perr := parseWriteError(err); perr != nil { + t.Errorf("failed to parse error: %v", perr) + } + if !isDeadlineExceeded(err) { + t.Errorf("error is not 'deadline exceeded'") + } if n != 0 { - t.Fatalf("#%d/%d: wrote %d; want 0", i, j, n) + t.Errorf("unexpectedly wrote %d bytes", n) } - break + if !t.Failed() { + t.Logf("WriteTo timed out as expected") + } + nDeadlineExceeded++ } - } + }) } } -func TestReadTimeoutFluctuation(t *testing.T) { - t.Parallel() +const ( + // minDynamicTimeout is the minimum timeout to attempt for + // tests that automatically increase timeouts until success. + // + // Lower values may allow tests to succeed more quickly if the value is close + // to the true minimum, but may require more iterations (and waste more time + // and CPU power on failed attempts) if the timeout is too low. + minDynamicTimeout = 1 * time.Millisecond + + // maxDynamicTimeout is the maximum timeout to attempt for + // tests that automatically increase timeouts until succeess. + // + // This should be a strict upper bound on the latency required to hit a + // timeout accurately, even on a slow or heavily-loaded machine. If a test + // would increase the timeout beyond this value, the test fails. + maxDynamicTimeout = 4 * time.Second +) - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) +// timeoutUpperBound returns the maximum time that we expect a timeout of +// duration d to take to return the caller. +func timeoutUpperBound(d time.Duration) time.Duration { + switch runtime.GOOS { + case "openbsd", "netbsd": + // NetBSD and OpenBSD seem to be unable to reliably hit deadlines even when + // the absolute durations are long. + // In https://build.golang.org/log/c34f8685d020b98377dd4988cd38f0c5bd72267e, + // we observed that an openbsd-amd64-68 builder took 4.090948779s for a + // 2.983020682s timeout (37.1% overhead). + // (See https://go.dev/issue/50189 for further detail.) + // Give them lots of slop to compensate. + return d * 3 / 2 + } + // Other platforms seem to hit their deadlines more reliably, + // at least when they are long enough to cover scheduling jitter. + return d * 11 / 10 +} + +// nextTimeout returns the next timeout to try after an operation took the given +// actual duration with a timeout shorter than that duration. +func nextTimeout(actual time.Duration) (next time.Duration, ok bool) { + if actual >= maxDynamicTimeout { + return maxDynamicTimeout, false + } + // Since the previous attempt took actual, we can't expect to beat that + // duration by any significant margin. Try the next attempt with an arbitrary + // factor above that, so that our growth curve is at least exponential. + next = actual * 5 / 4 + if next > maxDynamicTimeout { + return maxDynamicTimeout, true } + return next, true +} + +func TestReadTimeoutFluctuation(t *testing.T) { + ln := newLocalListener(t, "tcp") defer ln.Close() c, err := Dial(ln.Addr().Network(), ln.Addr().String()) @@ -674,31 +686,54 @@ func TestReadTimeoutFluctuation(t *testing.T) { } defer c.Close() - max := time.NewTimer(time.Second) - defer max.Stop() - ch := make(chan error) - go timeoutReceiver(c, 100*time.Millisecond, 50*time.Millisecond, 250*time.Millisecond, ch) + d := minDynamicTimeout + b := make([]byte, 256) + for { + t.Logf("SetReadDeadline(+%v)", d) + t0 := time.Now() + deadline := t0.Add(d) + if err = c.SetReadDeadline(deadline); err != nil { + t.Fatalf("SetReadDeadline(%v): %v", deadline, err) + } + var n int + n, err = c.Read(b) + t1 := time.Now() - select { - case <-max.C: - t.Fatal("Read took over 1s; expected 0.1s") - case err := <-ch: + if n != 0 || err == nil || !err.(Error).Timeout() { + t.Errorf("Read did not return (0, timeout): (%d, %v)", n, err) + } if perr := parseReadError(err); perr != nil { t.Error(perr) } if !isDeadlineExceeded(err) { - t.Fatal(err) + t.Errorf("Read error is not DeadlineExceeded: %v", err) } + + actual := t1.Sub(t0) + if t1.Before(deadline) { + t.Errorf("Read took %s; expected at least %s", actual, d) + } + if t.Failed() { + return + } + if want := timeoutUpperBound(d); actual > want { + next, ok := nextTimeout(actual) + if !ok { + t.Fatalf("Read took %s; expected at most %v", actual, want) + } + // Maybe this machine is too slow to reliably schedule goroutines within + // the requested duration. Increase the timeout and try again. + t.Logf("Read took %s (expected %s); trying with longer timeout", actual, d) + d = next + continue + } + + break } } func TestReadFromTimeoutFluctuation(t *testing.T) { - t.Parallel() - - c1, err := newLocalPacketListener("udp") - if err != nil { - t.Fatal(err) - } + c1 := newLocalPacketListener(t, "udp") defer c1.Close() c2, err := Dial(c1.LocalAddr().Network(), c1.LocalAddr().String()) @@ -707,36 +742,59 @@ func TestReadFromTimeoutFluctuation(t *testing.T) { } defer c2.Close() - max := time.NewTimer(time.Second) - defer max.Stop() - ch := make(chan error) - go timeoutPacketReceiver(c2.(PacketConn), 100*time.Millisecond, 50*time.Millisecond, 250*time.Millisecond, ch) + d := minDynamicTimeout + b := make([]byte, 256) + for { + t.Logf("SetReadDeadline(+%v)", d) + t0 := time.Now() + deadline := t0.Add(d) + if err = c2.SetReadDeadline(deadline); err != nil { + t.Fatalf("SetReadDeadline(%v): %v", deadline, err) + } + var n int + n, _, err = c2.(PacketConn).ReadFrom(b) + t1 := time.Now() - select { - case <-max.C: - t.Fatal("ReadFrom took over 1s; expected 0.1s") - case err := <-ch: + if n != 0 || err == nil || !err.(Error).Timeout() { + t.Errorf("ReadFrom did not return (0, timeout): (%d, %v)", n, err) + } if perr := parseReadError(err); perr != nil { t.Error(perr) } if !isDeadlineExceeded(err) { - t.Fatal(err) + t.Errorf("ReadFrom error is not DeadlineExceeded: %v", err) + } + + actual := t1.Sub(t0) + if t1.Before(deadline) { + t.Errorf("ReadFrom took %s; expected at least %s", actual, d) + } + if t.Failed() { + return + } + if want := timeoutUpperBound(d); actual > want { + next, ok := nextTimeout(actual) + if !ok { + t.Fatalf("ReadFrom took %s; expected at most %s", actual, want) + } + // Maybe this machine is too slow to reliably schedule goroutines within + // the requested duration. Increase the timeout and try again. + t.Logf("ReadFrom took %s (expected %s); trying with longer timeout", actual, d) + d = next + continue } + + break } } func TestWriteTimeoutFluctuation(t *testing.T) { - t.Parallel() - switch runtime.GOOS { case "plan9": t.Skipf("not supported on %s", runtime.GOOS) } - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") defer ln.Close() c, err := Dial(ln.Addr().Network(), ln.Addr().String()) @@ -745,25 +803,67 @@ func TestWriteTimeoutFluctuation(t *testing.T) { } defer c.Close() - d := time.Second - if iOS() { - d = 3 * time.Second // see golang.org/issue/10775 - } - max := time.NewTimer(d) - defer max.Stop() - ch := make(chan error) - go timeoutTransmitter(c, 100*time.Millisecond, 50*time.Millisecond, 250*time.Millisecond, ch) + d := minDynamicTimeout + for { + t.Logf("SetWriteDeadline(+%v)", d) + t0 := time.Now() + deadline := t0.Add(d) + if err = c.SetWriteDeadline(deadline); err != nil { + t.Fatalf("SetWriteDeadline(%v): %v", deadline, err) + } + var n int64 + for { + var dn int + dn, err = c.Write([]byte("TIMEOUT TRANSMITTER")) + n += int64(dn) + if err != nil { + break + } + } + t1 := time.Now() - select { - case <-max.C: - t.Fatalf("Write took over %v; expected 0.1s", d) - case err := <-ch: + if err == nil || !err.(Error).Timeout() { + t.Fatalf("Write did not return (any, timeout): (%d, %v)", n, err) + } if perr := parseWriteError(err); perr != nil { t.Error(perr) } if !isDeadlineExceeded(err) { - t.Fatal(err) + t.Errorf("Write error is not DeadlineExceeded: %v", err) + } + + actual := t1.Sub(t0) + if t1.Before(deadline) { + t.Errorf("Write took %s; expected at least %s", actual, d) } + if t.Failed() { + return + } + if want := timeoutUpperBound(d); actual > want { + if n > 0 { + // SetWriteDeadline specifies a time โafter which I/O operations fail + // instead of blockingโ. However, the kernel's send buffer is not yet + // full, we may be able to write some arbitrary (but finite) number of + // bytes to it without blocking. + t.Logf("Wrote %d bytes into send buffer; retrying until buffer is full", n) + if d <= maxDynamicTimeout/2 { + // We don't know how long the actual write loop would have taken if + // the buffer were full, so just guess and double the duration so that + // the next attempt can make twice as much progress toward filling it. + d *= 2 + } + } else if next, ok := nextTimeout(actual); !ok { + t.Fatalf("Write took %s; expected at most %s", actual, want) + } else { + // Maybe this machine is too slow to reliably schedule goroutines within + // the requested duration. Increase the timeout and try again. + t.Logf("Write took %s (expected %s); trying with longer timeout", actual, d) + d = next + } + continue + } + + break } } @@ -819,10 +919,7 @@ func testVariousDeadlines(t *testing.T) { c.Close() } } - ls, err := newLocalServer("tcp") - if err != nil { - t.Fatal(err) - } + ls := newLocalServer(t, "tcp") defer ls.teardown() if err := ls.buildup(handler); err != nil { t.Fatal(err) @@ -860,35 +957,23 @@ func testVariousDeadlines(t *testing.T) { name := fmt.Sprintf("%v %d/%d", timeout, run, numRuns) t.Log(name) - tooSlow := time.NewTimer(5 * time.Second) - defer tooSlow.Stop() - c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String()) if err != nil { t.Fatal(err) } - ch := make(chan result, 1) - go func() { - t0 := time.Now() - if err := c.SetDeadline(t0.Add(timeout)); err != nil { - t.Error(err) - } - n, err := io.Copy(io.Discard, c) - dt := time.Since(t0) - c.Close() - ch <- result{n, err, dt} - }() + t0 := time.Now() + if err := c.SetDeadline(t0.Add(timeout)); err != nil { + t.Error(err) + } + n, err := io.Copy(io.Discard, c) + dt := time.Since(t0) + c.Close() - select { - case res := <-ch: - if nerr, ok := res.err.(Error); ok && nerr.Timeout() { - t.Logf("%v: good timeout after %v; %d bytes", name, res.d, res.n) - } else { - t.Fatalf("%v: Copy = %d, %v; want timeout", name, res.n, res.err) - } - case <-tooSlow.C: - t.Fatalf("%v: client stuck in Dial+Copy", name) + if nerr, ok := err.(Error); ok && nerr.Timeout() { + t.Logf("%v: good timeout after %v; %d bytes", name, dt, n) + } else { + t.Fatalf("%v: Copy = %d, %v; want timeout", name, n, err) } } } @@ -954,10 +1039,7 @@ func TestReadWriteProlongedTimeout(t *testing.T) { }() wg.Wait() } - ls, err := newLocalServer("tcp") - if err != nil { - t.Fatal(err) - } + ls := newLocalServer(t, "tcp") defer ls.teardown() if err := ls.buildup(handler); err != nil { t.Fatal(err) @@ -984,10 +1066,7 @@ func TestReadWriteDeadlineRace(t *testing.T) { N = 50 } - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") defer ln.Close() c, err := Dial(ln.Addr().Network(), ln.Addr().String()) @@ -1037,10 +1116,7 @@ func TestReadWriteDeadlineRace(t *testing.T) { // Issue 35367. func TestConcurrentSetDeadline(t *testing.T) { - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") defer ln.Close() const goroutines = 8 @@ -1049,6 +1125,7 @@ func TestConcurrentSetDeadline(t *testing.T) { var c [conns]Conn for i := 0; i < conns; i++ { + var err error c[i], err = Dial(ln.Addr().Network(), ln.Addr().String()) if err != nil { t.Fatal(err) diff --git a/libgo/go/net/udpsock.go b/libgo/go/net/udpsock.go index 70f2ce2..6d29a39 100644 --- a/libgo/go/net/udpsock.go +++ b/libgo/go/net/udpsock.go @@ -7,6 +7,7 @@ package net import ( "context" "internal/itoa" + "net/netip" "syscall" ) @@ -26,6 +27,20 @@ type UDPAddr struct { Zone string // IPv6 scoped addressing zone } +// AddrPort returns the UDPAddr a as a netip.AddrPort. +// +// If a.Port does not fit in a uint16, it's silently truncated. +// +// If a is nil, a zero value is returned. +func (a *UDPAddr) AddrPort() netip.AddrPort { + if a == nil { + return netip.AddrPort{} + } + na, _ := netip.AddrFromSlice(a.IP) + na = na.WithZone(a.Zone) + return netip.AddrPortFrom(na, uint16(a.Port)) +} + // Network returns the address's network name, "udp". func (a *UDPAddr) Network() string { return "udp" } @@ -84,6 +99,24 @@ func ResolveUDPAddr(network, address string) (*UDPAddr, error) { return addrs.forResolve(network, address).(*UDPAddr), nil } +// UDPAddrFromAddrPort returns addr as a UDPAddr. If addr.IsValid() is false, +// then the returned UDPAddr will contain a nil IP field, indicating an +// address family-agnostic unspecified address. +func UDPAddrFromAddrPort(addr netip.AddrPort) *UDPAddr { + return &UDPAddr{ + IP: addr.Addr().AsSlice(), + Zone: addr.Addr().Zone(), + Port: int(addr.Port()), + } +} + +// An addrPortUDPAddr is a netip.AddrPort-based UDP address that satisfies the Addr interface. +type addrPortUDPAddr struct { + netip.AddrPort +} + +func (addrPortUDPAddr) Network() string { return "udp" } + // UDPConn is the implementation of the Conn and PacketConn interfaces // for UDP network connections. type UDPConn struct { @@ -130,6 +163,18 @@ func (c *UDPConn) ReadFrom(b []byte) (int, Addr, error) { return n, addr, err } +// ReadFromUDPAddrPort acts like ReadFrom but returns a netip.AddrPort. +func (c *UDPConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { + if !c.ok() { + return 0, netip.AddrPort{}, syscall.EINVAL + } + n, addr, err = c.readFromAddrPort(b) + if err != nil { + err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err} + } + return n, addr, err +} + // ReadMsgUDP reads a message from c, copying the payload into b and // the associated out-of-band data into oob. It returns the number of // bytes copied into b, the number of bytes copied into oob, the flags @@ -138,8 +183,18 @@ func (c *UDPConn) ReadFrom(b []byte) (int, Addr, error) { // The packages golang.org/x/net/ipv4 and golang.org/x/net/ipv6 can be // used to manipulate IP-level socket options in oob. func (c *UDPConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, err error) { + var ap netip.AddrPort + n, oobn, flags, ap, err = c.ReadMsgUDPAddrPort(b, oob) + if ap.IsValid() { + addr = UDPAddrFromAddrPort(ap) + } + return +} + +// ReadMsgUDPAddrPort is like ReadMsgUDP but returns an netip.AddrPort instead of a UDPAddr. +func (c *UDPConn) ReadMsgUDPAddrPort(b, oob []byte) (n, oobn, flags int, addr netip.AddrPort, err error) { if !c.ok() { - return 0, 0, 0, nil, syscall.EINVAL + return 0, 0, 0, netip.AddrPort{}, syscall.EINVAL } n, oobn, flags, addr, err = c.readMsg(b, oob) if err != nil { @@ -160,6 +215,18 @@ func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (int, error) { return n, err } +// WriteToUDPAddrPort acts like WriteTo but takes a netip.AddrPort. +func (c *UDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { + if !c.ok() { + return 0, syscall.EINVAL + } + n, err := c.writeToAddrPort(b, addr) + if err != nil { + err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addrPortUDPAddr{addr}, Err: err} + } + return n, err +} + // WriteTo implements the PacketConn WriteTo method. func (c *UDPConn) WriteTo(b []byte, addr Addr) (int, error) { if !c.ok() { @@ -195,6 +262,18 @@ func (c *UDPConn) WriteMsgUDP(b, oob []byte, addr *UDPAddr) (n, oobn int, err er return } +// WriteMsgUDPAddrPort is like WriteMsgUDP but takes a netip.AddrPort instead of a UDPAddr. +func (c *UDPConn) WriteMsgUDPAddrPort(b, oob []byte, addr netip.AddrPort) (n, oobn int, err error) { + if !c.ok() { + return 0, 0, syscall.EINVAL + } + n, oobn, err = c.writeMsgAddrPort(b, oob, addr) + if err != nil { + err = &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: addrPortUDPAddr{addr}, Err: err} + } + return +} + func newUDPConn(fd *netFD) *UDPConn { return &UDPConn{conn{fd}} } // DialUDP acts like Dial for UDP networks. diff --git a/libgo/go/net/udpsock_plan9.go b/libgo/go/net/udpsock_plan9.go index 1df293d..732a3b0 100644 --- a/libgo/go/net/udpsock_plan9.go +++ b/libgo/go/net/udpsock_plan9.go @@ -7,6 +7,7 @@ package net import ( "context" "errors" + "net/netip" "os" "syscall" ) @@ -28,8 +29,27 @@ func (c *UDPConn) readFrom(b []byte, addr *UDPAddr) (int, *UDPAddr, error) { return n, addr, nil } -func (c *UDPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, err error) { - return 0, 0, 0, nil, syscall.EPLAN9 +func (c *UDPConn) readFromAddrPort(b []byte) (int, netip.AddrPort, error) { + // TODO: optimize. The equivalent code on posix is alloc-free. + buf := make([]byte, udpHeaderSize+len(b)) + m, err := c.fd.Read(buf) + if err != nil { + return 0, netip.AddrPort{}, err + } + if m < udpHeaderSize { + return 0, netip.AddrPort{}, errors.New("short read reading UDP header") + } + buf = buf[:m] + + h, buf := unmarshalUDPHeader(buf) + n := copy(b, buf) + ip, _ := netip.AddrFromSlice(h.raddr) + addr := netip.AddrPortFrom(ip, h.rport) + return n, addr, nil +} + +func (c *UDPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr netip.AddrPort, err error) { + return 0, 0, 0, netip.AddrPort{}, syscall.EPLAN9 } func (c *UDPConn) writeTo(b []byte, addr *UDPAddr) (int, error) { @@ -52,10 +72,18 @@ func (c *UDPConn) writeTo(b []byte, addr *UDPAddr) (int, error) { return len(b), nil } +func (c *UDPConn) writeToAddrPort(b []byte, addr netip.AddrPort) (int, error) { + return c.writeTo(b, UDPAddrFromAddrPort(addr)) // TODO: optimize instead of allocating +} + func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error) { return 0, 0, syscall.EPLAN9 } +func (c *UDPConn) writeMsgAddrPort(b, oob []byte, addr netip.AddrPort) (n, oobn int, err error) { + return 0, 0, syscall.EPLAN9 +} + func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPConn, error) { fd, err := dialPlan9(ctx, sd.network, laddr, raddr) if err != nil { diff --git a/libgo/go/net/udpsock_posix.go b/libgo/go/net/udpsock_posix.go index a4c6da2..a435658 100644 --- a/libgo/go/net/udpsock_posix.go +++ b/libgo/go/net/udpsock_posix.go @@ -3,12 +3,12 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || (js && wasm) || linux || netbsd || openbsd || solaris || windows -// +build aix darwin dragonfly freebsd hurd js,wasm linux netbsd openbsd solaris windows package net import ( "context" + "net/netip" "syscall" ) @@ -44,27 +44,68 @@ func (a *UDPAddr) toLocal(net string) sockaddr { } func (c *UDPConn) readFrom(b []byte, addr *UDPAddr) (int, *UDPAddr, error) { - n, sa, err := c.fd.readFrom(b) - switch sa := sa.(type) { - case *syscall.SockaddrInet4: - *addr = UDPAddr{IP: sa.Addr[0:], Port: sa.Port} - case *syscall.SockaddrInet6: - *addr = UDPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneCache.name(int(sa.ZoneId))} - default: + var n int + var err error + switch c.fd.family { + case syscall.AF_INET: + var from syscall.SockaddrInet4 + n, err = c.fd.readFromInet4(b, &from) + if err == nil { + ip := from.Addr // copy from.Addr; ip escapes, so this line allocates 4 bytes + *addr = UDPAddr{IP: ip[:], Port: from.Port} + } + case syscall.AF_INET6: + var from syscall.SockaddrInet6 + n, err = c.fd.readFromInet6(b, &from) + if err == nil { + ip := from.Addr // copy from.Addr; ip escapes, so this line allocates 16 bytes + *addr = UDPAddr{IP: ip[:], Port: from.Port, Zone: zoneCache.name(int(from.ZoneId))} + } + } + if err != nil { // No sockaddr, so don't return UDPAddr. addr = nil } return n, addr, err } -func (c *UDPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, err error) { - var sa syscall.Sockaddr - n, oobn, flags, sa, err = c.fd.readMsg(b, oob, 0) - switch sa := sa.(type) { - case *syscall.SockaddrInet4: - addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port} - case *syscall.SockaddrInet6: - addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneCache.name(int(sa.ZoneId))} +func (c *UDPConn) readFromAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { + var ip netip.Addr + var port int + switch c.fd.family { + case syscall.AF_INET: + var from syscall.SockaddrInet4 + n, err = c.fd.readFromInet4(b, &from) + if err == nil { + ip = netip.AddrFrom4(from.Addr) + port = from.Port + } + case syscall.AF_INET6: + var from syscall.SockaddrInet6 + n, err = c.fd.readFromInet6(b, &from) + if err == nil { + ip = netip.AddrFrom16(from.Addr).WithZone(zoneCache.name(int(from.ZoneId))) + port = from.Port + } + } + if err == nil { + addr = netip.AddrPortFrom(ip, uint16(port)) + } + return n, addr, err +} + +func (c *UDPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr netip.AddrPort, err error) { + switch c.fd.family { + case syscall.AF_INET: + var sa syscall.SockaddrInet4 + n, oobn, flags, err = c.fd.readMsgInet4(b, oob, 0, &sa) + ip := netip.AddrFrom4(sa.Addr) + addr = netip.AddrPortFrom(ip, uint16(sa.Port)) + case syscall.AF_INET6: + var sa syscall.SockaddrInet6 + n, oobn, flags, err = c.fd.readMsgInet6(b, oob, 0, &sa) + ip := netip.AddrFrom16(sa.Addr).WithZone(zoneCache.name(int(sa.ZoneId))) + addr = netip.AddrPortFrom(ip, uint16(sa.Port)) } return } @@ -76,11 +117,49 @@ func (c *UDPConn) writeTo(b []byte, addr *UDPAddr) (int, error) { if addr == nil { return 0, errMissingAddress } - sa, err := addr.sockaddr(c.fd.family) - if err != nil { - return 0, err + + switch c.fd.family { + case syscall.AF_INET: + sa, err := ipToSockaddrInet4(addr.IP, addr.Port) + if err != nil { + return 0, err + } + return c.fd.writeToInet4(b, &sa) + case syscall.AF_INET6: + sa, err := ipToSockaddrInet6(addr.IP, addr.Port, addr.Zone) + if err != nil { + return 0, err + } + return c.fd.writeToInet6(b, &sa) + default: + return 0, &AddrError{Err: "invalid address family", Addr: addr.IP.String()} + } +} + +func (c *UDPConn) writeToAddrPort(b []byte, addr netip.AddrPort) (int, error) { + if c.fd.isConnected { + return 0, ErrWriteToConnected + } + if !addr.IsValid() { + return 0, errMissingAddress + } + + switch c.fd.family { + case syscall.AF_INET: + sa, err := addrPortToSockaddrInet4(addr) + if err != nil { + return 0, err + } + return c.fd.writeToInet4(b, &sa) + case syscall.AF_INET6: + sa, err := addrPortToSockaddrInet6(addr) + if err != nil { + return 0, err + } + return c.fd.writeToInet6(b, &sa) + default: + return 0, &AddrError{Err: "invalid address family", Addr: addr.Addr().String()} } - return c.fd.writeTo(b, sa) } func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error) { @@ -97,6 +176,32 @@ func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error return c.fd.writeMsg(b, oob, sa) } +func (c *UDPConn) writeMsgAddrPort(b, oob []byte, addr netip.AddrPort) (n, oobn int, err error) { + if c.fd.isConnected && addr.IsValid() { + return 0, 0, ErrWriteToConnected + } + if !c.fd.isConnected && !addr.IsValid() { + return 0, 0, errMissingAddress + } + + switch c.fd.family { + case syscall.AF_INET: + sa, err := addrPortToSockaddrInet4(addr) + if err != nil { + return 0, 0, err + } + return c.fd.writeMsgInet4(b, oob, &sa) + case syscall.AF_INET6: + sa, err := addrPortToSockaddrInet6(addr) + if err != nil { + return 0, 0, err + } + return c.fd.writeMsgInet6(b, oob, &sa) + default: + return 0, 0, &AddrError{Err: "invalid address family", Addr: addr.Addr().String()} + } +} + func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPConn, error) { fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial", sd.Dialer.Control) if err != nil { diff --git a/libgo/go/net/udpsock_test.go b/libgo/go/net/udpsock_test.go index 0e8c351..21f5af5 100644 --- a/libgo/go/net/udpsock_test.go +++ b/libgo/go/net/udpsock_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js -// +build !js package net @@ -286,10 +285,7 @@ func TestIPv6LinkLocalUnicastUDP(t *testing.T) { t.Log(err) continue } - ls, err := (&packetListener{PacketConn: c1}).newLocalServer() - if err != nil { - t.Fatal(err) - } + ls := (&packetListener{PacketConn: c1}).newLocalServer() defer ls.teardown() ch := make(chan error, 1) handler := func(ls *localPacketServer, c PacketConn) { packetTransponder(c, ch) } @@ -334,10 +330,7 @@ func TestUDPZeroBytePayload(t *testing.T) { testenv.SkipFlaky(t, 29225) } - c, err := newLocalPacketListener("udp") - if err != nil { - t.Fatal(err) - } + c := newLocalPacketListener(t, "udp") defer c.Close() for _, genericRead := range []bool{false, true} { @@ -370,10 +363,7 @@ func TestUDPZeroByteBuffer(t *testing.T) { t.Skipf("not supported on %s", runtime.GOOS) } - c, err := newLocalPacketListener("udp") - if err != nil { - t.Fatal(err) - } + c := newLocalPacketListener(t, "udp") defer c.Close() b := []byte("UDP ZERO BYTE BUFFER TEST") @@ -407,10 +397,7 @@ func TestUDPReadSizeError(t *testing.T) { t.Skipf("not supported on %s", runtime.GOOS) } - c1, err := newLocalPacketListener("udp") - if err != nil { - t.Fatal(err) - } + c1 := newLocalPacketListener(t, "udp") defer c1.Close() c2, err := Dial("udp", c1.LocalAddr().String()) @@ -475,11 +462,100 @@ func TestUDPReadTimeout(t *testing.T) { } } +func TestAllocs(t *testing.T) { + switch runtime.GOOS { + case "plan9": + // Plan9 wasn't optimized. + t.Skipf("skipping on %v", runtime.GOOS) + } + builder := os.Getenv("GO_BUILDER_NAME") + switch builder { + case "linux-amd64-noopt": + // Optimizations are required to remove the allocs. + t.Skipf("skipping on %v", builder) + } + conn, err := ListenUDP("udp4", &UDPAddr{IP: IPv4(127, 0, 0, 1)}) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + addr := conn.LocalAddr() + addrPort := addr.(*UDPAddr).AddrPort() + buf := make([]byte, 8) + + allocs := testing.AllocsPerRun(1000, func() { + _, _, err := conn.WriteMsgUDPAddrPort(buf, nil, addrPort) + if err != nil { + t.Fatal(err) + } + _, _, _, _, err = conn.ReadMsgUDPAddrPort(buf, nil) + if err != nil { + t.Fatal(err) + } + }) + if got := int(allocs); got != 0 { + t.Errorf("WriteMsgUDPAddrPort/ReadMsgUDPAddrPort allocated %d objects", got) + } + + allocs = testing.AllocsPerRun(1000, func() { + _, err := conn.WriteToUDPAddrPort(buf, addrPort) + if err != nil { + t.Fatal(err) + } + _, _, err = conn.ReadFromUDPAddrPort(buf) + if err != nil { + t.Fatal(err) + } + }) + if got := int(allocs); got != 0 { + t.Errorf("WriteToUDPAddrPort/ReadFromUDPAddrPort allocated %d objects", got) + } + + allocs = testing.AllocsPerRun(1000, func() { + _, err := conn.WriteTo(buf, addr) + if err != nil { + t.Fatal(err) + } + _, _, err = conn.ReadFromUDP(buf) + if err != nil { + t.Fatal(err) + } + }) + if got := int(allocs); got != 1 { + if runtime.Compiler != "gccgo" { + t.Errorf("WriteTo/ReadFromUDP allocated %d objects", got) + } + } +} + +func BenchmarkReadWriteMsgUDPAddrPort(b *testing.B) { + conn, err := ListenUDP("udp4", &UDPAddr{IP: IPv4(127, 0, 0, 1)}) + if err != nil { + b.Fatal(err) + } + defer conn.Close() + addr := conn.LocalAddr().(*UDPAddr).AddrPort() + buf := make([]byte, 8) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, _, err := conn.WriteMsgUDPAddrPort(buf, nil, addr) + if err != nil { + b.Fatal(err) + } + _, _, _, _, err = conn.ReadMsgUDPAddrPort(buf, nil) + if err != nil { + b.Fatal(err) + } + } +} + func BenchmarkWriteToReadFromUDP(b *testing.B) { conn, err := ListenUDP("udp4", &UDPAddr{IP: IPv4(127, 0, 0, 1)}) if err != nil { b.Fatal(err) } + defer conn.Close() addr := conn.LocalAddr() buf := make([]byte, 8) b.ResetTimer() @@ -495,3 +571,61 @@ func BenchmarkWriteToReadFromUDP(b *testing.B) { } } } + +func BenchmarkWriteToReadFromUDPAddrPort(b *testing.B) { + conn, err := ListenUDP("udp4", &UDPAddr{IP: IPv4(127, 0, 0, 1)}) + if err != nil { + b.Fatal(err) + } + defer conn.Close() + addr := conn.LocalAddr().(*UDPAddr).AddrPort() + buf := make([]byte, 8) + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + _, err := conn.WriteToUDPAddrPort(buf, addr) + if err != nil { + b.Fatal(err) + } + _, _, err = conn.ReadFromUDPAddrPort(buf) + if err != nil { + b.Fatal(err) + } + } +} + +func TestUDPIPVersionReadMsg(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping on %v", runtime.GOOS) + } + conn, err := ListenUDP("udp4", &UDPAddr{IP: IPv4(127, 0, 0, 1)}) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + daddr := conn.LocalAddr().(*UDPAddr).AddrPort() + buf := make([]byte, 8) + _, err = conn.WriteToUDPAddrPort(buf, daddr) + if err != nil { + t.Fatal(err) + } + _, _, _, saddr, err := conn.ReadMsgUDPAddrPort(buf, nil) + if err != nil { + t.Fatal(err) + } + if !saddr.Addr().Is4() { + t.Error("returned AddrPort is not IPv4") + } + _, err = conn.WriteToUDPAddrPort(buf, daddr) + if err != nil { + t.Fatal(err) + } + _, _, _, soldaddr, err := conn.ReadMsgUDP(buf, nil) + if err != nil { + t.Fatal(err) + } + if len(soldaddr.IP) != 4 { + t.Error("returned UDPAddr is not IPv4") + } +} diff --git a/libgo/go/net/unixsock_posix.go b/libgo/go/net/unixsock_posix.go index af075af..927b533 100644 --- a/libgo/go/net/unixsock_posix.go +++ b/libgo/go/net/unixsock_posix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || hurd || (js && wasm) || linux || netbsd || openbsd || solaris || windows -// +build aix darwin dragonfly freebsd hurd js,wasm linux netbsd openbsd solaris windows package net diff --git a/libgo/go/net/unixsock_readmsg_cloexec.go b/libgo/go/net/unixsock_readmsg_cloexec.go index 716484c..fa4fd7d 100644 --- a/libgo/go/net/unixsock_readmsg_cloexec.go +++ b/libgo/go/net/unixsock_readmsg_cloexec.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || freebsd || solaris -// +build aix darwin freebsd solaris package net diff --git a/libgo/go/net/unixsock_readmsg_cmsg_cloexec.go b/libgo/go/net/unixsock_readmsg_cmsg_cloexec.go index bb851b8..6b0de87 100644 --- a/libgo/go/net/unixsock_readmsg_cmsg_cloexec.go +++ b/libgo/go/net/unixsock_readmsg_cmsg_cloexec.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build dragonfly || linux || netbsd || openbsd -// +build dragonfly linux netbsd openbsd package net diff --git a/libgo/go/net/unixsock_readmsg_other.go b/libgo/go/net/unixsock_readmsg_other.go index 3290761..b3d19fe 100644 --- a/libgo/go/net/unixsock_readmsg_other.go +++ b/libgo/go/net/unixsock_readmsg_other.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build (js && wasm) || windows -// +build js,wasm windows package net diff --git a/libgo/go/net/unixsock_readmsg_test.go b/libgo/go/net/unixsock_readmsg_test.go index a4d2fca..c3bfbf9 100644 --- a/libgo/go/net/unixsock_readmsg_test.go +++ b/libgo/go/net/unixsock_readmsg_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris -// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris package net diff --git a/libgo/go/net/unixsock_test.go b/libgo/go/net/unixsock_test.go index 71092e8..2fc9580 100644 --- a/libgo/go/net/unixsock_test.go +++ b/libgo/go/net/unixsock_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js && !plan9 && !windows -// +build !js,!plan9,!windows package net @@ -26,7 +25,7 @@ func TestReadUnixgramWithUnnamedSocket(t *testing.T) { testenv.SkipFlaky(t, 15157) } - addr := testUnixAddr() + addr := testUnixAddr(t) la, err := ResolveUnixAddr("unixgram", addr) if err != nil { t.Fatal(err) @@ -77,10 +76,7 @@ func TestUnixgramZeroBytePayload(t *testing.T) { t.Skip("unixgram test") } - c1, err := newLocalPacketListener("unixgram") - if err != nil { - t.Fatal(err) - } + c1 := newLocalPacketListener(t, "unixgram") defer os.Remove(c1.LocalAddr().String()) defer c1.Close() @@ -127,10 +123,7 @@ func TestUnixgramZeroByteBuffer(t *testing.T) { // issue 4352: Recvfrom failed with "address family not // supported by protocol family" if zero-length buffer provided - c1, err := newLocalPacketListener("unixgram") - if err != nil { - t.Fatal(err) - } + c1 := newLocalPacketListener(t, "unixgram") defer os.Remove(c1.LocalAddr().String()) defer c1.Close() @@ -175,7 +168,7 @@ func TestUnixgramWrite(t *testing.T) { t.Skip("unixgram test") } - addr := testUnixAddr() + addr := testUnixAddr(t) laddr, err := ResolveUnixAddr("unixgram", addr) if err != nil { t.Fatal(err) @@ -220,7 +213,7 @@ func testUnixgramWriteConn(t *testing.T, raddr *UnixAddr) { } func testUnixgramWritePacketConn(t *testing.T, raddr *UnixAddr) { - addr := testUnixAddr() + addr := testUnixAddr(t) c, err := ListenPacket("unixgram", addr) if err != nil { t.Fatal(err) @@ -249,9 +242,9 @@ func TestUnixConnLocalAndRemoteNames(t *testing.T) { } handler := func(ls *localServer, ln Listener) {} - for _, laddr := range []string{"", testUnixAddr()} { + for _, laddr := range []string{"", testUnixAddr(t)} { laddr := laddr - taddr := testUnixAddr() + taddr := testUnixAddr(t) ta, err := ResolveUnixAddr("unix", taddr) if err != nil { t.Fatal(err) @@ -260,10 +253,7 @@ func TestUnixConnLocalAndRemoteNames(t *testing.T) { if err != nil { t.Fatal(err) } - ls, err := (&streamListener{Listener: ln}).newLocalServer() - if err != nil { - t.Fatal(err) - } + ls := (&streamListener{Listener: ln}).newLocalServer() defer ls.teardown() if err := ls.buildup(handler); err != nil { t.Fatal(err) @@ -311,9 +301,9 @@ func TestUnixgramConnLocalAndRemoteNames(t *testing.T) { t.Skip("unixgram test") } - for _, laddr := range []string{"", testUnixAddr()} { + for _, laddr := range []string{"", testUnixAddr(t)} { laddr := laddr - taddr := testUnixAddr() + taddr := testUnixAddr(t) ta, err := ResolveUnixAddr("unixgram", taddr) if err != nil { t.Fatal(err) @@ -369,7 +359,7 @@ func TestUnixUnlink(t *testing.T) { if !testableNetwork("unix") { t.Skip("unix test") } - name := testUnixAddr() + name := testUnixAddr(t) listen := func(t *testing.T) *UnixListener { l, err := Listen("unix", name) diff --git a/libgo/go/net/unixsock_windows_test.go b/libgo/go/net/unixsock_windows_test.go index 29244f6..d541d89 100644 --- a/libgo/go/net/unixsock_windows_test.go +++ b/libgo/go/net/unixsock_windows_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build windows -// +build windows package net @@ -46,9 +45,9 @@ func TestUnixConnLocalWindows(t *testing.T) { } handler := func(ls *localServer, ln Listener) {} - for _, laddr := range []string{"", testUnixAddr()} { + for _, laddr := range []string{"", testUnixAddr(t)} { laddr := laddr - taddr := testUnixAddr() + taddr := testUnixAddr(t) ta, err := ResolveUnixAddr("unix", taddr) if err != nil { t.Fatal(err) @@ -57,10 +56,7 @@ func TestUnixConnLocalWindows(t *testing.T) { if err != nil { t.Fatal(err) } - ls, err := (&streamListener{Listener: ln}).newLocalServer() - if err != nil { - t.Fatal(err) - } + ls := (&streamListener{Listener: ln}).newLocalServer() defer ls.teardown() if err := ls.buildup(handler); err != nil { t.Fatal(err) diff --git a/libgo/go/net/url/url.go b/libgo/go/net/url/url.go index 20de0f6..f31aa08 100644 --- a/libgo/go/net/url/url.go +++ b/libgo/go/net/url/url.go @@ -452,20 +452,6 @@ func getScheme(rawURL string) (scheme, path string, err error) { return "", rawURL, nil } -// split slices s into two substrings separated by the first occurrence of -// sep. If cutc is true then sep is excluded from the second substring. -// If sep does not occur in s then s and the empty string is returned. -func split(s string, sep byte, cutc bool) (string, string) { - i := strings.IndexByte(s, sep) - if i < 0 { - return s, "" - } - if cutc { - return s[:i], s[i+1:] - } - return s[:i], s[i:] -} - // Parse parses a raw url into a URL structure. // // The url may be relative (a path, without a host) or absolute @@ -474,7 +460,7 @@ func split(s string, sep byte, cutc bool) (string, string) { // error, due to parsing ambiguities. func Parse(rawURL string) (*URL, error) { // Cut off #frag - u, frag := split(rawURL, '#', true) + u, frag, _ := strings.Cut(rawURL, "#") url, err := parse(u, false) if err != nil { return nil, &Error{"parse", u, err} @@ -534,7 +520,7 @@ func parse(rawURL string, viaRequest bool) (*URL, error) { url.ForceQuery = true rest = rest[:len(rest)-1] } else { - rest, url.RawQuery = split(rest, '?', true) + rest, url.RawQuery, _ = strings.Cut(rest, "?") } if !strings.HasPrefix(rest, "/") { @@ -553,9 +539,7 @@ func parse(rawURL string, viaRequest bool) (*URL, error) { // RFC 3986, ยง3.3: // In addition, a URI reference (Section 4.1) may be a relative-path reference, // in which case the first path segment cannot contain a colon (":") character. - colon := strings.Index(rest, ":") - slash := strings.Index(rest, "/") - if colon >= 0 && (slash < 0 || colon < slash) { + if segment, _, _ := strings.Cut(rest, "/"); strings.Contains(segment, ":") { // First path segment has colon. Not allowed in relative URL. return nil, errors.New("first path segment in URL cannot contain colon") } @@ -563,7 +547,10 @@ func parse(rawURL string, viaRequest bool) (*URL, error) { if (url.Scheme != "" || !viaRequest && !strings.HasPrefix(rest, "///")) && strings.HasPrefix(rest, "//") { var authority string - authority, rest = split(rest[2:], '/', false) + authority, rest = rest[2:], "" + if i := strings.Index(authority, "/"); i >= 0 { + authority, rest = authority[:i], authority[i:] + } url.User, url.Host, err = parseAuthority(authority) if err != nil { return nil, err @@ -602,7 +589,7 @@ func parseAuthority(authority string) (user *Userinfo, host string, err error) { } user = User(userinfo) } else { - username, password := split(userinfo, ':', true) + username, password, _ := strings.Cut(userinfo, ":") if username, err = unescape(username, encodeUserPassword); err != nil { return nil, "", err } @@ -840,7 +827,7 @@ func (u *URL) String() string { // it would be mistaken for a scheme name. Such a segment must be // preceded by a dot-segment (e.g., "./this:that") to make a relative- // path reference. - if i := strings.IndexByte(path, ':'); i > -1 && strings.IndexByte(path[:i], '/') == -1 { + if segment, _, _ := strings.Cut(path, "/"); strings.Contains(segment, ":") { buf.WriteString("./") } } @@ -933,12 +920,8 @@ func ParseQuery(query string) (Values, error) { func parseQuery(m Values, query string) (err error) { for query != "" { - key := query - if i := strings.IndexAny(key, "&"); i >= 0 { - key, query = key[:i], key[i+1:] - } else { - query = "" - } + var key string + key, query, _ = strings.Cut(query, "&") if strings.Contains(key, ";") { err = fmt.Errorf("invalid semicolon separator in query") continue @@ -946,10 +929,7 @@ func parseQuery(m Values, query string) (err error) { if key == "" { continue } - value := "" - if i := strings.Index(key, "="); i >= 0 { - key, value = key[:i], key[i+1:] - } + key, value, _ := strings.Cut(key, "=") key, err1 := QueryUnescape(key) if err1 != nil { if err == nil { @@ -1013,22 +993,16 @@ func resolvePath(base, ref string) string { } var ( - last string elem string - i int dst strings.Builder ) first := true remaining := full // We want to return a leading '/', so write it now. dst.WriteByte('/') - for i >= 0 { - i = strings.IndexByte(remaining, '/') - if i < 0 { - last, elem, remaining = remaining, remaining, "" - } else { - elem, remaining = remaining[:i], remaining[i+1:] - } + found := true + for found { + elem, remaining, found = strings.Cut(remaining, "/") if elem == "." { first = false // drop @@ -1056,7 +1030,7 @@ func resolvePath(base, ref string) string { } } - if last == "." || last == ".." { + if elem == "." || elem == ".." { dst.WriteByte('/') } @@ -1109,7 +1083,7 @@ func (u *URL) ResolveReference(ref *URL) *URL { url.Path = "" return &url } - if ref.Path == "" && ref.RawQuery == "" { + if ref.Path == "" && !ref.ForceQuery && ref.RawQuery == "" { url.RawQuery = u.RawQuery if ref.Fragment == "" { url.Fragment = u.Fragment diff --git a/libgo/go/net/url/url_test.go b/libgo/go/net/url/url_test.go index 63c8e69..664757b 100644 --- a/libgo/go/net/url/url_test.go +++ b/libgo/go/net/url/url_test.go @@ -618,7 +618,7 @@ var urltests = []URLTest{ // more useful string for debugging than fmt's struct printer func ufmt(u *URL) string { - var user, pass interface{} + var user, pass any if u.User != nil { user = u.User.Username() if p, ok := u.User.Password(); ok { @@ -1172,7 +1172,7 @@ var resolveReferenceTests = []struct { {"http://foo.com/bar/baz", "quux/./dotdot/../dotdot/../dot/./tail/..", "http://foo.com/bar/quux/dot/"}, // Remove any dot-segments prior to forming the target URI. - // http://tools.ietf.org/html/rfc3986#section-5.2.4 + // https://datatracker.ietf.org/doc/html/rfc3986#section-5.2.4 {"http://foo.com/dot/./dotdot/../foo/bar", "../baz", "http://foo.com/dot/baz"}, // Triple dot isn't special @@ -1192,7 +1192,7 @@ var resolveReferenceTests = []struct { {"http://foo.com/foo%2dbar/", "./baz-quux", "http://foo.com/foo%2dbar/baz-quux"}, // RFC 3986: Normal Examples - // http://tools.ietf.org/html/rfc3986#section-5.4.1 + // https://datatracker.ietf.org/doc/html/rfc3986#section-5.4.1 {"http://a/b/c/d;p?q", "g:h", "g:h"}, {"http://a/b/c/d;p?q", "g", "http://a/b/c/g"}, {"http://a/b/c/d;p?q", "./g", "http://a/b/c/g"}, @@ -1218,7 +1218,7 @@ var resolveReferenceTests = []struct { {"http://a/b/c/d;p?q", "../../g", "http://a/g"}, // RFC 3986: Abnormal Examples - // http://tools.ietf.org/html/rfc3986#section-5.4.2 + // https://datatracker.ietf.org/doc/html/rfc3986#section-5.4.2 {"http://a/b/c/d;p?q", "../../../g", "http://a/g"}, {"http://a/b/c/d;p?q", "../../../../g", "http://a/g"}, {"http://a/b/c/d;p?q", "/./g", "http://a/g"}, @@ -1244,6 +1244,9 @@ var resolveReferenceTests = []struct { {"https://a/b/c/d;p?q", "//g/d/e/f?y#s", "https://g/d/e/f?y#s"}, {"https://a/b/c/d;p#s", "?y", "https://a/b/c/d;p?y"}, {"https://a/b/c/d;p?q#s", "?y", "https://a/b/c/d;p?y"}, + + // Empty path and query but with ForceQuery (issue 46033). + {"https://a/b/c/d;p?q#s", "?", "https://a/b/c/d;p?"}, } func TestResolveReference(t *testing.T) { @@ -2059,12 +2062,3 @@ func BenchmarkPathUnescape(b *testing.B) { }) } } - -var sink string - -func BenchmarkSplit(b *testing.B) { - url := "http://www.google.com/?q=go+language#foo%26bar" - for i := 0; i < b.N; i++ { - sink, sink = split(url, '#', true) - } -} diff --git a/libgo/go/net/write_unix_test.go b/libgo/go/net/write_unix_test.go index f79f2d0..23e8bef 100644 --- a/libgo/go/net/write_unix_test.go +++ b/libgo/go/net/write_unix_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris -// +build darwin dragonfly freebsd linux netbsd openbsd solaris package net diff --git a/libgo/go/net/writev_test.go b/libgo/go/net/writev_test.go index bf40ca2..18795a4 100644 --- a/libgo/go/net/writev_test.go +++ b/libgo/go/net/writev_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js -// +build !js package net @@ -187,10 +186,7 @@ func TestWritevError(t *testing.T) { t.Skipf("skipping the test: windows does not have problem sending large chunks of data") } - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } + ln := newLocalListener(t, "tcp") defer ln.Close() ch := make(chan Conn, 1) diff --git a/libgo/go/net/writev_unix.go b/libgo/go/net/writev_unix.go index a0fedc2..51ab29d 100644 --- a/libgo/go/net/writev_unix.go +++ b/libgo/go/net/writev_unix.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build darwin || dragonfly || freebsd || illumos || linux || netbsd || openbsd -// +build darwin dragonfly freebsd illumos linux netbsd openbsd package net |