diff options
Diffstat (limited to 'libgo/go/net/dial.go')
-rw-r--r-- | libgo/go/net/dial.go | 370 |
1 files changed, 250 insertions, 120 deletions
diff --git a/libgo/go/net/dial.go b/libgo/go/net/dial.go index 193776f..55edb43 100644 --- a/libgo/go/net/dial.go +++ b/libgo/go/net/dial.go @@ -1,11 +1,12 @@ -// Copyright 2010 The Go Authors. All rights reserved. +// Copyright 2010 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package net import ( - "errors" + "context" + "internal/nettrace" "time" ) @@ -61,21 +62,34 @@ type Dialer struct { // Cancel is an optional channel whose closure indicates that // the dial should be canceled. Not all types of dials support // cancelation. + // + // Deprecated: Use DialContext instead. Cancel <-chan struct{} } -// Return either now+Timeout or Deadline, whichever comes first. -// Or zero, if neither is set. -func (d *Dialer) deadline(now time.Time) time.Time { - if d.Timeout == 0 { - return d.Deadline +func minNonzeroTime(a, b time.Time) time.Time { + if a.IsZero() { + return b } - timeoutDeadline := now.Add(d.Timeout) - if d.Deadline.IsZero() || timeoutDeadline.Before(d.Deadline) { - return timeoutDeadline - } else { - return d.Deadline + if b.IsZero() || a.Before(b) { + return a + } + return b +} + +// deadline returns the earliest of: +// - now+Timeout +// - d.Deadline +// - the context's deadline +// Or zero, if none of Timeout, Deadline, or context's deadline is set. +func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Time) { + if d.Timeout != 0 { // including negative, for historical reasons + earliest = now.Add(d.Timeout) + } + if d, ok := ctx.Deadline(); ok { + earliest = minNonzeroTime(earliest, d) } + return minNonzeroTime(earliest, d.Deadline) } // partialDeadline returns the deadline to use for a single address, @@ -110,7 +124,7 @@ func (d *Dialer) fallbackDelay() time.Duration { } } -func parseNetwork(net string) (afnet string, proto int, err error) { +func parseNetwork(ctx context.Context, net string) (afnet string, proto int, err error) { i := last(net, ':') if i < 0 { // no colon switch net { @@ -129,7 +143,7 @@ func parseNetwork(net string) (afnet string, proto int, err error) { protostr := net[i+1:] proto, i, ok := dtoi(protostr, 0) if !ok || i != len(protostr) { - proto, err = lookupProtocol(protostr) + proto, err = lookupProtocol(ctx, protostr) if err != nil { return "", 0, err } @@ -139,8 +153,11 @@ func parseNetwork(net string) (afnet string, proto int, err error) { return "", 0, UnknownNetworkError(net) } -func resolveAddrList(op, net, addr string, deadline time.Time) (addrList, error) { - afnet, _, err := parseNetwork(net) +// resolverAddrList resolves addr using hint and returns a list of +// addresses. The result contains at least one address when error is +// nil. +func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) { + afnet, _, err := parseNetwork(ctx, network) if err != nil { return nil, err } @@ -149,13 +166,64 @@ func resolveAddrList(op, net, addr string, deadline time.Time) (addrList, error) } switch afnet { case "unix", "unixgram", "unixpacket": + // TODO(bradfitz): push down context addr, err := ResolveUnixAddr(afnet, addr) if err != nil { return nil, err } + if op == "dial" && hint != nil && addr.Network() != hint.Network() { + return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()} + } return addrList{addr}, nil } - return internetAddrList(afnet, addr, deadline) + addrs, err := internetAddrList(ctx, afnet, addr) + if err != nil || op != "dial" || hint == nil { + return addrs, err + } + var ( + tcp *TCPAddr + udp *UDPAddr + ip *IPAddr + wildcard bool + ) + switch hint := hint.(type) { + case *TCPAddr: + tcp = hint + wildcard = tcp.isWildcard() + case *UDPAddr: + udp = hint + wildcard = udp.isWildcard() + case *IPAddr: + ip = hint + wildcard = ip.isWildcard() + } + naddrs := addrs[:0] + for _, addr := range addrs { + if addr.Network() != hint.Network() { + return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()} + } + switch addr := addr.(type) { + case *TCPAddr: + if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(tcp.IP) { + continue + } + naddrs = append(naddrs, addr) + case *UDPAddr: + if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(udp.IP) { + continue + } + naddrs = append(naddrs, addr) + case *IPAddr: + if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(ip.IP) { + continue + } + naddrs = append(naddrs, addr) + } + } + if len(naddrs) == 0 { + return nil, errNoSuitableAddress + } + return naddrs, nil } // Dial connects to the address on the named network. @@ -173,8 +241,8 @@ func resolveAddrList(op, net, addr string, deadline time.Time) (addrList, error) // If the host is empty, as in ":80", the local system is assumed. // // Examples: -// Dial("tcp", "12.34.56.78:80") -// Dial("tcp", "google.com:http") +// Dial("tcp", "192.0.2.1:80") +// Dial("tcp", "golang.org:http") // Dial("tcp", "[2001:db8::1]:http") // Dial("tcp", "[fe80::1%lo0]:80") // Dial("tcp", ":80") @@ -184,8 +252,8 @@ func resolveAddrList(op, net, addr string, deadline time.Time) (addrList, error) // literal IP address. // // Examples: -// Dial("ip4:1", "127.0.0.1") -// Dial("ip6:ospf", "::1") +// Dial("ip4:1", "192.0.2.1") +// Dial("ip6:ipv6-icmp", "2001:db8::1") // // For Unix networks, the address must be a file system path. func Dial(network, address string) (Conn, error) { @@ -200,11 +268,10 @@ func DialTimeout(network, address string, timeout time.Duration) (Conn, error) { return d.Dial(network, address) } -// dialContext holds common state for all dial operations. -type dialContext struct { +// dialParam contains a Dial's parameters and configuration. +type dialParam struct { Dialer network, address string - finalDeadline time.Time } // Dial connects to the address on the named network. @@ -212,17 +279,62 @@ type dialContext struct { // See func Dial for a description of the network and address // parameters. func (d *Dialer) Dial(network, address string) (Conn, error) { - finalDeadline := d.deadline(time.Now()) - addrs, err := resolveAddrList("dial", network, address, finalDeadline) + return d.DialContext(context.Background(), network, address) +} + +// DialContext connects to the address on the named network using +// the provided context. +// +// The provided Context must be non-nil. If the context expires before +// the connection is complete, an error is returned. Once successfully +// connected, any expiration of the context will not affect the +// connection. +// +// See func Dial for a description of the network and address +// parameters. +func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) { + if ctx == nil { + panic("nil context") + } + deadline := d.deadline(ctx, time.Now()) + if !deadline.IsZero() { + if d, ok := ctx.Deadline(); !ok || deadline.Before(d) { + subCtx, cancel := context.WithDeadline(ctx, deadline) + defer cancel() + ctx = subCtx + } + } + if oldCancel := d.Cancel; oldCancel != nil { + subCtx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + select { + case <-oldCancel: + cancel() + case <-subCtx.Done(): + } + }() + ctx = subCtx + } + + // Shadow the nettrace (if any) during resolve so Connect events don't fire for DNS lookups. + resolveCtx := ctx + if trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace); trace != nil { + shadow := *trace + shadow.ConnectStart = nil + shadow.ConnectDone = nil + resolveCtx = context.WithValue(resolveCtx, nettrace.TraceKey{}, &shadow) + } + + addrs, err := resolveAddrList(resolveCtx, "dial", network, address, d.LocalAddr) if err != nil { return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err} } - ctx := &dialContext{ - Dialer: *d, - network: network, - address: address, - finalDeadline: finalDeadline, + dp := &dialParam{ + Dialer: *d, + network: network, + address: address, } var primaries, fallbacks addrList @@ -233,116 +345,128 @@ func (d *Dialer) Dial(network, address string) (Conn, error) { } var c Conn - if len(fallbacks) == 0 { - // dialParallel can accept an empty fallbacks list, - // but this shortcut avoids the goroutine/channel overhead. - c, err = dialSerial(ctx, primaries, nil) + if len(fallbacks) > 0 { + c, err = dialParallel(ctx, dp, primaries, fallbacks) } else { - c, err = dialParallel(ctx, primaries, fallbacks) + c, err = dialSerial(ctx, dp, primaries) + } + if err != nil { + return nil, err } - if d.KeepAlive > 0 && err == nil { - if tc, ok := c.(*TCPConn); ok { - setKeepAlive(tc.fd, true) - setKeepAlivePeriod(tc.fd, d.KeepAlive) - testHookSetKeepAlive() - } + if tc, ok := c.(*TCPConn); ok && d.KeepAlive > 0 { + setKeepAlive(tc.fd, true) + setKeepAlivePeriod(tc.fd, d.KeepAlive) + testHookSetKeepAlive() } - return c, err + return c, nil } // dialParallel races two copies of dialSerial, giving the first a // head start. It returns the first established connection and // closes the others. Otherwise it returns an error from the first // primary address. -func dialParallel(ctx *dialContext, primaries, fallbacks addrList) (Conn, error) { - results := make(chan dialResult) // unbuffered, so dialSerialAsync can detect race loss & cleanup - cancel := make(chan struct{}) - defer close(cancel) - - // Spawn the primary racer. - go dialSerialAsync(ctx, primaries, nil, cancel, results) - - // Spawn the fallback racer. - fallbackTimer := time.NewTimer(ctx.fallbackDelay()) - go dialSerialAsync(ctx, fallbacks, fallbackTimer, cancel, results) - - var primaryErr error - for nracers := 2; nracers > 0; nracers-- { - res := <-results - // If we're still waiting for a connection, then hasten the delay. - // Otherwise, disable the Timer and let cancel take over. - if fallbackTimer.Stop() && res.error != nil { - fallbackTimer.Reset(0) - } - if res.error == nil { - return res.Conn, nil - } - if res.primary { - primaryErr = res.error - } +func dialParallel(ctx context.Context, dp *dialParam, primaries, fallbacks addrList) (Conn, error) { + if len(fallbacks) == 0 { + return dialSerial(ctx, dp, primaries) } - return nil, primaryErr -} -type dialResult struct { - Conn - error - primary bool -} + returned := make(chan struct{}) + defer close(returned) + + type dialResult struct { + Conn + error + primary bool + done bool + } + results := make(chan dialResult) // unbuffered -// dialSerialAsync runs dialSerial after some delay, and returns the -// resulting connection through a channel. When racing two connections, -// the primary goroutine uses a nil timer to omit the delay. -func dialSerialAsync(ctx *dialContext, ras addrList, timer *time.Timer, cancel <-chan struct{}, results chan<- dialResult) { - if timer != nil { - // We're in the fallback goroutine; sleep before connecting. + startRacer := func(ctx context.Context, primary bool) { + ras := primaries + if !primary { + ras = fallbacks + } + c, err := dialSerial(ctx, dp, ras) select { - case <-timer.C: - case <-cancel: - return + case results <- dialResult{Conn: c, error: err, primary: primary, done: true}: + case <-returned: + if c != nil { + c.Close() + } } } - c, err := dialSerial(ctx, ras, cancel) - select { - case results <- dialResult{c, err, timer == nil}: - // We won the race. - case <-cancel: - // The other goroutine won the race. - if c != nil { - c.Close() + + var primary, fallback dialResult + + // Start the main racer. + primaryCtx, primaryCancel := context.WithCancel(ctx) + defer primaryCancel() + go startRacer(primaryCtx, true) + + // Start the timer for the fallback racer. + fallbackTimer := time.NewTimer(dp.fallbackDelay()) + defer fallbackTimer.Stop() + + for { + select { + case <-fallbackTimer.C: + fallbackCtx, fallbackCancel := context.WithCancel(ctx) + defer fallbackCancel() + go startRacer(fallbackCtx, false) + + case res := <-results: + if res.error == nil { + return res.Conn, nil + } + if res.primary { + primary = res + } else { + fallback = res + } + if primary.done && fallback.done { + return nil, primary.error + } + if res.primary && fallbackTimer.Stop() { + // If we were able to stop the timer, that means it + // was running (hadn't yet started the fallback), but + // we just got an error on the primary path, so start + // the fallback immediately (in 0 nanoseconds). + fallbackTimer.Reset(0) + } } } } // dialSerial connects to a list of addresses in sequence, returning // either the first successful connection, or the first error. -func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, error) { +func dialSerial(ctx context.Context, dp *dialParam, ras addrList) (Conn, error) { var firstErr error // The error from the first address is most relevant. for i, ra := range ras { select { - case <-cancel: - return nil, &OpError{Op: "dial", Net: ctx.network, Source: ctx.LocalAddr, Addr: ra, Err: errCanceled} + case <-ctx.Done(): + return nil, &OpError{Op: "dial", Net: dp.network, Source: dp.LocalAddr, Addr: ra, Err: mapErr(ctx.Err())} default: } - partialDeadline, err := partialDeadline(time.Now(), ctx.finalDeadline, len(ras)-i) + deadline, _ := ctx.Deadline() + partialDeadline, err := partialDeadline(time.Now(), deadline, len(ras)-i) if err != nil { // Ran out of time. if firstErr == nil { - firstErr = &OpError{Op: "dial", Net: ctx.network, Source: ctx.LocalAddr, Addr: ra, Err: err} + firstErr = &OpError{Op: "dial", Net: dp.network, Source: dp.LocalAddr, Addr: ra, Err: err} } break } - - // dialTCP does not support cancelation (see golang.org/issue/11225), - // so if cancel fires, we'll continue trying to connect until the next - // timeout, or return a spurious connection for the caller to close. - dialer := func(d time.Time) (Conn, error) { - return dialSingle(ctx, ra, d) + dialCtx := ctx + if partialDeadline.Before(deadline) { + var cancel context.CancelFunc + dialCtx, cancel = context.WithDeadline(ctx, partialDeadline) + defer cancel() } - c, err := dial(ctx.network, ra, dialer, partialDeadline) + + c, err := dialSingle(dialCtx, dp, ra) if err == nil { return c, nil } @@ -352,37 +476,43 @@ func dialSerial(ctx *dialContext, ras addrList, cancel <-chan struct{}) (Conn, e } if firstErr == nil { - firstErr = &OpError{Op: "dial", Net: ctx.network, Source: nil, Addr: nil, Err: errMissingAddress} + firstErr = &OpError{Op: "dial", Net: dp.network, Source: nil, Addr: nil, Err: errMissingAddress} } return nil, firstErr } // dialSingle attempts to establish and returns a single connection to -// the destination address. This must be called through the OS-specific -// dial function, because some OSes don't implement the deadline feature. -func dialSingle(ctx *dialContext, ra Addr, deadline time.Time) (c Conn, err error) { - la := ctx.LocalAddr - if la != nil && la.Network() != ra.Network() { - return nil, &OpError{Op: "dial", Net: ctx.network, Source: la, Addr: ra, Err: errors.New("mismatched local address type " + la.Network())} +// the destination address. +func dialSingle(ctx context.Context, dp *dialParam, ra Addr) (c Conn, err error) { + trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace) + if trace != nil { + raStr := ra.String() + if trace.ConnectStart != nil { + trace.ConnectStart(dp.network, raStr) + } + if trace.ConnectDone != nil { + defer func() { trace.ConnectDone(dp.network, raStr, err) }() + } } + la := dp.LocalAddr switch ra := ra.(type) { case *TCPAddr: la, _ := la.(*TCPAddr) - c, err = testHookDialTCP(ctx.network, la, ra, deadline, ctx.Cancel) + c, err = dialTCP(ctx, dp.network, la, ra) case *UDPAddr: la, _ := la.(*UDPAddr) - c, err = dialUDP(ctx.network, la, ra, deadline) + c, err = dialUDP(ctx, dp.network, la, ra) case *IPAddr: la, _ := la.(*IPAddr) - c, err = dialIP(ctx.network, la, ra, deadline) + c, err = dialIP(ctx, dp.network, la, ra) case *UnixAddr: la, _ := la.(*UnixAddr) - c, err = dialUnix(ctx.network, la, ra, deadline) + c, err = dialUnix(ctx, dp.network, la, ra) default: - return nil, &OpError{Op: "dial", Net: ctx.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: ctx.address}} + return nil, &OpError{Op: "dial", Net: dp.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: dp.address}} } if err != nil { - return nil, err // c is non-nil interface containing nil pointer + return nil, &OpError{Op: "dial", Net: dp.network, Source: la, Addr: ra, Err: err} // c is non-nil interface containing nil pointer } return c, nil } @@ -395,7 +525,7 @@ func dialSingle(ctx *dialContext, ra Addr, deadline time.Time) (c Conn, err erro // instead of just the interface with the given host address. // See Dial for more details about address syntax. func Listen(net, laddr string) (Listener, error) { - addrs, err := resolveAddrList("listen", net, laddr, noDeadline) + addrs, err := resolveAddrList(context.Background(), "listen", net, laddr, nil) if err != nil { return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err} } @@ -422,7 +552,7 @@ func Listen(net, laddr string) (Listener, error) { // instead of just the interface with the given host address. // See Dial for the syntax of laddr. func ListenPacket(net, laddr string) (PacketConn, error) { - addrs, err := resolveAddrList("listen", net, laddr, noDeadline) + addrs, err := resolveAddrList(context.Background(), "listen", net, laddr, nil) if err != nil { return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err} } |