aboutsummaryrefslogtreecommitdiff
path: root/libgo/go/crypto/rsa
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/crypto/rsa')
-rw-r--r--libgo/go/crypto/rsa/equal_test.go51
-rw-r--r--libgo/go/crypto/rsa/example_test.go2
-rw-r--r--libgo/go/crypto/rsa/pkcs1v15.go41
-rw-r--r--libgo/go/crypto/rsa/pkcs1v15_test.go18
-rw-r--r--libgo/go/crypto/rsa/pss.go182
-rw-r--r--libgo/go/crypto/rsa/rsa.go86
6 files changed, 231 insertions, 149 deletions
diff --git a/libgo/go/crypto/rsa/equal_test.go b/libgo/go/crypto/rsa/equal_test.go
new file mode 100644
index 0000000..90f4bf9
--- /dev/null
+++ b/libgo/go/crypto/rsa/equal_test.go
@@ -0,0 +1,51 @@
+// Copyright 2020 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package rsa_test
+
+import (
+ "crypto"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "testing"
+)
+
+func TestEqual(t *testing.T) {
+ private, _ := rsa.GenerateKey(rand.Reader, 512)
+ public := &private.PublicKey
+
+ if !public.Equal(public) {
+ t.Errorf("public key is not equal to itself: %v", public)
+ }
+ if !public.Equal(crypto.Signer(private).Public().(*rsa.PublicKey)) {
+ t.Errorf("private.Public() is not Equal to public: %q", public)
+ }
+ if !private.Equal(private) {
+ t.Errorf("private key is not equal to itself: %v", private)
+ }
+
+ enc, err := x509.MarshalPKCS8PrivateKey(private)
+ if err != nil {
+ t.Fatal(err)
+ }
+ decoded, err := x509.ParsePKCS8PrivateKey(enc)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !public.Equal(decoded.(crypto.Signer).Public()) {
+ t.Errorf("public key is not equal to itself after decoding: %v", public)
+ }
+ if !private.Equal(decoded) {
+ t.Errorf("private key is not equal to itself after decoding: %v", private)
+ }
+
+ other, _ := rsa.GenerateKey(rand.Reader, 512)
+ if public.Equal(other.Public()) {
+ t.Errorf("different public keys are Equal")
+ }
+ if private.Equal(other) {
+ t.Errorf("different private keys are Equal")
+ }
+}
diff --git a/libgo/go/crypto/rsa/example_test.go b/libgo/go/crypto/rsa/example_test.go
index 1435b70..ce5c2d9 100644
--- a/libgo/go/crypto/rsa/example_test.go
+++ b/libgo/go/crypto/rsa/example_test.go
@@ -27,7 +27,7 @@ import (
// exponentiation is larger than the modulus. (Otherwise it could be
// decrypted with a square-root.)
//
-// In these designs, when using PKCS#1 v1.5, it's vitally important to
+// In these designs, when using PKCS #1 v1.5, it's vitally important to
// avoid disclosing whether the received RSA message was well-formed
// (that is, whether the result of decrypting is a correctly padded
// message) because this leaks secret information.
diff --git a/libgo/go/crypto/rsa/pkcs1v15.go b/libgo/go/crypto/rsa/pkcs1v15.go
index 37790ac..0cbd6d0 100644
--- a/libgo/go/crypto/rsa/pkcs1v15.go
+++ b/libgo/go/crypto/rsa/pkcs1v15.go
@@ -14,9 +14,9 @@ import (
"crypto/internal/randutil"
)
-// This file implements encryption and decryption using PKCS#1 v1.5 padding.
+// This file implements encryption and decryption using PKCS #1 v1.5 padding.
-// PKCS1v15DecrypterOpts is for passing options to PKCS#1 v1.5 decryption using
+// PKCS1v15DecrypterOpts is for passing options to PKCS #1 v1.5 decryption using
// the crypto.Decrypter interface.
type PKCS1v15DecryptOptions struct {
// SessionKeyLen is the length of the session key that is being
@@ -27,7 +27,7 @@ type PKCS1v15DecryptOptions struct {
}
// EncryptPKCS1v15 encrypts the given message with RSA and the padding
-// scheme from PKCS#1 v1.5. The message must be no longer than the
+// scheme from PKCS #1 v1.5. The message must be no longer than the
// length of the public modulus minus 11 bytes.
//
// The rand parameter is used as a source of entropy to ensure that
@@ -61,11 +61,10 @@ func EncryptPKCS1v15(rand io.Reader, pub *PublicKey, msg []byte) ([]byte, error)
m := new(big.Int).SetBytes(em)
c := encrypt(new(big.Int), pub, m)
- copyWithLeftPad(em, c.Bytes())
- return em, nil
+ return c.FillBytes(em), nil
}
-// DecryptPKCS1v15 decrypts a plaintext using RSA and the padding scheme from PKCS#1 v1.5.
+// DecryptPKCS1v15 decrypts a plaintext using RSA and the padding scheme from PKCS #1 v1.5.
// If rand != nil, it uses RSA blinding to avoid timing side-channel attacks.
//
// Note that whether this function returns an error or not discloses secret
@@ -87,7 +86,7 @@ func DecryptPKCS1v15(rand io.Reader, priv *PrivateKey, ciphertext []byte) ([]byt
return out[index:], nil
}
-// DecryptPKCS1v15SessionKey decrypts a session key using RSA and the padding scheme from PKCS#1 v1.5.
+// DecryptPKCS1v15SessionKey decrypts a session key using RSA and the padding scheme from PKCS #1 v1.5.
// If rand != nil, it uses RSA blinding to avoid timing side-channel attacks.
// It returns an error if the ciphertext is the wrong length or if the
// ciphertext is greater than the public modulus. Otherwise, no error is
@@ -150,7 +149,7 @@ func decryptPKCS1v15(rand io.Reader, priv *PrivateKey, ciphertext []byte) (valid
return
}
- em = leftPad(m.Bytes(), k)
+ em = m.FillBytes(make([]byte, k))
firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)
secondByteIsTwo := subtle.ConstantTimeByteEq(em[1], 2)
@@ -217,7 +216,7 @@ var hashPrefixes = map[crypto.Hash][]byte{
}
// SignPKCS1v15 calculates the signature of hashed using
-// RSASSA-PKCS1-V1_5-SIGN from RSA PKCS#1 v1.5. Note that hashed must
+// RSASSA-PKCS1-V1_5-SIGN from RSA PKCS #1 v1.5. Note that hashed must
// be the result of hashing the input message using the given hash
// function. If hash is zero, hashed is signed directly. This isn't
// advisable except for interoperability.
@@ -256,11 +255,10 @@ func SignPKCS1v15(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed []b
return nil, err
}
- copyWithLeftPad(em, c.Bytes())
- return em, nil
+ return c.FillBytes(em), nil
}
-// VerifyPKCS1v15 verifies an RSA PKCS#1 v1.5 signature.
+// VerifyPKCS1v15 verifies an RSA PKCS #1 v1.5 signature.
// hashed is the result of hashing the input message using the given hash
// function and sig is the signature. A valid signature is indicated by
// returning a nil error. If hash is zero then hashed is used directly. This
@@ -277,9 +275,16 @@ func VerifyPKCS1v15(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte)
return ErrVerification
}
+ // RFC 8017 Section 8.2.2: If the length of the signature S is not k
+ // octets (where k is the length in octets of the RSA modulus n), output
+ // "invalid signature" and stop.
+ if k != len(sig) {
+ return ErrVerification
+ }
+
c := new(big.Int).SetBytes(sig)
m := encrypt(new(big.Int), pub, c)
- em := leftPad(m.Bytes(), k)
+ em := m.FillBytes(make([]byte, k))
// EM = 0x00 || 0x01 || PS || 0x00 || T
ok := subtle.ConstantTimeByteEq(em[0], 0)
@@ -316,13 +321,3 @@ func pkcs1v15HashInfo(hash crypto.Hash, inLen int) (hashLen int, prefix []byte,
}
return
}
-
-// copyWithLeftPad copies src to the end of dest, padding with zero bytes as
-// needed.
-func copyWithLeftPad(dest, src []byte) {
- numPaddingBytes := len(dest) - len(src)
- for i := 0; i < numPaddingBytes; i++ {
- dest[i] = 0
- }
- copy(dest[numPaddingBytes:], src)
-}
diff --git a/libgo/go/crypto/rsa/pkcs1v15_test.go b/libgo/go/crypto/rsa/pkcs1v15_test.go
index 7e62560..26b8c5f 100644
--- a/libgo/go/crypto/rsa/pkcs1v15_test.go
+++ b/libgo/go/crypto/rsa/pkcs1v15_test.go
@@ -9,6 +9,7 @@ import (
"crypto"
"crypto/rand"
"crypto/sha1"
+ "crypto/sha256"
"encoding/base64"
"encoding/hex"
"io"
@@ -296,3 +297,20 @@ var rsaPrivateKey = &PrivateKey{
fromBase10("94560208308847015747498523884063394671606671904944666360068158221458669711639"),
},
}
+
+func TestShortPKCS1v15Signature(t *testing.T) {
+ pub := &PublicKey{
+ E: 65537,
+ N: fromBase10("8272693557323587081220342447407965471608219912416565371060697606400726784709760494166080686904546560026343451112103559482851304715739629410219358933351333"),
+ }
+ sig, err := hex.DecodeString("193a310d0dcf64094c6e3a00c8219b80ded70535473acff72c08e1222974bb24a93a535b1dc4c59fc0e65775df7ba2007dd20e9193f4c4025a18a7070aee93")
+ if err != nil {
+ t.Fatalf("failed to decode signature: %s", err)
+ }
+
+ h := sha256.Sum256([]byte("hello"))
+ err = VerifyPKCS1v15(pub, crypto.SHA256, h[:], sig)
+ if err == nil {
+ t.Fatal("VerifyPKCS1v15 accepted a truncated signature")
+ }
+}
diff --git a/libgo/go/crypto/rsa/pss.go b/libgo/go/crypto/rsa/pss.go
index 3ff0c2f..b2adbed 100644
--- a/libgo/go/crypto/rsa/pss.go
+++ b/libgo/go/crypto/rsa/pss.go
@@ -4,9 +4,7 @@
package rsa
-// This file implements the PSS signature scheme [1].
-//
-// [1] https://www.emc.com/collateral/white-papers/h11300-pkcs-1v2-2-rsa-cryptography-standard-wp.pdf
+// This file implements the RSASSA-PSS signature scheme according to RFC 8017.
import (
"bytes"
@@ -17,8 +15,22 @@ import (
"math/big"
)
+// Per RFC 8017, Section 9.1
+//
+// EM = MGF1 xor DB || H( 8*0x00 || mHash || salt ) || 0xbc
+//
+// where
+//
+// DB = PS || 0x01 || salt
+//
+// and PS can be empty so
+//
+// emLen = dbLen + hLen + 1 = psLen + sLen + hLen + 2
+//
+
func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byte, error) {
- // See [1], section 9.1.1
+ // See RFC 8017, Section 9.1.1.
+
hLen := hash.Size()
sLen := len(salt)
emLen := (emBits + 7) / 8
@@ -30,7 +42,7 @@ func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byt
// 2. Let mHash = Hash(M), an octet string of length hLen.
if len(mHash) != hLen {
- return nil, errors.New("crypto/rsa: input must be hashed message")
+ return nil, errors.New("crypto/rsa: input must be hashed with given hash")
}
// 3. If emLen < hLen + sLen + 2, output "encoding error" and stop.
@@ -40,8 +52,9 @@ func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byt
}
em := make([]byte, emLen)
- db := em[:emLen-sLen-hLen-2+1+sLen]
- h := em[emLen-sLen-hLen-2+1+sLen : emLen-1]
+ psLen := emLen - sLen - hLen - 2
+ db := em[:psLen+1+sLen]
+ h := em[psLen+1+sLen : emLen-1]
// 4. Generate a random octet string salt of length sLen; if sLen = 0,
// then salt is the empty string.
@@ -69,8 +82,8 @@ func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byt
// 8. Let DB = PS || 0x01 || salt; DB is an octet string of length
// emLen - hLen - 1.
- db[emLen-sLen-hLen-2] = 0x01
- copy(db[emLen-sLen-hLen-1:], salt)
+ db[psLen] = 0x01
+ copy(db[psLen+1:], salt)
// 9. Let dbMask = MGF(H, emLen - hLen - 1).
//
@@ -81,47 +94,57 @@ func emsaPSSEncode(mHash []byte, emBits int, salt []byte, hash hash.Hash) ([]byt
// 11. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in
// maskedDB to zero.
- db[0] &= (0xFF >> uint(8*emLen-emBits))
+ db[0] &= 0xff >> (8*emLen - emBits)
// 12. Let EM = maskedDB || H || 0xbc.
- em[emLen-1] = 0xBC
+ em[emLen-1] = 0xbc
// 13. Output EM.
return em, nil
}
func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
+ // See RFC 8017, Section 9.1.2.
+
+ hLen := hash.Size()
+ if sLen == PSSSaltLengthEqualsHash {
+ sLen = hLen
+ }
+ emLen := (emBits + 7) / 8
+ if emLen != len(em) {
+ return errors.New("rsa: internal error: inconsistent length")
+ }
+
// 1. If the length of M is greater than the input limitation for the
// hash function (2^61 - 1 octets for SHA-1), output "inconsistent"
// and stop.
//
// 2. Let mHash = Hash(M), an octet string of length hLen.
- hLen := hash.Size()
if hLen != len(mHash) {
return ErrVerification
}
// 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop.
- emLen := (emBits + 7) / 8
if emLen < hLen+sLen+2 {
return ErrVerification
}
// 4. If the rightmost octet of EM does not have hexadecimal value
// 0xbc, output "inconsistent" and stop.
- if em[len(em)-1] != 0xBC {
+ if em[emLen-1] != 0xbc {
return ErrVerification
}
// 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and
// let H be the next hLen octets.
db := em[:emLen-hLen-1]
- h := em[emLen-hLen-1 : len(em)-1]
+ h := em[emLen-hLen-1 : emLen-1]
// 6. If the leftmost 8 * emLen - emBits bits of the leftmost octet in
// maskedDB are not all equal to zero, output "inconsistent" and
// stop.
- if em[0]&(0xFF<<uint(8-(8*emLen-emBits))) != 0 {
+ var bitMask byte = 0xff >> (8*emLen - emBits)
+ if em[0] & ^bitMask != 0 {
return ErrVerification
}
@@ -132,37 +155,30 @@ func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
// 9. Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB
// to zero.
- db[0] &= (0xFF >> uint(8*emLen-emBits))
+ db[0] &= bitMask
+ // If we don't know the salt length, look for the 0x01 delimiter.
if sLen == PSSSaltLengthAuto {
- FindSaltLength:
- for sLen = emLen - (hLen + 2); sLen >= 0; sLen-- {
- switch db[emLen-hLen-sLen-2] {
- case 1:
- break FindSaltLength
- case 0:
- continue
- default:
- return ErrVerification
- }
- }
- if sLen < 0 {
+ psLen := bytes.IndexByte(db, 0x01)
+ if psLen < 0 {
return ErrVerification
}
- } else {
- // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero
- // or if the octet at position emLen - hLen - sLen - 1 (the leftmost
- // position is "position 1") does not have hexadecimal value 0x01,
- // output "inconsistent" and stop.
- for _, e := range db[:emLen-hLen-sLen-2] {
- if e != 0x00 {
- return ErrVerification
- }
- }
- if db[emLen-hLen-sLen-2] != 0x01 {
+ sLen = len(db) - psLen - 1
+ }
+
+ // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero
+ // or if the octet at position emLen - hLen - sLen - 1 (the leftmost
+ // position is "position 1") does not have hexadecimal value 0x01,
+ // output "inconsistent" and stop.
+ psLen := emLen - hLen - sLen - 2
+ for _, e := range db[:psLen] {
+ if e != 0x00 {
return ErrVerification
}
}
+ if db[psLen] != 0x01 {
+ return ErrVerification
+ }
// 11. Let salt be the last sLen octets of DB.
salt := db[len(db)-sLen:]
@@ -181,30 +197,29 @@ func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error {
h0 := hash.Sum(nil)
// 14. If H = H', output "consistent." Otherwise, output "inconsistent."
- if !bytes.Equal(h0, h) {
+ if !bytes.Equal(h0, h) { // TODO: constant time?
return ErrVerification
}
return nil
}
-// signPSSWithSalt calculates the signature of hashed using PSS [1] with specified salt.
+// signPSSWithSalt calculates the signature of hashed using PSS with specified salt.
// Note that hashed must be the result of hashing the input message using the
// given hash function. salt is a random sequence of bytes whose length will be
// later used to verify the signature.
-func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) (s []byte, err error) {
- nBits := priv.N.BitLen()
- em, err := emsaPSSEncode(hashed, nBits-1, salt, hash.New())
+func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) {
+ emBits := priv.N.BitLen() - 1
+ em, err := emsaPSSEncode(hashed, emBits, salt, hash.New())
if err != nil {
- return
+ return nil, err
}
m := new(big.Int).SetBytes(em)
c, err := decryptAndCheck(rand, priv, m)
if err != nil {
- return
+ return nil, err
}
- s = make([]byte, (nBits+7)/8)
- copyWithLeftPad(s, c.Bytes())
- return
+ s := make([]byte, priv.Size())
+ return c.FillBytes(s), nil
}
const (
@@ -223,16 +238,15 @@ type PSSOptions struct {
// PSSSaltLength constants.
SaltLength int
- // Hash, if not zero, overrides the hash function passed to SignPSS.
- // This is the only way to specify the hash function when using the
- // crypto.Signer interface.
+ // Hash is the hash function used to generate the message digest. If not
+ // zero, it overrides the hash function passed to SignPSS. It's required
+ // when using PrivateKey.Sign.
Hash crypto.Hash
}
-// HashFunc returns pssOpts.Hash so that PSSOptions implements
-// crypto.SignerOpts.
-func (pssOpts *PSSOptions) HashFunc() crypto.Hash {
- return pssOpts.Hash
+// HashFunc returns opts.Hash so that PSSOptions implements crypto.SignerOpts.
+func (opts *PSSOptions) HashFunc() crypto.Hash {
+ return opts.Hash
}
func (opts *PSSOptions) saltLength() int {
@@ -242,56 +256,48 @@ func (opts *PSSOptions) saltLength() int {
return opts.SaltLength
}
-// SignPSS calculates the signature of hashed using RSASSA-PSS [1].
-// Note that hashed must be the result of hashing the input message using the
-// given hash function. The opts argument may be nil, in which case sensible
-// defaults are used.
-func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed []byte, opts *PSSOptions) ([]byte, error) {
+// SignPSS calculates the signature of digest using PSS.
+//
+// digest must be the result of hashing the input message using the given hash
+// function. The opts argument may be nil, in which case sensible defaults are
+// used. If opts.Hash is set, it overrides hash.
+func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, digest []byte, opts *PSSOptions) ([]byte, error) {
+ if opts != nil && opts.Hash != 0 {
+ hash = opts.Hash
+ }
+
saltLength := opts.saltLength()
switch saltLength {
case PSSSaltLengthAuto:
- saltLength = (priv.N.BitLen()+7)/8 - 2 - hash.Size()
+ saltLength = priv.Size() - 2 - hash.Size()
case PSSSaltLengthEqualsHash:
saltLength = hash.Size()
}
- if opts != nil && opts.Hash != 0 {
- hash = opts.Hash
- }
-
salt := make([]byte, saltLength)
if _, err := io.ReadFull(rand, salt); err != nil {
return nil, err
}
- return signPSSWithSalt(rand, priv, hash, hashed, salt)
+ return signPSSWithSalt(rand, priv, hash, digest, salt)
}
// VerifyPSS verifies a PSS signature.
-// hashed is the result of hashing the input message using the given hash
-// function and sig is the signature. A valid signature is indicated by
-// returning a nil error. The opts argument may be nil, in which case sensible
-// defaults are used.
-func VerifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, opts *PSSOptions) error {
- return verifyPSS(pub, hash, hashed, sig, opts.saltLength())
-}
-
-// verifyPSS verifies a PSS signature with the given salt length.
-func verifyPSS(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte, saltLen int) error {
- nBits := pub.N.BitLen()
- if len(sig) != (nBits+7)/8 {
+//
+// A valid signature is indicated by returning a nil error. digest must be the
+// result of hashing the input message using the given hash function. The opts
+// argument may be nil, in which case sensible defaults are used. opts.Hash is
+// ignored.
+func VerifyPSS(pub *PublicKey, hash crypto.Hash, digest []byte, sig []byte, opts *PSSOptions) error {
+ if len(sig) != pub.Size() {
return ErrVerification
}
s := new(big.Int).SetBytes(sig)
m := encrypt(new(big.Int), pub, s)
- emBits := nBits - 1
+ emBits := pub.N.BitLen() - 1
emLen := (emBits + 7) / 8
- if emLen < len(m.Bytes()) {
+ if m.BitLen() > emLen*8 {
return ErrVerification
}
- em := make([]byte, emLen)
- copyWithLeftPad(em, m.Bytes())
- if saltLen == PSSSaltLengthEqualsHash {
- saltLen = hash.Size()
- }
- return emsaPSSVerify(hashed, em, emBits, saltLen, hash.New())
+ em := m.FillBytes(make([]byte, emLen))
+ return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New())
}
diff --git a/libgo/go/crypto/rsa/rsa.go b/libgo/go/crypto/rsa/rsa.go
index d058949..178ade6 100644
--- a/libgo/go/crypto/rsa/rsa.go
+++ b/libgo/go/crypto/rsa/rsa.go
@@ -2,21 +2,21 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Package rsa implements RSA encryption as specified in PKCS#1.
+// Package rsa implements RSA encryption as specified in PKCS #1 and RFC 8017.
//
// RSA is a single, fundamental operation that is used in this package to
// implement either public-key encryption or public-key signatures.
//
-// The original specification for encryption and signatures with RSA is PKCS#1
+// The original specification for encryption and signatures with RSA is PKCS #1
// and the terms "RSA encryption" and "RSA signatures" by default refer to
-// PKCS#1 version 1.5. However, that specification has flaws and new designs
-// should use version two, usually called by just OAEP and PSS, where
+// PKCS #1 version 1.5. However, that specification has flaws and new designs
+// should use version 2, usually called by just OAEP and PSS, where
// possible.
//
// Two sets of interfaces are included in this package. When a more abstract
// interface isn't necessary, there are functions for encrypting/decrypting
// with v1.5/OAEP and signing/verifying with v1.5/PSS. If one needs to abstract
-// over the public-key primitive, the PrivateKey struct implements the
+// over the public key primitive, the PrivateKey type implements the
// Decrypter and Signer interfaces from the crypto package.
//
// The RSA operations in this package are not implemented using constant-time algorithms.
@@ -44,12 +44,24 @@ type PublicKey struct {
E int // public exponent
}
+// Any methods implemented on PublicKey might need to also be implemented on
+// PrivateKey, as the latter embeds the former and will expose its methods.
+
// Size returns the modulus size in bytes. Raw signatures and ciphertexts
// for or by this public key will have the same size.
func (pub *PublicKey) Size() int {
return (pub.N.BitLen() + 7) / 8
}
+// Equal reports whether pub and x have the same value.
+func (pub *PublicKey) Equal(x crypto.PublicKey) bool {
+ xx, ok := x.(*PublicKey)
+ if !ok {
+ return false
+ }
+ return pub.N.Cmp(xx.N) == 0 && pub.E == xx.E
+}
+
// OAEPOptions is an interface for passing options to OAEP decryption using the
// crypto.Decrypter interface.
type OAEPOptions struct {
@@ -100,9 +112,31 @@ func (priv *PrivateKey) Public() crypto.PublicKey {
return &priv.PublicKey
}
+// Equal reports whether priv and x have equivalent values. It ignores
+// Precomputed values.
+func (priv *PrivateKey) Equal(x crypto.PrivateKey) bool {
+ xx, ok := x.(*PrivateKey)
+ if !ok {
+ return false
+ }
+ if !priv.PublicKey.Equal(&xx.PublicKey) || priv.D.Cmp(xx.D) != 0 {
+ return false
+ }
+ if len(priv.Primes) != len(xx.Primes) {
+ return false
+ }
+ for i := range priv.Primes {
+ if priv.Primes[i].Cmp(xx.Primes[i]) != 0 {
+ return false
+ }
+ }
+ return true
+}
+
// Sign signs digest with priv, reading randomness from rand. If opts is a
-// *PSSOptions then the PSS algorithm will be used, otherwise PKCS#1 v1.5 will
-// be used.
+// *PSSOptions then the PSS algorithm will be used, otherwise PKCS #1 v1.5 will
+// be used. digest must be the result of hashing the input message using
+// opts.HashFunc().
//
// This method implements crypto.Signer, which is an interface to support keys
// where the private part is kept in, for example, a hardware module. Common
@@ -116,7 +150,7 @@ func (priv *PrivateKey) Sign(rand io.Reader, digest []byte, opts crypto.SignerOp
}
// Decrypt decrypts ciphertext with priv. If opts is nil or of type
-// *PKCS1v15DecryptOptions then PKCS#1 v1.5 decryption is performed. Otherwise
+// *PKCS1v15DecryptOptions then PKCS #1 v1.5 decryption is performed. Otherwise
// opts must have type *OAEPOptions and OAEP decryption is done.
func (priv *PrivateKey) Decrypt(rand io.Reader, ciphertext []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) {
if opts == nil {
@@ -152,7 +186,7 @@ type PrecomputedValues struct {
// CRTValues is used for the 3rd and subsequent primes. Due to a
// historical accident, the CRT for the first two primes is handled
- // differently in PKCS#1 and interoperability is sufficiently
+ // differently in PKCS #1 and interoperability is sufficiently
// important that we mirror this.
CRTValues []CRTValue
}
@@ -326,7 +360,7 @@ func incCounter(c *[4]byte) {
}
// mgf1XOR XORs the bytes in out with a mask generated using the MGF1 function
-// specified in PKCS#1 v2.1.
+// specified in PKCS #1 v2.1.
func mgf1XOR(out []byte, hash hash.Hash, seed []byte) {
var counter [4]byte
var digest []byte
@@ -406,16 +440,9 @@ func EncryptOAEP(hash hash.Hash, random io.Reader, pub *PublicKey, msg []byte, l
m := new(big.Int)
m.SetBytes(em)
c := encrypt(new(big.Int), pub, m)
- out := c.Bytes()
-
- if len(out) < k {
- // If the output is too small, we need to left-pad with zeros.
- t := make([]byte, k)
- copy(t[k-len(out):], out)
- out = t
- }
- return out, nil
+ out := make([]byte, k)
+ return c.FillBytes(out), nil
}
// ErrDecryption represents a failure to decrypt a message.
@@ -587,12 +614,9 @@ func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext
lHash := hash.Sum(nil)
hash.Reset()
- // Converting the plaintext number to bytes will strip any
- // leading zeros so we may have to left pad. We do this unconditionally
- // to avoid leaking timing information. (Although we still probably
- // leak the number of leading zeros. It's not clear that we can do
- // anything about this.)
- em := leftPad(m.Bytes(), k)
+ // We probably leak the number of leading zeros.
+ // It's not clear that we can do anything about this.
+ em := m.FillBytes(make([]byte, k))
firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)
@@ -633,15 +657,3 @@ func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext
return rest[index+1:], nil
}
-
-// leftPad returns a new slice of length size. The contents of input are right
-// aligned in the new slice.
-func leftPad(input []byte, size int) (out []byte) {
- n := len(input)
- if n > size {
- n = size
- }
- out = make([]byte, size)
- copy(out[len(out)-n:], input)
- return
-}