diff options
author | Ian Lance Taylor <ian@gcc.gnu.org> | 2012-10-03 05:27:36 +0000 |
---|---|---|
committer | Ian Lance Taylor <ian@gcc.gnu.org> | 2012-10-03 05:27:36 +0000 |
commit | bd2e46c8255fad4e75e589b3286ead560e910b39 (patch) | |
tree | 4f194bdb2e9edcc69ef2ab0dfb4aab15ca259267 /libgo/go/net | |
parent | bed6238ce677ba18a672a58bc077cec6de47f8d3 (diff) | |
download | gcc-bd2e46c8255fad4e75e589b3286ead560e910b39.zip gcc-bd2e46c8255fad4e75e589b3286ead560e910b39.tar.gz gcc-bd2e46c8255fad4e75e589b3286ead560e910b39.tar.bz2 |
libgo: Update to Go 1.0.3.
From-SVN: r192025
Diffstat (limited to 'libgo/go/net')
26 files changed, 926 insertions, 205 deletions
diff --git a/libgo/go/net/dial.go b/libgo/go/net/dial.go index 10ca5fa..5191239 100644 --- a/libgo/go/net/dial.go +++ b/libgo/go/net/dial.go @@ -173,7 +173,7 @@ func (a stringAddr) String() string { return a.addr } // Listen announces on the local network address laddr. // The network string net must be a stream-oriented network: -// "tcp", "tcp4", "tcp6", or "unix", or "unixpacket". +// "tcp", "tcp4", "tcp6", "unix" or "unixpacket". func Listen(net, laddr string) (Listener, error) { afnet, a, err := resolveNetAddr("listen", net, laddr) if err != nil { diff --git a/libgo/go/net/fd.go b/libgo/go/net/fd.go index 76c953b..ff4f4f8 100644 --- a/libgo/go/net/fd.go +++ b/libgo/go/net/fd.go @@ -645,10 +645,14 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err e } func (fd *netFD) dup() (f *os.File, err error) { + syscall.ForkLock.RLock() ns, err := syscall.Dup(fd.sysfd) if err != nil { + syscall.ForkLock.RUnlock() return nil, &OpError{"dup", fd.net, fd.laddr, err} } + syscall.CloseOnExec(ns) + syscall.ForkLock.RUnlock() // We want blocking mode for the new fd, hence the double negative. if err = syscall.SetNonblock(ns, false); err != nil { diff --git a/libgo/go/net/file.go b/libgo/go/net/file.go index fc6c6fa..837326e 100644 --- a/libgo/go/net/file.go +++ b/libgo/go/net/file.go @@ -12,13 +12,18 @@ import ( ) func newFileFD(f *os.File) (*netFD, error) { + syscall.ForkLock.RLock() fd, err := syscall.Dup(int(f.Fd())) if err != nil { + syscall.ForkLock.RUnlock() return nil, os.NewSyscallError("dup", err) } + syscall.CloseOnExec(fd) + syscall.ForkLock.RUnlock() - proto, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_TYPE) + sotype, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_TYPE) if err != nil { + closesocket(fd) return nil, os.NewSyscallError("getsockopt", err) } @@ -31,24 +36,24 @@ func newFileFD(f *os.File) (*netFD, error) { return nil, syscall.EINVAL case *syscall.SockaddrInet4: family = syscall.AF_INET - if proto == syscall.SOCK_DGRAM { + if sotype == syscall.SOCK_DGRAM { toAddr = sockaddrToUDP - } else if proto == syscall.SOCK_RAW { + } else if sotype == syscall.SOCK_RAW { toAddr = sockaddrToIP } case *syscall.SockaddrInet6: family = syscall.AF_INET6 - if proto == syscall.SOCK_DGRAM { + if sotype == syscall.SOCK_DGRAM { toAddr = sockaddrToUDP - } else if proto == syscall.SOCK_RAW { + } else if sotype == syscall.SOCK_RAW { toAddr = sockaddrToIP } case *syscall.SockaddrUnix: family = syscall.AF_UNIX toAddr = sockaddrToUnix - if proto == syscall.SOCK_DGRAM { + if sotype == syscall.SOCK_DGRAM { toAddr = sockaddrToUnixgram - } else if proto == syscall.SOCK_SEQPACKET { + } else if sotype == syscall.SOCK_SEQPACKET { toAddr = sockaddrToUnixpacket } } @@ -56,8 +61,9 @@ func newFileFD(f *os.File) (*netFD, error) { sa, _ = syscall.Getpeername(fd) raddr := toAddr(sa) - netfd, err := newFD(fd, family, proto, laddr.Network()) + netfd, err := newFD(fd, family, sotype, laddr.Network()) if err != nil { + closesocket(fd) return nil, err } netfd.setAddr(laddr, raddr) diff --git a/libgo/go/net/http/client.go b/libgo/go/net/http/client.go index 54564e0..8944142 100644 --- a/libgo/go/net/http/client.go +++ b/libgo/go/net/http/client.go @@ -14,6 +14,7 @@ import ( "errors" "fmt" "io" + "log" "net/url" "strings" ) @@ -35,7 +36,8 @@ type Client struct { // following an HTTP redirect. The arguments req and via // are the upcoming request and the requests made already, // oldest first. If CheckRedirect returns an error, the client - // returns that error instead of issue the Request req. + // returns that error (wrapped in a url.Error) instead of + // issuing the Request req. // // If CheckRedirect is nil, the Client uses its default policy, // which is to stop after 10 consecutive requests. @@ -87,9 +89,13 @@ type readClose struct { // Do sends an HTTP request and returns an HTTP response, following // policy (e.g. redirects, cookies, auth) as configured on the client. // -// A non-nil response always contains a non-nil resp.Body. +// An error is returned if caused by client policy (such as +// CheckRedirect), or if there was an HTTP protocol error. +// A non-2xx response doesn't cause an error. // -// Callers should close resp.Body when done reading from it. If +// When err is nil, resp always contains a non-nil resp.Body. +// +// Callers should close res.Body when done reading from it. If // resp.Body is not closed, the Client's underlying RoundTripper // (typically Transport) may not be able to re-use a persistent TCP // connection to the server for a subsequent "keep-alive" request. @@ -102,7 +108,8 @@ func (c *Client) Do(req *Request) (resp *Response, err error) { return send(req, c.Transport) } -// send issues an HTTP request. Caller should close resp.Body when done reading from it. +// send issues an HTTP request. +// Caller should close resp.Body when done reading from it. func send(req *Request, t RoundTripper) (resp *Response, err error) { if t == nil { t = DefaultTransport @@ -130,7 +137,14 @@ func send(req *Request, t RoundTripper) (resp *Response, err error) { if u := req.URL.User; u != nil { req.Header.Set("Authorization", "Basic "+base64.URLEncoding.EncodeToString([]byte(u.String()))) } - return t.RoundTrip(req) + resp, err = t.RoundTrip(req) + if err != nil { + if resp != nil { + log.Printf("RoundTripper returned a response & error; ignoring response") + } + return nil, err + } + return resp, nil } // True if the specified HTTP status code is one for which the Get utility should @@ -151,10 +165,15 @@ func shouldRedirect(statusCode int) bool { // 303 (See Other) // 307 (Temporary Redirect) // -// Caller should close r.Body when done reading from it. +// An error is returned if there were too many redirects or if there +// was an HTTP protocol error. A non-2xx response doesn't cause an +// error. +// +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. // // Get is a wrapper around DefaultClient.Get. -func Get(url string) (r *Response, err error) { +func Get(url string) (resp *Response, err error) { return DefaultClient.Get(url) } @@ -167,8 +186,13 @@ func Get(url string) (r *Response, err error) { // 303 (See Other) // 307 (Temporary Redirect) // -// Caller should close r.Body when done reading from it. -func (c *Client) Get(url string) (r *Response, err error) { +// An error is returned if the Client's CheckRedirect function fails +// or if there was an HTTP protocol error. A non-2xx response doesn't +// cause an error. +// +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. +func (c *Client) Get(url string) (resp *Response, err error) { req, err := NewRequest("GET", url, nil) if err != nil { return nil, err @@ -176,7 +200,7 @@ func (c *Client) Get(url string) (r *Response, err error) { return c.doFollowingRedirects(req) } -func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) { +func (c *Client) doFollowingRedirects(ireq *Request) (resp *Response, err error) { // TODO: if/when we add cookie support, the redirected request shouldn't // necessarily supply the same cookies as the original. var base *url.URL @@ -224,17 +248,17 @@ func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) { req.AddCookie(cookie) } urlStr = req.URL.String() - if r, err = send(req, c.Transport); err != nil { + if resp, err = send(req, c.Transport); err != nil { break } - if c := r.Cookies(); len(c) > 0 { + if c := resp.Cookies(); len(c) > 0 { jar.SetCookies(req.URL, c) } - if shouldRedirect(r.StatusCode) { - r.Body.Close() - if urlStr = r.Header.Get("Location"); urlStr == "" { - err = errors.New(fmt.Sprintf("%d response missing Location header", r.StatusCode)) + if shouldRedirect(resp.StatusCode) { + resp.Body.Close() + if urlStr = resp.Header.Get("Location"); urlStr == "" { + err = errors.New(fmt.Sprintf("%d response missing Location header", resp.StatusCode)) break } base = req.URL @@ -244,13 +268,16 @@ func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err error) { return } + if resp != nil { + resp.Body.Close() + } + method := ireq.Method - err = &url.Error{ + return nil, &url.Error{ Op: method[0:1] + strings.ToLower(method[1:]), URL: urlStr, Err: err, } - return } func defaultCheckRedirect(req *Request, via []*Request) error { @@ -262,17 +289,17 @@ func defaultCheckRedirect(req *Request, via []*Request) error { // Post issues a POST to the specified URL. // -// Caller should close r.Body when done reading from it. +// Caller should close resp.Body when done reading from it. // // Post is a wrapper around DefaultClient.Post -func Post(url string, bodyType string, body io.Reader) (r *Response, err error) { +func Post(url string, bodyType string, body io.Reader) (resp *Response, err error) { return DefaultClient.Post(url, bodyType, body) } // Post issues a POST to the specified URL. // -// Caller should close r.Body when done reading from it. -func (c *Client) Post(url string, bodyType string, body io.Reader) (r *Response, err error) { +// Caller should close resp.Body when done reading from it. +func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Response, err error) { req, err := NewRequest("POST", url, body) if err != nil { return nil, err @@ -283,28 +310,30 @@ func (c *Client) Post(url string, bodyType string, body io.Reader) (r *Response, req.AddCookie(cookie) } } - r, err = send(req, c.Transport) + resp, err = send(req, c.Transport) if err == nil && c.Jar != nil { - c.Jar.SetCookies(req.URL, r.Cookies()) + c.Jar.SetCookies(req.URL, resp.Cookies()) } - return r, err + return } -// PostForm issues a POST to the specified URL, -// with data's keys and values urlencoded as the request body. +// PostForm issues a POST to the specified URL, with data's keys and +// values URL-encoded as the request body. // -// Caller should close r.Body when done reading from it. +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. // // PostForm is a wrapper around DefaultClient.PostForm -func PostForm(url string, data url.Values) (r *Response, err error) { +func PostForm(url string, data url.Values) (resp *Response, err error) { return DefaultClient.PostForm(url, data) } // PostForm issues a POST to the specified URL, // with data's keys and values urlencoded as the request body. // -// Caller should close r.Body when done reading from it. -func (c *Client) PostForm(url string, data url.Values) (r *Response, err error) { +// When err is nil, resp always contains a non-nil resp.Body. +// Caller should close resp.Body when done reading from it. +func (c *Client) PostForm(url string, data url.Values) (resp *Response, err error) { return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) } @@ -318,7 +347,7 @@ func (c *Client) PostForm(url string, data url.Values) (r *Response, err error) // 307 (Temporary Redirect) // // Head is a wrapper around DefaultClient.Head -func Head(url string) (r *Response, err error) { +func Head(url string) (resp *Response, err error) { return DefaultClient.Head(url) } @@ -330,7 +359,7 @@ func Head(url string) (r *Response, err error) { // 302 (Found) // 303 (See Other) // 307 (Temporary Redirect) -func (c *Client) Head(url string) (r *Response, err error) { +func (c *Client) Head(url string) (resp *Response, err error) { req, err := NewRequest("HEAD", url, nil) if err != nil { return nil, err diff --git a/libgo/go/net/http/client_test.go b/libgo/go/net/http/client_test.go index 9b4261b..09fcc1c 100644 --- a/libgo/go/net/http/client_test.go +++ b/libgo/go/net/http/client_test.go @@ -8,6 +8,7 @@ package http_test import ( "crypto/tls" + "crypto/x509" "errors" "fmt" "io" @@ -231,9 +232,8 @@ func TestRedirects(t *testing.T) { checkErr = errors.New("no redirects allowed") res, err = c.Get(ts.URL) - finalUrl = res.Request.URL.String() - if e, g := "Get /?n=1: no redirects allowed", fmt.Sprintf("%v", err); e != g { - t.Errorf("with redirects forbidden, expected error %q, got %q", e, g) + if urlError, ok := err.(*url.Error); !ok || urlError.Err != checkErr { + t.Errorf("with redirects forbidden, expected a *url.Error with our 'no redirects allowed' error inside; got %#v (%q)", err, err) } } @@ -465,3 +465,49 @@ func TestClientErrorWithRequestURI(t *testing.T) { t.Errorf("wanted error mentioning RequestURI; got error: %v", err) } } + +func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport { + certs := x509.NewCertPool() + for _, c := range ts.TLS.Certificates { + roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) + if err != nil { + t.Fatalf("error parsing server's root cert: %v", err) + } + for _, root := range roots { + certs.AddCert(root) + } + } + return &Transport{ + TLSClientConfig: &tls.Config{RootCAs: certs}, + } +} + +func TestClientWithCorrectTLSServerName(t *testing.T) { + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.TLS.ServerName != "127.0.0.1" { + t.Errorf("expected client to set ServerName 127.0.0.1, got: %q", r.TLS.ServerName) + } + })) + defer ts.Close() + + c := &Client{Transport: newTLSTransport(t, ts)} + if _, err := c.Get(ts.URL); err != nil { + t.Fatalf("expected successful TLS connection, got error: %v", err) + } +} + +func TestClientWithIncorrectTLSServerName(t *testing.T) { + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + defer ts.Close() + + trans := newTLSTransport(t, ts) + trans.TLSClientConfig.ServerName = "badserver" + c := &Client{Transport: trans} + _, err := c.Get(ts.URL) + if err == nil { + t.Fatalf("expected an error") + } + if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") { + t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err) + } +} diff --git a/libgo/go/net/http/example_test.go b/libgo/go/net/http/example_test.go index ec81440..22073ea 100644 --- a/libgo/go/net/http/example_test.go +++ b/libgo/go/net/http/example_test.go @@ -43,10 +43,10 @@ func ExampleGet() { log.Fatal(err) } robots, err := ioutil.ReadAll(res.Body) + res.Body.Close() if err != nil { log.Fatal(err) } - res.Body.Close() fmt.Printf("%s", robots) } diff --git a/libgo/go/net/http/export_test.go b/libgo/go/net/http/export_test.go index 13640ca8..313c6af 100644 --- a/libgo/go/net/http/export_test.go +++ b/libgo/go/net/http/export_test.go @@ -11,8 +11,8 @@ import "time" func (t *Transport) IdleConnKeysForTesting() (keys []string) { keys = make([]string, 0) - t.lk.Lock() - defer t.lk.Unlock() + t.idleLk.Lock() + defer t.idleLk.Unlock() if t.idleConn == nil { return } @@ -23,8 +23,8 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) { } func (t *Transport) IdleConnCountForTesting(cacheKey string) int { - t.lk.Lock() - defer t.lk.Unlock() + t.idleLk.Lock() + defer t.idleLk.Unlock() if t.idleConn == nil { return 0 } diff --git a/libgo/go/net/http/fs.go b/libgo/go/net/http/fs.go index f35dd32..208d6ca 100644 --- a/libgo/go/net/http/fs.go +++ b/libgo/go/net/http/fs.go @@ -11,6 +11,8 @@ import ( "fmt" "io" "mime" + "mime/multipart" + "net/textproto" "os" "path" "path/filepath" @@ -26,7 +28,8 @@ import ( type Dir string func (d Dir) Open(name string) (File, error) { - if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 { + if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 || + strings.Contains(name, "\x00") { return nil, errors.New("http: invalid character in file path") } dir := string(d) @@ -123,8 +126,9 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, code := StatusOK // If Content-Type isn't set, use the file's extension to find it. - if w.Header().Get("Content-Type") == "" { - ctype := mime.TypeByExtension(filepath.Ext(name)) + ctype := w.Header().Get("Content-Type") + if ctype == "" { + ctype = mime.TypeByExtension(filepath.Ext(name)) if ctype == "" { // read a chunk to decide between utf-8 text and binary var buf [1024]byte @@ -141,18 +145,34 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, } // handle Content-Range header. - // TODO(adg): handle multiple ranges sendSize := size + var sendContent io.Reader = content if size >= 0 { ranges, err := parseRange(r.Header.Get("Range"), size) - if err == nil && len(ranges) > 1 { - err = errors.New("multiple ranges not supported") - } if err != nil { Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) return } - if len(ranges) == 1 { + if sumRangesSize(ranges) >= size { + // The total number of bytes in all the ranges + // is larger than the size of the file by + // itself, so this is probably an attack, or a + // dumb client. Ignore the range request. + ranges = nil + } + switch { + case len(ranges) == 1: + // RFC 2616, Section 14.16: + // "When an HTTP message includes the content of a single + // range (for example, a response to a request for a + // single range, or to a request for a set of ranges + // that overlap without any holes), this content is + // transmitted with a Content-Range header, and a + // Content-Length header showing the number of bytes + // actually transferred. + // ... + // A response to a request for a single range MUST NOT + // be sent using the multipart/byteranges media type." ra := ranges[0] if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil { Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) @@ -160,7 +180,41 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, } sendSize = ra.length code = StatusPartialContent - w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", ra.start, ra.start+ra.length-1, size)) + w.Header().Set("Content-Range", ra.contentRange(size)) + case len(ranges) > 1: + for _, ra := range ranges { + if ra.start > size { + Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) + return + } + } + sendSize = rangesMIMESize(ranges, ctype, size) + code = StatusPartialContent + + pr, pw := io.Pipe() + mw := multipart.NewWriter(pw) + w.Header().Set("Content-Type", "multipart/byteranges; boundary="+mw.Boundary()) + sendContent = pr + defer pr.Close() // cause writing goroutine to fail and exit if CopyN doesn't finish. + go func() { + for _, ra := range ranges { + part, err := mw.CreatePart(ra.mimeHeader(ctype, size)) + if err != nil { + pw.CloseWithError(err) + return + } + if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil { + pw.CloseWithError(err) + return + } + if _, err := io.CopyN(part, content, ra.length); err != nil { + pw.CloseWithError(err) + return + } + } + mw.Close() + pw.Close() + }() } w.Header().Set("Accept-Ranges", "bytes") @@ -172,11 +226,7 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, w.WriteHeader(code) if r.Method != "HEAD" { - if sendSize == -1 { - io.Copy(w, content) - } else { - io.CopyN(w, content, sendSize) - } + io.CopyN(w, sendContent, sendSize) } } @@ -243,9 +293,6 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec // use contents of index.html for directory, if present if d.IsDir() { - if checkLastModified(w, r, d.ModTime()) { - return - } index := name + indexPage ff, err := fs.Open(index) if err == nil { @@ -259,11 +306,16 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec } } + // Still a directory? (we didn't find an index.html file) if d.IsDir() { + if checkLastModified(w, r, d.ModTime()) { + return + } dirList(w, f) return } + // serverContent will check modification time serveContent(w, r, d.Name(), d.ModTime(), d.Size(), f) } @@ -312,6 +364,17 @@ type httpRange struct { start, length int64 } +func (r httpRange) contentRange(size int64) string { + return fmt.Sprintf("bytes %d-%d/%d", r.start, r.start+r.length-1, size) +} + +func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHeader { + return textproto.MIMEHeader{ + "Content-Range": {r.contentRange(size)}, + "Content-Type": {contentType}, + } +} + // parseRange parses a Range header string as per RFC 2616. func parseRange(s string, size int64) ([]httpRange, error) { if s == "" { @@ -323,11 +386,15 @@ func parseRange(s string, size int64) ([]httpRange, error) { } var ranges []httpRange for _, ra := range strings.Split(s[len(b):], ",") { + ra = strings.TrimSpace(ra) + if ra == "" { + continue + } i := strings.Index(ra, "-") if i < 0 { return nil, errors.New("invalid range") } - start, end := ra[:i], ra[i+1:] + start, end := strings.TrimSpace(ra[:i]), strings.TrimSpace(ra[i+1:]) var r httpRange if start == "" { // If no start is specified, end specifies the @@ -365,3 +432,32 @@ func parseRange(s string, size int64) ([]httpRange, error) { } return ranges, nil } + +// countingWriter counts how many bytes have been written to it. +type countingWriter int64 + +func (w *countingWriter) Write(p []byte) (n int, err error) { + *w += countingWriter(len(p)) + return len(p), nil +} + +// rangesMIMESize returns the nunber of bytes it takes to encode the +// provided ranges as a multipart response. +func rangesMIMESize(ranges []httpRange, contentType string, contentSize int64) (encSize int64) { + var w countingWriter + mw := multipart.NewWriter(&w) + for _, ra := range ranges { + mw.CreatePart(ra.mimeHeader(contentType, contentSize)) + encSize += ra.length + } + mw.Close() + encSize += int64(w) + return +} + +func sumRangesSize(ranges []httpRange) (size int64) { + for _, ra := range ranges { + size += ra.length + } + return +} diff --git a/libgo/go/net/http/fs_test.go b/libgo/go/net/http/fs_test.go index ffba6a7..17329fb 100644 --- a/libgo/go/net/http/fs_test.go +++ b/libgo/go/net/http/fs_test.go @@ -10,12 +10,15 @@ import ( "fmt" "io" "io/ioutil" + "mime" + "mime/multipart" "net" . "net/http" "net/http/httptest" "net/url" "os" "os/exec" + "path" "path/filepath" "regexp" "runtime" @@ -25,21 +28,29 @@ import ( ) const ( - testFile = "testdata/file" - testFileLength = 11 + testFile = "testdata/file" + testFileLen = 11 ) +type wantRange struct { + start, end int64 // range [start,end) +} + var ServeFileRangeTests = []struct { - start, end int - r string - code int + r string + code int + ranges []wantRange }{ - {0, testFileLength, "", StatusOK}, - {0, 5, "0-4", StatusPartialContent}, - {2, testFileLength, "2-", StatusPartialContent}, - {testFileLength - 5, testFileLength, "-5", StatusPartialContent}, - {3, 8, "3-7", StatusPartialContent}, - {0, 0, "20-", StatusRequestedRangeNotSatisfiable}, + {r: "", code: StatusOK}, + {r: "bytes=0-4", code: StatusPartialContent, ranges: []wantRange{{0, 5}}}, + {r: "bytes=2-", code: StatusPartialContent, ranges: []wantRange{{2, testFileLen}}}, + {r: "bytes=-5", code: StatusPartialContent, ranges: []wantRange{{testFileLen - 5, testFileLen}}}, + {r: "bytes=3-7", code: StatusPartialContent, ranges: []wantRange{{3, 8}}}, + {r: "bytes=20-", code: StatusRequestedRangeNotSatisfiable}, + {r: "bytes=0-0,-2", code: StatusPartialContent, ranges: []wantRange{{0, 1}, {testFileLen - 2, testFileLen}}}, + {r: "bytes=0-1,5-8", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, 9}}}, + {r: "bytes=0-1,5-", code: StatusPartialContent, ranges: []wantRange{{0, 2}, {5, testFileLen}}}, + {r: "bytes=0-,1-,2-,3-,4-", code: StatusOK}, // ignore wasteful range request } func TestServeFile(t *testing.T) { @@ -65,33 +76,81 @@ func TestServeFile(t *testing.T) { // straight GET _, body := getBody(t, "straight get", req) - if !equal(body, file) { + if !bytes.Equal(body, file) { t.Fatalf("body mismatch: got %q, want %q", body, file) } // Range tests - for i, rt := range ServeFileRangeTests { - req.Header.Set("Range", "bytes="+rt.r) - if rt.r == "" { - req.Header["Range"] = nil + for _, rt := range ServeFileRangeTests { + if rt.r != "" { + req.Header.Set("Range", rt.r) } - r, body := getBody(t, fmt.Sprintf("test %d", i), req) - if r.StatusCode != rt.code { - t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, r.StatusCode, rt.code) + resp, body := getBody(t, fmt.Sprintf("range test %q", rt.r), req) + if resp.StatusCode != rt.code { + t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, resp.StatusCode, rt.code) } if rt.code == StatusRequestedRangeNotSatisfiable { continue } - h := fmt.Sprintf("bytes %d-%d/%d", rt.start, rt.end-1, testFileLength) - if rt.r == "" { - h = "" + wantContentRange := "" + if len(rt.ranges) == 1 { + rng := rt.ranges[0] + wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen) + } + cr := resp.Header.Get("Content-Range") + if cr != wantContentRange { + t.Errorf("range=%q: Content-Range = %q, want %q", rt.r, cr, wantContentRange) } - cr := r.Header.Get("Content-Range") - if cr != h { - t.Errorf("header mismatch: range=%q: got %q, want %q", rt.r, cr, h) + ct := resp.Header.Get("Content-Type") + if len(rt.ranges) == 1 { + rng := rt.ranges[0] + wantBody := file[rng.start:rng.end] + if !bytes.Equal(body, wantBody) { + t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody) + } + if strings.HasPrefix(ct, "multipart/byteranges") { + t.Errorf("range=%q content-type = %q; unexpected multipart/byteranges", rt.r) + } } - if !equal(body, file[rt.start:rt.end]) { - t.Errorf("body mismatch: range=%q: got %q, want %q", rt.r, body, file[rt.start:rt.end]) + if len(rt.ranges) > 1 { + typ, params, err := mime.ParseMediaType(ct) + if err != nil { + t.Errorf("range=%q content-type = %q; %v", rt.r, ct, err) + continue + } + if typ != "multipart/byteranges" { + t.Errorf("range=%q content-type = %q; want multipart/byteranges", rt.r) + continue + } + if params["boundary"] == "" { + t.Errorf("range=%q content-type = %q; lacks boundary", rt.r, ct) + } + if g, w := resp.ContentLength, int64(len(body)); g != w { + t.Errorf("range=%q Content-Length = %d; want %d", rt.r, g, w) + } + mr := multipart.NewReader(bytes.NewReader(body), params["boundary"]) + for ri, rng := range rt.ranges { + part, err := mr.NextPart() + if err != nil { + t.Fatalf("range=%q, reading part index %d: %v", rt.r, ri, err) + } + body, err := ioutil.ReadAll(part) + if err != nil { + t.Fatalf("range=%q, reading part index %d body: %v", rt.r, ri, err) + } + wantContentRange = fmt.Sprintf("bytes %d-%d/%d", rng.start, rng.end-1, testFileLen) + wantBody := file[rng.start:rng.end] + if !bytes.Equal(body, wantBody) { + t.Errorf("range=%q: body = %q, want %q", rt.r, body, wantBody) + } + if g, w := part.Header.Get("Content-Range"), wantContentRange; g != w { + t.Errorf("range=%q: part Content-Range = %q; want %q", rt.r, g, w) + } + } + _, err = mr.NextPart() + if err != io.EOF { + t.Errorf("range=%q; expected final error io.EOF; got %v", err) + } } } } @@ -276,6 +335,11 @@ func TestServeFileMimeType(t *testing.T) { } func TestServeFileFromCWD(t *testing.T) { + if runtime.GOOS == "windows" { + // TODO(brainman): find out why this test is broken + t.Logf("Temporarily skipping test on Windows; see http://golang.org/issue/3917") + return + } ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "fs_test.go") })) @@ -325,6 +389,139 @@ func TestServeIndexHtml(t *testing.T) { } } +func TestFileServerZeroByte(t *testing.T) { + ts := httptest.NewServer(FileServer(Dir("."))) + defer ts.Close() + + res, err := Get(ts.URL + "/..\x00") + if err != nil { + t.Fatal(err) + } + b, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal("reading Body:", err) + } + if res.StatusCode == 200 { + t.Errorf("got status 200; want an error. Body is:\n%s", string(b)) + } +} + +type fakeFileInfo struct { + dir bool + basename string + modtime time.Time + ents []*fakeFileInfo + contents string +} + +func (f *fakeFileInfo) Name() string { return f.basename } +func (f *fakeFileInfo) Sys() interface{} { return nil } +func (f *fakeFileInfo) ModTime() time.Time { return f.modtime } +func (f *fakeFileInfo) IsDir() bool { return f.dir } +func (f *fakeFileInfo) Size() int64 { return int64(len(f.contents)) } +func (f *fakeFileInfo) Mode() os.FileMode { + if f.dir { + return 0755 | os.ModeDir + } + return 0644 +} + +type fakeFile struct { + io.ReadSeeker + fi *fakeFileInfo + path string // as opened +} + +func (f *fakeFile) Close() error { return nil } +func (f *fakeFile) Stat() (os.FileInfo, error) { return f.fi, nil } +func (f *fakeFile) Readdir(count int) ([]os.FileInfo, error) { + if !f.fi.dir { + return nil, os.ErrInvalid + } + var fis []os.FileInfo + for _, fi := range f.fi.ents { + fis = append(fis, fi) + } + return fis, nil +} + +type fakeFS map[string]*fakeFileInfo + +func (fs fakeFS) Open(name string) (File, error) { + name = path.Clean(name) + f, ok := fs[name] + if !ok { + println("fake filesystem didn't find file", name) + return nil, os.ErrNotExist + } + return &fakeFile{ReadSeeker: strings.NewReader(f.contents), fi: f, path: name}, nil +} + +func TestDirectoryIfNotModified(t *testing.T) { + const indexContents = "I am a fake index.html file" + fileMod := time.Unix(1000000000, 0).UTC() + fileModStr := fileMod.Format(TimeFormat) + dirMod := time.Unix(123, 0).UTC() + indexFile := &fakeFileInfo{ + basename: "index.html", + modtime: fileMod, + contents: indexContents, + } + fs := fakeFS{ + "/": &fakeFileInfo{ + dir: true, + modtime: dirMod, + ents: []*fakeFileInfo{indexFile}, + }, + "/index.html": indexFile, + } + + ts := httptest.NewServer(FileServer(fs)) + defer ts.Close() + + res, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + b, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(b) != indexContents { + t.Fatalf("Got body %q; want %q", b, indexContents) + } + res.Body.Close() + + lastMod := res.Header.Get("Last-Modified") + if lastMod != fileModStr { + t.Fatalf("initial Last-Modified = %q; want %q", lastMod, fileModStr) + } + + req, _ := NewRequest("GET", ts.URL, nil) + req.Header.Set("If-Modified-Since", lastMod) + + res, err = DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 304 { + t.Fatalf("Code after If-Modified-Since request = %v; want 304", res.StatusCode) + } + res.Body.Close() + + // Advance the index.html file's modtime, but not the directory's. + indexFile.modtime = indexFile.modtime.Add(1 * time.Hour) + + res, err = DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != 200 { + t.Fatalf("Code after second If-Modified-Since request = %v; want 200; res is %#v", res.StatusCode, res) + } + res.Body.Close() +} + func TestServeContent(t *testing.T) { type req struct { name string @@ -464,15 +661,3 @@ func TestLinuxSendfileChild(*testing.T) { panic(err) } } - -func equal(a, b []byte) bool { - if len(a) != len(b) { - return false - } - for i := range a { - if a[i] != b[i] { - return false - } - } - return true -} diff --git a/libgo/go/net/http/header.go b/libgo/go/net/http/header.go index b107c31..6be94f9 100644 --- a/libgo/go/net/http/header.go +++ b/libgo/go/net/http/header.go @@ -76,3 +76,43 @@ func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { // the rest are converted to lowercase. For example, the // canonical key for "accept-encoding" is "Accept-Encoding". func CanonicalHeaderKey(s string) string { return textproto.CanonicalMIMEHeaderKey(s) } + +// hasToken returns whether token appears with v, ASCII +// case-insensitive, with space or comma boundaries. +// token must be all lowercase. +// v may contain mixed cased. +func hasToken(v, token string) bool { + if len(token) > len(v) || token == "" { + return false + } + if v == token { + return true + } + for sp := 0; sp <= len(v)-len(token); sp++ { + // Check that first character is good. + // The token is ASCII, so checking only a single byte + // is sufficient. We skip this potential starting + // position if both the first byte and its potential + // ASCII uppercase equivalent (b|0x20) don't match. + // False positives ('^' => '~') are caught by EqualFold. + if b := v[sp]; b != token[0] && b|0x20 != token[0] { + continue + } + // Check that start pos is on a valid token boundary. + if sp > 0 && !isTokenBoundary(v[sp-1]) { + continue + } + // Check that end pos is on a valid token boundary. + if endPos := sp + len(token); endPos != len(v) && !isTokenBoundary(v[endPos]) { + continue + } + if strings.EqualFold(v[sp:sp+len(token)], token) { + return true + } + } + return false +} + +func isTokenBoundary(b byte) bool { + return b == ' ' || b == ',' || b == '\t' +} diff --git a/libgo/go/net/http/httptest/server.go b/libgo/go/net/http/httptest/server.go index 57cf0c9..165600e 100644 --- a/libgo/go/net/http/httptest/server.go +++ b/libgo/go/net/http/httptest/server.go @@ -184,15 +184,15 @@ func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end // of ASN.1 time). var localhostCert = []byte(`-----BEGIN CERTIFICATE----- -MIIBOTCB5qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX +MIIBTTCB+qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX DTQ5MTIzMTIzNTk1OVowADBaMAsGCSqGSIb3DQEBAQNLADBIAkEAsuA5mAFMj6Q7 qoBzcvKzIq4kzuT5epSp2AkcQfyBHm7K13Ws7u+0b5Vb9gqTf5cAiIKcrtrXVqkL -8i1UQF6AzwIDAQABo08wTTAOBgNVHQ8BAf8EBAMCACQwDQYDVR0OBAYEBAECAwQw -DwYDVR0jBAgwBoAEAQIDBDAbBgNVHREEFDASggkxMjcuMC4wLjGCBVs6OjFdMAsG -CSqGSIb3DQEBBQNBAJH30zjLWRztrWpOCgJL8RQWLaKzhK79pVhAx6q/3NrF16C7 -+l1BRZstTwIGdoGId8BRpErK1TXkniFb95ZMynM= ------END CERTIFICATE----- -`) +8i1UQF6AzwIDAQABo2MwYTAOBgNVHQ8BAf8EBAMCACQwEgYDVR0TAQH/BAgwBgEB +/wIBATANBgNVHQ4EBgQEAQIDBDAPBgNVHSMECDAGgAQBAgMEMBsGA1UdEQQUMBKC +CTEyNy4wLjAuMYIFWzo6MV0wCwYJKoZIhvcNAQEFA0EAj1Jsn/h2KHy7dgqutZNB +nCGlNN+8vw263Bax9MklR85Ti6a0VWSvp/fDQZUADvmFTDkcXeA24pqmdUxeQDWw +Pg== +-----END CERTIFICATE-----`) // localhostKey is the private key for localhostCert. var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- diff --git a/libgo/go/net/http/httputil/dump.go b/libgo/go/net/http/httputil/dump.go index 892ef4e..0fb2eeb 100644 --- a/libgo/go/net/http/httputil/dump.go +++ b/libgo/go/net/http/httputil/dump.go @@ -89,7 +89,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) { t := &http.Transport{ Dial: func(net, addr string) (net.Conn, error) { - return &dumpConn{io.MultiWriter(pw, &buf), dr}, nil + return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil }, } diff --git a/libgo/go/net/http/pprof/pprof.go b/libgo/go/net/http/pprof/pprof.go index 06fcde1..7a9f465 100644 --- a/libgo/go/net/http/pprof/pprof.go +++ b/libgo/go/net/http/pprof/pprof.go @@ -14,6 +14,14 @@ // To use pprof, link this package into your program: // import _ "net/http/pprof" // +// If your application is not already running an http server, you +// need to start one. Add "net/http" and "log" to your imports and +// the following code to your main function: +// +// go func() { +// log.Println(http.ListenAndServe("localhost:6060", nil)) +// }() +// // Then use the pprof tool to look at the heap profile: // // go tool pprof http://localhost:6060/debug/pprof/heap diff --git a/libgo/go/net/http/range_test.go b/libgo/go/net/http/range_test.go index 5274a81..ef911af 100644 --- a/libgo/go/net/http/range_test.go +++ b/libgo/go/net/http/range_test.go @@ -14,15 +14,34 @@ var ParseRangeTests = []struct { r []httpRange }{ {"", 0, nil}, + {"", 1000, nil}, {"foo", 0, nil}, {"bytes=", 0, nil}, + {"bytes=7", 10, nil}, + {"bytes= 7 ", 10, nil}, + {"bytes=1-", 0, nil}, {"bytes=5-4", 10, nil}, {"bytes=0-2,5-4", 10, nil}, + {"bytes=2-5,4-3", 10, nil}, + {"bytes=--5,4--3", 10, nil}, + {"bytes=A-", 10, nil}, + {"bytes=A- ", 10, nil}, + {"bytes=A-Z", 10, nil}, + {"bytes= -Z", 10, nil}, + {"bytes=5-Z", 10, nil}, + {"bytes=Ran-dom, garbage", 10, nil}, + {"bytes=0x01-0x02", 10, nil}, + {"bytes= ", 10, nil}, + {"bytes= , , , ", 10, nil}, + {"bytes=0-9", 10, []httpRange{{0, 10}}}, {"bytes=0-", 10, []httpRange{{0, 10}}}, {"bytes=5-", 10, []httpRange{{5, 5}}}, {"bytes=0-20", 10, []httpRange{{0, 10}}}, {"bytes=15-,0-5", 10, nil}, + {"bytes=1-2,5-", 10, []httpRange{{1, 2}, {5, 5}}}, + {"bytes=-2 , 7-", 11, []httpRange{{9, 2}, {7, 4}}}, + {"bytes=0-0 ,2-2, 7-", 11, []httpRange{{0, 1}, {2, 1}, {7, 4}}}, {"bytes=-5", 10, []httpRange{{5, 5}}}, {"bytes=-15", 10, []httpRange{{0, 10}}}, {"bytes=0-499", 10000, []httpRange{{0, 500}}}, @@ -32,6 +51,9 @@ var ParseRangeTests = []struct { {"bytes=0-0,-1", 10000, []httpRange{{0, 1}, {9999, 1}}}, {"bytes=500-600,601-999", 10000, []httpRange{{500, 101}, {601, 399}}}, {"bytes=500-700,601-999", 10000, []httpRange{{500, 201}, {601, 399}}}, + + // Match Apache laxity: + {"bytes= 1 -2 , 4- 5, 7 - 8 , ,,", 11, []httpRange{{1, 2}, {4, 2}, {7, 2}}}, } func TestParseRange(t *testing.T) { diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go index b6a6b4c..c9d7393 100644 --- a/libgo/go/net/http/serve_test.go +++ b/libgo/go/net/http/serve_test.go @@ -386,17 +386,18 @@ func testTcpConnectionCloses(t *testing.T, req string, h Handler) { } r := bufio.NewReader(conn) - _, err = ReadResponse(r, &Request{Method: "GET"}) + res, err := ReadResponse(r, &Request{Method: "GET"}) if err != nil { t.Fatal("ReadResponse error:", err) } - success := make(chan bool) + didReadAll := make(chan bool, 1) go func() { select { case <-time.After(5 * time.Second): - t.Fatal("body not closed after 5s") - case <-success: + t.Error("body not closed after 5s") + return + case <-didReadAll: } }() @@ -404,8 +405,11 @@ func testTcpConnectionCloses(t *testing.T, req string, h Handler) { if err != nil { t.Fatal("read error:", err) } + didReadAll <- true - success <- true + if !res.Close { + t.Errorf("Response.Close = false; want true") + } } // TestServeHTTP10Close verifies that HTTP/1.0 requests won't be kept alive. @@ -1108,6 +1112,38 @@ func TestServerBufferedChunking(t *testing.T) { } } +// TestContentLengthZero tests that for both an HTTP/1.0 and HTTP/1.1 +// request (both keep-alive), when a Handler never writes any +// response, the net/http package adds a "Content-Length: 0" response +// header. +func TestContentLengthZero(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {})) + defer ts.Close() + + for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} { + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("error dialing: %v", err) + } + _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version) + if err != nil { + t.Fatalf("error writing: %v", err) + } + req, _ := NewRequest("GET", "/", nil) + res, err := ReadResponse(bufio.NewReader(conn), req) + if err != nil { + t.Fatalf("error reading response: %v", err) + } + if te := res.TransferEncoding; len(te) > 0 { + t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te) + } + if cl := res.ContentLength; cl != 0 { + t.Errorf("For version %q, Content-Length = %v; want 0", version, cl) + } + conn.Close() + } +} + // goTimeout runs f, failing t if f takes more than ns to complete. func goTimeout(t *testing.T, d time.Duration, f func()) { ch := make(chan bool, 2) diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go index 0572b4a..b74b762 100644 --- a/libgo/go/net/http/server.go +++ b/libgo/go/net/http/server.go @@ -390,6 +390,11 @@ func (w *response) WriteHeader(code int) { if !w.req.ProtoAtLeast(1, 0) { return } + + if w.closeAfterReply && !hasToken(w.header.Get("Connection"), "close") { + w.header.Set("Connection", "close") + } + proto := "HTTP/1.0" if w.req.ProtoAtLeast(1, 1) { proto = "HTTP/1.1" @@ -508,8 +513,16 @@ func (w *response) Write(data []byte) (n int, err error) { } func (w *response) finishRequest() { - // If this was an HTTP/1.0 request with keep-alive and we sent a Content-Length - // back, we can make this a keep-alive response ... + // If the handler never wrote any bytes and never sent a Content-Length + // response header, set the length explicitly to zero. This helps + // HTTP/1.0 clients keep their "keep-alive" connections alive, and for + // HTTP/1.1 clients is just as good as the alternative: sending a + // chunked response and immediately sending the zero-length EOF chunk. + if w.written == 0 && w.header.Get("Content-Length") == "" { + w.header.Set("Content-Length", "0") + } + // If this was an HTTP/1.0 request with keep-alive and we sent a + // Content-Length back, we can make this a keep-alive response ... if w.req.wantsHttp10KeepAlive() { sentLength := w.header.Get("Content-Length") != "" if sentLength && w.header.Get("Connection") == "keep-alive" { @@ -817,13 +830,13 @@ func RedirectHandler(url string, code int) Handler { // patterns and calls the handler for the pattern that // most closely matches the URL. // -// Patterns named fixed, rooted paths, like "/favicon.ico", +// Patterns name fixed, rooted paths, like "/favicon.ico", // or rooted subtrees, like "/images/" (note the trailing slash). // Longer patterns take precedence over shorter ones, so that // if there are handlers registered for both "/images/" // and "/images/thumbnails/", the latter handler will be // called for paths beginning "/images/thumbnails/" and the -// former will receiver requests for any other paths in the +// former will receive requests for any other paths in the // "/images/" subtree. // // Patterns may optionally begin with a host name, restricting matches to @@ -917,11 +930,13 @@ func (mux *ServeMux) handler(r *Request) Handler { // ServeHTTP dispatches the request to the handler whose // pattern most closely matches the request URL. func (mux *ServeMux) ServeHTTP(w ResponseWriter, r *Request) { - // Clean path to canonical form and redirect. - if p := cleanPath(r.URL.Path); p != r.URL.Path { - w.Header().Set("Location", p) - w.WriteHeader(StatusMovedPermanently) - return + if r.Method != "CONNECT" { + // Clean path to canonical form and redirect. + if p := cleanPath(r.URL.Path); p != r.URL.Path { + w.Header().Set("Location", p) + w.WriteHeader(StatusMovedPermanently) + return + } } mux.handler(r).ServeHTTP(w, r) } diff --git a/libgo/go/net/http/transport.go b/libgo/go/net/http/transport.go index 6efe191..6131d0d 100644 --- a/libgo/go/net/http/transport.go +++ b/libgo/go/net/http/transport.go @@ -41,8 +41,9 @@ const DefaultMaxIdleConnsPerHost = 2 // https, and http proxies (for either http or https with CONNECT). // Transport can also cache connections for future re-use. type Transport struct { - lk sync.Mutex + idleLk sync.Mutex idleConn map[string][]*persistConn + altLk sync.RWMutex altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper // TODO: tunable on global max cached connections @@ -131,12 +132,12 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { return nil, errors.New("http: nil Request.Header") } if req.URL.Scheme != "http" && req.URL.Scheme != "https" { - t.lk.Lock() + t.altLk.RLock() var rt RoundTripper if t.altProto != nil { rt = t.altProto[req.URL.Scheme] } - t.lk.Unlock() + t.altLk.RUnlock() if rt == nil { return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme} } @@ -170,8 +171,8 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) { if scheme == "http" || scheme == "https" { panic("protocol " + scheme + " already registered") } - t.lk.Lock() - defer t.lk.Unlock() + t.altLk.Lock() + defer t.altLk.Unlock() if t.altProto == nil { t.altProto = make(map[string]RoundTripper) } @@ -186,17 +187,18 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) { // a "keep-alive" state. It does not interrupt any connections currently // in use. func (t *Transport) CloseIdleConnections() { - t.lk.Lock() - defer t.lk.Unlock() - if t.idleConn == nil { + t.idleLk.Lock() + m := t.idleConn + t.idleConn = nil + t.idleLk.Unlock() + if m == nil { return } - for _, conns := range t.idleConn { + for _, conns := range m { for _, pconn := range conns { pconn.close() } } - t.idleConn = make(map[string][]*persistConn) } // @@ -242,8 +244,6 @@ func (cm *connectMethod) proxyAuth() string { // If pconn is no longer needed or not in a good state, putIdleConn // returns false. func (t *Transport) putIdleConn(pconn *persistConn) bool { - t.lk.Lock() - defer t.lk.Unlock() if t.DisableKeepAlives || t.MaxIdleConnsPerHost < 0 { pconn.close() return false @@ -256,21 +256,27 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { if max == 0 { max = DefaultMaxIdleConnsPerHost } + t.idleLk.Lock() + if t.idleConn == nil { + t.idleConn = make(map[string][]*persistConn) + } if len(t.idleConn[key]) >= max { + t.idleLk.Unlock() pconn.close() return false } t.idleConn[key] = append(t.idleConn[key], pconn) + t.idleLk.Unlock() return true } func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { - t.lk.Lock() - defer t.lk.Unlock() + key := cm.String() + t.idleLk.Lock() + defer t.idleLk.Unlock() if t.idleConn == nil { - t.idleConn = make(map[string][]*persistConn) + return nil } - key := cm.String() for { pconns, ok := t.idleConn[key] if !ok { @@ -365,7 +371,18 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { if cm.targetScheme == "https" { // Initiate TLS and check remote host name against certificate. - conn = tls.Client(conn, t.TLSClientConfig) + cfg := t.TLSClientConfig + if cfg == nil || cfg.ServerName == "" { + host, _, _ := net.SplitHostPort(cm.addr()) + if cfg == nil { + cfg = &tls.Config{ServerName: host} + } else { + clone := *cfg // shallow clone + clone.ServerName = host + cfg = &clone + } + } + conn = tls.Client(conn, cfg) if err = conn.(*tls.Conn).Handshake(); err != nil { return nil, err } @@ -484,6 +501,7 @@ type persistConn struct { t *Transport cacheKey string // its connectMethod.String() conn net.Conn + closed bool // whether conn has been closed br *bufio.Reader // from conn bw *bufio.Writer // to conn reqch chan requestAndChan // written by roundTrip(); read by readLoop() @@ -501,8 +519,9 @@ type persistConn struct { func (pc *persistConn) isBroken() bool { pc.lk.Lock() - defer pc.lk.Unlock() - return pc.broken + b := pc.broken + pc.lk.Unlock() + return b } var remoteSideClosedFunc func(error) bool // or nil to use default @@ -571,29 +590,32 @@ func (pc *persistConn) readLoop() { hasBody := resp != nil && resp.ContentLength != 0 var waitForBodyRead chan bool - if alive { - if hasBody { - lastbody = resp.Body - waitForBodyRead = make(chan bool) - resp.Body.(*bodyEOFSignal).fn = func() { - if !pc.t.putIdleConn(pc) { - alive = false - } - waitForBodyRead <- true - } - } else { - // When there's no response body, we immediately - // reuse the TCP connection (putIdleConn), but - // we need to prevent ClientConn.Read from - // closing the Response.Body on the next - // loop, otherwise it might close the body - // before the client code has had a chance to - // read it (even though it'll just be 0, EOF). - lastbody = nil - - if !pc.t.putIdleConn(pc) { + if hasBody { + lastbody = resp.Body + waitForBodyRead = make(chan bool) + resp.Body.(*bodyEOFSignal).fn = func() { + if alive && !pc.t.putIdleConn(pc) { alive = false } + if !alive { + pc.close() + } + waitForBodyRead <- true + } + } + + if alive && !hasBody { + // When there's no response body, we immediately + // reuse the TCP connection (putIdleConn), but + // we need to prevent ClientConn.Read from + // closing the Response.Body on the next + // loop, otherwise it might close the body + // before the client code has had a chance to + // read it (even though it'll just be 0, EOF). + lastbody = nil + + if !pc.t.putIdleConn(pc) { + alive = false } } @@ -604,6 +626,10 @@ func (pc *persistConn) readLoop() { if waitForBodyRead != nil { <-waitForBodyRead } + + if !alive { + pc.close() + } } } @@ -669,7 +695,10 @@ func (pc *persistConn) close() { func (pc *persistConn) closeLocked() { pc.broken = true - pc.conn.Close() + if !pc.closed { + pc.conn.Close() + pc.closed = true + } pc.mutateHeaderFunc = nil } diff --git a/libgo/go/net/http/transport_test.go b/libgo/go/net/http/transport_test.go index a9e401d..e676bf6 100644 --- a/libgo/go/net/http/transport_test.go +++ b/libgo/go/net/http/transport_test.go @@ -13,6 +13,7 @@ import ( "fmt" "io" "io/ioutil" + "net" . "net/http" "net/http/httptest" "net/url" @@ -20,6 +21,7 @@ import ( "runtime" "strconv" "strings" + "sync" "testing" "time" ) @@ -35,6 +37,68 @@ var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte(r.RemoteAddr)) }) +// testCloseConn is a net.Conn tracked by a testConnSet. +type testCloseConn struct { + net.Conn + set *testConnSet +} + +func (c *testCloseConn) Close() error { + c.set.remove(c) + return c.Conn.Close() +} + +// testConnSet tracks a set of TCP connections and whether they've +// been closed. +type testConnSet struct { + t *testing.T + closed map[net.Conn]bool + list []net.Conn // in order created + mutex sync.Mutex +} + +func (tcs *testConnSet) insert(c net.Conn) { + tcs.mutex.Lock() + defer tcs.mutex.Unlock() + tcs.closed[c] = false + tcs.list = append(tcs.list, c) +} + +func (tcs *testConnSet) remove(c net.Conn) { + tcs.mutex.Lock() + defer tcs.mutex.Unlock() + tcs.closed[c] = true +} + +// some tests use this to manage raw tcp connections for later inspection +func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, error)) { + connSet := &testConnSet{ + t: t, + closed: make(map[net.Conn]bool), + } + dial := func(n, addr string) (net.Conn, error) { + c, err := net.Dial(n, addr) + if err != nil { + return nil, err + } + tc := &testCloseConn{c, connSet} + connSet.insert(tc) + return tc, nil + } + return connSet, dial +} + +func (tcs *testConnSet) check(t *testing.T) { + tcs.mutex.Lock() + defer tcs.mutex.Unlock() + + for i, c := range tcs.list { + if !tcs.closed[c] { + t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list)) + } + } +} + // Two subsequent requests and verify their response is the same. // The response from the server is our own IP:port func TestTransportKeepAlives(t *testing.T) { @@ -72,8 +136,12 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) { ts := httptest.NewServer(hostPortHandler) defer ts.Close() + connSet, testDial := makeTestDial(t) + for _, connectionClose := range []bool{false, true} { - tr := &Transport{} + tr := &Transport{ + Dial: testDial, + } c := &Client{Transport: tr} fetch := func(n int) string { @@ -92,8 +160,8 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) { if err != nil { t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err) } - body, err := ioutil.ReadAll(res.Body) defer res.Body.Close() + body, err := ioutil.ReadAll(res.Body) if err != nil { t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) } @@ -107,15 +175,23 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) { t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", connectionClose, bodiesDiffer, body1, body2) } + + tr.CloseIdleConnections() } + + connSet.check(t) } func TestTransportConnectionCloseOnRequest(t *testing.T) { ts := httptest.NewServer(hostPortHandler) defer ts.Close() + connSet, testDial := makeTestDial(t) + for _, connectionClose := range []bool{false, true} { - tr := &Transport{} + tr := &Transport{ + Dial: testDial, + } c := &Client{Transport: tr} fetch := func(n int) string { @@ -149,7 +225,11 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) { t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", connectionClose, bodiesDiffer, body1, body2) } + + tr.CloseIdleConnections() } + + connSet.check(t) } func TestTransportIdleCacheKeys(t *testing.T) { @@ -724,6 +804,35 @@ func TestTransportIdleConnCrash(t *testing.T) { <-didreq } +// Test that the transport doesn't close the TCP connection early, +// before the response body has been read. This was a regression +// which sadly lacked a triggering test. The large response body made +// the old race easier to trigger. +func TestIssue3644(t *testing.T) { + const numFoos = 5000 + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Connection", "close") + for i := 0; i < numFoos; i++ { + w.Write([]byte("foo ")) + } + })) + defer ts.Close() + tr := &Transport{} + c := &Client{Transport: tr} + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + bs, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if len(bs) != numFoos*len("foo ") { + t.Errorf("unexpected response length") + } +} + type fooProto struct{} func (fooProto) RoundTrip(req *Request) (*Response, error) { diff --git a/libgo/go/net/iprawsock.go b/libgo/go/net/iprawsock.go index b23213e..ae21b3c 100644 --- a/libgo/go/net/iprawsock.go +++ b/libgo/go/net/iprawsock.go @@ -6,7 +6,7 @@ package net -// IPAddr represents the address of a IP end point. +// IPAddr represents the address of an IP end point. type IPAddr struct { IP IP } @@ -21,7 +21,7 @@ func (a *IPAddr) String() string { return a.IP.String() } -// ResolveIPAddr parses addr as a IP address and resolves domain +// ResolveIPAddr parses addr as an IP address and resolves domain // names to numeric addresses on the network net, which must be // "ip", "ip4" or "ip6". A literal IPv6 host address must be // enclosed in square brackets, as in "[::]". diff --git a/libgo/go/net/iprawsock_plan9.go b/libgo/go/net/iprawsock_plan9.go index 43719fc..ea3321b 100644 --- a/libgo/go/net/iprawsock_plan9.go +++ b/libgo/go/net/iprawsock_plan9.go @@ -59,7 +59,7 @@ func (c *IPConn) RemoteAddr() Addr { // IP-specific methods. -// ReadFromIP reads a IP packet from c, copying the payload into b. +// ReadFromIP reads an IP packet from c, copying the payload into b. // It returns the number of bytes copied into b and the return address // that was on the packet. // @@ -75,7 +75,7 @@ func (c *IPConn) ReadFrom(b []byte) (int, Addr, error) { return 0, nil, syscall.EPLAN9 } -// WriteToIP writes a IP packet to addr via c, copying the payload from b. +// WriteToIP writes an IP packet to addr via c, copying the payload from b. // // WriteToIP can be made to time out and return // an error with Timeout() == true after a fixed time limit; diff --git a/libgo/go/net/iprawsock_posix.go b/libgo/go/net/iprawsock_posix.go index 9fc7ecd..dda81dd 100644 --- a/libgo/go/net/iprawsock_posix.go +++ b/libgo/go/net/iprawsock_posix.go @@ -146,7 +146,7 @@ func (c *IPConn) SetWriteBuffer(bytes int) error { // IP-specific methods. -// ReadFromIP reads a IP packet from c, copying the payload into b. +// ReadFromIP reads an IP packet from c, copying the payload into b. // It returns the number of bytes copied into b and the return address // that was on the packet. // @@ -184,7 +184,7 @@ func (c *IPConn) ReadFrom(b []byte) (int, Addr, error) { return n, uaddr.toAddr(), err } -// WriteToIP writes a IP packet to addr via c, copying the payload from b. +// WriteToIP writes an IP packet to addr via c, copying the payload from b. // // WriteToIP can be made to time out and return // an error with Timeout() == true after a fixed time limit; diff --git a/libgo/go/net/mail/message.go b/libgo/go/net/mail/message.go index b610ccf..93cc4d1 100644 --- a/libgo/go/net/mail/message.go +++ b/libgo/go/net/mail/message.go @@ -47,7 +47,8 @@ type Message struct { } // ReadMessage reads a message from r. -// The headers are parsed, and the body of the message will be reading from r. +// The headers are parsed, and the body of the message will be available +// for reading from r. func ReadMessage(r io.Reader) (msg *Message, err error) { tp := textproto.NewReader(bufio.NewReader(r)) diff --git a/libgo/go/net/net_posix.go b/libgo/go/net/net_posix.go new file mode 100644 index 0000000..3bcc54f --- /dev/null +++ b/libgo/go/net/net_posix.go @@ -0,0 +1,110 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin freebsd linux netbsd openbsd windows + +// Base posix socket functions. + +package net + +import ( + "os" + "syscall" + "time" +) + +type conn struct { + fd *netFD +} + +func (c *conn) ok() bool { return c != nil && c.fd != nil } + +// Implementation of the Conn interface - see Conn for documentation. + +// Read implements the Conn Read method. +func (c *conn) Read(b []byte) (int, error) { + if !c.ok() { + return 0, syscall.EINVAL + } + return c.fd.Read(b) +} + +// Write implements the Conn Write method. +func (c *conn) Write(b []byte) (int, error) { + if !c.ok() { + return 0, syscall.EINVAL + } + return c.fd.Write(b) +} + +// LocalAddr returns the local network address. +func (c *conn) LocalAddr() Addr { + if !c.ok() { + return nil + } + return c.fd.laddr +} + +// RemoteAddr returns the remote network address. +func (c *conn) RemoteAddr() Addr { + if !c.ok() { + return nil + } + return c.fd.raddr +} + +// SetDeadline implements the Conn SetDeadline method. +func (c *conn) SetDeadline(t time.Time) error { + if !c.ok() { + return syscall.EINVAL + } + return setDeadline(c.fd, t) +} + +// SetReadDeadline implements the Conn SetReadDeadline method. +func (c *conn) SetReadDeadline(t time.Time) error { + if !c.ok() { + return syscall.EINVAL + } + return setReadDeadline(c.fd, t) +} + +// SetWriteDeadline implements the Conn SetWriteDeadline method. +func (c *conn) SetWriteDeadline(t time.Time) error { + if !c.ok() { + return syscall.EINVAL + } + return setWriteDeadline(c.fd, t) +} + +// SetReadBuffer sets the size of the operating system's +// receive buffer associated with the connection. +func (c *conn) SetReadBuffer(bytes int) error { + if !c.ok() { + return syscall.EINVAL + } + return setReadBuffer(c.fd, bytes) +} + +// SetWriteBuffer sets the size of the operating system's +// transmit buffer associated with the connection. +func (c *conn) SetWriteBuffer(bytes int) error { + if !c.ok() { + return syscall.EINVAL + } + return setWriteBuffer(c.fd, bytes) +} + +// File returns a copy of the underlying os.File, set to blocking mode. +// It is the caller's responsibility to close f when finished. +// Closing c does not affect f, and closing f does not affect c. +func (c *conn) File() (f *os.File, err error) { return c.fd.dup() } + +// Close closes the connection. +func (c *conn) Close() error { + if !c.ok() { + return syscall.EINVAL + } + return c.fd.Close() +} diff --git a/libgo/go/net/rpc/jsonrpc/all_test.go b/libgo/go/net/rpc/jsonrpc/all_test.go index e6c7441..adc29d5 100644 --- a/libgo/go/net/rpc/jsonrpc/all_test.go +++ b/libgo/go/net/rpc/jsonrpc/all_test.go @@ -108,7 +108,7 @@ func TestClient(t *testing.T) { t.Errorf("Add: expected no error but got string %q", err.Error()) } if reply.C != args.A+args.B { - t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) + t.Errorf("Add: got %d expected %d", reply.C, args.A+args.B) } args = &Args{7, 8} @@ -118,7 +118,7 @@ func TestClient(t *testing.T) { t.Errorf("Mul: expected no error but got string %q", err.Error()) } if reply.C != args.A*args.B { - t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B) + t.Errorf("Mul: got %d expected %d", reply.C, args.A*args.B) } // Out of order. @@ -133,7 +133,7 @@ func TestClient(t *testing.T) { t.Errorf("Add: expected no error but got string %q", addCall.Error.Error()) } if addReply.C != args.A+args.B { - t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B) + t.Errorf("Add: got %d expected %d", addReply.C, args.A+args.B) } mulCall = <-mulCall.Done @@ -141,7 +141,7 @@ func TestClient(t *testing.T) { t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error()) } if mulReply.C != args.A*args.B { - t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B) + t.Errorf("Mul: got %d expected %d", mulReply.C, args.A*args.B) } // Error test diff --git a/libgo/go/net/rpc/server.go b/libgo/go/net/rpc/server.go index 1680e2f..e528220 100644 --- a/libgo/go/net/rpc/server.go +++ b/libgo/go/net/rpc/server.go @@ -24,12 +24,13 @@ where T, T1 and T2 can be marshaled by encoding/gob. These requirements apply even if a different codec is used. - (In future, these requirements may soften for custom codecs.) + (In the future, these requirements may soften for custom codecs.) The method's first argument represents the arguments provided by the caller; the second argument represents the result parameters to be returned to the caller. The method's return value, if non-nil, is passed back as a string that the client - sees as if created by errors.New. + sees as if created by errors.New. If an error is returned, the reply parameter + will not be sent back to the client. The server may handle requests on a single connection by calling ServeConn. More typically it will create a network listener and call Accept or, for an HTTP @@ -181,7 +182,7 @@ type Response struct { // Server represents an RPC Server. type Server struct { - mu sync.Mutex // protects the serviceMap + mu sync.RWMutex // protects the serviceMap serviceMap map[string]*service reqLock sync.Mutex // protects freeReq freeReq *Request @@ -538,9 +539,9 @@ func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mt return } // Look up the request. - server.mu.Lock() + server.mu.RLock() service = server.serviceMap[serviceMethod[0]] - server.mu.Unlock() + server.mu.RUnlock() if service == nil { err = errors.New("rpc: can't find service " + req.ServiceMethod) return diff --git a/libgo/go/net/sockopt.go b/libgo/go/net/sockopt.go index 0cd1926..b139c42 100644 --- a/libgo/go/net/sockopt.go +++ b/libgo/go/net/sockopt.go @@ -144,22 +144,6 @@ func setDeadline(fd *netFD, t time.Time) error { return setWriteDeadline(fd, t) } -func setReuseAddr(fd *netFD, reuse bool) error { - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, boolint(reuse))) -} - -func setDontRoute(fd *netFD, dontroute bool) error { - if err := fd.incref(false); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_DONTROUTE, boolint(dontroute))) -} - func setKeepAlive(fd *netFD, keepalive bool) error { if err := fd.incref(false); err != nil { return err |