diff options
author | Ian Lance Taylor <iant@golang.org> | 2020-12-23 09:57:37 -0800 |
---|---|---|
committer | Ian Lance Taylor <iant@golang.org> | 2020-12-30 15:13:24 -0800 |
commit | cfcbb4227fb20191e04eb8d7766ae6202f526afd (patch) | |
tree | e2effea96f6f204451779f044415c2385e45042b /libgo/go/net | |
parent | 0696141107d61483f38482b941549959a0d7f613 (diff) | |
download | gcc-cfcbb4227fb20191e04eb8d7766ae6202f526afd.zip gcc-cfcbb4227fb20191e04eb8d7766ae6202f526afd.tar.gz gcc-cfcbb4227fb20191e04eb8d7766ae6202f526afd.tar.bz2 |
libgo: update to Go1.16beta1 release
This does not yet include support for the //go:embed directive added
in this release.
* Makefile.am (check-runtime): Don't create check-runtime-dir.
(mostlyclean-local): Don't remove check-runtime-dir.
(check-go-tool, check-vet): Copy in go.mod and modules.txt.
(check-cgo-test, check-carchive-test): Add go.mod file.
* Makefile.in: Regenerate.
Reviewed-on: https://go-review.googlesource.com/c/gofrontend/+/280172
Diffstat (limited to 'libgo/go/net')
101 files changed, 2035 insertions, 761 deletions
diff --git a/libgo/go/net/conf.go b/libgo/go/net/conf.go index d064d9e..b0f1b79 100644 --- a/libgo/go/net/conf.go +++ b/libgo/go/net/conf.go @@ -69,7 +69,7 @@ func initConfVal() { // Darwin pops up annoying dialog boxes if programs try to do // their own DNS requests. So always use cgo instead, which // avoids that. - if runtime.GOOS == "darwin" { + if runtime.GOOS == "darwin" || runtime.GOOS == "ios" { confVal.forceCgoLookupHost = true return } @@ -202,11 +202,6 @@ func (c *conf) hostLookupOrder(r *Resolver, hostname string) (ret hostLookupOrde // illumos defaults to "nis [NOTFOUND=return] files" return fallbackOrder } - if c.goos == "linux" { - // glibc says the default is "dns [!UNAVAIL=return] files" - // https://www.gnu.org/software/libc/manual/html_node/Notes-on-NSS-Configuration-File.html. - return hostLookupDNSFiles - } return hostLookupFilesDNS } if nss.err != nil { diff --git a/libgo/go/net/conf_test.go b/libgo/go/net/conf_test.go index 081a274..a8e1807 100644 --- a/libgo/go/net/conf_test.go +++ b/libgo/go/net/conf_test.go @@ -7,7 +7,7 @@ package net import ( - "os" + "io/fs" "strings" "testing" ) @@ -26,7 +26,7 @@ var defaultResolvConf = &dnsConfig{ ndots: 1, timeout: 5, attempts: 2, - err: os.ErrNotExist, + err: fs.ErrNotExist, } func TestConfHostLookupOrder(t *testing.T) { @@ -106,7 +106,7 @@ func TestConfHostLookupOrder(t *testing.T) { name: "solaris_no_nsswitch", c: &conf{ goos: "solaris", - nss: &nssConf{err: os.ErrNotExist}, + nss: &nssConf{err: fs.ErrNotExist}, resolv: defaultResolvConf, }, hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupCgo}}, @@ -170,16 +170,23 @@ func TestConfHostLookupOrder(t *testing.T) { }, hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupDNSFiles}}, }, - // glibc lacking an nsswitch.conf, per - // https://www.gnu.org/software/libc/manual/html_node/Notes-on-NSS-Configuration-File.html { name: "linux_no_nsswitch.conf", c: &conf{ goos: "linux", - nss: &nssConf{err: os.ErrNotExist}, + nss: &nssConf{err: fs.ErrNotExist}, resolv: defaultResolvConf, }, - hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupDNSFiles}}, + hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupFilesDNS}}, + }, + { + name: "linux_empty_nsswitch.conf", + c: &conf{ + goos: "linux", + nss: nssStr(""), + resolv: defaultResolvConf, + }, + hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupFilesDNS}}, }, { name: "files_mdns_dns", diff --git a/libgo/go/net/conn_test.go b/libgo/go/net/conn_test.go index 6854898..771cabc 100644 --- a/libgo/go/net/conn_test.go +++ b/libgo/go/net/conn_test.go @@ -32,7 +32,7 @@ func TestConnAndListener(t *testing.T) { } defer ls.teardown() ch := make(chan error, 1) - handler := func(ls *localServer, ln Listener) { transponder(ln, ch) } + handler := func(ls *localServer, ln Listener) { ls.transponder(ln, ch) } if err := ls.buildup(handler); err != nil { t.Fatal(err) } diff --git a/libgo/go/net/dial_test.go b/libgo/go/net/dial_test.go index 0158248..57cf555 100644 --- a/libgo/go/net/dial_test.go +++ b/libgo/go/net/dial_test.go @@ -160,7 +160,7 @@ func dialClosedPort(t *testing.T) (actual, expected time.Duration) { // but other platforms should be instantaneous. if runtime.GOOS == "windows" { expected = 1500 * time.Millisecond - } else if runtime.GOOS == "darwin" { + } else if runtime.GOOS == "darwin" || runtime.GOOS == "ios" { expected = 150 * time.Millisecond } else { expected = 95 * time.Millisecond @@ -990,7 +990,7 @@ func TestDialerControl(t *testing.T) { // except that it won't skip testing on non-mobile builders. func mustHaveExternalNetwork(t *testing.T) { t.Helper() - mobile := runtime.GOOS == "android" || runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" + mobile := runtime.GOOS == "android" || runtime.GOOS == "ios" if testenv.Builder() == "" || mobile { testenv.MustHaveExternalNetwork(t) } diff --git a/libgo/go/net/dnsclient.go b/libgo/go/net/dnsclient.go index b5bb3a4..e9c7384 100644 --- a/libgo/go/net/dnsclient.go +++ b/libgo/go/net/dnsclient.go @@ -5,12 +5,25 @@ package net import ( - "math/rand" "sort" "golang.org/x/net/dns/dnsmessage" ) +// provided by runtime +func fastrand() uint32 + +func randInt() int { + x, y := fastrand(), fastrand() // 32-bit halves + u := uint(x)<<31 ^ uint(int32(y)) // full uint, even on 64-bit systems; avoid 32-bit shift on 32-bit systems + i := int(u >> 1) // clear sign bit, even on 32-bit systems + return i +} + +func randIntn(n int) int { + return randInt() % n +} + // reverseaddr returns the in-addr.arpa. or ip6.arpa. hostname of the IP // address addr suitable for rDNS (PTR) record lookup or an error if it fails // to parse the IP address. @@ -162,7 +175,7 @@ func (addrs byPriorityWeight) shuffleByWeight() { } for sum > 0 && len(addrs) > 1 { s := 0 - n := rand.Intn(sum) + n := randIntn(sum) for i := range addrs { s += int(addrs[i].Weight) if s > n { @@ -206,7 +219,7 @@ func (s byPref) Swap(i, j int) { s[i], s[j] = s[j], s[i] } // sort reorders MX records as specified in RFC 5321. func (s byPref) sort() { for i := range s { - j := rand.Intn(i + 1) + j := randIntn(i + 1) s[i], s[j] = s[j], s[i] } sort.Sort(s) diff --git a/libgo/go/net/dnsclient_test.go b/libgo/go/net/dnsclient_test.go index f3ed62d..24cd69e 100644 --- a/libgo/go/net/dnsclient_test.go +++ b/libgo/go/net/dnsclient_test.go @@ -5,7 +5,6 @@ package net import ( - "math/rand" "testing" ) @@ -17,7 +16,7 @@ func checkDistribution(t *testing.T, data []*SRV, margin float64) { results := make(map[string]int) - count := 1000 + count := 10000 for j := 0; j < count; j++ { d := make([]*SRV, len(data)) copy(d, data) @@ -39,7 +38,6 @@ func checkDistribution(t *testing.T, data []*SRV, margin float64) { } func testUniformity(t *testing.T, size int, margin float64) { - rand.Seed(1) data := make([]*SRV, size) for i := 0; i < size; i++ { data[i] = &SRV{Target: string('a' + rune(i)), Weight: 1} @@ -55,7 +53,6 @@ func TestDNSSRVUniformity(t *testing.T) { } func testWeighting(t *testing.T, margin float64) { - rand.Seed(1) data := []*SRV{ {Target: "a", Weight: 60}, {Target: "b", Weight: 30}, diff --git a/libgo/go/net/dnsclient_unix.go b/libgo/go/net/dnsclient_unix.go index 5f6c870..c5bfab9 100644 --- a/libgo/go/net/dnsclient_unix.go +++ b/libgo/go/net/dnsclient_unix.go @@ -18,7 +18,6 @@ import ( "context" "errors" "io" - "math/rand" "os" "sync" "time" @@ -47,7 +46,7 @@ var ( ) func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) { - id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano()) + id = uint16(randInt()) b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true}) b.EnableCompression() if err := b.StartQuestions(); err != nil { diff --git a/libgo/go/net/dnsclient_unix_test.go b/libgo/go/net/dnsclient_unix_test.go index a89dccc..f57f231 100644 --- a/libgo/go/net/dnsclient_unix_test.go +++ b/libgo/go/net/dnsclient_unix_test.go @@ -10,7 +10,6 @@ import ( "context" "errors" "fmt" - "io/ioutil" "os" "path" "reflect" @@ -235,7 +234,7 @@ type resolvConfTest struct { } func newResolvConfTest() (*resolvConfTest, error) { - dir, err := ioutil.TempDir("", "go-resolvconftest") + dir, err := os.MkdirTemp("", "go-resolvconftest") if err != nil { return nil, err } diff --git a/libgo/go/net/dnsconfig_unix_test.go b/libgo/go/net/dnsconfig_unix_test.go index 2fca329..f6edffc 100644 --- a/libgo/go/net/dnsconfig_unix_test.go +++ b/libgo/go/net/dnsconfig_unix_test.go @@ -8,6 +8,7 @@ package net import ( "errors" + "io/fs" "os" "reflect" "strings" @@ -183,7 +184,7 @@ func TestDNSReadMissingFile(t *testing.T) { conf := dnsReadConfig("a-nonexistent-file") if !os.IsNotExist(conf.err) { - t.Errorf("missing resolv.conf:\ngot: %v\nwant: %v", conf.err, os.ErrNotExist) + t.Errorf("missing resolv.conf:\ngot: %v\nwant: %v", conf.err, fs.ErrNotExist) } conf.err = nil want := &dnsConfig{ diff --git a/libgo/go/net/error_test.go b/libgo/go/net/error_test.go index 8d4a7ff..556eb8c 100644 --- a/libgo/go/net/error_test.go +++ b/libgo/go/net/error_test.go @@ -8,10 +8,11 @@ package net import ( "context" + "errors" "fmt" "internal/poll" "io" - "io/ioutil" + "io/fs" "net/internal/socktest" "os" "runtime" @@ -96,12 +97,12 @@ second: case *os.SyscallError: nestedErr = err.Err goto third - case *os.PathError: // for Plan 9 + case *fs.PathError: // for Plan 9 nestedErr = err.Err goto third } switch nestedErr { - case errCanceled, poll.ErrNetClosing, errMissingAddress, errNoSuitableAddress, + case errCanceled, ErrClosed, errMissingAddress, errNoSuitableAddress, context.DeadlineExceeded, context.Canceled: return nil } @@ -436,7 +437,7 @@ second: goto third } switch nestedErr { - case poll.ErrNetClosing, errTimeout, poll.ErrNotPollable, os.ErrDeadlineExceeded: + case ErrClosed, errTimeout, poll.ErrNotPollable, os.ErrDeadlineExceeded: return nil } return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr) @@ -478,7 +479,7 @@ second: goto third } switch nestedErr { - case errCanceled, poll.ErrNetClosing, errMissingAddress, errTimeout, os.ErrDeadlineExceeded, ErrWriteToConnected, io.ErrUnexpectedEOF: + case errCanceled, ErrClosed, errMissingAddress, errTimeout, os.ErrDeadlineExceeded, ErrWriteToConnected, io.ErrUnexpectedEOF: return nil } return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr) @@ -508,6 +509,10 @@ func parseCloseError(nestedErr error, isShutdown bool) error { return fmt.Errorf("error string %q does not contain expected string %q", nestedErr, want) } + if !isShutdown && !errors.Is(nestedErr, ErrClosed) { + return fmt.Errorf("errors.Is(%v, errClosed) returns false, want true", nestedErr) + } + switch err := nestedErr.(type) { case *OpError: if err := err.isValid(); err != nil { @@ -526,12 +531,12 @@ second: case *os.SyscallError: nestedErr = err.Err goto third - case *os.PathError: // for Plan 9 + case *fs.PathError: // for Plan 9 nestedErr = err.Err goto third } switch nestedErr { - case poll.ErrNetClosing: + case ErrClosed: return nil } return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr) @@ -541,7 +546,7 @@ third: return nil } switch nestedErr { - case os.ErrClosed: // for Plan 9 + case fs.ErrClosed: // for Plan 9 return nil } return fmt.Errorf("unexpected type on 3rd nested level: %T", nestedErr) @@ -622,12 +627,12 @@ second: case *os.SyscallError: nestedErr = err.Err goto third - case *os.PathError: // for Plan 9 + case *fs.PathError: // for Plan 9 nestedErr = err.Err goto third } switch nestedErr { - case poll.ErrNetClosing, errTimeout, poll.ErrNotPollable, os.ErrDeadlineExceeded: + case ErrClosed, errTimeout, poll.ErrNotPollable, os.ErrDeadlineExceeded: return nil } return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr) @@ -701,12 +706,12 @@ second: case *os.LinkError: nestedErr = err.Err goto third - case *os.PathError: + case *fs.PathError: nestedErr = err.Err goto third } switch nestedErr { - case poll.ErrNetClosing: + case ErrClosed: return nil } return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr) @@ -724,7 +729,7 @@ func TestFileError(t *testing.T) { t.Skipf("not supported on %s", runtime.GOOS) } - f, err := ioutil.TempFile("", "go-nettest") + f, err := os.CreateTemp("", "go-nettest") if err != nil { t.Fatal(err) } @@ -794,7 +799,7 @@ func parseLookupPortError(nestedErr error) error { switch nestedErr.(type) { case *AddrError, *DNSError: return nil - case *os.PathError: // for Plan 9 + case *fs.PathError: // for Plan 9 return nil } return fmt.Errorf("unexpected type on 1st nested level: %T", nestedErr) diff --git a/libgo/go/net/example_test.go b/libgo/go/net/example_test.go index ef8c38f..72c7183 100644 --- a/libgo/go/net/example_test.go +++ b/libgo/go/net/example_test.go @@ -55,6 +55,27 @@ func ExampleDialer() { } } +func ExampleDialer_unix() { + // DialUnix does not take a context.Context parameter. This example shows + // how to dial a Unix socket with a Context. Note that the Context only + // applies to the dial operation; it does not apply to the connection once + // it has been established. + var d net.Dialer + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + d.LocalAddr = nil // if you have a local addr, add it here + raddr := net.UnixAddr{Name: "/path/to/unix.sock", Net: "unix"} + conn, err := d.DialContext(ctx, "unix", raddr.String()) + if err != nil { + log.Fatalf("Failed to dial: %v", err) + } + defer conn.Close() + if _, err := conn.Write([]byte("Hello, socket!")); err != nil { + log.Fatal(err) + } +} + func ExampleIPv4() { fmt.Println(net.IPv4(8, 8, 8, 8)) diff --git a/libgo/go/net/http/alpn_test.go b/libgo/go/net/http/alpn_test.go index 618bdbe..a51038c 100644 --- a/libgo/go/net/http/alpn_test.go +++ b/libgo/go/net/http/alpn_test.go @@ -11,7 +11,6 @@ import ( "crypto/x509" "fmt" "io" - "io/ioutil" . "net/http" "net/http/httptest" "strings" @@ -49,7 +48,7 @@ func TestNextProtoUpgrade(t *testing.T) { if err != nil { t.Fatal(err) } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -93,7 +92,7 @@ func TestNextProtoUpgrade(t *testing.T) { t.Fatal(err) } conn.Write([]byte("GET /foo\n")) - body, err := ioutil.ReadAll(conn) + body, err := io.ReadAll(conn) if err != nil { t.Fatal(err) } diff --git a/libgo/go/net/http/cgi/child.go b/libgo/go/net/http/cgi/child.go index 61de616..0114da3 100644 --- a/libgo/go/net/http/cgi/child.go +++ b/libgo/go/net/http/cgi/child.go @@ -13,7 +13,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net" "net/http" "net/url" @@ -32,7 +31,7 @@ func Request() (*http.Request, error) { return nil, err } if r.ContentLength > 0 { - r.Body = ioutil.NopCloser(io.LimitReader(os.Stdin, r.ContentLength)) + r.Body = io.NopCloser(io.LimitReader(os.Stdin, r.ContentLength)) } return r, nil } @@ -146,6 +145,9 @@ func Serve(handler http.Handler) error { if err != nil { return err } + if req.Body == nil { + req.Body = http.NoBody + } if handler == nil { handler = http.DefaultServeMux } diff --git a/libgo/go/net/http/cgi/child_test.go b/libgo/go/net/http/cgi/child_test.go index f6ecb6e..18cf789 100644 --- a/libgo/go/net/http/cgi/child_test.go +++ b/libgo/go/net/http/cgi/child_test.go @@ -154,17 +154,6 @@ func TestRequestWithoutRemotePort(t *testing.T) { } } -type countingWriter int - -func (c *countingWriter) Write(p []byte) (int, error) { - *c += countingWriter(len(p)) - return len(p), nil -} -func (c *countingWriter) WriteString(p string) (int, error) { - *c += countingWriter(len(p)) - return len(p), nil -} - func TestResponse(t *testing.T) { var tests = []struct { name string diff --git a/libgo/go/net/http/cgi/integration_test.go b/libgo/go/net/http/cgi/integration_test.go index 295c3b8..76cbca8 100644 --- a/libgo/go/net/http/cgi/integration_test.go +++ b/libgo/go/net/http/cgi/integration_test.go @@ -154,6 +154,23 @@ func TestChildOnlyHeaders(t *testing.T) { } } +// Test that a child handler does not receive a nil Request Body. +// golang.org/issue/39190 +func TestNilRequestBody(t *testing.T) { + testenv.MustHaveExec(t) + + h := &Handler{ + Path: os.Args[0], + Root: "/test.go", + Args: []string{"-test.run=TestBeChildCGIProcess"}, + } + expectedMap := map[string]string{ + "nil-request-body": "false", + } + _ = runCgiTest(t, h, "POST /test.go?nil-request-body=1 HTTP/1.0\nHost: example.com\n\n", expectedMap) + _ = runCgiTest(t, h, "POST /test.go?nil-request-body=1 HTTP/1.0\nHost: example.com\nContent-Length: 0\n\n", expectedMap) +} + func TestChildContentType(t *testing.T) { testenv.MustHaveExec(t) @@ -245,6 +262,10 @@ func TestBeChildCGIProcess(t *testing.T) { os.Exit(0) } Serve(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.FormValue("nil-request-body") == "1" { + fmt.Fprintf(rw, "nil-request-body=%v\n", req.Body == nil) + return + } rw.Header().Set("X-Test-Header", "X-Test-Value") req.ParseForm() if req.FormValue("no-body") == "1" { diff --git a/libgo/go/net/http/client.go b/libgo/go/net/http/client.go index 3860d97..88e2028 100644 --- a/libgo/go/net/http/client.go +++ b/libgo/go/net/http/client.go @@ -16,7 +16,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "log" "net/url" "reflect" @@ -282,7 +281,7 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, d if resp.ContentLength > 0 && req.Method != "HEAD" { return nil, didTimeout, fmt.Errorf("http: RoundTripper implementation (%T) returned a *Response with content length %d but a nil Body", rt, resp.ContentLength) } - resp.Body = ioutil.NopCloser(strings.NewReader("")) + resp.Body = io.NopCloser(strings.NewReader("")) } if !deadline.IsZero() { resp.Body = &cancelTimerBody{ @@ -321,7 +320,7 @@ func knownRoundTripperImpl(rt RoundTripper, req *Request) bool { return true } // There's a very minor chance of a false positive with this. - // Insted of detecting our golang.org/x/net/http2.Transport, + // Instead of detecting our golang.org/x/net/http2.Transport, // it might detect a Transport type in a different http2 // package. But I know of none, and the only problem would be // some temporarily leaked goroutines if the transport didn't @@ -697,7 +696,7 @@ func (c *Client) do(req *Request) (retres *Response, reterr error) { // fails, the Transport won't reuse it anyway. const maxBodySlurpSize = 2 << 10 if resp.ContentLength == -1 || resp.ContentLength <= maxBodySlurpSize { - io.CopyN(ioutil.Discard, resp.Body, maxBodySlurpSize) + io.CopyN(io.Discard, resp.Body, maxBodySlurpSize) } resp.Body.Close() diff --git a/libgo/go/net/http/client_test.go b/libgo/go/net/http/client_test.go index 80807fa..d90b484 100644 --- a/libgo/go/net/http/client_test.go +++ b/libgo/go/net/http/client_test.go @@ -14,7 +14,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "log" "net" . "net/http" @@ -35,7 +34,7 @@ var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "User-agent: go\nDisallow: /something/") }) -// pedanticReadAll works like ioutil.ReadAll but additionally +// pedanticReadAll works like io.ReadAll but additionally // verifies that r obeys the documented io.Reader contract. func pedanticReadAll(r io.Reader) (b []byte, err error) { var bufa [64]byte @@ -190,7 +189,7 @@ func TestPostFormRequestFormat(t *testing.T) { if g, e := tr.req.ContentLength, int64(len(expectedBody)); g != e { t.Errorf("got ContentLength %d, want %d", g, e) } - bodyb, err := ioutil.ReadAll(tr.req.Body) + bodyb, err := io.ReadAll(tr.req.Body) if err != nil { t.Fatalf("ReadAll on req.Body: %v", err) } @@ -421,7 +420,7 @@ func testRedirectsByMethod(t *testing.T, method string, table []redirectTest, wa var ts *httptest.Server ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { log.Lock() - slurp, _ := ioutil.ReadAll(r.Body) + slurp, _ := io.ReadAll(r.Body) fmt.Fprintf(&log.Buffer, "%s %s %q", r.Method, r.RequestURI, slurp) if cl := r.Header.Get("Content-Length"); r.Method == "GET" && len(slurp) == 0 && (r.ContentLength != 0 || cl != "") { fmt.Fprintf(&log.Buffer, " (but with body=%T, content-length = %v, %q)", r.Body, r.ContentLength, cl) @@ -452,7 +451,7 @@ func testRedirectsByMethod(t *testing.T, method string, table []redirectTest, wa for _, tt := range table { content := tt.redirectBody req, _ := NewRequest(method, ts.URL+tt.suffix, strings.NewReader(content)) - req.GetBody = func() (io.ReadCloser, error) { return ioutil.NopCloser(strings.NewReader(content)), nil } + req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(strings.NewReader(content)), nil } res, err := c.Do(req) if err != nil { @@ -522,7 +521,7 @@ func TestClientRedirectUseResponse(t *testing.T) { t.Errorf("status = %d; want %d", res.StatusCode, StatusFound) } defer res.Body.Close() - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -1042,7 +1041,7 @@ func testClientHeadContentLength(t *testing.T, h2 bool) { if res.ContentLength != tt.want { t.Errorf("Content-Length = %d; want %d", res.ContentLength, tt.want) } - bs, err := ioutil.ReadAll(res.Body) + bs, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -1257,7 +1256,7 @@ func testClientTimeout(t *testing.T, h2 bool) { errc := make(chan error, 1) go func() { - _, err := ioutil.ReadAll(res.Body) + _, err := io.ReadAll(res.Body) errc <- err res.Body.Close() }() @@ -1348,7 +1347,7 @@ func TestClientTimeoutCancel(t *testing.T) { t.Fatal(err) } cancel() - _, err = io.Copy(ioutil.Discard, res.Body) + _, err = io.Copy(io.Discard, res.Body) if err != ExportErrRequestCanceled { t.Fatalf("error = %v; want errRequestCanceled", err) } @@ -1372,7 +1371,7 @@ func testClientRedirectEatsBody(t *testing.T, h2 bool) { if err != nil { t.Fatal(err) } - _, err = ioutil.ReadAll(res.Body) + _, err = io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Fatal(err) @@ -1450,7 +1449,7 @@ func (issue15577Tripper) RoundTrip(*Request) (*Response, error) { resp := &Response{ StatusCode: 303, Header: map[string][]string{"Location": {"http://www.example.com/"}}, - Body: ioutil.NopCloser(strings.NewReader("")), + Body: io.NopCloser(strings.NewReader("")), } return resp, nil } @@ -1591,7 +1590,7 @@ func TestClientCopyHostOnRedirect(t *testing.T) { if resp.StatusCode != 200 { t.Fatal(resp.Status) } - if got, err := ioutil.ReadAll(resp.Body); err != nil || string(got) != wantBody { + if got, err := io.ReadAll(resp.Body); err != nil || string(got) != wantBody { t.Errorf("body = %q; want %q", got, wantBody) } } @@ -2020,9 +2019,66 @@ func TestClientPopulatesNilResponseBody(t *testing.T) { } }() - if b, err := ioutil.ReadAll(resp.Body); err != nil { + if b, err := io.ReadAll(resp.Body); err != nil { t.Errorf("read error from substitute Response.Body: %v", err) } else if len(b) != 0 { t.Errorf("substitute Response.Body was unexpectedly non-empty: %q", b) } } + +// Issue 40382: Client calls Close multiple times on Request.Body. +func TestClientCallsCloseOnlyOnce(t *testing.T) { + setParallel(t) + defer afterTest(t) + cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.WriteHeader(StatusNoContent) + })) + defer cst.close() + + // Issue occurred non-deterministically: needed to occur after a successful + // write (into TCP buffer) but before end of body. + for i := 0; i < 50 && !t.Failed(); i++ { + body := &issue40382Body{t: t, n: 300000} + req, err := NewRequest(MethodPost, cst.ts.URL, body) + if err != nil { + t.Fatal(err) + } + resp, err := cst.tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + } +} + +// issue40382Body is an io.ReadCloser for TestClientCallsCloseOnlyOnce. +// Its Read reads n bytes before returning io.EOF. +// Its Close returns nil but fails the test if called more than once. +type issue40382Body struct { + t *testing.T + n int + closeCallsAtomic int32 +} + +func (b *issue40382Body) Read(p []byte) (int, error) { + switch { + case b.n == 0: + return 0, io.EOF + case b.n < len(p): + p = p[:b.n] + fallthrough + default: + for i := range p { + p[i] = 'x' + } + b.n -= len(p) + return len(p), nil + } +} + +func (b *issue40382Body) Close() error { + if atomic.AddInt32(&b.closeCallsAtomic, 1) == 2 { + b.t.Error("Body closed more than once") + } + return nil +} diff --git a/libgo/go/net/http/clientserver_test.go b/libgo/go/net/http/clientserver_test.go index 70bcd0e..42207ac 100644 --- a/libgo/go/net/http/clientserver_test.go +++ b/libgo/go/net/http/clientserver_test.go @@ -15,7 +15,6 @@ import ( "fmt" "hash" "io" - "io/ioutil" "log" "net" . "net/http" @@ -53,7 +52,7 @@ func (t *clientServerTest) getURL(u string) string { t.t.Fatal(err) } defer res.Body.Close() - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.t.Fatal(err) } @@ -152,7 +151,7 @@ func TestChunkedResponseHeaders_h2(t *testing.T) { testChunkedResponseHeaders(t, func testChunkedResponseHeaders(t *testing.T, h2 bool) { defer afterTest(t) - log.SetOutput(ioutil.Discard) // is noisy otherwise + log.SetOutput(io.Discard) // is noisy otherwise defer log.SetOutput(os.Stderr) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted @@ -266,11 +265,11 @@ func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) } else { t.Errorf("got %q response; want %q", res.Proto, wantProto) } - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) res.Body.Close() res.Body = slurpResult{ - ReadCloser: ioutil.NopCloser(bytes.NewReader(slurp)), + ReadCloser: io.NopCloser(bytes.NewReader(slurp)), body: slurp, err: err, } @@ -477,7 +476,7 @@ func test304Responses(t *testing.T, h2 bool) { if len(res.TransferEncoding) > 0 { t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding) } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Error(err) } @@ -564,7 +563,7 @@ func testCancelRequestMidBody(t *testing.T, h2 bool) { close(cancel) - rest, err := ioutil.ReadAll(res.Body) + rest, err := io.ReadAll(res.Body) all := string(firstRead) + string(rest) if all != "Hello" { t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest) @@ -587,7 +586,7 @@ func testTrailersClientToServer(t *testing.T, h2 bool) { } sort.Strings(decl) - slurp, err := ioutil.ReadAll(r.Body) + slurp, err := io.ReadAll(r.Body) if err != nil { t.Errorf("Server reading request body: %v", err) } @@ -721,7 +720,7 @@ func testResponseBodyReadAfterClose(t *testing.T, h2 bool) { t.Fatal(err) } res.Body.Close() - data, err := ioutil.ReadAll(res.Body) + data, err := io.ReadAll(res.Body) if len(data) != 0 || err == nil { t.Fatalf("ReadAll returned %q, %v; want error", data, err) } @@ -740,7 +739,7 @@ func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { // Read in one goroutine. go func() { defer wg.Done() - data, err := ioutil.ReadAll(r.Body) + data, err := io.ReadAll(r.Body) if string(data) != reqBody { t.Errorf("Handler read %q; want %q", data, reqBody) } @@ -770,7 +769,7 @@ func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { if err != nil { t.Fatal(err) } - data, err := ioutil.ReadAll(res.Body) + data, err := io.ReadAll(res.Body) defer res.Body.Close() if err != nil { t.Fatal(err) @@ -887,7 +886,7 @@ func testTransportUserAgent(t *testing.T, h2 bool) { t.Errorf("%d. RoundTrip = %v", i, err) continue } - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("%d. read body = %v", i, err) @@ -1009,11 +1008,17 @@ func TestTransportDiscardsUnneededConns(t *testing.T) { defer wg.Done() resp, err := c.Get(cst.ts.URL) if err != nil { - t.Errorf("Get: %v", err) - return + // Try to work around spurious connection reset on loaded system. + // See golang.org/issue/33585 and golang.org/issue/36797. + time.Sleep(10 * time.Millisecond) + resp, err = c.Get(cst.ts.URL) + if err != nil { + t.Errorf("Get: %v", err) + return + } } defer resp.Body.Close() - slurp, err := ioutil.ReadAll(resp.Body) + slurp, err := io.ReadAll(resp.Body) if err != nil { t.Error(err) } @@ -1062,7 +1067,7 @@ func testTransportGCRequest(t *testing.T, h2, body bool) { setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { - ioutil.ReadAll(r.Body) + io.ReadAll(r.Body) if body { io.WriteString(w, "Hello.") } @@ -1078,7 +1083,7 @@ func testTransportGCRequest(t *testing.T, h2, body bool) { if err != nil { t.Fatal(err) } - if _, err := ioutil.ReadAll(res.Body); err != nil { + if _, err := io.ReadAll(res.Body); err != nil { t.Fatal(err) } if err := res.Body.Close(); err != nil { @@ -1139,7 +1144,7 @@ func testTransportRejectsInvalidHeaders(t *testing.T, h2 bool) { res, err := cst.c.Do(req) var body []byte if err == nil { - body, _ = ioutil.ReadAll(res.Body) + body, _ = io.ReadAll(res.Body) res.Body.Close() } var dialed bool @@ -1196,7 +1201,7 @@ func testInterruptWithPanic(t *testing.T, h2 bool, panicValue interface{}) { } gotHeaders <- true defer res.Body.Close() - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if string(slurp) != msg { t.Errorf("client read %q; want %q", slurp, msg) } @@ -1361,7 +1366,7 @@ func testServerUndeclaredTrailers(t *testing.T, h2 bool) { if err != nil { t.Fatal(err) } - if _, err := io.Copy(ioutil.Discard, res.Body); err != nil { + if _, err := io.Copy(io.Discard, res.Body); err != nil { t.Fatal(err) } res.Body.Close() @@ -1379,7 +1384,7 @@ func testServerUndeclaredTrailers(t *testing.T, h2 bool) { func TestBadResponseAfterReadingBody(t *testing.T) { defer afterTest(t) cst := newClientServerTest(t, false, HandlerFunc(func(w ResponseWriter, r *Request) { - _, err := io.Copy(ioutil.Discard, r.Body) + _, err := io.Copy(io.Discard, r.Body) if err != nil { t.Fatal(err) } @@ -1472,7 +1477,7 @@ func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { t.Fatal(err) } defer res.Body.Close() - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } diff --git a/libgo/go/net/http/cookie.go b/libgo/go/net/http/cookie.go index d7a8f5e..141bc94 100644 --- a/libgo/go/net/http/cookie.go +++ b/libgo/go/net/http/cookie.go @@ -220,7 +220,7 @@ func (c *Cookie) String() string { } switch c.SameSite { case SameSiteDefaultMode: - b.WriteString("; SameSite") + // Skip, default mode is obtained by not emitting the attribute. case SameSiteNoneMode: b.WriteString("; SameSite=None") case SameSiteLaxMode: diff --git a/libgo/go/net/http/cookie_test.go b/libgo/go/net/http/cookie_test.go index 9e8196e..959713a 100644 --- a/libgo/go/net/http/cookie_test.go +++ b/libgo/go/net/http/cookie_test.go @@ -67,7 +67,7 @@ var writeSetCookiesTests = []struct { }, { &Cookie{Name: "cookie-12", Value: "samesite-default", SameSite: SameSiteDefaultMode}, - "cookie-12=samesite-default; SameSite", + "cookie-12=samesite-default", }, { &Cookie{Name: "cookie-13", Value: "samesite-lax", SameSite: SameSiteLaxMode}, @@ -283,6 +283,15 @@ var readSetCookiesTests = []struct { }}, }, { + Header{"Set-Cookie": {"samesiteinvalidisdefault=foo; SameSite=invalid"}}, + []*Cookie{{ + Name: "samesiteinvalidisdefault", + Value: "foo", + SameSite: SameSiteDefaultMode, + Raw: "samesiteinvalidisdefault=foo; SameSite=invalid", + }}, + }, + { Header{"Set-Cookie": {"samesitelax=foo; SameSite=Lax"}}, []*Cookie{{ Name: "samesitelax", diff --git a/libgo/go/net/http/doc.go b/libgo/go/net/http/doc.go index 7855fea..ae9b708 100644 --- a/libgo/go/net/http/doc.go +++ b/libgo/go/net/http/doc.go @@ -21,7 +21,7 @@ The client must close the response body when finished with it: // handle error } defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) // ... For control over HTTP client headers, redirect policy, and other diff --git a/libgo/go/net/http/example_filesystem_test.go b/libgo/go/net/http/example_filesystem_test.go index e1fd42d..0e81458 100644 --- a/libgo/go/net/http/example_filesystem_test.go +++ b/libgo/go/net/http/example_filesystem_test.go @@ -5,9 +5,9 @@ package http_test import ( + "io/fs" "log" "net/http" - "os" "strings" ) @@ -33,7 +33,7 @@ type dotFileHidingFile struct { // Readdir is a wrapper around the Readdir method of the embedded File // that filters out all files that start with a period in their name. -func (f dotFileHidingFile) Readdir(n int) (fis []os.FileInfo, err error) { +func (f dotFileHidingFile) Readdir(n int) (fis []fs.FileInfo, err error) { files, err := f.File.Readdir(n) for _, file := range files { // Filters out the dot files if !strings.HasPrefix(file.Name(), ".") { @@ -52,12 +52,12 @@ type dotFileHidingFileSystem struct { // Open is a wrapper around the Open method of the embedded FileSystem // that serves a 403 permission error when name has a file or directory // with whose name starts with a period in its path. -func (fs dotFileHidingFileSystem) Open(name string) (http.File, error) { +func (fsys dotFileHidingFileSystem) Open(name string) (http.File, error) { if containsDotFile(name) { // If dot file, return 403 response - return nil, os.ErrPermission + return nil, fs.ErrPermission } - file, err := fs.FileSystem.Open(name) + file, err := fsys.FileSystem.Open(name) if err != nil { return nil, err } @@ -65,7 +65,7 @@ func (fs dotFileHidingFileSystem) Open(name string) (http.File, error) { } func ExampleFileServer_dotFileHiding() { - fs := dotFileHidingFileSystem{http.Dir(".")} - http.Handle("/", http.FileServer(fs)) + fsys := dotFileHidingFileSystem{http.Dir(".")} + http.Handle("/", http.FileServer(fsys)) log.Fatal(http.ListenAndServe(":8080", nil)) } diff --git a/libgo/go/net/http/example_test.go b/libgo/go/net/http/example_test.go index a783b46..c677d52 100644 --- a/libgo/go/net/http/example_test.go +++ b/libgo/go/net/http/example_test.go @@ -8,7 +8,6 @@ import ( "context" "fmt" "io" - "io/ioutil" "log" "net/http" "os" @@ -46,7 +45,7 @@ func ExampleGet() { if err != nil { log.Fatal(err) } - robots, err := ioutil.ReadAll(res.Body) + robots, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { log.Fatal(err) diff --git a/libgo/go/net/http/export_test.go b/libgo/go/net/http/export_test.go index 657ff9d..096a6d3 100644 --- a/libgo/go/net/http/export_test.go +++ b/libgo/go/net/http/export_test.go @@ -254,7 +254,7 @@ func hookSetter(dst *func()) func(func()) { } func ExportHttp2ConfigureTransport(t *Transport) error { - t2, err := http2configureTransport(t) + t2, err := http2configureTransports(t) if err != nil { return err } @@ -274,6 +274,17 @@ func (s *Server) ExportAllConnsIdle() bool { return true } +func (s *Server) ExportAllConnsByState() map[ConnState]int { + states := map[ConnState]int{} + s.mu.Lock() + defer s.mu.Unlock() + for c := range s.activeConn { + st, _ := c.getState() + states[st] += 1 + } + return states +} + func (r *Request) WithT(t *testing.T) *Request { return r.WithContext(context.WithValue(r.Context(), tLogKey{}, t.Logf)) } diff --git a/libgo/go/net/http/fcgi/child.go b/libgo/go/net/http/fcgi/child.go index a31273b..e97b844 100644 --- a/libgo/go/net/http/fcgi/child.go +++ b/libgo/go/net/http/fcgi/child.go @@ -11,7 +11,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net" "net/http" "net/http/cgi" @@ -172,9 +171,12 @@ func (c *child) serve() { defer c.cleanUp() var rec record for { + c.conn.mutex.Lock() if err := rec.read(c.conn.rwc); err != nil { + c.conn.mutex.Unlock() return } + c.conn.mutex.Unlock() if err := c.handleRecord(&rec); err != nil { return } @@ -183,7 +185,7 @@ func (c *child) serve() { var errCloseConn = errors.New("fcgi: connection should be closed") -var emptyBody = ioutil.NopCloser(strings.NewReader("")) +var emptyBody = io.NopCloser(strings.NewReader("")) // ErrRequestAborted is returned by Read when a handler attempts to read the // body of a request that has been aborted by the web server. @@ -322,7 +324,7 @@ func (c *child) serveRequest(req *request, body io.ReadCloser) { // some sort of abort request to the host, so the host // can properly cut off the client sending all the data. // For now just bound it a little and - io.CopyN(ioutil.Discard, body, 100<<20) + io.CopyN(io.Discard, body, 100<<20) body.Close() if !req.keepConn { diff --git a/libgo/go/net/http/fcgi/fcgi_test.go b/libgo/go/net/http/fcgi/fcgi_test.go index 59246c2..d3b704f 100644 --- a/libgo/go/net/http/fcgi/fcgi_test.go +++ b/libgo/go/net/http/fcgi/fcgi_test.go @@ -8,7 +8,6 @@ import ( "bytes" "errors" "io" - "io/ioutil" "net/http" "strings" "testing" @@ -243,7 +242,7 @@ func TestChildServeCleansUp(t *testing.T) { r *http.Request, ) { // block on reading body of request - _, err := io.Copy(ioutil.Discard, r.Body) + _, err := io.Copy(io.Discard, r.Body) if err != tt.err { t.Errorf("Expected %#v, got %#v", tt.err, err) } @@ -275,7 +274,7 @@ func TestMalformedParams(t *testing.T) { // end of params 1, 4, 0, 1, 0, 0, 0, 0, } - rw := rwNopCloser{bytes.NewReader(input), ioutil.Discard} + rw := rwNopCloser{bytes.NewReader(input), io.Discard} c := newChild(rw, http.DefaultServeMux) c.serve() } @@ -347,7 +346,6 @@ func TestChildServeReadsEnvVars(t *testing.T) { } func TestResponseWriterSniffsContentType(t *testing.T) { - t.Skip("this test is flaky, see Issue 41167") var tests = []struct { name string body string diff --git a/libgo/go/net/http/filetransport_test.go b/libgo/go/net/http/filetransport_test.go index 2a2f32c..b58888d 100644 --- a/libgo/go/net/http/filetransport_test.go +++ b/libgo/go/net/http/filetransport_test.go @@ -5,7 +5,7 @@ package http import ( - "io/ioutil" + "io" "os" "path/filepath" "testing" @@ -23,10 +23,10 @@ func checker(t *testing.T) func(string, error) { func TestFileTransport(t *testing.T) { check := checker(t) - dname, err := ioutil.TempDir("", "") + dname, err := os.MkdirTemp("", "") check("TempDir", err) fname := filepath.Join(dname, "foo.txt") - err = ioutil.WriteFile(fname, []byte("Bar"), 0644) + err = os.WriteFile(fname, []byte("Bar"), 0644) check("WriteFile", err) defer os.Remove(dname) defer os.Remove(fname) @@ -48,7 +48,7 @@ func TestFileTransport(t *testing.T) { if res.Body == nil { t.Fatalf("for %s, nil Body", urlstr) } - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) res.Body.Close() check("ReadAll "+urlstr, err) if string(slurp) != "Bar" { diff --git a/libgo/go/net/http/fs.go b/libgo/go/net/http/fs.go index 922706a..a28ae85 100644 --- a/libgo/go/net/http/fs.go +++ b/libgo/go/net/http/fs.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + "io/fs" "mime" "mime/multipart" "net/textproto" @@ -43,7 +44,7 @@ type Dir string // mapDirOpenError maps the provided non-nil error from opening name // to a possibly better non-nil error. In particular, it turns OS-specific errors -// about opening files in non-directories into os.ErrNotExist. See Issue 18984. +// about opening files in non-directories into fs.ErrNotExist. See Issue 18984. func mapDirOpenError(originalErr error, name string) error { if os.IsNotExist(originalErr) || os.IsPermission(originalErr) { return originalErr @@ -59,7 +60,7 @@ func mapDirOpenError(originalErr error, name string) error { return originalErr } if !fi.IsDir() { - return os.ErrNotExist + return fs.ErrNotExist } } return originalErr @@ -86,6 +87,10 @@ func (d Dir) Open(name string) (File, error) { // A FileSystem implements access to a collection of named files. // The elements in a file path are separated by slash ('/', U+002F) // characters, regardless of host operating system convention. +// See the FileServer function to convert a FileSystem to a Handler. +// +// This interface predates the fs.FS interface, which can be used instead: +// the FS adapter function converts an fs.FS to a FileSystem. type FileSystem interface { Open(name string) (File, error) } @@ -98,24 +103,56 @@ type File interface { io.Closer io.Reader io.Seeker - Readdir(count int) ([]os.FileInfo, error) - Stat() (os.FileInfo, error) + Readdir(count int) ([]fs.FileInfo, error) + Stat() (fs.FileInfo, error) +} + +type anyDirs interface { + len() int + name(i int) string + isDir(i int) bool } +type fileInfoDirs []fs.FileInfo + +func (d fileInfoDirs) len() int { return len(d) } +func (d fileInfoDirs) isDir(i int) bool { return d[i].IsDir() } +func (d fileInfoDirs) name(i int) string { return d[i].Name() } + +type dirEntryDirs []fs.DirEntry + +func (d dirEntryDirs) len() int { return len(d) } +func (d dirEntryDirs) isDir(i int) bool { return d[i].IsDir() } +func (d dirEntryDirs) name(i int) string { return d[i].Name() } + func dirList(w ResponseWriter, r *Request, f File) { - dirs, err := f.Readdir(-1) + // Prefer to use ReadDir instead of Readdir, + // because the former doesn't require calling + // Stat on every entry of a directory on Unix. + var dirs anyDirs + var err error + if d, ok := f.(fs.ReadDirFile); ok { + var list dirEntryDirs + list, err = d.ReadDir(-1) + dirs = list + } else { + var list fileInfoDirs + list, err = f.Readdir(-1) + dirs = list + } + if err != nil { logf(r, "http: error reading directory: %v", err) Error(w, "Error reading directory", StatusInternalServerError) return } - sort.Slice(dirs, func(i, j int) bool { return dirs[i].Name() < dirs[j].Name() }) + sort.Slice(dirs, func(i, j int) bool { return dirs.name(i) < dirs.name(j) }) w.Header().Set("Content-Type", "text/html; charset=utf-8") fmt.Fprintf(w, "<pre>\n") - for _, d := range dirs { - name := d.Name() - if d.IsDir() { + for i, n := 0, dirs.len(); i < n; i++ { + name := dirs.name(i) + if dirs.isDir(i) { name += "/" } // name may contain '?' or '#', which must be escaped to remain @@ -706,17 +743,98 @@ type fileHandler struct { root FileSystem } +type ioFS struct { + fsys fs.FS +} + +type ioFile struct { + file fs.File +} + +func (f ioFS) Open(name string) (File, error) { + if name == "/" { + name = "." + } else { + name = strings.TrimPrefix(name, "/") + } + file, err := f.fsys.Open(name) + if err != nil { + return nil, err + } + return ioFile{file}, nil +} + +func (f ioFile) Close() error { return f.file.Close() } +func (f ioFile) Read(b []byte) (int, error) { return f.file.Read(b) } +func (f ioFile) Stat() (fs.FileInfo, error) { return f.file.Stat() } + +var errMissingSeek = errors.New("io.File missing Seek method") +var errMissingReadDir = errors.New("io.File directory missing ReadDir method") + +func (f ioFile) Seek(offset int64, whence int) (int64, error) { + s, ok := f.file.(io.Seeker) + if !ok { + return 0, errMissingSeek + } + return s.Seek(offset, whence) +} + +func (f ioFile) ReadDir(count int) ([]fs.DirEntry, error) { + d, ok := f.file.(fs.ReadDirFile) + if !ok { + return nil, errMissingReadDir + } + return d.ReadDir(count) +} + +func (f ioFile) Readdir(count int) ([]fs.FileInfo, error) { + d, ok := f.file.(fs.ReadDirFile) + if !ok { + return nil, errMissingReadDir + } + var list []fs.FileInfo + for { + dirs, err := d.ReadDir(count - len(list)) + for _, dir := range dirs { + info, err := dir.Info() + if err != nil { + // Pretend it doesn't exist, like (*os.File).Readdir does. + continue + } + list = append(list, info) + } + if err != nil { + return list, err + } + if count < 0 || len(list) >= count { + break + } + } + return list, nil +} + +// FS converts fsys to a FileSystem implementation, +// for use with FileServer and NewFileTransport. +func FS(fsys fs.FS) FileSystem { + return ioFS{fsys} +} + // FileServer returns a handler that serves HTTP requests // with the contents of the file system rooted at root. // +// As a special case, the returned file server redirects any request +// ending in "/index.html" to the same path, without the final +// "index.html". +// // To use the operating system's file system implementation, // use http.Dir: // // http.Handle("/", http.FileServer(http.Dir("/tmp"))) // -// As a special case, the returned file server redirects any request -// ending in "/index.html" to the same path, without the final -// "index.html". +// To use an fs.FS implementation, use http.FS to convert it: +// +// http.Handle("/", http.FileServer(http.FS(fsys))) +// func FileServer(root FileSystem) Handler { return &fileHandler{root} } @@ -771,9 +889,15 @@ func parseRange(s string, size int64) ([]httpRange, error) { var r httpRange if start == "" { // If no start is specified, end specifies the - // range start relative to the end of the file. + // range start relative to the end of the file, + // and we are dealing with <suffix-length> + // which has to be a non-negative integer as per + // RFC 7233 Section 2.1 "Byte-Ranges". + if end == "" || end[0] == '-' { + return nil, errors.New("invalid range") + } i, err := strconv.ParseInt(end, 10, 64) - if err != nil { + if i < 0 || err != nil { return nil, errors.New("invalid range") } if i > size { diff --git a/libgo/go/net/http/fs_test.go b/libgo/go/net/http/fs_test.go index c082cee..2499051 100644 --- a/libgo/go/net/http/fs_test.go +++ b/libgo/go/net/http/fs_test.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + "io/fs" "io/ioutil" "mime" "mime/multipart" @@ -78,7 +79,7 @@ func TestServeFile(t *testing.T) { var err error - file, err := ioutil.ReadFile(testFile) + file, err := os.ReadFile(testFile) if err != nil { t.Fatal("reading file:", err) } @@ -159,7 +160,7 @@ Cases: if g, w := part.Header.Get("Content-Range"), wantContentRange; g != w { t.Errorf("range=%q: part Content-Range = %q; want %q", rt.r, g, w) } - body, err := ioutil.ReadAll(part) + body, err := io.ReadAll(part) if err != nil { t.Errorf("range=%q, reading part index %d body: %v", rt.r, ri, err) continue Cases @@ -311,7 +312,7 @@ func TestFileServerEscapesNames(t *testing.T) { if err != nil { t.Fatalf("test %q: Get: %v", test.name, err) } - b, err := ioutil.ReadAll(res.Body) + b, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("test %q: read Body: %v", test.name, err) } @@ -359,7 +360,7 @@ func TestFileServerSortsNames(t *testing.T) { } defer res.Body.Close() - b, err := ioutil.ReadAll(res.Body) + b, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("read Body: %v", err) } @@ -378,12 +379,12 @@ func mustRemoveAll(dir string) { func TestFileServerImplicitLeadingSlash(t *testing.T) { defer afterTest(t) - tempDir, err := ioutil.TempDir("", "") + tempDir, err := os.MkdirTemp("", "") if err != nil { t.Fatalf("TempDir: %v", err) } defer mustRemoveAll(tempDir) - if err := ioutil.WriteFile(filepath.Join(tempDir, "foo.txt"), []byte("Hello world"), 0644); err != nil { + if err := os.WriteFile(filepath.Join(tempDir, "foo.txt"), []byte("Hello world"), 0644); err != nil { t.Fatalf("WriteFile: %v", err) } ts := httptest.NewServer(StripPrefix("/bar/", FileServer(Dir(tempDir)))) @@ -393,7 +394,7 @@ func TestFileServerImplicitLeadingSlash(t *testing.T) { if err != nil { t.Fatalf("Get %s: %v", suffix, err) } - b, err := ioutil.ReadAll(res.Body) + b, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("ReadAll %s: %v", suffix, err) } @@ -570,6 +571,43 @@ func testServeFileWithContentEncoding(t *testing.T, h2 bool) { func TestServeIndexHtml(t *testing.T) { defer afterTest(t) + + for i := 0; i < 2; i++ { + var h Handler + var name string + switch i { + case 0: + h = FileServer(Dir(".")) + name = "Dir" + case 1: + h = FileServer(FS(os.DirFS("."))) + name = "DirFS" + } + t.Run(name, func(t *testing.T) { + const want = "index.html says hello\n" + ts := httptest.NewServer(h) + defer ts.Close() + + for _, path := range []string{"/testdata/", "/testdata/index.html"} { + res, err := Get(ts.URL + path) + if err != nil { + t.Fatal(err) + } + b, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal("reading Body:", err) + } + if s := string(b); s != want { + t.Errorf("for path %q got %q, want %q", path, s, want) + } + res.Body.Close() + } + }) + } +} + +func TestServeIndexHtmlFS(t *testing.T) { + defer afterTest(t) const want = "index.html says hello\n" ts := httptest.NewServer(FileServer(Dir("."))) defer ts.Close() @@ -579,7 +617,7 @@ func TestServeIndexHtml(t *testing.T) { if err != nil { t.Fatal(err) } - b, err := ioutil.ReadAll(res.Body) + b, err := io.ReadAll(res.Body) if err != nil { t.Fatal("reading Body:", err) } @@ -629,9 +667,9 @@ func (f *fakeFileInfo) Sys() interface{} { return nil } func (f *fakeFileInfo) ModTime() time.Time { return f.modtime } func (f *fakeFileInfo) IsDir() bool { return f.dir } func (f *fakeFileInfo) Size() int64 { return int64(len(f.contents)) } -func (f *fakeFileInfo) Mode() os.FileMode { +func (f *fakeFileInfo) Mode() fs.FileMode { if f.dir { - return 0755 | os.ModeDir + return 0755 | fs.ModeDir } return 0644 } @@ -644,12 +682,12 @@ type fakeFile struct { } func (f *fakeFile) Close() error { return nil } -func (f *fakeFile) Stat() (os.FileInfo, error) { return f.fi, nil } -func (f *fakeFile) Readdir(count int) ([]os.FileInfo, error) { +func (f *fakeFile) Stat() (fs.FileInfo, error) { return f.fi, nil } +func (f *fakeFile) Readdir(count int) ([]fs.FileInfo, error) { if !f.fi.dir { - return nil, os.ErrInvalid + return nil, fs.ErrInvalid } - var fis []os.FileInfo + var fis []fs.FileInfo limit := f.entpos + count if count <= 0 || limit > len(f.fi.ents) { @@ -668,11 +706,11 @@ func (f *fakeFile) Readdir(count int) ([]os.FileInfo, error) { type fakeFS map[string]*fakeFileInfo -func (fs fakeFS) Open(name string) (File, error) { +func (fsys fakeFS) Open(name string) (File, error) { name = path.Clean(name) - f, ok := fs[name] + f, ok := fsys[name] if !ok { - return nil, os.ErrNotExist + return nil, fs.ErrNotExist } if f.err != nil { return nil, f.err @@ -707,7 +745,7 @@ func TestDirectoryIfNotModified(t *testing.T) { if err != nil { t.Fatal(err) } - b, err := ioutil.ReadAll(res.Body) + b, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -747,7 +785,7 @@ func TestDirectoryIfNotModified(t *testing.T) { res.Body.Close() } -func mustStat(t *testing.T, fileName string) os.FileInfo { +func mustStat(t *testing.T, fileName string) fs.FileInfo { fi, err := os.Stat(fileName) if err != nil { t.Fatal(err) @@ -1044,7 +1082,7 @@ func TestServeContent(t *testing.T) { if err != nil { t.Fatal(err) } - io.Copy(ioutil.Discard, res.Body) + io.Copy(io.Discard, res.Body) res.Body.Close() if res.StatusCode != tt.wantStatus { t.Errorf("test %q using %q: got status = %d; want %d", testName, method, res.StatusCode, tt.wantStatus) @@ -1081,7 +1119,7 @@ func (issue12991FS) Open(string) (File, error) { return issue12991File{}, nil } type issue12991File struct{ File } -func (issue12991File) Stat() (os.FileInfo, error) { return nil, os.ErrPermission } +func (issue12991File) Stat() (fs.FileInfo, error) { return nil, fs.ErrPermission } func (issue12991File) Close() error { return nil } func TestServeContentErrorMessages(t *testing.T) { @@ -1091,7 +1129,7 @@ func TestServeContentErrorMessages(t *testing.T) { err: errors.New("random error"), }, "/403": &fakeFileInfo{ - err: &os.PathError{Err: os.ErrPermission}, + err: &fs.PathError{Err: fs.ErrPermission}, }, } ts := httptest.NewServer(FileServer(fs)) @@ -1136,6 +1174,14 @@ func TestLinuxSendfile(t *testing.T) { t.Skipf("skipping; failed to run strace: %v", err) } + filename := fmt.Sprintf("1kb-%d", os.Getpid()) + filepath := path.Join(os.TempDir(), filename) + + if err := os.WriteFile(filepath, bytes.Repeat([]byte{'a'}, 1<<10), 0755); err != nil { + t.Fatal(err) + } + defer os.Remove(filepath) + var buf bytes.Buffer child := exec.Command("strace", "-f", "-q", os.Args[0], "-test.run=TestLinuxSendfileChild") child.ExtraFiles = append(child.ExtraFiles, lnf) @@ -1146,11 +1192,11 @@ func TestLinuxSendfile(t *testing.T) { t.Skipf("skipping; failed to start straced child: %v", err) } - res, err := Get(fmt.Sprintf("http://%s/", ln.Addr())) + res, err := Get(fmt.Sprintf("http://%s/%s", ln.Addr(), filename)) if err != nil { t.Fatalf("http client error: %v", err) } - _, err = io.Copy(ioutil.Discard, res.Body) + _, err = io.Copy(io.Discard, res.Body) if err != nil { t.Fatalf("client body read error: %v", err) } @@ -1172,7 +1218,7 @@ func getBody(t *testing.T, testName string, req Request, client *Client) (*Respo if err != nil { t.Fatalf("%s: for URL %q, send error: %v", testName, req.URL.String(), err) } - b, err := ioutil.ReadAll(r.Body) + b, err := io.ReadAll(r.Body) if err != nil { t.Fatalf("%s: for URL %q, reading body: %v", testName, req.URL.String(), err) } @@ -1192,7 +1238,7 @@ func TestLinuxSendfileChild(*testing.T) { panic(err) } mux := NewServeMux() - mux.Handle("/", FileServer(Dir("testdata"))) + mux.Handle("/", FileServer(Dir(os.TempDir()))) mux.HandleFunc("/quit", func(ResponseWriter, *Request) { os.Exit(0) }) @@ -1281,7 +1327,7 @@ func (d fileServerCleanPathDir) Open(path string) (File, error) { // Just return back something that's a directory. return Dir(".").Open(".") } - return nil, os.ErrNotExist + return nil, fs.ErrNotExist } type panicOnSeek struct{ io.ReadSeeker } @@ -1308,3 +1354,61 @@ func Test_scanETag(t *testing.T) { } } } + +// Issue 40940: Ensure that we only accept non-negative suffix-lengths +// in "Range": "bytes=-N", and should reject "bytes=--2". +func TestServeFileRejectsInvalidSuffixLengths_h1(t *testing.T) { + testServeFileRejectsInvalidSuffixLengths(t, h1Mode) +} +func TestServeFileRejectsInvalidSuffixLengths_h2(t *testing.T) { + testServeFileRejectsInvalidSuffixLengths(t, h2Mode) +} + +func testServeFileRejectsInvalidSuffixLengths(t *testing.T, h2 bool) { + defer afterTest(t) + cst := httptest.NewUnstartedServer(FileServer(Dir("testdata"))) + cst.EnableHTTP2 = h2 + cst.StartTLS() + defer cst.Close() + + tests := []struct { + r string + wantCode int + wantBody string + }{ + {"bytes=--6", 416, "invalid range\n"}, + {"bytes=--0", 416, "invalid range\n"}, + {"bytes=---0", 416, "invalid range\n"}, + {"bytes=-6", 206, "hello\n"}, + {"bytes=6-", 206, "html says hello\n"}, + {"bytes=-6-", 416, "invalid range\n"}, + {"bytes=-0", 206, ""}, + {"bytes=", 200, "index.html says hello\n"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.r, func(t *testing.T) { + req, err := NewRequest("GET", cst.URL+"/index.html", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Range", tt.r) + res, err := cst.Client().Do(req) + if err != nil { + t.Fatal(err) + } + if g, w := res.StatusCode, tt.wantCode; g != w { + t.Errorf("StatusCode mismatch: got %d want %d", g, w) + } + slurp, err := io.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatal(err) + } + if g, w := string(slurp), tt.wantBody; g != w { + t.Fatalf("Content mismatch:\nGot: %q\nWant: %q", g, w) + } + }) + } +} diff --git a/libgo/go/net/http/h2_bundle.go b/libgo/go/net/http/h2_bundle.go index 71592e9..e13c661 100644 --- a/libgo/go/net/http/h2_bundle.go +++ b/libgo/go/net/http/h2_bundle.go @@ -5592,7 +5592,11 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead } if bodyOpen { if vv, ok := rp.header["Content-Length"]; ok { - req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64) + if cl, err := strconv.ParseUint(vv[0], 10, 63); err == nil { + req.ContentLength = int64(cl) + } else { + req.ContentLength = 0 + } } else { req.ContentLength = -1 } @@ -5630,7 +5634,7 @@ func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2re var trailer Header for _, v := range rp.header["Trailer"] { for _, key := range strings.Split(v, ",") { - key = CanonicalHeaderKey(strings.TrimSpace(key)) + key = CanonicalHeaderKey(textproto.TrimString(key)) switch key { case "Transfer-Encoding", "Trailer", "Content-Length": // Bogus. (copy of http1 rules) @@ -5975,9 +5979,8 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { var ctype, clen string if clen = rws.snapHeader.Get("Content-Length"); clen != "" { rws.snapHeader.Del("Content-Length") - clen64, err := strconv.ParseInt(clen, 10, 64) - if err == nil && clen64 >= 0 { - rws.sentContentLen = clen64 + if cl, err := strconv.ParseUint(clen, 10, 63); err == nil { + rws.sentContentLen = int64(cl) } else { clen = "" } @@ -6607,6 +6610,19 @@ type http2Transport struct { // waiting for their turn. StrictMaxConcurrentStreams bool + // ReadIdleTimeout is the timeout after which a health check using ping + // frame will be carried out if no frame is received on the connection. + // Note that a ping response will is considered a received frame, so if + // there is no other traffic on the connection, the health check will + // be performed every ReadIdleTimeout interval. + // If zero, no health check is performed. + ReadIdleTimeout time.Duration + + // PingTimeout is the timeout after which the connection will be closed + // if a response to Ping is not received. + // Defaults to 15s. + PingTimeout time.Duration + // t1, if non-nil, is the standard library Transport using // this transport. Its settings are used (but not its // RoundTrip method, etc). @@ -6630,14 +6646,31 @@ func (t *http2Transport) disableCompression() bool { return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression) } +func (t *http2Transport) pingTimeout() time.Duration { + if t.PingTimeout == 0 { + return 15 * time.Second + } + return t.PingTimeout + +} + // ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2. // It returns an error if t1 has already been HTTP/2-enabled. +// +// Use ConfigureTransports instead to configure the HTTP/2 Transport. func http2ConfigureTransport(t1 *Transport) error { - _, err := http2configureTransport(t1) + _, err := http2ConfigureTransports(t1) return err } -func http2configureTransport(t1 *Transport) (*http2Transport, error) { +// ConfigureTransports configures a net/http HTTP/1 Transport to use HTTP/2. +// It returns a new HTTP/2 Transport for further configuration. +// It returns an error if t1 has already been HTTP/2-enabled. +func http2ConfigureTransports(t1 *Transport) (*http2Transport, error) { + return http2configureTransports(t1) +} + +func http2configureTransports(t1 *Transport) (*http2Transport, error) { connPool := new(http2clientConnPool) t2 := &http2Transport{ ConnPool: http2noDialClientConnPool{connPool}, @@ -7176,6 +7209,20 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client return cc, nil } +func (cc *http2ClientConn) healthCheck() { + pingTimeout := cc.t.pingTimeout() + // We don't need to periodically ping in the health check, because the readLoop of ClientConn will + // trigger the healthCheck again if there is no frame received. + ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) + defer cancel() + err := cc.Ping(ctx) + if err != nil { + cc.closeForLostPing() + cc.t.connPool().MarkDead(cc) + return + } +} + func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) { cc.mu.Lock() defer cc.mu.Unlock() @@ -7347,14 +7394,12 @@ func (cc *http2ClientConn) sendGoAway() error { return nil } -// Close closes the client connection immediately. -// -// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead. -func (cc *http2ClientConn) Close() error { +// closes the client connection immediately. In-flight requests are interrupted. +// err is sent to streams. +func (cc *http2ClientConn) closeForError(err error) error { cc.mu.Lock() defer cc.cond.Broadcast() defer cc.mu.Unlock() - err := errors.New("http2: client connection force closed via ClientConn.Close") for id, cs := range cc.streams { select { case cs.resc <- http2resAndError{err: err}: @@ -7367,6 +7412,20 @@ func (cc *http2ClientConn) Close() error { return cc.tconn.Close() } +// Close closes the client connection immediately. +// +// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead. +func (cc *http2ClientConn) Close() error { + err := errors.New("http2: client connection force closed via ClientConn.Close") + return cc.closeForError(err) +} + +// closes the client connection immediately. In-flight requests are interrupted. +func (cc *http2ClientConn) closeForLostPing() error { + err := errors.New("http2: client connection lost") + return cc.closeForError(err) +} + const http2maxAllocFrameSize = 512 << 10 // frameBuffer returns a scratch buffer suitable for writing DATA frames. @@ -7592,6 +7651,9 @@ func (cc *http2ClientConn) roundTrip(req *Request) (res *Response, gotErrAfterRe // we can keep it. bodyWriter.cancel() cs.abortRequestBodyWrite(http2errStopReqBodyWrite) + if hasBody && !bodyWritten { + <-bodyWriter.resc + } } if re.err != nil { cc.forgetStreamID(cs.ID) @@ -7612,6 +7674,7 @@ func (cc *http2ClientConn) roundTrip(req *Request) (res *Response, gotErrAfterRe } else { bodyWriter.cancel() cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) + <-bodyWriter.resc } cc.forgetStreamID(cs.ID) return nil, cs.getStartedWrite(), http2errTimeout @@ -7626,6 +7689,7 @@ func (cc *http2ClientConn) roundTrip(req *Request) (res *Response, gotErrAfterRe } else { bodyWriter.cancel() cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) + <-bodyWriter.resc } cc.forgetStreamID(cs.ID) return nil, cs.getStartedWrite(), ctx.Err() @@ -7640,6 +7704,7 @@ func (cc *http2ClientConn) roundTrip(req *Request) (res *Response, gotErrAfterRe } else { bodyWriter.cancel() cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) + <-bodyWriter.resc } cc.forgetStreamID(cs.ID) return nil, cs.getStartedWrite(), http2errRequestCanceled @@ -7654,6 +7719,7 @@ func (cc *http2ClientConn) roundTrip(req *Request) (res *Response, gotErrAfterRe // forgetStreamID. return nil, cs.getStartedWrite(), cs.resetErr case err := <-bodyWriter.resc: + bodyWritten = true // Prefer the read loop's response, if available. Issue 16102. select { case re := <-readLoopResCh: @@ -7664,7 +7730,6 @@ func (cc *http2ClientConn) roundTrip(req *Request) (res *Response, gotErrAfterRe cc.forgetStreamID(cs.ID) return nil, cs.getStartedWrite(), err } - bodyWritten = true if d := cc.responseHeaderTimeout(); d != 0 { timer := time.NewTimer(d) defer timer.Stop() @@ -8262,8 +8327,17 @@ func (rl *http2clientConnReadLoop) run() error { rl.closeWhenIdle = cc.t.disableKeepAlives() || cc.singleUse gotReply := false // ever saw a HEADERS reply gotSettings := false + readIdleTimeout := cc.t.ReadIdleTimeout + var t *time.Timer + if readIdleTimeout != 0 { + t = time.AfterFunc(readIdleTimeout, cc.healthCheck) + defer t.Stop() + } for { f, err := cc.fr.ReadFrame() + if t != nil { + t.Reset(readIdleTimeout) + } if err != nil { cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err) } @@ -8475,8 +8549,8 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http if !streamEnded || isHead { res.ContentLength = -1 if clens := res.Header["Content-Length"]; len(clens) == 1 { - if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil { - res.ContentLength = clen64 + if cl, err := strconv.ParseUint(clens[0], 10, 63); err == nil { + res.ContentLength = int64(cl) } else { // TODO: care? unlike http/1, it won't mess up our framing, so it's // more safe smuggling-wise to ignore. @@ -8994,6 +9068,8 @@ func http2strSliceContains(ss []string, s string) bool { type http2erringRoundTripper struct{ err error } +func (rt http2erringRoundTripper) RoundTripErr() error { return rt.err } + func (rt http2erringRoundTripper) RoundTrip(*Request) (*Response, error) { return nil, rt.err } // gzipReader wraps a response body so it can lazily @@ -9075,7 +9151,9 @@ func (t *http2Transport) getBodyWriterState(cs *http2clientStream, body io.Reade func (s http2bodyWriterState) cancel() { if s.timer != nil { - s.timer.Stop() + if s.timer.Stop() { + s.resc <- nil + } } } diff --git a/libgo/go/net/http/http_test.go b/libgo/go/net/http/http_test.go index f4ea52d..3f1d7ce 100644 --- a/libgo/go/net/http/http_test.go +++ b/libgo/go/net/http/http_test.go @@ -13,13 +13,8 @@ import ( "os/exec" "reflect" "testing" - "time" ) -func init() { - shutdownPollInterval = 5 * time.Millisecond -} - func TestForeachHeaderElement(t *testing.T) { tests := []struct { in string @@ -91,7 +86,7 @@ func TestCmdGoNoHTTPServer(t *testing.T) { } wantSym := map[string]bool{ // Verify these exist: (sanity checking this test) - "net/http.(*Client).Get": true, + "net/http.(*Client).do": true, "net/http.(*Transport).RoundTrip": true, // Verify these don't exist: diff --git a/libgo/go/net/http/httptest/example_test.go b/libgo/go/net/http/httptest/example_test.go index 54e77db..a673843 100644 --- a/libgo/go/net/http/httptest/example_test.go +++ b/libgo/go/net/http/httptest/example_test.go @@ -7,7 +7,6 @@ package httptest_test import ( "fmt" "io" - "io/ioutil" "log" "net/http" "net/http/httptest" @@ -23,7 +22,7 @@ func ExampleResponseRecorder() { handler(w, req) resp := w.Result() - body, _ := ioutil.ReadAll(resp.Body) + body, _ := io.ReadAll(resp.Body) fmt.Println(resp.StatusCode) fmt.Println(resp.Header.Get("Content-Type")) @@ -45,7 +44,7 @@ func ExampleServer() { if err != nil { log.Fatal(err) } - greeting, err := ioutil.ReadAll(res.Body) + greeting, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { log.Fatal(err) @@ -67,7 +66,7 @@ func ExampleServer_hTTP2() { if err != nil { log.Fatal(err) } - greeting, err := ioutil.ReadAll(res.Body) + greeting, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { log.Fatal(err) @@ -89,7 +88,7 @@ func ExampleNewTLSServer() { log.Fatal(err) } - greeting, err := ioutil.ReadAll(res.Body) + greeting, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { log.Fatal(err) diff --git a/libgo/go/net/http/httptest/httptest.go b/libgo/go/net/http/httptest/httptest.go index f7202da..9bedefd 100644 --- a/libgo/go/net/http/httptest/httptest.go +++ b/libgo/go/net/http/httptest/httptest.go @@ -10,7 +10,6 @@ import ( "bytes" "crypto/tls" "io" - "io/ioutil" "net/http" "strings" ) @@ -66,7 +65,7 @@ func NewRequest(method, target string, body io.Reader) *http.Request { if rc, ok := body.(io.ReadCloser); ok { req.Body = rc } else { - req.Body = ioutil.NopCloser(body) + req.Body = io.NopCloser(body) } } diff --git a/libgo/go/net/http/httptest/httptest_test.go b/libgo/go/net/http/httptest/httptest_test.go index ef7d943..071add6 100644 --- a/libgo/go/net/http/httptest/httptest_test.go +++ b/libgo/go/net/http/httptest/httptest_test.go @@ -7,7 +7,6 @@ package httptest import ( "crypto/tls" "io" - "io/ioutil" "net/http" "net/url" "reflect" @@ -155,7 +154,7 @@ func TestNewRequest(t *testing.T) { } { t.Run(tt.name, func(t *testing.T) { got := NewRequest(tt.method, tt.uri, tt.body) - slurp, err := ioutil.ReadAll(got.Body) + slurp, err := io.ReadAll(got.Body) if err != nil { t.Errorf("ReadAll: %v", err) } diff --git a/libgo/go/net/http/httptest/recorder.go b/libgo/go/net/http/httptest/recorder.go index 66e67e7..2428482 100644 --- a/libgo/go/net/http/httptest/recorder.go +++ b/libgo/go/net/http/httptest/recorder.go @@ -7,7 +7,7 @@ package httptest import ( "bytes" "fmt" - "io/ioutil" + "io" "net/http" "net/textproto" "strconv" @@ -179,7 +179,7 @@ func (rw *ResponseRecorder) Result() *http.Response { } res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode)) if rw.Body != nil { - res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes())) + res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes())) } else { res.Body = http.NoBody } diff --git a/libgo/go/net/http/httptest/recorder_test.go b/libgo/go/net/http/httptest/recorder_test.go index e953489..a865e87 100644 --- a/libgo/go/net/http/httptest/recorder_test.go +++ b/libgo/go/net/http/httptest/recorder_test.go @@ -7,7 +7,6 @@ package httptest import ( "fmt" "io" - "io/ioutil" "net/http" "testing" ) @@ -42,7 +41,7 @@ func TestRecorder(t *testing.T) { } hasResultContents := func(want string) checkFunc { return func(rec *ResponseRecorder) error { - contentBytes, err := ioutil.ReadAll(rec.Result().Body) + contentBytes, err := io.ReadAll(rec.Result().Body) if err != nil { return err } diff --git a/libgo/go/net/http/httptest/server_test.go b/libgo/go/net/http/httptest/server_test.go index 0aad15c..39568b3 100644 --- a/libgo/go/net/http/httptest/server_test.go +++ b/libgo/go/net/http/httptest/server_test.go @@ -6,7 +6,7 @@ package httptest import ( "bufio" - "io/ioutil" + "io" "net" "net/http" "testing" @@ -61,7 +61,7 @@ func testServer(t *testing.T, newServer newServerFunc) { if err != nil { t.Fatal(err) } - got, err := ioutil.ReadAll(res.Body) + got, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Fatal(err) @@ -81,7 +81,7 @@ func testGetAfterClose(t *testing.T, newServer newServerFunc) { if err != nil { t.Fatal(err) } - got, err := ioutil.ReadAll(res.Body) + got, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -93,7 +93,7 @@ func testGetAfterClose(t *testing.T, newServer newServerFunc) { res, err = http.Get(ts.URL) if err == nil { - body, _ := ioutil.ReadAll(res.Body) + body, _ := io.ReadAll(res.Body) t.Fatalf("Unexpected response after close: %v, %v, %s", res.Status, res.Header, body) } } @@ -152,7 +152,7 @@ func testServerClient(t *testing.T, newTLSServer newServerFunc) { if err != nil { t.Fatal(err) } - got, err := ioutil.ReadAll(res.Body) + got, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Fatal(err) diff --git a/libgo/go/net/http/httputil/dump.go b/libgo/go/net/http/httputil/dump.go index c97be06..4c9d28b 100644 --- a/libgo/go/net/http/httputil/dump.go +++ b/libgo/go/net/http/httputil/dump.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net" "net/http" "net/url" @@ -35,7 +34,7 @@ func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) { if err = b.Close(); err != nil { return nil, b, err } - return ioutil.NopCloser(&buf), ioutil.NopCloser(bytes.NewReader(buf.Bytes())), nil + return io.NopCloser(&buf), io.NopCloser(bytes.NewReader(buf.Bytes())), nil } // dumpConn is a net.Conn which writes to Writer and reads from Reader @@ -81,7 +80,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) { if !body { contentLength := outgoingLength(req) if contentLength != 0 { - req.Body = ioutil.NopCloser(io.LimitReader(neverEnding('x'), contentLength)) + req.Body = io.NopCloser(io.LimitReader(neverEnding('x'), contentLength)) dummyBody = true } } else { @@ -133,7 +132,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) { if err == nil { // Ensure all the body is read; otherwise // we'll get a partial dump. - io.Copy(ioutil.Discard, req.Body) + io.Copy(io.Discard, req.Body) req.Body.Close() } select { @@ -296,7 +295,7 @@ func (failureToReadBody) Read([]byte) (int, error) { return 0, errNoBody } func (failureToReadBody) Close() error { return nil } // emptyBody is an instance of empty reader. -var emptyBody = ioutil.NopCloser(strings.NewReader("")) +var emptyBody = io.NopCloser(strings.NewReader("")) // DumpResponse is like DumpRequest but dumps a response. func DumpResponse(resp *http.Response, body bool) ([]byte, error) { diff --git a/libgo/go/net/http/httputil/dump_test.go b/libgo/go/net/http/httputil/dump_test.go index ead56bc..7571eb0 100644 --- a/libgo/go/net/http/httputil/dump_test.go +++ b/libgo/go/net/http/httputil/dump_test.go @@ -9,7 +9,6 @@ import ( "bytes" "fmt" "io" - "io/ioutil" "net/http" "net/url" "runtime" @@ -268,7 +267,7 @@ func TestDumpRequest(t *testing.T) { } switch b := ti.Body.(type) { case []byte: - req.Body = ioutil.NopCloser(bytes.NewReader(b)) + req.Body = io.NopCloser(bytes.NewReader(b)) case func() io.ReadCloser: req.Body = b() default: @@ -363,7 +362,7 @@ var dumpResTests = []struct { Header: http.Header{ "Foo": []string{"Bar"}, }, - Body: ioutil.NopCloser(strings.NewReader("foo")), // shouldn't be used + Body: io.NopCloser(strings.NewReader("foo")), // shouldn't be used }, body: false, // to verify we see 50, not empty or 3. want: `HTTP/1.1 200 OK @@ -379,7 +378,7 @@ Foo: Bar`, ProtoMajor: 1, ProtoMinor: 1, ContentLength: 3, - Body: ioutil.NopCloser(strings.NewReader("foo")), + Body: io.NopCloser(strings.NewReader("foo")), }, body: true, want: `HTTP/1.1 200 OK @@ -396,7 +395,7 @@ foo`, ProtoMajor: 1, ProtoMinor: 1, ContentLength: -1, - Body: ioutil.NopCloser(strings.NewReader("foo")), + Body: io.NopCloser(strings.NewReader("foo")), TransferEncoding: []string{"chunked"}, }, body: true, diff --git a/libgo/go/net/http/httputil/example_test.go b/libgo/go/net/http/httputil/example_test.go index 6191603..b77a243 100644 --- a/libgo/go/net/http/httputil/example_test.go +++ b/libgo/go/net/http/httputil/example_test.go @@ -6,7 +6,7 @@ package httputil_test import ( "fmt" - "io/ioutil" + "io" "log" "net/http" "net/http/httptest" @@ -39,7 +39,7 @@ func ExampleDumpRequest() { } defer resp.Body.Close() - b, err := ioutil.ReadAll(resp.Body) + b, err := io.ReadAll(resp.Body) if err != nil { log.Fatal(err) } @@ -111,7 +111,7 @@ func ExampleReverseProxy() { log.Fatal(err) } - b, err := ioutil.ReadAll(resp.Body) + b, err := io.ReadAll(resp.Body) if err != nil { log.Fatal(err) } diff --git a/libgo/go/net/http/httputil/reverseproxy.go b/libgo/go/net/http/httputil/reverseproxy.go index 3f48fab..4e36958 100644 --- a/libgo/go/net/http/httputil/reverseproxy.go +++ b/libgo/go/net/http/httputil/reverseproxy.go @@ -58,9 +58,9 @@ type ReverseProxy struct { // A negative value means to flush immediately // after each write to the client. // The FlushInterval is ignored when ReverseProxy - // recognizes a response as a streaming response; - // for such responses, writes are flushed to the client - // immediately. + // recognizes a response as a streaming response, or + // if its ContentLength is -1; for such responses, writes + // are flushed to the client immediately. FlushInterval time.Duration // ErrorLog specifies an optional logger for errors @@ -325,7 +325,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(res.StatusCode) - err = p.copyResponse(rw, res.Body, p.flushInterval(req, res)) + err = p.copyResponse(rw, res.Body, p.flushInterval(res)) if err != nil { defer res.Body.Close() // Since we're streaming the response, if we run into an error all we can do @@ -397,7 +397,7 @@ func removeConnectionHeaders(h http.Header) { // flushInterval returns the p.FlushInterval value, conditionally // overriding its value for a specific request/response. -func (p *ReverseProxy) flushInterval(req *http.Request, res *http.Response) time.Duration { +func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration { resCT := res.Header.Get("Content-Type") // For Server-Sent Events responses, flush immediately. @@ -406,7 +406,11 @@ func (p *ReverseProxy) flushInterval(req *http.Request, res *http.Response) time return -1 // negative means immediately } - // TODO: more specific cases? e.g. res.ContentLength == -1? + // We might have the case of streaming for which Content-Length might be unset. + if res.ContentLength == -1 { + return -1 + } + return p.FlushInterval } @@ -545,8 +549,6 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R return } - copyHeader(res.Header, rw.Header()) - hj, ok := rw.(http.Hijacker) if !ok { p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) @@ -577,6 +579,10 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R return } defer conn.Close() + + copyHeader(rw.Header(), res.Header) + + res.Header = rw.Header() res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above if err := res.Write(brw); err != nil { p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err)) diff --git a/libgo/go/net/http/httputil/reverseproxy_test.go b/libgo/go/net/http/httputil/reverseproxy_test.go index 764939f..3acbd94 100644 --- a/libgo/go/net/http/httputil/reverseproxy_test.go +++ b/libgo/go/net/http/httputil/reverseproxy_test.go @@ -13,7 +13,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "log" "net/http" "net/http/httptest" @@ -84,7 +83,7 @@ func TestReverseProxy(t *testing.T) { t.Fatal(err) } proxyHandler := NewSingleHostReverseProxy(backendURL) - proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests frontend := httptest.NewServer(proxyHandler) defer frontend.Close() frontendClient := frontend.Client() @@ -124,7 +123,7 @@ func TestReverseProxy(t *testing.T) { if cookie := res.Cookies()[0]; cookie.Name != "flavor" { t.Errorf("unexpected cookie %q", cookie.Name) } - bodyBytes, _ := ioutil.ReadAll(res.Body) + bodyBytes, _ := io.ReadAll(res.Body) if g, e := string(bodyBytes), backendResponse; g != e { t.Errorf("got body %q; expected %q", g, e) } @@ -218,7 +217,7 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { t.Fatalf("Get: %v", err) } defer res.Body.Close() - bodyBytes, err := ioutil.ReadAll(res.Body) + bodyBytes, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("reading body: %v", err) } @@ -271,7 +270,7 @@ func TestXForwardedFor(t *testing.T) { if g, e := res.StatusCode, backendStatus; g != e { t.Errorf("got res.StatusCode %d; expected %d", g, e) } - bodyBytes, _ := ioutil.ReadAll(res.Body) + bodyBytes, _ := io.ReadAll(res.Body) if g, e := string(bodyBytes), backendResponse; g != e { t.Errorf("got body %q; expected %q", g, e) } @@ -373,7 +372,7 @@ func TestReverseProxyFlushInterval(t *testing.T) { t.Fatalf("Get: %v", err) } defer res.Body.Close() - if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected { + if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected { t.Errorf("got body %q; expected %q", bodyBytes, expected) } } @@ -441,7 +440,7 @@ func TestReverseProxyCancellation(t *testing.T) { defer backend.Close() - backend.Config.ErrorLog = log.New(ioutil.Discard, "", 0) + backend.Config.ErrorLog = log.New(io.Discard, "", 0) backendURL, err := url.Parse(backend.URL) if err != nil { @@ -452,7 +451,7 @@ func TestReverseProxyCancellation(t *testing.T) { // Discards errors of the form: // http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection - proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) frontend := httptest.NewServer(proxyHandler) defer frontend.Close() @@ -504,7 +503,7 @@ func TestNilBody(t *testing.T) { t.Fatal(err) } defer res.Body.Close() - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -533,7 +532,7 @@ func TestUserAgentHeader(t *testing.T) { t.Fatal(err) } proxyHandler := NewSingleHostReverseProxy(backendURL) - proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests frontend := httptest.NewServer(proxyHandler) defer frontend.Close() frontendClient := frontend.Client() @@ -606,7 +605,7 @@ func TestReverseProxyGetPutBuffer(t *testing.T) { if err != nil { t.Fatalf("Get: %v", err) } - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Fatalf("reading body: %v", err) @@ -627,7 +626,7 @@ func TestReverseProxy_Post(t *testing.T) { const backendStatus = 200 var requestBody = bytes.Repeat([]byte("a"), 1<<20) backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - slurp, err := ioutil.ReadAll(r.Body) + slurp, err := io.ReadAll(r.Body) if err != nil { t.Errorf("Backend body read = %v", err) } @@ -656,7 +655,7 @@ func TestReverseProxy_Post(t *testing.T) { if g, e := res.StatusCode, backendStatus; g != e { t.Errorf("got res.StatusCode %d; expected %d", g, e) } - bodyBytes, _ := ioutil.ReadAll(res.Body) + bodyBytes, _ := io.ReadAll(res.Body) if g, e := string(bodyBytes), backendResponse; g != e { t.Errorf("got body %q; expected %q", g, e) } @@ -672,7 +671,7 @@ func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) func TestReverseProxy_NilBody(t *testing.T) { backendURL, _ := url.Parse("http://fake.tld/") proxyHandler := NewSingleHostReverseProxy(backendURL) - proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { if req.Body != nil { t.Error("Body != nil; want a nil Body") @@ -695,8 +694,8 @@ func TestReverseProxy_NilBody(t *testing.T) { // Issue 33142: always allocate the request headers func TestReverseProxy_AllocatedHeader(t *testing.T) { proxyHandler := new(ReverseProxy) - proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests - proxyHandler.Director = func(*http.Request) {} // noop + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + proxyHandler.Director = func(*http.Request) {} // noop proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { if req.Header == nil { t.Error("Header == nil; want a non-nil Header") @@ -722,7 +721,7 @@ func TestReverseProxyModifyResponse(t *testing.T) { rpURL, _ := url.Parse(backendServer.URL) rproxy := NewSingleHostReverseProxy(rpURL) - rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests rproxy.ModifyResponse = func(resp *http.Response) error { if resp.Header.Get("X-Hit-Mod") != "true" { return fmt.Errorf("tried to by-pass proxy") @@ -821,7 +820,7 @@ func TestReverseProxyErrorHandler(t *testing.T) { if rproxy.Transport == nil { rproxy.Transport = failingRoundTripper{} } - rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests if tt.errorHandler != nil { rproxy.ErrorHandler = tt.errorHandler } @@ -896,7 +895,7 @@ func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) { func BenchmarkServeHTTP(b *testing.B) { res := &http.Response{ StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("")), + Body: io.NopCloser(strings.NewReader("")), } proxy := &ReverseProxy{ Director: func(*http.Request) {}, @@ -953,7 +952,7 @@ func TestServeHTTPDeepCopy(t *testing.T) { // Issue 18327: verify we always do a deep copy of the Request.Header map // before any mutations. func TestClonesRequestHeaders(t *testing.T) { - log.SetOutput(ioutil.Discard) + log.SetOutput(io.Discard) defer log.SetOutput(os.Stderr) req, _ := http.NewRequest("GET", "http://foo.tld/", nil) req.RemoteAddr = "1.2.3.4:56789" @@ -1031,7 +1030,7 @@ func (cc *checkCloser) Read(b []byte) (int, error) { // Issue 23643: panic on body copy error func TestReverseProxy_PanicBodyError(t *testing.T) { - log.SetOutput(ioutil.Discard) + log.SetOutput(io.Discard) defer log.SetOutput(os.Stderr) backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { out := "this call was relayed by the reverse proxy" @@ -1067,7 +1066,6 @@ func TestSelectFlushInterval(t *testing.T) { tests := []struct { name string p *ReverseProxy - req *http.Request res *http.Response want time.Duration }{ @@ -1097,10 +1095,26 @@ func TestSelectFlushInterval(t *testing.T) { p: &ReverseProxy{FlushInterval: 0}, want: -1, }, + { + name: "Content-Length: -1, overrides non-zero", + res: &http.Response{ + ContentLength: -1, + }, + p: &ReverseProxy{FlushInterval: 123}, + want: -1, + }, + { + name: "Content-Length: -1, overrides zero", + res: &http.Response{ + ContentLength: -1, + }, + p: &ReverseProxy{FlushInterval: 0}, + want: -1, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := tt.p.flushInterval(tt.req, tt.res) + got := tt.p.flushInterval(tt.res) if got != tt.want { t.Errorf("flushLatency = %v; want %v", got, tt.want) } @@ -1133,7 +1147,7 @@ func TestReverseProxyWebSocket(t *testing.T) { backURL, _ := url.Parse(backendServer.URL) rproxy := NewSingleHostReverseProxy(backURL) - rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests rproxy.ModifyResponse = func(res *http.Response) error { res.Header.Add("X-Modified", "true") return nil @@ -1142,6 +1156,9 @@ func TestReverseProxyWebSocket(t *testing.T) { handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("X-Header", "X-Value") rproxy.ServeHTTP(rw, req) + if got, want := rw.Header().Get("X-Modified"), "true"; got != want { + t.Errorf("response writer X-Modified header = %q; want %q", got, want) + } }) frontendProxy := httptest.NewServer(handler) @@ -1247,7 +1264,7 @@ func TestReverseProxyWebSocketCancelation(t *testing.T) { backendURL, _ := url.Parse(cst.URL) rproxy := NewSingleHostReverseProxy(backendURL) - rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests rproxy.ModifyResponse = func(res *http.Response) error { res.Header.Add("X-Modified", "true") return nil @@ -1334,7 +1351,7 @@ func TestUnannouncedTrailer(t *testing.T) { t.Fatal(err) } proxyHandler := NewSingleHostReverseProxy(backendURL) - proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests frontend := httptest.NewServer(proxyHandler) defer frontend.Close() frontendClient := frontend.Client() @@ -1344,7 +1361,7 @@ func TestUnannouncedTrailer(t *testing.T) { t.Fatalf("Get: %v", err) } - ioutil.ReadAll(res.Body) + io.ReadAll(res.Body) if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w { t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w) diff --git a/libgo/go/net/http/internal/chunked_test.go b/libgo/go/net/http/internal/chunked_test.go index d067165..08152ed 100644 --- a/libgo/go/net/http/internal/chunked_test.go +++ b/libgo/go/net/http/internal/chunked_test.go @@ -9,7 +9,6 @@ import ( "bytes" "fmt" "io" - "io/ioutil" "strings" "testing" ) @@ -29,7 +28,7 @@ func TestChunk(t *testing.T) { } r := NewChunkedReader(&b) - data, err := ioutil.ReadAll(r) + data, err := io.ReadAll(r) if err != nil { t.Logf(`data: "%s"`, data) t.Fatalf("ReadAll from reader: %v", err) @@ -177,7 +176,7 @@ func TestChunkReadingIgnoresExtensions(t *testing.T) { "17;someext\r\n" + // token without value "world! 0123456789abcdef\r\n" + "0;someextension=sometoken\r\n" // token=token - data, err := ioutil.ReadAll(NewChunkedReader(strings.NewReader(in))) + data, err := io.ReadAll(NewChunkedReader(strings.NewReader(in))) if err != nil { t.Fatalf("ReadAll = %q, %v", data, err) } diff --git a/libgo/go/net/http/main_test.go b/libgo/go/net/http/main_test.go index 35cc809..6564627 100644 --- a/libgo/go/net/http/main_test.go +++ b/libgo/go/net/http/main_test.go @@ -6,7 +6,7 @@ package http_test import ( "fmt" - "io/ioutil" + "io" "log" "net/http" "os" @@ -17,7 +17,7 @@ import ( "time" ) -var quietLog = log.New(ioutil.Discard, "", 0) +var quietLog = log.New(io.Discard, "", 0) func TestMain(m *testing.M) { v := m.Run() diff --git a/libgo/go/net/http/omithttp2.go b/libgo/go/net/http/omithttp2.go index 7e2f492..30c6e48 100644 --- a/libgo/go/net/http/omithttp2.go +++ b/libgo/go/net/http/omithttp2.go @@ -32,10 +32,6 @@ type http2Transport struct { func (*http2Transport) RoundTrip(*Request) (*Response, error) { panic(noHTTP2) } func (*http2Transport) CloseIdleConnections() {} -type http2erringRoundTripper struct{ err error } - -func (http2erringRoundTripper) RoundTrip(*Request) (*Response, error) { panic(noHTTP2) } - type http2noDialH2RoundTripper struct{} func (http2noDialH2RoundTripper) RoundTrip(*Request) (*Response, error) { panic(noHTTP2) } @@ -49,7 +45,7 @@ type http2clientConnPool struct { conns map[string][]struct{} } -func http2configureTransport(*Transport) (*http2Transport, error) { panic(noHTTP2) } +func http2configureTransports(*Transport) (*http2Transport, error) { panic(noHTTP2) } func http2isNoCachedConnError(err error) bool { _, ok := err.(interface{ IsHTTP2NoCachedConnError() }) diff --git a/libgo/go/net/http/pprof/pprof.go b/libgo/go/net/http/pprof/pprof.go index 81df044..5389a38 100644 --- a/libgo/go/net/http/pprof/pprof.go +++ b/libgo/go/net/http/pprof/pprof.go @@ -61,11 +61,12 @@ import ( "bytes" "context" "fmt" - "html/template" + "html" "internal/profile" "io" "log" "net/http" + "net/url" "os" "runtime" "runtime/pprof" @@ -90,17 +91,13 @@ func init() { func Cmdline(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("Content-Type", "text/plain; charset=utf-8") - fmt.Fprintf(w, strings.Join(os.Args, "\x00")) + fmt.Fprint(w, strings.Join(os.Args, "\x00")) } -func sleep(w http.ResponseWriter, d time.Duration) { - var clientGone <-chan bool - if cn, ok := w.(http.CloseNotifier); ok { - clientGone = cn.CloseNotify() - } +func sleep(r *http.Request, d time.Duration) { select { case <-time.After(d): - case <-clientGone: + case <-r.Context().Done(): } } @@ -142,7 +139,7 @@ func Profile(w http.ResponseWriter, r *http.Request) { fmt.Sprintf("Could not enable CPU profiling: %s", err)) return } - sleep(w, time.Duration(sec)*time.Second) + sleep(r, time.Duration(sec)*time.Second) pprof.StopCPUProfile() } @@ -171,7 +168,7 @@ func Trace(w http.ResponseWriter, r *http.Request) { fmt.Sprintf("Could not enable tracing: %s", err)) return } - sleep(w, time.Duration(sec*float64(time.Second))) + sleep(r, time.Duration(sec*float64(time.Second))) trace.Stop() } @@ -356,6 +353,13 @@ var profileDescriptions = map[string]string{ "trace": "A trace of execution of the current program. You can specify the duration in the seconds GET parameter. After you get the trace file, use the go tool trace command to investigate the trace.", } +type profileEntry struct { + Name string + Href string + Desc string + Count int +} + // Index responds with the pprof-formatted profile named by the request. // For example, "/debug/pprof/heap" serves the "heap" profile. // Index responds to a request for "/debug/pprof/" with an HTML page @@ -372,17 +376,11 @@ func Index(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("Content-Type", "text/html; charset=utf-8") - type profile struct { - Name string - Href string - Desc string - Count int - } - var profiles []profile + var profiles []profileEntry for _, p := range pprof.Profiles() { - profiles = append(profiles, profile{ + profiles = append(profiles, profileEntry{ Name: p.Name(), - Href: p.Name() + "?debug=1", + Href: p.Name(), Desc: profileDescriptions[p.Name()], Count: p.Count(), }) @@ -390,7 +388,7 @@ func Index(w http.ResponseWriter, r *http.Request) { // Adding other profiles exposed from within this package for _, p := range []string{"cmdline", "profile", "trace"} { - profiles = append(profiles, profile{ + profiles = append(profiles, profileEntry{ Name: p, Href: p, Desc: profileDescriptions[p], @@ -401,12 +399,14 @@ func Index(w http.ResponseWriter, r *http.Request) { return profiles[i].Name < profiles[j].Name }) - if err := indexTmpl.Execute(w, profiles); err != nil { + if err := indexTmplExecute(w, profiles); err != nil { log.Print(err) } } -var indexTmpl = template.Must(template.New("index").Parse(`<html> +func indexTmplExecute(w io.Writer, profiles []profileEntry) error { + var b bytes.Buffer + b.WriteString(`<html> <head> <title>/debug/pprof/</title> <style> @@ -422,22 +422,28 @@ var indexTmpl = template.Must(template.New("index").Parse(`<html> Types of profiles available: <table> <thead><td>Count</td><td>Profile</td></thead> -{{range .}} - <tr> - <td>{{.Count}}</td><td><a href={{.Href}}>{{.Name}}</a></td> - </tr> -{{end}} -</table> +`) + + for _, profile := range profiles { + link := &url.URL{Path: profile.Href, RawQuery: "debug=1"} + fmt.Fprintf(&b, "<tr><td>%d</td><td><a href='%s'>%s</a></td></tr>\n", profile.Count, link, html.EscapeString(profile.Name)) + } + + b.WriteString(`</table> <a href="goroutine?debug=2">full goroutine stack dump</a> <br/> <p> Profile Descriptions: <ul> -{{range .}} -<li><div class=profile-name>{{.Name}}:</div> {{.Desc}}</li> -{{end}} -</ul> +`) + for _, profile := range profiles { + fmt.Fprintf(&b, "<li><div class=profile-name>%s: </div> %s</li>\n", html.EscapeString(profile.Name), html.EscapeString(profile.Desc)) + } + b.WriteString(`</ul> </p> </body> -</html> -`)) +</html>`) + + _, err := w.Write(b.Bytes()) + return err +} diff --git a/libgo/go/net/http/pprof/pprof_test.go b/libgo/go/net/http/pprof/pprof_test.go index f6f9ef5..84757e4 100644 --- a/libgo/go/net/http/pprof/pprof_test.go +++ b/libgo/go/net/http/pprof/pprof_test.go @@ -8,7 +8,7 @@ import ( "bytes" "fmt" "internal/profile" - "io/ioutil" + "io" "net/http" "net/http/httptest" "runtime" @@ -63,7 +63,7 @@ func TestHandlers(t *testing.T) { t.Errorf("status code: got %d; want %d", got, want) } - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { t.Errorf("when reading response body, expected non-nil err; got %v", err) } @@ -227,7 +227,7 @@ func query(endpoint string) (*profile.Profile, error) { return nil, fmt.Errorf("failed to fetch %q: %v", url, r.Status) } - b, err := ioutil.ReadAll(r.Body) + b, err := io.ReadAll(r.Body) r.Body.Close() if err != nil { return nil, fmt.Errorf("failed to read and parse the result from %q: %v", url, err) diff --git a/libgo/go/net/http/readrequest_test.go b/libgo/go/net/http/readrequest_test.go index b227bb6..1950f49 100644 --- a/libgo/go/net/http/readrequest_test.go +++ b/libgo/go/net/http/readrequest_test.go @@ -9,7 +9,6 @@ import ( "bytes" "fmt" "io" - "io/ioutil" "net/url" "reflect" "strings" @@ -468,7 +467,7 @@ func TestReadRequest_Bad(t *testing.T) { for _, tt := range badRequestTests { got, err := ReadRequest(bufio.NewReader(bytes.NewReader(tt.req))) if err == nil { - all, err := ioutil.ReadAll(got.Body) + all, err := io.ReadAll(got.Body) t.Errorf("%s: got unexpected request = %#v\n Body = %q, %v", tt.name, got, all, err) } } diff --git a/libgo/go/net/http/request.go b/libgo/go/net/http/request.go index 54ec1c5..adba540 100644 --- a/libgo/go/net/http/request.go +++ b/libgo/go/net/http/request.go @@ -15,7 +15,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "mime" "mime/multipart" "net" @@ -175,6 +174,10 @@ type Request struct { // but will return EOF immediately when no body is present. // The Server will close the request body. The ServeHTTP // Handler does not need to. + // + // Body must allow Read to be called concurrently with Close. + // In particular, calling Close should unblock a Read waiting + // for input. Body io.ReadCloser // GetBody defines an optional func to return a new copy of @@ -540,6 +543,7 @@ var errMissingHost = errors.New("http: Request.Write on Request with no Host or // extraHeaders may be nil // waitForContinue may be nil +// always closes body func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitForContinue func() bool) (err error) { trace := httptrace.ContextClientTrace(r.Context()) if trace != nil && trace.WroteRequest != nil { @@ -549,6 +553,15 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF }) }() } + closed := false + defer func() { + if closed { + return + } + if closeErr := r.closeBody(); closeErr != nil && err == nil { + err = closeErr + } + }() // Find the target host. Prefer the Host: header, but if that // is not given, use the host from the request URL. @@ -667,6 +680,7 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF trace.Wait100Continue() } if !waitForContinue() { + closed = true r.closeBody() return nil } @@ -679,6 +693,7 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF } // Write body and trailer + closed = true err = tw.writeBody(w) if err != nil { if tw.bodyReadError == err { @@ -854,7 +869,7 @@ func NewRequestWithContext(ctx context.Context, method, url string, body io.Read } rc, ok := body.(io.ReadCloser) if !ok && body != nil { - rc = ioutil.NopCloser(body) + rc = io.NopCloser(body) } // The host's colon:port should be normalized. See Issue 14836. u.Host = removeEmptyPort(u.Host) @@ -876,21 +891,21 @@ func NewRequestWithContext(ctx context.Context, method, url string, body io.Read buf := v.Bytes() req.GetBody = func() (io.ReadCloser, error) { r := bytes.NewReader(buf) - return ioutil.NopCloser(r), nil + return io.NopCloser(r), nil } case *bytes.Reader: req.ContentLength = int64(v.Len()) snapshot := *v req.GetBody = func() (io.ReadCloser, error) { r := snapshot - return ioutil.NopCloser(&r), nil + return io.NopCloser(&r), nil } case *strings.Reader: req.ContentLength = int64(v.Len()) snapshot := *v req.GetBody = func() (io.ReadCloser, error) { r := snapshot - return ioutil.NopCloser(&r), nil + return io.NopCloser(&r), nil } default: // This is where we'd set it to -1 (at least @@ -1189,7 +1204,7 @@ func parsePostForm(r *Request) (vs url.Values, err error) { maxFormSize = int64(10 << 20) // 10 MB is a lot of text. reader = io.LimitReader(r.Body, maxFormSize+1) } - b, e := ioutil.ReadAll(reader) + b, e := io.ReadAll(reader) if e != nil { if err == nil { err = e @@ -1383,10 +1398,11 @@ func (r *Request) wantsClose() bool { return hasToken(r.Header.get("Connection"), "close") } -func (r *Request) closeBody() { - if r.Body != nil { - r.Body.Close() +func (r *Request) closeBody() error { + if r.Body == nil { + return nil } + return r.Body.Close() } func (r *Request) isReplayable() bool { diff --git a/libgo/go/net/http/request_test.go b/libgo/go/net/http/request_test.go index 461d66e..29297b0 100644 --- a/libgo/go/net/http/request_test.go +++ b/libgo/go/net/http/request_test.go @@ -12,7 +12,7 @@ import ( "encoding/base64" "fmt" "io" - "io/ioutil" + "math" "mime/multipart" . "net/http" "net/http/httptest" @@ -103,7 +103,7 @@ func TestParseFormUnknownContentType(t *testing.T) { req := &Request{ Method: "POST", Header: test.contentType, - Body: ioutil.NopCloser(strings.NewReader("body")), + Body: io.NopCloser(strings.NewReader("body")), } err := req.ParseForm() switch { @@ -150,7 +150,7 @@ func TestMultipartReader(t *testing.T) { req := &Request{ Method: "POST", Header: Header{"Content-Type": {test.contentType}}, - Body: ioutil.NopCloser(new(bytes.Buffer)), + Body: io.NopCloser(new(bytes.Buffer)), } multipart, err := req.MultipartReader() if test.shouldError { @@ -187,7 +187,7 @@ binary data req := &Request{ Method: "POST", Header: Header{"Content-Type": {`multipart/form-data; boundary=xxx`}}, - Body: ioutil.NopCloser(strings.NewReader(postData)), + Body: io.NopCloser(strings.NewReader(postData)), } initialFormItems := map[string]string{ @@ -231,7 +231,7 @@ func TestParseMultipartForm(t *testing.T) { req := &Request{ Method: "POST", Header: Header{"Content-Type": {`multipart/form-data; boundary="foo123"`}}, - Body: ioutil.NopCloser(new(bytes.Buffer)), + Body: io.NopCloser(new(bytes.Buffer)), } err := req.ParseMultipartForm(25) if err == nil { @@ -245,6 +245,50 @@ func TestParseMultipartForm(t *testing.T) { } } +// Issue #40430: Test that if maxMemory for ParseMultipartForm when combined with +// the payload size and the internal leeway buffer size of 10MiB overflows, that we +// correctly return an error. +func TestMaxInt64ForMultipartFormMaxMemoryOverflow(t *testing.T) { + defer afterTest(t) + + payloadSize := 1 << 10 + cst := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + // The combination of: + // MaxInt64 + payloadSize + (internal spare of 10MiB) + // triggers the overflow. See issue https://golang.org/issue/40430/ + if err := req.ParseMultipartForm(math.MaxInt64); err != nil { + Error(rw, err.Error(), StatusBadRequest) + return + } + })) + defer cst.Close() + fBuf := new(bytes.Buffer) + mw := multipart.NewWriter(fBuf) + mf, err := mw.CreateFormFile("file", "myfile.txt") + if err != nil { + t.Fatal(err) + } + if _, err := mf.Write(bytes.Repeat([]byte("abc"), payloadSize)); err != nil { + t.Fatal(err) + } + if err := mw.Close(); err != nil { + t.Fatal(err) + } + req, err := NewRequest("POST", cst.URL, fBuf) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", mw.FormDataContentType()) + res, err := cst.Client().Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if g, w := res.StatusCode, StatusOK; g != w { + t.Fatalf("Status code mismatch: got %d, want %d", g, w) + } +} + func TestRedirect_h1(t *testing.T) { testRedirect(t, h1Mode) } func TestRedirect_h2(t *testing.T) { testRedirect(t, h2Mode) } func testRedirect(t *testing.T, h2 bool) { @@ -756,10 +800,10 @@ func (dr delayedEOFReader) Read(p []byte) (n int, err error) { } func TestIssue10884_MaxBytesEOF(t *testing.T) { - dst := ioutil.Discard + dst := io.Discard _, err := io.Copy(dst, MaxBytesReader( responseWriterJustWriter{dst}, - ioutil.NopCloser(delayedEOFReader{strings.NewReader("12345")}), + io.NopCloser(delayedEOFReader{strings.NewReader("12345")}), 5)) if err != nil { t.Fatal(err) @@ -799,7 +843,7 @@ func TestMaxBytesReaderStickyError(t *testing.T) { 2: {101, 100}, } for i, tt := range tests { - rc := MaxBytesReader(nil, ioutil.NopCloser(bytes.NewReader(make([]byte, tt.readable))), tt.limit) + rc := MaxBytesReader(nil, io.NopCloser(bytes.NewReader(make([]byte, tt.readable))), tt.limit) if err := isSticky(rc); err != nil { t.Errorf("%d. error: %v", i, err) } @@ -900,7 +944,7 @@ func TestNewRequestGetBody(t *testing.T) { t.Errorf("test[%d]: GetBody = nil", i) continue } - slurp1, err := ioutil.ReadAll(req.Body) + slurp1, err := io.ReadAll(req.Body) if err != nil { t.Errorf("test[%d]: ReadAll(Body) = %v", i, err) } @@ -908,7 +952,7 @@ func TestNewRequestGetBody(t *testing.T) { if err != nil { t.Errorf("test[%d]: GetBody = %v", i, err) } - slurp2, err := ioutil.ReadAll(newBody) + slurp2, err := io.ReadAll(newBody) if err != nil { t.Errorf("test[%d]: ReadAll(GetBody()) = %v", i, err) } @@ -1119,7 +1163,7 @@ func BenchmarkFileAndServer_64MB(b *testing.B) { } func benchmarkFileAndServer(b *testing.B, n int64) { - f, err := ioutil.TempFile(os.TempDir(), "go-bench-http-file-and-server") + f, err := os.CreateTemp(os.TempDir(), "go-bench-http-file-and-server") if err != nil { b.Fatalf("Failed to create temp file: %v", err) } @@ -1145,7 +1189,7 @@ func benchmarkFileAndServer(b *testing.B, n int64) { func runFileAndServerBenchmarks(b *testing.B, tlsOption bool, f *os.File, n int64) { handler := HandlerFunc(func(rw ResponseWriter, req *Request) { defer req.Body.Close() - nc, err := io.Copy(ioutil.Discard, req.Body) + nc, err := io.Copy(io.Discard, req.Body) if err != nil { panic(err) } @@ -1172,7 +1216,7 @@ func runFileAndServerBenchmarks(b *testing.B, tlsOption bool, f *os.File, n int6 } b.StartTimer() - req, err := NewRequest("PUT", cst.URL, ioutil.NopCloser(f)) + req, err := NewRequest("PUT", cst.URL, io.NopCloser(f)) if err != nil { b.Fatal(err) } diff --git a/libgo/go/net/http/requestwrite_test.go b/libgo/go/net/http/requestwrite_test.go index b110b57..1157bdf 100644 --- a/libgo/go/net/http/requestwrite_test.go +++ b/libgo/go/net/http/requestwrite_test.go @@ -10,11 +10,11 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net" "net/url" "strings" "testing" + "testing/iotest" "time" ) @@ -228,7 +228,7 @@ var reqWriteTests = []reqWriteTest{ ContentLength: 0, // as if unset by user }, - Body: func() io.ReadCloser { return ioutil.NopCloser(io.LimitReader(strings.NewReader("xx"), 0)) }, + Body: func() io.ReadCloser { return io.NopCloser(io.LimitReader(strings.NewReader("xx"), 0)) }, WantWrite: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + @@ -280,7 +280,7 @@ var reqWriteTests = []reqWriteTest{ ContentLength: 0, // as if unset by user }, - Body: func() io.ReadCloser { return ioutil.NopCloser(io.LimitReader(strings.NewReader("xx"), 1)) }, + Body: func() io.ReadCloser { return io.NopCloser(io.LimitReader(strings.NewReader("xx"), 1)) }, WantWrite: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + @@ -349,8 +349,8 @@ var reqWriteTests = []reqWriteTest{ Body: func() io.ReadCloser { err := errors.New("Custom reader error") - errReader := &errorReader{err} - return ioutil.NopCloser(io.MultiReader(strings.NewReader("x"), errReader)) + errReader := iotest.ErrReader(err) + return io.NopCloser(io.MultiReader(strings.NewReader("x"), errReader)) }, WantError: errors.New("Custom reader error"), @@ -369,8 +369,8 @@ var reqWriteTests = []reqWriteTest{ Body: func() io.ReadCloser { err := errors.New("Custom reader error") - errReader := &errorReader{err} - return ioutil.NopCloser(errReader) + errReader := iotest.ErrReader(err) + return io.NopCloser(errReader) }, WantError: errors.New("Custom reader error"), @@ -587,6 +587,26 @@ var reqWriteTests = []reqWriteTest{ }, WantError: errors.New("net/http: can't write control character in Request.URL"), }, + + 26: { // Request with nil body and PATCH method. Issue #40978 + Req: Request{ + Method: "PATCH", + URL: mustParseURL("/"), + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 0, // as if unset by user + }, + Body: nil, + WantWrite: "PATCH / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go-http-client/1.1\r\n" + + "Content-Length: 0\r\n\r\n", + WantProxy: "PATCH / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go-http-client/1.1\r\n" + + "Content-Length: 0\r\n\r\n", + }, } func TestRequestWrite(t *testing.T) { @@ -599,7 +619,7 @@ func TestRequestWrite(t *testing.T) { } switch b := tt.Body.(type) { case []byte: - tt.Req.Body = ioutil.NopCloser(bytes.NewReader(b)) + tt.Req.Body = io.NopCloser(bytes.NewReader(b)) case func() io.ReadCloser: tt.Req.Body = b() } @@ -695,20 +715,20 @@ func TestRequestWriteTransport(t *testing.T) { }, { method: "GET", - body: ioutil.NopCloser(strings.NewReader("")), + body: io.NopCloser(strings.NewReader("")), want: noContentLengthOrTransferEncoding, }, { method: "GET", clen: -1, - body: ioutil.NopCloser(strings.NewReader("")), + body: io.NopCloser(strings.NewReader("")), want: noContentLengthOrTransferEncoding, }, // A GET with a body, with explicit content length: { method: "GET", clen: 7, - body: ioutil.NopCloser(strings.NewReader("foobody")), + body: io.NopCloser(strings.NewReader("foobody")), want: all(matchSubstr("Content-Length: 7"), matchSubstr("foobody")), }, @@ -716,7 +736,7 @@ func TestRequestWriteTransport(t *testing.T) { { method: "GET", clen: -1, - body: ioutil.NopCloser(strings.NewReader("foobody")), + body: io.NopCloser(strings.NewReader("foobody")), want: all(matchSubstr("Transfer-Encoding: chunked"), matchSubstr("\r\n1\r\nf\r\n"), matchSubstr("oobody")), @@ -726,14 +746,14 @@ func TestRequestWriteTransport(t *testing.T) { { method: "POST", clen: -1, - body: ioutil.NopCloser(strings.NewReader("foobody")), + body: io.NopCloser(strings.NewReader("foobody")), want: all(matchSubstr("Transfer-Encoding: chunked"), matchSubstr("foobody")), }, { method: "POST", clen: -1, - body: ioutil.NopCloser(strings.NewReader("")), + body: io.NopCloser(strings.NewReader("")), want: all(matchSubstr("Transfer-Encoding: chunked")), }, // Verify that a blocking Request.Body doesn't block forever. @@ -745,7 +765,7 @@ func TestRequestWriteTransport(t *testing.T) { tt.afterReqRead = func() { pw.Close() } - tt.body = ioutil.NopCloser(pr) + tt.body = io.NopCloser(pr) }, want: matchSubstr("Transfer-Encoding: chunked"), }, @@ -916,7 +936,7 @@ func dumpRequestOut(req *Request, onReadHeaders func()) ([]byte, error) { } // Ensure all the body is read; otherwise // we'll get a partial dump. - io.Copy(ioutil.Discard, req.Body) + io.Copy(io.Discard, req.Body) req.Body.Close() } dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n") diff --git a/libgo/go/net/http/response.go b/libgo/go/net/http/response.go index 72812f0..b8985da 100644 --- a/libgo/go/net/http/response.go +++ b/libgo/go/net/http/response.go @@ -352,10 +352,21 @@ func (r *Response) bodyIsWritable() bool { return ok } -// isProtocolSwitch reports whether r is a response to a successful -// protocol upgrade. +// isProtocolSwitch reports whether the response code and header +// indicate a successful protocol upgrade response. func (r *Response) isProtocolSwitch() bool { - return r.StatusCode == StatusSwitchingProtocols && - r.Header.Get("Upgrade") != "" && - httpguts.HeaderValuesContainsToken(r.Header["Connection"], "Upgrade") + return isProtocolSwitchResponse(r.StatusCode, r.Header) +} + +// isProtocolSwitchResponse reports whether the response code and +// response header indicate a successful protocol upgrade response. +func isProtocolSwitchResponse(code int, h Header) bool { + return code == StatusSwitchingProtocols && isProtocolSwitchHeader(h) +} + +// isProtocolSwitchHeader reports whether the request or response header +// is for a protocol switch. +func isProtocolSwitchHeader(h Header) bool { + return h.Get("Upgrade") != "" && + httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") } diff --git a/libgo/go/net/http/response_test.go b/libgo/go/net/http/response_test.go index ce87260..8eef654 100644 --- a/libgo/go/net/http/response_test.go +++ b/libgo/go/net/http/response_test.go @@ -12,7 +12,6 @@ import ( "fmt" "go/token" "io" - "io/ioutil" "net/http/internal" "net/url" "reflect" @@ -620,7 +619,7 @@ func TestWriteResponse(t *testing.T) { t.Errorf("#%d: %v", i, err) continue } - err = resp.Write(ioutil.Discard) + err = resp.Write(io.Discard) if err != nil { t.Errorf("#%d: %v", i, err) continue @@ -722,7 +721,7 @@ func TestReadResponseCloseInMiddle(t *testing.T) { } resp.Body.Close() - rest, err := ioutil.ReadAll(bufr) + rest, err := io.ReadAll(bufr) checkErr(err, "ReadAll on remainder") if e, g := "Next Request Here", string(rest); e != g { g = regexp.MustCompile(`(xx+)`).ReplaceAllStringFunc(g, func(match string) string { diff --git a/libgo/go/net/http/responsewrite_test.go b/libgo/go/net/http/responsewrite_test.go index d41d898..1cc87b9 100644 --- a/libgo/go/net/http/responsewrite_test.go +++ b/libgo/go/net/http/responsewrite_test.go @@ -6,7 +6,7 @@ package http import ( "bytes" - "io/ioutil" + "io" "strings" "testing" ) @@ -26,7 +26,7 @@ func TestResponseWrite(t *testing.T) { ProtoMinor: 0, Request: dummyReq("GET"), Header: Header{}, - Body: ioutil.NopCloser(strings.NewReader("abcdef")), + Body: io.NopCloser(strings.NewReader("abcdef")), ContentLength: 6, }, @@ -42,7 +42,7 @@ func TestResponseWrite(t *testing.T) { ProtoMinor: 0, Request: dummyReq("GET"), Header: Header{}, - Body: ioutil.NopCloser(strings.NewReader("abcdef")), + Body: io.NopCloser(strings.NewReader("abcdef")), ContentLength: -1, }, "HTTP/1.0 200 OK\r\n" + @@ -57,7 +57,7 @@ func TestResponseWrite(t *testing.T) { ProtoMinor: 1, Request: dummyReq("GET"), Header: Header{}, - Body: ioutil.NopCloser(strings.NewReader("abcdef")), + Body: io.NopCloser(strings.NewReader("abcdef")), ContentLength: -1, Close: true, }, @@ -74,7 +74,7 @@ func TestResponseWrite(t *testing.T) { ProtoMinor: 1, Request: dummyReq11("GET"), Header: Header{}, - Body: ioutil.NopCloser(strings.NewReader("abcdef")), + Body: io.NopCloser(strings.NewReader("abcdef")), ContentLength: -1, Close: false, }, @@ -92,7 +92,7 @@ func TestResponseWrite(t *testing.T) { ProtoMinor: 1, Request: dummyReq11("GET"), Header: Header{}, - Body: ioutil.NopCloser(strings.NewReader("abcdef")), + Body: io.NopCloser(strings.NewReader("abcdef")), ContentLength: -1, TransferEncoding: []string{"chunked"}, Close: false, @@ -125,7 +125,7 @@ func TestResponseWrite(t *testing.T) { ProtoMinor: 1, Request: dummyReq11("GET"), Header: Header{}, - Body: ioutil.NopCloser(strings.NewReader("")), + Body: io.NopCloser(strings.NewReader("")), ContentLength: 0, Close: false, }, @@ -141,7 +141,7 @@ func TestResponseWrite(t *testing.T) { ProtoMinor: 1, Request: dummyReq11("GET"), Header: Header{}, - Body: ioutil.NopCloser(strings.NewReader("foo")), + Body: io.NopCloser(strings.NewReader("foo")), ContentLength: 0, Close: false, }, @@ -157,7 +157,7 @@ func TestResponseWrite(t *testing.T) { ProtoMinor: 1, Request: dummyReq("GET"), Header: Header{}, - Body: ioutil.NopCloser(strings.NewReader("abcdef")), + Body: io.NopCloser(strings.NewReader("abcdef")), ContentLength: 6, TransferEncoding: []string{"chunked"}, Close: true, @@ -218,7 +218,7 @@ func TestResponseWrite(t *testing.T) { Request: &Request{Method: "POST"}, Header: Header{}, ContentLength: -1, - Body: ioutil.NopCloser(strings.NewReader("abcdef")), + Body: io.NopCloser(strings.NewReader("abcdef")), }, "HTTP/1.1 200 OK\r\nConnection: close\r\n\r\nabcdef", }, diff --git a/libgo/go/net/http/roundtrip_js.go b/libgo/go/net/http/roundtrip_js.go index 509d229..c6a221a 100644 --- a/libgo/go/net/http/roundtrip_js.go +++ b/libgo/go/net/http/roundtrip_js.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "strconv" "syscall/js" ) @@ -92,15 +91,17 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { // See https://github.com/web-platform-tests/wpt/issues/7693 for WHATWG tests issue. // See https://developer.mozilla.org/en-US/docs/Web/API/Streams_API for more details on the Streams API // and browser support. - body, err := ioutil.ReadAll(req.Body) + body, err := io.ReadAll(req.Body) if err != nil { req.Body.Close() // RoundTrip must always close the body, including on errors. return nil, err } req.Body.Close() - buf := uint8Array.New(len(body)) - js.CopyBytesToJS(buf, body) - opt.Set("body", buf) + if len(body) != 0 { + buf := uint8Array.New(len(body)) + js.CopyBytesToJS(buf, body) + opt.Set("body", buf) + } } fetchPromise := js.Global().Call("fetch", req.URL.String(), opt) diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go index 5f56932..95e6bf4 100644 --- a/libgo/go/net/http/serve_test.go +++ b/libgo/go/net/http/serve_test.go @@ -18,7 +18,6 @@ import ( "fmt" "internal/testenv" "io" - "io/ioutil" "log" "math/rand" "net" @@ -529,7 +528,7 @@ func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) { if err != nil { continue } - slurp, _ := ioutil.ReadAll(res.Body) + slurp, _ := io.ReadAll(res.Body) res.Body.Close() if !tt.statusOk { if got, want := res.StatusCode, 404; got != want { @@ -689,7 +688,7 @@ func testServerTimeouts(timeout time.Duration) error { if err != nil { return fmt.Errorf("http Get #1: %v", err) } - got, err := ioutil.ReadAll(r.Body) + got, err := io.ReadAll(r.Body) expected := "req=1" if string(got) != expected || err != nil { return fmt.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil", @@ -721,7 +720,7 @@ func testServerTimeouts(timeout time.Duration) error { if err != nil { return fmt.Errorf("http Get #2: %v", err) } - got, err = ioutil.ReadAll(r.Body) + got, err = io.ReadAll(r.Body) r.Body.Close() expected = "req=2" if string(got) != expected || err != nil { @@ -734,7 +733,7 @@ func testServerTimeouts(timeout time.Duration) error { return fmt.Errorf("long Dial: %v", err) } defer conn.Close() - go io.Copy(ioutil.Discard, conn) + go io.Copy(io.Discard, conn) for i := 0; i < 5; i++ { _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n")) if err != nil { @@ -954,7 +953,7 @@ func TestOnlyWriteTimeout(t *testing.T) { errc <- err return } - _, err = io.Copy(ioutil.Discard, res.Body) + _, err = io.Copy(io.Discard, res.Body) res.Body.Close() errc <- err }() @@ -1058,7 +1057,7 @@ func TestIdentityResponse(t *testing.T) { } // The ReadAll will hang for a failing test. - got, _ := ioutil.ReadAll(conn) + got, _ := io.ReadAll(conn) expectedSuffix := "\r\n\r\ntoo short" if !strings.HasSuffix(string(got), expectedSuffix) { t.Errorf("Expected output to end with %q; got response body %q", @@ -1099,7 +1098,7 @@ func testTCPConnectionCloses(t *testing.T, req string, h Handler) { } }() - _, err = ioutil.ReadAll(r) + _, err = io.ReadAll(r) if err != nil { t.Fatal("read error:", err) } @@ -1129,7 +1128,7 @@ func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) { if err != nil { t.Fatalf("res %d: %v", i+1, err) } - if _, err := io.Copy(ioutil.Discard, res.Body); err != nil { + if _, err := io.Copy(io.Discard, res.Body); err != nil { t.Fatalf("res %d body copy: %v", i+1, err) } res.Body.Close() @@ -1235,7 +1234,7 @@ func testSetsRemoteAddr(t *testing.T, h2 bool) { if err != nil { t.Fatalf("Get error: %v", err) } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("ReadAll error: %v", err) } @@ -1299,7 +1298,7 @@ func TestServerAllowsBlockingRemoteAddr(t *testing.T) { return } defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { t.Errorf("Request %d: %v", num, err) response <- "" @@ -1381,7 +1380,7 @@ func testHeadResponses(t *testing.T, h2 bool) { if v := res.ContentLength; v != 10 { t.Errorf("Content-Length: %d; want 10", v) } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Error(err) } @@ -1432,7 +1431,7 @@ func TestTLSServer(t *testing.T) { } } })) - ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0) + ts.Config.ErrorLog = log.New(io.Discard, "", 0) defer ts.Close() // Connect an idle TCP connection to this server before we run @@ -1540,7 +1539,7 @@ func TestTLSServerRejectHTTPRequests(t *testing.T) { } defer conn.Close() io.WriteString(conn, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n") - slurp, err := ioutil.ReadAll(conn) + slurp, err := io.ReadAll(conn) if err != nil { t.Fatal(err) } @@ -1734,7 +1733,7 @@ func TestServerExpect(t *testing.T) { // requests that would read from r.Body, which we only // conditionally want to do. if strings.Contains(r.URL.RawQuery, "readbody=true") { - ioutil.ReadAll(r.Body) + io.ReadAll(r.Body) w.Write([]byte("Hi")) } else { w.WriteHeader(StatusUnauthorized) @@ -1773,7 +1772,7 @@ func TestServerExpect(t *testing.T) { io.Closer }{ conn, - ioutil.NopCloser(nil), + io.NopCloser(nil), } if test.chunked { targ = httputil.NewChunkedWriter(conn) @@ -2072,7 +2071,7 @@ type testHandlerBodyConsumer struct { var testHandlerBodyConsumers = []testHandlerBodyConsumer{ {"nil", func(io.ReadCloser) {}}, {"close", func(r io.ReadCloser) { r.Close() }}, - {"discard", func(r io.ReadCloser) { io.Copy(ioutil.Discard, r) }}, + {"discard", func(r io.ReadCloser) { io.Copy(io.Discard, r) }}, } func TestRequestBodyReadErrorClosesConnection(t *testing.T) { @@ -2298,7 +2297,7 @@ func testTimeoutHandler(t *testing.T, h2 bool) { if g, e := res.StatusCode, StatusOK; g != e { t.Errorf("got res.StatusCode %d; expected %d", g, e) } - body, _ := ioutil.ReadAll(res.Body) + body, _ := io.ReadAll(res.Body) if g, e := string(body), "hi"; g != e { t.Errorf("got body %q; expected %q", g, e) } @@ -2315,7 +2314,7 @@ func testTimeoutHandler(t *testing.T, h2 bool) { if g, e := res.StatusCode, StatusServiceUnavailable; g != e { t.Errorf("got res.StatusCode %d; expected %d", g, e) } - body, _ = ioutil.ReadAll(res.Body) + body, _ = io.ReadAll(res.Body) if !strings.Contains(string(body), "<title>Timeout</title>") { t.Errorf("expected timeout body; got %q", string(body)) } @@ -2367,7 +2366,7 @@ func TestTimeoutHandlerRace(t *testing.T) { defer func() { <-gate }() res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50))) if err == nil { - io.Copy(ioutil.Discard, res.Body) + io.Copy(io.Discard, res.Body) res.Body.Close() } }() @@ -2410,7 +2409,7 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) { return } defer res.Body.Close() - io.Copy(ioutil.Discard, res.Body) + io.Copy(io.Discard, res.Body) }() } wg.Wait() @@ -2441,7 +2440,7 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { if g, e := res.StatusCode, StatusOK; g != e { t.Errorf("got res.StatusCode %d; expected %d", g, e) } - body, _ := ioutil.ReadAll(res.Body) + body, _ := io.ReadAll(res.Body) if g, e := string(body), "hi"; g != e { t.Errorf("got body %q; expected %q", g, e) } @@ -2458,7 +2457,7 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { if g, e := res.StatusCode, StatusServiceUnavailable; g != e { t.Errorf("got res.StatusCode %d; expected %d", g, e) } - body, _ = ioutil.ReadAll(res.Body) + body, _ = io.ReadAll(res.Body) if !strings.Contains(string(body), "<title>Timeout</title>") { t.Errorf("expected timeout body; got %q", string(body)) } @@ -2630,7 +2629,7 @@ func TestRedirectContentTypeAndBody(t *testing.T) { t.Errorf("Redirect(%q, %#v) generated Content-Type header %q; want %q", tt.method, tt.ct, got, want) } resp := rec.Result() - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } @@ -2657,7 +2656,7 @@ func testZeroLengthPostAndResponse(t *testing.T, h2 bool) { setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, r *Request) { - all, err := ioutil.ReadAll(r.Body) + all, err := io.ReadAll(r.Body) if err != nil { t.Fatalf("handler ReadAll: %v", err) } @@ -2683,7 +2682,7 @@ func testZeroLengthPostAndResponse(t *testing.T, h2 bool) { } for i := range resp { - all, err := ioutil.ReadAll(resp[i].Body) + all, err := io.ReadAll(resp[i].Body) if err != nil { t.Fatalf("req #%d: client ReadAll: %v", i, err) } @@ -2710,7 +2709,7 @@ func TestHandlerPanicWithHijack(t *testing.T) { func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) Handler, panicValue interface{}) { defer afterTest(t) - // Unlike the other tests that set the log output to ioutil.Discard + // Unlike the other tests that set the log output to io.Discard // to quiet the output, this test uses a pipe. The pipe serves three // purposes: // @@ -2849,29 +2848,47 @@ func TestStripPrefix(t *testing.T) { defer afterTest(t) h := HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Path", r.URL.Path) + w.Header().Set("X-RawPath", r.URL.RawPath) }) - ts := httptest.NewServer(StripPrefix("/foo", h)) + ts := httptest.NewServer(StripPrefix("/foo/bar", h)) defer ts.Close() c := ts.Client() - res, err := c.Get(ts.URL + "/foo/bar") - if err != nil { - t.Fatal(err) - } - if g, e := res.Header.Get("X-Path"), "/bar"; g != e { - t.Errorf("test 1: got %s, want %s", g, e) - } - res.Body.Close() - - res, err = Get(ts.URL + "/bar") - if err != nil { - t.Fatal(err) - } - if g, e := res.StatusCode, 404; g != e { - t.Errorf("test 2: got status %v, want %v", g, e) + cases := []struct { + reqPath string + path string // If empty we want a 404. + rawPath string + }{ + {"/foo/bar/qux", "/qux", ""}, + {"/foo/bar%2Fqux", "/qux", "%2Fqux"}, + {"/foo%2Fbar/qux", "", ""}, // Escaped prefix does not match. + {"/bar", "", ""}, // No prefix match. + } + for _, tc := range cases { + t.Run(tc.reqPath, func(t *testing.T) { + res, err := c.Get(ts.URL + tc.reqPath) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if tc.path == "" { + if res.StatusCode != StatusNotFound { + t.Errorf("got %q, want 404 Not Found", res.Status) + } + return + } + if res.StatusCode != StatusOK { + t.Fatalf("got %q, want 200 OK", res.Status) + } + if g, w := res.Header.Get("X-Path"), tc.path; g != w { + t.Errorf("got Path %q, want %q", g, w) + } + if g, w := res.Header.Get("X-RawPath"), tc.rawPath; g != w { + t.Errorf("got RawPath %q, want %q", g, w) + } + }) } - res.Body.Close() } // https://golang.org/issue/18952. @@ -2952,7 +2969,7 @@ func testRequestBodyLimit(t *testing.T, h2 bool) { const limit = 1 << 20 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { r.Body = MaxBytesReader(w, r.Body, limit) - n, err := io.Copy(ioutil.Discard, r.Body) + n, err := io.Copy(io.Discard, r.Body) if err == nil { t.Errorf("expected error from io.Copy") } @@ -3002,7 +3019,7 @@ func TestClientWriteShutdown(t *testing.T) { donec := make(chan bool) go func() { defer close(donec) - bs, err := ioutil.ReadAll(conn) + bs, err := io.ReadAll(conn) if err != nil { t.Errorf("ReadAll: %v", err) } @@ -3323,7 +3340,7 @@ func TestHijackBeforeRequestBodyRead(t *testing.T) { r.Body = nil // to test that server.go doesn't use this value. gone := w.(CloseNotifier).CloseNotify() - slurp, err := ioutil.ReadAll(reqBody) + slurp, err := io.ReadAll(reqBody) if err != nil { t.Errorf("Body read: %v", err) return @@ -3625,7 +3642,7 @@ func TestAcceptMaxFds(t *testing.T) { }}} server := &Server{ Handler: HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})), - ErrorLog: log.New(ioutil.Discard, "", 0), // noisy otherwise + ErrorLog: log.New(io.Discard, "", 0), // noisy otherwise } err := server.Serve(ln) if err != io.EOF { @@ -3764,7 +3781,7 @@ func testServerReaderFromOrder(t *testing.T, h2 bool) { close(done) }() time.Sleep(25 * time.Millisecond) // give Copy a chance to break things - n, err := io.Copy(ioutil.Discard, req.Body) + n, err := io.Copy(io.Discard, req.Body) if err != nil { t.Errorf("handler Copy: %v", err) return @@ -3786,7 +3803,7 @@ func testServerReaderFromOrder(t *testing.T, h2 bool) { if err != nil { t.Fatal(err) } - all, err := ioutil.ReadAll(res.Body) + all, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -3911,7 +3928,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { errorf("Proxy outbound request: %v", err) return } - _, err = io.CopyN(ioutil.Discard, bresp.Body, bodySize/2) + _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2) if err != nil { errorf("Proxy copy error: %v", err) return @@ -4118,7 +4135,7 @@ func TestServerConnState(t *testing.T) { ts.Close() }() - ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0) + ts.Config.ErrorLog = log.New(io.Discard, "", 0) ts.Config.ConnState = func(c net.Conn, state ConnState) { if c == nil { t.Errorf("nil conn seen in state %s", state) @@ -4158,7 +4175,7 @@ func TestServerConnState(t *testing.T) { t.Errorf("Error fetching %s: %v", url, err) return } - _, err = ioutil.ReadAll(res.Body) + _, err = io.ReadAll(res.Body) defer res.Body.Close() if err != nil { t.Errorf("Error reading %s: %v", url, err) @@ -4215,7 +4232,7 @@ func TestServerConnState(t *testing.T) { if err != nil { t.Fatal(err) } - if _, err := io.Copy(ioutil.Discard, res.Body); err != nil { + if _, err := io.Copy(io.Discard, res.Body); err != nil { t.Fatal(err) } c.Close() @@ -4257,11 +4274,17 @@ func testServerEmptyBodyRace(t *testing.T, h2 bool) { defer wg.Done() res, err := cst.c.Get(cst.ts.URL) if err != nil { - t.Error(err) - return + // Try to deflake spurious "connection reset by peer" under load. + // See golang.org/issue/22540. + time.Sleep(10 * time.Millisecond) + res, err = cst.c.Get(cst.ts.URL) + if err != nil { + t.Error(err) + return + } } defer res.Body.Close() - _, err = io.Copy(ioutil.Discard, res.Body) + _, err = io.Copy(io.Discard, res.Body) if err != nil { t.Error(err) return @@ -4287,7 +4310,7 @@ func TestServerConnStateNew(t *testing.T) { srv.Serve(&oneConnListener{ conn: &rwTestConn{ Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"), - Writer: ioutil.Discard, + Writer: io.Discard, }, }) if !sawNew { // testing that this read isn't racy @@ -4343,7 +4366,7 @@ func TestServerFlushAndHijack(t *testing.T) { t.Fatal(err) } defer res.Body.Close() - all, err := ioutil.ReadAll(res.Body) + all, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -4531,7 +4554,7 @@ Host: foo go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) { numReq++ if r.URL.Path == "/readbody" { - ioutil.ReadAll(r.Body) + io.ReadAll(r.Body) } io.WriteString(w, "Hello world!") })) @@ -4584,7 +4607,7 @@ func testHandlerSetsBodyNil(t *testing.T, h2 bool) { t.Fatal(err) } defer res.Body.Close() - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -4604,7 +4627,7 @@ func TestServerValidatesHostHeader(t *testing.T) { host string want int }{ - {"HTTP/0.9", "", 400}, + {"HTTP/0.9", "", 505}, {"HTTP/1.1", "", 400}, {"HTTP/1.1", "Host: \r\n", 200}, @@ -4636,9 +4659,9 @@ func TestServerValidatesHostHeader(t *testing.T) { {"CONNECT golang.org:443 HTTP/1.1", "", 200}, // But not other HTTP/2 stuff: - {"PRI / HTTP/2.0", "", 400}, - {"GET / HTTP/2.0", "", 400}, - {"GET / HTTP/3.0", "", 400}, + {"PRI / HTTP/2.0", "", 505}, + {"GET / HTTP/2.0", "", 505}, + {"GET / HTTP/3.0", "", 505}, } for _, tt := range tests { conn := &testConn{closec: make(chan bool, 1)} @@ -4700,7 +4723,7 @@ func TestServerHandlersCanHandleH2PRI(t *testing.T) { } defer c.Close() io.WriteString(c, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") - slurp, err := ioutil.ReadAll(c) + slurp, err := io.ReadAll(c) if err != nil { t.Fatal(err) } @@ -4934,7 +4957,7 @@ func BenchmarkClientServer(b *testing.B) { if err != nil { b.Fatal("Get:", err) } - all, err := ioutil.ReadAll(res.Body) + all, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { b.Fatal("ReadAll:", err) @@ -4985,7 +5008,7 @@ func benchmarkClientServerParallel(b *testing.B, parallelism int, useTLS bool) { b.Logf("Get: %v", err) continue } - all, err := ioutil.ReadAll(res.Body) + all, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { b.Logf("ReadAll: %v", err) @@ -5020,7 +5043,7 @@ func BenchmarkServer(b *testing.B) { if err != nil { log.Panicf("Get: %v", err) } - all, err := ioutil.ReadAll(res.Body) + all, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { log.Panicf("ReadAll: %v", err) @@ -5143,7 +5166,7 @@ func BenchmarkClient(b *testing.B) { if err != nil { b.Fatalf("Get: %v", err) } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { b.Fatalf("ReadAll: %v", err) @@ -5233,7 +5256,7 @@ Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3 conn := &rwTestConn{ Reader: &repeatReader{content: req, count: b.N}, - Writer: ioutil.Discard, + Writer: io.Discard, closec: make(chan bool, 1), } handled := 0 @@ -5262,7 +5285,7 @@ Host: golang.org conn := &rwTestConn{ Reader: &repeatReader{content: req, count: b.N}, - Writer: ioutil.Discard, + Writer: io.Discard, closec: make(chan bool, 1), } handled := 0 @@ -5322,7 +5345,7 @@ Host: golang.org `) conn := &rwTestConn{ Reader: &repeatReader{content: req, count: b.N}, - Writer: ioutil.Discard, + Writer: io.Discard, closec: make(chan bool, 1), } handled := 0 @@ -5351,7 +5374,7 @@ Host: golang.org conn.Close() }) conn := &rwTestConn{ - Writer: ioutil.Discard, + Writer: io.Discard, closec: make(chan bool, 1), } ln := &oneConnListener{conn: conn} @@ -5414,7 +5437,7 @@ func TestServerIdleTimeout(t *testing.T) { setParallel(t) defer afterTest(t) ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { - io.Copy(ioutil.Discard, r.Body) + io.Copy(io.Discard, r.Body) io.WriteString(w, r.RemoteAddr) })) ts.Config.ReadHeaderTimeout = 1 * time.Second @@ -5429,7 +5452,7 @@ func TestServerIdleTimeout(t *testing.T) { t.Fatal(err) } defer res.Body.Close() - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -5454,7 +5477,7 @@ func TestServerIdleTimeout(t *testing.T) { defer conn.Close() conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n")) time.Sleep(2 * time.Second) - if _, err := io.CopyN(ioutil.Discard, conn, 1); err == nil { + if _, err := io.CopyN(io.Discard, conn, 1); err == nil { t.Fatal("copy byte succeeded; want err") } } @@ -5465,7 +5488,7 @@ func get(t *testing.T, c *Client, url string) string { t.Fatal(err) } defer res.Body.Close() - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -5519,16 +5542,23 @@ func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) { } } -func TestServerShutdown_h1(t *testing.T) { testServerShutdown(t, h1Mode) } -func TestServerShutdown_h2(t *testing.T) { testServerShutdown(t, h2Mode) } +func TestServerShutdown_h1(t *testing.T) { + testServerShutdown(t, h1Mode) +} +func TestServerShutdown_h2(t *testing.T) { + testServerShutdown(t, h2Mode) +} func testServerShutdown(t *testing.T, h2 bool) { setParallel(t) defer afterTest(t) var doShutdown func() // set later + var doStateCount func() var shutdownRes = make(chan error, 1) + var statesRes = make(chan map[ConnState]int, 1) var gotOnShutdown = make(chan struct{}, 1) handler := HandlerFunc(func(w ResponseWriter, r *Request) { + doStateCount() go doShutdown() // Shutdown is graceful, so it should not interrupt // this in-flight response. Add a tiny sleep here to @@ -5545,6 +5575,9 @@ func testServerShutdown(t *testing.T, h2 bool) { doShutdown = func() { shutdownRes <- cst.ts.Config.Shutdown(context.Background()) } + doStateCount = func() { + statesRes <- cst.ts.Config.ExportAllConnsByState() + } get(t, cst.c, cst.ts.URL) // calls t.Fail on failure if err := <-shutdownRes; err != nil { @@ -5556,6 +5589,10 @@ func testServerShutdown(t *testing.T, h2 bool) { t.Errorf("onShutdown callback not called, RegisterOnShutdown broken?") } + if states := <-statesRes; states[StateActive] != 1 { + t.Errorf("connection in wrong state, %v", states) + } + res, err := cst.c.Get(cst.ts.URL) if err == nil { res.Body.Close() @@ -5701,7 +5738,7 @@ func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { if err != nil { return fmt.Errorf("Get: %v", err) } - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { return fmt.Errorf("Body ReadAll: %v", err) @@ -5764,7 +5801,7 @@ func TestServerDuplicateBackgroundRead(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - io.Copy(ioutil.Discard, cn) + io.Copy(io.Discard, cn) }() for j := 0; j < requests; j++ { @@ -5864,7 +5901,7 @@ func TestServerHijackGetsBackgroundByte_big(t *testing.T) { return } defer conn.Close() - slurp, err := ioutil.ReadAll(buf.Reader) + slurp, err := io.ReadAll(buf.Reader) if err != nil { t.Errorf("Copy: %v", err) } @@ -6398,16 +6435,73 @@ func fetchWireResponse(host string, http1ReqBody []byte) ([]byte, error) { if _, err := conn.Write(http1ReqBody); err != nil { return nil, err } - return ioutil.ReadAll(conn) + return io.ReadAll(conn) } func BenchmarkResponseStatusLine(b *testing.B) { b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { - bw := bufio.NewWriter(ioutil.Discard) + bw := bufio.NewWriter(io.Discard) var buf3 [3]byte for pb.Next() { Export_writeStatusLine(bw, true, 200, buf3[:]) } }) } +func TestDisableKeepAliveUpgrade(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + setParallel(t) + defer afterTest(t) + + s := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "someProto") + w.WriteHeader(StatusSwitchingProtocols) + c, _, err := w.(Hijacker).Hijack() + if err != nil { + return + } + defer c.Close() + + io.Copy(c, c) + })) + s.Config.SetKeepAlivesEnabled(false) + s.Start() + defer s.Close() + + cl := s.Client() + cl.Transport.(*Transport).DisableKeepAlives = true + + resp, err := cl.Get(s.URL) + if err != nil { + t.Fatalf("failed to perform request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != StatusSwitchingProtocols { + t.Fatalf("unexpected status code: %v", resp.StatusCode) + } + + rwc, ok := resp.Body.(io.ReadWriteCloser) + if !ok { + t.Fatalf("Response.Body is not a io.ReadWriteCloser: %T", resp.Body) + } + + _, err = rwc.Write([]byte("hello")) + if err != nil { + t.Fatalf("failed to write to body: %v", err) + } + + b := make([]byte, 5) + _, err = io.ReadFull(rwc, b) + if err != nil { + t.Fatalf("failed to read from body: %v", err) + } + + if string(b) != "hello" { + t.Fatalf("unexpected value read from body:\ngot: %q\nwant: %q", b, "hello") + } +} diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go index 6f7a259..26c43c2 100644 --- a/libgo/go/net/http/server.go +++ b/libgo/go/net/http/server.go @@ -14,8 +14,8 @@ import ( "errors" "fmt" "io" - "io/ioutil" "log" + "math/rand" "net" "net/textproto" "net/url" @@ -324,7 +324,7 @@ func (c *conn) hijackLocked() (rwc net.Conn, buf *bufio.ReadWriter, err error) { return nil, nil, fmt.Errorf("unexpected Peek failure reading buffered byte: %v", err) } } - c.setState(rwc, StateHijacked) + c.setState(rwc, StateHijacked, runHooks) return } @@ -561,51 +561,53 @@ type writerOnly struct { io.Writer } -func srcIsRegularFile(src io.Reader) (isRegular bool, err error) { - switch v := src.(type) { - case *os.File: - fi, err := v.Stat() - if err != nil { - return false, err - } - return fi.Mode().IsRegular(), nil - case *io.LimitedReader: - return srcIsRegularFile(v.R) - default: - return - } -} - // ReadFrom is here to optimize copying from an *os.File regular file -// to a *net.TCPConn with sendfile. +// to a *net.TCPConn with sendfile, or from a supported src type such +// as a *net.TCPConn on Linux with splice. func (w *response) ReadFrom(src io.Reader) (n int64, err error) { + bufp := copyBufPool.Get().(*[]byte) + buf := *bufp + defer copyBufPool.Put(bufp) + // Our underlying w.conn.rwc is usually a *TCPConn (with its - // own ReadFrom method). If not, or if our src isn't a regular - // file, just fall back to the normal copy method. + // own ReadFrom method). If not, just fall back to the normal + // copy method. rf, ok := w.conn.rwc.(io.ReaderFrom) - regFile, err := srcIsRegularFile(src) - if err != nil { - return 0, err - } - if !ok || !regFile { - bufp := copyBufPool.Get().(*[]byte) - defer copyBufPool.Put(bufp) - return io.CopyBuffer(writerOnly{w}, src, *bufp) + if !ok { + return io.CopyBuffer(writerOnly{w}, src, buf) } // sendfile path: - if !w.wroteHeader { - w.WriteHeader(StatusOK) - } + // Do not start actually writing response until src is readable. + // If body length is <= sniffLen, sendfile/splice path will do + // little anyway. This small read also satisfies sniffing the + // body in case Content-Type is missing. + nr, er := src.Read(buf[:sniffLen]) + atEOF := errors.Is(er, io.EOF) + n += int64(nr) - if w.needsSniff() { - n0, err := io.Copy(writerOnly{w}, io.LimitReader(src, sniffLen)) - n += n0 - if err != nil { - return n, err + if nr > 0 { + // Write the small amount read normally. + nw, ew := w.Write(buf[:nr]) + if ew != nil { + err = ew + } else if nr != nw { + err = io.ErrShortWrite } } + if err == nil && er != nil && !atEOF { + err = er + } + + // Do not send StatusOK in the error case where nothing has been written. + if err == nil && !w.wroteHeader { + w.WriteHeader(StatusOK) // nr == 0, no error (or EOF) + } + + if err != nil || atEOF { + return n, err + } w.w.Flush() // get rid of any previous writes w.cw.flush() // make sure Header is written; flush data to rwc @@ -888,12 +890,12 @@ func (srv *Server) initialReadLimitSize() int64 { type expectContinueReader struct { resp *response readCloser io.ReadCloser - closed bool + closed atomicBool sawEOF atomicBool } func (ecr *expectContinueReader) Read(p []byte) (n int, err error) { - if ecr.closed { + if ecr.closed.isSet() { return 0, ErrBodyReadAfterClose } w := ecr.resp @@ -915,7 +917,7 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err error) { } func (ecr *expectContinueReader) Close() error { - ecr.closed = true + ecr.closed.setTrue() return ecr.readCloser.Close() } @@ -990,7 +992,7 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) { } if !http1ServerSupportsRequest(req) { - return nil, badRequestError("unsupported protocol version") + return nil, statusError{StatusHTTPVersionNotSupported, "unsupported protocol version"} } c.lastMethod = req.Method @@ -1371,7 +1373,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { } if discard { - _, err := io.CopyN(ioutil.Discard, w.reqBody, maxPostHandlerReadBytes+1) + _, err := io.CopyN(io.Discard, w.reqBody, maxPostHandlerReadBytes+1) switch err { case nil: // There must be even more data left over. @@ -1471,7 +1473,13 @@ func (cw *chunkWriter) writeHeader(p []byte) { return } - if w.closeAfterReply && (!keepAlivesEnabled || !hasToken(cw.header.get("Connection"), "close")) { + // Only override the Connection header if it is not a successful + // protocol switch response and if KeepAlives are not enabled. + // See https://golang.org/issue/36381. + delConnectionHeader := w.closeAfterReply && + (!keepAlivesEnabled || !hasToken(cw.header.get("Connection"), "close")) && + !isProtocolSwitchResponse(w.status, header) + if delConnectionHeader { delHeader("Connection") if w.req.ProtoAtLeast(1, 1) { setHeader.connection = "close" @@ -1742,7 +1750,12 @@ func validNextProto(proto string) bool { return true } -func (c *conn) setState(nc net.Conn, state ConnState) { +const ( + runHooks = true + skipHooks = false +) + +func (c *conn) setState(nc net.Conn, state ConnState, runHook bool) { srv := c.server switch state { case StateNew: @@ -1755,6 +1768,9 @@ func (c *conn) setState(nc net.Conn, state ConnState) { } packedState := uint64(time.Now().Unix()<<8) | uint64(state) atomic.StoreUint64(&c.curState.atomic, packedState) + if !runHook { + return + } if hook := srv.ConnState; hook != nil { hook(nc, state) } @@ -1768,9 +1784,16 @@ func (c *conn) getState() (state ConnState, unixSec int64) { // badRequestError is a literal string (used by in the server in HTML, // unescaped) to tell the user why their request was bad. It should // be plain text without user info or other embedded errors. -type badRequestError string +func badRequestError(e string) error { return statusError{StatusBadRequest, e} } -func (e badRequestError) Error() string { return "Bad Request: " + string(e) } +// statusError is an error used to respond to a request with an HTTP status. +// The text should be plain text without user info or other embedded errors. +type statusError struct { + code int + text string +} + +func (e statusError) Error() string { return StatusText(e.code) + ": " + e.text } // ErrAbortHandler is a sentinel panic value to abort a handler. // While any panic from ServeHTTP aborts the response to the client, @@ -1808,7 +1831,7 @@ func (c *conn) serve(ctx context.Context) { } if !c.hijacked() { c.close() - c.setState(c.rwc, StateClosed) + c.setState(c.rwc, StateClosed, runHooks) } }() @@ -1819,7 +1842,7 @@ func (c *conn) serve(ctx context.Context) { if d := c.server.WriteTimeout; d != 0 { c.rwc.SetWriteDeadline(time.Now().Add(d)) } - if err := tlsConn.Handshake(); err != nil { + if err := tlsConn.HandshakeContext(ctx); err != nil { // If the handshake failed due to the client not speaking // TLS, assume they're speaking plaintext HTTP and write a // 400 response on the TLS conn's underlying net.Conn. @@ -1836,6 +1859,10 @@ func (c *conn) serve(ctx context.Context) { if proto := c.tlsState.NegotiatedProtocol; validNextProto(proto) { if fn := c.server.TLSNextProto[proto]; fn != nil { h := initALPNRequest{ctx, tlsConn, serverHandler{c.server}} + // Mark freshly created HTTP/2 as active and prevent any server state hooks + // from being run on these connections. This prevents closeIdleConns from + // closing such connections. See issue https://golang.org/issue/39776. + c.setState(c.rwc, StateActive, skipHooks) fn(c.server, tlsConn, h) } return @@ -1856,7 +1883,7 @@ func (c *conn) serve(ctx context.Context) { w, err := c.readRequest(ctx) if c.r.remain != c.server.initialReadLimitSize() { // If we read any bytes off the wire, we're active. - c.setState(c.rwc, StateActive) + c.setState(c.rwc, StateActive, runHooks) } if err != nil { const errorHeaders = "\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\n" @@ -1889,11 +1916,11 @@ func (c *conn) serve(ctx context.Context) { return // don't reply default: - publicErr := "400 Bad Request" - if v, ok := err.(badRequestError); ok { - publicErr = publicErr + ": " + string(v) + if v, ok := err.(statusError); ok { + fmt.Fprintf(c.rwc, "HTTP/1.1 %d %s: %s%s%d %s: %s", v.code, StatusText(v.code), v.text, errorHeaders, v.code, StatusText(v.code), v.text) + return } - + publicErr := "400 Bad Request" fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr) return } @@ -1939,7 +1966,7 @@ func (c *conn) serve(ctx context.Context) { } return } - c.setState(c.rwc, StateIdle) + c.setState(c.rwc, StateIdle, runHooks) c.curReq.Store((*response)(nil)) if !w.conn.server.doKeepAlives() { @@ -2067,22 +2094,26 @@ func NotFound(w ResponseWriter, r *Request) { Error(w, "404 page not found", Sta // that replies to each request with a ``404 page not found'' reply. func NotFoundHandler() Handler { return HandlerFunc(NotFound) } -// StripPrefix returns a handler that serves HTTP requests -// by removing the given prefix from the request URL's Path -// and invoking the handler h. StripPrefix handles a -// request for a path that doesn't begin with prefix by -// replying with an HTTP 404 not found error. +// StripPrefix returns a handler that serves HTTP requests by removing the +// given prefix from the request URL's Path (and RawPath if set) and invoking +// the handler h. StripPrefix handles a request for a path that doesn't begin +// with prefix by replying with an HTTP 404 not found error. The prefix must +// match exactly: if the prefix in the request contains escaped characters +// the reply is also an HTTP 404 not found error. func StripPrefix(prefix string, h Handler) Handler { if prefix == "" { return h } return HandlerFunc(func(w ResponseWriter, r *Request) { - if p := strings.TrimPrefix(r.URL.Path, prefix); len(p) < len(r.URL.Path) { + p := strings.TrimPrefix(r.URL.Path, prefix) + rp := strings.TrimPrefix(r.URL.RawPath, prefix) + if len(p) < len(r.URL.Path) && (r.URL.RawPath == "" || len(rp) < len(r.URL.RawPath)) { r2 := new(Request) *r2 = *r r2.URL = new(url.URL) *r2.URL = *r.URL r2.URL.Path = p + r2.URL.RawPath = rp h.ServeHTTP(w, r2) } else { NotFound(w, r) @@ -2672,14 +2703,14 @@ func (srv *Server) Close() error { return err } -// shutdownPollInterval is how often we poll for quiescence -// during Server.Shutdown. This is lower during tests, to -// speed up tests. +// shutdownPollIntervalMax is the max polling interval when checking +// quiescence during Server.Shutdown. Polling starts with a small +// interval and backs off to the max. // Ideally we could find a solution that doesn't involve polling, // but which also doesn't have a high runtime cost (and doesn't // involve any contentious mutexes), but that is left as an // exercise for the reader. -var shutdownPollInterval = 500 * time.Millisecond +const shutdownPollIntervalMax = 500 * time.Millisecond // Shutdown gracefully shuts down the server without interrupting any // active connections. Shutdown works by first closing all open @@ -2712,8 +2743,20 @@ func (srv *Server) Shutdown(ctx context.Context) error { } srv.mu.Unlock() - ticker := time.NewTicker(shutdownPollInterval) - defer ticker.Stop() + pollIntervalBase := time.Millisecond + nextPollInterval := func() time.Duration { + // Add 10% jitter. + interval := pollIntervalBase + time.Duration(rand.Intn(int(pollIntervalBase/10))) + // Double and clamp for next time. + pollIntervalBase *= 2 + if pollIntervalBase > shutdownPollIntervalMax { + pollIntervalBase = shutdownPollIntervalMax + } + return interval + } + + timer := time.NewTimer(nextPollInterval()) + defer timer.Stop() for { if srv.closeIdleConns() && srv.numListeners() == 0 { return lnerr @@ -2721,7 +2764,8 @@ func (srv *Server) Shutdown(ctx context.Context) error { select { case <-ctx.Done(): return ctx.Err() - case <-ticker.C: + case <-timer.C: + timer.Reset(nextPollInterval()) } } } @@ -2970,7 +3014,7 @@ func (srv *Server) Serve(l net.Listener) error { } tempDelay = 0 c := srv.newConn(rw) - c.setState(c.rwc, StateNew) // before Serve can return + c.setState(c.rwc, StateNew, runHooks) // before Serve can return go c.serve(connCtx) } } @@ -3387,7 +3431,7 @@ func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) { // (or an attack) and we abort and close the connection, // courtesy of MaxBytesReader's EOF behavior. mb := MaxBytesReader(w, r.Body, 4<<10) - io.Copy(ioutil.Discard, mb) + io.Copy(io.Discard, mb) } } diff --git a/libgo/go/net/http/sniff_test.go b/libgo/go/net/http/sniff_test.go index a1157a0..8d53503 100644 --- a/libgo/go/net/http/sniff_test.go +++ b/libgo/go/net/http/sniff_test.go @@ -8,7 +8,6 @@ import ( "bytes" "fmt" "io" - "io/ioutil" "log" . "net/http" "reflect" @@ -123,7 +122,7 @@ func testServerContentType(t *testing.T, h2 bool) { if ct := resp.Header.Get("Content-Type"); ct != wantContentType { t.Errorf("%v: Content-Type = %q, want %q", tt.desc, ct, wantContentType) } - data, err := ioutil.ReadAll(resp.Body) + data, err := io.ReadAll(resp.Body) if err != nil { t.Errorf("%v: reading body: %v", tt.desc, err) } else if !bytes.Equal(data, tt.data) { @@ -185,7 +184,7 @@ func testContentTypeWithCopy(t *testing.T, h2 bool) { if ct := resp.Header.Get("Content-Type"); ct != expected { t.Errorf("Content-Type = %q, want %q", ct, expected) } - data, err := ioutil.ReadAll(resp.Body) + data, err := io.ReadAll(resp.Body) if err != nil { t.Errorf("reading body: %v", err) } else if !bytes.Equal(data, []byte(input)) { @@ -216,7 +215,7 @@ func testSniffWriteSize(t *testing.T, h2 bool) { if err != nil { t.Fatalf("size %d: %v", size, err) } - if _, err := io.Copy(ioutil.Discard, res.Body); err != nil { + if _, err := io.Copy(io.Discard, res.Body); err != nil { t.Fatalf("size %d: io.Copy of body = %v", size, err) } if err := res.Body.Close(); err != nil { diff --git a/libgo/go/net/http/transfer.go b/libgo/go/net/http/transfer.go index 50d434b..fbb0c39 100644 --- a/libgo/go/net/http/transfer.go +++ b/libgo/go/net/http/transfer.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net/http/httptrace" "net/http/internal" "net/textproto" @@ -156,7 +155,7 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) { // servers. See Issue 18257, as one example. // // The only reason we'd send such a request is if the user set the Body to a -// non-nil value (say, ioutil.NopCloser(bytes.NewReader(nil))) and didn't +// non-nil value (say, io.NopCloser(bytes.NewReader(nil))) and didn't // set ContentLength, or NewRequest set it to -1 (unknown), so then we assume // there's bytes to send. // @@ -258,7 +257,7 @@ func (t *transferWriter) shouldSendContentLength() bool { return false } // Many servers expect a Content-Length for these methods - if t.Method == "POST" || t.Method == "PUT" { + if t.Method == "POST" || t.Method == "PUT" || t.Method == "PATCH" { return true } if t.ContentLength == 0 && isIdentity(t.TransferEncoding) { @@ -330,9 +329,18 @@ func (t *transferWriter) writeHeader(w io.Writer, trace *httptrace.ClientTrace) return nil } -func (t *transferWriter) writeBody(w io.Writer) error { - var err error +// always closes t.BodyCloser +func (t *transferWriter) writeBody(w io.Writer) (err error) { var ncopy int64 + closed := false + defer func() { + if closed || t.BodyCloser == nil { + return + } + if closeErr := t.BodyCloser.Close(); closeErr != nil && err == nil { + err = closeErr + } + }() // Write body. We "unwrap" the body first if it was wrapped in a // nopCloser or readTrackingBody. This is to ensure that we can take advantage of @@ -361,7 +369,7 @@ func (t *transferWriter) writeBody(w io.Writer) error { return err } var nextra int64 - nextra, err = t.doBodyCopy(ioutil.Discard, body) + nextra, err = t.doBodyCopy(io.Discard, body) ncopy += nextra } if err != nil { @@ -369,6 +377,7 @@ func (t *transferWriter) writeBody(w io.Writer) error { } } if t.BodyCloser != nil { + closed = true if err := t.BodyCloser.Close(); err != nil { return err } @@ -982,7 +991,7 @@ func (b *body) Close() error { var n int64 // Consume the body, or, which will also lead to us reading // the trailer headers after the body, if present. - n, err = io.CopyN(ioutil.Discard, bodyLocked{b}, maxPostHandlerReadBytes) + n, err = io.CopyN(io.Discard, bodyLocked{b}, maxPostHandlerReadBytes) if err == io.EOF { err = nil } @@ -993,7 +1002,7 @@ func (b *body) Close() error { default: // Fully consume the body, which will also lead to us reading // the trailer headers after the body, if present. - _, err = io.Copy(ioutil.Discard, bodyLocked{b}) + _, err = io.Copy(io.Discard, bodyLocked{b}) } b.closed = true return err @@ -1065,7 +1074,7 @@ func (fr finishAsyncByteRead) Read(p []byte) (n int, err error) { return } -var nopCloserType = reflect.TypeOf(ioutil.NopCloser(nil)) +var nopCloserType = reflect.TypeOf(io.NopCloser(nil)) // isKnownInMemoryReader reports whether r is a type known to not // block on Read. Its caller uses this as an optional optimization to diff --git a/libgo/go/net/http/transfer_test.go b/libgo/go/net/http/transfer_test.go index 185225f..f0c28b2 100644 --- a/libgo/go/net/http/transfer_test.go +++ b/libgo/go/net/http/transfer_test.go @@ -10,7 +10,6 @@ import ( "crypto/rand" "fmt" "io" - "io/ioutil" "os" "reflect" "strings" @@ -81,11 +80,11 @@ func TestDetectInMemoryReaders(t *testing.T) { {bytes.NewBuffer(nil), true}, {strings.NewReader(""), true}, - {ioutil.NopCloser(pr), false}, + {io.NopCloser(pr), false}, - {ioutil.NopCloser(bytes.NewReader(nil)), true}, - {ioutil.NopCloser(bytes.NewBuffer(nil)), true}, - {ioutil.NopCloser(strings.NewReader("")), true}, + {io.NopCloser(bytes.NewReader(nil)), true}, + {io.NopCloser(bytes.NewBuffer(nil)), true}, + {io.NopCloser(strings.NewReader("")), true}, } for i, tt := range tests { got := isKnownInMemoryReader(tt.r) @@ -104,12 +103,12 @@ var _ io.ReaderFrom = (*mockTransferWriter)(nil) func (w *mockTransferWriter) ReadFrom(r io.Reader) (int64, error) { w.CalledReader = r - return io.Copy(ioutil.Discard, r) + return io.Copy(io.Discard, r) } func (w *mockTransferWriter) Write(p []byte) (int, error) { w.WriteCalled = true - return ioutil.Discard.Write(p) + return io.Discard.Write(p) } func TestTransferWriterWriteBodyReaderTypes(t *testing.T) { @@ -118,7 +117,7 @@ func TestTransferWriterWriteBodyReaderTypes(t *testing.T) { nBytes := int64(1 << 10) newFileFunc := func() (r io.Reader, done func(), err error) { - f, err := ioutil.TempFile("", "net-http-newfilefunc") + f, err := os.CreateTemp("", "net-http-newfilefunc") if err != nil { return nil, nil, err } @@ -166,7 +165,7 @@ func TestTransferWriterWriteBodyReaderTypes(t *testing.T) { method: "PUT", bodyFunc: func() (io.Reader, func(), error) { r, cleanup, err := newFileFunc() - return ioutil.NopCloser(r), cleanup, err + return io.NopCloser(r), cleanup, err }, contentLength: nBytes, limitedReader: true, @@ -206,7 +205,7 @@ func TestTransferWriterWriteBodyReaderTypes(t *testing.T) { method: "PUT", bodyFunc: func() (io.Reader, func(), error) { r, cleanup, err := newBufferFunc() - return ioutil.NopCloser(r), cleanup, err + return io.NopCloser(r), cleanup, err }, contentLength: nBytes, limitedReader: true, diff --git a/libgo/go/net/http/transport.go b/libgo/go/net/http/transport.go index d37b52b..6358c38 100644 --- a/libgo/go/net/http/transport.go +++ b/libgo/go/net/http/transport.go @@ -44,7 +44,6 @@ var DefaultTransport RoundTripper = &Transport{ DialContext: (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, - DualStack: true, }).DialContext, ForceAttemptHTTP2: true, MaxIdleConns: 100, @@ -240,8 +239,18 @@ type Transport struct { // ProxyConnectHeader optionally specifies headers to send to // proxies during CONNECT requests. + // To set the header dynamically, see GetProxyConnectHeader. ProxyConnectHeader Header + // GetProxyConnectHeader optionally specifies a func to return + // headers to send to proxyURL during a CONNECT request to the + // ip:port target. + // If it returns an error, the Transport's RoundTrip fails with + // that error. It can return (nil, nil) to not add headers. + // If GetProxyConnectHeader is non-nil, ProxyConnectHeader is + // ignored. + GetProxyConnectHeader func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) + // MaxResponseHeaderBytes specifies a limit on how many // response bytes are allowed in the server's response // header. @@ -313,6 +322,7 @@ func (t *Transport) Clone() *Transport { ResponseHeaderTimeout: t.ResponseHeaderTimeout, ExpectContinueTimeout: t.ExpectContinueTimeout, ProxyConnectHeader: t.ProxyConnectHeader.Clone(), + GetProxyConnectHeader: t.GetProxyConnectHeader, MaxResponseHeaderBytes: t.MaxResponseHeaderBytes, ForceAttemptHTTP2: t.ForceAttemptHTTP2, WriteBufferSize: t.WriteBufferSize, @@ -385,7 +395,7 @@ func (t *Transport) onceSetNextProtoDefaults() { if omitBundledHTTP2 { return } - t2, err := http2configureTransport(t) + t2, err := http2configureTransports(t) if err != nil { log.Printf("Error enabling Transport HTTP/2 support: %v", err) return @@ -613,7 +623,8 @@ var errCannotRewind = errors.New("net/http: cannot rewind body after connection type readTrackingBody struct { io.ReadCloser - didRead bool + didRead bool + didClose bool } func (r *readTrackingBody) Read(data []byte) (int, error) { @@ -621,6 +632,11 @@ func (r *readTrackingBody) Read(data []byte) (int, error) { return r.ReadCloser.Read(data) } +func (r *readTrackingBody) Close() error { + r.didClose = true + return r.ReadCloser.Close() +} + // setupRewindBody returns a new request with a custom body wrapper // that can report whether the body needs rewinding. // This lets rewindBody avoid an error result when the request @@ -639,10 +655,12 @@ func setupRewindBody(req *Request) *Request { // rewindBody takes care of closing req.Body when appropriate // (in all cases except when rewindBody returns req unmodified). func rewindBody(req *Request) (rewound *Request, err error) { - if req.Body == nil || req.Body == NoBody || !req.Body.(*readTrackingBody).didRead { + if req.Body == nil || req.Body == NoBody || (!req.Body.(*readTrackingBody).didRead && !req.Body.(*readTrackingBody).didClose) { return req, nil // nothing to rewind } - req.closeBody() + if !req.Body.(*readTrackingBody).didClose { + req.closeBody() + } if req.GetBody == nil { return nil, errCannotRewind } @@ -766,7 +784,8 @@ func (t *Transport) CancelRequest(req *Request) { } // Cancel an in-flight request, recording the error value. -func (t *Transport) cancelRequest(key cancelKey, err error) { +// Returns whether the request was canceled. +func (t *Transport) cancelRequest(key cancelKey, err error) bool { t.reqMu.Lock() cancel := t.reqCanceler[key] delete(t.reqCanceler, key) @@ -774,6 +793,8 @@ func (t *Transport) cancelRequest(key cancelKey, err error) { if cancel != nil { cancel(err) } + + return cancel != nil } // @@ -1484,7 +1505,7 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) { // Add TLS to a persistent connection, i.e. negotiate a TLS session. If pconn is already a TLS // tunnel, this function establishes a nested TLS session inside the encrypted channel. // The remote endpoint's name may be overridden by TLSClientConfig.ServerName. -func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) error { +func (pconn *persistConn) addTLS(ctx context.Context, name string, trace *httptrace.ClientTrace) error { // Initiate TLS and check remote host name against certificate. cfg := cloneTLSConfig(pconn.t.TLSClientConfig) if cfg.ServerName == "" { @@ -1506,7 +1527,7 @@ func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) erro if trace != nil && trace.TLSHandshakeStart != nil { trace.TLSHandshakeStart() } - err := tlsConn.Handshake() + err := tlsConn.HandshakeContext(ctx) if timer != nil { timer.Stop() } @@ -1528,6 +1549,10 @@ func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) erro return nil } +type erringRoundTripper interface { + RoundTripErr() error +} + func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) { pconn = &persistConn{ t: t, @@ -1558,7 +1583,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if trace != nil && trace.TLSHandshakeStart != nil { trace.TLSHandshakeStart() } - if err := tc.Handshake(); err != nil { + if err := tc.HandshakeContext(ctx); err != nil { go pconn.conn.Close() if trace != nil && trace.TLSHandshakeDone != nil { trace.TLSHandshakeDone(tls.ConnectionState{}, err) @@ -1582,7 +1607,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil { return nil, wrapErr(err) } - if err = pconn.addTLS(firstTLSHost, trace); err != nil { + if err = pconn.addTLS(ctx, firstTLSHost, trace); err != nil { return nil, wrapErr(err) } } @@ -1619,7 +1644,17 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers } case cm.targetScheme == "https": conn := pconn.conn - hdr := t.ProxyConnectHeader + var hdr Header + if t.GetProxyConnectHeader != nil { + var err error + hdr, err = t.GetProxyConnectHeader(ctx, cm.proxyURL, cm.targetAddr) + if err != nil { + conn.Close() + return nil, err + } + } else { + hdr = t.ProxyConnectHeader + } if hdr == nil { hdr = make(Header) } @@ -1686,7 +1721,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers } if cm.proxyURL != nil && cm.targetScheme == "https" { - if err := pconn.addTLS(cm.tlsHost(), trace); err != nil { + if err := pconn.addTLS(ctx, cm.tlsHost(), trace); err != nil { return nil, err } } @@ -1694,9 +1729,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok { alt := next(cm.targetAddr, pconn.conn.(*tls.Conn)) - if e, ok := alt.(http2erringRoundTripper); ok { - // pconn.conn was closed by next (http2configureTransport.upgradeFn). - return nil, e.err + if e, ok := alt.(erringRoundTripper); ok { + // pconn.conn was closed by next (http2configureTransports.upgradeFn). + return nil, e.RoundTripErr() } return &persistConn{t: t, cacheKey: pconn.cacheKey, alt: alt}, nil } @@ -1963,6 +1998,15 @@ func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritte return nil } + // Wait for the writeLoop goroutine to terminate to avoid data + // races on callers who mutate the request on failure. + // + // When resc in pc.roundTrip and hence rc.ch receives a responseAndError + // with a non-nil error it implies that the persistConn is either closed + // or closing. Waiting on pc.writeLoopDone is hence safe as all callers + // close closech which in turn ensures writeLoop returns. + <-pc.writeLoopDone + // If the request was canceled, that's better than network // failures that were likely the result of tearing down the // connection. @@ -1988,7 +2032,6 @@ func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritte return err } if pc.isBroken() { - <-pc.writeLoopDone if pc.nwrite == startBytesWritten { return nothingWrittenError{err} } @@ -2087,18 +2130,17 @@ func (pc *persistConn) readLoop() { } if !hasBody || bodyWritable { - pc.t.setReqCanceler(rc.cancelKey, nil) + replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) // Put the idle conn back into the pool before we send the response // so if they process it quickly and make another request, they'll // get this same conn. But we use the unbuffered channel 'rc' // to guarantee that persistConn.roundTrip got out of its select // potentially waiting for this persistConn to close. - // but after alive = alive && !pc.sawEOF && pc.wroteRequest() && - tryPutIdleConn(trace) + replaced && tryPutIdleConn(trace) if bodyWritable { closeErr = errCallerOwnsConn @@ -2160,12 +2202,12 @@ func (pc *persistConn) readLoop() { // reading the response body. (or for cancellation or death) select { case bodyEOF := <-waitForBodyRead: - pc.t.setReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool + replaced := pc.t.replaceReqCanceler(rc.cancelKey, nil) // before pc might return to idle pool alive = alive && bodyEOF && !pc.sawEOF && pc.wroteRequest() && - tryPutIdleConn(trace) + replaced && tryPutIdleConn(trace) if bodyEOF { eofc <- struct{}{} } @@ -2347,7 +2389,7 @@ func (pc *persistConn) writeLoop() { // Request.Body are high priority. // Set it here before sending on the // channels below or calling - // pc.close() which tears town + // pc.close() which tears down // connections and causes other // errors. wr.req.setError(err) @@ -2356,7 +2398,6 @@ func (pc *persistConn) writeLoop() { err = pc.bw.Flush() } if err != nil { - wr.req.Request.closeBody() if pc.nwrite == startBytesWritten { err = nothingWrittenError{err} } @@ -2525,7 +2566,9 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err continueCh = make(chan struct{}, 1) } - if pc.t.DisableKeepAlives && !req.wantsClose() { + if pc.t.DisableKeepAlives && + !req.wantsClose() && + !isProtocolSwitchHeader(req.Header) { req.extraHeaders().Set("Connection", "close") } @@ -2560,6 +2603,8 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err var respHeaderTimer <-chan time.Time cancelChan := req.Request.Cancel ctxDoneChan := req.Context().Done() + pcClosed := pc.closech + canceled := false for { testHookWaitResLoop() select { @@ -2579,11 +2624,14 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err defer timer.Stop() // prevent leaks respHeaderTimer = timer.C } - case <-pc.closech: - if debugRoundTrip { - req.logf("closech recv: %T %#v", pc.closed, pc.closed) + case <-pcClosed: + pcClosed = nil + if canceled || pc.t.replaceReqCanceler(req.cancelKey, nil) { + if debugRoundTrip { + req.logf("closech recv: %T %#v", pc.closed, pc.closed) + } + return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed) } - return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed) case <-respHeaderTimer: if debugRoundTrip { req.logf("timeout waiting for response headers.") @@ -2602,10 +2650,10 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err } return re.res, nil case <-cancelChan: - pc.t.cancelRequest(req.cancelKey, errRequestCanceled) + canceled = pc.t.cancelRequest(req.cancelKey, errRequestCanceled) cancelChan = nil case <-ctxDoneChan: - pc.t.cancelRequest(req.cancelKey, req.Context().Err()) + canceled = pc.t.cancelRequest(req.cancelKey, req.Context().Err()) cancelChan = nil ctxDoneChan = nil } diff --git a/libgo/go/net/http/transport_internal_test.go b/libgo/go/net/http/transport_internal_test.go index 92729e6..1097ffd 100644 --- a/libgo/go/net/http/transport_internal_test.go +++ b/libgo/go/net/http/transport_internal_test.go @@ -11,7 +11,6 @@ import ( "crypto/tls" "errors" "io" - "io/ioutil" "net" "net/http/internal" "strings" @@ -226,7 +225,7 @@ func TestTransportBodyAltRewind(t *testing.T) { TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{ "foo": func(authority string, c *tls.Conn) RoundTripper { return roundTripFunc(func(r *Request) (*Response, error) { - n, _ := io.Copy(ioutil.Discard, r.Body) + n, _ := io.Copy(io.Discard, r.Body) if n == 0 { t.Error("body length is zero") } diff --git a/libgo/go/net/http/transport_test.go b/libgo/go/net/http/transport_test.go index 5c5ae3f..28fc4ed 100644 --- a/libgo/go/net/http/transport_test.go +++ b/libgo/go/net/http/transport_test.go @@ -23,8 +23,8 @@ import ( "go/token" "internal/nettrace" "io" - "io/ioutil" "log" + mrand "math/rand" "net" . "net/http" "net/http/httptest" @@ -41,6 +41,7 @@ import ( "sync" "sync/atomic" "testing" + "testing/iotest" "time" "golang.org/x/net/http/httpguts" @@ -171,7 +172,7 @@ func TestTransportKeepAlives(t *testing.T) { if err != nil { t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err) } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err) } @@ -218,7 +219,7 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) { t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err) } defer res.Body.Close() - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) } @@ -271,7 +272,7 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) { t.Errorf("For connectionClose = %v; handler's X-Saw-Close was %v; want %v", connectionClose, got, !connectionClose) } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) } @@ -380,7 +381,7 @@ func TestTransportIdleCacheKeys(t *testing.T) { if err != nil { t.Error(err) } - ioutil.ReadAll(resp.Body) + io.ReadAll(resp.Body) keys := tr.IdleConnKeysForTesting() if e, g := 1, len(keys); e != g { @@ -410,7 +411,7 @@ func TestTransportReadToEndReusesConn(t *testing.T) { w.WriteHeader(200) w.(Flusher).Flush() } else { - w.Header().Set("Content-Type", strconv.Itoa(len(msg))) + w.Header().Set("Content-Length", strconv.Itoa(len(msg))) w.WriteHeader(200) } w.Write([]byte(msg)) @@ -493,7 +494,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { t.Error(err) return } - if _, err := ioutil.ReadAll(resp.Body); err != nil { + if _, err := io.ReadAll(resp.Body); err != nil { t.Errorf("ReadAll: %v", err) return } @@ -573,7 +574,7 @@ func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) { if err != nil { t.Errorf("unexpected error for request %s: %v", reqId, err) } - _, err = ioutil.ReadAll(resp.Body) + _, err = io.ReadAll(resp.Body) if err != nil { t.Errorf("unexpected error for request %s: %v", reqId, err) } @@ -653,7 +654,7 @@ func TestTransportMaxConnsPerHost(t *testing.T) { t.Fatalf("request failed: %v", err) } defer resp.Body.Close() - _, err = ioutil.ReadAll(resp.Body) + _, err = io.ReadAll(resp.Body) if err != nil { t.Fatalf("read body failed: %v", err) } @@ -731,7 +732,7 @@ func TestTransportRemovesDeadIdleConnections(t *testing.T) { t.Fatalf("%s: %v", name, res.Status) } defer res.Body.Close() - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("%s: %v", name, err) } @@ -781,7 +782,7 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) { condFatalf("error in req #%d, GET: %v", n, err) continue } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { condFatalf("error in req #%d, ReadAll: %v", n, err) continue @@ -901,7 +902,7 @@ func TestTransportHeadResponses(t *testing.T) { if e, g := int64(123), res.ContentLength; e != g { t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g) } - if all, err := ioutil.ReadAll(res.Body); err != nil { + if all, err := io.ReadAll(res.Body); err != nil { t.Errorf("loop %d: Body ReadAll: %v", i, err) } else if len(all) != 0 { t.Errorf("Bogus body %q", all) @@ -1004,10 +1005,10 @@ func TestRoundTripGzip(t *testing.T) { t.Errorf("%d. gzip NewReader: %v", i, err) continue } - body, err = ioutil.ReadAll(r) + body, err = io.ReadAll(r) res.Body.Close() } else { - body, err = ioutil.ReadAll(res.Body) + body, err = io.ReadAll(res.Body) } if err != nil { t.Errorf("%d. Error: %q", i, err) @@ -1088,7 +1089,7 @@ func TestTransportGzip(t *testing.T) { if err != nil { t.Fatal(err) } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -1131,7 +1132,7 @@ func TestTransportExpect100Continue(t *testing.T) { switch req.URL.Path { case "/100": // This endpoint implicitly responds 100 Continue and reads body. - if _, err := io.Copy(ioutil.Discard, req.Body); err != nil { + if _, err := io.Copy(io.Discard, req.Body); err != nil { t.Error("Failed to read Body", err) } rw.WriteHeader(StatusOK) @@ -1157,7 +1158,7 @@ func TestTransportExpect100Continue(t *testing.T) { if err != nil { log.Fatal(err) } - if _, err := io.CopyN(ioutil.Discard, bufrw, req.ContentLength); err != nil { + if _, err := io.CopyN(io.Discard, bufrw, req.ContentLength); err != nil { t.Error("Failed to read Body", err) } bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n") @@ -1623,7 +1624,7 @@ func TestTransportGzipRecursive(t *testing.T) { if err != nil { t.Fatal(err) } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -1652,7 +1653,7 @@ func TestTransportGzipShort(t *testing.T) { t.Fatal(err) } defer res.Body.Close() - _, err = ioutil.ReadAll(res.Body) + _, err = io.ReadAll(res.Body) if err == nil { t.Fatal("Expect an error from reading a body.") } @@ -1699,7 +1700,7 @@ func TestTransportPersistConnLeak(t *testing.T) { res, err := c.Get(ts.URL) didReqCh <- true if err != nil { - t.Errorf("client fetch error: %v", err) + t.Logf("client fetch error: %v", err) failed <- true return } @@ -1713,17 +1714,15 @@ func TestTransportPersistConnLeak(t *testing.T) { case <-gotReqCh: // ok case <-failed: - close(unblockCh) - return + // Not great but not what we are testing: + // sometimes an overloaded system will fail to make all the connections. } } nhigh := runtime.NumGoroutine() // Tell all handlers to unblock and reply. - for i := 0; i < numReq; i++ { - unblockCh <- true - } + close(unblockCh) // Wait for all HTTP clients to be done. for i := 0; i < numReq; i++ { @@ -2003,7 +2002,7 @@ func TestIssue3644(t *testing.T) { t.Fatal(err) } defer res.Body.Close() - bs, err := ioutil.ReadAll(res.Body) + bs, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -2028,7 +2027,7 @@ func TestIssue3595(t *testing.T) { t.Errorf("Post: %v", err) return } - got, err := ioutil.ReadAll(res.Body) + got, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("Body ReadAll: %v", err) } @@ -2100,7 +2099,7 @@ func TestTransportConcurrency(t *testing.T) { wg.Done() continue } - all, err := ioutil.ReadAll(res.Body) + all, err := io.ReadAll(res.Body) if err != nil { t.Errorf("read error on req %s: %v", req, err) wg.Done() @@ -2167,7 +2166,7 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { t.Errorf("Error issuing GET: %v", err) break } - _, err = io.Copy(ioutil.Discard, sres.Body) + _, err = io.Copy(io.Discard, sres.Body) if err == nil { t.Errorf("Unexpected successful copy") break @@ -2188,7 +2187,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { }) mux.HandleFunc("/put", func(w ResponseWriter, r *Request) { defer r.Body.Close() - io.Copy(ioutil.Discard, r.Body) + io.Copy(io.Discard, r.Body) }) ts := httptest.NewServer(mux) timeout := 100 * time.Millisecond @@ -2342,7 +2341,7 @@ func TestTransportCancelRequest(t *testing.T) { tr.CancelRequest(req) }() t0 := time.Now() - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) d := time.Since(t0) if err != ExportErrRequestCanceled { @@ -2501,7 +2500,7 @@ func TestCancelRequestWithChannel(t *testing.T) { close(ch) }() t0 := time.Now() - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) d := time.Since(t0) if err != ExportErrRequestCanceled { @@ -2682,7 +2681,7 @@ func (fooProto) RoundTrip(req *Request) (*Response, error) { Status: "200 OK", StatusCode: 200, Header: make(Header), - Body: ioutil.NopCloser(strings.NewReader("You wanted " + req.URL.String())), + Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())), } return res, nil } @@ -2696,7 +2695,7 @@ func TestTransportAltProto(t *testing.T) { if err != nil { t.Fatal(err) } - bodyb, err := ioutil.ReadAll(res.Body) + bodyb, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -2773,7 +2772,7 @@ func TestTransportSocketLateBinding(t *testing.T) { // let the foo response finish so we can use its // connection for /bar fooGate <- true - io.Copy(ioutil.Discard, fooRes.Body) + io.Copy(io.Discard, fooRes.Body) fooRes.Body.Close() }) @@ -2812,7 +2811,7 @@ func TestTransportReading100Continue(t *testing.T) { t.Error(err) return } - slurp, err := ioutil.ReadAll(req.Body) + slurp, err := io.ReadAll(req.Body) if err != nil { t.Errorf("Server request body slurp: %v", err) return @@ -2876,7 +2875,7 @@ Content-Length: %d if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack { t.Errorf("%s: response id %q != request id %q", name, idBack, id) } - _, err = ioutil.ReadAll(res.Body) + _, err = io.ReadAll(res.Body) if err != nil { t.Fatalf("%s: Slurp error: %v", name, err) } @@ -3155,7 +3154,7 @@ func TestIdleConnChannelLeak(t *testing.T) { func TestTransportClosesRequestBody(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - io.Copy(ioutil.Discard, r.Body) + io.Copy(io.Discard, r.Body) })) defer ts.Close() @@ -3262,7 +3261,7 @@ func TestTLSServerClosesConnection(t *testing.T) { t.Fatal(err) } <-closedc - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -3277,7 +3276,7 @@ func TestTLSServerClosesConnection(t *testing.T) { errs = append(errs, err) continue } - slurp, err = ioutil.ReadAll(res.Body) + slurp, err = io.ReadAll(res.Body) if err != nil { errs = append(errs, err) continue @@ -3348,7 +3347,7 @@ func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { sconn.c = conn sconn.Unlock() conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive - go io.Copy(ioutil.Discard, conn) + go io.Copy(io.Discard, conn) })) defer ts.Close() c := ts.Client() @@ -3412,12 +3411,6 @@ func TestTransportIssue10457(t *testing.T) { } } -type errorReader struct { - err error -} - -func (e errorReader) Read(p []byte) (int, error) { return 0, e.err } - type closerFunc func() error func (f closerFunc) Close() error { return f() } @@ -3603,7 +3596,7 @@ func TestTransportClosesBodyOnError(t *testing.T) { defer afterTest(t) readBody := make(chan error, 1) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - _, err := ioutil.ReadAll(r.Body) + _, err := io.ReadAll(r.Body) readBody <- err })) defer ts.Close() @@ -3614,7 +3607,7 @@ func TestTransportClosesBodyOnError(t *testing.T) { io.Reader io.Closer }{ - io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), errorReader{fakeErr}), + io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), iotest.ErrReader(fakeErr)), closerFunc(func() error { select { case didClose <- true: @@ -3745,7 +3738,7 @@ func TestTransportDialTLSContext(t *testing.T) { if err != nil { return nil, err } - return c, c.Handshake() + return c, c.HandshakeContext(ctx) } req, err := NewRequest("GET", ts.URL, nil) @@ -3951,7 +3944,7 @@ func TestTransportResponseCancelRace(t *testing.T) { // If we do an early close, Transport just throws the connection away and // doesn't reuse it. In order to trigger the bug, it has to reuse the connection // so read the body - if _, err := io.Copy(ioutil.Discard, res.Body); err != nil { + if _, err := io.Copy(io.Discard, res.Body); err != nil { t.Fatal(err) } @@ -3988,7 +3981,7 @@ func TestTransportContentEncodingCaseInsensitive(t *testing.T) { t.Fatal(err) } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Fatal(err) @@ -4095,7 +4088,7 @@ func TestTransportFlushesBodyChunks(t *testing.T) { if err != nil { t.Fatal(err) } - io.Copy(ioutil.Discard, req.Body) + io.Copy(io.Discard, req.Body) // Unblock the transport's roundTrip goroutine. resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n") @@ -4476,7 +4469,7 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { // Do nothing for the second request. return } - if _, err := ioutil.ReadAll(r.Body); err != nil { + if _, err := io.ReadAll(r.Body); err != nil { t.Error(err) } if !noHooks { @@ -4564,7 +4557,7 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { t.Fatal(err) } logf("got roundtrip.response") - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -5182,6 +5175,57 @@ func TestTransportProxyConnectHeader(t *testing.T) { } } +func TestTransportProxyGetConnectHeader(t *testing.T) { + defer afterTest(t) + reqc := make(chan *Request, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method != "CONNECT" { + t.Errorf("method = %q; want CONNECT", r.Method) + } + reqc <- r + c, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Errorf("Hijack: %v", err) + return + } + c.Close() + })) + defer ts.Close() + + c := ts.Client() + c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) { + return url.Parse(ts.URL) + } + // These should be ignored: + c.Transport.(*Transport).ProxyConnectHeader = Header{ + "User-Agent": {"foo"}, + "Other": {"bar"}, + } + c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) { + return Header{ + "User-Agent": {"foo2"}, + "Other": {"bar2"}, + }, nil + } + + res, err := c.Get("https://dummy.tld/") // https to force a CONNECT + if err == nil { + res.Body.Close() + t.Errorf("unexpected success") + } + select { + case <-time.After(3 * time.Second): + t.Fatal("timeout") + case r := <-reqc: + if got, want := r.Header.Get("User-Agent"), "foo2"; got != want { + t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) + } + if got, want := r.Header.Get("Other"), "bar2"; got != want { + t.Errorf("CONNECT request Other = %q; want %q", got, want) + } + } +} + var errFakeRoundTrip = errors.New("fake roundtrip") type funcRoundTripper func() @@ -5195,7 +5239,7 @@ func wantBody(res *Response, err error, want string) error { if err != nil { return err } - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { return fmt.Errorf("error reading body: %v", err) } @@ -5294,7 +5338,7 @@ func TestMissingStatusNoPanic(t *testing.T) { conn, _ := ln.Accept() if conn != nil { io.WriteString(conn, raw) - ioutil.ReadAll(conn) + io.ReadAll(conn) conn.Close() } }() @@ -5312,7 +5356,7 @@ func TestMissingStatusNoPanic(t *testing.T) { t.Error("panicked, expecting an error") } if res != nil && res.Body != nil { - io.Copy(ioutil.Discard, res.Body) + io.Copy(io.Discard, res.Body) res.Body.Close() } @@ -5498,7 +5542,7 @@ func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) { } close(cancel) - got, err := ioutil.ReadAll(res.Body) + got, err := io.ReadAll(res.Body) if err == nil { t.Fatalf("unexpected success; read %q, nil", got) } @@ -5637,7 +5681,7 @@ func TestTransportCONNECTBidi(t *testing.T) { } func TestTransportRequestReplayable(t *testing.T) { - someBody := ioutil.NopCloser(strings.NewReader("")) + someBody := io.NopCloser(strings.NewReader("")) tests := []struct { name string req *Request @@ -5705,7 +5749,7 @@ func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) { func TestTransportRequestWriteRoundTrip(t *testing.T) { nBytes := int64(1 << 10) newFileFunc := func() (r io.Reader, done func(), err error) { - f, err := ioutil.TempFile("", "net-http-newfilefunc") + f, err := os.CreateTemp("", "net-http-newfilefunc") if err != nil { return nil, nil, err } @@ -5798,7 +5842,7 @@ func TestTransportRequestWriteRoundTrip(t *testing.T) { t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { - io.Copy(ioutil.Discard, r.Body) + io.Copy(io.Discard, r.Body) r.Body.Close() w.WriteHeader(200) }), @@ -5850,6 +5894,7 @@ func TestTransportClone(t *testing.T) { ResponseHeaderTimeout: time.Second, ExpectContinueTimeout: time.Second, ProxyConnectHeader: Header{}, + GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil }, MaxResponseHeaderBytes: 1, ForceAttemptHTTP2: true, TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{ @@ -5933,7 +5978,7 @@ func TestTransportIgnores408(t *testing.T) { if err != nil { t.Fatal(err) } - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -6195,7 +6240,7 @@ func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) { return } defer resp.Body.Close() - _, err = ioutil.ReadAll(resp.Body) + _, err = io.ReadAll(resp.Body) if err != nil { errCh <- fmt.Errorf("read body failed: %v", err) } @@ -6257,7 +6302,7 @@ func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) } func TestIssue32441(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - if n, _ := io.Copy(ioutil.Discard, r.Body); n == 0 { + if n, _ := io.Copy(io.Discard, r.Body); n == 0 { t.Error("body length is zero") } })) @@ -6265,7 +6310,7 @@ func TestIssue32441(t *testing.T) { c := ts.Client() c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) { // Draining body to trigger failure condition on actual request to server. - if n, _ := io.Copy(ioutil.Discard, r.Body); n == 0 { + if n, _ := io.Copy(io.Discard, r.Body); n == 0 { t.Error("body length is zero during round trip") } return nil, ErrSkipAltProtocol @@ -6293,3 +6338,152 @@ func TestTransportRejectsSignInContentLength(t *testing.T) { t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want) } } + +// dumpConn is a net.Conn which writes to Writer and reads from Reader +type dumpConn struct { + io.Writer + io.Reader +} + +func (c *dumpConn) Close() error { return nil } +func (c *dumpConn) LocalAddr() net.Addr { return nil } +func (c *dumpConn) RemoteAddr() net.Addr { return nil } +func (c *dumpConn) SetDeadline(t time.Time) error { return nil } +func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil } +func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil } + +// delegateReader is a reader that delegates to another reader, +// once it arrives on a channel. +type delegateReader struct { + c chan io.Reader + r io.Reader // nil until received from c +} + +func (r *delegateReader) Read(p []byte) (int, error) { + if r.r == nil { + var ok bool + if r.r, ok = <-r.c; !ok { + return 0, errors.New("delegate closed") + } + } + return r.r.Read(p) +} + +func testTransportRace(req *Request) { + save := req.Body + pr, pw := io.Pipe() + defer pr.Close() + defer pw.Close() + dr := &delegateReader{c: make(chan io.Reader)} + + t := &Transport{ + Dial: func(net, addr string) (net.Conn, error) { + return &dumpConn{pw, dr}, nil + }, + } + defer t.CloseIdleConnections() + + quitReadCh := make(chan struct{}) + // Wait for the request before replying with a dummy response: + go func() { + defer close(quitReadCh) + + req, err := ReadRequest(bufio.NewReader(pr)) + if err == nil { + // Ensure all the body is read; otherwise + // we'll get a partial dump. + io.Copy(io.Discard, req.Body) + req.Body.Close() + } + select { + case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"): + case quitReadCh <- struct{}{}: + // Ensure delegate is closed so Read doesn't block forever. + close(dr.c) + } + }() + + t.RoundTrip(req) + + // Ensure the reader returns before we reset req.Body to prevent + // a data race on req.Body. + pw.Close() + <-quitReadCh + + req.Body = save +} + +// Issue 37669 +// Test that a cancellation doesn't result in a data race due to the writeLoop +// goroutine being left running, if the caller mutates the processed Request +// upon completion. +func TestErrorWriteLoopRace(t *testing.T) { + if testing.Short() { + return + } + t.Parallel() + for i := 0; i < 1000; i++ { + delay := time.Duration(mrand.Intn(5)) * time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), delay) + defer cancel() + + r := bytes.NewBuffer(make([]byte, 10000)) + req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r) + if err != nil { + t.Fatal(err) + } + + testTransportRace(req) + } +} + +// Issue 41600 +// Test that a new request which uses the connection of an active request +// cannot cause it to be canceled as well. +func TestCancelRequestWhenSharingConnection(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, req *Request) { + w.Header().Add("Content-Length", "0") + })) + defer ts.Close() + + client := ts.Client() + transport := client.Transport.(*Transport) + transport.MaxIdleConns = 1 + transport.MaxConnsPerHost = 1 + + var wg sync.WaitGroup + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for ctx.Err() == nil { + reqctx, reqcancel := context.WithCancel(ctx) + go reqcancel() + req, _ := NewRequestWithContext(reqctx, "GET", ts.URL, nil) + res, err := client.Do(req) + if err == nil { + res.Body.Close() + } + } + }() + } + + for ctx.Err() == nil { + req, _ := NewRequest("GET", ts.URL, nil) + if res, err := client.Do(req); err != nil { + t.Errorf("unexpected: %p %v", req, err) + break + } else { + res.Body.Close() + } + } + + cancel() + wg.Wait() +} diff --git a/libgo/go/net/interface_solaris.go b/libgo/go/net/interface_solaris.go index 5f9367f..f8d1571 100644 --- a/libgo/go/net/interface_solaris.go +++ b/libgo/go/net/interface_solaris.go @@ -32,39 +32,21 @@ func interfaceTable(ifindex int) ([]Interface, error) { return ift, nil } -const ( - sysIFF_UP = 0x1 - sysIFF_BROADCAST = 0x2 - sysIFF_DEBUG = 0x4 - sysIFF_LOOPBACK = 0x8 - sysIFF_POINTOPOINT = 0x10 - sysIFF_NOTRAILERS = 0x20 - sysIFF_RUNNING = 0x40 - sysIFF_NOARP = 0x80 - sysIFF_PROMISC = 0x100 - sysIFF_ALLMULTI = 0x200 - sysIFF_INTELLIGENT = 0x400 - sysIFF_MULTICAST = 0x800 - sysIFF_MULTI_BCAST = 0x1000 - sysIFF_UNNUMBERED = 0x2000 - sysIFF_PRIVATE = 0x8000 -) - func linkFlags(rawFlags int) Flags { var f Flags - if rawFlags&sysIFF_UP != 0 { + if rawFlags&syscall.IFF_UP != 0 { f |= FlagUp } - if rawFlags&sysIFF_BROADCAST != 0 { + if rawFlags&syscall.IFF_BROADCAST != 0 { f |= FlagBroadcast } - if rawFlags&sysIFF_LOOPBACK != 0 { + if rawFlags&syscall.IFF_LOOPBACK != 0 { f |= FlagLoopback } - if rawFlags&sysIFF_POINTOPOINT != 0 { + if rawFlags&syscall.IFF_POINTOPOINT != 0 { f |= FlagPointToPoint } - if rawFlags&sysIFF_MULTICAST != 0 { + if rawFlags&syscall.IFF_MULTICAST != 0 { f |= FlagMulticast } return f diff --git a/libgo/go/net/interface_unix_test.go b/libgo/go/net/interface_unix_test.go index 6a2b7f1..bf41a0fb 100644 --- a/libgo/go/net/interface_unix_test.go +++ b/libgo/go/net/interface_unix_test.go @@ -46,7 +46,7 @@ func TestPointToPointInterface(t *testing.T) { if testing.Short() { t.Skip("avoid external network") } - if runtime.GOOS == "darwin" { + if runtime.GOOS == "darwin" || runtime.GOOS == "ios" { t.Skipf("not supported on %s", runtime.GOOS) } if os.Getuid() != 0 { diff --git a/libgo/go/net/internal/socktest/sys_cloexec.go b/libgo/go/net/internal/socktest/sys_cloexec.go index 7b9b8df..b13ba57 100644 --- a/libgo/go/net/internal/socktest/sys_cloexec.go +++ b/libgo/go/net/internal/socktest/sys_cloexec.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build dragonfly freebsd hurd linux netbsd openbsd +// +build dragonfly freebsd hurd illumos linux netbsd openbsd package socktest diff --git a/libgo/go/net/ipsock_plan9.go b/libgo/go/net/ipsock_plan9.go index 2308236..7a4b7a6 100644 --- a/libgo/go/net/ipsock_plan9.go +++ b/libgo/go/net/ipsock_plan9.go @@ -7,6 +7,7 @@ package net import ( "context" "internal/bytealg" + "io/fs" "os" "syscall" ) @@ -164,7 +165,7 @@ func fixErr(err error) { if nonNilInterface(oe.Addr) { oe.Addr = nil } - if pe, ok := oe.Err.(*os.PathError); ok { + if pe, ok := oe.Err.(*fs.PathError); ok { if _, ok = pe.Err.(syscall.ErrorString); ok { oe.Err = pe.Err } diff --git a/libgo/go/net/lookup_test.go b/libgo/go/net/lookup_test.go index 68bffca..32a0d37 100644 --- a/libgo/go/net/lookup_test.go +++ b/libgo/go/net/lookup_test.go @@ -511,7 +511,7 @@ func TestDNSFlood(t *testing.T) { defer dnsWaitGroup.Wait() var N = 5000 - if runtime.GOOS == "darwin" { + if runtime.GOOS == "darwin" || runtime.GOOS == "ios" { // On Darwin this test consumes kernel threads much // than other platforms for some reason. // When we monitor the number of allocated Ms by @@ -628,7 +628,7 @@ func TestLookupDotsWithLocalSource(t *testing.T) { } func TestLookupDotsWithRemoteSource(t *testing.T) { - if runtime.GOOS == "darwin" { + if runtime.GOOS == "darwin" || runtime.GOOS == "ios" { testenv.SkipFlaky(t, 27992) } mustHaveExternalNetwork(t) diff --git a/libgo/go/net/mail/example_test.go b/libgo/go/net/mail/example_test.go index c336564..d325dc7 100644 --- a/libgo/go/net/mail/example_test.go +++ b/libgo/go/net/mail/example_test.go @@ -6,7 +6,7 @@ package mail_test import ( "fmt" - "io/ioutil" + "io" "log" "net/mail" "strings" @@ -62,7 +62,7 @@ Message body fmt.Println("To:", header.Get("To")) fmt.Println("Subject:", header.Get("Subject")) - body, err := ioutil.ReadAll(m.Body) + body, err := io.ReadAll(m.Body) if err != nil { log.Fatal(err) } diff --git a/libgo/go/net/mail/message.go b/libgo/go/net/mail/message.go index 09fb794..47bbf6c 100644 --- a/libgo/go/net/mail/message.go +++ b/libgo/go/net/mail/message.go @@ -112,11 +112,25 @@ func ParseDate(date string) (time.Time, error) { if ind := strings.IndexAny(p.s, "+-"); ind != -1 && len(p.s) >= ind+5 { date = p.s[:ind+5] p.s = p.s[ind+5:] - } else if ind := strings.Index(p.s, "T"); ind != -1 && len(p.s) >= ind+1 { - // The last letter T of the obsolete time zone is checked when no standard time zone is found. - // If T is misplaced, the date to parse is garbage. - date = p.s[:ind+1] - p.s = p.s[ind+1:] + } else { + ind := strings.Index(p.s, "T") + if ind == 0 { + // In this case we have the following date formats: + // * Thu, 20 Nov 1997 09:55:06 MDT + // * Thu, 20 Nov 1997 09:55:06 MDT (MDT) + // * Thu, 20 Nov 1997 09:55:06 MDT (This comment) + ind = strings.Index(p.s[1:], "T") + if ind != -1 { + ind++ + } + } + + if ind != -1 && len(p.s) >= ind+5 { + // The last letter T of the obsolete time zone is checked when no standard time zone is found. + // If T is misplaced, the date to parse is garbage. + date = p.s[:ind+1] + p.s = p.s[ind+1:] + } } if !p.skipCFWS() { return time.Time{}, errors.New("mail: misformatted parenthetical comment") diff --git a/libgo/go/net/mail/message_test.go b/libgo/go/net/mail/message_test.go index 67e3643..0daa3d6 100644 --- a/libgo/go/net/mail/message_test.go +++ b/libgo/go/net/mail/message_test.go @@ -7,7 +7,6 @@ package mail import ( "bytes" "io" - "io/ioutil" "mime" "reflect" "strings" @@ -53,7 +52,7 @@ func TestParsing(t *testing.T) { t.Errorf("test #%d: Incorrectly parsed message header.\nGot:\n%+v\nWant:\n%+v", i, msg.Header, test.header) } - body, err := ioutil.ReadAll(msg.Body) + body, err := io.ReadAll(msg.Body) if err != nil { t.Errorf("test #%d: Failed reading body: %v", i, err) continue @@ -103,6 +102,18 @@ func TestDateParsing(t *testing.T) { "Fri, 21 Nov 1997 09:55:06 -0600 (MDT)", time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", -6*60*60)), }, + { + "Thu, 20 Nov 1997 09:55:06 -0600 (MDT)", + time.Date(1997, 11, 20, 9, 55, 6, 0, time.FixedZone("", -6*60*60)), + }, + { + "Thu, 20 Nov 1997 09:55:06 MDT (MDT)", + time.Date(1997, 11, 20, 9, 55, 6, 0, time.FixedZone("MDT", 0)), + }, + { + "Fri, 21 Nov 1997 09:55:06 +1300 (TOT)", + time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", +13*60*60)), + }, } for _, test := range tests { hdr := Header{ @@ -244,6 +255,33 @@ func TestDateParsingCFWS(t *testing.T) { time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("", -6*60*60)), false, }, + // Ensure that the presence of "T" in the date + // doesn't trip out ParseDate, as per issue 39260. + { + "Tue, 26 May 2020 14:04:40 GMT", + time.Date(2020, 05, 26, 14, 04, 40, 0, time.UTC), + true, + }, + { + "Tue, 26 May 2020 14:04:40 UT", + time.Date(2020, 05, 26, 14, 04, 40, 0, time.UTC), + false, + }, + { + "Thu, 21 May 2020 14:04:40 UT", + time.Date(2020, 05, 21, 14, 04, 40, 0, time.UTC), + false, + }, + { + "Thu, 21 May 2020 14:04:40 UTC", + time.Date(2020, 05, 21, 14, 04, 40, 0, time.UTC), + true, + }, + { + "Fri, 21 Nov 1997 09:55:06 MDT (MDT)", + time.Date(1997, 11, 21, 9, 55, 6, 0, time.FixedZone("MDT", 0)), + true, + }, } for _, test := range tests { hdr := Header{ @@ -842,7 +880,7 @@ func TestAddressParser(t *testing.T) { ap := AddressParser{WordDecoder: &mime.WordDecoder{ CharsetReader: func(charset string, input io.Reader) (io.Reader, error) { - in, err := ioutil.ReadAll(input) + in, err := io.ReadAll(input) if err != nil { return nil, err } diff --git a/libgo/go/net/main_cloexec_test.go b/libgo/go/net/main_cloexec_test.go index 28974e8..d436df1 100644 --- a/libgo/go/net/main_cloexec_test.go +++ b/libgo/go/net/main_cloexec_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build dragonfly freebsd hurd linux netbsd openbsd +// +build dragonfly freebsd hurd illumos linux netbsd openbsd package net diff --git a/libgo/go/net/main_test.go b/libgo/go/net/main_test.go index 85a269d..2d5be2e 100644 --- a/libgo/go/net/main_test.go +++ b/libgo/go/net/main_test.go @@ -133,7 +133,7 @@ func setupTestData() { {"udp6", "[" + addr + "%" + ifi.Name + "]:0", false}, }...) switch runtime.GOOS { - case "darwin", "dragonfly", "freebsd", "openbsd", "netbsd": + case "darwin", "ios", "dragonfly", "freebsd", "openbsd", "netbsd": ipv6LinkLocalUnicastTCPTests = append(ipv6LinkLocalUnicastTCPTests, []ipv6LinkLocalUnicastTest{ {"tcp", "[localhost%" + ifi.Name + "]:0", true}, {"tcp6", "[localhost%" + ifi.Name + "]:0", true}, diff --git a/libgo/go/net/mockserver_test.go b/libgo/go/net/mockserver_test.go index e085f44..867e31e 100644 --- a/libgo/go/net/mockserver_test.go +++ b/libgo/go/net/mockserver_test.go @@ -9,16 +9,15 @@ package net import ( "errors" "fmt" - "io/ioutil" "os" "sync" "testing" "time" ) -// testUnixAddr uses ioutil.TempFile to get a name that is unique. +// testUnixAddr uses os.CreateTemp to get a name that is unique. func testUnixAddr() string { - f, err := ioutil.TempFile("", "go-nettest") + f, err := os.CreateTemp("", "go-nettest") if err != nil { panic(err) } @@ -88,6 +87,7 @@ type localServer struct { lnmu sync.RWMutex Listener done chan bool // signal that indicates server stopped + cl []Conn // accepted connection list } func (ls *localServer) buildup(handler func(*localServer, Listener)) error { @@ -100,10 +100,16 @@ func (ls *localServer) buildup(handler func(*localServer, Listener)) error { func (ls *localServer) teardown() error { ls.lnmu.Lock() + defer ls.lnmu.Unlock() if ls.Listener != nil { network := ls.Listener.Addr().Network() address := ls.Listener.Addr().String() ls.Listener.Close() + for _, c := range ls.cl { + if err := c.Close(); err != nil { + return err + } + } <-ls.done ls.Listener = nil switch network { @@ -111,7 +117,6 @@ func (ls *localServer) teardown() error { os.Remove(address) } } - ls.lnmu.Unlock() return nil } @@ -204,7 +209,7 @@ func newDualStackServer() (*dualStackServer, error) { }, nil } -func transponder(ln Listener, ch chan<- error) { +func (ls *localServer) transponder(ln Listener, ch chan<- error) { defer close(ch) switch ln := ln.(type) { @@ -221,7 +226,7 @@ func transponder(ln Listener, ch chan<- error) { ch <- err return } - defer c.Close() + ls.cl = append(ls.cl, c) network := ln.Addr().Network() if c.LocalAddr().Network() != network || c.RemoteAddr().Network() != network { diff --git a/libgo/go/net/net.go b/libgo/go/net/net.go index 2e61a7c..4b4ed12 100644 --- a/libgo/go/net/net.go +++ b/libgo/go/net/net.go @@ -81,6 +81,7 @@ package net import ( "context" "errors" + "internal/poll" "io" "os" "sync" @@ -632,6 +633,17 @@ func (e *DNSError) Timeout() bool { return e.IsTimeout } // error and return a DNSError for which Temporary returns false. func (e *DNSError) Temporary() bool { return e.IsTimeout || e.IsTemporary } +// errClosed exists just so that the docs for ErrClosed don't mention +// the internal package poll. +var errClosed = poll.ErrNetClosing + +// ErrClosed is the error returned by an I/O call on a network +// connection that has already been closed, or that is closed by +// another goroutine before the I/O is completed. This may be wrapped +// in another error, and should normally be tested using +// errors.Is(err, net.ErrClosed). +var ErrClosed = errClosed + type writerOnly struct { io.Writer } diff --git a/libgo/go/net/platform_test.go b/libgo/go/net/platform_test.go index d3bb918..2da23de 100644 --- a/libgo/go/net/platform_test.go +++ b/libgo/go/net/platform_test.go @@ -59,7 +59,7 @@ func testableNetwork(network string) bool { } case "unixpacket": switch runtime.GOOS { - case "aix", "android", "darwin", "plan9", "windows": + case "aix", "android", "darwin", "ios", "plan9", "windows": return false case "netbsd": // It passes on amd64 at least. 386 fails (Issue 22927). arm is unknown. @@ -82,7 +82,7 @@ func testableNetwork(network string) bool { } func iOS() bool { - return runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" + return runtime.GOOS == "ios" } // testableAddress reports whether address of network is testable on diff --git a/libgo/go/net/protoconn_test.go b/libgo/go/net/protoconn_test.go index 9f6772c..6f83f52 100644 --- a/libgo/go/net/protoconn_test.go +++ b/libgo/go/net/protoconn_test.go @@ -72,7 +72,7 @@ func TestTCPConnSpecificMethods(t *testing.T) { t.Fatal(err) } ch := make(chan error, 1) - handler := func(ls *localServer, ln Listener) { transponder(ls.Listener, ch) } + handler := func(ls *localServer, ln Listener) { ls.transponder(ls.Listener, ch) } ls, err := (&streamListener{Listener: ln}).newLocalServer() if err != nil { t.Fatal(err) diff --git a/libgo/go/net/rawconn_unix_test.go b/libgo/go/net/rawconn_unix_test.go index d64dc75..21527c4 100644 --- a/libgo/go/net/rawconn_unix_test.go +++ b/libgo/go/net/rawconn_unix_test.go @@ -24,10 +24,7 @@ func readRawConn(c syscall.RawConn, b []byte) (int, error) { if err != nil { return n, err } - if operr != nil { - return n, operr - } - return n, nil + return n, operr } func writeRawConn(c syscall.RawConn, b []byte) error { @@ -42,10 +39,7 @@ func writeRawConn(c syscall.RawConn, b []byte) error { if err != nil { return err } - if operr != nil { - return operr - } - return nil + return operr } func controlRawConn(c syscall.RawConn, addr Addr) error { @@ -87,10 +81,7 @@ func controlRawConn(c syscall.RawConn, addr Addr) error { if err := c.Control(fn); err != nil { return err } - if operr != nil { - return operr - } - return nil + return operr } func controlOnConnSetup(network string, address string, c syscall.RawConn) error { @@ -120,8 +111,5 @@ func controlOnConnSetup(network string, address string, c syscall.RawConn) error if err := c.Control(fn); err != nil { return err } - if operr != nil { - return operr - } - return nil + return operr } diff --git a/libgo/go/net/rawconn_windows_test.go b/libgo/go/net/rawconn_windows_test.go index 2774c97..5febf08 100644 --- a/libgo/go/net/rawconn_windows_test.go +++ b/libgo/go/net/rawconn_windows_test.go @@ -26,10 +26,7 @@ func readRawConn(c syscall.RawConn, b []byte) (int, error) { if err != nil { return n, err } - if operr != nil { - return n, operr - } - return n, nil + return n, operr } func writeRawConn(c syscall.RawConn, b []byte) error { @@ -45,10 +42,7 @@ func writeRawConn(c syscall.RawConn, b []byte) error { if err != nil { return err } - if operr != nil { - return operr - } - return nil + return operr } func controlRawConn(c syscall.RawConn, addr Addr) error { @@ -92,10 +86,7 @@ func controlRawConn(c syscall.RawConn, addr Addr) error { if err := c.Control(fn); err != nil { return err } - if operr != nil { - return operr - } - return nil + return operr } func controlOnConnSetup(network string, address string, c syscall.RawConn) error { @@ -121,8 +112,5 @@ func controlOnConnSetup(network string, address string, c syscall.RawConn) error if err := c.Control(fn); err != nil { return err } - if operr != nil { - return operr - } - return nil + return operr } diff --git a/libgo/go/net/rpc/client.go b/libgo/go/net/rpc/client.go index 25f2a00..60bb2cc 100644 --- a/libgo/go/net/rpc/client.go +++ b/libgo/go/net/rpc/client.go @@ -245,7 +245,6 @@ func DialHTTP(network, address string) (*Client, error) { // DialHTTPPath connects to an HTTP RPC server // at the specified network address and path. func DialHTTPPath(network, address, path string) (*Client, error) { - var err error conn, err := net.Dial(network, address) if err != nil { return nil, err diff --git a/libgo/go/net/rpc/jsonrpc/all_test.go b/libgo/go/net/rpc/jsonrpc/all_test.go index 4e73edc..667f839 100644 --- a/libgo/go/net/rpc/jsonrpc/all_test.go +++ b/libgo/go/net/rpc/jsonrpc/all_test.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net" "net/rpc" "reflect" @@ -249,7 +248,7 @@ func TestMalformedInput(t *testing.T) { func TestMalformedOutput(t *testing.T) { cli, srv := net.Pipe() go srv.Write([]byte(`{"id":0,"result":null,"error":null}`)) - go ioutil.ReadAll(srv) + go io.ReadAll(srv) client := NewClient(cli) defer client.Close() @@ -271,7 +270,7 @@ func TestServerErrorHasNullResult(t *testing.T) { }{ Reader: strings.NewReader(`{"method": "Arith.Add", "id": "123", "params": []}`), Writer: &out, - Closer: ioutil.NopCloser(nil), + Closer: io.NopCloser(nil), }) r := new(rpc.Request) if err := sc.ReadRequestHeader(r); err != nil { diff --git a/libgo/go/net/sendfile_test.go b/libgo/go/net/sendfile_test.go index 13842a1..657a365 100644 --- a/libgo/go/net/sendfile_test.go +++ b/libgo/go/net/sendfile_test.go @@ -12,7 +12,6 @@ import ( "encoding/hex" "fmt" "io" - "io/ioutil" "os" "runtime" "sync" @@ -282,7 +281,7 @@ func TestSendfilePipe(t *testing.T) { return } defer conn.Close() - io.Copy(ioutil.Discard, conn) + io.Copy(io.Discard, conn) }() // Wait for the byte to be copied, meaning that sendfile has diff --git a/libgo/go/net/server_test.go b/libgo/go/net/server_test.go index 2673b87..4ac5443 100644 --- a/libgo/go/net/server_test.go +++ b/libgo/go/net/server_test.go @@ -86,7 +86,7 @@ func TestTCPServer(t *testing.T) { } for i := 0; i < N; i++ { ch := tpchs[i] - handler := func(ls *localServer, ln Listener) { transponder(ln, ch) } + handler := func(ls *localServer, ln Listener) { ls.transponder(ln, ch) } if err := lss[i].buildup(handler); err != nil { t.Fatal(err) } @@ -178,7 +178,7 @@ func TestUnixAndUnixpacketServer(t *testing.T) { } for i := 0; i < N; i++ { ch := tpchs[i] - handler := func(ls *localServer, ln Listener) { transponder(ln, ch) } + handler := func(ls *localServer, ln Listener) { ls.transponder(ln, ch) } if err := lss[i].buildup(handler); err != nil { t.Fatal(err) } diff --git a/libgo/go/net/smtp/smtp.go b/libgo/go/net/smtp/smtp.go index e4e12ae..1a6864a 100644 --- a/libgo/go/net/smtp/smtp.go +++ b/libgo/go/net/smtp/smtp.go @@ -241,7 +241,8 @@ func (c *Client) Auth(a Auth) error { // Mail issues a MAIL command to the server using the provided email address. // If the server supports the 8BITMIME extension, Mail adds the BODY=8BITMIME -// parameter. +// parameter. If the server supports the SMTPUTF8 extension, Mail adds the +// SMTPUTF8 parameter. // This initiates a mail transaction and is followed by one or more Rcpt calls. func (c *Client) Mail(from string) error { if err := validateLine(from); err != nil { @@ -255,6 +256,9 @@ func (c *Client) Mail(from string) error { if _, ok := c.ext["8BITMIME"]; ok { cmdStr += " BODY=8BITMIME" } + if _, ok := c.ext["SMTPUTF8"]; ok { + cmdStr += " SMTPUTF8" + } } _, _, err := c.cmd(250, cmdStr, from) return err diff --git a/libgo/go/net/smtp/smtp_test.go b/libgo/go/net/smtp/smtp_test.go index cfda079..5521937 100644 --- a/libgo/go/net/smtp/smtp_test.go +++ b/libgo/go/net/smtp/smtp_test.go @@ -288,6 +288,219 @@ Goodbye. QUIT ` +func TestExtensions(t *testing.T) { + fake := func(server string) (c *Client, bcmdbuf *bufio.Writer, cmdbuf *strings.Builder) { + server = strings.Join(strings.Split(server, "\n"), "\r\n") + + cmdbuf = &strings.Builder{} + bcmdbuf = bufio.NewWriter(cmdbuf) + var fake faker + fake.ReadWriter = bufio.NewReadWriter(bufio.NewReader(strings.NewReader(server)), bcmdbuf) + c = &Client{Text: textproto.NewConn(fake), localName: "localhost"} + + return c, bcmdbuf, cmdbuf + } + + t.Run("helo", func(t *testing.T) { + const ( + basicServer = `250 mx.google.com at your service +250 Sender OK +221 Goodbye +` + + basicClient = `HELO localhost +MAIL FROM:<user@gmail.com> +QUIT +` + ) + + c, bcmdbuf, cmdbuf := fake(basicServer) + + if err := c.helo(); err != nil { + t.Fatalf("HELO failed: %s", err) + } + c.didHello = true + if err := c.Mail("user@gmail.com"); err != nil { + t.Fatalf("MAIL FROM failed: %s", err) + } + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %s", err) + } + + bcmdbuf.Flush() + actualcmds := cmdbuf.String() + client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") + if client != actualcmds { + t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) + } + }) + + t.Run("ehlo", func(t *testing.T) { + const ( + basicServer = `250-mx.google.com at your service +250 SIZE 35651584 +250 Sender OK +221 Goodbye +` + + basicClient = `EHLO localhost +MAIL FROM:<user@gmail.com> +QUIT +` + ) + + c, bcmdbuf, cmdbuf := fake(basicServer) + + if err := c.Hello("localhost"); err != nil { + t.Fatalf("EHLO failed: %s", err) + } + if ok, _ := c.Extension("8BITMIME"); ok { + t.Fatalf("Shouldn't support 8BITMIME") + } + if ok, _ := c.Extension("SMTPUTF8"); ok { + t.Fatalf("Shouldn't support SMTPUTF8") + } + if err := c.Mail("user@gmail.com"); err != nil { + t.Fatalf("MAIL FROM failed: %s", err) + } + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %s", err) + } + + bcmdbuf.Flush() + actualcmds := cmdbuf.String() + client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") + if client != actualcmds { + t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) + } + }) + + t.Run("ehlo 8bitmime", func(t *testing.T) { + const ( + basicServer = `250-mx.google.com at your service +250-SIZE 35651584 +250 8BITMIME +250 Sender OK +221 Goodbye +` + + basicClient = `EHLO localhost +MAIL FROM:<user@gmail.com> BODY=8BITMIME +QUIT +` + ) + + c, bcmdbuf, cmdbuf := fake(basicServer) + + if err := c.Hello("localhost"); err != nil { + t.Fatalf("EHLO failed: %s", err) + } + if ok, _ := c.Extension("8BITMIME"); !ok { + t.Fatalf("Should support 8BITMIME") + } + if ok, _ := c.Extension("SMTPUTF8"); ok { + t.Fatalf("Shouldn't support SMTPUTF8") + } + if err := c.Mail("user@gmail.com"); err != nil { + t.Fatalf("MAIL FROM failed: %s", err) + } + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %s", err) + } + + bcmdbuf.Flush() + actualcmds := cmdbuf.String() + client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") + if client != actualcmds { + t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) + } + }) + + t.Run("ehlo smtputf8", func(t *testing.T) { + const ( + basicServer = `250-mx.google.com at your service +250-SIZE 35651584 +250 SMTPUTF8 +250 Sender OK +221 Goodbye +` + + basicClient = `EHLO localhost +MAIL FROM:<user+📧@gmail.com> SMTPUTF8 +QUIT +` + ) + + c, bcmdbuf, cmdbuf := fake(basicServer) + + if err := c.Hello("localhost"); err != nil { + t.Fatalf("EHLO failed: %s", err) + } + if ok, _ := c.Extension("8BITMIME"); ok { + t.Fatalf("Shouldn't support 8BITMIME") + } + if ok, _ := c.Extension("SMTPUTF8"); !ok { + t.Fatalf("Should support SMTPUTF8") + } + if err := c.Mail("user+📧@gmail.com"); err != nil { + t.Fatalf("MAIL FROM failed: %s", err) + } + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %s", err) + } + + bcmdbuf.Flush() + actualcmds := cmdbuf.String() + client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") + if client != actualcmds { + t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) + } + }) + + t.Run("ehlo 8bitmime smtputf8", func(t *testing.T) { + const ( + basicServer = `250-mx.google.com at your service +250-SIZE 35651584 +250-8BITMIME +250 SMTPUTF8 +250 Sender OK +221 Goodbye + ` + + basicClient = `EHLO localhost +MAIL FROM:<user+📧@gmail.com> BODY=8BITMIME SMTPUTF8 +QUIT +` + ) + + c, bcmdbuf, cmdbuf := fake(basicServer) + + if err := c.Hello("localhost"); err != nil { + t.Fatalf("EHLO failed: %s", err) + } + c.didHello = true + if ok, _ := c.Extension("8BITMIME"); !ok { + t.Fatalf("Should support 8BITMIME") + } + if ok, _ := c.Extension("SMTPUTF8"); !ok { + t.Fatalf("Should support SMTPUTF8") + } + if err := c.Mail("user+📧@gmail.com"); err != nil { + t.Fatalf("MAIL FROM failed: %s", err) + } + if err := c.Quit(); err != nil { + t.Fatalf("QUIT failed: %s", err) + } + + bcmdbuf.Flush() + actualcmds := cmdbuf.String() + client := strings.Join(strings.Split(basicClient, "\n"), "\r\n") + if client != actualcmds { + t.Fatalf("Got:\n%s\nExpected:\n%s", actualcmds, client) + } + }) +} + func TestNewClient(t *testing.T) { server := strings.Join(strings.Split(newClientServer, "\n"), "\r\n") client := strings.Join(strings.Split(newClientClient, "\n"), "\r\n") diff --git a/libgo/go/net/sock_bsd.go b/libgo/go/net/sock_bsd.go index 516e557..73fb6be 100644 --- a/libgo/go/net/sock_bsd.go +++ b/libgo/go/net/sock_bsd.go @@ -17,7 +17,7 @@ func maxListenerBacklog() int { err error ) switch runtime.GOOS { - case "darwin": + case "darwin", "ios": n, err = syscall.SysctlUint32("kern.ipc.somaxconn") case "freebsd": n, err = syscall.SysctlUint32("kern.ipc.soacceptqueue") diff --git a/libgo/go/net/sock_cloexec.go b/libgo/go/net/sock_cloexec.go index b70bb4c..43be6ff 100644 --- a/libgo/go/net/sock_cloexec.go +++ b/libgo/go/net/sock_cloexec.go @@ -5,7 +5,7 @@ // This file implements sysSocket and accept for platforms that // provide a fast path for setting SetNonblock and CloseOnExec. -// +build dragonfly freebsd hurd linux netbsd openbsd +// +build dragonfly freebsd hurd illumos linux netbsd openbsd package net diff --git a/libgo/go/net/sock_linux.go b/libgo/go/net/sock_linux.go index 7bca376..9f62ed3 100644 --- a/libgo/go/net/sock_linux.go +++ b/libgo/go/net/sock_linux.go @@ -6,6 +6,63 @@ package net import "syscall" +func kernelVersion() (major int, minor int) { + var uname syscall.Utsname + if err := syscall.Uname(&uname); err != nil { + return + } + + rl := uname.Release + var values [2]int + vi := 0 + value := 0 + for _, c := range rl { + if c >= '0' && c <= '9' { + value = (value * 10) + int(c-'0') + } else { + // Note that we're assuming N.N.N here. If we see anything else we are likely to + // mis-parse it. + values[vi] = value + vi++ + if vi >= len(values) { + break + } + value = 0 + } + } + switch vi { + case 0: + return 0, 0 + case 1: + return values[0], 0 + case 2: + return values[0], values[1] + } + return +} + +// Linux stores the backlog as: +// +// - uint16 in kernel version < 4.1, +// - uint32 in kernel version >= 4.1 +// +// Truncate number to avoid wrapping. +// +// See issue 5030 and 41470. +func maxAckBacklog(n int) int { + major, minor := kernelVersion() + size := 16 + if major > 4 || (major == 4 && minor >= 1) { + size = 32 + } + + var max uint = 1<<size - 1 + if uint(n) > max { + n = int(max) + } + return n +} + func maxListenerBacklog() int { fd, err := open("/proc/sys/net/core/somaxconn") if err != nil { @@ -21,11 +78,9 @@ func maxListenerBacklog() int { if n == 0 || !ok { return syscall.SOMAXCONN } - // Linux stores the backlog in a uint16. - // Truncate number to avoid wrapping. - // See issue 5030. + if n > 1<<16-1 { - n = 1<<16 - 1 + return maxAckBacklog(n) } return n } diff --git a/libgo/go/net/sock_linux_test.go b/libgo/go/net/sock_linux_test.go new file mode 100644 index 0000000..5df0293 --- /dev/null +++ b/libgo/go/net/sock_linux_test.go @@ -0,0 +1,22 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "testing" +) + +func TestMaxAckBacklog(t *testing.T) { + n := 196602 + major, minor := kernelVersion() + backlog := maxAckBacklog(n) + expected := 1<<16 - 1 + if major > 4 || (major == 4 && minor >= 1) { + expected = n + } + if backlog != expected { + t.Fatalf(`Kernel version: "%d.%d", sk_max_ack_backlog mismatch, got %d, want %d`, major, minor, backlog, expected) + } +} diff --git a/libgo/go/net/splice_test.go b/libgo/go/net/splice_test.go index b14ab9f..cd4e01f 100644 --- a/libgo/go/net/splice_test.go +++ b/libgo/go/net/splice_test.go @@ -8,7 +8,6 @@ package net import ( "io" - "io/ioutil" "log" "os" "os/exec" @@ -202,7 +201,7 @@ func testSpliceIssue25985(t *testing.T, upNet, downNet string) { } defer fromProxy.Close() - _, err = ioutil.ReadAll(fromProxy) + _, err = io.ReadAll(fromProxy) if err != nil { t.Fatal(err) } diff --git a/libgo/go/net/sys_cloexec.go b/libgo/go/net/sys_cloexec.go index 89aad70..967b8be 100644 --- a/libgo/go/net/sys_cloexec.go +++ b/libgo/go/net/sys_cloexec.go @@ -5,7 +5,7 @@ // This file implements sysSocket and accept for platforms that do not // provide a fast path for setting SetNonblock and CloseOnExec. -// +build aix darwin solaris +// +build aix darwin solaris,!illumos package net diff --git a/libgo/go/net/tcpsock_test.go b/libgo/go/net/tcpsock_test.go index 920bf42..c220c0b 100644 --- a/libgo/go/net/tcpsock_test.go +++ b/libgo/go/net/tcpsock_test.go @@ -393,7 +393,7 @@ func TestIPv6LinkLocalUnicastTCP(t *testing.T) { } defer ls.teardown() ch := make(chan error, 1) - handler := func(ls *localServer, ln Listener) { transponder(ln, ch) } + handler := func(ls *localServer, ln Listener) { ls.transponder(ln, ch) } if err := ls.buildup(handler); err != nil { t.Fatal(err) } @@ -652,7 +652,7 @@ func TestTCPSelfConnect(t *testing.T) { n = 1000 } switch runtime.GOOS { - case "darwin", "dragonfly", "freebsd", "netbsd", "openbsd", "plan9", "illumos", "solaris", "windows": + case "darwin", "ios", "dragonfly", "freebsd", "netbsd", "openbsd", "plan9", "illumos", "solaris", "windows": // Non-Linux systems take a long time to figure // out that there is nothing listening on localhost. n = 100 diff --git a/libgo/go/net/textproto/reader.go b/libgo/go/net/textproto/reader.go index a00fd23..5c3084f 100644 --- a/libgo/go/net/textproto/reader.go +++ b/libgo/go/net/textproto/reader.go @@ -9,7 +9,6 @@ import ( "bytes" "fmt" "io" - "io/ioutil" "strconv" "strings" "sync" @@ -426,7 +425,7 @@ func (r *Reader) closeDot() { // // See the documentation for the DotReader method for details about dot-encoding. func (r *Reader) ReadDotBytes() ([]byte, error) { - return ioutil.ReadAll(r.DotReader()) + return io.ReadAll(r.DotReader()) } // ReadDotLines reads a dot-encoding and returns a slice diff --git a/libgo/go/net/timeout_test.go b/libgo/go/net/timeout_test.go index ad14cd7..205aaa4 100644 --- a/libgo/go/net/timeout_test.go +++ b/libgo/go/net/timeout_test.go @@ -11,7 +11,6 @@ import ( "fmt" "internal/testenv" "io" - "io/ioutil" "net/internal/socktest" "os" "runtime" @@ -874,7 +873,7 @@ func testVariousDeadlines(t *testing.T) { if err := c.SetDeadline(t0.Add(timeout)); err != nil { t.Error(err) } - n, err := io.Copy(ioutil.Discard, c) + n, err := io.Copy(io.Discard, c) dt := time.Since(t0) c.Close() ch <- result{n, err, dt} diff --git a/libgo/go/net/udpsock.go b/libgo/go/net/udpsock.go index ec2bcfa..571e099 100644 --- a/libgo/go/net/udpsock.go +++ b/libgo/go/net/udpsock.go @@ -259,6 +259,9 @@ func ListenUDP(network string, laddr *UDPAddr) (*UDPConn, error) { // ListenMulticastUDP is just for convenience of simple, small // applications. There are golang.org/x/net/ipv4 and // golang.org/x/net/ipv6 packages for general purpose uses. +// +// Note that ListenMulticastUDP will set the IP_MULTICAST_LOOP socket option +// to 0 under IPPROTO_IP, to disable loopback of multicast packets. func ListenMulticastUDP(network string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) { switch network { case "udp", "udp4", "udp6": diff --git a/libgo/go/net/udpsock_test.go b/libgo/go/net/udpsock_test.go index 947381a..327eba6 100644 --- a/libgo/go/net/udpsock_test.go +++ b/libgo/go/net/udpsock_test.go @@ -327,7 +327,7 @@ func TestUDPZeroBytePayload(t *testing.T) { switch runtime.GOOS { case "plan9": t.Skipf("not supported on %s", runtime.GOOS) - case "darwin": + case "darwin", "ios": testenv.SkipFlaky(t, 29225) } diff --git a/libgo/go/net/unixsock_test.go b/libgo/go/net/unixsock_test.go index 4b2cfc4d..0b13bf6 100644 --- a/libgo/go/net/unixsock_test.go +++ b/libgo/go/net/unixsock_test.go @@ -9,7 +9,6 @@ package net import ( "bytes" "internal/testenv" - "io/ioutil" "os" "reflect" "runtime" @@ -417,7 +416,7 @@ func TestUnixUnlink(t *testing.T) { checkExists(t, "after Listen") l.Close() checkNotExists(t, "after Listener close") - if err := ioutil.WriteFile(name, []byte("hello world"), 0666); err != nil { + if err := os.WriteFile(name, []byte("hello world"), 0666); err != nil { t.Fatalf("cannot recreate socket file: %v", err) } checkExists(t, "after writing temp file") diff --git a/libgo/go/net/url/url.go b/libgo/go/net/url/url.go index c93def0..d90f5f0 100644 --- a/libgo/go/net/url/url.go +++ b/libgo/go/net/url/url.go @@ -1000,25 +1000,52 @@ func resolvePath(base, ref string) string { if full == "" { return "" } - src := strings.Split(full, "/") - dst := make([]string, 0, len(src)) - for _, elem := range src { - switch elem { - case ".": + + var ( + last string + elem string + i int + dst strings.Builder + ) + first := true + remaining := full + for i >= 0 { + i = strings.IndexByte(remaining, '/') + if i < 0 { + last, elem, remaining = remaining, remaining, "" + } else { + elem, remaining = remaining[:i], remaining[i+1:] + } + if elem == "." { + first = false // drop - case "..": - if len(dst) > 0 { - dst = dst[:len(dst)-1] + continue + } + + if elem == ".." { + str := dst.String() + index := strings.LastIndexByte(str, '/') + + dst.Reset() + if index == -1 { + first = true + } else { + dst.WriteString(str[:index]) } - default: - dst = append(dst, elem) + } else { + if !first { + dst.WriteByte('/') + } + dst.WriteString(elem) + first = false } } - if last := src[len(src)-1]; last == "." || last == ".." { - // Add final slash to the joined path. - dst = append(dst, "") + + if last == "." || last == ".." { + dst.WriteByte('/') } - return "/" + strings.TrimPrefix(strings.Join(dst, "/"), "/") + + return "/" + strings.TrimPrefix(dst.String(), "/") } // IsAbs reports whether the URL is absolute. diff --git a/libgo/go/net/url/url_test.go b/libgo/go/net/url/url_test.go index 92b15af..f02e4650 100644 --- a/libgo/go/net/url/url_test.go +++ b/libgo/go/net/url/url_test.go @@ -1114,6 +1114,14 @@ func TestResolvePath(t *testing.T) { } } +func BenchmarkResolvePath(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + for i := 0; i < b.N; i++ { + resolvePath("a/b/c", ".././d") + } +} + var resolveReferenceTests = []struct { base, rel, expected string }{ diff --git a/libgo/go/net/writev_test.go b/libgo/go/net/writev_test.go index c43be84..d603b7f 100644 --- a/libgo/go/net/writev_test.go +++ b/libgo/go/net/writev_test.go @@ -11,7 +11,6 @@ import ( "fmt" "internal/poll" "io" - "io/ioutil" "reflect" "runtime" "sync" @@ -28,7 +27,7 @@ func TestBuffers_read(t *testing.T) { []byte("in "), []byte("Gopherland ... "), } - got, err := ioutil.ReadAll(&buffers) + got, err := io.ReadAll(&buffers) if err != nil { t.Fatal(err) } @@ -141,7 +140,7 @@ func testBuffer_writeTo(t *testing.T, chunks int, useCopy bool) { } return nil }, func(c *TCPConn) error { - all, err := ioutil.ReadAll(c) + all, err := io.ReadAll(c) if !bytes.Equal(all, want.Bytes()) || err != nil { return fmt.Errorf("client read %q, %v; want %q, nil", all, err, want.Bytes()) } @@ -154,7 +153,7 @@ func testBuffer_writeTo(t *testing.T, chunks int, useCopy bool) { var wantSum int switch runtime.GOOS { - case "android", "darwin", "dragonfly", "freebsd", "linux", "netbsd", "openbsd": + case "android", "darwin", "ios", "dragonfly", "freebsd", "illumos", "linux", "netbsd", "openbsd": var wantMinCalls int wantSum = want.Len() v := chunks diff --git a/libgo/go/net/writev_unix.go b/libgo/go/net/writev_unix.go index bf0fbf8..8b20f42 100644 --- a/libgo/go/net/writev_unix.go +++ b/libgo/go/net/writev_unix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd linux netbsd openbsd +// +build darwin dragonfly freebsd illumos linux netbsd openbsd package net |