diff options
author | Ian Lance Taylor <iant@golang.org> | 2022-02-11 14:53:56 -0800 |
---|---|---|
committer | Ian Lance Taylor <iant@golang.org> | 2022-02-11 15:01:19 -0800 |
commit | 8dc2499aa62f768c6395c9754b8cabc1ce25c494 (patch) | |
tree | 43d7fd2bbfd7ad8c9625a718a5e8718889351994 /libgo/go/net/http | |
parent | 9a56779dbc4e2d9c15be8d31e36f2f59be7331a8 (diff) | |
download | gcc-8dc2499aa62f768c6395c9754b8cabc1ce25c494.zip gcc-8dc2499aa62f768c6395c9754b8cabc1ce25c494.tar.gz gcc-8dc2499aa62f768c6395c9754b8cabc1ce25c494.tar.bz2 |
libgo: update to Go1.18beta2
gotools/
* Makefile.am (go_cmd_cgo_files): Add ast_go118.go
(check-go-tool): Copy golang.org/x/tools directories.
* Makefile.in: Regenerate.
Reviewed-on: https://go-review.googlesource.com/c/gofrontend/+/384695
Diffstat (limited to 'libgo/go/net/http')
43 files changed, 1986 insertions, 1041 deletions
diff --git a/libgo/go/net/http/cgi/child.go b/libgo/go/net/http/cgi/child.go index 0114da3..bdb35a6 100644 --- a/libgo/go/net/http/cgi/child.go +++ b/libgo/go/net/http/cgi/child.go @@ -39,8 +39,8 @@ func Request() (*http.Request, error) { func envMap(env []string) map[string]string { m := make(map[string]string) for _, kv := range env { - if idx := strings.Index(kv, "="); idx != -1 { - m[kv[:idx]] = kv[idx+1:] + if k, v, ok := strings.Cut(kv, "="); ok { + m[k] = v } } return m diff --git a/libgo/go/net/http/cgi/host.go b/libgo/go/net/http/cgi/host.go index eff67ca..95b2e13 100644 --- a/libgo/go/net/http/cgi/host.go +++ b/libgo/go/net/http/cgi/host.go @@ -273,12 +273,11 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { break } headerLines++ - parts := strings.SplitN(string(line), ":", 2) - if len(parts) < 2 { + header, val, ok := strings.Cut(string(line), ":") + if !ok { h.printf("cgi: bogus header line: %s", string(line)) continue } - header, val := parts[0], parts[1] if !httpguts.ValidHeaderFieldName(header) { h.printf("cgi: invalid header name: %q", header) continue @@ -351,7 +350,7 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } } -func (h *Handler) printf(format string, v ...interface{}) { +func (h *Handler) printf(format string, v ...any) { if h.Logger != nil { h.Logger.Printf(format, v...) } else { diff --git a/libgo/go/net/http/cgi/host_test.go b/libgo/go/net/http/cgi/host_test.go index 9f1716b..1b72f7e 100644 --- a/libgo/go/net/http/cgi/host_test.go +++ b/libgo/go/net/http/cgi/host_test.go @@ -62,12 +62,12 @@ readlines: } linesRead++ trimmedLine := strings.TrimRight(line, "\r\n") - split := strings.SplitN(trimmedLine, "=", 2) - if len(split) != 2 { - t.Fatalf("Unexpected %d parts from invalid line number %v: %q; existing map=%v", - len(split), linesRead, line, m) + k, v, ok := strings.Cut(trimmedLine, "=") + if !ok { + t.Fatalf("Unexpected response from invalid line number %v: %q; existing map=%v", + linesRead, line, m) } - m[split[0]] = split[1] + m[k] = v } for key, expected := range expectedMap { diff --git a/libgo/go/net/http/cgi/posix_test.go b/libgo/go/net/http/cgi/posix_test.go index bc58ea9..49b9470 100644 --- a/libgo/go/net/http/cgi/posix_test.go +++ b/libgo/go/net/http/cgi/posix_test.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !plan9 -// +build !plan9 package cgi diff --git a/libgo/go/net/http/client.go b/libgo/go/net/http/client.go index 4d380c6..22db96b 100644 --- a/libgo/go/net/http/client.go +++ b/libgo/go/net/http/client.go @@ -965,7 +965,6 @@ func (b *cancelTimerBody) Read(p []byte) (n int, err error) { if err == nil { return n, nil } - b.stop() if err == io.EOF { return n, err } diff --git a/libgo/go/net/http/client_test.go b/libgo/go/net/http/client_test.go index 01d605c..e91d526 100644 --- a/libgo/go/net/http/client_test.go +++ b/libgo/go/net/http/client_test.go @@ -13,6 +13,7 @@ import ( "encoding/base64" "errors" "fmt" + "internal/testenv" "io" "log" "net" @@ -21,6 +22,7 @@ import ( "net/http/httptest" "net/url" "reflect" + "runtime" "strconv" "strings" "sync" @@ -431,11 +433,10 @@ func testRedirectsByMethod(t *testing.T, method string, table []redirectTest, wa if v := urlQuery.Get("code"); v != "" { location := ts.URL if final := urlQuery.Get("next"); final != "" { - splits := strings.Split(final, ",") - first, rest := splits[0], splits[1:] + first, rest, _ := strings.Cut(final, ",") location = fmt.Sprintf("%s?code=%s", location, first) - if len(rest) > 0 { - location = fmt.Sprintf("%s&next=%s", location, strings.Join(rest, ",")) + if rest != "" { + location = fmt.Sprintf("%s&next=%s", location, rest) } } code, _ := strconv.Atoi(v) @@ -746,7 +747,7 @@ func (j *RecordingJar) Cookies(u *url.URL) []*Cookie { return nil } -func (j *RecordingJar) logf(format string, args ...interface{}) { +func (j *RecordingJar) logf(format string, args ...any) { j.mu.Lock() defer j.mu.Unlock() fmt.Fprintf(&j.log, format, args...) @@ -1206,64 +1207,80 @@ func TestClientTimeout_h2(t *testing.T) { testClientTimeout(t, h2Mode) } func testClientTimeout(t *testing.T, h2 bool) { setParallel(t) defer afterTest(t) - testDone := make(chan struct{}) // closed in defer below - sawRoot := make(chan bool, 1) - sawSlow := make(chan bool, 1) + var ( + mu sync.Mutex + nonce string // a unique per-request string + sawSlowNonce bool // true if the handler saw /slow?nonce=<nonce> + ) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + _ = r.ParseForm() if r.URL.Path == "/" { - sawRoot <- true - Redirect(w, r, "/slow", StatusFound) + Redirect(w, r, "/slow?nonce="+r.Form.Get("nonce"), StatusFound) return } if r.URL.Path == "/slow" { - sawSlow <- true + mu.Lock() + if r.Form.Get("nonce") == nonce { + sawSlowNonce = true + } else { + t.Logf("mismatched nonce: received %s, want %s", r.Form.Get("nonce"), nonce) + } + mu.Unlock() + w.Write([]byte("Hello")) w.(Flusher).Flush() - <-testDone + <-r.Context().Done() return } })) defer cst.close() - defer close(testDone) // before cst.close, to unblock /slow handler - // 200ms should be long enough to get a normal request (the / - // handler), but not so long that it makes the test slow. - const timeout = 200 * time.Millisecond - cst.c.Timeout = timeout - - res, err := cst.c.Get(cst.ts.URL) - if err != nil { - if strings.Contains(err.Error(), "Client.Timeout") { - t.Skipf("host too slow to get fast resource in %v", timeout) + // Try to trigger a timeout after reading part of the response body. + // The initial timeout is emprically usually long enough on a decently fast + // machine, but if we undershoot we'll retry with exponentially longer + // timeouts until the test either passes or times out completely. + // This keeps the test reasonably fast in the typical case but allows it to + // also eventually succeed on arbitrarily slow machines. + timeout := 10 * time.Millisecond + nextNonce := 0 + for ; ; timeout *= 2 { + if timeout <= 0 { + // The only way we can feasibly hit this while the test is running is if + // the request fails without actually waiting for the timeout to occur. + t.Fatalf("timeout overflow") + } + if deadline, ok := t.Deadline(); ok && !time.Now().Add(timeout).Before(deadline) { + t.Fatalf("failed to produce expected timeout before test deadline") + } + t.Logf("attempting test with timeout %v", timeout) + cst.c.Timeout = timeout + + mu.Lock() + nonce = fmt.Sprint(nextNonce) + nextNonce++ + sawSlowNonce = false + mu.Unlock() + res, err := cst.c.Get(cst.ts.URL + "/?nonce=" + nonce) + if err != nil { + if strings.Contains(err.Error(), "Client.Timeout") { + // Timed out before handler could respond. + t.Logf("timeout before response received") + continue + } + t.Fatal(err) } - t.Fatal(err) - } - - select { - case <-sawRoot: - // good. - default: - t.Fatal("handler never got / request") - } - select { - case <-sawSlow: - // good. - default: - t.Fatal("handler never got /slow request") - } + mu.Lock() + ok := sawSlowNonce + mu.Unlock() + if !ok { + t.Fatal("handler never got /slow request, but client returned response") + } - errc := make(chan error, 1) - go func() { - _, err := io.ReadAll(res.Body) - errc <- err + _, err = io.ReadAll(res.Body) res.Body.Close() - }() - const failTime = 5 * time.Second - select { - case err := <-errc: if err == nil { t.Fatal("expected error from ReadAll") } @@ -1274,10 +1291,13 @@ func testClientTimeout(t *testing.T, h2 bool) { t.Errorf("net.Error.Timeout = false; want true") } if got := ne.Error(); !strings.Contains(got, "(Client.Timeout") { + if runtime.GOOS == "windows" && strings.HasPrefix(runtime.GOARCH, "arm") { + testenv.SkipFlaky(t, 43120) + } t.Errorf("error string = %q; missing timeout substring", got) } - case <-time.After(failTime): - t.Errorf("timeout after %v waiting for timeout of %v", failTime, timeout) + + break } } @@ -1319,6 +1339,9 @@ func testClientTimeout_Headers(t *testing.T, h2 bool) { t.Error("net.Error.Timeout = false; want true") } if got := ne.Error(); !strings.Contains(got, "Client.Timeout exceeded") { + if runtime.GOOS == "windows" && strings.HasPrefix(runtime.GOARCH, "arm") { + testenv.SkipFlaky(t, 43120) + } t.Errorf("error string = %q; missing timeout substring", got) } } @@ -1353,6 +1376,33 @@ func TestClientTimeoutCancel(t *testing.T) { } } +func TestClientTimeoutDoesNotExpire_h1(t *testing.T) { testClientTimeoutDoesNotExpire(t, h1Mode) } +func TestClientTimeoutDoesNotExpire_h2(t *testing.T) { testClientTimeoutDoesNotExpire(t, h2Mode) } + +// Issue 49366: if Client.Timeout is set but not hit, no error should be returned. +func testClientTimeoutDoesNotExpire(t *testing.T, h2 bool) { + setParallel(t) + defer afterTest(t) + + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Write([]byte("body")) + })) + defer cst.close() + + cst.c.Timeout = 1 * time.Hour + req, _ := NewRequest("GET", cst.ts.URL, nil) + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + if _, err = io.Copy(io.Discard, res.Body); err != nil { + t.Fatalf("io.Copy(io.Discard, res.Body) = %v, want nil", err) + } + if err = res.Body.Close(); err != nil { + t.Fatalf("res.Body.Close() = %v, want nil", err) + } +} + func TestClientRedirectEatsBody_h1(t *testing.T) { testClientRedirectEatsBody(t, h1Mode) } func TestClientRedirectEatsBody_h2(t *testing.T) { testClientRedirectEatsBody(t, h2Mode) } func testClientRedirectEatsBody(t *testing.T, h2 bool) { @@ -2082,3 +2132,47 @@ func (b *issue40382Body) Close() error { } return nil } + +func TestProbeZeroLengthBody(t *testing.T) { + setParallel(t) + defer afterTest(t) + reqc := make(chan struct{}) + cst := newClientServerTest(t, false, HandlerFunc(func(w ResponseWriter, r *Request) { + close(reqc) + if _, err := io.Copy(w, r.Body); err != nil { + t.Errorf("error copying request body: %v", err) + } + })) + defer cst.close() + + bodyr, bodyw := io.Pipe() + var gotBody string + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + req, _ := NewRequest("GET", cst.ts.URL, bodyr) + res, err := cst.c.Do(req) + b, err := io.ReadAll(res.Body) + if err != nil { + t.Error(err) + } + gotBody = string(b) + }() + + select { + case <-reqc: + // Request should be sent after trying to probe the request body for 200ms. + case <-time.After(60 * time.Second): + t.Errorf("request not sent after 60s") + } + + // Write the request body and wait for the request to complete. + const content = "body" + bodyw.Write([]byte(content)) + bodyw.Close() + wg.Wait() + if gotBody != content { + t.Fatalf("server got body %q, want %q", gotBody, content) + } +} diff --git a/libgo/go/net/http/clientserver_test.go b/libgo/go/net/http/clientserver_test.go index 42207ac..44d70f0 100644 --- a/libgo/go/net/http/clientserver_test.go +++ b/libgo/go/net/http/clientserver_test.go @@ -81,7 +81,7 @@ func optWithServerLog(lg *log.Logger) func(*httptest.Server) { } } -func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...interface{}) *clientServerTest { +func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...any) *clientServerTest { if h2 { CondSkipHTTP2(t) } @@ -189,7 +189,7 @@ type h12Compare struct { ReqFunc reqFunc // optional CheckResponse func(proto string, res *Response) // optional EarlyCheckResponse func(proto string, res *Response) // optional; pre-normalize - Opts []interface{} + Opts []any } func (tt h12Compare) reqFunc() reqFunc { @@ -441,7 +441,7 @@ func TestH12_AutoGzip(t *testing.T) { func TestH12_AutoGzip_Disabled(t *testing.T) { h12Compare{ - Opts: []interface{}{ + Opts: []any{ func(tr *Transport) { tr.DisableCompression = true }, }, Handler: func(w ResponseWriter, r *Request) { @@ -1172,7 +1172,7 @@ func TestInterruptWithPanic_ErrAbortHandler_h1(t *testing.T) { func TestInterruptWithPanic_ErrAbortHandler_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, ErrAbortHandler) } -func testInterruptWithPanic(t *testing.T, h2 bool, panicValue interface{}) { +func testInterruptWithPanic(t *testing.T, h2 bool, panicValue any) { setParallel(t) const msg = "hello" defer afterTest(t) @@ -1522,7 +1522,7 @@ func TestBidiStreamReverseProxy(t *testing.T) { })) defer proxy.close() - bodyRes := make(chan interface{}, 1) // error or hash.Hash + bodyRes := make(chan any, 1) // error or hash.Hash pr, pw := io.Pipe() req, _ := NewRequest("PUT", proxy.ts.URL, pr) const size = 4 << 20 @@ -1586,3 +1586,37 @@ func TestH12_WebSocketUpgrade(t *testing.T) { }, }.run(t) } + +func TestIdentityTransferEncoding_h1(t *testing.T) { testIdentityTransferEncoding(t, h1Mode) } +func TestIdentityTransferEncoding_h2(t *testing.T) { testIdentityTransferEncoding(t, h2Mode) } + +func testIdentityTransferEncoding(t *testing.T, h2 bool) { + setParallel(t) + defer afterTest(t) + + const body = "body" + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + gotBody, _ := io.ReadAll(r.Body) + if got, want := string(gotBody), body; got != want { + t.Errorf("got request body = %q; want %q", got, want) + } + w.Header().Set("Transfer-Encoding", "identity") + w.WriteHeader(StatusOK) + w.(Flusher).Flush() + io.WriteString(w, body) + })) + defer cst.close() + req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body)) + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + gotBody, err := io.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if got, want := string(gotBody), body; got != want { + t.Errorf("got response body = %q; want %q", got, want) + } +} diff --git a/libgo/go/net/http/cookie.go b/libgo/go/net/http/cookie.go index ca2c1c2..cb37f23 100644 --- a/libgo/go/net/http/cookie.go +++ b/libgo/go/net/http/cookie.go @@ -5,6 +5,8 @@ package http import ( + "errors" + "fmt" "log" "net" "net/http/internal/ascii" @@ -67,15 +69,14 @@ func readSetCookies(h Header) []*Cookie { continue } parts[0] = textproto.TrimString(parts[0]) - j := strings.Index(parts[0], "=") - if j < 0 { + name, value, ok := strings.Cut(parts[0], "=") + if !ok { continue } - name, value := parts[0][:j], parts[0][j+1:] if !isCookieNameValid(name) { continue } - value, ok := parseCookieValue(value, true) + value, ok = parseCookieValue(value, true) if !ok { continue } @@ -90,10 +91,7 @@ func readSetCookies(h Header) []*Cookie { continue } - attr, val := parts[i], "" - if j := strings.Index(attr, "="); j >= 0 { - attr, val = attr[:j], attr[j+1:] - } + attr, val, _ := strings.Cut(parts[i], "=") lowerAttr, isASCII := ascii.ToLower(attr) if !isASCII { continue @@ -240,6 +238,37 @@ func (c *Cookie) String() string { return b.String() } +// Valid reports whether the cookie is valid. +func (c *Cookie) Valid() error { + if c == nil { + return errors.New("http: nil Cookie") + } + if !isCookieNameValid(c.Name) { + return errors.New("http: invalid Cookie.Name") + } + if !validCookieExpires(c.Expires) { + return errors.New("http: invalid Cookie.Expires") + } + for i := 0; i < len(c.Value); i++ { + if !validCookieValueByte(c.Value[i]) { + return fmt.Errorf("http: invalid byte %q in Cookie.Value", c.Value[i]) + } + } + if len(c.Path) > 0 { + for i := 0; i < len(c.Path); i++ { + if !validCookiePathByte(c.Path[i]) { + return fmt.Errorf("http: invalid byte %q in Cookie.Path", c.Path[i]) + } + } + } + if len(c.Domain) > 0 { + if !validCookieDomain(c.Domain) { + return errors.New("http: invalid Cookie.Domain") + } + } + return nil +} + // readCookies parses all "Cookie" values from the header h and // returns the successfully parsed Cookies. // @@ -256,19 +285,12 @@ func readCookies(h Header, filter string) []*Cookie { var part string for len(line) > 0 { // continue since we have rest - if splitIndex := strings.Index(line, ";"); splitIndex > 0 { - part, line = line[:splitIndex], line[splitIndex+1:] - } else { - part, line = line, "" - } + part, line, _ = strings.Cut(line, ";") part = textproto.TrimString(part) - if len(part) == 0 { + if part == "" { continue } - name, val := part, "" - if j := strings.Index(part, "="); j >= 0 { - name, val = name[:j], name[j+1:] - } + name, val, _ := strings.Cut(part, "=") if !isCookieNameValid(name) { continue } @@ -379,7 +401,7 @@ func sanitizeCookieValue(v string) string { if len(v) == 0 { return v } - if strings.IndexByte(v, ' ') >= 0 || strings.IndexByte(v, ',') >= 0 { + if strings.ContainsAny(v, " ,") { return `"` + v + `"` } return v diff --git a/libgo/go/net/http/cookie_test.go b/libgo/go/net/http/cookie_test.go index 959713a..ccc5f98 100644 --- a/libgo/go/net/http/cookie_test.go +++ b/libgo/go/net/http/cookie_test.go @@ -360,7 +360,7 @@ var readSetCookiesTests = []struct { // Header{"Set-Cookie": {"ASP.NET_SessionId=foo; path=/; HttpOnly, .ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly"}}, } -func toJSON(v interface{}) string { +func toJSON(v any) string { b, err := json.Marshal(v) if err != nil { return fmt.Sprintf("%#v", v) @@ -529,6 +529,31 @@ func TestCookieSanitizePath(t *testing.T) { } } +func TestCookieValid(t *testing.T) { + tests := []struct { + cookie *Cookie + valid bool + }{ + {nil, false}, + {&Cookie{Name: ""}, false}, + {&Cookie{Name: "invalid-expires"}, false}, + {&Cookie{Name: "invalid-value", Value: "foo\"bar"}, false}, + {&Cookie{Name: "invalid-path", Path: "/foo;bar/"}, false}, + {&Cookie{Name: "invalid-domain", Domain: "example.com:80"}, false}, + {&Cookie{Name: "valid", Value: "foo", Path: "/bar", Domain: "example.com", Expires: time.Unix(0, 0)}, true}, + } + + for _, tt := range tests { + err := tt.cookie.Valid() + if err != nil && tt.valid { + t.Errorf("%#v.Valid() returned error %v; want nil", tt.cookie, err) + } + if err == nil && !tt.valid { + t.Errorf("%#v.Valid() returned nil; want error", tt.cookie) + } + } +} + func BenchmarkCookieString(b *testing.B) { const wantCookieString = `cookie-9=i3e01nf61b6t23bvfmplnanol3; Path=/restricted/; Domain=example.com; Expires=Tue, 10 Nov 2009 23:00:00 GMT; Max-Age=3600` c := &Cookie{ diff --git a/libgo/go/net/http/export_test.go b/libgo/go/net/http/export_test.go index 096a6d3..a849327 100644 --- a/libgo/go/net/http/export_test.go +++ b/libgo/go/net/http/export_test.go @@ -88,12 +88,7 @@ func SetPendingDialHooks(before, after func()) { func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn } -func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler { - ctx, cancel := context.WithCancel(context.Background()) - go func() { - <-ch - cancel() - }() +func NewTestTimeoutHandler(handler Handler, ctx context.Context) Handler { return &timeoutHandler{ handler: handler, testContext: ctx, diff --git a/libgo/go/net/http/fs.go b/libgo/go/net/http/fs.go index 57e731e..6caee9e 100644 --- a/libgo/go/net/http/fs.go +++ b/libgo/go/net/http/fs.go @@ -42,20 +42,20 @@ import ( // An empty Dir is treated as ".". type Dir string -// mapDirOpenError maps the provided non-nil error from opening name +// mapOpenError maps the provided non-nil error from opening name // to a possibly better non-nil error. In particular, it turns OS-specific errors -// about opening files in non-directories into fs.ErrNotExist. See Issue 18984. -func mapDirOpenError(originalErr error, name string) error { +// about opening files in non-directories into fs.ErrNotExist. See Issues 18984 and 49552. +func mapOpenError(originalErr error, name string, sep rune, stat func(string) (fs.FileInfo, error)) error { if errors.Is(originalErr, fs.ErrNotExist) || errors.Is(originalErr, fs.ErrPermission) { return originalErr } - parts := strings.Split(name, string(filepath.Separator)) + parts := strings.Split(name, string(sep)) for i := range parts { if parts[i] == "" { continue } - fi, err := os.Stat(strings.Join(parts[:i+1], string(filepath.Separator))) + fi, err := stat(strings.Join(parts[:i+1], string(sep))) if err != nil { return originalErr } @@ -79,7 +79,7 @@ func (d Dir) Open(name string) (File, error) { fullName := filepath.Join(dir, filepath.FromSlash(path.Clean("/"+name))) f, err := os.Open(fullName) if err != nil { - return nil, mapDirOpenError(err, fullName) + return nil, mapOpenError(err, fullName, filepath.Separator, os.Stat) } return f, nil } @@ -759,7 +759,9 @@ func (f ioFS) Open(name string) (File, error) { } file, err := f.fsys.Open(name) if err != nil { - return nil, err + return nil, mapOpenError(err, name, '/', func(path string) (fs.FileInfo, error) { + return fs.Stat(f.fsys, path) + }) } return ioFile{file}, nil } @@ -881,11 +883,11 @@ func parseRange(s string, size int64) ([]httpRange, error) { if ra == "" { continue } - i := strings.Index(ra, "-") - if i < 0 { + start, end, ok := strings.Cut(ra, "-") + if !ok { return nil, errors.New("invalid range") } - start, end := textproto.TrimString(ra[:i]), textproto.TrimString(ra[i+1:]) + start, end = textproto.TrimString(start), textproto.TrimString(end) var r httpRange if start == "" { // If no start is specified, end specifies the diff --git a/libgo/go/net/http/fs_test.go b/libgo/go/net/http/fs_test.go index b42ade1..d627dfd 100644 --- a/libgo/go/net/http/fs_test.go +++ b/libgo/go/net/http/fs_test.go @@ -658,7 +658,7 @@ type fakeFileInfo struct { } func (f *fakeFileInfo) Name() string { return f.basename } -func (f *fakeFileInfo) Sys() interface{} { return nil } +func (f *fakeFileInfo) Sys() any { return nil } func (f *fakeFileInfo) ModTime() time.Time { return f.modtime } func (f *fakeFileInfo) IsDir() bool { return f.dir } func (f *fakeFileInfo) Size() int64 { return int64(len(f.contents)) } @@ -1244,10 +1244,19 @@ func TestLinuxSendfileChild(*testing.T) { } } -// Issue 18984: tests that requests for paths beyond files return not-found errors +// Issues 18984, 49552: tests that requests for paths beyond files return not-found errors func TestFileServerNotDirError(t *testing.T) { defer afterTest(t) - ts := httptest.NewServer(FileServer(Dir("testdata"))) + t.Run("Dir", func(t *testing.T) { + testFileServerNotDirError(t, func(path string) FileSystem { return Dir(path) }) + }) + t.Run("FS", func(t *testing.T) { + testFileServerNotDirError(t, func(path string) FileSystem { return FS(os.DirFS(path)) }) + }) +} + +func testFileServerNotDirError(t *testing.T, newfs func(string) FileSystem) { + ts := httptest.NewServer(FileServer(newfs("testdata"))) defer ts.Close() res, err := Get(ts.URL + "/index.html/not-a-file") @@ -1259,9 +1268,9 @@ func TestFileServerNotDirError(t *testing.T) { t.Errorf("StatusCode = %v; want 404", res.StatusCode) } - test := func(name string, dir Dir) { + test := func(name string, fsys FileSystem) { t.Run(name, func(t *testing.T) { - _, err = dir.Open("/index.html/not-a-file") + _, err = fsys.Open("/index.html/not-a-file") if err == nil { t.Fatal("err == nil; want != nil") } @@ -1270,7 +1279,7 @@ func TestFileServerNotDirError(t *testing.T) { errors.Is(err, fs.ErrNotExist)) } - _, err = dir.Open("/index.html/not-a-dir/not-a-file") + _, err = fsys.Open("/index.html/not-a-dir/not-a-file") if err == nil { t.Fatal("err == nil; want != nil") } @@ -1286,8 +1295,8 @@ func TestFileServerNotDirError(t *testing.T) { t.Fatal("get abs path:", err) } - test("RelativePath", Dir("testdata")) - test("AbsolutePath", Dir(absPath)) + test("RelativePath", newfs("testdata")) + test("AbsolutePath", newfs(absPath)) } func TestFileServerCleanPath(t *testing.T) { diff --git a/libgo/go/net/http/h2_bundle.go b/libgo/go/net/http/h2_bundle.go index 8958a9e..bb82f24 100644 --- a/libgo/go/net/http/h2_bundle.go +++ b/libgo/go/net/http/h2_bundle.go @@ -53,6 +53,10 @@ import ( "golang.org/x/net/idna" ) +// The HTTP protocols are defined in terms of ASCII, not Unicode. This file +// contains helper functions which may use Unicode-aware functions which would +// otherwise be unsafe and could introduce vulnerabilities if used improperly. + // asciiEqualFold is strings.EqualFold, ASCII only. It reports whether s and t // are equal, ASCII-case-insensitively. func http2asciiEqualFold(s, t string) bool { @@ -733,6 +737,12 @@ func http2isBadCipher(cipher uint16) bool { // ClientConnPool manages a pool of HTTP/2 client connections. type http2ClientConnPool interface { + // GetClientConn returns a specific HTTP/2 connection (usually + // a TLS-TCP connection) to an HTTP/2 server. On success, the + // returned ClientConn accounts for the upcoming RoundTrip + // call, so the caller should not omit it. If the caller needs + // to, ClientConn.RoundTrip can be called with a bogus + // new(http.Request) to release the stream reservation. GetClientConn(req *Request, addr string) (*http2ClientConn, error) MarkDead(*http2ClientConn) } @@ -759,7 +769,7 @@ type http2clientConnPool struct { conns map[string][]*http2ClientConn // key is host:port dialing map[string]*http2dialCall // currently in-flight dials keys map[*http2ClientConn][]string - addConnCalls map[string]*http2addConnCall // in-flight addConnIfNeede calls + addConnCalls map[string]*http2addConnCall // in-flight addConnIfNeeded calls } func (p *http2clientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) { @@ -771,28 +781,8 @@ const ( http2noDialOnMiss = false ) -// shouldTraceGetConn reports whether getClientConn should call any -// ClientTrace.GetConn hook associated with the http.Request. -// -// This complexity is needed to avoid double calls of the GetConn hook -// during the back-and-forth between net/http and x/net/http2 (when the -// net/http.Transport is upgraded to also speak http2), as well as support -// the case where x/net/http2 is being used directly. -func (p *http2clientConnPool) shouldTraceGetConn(st http2clientConnIdleState) bool { - // If our Transport wasn't made via ConfigureTransport, always - // trace the GetConn hook if provided, because that means the - // http2 package is being used directly and it's the one - // dialing, as opposed to net/http. - if _, ok := p.t.ConnPool.(http2noDialClientConnPool); !ok { - return true - } - // Otherwise, only use the GetConn hook if this connection has - // been used previously for other requests. For fresh - // connections, the net/http package does the dialing. - return !st.freshConn -} - func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) { + // TODO(dneil): Dial a new connection when t.DisableKeepAlives is set? if http2isConnectionCloseRequest(req) && dialOnMiss { // It gets its own connection. http2traceGetConn(req, addr) @@ -806,10 +796,14 @@ func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMis for { p.mu.Lock() for _, cc := range p.conns[addr] { - if st := cc.idleState(); st.canTakeNewRequest { - if p.shouldTraceGetConn(st) { + if cc.ReserveNewRequest() { + // When a connection is presented to us by the net/http package, + // the GetConn hook has already been called. + // Don't call it a second time here. + if !cc.getConnCalled { http2traceGetConn(req, addr) } + cc.getConnCalled = false p.mu.Unlock() return cc, nil } @@ -825,7 +819,13 @@ func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMis if http2shouldRetryDial(call, req) { continue } - return call.res, call.err + cc, err := call.res, call.err + if err != nil { + return nil, err + } + if cc.ReserveNewRequest() { + return cc, nil + } } } @@ -922,6 +922,7 @@ func (c *http2addConnCall) run(t *http2Transport, key string, tc *tls.Conn) { if err != nil { c.err = err } else { + cc.getConnCalled = true // already called by the net/http package p.addConnLocked(key, cc) } delete(p.addConnCalls, key) @@ -1208,6 +1209,13 @@ func (e http2ErrCode) String() string { return fmt.Sprintf("unknown error code 0x%x", uint32(e)) } +func (e http2ErrCode) stringToken() string { + if s, ok := http2errCodeName[e]; ok { + return s + } + return fmt.Sprintf("ERR_UNKNOWN_%d", uint32(e)) +} + // ConnectionError is an error that results in the termination of the // entire connection. type http2ConnectionError http2ErrCode @@ -1224,6 +1232,11 @@ type http2StreamError struct { Cause error // optional additional detail } +// errFromPeer is a sentinel error value for StreamError.Cause to +// indicate that the StreamError was sent from the peer over the wire +// and wasn't locally generated in the Transport. +var http2errFromPeer = errors.New("received from peer") + func http2streamError(id uint32, code http2ErrCode) http2StreamError { return http2StreamError{StreamID: id, Code: code} } @@ -1438,7 +1451,7 @@ var http2flagName = map[http2FrameType]map[http2Flags]string{ // a frameParser parses a frame given its FrameHeader and payload // bytes. The length of payload will always equal fh.Length (which // might be 0). -type http2frameParser func(fc *http2frameCache, fh http2FrameHeader, payload []byte) (http2Frame, error) +type http2frameParser func(fc *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) var http2frameParsers = map[http2FrameType]http2frameParser{ http2FrameData: http2parseDataFrame, @@ -1583,6 +1596,11 @@ type http2Framer struct { lastFrame http2Frame errDetail error + // countError is a non-nil func that's called on a frame parse + // error with some unique error path token. It's initialized + // from Transport.CountError or Server.CountError. + countError func(errToken string) + // lastHeaderStream is non-zero if the last frame was an // unfinished HEADERS/CONTINUATION. lastHeaderStream uint32 @@ -1745,6 +1763,7 @@ func http2NewFramer(w io.Writer, r io.Reader) *http2Framer { fr := &http2Framer{ w: w, r: r, + countError: func(string) {}, logReads: http2logFrameReads, logWrites: http2logFrameWrites, debugReadLoggerf: log.Printf, @@ -1819,7 +1838,7 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) { if _, err := io.ReadFull(fr.r, payload); err != nil { return nil, err } - f, err := http2typeFrameParser(fh.Type)(fr.frameCache, fh, payload) + f, err := http2typeFrameParser(fh.Type)(fr.frameCache, fh, fr.countError, payload) if err != nil { if ce, ok := err.(http2connError); ok { return nil, fr.connError(ce.Code, ce.Reason) @@ -1907,13 +1926,14 @@ func (f *http2DataFrame) Data() []byte { return f.data } -func http2parseDataFrame(fc *http2frameCache, fh http2FrameHeader, payload []byte) (http2Frame, error) { +func http2parseDataFrame(fc *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) { if fh.StreamID == 0 { // DATA frames MUST be associated with a stream. If a // DATA frame is received whose stream identifier // field is 0x0, the recipient MUST respond with a // connection error (Section 5.4.1) of type // PROTOCOL_ERROR. + countError("frame_data_stream_0") return nil, http2connError{http2ErrCodeProtocol, "DATA frame with stream ID 0"} } f := fc.getDataFrame() @@ -1924,6 +1944,7 @@ func http2parseDataFrame(fc *http2frameCache, fh http2FrameHeader, payload []byt var err error payload, padSize, err = http2readByte(payload) if err != nil { + countError("frame_data_pad_byte_short") return nil, err } } @@ -1932,6 +1953,7 @@ func http2parseDataFrame(fc *http2frameCache, fh http2FrameHeader, payload []byt // length of the frame payload, the recipient MUST // treat this as a connection error. // Filed: https://github.com/http2/http2-spec/issues/610 + countError("frame_data_pad_too_big") return nil, http2connError{http2ErrCodeProtocol, "pad size larger than data payload"} } f.data = payload[:len(payload)-int(padSize)] @@ -2014,7 +2036,7 @@ type http2SettingsFrame struct { p []byte } -func http2parseSettingsFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { +func http2parseSettingsFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { if fh.Flags.Has(http2FlagSettingsAck) && fh.Length > 0 { // When this (ACK 0x1) bit is set, the payload of the // SETTINGS frame MUST be empty. Receipt of a @@ -2022,6 +2044,7 @@ func http2parseSettingsFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) // field value other than 0 MUST be treated as a // connection error (Section 5.4.1) of type // FRAME_SIZE_ERROR. + countError("frame_settings_ack_with_length") return nil, http2ConnectionError(http2ErrCodeFrameSize) } if fh.StreamID != 0 { @@ -2032,14 +2055,17 @@ func http2parseSettingsFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) // field is anything other than 0x0, the endpoint MUST // respond with a connection error (Section 5.4.1) of // type PROTOCOL_ERROR. + countError("frame_settings_has_stream") return nil, http2ConnectionError(http2ErrCodeProtocol) } if len(p)%6 != 0 { + countError("frame_settings_mod_6") // Expecting even number of 6 byte settings. return nil, http2ConnectionError(http2ErrCodeFrameSize) } f := &http2SettingsFrame{http2FrameHeader: fh, p: p} if v, ok := f.Value(http2SettingInitialWindowSize); ok && v > (1<<31)-1 { + countError("frame_settings_window_size_too_big") // Values above the maximum flow control window size of 2^31 - 1 MUST // be treated as a connection error (Section 5.4.1) of type // FLOW_CONTROL_ERROR. @@ -2151,11 +2177,13 @@ type http2PingFrame struct { func (f *http2PingFrame) IsAck() bool { return f.Flags.Has(http2FlagPingAck) } -func http2parsePingFrame(_ *http2frameCache, fh http2FrameHeader, payload []byte) (http2Frame, error) { +func http2parsePingFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) { if len(payload) != 8 { + countError("frame_ping_length") return nil, http2ConnectionError(http2ErrCodeFrameSize) } if fh.StreamID != 0 { + countError("frame_ping_has_stream") return nil, http2ConnectionError(http2ErrCodeProtocol) } f := &http2PingFrame{http2FrameHeader: fh} @@ -2191,11 +2219,13 @@ func (f *http2GoAwayFrame) DebugData() []byte { return f.debugData } -func http2parseGoAwayFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { +func http2parseGoAwayFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { if fh.StreamID != 0 { + countError("frame_goaway_has_stream") return nil, http2ConnectionError(http2ErrCodeProtocol) } if len(p) < 8 { + countError("frame_goaway_short") return nil, http2ConnectionError(http2ErrCodeFrameSize) } return &http2GoAwayFrame{ @@ -2231,7 +2261,7 @@ func (f *http2UnknownFrame) Payload() []byte { return f.p } -func http2parseUnknownFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { +func http2parseUnknownFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { return &http2UnknownFrame{fh, p}, nil } @@ -2242,8 +2272,9 @@ type http2WindowUpdateFrame struct { Increment uint32 // never read with high bit set } -func http2parseWindowUpdateFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { +func http2parseWindowUpdateFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { if len(p) != 4 { + countError("frame_windowupdate_bad_len") return nil, http2ConnectionError(http2ErrCodeFrameSize) } inc := binary.BigEndian.Uint32(p[:4]) & 0x7fffffff // mask off high reserved bit @@ -2255,8 +2286,10 @@ func http2parseWindowUpdateFrame(_ *http2frameCache, fh http2FrameHeader, p []by // control window MUST be treated as a connection // error (Section 5.4.1). if fh.StreamID == 0 { + countError("frame_windowupdate_zero_inc_conn") return nil, http2ConnectionError(http2ErrCodeProtocol) } + countError("frame_windowupdate_zero_inc_stream") return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol) } return &http2WindowUpdateFrame{ @@ -2307,7 +2340,7 @@ func (f *http2HeadersFrame) HasPriority() bool { return f.http2FrameHeader.Flags.Has(http2FlagHeadersPriority) } -func http2parseHeadersFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (_ http2Frame, err error) { +func http2parseHeadersFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (_ http2Frame, err error) { hf := &http2HeadersFrame{ http2FrameHeader: fh, } @@ -2316,11 +2349,13 @@ func http2parseHeadersFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) ( // is received whose stream identifier field is 0x0, the recipient MUST // respond with a connection error (Section 5.4.1) of type // PROTOCOL_ERROR. + countError("frame_headers_zero_stream") return nil, http2connError{http2ErrCodeProtocol, "HEADERS frame with stream ID 0"} } var padLength uint8 if fh.Flags.Has(http2FlagHeadersPadded) { if p, padLength, err = http2readByte(p); err != nil { + countError("frame_headers_pad_short") return } } @@ -2328,16 +2363,19 @@ func http2parseHeadersFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) ( var v uint32 p, v, err = http2readUint32(p) if err != nil { + countError("frame_headers_prio_short") return nil, err } hf.Priority.StreamDep = v & 0x7fffffff hf.Priority.Exclusive = (v != hf.Priority.StreamDep) // high bit was set p, hf.Priority.Weight, err = http2readByte(p) if err != nil { + countError("frame_headers_prio_weight_short") return nil, err } } - if len(p)-int(padLength) <= 0 { + if len(p)-int(padLength) < 0 { + countError("frame_headers_pad_too_big") return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol) } hf.headerFragBuf = p[:len(p)-int(padLength)] @@ -2444,11 +2482,13 @@ func (p http2PriorityParam) IsZero() bool { return p == http2PriorityParam{} } -func http2parsePriorityFrame(_ *http2frameCache, fh http2FrameHeader, payload []byte) (http2Frame, error) { +func http2parsePriorityFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), payload []byte) (http2Frame, error) { if fh.StreamID == 0 { + countError("frame_priority_zero_stream") return nil, http2connError{http2ErrCodeProtocol, "PRIORITY frame with stream ID 0"} } if len(payload) != 5 { + countError("frame_priority_bad_length") return nil, http2connError{http2ErrCodeFrameSize, fmt.Sprintf("PRIORITY frame payload size was %d; want 5", len(payload))} } v := binary.BigEndian.Uint32(payload[:4]) @@ -2491,11 +2531,13 @@ type http2RSTStreamFrame struct { ErrCode http2ErrCode } -func http2parseRSTStreamFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { +func http2parseRSTStreamFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { if len(p) != 4 { + countError("frame_rststream_bad_len") return nil, http2ConnectionError(http2ErrCodeFrameSize) } if fh.StreamID == 0 { + countError("frame_rststream_zero_stream") return nil, http2ConnectionError(http2ErrCodeProtocol) } return &http2RSTStreamFrame{fh, http2ErrCode(binary.BigEndian.Uint32(p[:4]))}, nil @@ -2521,8 +2563,9 @@ type http2ContinuationFrame struct { headerFragBuf []byte } -func http2parseContinuationFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { +func http2parseContinuationFrame(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (http2Frame, error) { if fh.StreamID == 0 { + countError("frame_continuation_zero_stream") return nil, http2connError{http2ErrCodeProtocol, "CONTINUATION frame with stream ID 0"} } return &http2ContinuationFrame{fh, p}, nil @@ -2571,7 +2614,7 @@ func (f *http2PushPromiseFrame) HeadersEnded() bool { return f.http2FrameHeader.Flags.Has(http2FlagPushPromiseEndHeaders) } -func http2parsePushPromise(_ *http2frameCache, fh http2FrameHeader, p []byte) (_ http2Frame, err error) { +func http2parsePushPromise(_ *http2frameCache, fh http2FrameHeader, countError func(string), p []byte) (_ http2Frame, err error) { pp := &http2PushPromiseFrame{ http2FrameHeader: fh, } @@ -2582,6 +2625,7 @@ func http2parsePushPromise(_ *http2frameCache, fh http2FrameHeader, p []byte) (_ // with. If the stream identifier field specifies the value // 0x0, a recipient MUST respond with a connection error // (Section 5.4.1) of type PROTOCOL_ERROR. + countError("frame_pushpromise_zero_stream") return nil, http2ConnectionError(http2ErrCodeProtocol) } // The PUSH_PROMISE frame includes optional padding. @@ -2589,18 +2633,21 @@ func http2parsePushPromise(_ *http2frameCache, fh http2FrameHeader, p []byte) (_ var padLength uint8 if fh.Flags.Has(http2FlagPushPromisePadded) { if p, padLength, err = http2readByte(p); err != nil { + countError("frame_pushpromise_pad_short") return } } p, pp.PromiseID, err = http2readUint32(p) if err != nil { + countError("frame_pushpromise_promiseid_short") return } pp.PromiseID = pp.PromiseID & (1<<31 - 1) if int(padLength) > len(p) { // like the DATA frame, error out if padding is longer than the body. + countError("frame_pushpromise_pad_too_big") return nil, http2ConnectionError(http2ErrCodeProtocol) } pp.headerFragBuf = p[:len(p)-int(padLength)] @@ -3570,6 +3617,17 @@ type http2pipeBuffer interface { io.Reader } +// setBuffer initializes the pipe buffer. +// It has no effect if the pipe is already closed. +func (p *http2pipe) setBuffer(b http2pipeBuffer) { + p.mu.Lock() + defer p.mu.Unlock() + if p.err != nil || p.breakErr != nil { + return + } + p.b = b +} + func (p *http2pipe) Len() int { p.mu.Lock() defer p.mu.Unlock() @@ -3786,6 +3844,12 @@ type http2Server struct { // If nil, a default scheduler is chosen. NewWriteScheduler func() http2WriteScheduler + // CountError, if non-nil, is called on HTTP/2 server errors. + // It's intended to increment a metric for monitoring, such + // as an expvar or Prometheus metric. + // The errType consists of only ASCII word characters. + CountError func(errType string) + // Internal state. This is a pointer (rather than embedded directly) // so that we don't embed a Mutex in this struct, which will make the // struct non-copyable, which might break some callers. @@ -3915,16 +3979,12 @@ func http2ConfigureServer(s *Server, conf *http2Server) error { s.TLSConfig.PreferServerCipherSuites = true - haveNPN := false - for _, p := range s.TLSConfig.NextProtos { - if p == http2NextProtoTLS { - haveNPN = true - break - } - } - if !haveNPN { + if !http2strSliceContains(s.TLSConfig.NextProtos, http2NextProtoTLS) { s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, http2NextProtoTLS) } + if !http2strSliceContains(s.TLSConfig.NextProtos, "http/1.1") { + s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "http/1.1") + } if s.TLSNextProto == nil { s.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){} @@ -4065,6 +4125,9 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) fr := http2NewFramer(sc.bw, c) + if s.CountError != nil { + fr.countError = s.CountError + } fr.ReadMetaHeaders = hpack.NewDecoder(http2initialHeaderTableSize, nil) fr.MaxHeaderListSize = sc.maxHeaderListSize() fr.SetMaxReadFrameSize(s.maxReadFrameSize()) @@ -4373,7 +4436,15 @@ func (sc *http2serverConn) canonicalHeader(v string) string { sc.canonHeader = make(map[string]string) } cv = CanonicalHeaderKey(v) - sc.canonHeader[v] = cv + // maxCachedCanonicalHeaders is an arbitrarily-chosen limit on the number of + // entries in the canonHeader cache. This should be larger than the number + // of unique, uncommon header keys likely to be sent by the peer, while not + // so high as to permit unreaasonable memory usage if the peer sends an unbounded + // number of unique header keys. + const maxCachedCanonicalHeaders = 32 + if len(sc.canonHeader) < maxCachedCanonicalHeaders { + sc.canonHeader[v] = cv + } return cv } @@ -4479,7 +4550,7 @@ func (sc *http2serverConn) serve() { }) sc.unackedSettings++ - // Each connection starts with intialWindowSize inflow tokens. + // Each connection starts with initialWindowSize inflow tokens. // If a higher value is configured, we add more tokens. if diff := sc.srv.initialConnRecvWindowSize() - http2initialWindowSize; diff > 0 { sc.sendWindowUpdate(nil, int(diff)) @@ -5064,7 +5135,7 @@ func (sc *http2serverConn) processFrame(f http2Frame) error { // First frame received must be SETTINGS. if !sc.sawFirstSettings { if _, ok := f.(*http2SettingsFrame); !ok { - return http2ConnectionError(http2ErrCodeProtocol) + return sc.countError("first_settings", http2ConnectionError(http2ErrCodeProtocol)) } sc.sawFirstSettings = true } @@ -5089,7 +5160,7 @@ func (sc *http2serverConn) processFrame(f http2Frame) error { case *http2PushPromiseFrame: // A client cannot push. Thus, servers MUST treat the receipt of a PUSH_PROMISE // frame as a connection error (Section 5.4.1) of type PROTOCOL_ERROR. - return http2ConnectionError(http2ErrCodeProtocol) + return sc.countError("push_promise", http2ConnectionError(http2ErrCodeProtocol)) default: sc.vlogf("http2: server ignoring frame: %v", f.Header()) return nil @@ -5109,7 +5180,7 @@ func (sc *http2serverConn) processPing(f *http2PingFrame) error { // identifier field value other than 0x0, the recipient MUST // respond with a connection error (Section 5.4.1) of type // PROTOCOL_ERROR." - return http2ConnectionError(http2ErrCodeProtocol) + return sc.countError("ping_on_stream", http2ConnectionError(http2ErrCodeProtocol)) } if sc.inGoAway && sc.goAwayCode != http2ErrCodeNo { return nil @@ -5128,7 +5199,7 @@ func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error // or PRIORITY on a stream in this state MUST be // treated as a connection error (Section 5.4.1) of // type PROTOCOL_ERROR." - return http2ConnectionError(http2ErrCodeProtocol) + return sc.countError("stream_idle", http2ConnectionError(http2ErrCodeProtocol)) } if st == nil { // "WINDOW_UPDATE can be sent by a peer that has sent a @@ -5139,7 +5210,7 @@ func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error return nil } if !st.flow.add(int32(f.Increment)) { - return http2streamError(f.StreamID, http2ErrCodeFlowControl) + return sc.countError("bad_flow", http2streamError(f.StreamID, http2ErrCodeFlowControl)) } default: // connection-level flow control if !sc.flow.add(int32(f.Increment)) { @@ -5160,7 +5231,7 @@ func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error { // identifying an idle stream is received, the // recipient MUST treat this as a connection error // (Section 5.4.1) of type PROTOCOL_ERROR. - return http2ConnectionError(http2ErrCodeProtocol) + return sc.countError("reset_idle_stream", http2ConnectionError(http2ErrCodeProtocol)) } if st != nil { st.cancelCtx() @@ -5212,7 +5283,7 @@ func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error { // Why is the peer ACKing settings we never sent? // The spec doesn't mention this case, but // hang up on them anyway. - return http2ConnectionError(http2ErrCodeProtocol) + return sc.countError("ack_mystery", http2ConnectionError(http2ErrCodeProtocol)) } return nil } @@ -5220,7 +5291,7 @@ func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error { // This isn't actually in the spec, but hang up on // suspiciously large settings frames or those with // duplicate entries. - return http2ConnectionError(http2ErrCodeProtocol) + return sc.countError("settings_big_or_dups", http2ConnectionError(http2ErrCodeProtocol)) } if err := f.ForeachSetting(sc.processSetting); err != nil { return err @@ -5287,7 +5358,7 @@ func (sc *http2serverConn) processSettingInitialWindowSize(val uint32) error { // control window to exceed the maximum size as a // connection error (Section 5.4.1) of type // FLOW_CONTROL_ERROR." - return http2ConnectionError(http2ErrCodeFlowControl) + return sc.countError("setting_win_size", http2ConnectionError(http2ErrCodeFlowControl)) } } return nil @@ -5320,7 +5391,7 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { // or PRIORITY on a stream in this state MUST be // treated as a connection error (Section 5.4.1) of // type PROTOCOL_ERROR." - return http2ConnectionError(http2ErrCodeProtocol) + return sc.countError("data_on_idle", http2ConnectionError(http2ErrCodeProtocol)) } // "If a DATA frame is received whose stream is not in "open" @@ -5337,7 +5408,7 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { // and return any flow control bytes since we're not going // to consume them. if sc.inflow.available() < int32(f.Length) { - return http2streamError(id, http2ErrCodeFlowControl) + return sc.countError("data_flow", http2streamError(id, http2ErrCodeFlowControl)) } // Deduct the flow control from inflow, since we're // going to immediately add it back in @@ -5350,7 +5421,7 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { // Already have a stream error in flight. Don't send another. return nil } - return http2streamError(id, http2ErrCodeStreamClosed) + return sc.countError("closed", http2streamError(id, http2ErrCodeStreamClosed)) } if st.body == nil { panic("internal error: should have a body in this state") @@ -5362,12 +5433,12 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { // RFC 7540, sec 8.1.2.6: A request or response is also malformed if the // value of a content-length header field does not equal the sum of the // DATA frame payload lengths that form the body. - return http2streamError(id, http2ErrCodeProtocol) + return sc.countError("send_too_much", http2streamError(id, http2ErrCodeProtocol)) } if f.Length > 0 { // Check whether the client has flow control quota. if st.inflow.available() < int32(f.Length) { - return http2streamError(id, http2ErrCodeFlowControl) + return sc.countError("flow_on_data_length", http2streamError(id, http2ErrCodeFlowControl)) } st.inflow.take(int32(f.Length)) @@ -5375,7 +5446,7 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { wrote, err := st.body.Write(data) if err != nil { sc.sendWindowUpdate(nil, int(f.Length)-wrote) - return http2streamError(id, http2ErrCodeStreamClosed) + return sc.countError("body_write_err", http2streamError(id, http2ErrCodeStreamClosed)) } if wrote != len(data) { panic("internal error: bad Writer") @@ -5461,7 +5532,7 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { // stream identifier MUST respond with a connection error // (Section 5.4.1) of type PROTOCOL_ERROR. if id%2 != 1 { - return http2ConnectionError(http2ErrCodeProtocol) + return sc.countError("headers_even", http2ConnectionError(http2ErrCodeProtocol)) } // A HEADERS frame can be used to create a new stream or // send a trailer for an open one. If we already have a stream @@ -5478,7 +5549,7 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { // this state, it MUST respond with a stream error (Section 5.4.2) of // type STREAM_CLOSED. if st.state == http2stateHalfClosedRemote { - return http2streamError(id, http2ErrCodeStreamClosed) + return sc.countError("headers_half_closed", http2streamError(id, http2ErrCodeStreamClosed)) } return st.processTrailerHeaders(f) } @@ -5489,7 +5560,7 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { // receives an unexpected stream identifier MUST respond with // a connection error (Section 5.4.1) of type PROTOCOL_ERROR. if id <= sc.maxClientStreamID { - return http2ConnectionError(http2ErrCodeProtocol) + return sc.countError("stream_went_down", http2ConnectionError(http2ErrCodeProtocol)) } sc.maxClientStreamID = id @@ -5506,14 +5577,14 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { if sc.curClientStreams+1 > sc.advMaxStreams { if sc.unackedSettings == 0 { // They should know better. - return http2streamError(id, http2ErrCodeProtocol) + return sc.countError("over_max_streams", http2streamError(id, http2ErrCodeProtocol)) } // Assume it's a network race, where they just haven't // received our last SETTINGS update. But actually // this can't happen yet, because we don't yet provide // a way for users to adjust server parameters at // runtime. - return http2streamError(id, http2ErrCodeRefusedStream) + return sc.countError("over_max_streams_race", http2streamError(id, http2ErrCodeRefusedStream)) } initialState := http2stateOpen @@ -5523,7 +5594,7 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { st := sc.newStream(id, 0, initialState) if f.HasPriority() { - if err := http2checkPriority(f.StreamID, f.Priority); err != nil { + if err := sc.checkPriority(f.StreamID, f.Priority); err != nil { return err } sc.writeSched.AdjustStream(st.id, f.Priority) @@ -5567,15 +5638,15 @@ func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error { sc := st.sc sc.serveG.check() if st.gotTrailerHeader { - return http2ConnectionError(http2ErrCodeProtocol) + return sc.countError("dup_trailers", http2ConnectionError(http2ErrCodeProtocol)) } st.gotTrailerHeader = true if !f.StreamEnded() { - return http2streamError(st.id, http2ErrCodeProtocol) + return sc.countError("trailers_not_ended", http2streamError(st.id, http2ErrCodeProtocol)) } if len(f.PseudoFields()) > 0 { - return http2streamError(st.id, http2ErrCodeProtocol) + return sc.countError("trailers_pseudo", http2streamError(st.id, http2ErrCodeProtocol)) } if st.trailer != nil { for _, hf := range f.RegularFields() { @@ -5584,7 +5655,7 @@ func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error { // TODO: send more details to the peer somehow. But http2 has // no way to send debug data at a stream level. Discuss with // HTTP folk. - return http2streamError(st.id, http2ErrCodeProtocol) + return sc.countError("trailers_bogus", http2streamError(st.id, http2ErrCodeProtocol)) } st.trailer[key] = append(st.trailer[key], hf.Value) } @@ -5593,13 +5664,13 @@ func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error { return nil } -func http2checkPriority(streamID uint32, p http2PriorityParam) error { +func (sc *http2serverConn) checkPriority(streamID uint32, p http2PriorityParam) error { if streamID == p.StreamDep { // Section 5.3.1: "A stream cannot depend on itself. An endpoint MUST treat // this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR." // Section 5.3.3 says that a stream can depend on one of its dependencies, // so it's only self-dependencies that are forbidden. - return http2streamError(streamID, http2ErrCodeProtocol) + return sc.countError("priority", http2streamError(streamID, http2ErrCodeProtocol)) } return nil } @@ -5608,7 +5679,7 @@ func (sc *http2serverConn) processPriority(f *http2PriorityFrame) error { if sc.inGoAway { return nil } - if err := http2checkPriority(f.StreamID, f.http2PriorityParam); err != nil { + if err := sc.checkPriority(f.StreamID, f.http2PriorityParam); err != nil { return err } sc.writeSched.AdjustStream(f.StreamID, f.http2PriorityParam) @@ -5665,7 +5736,7 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead isConnect := rp.method == "CONNECT" if isConnect { if rp.path != "" || rp.scheme != "" || rp.authority == "" { - return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) + return nil, nil, sc.countError("bad_connect", http2streamError(f.StreamID, http2ErrCodeProtocol)) } } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") { // See 8.1.2.6 Malformed Requests and Responses: @@ -5678,13 +5749,13 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead // "All HTTP/2 requests MUST include exactly one valid // value for the :method, :scheme, and :path // pseudo-header fields" - return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) + return nil, nil, sc.countError("bad_path_method", http2streamError(f.StreamID, http2ErrCodeProtocol)) } bodyOpen := !f.StreamEnded() if rp.method == "HEAD" && bodyOpen { // HEAD requests can't have bodies - return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) + return nil, nil, sc.countError("head_body", http2streamError(f.StreamID, http2ErrCodeProtocol)) } rp.header = make(Header) @@ -5767,7 +5838,7 @@ func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2re var err error url_, err = url.ParseRequestURI(rp.path) if err != nil { - return nil, nil, http2streamError(st.id, http2ErrCodeProtocol) + return nil, nil, sc.countError("bad_path", http2streamError(st.id, http2ErrCodeProtocol)) } requestURI = rp.path } @@ -6651,6 +6722,34 @@ func http2h1ServerKeepAlivesDisabled(hs *Server) bool { return false } +func (sc *http2serverConn) countError(name string, err error) error { + if sc == nil || sc.srv == nil { + return err + } + f := sc.srv.CountError + if f == nil { + return err + } + var typ string + var code http2ErrCode + switch e := err.(type) { + case http2ConnectionError: + typ = "conn" + code = http2ErrCode(e) + case http2StreamError: + typ = "stream" + code = http2ErrCode(e.Code) + default: + return err + } + codeStr := http2errCodeName[code] + if codeStr == "" { + codeStr = strconv.Itoa(int(code)) + } + f(fmt.Sprintf("%s_%s_%s", typ, codeStr, name)) + return err +} + const ( // transportDefaultConnFlow is how many connection-level flow control // tokens we give the server at start-up, past the default 64k. @@ -6666,6 +6765,15 @@ const ( http2transportDefaultStreamMinRefresh = 4 << 10 http2defaultUserAgent = "Go-http-client/2.0" + + // initialMaxConcurrentStreams is a connections maxConcurrentStreams until + // it's received servers initial SETTINGS frame, which corresponds with the + // spec's minimum recommended value. + http2initialMaxConcurrentStreams = 100 + + // defaultMaxConcurrentStreams is a connections default maxConcurrentStreams + // if the server doesn't include one in its initial SETTINGS frame. + http2defaultMaxConcurrentStreams = 1000 ) // Transport is an HTTP/2 Transport. @@ -6736,6 +6844,17 @@ type http2Transport struct { // Defaults to 15s. PingTimeout time.Duration + // WriteByteTimeout is the timeout after which the connection will be + // closed no data can be written to it. The timeout begins when data is + // available to write, and is extended whenever any bytes are written. + WriteByteTimeout time.Duration + + // CountError, if non-nil, is called on HTTP/2 transport errors. + // It's intended to increment a metric for monitoring, such + // as an expvar or Prometheus metric. + // The errType consists of only ASCII word characters. + CountError func(errType string) + // t1, if non-nil, is the standard library Transport using // this transport. Its settings are used (but not its // RoundTrip method, etc). @@ -6842,11 +6961,12 @@ func (t *http2Transport) initConnPool() { // ClientConn is the state of a single HTTP/2 client connection to an // HTTP/2 server. type http2ClientConn struct { - t *http2Transport - tconn net.Conn // usually *tls.Conn, except specialized impls - tlsState *tls.ConnectionState // nil only for specialized impls - reused uint32 // whether conn is being reused; atomic - singleUse bool // whether being used for a single http.Request + t *http2Transport + tconn net.Conn // usually *tls.Conn, except specialized impls + tlsState *tls.ConnectionState // nil only for specialized impls + reused uint32 // whether conn is being reused; atomic + singleUse bool // whether being used for a single http.Request + getConnCalled bool // used by clientConnPool // readLoop goroutine fields: readerDone chan struct{} // closed on error @@ -6859,87 +6979,94 @@ type http2ClientConn struct { cond *sync.Cond // hold mu; broadcast on flow/closed changes flow http2flow // our conn-level flow control quota (cs.flow is per stream) inflow http2flow // peer's conn-level flow control + doNotReuse bool // whether conn is marked to not be reused for any future requests closing bool closed bool + seenSettings bool // true if we've seen a settings frame, false otherwise wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back goAway *http2GoAwayFrame // if non-nil, the GoAwayFrame we received goAwayDebug string // goAway frame's debug data, retained as a string streams map[uint32]*http2clientStream // client-initiated + streamsReserved int // incr by ReserveNewRequest; decr on RoundTrip nextStreamID uint32 pendingRequests int // requests blocked and waiting to be sent because len(streams) == maxConcurrentStreams pings map[[8]byte]chan struct{} // in flight ping data to notification channel - bw *bufio.Writer br *bufio.Reader - fr *http2Framer lastActive time.Time lastIdle time.Time // time last idle - // Settings from peer: (also guarded by mu) + // Settings from peer: (also guarded by wmu) maxFrameSize uint32 maxConcurrentStreams uint32 peerMaxHeaderListSize uint64 initialWindowSize uint32 - hbuf bytes.Buffer // HPACK encoder writes into this - henc *hpack.Encoder - freeBuf [][]byte + // reqHeaderMu is a 1-element semaphore channel controlling access to sending new requests. + // Write to reqHeaderMu to lock it, read from it to unlock. + // Lock reqmu BEFORE mu or wmu. + reqHeaderMu chan struct{} - wmu sync.Mutex // held while writing; acquire AFTER mu if holding both - werr error // first write error that has occurred + // wmu is held while writing. + // Acquire BEFORE mu when holding both, to avoid blocking mu on network writes. + // Only acquire both at the same time when changing peer settings. + wmu sync.Mutex + bw *bufio.Writer + fr *http2Framer + werr error // first write error that has occurred + hbuf bytes.Buffer // HPACK encoder writes into this + henc *hpack.Encoder } // clientStream is the state for a single HTTP/2 stream. One of these // is created for each Transport.RoundTrip call. type http2clientStream struct { - cc *http2ClientConn - req *Request + cc *http2ClientConn + + // Fields of Request that we may access even after the response body is closed. + ctx context.Context + reqCancel <-chan struct{} + trace *httptrace.ClientTrace // or nil ID uint32 - resc chan http2resAndError bufPipe http2pipe // buffered pipe with the flow-controlled response payload - startedWrite bool // started request body write; guarded by cc.mu requestedGzip bool - on100 func() // optional code to run if get a 100 continue response + isHead bool + + abortOnce sync.Once + abort chan struct{} // closed to signal stream should end immediately + abortErr error // set if abort is closed + + peerClosed chan struct{} // closed when the peer sends an END_STREAM flag + donec chan struct{} // closed after the stream is in the closed state + on100 chan struct{} // buffered; written to if a 100 is received + + respHeaderRecv chan struct{} // closed when headers are received + res *Response // set if respHeaderRecv is closed flow http2flow // guarded by cc.mu inflow http2flow // guarded by cc.mu bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read readErr error // sticky read error; owned by transportResponseBody.Read - stopReqBody error // if non-nil, stop writing req body; guarded by cc.mu - didReset bool // whether we sent a RST_STREAM to the server; guarded by cc.mu - peerReset chan struct{} // closed on peer reset - resetErr error // populated before peerReset is closed + reqBody io.ReadCloser + reqBodyContentLength int64 // -1 means unknown + reqBodyClosed bool // body has been closed; guarded by cc.mu - done chan struct{} // closed when stream remove from cc.streams map; close calls guarded by cc.mu + // owned by writeRequest: + sentEndStream bool // sent an END_STREAM flag to the peer + sentHeaders bool // owned by clientConnReadLoop: firstByte bool // got the first response byte pastHeaders bool // got first MetaHeadersFrame (actual headers) pastTrailers bool // got optional second MetaHeadersFrame (trailers) num1xx uint8 // number of 1xx responses seen + readClosed bool // peer sent an END_STREAM flag + readAborted bool // read loop reset the stream trailer Header // accumulated trailers resTrailer *Header // client's Response.Trailer } -// awaitRequestCancel waits for the user to cancel a request or for the done -// channel to be signaled. A non-nil error is returned only if the request was -// canceled. -func http2awaitRequestCancel(req *Request, done <-chan struct{}) error { - ctx := req.Context() - if req.Cancel == nil && ctx.Done() == nil { - return nil - } - select { - case <-req.Cancel: - return http2errRequestCanceled - case <-ctx.Done(): - return ctx.Err() - case <-done: - return nil - } -} - var http2got1xxFuncForTests func(int, textproto.MIMEHeader) error // get1xxTraceFunc returns the value of request's httptrace.ClientTrace.Got1xxResponse func, @@ -6951,73 +7078,65 @@ func (cs *http2clientStream) get1xxTraceFunc() func(int, textproto.MIMEHeader) e return http2traceGot1xxResponseFunc(cs.trace) } -// awaitRequestCancel waits for the user to cancel a request, its context to -// expire, or for the request to be done (any way it might be removed from the -// cc.streams map: peer reset, successful completion, TCP connection breakage, -// etc). If the request is canceled, then cs will be canceled and closed. -func (cs *http2clientStream) awaitRequestCancel(req *Request) { - if err := http2awaitRequestCancel(req, cs.done); err != nil { - cs.cancelStream() - cs.bufPipe.CloseWithError(err) - } +func (cs *http2clientStream) abortStream(err error) { + cs.cc.mu.Lock() + defer cs.cc.mu.Unlock() + cs.abortStreamLocked(err) } -func (cs *http2clientStream) cancelStream() { - cc := cs.cc - cc.mu.Lock() - didReset := cs.didReset - cs.didReset = true - cc.mu.Unlock() - - if !didReset { - cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) - cc.forgetStreamID(cs.ID) +func (cs *http2clientStream) abortStreamLocked(err error) { + cs.abortOnce.Do(func() { + cs.abortErr = err + close(cs.abort) + }) + if cs.reqBody != nil && !cs.reqBodyClosed { + cs.reqBody.Close() + cs.reqBodyClosed = true } -} - -// checkResetOrDone reports any error sent in a RST_STREAM frame by the -// server, or errStreamClosed if the stream is complete. -func (cs *http2clientStream) checkResetOrDone() error { - select { - case <-cs.peerReset: - return cs.resetErr - case <-cs.done: - return http2errStreamClosed - default: - return nil + // TODO(dneil): Clean up tests where cs.cc.cond is nil. + if cs.cc.cond != nil { + // Wake up writeRequestBody if it is waiting on flow control. + cs.cc.cond.Broadcast() } } -func (cs *http2clientStream) getStartedWrite() bool { +func (cs *http2clientStream) abortRequestBodyWrite() { cc := cs.cc cc.mu.Lock() defer cc.mu.Unlock() - return cs.startedWrite -} - -func (cs *http2clientStream) abortRequestBodyWrite(err error) { - if err == nil { - panic("nil error") + if cs.reqBody != nil && !cs.reqBodyClosed { + cs.reqBody.Close() + cs.reqBodyClosed = true + cc.cond.Broadcast() } - cc := cs.cc - cc.mu.Lock() - cs.stopReqBody = err - cc.cond.Broadcast() - cc.mu.Unlock() } type http2stickyErrWriter struct { - w io.Writer - err *error + conn net.Conn + timeout time.Duration + err *error } func (sew http2stickyErrWriter) Write(p []byte) (n int, err error) { if *sew.err != nil { return 0, *sew.err } - n, err = sew.w.Write(p) - *sew.err = err - return + for { + if sew.timeout != 0 { + sew.conn.SetWriteDeadline(time.Now().Add(sew.timeout)) + } + nn, err := sew.conn.Write(p[n:]) + n += nn + if n < len(p) && nn > 0 && errors.Is(err, os.ErrDeadlineExceeded) { + // Keep extending the deadline so long as we're making progress. + continue + } + if sew.timeout != 0 { + sew.conn.SetWriteDeadline(time.Time{}) + } + *sew.err = err + return n, err + } } // noCachedConnError is the concrete type of ErrNoCachedConn, which @@ -7091,9 +7210,9 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res } reused := !atomic.CompareAndSwapUint32(&cc.reused, 0, 1) http2traceGotConn(req, cc, reused) - res, gotErrAfterReqBodyWrite, err := cc.roundTrip(req) + res, err := cc.RoundTrip(req) if err != nil && retry <= 6 { - if req, err = http2shouldRetryRequest(req, err, gotErrAfterReqBodyWrite); err == nil { + if req, err = http2shouldRetryRequest(req, err); err == nil { // After the first retry, do exponential backoff with 10% jitter. if retry == 0 { continue @@ -7104,7 +7223,7 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res case <-time.After(time.Second * time.Duration(backoff)): continue case <-req.Context().Done(): - return nil, req.Context().Err() + err = req.Context().Err() } } } @@ -7135,7 +7254,7 @@ var ( // response headers. It is always called with a non-nil error. // It returns either a request to retry (either the same request, or a // modified clone), or an error if the request can't be replayed. -func http2shouldRetryRequest(req *Request, err error, afterBodyWrite bool) (*Request, error) { +func http2shouldRetryRequest(req *Request, err error) (*Request, error) { if !http2canRetryError(err) { return nil, err } @@ -7148,7 +7267,6 @@ func http2shouldRetryRequest(req *Request, err error, afterBodyWrite bool) (*Req // If the request body can be reset back to its original // state via the optional req.GetBody, do that. if req.GetBody != nil { - // TODO: consider a req.Body.Close here? or audit that all caller paths do? body, err := req.GetBody() if err != nil { return nil, err @@ -7160,10 +7278,8 @@ func http2shouldRetryRequest(req *Request, err error, afterBodyWrite bool) (*Req // The Request.Body can't reset back to the beginning, but we // don't seem to have started to read from it yet, so reuse - // the request directly. The "afterBodyWrite" means the - // bodyWrite process has started, which becomes true before - // the first Read. - if !afterBodyWrite { + // the request directly. + if err == http2errClientConnUnusable { return req, nil } @@ -7175,6 +7291,10 @@ func http2canRetryError(err error) bool { return true } if se, ok := err.(http2StreamError); ok { + if se.Code == http2ErrCodeProtocol && se.Cause == http2errFromPeer { + // See golang/go#47635, golang/go#42777 + return true + } return se.Code == http2ErrCodeRefusedStream } return false @@ -7249,14 +7369,15 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client tconn: c, readerDone: make(chan struct{}), nextStreamID: 1, - maxFrameSize: 16 << 10, // spec default - initialWindowSize: 65535, // spec default - maxConcurrentStreams: 1000, // "infinite", per spec. 1000 seems good enough. - peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead. + maxFrameSize: 16 << 10, // spec default + initialWindowSize: 65535, // spec default + maxConcurrentStreams: http2initialMaxConcurrentStreams, // "infinite", per spec. Use a smaller value until we have received server settings. + peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead. streams: make(map[uint32]*http2clientStream), singleUse: singleUse, wantSettingsAck: true, pings: make(map[[8]byte]chan struct{}), + reqHeaderMu: make(chan struct{}, 1), } if d := t.idleConnTimeout(); d != 0 { cc.idleTimeout = d @@ -7271,9 +7392,16 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client // TODO: adjust this writer size to account for frame size + // MTU + crypto/tls record padding. - cc.bw = bufio.NewWriter(http2stickyErrWriter{c, &cc.werr}) + cc.bw = bufio.NewWriter(http2stickyErrWriter{ + conn: c, + timeout: t.WriteByteTimeout, + err: &cc.werr, + }) cc.br = bufio.NewReader(c) cc.fr = http2NewFramer(cc.bw, cc.br) + if t.CountError != nil { + cc.fr.countError = t.CountError + } cc.fr.ReadMetaHeaders = hpack.NewDecoder(http2initialHeaderTableSize, nil) cc.fr.MaxHeaderListSize = t.maxHeaderListSize() @@ -7326,6 +7454,13 @@ func (cc *http2ClientConn) healthCheck() { } } +// SetDoNotReuse marks cc as not reusable for future HTTP requests. +func (cc *http2ClientConn) SetDoNotReuse() { + cc.mu.Lock() + defer cc.mu.Unlock() + cc.doNotReuse = true +} + func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) { cc.mu.Lock() defer cc.mu.Unlock() @@ -7343,27 +7478,94 @@ func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) { last := f.LastStreamID for streamID, cs := range cc.streams { if streamID > last { - select { - case cs.resc <- http2resAndError{err: http2errClientConnGotGoAway}: - default: - } + cs.abortStreamLocked(http2errClientConnGotGoAway) } } } // CanTakeNewRequest reports whether the connection can take a new request, // meaning it has not been closed or received or sent a GOAWAY. +// +// If the caller is going to immediately make a new request on this +// connection, use ReserveNewRequest instead. func (cc *http2ClientConn) CanTakeNewRequest() bool { cc.mu.Lock() defer cc.mu.Unlock() return cc.canTakeNewRequestLocked() } +// ReserveNewRequest is like CanTakeNewRequest but also reserves a +// concurrent stream in cc. The reservation is decremented on the +// next call to RoundTrip. +func (cc *http2ClientConn) ReserveNewRequest() bool { + cc.mu.Lock() + defer cc.mu.Unlock() + if st := cc.idleStateLocked(); !st.canTakeNewRequest { + return false + } + cc.streamsReserved++ + return true +} + +// ClientConnState describes the state of a ClientConn. +type http2ClientConnState struct { + // Closed is whether the connection is closed. + Closed bool + + // Closing is whether the connection is in the process of + // closing. It may be closing due to shutdown, being a + // single-use connection, being marked as DoNotReuse, or + // having received a GOAWAY frame. + Closing bool + + // StreamsActive is how many streams are active. + StreamsActive int + + // StreamsReserved is how many streams have been reserved via + // ClientConn.ReserveNewRequest. + StreamsReserved int + + // StreamsPending is how many requests have been sent in excess + // of the peer's advertised MaxConcurrentStreams setting and + // are waiting for other streams to complete. + StreamsPending int + + // MaxConcurrentStreams is how many concurrent streams the + // peer advertised as acceptable. Zero means no SETTINGS + // frame has been received yet. + MaxConcurrentStreams uint32 + + // LastIdle, if non-zero, is when the connection last + // transitioned to idle state. + LastIdle time.Time +} + +// State returns a snapshot of cc's state. +func (cc *http2ClientConn) State() http2ClientConnState { + cc.wmu.Lock() + maxConcurrent := cc.maxConcurrentStreams + if !cc.seenSettings { + maxConcurrent = 0 + } + cc.wmu.Unlock() + + cc.mu.Lock() + defer cc.mu.Unlock() + return http2ClientConnState{ + Closed: cc.closed, + Closing: cc.closing || cc.singleUse || cc.doNotReuse || cc.goAway != nil, + StreamsActive: len(cc.streams), + StreamsReserved: cc.streamsReserved, + StreamsPending: cc.pendingRequests, + LastIdle: cc.lastIdle, + MaxConcurrentStreams: maxConcurrent, + } +} + // clientConnIdleState describes the suitability of a client // connection to initiate a new RoundTrip request. type http2clientConnIdleState struct { canTakeNewRequest bool - freshConn bool // whether it's unused by any previous request } func (cc *http2ClientConn) idleState() http2clientConnIdleState { @@ -7384,13 +7586,13 @@ func (cc *http2ClientConn) idleStateLocked() (st http2clientConnIdleState) { // writing it. maxConcurrentOkay = true } else { - maxConcurrentOkay = int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) + maxConcurrentOkay = int64(len(cc.streams)+cc.streamsReserved+1) <= int64(cc.maxConcurrentStreams) } st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing && maxConcurrentOkay && + !cc.doNotReuse && int64(cc.nextStreamID)+2*int64(cc.pendingRequests) < math.MaxInt32 && !cc.tooIdleLocked() - st.freshConn = cc.nextStreamID == 1 && st.canTakeNewRequest return } @@ -7421,7 +7623,7 @@ func (cc *http2ClientConn) onIdleTimeout() { func (cc *http2ClientConn) closeIfIdle() { cc.mu.Lock() - if len(cc.streams) > 0 { + if len(cc.streams) > 0 || cc.streamsReserved > 0 { cc.mu.Unlock() return } @@ -7436,9 +7638,15 @@ func (cc *http2ClientConn) closeIfIdle() { cc.tconn.Close() } +func (cc *http2ClientConn) isDoNotReuseAndIdle() bool { + cc.mu.Lock() + defer cc.mu.Unlock() + return cc.doNotReuse && len(cc.streams) == 0 +} + var http2shutdownEnterWaitStateHook = func() {} -// Shutdown gracefully close the client connection, waiting for running streams to complete. +// Shutdown gracefully closes the client connection, waiting for running streams to complete. func (cc *http2ClientConn) Shutdown(ctx context.Context) error { if err := cc.sendGoAway(); err != nil { return err @@ -7477,15 +7685,18 @@ func (cc *http2ClientConn) Shutdown(ctx context.Context) error { func (cc *http2ClientConn) sendGoAway() error { cc.mu.Lock() - defer cc.mu.Unlock() - cc.wmu.Lock() - defer cc.wmu.Unlock() - if cc.closing { + closing := cc.closing + cc.closing = true + maxStreamID := cc.nextStreamID + cc.mu.Unlock() + if closing { // GOAWAY sent already return nil } + + cc.wmu.Lock() + defer cc.wmu.Unlock() // Send a graceful shutdown frame to server - maxStreamID := cc.nextStreamID if err := cc.fr.WriteGoAway(maxStreamID, http2ErrCodeNo, nil); err != nil { return err } @@ -7493,7 +7704,6 @@ func (cc *http2ClientConn) sendGoAway() error { return err } // Prevent new requests - cc.closing = true return nil } @@ -7501,17 +7711,12 @@ func (cc *http2ClientConn) sendGoAway() error { // err is sent to streams. func (cc *http2ClientConn) closeForError(err error) error { cc.mu.Lock() + cc.closed = true + for _, cs := range cc.streams { + cs.abortStreamLocked(err) + } defer cc.cond.Broadcast() defer cc.mu.Unlock() - for id, cs := range cc.streams { - select { - case cs.resc <- http2resAndError{err: err}: - default: - } - cs.bufPipe.CloseWithError(err) - delete(cc.streams, id) - } - cc.closed = true return cc.tconn.Close() } @@ -7526,47 +7731,10 @@ func (cc *http2ClientConn) Close() error { // closes the client connection immediately. In-flight requests are interrupted. func (cc *http2ClientConn) closeForLostPing() error { err := errors.New("http2: client connection lost") - return cc.closeForError(err) -} - -const http2maxAllocFrameSize = 512 << 10 - -// frameBuffer returns a scratch buffer suitable for writing DATA frames. -// They're capped at the min of the peer's max frame size or 512KB -// (kinda arbitrarily), but definitely capped so we don't allocate 4GB -// bufers. -func (cc *http2ClientConn) frameScratchBuffer() []byte { - cc.mu.Lock() - size := cc.maxFrameSize - if size > http2maxAllocFrameSize { - size = http2maxAllocFrameSize - } - for i, buf := range cc.freeBuf { - if len(buf) >= int(size) { - cc.freeBuf[i] = nil - cc.mu.Unlock() - return buf[:size] - } - } - cc.mu.Unlock() - return make([]byte, size) -} - -func (cc *http2ClientConn) putFrameScratchBuffer(buf []byte) { - cc.mu.Lock() - defer cc.mu.Unlock() - const maxBufs = 4 // arbitrary; 4 concurrent requests per conn? investigate. - if len(cc.freeBuf) < maxBufs { - cc.freeBuf = append(cc.freeBuf, buf) - return - } - for i, old := range cc.freeBuf { - if old == nil { - cc.freeBuf[i] = buf - return - } + if f := cc.t.CountError; f != nil { + f("conn_close_lost_ping") } - // forget about it. + return cc.closeForError(err) } // errRequestCanceled is a copy of net/http's errRequestCanceled because it's not @@ -7630,41 +7798,158 @@ func http2actualContentLength(req *Request) int64 { return -1 } +func (cc *http2ClientConn) decrStreamReservations() { + cc.mu.Lock() + defer cc.mu.Unlock() + cc.decrStreamReservationsLocked() +} + +func (cc *http2ClientConn) decrStreamReservationsLocked() { + if cc.streamsReserved > 0 { + cc.streamsReserved-- + } +} + func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { - resp, _, err := cc.roundTrip(req) - return resp, err + ctx := req.Context() + cs := &http2clientStream{ + cc: cc, + ctx: ctx, + reqCancel: req.Cancel, + isHead: req.Method == "HEAD", + reqBody: req.Body, + reqBodyContentLength: http2actualContentLength(req), + trace: httptrace.ContextClientTrace(ctx), + peerClosed: make(chan struct{}), + abort: make(chan struct{}), + respHeaderRecv: make(chan struct{}), + donec: make(chan struct{}), + } + go cs.doRequest(req) + + waitDone := func() error { + select { + case <-cs.donec: + return nil + case <-ctx.Done(): + return ctx.Err() + case <-cs.reqCancel: + return http2errRequestCanceled + } + } + + handleResponseHeaders := func() (*Response, error) { + res := cs.res + if res.StatusCode > 299 { + // On error or status code 3xx, 4xx, 5xx, etc abort any + // ongoing write, assuming that the server doesn't care + // about our request body. If the server replied with 1xx or + // 2xx, however, then assume the server DOES potentially + // want our body (e.g. full-duplex streaming: + // golang.org/issue/13444). If it turns out the server + // doesn't, they'll RST_STREAM us soon enough. This is a + // heuristic to avoid adding knobs to Transport. Hopefully + // we can keep it. + cs.abortRequestBodyWrite() + } + res.Request = req + res.TLS = cc.tlsState + if res.Body == http2noBody && http2actualContentLength(req) == 0 { + // If there isn't a request or response body still being + // written, then wait for the stream to be closed before + // RoundTrip returns. + if err := waitDone(); err != nil { + return nil, err + } + } + return res, nil + } + + for { + select { + case <-cs.respHeaderRecv: + return handleResponseHeaders() + case <-cs.abort: + select { + case <-cs.respHeaderRecv: + // If both cs.respHeaderRecv and cs.abort are signaling, + // pick respHeaderRecv. The server probably wrote the + // response and immediately reset the stream. + // golang.org/issue/49645 + return handleResponseHeaders() + default: + waitDone() + return nil, cs.abortErr + } + case <-ctx.Done(): + err := ctx.Err() + cs.abortStream(err) + return nil, err + case <-cs.reqCancel: + cs.abortStream(http2errRequestCanceled) + return nil, http2errRequestCanceled + } + } } -func (cc *http2ClientConn) roundTrip(req *Request) (res *Response, gotErrAfterReqBodyWrite bool, err error) { +// doRequest runs for the duration of the request lifetime. +// +// It sends the request and performs post-request cleanup (closing Request.Body, etc.). +func (cs *http2clientStream) doRequest(req *Request) { + err := cs.writeRequest(req) + cs.cleanupWriteRequest(err) +} + +// writeRequest sends a request. +// +// It returns nil after the request is written, the response read, +// and the request stream is half-closed by the peer. +// +// It returns non-nil if the request ends otherwise. +// If the returned error is StreamError, the error Code may be used in resetting the stream. +func (cs *http2clientStream) writeRequest(req *Request) (err error) { + cc := cs.cc + ctx := cs.ctx + if err := http2checkConnHeaders(req); err != nil { - return nil, false, err - } - if cc.idleTimer != nil { - cc.idleTimer.Stop() + return err } - trailers, err := http2commaSeparatedTrailers(req) - if err != nil { - return nil, false, err + // Acquire the new-request lock by writing to reqHeaderMu. + // This lock guards the critical section covering allocating a new stream ID + // (requires mu) and creating the stream (requires wmu). + if cc.reqHeaderMu == nil { + panic("RoundTrip on uninitialized ClientConn") // for tests + } + select { + case cc.reqHeaderMu <- struct{}{}: + case <-cs.reqCancel: + return http2errRequestCanceled + case <-ctx.Done(): + return ctx.Err() } - hasTrailers := trailers != "" cc.mu.Lock() - if err := cc.awaitOpenSlotForRequest(req); err != nil { + if cc.idleTimer != nil { + cc.idleTimer.Stop() + } + cc.decrStreamReservationsLocked() + if err := cc.awaitOpenSlotForStreamLocked(cs); err != nil { cc.mu.Unlock() - return nil, false, err + <-cc.reqHeaderMu + return err } - - body := req.Body - contentLen := http2actualContentLength(req) - hasBody := contentLen != 0 + cc.addStreamLocked(cs) // assigns stream ID + if http2isConnectionCloseRequest(req) { + cc.doNotReuse = true + } + cc.mu.Unlock() // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? - var requestedGzip bool if !cc.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && - req.Method != "HEAD" { + !cs.isHead { // Request gzip only, not deflate. Deflate is ambiguous and // not as universally supported anyway. // See: https://zlib.net/zlib_faq.html#faq39 @@ -7677,210 +7962,224 @@ func (cc *http2ClientConn) roundTrip(req *Request) (res *Response, gotErrAfterRe // We don't request gzip if the request is for a range, since // auto-decoding a portion of a gzipped document will just fail // anyway. See https://golang.org/issue/8923 - requestedGzip = true + cs.requestedGzip = true } - // we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is - // sent by writeRequestBody below, along with any Trailers, - // again in form HEADERS{1}, CONTINUATION{0,}) - hdrs, err := cc.encodeHeaders(req, requestedGzip, trailers, contentLen) - if err != nil { - cc.mu.Unlock() - return nil, false, err + continueTimeout := cc.t.expectContinueTimeout() + if continueTimeout != 0 { + if !httpguts.HeaderValuesContainsToken(req.Header["Expect"], "100-continue") { + continueTimeout = 0 + } else { + cs.on100 = make(chan struct{}, 1) + } } - cs := cc.newStream() - cs.req = req - cs.trace = httptrace.ContextClientTrace(req.Context()) - cs.requestedGzip = requestedGzip - bodyWriter := cc.t.getBodyWriterState(cs, body) - cs.on100 = bodyWriter.on100 + // Past this point (where we send request headers), it is possible for + // RoundTrip to return successfully. Since the RoundTrip contract permits + // the caller to "mutate or reuse" the Request after closing the Response's Body, + // we must take care when referencing the Request from here on. + err = cs.encodeAndWriteHeaders(req) + <-cc.reqHeaderMu + if err != nil { + return err + } - defer func() { - cc.wmu.Lock() - werr := cc.werr - cc.wmu.Unlock() - if werr != nil { - cc.Close() + hasBody := cs.reqBodyContentLength != 0 + if !hasBody { + cs.sentEndStream = true + } else { + if continueTimeout != 0 { + http2traceWait100Continue(cs.trace) + timer := time.NewTimer(continueTimeout) + select { + case <-timer.C: + err = nil + case <-cs.on100: + err = nil + case <-cs.abort: + err = cs.abortErr + case <-ctx.Done(): + err = ctx.Err() + case <-cs.reqCancel: + err = http2errRequestCanceled + } + timer.Stop() + if err != nil { + http2traceWroteRequest(cs.trace, err) + return err + } } - }() - - cc.wmu.Lock() - endStream := !hasBody && !hasTrailers - werr := cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs) - cc.wmu.Unlock() - http2traceWroteHeaders(cs.trace) - cc.mu.Unlock() - if werr != nil { - if hasBody { - req.Body.Close() // per RoundTripper contract - bodyWriter.cancel() + if err = cs.writeRequestBody(req); err != nil { + if err != http2errStopReqBodyWrite { + http2traceWroteRequest(cs.trace, err) + return err + } + } else { + cs.sentEndStream = true } - cc.forgetStreamID(cs.ID) - // Don't bother sending a RST_STREAM (our write already failed; - // no need to keep writing) - http2traceWroteRequest(cs.trace, werr) - return nil, false, werr } + http2traceWroteRequest(cs.trace, err) + var respHeaderTimer <-chan time.Time - if hasBody { - bodyWriter.scheduleBodyWrite() - } else { - http2traceWroteRequest(cs.trace, nil) - if d := cc.responseHeaderTimeout(); d != 0 { - timer := time.NewTimer(d) - defer timer.Stop() - respHeaderTimer = timer.C + var respHeaderRecv chan struct{} + if d := cc.responseHeaderTimeout(); d != 0 { + timer := time.NewTimer(d) + defer timer.Stop() + respHeaderTimer = timer.C + respHeaderRecv = cs.respHeaderRecv + } + // Wait until the peer half-closes its end of the stream, + // or until the request is aborted (via context, error, or otherwise), + // whichever comes first. + for { + select { + case <-cs.peerClosed: + return nil + case <-respHeaderTimer: + return http2errTimeout + case <-respHeaderRecv: + respHeaderRecv = nil + respHeaderTimer = nil // keep waiting for END_STREAM + case <-cs.abort: + return cs.abortErr + case <-ctx.Done(): + return ctx.Err() + case <-cs.reqCancel: + return http2errRequestCanceled } } +} - readLoopResCh := cs.resc - bodyWritten := false - ctx := req.Context() +func (cs *http2clientStream) encodeAndWriteHeaders(req *Request) error { + cc := cs.cc + ctx := cs.ctx - handleReadLoopResponse := func(re http2resAndError) (*Response, bool, error) { - res := re.res - if re.err != nil || res.StatusCode > 299 { - // On error or status code 3xx, 4xx, 5xx, etc abort any - // ongoing write, assuming that the server doesn't care - // about our request body. If the server replied with 1xx or - // 2xx, however, then assume the server DOES potentially - // want our body (e.g. full-duplex streaming: - // golang.org/issue/13444). If it turns out the server - // doesn't, they'll RST_STREAM us soon enough. This is a - // heuristic to avoid adding knobs to Transport. Hopefully - // we can keep it. - bodyWriter.cancel() - cs.abortRequestBodyWrite(http2errStopReqBodyWrite) - if hasBody && !bodyWritten { - <-bodyWriter.resc - } - } - if re.err != nil { - cc.forgetStreamID(cs.ID) - return nil, cs.getStartedWrite(), re.err - } - res.Request = req - res.TLS = cc.tlsState - return res, false, nil + cc.wmu.Lock() + defer cc.wmu.Unlock() + + // If the request was canceled while waiting for cc.mu, just quit. + select { + case <-cs.abort: + return cs.abortErr + case <-ctx.Done(): + return ctx.Err() + case <-cs.reqCancel: + return http2errRequestCanceled + default: } - for { + // Encode headers. + // + // we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is + // sent by writeRequestBody below, along with any Trailers, + // again in form HEADERS{1}, CONTINUATION{0,}) + trailers, err := http2commaSeparatedTrailers(req) + if err != nil { + return err + } + hasTrailers := trailers != "" + contentLen := http2actualContentLength(req) + hasBody := contentLen != 0 + hdrs, err := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen) + if err != nil { + return err + } + + // Write the request. + endStream := !hasBody && !hasTrailers + cs.sentHeaders = true + err = cc.writeHeaders(cs.ID, endStream, int(cc.maxFrameSize), hdrs) + http2traceWroteHeaders(cs.trace) + return err +} + +// cleanupWriteRequest performs post-request tasks. +// +// If err (the result of writeRequest) is non-nil and the stream is not closed, +// cleanupWriteRequest will send a reset to the peer. +func (cs *http2clientStream) cleanupWriteRequest(err error) { + cc := cs.cc + + if cs.ID == 0 { + // We were canceled before creating the stream, so return our reservation. + cc.decrStreamReservations() + } + + // TODO: write h12Compare test showing whether + // Request.Body is closed by the Transport, + // and in multiple cases: server replies <=299 and >299 + // while still writing request body + cc.mu.Lock() + bodyClosed := cs.reqBodyClosed + cs.reqBodyClosed = true + cc.mu.Unlock() + if !bodyClosed && cs.reqBody != nil { + cs.reqBody.Close() + } + + if err != nil && cs.sentEndStream { + // If the connection is closed immediately after the response is read, + // we may be aborted before finishing up here. If the stream was closed + // cleanly on both sides, there is no error. select { - case re := <-readLoopResCh: - return handleReadLoopResponse(re) - case <-respHeaderTimer: - if !hasBody || bodyWritten { - cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) - } else { - bodyWriter.cancel() - cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) - <-bodyWriter.resc - } - cc.forgetStreamID(cs.ID) - return nil, cs.getStartedWrite(), http2errTimeout - case <-ctx.Done(): - select { - case re := <-readLoopResCh: - return handleReadLoopResponse(re) - default: - } - if !hasBody || bodyWritten { - cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) - } else { - bodyWriter.cancel() - cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) - <-bodyWriter.resc - } - cc.forgetStreamID(cs.ID) - return nil, cs.getStartedWrite(), ctx.Err() - case <-req.Cancel: - select { - case re := <-readLoopResCh: - return handleReadLoopResponse(re) - default: - } - if !hasBody || bodyWritten { - cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + case <-cs.peerClosed: + err = nil + default: + } + } + if err != nil { + cs.abortStream(err) // possibly redundant, but harmless + if cs.sentHeaders { + if se, ok := err.(http2StreamError); ok { + if se.Cause != http2errFromPeer { + cc.writeStreamReset(cs.ID, se.Code, err) + } } else { - bodyWriter.cancel() - cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) - <-bodyWriter.resc - } - cc.forgetStreamID(cs.ID) - return nil, cs.getStartedWrite(), http2errRequestCanceled - case <-cs.peerReset: - select { - case re := <-readLoopResCh: - return handleReadLoopResponse(re) - default: - } - // processResetStream already removed the - // stream from the streams map; no need for - // forgetStreamID. - return nil, cs.getStartedWrite(), cs.resetErr - case err := <-bodyWriter.resc: - bodyWritten = true - // Prefer the read loop's response, if available. Issue 16102. - select { - case re := <-readLoopResCh: - return handleReadLoopResponse(re) - default: - } - if err != nil { - cc.forgetStreamID(cs.ID) - return nil, cs.getStartedWrite(), err - } - if d := cc.responseHeaderTimeout(); d != 0 { - timer := time.NewTimer(d) - defer timer.Stop() - respHeaderTimer = timer.C + cc.writeStreamReset(cs.ID, http2ErrCodeCancel, err) } } + cs.bufPipe.CloseWithError(err) // no-op if already closed + } else { + if cs.sentHeaders && !cs.sentEndStream { + cc.writeStreamReset(cs.ID, http2ErrCodeNo, nil) + } + cs.bufPipe.CloseWithError(http2errRequestCanceled) + } + if cs.ID != 0 { + cc.forgetStreamID(cs.ID) + } + + cc.wmu.Lock() + werr := cc.werr + cc.wmu.Unlock() + if werr != nil { + cc.Close() } + + close(cs.donec) } -// awaitOpenSlotForRequest waits until len(streams) < maxConcurrentStreams. +// awaitOpenSlotForStream waits until len(streams) < maxConcurrentStreams. // Must hold cc.mu. -func (cc *http2ClientConn) awaitOpenSlotForRequest(req *Request) error { - var waitingForConn chan struct{} - var waitingForConnErr error // guarded by cc.mu +func (cc *http2ClientConn) awaitOpenSlotForStreamLocked(cs *http2clientStream) error { for { cc.lastActive = time.Now() if cc.closed || !cc.canTakeNewRequestLocked() { - if waitingForConn != nil { - close(waitingForConn) - } return http2errClientConnUnusable } cc.lastIdle = time.Time{} - if int64(len(cc.streams))+1 <= int64(cc.maxConcurrentStreams) { - if waitingForConn != nil { - close(waitingForConn) - } + if int64(len(cc.streams)) < int64(cc.maxConcurrentStreams) { return nil } - // Unfortunately, we cannot wait on a condition variable and channel at - // the same time, so instead, we spin up a goroutine to check if the - // request is canceled while we wait for a slot to open in the connection. - if waitingForConn == nil { - waitingForConn = make(chan struct{}) - go func() { - if err := http2awaitRequestCancel(req, waitingForConn); err != nil { - cc.mu.Lock() - waitingForConnErr = err - cc.cond.Broadcast() - cc.mu.Unlock() - } - }() - } cc.pendingRequests++ cc.cond.Wait() cc.pendingRequests-- - if waitingForConnErr != nil { - return waitingForConnErr + select { + case <-cs.abort: + return cs.abortErr + default: } } } @@ -7907,10 +8206,6 @@ func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, maxFram cc.fr.WriteContinuation(streamID, endHeaders, chunk) } } - // TODO(bradfitz): this Flush could potentially block (as - // could the WriteHeaders call(s) above), which means they - // wouldn't respond to Request.Cancel being readable. That's - // rare, but this should probably be in a goroutine. cc.bw.Flush() return cc.werr } @@ -7926,32 +8221,59 @@ var ( http2errReqBodyTooLong = errors.New("http2: request body larger than specified content length") ) -func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (err error) { +// frameScratchBufferLen returns the length of a buffer to use for +// outgoing request bodies to read/write to/from. +// +// It returns max(1, min(peer's advertised max frame size, +// Request.ContentLength+1, 512KB)). +func (cs *http2clientStream) frameScratchBufferLen(maxFrameSize int) int { + const max = 512 << 10 + n := int64(maxFrameSize) + if n > max { + n = max + } + if cl := cs.reqBodyContentLength; cl != -1 && cl+1 < n { + // Add an extra byte past the declared content-length to + // give the caller's Request.Body io.Reader a chance to + // give us more bytes than they declared, so we can catch it + // early. + n = cl + 1 + } + if n < 1 { + return 1 + } + return int(n) // doesn't truncate; max is 512K +} + +var http2bufPool sync.Pool // of *[]byte + +func (cs *http2clientStream) writeRequestBody(req *Request) (err error) { cc := cs.cc + body := cs.reqBody sentEnd := false // whether we sent the final DATA frame w/ END_STREAM - buf := cc.frameScratchBuffer() - defer cc.putFrameScratchBuffer(buf) - - defer func() { - http2traceWroteRequest(cs.trace, err) - // TODO: write h12Compare test showing whether - // Request.Body is closed by the Transport, - // and in multiple cases: server replies <=299 and >299 - // while still writing request body - cerr := bodyCloser.Close() - if err == nil { - err = cerr - } - }() - req := cs.req hasTrailers := req.Trailer != nil - remainLen := http2actualContentLength(req) + remainLen := cs.reqBodyContentLength hasContentLen := remainLen != -1 + cc.mu.Lock() + maxFrameSize := int(cc.maxFrameSize) + cc.mu.Unlock() + + // Scratch buffer for reading into & writing from. + scratchLen := cs.frameScratchBufferLen(maxFrameSize) + var buf []byte + if bp, ok := http2bufPool.Get().(*[]byte); ok && len(*bp) >= scratchLen { + defer http2bufPool.Put(bp) + buf = *bp + } else { + buf = make([]byte, scratchLen) + defer http2bufPool.Put(&buf) + } + var sawEOF bool for !sawEOF { - n, err := body.Read(buf[:len(buf)-1]) + n, err := body.Read(buf[:len(buf)]) if hasContentLen { remainLen -= int64(n) if remainLen == 0 && err == nil { @@ -7962,35 +8284,36 @@ func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Clos // to send the END_STREAM bit early, double-check that we're actually // at EOF. Subsequent reads should return (0, EOF) at this point. // If either value is different, we return an error in one of two ways below. + var scratch [1]byte var n1 int - n1, err = body.Read(buf[n:]) + n1, err = body.Read(scratch[:]) remainLen -= int64(n1) } if remainLen < 0 { err = http2errReqBodyTooLong - cc.writeStreamReset(cs.ID, http2ErrCodeCancel, err) return err } } - if err == io.EOF { - sawEOF = true - err = nil - } else if err != nil { - cc.writeStreamReset(cs.ID, http2ErrCodeCancel, err) - return err + if err != nil { + cc.mu.Lock() + bodyClosed := cs.reqBodyClosed + cc.mu.Unlock() + switch { + case bodyClosed: + return http2errStopReqBodyWrite + case err == io.EOF: + sawEOF = true + err = nil + default: + return err + } } remain := buf[:n] for len(remain) > 0 && err == nil { var allowed int32 allowed, err = cs.awaitFlowControl(len(remain)) - switch { - case err == http2errStopReqBodyWrite: - return err - case err == http2errStopReqBodyWriteAndCancel: - cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) - return err - case err != nil: + if err != nil { return err } cc.wmu.Lock() @@ -8021,24 +8344,26 @@ func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Clos return nil } - var trls []byte - if hasTrailers { - cc.mu.Lock() - trls, err = cc.encodeTrailers(req) - cc.mu.Unlock() - if err != nil { - cc.writeStreamReset(cs.ID, http2ErrCodeInternal, err) - cc.forgetStreamID(cs.ID) - return err - } - } - + // Since the RoundTrip contract permits the caller to "mutate or reuse" + // a request after the Response's Body is closed, verify that this hasn't + // happened before accessing the trailers. cc.mu.Lock() - maxFrameSize := int(cc.maxFrameSize) + trailer := req.Trailer + err = cs.abortErr cc.mu.Unlock() + if err != nil { + return err + } cc.wmu.Lock() defer cc.wmu.Unlock() + var trls []byte + if len(trailer) > 0 { + trls, err = cc.encodeTrailers(trailer) + if err != nil { + return err + } + } // Two ways to send END_STREAM: either with trailers, or // with an empty DATA frame. @@ -8059,17 +8384,24 @@ func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Clos // if the stream is dead. func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) { cc := cs.cc + ctx := cs.ctx cc.mu.Lock() defer cc.mu.Unlock() for { if cc.closed { return 0, http2errClientConnClosed } - if cs.stopReqBody != nil { - return 0, cs.stopReqBody + if cs.reqBodyClosed { + return 0, http2errStopReqBodyWrite } - if err := cs.checkResetOrDone(); err != nil { - return 0, err + select { + case <-cs.abort: + return 0, cs.abortErr + case <-ctx.Done(): + return 0, ctx.Err() + case <-cs.reqCancel: + return 0, http2errRequestCanceled + default: } if a := cs.flow.available(); a > 0 { take := a @@ -8087,9 +8419,14 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err er } } -// requires cc.mu be held. +var http2errNilRequestURL = errors.New("http2: Request.URI is nil") + +// requires cc.wmu be held. func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) { cc.hbuf.Reset() + if req.URL == nil { + return nil, http2errNilRequestURL + } host := req.Host if host == "" { @@ -8275,12 +8612,12 @@ func http2shouldSendReqContentLength(method string, contentLength int64) bool { } } -// requires cc.mu be held. -func (cc *http2ClientConn) encodeTrailers(req *Request) ([]byte, error) { +// requires cc.wmu be held. +func (cc *http2ClientConn) encodeTrailers(trailer Header) ([]byte, error) { cc.hbuf.Reset() hlSize := uint64(0) - for k, vv := range req.Trailer { + for k, vv := range trailer { for _, v := range vv { hf := hpack.HeaderField{Name: k, Value: v} hlSize += uint64(hf.Size()) @@ -8290,7 +8627,7 @@ func (cc *http2ClientConn) encodeTrailers(req *Request) ([]byte, error) { return nil, http2errRequestHeaderListSize } - for k, vv := range req.Trailer { + for k, vv := range trailer { lowKey, ascii := http2asciiToLower(k) if !ascii { // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header @@ -8320,51 +8657,51 @@ type http2resAndError struct { } // requires cc.mu be held. -func (cc *http2ClientConn) newStream() *http2clientStream { - cs := &http2clientStream{ - cc: cc, - ID: cc.nextStreamID, - resc: make(chan http2resAndError, 1), - peerReset: make(chan struct{}), - done: make(chan struct{}), - } +func (cc *http2ClientConn) addStreamLocked(cs *http2clientStream) { cs.flow.add(int32(cc.initialWindowSize)) cs.flow.setConnFlow(&cc.flow) cs.inflow.add(http2transportDefaultStreamFlow) cs.inflow.setConnFlow(&cc.inflow) + cs.ID = cc.nextStreamID cc.nextStreamID += 2 cc.streams[cs.ID] = cs - return cs + if cs.ID == 0 { + panic("assigned stream ID 0") + } } func (cc *http2ClientConn) forgetStreamID(id uint32) { - cc.streamByID(id, true) -} - -func (cc *http2ClientConn) streamByID(id uint32, andRemove bool) *http2clientStream { cc.mu.Lock() - defer cc.mu.Unlock() - cs := cc.streams[id] - if andRemove && cs != nil && !cc.closed { - cc.lastActive = time.Now() - delete(cc.streams, id) - if len(cc.streams) == 0 && cc.idleTimer != nil { - cc.idleTimer.Reset(cc.idleTimeout) - cc.lastIdle = time.Now() - } - close(cs.done) - // Wake up checkResetOrDone via clientStream.awaitFlowControl and - // wake up RoundTrip if there is a pending request. - cc.cond.Broadcast() + slen := len(cc.streams) + delete(cc.streams, id) + if len(cc.streams) != slen-1 { + panic("forgetting unknown stream id") + } + cc.lastActive = time.Now() + if len(cc.streams) == 0 && cc.idleTimer != nil { + cc.idleTimer.Reset(cc.idleTimeout) + cc.lastIdle = time.Now() + } + // Wake up writeRequestBody via clientStream.awaitFlowControl and + // wake up RoundTrip if there is a pending request. + cc.cond.Broadcast() + + closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() + if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 { + if http2VerboseLogs { + cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, cc.nextStreamID-2) + } + cc.closed = true + defer cc.tconn.Close() } - return cs + + cc.mu.Unlock() } // clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop. type http2clientConnReadLoop struct { - _ http2incomparable - cc *http2ClientConn - closeWhenIdle bool + _ http2incomparable + cc *http2ClientConn } // readLoop runs in its own goroutine and reads and dispatches frames. @@ -8424,23 +8761,49 @@ func (rl *http2clientConnReadLoop) cleanup() { } else if err == io.EOF { err = io.ErrUnexpectedEOF } + cc.closed = true for _, cs := range cc.streams { - cs.bufPipe.CloseWithError(err) // no-op if already closed select { - case cs.resc <- http2resAndError{err: err}: + case <-cs.peerClosed: + // The server closed the stream before closing the conn, + // so no need to interrupt it. default: + cs.abortStreamLocked(err) } - close(cs.done) } - cc.closed = true cc.cond.Broadcast() cc.mu.Unlock() } +// countReadFrameError calls Transport.CountError with a string +// representing err. +func (cc *http2ClientConn) countReadFrameError(err error) { + f := cc.t.CountError + if f == nil || err == nil { + return + } + if ce, ok := err.(http2ConnectionError); ok { + errCode := http2ErrCode(ce) + f(fmt.Sprintf("read_frame_conn_error_%s", errCode.stringToken())) + return + } + if errors.Is(err, io.EOF) { + f("read_frame_eof") + return + } + if errors.Is(err, io.ErrUnexpectedEOF) { + f("read_frame_unexpected_eof") + return + } + if errors.Is(err, http2ErrFrameTooLarge) { + f("read_frame_too_large") + return + } + f("read_frame_other") +} + func (rl *http2clientConnReadLoop) run() error { cc := rl.cc - rl.closeWhenIdle = cc.t.disableKeepAlives() || cc.singleUse - gotReply := false // ever saw a HEADERS reply gotSettings := false readIdleTimeout := cc.t.ReadIdleTimeout var t *time.Timer @@ -8457,9 +8820,7 @@ func (rl *http2clientConnReadLoop) run() error { cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err) } if se, ok := err.(http2StreamError); ok { - if cs := cc.streamByID(se.StreamID, false); cs != nil { - cs.cc.writeStreamReset(cs.ID, se.Code, err) - cs.cc.forgetStreamID(cs.ID) + if cs := rl.streamByID(se.StreamID); cs != nil { if se.Cause == nil { se.Cause = cc.fr.errDetail } @@ -8467,6 +8828,7 @@ func (rl *http2clientConnReadLoop) run() error { } continue } else if err != nil { + cc.countReadFrameError(err) return err } if http2VerboseLogs { @@ -8479,22 +8841,16 @@ func (rl *http2clientConnReadLoop) run() error { } gotSettings = true } - maybeIdle := false // whether frame might transition us to idle switch f := f.(type) { case *http2MetaHeadersFrame: err = rl.processHeaders(f) - maybeIdle = true - gotReply = true case *http2DataFrame: err = rl.processData(f) - maybeIdle = true case *http2GoAwayFrame: err = rl.processGoAway(f) - maybeIdle = true case *http2RSTStreamFrame: err = rl.processResetStream(f) - maybeIdle = true case *http2SettingsFrame: err = rl.processSettings(f) case *http2PushPromiseFrame: @@ -8512,38 +8868,24 @@ func (rl *http2clientConnReadLoop) run() error { } return err } - if rl.closeWhenIdle && gotReply && maybeIdle { - cc.closeIfIdle() - } } } func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) error { - cc := rl.cc - cs := cc.streamByID(f.StreamID, false) + cs := rl.streamByID(f.StreamID) if cs == nil { // We'd get here if we canceled a request while the // server had its response still in flight. So if this // was just something we canceled, ignore it. return nil } - if f.StreamEnded() { - // Issue 20521: If the stream has ended, streamByID() causes - // clientStream.done to be closed, which causes the request's bodyWriter - // to be closed with an errStreamClosed, which may be received by - // clientConn.RoundTrip before the result of processing these headers. - // Deferring stream closure allows the header processing to occur first. - // clientConn.RoundTrip may still receive the bodyWriter error first, but - // the fix for issue 16102 prioritises any response. - // - // Issue 22413: If there is no request body, we should close the - // stream before writing to cs.resc so that the stream is closed - // immediately once RoundTrip returns. - if cs.req.Body != nil { - defer cc.forgetStreamID(f.StreamID) - } else { - cc.forgetStreamID(f.StreamID) - } + if cs.readClosed { + rl.endStreamError(cs, http2StreamError{ + StreamID: f.StreamID, + Code: http2ErrCodeProtocol, + Cause: errors.New("protocol error: headers after END_STREAM"), + }) + return nil } if !cs.firstByte { if cs.trace != nil { @@ -8567,9 +8909,11 @@ func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) erro return err } // Any other error type is a stream error. - cs.cc.writeStreamReset(f.StreamID, http2ErrCodeProtocol, err) - cc.forgetStreamID(cs.ID) - cs.resc <- http2resAndError{err: err} + rl.endStreamError(cs, http2StreamError{ + StreamID: f.StreamID, + Code: http2ErrCodeProtocol, + Cause: err, + }) return nil // return nil from process* funcs to keep conn alive } if res == nil { @@ -8577,7 +8921,11 @@ func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) erro return nil } cs.resTrailer = &res.Trailer - cs.resc <- http2resAndError{res: res} + cs.res = res + close(cs.respHeaderRecv) + if f.StreamEnded() { + rl.endStream(cs) + } return nil } @@ -8639,6 +8987,9 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http } if statusCode >= 100 && statusCode <= 199 { + if f.StreamEnded() { + return nil, errors.New("1xx informational response with END_STREAM flag") + } cs.num1xx++ const max1xxResponses = 5 // arbitrary bound on number of informational responses, same as net/http if cs.num1xx > max1xxResponses { @@ -8651,42 +9002,49 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http } if statusCode == 100 { http2traceGot100Continue(cs.trace) - if cs.on100 != nil { - cs.on100() // forces any write delay timer to fire + select { + case cs.on100 <- struct{}{}: + default: } } cs.pastHeaders = false // do it all again return nil, nil } - streamEnded := f.StreamEnded() - isHead := cs.req.Method == "HEAD" - if !streamEnded || isHead { - res.ContentLength = -1 - if clens := res.Header["Content-Length"]; len(clens) == 1 { - if cl, err := strconv.ParseUint(clens[0], 10, 63); err == nil { - res.ContentLength = int64(cl) - } else { - // TODO: care? unlike http/1, it won't mess up our framing, so it's - // more safe smuggling-wise to ignore. - } - } else if len(clens) > 1 { + res.ContentLength = -1 + if clens := res.Header["Content-Length"]; len(clens) == 1 { + if cl, err := strconv.ParseUint(clens[0], 10, 63); err == nil { + res.ContentLength = int64(cl) + } else { // TODO: care? unlike http/1, it won't mess up our framing, so it's // more safe smuggling-wise to ignore. } + } else if len(clens) > 1 { + // TODO: care? unlike http/1, it won't mess up our framing, so it's + // more safe smuggling-wise to ignore. + } else if f.StreamEnded() && !cs.isHead { + res.ContentLength = 0 } - if streamEnded || isHead { + if cs.isHead { res.Body = http2noBody return res, nil } - cs.bufPipe = http2pipe{b: &http2dataBuffer{expected: res.ContentLength}} + if f.StreamEnded() { + if res.ContentLength > 0 { + res.Body = http2missingBody{} + } else { + res.Body = http2noBody + } + return res, nil + } + + cs.bufPipe.setBuffer(&http2dataBuffer{expected: res.ContentLength}) cs.bytesRemain = res.ContentLength res.Body = http2transportResponseBody{cs} - go cs.awaitRequestCancel(cs.req) - if cs.requestedGzip && res.Header.Get("Content-Encoding") == "gzip" { + if cs.requestedGzip && http2asciiEqualFold(res.Header.Get("Content-Encoding"), "gzip") { res.Header.Del("Content-Encoding") res.Header.Del("Content-Length") res.ContentLength = -1 @@ -8725,8 +9083,7 @@ func (rl *http2clientConnReadLoop) processTrailers(cs *http2clientStream, f *htt } // transportResponseBody is the concrete type of Transport.RoundTrip's -// Response.Body. It is an io.ReadCloser. On Read, it reads from cs.body. -// On Close it sends RST_STREAM if EOF wasn't already seen. +// Response.Body. It is an io.ReadCloser. type http2transportResponseBody struct { cs *http2clientStream } @@ -8744,7 +9101,7 @@ func (b http2transportResponseBody) Read(p []byte) (n int, err error) { n = int(cs.bytesRemain) if err == nil { err = errors.New("net/http: server replied with more than declared Content-Length; truncated") - cc.writeStreamReset(cs.ID, http2ErrCodeProtocol, err) + cs.abortStream(err) } cs.readErr = err return int(cs.bytesRemain), err @@ -8762,8 +9119,6 @@ func (b http2transportResponseBody) Read(p []byte) (n int, err error) { } cc.mu.Lock() - defer cc.mu.Unlock() - var connAdd, streamAdd int32 // Check the conn-level first, before the stream-level. if v := cc.inflow.available(); v < http2transportDefaultConnFlow/2 { @@ -8780,6 +9135,8 @@ func (b http2transportResponseBody) Read(p []byte) (n int, err error) { cs.inflow.add(streamAdd) } } + cc.mu.Unlock() + if connAdd != 0 || streamAdd != 0 { cc.wmu.Lock() defer cc.wmu.Unlock() @@ -8800,34 +9157,45 @@ func (b http2transportResponseBody) Close() error { cs := b.cs cc := cs.cc - serverSentStreamEnd := cs.bufPipe.Err() == io.EOF unread := cs.bufPipe.Len() - - if unread > 0 || !serverSentStreamEnd { + if unread > 0 { cc.mu.Lock() - cc.wmu.Lock() - if !serverSentStreamEnd { - cc.fr.WriteRSTStream(cs.ID, http2ErrCodeCancel) - cs.didReset = true - } // Return connection-level flow control. if unread > 0 { cc.inflow.add(int32(unread)) + } + cc.mu.Unlock() + + // TODO(dneil): Acquiring this mutex can block indefinitely. + // Move flow control return to a goroutine? + cc.wmu.Lock() + // Return connection-level flow control. + if unread > 0 { cc.fr.WriteWindowUpdate(0, uint32(unread)) } cc.bw.Flush() cc.wmu.Unlock() - cc.mu.Unlock() } cs.bufPipe.BreakWithError(http2errClosedResponseBody) - cc.forgetStreamID(cs.ID) + cs.abortStream(http2errClosedResponseBody) + + select { + case <-cs.donec: + case <-cs.ctx.Done(): + // See golang/go#49366: The net/http package can cancel the + // request context after the response body is fully read. + // Don't treat this as an error. + return nil + case <-cs.reqCancel: + return http2errRequestCanceled + } return nil } func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { cc := rl.cc - cs := cc.streamByID(f.StreamID, f.StreamEnded()) + cs := rl.streamByID(f.StreamID) data := f.Data() if cs == nil { cc.mu.Lock() @@ -8856,6 +9224,14 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { } return nil } + if cs.readClosed { + cc.logf("protocol error: received DATA after END_STREAM") + rl.endStreamError(cs, http2StreamError{ + StreamID: f.StreamID, + Code: http2ErrCodeProtocol, + }) + return nil + } if !cs.firstByte { cc.logf("protocol error: received DATA before a HEADERS frame") rl.endStreamError(cs, http2StreamError{ @@ -8865,7 +9241,7 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { return nil } if f.Length > 0 { - if cs.req.Method == "HEAD" && len(data) > 0 { + if cs.isHead && len(data) > 0 { cc.logf("protocol error: received DATA on a HEAD request") rl.endStreamError(cs, http2StreamError{ StreamID: f.StreamID, @@ -8887,30 +9263,39 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { if pad := int(f.Length) - len(data); pad > 0 { refund += pad } - // Return len(data) now if the stream is already closed, - // since data will never be read. - didReset := cs.didReset - if didReset { - refund += len(data) + + didReset := false + var err error + if len(data) > 0 { + if _, err = cs.bufPipe.Write(data); err != nil { + // Return len(data) now if the stream is already closed, + // since data will never be read. + didReset = true + refund += len(data) + } } + if refund > 0 { cc.inflow.add(int32(refund)) + if !didReset { + cs.inflow.add(int32(refund)) + } + } + cc.mu.Unlock() + + if refund > 0 { cc.wmu.Lock() cc.fr.WriteWindowUpdate(0, uint32(refund)) if !didReset { - cs.inflow.add(int32(refund)) cc.fr.WriteWindowUpdate(cs.ID, uint32(refund)) } cc.bw.Flush() cc.wmu.Unlock() } - cc.mu.Unlock() - if len(data) > 0 && !didReset { - if _, err := cs.bufPipe.Write(data); err != nil { - rl.endStreamError(cs, err) - return err - } + if err != nil { + rl.endStreamError(cs, err) + return nil } } @@ -8923,24 +9308,32 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { func (rl *http2clientConnReadLoop) endStream(cs *http2clientStream) { // TODO: check that any declared content-length matches, like // server.go's (*stream).endStream method. - rl.endStreamError(cs, nil) + if !cs.readClosed { + cs.readClosed = true + // Close cs.bufPipe and cs.peerClosed with cc.mu held to avoid a + // race condition: The caller can read io.EOF from Response.Body + // and close the body before we close cs.peerClosed, causing + // cleanupWriteRequest to send a RST_STREAM. + rl.cc.mu.Lock() + defer rl.cc.mu.Unlock() + cs.bufPipe.closeWithErrorAndCode(io.EOF, cs.copyTrailers) + close(cs.peerClosed) + } } func (rl *http2clientConnReadLoop) endStreamError(cs *http2clientStream, err error) { - var code func() - if err == nil { - err = io.EOF - code = cs.copyTrailers - } - if http2isConnectionCloseRequest(cs.req) { - rl.closeWhenIdle = true - } - cs.bufPipe.closeWithErrorAndCode(err, code) + cs.readAborted = true + cs.abortStream(err) +} - select { - case cs.resc <- http2resAndError{err: err}: - default: +func (rl *http2clientConnReadLoop) streamByID(id uint32) *http2clientStream { + rl.cc.mu.Lock() + defer rl.cc.mu.Unlock() + cs := rl.cc.streams[id] + if cs != nil && !cs.readAborted { + return cs } + return nil } func (cs *http2clientStream) copyTrailers() { @@ -8959,6 +9352,10 @@ func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error { if f.ErrCode != 0 { // TODO: deal with GOAWAY more. particularly the error code cc.vlogf("transport got GOAWAY with error code = %v", f.ErrCode) + if fn := cc.t.CountError; fn != nil { + fn("recv_goaway_" + f.ErrCode.stringToken()) + } + } cc.setGoAway(f) return nil @@ -8966,6 +9363,23 @@ func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error { func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error { cc := rl.cc + // Locking both mu and wmu here allows frame encoding to read settings with only wmu held. + // Acquiring wmu when f.IsAck() is unnecessary, but convenient and mostly harmless. + cc.wmu.Lock() + defer cc.wmu.Unlock() + + if err := rl.processSettingsNoWrite(f); err != nil { + return err + } + if !f.IsAck() { + cc.fr.WriteSettingsAck() + cc.bw.Flush() + } + return nil +} + +func (rl *http2clientConnReadLoop) processSettingsNoWrite(f *http2SettingsFrame) error { + cc := rl.cc cc.mu.Lock() defer cc.mu.Unlock() @@ -8977,12 +9391,14 @@ func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error return http2ConnectionError(http2ErrCodeProtocol) } + var seenMaxConcurrentStreams bool err := f.ForeachSetting(func(s http2Setting) error { switch s.ID { case http2SettingMaxFrameSize: cc.maxFrameSize = s.Val case http2SettingMaxConcurrentStreams: cc.maxConcurrentStreams = s.Val + seenMaxConcurrentStreams = true case http2SettingMaxHeaderListSize: cc.peerMaxHeaderListSize = uint64(s.Val) case http2SettingInitialWindowSize: @@ -9014,17 +9430,23 @@ func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error return err } - cc.wmu.Lock() - defer cc.wmu.Unlock() + if !cc.seenSettings { + if !seenMaxConcurrentStreams { + // This was the servers initial SETTINGS frame and it + // didn't contain a MAX_CONCURRENT_STREAMS field so + // increase the number of concurrent streams this + // connection can establish to our default. + cc.maxConcurrentStreams = http2defaultMaxConcurrentStreams + } + cc.seenSettings = true + } - cc.fr.WriteSettingsAck() - cc.bw.Flush() - return cc.werr + return nil } func (rl *http2clientConnReadLoop) processWindowUpdate(f *http2WindowUpdateFrame) error { cc := rl.cc - cs := cc.streamByID(f.StreamID, false) + cs := rl.streamByID(f.StreamID) if f.StreamID != 0 && cs == nil { return nil } @@ -9044,24 +9466,22 @@ func (rl *http2clientConnReadLoop) processWindowUpdate(f *http2WindowUpdateFrame } func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) error { - cs := rl.cc.streamByID(f.StreamID, true) + cs := rl.streamByID(f.StreamID) if cs == nil { - // TODO: return error if server tries to RST_STEAM an idle stream + // TODO: return error if server tries to RST_STREAM an idle stream return nil } - select { - case <-cs.peerReset: - // Already reset. - // This is the only goroutine - // which closes this, so there - // isn't a race. - default: - err := http2streamError(cs.ID, f.ErrCode) - cs.resetErr = err - close(cs.peerReset) - cs.bufPipe.CloseWithError(err) - cs.cc.cond.Broadcast() // wake up checkResetOrDone via clientStream.awaitFlowControl + serr := http2streamError(cs.ID, f.ErrCode) + serr.Cause = http2errFromPeer + if f.ErrCode == http2ErrCodeProtocol { + rl.cc.SetDoNotReuse() } + if fn := cs.cc.t.CountError; fn != nil { + fn("recv_rststream_" + f.ErrCode.stringToken()) + } + cs.abortStream(serr) + + cs.bufPipe.CloseWithError(serr) return nil } @@ -9083,19 +9503,24 @@ func (cc *http2ClientConn) Ping(ctx context.Context) error { } cc.mu.Unlock() } - cc.wmu.Lock() - if err := cc.fr.WritePing(false, p); err != nil { - cc.wmu.Unlock() - return err - } - if err := cc.bw.Flush(); err != nil { - cc.wmu.Unlock() - return err - } - cc.wmu.Unlock() + errc := make(chan error, 1) + go func() { + cc.wmu.Lock() + defer cc.wmu.Unlock() + if err := cc.fr.WritePing(false, p); err != nil { + errc <- err + return + } + if err := cc.bw.Flush(); err != nil { + errc <- err + return + } + }() select { case <-c: return nil + case err := <-errc: + return err case <-ctx.Done(): return ctx.Err() case <-cc.readerDone: @@ -9172,6 +9597,12 @@ func (t *http2Transport) logf(format string, args ...interface{}) { var http2noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil)) +type http2missingBody struct{} + +func (http2missingBody) Close() error { return nil } + +func (http2missingBody) Read([]byte) (int, error) { return 0, io.ErrUnexpectedEOF } + func http2strSliceContains(ss []string, s string) bool { for _, v := range ss { if v == s { @@ -9218,87 +9649,6 @@ type http2errorReader struct{ err error } func (r http2errorReader) Read(p []byte) (int, error) { return 0, r.err } -// bodyWriterState encapsulates various state around the Transport's writing -// of the request body, particularly regarding doing delayed writes of the body -// when the request contains "Expect: 100-continue". -type http2bodyWriterState struct { - cs *http2clientStream - timer *time.Timer // if non-nil, we're doing a delayed write - fnonce *sync.Once // to call fn with - fn func() // the code to run in the goroutine, writing the body - resc chan error // result of fn's execution - delay time.Duration // how long we should delay a delayed write for -} - -func (t *http2Transport) getBodyWriterState(cs *http2clientStream, body io.Reader) (s http2bodyWriterState) { - s.cs = cs - if body == nil { - return - } - resc := make(chan error, 1) - s.resc = resc - s.fn = func() { - cs.cc.mu.Lock() - cs.startedWrite = true - cs.cc.mu.Unlock() - resc <- cs.writeRequestBody(body, cs.req.Body) - } - s.delay = t.expectContinueTimeout() - if s.delay == 0 || - !httpguts.HeaderValuesContainsToken( - cs.req.Header["Expect"], - "100-continue") { - return - } - s.fnonce = new(sync.Once) - - // Arm the timer with a very large duration, which we'll - // intentionally lower later. It has to be large now because - // we need a handle to it before writing the headers, but the - // s.delay value is defined to not start until after the - // request headers were written. - const hugeDuration = 365 * 24 * time.Hour - s.timer = time.AfterFunc(hugeDuration, func() { - s.fnonce.Do(s.fn) - }) - return -} - -func (s http2bodyWriterState) cancel() { - if s.timer != nil { - if s.timer.Stop() { - s.resc <- nil - } - } -} - -func (s http2bodyWriterState) on100() { - if s.timer == nil { - // If we didn't do a delayed write, ignore the server's - // bogus 100 continue response. - return - } - s.timer.Stop() - go func() { s.fnonce.Do(s.fn) }() -} - -// scheduleBodyWrite starts writing the body, either immediately (in -// the common case) or after the delay timeout. It should not be -// called until after the headers have been written. -func (s http2bodyWriterState) scheduleBodyWrite() { - if s.timer == nil { - // We're not doing a delayed write (see - // getBodyWriterState), so just start the writing - // goroutine immediately. - go s.fn() - return - } - http2traceWait100Continue(s.cs.trace) - if s.timer.Stop() { - s.timer.Reset(s.delay) - } -} - // isConnectionCloseRequest reports whether req should use its own // connection for a single request and then close the connection. func http2isConnectionCloseRequest(req *Request) bool { @@ -9775,7 +10125,8 @@ type http2WriteScheduler interface { // Pop dequeues the next frame to write. Returns false if no frames can // be written. Frames with a given wr.StreamID() are Pop'd in the same - // order they are Push'd. No frames should be discarded except by CloseStream. + // order they are Push'd, except RST_STREAM frames. No frames should be + // discarded except by CloseStream. Pop() (wr http2FrameWriteRequest, ok bool) } @@ -9795,6 +10146,7 @@ type http2FrameWriteRequest struct { // stream is the stream on which this frame will be written. // nil for non-stream frames like PING and SETTINGS. + // nil for RST_STREAM streams, which use the StreamError.StreamID field instead. stream *http2stream // done, if non-nil, must be a buffered channel with space for @@ -10474,11 +10826,11 @@ func (ws *http2randomWriteScheduler) AdjustStream(streamID uint32, priority http } func (ws *http2randomWriteScheduler) Push(wr http2FrameWriteRequest) { - id := wr.StreamID() - if id == 0 { + if wr.isControl() { ws.zero.push(wr) return } + id := wr.StreamID() q, ok := ws.sq[id] if !ok { q = ws.queuePool.get() @@ -10488,7 +10840,7 @@ func (ws *http2randomWriteScheduler) Push(wr http2FrameWriteRequest) { } func (ws *http2randomWriteScheduler) Pop() (http2FrameWriteRequest, bool) { - // Control frames first. + // Control and RST_STREAM frames first. if !ws.zero.empty() { return ws.zero.shift(), true } diff --git a/libgo/go/net/http/header.go b/libgo/go/net/http/header.go index 4c72dcb..6487e50 100644 --- a/libgo/go/net/http/header.go +++ b/libgo/go/net/http/header.go @@ -13,6 +13,8 @@ import ( "strings" "sync" "time" + + "golang.org/x/net/http/httpguts" ) // A Header represents the key-value pairs in an HTTP header. @@ -155,7 +157,7 @@ func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kv func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } var headerSorterPool = sync.Pool{ - New: func() interface{} { return new(headerSorter) }, + New: func() any { return new(headerSorter) }, } // sortedKeyValues returns h's keys sorted in the returned kvs @@ -192,6 +194,13 @@ func (h Header) writeSubset(w io.Writer, exclude map[string]bool, trace *httptra kvs, sorter := h.sortedKeyValues(exclude) var formattedVals []string for _, kv := range kvs { + if !httpguts.ValidHeaderFieldName(kv.key) { + // This could be an error. In the common case of + // writing response headers, however, we have no good + // way to provide the error back to the server + // handler, so just drop invalid headers instead. + continue + } for _, v := range kv.values { v = headerNewlineToSpace.Replace(v) v = textproto.TrimString(v) diff --git a/libgo/go/net/http/header_test.go b/libgo/go/net/http/header_test.go index ad8ab9b..575493b 100644 --- a/libgo/go/net/http/header_test.go +++ b/libgo/go/net/http/header_test.go @@ -89,6 +89,19 @@ var headerWriteTests = []struct { "k4: 4a\r\nk4: 4b\r\nk6: 6a\r\nk6: 6b\r\n" + "k7: 7a\r\nk7: 7b\r\nk8: 8a\r\nk8: 8b\r\nk9: 9a\r\nk9: 9b\r\n", }, + // Tests invalid characters in headers. + { + Header{ + "Content-Type": {"text/html; charset=UTF-8"}, + "NewlineInValue": {"1\r\nBar: 2"}, + "NewlineInKey\r\n": {"1"}, + "Colon:InKey": {"1"}, + "Evil: 1\r\nSmuggledValue": {"1"}, + }, + nil, + "Content-Type: text/html; charset=UTF-8\r\n" + + "NewlineInValue: 1 Bar: 2\r\n", + }, } func TestHeaderWrite(t *testing.T) { diff --git a/libgo/go/net/http/httptrace/trace.go b/libgo/go/net/http/httptrace/trace.go index 5777c91..6af30f7 100644 --- a/libgo/go/net/http/httptrace/trace.go +++ b/libgo/go/net/http/httptrace/trace.go @@ -50,7 +50,7 @@ func WithClientTrace(ctx context.Context, trace *ClientTrace) context.Context { } } if trace.DNSDone != nil { - nt.DNSDone = func(netIPs []interface{}, coalesced bool, err error) { + nt.DNSDone = func(netIPs []any, coalesced bool, err error) { addrs := make([]net.IPAddr, len(netIPs)) for i, ip := range netIPs { addrs[i] = ip.(net.IPAddr) diff --git a/libgo/go/net/http/httputil/dump.go b/libgo/go/net/http/httputil/dump.go index 2948f27..d7baecd 100644 --- a/libgo/go/net/http/httputil/dump.go +++ b/libgo/go/net/http/httputil/dump.go @@ -292,7 +292,7 @@ func DumpRequest(req *http.Request, body bool) ([]byte, error) { // can detect that the lack of body was intentional. var errNoBody = errors.New("sentinel error value") -// failureToReadBody is a io.ReadCloser that just returns errNoBody on +// failureToReadBody is an io.ReadCloser that just returns errNoBody on // Read. It's swapped in when we don't actually want to consume // the body, but need a non-nil one, and want to distinguish the // error from reading the dummy body. diff --git a/libgo/go/net/http/httputil/dump_test.go b/libgo/go/net/http/httputil/dump_test.go index 366cc82..5df2ee8 100644 --- a/libgo/go/net/http/httputil/dump_test.go +++ b/libgo/go/net/http/httputil/dump_test.go @@ -31,7 +31,7 @@ type dumpTest struct { Req *http.Request GetReq func() *http.Request - Body interface{} // optional []byte or func() io.ReadCloser to populate Req.Body + Body any // optional []byte or func() io.ReadCloser to populate Req.Body WantDump string WantDumpOut string diff --git a/libgo/go/net/http/httputil/reverseproxy.go b/libgo/go/net/http/httputil/reverseproxy.go index 8b63368..319e2a3 100644 --- a/libgo/go/net/http/httputil/reverseproxy.go +++ b/libgo/go/net/http/httputil/reverseproxy.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "log" + "mime" "net" "net/http" "net/http/internal/ascii" @@ -412,7 +413,7 @@ func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration { // For Server-Sent Events responses, flush immediately. // The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream - if resCT == "text/event-stream" { + if baseCT, _, _ := mime.ParseMediaType(resCT); baseCT == "text/event-stream" { return -1 // negative means immediately } @@ -483,7 +484,7 @@ func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int } } -func (p *ReverseProxy) logf(format string, args ...interface{}) { +func (p *ReverseProxy) logf(format string, args ...any) { if p.ErrorLog != nil { p.ErrorLog.Printf(format, args...) } else { diff --git a/libgo/go/net/http/httputil/reverseproxy_test.go b/libgo/go/net/http/httputil/reverseproxy_test.go index 4b6ad77..90e8903 100644 --- a/libgo/go/net/http/httputil/reverseproxy_test.go +++ b/libgo/go/net/http/httputil/reverseproxy_test.go @@ -1195,6 +1195,26 @@ func TestSelectFlushInterval(t *testing.T) { want: -1, }, { + name: "server-sent events with media-type parameters overrides non-zero", + res: &http.Response{ + Header: http.Header{ + "Content-Type": {"text/event-stream;charset=utf-8"}, + }, + }, + p: &ReverseProxy{FlushInterval: 123}, + want: -1, + }, + { + name: "server-sent events with media-type parameters overrides zero", + res: &http.Response{ + Header: http.Header{ + "Content-Type": {"text/event-stream;charset=utf-8"}, + }, + }, + p: &ReverseProxy{FlushInterval: 0}, + want: -1, + }, + { name: "Content-Length: -1, overrides non-zero", res: &http.Response{ ContentLength: -1, diff --git a/libgo/go/net/http/internal/chunked.go b/libgo/go/net/http/internal/chunked.go index f06e572..37a72e9 100644 --- a/libgo/go/net/http/internal/chunked.go +++ b/libgo/go/net/http/internal/chunked.go @@ -81,6 +81,11 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) { cr.err = errors.New("malformed chunked encoding") break } + } else { + if cr.err == io.EOF { + cr.err = io.ErrUnexpectedEOF + } + break } cr.checkEnd = false } @@ -109,6 +114,8 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) { // bytes to verify they are "\r\n". if cr.n == 0 && cr.err == nil { cr.checkEnd = true + } else if cr.err == io.EOF { + cr.err = io.ErrUnexpectedEOF } } return n, cr.err @@ -152,6 +159,8 @@ func isASCIISpace(b byte) bool { return b == ' ' || b == '\t' || b == '\n' || b == '\r' } +var semi = []byte(";") + // removeChunkExtension removes any chunk-extension from p. // For example, // "0" => "0" @@ -159,14 +168,11 @@ func isASCIISpace(b byte) bool { // "0;token=val" => "0" // `0;token="quoted string"` => "0" func removeChunkExtension(p []byte) ([]byte, error) { - semi := bytes.IndexByte(p, ';') - if semi == -1 { - return p, nil - } + p, _, _ = bytes.Cut(p, semi) // TODO: care about exact syntax of chunk extensions? We're // ignoring and stripping them anyway. For now just never // return an error. - return p[:semi], nil + return p, nil } // NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP diff --git a/libgo/go/net/http/internal/chunked_test.go b/libgo/go/net/http/internal/chunked_test.go index 08152ed..5e29a78 100644 --- a/libgo/go/net/http/internal/chunked_test.go +++ b/libgo/go/net/http/internal/chunked_test.go @@ -11,6 +11,7 @@ import ( "io" "strings" "testing" + "testing/iotest" ) func TestChunk(t *testing.T) { @@ -211,3 +212,30 @@ func TestChunkReadPartial(t *testing.T) { } } + +// Issue 48861: ChunkedReader should report incomplete chunks +func TestIncompleteChunk(t *testing.T) { + const valid = "4\r\nabcd\r\n" + "5\r\nabc\r\n\r\n" + "0\r\n" + + for i := 0; i < len(valid); i++ { + incomplete := valid[:i] + r := NewChunkedReader(strings.NewReader(incomplete)) + if _, err := io.ReadAll(r); err != io.ErrUnexpectedEOF { + t.Errorf("expected io.ErrUnexpectedEOF for %q, got %v", incomplete, err) + } + } + + r := NewChunkedReader(strings.NewReader(valid)) + if _, err := io.ReadAll(r); err != nil { + t.Errorf("unexpected error for %q: %v", valid, err) + } +} + +func TestChunkEndReadError(t *testing.T) { + readErr := fmt.Errorf("chunk end read error") + + r := NewChunkedReader(io.MultiReader(strings.NewReader("4\r\nabcd"), iotest.ErrReader(readErr))) + if _, err := io.ReadAll(r); err != readErr { + t.Errorf("expected %v, got %v", readErr, err) + } +} diff --git a/libgo/go/net/http/internal/testcert/testcert.go b/libgo/go/net/http/internal/testcert/testcert.go index 5f94704..d510e79 100644 --- a/libgo/go/net/http/internal/testcert/testcert.go +++ b/libgo/go/net/http/internal/testcert/testcert.go @@ -10,37 +10,56 @@ import "strings" // LocalhostCert is a PEM-encoded TLS cert with SAN IPs // "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT. // generated from src/crypto/tls: -// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h +// go run generate_cert.go --rsa-bits 2048 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h var LocalhostCert = []byte(`-----BEGIN CERTIFICATE----- -MIICEzCCAXygAwIBAgIQMIMChMLGrR+QvmQvpwAU6zANBgkqhkiG9w0BAQsFADAS +MIIDOTCCAiGgAwIBAgIQSRJrEpBGFc7tNb1fb5pKFzANBgkqhkiG9w0BAQsFADAS MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw -MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB -iQKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9SjY1bIw4 -iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZBl2+XsDul -rKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQABo2gwZjAO -BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw -AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA -AAAAATANBgkqhkiG9w0BAQsFAAOBgQCEcetwO59EWk7WiJsG4x8SY+UIAA+flUI9 -tyC4lNhbcF2Idq9greZwbYCqTTTr2XiRNSMLCOjKyI7ukPoPjo16ocHj+P3vZGfs -h1fIw3cSS2OolhloGw/XM6RWPWtPAlGykKLciQrBru5NAPvCMsb/I1DAceTiotQM -fblo6RBxUQ== +MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEA6Gba5tHV1dAKouAaXO3/ebDUU4rvwCUg/CNaJ2PT5xLD4N1Vcb8r +bFSW2HXKq+MPfVdwIKR/1DczEoAGf/JWQTW7EgzlXrCd3rlajEX2D73faWJekD0U +aUgz5vtrTXZ90BQL7WvRICd7FlEZ6FPOcPlumiyNmzUqtwGhO+9ad1W5BqJaRI6P +YfouNkwR6Na4TzSj5BrqUfP0FwDizKSJ0XXmh8g8G9mtwxOSN3Ru1QFc61Xyeluk +POGKBV/q6RBNklTNe0gI8usUMlYyoC7ytppNMW7X2vodAelSu25jgx2anj9fDVZu +h7AXF5+4nJS4AAt0n1lNY7nGSsdZas8PbQIDAQABo4GIMIGFMA4GA1UdDwEB/wQE +AwICpDATBgNVHSUEDDAKBggrBgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MB0GA1Ud +DgQWBBStsdjh3/JCXXYlQryOrL4Sh7BW5TAuBgNVHREEJzAlggtleGFtcGxlLmNv +bYcEfwAAAYcQAAAAAAAAAAAAAAAAAAAAATANBgkqhkiG9w0BAQsFAAOCAQEAxWGI +5NhpF3nwwy/4yB4i/CwwSpLrWUa70NyhvprUBC50PxiXav1TeDzwzLx/o5HyNwsv +cxv3HdkLW59i/0SlJSrNnWdfZ19oTcS+6PtLoVyISgtyN6DpkKpdG1cOkW3Cy2P2 ++tK/tKHRP1Y/Ra0RiDpOAmqn0gCOFGz8+lqDIor/T7MTpibL3IxqWfPrvfVRHL3B +grw/ZQTTIVjjh4JBSW3WyWgNo/ikC1lrVxzl4iPUGptxT36Cr7Zk2Bsg0XqwbOvK +5d+NTDREkSnUbie4GeutujmX3Dsx88UiV6UY/4lHJa6I5leHUNOHahRbpbWeOfs/ +WkBKOclmOV2xlTVuPw== -----END CERTIFICATE-----`) // LocalhostKey is the private key for LocalhostCert. var LocalhostKey = []byte(testingKey(`-----BEGIN RSA TESTING KEY----- -MIICXgIBAAKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9 -SjY1bIw4iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZB -l2+XsDulrKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQAB -AoGAGRzwwir7XvBOAy5tM/uV6e+Zf6anZzus1s1Y1ClbjbE6HXbnWWF/wbZGOpet -3Zm4vD6MXc7jpTLryzTQIvVdfQbRc6+MUVeLKwZatTXtdZrhu+Jk7hx0nTPy8Jcb -uJqFk541aEw+mMogY/xEcfbWd6IOkp+4xqjlFLBEDytgbIECQQDvH/E6nk+hgN4H -qzzVtxxr397vWrjrIgPbJpQvBsafG7b0dA4AFjwVbFLmQcj2PprIMmPcQrooz8vp -jy4SHEg1AkEA/v13/5M47K9vCxmb8QeD/asydfsgS5TeuNi8DoUBEmiSJwma7FXY -fFUtxuvL7XvjwjN5B30pNEbc6Iuyt7y4MQJBAIt21su4b3sjXNueLKH85Q+phy2U -fQtuUE9txblTu14q3N7gHRZB4ZMhFYyDy8CKrN2cPg/Fvyt0Xlp/DoCzjA0CQQDU -y2ptGsuSmgUtWj3NM9xuwYPm+Z/F84K6+ARYiZ6PYj013sovGKUFfYAqVXVlxtIX -qyUBnu3X9ps8ZfjLZO7BAkEAlT4R5Yl6cGhaJQYZHOde3JEMhNRcVFMO8dJDaFeo -f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA== +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDoZtrm0dXV0Aqi +4Bpc7f95sNRTiu/AJSD8I1onY9PnEsPg3VVxvytsVJbYdcqr4w99V3AgpH/UNzMS +gAZ/8lZBNbsSDOVesJ3euVqMRfYPvd9pYl6QPRRpSDPm+2tNdn3QFAvta9EgJ3sW +URnoU85w+W6aLI2bNSq3AaE771p3VbkGolpEjo9h+i42TBHo1rhPNKPkGupR8/QX +AOLMpInRdeaHyDwb2a3DE5I3dG7VAVzrVfJ6W6Q84YoFX+rpEE2SVM17SAjy6xQy +VjKgLvK2mk0xbtfa+h0B6VK7bmODHZqeP18NVm6HsBcXn7iclLgAC3SfWU1jucZK +x1lqzw9tAgMBAAECggEABWzxS1Y2wckblnXY57Z+sl6YdmLV+gxj2r8Qib7g4ZIk +lIlWR1OJNfw7kU4eryib4fc6nOh6O4AWZyYqAK6tqNQSS/eVG0LQTLTTEldHyVJL +dvBe+MsUQOj4nTndZW+QvFzbcm2D8lY5n2nBSxU5ypVoKZ1EqQzytFcLZpTN7d89 +EPj0qDyrV4NZlWAwL1AygCwnlwhMQjXEalVF1ylXwU3QzyZ/6MgvF6d3SSUlh+sq +XefuyigXw484cQQgbzopv6niMOmGP3of+yV4JQqUSb3IDmmT68XjGd2Dkxl4iPki +6ZwXf3CCi+c+i/zVEcufgZ3SLf8D99kUGE7v7fZ6AQKBgQD1ZX3RAla9hIhxCf+O +3D+I1j2LMrdjAh0ZKKqwMR4JnHX3mjQI6LwqIctPWTU8wYFECSh9klEclSdCa64s +uI/GNpcqPXejd0cAAdqHEEeG5sHMDt0oFSurL4lyud0GtZvwlzLuwEweuDtvT9cJ +Wfvl86uyO36IW8JdvUprYDctrQKBgQDycZ697qutBieZlGkHpnYWUAeImVA878sJ +w44NuXHvMxBPz+lbJGAg8Cn8fcxNAPqHIraK+kx3po8cZGQywKHUWsxi23ozHoxo ++bGqeQb9U661TnfdDspIXia+xilZt3mm5BPzOUuRqlh4Y9SOBpSWRmEhyw76w4ZP +OPxjWYAgwQKBgA/FehSYxeJgRjSdo+MWnK66tjHgDJE8bYpUZsP0JC4R9DL5oiaA +brd2fI6Y+SbyeNBallObt8LSgzdtnEAbjIH8uDJqyOmknNePRvAvR6mP4xyuR+Bv +m+Lgp0DMWTw5J9CKpydZDItc49T/mJ5tPhdFVd+am0NAQnmr1MCZ6nHxAoGABS3Y +LkaC9FdFUUqSU8+Chkd/YbOkuyiENdkvl6t2e52jo5DVc1T7mLiIrRQi4SI8N9bN +/3oJWCT+uaSLX2ouCtNFunblzWHBrhxnZzTeqVq4SLc8aESAnbslKL4i8/+vYZlN +s8xtiNcSvL+lMsOBORSXzpj/4Ot8WwTkn1qyGgECgYBKNTypzAHeLE6yVadFp3nQ +Ckq9yzvP/ib05rvgbvrne00YeOxqJ9gtTrzgh7koqJyX1L4NwdkEza4ilDWpucn0 +xiUZS4SoaJq6ZvcBYS62Yr1t8n09iG47YL8ibgtmH3L+svaotvpVxVK+d7BLevA/ +ZboOWVe3icTy64BT3OQhmg== -----END RSA TESTING KEY-----`)) func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } diff --git a/libgo/go/net/http/main_test.go b/libgo/go/net/http/main_test.go index 6564627..27872b4 100644 --- a/libgo/go/net/http/main_test.go +++ b/libgo/go/net/http/main_test.go @@ -31,11 +31,8 @@ func interestingGoroutines() (gs []string) { buf := make([]byte, 2<<20) buf = buf[:runtime.Stack(buf, true)] for _, g := range strings.Split(string(buf), "\n\n") { - sl := strings.SplitN(g, "\n", 2) - if len(sl) != 2 { - continue - } - stack := strings.TrimSpace(sl[1]) + _, stack, _ := strings.Cut(g, "\n") + stack = strings.TrimSpace(stack) if stack == "" || strings.Contains(stack, "testing.(*M).before.func1") || strings.Contains(stack, "os/signal.signal_recv") || @@ -46,7 +43,7 @@ func interestingGoroutines() (gs []string) { // These only show up with GOTRACEBACK=2; Issue 5005 (comment 28) strings.Contains(stack, "runtime.goexit") || strings.Contains(stack, "created by runtime.gc") || - strings.Contains(stack, "net/http_test.interestingGoroutines") || + strings.Contains(stack, "interestingGoroutines") || strings.Contains(stack, "runtime.MHeap_Scavenger") { continue } diff --git a/libgo/go/net/http/omithttp2.go b/libgo/go/net/http/omithttp2.go index 79599d0..3316f55 100644 --- a/libgo/go/net/http/omithttp2.go +++ b/libgo/go/net/http/omithttp2.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build nethttpomithttp2 -// +build nethttpomithttp2 package http @@ -27,7 +26,7 @@ const http2NextProtoTLS = "h2" type http2Transport struct { MaxHeaderListSize uint32 - ConnPool interface{} + ConnPool any } func (*http2Transport) RoundTrip(*Request) (*Response, error) { panic(noHTTP2) } @@ -57,9 +56,9 @@ type http2Server struct { NewWriteScheduler func() http2WriteScheduler } -type http2WriteScheduler interface{} +type http2WriteScheduler any -func http2NewPriorityWriteScheduler(interface{}) http2WriteScheduler { panic(noHTTP2) } +func http2NewPriorityWriteScheduler(any) http2WriteScheduler { panic(noHTTP2) } func http2ConfigureServer(s *Server, conf *http2Server) error { panic(noHTTP2) } diff --git a/libgo/go/net/http/pprof/pprof.go b/libgo/go/net/http/pprof/pprof.go index 888ea35..dc855c8 100644 --- a/libgo/go/net/http/pprof/pprof.go +++ b/libgo/go/net/http/pprof/pprof.go @@ -44,7 +44,7 @@ // The package also exports a handler that serves execution trace data // for the "go tool trace" command. To collect a 5-second execution trace: // -// wget -O trace.out http://localhost:6060/debug/pprof/trace?seconds=5 +// curl -o trace.out http://localhost:6060/debug/pprof/trace?seconds=5 // go tool trace trace.out // // To view all available profiles, open http://localhost:6060/debug/pprof/ diff --git a/libgo/go/net/http/pprof/pprof_test.go b/libgo/go/net/http/pprof/pprof_test.go index 84757e4..1a4d653 100644 --- a/libgo/go/net/http/pprof/pprof_test.go +++ b/libgo/go/net/http/pprof/pprof_test.go @@ -8,6 +8,7 @@ import ( "bytes" "fmt" "internal/profile" + "internal/testenv" "io" "net/http" "net/http/httptest" @@ -152,6 +153,10 @@ func mutexHog(duration time.Duration, hogger func(mu1, mu2 *sync.Mutex, start ti } func TestDeltaProfile(t *testing.T) { + if runtime.GOOS == "openbsd" && runtime.GOARCH == "arm" { + testenv.SkipFlaky(t, 50218) + } + rate := runtime.SetMutexProfileFraction(1) defer func() { runtime.SetMutexProfileFraction(rate) diff --git a/libgo/go/net/http/request.go b/libgo/go/net/http/request.go index 09cb0c7..76c2317 100644 --- a/libgo/go/net/http/request.go +++ b/libgo/go/net/http/request.go @@ -779,11 +779,10 @@ func removeZone(host string) string { return host[:j] + host[i:] } -// ParseHTTPVersion parses an HTTP version string. +// ParseHTTPVersion parses an HTTP version string according to RFC 7230, section 2.6. // "HTTP/1.0" returns (1, 0, true). Note that strings without // a minor version, such as "HTTP/2", are not valid. func ParseHTTPVersion(vers string) (major, minor int, ok bool) { - const Big = 1000000 // arbitrary upper bound switch vers { case "HTTP/1.1": return 1, 1, true @@ -793,19 +792,21 @@ func ParseHTTPVersion(vers string) (major, minor int, ok bool) { if !strings.HasPrefix(vers, "HTTP/") { return 0, 0, false } - dot := strings.Index(vers, ".") - if dot < 0 { + if len(vers) != len("HTTP/X.Y") { return 0, 0, false } - major, err := strconv.Atoi(vers[5:dot]) - if err != nil || major < 0 || major > Big { + if vers[6] != '.' { return 0, 0, false } - minor, err = strconv.Atoi(vers[dot+1:]) - if err != nil || minor < 0 || minor > Big { + maj, err := strconv.ParseUint(vers[5:6], 10, 0) + if err != nil { return 0, 0, false } - return major, minor, true + min, err := strconv.ParseUint(vers[7:8], 10, 0) + if err != nil { + return 0, 0, false + } + return int(maj), int(min), true } func validMethod(method string) bool { @@ -939,7 +940,7 @@ func NewRequestWithContext(ctx context.Context, method, url string, body io.Read func (r *Request) BasicAuth() (username, password string, ok bool) { auth := r.Header.Get("Authorization") if auth == "" { - return + return "", "", false } return parseBasicAuth(auth) } @@ -950,18 +951,18 @@ func parseBasicAuth(auth string) (username, password string, ok bool) { const prefix = "Basic " // Case insensitive prefix match. See Issue 22736. if len(auth) < len(prefix) || !ascii.EqualFold(auth[:len(prefix)], prefix) { - return + return "", "", false } c, err := base64.StdEncoding.DecodeString(auth[len(prefix):]) if err != nil { - return + return "", "", false } cs := string(c) - s := strings.IndexByte(cs, ':') - if s < 0 { - return + username, password, ok = strings.Cut(cs, ":") + if !ok { + return "", "", false } - return cs[:s], cs[s+1:], true + return username, password, true } // SetBasicAuth sets the request's Authorization header to use HTTP @@ -979,13 +980,12 @@ func (r *Request) SetBasicAuth(username, password string) { // parseRequestLine parses "GET /foo HTTP/1.1" into its three parts. func parseRequestLine(line string) (method, requestURI, proto string, ok bool) { - s1 := strings.Index(line, " ") - s2 := strings.Index(line[s1+1:], " ") - if s1 < 0 || s2 < 0 { - return + method, rest, ok1 := strings.Cut(line, " ") + requestURI, proto, ok2 := strings.Cut(rest, " ") + if !ok1 || !ok2 { + return "", "", "", false } - s2 += s1 + 1 - return line[:s1], line[s1+1 : s2], line[s2+1:], true + return method, requestURI, proto, true } var textprotoReaderPool sync.Pool diff --git a/libgo/go/net/http/request_test.go b/libgo/go/net/http/request_test.go index 4e0c4ba..4363e11 100644 --- a/libgo/go/net/http/request_test.go +++ b/libgo/go/net/http/request_test.go @@ -639,10 +639,10 @@ var parseHTTPVersionTests = []struct { major, minor int ok bool }{ + {"HTTP/0.0", 0, 0, true}, {"HTTP/0.9", 0, 9, true}, {"HTTP/1.0", 1, 0, true}, {"HTTP/1.1", 1, 1, true}, - {"HTTP/3.14", 3, 14, true}, {"HTTP", 0, 0, false}, {"HTTP/one.one", 0, 0, false}, @@ -651,6 +651,12 @@ var parseHTTPVersionTests = []struct { {"HTTP/0,-1", 0, 0, false}, {"HTTP/", 0, 0, false}, {"HTTP/1,1", 0, 0, false}, + {"HTTP/+1.1", 0, 0, false}, + {"HTTP/1.+1", 0, 0, false}, + {"HTTP/0000000001.1", 0, 0, false}, + {"HTTP/1.0000000001", 0, 0, false}, + {"HTTP/3.14", 0, 0, false}, + {"HTTP/12.3", 0, 0, false}, } func TestParseHTTPVersion(t *testing.T) { diff --git a/libgo/go/net/http/requestwrite_test.go b/libgo/go/net/http/requestwrite_test.go index 1157bdf..bdc1e3c 100644 --- a/libgo/go/net/http/requestwrite_test.go +++ b/libgo/go/net/http/requestwrite_test.go @@ -20,7 +20,7 @@ import ( type reqWriteTest struct { Req Request - Body interface{} // optional []byte or func() io.ReadCloser to populate Req.Body + Body any // optional []byte or func() io.ReadCloser to populate Req.Body // Any of these three may be empty to skip that test. WantWrite string // Request.Write diff --git a/libgo/go/net/http/response.go b/libgo/go/net/http/response.go index b8985da..297394e 100644 --- a/libgo/go/net/http/response.go +++ b/libgo/go/net/http/response.go @@ -165,16 +165,14 @@ func ReadResponse(r *bufio.Reader, req *Request) (*Response, error) { } return nil, err } - if i := strings.IndexByte(line, ' '); i == -1 { + proto, status, ok := strings.Cut(line, " ") + if !ok { return nil, badStringError("malformed HTTP response", line) - } else { - resp.Proto = line[:i] - resp.Status = strings.TrimLeft(line[i+1:], " ") - } - statusCode := resp.Status - if i := strings.IndexByte(resp.Status, ' '); i != -1 { - statusCode = resp.Status[:i] } + resp.Proto = proto + resp.Status = strings.TrimLeft(status, " ") + + statusCode, _, _ := strings.Cut(resp.Status, " ") if len(statusCode) != 3 { return nil, badStringError("malformed HTTP status code", statusCode) } @@ -182,7 +180,6 @@ func ReadResponse(r *bufio.Reader, req *Request) (*Response, error) { if err != nil || resp.StatusCode < 0 { return nil, badStringError("malformed HTTP status code", statusCode) } - var ok bool if resp.ProtoMajor, resp.ProtoMinor, ok = ParseHTTPVersion(resp.Proto); !ok { return nil, badStringError("malformed HTTP version", resp.Proto) } diff --git a/libgo/go/net/http/response_test.go b/libgo/go/net/http/response_test.go index 8eef654..5a735b0 100644 --- a/libgo/go/net/http/response_test.go +++ b/libgo/go/net/http/response_test.go @@ -646,8 +646,8 @@ type readerAndCloser struct { func TestReadResponseCloseInMiddle(t *testing.T) { t.Parallel() for _, test := range readResponseCloseInMiddleTests { - fatalf := func(format string, args ...interface{}) { - args = append([]interface{}{test.chunked, test.compressed}, args...) + fatalf := func(format string, args ...any) { + args = append([]any{test.chunked, test.compressed}, args...) t.Fatalf("on test chunked=%v, compressed=%v: "+format, args...) } checkErr := func(err error, msg string) { @@ -732,7 +732,7 @@ func TestReadResponseCloseInMiddle(t *testing.T) { } } -func diff(t *testing.T, prefix string, have, want interface{}) { +func diff(t *testing.T, prefix string, have, want any) { t.Helper() hv := reflect.ValueOf(have).Elem() wv := reflect.ValueOf(want).Elem() @@ -849,10 +849,10 @@ func TestReadResponseErrors(t *testing.T) { type testCase struct { name string // optional, defaults to in in string - wantErr interface{} // nil, err value, or string substring + wantErr any // nil, err value, or string substring } - status := func(s string, wantErr interface{}) testCase { + status := func(s string, wantErr any) testCase { if wantErr == true { wantErr = "malformed HTTP status code" } @@ -863,7 +863,7 @@ func TestReadResponseErrors(t *testing.T) { } } - version := func(s string, wantErr interface{}) testCase { + version := func(s string, wantErr any) testCase { if wantErr == true { wantErr = "malformed HTTP version" } @@ -874,7 +874,7 @@ func TestReadResponseErrors(t *testing.T) { } } - contentLength := func(status, body string, wantErr interface{}) testCase { + contentLength := func(status, body string, wantErr any) testCase { return testCase{ name: fmt.Sprintf("status %q %q", status, body), in: fmt.Sprintf("HTTP/1.1 %s\r\n%s", status, body), @@ -947,7 +947,7 @@ func TestReadResponseErrors(t *testing.T) { // wantErr can be nil, an error value to match exactly, or type string to // match a substring. -func matchErr(err error, wantErr interface{}) error { +func matchErr(err error, wantErr any) error { if err == nil { if wantErr == nil { return nil diff --git a/libgo/go/net/http/roundtrip.go b/libgo/go/net/http/roundtrip.go index eef7c79..c4c5d3b 100644 --- a/libgo/go/net/http/roundtrip.go +++ b/libgo/go/net/http/roundtrip.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build !js || !wasm -// +build !js !wasm package http diff --git a/libgo/go/net/http/roundtrip_js.go b/libgo/go/net/http/roundtrip_js.go index 74c83a9..01c0600 100644 --- a/libgo/go/net/http/roundtrip_js.go +++ b/libgo/go/net/http/roundtrip_js.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build js && wasm -// +build js,wasm package http @@ -41,11 +40,19 @@ const jsFetchCreds = "js.fetch:credentials" // Reference: https://developer.mozilla.org/en-US/docs/Web/API/WindowOrWorkerGlobalScope/fetch#Parameters const jsFetchRedirect = "js.fetch:redirect" -var useFakeNetwork = js.Global().Get("fetch").IsUndefined() +// jsFetchMissing will be true if the Fetch API is not present in +// the browser globals. +var jsFetchMissing = js.Global().Get("fetch").IsUndefined() // RoundTrip implements the RoundTripper interface using the WHATWG Fetch API. func (t *Transport) RoundTrip(req *Request) (*Response, error) { - if useFakeNetwork { + // The Transport has a documented contract that states that if the DialContext or + // DialTLSContext functions are set, they will be used to set up the connections. + // If they aren't set then the documented contract is to use Dial or DialTLS, even + // though they are deprecated. Therefore, if any of these are set, we should obey + // the contract and dial using the regular round-trip instead. Otherwise, we'll try + // to fall back on the Fetch API, unless it's not available. + if t.Dial != nil || t.DialContext != nil || t.DialTLS != nil || t.DialTLSContext != nil || jsFetchMissing { return t.roundTrip(req) } @@ -111,7 +118,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { errCh = make(chan error, 1) success, failure js.Func ) - success = js.FuncOf(func(this js.Value, args []js.Value) interface{} { + success = js.FuncOf(func(this js.Value, args []js.Value) any { success.Release() failure.Release() @@ -131,8 +138,24 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { } contentLength := int64(0) - if cl, err := strconv.ParseInt(header.Get("Content-Length"), 10, 64); err == nil { + clHeader := header.Get("Content-Length") + switch { + case clHeader != "": + cl, err := strconv.ParseInt(clHeader, 10, 64) + if err != nil { + errCh <- fmt.Errorf("net/http: ill-formed Content-Length header: %v", err) + return nil + } + if cl < 0 { + // Content-Length values less than 0 are invalid. + // See: https://datatracker.ietf.org/doc/html/rfc2616/#section-14.13 + errCh <- fmt.Errorf("net/http: invalid Content-Length header: %q", clHeader) + return nil + } contentLength = cl + default: + // If the response length is not declared, set it to -1. + contentLength = -1 } b := result.Get("body") @@ -159,7 +182,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { return nil }) - failure = js.FuncOf(func(this js.Value, args []js.Value) interface{} { + failure = js.FuncOf(func(this js.Value, args []js.Value) any { success.Release() failure.Release() errCh <- fmt.Errorf("net/http: fetch() failed: %s", args[0].Get("message").String()) @@ -200,7 +223,7 @@ func (r *streamReader) Read(p []byte) (n int, err error) { bCh = make(chan []byte, 1) errCh = make(chan error, 1) ) - success := js.FuncOf(func(this js.Value, args []js.Value) interface{} { + success := js.FuncOf(func(this js.Value, args []js.Value) any { result := args[0] if result.Get("done").Bool() { errCh <- io.EOF @@ -212,7 +235,7 @@ func (r *streamReader) Read(p []byte) (n int, err error) { return nil }) defer success.Release() - failure := js.FuncOf(func(this js.Value, args []js.Value) interface{} { + failure := js.FuncOf(func(this js.Value, args []js.Value) any { // Assumes it's a TypeError. See // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/TypeError // for more information on this type. See @@ -266,7 +289,7 @@ func (r *arrayReader) Read(p []byte) (n int, err error) { bCh = make(chan []byte, 1) errCh = make(chan error, 1) ) - success := js.FuncOf(func(this js.Value, args []js.Value) interface{} { + success := js.FuncOf(func(this js.Value, args []js.Value) any { // Wrap the input ArrayBuffer with a Uint8Array uint8arrayWrapper := uint8Array.New(args[0]) value := make([]byte, uint8arrayWrapper.Get("byteLength").Int()) @@ -275,7 +298,7 @@ func (r *arrayReader) Read(p []byte) (n int, err error) { return nil }) defer success.Release() - failure := js.FuncOf(func(this js.Value, args []js.Value) interface{} { + failure := js.FuncOf(func(this js.Value, args []js.Value) any { // Assumes it's a TypeError. See // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/TypeError // for more information on this type. diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go index 6394da3..fb18cb2 100644 --- a/libgo/go/net/http/serve_test.go +++ b/libgo/go/net/http/serve_test.go @@ -23,6 +23,7 @@ import ( "net" . "net/http" "net/http/httptest" + "net/http/httptrace" "net/http/httputil" "net/http/internal" "net/http/internal/testcert" @@ -2146,7 +2147,7 @@ func TestInvalidTrailerClosesConnection(t *testing.T) { // Read and Write. type slowTestConn struct { // over multiple calls to Read, time.Durations are slept, strings are read. - script []interface{} + script []any closec chan bool mu sync.Mutex // guards rd/wd @@ -2238,7 +2239,7 @@ func TestRequestBodyTimeoutClosesConnection(t *testing.T) { defer afterTest(t) for _, handler := range testHandlerBodyConsumers { conn := &slowTestConn{ - script: []interface{}{ + script: []any{ "POST /public HTTP/1.1\r\n" + "Host: test\r\n" + "Content-Length: 10000\r\n" + @@ -2273,6 +2274,18 @@ func TestRequestBodyTimeoutClosesConnection(t *testing.T) { } } +// cancelableTimeoutContext overwrites the error message to DeadlineExceeded +type cancelableTimeoutContext struct { + context.Context +} + +func (c cancelableTimeoutContext) Err() error { + if c.Context.Err() != nil { + return context.DeadlineExceeded + } + return nil +} + func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) } func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) } func testTimeoutHandler(t *testing.T, h2 bool) { @@ -2285,8 +2298,9 @@ func testTimeoutHandler(t *testing.T, h2 bool) { _, werr := w.Write([]byte("hi")) writeErrors <- werr }) - timeout := make(chan time.Time, 1) // write to this to force timeouts - cst := newClientServerTest(t, h2, NewTestTimeoutHandler(sayHi, timeout)) + ctx, cancel := context.WithCancel(context.Background()) + h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx}) + cst := newClientServerTest(t, h2, h) defer cst.close() // Succeed without timing out: @@ -2307,7 +2321,8 @@ func testTimeoutHandler(t *testing.T, h2 bool) { } // Times out: - timeout <- time.Time{} + cancel() + res, err = cst.c.Get(cst.ts.URL) if err != nil { t.Error(err) @@ -2428,8 +2443,9 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { _, werr := w.Write([]byte("hi")) writeErrors <- werr }) - timeout := make(chan time.Time, 1) // write to this to force timeouts - cst := newClientServerTest(t, h1Mode, NewTestTimeoutHandler(sayHi, timeout)) + ctx, cancel := context.WithCancel(context.Background()) + h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx}) + cst := newClientServerTest(t, h1Mode, h) defer cst.close() // Succeed without timing out: @@ -2450,7 +2466,8 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { } // Times out: - timeout <- time.Time{} + cancel() + res, err = cst.c.Get(cst.ts.URL) if err != nil { t.Error(err) @@ -2500,6 +2517,47 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) { } } +func TestTimeoutHandlerContextCanceled(t *testing.T) { + setParallel(t) + defer afterTest(t) + writeErrors := make(chan error, 1) + sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Type", "text/plain") + var err error + // The request context has already been canceled, but + // retry the write for a while to give the timeout handler + // a chance to notice. + for i := 0; i < 100; i++ { + _, err = w.Write([]byte("a")) + if err != nil { + break + } + time.Sleep(1 * time.Millisecond) + } + writeErrors <- err + }) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + h := NewTestTimeoutHandler(sayHi, ctx) + cst := newClientServerTest(t, h1Mode, h) + defer cst.close() + + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Error(err) + } + if g, e := res.StatusCode, StatusServiceUnavailable; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + body, _ := io.ReadAll(res.Body) + if g, e := string(body), ""; g != e { + t.Errorf("got body %q; expected %q", g, e) + } + if g, e := <-writeErrors, context.Canceled; g != e { + t.Errorf("got unexpected Write in handler: %v, want %g", g, e) + } +} + // https://golang.org/issue/15948 func TestTimeoutHandlerEmptyResponse(t *testing.T) { setParallel(t) @@ -2708,7 +2766,7 @@ func TestHandlerPanicWithHijack(t *testing.T) { testHandlerPanic(t, true, h1Mode, nil, "intentional death for testing") } -func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) Handler, panicValue interface{}) { +func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) Handler, panicValue any) { defer afterTest(t) // Unlike the other tests that set the log output to io.Discard // to quiet the output, this test uses a pipe. The pipe serves three @@ -3017,22 +3075,14 @@ func TestClientWriteShutdown(t *testing.T) { if err != nil { t.Fatalf("CloseWrite: %v", err) } - donec := make(chan bool) - go func() { - defer close(donec) - bs, err := io.ReadAll(conn) - if err != nil { - t.Errorf("ReadAll: %v", err) - } - got := string(bs) - if got != "" { - t.Errorf("read %q from server; want nothing", got) - } - }() - select { - case <-donec: - case <-time.After(10 * time.Second): - t.Fatalf("timeout") + + bs, err := io.ReadAll(conn) + if err != nil { + t.Errorf("ReadAll: %v", err) + } + got := string(bs) + if got != "" { + t.Errorf("read %q from server; want nothing", got) } } @@ -3884,7 +3934,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { // this test fails, it hangs. This helps debugging and I've // added this enough times "temporarily". It now gets added // full time. - errorf := func(format string, args ...interface{}) { + errorf := func(format string, args ...any) { v := fmt.Sprintf(format, args...) println(v) t.Error(v) @@ -3893,10 +3943,10 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { unblockBackend := make(chan bool) backend := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { gone := rw.(CloseNotifier).CloseNotify() - didCopy := make(chan interface{}) + didCopy := make(chan any) go func() { n, err := io.CopyN(rw, req.Body, bodySize) - didCopy <- []interface{}{n, err} + didCopy <- []any{n, err} }() isGone := false Loop: @@ -4888,7 +4938,7 @@ func TestServerContext_LocalAddrContextKey_h2(t *testing.T) { func testServerContext_LocalAddrContextKey(t *testing.T, h2 bool) { setParallel(t) defer afterTest(t) - ch := make(chan interface{}, 1) + ch := make(chan any, 1) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { ch <- r.Context().Value(LocalAddrContextKey) })) @@ -5689,22 +5739,37 @@ func testServerKeepAlivesEnabled(t *testing.T, h2 bool) { } // Not parallel: messes with global variable. (http2goAwayTimeout) defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { - fmt.Fprintf(w, "%v", r.RemoteAddr) - })) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {})) defer cst.close() srv := cst.ts.Config srv.SetKeepAlivesEnabled(false) - a := cst.getURL(cst.ts.URL) - if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) { - t.Fatalf("test server has active conns") - } - b := cst.getURL(cst.ts.URL) - if a == b { - t.Errorf("got same connection between first and second requests") - } - if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) { - t.Fatalf("test server has active conns") + for try := 0; try < 2; try++ { + if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) { + t.Fatalf("request %v: test server has active conns", try) + } + conns := 0 + var info httptrace.GotConnInfo + ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ + GotConn: func(v httptrace.GotConnInfo) { + conns++ + info = v + }, + }) + req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if conns != 1 { + t.Fatalf("request %v: got %v conns, want 1", try, conns) + } + if info.Reused || info.WasIdle { + t.Fatalf("request %v: Reused=%v (want false), WasIdle=%v (want false)", try, info.Reused, info.WasIdle) + } } } @@ -5933,11 +5998,7 @@ func TestServerHijackGetsBackgroundByte_big(t *testing.T) { t.Fatal(err) } - select { - case <-done: - case <-time.After(2 * time.Second): - t.Error("timeout") - } + <-done } // Issue 18319: test that the Server validates the request method. @@ -6232,7 +6293,7 @@ func testContentEncodingNoSniffing(t *testing.T, h2 bool) { // setting contentEncoding as an interface instead of a string // directly, so as to differentiate between 3 states: // unset, empty string "" and set string "foo/bar". - contentEncoding interface{} + contentEncoding any wantContentType string } @@ -6490,7 +6551,7 @@ func TestDisableKeepAliveUpgrade(t *testing.T) { rwc, ok := resp.Body.(io.ReadWriteCloser) if !ok { - t.Fatalf("Response.Body is not a io.ReadWriteCloser: %T", resp.Body) + t.Fatalf("Response.Body is not an io.ReadWriteCloser: %T", resp.Body) } _, err = rwc.Write([]byte("hello")) @@ -6609,3 +6670,63 @@ func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolon } } } + +func TestMaxBytesHandler(t *testing.T) { + setParallel(t) + defer afterTest(t) + + for _, maxSize := range []int64{100, 1_000, 1_000_000} { + for _, requestSize := range []int64{100, 1_000, 1_000_000} { + t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize), + func(t *testing.T) { + testMaxBytesHandler(t, maxSize, requestSize) + }) + } + } +} + +func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) { + var ( + handlerN int64 + handlerErr error + ) + echo := HandlerFunc(func(w ResponseWriter, r *Request) { + var buf bytes.Buffer + handlerN, handlerErr = io.Copy(&buf, r.Body) + io.Copy(w, &buf) + }) + + ts := httptest.NewServer(MaxBytesHandler(echo, maxSize)) + defer ts.Close() + + c := ts.Client() + var buf strings.Builder + body := strings.NewReader(strings.Repeat("a", int(requestSize))) + res, err := c.Post(ts.URL, "text/plain", body) + if err != nil { + t.Errorf("unexpected connection error: %v", err) + } else { + _, err = io.Copy(&buf, res.Body) + res.Body.Close() + if err != nil { + t.Errorf("unexpected read error: %v", err) + } + } + if handlerN > maxSize { + t.Errorf("expected max request body %d; got %d", maxSize, handlerN) + } + if requestSize > maxSize && handlerErr == nil { + t.Error("expected error on handler side; got nil") + } + if requestSize <= maxSize { + if handlerErr != nil { + t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr) + } + if handlerN != requestSize { + t.Errorf("expected request of size %d; got %d", requestSize, handlerN) + } + } + if buf.Len() != int(handlerN) { + t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len()) + } +} diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go index ce39933..f5cdc3a 100644 --- a/libgo/go/net/http/server.go +++ b/libgo/go/net/http/server.go @@ -13,6 +13,7 @@ import ( "crypto/tls" "errors" "fmt" + "internal/godebug" "io" "log" "math/rand" @@ -20,7 +21,6 @@ import ( "net/textproto" "net/url" urlpkg "net/url" - "os" "path" "runtime" "sort" @@ -494,8 +494,8 @@ type response struct { // prior to the headers being written. If the set of trailers is fixed // or known before the header is written, the normal Go trailers mechanism // is preferred: -// https://golang.org/pkg/net/http/#ResponseWriter -// https://golang.org/pkg/net/http/#example_ResponseWriter_trailers +// https://pkg.go.dev/net/http#ResponseWriter +// https://pkg.go.dev/net/http#example-ResponseWriter-Trailers const TrailerPrefix = "Trailer:" // finalTrailers is called after the Handler exits and returns a non-nil @@ -798,7 +798,7 @@ var ( ) var copyBufPool = sync.Pool{ - New: func() interface{} { + New: func() any { b := make([]byte, 32*1024) return &b }, @@ -865,6 +865,28 @@ func (srv *Server) initialReadLimitSize() int64 { return int64(srv.maxHeaderBytes()) + 4096 // bufio slop } +// tlsHandshakeTimeout returns the time limit permitted for the TLS +// handshake, or zero for unlimited. +// +// It returns the minimum of any positive ReadHeaderTimeout, +// ReadTimeout, or WriteTimeout. +func (srv *Server) tlsHandshakeTimeout() time.Duration { + var ret time.Duration + for _, v := range [...]time.Duration{ + srv.ReadHeaderTimeout, + srv.ReadTimeout, + srv.WriteTimeout, + } { + if v <= 0 { + continue + } + if ret == 0 || v < ret { + ret = v + } + } + return ret +} + // wrapper around io.ReadCloser which on first read, sends an // HTTP/1.1 100 Continue header type expectContinueReader struct { @@ -1409,11 +1431,11 @@ func (cw *chunkWriter) writeHeader(p []byte) { hasCL = false } - if w.req.Method == "HEAD" || !bodyAllowedForStatus(code) { - // do nothing - } else if code == StatusNoContent { + if w.req.Method == "HEAD" || !bodyAllowedForStatus(code) || code == StatusNoContent { + // Response has no body. delHeader("Transfer-Encoding") } else if hasCL { + // Content-Length has been provided, so no chunking is to be done. delHeader("Transfer-Encoding") } else if w.req.ProtoAtLeast(1, 1) { // HTTP/1.1 or greater: Transfer-Encoding has been set to identity, and no @@ -1424,6 +1446,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { if hasTE && te == "identity" { cw.chunking = false w.closeAfterReply = true + delHeader("Transfer-Encoding") } else { // HTTP/1.1 or greater: use chunked transfer encoding // to avoid closing the connection at EOF. @@ -1799,6 +1822,7 @@ func isCommonNetReadError(err error) bool { func (c *conn) serve(ctx context.Context) { c.remoteAddr = c.rwc.RemoteAddr().String() ctx = context.WithValue(ctx, LocalAddrContextKey, c.rwc.LocalAddr()) + var inFlightResponse *response defer func() { if err := recover(); err != nil && err != ErrAbortHandler { const size = 64 << 10 @@ -1806,18 +1830,25 @@ func (c *conn) serve(ctx context.Context) { buf = buf[:runtime.Stack(buf, false)] c.server.logf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf) } + if inFlightResponse != nil { + inFlightResponse.cancelCtx() + } if !c.hijacked() { + if inFlightResponse != nil { + inFlightResponse.conn.r.abortPendingRead() + inFlightResponse.reqBody.Close() + } c.close() c.setState(c.rwc, StateClosed, runHooks) } }() if tlsConn, ok := c.rwc.(*tls.Conn); ok { - if d := c.server.ReadTimeout; d > 0 { - c.rwc.SetReadDeadline(time.Now().Add(d)) - } - if d := c.server.WriteTimeout; d > 0 { - c.rwc.SetWriteDeadline(time.Now().Add(d)) + tlsTO := c.server.tlsHandshakeTimeout() + if tlsTO > 0 { + dl := time.Now().Add(tlsTO) + c.rwc.SetReadDeadline(dl) + c.rwc.SetWriteDeadline(dl) } if err := tlsConn.HandshakeContext(ctx); err != nil { // If the handshake failed due to the client not speaking @@ -1831,6 +1862,11 @@ func (c *conn) serve(ctx context.Context) { c.server.logf("http: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err) return } + // Restore Conn-level deadlines. + if tlsTO > 0 { + c.rwc.SetReadDeadline(time.Time{}) + c.rwc.SetWriteDeadline(time.Time{}) + } c.tlsState = new(tls.ConnectionState) *c.tlsState = tlsConn.ConnectionState() if proto := c.tlsState.NegotiatedProtocol; validNextProto(proto) { @@ -1931,7 +1967,9 @@ func (c *conn) serve(ctx context.Context) { // in parallel even if their responses need to be serialized. // But we're not going to implement HTTP pipelining because it // was never deployed in the wild and the answer is HTTP/2. + inFlightResponse = w serverHandler{c.server}.ServeHTTP(w, w.req) + inFlightResponse = nil w.cancelCtx() if c.hijacked() { return @@ -2277,7 +2315,7 @@ func cleanPath(p string) string { // stripHostPort returns h without any trailing ":<port>". func stripHostPort(h string) string { // If no port on host, return unchanged - if strings.IndexByte(h, ':') == -1 { + if !strings.Contains(h, ":") { return h } host, _, err := net.SplitHostPort(h) @@ -3157,7 +3195,7 @@ func (srv *Server) SetKeepAlivesEnabled(v bool) { // TODO: Issue 26303: close HTTP/2 conns as soon as they become idle. } -func (s *Server) logf(format string, args ...interface{}) { +func (s *Server) logf(format string, args ...any) { if s.ErrorLog != nil { s.ErrorLog.Printf(format, args...) } else { @@ -3168,7 +3206,7 @@ func (s *Server) logf(format string, args ...interface{}) { // logf prints to the ErrorLog of the *Server associated with request r // via ServerContextKey. If there's no associated server, or if ErrorLog // is nil, logging is done via the log package's standard logger. -func logf(r *Request, format string, args ...interface{}) { +func logf(r *Request, format string, args ...any) { s, _ := r.Context().Value(ServerContextKey).(*Server) if s != nil && s.ErrorLog != nil { s.ErrorLog.Printf(format, args...) @@ -3264,7 +3302,7 @@ func (srv *Server) onceSetNextProtoDefaults_Serve() { // configured otherwise. (by setting srv.TLSNextProto non-nil) // It must only be called via srv.nextProtoOnce (use srv.setupHTTP2_*). func (srv *Server) onceSetNextProtoDefaults() { - if omitBundledHTTP2 || strings.Contains(os.Getenv("GODEBUG"), "http2server=0") { + if omitBundledHTTP2 || godebug.Get("http2server") == "0" { return } // Enable HTTP/2 by default if the user hasn't otherwise @@ -3331,7 +3369,7 @@ func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { h: make(Header), req: r, } - panicChan := make(chan interface{}, 1) + panicChan := make(chan any, 1) go func() { defer func() { if p := recover(); p != nil { @@ -3359,9 +3397,15 @@ func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { case <-ctx.Done(): tw.mu.Lock() defer tw.mu.Unlock() - w.WriteHeader(StatusServiceUnavailable) - io.WriteString(w, h.errorBody()) - tw.timedOut = true + switch err := ctx.Err(); err { + case context.DeadlineExceeded: + w.WriteHeader(StatusServiceUnavailable) + io.WriteString(w, h.errorBody()) + tw.err = ErrHandlerTimeout + default: + w.WriteHeader(StatusServiceUnavailable) + tw.err = err + } } } @@ -3372,7 +3416,7 @@ type timeoutWriter struct { req *Request mu sync.Mutex - timedOut bool + err error wroteHeader bool code int } @@ -3392,8 +3436,8 @@ func (tw *timeoutWriter) Header() Header { return tw.h } func (tw *timeoutWriter) Write(p []byte) (int, error) { tw.mu.Lock() defer tw.mu.Unlock() - if tw.timedOut { - return 0, ErrHandlerTimeout + if tw.err != nil { + return 0, tw.err } if !tw.wroteHeader { tw.writeHeaderLocked(StatusOK) @@ -3405,7 +3449,7 @@ func (tw *timeoutWriter) writeHeaderLocked(code int) { checkWriteHeaderCode(code) switch { - case tw.timedOut: + case tw.err != nil: return case tw.wroteHeader: if tw.req != nil { @@ -3572,3 +3616,12 @@ func tlsRecordHeaderLooksLikeHTTP(hdr [5]byte) bool { } return false } + +// MaxBytesHandler returns a Handler that runs h with its ResponseWriter and Request.Body wrapped by a MaxBytesReader. +func MaxBytesHandler(h Handler, n int64) Handler { + return HandlerFunc(func(w ResponseWriter, r *Request) { + r2 := *r + r2.Body = MaxBytesReader(w, r.Body, n) + h.ServeHTTP(w, &r2) + }) +} diff --git a/libgo/go/net/http/server_test.go b/libgo/go/net/http/server_test.go index 0132f3b..d17c5c1 100644 --- a/libgo/go/net/http/server_test.go +++ b/libgo/go/net/http/server_test.go @@ -9,8 +9,61 @@ package http import ( "fmt" "testing" + "time" ) +func TestServerTLSHandshakeTimeout(t *testing.T) { + tests := []struct { + s *Server + want time.Duration + }{ + { + s: &Server{}, + want: 0, + }, + { + s: &Server{ + ReadTimeout: -1, + }, + want: 0, + }, + { + s: &Server{ + ReadTimeout: 5 * time.Second, + }, + want: 5 * time.Second, + }, + { + s: &Server{ + ReadTimeout: 5 * time.Second, + WriteTimeout: -1, + }, + want: 5 * time.Second, + }, + { + s: &Server{ + ReadTimeout: 5 * time.Second, + WriteTimeout: 4 * time.Second, + }, + want: 4 * time.Second, + }, + { + s: &Server{ + ReadTimeout: 5 * time.Second, + ReadHeaderTimeout: 2 * time.Second, + WriteTimeout: 4 * time.Second, + }, + want: 2 * time.Second, + }, + } + for i, tt := range tests { + got := tt.s.tlsHandshakeTimeout() + if got != tt.want { + t.Errorf("%d. got %v; want %v", i, got, tt.want) + } + } +} + func BenchmarkServerMatch(b *testing.B) { fn := func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "OK") diff --git a/libgo/go/net/http/transfer.go b/libgo/go/net/http/transfer.go index 85c2e5a..6d51178 100644 --- a/libgo/go/net/http/transfer.go +++ b/libgo/go/net/http/transfer.go @@ -73,7 +73,7 @@ type transferWriter struct { ByteReadCh chan readResult // non-nil if probeRequestBody called } -func newTransferWriter(r interface{}) (t *transferWriter, err error) { +func newTransferWriter(r any) (t *transferWriter, err error) { t = &transferWriter{} // Extract relevant fields @@ -212,6 +212,7 @@ func (t *transferWriter) probeRequestBody() { rres.b = buf[0] } t.ByteReadCh <- rres + close(t.ByteReadCh) }(t.Body) timer := time.NewTimer(200 * time.Millisecond) select { @@ -480,7 +481,7 @@ func suppressedHeaders(status int) []string { } // msg is *Request or *Response. -func readTransfer(msg interface{}, r *bufio.Reader) (err error) { +func readTransfer(msg any, r *bufio.Reader) (err error) { t := &transferReader{RequestMethod: "GET"} // Unify input @@ -808,7 +809,7 @@ func fixTrailer(header Header, chunked bool) (Header, error) { // and then reads the trailer if necessary. type body struct { src io.Reader - hdr interface{} // non-nil (Response or Request) value means read trailer + hdr any // non-nil (Response or Request) value means read trailer r *bufio.Reader // underlying wire-format reader for the trailer closing bool // is the connection to be closed after reading body? doEarlyClose bool // whether Close should stop early @@ -1029,7 +1030,7 @@ func (b *body) registerOnHitEOF(fn func()) { b.onHitEOF = fn } -// bodyLocked is a io.Reader reading from a *body when its mutex is +// bodyLocked is an io.Reader reading from a *body when its mutex is // already held. type bodyLocked struct { b *body @@ -1072,6 +1073,9 @@ func (fr finishAsyncByteRead) Read(p []byte) (n int, err error) { if n == 1 { p[0] = rres.b } + if err == nil { + err = io.EOF + } return } diff --git a/libgo/go/net/http/transport.go b/libgo/go/net/http/transport.go index 309194e..5fe3e6e 100644 --- a/libgo/go/net/http/transport.go +++ b/libgo/go/net/http/transport.go @@ -17,6 +17,7 @@ import ( "crypto/tls" "errors" "fmt" + "internal/godebug" "io" "log" "net" @@ -24,7 +25,6 @@ import ( "net/http/internal/ascii" "net/textproto" "net/url" - "os" "reflect" "strings" "sync" @@ -42,10 +42,10 @@ import ( // $no_proxy) environment variables. var DefaultTransport RoundTripper = &Transport{ Proxy: ProxyFromEnvironment, - DialContext: (&net.Dialer{ + DialContext: defaultTransportDialContext(&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, - }).DialContext, + }), ForceAttemptHTTP2: true, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, @@ -360,7 +360,7 @@ func (t *Transport) hasCustomTLSDialer() bool { // It must be called via t.nextProtoOnce.Do. func (t *Transport) onceSetNextProtoDefaults() { t.tlsNextProtoWasNil = (t.TLSNextProto == nil) - if strings.Contains(os.Getenv("GODEBUG"), "http2client=0") { + if godebug.Get("http2client") == "0" { return } @@ -1715,12 +1715,12 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers return nil, err } if resp.StatusCode != 200 { - f := strings.SplitN(resp.Status, " ", 2) + _, text, ok := strings.Cut(resp.Status, " ") conn.Close() - if len(f) < 2 { + if !ok { return nil, errors.New("unknown status code") } - return nil, errors.New(f[1]) + return nil, errors.New(text) } } @@ -2481,7 +2481,7 @@ type requestAndChan struct { callerGone <-chan struct{} // closed when roundTrip caller has returned } -// A writeRequest is sent by the readLoop's goroutine to the +// A writeRequest is sent by the caller's goroutine to the // writeLoop's goroutine to write a request while the read loop // concurrently waits on both the write response and the server's // reply. @@ -2668,8 +2668,8 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // a t.Logf func. See export_test.go's Request.WithT method. type tLogKey struct{} -func (tr *transportRequest) logf(format string, args ...interface{}) { - if logf, ok := tr.Request.Context().Value(tLogKey{}).(func(string, ...interface{})); ok { +func (tr *transportRequest) logf(format string, args ...any) { + if logf, ok := tr.Request.Context().Value(tLogKey{}).(func(string, ...any)); ok { logf(time.Now().Format(time.RFC3339Nano)+": "+format, args...) } } diff --git a/libgo/go/net/http/transport_default_js.go b/libgo/go/net/http/transport_default_js.go new file mode 100644 index 0000000..c07d35e --- /dev/null +++ b/libgo/go/net/http/transport_default_js.go @@ -0,0 +1,17 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build js && wasm +// +build js,wasm + +package http + +import ( + "context" + "net" +) + +func defaultTransportDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) { + return nil +} diff --git a/libgo/go/net/http/transport_default_other.go b/libgo/go/net/http/transport_default_other.go new file mode 100644 index 0000000..8a2f1cc --- /dev/null +++ b/libgo/go/net/http/transport_default_other.go @@ -0,0 +1,17 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !(js && wasm) +// +build !js !wasm + +package http + +import ( + "context" + "net" +) + +func defaultTransportDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) { + return dialer.DialContext +} diff --git a/libgo/go/net/http/transport_test.go b/libgo/go/net/http/transport_test.go index 7e14749..fed092b 100644 --- a/libgo/go/net/http/transport_test.go +++ b/libgo/go/net/http/transport_test.go @@ -776,7 +776,7 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) { c := ts.Client() fetch := func(n, retries int) string { - condFatalf := func(format string, arg ...interface{}) { + condFatalf := func(format string, arg ...any) { if retries <= 0 { t.Fatalf(format, arg...) } @@ -3518,7 +3518,7 @@ func TestRetryRequestsOnError(t *testing.T) { mu sync.Mutex logbuf bytes.Buffer ) - logf := func(format string, args ...interface{}) { + logf := func(format string, args ...any) { mu.Lock() defer mu.Unlock() fmt.Fprintf(&logbuf, format, args...) @@ -4495,7 +4495,7 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { var mu sync.Mutex // guards buf var buf bytes.Buffer - logf := func(format string, args ...interface{}) { + logf := func(format string, args ...any) { mu.Lock() defer mu.Unlock() fmt.Fprintf(&buf, format, args...) @@ -4654,7 +4654,7 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { func TestTransportEventTraceTLSVerify(t *testing.T) { var mu sync.Mutex var buf bytes.Buffer - logf := func(format string, args ...interface{}) { + logf := func(format string, args ...any) { mu.Lock() defer mu.Unlock() fmt.Fprintf(&buf, format, args...) @@ -4740,7 +4740,7 @@ func TestTransportEventTraceRealDNS(t *testing.T) { var mu sync.Mutex // guards buf var buf bytes.Buffer - logf := func(format string, args ...interface{}) { + logf := func(format string, args ...any) { mu.Lock() defer mu.Unlock() fmt.Fprintf(&buf, format, args...) @@ -6516,3 +6516,32 @@ func TestCancelRequestWhenSharingConnection(t *testing.T) { close(r2c) wg.Wait() } + +func TestHandlerAbortRacesBodyRead(t *testing.T) { + setParallel(t) + defer afterTest(t) + + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + go io.Copy(io.Discard, req.Body) + panic(ErrAbortHandler) + })) + defer ts.Close() + + var wg sync.WaitGroup + for i := 0; i < 2; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 10; j++ { + const reqLen = 6 * 1024 * 1024 + req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen}) + req.ContentLength = reqLen + resp, _ := ts.Client().Transport.RoundTrip(req) + if resp != nil { + resp.Body.Close() + } + } + }() + } + wg.Wait() +} diff --git a/libgo/go/net/http/triv.go b/libgo/go/net/http/triv.go index 4dc6240..11b19ab 100644 --- a/libgo/go/net/http/triv.go +++ b/libgo/go/net/http/triv.go @@ -3,7 +3,6 @@ // license that can be found in the LICENSE file. //go:build ignore -// +build ignore package main |