diff options
author | Ian Lance Taylor <ian@gcc.gnu.org> | 2013-07-16 06:54:42 +0000 |
---|---|---|
committer | Ian Lance Taylor <ian@gcc.gnu.org> | 2013-07-16 06:54:42 +0000 |
commit | be47d6eceffd2c5dbbc1566d5eea490527fb2bd4 (patch) | |
tree | 0e8fda573576bb4181dba29d0e88380a8c38fafd /libgo/go/encoding | |
parent | efb30cdeb003fd7c585ee0d7657340086abcbd9e (diff) | |
download | gcc-be47d6eceffd2c5dbbc1566d5eea490527fb2bd4.zip gcc-be47d6eceffd2c5dbbc1566d5eea490527fb2bd4.tar.gz gcc-be47d6eceffd2c5dbbc1566d5eea490527fb2bd4.tar.bz2 |
libgo: Update to Go 1.1.1.
From-SVN: r200974
Diffstat (limited to 'libgo/go/encoding')
27 files changed, 1129 insertions, 216 deletions
diff --git a/libgo/go/encoding/ascii85/ascii85.go b/libgo/go/encoding/ascii85/ascii85.go index 7050227..e2afc58 100644 --- a/libgo/go/encoding/ascii85/ascii85.go +++ b/libgo/go/encoding/ascii85/ascii85.go @@ -296,5 +296,4 @@ func (d *decoder) Read(p []byte) (n int, err error) { nn, d.readErr = d.r.Read(d.buf[d.nbuf:]) d.nbuf += nn } - panic("unreachable") } diff --git a/libgo/go/encoding/asn1/marshal.go b/libgo/go/encoding/asn1/marshal.go index 0c216fd..adaf80d 100644 --- a/libgo/go/encoding/asn1/marshal.go +++ b/libgo/go/encoding/asn1/marshal.go @@ -460,7 +460,6 @@ func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameter default: return marshalUTF8String(out, v.String()) } - return } return StructuralError{"unknown Go type"} diff --git a/libgo/go/encoding/base32/base32.go b/libgo/go/encoding/base32/base32.go index dbefc48..fe17b73 100644 --- a/libgo/go/encoding/base32/base32.go +++ b/libgo/go/encoding/base32/base32.go @@ -6,8 +6,10 @@ package base32 import ( + "bytes" "io" "strconv" + "strings" ) /* @@ -48,6 +50,13 @@ var StdEncoding = NewEncoding(encodeStd) // It is typically used in DNS. var HexEncoding = NewEncoding(encodeHex) +var removeNewlinesMapper = func(r rune) rune { + if r == '\r' || r == '\n' { + return -1 + } + return r +} + /* * Encoder */ @@ -228,40 +237,47 @@ func (e CorruptInputError) Error() string { // decode is like Decode but returns an additional 'end' value, which // indicates if end-of-message padding was encountered and thus any -// additional data is an error. +// additional data is an error. This method assumes that src has been +// stripped of all supported whitespace ('\r' and '\n'). func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { - osrc := src + olen := len(src) for len(src) > 0 && !end { // Decode quantum using the base32 alphabet var dbuf [8]byte dlen := 8 - // do the top bytes contain any data? for j := 0; j < 8; { if len(src) == 0 { - return n, false, CorruptInputError(len(osrc) - len(src) - j) + return n, false, CorruptInputError(olen - len(src) - j) } in := src[0] src = src[1:] - if in == '\r' || in == '\n' { - // Ignore this character. - continue - } if in == '=' && j >= 2 && len(src) < 8 { - // We've reached the end and there's - // padding, the rest should be padded - for k := 0; k < 8-j-1; k++ { + // We've reached the end and there's padding + if len(src)+j < 8-1 { + // not enough padding + return n, false, CorruptInputError(olen) + } + for k := 0; k < 8-1-j; k++ { if len(src) > k && src[k] != '=' { - return n, false, CorruptInputError(len(osrc) - len(src) + k - 1) + // incorrect padding + return n, false, CorruptInputError(olen - len(src) + k - 1) } } - dlen = j - end = true + dlen, end = j, true + // 7, 5 and 2 are not valid padding lengths, and so 1, 3 and 6 are not + // valid dlen values. See RFC 4648 Section 6 "Base 32 Encoding" listing + // the five valid padding lengths, and Section 9 "Illustrations and + // Examples" for an illustration for how the the 1st, 3rd and 6th base32 + // src bytes do not yield enough information to decode a dst byte. + if dlen == 1 || dlen == 3 || dlen == 6 { + return n, false, CorruptInputError(olen - len(src) - 1) + } break } dbuf[j] = enc.decodeMap[in] if dbuf[j] == 0xFF { - return n, false, CorruptInputError(len(osrc) - len(src) - 1) + return n, false, CorruptInputError(olen - len(src) - 1) } j++ } @@ -269,16 +285,16 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { // Pack 8x 5-bit source blocks into 5 byte destination // quantum switch dlen { - case 7, 8: + case 8: dst[4] = dbuf[6]<<5 | dbuf[7] fallthrough - case 6, 5: + case 7: dst[3] = dbuf[4]<<7 | dbuf[5]<<2 | dbuf[6]>>3 fallthrough - case 4: + case 5: dst[2] = dbuf[3]<<4 | dbuf[4]>>1 fallthrough - case 3: + case 4: dst[1] = dbuf[1]<<6 | dbuf[2]<<1 | dbuf[3]>>4 fallthrough case 2: @@ -288,11 +304,11 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { switch dlen { case 2: n += 1 - case 3, 4: + case 4: n += 2 case 5: n += 3 - case 6, 7: + case 7: n += 4 case 8: n += 5 @@ -307,12 +323,14 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { // number of bytes successfully written and CorruptInputError. // New line characters (\r and \n) are ignored. func (enc *Encoding) Decode(dst, src []byte) (n int, err error) { + src = bytes.Map(removeNewlinesMapper, src) n, _, err = enc.decode(dst, src) return } // DecodeString returns the bytes represented by the base32 string s. func (enc *Encoding) DecodeString(s string) ([]byte, error) { + s = strings.Map(removeNewlinesMapper, s) dbuf := make([]byte, enc.DecodedLen(len(s))) n, err := enc.Decode(dbuf, []byte(s)) return dbuf[:n], err @@ -377,9 +395,34 @@ func (d *decoder) Read(p []byte) (n int, err error) { return n, d.err } +type newlineFilteringReader struct { + wrapped io.Reader +} + +func (r *newlineFilteringReader) Read(p []byte) (int, error) { + n, err := r.wrapped.Read(p) + for n > 0 { + offset := 0 + for i, b := range p[0:n] { + if b != '\r' && b != '\n' { + if i != offset { + p[offset] = b + } + offset++ + } + } + if offset > 0 { + return offset, err + } + // Previous buffer entirely whitespace, read again + n, err = r.wrapped.Read(p) + } + return n, err +} + // NewDecoder constructs a new base32 stream decoder. func NewDecoder(enc *Encoding, r io.Reader) io.Reader { - return &decoder{enc: enc, r: r} + return &decoder{enc: enc, r: &newlineFilteringReader{r}} } // DecodedLen returns the maximum length in bytes of the decoded data diff --git a/libgo/go/encoding/base32/base32_test.go b/libgo/go/encoding/base32/base32_test.go index 98365e1..63298d1 100644 --- a/libgo/go/encoding/base32/base32_test.go +++ b/libgo/go/encoding/base32/base32_test.go @@ -8,6 +8,7 @@ import ( "bytes" "io" "io/ioutil" + "strings" "testing" ) @@ -137,27 +138,48 @@ func TestDecoderBuffering(t *testing.T) { } func TestDecodeCorrupt(t *testing.T) { - type corrupt struct { - e string - p int - } - examples := []corrupt{ + testCases := []struct { + input string + offset int // -1 means no corruption. + }{ + {"", -1}, {"!!!!", 0}, {"x===", 0}, {"AA=A====", 2}, {"AAA=AAAA", 3}, {"MMMMMMMMM", 8}, {"MMMMMM", 0}, + {"A=", 1}, + {"AA=", 3}, + {"AA==", 4}, + {"AA===", 5}, + {"AAAA=", 5}, + {"AAAA==", 6}, + {"AAAAA=", 6}, + {"AAAAA==", 7}, + {"A=======", 1}, + {"AA======", -1}, + {"AAA=====", 3}, + {"AAAA====", -1}, + {"AAAAA===", -1}, + {"AAAAAA==", 6}, + {"AAAAAAA=", -1}, + {"AAAAAAAA", -1}, } - - for _, e := range examples { - dbuf := make([]byte, StdEncoding.DecodedLen(len(e.e))) - _, err := StdEncoding.Decode(dbuf, []byte(e.e)) + for _, tc := range testCases { + dbuf := make([]byte, StdEncoding.DecodedLen(len(tc.input))) + _, err := StdEncoding.Decode(dbuf, []byte(tc.input)) + if tc.offset == -1 { + if err != nil { + t.Error("Decoder wrongly detected coruption in", tc.input) + } + continue + } switch err := err.(type) { case CorruptInputError: - testEqual(t, "Corruption in %q at offset %v, want %v", e.e, int(err), e.p) + testEqual(t, "Corruption in %q at offset %v, want %v", tc.input, int(err), tc.offset) default: - t.Error("Decoder failed to detect corruption in", e) + t.Error("Decoder failed to detect corruption in", tc) } } } @@ -195,9 +217,21 @@ func TestBig(t *testing.T) { } } +func testStringEncoding(t *testing.T, expected string, examples []string) { + for _, e := range examples { + buf, err := StdEncoding.DecodeString(e) + if err != nil { + t.Errorf("Decode(%q) failed: %v", e, err) + continue + } + if s := string(buf); s != expected { + t.Errorf("Decode(%q) = %q, want %q", e, s, expected) + } + } +} + func TestNewLineCharacters(t *testing.T) { // Each of these should decode to the string "sure", without errors. - const expected = "sure" examples := []string{ "ON2XEZI=", "ON2XEZI=\r", @@ -209,14 +243,44 @@ func TestNewLineCharacters(t *testing.T) { "ON2XEZ\nI=", "ON2XEZI\n=", } - for _, e := range examples { - buf, err := StdEncoding.DecodeString(e) - if err != nil { - t.Errorf("Decode(%q) failed: %v", e, err) - continue - } - if s := string(buf); s != expected { - t.Errorf("Decode(%q) = %q, want %q", e, s, expected) - } + testStringEncoding(t, "sure", examples) + + // Each of these should decode to the string "foobar", without errors. + examples = []string{ + "MZXW6YTBOI======", + "MZXW6YTBOI=\r\n=====", + } + testStringEncoding(t, "foobar", examples) +} + +func TestDecoderIssue4779(t *testing.T) { + encoded := `JRXXEZLNEBUXA43VNUQGI33MN5ZCA43JOQQGC3LFOQWCAY3PNZZWKY3UMV2HK4 +RAMFSGS4DJONUWG2LOM4QGK3DJOQWCA43FMQQGI3YKMVUXK43NN5SCA5DFNVYG64RANFXGG2LENFSH +K3TUEB2XIIDMMFRG64TFEBSXIIDEN5WG64TFEBWWCZ3OMEQGC3DJOF2WCLRAKV2CAZLONFWQUYLEEB +WWS3TJNUQHMZLONFQW2LBAOF2WS4ZANZXXG5DSOVSCAZLYMVZGG2LUMF2GS33OEB2WY3DBNVRW6IDM +MFRG64TJOMQG42LTNEQHK5AKMFWGS4LVNFYCAZLYEBSWCIDDN5WW233EN4QGG33OONSXC5LBOQXCAR +DVNFZSAYLVORSSA2LSOVZGKIDEN5WG64RANFXAU4TFOBZGK2DFNZSGK4TJOQQGS3RAOZXWY5LQORQX +IZJAOZSWY2LUEBSXG43FEBRWS3DMOVWSAZDPNRXXEZJAMV2SAZTVM5UWC5BANZ2WY3DBBJYGC4TJMF +2HK4ROEBCXQY3FOB2GK5LSEBZWS3TUEBXWGY3BMVRWC5BAMN2XA2LEMF2GC5BANZXW4IDQOJXWSZDF +NZ2CYIDTOVXHIIDJNYFGG5LMOBQSA4LVNEQG6ZTGNFRWSYJAMRSXGZLSOVXHIIDNN5WGY2LUEBQW42 +LNEBUWIIDFON2CA3DBMJXXE5LNFY== +====` + encodedShort := strings.Replace(encoded, "\n", "", -1) + + dec := NewDecoder(StdEncoding, bytes.NewBufferString(encoded)) + res1, err := ioutil.ReadAll(dec) + if err != nil { + t.Errorf("ReadAll failed: %v", err) + } + + dec = NewDecoder(StdEncoding, bytes.NewBufferString(encodedShort)) + var res2 []byte + res2, err = ioutil.ReadAll(dec) + if err != nil { + t.Errorf("ReadAll failed: %v", err) + } + + if !bytes.Equal(res1, res2) { + t.Error("Decoded results not equal") } } diff --git a/libgo/go/encoding/base64/base64.go b/libgo/go/encoding/base64/base64.go index e66672a..85e398f 100644 --- a/libgo/go/encoding/base64/base64.go +++ b/libgo/go/encoding/base64/base64.go @@ -6,8 +6,10 @@ package base64 import ( + "bytes" "io" "strconv" + "strings" ) /* @@ -49,6 +51,13 @@ var StdEncoding = NewEncoding(encodeStd) // It is typically used in URLs and file names. var URLEncoding = NewEncoding(encodeURL) +var removeNewlinesMapper = func(r rune) rune { + if r == '\r' || r == '\n' { + return -1 + } + return r +} + /* * Encoder */ @@ -208,9 +217,10 @@ func (e CorruptInputError) Error() string { // decode is like Decode but returns an additional 'end' value, which // indicates if end-of-message padding was encountered and thus any -// additional data is an error. +// additional data is an error. This method assumes that src has been +// stripped of all supported whitespace ('\r' and '\n'). func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { - osrc := src + olen := len(src) for len(src) > 0 && !end { // Decode quantum using the base64 alphabet var dbuf [4]byte @@ -218,32 +228,26 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { for j := 0; j < 4; { if len(src) == 0 { - return n, false, CorruptInputError(len(osrc) - len(src) - j) + return n, false, CorruptInputError(olen - len(src) - j) } in := src[0] src = src[1:] - if in == '\r' || in == '\n' { - // Ignore this character. - continue - } if in == '=' && j >= 2 && len(src) < 4 { - // We've reached the end and there's - // padding - if len(src) == 0 && j == 2 { + // We've reached the end and there's padding + if len(src)+j < 4-1 { // not enough padding - return n, false, CorruptInputError(len(osrc)) + return n, false, CorruptInputError(olen) } if len(src) > 0 && src[0] != '=' { // incorrect padding - return n, false, CorruptInputError(len(osrc) - len(src) - 1) + return n, false, CorruptInputError(olen - len(src) - 1) } - dlen = j - end = true + dlen, end = j, true break } dbuf[j] = enc.decodeMap[in] if dbuf[j] == 0xFF { - return n, false, CorruptInputError(len(osrc) - len(src) - 1) + return n, false, CorruptInputError(olen - len(src) - 1) } j++ } @@ -273,12 +277,14 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { // number of bytes successfully written and CorruptInputError. // New line characters (\r and \n) are ignored. func (enc *Encoding) Decode(dst, src []byte) (n int, err error) { + src = bytes.Map(removeNewlinesMapper, src) n, _, err = enc.decode(dst, src) return } // DecodeString returns the bytes represented by the base64 string s. func (enc *Encoding) DecodeString(s string) ([]byte, error) { + s = strings.Map(removeNewlinesMapper, s) dbuf := make([]byte, enc.DecodedLen(len(s))) n, err := enc.Decode(dbuf, []byte(s)) return dbuf[:n], err @@ -343,9 +349,34 @@ func (d *decoder) Read(p []byte) (n int, err error) { return n, d.err } +type newlineFilteringReader struct { + wrapped io.Reader +} + +func (r *newlineFilteringReader) Read(p []byte) (int, error) { + n, err := r.wrapped.Read(p) + for n > 0 { + offset := 0 + for i, b := range p[0:n] { + if b != '\r' && b != '\n' { + if i != offset { + p[offset] = b + } + offset++ + } + } + if offset > 0 { + return offset, err + } + // Previous buffer entirely whitespace, read again + n, err = r.wrapped.Read(p) + } + return n, err +} + // NewDecoder constructs a new base64 stream decoder. func NewDecoder(enc *Encoding, r io.Reader) io.Reader { - return &decoder{enc: enc, r: r} + return &decoder{enc: enc, r: &newlineFilteringReader{r}} } // DecodedLen returns the maximum length in bytes of the decoded data diff --git a/libgo/go/encoding/base64/base64_test.go b/libgo/go/encoding/base64/base64_test.go index f9b863c..579591a 100644 --- a/libgo/go/encoding/base64/base64_test.go +++ b/libgo/go/encoding/base64/base64_test.go @@ -9,6 +9,7 @@ import ( "errors" "io" "io/ioutil" + "strings" "testing" "time" ) @@ -142,11 +143,11 @@ func TestDecoderBuffering(t *testing.T) { } func TestDecodeCorrupt(t *testing.T) { - type corrupt struct { - e string - p int - } - examples := []corrupt{ + testCases := []struct { + input string + offset int // -1 means no corruption. + }{ + {"", -1}, {"!!!!", 0}, {"x===", 1}, {"AA=A", 2}, @@ -154,18 +155,27 @@ func TestDecodeCorrupt(t *testing.T) { {"AAAAA", 4}, {"AAAAAA", 4}, {"A=", 1}, + {"A==", 1}, {"AA=", 3}, + {"AA==", -1}, + {"AAA=", -1}, + {"AAAA", -1}, {"AAAAAA=", 7}, } - - for _, e := range examples { - dbuf := make([]byte, StdEncoding.DecodedLen(len(e.e))) - _, err := StdEncoding.Decode(dbuf, []byte(e.e)) + for _, tc := range testCases { + dbuf := make([]byte, StdEncoding.DecodedLen(len(tc.input))) + _, err := StdEncoding.Decode(dbuf, []byte(tc.input)) + if tc.offset == -1 { + if err != nil { + t.Error("Decoder wrongly detected coruption in", tc.input) + } + continue + } switch err := err.(type) { case CorruptInputError: - testEqual(t, "Corruption in %q at offset %v, want %v", e.e, int(err), e.p) + testEqual(t, "Corruption in %q at offset %v, want %v", tc.input, int(err), tc.offset) default: - t.Error("Decoder failed to detect corruption in", e) + t.Error("Decoder failed to detect corruption in", tc) } } } @@ -216,6 +226,8 @@ func TestNewLineCharacters(t *testing.T) { "c3V\nyZ\rQ==", "c3VyZ\nQ==", "c3VyZQ\n==", + "c3VyZQ=\n=", + "c3VyZQ=\r\n\r\n=", } for _, e := range examples { buf, err := StdEncoding.DecodeString(e) @@ -257,6 +269,7 @@ func TestDecoderIssue3577(t *testing.T) { wantErr := errors.New("my error") next <- nextRead{5, nil} next <- nextRead{10, wantErr} + next <- nextRead{0, wantErr} d := NewDecoder(StdEncoding, &faultInjectReader{ source: "VHdhcyBicmlsbGlnLCBhbmQgdGhlIHNsaXRoeSB0b3Zlcw==", // twas brillig... nextc: next, @@ -275,3 +288,40 @@ func TestDecoderIssue3577(t *testing.T) { t.Errorf("timeout; Decoder blocked without returning an error") } } + +func TestDecoderIssue4779(t *testing.T) { + encoded := `CP/EAT8AAAEF +AQEBAQEBAAAAAAAAAAMAAQIEBQYHCAkKCwEAAQUBAQEBAQEAAAAAAAAAAQACAwQFBgcICQoLEAAB +BAEDAgQCBQcGCAUDDDMBAAIRAwQhEjEFQVFhEyJxgTIGFJGhsUIjJBVSwWIzNHKC0UMHJZJT8OHx +Y3M1FqKygyZEk1RkRcKjdDYX0lXiZfKzhMPTdePzRieUpIW0lcTU5PSltcXV5fVWZnaGlqa2xtbm +9jdHV2d3h5ent8fX5/cRAAICAQIEBAMEBQYHBwYFNQEAAhEDITESBEFRYXEiEwUygZEUobFCI8FS +0fAzJGLhcoKSQ1MVY3M08SUGFqKygwcmNcLSRJNUoxdkRVU2dGXi8rOEw9N14/NGlKSFtJXE1OT0 +pbXF1eX1VmZ2hpamtsbW5vYnN0dXZ3eHl6e3x//aAAwDAQACEQMRAD8A9VSSSSUpJJJJSkkkJ+Tj +1kiy1jCJJDnAcCTykpKkuQ6p/jN6FgmxlNduXawwAzaGH+V6jn/R/wCt71zdn+N/qL3kVYFNYB4N +ji6PDVjWpKp9TSXnvTf8bFNjg3qOEa2n6VlLpj/rT/pf567DpX1i6L1hs9Py67X8mqdtg/rUWbbf ++gkp0kkkklKSSSSUpJJJJT//0PVUkkklKVLq3WMDpGI7KzrNjADtYNXvI/Mqr/Pd/q9W3vaxjnvM +NaCXE9gNSvGPrf8AWS3qmba5jjsJhoB0DAf0NDf6sevf+/lf8Hj0JJATfWT6/dV6oXU1uOLQeKKn +EQP+Hubtfe/+R7Mf/g7f5xcocp++Z11JMCJPgFBxOg7/AOuqDx8I/ikpkXkmSdU8mJIJA/O8EMAy +j+mSARB/17pKVXYWHXjsj7yIex0PadzXMO1zT5KHoNA3HT8ietoGhgjsfA+CSnvvqh/jJtqsrwOv +2b6NGNzXfTYexzJ+nU7/ALkf4P8Awv6P9KvTQQ4AgyDqCF85Pho3CTB7eHwXoH+LT65uZbX9X+o2 +bqbPb06551Y4 +` + encodedShort := strings.Replace(encoded, "\n", "", -1) + + dec := NewDecoder(StdEncoding, bytes.NewBufferString(encoded)) + res1, err := ioutil.ReadAll(dec) + if err != nil { + t.Errorf("ReadAll failed: %v", err) + } + + dec = NewDecoder(StdEncoding, bytes.NewBufferString(encodedShort)) + var res2 []byte + res2, err = ioutil.ReadAll(dec) + if err != nil { + t.Errorf("ReadAll failed: %v", err) + } + + if !bytes.Equal(res1, res2) { + t.Error("Decoded results not equal") + } +} diff --git a/libgo/go/encoding/binary/binary.go b/libgo/go/encoding/binary/binary.go index 04d5723..edbac19 100644 --- a/libgo/go/encoding/binary/binary.go +++ b/libgo/go/encoding/binary/binary.go @@ -167,9 +167,9 @@ func Read(r io.Reader, order ByteOrder, data interface{}) error { default: return errors.New("binary.Read: invalid type " + d.Type().String()) } - size := dataSize(v) - if size < 0 { - return errors.New("binary.Read: invalid type " + v.Type().String()) + size, err := dataSize(v) + if err != nil { + return errors.New("binary.Read: " + err.Error()) } d := &decoder{order: order, buf: make([]byte, size)} if _, err := io.ReadFull(r, d.buf); err != nil { @@ -247,64 +247,68 @@ func Write(w io.Writer, order ByteOrder, data interface{}) error { // Fallback to reflect-based encoding. v := reflect.Indirect(reflect.ValueOf(data)) - size := dataSize(v) - if size < 0 { - return errors.New("binary.Write: invalid type " + v.Type().String()) + size, err := dataSize(v) + if err != nil { + return errors.New("binary.Write: " + err.Error()) } buf := make([]byte, size) e := &encoder{order: order, buf: buf} e.value(v) - _, err := w.Write(buf) + _, err = w.Write(buf) return err } // Size returns how many bytes Write would generate to encode the value v, which // must be a fixed-size value or a slice of fixed-size values, or a pointer to such data. func Size(v interface{}) int { - return dataSize(reflect.Indirect(reflect.ValueOf(v))) + n, err := dataSize(reflect.Indirect(reflect.ValueOf(v))) + if err != nil { + return -1 + } + return n } // dataSize returns the number of bytes the actual data represented by v occupies in memory. // For compound structures, it sums the sizes of the elements. Thus, for instance, for a slice // it returns the length of the slice times the element size and does not count the memory // occupied by the header. -func dataSize(v reflect.Value) int { +func dataSize(v reflect.Value) (int, error) { if v.Kind() == reflect.Slice { - elem := sizeof(v.Type().Elem()) - if elem < 0 { - return -1 + elem, err := sizeof(v.Type().Elem()) + if err != nil { + return 0, err } - return v.Len() * elem + return v.Len() * elem, nil } return sizeof(v.Type()) } -func sizeof(t reflect.Type) int { +func sizeof(t reflect.Type) (int, error) { switch t.Kind() { case reflect.Array: - n := sizeof(t.Elem()) - if n < 0 { - return -1 + n, err := sizeof(t.Elem()) + if err != nil { + return 0, err } - return t.Len() * n + return t.Len() * n, nil case reflect.Struct: sum := 0 for i, n := 0, t.NumField(); i < n; i++ { - s := sizeof(t.Field(i).Type) - if s < 0 { - return -1 + s, err := sizeof(t.Field(i).Type) + if err != nil { + return 0, err } sum += s } - return sum + return sum, nil case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: - return int(t.Size()) + return int(t.Size()), nil } - return -1 + return 0, errors.New("invalid type " + t.String()) } type coder struct { @@ -514,11 +518,12 @@ func (e *encoder) value(v reflect.Value) { } func (d *decoder) skip(v reflect.Value) { - d.buf = d.buf[dataSize(v):] + n, _ := dataSize(v) + d.buf = d.buf[n:] } func (e *encoder) skip(v reflect.Value) { - n := dataSize(v) + n, _ := dataSize(v) for i := range e.buf[0:n] { e.buf[i] = 0 } diff --git a/libgo/go/encoding/binary/binary_test.go b/libgo/go/encoding/binary/binary_test.go index cfad8d3..056f099 100644 --- a/libgo/go/encoding/binary/binary_test.go +++ b/libgo/go/encoding/binary/binary_test.go @@ -9,6 +9,7 @@ import ( "io" "math" "reflect" + "strings" "testing" ) @@ -149,8 +150,14 @@ func TestWriteT(t *testing.T) { tv := reflect.Indirect(reflect.ValueOf(ts)) for i, n := 0, tv.NumField(); i < n; i++ { + typ := tv.Field(i).Type().String() + if typ == "[4]int" { + typ = "int" // the problem is int, not the [4] + } if err := Write(buf, BigEndian, tv.Field(i).Interface()); err == nil { t.Errorf("WriteT.%v: have err == nil, want non-nil", tv.Field(i).Type()) + } else if !strings.Contains(err.Error(), typ) { + t.Errorf("WriteT: have err == %q, want it to mention %s", err, typ) } } } @@ -238,7 +245,7 @@ func BenchmarkReadStruct(b *testing.B) { bsr := &byteSliceReader{} var buf bytes.Buffer Write(&buf, BigEndian, &s) - n := dataSize(reflect.ValueOf(s)) + n, _ := dataSize(reflect.ValueOf(s)) b.SetBytes(int64(n)) t := s b.ResetTimer() diff --git a/libgo/go/encoding/binary/varint.go b/libgo/go/encoding/binary/varint.go index 7035529..3a2dfa3c 100644 --- a/libgo/go/encoding/binary/varint.go +++ b/libgo/go/encoding/binary/varint.go @@ -120,7 +120,6 @@ func ReadUvarint(r io.ByteReader) (uint64, error) { x |= uint64(b&0x7f) << s s += 7 } - panic("unreachable") } // ReadVarint reads an encoded signed integer from r and returns it as an int64. diff --git a/libgo/go/encoding/csv/reader.go b/libgo/go/encoding/csv/reader.go index db4d988..b099caf 100644 --- a/libgo/go/encoding/csv/reader.go +++ b/libgo/go/encoding/csv/reader.go @@ -171,7 +171,6 @@ func (r *Reader) ReadAll() (records [][]string, err error) { } records = append(records, record) } - panic("unreachable") } // readRune reads one rune from r, folding \r\n to \n and keeping track @@ -213,7 +212,6 @@ func (r *Reader) skip(delim rune) error { return nil } } - panic("unreachable") } // parseRecord reads and parses a single csv record from r. @@ -250,7 +248,6 @@ func (r *Reader) parseRecord() (fields []string, err error) { return nil, err } } - panic("unreachable") } // parseField parses the next field in the record. The read field is diff --git a/libgo/go/encoding/gob/codec_test.go b/libgo/go/encoding/gob/codec_test.go index 482212b..9e38e31 100644 --- a/libgo/go/encoding/gob/codec_test.go +++ b/libgo/go/encoding/gob/codec_test.go @@ -1191,10 +1191,8 @@ func TestInterface(t *testing.T) { if v1 != nil || v2 != nil { t.Errorf("item %d inconsistent nils", i) } - continue - if v1.Square() != v2.Square() { - t.Errorf("item %d inconsistent values: %v %v", i, v1, v2) - } + } else if v1.Square() != v2.Square() { + t.Errorf("item %d inconsistent values: %v %v", i, v1, v2) } } } diff --git a/libgo/go/encoding/gob/decode.go b/libgo/go/encoding/gob/decode.go index a80d9f9..7cc7565 100644 --- a/libgo/go/encoding/gob/decode.go +++ b/libgo/go/encoding/gob/decode.go @@ -1066,7 +1066,6 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[re case reflect.Struct: return true } - return true } // typeString returns a human-readable description of the type identified by remoteId. diff --git a/libgo/go/encoding/gob/gobencdec_test.go b/libgo/go/encoding/gob/gobencdec_test.go index 18f4450..ddcd80b 100644 --- a/libgo/go/encoding/gob/gobencdec_test.go +++ b/libgo/go/encoding/gob/gobencdec_test.go @@ -1,4 +1,4 @@ -// Copyright 20011 The Go Authors. All rights reserved. +// Copyright 2011 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -348,7 +348,7 @@ func TestGobEncoderFieldsOfDifferentType(t *testing.T) { t.Fatal("decode error:", err) } if y.G.s != "XYZ" { - t.Fatalf("expected `XYZ` got %c", y.G.s) + t.Fatalf("expected `XYZ` got %q", y.G.s) } } diff --git a/libgo/go/encoding/gob/timing_test.go b/libgo/go/encoding/gob/timing_test.go index 9a0e51d..f589675 100644 --- a/libgo/go/encoding/gob/timing_test.go +++ b/libgo/go/encoding/gob/timing_test.go @@ -50,49 +50,51 @@ func BenchmarkEndToEndByteBuffer(b *testing.B) { } func TestCountEncodeMallocs(t *testing.T) { - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) + if runtime.GOMAXPROCS(0) > 1 { + t.Skip("skipping; GOMAXPROCS>1") + } + + const N = 1000 + var buf bytes.Buffer enc := NewEncoder(&buf) bench := &Bench{7, 3.2, "now is the time", []byte("for all good men")} - memstats := new(runtime.MemStats) - runtime.ReadMemStats(memstats) - mallocs := 0 - memstats.Mallocs - const count = 1000 - for i := 0; i < count; i++ { + + allocs := testing.AllocsPerRun(N, func() { err := enc.Encode(bench) if err != nil { t.Fatal("encode:", err) } - } - runtime.ReadMemStats(memstats) - mallocs += memstats.Mallocs - fmt.Printf("mallocs per encode of type Bench: %d\n", mallocs/count) + }) + fmt.Printf("mallocs per encode of type Bench: %v\n", allocs) } func TestCountDecodeMallocs(t *testing.T) { - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) + if runtime.GOMAXPROCS(0) > 1 { + t.Skip("skipping; GOMAXPROCS>1") + } + + const N = 1000 + var buf bytes.Buffer enc := NewEncoder(&buf) bench := &Bench{7, 3.2, "now is the time", []byte("for all good men")} - const count = 1000 - for i := 0; i < count; i++ { + + // Fill the buffer with enough to decode + testing.AllocsPerRun(N, func() { err := enc.Encode(bench) if err != nil { t.Fatal("encode:", err) } - } + }) + dec := NewDecoder(&buf) - memstats := new(runtime.MemStats) - runtime.ReadMemStats(memstats) - mallocs := 0 - memstats.Mallocs - for i := 0; i < count; i++ { + allocs := testing.AllocsPerRun(N, func() { *bench = Bench{} err := dec.Decode(&bench) if err != nil { t.Fatal("decode:", err) } - } - runtime.ReadMemStats(memstats) - mallocs += memstats.Mallocs - fmt.Printf("mallocs per decode of type Bench: %d\n", mallocs/count) + }) + fmt.Printf("mallocs per decode of type Bench: %v\n", allocs) } diff --git a/libgo/go/encoding/gob/type.go b/libgo/go/encoding/gob/type.go index ea0db4e..7fa0b499 100644 --- a/libgo/go/encoding/gob/type.go +++ b/libgo/go/encoding/gob/type.go @@ -526,7 +526,6 @@ func newTypeObject(name string, ut *userTypeInfo, rt reflect.Type) (gobType, err default: return nil, errors.New("gob NewTypeObject can't handle type: " + rt.String()) } - return nil, nil } // isExported reports whether this is an exported - upper case - name. diff --git a/libgo/go/encoding/json/decode.go b/libgo/go/encoding/json/decode.go index 95e9120..62ac294 100644 --- a/libgo/go/encoding/json/decode.go +++ b/libgo/go/encoding/json/decode.go @@ -33,6 +33,10 @@ import ( // the value pointed at by the pointer. If the pointer is nil, Unmarshal // allocates a new value for it to point to. // +// To unmarshal JSON into a struct, Unmarshal matches incoming object +// keys to the keys used by Marshal (either the struct field name or its tag), +// preferring an exact match but also accepting a case-insensitive match. +// // To unmarshal JSON into an interface value, Unmarshal unmarshals // the JSON into the concrete value contained in the interface value. // If the interface value is nil, that is, has no concrete value stored in it, @@ -51,17 +55,22 @@ import ( // If no more serious errors are encountered, Unmarshal returns // an UnmarshalTypeError describing the earliest such error. // +// When unmarshaling quoted strings, invalid UTF-8 or +// invalid UTF-16 surrogate pairs are not treated as an error. +// Instead, they are replaced by the Unicode replacement +// character U+FFFD. +// func Unmarshal(data []byte, v interface{}) error { - d := new(decodeState).init(data) - - // Quick check for well-formedness. + // Check for well-formedness. // Avoids filling out half a data structure // before discovering a JSON syntax error. + var d decodeState err := checkValid(data, &d.scan) if err != nil { return err } + d.init(data) return d.unmarshal(v) } @@ -252,6 +261,16 @@ func (d *decodeState) value(v reflect.Value) { } d.scan.step(&d.scan, '"') d.scan.step(&d.scan, '"') + + n := len(d.scan.parseState) + if n > 0 && d.scan.parseState[n-1] == parseObjectKey { + // d.scan thinks we just read an object key; finish the object + d.scan.step(&d.scan, ':') + d.scan.step(&d.scan, '"') + d.scan.step(&d.scan, '"') + d.scan.step(&d.scan, '}') + } + return } @@ -730,6 +749,7 @@ func (d *decodeState) valueInterface() interface{} { switch d.scanWhile(scanSkipSpace) { default: d.error(errPhase) + panic("unreachable") case scanBeginArray: return d.arrayInterface() case scanBeginObject: @@ -737,12 +757,11 @@ func (d *decodeState) valueInterface() interface{} { case scanBeginLiteral: return d.literalInterface() } - panic("unreachable") } // arrayInterface is like array but returns []interface{}. func (d *decodeState) arrayInterface() []interface{} { - var v []interface{} + var v = make([]interface{}, 0) for { // Look ahead for ] - can only happen on first iteration. op := d.scanWhile(scanSkipSpace) @@ -849,7 +868,6 @@ func (d *decodeState) literalInterface() interface{} { } return n } - panic("unreachable") } // getu4 decodes \uXXXX from the beginning of s, returning the hex value, diff --git a/libgo/go/encoding/json/decode_test.go b/libgo/go/encoding/json/decode_test.go index a91c6da..f845f69 100644 --- a/libgo/go/encoding/json/decode_test.go +++ b/libgo/go/encoding/json/decode_test.go @@ -11,6 +11,7 @@ import ( "reflect" "strings" "testing" + "time" ) type T struct { @@ -29,7 +30,7 @@ type V struct { F3 Number } -// ifaceNumAsFloat64/ifaceNumAsNumber are used to test unmarshalling with and +// ifaceNumAsFloat64/ifaceNumAsNumber are used to test unmarshaling with and // without UseNumber var ifaceNumAsFloat64 = map[string]interface{}{ "k1": float64(1), @@ -239,6 +240,12 @@ var unmarshalTests = []unmarshalTest{ {in: `[1, 2, 3]`, ptr: new([1]int), out: [1]int{1}}, {in: `[1, 2, 3]`, ptr: new([5]int), out: [5]int{1, 2, 3, 0, 0}}, + // empty array to interface test + {in: `[]`, ptr: new([]interface{}), out: []interface{}{}}, + {in: `null`, ptr: new([]interface{}), out: []interface{}(nil)}, + {in: `{"T":[]}`, ptr: new(map[string]interface{}), out: map[string]interface{}{"T": []interface{}{}}}, + {in: `{"T":null}`, ptr: new(map[string]interface{}), out: map[string]interface{}{"T": interface{}(nil)}}, + // composite tests {in: allValueIndent, ptr: new(All), out: allValue}, {in: allValueCompact, ptr: new(All), out: allValue}, @@ -323,6 +330,43 @@ var unmarshalTests = []unmarshalTest{ ptr: new(S10), out: S10{S13: S13{S8: S8{S9: S9{Y: 2}}}}, }, + + // invalid UTF-8 is coerced to valid UTF-8. + { + in: "\"hello\xffworld\"", + ptr: new(string), + out: "hello\ufffdworld", + }, + { + in: "\"hello\xc2\xc2world\"", + ptr: new(string), + out: "hello\ufffd\ufffdworld", + }, + { + in: "\"hello\xc2\xffworld\"", + ptr: new(string), + out: "hello\ufffd\ufffdworld", + }, + { + in: "\"hello\\ud800world\"", + ptr: new(string), + out: "hello\ufffdworld", + }, + { + in: "\"hello\\ud800\\ud800world\"", + ptr: new(string), + out: "hello\ufffd\ufffdworld", + }, + { + in: "\"hello\\ud800\\ud800world\"", + ptr: new(string), + out: "hello\ufffd\ufffdworld", + }, + { + in: "\"hello\xed\xa0\x80\xed\xb0\x80world\"", + ptr: new(string), + out: "hello\ufffd\ufffd\ufffd\ufffd\ufffd\ufffdworld", + }, } func TestMarshal(t *testing.T) { @@ -1107,3 +1151,43 @@ func TestUnmarshalUnexported(t *testing.T) { t.Errorf("got %q, want %q", out, want) } } + +// Time3339 is a time.Time which encodes to and from JSON +// as an RFC 3339 time in UTC. +type Time3339 time.Time + +func (t *Time3339) UnmarshalJSON(b []byte) error { + if len(b) < 2 || b[0] != '"' || b[len(b)-1] != '"' { + return fmt.Errorf("types: failed to unmarshal non-string value %q as an RFC 3339 time", b) + } + tm, err := time.Parse(time.RFC3339, string(b[1:len(b)-1])) + if err != nil { + return err + } + *t = Time3339(tm) + return nil +} + +func TestUnmarshalJSONLiteralError(t *testing.T) { + var t3 Time3339 + err := Unmarshal([]byte(`"0000-00-00T00:00:00Z"`), &t3) + if err == nil { + t.Fatalf("expected error; got time %v", time.Time(t3)) + } + if !strings.Contains(err.Error(), "range") { + t.Errorf("got err = %v; want out of range error", err) + } +} + +// Test that extra object elements in an array do not result in a +// "data changing underfoot" error. +// Issue 3717 +func TestSkipArrayObjects(t *testing.T) { + json := `[{}]` + var dest [0]interface{} + + err := Unmarshal([]byte(json), &dest) + if err != nil { + t.Errorf("got error %q, want nil", err) + } +} diff --git a/libgo/go/encoding/json/encode.go b/libgo/go/encoding/json/encode.go index fb57f1d..85727ba 100644 --- a/libgo/go/encoding/json/encode.go +++ b/libgo/go/encoding/json/encode.go @@ -3,7 +3,8 @@ // license that can be found in the LICENSE file. // Package json implements encoding and decoding of JSON objects as defined in -// RFC 4627. +// RFC 4627. The mapping between JSON objects and Go values is described +// in the documentation for the Marshal and Unmarshal functions. // // See "JSON and Go" for an introduction to this package: // http://golang.org/doc/articles/json_and_go.html @@ -38,8 +39,8 @@ import ( // // Floating point, integer, and Number values encode as JSON numbers. // -// String values encode as JSON strings, with each invalid UTF-8 sequence -// replaced by the encoding of the Unicode replacement character U+FFFD. +// String values encode as JSON strings. InvalidUTF8Error will be returned +// if an invalid UTF-8 sequence is encountered. // The angle brackets "<" and ">" are escaped to "\u003c" and "\u003e" // to keep some browsers from misinterpreting JSON output as HTML. // @@ -86,9 +87,21 @@ import ( // underscores and slashes. // // Anonymous struct fields are usually marshaled as if their inner exported fields -// were fields in the outer struct, subject to the usual Go visibility rules. +// were fields in the outer struct, subject to the usual Go visibility rules amended +// as described in the next paragraph. // An anonymous struct field with a name given in its JSON tag is treated as -// having that name instead of as anonymous. +// having that name, rather than being anonymous. +// +// The Go visibility rules for struct fields are amended for JSON when +// deciding which field to marshal or unmarshal. If there are +// multiple fields at the same level, and that level is the least +// nested (and would therefore be the nesting level selected by the +// usual Go rules), the following extra rules apply: +// +// 1) Of those fields, if any are JSON-tagged, only tagged fields are considered, +// even if there are multiple untagged fields that would otherwise conflict. +// 2) If there is exactly one field (tagged or not according to the first rule), that is selected. +// 3) Otherwise there are multiple fields, and all are ignored; no error occurs. // // Handling of anonymous struct fields is new in Go 1.1. // Prior to Go 1.1, anonymous struct fields were ignored. To force ignoring of @@ -187,8 +200,10 @@ func (e *UnsupportedValueError) Error() string { return "json: unsupported value: " + e.Str } +// An InvalidUTF8Error is returned by Marshal when attempting +// to encode a string value with invalid UTF-8 sequences. type InvalidUTF8Error struct { - S string + S string // the whole string value that caused the error } func (e *InvalidUTF8Error) Error() string { @@ -654,27 +669,78 @@ func typeFields(t reflect.Type) []field { sort.Sort(byName(fields)) - // Remove fields with annihilating name collisions - // and also fields shadowed by fields with explicit JSON tags. - name := "" + // Delete all fields that are hidden by the Go rules for embedded fields, + // except that fields with JSON tags are promoted. + + // The fields are sorted in primary order of name, secondary order + // of field index length. Loop over names; for each name, delete + // hidden fields by choosing the one dominant field that survives. out := fields[:0] - for _, f := range fields { - if f.name != name { - name = f.name - out = append(out, f) + for advance, i := 0, 0; i < len(fields); i += advance { + // One iteration per name. + // Find the sequence of fields with the name of this first field. + fi := fields[i] + name := fi.name + for advance = 1; i+advance < len(fields); advance++ { + fj := fields[i+advance] + if fj.name != name { + break + } + } + if advance == 1 { // Only one field with this name + out = append(out, fi) continue } - if n := len(out); n > 0 && out[n-1].name == name && (!out[n-1].tag || f.tag) { - out = out[:n-1] + dominant, ok := dominantField(fields[i : i+advance]) + if ok { + out = append(out, dominant) } } - fields = out + fields = out sort.Sort(byIndex(fields)) return fields } +// dominantField looks through the fields, all of which are known to +// have the same name, to find the single field that dominates the +// others using Go's embedding rules, modified by the presence of +// JSON tags. If there are multiple top-level fields, the boolean +// will be false: This condition is an error in Go and we skip all +// the fields. +func dominantField(fields []field) (field, bool) { + // The fields are sorted in increasing index-length order. The winner + // must therefore be one with the shortest index length. Drop all + // longer entries, which is easy: just truncate the slice. + length := len(fields[0].index) + tagged := -1 // Index of first tagged field. + for i, f := range fields { + if len(f.index) > length { + fields = fields[:i] + break + } + if f.tag { + if tagged >= 0 { + // Multiple tagged fields at the same level: conflict. + // Return no field. + return field{}, false + } + tagged = i + } + } + if tagged >= 0 { + return fields[tagged], true + } + // All remaining fields have the same length. If there's more than one, + // we have a conflict (two fields named "X" at the same level) and we + // return no field. + if len(fields) > 1 { + return field{}, false + } + return fields[0], true +} + var fieldCache struct { sync.RWMutex m map[reflect.Type][]field diff --git a/libgo/go/encoding/json/encode_test.go b/libgo/go/encoding/json/encode_test.go index be74c99..5be0a99 100644 --- a/libgo/go/encoding/json/encode_test.go +++ b/libgo/go/encoding/json/encode_test.go @@ -206,3 +206,107 @@ func TestAnonymousNonstruct(t *testing.T) { t.Errorf("got %q, want %q", got, want) } } + +type BugA struct { + S string +} + +type BugB struct { + BugA + S string +} + +type BugC struct { + S string +} + +// Legal Go: We never use the repeated embedded field (S). +type BugX struct { + A int + BugA + BugB +} + +// Issue 5245. +func TestEmbeddedBug(t *testing.T) { + v := BugB{ + BugA{"A"}, + "B", + } + b, err := Marshal(v) + if err != nil { + t.Fatal("Marshal:", err) + } + want := `{"S":"B"}` + got := string(b) + if got != want { + t.Fatalf("Marshal: got %s want %s", got, want) + } + // Now check that the duplicate field, S, does not appear. + x := BugX{ + A: 23, + } + b, err = Marshal(x) + if err != nil { + t.Fatal("Marshal:", err) + } + want = `{"A":23}` + got = string(b) + if got != want { + t.Fatalf("Marshal: got %s want %s", got, want) + } +} + +type BugD struct { // Same as BugA after tagging. + XXX string `json:"S"` +} + +// BugD's tagged S field should dominate BugA's. +type BugY struct { + BugA + BugD +} + +// Test that a field with a tag dominates untagged fields. +func TestTaggedFieldDominates(t *testing.T) { + v := BugY{ + BugA{"BugA"}, + BugD{"BugD"}, + } + b, err := Marshal(v) + if err != nil { + t.Fatal("Marshal:", err) + } + want := `{"S":"BugD"}` + got := string(b) + if got != want { + t.Fatalf("Marshal: got %s want %s", got, want) + } +} + +// There are no tags here, so S should not appear. +type BugZ struct { + BugA + BugC + BugY // Contains a tagged S field through BugD; should not dominate. +} + +func TestDuplicatedFieldDisappears(t *testing.T) { + v := BugZ{ + BugA{"BugA"}, + BugC{"BugC"}, + BugY{ + BugA{"nested BugA"}, + BugD{"nested BugD"}, + }, + } + b, err := Marshal(v) + if err != nil { + t.Fatal("Marshal:", err) + } + want := `{}` + got := string(b) + if got != want { + t.Fatalf("Marshal: got %s want %s", got, want) + } +} diff --git a/libgo/go/encoding/json/scanner_test.go b/libgo/go/encoding/json/scanner_test.go index adb3571..77d3455 100644 --- a/libgo/go/encoding/json/scanner_test.go +++ b/libgo/go/encoding/json/scanner_test.go @@ -277,9 +277,6 @@ func genArray(n int) []interface{} { if f > n { f = n } - if n > 0 && f == 0 { - f = 1 - } x := make([]interface{}, f) for i := range x { x[i] = genValue(((i+1)*n)/f - (i*n)/f) diff --git a/libgo/go/encoding/xml/marshal.go b/libgo/go/encoding/xml/marshal.go index aacb50c..47b0017 100644 --- a/libgo/go/encoding/xml/marshal.go +++ b/libgo/go/encoding/xml/marshal.go @@ -81,8 +81,7 @@ func Marshal(v interface{}) ([]byte, error) { func MarshalIndent(v interface{}, prefix, indent string) ([]byte, error) { var b bytes.Buffer enc := NewEncoder(&b) - enc.prefix = prefix - enc.indent = indent + enc.Indent(prefix, indent) if err := enc.Encode(v); err != nil { return nil, err } @@ -99,6 +98,14 @@ func NewEncoder(w io.Writer) *Encoder { return &Encoder{printer{Writer: bufio.NewWriter(w)}} } +// Indent sets the encoder to generate XML in which each element +// begins on a new indented line that starts with prefix and is followed by +// one or more copies of indent according to the nesting depth. +func (enc *Encoder) Indent(prefix, indent string) { + enc.prefix = prefix + enc.indent = indent +} + // Encode writes the XML encoding of v to the stream. // // See the documentation for Marshal for details about the conversion @@ -113,10 +120,76 @@ func (enc *Encoder) Encode(v interface{}) error { type printer struct { *bufio.Writer + seq int indent string prefix string depth int indentedIn bool + putNewline bool + attrNS map[string]string // map prefix -> name space + attrPrefix map[string]string // map name space -> prefix +} + +// createAttrPrefix finds the name space prefix attribute to use for the given name space, +// defining a new prefix if necessary. It returns the prefix and whether it is new. +func (p *printer) createAttrPrefix(url string) (prefix string, isNew bool) { + if prefix = p.attrPrefix[url]; prefix != "" { + return prefix, false + } + + // The "http://www.w3.org/XML/1998/namespace" name space is predefined as "xml" + // and must be referred to that way. + // (The "http://www.w3.org/2000/xmlns/" name space is also predefined as "xmlns", + // but users should not be trying to use that one directly - that's our job.) + if url == xmlURL { + return "xml", false + } + + // Need to define a new name space. + if p.attrPrefix == nil { + p.attrPrefix = make(map[string]string) + p.attrNS = make(map[string]string) + } + + // Pick a name. We try to use the final element of the path + // but fall back to _. + prefix = strings.TrimRight(url, "/") + if i := strings.LastIndex(prefix, "/"); i >= 0 { + prefix = prefix[i+1:] + } + if prefix == "" || !isName([]byte(prefix)) || strings.Contains(prefix, ":") { + prefix = "_" + } + if strings.HasPrefix(prefix, "xml") { + // xmlanything is reserved. + prefix = "_" + prefix + } + if p.attrNS[prefix] != "" { + // Name is taken. Find a better one. + for p.seq++; ; p.seq++ { + if id := prefix + "_" + strconv.Itoa(p.seq); p.attrNS[id] == "" { + prefix = id + break + } + } + } + + p.attrPrefix[url] = prefix + p.attrNS[prefix] = url + + p.WriteString(`xmlns:`) + p.WriteString(prefix) + p.WriteString(`="`) + EscapeText(p, []byte(url)) + p.WriteString(`" `) + + return prefix, true +} + +// deleteAttrPrefix removes an attribute name space prefix. +func (p *printer) deleteAttrPrefix(prefix string) { + delete(p.attrPrefix, p.attrNS[prefix]) + delete(p.attrNS, prefix) } // marshalValue writes one or more XML elements representing val. @@ -185,7 +258,9 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error { if xmlns != "" { p.WriteString(` xmlns="`) // TODO: EscapeString, to avoid the allocation. - Escape(p, []byte(xmlns)) + if err := EscapeText(p, []byte(xmlns)); err != nil { + return err + } p.WriteByte('"') } @@ -200,6 +275,14 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error { continue } p.WriteByte(' ') + if finfo.xmlns != "" { + prefix, created := p.createAttrPrefix(finfo.xmlns) + if created { + defer p.deleteAttrPrefix(prefix) + } + p.WriteString(prefix) + p.WriteByte(':') + } p.WriteString(finfo.name) p.WriteString(`="`) if err := p.marshalSimple(fv.Type(), fv); err != nil { @@ -244,19 +327,22 @@ func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) error { p.WriteString(strconv.FormatFloat(val.Float(), 'g', -1, val.Type().Bits())) case reflect.String: // TODO: Add EscapeString. - Escape(p, []byte(val.String())) + EscapeText(p, []byte(val.String())) case reflect.Bool: p.WriteString(strconv.FormatBool(val.Bool())) case reflect.Array: // will be [...]byte - bytes := make([]byte, val.Len()) - for i := range bytes { - bytes[i] = val.Index(i).Interface().(byte) + var bytes []byte + if val.CanAddr() { + bytes = val.Slice(0, val.Len()).Bytes() + } else { + bytes = make([]byte, val.Len()) + reflect.Copy(reflect.ValueOf(bytes), val) } - Escape(p, bytes) + EscapeText(p, bytes) case reflect.Slice: // will be []byte - Escape(p, val.Bytes()) + EscapeText(p, val.Bytes()) default: return &UnsupportedTypeError{typ} } @@ -273,7 +359,7 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error { s := parentStack{printer: p} for i := range tinfo.fields { finfo := &tinfo.fields[i] - if finfo.flags&(fAttr) != 0 { + if finfo.flags&fAttr != 0 { continue } vf := finfo.value(val) @@ -290,10 +376,14 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error { case reflect.Bool: Escape(p, strconv.AppendBool(scratch[:0], vf.Bool())) case reflect.String: - Escape(p, []byte(vf.String())) + if err := EscapeText(p, []byte(vf.String())); err != nil { + return err + } case reflect.Slice: if elem, ok := vf.Interface().([]byte); ok { - Escape(p, elem) + if err := EscapeText(p, elem); err != nil { + return err + } } case reflect.Struct: if vf.Type() == timeType { @@ -387,7 +477,11 @@ func (p *printer) writeIndent(depthDelta int) { } p.indentedIn = false } - p.WriteByte('\n') + if p.putNewline { + p.WriteByte('\n') + } else { + p.putNewline = true + } if len(p.prefix) > 0 { p.WriteString(p.prefix) } diff --git a/libgo/go/encoding/xml/marshal_test.go b/libgo/go/encoding/xml/marshal_test.go index 67fcfd9..ca14a1e 100644 --- a/libgo/go/encoding/xml/marshal_test.go +++ b/libgo/go/encoding/xml/marshal_test.go @@ -7,6 +7,7 @@ package xml import ( "bytes" "errors" + "fmt" "io" "reflect" "strconv" @@ -265,6 +266,16 @@ type Plain struct { V interface{} } +type MyInt int + +type EmbedInt struct { + MyInt +} + +type Strings struct { + X []string `xml:"A>B,omitempty"` +} + // Unless explicitly stated as such (or *Plain), all of the // tests below are two-way tests. When introducing new tests, // please try to make them two-way as well to ensure that @@ -789,6 +800,17 @@ var marshalTests = []struct { }, UnmarshalOnly: true, }, + { + ExpectXML: `<EmbedInt><MyInt>42</MyInt></EmbedInt>`, + Value: &EmbedInt{ + MyInt: 42, + }, + }, + // Test omitempty with parent chain; see golang.org/issue/4168. + { + ExpectXML: `<Strings><A></A></Strings>`, + Value: &Strings{}, + }, } func TestMarshal(t *testing.T) { @@ -811,6 +833,10 @@ func TestMarshal(t *testing.T) { } } +type AttrParent struct { + X string `xml:"X>Y,attr"` +} + var marshalErrorTests = []struct { Value interface{} Err string @@ -838,12 +864,39 @@ var marshalErrorTests = []struct { Value: &Domain{Comment: []byte("f--bar")}, Err: `xml: comments must not contain "--"`, }, + // Reject parent chain with attr, never worked; see golang.org/issue/5033. + { + Value: &AttrParent{}, + Err: `xml: X>Y chain not valid with attr flag`, + }, +} + +var marshalIndentTests = []struct { + Value interface{} + Prefix string + Indent string + ExpectXML string +}{ + { + Value: &SecretAgent{ + Handle: "007", + Identity: "James Bond", + Obfuscate: "<redacted/>", + }, + Prefix: "", + Indent: "\t", + ExpectXML: fmt.Sprintf("<agent handle=\"007\">\n\t<Identity>James Bond</Identity><redacted/>\n</agent>"), + }, } func TestMarshalErrors(t *testing.T) { for idx, test := range marshalErrorTests { - _, err := Marshal(test.Value) - if err == nil || err.Error() != test.Err { + data, err := Marshal(test.Value) + if err == nil { + t.Errorf("#%d: marshal(%#v) = [success] %q, want error %v", idx, test.Value, data, test.Err) + continue + } + if err.Error() != test.Err { t.Errorf("#%d: marshal(%#v) = [error] %v, want %v", idx, test.Value, err, test.Err) } if test.Kind != reflect.Invalid { @@ -884,6 +937,19 @@ func TestUnmarshal(t *testing.T) { } } +func TestMarshalIndent(t *testing.T) { + for i, test := range marshalIndentTests { + data, err := MarshalIndent(test.Value, test.Prefix, test.Indent) + if err != nil { + t.Errorf("#%d: Error: %s", i, err) + continue + } + if got, want := string(data), test.ExpectXML; got != want { + t.Errorf("#%d: MarshalIndent:\nGot:%s\nWant:\n%s", i, got, want) + } + } +} + type limitedBytesWriter struct { w io.Writer remain int // until writes fail @@ -933,6 +999,16 @@ func TestMarshalWriteErrors(t *testing.T) { } } +func TestMarshalWriteIOErrors(t *testing.T) { + enc := NewEncoder(errWriter{}) + + expectErr := "unwritable" + err := enc.Encode(&Passenger{}) + if err == nil || err.Error() != expectErr { + t.Errorf("EscapeTest = [error] %v, want %v", err, expectErr) + } +} + func BenchmarkMarshal(b *testing.B) { for i := 0; i < b.N; i++ { Marshal(atomValue) diff --git a/libgo/go/encoding/xml/read.go b/libgo/go/encoding/xml/read.go index 344ab51..a7a2a96 100644 --- a/libgo/go/encoding/xml/read.go +++ b/libgo/go/encoding/xml/read.go @@ -263,7 +263,7 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error { strv := finfo.value(sv) // Look for attribute. for _, a := range start.Attr { - if a.Name.Local == finfo.name { + if a.Name.Local == finfo.name && (finfo.xmlns == "" || finfo.xmlns == a.Name.Space) { copyValue(strv, []byte(a.Value)) break } @@ -441,7 +441,7 @@ func (p *Decoder) unmarshalPath(tinfo *typeInfo, sv reflect.Value, parents []str Loop: for i := range tinfo.fields { finfo := &tinfo.fields[i] - if finfo.flags&fElement == 0 || len(finfo.parents) < len(parents) { + if finfo.flags&fElement == 0 || len(finfo.parents) < len(parents) || finfo.xmlns != "" && finfo.xmlns != start.Name.Space { continue } for j := range parents { @@ -493,7 +493,6 @@ Loop: return true, nil } } - panic("unreachable") } // Skip reads tokens until it has consumed the end element @@ -517,5 +516,4 @@ func (d *Decoder) Skip() error { return nil } } - panic("unreachable") } diff --git a/libgo/go/encoding/xml/read_test.go b/libgo/go/encoding/xml/read_test.go index b45e2f0..7d28c5d 100644 --- a/libgo/go/encoding/xml/read_test.go +++ b/libgo/go/encoding/xml/read_test.go @@ -6,6 +6,7 @@ package xml import ( "reflect" + "strings" "testing" "time" ) @@ -399,3 +400,224 @@ func TestUnmarshalAttr(t *testing.T) { t.Fatalf("Unmarshal with %s failed:\nhave %#v,\n want %#v", x, p3.Int, 1) } } + +type Tables struct { + HTable string `xml:"http://www.w3.org/TR/html4/ table"` + FTable string `xml:"http://www.w3schools.com/furniture table"` +} + +var tables = []struct { + xml string + tab Tables + ns string +}{ + { + xml: `<Tables>` + + `<table xmlns="http://www.w3.org/TR/html4/">hello</table>` + + `<table xmlns="http://www.w3schools.com/furniture">world</table>` + + `</Tables>`, + tab: Tables{"hello", "world"}, + }, + { + xml: `<Tables>` + + `<table xmlns="http://www.w3schools.com/furniture">world</table>` + + `<table xmlns="http://www.w3.org/TR/html4/">hello</table>` + + `</Tables>`, + tab: Tables{"hello", "world"}, + }, + { + xml: `<Tables xmlns:f="http://www.w3schools.com/furniture" xmlns:h="http://www.w3.org/TR/html4/">` + + `<f:table>world</f:table>` + + `<h:table>hello</h:table>` + + `</Tables>`, + tab: Tables{"hello", "world"}, + }, + { + xml: `<Tables>` + + `<table>bogus</table>` + + `</Tables>`, + tab: Tables{}, + }, + { + xml: `<Tables>` + + `<table>only</table>` + + `</Tables>`, + tab: Tables{HTable: "only"}, + ns: "http://www.w3.org/TR/html4/", + }, + { + xml: `<Tables>` + + `<table>only</table>` + + `</Tables>`, + tab: Tables{FTable: "only"}, + ns: "http://www.w3schools.com/furniture", + }, + { + xml: `<Tables>` + + `<table>only</table>` + + `</Tables>`, + tab: Tables{}, + ns: "something else entirely", + }, +} + +func TestUnmarshalNS(t *testing.T) { + for i, tt := range tables { + var dst Tables + var err error + if tt.ns != "" { + d := NewDecoder(strings.NewReader(tt.xml)) + d.DefaultSpace = tt.ns + err = d.Decode(&dst) + } else { + err = Unmarshal([]byte(tt.xml), &dst) + } + if err != nil { + t.Errorf("#%d: Unmarshal: %v", i, err) + continue + } + want := tt.tab + if dst != want { + t.Errorf("#%d: dst=%+v, want %+v", i, dst, want) + } + } +} + +func TestMarshalNS(t *testing.T) { + dst := Tables{"hello", "world"} + data, err := Marshal(&dst) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + want := `<Tables><table xmlns="http://www.w3.org/TR/html4/">hello</table><table xmlns="http://www.w3schools.com/furniture">world</table></Tables>` + str := string(data) + if str != want { + t.Errorf("have: %q\nwant: %q\n", str, want) + } +} + +type TableAttrs struct { + TAttr TAttr +} + +type TAttr struct { + HTable string `xml:"http://www.w3.org/TR/html4/ table,attr"` + FTable string `xml:"http://www.w3schools.com/furniture table,attr"` + Lang string `xml:"http://www.w3.org/XML/1998/namespace lang,attr,omitempty"` + Other1 string `xml:"http://golang.org/xml/ other,attr,omitempty"` + Other2 string `xml:"http://golang.org/xmlfoo/ other,attr,omitempty"` + Other3 string `xml:"http://golang.org/json/ other,attr,omitempty"` + Other4 string `xml:"http://golang.org/2/json/ other,attr,omitempty"` +} + +var tableAttrs = []struct { + xml string + tab TableAttrs + ns string +}{ + { + xml: `<TableAttrs xmlns:f="http://www.w3schools.com/furniture" xmlns:h="http://www.w3.org/TR/html4/"><TAttr ` + + `h:table="hello" f:table="world" ` + + `/></TableAttrs>`, + tab: TableAttrs{TAttr{HTable: "hello", FTable: "world"}}, + }, + { + xml: `<TableAttrs><TAttr xmlns:f="http://www.w3schools.com/furniture" xmlns:h="http://www.w3.org/TR/html4/" ` + + `h:table="hello" f:table="world" ` + + `/></TableAttrs>`, + tab: TableAttrs{TAttr{HTable: "hello", FTable: "world"}}, + }, + { + xml: `<TableAttrs><TAttr ` + + `h:table="hello" f:table="world" xmlns:f="http://www.w3schools.com/furniture" xmlns:h="http://www.w3.org/TR/html4/" ` + + `/></TableAttrs>`, + tab: TableAttrs{TAttr{HTable: "hello", FTable: "world"}}, + }, + { + // Default space does not apply to attribute names. + xml: `<TableAttrs xmlns="http://www.w3schools.com/furniture" xmlns:h="http://www.w3.org/TR/html4/"><TAttr ` + + `h:table="hello" table="world" ` + + `/></TableAttrs>`, + tab: TableAttrs{TAttr{HTable: "hello", FTable: ""}}, + }, + { + // Default space does not apply to attribute names. + xml: `<TableAttrs xmlns:f="http://www.w3schools.com/furniture"><TAttr xmlns="http://www.w3.org/TR/html4/" ` + + `table="hello" f:table="world" ` + + `/></TableAttrs>`, + tab: TableAttrs{TAttr{HTable: "", FTable: "world"}}, + }, + { + xml: `<TableAttrs><TAttr ` + + `table="bogus" ` + + `/></TableAttrs>`, + tab: TableAttrs{}, + }, + { + // Default space does not apply to attribute names. + xml: `<TableAttrs xmlns:h="http://www.w3.org/TR/html4/"><TAttr ` + + `h:table="hello" table="world" ` + + `/></TableAttrs>`, + tab: TableAttrs{TAttr{HTable: "hello", FTable: ""}}, + ns: "http://www.w3schools.com/furniture", + }, + { + // Default space does not apply to attribute names. + xml: `<TableAttrs xmlns:f="http://www.w3schools.com/furniture"><TAttr ` + + `table="hello" f:table="world" ` + + `/></TableAttrs>`, + tab: TableAttrs{TAttr{HTable: "", FTable: "world"}}, + ns: "http://www.w3.org/TR/html4/", + }, + { + xml: `<TableAttrs><TAttr ` + + `table="bogus" ` + + `/></TableAttrs>`, + tab: TableAttrs{}, + ns: "something else entirely", + }, +} + +func TestUnmarshalNSAttr(t *testing.T) { + for i, tt := range tableAttrs { + var dst TableAttrs + var err error + if tt.ns != "" { + d := NewDecoder(strings.NewReader(tt.xml)) + d.DefaultSpace = tt.ns + err = d.Decode(&dst) + } else { + err = Unmarshal([]byte(tt.xml), &dst) + } + if err != nil { + t.Errorf("#%d: Unmarshal: %v", i, err) + continue + } + want := tt.tab + if dst != want { + t.Errorf("#%d: dst=%+v, want %+v", i, dst, want) + } + } +} + +func TestMarshalNSAttr(t *testing.T) { + src := TableAttrs{TAttr{"hello", "world", "en_US", "other1", "other2", "other3", "other4"}} + data, err := Marshal(&src) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + want := `<TableAttrs><TAttr xmlns:html4="http://www.w3.org/TR/html4/" html4:table="hello" xmlns:furniture="http://www.w3schools.com/furniture" furniture:table="world" xml:lang="en_US" xmlns:_xml="http://golang.org/xml/" _xml:other="other1" xmlns:_xmlfoo="http://golang.org/xmlfoo/" _xmlfoo:other="other2" xmlns:json="http://golang.org/json/" json:other="other3" xmlns:json_1="http://golang.org/2/json/" json_1:other="other4"></TAttr></TableAttrs>` + str := string(data) + if str != want { + t.Errorf("Marshal:\nhave: %#q\nwant: %#q\n", str, want) + } + + var dst TableAttrs + if err := Unmarshal(data, &dst); err != nil { + t.Errorf("Unmarshal: %v", err) + } + + if dst != src { + t.Errorf("Unmarshal = %q, want %q", dst, src) + } +} diff --git a/libgo/go/encoding/xml/typeinfo.go b/libgo/go/encoding/xml/typeinfo.go index bbeb28d..83e6540 100644 --- a/libgo/go/encoding/xml/typeinfo.go +++ b/libgo/go/encoding/xml/typeinfo.go @@ -70,20 +70,19 @@ func getTypeInfo(typ reflect.Type) (*typeInfo, error) { if t.Kind() == reflect.Ptr { t = t.Elem() } - if t.Kind() != reflect.Struct { - continue - } - inner, err := getTypeInfo(t) - if err != nil { - return nil, err - } - for _, finfo := range inner.fields { - finfo.idx = append([]int{i}, finfo.idx...) - if err := addFieldInfo(typ, tinfo, &finfo); err != nil { + if t.Kind() == reflect.Struct { + inner, err := getTypeInfo(t) + if err != nil { return nil, err } + for _, finfo := range inner.fields { + finfo.idx = append([]int{i}, finfo.idx...) + if err := addFieldInfo(typ, tinfo, &finfo); err != nil { + return nil, err + } + } + continue } - continue } finfo, err := structFieldInfo(typ, &f) @@ -193,16 +192,19 @@ func structFieldInfo(typ reflect.Type, f *reflect.StructField) (*fieldInfo, erro } // Prepare field name and parents. - tokens = strings.Split(tag, ">") - if tokens[0] == "" { - tokens[0] = f.Name + parents := strings.Split(tag, ">") + if parents[0] == "" { + parents[0] = f.Name } - if tokens[len(tokens)-1] == "" { + if parents[len(parents)-1] == "" { return nil, fmt.Errorf("xml: trailing '>' in field %s of type %s", f.Name, typ) } - finfo.name = tokens[len(tokens)-1] - if len(tokens) > 1 { - finfo.parents = tokens[:len(tokens)-1] + finfo.name = parents[len(parents)-1] + if len(parents) > 1 { + if (finfo.flags & fElement) == 0 { + return nil, fmt.Errorf("xml: %s chain not valid with %s flag", tag, strings.Join(tokens[1:], ",")) + } + finfo.parents = parents[:len(parents)-1] } // If the field type has an XMLName field, the names must match @@ -268,6 +270,9 @@ Loop: if oldf.flags&fMode != newf.flags&fMode { continue } + if oldf.xmlns != "" && newf.xmlns != "" && oldf.xmlns != newf.xmlns { + continue + } minl := min(len(newf.parents), len(oldf.parents)) for p := 0; p < minl; p++ { if oldf.parents[p] != newf.parents[p] { diff --git a/libgo/go/encoding/xml/xml.go b/libgo/go/encoding/xml/xml.go index decb2be..021f7e4 100644 --- a/libgo/go/encoding/xml/xml.go +++ b/libgo/go/encoding/xml/xml.go @@ -169,6 +169,11 @@ type Decoder struct { // the CharsetReader's result values must be non-nil. CharsetReader func(charset string, input io.Reader) (io.Reader, error) + // DefaultSpace sets the default name space used for unadorned tags, + // as if the entire XML stream were wrapped in an element containing + // the attribute xmlns="DefaultSpace". + DefaultSpace string + r io.ByteReader buf bytes.Buffer saved *bytes.Buffer @@ -268,6 +273,8 @@ func (d *Decoder) Token() (t Token, err error) { return } +const xmlURL = "http://www.w3.org/XML/1998/namespace" + // Apply name space translation to name n. // The default name space (for Space=="") // applies only to element names, not to attribute names. @@ -277,11 +284,15 @@ func (d *Decoder) translate(n *Name, isElementName bool) { return case n.Space == "" && !isElementName: return + case n.Space == "xml": + n.Space = xmlURL case n.Space == "" && n.Local == "xmlns": return } if v, ok := d.ns[n.Space]; ok { n.Space = v + } else if n.Space == "" { + n.Space = d.DefaultSpace } } @@ -956,7 +967,7 @@ Input: b0, b1 = 0, 0 continue Input } - ent := string(d.buf.Bytes()[before]) + ent := string(d.buf.Bytes()[before:]) if ent[len(ent)-1] != ';' { ent += " (no semicolon)" } @@ -1692,7 +1703,7 @@ var HTMLAutoClose = htmlAutoClose var htmlAutoClose = []string{ /* hget http://www.w3.org/TR/html4/loose.dtd | - 9 sed -n 's/<!ELEMENT (.*) - O EMPTY.+/ "\1",/p' | tr A-Z a-z + 9 sed -n 's/<!ELEMENT ([^ ]*) +- O EMPTY.+/ "\1",/p' | tr A-Z a-z */ "basefont", "br", @@ -1702,7 +1713,7 @@ var htmlAutoClose = []string{ "param", "hr", "input", - "col ", + "col", "frame", "isindex", "base", @@ -1718,15 +1729,18 @@ var ( esc_tab = []byte("	") esc_nl = []byte("
") esc_cr = []byte("
") + esc_fffd = []byte("\uFFFD") // Unicode replacement character ) -// Escape writes to w the properly escaped XML equivalent +// EscapeText writes to w the properly escaped XML equivalent // of the plain text data s. -func Escape(w io.Writer, s []byte) { +func EscapeText(w io.Writer, s []byte) error { var esc []byte last := 0 - for i, c := range s { - switch c { + for i := 0; i < len(s); { + r, width := utf8.DecodeRune(s[i:]) + i += width + switch r { case '"': esc = esc_quot case '\'': @@ -1744,13 +1758,31 @@ func Escape(w io.Writer, s []byte) { case '\r': esc = esc_cr default: + if !isInCharacterRange(r) { + esc = esc_fffd + break + } continue } - w.Write(s[last:i]) - w.Write(esc) - last = i + 1 + if _, err := w.Write(s[last : i-width]); err != nil { + return err + } + if _, err := w.Write(esc); err != nil { + return err + } + last = i + } + if _, err := w.Write(s[last:]); err != nil { + return err } - w.Write(s[last:]) + return nil +} + +// Escape is like EscapeText but omits the error return value. +// It is provided for backwards compatibility with Go 1.0. +// Code targeting Go 1.1 or later should use EscapeText. +func Escape(w io.Writer, s []byte) { + EscapeText(w, s) } // procInstEncoding parses the `encoding="..."` or `encoding='...'` diff --git a/libgo/go/encoding/xml/xml_test.go b/libgo/go/encoding/xml/xml_test.go index 981d352..eeedbe5 100644 --- a/libgo/go/encoding/xml/xml_test.go +++ b/libgo/go/encoding/xml/xml_test.go @@ -5,6 +5,7 @@ package xml import ( + "bytes" "fmt" "io" "reflect" @@ -595,13 +596,6 @@ func TestEntityInsideCDATA(t *testing.T) { } } -// The last three tests (respectively one for characters in attribute -// names and two for character entities) pass not because of code -// changed for issue 1259, but instead pass with the given messages -// from other parts of xml.Decoder. I provide these to note the -// current behavior of situations where one might think that character -// range checking would detect the error, but it does not in fact. - var characterTests = []struct { in string err string @@ -611,8 +605,10 @@ var characterTests = []struct { {"\xef\xbf\xbe<doc/>", "illegal character code U+FFFE"}, {"<?xml version=\"1.0\"?><doc>\r\n<hiya/>\x07<toots/></doc>", "illegal character code U+0007"}, {"<?xml version=\"1.0\"?><doc \x12='value'>what's up</doc>", "expected attribute name in element"}, + {"<doc>&abc\x01;</doc>", "invalid character entity &abc (no semicolon)"}, {"<doc>&\x01;</doc>", "invalid character entity & (no semicolon)"}, - {"<doc>&\xef\xbf\xbe;</doc>", "invalid character entity & (no semicolon)"}, + {"<doc>&\xef\xbf\xbe;</doc>", "invalid character entity &\uFFFE;"}, + {"<doc>&hello;</doc>", "invalid character entity &hello;"}, } func TestDisallowedCharacters(t *testing.T) { @@ -629,7 +625,7 @@ func TestDisallowedCharacters(t *testing.T) { t.Fatalf("input %d d.Token() = _, %v, want _, *SyntaxError", i, err) } if synerr.Msg != tt.err { - t.Fatalf("input %d synerr.Msg wrong: want '%s', got '%s'", i, tt.err, synerr.Msg) + t.Fatalf("input %d synerr.Msg wrong: want %q, got %q", i, tt.err, synerr.Msg) } } } @@ -689,3 +685,32 @@ func TestDirectivesWithComments(t *testing.T) { } } } + +// Writer whose Write method always returns an error. +type errWriter struct{} + +func (errWriter) Write(p []byte) (n int, err error) { return 0, fmt.Errorf("unwritable") } + +func TestEscapeTextIOErrors(t *testing.T) { + expectErr := "unwritable" + err := EscapeText(errWriter{}, []byte{'A'}) + + if err == nil || err.Error() != expectErr { + t.Errorf("have %v, want %v", err, expectErr) + } +} + +func TestEscapeTextInvalidChar(t *testing.T) { + input := []byte("A \x00 terminated string.") + expected := "A \uFFFD terminated string." + + buff := new(bytes.Buffer) + if err := EscapeText(buff, input); err != nil { + t.Fatalf("have %v, want nil", err) + } + text := buff.String() + + if text != expected { + t.Errorf("have %v, want %v", text, expected) + } +} |