aboutsummaryrefslogtreecommitdiff
path: root/libgo/go/net/http
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/net/http')
-rw-r--r--libgo/go/net/http/cgi/child.go2
-rw-r--r--libgo/go/net/http/client.go15
-rw-r--r--libgo/go/net/http/client_test.go10
-rw-r--r--libgo/go/net/http/clientserver_test.go2
-rw-r--r--libgo/go/net/http/clone.go64
-rw-r--r--libgo/go/net/http/cookie.go40
-rw-r--r--libgo/go/net/http/cookie_test.go42
-rw-r--r--libgo/go/net/http/export_test.go16
-rw-r--r--libgo/go/net/http/fs.go2
-rw-r--r--libgo/go/net/http/h2_bundle.go71
-rw-r--r--libgo/go/net/http/header.go15
-rw-r--r--libgo/go/net/http/http.go4
-rw-r--r--libgo/go/net/http/httptest/recorder.go33
-rw-r--r--libgo/go/net/http/httputil/dump_test.go98
-rw-r--r--libgo/go/net/http/httputil/persist.go4
-rw-r--r--libgo/go/net/http/httputil/reverseproxy.go27
-rw-r--r--libgo/go/net/http/httputil/reverseproxy_test.go37
-rw-r--r--libgo/go/net/http/internal/testcert.go8
-rw-r--r--libgo/go/net/http/request.go80
-rw-r--r--libgo/go/net/http/request_test.go136
-rw-r--r--libgo/go/net/http/response.go5
-rw-r--r--libgo/go/net/http/response_test.go4
-rw-r--r--libgo/go/net/http/roundtrip_js.go47
-rw-r--r--libgo/go/net/http/serve_test.go151
-rw-r--r--libgo/go/net/http/server.go139
-rw-r--r--libgo/go/net/http/sniff.go117
-rw-r--r--libgo/go/net/http/sniff_test.go12
-rw-r--r--libgo/go/net/http/status.go2
-rw-r--r--libgo/go/net/http/transfer.go73
-rw-r--r--libgo/go/net/http/transfer_test.go220
-rw-r--r--libgo/go/net/http/transport.go172
-rw-r--r--libgo/go/net/http/transport_test.go407
32 files changed, 1715 insertions, 340 deletions
diff --git a/libgo/go/net/http/cgi/child.go b/libgo/go/net/http/cgi/child.go
index 10325c2..cb140f8 100644
--- a/libgo/go/net/http/cgi/child.go
+++ b/libgo/go/net/http/cgi/child.go
@@ -102,7 +102,7 @@ func RequestFromMap(params map[string]string) (*http.Request, error) {
}
// There's apparently a de-facto standard for this.
- // https://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636
+ // https://web.archive.org/web/20170105004655/http://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636
if s := params["HTTPS"]; s == "on" || s == "ON" || s == "1" {
r.TLS = &tls.ConnectionState{HandshakeComplete: true}
}
diff --git a/libgo/go/net/http/client.go b/libgo/go/net/http/client.go
index 921f86b..65a9d51 100644
--- a/libgo/go/net/http/client.go
+++ b/libgo/go/net/http/client.go
@@ -100,7 +100,7 @@ type Client struct {
// For compatibility, the Client will also use the deprecated
// CancelRequest method on Transport if found. New
// RoundTripper implementations should use the Request's Context
- // for cancelation instead of implementing CancelRequest.
+ // for cancellation instead of implementing CancelRequest.
Timeout time.Duration
}
@@ -238,7 +238,7 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, d
username := u.Username()
password, _ := u.Password()
forkReq()
- req.Header = ireq.Header.clone()
+ req.Header = ireq.Header.Clone()
req.Header.Set("Authorization", "Basic "+basicAuth(username, password))
}
@@ -643,7 +643,7 @@ func (c *Client) do(req *Request) (retres *Response, reterr error) {
reqBodyClosed = true
if !deadline.IsZero() && didTimeout() {
err = &httpError{
- // TODO: early in cycle: s/Client.Timeout exceeded/timeout or context cancelation/
+ // TODO: early in cycle: s/Client.Timeout exceeded/timeout or context cancellation/
err: err.Error() + " (Client.Timeout exceeded while awaiting headers)",
timeout: true,
}
@@ -668,7 +668,7 @@ func (c *Client) makeHeadersCopier(ireq *Request) func(*Request) {
// The headers to copy are from the very initial request.
// We use a closured callback to keep a reference to these original headers.
var (
- ireqhdr = ireq.Header.clone()
+ ireqhdr = ireq.Header.Clone()
icookies map[string][]*Cookie
)
if c.Jar != nil && ireq.Header.Get("Cookie") != "" {
@@ -870,7 +870,7 @@ func (b *cancelTimerBody) Read(p []byte) (n int, err error) {
}
if b.reqDidTimeout() {
err = &httpError{
- // TODO: early in cycle: s/Client.Timeout exceeded/timeout or context cancelation/
+ // TODO: early in cycle: s/Client.Timeout exceeded/timeout or context cancellation/
err: err.Error() + " (Client.Timeout exceeded while reading body)",
timeout: true,
}
@@ -926,10 +926,9 @@ func isDomainOrSubdomain(sub, parent string) bool {
}
func stripPassword(u *url.URL) string {
- pass, passSet := u.User.Password()
+ _, passSet := u.User.Password()
if passSet {
- return strings.Replace(u.String(), pass+"@", "***@", 1)
+ return strings.Replace(u.String(), u.User.String()+"@", u.User.Username()+":***@", 1)
}
-
return u.String()
}
diff --git a/libgo/go/net/http/client_test.go b/libgo/go/net/http/client_test.go
index 1c59ce7..de490bc 100644
--- a/libgo/go/net/http/client_test.go
+++ b/libgo/go/net/http/client_test.go
@@ -317,8 +317,7 @@ func TestClientRedirectContext(t *testing.T) {
return errors.New("redirected request's context never expired after root request canceled")
}
}
- req, _ := NewRequest("GET", ts.URL, nil)
- req = req.WithContext(ctx)
+ req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
_, err := c.Do(req)
ue, ok := err.(*url.Error)
if !ok {
@@ -1185,6 +1184,11 @@ func TestStripPasswordFromError(t *testing.T) {
in: "http://user:password@dummy.faketld/password",
out: "Get http://user:***@dummy.faketld/password: dummy impl",
},
+ {
+ desc: "Strip escaped password",
+ in: "http://user:pa%2Fssword@dummy.faketld/",
+ out: "Get http://user:***@dummy.faketld/: dummy impl",
+ },
}
for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) {
@@ -1288,7 +1292,7 @@ func testClientTimeout_Headers(t *testing.T, h2 bool) {
donec := make(chan bool, 1)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
<-donec
- }))
+ }), optQuietLog)
defer cst.close()
// Note that we use a channel send here and not a close.
// The race detector doesn't know that we're waiting for a timeout
diff --git a/libgo/go/net/http/clientserver_test.go b/libgo/go/net/http/clientserver_test.go
index 4da218b..ee48fb0 100644
--- a/libgo/go/net/http/clientserver_test.go
+++ b/libgo/go/net/http/clientserver_test.go
@@ -560,7 +560,7 @@ func testCancelRequestMidBody(t *testing.T, h2 bool) {
if all != "Hello" {
t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest)
}
- if !reflect.DeepEqual(err, ExportErrRequestCanceled) {
+ if err != ExportErrRequestCanceled {
t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled)
}
}
diff --git a/libgo/go/net/http/clone.go b/libgo/go/net/http/clone.go
new file mode 100644
index 0000000..5f2784d
--- /dev/null
+++ b/libgo/go/net/http/clone.go
@@ -0,0 +1,64 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+ "mime/multipart"
+ "net/textproto"
+ "net/url"
+)
+
+func cloneURLValues(v url.Values) url.Values {
+ if v == nil {
+ return nil
+ }
+ // http.Header and url.Values have the same representation, so temporarily
+ // treat it like http.Header, which does have a clone:
+ return url.Values(Header(v).Clone())
+}
+
+func cloneURL(u *url.URL) *url.URL {
+ if u == nil {
+ return nil
+ }
+ u2 := new(url.URL)
+ *u2 = *u
+ if u.User != nil {
+ u2.User = new(url.Userinfo)
+ *u2.User = *u.User
+ }
+ return u2
+}
+
+func cloneMultipartForm(f *multipart.Form) *multipart.Form {
+ if f == nil {
+ return nil
+ }
+ f2 := &multipart.Form{
+ Value: (map[string][]string)(Header(f.Value).Clone()),
+ }
+ if f.File != nil {
+ m := make(map[string][]*multipart.FileHeader)
+ for k, vv := range f.File {
+ vv2 := make([]*multipart.FileHeader, len(vv))
+ for i, v := range vv {
+ vv2[i] = cloneMultipartFileHeader(v)
+ }
+ m[k] = vv2
+ }
+ f2.File = m
+ }
+ return f2
+}
+
+func cloneMultipartFileHeader(fh *multipart.FileHeader) *multipart.FileHeader {
+ if fh == nil {
+ return nil
+ }
+ fh2 := new(multipart.FileHeader)
+ *fh2 = *fh
+ fh2.Header = textproto.MIMEHeader(Header(fh.Header).Clone())
+ return fh2
+}
diff --git a/libgo/go/net/http/cookie.go b/libgo/go/net/http/cookie.go
index 63f6221..91ff544 100644
--- a/libgo/go/net/http/cookie.go
+++ b/libgo/go/net/http/cookie.go
@@ -48,6 +48,7 @@ const (
SameSiteDefaultMode SameSite = iota + 1
SameSiteLaxMode
SameSiteStrictMode
+ SameSiteNoneMode
)
// readSetCookies parses all "Set-Cookie" values from
@@ -105,6 +106,8 @@ func readSetCookies(h Header) []*Cookie {
c.SameSite = SameSiteLaxMode
case "strict":
c.SameSite = SameSiteStrictMode
+ case "none":
+ c.SameSite = SameSiteNoneMode
default:
c.SameSite = SameSiteDefaultMode
}
@@ -168,8 +171,12 @@ func (c *Cookie) String() string {
if c == nil || !isCookieNameValid(c.Name) {
return ""
}
+ // extraCookieLength derived from typical length of cookie attributes
+ // see RFC 6265 Sec 4.1.
+ const extraCookieLength = 110
var b strings.Builder
- b.WriteString(sanitizeCookieName(c.Name))
+ b.Grow(len(c.Name) + len(c.Value) + len(c.Domain) + len(c.Path) + extraCookieLength)
+ b.WriteString(c.Name)
b.WriteRune('=')
b.WriteString(sanitizeCookieValue(c.Value))
@@ -213,6 +220,8 @@ func (c *Cookie) String() string {
switch c.SameSite {
case SameSiteDefaultMode:
b.WriteString("; SameSite")
+ case SameSiteNoneMode:
+ b.WriteString("; SameSite=None")
case SameSiteLaxMode:
b.WriteString("; SameSite=Lax")
case SameSiteStrictMode:
@@ -226,25 +235,28 @@ func (c *Cookie) String() string {
//
// if filter isn't empty, only cookies of that name are returned
func readCookies(h Header, filter string) []*Cookie {
- lines, ok := h["Cookie"]
- if !ok {
+ lines := h["Cookie"]
+ if len(lines) == 0 {
return []*Cookie{}
}
- cookies := []*Cookie{}
+ cookies := make([]*Cookie, 0, len(lines)+strings.Count(lines[0], ";"))
for _, line := range lines {
- parts := strings.Split(strings.TrimSpace(line), ";")
- if len(parts) == 1 && parts[0] == "" {
- continue
- }
- // Per-line attributes
- for i := 0; i < len(parts); i++ {
- parts[i] = strings.TrimSpace(parts[i])
- if len(parts[i]) == 0 {
+ line = strings.TrimSpace(line)
+
+ var part string
+ for len(line) > 0 { // continue since we have rest
+ if splitIndex := strings.Index(line, ";"); splitIndex > 0 {
+ part, line = line[:splitIndex], line[splitIndex+1:]
+ } else {
+ part, line = line, ""
+ }
+ part = strings.TrimSpace(part)
+ if len(part) == 0 {
continue
}
- name, val := parts[i], ""
- if j := strings.Index(name, "="); j >= 0 {
+ name, val := part, ""
+ if j := strings.Index(part, "="); j >= 0 {
name, val = name[:j], name[j+1:]
}
if !isCookieNameValid(name) {
diff --git a/libgo/go/net/http/cookie_test.go b/libgo/go/net/http/cookie_test.go
index 022adaa..9e8196e 100644
--- a/libgo/go/net/http/cookie_test.go
+++ b/libgo/go/net/http/cookie_test.go
@@ -77,6 +77,10 @@ var writeSetCookiesTests = []struct {
&Cookie{Name: "cookie-14", Value: "samesite-strict", SameSite: SameSiteStrictMode},
"cookie-14=samesite-strict; SameSite=Strict",
},
+ {
+ &Cookie{Name: "cookie-15", Value: "samesite-none", SameSite: SameSiteNoneMode},
+ "cookie-15=samesite-none; SameSite=None",
+ },
// The "special" cookies have values containing commas or spaces which
// are disallowed by RFC 6265 but are common in the wild.
{
@@ -127,6 +131,22 @@ var writeSetCookiesTests = []struct {
&Cookie{Name: "\t"},
``,
},
+ {
+ &Cookie{Name: "\r"},
+ ``,
+ },
+ {
+ &Cookie{Name: "a\nb", Value: "v"},
+ ``,
+ },
+ {
+ &Cookie{Name: "a\nb", Value: "v"},
+ ``,
+ },
+ {
+ &Cookie{Name: "a\rb", Value: "v"},
+ ``,
+ },
}
func TestWriteSetCookies(t *testing.T) {
@@ -280,6 +300,15 @@ var readSetCookiesTests = []struct {
Raw: "samesitestrict=foo; SameSite=Strict",
}},
},
+ {
+ Header{"Set-Cookie": {"samesitenone=foo; SameSite=None"}},
+ []*Cookie{{
+ Name: "samesitenone",
+ Value: "foo",
+ SameSite: SameSiteNoneMode,
+ Raw: "samesitenone=foo; SameSite=None",
+ }},
+ },
// Make sure we can properly read back the Set-Cookie headers we create
// for values containing spaces or commas:
{
@@ -385,6 +414,19 @@ var readCookiesTests = []struct {
{Name: "c2", Value: "v2"},
},
},
+ {
+ Header{"Cookie": {`Cookie-1="v$1"; c2=v2;`}},
+ "",
+ []*Cookie{
+ {Name: "Cookie-1", Value: "v$1"},
+ {Name: "c2", Value: "v2"},
+ },
+ },
+ {
+ Header{"Cookie": {``}},
+ "",
+ []*Cookie{},
+ },
}
func TestReadCookies(t *testing.T) {
diff --git a/libgo/go/net/http/export_test.go b/libgo/go/net/http/export_test.go
index b6965c2..f0dfa8c 100644
--- a/libgo/go/net/http/export_test.go
+++ b/libgo/go/net/http/export_test.go
@@ -33,6 +33,7 @@ var (
ExportHttp2ConfigureServer = http2ConfigureServer
Export_shouldCopyHeaderOnRedirect = shouldCopyHeaderOnRedirect
Export_writeStatusLine = writeStatusLine
+ Export_is408Message = is408Message
)
const MaxWriteWaitBeforeConnReuse = maxWriteWaitBeforeConnReuse
@@ -244,3 +245,18 @@ func ExportSetH2GoawayTimeout(d time.Duration) (restore func()) {
}
func (r *Request) ExportIsReplayable() bool { return r.isReplayable() }
+
+// ExportCloseTransportConnsAbruptly closes all idle connections from
+// tr in an abrupt way, just reaching into the underlying Conns and
+// closing them, without telling the Transport or its persistConns
+// that it's doing so. This is to simulate the server closing connections
+// on the Transport.
+func ExportCloseTransportConnsAbruptly(tr *Transport) {
+ tr.idleMu.Lock()
+ for _, pcs := range tr.idleConn {
+ for _, pc := range pcs {
+ pc.conn.Close()
+ }
+ }
+ tr.idleMu.Unlock()
+}
diff --git a/libgo/go/net/http/fs.go b/libgo/go/net/http/fs.go
index db44d6b..41d46dc 100644
--- a/libgo/go/net/http/fs.go
+++ b/libgo/go/net/http/fs.go
@@ -63,6 +63,8 @@ func mapDirOpenError(originalErr error, name string) error {
return originalErr
}
+// Open implements FileSystem using os.Open, opening files for reading rooted
+// and relative to the directory d.
func (d Dir) Open(name string) (File, error) {
if filepath.Separator != '/' && strings.ContainsRune(name, filepath.Separator) {
return nil, errors.New("http: invalid character in file path")
diff --git a/libgo/go/net/http/h2_bundle.go b/libgo/go/net/http/h2_bundle.go
index a848d68..2efa0ef 100644
--- a/libgo/go/net/http/h2_bundle.go
+++ b/libgo/go/net/http/h2_bundle.go
@@ -1,5 +1,5 @@
// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT.
-//go:generate bundle -o h2_bundle.go -prefix http2 -underscore golang.org/x/net/http2
+//go:generate bundle -o h2_bundle.go -prefix http2 golang.org/x/net/http2
// Package http2 implements the HTTP/2 protocol.
//
@@ -42,11 +42,12 @@ import (
"strconv"
"strings"
"sync"
+ "sync/atomic"
"time"
- "internal/x/net/http/httpguts"
- "internal/x/net/http2/hpack"
- "internal/x/net/idna"
+ "golang.org/x/net/http/httpguts"
+ "golang.org/x/net/http2/hpack"
+ "golang.org/x/net/idna"
)
// A list of the possible cipher suite ids. Taken from
@@ -1885,7 +1886,7 @@ func (f *http2Framer) WriteData(streamID uint32, endStream bool, data []byte) er
return f.WriteDataPadded(streamID, endStream, data, nil)
}
-// WriteData writes a DATA frame with optional padding.
+// WriteDataPadded writes a DATA frame with optional padding.
//
// If pad is nil, the padding bit is not sent.
// The length of pad must not exceed 255 bytes.
@@ -3831,7 +3832,20 @@ func http2ConfigureServer(s *Server, conf *http2Server) error {
if http2testHookOnConn != nil {
http2testHookOnConn()
}
+ // The TLSNextProto interface predates contexts, so
+ // the net/http package passes down its per-connection
+ // base context via an exported but unadvertised
+ // method on the Handler. This is for internal
+ // net/http<=>http2 use only.
+ var ctx context.Context
+ type baseContexter interface {
+ BaseContext() context.Context
+ }
+ if bc, ok := h.(baseContexter); ok {
+ ctx = bc.BaseContext()
+ }
conf.ServeConn(c, &http2ServeConnOpts{
+ Context: ctx,
Handler: h,
BaseConfig: hs,
})
@@ -3842,6 +3856,10 @@ func http2ConfigureServer(s *Server, conf *http2Server) error {
// ServeConnOpts are options for the Server.ServeConn method.
type http2ServeConnOpts struct {
+ // Context is the base context to use.
+ // If nil, context.Background is used.
+ Context context.Context
+
// BaseConfig optionally sets the base configuration
// for values. If nil, defaults are used.
BaseConfig *Server
@@ -3852,6 +3870,13 @@ type http2ServeConnOpts struct {
Handler Handler
}
+func (o *http2ServeConnOpts) context() context.Context {
+ if o.Context != nil {
+ return o.Context
+ }
+ return context.Background()
+}
+
func (o *http2ServeConnOpts) baseConfig() *Server {
if o != nil && o.BaseConfig != nil {
return o.BaseConfig
@@ -3997,7 +4022,7 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) {
}
func http2serverConnBaseContext(c net.Conn, opts *http2ServeConnOpts) (ctx context.Context, cancel func()) {
- ctx, cancel = context.WithCancel(context.Background())
+ ctx, cancel = context.WithCancel(opts.context())
ctx = context.WithValue(ctx, LocalAddrContextKey, c.LocalAddr())
if hs := opts.baseConfig(); hs != nil {
ctx = context.WithValue(ctx, ServerContextKey, hs)
@@ -5870,7 +5895,16 @@ type http2chunkWriter struct{ rws *http2responseWriterState }
func (cw http2chunkWriter) Write(p []byte) (n int, err error) { return cw.rws.writeChunk(p) }
-func (rws *http2responseWriterState) hasTrailers() bool { return len(rws.trailers) != 0 }
+func (rws *http2responseWriterState) hasTrailers() bool { return len(rws.trailers) > 0 }
+
+func (rws *http2responseWriterState) hasNonemptyTrailers() bool {
+ for _, trailer := range rws.trailers {
+ if _, ok := rws.handlerHeader[trailer]; ok {
+ return true
+ }
+ }
+ return false
+}
// declareTrailer is called for each Trailer header when the
// response header is written. It notes that a header will need to be
@@ -5970,7 +6004,10 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) {
rws.promoteUndeclaredTrailers()
}
- endStream := rws.handlerDone && !rws.hasTrailers()
+ // only send trailers if they have actually been defined by the
+ // server handler.
+ hasNonemptyTrailers := rws.hasNonemptyTrailers()
+ endStream := rws.handlerDone && !hasNonemptyTrailers
if len(p) > 0 || endStream {
// only send a 0 byte DATA frame if we're ending the stream.
if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil {
@@ -5979,7 +6016,7 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) {
}
}
- if rws.handlerDone && rws.hasTrailers() {
+ if rws.handlerDone && hasNonemptyTrailers {
err = rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{
streamID: rws.stream.id,
h: rws.handlerHeader,
@@ -6621,6 +6658,7 @@ type http2ClientConn struct {
t *http2Transport
tconn net.Conn // usually *tls.Conn, except specialized impls
tlsState *tls.ConnectionState // nil only for specialized impls
+ reused uint32 // whether conn is being reused; atomic
singleUse bool // whether being used for a single http.Request
// readLoop goroutine fields:
@@ -6863,7 +6901,8 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res
t.vlogf("http2: Transport failed to get client conn for %s: %v", addr, err)
return nil, err
}
- http2traceGotConn(req, cc)
+ reused := !atomic.CompareAndSwapUint32(&cc.reused, 0, 1)
+ http2traceGotConn(req, cc, reused)
res, gotErrAfterReqBodyWrite, err := cc.roundTrip(req)
if err != nil && retry <= 6 {
if req, err = http2shouldRetryRequest(req, err, gotErrAfterReqBodyWrite); err == nil {
@@ -7849,7 +7888,11 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail
// followed by the query production (see Sections 3.3 and 3.4 of
// [RFC3986]).
f(":authority", host)
- f(":method", req.Method)
+ m := req.Method
+ if m == "" {
+ m = MethodGet
+ }
+ f(":method", m)
if req.Method != "CONNECT" {
f(":path", path)
f(":scheme", req.URL.Scheme)
@@ -8993,15 +9036,15 @@ func http2traceGetConn(req *Request, hostPort string) {
trace.GetConn(hostPort)
}
-func http2traceGotConn(req *Request, cc *http2ClientConn) {
+func http2traceGotConn(req *Request, cc *http2ClientConn, reused bool) {
trace := httptrace.ContextClientTrace(req.Context())
if trace == nil || trace.GotConn == nil {
return
}
ci := httptrace.GotConnInfo{Conn: cc.tconn}
+ ci.Reused = reused
cc.mu.Lock()
- ci.Reused = cc.nextStreamID > 1
- ci.WasIdle = len(cc.streams) == 0 && ci.Reused
+ ci.WasIdle = len(cc.streams) == 0 && reused
if ci.WasIdle && !cc.lastActive.IsZero() {
ci.IdleTime = time.Now().Sub(cc.lastActive)
}
diff --git a/libgo/go/net/http/header.go b/libgo/go/net/http/header.go
index b699e7e..1e1ed98 100644
--- a/libgo/go/net/http/header.go
+++ b/libgo/go/net/http/header.go
@@ -78,12 +78,19 @@ func (h Header) write(w io.Writer, trace *httptrace.ClientTrace) error {
return h.writeSubset(w, nil, trace)
}
-func (h Header) clone() Header {
+// Clone returns a copy of h.
+func (h Header) Clone() Header {
+ // Find total number of values.
+ nv := 0
+ for _, vv := range h {
+ nv += len(vv)
+ }
+ sv := make([]string, nv) // shared backing array for headers' values
h2 := make(Header, len(h))
for k, vv := range h {
- vv2 := make([]string, len(vv))
- copy(vv2, vv)
- h2[k] = vv2
+ n := copy(sv, vv)
+ h2[k] = sv[:n:n]
+ sv = sv[n:]
}
return h2
}
diff --git a/libgo/go/net/http/http.go b/libgo/go/net/http/http.go
index e5d59e1..3510fe6 100644
--- a/libgo/go/net/http/http.go
+++ b/libgo/go/net/http/http.go
@@ -11,7 +11,7 @@ import (
"time"
"unicode/utf8"
- "internal/x/net/http/httpguts"
+ "golang.org/x/net/http/httpguts"
)
// maxInt64 is the effective "infinite" value for the Server and
@@ -19,7 +19,7 @@ import (
const maxInt64 = 1<<63 - 1
// aLongTimeAgo is a non-zero time, far in the past, used for
-// immediate cancelation of network operations.
+// immediate cancellation of network operations.
var aLongTimeAgo = time.Unix(1, 0)
// TODO(bradfitz): move common stuff here. The other files have accumulated
diff --git a/libgo/go/net/http/httptest/recorder.go b/libgo/go/net/http/httptest/recorder.go
index f2c3c07..d0bc0fa 100644
--- a/libgo/go/net/http/httptest/recorder.go
+++ b/libgo/go/net/http/httptest/recorder.go
@@ -12,7 +12,7 @@ import (
"strconv"
"strings"
- "internal/x/net/http/httpguts"
+ "golang.org/x/net/http/httpguts"
)
// ResponseRecorder is an implementation of http.ResponseWriter that
@@ -59,7 +59,10 @@ func NewRecorder() *ResponseRecorder {
// an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
const DefaultRemoteAddr = "1.2.3.4"
-// Header returns the response headers.
+// Header implements http.ResponseWriter. It returns the response
+// headers to mutate within a handler. To test the headers that were
+// written after a handler completes, use the Result method and see
+// the returned Response value's Header.
func (rw *ResponseRecorder) Header() http.Header {
m := rw.HeaderMap
if m == nil {
@@ -98,7 +101,8 @@ func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
rw.WriteHeader(200)
}
-// Write always succeeds and writes to rw.Body, if not nil.
+// Write implements http.ResponseWriter. The data in buf is written to
+// rw.Body, if not nil.
func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
rw.writeHeader(buf, "")
if rw.Body != nil {
@@ -107,7 +111,8 @@ func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
return len(buf), nil
}
-// WriteString always succeeds and writes to rw.Body, if not nil.
+// WriteString implements io.StringWriter. The data in str is written
+// to rw.Body, if not nil.
func (rw *ResponseRecorder) WriteString(str string) (int, error) {
rw.writeHeader(nil, str)
if rw.Body != nil {
@@ -116,8 +121,7 @@ func (rw *ResponseRecorder) WriteString(str string) (int, error) {
return len(str), nil
}
-// WriteHeader sets rw.Code. After it is called, changing rw.Header
-// will not affect rw.HeaderMap.
+// WriteHeader implements http.ResponseWriter.
func (rw *ResponseRecorder) WriteHeader(code int) {
if rw.wroteHeader {
return
@@ -127,20 +131,11 @@ func (rw *ResponseRecorder) WriteHeader(code int) {
if rw.HeaderMap == nil {
rw.HeaderMap = make(http.Header)
}
- rw.snapHeader = cloneHeader(rw.HeaderMap)
+ rw.snapHeader = rw.HeaderMap.Clone()
}
-func cloneHeader(h http.Header) http.Header {
- h2 := make(http.Header, len(h))
- for k, vv := range h {
- vv2 := make([]string, len(vv))
- copy(vv2, vv)
- h2[k] = vv2
- }
- return h2
-}
-
-// Flush sets rw.Flushed to true.
+// Flush implements http.Flusher. To test whether Flush was
+// called, see rw.Flushed.
func (rw *ResponseRecorder) Flush() {
if !rw.wroteHeader {
rw.WriteHeader(200)
@@ -168,7 +163,7 @@ func (rw *ResponseRecorder) Result() *http.Response {
return rw.result
}
if rw.snapHeader == nil {
- rw.snapHeader = cloneHeader(rw.HeaderMap)
+ rw.snapHeader = rw.HeaderMap.Clone()
}
res := &http.Response{
Proto: "HTTP/1.1",
diff --git a/libgo/go/net/http/httputil/dump_test.go b/libgo/go/net/http/httputil/dump_test.go
index 63312dd..97954ca 100644
--- a/libgo/go/net/http/httputil/dump_test.go
+++ b/libgo/go/net/http/httputil/dump_test.go
@@ -18,7 +18,10 @@ import (
)
type dumpTest struct {
- Req http.Request
+ // Either Req or GetReq can be set/nil but not both.
+ Req *http.Request
+ GetReq func() *http.Request
+
Body interface{} // optional []byte or func() io.ReadCloser to populate Req.Body
WantDump string
@@ -29,7 +32,7 @@ type dumpTest struct {
var dumpTests = []dumpTest{
// HTTP/1.1 => chunked coding; body; empty trailer
{
- Req: http.Request{
+ Req: &http.Request{
Method: "GET",
URL: &url.URL{
Scheme: "http",
@@ -52,7 +55,7 @@ var dumpTests = []dumpTest{
// Verify that DumpRequest preserves the HTTP version number, doesn't add a Host,
// and doesn't add a User-Agent.
{
- Req: http.Request{
+ Req: &http.Request{
Method: "GET",
URL: mustParseURL("/foo"),
ProtoMajor: 1,
@@ -67,7 +70,7 @@ var dumpTests = []dumpTest{
},
{
- Req: *mustNewRequest("GET", "http://example.com/foo", nil),
+ Req: mustNewRequest("GET", "http://example.com/foo", nil),
WantDumpOut: "GET /foo HTTP/1.1\r\n" +
"Host: example.com\r\n" +
@@ -79,8 +82,7 @@ var dumpTests = []dumpTest{
// with a bytes.Buffer and hang with all goroutines not
// runnable.
{
- Req: *mustNewRequest("GET", "https://example.com/foo", nil),
-
+ Req: mustNewRequest("GET", "https://example.com/foo", nil),
WantDumpOut: "GET /foo HTTP/1.1\r\n" +
"Host: example.com\r\n" +
"User-Agent: Go-http-client/1.1\r\n" +
@@ -89,7 +91,7 @@ var dumpTests = []dumpTest{
// Request with Body, but Dump requested without it.
{
- Req: http.Request{
+ Req: &http.Request{
Method: "POST",
URL: &url.URL{
Scheme: "http",
@@ -114,7 +116,7 @@ var dumpTests = []dumpTest{
// Request with Body > 8196 (default buffer size)
{
- Req: http.Request{
+ Req: &http.Request{
Method: "POST",
URL: &url.URL{
Scheme: "http",
@@ -145,8 +147,10 @@ var dumpTests = []dumpTest{
},
{
- Req: *mustReadRequest("GET http://foo.com/ HTTP/1.1\r\n" +
- "User-Agent: blah\r\n\r\n"),
+ GetReq: func() *http.Request {
+ return mustReadRequest("GET http://foo.com/ HTTP/1.1\r\n" +
+ "User-Agent: blah\r\n\r\n")
+ },
NoBody: true,
WantDump: "GET http://foo.com/ HTTP/1.1\r\n" +
"User-Agent: blah\r\n\r\n",
@@ -154,22 +158,25 @@ var dumpTests = []dumpTest{
// Issue #7215. DumpRequest should return the "Content-Length" when set
{
- Req: *mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" +
- "Host: passport.myhost.com\r\n" +
- "Content-Length: 3\r\n" +
- "\r\nkey1=name1&key2=name2"),
+ GetReq: func() *http.Request {
+ return mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" +
+ "Host: passport.myhost.com\r\n" +
+ "Content-Length: 3\r\n" +
+ "\r\nkey1=name1&key2=name2")
+ },
WantDump: "POST /v2/api/?login HTTP/1.1\r\n" +
"Host: passport.myhost.com\r\n" +
"Content-Length: 3\r\n" +
"\r\nkey",
},
-
// Issue #7215. DumpRequest should return the "Content-Length" in ReadRequest
{
- Req: *mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" +
- "Host: passport.myhost.com\r\n" +
- "Content-Length: 0\r\n" +
- "\r\nkey1=name1&key2=name2"),
+ GetReq: func() *http.Request {
+ return mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" +
+ "Host: passport.myhost.com\r\n" +
+ "Content-Length: 0\r\n" +
+ "\r\nkey1=name1&key2=name2")
+ },
WantDump: "POST /v2/api/?login HTTP/1.1\r\n" +
"Host: passport.myhost.com\r\n" +
"Content-Length: 0\r\n\r\n",
@@ -177,9 +184,11 @@ var dumpTests = []dumpTest{
// Issue #7215. DumpRequest should not return the "Content-Length" if unset
{
- Req: *mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" +
- "Host: passport.myhost.com\r\n" +
- "\r\nkey1=name1&key2=name2"),
+ GetReq: func() *http.Request {
+ return mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" +
+ "Host: passport.myhost.com\r\n" +
+ "\r\nkey1=name1&key2=name2")
+ },
WantDump: "POST /v2/api/?login HTTP/1.1\r\n" +
"Host: passport.myhost.com\r\n\r\n",
},
@@ -187,8 +196,7 @@ var dumpTests = []dumpTest{
// Issue 18506: make drainBody recognize NoBody. Otherwise
// this was turning into a chunked request.
{
- Req: *mustNewRequest("POST", "http://example.com/foo", http.NoBody),
-
+ Req: mustNewRequest("POST", "http://example.com/foo", http.NoBody),
WantDumpOut: "POST /foo HTTP/1.1\r\n" +
"Host: example.com\r\n" +
"User-Agent: Go-http-client/1.1\r\n" +
@@ -200,28 +208,40 @@ var dumpTests = []dumpTest{
func TestDumpRequest(t *testing.T) {
numg0 := runtime.NumGoroutine()
for i, tt := range dumpTests {
- setBody := func() {
- if tt.Body == nil {
- return
+ if tt.Req != nil && tt.GetReq != nil || tt.Req == nil && tt.GetReq == nil {
+ t.Errorf("#%d: either .Req(%p) or .GetReq(%p) can be set/nil but not both", i, tt.Req, tt.GetReq)
+ continue
+ }
+
+ freshReq := func(ti dumpTest) *http.Request {
+ req := ti.Req
+ if req == nil {
+ req = ti.GetReq()
}
- switch b := tt.Body.(type) {
+
+ if req.Header == nil {
+ req.Header = make(http.Header)
+ }
+
+ if ti.Body == nil {
+ return req
+ }
+ switch b := ti.Body.(type) {
case []byte:
- tt.Req.Body = ioutil.NopCloser(bytes.NewReader(b))
+ req.Body = ioutil.NopCloser(bytes.NewReader(b))
case func() io.ReadCloser:
- tt.Req.Body = b()
+ req.Body = b()
default:
- t.Fatalf("Test %d: unsupported Body of %T", i, tt.Body)
+ t.Fatalf("Test %d: unsupported Body of %T", i, ti.Body)
}
- }
- if tt.Req.Header == nil {
- tt.Req.Header = make(http.Header)
+ return req
}
if tt.WantDump != "" {
- setBody()
- dump, err := DumpRequest(&tt.Req, !tt.NoBody)
+ req := freshReq(tt)
+ dump, err := DumpRequest(req, !tt.NoBody)
if err != nil {
- t.Errorf("DumpRequest #%d: %s", i, err)
+ t.Errorf("DumpRequest #%d: %s\nWantDump:\n%s", i, err, tt.WantDump)
continue
}
if string(dump) != tt.WantDump {
@@ -231,8 +251,8 @@ func TestDumpRequest(t *testing.T) {
}
if tt.WantDumpOut != "" {
- setBody()
- dump, err := DumpRequestOut(&tt.Req, !tt.NoBody)
+ req := freshReq(tt)
+ dump, err := DumpRequestOut(req, !tt.NoBody)
if err != nil {
t.Errorf("DumpRequestOut #%d: %s", i, err)
continue
diff --git a/libgo/go/net/http/httputil/persist.go b/libgo/go/net/http/httputil/persist.go
index cbedf25..84b116d 100644
--- a/libgo/go/net/http/httputil/persist.go
+++ b/libgo/go/net/http/httputil/persist.go
@@ -292,8 +292,8 @@ func (cc *ClientConn) Close() error {
}
// Write writes a request. An ErrPersistEOF error is returned if the connection
-// has been closed in an HTTP keepalive sense. If req.Close equals true, the
-// keepalive connection is logically closed after this request and the opposing
+// has been closed in an HTTP keep-alive sense. If req.Close equals true, the
+// keep-alive connection is logically closed after this request and the opposing
// server is informed. An ErrUnexpectedEOF indicates the remote closed the
// underlying TCP connection, which is usually considered as graceful close.
func (cc *ClientConn) Write(req *http.Request) error {
diff --git a/libgo/go/net/http/httputil/reverseproxy.go b/libgo/go/net/http/httputil/reverseproxy.go
index 4b165d6..1d7b0ef 100644
--- a/libgo/go/net/http/httputil/reverseproxy.go
+++ b/libgo/go/net/http/httputil/reverseproxy.go
@@ -18,7 +18,7 @@ import (
"sync"
"time"
- "internal/x/net/http/httpguts"
+ "golang.org/x/net/http/httpguts"
)
// ReverseProxy is an HTTP Handler that takes an incoming request and
@@ -51,8 +51,7 @@ type ReverseProxy struct {
// ErrorLog specifies an optional logger for errors
// that occur when attempting to proxy the request.
- // If nil, logging goes to os.Stderr via the log package's
- // standard logger.
+ // If nil, logging is done via the log package's standard logger.
ErrorLog *log.Logger
// BufferPool optionally specifies a buffer pool to
@@ -132,16 +131,6 @@ func copyHeader(dst, src http.Header) {
}
}
-func cloneHeader(h http.Header) http.Header {
- h2 := make(http.Header, len(h))
- for k, vv := range h {
- vv2 := make([]string, len(vv))
- copy(vv2, vv)
- h2[k] = vv2
- }
- return h2
-}
-
// Hop-by-hop headers. These are removed when sent to the backend.
// As of RFC 7230, hop-by-hop headers are required to appear in the
// Connection header field. These are the headers defined by the
@@ -206,13 +195,11 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}()
}
- outreq := req.WithContext(ctx) // includes shallow copies of maps, but okay
+ outreq := req.Clone(ctx)
if req.ContentLength == 0 {
outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
}
- outreq.Header = cloneHeader(req.Header)
-
p.Director(outreq)
outreq.Close = false
@@ -357,10 +344,10 @@ func shouldPanicOnCopyError(req *http.Request) bool {
// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h.
// See RFC 7230, section 6.1
func removeConnectionHeaders(h http.Header) {
- if c := h.Get("Connection"); c != "" {
- for _, f := range strings.Split(c, ",") {
- if f = strings.TrimSpace(f); f != "" {
- h.Del(f)
+ for _, f := range h["Connection"] {
+ for _, sf := range strings.Split(f, ",") {
+ if sf = strings.TrimSpace(sf); sf != "" {
+ h.Del(sf)
}
}
}
diff --git a/libgo/go/net/http/httputil/reverseproxy_test.go b/libgo/go/net/http/httputil/reverseproxy_test.go
index 367ba73..e8cb814 100644
--- a/libgo/go/net/http/httputil/reverseproxy_test.go
+++ b/libgo/go/net/http/httputil/reverseproxy_test.go
@@ -20,6 +20,7 @@ import (
"net/url"
"os"
"reflect"
+ "sort"
"strconv"
"strings"
"sync"
@@ -160,13 +161,17 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
const someConnHeader = "X-Some-Conn-Header"
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if c := r.Header.Get("Connection"); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", "Connection", c)
+ }
if c := r.Header.Get(fakeConnectionToken); c != "" {
t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
}
if c := r.Header.Get(someConnHeader); c != "" {
t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
}
- w.Header().Set("Connection", someConnHeader+", "+fakeConnectionToken)
+ w.Header().Add("Connection", "Upgrade, "+fakeConnectionToken)
+ w.Header().Add("Connection", someConnHeader)
w.Header().Set(someConnHeader, "should be deleted")
w.Header().Set(fakeConnectionToken, "should be deleted")
io.WriteString(w, backendResponse)
@@ -179,15 +184,34 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
proxyHandler := NewSingleHostReverseProxy(backendURL)
frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxyHandler.ServeHTTP(w, r)
- if c := r.Header.Get(someConnHeader); c != "original value" {
- t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "original value")
+ if c := r.Header.Get(someConnHeader); c != "should be deleted" {
+ t.Errorf("handler modified header %q = %q; want %q", someConnHeader, c, "should be deleted")
+ }
+ if c := r.Header.Get(fakeConnectionToken); c != "should be deleted" {
+ t.Errorf("handler modified header %q = %q; want %q", fakeConnectionToken, c, "should be deleted")
+ }
+ c := r.Header["Connection"]
+ var cf []string
+ for _, f := range c {
+ for _, sf := range strings.Split(f, ",") {
+ if sf = strings.TrimSpace(sf); sf != "" {
+ cf = append(cf, sf)
+ }
+ }
+ }
+ sort.Strings(cf)
+ expectedValues := []string{"Upgrade", someConnHeader, fakeConnectionToken}
+ sort.Strings(expectedValues)
+ if !reflect.DeepEqual(cf, expectedValues) {
+ t.Errorf("handler modified header %q = %q; want %q", "Connection", cf, expectedValues)
}
}))
defer frontend.Close()
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
- getReq.Header.Set("Connection", someConnHeader+", "+fakeConnectionToken)
- getReq.Header.Set(someConnHeader, "original value")
+ getReq.Header.Add("Connection", "Upgrade, "+fakeConnectionToken)
+ getReq.Header.Add("Connection", someConnHeader)
+ getReq.Header.Set(someConnHeader, "should be deleted")
getReq.Header.Set(fakeConnectionToken, "should be deleted")
res, err := frontend.Client().Do(getReq)
if err != nil {
@@ -201,6 +225,9 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
if got, want := string(bodyBytes), backendResponse; got != want {
t.Errorf("got body %q; want %q", got, want)
}
+ if c := res.Header.Get("Connection"); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", "Connection", c)
+ }
if c := res.Header.Get(someConnHeader); c != "" {
t.Errorf("handler got header %q = %q; want empty", someConnHeader, c)
}
diff --git a/libgo/go/net/http/internal/testcert.go b/libgo/go/net/http/internal/testcert.go
index 4078909..2284a83 100644
--- a/libgo/go/net/http/internal/testcert.go
+++ b/libgo/go/net/http/internal/testcert.go
@@ -4,6 +4,8 @@
package internal
+import "strings"
+
// LocalhostCert is a PEM-encoded TLS cert with SAN IPs
// "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT.
// generated from src/crypto/tls:
@@ -24,7 +26,7 @@ fblo6RBxUQ==
-----END CERTIFICATE-----`)
// LocalhostKey is the private key for localhostCert.
-var LocalhostKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
+var LocalhostKey = []byte(testingKey(`-----BEGIN RSA TESTING KEY-----
MIICXgIBAAKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9
SjY1bIw4iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZB
l2+XsDulrKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQAB
@@ -38,4 +40,6 @@ fQtuUE9txblTu14q3N7gHRZB4ZMhFYyDy8CKrN2cPg/Fvyt0Xlp/DoCzjA0CQQDU
y2ptGsuSmgUtWj3NM9xuwYPm+Z/F84K6+ARYiZ6PYj013sovGKUFfYAqVXVlxtIX
qyUBnu3X9ps8ZfjLZO7BAkEAlT4R5Yl6cGhaJQYZHOde3JEMhNRcVFMO8dJDaFeo
f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA==
------END RSA PRIVATE KEY-----`)
+-----END RSA TESTING KEY-----`))
+
+func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
diff --git a/libgo/go/net/http/request.go b/libgo/go/net/http/request.go
index dcad2b6..fa63175 100644
--- a/libgo/go/net/http/request.go
+++ b/libgo/go/net/http/request.go
@@ -26,7 +26,7 @@ import (
"strings"
"sync"
- "internal/x/net/idna"
+ "golang.org/x/net/idna"
)
const (
@@ -304,7 +304,7 @@ type Request struct {
//
// For server requests, this field is not applicable.
//
- // Deprecated: Use the Context and WithContext methods
+ // Deprecated: Set the Request's context with NewRequestWithContext
// instead. If a Request's Cancel field and context are both
// set, it is undefined whether Cancel is respected.
Cancel <-chan struct{}
@@ -327,7 +327,7 @@ type Request struct {
// The returned context is always non-nil; it defaults to the
// background context.
//
-// For outgoing client requests, the context controls cancelation.
+// For outgoing client requests, the context controls cancellation.
//
// For incoming server requests, the context is canceled when the
// client's connection closes, the request is canceled (with HTTP/2),
@@ -345,6 +345,11 @@ func (r *Request) Context() context.Context {
// For outgoing client request, the context controls the entire
// lifetime of a request and its response: obtaining a connection,
// sending the request, and reading the response headers and body.
+//
+// To create a new request with a context, use NewRequestWithContext.
+// To change the context of a request (such as an incoming) you then
+// also want to modify to send back out, use Request.Clone. Between
+// those two uses, it's rare to need WithContext.
func (r *Request) WithContext(ctx context.Context) *Request {
if ctx == nil {
panic("nil context")
@@ -352,16 +357,38 @@ func (r *Request) WithContext(ctx context.Context) *Request {
r2 := new(Request)
*r2 = *r
r2.ctx = ctx
+ r2.URL = cloneURL(r.URL) // legacy behavior; TODO: try to remove. Issue 23544
+ return r2
+}
- // Deep copy the URL because it isn't
- // a map and the URL is mutable by users
- // of WithContext.
- if r.URL != nil {
- r2URL := new(url.URL)
- *r2URL = *r.URL
- r2.URL = r2URL
+// Clone returns a deep copy of r with its context changed to ctx.
+// The provided ctx must be non-nil.
+//
+// For an outgoing client request, the context controls the entire
+// lifetime of a request and its response: obtaining a connection,
+// sending the request, and reading the response headers and body.
+func (r *Request) Clone(ctx context.Context) *Request {
+ if ctx == nil {
+ panic("nil context")
}
-
+ r2 := new(Request)
+ *r2 = *r
+ r2.ctx = ctx
+ r2.URL = cloneURL(r.URL)
+ if r.Header != nil {
+ r2.Header = r.Header.Clone()
+ }
+ if r.Trailer != nil {
+ r2.Trailer = r.Trailer.Clone()
+ }
+ if s := r.TransferEncoding; s != nil {
+ s2 := make([]string, len(s))
+ copy(s2, s)
+ r2.TransferEncoding = s
+ }
+ r2.Form = cloneURLValues(r.Form)
+ r2.PostForm = cloneURLValues(r.PostForm)
+ r2.MultipartForm = cloneMultipartForm(r.MultipartForm)
return r2
}
@@ -781,25 +808,34 @@ func validMethod(method string) bool {
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
}
-// NewRequest returns a new Request given a method, URL, and optional body.
+// NewRequest wraps NewRequestWithContext using the background context.
+func NewRequest(method, url string, body io.Reader) (*Request, error) {
+ return NewRequestWithContext(context.Background(), method, url, body)
+}
+
+// NewRequestWithContext returns a new Request given a method, URL, and
+// optional body.
//
// If the provided body is also an io.Closer, the returned
// Request.Body is set to body and will be closed by the Client
// methods Do, Post, and PostForm, and Transport.RoundTrip.
//
-// NewRequest returns a Request suitable for use with Client.Do or
-// Transport.RoundTrip. To create a request for use with testing a
-// Server Handler, either use the NewRequest function in the
+// NewRequestWithContext returns a Request suitable for use with
+// Client.Do or Transport.RoundTrip. To create a request for use with
+// testing a Server Handler, either use the NewRequest function in the
// net/http/httptest package, use ReadRequest, or manually update the
-// Request fields. See the Request type's documentation for the
-// difference between inbound and outbound request fields.
+// Request fields. For an outgoing client request, the context
+// controls the entire lifetime of a request and its response:
+// obtaining a connection, sending the request, and reading the
+// response headers and body. See the Request type's documentation for
+// the difference between inbound and outbound request fields.
//
// If body is of type *bytes.Buffer, *bytes.Reader, or
// *strings.Reader, the returned request's ContentLength is set to its
// exact value (instead of -1), GetBody is populated (so 307 and 308
// redirects can replay the body), and Body is set to NoBody if the
// ContentLength is 0.
-func NewRequest(method, url string, body io.Reader) (*Request, error) {
+func NewRequestWithContext(ctx context.Context, method, url string, body io.Reader) (*Request, error) {
if method == "" {
// We document that "" means "GET" for Request.Method, and people have
// relied on that from NewRequest, so keep that working.
@@ -809,6 +845,9 @@ func NewRequest(method, url string, body io.Reader) (*Request, error) {
if !validMethod(method) {
return nil, fmt.Errorf("net/http: invalid method %q", method)
}
+ if ctx == nil {
+ return nil, errors.New("net/http: nil Context")
+ }
u, err := parseURL(url) // Just url.Parse (url is shadowed for godoc).
if err != nil {
return nil, err
@@ -820,6 +859,7 @@ func NewRequest(method, url string, body io.Reader) (*Request, error) {
// The host's colon:port should be normalized. See Issue 14836.
u.Host = removeEmptyPort(u.Host)
req := &Request{
+ ctx: ctx,
Method: method,
URL: u,
Proto: "HTTP/1.1",
@@ -912,6 +952,10 @@ func parseBasicAuth(auth string) (username, password string, ok bool) {
//
// With HTTP Basic Authentication the provided username and password
// are not encrypted.
+//
+// Some protocols may impose additional requirements on pre-escaping the
+// username and password. For instance, when used with OAuth2, both arguments
+// must be URL encoded first with url.QueryEscape.
func (r *Request) SetBasicAuth(username, password string) {
r.Header.Set("Authorization", "Basic "+basicAuth(username, password))
}
diff --git a/libgo/go/net/http/request_test.go b/libgo/go/net/http/request_test.go
index e800557..b072f95 100644
--- a/libgo/go/net/http/request_test.go
+++ b/libgo/go/net/http/request_test.go
@@ -8,12 +8,14 @@ import (
"bufio"
"bytes"
"context"
+ "crypto/rand"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"mime/multipart"
. "net/http"
+ "net/http/httptest"
"net/url"
"os"
"reflect"
@@ -133,30 +135,31 @@ func TestParseFormInitializeOnError(t *testing.T) {
}
func TestMultipartReader(t *testing.T) {
- req := &Request{
- Method: "POST",
- Header: Header{"Content-Type": {`multipart/form-data; boundary="foo123"`}},
- Body: ioutil.NopCloser(new(bytes.Buffer)),
- }
- multipart, err := req.MultipartReader()
- if multipart == nil {
- t.Errorf("expected multipart; error: %v", err)
- }
-
- req = &Request{
- Method: "POST",
- Header: Header{"Content-Type": {`multipart/mixed; boundary="foo123"`}},
- Body: ioutil.NopCloser(new(bytes.Buffer)),
- }
- multipart, err = req.MultipartReader()
- if multipart == nil {
- t.Errorf("expected multipart; error: %v", err)
+ tests := []struct {
+ shouldError bool
+ contentType string
+ }{
+ {false, `multipart/form-data; boundary="foo123"`},
+ {false, `multipart/mixed; boundary="foo123"`},
+ {true, `text/plain`},
}
- req.Header = Header{"Content-Type": {"text/plain"}}
- multipart, err = req.MultipartReader()
- if multipart != nil {
- t.Error("unexpected multipart for text/plain")
+ for i, test := range tests {
+ req := &Request{
+ Method: "POST",
+ Header: Header{"Content-Type": {test.contentType}},
+ Body: ioutil.NopCloser(new(bytes.Buffer)),
+ }
+ multipart, err := req.MultipartReader()
+ if test.shouldError {
+ if err == nil || multipart != nil {
+ t.Errorf("test %d: unexpectedly got nil-error (%v) or non-nil-multipart (%v)", i, err, multipart)
+ }
+ continue
+ }
+ if err != nil || multipart == nil {
+ t.Errorf("test %d: unexpectedly got error (%v) or nil-multipart (%v)", i, err, multipart)
+ }
}
}
@@ -1046,3 +1049,92 @@ func BenchmarkReadRequestWrk(b *testing.B) {
Host: localhost:8080
`)
}
+
+const (
+ withTLS = true
+ noTLS = false
+)
+
+func BenchmarkFileAndServer_1KB(b *testing.B) {
+ benchmarkFileAndServer(b, 1<<10)
+}
+
+func BenchmarkFileAndServer_16MB(b *testing.B) {
+ benchmarkFileAndServer(b, 1<<24)
+}
+
+func BenchmarkFileAndServer_64MB(b *testing.B) {
+ benchmarkFileAndServer(b, 1<<26)
+}
+
+func benchmarkFileAndServer(b *testing.B, n int64) {
+ f, err := ioutil.TempFile(os.TempDir(), "go-bench-http-file-and-server")
+ if err != nil {
+ b.Fatalf("Failed to create temp file: %v", err)
+ }
+
+ defer func() {
+ f.Close()
+ os.RemoveAll(f.Name())
+ }()
+
+ if _, err := io.CopyN(f, rand.Reader, n); err != nil {
+ b.Fatalf("Failed to copy %d bytes: %v", n, err)
+ }
+
+ b.Run("NoTLS", func(b *testing.B) {
+ runFileAndServerBenchmarks(b, noTLS, f, n)
+ })
+
+ b.Run("TLS", func(b *testing.B) {
+ runFileAndServerBenchmarks(b, withTLS, f, n)
+ })
+}
+
+func runFileAndServerBenchmarks(b *testing.B, tlsOption bool, f *os.File, n int64) {
+ handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
+ defer req.Body.Close()
+ nc, err := io.Copy(ioutil.Discard, req.Body)
+ if err != nil {
+ panic(err)
+ }
+
+ if nc != n {
+ panic(fmt.Errorf("Copied %d Wanted %d bytes", nc, n))
+ }
+ })
+
+ var cst *httptest.Server
+ if tlsOption == withTLS {
+ cst = httptest.NewTLSServer(handler)
+ } else {
+ cst = httptest.NewServer(handler)
+ }
+
+ defer cst.Close()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ // Perform some setup.
+ b.StopTimer()
+ if _, err := f.Seek(0, 0); err != nil {
+ b.Fatalf("Failed to seek back to file: %v", err)
+ }
+
+ b.StartTimer()
+ req, err := NewRequest("PUT", cst.URL, ioutil.NopCloser(f))
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ req.ContentLength = n
+ // Prevent mime sniffing by setting the Content-Type.
+ req.Header.Set("Content-Type", "application/octet-stream")
+ res, err := cst.Client().Do(req)
+ if err != nil {
+ b.Fatalf("Failed to make request to backend: %v", err)
+ }
+
+ res.Body.Close()
+ b.SetBytes(n)
+ }
+}
diff --git a/libgo/go/net/http/response.go b/libgo/go/net/http/response.go
index f906ce8..2065a25 100644
--- a/libgo/go/net/http/response.go
+++ b/libgo/go/net/http/response.go
@@ -12,12 +12,13 @@ import (
"crypto/tls"
"errors"
"fmt"
- "internal/x/net/http/httpguts"
"io"
"net/textproto"
"net/url"
"strconv"
"strings"
+
+ "golang.org/x/net/http/httpguts"
)
var respExcludeHeader = map[string]bool{
@@ -66,7 +67,7 @@ type Response struct {
// with a "chunked" Transfer-Encoding.
//
// As of Go 1.12, the Body will be also implement io.Writer
- // on a successful "101 Switching Protocols" responses,
+ // on a successful "101 Switching Protocols" response,
// as used by WebSockets and HTTP/2's "h2c" mode.
Body io.ReadCloser
diff --git a/libgo/go/net/http/response_test.go b/libgo/go/net/http/response_test.go
index c46f13f..ee7f0d0 100644
--- a/libgo/go/net/http/response_test.go
+++ b/libgo/go/net/http/response_test.go
@@ -10,7 +10,7 @@ import (
"compress/gzip"
"crypto/rand"
"fmt"
- "go/ast"
+ "go/token"
"io"
"io/ioutil"
"net/http/internal"
@@ -736,7 +736,7 @@ func diff(t *testing.T, prefix string, have, want interface{}) {
}
for i := 0; i < hv.NumField(); i++ {
name := hv.Type().Field(i).Name
- if !ast.IsExported(name) {
+ if !token.IsExported(name) {
continue
}
hf := hv.Field(i).Interface()
diff --git a/libgo/go/net/http/roundtrip_js.go b/libgo/go/net/http/roundtrip_js.go
index 1e38b90..6331351 100644
--- a/libgo/go/net/http/roundtrip_js.go
+++ b/libgo/go/net/http/roundtrip_js.go
@@ -11,12 +11,12 @@ import (
"fmt"
"io"
"io/ioutil"
- "os"
"strconv"
- "strings"
"syscall/js"
)
+var uint8Array = js.Global().Get("Uint8Array")
+
// jsFetchMode is a Request.Header map key that, if present,
// signals that the map entry is actually an option to the Fetch API mode setting.
// Valid values are: "cors", "no-cors", "same-origin", "navigate"
@@ -33,9 +33,19 @@ const jsFetchMode = "js.fetch:mode"
// Reference: https://developer.mozilla.org/en-US/docs/Web/API/WindowOrWorkerGlobalScope/fetch#Parameters
const jsFetchCreds = "js.fetch:credentials"
+// jsFetchRedirect is a Request.Header map key that, if present,
+// signals that the map entry is actually an option to the Fetch API redirect setting.
+// Valid values are: "follow", "error", "manual"
+// The default is "follow".
+//
+// Reference: https://developer.mozilla.org/en-US/docs/Web/API/WindowOrWorkerGlobalScope/fetch#Parameters
+const jsFetchRedirect = "js.fetch:redirect"
+
+var useFakeNetwork = js.Global().Get("fetch") == js.Undefined()
+
// RoundTrip implements the RoundTripper interface using the WHATWG Fetch API.
func (t *Transport) RoundTrip(req *Request) (*Response, error) {
- if useFakeNetwork() {
+ if useFakeNetwork {
return t.roundTrip(req)
}
@@ -60,6 +70,10 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) {
opt.Set("mode", h)
req.Header.Del(jsFetchMode)
}
+ if h := req.Header.Get(jsFetchRedirect); h != "" {
+ opt.Set("redirect", h)
+ req.Header.Del(jsFetchRedirect)
+ }
if ac != js.Undefined() {
opt.Set("signal", ac.Get("signal"))
}
@@ -84,9 +98,9 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) {
return nil, err
}
req.Body.Close()
- a := js.TypedArrayOf(body)
- defer a.Release()
- opt.Set("body", a)
+ buf := uint8Array.New(len(body))
+ js.CopyBytesToJS(buf, body)
+ opt.Set("body", buf)
}
respPromise := js.Global().Call("fetch", req.URL.String(), opt)
var (
@@ -126,10 +140,11 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) {
body = &arrayReader{arrayPromise: result.Call("arrayBuffer")}
}
+ code := result.Get("status").Int()
select {
case respCh <- &Response{
- Status: result.Get("status").String() + " " + StatusText(result.Get("status").Int()),
- StatusCode: result.Get("status").Int(),
+ Status: fmt.Sprintf("%d %s", code, StatusText(code)),
+ StatusCode: code,
Header: header,
ContentLength: contentLength,
Body: body,
@@ -167,12 +182,6 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) {
var errClosed = errors.New("net/http: reader is closed")
-// useFakeNetwork is used to determine whether the request is made
-// by a test and should be made to use the fake in-memory network.
-func useFakeNetwork() bool {
- return len(os.Args) > 0 && strings.HasSuffix(os.Args[0], ".test")
-}
-
// streamReader implements an io.ReadCloser wrapper for ReadableStream.
// See https://fetch.spec.whatwg.org/#readablestream for more information.
type streamReader struct {
@@ -197,9 +206,7 @@ func (r *streamReader) Read(p []byte) (n int, err error) {
return nil
}
value := make([]byte, result.Get("value").Get("byteLength").Int())
- a := js.TypedArrayOf(value)
- a.Call("set", result.Get("value"))
- a.Release()
+ js.CopyBytesToGo(value, result.Get("value"))
bCh <- value
return nil
})
@@ -260,11 +267,9 @@ func (r *arrayReader) Read(p []byte) (n int, err error) {
)
success := js.FuncOf(func(this js.Value, args []js.Value) interface{} {
// Wrap the input ArrayBuffer with a Uint8Array
- uint8arrayWrapper := js.Global().Get("Uint8Array").New(args[0])
+ uint8arrayWrapper := uint8Array.New(args[0])
value := make([]byte, uint8arrayWrapper.Get("byteLength").Int())
- a := js.TypedArrayOf(value)
- a.Call("set", uint8arrayWrapper)
- a.Release()
+ js.CopyBytesToGo(value, uint8arrayWrapper)
bCh <- value
return nil
})
diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go
index 6eb0088..e7ed15c 100644
--- a/libgo/go/net/http/serve_test.go
+++ b/libgo/go/net/http/serve_test.go
@@ -4273,7 +4273,7 @@ func testServerEmptyBodyRace(t *testing.T, h2 bool) {
var n int32
cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) {
atomic.AddInt32(&n, 1)
- }))
+ }), optQuietLog)
defer cst.close()
var wg sync.WaitGroup
const reqs = 20
@@ -4697,6 +4697,10 @@ func TestServerHandlersCanHandleH2PRI(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
conn, br, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
defer conn.Close()
if r.Method != "PRI" || r.RequestURI != "*" {
t.Errorf("Got method/target %q %q; want PRI *", r.Method, r.RequestURI)
@@ -5745,8 +5749,12 @@ func TestServerDuplicateBackgroundRead(t *testing.T) {
setParallel(t)
defer afterTest(t)
- const goroutines = 5
- const requests = 2000
+ goroutines := 5
+ requests := 2000
+ if testing.Short() {
+ goroutines = 3
+ requests = 100
+ }
hts := httptest.NewServer(HandlerFunc(NotFound))
defer hts.Close()
@@ -6021,6 +6029,143 @@ func TestStripPortFromHost(t *testing.T) {
}
}
+func TestServerContexts(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ type baseKey struct{}
+ type connKey struct{}
+ ch := make(chan context.Context, 1)
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
+ ch <- r.Context()
+ }))
+ ts.Config.BaseContext = func(ln net.Listener) context.Context {
+ if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
+ t.Errorf("unexpected onceClose listener type %T", ln)
+ }
+ return context.WithValue(context.Background(), baseKey{}, "base")
+ }
+ ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
+ if got, want := ctx.Value(baseKey{}), "base"; got != want {
+ t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
+ }
+ return context.WithValue(ctx, connKey{}, "conn")
+ }
+ ts.Start()
+ defer ts.Close()
+ res, err := ts.Client().Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ ctx := <-ch
+ if got, want := ctx.Value(baseKey{}), "base"; got != want {
+ t.Errorf("base context key = %#v; want %q", got, want)
+ }
+ if got, want := ctx.Value(connKey{}), "conn"; got != want {
+ t.Errorf("conn context key = %#v; want %q", got, want)
+ }
+}
+
+func TestServerContextsHTTP2(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ type baseKey struct{}
+ type connKey struct{}
+ ch := make(chan context.Context, 1)
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
+ if r.ProtoMajor != 2 {
+ t.Errorf("unexpected HTTP/1.x request")
+ }
+ ch <- r.Context()
+ }))
+ ts.Config.BaseContext = func(ln net.Listener) context.Context {
+ if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
+ t.Errorf("unexpected onceClose listener type %T", ln)
+ }
+ return context.WithValue(context.Background(), baseKey{}, "base")
+ }
+ ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
+ if got, want := ctx.Value(baseKey{}), "base"; got != want {
+ t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
+ }
+ return context.WithValue(ctx, connKey{}, "conn")
+ }
+ ts.TLS = &tls.Config{
+ NextProtos: []string{"h2", "http/1.1"},
+ }
+ ts.StartTLS()
+ defer ts.Close()
+ ts.Client().Transport.(*Transport).ForceAttemptHTTP2 = true
+ res, err := ts.Client().Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ ctx := <-ch
+ if got, want := ctx.Value(baseKey{}), "base"; got != want {
+ t.Errorf("base context key = %#v; want %q", got, want)
+ }
+ if got, want := ctx.Value(connKey{}), "conn"; got != want {
+ t.Errorf("conn context key = %#v; want %q", got, want)
+ }
+}
+
+// Issue 30710: ensure that as per the spec, a server responds
+// with 501 Not Implemented for unsupported transfer-encodings.
+func TestUnsupportedTransferEncodingsReturn501(t *testing.T) {
+ cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Write([]byte("Hello, World!"))
+ }))
+ defer cst.Close()
+
+ serverURL, err := url.Parse(cst.URL)
+ if err != nil {
+ t.Fatalf("Failed to parse server URL: %v", err)
+ }
+
+ unsupportedTEs := []string{
+ "fugazi",
+ "foo-bar",
+ "unknown",
+ }
+
+ for _, badTE := range unsupportedTEs {
+ http1ReqBody := fmt.Sprintf(""+
+ "POST / HTTP/1.1\r\nConnection: close\r\n"+
+ "Host: localhost\r\nTransfer-Encoding: %s\r\n\r\n", badTE)
+
+ gotBody, err := fetchWireResponse(serverURL.Host, []byte(http1ReqBody))
+ if err != nil {
+ t.Errorf("%q. unexpected error: %v", badTE, err)
+ continue
+ }
+
+ wantBody := fmt.Sprintf("" +
+ "HTTP/1.1 501 Not Implemented\r\nContent-Type: text/plain; charset=utf-8\r\n" +
+ "Connection: close\r\n\r\nUnsupported transfer encoding")
+
+ if string(gotBody) != wantBody {
+ t.Errorf("%q. body\ngot\n%q\nwant\n%q", badTE, gotBody, wantBody)
+ }
+ }
+}
+
+// fetchWireResponse is a helper for dialing to host,
+// sending http1ReqBody as the payload and retrieving
+// the response as it was sent on the wire.
+func fetchWireResponse(host string, http1ReqBody []byte) ([]byte, error) {
+ conn, err := net.Dial("tcp", host)
+ if err != nil {
+ return nil, err
+ }
+ defer conn.Close()
+
+ if _, err := conn.Write(http1ReqBody); err != nil {
+ return nil, err
+ }
+ return ioutil.ReadAll(conn)
+}
+
func BenchmarkResponseStatusLine(b *testing.B) {
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go
index cbf5673..74569bf 100644
--- a/libgo/go/net/http/server.go
+++ b/libgo/go/net/http/server.go
@@ -29,7 +29,7 @@ import (
"sync/atomic"
"time"
- "internal/x/net/http/httpguts"
+ "golang.org/x/net/http/httpguts"
)
// Errors used by the HTTP server.
@@ -749,10 +749,8 @@ func (cr *connReader) handleReadError(_ error) {
// may be called from multiple goroutines.
func (cr *connReader) closeNotify() {
res, _ := cr.conn.curReq.Load().(*response)
- if res != nil {
- if atomic.CompareAndSwapInt32(&res.didCloseNotify, 0, 1) {
- res.closeNotifyCh <- true
- }
+ if res != nil && atomic.CompareAndSwapInt32(&res.didCloseNotify, 0, 1) {
+ res.closeNotifyCh <- true
}
}
@@ -1060,7 +1058,7 @@ func (w *response) Header() Header {
// Accessing the header between logically writing it
// and physically writing it means we need to allocate
// a clone to snapshot the logically written state.
- w.cw.header = w.handlerHeader.clone()
+ w.cw.header = w.handlerHeader.Clone()
}
w.calledHeader = true
return w.handlerHeader
@@ -1134,7 +1132,7 @@ func (w *response) WriteHeader(code int) {
w.status = code
if w.calledHeader && w.cw.header == nil {
- w.cw.header = w.handlerHeader.clone()
+ w.cw.header = w.handlerHeader.Clone()
}
if cl := w.handlerHeader.get("Content-Length"); cl != "" {
@@ -1803,7 +1801,7 @@ func (c *conn) serve(ctx context.Context) {
*c.tlsState = tlsConn.ConnectionState()
if proto := c.tlsState.NegotiatedProtocol; validNPN(proto) {
if fn := c.server.TLSNextProto[proto]; fn != nil {
- h := initNPNRequest{tlsConn, serverHandler{c.server}}
+ h := initNPNRequest{ctx, tlsConn, serverHandler{c.server}}
fn(c.server, tlsConn, h)
}
return
@@ -1829,7 +1827,8 @@ func (c *conn) serve(ctx context.Context) {
if err != nil {
const errorHeaders = "\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\n"
- if err == errTooLarge {
+ switch {
+ case err == errTooLarge:
// Their HTTP client may or may not be
// able to read this if we're
// responding to them and hanging up
@@ -1839,18 +1838,31 @@ func (c *conn) serve(ctx context.Context) {
fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr)
c.closeWriteAndWait()
return
- }
- if isCommonNetReadError(err) {
+
+ case isUnsupportedTEError(err):
+ // Respond as per RFC 7230 Section 3.3.1 which says,
+ // A server that receives a request message with a
+ // transfer coding it does not understand SHOULD
+ // respond with 501 (Unimplemented).
+ code := StatusNotImplemented
+
+ // We purposefully aren't echoing back the transfer-encoding's value,
+ // so as to mitigate the risk of cross side scripting by an attacker.
+ fmt.Fprintf(c.rwc, "HTTP/1.1 %d %s%sUnsupported transfer encoding", code, StatusText(code), errorHeaders)
+ return
+
+ case isCommonNetReadError(err):
return // don't reply
- }
- publicErr := "400 Bad Request"
- if v, ok := err.(badRequestError); ok {
- publicErr = publicErr + ": " + string(v)
- }
+ default:
+ publicErr := "400 Bad Request"
+ if v, ok := err.(badRequestError); ok {
+ publicErr = publicErr + ": " + string(v)
+ }
- fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr)
- return
+ fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr)
+ return
+ }
}
// Expect 100 Continue support
@@ -2505,7 +2517,9 @@ type Server struct {
// ReadHeaderTimeout is the amount of time allowed to read
// request headers. The connection's read deadline is reset
// after reading the headers and the Handler can decide what
- // is considered too slow for the body.
+ // is considered too slow for the body. If ReadHeaderTimeout
+ // is zero, the value of ReadTimeout is used. If both are
+ // zero, there is no timeout.
ReadHeaderTimeout time.Duration
// WriteTimeout is the maximum duration before timing out
@@ -2517,7 +2531,7 @@ type Server struct {
// IdleTimeout is the maximum amount of time to wait for the
// next request when keep-alives are enabled. If IdleTimeout
// is zero, the value of ReadTimeout is used. If both are
- // zero, ReadHeaderTimeout is used.
+ // zero, there is no timeout.
IdleTimeout time.Duration
// MaxHeaderBytes controls the maximum number of bytes the
@@ -2549,6 +2563,20 @@ type Server struct {
// If nil, logging is done via the log package's standard logger.
ErrorLog *log.Logger
+ // BaseContext optionally specifies a function that returns
+ // the base context for incoming requests on this server.
+ // The provided Listener is the specific Listener that's
+ // about to start accepting requests.
+ // If BaseContext is nil, the default is context.Background().
+ // If non-nil, it must return a non-nil context.
+ BaseContext func(net.Listener) context.Context
+
+ // ConnContext optionally specifies a function that modifies
+ // the context used for a new connection c. The provided ctx
+ // is derived from the base context and has a ServerContextKey
+ // value.
+ ConnContext func(ctx context.Context, c net.Conn) context.Context
+
disableKeepAlives int32 // accessed atomically.
inShutdown int32 // accessed atomically (non-zero means we're in Shutdown)
nextProtoOnce sync.Once // guards setupHTTP2_* init
@@ -2799,7 +2827,7 @@ func (srv *Server) ListenAndServe() error {
if err != nil {
return err
}
- return srv.Serve(tcpKeepAliveListener{ln.(*net.TCPListener)})
+ return srv.Serve(ln)
}
var testHookServerServe func(*Server, net.Listener) // used if non-nil
@@ -2845,6 +2873,7 @@ func (srv *Server) Serve(l net.Listener) error {
fn(srv, l) // call hook with unwrapped listener
}
+ origListener := l
l = &onceCloseListener{Listener: l}
defer l.Close()
@@ -2857,8 +2886,16 @@ func (srv *Server) Serve(l net.Listener) error {
}
defer srv.trackListener(&l, false)
- var tempDelay time.Duration // how long to sleep on accept failure
- baseCtx := context.Background() // base is always background, per Issue 16220
+ var tempDelay time.Duration // how long to sleep on accept failure
+
+ baseCtx := context.Background()
+ if srv.BaseContext != nil {
+ baseCtx = srv.BaseContext(origListener)
+ if baseCtx == nil {
+ panic("BaseContext returned a nil context")
+ }
+ }
+
ctx := context.WithValue(baseCtx, ServerContextKey, srv)
for {
rw, e := l.Accept()
@@ -2883,6 +2920,12 @@ func (srv *Server) Serve(l net.Listener) error {
}
return e
}
+ if cc := srv.ConnContext; cc != nil {
+ ctx = cc(ctx, rw)
+ if ctx == nil {
+ panic("ConnContext returned nil")
+ }
+ }
tempDelay = 0
c := srv.newConn(rw)
c.setState(c.rwc, StateNew) // before Serve can return
@@ -3083,7 +3126,7 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error {
defer ln.Close()
- return srv.ServeTLS(tcpKeepAliveListener{ln.(*net.TCPListener)}, certFile, keyFile)
+ return srv.ServeTLS(ln, certFile, keyFile)
}
// setupHTTP2_ServeTLS conditionally configures HTTP/2 on
@@ -3228,6 +3271,25 @@ type timeoutWriter struct {
code int
}
+var _ Pusher = (*timeoutWriter)(nil)
+var _ Flusher = (*timeoutWriter)(nil)
+
+// Push implements the Pusher interface.
+func (tw *timeoutWriter) Push(target string, opts *PushOptions) error {
+ if pusher, ok := tw.w.(Pusher); ok {
+ return pusher.Push(target, opts)
+ }
+ return ErrNotSupported
+}
+
+// Flush implements the Flusher interface.
+func (tw *timeoutWriter) Flush() {
+ f, ok := tw.w.(Flusher)
+ if ok {
+ f.Flush()
+ }
+}
+
func (tw *timeoutWriter) Header() Header { return tw.h }
func (tw *timeoutWriter) Write(p []byte) (int, error) {
@@ -3257,24 +3319,6 @@ func (tw *timeoutWriter) writeHeader(code int) {
tw.code = code
}
-// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
-// connections. It's used by ListenAndServe and ListenAndServeTLS so
-// dead TCP connections (e.g. closing laptop mid-download) eventually
-// go away.
-type tcpKeepAliveListener struct {
- *net.TCPListener
-}
-
-func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
- tc, err := ln.AcceptTCP()
- if err != nil {
- return nil, err
- }
- tc.SetKeepAlive(true)
- tc.SetKeepAlivePeriod(3 * time.Minute)
- return tc, nil
-}
-
// onceCloseListener wraps a net.Listener, protecting it from
// multiple Close calls.
type onceCloseListener struct {
@@ -3310,10 +3354,17 @@ func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) {
// uninitialized fields in its *Request. Such partially-initialized
// Requests come from NPN protocol handlers.
type initNPNRequest struct {
- c *tls.Conn
- h serverHandler
+ ctx context.Context
+ c *tls.Conn
+ h serverHandler
}
+// BaseContext is an exported but unadvertised http.Handler method
+// recognized by x/net/http2 to pass down a context; the TLSNextProto
+// API predates context support so we shoehorn through the only
+// interface we have available.
+func (h initNPNRequest) BaseContext() context.Context { return h.ctx }
+
func (h initNPNRequest) ServeHTTP(rw ResponseWriter, req *Request) {
if req.TLS == nil {
req.TLS = &tls.ConnectionState{}
diff --git a/libgo/go/net/http/sniff.go b/libgo/go/net/http/sniff.go
index c1494ab..67a7151 100644
--- a/libgo/go/net/http/sniff.go
+++ b/libgo/go/net/http/sniff.go
@@ -37,6 +37,8 @@ func DetectContentType(data []byte) string {
return "application/octet-stream" // fallback
}
+// isWS reports whether the provided byte is a whitespace byte (0xWS)
+// as defined in https://mimesniff.spec.whatwg.org/#terminology.
func isWS(b byte) bool {
switch b {
case '\t', '\n', '\x0c', '\r', ' ':
@@ -45,6 +47,16 @@ func isWS(b byte) bool {
return false
}
+// isTT reports whether the provided byte is a tag-terminating byte (0xTT)
+// as defined in https://mimesniff.spec.whatwg.org/#terminology.
+func isTT(b byte) bool {
+ switch b {
+ case ' ', '>':
+ return true
+ }
+ return false
+}
+
type sniffSig interface {
// match returns the MIME type of the data, or "" if unknown.
match(data []byte, firstNonWS int) string
@@ -69,33 +81,57 @@ var sniffSignatures = []sniffSig{
htmlSig("<BR"),
htmlSig("<P"),
htmlSig("<!--"),
-
- &maskedSig{mask: []byte("\xFF\xFF\xFF\xFF\xFF"), pat: []byte("<?xml"), skipWS: true, ct: "text/xml; charset=utf-8"},
-
+ &maskedSig{
+ mask: []byte("\xFF\xFF\xFF\xFF\xFF"),
+ pat: []byte("<?xml"),
+ skipWS: true,
+ ct: "text/xml; charset=utf-8"},
&exactSig{[]byte("%PDF-"), "application/pdf"},
&exactSig{[]byte("%!PS-Adobe-"), "application/postscript"},
// UTF BOMs.
- &maskedSig{mask: []byte("\xFF\xFF\x00\x00"), pat: []byte("\xFE\xFF\x00\x00"), ct: "text/plain; charset=utf-16be"},
- &maskedSig{mask: []byte("\xFF\xFF\x00\x00"), pat: []byte("\xFF\xFE\x00\x00"), ct: "text/plain; charset=utf-16le"},
- &maskedSig{mask: []byte("\xFF\xFF\xFF\x00"), pat: []byte("\xEF\xBB\xBF\x00"), ct: "text/plain; charset=utf-8"},
+ &maskedSig{
+ mask: []byte("\xFF\xFF\x00\x00"),
+ pat: []byte("\xFE\xFF\x00\x00"),
+ ct: "text/plain; charset=utf-16be",
+ },
+ &maskedSig{
+ mask: []byte("\xFF\xFF\x00\x00"),
+ pat: []byte("\xFF\xFE\x00\x00"),
+ ct: "text/plain; charset=utf-16le",
+ },
+ &maskedSig{
+ mask: []byte("\xFF\xFF\xFF\x00"),
+ pat: []byte("\xEF\xBB\xBF\x00"),
+ ct: "text/plain; charset=utf-8",
+ },
+ // Image types
+ // For posterity, we originally returned "image/vnd.microsoft.icon" from
+ // https://tools.ietf.org/html/draft-ietf-websec-mime-sniff-03#section-7
+ // https://codereview.appspot.com/4746042
+ // but that has since been replaced with "image/x-icon" in Section 6.2
+ // of https://mimesniff.spec.whatwg.org/#matching-an-image-type-pattern
+ &exactSig{[]byte("\x00\x00\x01\x00"), "image/x-icon"},
+ &exactSig{[]byte("\x00\x00\x02\x00"), "image/x-icon"},
+ &exactSig{[]byte("BM"), "image/bmp"},
&exactSig{[]byte("GIF87a"), "image/gif"},
&exactSig{[]byte("GIF89a"), "image/gif"},
- &exactSig{[]byte("\x89\x50\x4E\x47\x0D\x0A\x1A\x0A"), "image/png"},
- &exactSig{[]byte("\xFF\xD8\xFF"), "image/jpeg"},
- &exactSig{[]byte("BM"), "image/bmp"},
&maskedSig{
mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF\xFF\xFF"),
pat: []byte("RIFF\x00\x00\x00\x00WEBPVP"),
ct: "image/webp",
},
- &exactSig{[]byte("\x00\x00\x01\x00"), "image/vnd.microsoft.icon"},
+ &exactSig{[]byte("\x89PNG\x0D\x0A\x1A\x0A"), "image/png"},
+ &exactSig{[]byte("\xFF\xD8\xFF"), "image/jpeg"},
+ // Audio and Video types
+ // Enforce the pattern match ordering as prescribed in
+ // https://mimesniff.spec.whatwg.org/#matching-an-audio-or-video-type-pattern
&maskedSig{
- mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF"),
- pat: []byte("RIFF\x00\x00\x00\x00WAVE"),
- ct: "audio/wave",
+ mask: []byte("\xFF\xFF\xFF\xFF"),
+ pat: []byte(".snd"),
+ ct: "audio/basic",
},
&maskedSig{
mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF"),
@@ -103,9 +139,9 @@ var sniffSignatures = []sniffSig{
ct: "audio/aiff",
},
&maskedSig{
- mask: []byte("\xFF\xFF\xFF\xFF"),
- pat: []byte(".snd"),
- ct: "audio/basic",
+ mask: []byte("\xFF\xFF\xFF"),
+ pat: []byte("ID3"),
+ ct: "audio/mpeg",
},
&maskedSig{
mask: []byte("\xFF\xFF\xFF\xFF\xFF"),
@@ -118,20 +154,24 @@ var sniffSignatures = []sniffSig{
ct: "audio/midi",
},
&maskedSig{
- mask: []byte("\xFF\xFF\xFF"),
- pat: []byte("ID3"),
- ct: "audio/mpeg",
- },
- &maskedSig{
mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF"),
pat: []byte("RIFF\x00\x00\x00\x00AVI "),
ct: "video/avi",
},
+ &maskedSig{
+ mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF"),
+ pat: []byte("RIFF\x00\x00\x00\x00WAVE"),
+ ct: "audio/wave",
+ },
+ // 6.2.0.2. video/mp4
+ mp4Sig{},
+ // 6.2.0.3. video/webm
+ &exactSig{[]byte("\x1A\x45\xDF\xA3"), "video/webm"},
- // Fonts
+ // Font types
&maskedSig{
// 34 NULL bytes followed by the string "LP"
- pat: []byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x4C\x50"),
+ pat: []byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00LP"),
// 34 NULL bytes followed by \xF\xF
mask: []byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF"),
ct: "application/vnd.ms-fontobject",
@@ -142,15 +182,20 @@ var sniffSignatures = []sniffSig{
&exactSig{[]byte("wOFF"), "font/woff"},
&exactSig{[]byte("wOF2"), "font/woff2"},
- &exactSig{[]byte("\x1A\x45\xDF\xA3"), "video/webm"},
- &exactSig{[]byte("\x52\x61\x72\x20\x1A\x07\x00"), "application/x-rar-compressed"},
- &exactSig{[]byte("\x50\x4B\x03\x04"), "application/zip"},
+ // Archive types
&exactSig{[]byte("\x1F\x8B\x08"), "application/x-gzip"},
+ &exactSig{[]byte("PK\x03\x04"), "application/zip"},
+ // RAR's signatures are incorrectly defined by the MIME spec as per
+ // https://github.com/whatwg/mimesniff/issues/63
+ // However, RAR Labs correctly defines it at:
+ // https://www.rarlab.com/technote.htm#rarsign
+ // so we use the definition from RAR Labs.
+ // TODO: do whatever the spec ends up doing.
+ &exactSig{[]byte("Rar!\x1A\x07\x00"), "application/x-rar-compressed"}, // RAR v1.5-v4.0
+ &exactSig{[]byte("Rar!\x1A\x07\x01\x00"), "application/x-rar-compressed"}, // RAR v5+
&exactSig{[]byte("\x00\x61\x73\x6D"), "application/wasm"},
- mp4Sig{},
-
textSig{}, // should be last
}
@@ -182,12 +227,12 @@ func (m *maskedSig) match(data []byte, firstNonWS int) string {
if len(m.pat) != len(m.mask) {
return ""
}
- if len(data) < len(m.mask) {
+ if len(data) < len(m.pat) {
return ""
}
- for i, mask := range m.mask {
- db := data[i] & mask
- if db != m.pat[i] {
+ for i, pb := range m.pat {
+ maskedData := data[i] & m.mask[i]
+ if maskedData != pb {
return ""
}
}
@@ -210,8 +255,8 @@ func (h htmlSig) match(data []byte, firstNonWS int) string {
return ""
}
}
- // Next byte must be space or right angle bracket.
- if db := data[len(h)]; db != ' ' && db != '>' {
+ // Next byte must be a tag-terminating byte(0xTT).
+ if !isTT(data[len(h)]) {
return ""
}
return "text/html; charset=utf-8"
@@ -229,7 +274,7 @@ func (mp4Sig) match(data []byte, firstNonWS int) string {
return ""
}
boxSize := int(binary.BigEndian.Uint32(data[:4]))
- if boxSize%4 != 0 || len(data) < boxSize {
+ if len(data) < boxSize || boxSize%4 != 0 {
return ""
}
if !bytes.Equal(data[4:8], mp4ftype) {
@@ -237,7 +282,7 @@ func (mp4Sig) match(data []byte, firstNonWS int) string {
}
for st := 8; st < boxSize; st += 4 {
if st == 12 {
- // minor version number
+ // Ignores the four bytes that correspond to the version number of the "major brand".
continue
}
if bytes.Equal(data[st:st+3], mp4) {
diff --git a/libgo/go/net/http/sniff_test.go b/libgo/go/net/http/sniff_test.go
index b4d3c9f..a1157a0 100644
--- a/libgo/go/net/http/sniff_test.go
+++ b/libgo/go/net/http/sniff_test.go
@@ -36,8 +36,14 @@ var sniffTests = []struct {
{"XML", []byte("\n<?xml!"), "text/xml; charset=utf-8"},
// Image types.
+ {"Windows icon", []byte("\x00\x00\x01\x00"), "image/x-icon"},
+ {"Windows cursor", []byte("\x00\x00\x02\x00"), "image/x-icon"},
+ {"BMP image", []byte("BM..."), "image/bmp"},
{"GIF 87a", []byte(`GIF87a`), "image/gif"},
{"GIF 89a", []byte(`GIF89a...`), "image/gif"},
+ {"WEBP image", []byte("RIFF\x00\x00\x00\x00WEBPVP"), "image/webp"},
+ {"PNG image", []byte("\x89PNG\x0D\x0A\x1A\x0A"), "image/png"},
+ {"JPEG image", []byte("\xFF\xD8\xFF"), "image/jpeg"},
// Audio types.
{"MIDI audio", []byte("MThd\x00\x00\x00\x06\x00\x01"), "audio/midi"},
@@ -66,6 +72,12 @@ var sniffTests = []struct {
{"woff sample I", []byte("\x77\x4f\x46\x46\x00\x01\x00\x00\x00\x00\x30\x54\x00\x0d\x00\x00"), "font/woff"},
{"woff2 sample", []byte("\x77\x4f\x46\x32\x00\x01\x00\x00\x00"), "font/woff2"},
{"wasm sample", []byte("\x00\x61\x73\x6d\x01\x00"), "application/wasm"},
+
+ // Archive types
+ {"RAR v1.5-v4.0", []byte("Rar!\x1A\x07\x00"), "application/x-rar-compressed"},
+ {"RAR v5+", []byte("Rar!\x1A\x07\x01\x00"), "application/x-rar-compressed"},
+ {"Incorrect RAR v1.5-v4.0", []byte("Rar \x1A\x07\x00"), "application/octet-stream"},
+ {"Incorrect RAR v5+", []byte("Rar \x1A\x07\x01\x00"), "application/octet-stream"},
}
func TestDetectContentType(t *testing.T) {
diff --git a/libgo/go/net/http/status.go b/libgo/go/net/http/status.go
index 086f3d1..286315f 100644
--- a/libgo/go/net/http/status.go
+++ b/libgo/go/net/http/status.go
@@ -10,6 +10,7 @@ const (
StatusContinue = 100 // RFC 7231, 6.2.1
StatusSwitchingProtocols = 101 // RFC 7231, 6.2.2
StatusProcessing = 102 // RFC 2518, 10.1
+ StatusEarlyHints = 103 // RFC 8297
StatusOK = 200 // RFC 7231, 6.3.1
StatusCreated = 201 // RFC 7231, 6.3.2
@@ -79,6 +80,7 @@ var statusText = map[int]string{
StatusContinue: "Continue",
StatusSwitchingProtocols: "Switching Protocols",
StatusProcessing: "Processing",
+ StatusEarlyHints: "Early Hints",
StatusOK: "OK",
StatusCreated: "Created",
diff --git a/libgo/go/net/http/transfer.go b/libgo/go/net/http/transfer.go
index e8a93e9..2e01a07 100644
--- a/libgo/go/net/http/transfer.go
+++ b/libgo/go/net/http/transfer.go
@@ -21,7 +21,7 @@ import (
"sync"
"time"
- "internal/x/net/http/httpguts"
+ "golang.org/x/net/http/httpguts"
)
// ErrLineTooLong is returned when reading request or response bodies
@@ -53,19 +53,6 @@ func (br *byteReader) Read(p []byte) (n int, err error) {
return 1, io.EOF
}
-// transferBodyReader is an io.Reader that reads from tw.Body
-// and records any non-EOF error in tw.bodyReadError.
-// It is exactly 1 pointer wide to avoid allocations into interfaces.
-type transferBodyReader struct{ tw *transferWriter }
-
-func (br transferBodyReader) Read(p []byte) (n int, err error) {
- n, err = br.tw.Body.Read(p)
- if err != nil && err != io.EOF {
- br.tw.bodyReadError = err
- }
- return
-}
-
// transferWriter inspects the fields of a user-supplied Request or Response,
// sanitizes them without changing the user object and provides methods for
// writing the respective header, body and trailer in wire format.
@@ -347,15 +334,18 @@ func (t *transferWriter) writeBody(w io.Writer) error {
var err error
var ncopy int64
- // Write body
+ // Write body. We "unwrap" the body first if it was wrapped in a
+ // nopCloser. This is to ensure that we can take advantage of
+ // OS-level optimizations in the event that the body is an
+ // *os.File.
if t.Body != nil {
- var body = transferBodyReader{t}
+ var body = t.unwrapBody()
if chunked(t.TransferEncoding) {
if bw, ok := w.(*bufio.Writer); ok && !t.IsResponse {
w = &internal.FlushAfterChunkWriter{Writer: bw}
}
cw := internal.NewChunkedWriter(w)
- _, err = io.Copy(cw, body)
+ _, err = t.doBodyCopy(cw, body)
if err == nil {
err = cw.Close()
}
@@ -364,14 +354,14 @@ func (t *transferWriter) writeBody(w io.Writer) error {
if t.Method == "CONNECT" {
dst = bufioFlushWriter{dst}
}
- ncopy, err = io.Copy(dst, body)
+ ncopy, err = t.doBodyCopy(dst, body)
} else {
- ncopy, err = io.Copy(w, io.LimitReader(body, t.ContentLength))
+ ncopy, err = t.doBodyCopy(w, io.LimitReader(body, t.ContentLength))
if err != nil {
return err
}
var nextra int64
- nextra, err = io.Copy(ioutil.Discard, body)
+ nextra, err = t.doBodyCopy(ioutil.Discard, body)
ncopy += nextra
}
if err != nil {
@@ -402,6 +392,31 @@ func (t *transferWriter) writeBody(w io.Writer) error {
return err
}
+// doBodyCopy wraps a copy operation, with any resulting error also
+// being saved in bodyReadError.
+//
+// This function is only intended for use in writeBody.
+func (t *transferWriter) doBodyCopy(dst io.Writer, src io.Reader) (n int64, err error) {
+ n, err = io.Copy(dst, src)
+ if err != nil && err != io.EOF {
+ t.bodyReadError = err
+ }
+ return
+}
+
+// unwrapBodyReader unwraps the body's inner reader if it's a
+// nopCloser. This is to ensure that body writes sourced from local
+// files (*os.File types) are properly optimized.
+//
+// This function is only intended for use in writeBody.
+func (t *transferWriter) unwrapBody() io.Reader {
+ if reflect.TypeOf(t.Body) == nopCloserType {
+ return reflect.ValueOf(t.Body).Field(0).Interface().(io.Reader)
+ }
+
+ return t.Body
+}
+
type transferReader struct {
// Input
Header Header
@@ -574,6 +589,22 @@ func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" }
// Checks whether the encoding is explicitly "identity".
func isIdentity(te []string) bool { return len(te) == 1 && te[0] == "identity" }
+// unsupportedTEError reports unsupported transfer-encodings.
+type unsupportedTEError struct {
+ err string
+}
+
+func (uste *unsupportedTEError) Error() string {
+ return uste.err
+}
+
+// isUnsupportedTEError checks if the error is of type
+// unsupportedTEError. It is usually invoked with a non-nil err.
+func isUnsupportedTEError(err error) bool {
+ _, ok := err.(*unsupportedTEError)
+ return ok
+}
+
// fixTransferEncoding sanitizes t.TransferEncoding, if needed.
func (t *transferReader) fixTransferEncoding() error {
raw, present := t.Header["Transfer-Encoding"]
@@ -600,7 +631,7 @@ func (t *transferReader) fixTransferEncoding() error {
break
}
if encoding != "chunked" {
- return &badStringError{"unsupported transfer encoding", encoding}
+ return &unsupportedTEError{fmt.Sprintf("unsupported transfer encoding: %q", encoding)}
}
te = te[0 : len(te)+1]
te[len(te)-1] = encoding
diff --git a/libgo/go/net/http/transfer_test.go b/libgo/go/net/http/transfer_test.go
index 993ea4e..65009ee 100644
--- a/libgo/go/net/http/transfer_test.go
+++ b/libgo/go/net/http/transfer_test.go
@@ -7,8 +7,12 @@ package http
import (
"bufio"
"bytes"
+ "crypto/rand"
+ "fmt"
"io"
"io/ioutil"
+ "os"
+ "reflect"
"strings"
"testing"
)
@@ -90,3 +94,219 @@ func TestDetectInMemoryReaders(t *testing.T) {
}
}
}
+
+type mockTransferWriter struct {
+ CalledReader io.Reader
+ WriteCalled bool
+}
+
+var _ io.ReaderFrom = (*mockTransferWriter)(nil)
+
+func (w *mockTransferWriter) ReadFrom(r io.Reader) (int64, error) {
+ w.CalledReader = r
+ return io.Copy(ioutil.Discard, r)
+}
+
+func (w *mockTransferWriter) Write(p []byte) (int, error) {
+ w.WriteCalled = true
+ return ioutil.Discard.Write(p)
+}
+
+func TestTransferWriterWriteBodyReaderTypes(t *testing.T) {
+ fileType := reflect.TypeOf(&os.File{})
+ bufferType := reflect.TypeOf(&bytes.Buffer{})
+
+ nBytes := int64(1 << 10)
+ newFileFunc := func() (r io.Reader, done func(), err error) {
+ f, err := ioutil.TempFile("", "net-http-newfilefunc")
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // Write some bytes to the file to enable reading.
+ if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
+ return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
+ }
+ if _, err := f.Seek(0, 0); err != nil {
+ return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
+ }
+
+ done = func() {
+ f.Close()
+ os.Remove(f.Name())
+ }
+
+ return f, done, nil
+ }
+
+ newBufferFunc := func() (io.Reader, func(), error) {
+ return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
+ }
+
+ cases := []struct {
+ name string
+ bodyFunc func() (io.Reader, func(), error)
+ method string
+ contentLength int64
+ transferEncoding []string
+ limitedReader bool
+ expectedReader reflect.Type
+ expectedWrite bool
+ }{
+ {
+ name: "file, non-chunked, size set",
+ bodyFunc: newFileFunc,
+ method: "PUT",
+ contentLength: nBytes,
+ limitedReader: true,
+ expectedReader: fileType,
+ },
+ {
+ name: "file, non-chunked, size set, nopCloser wrapped",
+ method: "PUT",
+ bodyFunc: func() (io.Reader, func(), error) {
+ r, cleanup, err := newFileFunc()
+ return ioutil.NopCloser(r), cleanup, err
+ },
+ contentLength: nBytes,
+ limitedReader: true,
+ expectedReader: fileType,
+ },
+ {
+ name: "file, non-chunked, negative size",
+ method: "PUT",
+ bodyFunc: newFileFunc,
+ contentLength: -1,
+ expectedReader: fileType,
+ },
+ {
+ name: "file, non-chunked, CONNECT, negative size",
+ method: "CONNECT",
+ bodyFunc: newFileFunc,
+ contentLength: -1,
+ expectedReader: fileType,
+ },
+ {
+ name: "file, chunked",
+ method: "PUT",
+ bodyFunc: newFileFunc,
+ transferEncoding: []string{"chunked"},
+ expectedWrite: true,
+ },
+ {
+ name: "buffer, non-chunked, size set",
+ bodyFunc: newBufferFunc,
+ method: "PUT",
+ contentLength: nBytes,
+ limitedReader: true,
+ expectedReader: bufferType,
+ },
+ {
+ name: "buffer, non-chunked, size set, nopCloser wrapped",
+ method: "PUT",
+ bodyFunc: func() (io.Reader, func(), error) {
+ r, cleanup, err := newBufferFunc()
+ return ioutil.NopCloser(r), cleanup, err
+ },
+ contentLength: nBytes,
+ limitedReader: true,
+ expectedReader: bufferType,
+ },
+ {
+ name: "buffer, non-chunked, negative size",
+ method: "PUT",
+ bodyFunc: newBufferFunc,
+ contentLength: -1,
+ expectedWrite: true,
+ },
+ {
+ name: "buffer, non-chunked, CONNECT, negative size",
+ method: "CONNECT",
+ bodyFunc: newBufferFunc,
+ contentLength: -1,
+ expectedWrite: true,
+ },
+ {
+ name: "buffer, chunked",
+ method: "PUT",
+ bodyFunc: newBufferFunc,
+ transferEncoding: []string{"chunked"},
+ expectedWrite: true,
+ },
+ }
+
+ for _, tc := range cases {
+ t.Run(tc.name, func(t *testing.T) {
+ body, cleanup, err := tc.bodyFunc()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ mw := &mockTransferWriter{}
+ tw := &transferWriter{
+ Body: body,
+ ContentLength: tc.contentLength,
+ TransferEncoding: tc.transferEncoding,
+ }
+
+ if err := tw.writeBody(mw); err != nil {
+ t.Fatal(err)
+ }
+
+ if tc.expectedReader != nil {
+ if mw.CalledReader == nil {
+ t.Fatal("did not call ReadFrom")
+ }
+
+ var actualReader reflect.Type
+ lr, ok := mw.CalledReader.(*io.LimitedReader)
+ if ok && tc.limitedReader {
+ actualReader = reflect.TypeOf(lr.R)
+ } else {
+ actualReader = reflect.TypeOf(mw.CalledReader)
+ }
+
+ if tc.expectedReader != actualReader {
+ t.Fatalf("got reader %T want %T", actualReader, tc.expectedReader)
+ }
+ }
+
+ if tc.expectedWrite && !mw.WriteCalled {
+ t.Fatal("did not invoke Write")
+ }
+ })
+ }
+}
+
+func TestFixTransferEncoding(t *testing.T) {
+ tests := []struct {
+ hdr Header
+ wantErr error
+ }{
+ {
+ hdr: Header{"Transfer-Encoding": {"fugazi"}},
+ wantErr: &unsupportedTEError{`unsupported transfer encoding: "fugazi"`},
+ },
+ {
+ hdr: Header{"Transfer-Encoding": {"chunked, chunked", "identity", "chunked"}},
+ wantErr: &badStringError{"too many transfer encodings", "chunked,chunked"},
+ },
+ {
+ hdr: Header{"Transfer-Encoding": {"chunked"}},
+ wantErr: nil,
+ },
+ }
+
+ for i, tt := range tests {
+ tr := &transferReader{
+ Header: tt.hdr,
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ }
+ gotErr := tr.fixTransferEncoding()
+ if !reflect.DeepEqual(gotErr, tt.wantErr) {
+ t.Errorf("%d.\ngot error:\n%v\nwant error:\n%v\n\n", i, gotErr, tt.wantErr)
+ }
+ }
+}
diff --git a/libgo/go/net/http/transport.go b/libgo/go/net/http/transport.go
index a8c5efe..26f642a 100644
--- a/libgo/go/net/http/transport.go
+++ b/libgo/go/net/http/transport.go
@@ -30,8 +30,8 @@ import (
"sync/atomic"
"time"
- "internal/x/net/http/httpguts"
- "internal/x/net/http/httpproxy"
+ "golang.org/x/net/http/httpguts"
+ "golang.org/x/net/http/httpproxy"
)
// DefaultTransport is the default implementation of Transport and is
@@ -46,6 +46,7 @@ var DefaultTransport RoundTripper = &Transport{
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
+ ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
@@ -253,10 +254,76 @@ type Transport struct {
// Zero means to use a default limit.
MaxResponseHeaderBytes int64
+ // WriteBufferSize specifies the size of the write buffer used
+ // when writing to the transport.
+ // If zero, a default (currently 4KB) is used.
+ WriteBufferSize int
+
+ // ReadBufferSize specifies the size of the read buffer used
+ // when reading from the transport.
+ // If zero, a default (currently 4KB) is used.
+ ReadBufferSize int
+
// nextProtoOnce guards initialization of TLSNextProto and
// h2transport (via onceSetNextProtoDefaults)
- nextProtoOnce sync.Once
- h2transport h2Transport // non-nil if http2 wired up
+ nextProtoOnce sync.Once
+ h2transport h2Transport // non-nil if http2 wired up
+ tlsNextProtoWasNil bool // whether TLSNextProto was nil when the Once fired
+
+ // ForceAttemptHTTP2 controls whether HTTP/2 is enabled when a non-zero
+ // Dial, DialTLS, or DialContext func or TLSClientConfig is provided.
+ // By default, use of any those fields conservatively disables HTTP/2.
+ // To use a custom dialer or TLS config and still attempt HTTP/2
+ // upgrades, set this to true.
+ ForceAttemptHTTP2 bool
+}
+
+func (t *Transport) writeBufferSize() int {
+ if t.WriteBufferSize > 0 {
+ return t.WriteBufferSize
+ }
+ return 4 << 10
+}
+
+func (t *Transport) readBufferSize() int {
+ if t.ReadBufferSize > 0 {
+ return t.ReadBufferSize
+ }
+ return 4 << 10
+}
+
+// Clone returns a deep copy of t's exported fields.
+func (t *Transport) Clone() *Transport {
+ t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
+ t2 := &Transport{
+ Proxy: t.Proxy,
+ DialContext: t.DialContext,
+ Dial: t.Dial,
+ DialTLS: t.DialTLS,
+ TLSClientConfig: t.TLSClientConfig.Clone(),
+ TLSHandshakeTimeout: t.TLSHandshakeTimeout,
+ DisableKeepAlives: t.DisableKeepAlives,
+ DisableCompression: t.DisableCompression,
+ MaxIdleConns: t.MaxIdleConns,
+ MaxIdleConnsPerHost: t.MaxIdleConnsPerHost,
+ MaxConnsPerHost: t.MaxConnsPerHost,
+ IdleConnTimeout: t.IdleConnTimeout,
+ ResponseHeaderTimeout: t.ResponseHeaderTimeout,
+ ExpectContinueTimeout: t.ExpectContinueTimeout,
+ ProxyConnectHeader: t.ProxyConnectHeader.Clone(),
+ MaxResponseHeaderBytes: t.MaxResponseHeaderBytes,
+ ForceAttemptHTTP2: t.ForceAttemptHTTP2,
+ WriteBufferSize: t.WriteBufferSize,
+ ReadBufferSize: t.ReadBufferSize,
+ }
+ if !t.tlsNextProtoWasNil {
+ npm := map[string]func(authority string, c *tls.Conn) RoundTripper{}
+ for k, v := range t.TLSNextProto {
+ npm[k] = v
+ }
+ t2.TLSNextProto = npm
+ }
+ return t2
}
// h2Transport is the interface we expect to be able to call from
@@ -272,6 +339,7 @@ type h2Transport interface {
// onceSetNextProtoDefaults initializes TLSNextProto.
// It must be called via t.nextProtoOnce.Do.
func (t *Transport) onceSetNextProtoDefaults() {
+ t.tlsNextProtoWasNil = (t.TLSNextProto == nil)
if strings.Contains(os.Getenv("GODEBUG"), "http2client=0") {
return
}
@@ -296,12 +364,13 @@ func (t *Transport) onceSetNextProtoDefaults() {
// Transport.
return
}
- if t.TLSClientConfig != nil || t.Dial != nil || t.DialTLS != nil {
+ if !t.ForceAttemptHTTP2 && (t.TLSClientConfig != nil || t.Dial != nil || t.DialTLS != nil || t.DialContext != nil) {
// Be conservative and don't automatically enable
// http2 if they've specified a custom TLS config or
// custom dialers. Let them opt-in themselves via
// http2.ConfigureTransport so we don't surprise them
// by modifying their tls.Config. Issue 14275.
+ // However, if ForceAttemptHTTP2 is true, it overrides the above checks.
return
}
t2, err := http2configureTransport(t)
@@ -474,8 +543,8 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
var resp *Response
if pconn.alt != nil {
// HTTP/2 path.
- t.decHostConnCount(cm.key()) // don't count cached http2 conns toward conns per host
- t.setReqCanceler(req, nil) // not cancelable with CancelRequest
+ t.putOrCloseIdleConn(pconn)
+ t.setReqCanceler(req, nil) // not cancelable with CancelRequest
resp, err = pconn.alt.RoundTrip(req)
} else {
resp, err = pconn.roundTrip(treq)
@@ -483,7 +552,10 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
if err == nil {
return resp, nil
}
- if !pconn.shouldRetryRequest(req, err) {
+ if http2isNoCachedConnError(err) {
+ t.removeIdleConn(pconn)
+ t.decHostConnCount(cm.key()) // clean up the persistent connection
+ } else if !pconn.shouldRetryRequest(req, err) {
// Issue 16465: return underlying net.Conn.Read error from peek,
// as we've historically done.
if e, ok := err.(transportReadFromServerError); ok {
@@ -717,6 +789,8 @@ type transportReadFromServerError struct {
err error
}
+func (e transportReadFromServerError) Unwrap() error { return e.err }
+
func (e transportReadFromServerError) Error() string {
return fmt.Sprintf("net/http: Transport failed to read from server: %v", e.err)
}
@@ -746,9 +820,6 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error {
if pconn.isBroken() {
return errConnBroken
}
- if pconn.alt != nil {
- return errNotCachingH2Conn
- }
pconn.markReused()
key := pconn.cacheKey
@@ -797,7 +868,10 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error {
if pconn.idleTimer != nil {
pconn.idleTimer.Reset(t.IdleConnTimeout)
} else {
- pconn.idleTimer = time.AfterFunc(t.IdleConnTimeout, pconn.closeConnIfStillIdle)
+ // idleTimer does not apply to HTTP/2
+ if pconn.alt == nil {
+ pconn.idleTimer = time.AfterFunc(t.IdleConnTimeout, pconn.closeConnIfStillIdle)
+ }
}
}
pconn.idleAt = time.Now()
@@ -1031,7 +1105,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
t.decHostConnCount(cmKey)
select {
case <-req.Cancel:
- // It was an error due to cancelation, so prioritize that
+ // It was an error due to cancellation, so prioritize that
// error value. (Issue 16049)
return nil, errRequestCanceledConn
case <-req.Context().Done():
@@ -1042,7 +1116,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
}
return nil, err
default:
- // It wasn't an error due to cancelation, so
+ // It wasn't an error due to cancellation, so
// return the original error message:
return nil, v.err
}
@@ -1345,15 +1419,16 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" {
if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok {
- return &persistConn{alt: next(cm.targetAddr, pconn.conn.(*tls.Conn))}, nil
+ return &persistConn{t: t, cacheKey: pconn.cacheKey, alt: next(cm.targetAddr, pconn.conn.(*tls.Conn))}, nil
}
}
if t.MaxConnsPerHost > 0 {
pconn.conn = &connCloseListener{Conn: pconn.conn, t: t, cmKey: pconn.cacheKey}
}
- pconn.br = bufio.NewReader(pconn)
- pconn.bw = bufio.NewWriter(persistConnWriter{pconn})
+ pconn.br = bufio.NewReaderSize(pconn, t.readBufferSize())
+ pconn.bw = bufio.NewWriterSize(persistConnWriter{pconn}, t.writeBufferSize())
+
go pconn.readLoop()
go pconn.writeLoop()
return pconn, nil
@@ -1375,6 +1450,17 @@ func (w persistConnWriter) Write(p []byte) (n int, err error) {
return
}
+// ReadFrom exposes persistConnWriter's underlying Conn to io.Copy and if
+// the Conn implements io.ReaderFrom, it can take advantage of optimizations
+// such as sendfile.
+func (w persistConnWriter) ReadFrom(r io.Reader) (n int64, err error) {
+ n, err = io.Copy(w.pc.conn, r)
+ w.pc.nwrite += n
+ return
+}
+
+var _ io.ReaderFrom = (*persistConnWriter)(nil)
+
// connectMethod is the map key (in its String form) for keeping persistent
// TCP connections alive for subsequent HTTP requests.
//
@@ -1538,7 +1624,7 @@ func (pc *persistConn) isBroken() bool {
}
// canceled returns non-nil if the connection was closed due to
-// CancelRequest or due to context cancelation.
+// CancelRequest or due to context cancellation.
func (pc *persistConn) canceled() error {
pc.mu.Lock()
defer pc.mu.Unlock()
@@ -1794,7 +1880,7 @@ func (pc *persistConn) readLoop() {
// Before looping back to the top of this function and peeking on
// the bufio.Reader, wait for the caller goroutine to finish
- // reading the response body. (or for cancelation or death)
+ // reading the response body. (or for cancellation or death)
select {
case bodyEOF := <-waitForBodyRead:
pc.t.setReqCanceler(rc.req, nil) // before pc might return to idle pool
@@ -1826,7 +1912,12 @@ func (pc *persistConn) readLoopPeekFailLocked(peekErr error) {
}
if n := pc.br.Buffered(); n > 0 {
buf, _ := pc.br.Peek(n)
- log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v", buf, peekErr)
+ if is408Message(buf) {
+ pc.closeLocked(errServerClosedIdle)
+ return
+ } else {
+ log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v", buf, peekErr)
+ }
}
if peekErr == io.EOF {
// common case.
@@ -1836,6 +1927,19 @@ func (pc *persistConn) readLoopPeekFailLocked(peekErr error) {
}
}
+// is408Message reports whether buf has the prefix of an
+// HTTP 408 Request Timeout response.
+// See golang.org/issue/32310.
+func is408Message(buf []byte) bool {
+ if len(buf) < len("HTTP/1.x 408") {
+ return false
+ }
+ if string(buf[:7]) != "HTTP/1." {
+ return false
+ }
+ return string(buf[8:12]) == " 408"
+}
+
// readResponse reads an HTTP response (or two, in the case of "Expect:
// 100-continue") from the server. It returns the final non-100 one.
// trace is optional.
@@ -2072,8 +2176,21 @@ func (e *httpError) Error() string { return e.err }
func (e *httpError) Timeout() bool { return e.timeout }
func (e *httpError) Temporary() bool { return true }
+func (e *httpError) Is(target error) bool {
+ switch target {
+ case os.ErrTimeout:
+ return e.timeout
+ case os.ErrTemporary:
+ return true
+ }
+ return false
+}
+
var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true}
-var errRequestCanceled = errors.New("net/http: request canceled")
+
+// errRequestCanceled is set to be identical to the one from h2 to facilitate
+// testing.
+var errRequestCanceled = http2errRequestCanceled
var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify?
func nop() {}
@@ -2258,13 +2375,8 @@ func (pc *persistConn) closeLocked(err error) {
if pc.closed == nil {
pc.closed = err
if pc.alt != nil {
- // Do nothing; can only get here via getConn's
- // handlePendingDial's putOrCloseIdleConn when
- // it turns out the abandoned connection in
- // flight ended up negotiating an alternate
- // protocol. We don't use the connection
- // freelist for http2. That's done by the
- // alternate protocol's RoundTripper.
+ // Clean up any host connection counting.
+ pc.t.decHostConnCount(pc.cacheKey)
} else {
if err != errCallerOwnsConn {
pc.conn.Close()
@@ -2408,6 +2520,10 @@ func (tlsHandshakeTimeoutError) Timeout() bool { return true }
func (tlsHandshakeTimeoutError) Temporary() bool { return true }
func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
+func (tlsHandshakeTimeoutError) Is(target error) bool {
+ return target == os.ErrTimeout || target == os.ErrTemporary
+}
+
// fakeLocker is a sync.Locker which does nothing. It's used to guard
// test-only fields when not under test, to avoid runtime atomic
// overhead.
diff --git a/libgo/go/net/http/transport_test.go b/libgo/go/net/http/transport_test.go
index 6e07584..2b58e1d 100644
--- a/libgo/go/net/http/transport_test.go
+++ b/libgo/go/net/http/transport_test.go
@@ -20,6 +20,7 @@ import (
"encoding/binary"
"errors"
"fmt"
+ "go/token"
"internal/nettrace"
"io"
"io/ioutil"
@@ -42,7 +43,7 @@ import (
"testing"
"time"
- "internal/x/net/http/httpguts"
+ "golang.org/x/net/http/httpguts"
)
// TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close
@@ -588,6 +589,106 @@ func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
<-reqComplete
}
+func TestTransportMaxConnsPerHost(t *testing.T) {
+ defer afterTest(t)
+
+ h := HandlerFunc(func(w ResponseWriter, r *Request) {
+ _, err := w.Write([]byte("foo"))
+ if err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+ })
+
+ testMaxConns := func(scheme string, ts *httptest.Server) {
+ defer ts.Close()
+
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+ tr.MaxConnsPerHost = 1
+ if err := ExportHttp2ConfigureTransport(tr); err != nil {
+ t.Fatalf("ExportHttp2ConfigureTransport: %v", err)
+ }
+
+ connCh := make(chan net.Conn, 1)
+ var dialCnt, gotConnCnt, tlsHandshakeCnt int32
+ tr.Dial = func(network, addr string) (net.Conn, error) {
+ atomic.AddInt32(&dialCnt, 1)
+ c, err := net.Dial(network, addr)
+ connCh <- c
+ return c, err
+ }
+
+ doReq := func() {
+ trace := &httptrace.ClientTrace{
+ GotConn: func(connInfo httptrace.GotConnInfo) {
+ if !connInfo.Reused {
+ atomic.AddInt32(&gotConnCnt, 1)
+ }
+ },
+ TLSHandshakeStart: func() {
+ atomic.AddInt32(&tlsHandshakeCnt, 1)
+ },
+ }
+ req, _ := NewRequest("GET", ts.URL, nil)
+ req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
+
+ resp, err := c.Do(req)
+ if err != nil {
+ t.Fatalf("request failed: %v", err)
+ }
+ defer resp.Body.Close()
+ _, err = ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Fatalf("read body failed: %v", err)
+ }
+ }
+
+ wg := sync.WaitGroup{}
+ for i := 0; i < 10; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ doReq()
+ }()
+ }
+ wg.Wait()
+
+ expected := int32(tr.MaxConnsPerHost)
+ if dialCnt != expected {
+ t.Errorf("Too many dials (%s): %d", scheme, dialCnt)
+ }
+ if gotConnCnt != expected {
+ t.Errorf("Too many get connections (%s): %d", scheme, gotConnCnt)
+ }
+ if ts.TLS != nil && tlsHandshakeCnt != expected {
+ t.Errorf("Too many tls handshakes (%s): %d", scheme, tlsHandshakeCnt)
+ }
+
+ (<-connCh).Close()
+ tr.CloseIdleConnections()
+
+ doReq()
+ expected++
+ if dialCnt != expected {
+ t.Errorf("Too many dials (%s): %d", scheme, dialCnt)
+ }
+ if gotConnCnt != expected {
+ t.Errorf("Too many get connections (%s): %d", scheme, gotConnCnt)
+ }
+ if ts.TLS != nil && tlsHandshakeCnt != expected {
+ t.Errorf("Too many tls handshakes (%s): %d", scheme, tlsHandshakeCnt)
+ }
+ }
+
+ testMaxConns("http", httptest.NewServer(h))
+ testMaxConns("https", httptest.NewTLSServer(h))
+
+ ts := httptest.NewUnstartedServer(h)
+ ts.TLS = &tls.Config{NextProtos: []string{"h2"}}
+ ts.StartTLS()
+ testMaxConns("http2", ts)
+}
+
func TestTransportRemovesDeadIdleConnections(t *testing.T) {
setParallel(t)
defer afterTest(t)
@@ -636,6 +737,8 @@ func TestTransportRemovesDeadIdleConnections(t *testing.T) {
}
}
+// Test that the Transport notices when a server hangs up on its
+// unexpectedly (a keep-alive connection is closed).
func TestTransportServerClosingUnexpectedly(t *testing.T) {
setParallel(t)
defer afterTest(t)
@@ -672,13 +775,14 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) {
body1 := fetch(1, 0)
body2 := fetch(2, 0)
- ts.CloseClientConnections() // surprise!
-
- // This test has an expected race. Sleeping for 25 ms prevents
- // it on most fast machines, causing the next fetch() call to
- // succeed quickly. But if we do get errors, fetch() will retry 5
- // times with some delays between.
- time.Sleep(25 * time.Millisecond)
+ // Close all the idle connections in a way that's similar to
+ // the server hanging up on us. We don't use
+ // httptest.Server.CloseClientConnections because it's
+ // best-effort and stops blocking after 5 seconds. On a loaded
+ // machine running many tests concurrently it's possible for
+ // that method to be async and cause the body3 fetch below to
+ // run on an old connection. This function is synchronous.
+ ExportCloseTransportConnsAbruptly(c.Transport.(*Transport))
body3 := fetch(3, 5)
@@ -865,6 +969,10 @@ func TestRoundTripGzip(t *testing.T) {
req.Header.Set("Accept-Encoding", test.accept)
}
res, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Errorf("%d. RoundTrip: %v", i, err)
+ continue
+ }
var body []byte
if test.compressed {
var r *gzip.Reader
@@ -2110,7 +2218,7 @@ func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) {
}
} else {
if err == nil || !strings.Contains(err.Error(), "canceled") {
- t.Errorf("Do error = %v; want cancelation", err)
+ t.Errorf("Do error = %v; want cancellation", err)
}
}
}
@@ -3589,6 +3697,13 @@ func TestTransportAutomaticHTTP2(t *testing.T) {
testTransportAutoHTTP(t, &Transport{}, true)
}
+func TestTransportAutomaticHTTP2_DialerAndTLSConfigSupportsHTTP2AndTLSConfig(t *testing.T) {
+ testTransportAutoHTTP(t, &Transport{
+ ForceAttemptHTTP2: true,
+ TLSClientConfig: new(tls.Config),
+ }, true)
+}
+
// golang.org/issue/14391: also check DefaultTransport
func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) {
testTransportAutoHTTP(t, DefaultTransport.(*Transport), true)
@@ -3619,6 +3734,13 @@ func TestTransportAutomaticHTTP2_Dial(t *testing.T) {
}, false)
}
+func TestTransportAutomaticHTTP2_DialContext(t *testing.T) {
+ var d net.Dialer
+ testTransportAutoHTTP(t, &Transport{
+ DialContext: d.DialContext,
+ }, false)
+}
+
func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) {
testTransportAutoHTTP(t, &Transport{
DialTLS: func(network, addr string) (net.Conn, error) {
@@ -5059,3 +5181,270 @@ func TestTransportRequestReplayable(t *testing.T) {
})
}
}
+
+// testMockTCPConn is a mock TCP connection used to test that
+// ReadFrom is called when sending the request body.
+type testMockTCPConn struct {
+ *net.TCPConn
+
+ ReadFromCalled bool
+}
+
+func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
+ c.ReadFromCalled = true
+ return c.TCPConn.ReadFrom(r)
+}
+
+func TestTransportRequestWriteRoundTrip(t *testing.T) {
+ nBytes := int64(1 << 10)
+ newFileFunc := func() (r io.Reader, done func(), err error) {
+ f, err := ioutil.TempFile("", "net-http-newfilefunc")
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // Write some bytes to the file to enable reading.
+ if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
+ return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
+ }
+ if _, err := f.Seek(0, 0); err != nil {
+ return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
+ }
+
+ done = func() {
+ f.Close()
+ os.Remove(f.Name())
+ }
+
+ return f, done, nil
+ }
+
+ newBufferFunc := func() (io.Reader, func(), error) {
+ return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
+ }
+
+ cases := []struct {
+ name string
+ readerFunc func() (io.Reader, func(), error)
+ contentLength int64
+ expectedReadFrom bool
+ }{
+ {
+ name: "file, length",
+ readerFunc: newFileFunc,
+ contentLength: nBytes,
+ expectedReadFrom: true,
+ },
+ {
+ name: "file, no length",
+ readerFunc: newFileFunc,
+ },
+ {
+ name: "file, negative length",
+ readerFunc: newFileFunc,
+ contentLength: -1,
+ },
+ {
+ name: "buffer",
+ contentLength: nBytes,
+ readerFunc: newBufferFunc,
+ },
+ {
+ name: "buffer, no length",
+ readerFunc: newBufferFunc,
+ },
+ {
+ name: "buffer, length -1",
+ contentLength: -1,
+ readerFunc: newBufferFunc,
+ },
+ }
+
+ for _, tc := range cases {
+ t.Run(tc.name, func(t *testing.T) {
+ r, cleanup, err := tc.readerFunc()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ tConn := &testMockTCPConn{}
+ trFunc := func(tr *Transport) {
+ tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
+ var d net.Dialer
+ conn, err := d.DialContext(ctx, network, addr)
+ if err != nil {
+ return nil, err
+ }
+
+ tcpConn, ok := conn.(*net.TCPConn)
+ if !ok {
+ return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
+ }
+
+ tConn.TCPConn = tcpConn
+ return tConn, nil
+ }
+ }
+
+ cst := newClientServerTest(
+ t,
+ h1Mode,
+ HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.Copy(ioutil.Discard, r.Body)
+ r.Body.Close()
+ w.WriteHeader(200)
+ }),
+ trFunc,
+ )
+ defer cst.close()
+
+ req, err := NewRequest("PUT", cst.ts.URL, r)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.ContentLength = tc.contentLength
+ req.Header.Set("Content-Type", "application/octet-stream")
+ resp, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != 200 {
+ t.Fatalf("status code = %d; want 200", resp.StatusCode)
+ }
+
+ if !tConn.ReadFromCalled && tc.expectedReadFrom {
+ t.Fatalf("did not call ReadFrom")
+ }
+
+ if tConn.ReadFromCalled && !tc.expectedReadFrom {
+ t.Fatalf("ReadFrom was unexpectedly invoked")
+ }
+ })
+ }
+}
+
+func TestTransportClone(t *testing.T) {
+ tr := &Transport{
+ Proxy: func(*Request) (*url.URL, error) { panic("") },
+ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
+ Dial: func(network, addr string) (net.Conn, error) { panic("") },
+ DialTLS: func(network, addr string) (net.Conn, error) { panic("") },
+ TLSClientConfig: new(tls.Config),
+ TLSHandshakeTimeout: time.Second,
+ DisableKeepAlives: true,
+ DisableCompression: true,
+ MaxIdleConns: 1,
+ MaxIdleConnsPerHost: 1,
+ MaxConnsPerHost: 1,
+ IdleConnTimeout: time.Second,
+ ResponseHeaderTimeout: time.Second,
+ ExpectContinueTimeout: time.Second,
+ ProxyConnectHeader: Header{},
+ MaxResponseHeaderBytes: 1,
+ ForceAttemptHTTP2: true,
+ TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
+ "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
+ },
+ ReadBufferSize: 1,
+ WriteBufferSize: 1,
+ }
+ tr2 := tr.Clone()
+ rv := reflect.ValueOf(tr2).Elem()
+ rt := rv.Type()
+ for i := 0; i < rt.NumField(); i++ {
+ sf := rt.Field(i)
+ if !token.IsExported(sf.Name) {
+ continue
+ }
+ if rv.Field(i).IsZero() {
+ t.Errorf("cloned field t2.%s is zero", sf.Name)
+ }
+ }
+
+ if _, ok := tr2.TLSNextProto["foo"]; !ok {
+ t.Errorf("cloned Transport lacked TLSNextProto 'foo' key")
+ }
+
+ // But test that a nil TLSNextProto is kept nil:
+ tr = new(Transport)
+ tr2 = tr.Clone()
+ if tr2.TLSNextProto != nil {
+ t.Errorf("Transport.TLSNextProto unexpected non-nil")
+ }
+}
+
+func TestIs408(t *testing.T) {
+ tests := []struct {
+ in string
+ want bool
+ }{
+ {"HTTP/1.0 408", true},
+ {"HTTP/1.1 408", true},
+ {"HTTP/1.8 408", true},
+ {"HTTP/2.0 408", false}, // maybe h2c would do this? but false for now.
+ {"HTTP/1.1 408 ", true},
+ {"HTTP/1.1 40", false},
+ {"http/1.0 408", false},
+ {"HTTP/1-1 408", false},
+ }
+ for _, tt := range tests {
+ if got := Export_is408Message([]byte(tt.in)); got != tt.want {
+ t.Errorf("is408Message(%q) = %v; want %v", tt.in, got, tt.want)
+ }
+ }
+}
+
+func TestTransportIgnores408(t *testing.T) {
+ // Not parallel. Relies on mutating the log package's global Output.
+ defer log.SetOutput(log.Writer())
+
+ var logout bytes.Buffer
+ log.SetOutput(&logout)
+
+ defer afterTest(t)
+ const target = "backend:443"
+
+ cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ nc, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer nc.Close()
+ nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"))
+ nc.Write([]byte("HTTP/1.1 408 bye\r\n")) // changing 408 to 409 makes test fail
+ }))
+ defer cst.close()
+ req, err := NewRequest("GET", cst.ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ slurp, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(slurp) != "ok" {
+ t.Fatalf("got %q; want ok", slurp)
+ }
+
+ t0 := time.Now()
+ for i := 0; i < 50; i++ {
+ time.Sleep(time.Duration(i) * 5 * time.Millisecond)
+ if cst.tr.IdleConnKeyCountForTesting() == 0 {
+ if got := logout.String(); got != "" {
+ t.Fatalf("expected no log output; got: %s", got)
+ }
+ return
+ }
+ }
+ t.Fatalf("timeout after %v waiting for Transport connections to die off", time.Since(t0))
+}