diff options
Diffstat (limited to 'libgo/go/net/http')
32 files changed, 1715 insertions, 340 deletions
diff --git a/libgo/go/net/http/cgi/child.go b/libgo/go/net/http/cgi/child.go index 10325c2..cb140f8 100644 --- a/libgo/go/net/http/cgi/child.go +++ b/libgo/go/net/http/cgi/child.go @@ -102,7 +102,7 @@ func RequestFromMap(params map[string]string) (*http.Request, error) { } // There's apparently a de-facto standard for this. - // https://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636 + // https://web.archive.org/web/20170105004655/http://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636 if s := params["HTTPS"]; s == "on" || s == "ON" || s == "1" { r.TLS = &tls.ConnectionState{HandshakeComplete: true} } diff --git a/libgo/go/net/http/client.go b/libgo/go/net/http/client.go index 921f86b..65a9d51 100644 --- a/libgo/go/net/http/client.go +++ b/libgo/go/net/http/client.go @@ -100,7 +100,7 @@ type Client struct { // For compatibility, the Client will also use the deprecated // CancelRequest method on Transport if found. New // RoundTripper implementations should use the Request's Context - // for cancelation instead of implementing CancelRequest. + // for cancellation instead of implementing CancelRequest. Timeout time.Duration } @@ -238,7 +238,7 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, d username := u.Username() password, _ := u.Password() forkReq() - req.Header = ireq.Header.clone() + req.Header = ireq.Header.Clone() req.Header.Set("Authorization", "Basic "+basicAuth(username, password)) } @@ -643,7 +643,7 @@ func (c *Client) do(req *Request) (retres *Response, reterr error) { reqBodyClosed = true if !deadline.IsZero() && didTimeout() { err = &httpError{ - // TODO: early in cycle: s/Client.Timeout exceeded/timeout or context cancelation/ + // TODO: early in cycle: s/Client.Timeout exceeded/timeout or context cancellation/ err: err.Error() + " (Client.Timeout exceeded while awaiting headers)", timeout: true, } @@ -668,7 +668,7 @@ func (c *Client) makeHeadersCopier(ireq *Request) func(*Request) { // The headers to copy are from the very initial request. // We use a closured callback to keep a reference to these original headers. var ( - ireqhdr = ireq.Header.clone() + ireqhdr = ireq.Header.Clone() icookies map[string][]*Cookie ) if c.Jar != nil && ireq.Header.Get("Cookie") != "" { @@ -870,7 +870,7 @@ func (b *cancelTimerBody) Read(p []byte) (n int, err error) { } if b.reqDidTimeout() { err = &httpError{ - // TODO: early in cycle: s/Client.Timeout exceeded/timeout or context cancelation/ + // TODO: early in cycle: s/Client.Timeout exceeded/timeout or context cancellation/ err: err.Error() + " (Client.Timeout exceeded while reading body)", timeout: true, } @@ -926,10 +926,9 @@ func isDomainOrSubdomain(sub, parent string) bool { } func stripPassword(u *url.URL) string { - pass, passSet := u.User.Password() + _, passSet := u.User.Password() if passSet { - return strings.Replace(u.String(), pass+"@", "***@", 1) + return strings.Replace(u.String(), u.User.String()+"@", u.User.Username()+":***@", 1) } - return u.String() } diff --git a/libgo/go/net/http/client_test.go b/libgo/go/net/http/client_test.go index 1c59ce7..de490bc 100644 --- a/libgo/go/net/http/client_test.go +++ b/libgo/go/net/http/client_test.go @@ -317,8 +317,7 @@ func TestClientRedirectContext(t *testing.T) { return errors.New("redirected request's context never expired after root request canceled") } } - req, _ := NewRequest("GET", ts.URL, nil) - req = req.WithContext(ctx) + req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil) _, err := c.Do(req) ue, ok := err.(*url.Error) if !ok { @@ -1185,6 +1184,11 @@ func TestStripPasswordFromError(t *testing.T) { in: "http://user:password@dummy.faketld/password", out: "Get http://user:***@dummy.faketld/password: dummy impl", }, + { + desc: "Strip escaped password", + in: "http://user:pa%2Fssword@dummy.faketld/", + out: "Get http://user:***@dummy.faketld/: dummy impl", + }, } for _, tC := range testCases { t.Run(tC.desc, func(t *testing.T) { @@ -1288,7 +1292,7 @@ func testClientTimeout_Headers(t *testing.T, h2 bool) { donec := make(chan bool, 1) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { <-donec - })) + }), optQuietLog) defer cst.close() // Note that we use a channel send here and not a close. // The race detector doesn't know that we're waiting for a timeout diff --git a/libgo/go/net/http/clientserver_test.go b/libgo/go/net/http/clientserver_test.go index 4da218b..ee48fb0 100644 --- a/libgo/go/net/http/clientserver_test.go +++ b/libgo/go/net/http/clientserver_test.go @@ -560,7 +560,7 @@ func testCancelRequestMidBody(t *testing.T, h2 bool) { if all != "Hello" { t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest) } - if !reflect.DeepEqual(err, ExportErrRequestCanceled) { + if err != ExportErrRequestCanceled { t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled) } } diff --git a/libgo/go/net/http/clone.go b/libgo/go/net/http/clone.go new file mode 100644 index 0000000..5f2784d --- /dev/null +++ b/libgo/go/net/http/clone.go @@ -0,0 +1,64 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "mime/multipart" + "net/textproto" + "net/url" +) + +func cloneURLValues(v url.Values) url.Values { + if v == nil { + return nil + } + // http.Header and url.Values have the same representation, so temporarily + // treat it like http.Header, which does have a clone: + return url.Values(Header(v).Clone()) +} + +func cloneURL(u *url.URL) *url.URL { + if u == nil { + return nil + } + u2 := new(url.URL) + *u2 = *u + if u.User != nil { + u2.User = new(url.Userinfo) + *u2.User = *u.User + } + return u2 +} + +func cloneMultipartForm(f *multipart.Form) *multipart.Form { + if f == nil { + return nil + } + f2 := &multipart.Form{ + Value: (map[string][]string)(Header(f.Value).Clone()), + } + if f.File != nil { + m := make(map[string][]*multipart.FileHeader) + for k, vv := range f.File { + vv2 := make([]*multipart.FileHeader, len(vv)) + for i, v := range vv { + vv2[i] = cloneMultipartFileHeader(v) + } + m[k] = vv2 + } + f2.File = m + } + return f2 +} + +func cloneMultipartFileHeader(fh *multipart.FileHeader) *multipart.FileHeader { + if fh == nil { + return nil + } + fh2 := new(multipart.FileHeader) + *fh2 = *fh + fh2.Header = textproto.MIMEHeader(Header(fh.Header).Clone()) + return fh2 +} diff --git a/libgo/go/net/http/cookie.go b/libgo/go/net/http/cookie.go index 63f6221..91ff544 100644 --- a/libgo/go/net/http/cookie.go +++ b/libgo/go/net/http/cookie.go @@ -48,6 +48,7 @@ const ( SameSiteDefaultMode SameSite = iota + 1 SameSiteLaxMode SameSiteStrictMode + SameSiteNoneMode ) // readSetCookies parses all "Set-Cookie" values from @@ -105,6 +106,8 @@ func readSetCookies(h Header) []*Cookie { c.SameSite = SameSiteLaxMode case "strict": c.SameSite = SameSiteStrictMode + case "none": + c.SameSite = SameSiteNoneMode default: c.SameSite = SameSiteDefaultMode } @@ -168,8 +171,12 @@ func (c *Cookie) String() string { if c == nil || !isCookieNameValid(c.Name) { return "" } + // extraCookieLength derived from typical length of cookie attributes + // see RFC 6265 Sec 4.1. + const extraCookieLength = 110 var b strings.Builder - b.WriteString(sanitizeCookieName(c.Name)) + b.Grow(len(c.Name) + len(c.Value) + len(c.Domain) + len(c.Path) + extraCookieLength) + b.WriteString(c.Name) b.WriteRune('=') b.WriteString(sanitizeCookieValue(c.Value)) @@ -213,6 +220,8 @@ func (c *Cookie) String() string { switch c.SameSite { case SameSiteDefaultMode: b.WriteString("; SameSite") + case SameSiteNoneMode: + b.WriteString("; SameSite=None") case SameSiteLaxMode: b.WriteString("; SameSite=Lax") case SameSiteStrictMode: @@ -226,25 +235,28 @@ func (c *Cookie) String() string { // // if filter isn't empty, only cookies of that name are returned func readCookies(h Header, filter string) []*Cookie { - lines, ok := h["Cookie"] - if !ok { + lines := h["Cookie"] + if len(lines) == 0 { return []*Cookie{} } - cookies := []*Cookie{} + cookies := make([]*Cookie, 0, len(lines)+strings.Count(lines[0], ";")) for _, line := range lines { - parts := strings.Split(strings.TrimSpace(line), ";") - if len(parts) == 1 && parts[0] == "" { - continue - } - // Per-line attributes - for i := 0; i < len(parts); i++ { - parts[i] = strings.TrimSpace(parts[i]) - if len(parts[i]) == 0 { + line = strings.TrimSpace(line) + + var part string + for len(line) > 0 { // continue since we have rest + if splitIndex := strings.Index(line, ";"); splitIndex > 0 { + part, line = line[:splitIndex], line[splitIndex+1:] + } else { + part, line = line, "" + } + part = strings.TrimSpace(part) + if len(part) == 0 { continue } - name, val := parts[i], "" - if j := strings.Index(name, "="); j >= 0 { + name, val := part, "" + if j := strings.Index(part, "="); j >= 0 { name, val = name[:j], name[j+1:] } if !isCookieNameValid(name) { diff --git a/libgo/go/net/http/cookie_test.go b/libgo/go/net/http/cookie_test.go index 022adaa..9e8196e 100644 --- a/libgo/go/net/http/cookie_test.go +++ b/libgo/go/net/http/cookie_test.go @@ -77,6 +77,10 @@ var writeSetCookiesTests = []struct { &Cookie{Name: "cookie-14", Value: "samesite-strict", SameSite: SameSiteStrictMode}, "cookie-14=samesite-strict; SameSite=Strict", }, + { + &Cookie{Name: "cookie-15", Value: "samesite-none", SameSite: SameSiteNoneMode}, + "cookie-15=samesite-none; SameSite=None", + }, // The "special" cookies have values containing commas or spaces which // are disallowed by RFC 6265 but are common in the wild. { @@ -127,6 +131,22 @@ var writeSetCookiesTests = []struct { &Cookie{Name: "\t"}, ``, }, + { + &Cookie{Name: "\r"}, + ``, + }, + { + &Cookie{Name: "a\nb", Value: "v"}, + ``, + }, + { + &Cookie{Name: "a\nb", Value: "v"}, + ``, + }, + { + &Cookie{Name: "a\rb", Value: "v"}, + ``, + }, } func TestWriteSetCookies(t *testing.T) { @@ -280,6 +300,15 @@ var readSetCookiesTests = []struct { Raw: "samesitestrict=foo; SameSite=Strict", }}, }, + { + Header{"Set-Cookie": {"samesitenone=foo; SameSite=None"}}, + []*Cookie{{ + Name: "samesitenone", + Value: "foo", + SameSite: SameSiteNoneMode, + Raw: "samesitenone=foo; SameSite=None", + }}, + }, // Make sure we can properly read back the Set-Cookie headers we create // for values containing spaces or commas: { @@ -385,6 +414,19 @@ var readCookiesTests = []struct { {Name: "c2", Value: "v2"}, }, }, + { + Header{"Cookie": {`Cookie-1="v$1"; c2=v2;`}}, + "", + []*Cookie{ + {Name: "Cookie-1", Value: "v$1"}, + {Name: "c2", Value: "v2"}, + }, + }, + { + Header{"Cookie": {``}}, + "", + []*Cookie{}, + }, } func TestReadCookies(t *testing.T) { diff --git a/libgo/go/net/http/export_test.go b/libgo/go/net/http/export_test.go index b6965c2..f0dfa8c 100644 --- a/libgo/go/net/http/export_test.go +++ b/libgo/go/net/http/export_test.go @@ -33,6 +33,7 @@ var ( ExportHttp2ConfigureServer = http2ConfigureServer Export_shouldCopyHeaderOnRedirect = shouldCopyHeaderOnRedirect Export_writeStatusLine = writeStatusLine + Export_is408Message = is408Message ) const MaxWriteWaitBeforeConnReuse = maxWriteWaitBeforeConnReuse @@ -244,3 +245,18 @@ func ExportSetH2GoawayTimeout(d time.Duration) (restore func()) { } func (r *Request) ExportIsReplayable() bool { return r.isReplayable() } + +// ExportCloseTransportConnsAbruptly closes all idle connections from +// tr in an abrupt way, just reaching into the underlying Conns and +// closing them, without telling the Transport or its persistConns +// that it's doing so. This is to simulate the server closing connections +// on the Transport. +func ExportCloseTransportConnsAbruptly(tr *Transport) { + tr.idleMu.Lock() + for _, pcs := range tr.idleConn { + for _, pc := range pcs { + pc.conn.Close() + } + } + tr.idleMu.Unlock() +} diff --git a/libgo/go/net/http/fs.go b/libgo/go/net/http/fs.go index db44d6b..41d46dc 100644 --- a/libgo/go/net/http/fs.go +++ b/libgo/go/net/http/fs.go @@ -63,6 +63,8 @@ func mapDirOpenError(originalErr error, name string) error { return originalErr } +// Open implements FileSystem using os.Open, opening files for reading rooted +// and relative to the directory d. func (d Dir) Open(name string) (File, error) { if filepath.Separator != '/' && strings.ContainsRune(name, filepath.Separator) { return nil, errors.New("http: invalid character in file path") diff --git a/libgo/go/net/http/h2_bundle.go b/libgo/go/net/http/h2_bundle.go index a848d68..2efa0ef 100644 --- a/libgo/go/net/http/h2_bundle.go +++ b/libgo/go/net/http/h2_bundle.go @@ -1,5 +1,5 @@ // Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT. -//go:generate bundle -o h2_bundle.go -prefix http2 -underscore golang.org/x/net/http2 +//go:generate bundle -o h2_bundle.go -prefix http2 golang.org/x/net/http2 // Package http2 implements the HTTP/2 protocol. // @@ -42,11 +42,12 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" - "internal/x/net/http/httpguts" - "internal/x/net/http2/hpack" - "internal/x/net/idna" + "golang.org/x/net/http/httpguts" + "golang.org/x/net/http2/hpack" + "golang.org/x/net/idna" ) // A list of the possible cipher suite ids. Taken from @@ -1885,7 +1886,7 @@ func (f *http2Framer) WriteData(streamID uint32, endStream bool, data []byte) er return f.WriteDataPadded(streamID, endStream, data, nil) } -// WriteData writes a DATA frame with optional padding. +// WriteDataPadded writes a DATA frame with optional padding. // // If pad is nil, the padding bit is not sent. // The length of pad must not exceed 255 bytes. @@ -3831,7 +3832,20 @@ func http2ConfigureServer(s *Server, conf *http2Server) error { if http2testHookOnConn != nil { http2testHookOnConn() } + // The TLSNextProto interface predates contexts, so + // the net/http package passes down its per-connection + // base context via an exported but unadvertised + // method on the Handler. This is for internal + // net/http<=>http2 use only. + var ctx context.Context + type baseContexter interface { + BaseContext() context.Context + } + if bc, ok := h.(baseContexter); ok { + ctx = bc.BaseContext() + } conf.ServeConn(c, &http2ServeConnOpts{ + Context: ctx, Handler: h, BaseConfig: hs, }) @@ -3842,6 +3856,10 @@ func http2ConfigureServer(s *Server, conf *http2Server) error { // ServeConnOpts are options for the Server.ServeConn method. type http2ServeConnOpts struct { + // Context is the base context to use. + // If nil, context.Background is used. + Context context.Context + // BaseConfig optionally sets the base configuration // for values. If nil, defaults are used. BaseConfig *Server @@ -3852,6 +3870,13 @@ type http2ServeConnOpts struct { Handler Handler } +func (o *http2ServeConnOpts) context() context.Context { + if o.Context != nil { + return o.Context + } + return context.Background() +} + func (o *http2ServeConnOpts) baseConfig() *Server { if o != nil && o.BaseConfig != nil { return o.BaseConfig @@ -3997,7 +4022,7 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { } func http2serverConnBaseContext(c net.Conn, opts *http2ServeConnOpts) (ctx context.Context, cancel func()) { - ctx, cancel = context.WithCancel(context.Background()) + ctx, cancel = context.WithCancel(opts.context()) ctx = context.WithValue(ctx, LocalAddrContextKey, c.LocalAddr()) if hs := opts.baseConfig(); hs != nil { ctx = context.WithValue(ctx, ServerContextKey, hs) @@ -5870,7 +5895,16 @@ type http2chunkWriter struct{ rws *http2responseWriterState } func (cw http2chunkWriter) Write(p []byte) (n int, err error) { return cw.rws.writeChunk(p) } -func (rws *http2responseWriterState) hasTrailers() bool { return len(rws.trailers) != 0 } +func (rws *http2responseWriterState) hasTrailers() bool { return len(rws.trailers) > 0 } + +func (rws *http2responseWriterState) hasNonemptyTrailers() bool { + for _, trailer := range rws.trailers { + if _, ok := rws.handlerHeader[trailer]; ok { + return true + } + } + return false +} // declareTrailer is called for each Trailer header when the // response header is written. It notes that a header will need to be @@ -5970,7 +6004,10 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { rws.promoteUndeclaredTrailers() } - endStream := rws.handlerDone && !rws.hasTrailers() + // only send trailers if they have actually been defined by the + // server handler. + hasNonemptyTrailers := rws.hasNonemptyTrailers() + endStream := rws.handlerDone && !hasNonemptyTrailers if len(p) > 0 || endStream { // only send a 0 byte DATA frame if we're ending the stream. if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil { @@ -5979,7 +6016,7 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { } } - if rws.handlerDone && rws.hasTrailers() { + if rws.handlerDone && hasNonemptyTrailers { err = rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{ streamID: rws.stream.id, h: rws.handlerHeader, @@ -6621,6 +6658,7 @@ type http2ClientConn struct { t *http2Transport tconn net.Conn // usually *tls.Conn, except specialized impls tlsState *tls.ConnectionState // nil only for specialized impls + reused uint32 // whether conn is being reused; atomic singleUse bool // whether being used for a single http.Request // readLoop goroutine fields: @@ -6863,7 +6901,8 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res t.vlogf("http2: Transport failed to get client conn for %s: %v", addr, err) return nil, err } - http2traceGotConn(req, cc) + reused := !atomic.CompareAndSwapUint32(&cc.reused, 0, 1) + http2traceGotConn(req, cc, reused) res, gotErrAfterReqBodyWrite, err := cc.roundTrip(req) if err != nil && retry <= 6 { if req, err = http2shouldRetryRequest(req, err, gotErrAfterReqBodyWrite); err == nil { @@ -7849,7 +7888,11 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail // followed by the query production (see Sections 3.3 and 3.4 of // [RFC3986]). f(":authority", host) - f(":method", req.Method) + m := req.Method + if m == "" { + m = MethodGet + } + f(":method", m) if req.Method != "CONNECT" { f(":path", path) f(":scheme", req.URL.Scheme) @@ -8993,15 +9036,15 @@ func http2traceGetConn(req *Request, hostPort string) { trace.GetConn(hostPort) } -func http2traceGotConn(req *Request, cc *http2ClientConn) { +func http2traceGotConn(req *Request, cc *http2ClientConn, reused bool) { trace := httptrace.ContextClientTrace(req.Context()) if trace == nil || trace.GotConn == nil { return } ci := httptrace.GotConnInfo{Conn: cc.tconn} + ci.Reused = reused cc.mu.Lock() - ci.Reused = cc.nextStreamID > 1 - ci.WasIdle = len(cc.streams) == 0 && ci.Reused + ci.WasIdle = len(cc.streams) == 0 && reused if ci.WasIdle && !cc.lastActive.IsZero() { ci.IdleTime = time.Now().Sub(cc.lastActive) } diff --git a/libgo/go/net/http/header.go b/libgo/go/net/http/header.go index b699e7e..1e1ed98 100644 --- a/libgo/go/net/http/header.go +++ b/libgo/go/net/http/header.go @@ -78,12 +78,19 @@ func (h Header) write(w io.Writer, trace *httptrace.ClientTrace) error { return h.writeSubset(w, nil, trace) } -func (h Header) clone() Header { +// Clone returns a copy of h. +func (h Header) Clone() Header { + // Find total number of values. + nv := 0 + for _, vv := range h { + nv += len(vv) + } + sv := make([]string, nv) // shared backing array for headers' values h2 := make(Header, len(h)) for k, vv := range h { - vv2 := make([]string, len(vv)) - copy(vv2, vv) - h2[k] = vv2 + n := copy(sv, vv) + h2[k] = sv[:n:n] + sv = sv[n:] } return h2 } diff --git a/libgo/go/net/http/http.go b/libgo/go/net/http/http.go index e5d59e1..3510fe6 100644 --- a/libgo/go/net/http/http.go +++ b/libgo/go/net/http/http.go @@ -11,7 +11,7 @@ import ( "time" "unicode/utf8" - "internal/x/net/http/httpguts" + "golang.org/x/net/http/httpguts" ) // maxInt64 is the effective "infinite" value for the Server and @@ -19,7 +19,7 @@ import ( const maxInt64 = 1<<63 - 1 // aLongTimeAgo is a non-zero time, far in the past, used for -// immediate cancelation of network operations. +// immediate cancellation of network operations. var aLongTimeAgo = time.Unix(1, 0) // TODO(bradfitz): move common stuff here. The other files have accumulated diff --git a/libgo/go/net/http/httptest/recorder.go b/libgo/go/net/http/httptest/recorder.go index f2c3c07..d0bc0fa 100644 --- a/libgo/go/net/http/httptest/recorder.go +++ b/libgo/go/net/http/httptest/recorder.go @@ -12,7 +12,7 @@ import ( "strconv" "strings" - "internal/x/net/http/httpguts" + "golang.org/x/net/http/httpguts" ) // ResponseRecorder is an implementation of http.ResponseWriter that @@ -59,7 +59,10 @@ func NewRecorder() *ResponseRecorder { // an explicit DefaultRemoteAddr isn't set on ResponseRecorder. const DefaultRemoteAddr = "1.2.3.4" -// Header returns the response headers. +// Header implements http.ResponseWriter. It returns the response +// headers to mutate within a handler. To test the headers that were +// written after a handler completes, use the Result method and see +// the returned Response value's Header. func (rw *ResponseRecorder) Header() http.Header { m := rw.HeaderMap if m == nil { @@ -98,7 +101,8 @@ func (rw *ResponseRecorder) writeHeader(b []byte, str string) { rw.WriteHeader(200) } -// Write always succeeds and writes to rw.Body, if not nil. +// Write implements http.ResponseWriter. The data in buf is written to +// rw.Body, if not nil. func (rw *ResponseRecorder) Write(buf []byte) (int, error) { rw.writeHeader(buf, "") if rw.Body != nil { @@ -107,7 +111,8 @@ func (rw *ResponseRecorder) Write(buf []byte) (int, error) { return len(buf), nil } -// WriteString always succeeds and writes to rw.Body, if not nil. +// WriteString implements io.StringWriter. The data in str is written +// to rw.Body, if not nil. func (rw *ResponseRecorder) WriteString(str string) (int, error) { rw.writeHeader(nil, str) if rw.Body != nil { @@ -116,8 +121,7 @@ func (rw *ResponseRecorder) WriteString(str string) (int, error) { return len(str), nil } -// WriteHeader sets rw.Code. After it is called, changing rw.Header -// will not affect rw.HeaderMap. +// WriteHeader implements http.ResponseWriter. func (rw *ResponseRecorder) WriteHeader(code int) { if rw.wroteHeader { return @@ -127,20 +131,11 @@ func (rw *ResponseRecorder) WriteHeader(code int) { if rw.HeaderMap == nil { rw.HeaderMap = make(http.Header) } - rw.snapHeader = cloneHeader(rw.HeaderMap) + rw.snapHeader = rw.HeaderMap.Clone() } -func cloneHeader(h http.Header) http.Header { - h2 := make(http.Header, len(h)) - for k, vv := range h { - vv2 := make([]string, len(vv)) - copy(vv2, vv) - h2[k] = vv2 - } - return h2 -} - -// Flush sets rw.Flushed to true. +// Flush implements http.Flusher. To test whether Flush was +// called, see rw.Flushed. func (rw *ResponseRecorder) Flush() { if !rw.wroteHeader { rw.WriteHeader(200) @@ -168,7 +163,7 @@ func (rw *ResponseRecorder) Result() *http.Response { return rw.result } if rw.snapHeader == nil { - rw.snapHeader = cloneHeader(rw.HeaderMap) + rw.snapHeader = rw.HeaderMap.Clone() } res := &http.Response{ Proto: "HTTP/1.1", diff --git a/libgo/go/net/http/httputil/dump_test.go b/libgo/go/net/http/httputil/dump_test.go index 63312dd..97954ca 100644 --- a/libgo/go/net/http/httputil/dump_test.go +++ b/libgo/go/net/http/httputil/dump_test.go @@ -18,7 +18,10 @@ import ( ) type dumpTest struct { - Req http.Request + // Either Req or GetReq can be set/nil but not both. + Req *http.Request + GetReq func() *http.Request + Body interface{} // optional []byte or func() io.ReadCloser to populate Req.Body WantDump string @@ -29,7 +32,7 @@ type dumpTest struct { var dumpTests = []dumpTest{ // HTTP/1.1 => chunked coding; body; empty trailer { - Req: http.Request{ + Req: &http.Request{ Method: "GET", URL: &url.URL{ Scheme: "http", @@ -52,7 +55,7 @@ var dumpTests = []dumpTest{ // Verify that DumpRequest preserves the HTTP version number, doesn't add a Host, // and doesn't add a User-Agent. { - Req: http.Request{ + Req: &http.Request{ Method: "GET", URL: mustParseURL("/foo"), ProtoMajor: 1, @@ -67,7 +70,7 @@ var dumpTests = []dumpTest{ }, { - Req: *mustNewRequest("GET", "http://example.com/foo", nil), + Req: mustNewRequest("GET", "http://example.com/foo", nil), WantDumpOut: "GET /foo HTTP/1.1\r\n" + "Host: example.com\r\n" + @@ -79,8 +82,7 @@ var dumpTests = []dumpTest{ // with a bytes.Buffer and hang with all goroutines not // runnable. { - Req: *mustNewRequest("GET", "https://example.com/foo", nil), - + Req: mustNewRequest("GET", "https://example.com/foo", nil), WantDumpOut: "GET /foo HTTP/1.1\r\n" + "Host: example.com\r\n" + "User-Agent: Go-http-client/1.1\r\n" + @@ -89,7 +91,7 @@ var dumpTests = []dumpTest{ // Request with Body, but Dump requested without it. { - Req: http.Request{ + Req: &http.Request{ Method: "POST", URL: &url.URL{ Scheme: "http", @@ -114,7 +116,7 @@ var dumpTests = []dumpTest{ // Request with Body > 8196 (default buffer size) { - Req: http.Request{ + Req: &http.Request{ Method: "POST", URL: &url.URL{ Scheme: "http", @@ -145,8 +147,10 @@ var dumpTests = []dumpTest{ }, { - Req: *mustReadRequest("GET http://foo.com/ HTTP/1.1\r\n" + - "User-Agent: blah\r\n\r\n"), + GetReq: func() *http.Request { + return mustReadRequest("GET http://foo.com/ HTTP/1.1\r\n" + + "User-Agent: blah\r\n\r\n") + }, NoBody: true, WantDump: "GET http://foo.com/ HTTP/1.1\r\n" + "User-Agent: blah\r\n\r\n", @@ -154,22 +158,25 @@ var dumpTests = []dumpTest{ // Issue #7215. DumpRequest should return the "Content-Length" when set { - Req: *mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" + - "Host: passport.myhost.com\r\n" + - "Content-Length: 3\r\n" + - "\r\nkey1=name1&key2=name2"), + GetReq: func() *http.Request { + return mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" + + "Host: passport.myhost.com\r\n" + + "Content-Length: 3\r\n" + + "\r\nkey1=name1&key2=name2") + }, WantDump: "POST /v2/api/?login HTTP/1.1\r\n" + "Host: passport.myhost.com\r\n" + "Content-Length: 3\r\n" + "\r\nkey", }, - // Issue #7215. DumpRequest should return the "Content-Length" in ReadRequest { - Req: *mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" + - "Host: passport.myhost.com\r\n" + - "Content-Length: 0\r\n" + - "\r\nkey1=name1&key2=name2"), + GetReq: func() *http.Request { + return mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" + + "Host: passport.myhost.com\r\n" + + "Content-Length: 0\r\n" + + "\r\nkey1=name1&key2=name2") + }, WantDump: "POST /v2/api/?login HTTP/1.1\r\n" + "Host: passport.myhost.com\r\n" + "Content-Length: 0\r\n\r\n", @@ -177,9 +184,11 @@ var dumpTests = []dumpTest{ // Issue #7215. DumpRequest should not return the "Content-Length" if unset { - Req: *mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" + - "Host: passport.myhost.com\r\n" + - "\r\nkey1=name1&key2=name2"), + GetReq: func() *http.Request { + return mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" + + "Host: passport.myhost.com\r\n" + + "\r\nkey1=name1&key2=name2") + }, WantDump: "POST /v2/api/?login HTTP/1.1\r\n" + "Host: passport.myhost.com\r\n\r\n", }, @@ -187,8 +196,7 @@ var dumpTests = []dumpTest{ // Issue 18506: make drainBody recognize NoBody. Otherwise // this was turning into a chunked request. { - Req: *mustNewRequest("POST", "http://example.com/foo", http.NoBody), - + Req: mustNewRequest("POST", "http://example.com/foo", http.NoBody), WantDumpOut: "POST /foo HTTP/1.1\r\n" + "Host: example.com\r\n" + "User-Agent: Go-http-client/1.1\r\n" + @@ -200,28 +208,40 @@ var dumpTests = []dumpTest{ func TestDumpRequest(t *testing.T) { numg0 := runtime.NumGoroutine() for i, tt := range dumpTests { - setBody := func() { - if tt.Body == nil { - return + if tt.Req != nil && tt.GetReq != nil || tt.Req == nil && tt.GetReq == nil { + t.Errorf("#%d: either .Req(%p) or .GetReq(%p) can be set/nil but not both", i, tt.Req, tt.GetReq) + continue + } + + freshReq := func(ti dumpTest) *http.Request { + req := ti.Req + if req == nil { + req = ti.GetReq() } - switch b := tt.Body.(type) { + + if req.Header == nil { + req.Header = make(http.Header) + } + + if ti.Body == nil { + return req + } + switch b := ti.Body.(type) { case []byte: - tt.Req.Body = ioutil.NopCloser(bytes.NewReader(b)) + req.Body = ioutil.NopCloser(bytes.NewReader(b)) case func() io.ReadCloser: - tt.Req.Body = b() + req.Body = b() default: - t.Fatalf("Test %d: unsupported Body of %T", i, tt.Body) + t.Fatalf("Test %d: unsupported Body of %T", i, ti.Body) } - } - if tt.Req.Header == nil { - tt.Req.Header = make(http.Header) + return req } if tt.WantDump != "" { - setBody() - dump, err := DumpRequest(&tt.Req, !tt.NoBody) + req := freshReq(tt) + dump, err := DumpRequest(req, !tt.NoBody) if err != nil { - t.Errorf("DumpRequest #%d: %s", i, err) + t.Errorf("DumpRequest #%d: %s\nWantDump:\n%s", i, err, tt.WantDump) continue } if string(dump) != tt.WantDump { @@ -231,8 +251,8 @@ func TestDumpRequest(t *testing.T) { } if tt.WantDumpOut != "" { - setBody() - dump, err := DumpRequestOut(&tt.Req, !tt.NoBody) + req := freshReq(tt) + dump, err := DumpRequestOut(req, !tt.NoBody) if err != nil { t.Errorf("DumpRequestOut #%d: %s", i, err) continue diff --git a/libgo/go/net/http/httputil/persist.go b/libgo/go/net/http/httputil/persist.go index cbedf25..84b116d 100644 --- a/libgo/go/net/http/httputil/persist.go +++ b/libgo/go/net/http/httputil/persist.go @@ -292,8 +292,8 @@ func (cc *ClientConn) Close() error { } // Write writes a request. An ErrPersistEOF error is returned if the connection -// has been closed in an HTTP keepalive sense. If req.Close equals true, the -// keepalive connection is logically closed after this request and the opposing +// has been closed in an HTTP keep-alive sense. If req.Close equals true, the +// keep-alive connection is logically closed after this request and the opposing // server is informed. An ErrUnexpectedEOF indicates the remote closed the // underlying TCP connection, which is usually considered as graceful close. func (cc *ClientConn) Write(req *http.Request) error { diff --git a/libgo/go/net/http/httputil/reverseproxy.go b/libgo/go/net/http/httputil/reverseproxy.go index 4b165d6..1d7b0ef 100644 --- a/libgo/go/net/http/httputil/reverseproxy.go +++ b/libgo/go/net/http/httputil/reverseproxy.go @@ -18,7 +18,7 @@ import ( "sync" "time" - "internal/x/net/http/httpguts" + "golang.org/x/net/http/httpguts" ) // ReverseProxy is an HTTP Handler that takes an incoming request and @@ -51,8 +51,7 @@ type ReverseProxy struct { // ErrorLog specifies an optional logger for errors // that occur when attempting to proxy the request. - // If nil, logging goes to os.Stderr via the log package's - // standard logger. + // If nil, logging is done via the log package's standard logger. ErrorLog *log.Logger // BufferPool optionally specifies a buffer pool to @@ -132,16 +131,6 @@ func copyHeader(dst, src http.Header) { } } -func cloneHeader(h http.Header) http.Header { - h2 := make(http.Header, len(h)) - for k, vv := range h { - vv2 := make([]string, len(vv)) - copy(vv2, vv) - h2[k] = vv2 - } - return h2 -} - // Hop-by-hop headers. These are removed when sent to the backend. // As of RFC 7230, hop-by-hop headers are required to appear in the // Connection header field. These are the headers defined by the @@ -206,13 +195,11 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { }() } - outreq := req.WithContext(ctx) // includes shallow copies of maps, but okay + outreq := req.Clone(ctx) if req.ContentLength == 0 { outreq.Body = nil // Issue 16036: nil Body for http.Transport retries } - outreq.Header = cloneHeader(req.Header) - p.Director(outreq) outreq.Close = false @@ -357,10 +344,10 @@ func shouldPanicOnCopyError(req *http.Request) bool { // removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h. // See RFC 7230, section 6.1 func removeConnectionHeaders(h http.Header) { - if c := h.Get("Connection"); c != "" { - for _, f := range strings.Split(c, ",") { - if f = strings.TrimSpace(f); f != "" { - h.Del(f) + for _, f := range h["Connection"] { + for _, sf := range strings.Split(f, ",") { + if sf = strings.TrimSpace(sf); sf != "" { + h.Del(sf) } } } diff --git a/libgo/go/net/http/httputil/reverseproxy_test.go b/libgo/go/net/http/httputil/reverseproxy_test.go index 367ba73..e8cb814 100644 --- a/libgo/go/net/http/httputil/reverseproxy_test.go +++ b/libgo/go/net/http/httputil/reverseproxy_test.go @@ -20,6 +20,7 @@ import ( "net/url" "os" "reflect" + "sort" "strconv" "strings" "sync" @@ -160,13 +161,17 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { const someConnHeader = "X-Some-Conn-Header" backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if c := r.Header.Get("Connection"); c != "" { + t.Errorf("handler got header %q = %q; want empty", "Connection", c) + } if c := r.Header.Get(fakeConnectionToken); c != "" { t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) } if c := r.Header.Get(someConnHeader); c != "" { t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) } - w.Header().Set("Connection", someConnHeader+", "+fakeConnectionToken) + w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken) + w.Header().Add("Connection", someConnHeader) w.Header().Set(someConnHeader, "should be deleted") w.Header().Set(fakeConnectionToken, "should be deleted") io.WriteString(w, backendResponse) @@ -179,15 +184,34 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { proxyHandler := NewSingleHostReverseProxy(backendURL) frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { proxyHandler.ServeHTTP(w, r) - if c := r.Header.Get(someConnHeader); c != "original value" { - t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "original value") + if c := r.Header.Get(someConnHeader); c != "should be deleted" { + t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted") + } + if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" { + t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted") + } + c := r.Header["Connection"] + var cf []string + for _, f := range c { + for _, sf := range strings.Split(f, ",") { + if sf = strings.TrimSpace(sf); sf != "" { + cf = append(cf, sf) + } + } + } + sort.Strings(cf) + expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken} + sort.Strings(expectedValues) + if !reflect.DeepEqual(cf, expectedValues) { + t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues) } })) defer frontend.Close() getReq, _ := http.NewRequest("GET", frontend.URL, nil) - getReq.Header.Set("Connection", someConnHeader+", "+fakeConnectionToken) - getReq.Header.Set(someConnHeader, "original value") + getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken) + getReq.Header.Add("Connection", someConnHeader) + getReq.Header.Set(someConnHeader, "should be deleted") getReq.Header.Set(fakeConnectionToken, "should be deleted") res, err := frontend.Client().Do(getReq) if err != nil { @@ -201,6 +225,9 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { if got, want := string(bodyBytes), backendResponse; got != want { t.Errorf("got body %q; want %q", got, want) } + if c := res.Header.Get("Connection"); c != "" { + t.Errorf("handler got header %q = %q; want empty", "Connection", c) + } if c := res.Header.Get(someConnHeader); c != "" { t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) } diff --git a/libgo/go/net/http/internal/testcert.go b/libgo/go/net/http/internal/testcert.go index 4078909..2284a83 100644 --- a/libgo/go/net/http/internal/testcert.go +++ b/libgo/go/net/http/internal/testcert.go @@ -4,6 +4,8 @@ package internal +import "strings" + // LocalhostCert is a PEM-encoded TLS cert with SAN IPs // "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT. // generated from src/crypto/tls: @@ -24,7 +26,7 @@ fblo6RBxUQ== -----END CERTIFICATE-----`) // LocalhostKey is the private key for localhostCert. -var LocalhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- +var LocalhostKey = []byte(testingKey(`-----BEGIN RSA TESTING KEY----- MIICXgIBAAKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9 SjY1bIw4iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZB l2+XsDulrKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQAB @@ -38,4 +40,6 @@ fQtuUE9txblTu14q3N7gHRZB4ZMhFYyDy8CKrN2cPg/Fvyt0Xlp/DoCzjA0CQQDU y2ptGsuSmgUtWj3NM9xuwYPm+Z/F84K6+ARYiZ6PYj013sovGKUFfYAqVXVlxtIX qyUBnu3X9ps8ZfjLZO7BAkEAlT4R5Yl6cGhaJQYZHOde3JEMhNRcVFMO8dJDaFeo f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA== ------END RSA PRIVATE KEY-----`) +-----END RSA TESTING KEY-----`)) + +func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") } diff --git a/libgo/go/net/http/request.go b/libgo/go/net/http/request.go index dcad2b6..fa63175 100644 --- a/libgo/go/net/http/request.go +++ b/libgo/go/net/http/request.go @@ -26,7 +26,7 @@ import ( "strings" "sync" - "internal/x/net/idna" + "golang.org/x/net/idna" ) const ( @@ -304,7 +304,7 @@ type Request struct { // // For server requests, this field is not applicable. // - // Deprecated: Use the Context and WithContext methods + // Deprecated: Set the Request's context with NewRequestWithContext // instead. If a Request's Cancel field and context are both // set, it is undefined whether Cancel is respected. Cancel <-chan struct{} @@ -327,7 +327,7 @@ type Request struct { // The returned context is always non-nil; it defaults to the // background context. // -// For outgoing client requests, the context controls cancelation. +// For outgoing client requests, the context controls cancellation. // // For incoming server requests, the context is canceled when the // client's connection closes, the request is canceled (with HTTP/2), @@ -345,6 +345,11 @@ func (r *Request) Context() context.Context { // For outgoing client request, the context controls the entire // lifetime of a request and its response: obtaining a connection, // sending the request, and reading the response headers and body. +// +// To create a new request with a context, use NewRequestWithContext. +// To change the context of a request (such as an incoming) you then +// also want to modify to send back out, use Request.Clone. Between +// those two uses, it's rare to need WithContext. func (r *Request) WithContext(ctx context.Context) *Request { if ctx == nil { panic("nil context") @@ -352,16 +357,38 @@ func (r *Request) WithContext(ctx context.Context) *Request { r2 := new(Request) *r2 = *r r2.ctx = ctx + r2.URL = cloneURL(r.URL) // legacy behavior; TODO: try to remove. Issue 23544 + return r2 +} - // Deep copy the URL because it isn't - // a map and the URL is mutable by users - // of WithContext. - if r.URL != nil { - r2URL := new(url.URL) - *r2URL = *r.URL - r2.URL = r2URL +// Clone returns a deep copy of r with its context changed to ctx. +// The provided ctx must be non-nil. +// +// For an outgoing client request, the context controls the entire +// lifetime of a request and its response: obtaining a connection, +// sending the request, and reading the response headers and body. +func (r *Request) Clone(ctx context.Context) *Request { + if ctx == nil { + panic("nil context") } - + r2 := new(Request) + *r2 = *r + r2.ctx = ctx + r2.URL = cloneURL(r.URL) + if r.Header != nil { + r2.Header = r.Header.Clone() + } + if r.Trailer != nil { + r2.Trailer = r.Trailer.Clone() + } + if s := r.TransferEncoding; s != nil { + s2 := make([]string, len(s)) + copy(s2, s) + r2.TransferEncoding = s + } + r2.Form = cloneURLValues(r.Form) + r2.PostForm = cloneURLValues(r.PostForm) + r2.MultipartForm = cloneMultipartForm(r.MultipartForm) return r2 } @@ -781,25 +808,34 @@ func validMethod(method string) bool { return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 } -// NewRequest returns a new Request given a method, URL, and optional body. +// NewRequest wraps NewRequestWithContext using the background context. +func NewRequest(method, url string, body io.Reader) (*Request, error) { + return NewRequestWithContext(context.Background(), method, url, body) +} + +// NewRequestWithContext returns a new Request given a method, URL, and +// optional body. // // If the provided body is also an io.Closer, the returned // Request.Body is set to body and will be closed by the Client // methods Do, Post, and PostForm, and Transport.RoundTrip. // -// NewRequest returns a Request suitable for use with Client.Do or -// Transport.RoundTrip. To create a request for use with testing a -// Server Handler, either use the NewRequest function in the +// NewRequestWithContext returns a Request suitable for use with +// Client.Do or Transport.RoundTrip. To create a request for use with +// testing a Server Handler, either use the NewRequest function in the // net/http/httptest package, use ReadRequest, or manually update the -// Request fields. See the Request type's documentation for the -// difference between inbound and outbound request fields. +// Request fields. For an outgoing client request, the context +// controls the entire lifetime of a request and its response: +// obtaining a connection, sending the request, and reading the +// response headers and body. See the Request type's documentation for +// the difference between inbound and outbound request fields. // // If body is of type *bytes.Buffer, *bytes.Reader, or // *strings.Reader, the returned request's ContentLength is set to its // exact value (instead of -1), GetBody is populated (so 307 and 308 // redirects can replay the body), and Body is set to NoBody if the // ContentLength is 0. -func NewRequest(method, url string, body io.Reader) (*Request, error) { +func NewRequestWithContext(ctx context.Context, method, url string, body io.Reader) (*Request, error) { if method == "" { // We document that "" means "GET" for Request.Method, and people have // relied on that from NewRequest, so keep that working. @@ -809,6 +845,9 @@ func NewRequest(method, url string, body io.Reader) (*Request, error) { if !validMethod(method) { return nil, fmt.Errorf("net/http: invalid method %q", method) } + if ctx == nil { + return nil, errors.New("net/http: nil Context") + } u, err := parseURL(url) // Just url.Parse (url is shadowed for godoc). if err != nil { return nil, err @@ -820,6 +859,7 @@ func NewRequest(method, url string, body io.Reader) (*Request, error) { // The host's colon:port should be normalized. See Issue 14836. u.Host = removeEmptyPort(u.Host) req := &Request{ + ctx: ctx, Method: method, URL: u, Proto: "HTTP/1.1", @@ -912,6 +952,10 @@ func parseBasicAuth(auth string) (username, password string, ok bool) { // // With HTTP Basic Authentication the provided username and password // are not encrypted. +// +// Some protocols may impose additional requirements on pre-escaping the +// username and password. For instance, when used with OAuth2, both arguments +// must be URL encoded first with url.QueryEscape. func (r *Request) SetBasicAuth(username, password string) { r.Header.Set("Authorization", "Basic "+basicAuth(username, password)) } diff --git a/libgo/go/net/http/request_test.go b/libgo/go/net/http/request_test.go index e800557..b072f95 100644 --- a/libgo/go/net/http/request_test.go +++ b/libgo/go/net/http/request_test.go @@ -8,12 +8,14 @@ import ( "bufio" "bytes" "context" + "crypto/rand" "encoding/base64" "fmt" "io" "io/ioutil" "mime/multipart" . "net/http" + "net/http/httptest" "net/url" "os" "reflect" @@ -133,30 +135,31 @@ func TestParseFormInitializeOnError(t *testing.T) { } func TestMultipartReader(t *testing.T) { - req := &Request{ - Method: "POST", - Header: Header{"Content-Type": {`multipart/form-data; boundary="foo123"`}}, - Body: ioutil.NopCloser(new(bytes.Buffer)), - } - multipart, err := req.MultipartReader() - if multipart == nil { - t.Errorf("expected multipart; error: %v", err) - } - - req = &Request{ - Method: "POST", - Header: Header{"Content-Type": {`multipart/mixed; boundary="foo123"`}}, - Body: ioutil.NopCloser(new(bytes.Buffer)), - } - multipart, err = req.MultipartReader() - if multipart == nil { - t.Errorf("expected multipart; error: %v", err) + tests := []struct { + shouldError bool + contentType string + }{ + {false, `multipart/form-data; boundary="foo123"`}, + {false, `multipart/mixed; boundary="foo123"`}, + {true, `text/plain`}, } - req.Header = Header{"Content-Type": {"text/plain"}} - multipart, err = req.MultipartReader() - if multipart != nil { - t.Error("unexpected multipart for text/plain") + for i, test := range tests { + req := &Request{ + Method: "POST", + Header: Header{"Content-Type": {test.contentType}}, + Body: ioutil.NopCloser(new(bytes.Buffer)), + } + multipart, err := req.MultipartReader() + if test.shouldError { + if err == nil || multipart != nil { + t.Errorf("test %d: unexpectedly got nil-error (%v) or non-nil-multipart (%v)", i, err, multipart) + } + continue + } + if err != nil || multipart == nil { + t.Errorf("test %d: unexpectedly got error (%v) or nil-multipart (%v)", i, err, multipart) + } } } @@ -1046,3 +1049,92 @@ func BenchmarkReadRequestWrk(b *testing.B) { Host: localhost:8080 `) } + +const ( + withTLS = true + noTLS = false +) + +func BenchmarkFileAndServer_1KB(b *testing.B) { + benchmarkFileAndServer(b, 1<<10) +} + +func BenchmarkFileAndServer_16MB(b *testing.B) { + benchmarkFileAndServer(b, 1<<24) +} + +func BenchmarkFileAndServer_64MB(b *testing.B) { + benchmarkFileAndServer(b, 1<<26) +} + +func benchmarkFileAndServer(b *testing.B, n int64) { + f, err := ioutil.TempFile(os.TempDir(), "go-bench-http-file-and-server") + if err != nil { + b.Fatalf("Failed to create temp file: %v", err) + } + + defer func() { + f.Close() + os.RemoveAll(f.Name()) + }() + + if _, err := io.CopyN(f, rand.Reader, n); err != nil { + b.Fatalf("Failed to copy %d bytes: %v", n, err) + } + + b.Run("NoTLS", func(b *testing.B) { + runFileAndServerBenchmarks(b, noTLS, f, n) + }) + + b.Run("TLS", func(b *testing.B) { + runFileAndServerBenchmarks(b, withTLS, f, n) + }) +} + +func runFileAndServerBenchmarks(b *testing.B, tlsOption bool, f *os.File, n int64) { + handler := HandlerFunc(func(rw ResponseWriter, req *Request) { + defer req.Body.Close() + nc, err := io.Copy(ioutil.Discard, req.Body) + if err != nil { + panic(err) + } + + if nc != n { + panic(fmt.Errorf("Copied %d Wanted %d bytes", nc, n)) + } + }) + + var cst *httptest.Server + if tlsOption == withTLS { + cst = httptest.NewTLSServer(handler) + } else { + cst = httptest.NewServer(handler) + } + + defer cst.Close() + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Perform some setup. + b.StopTimer() + if _, err := f.Seek(0, 0); err != nil { + b.Fatalf("Failed to seek back to file: %v", err) + } + + b.StartTimer() + req, err := NewRequest("PUT", cst.URL, ioutil.NopCloser(f)) + if err != nil { + b.Fatal(err) + } + + req.ContentLength = n + // Prevent mime sniffing by setting the Content-Type. + req.Header.Set("Content-Type", "application/octet-stream") + res, err := cst.Client().Do(req) + if err != nil { + b.Fatalf("Failed to make request to backend: %v", err) + } + + res.Body.Close() + b.SetBytes(n) + } +} diff --git a/libgo/go/net/http/response.go b/libgo/go/net/http/response.go index f906ce8..2065a25 100644 --- a/libgo/go/net/http/response.go +++ b/libgo/go/net/http/response.go @@ -12,12 +12,13 @@ import ( "crypto/tls" "errors" "fmt" - "internal/x/net/http/httpguts" "io" "net/textproto" "net/url" "strconv" "strings" + + "golang.org/x/net/http/httpguts" ) var respExcludeHeader = map[string]bool{ @@ -66,7 +67,7 @@ type Response struct { // with a "chunked" Transfer-Encoding. // // As of Go 1.12, the Body will be also implement io.Writer - // on a successful "101 Switching Protocols" responses, + // on a successful "101 Switching Protocols" response, // as used by WebSockets and HTTP/2's "h2c" mode. Body io.ReadCloser diff --git a/libgo/go/net/http/response_test.go b/libgo/go/net/http/response_test.go index c46f13f..ee7f0d0 100644 --- a/libgo/go/net/http/response_test.go +++ b/libgo/go/net/http/response_test.go @@ -10,7 +10,7 @@ import ( "compress/gzip" "crypto/rand" "fmt" - "go/ast" + "go/token" "io" "io/ioutil" "net/http/internal" @@ -736,7 +736,7 @@ func diff(t *testing.T, prefix string, have, want interface{}) { } for i := 0; i < hv.NumField(); i++ { name := hv.Type().Field(i).Name - if !ast.IsExported(name) { + if !token.IsExported(name) { continue } hf := hv.Field(i).Interface() diff --git a/libgo/go/net/http/roundtrip_js.go b/libgo/go/net/http/roundtrip_js.go index 1e38b90..6331351 100644 --- a/libgo/go/net/http/roundtrip_js.go +++ b/libgo/go/net/http/roundtrip_js.go @@ -11,12 +11,12 @@ import ( "fmt" "io" "io/ioutil" - "os" "strconv" - "strings" "syscall/js" ) +var uint8Array = js.Global().Get("Uint8Array") + // jsFetchMode is a Request.Header map key that, if present, // signals that the map entry is actually an option to the Fetch API mode setting. // Valid values are: "cors", "no-cors", "same-origin", "navigate" @@ -33,9 +33,19 @@ const jsFetchMode = "js.fetch:mode" // Reference: https://developer.mozilla.org/en-US/docs/Web/API/WindowOrWorkerGlobalScope/fetch#Parameters const jsFetchCreds = "js.fetch:credentials" +// jsFetchRedirect is a Request.Header map key that, if present, +// signals that the map entry is actually an option to the Fetch API redirect setting. +// Valid values are: "follow", "error", "manual" +// The default is "follow". +// +// Reference: https://developer.mozilla.org/en-US/docs/Web/API/WindowOrWorkerGlobalScope/fetch#Parameters +const jsFetchRedirect = "js.fetch:redirect" + +var useFakeNetwork = js.Global().Get("fetch") == js.Undefined() + // RoundTrip implements the RoundTripper interface using the WHATWG Fetch API. func (t *Transport) RoundTrip(req *Request) (*Response, error) { - if useFakeNetwork() { + if useFakeNetwork { return t.roundTrip(req) } @@ -60,6 +70,10 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { opt.Set("mode", h) req.Header.Del(jsFetchMode) } + if h := req.Header.Get(jsFetchRedirect); h != "" { + opt.Set("redirect", h) + req.Header.Del(jsFetchRedirect) + } if ac != js.Undefined() { opt.Set("signal", ac.Get("signal")) } @@ -84,9 +98,9 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { return nil, err } req.Body.Close() - a := js.TypedArrayOf(body) - defer a.Release() - opt.Set("body", a) + buf := uint8Array.New(len(body)) + js.CopyBytesToJS(buf, body) + opt.Set("body", buf) } respPromise := js.Global().Call("fetch", req.URL.String(), opt) var ( @@ -126,10 +140,11 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { body = &arrayReader{arrayPromise: result.Call("arrayBuffer")} } + code := result.Get("status").Int() select { case respCh <- &Response{ - Status: result.Get("status").String() + " " + StatusText(result.Get("status").Int()), - StatusCode: result.Get("status").Int(), + Status: fmt.Sprintf("%d %s", code, StatusText(code)), + StatusCode: code, Header: header, ContentLength: contentLength, Body: body, @@ -167,12 +182,6 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { var errClosed = errors.New("net/http: reader is closed") -// useFakeNetwork is used to determine whether the request is made -// by a test and should be made to use the fake in-memory network. -func useFakeNetwork() bool { - return len(os.Args) > 0 && strings.HasSuffix(os.Args[0], ".test") -} - // streamReader implements an io.ReadCloser wrapper for ReadableStream. // See https://fetch.spec.whatwg.org/#readablestream for more information. type streamReader struct { @@ -197,9 +206,7 @@ func (r *streamReader) Read(p []byte) (n int, err error) { return nil } value := make([]byte, result.Get("value").Get("byteLength").Int()) - a := js.TypedArrayOf(value) - a.Call("set", result.Get("value")) - a.Release() + js.CopyBytesToGo(value, result.Get("value")) bCh <- value return nil }) @@ -260,11 +267,9 @@ func (r *arrayReader) Read(p []byte) (n int, err error) { ) success := js.FuncOf(func(this js.Value, args []js.Value) interface{} { // Wrap the input ArrayBuffer with a Uint8Array - uint8arrayWrapper := js.Global().Get("Uint8Array").New(args[0]) + uint8arrayWrapper := uint8Array.New(args[0]) value := make([]byte, uint8arrayWrapper.Get("byteLength").Int()) - a := js.TypedArrayOf(value) - a.Call("set", uint8arrayWrapper) - a.Release() + js.CopyBytesToGo(value, uint8arrayWrapper) bCh <- value return nil }) diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go index 6eb0088..e7ed15c 100644 --- a/libgo/go/net/http/serve_test.go +++ b/libgo/go/net/http/serve_test.go @@ -4273,7 +4273,7 @@ func testServerEmptyBodyRace(t *testing.T, h2 bool) { var n int32 cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { atomic.AddInt32(&n, 1) - })) + }), optQuietLog) defer cst.close() var wg sync.WaitGroup const reqs = 20 @@ -4697,6 +4697,10 @@ func TestServerHandlersCanHandleH2PRI(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { conn, br, err := w.(Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } defer conn.Close() if r.Method != "PRI" || r.RequestURI != "*" { t.Errorf("Got method/target %q %q; want PRI *", r.Method, r.RequestURI) @@ -5745,8 +5749,12 @@ func TestServerDuplicateBackgroundRead(t *testing.T) { setParallel(t) defer afterTest(t) - const goroutines = 5 - const requests = 2000 + goroutines := 5 + requests := 2000 + if testing.Short() { + goroutines = 3 + requests = 100 + } hts := httptest.NewServer(HandlerFunc(NotFound)) defer hts.Close() @@ -6021,6 +6029,143 @@ func TestStripPortFromHost(t *testing.T) { } } +func TestServerContexts(t *testing.T) { + setParallel(t) + defer afterTest(t) + type baseKey struct{} + type connKey struct{} + ch := make(chan context.Context, 1) + ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + ch <- r.Context() + })) + ts.Config.BaseContext = func(ln net.Listener) context.Context { + if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") { + t.Errorf("unexpected onceClose listener type %T", ln) + } + return context.WithValue(context.Background(), baseKey{}, "base") + } + ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + if got, want := ctx.Value(baseKey{}), "base"; got != want { + t.Errorf("in ConnContext, base context key = %#v; want %q", got, want) + } + return context.WithValue(ctx, connKey{}, "conn") + } + ts.Start() + defer ts.Close() + res, err := ts.Client().Get(ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + ctx := <-ch + if got, want := ctx.Value(baseKey{}), "base"; got != want { + t.Errorf("base context key = %#v; want %q", got, want) + } + if got, want := ctx.Value(connKey{}), "conn"; got != want { + t.Errorf("conn context key = %#v; want %q", got, want) + } +} + +func TestServerContextsHTTP2(t *testing.T) { + setParallel(t) + defer afterTest(t) + type baseKey struct{} + type connKey struct{} + ch := make(chan context.Context, 1) + ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + if r.ProtoMajor != 2 { + t.Errorf("unexpected HTTP/1.x request") + } + ch <- r.Context() + })) + ts.Config.BaseContext = func(ln net.Listener) context.Context { + if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") { + t.Errorf("unexpected onceClose listener type %T", ln) + } + return context.WithValue(context.Background(), baseKey{}, "base") + } + ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + if got, want := ctx.Value(baseKey{}), "base"; got != want { + t.Errorf("in ConnContext, base context key = %#v; want %q", got, want) + } + return context.WithValue(ctx, connKey{}, "conn") + } + ts.TLS = &tls.Config{ + NextProtos: []string{"h2", "http/1.1"}, + } + ts.StartTLS() + defer ts.Close() + ts.Client().Transport.(*Transport).ForceAttemptHTTP2 = true + res, err := ts.Client().Get(ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + ctx := <-ch + if got, want := ctx.Value(baseKey{}), "base"; got != want { + t.Errorf("base context key = %#v; want %q", got, want) + } + if got, want := ctx.Value(connKey{}), "conn"; got != want { + t.Errorf("conn context key = %#v; want %q", got, want) + } +} + +// Issue 30710: ensure that as per the spec, a server responds +// with 501 Not Implemented for unsupported transfer-encodings. +func TestUnsupportedTransferEncodingsReturn501(t *testing.T) { + cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Write([]byte("Hello, World!")) + })) + defer cst.Close() + + serverURL, err := url.Parse(cst.URL) + if err != nil { + t.Fatalf("Failed to parse server URL: %v", err) + } + + unsupportedTEs := []string{ + "fugazi", + "foo-bar", + "unknown", + } + + for _, badTE := range unsupportedTEs { + http1ReqBody := fmt.Sprintf(""+ + "POST / HTTP/1.1\r\nConnection: close\r\n"+ + "Host: localhost\r\nTransfer-Encoding: %s\r\n\r\n", badTE) + + gotBody, err := fetchWireResponse(serverURL.Host, []byte(http1ReqBody)) + if err != nil { + t.Errorf("%q. unexpected error: %v", badTE, err) + continue + } + + wantBody := fmt.Sprintf("" + + "HTTP/1.1 501 Not Implemented\r\nContent-Type: text/plain; charset=utf-8\r\n" + + "Connection: close\r\n\r\nUnsupported transfer encoding") + + if string(gotBody) != wantBody { + t.Errorf("%q. body\ngot\n%q\nwant\n%q", badTE, gotBody, wantBody) + } + } +} + +// fetchWireResponse is a helper for dialing to host, +// sending http1ReqBody as the payload and retrieving +// the response as it was sent on the wire. +func fetchWireResponse(host string, http1ReqBody []byte) ([]byte, error) { + conn, err := net.Dial("tcp", host) + if err != nil { + return nil, err + } + defer conn.Close() + + if _, err := conn.Write(http1ReqBody); err != nil { + return nil, err + } + return ioutil.ReadAll(conn) +} + func BenchmarkResponseStatusLine(b *testing.B) { b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go index cbf5673..74569bf 100644 --- a/libgo/go/net/http/server.go +++ b/libgo/go/net/http/server.go @@ -29,7 +29,7 @@ import ( "sync/atomic" "time" - "internal/x/net/http/httpguts" + "golang.org/x/net/http/httpguts" ) // Errors used by the HTTP server. @@ -749,10 +749,8 @@ func (cr *connReader) handleReadError(_ error) { // may be called from multiple goroutines. func (cr *connReader) closeNotify() { res, _ := cr.conn.curReq.Load().(*response) - if res != nil { - if atomic.CompareAndSwapInt32(&res.didCloseNotify, 0, 1) { - res.closeNotifyCh <- true - } + if res != nil && atomic.CompareAndSwapInt32(&res.didCloseNotify, 0, 1) { + res.closeNotifyCh <- true } } @@ -1060,7 +1058,7 @@ func (w *response) Header() Header { // Accessing the header between logically writing it // and physically writing it means we need to allocate // a clone to snapshot the logically written state. - w.cw.header = w.handlerHeader.clone() + w.cw.header = w.handlerHeader.Clone() } w.calledHeader = true return w.handlerHeader @@ -1134,7 +1132,7 @@ func (w *response) WriteHeader(code int) { w.status = code if w.calledHeader && w.cw.header == nil { - w.cw.header = w.handlerHeader.clone() + w.cw.header = w.handlerHeader.Clone() } if cl := w.handlerHeader.get("Content-Length"); cl != "" { @@ -1803,7 +1801,7 @@ func (c *conn) serve(ctx context.Context) { *c.tlsState = tlsConn.ConnectionState() if proto := c.tlsState.NegotiatedProtocol; validNPN(proto) { if fn := c.server.TLSNextProto[proto]; fn != nil { - h := initNPNRequest{tlsConn, serverHandler{c.server}} + h := initNPNRequest{ctx, tlsConn, serverHandler{c.server}} fn(c.server, tlsConn, h) } return @@ -1829,7 +1827,8 @@ func (c *conn) serve(ctx context.Context) { if err != nil { const errorHeaders = "\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\n" - if err == errTooLarge { + switch { + case err == errTooLarge: // Their HTTP client may or may not be // able to read this if we're // responding to them and hanging up @@ -1839,18 +1838,31 @@ func (c *conn) serve(ctx context.Context) { fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr) c.closeWriteAndWait() return - } - if isCommonNetReadError(err) { + + case isUnsupportedTEError(err): + // Respond as per RFC 7230 Section 3.3.1 which says, + // A server that receives a request message with a + // transfer coding it does not understand SHOULD + // respond with 501 (Unimplemented). + code := StatusNotImplemented + + // We purposefully aren't echoing back the transfer-encoding's value, + // so as to mitigate the risk of cross side scripting by an attacker. + fmt.Fprintf(c.rwc, "HTTP/1.1 %d %s%sUnsupported transfer encoding", code, StatusText(code), errorHeaders) + return + + case isCommonNetReadError(err): return // don't reply - } - publicErr := "400 Bad Request" - if v, ok := err.(badRequestError); ok { - publicErr = publicErr + ": " + string(v) - } + default: + publicErr := "400 Bad Request" + if v, ok := err.(badRequestError); ok { + publicErr = publicErr + ": " + string(v) + } - fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr) - return + fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr) + return + } } // Expect 100 Continue support @@ -2505,7 +2517,9 @@ type Server struct { // ReadHeaderTimeout is the amount of time allowed to read // request headers. The connection's read deadline is reset // after reading the headers and the Handler can decide what - // is considered too slow for the body. + // is considered too slow for the body. If ReadHeaderTimeout + // is zero, the value of ReadTimeout is used. If both are + // zero, there is no timeout. ReadHeaderTimeout time.Duration // WriteTimeout is the maximum duration before timing out @@ -2517,7 +2531,7 @@ type Server struct { // IdleTimeout is the maximum amount of time to wait for the // next request when keep-alives are enabled. If IdleTimeout // is zero, the value of ReadTimeout is used. If both are - // zero, ReadHeaderTimeout is used. + // zero, there is no timeout. IdleTimeout time.Duration // MaxHeaderBytes controls the maximum number of bytes the @@ -2549,6 +2563,20 @@ type Server struct { // If nil, logging is done via the log package's standard logger. ErrorLog *log.Logger + // BaseContext optionally specifies a function that returns + // the base context for incoming requests on this server. + // The provided Listener is the specific Listener that's + // about to start accepting requests. + // If BaseContext is nil, the default is context.Background(). + // If non-nil, it must return a non-nil context. + BaseContext func(net.Listener) context.Context + + // ConnContext optionally specifies a function that modifies + // the context used for a new connection c. The provided ctx + // is derived from the base context and has a ServerContextKey + // value. + ConnContext func(ctx context.Context, c net.Conn) context.Context + disableKeepAlives int32 // accessed atomically. inShutdown int32 // accessed atomically (non-zero means we're in Shutdown) nextProtoOnce sync.Once // guards setupHTTP2_* init @@ -2799,7 +2827,7 @@ func (srv *Server) ListenAndServe() error { if err != nil { return err } - return srv.Serve(tcpKeepAliveListener{ln.(*net.TCPListener)}) + return srv.Serve(ln) } var testHookServerServe func(*Server, net.Listener) // used if non-nil @@ -2845,6 +2873,7 @@ func (srv *Server) Serve(l net.Listener) error { fn(srv, l) // call hook with unwrapped listener } + origListener := l l = &onceCloseListener{Listener: l} defer l.Close() @@ -2857,8 +2886,16 @@ func (srv *Server) Serve(l net.Listener) error { } defer srv.trackListener(&l, false) - var tempDelay time.Duration // how long to sleep on accept failure - baseCtx := context.Background() // base is always background, per Issue 16220 + var tempDelay time.Duration // how long to sleep on accept failure + + baseCtx := context.Background() + if srv.BaseContext != nil { + baseCtx = srv.BaseContext(origListener) + if baseCtx == nil { + panic("BaseContext returned a nil context") + } + } + ctx := context.WithValue(baseCtx, ServerContextKey, srv) for { rw, e := l.Accept() @@ -2883,6 +2920,12 @@ func (srv *Server) Serve(l net.Listener) error { } return e } + if cc := srv.ConnContext; cc != nil { + ctx = cc(ctx, rw) + if ctx == nil { + panic("ConnContext returned nil") + } + } tempDelay = 0 c := srv.newConn(rw) c.setState(c.rwc, StateNew) // before Serve can return @@ -3083,7 +3126,7 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { defer ln.Close() - return srv.ServeTLS(tcpKeepAliveListener{ln.(*net.TCPListener)}, certFile, keyFile) + return srv.ServeTLS(ln, certFile, keyFile) } // setupHTTP2_ServeTLS conditionally configures HTTP/2 on @@ -3228,6 +3271,25 @@ type timeoutWriter struct { code int } +var _ Pusher = (*timeoutWriter)(nil) +var _ Flusher = (*timeoutWriter)(nil) + +// Push implements the Pusher interface. +func (tw *timeoutWriter) Push(target string, opts *PushOptions) error { + if pusher, ok := tw.w.(Pusher); ok { + return pusher.Push(target, opts) + } + return ErrNotSupported +} + +// Flush implements the Flusher interface. +func (tw *timeoutWriter) Flush() { + f, ok := tw.w.(Flusher) + if ok { + f.Flush() + } +} + func (tw *timeoutWriter) Header() Header { return tw.h } func (tw *timeoutWriter) Write(p []byte) (int, error) { @@ -3257,24 +3319,6 @@ func (tw *timeoutWriter) writeHeader(code int) { tw.code = code } -// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted -// connections. It's used by ListenAndServe and ListenAndServeTLS so -// dead TCP connections (e.g. closing laptop mid-download) eventually -// go away. -type tcpKeepAliveListener struct { - *net.TCPListener -} - -func (ln tcpKeepAliveListener) Accept() (net.Conn, error) { - tc, err := ln.AcceptTCP() - if err != nil { - return nil, err - } - tc.SetKeepAlive(true) - tc.SetKeepAlivePeriod(3 * time.Minute) - return tc, nil -} - // onceCloseListener wraps a net.Listener, protecting it from // multiple Close calls. type onceCloseListener struct { @@ -3310,10 +3354,17 @@ func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) { // uninitialized fields in its *Request. Such partially-initialized // Requests come from NPN protocol handlers. type initNPNRequest struct { - c *tls.Conn - h serverHandler + ctx context.Context + c *tls.Conn + h serverHandler } +// BaseContext is an exported but unadvertised http.Handler method +// recognized by x/net/http2 to pass down a context; the TLSNextProto +// API predates context support so we shoehorn through the only +// interface we have available. +func (h initNPNRequest) BaseContext() context.Context { return h.ctx } + func (h initNPNRequest) ServeHTTP(rw ResponseWriter, req *Request) { if req.TLS == nil { req.TLS = &tls.ConnectionState{} diff --git a/libgo/go/net/http/sniff.go b/libgo/go/net/http/sniff.go index c1494ab..67a7151 100644 --- a/libgo/go/net/http/sniff.go +++ b/libgo/go/net/http/sniff.go @@ -37,6 +37,8 @@ func DetectContentType(data []byte) string { return "application/octet-stream" // fallback } +// isWS reports whether the provided byte is a whitespace byte (0xWS) +// as defined in https://mimesniff.spec.whatwg.org/#terminology. func isWS(b byte) bool { switch b { case '\t', '\n', '\x0c', '\r', ' ': @@ -45,6 +47,16 @@ func isWS(b byte) bool { return false } +// isTT reports whether the provided byte is a tag-terminating byte (0xTT) +// as defined in https://mimesniff.spec.whatwg.org/#terminology. +func isTT(b byte) bool { + switch b { + case ' ', '>': + return true + } + return false +} + type sniffSig interface { // match returns the MIME type of the data, or "" if unknown. match(data []byte, firstNonWS int) string @@ -69,33 +81,57 @@ var sniffSignatures = []sniffSig{ htmlSig("<BR"), htmlSig("<P"), htmlSig("<!--"), - - &maskedSig{mask: []byte("\xFF\xFF\xFF\xFF\xFF"), pat: []byte("<?xml"), skipWS: true, ct: "text/xml; charset=utf-8"}, - + &maskedSig{ + mask: []byte("\xFF\xFF\xFF\xFF\xFF"), + pat: []byte("<?xml"), + skipWS: true, + ct: "text/xml; charset=utf-8"}, &exactSig{[]byte("%PDF-"), "application/pdf"}, &exactSig{[]byte("%!PS-Adobe-"), "application/postscript"}, // UTF BOMs. - &maskedSig{mask: []byte("\xFF\xFF\x00\x00"), pat: []byte("\xFE\xFF\x00\x00"), ct: "text/plain; charset=utf-16be"}, - &maskedSig{mask: []byte("\xFF\xFF\x00\x00"), pat: []byte("\xFF\xFE\x00\x00"), ct: "text/plain; charset=utf-16le"}, - &maskedSig{mask: []byte("\xFF\xFF\xFF\x00"), pat: []byte("\xEF\xBB\xBF\x00"), ct: "text/plain; charset=utf-8"}, + &maskedSig{ + mask: []byte("\xFF\xFF\x00\x00"), + pat: []byte("\xFE\xFF\x00\x00"), + ct: "text/plain; charset=utf-16be", + }, + &maskedSig{ + mask: []byte("\xFF\xFF\x00\x00"), + pat: []byte("\xFF\xFE\x00\x00"), + ct: "text/plain; charset=utf-16le", + }, + &maskedSig{ + mask: []byte("\xFF\xFF\xFF\x00"), + pat: []byte("\xEF\xBB\xBF\x00"), + ct: "text/plain; charset=utf-8", + }, + // Image types + // For posterity, we originally returned "image/vnd.microsoft.icon" from + // https://tools.ietf.org/html/draft-ietf-websec-mime-sniff-03#section-7 + // https://codereview.appspot.com/4746042 + // but that has since been replaced with "image/x-icon" in Section 6.2 + // of https://mimesniff.spec.whatwg.org/#matching-an-image-type-pattern + &exactSig{[]byte("\x00\x00\x01\x00"), "image/x-icon"}, + &exactSig{[]byte("\x00\x00\x02\x00"), "image/x-icon"}, + &exactSig{[]byte("BM"), "image/bmp"}, &exactSig{[]byte("GIF87a"), "image/gif"}, &exactSig{[]byte("GIF89a"), "image/gif"}, - &exactSig{[]byte("\x89\x50\x4E\x47\x0D\x0A\x1A\x0A"), "image/png"}, - &exactSig{[]byte("\xFF\xD8\xFF"), "image/jpeg"}, - &exactSig{[]byte("BM"), "image/bmp"}, &maskedSig{ mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF\xFF\xFF"), pat: []byte("RIFF\x00\x00\x00\x00WEBPVP"), ct: "image/webp", }, - &exactSig{[]byte("\x00\x00\x01\x00"), "image/vnd.microsoft.icon"}, + &exactSig{[]byte("\x89PNG\x0D\x0A\x1A\x0A"), "image/png"}, + &exactSig{[]byte("\xFF\xD8\xFF"), "image/jpeg"}, + // Audio and Video types + // Enforce the pattern match ordering as prescribed in + // https://mimesniff.spec.whatwg.org/#matching-an-audio-or-video-type-pattern &maskedSig{ - mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF"), - pat: []byte("RIFF\x00\x00\x00\x00WAVE"), - ct: "audio/wave", + mask: []byte("\xFF\xFF\xFF\xFF"), + pat: []byte(".snd"), + ct: "audio/basic", }, &maskedSig{ mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF"), @@ -103,9 +139,9 @@ var sniffSignatures = []sniffSig{ ct: "audio/aiff", }, &maskedSig{ - mask: []byte("\xFF\xFF\xFF\xFF"), - pat: []byte(".snd"), - ct: "audio/basic", + mask: []byte("\xFF\xFF\xFF"), + pat: []byte("ID3"), + ct: "audio/mpeg", }, &maskedSig{ mask: []byte("\xFF\xFF\xFF\xFF\xFF"), @@ -118,20 +154,24 @@ var sniffSignatures = []sniffSig{ ct: "audio/midi", }, &maskedSig{ - mask: []byte("\xFF\xFF\xFF"), - pat: []byte("ID3"), - ct: "audio/mpeg", - }, - &maskedSig{ mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF"), pat: []byte("RIFF\x00\x00\x00\x00AVI "), ct: "video/avi", }, + &maskedSig{ + mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF"), + pat: []byte("RIFF\x00\x00\x00\x00WAVE"), + ct: "audio/wave", + }, + // 6.2.0.2. video/mp4 + mp4Sig{}, + // 6.2.0.3. video/webm + &exactSig{[]byte("\x1A\x45\xDF\xA3"), "video/webm"}, - // Fonts + // Font types &maskedSig{ // 34 NULL bytes followed by the string "LP" - pat: []byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x4C\x50"), + pat: []byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00LP"), // 34 NULL bytes followed by \xF\xF mask: []byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF"), ct: "application/vnd.ms-fontobject", @@ -142,15 +182,20 @@ var sniffSignatures = []sniffSig{ &exactSig{[]byte("wOFF"), "font/woff"}, &exactSig{[]byte("wOF2"), "font/woff2"}, - &exactSig{[]byte("\x1A\x45\xDF\xA3"), "video/webm"}, - &exactSig{[]byte("\x52\x61\x72\x20\x1A\x07\x00"), "application/x-rar-compressed"}, - &exactSig{[]byte("\x50\x4B\x03\x04"), "application/zip"}, + // Archive types &exactSig{[]byte("\x1F\x8B\x08"), "application/x-gzip"}, + &exactSig{[]byte("PK\x03\x04"), "application/zip"}, + // RAR's signatures are incorrectly defined by the MIME spec as per + // https://github.com/whatwg/mimesniff/issues/63 + // However, RAR Labs correctly defines it at: + // https://www.rarlab.com/technote.htm#rarsign + // so we use the definition from RAR Labs. + // TODO: do whatever the spec ends up doing. + &exactSig{[]byte("Rar!\x1A\x07\x00"), "application/x-rar-compressed"}, // RAR v1.5-v4.0 + &exactSig{[]byte("Rar!\x1A\x07\x01\x00"), "application/x-rar-compressed"}, // RAR v5+ &exactSig{[]byte("\x00\x61\x73\x6D"), "application/wasm"}, - mp4Sig{}, - textSig{}, // should be last } @@ -182,12 +227,12 @@ func (m *maskedSig) match(data []byte, firstNonWS int) string { if len(m.pat) != len(m.mask) { return "" } - if len(data) < len(m.mask) { + if len(data) < len(m.pat) { return "" } - for i, mask := range m.mask { - db := data[i] & mask - if db != m.pat[i] { + for i, pb := range m.pat { + maskedData := data[i] & m.mask[i] + if maskedData != pb { return "" } } @@ -210,8 +255,8 @@ func (h htmlSig) match(data []byte, firstNonWS int) string { return "" } } - // Next byte must be space or right angle bracket. - if db := data[len(h)]; db != ' ' && db != '>' { + // Next byte must be a tag-terminating byte(0xTT). + if !isTT(data[len(h)]) { return "" } return "text/html; charset=utf-8" @@ -229,7 +274,7 @@ func (mp4Sig) match(data []byte, firstNonWS int) string { return "" } boxSize := int(binary.BigEndian.Uint32(data[:4])) - if boxSize%4 != 0 || len(data) < boxSize { + if len(data) < boxSize || boxSize%4 != 0 { return "" } if !bytes.Equal(data[4:8], mp4ftype) { @@ -237,7 +282,7 @@ func (mp4Sig) match(data []byte, firstNonWS int) string { } for st := 8; st < boxSize; st += 4 { if st == 12 { - // minor version number + // Ignores the four bytes that correspond to the version number of the "major brand". continue } if bytes.Equal(data[st:st+3], mp4) { diff --git a/libgo/go/net/http/sniff_test.go b/libgo/go/net/http/sniff_test.go index b4d3c9f..a1157a0 100644 --- a/libgo/go/net/http/sniff_test.go +++ b/libgo/go/net/http/sniff_test.go @@ -36,8 +36,14 @@ var sniffTests = []struct { {"XML", []byte("\n<?xml!"), "text/xml; charset=utf-8"}, // Image types. + {"Windows icon", []byte("\x00\x00\x01\x00"), "image/x-icon"}, + {"Windows cursor", []byte("\x00\x00\x02\x00"), "image/x-icon"}, + {"BMP image", []byte("BM..."), "image/bmp"}, {"GIF 87a", []byte(`GIF87a`), "image/gif"}, {"GIF 89a", []byte(`GIF89a...`), "image/gif"}, + {"WEBP image", []byte("RIFF\x00\x00\x00\x00WEBPVP"), "image/webp"}, + {"PNG image", []byte("\x89PNG\x0D\x0A\x1A\x0A"), "image/png"}, + {"JPEG image", []byte("\xFF\xD8\xFF"), "image/jpeg"}, // Audio types. {"MIDI audio", []byte("MThd\x00\x00\x00\x06\x00\x01"), "audio/midi"}, @@ -66,6 +72,12 @@ var sniffTests = []struct { {"woff sample I", []byte("\x77\x4f\x46\x46\x00\x01\x00\x00\x00\x00\x30\x54\x00\x0d\x00\x00"), "font/woff"}, {"woff2 sample", []byte("\x77\x4f\x46\x32\x00\x01\x00\x00\x00"), "font/woff2"}, {"wasm sample", []byte("\x00\x61\x73\x6d\x01\x00"), "application/wasm"}, + + // Archive types + {"RAR v1.5-v4.0", []byte("Rar!\x1A\x07\x00"), "application/x-rar-compressed"}, + {"RAR v5+", []byte("Rar!\x1A\x07\x01\x00"), "application/x-rar-compressed"}, + {"Incorrect RAR v1.5-v4.0", []byte("Rar \x1A\x07\x00"), "application/octet-stream"}, + {"Incorrect RAR v5+", []byte("Rar \x1A\x07\x01\x00"), "application/octet-stream"}, } func TestDetectContentType(t *testing.T) { diff --git a/libgo/go/net/http/status.go b/libgo/go/net/http/status.go index 086f3d1..286315f 100644 --- a/libgo/go/net/http/status.go +++ b/libgo/go/net/http/status.go @@ -10,6 +10,7 @@ const ( StatusContinue = 100 // RFC 7231, 6.2.1 StatusSwitchingProtocols = 101 // RFC 7231, 6.2.2 StatusProcessing = 102 // RFC 2518, 10.1 + StatusEarlyHints = 103 // RFC 8297 StatusOK = 200 // RFC 7231, 6.3.1 StatusCreated = 201 // RFC 7231, 6.3.2 @@ -79,6 +80,7 @@ var statusText = map[int]string{ StatusContinue: "Continue", StatusSwitchingProtocols: "Switching Protocols", StatusProcessing: "Processing", + StatusEarlyHints: "Early Hints", StatusOK: "OK", StatusCreated: "Created", diff --git a/libgo/go/net/http/transfer.go b/libgo/go/net/http/transfer.go index e8a93e9..2e01a07 100644 --- a/libgo/go/net/http/transfer.go +++ b/libgo/go/net/http/transfer.go @@ -21,7 +21,7 @@ import ( "sync" "time" - "internal/x/net/http/httpguts" + "golang.org/x/net/http/httpguts" ) // ErrLineTooLong is returned when reading request or response bodies @@ -53,19 +53,6 @@ func (br *byteReader) Read(p []byte) (n int, err error) { return 1, io.EOF } -// transferBodyReader is an io.Reader that reads from tw.Body -// and records any non-EOF error in tw.bodyReadError. -// It is exactly 1 pointer wide to avoid allocations into interfaces. -type transferBodyReader struct{ tw *transferWriter } - -func (br transferBodyReader) Read(p []byte) (n int, err error) { - n, err = br.tw.Body.Read(p) - if err != nil && err != io.EOF { - br.tw.bodyReadError = err - } - return -} - // transferWriter inspects the fields of a user-supplied Request or Response, // sanitizes them without changing the user object and provides methods for // writing the respective header, body and trailer in wire format. @@ -347,15 +334,18 @@ func (t *transferWriter) writeBody(w io.Writer) error { var err error var ncopy int64 - // Write body + // Write body. We "unwrap" the body first if it was wrapped in a + // nopCloser. This is to ensure that we can take advantage of + // OS-level optimizations in the event that the body is an + // *os.File. if t.Body != nil { - var body = transferBodyReader{t} + var body = t.unwrapBody() if chunked(t.TransferEncoding) { if bw, ok := w.(*bufio.Writer); ok && !t.IsResponse { w = &internal.FlushAfterChunkWriter{Writer: bw} } cw := internal.NewChunkedWriter(w) - _, err = io.Copy(cw, body) + _, err = t.doBodyCopy(cw, body) if err == nil { err = cw.Close() } @@ -364,14 +354,14 @@ func (t *transferWriter) writeBody(w io.Writer) error { if t.Method == "CONNECT" { dst = bufioFlushWriter{dst} } - ncopy, err = io.Copy(dst, body) + ncopy, err = t.doBodyCopy(dst, body) } else { - ncopy, err = io.Copy(w, io.LimitReader(body, t.ContentLength)) + ncopy, err = t.doBodyCopy(w, io.LimitReader(body, t.ContentLength)) if err != nil { return err } var nextra int64 - nextra, err = io.Copy(ioutil.Discard, body) + nextra, err = t.doBodyCopy(ioutil.Discard, body) ncopy += nextra } if err != nil { @@ -402,6 +392,31 @@ func (t *transferWriter) writeBody(w io.Writer) error { return err } +// doBodyCopy wraps a copy operation, with any resulting error also +// being saved in bodyReadError. +// +// This function is only intended for use in writeBody. +func (t *transferWriter) doBodyCopy(dst io.Writer, src io.Reader) (n int64, err error) { + n, err = io.Copy(dst, src) + if err != nil && err != io.EOF { + t.bodyReadError = err + } + return +} + +// unwrapBodyReader unwraps the body's inner reader if it's a +// nopCloser. This is to ensure that body writes sourced from local +// files (*os.File types) are properly optimized. +// +// This function is only intended for use in writeBody. +func (t *transferWriter) unwrapBody() io.Reader { + if reflect.TypeOf(t.Body) == nopCloserType { + return reflect.ValueOf(t.Body).Field(0).Interface().(io.Reader) + } + + return t.Body +} + type transferReader struct { // Input Header Header @@ -574,6 +589,22 @@ func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" } // Checks whether the encoding is explicitly "identity". func isIdentity(te []string) bool { return len(te) == 1 && te[0] == "identity" } +// unsupportedTEError reports unsupported transfer-encodings. +type unsupportedTEError struct { + err string +} + +func (uste *unsupportedTEError) Error() string { + return uste.err +} + +// isUnsupportedTEError checks if the error is of type +// unsupportedTEError. It is usually invoked with a non-nil err. +func isUnsupportedTEError(err error) bool { + _, ok := err.(*unsupportedTEError) + return ok +} + // fixTransferEncoding sanitizes t.TransferEncoding, if needed. func (t *transferReader) fixTransferEncoding() error { raw, present := t.Header["Transfer-Encoding"] @@ -600,7 +631,7 @@ func (t *transferReader) fixTransferEncoding() error { break } if encoding != "chunked" { - return &badStringError{"unsupported transfer encoding", encoding} + return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", encoding)} } te = te[0 : len(te)+1] te[len(te)-1] = encoding diff --git a/libgo/go/net/http/transfer_test.go b/libgo/go/net/http/transfer_test.go index 993ea4e..65009ee 100644 --- a/libgo/go/net/http/transfer_test.go +++ b/libgo/go/net/http/transfer_test.go @@ -7,8 +7,12 @@ package http import ( "bufio" "bytes" + "crypto/rand" + "fmt" "io" "io/ioutil" + "os" + "reflect" "strings" "testing" ) @@ -90,3 +94,219 @@ func TestDetectInMemoryReaders(t *testing.T) { } } } + +type mockTransferWriter struct { + CalledReader io.Reader + WriteCalled bool +} + +var _ io.ReaderFrom = (*mockTransferWriter)(nil) + +func (w *mockTransferWriter) ReadFrom(r io.Reader) (int64, error) { + w.CalledReader = r + return io.Copy(ioutil.Discard, r) +} + +func (w *mockTransferWriter) Write(p []byte) (int, error) { + w.WriteCalled = true + return ioutil.Discard.Write(p) +} + +func TestTransferWriterWriteBodyReaderTypes(t *testing.T) { + fileType := reflect.TypeOf(&os.File{}) + bufferType := reflect.TypeOf(&bytes.Buffer{}) + + nBytes := int64(1 << 10) + newFileFunc := func() (r io.Reader, done func(), err error) { + f, err := ioutil.TempFile("", "net-http-newfilefunc") + if err != nil { + return nil, nil, err + } + + // Write some bytes to the file to enable reading. + if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil { + return nil, nil, fmt.Errorf("failed to write data to file: %v", err) + } + if _, err := f.Seek(0, 0); err != nil { + return nil, nil, fmt.Errorf("failed to seek to front: %v", err) + } + + done = func() { + f.Close() + os.Remove(f.Name()) + } + + return f, done, nil + } + + newBufferFunc := func() (io.Reader, func(), error) { + return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil + } + + cases := []struct { + name string + bodyFunc func() (io.Reader, func(), error) + method string + contentLength int64 + transferEncoding []string + limitedReader bool + expectedReader reflect.Type + expectedWrite bool + }{ + { + name: "file, non-chunked, size set", + bodyFunc: newFileFunc, + method: "PUT", + contentLength: nBytes, + limitedReader: true, + expectedReader: fileType, + }, + { + name: "file, non-chunked, size set, nopCloser wrapped", + method: "PUT", + bodyFunc: func() (io.Reader, func(), error) { + r, cleanup, err := newFileFunc() + return ioutil.NopCloser(r), cleanup, err + }, + contentLength: nBytes, + limitedReader: true, + expectedReader: fileType, + }, + { + name: "file, non-chunked, negative size", + method: "PUT", + bodyFunc: newFileFunc, + contentLength: -1, + expectedReader: fileType, + }, + { + name: "file, non-chunked, CONNECT, negative size", + method: "CONNECT", + bodyFunc: newFileFunc, + contentLength: -1, + expectedReader: fileType, + }, + { + name: "file, chunked", + method: "PUT", + bodyFunc: newFileFunc, + transferEncoding: []string{"chunked"}, + expectedWrite: true, + }, + { + name: "buffer, non-chunked, size set", + bodyFunc: newBufferFunc, + method: "PUT", + contentLength: nBytes, + limitedReader: true, + expectedReader: bufferType, + }, + { + name: "buffer, non-chunked, size set, nopCloser wrapped", + method: "PUT", + bodyFunc: func() (io.Reader, func(), error) { + r, cleanup, err := newBufferFunc() + return ioutil.NopCloser(r), cleanup, err + }, + contentLength: nBytes, + limitedReader: true, + expectedReader: bufferType, + }, + { + name: "buffer, non-chunked, negative size", + method: "PUT", + bodyFunc: newBufferFunc, + contentLength: -1, + expectedWrite: true, + }, + { + name: "buffer, non-chunked, CONNECT, negative size", + method: "CONNECT", + bodyFunc: newBufferFunc, + contentLength: -1, + expectedWrite: true, + }, + { + name: "buffer, chunked", + method: "PUT", + bodyFunc: newBufferFunc, + transferEncoding: []string{"chunked"}, + expectedWrite: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + body, cleanup, err := tc.bodyFunc() + if err != nil { + t.Fatal(err) + } + defer cleanup() + + mw := &mockTransferWriter{} + tw := &transferWriter{ + Body: body, + ContentLength: tc.contentLength, + TransferEncoding: tc.transferEncoding, + } + + if err := tw.writeBody(mw); err != nil { + t.Fatal(err) + } + + if tc.expectedReader != nil { + if mw.CalledReader == nil { + t.Fatal("did not call ReadFrom") + } + + var actualReader reflect.Type + lr, ok := mw.CalledReader.(*io.LimitedReader) + if ok && tc.limitedReader { + actualReader = reflect.TypeOf(lr.R) + } else { + actualReader = reflect.TypeOf(mw.CalledReader) + } + + if tc.expectedReader != actualReader { + t.Fatalf("got reader %T want %T", actualReader, tc.expectedReader) + } + } + + if tc.expectedWrite && !mw.WriteCalled { + t.Fatal("did not invoke Write") + } + }) + } +} + +func TestFixTransferEncoding(t *testing.T) { + tests := []struct { + hdr Header + wantErr error + }{ + { + hdr: Header{"Transfer-Encoding": {"fugazi"}}, + wantErr: &unsupportedTEError{`unsupported transfer encoding: "fugazi"`}, + }, + { + hdr: Header{"Transfer-Encoding": {"chunked, chunked", "identity", "chunked"}}, + wantErr: &badStringError{"too many transfer encodings", "chunked,chunked"}, + }, + { + hdr: Header{"Transfer-Encoding": {"chunked"}}, + wantErr: nil, + }, + } + + for i, tt := range tests { + tr := &transferReader{ + Header: tt.hdr, + ProtoMajor: 1, + ProtoMinor: 1, + } + gotErr := tr.fixTransferEncoding() + if !reflect.DeepEqual(gotErr, tt.wantErr) { + t.Errorf("%d.\ngot error:\n%v\nwant error:\n%v\n\n", i, gotErr, tt.wantErr) + } + } +} diff --git a/libgo/go/net/http/transport.go b/libgo/go/net/http/transport.go index a8c5efe..26f642a 100644 --- a/libgo/go/net/http/transport.go +++ b/libgo/go/net/http/transport.go @@ -30,8 +30,8 @@ import ( "sync/atomic" "time" - "internal/x/net/http/httpguts" - "internal/x/net/http/httpproxy" + "golang.org/x/net/http/httpguts" + "golang.org/x/net/http/httpproxy" ) // DefaultTransport is the default implementation of Transport and is @@ -46,6 +46,7 @@ var DefaultTransport RoundTripper = &Transport{ KeepAlive: 30 * time.Second, DualStack: true, }).DialContext, + ForceAttemptHTTP2: true, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, @@ -253,10 +254,76 @@ type Transport struct { // Zero means to use a default limit. MaxResponseHeaderBytes int64 + // WriteBufferSize specifies the size of the write buffer used + // when writing to the transport. + // If zero, a default (currently 4KB) is used. + WriteBufferSize int + + // ReadBufferSize specifies the size of the read buffer used + // when reading from the transport. + // If zero, a default (currently 4KB) is used. + ReadBufferSize int + // nextProtoOnce guards initialization of TLSNextProto and // h2transport (via onceSetNextProtoDefaults) - nextProtoOnce sync.Once - h2transport h2Transport // non-nil if http2 wired up + nextProtoOnce sync.Once + h2transport h2Transport // non-nil if http2 wired up + tlsNextProtoWasNil bool // whether TLSNextProto was nil when the Once fired + + // ForceAttemptHTTP2 controls whether HTTP/2 is enabled when a non-zero + // Dial, DialTLS, or DialContext func or TLSClientConfig is provided. + // By default, use of any those fields conservatively disables HTTP/2. + // To use a custom dialer or TLS config and still attempt HTTP/2 + // upgrades, set this to true. + ForceAttemptHTTP2 bool +} + +func (t *Transport) writeBufferSize() int { + if t.WriteBufferSize > 0 { + return t.WriteBufferSize + } + return 4 << 10 +} + +func (t *Transport) readBufferSize() int { + if t.ReadBufferSize > 0 { + return t.ReadBufferSize + } + return 4 << 10 +} + +// Clone returns a deep copy of t's exported fields. +func (t *Transport) Clone() *Transport { + t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) + t2 := &Transport{ + Proxy: t.Proxy, + DialContext: t.DialContext, + Dial: t.Dial, + DialTLS: t.DialTLS, + TLSClientConfig: t.TLSClientConfig.Clone(), + TLSHandshakeTimeout: t.TLSHandshakeTimeout, + DisableKeepAlives: t.DisableKeepAlives, + DisableCompression: t.DisableCompression, + MaxIdleConns: t.MaxIdleConns, + MaxIdleConnsPerHost: t.MaxIdleConnsPerHost, + MaxConnsPerHost: t.MaxConnsPerHost, + IdleConnTimeout: t.IdleConnTimeout, + ResponseHeaderTimeout: t.ResponseHeaderTimeout, + ExpectContinueTimeout: t.ExpectContinueTimeout, + ProxyConnectHeader: t.ProxyConnectHeader.Clone(), + MaxResponseHeaderBytes: t.MaxResponseHeaderBytes, + ForceAttemptHTTP2: t.ForceAttemptHTTP2, + WriteBufferSize: t.WriteBufferSize, + ReadBufferSize: t.ReadBufferSize, + } + if !t.tlsNextProtoWasNil { + npm := map[string]func(authority string, c *tls.Conn) RoundTripper{} + for k, v := range t.TLSNextProto { + npm[k] = v + } + t2.TLSNextProto = npm + } + return t2 } // h2Transport is the interface we expect to be able to call from @@ -272,6 +339,7 @@ type h2Transport interface { // onceSetNextProtoDefaults initializes TLSNextProto. // It must be called via t.nextProtoOnce.Do. func (t *Transport) onceSetNextProtoDefaults() { + t.tlsNextProtoWasNil = (t.TLSNextProto == nil) if strings.Contains(os.Getenv("GODEBUG"), "http2client=0") { return } @@ -296,12 +364,13 @@ func (t *Transport) onceSetNextProtoDefaults() { // Transport. return } - if t.TLSClientConfig != nil || t.Dial != nil || t.DialTLS != nil { + if !t.ForceAttemptHTTP2 && (t.TLSClientConfig != nil || t.Dial != nil || t.DialTLS != nil || t.DialContext != nil) { // Be conservative and don't automatically enable // http2 if they've specified a custom TLS config or // custom dialers. Let them opt-in themselves via // http2.ConfigureTransport so we don't surprise them // by modifying their tls.Config. Issue 14275. + // However, if ForceAttemptHTTP2 is true, it overrides the above checks. return } t2, err := http2configureTransport(t) @@ -474,8 +543,8 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { var resp *Response if pconn.alt != nil { // HTTP/2 path. - t.decHostConnCount(cm.key()) // don't count cached http2 conns toward conns per host - t.setReqCanceler(req, nil) // not cancelable with CancelRequest + t.putOrCloseIdleConn(pconn) + t.setReqCanceler(req, nil) // not cancelable with CancelRequest resp, err = pconn.alt.RoundTrip(req) } else { resp, err = pconn.roundTrip(treq) @@ -483,7 +552,10 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { if err == nil { return resp, nil } - if !pconn.shouldRetryRequest(req, err) { + if http2isNoCachedConnError(err) { + t.removeIdleConn(pconn) + t.decHostConnCount(cm.key()) // clean up the persistent connection + } else if !pconn.shouldRetryRequest(req, err) { // Issue 16465: return underlying net.Conn.Read error from peek, // as we've historically done. if e, ok := err.(transportReadFromServerError); ok { @@ -717,6 +789,8 @@ type transportReadFromServerError struct { err error } +func (e transportReadFromServerError) Unwrap() error { return e.err } + func (e transportReadFromServerError) Error() string { return fmt.Sprintf("net/http: Transport failed to read from server: %v", e.err) } @@ -746,9 +820,6 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error { if pconn.isBroken() { return errConnBroken } - if pconn.alt != nil { - return errNotCachingH2Conn - } pconn.markReused() key := pconn.cacheKey @@ -797,7 +868,10 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error { if pconn.idleTimer != nil { pconn.idleTimer.Reset(t.IdleConnTimeout) } else { - pconn.idleTimer = time.AfterFunc(t.IdleConnTimeout, pconn.closeConnIfStillIdle) + // idleTimer does not apply to HTTP/2 + if pconn.alt == nil { + pconn.idleTimer = time.AfterFunc(t.IdleConnTimeout, pconn.closeConnIfStillIdle) + } } } pconn.idleAt = time.Now() @@ -1031,7 +1105,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC t.decHostConnCount(cmKey) select { case <-req.Cancel: - // It was an error due to cancelation, so prioritize that + // It was an error due to cancellation, so prioritize that // error value. (Issue 16049) return nil, errRequestCanceledConn case <-req.Context().Done(): @@ -1042,7 +1116,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC } return nil, err default: - // It wasn't an error due to cancelation, so + // It wasn't an error due to cancellation, so // return the original error message: return nil, v.err } @@ -1345,15 +1419,16 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" { if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok { - return &persistConn{alt: next(cm.targetAddr, pconn.conn.(*tls.Conn))}, nil + return &persistConn{t: t, cacheKey: pconn.cacheKey, alt: next(cm.targetAddr, pconn.conn.(*tls.Conn))}, nil } } if t.MaxConnsPerHost > 0 { pconn.conn = &connCloseListener{Conn: pconn.conn, t: t, cmKey: pconn.cacheKey} } - pconn.br = bufio.NewReader(pconn) - pconn.bw = bufio.NewWriter(persistConnWriter{pconn}) + pconn.br = bufio.NewReaderSize(pconn, t.readBufferSize()) + pconn.bw = bufio.NewWriterSize(persistConnWriter{pconn}, t.writeBufferSize()) + go pconn.readLoop() go pconn.writeLoop() return pconn, nil @@ -1375,6 +1450,17 @@ func (w persistConnWriter) Write(p []byte) (n int, err error) { return } +// ReadFrom exposes persistConnWriter's underlying Conn to io.Copy and if +// the Conn implements io.ReaderFrom, it can take advantage of optimizations +// such as sendfile. +func (w persistConnWriter) ReadFrom(r io.Reader) (n int64, err error) { + n, err = io.Copy(w.pc.conn, r) + w.pc.nwrite += n + return +} + +var _ io.ReaderFrom = (*persistConnWriter)(nil) + // connectMethod is the map key (in its String form) for keeping persistent // TCP connections alive for subsequent HTTP requests. // @@ -1538,7 +1624,7 @@ func (pc *persistConn) isBroken() bool { } // canceled returns non-nil if the connection was closed due to -// CancelRequest or due to context cancelation. +// CancelRequest or due to context cancellation. func (pc *persistConn) canceled() error { pc.mu.Lock() defer pc.mu.Unlock() @@ -1794,7 +1880,7 @@ func (pc *persistConn) readLoop() { // Before looping back to the top of this function and peeking on // the bufio.Reader, wait for the caller goroutine to finish - // reading the response body. (or for cancelation or death) + // reading the response body. (or for cancellation or death) select { case bodyEOF := <-waitForBodyRead: pc.t.setReqCanceler(rc.req, nil) // before pc might return to idle pool @@ -1826,7 +1912,12 @@ func (pc *persistConn) readLoopPeekFailLocked(peekErr error) { } if n := pc.br.Buffered(); n > 0 { buf, _ := pc.br.Peek(n) - log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v", buf, peekErr) + if is408Message(buf) { + pc.closeLocked(errServerClosedIdle) + return + } else { + log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v", buf, peekErr) + } } if peekErr == io.EOF { // common case. @@ -1836,6 +1927,19 @@ func (pc *persistConn) readLoopPeekFailLocked(peekErr error) { } } +// is408Message reports whether buf has the prefix of an +// HTTP 408 Request Timeout response. +// See golang.org/issue/32310. +func is408Message(buf []byte) bool { + if len(buf) < len("HTTP/1.x 408") { + return false + } + if string(buf[:7]) != "HTTP/1." { + return false + } + return string(buf[8:12]) == " 408" +} + // readResponse reads an HTTP response (or two, in the case of "Expect: // 100-continue") from the server. It returns the final non-100 one. // trace is optional. @@ -2072,8 +2176,21 @@ func (e *httpError) Error() string { return e.err } func (e *httpError) Timeout() bool { return e.timeout } func (e *httpError) Temporary() bool { return true } +func (e *httpError) Is(target error) bool { + switch target { + case os.ErrTimeout: + return e.timeout + case os.ErrTemporary: + return true + } + return false +} + var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true} -var errRequestCanceled = errors.New("net/http: request canceled") + +// errRequestCanceled is set to be identical to the one from h2 to facilitate +// testing. +var errRequestCanceled = http2errRequestCanceled var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify? func nop() {} @@ -2258,13 +2375,8 @@ func (pc *persistConn) closeLocked(err error) { if pc.closed == nil { pc.closed = err if pc.alt != nil { - // Do nothing; can only get here via getConn's - // handlePendingDial's putOrCloseIdleConn when - // it turns out the abandoned connection in - // flight ended up negotiating an alternate - // protocol. We don't use the connection - // freelist for http2. That's done by the - // alternate protocol's RoundTripper. + // Clean up any host connection counting. + pc.t.decHostConnCount(pc.cacheKey) } else { if err != errCallerOwnsConn { pc.conn.Close() @@ -2408,6 +2520,10 @@ func (tlsHandshakeTimeoutError) Timeout() bool { return true } func (tlsHandshakeTimeoutError) Temporary() bool { return true } func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" } +func (tlsHandshakeTimeoutError) Is(target error) bool { + return target == os.ErrTimeout || target == os.ErrTemporary +} + // fakeLocker is a sync.Locker which does nothing. It's used to guard // test-only fields when not under test, to avoid runtime atomic // overhead. diff --git a/libgo/go/net/http/transport_test.go b/libgo/go/net/http/transport_test.go index 6e07584..2b58e1d 100644 --- a/libgo/go/net/http/transport_test.go +++ b/libgo/go/net/http/transport_test.go @@ -20,6 +20,7 @@ import ( "encoding/binary" "errors" "fmt" + "go/token" "internal/nettrace" "io" "io/ioutil" @@ -42,7 +43,7 @@ import ( "testing" "time" - "internal/x/net/http/httpguts" + "golang.org/x/net/http/httpguts" ) // TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close @@ -588,6 +589,106 @@ func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) { <-reqComplete } +func TestTransportMaxConnsPerHost(t *testing.T) { + defer afterTest(t) + + h := HandlerFunc(func(w ResponseWriter, r *Request) { + _, err := w.Write([]byte("foo")) + if err != nil { + t.Fatalf("Write: %v", err) + } + }) + + testMaxConns := func(scheme string, ts *httptest.Server) { + defer ts.Close() + + c := ts.Client() + tr := c.Transport.(*Transport) + tr.MaxConnsPerHost = 1 + if err := ExportHttp2ConfigureTransport(tr); err != nil { + t.Fatalf("ExportHttp2ConfigureTransport: %v", err) + } + + connCh := make(chan net.Conn, 1) + var dialCnt, gotConnCnt, tlsHandshakeCnt int32 + tr.Dial = func(network, addr string) (net.Conn, error) { + atomic.AddInt32(&dialCnt, 1) + c, err := net.Dial(network, addr) + connCh <- c + return c, err + } + + doReq := func() { + trace := &httptrace.ClientTrace{ + GotConn: func(connInfo httptrace.GotConnInfo) { + if !connInfo.Reused { + atomic.AddInt32(&gotConnCnt, 1) + } + }, + TLSHandshakeStart: func() { + atomic.AddInt32(&tlsHandshakeCnt, 1) + }, + } + req, _ := NewRequest("GET", ts.URL, nil) + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + + resp, err := c.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body failed: %v", err) + } + } + + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + doReq() + }() + } + wg.Wait() + + expected := int32(tr.MaxConnsPerHost) + if dialCnt != expected { + t.Errorf("Too many dials (%s): %d", scheme, dialCnt) + } + if gotConnCnt != expected { + t.Errorf("Too many get connections (%s): %d", scheme, gotConnCnt) + } + if ts.TLS != nil && tlsHandshakeCnt != expected { + t.Errorf("Too many tls handshakes (%s): %d", scheme, tlsHandshakeCnt) + } + + (<-connCh).Close() + tr.CloseIdleConnections() + + doReq() + expected++ + if dialCnt != expected { + t.Errorf("Too many dials (%s): %d", scheme, dialCnt) + } + if gotConnCnt != expected { + t.Errorf("Too many get connections (%s): %d", scheme, gotConnCnt) + } + if ts.TLS != nil && tlsHandshakeCnt != expected { + t.Errorf("Too many tls handshakes (%s): %d", scheme, tlsHandshakeCnt) + } + } + + testMaxConns("http", httptest.NewServer(h)) + testMaxConns("https", httptest.NewTLSServer(h)) + + ts := httptest.NewUnstartedServer(h) + ts.TLS = &tls.Config{NextProtos: []string{"h2"}} + ts.StartTLS() + testMaxConns("http2", ts) +} + func TestTransportRemovesDeadIdleConnections(t *testing.T) { setParallel(t) defer afterTest(t) @@ -636,6 +737,8 @@ func TestTransportRemovesDeadIdleConnections(t *testing.T) { } } +// Test that the Transport notices when a server hangs up on its +// unexpectedly (a keep-alive connection is closed). func TestTransportServerClosingUnexpectedly(t *testing.T) { setParallel(t) defer afterTest(t) @@ -672,13 +775,14 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) { body1 := fetch(1, 0) body2 := fetch(2, 0) - ts.CloseClientConnections() // surprise! - - // This test has an expected race. Sleeping for 25 ms prevents - // it on most fast machines, causing the next fetch() call to - // succeed quickly. But if we do get errors, fetch() will retry 5 - // times with some delays between. - time.Sleep(25 * time.Millisecond) + // Close all the idle connections in a way that's similar to + // the server hanging up on us. We don't use + // httptest.Server.CloseClientConnections because it's + // best-effort and stops blocking after 5 seconds. On a loaded + // machine running many tests concurrently it's possible for + // that method to be async and cause the body3 fetch below to + // run on an old connection. This function is synchronous. + ExportCloseTransportConnsAbruptly(c.Transport.(*Transport)) body3 := fetch(3, 5) @@ -865,6 +969,10 @@ func TestRoundTripGzip(t *testing.T) { req.Header.Set("Accept-Encoding", test.accept) } res, err := tr.RoundTrip(req) + if err != nil { + t.Errorf("%d. RoundTrip: %v", i, err) + continue + } var body []byte if test.compressed { var r *gzip.Reader @@ -2110,7 +2218,7 @@ func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) { } } else { if err == nil || !strings.Contains(err.Error(), "canceled") { - t.Errorf("Do error = %v; want cancelation", err) + t.Errorf("Do error = %v; want cancellation", err) } } } @@ -3589,6 +3697,13 @@ func TestTransportAutomaticHTTP2(t *testing.T) { testTransportAutoHTTP(t, &Transport{}, true) } +func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) { + testTransportAutoHTTP(t, &Transport{ + ForceAttemptHTTP2: true, + TLSClientConfig: new(tls.Config), + }, true) +} + // golang.org/issue/14391: also check DefaultTransport func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) { testTransportAutoHTTP(t, DefaultTransport.(*Transport), true) @@ -3619,6 +3734,13 @@ func TestTransportAutomaticHTTP2_Dial(t *testing.T) { }, false) } +func TestTransportAutomaticHTTP2_DialContext(t *testing.T) { + var d net.Dialer + testTransportAutoHTTP(t, &Transport{ + DialContext: d.DialContext, + }, false) +} + func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) { testTransportAutoHTTP(t, &Transport{ DialTLS: func(network, addr string) (net.Conn, error) { @@ -5059,3 +5181,270 @@ func TestTransportRequestReplayable(t *testing.T) { }) } } + +// testMockTCPConn is a mock TCP connection used to test that +// ReadFrom is called when sending the request body. +type testMockTCPConn struct { + *net.TCPConn + + ReadFromCalled bool +} + +func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) { + c.ReadFromCalled = true + return c.TCPConn.ReadFrom(r) +} + +func TestTransportRequestWriteRoundTrip(t *testing.T) { + nBytes := int64(1 << 10) + newFileFunc := func() (r io.Reader, done func(), err error) { + f, err := ioutil.TempFile("", "net-http-newfilefunc") + if err != nil { + return nil, nil, err + } + + // Write some bytes to the file to enable reading. + if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil { + return nil, nil, fmt.Errorf("failed to write data to file: %v", err) + } + if _, err := f.Seek(0, 0); err != nil { + return nil, nil, fmt.Errorf("failed to seek to front: %v", err) + } + + done = func() { + f.Close() + os.Remove(f.Name()) + } + + return f, done, nil + } + + newBufferFunc := func() (io.Reader, func(), error) { + return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil + } + + cases := []struct { + name string + readerFunc func() (io.Reader, func(), error) + contentLength int64 + expectedReadFrom bool + }{ + { + name: "file, length", + readerFunc: newFileFunc, + contentLength: nBytes, + expectedReadFrom: true, + }, + { + name: "file, no length", + readerFunc: newFileFunc, + }, + { + name: "file, negative length", + readerFunc: newFileFunc, + contentLength: -1, + }, + { + name: "buffer", + contentLength: nBytes, + readerFunc: newBufferFunc, + }, + { + name: "buffer, no length", + readerFunc: newBufferFunc, + }, + { + name: "buffer, length -1", + contentLength: -1, + readerFunc: newBufferFunc, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + r, cleanup, err := tc.readerFunc() + if err != nil { + t.Fatal(err) + } + defer cleanup() + + tConn := &testMockTCPConn{} + trFunc := func(tr *Transport) { + tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + conn, err := d.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr) + } + + tConn.TCPConn = tcpConn + return tConn, nil + } + } + + cst := newClientServerTest( + t, + h1Mode, + HandlerFunc(func(w ResponseWriter, r *Request) { + io.Copy(ioutil.Discard, r.Body) + r.Body.Close() + w.WriteHeader(200) + }), + trFunc, + ) + defer cst.close() + + req, err := NewRequest("PUT", cst.ts.URL, r) + if err != nil { + t.Fatal(err) + } + req.ContentLength = tc.contentLength + req.Header.Set("Content-Type", "application/octet-stream") + resp, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Fatalf("status code = %d; want 200", resp.StatusCode) + } + + if !tConn.ReadFromCalled && tc.expectedReadFrom { + t.Fatalf("did not call ReadFrom") + } + + if tConn.ReadFromCalled && !tc.expectedReadFrom { + t.Fatalf("ReadFrom was unexpectedly invoked") + } + }) + } +} + +func TestTransportClone(t *testing.T) { + tr := &Transport{ + Proxy: func(*Request) (*url.URL, error) { panic("") }, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, + Dial: func(network, addr string) (net.Conn, error) { panic("") }, + DialTLS: func(network, addr string) (net.Conn, error) { panic("") }, + TLSClientConfig: new(tls.Config), + TLSHandshakeTimeout: time.Second, + DisableKeepAlives: true, + DisableCompression: true, + MaxIdleConns: 1, + MaxIdleConnsPerHost: 1, + MaxConnsPerHost: 1, + IdleConnTimeout: time.Second, + ResponseHeaderTimeout: time.Second, + ExpectContinueTimeout: time.Second, + ProxyConnectHeader: Header{}, + MaxResponseHeaderBytes: 1, + ForceAttemptHTTP2: true, + TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{ + "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") }, + }, + ReadBufferSize: 1, + WriteBufferSize: 1, + } + tr2 := tr.Clone() + rv := reflect.ValueOf(tr2).Elem() + rt := rv.Type() + for i := 0; i < rt.NumField(); i++ { + sf := rt.Field(i) + if !token.IsExported(sf.Name) { + continue + } + if rv.Field(i).IsZero() { + t.Errorf("cloned field t2.%s is zero", sf.Name) + } + } + + if _, ok := tr2.TLSNextProto["foo"]; !ok { + t.Errorf("cloned Transport lacked TLSNextProto 'foo' key") + } + + // But test that a nil TLSNextProto is kept nil: + tr = new(Transport) + tr2 = tr.Clone() + if tr2.TLSNextProto != nil { + t.Errorf("Transport.TLSNextProto unexpected non-nil") + } +} + +func TestIs408(t *testing.T) { + tests := []struct { + in string + want bool + }{ + {"HTTP/1.0 408", true}, + {"HTTP/1.1 408", true}, + {"HTTP/1.8 408", true}, + {"HTTP/2.0 408", false}, // maybe h2c would do this? but false for now. + {"HTTP/1.1 408 ", true}, + {"HTTP/1.1 40", false}, + {"http/1.0 408", false}, + {"HTTP/1-1 408", false}, + } + for _, tt := range tests { + if got := Export_is408Message([]byte(tt.in)); got != tt.want { + t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want) + } + } +} + +func TestTransportIgnores408(t *testing.T) { + // Not parallel. Relies on mutating the log package's global Output. + defer log.SetOutput(log.Writer()) + + var logout bytes.Buffer + log.SetOutput(&logout) + + defer afterTest(t) + const target = "backend:443" + + cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + nc, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Error(err) + return + } + defer nc.Close() + nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok")) + nc.Write([]byte("HTTP/1.1 408 bye\r\n")) // changing 408 to 409 makes test fail + })) + defer cst.close() + req, err := NewRequest("GET", cst.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if err != nil { + t.Fatal(err) + } + if string(slurp) != "ok" { + t.Fatalf("got %q; want ok", slurp) + } + + t0 := time.Now() + for i := 0; i < 50; i++ { + time.Sleep(time.Duration(i) * 5 * time.Millisecond) + if cst.tr.IdleConnKeyCountForTesting() == 0 { + if got := logout.String(); got != "" { + t.Fatalf("expected no log output; got: %s", got) + } + return + } + } + t.Fatalf("timeout after %v waiting for Transport connections to die off", time.Since(t0)) +} |