diff options
author | Ian Lance Taylor <ian@gcc.gnu.org> | 2011-10-26 23:57:58 +0000 |
---|---|---|
committer | Ian Lance Taylor <ian@gcc.gnu.org> | 2011-10-26 23:57:58 +0000 |
commit | d8f412571f8768df2d3239e72392dfeabbad1559 (patch) | |
tree | 19d182df05ead7ff8ba7ee00a7d57555e1383fdf /libgo/go/http | |
parent | e0c39d66d4f0607177b1cf8995dda56a667e07b3 (diff) | |
download | gcc-d8f412571f8768df2d3239e72392dfeabbad1559.zip gcc-d8f412571f8768df2d3239e72392dfeabbad1559.tar.gz gcc-d8f412571f8768df2d3239e72392dfeabbad1559.tar.bz2 |
Update Go library to last weekly.
From-SVN: r180552
Diffstat (limited to 'libgo/go/http')
34 files changed, 1261 insertions, 1943 deletions
diff --git a/libgo/go/http/cgi/child.go b/libgo/go/http/cgi/child.go index 8d0eca8..bf14c04 100644 --- a/libgo/go/http/cgi/child.go +++ b/libgo/go/http/cgi/child.go @@ -93,20 +93,20 @@ func RequestFromMap(params map[string]string) (*http.Request, os.Error) { if r.Host != "" { // Hostname is provided, so we can reasonably construct a URL, // even if we have to assume 'http' for the scheme. - r.RawURL = "http://" + r.Host + params["REQUEST_URI"] - url, err := url.Parse(r.RawURL) + rawurl := "http://" + r.Host + params["REQUEST_URI"] + url, err := url.Parse(rawurl) if err != nil { - return nil, os.NewError("cgi: failed to parse host and REQUEST_URI into a URL: " + r.RawURL) + return nil, os.NewError("cgi: failed to parse host and REQUEST_URI into a URL: " + rawurl) } r.URL = url } // Fallback logic if we don't have a Host header or the URL // failed to parse if r.URL == nil { - r.RawURL = params["REQUEST_URI"] - url, err := url.Parse(r.RawURL) + uriStr := params["REQUEST_URI"] + url, err := url.Parse(uriStr) if err != nil { - return nil, os.NewError("cgi: failed to parse REQUEST_URI into a URL: " + r.RawURL) + return nil, os.NewError("cgi: failed to parse REQUEST_URI into a URL: " + uriStr) } r.URL = url } diff --git a/libgo/go/http/cgi/child_test.go b/libgo/go/http/cgi/child_test.go index eee043b..ec53ab8 100644 --- a/libgo/go/http/cgi/child_test.go +++ b/libgo/go/http/cgi/child_test.go @@ -49,9 +49,6 @@ func TestRequest(t *testing.T) { if g, e := req.Header.Get("Foo-Bar"), "baz"; e != g { t.Errorf("expected Foo-Bar %q; got %q", e, g) } - if g, e := req.RawURL, "http://example.com/path?a=b"; e != g { - t.Errorf("expected RawURL %q; got %q", e, g) - } if g, e := req.URL.String(), "http://example.com/path?a=b"; e != g { t.Errorf("expected URL %q; got %q", e, g) } @@ -81,9 +78,6 @@ func TestRequestWithoutHost(t *testing.T) { if err != nil { t.Fatalf("RequestFromMap: %v", err) } - if g, e := req.RawURL, "/path?a=b"; e != g { - t.Errorf("expected RawURL %q; got %q", e, g) - } if req.URL == nil { t.Fatalf("unexpected nil URL") } diff --git a/libgo/go/http/cgi/host.go b/libgo/go/http/cgi/host.go index f7de89f..9ea4c9d 100644 --- a/libgo/go/http/cgi/host.go +++ b/libgo/go/http/cgi/host.go @@ -32,13 +32,14 @@ import ( var trailingPort = regexp.MustCompile(`:([0-9]+)$`) var osDefaultInheritEnv = map[string][]string{ - "darwin": []string{"DYLD_LIBRARY_PATH"}, - "freebsd": []string{"LD_LIBRARY_PATH"}, - "hpux": []string{"LD_LIBRARY_PATH", "SHLIB_PATH"}, - "irix": []string{"LD_LIBRARY_PATH", "LD_LIBRARYN32_PATH", "LD_LIBRARY64_PATH"}, - "linux": []string{"LD_LIBRARY_PATH"}, - "solaris": []string{"LD_LIBRARY_PATH", "LD_LIBRARY_PATH_32", "LD_LIBRARY_PATH_64"}, - "windows": []string{"SystemRoot", "COMSPEC", "PATHEXT", "WINDIR"}, + "darwin": {"DYLD_LIBRARY_PATH"}, + "freebsd": {"LD_LIBRARY_PATH"}, + "hpux": {"LD_LIBRARY_PATH", "SHLIB_PATH"}, + "irix": {"LD_LIBRARY_PATH", "LD_LIBRARYN32_PATH", "LD_LIBRARY64_PATH"}, + "linux": {"LD_LIBRARY_PATH"}, + "openbsd": {"LD_LIBRARY_PATH"}, + "solaris": {"LD_LIBRARY_PATH", "LD_LIBRARY_PATH_32", "LD_LIBRARY_PATH_64"}, + "windows": {"SystemRoot", "COMSPEC", "PATHEXT", "WINDIR"}, } // Handler runs an executable in a subprocess with a CGI environment. @@ -68,6 +69,31 @@ type Handler struct { PathLocationHandler http.Handler } +// removeLeadingDuplicates remove leading duplicate in environments. +// It's possible to override environment like following. +// cgi.Handler{ +// ... +// Env: []string{"SCRIPT_FILENAME=foo.php"}, +// } +func removeLeadingDuplicates(env []string) (ret []string) { + n := len(env) + for i := 0; i < n; i++ { + e := env[i] + s := strings.SplitN(e, "=", 2)[0] + found := false + for j := i + 1; j < n; j++ { + if s == strings.SplitN(env[j], "=", 2)[0] { + found = true + break + } + } + if !found { + ret = append(ret, e) + } + } + return +} + func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { root := h.Root if root == "" { @@ -149,6 +175,8 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } } + env = removeLeadingDuplicates(env) + var cwd, path string if h.Dir != "" { path = h.Path @@ -294,7 +322,6 @@ func (h *Handler) handleInternalRedirect(rw http.ResponseWriter, req *http.Reque newReq := &http.Request{ Method: "GET", URL: url, - RawURL: path, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, diff --git a/libgo/go/http/cgi/host_test.go b/libgo/go/http/cgi/host_test.go index ff46631..6c0f1a0 100644 --- a/libgo/go/http/cgi/host_test.go +++ b/libgo/go/http/cgi/host_test.go @@ -451,3 +451,32 @@ func TestDirWindows(t *testing.T) { } runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) } + +func TestEnvOverride(t *testing.T) { + cgifile, _ := filepath.Abs("testdata/test.cgi") + + var perl string + var err os.Error + perl, err = exec.LookPath("perl") + if err != nil { + return + } + perl, _ = filepath.Abs(perl) + + cwd, _ := os.Getwd() + h := &Handler{ + Path: perl, + Root: "/test.cgi", + Dir: cwd, + Args: []string{cgifile}, + Env: []string{ + "SCRIPT_FILENAME=" + cgifile, + "REQUEST_URI=/foo/bar"}, + } + expectedMap := map[string]string{ + "cwd": cwd, + "env-SCRIPT_FILENAME": cgifile, + "env-REQUEST_URI": "/foo/bar", + } + runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) +} diff --git a/libgo/go/http/chunked.go b/libgo/go/http/chunked.go index 6c23e69..eff9ae2 100644 --- a/libgo/go/http/chunked.go +++ b/libgo/go/http/chunked.go @@ -5,11 +5,11 @@ package http import ( + "bufio" "io" "log" "os" "strconv" - "bufio" ) // NewChunkedWriter returns a new writer that translates writes into HTTP diff --git a/libgo/go/http/client.go b/libgo/go/http/client.go index 44b3443..3fa4a05 100644 --- a/libgo/go/http/client.go +++ b/libgo/go/http/client.go @@ -56,9 +56,10 @@ type RoundTripper interface { // higher-level protocol details such as redirects, // authentication, or cookies. // - // RoundTrip may modify the request. The request Headers field is - // guaranteed to be initialized. - RoundTrip(req *Request) (resp *Response, err os.Error) + // RoundTrip should not modify the request, except for + // consuming the Body. The request's URL and Header fields + // are guaranteed to be initialized. + RoundTrip(*Request) (*Response, os.Error) } // Given a string of the form "host", "host:port", or "[ipv6::address]:port", @@ -76,7 +77,12 @@ type readClose struct { // Do sends an HTTP request and returns an HTTP response, following // policy (e.g. redirects, cookies, auth) as configured on the client. // -// Callers should close resp.Body when done reading from it. +// A non-nil response always contains a non-nil resp.Body. +// +// Callers should close resp.Body when done reading from it. If +// resp.Body is not closed, the Client's underlying RoundTripper +// (typically Transport) may not be able to re-use a persistent TCP +// connection to the server for a subsequent "keep-alive" request. // // Generally Get, Post, or PostForm will be used instead of Do. func (c *Client) Do(req *Request) (resp *Response, err os.Error) { @@ -91,11 +97,15 @@ func send(req *Request, t RoundTripper) (resp *Response, err os.Error) { if t == nil { t = DefaultTransport if t == nil { - err = os.NewError("no http.Client.Transport or http.DefaultTransport") + err = os.NewError("http: no Client.Transport or DefaultTransport") return } } + if req.URL == nil { + return nil, os.NewError("http: nil Request.URL") + } + // Most the callers of send (Get, Post, et al) don't need // Headers, leaving it uninitialized. We guarantee to the // Transport that this has been initialized, though. @@ -105,9 +115,6 @@ func send(req *Request, t RoundTripper) (resp *Response, err os.Error) { info := req.URL.RawUserinfo if len(info) > 0 { - if req.Header == nil { - req.Header = make(Header) - } req.Header.Set("Authorization", "Basic "+base64.URLEncoding.EncodeToString([]byte(info))) } return t.RoundTrip(req) @@ -166,6 +173,10 @@ func (c *Client) doFollowingRedirects(ireq *Request) (r *Response, err os.Error) } var via []*Request + if ireq.URL == nil { + return nil, os.NewError("http: nil Request.URL") + } + req := ireq urlStr := "" // next relative or absolute URL to fetch (after first request) for redirect := 0; ; redirect++ { diff --git a/libgo/go/http/client_test.go b/libgo/go/http/client_test.go index 8efb1d9..0ad6cd7 100644 --- a/libgo/go/http/client_test.go +++ b/libgo/go/http/client_test.go @@ -132,7 +132,9 @@ func TestPostFormRequestFormat(t *testing.T) { if tr.req.Close { t.Error("got Close true, want false") } - expectedBody := "bar=baz&foo=bar&foo=bar2" + // Depending on map iteration, body can be either of these. + expectedBody := "foo=bar&foo=bar2&bar=baz" + expectedBody1 := "bar=baz&foo=bar&foo=bar2" if g, e := tr.req.ContentLength, int64(len(expectedBody)); g != e { t.Errorf("got ContentLength %d, want %d", g, e) } @@ -140,8 +142,8 @@ func TestPostFormRequestFormat(t *testing.T) { if err != nil { t.Fatalf("ReadAll on req.Body: %v", err) } - if g := string(bodyb); g != expectedBody { - t.Errorf("got body %q, want %q", g, expectedBody) + if g := string(bodyb); g != expectedBody && g != expectedBody1 { + t.Errorf("got body %q, want %q or %q", g, expectedBody, expectedBody1) } } diff --git a/libgo/go/http/cookie.go b/libgo/go/http/cookie.go index fe70431..6935014 100644 --- a/libgo/go/http/cookie.go +++ b/libgo/go/http/cookie.go @@ -207,17 +207,16 @@ func readCookies(h Header, filter string) []*Cookie { return cookies } +var cookieNameSanitizer = strings.NewReplacer("\n", "-", "\r", "-") + func sanitizeName(n string) string { - n = strings.Replace(n, "\n", "-", -1) - n = strings.Replace(n, "\r", "-", -1) - return n + return cookieNameSanitizer.Replace(n) } +var cookieValueSanitizer = strings.NewReplacer("\n", " ", "\r", " ", ";", " ") + func sanitizeValue(v string) string { - v = strings.Replace(v, "\n", " ", -1) - v = strings.Replace(v, "\r", " ", -1) - v = strings.Replace(v, ";", " ", -1) - return v + return cookieValueSanitizer.Replace(v) } func unquoteCookieValue(v string) string { diff --git a/libgo/go/http/cookie_test.go b/libgo/go/http/cookie_test.go index d7aeda0..5de6aab 100644 --- a/libgo/go/http/cookie_test.go +++ b/libgo/go/http/cookie_test.go @@ -124,7 +124,7 @@ var readSetCookiesTests = []struct { Path: "/", Domain: ".google.ch", HttpOnly: true, - Expires: time.Time{Year: 2011, Month: 11, Day: 23, Hour: 1, Minute: 5, Second: 3, Weekday: 3, ZoneOffset: 0, Zone: "GMT"}, + Expires: time.Time{Year: 2011, Month: 11, Day: 23, Hour: 1, Minute: 5, Second: 3, ZoneOffset: 0, Zone: "GMT"}, RawExpires: "Wed, 23-Nov-2011 01:05:03 GMT", Raw: "NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly", }}, diff --git a/libgo/go/http/dump.go b/libgo/go/http/dump.go index 358980f..f78df57 100644 --- a/libgo/go/http/dump.go +++ b/libgo/go/http/dump.go @@ -44,7 +44,7 @@ func DumpRequest(req *Request, body bool) (dump []byte, err os.Error) { return } } - err = req.Write(&b) + err = req.dumpWrite(&b) req.Body = save if err != nil { return diff --git a/libgo/go/http/fcgi/child.go b/libgo/go/http/fcgi/child.go index 1971882..61dd3fb 100644 --- a/libgo/go/http/fcgi/child.go +++ b/libgo/go/http/fcgi/child.go @@ -194,7 +194,7 @@ func (c *child) serve() { case typeData: // If the filter role is implemented, read the data stream here. case typeAbortRequest: - requests[rec.h.Id] = nil, false + delete(requests, rec.h.Id) c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete) if !req.keepConn { // connection will close upon return diff --git a/libgo/go/http/fcgi/fcgi_test.go b/libgo/go/http/fcgi/fcgi_test.go index 16a6243..5c8e46b 100644 --- a/libgo/go/http/fcgi/fcgi_test.go +++ b/libgo/go/http/fcgi/fcgi_test.go @@ -53,13 +53,13 @@ var streamTests = []struct { {"two records", typeStdin, 300, make([]byte, 66000), bytes.Join([][]byte{ // header for the first record - []byte{1, typeStdin, 0x01, 0x2C, 0xFF, 0xFF, 1, 0}, + {1, typeStdin, 0x01, 0x2C, 0xFF, 0xFF, 1, 0}, make([]byte, 65536), // header for the second - []byte{1, typeStdin, 0x01, 0x2C, 0x01, 0xD1, 7, 0}, + {1, typeStdin, 0x01, 0x2C, 0x01, 0xD1, 7, 0}, make([]byte, 472), // header for the empty record - []byte{1, typeStdin, 0x01, 0x2C, 0, 0, 0, 0}, + {1, typeStdin, 0x01, 0x2C, 0, 0, 0, 0}, }, nil), }, diff --git a/libgo/go/http/filetransport.go b/libgo/go/http/filetransport.go new file mode 100644 index 0000000..78f3aa2 --- /dev/null +++ b/libgo/go/http/filetransport.go @@ -0,0 +1,124 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "fmt" + "io" + "os" +) + +// fileTransport implements RoundTripper for the 'file' protocol. +type fileTransport struct { + fh fileHandler +} + +// NewFileTransport returns a new RoundTripper, serving the provided +// FileSystem. The returned RoundTripper ignores the URL host in its +// incoming requests, as well as most other properties of the +// request. +// +// The typical use case for NewFileTransport is to register the "file" +// protocol with a Transport, as in: +// +// t := &http.Transport{} +// t.RegisterProtocol("file", http.NewFileTransport(http.Dir("/"))) +// c := &http.Client{Transport: t} +// res, err := c.Get("file:///etc/passwd") +// ... +func NewFileTransport(fs FileSystem) RoundTripper { + return fileTransport{fileHandler{fs}} +} + +func (t fileTransport) RoundTrip(req *Request) (resp *Response, err os.Error) { + // We start ServeHTTP in a goroutine, which may take a long + // time if the file is large. The newPopulateResponseWriter + // call returns a channel which either ServeHTTP or finish() + // sends our *Response on, once the *Response itself has been + // populated (even if the body itself is still being + // written to the res.Body, a pipe) + rw, resc := newPopulateResponseWriter() + go func() { + t.fh.ServeHTTP(rw, req) + rw.finish() + }() + return <-resc, nil +} + +func newPopulateResponseWriter() (*populateResponse, <-chan *Response) { + pr, pw := io.Pipe() + rw := &populateResponse{ + ch: make(chan *Response), + pw: pw, + res: &Response{ + Proto: "HTTP/1.0", + ProtoMajor: 1, + Header: make(Header), + Close: true, + Body: pr, + }, + } + return rw, rw.ch +} + +// populateResponse is a ResponseWriter that populates the *Response +// in res, and writes its body to a pipe connected to the response +// body. Once writes begin or finish() is called, the response is sent +// on ch. +type populateResponse struct { + res *Response + ch chan *Response + wroteHeader bool + hasContent bool + sentResponse bool + pw *io.PipeWriter +} + +func (pr *populateResponse) finish() { + if !pr.wroteHeader { + pr.WriteHeader(500) + } + if !pr.sentResponse { + pr.sendResponse() + } + pr.pw.Close() +} + +func (pr *populateResponse) sendResponse() { + if pr.sentResponse { + return + } + pr.sentResponse = true + + if pr.hasContent { + pr.res.ContentLength = -1 + } + pr.ch <- pr.res +} + +func (pr *populateResponse) Header() Header { + return pr.res.Header +} + +func (pr *populateResponse) WriteHeader(code int) { + if pr.wroteHeader { + return + } + pr.wroteHeader = true + + pr.res.StatusCode = code + pr.res.Status = fmt.Sprintf("%d %s", code, StatusText(code)) +} + +func (pr *populateResponse) Write(p []byte) (n int, err os.Error) { + if !pr.wroteHeader { + pr.WriteHeader(StatusOK) + } + pr.hasContent = true + if !pr.sentResponse { + pr.sendResponse() + } + return pr.pw.Write(p) +} diff --git a/libgo/go/http/filetransport_test.go b/libgo/go/http/filetransport_test.go new file mode 100644 index 0000000..2634243 --- /dev/null +++ b/libgo/go/http/filetransport_test.go @@ -0,0 +1,63 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "http" + "io/ioutil" + "path/filepath" + "os" + "testing" +) + +func checker(t *testing.T) func(string, os.Error) { + return func(call string, err os.Error) { + if err == nil { + return + } + t.Fatalf("%s: %v", call, err) + } +} + +func TestFileTransport(t *testing.T) { + check := checker(t) + + dname, err := ioutil.TempDir("", "") + check("TempDir", err) + fname := filepath.Join(dname, "foo.txt") + err = ioutil.WriteFile(fname, []byte("Bar"), 0644) + check("WriteFile", err) + + tr := &http.Transport{} + tr.RegisterProtocol("file", http.NewFileTransport(http.Dir(dname))) + c := &http.Client{Transport: tr} + + fooURLs := []string{"file:///foo.txt", "file://../foo.txt"} + for _, urlstr := range fooURLs { + res, err := c.Get(urlstr) + check("Get "+urlstr, err) + if res.StatusCode != 200 { + t.Errorf("for %s, StatusCode = %d, want 200", urlstr, res.StatusCode) + } + if res.ContentLength != -1 { + t.Errorf("for %s, ContentLength = %d, want -1", urlstr, res.ContentLength) + } + if res.Body == nil { + t.Fatalf("for %s, nil Body", urlstr) + } + slurp, err := ioutil.ReadAll(res.Body) + check("ReadAll "+urlstr, err) + if string(slurp) != "Bar" { + t.Errorf("for %s, got content %q, want %q", urlstr, string(slurp), "Bar") + } + } + + const badURL = "file://../no-exist.txt" + res, err := c.Get(badURL) + check("Get "+badURL, err) + if res.StatusCode != 404 { + t.Errorf("for %s, StatusCode = %d, want 404", badURL, res.StatusCode) + } +} diff --git a/libgo/go/http/fs.go b/libgo/go/http/fs.go index 2c7c636..6d71665 100644 --- a/libgo/go/http/fs.go +++ b/libgo/go/http/fs.go @@ -219,7 +219,7 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec w.WriteHeader(code) if r.Method != "HEAD" { - io.Copyn(w, f, size) + io.CopyN(w, f, size) } } diff --git a/libgo/go/http/header.go b/libgo/go/http/header.go index 08b0771..aaaa92a 100644 --- a/libgo/go/http/header.go +++ b/libgo/go/http/header.go @@ -47,6 +47,8 @@ func (h Header) Write(w io.Writer) os.Error { return h.WriteSubset(w, nil) } +var headerNewlineToSpace = strings.NewReplacer("\n", " ", "\r", " ") + // WriteSubset writes a header in wire format. // If exclude is not nil, keys where exclude[key] == true are not written. func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) os.Error { @@ -59,8 +61,7 @@ func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) os.Error { sort.Strings(keys) for _, k := range keys { for _, v := range h[k] { - v = strings.Replace(v, "\n", " ", -1) - v = strings.Replace(v, "\r", " ", -1) + v = headerNewlineToSpace.Replace(v) v = strings.TrimSpace(v) if _, err := fmt.Fprintf(w, "%s: %s\r\n", k, v); err != nil { return err diff --git a/libgo/go/http/httptest/server.go b/libgo/go/http/httptest/server.go index 2ec36d0..43a48eb 100644 --- a/libgo/go/http/httptest/server.go +++ b/libgo/go/http/httptest/server.go @@ -23,6 +23,10 @@ type Server struct { URL string // base URL of form http://ipaddr:port with no trailing slash Listener net.Listener TLS *tls.Config // nil if not using using TLS + + // Config may be changed after calling NewUnstartedServer and + // before Start or StartTLS. + Config *http.Server } // historyListener keeps track of all connections that it's ever @@ -41,6 +45,13 @@ func (hs *historyListener) Accept() (c net.Conn, err os.Error) { } func newLocalListener() net.Listener { + if *serve != "" { + l, err := net.Listen("tcp", *serve) + if err != nil { + panic(fmt.Sprintf("httptest: failed to listen on %v: %v", *serve, err)) + } + return l + } l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { if l, err = net.Listen("tcp6", "[::1]:0"); err != nil { @@ -59,51 +70,66 @@ var serve = flag.String("httptest.serve", "", "if non-empty, httptest.NewServer // NewServer starts and returns a new Server. // The caller should call Close when finished, to shut it down. func NewServer(handler http.Handler) *Server { - ts := new(Server) - var l net.Listener - if *serve != "" { - var err os.Error - l, err = net.Listen("tcp", *serve) - if err != nil { - panic(fmt.Sprintf("httptest: failed to listen on %v: %v", *serve, err)) - } - } else { - l = newLocalListener() + ts := NewUnstartedServer(handler) + ts.Start() + return ts +} + +// NewUnstartedServer returns a new Server but doesn't start it. +// +// After changing its configuration, the caller should call Start or +// StartTLS. +// +// The caller should call Close when finished, to shut it down. +func NewUnstartedServer(handler http.Handler) *Server { + return &Server{ + Listener: newLocalListener(), + Config: &http.Server{Handler: handler}, } - ts.Listener = &historyListener{l, make([]net.Conn, 0)} - ts.URL = "http://" + l.Addr().String() - server := &http.Server{Handler: handler} - go server.Serve(ts.Listener) +} + +// Start starts a server from NewUnstartedServer. +func (s *Server) Start() { + if s.URL != "" { + panic("Server already started") + } + s.Listener = &historyListener{s.Listener, make([]net.Conn, 0)} + s.URL = "http://" + s.Listener.Addr().String() + go s.Config.Serve(s.Listener) if *serve != "" { - fmt.Println(os.Stderr, "httptest: serving on", ts.URL) + fmt.Println(os.Stderr, "httptest: serving on", s.URL) select {} } - return ts } -// NewTLSServer starts and returns a new Server using TLS. -// The caller should call Close when finished, to shut it down. -func NewTLSServer(handler http.Handler) *Server { - l := newLocalListener() - ts := new(Server) - +// StartTLS starts TLS on a server from NewUnstartedServer. +func (s *Server) StartTLS() { + if s.URL != "" { + panic("Server already started") + } cert, err := tls.X509KeyPair(localhostCert, localhostKey) if err != nil { panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) } - ts.TLS = &tls.Config{ + s.TLS = &tls.Config{ Rand: rand.Reader, Time: time.Seconds, NextProtos: []string{"http/1.1"}, Certificates: []tls.Certificate{cert}, } - tlsListener := tls.NewListener(l, ts.TLS) + tlsListener := tls.NewListener(s.Listener, s.TLS) + + s.Listener = &historyListener{tlsListener, make([]net.Conn, 0)} + s.URL = "https://" + s.Listener.Addr().String() + go s.Config.Serve(s.Listener) +} - ts.Listener = &historyListener{tlsListener, make([]net.Conn, 0)} - ts.URL = "https://" + l.Addr().String() - server := &http.Server{Handler: handler} - go server.Serve(ts.Listener) +// NewTLSServer starts and returns a new Server using TLS. +// The caller should call Close when finished, to shut it down. +func NewTLSServer(handler http.Handler) *Server { + ts := NewUnstartedServer(handler) + ts.StartTLS() return ts } diff --git a/libgo/go/http/persist.go b/libgo/go/http/persist.go index 78bf905..f73e6c6 100644 --- a/libgo/go/http/persist.go +++ b/libgo/go/http/persist.go @@ -165,7 +165,7 @@ func (sc *ServerConn) Write(req *Request, resp *Response) os.Error { // Retrieve the pipeline ID of this request/response pair sc.lk.Lock() id, ok := sc.pipereq[req] - sc.pipereq[req] = 0, false + delete(sc.pipereq, req) if !ok { sc.lk.Unlock() return ErrPipeline @@ -353,7 +353,7 @@ func (cc *ClientConn) readUsing(req *Request, readRes func(*bufio.Reader, *Reque // Retrieve the pipeline ID of this request/response pair cc.lk.Lock() id, ok := cc.pipereq[req] - cc.pipereq[req] = 0, false + delete(cc.pipereq, req) if !ok { cc.lk.Unlock() return nil, ErrPipeline diff --git a/libgo/go/http/readrequest_test.go b/libgo/go/http/readrequest_test.go index f6dc99e..6d9042a 100644 --- a/libgo/go/http/readrequest_test.go +++ b/libgo/go/http/readrequest_test.go @@ -40,7 +40,6 @@ var reqTests = []reqTest{ &Request{ Method: "GET", - RawURL: "http://www.techcrunch.com/", URL: &url.URL{ Raw: "http://www.techcrunch.com/", Scheme: "http", @@ -83,7 +82,6 @@ var reqTests = []reqTest{ &Request{ Method: "GET", - RawURL: "/", URL: &url.URL{ Raw: "/", Path: "/", @@ -110,7 +108,6 @@ var reqTests = []reqTest{ &Request{ Method: "GET", - RawURL: "//user@host/is/actually/a/path/", URL: &url.URL{ Raw: "//user@host/is/actually/a/path/", Scheme: "", diff --git a/libgo/go/http/request.go b/libgo/go/http/request.go index ed41fa4..02317e0 100644 --- a/libgo/go/http/request.go +++ b/libgo/go/http/request.go @@ -64,18 +64,24 @@ func (e *badStringError) String() string { return fmt.Sprintf("%s %q", e.what, e // Headers that Request.Write handles itself and should be skipped. var reqWriteExcludeHeader = map[string]bool{ - "Host": true, + "Host": true, // not in Header map anyway "User-Agent": true, "Content-Length": true, "Transfer-Encoding": true, "Trailer": true, } +var reqWriteExcludeHeaderDump = map[string]bool{ + "Host": true, // not in Header map anyway + "Content-Length": true, + "Transfer-Encoding": true, + "Trailer": true, +} + // A Request represents a parsed HTTP request header. type Request struct { - Method string // GET, POST, PUT, etc. - RawURL string // The raw URL given in the request. - URL *url.URL // Parsed URL. + Method string // GET, POST, PUT, etc. + URL *url.URL // The protocol version for incoming requests. // Outgoing requests always use HTTP/1.1. @@ -234,8 +240,8 @@ func (r *Request) multipartReader() (*multipart.Reader, os.Error) { if v == "" { return nil, ErrNotMultipart } - d, params := mime.ParseMediaType(v) - if d != "multipart/form-data" { + d, params, err := mime.ParseMediaType(v) + if err != nil || d != "multipart/form-data" { return nil, ErrNotMultipart } boundary, ok := params["boundary"] @@ -258,7 +264,7 @@ const defaultUserAgent = "Go http package" // Write writes an HTTP/1.1 request -- header and body -- in wire format. // This method consults the following fields of req: // Host -// RawURL, if non-empty, or else URL +// URL // Method (defaults to "GET") // Header // ContentLength @@ -269,19 +275,66 @@ const defaultUserAgent = "Go http package" // hasn't been set to "identity", Write adds "Transfer-Encoding: // chunked" to the header. Body is closed after it is sent. func (req *Request) Write(w io.Writer) os.Error { - return req.write(w, false) + return req.write(w, false, nil) } // WriteProxy is like Write but writes the request in the form -// expected by an HTTP proxy. It includes the scheme and host -// name in the URI instead of using a separate Host: header line. -// If req.RawURL is non-empty, WriteProxy uses it unchanged -// instead of URL but still omits the Host: header. +// expected by an HTTP proxy. In particular, WriteProxy writes the +// initial Request-URI line of the request with an absolute URI, per +// section 5.1.2 of RFC 2616, including the scheme and host. In +// either case, WriteProxy also writes a Host header, using either +// req.Host or req.URL.Host. func (req *Request) WriteProxy(w io.Writer) os.Error { - return req.write(w, true) + return req.write(w, true, nil) +} + +func (req *Request) dumpWrite(w io.Writer) os.Error { + // TODO(bradfitz): RawPath here? + urlStr := valueOrDefault(req.URL.EncodedPath(), "/") + if req.URL.RawQuery != "" { + urlStr += "?" + req.URL.RawQuery + } + + bw := bufio.NewWriter(w) + fmt.Fprintf(bw, "%s %s HTTP/%d.%d\r\n", valueOrDefault(req.Method, "GET"), urlStr, + req.ProtoMajor, req.ProtoMinor) + + host := req.Host + if host == "" && req.URL != nil { + host = req.URL.Host + } + if host != "" { + fmt.Fprintf(bw, "Host: %s\r\n", host) + } + + // Process Body,ContentLength,Close,Trailer + tw, err := newTransferWriter(req) + if err != nil { + return err + } + err = tw.WriteHeader(bw) + if err != nil { + return err + } + + err = req.Header.WriteSubset(bw, reqWriteExcludeHeaderDump) + if err != nil { + return err + } + + io.WriteString(bw, "\r\n") + + // Write body and trailer + err = tw.WriteBody(bw) + if err != nil { + return err + } + bw.Flush() + return nil } -func (req *Request) write(w io.Writer, usingProxy bool) os.Error { +// extraHeaders may be nil +func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) os.Error { host := req.Host if host == "" { if req.URL == nil { @@ -290,9 +343,12 @@ func (req *Request) write(w io.Writer, usingProxy bool) os.Error { host = req.URL.Host } - urlStr := req.RawURL + urlStr := req.URL.RawPath + if strings.HasPrefix(urlStr, "?") { + urlStr = "/" + urlStr // Issue 2344 + } if urlStr == "" { - urlStr = valueOrDefault(req.URL.EncodedPath(), "/") + urlStr = valueOrDefault(req.URL.RawPath, valueOrDefault(req.URL.EncodedPath(), "/")) if req.URL.RawQuery != "" { urlStr += "?" + req.URL.RawQuery } @@ -303,6 +359,7 @@ func (req *Request) write(w io.Writer, usingProxy bool) os.Error { urlStr = req.URL.Scheme + "://" + host + urlStr } } + // TODO(bradfitz): escape at least newlines in urlStr? bw := bufio.NewWriter(w) fmt.Fprintf(bw, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), urlStr) @@ -338,6 +395,13 @@ func (req *Request) write(w io.Writer, usingProxy bool) os.Error { return err } + if extraHeaders != nil { + err = extraHeaders.Write(bw) + if err != nil { + return err + } + } + io.WriteString(bw, "\r\n") // Write body and trailer @@ -542,13 +606,14 @@ func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) { if f = strings.SplitN(s, " ", 3); len(f) < 3 { return nil, &badStringError{"malformed HTTP request", s} } - req.Method, req.RawURL, req.Proto = f[0], f[1], f[2] + var rawurl string + req.Method, rawurl, req.Proto = f[0], f[1], f[2] var ok bool if req.ProtoMajor, req.ProtoMinor, ok = ParseHTTPVersion(req.Proto); !ok { return nil, &badStringError{"malformed HTTP version", req.Proto} } - if req.URL, err = url.ParseRequest(req.RawURL); err != nil { + if req.URL, err = url.ParseRequest(rawurl); err != nil { return nil, err } @@ -608,27 +673,77 @@ func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) { return req, nil } -// ParseForm parses the raw query. -// For POST requests, it also parses the request body as a form. +// MaxBytesReader is similar to io.LimitReader but is intended for +// limiting the size of incoming request bodies. In contrast to +// io.LimitReader, MaxBytesReader's result is a ReadCloser, returns a +// non-EOF error for a Read beyond the limit, and Closes the +// underlying reader when its Close method is called. +// +// MaxBytesReader prevents clients from accidentally or maliciously +// sending a large request and wasting server resources. +func MaxBytesReader(w ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser { + return &maxBytesReader{w: w, r: r, n: n} +} + +type maxBytesReader struct { + w ResponseWriter + r io.ReadCloser // underlying reader + n int64 // max bytes remaining + stopped bool +} + +func (l *maxBytesReader) Read(p []byte) (n int, err os.Error) { + if l.n <= 0 { + if !l.stopped { + l.stopped = true + if res, ok := l.w.(*response); ok { + res.requestTooLarge() + } + } + return 0, os.NewError("http: request body too large") + } + if int64(len(p)) > l.n { + p = p[:l.n] + } + n, err = l.r.Read(p) + l.n -= int64(n) + return +} + +func (l *maxBytesReader) Close() os.Error { + return l.r.Close() +} + +// ParseForm parses the raw query from the URL. +// +// For POST or PUT requests, it also parses the request body as a form. +// If the request Body's size has not already been limited by MaxBytesReader, +// the size is capped at 10MB. +// // ParseMultipartForm calls ParseForm automatically. // It is idempotent. func (r *Request) ParseForm() (err os.Error) { if r.Form != nil { return } - if r.URL != nil { r.Form, err = url.ParseQuery(r.URL.RawQuery) } - if r.Method == "POST" { + if r.Method == "POST" || r.Method == "PUT" { if r.Body == nil { return os.NewError("missing form body") } ct := r.Header.Get("Content-Type") - switch strings.SplitN(ct, ";", 2)[0] { - case "text/plain", "application/x-www-form-urlencoded", "": - const maxFormSize = int64(10 << 20) // 10 MB is a lot of text. - b, e := ioutil.ReadAll(io.LimitReader(r.Body, maxFormSize+1)) + ct, _, err := mime.ParseMediaType(ct) + switch { + case ct == "text/plain" || ct == "application/x-www-form-urlencoded" || ct == "": + var reader io.Reader = r.Body + maxFormSize := int64(1<<63 - 1) + if _, ok := r.Body.(*maxBytesReader); !ok { + maxFormSize = int64(10 << 20) // 10 MB is a lot of text. + reader = io.LimitReader(r.Body, maxFormSize+1) + } + b, e := ioutil.ReadAll(reader) if e != nil { if err == nil { err = e @@ -652,8 +767,13 @@ func (r *Request) ParseForm() (err os.Error) { r.Form.Add(k, value) } } - case "multipart/form-data": - // handled by ParseMultipartForm + case ct == "multipart/form-data": + // handled by ParseMultipartForm (which is calling us, or should be) + // TODO(bradfitz): there are too many possible + // orders to call too many functions here. + // Clean this up and write more tests. + // request_test.go contains the start of this, + // in TestRequestMultipartCallOrder. default: return &badStringError{"unknown Content-Type", ct} } diff --git a/libgo/go/http/request_test.go b/libgo/go/http/request_test.go index 869cd57..175d6f1 100644 --- a/libgo/go/http/request_test.go +++ b/libgo/go/http/request_test.go @@ -20,57 +20,6 @@ import ( "url" ) -type stringMultimap map[string][]string - -type parseTest struct { - query string - out stringMultimap -} - -var parseTests = []parseTest{ - { - query: "a=1&b=2", - out: stringMultimap{"a": []string{"1"}, "b": []string{"2"}}, - }, - { - query: "a=1&a=2&a=banana", - out: stringMultimap{"a": []string{"1", "2", "banana"}}, - }, - { - query: "ascii=%3Ckey%3A+0x90%3E", - out: stringMultimap{"ascii": []string{"<key: 0x90>"}}, - }, -} - -func TestParseForm(t *testing.T) { - for i, test := range parseTests { - form, err := url.ParseQuery(test.query) - if err != nil { - t.Errorf("test %d: Unexpected error: %v", i, err) - continue - } - if len(form) != len(test.out) { - t.Errorf("test %d: len(form) = %d, want %d", i, len(form), len(test.out)) - } - for k, evs := range test.out { - vs, ok := form[k] - if !ok { - t.Errorf("test %d: Missing key %q", i, k) - continue - } - if len(vs) != len(evs) { - t.Errorf("test %d: len(form[%q]) = %d, want %d", i, k, len(vs), len(evs)) - continue - } - for j, ev := range evs { - if v := vs[j]; v != ev { - t.Errorf("test %d: form[%q][%d] = %q, want %q", i, k, j, v, ev) - } - } - } - } -} - func TestQuery(t *testing.T) { req := &Request{Method: "GET"} req.URL, _ = url.Parse("http://www.google.com/search?q=foo&q=bar") diff --git a/libgo/go/http/requestwrite_test.go b/libgo/go/http/requestwrite_test.go index 458f0bd..194f6dd 100644 --- a/libgo/go/http/requestwrite_test.go +++ b/libgo/go/http/requestwrite_test.go @@ -16,18 +16,22 @@ import ( ) type reqWriteTest struct { - Req Request - Body interface{} // optional []byte or func() io.ReadCloser to populate Req.Body - Raw string - RawProxy string + Req Request + Body interface{} // optional []byte or func() io.ReadCloser to populate Req.Body + + // Any of these three may be empty to skip that test. + WantWrite string // Request.Write + WantProxy string // Request.WriteProxy + WantDump string // DumpRequest + + WantError os.Error // wanted error from Request.Write } var reqWriteTests = []reqWriteTest{ // HTTP/1.1 => chunked coding; no body; no trailer { - Request{ + Req: Request{ Method: "GET", - RawURL: "http://www.techcrunch.com/", URL: &url.URL{ Raw: "http://www.techcrunch.com/", Scheme: "http", @@ -57,9 +61,7 @@ var reqWriteTests = []reqWriteTest{ Form: map[string][]string{}, }, - nil, - - "GET http://www.techcrunch.com/ HTTP/1.1\r\n" + + WantWrite: "GET http://www.techcrunch.com/ HTTP/1.1\r\n" + "Host: www.techcrunch.com\r\n" + "User-Agent: Fake\r\n" + "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" + @@ -69,7 +71,7 @@ var reqWriteTests = []reqWriteTest{ "Keep-Alive: 300\r\n" + "Proxy-Connection: keep-alive\r\n\r\n", - "GET http://www.techcrunch.com/ HTTP/1.1\r\n" + + WantProxy: "GET http://www.techcrunch.com/ HTTP/1.1\r\n" + "Host: www.techcrunch.com\r\n" + "User-Agent: Fake\r\n" + "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" + @@ -81,7 +83,7 @@ var reqWriteTests = []reqWriteTest{ }, // HTTP/1.1 => chunked coding; body; empty trailer { - Request{ + Req: Request{ Method: "GET", URL: &url.URL{ Scheme: "http", @@ -94,23 +96,28 @@ var reqWriteTests = []reqWriteTest{ TransferEncoding: []string{"chunked"}, }, - []byte("abcdef"), + Body: []byte("abcdef"), - "GET /search HTTP/1.1\r\n" + + WantWrite: "GET /search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "User-Agent: Go http package\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + chunk("abcdef") + chunk(""), - "GET http://www.google.com/search HTTP/1.1\r\n" + + WantProxy: "GET http://www.google.com/search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "User-Agent: Go http package\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + chunk("abcdef") + chunk(""), + + WantDump: "GET /search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + chunk("abcdef") + chunk(""), }, // HTTP/1.1 POST => chunked coding; body; empty trailer { - Request{ + Req: Request{ Method: "POST", URL: &url.URL{ Scheme: "http", @@ -124,16 +131,16 @@ var reqWriteTests = []reqWriteTest{ TransferEncoding: []string{"chunked"}, }, - []byte("abcdef"), + Body: []byte("abcdef"), - "POST /search HTTP/1.1\r\n" + + WantWrite: "POST /search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "User-Agent: Go http package\r\n" + "Connection: close\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + chunk("abcdef") + chunk(""), - "POST http://www.google.com/search HTTP/1.1\r\n" + + WantProxy: "POST http://www.google.com/search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "User-Agent: Go http package\r\n" + "Connection: close\r\n" + @@ -143,7 +150,7 @@ var reqWriteTests = []reqWriteTest{ // HTTP/1.1 POST with Content-Length, no chunking { - Request{ + Req: Request{ Method: "POST", URL: &url.URL{ Scheme: "http", @@ -157,9 +164,9 @@ var reqWriteTests = []reqWriteTest{ ContentLength: 6, }, - []byte("abcdef"), + Body: []byte("abcdef"), - "POST /search HTTP/1.1\r\n" + + WantWrite: "POST /search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "User-Agent: Go http package\r\n" + "Connection: close\r\n" + @@ -167,7 +174,7 @@ var reqWriteTests = []reqWriteTest{ "\r\n" + "abcdef", - "POST http://www.google.com/search HTTP/1.1\r\n" + + WantProxy: "POST http://www.google.com/search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "User-Agent: Go http package\r\n" + "Connection: close\r\n" + @@ -178,9 +185,9 @@ var reqWriteTests = []reqWriteTest{ // HTTP/1.1 POST with Content-Length in headers { - Request{ + Req: Request{ Method: "POST", - RawURL: "http://example.com/", + URL: mustParseURL("http://example.com/"), Host: "example.com", Header: Header{ "Content-Length": []string{"10"}, // ignored @@ -188,16 +195,16 @@ var reqWriteTests = []reqWriteTest{ ContentLength: 6, }, - []byte("abcdef"), + Body: []byte("abcdef"), - "POST http://example.com/ HTTP/1.1\r\n" + + WantWrite: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + "User-Agent: Go http package\r\n" + "Content-Length: 6\r\n" + "\r\n" + "abcdef", - "POST http://example.com/ HTTP/1.1\r\n" + + WantProxy: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + "User-Agent: Go http package\r\n" + "Content-Length: 6\r\n" + @@ -207,21 +214,13 @@ var reqWriteTests = []reqWriteTest{ // default to HTTP/1.1 { - Request{ + Req: Request{ Method: "GET", - RawURL: "/search", + URL: mustParseURL("/search"), Host: "www.google.com", }, - nil, - - "GET /search HTTP/1.1\r\n" + - "Host: www.google.com\r\n" + - "User-Agent: Go http package\r\n" + - "\r\n", - - // Looks weird but RawURL overrides what WriteProxy would choose. - "GET /search HTTP/1.1\r\n" + + WantWrite: "GET /search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + "User-Agent: Go http package\r\n" + "\r\n", @@ -229,53 +228,125 @@ var reqWriteTests = []reqWriteTest{ // Request with a 0 ContentLength and a 0 byte body. { - Request{ + Req: Request{ Method: "POST", - RawURL: "/", + URL: mustParseURL("/"), Host: "example.com", ProtoMajor: 1, ProtoMinor: 1, ContentLength: 0, // as if unset by user }, - func() io.ReadCloser { return ioutil.NopCloser(io.LimitReader(strings.NewReader("xx"), 0)) }, + Body: func() io.ReadCloser { return ioutil.NopCloser(io.LimitReader(strings.NewReader("xx"), 0)) }, - "POST / HTTP/1.1\r\n" + + // RFC 2616 Section 14.13 says Content-Length should be specified + // unless body is prohibited by the request method. + // Also, nginx expects it for POST and PUT. + WantWrite: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + "User-Agent: Go http package\r\n" + + "Content-Length: 0\r\n" + "\r\n", - "POST / HTTP/1.1\r\n" + + WantProxy: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + "User-Agent: Go http package\r\n" + + "Content-Length: 0\r\n" + "\r\n", }, // Request with a 0 ContentLength and a 1 byte body. { - Request{ + Req: Request{ Method: "POST", - RawURL: "/", + URL: mustParseURL("/"), Host: "example.com", ProtoMajor: 1, ProtoMinor: 1, ContentLength: 0, // as if unset by user }, - func() io.ReadCloser { return ioutil.NopCloser(io.LimitReader(strings.NewReader("xx"), 1)) }, + Body: func() io.ReadCloser { return ioutil.NopCloser(io.LimitReader(strings.NewReader("xx"), 1)) }, - "POST / HTTP/1.1\r\n" + + WantWrite: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + "User-Agent: Go http package\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + chunk("x") + chunk(""), - "POST / HTTP/1.1\r\n" + + WantProxy: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + "User-Agent: Go http package\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + chunk("x") + chunk(""), }, + + // Request with a ContentLength of 10 but a 5 byte body. + { + Req: Request{ + Method: "POST", + URL: mustParseURL("/"), + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 10, // but we're going to send only 5 bytes + }, + Body: []byte("12345"), + WantError: os.NewError("http: Request.ContentLength=10 with Body length 5"), + }, + + // Request with a ContentLength of 4 but an 8 byte body. + { + Req: Request{ + Method: "POST", + URL: mustParseURL("/"), + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 4, // but we're going to try to send 8 bytes + }, + Body: []byte("12345678"), + WantError: os.NewError("http: Request.ContentLength=4 with Body length 8"), + }, + + // Request with a 5 ContentLength and nil body. + { + Req: Request{ + Method: "POST", + URL: mustParseURL("/"), + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 5, // but we'll omit the body + }, + WantError: os.NewError("http: Request.ContentLength=5 with nil Body"), + }, + + // Verify that DumpRequest preserves the HTTP version number, doesn't add a Host, + // and doesn't add a User-Agent. + { + Req: Request{ + Method: "GET", + URL: mustParseURL("/foo"), + ProtoMajor: 1, + ProtoMinor: 0, + Header: Header{ + "X-Foo": []string{"X-Bar"}, + }, + }, + + // We can dump it: + WantDump: "GET /foo HTTP/1.0\r\n" + + "X-Foo: X-Bar\r\n\r\n", + + // .. but we can't call Request.Write on it, due to its lack of Host header. + // TODO(bradfitz): there might be an argument to allow this, but for now I'd + // rather let HTTP/1.0 continue to die. + WantWrite: "GET /foo HTTP/1.1\r\n" + + "Host: \r\n" + + "User-Agent: Go http package\r\n" + + "X-Foo: X-Bar\r\n\r\n", + }, } func TestRequestWrite(t *testing.T) { @@ -283,6 +354,9 @@ func TestRequestWrite(t *testing.T) { tt := &reqWriteTests[i] setBody := func() { + if tt.Body == nil { + return + } switch b := tt.Body.(type) { case []byte: tt.Req.Body = ioutil.NopCloser(bytes.NewBuffer(b)) @@ -290,37 +364,55 @@ func TestRequestWrite(t *testing.T) { tt.Req.Body = b() } } - if tt.Body != nil { - setBody() - } + setBody() if tt.Req.Header == nil { tt.Req.Header = make(Header) } + var braw bytes.Buffer err := tt.Req.Write(&braw) - if err != nil { - t.Errorf("error writing #%d: %s", i, err) + if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.WantError); g != e { + t.Errorf("writing #%d, err = %q, want %q", i, g, e) continue } - sraw := braw.String() - if sraw != tt.Raw { - t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, tt.Raw, sraw) + if err != nil { continue } - if tt.Body != nil { - setBody() + if tt.WantWrite != "" { + sraw := braw.String() + if sraw != tt.WantWrite { + t.Errorf("Test %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantWrite, sraw) + continue + } } - var praw bytes.Buffer - err = tt.Req.WriteProxy(&praw) - if err != nil { - t.Errorf("error writing #%d: %s", i, err) - continue + + if tt.WantProxy != "" { + setBody() + var praw bytes.Buffer + err = tt.Req.WriteProxy(&praw) + if err != nil { + t.Errorf("WriteProxy #%d: %s", i, err) + continue + } + sraw := praw.String() + if sraw != tt.WantProxy { + t.Errorf("Test Proxy %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantProxy, sraw) + continue + } } - sraw = praw.String() - if sraw != tt.RawProxy { - t.Errorf("Test Proxy %d, expecting:\n%s\nGot:\n%s\n", i, tt.RawProxy, sraw) - continue + + if tt.WantDump != "" { + setBody() + dump, err := DumpRequest(&tt.Req, true) + if err != nil { + t.Errorf("DumpRequest #%d: %s", i, err) + continue + } + if string(dump) != tt.WantDump { + t.Errorf("DumpRequest %d, expecting:\n%s\nGot:\n%s\n", i, tt.WantDump, string(dump)) + continue + } } } } @@ -368,3 +460,11 @@ func TestRequestWriteClosesBody(t *testing.T) { func chunk(s string) string { return fmt.Sprintf("%x\r\n%s\r\n", len(s), s) } + +func mustParseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + panic(fmt.Sprintf("Error parsing URL %q: %v", s, err)) + } + return u +} diff --git a/libgo/go/http/response.go b/libgo/go/http/response.go index 915327a..56c65b53 100644 --- a/libgo/go/http/response.go +++ b/libgo/go/http/response.go @@ -13,6 +13,7 @@ import ( "os" "strconv" "strings" + "url" ) var respExcludeHeader = map[string]bool{ @@ -41,6 +42,10 @@ type Response struct { Header Header // Body represents the response body. + // + // The http Client and Transport guarantee that Body is always + // non-nil, even on responses without a body or responses with + // a zero-lengthed body. Body io.ReadCloser // ContentLength records the length of the associated content. The @@ -73,6 +78,23 @@ func (r *Response) Cookies() []*Cookie { return readSetCookies(r.Header) } +var ErrNoLocation = os.NewError("http: no Location header in response") + +// Location returns the URL of the response's "Location" header, +// if present. Relative redirects are resolved relative to +// the Response's Request. ErrNoLocation is returned if no +// Location header is present. +func (r *Response) Location() (*url.URL, os.Error) { + lv := r.Header.Get("Location") + if lv == "" { + return nil, ErrNoLocation + } + if r.Request != nil && r.Request.URL != nil { + return r.Request.URL.Parse(lv) + } + return url.Parse(lv) +} + // ReadResponse reads and returns an HTTP response from r. The // req parameter specifies the Request that corresponds to // this Response. Clients must call resp.Body.Close when finished diff --git a/libgo/go/http/response_test.go b/libgo/go/http/response_test.go index 1d4a234..86494bf 100644 --- a/libgo/go/http/response_test.go +++ b/libgo/go/http/response_test.go @@ -15,6 +15,7 @@ import ( "io/ioutil" "reflect" "testing" + "url" ) type respTest struct { @@ -395,3 +396,52 @@ func diff(t *testing.T, prefix string, have, want interface{}) { } } } + +type responseLocationTest struct { + location string // Response's Location header or "" + requrl string // Response.Request.URL or "" + want string + wantErr os.Error +} + +var responseLocationTests = []responseLocationTest{ + {"/foo", "http://bar.com/baz", "http://bar.com/foo", nil}, + {"http://foo.com/", "http://bar.com/baz", "http://foo.com/", nil}, + {"", "http://bar.com/baz", "", ErrNoLocation}, +} + +func TestLocationResponse(t *testing.T) { + for i, tt := range responseLocationTests { + res := new(Response) + res.Header = make(Header) + res.Header.Set("Location", tt.location) + if tt.requrl != "" { + res.Request = &Request{} + var err os.Error + res.Request.URL, err = url.Parse(tt.requrl) + if err != nil { + t.Fatalf("bad test URL %q: %v", tt.requrl, err) + } + } + + got, err := res.Location() + if tt.wantErr != nil { + if err == nil { + t.Errorf("%d. err=nil; want %q", i, tt.wantErr) + continue + } + if g, e := err.String(), tt.wantErr.String(); g != e { + t.Errorf("%d. err=%q; want %q", i, g, e) + continue + } + continue + } + if err != nil { + t.Errorf("%d. err=%q", i, err) + continue + } + if g, e := got.String(), tt.want; g != e { + t.Errorf("%d. Location=%q; want %q", i, g, e) + } + } +} diff --git a/libgo/go/http/serve_test.go b/libgo/go/http/serve_test.go index ac04033..2ff66d5 100644 --- a/libgo/go/http/serve_test.go +++ b/libgo/go/http/serve_test.go @@ -9,14 +9,15 @@ package http_test import ( "bufio" "bytes" + "crypto/tls" "fmt" . "http" "http/httptest" "io" "io/ioutil" "log" - "os" "net" + "os" "reflect" "strings" "syscall" @@ -356,18 +357,17 @@ func TestIdentityResponse(t *testing.T) { if err != nil { t.Fatalf("error writing: %v", err) } - // The next ReadAll will hang for a failing test, so use a Timer instead - // to fail more traditionally - timer := time.AfterFunc(2e9, func() { - t.Fatalf("Timeout expired in ReadAll.") + + // The ReadAll will hang for a failing test, so use a Timer to + // fail explicitly. + goTimeout(t, 2e9, func() { + got, _ := ioutil.ReadAll(conn) + expectedSuffix := "\r\n\r\ntoo short" + if !strings.HasSuffix(string(got), expectedSuffix) { + t.Errorf("Expected output to end with %q; got response body %q", + expectedSuffix, string(got)) + } }) - defer timer.Stop() - got, _ := ioutil.ReadAll(conn) - expectedSuffix := "\r\n\r\ntoo short" - if !strings.HasSuffix(string(got), expectedSuffix) { - t.Fatalf("Expected output to end with %q; got response body %q", - expectedSuffix, string(got)) - } } func testTcpConnectionCloses(t *testing.T, req string, h Handler) { @@ -535,6 +535,25 @@ func TestHeadResponses(t *testing.T) { } } +func TestTLSHandshakeTimeout(t *testing.T) { + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + ts.Config.ReadTimeout = 250e6 + ts.StartTLS() + defer ts.Close() + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + goTimeout(t, 10e9, func() { + var buf [1]byte + n, err := conn.Read(buf[:]) + if err == nil || n != 0 { + t.Errorf("Read = %d, %v; want an error and no bytes", n, err) + } + }) +} + func TestTLSServer(t *testing.T) { ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.TLS != nil { @@ -545,23 +564,46 @@ func TestTLSServer(t *testing.T) { } })) defer ts.Close() - if !strings.HasPrefix(ts.URL, "https://") { - t.Fatalf("expected test TLS server to start with https://, got %q", ts.URL) - } - res, err := Get(ts.URL) + + // Connect an idle TCP connection to this server before we run + // our real tests. This idle connection used to block forever + // in the TLS handshake, preventing future connections from + // being accepted. It may prevent future accidental blocking + // in newConn. + idleConn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { - t.Fatal(err) - } - if res == nil { - t.Fatalf("got nil Response") - } - defer res.Body.Close() - if res.Header.Get("X-TLS-Set") != "true" { - t.Errorf("expected X-TLS-Set response header") - } - if res.Header.Get("X-TLS-HandshakeComplete") != "true" { - t.Errorf("expected X-TLS-HandshakeComplete header") + t.Fatalf("Dial: %v", err) } + defer idleConn.Close() + goTimeout(t, 10e9, func() { + if !strings.HasPrefix(ts.URL, "https://") { + t.Errorf("expected test TLS server to start with https://, got %q", ts.URL) + return + } + noVerifyTransport := &Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + } + client := &Client{Transport: noVerifyTransport} + res, err := client.Get(ts.URL) + if err != nil { + t.Error(err) + return + } + if res == nil { + t.Errorf("got nil Response") + return + } + defer res.Body.Close() + if res.Header.Get("X-TLS-Set") != "true" { + t.Errorf("expected X-TLS-Set response header") + return + } + if res.Header.Get("X-TLS-HandshakeComplete") != "true" { + t.Errorf("expected X-TLS-HandshakeComplete header") + } + }) } type serverExpectTest struct { @@ -646,9 +688,11 @@ func TestServerExpect(t *testing.T) { } } -func TestServerConsumesRequestBody(t *testing.T) { +// Under a ~256KB (maxPostHandlerReadBytes) threshold, the server +// should consume client request bodies that a handler didn't read. +func TestServerUnreadRequestBodyLittle(t *testing.T) { conn := new(testConn) - body := strings.Repeat("x", 1<<20) + body := strings.Repeat("x", 100<<10) conn.readBuf.Write([]byte(fmt.Sprintf( "POST / HTTP/1.1\r\n"+ "Host: test\r\n"+ @@ -660,14 +704,49 @@ func TestServerConsumesRequestBody(t *testing.T) { ls := &oneConnListener{conn} go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { + defer close(done) if conn.readBuf.Len() < len(body)/2 { - t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len()) + t.Errorf("on request, read buffer length is %d; expected about 100 KB", conn.readBuf.Len()) } rw.WriteHeader(200) if g, e := conn.readBuf.Len(), 0; g != e { t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e) } - done <- true + if c := rw.Header().Get("Connection"); c != "" { + t.Errorf(`Connection header = %q; want ""`, c) + } + })) + <-done +} + +// Over a ~256KB (maxPostHandlerReadBytes) threshold, the server +// should ignore client request bodies that a handler didn't read +// and close the connection. +func TestServerUnreadRequestBodyLarge(t *testing.T) { + conn := new(testConn) + body := strings.Repeat("x", 1<<20) + conn.readBuf.Write([]byte(fmt.Sprintf( + "POST / HTTP/1.1\r\n"+ + "Host: test\r\n"+ + "Content-Length: %d\r\n"+ + "\r\n", len(body)))) + conn.readBuf.Write([]byte(body)) + + done := make(chan bool) + + ls := &oneConnListener{conn} + go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { + defer close(done) + if conn.readBuf.Len() < len(body)/2 { + t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len()) + } + rw.WriteHeader(200) + if conn.readBuf.Len() < len(body)/2 { + t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len()) + } + if c := rw.Header().Get("Connection"); c != "close" { + t.Errorf(`Connection header = %q; want "close"`, c) + } })) <-done } @@ -785,6 +864,14 @@ func TestZeroLengthPostAndResponse(t *testing.T) { } func TestHandlerPanic(t *testing.T) { + testHandlerPanic(t, false) +} + +func TestHandlerPanicWithHijack(t *testing.T) { + testHandlerPanic(t, true) +} + +func testHandlerPanic(t *testing.T, withHijack bool) { // Unlike the other tests that set the log output to ioutil.Discard // to quiet the output, this test uses a pipe. The pipe serves three // purposes: @@ -805,7 +892,14 @@ func TestHandlerPanic(t *testing.T) { log.SetOutput(pw) defer log.SetOutput(os.Stderr) - ts := httptest.NewServer(HandlerFunc(func(ResponseWriter, *Request) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if withHijack { + rwc, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Logf("unexpected error: %v", err) + } + defer rwc.Close() + } panic("intentional death for testing") })) defer ts.Close() @@ -891,9 +985,110 @@ func TestRequestLimit(t *testing.T) { // we do support it (at least currently), so we expect a response below. t.Fatalf("Do: %v", err) } - if res.StatusCode != 400 { - t.Fatalf("expected 400 response status; got: %d %s", res.StatusCode, res.Status) + if res.StatusCode != 413 { + t.Fatalf("expected 413 response status; got: %d %s", res.StatusCode, res.Status) + } +} + +type neverEnding byte + +func (b neverEnding) Read(p []byte) (n int, err os.Error) { + for i := range p { + p[i] = byte(b) + } + return len(p), nil +} + +type countReader struct { + r io.Reader + n *int64 +} + +func (cr countReader) Read(p []byte) (n int, err os.Error) { + n, err = cr.r.Read(p) + *cr.n += int64(n) + return +} + +func TestRequestBodyLimit(t *testing.T) { + const limit = 1 << 20 + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + r.Body = MaxBytesReader(w, r.Body, limit) + n, err := io.Copy(ioutil.Discard, r.Body) + if err == nil { + t.Errorf("expected error from io.Copy") + } + if n != limit { + t.Errorf("io.Copy = %d, want %d", n, limit) + } + })) + defer ts.Close() + + nWritten := int64(0) + req, _ := NewRequest("POST", ts.URL, io.LimitReader(countReader{neverEnding('a'), &nWritten}, limit*200)) + + // Send the POST, but don't care it succeeds or not. The + // remote side is going to reply and then close the TCP + // connection, and HTTP doesn't really define if that's + // allowed or not. Some HTTP clients will get the response + // and some (like ours, currently) will complain that the + // request write failed, without reading the response. + // + // But that's okay, since what we're really testing is that + // the remote side hung up on us before we wrote too much. + _, _ = DefaultClient.Do(req) + + if nWritten > limit*100 { + t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d", + limit, nWritten) + } +} + +// TestClientWriteShutdown tests that if the client shuts down the write +// side of their TCP connection, the server doesn't send a 400 Bad Request. +func TestClientWriteShutdown(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + defer ts.Close() + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) } + err = conn.(*net.TCPConn).CloseWrite() + if err != nil { + t.Fatalf("Dial: %v", err) + } + donec := make(chan bool) + go func() { + defer close(donec) + bs, err := ioutil.ReadAll(conn) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + got := string(bs) + if got != "" { + t.Errorf("read %q from server; want nothing", got) + } + }() + select { + case <-donec: + case <-time.After(10e9): + t.Fatalf("timeout") + } +} + +// goTimeout runs f, failing t if f takes more than ns to complete. +func goTimeout(t *testing.T, ns int64, f func()) { + ch := make(chan bool, 2) + timer := time.AfterFunc(ns, func() { + t.Errorf("Timeout expired after %d ns", ns) + ch <- true + }) + defer timer.Stop() + go func() { + defer func() { ch <- true }() + f() + }() + <-ch } type errorListener struct { diff --git a/libgo/go/http/server.go b/libgo/go/http/server.go index b634e27..9792c60 100644 --- a/libgo/go/http/server.go +++ b/libgo/go/http/server.go @@ -16,6 +16,7 @@ import ( "crypto/tls" "fmt" "io" + "io/ioutil" "log" "net" "os" @@ -122,27 +123,46 @@ type response struct { // "Connection: keep-alive" response header and a // Content-Length. closeAfterReply bool + + // requestBodyLimitHit is set by requestTooLarge when + // maxBytesReader hits its max size. It is checked in + // WriteHeader, to make sure we don't consume the the + // remaining request body to try to advance to the next HTTP + // request. Instead, when this is set, we stop doing + // subsequent requests on this connection and stop reading + // input from it. + requestBodyLimitHit bool +} + +// requestTooLarge is called by maxBytesReader when too much input has +// been read from the client. +func (w *response) requestTooLarge() { + w.closeAfterReply = true + w.requestBodyLimitHit = true + if !w.wroteHeader { + w.Header().Set("Connection", "close") + } } type writerOnly struct { io.Writer } -func (r *response) ReadFrom(src io.Reader) (n int64, err os.Error) { - // Flush before checking r.chunking, as Flush will call +func (w *response) ReadFrom(src io.Reader) (n int64, err os.Error) { + // Flush before checking w.chunking, as Flush will call // WriteHeader if it hasn't been called yet, and WriteHeader - // is what sets r.chunking. - r.Flush() - if !r.chunking && r.bodyAllowed() && !r.needSniff { - if rf, ok := r.conn.rwc.(io.ReaderFrom); ok { + // is what sets w.chunking. + w.Flush() + if !w.chunking && w.bodyAllowed() && !w.needSniff { + if rf, ok := w.conn.rwc.(io.ReaderFrom); ok { n, err = rf.ReadFrom(src) - r.written += n + w.written += n return } } // Fall back to default io.Copy implementation. - // Use wrapper to hide r.ReadFrom from io.Copy. - return io.Copy(writerOnly{r}, src) + // Use wrapper to hide w.ReadFrom from io.Copy. + return io.Copy(writerOnly{w}, src) } // noLimit is an effective infinite upper bound for io.LimitedReader @@ -159,13 +179,6 @@ func (srv *Server) newConn(rwc net.Conn) (c *conn, err os.Error) { br := bufio.NewReader(c.lr) bw := bufio.NewWriter(rwc) c.buf = bufio.NewReadWriter(br, bw) - - if tlsConn, ok := rwc.(*tls.Conn); ok { - tlsConn.Handshake() - c.tlsState = new(tls.ConnectionState) - *c.tlsState = tlsConn.ConnectionState() - } - return c, nil } @@ -245,6 +258,17 @@ func (w *response) Header() Header { return w.header } +// maxPostHandlerReadBytes is the max number of Request.Body bytes not +// consumed by a handler that the server will read from the a client +// in order to keep a connection alive. If there are more bytes than +// this then the server to be paranoid instead sends a "Connection: +// close" response. +// +// This number is approximately what a typical machine's TCP buffer +// size is anyway. (if we have the bytes on the machine, we might as +// well read them) +const maxPostHandlerReadBytes = 256 << 10 + func (w *response) WriteHeader(code int) { if w.conn.hijacked { log.Print("http: response.WriteHeader on hijacked connection") @@ -254,18 +278,54 @@ func (w *response) WriteHeader(code int) { log.Print("http: multiple response.WriteHeader calls") return } + w.wroteHeader = true + w.status = code + + // Check for a explicit (and valid) Content-Length header. + var hasCL bool + var contentLength int64 + if clenStr := w.header.Get("Content-Length"); clenStr != "" { + var err os.Error + contentLength, err = strconv.Atoi64(clenStr) + if err == nil { + hasCL = true + } else { + log.Printf("http: invalid Content-Length of %q sent", clenStr) + w.header.Del("Content-Length") + } + } + + if w.req.wantsHttp10KeepAlive() && (w.req.Method == "HEAD" || hasCL) { + _, connectionHeaderSet := w.header["Connection"] + if !connectionHeaderSet { + w.header.Set("Connection", "keep-alive") + } + } else if !w.req.ProtoAtLeast(1, 1) { + // Client did not ask to keep connection alive. + w.closeAfterReply = true + } + + if w.header.Get("Connection") == "close" { + w.closeAfterReply = true + } // Per RFC 2616, we should consume the request body before - // replying, if the handler hasn't already done so. - if w.req.ContentLength != 0 { + // replying, if the handler hasn't already done so. But we + // don't want to do an unbounded amount of reading here for + // DoS reasons, so we only try up to a threshold. + if w.req.ContentLength != 0 && !w.closeAfterReply { ecr, isExpecter := w.req.Body.(*expectContinueReader) if !isExpecter || ecr.resp.wroteContinue { - w.req.Body.Close() + n, _ := io.CopyN(ioutil.Discard, w.req.Body, maxPostHandlerReadBytes+1) + if n >= maxPostHandlerReadBytes { + w.requestTooLarge() + w.header.Set("Connection", "close") + } else { + w.req.Body.Close() + } } } - w.wroteHeader = true - w.status = code if code == StatusNotModified { // Must not have body. for _, header := range []string{"Content-Type", "Content-Length", "Transfer-Encoding"} { @@ -288,20 +348,6 @@ func (w *response) WriteHeader(code int) { w.Header().Set("Date", time.UTC().Format(TimeFormat)) } - // Check for a explicit (and valid) Content-Length header. - var hasCL bool - var contentLength int64 - if clenStr := w.header.Get("Content-Length"); clenStr != "" { - var err os.Error - contentLength, err = strconv.Atoi64(clenStr) - if err == nil { - hasCL = true - } else { - log.Printf("http: invalid Content-Length of %q sent", clenStr) - w.header.Del("Content-Length") - } - } - te := w.header.Get("Transfer-Encoding") hasTE := te != "" if hasCL && hasTE && te != "identity" { @@ -334,20 +380,6 @@ func (w *response) WriteHeader(code int) { w.header.Del("Transfer-Encoding") // in case already set } - if w.req.wantsHttp10KeepAlive() && (w.req.Method == "HEAD" || hasCL) { - _, connectionHeaderSet := w.header["Connection"] - if !connectionHeaderSet { - w.header.Set("Connection", "keep-alive") - } - } else if !w.req.ProtoAtLeast(1, 1) { - // Client did not ask to keep connection alive. - w.closeAfterReply = true - } - - if w.header.Get("Connection") == "close" { - w.closeAfterReply = true - } - // Cannot use Content-Length with non-identity Transfer-Encoding. if w.chunking { w.header.Del("Content-Length") @@ -472,55 +504,6 @@ func (w *response) Write(data []byte) (n int, err os.Error) { return m + n, err } -// If this is an error reply (4xx or 5xx) -// and the handler wrote some data explaining the error, -// some browsers (i.e., Chrome, Internet Explorer) -// will show their own error instead unless the error is -// long enough. The minimum lengths used in those -// browsers are in the 256-512 range. -// Pad to 1024 bytes. -func errorKludge(w *response) { - const min = 1024 - - // Is this an error? - if kind := w.status / 100; kind != 4 && kind != 5 { - return - } - - // Did the handler supply any info? Enough? - if w.written == 0 || w.written >= min { - return - } - - // Is it a broken browser? - var msg string - switch agent := w.req.UserAgent(); { - case strings.Contains(agent, "MSIE"): - msg = "Internet Explorer" - case strings.Contains(agent, "Chrome/"): - msg = "Chrome" - default: - return - } - msg += " would ignore this error page if this text weren't here.\n" - - // Is it text? ("Content-Type" is always in the map) - baseType := strings.SplitN(w.header.Get("Content-Type"), ";", 2)[0] - switch baseType { - case "text/html": - io.WriteString(w, "<!-- ") - for w.written < min { - io.WriteString(w, msg) - } - io.WriteString(w, " -->") - case "text/plain": - io.WriteString(w, "\n") - for w.written < min { - io.WriteString(w, msg) - } - } -} - func (w *response) finishRequest() { // If this was an HTTP/1.0 request with keep-alive and we sent a Content-Length // back, we can make this a keep-alive response ... @@ -536,14 +519,17 @@ func (w *response) finishRequest() { if w.needSniff { w.sniff() } - errorKludge(w) if w.chunking { io.WriteString(w.conn.buf, "0\r\n") // trailer key/value pairs, followed by blank line io.WriteString(w.conn.buf, "\r\n") } w.conn.buf.Flush() - w.req.Body.Close() + // Close the body, unless we're about to close the whole TCP connection + // anyway. + if !w.closeAfterReply { + w.req.Body.Close() + } if w.req.MultipartForm != nil { w.req.MultipartForm.RemoveAll() } @@ -581,7 +567,9 @@ func (c *conn) serve() { if err == nil { return } - c.rwc.Close() + if c.rwc != nil { // may be nil if connection hijacked + c.rwc.Close() + } var buf bytes.Buffer fmt.Fprintf(&buf, "http: panic serving %v: %v\n", c.remoteAddr, err) @@ -589,17 +577,32 @@ func (c *conn) serve() { log.Print(buf.String()) }() + if tlsConn, ok := c.rwc.(*tls.Conn); ok { + if err := tlsConn.Handshake(); err != nil { + c.close() + return + } + c.tlsState = new(tls.ConnectionState) + *c.tlsState = tlsConn.ConnectionState() + } + for { w, err := c.readRequest() if err != nil { + msg := "400 Bad Request" if err == errTooLarge { // Their HTTP client may or may not be // able to read this if we're // responding to them and hanging up // while they're still writing their // request. Undefined behavior. - fmt.Fprintf(c.rwc, "HTTP/1.1 400 Request Too Large\r\n\r\n") + msg = "413 Request Entity Too Large" + } else if err == io.ErrUnexpectedEOF { + break // Don't reply + } else if neterr, ok := err.(net.Error); ok && neterr.Timeout() { + break // Don't reply } + fmt.Fprintf(c.rwc, "HTTP/1.1 %s\r\n\r\n", msg) break } @@ -774,13 +777,16 @@ func Redirect(w ResponseWriter, r *Request, urlStr string, code int) { } } +var htmlReplacer = strings.NewReplacer( + "&", "&", + "<", "<", + ">", ">", + `"`, """, + "'", "'", +) + func htmlEscape(s string) string { - s = strings.Replace(s, "&", "&", -1) - s = strings.Replace(s, "<", "<", -1) - s = strings.Replace(s, ">", ">", -1) - s = strings.Replace(s, "\"", """, -1) - s = strings.Replace(s, "'", "'", -1) - return s + return htmlReplacer.Replace(s) } // Redirect to a fixed URL diff --git a/libgo/go/http/spdy/read.go b/libgo/go/http/spdy/read.go deleted file mode 100644 index c6b6ab3..0000000 --- a/libgo/go/http/spdy/read.go +++ /dev/null @@ -1,313 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package spdy - -import ( - "compress/zlib" - "encoding/binary" - "http" - "io" - "os" - "strings" -) - -func (frame *SynStreamFrame) read(h ControlFrameHeader, f *Framer) os.Error { - return f.readSynStreamFrame(h, frame) -} - -func (frame *SynReplyFrame) read(h ControlFrameHeader, f *Framer) os.Error { - return f.readSynReplyFrame(h, frame) -} - -func (frame *RstStreamFrame) read(h ControlFrameHeader, f *Framer) os.Error { - frame.CFHeader = h - if err := binary.Read(f.r, binary.BigEndian, &frame.StreamId); err != nil { - return err - } - if err := binary.Read(f.r, binary.BigEndian, &frame.Status); err != nil { - return err - } - return nil -} - -func (frame *SettingsFrame) read(h ControlFrameHeader, f *Framer) os.Error { - frame.CFHeader = h - var numSettings uint32 - if err := binary.Read(f.r, binary.BigEndian, &numSettings); err != nil { - return err - } - frame.FlagIdValues = make([]SettingsFlagIdValue, numSettings) - for i := uint32(0); i < numSettings; i++ { - if err := binary.Read(f.r, binary.BigEndian, &frame.FlagIdValues[i].Id); err != nil { - return err - } - frame.FlagIdValues[i].Flag = SettingsFlag((frame.FlagIdValues[i].Id & 0xff000000) >> 24) - frame.FlagIdValues[i].Id &= 0xffffff - if err := binary.Read(f.r, binary.BigEndian, &frame.FlagIdValues[i].Value); err != nil { - return err - } - } - return nil -} - -func (frame *NoopFrame) read(h ControlFrameHeader, f *Framer) os.Error { - frame.CFHeader = h - return nil -} - -func (frame *PingFrame) read(h ControlFrameHeader, f *Framer) os.Error { - frame.CFHeader = h - if err := binary.Read(f.r, binary.BigEndian, &frame.Id); err != nil { - return err - } - return nil -} - -func (frame *GoAwayFrame) read(h ControlFrameHeader, f *Framer) os.Error { - frame.CFHeader = h - if err := binary.Read(f.r, binary.BigEndian, &frame.LastGoodStreamId); err != nil { - return err - } - return nil -} - -func (frame *HeadersFrame) read(h ControlFrameHeader, f *Framer) os.Error { - return f.readHeadersFrame(h, frame) -} - -func newControlFrame(frameType ControlFrameType) (controlFrame, os.Error) { - ctor, ok := cframeCtor[frameType] - if !ok { - return nil, &Error{Err: InvalidControlFrame} - } - return ctor(), nil -} - -var cframeCtor = map[ControlFrameType]func() controlFrame{ - TypeSynStream: func() controlFrame { return new(SynStreamFrame) }, - TypeSynReply: func() controlFrame { return new(SynReplyFrame) }, - TypeRstStream: func() controlFrame { return new(RstStreamFrame) }, - TypeSettings: func() controlFrame { return new(SettingsFrame) }, - TypeNoop: func() controlFrame { return new(NoopFrame) }, - TypePing: func() controlFrame { return new(PingFrame) }, - TypeGoAway: func() controlFrame { return new(GoAwayFrame) }, - TypeHeaders: func() controlFrame { return new(HeadersFrame) }, - // TODO(willchan): Add TypeWindowUpdate -} - -func (f *Framer) uncorkHeaderDecompressor(payloadSize int64) os.Error { - if f.headerDecompressor != nil { - f.headerReader.N = payloadSize - return nil - } - f.headerReader = io.LimitedReader{R: f.r, N: payloadSize} - decompressor, err := zlib.NewReaderDict(&f.headerReader, []byte(HeaderDictionary)) - if err != nil { - return err - } - f.headerDecompressor = decompressor - return nil -} - -// ReadFrame reads SPDY encoded data and returns a decompressed Frame. -func (f *Framer) ReadFrame() (Frame, os.Error) { - var firstWord uint32 - if err := binary.Read(f.r, binary.BigEndian, &firstWord); err != nil { - return nil, err - } - if (firstWord & 0x80000000) != 0 { - frameType := ControlFrameType(firstWord & 0xffff) - version := uint16(0x7fff & (firstWord >> 16)) - return f.parseControlFrame(version, frameType) - } - return f.parseDataFrame(firstWord & 0x7fffffff) -} - -func (f *Framer) parseControlFrame(version uint16, frameType ControlFrameType) (Frame, os.Error) { - var length uint32 - if err := binary.Read(f.r, binary.BigEndian, &length); err != nil { - return nil, err - } - flags := ControlFlags((length & 0xff000000) >> 24) - length &= 0xffffff - header := ControlFrameHeader{version, frameType, flags, length} - cframe, err := newControlFrame(frameType) - if err != nil { - return nil, err - } - if err = cframe.read(header, f); err != nil { - return nil, err - } - return cframe, nil -} - -func parseHeaderValueBlock(r io.Reader, streamId uint32) (http.Header, os.Error) { - var numHeaders uint16 - if err := binary.Read(r, binary.BigEndian, &numHeaders); err != nil { - return nil, err - } - var e os.Error - h := make(http.Header, int(numHeaders)) - for i := 0; i < int(numHeaders); i++ { - var length uint16 - if err := binary.Read(r, binary.BigEndian, &length); err != nil { - return nil, err - } - nameBytes := make([]byte, length) - if _, err := io.ReadFull(r, nameBytes); err != nil { - return nil, err - } - name := string(nameBytes) - if name != strings.ToLower(name) { - e = &Error{UnlowercasedHeaderName, streamId} - name = strings.ToLower(name) - } - if h[name] != nil { - e = &Error{DuplicateHeaders, streamId} - } - if err := binary.Read(r, binary.BigEndian, &length); err != nil { - return nil, err - } - value := make([]byte, length) - if _, err := io.ReadFull(r, value); err != nil { - return nil, err - } - valueList := strings.Split(string(value), "\x00") - for _, v := range valueList { - h.Add(name, v) - } - } - if e != nil { - return h, e - } - return h, nil -} - -func (f *Framer) readSynStreamFrame(h ControlFrameHeader, frame *SynStreamFrame) os.Error { - frame.CFHeader = h - var err os.Error - if err = binary.Read(f.r, binary.BigEndian, &frame.StreamId); err != nil { - return err - } - if err = binary.Read(f.r, binary.BigEndian, &frame.AssociatedToStreamId); err != nil { - return err - } - if err = binary.Read(f.r, binary.BigEndian, &frame.Priority); err != nil { - return err - } - frame.Priority >>= 14 - - reader := f.r - if !f.headerCompressionDisabled { - f.uncorkHeaderDecompressor(int64(h.length - 10)) - reader = f.headerDecompressor - } - - frame.Headers, err = parseHeaderValueBlock(reader, frame.StreamId) - if !f.headerCompressionDisabled && ((err == os.EOF && f.headerReader.N == 0) || f.headerReader.N != 0) { - err = &Error{WrongCompressedPayloadSize, 0} - } - if err != nil { - return err - } - // Remove this condition when we bump Version to 3. - if Version >= 3 { - for h, _ := range frame.Headers { - if invalidReqHeaders[h] { - return &Error{InvalidHeaderPresent, frame.StreamId} - } - } - } - return nil -} - -func (f *Framer) readSynReplyFrame(h ControlFrameHeader, frame *SynReplyFrame) os.Error { - frame.CFHeader = h - var err os.Error - if err = binary.Read(f.r, binary.BigEndian, &frame.StreamId); err != nil { - return err - } - var unused uint16 - if err = binary.Read(f.r, binary.BigEndian, &unused); err != nil { - return err - } - reader := f.r - if !f.headerCompressionDisabled { - f.uncorkHeaderDecompressor(int64(h.length - 6)) - reader = f.headerDecompressor - } - frame.Headers, err = parseHeaderValueBlock(reader, frame.StreamId) - if !f.headerCompressionDisabled && ((err == os.EOF && f.headerReader.N == 0) || f.headerReader.N != 0) { - err = &Error{WrongCompressedPayloadSize, 0} - } - if err != nil { - return err - } - // Remove this condition when we bump Version to 3. - if Version >= 3 { - for h, _ := range frame.Headers { - if invalidRespHeaders[h] { - return &Error{InvalidHeaderPresent, frame.StreamId} - } - } - } - return nil -} - -func (f *Framer) readHeadersFrame(h ControlFrameHeader, frame *HeadersFrame) os.Error { - frame.CFHeader = h - var err os.Error - if err = binary.Read(f.r, binary.BigEndian, &frame.StreamId); err != nil { - return err - } - var unused uint16 - if err = binary.Read(f.r, binary.BigEndian, &unused); err != nil { - return err - } - reader := f.r - if !f.headerCompressionDisabled { - f.uncorkHeaderDecompressor(int64(h.length - 6)) - reader = f.headerDecompressor - } - frame.Headers, err = parseHeaderValueBlock(reader, frame.StreamId) - if !f.headerCompressionDisabled && ((err == os.EOF && f.headerReader.N == 0) || f.headerReader.N != 0) { - err = &Error{WrongCompressedPayloadSize, 0} - } - if err != nil { - return err - } - - // Remove this condition when we bump Version to 3. - if Version >= 3 { - var invalidHeaders map[string]bool - if frame.StreamId%2 == 0 { - invalidHeaders = invalidReqHeaders - } else { - invalidHeaders = invalidRespHeaders - } - for h, _ := range frame.Headers { - if invalidHeaders[h] { - return &Error{InvalidHeaderPresent, frame.StreamId} - } - } - } - return nil -} - -func (f *Framer) parseDataFrame(streamId uint32) (*DataFrame, os.Error) { - var length uint32 - if err := binary.Read(f.r, binary.BigEndian, &length); err != nil { - return nil, err - } - var frame DataFrame - frame.StreamId = streamId - frame.Flags = DataFlags(length >> 24) - length &= 0xffffff - frame.Data = make([]byte, length) - if _, err := io.ReadFull(f.r, frame.Data); err != nil { - return nil, err - } - return &frame, nil -} diff --git a/libgo/go/http/spdy/spdy_test.go b/libgo/go/http/spdy/spdy_test.go deleted file mode 100644 index cb91e02..0000000 --- a/libgo/go/http/spdy/spdy_test.go +++ /dev/null @@ -1,497 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package spdy - -import ( - "bytes" - "http" - "io" - "reflect" - "testing" -) - -func TestHeaderParsing(t *testing.T) { - headers := http.Header{ - "Url": []string{"http://www.google.com/"}, - "Method": []string{"get"}, - "Version": []string{"http/1.1"}, - } - var headerValueBlockBuf bytes.Buffer - writeHeaderValueBlock(&headerValueBlockBuf, headers) - - const bogusStreamId = 1 - newHeaders, err := parseHeaderValueBlock(&headerValueBlockBuf, bogusStreamId) - if err != nil { - t.Fatal("parseHeaderValueBlock:", err) - } - - if !reflect.DeepEqual(headers, newHeaders) { - t.Fatal("got: ", newHeaders, "\nwant: ", headers) - } -} - -func TestCreateParseSynStreamFrame(t *testing.T) { - buffer := new(bytes.Buffer) - framer := &Framer{ - headerCompressionDisabled: true, - w: buffer, - headerBuf: new(bytes.Buffer), - r: buffer, - } - synStreamFrame := SynStreamFrame{ - CFHeader: ControlFrameHeader{ - version: Version, - frameType: TypeSynStream, - }, - Headers: http.Header{ - "Url": []string{"http://www.google.com/"}, - "Method": []string{"get"}, - "Version": []string{"http/1.1"}, - }, - } - if err := framer.WriteFrame(&synStreamFrame); err != nil { - t.Fatal("WriteFrame without compression:", err) - } - frame, err := framer.ReadFrame() - if err != nil { - t.Fatal("ReadFrame without compression:", err) - } - parsedSynStreamFrame, ok := frame.(*SynStreamFrame) - if !ok { - t.Fatal("Parsed incorrect frame type:", frame) - } - if !reflect.DeepEqual(synStreamFrame, *parsedSynStreamFrame) { - t.Fatal("got: ", *parsedSynStreamFrame, "\nwant: ", synStreamFrame) - } - - // Test again with compression - buffer.Reset() - framer, err = NewFramer(buffer, buffer) - if err != nil { - t.Fatal("Failed to create new framer:", err) - } - if err := framer.WriteFrame(&synStreamFrame); err != nil { - t.Fatal("WriteFrame with compression:", err) - } - frame, err = framer.ReadFrame() - if err != nil { - t.Fatal("ReadFrame with compression:", err) - } - parsedSynStreamFrame, ok = frame.(*SynStreamFrame) - if !ok { - t.Fatal("Parsed incorrect frame type:", frame) - } - if !reflect.DeepEqual(synStreamFrame, *parsedSynStreamFrame) { - t.Fatal("got: ", *parsedSynStreamFrame, "\nwant: ", synStreamFrame) - } -} - -func TestCreateParseSynReplyFrame(t *testing.T) { - buffer := new(bytes.Buffer) - framer := &Framer{ - headerCompressionDisabled: true, - w: buffer, - headerBuf: new(bytes.Buffer), - r: buffer, - } - synReplyFrame := SynReplyFrame{ - CFHeader: ControlFrameHeader{ - version: Version, - frameType: TypeSynReply, - }, - Headers: http.Header{ - "Url": []string{"http://www.google.com/"}, - "Method": []string{"get"}, - "Version": []string{"http/1.1"}, - }, - } - if err := framer.WriteFrame(&synReplyFrame); err != nil { - t.Fatal("WriteFrame without compression:", err) - } - frame, err := framer.ReadFrame() - if err != nil { - t.Fatal("ReadFrame without compression:", err) - } - parsedSynReplyFrame, ok := frame.(*SynReplyFrame) - if !ok { - t.Fatal("Parsed incorrect frame type:", frame) - } - if !reflect.DeepEqual(synReplyFrame, *parsedSynReplyFrame) { - t.Fatal("got: ", *parsedSynReplyFrame, "\nwant: ", synReplyFrame) - } - - // Test again with compression - buffer.Reset() - framer, err = NewFramer(buffer, buffer) - if err != nil { - t.Fatal("Failed to create new framer:", err) - } - if err := framer.WriteFrame(&synReplyFrame); err != nil { - t.Fatal("WriteFrame with compression:", err) - } - frame, err = framer.ReadFrame() - if err != nil { - t.Fatal("ReadFrame with compression:", err) - } - parsedSynReplyFrame, ok = frame.(*SynReplyFrame) - if !ok { - t.Fatal("Parsed incorrect frame type:", frame) - } - if !reflect.DeepEqual(synReplyFrame, *parsedSynReplyFrame) { - t.Fatal("got: ", *parsedSynReplyFrame, "\nwant: ", synReplyFrame) - } -} - -func TestCreateParseRstStream(t *testing.T) { - buffer := new(bytes.Buffer) - framer, err := NewFramer(buffer, buffer) - if err != nil { - t.Fatal("Failed to create new framer:", err) - } - rstStreamFrame := RstStreamFrame{ - CFHeader: ControlFrameHeader{ - version: Version, - frameType: TypeRstStream, - }, - StreamId: 1, - Status: InvalidStream, - } - if err := framer.WriteFrame(&rstStreamFrame); err != nil { - t.Fatal("WriteFrame:", err) - } - frame, err := framer.ReadFrame() - if err != nil { - t.Fatal("ReadFrame:", err) - } - parsedRstStreamFrame, ok := frame.(*RstStreamFrame) - if !ok { - t.Fatal("Parsed incorrect frame type:", frame) - } - if !reflect.DeepEqual(rstStreamFrame, *parsedRstStreamFrame) { - t.Fatal("got: ", *parsedRstStreamFrame, "\nwant: ", rstStreamFrame) - } -} - -func TestCreateParseSettings(t *testing.T) { - buffer := new(bytes.Buffer) - framer, err := NewFramer(buffer, buffer) - if err != nil { - t.Fatal("Failed to create new framer:", err) - } - settingsFrame := SettingsFrame{ - CFHeader: ControlFrameHeader{ - version: Version, - frameType: TypeSettings, - }, - FlagIdValues: []SettingsFlagIdValue{ - {FlagSettingsPersistValue, SettingsCurrentCwnd, 10}, - {FlagSettingsPersisted, SettingsUploadBandwidth, 1}, - }, - } - if err := framer.WriteFrame(&settingsFrame); err != nil { - t.Fatal("WriteFrame:", err) - } - frame, err := framer.ReadFrame() - if err != nil { - t.Fatal("ReadFrame:", err) - } - parsedSettingsFrame, ok := frame.(*SettingsFrame) - if !ok { - t.Fatal("Parsed incorrect frame type:", frame) - } - if !reflect.DeepEqual(settingsFrame, *parsedSettingsFrame) { - t.Fatal("got: ", *parsedSettingsFrame, "\nwant: ", settingsFrame) - } -} - -func TestCreateParseNoop(t *testing.T) { - buffer := new(bytes.Buffer) - framer, err := NewFramer(buffer, buffer) - if err != nil { - t.Fatal("Failed to create new framer:", err) - } - noopFrame := NoopFrame{ - CFHeader: ControlFrameHeader{ - version: Version, - frameType: TypeNoop, - }, - } - if err := framer.WriteFrame(&noopFrame); err != nil { - t.Fatal("WriteFrame:", err) - } - frame, err := framer.ReadFrame() - if err != nil { - t.Fatal("ReadFrame:", err) - } - parsedNoopFrame, ok := frame.(*NoopFrame) - if !ok { - t.Fatal("Parsed incorrect frame type:", frame) - } - if !reflect.DeepEqual(noopFrame, *parsedNoopFrame) { - t.Fatal("got: ", *parsedNoopFrame, "\nwant: ", noopFrame) - } -} - -func TestCreateParsePing(t *testing.T) { - buffer := new(bytes.Buffer) - framer, err := NewFramer(buffer, buffer) - if err != nil { - t.Fatal("Failed to create new framer:", err) - } - pingFrame := PingFrame{ - CFHeader: ControlFrameHeader{ - version: Version, - frameType: TypePing, - }, - Id: 31337, - } - if err := framer.WriteFrame(&pingFrame); err != nil { - t.Fatal("WriteFrame:", err) - } - frame, err := framer.ReadFrame() - if err != nil { - t.Fatal("ReadFrame:", err) - } - parsedPingFrame, ok := frame.(*PingFrame) - if !ok { - t.Fatal("Parsed incorrect frame type:", frame) - } - if !reflect.DeepEqual(pingFrame, *parsedPingFrame) { - t.Fatal("got: ", *parsedPingFrame, "\nwant: ", pingFrame) - } -} - -func TestCreateParseGoAway(t *testing.T) { - buffer := new(bytes.Buffer) - framer, err := NewFramer(buffer, buffer) - if err != nil { - t.Fatal("Failed to create new framer:", err) - } - goAwayFrame := GoAwayFrame{ - CFHeader: ControlFrameHeader{ - version: Version, - frameType: TypeGoAway, - }, - LastGoodStreamId: 31337, - } - if err := framer.WriteFrame(&goAwayFrame); err != nil { - t.Fatal("WriteFrame:", err) - } - frame, err := framer.ReadFrame() - if err != nil { - t.Fatal("ReadFrame:", err) - } - parsedGoAwayFrame, ok := frame.(*GoAwayFrame) - if !ok { - t.Fatal("Parsed incorrect frame type:", frame) - } - if !reflect.DeepEqual(goAwayFrame, *parsedGoAwayFrame) { - t.Fatal("got: ", *parsedGoAwayFrame, "\nwant: ", goAwayFrame) - } -} - -func TestCreateParseHeadersFrame(t *testing.T) { - buffer := new(bytes.Buffer) - framer := &Framer{ - headerCompressionDisabled: true, - w: buffer, - headerBuf: new(bytes.Buffer), - r: buffer, - } - headersFrame := HeadersFrame{ - CFHeader: ControlFrameHeader{ - version: Version, - frameType: TypeHeaders, - }, - } - headersFrame.Headers = http.Header{ - "Url": []string{"http://www.google.com/"}, - "Method": []string{"get"}, - "Version": []string{"http/1.1"}, - } - if err := framer.WriteFrame(&headersFrame); err != nil { - t.Fatal("WriteFrame without compression:", err) - } - frame, err := framer.ReadFrame() - if err != nil { - t.Fatal("ReadFrame without compression:", err) - } - parsedHeadersFrame, ok := frame.(*HeadersFrame) - if !ok { - t.Fatal("Parsed incorrect frame type:", frame) - } - if !reflect.DeepEqual(headersFrame, *parsedHeadersFrame) { - t.Fatal("got: ", *parsedHeadersFrame, "\nwant: ", headersFrame) - } - - // Test again with compression - buffer.Reset() - framer, err = NewFramer(buffer, buffer) - if err := framer.WriteFrame(&headersFrame); err != nil { - t.Fatal("WriteFrame with compression:", err) - } - frame, err = framer.ReadFrame() - if err != nil { - t.Fatal("ReadFrame with compression:", err) - } - parsedHeadersFrame, ok = frame.(*HeadersFrame) - if !ok { - t.Fatal("Parsed incorrect frame type:", frame) - } - if !reflect.DeepEqual(headersFrame, *parsedHeadersFrame) { - t.Fatal("got: ", *parsedHeadersFrame, "\nwant: ", headersFrame) - } -} - -func TestCreateParseDataFrame(t *testing.T) { - buffer := new(bytes.Buffer) - framer, err := NewFramer(buffer, buffer) - if err != nil { - t.Fatal("Failed to create new framer:", err) - } - dataFrame := DataFrame{ - StreamId: 1, - Data: []byte{'h', 'e', 'l', 'l', 'o'}, - } - if err := framer.WriteFrame(&dataFrame); err != nil { - t.Fatal("WriteFrame:", err) - } - frame, err := framer.ReadFrame() - if err != nil { - t.Fatal("ReadFrame:", err) - } - parsedDataFrame, ok := frame.(*DataFrame) - if !ok { - t.Fatal("Parsed incorrect frame type:", frame) - } - if !reflect.DeepEqual(dataFrame, *parsedDataFrame) { - t.Fatal("got: ", *parsedDataFrame, "\nwant: ", dataFrame) - } -} - -func TestCompressionContextAcrossFrames(t *testing.T) { - buffer := new(bytes.Buffer) - framer, err := NewFramer(buffer, buffer) - if err != nil { - t.Fatal("Failed to create new framer:", err) - } - headersFrame := HeadersFrame{ - CFHeader: ControlFrameHeader{ - version: Version, - frameType: TypeHeaders, - }, - Headers: http.Header{ - "Url": []string{"http://www.google.com/"}, - "Method": []string{"get"}, - "Version": []string{"http/1.1"}, - }, - } - if err := framer.WriteFrame(&headersFrame); err != nil { - t.Fatal("WriteFrame (HEADERS):", err) - } - synStreamFrame := SynStreamFrame{ControlFrameHeader{Version, TypeSynStream, 0, 0}, 0, 0, 0, nil} - synStreamFrame.Headers = http.Header{ - "Url": []string{"http://www.google.com/"}, - "Method": []string{"get"}, - "Version": []string{"http/1.1"}, - } - if err := framer.WriteFrame(&synStreamFrame); err != nil { - t.Fatal("WriteFrame (SYN_STREAM):", err) - } - frame, err := framer.ReadFrame() - if err != nil { - t.Fatal("ReadFrame (HEADERS):", err, buffer.Bytes()) - } - parsedHeadersFrame, ok := frame.(*HeadersFrame) - if !ok { - t.Fatalf("expected HeadersFrame; got %T %v", frame, frame) - } - if !reflect.DeepEqual(headersFrame, *parsedHeadersFrame) { - t.Fatal("got: ", *parsedHeadersFrame, "\nwant: ", headersFrame) - } - frame, err = framer.ReadFrame() - if err != nil { - t.Fatal("ReadFrame (SYN_STREAM):", err, buffer.Bytes()) - } - parsedSynStreamFrame, ok := frame.(*SynStreamFrame) - if !ok { - t.Fatalf("expected SynStreamFrame; got %T %v", frame, frame) - } - if !reflect.DeepEqual(synStreamFrame, *parsedSynStreamFrame) { - t.Fatal("got: ", *parsedSynStreamFrame, "\nwant: ", synStreamFrame) - } -} - -func TestMultipleSPDYFrames(t *testing.T) { - // Initialize the framers. - pr1, pw1 := io.Pipe() - pr2, pw2 := io.Pipe() - writer, err := NewFramer(pw1, pr2) - if err != nil { - t.Fatal("Failed to create writer:", err) - } - reader, err := NewFramer(pw2, pr1) - if err != nil { - t.Fatal("Failed to create reader:", err) - } - - // Set up the frames we're actually transferring. - headersFrame := HeadersFrame{ - CFHeader: ControlFrameHeader{ - version: Version, - frameType: TypeHeaders, - }, - Headers: http.Header{ - "Url": []string{"http://www.google.com/"}, - "Method": []string{"get"}, - "Version": []string{"http/1.1"}, - }, - } - synStreamFrame := SynStreamFrame{ - CFHeader: ControlFrameHeader{ - version: Version, - frameType: TypeSynStream, - }, - Headers: http.Header{ - "Url": []string{"http://www.google.com/"}, - "Method": []string{"get"}, - "Version": []string{"http/1.1"}, - }, - } - - // Start the goroutines to write the frames. - go func() { - if err := writer.WriteFrame(&headersFrame); err != nil { - t.Fatal("WriteFrame (HEADERS): ", err) - } - if err := writer.WriteFrame(&synStreamFrame); err != nil { - t.Fatal("WriteFrame (SYN_STREAM): ", err) - } - }() - - // Read the frames and verify they look as expected. - frame, err := reader.ReadFrame() - if err != nil { - t.Fatal("ReadFrame (HEADERS): ", err) - } - parsedHeadersFrame, ok := frame.(*HeadersFrame) - if !ok { - t.Fatal("Parsed incorrect frame type:", frame) - } - if !reflect.DeepEqual(headersFrame, *parsedHeadersFrame) { - t.Fatal("got: ", *parsedHeadersFrame, "\nwant: ", headersFrame) - } - frame, err = reader.ReadFrame() - if err != nil { - t.Fatal("ReadFrame (SYN_STREAM):", err) - } - parsedSynStreamFrame, ok := frame.(*SynStreamFrame) - if !ok { - t.Fatal("Parsed incorrect frame type.") - } - if !reflect.DeepEqual(synStreamFrame, *parsedSynStreamFrame) { - t.Fatal("got: ", *parsedSynStreamFrame, "\nwant: ", synStreamFrame) - } -} diff --git a/libgo/go/http/spdy/types.go b/libgo/go/http/spdy/types.go deleted file mode 100644 index 41cafb1..0000000 --- a/libgo/go/http/spdy/types.go +++ /dev/null @@ -1,370 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package spdy - -import ( - "bytes" - "compress/zlib" - "http" - "io" - "os" -) - -// Data Frame Format -// +----------------------------------+ -// |0| Stream-ID (31bits) | -// +----------------------------------+ -// | flags (8) | Length (24 bits) | -// +----------------------------------+ -// | Data | -// +----------------------------------+ -// -// Control Frame Format -// +----------------------------------+ -// |1| Version(15bits) | Type(16bits) | -// +----------------------------------+ -// | flags (8) | Length (24 bits) | -// +----------------------------------+ -// | Data | -// +----------------------------------+ -// -// Control Frame: SYN_STREAM -// +----------------------------------+ -// |1|000000000000001|0000000000000001| -// +----------------------------------+ -// | flags (8) | Length (24 bits) | >= 12 -// +----------------------------------+ -// |X| Stream-ID(31bits) | -// +----------------------------------+ -// |X|Associated-To-Stream-ID (31bits)| -// +----------------------------------+ -// |Pri| unused | Length (16bits)| -// +----------------------------------+ -// -// Control Frame: SYN_REPLY -// +----------------------------------+ -// |1|000000000000001|0000000000000010| -// +----------------------------------+ -// | flags (8) | Length (24 bits) | >= 8 -// +----------------------------------+ -// |X| Stream-ID(31bits) | -// +----------------------------------+ -// | unused (16 bits)| Length (16bits)| -// +----------------------------------+ -// -// Control Frame: RST_STREAM -// +----------------------------------+ -// |1|000000000000001|0000000000000011| -// +----------------------------------+ -// | flags (8) | Length (24 bits) | >= 4 -// +----------------------------------+ -// |X| Stream-ID(31bits) | -// +----------------------------------+ -// | Status code (32 bits) | -// +----------------------------------+ -// -// Control Frame: SETTINGS -// +----------------------------------+ -// |1|000000000000001|0000000000000100| -// +----------------------------------+ -// | flags (8) | Length (24 bits) | -// +----------------------------------+ -// | # of entries (32) | -// +----------------------------------+ -// -// Control Frame: NOOP -// +----------------------------------+ -// |1|000000000000001|0000000000000101| -// +----------------------------------+ -// | flags (8) | Length (24 bits) | = 0 -// +----------------------------------+ -// -// Control Frame: PING -// +----------------------------------+ -// |1|000000000000001|0000000000000110| -// +----------------------------------+ -// | flags (8) | Length (24 bits) | = 4 -// +----------------------------------+ -// | Unique id (32 bits) | -// +----------------------------------+ -// -// Control Frame: GOAWAY -// +----------------------------------+ -// |1|000000000000001|0000000000000111| -// +----------------------------------+ -// | flags (8) | Length (24 bits) | = 4 -// +----------------------------------+ -// |X| Last-accepted-stream-id | -// +----------------------------------+ -// -// Control Frame: HEADERS -// +----------------------------------+ -// |1|000000000000001|0000000000001000| -// +----------------------------------+ -// | flags (8) | Length (24 bits) | >= 8 -// +----------------------------------+ -// |X| Stream-ID (31 bits) | -// +----------------------------------+ -// | unused (16 bits)| Length (16bits)| -// +----------------------------------+ -// -// Control Frame: WINDOW_UPDATE -// +----------------------------------+ -// |1|000000000000001|0000000000001001| -// +----------------------------------+ -// | flags (8) | Length (24 bits) | = 8 -// +----------------------------------+ -// |X| Stream-ID (31 bits) | -// +----------------------------------+ -// | Delta-Window-Size (32 bits) | -// +----------------------------------+ - -// Version is the protocol version number that this package implements. -const Version = 2 - -// ControlFrameType stores the type field in a control frame header. -type ControlFrameType uint16 - -// Control frame type constants -const ( - TypeSynStream ControlFrameType = 0x0001 - TypeSynReply = 0x0002 - TypeRstStream = 0x0003 - TypeSettings = 0x0004 - TypeNoop = 0x0005 - TypePing = 0x0006 - TypeGoAway = 0x0007 - TypeHeaders = 0x0008 - TypeWindowUpdate = 0x0009 -) - -// ControlFlags are the flags that can be set on a control frame. -type ControlFlags uint8 - -const ( - ControlFlagFin ControlFlags = 0x01 -) - -// DataFlags are the flags that can be set on a data frame. -type DataFlags uint8 - -const ( - DataFlagFin DataFlags = 0x01 - DataFlagCompressed = 0x02 -) - -// MaxDataLength is the maximum number of bytes that can be stored in one frame. -const MaxDataLength = 1<<24 - 1 - -// Frame is a single SPDY frame in its unpacked in-memory representation. Use -// Framer to read and write it. -type Frame interface { - write(f *Framer) os.Error -} - -// ControlFrameHeader contains all the fields in a control frame header, -// in its unpacked in-memory representation. -type ControlFrameHeader struct { - // Note, high bit is the "Control" bit. - version uint16 - frameType ControlFrameType - Flags ControlFlags - length uint32 -} - -type controlFrame interface { - Frame - read(h ControlFrameHeader, f *Framer) os.Error -} - -// SynStreamFrame is the unpacked, in-memory representation of a SYN_STREAM -// frame. -type SynStreamFrame struct { - CFHeader ControlFrameHeader - StreamId uint32 - AssociatedToStreamId uint32 - // Note, only 2 highest bits currently used - // Rest of Priority is unused. - Priority uint16 - Headers http.Header -} - -// SynReplyFrame is the unpacked, in-memory representation of a SYN_REPLY frame. -type SynReplyFrame struct { - CFHeader ControlFrameHeader - StreamId uint32 - Headers http.Header -} - -// StatusCode represents the status that led to a RST_STREAM -type StatusCode uint32 - -const ( - ProtocolError StatusCode = 1 - InvalidStream = 2 - RefusedStream = 3 - UnsupportedVersion = 4 - Cancel = 5 - InternalError = 6 - FlowControlError = 7 -) - -// RstStreamFrame is the unpacked, in-memory representation of a RST_STREAM -// frame. -type RstStreamFrame struct { - CFHeader ControlFrameHeader - StreamId uint32 - Status StatusCode -} - -// SettingsFlag represents a flag in a SETTINGS frame. -type SettingsFlag uint8 - -const ( - FlagSettingsPersistValue SettingsFlag = 0x1 - FlagSettingsPersisted = 0x2 -) - -// SettingsFlag represents the id of an id/value pair in a SETTINGS frame. -type SettingsId uint32 - -const ( - SettingsUploadBandwidth SettingsId = 1 - SettingsDownloadBandwidth = 2 - SettingsRoundTripTime = 3 - SettingsMaxConcurrentStreams = 4 - SettingsCurrentCwnd = 5 -) - -// SettingsFlagIdValue is the unpacked, in-memory representation of the -// combined flag/id/value for a setting in a SETTINGS frame. -type SettingsFlagIdValue struct { - Flag SettingsFlag - Id SettingsId - Value uint32 -} - -// SettingsFrame is the unpacked, in-memory representation of a SPDY -// SETTINGS frame. -type SettingsFrame struct { - CFHeader ControlFrameHeader - FlagIdValues []SettingsFlagIdValue -} - -// NoopFrame is the unpacked, in-memory representation of a NOOP frame. -type NoopFrame struct { - CFHeader ControlFrameHeader -} - -// PingFrame is the unpacked, in-memory representation of a PING frame. -type PingFrame struct { - CFHeader ControlFrameHeader - Id uint32 -} - -// GoAwayFrame is the unpacked, in-memory representation of a GOAWAY frame. -type GoAwayFrame struct { - CFHeader ControlFrameHeader - LastGoodStreamId uint32 -} - -// HeadersFrame is the unpacked, in-memory representation of a HEADERS frame. -type HeadersFrame struct { - CFHeader ControlFrameHeader - StreamId uint32 - Headers http.Header -} - -// DataFrame is the unpacked, in-memory representation of a DATA frame. -type DataFrame struct { - // Note, high bit is the "Control" bit. Should be 0 for data frames. - StreamId uint32 - Flags DataFlags - Data []byte -} - -// HeaderDictionary is the dictionary sent to the zlib compressor/decompressor. -// Even though the specification states there is no null byte at the end, Chrome sends it. -const HeaderDictionary = "optionsgetheadpostputdeletetrace" + - "acceptaccept-charsetaccept-encodingaccept-languageauthorizationexpectfromhost" + - "if-modified-sinceif-matchif-none-matchif-rangeif-unmodifiedsince" + - "max-forwardsproxy-authorizationrangerefererteuser-agent" + - "100101200201202203204205206300301302303304305306307400401402403404405406407408409410411412413414415416417500501502503504505" + - "accept-rangesageetaglocationproxy-authenticatepublicretry-after" + - "servervarywarningwww-authenticateallowcontent-basecontent-encodingcache-control" + - "connectiondatetrailertransfer-encodingupgradeviawarning" + - "content-languagecontent-lengthcontent-locationcontent-md5content-rangecontent-typeetagexpireslast-modifiedset-cookie" + - "MondayTuesdayWednesdayThursdayFridaySaturdaySunday" + - "JanFebMarAprMayJunJulAugSepOctNovDec" + - "chunkedtext/htmlimage/pngimage/jpgimage/gifapplication/xmlapplication/xhtmltext/plainpublicmax-age" + - "charset=iso-8859-1utf-8gzipdeflateHTTP/1.1statusversionurl\x00" - -// A SPDY specific error. -type ErrorCode string - -const ( - UnlowercasedHeaderName ErrorCode = "header was not lowercased" - DuplicateHeaders ErrorCode = "multiple headers with same name" - WrongCompressedPayloadSize ErrorCode = "compressed payload size was incorrect" - UnknownFrameType ErrorCode = "unknown frame type" - InvalidControlFrame ErrorCode = "invalid control frame" - InvalidDataFrame ErrorCode = "invalid data frame" - InvalidHeaderPresent ErrorCode = "frame contained invalid header" -) - -// Error contains both the type of error and additional values. StreamId is 0 -// if Error is not associated with a stream. -type Error struct { - Err ErrorCode - StreamId uint32 -} - -func (e *Error) String() string { - return string(e.Err) -} - -var invalidReqHeaders = map[string]bool{ - "Connection": true, - "Keep-Alive": true, - "Proxy-Connection": true, - "Transfer-Encoding": true, -} - -var invalidRespHeaders = map[string]bool{ - "Connection": true, - "Keep-Alive": true, - "Transfer-Encoding": true, -} - -// Framer handles serializing/deserializing SPDY frames, including compressing/ -// decompressing payloads. -type Framer struct { - headerCompressionDisabled bool - w io.Writer - headerBuf *bytes.Buffer - headerCompressor *zlib.Writer - r io.Reader - headerReader io.LimitedReader - headerDecompressor io.ReadCloser -} - -// NewFramer allocates a new Framer for a given SPDY connection, repesented by -// a io.Writer and io.Reader. Note that Framer will read and write individual fields -// from/to the Reader and Writer, so the caller should pass in an appropriately -// buffered implementation to optimize performance. -func NewFramer(w io.Writer, r io.Reader) (*Framer, os.Error) { - compressBuf := new(bytes.Buffer) - compressor, err := zlib.NewWriterDict(compressBuf, zlib.BestCompression, []byte(HeaderDictionary)) - if err != nil { - return nil, err - } - framer := &Framer{ - w: w, - headerBuf: compressBuf, - headerCompressor: compressor, - r: r, - } - return framer, nil -} diff --git a/libgo/go/http/spdy/write.go b/libgo/go/http/spdy/write.go deleted file mode 100644 index 7d40bbe..0000000 --- a/libgo/go/http/spdy/write.go +++ /dev/null @@ -1,286 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package spdy - -import ( - "encoding/binary" - "http" - "io" - "os" - "strings" -) - -func (frame *SynStreamFrame) write(f *Framer) os.Error { - return f.writeSynStreamFrame(frame) -} - -func (frame *SynReplyFrame) write(f *Framer) os.Error { - return f.writeSynReplyFrame(frame) -} - -func (frame *RstStreamFrame) write(f *Framer) (err os.Error) { - frame.CFHeader.version = Version - frame.CFHeader.frameType = TypeRstStream - frame.CFHeader.length = 8 - - // Serialize frame to Writer - if err = writeControlFrameHeader(f.w, frame.CFHeader); err != nil { - return - } - if err = binary.Write(f.w, binary.BigEndian, frame.StreamId); err != nil { - return - } - if err = binary.Write(f.w, binary.BigEndian, frame.Status); err != nil { - return - } - return -} - -func (frame *SettingsFrame) write(f *Framer) (err os.Error) { - frame.CFHeader.version = Version - frame.CFHeader.frameType = TypeSettings - frame.CFHeader.length = uint32(len(frame.FlagIdValues)*8 + 4) - - // Serialize frame to Writer - if err = writeControlFrameHeader(f.w, frame.CFHeader); err != nil { - return - } - if err = binary.Write(f.w, binary.BigEndian, uint32(len(frame.FlagIdValues))); err != nil { - return - } - for _, flagIdValue := range frame.FlagIdValues { - flagId := (uint32(flagIdValue.Flag) << 24) | uint32(flagIdValue.Id) - if err = binary.Write(f.w, binary.BigEndian, flagId); err != nil { - return - } - if err = binary.Write(f.w, binary.BigEndian, flagIdValue.Value); err != nil { - return - } - } - return -} - -func (frame *NoopFrame) write(f *Framer) os.Error { - frame.CFHeader.version = Version - frame.CFHeader.frameType = TypeNoop - - // Serialize frame to Writer - return writeControlFrameHeader(f.w, frame.CFHeader) -} - -func (frame *PingFrame) write(f *Framer) (err os.Error) { - frame.CFHeader.version = Version - frame.CFHeader.frameType = TypePing - frame.CFHeader.length = 4 - - // Serialize frame to Writer - if err = writeControlFrameHeader(f.w, frame.CFHeader); err != nil { - return - } - if err = binary.Write(f.w, binary.BigEndian, frame.Id); err != nil { - return - } - return -} - -func (frame *GoAwayFrame) write(f *Framer) (err os.Error) { - frame.CFHeader.version = Version - frame.CFHeader.frameType = TypeGoAway - frame.CFHeader.length = 4 - - // Serialize frame to Writer - if err = writeControlFrameHeader(f.w, frame.CFHeader); err != nil { - return - } - if err = binary.Write(f.w, binary.BigEndian, frame.LastGoodStreamId); err != nil { - return - } - return nil -} - -func (frame *HeadersFrame) write(f *Framer) os.Error { - return f.writeHeadersFrame(frame) -} - -func (frame *DataFrame) write(f *Framer) os.Error { - return f.writeDataFrame(frame) -} - -// WriteFrame writes a frame. -func (f *Framer) WriteFrame(frame Frame) os.Error { - return frame.write(f) -} - -func writeControlFrameHeader(w io.Writer, h ControlFrameHeader) os.Error { - if err := binary.Write(w, binary.BigEndian, 0x8000|h.version); err != nil { - return err - } - if err := binary.Write(w, binary.BigEndian, h.frameType); err != nil { - return err - } - flagsAndLength := (uint32(h.Flags) << 24) | h.length - if err := binary.Write(w, binary.BigEndian, flagsAndLength); err != nil { - return err - } - return nil -} - -func writeHeaderValueBlock(w io.Writer, h http.Header) (n int, err os.Error) { - n = 0 - if err = binary.Write(w, binary.BigEndian, uint16(len(h))); err != nil { - return - } - n += 2 - for name, values := range h { - if err = binary.Write(w, binary.BigEndian, uint16(len(name))); err != nil { - return - } - n += 2 - name = strings.ToLower(name) - if _, err = io.WriteString(w, name); err != nil { - return - } - n += len(name) - v := strings.Join(values, "\x00") - if err = binary.Write(w, binary.BigEndian, uint16(len(v))); err != nil { - return - } - n += 2 - if _, err = io.WriteString(w, v); err != nil { - return - } - n += len(v) - } - return -} - -func (f *Framer) writeSynStreamFrame(frame *SynStreamFrame) (err os.Error) { - // Marshal the headers. - var writer io.Writer = f.headerBuf - if !f.headerCompressionDisabled { - writer = f.headerCompressor - } - if _, err = writeHeaderValueBlock(writer, frame.Headers); err != nil { - return - } - if !f.headerCompressionDisabled { - f.headerCompressor.Flush() - } - - // Set ControlFrameHeader - frame.CFHeader.version = Version - frame.CFHeader.frameType = TypeSynStream - frame.CFHeader.length = uint32(len(f.headerBuf.Bytes()) + 10) - - // Serialize frame to Writer - if err = writeControlFrameHeader(f.w, frame.CFHeader); err != nil { - return err - } - if err = binary.Write(f.w, binary.BigEndian, frame.StreamId); err != nil { - return err - } - if err = binary.Write(f.w, binary.BigEndian, frame.AssociatedToStreamId); err != nil { - return err - } - if err = binary.Write(f.w, binary.BigEndian, frame.Priority<<14); err != nil { - return err - } - if _, err = f.w.Write(f.headerBuf.Bytes()); err != nil { - return err - } - f.headerBuf.Reset() - return nil -} - -func (f *Framer) writeSynReplyFrame(frame *SynReplyFrame) (err os.Error) { - // Marshal the headers. - var writer io.Writer = f.headerBuf - if !f.headerCompressionDisabled { - writer = f.headerCompressor - } - if _, err = writeHeaderValueBlock(writer, frame.Headers); err != nil { - return - } - if !f.headerCompressionDisabled { - f.headerCompressor.Flush() - } - - // Set ControlFrameHeader - frame.CFHeader.version = Version - frame.CFHeader.frameType = TypeSynReply - frame.CFHeader.length = uint32(len(f.headerBuf.Bytes()) + 6) - - // Serialize frame to Writer - if err = writeControlFrameHeader(f.w, frame.CFHeader); err != nil { - return - } - if err = binary.Write(f.w, binary.BigEndian, frame.StreamId); err != nil { - return - } - if err = binary.Write(f.w, binary.BigEndian, uint16(0)); err != nil { - return - } - if _, err = f.w.Write(f.headerBuf.Bytes()); err != nil { - return - } - f.headerBuf.Reset() - return -} - -func (f *Framer) writeHeadersFrame(frame *HeadersFrame) (err os.Error) { - // Marshal the headers. - var writer io.Writer = f.headerBuf - if !f.headerCompressionDisabled { - writer = f.headerCompressor - } - if _, err = writeHeaderValueBlock(writer, frame.Headers); err != nil { - return - } - if !f.headerCompressionDisabled { - f.headerCompressor.Flush() - } - - // Set ControlFrameHeader - frame.CFHeader.version = Version - frame.CFHeader.frameType = TypeHeaders - frame.CFHeader.length = uint32(len(f.headerBuf.Bytes()) + 6) - - // Serialize frame to Writer - if err = writeControlFrameHeader(f.w, frame.CFHeader); err != nil { - return - } - if err = binary.Write(f.w, binary.BigEndian, frame.StreamId); err != nil { - return - } - if err = binary.Write(f.w, binary.BigEndian, uint16(0)); err != nil { - return - } - if _, err = f.w.Write(f.headerBuf.Bytes()); err != nil { - return - } - f.headerBuf.Reset() - return -} - -func (f *Framer) writeDataFrame(frame *DataFrame) (err os.Error) { - // Validate DataFrame - if frame.StreamId&0x80000000 != 0 || len(frame.Data) >= 0x0f000000 { - return &Error{InvalidDataFrame, frame.StreamId} - } - - // Serialize frame to Writer - if err = binary.Write(f.w, binary.BigEndian, frame.StreamId); err != nil { - return - } - flagsAndLength := (uint32(frame.Flags) << 24) | uint32(len(frame.Data)) - if err = binary.Write(f.w, binary.BigEndian, flagsAndLength); err != nil { - return - } - if _, err = f.w.Write(frame.Data); err != nil { - return - } - - return nil -} diff --git a/libgo/go/http/transfer.go b/libgo/go/http/transfer.go index b65d99a..868a114 100644 --- a/libgo/go/http/transfer.go +++ b/libgo/go/http/transfer.go @@ -7,6 +7,7 @@ package http import ( "bytes" "bufio" + "fmt" "io" "io/ioutil" "os" @@ -18,10 +19,11 @@ import ( // sanitizes them without changing the user object and provides methods for // writing the respective header, body and trailer in wire format. type transferWriter struct { + Method string Body io.Reader BodyCloser io.Closer ResponseToHEAD bool - ContentLength int64 + ContentLength int64 // -1 means unknown, 0 means exactly none Close bool TransferEncoding []string Trailer Header @@ -34,6 +36,10 @@ func newTransferWriter(r interface{}) (t *transferWriter, err os.Error) { atLeastHTTP11 := false switch rr := r.(type) { case *Request: + if rr.ContentLength != 0 && rr.Body == nil { + return nil, fmt.Errorf("http: Request.ContentLength=%d with nil Body", rr.ContentLength) + } + t.Method = rr.Method t.Body = rr.Body t.BodyCloser = rr.Body t.ContentLength = rr.ContentLength @@ -64,6 +70,7 @@ func newTransferWriter(r interface{}) (t *transferWriter, err os.Error) { } } case *Response: + t.Method = rr.Request.Method t.Body = rr.Body t.BodyCloser = rr.Body t.ContentLength = rr.ContentLength @@ -105,6 +112,27 @@ func noBodyExpected(requestMethod string) bool { return requestMethod == "HEAD" } +func (t *transferWriter) shouldSendContentLength() bool { + if chunked(t.TransferEncoding) { + return false + } + if t.ContentLength > 0 { + return true + } + if t.ResponseToHEAD { + return true + } + // Many servers expect a Content-Length for these methods + if t.Method == "POST" || t.Method == "PUT" { + return true + } + if t.ContentLength == 0 && isIdentity(t.TransferEncoding) { + return true + } + + return false +} + func (t *transferWriter) WriteHeader(w io.Writer) (err os.Error) { if t.Close { _, err = io.WriteString(w, "Connection: close\r\n") @@ -116,14 +144,14 @@ func (t *transferWriter) WriteHeader(w io.Writer) (err os.Error) { // Write Content-Length and/or Transfer-Encoding whose values are a // function of the sanitized field triple (Body, ContentLength, // TransferEncoding) - if chunked(t.TransferEncoding) { - _, err = io.WriteString(w, "Transfer-Encoding: chunked\r\n") + if t.shouldSendContentLength() { + io.WriteString(w, "Content-Length: ") + _, err = io.WriteString(w, strconv.Itoa64(t.ContentLength)+"\r\n") if err != nil { return } - } else if t.ContentLength > 0 || t.ResponseToHEAD || (t.ContentLength == 0 && isIdentity(t.TransferEncoding)) { - io.WriteString(w, "Content-Length: ") - _, err = io.WriteString(w, strconv.Itoa64(t.ContentLength)+"\r\n") + } else if chunked(t.TransferEncoding) { + _, err = io.WriteString(w, "Transfer-Encoding: chunked\r\n") if err != nil { return } @@ -154,6 +182,8 @@ func (t *transferWriter) WriteHeader(w io.Writer) (err os.Error) { } func (t *transferWriter) WriteBody(w io.Writer) (err os.Error) { + var ncopy int64 + // Write body if t.Body != nil { if chunked(t.TransferEncoding) { @@ -163,9 +193,14 @@ func (t *transferWriter) WriteBody(w io.Writer) (err os.Error) { err = cw.Close() } } else if t.ContentLength == -1 { - _, err = io.Copy(w, t.Body) + ncopy, err = io.Copy(w, t.Body) } else { - _, err = io.Copy(w, io.LimitReader(t.Body, t.ContentLength)) + ncopy, err = io.Copy(w, io.LimitReader(t.Body, t.ContentLength)) + nextra, err := io.Copy(ioutil.Discard, t.Body) + if err != nil { + return err + } + ncopy += nextra } if err != nil { return err @@ -175,6 +210,11 @@ func (t *transferWriter) WriteBody(w io.Writer) (err os.Error) { } } + if t.ContentLength != -1 && t.ContentLength != ncopy { + return fmt.Errorf("http: Request.ContentLength=%d with Body length %d", + t.ContentLength, ncopy) + } + // TODO(petar): Place trailer writer code here. if chunked(t.TransferEncoding) { // Last chunk, empty trailer @@ -326,7 +366,7 @@ func fixTransferEncoding(requestMethod string, header Header) ([]string, os.Erro return nil, nil } - header["Transfer-Encoding"] = nil, false + delete(header, "Transfer-Encoding") // Head responses have no bodies, so the transfer encoding // should be ignored. @@ -359,7 +399,7 @@ func fixTransferEncoding(requestMethod string, header Header) ([]string, os.Erro // Chunked encoding trumps Content-Length. See RFC 2616 // Section 4.4. Currently len(te) > 0 implies chunked // encoding. - header["Content-Length"] = nil, false + delete(header, "Content-Length") return te, nil } @@ -478,6 +518,8 @@ type body struct { r *bufio.Reader // underlying wire-format reader for the trailer closing bool // is the connection to be closed after reading body? closed bool + + res *response // response writer for server requests, else nil } // ErrBodyReadAfterClose is returned when reading a Request Body after @@ -506,6 +548,15 @@ func (b *body) Close() os.Error { return nil } + // In a server request, don't continue reading from the client + // if we've already hit the maximum body size set by the + // handler. If this is set, that also means the TCP connection + // is about to be closed, so getting to the next HTTP request + // in the stream is not necessary. + if b.res != nil && b.res.requestBodyLimitHit { + return nil + } + if _, err := io.Copy(ioutil.Discard, b); err != nil { return err } diff --git a/libgo/go/http/transport.go b/libgo/go/http/transport.go index 4302ffa..0914af7 100644 --- a/libgo/go/http/transport.go +++ b/libgo/go/http/transport.go @@ -54,6 +54,10 @@ type Transport struct { // If Dial is nil, net.Dial is used. Dial func(net, addr string) (c net.Conn, err os.Error) + // TLSClientConfig specifies the TLS configuration to use with + // tls.Client. If nil, the default configuration is used. + TLSClientConfig *tls.Config + DisableKeepAlives bool DisableCompression bool @@ -96,12 +100,27 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, os.Error) { } } +// transportRequest is a wrapper around a *Request that adds +// optional extra headers to write. +type transportRequest struct { + *Request // original request, not to be mutated + extra Header // extra headers to write, or nil +} + +func (tr *transportRequest) extraHeaders() Header { + if tr.extra == nil { + tr.extra = make(Header) + } + return tr.extra +} + // RoundTrip implements the RoundTripper interface. func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) { if req.URL == nil { - if req.URL, err = url.Parse(req.RawURL); err != nil { - return - } + return nil, os.NewError("http: nil Request.URL") + } + if req.Header == nil { + return nil, os.NewError("http: nil Request.Header") } if req.URL.Scheme != "http" && req.URL.Scheme != "https" { t.lk.Lock() @@ -115,8 +134,8 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) { } return rt.RoundTrip(req) } - - cm, err := t.connectMethodForRequest(req) + treq := &transportRequest{Request: req} + cm, err := t.connectMethodForRequest(treq) if err != nil { return nil, err } @@ -130,7 +149,7 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) { return nil, err } - return pconn.roundTrip(req) + return pconn.roundTrip(treq) } // RegisterProtocol registers a new protocol with scheme. @@ -183,14 +202,14 @@ func getenvEitherCase(k string) string { return os.Getenv(strings.ToLower(k)) } -func (t *Transport) connectMethodForRequest(req *Request) (*connectMethod, os.Error) { +func (t *Transport) connectMethodForRequest(treq *transportRequest) (*connectMethod, os.Error) { cm := &connectMethod{ - targetScheme: req.URL.Scheme, - targetAddr: canonicalAddr(req.URL), + targetScheme: treq.URL.Scheme, + targetAddr: canonicalAddr(treq.URL), } if t.Proxy != nil { var err os.Error - cm.proxyURL, err = t.Proxy(req) + cm.proxyURL, err = t.Proxy(treq.Request) if err != nil { return nil, err } @@ -247,7 +266,7 @@ func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { } if len(pconns) == 1 { pconn = pconns[0] - t.idleConn[key] = nil, false + delete(t.idleConn, key) } else { // 2 or more cached connections; pop last // TODO: queue? @@ -293,25 +312,21 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { conn: conn, reqch: make(chan requestAndChan, 50), } - newClientConnFunc := NewClientConn switch { case cm.proxyURL == nil: // Do nothing. case cm.targetScheme == "http": - newClientConnFunc = NewProxyClientConn + pconn.isProxy = true if pa != "" { - pconn.mutateRequestFunc = func(req *Request) { - if req.Header == nil { - req.Header = make(Header) - } - req.Header.Set("Proxy-Authorization", pa) + pconn.mutateHeaderFunc = func(h Header) { + h.Set("Proxy-Authorization", pa) } } case cm.targetScheme == "https": connectReq := &Request{ Method: "CONNECT", - RawURL: cm.targetAddr, + URL: &url.URL{RawPath: cm.targetAddr}, Host: cm.targetAddr, Header: make(Header), } @@ -338,7 +353,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { if cm.targetScheme == "https" { // Initiate TLS and check remote host name against certificate. - conn = tls.Client(conn, nil) + conn = tls.Client(conn, t.TLSClientConfig) if err = conn.(*tls.Conn).Handshake(); err != nil { return nil, err } @@ -349,7 +364,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { } pconn.br = bufio.NewReader(pconn.conn) - pconn.cc = newClientConnFunc(conn, pconn.br) + pconn.cc = NewClientConn(conn, pconn.br) go pconn.readLoop() return pconn, nil } @@ -445,30 +460,21 @@ func (cm *connectMethod) tlsHost() string { return h } -type readResult struct { - res *Response // either res or err will be set - err os.Error -} - -type writeRequest struct { - // Set by client (in pc.roundTrip) - req *Request - resch chan *readResult - - // Set by writeLoop if an error writing headers. - writeErr os.Error -} - // persistConn wraps a connection, usually a persistent one // (but may be used for non-keep-alive requests as well) type persistConn struct { - t *Transport - cacheKey string // its connectMethod.String() - conn net.Conn - cc *ClientConn - br *bufio.Reader - reqch chan requestAndChan // written by roundTrip(); read by readLoop() - mutateRequestFunc func(*Request) // nil or func to modify each outbound request + t *Transport + cacheKey string // its connectMethod.String() + conn net.Conn + cc *ClientConn + br *bufio.Reader + reqch chan requestAndChan // written by roundTrip(); read by readLoop() + isProxy bool + + // mutateHeaderFunc is an optional func to modify extra + // headers on each outbound request before it's written. (the + // original Request given to RoundTrip is not modified) + mutateHeaderFunc func(Header) lk sync.Mutex // guards numExpectedResponses and broken numExpectedResponses int @@ -487,12 +493,24 @@ func (pc *persistConn) expectingResponse() bool { return pc.numExpectedResponses > 0 } +var remoteSideClosedFunc func(os.Error) bool // or nil to use default + +func remoteSideClosed(err os.Error) bool { + if err == os.EOF || err == os.EINVAL { + return true + } + if remoteSideClosedFunc != nil { + return remoteSideClosedFunc(err) + } + return false +} + func (pc *persistConn) readLoop() { alive := true for alive { pb, err := pc.br.Peek(1) if err != nil { - if (err == os.EOF || err == os.EINVAL) && !pc.expectingResponse() { + if remoteSideClosed(err) && !pc.expectingResponse() { // Remote side closed on us. (We probably hit their // max idle timeout) pc.close() @@ -512,9 +530,6 @@ func (pc *persistConn) readLoop() { if err != nil || resp.ContentLength == 0 { return resp, err } - if rc.addedGzip { - forReq.Header.Del("Accept-Encoding") - } if rc.addedGzip && resp.Header.Get("Content-Encoding") == "gzip" { resp.Header.Del("Content-Encoding") resp.Header.Del("Content-Length") @@ -590,9 +605,9 @@ type requestAndChan struct { addedGzip bool } -func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) { - if pc.mutateRequestFunc != nil { - pc.mutateRequestFunc(req) +func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err os.Error) { + if pc.mutateHeaderFunc != nil { + pc.mutateHeaderFunc(req.extraHeaders()) } // Ask for a compressed version if the caller didn't set their @@ -602,24 +617,28 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) { requestedGzip := false if !pc.t.DisableCompression && req.Header.Get("Accept-Encoding") == "" { // Request gzip only, not deflate. Deflate is ambiguous and - // as universally supported anyway. + // not as universally supported anyway. // See: http://www.gzip.org/zlib/zlib_faq.html#faq38 requestedGzip = true - req.Header.Set("Accept-Encoding", "gzip") + req.extraHeaders().Set("Accept-Encoding", "gzip") } pc.lk.Lock() pc.numExpectedResponses++ pc.lk.Unlock() - err = pc.cc.Write(req) + pc.cc.writeReq = func(r *Request, w io.Writer) os.Error { + return r.write(w, pc.isProxy, req.extra) + } + + err = pc.cc.Write(req.Request) if err != nil { pc.close() return } ch := make(chan responseAndError, 1) - pc.reqch <- requestAndChan{req, ch, requestedGzip} + pc.reqch <- requestAndChan{req.Request, ch, requestedGzip} re := <-ch pc.lk.Lock() pc.numExpectedResponses-- @@ -634,7 +653,7 @@ func (pc *persistConn) close() { pc.broken = true pc.cc.Close() pc.conn.Close() - pc.mutateRequestFunc = nil + pc.mutateHeaderFunc = nil } var portMap = map[string]string{ diff --git a/libgo/go/http/transport_test.go b/libgo/go/http/transport_test.go index eafde7f..f3162b9 100644 --- a/libgo/go/http/transport_test.go +++ b/libgo/go/http/transport_test.go @@ -78,7 +78,7 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) { fetch := func(n int) string { req := new(Request) var err os.Error - req.URL, err = url.Parse(ts.URL + fmt.Sprintf("?close=%v", connectionClose)) + req.URL, err = url.Parse(ts.URL + fmt.Sprintf("/?close=%v", connectionClose)) if err != nil { t.Fatalf("URL parse error: %v", err) } @@ -362,32 +362,6 @@ func TestTransportHeadChunkedResponse(t *testing.T) { } } -func TestTransportNilURL(t *testing.T) { - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - fmt.Fprintf(w, "Hi") - })) - defer ts.Close() - - req := new(Request) - req.URL = nil // what we're actually testing - req.Method = "GET" - req.RawURL = ts.URL - req.Proto = "HTTP/1.1" - req.ProtoMajor = 1 - req.ProtoMinor = 1 - req.Header = make(Header) - - tr := &Transport{} - res, err := tr.RoundTrip(req) - if err != nil { - t.Fatalf("unexpected RoundTrip error: %v", err) - } - body, err := ioutil.ReadAll(res.Body) - if g, e := string(body), "Hi"; g != e { - t.Fatalf("Expected response body of %q; got %q", e, g) - } -} - var roundTripTests = []struct { accept string expectAccept string @@ -398,7 +372,8 @@ var roundTripTests = []struct { // Requests with other accept-encoding should pass through unmodified {"foo", "foo", false}, // Requests with accept-encoding == gzip should be passed through - {"gzip", "gzip", true}} + {"gzip", "gzip", true}, +} // Test that the modification made to the Request by the RoundTripper is cleaned up func TestRoundTripGzip(t *testing.T) { @@ -406,7 +381,8 @@ func TestRoundTripGzip(t *testing.T) { ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { accept := req.Header.Get("Accept-Encoding") if expect := req.FormValue("expect_accept"); accept != expect { - t.Errorf("Accept-Encoding = %q, want %q", accept, expect) + t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q", + req.FormValue("testnum"), accept, expect) } if accept == "gzip" { rw.Header().Set("Content-Encoding", "gzip") @@ -422,8 +398,10 @@ func TestRoundTripGzip(t *testing.T) { for i, test := range roundTripTests { // Test basic request (no accept-encoding) - req, _ := NewRequest("GET", ts.URL+"?expect_accept="+test.expectAccept, nil) - req.Header.Set("Accept-Encoding", test.accept) + req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil) + if test.accept != "" { + req.Header.Set("Accept-Encoding", test.accept) + } res, err := DefaultTransport.RoundTrip(req) var body []byte if test.compressed { @@ -435,16 +413,16 @@ func TestRoundTripGzip(t *testing.T) { } if err != nil { t.Errorf("%d. Error: %q", i, err) - } else { - if g, e := string(body), responseBody; g != e { - t.Errorf("%d. body = %q; want %q", i, g, e) - } - if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e { - t.Errorf("%d. Accept-Encoding = %q; want %q", i, g, e) - } - if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e { - t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e) - } + continue + } + if g, e := string(body), responseBody; g != e { + t.Errorf("%d. body = %q; want %q", i, g, e) + } + if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e { + t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e) + } + if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e { + t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e) } } @@ -474,7 +452,7 @@ func TestTransportGzip(t *testing.T) { gz, _ := gzip.NewWriter(w) gz.Write([]byte(testString)) if req.FormValue("body") == "large" { - io.Copyn(gz, rand.Reader, nRandBytes) + io.CopyN(gz, rand.Reader, nRandBytes) } gz.Close() })) @@ -484,7 +462,7 @@ func TestTransportGzip(t *testing.T) { c := &Client{Transport: &Transport{}} // First fetch something large, but only read some of it. - res, err := c.Get(ts.URL + "?body=large&chunked=" + chunked) + res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked) if err != nil { t.Fatalf("large get: %v", err) } @@ -504,7 +482,7 @@ func TestTransportGzip(t *testing.T) { } // Then something small. - res, err = c.Get(ts.URL + "?chunked=" + chunked) + res, err = c.Get(ts.URL + "/?chunked=" + chunked) if err != nil { t.Fatal(err) } diff --git a/libgo/go/http/transport_windows.go b/libgo/go/http/transport_windows.go new file mode 100644 index 0000000..1ae7d83 --- /dev/null +++ b/libgo/go/http/transport_windows.go @@ -0,0 +1,21 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "os" + "net" +) + +func init() { + remoteSideClosedFunc = func(err os.Error) (out bool) { + op, ok := err.(*net.OpError) + if ok && op.Op == "WSARecv" && op.Net == "tcp" && op.Error == os.Errno(10058) { + // TODO(bradfitz): find the symbol for 10058 + return true + } + return false + } +} |