diff options
Diffstat (limited to 'libgo/go/net/http/transfer.go')
-rw-r--r-- | libgo/go/net/http/transfer.go | 63 |
1 files changed, 29 insertions, 34 deletions
diff --git a/libgo/go/net/http/transfer.go b/libgo/go/net/http/transfer.go index 25b34ad..53569bc 100644 --- a/libgo/go/net/http/transfer.go +++ b/libgo/go/net/http/transfer.go @@ -194,10 +194,11 @@ func (t *transferWriter) WriteBody(w io.Writer) (err error) { ncopy, err = io.Copy(w, t.Body) } else { ncopy, err = io.Copy(w, io.LimitReader(t.Body, t.ContentLength)) - nextra, err := io.Copy(ioutil.Discard, t.Body) if err != nil { return err } + var nextra int64 + nextra, err = io.Copy(ioutil.Discard, t.Body) ncopy += nextra } if err != nil { @@ -208,7 +209,7 @@ func (t *transferWriter) WriteBody(w io.Writer) (err error) { } } - if t.ContentLength != -1 && t.ContentLength != ncopy { + if !t.ResponseToHEAD && t.ContentLength != -1 && t.ContentLength != ncopy { return fmt.Errorf("http: Request.ContentLength=%d with Body length %d", t.ContentLength, ncopy) } @@ -326,9 +327,14 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { // or close connection when finished, since multipart is not supported yet switch { case chunked(t.TransferEncoding): - t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close} - case realLength >= 0: - // TODO: limit the Content-Length. This is an easy DoS vector. + if noBodyExpected(t.RequestMethod) { + t.Body = &body{Reader: eofReader, closing: t.Close} + } else { + t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close} + } + case realLength == 0: + t.Body = &body{Reader: eofReader, closing: t.Close} + case realLength > 0: t.Body = &body{Reader: io.LimitReader(r, realLength), closing: t.Close} default: // realLength < 0, i.e. "Content-Length" not mentioned in header @@ -337,7 +343,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { t.Body = &body{Reader: r, closing: t.Close} } else { // Persistent connection (i.e. HTTP/1.1) - t.Body = &body{Reader: io.LimitReader(r, 0), closing: t.Close} + t.Body = &body{Reader: eofReader, closing: t.Close} } } @@ -449,13 +455,6 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header, return 0, nil } - // Logic based on media type. The purpose of the following code is just - // to detect whether the unsupported "multipart/byteranges" is being - // used. A proper Content-Type parser is needed in the future. - if strings.Contains(strings.ToLower(header.get("Content-Type")), "multipart/byteranges") { - return -1, ErrNotSupported - } - // Body-EOF logic based on other methods (like closing, or chunked coding) return -1, nil } @@ -614,30 +613,26 @@ func (b *body) Close() error { if b.closed { return nil } - defer func() { - b.closed = true - }() - if b.hdr == nil && b.closing { + var err error + switch { + case b.hdr == nil && b.closing: // no trailer and closing the connection next. // no point in reading to EOF. - return nil - } - - // In a server request, don't continue reading from the client - // if we've already hit the maximum body size set by the - // handler. If this is set, that also means the TCP connection - // is about to be closed, so getting to the next HTTP request - // in the stream is not necessary. - if b.res != nil && b.res.requestBodyLimitHit { - return nil - } - - // Fully consume the body, which will also lead to us reading - // the trailer headers after the body, if present. - if _, err := io.Copy(ioutil.Discard, b); err != nil { - return err + case b.res != nil && b.res.requestBodyLimitHit: + // In a server request, don't continue reading from the client + // if we've already hit the maximum body size set by the + // handler. If this is set, that also means the TCP connection + // is about to be closed, so getting to the next HTTP request + // in the stream is not necessary. + case b.Reader == eofReader: + // Nothing to read. No need to io.Copy from it. + default: + // Fully consume the body, which will also lead to us reading + // the trailer headers after the body, if present. + _, err = io.Copy(ioutil.Discard, b) } - return nil + b.closed = true + return err } // parseContentLength trims whitespace from s and returns -1 if no value |