diff options
Diffstat (limited to 'libgo/go/net')
117 files changed, 8099 insertions, 2711 deletions
diff --git a/libgo/go/net/conn_test.go b/libgo/go/net/conn_test.go index f733a81..98bd695 100644 --- a/libgo/go/net/conn_test.go +++ b/libgo/go/net/conn_test.go @@ -2,10 +2,12 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package net_test +// This file implements API tests across platforms and will never have a build +// tag. + +package net import ( - "net" "os" "runtime" "testing" @@ -14,13 +16,18 @@ import ( var connTests = []struct { net string - addr string + addr func() string }{ - {"tcp", "127.0.0.1:0"}, - {"unix", "/tmp/gotest.net1"}, - {"unixpacket", "/tmp/gotest.net2"}, + {"tcp", func() string { return "127.0.0.1:0" }}, + {"unix", testUnixAddr}, + {"unixpacket", testUnixAddr}, } +// someTimeout is used just to test that net.Conn implementations +// don't explode when their SetFooDeadline methods are called. +// It isn't actually used for testing timeouts. +const someTimeout = 10 * time.Second + func TestConnAndListener(t *testing.T) { for _, tt := range connTests { switch tt.net { @@ -32,74 +39,77 @@ func TestConnAndListener(t *testing.T) { if tt.net == "unixpacket" && runtime.GOOS != "linux" { continue } - os.Remove(tt.addr) } - ln, err := net.Listen(tt.net, tt.addr) + addr := tt.addr() + ln, err := Listen(tt.net, addr) if err != nil { - t.Errorf("net.Listen failed: %v", err) - return + t.Fatalf("Listen failed: %v", err) } - ln.Addr() - defer func(ln net.Listener, net, addr string) { + defer func(ln Listener, net, addr string) { ln.Close() switch net { case "unix", "unixpacket": os.Remove(addr) } - }(ln, tt.net, tt.addr) + }(ln, tt.net, addr) + ln.Addr() done := make(chan int) go transponder(t, ln, done) - c, err := net.Dial(tt.net, ln.Addr().String()) + c, err := Dial(tt.net, ln.Addr().String()) if err != nil { - t.Errorf("net.Dial failed: %v", err) - return + t.Fatalf("Dial failed: %v", err) } + defer c.Close() c.LocalAddr() c.RemoteAddr() - c.SetDeadline(time.Now().Add(100 * time.Millisecond)) - c.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) - c.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) - defer c.Close() + c.SetDeadline(time.Now().Add(someTimeout)) + c.SetReadDeadline(time.Now().Add(someTimeout)) + c.SetWriteDeadline(time.Now().Add(someTimeout)) if _, err := c.Write([]byte("CONN TEST")); err != nil { - t.Errorf("net.Conn.Write failed: %v", err) - return + t.Fatalf("Conn.Write failed: %v", err) } rb := make([]byte, 128) if _, err := c.Read(rb); err != nil { - t.Errorf("net.Conn.Read failed: %v", err) + t.Fatalf("Conn.Read failed: %v", err) } <-done } } -func transponder(t *testing.T, ln net.Listener, done chan<- int) { +func transponder(t *testing.T, ln Listener, done chan<- int) { defer func() { done <- 1 }() + switch ln := ln.(type) { + case *TCPListener: + ln.SetDeadline(time.Now().Add(someTimeout)) + case *UnixListener: + ln.SetDeadline(time.Now().Add(someTimeout)) + } c, err := ln.Accept() if err != nil { - t.Errorf("net.Listener.Accept failed: %v", err) + t.Errorf("Listener.Accept failed: %v", err) return } + defer c.Close() c.LocalAddr() c.RemoteAddr() - c.SetDeadline(time.Now().Add(100 * time.Millisecond)) - c.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) - c.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) - defer c.Close() + c.SetDeadline(time.Now().Add(someTimeout)) + c.SetReadDeadline(time.Now().Add(someTimeout)) + c.SetWriteDeadline(time.Now().Add(someTimeout)) b := make([]byte, 128) n, err := c.Read(b) if err != nil { - t.Errorf("net.Conn.Read failed: %v", err) + t.Errorf("Conn.Read failed: %v", err) return } if _, err := c.Write(b[:n]); err != nil { - t.Errorf("net.Conn.Write failed: %v", err) + t.Errorf("Conn.Write failed: %v", err) return } } diff --git a/libgo/go/net/dial.go b/libgo/go/net/dial.go index 354028a..b18d283 100644 --- a/libgo/go/net/dial.go +++ b/libgo/go/net/dial.go @@ -5,10 +5,55 @@ package net import ( + "errors" "time" ) -func parseDialNetwork(net string) (afnet string, proto int, err error) { +// A Dialer contains options for connecting to an address. +// +// The zero value for each field is equivalent to dialing +// without that option. Dialing with the zero value of Dialer +// is therefore equivalent to just calling the Dial function. +type Dialer struct { + // Timeout is the maximum amount of time a dial will wait for + // a connect to complete. If Deadline is also set, it may fail + // earlier. + // + // The default is no timeout. + // + // With or without a timeout, the operating system may impose + // its own earlier timeout. For instance, TCP timeouts are + // often around 3 minutes. + Timeout time.Duration + + // Deadline is the absolute point in time after which dials + // will fail. If Timeout is set, it may fail earlier. + // Zero means no deadline, or dependent on the operating system + // as with the Timeout option. + Deadline time.Time + + // LocalAddr is the local address to use when dialing an + // address. The address must be of a compatible type for the + // network being dialed. + // If nil, a local address is automatically chosen. + LocalAddr Addr +} + +// Return either now+Timeout or Deadline, whichever comes first. +// Or zero, if neither is set. +func (d *Dialer) deadline() time.Time { + if d.Timeout == 0 { + return d.Deadline + } + timeoutDeadline := time.Now().Add(d.Timeout) + if d.Deadline.IsZero() || timeoutDeadline.Before(d.Deadline) { + return timeoutDeadline + } else { + return d.Deadline + } +} + +func parseNetwork(net string) (afnet string, proto int, err error) { i := last(net, ':') if i < 0 { // no colon switch net { @@ -37,74 +82,89 @@ func parseDialNetwork(net string) (afnet string, proto int, err error) { return "", 0, UnknownNetworkError(net) } -func resolveNetAddr(op, net, addr string, deadline time.Time) (afnet string, a Addr, err error) { - afnet, _, err = parseDialNetwork(net) +func resolveAddr(op, net, addr string, deadline time.Time) (Addr, error) { + afnet, _, err := parseNetwork(net) if err != nil { - return "", nil, &OpError{op, net, nil, err} + return nil, &OpError{op, net, nil, err} } if op == "dial" && addr == "" { - return "", nil, &OpError{op, net, nil, errMissingAddress} - } - a, err = resolveAfnetAddr(afnet, addr, deadline) - return -} - -func resolveAfnetAddr(afnet, addr string, deadline time.Time) (Addr, error) { - if addr == "" { - return nil, nil + return nil, &OpError{op, net, nil, errMissingAddress} } switch afnet { - case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6", "ip", "ip4", "ip6": - return resolveInternetAddr(afnet, addr, deadline) case "unix", "unixgram", "unixpacket": return ResolveUnixAddr(afnet, addr) } - return nil, nil + return resolveInternetAddr(afnet, addr, deadline) } -// Dial connects to the address addr on the network net. +// Dial connects to the address on the named network. // // Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only), // "udp", "udp4" (IPv4-only), "udp6" (IPv6-only), "ip", "ip4" -// (IPv4-only), "ip6" (IPv6-only), "unix" and "unixpacket". +// (IPv4-only), "ip6" (IPv6-only), "unix", "unixgram" and +// "unixpacket". // // For TCP and UDP networks, addresses have the form host:port. -// If host is a literal IPv6 address, it must be enclosed -// in square brackets. The functions JoinHostPort and SplitHostPort -// manipulate addresses in this form. +// If host is a literal IPv6 address or host name, it must be enclosed +// in square brackets as in "[::1]:80", "[ipv6-host]:http" or +// "[ipv6-host%zone]:80". +// The functions JoinHostPort and SplitHostPort manipulate addresses +// in this form. // // Examples: // Dial("tcp", "12.34.56.78:80") -// Dial("tcp", "google.com:80") -// Dial("tcp", "[de:ad:be:ef::ca:fe]:80") +// Dial("tcp", "google.com:http") +// Dial("tcp", "[2001:db8::1]:http") +// Dial("tcp", "[fe80::1%lo0]:80") // -// For IP networks, net must be "ip", "ip4" or "ip6" followed -// by a colon and a protocol number or name. +// For IP networks, the network must be "ip", "ip4" or "ip6" followed +// by a colon and a protocol number or name and the addr must be a +// literal IP address. // // Examples: // Dial("ip4:1", "127.0.0.1") // Dial("ip6:ospf", "::1") // -func Dial(net, addr string) (Conn, error) { - _, addri, err := resolveNetAddr("dial", net, addr, noDeadline) - if err != nil { - return nil, err - } - return dialAddr(net, addr, addri, noDeadline) +// For Unix networks, the address must be a file system path. +func Dial(network, address string) (Conn, error) { + var d Dialer + return d.Dial(network, address) +} + +// DialTimeout acts like Dial but takes a timeout. +// The timeout includes name resolution, if required. +func DialTimeout(network, address string, timeout time.Duration) (Conn, error) { + d := Dialer{Timeout: timeout} + return d.Dial(network, address) +} + +// Dial connects to the address on the named network. +// +// See func Dial for a description of the network and address +// parameters. +func (d *Dialer) Dial(network, address string) (Conn, error) { + return resolveAndDial(network, address, d.LocalAddr, d.deadline()) } -func dialAddr(net, addr string, addri Addr, deadline time.Time) (c Conn, err error) { - switch ra := addri.(type) { +func dial(net, addr string, la, ra Addr, deadline time.Time) (c Conn, err error) { + if la != nil && la.Network() != ra.Network() { + return nil, &OpError{"dial", net, ra, errors.New("mismatched local addr type " + la.Network())} + } + switch ra := ra.(type) { case *TCPAddr: - c, err = dialTCP(net, nil, ra, deadline) + la, _ := la.(*TCPAddr) + c, err = dialTCP(net, la, ra, deadline) case *UDPAddr: - c, err = dialUDP(net, nil, ra, deadline) + la, _ := la.(*UDPAddr) + c, err = dialUDP(net, la, ra, deadline) case *IPAddr: - c, err = dialIP(net, nil, ra, deadline) + la, _ := la.(*IPAddr) + c, err = dialIP(net, la, ra, deadline) case *UnixAddr: - c, err = dialUnix(net, nil, ra, deadline) + la, _ := la.(*UnixAddr) + c, err = dialUnix(net, la, ra, deadline) default: - err = &OpError{"dial", net + " " + addr, nil, UnknownNetworkError(net)} + err = &OpError{"dial", net + " " + addr, ra, UnknownNetworkError(net)} } if err != nil { return nil, err @@ -112,59 +172,6 @@ func dialAddr(net, addr string, addri Addr, deadline time.Time) (c Conn, err err return } -// DialTimeout acts like Dial but takes a timeout. -// The timeout includes name resolution, if required. -func DialTimeout(net, addr string, timeout time.Duration) (Conn, error) { - return dialTimeout(net, addr, timeout) -} - -// dialTimeoutRace is the old implementation of DialTimeout, still used -// on operating systems where the deadline hasn't been pushed down -// into the pollserver. -// TODO: fix this on plan9. -func dialTimeoutRace(net, addr string, timeout time.Duration) (Conn, error) { - t := time.NewTimer(timeout) - defer t.Stop() - type pair struct { - Conn - error - } - ch := make(chan pair, 1) - resolvedAddr := make(chan Addr, 1) - go func() { - _, addri, err := resolveNetAddr("dial", net, addr, noDeadline) - if err != nil { - ch <- pair{nil, err} - return - } - resolvedAddr <- addri // in case we need it for OpError - c, err := dialAddr(net, addr, addri, noDeadline) - ch <- pair{c, err} - }() - select { - case <-t.C: - // Try to use the real Addr in our OpError, if we resolved it - // before the timeout. Otherwise we just use stringAddr. - var addri Addr - select { - case a := <-resolvedAddr: - addri = a - default: - addri = &stringAddr{net, addr} - } - err := &OpError{ - Op: "dial", - Net: net, - Addr: addri, - Err: &timeoutError{}, - } - return nil, err - case p := <-ch: - return p.Conn, p.error - } - panic("unreachable") -} - type stringAddr struct { net, addr string } @@ -173,56 +180,38 @@ func (a stringAddr) Network() string { return a.net } func (a stringAddr) String() string { return a.addr } // Listen announces on the local network address laddr. -// The network string net must be a stream-oriented network: -// "tcp", "tcp4", "tcp6", "unix" or "unixpacket". +// The network net must be a stream-oriented network: "tcp", "tcp4", +// "tcp6", "unix" or "unixpacket". +// See Dial for the syntax of laddr. func Listen(net, laddr string) (Listener, error) { - afnet, a, err := resolveNetAddr("listen", net, laddr, noDeadline) + la, err := resolveAddr("listen", net, laddr, noDeadline) if err != nil { return nil, err } - switch afnet { - case "tcp", "tcp4", "tcp6": - var la *TCPAddr - if a != nil { - la = a.(*TCPAddr) - } + switch la := la.(type) { + case *TCPAddr: return ListenTCP(net, la) - case "unix", "unixpacket": - var la *UnixAddr - if a != nil { - la = a.(*UnixAddr) - } + case *UnixAddr: return ListenUnix(net, la) } return nil, UnknownNetworkError(net) } // ListenPacket announces on the local network address laddr. -// The network string net must be a packet-oriented network: -// "udp", "udp4", "udp6", "ip", "ip4", "ip6" or "unixgram". +// The network net must be a packet-oriented network: "udp", "udp4", +// "udp6", "ip", "ip4", "ip6" or "unixgram". +// See Dial for the syntax of laddr. func ListenPacket(net, laddr string) (PacketConn, error) { - afnet, a, err := resolveNetAddr("listen", net, laddr, noDeadline) + la, err := resolveAddr("listen", net, laddr, noDeadline) if err != nil { return nil, err } - switch afnet { - case "udp", "udp4", "udp6": - var la *UDPAddr - if a != nil { - la = a.(*UDPAddr) - } + switch la := la.(type) { + case *UDPAddr: return ListenUDP(net, la) - case "ip", "ip4", "ip6": - var la *IPAddr - if a != nil { - la = a.(*IPAddr) - } + case *IPAddr: return ListenIP(net, la) - case "unixgram": - var la *UnixAddr - if a != nil { - la = a.(*UnixAddr) - } + case *UnixAddr: return ListenUnixgram(net, la) } return nil, UnknownNetworkError(net) diff --git a/libgo/go/net/dial_gen.go b/libgo/go/net/dial_gen.go new file mode 100644 index 0000000..19f8681 --- /dev/null +++ b/libgo/go/net/dial_gen.go @@ -0,0 +1,73 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build windows plan9 + +package net + +import ( + "time" +) + +var testingIssue5349 bool // used during tests + +// resolveAndDialChannel is the simple pure-Go implementation of +// resolveAndDial, still used on operating systems where the deadline +// hasn't been pushed down into the pollserver. (Plan 9 and some old +// versions of Windows) +func resolveAndDialChannel(net, addr string, localAddr Addr, deadline time.Time) (Conn, error) { + var timeout time.Duration + if !deadline.IsZero() { + timeout = deadline.Sub(time.Now()) + } + if timeout <= 0 { + ra, err := resolveAddr("dial", net, addr, noDeadline) + if err != nil { + return nil, err + } + return dial(net, addr, localAddr, ra, noDeadline) + } + t := time.NewTimer(timeout) + defer t.Stop() + type pair struct { + Conn + error + } + ch := make(chan pair, 1) + resolvedAddr := make(chan Addr, 1) + go func() { + if testingIssue5349 { + time.Sleep(time.Millisecond) + } + ra, err := resolveAddr("dial", net, addr, noDeadline) + if err != nil { + ch <- pair{nil, err} + return + } + resolvedAddr <- ra // in case we need it for OpError + c, err := dial(net, addr, localAddr, ra, noDeadline) + ch <- pair{c, err} + }() + select { + case <-t.C: + // Try to use the real Addr in our OpError, if we resolved it + // before the timeout. Otherwise we just use stringAddr. + var ra Addr + select { + case a := <-resolvedAddr: + ra = a + default: + ra = &stringAddr{net, addr} + } + err := &OpError{ + Op: "dial", + Net: net, + Addr: ra, + Err: &timeoutError{}, + } + return nil, err + case p := <-ch: + return p.Conn, p.error + } +} diff --git a/libgo/go/net/dial_test.go b/libgo/go/net/dial_test.go index aa53b66..03a0bad 100644 --- a/libgo/go/net/dial_test.go +++ b/libgo/go/net/dial_test.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "os" + "reflect" "regexp" "runtime" "testing" @@ -27,12 +28,18 @@ func newLocalListener(t *testing.T) Listener { } func TestDialTimeout(t *testing.T) { + origBacklog := listenerBacklog + defer func() { + listenerBacklog = origBacklog + }() + listenerBacklog = 1 + ln := newLocalListener(t) defer ln.Close() errc := make(chan error) - numConns := listenerBacklog + 10 + numConns := listenerBacklog + 100 // TODO(bradfitz): It's hard to test this in a portable // way. This is unfortunate, but works for now. @@ -223,6 +230,31 @@ func TestDialError(t *testing.T) { } } +var invalidDialAndListenArgTests = []struct { + net string + addr string + err error +}{ + {"foo", "bar", &OpError{Op: "dial", Net: "foo", Addr: nil, Err: UnknownNetworkError("foo")}}, + {"baz", "", &OpError{Op: "listen", Net: "baz", Addr: nil, Err: UnknownNetworkError("baz")}}, + {"tcp", "", &OpError{Op: "dial", Net: "tcp", Addr: nil, Err: errMissingAddress}}, +} + +func TestInvalidDialAndListenArgs(t *testing.T) { + for _, tt := range invalidDialAndListenArgTests { + var err error + switch tt.err.(*OpError).Op { + case "dial": + _, err = Dial(tt.net, tt.addr) + case "listen": + _, err = Listen(tt.net, tt.addr) + } + if !reflect.DeepEqual(tt.err, err) { + t.Fatalf("got %#v; expected %#v", err, tt.err) + } + } +} + func TestDialTimeoutFDLeak(t *testing.T) { if runtime.GOOS != "linux" { // TODO(bradfitz): test on other platforms @@ -298,3 +330,80 @@ func numFD() int { // All tests using this should be skipped anyway, but: panic("numFDs not implemented on " + runtime.GOOS) } + +var testPoller = flag.Bool("poller", false, "platform supports runtime-integrated poller") + +// Assert that a failed Dial attempt does not leak +// runtime.PollDesc structures +func TestDialFailPDLeak(t *testing.T) { + if !*testPoller { + t.Skip("test disabled; use -poller to enable") + } + + const loops = 10 + const count = 20000 + var old runtime.MemStats // used by sysdelta + runtime.ReadMemStats(&old) + sysdelta := func() uint64 { + var new runtime.MemStats + runtime.ReadMemStats(&new) + delta := old.Sys - new.Sys + old = new + return delta + } + d := &Dialer{Timeout: time.Nanosecond} // don't bother TCP with handshaking + failcount := 0 + for i := 0; i < loops; i++ { + for i := 0; i < count; i++ { + conn, err := d.Dial("tcp", "127.0.0.1:1") + if err == nil { + t.Error("dial should not succeed") + conn.Close() + t.FailNow() + } + } + if delta := sysdelta(); delta > 0 { + failcount++ + } + // there are always some allocations on the first loop + if failcount > 3 { + t.Error("detected possible memory leak in runtime") + t.FailNow() + } + } +} + +func TestDialer(t *testing.T) { + ln, err := Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + defer ln.Close() + ch := make(chan error, 1) + go func() { + var err error + c, err := ln.Accept() + if err != nil { + ch <- fmt.Errorf("Accept failed: %v", err) + return + } + defer c.Close() + ch <- nil + }() + + laddr, err := ResolveTCPAddr("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("ResolveTCPAddr failed: %v", err) + } + d := &Dialer{LocalAddr: laddr} + c, err := d.Dial("tcp4", ln.Addr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + defer c.Close() + c.Read(make([]byte, 1)) + err = <-ch + if err != nil { + t.Error(err) + } +} diff --git a/libgo/go/net/fd_bsd.go b/libgo/go/net/fd_bsd.go index f5a55bb..8bb1ae5 100644 --- a/libgo/go/net/fd_bsd.go +++ b/libgo/go/net/fd_bsd.go @@ -33,6 +33,8 @@ func newpollster() (p *pollster, err error) { return p, nil } +// First return value is whether the pollServer should be woken up. +// This version always returns false. func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) { // pollServer is locked. @@ -64,6 +66,8 @@ func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) { return false, nil } +// Return value is whether the pollServer should be woken up. +// This version always returns false. func (p *pollster) DelFD(fd int, mode int) bool { // pollServer is locked. diff --git a/libgo/go/net/fd_linux.go b/libgo/go/net/fd_linux.go deleted file mode 100644 index 8ecbff8..0000000 --- a/libgo/go/net/fd_linux.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Waiting for FDs via epoll(7). - -package net - -import ( - "os" - "syscall" -) - -const ( - readFlags = syscall.EPOLLIN | syscall.EPOLLRDHUP - writeFlags = syscall.EPOLLOUT -) - -type pollster struct { - epfd int - - // Events we're already waiting for - // Must hold pollServer lock - events map[int]uint32 - - // An event buffer for EpollWait. - // Used without a lock, may only be used by WaitFD. - waitEventBuf [10]syscall.EpollEvent - waitEvents []syscall.EpollEvent - - // An event buffer for EpollCtl, to avoid a malloc. - // Must hold pollServer lock. - ctlEvent syscall.EpollEvent -} - -func newpollster() (p *pollster, err error) { - p = new(pollster) - if p.epfd, err = syscall.EpollCreate1(syscall.EPOLL_CLOEXEC); err != nil { - if err != syscall.ENOSYS { - return nil, os.NewSyscallError("epoll_create1", err) - } - // The arg to epoll_create is a hint to the kernel - // about the number of FDs we will care about. - // We don't know, and since 2.6.8 the kernel ignores it anyhow. - if p.epfd, err = syscall.EpollCreate(16); err != nil { - return nil, os.NewSyscallError("epoll_create", err) - } - syscall.CloseOnExec(p.epfd) - } - p.events = make(map[int]uint32) - return p, nil -} - -func (p *pollster) AddFD(fd int, mode int, repeat bool) (bool, error) { - // pollServer is locked. - - var already bool - p.ctlEvent.Fd = int32(fd) - p.ctlEvent.Events, already = p.events[fd] - if !repeat { - p.ctlEvent.Events |= syscall.EPOLLONESHOT - } - if mode == 'r' { - p.ctlEvent.Events |= readFlags - } else { - p.ctlEvent.Events |= writeFlags - } - - var op int - if already { - op = syscall.EPOLL_CTL_MOD - } else { - op = syscall.EPOLL_CTL_ADD - } - if err := syscall.EpollCtl(p.epfd, op, fd, &p.ctlEvent); err != nil { - return false, os.NewSyscallError("epoll_ctl", err) - } - p.events[fd] = p.ctlEvent.Events - return false, nil -} - -func (p *pollster) StopWaiting(fd int, bits uint) { - // pollServer is locked. - - events, already := p.events[fd] - if !already { - // The fd returned by the kernel may have been - // cancelled already; return silently. - return - } - - // If syscall.EPOLLONESHOT is not set, the wait - // is a repeating wait, so don't change it. - if events&syscall.EPOLLONESHOT == 0 { - return - } - - // Disable the given bits. - // If we're still waiting for other events, modify the fd - // event in the kernel. Otherwise, delete it. - events &= ^uint32(bits) - if int32(events)&^syscall.EPOLLONESHOT != 0 { - p.ctlEvent.Fd = int32(fd) - p.ctlEvent.Events = events - if err := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_MOD, fd, &p.ctlEvent); err != nil { - print("Epoll modify fd=", fd, ": ", err.Error(), "\n") - } - p.events[fd] = events - } else { - if err := syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_DEL, fd, nil); err != nil { - print("Epoll delete fd=", fd, ": ", err.Error(), "\n") - } - delete(p.events, fd) - } -} - -func (p *pollster) DelFD(fd int, mode int) bool { - // pollServer is locked. - - if mode == 'r' { - p.StopWaiting(fd, readFlags) - } else { - p.StopWaiting(fd, writeFlags) - } - - // Discard any queued up events. - i := 0 - for i < len(p.waitEvents) { - if fd == int(p.waitEvents[i].Fd) { - copy(p.waitEvents[i:], p.waitEvents[i+1:]) - p.waitEvents = p.waitEvents[:len(p.waitEvents)-1] - } else { - i++ - } - } - return false -} - -func (p *pollster) WaitFD(s *pollServer, nsec int64) (fd int, mode int, err error) { - for len(p.waitEvents) == 0 { - var msec int = -1 - if nsec > 0 { - msec = int((nsec + 1e6 - 1) / 1e6) - } - - s.Unlock() - n, err := syscall.EpollWait(p.epfd, p.waitEventBuf[0:], msec) - s.Lock() - - if err != nil { - if err == syscall.EAGAIN || err == syscall.EINTR { - continue - } - return -1, 0, os.NewSyscallError("epoll_wait", err) - } - if n == 0 { - return -1, 0, nil - } - p.waitEvents = p.waitEventBuf[0:n] - } - - ev := &p.waitEvents[0] - p.waitEvents = p.waitEvents[1:] - - fd = int(ev.Fd) - - if ev.Events&writeFlags != 0 { - p.StopWaiting(fd, writeFlags) - return fd, 'w', nil - } - if ev.Events&readFlags != 0 { - p.StopWaiting(fd, readFlags) - return fd, 'r', nil - } - - // Other events are error conditions - wake whoever is waiting. - events, _ := p.events[fd] - if events&writeFlags != 0 { - p.StopWaiting(fd, writeFlags) - return fd, 'w', nil - } - p.StopWaiting(fd, readFlags) - return fd, 'r', nil -} - -func (p *pollster) Close() error { - return os.NewSyscallError("close", syscall.Close(p.epfd)) -} diff --git a/libgo/go/net/fd_plan9.go b/libgo/go/net/fd_plan9.go index 3462792..e9527a3 100644 --- a/libgo/go/net/fd_plan9.go +++ b/libgo/go/net/fd_plan9.go @@ -23,28 +23,22 @@ var canCancelIO = true // used for testing current package func sysInit() { } -func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) { +func resolveAndDial(net, addr string, localAddr Addr, deadline time.Time) (Conn, error) { // On plan9, use the relatively inefficient // goroutine-racing implementation. - return dialTimeoutRace(net, addr, timeout) + return resolveAndDialChannel(net, addr, localAddr, deadline) } -func newFD(proto, name string, ctl *os.File, laddr, raddr Addr) *netFD { - return &netFD{proto, name, "/net/" + proto + "/" + name, ctl, nil, laddr, raddr} +func newFD(proto, name string, ctl, data *os.File, laddr, raddr Addr) *netFD { + return &netFD{proto, name, "/net/" + proto + "/" + name, ctl, data, laddr, raddr} } func (fd *netFD) ok() bool { return fd != nil && fd.ctl != nil } func (fd *netFD) Read(b []byte) (n int, err error) { - if !fd.ok() { + if !fd.ok() || fd.data == nil { return 0, syscall.EINVAL } - if fd.data == nil { - fd.data, err = os.OpenFile(fd.dir+"/data", os.O_RDWR, 0) - if err != nil { - return 0, err - } - } n, err = fd.data.Read(b) if fd.proto == "udp" && err == io.EOF { n = 0 @@ -54,15 +48,9 @@ func (fd *netFD) Read(b []byte) (n int, err error) { } func (fd *netFD) Write(b []byte) (n int, err error) { - if !fd.ok() { + if !fd.ok() || fd.data == nil { return 0, syscall.EINVAL } - if fd.data == nil { - fd.data, err = os.OpenFile(fd.dir+"/data", os.O_RDWR, 0) - if err != nil { - return 0, err - } - } return fd.data.Write(b) } @@ -85,19 +73,39 @@ func (fd *netFD) Close() error { return syscall.EINVAL } err := fd.ctl.Close() - if err != nil { - return err - } if fd.data != nil { - err = fd.data.Close() + if err1 := fd.data.Close(); err1 != nil && err == nil { + err = err1 + } } fd.ctl = nil fd.data = nil return err } +// This method is only called via Conn. func (fd *netFD) dup() (*os.File, error) { - return nil, syscall.EPLAN9 + if !fd.ok() || fd.data == nil { + return nil, syscall.EINVAL + } + return fd.file(fd.data, fd.dir+"/data") +} + +func (l *TCPListener) dup() (*os.File, error) { + if !l.fd.ok() { + return nil, syscall.EINVAL + } + return l.fd.file(l.fd.ctl, l.fd.dir+"/ctl") +} + +func (fd *netFD) file(f *os.File, s string) (*os.File, error) { + syscall.ForkLock.RLock() + dfd, err := syscall.Dup(int(f.Fd()), -1) + syscall.ForkLock.RUnlock() + if err != nil { + return nil, &OpError{"dup", s, fd.laddr, err} + } + return os.NewFile(uintptr(dfd), s), nil } func setDeadline(fd *netFD, t time.Time) error { diff --git a/libgo/go/net/fd_poll_runtime.go b/libgo/go/net/fd_poll_runtime.go new file mode 100644 index 0000000..e3b4f7e --- /dev/null +++ b/libgo/go/net/fd_poll_runtime.go @@ -0,0 +1,119 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin linux + +package net + +import ( + "sync" + "syscall" + "time" +) + +func runtime_pollServerInit() +func runtime_pollOpen(fd int) (uintptr, int) +func runtime_pollClose(ctx uintptr) +func runtime_pollWait(ctx uintptr, mode int) int +func runtime_pollReset(ctx uintptr, mode int) int +func runtime_pollSetDeadline(ctx uintptr, d int64, mode int) +func runtime_pollUnblock(ctx uintptr) + +var canCancelIO = true // used for testing current package + +type pollDesc struct { + runtimeCtx uintptr +} + +var serverInit sync.Once + +func sysInit() { +} + +func (pd *pollDesc) Init(fd *netFD) error { + serverInit.Do(runtime_pollServerInit) + ctx, errno := runtime_pollOpen(fd.sysfd) + if errno != 0 { + return syscall.Errno(errno) + } + pd.runtimeCtx = ctx + return nil +} + +func (pd *pollDesc) Close() { + runtime_pollClose(pd.runtimeCtx) +} + +func (pd *pollDesc) Lock() { +} + +func (pd *pollDesc) Unlock() { +} + +func (pd *pollDesc) Wakeup() { +} + +// Evict evicts fd from the pending list, unblocking any I/O running on fd. +// Return value is whether the pollServer should be woken up. +func (pd *pollDesc) Evict() bool { + runtime_pollUnblock(pd.runtimeCtx) + return false +} + +func (pd *pollDesc) PrepareRead() error { + res := runtime_pollReset(pd.runtimeCtx, 'r') + return convertErr(res) +} + +func (pd *pollDesc) PrepareWrite() error { + res := runtime_pollReset(pd.runtimeCtx, 'w') + return convertErr(res) +} + +func (pd *pollDesc) WaitRead() error { + res := runtime_pollWait(pd.runtimeCtx, 'r') + return convertErr(res) +} + +func (pd *pollDesc) WaitWrite() error { + res := runtime_pollWait(pd.runtimeCtx, 'w') + return convertErr(res) +} + +func convertErr(res int) error { + switch res { + case 0: + return nil + case 1: + return errClosing + case 2: + return errTimeout + } + panic("unreachable") +} + +func setReadDeadline(fd *netFD, t time.Time) error { + return setDeadlineImpl(fd, t, 'r') +} + +func setWriteDeadline(fd *netFD, t time.Time) error { + return setDeadlineImpl(fd, t, 'w') +} + +func setDeadline(fd *netFD, t time.Time) error { + return setDeadlineImpl(fd, t, 'r'+'w') +} + +func setDeadlineImpl(fd *netFD, t time.Time, mode int) error { + d := t.UnixNano() + if t.IsZero() { + d = 0 + } + if err := fd.incref(false); err != nil { + return err + } + runtime_pollSetDeadline(fd.pd.runtimeCtx, d, mode) + fd.decref() + return nil +} diff --git a/libgo/go/net/fd_poll_unix.go b/libgo/go/net/fd_poll_unix.go new file mode 100644 index 0000000..307e577 --- /dev/null +++ b/libgo/go/net/fd_poll_unix.go @@ -0,0 +1,360 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build freebsd netbsd openbsd + +package net + +import ( + "os" + "runtime" + "sync" + "syscall" + "time" +) + +// A pollServer helps FDs determine when to retry a non-blocking +// read or write after they get EAGAIN. When an FD needs to wait, +// call s.WaitRead() or s.WaitWrite() to pass the request to the poll server. +// When the pollServer finds that i/o on FD should be possible +// again, it will send on fd.cr/fd.cw to wake any waiting goroutines. +// +// To avoid races in closing, all fd operations are locked and +// refcounted. when netFD.Close() is called, it calls syscall.Shutdown +// and sets a closing flag. Only when the last reference is removed +// will the fd be closed. + +type pollServer struct { + pr, pw *os.File + poll *pollster // low-level OS hooks + sync.Mutex // controls pending and deadline + pending map[int]*pollDesc + deadline int64 // next deadline (nsec since 1970) +} + +// A pollDesc contains netFD state related to pollServer. +type pollDesc struct { + // immutable after Init() + pollServer *pollServer + sysfd int + cr, cw chan error + + // mutable, protected by pollServer mutex + closing bool + ncr, ncw int + + // mutable, safe for concurrent access + rdeadline, wdeadline deadline +} + +func newPollServer() (s *pollServer, err error) { + s = new(pollServer) + if s.pr, s.pw, err = os.Pipe(); err != nil { + return nil, err + } + if err = syscall.SetNonblock(int(s.pr.Fd()), true); err != nil { + goto Errno + } + if err = syscall.SetNonblock(int(s.pw.Fd()), true); err != nil { + goto Errno + } + if s.poll, err = newpollster(); err != nil { + goto Error + } + if _, err = s.poll.AddFD(int(s.pr.Fd()), 'r', true); err != nil { + s.poll.Close() + goto Error + } + s.pending = make(map[int]*pollDesc) + go s.Run() + return s, nil + +Errno: + err = &os.PathError{ + Op: "setnonblock", + Path: s.pr.Name(), + Err: err, + } +Error: + s.pr.Close() + s.pw.Close() + return nil, err +} + +func (s *pollServer) AddFD(pd *pollDesc, mode int) error { + s.Lock() + intfd := pd.sysfd + if intfd < 0 || pd.closing { + // fd closed underfoot + s.Unlock() + return errClosing + } + + var t int64 + key := intfd << 1 + if mode == 'r' { + pd.ncr++ + t = pd.rdeadline.value() + } else { + pd.ncw++ + key++ + t = pd.wdeadline.value() + } + s.pending[key] = pd + doWakeup := false + if t > 0 && (s.deadline == 0 || t < s.deadline) { + s.deadline = t + doWakeup = true + } + + wake, err := s.poll.AddFD(intfd, mode, false) + s.Unlock() + if err != nil { + return err + } + if wake || doWakeup { + s.Wakeup() + } + return nil +} + +// Evict evicts pd from the pending list, unblocking +// any I/O running on pd. The caller must have locked +// pollserver. +// Return value is whether the pollServer should be woken up. +func (s *pollServer) Evict(pd *pollDesc) bool { + pd.closing = true + doWakeup := false + if s.pending[pd.sysfd<<1] == pd { + s.WakeFD(pd, 'r', errClosing) + if s.poll.DelFD(pd.sysfd, 'r') { + doWakeup = true + } + delete(s.pending, pd.sysfd<<1) + } + if s.pending[pd.sysfd<<1|1] == pd { + s.WakeFD(pd, 'w', errClosing) + if s.poll.DelFD(pd.sysfd, 'w') { + doWakeup = true + } + delete(s.pending, pd.sysfd<<1|1) + } + return doWakeup +} + +var wakeupbuf [1]byte + +func (s *pollServer) Wakeup() { s.pw.Write(wakeupbuf[0:]) } + +func (s *pollServer) LookupFD(fd int, mode int) *pollDesc { + key := fd << 1 + if mode == 'w' { + key++ + } + netfd, ok := s.pending[key] + if !ok { + return nil + } + delete(s.pending, key) + return netfd +} + +func (s *pollServer) WakeFD(pd *pollDesc, mode int, err error) { + if mode == 'r' { + for pd.ncr > 0 { + pd.ncr-- + pd.cr <- err + } + } else { + for pd.ncw > 0 { + pd.ncw-- + pd.cw <- err + } + } +} + +func (s *pollServer) CheckDeadlines() { + now := time.Now().UnixNano() + // TODO(rsc): This will need to be handled more efficiently, + // probably with a heap indexed by wakeup time. + + var nextDeadline int64 + for key, pd := range s.pending { + var t int64 + var mode int + if key&1 == 0 { + mode = 'r' + } else { + mode = 'w' + } + if mode == 'r' { + t = pd.rdeadline.value() + } else { + t = pd.wdeadline.value() + } + if t > 0 { + if t <= now { + delete(s.pending, key) + s.poll.DelFD(pd.sysfd, mode) + s.WakeFD(pd, mode, errTimeout) + } else if nextDeadline == 0 || t < nextDeadline { + nextDeadline = t + } + } + } + s.deadline = nextDeadline +} + +func (s *pollServer) Run() { + var scratch [100]byte + s.Lock() + defer s.Unlock() + for { + var timeout int64 // nsec to wait for or 0 for none + if s.deadline > 0 { + timeout = s.deadline - time.Now().UnixNano() + if timeout <= 0 { + s.CheckDeadlines() + continue + } + } + fd, mode, err := s.poll.WaitFD(s, timeout) + if err != nil { + print("pollServer WaitFD: ", err.Error(), "\n") + return + } + if fd < 0 { + // Timeout happened. + s.CheckDeadlines() + continue + } + if fd == int(s.pr.Fd()) { + // Drain our wakeup pipe (we could loop here, + // but it's unlikely that there are more than + // len(scratch) wakeup calls). + s.pr.Read(scratch[0:]) + s.CheckDeadlines() + } else { + pd := s.LookupFD(fd, mode) + if pd == nil { + // This can happen because the WaitFD runs without + // holding s's lock, so there might be a pending wakeup + // for an fd that has been evicted. No harm done. + continue + } + s.WakeFD(pd, mode, nil) + } + } +} + +func (pd *pollDesc) Close() { +} + +func (pd *pollDesc) Lock() { + pd.pollServer.Lock() +} + +func (pd *pollDesc) Unlock() { + pd.pollServer.Unlock() +} + +func (pd *pollDesc) Wakeup() { + pd.pollServer.Wakeup() +} + +func (pd *pollDesc) PrepareRead() error { + if pd.rdeadline.expired() { + return errTimeout + } + return nil +} + +func (pd *pollDesc) PrepareWrite() error { + if pd.wdeadline.expired() { + return errTimeout + } + return nil +} + +func (pd *pollDesc) WaitRead() error { + err := pd.pollServer.AddFD(pd, 'r') + if err == nil { + err = <-pd.cr + } + return err +} + +func (pd *pollDesc) WaitWrite() error { + err := pd.pollServer.AddFD(pd, 'w') + if err == nil { + err = <-pd.cw + } + return err +} + +func (pd *pollDesc) Evict() bool { + return pd.pollServer.Evict(pd) +} + +// Spread network FDs over several pollServers. + +var pollMaxN int +var pollservers []*pollServer +var startServersOnce []func() + +var canCancelIO = true // used for testing current package + +func sysInit() { + pollMaxN = runtime.NumCPU() + if pollMaxN > 8 { + pollMaxN = 8 // No improvement then. + } + pollservers = make([]*pollServer, pollMaxN) + startServersOnce = make([]func(), pollMaxN) + for i := 0; i < pollMaxN; i++ { + k := i + once := new(sync.Once) + startServersOnce[i] = func() { once.Do(func() { startServer(k) }) } + } +} + +func startServer(k int) { + p, err := newPollServer() + if err != nil { + panic(err) + } + pollservers[k] = p +} + +func (pd *pollDesc) Init(fd *netFD) error { + pollN := runtime.GOMAXPROCS(0) + if pollN > pollMaxN { + pollN = pollMaxN + } + k := fd.sysfd % pollN + startServersOnce[k]() + pd.sysfd = fd.sysfd + pd.pollServer = pollservers[k] + pd.cr = make(chan error, 1) + pd.cw = make(chan error, 1) + return nil +} + +// TODO(dfc) these unused error returns could be removed + +func setReadDeadline(fd *netFD, t time.Time) error { + fd.pd.rdeadline.setTime(t) + return nil +} + +func setWriteDeadline(fd *netFD, t time.Time) error { + fd.pd.wdeadline.setTime(t) + return nil +} + +func setDeadline(fd *netFD, t time.Time) error { + setReadDeadline(fd, t) + setWriteDeadline(fd, t) + return nil +} diff --git a/libgo/go/net/fd_unix.go b/libgo/go/net/fd_unix.go index 42b0c74..8c59bff 100644 --- a/libgo/go/net/fd_unix.go +++ b/libgo/go/net/fd_unix.go @@ -9,7 +9,6 @@ package net import ( "io" "os" - "runtime" "sync" "syscall" "time" @@ -21,7 +20,7 @@ type netFD struct { sysmu sync.Mutex sysref int - // must lock both sysmu and pollserver to write + // must lock both sysmu and pollDesc to write // can lock either to read closing bool @@ -31,8 +30,6 @@ type netFD struct { sotype int isConnected bool sysfile *os.File - cr chan error - cw chan error net string laddr Addr raddr Addr @@ -40,269 +37,16 @@ type netFD struct { // serialize access to Read and Write methods rio, wio sync.Mutex - // read and write deadlines - rdeadline, wdeadline deadline - - // owned by fd wait server - ncr, ncw int - // wait server - pollServer *pollServer -} - -// A pollServer helps FDs determine when to retry a non-blocking -// read or write after they get EAGAIN. When an FD needs to wait, -// call s.WaitRead() or s.WaitWrite() to pass the request to the poll server. -// When the pollServer finds that i/o on FD should be possible -// again, it will send on fd.cr/fd.cw to wake any waiting goroutines. -// -// To avoid races in closing, all fd operations are locked and -// refcounted. when netFD.Close() is called, it calls syscall.Shutdown -// and sets a closing flag. Only when the last reference is removed -// will the fd be closed. - -type pollServer struct { - pr, pw *os.File - poll *pollster // low-level OS hooks - sync.Mutex // controls pending and deadline - pending map[int]*netFD - deadline int64 // next deadline (nsec since 1970) -} - -func (s *pollServer) AddFD(fd *netFD, mode int) error { - s.Lock() - intfd := fd.sysfd - if intfd < 0 || fd.closing { - // fd closed underfoot - s.Unlock() - return errClosing - } - - var t int64 - key := intfd << 1 - if mode == 'r' { - fd.ncr++ - t = fd.rdeadline.value() - } else { - fd.ncw++ - key++ - t = fd.wdeadline.value() - } - s.pending[key] = fd - doWakeup := false - if t > 0 && (s.deadline == 0 || t < s.deadline) { - s.deadline = t - doWakeup = true - } - - wake, err := s.poll.AddFD(intfd, mode, false) - s.Unlock() - if err != nil { - return &OpError{"addfd", fd.net, fd.laddr, err} - } - if wake || doWakeup { - s.Wakeup() - } - return nil -} - -// Evict evicts fd from the pending list, unblocking -// any I/O running on fd. The caller must have locked -// pollserver. -func (s *pollServer) Evict(fd *netFD) { - doWakeup := false - if s.pending[fd.sysfd<<1] == fd { - s.WakeFD(fd, 'r', errClosing) - if s.poll.DelFD(fd.sysfd, 'r') { - doWakeup = true - } - delete(s.pending, fd.sysfd<<1) - } - if s.pending[fd.sysfd<<1|1] == fd { - s.WakeFD(fd, 'w', errClosing) - if s.poll.DelFD(fd.sysfd, 'w') { - doWakeup = true - } - delete(s.pending, fd.sysfd<<1|1) - } - if doWakeup { - s.Wakeup() - } -} - -var wakeupbuf [1]byte - -func (s *pollServer) Wakeup() { s.pw.Write(wakeupbuf[0:]) } - -func (s *pollServer) LookupFD(fd int, mode int) *netFD { - key := fd << 1 - if mode == 'w' { - key++ - } - netfd, ok := s.pending[key] - if !ok { - return nil - } - delete(s.pending, key) - return netfd -} - -func (s *pollServer) WakeFD(fd *netFD, mode int, err error) { - if mode == 'r' { - for fd.ncr > 0 { - fd.ncr-- - fd.cr <- err - } - } else { - for fd.ncw > 0 { - fd.ncw-- - fd.cw <- err - } - } -} - -func (s *pollServer) CheckDeadlines() { - now := time.Now().UnixNano() - // TODO(rsc): This will need to be handled more efficiently, - // probably with a heap indexed by wakeup time. - - var nextDeadline int64 - for key, fd := range s.pending { - var t int64 - var mode int - if key&1 == 0 { - mode = 'r' - } else { - mode = 'w' - } - if mode == 'r' { - t = fd.rdeadline.value() - } else { - t = fd.wdeadline.value() - } - if t > 0 { - if t <= now { - delete(s.pending, key) - if mode == 'r' { - s.poll.DelFD(fd.sysfd, mode) - } else { - s.poll.DelFD(fd.sysfd, mode) - } - s.WakeFD(fd, mode, errTimeout) - } else if nextDeadline == 0 || t < nextDeadline { - nextDeadline = t - } - } - } - s.deadline = nextDeadline -} - -func (s *pollServer) Run() { - var scratch [100]byte - s.Lock() - defer s.Unlock() - for { - var timeout int64 // nsec to wait for or 0 for none - if s.deadline > 0 { - timeout = s.deadline - time.Now().UnixNano() - if timeout <= 0 { - s.CheckDeadlines() - continue - } - } - fd, mode, err := s.poll.WaitFD(s, timeout) - if err != nil { - print("pollServer WaitFD: ", err.Error(), "\n") - return - } - if fd < 0 { - // Timeout happened. - s.CheckDeadlines() - continue - } - if fd == int(s.pr.Fd()) { - // Drain our wakeup pipe (we could loop here, - // but it's unlikely that there are more than - // len(scratch) wakeup calls). - s.pr.Read(scratch[0:]) - s.CheckDeadlines() - } else { - netfd := s.LookupFD(fd, mode) - if netfd == nil { - // This can happen because the WaitFD runs without - // holding s's lock, so there might be a pending wakeup - // for an fd that has been evicted. No harm done. - continue - } - s.WakeFD(netfd, mode, nil) - } - } -} - -func (s *pollServer) WaitRead(fd *netFD) error { - err := s.AddFD(fd, 'r') - if err == nil { - err = <-fd.cr - } - return err -} - -func (s *pollServer) WaitWrite(fd *netFD) error { - err := s.AddFD(fd, 'w') - if err == nil { - err = <-fd.cw - } - return err -} - -// Network FD methods. -// Spread network FDs over several pollServers. - -var pollMaxN int -var pollservers []*pollServer -var startServersOnce []func() - -var canCancelIO = true // used for testing current package - -func sysInit() { - pollMaxN = runtime.NumCPU() - if pollMaxN > 8 { - pollMaxN = 8 // No improvement then. - } - pollservers = make([]*pollServer, pollMaxN) - startServersOnce = make([]func(), pollMaxN) - for i := 0; i < pollMaxN; i++ { - k := i - once := new(sync.Once) - startServersOnce[i] = func() { once.Do(func() { startServer(k) }) } - } + pd pollDesc } -func startServer(k int) { - p, err := newPollServer() - if err != nil { - panic(err) - } - pollservers[k] = p -} - -func server(fd int) *pollServer { - pollN := runtime.GOMAXPROCS(0) - if pollN > pollMaxN { - pollN = pollMaxN - } - k := fd % pollN - startServersOnce[k]() - return pollservers[k] -} - -func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) { - deadline := time.Now().Add(timeout) - _, addri, err := resolveNetAddr("dial", net, addr, deadline) +func resolveAndDial(net, addr string, localAddr Addr, deadline time.Time) (Conn, error) { + ra, err := resolveAddr("dial", net, addr, deadline) if err != nil { return nil, err } - return dialAddr(net, addr, addri, deadline) + return dial(net, addr, localAddr, ra, deadline) } func newFD(fd, family, sotype int, net string) (*netFD, error) { @@ -312,9 +56,9 @@ func newFD(fd, family, sotype int, net string) (*netFD, error) { sotype: sotype, net: net, } - netfd.cr = make(chan error, 1) - netfd.cw = make(chan error, 1) - netfd.pollServer = server(fd) + if err := netfd.pd.Init(netfd); err != nil { + return nil, err + } return netfd, nil } @@ -335,26 +79,29 @@ func (fd *netFD) name() string { return fd.net + ":" + ls + "->" + rs } -func (fd *netFD) connect(ra syscall.Sockaddr) error { - err := syscall.Connect(fd.sysfd, ra) - if err == syscall.EINPROGRESS { - if err = fd.pollServer.WaitWrite(fd); err != nil { - return err +func (fd *netFD) connect(la, ra syscall.Sockaddr) error { + fd.wio.Lock() + defer fd.wio.Unlock() + if err := fd.pd.PrepareWrite(); err != nil { + return err + } + for { + err := syscall.Connect(fd.sysfd, ra) + if err == nil || err == syscall.EISCONN { + break } - var e int - e, err = syscall.GetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR) - if err != nil { - return os.NewSyscallError("getsockopt", err) + if err != syscall.EINPROGRESS && err != syscall.EALREADY && err != syscall.EINTR { + return err } - if e != 0 { - err = syscall.Errno(e) + if err = fd.pd.WaitWrite(); err != nil { + return err } } - return err + return nil } // Add a reference to this fd. -// If closing==true, pollserver must be locked; mark the fd as closing. +// If closing==true, pollDesc must be locked; mark the fd as closing. // Returns an error if the fd cannot be used. func (fd *netFD) incref(closing bool) error { fd.sysmu.Lock() @@ -375,28 +122,38 @@ func (fd *netFD) incref(closing bool) error { func (fd *netFD) decref() { fd.sysmu.Lock() fd.sysref-- - if fd.closing && fd.sysref == 0 && fd.sysfile != nil { - fd.sysfile.Close() - fd.sysfile = nil + if fd.closing && fd.sysref == 0 { + // Poller may want to unregister fd in readiness notification mechanism, + // so this must be executed before sysfile.Close(). + fd.pd.Close() + if fd.sysfile != nil { + fd.sysfile.Close() + fd.sysfile = nil + } else { + closesocket(fd.sysfd) + } fd.sysfd = -1 } fd.sysmu.Unlock() } func (fd *netFD) Close() error { - fd.pollServer.Lock() // needed for both fd.incref(true) and pollserver.Evict + fd.pd.Lock() // needed for both fd.incref(true) and pollDesc.Evict if err := fd.incref(true); err != nil { - fd.pollServer.Unlock() + fd.pd.Unlock() return err } // Unblock any I/O. Once it all unblocks and returns, // so that it cannot be referring to fd.sysfd anymore, // the final decref will close fd.sysfd. This should happen // fairly quickly, since all the I/O is non-blocking, and any - // attempts to block in the pollserver will return errClosing. - fd.pollServer.Evict(fd) - fd.pollServer.Unlock() + // attempts to block in the pollDesc will return errClosing. + doWakeup := fd.pd.Evict() + fd.pd.Unlock() fd.decref() + if doWakeup { + fd.pd.Wakeup() + } return nil } @@ -427,16 +184,15 @@ func (fd *netFD) Read(p []byte) (n int, err error) { return 0, err } defer fd.decref() + if err := fd.pd.PrepareRead(); err != nil { + return 0, &OpError{"read", fd.net, fd.raddr, err} + } for { - if fd.rdeadline.expired() { - err = errTimeout - break - } n, err = syscall.Read(int(fd.sysfd), p) if err != nil { n = 0 if err == syscall.EAGAIN { - if err = fd.pollServer.WaitRead(fd); err == nil { + if err = fd.pd.WaitRead(); err == nil { continue } } @@ -457,16 +213,15 @@ func (fd *netFD) ReadFrom(p []byte) (n int, sa syscall.Sockaddr, err error) { return 0, nil, err } defer fd.decref() + if err := fd.pd.PrepareRead(); err != nil { + return 0, nil, &OpError{"read", fd.net, fd.laddr, err} + } for { - if fd.rdeadline.expired() { - err = errTimeout - break - } n, sa, err = syscall.Recvfrom(fd.sysfd, p, 0) if err != nil { n = 0 if err == syscall.EAGAIN { - if err = fd.pollServer.WaitRead(fd); err == nil { + if err = fd.pd.WaitRead(); err == nil { continue } } @@ -487,16 +242,15 @@ func (fd *netFD) ReadMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.S return 0, 0, 0, nil, err } defer fd.decref() + if err := fd.pd.PrepareRead(); err != nil { + return 0, 0, 0, nil, &OpError{"read", fd.net, fd.laddr, err} + } for { - if fd.rdeadline.expired() { - err = errTimeout - break - } n, oobn, flags, sa, err = syscall.Recvmsg(fd.sysfd, p, oob, 0) if err != nil { // TODO(dfc) should n and oobn be set to 0 if err == syscall.EAGAIN { - if err = fd.pollServer.WaitRead(fd); err == nil { + if err = fd.pd.WaitRead(); err == nil { continue } } @@ -524,11 +278,10 @@ func (fd *netFD) Write(p []byte) (nn int, err error) { return 0, err } defer fd.decref() + if err := fd.pd.PrepareWrite(); err != nil { + return 0, &OpError{"write", fd.net, fd.raddr, err} + } for { - if fd.wdeadline.expired() { - err = errTimeout - break - } var n int n, err = syscall.Write(int(fd.sysfd), p[nn:]) if n > 0 { @@ -538,7 +291,7 @@ func (fd *netFD) Write(p []byte) (nn int, err error) { break } if err == syscall.EAGAIN { - if err = fd.pollServer.WaitWrite(fd); err == nil { + if err = fd.pd.WaitWrite(); err == nil { continue } } @@ -564,14 +317,13 @@ func (fd *netFD) WriteTo(p []byte, sa syscall.Sockaddr) (n int, err error) { return 0, err } defer fd.decref() + if err := fd.pd.PrepareWrite(); err != nil { + return 0, &OpError{"write", fd.net, fd.raddr, err} + } for { - if fd.wdeadline.expired() { - err = errTimeout - break - } err = syscall.Sendto(fd.sysfd, p, 0, sa) if err == syscall.EAGAIN { - if err = fd.pollServer.WaitWrite(fd); err == nil { + if err = fd.pd.WaitWrite(); err == nil { continue } } @@ -592,14 +344,13 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob return 0, 0, err } defer fd.decref() + if err := fd.pd.PrepareWrite(); err != nil { + return 0, 0, &OpError{"write", fd.net, fd.raddr, err} + } for { - if fd.wdeadline.expired() { - err = errTimeout - break - } err = syscall.Sendmsg(fd.sysfd, p, oob, sa, 0) if err == syscall.EAGAIN { - if err = fd.pollServer.WaitWrite(fd); err == nil { + if err = fd.pd.WaitWrite(); err == nil { continue } } @@ -615,6 +366,8 @@ func (fd *netFD) WriteMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob } func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err error) { + fd.rio.Lock() + defer fd.rio.Unlock() if err := fd.incref(false); err != nil { return nil, err } @@ -622,11 +375,14 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err e var s int var rsa syscall.Sockaddr + if err = fd.pd.PrepareRead(); err != nil { + return nil, &OpError{"accept", fd.net, fd.laddr, err} + } for { s, rsa, err = accept(fd.sysfd) if err != nil { if err == syscall.EAGAIN { - if err = fd.pollServer.WaitRead(fd); err == nil { + if err = fd.pd.WaitRead(); err == nil { continue } } else if err == syscall.ECONNABORTED { @@ -659,6 +415,9 @@ func (fd *netFD) dup() (f *os.File, err error) { syscall.ForkLock.RUnlock() // We want blocking mode for the new fd, hence the double negative. + // This also puts the old fd into blocking mode, meaning that + // I/O will block the thread instead of letting us use the epoll server. + // Everything will still work, just with more threads. if err = syscall.SetNonblock(ns, false); err != nil { return nil, &OpError{"setnonblock", fd.net, fd.laddr, err} } diff --git a/libgo/go/net/fd_unix_test.go b/libgo/go/net/fd_unix_test.go index fd1385e..664ef1b 100644 --- a/libgo/go/net/fd_unix_test.go +++ b/libgo/go/net/fd_unix_test.go @@ -12,54 +12,6 @@ import ( "testing" ) -// Issue 3590. netFd.AddFD should return an error -// from the underlying pollster rather than panicing. -func TestAddFDReturnsError(t *testing.T) { - ln := newLocalListener(t).(*TCPListener) - defer ln.Close() - connected := make(chan bool) - go func() { - for { - c, err := ln.Accept() - if err != nil { - return - } - connected <- true - defer c.Close() - } - }() - - c, err := DialTCP("tcp", nil, ln.Addr().(*TCPAddr)) - if err != nil { - t.Fatal(err) - } - defer c.Close() - <-connected - - // replace c's pollServer with a closed version. - ps, err := newPollServer() - if err != nil { - t.Fatal(err) - } - ps.poll.Close() - c.conn.fd.pollServer = ps - - var b [1]byte - _, err = c.Read(b[:]) - if err, ok := err.(*OpError); ok { - if err.Op == "addfd" { - return - } - if err, ok := err.Err.(*OpError); ok { - // the err is sometimes wrapped by another OpError - if err.Op == "addfd" { - return - } - } - } - t.Error("unexpected error:", err) -} - var chkReadErrTests = []struct { n int err error diff --git a/libgo/go/net/fd_windows.go b/libgo/go/net/fd_windows.go index ea6ef10..fefd174 100644 --- a/libgo/go/net/fd_windows.go +++ b/libgo/go/net/fd_windows.go @@ -37,6 +37,7 @@ func sysInit() { } canCancelIO = syscall.LoadCancelIoEx() == nil if syscall.LoadGetAddrInfo() == nil { + lookupPort = newLookupPort lookupIP = newLookupIP } } @@ -53,18 +54,17 @@ func canUseConnectEx(net string) bool { return syscall.LoadConnectEx() == nil } -func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) { +func resolveAndDial(net, addr string, localAddr Addr, deadline time.Time) (Conn, error) { if !canUseConnectEx(net) { // Use the relatively inefficient goroutine-racing // implementation of DialTimeout. - return dialTimeoutRace(net, addr, timeout) + return resolveAndDialChannel(net, addr, localAddr, deadline) } - deadline := time.Now().Add(timeout) - _, addri, err := resolveNetAddr("dial", net, addr, deadline) + ra, err := resolveAddr("dial", net, addr, deadline) if err != nil { return nil, err } - return dialAddr(net, addr, addri, deadline) + return dial(net, addr, localAddr, ra, deadline) } // Interface for all IO operations. @@ -137,12 +137,18 @@ type resultSrv struct { iocp syscall.Handle } +func runtime_blockingSyscallHint() + func (s *resultSrv) Run() { var o *syscall.Overlapped var key uint32 var r ioResult for { - r.err = syscall.GetQueuedCompletionStatus(s.iocp, &(r.qty), &key, &o, syscall.INFINITE) + r.err = syscall.GetQueuedCompletionStatus(s.iocp, &(r.qty), &key, &o, 0) + if r.err == syscall.Errno(syscall.WAIT_TIMEOUT) && o == nil { + runtime_blockingSyscallHint() + r.err = syscall.GetQueuedCompletionStatus(s.iocp, &(r.qty), &key, &o, syscall.INFINITE) + } switch { case r.err == nil: // Dequeued successfully completed IO packet. @@ -358,22 +364,23 @@ func (o *connectOp) Name() string { return "ConnectEx" } -func (fd *netFD) connect(ra syscall.Sockaddr) error { +func (fd *netFD) connect(la, ra syscall.Sockaddr) error { if !canUseConnectEx(fd.net) { return syscall.Connect(fd.sysfd, ra) } // ConnectEx windows API requires an unconnected, previously bound socket. - var la syscall.Sockaddr - switch ra.(type) { - case *syscall.SockaddrInet4: - la = &syscall.SockaddrInet4{} - case *syscall.SockaddrInet6: - la = &syscall.SockaddrInet6{} - default: - panic("unexpected type in connect") - } - if err := syscall.Bind(fd.sysfd, la); err != nil { - return err + if la == nil { + switch ra.(type) { + case *syscall.SockaddrInet4: + la = &syscall.SockaddrInet4{} + case *syscall.SockaddrInet6: + la = &syscall.SockaddrInet6{} + default: + panic("unexpected type in connect") + } + if err := syscall.Bind(fd.sysfd, la); err != nil { + return err + } } // Call ConnectEx API. var o connectOp @@ -618,15 +625,10 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (*netFD, error) { defer fd.decref() // Get new socket. - // See ../syscall/exec_unix.go for description of ForkLock. - syscall.ForkLock.RLock() - s, err := syscall.Socket(fd.family, fd.sotype, 0) + s, err := sysSocket(fd.family, fd.sotype, 0) if err != nil { - syscall.ForkLock.RUnlock() return nil, &OpError{"socket", fd.net, fd.laddr, err} } - syscall.CloseOnExec(s) - syscall.ForkLock.RUnlock() // Associate our new socket with IOCP. onceStartServer.Do(startServer) diff --git a/libgo/go/net/file_plan9.go b/libgo/go/net/file_plan9.go index ae3ac15..f6ee1c2 100644 --- a/libgo/go/net/file_plan9.go +++ b/libgo/go/net/file_plan9.go @@ -5,16 +5,139 @@ package net import ( + "errors" + "io" "os" "syscall" ) +func (fd *netFD) status(ln int) (string, error) { + if !fd.ok() { + return "", syscall.EINVAL + } + + status, err := os.Open(fd.dir + "/status") + if err != nil { + return "", err + } + defer status.Close() + buf := make([]byte, ln) + n, err := io.ReadFull(status, buf[:]) + if err != nil { + return "", err + } + return string(buf[:n]), nil +} + +func newFileFD(f *os.File) (net *netFD, err error) { + var ctl *os.File + close := func(fd int) { + if err != nil { + syscall.Close(fd) + } + } + + path, err := syscall.Fd2path(int(f.Fd())) + if err != nil { + return nil, os.NewSyscallError("fd2path", err) + } + comp := splitAtBytes(path, "/") + n := len(comp) + if n < 3 || comp[0] != "net" { + return nil, syscall.EPLAN9 + } + + name := comp[2] + switch file := comp[n-1]; file { + case "ctl", "clone": + syscall.ForkLock.RLock() + fd, err := syscall.Dup(int(f.Fd()), -1) + syscall.ForkLock.RUnlock() + if err != nil { + return nil, os.NewSyscallError("dup", err) + } + defer close(fd) + + dir := "/net/" + comp[n-2] + ctl = os.NewFile(uintptr(fd), dir+"/"+file) + ctl.Seek(0, 0) + var buf [16]byte + n, err := ctl.Read(buf[:]) + if err != nil { + return nil, err + } + name = string(buf[:n]) + default: + if len(comp) < 4 { + return nil, errors.New("could not find control file for connection") + } + dir := "/net/" + comp[1] + "/" + name + ctl, err = os.OpenFile(dir+"/ctl", os.O_RDWR, 0) + if err != nil { + return nil, err + } + defer close(int(ctl.Fd())) + } + dir := "/net/" + comp[1] + "/" + name + laddr, err := readPlan9Addr(comp[1], dir+"/local") + if err != nil { + return nil, err + } + return newFD(comp[1], name, ctl, nil, laddr, nil), nil +} + +func newFileConn(f *os.File) (c Conn, err error) { + fd, err := newFileFD(f) + if err != nil { + return nil, err + } + if !fd.ok() { + return nil, syscall.EINVAL + } + + fd.data, err = os.OpenFile(fd.dir+"/data", os.O_RDWR, 0) + if err != nil { + return nil, err + } + + switch fd.laddr.(type) { + case *TCPAddr: + return newTCPConn(fd), nil + case *UDPAddr: + return newUDPConn(fd), nil + } + return nil, syscall.EPLAN9 +} + +func newFileListener(f *os.File) (l Listener, err error) { + fd, err := newFileFD(f) + if err != nil { + return nil, err + } + switch fd.laddr.(type) { + case *TCPAddr: + default: + return nil, syscall.EPLAN9 + } + + // check that file corresponds to a listener + s, err := fd.status(len("Listen")) + if err != nil { + return nil, err + } + if s != "Listen" { + return nil, errors.New("file does not represent a listener") + } + + return &TCPListener{fd}, nil +} + // FileConn returns a copy of the network connection corresponding to // the open file f. It is the caller's responsibility to close f when // finished. Closing c does not affect f, and closing f does not // affect c. func FileConn(f *os.File) (c Conn, err error) { - return nil, syscall.EPLAN9 + return newFileConn(f) } // FileListener returns a copy of the network listener corresponding @@ -22,7 +145,7 @@ func FileConn(f *os.File) (c Conn, err error) { // when finished. Closing l does not affect f, and closing f does not // affect l. func FileListener(f *os.File) (l Listener, err error) { - return nil, syscall.EPLAN9 + return newFileListener(f) } // FilePacketConn returns a copy of the packet network connection diff --git a/libgo/go/net/file_test.go b/libgo/go/net/file_test.go index 78c6222..acaf188 100644 --- a/libgo/go/net/file_test.go +++ b/libgo/go/net/file_test.go @@ -89,7 +89,7 @@ var fileListenerTests = []struct { func TestFileListener(t *testing.T) { switch runtime.GOOS { - case "plan9", "windows": + case "windows": t.Skipf("skipping test on %q", runtime.GOOS) } diff --git a/libgo/go/net/file_windows.go b/libgo/go/net/file_windows.go index c50c32e..ca2b9b2 100644 --- a/libgo/go/net/file_windows.go +++ b/libgo/go/net/file_windows.go @@ -9,16 +9,28 @@ import ( "syscall" ) +// FileConn returns a copy of the network connection corresponding to +// the open file f. It is the caller's responsibility to close f when +// finished. Closing c does not affect f, and closing f does not +// affect c. func FileConn(f *os.File) (c Conn, err error) { // TODO: Implement this return nil, os.NewSyscallError("FileConn", syscall.EWINDOWS) } +// FileListener returns a copy of the network listener corresponding +// to the open file f. It is the caller's responsibility to close l +// when finished. Closing l does not affect f, and closing f does not +// affect l. func FileListener(f *os.File) (l Listener, err error) { // TODO: Implement this return nil, os.NewSyscallError("FileListener", syscall.EWINDOWS) } +// FilePacketConn returns a copy of the packet network connection +// corresponding to the open file f. It is the caller's +// responsibility to close f when finished. Closing c does not affect +// f, and closing f does not affect c. func FilePacketConn(f *os.File) (c PacketConn, err error) { // TODO: Implement this return nil, os.NewSyscallError("FilePacketConn", syscall.EWINDOWS) diff --git a/libgo/go/net/http/cgi/host_test.go b/libgo/go/net/http/cgi/host_test.go index cb6f1df..b514e10 100644 --- a/libgo/go/net/http/cgi/host_test.go +++ b/libgo/go/net/http/cgi/host_test.go @@ -19,7 +19,6 @@ import ( "runtime" "strconv" "strings" - "syscall" "testing" "time" ) @@ -340,11 +339,7 @@ func TestCopyError(t *testing.T) { } childRunning := func() bool { - p, err := os.FindProcess(pid) - if err != nil { - return false - } - return p.Signal(syscall.Signal(0)) == nil + return isProcessRunning(t, pid) } if !childRunning() { diff --git a/libgo/go/net/http/cgi/posix_test.go b/libgo/go/net/http/cgi/posix_test.go new file mode 100644 index 0000000..5ff9e7d --- /dev/null +++ b/libgo/go/net/http/cgi/posix_test.go @@ -0,0 +1,21 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !plan9 + +package cgi + +import ( + "os" + "syscall" + "testing" +) + +func isProcessRunning(t *testing.T, pid int) bool { + p, err := os.FindProcess(pid) + if err != nil { + return false + } + return p.Signal(syscall.Signal(0)) == nil +} diff --git a/libgo/go/net/http/cgi/testdata/test.cgi b/libgo/go/net/http/cgi/testdata/test.cgi index 1b25bc2..3214df6 100644 --- a/libgo/go/net/http/cgi/testdata/test.cgi +++ b/libgo/go/net/http/cgi/testdata/test.cgi @@ -24,7 +24,8 @@ print "X-Test-Header: X-Test-Value\r\n"; print "\r\n"; if ($params->{"bigresponse"}) { - for (1..1024) { + # 17 MB, for OS X: golang.org/issue/4958 + for (1..(17 * 1024)) { print "A" x 1024, "\r\n"; } exit 0; diff --git a/libgo/go/net/http/client.go b/libgo/go/net/http/client.go index 5ee0804..a34d47b 100644 --- a/libgo/go/net/http/client.go +++ b/libgo/go/net/http/client.go @@ -19,12 +19,16 @@ import ( "strings" ) -// A Client is an HTTP client. Its zero value (DefaultClient) is a usable client -// that uses DefaultTransport. +// A Client is an HTTP client. Its zero value (DefaultClient) is a +// usable client that uses DefaultTransport. // -// The Client's Transport typically has internal state (cached -// TCP connections), so Clients should be reused instead of created as +// The Client's Transport typically has internal state (cached TCP +// connections), so Clients should be reused instead of created as // needed. Clients are safe for concurrent use by multiple goroutines. +// +// A Client is higher-level than a RoundTripper (such as Transport) +// and additionally handles HTTP details such as cookies and +// redirects. type Client struct { // Transport specifies the mechanism by which individual // HTTP requests are made. diff --git a/libgo/go/net/http/client_test.go b/libgo/go/net/http/client_test.go index 9514a4b..73f1fe3 100644 --- a/libgo/go/net/http/client_test.go +++ b/libgo/go/net/http/client_test.go @@ -51,10 +51,10 @@ func pedanticReadAll(r io.Reader) (b []byte, err error) { return b, err } } - panic("unreachable") } func TestClient(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(robotsTxtHandler) defer ts.Close() @@ -72,6 +72,7 @@ func TestClient(t *testing.T) { } func TestClientHead(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(robotsTxtHandler) defer ts.Close() @@ -94,6 +95,7 @@ func (t *recordingTransport) RoundTrip(req *Request) (resp *Response, err error) } func TestGetRequestFormat(t *testing.T) { + defer afterTest(t) tr := &recordingTransport{} client := &Client{Transport: tr} url := "http://dummy.faketld/" @@ -110,6 +112,7 @@ func TestGetRequestFormat(t *testing.T) { } func TestPostRequestFormat(t *testing.T) { + defer afterTest(t) tr := &recordingTransport{} client := &Client{Transport: tr} @@ -136,6 +139,7 @@ func TestPostRequestFormat(t *testing.T) { } func TestPostFormRequestFormat(t *testing.T) { + defer afterTest(t) tr := &recordingTransport{} client := &Client{Transport: tr} @@ -176,7 +180,8 @@ func TestPostFormRequestFormat(t *testing.T) { } } -func TestRedirects(t *testing.T) { +func TestClientRedirects(t *testing.T) { + defer afterTest(t) var ts *httptest.Server ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { n, _ := strconv.Atoi(r.FormValue("n")) @@ -223,6 +228,7 @@ func TestRedirects(t *testing.T) { if err != nil { t.Fatalf("Get error: %v", err) } + res.Body.Close() finalUrl := res.Request.URL.String() if e, g := "<nil>", fmt.Sprintf("%v", err); e != g { t.Errorf("with custom client, expected error %q, got %q", e, g) @@ -242,12 +248,14 @@ func TestRedirects(t *testing.T) { if res == nil { t.Fatalf("Expected a non-nil Response on CheckRedirect failure (http://golang.org/issue/3795)") } + res.Body.Close() if res.Header.Get("Location") == "" { t.Errorf("no Location header in Response") } } func TestPostRedirects(t *testing.T) { + defer afterTest(t) var log struct { sync.Mutex bytes.Buffer @@ -265,6 +273,7 @@ func TestPostRedirects(t *testing.T) { w.WriteHeader(code) } })) + defer ts.Close() tests := []struct { suffix string want int // response code @@ -364,6 +373,7 @@ func (j *TestJar) Cookies(u *url.URL) []*Cookie { } func TestRedirectCookiesOnRequest(t *testing.T) { + defer afterTest(t) var ts *httptest.Server ts = httptest.NewServer(echoCookiesRedirectHandler) defer ts.Close() @@ -381,6 +391,7 @@ func TestRedirectCookiesOnRequest(t *testing.T) { } func TestRedirectCookiesJar(t *testing.T) { + defer afterTest(t) var ts *httptest.Server ts = httptest.NewServer(echoCookiesRedirectHandler) defer ts.Close() @@ -393,6 +404,7 @@ func TestRedirectCookiesJar(t *testing.T) { if err != nil { t.Fatalf("Get: %v", err) } + resp.Body.Close() matchReturnedCookies(t, expectedCookies, resp.Cookies()) } @@ -416,6 +428,7 @@ func matchReturnedCookies(t *testing.T, expected, given []*Cookie) { } func TestJarCalls(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { pathSuffix := r.RequestURI[1:] if r.RequestURI == "/nosetcookie" { @@ -479,6 +492,7 @@ func (j *RecordingJar) logf(format string, args ...interface{}) { } func TestStreamingGet(t *testing.T) { + defer afterTest(t) say := make(chan string) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.(Flusher).Flush() @@ -529,6 +543,7 @@ func (c *writeCountingConn) Write(p []byte) (int, error) { // TestClientWrites verifies that client requests are buffered and we // don't send a TCP packet per line of the http request + body. func TestClientWrites(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { })) defer ts.Close() @@ -562,6 +577,7 @@ func TestClientWrites(t *testing.T) { } func TestClientInsecureTransport(t *testing.T) { + defer afterTest(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello")) })) @@ -576,15 +592,20 @@ func TestClientInsecureTransport(t *testing.T) { InsecureSkipVerify: insecure, }, } + defer tr.CloseIdleConnections() c := &Client{Transport: tr} - _, err := c.Get(ts.URL) + res, err := c.Get(ts.URL) if (err == nil) != insecure { t.Errorf("insecure=%v: got unexpected err=%v", insecure, err) } + if res != nil { + res.Body.Close() + } } } func TestClientErrorWithRequestURI(t *testing.T) { + defer afterTest(t) req, _ := NewRequest("GET", "http://localhost:1234/", nil) req.RequestURI = "/this/field/is/illegal/and/should/error/" _, err := DefaultClient.Do(req) @@ -613,6 +634,7 @@ func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport { } func TestClientWithCorrectTLSServerName(t *testing.T) { + defer afterTest(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.TLS.ServerName != "127.0.0.1" { t.Errorf("expected client to set ServerName 127.0.0.1, got: %q", r.TLS.ServerName) @@ -627,6 +649,7 @@ func TestClientWithCorrectTLSServerName(t *testing.T) { } func TestClientWithIncorrectTLSServerName(t *testing.T) { + defer afterTest(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) defer ts.Close() @@ -644,6 +667,7 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) { // Verify Response.ContentLength is populated. http://golang.org/issue/4126 func TestClientHeadContentLength(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if v := r.FormValue("cl"); v != "" { w.Header().Set("Content-Length", v) diff --git a/libgo/go/net/http/cookiejar/jar.go b/libgo/go/net/http/cookiejar/jar.go new file mode 100644 index 0000000..5977d48 --- /dev/null +++ b/libgo/go/net/http/cookiejar/jar.go @@ -0,0 +1,497 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package cookiejar implements an in-memory RFC 6265-compliant http.CookieJar. +package cookiejar + +import ( + "errors" + "fmt" + "net" + "net/http" + "net/url" + "sort" + "strings" + "sync" + "time" +) + +// PublicSuffixList provides the public suffix of a domain. For example: +// - the public suffix of "example.com" is "com", +// - the public suffix of "foo1.foo2.foo3.co.uk" is "co.uk", and +// - the public suffix of "bar.pvt.k12.ma.us" is "pvt.k12.ma.us". +// +// Implementations of PublicSuffixList must be safe for concurrent use by +// multiple goroutines. +// +// An implementation that always returns "" is valid and may be useful for +// testing but it is not secure: it means that the HTTP server for foo.com can +// set a cookie for bar.com. +// +// A public suffix list implementation is in the package +// code.google.com/p/go.net/publicsuffix. +type PublicSuffixList interface { + // PublicSuffix returns the public suffix of domain. + // + // TODO: specify which of the caller and callee is responsible for IP + // addresses, for leading and trailing dots, for case sensitivity, and + // for IDN/Punycode. + PublicSuffix(domain string) string + + // String returns a description of the source of this public suffix + // list. The description will typically contain something like a time + // stamp or version number. + String() string +} + +// Options are the options for creating a new Jar. +type Options struct { + // PublicSuffixList is the public suffix list that determines whether + // an HTTP server can set a cookie for a domain. + // + // A nil value is valid and may be useful for testing but it is not + // secure: it means that the HTTP server for foo.co.uk can set a cookie + // for bar.co.uk. + PublicSuffixList PublicSuffixList +} + +// Jar implements the http.CookieJar interface from the net/http package. +type Jar struct { + psList PublicSuffixList + + // mu locks the remaining fields. + mu sync.Mutex + + // entries is a set of entries, keyed by their eTLD+1 and subkeyed by + // their name/domain/path. + entries map[string]map[string]entry + + // nextSeqNum is the next sequence number assigned to a new cookie + // created SetCookies. + nextSeqNum uint64 +} + +// New returns a new cookie jar. A nil *Options is equivalent to a zero +// Options. +func New(o *Options) (*Jar, error) { + jar := &Jar{ + entries: make(map[string]map[string]entry), + } + if o != nil { + jar.psList = o.PublicSuffixList + } + return jar, nil +} + +// entry is the internal representation of a cookie. +// +// This struct type is not used outside of this package per se, but the exported +// fields are those of RFC 6265. +type entry struct { + Name string + Value string + Domain string + Path string + Secure bool + HttpOnly bool + Persistent bool + HostOnly bool + Expires time.Time + Creation time.Time + LastAccess time.Time + + // seqNum is a sequence number so that Cookies returns cookies in a + // deterministic order, even for cookies that have equal Path length and + // equal Creation time. This simplifies testing. + seqNum uint64 +} + +// Id returns the domain;path;name triple of e as an id. +func (e *entry) id() string { + return fmt.Sprintf("%s;%s;%s", e.Domain, e.Path, e.Name) +} + +// shouldSend determines whether e's cookie qualifies to be included in a +// request to host/path. It is the caller's responsibility to check if the +// cookie is expired. +func (e *entry) shouldSend(https bool, host, path string) bool { + return e.domainMatch(host) && e.pathMatch(path) && (https || !e.Secure) +} + +// domainMatch implements "domain-match" of RFC 6265 section 5.1.3. +func (e *entry) domainMatch(host string) bool { + if e.Domain == host { + return true + } + return !e.HostOnly && hasDotSuffix(host, e.Domain) +} + +// pathMatch implements "path-match" according to RFC 6265 section 5.1.4. +func (e *entry) pathMatch(requestPath string) bool { + if requestPath == e.Path { + return true + } + if strings.HasPrefix(requestPath, e.Path) { + if e.Path[len(e.Path)-1] == '/' { + return true // The "/any/" matches "/any/path" case. + } else if requestPath[len(e.Path)] == '/' { + return true // The "/any" matches "/any/path" case. + } + } + return false +} + +// hasDotSuffix returns whether s ends in "."+suffix. +func hasDotSuffix(s, suffix string) bool { + return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix +} + +// byPathLength is a []entry sort.Interface that sorts according to RFC 6265 +// section 5.4 point 2: by longest path and then by earliest creation time. +type byPathLength []entry + +func (s byPathLength) Len() int { return len(s) } + +func (s byPathLength) Less(i, j int) bool { + if len(s[i].Path) != len(s[j].Path) { + return len(s[i].Path) > len(s[j].Path) + } + if !s[i].Creation.Equal(s[j].Creation) { + return s[i].Creation.Before(s[j].Creation) + } + return s[i].seqNum < s[j].seqNum +} + +func (s byPathLength) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// Cookies implements the Cookies method of the http.CookieJar interface. +// +// It returns an empty slice if the URL's scheme is not HTTP or HTTPS. +func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) { + return j.cookies(u, time.Now()) +} + +// cookies is like Cookies but takes the current time as a parameter. +func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) { + if u.Scheme != "http" && u.Scheme != "https" { + return cookies + } + host, err := canonicalHost(u.Host) + if err != nil { + return cookies + } + key := jarKey(host, j.psList) + + j.mu.Lock() + defer j.mu.Unlock() + + submap := j.entries[key] + if submap == nil { + return cookies + } + + https := u.Scheme == "https" + path := u.Path + if path == "" { + path = "/" + } + + modified := false + var selected []entry + for id, e := range submap { + if e.Persistent && !e.Expires.After(now) { + delete(submap, id) + modified = true + continue + } + if !e.shouldSend(https, host, path) { + continue + } + e.LastAccess = now + submap[id] = e + selected = append(selected, e) + modified = true + } + if modified { + if len(submap) == 0 { + delete(j.entries, key) + } else { + j.entries[key] = submap + } + } + + sort.Sort(byPathLength(selected)) + for _, e := range selected { + cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value}) + } + + return cookies +} + +// SetCookies implements the SetCookies method of the http.CookieJar interface. +// +// It does nothing if the URL's scheme is not HTTP or HTTPS. +func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) { + j.setCookies(u, cookies, time.Now()) +} + +// setCookies is like SetCookies but takes the current time as parameter. +func (j *Jar) setCookies(u *url.URL, cookies []*http.Cookie, now time.Time) { + if len(cookies) == 0 { + return + } + if u.Scheme != "http" && u.Scheme != "https" { + return + } + host, err := canonicalHost(u.Host) + if err != nil { + return + } + key := jarKey(host, j.psList) + defPath := defaultPath(u.Path) + + j.mu.Lock() + defer j.mu.Unlock() + + submap := j.entries[key] + + modified := false + for _, cookie := range cookies { + e, remove, err := j.newEntry(cookie, now, defPath, host) + if err != nil { + continue + } + id := e.id() + if remove { + if submap != nil { + if _, ok := submap[id]; ok { + delete(submap, id) + modified = true + } + } + continue + } + if submap == nil { + submap = make(map[string]entry) + } + + if old, ok := submap[id]; ok { + e.Creation = old.Creation + e.seqNum = old.seqNum + } else { + e.Creation = now + e.seqNum = j.nextSeqNum + j.nextSeqNum++ + } + e.LastAccess = now + submap[id] = e + modified = true + } + + if modified { + if len(submap) == 0 { + delete(j.entries, key) + } else { + j.entries[key] = submap + } + } +} + +// canonicalHost strips port from host if present and returns the canonicalized +// host name. +func canonicalHost(host string) (string, error) { + var err error + host = strings.ToLower(host) + if hasPort(host) { + host, _, err = net.SplitHostPort(host) + if err != nil { + return "", err + } + } + if strings.HasSuffix(host, ".") { + // Strip trailing dot from fully qualified domain names. + host = host[:len(host)-1] + } + return toASCII(host) +} + +// hasPort returns whether host contains a port number. host may be a host +// name, an IPv4 or an IPv6 address. +func hasPort(host string) bool { + colons := strings.Count(host, ":") + if colons == 0 { + return false + } + if colons == 1 { + return true + } + return host[0] == '[' && strings.Contains(host, "]:") +} + +// jarKey returns the key to use for a jar. +func jarKey(host string, psl PublicSuffixList) string { + if isIP(host) { + return host + } + + var i int + if psl == nil { + i = strings.LastIndex(host, ".") + if i == -1 { + return host + } + } else { + suffix := psl.PublicSuffix(host) + if suffix == host { + return host + } + i = len(host) - len(suffix) + if i <= 0 || host[i-1] != '.' { + // The provided public suffix list psl is broken. + // Storing cookies under host is a safe stopgap. + return host + } + } + prevDot := strings.LastIndex(host[:i-1], ".") + return host[prevDot+1:] +} + +// isIP returns whether host is an IP address. +func isIP(host string) bool { + return net.ParseIP(host) != nil +} + +// defaultPath returns the directory part of an URL's path according to +// RFC 6265 section 5.1.4. +func defaultPath(path string) string { + if len(path) == 0 || path[0] != '/' { + return "/" // Path is empty or malformed. + } + + i := strings.LastIndex(path, "/") // Path starts with "/", so i != -1. + if i == 0 { + return "/" // Path has the form "/abc". + } + return path[:i] // Path is either of form "/abc/xyz" or "/abc/xyz/". +} + +// newEntry creates an entry from a http.Cookie c. now is the current time and +// is compared to c.Expires to determine deletion of c. defPath and host are the +// default-path and the canonical host name of the URL c was received from. +// +// remove is whether the jar should delete this cookie, as it has already +// expired with respect to now. In this case, e may be incomplete, but it will +// be valid to call e.id (which depends on e's Name, Domain and Path). +// +// A malformed c.Domain will result in an error. +func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e entry, remove bool, err error) { + e.Name = c.Name + + if c.Path == "" || c.Path[0] != '/' { + e.Path = defPath + } else { + e.Path = c.Path + } + + e.Domain, e.HostOnly, err = j.domainAndType(host, c.Domain) + if err != nil { + return e, false, err + } + + // MaxAge takes precedence over Expires. + if c.MaxAge < 0 { + return e, true, nil + } else if c.MaxAge > 0 { + e.Expires = now.Add(time.Duration(c.MaxAge) * time.Second) + e.Persistent = true + } else { + if c.Expires.IsZero() { + e.Expires = endOfTime + e.Persistent = false + } else { + if !c.Expires.After(now) { + return e, true, nil + } + e.Expires = c.Expires + e.Persistent = true + } + } + + e.Value = c.Value + e.Secure = c.Secure + e.HttpOnly = c.HttpOnly + + return e, false, nil +} + +var ( + errIllegalDomain = errors.New("cookiejar: illegal cookie domain attribute") + errMalformedDomain = errors.New("cookiejar: malformed cookie domain attribute") + errNoHostname = errors.New("cookiejar: no host name available (IP only)") +) + +// endOfTime is the time when session (non-persistent) cookies expire. +// This instant is representable in most date/time formats (not just +// Go's time.Time) and should be far enough in the future. +var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC) + +// domainAndType determines the cookie's domain and hostOnly attribute. +func (j *Jar) domainAndType(host, domain string) (string, bool, error) { + if domain == "" { + // No domain attribute in the SetCookie header indicates a + // host cookie. + return host, true, nil + } + + if isIP(host) { + // According to RFC 6265 domain-matching includes not being + // an IP address. + // TODO: This might be relaxed as in common browsers. + return "", false, errNoHostname + } + + // From here on: If the cookie is valid, it is a domain cookie (with + // the one exception of a public suffix below). + // See RFC 6265 section 5.2.3. + if domain[0] == '.' { + domain = domain[1:] + } + + if len(domain) == 0 || domain[0] == '.' { + // Received either "Domain=." or "Domain=..some.thing", + // both are illegal. + return "", false, errMalformedDomain + } + domain = strings.ToLower(domain) + + if domain[len(domain)-1] == '.' { + // We received stuff like "Domain=www.example.com.". + // Browsers do handle such stuff (actually differently) but + // RFC 6265 seems to be clear here (e.g. section 4.1.2.3) in + // requiring a reject. 4.1.2.3 is not normative, but + // "Domain Matching" (5.1.3) and "Canonicalized Host Names" + // (5.1.2) are. + return "", false, errMalformedDomain + } + + // See RFC 6265 section 5.3 #5. + if j.psList != nil { + if ps := j.psList.PublicSuffix(domain); ps != "" && !hasDotSuffix(domain, ps) { + if host == domain { + // This is the one exception in which a cookie + // with a domain attribute is a host cookie. + return host, true, nil + } + return "", false, errIllegalDomain + } + } + + // The domain must domain-match host: www.mycompany.com cannot + // set cookies for .ourcompetitors.com. + if host != domain && !hasDotSuffix(host, domain) { + return "", false, errIllegalDomain + } + + return domain, false, nil +} diff --git a/libgo/go/net/http/cookiejar/jar_test.go b/libgo/go/net/http/cookiejar/jar_test.go new file mode 100644 index 0000000..3aa6015 --- /dev/null +++ b/libgo/go/net/http/cookiejar/jar_test.go @@ -0,0 +1,1267 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cookiejar + +import ( + "fmt" + "net/http" + "net/url" + "sort" + "strings" + "testing" + "time" +) + +// tNow is the synthetic current time used as now during testing. +var tNow = time.Date(2013, 1, 1, 12, 0, 0, 0, time.UTC) + +// testPSL implements PublicSuffixList with just two rules: "co.uk" +// and the default rule "*". +type testPSL struct{} + +func (testPSL) String() string { + return "testPSL" +} +func (testPSL) PublicSuffix(d string) string { + if d == "co.uk" || strings.HasSuffix(d, ".co.uk") { + return "co.uk" + } + return d[strings.LastIndex(d, ".")+1:] +} + +// newTestJar creates an empty Jar with testPSL as the public suffix list. +func newTestJar() *Jar { + jar, err := New(&Options{PublicSuffixList: testPSL{}}) + if err != nil { + panic(err) + } + return jar +} + +var hasDotSuffixTests = [...]struct { + s, suffix string +}{ + {"", ""}, + {"", "."}, + {"", "x"}, + {".", ""}, + {".", "."}, + {".", ".."}, + {".", "x"}, + {".", "x."}, + {".", ".x"}, + {".", ".x."}, + {"x", ""}, + {"x", "."}, + {"x", ".."}, + {"x", "x"}, + {"x", "x."}, + {"x", ".x"}, + {"x", ".x."}, + {".x", ""}, + {".x", "."}, + {".x", ".."}, + {".x", "x"}, + {".x", "x."}, + {".x", ".x"}, + {".x", ".x."}, + {"x.", ""}, + {"x.", "."}, + {"x.", ".."}, + {"x.", "x"}, + {"x.", "x."}, + {"x.", ".x"}, + {"x.", ".x."}, + {"com", ""}, + {"com", "m"}, + {"com", "om"}, + {"com", "com"}, + {"com", ".com"}, + {"com", "x.com"}, + {"com", "xcom"}, + {"com", "xorg"}, + {"com", "org"}, + {"com", "rg"}, + {"foo.com", ""}, + {"foo.com", "m"}, + {"foo.com", "om"}, + {"foo.com", "com"}, + {"foo.com", ".com"}, + {"foo.com", "o.com"}, + {"foo.com", "oo.com"}, + {"foo.com", "foo.com"}, + {"foo.com", ".foo.com"}, + {"foo.com", "x.foo.com"}, + {"foo.com", "xfoo.com"}, + {"foo.com", "xfoo.org"}, + {"foo.com", "foo.org"}, + {"foo.com", "oo.org"}, + {"foo.com", "o.org"}, + {"foo.com", ".org"}, + {"foo.com", "org"}, + {"foo.com", "rg"}, +} + +func TestHasDotSuffix(t *testing.T) { + for _, tc := range hasDotSuffixTests { + got := hasDotSuffix(tc.s, tc.suffix) + want := strings.HasSuffix(tc.s, "."+tc.suffix) + if got != want { + t.Errorf("s=%q, suffix=%q: got %v, want %v", tc.s, tc.suffix, got, want) + } + } +} + +var canonicalHostTests = map[string]string{ + "www.example.com": "www.example.com", + "WWW.EXAMPLE.COM": "www.example.com", + "wWw.eXAmple.CoM": "www.example.com", + "www.example.com:80": "www.example.com", + "192.168.0.10": "192.168.0.10", + "192.168.0.5:8080": "192.168.0.5", + "2001:4860:0:2001::68": "2001:4860:0:2001::68", + "[2001:4860:0:::68]:8080": "2001:4860:0:::68", + "www.bücher.de": "www.xn--bcher-kva.de", + "www.example.com.": "www.example.com", + "[bad.unmatched.bracket:": "error", +} + +func TestCanonicalHost(t *testing.T) { + for h, want := range canonicalHostTests { + got, err := canonicalHost(h) + if want == "error" { + if err == nil { + t.Errorf("%q: got nil error, want non-nil", h) + } + continue + } + if err != nil { + t.Errorf("%q: %v", h, err) + continue + } + if got != want { + t.Errorf("%q: got %q, want %q", h, got, want) + continue + } + } +} + +var hasPortTests = map[string]bool{ + "www.example.com": false, + "www.example.com:80": true, + "127.0.0.1": false, + "127.0.0.1:8080": true, + "2001:4860:0:2001::68": false, + "[2001::0:::68]:80": true, +} + +func TestHasPort(t *testing.T) { + for host, want := range hasPortTests { + if got := hasPort(host); got != want { + t.Errorf("%q: got %t, want %t", host, got, want) + } + } +} + +var jarKeyTests = map[string]string{ + "foo.www.example.com": "example.com", + "www.example.com": "example.com", + "example.com": "example.com", + "com": "com", + "foo.www.bbc.co.uk": "bbc.co.uk", + "www.bbc.co.uk": "bbc.co.uk", + "bbc.co.uk": "bbc.co.uk", + "co.uk": "co.uk", + "uk": "uk", + "192.168.0.5": "192.168.0.5", +} + +func TestJarKey(t *testing.T) { + for host, want := range jarKeyTests { + if got := jarKey(host, testPSL{}); got != want { + t.Errorf("%q: got %q, want %q", host, got, want) + } + } +} + +var jarKeyNilPSLTests = map[string]string{ + "foo.www.example.com": "example.com", + "www.example.com": "example.com", + "example.com": "example.com", + "com": "com", + "foo.www.bbc.co.uk": "co.uk", + "www.bbc.co.uk": "co.uk", + "bbc.co.uk": "co.uk", + "co.uk": "co.uk", + "uk": "uk", + "192.168.0.5": "192.168.0.5", +} + +func TestJarKeyNilPSL(t *testing.T) { + for host, want := range jarKeyNilPSLTests { + if got := jarKey(host, nil); got != want { + t.Errorf("%q: got %q, want %q", host, got, want) + } + } +} + +var isIPTests = map[string]bool{ + "127.0.0.1": true, + "1.2.3.4": true, + "2001:4860:0:2001::68": true, + "example.com": false, + "1.1.1.300": false, + "www.foo.bar.net": false, + "123.foo.bar.net": false, +} + +func TestIsIP(t *testing.T) { + for host, want := range isIPTests { + if got := isIP(host); got != want { + t.Errorf("%q: got %t, want %t", host, got, want) + } + } +} + +var defaultPathTests = map[string]string{ + "/": "/", + "/abc": "/", + "/abc/": "/abc", + "/abc/xyz": "/abc", + "/abc/xyz/": "/abc/xyz", + "/a/b/c.html": "/a/b", + "": "/", + "strange": "/", + "//": "/", + "/a//b": "/a/", + "/a/./b": "/a/.", + "/a/../b": "/a/..", +} + +func TestDefaultPath(t *testing.T) { + for path, want := range defaultPathTests { + if got := defaultPath(path); got != want { + t.Errorf("%q: got %q, want %q", path, got, want) + } + } +} + +var domainAndTypeTests = [...]struct { + host string // host Set-Cookie header was received from + domain string // domain attribute in Set-Cookie header + wantDomain string // expected domain of cookie + wantHostOnly bool // expected host-cookie flag + wantErr error // expected error +}{ + {"www.example.com", "", "www.example.com", true, nil}, + {"127.0.0.1", "", "127.0.0.1", true, nil}, + {"2001:4860:0:2001::68", "", "2001:4860:0:2001::68", true, nil}, + {"www.example.com", "example.com", "example.com", false, nil}, + {"www.example.com", ".example.com", "example.com", false, nil}, + {"www.example.com", "www.example.com", "www.example.com", false, nil}, + {"www.example.com", ".www.example.com", "www.example.com", false, nil}, + {"foo.sso.example.com", "sso.example.com", "sso.example.com", false, nil}, + {"bar.co.uk", "bar.co.uk", "bar.co.uk", false, nil}, + {"foo.bar.co.uk", ".bar.co.uk", "bar.co.uk", false, nil}, + {"127.0.0.1", "127.0.0.1", "", false, errNoHostname}, + {"2001:4860:0:2001::68", "2001:4860:0:2001::68", "2001:4860:0:2001::68", false, errNoHostname}, + {"www.example.com", ".", "", false, errMalformedDomain}, + {"www.example.com", "..", "", false, errMalformedDomain}, + {"www.example.com", "other.com", "", false, errIllegalDomain}, + {"www.example.com", "com", "", false, errIllegalDomain}, + {"www.example.com", ".com", "", false, errIllegalDomain}, + {"foo.bar.co.uk", ".co.uk", "", false, errIllegalDomain}, + {"127.www.0.0.1", "127.0.0.1", "", false, errIllegalDomain}, + {"com", "", "com", true, nil}, + {"com", "com", "com", true, nil}, + {"com", ".com", "com", true, nil}, + {"co.uk", "", "co.uk", true, nil}, + {"co.uk", "co.uk", "co.uk", true, nil}, + {"co.uk", ".co.uk", "co.uk", true, nil}, +} + +func TestDomainAndType(t *testing.T) { + jar := newTestJar() + for _, tc := range domainAndTypeTests { + domain, hostOnly, err := jar.domainAndType(tc.host, tc.domain) + if err != tc.wantErr { + t.Errorf("%q/%q: got %q error, want %q", + tc.host, tc.domain, err, tc.wantErr) + continue + } + if err != nil { + continue + } + if domain != tc.wantDomain || hostOnly != tc.wantHostOnly { + t.Errorf("%q/%q: got %q/%t want %q/%t", + tc.host, tc.domain, domain, hostOnly, + tc.wantDomain, tc.wantHostOnly) + } + } +} + +// expiresIn creates an expires attribute delta seconds from tNow. +func expiresIn(delta int) string { + t := tNow.Add(time.Duration(delta) * time.Second) + return "expires=" + t.Format(time.RFC1123) +} + +// mustParseURL parses s to an URL and panics on error. +func mustParseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil || u.Scheme == "" || u.Host == "" { + panic(fmt.Sprintf("Unable to parse URL %s.", s)) + } + return u +} + +// jarTest encapsulates the following actions on a jar: +// 1. Perform SetCookies with fromURL and the cookies from setCookies. +// (Done at time tNow + 0 ms.) +// 2. Check that the entries in the jar matches content. +// (Done at time tNow + 1001 ms.) +// 3. For each query in tests: Check that Cookies with toURL yields the +// cookies in want. +// (Query n done at tNow + (n+2)*1001 ms.) +type jarTest struct { + description string // The description of what this test is supposed to test + fromURL string // The full URL of the request from which Set-Cookie headers where received + setCookies []string // All the cookies received from fromURL + content string // The whole (non-expired) content of the jar + queries []query // Queries to test the Jar.Cookies method +} + +// query contains one test of the cookies returned from Jar.Cookies. +type query struct { + toURL string // the URL in the Cookies call + want string // the expected list of cookies (order matters) +} + +// run runs the jarTest. +func (test jarTest) run(t *testing.T, jar *Jar) { + now := tNow + + // Populate jar with cookies. + setCookies := make([]*http.Cookie, len(test.setCookies)) + for i, cs := range test.setCookies { + cookies := (&http.Response{Header: http.Header{"Set-Cookie": {cs}}}).Cookies() + if len(cookies) != 1 { + panic(fmt.Sprintf("Wrong cookie line %q: %#v", cs, cookies)) + } + setCookies[i] = cookies[0] + } + jar.setCookies(mustParseURL(test.fromURL), setCookies, now) + now = now.Add(1001 * time.Millisecond) + + // Serialize non-expired entries in the form "name1=val1 name2=val2". + var cs []string + for _, submap := range jar.entries { + for _, cookie := range submap { + if !cookie.Expires.After(now) { + continue + } + cs = append(cs, cookie.Name+"="+cookie.Value) + } + } + sort.Strings(cs) + got := strings.Join(cs, " ") + + // Make sure jar content matches our expectations. + if got != test.content { + t.Errorf("Test %q Content\ngot %q\nwant %q", + test.description, got, test.content) + } + + // Test different calls to Cookies. + for i, query := range test.queries { + now = now.Add(1001 * time.Millisecond) + var s []string + for _, c := range jar.cookies(mustParseURL(query.toURL), now) { + s = append(s, c.Name+"="+c.Value) + } + if got := strings.Join(s, " "); got != query.want { + t.Errorf("Test %q #%d\ngot %q\nwant %q", test.description, i, got, query.want) + } + } +} + +// basicsTests contains fundamental tests. Each jarTest has to be performed on +// a fresh, empty Jar. +var basicsTests = [...]jarTest{ + { + "Retrieval of a plain host cookie.", + "http://www.host.test/", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", "A=a"}, + {"http://www.host.test/", "A=a"}, + {"http://www.host.test/some/path", "A=a"}, + {"https://www.host.test", "A=a"}, + {"https://www.host.test/", "A=a"}, + {"https://www.host.test/some/path", "A=a"}, + {"ftp://www.host.test", ""}, + {"ftp://www.host.test/", ""}, + {"ftp://www.host.test/some/path", ""}, + {"http://www.other.org", ""}, + {"http://sibling.host.test", ""}, + {"http://deep.www.host.test", ""}, + }, + }, + { + "Secure cookies are not returned to http.", + "http://www.host.test/", + []string{"A=a; secure"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some/path", ""}, + {"https://www.host.test", "A=a"}, + {"https://www.host.test/", "A=a"}, + {"https://www.host.test/some/path", "A=a"}, + }, + }, + { + "Explicit path.", + "http://www.host.test/", + []string{"A=a; path=/some/path"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some", ""}, + {"http://www.host.test/some/", ""}, + {"http://www.host.test/some/path", "A=a"}, + {"http://www.host.test/some/paths", ""}, + {"http://www.host.test/some/path/foo", "A=a"}, + {"http://www.host.test/some/path/foo/", "A=a"}, + }, + }, + { + "Implicit path #1: path is a directory.", + "http://www.host.test/some/path/", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some", ""}, + {"http://www.host.test/some/", ""}, + {"http://www.host.test/some/path", "A=a"}, + {"http://www.host.test/some/paths", ""}, + {"http://www.host.test/some/path/foo", "A=a"}, + {"http://www.host.test/some/path/foo/", "A=a"}, + }, + }, + { + "Implicit path #2: path is not a directory.", + "http://www.host.test/some/path/index.html", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", ""}, + {"http://www.host.test/", ""}, + {"http://www.host.test/some", ""}, + {"http://www.host.test/some/", ""}, + {"http://www.host.test/some/path", "A=a"}, + {"http://www.host.test/some/paths", ""}, + {"http://www.host.test/some/path/foo", "A=a"}, + {"http://www.host.test/some/path/foo/", "A=a"}, + }, + }, + { + "Implicit path #3: no path in URL at all.", + "http://www.host.test", + []string{"A=a"}, + "A=a", + []query{ + {"http://www.host.test", "A=a"}, + {"http://www.host.test/", "A=a"}, + {"http://www.host.test/some/path", "A=a"}, + }, + }, + { + "Cookies are sorted by path length.", + "http://www.host.test/", + []string{ + "A=a; path=/foo/bar", + "B=b; path=/foo/bar/baz/qux", + "C=c; path=/foo/bar/baz", + "D=d; path=/foo"}, + "A=a B=b C=c D=d", + []query{ + {"http://www.host.test/foo/bar/baz/qux", "B=b C=c A=a D=d"}, + {"http://www.host.test/foo/bar/baz/", "C=c A=a D=d"}, + {"http://www.host.test/foo/bar", "A=a D=d"}, + }, + }, + { + "Creation time determines sorting on same length paths.", + "http://www.host.test/", + []string{ + "A=a; path=/foo/bar", + "X=x; path=/foo/bar", + "Y=y; path=/foo/bar/baz/qux", + "B=b; path=/foo/bar/baz/qux", + "C=c; path=/foo/bar/baz", + "W=w; path=/foo/bar/baz", + "Z=z; path=/foo", + "D=d; path=/foo"}, + "A=a B=b C=c D=d W=w X=x Y=y Z=z", + []query{ + {"http://www.host.test/foo/bar/baz/qux", "Y=y B=b C=c W=w A=a X=x Z=z D=d"}, + {"http://www.host.test/foo/bar/baz/", "C=c W=w A=a X=x Z=z D=d"}, + {"http://www.host.test/foo/bar", "A=a X=x Z=z D=d"}, + }, + }, + { + "Sorting of same-name cookies.", + "http://www.host.test/", + []string{ + "A=1; path=/", + "A=2; path=/path", + "A=3; path=/quux", + "A=4; path=/path/foo", + "A=5; domain=.host.test; path=/path", + "A=6; domain=.host.test; path=/quux", + "A=7; domain=.host.test; path=/path/foo", + }, + "A=1 A=2 A=3 A=4 A=5 A=6 A=7", + []query{ + {"http://www.host.test/path", "A=2 A=5 A=1"}, + {"http://www.host.test/path/foo", "A=4 A=7 A=2 A=5 A=1"}, + }, + }, + { + "Disallow domain cookie on public suffix.", + "http://www.bbc.co.uk", + []string{ + "a=1", + "b=2; domain=co.uk", + }, + "a=1", + []query{{"http://www.bbc.co.uk", "a=1"}}, + }, + { + "Host cookie on IP.", + "http://192.168.0.10", + []string{"a=1"}, + "a=1", + []query{{"http://192.168.0.10", "a=1"}}, + }, + { + "Port is ignored #1.", + "http://www.host.test/", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://www.host.test:8080/", "a=1"}, + }, + }, + { + "Port is ignored #2.", + "http://www.host.test:8080/", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://www.host.test:8080/", "a=1"}, + {"http://www.host.test:1234/", "a=1"}, + }, + }, +} + +func TestBasics(t *testing.T) { + for _, test := range basicsTests { + jar := newTestJar() + test.run(t, jar) + } +} + +// updateAndDeleteTests contains jarTests which must be performed on the same +// Jar. +var updateAndDeleteTests = [...]jarTest{ + { + "Set initial cookies.", + "http://www.host.test", + []string{ + "a=1", + "b=2; secure", + "c=3; httponly", + "d=4; secure; httponly"}, + "a=1 b=2 c=3 d=4", + []query{ + {"http://www.host.test", "a=1 c=3"}, + {"https://www.host.test", "a=1 b=2 c=3 d=4"}, + }, + }, + { + "Update value via http.", + "http://www.host.test", + []string{ + "a=w", + "b=x; secure", + "c=y; httponly", + "d=z; secure; httponly"}, + "a=w b=x c=y d=z", + []query{ + {"http://www.host.test", "a=w c=y"}, + {"https://www.host.test", "a=w b=x c=y d=z"}, + }, + }, + { + "Clear Secure flag from a http.", + "http://www.host.test/", + []string{ + "b=xx", + "d=zz; httponly"}, + "a=w b=xx c=y d=zz", + []query{{"http://www.host.test", "a=w b=xx c=y d=zz"}}, + }, + { + "Delete all.", + "http://www.host.test/", + []string{ + "a=1; max-Age=-1", // delete via MaxAge + "b=2; " + expiresIn(-10), // delete via Expires + "c=2; max-age=-1; " + expiresIn(-10), // delete via both + "d=4; max-age=-1; " + expiresIn(10)}, // MaxAge takes precedence + "", + []query{{"http://www.host.test", ""}}, + }, + { + "Refill #1.", + "http://www.host.test", + []string{ + "A=1", + "A=2; path=/foo", + "A=3; domain=.host.test", + "A=4; path=/foo; domain=.host.test"}, + "A=1 A=2 A=3 A=4", + []query{{"http://www.host.test/foo", "A=2 A=4 A=1 A=3"}}, + }, + { + "Refill #2.", + "http://www.google.com", + []string{ + "A=6", + "A=7; path=/foo", + "A=8; domain=.google.com", + "A=9; path=/foo; domain=.google.com"}, + "A=1 A=2 A=3 A=4 A=6 A=7 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=4 A=1 A=3"}, + {"http://www.google.com/foo", "A=7 A=9 A=6 A=8"}, + }, + }, + { + "Delete A7.", + "http://www.google.com", + []string{"A=; path=/foo; max-age=-1"}, + "A=1 A=2 A=3 A=4 A=6 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=4 A=1 A=3"}, + {"http://www.google.com/foo", "A=9 A=6 A=8"}, + }, + }, + { + "Delete A4.", + "http://www.host.test", + []string{"A=; path=/foo; domain=host.test; max-age=-1"}, + "A=1 A=2 A=3 A=6 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1 A=3"}, + {"http://www.google.com/foo", "A=9 A=6 A=8"}, + }, + }, + { + "Delete A6.", + "http://www.google.com", + []string{"A=; max-age=-1"}, + "A=1 A=2 A=3 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1 A=3"}, + {"http://www.google.com/foo", "A=9 A=8"}, + }, + }, + { + "Delete A3.", + "http://www.host.test", + []string{"A=; domain=host.test; max-age=-1"}, + "A=1 A=2 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1"}, + {"http://www.google.com/foo", "A=9 A=8"}, + }, + }, + { + "No cross-domain delete.", + "http://www.host.test", + []string{ + "A=; domain=google.com; max-age=-1", + "A=; path=/foo; domain=google.com; max-age=-1"}, + "A=1 A=2 A=8 A=9", + []query{ + {"http://www.host.test/foo", "A=2 A=1"}, + {"http://www.google.com/foo", "A=9 A=8"}, + }, + }, + { + "Delete A8 and A9.", + "http://www.google.com", + []string{ + "A=; domain=google.com; max-age=-1", + "A=; path=/foo; domain=google.com; max-age=-1"}, + "A=1 A=2", + []query{ + {"http://www.host.test/foo", "A=2 A=1"}, + {"http://www.google.com/foo", ""}, + }, + }, +} + +func TestUpdateAndDelete(t *testing.T) { + jar := newTestJar() + for _, test := range updateAndDeleteTests { + test.run(t, jar) + } +} + +func TestExpiration(t *testing.T) { + jar := newTestJar() + jarTest{ + "Expiration.", + "http://www.host.test", + []string{ + "a=1", + "b=2; max-age=3", + "c=3; " + expiresIn(3), + "d=4; max-age=5", + "e=5; " + expiresIn(5), + "f=6; max-age=100", + }, + "a=1 b=2 c=3 d=4 e=5 f=6", // executed at t0 + 1001 ms + []query{ + {"http://www.host.test", "a=1 b=2 c=3 d=4 e=5 f=6"}, // t0 + 2002 ms + {"http://www.host.test", "a=1 d=4 e=5 f=6"}, // t0 + 3003 ms + {"http://www.host.test", "a=1 d=4 e=5 f=6"}, // t0 + 4004 ms + {"http://www.host.test", "a=1 f=6"}, // t0 + 5005 ms + {"http://www.host.test", "a=1 f=6"}, // t0 + 6006 ms + }, + }.run(t, jar) +} + +// +// Tests derived from Chromium's cookie_store_unittest.h. +// + +// See http://src.chromium.org/viewvc/chrome/trunk/src/net/cookies/cookie_store_unittest.h?revision=159685&content-type=text/plain +// Some of the original tests are in a bad condition (e.g. +// DomainWithTrailingDotTest) or are not RFC 6265 conforming (e.g. +// TestNonDottedAndTLD #1 and #6) and have not been ported. + +// chromiumBasicsTests contains fundamental tests. Each jarTest has to be +// performed on a fresh, empty Jar. +var chromiumBasicsTests = [...]jarTest{ + { + "DomainWithTrailingDotTest.", + "http://www.google.com/", + []string{ + "a=1; domain=.www.google.com.", + "b=2; domain=.www.google.com.."}, + "", + []query{ + {"http://www.google.com", ""}, + }, + }, + { + "ValidSubdomainTest #1.", + "http://a.b.c.d.com", + []string{ + "a=1; domain=.a.b.c.d.com", + "b=2; domain=.b.c.d.com", + "c=3; domain=.c.d.com", + "d=4; domain=.d.com"}, + "a=1 b=2 c=3 d=4", + []query{ + {"http://a.b.c.d.com", "a=1 b=2 c=3 d=4"}, + {"http://b.c.d.com", "b=2 c=3 d=4"}, + {"http://c.d.com", "c=3 d=4"}, + {"http://d.com", "d=4"}, + }, + }, + { + "ValidSubdomainTest #2.", + "http://a.b.c.d.com", + []string{ + "a=1; domain=.a.b.c.d.com", + "b=2; domain=.b.c.d.com", + "c=3; domain=.c.d.com", + "d=4; domain=.d.com", + "X=bcd; domain=.b.c.d.com", + "X=cd; domain=.c.d.com"}, + "X=bcd X=cd a=1 b=2 c=3 d=4", + []query{ + {"http://b.c.d.com", "b=2 c=3 d=4 X=bcd X=cd"}, + {"http://c.d.com", "c=3 d=4 X=cd"}, + }, + }, + { + "InvalidDomainTest #1.", + "http://foo.bar.com", + []string{ + "a=1; domain=.yo.foo.bar.com", + "b=2; domain=.foo.com", + "c=3; domain=.bar.foo.com", + "d=4; domain=.foo.bar.com.net", + "e=5; domain=ar.com", + "f=6; domain=.", + "g=7; domain=/", + "h=8; domain=http://foo.bar.com", + "i=9; domain=..foo.bar.com", + "j=10; domain=..bar.com", + "k=11; domain=.foo.bar.com?blah", + "l=12; domain=.foo.bar.com/blah", + "m=12; domain=.foo.bar.com:80", + "n=14; domain=.foo.bar.com:", + "o=15; domain=.foo.bar.com#sup", + }, + "", // Jar is empty. + []query{{"http://foo.bar.com", ""}}, + }, + { + "InvalidDomainTest #2.", + "http://foo.com.com", + []string{"a=1; domain=.foo.com.com.com"}, + "", + []query{{"http://foo.bar.com", ""}}, + }, + { + "DomainWithoutLeadingDotTest #1.", + "http://manage.hosted.filefront.com", + []string{"a=1; domain=filefront.com"}, + "a=1", + []query{{"http://www.filefront.com", "a=1"}}, + }, + { + "DomainWithoutLeadingDotTest #2.", + "http://www.google.com", + []string{"a=1; domain=www.google.com"}, + "a=1", + []query{ + {"http://www.google.com", "a=1"}, + {"http://sub.www.google.com", "a=1"}, + {"http://something-else.com", ""}, + }, + }, + { + "CaseInsensitiveDomainTest.", + "http://www.google.com", + []string{ + "a=1; domain=.GOOGLE.COM", + "b=2; domain=.www.gOOgLE.coM"}, + "a=1 b=2", + []query{{"http://www.google.com", "a=1 b=2"}}, + }, + { + "TestIpAddress #1.", + "http://1.2.3.4/foo", + []string{"a=1; path=/"}, + "a=1", + []query{{"http://1.2.3.4/foo", "a=1"}}, + }, + { + "TestIpAddress #2.", + "http://1.2.3.4/foo", + []string{ + "a=1; domain=.1.2.3.4", + "b=2; domain=.3.4"}, + "", + []query{{"http://1.2.3.4/foo", ""}}, + }, + { + "TestIpAddress #3.", + "http://1.2.3.4/foo", + []string{"a=1; domain=1.2.3.4"}, + "", + []query{{"http://1.2.3.4/foo", ""}}, + }, + { + "TestNonDottedAndTLD #2.", + "http://com./index.html", + []string{"a=1"}, + "a=1", + []query{ + {"http://com./index.html", "a=1"}, + {"http://no-cookies.com./index.html", ""}, + }, + }, + { + "TestNonDottedAndTLD #3.", + "http://a.b", + []string{ + "a=1; domain=.b", + "b=2; domain=b"}, + "", + []query{{"http://bar.foo", ""}}, + }, + { + "TestNonDottedAndTLD #4.", + "http://google.com", + []string{ + "a=1; domain=.com", + "b=2; domain=com"}, + "", + []query{{"http://google.com", ""}}, + }, + { + "TestNonDottedAndTLD #5.", + "http://google.co.uk", + []string{ + "a=1; domain=.co.uk", + "b=2; domain=.uk"}, + "", + []query{ + {"http://google.co.uk", ""}, + {"http://else.co.com", ""}, + {"http://else.uk", ""}, + }, + }, + { + "TestHostEndsWithDot.", + "http://www.google.com", + []string{ + "a=1", + "b=2; domain=.www.google.com."}, + "a=1", + []query{{"http://www.google.com", "a=1"}}, + }, + { + "PathTest", + "http://www.google.izzle", + []string{"a=1; path=/wee"}, + "a=1", + []query{ + {"http://www.google.izzle/wee", "a=1"}, + {"http://www.google.izzle/wee/", "a=1"}, + {"http://www.google.izzle/wee/war", "a=1"}, + {"http://www.google.izzle/wee/war/more/more", "a=1"}, + {"http://www.google.izzle/weehee", ""}, + {"http://www.google.izzle/", ""}, + }, + }, +} + +func TestChromiumBasics(t *testing.T) { + for _, test := range chromiumBasicsTests { + jar := newTestJar() + test.run(t, jar) + } +} + +// chromiumDomainTests contains jarTests which must be executed all on the +// same Jar. +var chromiumDomainTests = [...]jarTest{ + { + "Fill #1.", + "http://www.google.izzle", + []string{"A=B"}, + "A=B", + []query{{"http://www.google.izzle", "A=B"}}, + }, + { + "Fill #2.", + "http://www.google.izzle", + []string{"C=D; domain=.google.izzle"}, + "A=B C=D", + []query{{"http://www.google.izzle", "A=B C=D"}}, + }, + { + "Verify A is a host cookie and not accessible from subdomain.", + "http://unused.nil", + []string{}, + "A=B C=D", + []query{{"http://foo.www.google.izzle", "C=D"}}, + }, + { + "Verify domain cookies are found on proper domain.", + "http://www.google.izzle", + []string{"E=F; domain=.www.google.izzle"}, + "A=B C=D E=F", + []query{{"http://www.google.izzle", "A=B C=D E=F"}}, + }, + { + "Leading dots in domain attributes are optional.", + "http://www.google.izzle", + []string{"G=H; domain=www.google.izzle"}, + "A=B C=D E=F G=H", + []query{{"http://www.google.izzle", "A=B C=D E=F G=H"}}, + }, + { + "Verify domain enforcement works #1.", + "http://www.google.izzle", + []string{"K=L; domain=.bar.www.google.izzle"}, + "A=B C=D E=F G=H", + []query{{"http://bar.www.google.izzle", "C=D E=F G=H"}}, + }, + { + "Verify domain enforcement works #2.", + "http://unused.nil", + []string{}, + "A=B C=D E=F G=H", + []query{{"http://www.google.izzle", "A=B C=D E=F G=H"}}, + }, +} + +func TestChromiumDomain(t *testing.T) { + jar := newTestJar() + for _, test := range chromiumDomainTests { + test.run(t, jar) + } + +} + +// chromiumDeletionTests must be performed all on the same Jar. +var chromiumDeletionTests = [...]jarTest{ + { + "Create session cookie a1.", + "http://www.google.com", + []string{"a=1"}, + "a=1", + []query{{"http://www.google.com", "a=1"}}, + }, + { + "Delete sc a1 via MaxAge.", + "http://www.google.com", + []string{"a=1; max-age=-1"}, + "", + []query{{"http://www.google.com", ""}}, + }, + { + "Create session cookie b2.", + "http://www.google.com", + []string{"b=2"}, + "b=2", + []query{{"http://www.google.com", "b=2"}}, + }, + { + "Delete sc b2 via Expires.", + "http://www.google.com", + []string{"b=2; " + expiresIn(-10)}, + "", + []query{{"http://www.google.com", ""}}, + }, + { + "Create persistent cookie c3.", + "http://www.google.com", + []string{"c=3; max-age=3600"}, + "c=3", + []query{{"http://www.google.com", "c=3"}}, + }, + { + "Delete pc c3 via MaxAge.", + "http://www.google.com", + []string{"c=3; max-age=-1"}, + "", + []query{{"http://www.google.com", ""}}, + }, + { + "Create persistent cookie d4.", + "http://www.google.com", + []string{"d=4; max-age=3600"}, + "d=4", + []query{{"http://www.google.com", "d=4"}}, + }, + { + "Delete pc d4 via Expires.", + "http://www.google.com", + []string{"d=4; " + expiresIn(-10)}, + "", + []query{{"http://www.google.com", ""}}, + }, +} + +func TestChromiumDeletion(t *testing.T) { + jar := newTestJar() + for _, test := range chromiumDeletionTests { + test.run(t, jar) + } +} + +// domainHandlingTests tests and documents the rules for domain handling. +// Each test must be performed on an empty new Jar. +var domainHandlingTests = [...]jarTest{ + { + "Host cookie", + "http://www.host.test", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://host.test", ""}, + {"http://bar.host.test", ""}, + {"http://foo.www.host.test", ""}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie #1", + "http://www.host.test", + []string{"a=1; domain=host.test"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://host.test", "a=1"}, + {"http://bar.host.test", "a=1"}, + {"http://foo.www.host.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie #2", + "http://www.host.test", + []string{"a=1; domain=.host.test"}, + "a=1", + []query{ + {"http://www.host.test", "a=1"}, + {"http://host.test", "a=1"}, + {"http://bar.host.test", "a=1"}, + {"http://foo.www.host.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Host cookie on IDNA domain #1", + "http://www.bücher.test", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.bücher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bücher.test", ""}, + {"http://xn--bcher-kva.test", ""}, + {"http://bar.bücher.test", ""}, + {"http://bar.xn--bcher-kva.test", ""}, + {"http://foo.www.bücher.test", ""}, + {"http://foo.www.xn--bcher-kva.test", ""}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Host cookie on IDNA domain #2", + "http://www.xn--bcher-kva.test", + []string{"a=1"}, + "a=1", + []query{ + {"http://www.bücher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bücher.test", ""}, + {"http://xn--bcher-kva.test", ""}, + {"http://bar.bücher.test", ""}, + {"http://bar.xn--bcher-kva.test", ""}, + {"http://foo.www.bücher.test", ""}, + {"http://foo.www.xn--bcher-kva.test", ""}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie on IDNA domain #1", + "http://www.bücher.test", + []string{"a=1; domain=xn--bcher-kva.test"}, + "a=1", + []query{ + {"http://www.bücher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bücher.test", "a=1"}, + {"http://xn--bcher-kva.test", "a=1"}, + {"http://bar.bücher.test", "a=1"}, + {"http://bar.xn--bcher-kva.test", "a=1"}, + {"http://foo.www.bücher.test", "a=1"}, + {"http://foo.www.xn--bcher-kva.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Domain cookie on IDNA domain #2", + "http://www.xn--bcher-kva.test", + []string{"a=1; domain=xn--bcher-kva.test"}, + "a=1", + []query{ + {"http://www.bücher.test", "a=1"}, + {"http://www.xn--bcher-kva.test", "a=1"}, + {"http://bücher.test", "a=1"}, + {"http://xn--bcher-kva.test", "a=1"}, + {"http://bar.bücher.test", "a=1"}, + {"http://bar.xn--bcher-kva.test", "a=1"}, + {"http://foo.www.bücher.test", "a=1"}, + {"http://foo.www.xn--bcher-kva.test", "a=1"}, + {"http://other.test", ""}, + {"http://test", ""}, + }, + }, + { + "Host cookie on TLD.", + "http://com", + []string{"a=1"}, + "a=1", + []query{ + {"http://com", "a=1"}, + {"http://any.com", ""}, + {"http://any.test", ""}, + }, + }, + { + "Domain cookie on TLD becomes a host cookie.", + "http://com", + []string{"a=1; domain=com"}, + "a=1", + []query{ + {"http://com", "a=1"}, + {"http://any.com", ""}, + {"http://any.test", ""}, + }, + }, + { + "Host cookie on public suffix.", + "http://co.uk", + []string{"a=1"}, + "a=1", + []query{ + {"http://co.uk", "a=1"}, + {"http://uk", ""}, + {"http://some.co.uk", ""}, + {"http://foo.some.co.uk", ""}, + {"http://any.uk", ""}, + }, + }, + { + "Domain cookie on public suffix is ignored.", + "http://some.co.uk", + []string{"a=1; domain=co.uk"}, + "", + []query{ + {"http://co.uk", ""}, + {"http://uk", ""}, + {"http://some.co.uk", ""}, + {"http://foo.some.co.uk", ""}, + {"http://any.uk", ""}, + }, + }, +} + +func TestDomainHandling(t *testing.T) { + for _, test := range domainHandlingTests { + jar := newTestJar() + test.run(t, jar) + } +} diff --git a/libgo/go/net/http/cookiejar/punycode.go b/libgo/go/net/http/cookiejar/punycode.go new file mode 100644 index 0000000..ea7ceb5 --- /dev/null +++ b/libgo/go/net/http/cookiejar/punycode.go @@ -0,0 +1,159 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cookiejar + +// This file implements the Punycode algorithm from RFC 3492. + +import ( + "fmt" + "strings" + "unicode/utf8" +) + +// These parameter values are specified in section 5. +// +// All computation is done with int32s, so that overflow behavior is identical +// regardless of whether int is 32-bit or 64-bit. +const ( + base int32 = 36 + damp int32 = 700 + initialBias int32 = 72 + initialN int32 = 128 + skew int32 = 38 + tmax int32 = 26 + tmin int32 = 1 +) + +// encode encodes a string as specified in section 6.3 and prepends prefix to +// the result. +// +// The "while h < length(input)" line in the specification becomes "for +// remaining != 0" in the Go code, because len(s) in Go is in bytes, not runes. +func encode(prefix, s string) (string, error) { + output := make([]byte, len(prefix), len(prefix)+1+2*len(s)) + copy(output, prefix) + delta, n, bias := int32(0), initialN, initialBias + b, remaining := int32(0), int32(0) + for _, r := range s { + if r < 0x80 { + b++ + output = append(output, byte(r)) + } else { + remaining++ + } + } + h := b + if b > 0 { + output = append(output, '-') + } + for remaining != 0 { + m := int32(0x7fffffff) + for _, r := range s { + if m > r && r >= n { + m = r + } + } + delta += (m - n) * (h + 1) + if delta < 0 { + return "", fmt.Errorf("cookiejar: invalid label %q", s) + } + n = m + for _, r := range s { + if r < n { + delta++ + if delta < 0 { + return "", fmt.Errorf("cookiejar: invalid label %q", s) + } + continue + } + if r > n { + continue + } + q := delta + for k := base; ; k += base { + t := k - bias + if t < tmin { + t = tmin + } else if t > tmax { + t = tmax + } + if q < t { + break + } + output = append(output, encodeDigit(t+(q-t)%(base-t))) + q = (q - t) / (base - t) + } + output = append(output, encodeDigit(q)) + bias = adapt(delta, h+1, h == b) + delta = 0 + h++ + remaining-- + } + delta++ + n++ + } + return string(output), nil +} + +func encodeDigit(digit int32) byte { + switch { + case 0 <= digit && digit < 26: + return byte(digit + 'a') + case 26 <= digit && digit < 36: + return byte(digit + ('0' - 26)) + } + panic("cookiejar: internal error in punycode encoding") +} + +// adapt is the bias adaptation function specified in section 6.1. +func adapt(delta, numPoints int32, firstTime bool) int32 { + if firstTime { + delta /= damp + } else { + delta /= 2 + } + delta += delta / numPoints + k := int32(0) + for delta > ((base-tmin)*tmax)/2 { + delta /= base - tmin + k += base + } + return k + (base-tmin+1)*delta/(delta+skew) +} + +// Strictly speaking, the remaining code below deals with IDNA (RFC 5890 and +// friends) and not Punycode (RFC 3492) per se. + +// acePrefix is the ASCII Compatible Encoding prefix. +const acePrefix = "xn--" + +// toASCII converts a domain or domain label to its ASCII form. For example, +// toASCII("bücher.example.com") is "xn--bcher-kva.example.com", and +// toASCII("golang") is "golang". +func toASCII(s string) (string, error) { + if ascii(s) { + return s, nil + } + labels := strings.Split(s, ".") + for i, label := range labels { + if !ascii(label) { + a, err := encode(acePrefix, label) + if err != nil { + return "", err + } + labels[i] = a + } + } + return strings.Join(labels, "."), nil +} + +func ascii(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] >= utf8.RuneSelf { + return false + } + } + return true +} diff --git a/libgo/go/net/http/cookiejar/punycode_test.go b/libgo/go/net/http/cookiejar/punycode_test.go new file mode 100644 index 0000000..0301de1 --- /dev/null +++ b/libgo/go/net/http/cookiejar/punycode_test.go @@ -0,0 +1,161 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cookiejar + +import ( + "testing" +) + +var punycodeTestCases = [...]struct { + s, encoded string +}{ + {"", ""}, + {"-", "--"}, + {"-a", "-a-"}, + {"-a-", "-a--"}, + {"a", "a-"}, + {"a-", "a--"}, + {"a-b", "a-b-"}, + {"books", "books-"}, + {"bücher", "bcher-kva"}, + {"Hello世界", "Hello-ck1hg65u"}, + {"ü", "tda"}, + {"üý", "tdac"}, + + // The test cases below come from RFC 3492 section 7.1 with Errata 3026. + { + // (A) Arabic (Egyptian). + "\u0644\u064A\u0647\u0645\u0627\u0628\u062A\u0643\u0644" + + "\u0645\u0648\u0634\u0639\u0631\u0628\u064A\u061F", + "egbpdaj6bu4bxfgehfvwxn", + }, + { + // (B) Chinese (simplified). + "\u4ED6\u4EEC\u4E3A\u4EC0\u4E48\u4E0D\u8BF4\u4E2D\u6587", + "ihqwcrb4cv8a8dqg056pqjye", + }, + { + // (C) Chinese (traditional). + "\u4ED6\u5011\u7232\u4EC0\u9EBD\u4E0D\u8AAA\u4E2D\u6587", + "ihqwctvzc91f659drss3x8bo0yb", + }, + { + // (D) Czech. + "\u0050\u0072\u006F\u010D\u0070\u0072\u006F\u0073\u0074" + + "\u011B\u006E\u0065\u006D\u006C\u0075\u0076\u00ED\u010D" + + "\u0065\u0073\u006B\u0079", + "Proprostnemluvesky-uyb24dma41a", + }, + { + // (E) Hebrew. + "\u05DC\u05DE\u05D4\u05D4\u05DD\u05E4\u05E9\u05D5\u05D8" + + "\u05DC\u05D0\u05DE\u05D3\u05D1\u05E8\u05D9\u05DD\u05E2" + + "\u05D1\u05E8\u05D9\u05EA", + "4dbcagdahymbxekheh6e0a7fei0b", + }, + { + // (F) Hindi (Devanagari). + "\u092F\u0939\u0932\u094B\u0917\u0939\u093F\u0928\u094D" + + "\u0926\u0940\u0915\u094D\u092F\u094B\u0902\u0928\u0939" + + "\u0940\u0902\u092C\u094B\u0932\u0938\u0915\u0924\u0947" + + "\u0939\u0948\u0902", + "i1baa7eci9glrd9b2ae1bj0hfcgg6iyaf8o0a1dig0cd", + }, + { + // (G) Japanese (kanji and hiragana). + "\u306A\u305C\u307F\u3093\u306A\u65E5\u672C\u8A9E\u3092" + + "\u8A71\u3057\u3066\u304F\u308C\u306A\u3044\u306E\u304B", + "n8jok5ay5dzabd5bym9f0cm5685rrjetr6pdxa", + }, + { + // (H) Korean (Hangul syllables). + "\uC138\uACC4\uC758\uBAA8\uB4E0\uC0AC\uB78C\uB4E4\uC774" + + "\uD55C\uAD6D\uC5B4\uB97C\uC774\uD574\uD55C\uB2E4\uBA74" + + "\uC5BC\uB9C8\uB098\uC88B\uC744\uAE4C", + "989aomsvi5e83db1d2a355cv1e0vak1dwrv93d5xbh15a0dt30a5j" + + "psd879ccm6fea98c", + }, + { + // (I) Russian (Cyrillic). + "\u043F\u043E\u0447\u0435\u043C\u0443\u0436\u0435\u043E" + + "\u043D\u0438\u043D\u0435\u0433\u043E\u0432\u043E\u0440" + + "\u044F\u0442\u043F\u043E\u0440\u0443\u0441\u0441\u043A" + + "\u0438", + "b1abfaaepdrnnbgefbadotcwatmq2g4l", + }, + { + // (J) Spanish. + "\u0050\u006F\u0072\u0071\u0075\u00E9\u006E\u006F\u0070" + + "\u0075\u0065\u0064\u0065\u006E\u0073\u0069\u006D\u0070" + + "\u006C\u0065\u006D\u0065\u006E\u0074\u0065\u0068\u0061" + + "\u0062\u006C\u0061\u0072\u0065\u006E\u0045\u0073\u0070" + + "\u0061\u00F1\u006F\u006C", + "PorqunopuedensimplementehablarenEspaol-fmd56a", + }, + { + // (K) Vietnamese. + "\u0054\u1EA1\u0069\u0073\u0061\u006F\u0068\u1ECD\u006B" + + "\u0068\u00F4\u006E\u0067\u0074\u0068\u1EC3\u0063\u0068" + + "\u1EC9\u006E\u00F3\u0069\u0074\u0069\u1EBF\u006E\u0067" + + "\u0056\u0069\u1EC7\u0074", + "TisaohkhngthchnitingVit-kjcr8268qyxafd2f1b9g", + }, + { + // (L) 3<nen>B<gumi><kinpachi><sensei>. + "\u0033\u5E74\u0042\u7D44\u91D1\u516B\u5148\u751F", + "3B-ww4c5e180e575a65lsy2b", + }, + { + // (M) <amuro><namie>-with-SUPER-MONKEYS. + "\u5B89\u5BA4\u5948\u7F8E\u6075\u002D\u0077\u0069\u0074" + + "\u0068\u002D\u0053\u0055\u0050\u0045\u0052\u002D\u004D" + + "\u004F\u004E\u004B\u0045\u0059\u0053", + "-with-SUPER-MONKEYS-pc58ag80a8qai00g7n9n", + }, + { + // (N) Hello-Another-Way-<sorezore><no><basho>. + "\u0048\u0065\u006C\u006C\u006F\u002D\u0041\u006E\u006F" + + "\u0074\u0068\u0065\u0072\u002D\u0057\u0061\u0079\u002D" + + "\u305D\u308C\u305E\u308C\u306E\u5834\u6240", + "Hello-Another-Way--fc4qua05auwb3674vfr0b", + }, + { + // (O) <hitotsu><yane><no><shita>2. + "\u3072\u3068\u3064\u5C4B\u6839\u306E\u4E0B\u0032", + "2-u9tlzr9756bt3uc0v", + }, + { + // (P) Maji<de>Koi<suru>5<byou><mae> + "\u004D\u0061\u006A\u0069\u3067\u004B\u006F\u0069\u3059" + + "\u308B\u0035\u79D2\u524D", + "MajiKoi5-783gue6qz075azm5e", + }, + { + // (Q) <pafii>de<runba> + "\u30D1\u30D5\u30A3\u30FC\u0064\u0065\u30EB\u30F3\u30D0", + "de-jg4avhby1noc0d", + }, + { + // (R) <sono><supiido><de> + "\u305D\u306E\u30B9\u30D4\u30FC\u30C9\u3067", + "d9juau41awczczp", + }, + { + // (S) -> $1.00 <- + "\u002D\u003E\u0020\u0024\u0031\u002E\u0030\u0030\u0020" + + "\u003C\u002D", + "-> $1.00 <--", + }, +} + +func TestPunycode(t *testing.T) { + for _, tc := range punycodeTestCases { + if got, err := encode("", tc.s); err != nil { + t.Errorf(`encode("", %q): %v`, tc.s, err) + } else if got != tc.encoded { + t.Errorf(`encode("", %q): got %q, want %q`, tc.s, got, tc.encoded) + } + } +} diff --git a/libgo/go/net/http/example_test.go b/libgo/go/net/http/example_test.go index 22073ea..bc60df7 100644 --- a/libgo/go/net/http/example_test.go +++ b/libgo/go/net/http/example_test.go @@ -51,6 +51,20 @@ func ExampleGet() { } func ExampleFileServer() { - // we use StripPrefix so that /tmpfiles/somefile will access /tmp/somefile + // Simple static webserver: + log.Fatal(http.ListenAndServe(":8080", http.FileServer(http.Dir("/usr/share/doc")))) +} + +func ExampleFileServer_stripPrefix() { + // To serve a directory on disk (/tmp) under an alternate URL + // path (/tmpfiles/), use StripPrefix to modify the request + // URL's path before the FileServer sees it: + http.Handle("/tmpfiles/", http.StripPrefix("/tmpfiles/", http.FileServer(http.Dir("/tmp")))) +} + +func ExampleStripPrefix() { + // To serve a directory on disk (/tmp) under an alternate URL + // path (/tmpfiles/), use StripPrefix to modify the request + // URL's path before the FileServer sees it: http.Handle("/tmpfiles/", http.StripPrefix("/tmpfiles/", http.FileServer(http.Dir("/tmp")))) } diff --git a/libgo/go/net/http/export_test.go b/libgo/go/net/http/export_test.go index a7a0785..3fc2453 100644 --- a/libgo/go/net/http/export_test.go +++ b/libgo/go/net/http/export_test.go @@ -16,10 +16,16 @@ func NewLoggingConn(baseName string, c net.Conn) net.Conn { return newLoggingConn(baseName, c) } +func (t *Transport) NumPendingRequestsForTesting() int { + t.reqMu.Lock() + defer t.reqMu.Unlock() + return len(t.reqConn) +} + func (t *Transport) IdleConnKeysForTesting() (keys []string) { keys = make([]string, 0) - t.idleLk.Lock() - defer t.idleLk.Unlock() + t.idleMu.Lock() + defer t.idleMu.Unlock() if t.idleConn == nil { return } @@ -30,8 +36,8 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) { } func (t *Transport) IdleConnCountForTesting(cacheKey string) int { - t.idleLk.Lock() - defer t.idleLk.Unlock() + t.idleMu.Lock() + defer t.idleMu.Unlock() if t.idleConn == nil { return 0 } @@ -48,3 +54,5 @@ func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler { } return &timeoutHandler{handler, f, ""} } + +var DefaultUserAgent = defaultUserAgent diff --git a/libgo/go/net/http/fcgi/child.go b/libgo/go/net/http/fcgi/child.go index c8b9a33..60b794e 100644 --- a/libgo/go/net/http/fcgi/child.go +++ b/libgo/go/net/http/fcgi/child.go @@ -10,10 +10,12 @@ import ( "errors" "fmt" "io" + "io/ioutil" "net" "net/http" "net/http/cgi" "os" + "strings" "time" ) @@ -152,20 +154,23 @@ func (c *child) serve() { var errCloseConn = errors.New("fcgi: connection should be closed") +var emptyBody = ioutil.NopCloser(strings.NewReader("")) + func (c *child) handleRecord(rec *record) error { req, ok := c.requests[rec.h.Id] if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues { // The spec says to ignore unknown request IDs. return nil } - if ok && rec.h.Type == typeBeginRequest { - // The server is trying to begin a request with the same ID - // as an in-progress request. This is an error. - return errors.New("fcgi: received ID that is already in-flight") - } switch rec.h.Type { case typeBeginRequest: + if req != nil { + // The server is trying to begin a request with the same ID + // as an in-progress request. This is an error. + return errors.New("fcgi: received ID that is already in-flight") + } + var br beginRequest if err := br.read(rec.content()); err != nil { return err @@ -175,6 +180,7 @@ func (c *child) handleRecord(rec *record) error { return nil } c.requests[rec.h.Id] = newRequest(rec.h.Id, br.flags) + return nil case typeParams: // NOTE(eds): Technically a key-value pair can straddle the boundary // between two packets. We buffer until we've received all parameters. @@ -183,6 +189,7 @@ func (c *child) handleRecord(rec *record) error { return nil } req.parseParams() + return nil case typeStdin: content := rec.content() if req.pw == nil { @@ -191,6 +198,8 @@ func (c *child) handleRecord(rec *record) error { // body could be an io.LimitReader, but it shouldn't matter // as long as both sides are behaving. body, req.pw = io.Pipe() + } else { + body = emptyBody } go c.serveRequest(req, body) } @@ -201,24 +210,29 @@ func (c *child) handleRecord(rec *record) error { } else if req.pw != nil { req.pw.Close() } + return nil case typeGetValues: values := map[string]string{"FCGI_MPXS_CONNS": "1"} c.conn.writePairs(typeGetValuesResult, 0, values) + return nil case typeData: // If the filter role is implemented, read the data stream here. + return nil case typeAbortRequest: + println("abort") delete(c.requests, rec.h.Id) c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete) if !req.keepConn { // connection will close upon return return errCloseConn } + return nil default: b := make([]byte, 8) b[0] = byte(rec.h.Type) c.conn.writeRecord(typeUnknownType, 0, b) + return nil } - return nil } func (c *child) serveRequest(req *request, body io.ReadCloser) { @@ -232,11 +246,19 @@ func (c *child) serveRequest(req *request, body io.ReadCloser) { httpReq.Body = body c.handler.ServeHTTP(r, httpReq) } - if body != nil { - body.Close() - } r.Close() c.conn.writeEndRequest(req.reqId, 0, statusRequestComplete) + + // Consume the entire body, so the host isn't still writing to + // us when we close the socket below in the !keepConn case, + // otherwise we'd send a RST. (golang.org/issue/4183) + // TODO(bradfitz): also bound this copy in time. Or send + // some sort of abort request to the host, so the host + // can properly cut off the client sending all the data. + // For now just bound it a little and + io.CopyN(ioutil.Discard, body, 100<<20) + body.Close() + if !req.keepConn { c.conn.Close() } @@ -267,5 +289,4 @@ func Serve(l net.Listener, handler http.Handler) error { c := newChild(rw, handler) go c.serve() } - panic("unreachable") } diff --git a/libgo/go/net/http/filetransport_test.go b/libgo/go/net/http/filetransport_test.go index 039926b..6f1a537 100644 --- a/libgo/go/net/http/filetransport_test.go +++ b/libgo/go/net/http/filetransport_test.go @@ -2,11 +2,10 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package http_test +package http import ( "io/ioutil" - "net/http" "os" "path/filepath" "testing" @@ -32,9 +31,9 @@ func TestFileTransport(t *testing.T) { defer os.Remove(dname) defer os.Remove(fname) - tr := &http.Transport{} - tr.RegisterProtocol("file", http.NewFileTransport(http.Dir(dname))) - c := &http.Client{Transport: tr} + tr := &Transport{} + tr.RegisterProtocol("file", NewFileTransport(Dir(dname))) + c := &Client{Transport: tr} fooURLs := []string{"file:///foo.txt", "file://../foo.txt"} for _, urlstr := range fooURLs { @@ -62,4 +61,5 @@ func TestFileTransport(t *testing.T) { if res.StatusCode != 404 { t.Errorf("for %s, StatusCode = %d, want 404", badURL, res.StatusCode) } + res.Body.Close() } diff --git a/libgo/go/net/http/fs_test.go b/libgo/go/net/http/fs_test.go index d42014c..2c37376 100644 --- a/libgo/go/net/http/fs_test.go +++ b/libgo/go/net/http/fs_test.go @@ -54,6 +54,7 @@ var ServeFileRangeTests = []struct { } func TestServeFile(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "testdata/file") })) @@ -169,6 +170,7 @@ var fsRedirectTestData = []struct { } func TestFSRedirect(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(StripPrefix("/test", FileServer(Dir(".")))) defer ts.Close() @@ -193,6 +195,7 @@ func (fs *testFileSystem) Open(name string) (File, error) { } func TestFileServerCleans(t *testing.T) { + defer afterTest(t) ch := make(chan string, 1) fs := FileServer(&testFileSystem{func(name string) (File, error) { ch <- name @@ -224,6 +227,7 @@ func mustRemoveAll(dir string) { } func TestFileServerImplicitLeadingSlash(t *testing.T) { + defer afterTest(t) tempDir, err := ioutil.TempDir("", "") if err != nil { t.Fatalf("TempDir: %v", err) @@ -302,6 +306,7 @@ func TestEmptyDirOpenCWD(t *testing.T) { } func TestServeFileContentType(t *testing.T) { + defer afterTest(t) const ctype = "icecream/chocolate" ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.FormValue("override") == "1" { @@ -318,12 +323,14 @@ func TestServeFileContentType(t *testing.T) { if h := resp.Header.Get("Content-Type"); h != want { t.Errorf("Content-Type mismatch: got %q, want %q", h, want) } + resp.Body.Close() } get("0", "text/plain; charset=utf-8") get("1", ctype) } func TestServeFileMimeType(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "testdata/style.css") })) @@ -332,6 +339,7 @@ func TestServeFileMimeType(t *testing.T) { if err != nil { t.Fatal(err) } + resp.Body.Close() want := "text/css; charset=utf-8" if h := resp.Header.Get("Content-Type"); h != want { t.Errorf("Content-Type mismatch: got %q, want %q", h, want) @@ -339,6 +347,7 @@ func TestServeFileMimeType(t *testing.T) { } func TestServeFileFromCWD(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "fs_test.go") })) @@ -354,6 +363,7 @@ func TestServeFileFromCWD(t *testing.T) { } func TestServeFileWithContentEncoding(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "foo") ServeFile(w, r, "testdata/file") @@ -363,12 +373,14 @@ func TestServeFileWithContentEncoding(t *testing.T) { if err != nil { t.Fatal(err) } + resp.Body.Close() if g, e := resp.ContentLength, int64(-1); g != e { t.Errorf("Content-Length mismatch: got %d, want %d", g, e) } } func TestServeIndexHtml(t *testing.T) { + defer afterTest(t) const want = "index.html says hello\n" ts := httptest.NewServer(FileServer(Dir("."))) defer ts.Close() @@ -390,6 +402,7 @@ func TestServeIndexHtml(t *testing.T) { } func TestFileServerZeroByte(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(FileServer(Dir("."))) defer ts.Close() @@ -458,6 +471,7 @@ func (fs fakeFS) Open(name string) (File, error) { } func TestDirectoryIfNotModified(t *testing.T) { + defer afterTest(t) const indexContents = "I am a fake index.html file" fileMod := time.Unix(1000000000, 0).UTC() fileModStr := fileMod.Format(TimeFormat) @@ -531,6 +545,7 @@ func mustStat(t *testing.T, fileName string) os.FileInfo { } func TestServeContent(t *testing.T) { + defer afterTest(t) type serveParam struct { name string modtime time.Time @@ -663,6 +678,7 @@ func TestServeContent(t *testing.T) { // verifies that sendfile is being used on Linux func TestLinuxSendfile(t *testing.T) { + defer afterTest(t) if runtime.GOOS != "linux" { t.Skip("skipping; linux-only test") } @@ -681,7 +697,7 @@ func TestLinuxSendfile(t *testing.T) { defer ln.Close() var buf bytes.Buffer - child := exec.Command("strace", "-f", "-e!sigaltstack", os.Args[0], "-test.run=TestLinuxSendfileChild") + child := exec.Command("strace", "-f", "-q", "-e", "trace=sendfile,sendfile64", os.Args[0], "-test.run=TestLinuxSendfileChild") child.ExtraFiles = append(child.ExtraFiles, lnf) child.Env = append([]string{"GO_WANT_HELPER_PROCESS=1"}, os.Environ()...) child.Stdout = &buf diff --git a/libgo/go/net/http/header.go b/libgo/go/net/http/header.go index f479b7b..6374237 100644 --- a/libgo/go/net/http/header.go +++ b/libgo/go/net/http/header.go @@ -103,21 +103,41 @@ type keyValues struct { values []string } -type byKey []keyValues - -func (s byKey) Len() int { return len(s) } -func (s byKey) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -func (s byKey) Less(i, j int) bool { return s[i].key < s[j].key } - -func (h Header) sortedKeyValues(exclude map[string]bool) []keyValues { - kvs := make([]keyValues, 0, len(h)) +// A headerSorter implements sort.Interface by sorting a []keyValues +// by key. It's used as a pointer, so it can fit in a sort.Interface +// interface value without allocation. +type headerSorter struct { + kvs []keyValues +} + +func (s *headerSorter) Len() int { return len(s.kvs) } +func (s *headerSorter) Swap(i, j int) { s.kvs[i], s.kvs[j] = s.kvs[j], s.kvs[i] } +func (s *headerSorter) Less(i, j int) bool { return s.kvs[i].key < s.kvs[j].key } + +// TODO: convert this to a sync.Cache (issue 4720) +var headerSorterCache = make(chan *headerSorter, 8) + +// sortedKeyValues returns h's keys sorted in the returned kvs +// slice. The headerSorter used to sort is also returned, for possible +// return to headerSorterCache. +func (h Header) sortedKeyValues(exclude map[string]bool) (kvs []keyValues, hs *headerSorter) { + select { + case hs = <-headerSorterCache: + default: + hs = new(headerSorter) + } + if cap(hs.kvs) < len(h) { + hs.kvs = make([]keyValues, 0, len(h)) + } + kvs = hs.kvs[:0] for k, vv := range h { if !exclude[k] { kvs = append(kvs, keyValues{k, vv}) } } - sort.Sort(byKey(kvs)) - return kvs + hs.kvs = kvs + sort.Sort(hs) + return kvs, hs } // WriteSubset writes a header in wire format. @@ -127,7 +147,8 @@ func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { if !ok { ws = stringWriter{w} } - for _, kv := range h.sortedKeyValues(exclude) { + kvs, sorter := h.sortedKeyValues(exclude) + for _, kv := range kvs { for _, v := range kv.values { v = headerNewlineToSpace.Replace(v) v = textproto.TrimString(v) @@ -138,6 +159,10 @@ func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { } } } + select { + case headerSorterCache <- sorter: + default: + } return nil } diff --git a/libgo/go/net/http/header_test.go b/libgo/go/net/http/header_test.go index 01bb4dc..584f100 100644 --- a/libgo/go/net/http/header_test.go +++ b/libgo/go/net/http/header_test.go @@ -175,38 +175,33 @@ func TestHasToken(t *testing.T) { } } -func BenchmarkHeaderWriteSubset(b *testing.B) { - doHeaderWriteSubset(b.N, b) +var testHeader = Header{ + "Content-Length": {"123"}, + "Content-Type": {"text/plain"}, + "Date": {"some date at some time Z"}, + "Server": {DefaultUserAgent}, } -func TestHeaderWriteSubsetMallocs(t *testing.T) { - doHeaderWriteSubset(100, t) -} +var buf bytes.Buffer -type errorfer interface { - Errorf(string, ...interface{}) +func BenchmarkHeaderWriteSubset(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + buf.Reset() + testHeader.WriteSubset(&buf, nil) + } } -func doHeaderWriteSubset(n int, t errorfer) { - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) - h := Header(map[string][]string{ - "Content-Length": {"123"}, - "Content-Type": {"text/plain"}, - "Date": {"some date at some time Z"}, - "Server": {"Go http package"}, - }) - var buf bytes.Buffer - var m0 runtime.MemStats - runtime.ReadMemStats(&m0) - for i := 0; i < n; i++ { - buf.Reset() - h.WriteSubset(&buf, nil) +func TestHeaderWriteSubsetMallocs(t *testing.T) { + t.Skip("Skipping alloc count test on gccgo") + if runtime.GOMAXPROCS(0) > 1 { + t.Skip("skipping; GOMAXPROCS>1") } - var m1 runtime.MemStats - runtime.ReadMemStats(&m1) - if mallocs := m1.Mallocs - m0.Mallocs; n >= 100 && mallocs >= uint64(n) { - // TODO(bradfitz,rsc): once we can sort without allocating, - // make this an error. See http://golang.org/issue/3761 - // t.Errorf("did %d mallocs (>= %d iterations); should have avoided mallocs", mallocs, n) + n := testing.AllocsPerRun(100, func() { + buf.Reset() + testHeader.WriteSubset(&buf, nil) + }) + if n > 0 { + t.Errorf("mallocs = %d; want 0", n) } } diff --git a/libgo/go/net/http/httptest/server.go b/libgo/go/net/http/httptest/server.go index fc52c9a..7f26555 100644 --- a/libgo/go/net/http/httptest/server.go +++ b/libgo/go/net/http/httptest/server.go @@ -21,7 +21,11 @@ import ( type Server struct { URL string // base URL of form http://ipaddr:port with no trailing slash Listener net.Listener - TLS *tls.Config // nil if not using TLS + + // TLS is the optional TLS configuration, populated with a new config + // after TLS is started. If set on an unstarted server before StartTLS + // is called, existing fields are copied into the new config. + TLS *tls.Config // Config may be changed after calling NewUnstartedServer and // before Start or StartTLS. @@ -119,9 +123,16 @@ func (s *Server) StartTLS() { panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) } - s.TLS = &tls.Config{ - NextProtos: []string{"http/1.1"}, - Certificates: []tls.Certificate{cert}, + existingConfig := s.TLS + s.TLS = new(tls.Config) + if existingConfig != nil { + *s.TLS = *existingConfig + } + if s.TLS.NextProtos == nil { + s.TLS.NextProtos = []string{"http/1.1"} + } + if len(s.TLS.Certificates) == 0 { + s.TLS.Certificates = []tls.Certificate{cert} } tlsListener := tls.NewListener(s.Listener, s.TLS) @@ -189,28 +200,29 @@ func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.h.ServeHTTP(w, r) } -// localhostCert is a PEM-encoded TLS cert with SAN DNS names +// localhostCert is a PEM-encoded TLS cert with SAN IPs // "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end // of ASN.1 time). +// generated from src/pkg/crypto/tls: +// go run generate_cert.go --rsa-bits 512 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h var localhostCert = []byte(`-----BEGIN CERTIFICATE----- -MIIBTTCB+qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX -DTQ5MTIzMTIzNTk1OVowADBaMAsGCSqGSIb3DQEBAQNLADBIAkEAsuA5mAFMj6Q7 -qoBzcvKzIq4kzuT5epSp2AkcQfyBHm7K13Ws7u+0b5Vb9gqTf5cAiIKcrtrXVqkL -8i1UQF6AzwIDAQABo2MwYTAOBgNVHQ8BAf8EBAMCACQwEgYDVR0TAQH/BAgwBgEB -/wIBATANBgNVHQ4EBgQEAQIDBDAPBgNVHSMECDAGgAQBAgMEMBsGA1UdEQQUMBKC -CTEyNy4wLjAuMYIFWzo6MV0wCwYJKoZIhvcNAQEFA0EAj1Jsn/h2KHy7dgqutZNB -nCGlNN+8vw263Bax9MklR85Ti6a0VWSvp/fDQZUADvmFTDkcXeA24pqmdUxeQDWw -Pg== +MIIBdzCCASOgAwIBAgIBADALBgkqhkiG9w0BAQUwEjEQMA4GA1UEChMHQWNtZSBD +bzAeFw03MDAxMDEwMDAwMDBaFw00OTEyMzEyMzU5NTlaMBIxEDAOBgNVBAoTB0Fj +bWUgQ28wWjALBgkqhkiG9w0BAQEDSwAwSAJBAN55NcYKZeInyTuhcCwFMhDHCmwa +IUSdtXdcbItRB/yfXGBhiex00IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEA +AaNoMGYwDgYDVR0PAQH/BAQDAgCkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1Ud +EwEB/wQFMAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAA +AAAAAAAAAAAAAAEwCwYJKoZIhvcNAQEFA0EAAoQn/ytgqpiLcZu9XKbCJsJcvkgk +Se6AbGXgSlq+ZCEVo0qIwSgeBqmsJxUu7NCSOwVJLYNEBO2DtIxoYVk+MA== -----END CERTIFICATE-----`) // localhostKey is the private key for localhostCert. var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- -MIIBPQIBAAJBALLgOZgBTI+kO6qAc3LysyKuJM7k+XqUqdgJHEH8gR5uytd1rO7v -tG+VW/YKk3+XAIiCnK7a11apC/ItVEBegM8CAwEAAQJBAI5sxq7naeR9ahyqRkJi -SIv2iMxLuPEHaezf5CYOPWjSjBPyVhyRevkhtqEjF/WkgL7C2nWpYHsUcBDBQVF0 -3KECIQDtEGB2ulnkZAahl3WuJziXGLB+p8Wgx7wzSM6bHu1c6QIhAMEp++CaS+SJ -/TrU0zwY/fW4SvQeb49BPZUF3oqR8Xz3AiEA1rAJHBzBgdOQKdE3ksMUPcnvNJSN -poCcELmz2clVXtkCIQCLytuLV38XHToTipR4yMl6O+6arzAjZ56uq7m7ZRV0TwIh -AM65XAOw8Dsg9Kq78aYXiOEDc5DL0sbFUu/SlmRcCg93 ------END RSA PRIVATE KEY----- -`) +MIIBPAIBAAJBAN55NcYKZeInyTuhcCwFMhDHCmwaIUSdtXdcbItRB/yfXGBhiex0 +0IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEAAQJBAQdUx66rfh8sYsgfdcvV +NoafYpnEcB5s4m/vSVe6SU7dCK6eYec9f9wpT353ljhDUHq3EbmE4foNzJngh35d +AekCIQDhRQG5Li0Wj8TM4obOnnXUXf1jRv0UkzE9AHWLG5q3AwIhAPzSjpYUDjVW +MCUXgckTpKCuGwbJk7424Nb8bLzf3kllAiA5mUBgjfr/WtFSJdWcPQ4Zt9KTMNKD +EUO0ukpTwEIl6wIhAMbGqZK3zAAFdq8DD2jPx+UJXnh0rnOkZBzDtJ6/iN69AiEA +1Aq8MJgTaYsDQWyU/hDq5YkDJc9e9DSCvUIzqxQWMQE= +-----END RSA PRIVATE KEY-----`) diff --git a/libgo/go/net/http/httputil/dump_test.go b/libgo/go/net/http/httputil/dump_test.go index 5afe9ba..3e87c27 100644 --- a/libgo/go/net/http/httputil/dump_test.go +++ b/libgo/go/net/http/httputil/dump_test.go @@ -68,7 +68,7 @@ var dumpTests = []dumpTest{ WantDumpOut: "GET /foo HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Accept-Encoding: gzip\r\n\r\n", }, @@ -80,7 +80,7 @@ var dumpTests = []dumpTest{ WantDumpOut: "GET /foo HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Accept-Encoding: gzip\r\n\r\n", }, } diff --git a/libgo/go/net/http/httputil/reverseproxy.go b/libgo/go/net/http/httputil/reverseproxy.go index 134c452..1990f64 100644 --- a/libgo/go/net/http/httputil/reverseproxy.go +++ b/libgo/go/net/http/httputil/reverseproxy.go @@ -81,6 +81,19 @@ func copyHeader(dst, src http.Header) { } } +// Hop-by-hop headers. These are removed when sent to the backend. +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html +var hopHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailers", + "Transfer-Encoding", + "Upgrade", +} + func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { transport := p.Transport if transport == nil { @@ -96,14 +109,21 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { outreq.ProtoMinor = 1 outreq.Close = false - // Remove the connection header to the backend. We want a - // persistent connection, regardless of what the client sent - // to us. This is modifying the same underlying map from req - // (shallow copied above) so we only copy it if necessary. - if outreq.Header.Get("Connection") != "" { - outreq.Header = make(http.Header) - copyHeader(outreq.Header, req.Header) - outreq.Header.Del("Connection") + // Remove hop-by-hop headers to the backend. Especially + // important is "Connection" because we want a persistent + // connection, regardless of what the client sent to us. This + // is modifying the same underlying map from req (shallow + // copied above) so we only copy it if necessary. + copiedHeaders := false + for _, h := range hopHeaders { + if outreq.Header.Get(h) != "" { + if !copiedHeaders { + outreq.Header = make(http.Header) + copyHeader(outreq.Header, req.Header) + copiedHeaders = true + } + outreq.Header.Del(h) + } } if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { @@ -182,7 +202,6 @@ func (m *maxLatencyWriter) flushLoop() { m.lk.Unlock() } } - panic("unreached") } func (m *maxLatencyWriter) stop() { m.done <- true } diff --git a/libgo/go/net/http/httputil/reverseproxy_test.go b/libgo/go/net/http/httputil/reverseproxy_test.go index 8639271..1c0444e 100644 --- a/libgo/go/net/http/httputil/reverseproxy_test.go +++ b/libgo/go/net/http/httputil/reverseproxy_test.go @@ -29,6 +29,9 @@ func TestReverseProxy(t *testing.T) { if c := r.Header.Get("Connection"); c != "" { t.Errorf("handler got Connection header value %q", c) } + if c := r.Header.Get("Upgrade"); c != "" { + t.Errorf("handler got Upgrade header value %q", c) + } if g, e := r.Host, "some-name"; g != e { t.Errorf("backend got Host header %q, want %q", g, e) } @@ -49,6 +52,7 @@ func TestReverseProxy(t *testing.T) { getReq, _ := http.NewRequest("GET", frontend.URL, nil) getReq.Host = "some-name" getReq.Header.Set("Connection", "close") + getReq.Header.Set("Upgrade", "foo") getReq.Close = true res, err := http.DefaultClient.Do(getReq) if err != nil { diff --git a/libgo/go/net/http/jar.go b/libgo/go/net/http/jar.go index 35eee68..5c3de0d 100644 --- a/libgo/go/net/http/jar.go +++ b/libgo/go/net/http/jar.go @@ -12,6 +12,8 @@ import ( // // Implementations of CookieJar must be safe for concurrent use by multiple // goroutines. +// +// The net/http/cookiejar package provides a CookieJar implementation. type CookieJar interface { // SetCookies handles the receipt of the cookies in a reply for the // given URL. It may or may not choose to save the cookies, depending diff --git a/libgo/go/net/http/npn_test.go b/libgo/go/net/http/npn_test.go new file mode 100644 index 0000000..98b8930 --- /dev/null +++ b/libgo/go/net/http/npn_test.go @@ -0,0 +1,118 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "bufio" + "crypto/tls" + "fmt" + "io" + "io/ioutil" + . "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestNextProtoUpgrade(t *testing.T) { + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "path=%s,proto=", r.URL.Path) + if r.TLS != nil { + w.Write([]byte(r.TLS.NegotiatedProtocol)) + } + if r.RemoteAddr == "" { + t.Error("request with no RemoteAddr") + } + if r.Body == nil { + t.Errorf("request with nil Body") + } + })) + ts.TLS = &tls.Config{ + NextProtos: []string{"unhandled-proto", "tls-0.9"}, + } + ts.Config.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){ + "tls-0.9": handleTLSProtocol09, + } + ts.StartTLS() + defer ts.Close() + + tr := newTLSTransport(t, ts) + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + // Normal request, without NPN. + { + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if want := "path=/,proto="; string(body) != want { + t.Errorf("plain request = %q; want %q", body, want) + } + } + + // Request to an advertised but unhandled NPN protocol. + // Server will hang up. + { + tr.CloseIdleConnections() + tr.TLSClientConfig.NextProtos = []string{"unhandled-proto"} + _, err := c.Get(ts.URL) + if err == nil { + t.Errorf("expected error on unhandled-proto request") + } + } + + // Request using the "tls-0.9" protocol, which we register here. + // It is HTTP/0.9 over TLS. + { + tlsConfig := newTLSTransport(t, ts).TLSClientConfig + tlsConfig.NextProtos = []string{"tls-0.9"} + conn, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig) + if err != nil { + t.Fatal(err) + } + conn.Write([]byte("GET /foo\n")) + body, err := ioutil.ReadAll(conn) + if err != nil { + t.Fatal(err) + } + if want := "path=/foo,proto=tls-0.9"; string(body) != want { + t.Errorf("plain request = %q; want %q", body, want) + } + } +} + +// handleTLSProtocol09 implements the HTTP/0.9 protocol over TLS, for the +// TestNextProtoUpgrade test. +func handleTLSProtocol09(srv *Server, conn *tls.Conn, h Handler) { + br := bufio.NewReader(conn) + line, err := br.ReadString('\n') + if err != nil { + return + } + line = strings.TrimSpace(line) + path := strings.TrimPrefix(line, "GET ") + if path == line { + return + } + req, _ := NewRequest("GET", path, nil) + req.Proto = "HTTP/0.9" + req.ProtoMajor = 0 + req.ProtoMinor = 9 + rw := &http09Writer{conn, make(Header)} + h.ServeHTTP(rw, req) +} + +type http09Writer struct { + io.Writer + h Header +} + +func (w http09Writer) Header() Header { return w.h } +func (w http09Writer) WriteHeader(int) {} // no headers diff --git a/libgo/go/net/http/pprof/pprof.go b/libgo/go/net/http/pprof/pprof.go index 0c03e5b..0c7548e 100644 --- a/libgo/go/net/http/pprof/pprof.go +++ b/libgo/go/net/http/pprof/pprof.go @@ -172,7 +172,7 @@ func (name handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // listing the available profiles. func Index(w http.ResponseWriter, r *http.Request) { if strings.HasPrefix(r.URL.Path, "/debug/pprof/") { - name := r.URL.Path[len("/debug/pprof/"):] + name := strings.TrimPrefix(r.URL.Path, "/debug/pprof/") if name != "" { handler(name).ServeHTTP(w, r) return diff --git a/libgo/go/net/http/request.go b/libgo/go/net/http/request.go index 217f35b..6d45691 100644 --- a/libgo/go/net/http/request.go +++ b/libgo/go/net/http/request.go @@ -48,7 +48,7 @@ var ( ErrUnexpectedTrailer = &ProtocolError{"trailer header without chunked transfer encoding"} ErrMissingContentLength = &ProtocolError{"missing ContentLength in HEAD response"} ErrNotMultipart = &ProtocolError{"request Content-Type isn't multipart/form-data"} - ErrMissingBoundary = &ProtocolError{"no multipart boundary param Content-Type"} + ErrMissingBoundary = &ProtocolError{"no multipart boundary param in Content-Type"} ) type badStringError struct { @@ -283,7 +283,7 @@ func valueOrDefault(value, def string) string { return def } -const defaultUserAgent = "Go http package" +const defaultUserAgent = "Go 1.1 package http" // Write writes an HTTP/1.1 request -- header and body -- in wire format. // This method consults the following fields of the request: @@ -467,10 +467,42 @@ func (r *Request) SetBasicAuth(username, password string) { r.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(s))) } +// parseRequestLine parses "GET /foo HTTP/1.1" into its three parts. +func parseRequestLine(line string) (method, requestURI, proto string, ok bool) { + s1 := strings.Index(line, " ") + s2 := strings.Index(line[s1+1:], " ") + if s1 < 0 || s2 < 0 { + return + } + s2 += s1 + 1 + return line[:s1], line[s1+1 : s2], line[s2+1:], true +} + +// TODO(bradfitz): use a sync.Cache when available +var textprotoReaderCache = make(chan *textproto.Reader, 4) + +func newTextprotoReader(br *bufio.Reader) *textproto.Reader { + select { + case r := <-textprotoReaderCache: + r.R = br + return r + default: + return textproto.NewReader(br) + } +} + +func putTextprotoReader(r *textproto.Reader) { + r.R = nil + select { + case textprotoReaderCache <- r: + default: + } +} + // ReadRequest reads and parses a request from b. func ReadRequest(b *bufio.Reader) (req *Request, err error) { - tp := textproto.NewReader(b) + tp := newTextprotoReader(b) req = new(Request) // First line: GET /index.html HTTP/1.0 @@ -479,18 +511,18 @@ func ReadRequest(b *bufio.Reader) (req *Request, err error) { return nil, err } defer func() { + putTextprotoReader(tp) if err == io.EOF { err = io.ErrUnexpectedEOF } }() - var f []string - if f = strings.SplitN(s, " ", 3); len(f) < 3 { + var ok bool + req.Method, req.RequestURI, req.Proto, ok = parseRequestLine(s) + if !ok { return nil, &badStringError{"malformed HTTP request", s} } - req.Method, req.RequestURI, req.Proto = f[0], f[1], f[2] rawurl := req.RequestURI - var ok bool if req.ProtoMajor, req.ProtoMinor, ok = ParseHTTPVersion(req.Proto); !ok { return nil, &badStringError{"malformed HTTP version", req.Proto} } diff --git a/libgo/go/net/http/request_test.go b/libgo/go/net/http/request_test.go index bd75792..692485c 100644 --- a/libgo/go/net/http/request_test.go +++ b/libgo/go/net/http/request_test.go @@ -262,7 +262,39 @@ func TestNewRequestContentLength(t *testing.T) { t.Fatal(err) } if req.ContentLength != tt.want { - t.Errorf("ContentLength(%#T) = %d; want %d", tt.r, req.ContentLength, tt.want) + t.Errorf("ContentLength(%T) = %d; want %d", tt.r, req.ContentLength, tt.want) + } + } +} + +var parseHTTPVersionTests = []struct { + vers string + major, minor int + ok bool +}{ + {"HTTP/0.9", 0, 9, true}, + {"HTTP/1.0", 1, 0, true}, + {"HTTP/1.1", 1, 1, true}, + {"HTTP/3.14", 3, 14, true}, + + {"HTTP", 0, 0, false}, + {"HTTP/one.one", 0, 0, false}, + {"HTTP/1.1/", 0, 0, false}, + {"HTTP/-1,0", 0, 0, false}, + {"HTTP/0,-1", 0, 0, false}, + {"HTTP/", 0, 0, false}, + {"HTTP/1,1", 0, 0, false}, +} + +func TestParseHTTPVersion(t *testing.T) { + for _, tt := range parseHTTPVersionTests { + major, minor, ok := ParseHTTPVersion(tt.vers) + if ok != tt.ok || major != tt.major || minor != tt.minor { + type version struct { + major, minor int + ok bool + } + t.Errorf("failed to parse %q, expected: %#v, got %#v", tt.vers, version{tt.major, tt.minor, tt.ok}, version{major, minor, ok}) } } } @@ -289,7 +321,7 @@ func TestRequestWriteBufferedWriter(t *testing.T) { want := []string{ "GET / HTTP/1.1\r\n", "Host: foo.com\r\n", - "User-Agent: Go http package\r\n", + "User-Agent: " + DefaultUserAgent + "\r\n", "\r\n", } if !reflect.DeepEqual(got, want) { @@ -401,3 +433,81 @@ Content-Disposition: form-data; name="textb" ` + textbValue + ` --MyBoundary-- ` + +func benchmarkReadRequest(b *testing.B, request string) { + request = request + "\n" // final \n + request = strings.Replace(request, "\n", "\r\n", -1) // expand \n to \r\n + b.SetBytes(int64(len(request))) + r := bufio.NewReader(&infiniteReader{buf: []byte(request)}) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := ReadRequest(r) + if err != nil { + b.Fatalf("failed to read request: %v", err) + } + } +} + +// infiniteReader satisfies Read requests as if the contents of buf +// loop indefinitely. +type infiniteReader struct { + buf []byte + offset int +} + +func (r *infiniteReader) Read(b []byte) (int, error) { + n := copy(b, r.buf[r.offset:]) + r.offset = (r.offset + n) % len(r.buf) + return n, nil +} + +func BenchmarkReadRequestChrome(b *testing.B) { + // https://github.com/felixge/node-http-perf/blob/master/fixtures/get.http + benchmarkReadRequest(b, `GET / HTTP/1.1 +Host: localhost:8080 +Connection: keep-alive +Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8 +User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17 +Accept-Encoding: gzip,deflate,sdch +Accept-Language: en-US,en;q=0.8 +Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3 +Cookie: __utma=1.1978842379.1323102373.1323102373.1323102373.1; EPi:NumberOfVisits=1,2012-02-28T13:42:18; CrmSession=5b707226b9563e1bc69084d07a107c98; plushContainerWidth=100%25; plushNoTopMenu=0; hudson_auto_refresh=false +`) +} + +func BenchmarkReadRequestCurl(b *testing.B) { + // curl http://localhost:8080/ + benchmarkReadRequest(b, `GET / HTTP/1.1 +User-Agent: curl/7.27.0 +Host: localhost:8080 +Accept: */* +`) +} + +func BenchmarkReadRequestApachebench(b *testing.B) { + // ab -n 1 -c 1 http://localhost:8080/ + benchmarkReadRequest(b, `GET / HTTP/1.0 +Host: localhost:8080 +User-Agent: ApacheBench/2.3 +Accept: */* +`) +} + +func BenchmarkReadRequestSiege(b *testing.B) { + // siege -r 1 -c 1 http://localhost:8080/ + benchmarkReadRequest(b, `GET / HTTP/1.1 +Host: localhost:8080 +Accept: */* +Accept-Encoding: gzip +User-Agent: JoeDog/1.00 [en] (X11; I; Siege 2.70) +Connection: keep-alive +`) +} + +func BenchmarkReadRequestWrk(b *testing.B) { + // wrk -t 1 -r 1 -c 1 http://localhost:8080/ + benchmarkReadRequest(b, `GET / HTTP/1.1 +Host: localhost:8080 +`) +} diff --git a/libgo/go/net/http/requestwrite_test.go b/libgo/go/net/http/requestwrite_test.go index fc3186f..b27b1f7 100644 --- a/libgo/go/net/http/requestwrite_test.go +++ b/libgo/go/net/http/requestwrite_test.go @@ -93,13 +93,13 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "GET /search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + chunk("abcdef") + chunk(""), WantProxy: "GET http://www.google.com/search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + chunk("abcdef") + chunk(""), }, @@ -123,14 +123,14 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "POST /search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Connection: close\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + chunk("abcdef") + chunk(""), WantProxy: "POST http://www.google.com/search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Connection: close\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + chunk("abcdef") + chunk(""), @@ -156,7 +156,7 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "POST /search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Connection: close\r\n" + "Content-Length: 6\r\n" + "\r\n" + @@ -164,7 +164,7 @@ var reqWriteTests = []reqWriteTest{ WantProxy: "POST http://www.google.com/search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Connection: close\r\n" + "Content-Length: 6\r\n" + "\r\n" + @@ -187,14 +187,14 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Content-Length: 6\r\n" + "\r\n" + "abcdef", WantProxy: "POST http://example.com/ HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Content-Length: 6\r\n" + "\r\n" + "abcdef", @@ -210,7 +210,7 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "GET /search HTTP/1.1\r\n" + "Host: www.google.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "\r\n", }, @@ -232,13 +232,13 @@ var reqWriteTests = []reqWriteTest{ // Also, nginx expects it for POST and PUT. WantWrite: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Content-Length: 0\r\n" + "\r\n", WantProxy: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Content-Length: 0\r\n" + "\r\n", }, @@ -258,13 +258,13 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + chunk("x") + chunk(""), WantProxy: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + chunk("x") + chunk(""), }, @@ -325,9 +325,96 @@ var reqWriteTests = []reqWriteTest{ WantWrite: "GET /foo HTTP/1.1\r\n" + "Host: \r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "X-Foo: X-Bar\r\n\r\n", }, + + // If no Request.Host and no Request.URL.Host, we send + // an empty Host header, and don't use + // Request.Header["Host"]. This is just testing that + // we don't change Go 1.0 behavior. + { + Req: Request{ + Method: "GET", + Host: "", + URL: &url.URL{ + Scheme: "http", + Host: "", + Path: "/search", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{ + "Host": []string{"bad.example.com"}, + }, + }, + + WantWrite: "GET /search HTTP/1.1\r\n" + + "Host: \r\n" + + "User-Agent: Go 1.1 package http\r\n\r\n", + }, + + // Opaque test #1 from golang.org/issue/4860 + { + Req: Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "www.google.com", + Opaque: "/%2F/%2F/", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + }, + + WantWrite: "GET /%2F/%2F/ HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go 1.1 package http\r\n\r\n", + }, + + // Opaque test #2 from golang.org/issue/4860 + { + Req: Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "x.google.com", + Opaque: "//y.google.com/%2F/%2F/", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + }, + + WantWrite: "GET http://y.google.com/%2F/%2F/ HTTP/1.1\r\n" + + "Host: x.google.com\r\n" + + "User-Agent: Go 1.1 package http\r\n\r\n", + }, + + // Testing custom case in header keys. Issue 5022. + { + Req: Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{ + "ALL-CAPS": {"x"}, + }, + }, + + WantWrite: "GET / HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go 1.1 package http\r\n" + + "ALL-CAPS: x\r\n" + + "\r\n", + }, } func TestRequestWrite(t *testing.T) { @@ -411,7 +498,7 @@ func TestRequestWriteClosesBody(t *testing.T) { } expected := "POST / HTTP/1.1\r\n" + "Host: foo.com\r\n" + - "User-Agent: Go http package\r\n" + + "User-Agent: Go 1.1 package http\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + // TODO: currently we don't buffer before chunking, so we get a // single "m" chunk before the other chunks, as this was the 1-byte diff --git a/libgo/go/net/http/response.go b/libgo/go/net/http/response.go index 7901c49..9a7e4e3 100644 --- a/libgo/go/net/http/response.go +++ b/libgo/go/net/http/response.go @@ -46,6 +46,9 @@ type Response struct { // The http Client and Transport guarantee that Body is always // non-nil, even on responses without a body or responses with // a zero-lengthed body. + // + // The Body is automatically dechunked if the server replied + // with a "chunked" Transfer-Encoding. Body io.ReadCloser // ContentLength records the length of the associated content. The @@ -198,9 +201,7 @@ func (r *Response) Write(w io.Writer) error { } protoMajor, protoMinor := strconv.Itoa(r.ProtoMajor), strconv.Itoa(r.ProtoMinor) statusCode := strconv.Itoa(r.StatusCode) + " " - if strings.HasPrefix(text, statusCode) { - text = text[len(statusCode):] - } + text = strings.TrimPrefix(text, statusCode) io.WriteString(w, "HTTP/"+protoMajor+"."+protoMinor+" "+statusCode+text+"\r\n") // Process Body,ContentLength,Close,Trailer diff --git a/libgo/go/net/http/response_test.go b/libgo/go/net/http/response_test.go index a00a4ae..02796e88 100644 --- a/libgo/go/net/http/response_test.go +++ b/libgo/go/net/http/response_test.go @@ -112,8 +112,8 @@ var respTests = []respTest{ ProtoMinor: 0, Request: dummyReq("GET"), Header: Header{ - "Connection": {"close"}, // TODO(rsc): Delete? - "Content-Length": {"10"}, // TODO(rsc): Delete? + "Connection": {"close"}, + "Content-Length": {"10"}, }, Close: true, ContentLength: 10, @@ -157,7 +157,7 @@ var respTests = []respTest{ "Content-Length: 10\r\n" + "\r\n" + "0a\r\n" + - "Body here\n" + + "Body here\n\r\n" + "0\r\n" + "\r\n", @@ -170,7 +170,7 @@ var respTests = []respTest{ Request: dummyReq("GET"), Header: Header{}, Close: false, - ContentLength: -1, // TODO(rsc): Fix? + ContentLength: -1, TransferEncoding: []string{"chunked"}, }, @@ -324,16 +324,37 @@ var respTests = []respTest{ "", }, + + // golang.org/issue/4767: don't special-case multipart/byteranges responses + { + `HTTP/1.1 206 Partial Content +Connection: close +Content-Type: multipart/byteranges; boundary=18a75608c8f47cef + +some body`, + Response{ + Status: "206 Partial Content", + StatusCode: 206, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), + Header: Header{ + "Content-Type": []string{"multipart/byteranges; boundary=18a75608c8f47cef"}, + }, + Close: true, + ContentLength: -1, + }, + + "some body", + }, } func TestReadResponse(t *testing.T) { - for i := range respTests { - tt := &respTests[i] - var braw bytes.Buffer - braw.WriteString(tt.Raw) - resp, err := ReadResponse(bufio.NewReader(&braw), tt.Resp.Request) + for i, tt := range respTests { + resp, err := ReadResponse(bufio.NewReader(strings.NewReader(tt.Raw)), tt.Resp.Request) if err != nil { - t.Errorf("#%d: %s", i, err) + t.Errorf("#%d: %v", i, err) continue } rbody := resp.Body @@ -341,7 +362,11 @@ func TestReadResponse(t *testing.T) { diff(t, fmt.Sprintf("#%d Response", i), resp, &tt.Resp) var bout bytes.Buffer if rbody != nil { - io.Copy(&bout, rbody) + _, err = io.Copy(&bout, rbody) + if err != nil { + t.Errorf("#%d: %v", i, err) + continue + } rbody.Close() } body := bout.String() @@ -351,6 +376,22 @@ func TestReadResponse(t *testing.T) { } } +func TestWriteResponse(t *testing.T) { + for i, tt := range respTests { + resp, err := ReadResponse(bufio.NewReader(strings.NewReader(tt.Raw)), tt.Resp.Request) + if err != nil { + t.Errorf("#%d: %v", i, err) + continue + } + bout := bytes.NewBuffer(nil) + err = resp.Write(bout) + if err != nil { + t.Errorf("#%d: %v", i, err) + continue + } + } +} + var readResponseCloseInMiddleTests = []struct { chunked, compressed bool }{ @@ -425,7 +466,7 @@ func TestReadResponseCloseInMiddle(t *testing.T) { if test.compressed { gzReader, err := gzip.NewReader(resp.Body) checkErr(err, "gzip.NewReader") - resp.Body = &readFirstCloseBoth{gzReader, resp.Body} + resp.Body = &readerAndCloser{gzReader, resp.Body} } rbuf := make([]byte, 2500) diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go index 886ed4e..d7b3215 100644 --- a/libgo/go/net/http/serve_test.go +++ b/libgo/go/net/http/serve_test.go @@ -10,6 +10,7 @@ import ( "bufio" "bytes" "crypto/tls" + "errors" "fmt" "io" "io/ioutil" @@ -64,10 +65,39 @@ func (a dummyAddr) String() string { return string(a) } +type noopConn struct{} + +func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") } +func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") } +func (noopConn) SetDeadline(t time.Time) error { return nil } +func (noopConn) SetReadDeadline(t time.Time) error { return nil } +func (noopConn) SetWriteDeadline(t time.Time) error { return nil } + +type rwTestConn struct { + io.Reader + io.Writer + noopConn + + closeFunc func() error // called if non-nil + closec chan bool // else, if non-nil, send value to it on close +} + +func (c *rwTestConn) Close() error { + if c.closeFunc != nil { + return c.closeFunc() + } + select { + case c.closec <- true: + default: + } + return nil +} + type testConn struct { readBuf bytes.Buffer writeBuf bytes.Buffer closec chan bool // if non-nil, send value to it on close + noopConn } func (c *testConn) Read(b []byte) (int, error) { @@ -86,26 +116,6 @@ func (c *testConn) Close() error { return nil } -func (c *testConn) LocalAddr() net.Addr { - return dummyAddr("local-addr") -} - -func (c *testConn) RemoteAddr() net.Addr { - return dummyAddr("remote-addr") -} - -func (c *testConn) SetDeadline(t time.Time) error { - return nil -} - -func (c *testConn) SetReadDeadline(t time.Time) error { - return nil -} - -func (c *testConn) SetWriteDeadline(t time.Time) error { - return nil -} - func TestConsumingBodyOnNextConn(t *testing.T) { conn := new(testConn) for i := 0; i < 2; i++ { @@ -184,6 +194,7 @@ var vtests = []struct { } func TestHostHandlers(t *testing.T) { + defer afterTest(t) mux := NewServeMux() for _, h := range handlers { mux.Handle(h.pattern, stringHandler(h.msg)) @@ -256,28 +267,22 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { } func TestServerTimeouts(t *testing.T) { - // TODO(bradfitz): convert this to use httptest.Server - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen error: %v", err) - } - addr, _ := l.Addr().(*net.TCPAddr) - + defer afterTest(t) reqNum := 0 - handler := HandlerFunc(func(res ResponseWriter, req *Request) { + ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { reqNum++ fmt.Fprintf(res, "req=%d", reqNum) - }) - - server := &Server{Handler: handler, ReadTimeout: 250 * time.Millisecond, WriteTimeout: 250 * time.Millisecond} - go server.Serve(l) - - url := fmt.Sprintf("http://%s/", addr) + })) + ts.Config.ReadTimeout = 250 * time.Millisecond + ts.Config.WriteTimeout = 250 * time.Millisecond + ts.Start() + defer ts.Close() // Hit the HTTP server successfully. tr := &Transport{DisableKeepAlives: true} // they interfere with this test + defer tr.CloseIdleConnections() c := &Client{Transport: tr} - r, err := c.Get(url) + r, err := c.Get(ts.URL) if err != nil { t.Fatalf("http Get #1: %v", err) } @@ -290,13 +295,13 @@ func TestServerTimeouts(t *testing.T) { // Slow client that should timeout. t1 := time.Now() - conn, err := net.Dial("tcp", addr.String()) + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) } buf := make([]byte, 1) n, err := conn.Read(buf) - latency := time.Now().Sub(t1) + latency := time.Since(t1) if n != 0 || err != io.EOF { t.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF) } @@ -307,7 +312,7 @@ func TestServerTimeouts(t *testing.T) { // Hit the HTTP server successfully again, verifying that the // previous slow connection didn't run our handler. (that we // get "req=2", not "req=3") - r, err = Get(url) + r, err = Get(ts.URL) if err != nil { t.Fatalf("http Get #2: %v", err) } @@ -317,11 +322,87 @@ func TestServerTimeouts(t *testing.T) { t.Errorf("Get #2 got %q, want %q", string(got), expected) } - l.Close() + if !testing.Short() { + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + go io.Copy(ioutil.Discard, conn) + for i := 0; i < 5; i++ { + _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n")) + if err != nil { + t.Fatalf("on write %d: %v", i, err) + } + time.Sleep(ts.Config.ReadTimeout / 2) + } + } +} + +// golang.org/issue/4741 -- setting only a write timeout that triggers +// shouldn't cause a handler to block forever on reads (next HTTP +// request) that will never happen. +func TestOnlyWriteTimeout(t *testing.T) { + defer afterTest(t) + var conn net.Conn + var afterTimeoutErrc = make(chan error, 1) + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, req *Request) { + buf := make([]byte, 512<<10) + _, err := w.Write(buf) + if err != nil { + t.Errorf("handler Write error: %v", err) + return + } + conn.SetWriteDeadline(time.Now().Add(-30 * time.Second)) + _, err = w.Write(buf) + afterTimeoutErrc <- err + })) + ts.Listener = trackLastConnListener{ts.Listener, &conn} + ts.Start() + defer ts.Close() + + tr := &Transport{DisableKeepAlives: false} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + errc := make(chan error) + go func() { + res, err := c.Get(ts.URL) + if err != nil { + errc <- err + return + } + _, err = io.Copy(ioutil.Discard, res.Body) + errc <- err + }() + select { + case err := <-errc: + if err == nil { + t.Errorf("expected an error from Get request") + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for Get error") + } + if err := <-afterTimeoutErrc; err == nil { + t.Error("expected write error after timeout") + } +} + +// trackLastConnListener tracks the last net.Conn that was accepted. +type trackLastConnListener struct { + net.Listener + last *net.Conn // destination +} + +func (l trackLastConnListener) Accept() (c net.Conn, err error) { + c, err = l.Listener.Accept() + *l.last = c + return } // TestIdentityResponse verifies that a handler can unset func TestIdentityResponse(t *testing.T) { + defer afterTest(t) handler := HandlerFunc(func(rw ResponseWriter, req *Request) { rw.Header().Set("Content-Length", "3") rw.Header().Set("Transfer-Encoding", req.FormValue("te")) @@ -367,10 +448,12 @@ func TestIdentityResponse(t *testing.T) { // Verify that ErrContentLength is returned url := ts.URL + "/?overwrite=1" - _, err := Get(url) + res, err := Get(url) if err != nil { t.Fatalf("error with Get of %s: %v", url, err) } + res.Body.Close() + // Verify that the connection is closed when the declared Content-Length // is larger than what the handler wrote. conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -395,6 +478,7 @@ func TestIdentityResponse(t *testing.T) { } func testTCPConnectionCloses(t *testing.T, req string, h Handler) { + defer afterTest(t) s := httptest.NewServer(h) defer s.Close() @@ -465,6 +549,7 @@ func TestHandlersCanSetConnectionClose10(t *testing.T) { } func TestSetsRemoteAddr(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%s", r.RemoteAddr) })) @@ -485,6 +570,7 @@ func TestSetsRemoteAddr(t *testing.T) { } func TestChunkedResponseHeaders(t *testing.T) { + defer afterTest(t) log.SetOutput(ioutil.Discard) // is noisy otherwise defer log.SetOutput(os.Stderr) @@ -499,6 +585,7 @@ func TestChunkedResponseHeaders(t *testing.T) { if err != nil { t.Fatalf("Get error: %v", err) } + defer res.Body.Close() if g, e := res.ContentLength, int64(-1); g != e { t.Errorf("expected ContentLength of %d; got %d", e, g) } @@ -514,6 +601,7 @@ func TestChunkedResponseHeaders(t *testing.T) { // chunking in their response headers and aren't allowed to produce // output. func Test304Responses(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusNotModified) _, err := w.Write([]byte("illegal body")) @@ -543,6 +631,7 @@ func Test304Responses(t *testing.T) { // allowed to produce output, and don't set a Content-Type since // the real type of the body data cannot be inferred. func TestHeadResponses(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { _, err := w.Write([]byte("Ignored body")) if err != ErrBodyNotAllowed { @@ -577,6 +666,7 @@ func TestHeadResponses(t *testing.T) { } func TestTLSHandshakeTimeout(t *testing.T) { + defer afterTest(t) ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) ts.Config.ReadTimeout = 250 * time.Millisecond ts.StartTLS() @@ -596,6 +686,7 @@ func TestTLSHandshakeTimeout(t *testing.T) { } func TestTLSServer(t *testing.T) { + defer afterTest(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.TLS != nil { w.Header().Set("X-TLS-Set", "true") @@ -678,6 +769,7 @@ var serverExpectTests = []serverExpectTest{ // Tests that the server responds to the "Expect" request header // correctly. func TestServerExpect(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { // Note using r.FormValue("readbody") because for POST // requests that would read from r.Body, which we only @@ -815,6 +907,7 @@ func TestServerUnreadRequestBodyLarge(t *testing.T) { } func TestTimeoutHandler(t *testing.T) { + defer afterTest(t) sendHi := make(chan bool, 1) writeErrors := make(chan error, 1) sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -889,6 +982,7 @@ func TestRedirectMunging(t *testing.T) { // the previous request's body, which is not optimal for zero-lengthed bodies, // as the client would then see http.ErrBodyReadAfterClose and not 0, io.EOF. func TestZeroLengthPostAndResponse(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { all, err := ioutil.ReadAll(r.Body) if err != nil { @@ -939,6 +1033,7 @@ func TestHandlerPanicWithHijack(t *testing.T) { } func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) { + defer afterTest(t) // Unlike the other tests that set the log output to ioutil.Discard // to quiet the output, this test uses a pipe. The pipe serves three // purposes: @@ -958,6 +1053,7 @@ func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) { pr, pw := io.Pipe() log.SetOutput(pw) defer log.SetOutput(os.Stderr) + defer pw.Close() ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if withHijack { @@ -979,7 +1075,7 @@ func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) { buf := make([]byte, 4<<10) _, err := pr.Read(buf) pr.Close() - if err != nil { + if err != nil && err != io.EOF { t.Error(err) } done <- true @@ -1003,6 +1099,7 @@ func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) { } func TestNoDate(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header()["Date"] = nil })) @@ -1018,6 +1115,7 @@ func TestNoDate(t *testing.T) { } func TestStripPrefix(t *testing.T) { + defer afterTest(t) h := HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Path", r.URL.Path) }) @@ -1031,6 +1129,7 @@ func TestStripPrefix(t *testing.T) { if g, e := res.Header.Get("X-Path"), "/bar"; g != e { t.Errorf("test 1: got %s, want %s", g, e) } + res.Body.Close() res, err = Get(ts.URL + "/bar") if err != nil { @@ -1039,9 +1138,11 @@ func TestStripPrefix(t *testing.T) { if g, e := res.StatusCode, 404; g != e { t.Errorf("test 2: got status %v, want %v", g, e) } + res.Body.Close() } func TestRequestLimit(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { t.Fatalf("didn't expect to get request in Handler") })) @@ -1058,6 +1159,7 @@ func TestRequestLimit(t *testing.T) { // we do support it (at least currently), so we expect a response below. t.Fatalf("Do: %v", err) } + defer res.Body.Close() if res.StatusCode != 413 { t.Fatalf("expected 413 response status; got: %d %s", res.StatusCode, res.Status) } @@ -1084,6 +1186,7 @@ func (cr countReader) Read(p []byte) (n int, err error) { } func TestRequestBodyLimit(t *testing.T) { + defer afterTest(t) const limit = 1 << 20 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { r.Body = MaxBytesReader(w, r.Body, limit) @@ -1120,6 +1223,7 @@ func TestRequestBodyLimit(t *testing.T) { // TestClientWriteShutdown tests that if the client shuts down the write // side of their TCP connection, the server doesn't send a 400 Bad Request. func TestClientWriteShutdown(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) defer ts.Close() conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -1174,6 +1278,7 @@ func TestServerBufferedChunking(t *testing.T) { // closing the TCP connection, causing the client to get a RST. // See http://golang.org/issue/3595 func TestServerGracefulClose(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { Error(w, "bye", StatusUnauthorized) })) @@ -1216,6 +1321,7 @@ func TestServerGracefulClose(t *testing.T) { } func TestCaseSensitiveMethod(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "get" { t.Errorf(`Got method %q; want "get"`, r.Method) @@ -1264,6 +1370,7 @@ func TestContentLengthZero(t *testing.T) { } func TestCloseNotifier(t *testing.T) { + defer afterTest(t) gotReq := make(chan bool, 1) sawClose := make(chan bool, 1) ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { @@ -1299,6 +1406,31 @@ For: ts.Close() } +func TestCloseNotifierChanLeak(t *testing.T) { + defer afterTest(t) + req := []byte(strings.Replace(`GET / HTTP/1.0 +Host: golang.org + +`, "\n", "\r\n", -1)) + for i := 0; i < 20; i++ { + var output bytes.Buffer + conn := &rwTestConn{ + Reader: bytes.NewReader(req), + Writer: &output, + closec: make(chan bool, 1), + } + ln := &oneConnListener{conn: conn} + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + // Ignore the return value and never read from + // it, testing that we don't leak goroutines + // on the sending side: + _ = rw.(CloseNotifier).CloseNotify() + }) + go Serve(ln, handler) + <-conn.closec + } +} + func TestOptions(t *testing.T) { uric := make(chan string, 2) // only expect 1, but leave space for 2 mux := NewServeMux() @@ -1351,6 +1483,198 @@ func TestOptions(t *testing.T) { } } +// Tests regarding the ordering of Write, WriteHeader, Header, and +// Flush calls. In Go 1.0, rw.WriteHeader immediately flushed the +// (*response).header to the wire. In Go 1.1, the actual wire flush is +// delayed, so we could maybe tack on a Content-Length and better +// Content-Type after we see more (or all) of the output. To preserve +// compatibility with Go 1, we need to be careful to track which +// headers were live at the time of WriteHeader, so we write the same +// ones, even if the handler modifies them (~erroneously) after the +// first Write. +func TestHeaderToWire(t *testing.T) { + req := []byte(strings.Replace(`GET / HTTP/1.1 +Host: golang.org + +`, "\n", "\r\n", -1)) + + tests := []struct { + name string + handler func(ResponseWriter, *Request) + check func(output string) error + }{ + { + name: "write without Header", + handler: func(rw ResponseWriter, r *Request) { + rw.Write([]byte("hello world")) + }, + check: func(got string) error { + if !strings.Contains(got, "Content-Length:") { + return errors.New("no content-length") + } + if !strings.Contains(got, "Content-Type: text/plain") { + return errors.New("no content-length") + } + return nil + }, + }, + { + name: "Header mutation before write", + handler: func(rw ResponseWriter, r *Request) { + h := rw.Header() + h.Set("Content-Type", "some/type") + rw.Write([]byte("hello world")) + h.Set("Too-Late", "bogus") + }, + check: func(got string) error { + if !strings.Contains(got, "Content-Length:") { + return errors.New("no content-length") + } + if !strings.Contains(got, "Content-Type: some/type") { + return errors.New("wrong content-type") + } + if strings.Contains(got, "Too-Late") { + return errors.New("don't want too-late header") + } + return nil + }, + }, + { + name: "write then useless Header mutation", + handler: func(rw ResponseWriter, r *Request) { + rw.Write([]byte("hello world")) + rw.Header().Set("Too-Late", "Write already wrote headers") + }, + check: func(got string) error { + if strings.Contains(got, "Too-Late") { + return errors.New("header appeared from after WriteHeader") + } + return nil + }, + }, + { + name: "flush then write", + handler: func(rw ResponseWriter, r *Request) { + rw.(Flusher).Flush() + rw.Write([]byte("post-flush")) + rw.Header().Set("Too-Late", "Write already wrote headers") + }, + check: func(got string) error { + if !strings.Contains(got, "Transfer-Encoding: chunked") { + return errors.New("not chunked") + } + if strings.Contains(got, "Too-Late") { + return errors.New("header appeared from after WriteHeader") + } + return nil + }, + }, + { + name: "header then flush", + handler: func(rw ResponseWriter, r *Request) { + rw.Header().Set("Content-Type", "some/type") + rw.(Flusher).Flush() + rw.Write([]byte("post-flush")) + rw.Header().Set("Too-Late", "Write already wrote headers") + }, + check: func(got string) error { + if !strings.Contains(got, "Transfer-Encoding: chunked") { + return errors.New("not chunked") + } + if strings.Contains(got, "Too-Late") { + return errors.New("header appeared from after WriteHeader") + } + if !strings.Contains(got, "Content-Type: some/type") { + return errors.New("wrong content-length") + } + return nil + }, + }, + { + name: "sniff-on-first-write content-type", + handler: func(rw ResponseWriter, r *Request) { + rw.Write([]byte("<html><head></head><body>some html</body></html>")) + rw.Header().Set("Content-Type", "x/wrong") + }, + check: func(got string) error { + if !strings.Contains(got, "Content-Type: text/html") { + return errors.New("wrong content-length; want html") + } + return nil + }, + }, + { + name: "explicit content-type wins", + handler: func(rw ResponseWriter, r *Request) { + rw.Header().Set("Content-Type", "some/type") + rw.Write([]byte("<html><head></head><body>some html</body></html>")) + }, + check: func(got string) error { + if !strings.Contains(got, "Content-Type: some/type") { + return errors.New("wrong content-length; want html") + } + return nil + }, + }, + { + name: "empty handler", + handler: func(rw ResponseWriter, r *Request) { + }, + check: func(got string) error { + if !strings.Contains(got, "Content-Type: text/plain") { + return errors.New("wrong content-length; want text/plain") + } + if !strings.Contains(got, "Content-Length: 0") { + return errors.New("want 0 content-length") + } + return nil + }, + }, + { + name: "only Header, no write", + handler: func(rw ResponseWriter, r *Request) { + rw.Header().Set("Some-Header", "some-value") + }, + check: func(got string) error { + if !strings.Contains(got, "Some-Header") { + return errors.New("didn't get header") + } + return nil + }, + }, + { + name: "WriteHeader call", + handler: func(rw ResponseWriter, r *Request) { + rw.WriteHeader(404) + rw.Header().Set("Too-Late", "some-value") + }, + check: func(got string) error { + if !strings.Contains(got, "404") { + return errors.New("wrong status") + } + if strings.Contains(got, "Some-Header") { + return errors.New("shouldn't have seen Too-Late") + } + return nil + }, + }, + } + for _, tc := range tests { + var output bytes.Buffer + conn := &rwTestConn{ + Reader: bytes.NewReader(req), + Writer: &output, + closec: make(chan bool, 1), + } + ln := &oneConnListener{conn: conn} + go Serve(ln, HandlerFunc(tc.handler)) + <-conn.closec + if err := tc.check(output.String()); err != nil { + t.Errorf("%s: %v\nGot response:\n%s", tc.name, err, output.Bytes()) + } + } +} + // goTimeout runs f, failing t if f takes more than ns to complete. func goTimeout(t *testing.T, d time.Duration, f func()) { ch := make(chan bool, 2) @@ -1524,3 +1848,179 @@ func BenchmarkServer(b *testing.B) { b.Errorf("Test failure: %v, with output: %s", err, out) } } + +func BenchmarkServerFakeConnNoKeepAlive(b *testing.B) { + b.ReportAllocs() + req := []byte(strings.Replace(`GET / HTTP/1.0 +Host: golang.org +Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8 +User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17 +Accept-Encoding: gzip,deflate,sdch +Accept-Language: en-US,en;q=0.8 +Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3 + +`, "\n", "\r\n", -1)) + res := []byte("Hello world!\n") + + conn := &testConn{ + // testConn.Close will not push into the channel + // if it's full. + closec: make(chan bool, 1), + } + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + rw.Header().Set("Content-Type", "text/html; charset=utf-8") + rw.Write(res) + }) + ln := new(oneConnListener) + for i := 0; i < b.N; i++ { + conn.readBuf.Reset() + conn.writeBuf.Reset() + conn.readBuf.Write(req) + ln.conn = conn + Serve(ln, handler) + <-conn.closec + } +} + +// repeatReader reads content count times, then EOFs. +type repeatReader struct { + content []byte + count int + off int +} + +func (r *repeatReader) Read(p []byte) (n int, err error) { + if r.count <= 0 { + return 0, io.EOF + } + n = copy(p, r.content[r.off:]) + r.off += n + if r.off == len(r.content) { + r.count-- + r.off = 0 + } + return +} + +func BenchmarkServerFakeConnWithKeepAlive(b *testing.B) { + b.ReportAllocs() + + req := []byte(strings.Replace(`GET / HTTP/1.1 +Host: golang.org +Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8 +User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17 +Accept-Encoding: gzip,deflate,sdch +Accept-Language: en-US,en;q=0.8 +Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3 + +`, "\n", "\r\n", -1)) + res := []byte("Hello world!\n") + + conn := &rwTestConn{ + Reader: &repeatReader{content: req, count: b.N}, + Writer: ioutil.Discard, + closec: make(chan bool, 1), + } + handled := 0 + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + handled++ + rw.Header().Set("Content-Type", "text/html; charset=utf-8") + rw.Write(res) + }) + ln := &oneConnListener{conn: conn} + go Serve(ln, handler) + <-conn.closec + if b.N != handled { + b.Errorf("b.N=%d but handled %d", b.N, handled) + } +} + +// same as above, but representing the most simple possible request +// and handler. Notably: the handler does not call rw.Header(). +func BenchmarkServerFakeConnWithKeepAliveLite(b *testing.B) { + b.ReportAllocs() + + req := []byte(strings.Replace(`GET / HTTP/1.1 +Host: golang.org + +`, "\n", "\r\n", -1)) + res := []byte("Hello world!\n") + + conn := &rwTestConn{ + Reader: &repeatReader{content: req, count: b.N}, + Writer: ioutil.Discard, + closec: make(chan bool, 1), + } + handled := 0 + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + handled++ + rw.Write(res) + }) + ln := &oneConnListener{conn: conn} + go Serve(ln, handler) + <-conn.closec + if b.N != handled { + b.Errorf("b.N=%d but handled %d", b.N, handled) + } +} + +const someResponse = "<html>some response</html>" + +// A Response that's just no bigger than 2KB, the buffer-before-chunking threshold. +var response = bytes.Repeat([]byte(someResponse), 2<<10/len(someResponse)) + +// Both Content-Type and Content-Length set. Should be no buffering. +func BenchmarkServerHandlerTypeLen(b *testing.B) { + benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Type", "text/html") + w.Header().Set("Content-Length", strconv.Itoa(len(response))) + w.Write(response) + })) +} + +// A Content-Type is set, but no length. No sniffing, but will count the Content-Length. +func BenchmarkServerHandlerNoLen(b *testing.B) { + benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Type", "text/html") + w.Write(response) + })) +} + +// A Content-Length is set, but the Content-Type will be sniffed. +func BenchmarkServerHandlerNoType(b *testing.B) { + benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Length", strconv.Itoa(len(response))) + w.Write(response) + })) +} + +// Neither a Content-Type or Content-Length, so sniffed and counted. +func BenchmarkServerHandlerNoHeader(b *testing.B) { + benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Write(response) + })) +} + +func benchmarkHandler(b *testing.B, h Handler) { + b.ReportAllocs() + req := []byte(strings.Replace(`GET / HTTP/1.1 +Host: golang.org + +`, "\n", "\r\n", -1)) + conn := &rwTestConn{ + Reader: &repeatReader{content: req, count: b.N}, + Writer: ioutil.Discard, + closec: make(chan bool, 1), + } + handled := 0 + handler := HandlerFunc(func(rw ResponseWriter, r *Request) { + handled++ + h.ServeHTTP(rw, r) + }) + ln := &oneConnListener{conn: conn} + go Serve(ln, handler) + <-conn.closec + if b.N != handled { + b.Errorf("b.N=%d but handled %d", b.N, handled) + } +} diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go index 434943d..b259607 100644 --- a/libgo/go/net/http/server.go +++ b/libgo/go/net/http/server.go @@ -4,9 +4,6 @@ // HTTP server. See RFC 2616. -// TODO(rsc): -// logging - package http import ( @@ -109,9 +106,11 @@ type conn struct { remoteAddr string // network address of remote side server *Server // the Server on which the connection arrived rwc net.Conn // i/o connection - sr switchReader // where the LimitReader reads from; usually the rwc + sr liveSwitchReader // where the LimitReader reads from; usually the rwc lr *io.LimitedReader // io.LimitReader(sr) buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->sr->rwc + bufswr *switchReader // the *switchReader io.Reader source of buf + bufsww *switchWriter // the *switchWriter io.Writer dest of buf tlsState *tls.ConnectionState // or nil when not using TLS mu sync.Mutex // guards the following @@ -147,7 +146,7 @@ func (c *conn) closeNotify() <-chan bool { c.mu.Lock() defer c.mu.Unlock() if c.closeNotifyc == nil { - c.closeNotifyc = make(chan bool) + c.closeNotifyc = make(chan bool, 1) if c.hijackedv { // to obey the function signature, even though // it'll never receive a value. @@ -180,12 +179,26 @@ func (c *conn) noteClientGone() { c.clientGone = true } +// A switchReader can have its Reader changed at runtime. +// It's not safe for concurrent Reads and switches. type switchReader struct { + io.Reader +} + +// A switchWriter can have its Writer changed at runtime. +// It's not safe for concurrent Writes and switches. +type switchWriter struct { + io.Writer +} + +// A liveSwitchReader is a switchReader that's safe for concurrent +// reads and switches, if its mutex is held. +type liveSwitchReader struct { sync.Mutex r io.Reader } -func (sr *switchReader) Read(p []byte) (n int, err error) { +func (sr *liveSwitchReader) Read(p []byte) (n int, err error) { sr.Lock() r := sr.r sr.Unlock() @@ -206,15 +219,28 @@ const bufferBeforeChunkingSize = 2048 // // See the comment above (*response).Write for the entire write flow. type chunkWriter struct { - res *response - header Header // a deep copy of r.Header, once WriteHeader is called - wroteHeader bool // whether the header's been sent + res *response + + // header is either nil or a deep clone of res.handlerHeader + // at the time of res.WriteHeader, if res.WriteHeader is + // called and extra buffering is being done to calculate + // Content-Type and/or Content-Length. + header Header + + // wroteHeader tells whether the header's been written to "the + // wire" (or rather: w.conn.buf). this is unlike + // (*response).wroteHeader, which tells only whether it was + // logically written. + wroteHeader bool // set by the writeHeader method: chunking bool // using chunked transfer encoding for reply body } -var crlf = []byte("\r\n") +var ( + crlf = []byte("\r\n") + colonSpace = []byte(": ") +) func (cw *chunkWriter) Write(p []byte) (n int, err error) { if !cw.wroteHeader { @@ -223,6 +249,7 @@ func (cw *chunkWriter) Write(p []byte) (n int, err error) { if cw.chunking { _, err = fmt.Fprintf(cw.res.conn.buf, "%x\r\n", len(p)) if err != nil { + cw.res.conn.rwc.Close() return } } @@ -230,6 +257,9 @@ func (cw *chunkWriter) Write(p []byte) (n int, err error) { if cw.chunking && err == nil { _, err = cw.res.conn.buf.Write(crlf) } + if err != nil { + cw.res.conn.rwc.Close() + } return } @@ -260,13 +290,15 @@ type response struct { wroteContinue bool // 100 Continue response was written w *bufio.Writer // buffers output in chunks to chunkWriter - cw *chunkWriter + cw chunkWriter + sw *switchWriter // of the bufio.Writer, for return to putBufioWriter // handlerHeader is the Header that Handlers get access to, // which may be retained and mutated even after WriteHeader. // handlerHeader is copied into cw.header at WriteHeader // time, and privately mutated thereafter. handlerHeader Header + calledHeader bool // handler accessed handlerHeader via Header written int64 // number of bytes written in body contentLength int64 // explicitly-declared Content-Length; or -1 @@ -358,14 +390,98 @@ func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) { if debugServerConnections { c.rwc = newLoggingConn("server", c.rwc) } - c.sr = switchReader{r: c.rwc} + c.sr = liveSwitchReader{r: c.rwc} c.lr = io.LimitReader(&c.sr, noLimit).(*io.LimitedReader) - br := bufio.NewReader(c.lr) - bw := bufio.NewWriter(c.rwc) + br, sr := newBufioReader(c.lr) + bw, sw := newBufioWriterSize(c.rwc, 4<<10) c.buf = bufio.NewReadWriter(br, bw) + c.bufswr = sr + c.bufsww = sw return c, nil } +// TODO: remove this, if issue 5100 is fixed +type bufioReaderPair struct { + br *bufio.Reader + sr *switchReader // from which the bufio.Reader is reading +} + +// TODO: remove this, if issue 5100 is fixed +type bufioWriterPair struct { + bw *bufio.Writer + sw *switchWriter // to which the bufio.Writer is writing +} + +// TODO: use a sync.Cache instead +var ( + bufioReaderCache = make(chan bufioReaderPair, 4) + bufioWriterCache2k = make(chan bufioWriterPair, 4) + bufioWriterCache4k = make(chan bufioWriterPair, 4) +) + +func bufioWriterCache(size int) chan bufioWriterPair { + switch size { + case 2 << 10: + return bufioWriterCache2k + case 4 << 10: + return bufioWriterCache4k + } + return nil +} + +func newBufioReader(r io.Reader) (*bufio.Reader, *switchReader) { + select { + case p := <-bufioReaderCache: + p.sr.Reader = r + return p.br, p.sr + default: + sr := &switchReader{r} + return bufio.NewReader(sr), sr + } +} + +func putBufioReader(br *bufio.Reader, sr *switchReader) { + if n := br.Buffered(); n > 0 { + io.CopyN(ioutil.Discard, br, int64(n)) + } + br.Read(nil) // clears br.err + sr.Reader = nil + select { + case bufioReaderCache <- bufioReaderPair{br, sr}: + default: + } +} + +func newBufioWriterSize(w io.Writer, size int) (*bufio.Writer, *switchWriter) { + select { + case p := <-bufioWriterCache(size): + p.sw.Writer = w + return p.bw, p.sw + default: + sw := &switchWriter{w} + return bufio.NewWriterSize(sw, size), sw + } +} + +func putBufioWriter(bw *bufio.Writer, sw *switchWriter) { + if bw.Buffered() > 0 { + // It must have failed to flush to its target + // earlier. We can't reuse this bufio.Writer. + return + } + if err := bw.Flush(); err != nil { + // Its sticky error field is set, which is returned by + // Flush even when there's no data buffered. This + // bufio Writer is dead to us. Don't reuse it. + return + } + sw.Writer = nil + select { + case bufioWriterCache(bw.Available()) <- bufioWriterPair{bw, sw}: + default: + } +} + // DefaultMaxHeaderBytes is the maximum permitted size of the headers // in an HTTP request. // This can be overridden by setting Server.MaxHeaderBytes. @@ -416,6 +532,16 @@ func (c *conn) readRequest() (w *response, err error) { if c.hijacked() { return nil, ErrHijacked } + + if d := c.server.ReadTimeout; d != 0 { + c.rwc.SetReadDeadline(time.Now().Add(d)) + } + if d := c.server.WriteTimeout; d != 0 { + defer func() { + c.rwc.SetWriteDeadline(time.Now().Add(d)) + }() + } + c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */ var req *Request if req, err = ReadRequest(c.buf.Reader); err != nil { @@ -434,14 +560,20 @@ func (c *conn) readRequest() (w *response, err error) { req: req, handlerHeader: make(Header), contentLength: -1, - cw: new(chunkWriter), } w.cw.res = w - w.w = bufio.NewWriterSize(w.cw, bufferBeforeChunkingSize) + w.w, w.sw = newBufioWriterSize(&w.cw, bufferBeforeChunkingSize) return w, nil } func (w *response) Header() Header { + if w.cw.header == nil && w.wroteHeader && !w.cw.wroteHeader { + // Accessing the header between logically writing it + // and physically writing it means we need to allocate + // a clone to snapshot the logically written state. + w.cw.header = w.handlerHeader.clone() + } + w.calledHeader = true return w.handlerHeader } @@ -468,15 +600,48 @@ func (w *response) WriteHeader(code int) { w.wroteHeader = true w.status = code - w.cw.header = w.handlerHeader.clone() + if w.calledHeader && w.cw.header == nil { + w.cw.header = w.handlerHeader.clone() + } - if cl := w.cw.header.get("Content-Length"); cl != "" { + if cl := w.handlerHeader.get("Content-Length"); cl != "" { v, err := strconv.ParseInt(cl, 10, 64) if err == nil && v >= 0 { w.contentLength = v } else { log.Printf("http: invalid Content-Length of %q", cl) - w.cw.header.Del("Content-Length") + w.handlerHeader.Del("Content-Length") + } + } +} + +// extraHeader is the set of headers sometimes added by chunkWriter.writeHeader. +// This type is used to avoid extra allocations from cloning and/or populating +// the response Header map and all its 1-element slices. +type extraHeader struct { + contentType string + contentLength string + connection string + date string + transferEncoding string +} + +// Sorted the same as extraHeader.Write's loop. +var extraHeaderKeys = [][]byte{ + []byte("Content-Type"), []byte("Content-Length"), + []byte("Connection"), []byte("Date"), []byte("Transfer-Encoding"), +} + +// The value receiver, despite copying 5 strings to the stack, +// prevents an extra allocation. The escape analysis isn't smart +// enough to realize this doesn't mutate h. +func (h extraHeader) Write(w io.Writer) { + for i, v := range []string{h.contentType, h.contentLength, h.connection, h.date, h.transferEncoding} { + if v != "" { + w.Write(extraHeaderKeys[i]) + w.Write(colonSpace) + io.WriteString(w, v) + w.Write(crlf) } } } @@ -496,23 +661,47 @@ func (cw *chunkWriter) writeHeader(p []byte) { cw.wroteHeader = true w := cw.res - code := w.status - done := w.handlerDone + + // header is written out to w.conn.buf below. Depending on the + // state of the handler, we either own the map or not. If we + // don't own it, the exclude map is created lazily for + // WriteSubset to remove headers. The setHeader struct holds + // headers we need to add. + header := cw.header + owned := header != nil + if !owned { + header = w.handlerHeader + } + var excludeHeader map[string]bool + delHeader := func(key string) { + if owned { + header.Del(key) + return + } + if _, ok := header[key]; !ok { + return + } + if excludeHeader == nil { + excludeHeader = make(map[string]bool) + } + excludeHeader[key] = true + } + var setHeader extraHeader // If the handler is done but never sent a Content-Length // response header and this is our first (and last) write, set // it, even to zero. This helps HTTP/1.0 clients keep their // "keep-alive" connections alive. - if done && cw.header.get("Content-Length") == "" && w.req.Method != "HEAD" { + if w.handlerDone && header.get("Content-Length") == "" && w.req.Method != "HEAD" { w.contentLength = int64(len(p)) - cw.header.Set("Content-Length", strconv.Itoa(len(p))) + setHeader.contentLength = strconv.Itoa(len(p)) } // If this was an HTTP/1.0 request with keep-alive and we sent a // Content-Length back, we can make this a keep-alive response ... if w.req.wantsHttp10KeepAlive() { - sentLength := cw.header.get("Content-Length") != "" - if sentLength && cw.header.get("Connection") == "keep-alive" { + sentLength := header.get("Content-Length") != "" + if sentLength && header.get("Connection") == "keep-alive" { w.closeAfterReply = false } } @@ -521,15 +710,15 @@ func (cw *chunkWriter) writeHeader(p []byte) { hasCL := w.contentLength != -1 if w.req.wantsHttp10KeepAlive() && (w.req.Method == "HEAD" || hasCL) { - _, connectionHeaderSet := cw.header["Connection"] + _, connectionHeaderSet := header["Connection"] if !connectionHeaderSet { - cw.header.Set("Connection", "keep-alive") + setHeader.connection = "keep-alive" } } else if !w.req.ProtoAtLeast(1, 1) || w.req.wantsClose() { w.closeAfterReply = true } - if cw.header.get("Connection") == "close" { + if header.get("Connection") == "close" { w.closeAfterReply = true } @@ -543,49 +732,49 @@ func (cw *chunkWriter) writeHeader(p []byte) { n, _ := io.CopyN(ioutil.Discard, w.req.Body, maxPostHandlerReadBytes+1) if n >= maxPostHandlerReadBytes { w.requestTooLarge() - cw.header.Set("Connection", "close") + delHeader("Connection") + setHeader.connection = "close" } else { w.req.Body.Close() } } } + code := w.status if code == StatusNotModified { // Must not have body. - for _, header := range []string{"Content-Type", "Content-Length", "Transfer-Encoding"} { - // RFC 2616 section 10.3.5: "the response MUST NOT include other entity-headers" - if cw.header.get(header) != "" { - cw.header.Del(header) - } + // RFC 2616 section 10.3.5: "the response MUST NOT include other entity-headers" + for _, k := range []string{"Content-Type", "Content-Length", "Transfer-Encoding"} { + delHeader(k) } } else { // If no content type, apply sniffing algorithm to body. - if cw.header.get("Content-Type") == "" && w.req.Method != "HEAD" { - cw.header.Set("Content-Type", DetectContentType(p)) + if header.get("Content-Type") == "" && w.req.Method != "HEAD" { + setHeader.contentType = DetectContentType(p) } } - if _, ok := cw.header["Date"]; !ok { - cw.header.Set("Date", time.Now().UTC().Format(TimeFormat)) + if _, ok := header["Date"]; !ok { + setHeader.date = time.Now().UTC().Format(TimeFormat) } - te := cw.header.get("Transfer-Encoding") + te := header.get("Transfer-Encoding") hasTE := te != "" if hasCL && hasTE && te != "identity" { // TODO: return an error if WriteHeader gets a return parameter // For now just ignore the Content-Length. log.Printf("http: WriteHeader called with both Transfer-Encoding of %q and a Content-Length of %d", te, w.contentLength) - cw.header.Del("Content-Length") + delHeader("Content-Length") hasCL = false } if w.req.Method == "HEAD" || code == StatusNotModified { // do nothing } else if code == StatusNoContent { - cw.header.Del("Transfer-Encoding") + delHeader("Transfer-Encoding") } else if hasCL { - cw.header.Del("Transfer-Encoding") + delHeader("Transfer-Encoding") } else if w.req.ProtoAtLeast(1, 1) { // HTTP/1.1 or greater: use chunked transfer encoding // to avoid closing the connection at EOF. @@ -593,29 +782,63 @@ func (cw *chunkWriter) writeHeader(p []byte) { // might have set. Deal with that as need arises once we have a valid // use case. cw.chunking = true - cw.header.Set("Transfer-Encoding", "chunked") + setHeader.transferEncoding = "chunked" } else { // HTTP version < 1.1: cannot do chunked transfer // encoding and we don't know the Content-Length so // signal EOF by closing connection. w.closeAfterReply = true - cw.header.Del("Transfer-Encoding") // in case already set + delHeader("Transfer-Encoding") // in case already set } // Cannot use Content-Length with non-identity Transfer-Encoding. if cw.chunking { - cw.header.Del("Content-Length") + delHeader("Content-Length") } if !w.req.ProtoAtLeast(1, 0) { return } if w.closeAfterReply && !hasToken(cw.header.get("Connection"), "close") { - cw.header.Set("Connection", "close") + delHeader("Connection") + setHeader.connection = "close" } + io.WriteString(w.conn.buf, statusLine(w.req, code)) + cw.header.WriteSubset(w.conn.buf, excludeHeader) + setHeader.Write(w.conn.buf) + w.conn.buf.Write(crlf) +} + +// statusLines is a cache of Status-Line strings, keyed by code (for +// HTTP/1.1) or negative code (for HTTP/1.0). This is faster than a +// map keyed by struct of two fields. This map's max size is bounded +// by 2*len(statusText), two protocol types for each known official +// status code in the statusText map. +var ( + statusMu sync.RWMutex + statusLines = make(map[int]string) +) + +// statusLine returns a response Status-Line (RFC 2616 Section 6.1) +// for the given request and response status code. +func statusLine(req *Request, code int) string { + // Fast path: + key := code + proto11 := req.ProtoAtLeast(1, 1) + if !proto11 { + key = -key + } + statusMu.RLock() + line, ok := statusLines[key] + statusMu.RUnlock() + if ok { + return line + } + + // Slow path: proto := "HTTP/1.0" - if w.req.ProtoAtLeast(1, 1) { + if proto11 { proto = "HTTP/1.1" } codestring := strconv.Itoa(code) @@ -623,9 +846,13 @@ func (cw *chunkWriter) writeHeader(p []byte) { if !ok { text = "status code " + codestring } - io.WriteString(w.conn.buf, proto+" "+codestring+" "+text+"\r\n") - cw.header.Write(w.conn.buf) - w.conn.buf.Write(crlf) + line = proto + " " + codestring + " " + text + "\r\n" + if ok { + statusMu.Lock() + defer statusMu.Unlock() + statusLines[key] = line + } + return line } // bodyAllowed returns true if a Write is allowed for this response type. @@ -641,7 +868,7 @@ func (w *response) bodyAllowed() bool { // // Handler starts. No header has been sent. The handler can either // write a header, or just start writing. Writing before sending a header -// sends an implicity empty 200 OK header. +// sends an implicitly empty 200 OK header. // // If the handler didn't declare a Content-Length up front, we either // go into chunking mode or, if the handler finishes running before @@ -699,6 +926,7 @@ func (w *response) finishRequest() { } w.w.Flush() + putBufioWriter(w.w, w.sw) w.cw.close() w.conn.buf.Flush() @@ -728,6 +956,15 @@ func (w *response) Flush() { func (c *conn) finalFlush() { if c.buf != nil { c.buf.Flush() + + // Steal the bufio.Reader (~4KB worth of memory) and its associated + // reader for a future connection. + putBufioReader(c.buf.Reader, c.bufswr) + + // Steal the bufio.Writer (~4KB worth of memory) and its associated + // writer for a future connection. + putBufioWriter(c.buf.Writer, c.bufsww) + c.buf = nil } } @@ -764,6 +1001,18 @@ func (c *conn) closeWriteAndWait() { time.Sleep(rstAvoidanceDelay) } +// validNPN returns whether the proto is not a blacklisted Next +// Protocol Negotiation protocol. Empty and built-in protocol types +// are blacklisted and can't be overridden with alternate +// implementations. +func validNPN(proto string) bool { + switch proto { + case "", "http/1.1", "http/1.0": + return false + } + return true +} + // Serve a new connection. func (c *conn) serve() { defer func() { @@ -779,11 +1028,24 @@ func (c *conn) serve() { }() if tlsConn, ok := c.rwc.(*tls.Conn); ok { + if d := c.server.ReadTimeout; d != 0 { + c.rwc.SetReadDeadline(time.Now().Add(d)) + } + if d := c.server.WriteTimeout; d != 0 { + c.rwc.SetWriteDeadline(time.Now().Add(d)) + } if err := tlsConn.Handshake(); err != nil { return } c.tlsState = new(tls.ConnectionState) *c.tlsState = tlsConn.ConnectionState() + if proto := c.tlsState.NegotiatedProtocol; validNPN(proto) { + if fn := c.server.TLSNextProto[proto]; fn != nil { + h := initNPNRequest{tlsConn, serverHandler{c.server}} + fn(c.server, tlsConn, h) + } + return + } } for { @@ -826,20 +1088,12 @@ func (c *conn) serve() { break } - handler := c.server.Handler - if handler == nil { - handler = DefaultServeMux - } - if req.RequestURI == "*" && req.Method == "OPTIONS" { - handler = globalOptionsHandler{} - } - // HTTP cannot have multiple simultaneous active requests.[*] // Until the server replies to this request, it can't read another, // so we might as well run the handler in this goroutine. // [*] Not strictly true: HTTP pipelining. We could let them all process // in parallel even if their responses need to be serialized. - handler.ServeHTTP(w, w.req) + serverHandler{c.server}.ServeHTTP(w, w.req) if c.hijacked() { return } @@ -917,13 +1171,16 @@ func NotFoundHandler() Handler { return HandlerFunc(NotFound) } // request for a path that doesn't begin with prefix by // replying with an HTTP 404 not found error. func StripPrefix(prefix string, h Handler) Handler { + if prefix == "" { + return h + } return HandlerFunc(func(w ResponseWriter, r *Request) { - if !strings.HasPrefix(r.URL.Path, prefix) { + if p := strings.TrimPrefix(r.URL.Path, prefix); len(p) < len(r.URL.Path) { + r.URL.Path = p + h.ServeHTTP(w, r) + } else { NotFound(w, r) - return } - r.URL.Path = r.URL.Path[len(prefix):] - h.ServeHTTP(w, r) }) } @@ -965,9 +1222,9 @@ func Redirect(w ResponseWriter, r *Request, urlStr string, code int) { } // clean up but preserve trailing slash - trailing := urlStr[len(urlStr)-1] == '/' + trailing := strings.HasSuffix(urlStr, "/") urlStr = path.Clean(urlStr) - if trailing && urlStr[len(urlStr)-1] != '/' { + if trailing && !strings.HasSuffix(urlStr, "/") { urlStr += "/" } urlStr += query @@ -1232,6 +1489,32 @@ type Server struct { WriteTimeout time.Duration // maximum duration before timing out write of the response MaxHeaderBytes int // maximum size of request headers, DefaultMaxHeaderBytes if 0 TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS + + // TLSNextProto optionally specifies a function to take over + // ownership of the provided TLS connection when an NPN + // protocol upgrade has occurred. The map key is the protocol + // name negotiated. The Handler argument should be used to + // handle HTTP requests and will initialize the Request's TLS + // and RemoteAddr if not already set. The connection is + // automatically closed when the function returns. + TLSNextProto map[string]func(*Server, *tls.Conn, Handler) +} + +// serverHandler delegates to either the server's Handler or +// DefaultServeMux and also handles "OPTIONS *" requests. +type serverHandler struct { + srv *Server +} + +func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) { + handler := sh.srv.Handler + if handler == nil { + handler = DefaultServeMux + } + if req.RequestURI == "*" && req.Method == "OPTIONS" { + handler = globalOptionsHandler{} + } + handler.ServeHTTP(rw, req) } // ListenAndServe listens on the TCP network address srv.Addr and then @@ -1274,19 +1557,12 @@ func (srv *Server) Serve(l net.Listener) error { return e } tempDelay = 0 - if srv.ReadTimeout != 0 { - rw.SetReadDeadline(time.Now().Add(srv.ReadTimeout)) - } - if srv.WriteTimeout != 0 { - rw.SetWriteDeadline(time.Now().Add(srv.WriteTimeout)) - } c, err := srv.newConn(rw) if err != nil { continue } go c.serve() } - panic("not reached") } // ListenAndServe listens on the TCP network address addr @@ -1425,7 +1701,7 @@ func (h *timeoutHandler) errorBody() string { } func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { - done := make(chan bool) + done := make(chan bool, 1) tw := &timeoutWriter{w: w} go func() { h.handler.ServeHTTP(tw, r) @@ -1494,6 +1770,31 @@ func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) { } } +// eofReader is a non-nil io.ReadCloser that always returns EOF. +var eofReader = ioutil.NopCloser(strings.NewReader("")) + +// initNPNRequest is an HTTP handler that initializes certain +// uninitialized fields in its *Request. Such partially-initialized +// Requests come from NPN protocol handlers. +type initNPNRequest struct { + c *tls.Conn + h serverHandler +} + +func (h initNPNRequest) ServeHTTP(rw ResponseWriter, req *Request) { + if req.TLS == nil { + req.TLS = &tls.ConnectionState{} + *req.TLS = h.c.ConnectionState() + } + if req.Body == nil { + req.Body = eofReader + } + if req.RemoteAddr == "" { + req.RemoteAddr = h.c.RemoteAddr().String() + } + h.h.ServeHTTP(rw, req) +} + // loggingConn is used for debugging. type loggingConn struct { name string diff --git a/libgo/go/net/http/server_test.go b/libgo/go/net/http/server_test.go index 8b4e8c6..e8b69f7 100644 --- a/libgo/go/net/http/server_test.go +++ b/libgo/go/net/http/server_test.go @@ -2,9 +2,11 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package http +package http_test import ( + . "net/http" + "net/http/httptest" "net/url" "testing" ) @@ -76,20 +78,27 @@ func TestServeMuxHandler(t *testing.T) { }, } h, pattern := mux.Handler(r) - cs := &codeSaver{h: Header{}} - h.ServeHTTP(cs, r) - if pattern != tt.pattern || cs.code != tt.code { - t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, cs.code, pattern, tt.code, tt.pattern) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, r) + if pattern != tt.pattern || rr.Code != tt.code { + t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern) } } } -// A codeSaver is a ResponseWriter that saves the code passed to WriteHeader. -type codeSaver struct { - h Header - code int +func TestServerRedirect(t *testing.T) { + // This used to crash. It's not valid input (bad path), but it + // shouldn't crash. + rr := httptest.NewRecorder() + req := &Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Path: "not-empty-but-no-leading-slash", // bogus + }, + } + Redirect(rr, req, "", 304) + if rr.Code != 304 { + t.Errorf("Code = %d; want 304", rr.Code) + } } - -func (cs *codeSaver) Header() Header { return cs.h } -func (cs *codeSaver) Write(p []byte) (int, error) { return len(p), nil } -func (cs *codeSaver) WriteHeader(code int) { cs.code = code } diff --git a/libgo/go/net/http/sniff_test.go b/libgo/go/net/http/sniff_test.go index 8ab72ac..106d94ec 100644 --- a/libgo/go/net/http/sniff_test.go +++ b/libgo/go/net/http/sniff_test.go @@ -54,6 +54,7 @@ func TestDetectContentType(t *testing.T) { } func TestServerContentType(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { i, _ := strconv.Atoi(r.FormValue("i")) tt := sniffTests[i] @@ -84,6 +85,8 @@ func TestServerContentType(t *testing.T) { } func TestContentTypeWithCopy(t *testing.T) { + defer afterTest(t) + const ( input = "\n<html>\n\t<head>\n" expected = "text/html; charset=utf-8" @@ -116,6 +119,7 @@ func TestContentTypeWithCopy(t *testing.T) { } func TestSniffWriteSize(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { size, _ := strconv.Atoi(r.FormValue("size")) written, err := io.WriteString(w, strings.Repeat("a", size)) @@ -133,6 +137,11 @@ func TestSniffWriteSize(t *testing.T) { if err != nil { t.Fatalf("size %d: %v", size, err) } - res.Body.Close() + if _, err := io.Copy(ioutil.Discard, res.Body); err != nil { + t.Fatalf("size %d: io.Copy of body = %v", size, err) + } + if err := res.Body.Close(); err != nil { + t.Fatalf("size %d: body Close = %v", size, err) + } } } diff --git a/libgo/go/net/http/status.go b/libgo/go/net/http/status.go index 5af0b77..d253bd5 100644 --- a/libgo/go/net/http/status.go +++ b/libgo/go/net/http/status.go @@ -51,6 +51,13 @@ const ( StatusServiceUnavailable = 503 StatusGatewayTimeout = 504 StatusHTTPVersionNotSupported = 505 + + // New HTTP status codes from RFC 6585. Not exported yet in Go 1.1. + // See discussion at https://codereview.appspot.com/7678043/ + statusPreconditionRequired = 428 + statusTooManyRequests = 429 + statusRequestHeaderFieldsTooLarge = 431 + statusNetworkAuthenticationRequired = 511 ) var statusText = map[int]string{ @@ -99,6 +106,11 @@ var statusText = map[int]string{ StatusServiceUnavailable: "Service Unavailable", StatusGatewayTimeout: "Gateway Timeout", StatusHTTPVersionNotSupported: "HTTP Version Not Supported", + + statusPreconditionRequired: "Precondition Required", + statusTooManyRequests: "Too Many Requests", + statusRequestHeaderFieldsTooLarge: "Request Header Fields Too Large", + statusNetworkAuthenticationRequired: "Network Authentication Required", } // StatusText returns a text for the HTTP status code. It returns the empty diff --git a/libgo/go/net/http/transfer.go b/libgo/go/net/http/transfer.go index 25b34ad..53569bc 100644 --- a/libgo/go/net/http/transfer.go +++ b/libgo/go/net/http/transfer.go @@ -194,10 +194,11 @@ func (t *transferWriter) WriteBody(w io.Writer) (err error) { ncopy, err = io.Copy(w, t.Body) } else { ncopy, err = io.Copy(w, io.LimitReader(t.Body, t.ContentLength)) - nextra, err := io.Copy(ioutil.Discard, t.Body) if err != nil { return err } + var nextra int64 + nextra, err = io.Copy(ioutil.Discard, t.Body) ncopy += nextra } if err != nil { @@ -208,7 +209,7 @@ func (t *transferWriter) WriteBody(w io.Writer) (err error) { } } - if t.ContentLength != -1 && t.ContentLength != ncopy { + if !t.ResponseToHEAD && t.ContentLength != -1 && t.ContentLength != ncopy { return fmt.Errorf("http: Request.ContentLength=%d with Body length %d", t.ContentLength, ncopy) } @@ -326,9 +327,14 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { // or close connection when finished, since multipart is not supported yet switch { case chunked(t.TransferEncoding): - t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close} - case realLength >= 0: - // TODO: limit the Content-Length. This is an easy DoS vector. + if noBodyExpected(t.RequestMethod) { + t.Body = &body{Reader: eofReader, closing: t.Close} + } else { + t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close} + } + case realLength == 0: + t.Body = &body{Reader: eofReader, closing: t.Close} + case realLength > 0: t.Body = &body{Reader: io.LimitReader(r, realLength), closing: t.Close} default: // realLength < 0, i.e. "Content-Length" not mentioned in header @@ -337,7 +343,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { t.Body = &body{Reader: r, closing: t.Close} } else { // Persistent connection (i.e. HTTP/1.1) - t.Body = &body{Reader: io.LimitReader(r, 0), closing: t.Close} + t.Body = &body{Reader: eofReader, closing: t.Close} } } @@ -449,13 +455,6 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header, return 0, nil } - // Logic based on media type. The purpose of the following code is just - // to detect whether the unsupported "multipart/byteranges" is being - // used. A proper Content-Type parser is needed in the future. - if strings.Contains(strings.ToLower(header.get("Content-Type")), "multipart/byteranges") { - return -1, ErrNotSupported - } - // Body-EOF logic based on other methods (like closing, or chunked coding) return -1, nil } @@ -614,30 +613,26 @@ func (b *body) Close() error { if b.closed { return nil } - defer func() { - b.closed = true - }() - if b.hdr == nil && b.closing { + var err error + switch { + case b.hdr == nil && b.closing: // no trailer and closing the connection next. // no point in reading to EOF. - return nil - } - - // In a server request, don't continue reading from the client - // if we've already hit the maximum body size set by the - // handler. If this is set, that also means the TCP connection - // is about to be closed, so getting to the next HTTP request - // in the stream is not necessary. - if b.res != nil && b.res.requestBodyLimitHit { - return nil - } - - // Fully consume the body, which will also lead to us reading - // the trailer headers after the body, if present. - if _, err := io.Copy(ioutil.Discard, b); err != nil { - return err + case b.res != nil && b.res.requestBodyLimitHit: + // In a server request, don't continue reading from the client + // if we've already hit the maximum body size set by the + // handler. If this is set, that also means the TCP connection + // is about to be closed, so getting to the next HTTP request + // in the stream is not necessary. + case b.Reader == eofReader: + // Nothing to read. No need to io.Copy from it. + default: + // Fully consume the body, which will also lead to us reading + // the trailer headers after the body, if present. + _, err = io.Copy(ioutil.Discard, b) } - return nil + b.closed = true + return err } // parseContentLength trims whitespace from s and returns -1 if no value diff --git a/libgo/go/net/http/transport.go b/libgo/go/net/http/transport.go index 98e198e..4cd0533 100644 --- a/libgo/go/net/http/transport.go +++ b/libgo/go/net/http/transport.go @@ -17,7 +17,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "log" "net" "net/url" @@ -42,14 +41,13 @@ const DefaultMaxIdleConnsPerHost = 2 // https, and http proxies (for either http or https with CONNECT). // Transport can also cache connections for future re-use. type Transport struct { - idleLk sync.Mutex - idleConn map[string][]*persistConn - altLk sync.RWMutex - altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper - - // TODO: tunable on global max cached connections - // TODO: tunable on timeout on cached connections - // TODO: optional pipelining + idleMu sync.Mutex + idleConn map[string][]*persistConn + idleConnCh map[string]chan *persistConn + reqMu sync.Mutex + reqConn map[*Request]*persistConn + altMu sync.RWMutex + altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the @@ -60,19 +58,39 @@ type Transport struct { // Dial specifies the dial function for creating TCP // connections. // If Dial is nil, net.Dial is used. - Dial func(net, addr string) (c net.Conn, err error) + Dial func(network, addr string) (net.Conn, error) // TLSClientConfig specifies the TLS configuration to use with // tls.Client. If nil, the default configuration is used. TLSClientConfig *tls.Config - DisableKeepAlives bool + // DisableKeepAlives, if true, prevents re-use of TCP connections + // between different HTTP requests. + DisableKeepAlives bool + + // DisableCompression, if true, prevents the Transport from + // requesting compression with an "Accept-Encoding: gzip" + // request header when the Request contains no existing + // Accept-Encoding value. If the Transport requests gzip on + // its own and gets a gzipped response, it's transparently + // decoded in the Response.Body. However, if the user + // explicitly requested gzip it is not automatically + // uncompressed. DisableCompression bool // MaxIdleConnsPerHost, if non-zero, controls the maximum idle // (keep-alive) to keep per-host. If zero, // DefaultMaxIdleConnsPerHost is used. MaxIdleConnsPerHost int + + // ResponseHeaderTimeout, if non-zero, specifies the amount of + // time to wait for a server's response headers after fully + // writing the request (including its body, if any). This + // time does not include the time to read the response body. + ResponseHeaderTimeout time.Duration + + // TODO: tunable on global max cached connections + // TODO: tunable on timeout on cached connections } // ProxyFromEnvironment returns the URL of the proxy to use for a @@ -125,6 +143,9 @@ func (tr *transportRequest) extraHeaders() Header { } // RoundTrip implements the RoundTripper interface. +// +// For higher-level HTTP client support (such as handling of cookies +// and redirects), see Get, Post, and the Client type. func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { if req.URL == nil { return nil, errors.New("http: nil Request.URL") @@ -133,12 +154,12 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) { return nil, errors.New("http: nil Request.Header") } if req.URL.Scheme != "http" && req.URL.Scheme != "https" { - t.altLk.RLock() + t.altMu.RLock() var rt RoundTripper if t.altProto != nil { rt = t.altProto[req.URL.Scheme] } - t.altLk.RUnlock() + t.altMu.RUnlock() if rt == nil { return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme} } @@ -175,8 +196,8 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) { if scheme == "http" || scheme == "https" { panic("protocol " + scheme + " already registered") } - t.altLk.Lock() - defer t.altLk.Unlock() + t.altMu.Lock() + defer t.altMu.Unlock() if t.altProto == nil { t.altProto = make(map[string]RoundTripper) } @@ -191,10 +212,10 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) { // a "keep-alive" state. It does not interrupt any connections currently // in use. func (t *Transport) CloseIdleConnections() { - t.idleLk.Lock() + t.idleMu.Lock() m := t.idleConn t.idleConn = nil - t.idleLk.Unlock() + t.idleMu.Unlock() if m == nil { return } @@ -205,6 +226,17 @@ func (t *Transport) CloseIdleConnections() { } } +// CancelRequest cancels an in-flight request by closing its +// connection. +func (t *Transport) CancelRequest(req *Request) { + t.reqMu.Lock() + pc := t.reqConn[req] + t.reqMu.Unlock() + if pc != nil { + pc.conn.Close() + } +} + // // Private implementation past this point. // @@ -260,12 +292,23 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { if max == 0 { max = DefaultMaxIdleConnsPerHost } - t.idleLk.Lock() + t.idleMu.Lock() + select { + case t.idleConnCh[key] <- pconn: + // We're done with this pconn and somebody else is + // currently waiting for a conn of this type (they're + // actively dialing, but this conn is ready + // first). Chrome calls this socket late binding. See + // https://insouciant.org/tech/connection-management-in-chromium/ + t.idleMu.Unlock() + return true + default: + } if t.idleConn == nil { t.idleConn = make(map[string][]*persistConn) } if len(t.idleConn[key]) >= max { - t.idleLk.Unlock() + t.idleMu.Unlock() pconn.close() return false } @@ -275,14 +318,29 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool { } } t.idleConn[key] = append(t.idleConn[key], pconn) - t.idleLk.Unlock() + t.idleMu.Unlock() return true } +func (t *Transport) getIdleConnCh(cm *connectMethod) chan *persistConn { + key := cm.key() + t.idleMu.Lock() + defer t.idleMu.Unlock() + if t.idleConnCh == nil { + t.idleConnCh = make(map[string]chan *persistConn) + } + ch, ok := t.idleConnCh[key] + if !ok { + ch = make(chan *persistConn) + t.idleConnCh[key] = ch + } + return ch +} + func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { - key := cm.String() - t.idleLk.Lock() - defer t.idleLk.Unlock() + key := cm.key() + t.idleMu.Lock() + defer t.idleMu.Unlock() if t.idleConn == nil { return nil } @@ -304,7 +362,19 @@ func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) { return } } - panic("unreachable") +} + +func (t *Transport) setReqConn(r *Request, pc *persistConn) { + t.reqMu.Lock() + defer t.reqMu.Unlock() + if t.reqConn == nil { + t.reqConn = make(map[*Request]*persistConn) + } + if pc != nil { + t.reqConn[r] = pc + } else { + delete(t.reqConn, r) + } } func (t *Transport) dial(network, addr string) (c net.Conn, err error) { @@ -323,6 +393,37 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { return pc, nil } + type dialRes struct { + pc *persistConn + err error + } + dialc := make(chan dialRes) + go func() { + pc, err := t.dialConn(cm) + dialc <- dialRes{pc, err} + }() + + idleConnCh := t.getIdleConnCh(cm) + select { + case v := <-dialc: + // Our dial finished. + return v.pc, v.err + case pc := <-idleConnCh: + // Another request finished first and its net.Conn + // became available before our dial. Or somebody + // else's dial that they didn't use. + // But our dial is still going, so give it away + // when it finishes: + go func() { + if v := <-dialc; v.err == nil { + t.putIdleConn(v.pc) + } + }() + return pc, nil + } +} + +func (t *Transport) dialConn(cm *connectMethod) (*persistConn, error) { conn, err := t.dial("tcp", cm.addr()) if err != nil { if cm.proxyURL != nil { @@ -335,7 +436,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { pconn := &persistConn{ t: t, - cacheKey: cm.String(), + cacheKey: cm.key(), conn: conn, reqch: make(chan requestAndChan, 50), writech: make(chan writeRequest, 50), @@ -485,6 +586,10 @@ type connectMethod struct { targetAddr string // Not used if proxy + http targetScheme (4th example in table) } +func (ck *connectMethod) key() string { + return ck.String() // TODO: use a struct type instead +} + func (ck *connectMethod) String() string { proxyStr := "" targetAddr := ck.targetAddr @@ -529,14 +634,13 @@ type persistConn struct { closech chan struct{} // broadcast close when readLoop (TCP connection) closes isProxy bool + lk sync.Mutex // guards following 3 fields + numExpectedResponses int + broken bool // an error has happened on this connection; marked broken so it's not reused. // mutateHeaderFunc is an optional func to modify extra // headers on each outbound request before it's written. (the // original Request given to RoundTrip is not modified) mutateHeaderFunc func(Header) - - lk sync.Mutex // guards numExpectedResponses and broken - numExpectedResponses int - broken bool // an error has happened on this connection; marked broken so it's not reused. } func (pc *persistConn) isBroken() bool { @@ -561,7 +665,6 @@ func remoteSideClosed(err error) bool { func (pc *persistConn) readLoop() { defer close(pc.closech) alive := true - var lastbody io.ReadCloser // last response body, if any, read on this connection for alive { pb, err := pc.br.Peek(1) @@ -580,22 +683,23 @@ func (pc *persistConn) readLoop() { rc := <-pc.reqch - // Advance past the previous response's body, if the - // caller hasn't done so. - if lastbody != nil { - lastbody.Close() // assumed idempotent - lastbody = nil - } - var resp *Response if err == nil { resp, err = ReadResponse(pc.br, rc.req) + if err == nil && resp.StatusCode == 100 { + // Skip any 100-continue for now. + // TODO(bradfitz): if rc.req had "Expect: 100-continue", + // actually block the request body write and signal the + // writeLoop now to begin sending it. (Issue 2184) For now we + // eat it, since we're never expecting one. + resp, err = ReadResponse(pc.br, rc.req) + } } + hasBody := resp != nil && rc.req.Method != "HEAD" && resp.ContentLength != 0 if err != nil { pc.close() } else { - hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0 if rc.addedGzip && hasBody && resp.Header.Get("Content-Encoding") == "gzip" { resp.Header.Del("Content-Encoding") resp.Header.Del("Content-Length") @@ -605,21 +709,29 @@ func (pc *persistConn) readLoop() { pc.close() err = zerr } else { - resp.Body = &readFirstCloseBoth{&discardOnCloseReadCloser{gzReader}, resp.Body} + resp.Body = &readerAndCloser{gzReader, resp.Body} } } resp.Body = &bodyEOFSignal{body: resp.Body} } - if err != nil || resp.Close || rc.req.Close { + if err != nil || resp.Close || rc.req.Close || resp.StatusCode <= 199 { + // Don't do keep-alive on error if either party requested a close + // or we get an unexpected informational (1xx) response. + // StatusCode 100 is already handled above. alive = false } - hasBody := resp != nil && rc.req.Method != "HEAD" && resp.ContentLength != 0 var waitForBodyRead chan bool if hasBody { - lastbody = resp.Body - waitForBodyRead = make(chan bool, 1) + waitForBodyRead = make(chan bool, 2) + resp.Body.(*bodyEOFSignal).earlyCloseFn = func() error { + // Sending false here sets alive to + // false and closes the connection + // below. + waitForBodyRead <- false + return nil + } resp.Body.(*bodyEOFSignal).fn = func(err error) { alive1 := alive if err != nil { @@ -636,15 +748,6 @@ func (pc *persistConn) readLoop() { } if alive && !hasBody { - // When there's no response body, we immediately - // reuse the TCP connection (putIdleConn), but - // we need to prevent ClientConn.Read from - // closing the Response.Body on the next - // loop, otherwise it might close the body - // before the client code has had a chance to - // read it (even though it'll just be 0, EOF). - lastbody = nil - if !pc.t.putIdleConn(pc) { alive = false } @@ -658,6 +761,8 @@ func (pc *persistConn) readLoop() { alive = <-waitForBodyRead } + pc.t.setReqConn(rc.req, nil) + if !alive { pc.close() } @@ -711,8 +816,14 @@ type writeRequest struct { } func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { - if pc.mutateHeaderFunc != nil { - pc.mutateHeaderFunc(req.extraHeaders()) + pc.t.setReqConn(req.Request, pc) + pc.lk.Lock() + pc.numExpectedResponses++ + headerFn := pc.mutateHeaderFunc + pc.lk.Unlock() + + if headerFn != nil { + headerFn(req.extraHeaders()) } // Ask for a compressed version if the caller didn't set their @@ -728,10 +839,6 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err req.extraHeaders().Set("Accept-Encoding", "gzip") } - pc.lk.Lock() - pc.numExpectedResponses++ - pc.lk.Unlock() - // Write the request concurrently with waiting for a response, // in case the server decides to reply before reading our full // request body. @@ -744,6 +851,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err var re responseAndError var pconnDeadCh = pc.closech var failTicker <-chan time.Time + var respHeaderTimer <-chan time.Time WaitResponse: for { select { @@ -753,6 +861,9 @@ WaitResponse: pc.close() break WaitResponse } + if d := pc.t.ResponseHeaderTimeout; d > 0 { + respHeaderTimer = time.After(d) + } case <-pconnDeadCh: // The persist connection is dead. This shouldn't // usually happen (only with Connection: close responses @@ -769,7 +880,11 @@ WaitResponse: pconnDeadCh = nil // avoid spinning failTicker = time.After(100 * time.Millisecond) // arbitrary time to wait for resc case <-failTicker: - re = responseAndError{nil, errors.New("net/http: transport closed before response was received")} + re = responseAndError{err: errors.New("net/http: transport closed before response was received")} + break WaitResponse + case <-respHeaderTimer: + pc.close() + re = responseAndError{err: errors.New("net/http: timeout awaiting response headers")} break WaitResponse case re = <-resc: break WaitResponse @@ -780,6 +895,9 @@ WaitResponse: pc.numExpectedResponses-- pc.lk.Unlock() + if re.err != nil { + pc.t.setReqConn(req.Request, nil) + } return re.res, re.err } @@ -823,13 +941,16 @@ func canonicalAddr(url *url.URL) string { // bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most // once, right before its final (error-producing) Read or Close call -// returns. +// returns. If earlyCloseFn is non-nil and Close is called before +// io.EOF is seen, earlyCloseFn is called instead of fn, and its +// return value is the return value from Close. type bodyEOFSignal struct { - body io.ReadCloser - mu sync.Mutex // guards closed, rerr and fn - closed bool // whether Close has been called - rerr error // sticky Read error - fn func(error) // error will be nil on Read io.EOF + body io.ReadCloser + mu sync.Mutex // guards following 4 fields + closed bool // whether Close has been called + rerr error // sticky Read error + fn func(error) // error will be nil on Read io.EOF + earlyCloseFn func() error // optional alt Close func used if io.EOF not seen } func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { @@ -862,6 +983,9 @@ func (es *bodyEOFSignal) Close() error { return nil } es.closed = true + if es.earlyCloseFn != nil && es.rerr != io.EOF { + return es.earlyCloseFn() + } err := es.body.Close() es.condfn(err) return err @@ -879,28 +1003,7 @@ func (es *bodyEOFSignal) condfn(err error) { es.fn = nil } -type readFirstCloseBoth struct { - io.ReadCloser +type readerAndCloser struct { + io.Reader io.Closer } - -func (r *readFirstCloseBoth) Close() error { - if err := r.ReadCloser.Close(); err != nil { - r.Closer.Close() - return err - } - if err := r.Closer.Close(); err != nil { - return err - } - return nil -} - -// discardOnCloseReadCloser consumes all its input on Close. -type discardOnCloseReadCloser struct { - io.ReadCloser -} - -func (d *discardOnCloseReadCloser) Close() error { - io.Copy(ioutil.Discard, d.ReadCloser) // ignore errors; likely invalid or already closed - return d.ReadCloser.Close() -} diff --git a/libgo/go/net/http/transport_test.go b/libgo/go/net/http/transport_test.go index daaecae..9f64a6e 100644 --- a/libgo/go/net/http/transport_test.go +++ b/libgo/go/net/http/transport_test.go @@ -7,6 +7,7 @@ package http_test import ( + "bufio" "bytes" "compress/gzip" "crypto/rand" @@ -102,11 +103,13 @@ func (tcs *testConnSet) check(t *testing.T) { // Two subsequent requests and verify their response is the same. // The response from the server is our own IP:port func TestTransportKeepAlives(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() for _, disableKeepAlive := range []bool{false, true} { tr := &Transport{DisableKeepAlives: disableKeepAlive} + defer tr.CloseIdleConnections() c := &Client{Transport: tr} fetch := func(n int) string { @@ -133,6 +136,7 @@ func TestTransportKeepAlives(t *testing.T) { } func TestTransportConnectionCloseOnResponse(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() @@ -183,6 +187,7 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) { } func TestTransportConnectionCloseOnRequest(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() @@ -233,6 +238,7 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) { } func TestTransportIdleCacheKeys(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() @@ -265,6 +271,7 @@ func TestTransportIdleCacheKeys(t *testing.T) { } func TestTransportMaxPerHostIdleConns(t *testing.T) { + defer afterTest(t) resch := make(chan string) gotReq := make(chan bool) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -333,6 +340,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { } func TestTransportServerClosingUnexpectedly(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() @@ -389,6 +397,7 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) { // Test for http://golang.org/issue/2616 (appropriate issue number) // This fails pretty reliably with GOMAXPROCS=100 or something high. func TestStressSurpriseServerCloses(t *testing.T) { + defer afterTest(t) if testing.Short() { t.Skip("skipping test in short mode") } @@ -444,6 +453,7 @@ func TestStressSurpriseServerCloses(t *testing.T) { // TestTransportHeadResponses verifies that we deal with Content-Lengths // with no bodies properly func TestTransportHeadResponses(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "HEAD" { panic("expected HEAD; got " + r.Method) @@ -472,6 +482,7 @@ func TestTransportHeadResponses(t *testing.T) { // TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding // on responses to HEAD requests. func TestTransportHeadChunkedResponse(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "HEAD" { panic("expected HEAD; got " + r.Method) @@ -513,6 +524,7 @@ var roundTripTests = []struct { // Test that the modification made to the Request by the RoundTripper is cleaned up func TestRoundTripGzip(t *testing.T) { + defer afterTest(t) const responseBody = "test response body" ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { accept := req.Header.Get("Accept-Encoding") @@ -569,6 +581,7 @@ func TestRoundTripGzip(t *testing.T) { } func TestTransportGzip(t *testing.T) { + defer afterTest(t) const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" const nRandBytes = 1024 * 1024 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { @@ -661,6 +674,7 @@ func TestTransportGzip(t *testing.T) { } func TestTransportProxy(t *testing.T) { + defer afterTest(t) ch := make(chan string, 1) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ch <- "real server" @@ -689,6 +703,7 @@ func TestTransportProxy(t *testing.T) { // but checks that we don't recurse forever, and checks that // Content-Encoding is removed. func TestTransportGzipRecursive(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "gzip") w.Write(rgz) @@ -715,6 +730,7 @@ func TestTransportGzipRecursive(t *testing.T) { // tests that persistent goroutine connections shut down when no longer desired. func TestTransportPersistConnLeak(t *testing.T) { + defer afterTest(t) gotReqCh := make(chan bool) unblockCh := make(chan bool) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -780,6 +796,7 @@ func TestTransportPersistConnLeak(t *testing.T) { // golang.org/issue/4531: Transport leaks goroutines when // request.ContentLength is explicitly short func TestTransportPersistConnLeakShortBody(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { })) defer ts.Close() @@ -818,6 +835,7 @@ func TestTransportPersistConnLeakShortBody(t *testing.T) { // This used to crash; http://golang.org/issue/3266 func TestTransportIdleConnCrash(t *testing.T) { + defer afterTest(t) tr := &Transport{} c := &Client{Transport: tr} @@ -847,6 +865,7 @@ func TestTransportIdleConnCrash(t *testing.T) { // which sadly lacked a triggering test. The large response body made // the old race easier to trigger. func TestIssue3644(t *testing.T) { + defer afterTest(t) const numFoos = 5000 ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Connection", "close") @@ -874,6 +893,7 @@ func TestIssue3644(t *testing.T) { // Test that a client receives a server's reply, even if the server doesn't read // the entire request body. func TestIssue3595(t *testing.T) { + defer afterTest(t) const deniedMsg = "sorry, denied." ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { Error(w, deniedMsg, StatusUnauthorized) @@ -898,6 +918,7 @@ func TestIssue3595(t *testing.T) { // From http://golang.org/issue/4454 , // "client fails to handle requests with no body and chunked encoding" func TestChunkedNoContent(t *testing.T) { + defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusNoContent) })) @@ -920,20 +941,38 @@ func TestChunkedNoContent(t *testing.T) { } func TestTransportConcurrency(t *testing.T) { - const maxProcs = 16 - const numReqs = 500 + defer afterTest(t) + maxProcs, numReqs := 16, 500 + if testing.Short() { + maxProcs, numReqs = 4, 50 + } defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%v", r.FormValue("echo")) })) defer ts.Close() - tr := &Transport{} + + var wg sync.WaitGroup + wg.Add(numReqs) + + tr := &Transport{ + Dial: func(netw, addr string) (c net.Conn, err error) { + // Due to the Transport's "socket late + // binding" (see idleConnCh in transport.go), + // the numReqs HTTP requests below can finish + // with a dial still outstanding. So count + // our dials as work too so the leak checker + // doesn't complain at us. + wg.Add(1) + defer wg.Done() + return net.Dial(netw, addr) + }, + } + defer tr.CloseIdleConnections() c := &Client{Transport: tr} reqs := make(chan string) defer close(reqs) - var wg sync.WaitGroup - wg.Add(numReqs) for i := 0; i < maxProcs*2; i++ { go func() { for req := range reqs { @@ -952,8 +991,8 @@ func TestTransportConcurrency(t *testing.T) { if string(all) != req { t.Errorf("body of req %s = %q; want %q", req, all, req) } - wg.Done() res.Body.Close() + wg.Done() } }() } @@ -964,12 +1003,14 @@ func TestTransportConcurrency(t *testing.T) { } func TestIssue4191_InfiniteGetTimeout(t *testing.T) { + defer afterTest(t) const debug = false mux := NewServeMux() mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { io.Copy(w, neverEnding('a')) }) ts := httptest.NewServer(mux) + timeout := 100 * time.Millisecond client := &Client{ Transport: &Transport{ @@ -978,7 +1019,7 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { if err != nil { return nil, err } - conn.SetDeadline(time.Now().Add(100 * time.Millisecond)) + conn.SetDeadline(time.Now().Add(timeout)) if debug { conn = NewLoggingConn("client", conn) } @@ -988,6 +1029,7 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { }, } + getFailed := false nRuns := 5 if testing.Short() { nRuns = 1 @@ -998,6 +1040,14 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { } sres, err := client.Get(ts.URL + "/get") if err != nil { + if !getFailed { + // Make the timeout longer, once. + getFailed = true + t.Logf("increasing timeout") + i-- + timeout *= 10 + continue + } t.Errorf("Error issuing GET: %v", err) break } @@ -1014,6 +1064,7 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { } func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { + defer afterTest(t) const debug = false mux := NewServeMux() mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { @@ -1024,6 +1075,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { io.Copy(ioutil.Discard, r.Body) }) ts := httptest.NewServer(mux) + timeout := 100 * time.Millisecond client := &Client{ Transport: &Transport{ @@ -1032,7 +1084,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { if err != nil { return nil, err } - conn.SetDeadline(time.Now().Add(100 * time.Millisecond)) + conn.SetDeadline(time.Now().Add(timeout)) if debug { conn = NewLoggingConn("client", conn) } @@ -1042,6 +1094,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { }, } + getFailed := false nRuns := 5 if testing.Short() { nRuns = 1 @@ -1052,6 +1105,14 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { } sres, err := client.Get(ts.URL + "/get") if err != nil { + if !getFailed { + // Make the timeout longer, once. + getFailed = true + t.Logf("increasing timeout") + i-- + timeout *= 10 + continue + } t.Errorf("Error issuing GET: %v", err) break } @@ -1070,6 +1131,171 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { ts.Close() } +func TestTransportResponseHeaderTimeout(t *testing.T) { + defer afterTest(t) + if testing.Short() { + t.Skip("skipping timeout test in -short mode") + } + mux := NewServeMux() + mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {}) + mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) { + time.Sleep(2 * time.Second) + }) + ts := httptest.NewServer(mux) + defer ts.Close() + + tr := &Transport{ + ResponseHeaderTimeout: 500 * time.Millisecond, + } + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + tests := []struct { + path string + want int + wantErr string + }{ + {path: "/fast", want: 200}, + {path: "/slow", wantErr: "timeout awaiting response headers"}, + {path: "/fast", want: 200}, + } + for i, tt := range tests { + res, err := c.Get(ts.URL + tt.path) + if err != nil { + if strings.Contains(err.Error(), tt.wantErr) { + continue + } + t.Errorf("%d. unexpected error: %v", i, err) + continue + } + if tt.wantErr != "" { + t.Errorf("%d. no error. expected error: %v", i, tt.wantErr) + continue + } + if res.StatusCode != tt.want { + t.Errorf("%d for path %q status = %d; want %d", i, tt.path, res.StatusCode, tt.want) + } + } +} + +func TestTransportCancelRequest(t *testing.T) { + defer afterTest(t) + if testing.Short() { + t.Skip("skipping test in -short mode") + } + unblockc := make(chan bool) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "Hello") + w.(Flusher).Flush() // send headers and some body + <-unblockc + })) + defer ts.Close() + defer close(unblockc) + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + req, _ := NewRequest("GET", ts.URL, nil) + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + go func() { + time.Sleep(1 * time.Second) + tr.CancelRequest(req) + }() + t0 := time.Now() + body, err := ioutil.ReadAll(res.Body) + d := time.Since(t0) + + if err == nil { + t.Error("expected an error reading the body") + } + if string(body) != "Hello" { + t.Errorf("Body = %q; want Hello", body) + } + if d < 500*time.Millisecond { + t.Errorf("expected ~1 second delay; got %v", d) + } + // Verify no outstanding requests after readLoop/writeLoop + // goroutines shut down. + for tries := 3; tries > 0; tries-- { + n := tr.NumPendingRequestsForTesting() + if n == 0 { + break + } + time.Sleep(100 * time.Millisecond) + if tries == 1 { + t.Errorf("pending requests = %d; want 0", n) + } + } +} + +// golang.org/issue/3672 -- Client can't close HTTP stream +// Calling Close on a Response.Body used to just read until EOF. +// Now it actually closes the TCP connection. +func TestTransportCloseResponseBody(t *testing.T) { + defer afterTest(t) + writeErr := make(chan error, 1) + msg := []byte("young\n") + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + for { + _, err := w.Write(msg) + if err != nil { + writeErr <- err + return + } + w.(Flusher).Flush() + } + })) + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + req, _ := NewRequest("GET", ts.URL, nil) + defer tr.CancelRequest(req) + + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + + const repeats = 3 + buf := make([]byte, len(msg)*repeats) + want := bytes.Repeat(msg, repeats) + + _, err = io.ReadFull(res.Body, buf) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf, want) { + t.Errorf("read %q; want %q", buf, want) + } + didClose := make(chan error, 1) + go func() { + didClose <- res.Body.Close() + }() + select { + case err := <-didClose: + if err != nil { + t.Errorf("Close = %v", err) + } + case <-time.After(10 * time.Second): + t.Fatal("too long waiting for close") + } + select { + case err := <-writeErr: + if err == nil { + t.Errorf("expected non-nil write error") + } + case <-time.After(10 * time.Second): + t.Fatal("too long waiting for write error") + } +} + type fooProto struct{} func (fooProto) RoundTrip(req *Request) (*Response, error) { @@ -1083,6 +1309,7 @@ func (fooProto) RoundTrip(req *Request) (*Response, error) { } func TestTransportAltProto(t *testing.T) { + defer afterTest(t) tr := &Transport{} c := &Client{Transport: tr} tr.RegisterProtocol("foo", fooProto{}) @@ -1101,6 +1328,7 @@ func TestTransportAltProto(t *testing.T) { } func TestTransportNoHost(t *testing.T) { + defer afterTest(t) tr := &Transport{} _, err := tr.RoundTrip(&Request{ Header: make(Header), @@ -1114,6 +1342,172 @@ func TestTransportNoHost(t *testing.T) { } } +func TestTransportSocketLateBinding(t *testing.T) { + defer afterTest(t) + + mux := NewServeMux() + fooGate := make(chan bool, 1) + mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) { + w.Header().Set("foo-ipport", r.RemoteAddr) + w.(Flusher).Flush() + <-fooGate + }) + mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) { + w.Header().Set("bar-ipport", r.RemoteAddr) + }) + ts := httptest.NewServer(mux) + defer ts.Close() + + dialGate := make(chan bool, 1) + tr := &Transport{ + Dial: func(n, addr string) (net.Conn, error) { + <-dialGate + return net.Dial(n, addr) + }, + DisableKeepAlives: false, + } + defer tr.CloseIdleConnections() + c := &Client{ + Transport: tr, + } + + dialGate <- true // only allow one dial + fooRes, err := c.Get(ts.URL + "/foo") + if err != nil { + t.Fatal(err) + } + fooAddr := fooRes.Header.Get("foo-ipport") + if fooAddr == "" { + t.Fatal("No addr on /foo request") + } + time.AfterFunc(200*time.Millisecond, func() { + // let the foo response finish so we can use its + // connection for /bar + fooGate <- true + io.Copy(ioutil.Discard, fooRes.Body) + fooRes.Body.Close() + }) + + barRes, err := c.Get(ts.URL + "/bar") + if err != nil { + t.Fatal(err) + } + barAddr := barRes.Header.Get("bar-ipport") + if barAddr != fooAddr { + t.Fatalf("/foo came from conn %q; /bar came from %q instead", fooAddr, barAddr) + } + barRes.Body.Close() + dialGate <- true +} + +// Issue 2184 +func TestTransportReading100Continue(t *testing.T) { + defer afterTest(t) + + const numReqs = 5 + reqBody := func(n int) string { return fmt.Sprintf("request body %d", n) } + reqID := func(n int) string { return fmt.Sprintf("REQ-ID-%d", n) } + + send100Response := func(w *io.PipeWriter, r *io.PipeReader) { + defer w.Close() + defer r.Close() + br := bufio.NewReader(r) + n := 0 + for { + n++ + req, err := ReadRequest(br) + if err == io.EOF { + return + } + if err != nil { + t.Error(err) + return + } + slurp, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Errorf("Server request body slurp: %v", err) + return + } + id := req.Header.Get("Request-Id") + resCode := req.Header.Get("X-Want-Response-Code") + if resCode == "" { + resCode = "100 Continue" + if string(slurp) != reqBody(n) { + t.Errorf("Server got %q, %v; want %q", slurp, err, reqBody(n)) + } + } + body := fmt.Sprintf("Response number %d", n) + v := []byte(strings.Replace(fmt.Sprintf(`HTTP/1.1 %s +Date: Thu, 28 Feb 2013 17:55:41 GMT + +HTTP/1.1 200 OK +Content-Type: text/html +Echo-Request-Id: %s +Content-Length: %d + +%s`, resCode, id, len(body), body), "\n", "\r\n", -1)) + w.Write(v) + if id == reqID(numReqs) { + return + } + } + + } + + tr := &Transport{ + Dial: func(n, addr string) (net.Conn, error) { + sr, sw := io.Pipe() // server read/write + cr, cw := io.Pipe() // client read/write + conn := &rwTestConn{ + Reader: cr, + Writer: sw, + closeFunc: func() error { + sw.Close() + cw.Close() + return nil + }, + } + go send100Response(cw, sr) + return conn, nil + }, + DisableKeepAlives: false, + } + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + testResponse := func(req *Request, name string, wantCode int) { + res, err := c.Do(req) + if err != nil { + t.Fatalf("%s: Do: %v", name, err) + } + if res.StatusCode != wantCode { + t.Fatalf("%s: Response Statuscode=%d; want %d", name, res.StatusCode, wantCode) + } + if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack { + t.Errorf("%s: response id %q != request id %q", name, idBack, id) + } + _, err = ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("%s: Slurp error: %v", name, err) + } + } + + // Few 100 responses, making sure we're not off-by-one. + for i := 1; i <= numReqs; i++ { + req, _ := NewRequest("POST", "http://dummy.tld/", strings.NewReader(reqBody(i))) + req.Header.Set("Request-Id", reqID(i)) + testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200) + } + + // And some other informational 1xx but non-100 responses, to test + // we return them but don't re-use the connection. + for i := 1; i <= numReqs; i++ { + req, _ := NewRequest("POST", "http://other.tld/", strings.NewReader(reqBody(i))) + req.Header.Set("X-Want-Response-Code", "123 Sesame Street") + testResponse(req, fmt.Sprintf("123, %d/%d", i, numReqs), 123) + } +} + type proxyFromEnvTest struct { req string // URL to fetch; blank means "http://example.com" env string diff --git a/libgo/go/net/http/z_last_test.go b/libgo/go/net/http/z_last_test.go new file mode 100644 index 0000000..2161db7 --- /dev/null +++ b/libgo/go/net/http/z_last_test.go @@ -0,0 +1,98 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "net/http" + "runtime" + "sort" + "strings" + "testing" + "time" +) + +func interestingGoroutines() (gs []string) { + buf := make([]byte, 2<<20) + buf = buf[:runtime.Stack(buf, true)] + for _, g := range strings.Split(string(buf), "\n\n") { + sl := strings.SplitN(g, "\n", 2) + if len(sl) != 2 { + continue + } + stack := strings.TrimSpace(sl[1]) + if stack == "" || + strings.Contains(stack, "created by net.newPollServer") || + strings.Contains(stack, "created by net.startServer") || + strings.Contains(stack, "created by testing.RunTests") || + strings.Contains(stack, "closeWriteAndWait") || + strings.Contains(stack, "testing.Main(") || + // These only show up with GOTRACEBACK=2; Issue 5005 (comment 28) + strings.Contains(stack, "runtime.goexit") || + strings.Contains(stack, "created by runtime.gc") || + strings.Contains(stack, "runtime.MHeap_Scavenger") { + continue + } + gs = append(gs, stack) + } + sort.Strings(gs) + return +} + +// Verify the other tests didn't leave any goroutines running. +// This is in a file named z_last_test.go so it sorts at the end. +func TestGoroutinesRunning(t *testing.T) { + if testing.Short() { + t.Skip("not counting goroutines for leakage in -short mode") + } + gs := interestingGoroutines() + + n := 0 + stackCount := make(map[string]int) + for _, g := range gs { + stackCount[g]++ + n++ + } + + t.Logf("num goroutines = %d", n) + if n > 0 { + t.Error("Too many goroutines.") + for stack, count := range stackCount { + t.Logf("%d instances of:\n%s", count, stack) + } + } +} + +func afterTest(t *testing.T) { + http.DefaultTransport.(*http.Transport).CloseIdleConnections() + if testing.Short() { + return + } + var bad string + badSubstring := map[string]string{ + ").readLoop(": "a Transport", + ").writeLoop(": "a Transport", + "created by net/http/httptest.(*Server).Start": "an httptest.Server", + "timeoutHandler": "a TimeoutHandler", + "net.(*netFD).connect(": "a timing out dial", + ").noteClientGone(": "a closenotifier sender", + } + var stacks string + for i := 0; i < 4; i++ { + bad = "" + stacks = strings.Join(interestingGoroutines(), "\n\n") + for substr, what := range badSubstring { + if strings.Contains(stacks, substr) { + bad = what + } + } + if bad == "" { + return + } + // Bad stuff found, but goroutines might just still be + // shutting down, so give it some time. + time.Sleep(250 * time.Millisecond) + } + t.Errorf("Test appears to have leaked %s:\n%s", bad, stacks) +} diff --git a/libgo/go/net/interface.go b/libgo/go/net/interface.go index ee23570..0713e9c 100644 --- a/libgo/go/net/interface.go +++ b/libgo/go/net/interface.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Network interface identification - package net import "errors" @@ -66,7 +64,7 @@ func (ifi *Interface) Addrs() ([]Addr, error) { if ifi == nil { return nil, errInvalidInterface } - return interfaceAddrTable(ifi.Index) + return interfaceAddrTable(ifi) } // MulticastAddrs returns multicast, joined group addresses for @@ -75,7 +73,7 @@ func (ifi *Interface) MulticastAddrs() ([]Addr, error) { if ifi == nil { return nil, errInvalidInterface } - return interfaceMulticastAddrTable(ifi.Index) + return interfaceMulticastAddrTable(ifi) } // Interfaces returns a list of the system's network interfaces. @@ -86,7 +84,7 @@ func Interfaces() ([]Interface, error) { // InterfaceAddrs returns a list of the system's network interface // addresses. func InterfaceAddrs() ([]Addr, error) { - return interfaceAddrTable(0) + return interfaceAddrTable(nil) } // InterfaceByIndex returns the interface specified by index. @@ -98,8 +96,14 @@ func InterfaceByIndex(index int) (*Interface, error) { if err != nil { return nil, err } + return interfaceByIndex(ift, index) +} + +func interfaceByIndex(ift []Interface, index int) (*Interface, error) { for _, ifi := range ift { - return &ifi, nil + if index == ifi.Index { + return &ifi, nil + } } return nil, errNoSuchInterface } diff --git a/libgo/go/net/interface_bsd.go b/libgo/go/net/interface_bsd.go index df9b3a2..716b60a 100644 --- a/libgo/go/net/interface_bsd.go +++ b/libgo/go/net/interface_bsd.go @@ -4,8 +4,6 @@ // +build darwin freebsd netbsd openbsd -// Network interface identification for BSD variants - package net import ( @@ -22,57 +20,60 @@ func interfaceTable(ifindex int) ([]Interface, error) { if err != nil { return nil, os.NewSyscallError("route rib", err) } - msgs, err := syscall.ParseRoutingMessage(tab) if err != nil { return nil, os.NewSyscallError("route message", err) } + return parseInterfaceTable(ifindex, msgs) +} +func parseInterfaceTable(ifindex int, msgs []syscall.RoutingMessage) ([]Interface, error) { var ift []Interface +loop: for _, m := range msgs { - switch v := m.(type) { + switch m := m.(type) { case *syscall.InterfaceMessage: - if ifindex == 0 || ifindex == int(v.Header.Index) { - ifi, err := newLink(v) + if ifindex == 0 || ifindex == int(m.Header.Index) { + ifi, err := newLink(m) if err != nil { return nil, err } - ift = append(ift, ifi...) + ift = append(ift, *ifi) + if ifindex == int(m.Header.Index) { + break loop + } } } } return ift, nil } -func newLink(m *syscall.InterfaceMessage) ([]Interface, error) { +func newLink(m *syscall.InterfaceMessage) (*Interface, error) { sas, err := syscall.ParseRoutingSockaddr(m) if err != nil { return nil, os.NewSyscallError("route sockaddr", err) } - - var ift []Interface - for _, s := range sas { - switch v := s.(type) { + ifi := &Interface{Index: int(m.Header.Index), Flags: linkFlags(m.Header.Flags)} + for _, sa := range sas { + switch sa := sa.(type) { case *syscall.SockaddrDatalink: // NOTE: SockaddrDatalink.Data is minimum work area, // can be larger. - m.Data = m.Data[unsafe.Offsetof(v.Data):] - ifi := Interface{Index: int(m.Header.Index), Flags: linkFlags(m.Header.Flags)} + m.Data = m.Data[unsafe.Offsetof(sa.Data):] var name [syscall.IFNAMSIZ]byte - for i := 0; i < int(v.Nlen); i++ { + for i := 0; i < int(sa.Nlen); i++ { name[i] = byte(m.Data[i]) } - ifi.Name = string(name[:v.Nlen]) + ifi.Name = string(name[:sa.Nlen]) ifi.MTU = int(m.Header.Data.Mtu) - addr := make([]byte, v.Alen) - for i := 0; i < int(v.Alen); i++ { - addr[i] = byte(m.Data[int(v.Nlen)+i]) + addr := make([]byte, sa.Alen) + for i := 0; i < int(sa.Alen); i++ { + addr[i] = byte(m.Data[int(sa.Nlen)+i]) } - ifi.HardwareAddr = addr[:v.Alen] - ift = append(ift, ifi) + ifi.HardwareAddr = addr[:sa.Alen] } } - return ift, nil + return ifi, nil } func linkFlags(rawFlags int32) Flags { @@ -95,26 +96,42 @@ func linkFlags(rawFlags int32) Flags { return f } -// If the ifindex is zero, interfaceAddrTable returns addresses -// for all network interfaces. Otherwise it returns addresses -// for a specific interface. -func interfaceAddrTable(ifindex int) ([]Addr, error) { - tab, err := syscall.RouteRIB(syscall.NET_RT_IFLIST, ifindex) +// If the ifi is nil, interfaceAddrTable returns addresses for all +// network interfaces. Otherwise it returns addresses for a specific +// interface. +func interfaceAddrTable(ifi *Interface) ([]Addr, error) { + index := 0 + if ifi != nil { + index = ifi.Index + } + tab, err := syscall.RouteRIB(syscall.NET_RT_IFLIST, index) if err != nil { return nil, os.NewSyscallError("route rib", err) } - msgs, err := syscall.ParseRoutingMessage(tab) if err != nil { return nil, os.NewSyscallError("route message", err) } - + var ift []Interface + if index == 0 { + ift, err = parseInterfaceTable(index, msgs) + if err != nil { + return nil, err + } + } var ifat []Addr for _, m := range msgs { - switch v := m.(type) { + switch m := m.(type) { case *syscall.InterfaceAddrMessage: - if ifindex == 0 || ifindex == int(v.Header.Index) { - ifa, err := newAddr(v) + if index == 0 || index == int(m.Header.Index) { + if index == 0 { + var err error + ifi, err = interfaceByIndex(ift, int(m.Header.Index)) + if err != nil { + return nil, err + } + } + ifa, err := newAddr(ifi, m) if err != nil { return nil, err } @@ -127,35 +144,33 @@ func interfaceAddrTable(ifindex int) ([]Addr, error) { return ifat, nil } -func newAddr(m *syscall.InterfaceAddrMessage) (Addr, error) { +func newAddr(ifi *Interface, m *syscall.InterfaceAddrMessage) (Addr, error) { sas, err := syscall.ParseRoutingSockaddr(m) if err != nil { return nil, os.NewSyscallError("route sockaddr", err) } - ifa := &IPNet{} - for i, s := range sas { - switch v := s.(type) { + for i, sa := range sas { + switch sa := sa.(type) { case *syscall.SockaddrInet4: switch i { case 0: - ifa.Mask = IPv4Mask(v.Addr[0], v.Addr[1], v.Addr[2], v.Addr[3]) + ifa.Mask = IPv4Mask(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]) case 1: - ifa.IP = IPv4(v.Addr[0], v.Addr[1], v.Addr[2], v.Addr[3]) + ifa.IP = IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]) } case *syscall.SockaddrInet6: switch i { case 0: ifa.Mask = make(IPMask, IPv6len) - copy(ifa.Mask, v.Addr[:]) + copy(ifa.Mask, sa.Addr[:]) case 1: ifa.IP = make(IP, IPv6len) - copy(ifa.IP, v.Addr[:]) + copy(ifa.IP, sa.Addr[:]) // NOTE: KAME based IPv6 protcol stack usually embeds // the interface index in the interface-local or link- // local address as the kernel-internal form. if ifa.IP.IsLinkLocalUnicast() { - // remove embedded scope zone ID ifa.IP[2], ifa.IP[3] = 0, 0 } } diff --git a/libgo/go/net/interface_darwin.go b/libgo/go/net/interface_darwin.go index 0b5fb5fb..ad0937d 100644 --- a/libgo/go/net/interface_darwin.go +++ b/libgo/go/net/interface_darwin.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Network interface identification for Darwin - package net import ( @@ -11,26 +9,23 @@ import ( "syscall" ) -// If the ifindex is zero, interfaceMulticastAddrTable returns -// addresses for all network interfaces. Otherwise it returns -// addresses for a specific interface. -func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { - tab, err := syscall.RouteRIB(syscall.NET_RT_IFLIST2, ifindex) +// interfaceMulticastAddrTable returns addresses for a specific +// interface. +func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) { + tab, err := syscall.RouteRIB(syscall.NET_RT_IFLIST2, ifi.Index) if err != nil { return nil, os.NewSyscallError("route rib", err) } - msgs, err := syscall.ParseRoutingMessage(tab) if err != nil { return nil, os.NewSyscallError("route message", err) } - var ifmat []Addr for _, m := range msgs { - switch v := m.(type) { + switch m := m.(type) { case *syscall.InterfaceMulticastAddrMessage: - if ifindex == 0 || ifindex == int(v.Header.Index) { - ifma, err := newMulticastAddr(v) + if ifi.Index == int(m.Header.Index) { + ifma, err := newMulticastAddr(ifi, m) if err != nil { return nil, err } @@ -41,27 +36,24 @@ func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { return ifmat, nil } -func newMulticastAddr(m *syscall.InterfaceMulticastAddrMessage) ([]Addr, error) { +func newMulticastAddr(ifi *Interface, m *syscall.InterfaceMulticastAddrMessage) ([]Addr, error) { sas, err := syscall.ParseRoutingSockaddr(m) if err != nil { return nil, os.NewSyscallError("route sockaddr", err) } - var ifmat []Addr - for _, s := range sas { - switch v := s.(type) { + for _, sa := range sas { + switch sa := sa.(type) { case *syscall.SockaddrInet4: - ifma := &IPAddr{IP: IPv4(v.Addr[0], v.Addr[1], v.Addr[2], v.Addr[3])} + ifma := &IPAddr{IP: IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3])} ifmat = append(ifmat, ifma.toAddr()) case *syscall.SockaddrInet6: ifma := &IPAddr{IP: make(IP, IPv6len)} - copy(ifma.IP, v.Addr[:]) - // NOTE: KAME based IPv6 protcol stack usually embeds + copy(ifma.IP, sa.Addr[:]) + // NOTE: KAME based IPv6 protocol stack usually embeds // the interface index in the interface-local or link- // local address as the kernel-internal form. - if ifma.IP.IsInterfaceLocalMulticast() || - ifma.IP.IsLinkLocalMulticast() { - // remove embedded scope zone ID + if ifma.IP.IsInterfaceLocalMulticast() || ifma.IP.IsLinkLocalMulticast() { ifma.IP[2], ifma.IP[3] = 0, 0 } ifmat = append(ifmat, ifma.toAddr()) diff --git a/libgo/go/net/interface_freebsd.go b/libgo/go/net/interface_freebsd.go index 3cba28f..5df7679 100644 --- a/libgo/go/net/interface_freebsd.go +++ b/libgo/go/net/interface_freebsd.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Network interface identification for FreeBSD - package net import ( @@ -11,26 +9,23 @@ import ( "syscall" ) -// If the ifindex is zero, interfaceMulticastAddrTable returns -// addresses for all network interfaces. Otherwise it returns -// addresses for a specific interface. -func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { - tab, err := syscall.RouteRIB(syscall.NET_RT_IFMALIST, ifindex) +// interfaceMulticastAddrTable returns addresses for a specific +// interface. +func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) { + tab, err := syscall.RouteRIB(syscall.NET_RT_IFMALIST, ifi.Index) if err != nil { return nil, os.NewSyscallError("route rib", err) } - msgs, err := syscall.ParseRoutingMessage(tab) if err != nil { return nil, os.NewSyscallError("route message", err) } - var ifmat []Addr for _, m := range msgs { - switch v := m.(type) { + switch m := m.(type) { case *syscall.InterfaceMulticastAddrMessage: - if ifindex == 0 || ifindex == int(v.Header.Index) { - ifma, err := newMulticastAddr(v) + if ifi.Index == int(m.Header.Index) { + ifma, err := newMulticastAddr(ifi, m) if err != nil { return nil, err } @@ -41,27 +36,24 @@ func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { return ifmat, nil } -func newMulticastAddr(m *syscall.InterfaceMulticastAddrMessage) ([]Addr, error) { +func newMulticastAddr(ifi *Interface, m *syscall.InterfaceMulticastAddrMessage) ([]Addr, error) { sas, err := syscall.ParseRoutingSockaddr(m) if err != nil { return nil, os.NewSyscallError("route sockaddr", err) } - var ifmat []Addr - for _, s := range sas { - switch v := s.(type) { + for _, sa := range sas { + switch sa := sa.(type) { case *syscall.SockaddrInet4: - ifma := &IPAddr{IP: IPv4(v.Addr[0], v.Addr[1], v.Addr[2], v.Addr[3])} + ifma := &IPAddr{IP: IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3])} ifmat = append(ifmat, ifma.toAddr()) case *syscall.SockaddrInet6: ifma := &IPAddr{IP: make(IP, IPv6len)} - copy(ifma.IP, v.Addr[:]) - // NOTE: KAME based IPv6 protcol stack usually embeds + copy(ifma.IP, sa.Addr[:]) + // NOTE: KAME based IPv6 protocol stack usually embeds // the interface index in the interface-local or link- // local address as the kernel-internal form. - if ifma.IP.IsInterfaceLocalMulticast() || - ifma.IP.IsLinkLocalMulticast() { - // remove embedded scope zone ID + if ifma.IP.IsInterfaceLocalMulticast() || ifma.IP.IsLinkLocalMulticast() { ifma.IP[2], ifma.IP[3] = 0, 0 } ifmat = append(ifmat, ifma.toAddr()) diff --git a/libgo/go/net/interface_linux.go b/libgo/go/net/interface_linux.go index ce2e921..1207c0f 100644 --- a/libgo/go/net/interface_linux.go +++ b/libgo/go/net/interface_linux.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Network interface identification for Linux - package net import ( @@ -20,17 +18,16 @@ func interfaceTable(ifindex int) ([]Interface, error) { if err != nil { return nil, os.NewSyscallError("netlink rib", err) } - msgs, err := syscall.ParseNetlinkMessage(tab) if err != nil { return nil, os.NewSyscallError("netlink message", err) } - var ift []Interface +loop: for _, m := range msgs { switch m.Header.Type { case syscall.NLMSG_DONE: - goto done + break loop case syscall.RTM_NEWLINK: ifim := (*syscall.IfInfomsg)(unsafe.Pointer(&m.Data[0])) if ifindex == 0 || ifindex == int(ifim.Index) { @@ -38,17 +35,18 @@ func interfaceTable(ifindex int) ([]Interface, error) { if err != nil { return nil, os.NewSyscallError("netlink routeattr", err) } - ifi := newLink(ifim, attrs) - ift = append(ift, ifi) + ift = append(ift, *newLink(ifim, attrs)) + if ifindex == int(ifim.Index) { + break loop + } } } } -done: return ift, nil } -func newLink(ifim *syscall.IfInfomsg, attrs []syscall.NetlinkRouteAttr) Interface { - ifi := Interface{Index: int(ifim.Index), Flags: linkFlags(ifim.Flags)} +func newLink(ifim *syscall.IfInfomsg, attrs []syscall.NetlinkRouteAttr) *Interface { + ifi := &Interface{Index: int(ifim.Index), Flags: linkFlags(ifim.Flags)} for _, a := range attrs { switch a.Attr.Type { case syscall.IFLA_ADDRESS: @@ -90,81 +88,84 @@ func linkFlags(rawFlags uint32) Flags { return f } -// If the ifindex is zero, interfaceAddrTable returns addresses -// for all network interfaces. Otherwise it returns addresses -// for a specific interface. -func interfaceAddrTable(ifindex int) ([]Addr, error) { +// If the ifi is nil, interfaceAddrTable returns addresses for all +// network interfaces. Otherwise it returns addresses for a specific +// interface. +func interfaceAddrTable(ifi *Interface) ([]Addr, error) { tab, err := syscall.NetlinkRIB(syscall.RTM_GETADDR, syscall.AF_UNSPEC) if err != nil { return nil, os.NewSyscallError("netlink rib", err) } - msgs, err := syscall.ParseNetlinkMessage(tab) if err != nil { return nil, os.NewSyscallError("netlink message", err) } - - ifat, err := addrTable(msgs, ifindex) + var ift []Interface + if ifi == nil { + var err error + ift, err = interfaceTable(0) + if err != nil { + return nil, err + } + } + ifat, err := addrTable(ift, ifi, msgs) if err != nil { return nil, err } return ifat, nil } -func addrTable(msgs []syscall.NetlinkMessage, ifindex int) ([]Addr, error) { +func addrTable(ift []Interface, ifi *Interface, msgs []syscall.NetlinkMessage) ([]Addr, error) { var ifat []Addr +loop: for _, m := range msgs { switch m.Header.Type { case syscall.NLMSG_DONE: - goto done + break loop case syscall.RTM_NEWADDR: ifam := (*syscall.IfAddrmsg)(unsafe.Pointer(&m.Data[0])) - if ifindex == 0 || ifindex == int(ifam.Index) { + if len(ift) != 0 || ifi.Index == int(ifam.Index) { + if len(ift) != 0 { + var err error + ifi, err = interfaceByIndex(ift, int(ifam.Index)) + if err != nil { + return nil, err + } + } attrs, err := syscall.ParseNetlinkRouteAttr(&m) if err != nil { return nil, os.NewSyscallError("netlink routeattr", err) } - ifat = append(ifat, newAddr(attrs, int(ifam.Family), int(ifam.Prefixlen))) + ifa := newAddr(ifi, ifam, attrs) + if ifa != nil { + ifat = append(ifat, ifa) + } } } } -done: return ifat, nil } -func newAddr(attrs []syscall.NetlinkRouteAttr, family, pfxlen int) Addr { - ifa := &IPNet{} +func newAddr(ifi *Interface, ifam *syscall.IfAddrmsg, attrs []syscall.NetlinkRouteAttr) Addr { for _, a := range attrs { - switch a.Attr.Type { - case syscall.IFA_ADDRESS: - switch family { + if ifi.Flags&FlagPointToPoint != 0 && a.Attr.Type == syscall.IFA_LOCAL || + ifi.Flags&FlagPointToPoint == 0 && a.Attr.Type == syscall.IFA_ADDRESS { + switch ifam.Family { case syscall.AF_INET: - ifa.IP = IPv4(a.Value[0], a.Value[1], a.Value[2], a.Value[3]) - ifa.Mask = CIDRMask(pfxlen, 8*IPv4len) + return &IPNet{IP: IPv4(a.Value[0], a.Value[1], a.Value[2], a.Value[3]), Mask: CIDRMask(int(ifam.Prefixlen), 8*IPv4len)} case syscall.AF_INET6: - ifa.IP = make(IP, IPv6len) + ifa := &IPNet{IP: make(IP, IPv6len), Mask: CIDRMask(int(ifam.Prefixlen), 8*IPv6len)} copy(ifa.IP, a.Value[:]) - ifa.Mask = CIDRMask(pfxlen, 8*IPv6len) + return ifa } } } - return ifa + return nil } -// If the ifindex is zero, interfaceMulticastAddrTable returns -// addresses for all network interfaces. Otherwise it returns -// addresses for a specific interface. -func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { - var ( - err error - ifi *Interface - ) - if ifindex > 0 { - ifi, err = InterfaceByIndex(ifindex) - if err != nil { - return nil, err - } - } +// interfaceMulticastAddrTable returns addresses for a specific +// interface. +func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) { ifmat4 := parseProcNetIGMP("/proc/net/igmp", ifi) ifmat6 := parseProcNetIGMP6("/proc/net/igmp6", ifi) return append(ifmat4, ifmat6...), nil @@ -176,7 +177,6 @@ func parseProcNetIGMP(path string, ifi *Interface) []Addr { return nil } defer fd.close() - var ( ifmat []Addr name string @@ -214,7 +214,6 @@ func parseProcNetIGMP6(path string, ifi *Interface) []Addr { return nil } defer fd.close() - var ifmat []Addr b := make([]byte, IPv6len) for l, ok := fd.readLine(); ok; l, ok = fd.readLine() { diff --git a/libgo/go/net/interface_netbsd.go b/libgo/go/net/interface_netbsd.go index 4150e9a..c9ce5a7 100644 --- a/libgo/go/net/interface_netbsd.go +++ b/libgo/go/net/interface_netbsd.go @@ -2,13 +2,11 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Network interface identification for NetBSD - package net -// If the ifindex is zero, interfaceMulticastAddrTable returns -// addresses for all network interfaces. Otherwise it returns -// addresses for a specific interface. -func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { +// interfaceMulticastAddrTable returns addresses for a specific +// interface. +func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) { + // TODO(mikio): Implement this like other platforms. return nil, nil } diff --git a/libgo/go/net/interface_openbsd.go b/libgo/go/net/interface_openbsd.go index d8adb46..c9ce5a7 100644 --- a/libgo/go/net/interface_openbsd.go +++ b/libgo/go/net/interface_openbsd.go @@ -2,13 +2,11 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Network interface identification for OpenBSD - package net -// If the ifindex is zero, interfaceMulticastAddrTable returns -// addresses for all network interfaces. Otherwise it returns -// addresses for a specific interface. -func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { +// interfaceMulticastAddrTable returns addresses for a specific +// interface. +func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) { + // TODO(mikio): Implement this like other platforms. return nil, nil } diff --git a/libgo/go/net/interface_stub.go b/libgo/go/net/interface_stub.go index d4d7ce9..a4eb731 100644 --- a/libgo/go/net/interface_stub.go +++ b/libgo/go/net/interface_stub.go @@ -4,8 +4,6 @@ // +build plan9 -// Network interface identification - package net // If the ifindex is zero, interfaceTable returns mappings of all @@ -15,16 +13,15 @@ func interfaceTable(ifindex int) ([]Interface, error) { return nil, nil } -// If the ifindex is zero, interfaceAddrTable returns addresses -// for all network interfaces. Otherwise it returns addresses -// for a specific interface. -func interfaceAddrTable(ifindex int) ([]Addr, error) { +// If the ifi is nil, interfaceAddrTable returns addresses for all +// network interfaces. Otherwise it returns addresses for a specific +// interface. +func interfaceAddrTable(ifi *Interface) ([]Addr, error) { return nil, nil } -// If the ifindex is zero, interfaceMulticastAddrTable returns -// addresses for all network interfaces. Otherwise it returns -// addresses for a specific interface. -func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { +// interfaceMulticastAddrTable returns addresses for a specific +// interface. +func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) { return nil, nil } diff --git a/libgo/go/net/interface_test.go b/libgo/go/net/interface_test.go index 803c1f4..e31894a 100644 --- a/libgo/go/net/interface_test.go +++ b/libgo/go/net/interface_test.go @@ -5,18 +5,50 @@ package net import ( - "bytes" + "reflect" "testing" ) -func sameInterface(i, j *Interface) bool { - if i == nil || j == nil { - return false +// loopbackInterface returns an available logical network interface +// for loopback tests. It returns nil if no suitable interface is +// found. +func loopbackInterface() *Interface { + ift, err := Interfaces() + if err != nil { + return nil + } + for _, ifi := range ift { + if ifi.Flags&FlagLoopback != 0 && ifi.Flags&FlagUp != 0 { + return &ifi + } + } + return nil +} + +// ipv6LinkLocalUnicastAddr returns an IPv6 link-local unicast address +// on the given network interface for tests. It returns "" if no +// suitable address is found. +func ipv6LinkLocalUnicastAddr(ifi *Interface) string { + if ifi == nil { + return "" + } + ifat, err := ifi.Addrs() + if err != nil { + return "" } - if i.Index == j.Index && i.Name == j.Name && bytes.Equal(i.HardwareAddr, j.HardwareAddr) { - return true + for _, ifa := range ifat { + switch ifa := ifa.(type) { + case *IPAddr: + if ifa.IP.To4() == nil && ifa.IP.IsLinkLocalUnicast() { + return ifa.IP.String() + } + case *IPNet: + if ifa.IP.To4() == nil && ifa.IP.IsLinkLocalUnicast() { + return ifa.IP.String() + } + } } - return false + return "" } func TestInterfaces(t *testing.T) { @@ -29,17 +61,17 @@ func TestInterfaces(t *testing.T) { for _, ifi := range ift { ifxi, err := InterfaceByIndex(ifi.Index) if err != nil { - t.Fatalf("InterfaceByIndex(%q) failed: %v", ifi.Index, err) + t.Fatalf("InterfaceByIndex(%v) failed: %v", ifi.Index, err) } - if !sameInterface(ifxi, &ifi) { - t.Fatalf("InterfaceByIndex(%q) = %v, want %v", ifi.Index, *ifxi, ifi) + if !reflect.DeepEqual(ifxi, &ifi) { + t.Fatalf("InterfaceByIndex(%v) = %v, want %v", ifi.Index, ifxi, ifi) } ifxn, err := InterfaceByName(ifi.Name) if err != nil { t.Fatalf("InterfaceByName(%q) failed: %v", ifi.Name, err) } - if !sameInterface(ifxn, &ifi) { - t.Fatalf("InterfaceByName(%q) = %v, want %v", ifi.Name, *ifxn, ifi) + if !reflect.DeepEqual(ifxn, &ifi) { + t.Fatalf("InterfaceByName(%q) = %v, want %v", ifi.Name, ifxn, ifi) } t.Logf("%q: flags %q, ifindex %v, mtu %v", ifi.Name, ifi.Flags.String(), ifi.Index, ifi.MTU) t.Logf("\thardware address %q", ifi.HardwareAddr.String()) @@ -75,9 +107,9 @@ func testInterfaceMulticastAddrs(t *testing.T, ifi *Interface) { func testAddrs(t *testing.T, ifat []Addr) { for _, ifa := range ifat { - switch v := ifa.(type) { + switch ifa := ifa.(type) { case *IPAddr, *IPNet: - if v == nil { + if ifa == nil { t.Errorf("\tunexpected value: %v", ifa) } else { t.Logf("\tinterface address %q", ifa.String()) @@ -90,9 +122,9 @@ func testAddrs(t *testing.T, ifat []Addr) { func testMulticastAddrs(t *testing.T, ifmat []Addr) { for _, ifma := range ifmat { - switch v := ifma.(type) { + switch ifma := ifma.(type) { case *IPAddr: - if v == nil { + if ifma == nil { t.Errorf("\tunexpected value: %v", ifma) } else { t.Logf("\tjoined group address %q", ifma.String()) @@ -102,3 +134,67 @@ func testMulticastAddrs(t *testing.T, ifmat []Addr) { } } } + +func BenchmarkInterfaces(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := Interfaces(); err != nil { + b.Fatalf("Interfaces failed: %v", err) + } + } +} + +func BenchmarkInterfaceByIndex(b *testing.B) { + ifi := loopbackInterface() + if ifi == nil { + b.Skip("loopback interface not found") + } + for i := 0; i < b.N; i++ { + if _, err := InterfaceByIndex(ifi.Index); err != nil { + b.Fatalf("InterfaceByIndex failed: %v", err) + } + } +} + +func BenchmarkInterfaceByName(b *testing.B) { + ifi := loopbackInterface() + if ifi == nil { + b.Skip("loopback interface not found") + } + for i := 0; i < b.N; i++ { + if _, err := InterfaceByName(ifi.Name); err != nil { + b.Fatalf("InterfaceByName failed: %v", err) + } + } +} + +func BenchmarkInterfaceAddrs(b *testing.B) { + for i := 0; i < b.N; i++ { + if _, err := InterfaceAddrs(); err != nil { + b.Fatalf("InterfaceAddrs failed: %v", err) + } + } +} + +func BenchmarkInterfacesAndAddrs(b *testing.B) { + ifi := loopbackInterface() + if ifi == nil { + b.Skip("loopback interface not found") + } + for i := 0; i < b.N; i++ { + if _, err := ifi.Addrs(); err != nil { + b.Fatalf("Interface.Addrs failed: %v", err) + } + } +} + +func BenchmarkInterfacesAndMulticastAddrs(b *testing.B) { + ifi := loopbackInterface() + if ifi == nil { + b.Skip("loopback interface not found") + } + for i := 0; i < b.N; i++ { + if _, err := ifi.MulticastAddrs(); err != nil { + b.Fatalf("Interface.MulticastAddrs failed: %v", err) + } + } +} diff --git a/libgo/go/net/interface_windows.go b/libgo/go/net/interface_windows.go index 4368b33..0759dc2 100644 --- a/libgo/go/net/interface_windows.go +++ b/libgo/go/net/interface_windows.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Network interface identification for Windows - package net import ( @@ -25,6 +23,9 @@ func getAdapterList() (*syscall.IpAdapterInfo, error) { b := make([]byte, 1000) l := uint32(len(b)) a := (*syscall.IpAdapterInfo)(unsafe.Pointer(&b[0])) + // TODO(mikio): GetAdaptersInfo returns IP_ADAPTER_INFO that + // contains IPv4 address list only. We should use another API + // for fetching IPv6 stuff from the kernel. err := syscall.GetAdaptersInfo(a, &l) if err == syscall.ERROR_BUFFER_OVERFLOW { b = make([]byte, l) @@ -38,7 +39,7 @@ func getAdapterList() (*syscall.IpAdapterInfo, error) { } func getInterfaceList() ([]syscall.InterfaceInfo, error) { - s, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP) + s, err := sysSocket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP) if err != nil { return nil, os.NewSyscallError("Socket", err) } @@ -126,10 +127,10 @@ func interfaceTable(ifindex int) ([]Interface, error) { return ift, nil } -// If the ifindex is zero, interfaceAddrTable returns addresses -// for all network interfaces. Otherwise it returns addresses -// for a specific interface. -func interfaceAddrTable(ifindex int) ([]Addr, error) { +// If the ifi is nil, interfaceAddrTable returns addresses for all +// network interfaces. Otherwise it returns addresses for a specific +// interface. +func interfaceAddrTable(ifi *Interface) ([]Addr, error) { ai, err := getAdapterList() if err != nil { return nil, err @@ -138,11 +139,10 @@ func interfaceAddrTable(ifindex int) ([]Addr, error) { var ifat []Addr for ; ai != nil; ai = ai.Next { index := ai.Index - if ifindex == 0 || ifindex == int(index) { + if ifi == nil || ifi.Index == int(index) { ipl := &ai.IpAddressList for ; ipl != nil; ipl = ipl.Next { - ifa := IPAddr{} - ifa.IP = parseIPv4(bytePtrToString(&ipl.IpAddress.String[0])) + ifa := IPAddr{IP: parseIPv4(bytePtrToString(&ipl.IpAddress.String[0]))} ifat = append(ifat, ifa.toAddr()) } } @@ -150,9 +150,9 @@ func interfaceAddrTable(ifindex int) ([]Addr, error) { return ifat, nil } -// If the ifindex is zero, interfaceMulticastAddrTable returns -// addresses for all network interfaces. Otherwise it returns -// addresses for a specific interface. -func interfaceMulticastAddrTable(ifindex int) ([]Addr, error) { +// interfaceMulticastAddrTable returns addresses for a specific +// interface. +func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) { + // TODO(mikio): Implement this like other platforms. return nil, nil } diff --git a/libgo/go/net/ip.go b/libgo/go/net/ip.go index d588e3a..0e42da2 100644 --- a/libgo/go/net/ip.go +++ b/libgo/go/net/ip.go @@ -36,7 +36,6 @@ type IPMask []byte type IPNet struct { IP IP // network number Mask IPMask // network mask - Zone string // IPv6 scoped addressing zone } // IPv4 returns the IP address (in 16-byte form) of the @@ -223,7 +222,6 @@ func (ip IP) DefaultMask() IPMask { default: return classCMask } - return nil // not reached } func allFF(b []byte) bool { @@ -433,6 +431,9 @@ func (n *IPNet) Contains(ip IP) bool { return true } +// Network returns the address's network name, "ip+net". +func (n *IPNet) Network() string { return "ip+net" } + // String returns the CIDR notation of n like "192.168.100.1/24" // or "2001:DB8::/48" as defined in RFC 4632 and RFC 4291. // If the mask is not in the canonical form, it returns the @@ -451,9 +452,6 @@ func (n *IPNet) String() string { return nn.String() + "/" + itod(uint(l)) } -// Network returns the address's network name, "ip+net". -func (n *IPNet) Network() string { return "ip+net" } - // Parse IPv4 address (d.d.d.d). func parseIPv4(s string) IP { var p [IPv4len]byte @@ -485,26 +483,26 @@ func parseIPv4(s string) IP { return IPv4(p[0], p[1], p[2], p[3]) } -// Parse IPv6 address. Many forms. -// The basic form is a sequence of eight colon-separated -// 16-bit hex numbers separated by colons, -// as in 0123:4567:89ab:cdef:0123:4567:89ab:cdef. -// Two exceptions: -// * A run of zeros can be replaced with "::". -// * The last 32 bits can be in IPv4 form. -// Thus, ::ffff:1.2.3.4 is the IPv4 address 1.2.3.4. -func parseIPv6(s string) IP { - p := make(IP, IPv6len) +// parseIPv6 parses s as a literal IPv6 address described in RFC 4291 +// and RFC 5952. It can also parse a literal scoped IPv6 address with +// zone identifier which is described in RFC 4007 when zoneAllowed is +// true. +func parseIPv6(s string, zoneAllowed bool) (ip IP, zone string) { + ip = make(IP, IPv6len) ellipsis := -1 // position of ellipsis in p i := 0 // index in string s + if zoneAllowed { + s, zone = splitHostZone(s) + } + // Might have leading ellipsis if len(s) >= 2 && s[0] == ':' && s[1] == ':' { ellipsis = 0 i = 2 // Might be only ellipsis if i == len(s) { - return p + return ip, zone } } @@ -514,35 +512,35 @@ func parseIPv6(s string) IP { // Hex number. n, i1, ok := xtoi(s, i) if !ok || n > 0xFFFF { - return nil + return nil, zone } // If followed by dot, might be in trailing IPv4. if i1 < len(s) && s[i1] == '.' { if ellipsis < 0 && j != IPv6len-IPv4len { // Not the right place. - return nil + return nil, zone } if j+IPv4len > IPv6len { // Not enough room. - return nil + return nil, zone } - p4 := parseIPv4(s[i:]) - if p4 == nil { - return nil + ip4 := parseIPv4(s[i:]) + if ip4 == nil { + return nil, zone } - p[j] = p4[12] - p[j+1] = p4[13] - p[j+2] = p4[14] - p[j+3] = p4[15] + ip[j] = ip4[12] + ip[j+1] = ip4[13] + ip[j+2] = ip4[14] + ip[j+3] = ip4[15] i = len(s) j += IPv4len break } // Save this 16-bit chunk. - p[j] = byte(n >> 8) - p[j+1] = byte(n) + ip[j] = byte(n >> 8) + ip[j+1] = byte(n) j += 2 // Stop at end of string. @@ -553,14 +551,14 @@ func parseIPv6(s string) IP { // Otherwise must be followed by colon and more. if s[i] != ':' || i+1 == len(s) { - return nil + return nil, zone } i++ // Look for ellipsis. if s[i] == ':' { if ellipsis >= 0 { // already have one - return nil + return nil, zone } ellipsis = j if i++; i == len(s) { // can be at end @@ -571,23 +569,23 @@ func parseIPv6(s string) IP { // Must have used entire string. if i != len(s) { - return nil + return nil, zone } // If didn't parse enough, expand ellipsis. if j < IPv6len { if ellipsis < 0 { - return nil + return nil, zone } n := IPv6len - j for k := j - 1; k >= ellipsis; k-- { - p[k+n] = p[k] + ip[k+n] = ip[k] } for k := ellipsis + n - 1; k >= ellipsis; k-- { - p[k] = 0 + ip[k] = 0 } } - return p + return ip, zone } // A ParseError represents a malformed text string and the type of string that was expected. @@ -600,26 +598,17 @@ func (e *ParseError) Error() string { return "invalid " + e.Type + ": " + e.Text } -func parseIP(s string) IP { - if p := parseIPv4(s); p != nil { - return p - } - if p := parseIPv6(s); p != nil { - return p - } - return nil -} - // ParseIP parses s as an IP address, returning the result. // The string s can be in dotted decimal ("74.125.19.99") // or IPv6 ("2001:4860:0:2001::68") form. // If s is not a valid textual representation of an IP address, // ParseIP returns nil. func ParseIP(s string) IP { - if p := parseIPv4(s); p != nil { - return p + if ip := parseIPv4(s); ip != nil { + return ip } - return parseIPv6(s) + ip, _ := parseIPv6(s, false) + return ip } // ParseCIDR parses s as a CIDR notation IP address and mask, @@ -634,15 +623,15 @@ func ParseCIDR(s string) (IP, *IPNet, error) { if i < 0 { return nil, nil, &ParseError{"CIDR address", s} } - ipstr, maskstr := s[:i], s[i+1:] + addr, mask := s[:i], s[i+1:] iplen := IPv4len - ip := parseIPv4(ipstr) + ip := parseIPv4(addr) if ip == nil { iplen = IPv6len - ip = parseIPv6(ipstr) + ip, _ = parseIPv6(addr, false) } - n, i, ok := dtoi(maskstr, 0) - if ip == nil || !ok || i != len(maskstr) || n < 0 || n > 8*iplen { + n, i, ok := dtoi(mask, 0) + if ip == nil || !ok || i != len(mask) || n < 0 || n > 8*iplen { return nil, nil, &ParseError{"CIDR address", s} } m := CIDRMask(n, 8*iplen) diff --git a/libgo/go/net/ip_test.go b/libgo/go/net/ip_test.go index 8324d2a..16f30d4 100644 --- a/libgo/go/net/ip_test.go +++ b/libgo/go/net/ip_test.go @@ -5,23 +5,12 @@ package net import ( - "bytes" "reflect" "runtime" "testing" ) -func isEqual(a, b []byte) bool { - if a == nil && b == nil { - return true - } - if a == nil || b == nil { - return false - } - return bytes.Equal(a, b) -} - -var parseiptests = []struct { +var parseIPTests = []struct { in string out IP }{ @@ -33,22 +22,23 @@ var parseiptests = []struct { {"::ffff:127.0.0.1", IPv4(127, 0, 0, 1)}, {"2001:4860:0:2001::68", IP{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68}}, {"::ffff:4a7d:1363", IPv4(74, 125, 19, 99)}, + {"fe80::1%lo0", nil}, + {"fe80::1%911", nil}, {"", nil}, } func TestParseIP(t *testing.T) { - for _, tt := range parseiptests { - if out := ParseIP(tt.in); !isEqual(out, tt.out) { + for _, tt := range parseIPTests { + if out := ParseIP(tt.in); !reflect.DeepEqual(out, tt.out) { t.Errorf("ParseIP(%q) = %v, want %v", tt.in, out, tt.out) } } } -var ipstringtests = []struct { +var ipStringTests = []struct { in IP - out string + out string // see RFC 5952 }{ - // cf. RFC 5952 (A Recommendation for IPv6 Address Text Representation) {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, 0, 0, 0x1, 0x23, 0, 0x12, 0, 0x1}, "2001:db8::123:12:1"}, {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x1}, "2001:db8::1"}, {IP{0x20, 0x1, 0xd, 0xb8, 0, 0, 0, 0x1, 0, 0, 0, 0x1, 0, 0, 0, 0x1}, "2001:db8:0:1:0:1:0:1"}, @@ -61,14 +51,14 @@ var ipstringtests = []struct { } func TestIPString(t *testing.T) { - for _, tt := range ipstringtests { + for _, tt := range ipStringTests { if out := tt.in.String(); out != tt.out { t.Errorf("IP.String(%v) = %q, want %q", tt.in, out, tt.out) } } } -var ipmasktests = []struct { +var ipMaskTests = []struct { in IP mask IPMask out IP @@ -82,14 +72,14 @@ var ipmasktests = []struct { } func TestIPMask(t *testing.T) { - for _, tt := range ipmasktests { + for _, tt := range ipMaskTests { if out := tt.in.Mask(tt.mask); out == nil || !tt.out.Equal(out) { t.Errorf("IP(%v).Mask(%v) = %v, want %v", tt.in, tt.mask, out, tt.out) } } } -var ipmaskstringtests = []struct { +var ipMaskStringTests = []struct { in IPMask out string }{ @@ -101,14 +91,14 @@ var ipmaskstringtests = []struct { } func TestIPMaskString(t *testing.T) { - for _, tt := range ipmaskstringtests { + for _, tt := range ipMaskStringTests { if out := tt.in.String(); out != tt.out { t.Errorf("IPMask.String(%v) = %q, want %q", tt.in, out, tt.out) } } } -var parsecidrtests = []struct { +var parseCIDRTests = []struct { in string ip IP net *IPNet @@ -138,18 +128,18 @@ var parsecidrtests = []struct { } func TestParseCIDR(t *testing.T) { - for _, tt := range parsecidrtests { + for _, tt := range parseCIDRTests { ip, net, err := ParseCIDR(tt.in) if !reflect.DeepEqual(err, tt.err) { t.Errorf("ParseCIDR(%q) = %v, %v; want %v, %v", tt.in, ip, net, tt.ip, tt.net) } - if err == nil && (!tt.ip.Equal(ip) || !tt.net.IP.Equal(net.IP) || !isEqual(net.Mask, tt.net.Mask)) { - t.Errorf("ParseCIDR(%q) = %v, {%v, %v}; want %v {%v, %v}", tt.in, ip, net.IP, net.Mask, tt.ip, tt.net.IP, tt.net.Mask) + if err == nil && (!tt.ip.Equal(ip) || !tt.net.IP.Equal(net.IP) || !reflect.DeepEqual(net.Mask, tt.net.Mask)) { + t.Errorf("ParseCIDR(%q) = %v, {%v, %v}; want %v, {%v, %v}", tt.in, ip, net.IP, net.Mask, tt.ip, tt.net.IP, tt.net.Mask) } } } -var ipnetcontainstests = []struct { +var ipNetContainsTests = []struct { ip IP net *IPNet ok bool @@ -165,14 +155,14 @@ var ipnetcontainstests = []struct { } func TestIPNetContains(t *testing.T) { - for _, tt := range ipnetcontainstests { + for _, tt := range ipNetContainsTests { if ok := tt.net.Contains(tt.ip); ok != tt.ok { t.Errorf("IPNet(%v).Contains(%v) = %v, want %v", tt.net, tt.ip, ok, tt.ok) } } } -var ipnetstringtests = []struct { +var ipNetStringTests = []struct { in *IPNet out string }{ @@ -183,14 +173,14 @@ var ipnetstringtests = []struct { } func TestIPNetString(t *testing.T) { - for _, tt := range ipnetstringtests { + for _, tt := range ipNetStringTests { if out := tt.in.String(); out != tt.out { t.Errorf("IPNet.String(%v) = %q, want %q", tt.in, out, tt.out) } } } -var cidrmasktests = []struct { +var cidrMaskTests = []struct { ones int bits int out IPMask @@ -210,8 +200,8 @@ var cidrmasktests = []struct { } func TestCIDRMask(t *testing.T) { - for _, tt := range cidrmasktests { - if out := CIDRMask(tt.ones, tt.bits); !isEqual(out, tt.out) { + for _, tt := range cidrMaskTests { + if out := CIDRMask(tt.ones, tt.bits); !reflect.DeepEqual(out, tt.out) { t.Errorf("CIDRMask(%v, %v) = %v, want %v", tt.ones, tt.bits, out, tt.out) } } @@ -229,7 +219,7 @@ var ( v4maskzero = IPMask{0, 0, 0, 0} ) -var networknumberandmasktests = []struct { +var networkNumberAndMaskTests = []struct { in IPNet out IPNet }{ @@ -251,43 +241,90 @@ var networknumberandmasktests = []struct { } func TestNetworkNumberAndMask(t *testing.T) { - for _, tt := range networknumberandmasktests { + for _, tt := range networkNumberAndMaskTests { ip, m := networkNumberAndMask(&tt.in) out := &IPNet{IP: ip, Mask: m} if !reflect.DeepEqual(&tt.out, out) { - t.Errorf("networkNumberAndMask(%v) = %v; want %v", tt.in, out, &tt.out) + t.Errorf("networkNumberAndMask(%v) = %v, want %v", tt.in, out, &tt.out) } } } -var splitjointests = []struct { - Host string - Port string - Join string +var splitJoinTests = []struct { + host string + port string + join string }{ {"www.google.com", "80", "www.google.com:80"}, {"127.0.0.1", "1234", "127.0.0.1:1234"}, {"::1", "80", "[::1]:80"}, + {"fe80::1%lo0", "80", "[fe80::1%lo0]:80"}, + {"localhost%lo0", "80", "[localhost%lo0]:80"}, + {"", "0", ":0"}, + {"google.com", "https%foo", "google.com:https%foo"}, // Go 1.0 behavior + {"127.0.0.1", "", "127.0.0.1:"}, // Go 1.0 behaviour + {"www.google.com", "", "www.google.com:"}, // Go 1.0 behaviour +} + +var splitFailureTests = []struct { + hostPort string + err string +}{ + {"www.google.com", "missing port in address"}, + {"127.0.0.1", "missing port in address"}, + {"[::1]", "missing port in address"}, + {"[fe80::1%lo0]", "missing port in address"}, + {"[localhost%lo0]", "missing port in address"}, + {"localhost%lo0", "missing port in address"}, + + {"::1", "too many colons in address"}, + {"fe80::1%lo0", "too many colons in address"}, + {"fe80::1%lo0:80", "too many colons in address"}, + + {"localhost%lo0:80", "missing brackets in address"}, + + // Test cases that didn't fail in Go 1.0 + + {"[foo:bar]", "missing port in address"}, + {"[foo:bar]baz", "missing port in address"}, + {"[foo]bar:baz", "missing port in address"}, + + {"[foo]:[bar]:baz", "too many colons in address"}, + + {"[foo]:[bar]baz", "unexpected '[' in address"}, + {"foo[bar]:baz", "unexpected '[' in address"}, + + {"foo]bar:baz", "unexpected ']' in address"}, } func TestSplitHostPort(t *testing.T) { - for _, tt := range splitjointests { - if host, port, err := SplitHostPort(tt.Join); host != tt.Host || port != tt.Port || err != nil { - t.Errorf("SplitHostPort(%q) = %q, %q, %v; want %q, %q, nil", tt.Join, host, port, err, tt.Host, tt.Port) + for _, tt := range splitJoinTests { + if host, port, err := SplitHostPort(tt.join); host != tt.host || port != tt.port || err != nil { + t.Errorf("SplitHostPort(%q) = %q, %q, %v; want %q, %q, nil", tt.join, host, port, err, tt.host, tt.port) + } + } + for _, tt := range splitFailureTests { + if _, _, err := SplitHostPort(tt.hostPort); err == nil { + t.Errorf("SplitHostPort(%q) should have failed", tt.hostPort) + } else { + e := err.(*AddrError) + if e.Err != tt.err { + t.Errorf("SplitHostPort(%q) = _, _, %q; want %q", tt.hostPort, e.Err, tt.err) + } } } } func TestJoinHostPort(t *testing.T) { - for _, tt := range splitjointests { - if join := JoinHostPort(tt.Host, tt.Port); join != tt.Join { - t.Errorf("JoinHostPort(%q, %q) = %q; want %q", tt.Host, tt.Port, join, tt.Join) + for _, tt := range splitJoinTests { + if join := JoinHostPort(tt.host, tt.port); join != tt.join { + t.Errorf("JoinHostPort(%q, %q) = %q; want %q", tt.host, tt.port, join, tt.join) } } } -var ipaftests = []struct { +var ipAddrFamilyTests = []struct { in IP af4 bool af6 bool @@ -310,7 +347,7 @@ var ipaftests = []struct { } func TestIPAddrFamily(t *testing.T) { - for _, tt := range ipaftests { + for _, tt := range ipAddrFamilyTests { if af := tt.in.To4() != nil; af != tt.af4 { t.Errorf("verifying IPv4 address family for %q = %v, want %v", tt.in, af, tt.af4) } @@ -320,7 +357,7 @@ func TestIPAddrFamily(t *testing.T) { } } -var ipscopetests = []struct { +var ipAddrScopeTests = []struct { scope func(IP) bool in IP ok bool @@ -361,7 +398,7 @@ func name(f interface{}) string { } func TestIPAddrScope(t *testing.T) { - for _, tt := range ipscopetests { + for _, tt := range ipAddrScopeTests { if ok := tt.scope(tt.in); ok != tt.ok { t.Errorf("%s(%q) = %v, want %v", name(tt.scope), tt.in, ok, tt.ok) } diff --git a/libgo/go/net/ipraw_test.go b/libgo/go/net/ipraw_test.go index db1c769..12c199d 100644 --- a/libgo/go/net/ipraw_test.go +++ b/libgo/go/net/ipraw_test.go @@ -2,32 +2,36 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build !plan9 - package net import ( "bytes" + "errors" + "fmt" "os" "reflect" - "syscall" "testing" "time" ) -var resolveIPAddrTests = []struct { +type resolveIPAddrTest struct { net string litAddr string addr *IPAddr err error -}{ +} + +var resolveIPAddrTests = []resolveIPAddrTest{ {"ip", "127.0.0.1", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil}, {"ip4", "127.0.0.1", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil}, {"ip4:icmp", "127.0.0.1", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil}, {"ip", "::1", &IPAddr{IP: ParseIP("::1")}, nil}, {"ip6", "::1", &IPAddr{IP: ParseIP("::1")}, nil}, - {"ip6:icmp", "::1", &IPAddr{IP: ParseIP("::1")}, nil}, + {"ip6:ipv6-icmp", "::1", &IPAddr{IP: ParseIP("::1")}, nil}, + + {"ip", "::1%en0", &IPAddr{IP: ParseIP("::1"), Zone: "en0"}, nil}, + {"ip6", "::1%911", &IPAddr{IP: ParseIP("::1"), Zone: "911"}, nil}, {"", "127.0.0.1", &IPAddr{IP: IPv4(127, 0, 0, 1)}, nil}, // Go 1.0 behavior {"", "::1", &IPAddr{IP: ParseIP("::1")}, nil}, // Go 1.0 behavior @@ -37,208 +41,290 @@ var resolveIPAddrTests = []struct { {"tcp", "1.2.3.4:123", nil, UnknownNetworkError("tcp")}, } +func init() { + if ifi := loopbackInterface(); ifi != nil { + index := fmt.Sprintf("%v", ifi.Index) + resolveIPAddrTests = append(resolveIPAddrTests, []resolveIPAddrTest{ + {"ip6", "fe80::1%" + ifi.Name, &IPAddr{IP: ParseIP("fe80::1"), Zone: zoneToString(ifi.Index)}, nil}, + {"ip6", "fe80::1%" + index, &IPAddr{IP: ParseIP("fe80::1"), Zone: index}, nil}, + }...) + } +} + func TestResolveIPAddr(t *testing.T) { for _, tt := range resolveIPAddrTests { addr, err := ResolveIPAddr(tt.net, tt.litAddr) if err != tt.err { - t.Fatalf("ResolveIPAddr(%v, %v) failed: %v", tt.net, tt.litAddr, err) - } - if !reflect.DeepEqual(addr, tt.addr) { + condFatalf(t, "ResolveIPAddr(%v, %v) failed: %v", tt.net, tt.litAddr, err) + } else if !reflect.DeepEqual(addr, tt.addr) { t.Fatalf("got %#v; expected %#v", addr, tt.addr) } } } -var icmpTests = []struct { +var icmpEchoTests = []struct { net string laddr string raddr string - ipv6 bool // test with underlying AF_INET6 socket }{ - {"ip4:icmp", "", "127.0.0.1", false}, - {"ip6:ipv6-icmp", "", "::1", true}, + {"ip4:icmp", "0.0.0.0", "127.0.0.1"}, + {"ip6:ipv6-icmp", "::", "::1"}, } -func TestICMP(t *testing.T) { +func TestConnICMPEcho(t *testing.T) { if os.Getuid() != 0 { t.Skip("skipping test; must be root") } - seqnum := 61455 - for _, tt := range icmpTests { - if tt.ipv6 && !supportsIPv6 { + for i, tt := range icmpEchoTests { + net, _, err := parseNetwork(tt.net) + if err != nil { + t.Fatalf("parseNetwork failed: %v", err) + } + if net == "ip6" && !supportsIPv6 { continue } - id := os.Getpid() & 0xffff - seqnum++ - echo := newICMPEchoRequest(tt.net, id, seqnum, 128, []byte("Go Go Gadget Ping!!!")) - exchangeICMPEcho(t, tt.net, tt.laddr, tt.raddr, echo) - } -} - -func exchangeICMPEcho(t *testing.T, net, laddr, raddr string, echo []byte) { - c, err := ListenPacket(net, laddr) - if err != nil { - t.Errorf("ListenPacket(%q, %q) failed: %v", net, laddr, err) - return - } - c.SetDeadline(time.Now().Add(100 * time.Millisecond)) - defer c.Close() - - ra, err := ResolveIPAddr(net, raddr) - if err != nil { - t.Errorf("ResolveIPAddr(%q, %q) failed: %v", net, raddr, err) - return - } - - waitForReady := make(chan bool) - go icmpEchoTransponder(t, net, raddr, waitForReady) - <-waitForReady - _, err = c.WriteTo(echo, ra) - if err != nil { - t.Errorf("WriteTo failed: %v", err) - return - } + c, err := Dial(tt.net, tt.raddr) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + c.SetDeadline(time.Now().Add(100 * time.Millisecond)) + defer c.Close() - reply := make([]byte, 256) - for { - _, _, err := c.ReadFrom(reply) + typ := icmpv4EchoRequest + if net == "ip6" { + typ = icmpv6EchoRequest + } + xid, xseq := os.Getpid()&0xffff, i+1 + b, err := (&icmpMessage{ + Type: typ, Code: 0, + Body: &icmpEcho{ + ID: xid, Seq: xseq, + Data: bytes.Repeat([]byte("Go Go Gadget Ping!!!"), 3), + }, + }).Marshal() if err != nil { - t.Errorf("ReadFrom failed: %v", err) - return + t.Fatalf("icmpMessage.Marshal failed: %v", err) } - switch c.(*IPConn).fd.family { - case syscall.AF_INET: - if reply[0] != ICMP4_ECHO_REPLY { - continue + if _, err := c.Write(b); err != nil { + t.Fatalf("Conn.Write failed: %v", err) + } + var m *icmpMessage + for { + if _, err := c.Read(b); err != nil { + t.Fatalf("Conn.Read failed: %v", err) + } + if net == "ip4" { + b = ipv4Payload(b) } - case syscall.AF_INET6: - if reply[0] != ICMP6_ECHO_REPLY { + if m, err = parseICMPMessage(b); err != nil { + t.Fatalf("parseICMPMessage failed: %v", err) + } + switch m.Type { + case icmpv4EchoRequest, icmpv6EchoRequest: continue } + break } - xid, xseqnum := parseICMPEchoReply(echo) - rid, rseqnum := parseICMPEchoReply(reply) - if rid != xid || rseqnum != xseqnum { - t.Errorf("ID = %v, Seqnum = %v, want ID = %v, Seqnum = %v", rid, rseqnum, xid, xseqnum) - return + switch p := m.Body.(type) { + case *icmpEcho: + if p.ID != xid || p.Seq != xseq { + t.Fatalf("got id=%v, seqnum=%v; expected id=%v, seqnum=%v", p.ID, p.Seq, xid, xseq) + } + default: + t.Fatalf("got type=%v, code=%v; expected type=%v, code=%v", m.Type, m.Code, typ, 0) } - break } } -func icmpEchoTransponder(t *testing.T, net, raddr string, waitForReady chan bool) { - c, err := Dial(net, raddr) - if err != nil { - waitForReady <- true - t.Errorf("Dial(%q, %q) failed: %v", net, raddr, err) - return +func TestPacketConnICMPEcho(t *testing.T) { + if os.Getuid() != 0 { + t.Skip("skipping test; must be root") } - c.SetDeadline(time.Now().Add(100 * time.Millisecond)) - defer c.Close() - waitForReady <- true - echo := make([]byte, 256) - var nr int - for { - nr, err = c.Read(echo) + for i, tt := range icmpEchoTests { + net, _, err := parseNetwork(tt.net) if err != nil { - t.Errorf("Read failed: %v", err) - return + t.Fatalf("parseNetwork failed: %v", err) } - switch c.(*IPConn).fd.family { - case syscall.AF_INET: - if echo[0] != ICMP4_ECHO_REQUEST { - continue + if net == "ip6" && !supportsIPv6 { + continue + } + + c, err := ListenPacket(tt.net, tt.laddr) + if err != nil { + t.Fatalf("ListenPacket failed: %v", err) + } + c.SetDeadline(time.Now().Add(100 * time.Millisecond)) + defer c.Close() + + ra, err := ResolveIPAddr(tt.net, tt.raddr) + if err != nil { + t.Fatalf("ResolveIPAddr failed: %v", err) + } + typ := icmpv4EchoRequest + if net == "ip6" { + typ = icmpv6EchoRequest + } + xid, xseq := os.Getpid()&0xffff, i+1 + b, err := (&icmpMessage{ + Type: typ, Code: 0, + Body: &icmpEcho{ + ID: xid, Seq: xseq, + Data: bytes.Repeat([]byte("Go Go Gadget Ping!!!"), 3), + }, + }).Marshal() + if err != nil { + t.Fatalf("icmpMessage.Marshal failed: %v", err) + } + if _, err := c.WriteTo(b, ra); err != nil { + t.Fatalf("PacketConn.WriteTo failed: %v", err) + } + var m *icmpMessage + for { + if _, _, err := c.ReadFrom(b); err != nil { + t.Fatalf("PacketConn.ReadFrom failed: %v", err) + } + // TODO: fix issue 3944 + //if net == "ip4" { + // b = ipv4Payload(b) + //} + if m, err = parseICMPMessage(b); err != nil { + t.Fatalf("parseICMPMessage failed: %v", err) } - case syscall.AF_INET6: - if echo[0] != ICMP6_ECHO_REQUEST { + switch m.Type { + case icmpv4EchoRequest, icmpv6EchoRequest: continue } + break + } + switch p := m.Body.(type) { + case *icmpEcho: + if p.ID != xid || p.Seq != xseq { + t.Fatalf("got id=%v, seqnum=%v; expected id=%v, seqnum=%v", p.ID, p.Seq, xid, xseq) + } + default: + t.Fatalf("got type=%v, code=%v; expected type=%v, code=%v", m.Type, m.Code, typ, 0) } - break - } - - switch c.(*IPConn).fd.family { - case syscall.AF_INET: - echo[0] = ICMP4_ECHO_REPLY - case syscall.AF_INET6: - echo[0] = ICMP6_ECHO_REPLY } +} - _, err = c.Write(echo[:nr]) - if err != nil { - t.Errorf("Write failed: %v", err) - return +func ipv4Payload(b []byte) []byte { + if len(b) < 20 { + return b } + hdrlen := int(b[0]&0x0f) << 2 + return b[hdrlen:] } const ( - ICMP4_ECHO_REQUEST = 8 - ICMP4_ECHO_REPLY = 0 - ICMP6_ECHO_REQUEST = 128 - ICMP6_ECHO_REPLY = 129 + icmpv4EchoRequest = 8 + icmpv4EchoReply = 0 + icmpv6EchoRequest = 128 + icmpv6EchoReply = 129 ) -func newICMPEchoRequest(net string, id, seqnum, msglen int, filler []byte) []byte { - afnet, _, _ := parseDialNetwork(net) - switch afnet { - case "ip4": - return newICMPv4EchoRequest(id, seqnum, msglen, filler) - case "ip6": - return newICMPv6EchoRequest(id, seqnum, msglen, filler) - } - return nil +// icmpMessage represents an ICMP message. +type icmpMessage struct { + Type int // type + Code int // code + Checksum int // checksum + Body icmpMessageBody // body } -func newICMPv4EchoRequest(id, seqnum, msglen int, filler []byte) []byte { - b := newICMPInfoMessage(id, seqnum, msglen, filler) - b[0] = ICMP4_ECHO_REQUEST +// icmpMessageBody represents an ICMP message body. +type icmpMessageBody interface { + Len() int + Marshal() ([]byte, error) +} - // calculate ICMP checksum - cklen := len(b) +// Marshal returns the binary enconding of the ICMP echo request or +// reply message m. +func (m *icmpMessage) Marshal() ([]byte, error) { + b := []byte{byte(m.Type), byte(m.Code), 0, 0} + if m.Body != nil && m.Body.Len() != 0 { + mb, err := m.Body.Marshal() + if err != nil { + return nil, err + } + b = append(b, mb...) + } + switch m.Type { + case icmpv6EchoRequest, icmpv6EchoReply: + return b, nil + } + csumcv := len(b) - 1 // checksum coverage s := uint32(0) - for i := 0; i < cklen-1; i += 2 { + for i := 0; i < csumcv; i += 2 { s += uint32(b[i+1])<<8 | uint32(b[i]) } - if cklen&1 == 1 { - s += uint32(b[cklen-1]) + if csumcv&1 == 0 { + s += uint32(b[csumcv]) } - s = (s >> 16) + (s & 0xffff) - s = s + (s >> 16) - // place checksum back in header; using ^= avoids the - // assumption the checksum bytes are zero - b[2] ^= uint8(^s & 0xff) - b[3] ^= uint8(^s >> 8) - - return b + s = s>>16 + s&0xffff + s = s + s>>16 + // Place checksum back in header; using ^= avoids the + // assumption the checksum bytes are zero. + b[2] ^= byte(^s & 0xff) + b[3] ^= byte(^s >> 8) + return b, nil } -func newICMPv6EchoRequest(id, seqnum, msglen int, filler []byte) []byte { - b := newICMPInfoMessage(id, seqnum, msglen, filler) - b[0] = ICMP6_ECHO_REQUEST - return b +// parseICMPMessage parses b as an ICMP message. +func parseICMPMessage(b []byte) (*icmpMessage, error) { + msglen := len(b) + if msglen < 4 { + return nil, errors.New("message too short") + } + m := &icmpMessage{Type: int(b[0]), Code: int(b[1]), Checksum: int(b[2])<<8 | int(b[3])} + if msglen > 4 { + var err error + switch m.Type { + case icmpv4EchoRequest, icmpv4EchoReply, icmpv6EchoRequest, icmpv6EchoReply: + m.Body, err = parseICMPEcho(b[4:]) + if err != nil { + return nil, err + } + } + } + return m, nil +} + +// imcpEcho represenets an ICMP echo request or reply message body. +type icmpEcho struct { + ID int // identifier + Seq int // sequence number + Data []byte // data } -func newICMPInfoMessage(id, seqnum, msglen int, filler []byte) []byte { - b := make([]byte, msglen) - copy(b[8:], bytes.Repeat(filler, (msglen-8)/len(filler)+1)) - b[0] = 0 // type - b[1] = 0 // code - b[2] = 0 // checksum - b[3] = 0 // checksum - b[4] = uint8(id >> 8) // identifier - b[5] = uint8(id & 0xff) // identifier - b[6] = uint8(seqnum >> 8) // sequence number - b[7] = uint8(seqnum & 0xff) // sequence number - return b +func (p *icmpEcho) Len() int { + if p == nil { + return 0 + } + return 4 + len(p.Data) } -func parseICMPEchoReply(b []byte) (id, seqnum int) { - id = int(b[4])<<8 | int(b[5]) - seqnum = int(b[6])<<8 | int(b[7]) - return +// Marshal returns the binary enconding of the ICMP echo request or +// reply message body p. +func (p *icmpEcho) Marshal() ([]byte, error) { + b := make([]byte, 4+len(p.Data)) + b[0], b[1] = byte(p.ID>>8), byte(p.ID&0xff) + b[2], b[3] = byte(p.Seq>>8), byte(p.Seq&0xff) + copy(b[4:], p.Data) + return b, nil +} + +// parseICMPEcho parses b as an ICMP echo request or reply message +// body. +func parseICMPEcho(b []byte) (*icmpEcho, error) { + bodylen := len(b) + p := &icmpEcho{ID: int(b[0])<<8 | int(b[1]), Seq: int(b[2])<<8 | int(b[3])} + if bodylen > 4 { + p.Data = make([]byte, bodylen-4) + copy(p.Data, b[4:]) + } + return p, nil } var ipConnLocalNameTests = []struct { @@ -258,14 +344,27 @@ func TestIPConnLocalName(t *testing.T) { for _, tt := range ipConnLocalNameTests { c, err := ListenIP(tt.net, tt.laddr) if err != nil { - t.Errorf("ListenIP failed: %v", err) - return + t.Fatalf("ListenIP failed: %v", err) } defer c.Close() - la := c.LocalAddr() - if la == nil { - t.Error("IPConn.LocalAddr failed") - return + if la := c.LocalAddr(); la == nil { + t.Fatal("IPConn.LocalAddr failed") } } } + +func TestIPConnRemoteName(t *testing.T) { + if os.Getuid() != 0 { + t.Skip("skipping test; must be root") + } + + raddr := &IPAddr{IP: IPv4(127, 0, 0, 10).To4()} + c, err := DialIP("ip:tcp", &IPAddr{IP: IPv4(127, 0, 0, 1)}, raddr) + if err != nil { + t.Fatalf("DialIP failed: %v", err) + } + defer c.Close() + if !reflect.DeepEqual(raddr, c.RemoteAddr()) { + t.Fatalf("got %#v, expected %#v", c.RemoteAddr(), raddr) + } +} diff --git a/libgo/go/net/iprawsock.go b/libgo/go/net/iprawsock.go index 13bfd62..0be94eb 100644 --- a/libgo/go/net/iprawsock.go +++ b/libgo/go/net/iprawsock.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Raw IP sockets - package net // IPAddr represents the address of an IP end point. @@ -19,17 +17,20 @@ func (a *IPAddr) String() string { if a == nil { return "<nil>" } + if a.Zone != "" { + return a.IP.String() + "%" + a.Zone + } return a.IP.String() } -// ResolveIPAddr parses addr as an IP address and resolves domain -// names to numeric addresses on the network net, which must be -// "ip", "ip4" or "ip6". +// ResolveIPAddr parses addr as an IP address of the form "host" or +// "ipv6-host%zone" and resolves the domain name on the network net, +// which must be "ip", "ip4" or "ip6". func ResolveIPAddr(net, addr string) (*IPAddr, error) { if net == "" { // a hint wildcard for Go 1.0 undocumented behavior net = "ip" } - afnet, _, err := parseDialNetwork(net) + afnet, _, err := parseNetwork(net) if err != nil { return nil, err } diff --git a/libgo/go/net/iprawsock_plan9.go b/libgo/go/net/iprawsock_plan9.go index 88e3b2c..e62d116 100644 --- a/libgo/go/net/iprawsock_plan9.go +++ b/libgo/go/net/iprawsock_plan9.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Raw IP sockets for Plan 9 - package net import ( @@ -34,7 +32,7 @@ func (c *IPConn) ReadFrom(b []byte) (int, Addr, error) { } // ReadMsgIP reads a packet from c, copying the payload into b and the -// associdated out-of-band data into oob. It returns the number of +// associated out-of-band data into oob. It returns the number of // bytes copied into b, the number of bytes copied into oob, the flags // that were set on the packet and the source address of the packet. func (c *IPConn) ReadMsgIP(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err error) { @@ -76,7 +74,7 @@ func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn, } // ListenIP listens for incoming IP packets addressed to the local -// address laddr. The returned connection c's ReadFrom and WriteTo +// address laddr. The returned connection's ReadFrom and WriteTo // methods can be used to receive and send IP packets with per-packet // addressing. func ListenIP(netProto string, laddr *IPAddr) (*IPConn, error) { diff --git a/libgo/go/net/iprawsock_posix.go b/libgo/go/net/iprawsock_posix.go index 7a8cd44..caeeb46 100644 --- a/libgo/go/net/iprawsock_posix.go +++ b/libgo/go/net/iprawsock_posix.go @@ -4,8 +4,6 @@ // +build darwin freebsd linux netbsd openbsd windows -// Raw IP sockets for POSIX - package net import ( @@ -51,8 +49,8 @@ func (a *IPAddr) toAddr() sockaddr { return a } -// IPConn is the implementation of the Conn and PacketConn -// interfaces for IP network connections. +// IPConn is the implementation of the Conn and PacketConn interfaces +// for IP network connections. type IPConn struct { conn } @@ -98,7 +96,7 @@ func (c *IPConn) ReadFrom(b []byte) (int, Addr, error) { } // ReadMsgIP reads a packet from c, copying the payload into b and the -// associdated out-of-band data into oob. It returns the number of +// associated out-of-band data into oob. It returns the number of // bytes copied into b, the number of bytes copied into oob, the flags // that were set on the packet and the source address of the packet. func (c *IPConn) ReadMsgIP(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err error) { @@ -116,12 +114,13 @@ func (c *IPConn) ReadMsgIP(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err return } -// WriteToIP writes an IP packet to addr via c, copying the payload from b. +// WriteToIP writes an IP packet to addr via c, copying the payload +// from b. // -// WriteToIP can be made to time out and return -// an error with Timeout() == true after a fixed time limit; -// see SetDeadline and SetWriteDeadline. -// On packet-oriented connections, write timeouts are rare. +// WriteToIP can be made to time out and return an error with +// Timeout() == true after a fixed time limit; see SetDeadline and +// SetWriteDeadline. On packet-oriented connections, write timeouts +// are rare. func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (int, error) { if !c.ok() { return 0, syscall.EINVAL @@ -159,14 +158,15 @@ func (c *IPConn) WriteMsgIP(b, oob []byte, addr *IPAddr) (n, oobn int, err error return c.fd.WriteMsg(b, oob, sa) } -// DialIP connects to the remote address raddr on the network protocol netProto, -// which must be "ip", "ip4", or "ip6" followed by a colon and a protocol number or name. +// DialIP connects to the remote address raddr on the network protocol +// netProto, which must be "ip", "ip4", or "ip6" followed by a colon +// and a protocol number or name. func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) { return dialIP(netProto, laddr, raddr, noDeadline) } func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn, error) { - net, proto, err := parseDialNetwork(netProto) + net, proto, err := parseNetwork(netProto) if err != nil { return nil, err } @@ -185,12 +185,12 @@ func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn, return newIPConn(fd), nil } -// ListenIP listens for incoming IP packets addressed to the -// local address laddr. The returned connection c's ReadFrom -// and WriteTo methods can be used to receive and send IP -// packets with per-packet addressing. +// ListenIP listens for incoming IP packets addressed to the local +// address laddr. The returned connection's ReadFrom and WriteTo +// methods can be used to receive and send IP packets with per-packet +// addressing. func ListenIP(netProto string, laddr *IPAddr) (*IPConn, error) { - net, proto, err := parseDialNetwork(netProto) + net, proto, err := parseNetwork(netProto) if err != nil { return nil, err } diff --git a/libgo/go/net/ipsock.go b/libgo/go/net/ipsock.go index 5636c85..d930595 100644 --- a/libgo/go/net/ipsock.go +++ b/libgo/go/net/ipsock.go @@ -68,40 +68,94 @@ func (e InvalidAddrError) Error() string { return string(e) } func (e InvalidAddrError) Timeout() bool { return false } func (e InvalidAddrError) Temporary() bool { return false } -// SplitHostPort splits a network address of the form -// "host:port" or "[host]:port" into host and port. -// The latter form must be used when host contains a colon. +// SplitHostPort splits a network address of the form "host:port", +// "[host]:port" or "[ipv6-host%zone]:port" into host or +// ipv6-host%zone and port. A literal address or host name for IPv6 +// must be enclosed in square brackets, as in "[::1]:80", +// "[ipv6-host]:http" or "[ipv6-host%zone]:80". func SplitHostPort(hostport string) (host, port string, err error) { - host, port, _, err = splitHostPort(hostport) - return -} + j, k := 0, 0 -func splitHostPort(hostport string) (host, port, zone string, err error) { // The port starts after the last colon. i := last(hostport, ':') if i < 0 { - err = &AddrError{"missing port in address", hostport} - return + goto missingPort } - host, port = hostport[:i], hostport[i+1:] - // Can put brackets around host ... - if len(host) > 0 && host[0] == '[' && host[len(host)-1] == ']' { - host = host[1 : len(host)-1] + + if hostport[0] == '[' { + // Expect the first ']' just before the last ':'. + end := byteIndex(hostport, ']') + if end < 0 { + err = &AddrError{"missing ']' in address", hostport} + return + } + switch end + 1 { + case len(hostport): + // There can't be a ':' behind the ']' now. + goto missingPort + case i: + // The expected result. + default: + // Either ']' isn't followed by a colon, or it is + // followed by a colon that is not the last one. + if hostport[end+1] == ':' { + goto tooManyColons + } + goto missingPort + } + host = hostport[1:end] + j, k = 1, end+1 // there can't be a '[' resp. ']' before these positions } else { - // ... but if there are no brackets, no colons. + host = hostport[:i] if byteIndex(host, ':') >= 0 { - err = &AddrError{"too many colons in address", hostport} - return + goto tooManyColons } + if byteIndex(host, '%') >= 0 { + goto missingBrackets + } + } + if byteIndex(hostport[j:], '[') >= 0 { + err = &AddrError{"unexpected '[' in address", hostport} + return + } + if byteIndex(hostport[k:], ']') >= 0 { + err = &AddrError{"unexpected ']' in address", hostport} + return + } + + port = hostport[i+1:] + return + +missingPort: + err = &AddrError{"missing port in address", hostport} + return + +tooManyColons: + err = &AddrError{"too many colons in address", hostport} + return + +missingBrackets: + err = &AddrError{"missing brackets in address", hostport} + return +} + +func splitHostZone(s string) (host, zone string) { + // The IPv6 scoped addressing zone identifer starts after the + // last percent sign. + if i := last(s, '%'); i > 0 { + host, zone = s[:i], s[i+1:] + } else { + host = s } return } -// JoinHostPort combines host and port into a network address -// of the form "host:port" or, if host contains a colon, "[host]:port". +// JoinHostPort combines host and port into a network address of the +// form "host:port" or, if host contains a colon or a percent sign, +// "[host]:port". func JoinHostPort(host, port string) string { - // If host has colons, have to bracket it. - if byteIndex(host, ':') >= 0 { + // If host has colons or a percent sign, have to bracket it. + if byteIndex(host, ':') >= 0 || byteIndex(host, '%') >= 0 { return "[" + host + "]:" + port } return host + ":" + port @@ -116,7 +170,7 @@ func resolveInternetAddr(net, addr string, deadline time.Time) (Addr, error) { switch net { case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6": if addr != "" { - if host, port, zone, err = splitHostPort(addr); err != nil { + if host, port, err = SplitHostPort(addr); err != nil { return nil, err } if portnum, err = parsePort(net, port); err != nil { @@ -145,21 +199,25 @@ func resolveInternetAddr(net, addr string, deadline time.Time) (Addr, error) { return inetaddr(net, nil, portnum, zone), nil } // Try as an IP address. - if ip := ParseIP(host); ip != nil { + if ip := parseIPv4(host); ip != nil { + return inetaddr(net, ip, portnum, zone), nil + } + if ip, zone := parseIPv6(host, true); ip != nil { return inetaddr(net, ip, portnum, zone), nil } + // Try as a domain name. + host, zone = splitHostZone(host) + addrs, err := lookupHostDeadline(host, deadline) + if err != nil { + return nil, err + } var filter func(IP) IP if net != "" && net[len(net)-1] == '4' { filter = ipv4only } - if net != "" && net[len(net)-1] == '6' { + if net != "" && net[len(net)-1] == '6' || zone != "" { filter = ipv6only } - // Try as a DNS name. - addrs, err := lookupHostDeadline(host, deadline) - if err != nil { - return nil, err - } ip := firstFavoriteAddr(filter, addrs) if ip == nil { // should not happen diff --git a/libgo/go/net/ipsock_plan9.go b/libgo/go/net/ipsock_plan9.go index eaef768..c7d542d 100644 --- a/libgo/go/net/ipsock_plan9.go +++ b/libgo/go/net/ipsock_plan9.go @@ -9,6 +9,7 @@ package net import ( "errors" "os" + "syscall" ) // /sys/include/ape/sys/socket.h:/SOMAXCONN @@ -104,72 +105,89 @@ func startPlan9(net string, addr Addr) (ctl *os.File, dest, proto, name string, return f, dest, proto, string(buf[:n]), nil } -func dialPlan9(net string, laddr, raddr Addr) (*netFD, error) { +func netErr(e error) { + oe, ok := e.(*OpError) + if !ok { + return + } + if pe, ok := oe.Err.(*os.PathError); ok { + if _, ok = pe.Err.(syscall.ErrorString); ok { + oe.Err = pe.Err + } + } +} + +func dialPlan9(net string, laddr, raddr Addr) (fd *netFD, err error) { + defer func() { netErr(err) }() f, dest, proto, name, err := startPlan9(net, raddr) if err != nil { - return nil, err + return nil, &OpError{"dial", net, raddr, err} } _, err = f.WriteString("connect " + dest) if err != nil { f.Close() - return nil, err + return nil, &OpError{"dial", f.Name(), raddr, err} } - laddr, err = readPlan9Addr(proto, "/net/"+proto+"/"+name+"/local") + data, err := os.OpenFile("/net/"+proto+"/"+name+"/data", os.O_RDWR, 0) if err != nil { f.Close() - return nil, err + return nil, &OpError{"dial", net, raddr, err} } - raddr, err = readPlan9Addr(proto, "/net/"+proto+"/"+name+"/remote") + laddr, err = readPlan9Addr(proto, "/net/"+proto+"/"+name+"/local") if err != nil { + data.Close() f.Close() - return nil, err + return nil, &OpError{"dial", proto, raddr, err} } - return newFD(proto, name, f, laddr, raddr), nil + return newFD(proto, name, f, data, laddr, raddr), nil } -func listenPlan9(net string, laddr Addr) (*netFD, error) { +func listenPlan9(net string, laddr Addr) (fd *netFD, err error) { + defer func() { netErr(err) }() f, dest, proto, name, err := startPlan9(net, laddr) if err != nil { - return nil, err + return nil, &OpError{"listen", net, laddr, err} } _, err = f.WriteString("announce " + dest) if err != nil { f.Close() - return nil, err + return nil, &OpError{"announce", proto, laddr, err} } laddr, err = readPlan9Addr(proto, "/net/"+proto+"/"+name+"/local") if err != nil { f.Close() - return nil, err + return nil, &OpError{Op: "listen", Net: net, Err: err} } - return &netFD{proto: proto, name: name, dir: "/net/" + proto + "/" + name, ctl: f, laddr: laddr}, nil + return newFD(proto, name, f, nil, laddr, nil), nil } func (l *netFD) netFD() *netFD { - return newFD(l.proto, l.name, l.ctl, l.laddr, nil) + return newFD(l.proto, l.name, l.ctl, l.data, l.laddr, l.raddr) } -func (l *netFD) acceptPlan9() (*netFD, error) { +func (l *netFD) acceptPlan9() (fd *netFD, err error) { + defer func() { netErr(err) }() f, err := os.Open(l.dir + "/listen") if err != nil { - return nil, err + return nil, &OpError{"accept", l.dir + "/listen", l.laddr, err} } var buf [16]byte n, err := f.Read(buf[:]) if err != nil { f.Close() - return nil, err + return nil, &OpError{"accept", l.dir + "/listen", l.laddr, err} } name := string(buf[:n]) - laddr, err := readPlan9Addr(l.proto, l.dir+"/local") + data, err := os.OpenFile("/net/"+l.proto+"/"+name+"/data", os.O_RDWR, 0) if err != nil { f.Close() - return nil, err + return nil, &OpError{"accept", l.proto, l.laddr, err} } - raddr, err := readPlan9Addr(l.proto, l.dir+"/remote") + raddr, err := readPlan9Addr(l.proto, "/net/"+l.proto+"/"+name+"/remote") if err != nil { + data.Close() f.Close() - return nil, err + return nil, &OpError{"accept", l.proto, l.laddr, err} } - return newFD(l.proto, name, f, laddr, raddr), nil + return newFD(l.proto, name, f, data, l.laddr, raddr), nil } diff --git a/libgo/go/net/lookup_plan9.go b/libgo/go/net/lookup_plan9.go index ae7cf79..94c5533 100644 --- a/libgo/go/net/lookup_plan9.go +++ b/libgo/go/net/lookup_plan9.go @@ -7,7 +7,6 @@ package net import ( "errors" "os" - "syscall" ) func query(filename, query string, bufSize int) (res []string, err error) { @@ -70,9 +69,26 @@ func queryDNS(addr string, typ string) (res []string, err error) { return query("/net/dns", addr+" "+typ, 1024) } +// lookupProtocol looks up IP protocol name and returns +// the corresponding protocol number. func lookupProtocol(name string) (proto int, err error) { - // TODO: Implement this - return 0, syscall.EPLAN9 + lines, err := query("/net/cs", "!protocol="+name, 128) + if err != nil { + return 0, err + } + unknownProtoError := errors.New("unknown IP protocol specified: " + name) + if len(lines) == 0 { + return 0, unknownProtoError + } + f := getFields(lines[0]) + if len(f) < 2 { + return 0, unknownProtoError + } + s := f[1] + if n, _, ok := dtoi(s, byteIndex(s, '=')+1); ok { + return n, nil + } + return 0, unknownProtoError } func lookupHost(host string) (addrs []string, err error) { diff --git a/libgo/go/net/lookup_windows.go b/libgo/go/net/lookup_windows.go index 390fe7f..3b29724 100644 --- a/libgo/go/net/lookup_windows.go +++ b/libgo/go/net/lookup_windows.go @@ -6,21 +6,17 @@ package net import ( "os" - "sync" + "runtime" "syscall" "unsafe" ) var ( - protoentLock sync.Mutex - hostentLock sync.Mutex - serventLock sync.Mutex + lookupPort = oldLookupPort + lookupIP = oldLookupIP ) -// lookupProtocol looks up IP protocol name and returns correspondent protocol number. -func lookupProtocol(name string) (proto int, err error) { - protoentLock.Lock() - defer protoentLock.Unlock() +func getprotobyname(name string) (proto int, err error) { p, err := syscall.GetProtoByName(name) if err != nil { return 0, os.NewSyscallError("GetProtoByName", err) @@ -28,6 +24,25 @@ func lookupProtocol(name string) (proto int, err error) { return int(p.Proto), nil } +// lookupProtocol looks up IP protocol name and returns correspondent protocol number. +func lookupProtocol(name string) (proto int, err error) { + // GetProtoByName return value is stored in thread local storage. + // Start new os thread before the call to prevent races. + type result struct { + proto int + err error + } + ch := make(chan result) + go func() { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + proto, err := getprotobyname(name) + ch <- result{proto: proto, err: err} + }() + r := <-ch + return r.proto, r.err +} + func lookupHost(name string) (addrs []string, err error) { ips, err := LookupIP(name) if err != nil { @@ -40,11 +55,7 @@ func lookupHost(name string) (addrs []string, err error) { return } -var lookupIP = oldLookupIP - -func oldLookupIP(name string) (addrs []IP, err error) { - hostentLock.Lock() - defer hostentLock.Unlock() +func gethostbyname(name string) (addrs []IP, err error) { h, err := syscall.GetHostByName(name) if err != nil { return nil, os.NewSyscallError("GetHostByName", err) @@ -63,6 +74,24 @@ func oldLookupIP(name string) (addrs []IP, err error) { return addrs, nil } +func oldLookupIP(name string) (addrs []IP, err error) { + // GetHostByName return value is stored in thread local storage. + // Start new os thread before the call to prevent races. + type result struct { + addrs []IP + err error + } + ch := make(chan result) + go func() { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + addrs, err := gethostbyname(name) + ch <- result{addrs: addrs, err: err} + }() + r := <-ch + return r.addrs, r.err +} + func newLookupIP(name string) (addrs []IP, err error) { hints := syscall.AddrinfoW{ Family: syscall.AF_UNSPEC, @@ -92,15 +121,13 @@ func newLookupIP(name string) (addrs []IP, err error) { return addrs, nil } -func lookupPort(network, service string) (port int, err error) { +func getservbyname(network, service string) (port int, err error) { switch network { case "tcp4", "tcp6": network = "tcp" case "udp4", "udp6": network = "udp" } - serventLock.Lock() - defer serventLock.Unlock() s, err := syscall.GetServByName(service, network) if err != nil { return 0, os.NewSyscallError("GetServByName", err) @@ -108,6 +135,58 @@ func lookupPort(network, service string) (port int, err error) { return int(syscall.Ntohs(s.Port)), nil } +func oldLookupPort(network, service string) (port int, err error) { + // GetServByName return value is stored in thread local storage. + // Start new os thread before the call to prevent races. + type result struct { + port int + err error + } + ch := make(chan result) + go func() { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + port, err := getservbyname(network, service) + ch <- result{port: port, err: err} + }() + r := <-ch + return r.port, r.err +} + +func newLookupPort(network, service string) (port int, err error) { + var stype int32 + switch network { + case "tcp4", "tcp6": + stype = syscall.SOCK_STREAM + case "udp4", "udp6": + stype = syscall.SOCK_DGRAM + } + hints := syscall.AddrinfoW{ + Family: syscall.AF_UNSPEC, + Socktype: stype, + Protocol: syscall.IPPROTO_IP, + } + var result *syscall.AddrinfoW + e := syscall.GetAddrInfoW(nil, syscall.StringToUTF16Ptr(service), &hints, &result) + if e != nil { + return 0, os.NewSyscallError("GetAddrInfoW", e) + } + defer syscall.FreeAddrInfoW(result) + if result == nil { + return 0, os.NewSyscallError("LookupPort", syscall.EINVAL) + } + addr := unsafe.Pointer(result.Addr) + switch result.Family { + case syscall.AF_INET: + a := (*syscall.RawSockaddrInet4)(addr) + return int(syscall.Ntohs(a.Port)), nil + case syscall.AF_INET6: + a := (*syscall.RawSockaddrInet6)(addr) + return int(syscall.Ntohs(a.Port)), nil + } + return 0, os.NewSyscallError("LookupPort", syscall.EINVAL) +} + func lookupCNAME(name string) (cname string, err error) { var r *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil) diff --git a/libgo/go/net/multicast_posix_test.go b/libgo/go/net/multicast_posix_test.go deleted file mode 100644 index ff1edaf..0000000 --- a/libgo/go/net/multicast_posix_test.go +++ /dev/null @@ -1,180 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build !plan9 - -package net - -import ( - "errors" - "os" - "runtime" - "testing" -) - -var multicastListenerTests = []struct { - net string - gaddr *UDPAddr - flags Flags - ipv6 bool // test with underlying AF_INET6 socket -}{ - // cf. RFC 4727: Experimental Values in IPv4, IPv6, ICMPv4, ICMPv6, UDP, and TCP Headers - - {"udp", &UDPAddr{IP: IPv4(224, 0, 0, 254), Port: 12345}, FlagUp | FlagLoopback, false}, - {"udp", &UDPAddr{IP: IPv4(224, 0, 0, 254), Port: 12345}, 0, false}, - {"udp", &UDPAddr{IP: ParseIP("ff0e::114"), Port: 12345}, FlagUp | FlagLoopback, true}, - {"udp", &UDPAddr{IP: ParseIP("ff0e::114"), Port: 12345}, 0, true}, - - {"udp4", &UDPAddr{IP: IPv4(224, 0, 0, 254), Port: 12345}, FlagUp | FlagLoopback, false}, - {"udp4", &UDPAddr{IP: IPv4(224, 0, 0, 254), Port: 12345}, 0, false}, - - {"udp6", &UDPAddr{IP: ParseIP("ff01::114"), Port: 12345}, FlagUp | FlagLoopback, true}, - {"udp6", &UDPAddr{IP: ParseIP("ff01::114"), Port: 12345}, 0, true}, - {"udp6", &UDPAddr{IP: ParseIP("ff02::114"), Port: 12345}, FlagUp | FlagLoopback, true}, - {"udp6", &UDPAddr{IP: ParseIP("ff02::114"), Port: 12345}, 0, true}, - {"udp6", &UDPAddr{IP: ParseIP("ff04::114"), Port: 12345}, FlagUp | FlagLoopback, true}, - {"udp6", &UDPAddr{IP: ParseIP("ff04::114"), Port: 12345}, 0, true}, - {"udp6", &UDPAddr{IP: ParseIP("ff05::114"), Port: 12345}, FlagUp | FlagLoopback, true}, - {"udp6", &UDPAddr{IP: ParseIP("ff05::114"), Port: 12345}, 0, true}, - {"udp6", &UDPAddr{IP: ParseIP("ff08::114"), Port: 12345}, FlagUp | FlagLoopback, true}, - {"udp6", &UDPAddr{IP: ParseIP("ff08::114"), Port: 12345}, 0, true}, - {"udp6", &UDPAddr{IP: ParseIP("ff0e::114"), Port: 12345}, FlagUp | FlagLoopback, true}, - {"udp6", &UDPAddr{IP: ParseIP("ff0e::114"), Port: 12345}, 0, true}, -} - -// TestMulticastListener tests both single and double listen to a test -// listener with same address family, same group address and same port. -func TestMulticastListener(t *testing.T) { - switch runtime.GOOS { - case "netbsd", "openbsd", "plan9", "solaris", "windows": - t.Skipf("skipping test on %q", runtime.GOOS) - case "linux": - if runtime.GOARCH == "arm" || runtime.GOARCH == "alpha" { - t.Skipf("skipping test on %q/%q", runtime.GOOS, runtime.GOARCH) - } - } - - for _, tt := range multicastListenerTests { - if tt.ipv6 && (!*testIPv6 || !supportsIPv6 || os.Getuid() != 0) { - continue - } - ifi, err := availMulticastInterface(t, tt.flags) - if err != nil { - continue - } - c1, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr) - if err != nil { - t.Fatalf("First ListenMulticastUDP failed: %v", err) - } - checkMulticastListener(t, err, c1, tt.gaddr) - c2, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr) - if err != nil { - t.Fatalf("Second ListenMulticastUDP failed: %v", err) - } - checkMulticastListener(t, err, c2, tt.gaddr) - c2.Close() - c1.Close() - } -} - -func TestSimpleMulticastListener(t *testing.T) { - switch runtime.GOOS { - case "plan9": - t.Skipf("skipping test on %q", runtime.GOOS) - case "windows": - if testing.Short() || !*testExternal { - t.Skip("skipping test on windows to avoid firewall") - } - } - - for _, tt := range multicastListenerTests { - if tt.ipv6 { - continue - } - tt.flags = FlagUp | FlagMulticast // for windows testing - ifi, err := availMulticastInterface(t, tt.flags) - if err != nil { - continue - } - c1, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr) - if err != nil { - t.Fatalf("First ListenMulticastUDP failed: %v", err) - } - checkSimpleMulticastListener(t, err, c1, tt.gaddr) - c2, err := ListenMulticastUDP(tt.net, ifi, tt.gaddr) - if err != nil { - t.Fatalf("Second ListenMulticastUDP failed: %v", err) - } - checkSimpleMulticastListener(t, err, c2, tt.gaddr) - c2.Close() - c1.Close() - } -} - -func checkMulticastListener(t *testing.T, err error, c *UDPConn, gaddr *UDPAddr) { - if !multicastRIBContains(t, gaddr.IP) { - t.Errorf("%q not found in RIB", gaddr.String()) - return - } - la := c.LocalAddr() - if la == nil { - t.Error("LocalAddr failed") - return - } - if a, ok := la.(*UDPAddr); !ok || a.Port == 0 { - t.Errorf("got %v; expected a proper address with non-zero port number", la) - return - } -} - -func checkSimpleMulticastListener(t *testing.T, err error, c *UDPConn, gaddr *UDPAddr) { - la := c.LocalAddr() - if la == nil { - t.Error("LocalAddr failed") - return - } - if a, ok := la.(*UDPAddr); !ok || a.Port == 0 { - t.Errorf("got %v; expected a proper address with non-zero port number", la) - return - } -} - -func availMulticastInterface(t *testing.T, flags Flags) (*Interface, error) { - var ifi *Interface - if flags != Flags(0) { - ift, err := Interfaces() - if err != nil { - t.Fatalf("Interfaces failed: %v", err) - } - for _, x := range ift { - if x.Flags&flags == flags { - ifi = &x - break - } - } - if ifi == nil { - return nil, errors.New("an appropriate multicast interface not found") - } - } - return ifi, nil -} - -func multicastRIBContains(t *testing.T, ip IP) bool { - ift, err := Interfaces() - if err != nil { - t.Fatalf("Interfaces failed: %v", err) - } - for _, ifi := range ift { - ifmat, err := ifi.MulticastAddrs() - if err != nil { - t.Fatalf("MulticastAddrs failed: %v", err) - } - for _, ifma := range ifmat { - if ifma.(*IPAddr).IP.Equal(ip) { - return true - } - } - } - return false -} diff --git a/libgo/go/net/multicast_test.go b/libgo/go/net/multicast_test.go new file mode 100644 index 0000000..8ff02a3 --- /dev/null +++ b/libgo/go/net/multicast_test.go @@ -0,0 +1,184 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "fmt" + "os" + "runtime" + "testing" +) + +var ipv4MulticastListenerTests = []struct { + net string + gaddr *UDPAddr // see RFC 4727 +}{ + {"udp", &UDPAddr{IP: IPv4(224, 0, 0, 254), Port: 12345}}, + + {"udp4", &UDPAddr{IP: IPv4(224, 0, 0, 254), Port: 12345}}, +} + +// TestIPv4MulticastListener tests both single and double listen to a +// test listener with same address family, same group address and same +// port. +func TestIPv4MulticastListener(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + + closer := func(cs []*UDPConn) { + for _, c := range cs { + if c != nil { + c.Close() + } + } + } + + for _, ifi := range []*Interface{loopbackInterface(), nil} { + // Note that multicast interface assignment by system + // is not recommended because it usually relies on + // routing stuff for finding out an appropriate + // nexthop containing both network and link layer + // adjacencies. + if ifi == nil && !*testExternal { + continue + } + for _, tt := range ipv4MulticastListenerTests { + var err error + cs := make([]*UDPConn, 2) + if cs[0], err = ListenMulticastUDP(tt.net, ifi, tt.gaddr); err != nil { + t.Fatalf("First ListenMulticastUDP on %v failed: %v", ifi, err) + } + if err := checkMulticastListener(cs[0], tt.gaddr.IP); err != nil { + closer(cs) + t.Fatal(err) + } + if cs[1], err = ListenMulticastUDP(tt.net, ifi, tt.gaddr); err != nil { + closer(cs) + t.Fatalf("Second ListenMulticastUDP on %v failed: %v", ifi, err) + } + if err := checkMulticastListener(cs[1], tt.gaddr.IP); err != nil { + closer(cs) + t.Fatal(err) + } + closer(cs) + } + } +} + +var ipv6MulticastListenerTests = []struct { + net string + gaddr *UDPAddr // see RFC 4727 +}{ + {"udp", &UDPAddr{IP: ParseIP("ff01::114"), Port: 12345}}, + {"udp", &UDPAddr{IP: ParseIP("ff02::114"), Port: 12345}}, + {"udp", &UDPAddr{IP: ParseIP("ff04::114"), Port: 12345}}, + {"udp", &UDPAddr{IP: ParseIP("ff05::114"), Port: 12345}}, + {"udp", &UDPAddr{IP: ParseIP("ff08::114"), Port: 12345}}, + {"udp", &UDPAddr{IP: ParseIP("ff0e::114"), Port: 12345}}, + + {"udp6", &UDPAddr{IP: ParseIP("ff01::114"), Port: 12345}}, + {"udp6", &UDPAddr{IP: ParseIP("ff02::114"), Port: 12345}}, + {"udp6", &UDPAddr{IP: ParseIP("ff04::114"), Port: 12345}}, + {"udp6", &UDPAddr{IP: ParseIP("ff05::114"), Port: 12345}}, + {"udp6", &UDPAddr{IP: ParseIP("ff08::114"), Port: 12345}}, + {"udp6", &UDPAddr{IP: ParseIP("ff0e::114"), Port: 12345}}, +} + +// TestIPv6MulticastListener tests both single and double listen to a +// test listener with same address family, same group address and same +// port. +func TestIPv6MulticastListener(t *testing.T) { + switch runtime.GOOS { + case "plan9", "solaris": + t.Skipf("skipping test on %q", runtime.GOOS) + } + if !supportsIPv6 { + t.Skip("ipv6 is not supported") + } + if os.Getuid() != 0 { + t.Skip("skipping test; must be root") + } + + closer := func(cs []*UDPConn) { + for _, c := range cs { + if c != nil { + c.Close() + } + } + } + + for _, ifi := range []*Interface{loopbackInterface(), nil} { + // Note that multicast interface assignment by system + // is not recommended because it usually relies on + // routing stuff for finding out an appropriate + // nexthop containing both network and link layer + // adjacencies. + if ifi == nil && (!*testExternal || !*testIPv6) { + continue + } + for _, tt := range ipv6MulticastListenerTests { + var err error + cs := make([]*UDPConn, 2) + if cs[0], err = ListenMulticastUDP(tt.net, ifi, tt.gaddr); err != nil { + t.Fatalf("First ListenMulticastUDP on %v failed: %v", ifi, err) + } + if err := checkMulticastListener(cs[0], tt.gaddr.IP); err != nil { + closer(cs) + t.Fatal(err) + } + if cs[1], err = ListenMulticastUDP(tt.net, ifi, tt.gaddr); err != nil { + closer(cs) + t.Fatalf("Second ListenMulticastUDP on %v failed: %v", ifi, err) + } + if err := checkMulticastListener(cs[1], tt.gaddr.IP); err != nil { + closer(cs) + t.Fatal(err) + } + closer(cs) + } + } +} + +func checkMulticastListener(c *UDPConn, ip IP) error { + if ok, err := multicastRIBContains(ip); err != nil { + return err + } else if !ok { + return fmt.Errorf("%q not found in multicast RIB", ip.String()) + } + la := c.LocalAddr() + if la, ok := la.(*UDPAddr); !ok || la.Port == 0 { + return fmt.Errorf("got %v; expected a proper address with non-zero port number", la) + } + return nil +} + +func multicastRIBContains(ip IP) (bool, error) { + switch runtime.GOOS { + case "netbsd", "openbsd", "plan9", "solaris", "windows": + return true, nil // not implemented yet + case "linux": + if runtime.GOARCH == "arm" || runtime.GOARCH == "alpha" { + return true, nil // not implemented yet + } + } + ift, err := Interfaces() + if err != nil { + return false, err + } + for _, ifi := range ift { + ifmat, err := ifi.MulticastAddrs() + if err != nil { + return false, err + } + for _, ifma := range ifmat { + if ifma.(*IPAddr).IP.Equal(ip) { + return true, nil + } + } + } + return false, nil +} diff --git a/libgo/go/net/net.go b/libgo/go/net/net.go index a3d1759..72b2b64 100644 --- a/libgo/go/net/net.go +++ b/libgo/go/net/net.go @@ -276,11 +276,23 @@ type Listener interface { var errMissingAddress = errors.New("missing address") +// OpError is the error type usually returned by functions in the net +// package. It describes the operation, network type, and address of +// an error. type OpError struct { - Op string - Net string + // Op is the operation which caused the error, such as + // "read" or "write". + Op string + + // Net is the network type on which this error occurred, + // such as "tcp" or "udp6". + Net string + + // Addr is the network address on which this error occurred. Addr Addr - Err error + + // Err is the error that occurred during the operation. + Err error } func (e *OpError) Error() string { diff --git a/libgo/go/net/net_test.go b/libgo/go/net/net_test.go index 8a560b5..1a512a5 100644 --- a/libgo/go/net/net_test.go +++ b/libgo/go/net/net_test.go @@ -173,6 +173,10 @@ func TestUDPListenClose(t *testing.T) { } func TestTCPClose(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } l, err := Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatal(err) diff --git a/libgo/go/net/newpollserver_unix.go b/libgo/go/net/newpollserver_unix.go deleted file mode 100644 index 618b5b1..0000000 --- a/libgo/go/net/newpollserver_unix.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2010 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build darwin freebsd linux netbsd openbsd - -package net - -import ( - "os" - "syscall" -) - -func newPollServer() (s *pollServer, err error) { - s = new(pollServer) - if s.pr, s.pw, err = os.Pipe(); err != nil { - return nil, err - } - if err = syscall.SetNonblock(int(s.pr.Fd()), true); err != nil { - goto Errno - } - if err = syscall.SetNonblock(int(s.pw.Fd()), true); err != nil { - goto Errno - } - if s.poll, err = newpollster(); err != nil { - goto Error - } - if _, err = s.poll.AddFD(int(s.pr.Fd()), 'r', true); err != nil { - s.poll.Close() - goto Error - } - s.pending = make(map[int]*netFD) - go s.Run() - return s, nil - -Errno: - err = &os.PathError{ - Op: "setnonblock", - Path: s.pr.Name(), - Err: err, - } -Error: - s.pr.Close() - s.pw.Close() - return nil, err -} diff --git a/libgo/go/net/packetconn_test.go b/libgo/go/net/packetconn_test.go index ff29e24..ec5dd71 100644 --- a/libgo/go/net/packetconn_test.go +++ b/libgo/go/net/packetconn_test.go @@ -2,10 +2,12 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package net_test +// This file implements API tests across platforms and will never have a build +// tag. + +package net import ( - "net" "os" "runtime" "strings" @@ -13,18 +15,24 @@ import ( "time" ) +func strfunc(s string) func() string { + return func() string { + return s + } +} + var packetConnTests = []struct { net string - addr1 string - addr2 string + addr1 func() string + addr2 func() string }{ - {"udp", "127.0.0.1:0", "127.0.0.1:0"}, - {"ip:icmp", "127.0.0.1", "127.0.0.1"}, - {"unixgram", "/tmp/gotest.net1", "/tmp/gotest.net2"}, + {"udp", strfunc("127.0.0.1:0"), strfunc("127.0.0.1:0")}, + {"ip:icmp", strfunc("127.0.0.1"), strfunc("127.0.0.1")}, + {"unixgram", testUnixAddr, testUnixAddr}, } func TestPacketConn(t *testing.T) { - closer := func(c net.PacketConn, net, addr1, addr2 string) { + closer := func(c PacketConn, net, addr1, addr2 string) { c.Close() switch net { case "unixgram": @@ -33,7 +41,7 @@ func TestPacketConn(t *testing.T) { } } - for _, tt := range packetConnTests { + for i, tt := range packetConnTests { var wb []byte netstr := strings.Split(tt.net, ":") switch netstr[0] { @@ -47,59 +55,76 @@ func TestPacketConn(t *testing.T) { if os.Getuid() != 0 { continue } - id := os.Getpid() & 0xffff - wb = newICMPEchoRequest(id, 1, 128, []byte("IP PACKETCONN TEST")) + var err error + wb, err = (&icmpMessage{ + Type: icmpv4EchoRequest, Code: 0, + Body: &icmpEcho{ + ID: os.Getpid() & 0xffff, Seq: i + 1, + Data: []byte("IP PACKETCONN TEST"), + }, + }).Marshal() + if err != nil { + t.Fatalf("icmpMessage.Marshal failed: %v", err) + } case "unixgram": switch runtime.GOOS { case "plan9", "windows": continue } - os.Remove(tt.addr1) - os.Remove(tt.addr2) wb = []byte("UNIXGRAM PACKETCONN TEST") default: continue } - c1, err := net.ListenPacket(tt.net, tt.addr1) + addr1, addr2 := tt.addr1(), tt.addr2() + c1, err := ListenPacket(tt.net, addr1) if err != nil { - t.Fatalf("net.ListenPacket failed: %v", err) + t.Fatalf("ListenPacket failed: %v", err) } + defer closer(c1, netstr[0], addr1, addr2) c1.LocalAddr() c1.SetDeadline(time.Now().Add(100 * time.Millisecond)) c1.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) c1.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) - defer closer(c1, netstr[0], tt.addr1, tt.addr2) - c2, err := net.ListenPacket(tt.net, tt.addr2) + c2, err := ListenPacket(tt.net, addr2) if err != nil { - t.Fatalf("net.ListenPacket failed: %v", err) + t.Fatalf("ListenPacket failed: %v", err) } + defer closer(c2, netstr[0], addr1, addr2) c2.LocalAddr() c2.SetDeadline(time.Now().Add(100 * time.Millisecond)) c2.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) c2.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) - defer closer(c2, netstr[0], tt.addr1, tt.addr2) if _, err := c1.WriteTo(wb, c2.LocalAddr()); err != nil { - t.Fatalf("net.PacketConn.WriteTo failed: %v", err) + t.Fatalf("PacketConn.WriteTo failed: %v", err) } rb2 := make([]byte, 128) if _, _, err := c2.ReadFrom(rb2); err != nil { - t.Fatalf("net.PacketConn.ReadFrom failed: %v", err) + t.Fatalf("PacketConn.ReadFrom failed: %v", err) } if _, err := c2.WriteTo(wb, c1.LocalAddr()); err != nil { - t.Fatalf("net.PacketConn.WriteTo failed: %v", err) + t.Fatalf("PacketConn.WriteTo failed: %v", err) } rb1 := make([]byte, 128) if _, _, err := c1.ReadFrom(rb1); err != nil { - t.Fatalf("net.PacketConn.ReadFrom failed: %v", err) + t.Fatalf("PacketConn.ReadFrom failed: %v", err) } } } func TestConnAndPacketConn(t *testing.T) { - for _, tt := range packetConnTests { + closer := func(c PacketConn, net, addr1, addr2 string) { + c.Close() + switch net { + case "unixgram": + os.Remove(addr1) + os.Remove(addr2) + } + } + + for i, tt := range packetConnTests { var wb []byte netstr := strings.Split(tt.net, ":") switch netstr[0] { @@ -113,52 +138,71 @@ func TestConnAndPacketConn(t *testing.T) { if os.Getuid() != 0 { continue } - id := os.Getpid() & 0xffff - wb = newICMPEchoRequest(id, 1, 128, []byte("IP PACKETCONN TEST")) + var err error + wb, err = (&icmpMessage{ + Type: icmpv4EchoRequest, Code: 0, + Body: &icmpEcho{ + ID: os.Getpid() & 0xffff, Seq: i + 1, + Data: []byte("IP PACKETCONN TEST"), + }, + }).Marshal() + if err != nil { + t.Fatalf("icmpMessage.Marshal failed: %v", err) + } + case "unixgram": + switch runtime.GOOS { + case "plan9", "windows": + continue + } + wb = []byte("UNIXGRAM PACKETCONN TEST") default: continue } - c1, err := net.ListenPacket(tt.net, tt.addr1) + addr1, addr2 := tt.addr1(), tt.addr2() + c1, err := ListenPacket(tt.net, addr1) if err != nil { - t.Fatalf("net.ListenPacket failed: %v", err) + t.Fatalf("ListenPacket failed: %v", err) } + defer closer(c1, netstr[0], addr1, addr2) c1.LocalAddr() c1.SetDeadline(time.Now().Add(100 * time.Millisecond)) c1.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) c1.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) - defer c1.Close() - c2, err := net.Dial(tt.net, c1.LocalAddr().String()) + c2, err := Dial(tt.net, c1.LocalAddr().String()) if err != nil { - t.Fatalf("net.Dial failed: %v", err) + t.Fatalf("Dial failed: %v", err) } + defer c2.Close() c2.LocalAddr() c2.RemoteAddr() c2.SetDeadline(time.Now().Add(100 * time.Millisecond)) c2.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) c2.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) - defer c2.Close() if _, err := c2.Write(wb); err != nil { - t.Fatalf("net.Conn.Write failed: %v", err) + t.Fatalf("Conn.Write failed: %v", err) } rb1 := make([]byte, 128) if _, _, err := c1.ReadFrom(rb1); err != nil { - t.Fatalf("net.PacetConn.ReadFrom failed: %v", err) + t.Fatalf("PacetConn.ReadFrom failed: %v", err) } - var dst net.Addr - if netstr[0] == "ip" { - dst = &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)} - } else { + var dst Addr + switch netstr[0] { + case "ip": + dst = &IPAddr{IP: IPv4(127, 0, 0, 1)} + case "unixgram": + continue + default: dst = c2.LocalAddr() } if _, err := c1.WriteTo(wb, dst); err != nil { - t.Fatalf("net.PacketConn.WriteTo failed: %v", err) + t.Fatalf("PacketConn.WriteTo failed: %v", err) } rb2 := make([]byte, 128) if _, err := c2.Read(rb2); err != nil { - t.Fatalf("net.Conn.Read failed: %v", err) + t.Fatalf("Conn.Read failed: %v", err) } } } diff --git a/libgo/go/net/port_test.go b/libgo/go/net/port_test.go index 329b169..9e8968f 100644 --- a/libgo/go/net/port_test.go +++ b/libgo/go/net/port_test.go @@ -46,7 +46,7 @@ func TestLookupPort(t *testing.T) { for i := 0; i < len(porttests); i++ { tt := porttests[i] if port, err := LookupPort(tt.netw, tt.name); port != tt.port || (err == nil) != tt.ok { - t.Errorf("LookupPort(%q, %q) = %v, %s; want %v", + t.Errorf("LookupPort(%q, %q) = %v, %v; want %v", tt.netw, tt.name, port, err, tt.port) } } diff --git a/libgo/go/net/protoconn_test.go b/libgo/go/net/protoconn_test.go index 1344fba..b59925e 100644 --- a/libgo/go/net/protoconn_test.go +++ b/libgo/go/net/protoconn_test.go @@ -2,152 +2,161 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package net_test +// This file implements API tests across platforms and will never have a build +// tag. + +package net import ( - "bytes" - "net" + "io/ioutil" "os" "runtime" "testing" "time" ) -var condErrorf = func() func(*testing.T, string, ...interface{}) { +// testUnixAddr uses ioutil.TempFile to get a name that is unique. It +// also uses /tmp directory in case it is prohibited to create UNIX +// sockets in TMPDIR. +func testUnixAddr() string { + f, err := ioutil.TempFile("/tmp", "nettest") + if err != nil { + panic(err) + } + addr := f.Name() + f.Close() + os.Remove(addr) + return addr +} + +var condFatalf = func() func(*testing.T, string, ...interface{}) { // A few APIs are not implemented yet on both Plan 9 and Windows. switch runtime.GOOS { case "plan9", "windows": return (*testing.T).Logf } - return (*testing.T).Errorf + return (*testing.T).Fatalf }() func TestTCPListenerSpecificMethods(t *testing.T) { - la, err := net.ResolveTCPAddr("tcp4", "127.0.0.1:0") + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + + la, err := ResolveTCPAddr("tcp4", "127.0.0.1:0") if err != nil { - t.Fatalf("net.ResolveTCPAddr failed: %v", err) + t.Fatalf("ResolveTCPAddr failed: %v", err) } - ln, err := net.ListenTCP("tcp4", la) + ln, err := ListenTCP("tcp4", la) if err != nil { - t.Fatalf("net.ListenTCP failed: %v", err) + t.Fatalf("ListenTCP failed: %v", err) } + defer ln.Close() ln.Addr() ln.SetDeadline(time.Now().Add(30 * time.Nanosecond)) - defer ln.Close() if c, err := ln.Accept(); err != nil { - if !err.(net.Error).Timeout() { - t.Errorf("net.TCPListener.Accept failed: %v", err) - return + if !err.(Error).Timeout() { + t.Fatalf("TCPListener.Accept failed: %v", err) } } else { c.Close() } if c, err := ln.AcceptTCP(); err != nil { - if !err.(net.Error).Timeout() { - t.Errorf("net.TCPListener.AcceptTCP failed: %v", err) - return + if !err.(Error).Timeout() { + t.Fatalf("TCPListener.AcceptTCP failed: %v", err) } } else { c.Close() } if f, err := ln.File(); err != nil { - condErrorf(t, "net.TCPListener.File failed: %v", err) - return + condFatalf(t, "TCPListener.File failed: %v", err) } else { f.Close() } } func TestTCPConnSpecificMethods(t *testing.T) { - la, err := net.ResolveTCPAddr("tcp4", "127.0.0.1:0") + la, err := ResolveTCPAddr("tcp4", "127.0.0.1:0") if err != nil { - t.Fatalf("net.ResolveTCPAddr failed: %v", err) + t.Fatalf("ResolveTCPAddr failed: %v", err) } - ln, err := net.ListenTCP("tcp4", la) + ln, err := ListenTCP("tcp4", la) if err != nil { - t.Fatalf("net.ListenTCP failed: %v", err) + t.Fatalf("ListenTCP failed: %v", err) } - ln.Addr() defer ln.Close() + ln.Addr() done := make(chan int) go transponder(t, ln, done) - ra, err := net.ResolveTCPAddr("tcp4", ln.Addr().String()) + ra, err := ResolveTCPAddr("tcp4", ln.Addr().String()) if err != nil { - t.Errorf("net.ResolveTCPAddr failed: %v", err) - return + t.Fatalf("ResolveTCPAddr failed: %v", err) } - c, err := net.DialTCP("tcp4", nil, ra) + c, err := DialTCP("tcp4", nil, ra) if err != nil { - t.Errorf("net.DialTCP failed: %v", err) - return + t.Fatalf("DialTCP failed: %v", err) } + defer c.Close() c.SetKeepAlive(false) c.SetLinger(0) c.SetNoDelay(false) c.LocalAddr() c.RemoteAddr() - c.SetDeadline(time.Now().Add(100 * time.Millisecond)) - c.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) - c.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) - defer c.Close() + c.SetDeadline(time.Now().Add(someTimeout)) + c.SetReadDeadline(time.Now().Add(someTimeout)) + c.SetWriteDeadline(time.Now().Add(someTimeout)) if _, err := c.Write([]byte("TCPCONN TEST")); err != nil { - t.Errorf("net.TCPConn.Write failed: %v", err) - return + t.Fatalf("TCPConn.Write failed: %v", err) } rb := make([]byte, 128) if _, err := c.Read(rb); err != nil { - t.Errorf("net.TCPConn.Read failed: %v", err) - return + t.Fatalf("TCPConn.Read failed: %v", err) } <-done } func TestUDPConnSpecificMethods(t *testing.T) { - la, err := net.ResolveUDPAddr("udp4", "127.0.0.1:0") + la, err := ResolveUDPAddr("udp4", "127.0.0.1:0") if err != nil { - t.Fatalf("net.ResolveUDPAddr failed: %v", err) + t.Fatalf("ResolveUDPAddr failed: %v", err) } - c, err := net.ListenUDP("udp4", la) + c, err := ListenUDP("udp4", la) if err != nil { - t.Fatalf("net.ListenUDP failed: %v", err) + t.Fatalf("ListenUDP failed: %v", err) } + defer c.Close() c.LocalAddr() c.RemoteAddr() - c.SetDeadline(time.Now().Add(100 * time.Millisecond)) - c.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) - c.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) + c.SetDeadline(time.Now().Add(someTimeout)) + c.SetReadDeadline(time.Now().Add(someTimeout)) + c.SetWriteDeadline(time.Now().Add(someTimeout)) c.SetReadBuffer(2048) c.SetWriteBuffer(2048) - defer c.Close() wb := []byte("UDPCONN TEST") rb := make([]byte, 128) - if _, err := c.WriteToUDP(wb, c.LocalAddr().(*net.UDPAddr)); err != nil { - t.Errorf("net.UDPConn.WriteToUDP failed: %v", err) - return + if _, err := c.WriteToUDP(wb, c.LocalAddr().(*UDPAddr)); err != nil { + t.Fatalf("UDPConn.WriteToUDP failed: %v", err) } if _, _, err := c.ReadFromUDP(rb); err != nil { - t.Errorf("net.UDPConn.ReadFromUDP failed: %v", err) - return + t.Fatalf("UDPConn.ReadFromUDP failed: %v", err) } - if _, _, err := c.WriteMsgUDP(wb, nil, c.LocalAddr().(*net.UDPAddr)); err != nil { - condErrorf(t, "net.UDPConn.WriteMsgUDP failed: %v", err) - return + if _, _, err := c.WriteMsgUDP(wb, nil, c.LocalAddr().(*UDPAddr)); err != nil { + condFatalf(t, "UDPConn.WriteMsgUDP failed: %v", err) } if _, _, _, _, err := c.ReadMsgUDP(rb, nil); err != nil { - condErrorf(t, "net.UDPConn.ReadMsgUDP failed: %v", err) - return + condFatalf(t, "UDPConn.ReadMsgUDP failed: %v", err) } if f, err := c.File(); err != nil { - condErrorf(t, "net.UDPConn.File failed: %v", err) - return + condFatalf(t, "UDPConn.File failed: %v", err) } else { f.Close() } @@ -156,52 +165,55 @@ func TestUDPConnSpecificMethods(t *testing.T) { func TestIPConnSpecificMethods(t *testing.T) { switch runtime.GOOS { case "plan9": - t.Skipf("skipping read test on %q", runtime.GOOS) + t.Skipf("skipping test on %q", runtime.GOOS) } if os.Getuid() != 0 { t.Skipf("skipping test; must be root") } - la, err := net.ResolveIPAddr("ip4", "127.0.0.1") + la, err := ResolveIPAddr("ip4", "127.0.0.1") if err != nil { - t.Fatalf("net.ResolveIPAddr failed: %v", err) + t.Fatalf("ResolveIPAddr failed: %v", err) } - c, err := net.ListenIP("ip4:icmp", la) + c, err := ListenIP("ip4:icmp", la) if err != nil { - t.Fatalf("net.ListenIP failed: %v", err) + t.Fatalf("ListenIP failed: %v", err) } + defer c.Close() c.LocalAddr() c.RemoteAddr() - c.SetDeadline(time.Now().Add(100 * time.Millisecond)) - c.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) - c.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) + c.SetDeadline(time.Now().Add(someTimeout)) + c.SetReadDeadline(time.Now().Add(someTimeout)) + c.SetWriteDeadline(time.Now().Add(someTimeout)) c.SetReadBuffer(2048) c.SetWriteBuffer(2048) - defer c.Close() - id := os.Getpid() & 0xffff - wb := newICMPEchoRequest(id, 1, 128, []byte("IPCONN TEST ")) + wb, err := (&icmpMessage{ + Type: icmpv4EchoRequest, Code: 0, + Body: &icmpEcho{ + ID: os.Getpid() & 0xffff, Seq: 1, + Data: []byte("IPCONN TEST "), + }, + }).Marshal() + if err != nil { + t.Fatalf("icmpMessage.Marshal failed: %v", err) + } rb := make([]byte, 20+128) - if _, err := c.WriteToIP(wb, c.LocalAddr().(*net.IPAddr)); err != nil { - t.Errorf("net.IPConn.WriteToIP failed: %v", err) - return + if _, err := c.WriteToIP(wb, c.LocalAddr().(*IPAddr)); err != nil { + t.Fatalf("IPConn.WriteToIP failed: %v", err) } if _, _, err := c.ReadFromIP(rb); err != nil { - t.Errorf("net.IPConn.ReadFromIP failed: %v", err) - return + t.Fatalf("IPConn.ReadFromIP failed: %v", err) } - if _, _, err := c.WriteMsgIP(wb, nil, c.LocalAddr().(*net.IPAddr)); err != nil { - condErrorf(t, "net.UDPConn.WriteMsgIP failed: %v", err) - return + if _, _, err := c.WriteMsgIP(wb, nil, c.LocalAddr().(*IPAddr)); err != nil { + condFatalf(t, "IPConn.WriteMsgIP failed: %v", err) } if _, _, _, _, err := c.ReadMsgIP(rb, nil); err != nil { - condErrorf(t, "net.UDPConn.ReadMsgIP failed: %v", err) - return + condFatalf(t, "IPConn.ReadMsgIP failed: %v", err) } if f, err := c.File(); err != nil { - condErrorf(t, "net.IPConn.File failed: %v", err) - return + condFatalf(t, "IPConn.File failed: %v", err) } else { f.Close() } @@ -210,44 +222,40 @@ func TestIPConnSpecificMethods(t *testing.T) { func TestUnixListenerSpecificMethods(t *testing.T) { switch runtime.GOOS { case "plan9", "windows": - t.Skipf("skipping read test on %q", runtime.GOOS) + t.Skipf("skipping test on %q", runtime.GOOS) } - p := "/tmp/gotest.net" - os.Remove(p) - la, err := net.ResolveUnixAddr("unix", p) + addr := testUnixAddr() + la, err := ResolveUnixAddr("unix", addr) if err != nil { - t.Fatalf("net.ResolveUnixAddr failed: %v", err) + t.Fatalf("ResolveUnixAddr failed: %v", err) } - ln, err := net.ListenUnix("unix", la) + ln, err := ListenUnix("unix", la) if err != nil { - t.Fatalf("net.ListenUnix failed: %v", err) + t.Fatalf("ListenUnix failed: %v", err) } + defer ln.Close() + defer os.Remove(addr) ln.Addr() ln.SetDeadline(time.Now().Add(30 * time.Nanosecond)) - defer ln.Close() - defer os.Remove(p) if c, err := ln.Accept(); err != nil { - if !err.(net.Error).Timeout() { - t.Errorf("net.TCPListener.AcceptTCP failed: %v", err) - return + if !err.(Error).Timeout() { + t.Fatalf("UnixListener.Accept failed: %v", err) } } else { c.Close() } if c, err := ln.AcceptUnix(); err != nil { - if !err.(net.Error).Timeout() { - t.Errorf("net.TCPListener.AcceptTCP failed: %v", err) - return + if !err.(Error).Timeout() { + t.Fatalf("UnixListener.AcceptUnix failed: %v", err) } } else { c.Close() } if f, err := ln.File(); err != nil { - t.Errorf("net.UnixListener.File failed: %v", err) - return + t.Fatalf("UnixListener.File failed: %v", err) } else { f.Close() } @@ -259,145 +267,94 @@ func TestUnixConnSpecificMethods(t *testing.T) { t.Skipf("skipping test on %q", runtime.GOOS) } - p1, p2, p3 := "/tmp/gotest.net1", "/tmp/gotest.net2", "/tmp/gotest.net3" - os.Remove(p1) - os.Remove(p2) - os.Remove(p3) + addr1, addr2, addr3 := testUnixAddr(), testUnixAddr(), testUnixAddr() - a1, err := net.ResolveUnixAddr("unixgram", p1) + a1, err := ResolveUnixAddr("unixgram", addr1) if err != nil { - t.Fatalf("net.ResolveUnixAddr failed: %v", err) + t.Fatalf("ResolveUnixAddr failed: %v", err) } - c1, err := net.DialUnix("unixgram", a1, nil) + c1, err := DialUnix("unixgram", a1, nil) if err != nil { - t.Fatalf("net.DialUnix failed: %v", err) + t.Fatalf("DialUnix failed: %v", err) } + defer c1.Close() + defer os.Remove(addr1) c1.LocalAddr() c1.RemoteAddr() - c1.SetDeadline(time.Now().Add(100 * time.Millisecond)) - c1.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) - c1.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) + c1.SetDeadline(time.Now().Add(someTimeout)) + c1.SetReadDeadline(time.Now().Add(someTimeout)) + c1.SetWriteDeadline(time.Now().Add(someTimeout)) c1.SetReadBuffer(2048) c1.SetWriteBuffer(2048) - defer c1.Close() - defer os.Remove(p1) - a2, err := net.ResolveUnixAddr("unixgram", p2) + a2, err := ResolveUnixAddr("unixgram", addr2) if err != nil { - t.Errorf("net.ResolveUnixAddr failed: %v", err) - return + t.Fatalf("ResolveUnixAddr failed: %v", err) } - c2, err := net.DialUnix("unixgram", a2, nil) + c2, err := DialUnix("unixgram", a2, nil) if err != nil { - t.Errorf("net.DialUnix failed: %v", err) - return + t.Fatalf("DialUnix failed: %v", err) } + defer c2.Close() + defer os.Remove(addr2) c2.LocalAddr() c2.RemoteAddr() - c2.SetDeadline(time.Now().Add(100 * time.Millisecond)) - c2.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) - c2.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) + c2.SetDeadline(time.Now().Add(someTimeout)) + c2.SetReadDeadline(time.Now().Add(someTimeout)) + c2.SetWriteDeadline(time.Now().Add(someTimeout)) c2.SetReadBuffer(2048) c2.SetWriteBuffer(2048) - defer c2.Close() - defer os.Remove(p2) - a3, err := net.ResolveUnixAddr("unixgram", p3) + a3, err := ResolveUnixAddr("unixgram", addr3) if err != nil { - t.Errorf("net.ResolveUnixAddr failed: %v", err) - return + t.Fatalf("ResolveUnixAddr failed: %v", err) } - c3, err := net.ListenUnixgram("unixgram", a3) + c3, err := ListenUnixgram("unixgram", a3) if err != nil { - t.Errorf("net.ListenUnixgram failed: %v", err) - return + t.Fatalf("ListenUnixgram failed: %v", err) } + defer c3.Close() + defer os.Remove(addr3) c3.LocalAddr() c3.RemoteAddr() - c3.SetDeadline(time.Now().Add(100 * time.Millisecond)) - c3.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) - c3.SetWriteDeadline(time.Now().Add(100 * time.Millisecond)) + c3.SetDeadline(time.Now().Add(someTimeout)) + c3.SetReadDeadline(time.Now().Add(someTimeout)) + c3.SetWriteDeadline(time.Now().Add(someTimeout)) c3.SetReadBuffer(2048) c3.SetWriteBuffer(2048) - defer c3.Close() - defer os.Remove(p3) wb := []byte("UNIXCONN TEST") rb1 := make([]byte, 128) rb2 := make([]byte, 128) rb3 := make([]byte, 128) if _, _, err := c1.WriteMsgUnix(wb, nil, a2); err != nil { - t.Errorf("net.UnixConn.WriteMsgUnix failed: %v", err) - return + t.Fatalf("UnixConn.WriteMsgUnix failed: %v", err) } if _, _, _, _, err := c2.ReadMsgUnix(rb2, nil); err != nil { - t.Errorf("net.UnixConn.ReadMsgUnix failed: %v", err) - return + t.Fatalf("UnixConn.ReadMsgUnix failed: %v", err) } if _, err := c2.WriteToUnix(wb, a1); err != nil { - t.Errorf("net.UnixConn.WriteToUnix failed: %v", err) - return + t.Fatalf("UnixConn.WriteToUnix failed: %v", err) } if _, _, err := c1.ReadFromUnix(rb1); err != nil { - t.Errorf("net.UnixConn.ReadFromUnix failed: %v", err) - return + t.Fatalf("UnixConn.ReadFromUnix failed: %v", err) } if _, err := c3.WriteToUnix(wb, a1); err != nil { - t.Errorf("net.UnixConn.WriteToUnix failed: %v", err) - return + t.Fatalf("UnixConn.WriteToUnix failed: %v", err) } if _, _, err := c1.ReadFromUnix(rb1); err != nil { - t.Errorf("net.UnixConn.ReadFromUnix failed: %v", err) - return + t.Fatalf("UnixConn.ReadFromUnix failed: %v", err) } if _, err := c2.WriteToUnix(wb, a3); err != nil { - t.Errorf("net.UnixConn.WriteToUnix failed: %v", err) - return + t.Fatalf("UnixConn.WriteToUnix failed: %v", err) } if _, _, err := c3.ReadFromUnix(rb3); err != nil { - t.Errorf("net.UnixConn.ReadFromUnix failed: %v", err) - return + t.Fatalf("UnixConn.ReadFromUnix failed: %v", err) } if f, err := c1.File(); err != nil { - t.Errorf("net.UnixConn.File failed: %v", err) - return + t.Fatalf("UnixConn.File failed: %v", err) } else { f.Close() } } - -func newICMPEchoRequest(id, seqnum, msglen int, filler []byte) []byte { - b := newICMPInfoMessage(id, seqnum, msglen, filler) - b[0] = 8 - // calculate ICMP checksum - cklen := len(b) - s := uint32(0) - for i := 0; i < cklen-1; i += 2 { - s += uint32(b[i+1])<<8 | uint32(b[i]) - } - if cklen&1 == 1 { - s += uint32(b[cklen-1]) - } - s = (s >> 16) + (s & 0xffff) - s = s + (s >> 16) - // place checksum back in header; using ^= avoids the - // assumption the checksum bytes are zero - b[2] ^= byte(^s & 0xff) - b[3] ^= byte(^s >> 8) - return b -} - -func newICMPInfoMessage(id, seqnum, msglen int, filler []byte) []byte { - b := make([]byte, msglen) - copy(b[8:], bytes.Repeat(filler, (msglen-8)/len(filler)+1)) - b[0] = 0 // type - b[1] = 0 // code - b[2] = 0 // checksum - b[3] = 0 // checksum - b[4] = byte(id >> 8) // identifier - b[5] = byte(id & 0xff) // identifier - b[6] = byte(seqnum >> 8) // sequence number - b[7] = byte(seqnum & 0xff) // sequence number - return b -} diff --git a/libgo/go/net/rpc/client.go b/libgo/go/net/rpc/client.go index ee3cc4d..4b0c9c3 100644 --- a/libgo/go/net/rpc/client.go +++ b/libgo/go/net/rpc/client.go @@ -71,7 +71,7 @@ func (client *Client) send(call *Call) { // Register this call. client.mutex.Lock() - if client.shutdown { + if client.shutdown || client.closing { call.Error = ErrShutdown client.mutex.Unlock() call.done() @@ -105,9 +105,6 @@ func (client *Client) input() { response = Response{} err = client.codec.ReadResponseHeader(&response) if err != nil { - if err == io.EOF && !client.closing { - err = io.ErrUnexpectedEOF - } break } seq := response.Seq @@ -150,6 +147,13 @@ func (client *Client) input() { client.mutex.Lock() client.shutdown = true closing := client.closing + if err == io.EOF { + if closing { + err = ErrShutdown + } else { + err = io.ErrUnexpectedEOF + } + } for _, call := range client.pending { call.Error = err call.done() diff --git a/libgo/go/net/rpc/jsonrpc/all_test.go b/libgo/go/net/rpc/jsonrpc/all_test.go index 3c7c4d4..40d4b82 100644 --- a/libgo/go/net/rpc/jsonrpc/all_test.go +++ b/libgo/go/net/rpc/jsonrpc/all_test.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "io/ioutil" "net" "net/rpc" "testing" @@ -185,6 +186,22 @@ func TestMalformedInput(t *testing.T) { ServeConn(srv) // must return, not loop } +func TestMalformedOutput(t *testing.T) { + cli, srv := net.Pipe() + go srv.Write([]byte(`{"id":0,"result":null,"error":null}`)) + go ioutil.ReadAll(srv) + + client := NewClient(cli) + defer client.Close() + + args := &Args{7, 8} + reply := new(Reply) + err := client.Call("Arith.Add", args, reply) + if err == nil { + t.Error("expected error") + } +} + func TestUnexpectedError(t *testing.T) { cli, srv := myPipe() go cli.PipeWriter.CloseWithError(errors.New("unexpected error!")) // reader will get this error diff --git a/libgo/go/net/rpc/jsonrpc/client.go b/libgo/go/net/rpc/jsonrpc/client.go index 3fa8cbf..2194f21 100644 --- a/libgo/go/net/rpc/jsonrpc/client.go +++ b/libgo/go/net/rpc/jsonrpc/client.go @@ -83,7 +83,7 @@ func (c *clientCodec) ReadResponseHeader(r *rpc.Response) error { r.Error = "" r.Seq = c.resp.Id - if c.resp.Error != nil { + if c.resp.Error != nil || c.resp.Result == nil { x, ok := c.resp.Error.(string) if !ok { return fmt.Errorf("invalid error %v", c.resp.Error) diff --git a/libgo/go/net/rpc/server_test.go b/libgo/go/net/rpc/server_test.go index 2c734a4..eb17210 100644 --- a/libgo/go/net/rpc/server_test.go +++ b/libgo/go/net/rpc/server_test.go @@ -399,12 +399,10 @@ func (WriteFailCodec) WriteRequest(*Request, interface{}) error { func (WriteFailCodec) ReadResponseHeader(*Response) error { select {} - panic("unreachable") } func (WriteFailCodec) ReadResponseBody(interface{}) error { select {} - panic("unreachable") } func (WriteFailCodec) Close() error { @@ -445,8 +443,7 @@ func dialHTTP() (*Client, error) { return DialHTTP("tcp", httpServerAddr) } -func countMallocs(dial func() (*Client, error), t *testing.T) uint64 { - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1)) +func countMallocs(dial func() (*Client, error), t *testing.T) float64 { once.Do(startServer) client, err := dial() if err != nil { @@ -454,11 +451,7 @@ func countMallocs(dial func() (*Client, error), t *testing.T) uint64 { } args := &Args{7, 8} reply := new(Reply) - memstats := new(runtime.MemStats) - runtime.ReadMemStats(memstats) - mallocs := 0 - memstats.Mallocs - const count = 100 - for i := 0; i < count; i++ { + return testing.AllocsPerRun(100, func() { err := client.Call("Arith.Add", args, reply) if err != nil { t.Errorf("Add: expected no error but got string %q", err.Error()) @@ -466,18 +459,21 @@ func countMallocs(dial func() (*Client, error), t *testing.T) uint64 { if reply.C != args.A+args.B { t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) } - } - runtime.ReadMemStats(memstats) - mallocs += memstats.Mallocs - return mallocs / count + }) } func TestCountMallocs(t *testing.T) { - fmt.Printf("mallocs per rpc round trip: %d\n", countMallocs(dialDirect, t)) + if runtime.GOMAXPROCS(0) > 1 { + t.Skip("skipping; GOMAXPROCS>1") + } + fmt.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect, t)) } func TestCountMallocsOverHTTP(t *testing.T) { - fmt.Printf("mallocs per HTTP rpc round trip: %d\n", countMallocs(dialHTTP, t)) + if runtime.GOMAXPROCS(0) > 1 { + t.Skip("skipping; GOMAXPROCS>1") + } + fmt.Printf("mallocs per HTTP rpc round trip: %v\n", countMallocs(dialHTTP, t)) } type writeCrasher struct { @@ -532,6 +528,23 @@ func TestTCPClose(t *testing.T) { } } +func TestErrorAfterClientClose(t *testing.T) { + once.Do(startServer) + + client, err := dialHTTP() + if err != nil { + t.Fatalf("dialing: %v", err) + } + err = client.Close() + if err != nil { + t.Fatal("close error:", err) + } + err = client.Call("Arith.Add", &Args{7, 9}, new(Reply)) + if err != ErrShutdown { + t.Errorf("Forever: expected ErrShutdown got %v", err) + } +} + func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) { b.StopTimer() once.Do(startServer) diff --git a/libgo/go/net/sendfile_freebsd.go b/libgo/go/net/sendfile_freebsd.go index 8008bc3..dc5b767 100644 --- a/libgo/go/net/sendfile_freebsd.go +++ b/libgo/go/net/sendfile_freebsd.go @@ -83,7 +83,7 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { break } if err1 == syscall.EAGAIN { - if err1 = c.pollServer.WaitWrite(c); err1 == nil { + if err1 = c.pd.WaitWrite(); err1 == nil { continue } } diff --git a/libgo/go/net/sendfile_linux.go b/libgo/go/net/sendfile_linux.go index 3357e65..6f1323b 100644 --- a/libgo/go/net/sendfile_linux.go +++ b/libgo/go/net/sendfile_linux.go @@ -59,7 +59,7 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { break } if err1 == syscall.EAGAIN { - if err1 = c.pollServer.WaitWrite(c); err1 == nil { + if err1 = c.pd.WaitWrite(); err1 == nil { continue } } diff --git a/libgo/go/net/server_test.go b/libgo/go/net/server_test.go index eba1e7d..9194a8e 100644 --- a/libgo/go/net/server_test.go +++ b/libgo/go/net/server_test.go @@ -9,6 +9,7 @@ import ( "io" "os" "runtime" + "strconv" "testing" "time" ) @@ -41,6 +42,12 @@ func skipServerTest(net, unixsotype, addr string, ipv6, ipv4map, linuxonly bool) return false } +func tempfile(filename string) string { + // use /tmp in case it is prohibited to create + // UNIX sockets in TMPDIR + return "/tmp/" + filename + "." + strconv.Itoa(os.Getpid()) +} + var streamConnServerTests = []struct { snet string // server side saddr string @@ -86,7 +93,7 @@ var streamConnServerTests = []struct { {snet: "tcp6", saddr: "[::1]", cnet: "tcp6", caddr: "[::1]", ipv6: true}, - {snet: "unix", saddr: "/tmp/gotest1.net", cnet: "unix", caddr: "/tmp/gotest1.net.local"}, + {snet: "unix", saddr: tempfile("gotest1.net"), cnet: "unix", caddr: tempfile("gotest1.net.local")}, {snet: "unix", saddr: "@gotest2/net", cnet: "unix", caddr: "@gotest2/net.local", linux: true}, } @@ -113,8 +120,7 @@ func TestStreamConnServer(t *testing.T) { case "tcp", "tcp4", "tcp6": _, port, err := SplitHostPort(taddr) if err != nil { - t.Errorf("SplitHostPort(%q) failed: %v", taddr, err) - return + t.Fatalf("SplitHostPort(%q) failed: %v", taddr, err) } taddr = tt.caddr + ":" + port } @@ -136,7 +142,7 @@ var seqpacketConnServerTests = []struct { caddr string // client address empty bool // test with empty data }{ - {net: "unixpacket", saddr: "/tmp/gotest3.net", caddr: "/tmp/gotest3.net.local"}, + {net: "unixpacket", saddr: tempfile("/gotest3.net"), caddr: tempfile("gotest3.net.local")}, {net: "unixpacket", saddr: "@gotest4/net", caddr: "@gotest4/net.local"}, } @@ -169,11 +175,11 @@ func TestSeqpacketConnServer(t *testing.T) { } func runStreamConnServer(t *testing.T, net, laddr string, listening chan<- string, done chan<- int) { + defer close(done) l, err := Listen(net, laddr) if err != nil { t.Errorf("Listen(%q, %q) failed: %v", net, laddr, err) listening <- "<nil>" - done <- 1 return } defer l.Close() @@ -188,13 +194,14 @@ func runStreamConnServer(t *testing.T, net, laddr string, listening chan<- strin } rw.Write(buf[0:n]) } - done <- 1 + close(done) } run: for { c, err := l.Accept() if err != nil { + t.Logf("Accept failed: %v", err) continue run } echodone := make(chan int) @@ -203,14 +210,12 @@ run: c.Close() break run } - done <- 1 } func runStreamConnClient(t *testing.T, net, taddr string, isEmpty bool) { c, err := Dial(net, taddr) if err != nil { - t.Errorf("Dial(%q, %q) failed: %v", net, taddr, err) - return + t.Fatalf("Dial(%q, %q) failed: %v", net, taddr, err) } defer c.Close() c.SetReadDeadline(time.Now().Add(1 * time.Second)) @@ -220,14 +225,12 @@ func runStreamConnClient(t *testing.T, net, taddr string, isEmpty bool) { wb = []byte("StreamConnClient by Dial\n") } if n, err := c.Write(wb); err != nil || n != len(wb) { - t.Errorf("Write failed: %v, %v; want %v, <nil>", n, err, len(wb)) - return + t.Fatalf("Write failed: %v, %v; want %v, <nil>", n, err, len(wb)) } rb := make([]byte, 1024) if n, err := c.Read(rb[0:]); err != nil || n != len(wb) { - t.Errorf("Read failed: %v, %v; want %v, <nil>", n, err, len(wb)) - return + t.Fatalf("Read failed: %v, %v; want %v, <nil>", n, err, len(wb)) } // Send explicit ending for unixpacket. @@ -298,10 +301,10 @@ var datagramPacketConnServerTests = []struct { {snet: "udp", saddr: "[::1]", cnet: "udp", caddr: "[::1]", ipv6: true, empty: true}, {snet: "udp", saddr: "[::1]", cnet: "udp", caddr: "[::1]", ipv6: true, dial: true, empty: true}, - {snet: "unixgram", saddr: "/tmp/gotest5.net", cnet: "unixgram", caddr: "/tmp/gotest5.net.local"}, - {snet: "unixgram", saddr: "/tmp/gotest5.net", cnet: "unixgram", caddr: "/tmp/gotest5.net.local", dial: true}, - {snet: "unixgram", saddr: "/tmp/gotest5.net", cnet: "unixgram", caddr: "/tmp/gotest5.net.local", empty: true}, - {snet: "unixgram", saddr: "/tmp/gotest5.net", cnet: "unixgram", caddr: "/tmp/gotest5.net.local", dial: true, empty: true}, + {snet: "unixgram", saddr: tempfile("gotest5.net"), cnet: "unixgram", caddr: tempfile("gotest5.net.local")}, + {snet: "unixgram", saddr: tempfile("gotest5.net"), cnet: "unixgram", caddr: tempfile("gotest5.net.local"), dial: true}, + {snet: "unixgram", saddr: tempfile("gotest5.net"), cnet: "unixgram", caddr: tempfile("gotest5.net.local"), empty: true}, + {snet: "unixgram", saddr: tempfile("gotest5.net"), cnet: "unixgram", caddr: tempfile("gotest5.net.local"), dial: true, empty: true}, {snet: "unixgram", saddr: "@gotest6/net", cnet: "unixgram", caddr: "@gotest6/net.local", linux: true}, } @@ -333,8 +336,7 @@ func TestDatagramPacketConnServer(t *testing.T) { case "udp", "udp4", "udp6": _, port, err := SplitHostPort(taddr) if err != nil { - t.Errorf("SplitHostPort(%q) failed: %v", taddr, err) - return + t.Fatalf("SplitHostPort(%q) failed: %v", taddr, err) } taddr = tt.caddr + ":" + port tt.caddr += ":0" @@ -397,14 +399,12 @@ func runDatagramConnClient(t *testing.T, net, laddr, taddr string, isEmpty bool) case "udp", "udp4", "udp6": c, err = Dial(net, taddr) if err != nil { - t.Errorf("Dial(%q, %q) failed: %v", net, taddr, err) - return + t.Fatalf("Dial(%q, %q) failed: %v", net, taddr, err) } case "unixgram": - c, err = DialUnix(net, &UnixAddr{laddr, net}, &UnixAddr{taddr, net}) + c, err = DialUnix(net, &UnixAddr{Name: laddr, Net: net}, &UnixAddr{Name: taddr, Net: net}) if err != nil { - t.Errorf("DialUnix(%q, {%q, %q}) failed: %v", net, laddr, taddr, err) - return + t.Fatalf("DialUnix(%q, {%q, %q}) failed: %v", net, laddr, taddr, err) } } defer c.Close() @@ -415,14 +415,12 @@ func runDatagramConnClient(t *testing.T, net, laddr, taddr string, isEmpty bool) wb = []byte("DatagramConnClient by Dial\n") } if n, err := c.Write(wb[0:]); err != nil || n != len(wb) { - t.Errorf("Write failed: %v, %v; want %v, <nil>", n, err, len(wb)) - return + t.Fatalf("Write failed: %v, %v; want %v, <nil>", n, err, len(wb)) } rb := make([]byte, 1024) if n, err := c.Read(rb[0:]); err != nil || n != len(wb) { - t.Errorf("Read failed: %v, %v; want %v, <nil>", n, err, len(wb)) - return + t.Fatalf("Read failed: %v, %v; want %v, <nil>", n, err, len(wb)) } } @@ -433,20 +431,17 @@ func runDatagramPacketConnClient(t *testing.T, net, laddr, taddr string, isEmpty case "udp", "udp4", "udp6": ra, err = ResolveUDPAddr(net, taddr) if err != nil { - t.Errorf("ResolveUDPAddr(%q, %q) failed: %v", net, taddr, err) - return + t.Fatalf("ResolveUDPAddr(%q, %q) failed: %v", net, taddr, err) } case "unixgram": ra, err = ResolveUnixAddr(net, taddr) if err != nil { - t.Errorf("ResolveUxixAddr(%q, %q) failed: %v", net, taddr, err) - return + t.Fatalf("ResolveUxixAddr(%q, %q) failed: %v", net, taddr, err) } } c, err := ListenPacket(net, laddr) if err != nil { - t.Errorf("ListenPacket(%q, %q) faild: %v", net, laddr, err) - return + t.Fatalf("ListenPacket(%q, %q) faild: %v", net, laddr, err) } defer c.Close() c.SetReadDeadline(time.Now().Add(1 * time.Second)) @@ -456,13 +451,11 @@ func runDatagramPacketConnClient(t *testing.T, net, laddr, taddr string, isEmpty wb = []byte("DatagramPacketConnClient by ListenPacket\n") } if n, err := c.WriteTo(wb[0:], ra); err != nil || n != len(wb) { - t.Errorf("WriteTo(%v) failed: %v, %v; want %v, <nil>", ra, n, err, len(wb)) - return + t.Fatalf("WriteTo(%v) failed: %v, %v; want %v, <nil>", ra, n, err, len(wb)) } rb := make([]byte, 1024) if n, _, err := c.ReadFrom(rb[0:]); err != nil || n != len(wb) { - t.Errorf("ReadFrom failed: %v, %v; want %v, <nil>", n, err, len(wb)) - return + t.Fatalf("ReadFrom failed: %v, %v; want %v, <nil>", n, err, len(wb)) } } diff --git a/libgo/go/net/smtp/auth.go b/libgo/go/net/smtp/auth.go index d401e3c..3f1339e 100644 --- a/libgo/go/net/smtp/auth.go +++ b/libgo/go/net/smtp/auth.go @@ -54,7 +54,16 @@ func PlainAuth(identity, username, password, host string) Auth { func (a *plainAuth) Start(server *ServerInfo) (string, []byte, error) { if !server.TLS { - return "", nil, errors.New("unencrypted connection") + advertised := false + for _, mechanism := range server.Auth { + if mechanism == "PLAIN" { + advertised = true + break + } + } + if !advertised { + return "", nil, errors.New("unencrypted connection") + } } if server.Name != a.host { return "", nil, errors.New("wrong host name") diff --git a/libgo/go/net/smtp/smtp_test.go b/libgo/go/net/smtp/smtp_test.go index 8317428..c190b32 100644 --- a/libgo/go/net/smtp/smtp_test.go +++ b/libgo/go/net/smtp/smtp_test.go @@ -57,6 +57,41 @@ testLoop: } } +func TestAuthPlain(t *testing.T) { + auth := PlainAuth("foo", "bar", "baz", "servername") + + tests := []struct { + server *ServerInfo + err string + }{ + { + server: &ServerInfo{Name: "servername", TLS: true}, + }, + { + // Okay; explicitly advertised by server. + server: &ServerInfo{Name: "servername", Auth: []string{"PLAIN"}}, + }, + { + server: &ServerInfo{Name: "servername", Auth: []string{"CRAM-MD5"}}, + err: "unencrypted connection", + }, + { + server: &ServerInfo{Name: "attacker", TLS: true}, + err: "wrong host name", + }, + } + for i, tt := range tests { + _, _, err := auth.Start(tt.server) + got := "" + if err != nil { + got = err.Error() + } + if got != tt.err { + t.Errorf("%d. got error = %q; want %q", i, got, tt.err) + } + } +} + type faker struct { io.ReadWriter } diff --git a/libgo/go/net/sock_bsd.go b/libgo/go/net/sock_bsd.go index 2607b04..d993492 100644 --- a/libgo/go/net/sock_bsd.go +++ b/libgo/go/net/sock_bsd.go @@ -4,8 +4,6 @@ // +build darwin freebsd netbsd openbsd -// Sockets for BSD variants - package net import ( @@ -29,34 +27,11 @@ func maxListenerBacklog() int { if n == 0 || err != nil { return syscall.SOMAXCONN } - return int(n) -} - -func listenerSockaddr(s, f int, la syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (syscall.Sockaddr, error) { - a := toAddr(la) - if a == nil { - return la, nil - } - switch v := a.(type) { - case *TCPAddr, *UnixAddr: - err := setDefaultListenerSockopts(s) - if err != nil { - return nil, err - } - case *UDPAddr: - if v.IP.IsMulticast() { - err := setDefaultMulticastSockopts(s) - if err != nil { - return nil, err - } - switch f { - case syscall.AF_INET: - v.IP = IPv4zero - case syscall.AF_INET6: - v.IP = IPv6unspecified - } - return v.sockaddr(f) - } + // FreeBSD stores the backlog in a uint16, as does Linux. + // Assume the other BSDs do too. Truncate number to avoid wrapping. + // See issue 5030. + if n > 1<<16-1 { + n = 1<<16 - 1 } - return la, nil + return int(n) } diff --git a/libgo/go/net/sock_cloexec.go b/libgo/go/net/sock_cloexec.go index e2a5ef7..3f22cd8 100644 --- a/libgo/go/net/sock_cloexec.go +++ b/libgo/go/net/sock_cloexec.go @@ -44,20 +44,20 @@ func sysSocket(f, t, p int) (int, error) { func accept(fd int) (int, syscall.Sockaddr, error) { nfd, sa, err := syscall.Accept4(fd, syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC) // The accept4 system call was introduced in Linux 2.6.28. If - // we get an ENOSYS error, fall back to using accept. - if err == nil || err != syscall.ENOSYS { + // we get an ENOSYS or EINVAL error, fall back to using accept. + if err == nil || (err != syscall.ENOSYS && err != syscall.EINVAL) { return nfd, sa, err } // See ../syscall/exec_unix.go for description of ForkLock. - // It is okay to hold the lock across syscall.Accept + // It is probably okay to hold the lock across syscall.Accept // because we have put fd.sysfd into non-blocking mode. - syscall.ForkLock.RLock() + // However, a call to the File method will put it back into + // blocking mode. We can't take that risk, so no use of ForkLock here. nfd, sa, err = syscall.Accept(fd) if err == nil { syscall.CloseOnExec(nfd) } - syscall.ForkLock.RUnlock() if err != nil { return -1, nil, err } diff --git a/libgo/go/net/sock_linux.go b/libgo/go/net/sock_linux.go index e509d93..cc5ce15 100644 --- a/libgo/go/net/sock_linux.go +++ b/libgo/go/net/sock_linux.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Sockets for Linux - package net import "syscall" @@ -23,34 +21,11 @@ func maxListenerBacklog() int { if n == 0 || !ok { return syscall.SOMAXCONN } - return n -} - -func listenerSockaddr(s, f int, la syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (syscall.Sockaddr, error) { - a := toAddr(la) - if a == nil { - return la, nil - } - switch v := a.(type) { - case *TCPAddr, *UnixAddr: - err := setDefaultListenerSockopts(s) - if err != nil { - return nil, err - } - case *UDPAddr: - if v.IP.IsMulticast() { - err := setDefaultMulticastSockopts(s) - if err != nil { - return nil, err - } - switch f { - case syscall.AF_INET: - v.IP = IPv4zero - case syscall.AF_INET6: - v.IP = IPv6unspecified - } - return v.sockaddr(f) - } + // Linux stores the backlog in a uint16. + // Truncate number to avoid wrapping. + // See issue 5030. + if n > 1<<16-1 { + n = 1<<16 - 1 } - return la, nil + return n } diff --git a/libgo/go/net/sock_posix.go b/libgo/go/net/sock_posix.go index 9cd149e..be89c26d 100644 --- a/libgo/go/net/sock_posix.go +++ b/libgo/go/net/sock_posix.go @@ -4,8 +4,6 @@ // +build darwin freebsd linux netbsd openbsd windows -// Sockets - package net import ( @@ -15,7 +13,7 @@ import ( var listenerBacklog = maxListenerBacklog() -// Generic socket creation. +// Generic POSIX socket creation. func socket(net string, f, t, p int, ipv6only bool, ulsa, ursa syscall.Sockaddr, deadline time.Time, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) { s, err := sysSocket(f, t, p) if err != nil { @@ -27,7 +25,8 @@ func socket(net string, f, t, p int, ipv6only bool, ulsa, ursa syscall.Sockaddr, return nil, err } - if ulsa != nil { + // This socket is used by a listener. + if ulsa != nil && ursa == nil { // We provide a socket that listens to a wildcard // address with reusable UDP port when the given ulsa // is an appropriate UDP multicast address prefix. @@ -39,6 +38,9 @@ func socket(net string, f, t, p int, ipv6only bool, ulsa, ursa syscall.Sockaddr, closesocket(s) return nil, err } + } + + if ulsa != nil { if err = syscall.Bind(s, ulsa); err != nil { closesocket(s) return nil, err @@ -50,19 +52,27 @@ func socket(net string, f, t, p int, ipv6only bool, ulsa, ursa syscall.Sockaddr, return nil, err } + // This socket is used by a dialer. if ursa != nil { - fd.wdeadline.setTime(deadline) - if err = fd.connect(ursa); err != nil { - closesocket(s) + if !deadline.IsZero() { + setWriteDeadline(fd, deadline) + } + if err = fd.connect(ulsa, ursa); err != nil { + fd.Close() return nil, err } fd.isConnected = true - fd.wdeadline.set(0) + if !deadline.IsZero() { + setWriteDeadline(fd, time.Time{}) + } } lsa, _ := syscall.Getsockname(s) laddr := toAddr(lsa) rsa, _ := syscall.Getpeername(s) + if rsa == nil { + rsa = ursa + } raddr := toAddr(rsa) fd.setAddr(laddr, raddr) return fd, nil diff --git a/libgo/go/net/sock_unix.go b/libgo/go/net/sock_unix.go new file mode 100644 index 0000000..b0d6d49 --- /dev/null +++ b/libgo/go/net/sock_unix.go @@ -0,0 +1,36 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin freebsd linux netbsd openbsd + +package net + +import "syscall" + +func listenerSockaddr(s, f int, la syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (syscall.Sockaddr, error) { + a := toAddr(la) + if a == nil { + return la, nil + } + switch a := a.(type) { + case *TCPAddr, *UnixAddr: + if err := setDefaultListenerSockopts(s); err != nil { + return nil, err + } + case *UDPAddr: + if a.IP.IsMulticast() { + if err := setDefaultMulticastSockopts(s); err != nil { + return nil, err + } + switch f { + case syscall.AF_INET: + a.IP = IPv4zero + case syscall.AF_INET6: + a.IP = IPv6unspecified + } + return a.sockaddr(f) + } + } + return la, nil +} diff --git a/libgo/go/net/sock_windows.go b/libgo/go/net/sock_windows.go index fc5d9e5..41368d3 100644 --- a/libgo/go/net/sock_windows.go +++ b/libgo/go/net/sock_windows.go @@ -2,14 +2,13 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Sockets for Windows - package net import "syscall" func maxListenerBacklog() int { // TODO: Implement this + // NOTE: Never return a number bigger than 1<<16 - 1. See issue 5030. return syscall.SOMAXCONN } @@ -18,25 +17,23 @@ func listenerSockaddr(s syscall.Handle, f int, la syscall.Sockaddr, toAddr func( if a == nil { return la, nil } - switch v := a.(type) { + switch a := a.(type) { case *TCPAddr, *UnixAddr: - err := setDefaultListenerSockopts(s) - if err != nil { + if err := setDefaultListenerSockopts(s); err != nil { return nil, err } case *UDPAddr: - if v.IP.IsMulticast() { - err := setDefaultMulticastSockopts(s) - if err != nil { + if a.IP.IsMulticast() { + if err := setDefaultMulticastSockopts(s); err != nil { return nil, err } switch f { case syscall.AF_INET: - v.IP = IPv4zero + a.IP = IPv4zero case syscall.AF_INET6: - v.IP = IPv6unspecified + a.IP = IPv6unspecified } - return v.sockaddr(f) + return a.sockaddr(f) } } return la, nil diff --git a/libgo/go/net/sockopt_posix.go b/libgo/go/net/sockopt_posix.go index fe371fe..1590f4e 100644 --- a/libgo/go/net/sockopt_posix.go +++ b/libgo/go/net/sockopt_posix.go @@ -11,7 +11,6 @@ package net import ( "os" "syscall" - "time" ) // Boolean to int. @@ -119,24 +118,6 @@ func setWriteBuffer(fd *netFD, bytes int) error { return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bytes)) } -// TODO(dfc) these unused error returns could be removed - -func setReadDeadline(fd *netFD, t time.Time) error { - fd.rdeadline.setTime(t) - return nil -} - -func setWriteDeadline(fd *netFD, t time.Time) error { - fd.wdeadline.setTime(t) - return nil -} - -func setDeadline(fd *netFD, t time.Time) error { - setReadDeadline(fd, t) - setWriteDeadline(fd, t) - return nil -} - func setKeepAlive(fd *netFD, keepalive bool) error { if err := fd.incref(false); err != nil { return err diff --git a/libgo/go/net/sockopt_windows.go b/libgo/go/net/sockopt_windows.go index 509b596..0861fe8 100644 --- a/libgo/go/net/sockopt_windows.go +++ b/libgo/go/net/sockopt_windows.go @@ -9,6 +9,7 @@ package net import ( "os" "syscall" + "time" ) func setDefaultSockopts(s syscall.Handle, f, t int, ipv6only bool) error { @@ -47,3 +48,21 @@ func setDefaultMulticastSockopts(s syscall.Handle) error { } return nil } + +// TODO(dfc) these unused error returns could be removed + +func setReadDeadline(fd *netFD, t time.Time) error { + fd.rdeadline.setTime(t) + return nil +} + +func setWriteDeadline(fd *netFD, t time.Time) error { + fd.wdeadline.setTime(t) + return nil +} + +func setDeadline(fd *netFD, t time.Time) error { + setReadDeadline(fd, t) + setWriteDeadline(fd, t) + return nil +} diff --git a/libgo/go/net/sys_cloexec.go b/libgo/go/net/sys_cloexec.go index 75d5688..17e8749 100644 --- a/libgo/go/net/sys_cloexec.go +++ b/libgo/go/net/sys_cloexec.go @@ -35,14 +35,14 @@ func sysSocket(f, t, p int) (int, error) { // descriptor as nonblocking and close-on-exec. func accept(fd int) (int, syscall.Sockaddr, error) { // See ../syscall/exec_unix.go for description of ForkLock. - // It is okay to hold the lock across syscall.Accept + // It is probably okay to hold the lock across syscall.Accept // because we have put fd.sysfd into non-blocking mode. - syscall.ForkLock.RLock() + // However, a call to the File method will put it back into + // blocking mode. We can't take that risk, so no use of ForkLock here. nfd, sa, err := syscall.Accept(fd) if err == nil { syscall.CloseOnExec(nfd) } - syscall.ForkLock.RUnlock() if err != nil { return -1, nil, err } diff --git a/libgo/go/net/tcp_test.go b/libgo/go/net/tcp_test.go index 1d54b3a..a71b02b 100644 --- a/libgo/go/net/tcp_test.go +++ b/libgo/go/net/tcp_test.go @@ -5,29 +5,58 @@ package net import ( + "fmt" "reflect" "runtime" "testing" "time" ) -func BenchmarkTCPOneShot(b *testing.B) { - benchmarkTCP(b, false, false) +func BenchmarkTCP4OneShot(b *testing.B) { + benchmarkTCP(b, false, false, "127.0.0.1:0") } -func BenchmarkTCPOneShotTimeout(b *testing.B) { - benchmarkTCP(b, false, true) +func BenchmarkTCP4OneShotTimeout(b *testing.B) { + benchmarkTCP(b, false, true, "127.0.0.1:0") } -func BenchmarkTCPPersistent(b *testing.B) { - benchmarkTCP(b, true, false) +func BenchmarkTCP4Persistent(b *testing.B) { + benchmarkTCP(b, true, false, "127.0.0.1:0") } -func BenchmarkTCPPersistentTimeout(b *testing.B) { - benchmarkTCP(b, true, true) +func BenchmarkTCP4PersistentTimeout(b *testing.B) { + benchmarkTCP(b, true, true, "127.0.0.1:0") } -func benchmarkTCP(b *testing.B, persistent, timeout bool) { +func BenchmarkTCP6OneShot(b *testing.B) { + if !supportsIPv6 { + b.Skip("ipv6 is not supported") + } + benchmarkTCP(b, false, false, "[::1]:0") +} + +func BenchmarkTCP6OneShotTimeout(b *testing.B) { + if !supportsIPv6 { + b.Skip("ipv6 is not supported") + } + benchmarkTCP(b, false, true, "[::1]:0") +} + +func BenchmarkTCP6Persistent(b *testing.B) { + if !supportsIPv6 { + b.Skip("ipv6 is not supported") + } + benchmarkTCP(b, true, false, "[::1]:0") +} + +func BenchmarkTCP6PersistentTimeout(b *testing.B) { + if !supportsIPv6 { + b.Skip("ipv6 is not supported") + } + benchmarkTCP(b, true, true, "[::1]:0") +} + +func benchmarkTCP(b *testing.B, persistent, timeout bool, laddr string) { const msgLen = 512 conns := b.N numConcurrent := runtime.GOMAXPROCS(-1) * 16 @@ -61,7 +90,7 @@ func benchmarkTCP(b *testing.B, persistent, timeout bool) { } return true } - ln, err := Listen("tcp", "127.0.0.1:0") + ln, err := Listen("tcp", laddr) if err != nil { b.Fatalf("Listen failed: %v", err) } @@ -118,24 +147,39 @@ func benchmarkTCP(b *testing.B, persistent, timeout bool) { } } -var resolveTCPAddrTests = []struct { +type resolveTCPAddrTest struct { net string litAddr string addr *TCPAddr err error -}{ +} + +var resolveTCPAddrTests = []resolveTCPAddrTest{ {"tcp", "127.0.0.1:0", &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 0}, nil}, {"tcp4", "127.0.0.1:65535", &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 65535}, nil}, {"tcp", "[::1]:1", &TCPAddr{IP: ParseIP("::1"), Port: 1}, nil}, {"tcp6", "[::1]:65534", &TCPAddr{IP: ParseIP("::1"), Port: 65534}, nil}, + {"tcp", "[::1%en0]:1", &TCPAddr{IP: ParseIP("::1"), Port: 1, Zone: "en0"}, nil}, + {"tcp6", "[::1%911]:2", &TCPAddr{IP: ParseIP("::1"), Port: 2, Zone: "911"}, nil}, + {"", "127.0.0.1:0", &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 0}, nil}, // Go 1.0 behavior {"", "[::1]:0", &TCPAddr{IP: ParseIP("::1"), Port: 0}, nil}, // Go 1.0 behavior {"http", "127.0.0.1:0", nil, UnknownNetworkError("http")}, } +func init() { + if ifi := loopbackInterface(); ifi != nil { + index := fmt.Sprintf("%v", ifi.Index) + resolveTCPAddrTests = append(resolveTCPAddrTests, []resolveTCPAddrTest{ + {"tcp6", "[fe80::1%" + ifi.Name + "]:3", &TCPAddr{IP: ParseIP("fe80::1"), Port: 3, Zone: zoneToString(ifi.Index)}, nil}, + {"tcp6", "[fe80::1%" + index + "]:4", &TCPAddr{IP: ParseIP("fe80::1"), Port: 4, Zone: index}, nil}, + }...) + } +} + func TestResolveTCPAddr(t *testing.T) { for _, tt := range resolveTCPAddrTests { addr, err := ResolveTCPAddr(tt.net, tt.litAddr) @@ -165,14 +209,88 @@ func TestTCPListenerName(t *testing.T) { for _, tt := range tcpListenerNameTests { ln, err := ListenTCP(tt.net, tt.laddr) if err != nil { - t.Errorf("ListenTCP failed: %v", err) - return + t.Fatalf("ListenTCP failed: %v", err) } defer ln.Close() la := ln.Addr() if a, ok := la.(*TCPAddr); !ok || a.Port == 0 { - t.Errorf("got %v; expected a proper address with non-zero port number", la) - return + t.Fatalf("got %v; expected a proper address with non-zero port number", la) + } + } +} + +func TestIPv6LinkLocalUnicastTCP(t *testing.T) { + if testing.Short() || !*testExternal { + t.Skip("skipping test to avoid external network") + } + if !supportsIPv6 { + t.Skip("ipv6 is not supported") + } + ifi := loopbackInterface() + if ifi == nil { + t.Skip("loopback interface not found") + } + laddr := ipv6LinkLocalUnicastAddr(ifi) + if laddr == "" { + t.Skip("ipv6 unicast address on loopback not found") + } + + type test struct { + net, addr string + nameLookup bool + } + var tests = []test{ + {"tcp", "[" + laddr + "%" + ifi.Name + "]:0", false}, + {"tcp6", "[" + laddr + "%" + ifi.Name + "]:0", false}, + } + switch runtime.GOOS { + case "darwin", "freebsd", "opensbd", "netbsd": + tests = append(tests, []test{ + {"tcp", "[localhost%" + ifi.Name + "]:0", true}, + {"tcp6", "[localhost%" + ifi.Name + "]:0", true}, + }...) + case "linux": + tests = append(tests, []test{ + {"tcp", "[ip6-localhost%" + ifi.Name + "]:0", true}, + {"tcp6", "[ip6-localhost%" + ifi.Name + "]:0", true}, + }...) + } + for _, tt := range tests { + ln, err := Listen(tt.net, tt.addr) + if err != nil { + // It might return "LookupHost returned no + // suitable address" error on some platforms. + t.Logf("Listen failed: %v", err) + continue } + defer ln.Close() + if la, ok := ln.Addr().(*TCPAddr); !ok || !tt.nameLookup && la.Zone == "" { + t.Fatalf("got %v; expected a proper address with zone identifier", la) + } + + done := make(chan int) + go transponder(t, ln, done) + + c, err := Dial(tt.net, ln.Addr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + defer c.Close() + if la, ok := c.LocalAddr().(*TCPAddr); !ok || !tt.nameLookup && la.Zone == "" { + t.Fatalf("got %v; expected a proper address with zone identifier", la) + } + if ra, ok := c.RemoteAddr().(*TCPAddr); !ok || !tt.nameLookup && ra.Zone == "" { + t.Fatalf("got %v; expected a proper address with zone identifier", ra) + } + + if _, err := c.Write([]byte("TCP OVER IPV6 LINKLOCAL TEST")); err != nil { + t.Fatalf("Conn.Write failed: %v", err) + } + b := make([]byte, 32) + if _, err := c.Read(b); err != nil { + t.Fatalf("Conn.Read failed: %v", err) + } + + <-done } } diff --git a/libgo/go/net/tcpsock.go b/libgo/go/net/tcpsock.go index d5158b2..4d9ebd2 100644 --- a/libgo/go/net/tcpsock.go +++ b/libgo/go/net/tcpsock.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// TCP sockets - package net // TCPAddr represents the address of a TCP end point. @@ -20,14 +18,18 @@ func (a *TCPAddr) String() string { if a == nil { return "<nil>" } + if a.Zone != "" { + return JoinHostPort(a.IP.String()+"%"+a.Zone, itoa(a.Port)) + } return JoinHostPort(a.IP.String(), itoa(a.Port)) } -// ResolveTCPAddr parses addr as a TCP address of the form -// host:port and resolves domain names or port names to -// numeric addresses on the network net, which must be "tcp", -// "tcp4" or "tcp6". A literal IPv6 host address must be -// enclosed in square brackets, as in "[::]:80". +// ResolveTCPAddr parses addr as a TCP address of the form "host:port" +// or "[ipv6-host%zone]:port" and resolves a pair of domain name and +// port name on the network net, which must be "tcp", "tcp4" or +// "tcp6". A literal address or host name for IPv6 must be enclosed +// in square brackets, as in "[::1]:80", "[ipv6-host]:http" or +// "[ipv6-host%zone]:80". func ResolveTCPAddr(net, addr string) (*TCPAddr, error) { switch net { case "tcp", "tcp4", "tcp6": diff --git a/libgo/go/net/tcpsock_plan9.go b/libgo/go/net/tcpsock_plan9.go index 954c99a..48334fed 100644 --- a/libgo/go/net/tcpsock_plan9.go +++ b/libgo/go/net/tcpsock_plan9.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// TCP sockets for Plan 9 - package net import ( @@ -89,7 +87,7 @@ func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time) (*TCPConn, e switch net { case "tcp", "tcp4", "tcp6": default: - return nil, UnknownNetworkError(net) + return nil, &OpError{"dial", net, raddr, UnknownNetworkError(net)} } if raddr == nil { return nil, &OpError{"dial", net, nil, errMissingAddress} @@ -98,7 +96,7 @@ func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time) (*TCPConn, e if err != nil { return nil, err } - return &TCPConn{conn{fd}}, nil + return newTCPConn(fd), nil } // TCPListener is a TCP network listener. Clients should typically @@ -141,7 +139,7 @@ func (l *TCPListener) Close() error { } if _, err := l.fd.ctl.WriteString("hangup"); err != nil { l.fd.ctl.Close() - return err + return &OpError{"close", l.fd.ctl.Name(), l.fd.laddr, err} } return l.fd.ctl.Close() } @@ -161,17 +159,21 @@ func (l *TCPListener) SetDeadline(t time.Time) error { // File returns a copy of the underlying os.File, set to blocking // mode. It is the caller's responsibility to close f when finished. // Closing l does not affect f, and closing f does not affect l. -func (l *TCPListener) File() (f *os.File, err error) { return l.fd.dup() } +// +// The returned os.File's file descriptor is different from the +// connection's. Attempting to change properties of the original +// using this duplicate may or may not have the desired effect. +func (l *TCPListener) File() (f *os.File, err error) { return l.dup() } // ListenTCP announces on the TCP address laddr and returns a TCP // listener. Net must be "tcp", "tcp4", or "tcp6". If laddr has a -// port of 0, it means to listen on some available port. The caller -// can use l.Addr() to retrieve the chosen address. +// port of 0, ListenTCP will choose an available port. The caller can +// use the Addr method of TCPListener to retrieve the chosen address. func ListenTCP(net string, laddr *TCPAddr) (*TCPListener, error) { switch net { case "tcp", "tcp4", "tcp6": default: - return nil, UnknownNetworkError(net) + return nil, &OpError{"listen", net, laddr, UnknownNetworkError(net)} } if laddr == nil { laddr = &TCPAddr{} diff --git a/libgo/go/net/tcpsock_posix.go b/libgo/go/net/tcpsock_posix.go index bd5a2a2..876edb1 100644 --- a/libgo/go/net/tcpsock_posix.go +++ b/libgo/go/net/tcpsock_posix.go @@ -4,8 +4,6 @@ // +build darwin freebsd linux netbsd openbsd windows -// TCP sockets - package net import ( @@ -58,8 +56,8 @@ func (a *TCPAddr) toAddr() sockaddr { return a } -// TCPConn is an implementation of the Conn interface -// for TCP network connections. +// TCPConn is an implementation of the Conn interface for TCP network +// connections. type TCPConn struct { conn } @@ -96,17 +94,17 @@ func (c *TCPConn) CloseWrite() error { return c.fd.CloseWrite() } -// SetLinger sets the behavior of Close() on a connection -// which still has data waiting to be sent or to be acknowledged. +// SetLinger sets the behavior of Close() on a connection which still +// has data waiting to be sent or to be acknowledged. // -// If sec < 0 (the default), Close returns immediately and -// the operating system finishes sending the data in the background. +// If sec < 0 (the default), Close returns immediately and the +// operating system finishes sending the data in the background. // // If sec == 0, Close returns immediately and the operating system // discards any unsent or unacknowledged data. // -// If sec > 0, Close blocks for at most sec seconds waiting for -// data to be sent and acknowledged. +// If sec > 0, Close blocks for at most sec seconds waiting for data +// to be sent and acknowledged. func (c *TCPConn) SetLinger(sec int) error { if !c.ok() { return syscall.EINVAL @@ -124,9 +122,9 @@ func (c *TCPConn) SetKeepAlive(keepalive bool) error { } // SetNoDelay controls whether the operating system should delay -// packet transmission in hopes of sending fewer packets -// (Nagle's algorithm). The default is true (no delay), meaning -// that data is sent as soon as possible after a Write. +// packet transmission in hopes of sending fewer packets (Nagle's +// algorithm). The default is true (no delay), meaning that data is +// sent as soon as possible after a Write. func (c *TCPConn) SetNoDelay(noDelay bool) error { if !c.ok() { return syscall.EINVAL @@ -135,8 +133,8 @@ func (c *TCPConn) SetNoDelay(noDelay bool) error { } // DialTCP connects to the remote address raddr on the network net, -// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used -// as the local address for the connection. +// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is +// used as the local address for the connection. func DialTCP(net string, laddr, raddr *TCPAddr) (*TCPConn, error) { switch net { case "tcp", "tcp4", "tcp6": @@ -216,16 +214,15 @@ func spuriousENOTAVAIL(err error) bool { return ok && e.Err == syscall.EADDRNOTAVAIL } -// TCPListener is a TCP network listener. -// Clients should typically use variables of type Listener -// instead of assuming TCP. +// TCPListener is a TCP network listener. Clients should typically +// use variables of type Listener instead of assuming TCP. type TCPListener struct { fd *netFD } -// AcceptTCP accepts the next incoming call and returns the new connection -// and the remote address. -func (l *TCPListener) AcceptTCP() (c *TCPConn, err error) { +// AcceptTCP accepts the next incoming call and returns the new +// connection and the remote address. +func (l *TCPListener) AcceptTCP() (*TCPConn, error) { if l == nil || l.fd == nil { return nil, syscall.EINVAL } @@ -236,14 +233,14 @@ func (l *TCPListener) AcceptTCP() (c *TCPConn, err error) { return newTCPConn(fd), nil } -// Accept implements the Accept method in the Listener interface; -// it waits for the next call and returns a generic Conn. -func (l *TCPListener) Accept() (c Conn, err error) { - c1, err := l.AcceptTCP() +// Accept implements the Accept method in the Listener interface; it +// waits for the next call and returns a generic Conn. +func (l *TCPListener) Accept() (Conn, error) { + c, err := l.AcceptTCP() if err != nil { return nil, err } - return c1, nil + return c, nil } // Close stops listening on the TCP address. @@ -267,15 +264,19 @@ func (l *TCPListener) SetDeadline(t time.Time) error { return setDeadline(l.fd, t) } -// File returns a copy of the underlying os.File, set to blocking mode. -// It is the caller's responsibility to close f when finished. +// File returns a copy of the underlying os.File, set to blocking +// mode. It is the caller's responsibility to close f when finished. // Closing l does not affect f, and closing f does not affect l. +// +// The returned os.File's file descriptor is different from the +// connection's. Attempting to change properties of the original +// using this duplicate may or may not have the desired effect. func (l *TCPListener) File() (f *os.File, err error) { return l.fd.dup() } -// ListenTCP announces on the TCP address laddr and returns a TCP listener. -// Net must be "tcp", "tcp4", or "tcp6". -// If laddr has a port of 0, it means to listen on some available port. -// The caller can use l.Addr() to retrieve the chosen address. +// ListenTCP announces on the TCP address laddr and returns a TCP +// listener. Net must be "tcp", "tcp4", or "tcp6". If laddr has a +// port of 0, ListenTCP will choose an available port. The caller can +// use the Addr method of TCPListener to retrieve the chosen address. func ListenTCP(net string, laddr *TCPAddr) (*TCPListener, error) { switch net { case "tcp", "tcp4", "tcp6": @@ -291,7 +292,7 @@ func ListenTCP(net string, laddr *TCPAddr) (*TCPListener, error) { } err = syscall.Listen(fd.sysfd, listenerBacklog) if err != nil { - closesocket(fd.sysfd) + fd.Close() return nil, &OpError{"listen", net, laddr, err} } return &TCPListener{fd}, nil diff --git a/libgo/go/net/textproto/reader.go b/libgo/go/net/textproto/reader.go index 855350c..5bd26ac 100644 --- a/libgo/go/net/textproto/reader.go +++ b/libgo/go/net/textproto/reader.go @@ -128,6 +128,17 @@ func (r *Reader) readContinuedLineSlice() ([]byte, error) { return line, nil } + // Optimistically assume that we have started to buffer the next line + // and it starts with an ASCII letter (the next header key), so we can + // avoid copying that buffered data around in memory and skipping over + // non-existent whitespace. + if r.R.Buffered() > 1 { + peek, err := r.R.Peek(1) + if err == nil && isASCIILetter(peek[0]) { + return trim(line), nil + } + } + // ReadByte or the next readLineSlice will flush the read buffer; // copy the slice into buf. r.buf = append(r.buf[:0], trim(line)...) @@ -445,7 +456,7 @@ func (r *Reader) ReadDotLines() ([]string, error) { // } // func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) { - m := make(MIMEHeader) + m := make(MIMEHeader, 4) for { kv, err := r.readContinuedLineSlice() if len(kv) == 0 { @@ -478,7 +489,6 @@ func (r *Reader) ReadMIMEHeader() (MIMEHeader, error) { return m, err } } - panic("unreachable") } // CanonicalMIMEHeaderKey returns the canonical format of the @@ -564,6 +574,7 @@ var commonHeaders = []string{ "Content-Length", "Content-Transfer-Encoding", "Content-Type", + "Cookie", "Date", "Dkim-Signature", "Etag", diff --git a/libgo/go/net/textproto/reader_test.go b/libgo/go/net/textproto/reader_test.go index 26987f6..f27042d 100644 --- a/libgo/go/net/textproto/reader_test.go +++ b/libgo/go/net/textproto/reader_test.go @@ -290,6 +290,7 @@ Non-Interned: test `, "\n", "\r\n", -1) func BenchmarkReadMIMEHeader(b *testing.B) { + b.ReportAllocs() var buf bytes.Buffer br := bufio.NewReader(&buf) r := NewReader(br) @@ -319,6 +320,7 @@ func BenchmarkReadMIMEHeader(b *testing.B) { } func BenchmarkUncommon(b *testing.B) { + b.ReportAllocs() var buf bytes.Buffer br := bufio.NewReader(&buf) r := NewReader(br) diff --git a/libgo/go/net/textproto/textproto.go b/libgo/go/net/textproto/textproto.go index e7ad877..eb6ced1 100644 --- a/libgo/go/net/textproto/textproto.go +++ b/libgo/go/net/textproto/textproto.go @@ -147,3 +147,8 @@ func TrimBytes(b []byte) []byte { func isASCIISpace(b byte) bool { return b == ' ' || b == '\t' || b == '\n' || b == '\r' } + +func isASCIILetter(b byte) bool { + b |= 0x20 // make lower case + return 'a' <= b && b <= 'z' +} diff --git a/libgo/go/net/timeout_test.go b/libgo/go/net/timeout_test.go index 7cf45ca..2e92147 100644 --- a/libgo/go/net/timeout_test.go +++ b/libgo/go/net/timeout_test.go @@ -420,6 +420,11 @@ func TestVariousDeadlines4Proc(t *testing.T) { } func testVariousDeadlines(t *testing.T, maxProcs int) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) ln := newLocalListener(t) defer ln.Close() @@ -518,11 +523,16 @@ func testVariousDeadlines(t *testing.T, maxProcs int) { // TestReadDeadlineDataAvailable tests that read deadlines work, even // if there's data ready to be read. func TestReadDeadlineDataAvailable(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + ln := newLocalListener(t) defer ln.Close() servec := make(chan copyRes) - const msg = "data client shouldn't read, even though it it'll be waiting" + const msg = "data client shouldn't read, even though it'll be waiting" go func() { c, err := ln.Accept() if err != nil { @@ -552,6 +562,11 @@ func TestReadDeadlineDataAvailable(t *testing.T) { // TestWriteDeadlineBufferAvailable tests that write deadlines work, even // if there's buffer space available to write. func TestWriteDeadlineBufferAvailable(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + ln := newLocalListener(t) defer ln.Close() @@ -581,6 +596,64 @@ func TestWriteDeadlineBufferAvailable(t *testing.T) { } } +// TestAcceptDeadlineConnectionAvailable tests that accept deadlines work, even +// if there's incoming connections available. +func TestAcceptDeadlineConnectionAvailable(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + + ln := newLocalListener(t).(*TCPListener) + defer ln.Close() + + go func() { + c, err := Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer c.Close() + var buf [1]byte + c.Read(buf[:]) // block until the connection or listener is closed + }() + time.Sleep(10 * time.Millisecond) + ln.SetDeadline(time.Now().Add(-5 * time.Second)) // in the past + c, err := ln.Accept() + if err == nil { + defer c.Close() + } + if !isTimeout(err) { + t.Fatalf("Accept: got %v; want timeout", err) + } +} + +// TestConnectDeadlineInThePast tests that connect deadlines work, even +// if the connection can be established w/o blocking. +func TestConnectDeadlineInThePast(t *testing.T) { + switch runtime.GOOS { + case "plan9": + t.Skipf("skipping test on %q", runtime.GOOS) + } + + ln := newLocalListener(t).(*TCPListener) + defer ln.Close() + + go func() { + c, err := ln.Accept() + if err == nil { + defer c.Close() + } + }() + time.Sleep(10 * time.Millisecond) + c, err := DialTimeout("tcp", ln.Addr().String(), -5*time.Second) // in the past + if err == nil { + defer c.Close() + } + if !isTimeout(err) { + t.Fatalf("DialTimeout: got %v; want timeout", err) + } +} + // TestProlongTimeout tests concurrent deadline modification. // Known to cause data races in the past. func TestProlongTimeout(t *testing.T) { diff --git a/libgo/go/net/udp_test.go b/libgo/go/net/udp_test.go index 220422e..4278f6d 100644 --- a/libgo/go/net/udp_test.go +++ b/libgo/go/net/udp_test.go @@ -5,29 +5,45 @@ package net import ( + "fmt" "reflect" "runtime" "testing" ) -var resolveUDPAddrTests = []struct { +type resolveUDPAddrTest struct { net string litAddr string addr *UDPAddr err error -}{ +} + +var resolveUDPAddrTests = []resolveUDPAddrTest{ {"udp", "127.0.0.1:0", &UDPAddr{IP: IPv4(127, 0, 0, 1), Port: 0}, nil}, {"udp4", "127.0.0.1:65535", &UDPAddr{IP: IPv4(127, 0, 0, 1), Port: 65535}, nil}, {"udp", "[::1]:1", &UDPAddr{IP: ParseIP("::1"), Port: 1}, nil}, {"udp6", "[::1]:65534", &UDPAddr{IP: ParseIP("::1"), Port: 65534}, nil}, + {"udp", "[::1%en0]:1", &UDPAddr{IP: ParseIP("::1"), Port: 1, Zone: "en0"}, nil}, + {"udp6", "[::1%911]:2", &UDPAddr{IP: ParseIP("::1"), Port: 2, Zone: "911"}, nil}, + {"", "127.0.0.1:0", &UDPAddr{IP: IPv4(127, 0, 0, 1), Port: 0}, nil}, // Go 1.0 behavior {"", "[::1]:0", &UDPAddr{IP: ParseIP("::1"), Port: 0}, nil}, // Go 1.0 behavior {"sip", "127.0.0.1:0", nil, UnknownNetworkError("sip")}, } +func init() { + if ifi := loopbackInterface(); ifi != nil { + index := fmt.Sprintf("%v", ifi.Index) + resolveUDPAddrTests = append(resolveUDPAddrTests, []resolveUDPAddrTest{ + {"udp6", "[fe80::1%" + ifi.Name + "]:3", &UDPAddr{IP: ParseIP("fe80::1"), Port: 3, Zone: zoneToString(ifi.Index)}, nil}, + {"udp6", "[fe80::1%" + index + "]:4", &UDPAddr{IP: ParseIP("fe80::1"), Port: 4, Zone: index}, nil}, + }...) + } +} + func TestResolveUDPAddr(t *testing.T) { for _, tt := range resolveUDPAddrTests { addr, err := ResolveUDPAddr(tt.net, tt.litAddr) @@ -135,14 +151,125 @@ func TestUDPConnLocalName(t *testing.T) { for _, tt := range udpConnLocalNameTests { c, err := ListenUDP(tt.net, tt.laddr) if err != nil { - t.Errorf("ListenUDP failed: %v", err) - return + t.Fatalf("ListenUDP failed: %v", err) } defer c.Close() la := c.LocalAddr() if a, ok := la.(*UDPAddr); !ok || a.Port == 0 { - t.Errorf("got %v; expected a proper address with non-zero port number", la) - return + t.Fatalf("got %v; expected a proper address with non-zero port number", la) + } + } +} + +func TestUDPConnLocalAndRemoteNames(t *testing.T) { + for _, laddr := range []string{"", "127.0.0.1:0"} { + c1, err := ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("ListenUDP failed: %v", err) + } + defer c1.Close() + + var la *UDPAddr + if laddr != "" { + var err error + if la, err = ResolveUDPAddr("udp", laddr); err != nil { + t.Fatalf("ResolveUDPAddr failed: %v", err) + } + } + c2, err := DialUDP("udp", la, c1.LocalAddr().(*UDPAddr)) + if err != nil { + t.Fatalf("DialUDP failed: %v", err) + } + defer c2.Close() + + var connAddrs = [4]struct { + got Addr + ok bool + }{ + {c1.LocalAddr(), true}, + {c1.(*UDPConn).RemoteAddr(), false}, + {c2.LocalAddr(), true}, + {c2.RemoteAddr(), true}, + } + for _, ca := range connAddrs { + if a, ok := ca.got.(*UDPAddr); ok != ca.ok || ok && a.Port == 0 { + t.Fatalf("got %v; expected a proper address with non-zero port number", ca.got) + } + } + } +} + +func TestIPv6LinkLocalUnicastUDP(t *testing.T) { + if testing.Short() || !*testExternal { + t.Skip("skipping test to avoid external network") + } + if !supportsIPv6 { + t.Skip("ipv6 is not supported") + } + ifi := loopbackInterface() + if ifi == nil { + t.Skip("loopback interface not found") + } + laddr := ipv6LinkLocalUnicastAddr(ifi) + if laddr == "" { + t.Skip("ipv6 unicast address on loopback not found") + } + + type test struct { + net, addr string + nameLookup bool + } + var tests = []test{ + {"udp", "[" + laddr + "%" + ifi.Name + "]:0", false}, + {"udp6", "[" + laddr + "%" + ifi.Name + "]:0", false}, + } + switch runtime.GOOS { + case "darwin", "freebsd", "openbsd", "netbsd": + tests = append(tests, []test{ + {"udp", "[localhost%" + ifi.Name + "]:0", true}, + {"udp6", "[localhost%" + ifi.Name + "]:0", true}, + }...) + case "linux": + tests = append(tests, []test{ + {"udp", "[ip6-localhost%" + ifi.Name + "]:0", true}, + {"udp6", "[ip6-localhost%" + ifi.Name + "]:0", true}, + }...) + } + for _, tt := range tests { + c1, err := ListenPacket(tt.net, tt.addr) + if err != nil { + // It might return "LookupHost returned no + // suitable address" error on some platforms. + t.Logf("ListenPacket failed: %v", err) + continue + } + defer c1.Close() + if la, ok := c1.LocalAddr().(*UDPAddr); !ok || !tt.nameLookup && la.Zone == "" { + t.Fatalf("got %v; expected a proper address with zone identifier", la) + } + + c2, err := Dial(tt.net, c1.LocalAddr().String()) + if err != nil { + t.Fatalf("Dial failed: %v", err) + } + defer c2.Close() + if la, ok := c2.LocalAddr().(*UDPAddr); !ok || !tt.nameLookup && la.Zone == "" { + t.Fatalf("got %v; expected a proper address with zone identifier", la) + } + if ra, ok := c2.RemoteAddr().(*UDPAddr); !ok || !tt.nameLookup && ra.Zone == "" { + t.Fatalf("got %v; expected a proper address with zone identifier", ra) + } + + if _, err := c2.Write([]byte("UDP OVER IPV6 LINKLOCAL TEST")); err != nil { + t.Fatalf("Conn.Write failed: %v", err) + } + b := make([]byte, 32) + if _, from, err := c1.ReadFrom(b); err != nil { + t.Fatalf("PacketConn.ReadFrom failed: %v", err) + } else { + if ra, ok := from.(*UDPAddr); !ok || !tt.nameLookup && ra.Zone == "" { + t.Fatalf("got %v; expected a proper address with zone identifier", ra) + } } } } diff --git a/libgo/go/net/udpsock.go b/libgo/go/net/udpsock.go index 6e5e902..5ce7d6b 100644 --- a/libgo/go/net/udpsock.go +++ b/libgo/go/net/udpsock.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// UDP sockets - package net import "errors" @@ -24,14 +22,18 @@ func (a *UDPAddr) String() string { if a == nil { return "<nil>" } + if a.Zone != "" { + return JoinHostPort(a.IP.String()+"%"+a.Zone, itoa(a.Port)) + } return JoinHostPort(a.IP.String(), itoa(a.Port)) } -// ResolveUDPAddr parses addr as a UDP address of the form -// host:port and resolves domain names or port names to -// numeric addresses on the network net, which must be "udp", -// "udp4" or "udp6". A literal IPv6 host address must be -// enclosed in square brackets, as in "[::]:80". +// ResolveUDPAddr parses addr as a UDP address of the form "host:port" +// or "[ipv6-host%zone]:port" and resolves a pair of domain name and +// port name on the network net, which must be "udp", "udp4" or +// "udp6". A literal address or host name for IPv6 must be enclosed +// in square brackets, as in "[::1]:80", "[ipv6-host]:http" or +// "[ipv6-host%zone]:80". func ResolveUDPAddr(net, addr string) (*UDPAddr, error) { switch net { case "udp", "udp4", "udp6": diff --git a/libgo/go/net/udpsock_plan9.go b/libgo/go/net/udpsock_plan9.go index b9ade48..12a3483 100644 --- a/libgo/go/net/udpsock_plan9.go +++ b/libgo/go/net/udpsock_plan9.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// UDP sockets for Plan 9 - package net import ( @@ -13,12 +11,14 @@ import ( "time" ) -// UDPConn is the implementation of the Conn and PacketConn -// interfaces for UDP network connections. +// UDPConn is the implementation of the Conn and PacketConn interfaces +// for UDP network connections. type UDPConn struct { conn } +func newUDPConn(fd *netFD) *UDPConn { return &UDPConn{conn{fd}} } + // ReadFromUDP reads a UDP packet from c, copying the payload into b. // It returns the number of bytes copied into b and the return address // that was on the packet. @@ -27,15 +27,9 @@ type UDPConn struct { // Timeout() == true after a fixed time limit; see SetDeadline and // SetReadDeadline. func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err error) { - if !c.ok() { + if !c.ok() || c.fd.data == nil { return 0, nil, syscall.EINVAL } - if c.fd.data == nil { - c.fd.data, err = os.OpenFile(c.fd.dir+"/data", os.O_RDWR, 0) - if err != nil { - return 0, nil, err - } - } buf := make([]byte, udpHeaderSize+len(b)) m, err := c.fd.data.Read(buf) if err != nil { @@ -60,7 +54,7 @@ func (c *UDPConn) ReadFrom(b []byte) (int, Addr, error) { } // ReadMsgUDP reads a packet from c, copying the payload into b and -// the associdated out-of-band data into oob. It returns the number +// the associated out-of-band data into oob. It returns the number // of bytes copied into b, the number of bytes copied into oob, the // flags that were set on the packet and the source address of the // packet. @@ -76,16 +70,9 @@ func (c *UDPConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, // SetWriteDeadline. On packet-oriented connections, write timeouts // are rare. func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (int, error) { - if !c.ok() { + if !c.ok() || c.fd.data == nil { return 0, syscall.EINVAL } - if c.fd.data == nil { - f, err := os.OpenFile(c.fd.dir+"/data", os.O_RDWR, 0) - if err != nil { - return 0, err - } - c.fd.data = f - } h := new(udpHeader) h.raddr = addr.IP.To16() h.laddr = c.fd.laddr.(*UDPAddr).IP.To16() @@ -141,7 +128,7 @@ func dialUDP(net string, laddr, raddr *UDPAddr, deadline time.Time) (*UDPConn, e if err != nil { return nil, err } - return &UDPConn{conn{fd}}, nil + return newUDPConn(fd), nil } const udpHeaderSize = 16*3 + 2*2 @@ -173,7 +160,10 @@ func unmarshalUDPHeader(b []byte) (*udpHeader, []byte) { } // ListenUDP listens for incoming UDP packets addressed to the local -// address laddr. The returned connection c's ReadFrom and WriteTo +// address laddr. Net must be "udp", "udp4", or "udp6". If laddr has +// a port of 0, ListenUDP will choose an available port. +// The LocalAddr method of the returned UDPConn can be used to +// discover the port. The returned connection's ReadFrom and WriteTo // methods can be used to receive and send UDP packets with per-packet // addressing. func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) { @@ -193,7 +183,11 @@ func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) { if err != nil { return nil, err } - return &UDPConn{conn{l.netFD()}}, nil + l.data, err = os.OpenFile(l.dir+"/data", os.O_RDWR, 0) + if err != nil { + return nil, err + } + return newUDPConn(l.netFD()), nil } // ListenMulticastUDP listens for incoming multicast UDP packets diff --git a/libgo/go/net/udpsock_posix.go b/libgo/go/net/udpsock_posix.go index 385cd90..b90cb03 100644 --- a/libgo/go/net/udpsock_posix.go +++ b/libgo/go/net/udpsock_posix.go @@ -4,8 +4,6 @@ // +build darwin freebsd linux netbsd openbsd windows -// UDP sockets for POSIX - package net import ( @@ -51,8 +49,8 @@ func (a *UDPAddr) toAddr() sockaddr { return a } -// UDPConn is the implementation of the Conn and PacketConn -// interfaces for UDP network connections. +// UDPConn is the implementation of the Conn and PacketConn interfaces +// for UDP network connections. type UDPConn struct { conn } @@ -63,8 +61,9 @@ func newUDPConn(fd *netFD) *UDPConn { return &UDPConn{conn{fd}} } // It returns the number of bytes copied into b and the return address // that was on the packet. // -// ReadFromUDP can be made to time out and return an error with Timeout() == true -// after a fixed time limit; see SetDeadline and SetReadDeadline. +// ReadFromUDP can be made to time out and return an error with +// Timeout() == true after a fixed time limit; see SetDeadline and +// SetReadDeadline. func (c *UDPConn) ReadFromUDP(b []byte) (n int, addr *UDPAddr, err error) { if !c.ok() { return 0, nil, syscall.EINVAL @@ -89,7 +88,7 @@ func (c *UDPConn) ReadFrom(b []byte) (int, Addr, error) { } // ReadMsgUDP reads a packet from c, copying the payload into b and -// the associdated out-of-band data into oob. It returns the number +// the associated out-of-band data into oob. It returns the number // of bytes copied into b, the number of bytes copied into oob, the // flags that were set on the packet and the source address of the // packet. @@ -108,12 +107,13 @@ func (c *UDPConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, return } -// WriteToUDP writes a UDP packet to addr via c, copying the payload from b. +// WriteToUDP writes a UDP packet to addr via c, copying the payload +// from b. // -// WriteToUDP can be made to time out and return -// an error with Timeout() == true after a fixed time limit; -// see SetDeadline and SetWriteDeadline. -// On packet-oriented connections, write timeouts are rare. +// WriteToUDP can be made to time out and return an error with +// Timeout() == true after a fixed time limit; see SetDeadline and +// SetWriteDeadline. On packet-oriented connections, write timeouts +// are rare. func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (int, error) { if !c.ok() { return 0, syscall.EINVAL @@ -158,8 +158,8 @@ func (c *UDPConn) WriteMsgUDP(b, oob []byte, addr *UDPAddr) (n, oobn int, err er } // DialUDP connects to the remote address raddr on the network net, -// which must be "udp", "udp4", or "udp6". If laddr is not nil, it is used -// as the local address for the connection. +// which must be "udp", "udp4", or "udp6". If laddr is not nil, it is +// used as the local address for the connection. func DialUDP(net string, laddr, raddr *UDPAddr) (*UDPConn, error) { return dialUDP(net, laddr, raddr, noDeadline) } @@ -180,10 +180,13 @@ func dialUDP(net string, laddr, raddr *UDPAddr, deadline time.Time) (*UDPConn, e return newUDPConn(fd), nil } -// ListenUDP listens for incoming UDP packets addressed to the -// local address laddr. The returned connection c's ReadFrom -// and WriteTo methods can be used to receive and send UDP -// packets with per-packet addressing. +// ListenUDP listens for incoming UDP packets addressed to the local +// address laddr. Net must be "udp", "udp4", or "udp6". If laddr has +// a port of 0, ListenUDP will choose an available port. +// The LocalAddr method of the returned UDPConn can be used to +// discover the port. The returned connection's ReadFrom and WriteTo +// methods can be used to receive and send UDP packets with per-packet +// addressing. func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) { switch net { case "udp", "udp4", "udp6": @@ -201,9 +204,9 @@ func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) { } // ListenMulticastUDP listens for incoming multicast UDP packets -// addressed to the group address gaddr on ifi, which specifies -// the interface to join. ListenMulticastUDP uses default -// multicast interface if ifi is nil. +// addressed to the group address gaddr on ifi, which specifies the +// interface to join. ListenMulticastUDP uses default multicast +// interface if ifi is nil. func ListenMulticastUDP(net string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) { switch net { case "udp", "udp4", "udp6": diff --git a/libgo/go/net/unicast_posix_test.go b/libgo/go/net/unicast_posix_test.go index a8855ca..b0588f4 100644 --- a/libgo/go/net/unicast_posix_test.go +++ b/libgo/go/net/unicast_posix_test.go @@ -45,7 +45,7 @@ var listenerTests = []struct { // same port. func TestTCPListener(t *testing.T) { switch runtime.GOOS { - case "plan9", "windows": + case "plan9": t.Skipf("skipping test on %q", runtime.GOOS) } @@ -69,65 +69,8 @@ func TestTCPListener(t *testing.T) { // same port. func TestUDPListener(t *testing.T) { switch runtime.GOOS { - case "plan9", "windows": - t.Skipf("skipping test on %q", runtime.GOOS) - } - - toudpnet := func(net string) string { - switch net { - case "tcp": - return "udp" - case "tcp4": - return "udp4" - case "tcp6": - return "udp6" - } - return "<nil>" - } - - for _, tt := range listenerTests { - if tt.wildcard && (testing.Short() || !*testExternal) { - continue - } - if tt.ipv6 && !supportsIPv6 { - continue - } - tt.net = toudpnet(tt.net) - l1, port := usableListenPacketPort(t, tt.net, tt.laddr) - checkFirstListener(t, tt.net, tt.laddr+":"+port, l1) - l2, err := ListenPacket(tt.net, tt.laddr+":"+port) - checkSecondListener(t, tt.net, tt.laddr+":"+port, err, l2) - l1.Close() - } -} - -func TestSimpleTCPListener(t *testing.T) { - switch runtime.GOOS { - case "plan9": - t.Skipf("skipping test on %q", runtime.GOOS) - return - } - - for _, tt := range listenerTests { - if tt.wildcard && (testing.Short() || !*testExternal) { - continue - } - if tt.ipv6 { - continue - } - l1, port := usableListenPort(t, tt.net, tt.laddr) - checkFirstListener(t, tt.net, tt.laddr+":"+port, l1) - l2, err := Listen(tt.net, tt.laddr+":"+port) - checkSecondListener(t, tt.net, tt.laddr+":"+port, err, l2) - l1.Close() - } -} - -func TestSimpleUDPListener(t *testing.T) { - switch runtime.GOOS { case "plan9": t.Skipf("skipping test on %q", runtime.GOOS) - return } toudpnet := func(net string) string { @@ -146,7 +89,7 @@ func TestSimpleUDPListener(t *testing.T) { if tt.wildcard && (testing.Short() || !*testExternal) { continue } - if tt.ipv6 { + if tt.ipv6 && !supportsIPv6 { continue } tt.net = toudpnet(tt.net) @@ -231,7 +174,7 @@ func TestDualStackTCPListener(t *testing.T) { t.Skipf("skipping test on %q", runtime.GOOS) } if !supportsIPv6 { - return + t.Skip("ipv6 is not supported") } for _, tt := range dualStackListenerTests { @@ -263,7 +206,7 @@ func TestDualStackUDPListener(t *testing.T) { t.Skipf("skipping test on %q", runtime.GOOS) } if !supportsIPv6 { - return + t.Skip("ipv6 is not supported") } toudpnet := func(net string) string { diff --git a/libgo/go/net/unix_test.go b/libgo/go/net/unix_test.go new file mode 100644 index 0000000..5e63e9d9 --- /dev/null +++ b/libgo/go/net/unix_test.go @@ -0,0 +1,246 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !plan9,!windows + +package net + +import ( + "bytes" + "os" + "reflect" + "runtime" + "syscall" + "testing" + "time" +) + +func TestReadUnixgramWithUnnamedSocket(t *testing.T) { + addr := testUnixAddr() + la, err := ResolveUnixAddr("unixgram", addr) + if err != nil { + t.Fatalf("ResolveUnixAddr failed: %v", err) + } + c, err := ListenUnixgram("unixgram", la) + if err != nil { + t.Fatalf("ListenUnixgram failed: %v", err) + } + defer func() { + c.Close() + os.Remove(addr) + }() + + off := make(chan bool) + data := [5]byte{1, 2, 3, 4, 5} + go func() { + defer func() { off <- true }() + s, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0) + if err != nil { + t.Errorf("syscall.Socket failed: %v", err) + return + } + defer syscall.Close(s) + rsa := &syscall.SockaddrUnix{Name: addr} + if err := syscall.Sendto(s, data[:], 0, rsa); err != nil { + t.Errorf("syscall.Sendto failed: %v", err) + return + } + }() + + <-off + b := make([]byte, 64) + c.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + n, from, err := c.ReadFrom(b) + if err != nil { + t.Fatalf("UnixConn.ReadFrom failed: %v", err) + } + if from != nil { + t.Fatalf("neighbor address is %v", from) + } + if !bytes.Equal(b[:n], data[:]) { + t.Fatalf("got %v, want %v", b[:n], data[:]) + } +} + +func TestReadUnixgramWithZeroBytesBuffer(t *testing.T) { + // issue 4352: Recvfrom failed with "address family not + // supported by protocol family" if zero-length buffer provided + + addr := testUnixAddr() + la, err := ResolveUnixAddr("unixgram", addr) + if err != nil { + t.Fatalf("ResolveUnixAddr failed: %v", err) + } + c, err := ListenUnixgram("unixgram", la) + if err != nil { + t.Fatalf("ListenUnixgram failed: %v", err) + } + defer func() { + c.Close() + os.Remove(addr) + }() + + off := make(chan bool) + go func() { + defer func() { off <- true }() + c, err := DialUnix("unixgram", nil, la) + if err != nil { + t.Errorf("DialUnix failed: %v", err) + return + } + defer c.Close() + if _, err := c.Write([]byte{1, 2, 3, 4, 5}); err != nil { + t.Errorf("UnixConn.Write failed: %v", err) + return + } + }() + + <-off + c.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + _, from, err := c.ReadFrom(nil) + if err != nil { + t.Fatalf("UnixConn.ReadFrom failed: %v", err) + } + if from != nil { + t.Fatalf("neighbor address is %v", from) + } +} + +func TestUnixAutobind(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("skipping: autobind is linux only") + } + + laddr := &UnixAddr{Name: "", Net: "unixgram"} + c1, err := ListenUnixgram("unixgram", laddr) + if err != nil { + t.Fatalf("ListenUnixgram failed: %v", err) + } + defer c1.Close() + + // retrieve the autobind address + autoAddr := c1.LocalAddr().(*UnixAddr) + if len(autoAddr.Name) <= 1 { + t.Fatalf("invalid autobind address: %v", autoAddr) + } + if autoAddr.Name[0] != '@' { + t.Fatalf("invalid autobind address: %v", autoAddr) + } + + c2, err := DialUnix("unixgram", nil, autoAddr) + if err != nil { + t.Fatalf("DialUnix failed: %v", err) + } + defer c2.Close() + + if !reflect.DeepEqual(c1.LocalAddr(), c2.RemoteAddr()) { + t.Fatalf("expected autobind address %v, got %v", c1.LocalAddr(), c2.RemoteAddr()) + } +} + +func TestUnixConnLocalAndRemoteNames(t *testing.T) { + for _, laddr := range []string{"", testUnixAddr()} { + taddr := testUnixAddr() + ta, err := ResolveUnixAddr("unix", taddr) + if err != nil { + t.Fatalf("ResolveUnixAddr failed: %v", err) + } + ln, err := ListenUnix("unix", ta) + if err != nil { + t.Fatalf("ListenUnix failed: %v", err) + } + defer func() { + ln.Close() + os.Remove(taddr) + }() + + done := make(chan int) + go transponder(t, ln, done) + + la, err := ResolveUnixAddr("unix", laddr) + if err != nil { + t.Fatalf("ResolveUnixAddr failed: %v", err) + } + c, err := DialUnix("unix", la, ta) + if err != nil { + t.Fatalf("DialUnix failed: %v", err) + } + defer func() { + c.Close() + if la != nil { + defer os.Remove(laddr) + } + }() + if _, err := c.Write([]byte("UNIXCONN LOCAL AND REMOTE NAME TEST")); err != nil { + t.Fatalf("UnixConn.Write failed: %v", err) + } + + if runtime.GOOS == "linux" && laddr == "" { + laddr = "@" // autobind feature + } + var connAddrs = [3]struct{ got, want Addr }{ + {ln.Addr(), ta}, + {c.LocalAddr(), &UnixAddr{Name: laddr, Net: "unix"}}, + {c.RemoteAddr(), ta}, + } + for _, ca := range connAddrs { + if !reflect.DeepEqual(ca.got, ca.want) { + t.Fatalf("got %#v, expected %#v", ca.got, ca.want) + } + } + + <-done + } +} + +func TestUnixgramConnLocalAndRemoteNames(t *testing.T) { + for _, laddr := range []string{"", testUnixAddr()} { + taddr := testUnixAddr() + ta, err := ResolveUnixAddr("unixgram", taddr) + if err != nil { + t.Fatalf("ResolveUnixAddr failed: %v", err) + } + c1, err := ListenUnixgram("unixgram", ta) + if err != nil { + t.Fatalf("ListenUnixgram failed: %v", err) + } + defer func() { + c1.Close() + os.Remove(taddr) + }() + + var la *UnixAddr + if laddr != "" { + var err error + if la, err = ResolveUnixAddr("unixgram", laddr); err != nil { + t.Fatalf("ResolveUnixAddr failed: %v", err) + } + } + c2, err := DialUnix("unixgram", la, ta) + if err != nil { + t.Fatalf("DialUnix failed: %v", err) + } + defer func() { + c2.Close() + if la != nil { + defer os.Remove(laddr) + } + }() + + if runtime.GOOS == "linux" && laddr == "" { + laddr = "@" // autobind feature + } + var connAddrs = [4]struct{ got, want Addr }{ + {c1.LocalAddr(), ta}, + {c1.RemoteAddr(), nil}, + {c2.LocalAddr(), &UnixAddr{Name: laddr, Net: "unixgram"}}, + {c2.RemoteAddr(), ta}, + } + for _, ca := range connAddrs { + if !reflect.DeepEqual(ca.got, ca.want) { + t.Fatalf("got %#v, expected %#v", ca.got, ca.want) + } + } + } +} diff --git a/libgo/go/net/unixsock.go b/libgo/go/net/unixsock.go index ae09569..21a19ec 100644 --- a/libgo/go/net/unixsock.go +++ b/libgo/go/net/unixsock.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Unix domain sockets - package net // UnixAddr represents the address of a Unix domain socket end point. @@ -12,7 +10,8 @@ type UnixAddr struct { Net string } -// Network returns the address's network name, "unix" or "unixgram". +// Network returns the address's network name, "unix", "unixgram" or +// "unixpacket". func (a *UnixAddr) Network() string { return a.Net } @@ -36,11 +35,9 @@ func (a *UnixAddr) toAddr() Addr { // "unixpacket". func ResolveUnixAddr(net, addr string) (*UnixAddr, error) { switch net { - case "unix": - case "unixpacket": - case "unixgram": + case "unix", "unixgram", "unixpacket": + return &UnixAddr{Name: addr, Net: net}, nil default: return nil, UnknownNetworkError(net) } - return &UnixAddr{addr, net}, nil } diff --git a/libgo/go/net/unixsock_plan9.go b/libgo/go/net/unixsock_plan9.go index 713820c..8a1281f 100644 --- a/libgo/go/net/unixsock_plan9.go +++ b/libgo/go/net/unixsock_plan9.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Unix domain sockets stubs for Plan 9 - package net import ( @@ -93,8 +91,7 @@ func dialUnix(net string, laddr, raddr *UnixAddr, deadline time.Time) (*UnixConn type UnixListener struct{} // ListenUnix announces on the Unix domain socket laddr and returns a -// Unix listener. The network net must be "unix", "unixgram" or -// "unixpacket". +// Unix listener. The network net must be "unix" or "unixpacket". func ListenUnix(net string, laddr *UnixAddr) (*UnixListener, error) { return nil, syscall.EPLAN9 } @@ -129,14 +126,18 @@ func (l *UnixListener) SetDeadline(t time.Time) error { // File returns a copy of the underlying os.File, set to blocking // mode. It is the caller's responsibility to close f when finished. // Closing l does not affect f, and closing f does not affect l. +// +// The returned os.File's file descriptor is different from the +// connection's. Attempting to change properties of the original +// using this duplicate may or may not have the desired effect. func (l *UnixListener) File() (*os.File, error) { return nil, syscall.EPLAN9 } // ListenUnixgram listens for incoming Unix datagram packets addressed -// to the local address laddr. The returned connection c's ReadFrom -// and WriteTo methods can be used to receive and send packets with -// per-packet addressing. The network net must be "unixgram". +// to the local address laddr. The network net must be "unixgram". +// The returned connection's ReadFrom and WriteTo methods can be used +// to receive and send packets with per-packet addressing. func ListenUnixgram(net string, laddr *UnixAddr) (*UnixConn, error) { return nil, syscall.EPLAN9 } diff --git a/libgo/go/net/unixsock_posix.go b/libgo/go/net/unixsock_posix.go index 653190c..5db30df 100644 --- a/libgo/go/net/unixsock_posix.go +++ b/libgo/go/net/unixsock_posix.go @@ -4,8 +4,6 @@ // +build darwin freebsd linux netbsd openbsd windows -// Unix domain sockets - package net import ( @@ -15,6 +13,13 @@ import ( "time" ) +func (a *UnixAddr) isUnnamed() bool { + if a == nil || a.Name == "" { + return true + } + return false +} + func unixSocket(net string, laddr, raddr *UnixAddr, mode string, deadline time.Time) (*netFD, error) { var sotype int switch net { @@ -31,12 +36,12 @@ func unixSocket(net string, laddr, raddr *UnixAddr, mode string, deadline time.T var la, ra syscall.Sockaddr switch mode { case "dial": - if laddr != nil { + if !laddr.isUnnamed() { la = &syscall.SockaddrUnix{Name: laddr.Name} } if raddr != nil { ra = &syscall.SockaddrUnix{Name: raddr.Name} - } else if sotype != syscall.SOCK_DGRAM || laddr == nil { + } else if sotype != syscall.SOCK_DGRAM || laddr.isUnnamed() { return nil, &OpError{Op: mode, Net: net, Err: errMissingAddress} } case "listen": @@ -69,21 +74,21 @@ error: func sockaddrToUnix(sa syscall.Sockaddr) Addr { if s, ok := sa.(*syscall.SockaddrUnix); ok { - return &UnixAddr{s.Name, "unix"} + return &UnixAddr{Name: s.Name, Net: "unix"} } return nil } func sockaddrToUnixgram(sa syscall.Sockaddr) Addr { if s, ok := sa.(*syscall.SockaddrUnix); ok { - return &UnixAddr{s.Name, "unixgram"} + return &UnixAddr{Name: s.Name, Net: "unixgram"} } return nil } func sockaddrToUnixpacket(sa syscall.Sockaddr) Addr { if s, ok := sa.(*syscall.SockaddrUnix); ok { - return &UnixAddr{s.Name, "unixpacket"} + return &UnixAddr{Name: s.Name, Net: "unixpacket"} } return nil } @@ -92,14 +97,13 @@ func sotypeToNet(sotype int) string { switch sotype { case syscall.SOCK_STREAM: return "unix" - case syscall.SOCK_SEQPACKET: - return "unixpacket" case syscall.SOCK_DGRAM: return "unixgram" + case syscall.SOCK_SEQPACKET: + return "unixpacket" default: panic("sotypeToNet unknown socket type") } - return "" } // UnixConn is an implementation of the Conn interface for connections @@ -124,18 +128,20 @@ func (c *UnixConn) ReadFromUnix(b []byte) (n int, addr *UnixAddr, err error) { n, sa, err := c.fd.ReadFrom(b) switch sa := sa.(type) { case *syscall.SockaddrUnix: - addr = &UnixAddr{sa.Name, sotypeToNet(c.fd.sotype)} + if sa.Name != "" { + addr = &UnixAddr{Name: sa.Name, Net: sotypeToNet(c.fd.sotype)} + } } return } // ReadFrom implements the PacketConn ReadFrom method. -func (c *UnixConn) ReadFrom(b []byte) (n int, addr Addr, err error) { +func (c *UnixConn) ReadFrom(b []byte) (int, Addr, error) { if !c.ok() { return 0, nil, syscall.EINVAL } - n, uaddr, err := c.ReadFromUnix(b) - return n, uaddr.toAddr(), err + n, addr, err := c.ReadFromUnix(b) + return n, addr.toAddr(), err } // ReadMsgUnix reads a packet from c, copying the payload into b and @@ -149,7 +155,9 @@ func (c *UnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *UnixAdd n, oobn, flags, sa, err := c.fd.ReadMsg(b, oob) switch sa := sa.(type) { case *syscall.SockaddrUnix: - addr = &UnixAddr{sa.Name, sotypeToNet(c.fd.sotype)} + if sa.Name != "" { + addr = &UnixAddr{Name: sa.Name, Net: sotypeToNet(c.fd.sotype)} + } } return } @@ -247,11 +255,10 @@ type UnixListener struct { } // ListenUnix announces on the Unix domain socket laddr and returns a -// Unix listener. The network net must be "unix", "unixgram" or -// "unixpacket". +// Unix listener. The network net must be "unix" or "unixpacket". func ListenUnix(net string, laddr *UnixAddr) (*UnixListener, error) { switch net { - case "unix", "unixgram", "unixpacket": + case "unix", "unixpacket": default: return nil, UnknownNetworkError(net) } @@ -264,7 +271,7 @@ func ListenUnix(net string, laddr *UnixAddr) (*UnixListener, error) { } err = syscall.Listen(fd.sysfd, listenerBacklog) if err != nil { - closesocket(fd.sysfd) + fd.Close() return nil, &OpError{Op: "listen", Net: net, Addr: laddr, Err: err} } return &UnixListener{fd, laddr.Name}, nil @@ -332,12 +339,16 @@ func (l *UnixListener) SetDeadline(t time.Time) (err error) { // File returns a copy of the underlying os.File, set to blocking // mode. It is the caller's responsibility to close f when finished. // Closing l does not affect f, and closing f does not affect l. +// +// The returned os.File's file descriptor is different from the +// connection's. Attempting to change properties of the original +// using this duplicate may or may not have the desired effect. func (l *UnixListener) File() (f *os.File, err error) { return l.fd.dup() } // ListenUnixgram listens for incoming Unix datagram packets addressed -// to the local address laddr. The returned connection c's ReadFrom -// and WriteTo methods can be used to receive and send packets with -// per-packet addressing. The network net must be "unixgram". +// to the local address laddr. The network net must be "unixgram". +// The returned connection's ReadFrom and WriteTo methods can be used +// to receive and send packets with per-packet addressing. func ListenUnixgram(net string, laddr *UnixAddr) (*UnixConn, error) { switch net { case "unixgram": diff --git a/libgo/go/net/url/url.go b/libgo/go/net/url/url.go index 68f2c2f..459dc47 100644 --- a/libgo/go/net/url/url.go +++ b/libgo/go/net/url/url.go @@ -220,6 +220,13 @@ func escape(s string, mode encoding) string { // // scheme:opaque[?query][#fragment] // +// Note that the Path field is stored in decoded form: /%47%6f%2f becomes /Go/. +// A consequence is that it is impossible to tell which slashes in the Path were +// slashes in the raw URL and which were %2f. This distinction is rarely important, +// but when it is a client must use other routines to parse the raw URL or construct +// the parsed URL. For example, an HTTP server can consult req.RequestURI, and +// an HTTP client can use URL{Host: "example.com", Opaque: "//example.com/Go%2f"} +// instead of URL{Host: "example.com", Path: "/Go/"}. type URL struct { Scheme string Opaque string // encoded opaque data @@ -310,23 +317,22 @@ func getscheme(rawurl string) (scheme, path string, err error) { // Maybe s is of the form t c u. // If so, return t, c u (or t, u if cutc == true). // If not, return s, "". -func split(s string, c byte, cutc bool) (string, string) { - for i := 0; i < len(s); i++ { - if s[i] == c { - if cutc { - return s[0:i], s[i+1:] - } - return s[0:i], s[i:] - } +func split(s string, c string, cutc bool) (string, string) { + i := strings.Index(s, c) + if i < 0 { + return s, "" + } + if cutc { + return s[0:i], s[i+len(c):] } - return s, "" + return s[0:i], s[i:] } // Parse parses rawurl into a URL structure. // The rawurl may be relative or absolute. func Parse(rawurl string) (url *URL, err error) { // Cut off #frag - u, frag := split(rawurl, '#', true) + u, frag := split(rawurl, "#", true) if url, err = parse(u, false); err != nil { return nil, err } @@ -355,7 +361,7 @@ func ParseRequestURI(rawurl string) (url *URL, err error) { func parse(rawurl string, viaRequest bool) (url *URL, err error) { var rest string - if rawurl == "" { + if rawurl == "" && viaRequest { err = errors.New("empty url") goto Error } @@ -371,8 +377,9 @@ func parse(rawurl string, viaRequest bool) (url *URL, err error) { if url.Scheme, rest, err = getscheme(rawurl); err != nil { goto Error } + url.Scheme = strings.ToLower(url.Scheme) - rest, url.RawQuery = split(rest, '?', true) + rest, url.RawQuery = split(rest, "?", true) if !strings.HasPrefix(rest, "/") { if url.Scheme != "" { @@ -388,7 +395,7 @@ func parse(rawurl string, viaRequest bool) (url *URL, err error) { if (url.Scheme != "" || !viaRequest && !strings.HasPrefix(rest, "///")) && strings.HasPrefix(rest, "//") { var authority string - authority, rest = split(rest[2:], '/', false) + authority, rest = split(rest[2:], "/", false) url.User, url.Host, err = parseAuthority(authority) if err != nil { goto Error @@ -420,7 +427,7 @@ func parseAuthority(authority string) (user *Userinfo, host string, err error) { } user = User(userinfo) } else { - username, password := split(userinfo, ':', true) + username, password := split(userinfo, ":", true) if username, err = unescape(username, encodeUserPassword); err != nil { return } @@ -575,43 +582,39 @@ func (v Values) Encode() string { } // resolvePath applies special path segments from refs and applies -// them to base, per RFC 2396. -func resolvePath(basepath string, refpath string) string { - base := strings.Split(basepath, "/") - refs := strings.Split(refpath, "/") - if len(base) == 0 { - base = []string{""} +// them to base, per RFC 3986. +func resolvePath(base, ref string) string { + var full string + if ref == "" { + full = base + } else if ref[0] != '/' { + i := strings.LastIndex(base, "/") + full = base[:i+1] + ref + } else { + full = ref } - - rm := true - for idx, ref := range refs { - switch { - case ref == ".": - if idx == 0 { - base[len(base)-1] = "" - rm = true - } else { - rm = false - } - case ref == "..": - newLen := len(base) - 1 - if newLen < 1 { - newLen = 1 - } - base = base[0:newLen] - if rm { - base[len(base)-1] = "" + if full == "" { + return "" + } + var dst []string + src := strings.Split(full, "/") + for _, elem := range src { + switch elem { + case ".": + // drop + case "..": + if len(dst) > 0 { + dst = dst[:len(dst)-1] } default: - if idx == 0 || base[len(base)-1] == "" { - base[len(base)-1] = ref - } else { - base = append(base, ref) - } - rm = false + dst = append(dst, elem) } } - return strings.Join(base, "/") + if last := src[len(src)-1]; last == "." || last == ".." { + // Add final slash to the joined path. + dst = append(dst, "") + } + return "/" + strings.TrimLeft(strings.Join(dst, "/"), "/") } // IsAbs returns true if the URL is absolute. @@ -631,43 +634,39 @@ func (u *URL) Parse(ref string) (*URL, error) { } // ResolveReference resolves a URI reference to an absolute URI from -// an absolute base URI, per RFC 2396 Section 5.2. The URI reference +// an absolute base URI, per RFC 3986 Section 5.2. The URI reference // may be relative or absolute. ResolveReference always returns a new // URL instance, even if the returned URL is identical to either the // base or reference. If ref is an absolute URL, then ResolveReference // ignores base and returns a copy of ref. func (u *URL) ResolveReference(ref *URL) *URL { - if ref.IsAbs() { - url := *ref + url := *ref + if ref.Scheme == "" { + url.Scheme = u.Scheme + } + if ref.Scheme != "" || ref.Host != "" || ref.User != nil { + // The "absoluteURI" or "net_path" cases. + url.Path = resolvePath(ref.Path, "") return &url } - // relativeURI = ( net_path | abs_path | rel_path ) [ "?" query ] - url := *u - url.RawQuery = ref.RawQuery - url.Fragment = ref.Fragment if ref.Opaque != "" { - url.Opaque = ref.Opaque url.User = nil url.Host = "" url.Path = "" return &url } - if ref.Host != "" || ref.User != nil { - // The "net_path" case. - url.Host = ref.Host - url.User = ref.User - } - if strings.HasPrefix(ref.Path, "/") { - // The "abs_path" case. - url.Path = ref.Path - } else { - // The "rel_path" case. - path := resolvePath(u.Path, ref.Path) - if !strings.HasPrefix(path, "/") { - path = "/" + path + if ref.Path == "" { + if ref.RawQuery == "" { + url.RawQuery = u.RawQuery + if ref.Fragment == "" { + url.Fragment = u.Fragment + } } - url.Path = path } + // The "abs_path" or "rel_path" cases. + url.Host = u.Host + url.User = u.User + url.Path = resolvePath(u.Path, ref.Path) return &url } @@ -686,6 +685,10 @@ func (u *URL) RequestURI() string { if result == "" { result = "/" } + } else { + if strings.HasPrefix(result, "//") { + result = u.Scheme + ":" + result + } } if u.RawQuery != "" { result += "?" + u.RawQuery diff --git a/libgo/go/net/url/url_test.go b/libgo/go/net/url/url_test.go index cd3b0b9..9d81289 100644 --- a/libgo/go/net/url/url_test.go +++ b/libgo/go/net/url/url_test.go @@ -251,6 +251,15 @@ var urltests = []URLTest{ }, "file:///home/adg/rabbits", }, + // case-insensitive scheme + { + "MaIlTo:webmaster@golang.org", + &URL{ + Scheme: "mailto", + Opaque: "webmaster@golang.org", + }, + "mailto:webmaster@golang.org", + }, } // more useful string for debugging than fmt's struct printer @@ -514,18 +523,18 @@ func TestEncodeQuery(t *testing.T) { var resolvePathTests = []struct { base, ref, expected string }{ - {"a/b", ".", "a/"}, - {"a/b", "c", "a/c"}, - {"a/b", "..", ""}, - {"a/", "..", ""}, - {"a/", "../..", ""}, - {"a/b/c", "..", "a/"}, - {"a/b/c", "../d", "a/d"}, - {"a/b/c", ".././d", "a/d"}, - {"a/b", "./..", ""}, - {"a/./b", ".", "a/./"}, - {"a/../", ".", "a/../"}, - {"a/.././b", "c", "a/.././c"}, + {"a/b", ".", "/a/"}, + {"a/b", "c", "/a/c"}, + {"a/b", "..", "/"}, + {"a/", "..", "/"}, + {"a/", "../..", "/"}, + {"a/b/c", "..", "/a/"}, + {"a/b/c", "../d", "/a/d"}, + {"a/b/c", ".././d", "/a/d"}, + {"a/b", "./..", "/"}, + {"a/./b", ".", "/a/"}, + {"a/../", ".", "/"}, + {"a/.././b", "c", "/c"}, } func TestResolvePath(t *testing.T) { @@ -578,16 +587,71 @@ var resolveReferenceTests = []struct { {"http://foo.com/bar/baz", "quux/./dotdot/dotdot/././../../tail", "http://foo.com/bar/quux/tail"}, {"http://foo.com/bar/baz", "quux/./dotdot/dotdot/./.././../tail", "http://foo.com/bar/quux/tail"}, {"http://foo.com/bar/baz", "quux/./dotdot/dotdot/dotdot/./../../.././././tail", "http://foo.com/bar/quux/tail"}, - {"http://foo.com/bar/baz", "quux/./dotdot/../dotdot/../dot/./tail/..", "http://foo.com/bar/quux/dot"}, + {"http://foo.com/bar/baz", "quux/./dotdot/../dotdot/../dot/./tail/..", "http://foo.com/bar/quux/dot/"}, - // "." and ".." in the base aren't special - {"http://foo.com/dot/./dotdot/../foo/bar", "../baz", "http://foo.com/dot/./dotdot/../baz"}, + // Remove any dot-segments prior to forming the target URI. + // http://tools.ietf.org/html/rfc3986#section-5.2.4 + {"http://foo.com/dot/./dotdot/../foo/bar", "../baz", "http://foo.com/dot/baz"}, // Triple dot isn't special {"http://foo.com/bar", "...", "http://foo.com/..."}, // Fragment {"http://foo.com/bar", ".#frag", "http://foo.com/#frag"}, + + // RFC 3986: Normal Examples + // http://tools.ietf.org/html/rfc3986#section-5.4.1 + {"http://a/b/c/d;p?q", "g:h", "g:h"}, + {"http://a/b/c/d;p?q", "g", "http://a/b/c/g"}, + {"http://a/b/c/d;p?q", "./g", "http://a/b/c/g"}, + {"http://a/b/c/d;p?q", "g/", "http://a/b/c/g/"}, + {"http://a/b/c/d;p?q", "/g", "http://a/g"}, + {"http://a/b/c/d;p?q", "//g", "http://g"}, + {"http://a/b/c/d;p?q", "?y", "http://a/b/c/d;p?y"}, + {"http://a/b/c/d;p?q", "g?y", "http://a/b/c/g?y"}, + {"http://a/b/c/d;p?q", "#s", "http://a/b/c/d;p?q#s"}, + {"http://a/b/c/d;p?q", "g#s", "http://a/b/c/g#s"}, + {"http://a/b/c/d;p?q", "g?y#s", "http://a/b/c/g?y#s"}, + {"http://a/b/c/d;p?q", ";x", "http://a/b/c/;x"}, + {"http://a/b/c/d;p?q", "g;x", "http://a/b/c/g;x"}, + {"http://a/b/c/d;p?q", "g;x?y#s", "http://a/b/c/g;x?y#s"}, + {"http://a/b/c/d;p?q", "", "http://a/b/c/d;p?q"}, + {"http://a/b/c/d;p?q", ".", "http://a/b/c/"}, + {"http://a/b/c/d;p?q", "./", "http://a/b/c/"}, + {"http://a/b/c/d;p?q", "..", "http://a/b/"}, + {"http://a/b/c/d;p?q", "../", "http://a/b/"}, + {"http://a/b/c/d;p?q", "../g", "http://a/b/g"}, + {"http://a/b/c/d;p?q", "../..", "http://a/"}, + {"http://a/b/c/d;p?q", "../../", "http://a/"}, + {"http://a/b/c/d;p?q", "../../g", "http://a/g"}, + + // RFC 3986: Abnormal Examples + // http://tools.ietf.org/html/rfc3986#section-5.4.2 + {"http://a/b/c/d;p?q", "../../../g", "http://a/g"}, + {"http://a/b/c/d;p?q", "../../../../g", "http://a/g"}, + {"http://a/b/c/d;p?q", "/./g", "http://a/g"}, + {"http://a/b/c/d;p?q", "/../g", "http://a/g"}, + {"http://a/b/c/d;p?q", "g.", "http://a/b/c/g."}, + {"http://a/b/c/d;p?q", ".g", "http://a/b/c/.g"}, + {"http://a/b/c/d;p?q", "g..", "http://a/b/c/g.."}, + {"http://a/b/c/d;p?q", "..g", "http://a/b/c/..g"}, + {"http://a/b/c/d;p?q", "./../g", "http://a/b/g"}, + {"http://a/b/c/d;p?q", "./g/.", "http://a/b/c/g/"}, + {"http://a/b/c/d;p?q", "g/./h", "http://a/b/c/g/h"}, + {"http://a/b/c/d;p?q", "g/../h", "http://a/b/c/h"}, + {"http://a/b/c/d;p?q", "g;x=1/./y", "http://a/b/c/g;x=1/y"}, + {"http://a/b/c/d;p?q", "g;x=1/../y", "http://a/b/c/y"}, + {"http://a/b/c/d;p?q", "g?y/./x", "http://a/b/c/g?y/./x"}, + {"http://a/b/c/d;p?q", "g?y/../x", "http://a/b/c/g?y/../x"}, + {"http://a/b/c/d;p?q", "g#s/./x", "http://a/b/c/g#s/./x"}, + {"http://a/b/c/d;p?q", "g#s/../x", "http://a/b/c/g#s/../x"}, + + // Extras. + {"https://a/b/c/d;p?q", "//g?q", "https://g?q"}, + {"https://a/b/c/d;p?q", "//g#s", "https://g#s"}, + {"https://a/b/c/d;p?q", "//g/d/e/f?y#s", "https://g/d/e/f?y#s"}, + {"https://a/b/c/d;p#s", "?y", "https://a/b/c/d;p?y"}, + {"https://a/b/c/d;p?q#s", "?y", "https://a/b/c/d;p?y"}, } func TestResolveReference(t *testing.T) { @@ -598,91 +662,44 @@ func TestResolveReference(t *testing.T) { } return u } + opaque := &URL{Scheme: "scheme", Opaque: "opaque"} for _, test := range resolveReferenceTests { base := mustParse(test.base) rel := mustParse(test.rel) url := base.ResolveReference(rel) - urlStr := url.String() - if urlStr != test.expected { - t.Errorf("Resolving %q + %q != %q; got %q", test.base, test.rel, test.expected, urlStr) + if url.String() != test.expected { + t.Errorf("URL(%q).ResolveReference(%q) == %q, got %q", test.base, test.rel, test.expected, url.String()) } - } - - // Test that new instances are returned. - base := mustParse("http://foo.com/") - abs := base.ResolveReference(mustParse(".")) - if base == abs { - t.Errorf("Expected no-op reference to return new URL instance.") - } - barRef := mustParse("http://bar.com/") - abs = base.ResolveReference(barRef) - if abs == barRef { - t.Errorf("Expected resolution of absolute reference to return new URL instance.") - } - - // Test the convenience wrapper too - base = mustParse("http://foo.com/path/one/") - abs, _ = base.Parse("../two") - expected := "http://foo.com/path/two" - if abs.String() != expected { - t.Errorf("Parse wrapper got %q; expected %q", abs.String(), expected) - } - _, err := base.Parse("") - if err == nil { - t.Errorf("Expected an error from Parse wrapper parsing an empty string.") - } - - // Ensure Opaque resets the URL. - base = mustParse("scheme://user@foo.com/bar") - abs = base.ResolveReference(&URL{Opaque: "opaque"}) - want := mustParse("scheme:opaque") - if *abs != *want { - t.Errorf("ResolveReference failed to resolve opaque URL: want %#v, got %#v", abs, want) - } -} - -func TestResolveReferenceOpaque(t *testing.T) { - mustParse := func(url string) *URL { - u, err := Parse(url) + // Ensure that new instances are returned. + if base == url { + t.Errorf("Expected URL.ResolveReference to return new URL instance.") + } + // Test the convenience wrapper too. + url, err := base.Parse(test.rel) if err != nil { - t.Fatalf("Expected URL to parse: %q, got error: %v", url, err) + t.Errorf("URL(%q).Parse(%q) failed: %v", test.base, test.rel, err) + } else if url.String() != test.expected { + t.Errorf("URL(%q).Parse(%q) == %q, got %q", test.base, test.rel, test.expected, url.String()) + } else if base == url { + // Ensure that new instances are returned for the wrapper too. + t.Errorf("Expected URL.Parse to return new URL instance.") } - return u - } - for _, test := range resolveReferenceTests { - base := mustParse(test.base) - rel := mustParse(test.rel) - url := base.ResolveReference(rel) - urlStr := url.String() - if urlStr != test.expected { - t.Errorf("Resolving %q + %q != %q; got %q", test.base, test.rel, test.expected, urlStr) + // Ensure Opaque resets the URL. + url = base.ResolveReference(opaque) + if *url != *opaque { + t.Errorf("ResolveReference failed to resolve opaque URL: want %#v, got %#v", url, opaque) + } + // Test the convenience wrapper with an opaque URL too. + url, err = base.Parse("scheme:opaque") + if err != nil { + t.Errorf(`URL(%q).Parse("scheme:opaque") failed: %v`, test.base, err) + } else if *url != *opaque { + t.Errorf("Parse failed to resolve opaque URL: want %#v, got %#v", url, opaque) + } else if base == url { + // Ensure that new instances are returned, again. + t.Errorf("Expected URL.Parse to return new URL instance.") } } - - // Test that new instances are returned. - base := mustParse("http://foo.com/") - abs := base.ResolveReference(mustParse(".")) - if base == abs { - t.Errorf("Expected no-op reference to return new URL instance.") - } - barRef := mustParse("http://bar.com/") - abs = base.ResolveReference(barRef) - if abs == barRef { - t.Errorf("Expected resolution of absolute reference to return new URL instance.") - } - - // Test the convenience wrapper too - base = mustParse("http://foo.com/path/one/") - abs, _ = base.Parse("../two") - expected := "http://foo.com/path/two" - if abs.String() != expected { - t.Errorf("Parse wrapper got %q; expected %q", abs.String(), expected) - } - _, err := base.Parse("") - if err == nil { - t.Errorf("Expected an error from Parse wrapper parsing an empty string.") - } - } func TestQueryValues(t *testing.T) { @@ -789,6 +806,24 @@ var requritests = []RequestURITest{ }, "/a%20b", }, + // golang.org/issue/4860 variant 1 + { + &URL{ + Scheme: "http", + Host: "example.com", + Opaque: "/%2F/%2F/", + }, + "/%2F/%2F/", + }, + // golang.org/issue/4860 variant 2 + { + &URL{ + Scheme: "http", + Host: "example.com", + Opaque: "//other.example.com/%2F/%2F/", + }, + "http://other.example.com/%2F/%2F/", + }, { &URL{ Scheme: "http", |