From b740cb6335fd36e5847f2abd3c890b7f14d1ef30 Mon Sep 17 00:00:00 2001 From: Ian Lance Taylor Date: Tue, 29 Nov 2011 23:02:54 +0000 Subject: libgo: update to weekly.2011-10-25 Changes were mainly straightforward to merge. From-SVN: r181824 --- libgo/go/big/int.go | 31 +- libgo/go/big/int_test.go | 11 +- libgo/go/big/nat.go | 13 +- libgo/go/big/nat_test.go | 20 +- libgo/go/big/rat.go | 201 +- libgo/go/big/rat_test.go | 124 +- libgo/go/crypto/x509/x509.go | 4 +- libgo/go/crypto/x509/x509_test.go | 13 +- libgo/go/exp/inotify/inotify_linux.go | 288 ++ libgo/go/exp/inotify/inotify_linux_test.go | 96 + libgo/go/exp/ssh/channel.go | 2 +- libgo/go/exp/ssh/client.go | 490 ++ libgo/go/exp/ssh/doc.go | 48 +- libgo/go/exp/ssh/messages.go | 2 +- libgo/go/exp/ssh/server.go | 151 +- libgo/go/exp/ssh/session.go | 132 + libgo/go/exp/ssh/transport.go | 17 +- libgo/go/exp/ssh/transport_test.go | 10 +- libgo/go/exp/types/gcimporter.go | 5 +- libgo/go/exp/winfsnotify/winfsnotify_test.go | 10 +- libgo/go/fmt/fmt_test.go | 5 + libgo/go/fmt/print.go | 179 +- libgo/go/go/ast/print_test.go | 7 +- libgo/go/html/doc.go | 3 - libgo/go/html/parse.go | 353 +- libgo/go/html/parse_test.go | 22 +- libgo/go/html/render.go | 12 +- libgo/go/html/token.go | 50 +- libgo/go/html/token_test.go | 1 - libgo/go/http/client.go | 5 +- libgo/go/http/client_test.go | 24 + libgo/go/http/doc.go | 79 + libgo/go/http/request.go | 2 - libgo/go/http/transport.go | 11 +- libgo/go/net/sock_windows.go | 3 - libgo/go/os/inotify/inotify_linux.go | 288 -- libgo/go/os/inotify/inotify_linux_test.go | 96 - libgo/go/unicode/tables.go | 6976 +++++++++++++------------- 38 files changed, 5521 insertions(+), 4263 deletions(-) create mode 100644 libgo/go/exp/inotify/inotify_linux.go create mode 100644 libgo/go/exp/inotify/inotify_linux_test.go create mode 100644 libgo/go/exp/ssh/client.go create mode 100644 libgo/go/exp/ssh/session.go create mode 100644 libgo/go/http/doc.go delete mode 100644 libgo/go/os/inotify/inotify_linux.go delete mode 100644 libgo/go/os/inotify/inotify_linux_test.go (limited to 'libgo/go') diff --git a/libgo/go/big/int.go b/libgo/go/big/int.go index 9e1d1ae..b0dde1e 100644 --- a/libgo/go/big/int.go +++ b/libgo/go/big/int.go @@ -58,22 +58,24 @@ func NewInt(x int64) *Int { // Set sets z to x and returns z. func (z *Int) Set(x *Int) *Int { - z.abs = z.abs.set(x.abs) - z.neg = x.neg + if z != x { + z.abs = z.abs.set(x.abs) + z.neg = x.neg + } return z } // Abs sets z to |x| (the absolute value of x) and returns z. func (z *Int) Abs(x *Int) *Int { - z.abs = z.abs.set(x.abs) + z.Set(x) z.neg = false return z } // Neg sets z to -x and returns z. func (z *Int) Neg(x *Int) *Int { - z.abs = z.abs.set(x.abs) - z.neg = len(z.abs) > 0 && !x.neg // 0 has no sign + z.Set(x) + z.neg = len(z.abs) > 0 && !z.neg // 0 has no sign return z } @@ -174,7 +176,7 @@ func (z *Int) Quo(x, y *Int) *Int { // If y == 0, a division-by-zero run-time panic occurs. // Rem implements truncated modulus (like Go); see QuoRem for more details. func (z *Int) Rem(x, y *Int) *Int { - _, z.abs = nat(nil).div(z.abs, x.abs, y.abs) + _, z.abs = nat{}.div(z.abs, x.abs, y.abs) z.neg = len(z.abs) > 0 && x.neg // 0 has no sign return z } @@ -422,8 +424,8 @@ func (x *Int) Format(s fmt.State, ch int) { // scan sets z to the integer value corresponding to the longest possible prefix // read from r representing a signed integer number in a given conversion base. // It returns z, the actual conversion base used, and an error, if any. In the -// error case, the value of z is undefined. The syntax follows the syntax of -// integer literals in Go. +// error case, the value of z is undefined but the returned value is nil. The +// syntax follows the syntax of integer literals in Go. // // The base argument must be 0 or a value from 2 through MaxBase. If the base // is 0, the string prefix determines the actual conversion base. A prefix of @@ -434,7 +436,7 @@ func (z *Int) scan(r io.RuneScanner, base int) (*Int, int, os.Error) { // determine sign ch, _, err := r.ReadRune() if err != nil { - return z, 0, err + return nil, 0, err } neg := false switch ch { @@ -448,7 +450,7 @@ func (z *Int) scan(r io.RuneScanner, base int) (*Int, int, os.Error) { // determine mantissa z.abs, base, err = z.abs.scan(r, base) if err != nil { - return z, base, err + return nil, base, err } z.neg = len(z.abs) > 0 && neg // 0 has no sign @@ -497,7 +499,7 @@ func (x *Int) Int64() int64 { // SetString sets z to the value of s, interpreted in the given base, // and returns z and a boolean indicating success. If SetString fails, -// the value of z is undefined. +// the value of z is undefined but the returned value is nil. // // The base argument must be 0 or a value from 2 through MaxBase. If the base // is 0, the string prefix determines the actual conversion base. A prefix of @@ -508,10 +510,13 @@ func (z *Int) SetString(s string, base int) (*Int, bool) { r := strings.NewReader(s) _, _, err := z.scan(r, base) if err != nil { - return z, false + return nil, false } _, _, err = r.ReadRune() - return z, err == os.EOF // err == os.EOF => scan consumed all of s + if err != os.EOF { + return nil, false + } + return z, true // err == os.EOF => scan consumed all of s } // SetBytes interprets buf as the bytes of a big-endian unsigned diff --git a/libgo/go/big/int_test.go b/libgo/go/big/int_test.go index b2e16921..fde19c2 100644 --- a/libgo/go/big/int_test.go +++ b/libgo/go/big/int_test.go @@ -311,7 +311,16 @@ func TestSetString(t *testing.T) { t.Errorf("#%d (input '%s') ok incorrect (should be %t)", i, test.in, test.ok) continue } - if !ok1 || !ok2 { + if !ok1 { + if n1 != nil { + t.Errorf("#%d (input '%s') n1 != nil", i, test.in) + } + continue + } + if !ok2 { + if n2 != nil { + t.Errorf("#%d (input '%s') n2 != nil", i, test.in) + } continue } diff --git a/libgo/go/big/nat.go b/libgo/go/big/nat.go index 33d6bb1..c0769d8 100644 --- a/libgo/go/big/nat.go +++ b/libgo/go/big/nat.go @@ -35,7 +35,7 @@ import ( // During arithmetic operations, denormalized values may occur but are // always normalized before returning the final result. The normalized // representation of 0 is the empty or nil slice (length = 0). - +// type nat []Word var ( @@ -447,10 +447,10 @@ func (z nat) mulRange(a, b uint64) nat { case a == b: return z.setUint64(a) case a+1 == b: - return z.mul(nat(nil).setUint64(a), nat(nil).setUint64(b)) + return z.mul(nat{}.setUint64(a), nat{}.setUint64(b)) } m := (a + b) / 2 - return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b)) + return z.mul(nat{}.mulRange(a, m), nat{}.mulRange(m+1, b)) } // q = (x-r)/y, with 0 <= r < y @@ -589,7 +589,6 @@ func (x nat) bitLen() int { // MaxBase is the largest number base accepted for string conversions. const MaxBase = 'z' - 'a' + 10 + 1 // = hexValue('z') + 1 - func hexValue(ch int) Word { d := MaxBase + 1 // illegal base switch { @@ -786,7 +785,7 @@ func (x nat) string(charset string) string { } // preserve x, create local copy for use in repeated divisions - q := nat(nil).set(x) + q := nat{}.set(x) var r Word // convert @@ -1192,11 +1191,11 @@ func (n nat) probablyPrime(reps int) bool { return false } - nm1 := nat(nil).sub(n, natOne) + nm1 := nat{}.sub(n, natOne) // 1< 0 { @@ -212,7 +212,7 @@ func TestString(t *testing.T) { t.Errorf("string%+v\n\tgot s = %s; want %s", a, s, a.s) } - x, b, err := nat(nil).scan(strings.NewReader(a.s), len(a.c)) + x, b, err := nat{}.scan(strings.NewReader(a.s), len(a.c)) if x.cmp(a.x) != 0 { t.Errorf("scan%+v\n\tgot z = %v; want %v", a, x, a.x) } @@ -271,7 +271,7 @@ var natScanTests = []struct { func TestScanBase(t *testing.T) { for _, a := range natScanTests { r := strings.NewReader(a.s) - x, b, err := nat(nil).scan(r, a.base) + x, b, err := nat{}.scan(r, a.base) if err == nil && !a.ok { t.Errorf("scan%+v\n\texpected error", a) } @@ -651,17 +651,17 @@ var expNNTests = []struct { func TestExpNN(t *testing.T) { for i, test := range expNNTests { - x, _, _ := nat(nil).scan(strings.NewReader(test.x), 0) - y, _, _ := nat(nil).scan(strings.NewReader(test.y), 0) - out, _, _ := nat(nil).scan(strings.NewReader(test.out), 0) + x, _, _ := nat{}.scan(strings.NewReader(test.x), 0) + y, _, _ := nat{}.scan(strings.NewReader(test.y), 0) + out, _, _ := nat{}.scan(strings.NewReader(test.out), 0) var m nat if len(test.m) > 0 { - m, _, _ = nat(nil).scan(strings.NewReader(test.m), 0) + m, _, _ = nat{}.scan(strings.NewReader(test.m), 0) } - z := nat(nil).expNN(x, y, m) + z := nat{}.expNN(x, y, m) if z.cmp(out) != 0 { t.Errorf("#%d got %v want %v", i, z, out) } diff --git a/libgo/go/big/rat.go b/libgo/go/big/rat.go index f435e63..6b86062 100644 --- a/libgo/go/big/rat.go +++ b/libgo/go/big/rat.go @@ -13,11 +13,11 @@ import ( "strings" ) -// A Rat represents a quotient a/b of arbitrary precision. The zero value for -// a Rat, 0/0, is not a legal Rat. +// A Rat represents a quotient a/b of arbitrary precision. +// The zero value for a Rat represents the value 0. type Rat struct { a Int - b nat + b nat // len(b) == 0 acts like b == 1 } // NewRat creates a new Rat with numerator a and denominator b. @@ -29,8 +29,11 @@ func NewRat(a, b int64) *Rat { func (z *Rat) SetFrac(a, b *Int) *Rat { z.a.neg = a.neg != b.neg babs := b.abs + if len(babs) == 0 { + panic("division by zero") + } if &z.a == b || alias(z.a.abs, babs) { - babs = nat(nil).set(babs) // make a copy + babs = nat{}.set(babs) // make a copy } z.a.abs = z.a.abs.set(a.abs) z.b = z.b.set(babs) @@ -40,6 +43,9 @@ func (z *Rat) SetFrac(a, b *Int) *Rat { // SetFrac64 sets z to a/b and returns z. func (z *Rat) SetFrac64(a, b int64) *Rat { z.a.SetInt64(a) + if b == 0 { + panic("division by zero") + } if b < 0 { b = -b z.a.neg = !z.a.neg @@ -51,14 +57,55 @@ func (z *Rat) SetFrac64(a, b int64) *Rat { // SetInt sets z to x (by making a copy of x) and returns z. func (z *Rat) SetInt(x *Int) *Rat { z.a.Set(x) - z.b = z.b.setWord(1) + z.b = z.b.make(0) return z } // SetInt64 sets z to x and returns z. func (z *Rat) SetInt64(x int64) *Rat { z.a.SetInt64(x) - z.b = z.b.setWord(1) + z.b = z.b.make(0) + return z +} + +// Set sets z to x (by making a copy of x) and returns z. +func (z *Rat) Set(x *Rat) *Rat { + if z != x { + z.a.Set(&x.a) + z.b = z.b.set(x.b) + } + return z +} + +// Abs sets z to |x| (the absolute value of x) and returns z. +func (z *Rat) Abs(x *Rat) *Rat { + z.Set(x) + z.a.neg = false + return z +} + +// Neg sets z to -x and returns z. +func (z *Rat) Neg(x *Rat) *Rat { + z.Set(x) + z.a.neg = len(z.a.abs) > 0 && !z.a.neg // 0 has no sign + return z +} + +// Inv sets z to 1/x and returns z. +func (z *Rat) Inv(x *Rat) *Rat { + if len(x.a.abs) == 0 { + panic("division by zero") + } + z.Set(x) + a := z.b + if len(a) == 0 { + a = a.setWord(1) // materialize numerator + } + b := z.a.abs + if b.cmp(natOne) == 0 { + b = b.make(0) // normalize denominator + } + z.a.abs, z.b = a, b // sign doesn't change return z } @@ -74,21 +121,24 @@ func (x *Rat) Sign() int { // IsInt returns true if the denominator of x is 1. func (x *Rat) IsInt() bool { - return len(x.b) == 1 && x.b[0] == 1 + return len(x.b) == 0 || x.b.cmp(natOne) == 0 } -// Num returns the numerator of z; it may be <= 0. -// The result is a reference to z's numerator; it -// may change if a new value is assigned to z. -func (z *Rat) Num() *Int { - return &z.a +// Num returns the numerator of x; it may be <= 0. +// The result is a reference to x's numerator; it +// may change if a new value is assigned to x. +func (x *Rat) Num() *Int { + return &x.a } -// Denom returns the denominator of z; it is always > 0. -// The result is a reference to z's denominator; it -// may change if a new value is assigned to z. -func (z *Rat) Denom() *Int { - return &Int{false, z.b} +// Denom returns the denominator of x; it is always > 0. +// The result is a reference to x's denominator; it +// may change if a new value is assigned to x. +func (x *Rat) Denom() *Int { + if len(x.b) == 0 { + return &Int{abs: nat{1}} + } + return &Int{abs: x.b} } func gcd(x, y nat) nat { @@ -106,24 +156,47 @@ func gcd(x, y nat) nat { } func (z *Rat) norm() *Rat { - f := gcd(z.a.abs, z.b) - if len(z.a.abs) == 0 { - // z == 0 - z.a.neg = false // normalize sign - z.b = z.b.setWord(1) - return z - } - if f.cmp(natOne) != 0 { - z.a.abs, _ = z.a.abs.div(nil, z.a.abs, f) - z.b, _ = z.b.div(nil, z.b, f) + switch { + case len(z.a.abs) == 0: + // z == 0 - normalize sign and denominator + z.a.neg = false + z.b = z.b.make(0) + case len(z.b) == 0: + // z is normalized int - nothing to do + case z.b.cmp(natOne) == 0: + // z is int - normalize denominator + z.b = z.b.make(0) + default: + if f := gcd(z.a.abs, z.b); f.cmp(natOne) != 0 { + z.a.abs, _ = z.a.abs.div(nil, z.a.abs, f) + z.b, _ = z.b.div(nil, z.b, f) + } } return z } -func mulNat(x *Int, y nat) *Int { +// mulDenom sets z to the denominator product x*y (by taking into +// account that 0 values for x or y must be interpreted as 1) and +// returns z. +func mulDenom(z, x, y nat) nat { + switch { + case len(x) == 0: + return z.set(y) + case len(y) == 0: + return z.set(x) + } + return z.mul(x, y) +} + +// scaleDenom computes x*f. +// If f == 0 (zero value of denominator), the result is (a copy of) x. +func scaleDenom(x *Int, f nat) *Int { var z Int - z.abs = z.abs.mul(x.abs, y) - z.neg = len(z.abs) > 0 && x.neg + if len(f) == 0 { + return z.Set(x) + } + z.abs = z.abs.mul(x.abs, f) + z.neg = x.neg return &z } @@ -133,39 +206,32 @@ func mulNat(x *Int, y nat) *Int { // 0 if x == y // +1 if x > y // -func (x *Rat) Cmp(y *Rat) (r int) { - return mulNat(&x.a, y.b).Cmp(mulNat(&y.a, x.b)) -} - -// Abs sets z to |x| (the absolute value of x) and returns z. -func (z *Rat) Abs(x *Rat) *Rat { - z.a.Abs(&x.a) - z.b = z.b.set(x.b) - return z +func (x *Rat) Cmp(y *Rat) int { + return scaleDenom(&x.a, y.b).Cmp(scaleDenom(&y.a, x.b)) } // Add sets z to the sum x+y and returns z. func (z *Rat) Add(x, y *Rat) *Rat { - a1 := mulNat(&x.a, y.b) - a2 := mulNat(&y.a, x.b) + a1 := scaleDenom(&x.a, y.b) + a2 := scaleDenom(&y.a, x.b) z.a.Add(a1, a2) - z.b = z.b.mul(x.b, y.b) + z.b = mulDenom(z.b, x.b, y.b) return z.norm() } // Sub sets z to the difference x-y and returns z. func (z *Rat) Sub(x, y *Rat) *Rat { - a1 := mulNat(&x.a, y.b) - a2 := mulNat(&y.a, x.b) + a1 := scaleDenom(&x.a, y.b) + a2 := scaleDenom(&y.a, x.b) z.a.Sub(a1, a2) - z.b = z.b.mul(x.b, y.b) + z.b = mulDenom(z.b, x.b, y.b) return z.norm() } // Mul sets z to the product x*y and returns z. func (z *Rat) Mul(x, y *Rat) *Rat { z.a.Mul(&x.a, &y.a) - z.b = z.b.mul(x.b, y.b) + z.b = mulDenom(z.b, x.b, y.b) return z.norm() } @@ -175,28 +241,14 @@ func (z *Rat) Quo(x, y *Rat) *Rat { if len(y.a.abs) == 0 { panic("division by zero") } - a := mulNat(&x.a, y.b) - b := mulNat(&y.a, x.b) + a := scaleDenom(&x.a, y.b) + b := scaleDenom(&y.a, x.b) z.a.abs = a.abs z.b = b.abs z.a.neg = a.neg != b.neg return z.norm() } -// Neg sets z to -x (by making a copy of x if necessary) and returns z. -func (z *Rat) Neg(x *Rat) *Rat { - z.a.Neg(&x.a) - z.b = z.b.set(x.b) - return z -} - -// Set sets z to x (by making a copy of x if necessary) and returns z. -func (z *Rat) Set(x *Rat) *Rat { - z.a.Set(&x.a) - z.b = z.b.set(x.b) - return z -} - func ratTok(ch int) bool { return strings.IndexRune("+-/0123456789.eE", ch) >= 0 } @@ -219,23 +271,23 @@ func (z *Rat) Scan(s fmt.ScanState, ch int) os.Error { // SetString sets z to the value of s and returns z and a boolean indicating // success. s can be given as a fraction "a/b" or as a floating-point number -// optionally followed by an exponent. If the operation failed, the value of z -// is undefined. +// optionally followed by an exponent. If the operation failed, the value of +// z is undefined but the returned value is nil. func (z *Rat) SetString(s string) (*Rat, bool) { if len(s) == 0 { - return z, false + return nil, false } // check for a quotient sep := strings.Index(s, "/") if sep >= 0 { if _, ok := z.a.SetString(s[0:sep], 10); !ok { - return z, false + return nil, false } s = s[sep+1:] var err os.Error if z.b, _, err = z.b.scan(strings.NewReader(s), 10); err != nil { - return z, false + return nil, false } return z.norm(), true } @@ -248,10 +300,10 @@ func (z *Rat) SetString(s string) (*Rat, bool) { if e >= 0 { if e < sep { // The E must come after the decimal point. - return z, false + return nil, false } if _, ok := exp.SetString(s[e+1:], 10); !ok { - return z, false + return nil, false } s = s[0:e] } @@ -261,7 +313,7 @@ func (z *Rat) SetString(s string) (*Rat, bool) { } if _, ok := z.a.SetString(s, 10); !ok { - return z, false + return nil, false } powTen := nat{}.expNN(natTen, exp.abs, nil) if exp.neg { @@ -269,7 +321,7 @@ func (z *Rat) SetString(s string) (*Rat, bool) { z.norm() } else { z.a.abs = z.a.abs.mul(z.a.abs, powTen) - z.b = z.b.setWord(1) + z.b = z.b.make(0) } return z, true @@ -277,7 +329,11 @@ func (z *Rat) SetString(s string) (*Rat, bool) { // String returns a string representation of z in the form "a/b" (even if b == 1). func (z *Rat) String() string { - return z.a.String() + "/" + z.b.decimalString() + s := "/1" + if len(z.b) != 0 { + s = "/" + z.b.decimalString() + } + return z.a.String() + s } // RatString returns a string representation of z in the form "a/b" if b != 1, @@ -299,6 +355,7 @@ func (z *Rat) FloatString(prec int) string { } return s } + // z.b != 0 q, r := nat{}.div(nat{}, z.a.abs, z.b) diff --git a/libgo/go/big/rat_test.go b/libgo/go/big/rat_test.go index a2b9055..a95e5fe 100644 --- a/libgo/go/big/rat_test.go +++ b/libgo/go/big/rat_test.go @@ -11,6 +11,46 @@ import ( "testing" ) +func TestZeroRat(t *testing.T) { + var x, y, z Rat + y.SetFrac64(0, 42) + + if x.Cmp(&y) != 0 { + t.Errorf("x and y should be both equal and zero") + } + + if s := x.String(); s != "0/1" { + t.Errorf("got x = %s, want 0/1", s) + } + + if s := x.RatString(); s != "0" { + t.Errorf("got x = %s, want 0", s) + } + + z.Add(&x, &y) + if s := z.RatString(); s != "0" { + t.Errorf("got x+y = %s, want 0", s) + } + + z.Sub(&x, &y) + if s := z.RatString(); s != "0" { + t.Errorf("got x-y = %s, want 0", s) + } + + z.Mul(&x, &y) + if s := z.RatString(); s != "0" { + t.Errorf("got x*y = %s, want 0", s) + } + + // check for division by zero + defer func() { + if s := recover(); s == nil || s.(string) != "division by zero" { + panic(s) + } + }() + z.Quo(&x, &y) +} + var setStringTests = []struct { in, out string ok bool @@ -50,8 +90,14 @@ func TestRatSetString(t *testing.T) { for i, test := range setStringTests { x, ok := new(Rat).SetString(test.in) - if ok != test.ok || ok && x.RatString() != test.out { - t.Errorf("#%d got %s want %s", i, x.RatString(), test.out) + if ok { + if !test.ok { + t.Errorf("#%d SetString(%q) expected failure", i, test.in) + } else if x.RatString() != test.out { + t.Errorf("#%d SetString(%q) got %s want %s", i, test.in, x.RatString(), test.out) + } + } else if x != nil { + t.Errorf("#%d SetString(%q) got %p want nil", i, test.in, x) } } } @@ -113,8 +159,10 @@ func TestFloatString(t *testing.T) { func TestRatSign(t *testing.T) { zero := NewRat(0, 1) for _, a := range setStringTests { - var x Rat - x.SetString(a.in) + x, ok := new(Rat).SetString(a.in) + if !ok { + continue + } s := x.Sign() e := x.Cmp(zero) if s != e { @@ -153,29 +201,65 @@ func TestRatCmp(t *testing.T) { func TestIsInt(t *testing.T) { one := NewInt(1) for _, a := range setStringTests { - var x Rat - x.SetString(a.in) + x, ok := new(Rat).SetString(a.in) + if !ok { + continue + } i := x.IsInt() e := x.Denom().Cmp(one) == 0 if i != e { - t.Errorf("got %v; want %v for z = %v", i, e, &x) + t.Errorf("got IsInt(%v) == %v; want %v", x, i, e) } } } func TestRatAbs(t *testing.T) { - zero := NewRat(0, 1) + zero := new(Rat) for _, a := range setStringTests { - var z Rat - z.SetString(a.in) - var e Rat - e.Set(&z) + x, ok := new(Rat).SetString(a.in) + if !ok { + continue + } + e := new(Rat).Set(x) if e.Cmp(zero) < 0 { - e.Sub(zero, &e) + e.Sub(zero, e) + } + z := new(Rat).Abs(x) + if z.Cmp(e) != 0 { + t.Errorf("got Abs(%v) = %v; want %v", x, z, e) + } + } +} + +func TestRatNeg(t *testing.T) { + zero := new(Rat) + for _, a := range setStringTests { + x, ok := new(Rat).SetString(a.in) + if !ok { + continue + } + e := new(Rat).Sub(zero, x) + z := new(Rat).Neg(x) + if z.Cmp(e) != 0 { + t.Errorf("got Neg(%v) = %v; want %v", x, z, e) + } + } +} + +func TestRatInv(t *testing.T) { + zero := new(Rat) + for _, a := range setStringTests { + x, ok := new(Rat).SetString(a.in) + if !ok { + continue + } + if x.Cmp(zero) == 0 { + continue // avoid division by zero } - z.Abs(&z) - if z.Cmp(&e) != 0 { - t.Errorf("got z = %v; want %v", &z, &e) + e := new(Rat).SetFrac(x.Denom(), x.Num()) + z := new(Rat).Inv(x) + if z.Cmp(e) != 0 { + t.Errorf("got Inv(%v) = %v; want %v", x, z, e) } } } @@ -186,10 +270,10 @@ type ratBinArg struct { } func testRatBin(t *testing.T, i int, name string, f ratBinFun, a ratBinArg) { - x, _ := NewRat(0, 1).SetString(a.x) - y, _ := NewRat(0, 1).SetString(a.y) - z, _ := NewRat(0, 1).SetString(a.z) - out := f(NewRat(0, 1), x, y) + x, _ := new(Rat).SetString(a.x) + y, _ := new(Rat).SetString(a.y) + z, _ := new(Rat).SetString(a.z) + out := f(new(Rat), x, y) if out.Cmp(z) != 0 { t.Errorf("%s #%d got %s want %s", name, i, out, z) diff --git a/libgo/go/crypto/x509/x509.go b/libgo/go/crypto/x509/x509.go index 4b8ecc5..73b32e7 100644 --- a/libgo/go/crypto/x509/x509.go +++ b/libgo/go/crypto/x509/x509.go @@ -928,11 +928,11 @@ func CreateCertificate(rand io.Reader, template, parent *Certificate, pub *rsa.P return } - asn1Issuer, err := asn1.Marshal(parent.Issuer.ToRDNSequence()) + asn1Issuer, err := asn1.Marshal(parent.Subject.ToRDNSequence()) if err != nil { return } - asn1Subject, err := asn1.Marshal(parent.Subject.ToRDNSequence()) + asn1Subject, err := asn1.Marshal(template.Subject.ToRDNSequence()) if err != nil { return } diff --git a/libgo/go/crypto/x509/x509_test.go b/libgo/go/crypto/x509/x509_test.go index dbc5273..d113f85 100644 --- a/libgo/go/crypto/x509/x509_test.go +++ b/libgo/go/crypto/x509/x509_test.go @@ -6,8 +6,8 @@ package x509 import ( "asn1" - "bytes" "big" + "bytes" "crypto/dsa" "crypto/rand" "crypto/rsa" @@ -243,10 +243,11 @@ func TestCreateSelfSignedCertificate(t *testing.T) { return } + commonName := "test.example.com" template := Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{ - CommonName: "test.example.com", + CommonName: commonName, Organization: []string{"Acme Co"}, }, NotBefore: time.SecondsToUTC(1000), @@ -283,6 +284,14 @@ func TestCreateSelfSignedCertificate(t *testing.T) { t.Errorf("Failed to parse name constraints: %#v", cert.PermittedDNSDomains) } + if cert.Subject.CommonName != commonName { + t.Errorf("Subject wasn't correctly copied from the template. Got %s, want %s", cert.Subject.CommonName, commonName) + } + + if cert.Issuer.CommonName != commonName { + t.Errorf("Issuer wasn't correctly copied from the template. Got %s, want %s", cert.Issuer.CommonName, commonName) + } + err = cert.CheckSignatureFrom(cert) if err != nil { t.Errorf("Signature verification failed: %s", err) diff --git a/libgo/go/exp/inotify/inotify_linux.go b/libgo/go/exp/inotify/inotify_linux.go new file mode 100644 index 0000000..ee3c75f --- /dev/null +++ b/libgo/go/exp/inotify/inotify_linux.go @@ -0,0 +1,288 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package inotify implements a wrapper for the Linux inotify system. + +Example: + watcher, err := inotify.NewWatcher() + if err != nil { + log.Fatal(err) + } + err = watcher.Watch("/tmp") + if err != nil { + log.Fatal(err) + } + for { + select { + case ev := <-watcher.Event: + log.Println("event:", ev) + case err := <-watcher.Error: + log.Println("error:", err) + } + } + +*/ +package inotify + +import ( + "fmt" + "os" + "strings" + "syscall" + "unsafe" +) + +type Event struct { + Mask uint32 // Mask of events + Cookie uint32 // Unique cookie associating related events (for rename(2)) + Name string // File name (optional) +} + +type watch struct { + wd uint32 // Watch descriptor (as returned by the inotify_add_watch() syscall) + flags uint32 // inotify flags of this watch (see inotify(7) for the list of valid flags) +} + +type Watcher struct { + fd int // File descriptor (as returned by the inotify_init() syscall) + watches map[string]*watch // Map of inotify watches (key: path) + paths map[int]string // Map of watched paths (key: watch descriptor) + Error chan os.Error // Errors are sent on this channel + Event chan *Event // Events are returned on this channel + done chan bool // Channel for sending a "quit message" to the reader goroutine + isClosed bool // Set to true when Close() is first called +} + +// NewWatcher creates and returns a new inotify instance using inotify_init(2) +func NewWatcher() (*Watcher, os.Error) { + fd, errno := syscall.InotifyInit() + if fd == -1 { + return nil, os.NewSyscallError("inotify_init", errno) + } + w := &Watcher{ + fd: fd, + watches: make(map[string]*watch), + paths: make(map[int]string), + Event: make(chan *Event), + Error: make(chan os.Error), + done: make(chan bool, 1), + } + + go w.readEvents() + return w, nil +} + +// Close closes an inotify watcher instance +// It sends a message to the reader goroutine to quit and removes all watches +// associated with the inotify instance +func (w *Watcher) Close() os.Error { + if w.isClosed { + return nil + } + w.isClosed = true + + // Send "quit" message to the reader goroutine + w.done <- true + for path := range w.watches { + w.RemoveWatch(path) + } + + return nil +} + +// AddWatch adds path to the watched file set. +// The flags are interpreted as described in inotify_add_watch(2). +func (w *Watcher) AddWatch(path string, flags uint32) os.Error { + if w.isClosed { + return os.NewError("inotify instance already closed") + } + + watchEntry, found := w.watches[path] + if found { + watchEntry.flags |= flags + flags |= syscall.IN_MASK_ADD + } + wd, errno := syscall.InotifyAddWatch(w.fd, path, flags) + if wd == -1 { + return &os.PathError{"inotify_add_watch", path, os.Errno(errno)} + } + + if !found { + w.watches[path] = &watch{wd: uint32(wd), flags: flags} + w.paths[wd] = path + } + return nil +} + +// Watch adds path to the watched file set, watching all events. +func (w *Watcher) Watch(path string) os.Error { + return w.AddWatch(path, IN_ALL_EVENTS) +} + +// RemoveWatch removes path from the watched file set. +func (w *Watcher) RemoveWatch(path string) os.Error { + watch, ok := w.watches[path] + if !ok { + return os.NewError(fmt.Sprintf("can't remove non-existent inotify watch for: %s", path)) + } + success, errno := syscall.InotifyRmWatch(w.fd, watch.wd) + if success == -1 { + return os.NewSyscallError("inotify_rm_watch", errno) + } + delete(w.watches, path) + return nil +} + +// readEvents reads from the inotify file descriptor, converts the +// received events into Event objects and sends them via the Event channel +func (w *Watcher) readEvents() { + var ( + buf [syscall.SizeofInotifyEvent * 4096]byte // Buffer for a maximum of 4096 raw events + n int // Number of bytes read with read() + errno int // Syscall errno + ) + + for { + n, errno = syscall.Read(w.fd, buf[0:]) + // See if there is a message on the "done" channel + var done bool + select { + case done = <-w.done: + default: + } + + // If EOF or a "done" message is received + if n == 0 || done { + errno := syscall.Close(w.fd) + if errno == -1 { + w.Error <- os.NewSyscallError("close", errno) + } + close(w.Event) + close(w.Error) + return + } + if n < 0 { + w.Error <- os.NewSyscallError("read", errno) + continue + } + if n < syscall.SizeofInotifyEvent { + w.Error <- os.NewError("inotify: short read in readEvents()") + continue + } + + var offset uint32 = 0 + // We don't know how many events we just read into the buffer + // While the offset points to at least one whole event... + for offset <= uint32(n-syscall.SizeofInotifyEvent) { + // Point "raw" to the event in the buffer + raw := (*syscall.InotifyEvent)(unsafe.Pointer(&buf[offset])) + event := new(Event) + event.Mask = uint32(raw.Mask) + event.Cookie = uint32(raw.Cookie) + nameLen := uint32(raw.Len) + // If the event happened to the watched directory or the watched file, the kernel + // doesn't append the filename to the event, but we would like to always fill the + // the "Name" field with a valid filename. We retrieve the path of the watch from + // the "paths" map. + event.Name = w.paths[int(raw.Wd)] + if nameLen > 0 { + // Point "bytes" at the first byte of the filename + bytes := (*[syscall.PathMax]byte)(unsafe.Pointer(&buf[offset+syscall.SizeofInotifyEvent])) + // The filename is padded with NUL bytes. TrimRight() gets rid of those. + event.Name += "/" + strings.TrimRight(string(bytes[0:nameLen]), "\000") + } + // Send the event on the events channel + w.Event <- event + + // Move to the next event in the buffer + offset += syscall.SizeofInotifyEvent + nameLen + } + } +} + +// String formats the event e in the form +// "filename: 0xEventMask = IN_ACCESS|IN_ATTRIB_|..." +func (e *Event) String() string { + var events string = "" + + m := e.Mask + for _, b := range eventBits { + if m&b.Value != 0 { + m &^= b.Value + events += "|" + b.Name + } + } + + if m != 0 { + events += fmt.Sprintf("|%#x", m) + } + if len(events) > 0 { + events = " == " + events[1:] + } + + return fmt.Sprintf("%q: %#x%s", e.Name, e.Mask, events) +} + +const ( + // Options for inotify_init() are not exported + // IN_CLOEXEC uint32 = syscall.IN_CLOEXEC + // IN_NONBLOCK uint32 = syscall.IN_NONBLOCK + + // Options for AddWatch + IN_DONT_FOLLOW uint32 = syscall.IN_DONT_FOLLOW + IN_ONESHOT uint32 = syscall.IN_ONESHOT + IN_ONLYDIR uint32 = syscall.IN_ONLYDIR + + // The "IN_MASK_ADD" option is not exported, as AddWatch + // adds it automatically, if there is already a watch for the given path + // IN_MASK_ADD uint32 = syscall.IN_MASK_ADD + + // Events + IN_ACCESS uint32 = syscall.IN_ACCESS + IN_ALL_EVENTS uint32 = syscall.IN_ALL_EVENTS + IN_ATTRIB uint32 = syscall.IN_ATTRIB + IN_CLOSE uint32 = syscall.IN_CLOSE + IN_CLOSE_NOWRITE uint32 = syscall.IN_CLOSE_NOWRITE + IN_CLOSE_WRITE uint32 = syscall.IN_CLOSE_WRITE + IN_CREATE uint32 = syscall.IN_CREATE + IN_DELETE uint32 = syscall.IN_DELETE + IN_DELETE_SELF uint32 = syscall.IN_DELETE_SELF + IN_MODIFY uint32 = syscall.IN_MODIFY + IN_MOVE uint32 = syscall.IN_MOVE + IN_MOVED_FROM uint32 = syscall.IN_MOVED_FROM + IN_MOVED_TO uint32 = syscall.IN_MOVED_TO + IN_MOVE_SELF uint32 = syscall.IN_MOVE_SELF + IN_OPEN uint32 = syscall.IN_OPEN + + // Special events + IN_ISDIR uint32 = syscall.IN_ISDIR + IN_IGNORED uint32 = syscall.IN_IGNORED + IN_Q_OVERFLOW uint32 = syscall.IN_Q_OVERFLOW + IN_UNMOUNT uint32 = syscall.IN_UNMOUNT +) + +var eventBits = []struct { + Value uint32 + Name string +}{ + {IN_ACCESS, "IN_ACCESS"}, + {IN_ATTRIB, "IN_ATTRIB"}, + {IN_CLOSE, "IN_CLOSE"}, + {IN_CLOSE_NOWRITE, "IN_CLOSE_NOWRITE"}, + {IN_CLOSE_WRITE, "IN_CLOSE_WRITE"}, + {IN_CREATE, "IN_CREATE"}, + {IN_DELETE, "IN_DELETE"}, + {IN_DELETE_SELF, "IN_DELETE_SELF"}, + {IN_MODIFY, "IN_MODIFY"}, + {IN_MOVE, "IN_MOVE"}, + {IN_MOVED_FROM, "IN_MOVED_FROM"}, + {IN_MOVED_TO, "IN_MOVED_TO"}, + {IN_MOVE_SELF, "IN_MOVE_SELF"}, + {IN_OPEN, "IN_OPEN"}, + {IN_ISDIR, "IN_ISDIR"}, + {IN_IGNORED, "IN_IGNORED"}, + {IN_Q_OVERFLOW, "IN_Q_OVERFLOW"}, + {IN_UNMOUNT, "IN_UNMOUNT"}, +} diff --git a/libgo/go/exp/inotify/inotify_linux_test.go b/libgo/go/exp/inotify/inotify_linux_test.go new file mode 100644 index 0000000..a6bb46f --- /dev/null +++ b/libgo/go/exp/inotify/inotify_linux_test.go @@ -0,0 +1,96 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package inotify + +import ( + "os" + "testing" + "time" +) + +func TestInotifyEvents(t *testing.T) { + // Create an inotify watcher instance and initialize it + watcher, err := NewWatcher() + if err != nil { + t.Fatalf("NewWatcher() failed: %s", err) + } + + // Add a watch for "_test" + err = watcher.Watch("_test") + if err != nil { + t.Fatalf("Watcher.Watch() failed: %s", err) + } + + // Receive errors on the error channel on a separate goroutine + go func() { + for err := range watcher.Error { + t.Fatalf("error received: %s", err) + } + }() + + const testFile string = "_test/TestInotifyEvents.testfile" + + // Receive events on the event channel on a separate goroutine + eventstream := watcher.Event + var eventsReceived = 0 + done := make(chan bool) + go func() { + for event := range eventstream { + // Only count relevant events + if event.Name == testFile { + eventsReceived++ + t.Logf("event received: %s", event) + } else { + t.Logf("unexpected event received: %s", event) + } + } + done <- true + }() + + // Create a file + // This should add at least one event to the inotify event queue + _, err = os.OpenFile(testFile, os.O_WRONLY|os.O_CREATE, 0666) + if err != nil { + t.Fatalf("creating test file failed: %s", err) + } + + // We expect this event to be received almost immediately, but let's wait 1 s to be sure + time.Sleep(1000e6) // 1000 ms + if eventsReceived == 0 { + t.Fatal("inotify event hasn't been received after 1 second") + } + + // Try closing the inotify instance + t.Log("calling Close()") + watcher.Close() + t.Log("waiting for the event channel to become closed...") + select { + case <-done: + t.Log("event channel closed") + case <-time.After(1e9): + t.Fatal("event stream was not closed after 1 second") + } +} + +func TestInotifyClose(t *testing.T) { + watcher, _ := NewWatcher() + watcher.Close() + + done := false + go func() { + watcher.Close() + done = true + }() + + time.Sleep(50e6) // 50 ms + if !done { + t.Fatal("double Close() test failed: second Close() call didn't return") + } + + err := watcher.Watch("_test") + if err == nil { + t.Fatal("expected error on Watch() after Close(), got nil") + } +} diff --git a/libgo/go/exp/ssh/channel.go b/libgo/go/exp/ssh/channel.go index 922584f..f69b735 100644 --- a/libgo/go/exp/ssh/channel.go +++ b/libgo/go/exp/ssh/channel.go @@ -68,7 +68,7 @@ type channel struct { weClosed bool dead bool - serverConn *ServerConnection + serverConn *ServerConn myId, theirId uint32 myWindow, theirWindow uint32 maxPacketSize uint32 diff --git a/libgo/go/exp/ssh/client.go b/libgo/go/exp/ssh/client.go new file mode 100644 index 0000000..3311385 --- /dev/null +++ b/libgo/go/exp/ssh/client.go @@ -0,0 +1,490 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "big" + "crypto" + "crypto/rand" + "fmt" + "io" + "net" + "os" + "sync" +) + +// clientVersion is the fixed identification string that the client will use. +var clientVersion = []byte("SSH-2.0-Go\r\n") + +// ClientConn represents the client side of an SSH connection. +type ClientConn struct { + *transport + config *ClientConfig + chanlist +} + +// Client returns a new SSH client connection using c as the underlying transport. +func Client(c net.Conn, config *ClientConfig) (*ClientConn, os.Error) { + conn := &ClientConn{ + transport: newTransport(c, config.rand()), + config: config, + } + if err := conn.handshake(); err != nil { + conn.Close() + return nil, err + } + if err := conn.authenticate(); err != nil { + conn.Close() + return nil, err + } + go conn.mainLoop() + return conn, nil +} + +// handshake performs the client side key exchange. See RFC 4253 Section 7. +func (c *ClientConn) handshake() os.Error { + var magics handshakeMagics + + if _, err := c.Write(clientVersion); err != nil { + return err + } + if err := c.Flush(); err != nil { + return err + } + magics.clientVersion = clientVersion[:len(clientVersion)-2] + + // read remote server version + version, err := readVersion(c) + if err != nil { + return err + } + magics.serverVersion = version + clientKexInit := kexInitMsg{ + KexAlgos: supportedKexAlgos, + ServerHostKeyAlgos: supportedHostKeyAlgos, + CiphersClientServer: supportedCiphers, + CiphersServerClient: supportedCiphers, + MACsClientServer: supportedMACs, + MACsServerClient: supportedMACs, + CompressionClientServer: supportedCompressions, + CompressionServerClient: supportedCompressions, + } + kexInitPacket := marshal(msgKexInit, clientKexInit) + magics.clientKexInit = kexInitPacket + + if err := c.writePacket(kexInitPacket); err != nil { + return err + } + packet, err := c.readPacket() + if err != nil { + return err + } + + magics.serverKexInit = packet + + var serverKexInit kexInitMsg + if err = unmarshal(&serverKexInit, packet, msgKexInit); err != nil { + return err + } + + kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(c.transport, &clientKexInit, &serverKexInit) + if !ok { + return os.NewError("ssh: no common algorithms") + } + + if serverKexInit.FirstKexFollows && kexAlgo != serverKexInit.KexAlgos[0] { + // The server sent a Kex message for the wrong algorithm, + // which we have to ignore. + if _, err := c.readPacket(); err != nil { + return err + } + } + + var H, K []byte + var hashFunc crypto.Hash + switch kexAlgo { + case kexAlgoDH14SHA1: + hashFunc = crypto.SHA1 + dhGroup14Once.Do(initDHGroup14) + H, K, err = c.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo) + default: + err = fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo) + } + if err != nil { + return err + } + + if err = c.writePacket([]byte{msgNewKeys}); err != nil { + return err + } + if err = c.transport.writer.setupKeys(clientKeys, K, H, H, hashFunc); err != nil { + return err + } + if packet, err = c.readPacket(); err != nil { + return err + } + if packet[0] != msgNewKeys { + return UnexpectedMessageError{msgNewKeys, packet[0]} + } + return c.transport.reader.setupKeys(serverKeys, K, H, H, hashFunc) +} + +// authenticate authenticates with the remote server. See RFC 4252. +// Only "password" authentication is supported. +func (c *ClientConn) authenticate() os.Error { + if err := c.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil { + return err + } + packet, err := c.readPacket() + if err != nil { + return err + } + + var serviceAccept serviceAcceptMsg + if err = unmarshal(&serviceAccept, packet, msgServiceAccept); err != nil { + return err + } + + // TODO(dfc) support proper authentication method negotation + method := "none" + if c.config.Password != "" { + method = "password" + } + if err := c.sendUserAuthReq(method); err != nil { + return err + } + + if packet, err = c.readPacket(); err != nil { + return err + } + + if packet[0] != msgUserAuthSuccess { + return UnexpectedMessageError{msgUserAuthSuccess, packet[0]} + } + return nil +} + +func (c *ClientConn) sendUserAuthReq(method string) os.Error { + length := stringLength([]byte(c.config.Password)) + 1 + payload := make([]byte, length) + // always false for password auth, see RFC 4252 Section 8. + payload[0] = 0 + marshalString(payload[1:], []byte(c.config.Password)) + + return c.writePacket(marshal(msgUserAuthRequest, userAuthRequestMsg{ + User: c.config.User, + Service: serviceSSH, + Method: method, + Payload: payload, + })) +} + +// kexDH performs Diffie-Hellman key agreement on a ClientConn. The +// returned values are given the same names as in RFC 4253, section 8. +func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) ([]byte, []byte, os.Error) { + x, err := rand.Int(c.config.rand(), group.p) + if err != nil { + return nil, nil, err + } + X := new(big.Int).Exp(group.g, x, group.p) + kexDHInit := kexDHInitMsg{ + X: X, + } + if err := c.writePacket(marshal(msgKexDHInit, kexDHInit)); err != nil { + return nil, nil, err + } + + packet, err := c.readPacket() + if err != nil { + return nil, nil, err + } + + var kexDHReply = new(kexDHReplyMsg) + if err = unmarshal(kexDHReply, packet, msgKexDHReply); err != nil { + return nil, nil, err + } + + if kexDHReply.Y.Sign() == 0 || kexDHReply.Y.Cmp(group.p) >= 0 { + return nil, nil, os.NewError("server DH parameter out of bounds") + } + + kInt := new(big.Int).Exp(kexDHReply.Y, x, group.p) + h := hashFunc.New() + writeString(h, magics.clientVersion) + writeString(h, magics.serverVersion) + writeString(h, magics.clientKexInit) + writeString(h, magics.serverKexInit) + writeString(h, kexDHReply.HostKey) + writeInt(h, X) + writeInt(h, kexDHReply.Y) + K := make([]byte, intLength(kInt)) + marshalInt(K, kInt) + h.Write(K) + + H := h.Sum() + + return H, K, nil +} + +// openChan opens a new client channel. The most common session type is "session". +// The full set of valid session types are listed in RFC 4250 4.9.1. +func (c *ClientConn) openChan(typ string) (*clientChan, os.Error) { + ch := c.newChan(c.transport) + if err := c.writePacket(marshal(msgChannelOpen, channelOpenMsg{ + ChanType: typ, + PeersId: ch.id, + PeersWindow: 1 << 14, + MaxPacketSize: 1 << 15, // RFC 4253 6.1 + })); err != nil { + c.chanlist.remove(ch.id) + return nil, err + } + // wait for response + switch msg := (<-ch.msg).(type) { + case *channelOpenConfirmMsg: + ch.peersId = msg.MyId + case *channelOpenFailureMsg: + c.chanlist.remove(ch.id) + return nil, os.NewError(msg.Message) + default: + c.chanlist.remove(ch.id) + return nil, os.NewError("Unexpected packet") + } + return ch, nil +} + +// mainloop reads incoming messages and routes channel messages +// to their respective ClientChans. +func (c *ClientConn) mainLoop() { + for { + packet, err := c.readPacket() + if err != nil { + // TODO(dfc) signal the underlying close to all channels + c.Close() + return + } + // TODO(dfc) A note on blocking channel use. + // The msg, win, data and dataExt channels of a clientChan can + // cause this loop to block indefinately if the consumer does + // not service them. + switch msg := decode(packet).(type) { + case *channelOpenMsg: + c.getChan(msg.PeersId).msg <- msg + case *channelOpenConfirmMsg: + c.getChan(msg.PeersId).msg <- msg + case *channelOpenFailureMsg: + c.getChan(msg.PeersId).msg <- msg + case *channelCloseMsg: + ch := c.getChan(msg.PeersId) + close(ch.win) + close(ch.data) + close(ch.dataExt) + c.chanlist.remove(msg.PeersId) + case *channelEOFMsg: + c.getChan(msg.PeersId).msg <- msg + case *channelRequestSuccessMsg: + c.getChan(msg.PeersId).msg <- msg + case *channelRequestFailureMsg: + c.getChan(msg.PeersId).msg <- msg + case *channelRequestMsg: + c.getChan(msg.PeersId).msg <- msg + case *windowAdjustMsg: + c.getChan(msg.PeersId).win <- int(msg.AdditionalBytes) + case *channelData: + c.getChan(msg.PeersId).data <- msg.Payload + case *channelExtendedData: + // RFC 4254 5.2 defines data_type_code 1 to be data destined + // for stderr on interactive sessions. Other data types are + // silently discarded. + if msg.Datatype == 1 { + c.getChan(msg.PeersId).dataExt <- msg.Payload + } + default: + fmt.Printf("mainLoop: unhandled %#v\n", msg) + } + } +} + +// Dial connects to the given network address using net.Dial and +// then initiates a SSH handshake, returning the resulting client connection. +func Dial(network, addr string, config *ClientConfig) (*ClientConn, os.Error) { + conn, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + return Client(conn, config) +} + +// A ClientConfig structure is used to configure a ClientConn. After one has +// been passed to an SSH function it must not be modified. +type ClientConfig struct { + // Rand provides the source of entropy for key exchange. If Rand is + // nil, the cryptographic random reader in package crypto/rand will + // be used. + Rand io.Reader + + // The username to authenticate. + User string + + // Used for "password" method authentication. + Password string +} + +func (c *ClientConfig) rand() io.Reader { + if c.Rand == nil { + return rand.Reader + } + return c.Rand +} + +// A clientChan represents a single RFC 4254 channel that is multiplexed +// over a single SSH connection. +type clientChan struct { + packetWriter + id, peersId uint32 + data chan []byte // receives the payload of channelData messages + dataExt chan []byte // receives the payload of channelExtendedData messages + win chan int // receives window adjustments + msg chan interface{} // incoming messages +} + +func newClientChan(t *transport, id uint32) *clientChan { + return &clientChan{ + packetWriter: t, + id: id, + data: make(chan []byte, 16), + dataExt: make(chan []byte, 16), + win: make(chan int, 16), + msg: make(chan interface{}, 16), + } +} + +// Close closes the channel. This does not close the underlying connection. +func (c *clientChan) Close() os.Error { + return c.writePacket(marshal(msgChannelClose, channelCloseMsg{ + PeersId: c.id, + })) +} + +func (c *clientChan) sendChanReq(req channelRequestMsg) os.Error { + if err := c.writePacket(marshal(msgChannelRequest, req)); err != nil { + return err + } + msg := <-c.msg + if _, ok := msg.(*channelRequestSuccessMsg); ok { + return nil + } + return fmt.Errorf("failed to complete request: %s, %#v", req.Request, msg) +} + +// Thread safe channel list. +type chanlist struct { + // protects concurrent access to chans + sync.Mutex + // chans are indexed by the local id of the channel, clientChan.id. + // The PeersId value of messages received by ClientConn.mainloop is + // used to locate the right local clientChan in this slice. + chans []*clientChan +} + +// Allocate a new ClientChan with the next avail local id. +func (c *chanlist) newChan(t *transport) *clientChan { + c.Lock() + defer c.Unlock() + for i := range c.chans { + if c.chans[i] == nil { + ch := newClientChan(t, uint32(i)) + c.chans[i] = ch + return ch + } + } + i := len(c.chans) + ch := newClientChan(t, uint32(i)) + c.chans = append(c.chans, ch) + return ch +} + +func (c *chanlist) getChan(id uint32) *clientChan { + c.Lock() + defer c.Unlock() + return c.chans[int(id)] +} + +func (c *chanlist) remove(id uint32) { + c.Lock() + defer c.Unlock() + c.chans[int(id)] = nil +} + +// A chanWriter represents the stdin of a remote process. +type chanWriter struct { + win chan int // receives window adjustments + id uint32 // this channel's id + rwin int // current rwin size + packetWriter // for sending channelDataMsg +} + +// Write writes data to the remote process's standard input. +func (w *chanWriter) Write(data []byte) (n int, err os.Error) { + for { + if w.rwin == 0 { + win, ok := <-w.win + if !ok { + return 0, os.EOF + } + w.rwin += win + continue + } + n = len(data) + packet := make([]byte, 0, 9+n) + packet = append(packet, msgChannelData, + byte(w.id)>>24, byte(w.id)>>16, byte(w.id)>>8, byte(w.id), + byte(n)>>24, byte(n)>>16, byte(n)>>8, byte(n)) + err = w.writePacket(append(packet, data...)) + w.rwin -= n + return + } + panic("unreachable") +} + +func (w *chanWriter) Close() os.Error { + return w.writePacket(marshal(msgChannelEOF, channelEOFMsg{w.id})) +} + +// A chanReader represents stdout or stderr of a remote process. +type chanReader struct { + // TODO(dfc) a fixed size channel may not be the right data structure. + // If writes to this channel block, they will block mainLoop, making + // it unable to receive new messages from the remote side. + data chan []byte // receives data from remote + id uint32 + packetWriter // for sending windowAdjustMsg + buf []byte +} + +// Read reads data from the remote process's stdout or stderr. +func (r *chanReader) Read(data []byte) (int, os.Error) { + var ok bool + for { + if len(r.buf) > 0 { + n := copy(data, r.buf) + r.buf = r.buf[n:] + msg := windowAdjustMsg{ + PeersId: r.id, + AdditionalBytes: uint32(n), + } + return n, r.writePacket(marshal(msgChannelWindowAdjust, msg)) + } + r.buf, ok = <-r.data + if !ok { + return 0, os.EOF + } + } + panic("unreachable") +} + +func (r *chanReader) Close() os.Error { + return r.writePacket(marshal(msgChannelEOF, channelEOFMsg{r.id})) +} diff --git a/libgo/go/exp/ssh/doc.go b/libgo/go/exp/ssh/doc.go index 54a7ba9..fc842b0 100644 --- a/libgo/go/exp/ssh/doc.go +++ b/libgo/go/exp/ssh/doc.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. /* -Package ssh implements an SSH server. +Package ssh implements an SSH client and server. SSH is a transport security protocol, an authentication protocol and a family of application protocols. The most typical application level @@ -11,26 +11,29 @@ protocol is a remote shell and this is specifically implemented. However, the multiplexed nature of SSH is exposed to users that wish to support others. -An SSH server is represented by a Server, which manages a number of -ServerConnections and handles authentication. +An SSH server is represented by a ServerConfig, which holds certificate +details and handles authentication of ServerConns. - var s Server - s.PubKeyCallback = pubKeyAuth - s.PasswordCallback = passwordAuth + config := new(ServerConfig) + config.PubKeyCallback = pubKeyAuth + config.PasswordCallback = passwordAuth pemBytes, err := ioutil.ReadFile("id_rsa") if err != nil { panic("Failed to load private key") } - err = s.SetRSAPrivateKey(pemBytes) + err = config.SetRSAPrivateKey(pemBytes) if err != nil { panic("Failed to parse private key") } -Once a Server has been set up, connections can be attached. +Once a ServerConfig has been configured, connections can be accepted. - var sConn ServerConnection - sConn.Server = &s + listener := Listen("tcp", "0.0.0.0:2022", config) + sConn, err := listener.Accept() + if err != nil { + panic("failed to accept incoming connection") + } err = sConn.Handshake(conn) if err != nil { panic("failed to handshake") @@ -38,7 +41,6 @@ Once a Server has been set up, connections can be attached. An SSH connection multiplexes several channels, which must be accepted themselves: - for { channel, err := sConn.Accept() if err != nil { @@ -75,5 +77,29 @@ present a simple terminal interface. } return }() + +An SSH client is represented with a ClientConn. Currently only the "password" +authentication method is supported. + + config := &ClientConfig{ + User: "username", + Password: "123456", + } + client, err := Dial("yourserver.com:22", config) + +Each ClientConn can support multiple interactive sessions, represented by a Session. + + session, err := client.NewSession() + +Once a Session is created, you can execute a single command on the remote side +using the Exec method. + + if err := session.Exec("/usr/bin/whoami"); err != nil { + panic("Failed to exec: " + err.String()) + } + reader := bufio.NewReader(session.Stdin) + line, _, _ := reader.ReadLine() + fmt.Println(line) + session.Close() */ package ssh diff --git a/libgo/go/exp/ssh/messages.go b/libgo/go/exp/ssh/messages.go index 1d0bc57..7771f2b 100644 --- a/libgo/go/exp/ssh/messages.go +++ b/libgo/go/exp/ssh/messages.go @@ -154,7 +154,7 @@ type channelData struct { type channelExtendedData struct { PeersId uint32 Datatype uint32 - Data string + Payload []byte `ssh:"rest"` } type channelRequestMsg struct { diff --git a/libgo/go/exp/ssh/server.go b/libgo/go/exp/ssh/server.go index 410cafc..3a640fc 100644 --- a/libgo/go/exp/ssh/server.go +++ b/libgo/go/exp/ssh/server.go @@ -10,19 +10,23 @@ import ( "crypto" "crypto/rand" "crypto/rsa" - _ "crypto/sha1" "crypto/x509" "encoding/pem" + "io" "net" "os" "sync" ) -// Server represents an SSH server. A Server may have several ServerConnections. -type Server struct { +type ServerConfig struct { rsa *rsa.PrivateKey rsaSerialized []byte + // Rand provides the source of entropy for key exchange. If Rand is + // nil, the cryptographic random reader in package crypto/rand will + // be used. + Rand io.Reader + // NoClientAuth is true if clients are allowed to connect without // authenticating. NoClientAuth bool @@ -38,11 +42,18 @@ type Server struct { PubKeyCallback func(user, algo string, pubkey []byte) bool } +func (c *ServerConfig) rand() io.Reader { + if c.Rand == nil { + return rand.Reader + } + return c.Rand +} + // SetRSAPrivateKey sets the private key for a Server. A Server must have a // private key configured in order to accept connections. The private key must // be in the form of a PEM encoded, PKCS#1, RSA private key. The file "id_rsa" // typically contains such a key. -func (s *Server) SetRSAPrivateKey(pemBytes []byte) os.Error { +func (s *ServerConfig) SetRSAPrivateKey(pemBytes []byte) os.Error { block, _ := pem.Decode(pemBytes) if block == nil { return os.NewError("ssh: no key found") @@ -109,7 +120,7 @@ func parseRSASig(in []byte) (sig []byte, ok bool) { } // cachedPubKey contains the results of querying whether a public key is -// acceptable for a user. The cache only applies to a single ServerConnection. +// acceptable for a user. The cache only applies to a single ServerConn. type cachedPubKey struct { user, algo string pubKey []byte @@ -118,11 +129,10 @@ type cachedPubKey struct { const maxCachedPubKeys = 16 -// ServerConnection represents an incomming connection to a Server. -type ServerConnection struct { - Server *Server - +// A ServerConn represents an incomming connection. +type ServerConn struct { *transport + config *ServerConfig channels map[uint32]*channel nextChanId uint32 @@ -139,9 +149,20 @@ type ServerConnection struct { cachedPubKeys []cachedPubKey } +// Server returns a new SSH server connection +// using c as the underlying transport. +func Server(c net.Conn, config *ServerConfig) *ServerConn { + conn := &ServerConn{ + transport: newTransport(c, config.rand()), + channels: make(map[uint32]*channel), + config: config, + } + return conn +} + // kexDH performs Diffie-Hellman key agreement on a ServerConnection. The // returned values are given the same names as in RFC 4253, section 8. -func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err os.Error) { +func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err os.Error) { packet, err := s.readPacket() if err != nil { return @@ -155,7 +176,7 @@ func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *h return nil, nil, os.NewError("client DH parameter out of bounds") } - y, err := rand.Int(rand.Reader, group.p) + y, err := rand.Int(s.config.rand(), group.p) if err != nil { return } @@ -166,7 +187,7 @@ func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *h var serializedHostKey []byte switch hostKeyAlgo { case hostAlgoRSA: - serializedHostKey = s.Server.rsaSerialized + serializedHostKey = s.config.rsaSerialized default: return nil, nil, os.NewError("internal error") } @@ -192,7 +213,7 @@ func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *h var sig []byte switch hostKeyAlgo { case hostAlgoRSA: - sig, err = rsa.SignPKCS1v15(rand.Reader, s.Server.rsa, hashFunc, hh) + sig, err = rsa.SignPKCS1v15(s.config.rand(), s.config.rsa, hashFunc, hh) if err != nil { return } @@ -257,19 +278,20 @@ func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubK return ret } -// Handshake performs an SSH transport and client authentication on the given ServerConnection. -func (s *ServerConnection) Handshake(conn net.Conn) os.Error { +// Handshake performs an SSH transport and client authentication on the given ServerConn. +func (s *ServerConn) Handshake() os.Error { var magics handshakeMagics - s.transport = newTransport(conn, rand.Reader) - - if _, err := conn.Write(serverVersion); err != nil { + if _, err := s.Write(serverVersion); err != nil { + return err + } + if err := s.Flush(); err != nil { return err } magics.serverVersion = serverVersion[:len(serverVersion)-2] - version, ok := readVersion(s.transport) - if !ok { - return os.NewError("failed to read version string from client") + version, err := readVersion(s) + if err != nil { + return err } magics.clientVersion = version @@ -310,8 +332,7 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error { if clientKexInit.FirstKexFollows && kexAlgo != clientKexInit.KexAlgos[0] { // The client sent a Kex message for the wrong algorithm, // which we have to ignore. - _, err := s.readPacket() - if err != nil { + if _, err := s.readPacket(); err != nil { return err } } @@ -324,32 +345,27 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error { dhGroup14Once.Do(initDHGroup14) H, K, err = s.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo) default: - err = os.NewError("ssh: internal error") + err = os.NewError("ssh: unexpected key exchange algorithm " + kexAlgo) } - if err != nil { return err } - packet = []byte{msgNewKeys} - if err = s.writePacket(packet); err != nil { + if err = s.writePacket([]byte{msgNewKeys}); err != nil { return err } if err = s.transport.writer.setupKeys(serverKeys, K, H, H, hashFunc); err != nil { return err } - if packet, err = s.readPacket(); err != nil { return err } + if packet[0] != msgNewKeys { return UnexpectedMessageError{msgNewKeys, packet[0]} } - s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc) - - packet, err = s.readPacket() - if err != nil { + if packet, err = s.readPacket(); err != nil { return err } @@ -360,20 +376,16 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error { if serviceRequest.Service != serviceUserAuth { return os.NewError("ssh: requested service '" + serviceRequest.Service + "' before authenticating") } - serviceAccept := serviceAcceptMsg{ Service: serviceUserAuth, } - packet = marshal(msgServiceAccept, serviceAccept) - if err = s.writePacket(packet); err != nil { + if err = s.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil { return err } if err = s.authenticate(H); err != nil { return err } - - s.channels = make(map[uint32]*channel) return nil } @@ -382,8 +394,8 @@ func isAcceptableAlgo(algo string) bool { } // testPubKey returns true if the given public key is acceptable for the user. -func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool { - if s.Server.PubKeyCallback == nil || !isAcceptableAlgo(algo) { +func (s *ServerConn) testPubKey(user, algo string, pubKey []byte) bool { + if s.config.PubKeyCallback == nil || !isAcceptableAlgo(algo) { return false } @@ -393,7 +405,7 @@ func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool { } } - result := s.Server.PubKeyCallback(user, algo, pubKey) + result := s.config.PubKeyCallback(user, algo, pubKey) if len(s.cachedPubKeys) < maxCachedPubKeys { c := cachedPubKey{ user: user, @@ -408,7 +420,7 @@ func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool { return result } -func (s *ServerConnection) authenticate(H []byte) os.Error { +func (s *ServerConn) authenticate(H []byte) os.Error { var userAuthReq userAuthRequestMsg var err os.Error var packet []byte @@ -428,11 +440,11 @@ userAuthLoop: switch userAuthReq.Method { case "none": - if s.Server.NoClientAuth { + if s.config.NoClientAuth { break userAuthLoop } case "password": - if s.Server.PasswordCallback == nil { + if s.config.PasswordCallback == nil { break } payload := userAuthReq.Payload @@ -445,11 +457,11 @@ userAuthLoop: return ParseError{msgUserAuthRequest} } - if s.Server.PasswordCallback(userAuthReq.User, string(password)) { + if s.config.PasswordCallback(userAuthReq.User, string(password)) { break userAuthLoop } case "publickey": - if s.Server.PubKeyCallback == nil { + if s.config.PubKeyCallback == nil { break } payload := userAuthReq.Payload @@ -520,10 +532,10 @@ userAuthLoop: } var failureMsg userAuthFailureMsg - if s.Server.PasswordCallback != nil { + if s.config.PasswordCallback != nil { failureMsg.Methods = append(failureMsg.Methods, "password") } - if s.Server.PubKeyCallback != nil { + if s.config.PubKeyCallback != nil { failureMsg.Methods = append(failureMsg.Methods, "publickey") } @@ -546,9 +558,9 @@ userAuthLoop: const defaultWindowSize = 32768 -// Accept reads and processes messages on a ServerConnection. It must be called +// Accept reads and processes messages on a ServerConn. It must be called // in order to demultiplex messages to any resulting Channels. -func (s *ServerConnection) Accept() (Channel, os.Error) { +func (s *ServerConn) Accept() (Channel, os.Error) { if s.err != nil { return nil, s.err } @@ -643,3 +655,44 @@ func (s *ServerConnection) Accept() (Channel, os.Error) { panic("unreachable") } + +// A Listener implements a network listener (net.Listener) for SSH connections. +type Listener struct { + listener net.Listener + config *ServerConfig +} + +// Accept waits for and returns the next incoming SSH connection. +// The receiver should call Handshake() in another goroutine +// to avoid blocking the accepter. +func (l *Listener) Accept() (*ServerConn, os.Error) { + c, err := l.listener.Accept() + if err != nil { + return nil, err + } + conn := Server(c, l.config) + return conn, nil +} + +// Addr returns the listener's network address. +func (l *Listener) Addr() net.Addr { + return l.listener.Addr() +} + +// Close closes the listener. +func (l *Listener) Close() os.Error { + return l.listener.Close() +} + +// Listen creates an SSH listener accepting connections on +// the given network address using net.Listen. +func Listen(network, addr string, config *ServerConfig) (*Listener, os.Error) { + l, err := net.Listen(network, addr) + if err != nil { + return nil, err + } + return &Listener{ + l, + config, + }, nil +} diff --git a/libgo/go/exp/ssh/session.go b/libgo/go/exp/ssh/session.go new file mode 100644 index 0000000..13df2f0 --- /dev/null +++ b/libgo/go/exp/ssh/session.go @@ -0,0 +1,132 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Session implements an interactive session described in +// "RFC 4254, section 6". + +import ( + "encoding/binary" + "io" + "os" +) + +// A Session represents a connection to a remote command or shell. +type Session struct { + // Writes to Stdin are made available to the remote command's standard input. + // Closing Stdin causes the command to observe an EOF on its standard input. + Stdin io.WriteCloser + + // Reads from Stdout and Stderr consume from the remote command's standard + // output and error streams, respectively. + // There is a fixed amount of buffering that is shared for the two streams. + // Failing to read from either may eventually cause the command to block. + // Closing Stdout unblocks such writes and causes them to return errors. + Stdout io.ReadCloser + Stderr io.Reader + + *clientChan // the channel backing this session + + started bool // started is set to true once a Shell or Exec is invoked. +} + +// Setenv sets an environment variable that will be applied to any +// command executed by Shell or Exec. +func (s *Session) Setenv(name, value string) os.Error { + n, v := []byte(name), []byte(value) + nlen, vlen := stringLength(n), stringLength(v) + payload := make([]byte, nlen+vlen) + marshalString(payload[:nlen], n) + marshalString(payload[nlen:], v) + + return s.sendChanReq(channelRequestMsg{ + PeersId: s.id, + Request: "env", + WantReply: true, + RequestSpecificData: payload, + }) +} + +// An empty mode list (a string of 1 character, opcode 0), see RFC 4254 Section 8. +var emptyModeList = []byte{0, 0, 0, 1, 0} + +// RequestPty requests the association of a pty with the session on the remote host. +func (s *Session) RequestPty(term string, h, w int) os.Error { + buf := make([]byte, 4+len(term)+16+len(emptyModeList)) + b := marshalString(buf, []byte(term)) + binary.BigEndian.PutUint32(b, uint32(h)) + binary.BigEndian.PutUint32(b[4:], uint32(w)) + binary.BigEndian.PutUint32(b[8:], uint32(h*8)) + binary.BigEndian.PutUint32(b[12:], uint32(w*8)) + copy(b[16:], emptyModeList) + + return s.sendChanReq(channelRequestMsg{ + PeersId: s.id, + Request: "pty-req", + WantReply: true, + RequestSpecificData: buf, + }) +} + +// Exec runs cmd on the remote host. Typically, the remote +// server passes cmd to the shell for interpretation. +// A Session only accepts one call to Exec or Shell. +func (s *Session) Exec(cmd string) os.Error { + if s.started { + return os.NewError("session already started") + } + cmdLen := stringLength([]byte(cmd)) + payload := make([]byte, cmdLen) + marshalString(payload, []byte(cmd)) + s.started = true + + return s.sendChanReq(channelRequestMsg{ + PeersId: s.id, + Request: "exec", + WantReply: true, + RequestSpecificData: payload, + }) +} + +// Shell starts a login shell on the remote host. A Session only +// accepts one call to Exec or Shell. +func (s *Session) Shell() os.Error { + if s.started { + return os.NewError("session already started") + } + s.started = true + + return s.sendChanReq(channelRequestMsg{ + PeersId: s.id, + Request: "shell", + WantReply: true, + }) +} + +// NewSession returns a new interactive session on the remote host. +func (c *ClientConn) NewSession() (*Session, os.Error) { + ch, err := c.openChan("session") + if err != nil { + return nil, err + } + return &Session{ + Stdin: &chanWriter{ + packetWriter: ch, + id: ch.id, + win: ch.win, + }, + Stdout: &chanReader{ + packetWriter: ch, + id: ch.id, + data: ch.data, + }, + Stderr: &chanReader{ + packetWriter: ch, + id: ch.id, + data: ch.dataExt, + }, + clientChan: ch, + }, nil +} diff --git a/libgo/go/exp/ssh/transport.go b/libgo/go/exp/ssh/transport.go index 5994004..97eaf97 100644 --- a/libgo/go/exp/ssh/transport.go +++ b/libgo/go/exp/ssh/transport.go @@ -332,16 +332,15 @@ func (t truncatingMAC) Size() int { const maxVersionStringBytes = 1024 // Read version string as specified by RFC 4253, section 4.2. -func readVersion(r io.Reader) (versionString []byte, ok bool) { - versionString = make([]byte, 0, 64) - seenCR := false - +func readVersion(r io.Reader) ([]byte, os.Error) { + versionString := make([]byte, 0, 64) + var ok, seenCR bool var buf [1]byte forEachByte: for len(versionString) < maxVersionStringBytes { _, err := io.ReadFull(r, buf[:]) if err != nil { - return + return nil, err } b := buf[0] @@ -360,10 +359,10 @@ forEachByte: versionString = append(versionString, b) } - if ok { - // We need to remove the CR from versionString - versionString = versionString[:len(versionString)-1] + if !ok { + return nil, os.NewError("failed to read version string") } - return + // We need to remove the CR from versionString + return versionString[:len(versionString)-1], nil } diff --git a/libgo/go/exp/ssh/transport_test.go b/libgo/go/exp/ssh/transport_test.go index 9a610a7..b2e2a7f 100644 --- a/libgo/go/exp/ssh/transport_test.go +++ b/libgo/go/exp/ssh/transport_test.go @@ -12,9 +12,9 @@ import ( func TestReadVersion(t *testing.T) { buf := []byte(serverVersion) - result, ok := readVersion(bufio.NewReader(bytes.NewBuffer(buf))) - if !ok { - t.Error("readVersion didn't read version correctly") + result, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf))) + if err != nil { + t.Errorf("readVersion didn't read version correctly: %s", err) } if !bytes.Equal(buf[:len(buf)-2], result) { t.Error("version read did not match expected") @@ -23,7 +23,7 @@ func TestReadVersion(t *testing.T) { func TestReadVersionTooLong(t *testing.T) { buf := make([]byte, maxVersionStringBytes+1) - if _, ok := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); ok { + if _, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); err == nil { t.Errorf("readVersion consumed %d bytes without error", len(buf)) } } @@ -31,7 +31,7 @@ func TestReadVersionTooLong(t *testing.T) { func TestReadVersionWithoutCRLF(t *testing.T) { buf := []byte(serverVersion) buf = buf[:len(buf)-1] - if _, ok := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); ok { + if _, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); err == nil { t.Error("readVersion did not notice \\n was missing") } } diff --git a/libgo/go/exp/types/gcimporter.go b/libgo/go/exp/types/gcimporter.go index fe90f91..e744a63 100644 --- a/libgo/go/exp/types/gcimporter.go +++ b/libgo/go/exp/types/gcimporter.go @@ -289,9 +289,10 @@ func (p *gcParser) parseExportedName() (*ast.Object, string) { // BasicType = identifier . // func (p *gcParser) parseBasicType() Type { - obj := Universe.Lookup(p.expect(scanner.Ident)) + id := p.expect(scanner.Ident) + obj := Universe.Lookup(id) if obj == nil || obj.Kind != ast.Typ { - p.errorf("not a basic type: %s", obj.Name) + p.errorf("not a basic type: %s", id) } return obj.Type.(Type) } diff --git a/libgo/go/exp/winfsnotify/winfsnotify_test.go b/libgo/go/exp/winfsnotify/winfsnotify_test.go index edf2165..6e264d0 100644 --- a/libgo/go/exp/winfsnotify/winfsnotify_test.go +++ b/libgo/go/exp/winfsnotify/winfsnotify_test.go @@ -6,8 +6,8 @@ package winfsnotify import ( "os" - "time" "testing" + "time" ) func expect(t *testing.T, eventstream <-chan *Event, name string, mask uint32) { @@ -70,15 +70,11 @@ func TestNotifyEvents(t *testing.T) { if _, err = file.WriteString("hello, world"); err != nil { t.Fatalf("failed to write to test file: %s", err) } - if err = file.Sync(); err != nil { - t.Fatalf("failed to sync test file: %s", err) - } - expect(t, watcher.Event, testFile, FS_MODIFY) - expect(t, watcher.Event, testFile, FS_MODIFY) - if err = file.Close(); err != nil { t.Fatalf("failed to close test file: %s", err) } + expect(t, watcher.Event, testFile, FS_MODIFY) + expect(t, watcher.Event, testFile, FS_MODIFY) if err = os.Rename(testFile, testFile2); err != nil { t.Fatalf("failed to rename test file: %s", err) diff --git a/libgo/go/fmt/fmt_test.go b/libgo/go/fmt/fmt_test.go index 030ad61..38280d6 100644 --- a/libgo/go/fmt/fmt_test.go +++ b/libgo/go/fmt/fmt_test.go @@ -88,6 +88,10 @@ type S struct { G G // a struct field that GoStrings } +type SI struct { + I interface{} +} + // A type with a String method with pointer receiver for testing %p type P int @@ -352,6 +356,7 @@ var fmttests = []struct { {"%#v", map[string]int{"a": 1}, `map[string] int{"a":1}`}, {"%#v", map[string]B{"a": {1, 2}}, `map[string] fmt_test.B{"a":fmt_test.B{I:1, j:2}}`}, {"%#v", []string{"a", "b"}, `[]string{"a", "b"}`}, + {"%#v", SI{}, `fmt_test.SI{I:interface { }(nil)}`}, // slices with other formats {"%#x", []int{1, 2, 15}, `[0x1 0x2 0xf]`}, diff --git a/libgo/go/fmt/print.go b/libgo/go/fmt/print.go index 7721e72..710baee 100644 --- a/libgo/go/fmt/print.go +++ b/libgo/go/fmt/print.go @@ -74,6 +74,8 @@ type pp struct { n int panicking bool buf bytes.Buffer + // field holds the current item, as an interface{}. + field interface{} // value holds the current item, as a reflect.Value, and will be // the zero Value if the item has not been reflected. value reflect.Value @@ -132,6 +134,7 @@ func (p *pp) free() { return } p.buf.Reset() + p.field = nil p.value = reflect.Value{} ppFree.put(p) } @@ -294,16 +297,16 @@ func (p *pp) unknownType(v interface{}) { p.buf.WriteByte('?') } -func (p *pp) badVerb(verb int, val interface{}) { +func (p *pp) badVerb(verb int) { p.add('%') p.add('!') p.add(verb) p.add('(') switch { - case val != nil: - p.buf.WriteString(reflect.TypeOf(val).String()) + case p.field != nil: + p.buf.WriteString(reflect.TypeOf(p.field).String()) p.add('=') - p.printField(val, 'v', false, false, 0) + p.printField(p.field, 'v', false, false, 0) case p.value.IsValid(): p.buf.WriteString(p.value.Type().String()) p.add('=') @@ -314,12 +317,12 @@ func (p *pp) badVerb(verb int, val interface{}) { p.add(')') } -func (p *pp) fmtBool(v bool, verb int, value interface{}) { +func (p *pp) fmtBool(v bool, verb int) { switch verb { case 't', 'v': p.fmt.fmt_boolean(v) default: - p.badVerb(verb, value) + p.badVerb(verb) } } @@ -333,7 +336,7 @@ func (p *pp) fmtC(c int64) { p.fmt.pad(p.runeBuf[0:w]) } -func (p *pp) fmtInt64(v int64, verb int, value interface{}) { +func (p *pp) fmtInt64(v int64, verb int) { switch verb { case 'b': p.fmt.integer(v, 2, signed, ldigits) @@ -347,7 +350,7 @@ func (p *pp) fmtInt64(v int64, verb int, value interface{}) { if 0 <= v && v <= unicode.MaxRune { p.fmt.fmt_qc(v) } else { - p.badVerb(verb, value) + p.badVerb(verb) } case 'x': p.fmt.integer(v, 16, signed, ldigits) @@ -356,7 +359,7 @@ func (p *pp) fmtInt64(v int64, verb int, value interface{}) { case 'X': p.fmt.integer(v, 16, signed, udigits) default: - p.badVerb(verb, value) + p.badVerb(verb) } } @@ -391,7 +394,7 @@ func (p *pp) fmtUnicode(v int64) { p.fmt.sharp = sharp } -func (p *pp) fmtUint64(v uint64, verb int, goSyntax bool, value interface{}) { +func (p *pp) fmtUint64(v uint64, verb int, goSyntax bool) { switch verb { case 'b': p.fmt.integer(int64(v), 2, unsigned, ldigits) @@ -411,7 +414,7 @@ func (p *pp) fmtUint64(v uint64, verb int, goSyntax bool, value interface{}) { if 0 <= v && v <= unicode.MaxRune { p.fmt.fmt_qc(int64(v)) } else { - p.badVerb(verb, value) + p.badVerb(verb) } case 'x': p.fmt.integer(int64(v), 16, unsigned, ldigits) @@ -420,11 +423,11 @@ func (p *pp) fmtUint64(v uint64, verb int, goSyntax bool, value interface{}) { case 'U': p.fmtUnicode(int64(v)) default: - p.badVerb(verb, value) + p.badVerb(verb) } } -func (p *pp) fmtFloat32(v float32, verb int, value interface{}) { +func (p *pp) fmtFloat32(v float32, verb int) { switch verb { case 'b': p.fmt.fmt_fb32(v) @@ -439,11 +442,11 @@ func (p *pp) fmtFloat32(v float32, verb int, value interface{}) { case 'G': p.fmt.fmt_G32(v) default: - p.badVerb(verb, value) + p.badVerb(verb) } } -func (p *pp) fmtFloat64(v float64, verb int, value interface{}) { +func (p *pp) fmtFloat64(v float64, verb int) { switch verb { case 'b': p.fmt.fmt_fb64(v) @@ -458,33 +461,33 @@ func (p *pp) fmtFloat64(v float64, verb int, value interface{}) { case 'G': p.fmt.fmt_G64(v) default: - p.badVerb(verb, value) + p.badVerb(verb) } } -func (p *pp) fmtComplex64(v complex64, verb int, value interface{}) { +func (p *pp) fmtComplex64(v complex64, verb int) { switch verb { case 'e', 'E', 'f', 'F', 'g', 'G': p.fmt.fmt_c64(v, verb) case 'v': p.fmt.fmt_c64(v, 'g') default: - p.badVerb(verb, value) + p.badVerb(verb) } } -func (p *pp) fmtComplex128(v complex128, verb int, value interface{}) { +func (p *pp) fmtComplex128(v complex128, verb int) { switch verb { case 'e', 'E', 'f', 'F', 'g', 'G': p.fmt.fmt_c128(v, verb) case 'v': p.fmt.fmt_c128(v, 'g') default: - p.badVerb(verb, value) + p.badVerb(verb) } } -func (p *pp) fmtString(v string, verb int, goSyntax bool, value interface{}) { +func (p *pp) fmtString(v string, verb int, goSyntax bool) { switch verb { case 'v': if goSyntax { @@ -501,11 +504,11 @@ func (p *pp) fmtString(v string, verb int, goSyntax bool, value interface{}) { case 'q': p.fmt.fmt_q(v) default: - p.badVerb(verb, value) + p.badVerb(verb) } } -func (p *pp) fmtBytes(v []byte, verb int, goSyntax bool, depth int, value interface{}) { +func (p *pp) fmtBytes(v []byte, verb int, goSyntax bool, depth int) { if verb == 'v' || verb == 'd' { if goSyntax { p.buf.Write(bytesBytes) @@ -540,17 +543,17 @@ func (p *pp) fmtBytes(v []byte, verb int, goSyntax bool, depth int, value interf case 'q': p.fmt.fmt_q(s) default: - p.badVerb(verb, value) + p.badVerb(verb) } } -func (p *pp) fmtPointer(field interface{}, value reflect.Value, verb int, goSyntax bool) { +func (p *pp) fmtPointer(value reflect.Value, verb int, goSyntax bool) { var u uintptr switch value.Kind() { case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer: u = value.Pointer() default: - p.badVerb(verb, field) + p.badVerb(verb) return } if goSyntax { @@ -576,12 +579,12 @@ var ( uintptrBits = reflect.TypeOf(uintptr(0)).Bits() ) -func (p *pp) catchPanic(val interface{}, verb int) { +func (p *pp) catchPanic(field interface{}, verb int) { if err := recover(); err != nil { // If it's a nil pointer, just say "". The likeliest causes are a // Stringer that fails to guard against nil or a nil pointer for a // value receiver, and in either case, "" is a nice result. - if v := reflect.ValueOf(val); v.Kind() == reflect.Ptr && v.IsNil() { + if v := reflect.ValueOf(field); v.Kind() == reflect.Ptr && v.IsNil() { p.buf.Write(nilAngleBytes) return } @@ -601,12 +604,12 @@ func (p *pp) catchPanic(val interface{}, verb int) { } } -func (p *pp) handleMethods(field interface{}, verb int, plus, goSyntax bool, depth int) (wasString, handled bool) { +func (p *pp) handleMethods(verb int, plus, goSyntax bool, depth int) (wasString, handled bool) { // Is it a Formatter? - if formatter, ok := field.(Formatter); ok { + if formatter, ok := p.field.(Formatter); ok { handled = true wasString = false - defer p.catchPanic(field, verb) + defer p.catchPanic(p.field, verb) formatter.Format(p, verb) return } @@ -618,20 +621,20 @@ func (p *pp) handleMethods(field interface{}, verb int, plus, goSyntax bool, dep // If we're doing Go syntax and the field knows how to supply it, take care of it now. if goSyntax { p.fmt.sharp = false - if stringer, ok := field.(GoStringer); ok { + if stringer, ok := p.field.(GoStringer); ok { wasString = false handled = true - defer p.catchPanic(field, verb) + defer p.catchPanic(p.field, verb) // Print the result of GoString unadorned. - p.fmtString(stringer.GoString(), 's', false, field) + p.fmtString(stringer.GoString(), 's', false) return } } else { // Is it a Stringer? - if stringer, ok := field.(Stringer); ok { + if stringer, ok := p.field.(Stringer); ok { wasString = false handled = true - defer p.catchPanic(field, verb) + defer p.catchPanic(p.field, verb) p.printField(stringer.String(), verb, plus, false, depth) return } @@ -645,11 +648,13 @@ func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth if verb == 'T' || verb == 'v' { p.buf.Write(nilAngleBytes) } else { - p.badVerb(verb, field) + p.badVerb(verb) } return false } + p.field = field + p.value = reflect.Value{} // Special processing considerations. // %T (the value's type) and %p (its address) are special; we always do them first. switch verb { @@ -657,74 +662,60 @@ func (p *pp) printField(field interface{}, verb int, plus, goSyntax bool, depth p.printField(reflect.TypeOf(field).String(), 's', false, false, 0) return false case 'p': - p.fmtPointer(field, reflect.ValueOf(field), verb, goSyntax) + p.fmtPointer(reflect.ValueOf(field), verb, goSyntax) return false } - if wasString, handled := p.handleMethods(field, verb, plus, goSyntax, depth); handled { + if wasString, handled := p.handleMethods(verb, plus, goSyntax, depth); handled { return wasString } // Some types can be done without reflection. switch f := field.(type) { case bool: - p.fmtBool(f, verb, field) - return false + p.fmtBool(f, verb) case float32: - p.fmtFloat32(f, verb, field) - return false + p.fmtFloat32(f, verb) case float64: - p.fmtFloat64(f, verb, field) - return false + p.fmtFloat64(f, verb) case complex64: - p.fmtComplex64(complex64(f), verb, field) - return false + p.fmtComplex64(complex64(f), verb) case complex128: - p.fmtComplex128(f, verb, field) - return false + p.fmtComplex128(f, verb) case int: - p.fmtInt64(int64(f), verb, field) - return false + p.fmtInt64(int64(f), verb) case int8: - p.fmtInt64(int64(f), verb, field) - return false + p.fmtInt64(int64(f), verb) case int16: - p.fmtInt64(int64(f), verb, field) - return false + p.fmtInt64(int64(f), verb) case int32: - p.fmtInt64(int64(f), verb, field) - return false + p.fmtInt64(int64(f), verb) case int64: - p.fmtInt64(f, verb, field) - return false + p.fmtInt64(f, verb) case uint: - p.fmtUint64(uint64(f), verb, goSyntax, field) - return false + p.fmtUint64(uint64(f), verb, goSyntax) case uint8: - p.fmtUint64(uint64(f), verb, goSyntax, field) - return false + p.fmtUint64(uint64(f), verb, goSyntax) case uint16: - p.fmtUint64(uint64(f), verb, goSyntax, field) - return false + p.fmtUint64(uint64(f), verb, goSyntax) case uint32: - p.fmtUint64(uint64(f), verb, goSyntax, field) - return false + p.fmtUint64(uint64(f), verb, goSyntax) case uint64: - p.fmtUint64(f, verb, goSyntax, field) - return false + p.fmtUint64(f, verb, goSyntax) case uintptr: - p.fmtUint64(uint64(f), verb, goSyntax, field) - return false + p.fmtUint64(uint64(f), verb, goSyntax) case string: - p.fmtString(f, verb, goSyntax, field) - return verb == 's' || verb == 'v' + p.fmtString(f, verb, goSyntax) + wasString = verb == 's' || verb == 'v' case []byte: - p.fmtBytes(f, verb, goSyntax, depth, field) - return verb == 's' + p.fmtBytes(f, verb, goSyntax, depth) + wasString = verb == 's' + default: + // Need to use reflection + return p.printReflectValue(reflect.ValueOf(field), verb, plus, goSyntax, depth) } - - // Need to use reflection - return p.printReflectValue(reflect.ValueOf(field), verb, plus, goSyntax, depth) + p.field = nil + return } // printValue is like printField but starts with a reflect value, not an interface{} value. @@ -733,7 +724,7 @@ func (p *pp) printValue(value reflect.Value, verb int, plus, goSyntax bool, dept if verb == 'T' || verb == 'v' { p.buf.Write(nilAngleBytes) } else { - p.badVerb(verb, nil) + p.badVerb(verb) } return false } @@ -745,17 +736,17 @@ func (p *pp) printValue(value reflect.Value, verb int, plus, goSyntax bool, dept p.printField(value.Type().String(), 's', false, false, 0) return false case 'p': - p.fmtPointer(nil, value, verb, goSyntax) + p.fmtPointer(value, verb, goSyntax) return false } // Handle values with special methods. // Call always, even when field == nil, because handleMethods clears p.fmt.plus for us. - var field interface{} + p.field = nil // Make sure it's cleared, for safety. if value.CanInterface() { - field = value.Interface() + p.field = value.Interface() } - if wasString, handled := p.handleMethods(field, verb, plus, goSyntax, depth); handled { + if wasString, handled := p.handleMethods(verb, plus, goSyntax, depth); handled { return wasString } @@ -770,25 +761,25 @@ func (p *pp) printReflectValue(value reflect.Value, verb int, plus, goSyntax boo BigSwitch: switch f := value; f.Kind() { case reflect.Bool: - p.fmtBool(f.Bool(), verb, nil) + p.fmtBool(f.Bool(), verb) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - p.fmtInt64(f.Int(), verb, nil) + p.fmtInt64(f.Int(), verb) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - p.fmtUint64(uint64(f.Uint()), verb, goSyntax, nil) + p.fmtUint64(uint64(f.Uint()), verb, goSyntax) case reflect.Float32, reflect.Float64: if f.Type().Size() == 4 { - p.fmtFloat32(float32(f.Float()), verb, nil) + p.fmtFloat32(float32(f.Float()), verb) } else { - p.fmtFloat64(float64(f.Float()), verb, nil) + p.fmtFloat64(float64(f.Float()), verb) } case reflect.Complex64, reflect.Complex128: if f.Type().Size() == 8 { - p.fmtComplex64(complex64(f.Complex()), verb, nil) + p.fmtComplex64(complex64(f.Complex()), verb) } else { - p.fmtComplex128(complex128(f.Complex()), verb, nil) + p.fmtComplex128(complex128(f.Complex()), verb) } case reflect.String: - p.fmtString(f.String(), verb, goSyntax, nil) + p.fmtString(f.String(), verb, goSyntax) case reflect.Map: if goSyntax { p.buf.WriteString(f.Type().String()) @@ -842,7 +833,7 @@ BigSwitch: value := f.Elem() if !value.IsValid() { if goSyntax { - p.buf.WriteString(value.Type().String()) + p.buf.WriteString(f.Type().String()) p.buf.Write(nilParenBytes) } else { p.buf.Write(nilAngleBytes) @@ -864,7 +855,7 @@ BigSwitch: for i := range bytes { bytes[i] = byte(f.Index(i).Uint()) } - p.fmtBytes(bytes, verb, goSyntax, depth, nil) + p.fmtBytes(bytes, verb, goSyntax, depth) wasString = verb == 's' break } @@ -924,7 +915,7 @@ BigSwitch: } p.fmt0x64(uint64(v), true) case reflect.Chan, reflect.Func, reflect.UnsafePointer: - p.fmtPointer(nil, value, verb, goSyntax) + p.fmtPointer(value, verb, goSyntax) default: p.unknownType(f) } diff --git a/libgo/go/go/ast/print_test.go b/libgo/go/go/ast/print_test.go index a4bc3bb..c3153ed 100644 --- a/libgo/go/go/ast/print_test.go +++ b/libgo/go/go/ast/print_test.go @@ -23,11 +23,10 @@ var tests = []struct { {"foobar", "0 \"foobar\""}, // maps - {map[string]int{"a": 1, "b": 2}, - `0 map[string] int (len = 2) { + {map[string]int{"a": 1}, + `0 map[string] int (len = 1) { 1 . "a": 1 - 2 . "b": 2 - 3 }`}, + 2 }`}, // pointers {new(int), "0 *0"}, diff --git a/libgo/go/html/doc.go b/libgo/go/html/doc.go index 5bc0630..ba9d188 100644 --- a/libgo/go/html/doc.go +++ b/libgo/go/html/doc.go @@ -70,9 +70,6 @@ call to Next. For example, to extract an HTML page's anchor text: } } -A Tokenizer typically skips over HTML comments. To return comment tokens, set -Tokenizer.ReturnComments to true before looping over calls to Next. - Parsing is done by calling Parse with an io.Reader, which returns the root of the parse tree (the document element) as a *Node. It is the caller's responsibility to ensure that the Reader provides UTF-8 encoded HTML. For diff --git a/libgo/go/html/parse.go b/libgo/go/html/parse.go index 582437f..530942aa 100644 --- a/libgo/go/html/parse.go +++ b/libgo/go/html/parse.go @@ -32,6 +32,9 @@ type parser struct { // originalIM is the insertion mode to go back to after completing a text // or inTableText insertion mode. originalIM insertionMode + // fosterParenting is whether new elements should be inserted according to + // the foster parenting rules (section 11.2.5.3). + fosterParenting bool } func (p *parser) top() *Node { @@ -49,6 +52,11 @@ var ( tableScopeStopTags = []string{"html", "table"} ) +// stopTags for use in clearStackToContext. +var ( + tableRowContextStopTags = []string{"tr", "html"} +) + // popUntil pops the stack of open elements at the highest element whose tag // is in matchTags, provided there is no higher element in stopTags. It returns // whether or not there was such an element. If there was not, popUntil leaves @@ -103,12 +111,61 @@ func (p *parser) elementInScope(stopTags []string, matchTags ...string) bool { // addChild adds a child node n to the top element, and pushes n onto the stack // of open elements if it is an element node. func (p *parser) addChild(n *Node) { - p.top().Add(n) + if p.fosterParenting { + p.fosterParent(n) + } else { + p.top().Add(n) + } + if n.Type == ElementNode { p.oe = append(p.oe, n) } } +// fosterParent adds a child node according to the foster parenting rules. +// Section 11.2.5.3, "foster parenting". +func (p *parser) fosterParent(n *Node) { + var table, parent *Node + var i int + for i = len(p.oe) - 1; i >= 0; i-- { + if p.oe[i].Data == "table" { + table = p.oe[i] + break + } + } + + if table == nil { + // The foster parent is the html element. + parent = p.oe[0] + } else { + parent = table.Parent + } + if parent == nil { + parent = p.oe[i-1] + } + + var child *Node + for i, child = range parent.Child { + if child == table { + break + } + } + + if i > 0 && parent.Child[i-1].Type == TextNode && n.Type == TextNode { + parent.Child[i-1].Data += n.Data + return + } + + if i == len(parent.Child) { + parent.Add(n) + } else { + // Insert n into parent.Child at index i. + parent.Child = append(parent.Child[:i+1], parent.Child[i:]...) + parent.Child[i] = n + n.Parent = parent + } +} + // addText adds text to the preceding node if it is a text node, or else it // calls addChild with a new text node. func (p *parser) addText(text string) { @@ -170,9 +227,9 @@ func (p *parser) reconstructActiveFormattingElements() { } for { i++ - n = p.afe[i] - p.addChild(n.clone()) - p.afe[i] = n + clone := p.afe[i].clone() + p.addChild(clone) + p.afe[i] = clone if i == len(p.afe)-1 { break } @@ -234,10 +291,52 @@ func (p *parser) setOriginalIM(im insertionMode) { p.originalIM = im } +// Section 11.2.3.1, "reset the insertion mode". +func (p *parser) resetInsertionMode() insertionMode { + for i := len(p.oe) - 1; i >= 0; i-- { + n := p.oe[i] + if i == 0 { + // TODO: set n to the context element, for HTML fragment parsing. + } + switch n.Data { + case "select": + return inSelectIM + case "td", "th": + return inCellIM + case "tr": + return inRowIM + case "tbody", "thead", "tfoot": + return inTableBodyIM + case "caption": + // TODO: return inCaptionIM + case "colgroup": + // TODO: return inColumnGroupIM + case "table": + return inTableIM + case "head": + return inBodyIM + case "body": + return inBodyIM + case "frameset": + // TODO: return inFramesetIM + case "html": + return beforeHeadIM + } + } + return inBodyIM +} + // Section 11.2.5.4.1. func initialIM(p *parser) (insertionMode, bool) { - if p.tok.Type == DoctypeToken { - p.addChild(&Node{ + switch p.tok.Type { + case CommentToken: + p.doc.Add(&Node{ + Type: CommentNode, + Data: p.tok.Data, + }) + return initialIM, true + case DoctypeToken: + p.doc.Add(&Node{ Type: DoctypeNode, Data: p.tok.Data, }) @@ -275,6 +374,12 @@ func beforeHTMLIM(p *parser) (insertionMode, bool) { default: // Ignore the token. } + case CommentToken: + p.doc.Add(&Node{ + Type: CommentNode, + Data: p.tok.Data, + }) + return beforeHTMLIM, true } if add || implied { p.addElement("html", attr) @@ -312,6 +417,12 @@ func beforeHeadIM(p *parser) (insertionMode, bool) { default: // Ignore the token. } + case CommentToken: + p.addChild(&Node{ + Type: CommentNode, + Data: p.tok.Data, + }) + return beforeHeadIM, true } if add || implied { p.addElement("head", attr) @@ -344,11 +455,17 @@ func inHeadIM(p *parser) (insertionMode, bool) { pop = true } // TODO. + case CommentToken: + p.addChild(&Node{ + Type: CommentNode, + Data: p.tok.Data, + }) + return inHeadIM, true } if pop || implied { n := p.oe.pop() if n.Data != "head" { - panic("html: bad parser state") + panic("html: bad parser state: element not found, in the in-head insertion mode") } return afterHeadIM, !implied } @@ -387,6 +504,12 @@ func afterHeadIM(p *parser) (insertionMode, bool) { } case EndTagToken: // TODO. + case CommentToken: + p.addChild(&Node{ + Type: CommentNode, + Data: p.tok.Data, + }) + return afterHeadIM, true } if add || implied { p.addElement("body", attr) @@ -447,6 +570,30 @@ func inBodyIM(p *parser) (insertionMode, bool) { p.oe.pop() p.acknowledgeSelfClosingTag() p.framesetOK = false + case "select": + p.reconstructActiveFormattingElements() + p.addElement(p.tok.Data, p.tok.Attr) + p.framesetOK = false + // TODO: detect