diff options
Diffstat (limited to 'libgo/go/math/big/nat.go')
-rw-r--r-- | libgo/go/math/big/nat.go | 30 |
1 files changed, 24 insertions, 6 deletions
diff --git a/libgo/go/math/big/nat.go b/libgo/go/math/big/nat.go index 1b771ca..6a3989b 100644 --- a/libgo/go/math/big/nat.go +++ b/libgo/go/math/big/nat.go @@ -740,7 +740,8 @@ func (z nat) divLarge(u, uIn, vIn nat) (q, r nat) { // The remainder overwrites input u. // // Precondition: -// - len(q) >= len(u)-len(v) +// - q is large enough to hold the quotient u / v +// which has a maximum length of len(u)-len(v)+1. func (q nat) divBasic(u, v nat) { n := len(v) m := len(u) - n @@ -779,6 +780,8 @@ func (q nat) divBasic(u, v nat) { } // D4. + // Compute the remainder u - (q̂*v) << (_W*j). + // The subtraction may overflow if q̂ estimate was off by one. qhatv[n] = mulAddVWW(qhatv[0:n], v, qhat, 0) qhl := len(qhatv) if j+qhl > len(u) && qhatv[n] == 0 { @@ -787,7 +790,11 @@ func (q nat) divBasic(u, v nat) { c := subVV(u[j:j+qhl], u[j:], qhatv) if c != 0 { c := addVV(u[j:j+n], u[j:], v) - u[j+n] += c + // If n == qhl, the carry from subVV and the carry from addVV + // cancel out and don't affect u[j+n]. + if n < qhl { + u[j+n] += c + } qhat-- } @@ -827,6 +834,10 @@ func (z nat) divRecursive(u, v nat) { putNat(tmp) } +// divRecursiveStep computes the division of u by v. +// - z must be large enough to hold the quotient +// - the quotient will overwrite z +// - the remainder will overwrite u func (z nat) divRecursiveStep(u, v nat, depth int, tmp *nat, temps []*nat) { u = u.norm() v = v.norm() @@ -1465,19 +1476,26 @@ func (z nat) expNNMontgomery(x, y, m nat) nat { } // bytes writes the value of z into buf using big-endian encoding. -// len(buf) must be >= len(z)*_S. The value of z is encoded in the -// slice buf[i:]. The number i of unused bytes at the beginning of -// buf is returned as result. +// The value of z is encoded in the slice buf[i:]. If the value of z +// cannot be represented in buf, bytes panics. The number i of unused +// bytes at the beginning of buf is returned as result. func (z nat) bytes(buf []byte) (i int) { i = len(buf) for _, d := range z { for j := 0; j < _S; j++ { i-- - buf[i] = byte(d) + if i >= 0 { + buf[i] = byte(d) + } else if byte(d) != 0 { + panic("math/big: buffer too small to fit value") + } d >>= 8 } } + if i < 0 { + i = 0 + } for i < len(buf) && buf[i] == 0 { i++ } |