aboutsummaryrefslogtreecommitdiff
path: root/libgo/go/encoding/base64
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/encoding/base64')
-rw-r--r--libgo/go/encoding/base64/base64.go24
-rw-r--r--libgo/go/encoding/base64/base64_test.go21
2 files changed, 42 insertions, 3 deletions
diff --git a/libgo/go/encoding/base64/base64.go b/libgo/go/encoding/base64/base64.go
index c2116d8..d2efad4 100644
--- a/libgo/go/encoding/base64/base64.go
+++ b/libgo/go/encoding/base64/base64.go
@@ -23,6 +23,7 @@ type Encoding struct {
encode [64]byte
decodeMap [256]byte
padChar rune
+ strict bool
}
const (
@@ -62,6 +63,14 @@ func (enc Encoding) WithPadding(padding rune) *Encoding {
return &enc
}
+// Strict creates a new encoding identical to enc except with
+// strict decoding enabled. In this mode, the decoder requires that
+// trailing padding bits are zero, as described in RFC 4648 section 3.5.
+func (enc Encoding) Strict() *Encoding {
+ enc.strict = true
+ return &enc
+}
+
// StdEncoding is the standard base64 encoding, as defined in
// RFC 4648.
var StdEncoding = NewEncoding(encodeStd)
@@ -311,15 +320,24 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
// Convert 4x 6bit source bytes into 3 bytes
val := uint(dbuf[0])<<18 | uint(dbuf[1])<<12 | uint(dbuf[2])<<6 | uint(dbuf[3])
+ dbuf[2], dbuf[1], dbuf[0] = byte(val>>0), byte(val>>8), byte(val>>16)
switch dlen {
case 4:
- dst[2] = byte(val >> 0)
+ dst[2] = dbuf[2]
+ dbuf[2] = 0
fallthrough
case 3:
- dst[1] = byte(val >> 8)
+ dst[1] = dbuf[1]
+ if enc.strict && dbuf[2] != 0 {
+ return n, end, CorruptInputError(si - 1)
+ }
+ dbuf[1] = 0
fallthrough
case 2:
- dst[0] = byte(val >> 16)
+ dst[0] = dbuf[0]
+ if enc.strict && (dbuf[1] != 0 || dbuf[2] != 0) {
+ return n, end, CorruptInputError(si - 2)
+ }
}
dst = dst[dinc:]
n += dlen - 1
diff --git a/libgo/go/encoding/base64/base64_test.go b/libgo/go/encoding/base64/base64_test.go
index 19ddb92..e2e1d59 100644
--- a/libgo/go/encoding/base64/base64_test.go
+++ b/libgo/go/encoding/base64/base64_test.go
@@ -85,6 +85,11 @@ var encodingTests = []encodingTest{
{RawStdEncoding, rawRef},
{RawURLEncoding, rawUrlRef},
{funnyEncoding, funnyRef},
+ {StdEncoding.Strict(), stdRef},
+ {URLEncoding.Strict(), urlRef},
+ {RawStdEncoding.Strict(), rawRef},
+ {RawURLEncoding.Strict(), rawUrlRef},
+ {funnyEncoding.Strict(), funnyRef},
}
var bigtest = testpair{
@@ -436,6 +441,22 @@ func TestDecoderIssue7733(t *testing.T) {
}
}
+func TestDecoderIssue15656(t *testing.T) {
+ _, err := StdEncoding.Strict().DecodeString("WvLTlMrX9NpYDQlEIFlnDB==")
+ want := CorruptInputError(22)
+ if !reflect.DeepEqual(want, err) {
+ t.Errorf("Error = %v; want CorruptInputError(22)", err)
+ }
+ _, err = StdEncoding.Strict().DecodeString("WvLTlMrX9NpYDQlEIFlnDA==")
+ if err != nil {
+ t.Errorf("Error = %v; want nil", err)
+ }
+ _, err = StdEncoding.DecodeString("WvLTlMrX9NpYDQlEIFlnDB==")
+ if err != nil {
+ t.Errorf("Error = %v; want nil", err)
+ }
+}
+
func BenchmarkEncodeToString(b *testing.B) {
data := make([]byte, 8192)
b.SetBytes(int64(len(data)))