aboutsummaryrefslogtreecommitdiff
path: root/libgo/go/crypto/elliptic/elliptic.go
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/crypto/elliptic/elliptic.go')
-rw-r--r--libgo/go/crypto/elliptic/elliptic.go83
1 files changed, 66 insertions, 17 deletions
diff --git a/libgo/go/crypto/elliptic/elliptic.go b/libgo/go/crypto/elliptic/elliptic.go
index e2f71cd..f93dc16 100644
--- a/libgo/go/crypto/elliptic/elliptic.go
+++ b/libgo/go/crypto/elliptic/elliptic.go
@@ -20,7 +20,10 @@ import (
)
// A Curve represents a short-form Weierstrass curve with a=-3.
-// See https://www.hyperelliptic.org/EFD/g1p/auto-shortw.html
+//
+// Note that the point at infinity (0, 0) is not considered on the curve, and
+// although it can be returned by Add, Double, ScalarMult, or ScalarBaseMult, it
+// can't be marshaled or unmarshaled, and IsOnCurve will return false for it.
type Curve interface {
// Params returns the parameters for the curve.
Params() *CurveParams
@@ -52,11 +55,8 @@ func (curve *CurveParams) Params() *CurveParams {
return curve
}
-func (curve *CurveParams) IsOnCurve(x, y *big.Int) bool {
- // y² = x³ - 3x + b
- y2 := new(big.Int).Mul(y, y)
- y2.Mod(y2, curve.P)
-
+// polynomial returns x³ - 3x + b.
+func (curve *CurveParams) polynomial(x *big.Int) *big.Int {
x3 := new(big.Int).Mul(x, x)
x3.Mul(x3, x)
@@ -67,7 +67,15 @@ func (curve *CurveParams) IsOnCurve(x, y *big.Int) bool {
x3.Add(x3, curve.B)
x3.Mod(x3, curve.P)
- return x3.Cmp(y2) == 0
+ return x3
+}
+
+func (curve *CurveParams) IsOnCurve(x, y *big.Int) bool {
+ // y² = x³ - 3x + b
+ y2 := new(big.Int).Mul(y, y)
+ y2.Mod(y2, curve.P)
+
+ return curve.polynomial(x).Cmp(y2) == 0
}
// zForAffine returns a Jacobian Z value for the affine point (x, y). If x and
@@ -277,7 +285,7 @@ var mask = []byte{0xff, 0x1, 0x3, 0x7, 0xf, 0x1f, 0x3f, 0x7f}
func GenerateKey(curve Curve, rand io.Reader) (priv []byte, x, y *big.Int, err error) {
N := curve.Params().N
bitSize := N.BitLen()
- byteLen := (bitSize + 7) >> 3
+ byteLen := (bitSize + 7) / 8
priv = make([]byte, byteLen)
for x == nil {
@@ -302,30 +310,40 @@ func GenerateKey(curve Curve, rand io.Reader) (priv []byte, x, y *big.Int, err e
return
}
-// Marshal converts a point into the uncompressed form specified in section 4.3.6 of ANSI X9.62.
+// Marshal converts a point on the curve into the uncompressed form specified in
+// section 4.3.6 of ANSI X9.62.
func Marshal(curve Curve, x, y *big.Int) []byte {
- byteLen := (curve.Params().BitSize + 7) >> 3
+ byteLen := (curve.Params().BitSize + 7) / 8
ret := make([]byte, 1+2*byteLen)
ret[0] = 4 // uncompressed point
- xBytes := x.Bytes()
- copy(ret[1+byteLen-len(xBytes):], xBytes)
- yBytes := y.Bytes()
- copy(ret[1+2*byteLen-len(yBytes):], yBytes)
+ x.FillBytes(ret[1 : 1+byteLen])
+ y.FillBytes(ret[1+byteLen : 1+2*byteLen])
+
return ret
}
+// MarshalCompressed converts a point on the curve into the compressed form
+// specified in section 4.3.6 of ANSI X9.62.
+func MarshalCompressed(curve Curve, x, y *big.Int) []byte {
+ byteLen := (curve.Params().BitSize + 7) / 8
+ compressed := make([]byte, 1+byteLen)
+ compressed[0] = byte(y.Bit(0)) | 2
+ x.FillBytes(compressed[1:])
+ return compressed
+}
+
// Unmarshal converts a point, serialized by Marshal, into an x, y pair.
// It is an error if the point is not in uncompressed form or is not on the curve.
// On error, x = nil.
func Unmarshal(curve Curve, data []byte) (x, y *big.Int) {
- byteLen := (curve.Params().BitSize + 7) >> 3
+ byteLen := (curve.Params().BitSize + 7) / 8
if len(data) != 1+2*byteLen {
- return
+ return nil, nil
}
if data[0] != 4 { // uncompressed form
- return
+ return nil, nil
}
p := curve.Params().P
x = new(big.Int).SetBytes(data[1 : 1+byteLen])
@@ -339,6 +357,37 @@ func Unmarshal(curve Curve, data []byte) (x, y *big.Int) {
return
}
+// UnmarshalCompressed converts a point, serialized by MarshalCompressed, into an x, y pair.
+// It is an error if the point is not in compressed form or is not on the curve.
+// On error, x = nil.
+func UnmarshalCompressed(curve Curve, data []byte) (x, y *big.Int) {
+ byteLen := (curve.Params().BitSize + 7) / 8
+ if len(data) != 1+byteLen {
+ return nil, nil
+ }
+ if data[0] != 2 && data[0] != 3 { // compressed form
+ return nil, nil
+ }
+ p := curve.Params().P
+ x = new(big.Int).SetBytes(data[1:])
+ if x.Cmp(p) >= 0 {
+ return nil, nil
+ }
+ // y² = x³ - 3x + b
+ y = curve.Params().polynomial(x)
+ y = y.ModSqrt(y, p)
+ if y == nil {
+ return nil, nil
+ }
+ if byte(y.Bit(0)) != data[0]&1 {
+ y.Neg(y).Mod(y, p)
+ }
+ if !curve.IsOnCurve(x, y) {
+ return nil, nil
+ }
+ return
+}
+
var initonce sync.Once
var p384 *CurveParams
var p521 *CurveParams