diff options
author | Ian Lance Taylor <iant@golang.org> | 2017-09-14 17:11:35 +0000 |
---|---|---|
committer | Ian Lance Taylor <ian@gcc.gnu.org> | 2017-09-14 17:11:35 +0000 |
commit | bc998d034f45d1828a8663b2eed928faf22a7d01 (patch) | |
tree | 8d262a22ca7318f4bcd64269fe8fe9e45bcf8d0f /libgo/go/net | |
parent | a41a6142df74219f596e612d3a7775f68ca6e96f (diff) | |
download | gcc-bc998d034f45d1828a8663b2eed928faf22a7d01.zip gcc-bc998d034f45d1828a8663b2eed928faf22a7d01.tar.gz gcc-bc998d034f45d1828a8663b2eed928faf22a7d01.tar.bz2 |
libgo: update to go1.9
Reviewed-on: https://go-review.googlesource.com/63753
From-SVN: r252767
Diffstat (limited to 'libgo/go/net')
137 files changed, 6491 insertions, 4716 deletions
diff --git a/libgo/go/net/cgo_unix.go b/libgo/go/net/cgo_unix.go index 09cfb2a..0de3ff8 100644 --- a/libgo/go/net/cgo_unix.go +++ b/libgo/go/net/cgo_unix.go @@ -220,7 +220,7 @@ func cgoLookupIPCNAME(name string) (addrs []IPAddr, cname string, err error) { addrs = append(addrs, addr) case syscall.AF_INET6: sa := (*syscall.RawSockaddrInet6)(unsafe.Pointer(r.Ai_addr)) - addr := IPAddr{IP: copyIP(sa.Addr[:]), Zone: zoneToString(int(sa.Scope_id))} + addr := IPAddr{IP: copyIP(sa.Addr[:]), Zone: zoneCache.name(int(sa.Scope_id))} addrs = append(addrs, addr) } } @@ -345,7 +345,7 @@ func cgoSockaddr(ip IP, zone string) (*syscall.RawSockaddr, syscall.Socklen_t) { return cgoSockaddrInet4(ip4), syscall.Socklen_t(syscall.SizeofSockaddrInet4) } if ip6 := ip.To16(); ip6 != nil { - return cgoSockaddrInet6(ip6, zoneToInt(zone)), syscall.Socklen_t(syscall.SizeofSockaddrInet6) + return cgoSockaddrInet6(ip6, zoneCache.index(zone)), syscall.Socklen_t(syscall.SizeofSockaddrInet6) } return nil, 0 } diff --git a/libgo/go/net/dial.go b/libgo/go/net/dial.go index 50bba5a..f8b4aa2 100644 --- a/libgo/go/net/dial.go +++ b/libgo/go/net/dial.go @@ -7,6 +7,7 @@ package net import ( "context" "internal/nettrace" + "internal/poll" "time" ) @@ -22,8 +23,8 @@ type Dialer struct { // // The default is no timeout. // - // When dialing a name with multiple IP addresses, the timeout - // may be divided between them. + // When using TCP and dialing a host name with multiple IP + // addresses, the timeout may be divided between them. // // With or without a timeout, the operating system may impose // its own earlier timeout. For instance, TCP timeouts are @@ -42,10 +43,11 @@ type Dialer struct { // If nil, a local address is automatically chosen. LocalAddr Addr - // DualStack enables RFC 6555-compliant "Happy Eyeballs" dialing - // when the network is "tcp" and the destination is a host name - // with both IPv4 and IPv6 addresses. This allows a client to - // tolerate networks where one address family is silently broken. + // DualStack enables RFC 6555-compliant "Happy Eyeballs" + // dialing when the network is "tcp" and the host in the + // address parameter resolves to both IPv4 and IPv6 addresses. + // This allows a client to tolerate networks where one address + // family is silently broken. DualStack bool // FallbackDelay specifies the length of time to wait before @@ -110,7 +112,7 @@ func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, er } timeRemaining := deadline.Sub(now) if timeRemaining <= 0 { - return time.Time{}, errTimeout + return time.Time{}, poll.ErrTimeout } // Tentatively allocate equal time to each remaining address. timeout := timeRemaining / time.Duration(addrsRemaining) @@ -134,23 +136,26 @@ func (d *Dialer) fallbackDelay() time.Duration { } } -func parseNetwork(ctx context.Context, net string) (afnet string, proto int, err error) { - i := last(net, ':') +func parseNetwork(ctx context.Context, network string, needsProto bool) (afnet string, proto int, err error) { + i := last(network, ':') if i < 0 { // no colon - switch net { + switch network { case "tcp", "tcp4", "tcp6": case "udp", "udp4", "udp6": case "ip", "ip4", "ip6": + if needsProto { + return "", 0, UnknownNetworkError(network) + } case "unix", "unixgram", "unixpacket": default: - return "", 0, UnknownNetworkError(net) + return "", 0, UnknownNetworkError(network) } - return net, 0, nil + return network, 0, nil } - afnet = net[:i] + afnet = network[:i] switch afnet { case "ip", "ip4", "ip6": - protostr := net[i+1:] + protostr := network[i+1:] proto, i, ok := dtoi(protostr) if !ok || i != len(protostr) { proto, err = lookupProtocol(ctx, protostr) @@ -160,14 +165,14 @@ func parseNetwork(ctx context.Context, net string) (afnet string, proto int, err } return afnet, proto, nil } - return "", 0, UnknownNetworkError(net) + return "", 0, UnknownNetworkError(network) } // resolveAddrList resolves addr using hint and returns a list of // addresses. The result contains at least one address when error is // nil. func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) { - afnet, _, err := parseNetwork(ctx, network) + afnet, _, err := parseNetwork(ctx, network, true) if err != nil { return nil, err } @@ -242,39 +247,60 @@ func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string // (IPv4-only), "ip6" (IPv6-only), "unix", "unixgram" and // "unixpacket". // -// For TCP and UDP networks, addresses have the form host:port. -// If host is a literal IPv6 address it must be enclosed -// in square brackets as in "[::1]:80" or "[ipv6-host%zone]:80". -// The functions JoinHostPort and SplitHostPort manipulate addresses -// in this form. -// If the host is empty, as in ":80", the local system is assumed. +// For TCP and UDP networks, the address has the form "host:port". +// The host must be a literal IP address, or a host name that can be +// resolved to IP addresses. +// The port must be a literal port number or a service name. +// If the host is a literal IPv6 address it must be enclosed in square +// brackets, as in "[2001:db8::1]:80" or "[fe80::1%zone]:80". +// The zone specifies the scope of the literal IPv6 address as defined +// in RFC 4007. +// The functions JoinHostPort and SplitHostPort manipulate a pair of +// host and port in this form. +// When using TCP, and the host resolves to multiple IP addresses, +// Dial will try each IP address in order until one succeeds. // // Examples: -// Dial("tcp", "192.0.2.1:80") // Dial("tcp", "golang.org:http") -// Dial("tcp", "[2001:db8::1]:http") -// Dial("tcp", "[fe80::1%lo0]:80") +// Dial("tcp", "192.0.2.1:http") +// Dial("tcp", "198.51.100.1:80") +// Dial("udp", "[2001:db8::1]:domain") +// Dial("udp", "[fe80::1%lo0]:53") // Dial("tcp", ":80") // // For IP networks, the network must be "ip", "ip4" or "ip6" followed -// by a colon and a protocol number or name and the addr must be a -// literal IP address. +// by a colon and a literal protocol number or a protocol name, and +// the address has the form "host". The host must be a literal IP +// address or a literal IPv6 address with zone. +// It depends on each operating system how the operating system +// behaves with a non-well known protocol number such as "0" or "255". // // Examples: // Dial("ip4:1", "192.0.2.1") // Dial("ip6:ipv6-icmp", "2001:db8::1") +// Dial("ip6:58", "fe80::1%lo0") // -// For Unix networks, the address must be a file system path. +// For TCP, UDP and IP networks, if the host is empty or a literal +// unspecified IP address, as in ":80", "0.0.0.0:80" or "[::]:80" for +// TCP and UDP, "", "0.0.0.0" or "::" for IP, the local system is +// assumed. // -// If the host is resolved to multiple addresses, -// Dial will try each address in order until one succeeds. +// For Unix networks, the address must be a file system path. func Dial(network, address string) (Conn, error) { var d Dialer return d.Dial(network, address) } // DialTimeout acts like Dial but takes a timeout. +// // The timeout includes name resolution, if required. +// When using TCP, and the host in the address parameter resolves to +// multiple IP addresses, the timeout is spread over each consecutive +// dial, such that each is given an appropriate fraction of the time +// to connect. +// +// See func Dial for a description of the network and address +// parameters. func DialTimeout(network, address string, timeout time.Duration) (Conn, error) { d := Dialer{Timeout: timeout} return d.Dial(network, address) @@ -537,29 +563,37 @@ func dialSingle(ctx context.Context, dp *dialParam, ra Addr) (c Conn, err error) return c, nil } -// Listen announces on the local network address laddr. -// The network net must be a stream-oriented network: "tcp", "tcp4", -// "tcp6", "unix" or "unixpacket". -// For TCP and UDP, the syntax of laddr is "host:port", like "127.0.0.1:8080". -// If host is omitted, as in ":8080", Listen listens on all available interfaces -// instead of just the interface with the given host address. -// See Dial for more details about address syntax. +// Listen announces on the local network address. +// +// The network must be "tcp", "tcp4", "tcp6", "unix" or "unixpacket". // -// Listening on a hostname is not recommended because this creates a socket -// for at most one of its IP addresses. -func Listen(net, laddr string) (Listener, error) { - addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", net, laddr, nil) +// For TCP networks, if the host in the address parameter is empty or +// a literal unspecified IP address, Listen listens on all available +// unicast and anycast IP addresses of the local system. +// To only use IPv4, use network "tcp4". +// The address can use a host name, but this is not recommended, +// because it will create a listener for at most one of the host's IP +// addresses. +// If the port in the address parameter is empty or "0", as in +// "127.0.0.1:" or "[::1]:0", a port number is automatically chosen. +// The Addr method of Listener can be used to discover the chosen +// port. +// +// See func Dial for a description of the network and address +// parameters. +func Listen(network, address string) (Listener, error) { + addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", network, address, nil) if err != nil { - return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err} + return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err} } var l Listener switch la := addrs.first(isIPv4).(type) { case *TCPAddr: - l, err = ListenTCP(net, la) + l, err = ListenTCP(network, la) case *UnixAddr: - l, err = ListenUnix(net, la) + l, err = ListenUnix(network, la) default: - return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: laddr}} + return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}} } if err != nil { return nil, err // l is non-nil interface containing nil pointer @@ -567,31 +601,43 @@ func Listen(net, laddr string) (Listener, error) { return l, nil } -// ListenPacket announces on the local network address laddr. -// The network net must be a packet-oriented network: "udp", "udp4", -// "udp6", "ip", "ip4", "ip6" or "unixgram". -// For TCP and UDP, the syntax of laddr is "host:port", like "127.0.0.1:8080". -// If host is omitted, as in ":8080", ListenPacket listens on all available interfaces -// instead of just the interface with the given host address. -// See Dial for the syntax of laddr. +// ListenPacket announces on the local network address. +// +// The network must be "udp", "udp4", "udp6", "unixgram", or an IP +// transport. The IP transports are "ip", "ip4", or "ip6" followed by +// a colon and a literal protocol number or a protocol name, as in +// "ip:1" or "ip:icmp". // -// Listening on a hostname is not recommended because this creates a socket -// for at most one of its IP addresses. -func ListenPacket(net, laddr string) (PacketConn, error) { - addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", net, laddr, nil) +// For UDP and IP networks, if the host in the address parameter is +// empty or a literal unspecified IP address, ListenPacket listens on +// all available IP addresses of the local system except multicast IP +// addresses. +// To only use IPv4, use network "udp4" or "ip4:proto". +// The address can use a host name, but this is not recommended, +// because it will create a listener for at most one of the host's IP +// addresses. +// If the port in the address parameter is empty or "0", as in +// "127.0.0.1:" or "[::1]:0", a port number is automatically chosen. +// The LocalAddr method of PacketConn can be used to discover the +// chosen port. +// +// See func Dial for a description of the network and address +// parameters. +func ListenPacket(network, address string) (PacketConn, error) { + addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", network, address, nil) if err != nil { - return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err} + return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err} } var l PacketConn switch la := addrs.first(isIPv4).(type) { case *UDPAddr: - l, err = ListenUDP(net, la) + l, err = ListenUDP(network, la) case *IPAddr: - l, err = ListenIP(net, la) + l, err = ListenIP(network, la) case *UnixAddr: - l, err = ListenUnixgram(net, la) + l, err = ListenUnixgram(network, la) default: - return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: laddr}} + return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}} } if err != nil { return nil, err // l is non-nil interface containing nil pointer diff --git a/libgo/go/net/dial_test.go b/libgo/go/net/dial_test.go index 9919d72..a892bf1 100644 --- a/libgo/go/net/dial_test.go +++ b/libgo/go/net/dial_test.go @@ -7,9 +7,9 @@ package net import ( "bufio" "context" + "internal/poll" "internal/testenv" "io" - "net/internal/socktest" "runtime" "sync" "testing" @@ -31,7 +31,7 @@ func TestProhibitionaryDialArg(t *testing.T) { case "plan9": t.Skipf("not supported on %s", runtime.GOOS) } - if !supportsIPv4map { + if !supportsIPv4map() { t.Skip("mapping ipv4 address inside ipv6 address not supported") } @@ -72,70 +72,6 @@ func TestDialLocal(t *testing.T) { c.Close() } -func TestDialTimeoutFDLeak(t *testing.T) { - switch runtime.GOOS { - case "plan9": - t.Skipf("%s does not have full support of socktest", runtime.GOOS) - case "openbsd": - testenv.SkipFlaky(t, 15157) - } - - const T = 100 * time.Millisecond - - switch runtime.GOOS { - case "plan9", "windows": - origTestHookDialChannel := testHookDialChannel - testHookDialChannel = func() { time.Sleep(2 * T) } - defer func() { testHookDialChannel = origTestHookDialChannel }() - if runtime.GOOS == "plan9" { - break - } - fallthrough - default: - sw.Set(socktest.FilterConnect, func(so *socktest.Status) (socktest.AfterFilter, error) { - time.Sleep(2 * T) - return nil, errTimeout - }) - defer sw.Set(socktest.FilterConnect, nil) - } - - // Avoid tracking open-close jitterbugs between netFD and - // socket that leads to confusion of information inside - // socktest.Switch. - // It may happen when the Dial call bumps against TCP - // simultaneous open. See selfConnect in tcpsock_posix.go. - defer func() { sw.Set(socktest.FilterClose, nil) }() - var mu sync.Mutex - var attempts int - sw.Set(socktest.FilterClose, func(so *socktest.Status) (socktest.AfterFilter, error) { - mu.Lock() - attempts++ - mu.Unlock() - return nil, nil - }) - - const N = 100 - var wg sync.WaitGroup - wg.Add(N) - for i := 0; i < N; i++ { - go func() { - defer wg.Done() - // This dial never starts to send any SYN - // segment because of above socket filter and - // test hook. - c, err := DialTimeout("tcp", "127.0.0.1:0", T) - if err == nil { - t.Errorf("unexpectedly established: tcp:%s->%s", c.LocalAddr(), c.RemoteAddr()) - c.Close() - } - }() - } - wg.Wait() - if attempts < N { - t.Errorf("got %d; want >= %d", attempts, N) - } -} - func TestDialerDualStackFDLeak(t *testing.T) { switch runtime.GOOS { case "plan9": @@ -145,7 +81,7 @@ func TestDialerDualStackFDLeak(t *testing.T) { case "openbsd": testenv.SkipFlaky(t, 15157) } - if !supportsIPv4 || !supportsIPv6 { + if !supportsIPv4() || !supportsIPv6() { t.Skip("both IPv4 and IPv6 are required") } @@ -254,7 +190,7 @@ func dialClosedPort() (actual, expected time.Duration) { func TestDialParallel(t *testing.T) { testenv.MustHaveExternalNetwork(t) - if !supportsIPv4 || !supportsIPv6 { + if !supportsIPv4() || !supportsIPv6() { t.Skip("both IPv4 and IPv6 are required") } @@ -425,7 +361,7 @@ func lookupSlowFast(ctx context.Context, fn func(context.Context, string) ([]IPA func TestDialerFallbackDelay(t *testing.T) { testenv.MustHaveExternalNetwork(t) - if !supportsIPv4 || !supportsIPv6 { + if !supportsIPv4() || !supportsIPv6() { t.Skip("both IPv4 and IPv6 are required") } @@ -491,7 +427,7 @@ func TestDialerFallbackDelay(t *testing.T) { } func TestDialParallelSpuriousConnection(t *testing.T) { - if !supportsIPv4 || !supportsIPv6 { + if !supportsIPv4() || !supportsIPv6() { t.Skip("both IPv4 and IPv6 are required") } @@ -585,22 +521,22 @@ func TestDialerPartialDeadline(t *testing.T) { {now, noDeadline, 1, noDeadline, nil}, // Step the clock forward and cross the deadline. {now.Add(-1 * time.Millisecond), now, 1, now, nil}, - {now.Add(0 * time.Millisecond), now, 1, noDeadline, errTimeout}, - {now.Add(1 * time.Millisecond), now, 1, noDeadline, errTimeout}, + {now.Add(0 * time.Millisecond), now, 1, noDeadline, poll.ErrTimeout}, + {now.Add(1 * time.Millisecond), now, 1, noDeadline, poll.ErrTimeout}, } for i, tt := range testCases { deadline, err := partialDeadline(tt.now, tt.deadline, tt.addrs) if err != tt.expectErr { t.Errorf("#%d: got %v; want %v", i, err, tt.expectErr) } - if deadline != tt.expectDeadline { + if !deadline.Equal(tt.expectDeadline) { t.Errorf("#%d: got %v; want %v", i, deadline, tt.expectDeadline) } } } func TestDialerLocalAddr(t *testing.T) { - if !supportsIPv4 || !supportsIPv6 { + if !supportsIPv4() || !supportsIPv6() { t.Skip("both IPv4 and IPv6 are required") } @@ -654,7 +590,7 @@ func TestDialerLocalAddr(t *testing.T) { {"tcp", "::1", &UnixAddr{}, &AddrError{Err: "some error"}}, } - if supportsIPv4map { + if supportsIPv4map() { tests = append(tests, test{ "tcp", "127.0.0.1", &TCPAddr{IP: ParseIP("::")}, nil, }) @@ -714,12 +650,9 @@ func TestDialerLocalAddr(t *testing.T) { } func TestDialerDualStack(t *testing.T) { - // This test is known to be flaky. Don't frighten regular - // users about it; only fail on the build dashboard. - if testenv.Builder() == "" { - testenv.SkipFlaky(t, 13324) - } - if !supportsIPv4 || !supportsIPv6 { + testenv.SkipFlaky(t, 13324) + + if !supportsIPv4() || !supportsIPv6() { t.Skip("both IPv4 and IPv6 are required") } @@ -822,7 +755,7 @@ func TestDialCancel(t *testing.T) { } blackholeIPPort := JoinHostPort(slowDst4, "1234") - if !supportsIPv4 { + if !supportsIPv4() { blackholeIPPort = JoinHostPort(slowDst6, "1234") } @@ -954,3 +887,24 @@ func TestCancelAfterDial(t *testing.T) { try() } } + +// Issue 18806: it should always be possible to net.Dial a +// net.Listener().Addr().String when the listen address was ":n", even +// if the machine has halfway configured IPv6 such that it can bind on +// "::" not connect back to that same address. +func TestDialListenerAddr(t *testing.T) { + if testenv.Builder() == "" { + testenv.MustHaveExternalNetwork(t) + } + ln, err := Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + addr := ln.Addr().String() + c, err := Dial("tcp", addr) + if err != nil { + t.Fatalf("for addr %q, dial error: %v", addr, err) + } + c.Close() +} diff --git a/libgo/go/net/dnsclient_unix.go b/libgo/go/net/dnsclient_unix.go index 0647b9c..ff6a4f6 100644 --- a/libgo/go/net/dnsclient_unix.go +++ b/libgo/go/net/dnsclient_unix.go @@ -25,13 +25,6 @@ import ( "time" ) -// A dnsDialer provides dialing suitable for DNS queries. -type dnsDialer interface { - dialDNS(ctx context.Context, network, addr string) (dnsConn, error) -} - -var testHookDNSDialer = func() dnsDialer { return &Dialer{} } - // A dnsConn represents a DNS transport endpoint. type dnsConn interface { io.Closer @@ -43,14 +36,14 @@ type dnsConn interface { dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) } -func (c *UDPConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) { - return dnsRoundTripUDP(c, query) +// dnsPacketConn implements the dnsConn interface for RFC 1035's +// "UDP usage" transport mechanism. Conn is a packet-oriented connection, +// such as a *UDPConn. +type dnsPacketConn struct { + Conn } -// dnsRoundTripUDP implements the dnsRoundTrip interface for RFC 1035's -// "UDP usage" transport mechanism. c should be a packet-oriented connection, -// such as a *UDPConn. -func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) { +func (c *dnsPacketConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) { b, ok := query.Pack() if !ok { return nil, errors.New("cannot marshal DNS message") @@ -76,14 +69,14 @@ func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) { } } -func (c *TCPConn) dnsRoundTrip(out *dnsMsg) (*dnsMsg, error) { - return dnsRoundTripTCP(c, out) +// dnsStreamConn implements the dnsConn interface for RFC 1035's +// "TCP usage" transport mechanism. Conn is a stream-oriented connection, +// such as a *TCPConn. +type dnsStreamConn struct { + Conn } -// dnsRoundTripTCP implements the dnsRoundTrip interface for RFC 1035's -// "TCP usage" transport mechanism. c should be a stream-oriented connection, -// such as a *TCPConn. -func dnsRoundTripTCP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) { +func (c *dnsStreamConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) { b, ok := query.Pack() if !ok { return nil, errors.New("cannot marshal DNS message") @@ -116,33 +109,8 @@ func dnsRoundTripTCP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) { return resp, nil } -func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn, error) { - switch network { - case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6": - default: - return nil, UnknownNetworkError(network) - } - // Calling Dial here is scary -- we have to be sure not to - // dial a name that will require a DNS lookup, or Dial will - // call back here to translate it. The DNS config parser has - // already checked that all the cfg.servers are IP - // addresses, which Dial will use without a DNS lookup. - c, err := d.DialContext(ctx, network, server) - if err != nil { - return nil, mapErr(err) - } - switch network { - case "tcp", "tcp4", "tcp6": - return c.(*TCPConn), nil - case "udp", "udp4", "udp6": - return c.(*UDPConn), nil - } - panic("unreachable") -} - // exchange sends a query on the connection and hopes for a response. -func exchange(ctx context.Context, server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) { - d := testHookDNSDialer() +func (r *Resolver) exchange(ctx context.Context, server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) { out := dnsMsg{ dnsMsgHdr: dnsMsgHdr{ recursion_desired: true, @@ -158,7 +126,7 @@ func exchange(ctx context.Context, server, name string, qtype uint16, timeout ti ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) defer cancel() - c, err := d.dialDNS(ctx, network, server) + c, err := r.dial(ctx, network, server) if err != nil { return nil, err } @@ -181,7 +149,7 @@ func exchange(ctx context.Context, server, name string, qtype uint16, timeout ti // Do a lookup for a single name, which must be rooted // (otherwise answer will not find the answers). -func tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) { +func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) { var lastErr error serverOffset := cfg.serverOffset() sLen := uint32(len(cfg.servers)) @@ -190,7 +158,7 @@ func tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) for j := uint32(0); j < sLen; j++ { server := cfg.servers[(serverOffset+j)%sLen] - msg, err := exchange(ctx, server, name, qtype, cfg.timeout) + msg, err := r.exchange(ctx, server, name, qtype, cfg.timeout) if err != nil { lastErr = &DNSError{ Err: err.Error(), @@ -200,6 +168,11 @@ func tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) if nerr, ok := err.(Error); ok && nerr.Timeout() { lastErr.(*DNSError).IsTimeout = true } + // Set IsTemporary for socket-level errors. Note that this flag + // may also be used to indicate a SERVFAIL response. + if _, ok := err.(*OpError); ok { + lastErr.(*DNSError).IsTemporary = true + } continue } // libresolv continues to the next server when it receives @@ -314,7 +287,7 @@ func (conf *resolverConfig) releaseSema() { <-conf.ch } -func lookup(ctx context.Context, name string, qtype uint16) (cname string, rrs []dnsRR, err error) { +func (r *Resolver) lookup(ctx context.Context, name string, qtype uint16) (cname string, rrs []dnsRR, err error) { if !isDomainName(name) { // We used to use "invalid domain name" as the error, // but that is a detail of the specific lookup mechanism. @@ -328,10 +301,15 @@ func lookup(ctx context.Context, name string, qtype uint16) (cname string, rrs [ conf := resolvConf.dnsConfig resolvConf.mu.RUnlock() for _, fqdn := range conf.nameList(name) { - cname, rrs, err = tryOneName(ctx, conf, fqdn, qtype) + cname, rrs, err = r.tryOneName(ctx, conf, fqdn, qtype) if err == nil { break } + if nerr, ok := err.(Error); ok && nerr.Temporary() && r.StrictErrors { + // If we hit a temporary error with StrictErrors enabled, + // stop immediately instead of trying more names. + break + } } if err, ok := err.(*DNSError); ok { // Show original name passed to lookup, not suffixed one. @@ -432,11 +410,11 @@ func (o hostLookupOrder) String() string { // Normally we let cgo use the C library resolver instead of // depending on our lookup code, so that Go and C get the same // answers. -func goLookupHost(ctx context.Context, name string) (addrs []string, err error) { - return goLookupHostOrder(ctx, name, hostLookupFilesDNS) +func (r *Resolver) goLookupHost(ctx context.Context, name string) (addrs []string, err error) { + return r.goLookupHostOrder(ctx, name, hostLookupFilesDNS) } -func goLookupHostOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []string, err error) { +func (r *Resolver) goLookupHostOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []string, err error) { if order == hostLookupFilesDNS || order == hostLookupFiles { // Use entries from /etc/hosts if they match. addrs = lookupStaticHost(name) @@ -444,7 +422,7 @@ func goLookupHostOrder(ctx context.Context, name string, order hostLookupOrder) return } } - ips, _, err := goLookupIPCNAMEOrder(ctx, name, order) + ips, _, err := r.goLookupIPCNAMEOrder(ctx, name, order) if err != nil { return } @@ -470,13 +448,13 @@ func goLookupIPFiles(name string) (addrs []IPAddr) { // goLookupIP is the native Go implementation of LookupIP. // The libc versions are in cgo_*.go. -func goLookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { +func (r *Resolver) goLookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { order := systemConf().hostLookupOrder(host) - addrs, _, err = goLookupIPCNAMEOrder(ctx, host, order) + addrs, _, err = r.goLookupIPCNAMEOrder(ctx, host, order) return } -func goLookupIPCNAMEOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, cname string, err error) { +func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, cname string, err error) { if order == hostLookupFilesDNS || order == hostLookupFiles { addrs = goLookupIPFiles(name) if len(addrs) > 0 || order == hostLookupFiles { @@ -502,15 +480,20 @@ func goLookupIPCNAMEOrder(ctx context.Context, name string, order hostLookupOrde for _, fqdn := range conf.nameList(name) { for _, qtype := range qtypes { go func(qtype uint16) { - cname, rrs, err := tryOneName(ctx, conf, fqdn, qtype) + cname, rrs, err := r.tryOneName(ctx, conf, fqdn, qtype) lane <- racer{cname, rrs, err} }(qtype) } + hitStrictError := false for range qtypes { racer := <-lane if racer.error != nil { - // Prefer error for original name. - if lastErr == nil || fqdn == name+"." { + if nerr, ok := racer.error.(Error); ok && nerr.Temporary() && r.StrictErrors { + // This error will abort the nameList loop. + hitStrictError = true + lastErr = racer.error + } else if lastErr == nil || fqdn == name+"." { + // Prefer error for original name. lastErr = racer.error } continue @@ -520,6 +503,13 @@ func goLookupIPCNAMEOrder(ctx context.Context, name string, order hostLookupOrde cname = racer.cname } } + if hitStrictError { + // If either family hit an error with StrictErrors enabled, + // discard all addresses. This ensures that network flakiness + // cannot turn a dualstack hostname IPv4/IPv6-only. + addrs = nil + break + } if len(addrs) > 0 { break } @@ -543,9 +533,9 @@ func goLookupIPCNAMEOrder(ctx context.Context, name string, order hostLookupOrde } // goLookupCNAME is the native Go (non-cgo) implementation of LookupCNAME. -func goLookupCNAME(ctx context.Context, host string) (cname string, err error) { +func (r *Resolver) goLookupCNAME(ctx context.Context, host string) (cname string, err error) { order := systemConf().hostLookupOrder(host) - _, cname, err = goLookupIPCNAMEOrder(ctx, host, order) + _, cname, err = r.goLookupIPCNAMEOrder(ctx, host, order) return } @@ -554,7 +544,7 @@ func goLookupCNAME(ctx context.Context, host string) (cname string, err error) { // only if cgoLookupPTR is the stub in cgo_stub.go). // Normally we let cgo use the C library resolver instead of depending // on our lookup code, so that Go and C get the same answers. -func goLookupPTR(ctx context.Context, addr string) ([]string, error) { +func (r *Resolver) goLookupPTR(ctx context.Context, addr string) ([]string, error) { names := lookupStaticAddr(addr) if len(names) > 0 { return names, nil @@ -563,7 +553,7 @@ func goLookupPTR(ctx context.Context, addr string) ([]string, error) { if err != nil { return nil, err } - _, rrs, err := lookup(ctx, arpa, dnsTypePTR) + _, rrs, err := r.lookup(ctx, arpa, dnsTypePTR) if err != nil { return nil, err } diff --git a/libgo/go/net/dnsclient_unix_test.go b/libgo/go/net/dnsclient_unix_test.go index c66d2d1..94811c9 100644 --- a/libgo/go/net/dnsclient_unix_test.go +++ b/libgo/go/net/dnsclient_unix_test.go @@ -8,8 +8,9 @@ package net import ( "context" + "errors" "fmt" - "internal/testenv" + "internal/poll" "io/ioutil" "os" "path" @@ -20,9 +21,14 @@ import ( "time" ) +var goResolver = Resolver{PreferGo: true} + // Test address from 192.0.2.0/24 block, reserved by RFC 5737 for documentation. const TestAddr uint32 = 0xc0000201 +// Test address from 2001:db8::/32 block, reserved by RFC 3849 for documentation. +var TestAddr6 = [16]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + var dnsTransportFallbackTests = []struct { server string name string @@ -37,18 +43,33 @@ var dnsTransportFallbackTests = []struct { } func TestDNSTransportFallback(t *testing.T) { - testenv.MustHaveExternalNetwork(t) - + fake := fakeDNSServer{ + rh: func(n, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { + r := &dnsMsg{ + dnsMsgHdr: dnsMsgHdr{ + id: q.id, + response: true, + rcode: dnsRcodeSuccess, + }, + question: q.question, + } + if n == "udp" { + r.truncated = true + } + return r, nil + }, + } + r := Resolver{PreferGo: true, Dial: fake.DialContext} for _, tt := range dnsTransportFallbackTests { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - msg, err := exchange(ctx, tt.server, tt.name, tt.qtype, time.Second) + msg, err := r.exchange(ctx, tt.server, tt.name, tt.qtype, time.Second) if err != nil { t.Error(err) continue } switch msg.rcode { - case tt.rcode, dnsRcodeServerFailure: + case tt.rcode: default: t.Errorf("got %v from %v; want %v", msg.rcode, tt.server, tt.rcode) continue @@ -78,13 +99,30 @@ var specialDomainNameTests = []struct { } func TestSpecialDomainName(t *testing.T) { - testenv.MustHaveExternalNetwork(t) + fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { + r := &dnsMsg{ + dnsMsgHdr: dnsMsgHdr{ + id: q.id, + response: true, + }, + question: q.question, + } + + switch q.question[0].Name { + case "example.com.": + r.rcode = dnsRcodeSuccess + default: + r.rcode = dnsRcodeNameError + } + return r, nil + }} + r := Resolver{PreferGo: true, Dial: fake.DialContext} server := "8.8.8.8:53" for _, tt := range specialDomainNameTests { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - msg, err := exchange(ctx, server, tt.name, tt.qtype, 3*time.Second) + msg, err := r.exchange(ctx, server, tt.name, tt.qtype, 3*time.Second) if err != nil { t.Error(err) continue @@ -139,15 +177,40 @@ func TestAvoidDNSName(t *testing.T) { } } +var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { + r := &dnsMsg{ + dnsMsgHdr: dnsMsgHdr{ + id: q.id, + response: true, + }, + question: q.question, + } + if len(q.question) == 1 && q.question[0].Qtype == dnsTypeA { + r.answer = []dnsRR{ + &dnsRR_A{ + Hdr: dnsRR_Header{ + Name: q.question[0].Name, + Rrtype: dnsTypeA, + Class: dnsClassINET, + Rdlength: 4, + }, + A: TestAddr, + }, + } + } + return r, nil +}} + // Issue 13705: don't try to resolve onion addresses, etc func TestLookupTorOnion(t *testing.T) { - addrs, err := goLookupIP(context.Background(), "foo.onion") - if len(addrs) > 0 { - t.Errorf("unexpected addresses: %v", addrs) - } + r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext} + addrs, err := r.LookupIPAddr(context.Background(), "foo.onion") if err != nil { t.Fatalf("lookup = %v; want nil", err) } + if len(addrs) > 0 { + t.Errorf("unexpected addresses: %v", addrs) + } } type resolvConfTest struct { @@ -237,7 +300,7 @@ var updateResolvConfTests = []struct { } func TestUpdateResolvConf(t *testing.T) { - testenv.MustHaveExternalNetwork(t) + r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext} conf, err := newResolvConfTest() if err != nil { @@ -257,7 +320,7 @@ func TestUpdateResolvConf(t *testing.T) { for j := 0; j < N; j++ { go func(name string) { defer wg.Done() - ips, err := goLookupIP(context.Background(), name) + ips, err := r.LookupIPAddr(context.Background(), name) if err != nil { t.Error(err) return @@ -392,7 +455,60 @@ var goLookupIPWithResolverConfigTests = []struct { } func TestGoLookupIPWithResolverConfig(t *testing.T) { - testenv.MustHaveExternalNetwork(t) + fake := fakeDNSServer{func(n, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { + switch s { + case "[2001:4860:4860::8888]:53", "8.8.8.8:53": + break + default: + time.Sleep(10 * time.Millisecond) + return nil, poll.ErrTimeout + } + r := &dnsMsg{ + dnsMsgHdr: dnsMsgHdr{ + id: q.id, + response: true, + }, + question: q.question, + } + for _, question := range q.question { + switch question.Qtype { + case dnsTypeA: + switch question.Name { + case "hostname.as112.net.": + break + case "ipv4.google.com.": + r.answer = append(r.answer, &dnsRR_A{ + Hdr: dnsRR_Header{ + Name: q.question[0].Name, + Rrtype: dnsTypeA, + Class: dnsClassINET, + Rdlength: 4, + }, + A: TestAddr, + }) + default: + + } + case dnsTypeAAAA: + switch question.Name { + case "hostname.as112.net.": + break + case "ipv6.google.com.": + r.answer = append(r.answer, &dnsRR_AAAA{ + Hdr: dnsRR_Header{ + Name: q.question[0].Name, + Rrtype: dnsTypeAAAA, + Class: dnsClassINET, + Rdlength: 16, + }, + AAAA: TestAddr6, + }) + } + } + } + return r, nil + }} + r := Resolver{PreferGo: true, Dial: fake.DialContext} conf, err := newResolvConfTest() if err != nil { @@ -405,14 +521,8 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) { t.Error(err) continue } - addrs, err := goLookupIP(context.Background(), tt.name) + addrs, err := r.LookupIPAddr(context.Background(), tt.name) if err != nil { - // This test uses external network connectivity. - // We need to take care with errors on both - // DNS message exchange layer and DNS - // transport layer because goLookupIP may fail - // when the IP connectivity on node under test - // gets lost during its run. if err, ok := err.(*DNSError); !ok || tt.error != nil && (err.Name != tt.error.(*DNSError).Name || err.Server != tt.error.(*DNSError).Server || err.IsTimeout != tt.error.(*DNSError).IsTimeout) { t.Errorf("got %v; want %v", err, tt.error) } @@ -437,7 +547,17 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) { // Test that goLookupIPOrder falls back to the host file when no DNS servers are available. func TestGoLookupIPOrderFallbackToFile(t *testing.T) { - testenv.MustHaveExternalNetwork(t) + fake := fakeDNSServer{func(n, s string, q *dnsMsg, tm time.Time) (*dnsMsg, error) { + r := &dnsMsg{ + dnsMsgHdr: dnsMsgHdr{ + id: q.id, + response: true, + }, + question: q.question, + } + return r, nil + }} + r := Resolver{PreferGo: true, Dial: fake.DialContext} // Add a config that simulates no dns servers being available. conf, err := newResolvConfTest() @@ -455,14 +575,14 @@ func TestGoLookupIPOrderFallbackToFile(t *testing.T) { name := fmt.Sprintf("order %v", order) // First ensure that we get an error when contacting a non-existent host. - _, _, err := goLookupIPCNAMEOrder(context.Background(), "notarealhost", order) + _, _, err := r.goLookupIPCNAMEOrder(context.Background(), "notarealhost", order) if err == nil { t.Errorf("%s: expected error while looking up name not in hosts file", name) continue } // Now check that we get an address when the name appears in the hosts file. - addrs, _, err := goLookupIPCNAMEOrder(context.Background(), "thor", order) // entry is in "testdata/hosts" + addrs, _, err := r.goLookupIPCNAMEOrder(context.Background(), "thor", order) // entry is in "testdata/hosts" if err != nil { t.Errorf("%s: expected to successfully lookup host entry", name) continue @@ -485,9 +605,6 @@ func TestGoLookupIPOrderFallbackToFile(t *testing.T) { func TestErrorForOriginalNameWhenSearching(t *testing.T) { const fqdn = "doesnotexist.domain" - origTestHookDNSDialer := testHookDNSDialer - defer func() { testHookDNSDialer = origTestHookDNSDialer }() - conf, err := newResolvConfTest() if err != nil { t.Fatal(err) @@ -498,14 +615,13 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) { t.Fatal(err) } - d := &fakeDNSDialer{} - testHookDNSDialer = func() dnsDialer { return d } - - d.rh = func(s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { + fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { r := &dnsMsg{ dnsMsgHdr: dnsMsgHdr{ - id: q.id, + id: q.id, + response: true, }, + question: q.question, } switch q.question[0].Name { @@ -516,24 +632,31 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) { } return r, nil - } + }} - _, err = goLookupIP(context.Background(), fqdn) - if err == nil { - t.Fatal("expected an error") + cases := []struct { + strictErrors bool + wantErr *DNSError + }{ + {true, &DNSError{Name: fqdn, Err: "server misbehaving", IsTemporary: true}}, + {false, &DNSError{Name: fqdn, Err: errNoSuchHost.Error()}}, } + for _, tt := range cases { + r := Resolver{PreferGo: true, StrictErrors: tt.strictErrors, Dial: fake.DialContext} + _, err = r.LookupIPAddr(context.Background(), fqdn) + if err == nil { + t.Fatal("expected an error") + } - want := &DNSError{Name: fqdn, Err: errNoSuchHost.Error()} - if err, ok := err.(*DNSError); !ok || err.Name != want.Name || err.Err != want.Err { - t.Errorf("got %v; want %v", err, want) + want := tt.wantErr + if err, ok := err.(*DNSError); !ok || err.Name != want.Name || err.Err != want.Err || err.IsTemporary != want.IsTemporary { + t.Errorf("got %v; want %v", err, want) + } } } // Issue 15434. If a name server gives a lame referral, continue to the next. func TestIgnoreLameReferrals(t *testing.T) { - origTestHookDNSDialer := testHookDNSDialer - defer func() { testHookDNSDialer = origTestHookDNSDialer }() - conf, err := newResolvConfTest() if err != nil { t.Fatal(err) @@ -545,10 +668,7 @@ func TestIgnoreLameReferrals(t *testing.T) { t.Fatal(err) } - d := &fakeDNSDialer{} - testHookDNSDialer = func() dnsDialer { return d } - - d.rh = func(s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { + fake := fakeDNSServer{func(_, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { t.Log(s, q) r := &dnsMsg{ dnsMsgHdr: dnsMsgHdr{ @@ -576,9 +696,10 @@ func TestIgnoreLameReferrals(t *testing.T) { } return r, nil - } + }} + r := Resolver{PreferGo: true, Dial: fake.DialContext} - addrs, err := goLookupIP(context.Background(), "www.golang.org") + addrs, err := r.LookupIPAddr(context.Background(), "www.golang.org") if err != nil { t.Fatal(err) } @@ -597,7 +718,7 @@ func BenchmarkGoLookupIP(b *testing.B) { ctx := context.Background() for i := 0; i < b.N; i++ { - goLookupIP(ctx, "www.example.com") + goResolver.LookupIPAddr(ctx, "www.example.com") } } @@ -606,7 +727,7 @@ func BenchmarkGoLookupIPNoSuchHost(b *testing.B) { ctx := context.Background() for i := 0; i < b.N; i++ { - goLookupIP(ctx, "some.nonexistent") + goResolver.LookupIPAddr(ctx, "some.nonexistent") } } @@ -629,38 +750,70 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) { ctx := context.Background() for i := 0; i < b.N; i++ { - goLookupIP(ctx, "www.example.com") + goResolver.LookupIPAddr(ctx, "www.example.com") } } -type fakeDNSDialer struct { - // reply handler - rh func(s string, q *dnsMsg, t time.Time) (*dnsMsg, error) +type fakeDNSServer struct { + rh func(n, s string, q *dnsMsg, t time.Time) (*dnsMsg, error) } -func (f *fakeDNSDialer) dialDNS(_ context.Context, n, s string) (dnsConn, error) { - return &fakeDNSConn{f.rh, s, time.Time{}}, nil +func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) { + return &fakeDNSConn{nil, server, n, s, nil, time.Time{}}, nil } type fakeDNSConn struct { - rh func(s string, q *dnsMsg, t time.Time) (*dnsMsg, error) - s string - t time.Time + Conn + server *fakeDNSServer + n string + s string + q *dnsMsg + t time.Time } func (f *fakeDNSConn) Close() error { return nil } +func (f *fakeDNSConn) Read(b []byte) (int, error) { + resp, err := f.server.rh(f.n, f.s, f.q, f.t) + if err != nil { + return 0, err + } + + bb, ok := resp.Pack() + if !ok { + return 0, errors.New("cannot marshal DNS message") + } + if len(b) < len(bb) { + return 0, errors.New("read would fragment DNS message") + } + + copy(b, bb) + return len(bb), nil +} + +func (f *fakeDNSConn) ReadFrom(b []byte) (int, Addr, error) { + return 0, nil, nil +} + +func (f *fakeDNSConn) Write(b []byte) (int, error) { + f.q = new(dnsMsg) + if !f.q.Unpack(b) { + return 0, errors.New("cannot unmarshal DNS message") + } + return len(b), nil +} + +func (f *fakeDNSConn) WriteTo(b []byte, addr Addr) (int, error) { + return 0, nil +} + func (f *fakeDNSConn) SetDeadline(t time.Time) error { f.t = t return nil } -func (f *fakeDNSConn) dnsRoundTrip(q *dnsMsg) (*dnsMsg, error) { - return f.rh(f.s, q, f.t) -} - // UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281). func TestIgnoreDNSForgeries(t *testing.T) { c, s := Pipe() @@ -723,7 +876,8 @@ func TestIgnoreDNSForgeries(t *testing.T) { }, } - resp, err := dnsRoundTripUDP(c, msg) + dc := &dnsPacketConn{c} + resp, err := dc.dnsRoundTrip(msg) if err != nil { t.Fatalf("dnsRoundTripUDP failed: %v", err) } @@ -735,9 +889,6 @@ func TestIgnoreDNSForgeries(t *testing.T) { // Issue 16865. If a name server times out, continue to the next. func TestRetryTimeout(t *testing.T) { - origTestHookDNSDialer := testHookDNSDialer - defer func() { testHookDNSDialer = origTestHookDNSDialer }() - conf, err := newResolvConfTest() if err != nil { t.Fatal(err) @@ -752,12 +903,9 @@ func TestRetryTimeout(t *testing.T) { t.Fatal(err) } - d := &fakeDNSDialer{} - testHookDNSDialer = func() dnsDialer { return d } - var deadline0 time.Time - d.rh = func(s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) { + fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) { t.Log(s, q, deadline) if deadline.IsZero() { @@ -767,17 +915,18 @@ func TestRetryTimeout(t *testing.T) { if s == "192.0.2.1:53" { deadline0 = deadline time.Sleep(10 * time.Millisecond) - return nil, errTimeout + return nil, poll.ErrTimeout } - if deadline == deadline0 { + if deadline.Equal(deadline0) { t.Error("deadline didn't change") } return mockTXTResponse(q), nil - } + }} + r := &Resolver{PreferGo: true, Dial: fake.DialContext} - _, err = LookupTXT("www.golang.org") + _, err = r.LookupTXT(context.Background(), "www.golang.org") if err != nil { t.Fatal(err) } @@ -796,9 +945,6 @@ func TestRotate(t *testing.T) { } func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) { - origTestHookDNSDialer := testHookDNSDialer - defer func() { testHookDNSDialer = origTestHookDNSDialer }() - conf, err := newResolvConfTest() if err != nil { t.Fatal(err) @@ -817,18 +963,16 @@ func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) { t.Fatal(err) } - d := &fakeDNSDialer{} - testHookDNSDialer = func() dnsDialer { return d } - var usedServers []string - d.rh = func(s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { + fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) { usedServers = append(usedServers, s) return mockTXTResponse(q), nil - } + }} + r := Resolver{PreferGo: true, Dial: fake.DialContext} // len(nameservers) + 1 to allow rotation to get back to start for i := 0; i < len(nameservers)+1; i++ { - if _, err := LookupTXT("www.golang.org"); err != nil { + if _, err := r.LookupTXT(context.Background(), "www.golang.org"); err != nil { t.Fatal(err) } } @@ -860,3 +1004,311 @@ func mockTXTResponse(q *dnsMsg) *dnsMsg { return r } + +// Issue 17448. With StrictErrors enabled, temporary errors should make +// LookupIP fail rather than return a partial result. +func TestStrictErrorsLookupIP(t *testing.T) { + conf, err := newResolvConfTest() + if err != nil { + t.Fatal(err) + } + defer conf.teardown() + + confData := []string{ + "nameserver 192.0.2.53", + "search x.golang.org y.golang.org", + } + if err := conf.writeAndUpdate(confData); err != nil { + t.Fatal(err) + } + + const name = "test-issue19592" + const server = "192.0.2.53:53" + const searchX = "test-issue19592.x.golang.org." + const searchY = "test-issue19592.y.golang.org." + const ip4 = "192.0.2.1" + const ip6 = "2001:db8::1" + + type resolveWhichEnum int + const ( + resolveOK resolveWhichEnum = iota + resolveOpError + resolveServfail + resolveTimeout + ) + + makeTempError := func(err string) error { + return &DNSError{ + Err: err, + Name: name, + Server: server, + IsTemporary: true, + } + } + makeTimeout := func() error { + return &DNSError{ + Err: poll.ErrTimeout.Error(), + Name: name, + Server: server, + IsTimeout: true, + } + } + makeNxDomain := func() error { + return &DNSError{ + Err: errNoSuchHost.Error(), + Name: name, + Server: server, + } + } + + cases := []struct { + desc string + resolveWhich func(quest *dnsQuestion) resolveWhichEnum + wantStrictErr error + wantLaxErr error + wantIPs []string + }{ + { + desc: "No errors", + resolveWhich: func(quest *dnsQuestion) resolveWhichEnum { + return resolveOK + }, + wantIPs: []string{ip4, ip6}, + }, + { + desc: "searchX error fails in strict mode", + resolveWhich: func(quest *dnsQuestion) resolveWhichEnum { + if quest.Name == searchX { + return resolveTimeout + } + return resolveOK + }, + wantStrictErr: makeTimeout(), + wantIPs: []string{ip4, ip6}, + }, + { + desc: "searchX IPv4-only timeout fails in strict mode", + resolveWhich: func(quest *dnsQuestion) resolveWhichEnum { + if quest.Name == searchX && quest.Qtype == dnsTypeA { + return resolveTimeout + } + return resolveOK + }, + wantStrictErr: makeTimeout(), + wantIPs: []string{ip4, ip6}, + }, + { + desc: "searchX IPv6-only servfail fails in strict mode", + resolveWhich: func(quest *dnsQuestion) resolveWhichEnum { + if quest.Name == searchX && quest.Qtype == dnsTypeAAAA { + return resolveServfail + } + return resolveOK + }, + wantStrictErr: makeTempError("server misbehaving"), + wantIPs: []string{ip4, ip6}, + }, + { + desc: "searchY error always fails", + resolveWhich: func(quest *dnsQuestion) resolveWhichEnum { + if quest.Name == searchY { + return resolveTimeout + } + return resolveOK + }, + wantStrictErr: makeTimeout(), + wantLaxErr: makeNxDomain(), // This one reaches the "test." FQDN. + }, + { + desc: "searchY IPv4-only socket error fails in strict mode", + resolveWhich: func(quest *dnsQuestion) resolveWhichEnum { + if quest.Name == searchY && quest.Qtype == dnsTypeA { + return resolveOpError + } + return resolveOK + }, + wantStrictErr: makeTempError("write: socket on fire"), + wantIPs: []string{ip6}, + }, + { + desc: "searchY IPv6-only timeout fails in strict mode", + resolveWhich: func(quest *dnsQuestion) resolveWhichEnum { + if quest.Name == searchY && quest.Qtype == dnsTypeAAAA { + return resolveTimeout + } + return resolveOK + }, + wantStrictErr: makeTimeout(), + wantIPs: []string{ip4}, + }, + } + + for i, tt := range cases { + fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) { + t.Log(s, q) + + switch tt.resolveWhich(&q.question[0]) { + case resolveOK: + // Handle below. + case resolveOpError: + return nil, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")} + case resolveServfail: + return &dnsMsg{ + dnsMsgHdr: dnsMsgHdr{ + id: q.id, + response: true, + rcode: dnsRcodeServerFailure, + }, + question: q.question, + }, nil + case resolveTimeout: + return nil, poll.ErrTimeout + default: + t.Fatal("Impossible resolveWhich") + } + + switch q.question[0].Name { + case searchX, name + ".": + // Return NXDOMAIN to utilize the search list. + return &dnsMsg{ + dnsMsgHdr: dnsMsgHdr{ + id: q.id, + response: true, + rcode: dnsRcodeNameError, + }, + question: q.question, + }, nil + case searchY: + // Return records below. + default: + return nil, fmt.Errorf("Unexpected Name: %v", q.question[0].Name) + } + + r := &dnsMsg{ + dnsMsgHdr: dnsMsgHdr{ + id: q.id, + response: true, + }, + question: q.question, + } + switch q.question[0].Qtype { + case dnsTypeA: + r.answer = []dnsRR{ + &dnsRR_A{ + Hdr: dnsRR_Header{ + Name: q.question[0].Name, + Rrtype: dnsTypeA, + Class: dnsClassINET, + Rdlength: 4, + }, + A: TestAddr, + }, + } + case dnsTypeAAAA: + r.answer = []dnsRR{ + &dnsRR_AAAA{ + Hdr: dnsRR_Header{ + Name: q.question[0].Name, + Rrtype: dnsTypeAAAA, + Class: dnsClassINET, + Rdlength: 16, + }, + AAAA: TestAddr6, + }, + } + default: + return nil, fmt.Errorf("Unexpected Qtype: %v", q.question[0].Qtype) + } + return r, nil + }} + + for _, strict := range []bool{true, false} { + r := Resolver{PreferGo: true, StrictErrors: strict, Dial: fake.DialContext} + ips, err := r.LookupIPAddr(context.Background(), name) + + var wantErr error + if strict { + wantErr = tt.wantStrictErr + } else { + wantErr = tt.wantLaxErr + } + if !reflect.DeepEqual(err, wantErr) { + t.Errorf("#%d (%s) strict=%v: got err %#v; want %#v", i, tt.desc, strict, err, wantErr) + } + + gotIPs := map[string]struct{}{} + for _, ip := range ips { + gotIPs[ip.String()] = struct{}{} + } + wantIPs := map[string]struct{}{} + if wantErr == nil { + for _, ip := range tt.wantIPs { + wantIPs[ip] = struct{}{} + } + } + if !reflect.DeepEqual(gotIPs, wantIPs) { + t.Errorf("#%d (%s) strict=%v: got ips %v; want %v", i, tt.desc, strict, gotIPs, wantIPs) + } + } + } +} + +// Issue 17448. With StrictErrors enabled, temporary errors should make +// LookupTXT stop walking the search list. +func TestStrictErrorsLookupTXT(t *testing.T) { + conf, err := newResolvConfTest() + if err != nil { + t.Fatal(err) + } + defer conf.teardown() + + confData := []string{ + "nameserver 192.0.2.53", + "search x.golang.org y.golang.org", + } + if err := conf.writeAndUpdate(confData); err != nil { + t.Fatal(err) + } + + const name = "test" + const server = "192.0.2.53:53" + const searchX = "test.x.golang.org." + const searchY = "test.y.golang.org." + const txt = "Hello World" + + fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) { + t.Log(s, q) + + switch q.question[0].Name { + case searchX: + return nil, poll.ErrTimeout + case searchY: + return mockTXTResponse(q), nil + default: + return nil, fmt.Errorf("Unexpected Name: %v", q.question[0].Name) + } + }} + + for _, strict := range []bool{true, false} { + r := Resolver{StrictErrors: strict, Dial: fake.DialContext} + _, rrs, err := r.lookup(context.Background(), name, dnsTypeTXT) + var wantErr error + var wantRRs int + if strict { + wantErr = &DNSError{ + Err: poll.ErrTimeout.Error(), + Name: name, + Server: server, + IsTimeout: true, + } + } else { + wantRRs = 1 + } + if !reflect.DeepEqual(err, wantErr) { + t.Errorf("strict=%v: got err %#v; want %#v", strict, err, wantErr) + } + if len(rrs) != wantRRs { + t.Errorf("strict=%v: got %v; want %v", strict, len(rrs), wantRRs) + } + } +} diff --git a/libgo/go/net/error_posix.go b/libgo/go/net/error_posix.go new file mode 100644 index 0000000..dd9754c --- /dev/null +++ b/libgo/go/net/error_posix.go @@ -0,0 +1,21 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris windows + +package net + +import ( + "os" + "syscall" +) + +// wrapSyscallError takes an error and a syscall name. If the error is +// a syscall.Errno, it wraps it in a os.SyscallError using the syscall name. +func wrapSyscallError(name string, err error) error { + if _, ok := err.(syscall.Errno); ok { + err = os.NewSyscallError(name, err) + } + return err +} diff --git a/libgo/go/net/error_test.go b/libgo/go/net/error_test.go index c23da49..9791e6f 100644 --- a/libgo/go/net/error_test.go +++ b/libgo/go/net/error_test.go @@ -7,11 +7,13 @@ package net import ( "context" "fmt" + "internal/poll" "io" "io/ioutil" "net/internal/socktest" "os" "runtime" + "strings" "testing" "time" ) @@ -87,7 +89,7 @@ second: return nil } switch err := nestedErr.(type) { - case *AddrError, addrinfoErrno, *DNSError, InvalidAddrError, *ParseError, *timeoutError, UnknownNetworkError: + case *AddrError, addrinfoErrno, *DNSError, InvalidAddrError, *ParseError, *poll.TimeoutError, UnknownNetworkError: return nil case *os.SyscallError: nestedErr = err.Err @@ -97,7 +99,7 @@ second: goto third } switch nestedErr { - case errCanceled, errClosing, errMissingAddress, errNoSuitableAddress, + case errCanceled, poll.ErrNetClosing, errMissingAddress, errNoSuitableAddress, context.DeadlineExceeded, context.Canceled: return nil } @@ -213,7 +215,7 @@ func TestDialAddrError(t *testing.T) { case "nacl", "plan9": t.Skipf("not supported on %s", runtime.GOOS) } - if !supportsIPv4 || !supportsIPv6 { + if !supportsIPv4() || !supportsIPv6() { t.Skip("both IPv4 and IPv6 are required") } @@ -432,7 +434,7 @@ second: goto third } switch nestedErr { - case errClosing, errTimeout: + case poll.ErrNetClosing, poll.ErrTimeout: return nil } return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr) @@ -467,14 +469,14 @@ second: return nil } switch err := nestedErr.(type) { - case *AddrError, addrinfoErrno, *DNSError, InvalidAddrError, *ParseError, *timeoutError, UnknownNetworkError: + case *AddrError, addrinfoErrno, *DNSError, InvalidAddrError, *ParseError, *poll.TimeoutError, UnknownNetworkError: return nil case *os.SyscallError: nestedErr = err.Err goto third } switch nestedErr { - case errCanceled, errClosing, errMissingAddress, errTimeout, ErrWriteToConnected, io.ErrUnexpectedEOF: + case errCanceled, poll.ErrNetClosing, errMissingAddress, poll.ErrTimeout, ErrWriteToConnected, io.ErrUnexpectedEOF: return nil } return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr) @@ -489,11 +491,21 @@ third: // parseCloseError parses nestedErr and reports whether it is a valid // error value from Close functions. // It returns nil when nestedErr is valid. -func parseCloseError(nestedErr error) error { +func parseCloseError(nestedErr error, isShutdown bool) error { if nestedErr == nil { return nil } + // Because historically we have not exported the error that we + // return for an operation on a closed network connection, + // there are programs that test for the exact error string. + // Verify that string here so that we don't break those + // programs unexpectedly. See issues #4373 and #19252. + want := "use of closed network connection" + if !isShutdown && !strings.Contains(nestedErr.Error(), want) { + return fmt.Errorf("error string %q does not contain expected string %q", nestedErr, want) + } + switch err := nestedErr.(type) { case *OpError: if err := err.isValid(); err != nil { @@ -517,7 +529,7 @@ second: goto third } switch nestedErr { - case errClosing: + case poll.ErrNetClosing: return nil } return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr) @@ -547,23 +559,23 @@ func TestCloseError(t *testing.T) { for i := 0; i < 3; i++ { err = c.(*TCPConn).CloseRead() - if perr := parseCloseError(err); perr != nil { + if perr := parseCloseError(err, true); perr != nil { t.Errorf("#%d: %v", i, perr) } } for i := 0; i < 3; i++ { err = c.(*TCPConn).CloseWrite() - if perr := parseCloseError(err); perr != nil { + if perr := parseCloseError(err, true); perr != nil { t.Errorf("#%d: %v", i, perr) } } for i := 0; i < 3; i++ { err = c.Close() - if perr := parseCloseError(err); perr != nil { + if perr := parseCloseError(err, false); perr != nil { t.Errorf("#%d: %v", i, perr) } err = ln.Close() - if perr := parseCloseError(err); perr != nil { + if perr := parseCloseError(err, false); perr != nil { t.Errorf("#%d: %v", i, perr) } } @@ -576,7 +588,7 @@ func TestCloseError(t *testing.T) { for i := 0; i < 3; i++ { err = pc.Close() - if perr := parseCloseError(err); perr != nil { + if perr := parseCloseError(err, false); perr != nil { t.Errorf("#%d: %v", i, perr) } } @@ -613,7 +625,7 @@ second: goto third } switch nestedErr { - case errClosing, errTimeout: + case poll.ErrNetClosing, poll.ErrTimeout: return nil } return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr) @@ -692,7 +704,7 @@ second: goto third } switch nestedErr { - case errClosing: + case poll.ErrNetClosing: return nil } return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr) diff --git a/libgo/go/net/external_test.go b/libgo/go/net/external_test.go index e18b547..38788ef 100644 --- a/libgo/go/net/external_test.go +++ b/libgo/go/net/external_test.go @@ -15,7 +15,7 @@ import ( func TestResolveGoogle(t *testing.T) { testenv.MustHaveExternalNetwork(t) - if !supportsIPv4 || !supportsIPv6 || !*testIPv4 || !*testIPv6 { + if !supportsIPv4() || !supportsIPv6() || !*testIPv4 || !*testIPv6 { t.Skip("both IPv4 and IPv6 are required") } @@ -62,7 +62,7 @@ var dialGoogleTests = []struct { func TestDialGoogle(t *testing.T) { testenv.MustHaveExternalNetwork(t) - if !supportsIPv4 || !supportsIPv6 || !*testIPv4 || !*testIPv6 { + if !supportsIPv4() || !supportsIPv6() || !*testIPv4 || !*testIPv6 { t.Skip("both IPv4 and IPv6 are required") } diff --git a/libgo/go/net/fd_io_plan9.go b/libgo/go/net/fd_io_plan9.go deleted file mode 100644 index 76da0c5..0000000 --- a/libgo/go/net/fd_io_plan9.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2016 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package net - -import ( - "os" - "runtime" - "sync" - "syscall" -) - -// asyncIO implements asynchronous cancelable I/O. -// An asyncIO represents a single asynchronous Read or Write -// operation. The result is returned on the result channel. -// The undergoing I/O system call can either complete or be -// interrupted by a note. -type asyncIO struct { - res chan result - - // mu guards the pid field. - mu sync.Mutex - - // pid holds the process id of - // the process running the IO operation. - pid int -} - -// result is the return value of a Read or Write operation. -type result struct { - n int - err error -} - -// newAsyncIO returns a new asyncIO that performs an I/O -// operation by calling fn, which must do one and only one -// interruptible system call. -func newAsyncIO(fn func([]byte) (int, error), b []byte) *asyncIO { - aio := &asyncIO{ - res: make(chan result, 0), - } - aio.mu.Lock() - go func() { - // Lock the current goroutine to its process - // and store the pid in io so that Cancel can - // interrupt it. We ignore the "hangup" signal, - // so the signal does not take down the entire - // Go runtime. - runtime.LockOSThread() - runtime_ignoreHangup() - aio.pid = os.Getpid() - aio.mu.Unlock() - - n, err := fn(b) - - aio.mu.Lock() - aio.pid = -1 - runtime_unignoreHangup() - aio.mu.Unlock() - - aio.res <- result{n, err} - }() - return aio -} - -var hangupNote os.Signal = syscall.Note("hangup") - -// Cancel interrupts the I/O operation, causing -// the Wait function to return. -func (aio *asyncIO) Cancel() { - aio.mu.Lock() - defer aio.mu.Unlock() - if aio.pid == -1 { - return - } - proc, err := os.FindProcess(aio.pid) - if err != nil { - return - } - proc.Signal(hangupNote) -} - -// Wait for the I/O operation to complete. -func (aio *asyncIO) Wait() (int, error) { - res := <-aio.res - return res.n, res.err -} - -// The following functions, provided by the runtime, are used to -// ignore and unignore the "hangup" signal received by the process. -func runtime_ignoreHangup() -func runtime_unignoreHangup() diff --git a/libgo/go/net/fd_mutex.go b/libgo/go/net/fd_mutex.go deleted file mode 100644 index 4591fd1..0000000 --- a/libgo/go/net/fd_mutex.go +++ /dev/null @@ -1,249 +0,0 @@ -// Copyright 2013 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package net - -import "sync/atomic" - -// fdMutex is a specialized synchronization primitive that manages -// lifetime of an fd and serializes access to Read, Write and Close -// methods on netFD. -type fdMutex struct { - state uint64 - rsema uint32 - wsema uint32 -} - -// fdMutex.state is organized as follows: -// 1 bit - whether netFD is closed, if set all subsequent lock operations will fail. -// 1 bit - lock for read operations. -// 1 bit - lock for write operations. -// 20 bits - total number of references (read+write+misc). -// 20 bits - number of outstanding read waiters. -// 20 bits - number of outstanding write waiters. -const ( - mutexClosed = 1 << 0 - mutexRLock = 1 << 1 - mutexWLock = 1 << 2 - mutexRef = 1 << 3 - mutexRefMask = (1<<20 - 1) << 3 - mutexRWait = 1 << 23 - mutexRMask = (1<<20 - 1) << 23 - mutexWWait = 1 << 43 - mutexWMask = (1<<20 - 1) << 43 -) - -// Read operations must do rwlock(true)/rwunlock(true). -// -// Write operations must do rwlock(false)/rwunlock(false). -// -// Misc operations must do incref/decref. -// Misc operations include functions like setsockopt and setDeadline. -// They need to use incref/decref to ensure that they operate on the -// correct fd in presence of a concurrent close call (otherwise fd can -// be closed under their feet). -// -// Close operations must do increfAndClose/decref. - -// incref adds a reference to mu. -// It reports whether mu is available for reading or writing. -func (mu *fdMutex) incref() bool { - for { - old := atomic.LoadUint64(&mu.state) - if old&mutexClosed != 0 { - return false - } - new := old + mutexRef - if new&mutexRefMask == 0 { - panic("net: inconsistent fdMutex") - } - if atomic.CompareAndSwapUint64(&mu.state, old, new) { - return true - } - } -} - -// increfAndClose sets the state of mu to closed. -// It reports whether there is no remaining reference. -func (mu *fdMutex) increfAndClose() bool { - for { - old := atomic.LoadUint64(&mu.state) - if old&mutexClosed != 0 { - return false - } - // Mark as closed and acquire a reference. - new := (old | mutexClosed) + mutexRef - if new&mutexRefMask == 0 { - panic("net: inconsistent fdMutex") - } - // Remove all read and write waiters. - new &^= mutexRMask | mutexWMask - if atomic.CompareAndSwapUint64(&mu.state, old, new) { - // Wake all read and write waiters, - // they will observe closed flag after wakeup. - for old&mutexRMask != 0 { - old -= mutexRWait - runtime_Semrelease(&mu.rsema) - } - for old&mutexWMask != 0 { - old -= mutexWWait - runtime_Semrelease(&mu.wsema) - } - return true - } - } -} - -// decref removes a reference from mu. -// It reports whether there is no remaining reference. -func (mu *fdMutex) decref() bool { - for { - old := atomic.LoadUint64(&mu.state) - if old&mutexRefMask == 0 { - panic("net: inconsistent fdMutex") - } - new := old - mutexRef - if atomic.CompareAndSwapUint64(&mu.state, old, new) { - return new&(mutexClosed|mutexRefMask) == mutexClosed - } - } -} - -// lock adds a reference to mu and locks mu. -// It reports whether mu is available for reading or writing. -func (mu *fdMutex) rwlock(read bool) bool { - var mutexBit, mutexWait, mutexMask uint64 - var mutexSema *uint32 - if read { - mutexBit = mutexRLock - mutexWait = mutexRWait - mutexMask = mutexRMask - mutexSema = &mu.rsema - } else { - mutexBit = mutexWLock - mutexWait = mutexWWait - mutexMask = mutexWMask - mutexSema = &mu.wsema - } - for { - old := atomic.LoadUint64(&mu.state) - if old&mutexClosed != 0 { - return false - } - var new uint64 - if old&mutexBit == 0 { - // Lock is free, acquire it. - new = (old | mutexBit) + mutexRef - if new&mutexRefMask == 0 { - panic("net: inconsistent fdMutex") - } - } else { - // Wait for lock. - new = old + mutexWait - if new&mutexMask == 0 { - panic("net: inconsistent fdMutex") - } - } - if atomic.CompareAndSwapUint64(&mu.state, old, new) { - if old&mutexBit == 0 { - return true - } - runtime_Semacquire(mutexSema) - // The signaller has subtracted mutexWait. - } - } -} - -// unlock removes a reference from mu and unlocks mu. -// It reports whether there is no remaining reference. -func (mu *fdMutex) rwunlock(read bool) bool { - var mutexBit, mutexWait, mutexMask uint64 - var mutexSema *uint32 - if read { - mutexBit = mutexRLock - mutexWait = mutexRWait - mutexMask = mutexRMask - mutexSema = &mu.rsema - } else { - mutexBit = mutexWLock - mutexWait = mutexWWait - mutexMask = mutexWMask - mutexSema = &mu.wsema - } - for { - old := atomic.LoadUint64(&mu.state) - if old&mutexBit == 0 || old&mutexRefMask == 0 { - panic("net: inconsistent fdMutex") - } - // Drop lock, drop reference and wake read waiter if present. - new := (old &^ mutexBit) - mutexRef - if old&mutexMask != 0 { - new -= mutexWait - } - if atomic.CompareAndSwapUint64(&mu.state, old, new) { - if old&mutexMask != 0 { - runtime_Semrelease(mutexSema) - } - return new&(mutexClosed|mutexRefMask) == mutexClosed - } - } -} - -// Implemented in runtime package. -func runtime_Semacquire(sema *uint32) -func runtime_Semrelease(sema *uint32) - -// incref adds a reference to fd. -// It returns an error when fd cannot be used. -func (fd *netFD) incref() error { - if !fd.fdmu.incref() { - return errClosing - } - return nil -} - -// decref removes a reference from fd. -// It also closes fd when the state of fd is set to closed and there -// is no remaining reference. -func (fd *netFD) decref() { - if fd.fdmu.decref() { - fd.destroy() - } -} - -// readLock adds a reference to fd and locks fd for reading. -// It returns an error when fd cannot be used for reading. -func (fd *netFD) readLock() error { - if !fd.fdmu.rwlock(true) { - return errClosing - } - return nil -} - -// readUnlock removes a reference from fd and unlocks fd for reading. -// It also closes fd when the state of fd is set to closed and there -// is no remaining reference. -func (fd *netFD) readUnlock() { - if fd.fdmu.rwunlock(true) { - fd.destroy() - } -} - -// writeLock adds a reference to fd and locks fd for writing. -// It returns an error when fd cannot be used for writing. -func (fd *netFD) writeLock() error { - if !fd.fdmu.rwlock(false) { - return errClosing - } - return nil -} - -// writeUnlock removes a reference from fd and unlocks fd for writing. -// It also closes fd when the state of fd is set to closed and there -// is no remaining reference. -func (fd *netFD) writeUnlock() { - if fd.fdmu.rwunlock(false) { - fd.destroy() - } -} diff --git a/libgo/go/net/fd_mutex_test.go b/libgo/go/net/fd_mutex_test.go deleted file mode 100644 index 3542c70..0000000 --- a/libgo/go/net/fd_mutex_test.go +++ /dev/null @@ -1,195 +0,0 @@ -// Copyright 2013 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package net - -import ( - "math/rand" - "runtime" - "testing" - "time" -) - -func TestMutexLock(t *testing.T) { - var mu fdMutex - - if !mu.incref() { - t.Fatal("broken") - } - if mu.decref() { - t.Fatal("broken") - } - - if !mu.rwlock(true) { - t.Fatal("broken") - } - if mu.rwunlock(true) { - t.Fatal("broken") - } - - if !mu.rwlock(false) { - t.Fatal("broken") - } - if mu.rwunlock(false) { - t.Fatal("broken") - } -} - -func TestMutexClose(t *testing.T) { - var mu fdMutex - if !mu.increfAndClose() { - t.Fatal("broken") - } - - if mu.incref() { - t.Fatal("broken") - } - if mu.rwlock(true) { - t.Fatal("broken") - } - if mu.rwlock(false) { - t.Fatal("broken") - } - if mu.increfAndClose() { - t.Fatal("broken") - } -} - -func TestMutexCloseUnblock(t *testing.T) { - c := make(chan bool) - var mu fdMutex - mu.rwlock(true) - for i := 0; i < 4; i++ { - go func() { - if mu.rwlock(true) { - t.Error("broken") - return - } - c <- true - }() - } - // Concurrent goroutines must not be able to read lock the mutex. - time.Sleep(time.Millisecond) - select { - case <-c: - t.Fatal("broken") - default: - } - mu.increfAndClose() // Must unblock the readers. - for i := 0; i < 4; i++ { - select { - case <-c: - case <-time.After(10 * time.Second): - t.Fatal("broken") - } - } - if mu.decref() { - t.Fatal("broken") - } - if !mu.rwunlock(true) { - t.Fatal("broken") - } -} - -func TestMutexPanic(t *testing.T) { - ensurePanics := func(f func()) { - defer func() { - if recover() == nil { - t.Fatal("does not panic") - } - }() - f() - } - - var mu fdMutex - ensurePanics(func() { mu.decref() }) - ensurePanics(func() { mu.rwunlock(true) }) - ensurePanics(func() { mu.rwunlock(false) }) - - ensurePanics(func() { mu.incref(); mu.decref(); mu.decref() }) - ensurePanics(func() { mu.rwlock(true); mu.rwunlock(true); mu.rwunlock(true) }) - ensurePanics(func() { mu.rwlock(false); mu.rwunlock(false); mu.rwunlock(false) }) - - // ensure that it's still not broken - mu.incref() - mu.decref() - mu.rwlock(true) - mu.rwunlock(true) - mu.rwlock(false) - mu.rwunlock(false) -} - -func TestMutexStress(t *testing.T) { - P := 8 - N := int(1e6) - if testing.Short() { - P = 4 - N = 1e4 - } - defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(P)) - done := make(chan bool) - var mu fdMutex - var readState [2]uint64 - var writeState [2]uint64 - for p := 0; p < P; p++ { - go func() { - r := rand.New(rand.NewSource(rand.Int63())) - for i := 0; i < N; i++ { - switch r.Intn(3) { - case 0: - if !mu.incref() { - t.Error("broken") - return - } - if mu.decref() { - t.Error("broken") - return - } - case 1: - if !mu.rwlock(true) { - t.Error("broken") - return - } - // Ensure that it provides mutual exclusion for readers. - if readState[0] != readState[1] { - t.Error("broken") - return - } - readState[0]++ - readState[1]++ - if mu.rwunlock(true) { - t.Error("broken") - return - } - case 2: - if !mu.rwlock(false) { - t.Error("broken") - return - } - // Ensure that it provides mutual exclusion for writers. - if writeState[0] != writeState[1] { - t.Error("broken") - return - } - writeState[0]++ - writeState[1]++ - if mu.rwunlock(false) { - t.Error("broken") - return - } - } - } - done <- true - }() - } - for p := 0; p < P; p++ { - <-done - } - if !mu.increfAndClose() { - t.Fatal("broken") - } - if !mu.decref() { - t.Fatal("broken") - } -} diff --git a/libgo/go/net/fd_plan9.go b/libgo/go/net/fd_plan9.go index 300d8c4..46ee5d9 100644 --- a/libgo/go/net/fd_plan9.go +++ b/libgo/go/net/fd_plan9.go @@ -5,23 +5,15 @@ package net import ( + "internal/poll" "io" "os" - "sync/atomic" "syscall" - "time" ) -type atomicBool int32 - -func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 } -func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) } -func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) } - // Network file descriptor. type netFD struct { - // locking/lifetime of sysfd + serialize access to Read and Write methods - fdmu fdMutex + pfd poll.FD // immutable until Close net string @@ -30,26 +22,12 @@ type netFD struct { listen, ctl, data *os.File laddr, raddr Addr isStream bool - - // deadlines - raio *asyncIO - waio *asyncIO - rtimer *time.Timer - wtimer *time.Timer - rtimedout atomicBool // set true when read deadline has been reached - wtimedout atomicBool // set true when write deadline has been reached } -var ( - netdir string // default network -) - -func sysInit() { - netdir = "/net" -} +var netdir = "/net" // default network func newFD(net, name string, listen, ctl, data *os.File, laddr, raddr Addr) (*netFD, error) { - return &netFD{ + ret := &netFD{ net: net, n: name, dir: netdir + "/" + net + "/" + name, @@ -57,7 +35,9 @@ func newFD(net, name string, listen, ctl, data *os.File, laddr, raddr Addr) (*ne ctl: ctl, data: data, laddr: laddr, raddr: raddr, - }, nil + } + ret.pfd.Destroy = ret.destroy + return ret, nil } func (fd *netFD) init() error { @@ -99,28 +79,10 @@ func (fd *netFD) destroy() { } func (fd *netFD) Read(b []byte) (n int, err error) { - if fd.rtimedout.isSet() { - return 0, errTimeout - } if !fd.ok() || fd.data == nil { return 0, syscall.EINVAL } - if err := fd.readLock(); err != nil { - return 0, err - } - defer fd.readUnlock() - if len(b) == 0 { - return 0, nil - } - fd.raio = newAsyncIO(fd.data.Read, b) - n, err = fd.raio.Wait() - fd.raio = nil - if isHangup(err) { - err = io.EOF - } - if isInterrupted(err) { - err = errTimeout - } + n, err = fd.pfd.Read(fd.data.Read, b) if fd.net == "udp" && err == io.EOF { n = 0 err = nil @@ -129,23 +91,10 @@ func (fd *netFD) Read(b []byte) (n int, err error) { } func (fd *netFD) Write(b []byte) (n int, err error) { - if fd.wtimedout.isSet() { - return 0, errTimeout - } if !fd.ok() || fd.data == nil { return 0, syscall.EINVAL } - if err := fd.writeLock(); err != nil { - return 0, err - } - defer fd.writeUnlock() - fd.waio = newAsyncIO(fd.data.Write, b) - n, err = fd.waio.Wait() - fd.waio = nil - if isInterrupted(err) { - err = errTimeout - } - return + return fd.pfd.Write(fd.data.Write, b) } func (fd *netFD) closeRead() error { @@ -163,8 +112,8 @@ func (fd *netFD) closeWrite() error { } func (fd *netFD) Close() error { - if !fd.fdmu.increfAndClose() { - return errClosing + if err := fd.pfd.Close(); err != nil { + return err } if !fd.ok() { return syscall.EINVAL @@ -216,77 +165,6 @@ func (fd *netFD) file(f *os.File, s string) (*os.File, error) { return os.NewFile(uintptr(dfd), s), nil } -func (fd *netFD) setDeadline(t time.Time) error { - return setDeadlineImpl(fd, t, 'r'+'w') -} - -func (fd *netFD) setReadDeadline(t time.Time) error { - return setDeadlineImpl(fd, t, 'r') -} - -func (fd *netFD) setWriteDeadline(t time.Time) error { - return setDeadlineImpl(fd, t, 'w') -} - -func setDeadlineImpl(fd *netFD, t time.Time, mode int) error { - d := t.Sub(time.Now()) - if mode == 'r' || mode == 'r'+'w' { - fd.rtimedout.setFalse() - } - if mode == 'w' || mode == 'r'+'w' { - fd.wtimedout.setFalse() - } - if t.IsZero() || d < 0 { - // Stop timer - if mode == 'r' || mode == 'r'+'w' { - if fd.rtimer != nil { - fd.rtimer.Stop() - } - fd.rtimer = nil - } - if mode == 'w' || mode == 'r'+'w' { - if fd.wtimer != nil { - fd.wtimer.Stop() - } - fd.wtimer = nil - } - } else { - // Interrupt I/O operation once timer has expired - if mode == 'r' || mode == 'r'+'w' { - fd.rtimer = time.AfterFunc(d, func() { - fd.rtimedout.setTrue() - if fd.raio != nil { - fd.raio.Cancel() - } - }) - } - if mode == 'w' || mode == 'r'+'w' { - fd.wtimer = time.AfterFunc(d, func() { - fd.wtimedout.setTrue() - if fd.waio != nil { - fd.waio.Cancel() - } - }) - } - } - if !t.IsZero() && d < 0 { - // Interrupt current I/O operation - if mode == 'r' || mode == 'r'+'w' { - fd.rtimedout.setTrue() - if fd.raio != nil { - fd.raio.Cancel() - } - } - if mode == 'w' || mode == 'r'+'w' { - fd.wtimedout.setTrue() - if fd.waio != nil { - fd.waio.Cancel() - } - } - } - return nil -} - func setReadBuffer(fd *netFD, bytes int) error { return syscall.EPLAN9 } @@ -294,11 +172,3 @@ func setReadBuffer(fd *netFD, bytes int) error { func setWriteBuffer(fd *netFD, bytes int) error { return syscall.EPLAN9 } - -func isHangup(err error) bool { - return err != nil && stringsHasSuffix(err.Error(), "Hangup") -} - -func isInterrupted(err error) bool { - return err != nil && stringsHasSuffix(err.Error(), "interrupted") -} diff --git a/libgo/go/net/fd_poll_nacl.go b/libgo/go/net/fd_poll_nacl.go deleted file mode 100644 index 8398760..0000000 --- a/libgo/go/net/fd_poll_nacl.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2013 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package net - -import ( - "runtime" - "syscall" - "time" -) - -type pollDesc struct { - fd *netFD - closing bool -} - -func (pd *pollDesc) init(fd *netFD) error { pd.fd = fd; return nil } - -func (pd *pollDesc) close() {} - -func (pd *pollDesc) evict() { - pd.closing = true - if pd.fd != nil { - syscall.StopIO(pd.fd.sysfd) - runtime.KeepAlive(pd.fd) - } -} - -func (pd *pollDesc) prepare(mode int) error { - if pd.closing { - return errClosing - } - return nil -} - -func (pd *pollDesc) prepareRead() error { return pd.prepare('r') } - -func (pd *pollDesc) prepareWrite() error { return pd.prepare('w') } - -func (pd *pollDesc) wait(mode int) error { - if pd.closing { - return errClosing - } - return errTimeout -} - -func (pd *pollDesc) waitRead() error { return pd.wait('r') } - -func (pd *pollDesc) waitWrite() error { return pd.wait('w') } - -func (pd *pollDesc) waitCanceled(mode int) {} - -func (pd *pollDesc) waitCanceledRead() {} - -func (pd *pollDesc) waitCanceledWrite() {} - -func (fd *netFD) setDeadline(t time.Time) error { - return setDeadlineImpl(fd, t, 'r'+'w') -} - -func (fd *netFD) setReadDeadline(t time.Time) error { - return setDeadlineImpl(fd, t, 'r') -} - -func (fd *netFD) setWriteDeadline(t time.Time) error { - return setDeadlineImpl(fd, t, 'w') -} - -func setDeadlineImpl(fd *netFD, t time.Time, mode int) error { - d := t.UnixNano() - if t.IsZero() { - d = 0 - } - if err := fd.incref(); err != nil { - return err - } - switch mode { - case 'r': - syscall.SetReadDeadline(fd.sysfd, d) - case 'w': - syscall.SetWriteDeadline(fd.sysfd, d) - case 'r' + 'w': - syscall.SetReadDeadline(fd.sysfd, d) - syscall.SetWriteDeadline(fd.sysfd, d) - } - fd.decref() - return nil -} diff --git a/libgo/go/net/fd_poll_runtime.go b/libgo/go/net/fd_poll_runtime.go deleted file mode 100644 index 4ea92cb..0000000 --- a/libgo/go/net/fd_poll_runtime.go +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright 2013 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build aix darwin dragonfly freebsd linux netbsd openbsd windows solaris - -package net - -import ( - "runtime" - "sync" - "syscall" - "time" -) - -// runtimeNano returns the current value of the runtime clock in nanoseconds. -func runtimeNano() int64 - -func runtime_pollServerInit() -func runtime_pollOpen(fd uintptr) (uintptr, int) -func runtime_pollClose(ctx uintptr) -func runtime_pollWait(ctx uintptr, mode int) int -func runtime_pollWaitCanceled(ctx uintptr, mode int) int -func runtime_pollReset(ctx uintptr, mode int) int -func runtime_pollSetDeadline(ctx uintptr, d int64, mode int) -func runtime_pollUnblock(ctx uintptr) - -type pollDesc struct { - runtimeCtx uintptr -} - -var serverInit sync.Once - -func (pd *pollDesc) init(fd *netFD) error { - serverInit.Do(runtime_pollServerInit) - ctx, errno := runtime_pollOpen(uintptr(fd.sysfd)) - runtime.KeepAlive(fd) - if errno != 0 { - return syscall.Errno(errno) - } - pd.runtimeCtx = ctx - return nil -} - -func (pd *pollDesc) close() { - if pd.runtimeCtx == 0 { - return - } - runtime_pollClose(pd.runtimeCtx) - pd.runtimeCtx = 0 -} - -// Evict evicts fd from the pending list, unblocking any I/O running on fd. -func (pd *pollDesc) evict() { - if pd.runtimeCtx == 0 { - return - } - runtime_pollUnblock(pd.runtimeCtx) -} - -func (pd *pollDesc) prepare(mode int) error { - res := runtime_pollReset(pd.runtimeCtx, mode) - return convertErr(res) -} - -func (pd *pollDesc) prepareRead() error { - return pd.prepare('r') -} - -func (pd *pollDesc) prepareWrite() error { - return pd.prepare('w') -} - -func (pd *pollDesc) wait(mode int) error { - res := runtime_pollWait(pd.runtimeCtx, mode) - return convertErr(res) -} - -func (pd *pollDesc) waitRead() error { - return pd.wait('r') -} - -func (pd *pollDesc) waitWrite() error { - return pd.wait('w') -} - -func (pd *pollDesc) waitCanceled(mode int) { - runtime_pollWaitCanceled(pd.runtimeCtx, mode) -} - -func (pd *pollDesc) waitCanceledRead() { - pd.waitCanceled('r') -} - -func (pd *pollDesc) waitCanceledWrite() { - pd.waitCanceled('w') -} - -func convertErr(res int) error { - switch res { - case 0: - return nil - case 1: - return errClosing - case 2: - return errTimeout - } - println("unreachable: ", res) - panic("unreachable") -} - -func (fd *netFD) setDeadline(t time.Time) error { - return setDeadlineImpl(fd, t, 'r'+'w') -} - -func (fd *netFD) setReadDeadline(t time.Time) error { - return setDeadlineImpl(fd, t, 'r') -} - -func (fd *netFD) setWriteDeadline(t time.Time) error { - return setDeadlineImpl(fd, t, 'w') -} - -func setDeadlineImpl(fd *netFD, t time.Time, mode int) error { - diff := int64(time.Until(t)) - d := runtimeNano() + diff - if d <= 0 && diff > 0 { - // If the user has a deadline in the future, but the delay calculation - // overflows, then set the deadline to the maximum possible value. - d = 1<<63 - 1 - } - if t.IsZero() { - d = 0 - } - if err := fd.incref(); err != nil { - return err - } - runtime_pollSetDeadline(fd.pd.runtimeCtx, d, mode) - fd.decref() - return nil -} diff --git a/libgo/go/net/fd_posix.go b/libgo/go/net/fd_posix.go deleted file mode 100644 index 7230479..0000000 --- a/libgo/go/net/fd_posix.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build aix darwin dragonfly freebsd linux nacl netbsd openbsd solaris windows - -package net - -import ( - "io" - "syscall" -) - -// eofError returns io.EOF when fd is available for reading end of -// file. -func (fd *netFD) eofError(n int, err error) error { - if n == 0 && err == nil && fd.sotype != syscall.SOCK_DGRAM && fd.sotype != syscall.SOCK_RAW { - return io.EOF - } - return err -} diff --git a/libgo/go/net/fd_posix_test.go b/libgo/go/net/fd_posix_test.go deleted file mode 100644 index 85711ef..0000000 --- a/libgo/go/net/fd_posix_test.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris windows - -package net - -import ( - "io" - "syscall" - "testing" -) - -var eofErrorTests = []struct { - n int - err error - fd *netFD - expected error -}{ - {100, nil, &netFD{sotype: syscall.SOCK_STREAM}, nil}, - {100, io.EOF, &netFD{sotype: syscall.SOCK_STREAM}, io.EOF}, - {100, errClosing, &netFD{sotype: syscall.SOCK_STREAM}, errClosing}, - {0, nil, &netFD{sotype: syscall.SOCK_STREAM}, io.EOF}, - {0, io.EOF, &netFD{sotype: syscall.SOCK_STREAM}, io.EOF}, - {0, errClosing, &netFD{sotype: syscall.SOCK_STREAM}, errClosing}, - - {100, nil, &netFD{sotype: syscall.SOCK_DGRAM}, nil}, - {100, io.EOF, &netFD{sotype: syscall.SOCK_DGRAM}, io.EOF}, - {100, errClosing, &netFD{sotype: syscall.SOCK_DGRAM}, errClosing}, - {0, nil, &netFD{sotype: syscall.SOCK_DGRAM}, nil}, - {0, io.EOF, &netFD{sotype: syscall.SOCK_DGRAM}, io.EOF}, - {0, errClosing, &netFD{sotype: syscall.SOCK_DGRAM}, errClosing}, - - {100, nil, &netFD{sotype: syscall.SOCK_SEQPACKET}, nil}, - {100, io.EOF, &netFD{sotype: syscall.SOCK_SEQPACKET}, io.EOF}, - {100, errClosing, &netFD{sotype: syscall.SOCK_SEQPACKET}, errClosing}, - {0, nil, &netFD{sotype: syscall.SOCK_SEQPACKET}, io.EOF}, - {0, io.EOF, &netFD{sotype: syscall.SOCK_SEQPACKET}, io.EOF}, - {0, errClosing, &netFD{sotype: syscall.SOCK_SEQPACKET}, errClosing}, - - {100, nil, &netFD{sotype: syscall.SOCK_RAW}, nil}, - {100, io.EOF, &netFD{sotype: syscall.SOCK_RAW}, io.EOF}, - {100, errClosing, &netFD{sotype: syscall.SOCK_RAW}, errClosing}, - {0, nil, &netFD{sotype: syscall.SOCK_RAW}, nil}, - {0, io.EOF, &netFD{sotype: syscall.SOCK_RAW}, io.EOF}, - {0, errClosing, &netFD{sotype: syscall.SOCK_RAW}, errClosing}, -} - -func TestEOFError(t *testing.T) { - for _, tt := range eofErrorTests { - actual := tt.fd.eofError(tt.n, tt.err) - if actual != tt.expected { - t.Errorf("eofError(%v, %v, %v): expected %v, actual %v", tt.n, tt.err, tt.fd.sotype, tt.expected, actual) - } - } -} diff --git a/libgo/go/net/fd_unix.go b/libgo/go/net/fd_unix.go index b6ee059..e5afd1a 100644 --- a/libgo/go/net/fd_unix.go +++ b/libgo/go/net/fd_unix.go @@ -8,7 +8,7 @@ package net import ( "context" - "io" + "internal/poll" "os" "runtime" "sync/atomic" @@ -17,38 +17,33 @@ import ( // Network file descriptor. type netFD struct { - // locking/lifetime of sysfd + serialize access to Read and Write methods - fdmu fdMutex + pfd poll.FD // immutable until Close - sysfd int family int sotype int - isStream bool isConnected bool net string laddr Addr raddr Addr - - // writev cache. - iovecs *[]syscall.Iovec - - // wait server - pd pollDesc -} - -func sysInit() { } func newFD(sysfd, family, sotype int, net string) (*netFD, error) { - return &netFD{sysfd: sysfd, family: family, sotype: sotype, net: net, isStream: sotype == syscall.SOCK_STREAM}, nil + ret := &netFD{ + pfd: poll.FD{ + Sysfd: sysfd, + IsStream: sotype == syscall.SOCK_STREAM, + ZeroReadIsEOF: sotype != syscall.SOCK_DGRAM && sotype != syscall.SOCK_RAW, + }, + family: family, + sotype: sotype, + net: net, + } + return ret, nil } func (fd *netFD) init() error { - if err := fd.pd.init(fd); err != nil { - return err - } - return nil + return fd.pfd.Init(fd.net, true) } func (fd *netFD) setAddr(laddr, raddr Addr) { @@ -68,22 +63,23 @@ func (fd *netFD) name() string { return fd.net + ":" + ls + "->" + rs } -func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (ret error) { +func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (rsa syscall.Sockaddr, ret error) { // Do not need to call fd.writeLock here, // because fd is not yet accessible to user, // so no concurrent operations are possible. - switch err := connectFunc(fd.sysfd, ra); err { + switch err := connectFunc(fd.pfd.Sysfd, ra); err { case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR: case nil, syscall.EISCONN: select { case <-ctx.Done(): - return mapErr(ctx.Err()) + return nil, mapErr(ctx.Err()) default: } - if err := fd.init(); err != nil { - return err + if err := fd.pfd.Init(fd.net, true); err != nil { + return nil, err } - return nil + runtime.KeepAlive(fd) + return nil, nil case syscall.EINVAL: // On Solaris we can see EINVAL if the socket has // already been accepted and closed by the server. @@ -91,18 +87,18 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (ret erro // the socket will see EOF. For details and a test // case in C see https://golang.org/issue/6828. if runtime.GOOS == "solaris" { - return nil + return nil, nil } fallthrough default: - return os.NewSyscallError("connect", err) + return nil, os.NewSyscallError("connect", err) } - if err := fd.init(); err != nil { - return err + if err := fd.pfd.Init(fd.net, true); err != nil { + return nil, err } if deadline, _ := ctx.Deadline(); !deadline.IsZero() { - fd.setWriteDeadline(deadline) - defer fd.setWriteDeadline(noDeadline) + fd.pfd.SetWriteDeadline(deadline) + defer fd.pfd.SetWriteDeadline(noDeadline) } // Start the "interrupter" goroutine, if this context might be canceled. @@ -119,7 +115,7 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (ret erro defer func() { close(done) if ctxErr := <-interruptRes; ctxErr != nil && ret == nil { - // The interrupter goroutine called setWriteDeadline, + // The interrupter goroutine called SetWriteDeadline, // but the connect code below had returned from // waitWrite already and did a successful connect (ret // == nil). Because we've now poisoned the connection @@ -135,7 +131,7 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (ret erro // Force the runtime's poller to immediately give up // waiting for writability, unblocking waitWrite // below. - fd.setWriteDeadline(aLongTimeAgo) + fd.pfd.SetWriteDeadline(aLongTimeAgo) testHookCanceledDial() interruptRes <- ctx.Err() case <-done: @@ -153,66 +149,45 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (ret erro // SO_ERROR socket option to see if the connection // succeeded or failed. See issue 7474 for further // details. - if err := fd.pd.waitWrite(); err != nil { + if err := fd.pfd.WaitWrite(); err != nil { select { case <-ctx.Done(): - return mapErr(ctx.Err()) + return nil, mapErr(ctx.Err()) default: } - return err + return nil, err } - nerr, err := getsockoptIntFunc(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR) + nerr, err := getsockoptIntFunc(fd.pfd.Sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR) if err != nil { - return os.NewSyscallError("getsockopt", err) + return nil, os.NewSyscallError("getsockopt", err) } switch err := syscall.Errno(nerr); err { case syscall.EINPROGRESS, syscall.EALREADY, syscall.EINTR: - case syscall.Errno(0), syscall.EISCONN: - if runtime.GOOS != "darwin" { - return nil - } - // See golang.org/issue/14548. - // On Darwin, multiple connect system calls on - // a non-blocking socket never harm SO_ERROR. - switch err := connectFunc(fd.sysfd, ra); err { - case nil, syscall.EISCONN: - return nil + case syscall.EISCONN: + return nil, nil + case syscall.Errno(0): + // The runtime poller can wake us up spuriously; + // see issues 14548 and 19289. Check that we are + // really connected; if not, wait again. + if rsa, err := syscall.Getpeername(fd.pfd.Sysfd); err == nil { + return rsa, nil } default: - return os.NewSyscallError("getsockopt", err) + return nil, os.NewSyscallError("getsockopt", err) } + runtime.KeepAlive(fd) } } -func (fd *netFD) destroy() { - // Poller may want to unregister fd in readiness notification mechanism, - // so this must be executed before closeFunc. - fd.pd.close() - closeFunc(fd.sysfd) - fd.sysfd = -1 - runtime.SetFinalizer(fd, nil) -} - func (fd *netFD) Close() error { - if !fd.fdmu.increfAndClose() { - return errClosing - } - // Unblock any I/O. Once it all unblocks and returns, - // so that it cannot be referring to fd.sysfd anymore, - // the final decref will close fd.sysfd. This should happen - // fairly quickly, since all the I/O is non-blocking, and any - // attempts to block in the pollDesc will return errClosing. - fd.pd.evict() - fd.decref() - return nil + runtime.SetFinalizer(fd, nil) + return fd.pfd.Close() } func (fd *netFD) shutdown(how int) error { - if err := fd.incref(); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("shutdown", syscall.Shutdown(fd.sysfd, how)) + err := fd.pfd.Shutdown(how) + runtime.KeepAlive(fd) + return wrapSyscallError("shutdown", err) } func (fd *netFD) closeRead() error { @@ -224,233 +199,59 @@ func (fd *netFD) closeWrite() error { } func (fd *netFD) Read(p []byte) (n int, err error) { - if err := fd.readLock(); err != nil { - return 0, err - } - defer fd.readUnlock() - if len(p) == 0 { - // If the caller wanted a zero byte read, return immediately - // without trying. (But after acquiring the readLock.) Otherwise - // syscall.Read returns 0, nil and eofError turns that into - // io.EOF. - // TODO(bradfitz): make it wait for readability? (Issue 15735) - return 0, nil - } - if err := fd.pd.prepareRead(); err != nil { - return 0, err - } - if fd.isStream && len(p) > 1<<30 { - p = p[:1<<30] - } - for { - n, err = syscall.Read(fd.sysfd, p) - if err != nil { - n = 0 - if err == syscall.EAGAIN { - if err = fd.pd.waitRead(); err == nil { - continue - } - } - } - err = fd.eofError(n, err) - break - } - if _, ok := err.(syscall.Errno); ok { - err = os.NewSyscallError("read", err) - } - return + n, err = fd.pfd.Read(p) + runtime.KeepAlive(fd) + return n, wrapSyscallError("read", err) } func (fd *netFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) { - if err := fd.readLock(); err != nil { - return 0, nil, err - } - defer fd.readUnlock() - if err := fd.pd.prepareRead(); err != nil { - return 0, nil, err - } - for { - n, sa, err = syscall.Recvfrom(fd.sysfd, p, 0) - if err != nil { - n = 0 - if err == syscall.EAGAIN { - if err = fd.pd.waitRead(); err == nil { - continue - } - } - } - err = fd.eofError(n, err) - break - } - if _, ok := err.(syscall.Errno); ok { - err = os.NewSyscallError("recvfrom", err) - } - return + n, sa, err = fd.pfd.ReadFrom(p) + runtime.KeepAlive(fd) + return n, sa, wrapSyscallError("recvfrom", err) } func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) { - if err := fd.readLock(); err != nil { - return 0, 0, 0, nil, err - } - defer fd.readUnlock() - if err := fd.pd.prepareRead(); err != nil { - return 0, 0, 0, nil, err - } - for { - n, oobn, flags, sa, err = syscall.Recvmsg(fd.sysfd, p, oob, 0) - if err != nil { - // TODO(dfc) should n and oobn be set to 0 - if err == syscall.EAGAIN { - if err = fd.pd.waitRead(); err == nil { - continue - } - } - } - err = fd.eofError(n, err) - break - } - if _, ok := err.(syscall.Errno); ok { - err = os.NewSyscallError("recvmsg", err) - } - return + n, oobn, flags, sa, err = fd.pfd.ReadMsg(p, oob) + runtime.KeepAlive(fd) + return n, oobn, flags, sa, wrapSyscallError("recvmsg", err) } func (fd *netFD) Write(p []byte) (nn int, err error) { - if err := fd.writeLock(); err != nil { - return 0, err - } - defer fd.writeUnlock() - if err := fd.pd.prepareWrite(); err != nil { - return 0, err - } - for { - var n int - max := len(p) - if fd.isStream && max-nn > 1<<30 { - max = nn + 1<<30 - } - n, err = syscall.Write(fd.sysfd, p[nn:max]) - if n > 0 { - nn += n - } - if nn == len(p) { - break - } - if err == syscall.EAGAIN { - if err = fd.pd.waitWrite(); err == nil { - continue - } - } - if err != nil { - break - } - if n == 0 { - err = io.ErrUnexpectedEOF - break - } - } - if _, ok := err.(syscall.Errno); ok { - err = os.NewSyscallError("write", err) - } - return nn, err + nn, err = fd.pfd.Write(p) + runtime.KeepAlive(fd) + return nn, wrapSyscallError("write", err) } func (fd *netFD) writeTo(p []byte, sa syscall.Sockaddr) (n int, err error) { - if err := fd.writeLock(); err != nil { - return 0, err - } - defer fd.writeUnlock() - if err := fd.pd.prepareWrite(); err != nil { - return 0, err - } - for { - err = syscall.Sendto(fd.sysfd, p, 0, sa) - if err == syscall.EAGAIN { - if err = fd.pd.waitWrite(); err == nil { - continue - } - } - break - } - if err == nil { - n = len(p) - } - if _, ok := err.(syscall.Errno); ok { - err = os.NewSyscallError("sendto", err) - } - return + n, err = fd.pfd.WriteTo(p, sa) + runtime.KeepAlive(fd) + return n, wrapSyscallError("sendto", err) } func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) { - if err := fd.writeLock(); err != nil { - return 0, 0, err - } - defer fd.writeUnlock() - if err := fd.pd.prepareWrite(); err != nil { - return 0, 0, err - } - for { - n, err = syscall.SendmsgN(fd.sysfd, p, oob, sa, 0) - if err == syscall.EAGAIN { - if err = fd.pd.waitWrite(); err == nil { - continue - } - } - break - } - if err == nil { - oobn = len(oob) - } - if _, ok := err.(syscall.Errno); ok { - err = os.NewSyscallError("sendmsg", err) - } - return + n, oobn, err = fd.pfd.WriteMsg(p, oob, sa) + runtime.KeepAlive(fd) + return n, oobn, wrapSyscallError("sendmsg", err) } func (fd *netFD) accept() (netfd *netFD, err error) { - if err := fd.readLock(); err != nil { - return nil, err - } - defer fd.readUnlock() - - var s int - var rsa syscall.Sockaddr - if err = fd.pd.prepareRead(); err != nil { - return nil, err - } - for { - s, rsa, err = accept(fd.sysfd) - if err != nil { - nerr, ok := err.(*os.SyscallError) - if !ok { - return nil, err - } - switch nerr.Err { - case syscall.EAGAIN: - if err = fd.pd.waitRead(); err == nil { - continue - } - case syscall.ECONNABORTED: - // This means that a socket on the - // listen queue was closed before we - // Accept()ed it; it's a silly error, - // so try again. - continue - } - return nil, err + d, rsa, errcall, err := fd.pfd.Accept() + if err != nil { + if errcall != "" { + err = wrapSyscallError(errcall, err) } - break + return nil, err } - if netfd, err = newFD(s, fd.family, fd.sotype, fd.net); err != nil { - closeFunc(s) + if netfd, err = newFD(d, fd.family, fd.sotype, fd.net); err != nil { + poll.CloseFunc(d) return nil, err } if err = netfd.init(); err != nil { fd.Close() return nil, err } - lsa, _ := syscall.Getsockname(netfd.sysfd) + lsa, _ := syscall.Getsockname(netfd.pfd.Sysfd) netfd.setAddr(netfd.addrFunc()(lsa), netfd.addrFunc()(rsa)) return netfd, nil } @@ -511,7 +312,7 @@ func dupCloseOnExecOld(fd int) (newfd int, err error) { } func (fd *netFD) dup() (f *os.File, err error) { - ns, err := dupCloseOnExec(fd.sysfd) + ns, err := dupCloseOnExec(fd.pfd.Sysfd) if err != nil { return nil, err } diff --git a/libgo/go/net/fd_windows.go b/libgo/go/net/fd_windows.go index a976f2a..c2156b2 100644 --- a/libgo/go/net/fd_windows.go +++ b/libgo/go/net/fd_windows.go @@ -6,64 +6,13 @@ package net import ( "context" - "internal/race" + "internal/poll" "os" "runtime" - "sync" "syscall" "unsafe" ) -var ( - initErr error - ioSync uint64 -) - -// CancelIo Windows API cancels all outstanding IO for a particular -// socket on current thread. To overcome that limitation, we run -// special goroutine, locked to OS single thread, that both starts -// and cancels IO. It means, there are 2 unavoidable thread switches -// for every IO. -// Some newer versions of Windows has new CancelIoEx API, that does -// not have that limitation and can be used from any thread. This -// package uses CancelIoEx API, if present, otherwise it fallback -// to CancelIo. - -var ( - canCancelIO bool // determines if CancelIoEx API is present - skipSyncNotif bool - hasLoadSetFileCompletionNotificationModes bool -) - -func sysInit() { - var d syscall.WSAData - e := syscall.WSAStartup(uint32(0x202), &d) - if e != nil { - initErr = os.NewSyscallError("wsastartup", e) - } - canCancelIO = syscall.LoadCancelIoEx() == nil - hasLoadSetFileCompletionNotificationModes = syscall.LoadSetFileCompletionNotificationModes() == nil - if hasLoadSetFileCompletionNotificationModes { - // It's not safe to use FILE_SKIP_COMPLETION_PORT_ON_SUCCESS if non IFS providers are installed: - // http://support.microsoft.com/kb/2568167 - skipSyncNotif = true - protos := [2]int32{syscall.IPPROTO_TCP, 0} - var buf [32]syscall.WSAProtocolInfo - len := uint32(unsafe.Sizeof(buf)) - n, err := syscall.WSAEnumProtocols(&protos[0], &buf[0], &len) - if err != nil { - skipSyncNotif = false - } else { - for i := int32(0); i < n; i++ { - if buf[i].ServiceFlags1&syscall.XP1_IFS_HANDLES == 0 { - skipSyncNotif = false - break - } - } - } - } -} - // canUseConnectEx reports whether we can use the ConnectEx Windows API call // for the given network type. func canUseConnectEx(net string) bool { @@ -75,257 +24,39 @@ func canUseConnectEx(net string) bool { return false } -// operation contains superset of data necessary to perform all async IO. -type operation struct { - // Used by IOCP interface, it must be first field - // of the struct, as our code rely on it. - o syscall.Overlapped - - // fields used by runtime.netpoll - runtimeCtx uintptr - mode int32 - errno int32 - qty uint32 - - // fields used only by net package - fd *netFD - errc chan error - buf syscall.WSABuf - sa syscall.Sockaddr - rsa *syscall.RawSockaddrAny - rsan int32 - handle syscall.Handle - flags uint32 - bufs []syscall.WSABuf -} - -func (o *operation) InitBuf(buf []byte) { - o.buf.Len = uint32(len(buf)) - o.buf.Buf = nil - if len(buf) != 0 { - o.buf.Buf = &buf[0] - } -} - -func (o *operation) InitBufs(buf *Buffers) { - if o.bufs == nil { - o.bufs = make([]syscall.WSABuf, 0, len(*buf)) - } else { - o.bufs = o.bufs[:0] - } - for _, b := range *buf { - var p *byte - if len(b) > 0 { - p = &b[0] - } - o.bufs = append(o.bufs, syscall.WSABuf{Len: uint32(len(b)), Buf: p}) - } -} - -// ClearBufs clears all pointers to Buffers parameter captured -// by InitBufs, so it can be released by garbage collector. -func (o *operation) ClearBufs() { - for i := range o.bufs { - o.bufs[i].Buf = nil - } - o.bufs = o.bufs[:0] -} - -// ioSrv executes net IO requests. -type ioSrv struct { - req chan ioSrvReq -} - -type ioSrvReq struct { - o *operation - submit func(o *operation) error // if nil, cancel the operation -} - -// ProcessRemoteIO will execute submit IO requests on behalf -// of other goroutines, all on a single os thread, so it can -// cancel them later. Results of all operations will be sent -// back to their requesters via channel supplied in request. -// It is used only when the CancelIoEx API is unavailable. -func (s *ioSrv) ProcessRemoteIO() { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - for r := range s.req { - if r.submit != nil { - r.o.errc <- r.submit(r.o) - } else { - r.o.errc <- syscall.CancelIo(r.o.fd.sysfd) - } - } -} - -// ExecIO executes a single IO operation o. It submits and cancels -// IO in the current thread for systems where Windows CancelIoEx API -// is available. Alternatively, it passes the request onto -// runtime netpoll and waits for completion or cancels request. -func (s *ioSrv) ExecIO(o *operation, name string, submit func(o *operation) error) (int, error) { - fd := o.fd - // Notify runtime netpoll about starting IO. - err := fd.pd.prepare(int(o.mode)) - if err != nil { - return 0, err - } - // Start IO. - if canCancelIO { - err = submit(o) - } else { - // Send request to a special dedicated thread, - // so it can stop the IO with CancelIO later. - s.req <- ioSrvReq{o, submit} - err = <-o.errc - } - switch err { - case nil: - // IO completed immediately - if o.fd.skipSyncNotif { - // No completion message will follow, so return immediately. - return int(o.qty), nil - } - // Need to get our completion message anyway. - case syscall.ERROR_IO_PENDING: - // IO started, and we have to wait for its completion. - err = nil - default: - return 0, err - } - // Wait for our request to complete. - err = fd.pd.wait(int(o.mode)) - if err == nil { - // All is good. Extract our IO results and return. - if o.errno != 0 { - err = syscall.Errno(o.errno) - return 0, err - } - return int(o.qty), nil - } - // IO is interrupted by "close" or "timeout" - netpollErr := err - switch netpollErr { - case errClosing, errTimeout: - // will deal with those. - default: - panic("net: unexpected runtime.netpoll error: " + netpollErr.Error()) - } - // Cancel our request. - if canCancelIO { - err := syscall.CancelIoEx(fd.sysfd, &o.o) - // Assuming ERROR_NOT_FOUND is returned, if IO is completed. - if err != nil && err != syscall.ERROR_NOT_FOUND { - // TODO(brainman): maybe do something else, but panic. - panic(err) - } - } else { - s.req <- ioSrvReq{o, nil} - <-o.errc - } - // Wait for cancelation to complete. - fd.pd.waitCanceled(int(o.mode)) - if o.errno != 0 { - err = syscall.Errno(o.errno) - if err == syscall.ERROR_OPERATION_ABORTED { // IO Canceled - err = netpollErr - } - return 0, err - } - // We issued a cancelation request. But, it seems, IO operation succeeded - // before the cancelation request run. We need to treat the IO operation as - // succeeded (the bytes are actually sent/recv from network). - return int(o.qty), nil -} - -// Start helper goroutines. -var rsrv, wsrv *ioSrv -var onceStartServer sync.Once - -func startServer() { - rsrv = new(ioSrv) - wsrv = new(ioSrv) - if !canCancelIO { - // Only CancelIo API is available. Lets start two special goroutines - // locked to an OS thread, that both starts and cancels IO. One will - // process read requests, while other will do writes. - rsrv.req = make(chan ioSrvReq) - go rsrv.ProcessRemoteIO() - wsrv.req = make(chan ioSrvReq) - go wsrv.ProcessRemoteIO() - } -} - // Network file descriptor. type netFD struct { - // locking/lifetime of sysfd + serialize access to Read and Write methods - fdmu fdMutex + pfd poll.FD // immutable until Close - sysfd syscall.Handle - family int - sotype int - isStream bool - isConnected bool - skipSyncNotif bool - net string - laddr Addr - raddr Addr - - rop operation // read operation - wop operation // write operation - - // wait server - pd pollDesc + family int + sotype int + isConnected bool + net string + laddr Addr + raddr Addr } func newFD(sysfd syscall.Handle, family, sotype int, net string) (*netFD, error) { - if initErr != nil { - return nil, initErr + ret := &netFD{ + pfd: poll.FD{ + Sysfd: sysfd, + IsStream: sotype == syscall.SOCK_STREAM, + ZeroReadIsEOF: sotype != syscall.SOCK_DGRAM && sotype != syscall.SOCK_RAW, + }, + family: family, + sotype: sotype, + net: net, } - onceStartServer.Do(startServer) - return &netFD{sysfd: sysfd, family: family, sotype: sotype, net: net, isStream: sotype == syscall.SOCK_STREAM}, nil + return ret, nil } func (fd *netFD) init() error { - if err := fd.pd.init(fd); err != nil { - return err - } - if hasLoadSetFileCompletionNotificationModes { - // We do not use events, so we can skip them always. - flags := uint8(syscall.FILE_SKIP_SET_EVENT_ON_HANDLE) - // It's not safe to skip completion notifications for UDP: - // http://blogs.technet.com/b/winserverperformance/archive/2008/06/26/designing-applications-for-high-performance-part-iii.aspx - if skipSyncNotif && fd.net == "tcp" { - flags |= syscall.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS - } - err := syscall.SetFileCompletionNotificationModes(fd.sysfd, flags) - if err == nil && flags&syscall.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS != 0 { - fd.skipSyncNotif = true - } - } - // Disable SIO_UDP_CONNRESET behavior. - // http://support.microsoft.com/kb/263823 - switch fd.net { - case "udp", "udp4", "udp6": - ret := uint32(0) - flag := uint32(0) - size := uint32(unsafe.Sizeof(flag)) - err := syscall.WSAIoctl(fd.sysfd, syscall.SIO_UDP_CONNRESET, (*byte)(unsafe.Pointer(&flag)), size, nil, 0, &ret, nil, 0) - if err != nil { - return os.NewSyscallError("wsaioctl", err) - } + errcall, err := fd.pfd.Init(fd.net) + if errcall != "" { + err = wrapSyscallError(errcall, err) } - fd.rop.mode = 'r' - fd.wop.mode = 'w' - fd.rop.fd = fd - fd.wop.fd = fd - fd.rop.runtimeCtx = fd.pd.runtimeCtx - fd.wop.runtimeCtx = fd.pd.runtimeCtx - if !canCancelIO { - fd.rop.errc = make(chan error) - fd.wop.errc = make(chan error) - } - return nil + return err } func (fd *netFD) setAddr(laddr, raddr Addr) { @@ -334,20 +65,21 @@ func (fd *netFD) setAddr(laddr, raddr Addr) { runtime.SetFinalizer(fd, (*netFD).Close) } -func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error { +// Always returns nil for connected peer address result. +func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (syscall.Sockaddr, error) { // Do not need to call fd.writeLock here, // because fd is not yet accessible to user, // so no concurrent operations are possible. if err := fd.init(); err != nil { - return err + return nil, err } if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() { - fd.setWriteDeadline(deadline) - defer fd.setWriteDeadline(noDeadline) + fd.pfd.SetWriteDeadline(deadline) + defer fd.pfd.SetWriteDeadline(noDeadline) } if !canUseConnectEx(fd.net) { - err := connectFunc(fd.sysfd, ra) - return os.NewSyscallError("connect", err) + err := connectFunc(fd.pfd.Sysfd, ra) + return nil, os.NewSyscallError("connect", err) } // ConnectEx windows API requires an unconnected, previously bound socket. if la == nil { @@ -359,13 +91,10 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error { default: panic("unexpected type in connect") } - if err := syscall.Bind(fd.sysfd, la); err != nil { - return os.NewSyscallError("bind", err) + if err := syscall.Bind(fd.pfd.Sysfd, la); err != nil { + return nil, os.NewSyscallError("bind", err) } } - // Call ConnectEx API. - o := &fd.wop - o.sa = ra // Wait for the goroutine converting context.Done into a write timeout // to exist, otherwise our caller might cancel the context and @@ -377,59 +106,37 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error { case <-ctx.Done(): // Force the runtime's poller to immediately give // up waiting for writability. - fd.setWriteDeadline(aLongTimeAgo) + fd.pfd.SetWriteDeadline(aLongTimeAgo) <-done case <-done: } }() - _, err := wsrv.ExecIO(o, "ConnectEx", func(o *operation) error { - return connectExFunc(o.fd.sysfd, o.sa, nil, 0, nil, &o.o) - }) - if err != nil { + // Call ConnectEx API. + if err := fd.pfd.ConnectEx(ra); err != nil { select { case <-ctx.Done(): - return mapErr(ctx.Err()) + return nil, mapErr(ctx.Err()) default: if _, ok := err.(syscall.Errno); ok { err = os.NewSyscallError("connectex", err) } - return err + return nil, err } } // Refresh socket properties. - return os.NewSyscallError("setsockopt", syscall.Setsockopt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_UPDATE_CONNECT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd)))) -} - -func (fd *netFD) destroy() { - if fd.sysfd == syscall.InvalidHandle { - return - } - // Poller may want to unregister fd in readiness notification mechanism, - // so this must be executed before closeFunc. - fd.pd.close() - closeFunc(fd.sysfd) - fd.sysfd = syscall.InvalidHandle - // no need for a finalizer anymore - runtime.SetFinalizer(fd, nil) + return nil, os.NewSyscallError("setsockopt", syscall.Setsockopt(fd.pfd.Sysfd, syscall.SOL_SOCKET, syscall.SO_UPDATE_CONNECT_CONTEXT, (*byte)(unsafe.Pointer(&fd.pfd.Sysfd)), int32(unsafe.Sizeof(fd.pfd.Sysfd)))) } func (fd *netFD) Close() error { - if !fd.fdmu.increfAndClose() { - return errClosing - } - // unblock pending reader and writer - fd.pd.evict() - fd.decref() - return nil + runtime.SetFinalizer(fd, nil) + return fd.pfd.Close() } func (fd *netFD) shutdown(how int) error { - if err := fd.incref(); err != nil { - return err - } - defer fd.decref() - return syscall.Shutdown(fd.sysfd, how) + err := fd.pfd.Shutdown(how) + runtime.KeepAlive(fd) + return err } func (fd *netFD) closeRead() error { @@ -441,72 +148,21 @@ func (fd *netFD) closeWrite() error { } func (fd *netFD) Read(buf []byte) (int, error) { - if err := fd.readLock(); err != nil { - return 0, err - } - defer fd.readUnlock() - o := &fd.rop - o.InitBuf(buf) - n, err := rsrv.ExecIO(o, "WSARecv", func(o *operation) error { - return syscall.WSARecv(o.fd.sysfd, &o.buf, 1, &o.qty, &o.flags, &o.o, nil) - }) - if race.Enabled { - race.Acquire(unsafe.Pointer(&ioSync)) - } - if len(buf) != 0 { - err = fd.eofError(n, err) - } - if _, ok := err.(syscall.Errno); ok { - err = os.NewSyscallError("wsarecv", err) - } - return n, err + n, err := fd.pfd.Read(buf) + runtime.KeepAlive(fd) + return n, wrapSyscallError("wsarecv", err) } func (fd *netFD) readFrom(buf []byte) (int, syscall.Sockaddr, error) { - if len(buf) == 0 { - return 0, nil, nil - } - if err := fd.readLock(); err != nil { - return 0, nil, err - } - defer fd.readUnlock() - o := &fd.rop - o.InitBuf(buf) - n, err := rsrv.ExecIO(o, "WSARecvFrom", func(o *operation) error { - if o.rsa == nil { - o.rsa = new(syscall.RawSockaddrAny) - } - o.rsan = int32(unsafe.Sizeof(*o.rsa)) - return syscall.WSARecvFrom(o.fd.sysfd, &o.buf, 1, &o.qty, &o.flags, o.rsa, &o.rsan, &o.o, nil) - }) - err = fd.eofError(n, err) - if _, ok := err.(syscall.Errno); ok { - err = os.NewSyscallError("wsarecvfrom", err) - } - if err != nil { - return n, nil, err - } - sa, _ := o.rsa.Sockaddr() - return n, sa, nil + n, sa, err := fd.pfd.ReadFrom(buf) + runtime.KeepAlive(fd) + return n, sa, wrapSyscallError("wsarecvfrom", err) } func (fd *netFD) Write(buf []byte) (int, error) { - if err := fd.writeLock(); err != nil { - return 0, err - } - defer fd.writeUnlock() - if race.Enabled { - race.ReleaseMerge(unsafe.Pointer(&ioSync)) - } - o := &fd.wop - o.InitBuf(buf) - n, err := wsrv.ExecIO(o, "WSASend", func(o *operation) error { - return syscall.WSASend(o.fd.sysfd, &o.buf, 1, &o.qty, 0, &o.o, nil) - }) - if _, ok := err.(syscall.Errno); ok { - err = os.NewSyscallError("wsasend", err) - } - return n, err + n, err := fd.pfd.Write(buf) + runtime.KeepAlive(fd) + return n, wrapSyscallError("wsasend", err) } func (c *conn) writeBuffers(v *Buffers) (int64, error) { @@ -515,67 +171,39 @@ func (c *conn) writeBuffers(v *Buffers) (int64, error) { } n, err := c.fd.writeBuffers(v) if err != nil { - return n, &OpError{Op: "WSASend", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err} + return n, &OpError{Op: "wsasend", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err} } return n, nil } func (fd *netFD) writeBuffers(buf *Buffers) (int64, error) { - if len(*buf) == 0 { - return 0, nil - } - if err := fd.writeLock(); err != nil { - return 0, err - } - defer fd.writeUnlock() - if race.Enabled { - race.ReleaseMerge(unsafe.Pointer(&ioSync)) - } - o := &fd.wop - o.InitBufs(buf) - n, err := wsrv.ExecIO(o, "WSASend", func(o *operation) error { - return syscall.WSASend(o.fd.sysfd, &o.bufs[0], uint32(len(*buf)), &o.qty, 0, &o.o, nil) - }) - o.ClearBufs() - if _, ok := err.(syscall.Errno); ok { - err = os.NewSyscallError("wsasend", err) - } - testHookDidWritev(n) - buf.consume(int64(n)) - return int64(n), err + n, err := fd.pfd.Writev((*[][]byte)(buf)) + runtime.KeepAlive(fd) + return n, wrapSyscallError("wsasend", err) } func (fd *netFD) writeTo(buf []byte, sa syscall.Sockaddr) (int, error) { - if len(buf) == 0 { - return 0, nil - } - if err := fd.writeLock(); err != nil { - return 0, err - } - defer fd.writeUnlock() - o := &fd.wop - o.InitBuf(buf) - o.sa = sa - n, err := wsrv.ExecIO(o, "WSASendto", func(o *operation) error { - return syscall.WSASendto(o.fd.sysfd, &o.buf, 1, &o.qty, 0, o.sa, &o.o, nil) - }) - if _, ok := err.(syscall.Errno); ok { - err = os.NewSyscallError("wsasendto", err) - } - return n, err + n, err := fd.pfd.WriteTo(buf, sa) + runtime.KeepAlive(fd) + return n, wrapSyscallError("wsasendto", err) } -func (fd *netFD) acceptOne(rawsa []syscall.RawSockaddrAny, o *operation) (*netFD, error) { - // Get new socket. - s, err := sysSocket(fd.family, fd.sotype, 0) +func (fd *netFD) accept() (*netFD, error) { + s, rawsa, rsan, errcall, err := fd.pfd.Accept(func() (syscall.Handle, error) { + return sysSocket(fd.family, fd.sotype, 0) + }) + if err != nil { + if errcall != "" { + err = wrapSyscallError(errcall, err) + } return nil, err } // Associate our new socket with IOCP. netfd, err := newFD(s, fd.family, fd.sotype, fd.net) if err != nil { - closeFunc(s) + poll.CloseFunc(s) return nil, err } if err := netfd.init(); err != nil { @@ -583,71 +211,11 @@ func (fd *netFD) acceptOne(rawsa []syscall.RawSockaddrAny, o *operation) (*netFD return nil, err } - // Submit accept request. - o.handle = s - o.rsan = int32(unsafe.Sizeof(rawsa[0])) - _, err = rsrv.ExecIO(o, "AcceptEx", func(o *operation) error { - return acceptFunc(o.fd.sysfd, o.handle, (*byte)(unsafe.Pointer(&rawsa[0])), 0, uint32(o.rsan), uint32(o.rsan), &o.qty, &o.o) - }) - if err != nil { - netfd.Close() - if _, ok := err.(syscall.Errno); ok { - err = os.NewSyscallError("acceptex", err) - } - return nil, err - } - - // Inherit properties of the listening socket. - err = syscall.Setsockopt(s, syscall.SOL_SOCKET, syscall.SO_UPDATE_ACCEPT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd))) - if err != nil { - netfd.Close() - return nil, os.NewSyscallError("setsockopt", err) - } - runtime.KeepAlive(fd) - return netfd, nil -} - -func (fd *netFD) accept() (*netFD, error) { - if err := fd.readLock(); err != nil { - return nil, err - } - defer fd.readUnlock() - - o := &fd.rop - var netfd *netFD - var err error - var rawsa [2]syscall.RawSockaddrAny - for { - netfd, err = fd.acceptOne(rawsa[:], o) - if err == nil { - break - } - // Sometimes we see WSAECONNRESET and ERROR_NETNAME_DELETED is - // returned here. These happen if connection reset is received - // before AcceptEx could complete. These errors relate to new - // connection, not to AcceptEx, so ignore broken connection and - // try AcceptEx again for more connections. - nerr, ok := err.(*os.SyscallError) - if !ok { - return nil, err - } - errno, ok := nerr.Err.(syscall.Errno) - if !ok { - return nil, err - } - switch errno { - case syscall.ERROR_NETNAME_DELETED, syscall.WSAECONNRESET: - // ignore these and try again - default: - return nil, err - } - } - // Get local and peer addr out of AcceptEx buffer. var lrsa, rrsa *syscall.RawSockaddrAny var llen, rlen int32 syscall.GetAcceptExSockaddrs((*byte)(unsafe.Pointer(&rawsa[0])), - 0, uint32(o.rsan), uint32(o.rsan), &lrsa, &llen, &rrsa, &rlen) + 0, rsan, rsan, &lrsa, &llen, &rrsa, &rlen) lsa, _ := lrsa.Sockaddr() rsa, _ := rrsa.Sockaddr() diff --git a/libgo/go/net/file_test.go b/libgo/go/net/file_test.go index 6566ce2..abf8b3a 100644 --- a/libgo/go/net/file_test.go +++ b/libgo/go/net/file_test.go @@ -90,7 +90,7 @@ func TestFileConn(t *testing.T) { f, err = c1.File() } if err := c1.Close(); err != nil { - if perr := parseCloseError(err); perr != nil { + if perr := parseCloseError(err, false); perr != nil { t.Error(perr) } t.Error(err) @@ -256,7 +256,7 @@ func TestFilePacketConn(t *testing.T) { f, err = c1.File() } if err := c1.Close(); err != nil { - if perr := parseCloseError(err); perr != nil { + if perr := parseCloseError(err, false); perr != nil { t.Error(perr) } t.Error(err) diff --git a/libgo/go/net/file_unix.go b/libgo/go/net/file_unix.go index b47a614..3655a89 100644 --- a/libgo/go/net/file_unix.go +++ b/libgo/go/net/file_unix.go @@ -7,6 +7,7 @@ package net import ( + "internal/poll" "os" "syscall" ) @@ -17,7 +18,7 @@ func dupSocket(f *os.File) (int, error) { return -1, err } if err := syscall.SetNonblock(s, true); err != nil { - closeFunc(s) + poll.CloseFunc(s) return -1, os.NewSyscallError("setnonblock", err) } return s, nil @@ -31,7 +32,7 @@ func newFileFD(f *os.File) (*netFD, error) { family := syscall.AF_UNSPEC sotype, err := syscall.GetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_TYPE) if err != nil { - closeFunc(s) + poll.CloseFunc(s) return nil, os.NewSyscallError("getsockopt", err) } lsa, _ := syscall.Getsockname(s) @@ -44,12 +45,12 @@ func newFileFD(f *os.File) (*netFD, error) { case *syscall.SockaddrUnix: family = syscall.AF_UNIX default: - closeFunc(s) + poll.CloseFunc(s) return nil, syscall.EPROTONOSUPPORT } fd, err := newFD(s, family, sotype, "") if err != nil { - closeFunc(s) + poll.CloseFunc(s) return nil, err } laddr := fd.addrFunc()(lsa) diff --git a/libgo/go/net/hook_cloexec.go b/libgo/go/net/hook_cloexec.go deleted file mode 100644 index 870f0d7..0000000 --- a/libgo/go/net/hook_cloexec.go +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright 2015 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build freebsd linux - -package net - -import "syscall" - -var ( - // Placeholders for socket system calls. - accept4Func func(int, int) (int, syscall.Sockaddr, error) = syscall.Accept4 -) diff --git a/libgo/go/net/hook_unix.go b/libgo/go/net/hook_unix.go index b2522a2..7d58d0f 100644 --- a/libgo/go/net/hook_unix.go +++ b/libgo/go/net/hook_unix.go @@ -13,10 +13,8 @@ var ( testHookCanceledDial = func() {} // for golang.org/issue/16523 // Placeholders for socket system calls. - socketFunc func(int, int, int) (int, error) = syscall.Socket - closeFunc func(int) error = syscall.Close - connectFunc func(int, syscall.Sockaddr) error = syscall.Connect - listenFunc func(int, int) error = syscall.Listen - acceptFunc func(int) (int, syscall.Sockaddr, error) = syscall.Accept - getsockoptIntFunc func(int, int, int) (int, error) = syscall.GetsockoptInt + socketFunc func(int, int, int) (int, error) = syscall.Socket + connectFunc func(int, syscall.Sockaddr) error = syscall.Connect + listenFunc func(int, int) error = syscall.Listen + getsockoptIntFunc func(int, int, int) (int, error) = syscall.GetsockoptInt ) diff --git a/libgo/go/net/hook_windows.go b/libgo/go/net/hook_windows.go index 63ea35a..4e64dce 100644 --- a/libgo/go/net/hook_windows.go +++ b/libgo/go/net/hook_windows.go @@ -13,10 +13,7 @@ var ( testHookDialChannel = func() { time.Sleep(time.Millisecond) } // see golang.org/issue/5349 // Placeholders for socket system calls. - socketFunc func(int, int, int) (syscall.Handle, error) = syscall.Socket - closeFunc func(syscall.Handle) error = syscall.Closesocket - connectFunc func(syscall.Handle, syscall.Sockaddr) error = syscall.Connect - connectExFunc func(syscall.Handle, syscall.Sockaddr, *byte, uint32, *uint32, *syscall.Overlapped) error = syscall.ConnectEx - listenFunc func(syscall.Handle, int) error = syscall.Listen - acceptFunc func(syscall.Handle, syscall.Handle, *byte, uint32, uint32, uint32, *uint32, *syscall.Overlapped) error = syscall.AcceptEx + socketFunc func(int, int, int) (syscall.Handle, error) = syscall.Socket + connectFunc func(syscall.Handle, syscall.Sockaddr) error = syscall.Connect + listenFunc func(syscall.Handle, int) error = syscall.Listen ) diff --git a/libgo/go/net/http/cgi/host_test.go b/libgo/go/net/http/cgi/host_test.go index f058372..1336300 100644 --- a/libgo/go/net/http/cgi/host_test.go +++ b/libgo/go/net/http/cgi/host_test.go @@ -409,7 +409,7 @@ func TestCopyError(t *testing.T) { } childRunning := func() bool { - return isProcessRunning(t, pid) + return isProcessRunning(pid) } if !childRunning() { diff --git a/libgo/go/net/http/cgi/posix_test.go b/libgo/go/net/http/cgi/posix_test.go index 5ff9e7d..9396ce0 100644 --- a/libgo/go/net/http/cgi/posix_test.go +++ b/libgo/go/net/http/cgi/posix_test.go @@ -9,10 +9,9 @@ package cgi import ( "os" "syscall" - "testing" ) -func isProcessRunning(t *testing.T, pid int) bool { +func isProcessRunning(pid int) bool { p, err := os.FindProcess(pid) if err != nil { return false diff --git a/libgo/go/net/http/client.go b/libgo/go/net/http/client.go index 0005538..4c9084a 100644 --- a/libgo/go/net/http/client.go +++ b/libgo/go/net/http/client.go @@ -38,20 +38,20 @@ import ( // When following redirects, the Client will forward all headers set on the // initial Request except: // -// * when forwarding sensitive headers like "Authorization", -// "WWW-Authenticate", and "Cookie" to untrusted targets. -// These headers will be ignored when following a redirect to a domain -// that is not a subdomain match or exact match of the initial domain. -// For example, a redirect from "foo.com" to either "foo.com" or "sub.foo.com" -// will forward the sensitive headers, but a redirect to "bar.com" will not. -// -// * when forwarding the "Cookie" header with a non-nil cookie Jar. -// Since each redirect may mutate the state of the cookie jar, -// a redirect may possibly alter a cookie set in the initial request. -// When forwarding the "Cookie" header, any mutated cookies will be omitted, -// with the expectation that the Jar will insert those mutated cookies -// with the updated values (assuming the origin matches). -// If Jar is nil, the initial cookies are forwarded without change. +// • when forwarding sensitive headers like "Authorization", +// "WWW-Authenticate", and "Cookie" to untrusted targets. +// These headers will be ignored when following a redirect to a domain +// that is not a subdomain match or exact match of the initial domain. +// For example, a redirect from "foo.com" to either "foo.com" or "sub.foo.com" +// will forward the sensitive headers, but a redirect to "bar.com" will not. +// +// • when forwarding the "Cookie" header with a non-nil cookie Jar. +// Since each redirect may mutate the state of the cookie jar, +// a redirect may possibly alter a cookie set in the initial request. +// When forwarding the "Cookie" header, any mutated cookies will be omitted, +// with the expectation that the Jar will insert those mutated cookies +// with the updated values (assuming the origin matches). +// If Jar is nil, the initial cookies are forwarded without change. // type Client struct { // Transport specifies the mechanism by which individual @@ -494,17 +494,21 @@ func (c *Client) Do(req *Request) (*Response, error) { } var ( - deadline = c.deadline() - reqs []*Request - resp *Response - copyHeaders = c.makeHeadersCopier(req) + deadline = c.deadline() + reqs []*Request + resp *Response + copyHeaders = c.makeHeadersCopier(req) + reqBodyClosed = false // have we closed the current req.Body? // Redirect behavior: redirectMethod string includeBody bool ) uerr := func(err error) error { - req.closeBody() + // the body may have been closed already by c.send() + if !reqBodyClosed { + req.closeBody() + } method := valueOrDefault(reqs[0].Method, "GET") var urlStr string if resp != nil && resp.Request != nil { @@ -524,10 +528,12 @@ func (c *Client) Do(req *Request) (*Response, error) { if len(reqs) > 0 { loc := resp.Header.Get("Location") if loc == "" { + resp.closeBody() return nil, uerr(fmt.Errorf("%d response missing Location header", resp.StatusCode)) } u, err := req.URL.Parse(loc) if err != nil { + resp.closeBody() return nil, uerr(fmt.Errorf("failed to parse Location header %q: %v", loc, err)) } ireq := reqs[0] @@ -542,6 +548,7 @@ func (c *Client) Do(req *Request) (*Response, error) { if includeBody && ireq.GetBody != nil { req.Body, err = ireq.GetBody() if err != nil { + resp.closeBody() return nil, uerr(err) } req.ContentLength = ireq.ContentLength @@ -593,6 +600,8 @@ func (c *Client) Do(req *Request) (*Response, error) { var err error var didTimeout func() bool if resp, didTimeout, err = c.send(req, deadline); err != nil { + // c.send() always closes req.Body + reqBodyClosed = true if !deadline.IsZero() && didTimeout() { err = &httpError{ err: err.Error() + " (Client.Timeout exceeded while awaiting headers)", diff --git a/libgo/go/net/http/client_test.go b/libgo/go/net/http/client_test.go index 4f674dd..b9a1c31 100644 --- a/libgo/go/net/http/client_test.go +++ b/libgo/go/net/http/client_test.go @@ -10,7 +10,6 @@ import ( "bytes" "context" "crypto/tls" - "crypto/x509" "encoding/base64" "errors" "fmt" @@ -73,7 +72,7 @@ func TestClient(t *testing.T) { ts := httptest.NewServer(robotsTxtHandler) defer ts.Close() - c := &Client{Transport: &Transport{DisableKeepAlives: true}} + c := ts.Client() r, err := c.Get(ts.URL) var b []byte if err == nil { @@ -220,10 +219,7 @@ func TestClientRedirects(t *testing.T) { })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - - c := &Client{Transport: tr} + c := ts.Client() _, err := c.Get(ts.URL) if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g { t.Errorf("with default client Get, expected error %q, got %q", e, g) @@ -252,13 +248,10 @@ func TestClientRedirects(t *testing.T) { var checkErr error var lastVia []*Request var lastReq *Request - c = &Client{ - Transport: tr, - CheckRedirect: func(req *Request, via []*Request) error { - lastReq = req - lastVia = via - return checkErr - }, + c.CheckRedirect = func(req *Request, via []*Request) error { + lastReq = req + lastVia = via + return checkErr } res, err := c.Get(ts.URL) if err != nil { @@ -304,6 +297,7 @@ func TestClientRedirects(t *testing.T) { } } +// Tests that Client redirects' contexts are derived from the original request's context. func TestClientRedirectContext(t *testing.T) { setParallel(t) defer afterTest(t) @@ -312,19 +306,16 @@ func TestClientRedirectContext(t *testing.T) { })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - ctx, cancel := context.WithCancel(context.Background()) - c := &Client{ - Transport: tr, - CheckRedirect: func(req *Request, via []*Request) error { - cancel() - if len(via) > 2 { - return errors.New("too many redirects") - } + c := ts.Client() + c.CheckRedirect = func(req *Request, via []*Request) error { + cancel() + select { + case <-req.Context().Done(): return nil - }, + case <-time.After(5 * time.Second): + return errors.New("redirected request's context never expired after root request canceled") + } } req, _ := NewRequest("GET", ts.URL, nil) req = req.WithContext(ctx) @@ -458,11 +449,12 @@ func testRedirectsByMethod(t *testing.T, method string, table []redirectTest, wa })) defer ts.Close() + c := ts.Client() for _, tt := range table { content := tt.redirectBody req, _ := NewRequest(method, ts.URL+tt.suffix, strings.NewReader(content)) req.GetBody = func() (io.ReadCloser, error) { return ioutil.NopCloser(strings.NewReader(content)), nil } - res, err := DefaultClient.Do(req) + res, err := c.Do(req) if err != nil { t.Fatal(err) @@ -516,17 +508,12 @@ func TestClientRedirectUseResponse(t *testing.T) { })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - - c := &Client{ - Transport: tr, - CheckRedirect: func(req *Request, via []*Request) error { - if req.Response == nil { - t.Error("expected non-nil Request.Response") - } - return ErrUseLastResponse - }, + c := ts.Client() + c.CheckRedirect = func(req *Request, via []*Request) error { + if req.Response == nil { + t.Error("expected non-nil Request.Response") + } + return ErrUseLastResponse } res, err := c.Get(ts.URL) if err != nil { @@ -555,7 +542,8 @@ func TestClientRedirect308NoLocation(t *testing.T) { w.WriteHeader(308) })) defer ts.Close() - res, err := Get(ts.URL) + c := ts.Client() + res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) } @@ -582,8 +570,9 @@ func TestClientRedirect308NoGetBody(t *testing.T) { if err != nil { t.Fatal(err) } + c := ts.Client() req.GetBody = nil // so it can't rewind. - res, err := DefaultClient.Do(req) + res, err := c.Do(req) if err != nil { t.Fatal(err) } @@ -673,12 +662,8 @@ func TestRedirectCookiesJar(t *testing.T) { var ts *httptest.Server ts = httptest.NewServer(echoCookiesRedirectHandler) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{ - Transport: tr, - Jar: new(TestJar), - } + c := ts.Client() + c.Jar = new(TestJar) u, _ := url.Parse(ts.URL) c.Jar.SetCookies(u, []*Cookie{expectedCookies[0]}) resp, err := c.Get(ts.URL) @@ -722,13 +707,10 @@ func TestJarCalls(t *testing.T) { })) defer ts.Close() jar := new(RecordingJar) - c := &Client{ - Jar: jar, - Transport: &Transport{ - Dial: func(_ string, _ string) (net.Conn, error) { - return net.Dial("tcp", ts.Listener.Addr().String()) - }, - }, + c := ts.Client() + c.Jar = jar + c.Transport.(*Transport).Dial = func(_ string, _ string) (net.Conn, error) { + return net.Dial("tcp", ts.Listener.Addr().String()) } _, err := c.Get("http://firsthost.fake/") if err != nil { @@ -840,7 +822,8 @@ func TestClientWrites(t *testing.T) { } return c, err } - c := &Client{Transport: &Transport{Dial: dialer}} + c := ts.Client() + c.Transport.(*Transport).Dial = dialer _, err := c.Get(ts.URL) if err != nil { @@ -873,14 +856,11 @@ func TestClientInsecureTransport(t *testing.T) { // TODO(bradfitz): add tests for skipping hostname checks too? // would require a new cert for testing, and probably // redundant with these tests. + c := ts.Client() for _, insecure := range []bool{true, false} { - tr := &Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: insecure, - }, + c.Transport.(*Transport).TLSClientConfig = &tls.Config{ + InsecureSkipVerify: insecure, } - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} res, err := c.Get(ts.URL) if (err == nil) != insecure { t.Errorf("insecure=%v: got unexpected err=%v", insecure, err) @@ -914,22 +894,6 @@ func TestClientErrorWithRequestURI(t *testing.T) { } } -func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport { - certs := x509.NewCertPool() - for _, c := range ts.TLS.Certificates { - roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) - if err != nil { - t.Fatalf("error parsing server's root cert: %v", err) - } - for _, root := range roots { - certs.AddCert(root) - } - } - return &Transport{ - TLSClientConfig: &tls.Config{RootCAs: certs}, - } -} - func TestClientWithCorrectTLSServerName(t *testing.T) { defer afterTest(t) @@ -941,9 +905,8 @@ func TestClientWithCorrectTLSServerName(t *testing.T) { })) defer ts.Close() - trans := newTLSTransport(t, ts) - trans.TLSClientConfig.ServerName = serverName - c := &Client{Transport: trans} + c := ts.Client() + c.Transport.(*Transport).TLSClientConfig.ServerName = serverName if _, err := c.Get(ts.URL); err != nil { t.Fatalf("expected successful TLS connection, got error: %v", err) } @@ -956,9 +919,8 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) { errc := make(chanWriter, 10) // but only expecting 1 ts.Config.ErrorLog = log.New(errc, "", 0) - trans := newTLSTransport(t, ts) - trans.TLSClientConfig.ServerName = "badserver" - c := &Client{Transport: trans} + c := ts.Client() + c.Transport.(*Transport).TLSClientConfig.ServerName = "badserver" _, err := c.Get(ts.URL) if err == nil { t.Fatalf("expected an error") @@ -992,13 +954,12 @@ func TestTransportUsesTLSConfigServerName(t *testing.T) { })) defer ts.Close() - tr := newTLSTransport(t, ts) + c := ts.Client() + tr := c.Transport.(*Transport) tr.TLSClientConfig.ServerName = "example.com" // one of httptest's Server cert names tr.Dial = func(netw, addr string) (net.Conn, error) { return net.Dial(netw, ts.Listener.Addr().String()) } - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} res, err := c.Get("https://some-other-host.tld/") if err != nil { t.Fatal(err) @@ -1013,13 +974,12 @@ func TestResponseSetsTLSConnectionState(t *testing.T) { })) defer ts.Close() - tr := newTLSTransport(t, ts) + c := ts.Client() + tr := c.Transport.(*Transport) tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA} tr.Dial = func(netw, addr string) (net.Conn, error) { return net.Dial(netw, ts.Listener.Addr().String()) } - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} res, err := c.Get("https://example.com/") if err != nil { t.Fatal(err) @@ -1114,14 +1074,12 @@ func TestEmptyPasswordAuth(t *testing.T) { } })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} req, err := NewRequest("GET", ts.URL, nil) if err != nil { t.Fatal(err) } req.URL.User = url.User(gopher) + c := ts.Client() resp, err := c.Do(req) if err != nil { t.Fatal(err) @@ -1498,21 +1456,17 @@ func TestClientCopyHeadersOnRedirect(t *testing.T) { defer ts2.Close() ts2URL = ts2.URL - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{ - Transport: tr, - CheckRedirect: func(r *Request, via []*Request) error { - want := Header{ - "User-Agent": []string{ua}, - "X-Foo": []string{xfoo}, - "Referer": []string{ts2URL}, - } - if !reflect.DeepEqual(r.Header, want) { - t.Errorf("CheckRedirect Request.Header = %#v; want %#v", r.Header, want) - } - return nil - }, + c := ts1.Client() + c.CheckRedirect = func(r *Request, via []*Request) error { + want := Header{ + "User-Agent": []string{ua}, + "X-Foo": []string{xfoo}, + "Referer": []string{ts2URL}, + } + if !reflect.DeepEqual(r.Header, want) { + t.Errorf("CheckRedirect Request.Header = %#v; want %#v", r.Header, want) + } + return nil } req, _ := NewRequest("GET", ts2.URL, nil) @@ -1601,13 +1555,9 @@ func TestClientAltersCookiesOnRedirect(t *testing.T) { })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() jar, _ := cookiejar.New(nil) - c := &Client{ - Transport: tr, - Jar: jar, - } + c := ts.Client() + c.Jar = jar u, _ := url.Parse(ts.URL) req, _ := NewRequest("GET", ts.URL, nil) @@ -1725,9 +1675,7 @@ func TestClientRedirectTypes(t *testing.T) { })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - + c := ts.Client() for i, tt := range tests { handlerc <- func(w ResponseWriter, r *Request) { w.Header().Set("Location", ts.URL) @@ -1740,7 +1688,6 @@ func TestClientRedirectTypes(t *testing.T) { continue } - c := &Client{Transport: tr} c.CheckRedirect = func(req *Request, via []*Request) error { if got, want := req.Method, tt.wantMethod; got != want { return fmt.Errorf("#%d: got next method %q; want %q", i, got, want) @@ -1780,8 +1727,8 @@ func (b issue18239Body) Close() error { return nil } -// Issue 18239: make sure the Transport doesn't retry requests with bodies. -// (Especially if Request.GetBody is not defined.) +// Issue 18239: make sure the Transport doesn't retry requests with bodies +// if Request.GetBody is not defined. func TestTransportBodyReadError(t *testing.T) { setParallel(t) defer afterTest(t) @@ -1794,9 +1741,8 @@ func TestTransportBodyReadError(t *testing.T) { w.Header().Set("X-Body-Read", fmt.Sprintf("%v, %v", n, err)) })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() + tr := c.Transport.(*Transport) // Do one initial successful request to create an idle TCP connection // for the subsequent request to reuse. (The Transport only retries @@ -1816,6 +1762,7 @@ func TestTransportBodyReadError(t *testing.T) { if err != nil { t.Fatal(err) } + req = req.WithT(t) _, err = tr.RoundTrip(req) if err != someErr { t.Errorf("Got error: %v; want Request.Body read error: %v", err, someErr) diff --git a/libgo/go/net/http/clientserver_test.go b/libgo/go/net/http/clientserver_test.go index 580115c..20feaa7 100644 --- a/libgo/go/net/http/clientserver_test.go +++ b/libgo/go/net/http/clientserver_test.go @@ -1385,3 +1385,30 @@ func testServerUndeclaredTrailers(t *testing.T, h2 bool) { t.Errorf("Trailer = %#v; want %#v", res.Trailer, want) } } + +func TestBadResponseAfterReadingBody(t *testing.T) { + defer afterTest(t) + cst := newClientServerTest(t, false, HandlerFunc(func(w ResponseWriter, r *Request) { + _, err := io.Copy(ioutil.Discard, r.Body) + if err != nil { + t.Fatal(err) + } + c, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Fatal(err) + } + defer c.Close() + fmt.Fprintln(c, "some bogus crap") + })) + defer cst.close() + + closes := 0 + res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) + if err == nil { + res.Body.Close() + t.Fatal("expected an error to be returned from Post") + } + if closes != 1 { + t.Errorf("closes = %d; want 1", closes) + } +} diff --git a/libgo/go/net/http/cookie.go b/libgo/go/net/http/cookie.go index 5a67476..cf52248 100644 --- a/libgo/go/net/http/cookie.go +++ b/libgo/go/net/http/cookie.go @@ -328,7 +328,7 @@ func sanitizeCookieValue(v string) string { if len(v) == 0 { return v } - if v[0] == ' ' || v[0] == ',' || v[len(v)-1] == ' ' || v[len(v)-1] == ',' { + if strings.IndexByte(v, ' ') >= 0 || strings.IndexByte(v, ',') >= 0 { return `"` + v + `"` } return v diff --git a/libgo/go/net/http/cookie_test.go b/libgo/go/net/http/cookie_test.go index b3e54f8..9d199a3 100644 --- a/libgo/go/net/http/cookie_test.go +++ b/libgo/go/net/http/cookie_test.go @@ -69,7 +69,7 @@ var writeSetCookiesTests = []struct { // are disallowed by RFC 6265 but are common in the wild. { &Cookie{Name: "special-1", Value: "a z"}, - `special-1=a z`, + `special-1="a z"`, }, { &Cookie{Name: "special-2", Value: " z"}, @@ -85,7 +85,7 @@ var writeSetCookiesTests = []struct { }, { &Cookie{Name: "special-5", Value: "a,z"}, - `special-5=a,z`, + `special-5="a,z"`, }, { &Cookie{Name: "special-6", Value: ",z"}, @@ -398,9 +398,12 @@ func TestCookieSanitizeValue(t *testing.T) { {"foo\"bar", "foobar"}, {"\x00\x7e\x7f\x80", "\x7e"}, {`"withquotes"`, "withquotes"}, - {"a z", "a z"}, + {"a z", `"a z"`}, {" z", `" z"`}, {"a ", `"a "`}, + {"a,z", `"a,z"`}, + {",z", `",z"`}, + {"a,", `"a,"`}, } for _, tt := range tests { if got := sanitizeCookieValue(tt.in); got != tt.want { diff --git a/libgo/go/net/http/cookiejar/jar.go b/libgo/go/net/http/cookiejar/jar.go index f89abbc..ef8c35b 100644 --- a/libgo/go/net/http/cookiejar/jar.go +++ b/libgo/go/net/http/cookiejar/jar.go @@ -331,7 +331,7 @@ func jarKey(host string, psl PublicSuffixList) string { var i int if psl == nil { i = strings.LastIndex(host, ".") - if i == -1 { + if i <= 0 { return host } } else { @@ -345,6 +345,9 @@ func jarKey(host string, psl PublicSuffixList) string { // Storing cookies under host is a safe stopgap. return host } + // Only len(suffix) is used to determine the jar key from + // here on, so it is okay if psl.PublicSuffix("www.buggy.psl") + // returns "com" as the jar key is generated from host. } prevDot := strings.LastIndex(host[:i-1], ".") return host[prevDot+1:] diff --git a/libgo/go/net/http/cookiejar/jar_test.go b/libgo/go/net/http/cookiejar/jar_test.go index 3aa6015..47fb1ab 100644 --- a/libgo/go/net/http/cookiejar/jar_test.go +++ b/libgo/go/net/http/cookiejar/jar_test.go @@ -19,6 +19,9 @@ var tNow = time.Date(2013, 1, 1, 12, 0, 0, 0, time.UTC) // testPSL implements PublicSuffixList with just two rules: "co.uk" // and the default rule "*". +// The implementation has two intentional bugs: +// PublicSuffix("www.buggy.psl") == "xy" +// PublicSuffix("www2.buggy.psl") == "com" type testPSL struct{} func (testPSL) String() string { @@ -28,6 +31,12 @@ func (testPSL) PublicSuffix(d string) string { if d == "co.uk" || strings.HasSuffix(d, ".co.uk") { return "co.uk" } + if d == "www.buggy.psl" { + return "xy" + } + if d == "www2.buggy.psl" { + return "com" + } return d[strings.LastIndex(d, ".")+1:] } @@ -125,6 +134,17 @@ var canonicalHostTests = map[string]string{ "[2001:4860:0:::68]:8080": "2001:4860:0:::68", "www.bücher.de": "www.xn--bcher-kva.de", "www.example.com.": "www.example.com", + // TODO: Fix canonicalHost so that all of the following malformed + // domain names trigger an error. (This list is not exhaustive, e.g. + // malformed internationalized domain names are missing.) + ".": "", + "..": ".", + "...": "..", + ".net": ".net", + ".net.": ".net", + "a..": "a.", + "b.a..": "b.a.", + "weird.stuff...": "weird.stuff..", "[bad.unmatched.bracket:": "error", } @@ -133,7 +153,7 @@ func TestCanonicalHost(t *testing.T) { got, err := canonicalHost(h) if want == "error" { if err == nil { - t.Errorf("%q: got nil error, want non-nil", h) + t.Errorf("%q: got %q and nil error, want non-nil", h, got) } continue } @@ -176,6 +196,17 @@ var jarKeyTests = map[string]string{ "co.uk": "co.uk", "uk": "uk", "192.168.0.5": "192.168.0.5", + "www.buggy.psl": "www.buggy.psl", + "www2.buggy.psl": "buggy.psl", + // The following are actual outputs of canonicalHost for + // malformed inputs to canonicalHost (see above). + "": "", + ".": ".", + "..": ".", + ".net": ".net", + "a.": "a.", + "b.a.": "a.", + "weird.stuff..": ".", } func TestJarKey(t *testing.T) { @@ -197,6 +228,15 @@ var jarKeyNilPSLTests = map[string]string{ "co.uk": "co.uk", "uk": "uk", "192.168.0.5": "192.168.0.5", + // The following are actual outputs of canonicalHost for + // malformed inputs to canonicalHost. + "": "", + ".": ".", + "..": "..", + ".net": ".net", + "a.": "a.", + "b.a.": "a.", + "weird.stuff..": "stuff..", } func TestJarKeyNilPSL(t *testing.T) { @@ -1265,3 +1305,18 @@ func TestDomainHandling(t *testing.T) { test.run(t, jar) } } + +func TestIssue19384(t *testing.T) { + cookies := []*http.Cookie{{Name: "name", Value: "value"}} + for _, host := range []string{"", ".", "..", "..."} { + jar, _ := New(nil) + u := &url.URL{Scheme: "http", Host: host, Path: "/"} + if got := jar.Cookies(u); len(got) != 0 { + t.Errorf("host %q, got %v", host, got) + } + jar.SetCookies(u, cookies) + if got := jar.Cookies(u); len(got) != 1 || got[0].Value != "value" { + t.Errorf("host %q, got %v", host, got) + } + } +} diff --git a/libgo/go/net/http/export_test.go b/libgo/go/net/http/export_test.go index b61f58b..2ef145e 100644 --- a/libgo/go/net/http/export_test.go +++ b/libgo/go/net/http/export_test.go @@ -8,24 +8,29 @@ package http import ( + "context" "net" "sort" "sync" + "testing" "time" ) var ( - DefaultUserAgent = defaultUserAgent - NewLoggingConn = newLoggingConn - ExportAppendTime = appendTime - ExportRefererForURL = refererForURL - ExportServerNewConn = (*Server).newConn - ExportCloseWriteAndWait = (*conn).closeWriteAndWait - ExportErrRequestCanceled = errRequestCanceled - ExportErrRequestCanceledConn = errRequestCanceledConn - ExportServeFile = serveFile - ExportScanETag = scanETag - ExportHttp2ConfigureServer = http2ConfigureServer + DefaultUserAgent = defaultUserAgent + NewLoggingConn = newLoggingConn + ExportAppendTime = appendTime + ExportRefererForURL = refererForURL + ExportServerNewConn = (*Server).newConn + ExportCloseWriteAndWait = (*conn).closeWriteAndWait + ExportErrRequestCanceled = errRequestCanceled + ExportErrRequestCanceledConn = errRequestCanceledConn + ExportErrServerClosedIdle = errServerClosedIdle + ExportServeFile = serveFile + ExportScanETag = scanETag + ExportHttp2ConfigureServer = http2ConfigureServer + Export_shouldCopyHeaderOnRedirect = shouldCopyHeaderOnRedirect + Export_writeStatusLine = writeStatusLine ) func init() { @@ -186,8 +191,6 @@ func ExportHttp2ConfigureTransport(t *Transport) error { return nil } -var Export_shouldCopyHeaderOnRedirect = shouldCopyHeaderOnRedirect - func (s *Server) ExportAllConnsIdle() bool { s.mu.Lock() defer s.mu.Unlock() @@ -199,3 +202,7 @@ func (s *Server) ExportAllConnsIdle() bool { } return true } + +func (r *Request) WithT(t *testing.T) *Request { + return r.WithContext(context.WithValue(r.Context(), tLogKey{}, t.Logf)) +} diff --git a/libgo/go/net/http/fcgi/child.go b/libgo/go/net/http/fcgi/child.go index 8870424..30a6b2c 100644 --- a/libgo/go/net/http/fcgi/child.go +++ b/libgo/go/net/http/fcgi/child.go @@ -7,6 +7,7 @@ package fcgi // This file implements FastCGI from the perspective of a child process. import ( + "context" "errors" "fmt" "io" @@ -31,6 +32,10 @@ type request struct { keepConn bool } +// envVarsContextKey uniquely identifies a mapping of CGI +// environment variables to their values in a request context +type envVarsContextKey struct{} + func newRequest(reqId uint16, flags uint8) *request { r := &request{ reqId: reqId, @@ -259,6 +264,18 @@ func (c *child) handleRecord(rec *record) error { } } +// filterOutUsedEnvVars returns a new map of env vars without the +// variables in the given envVars map that are read for creating each http.Request +func filterOutUsedEnvVars(envVars map[string]string) map[string]string { + withoutUsedEnvVars := make(map[string]string) + for k, v := range envVars { + if addFastCGIEnvToContext(k) { + withoutUsedEnvVars[k] = v + } + } + return withoutUsedEnvVars +} + func (c *child) serveRequest(req *request, body io.ReadCloser) { r := newResponse(c, req) httpReq, err := cgi.RequestFromMap(req.params) @@ -268,6 +285,9 @@ func (c *child) serveRequest(req *request, body io.ReadCloser) { c.conn.writeRecord(typeStderr, req.reqId, []byte(err.Error())) } else { httpReq.Body = body + withoutUsedEnvVars := filterOutUsedEnvVars(req.params) + envVarCtx := context.WithValue(httpReq.Context(), envVarsContextKey{}, withoutUsedEnvVars) + httpReq = httpReq.WithContext(envVarCtx) c.handler.ServeHTTP(r, httpReq) } r.Close() @@ -329,3 +349,39 @@ func Serve(l net.Listener, handler http.Handler) error { go c.serve() } } + +// ProcessEnv returns FastCGI environment variables associated with the request r +// for which no effort was made to be included in the request itself - the data +// is hidden in the request's context. As an example, if REMOTE_USER is set for a +// request, it will not be found anywhere in r, but it will be included in +// ProcessEnv's response (via r's context). +func ProcessEnv(r *http.Request) map[string]string { + env, _ := r.Context().Value(envVarsContextKey{}).(map[string]string) + return env +} + +// addFastCGIEnvToContext reports whether to include the FastCGI environment variable s +// in the http.Request.Context, accessible via ProcessEnv. +func addFastCGIEnvToContext(s string) bool { + // Exclude things supported by net/http natively: + switch s { + case "CONTENT_LENGTH", "CONTENT_TYPE", "HTTPS", + "PATH_INFO", "QUERY_STRING", "REMOTE_ADDR", + "REMOTE_HOST", "REMOTE_PORT", "REQUEST_METHOD", + "REQUEST_URI", "SCRIPT_NAME", "SERVER_PROTOCOL": + return false + } + if strings.HasPrefix(s, "HTTP_") { + return false + } + // Explicitly include FastCGI-specific things. + // This list is redundant with the default "return true" below. + // Consider this documentation of the sorts of things we expect + // to maybe see. + switch s { + case "REMOTE_USER": + return true + } + // Unknown, so include it to be safe. + return true +} diff --git a/libgo/go/net/http/fcgi/fcgi.go b/libgo/go/net/http/fcgi/fcgi.go index 5057d70..8f3449a 100644 --- a/libgo/go/net/http/fcgi/fcgi.go +++ b/libgo/go/net/http/fcgi/fcgi.go @@ -24,7 +24,7 @@ import ( ) // recType is a record type, as defined by -// http://www.fastcgi.com/devkit/doc/fcgi-spec.html#S8 +// https://web.archive.org/web/20150420080736/http://www.fastcgi.com/drupal/node/6?q=node/22#S8 type recType uint8 const ( diff --git a/libgo/go/net/http/fcgi/fcgi_test.go b/libgo/go/net/http/fcgi/fcgi_test.go index b6013bf..e9d2b34 100644 --- a/libgo/go/net/http/fcgi/fcgi_test.go +++ b/libgo/go/net/http/fcgi/fcgi_test.go @@ -278,3 +278,69 @@ func TestMalformedParams(t *testing.T) { c := newChild(rw, http.DefaultServeMux) c.serve() } + +// a series of FastCGI records that start and end a request +var streamFullRequestStdin = bytes.Join([][]byte{ + // set up request + makeRecord(typeBeginRequest, 1, + []byte{0, byte(roleResponder), 0, 0, 0, 0, 0, 0}), + // add required parameters + makeRecord(typeParams, 1, nameValuePair11("REQUEST_METHOD", "GET")), + makeRecord(typeParams, 1, nameValuePair11("SERVER_PROTOCOL", "HTTP/1.1")), + // set optional parameters + makeRecord(typeParams, 1, nameValuePair11("REMOTE_USER", "jane.doe")), + makeRecord(typeParams, 1, nameValuePair11("QUERY_STRING", "/foo/bar")), + makeRecord(typeParams, 1, nil), + // begin sending body of request + makeRecord(typeStdin, 1, []byte("0123456789abcdef")), + // end request + makeRecord(typeEndRequest, 1, nil), +}, + nil) + +var envVarTests = []struct { + input []byte + envVar string + expectedVal string + expectedFilteredOut bool +}{ + { + streamFullRequestStdin, + "REMOTE_USER", + "jane.doe", + false, + }, + { + streamFullRequestStdin, + "QUERY_STRING", + "", + true, + }, +} + +// Test that environment variables set for a request can be +// read by a handler. Ensures that variables not set will not +// be exposed to a handler. +func TestChildServeReadsEnvVars(t *testing.T) { + for _, tt := range envVarTests { + input := make([]byte, len(tt.input)) + copy(input, tt.input) + rc := nopWriteCloser{bytes.NewBuffer(input)} + done := make(chan bool) + c := newChild(rc, http.HandlerFunc(func( + w http.ResponseWriter, + r *http.Request, + ) { + env := ProcessEnv(r) + if _, ok := env[tt.envVar]; ok && tt.expectedFilteredOut { + t.Errorf("Expected environment variable %s to not be set, but set to %s", + tt.envVar, env[tt.envVar]) + } else if env[tt.envVar] != tt.expectedVal { + t.Errorf("Expected %s, got %s", tt.expectedVal, env[tt.envVar]) + } + done <- true + })) + go c.serve() + <-done + } +} diff --git a/libgo/go/net/http/filetransport_test.go b/libgo/go/net/http/filetransport_test.go index 6f1a537..2a2f32c 100644 --- a/libgo/go/net/http/filetransport_test.go +++ b/libgo/go/net/http/filetransport_test.go @@ -49,6 +49,7 @@ func TestFileTransport(t *testing.T) { t.Fatalf("for %s, nil Body", urlstr) } slurp, err := ioutil.ReadAll(res.Body) + res.Body.Close() check("ReadAll "+urlstr, err) if string(slurp) != "Bar" { t.Errorf("for %s, got content %q, want %q", urlstr, string(slurp), "Bar") diff --git a/libgo/go/net/http/fs.go b/libgo/go/net/http/fs.go index bf63bb5..5819334 100644 --- a/libgo/go/net/http/fs.go +++ b/libgo/go/net/http/fs.go @@ -30,21 +30,51 @@ import ( // value is a filename on the native file system, not a URL, so it is separated // by filepath.Separator, which isn't necessarily '/'. // +// Note that Dir will allow access to files and directories starting with a +// period, which could expose sensitive directories like a .git directory or +// sensitive files like .htpasswd. To exclude files with a leading period, +// remove the files/directories from the server or create a custom FileSystem +// implementation. +// // An empty Dir is treated as ".". type Dir string +// mapDirOpenError maps the provided non-nil error from opening name +// to a possibly better non-nil error. In particular, it turns OS-specific errors +// about opening files in non-directories into os.ErrNotExist. See Issue 18984. +func mapDirOpenError(originalErr error, name string) error { + if os.IsNotExist(originalErr) || os.IsPermission(originalErr) { + return originalErr + } + + parts := strings.Split(name, string(filepath.Separator)) + for i := range parts { + if parts[i] == "" { + continue + } + fi, err := os.Stat(strings.Join(parts[:i+1], string(filepath.Separator))) + if err != nil { + return originalErr + } + if !fi.IsDir() { + return os.ErrNotExist + } + } + return originalErr +} + func (d Dir) Open(name string) (File, error) { - if filepath.Separator != '/' && strings.ContainsRune(name, filepath.Separator) || - strings.Contains(name, "\x00") { + if filepath.Separator != '/' && strings.ContainsRune(name, filepath.Separator) { return nil, errors.New("http: invalid character in file path") } dir := string(d) if dir == "" { dir = "." } - f, err := os.Open(filepath.Join(dir, filepath.FromSlash(path.Clean("/"+name)))) + fullName := filepath.Join(dir, filepath.FromSlash(path.Clean("/"+name))) + f, err := os.Open(fullName) if err != nil { - return nil, err + return nil, mapDirOpenError(err, fullName) } return f, nil } @@ -291,7 +321,7 @@ func scanETag(s string) (etag string, remain string) { case c == '"': return string(s[:i+1]), s[i+1:] default: - break + return "", "" } } return "", "" @@ -349,7 +379,7 @@ func checkIfMatch(w ResponseWriter, r *Request) condResult { return condFalse } -func checkIfUnmodifiedSince(w ResponseWriter, r *Request, modtime time.Time) condResult { +func checkIfUnmodifiedSince(r *Request, modtime time.Time) condResult { ius := r.Header.Get("If-Unmodified-Since") if ius == "" || isZeroTime(modtime) { return condNone @@ -394,7 +424,7 @@ func checkIfNoneMatch(w ResponseWriter, r *Request) condResult { return condTrue } -func checkIfModifiedSince(w ResponseWriter, r *Request, modtime time.Time) condResult { +func checkIfModifiedSince(r *Request, modtime time.Time) condResult { if r.Method != "GET" && r.Method != "HEAD" { return condNone } @@ -479,7 +509,7 @@ func checkPreconditions(w ResponseWriter, r *Request, modtime time.Time) (done b // This function carefully follows RFC 7232 section 6. ch := checkIfMatch(w, r) if ch == condNone { - ch = checkIfUnmodifiedSince(w, r, modtime) + ch = checkIfUnmodifiedSince(r, modtime) } if ch == condFalse { w.WriteHeader(StatusPreconditionFailed) @@ -495,7 +525,7 @@ func checkPreconditions(w ResponseWriter, r *Request, modtime time.Time) (done b return true, "" } case condNone: - if checkIfModifiedSince(w, r, modtime) == condFalse { + if checkIfModifiedSince(r, modtime) == condFalse { writeNotModified(w) return true, "" } @@ -580,7 +610,7 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec // Still a directory? (we didn't find an index.html file) if d.IsDir() { - if checkIfModifiedSince(w, r, d.ModTime()) == condFalse { + if checkIfModifiedSince(r, d.ModTime()) == condFalse { writeNotModified(w) return } diff --git a/libgo/go/net/http/fs_test.go b/libgo/go/net/http/fs_test.go index bba5682..f6eab0f 100644 --- a/libgo/go/net/http/fs_test.go +++ b/libgo/go/net/http/fs_test.go @@ -74,6 +74,7 @@ func TestServeFile(t *testing.T) { ServeFile(w, r, "testdata/file") })) defer ts.Close() + c := ts.Client() var err error @@ -91,7 +92,7 @@ func TestServeFile(t *testing.T) { req.Method = "GET" // straight GET - _, body := getBody(t, "straight get", req) + _, body := getBody(t, "straight get", req, c) if !bytes.Equal(body, file) { t.Fatalf("body mismatch: got %q, want %q", body, file) } @@ -102,7 +103,7 @@ Cases: if rt.r != "" { req.Header.Set("Range", rt.r) } - resp, body := getBody(t, fmt.Sprintf("range test %q", rt.r), req) + resp, body := getBody(t, fmt.Sprintf("range test %q", rt.r), req, c) if resp.StatusCode != rt.code { t.Errorf("range=%q: StatusCode=%d, want %d", rt.r, resp.StatusCode, rt.code) } @@ -704,7 +705,8 @@ func TestDirectoryIfNotModified(t *testing.T) { req, _ := NewRequest("GET", ts.URL, nil) req.Header.Set("If-Modified-Since", lastMod) - res, err = DefaultClient.Do(req) + c := ts.Client() + res, err = c.Do(req) if err != nil { t.Fatal(err) } @@ -716,7 +718,7 @@ func TestDirectoryIfNotModified(t *testing.T) { // Advance the index.html file's modtime, but not the directory's. indexFile.modtime = indexFile.modtime.Add(1 * time.Hour) - res, err = DefaultClient.Do(req) + res, err = c.Do(req) if err != nil { t.Fatal(err) } @@ -995,7 +997,9 @@ func TestServeContent(t *testing.T) { for k, v := range tt.reqHeader { req.Header.Set(k, v) } - res, err := DefaultClient.Do(req) + + c := ts.Client() + res, err := c.Do(req) if err != nil { t.Fatal(err) } @@ -1050,8 +1054,9 @@ func TestServeContentErrorMessages(t *testing.T) { } ts := httptest.NewServer(FileServer(fs)) defer ts.Close() + c := ts.Client() for _, code := range []int{403, 404, 500} { - res, err := DefaultClient.Get(fmt.Sprintf("%s/%d", ts.URL, code)) + res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, code)) if err != nil { t.Errorf("Error fetching /%d: %v", code, err) continue @@ -1090,8 +1095,11 @@ func TestLinuxSendfile(t *testing.T) { // strace on the above platforms doesn't support sendfile64 // and will error out if we specify that with `-e trace='. syscalls = "sendfile" - case "mips64": - t.Skip("TODO: update this test to be robust against various versions of strace on mips64. See golang.org/issue/33430") + } + + // Attempt to run strace, and skip on failure - this test requires SYS_PTRACE. + if err := exec.Command("strace", "-f", "-q", "-e", "trace="+syscalls, os.Args[0], "-test.run=^$").Run(); err != nil { + t.Skipf("skipping; failed to run strace: %v", err) } var buf bytes.Buffer @@ -1125,8 +1133,8 @@ func TestLinuxSendfile(t *testing.T) { } } -func getBody(t *testing.T, testName string, req Request) (*Response, []byte) { - r, err := DefaultClient.Do(&req) +func getBody(t *testing.T, testName string, req Request, client *Client) (*Response, []byte) { + r, err := client.Do(&req) if err != nil { t.Fatalf("%s: for URL %q, send error: %v", testName, req.URL.String(), err) } @@ -1161,6 +1169,50 @@ func TestLinuxSendfileChild(*testing.T) { } } +// Issue 18984: tests that requests for paths beyond files return not-found errors +func TestFileServerNotDirError(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(FileServer(Dir("testdata"))) + defer ts.Close() + + res, err := Get(ts.URL + "/index.html/not-a-file") + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != 404 { + t.Errorf("StatusCode = %v; want 404", res.StatusCode) + } + + test := func(name string, dir Dir) { + t.Run(name, func(t *testing.T) { + _, err = dir.Open("/index.html/not-a-file") + if err == nil { + t.Fatal("err == nil; want != nil") + } + if !os.IsNotExist(err) { + t.Errorf("err = %v; os.IsNotExist(err) = %v; want true", err, os.IsNotExist(err)) + } + + _, err = dir.Open("/index.html/not-a-dir/not-a-file") + if err == nil { + t.Fatal("err == nil; want != nil") + } + if !os.IsNotExist(err) { + t.Errorf("err = %v; os.IsNotExist(err) = %v; want true", err, os.IsNotExist(err)) + } + }) + } + + absPath, err := filepath.Abs("testdata") + if err != nil { + t.Fatal("get abs path:", err) + } + + test("RelativePath", Dir("testdata")) + test("AbsolutePath", Dir(absPath)) +} + func TestFileServerCleanPath(t *testing.T) { tests := []struct { path string @@ -1210,10 +1262,10 @@ func Test_scanETag(t *testing.T) { {`"etag-2"`, `"etag-2"`, ""}, {`"etag-1", "etag-2"`, `"etag-1"`, `, "etag-2"`}, {"", "", ""}, - {"", "", ""}, {"W/", "", ""}, {`W/"truc`, "", ""}, {`w/"case-sensitive"`, "", ""}, + {`"spaced etag"`, "", ""}, } for _, test := range tests { etag, remain := ExportScanETag(test.in) diff --git a/libgo/go/net/http/h2_bundle.go b/libgo/go/net/http/h2_bundle.go index 6fbbcd0..373f550 100644 --- a/libgo/go/net/http/h2_bundle.go +++ b/libgo/go/net/http/h2_bundle.go @@ -48,6 +48,642 @@ import ( "golang_org/x/net/lex/httplex" ) +// A list of the possible cipher suite ids. Taken from +// http://www.iana.org/assignments/tls-parameters/tls-parameters.txt + +const ( + http2cipher_TLS_NULL_WITH_NULL_NULL uint16 = 0x0000 + http2cipher_TLS_RSA_WITH_NULL_MD5 uint16 = 0x0001 + http2cipher_TLS_RSA_WITH_NULL_SHA uint16 = 0x0002 + http2cipher_TLS_RSA_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0003 + http2cipher_TLS_RSA_WITH_RC4_128_MD5 uint16 = 0x0004 + http2cipher_TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005 + http2cipher_TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x0006 + http2cipher_TLS_RSA_WITH_IDEA_CBC_SHA uint16 = 0x0007 + http2cipher_TLS_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0008 + http2cipher_TLS_RSA_WITH_DES_CBC_SHA uint16 = 0x0009 + http2cipher_TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000A + http2cipher_TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000B + http2cipher_TLS_DH_DSS_WITH_DES_CBC_SHA uint16 = 0x000C + http2cipher_TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x000D + http2cipher_TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x000E + http2cipher_TLS_DH_RSA_WITH_DES_CBC_SHA uint16 = 0x000F + http2cipher_TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0010 + http2cipher_TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0011 + http2cipher_TLS_DHE_DSS_WITH_DES_CBC_SHA uint16 = 0x0012 + http2cipher_TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0x0013 + http2cipher_TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0014 + http2cipher_TLS_DHE_RSA_WITH_DES_CBC_SHA uint16 = 0x0015 + http2cipher_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0016 + http2cipher_TLS_DH_anon_EXPORT_WITH_RC4_40_MD5 uint16 = 0x0017 + http2cipher_TLS_DH_anon_WITH_RC4_128_MD5 uint16 = 0x0018 + http2cipher_TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA uint16 = 0x0019 + http2cipher_TLS_DH_anon_WITH_DES_CBC_SHA uint16 = 0x001A + http2cipher_TLS_DH_anon_WITH_3DES_EDE_CBC_SHA uint16 = 0x001B + // Reserved uint16 = 0x001C-1D + http2cipher_TLS_KRB5_WITH_DES_CBC_SHA uint16 = 0x001E + http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_SHA uint16 = 0x001F + http2cipher_TLS_KRB5_WITH_RC4_128_SHA uint16 = 0x0020 + http2cipher_TLS_KRB5_WITH_IDEA_CBC_SHA uint16 = 0x0021 + http2cipher_TLS_KRB5_WITH_DES_CBC_MD5 uint16 = 0x0022 + http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_MD5 uint16 = 0x0023 + http2cipher_TLS_KRB5_WITH_RC4_128_MD5 uint16 = 0x0024 + http2cipher_TLS_KRB5_WITH_IDEA_CBC_MD5 uint16 = 0x0025 + http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA uint16 = 0x0026 + http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_SHA uint16 = 0x0027 + http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_SHA uint16 = 0x0028 + http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5 uint16 = 0x0029 + http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_MD5 uint16 = 0x002A + http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_MD5 uint16 = 0x002B + http2cipher_TLS_PSK_WITH_NULL_SHA uint16 = 0x002C + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA uint16 = 0x002D + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA uint16 = 0x002E + http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002F + http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0030 + http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0031 + http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA uint16 = 0x0032 + http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0033 + http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA uint16 = 0x0034 + http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035 + http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0036 + http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0037 + http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA uint16 = 0x0038 + http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0039 + http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA uint16 = 0x003A + http2cipher_TLS_RSA_WITH_NULL_SHA256 uint16 = 0x003B + http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003C + http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x003D + http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x003E + http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003F + http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA256 uint16 = 0x0040 + http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0041 + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0042 + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0043 + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0044 + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0045 + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA uint16 = 0x0046 + // Reserved uint16 = 0x0047-4F + // Reserved uint16 = 0x0050-58 + // Reserved uint16 = 0x0059-5C + // Unassigned uint16 = 0x005D-5F + // Reserved uint16 = 0x0060-66 + http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x0067 + http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x0068 + http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x0069 + http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA256 uint16 = 0x006A + http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x006B + http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA256 uint16 = 0x006C + http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA256 uint16 = 0x006D + // Unassigned uint16 = 0x006E-83 + http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0084 + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0085 + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0086 + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0087 + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0088 + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA uint16 = 0x0089 + http2cipher_TLS_PSK_WITH_RC4_128_SHA uint16 = 0x008A + http2cipher_TLS_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008B + http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA uint16 = 0x008C + http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA uint16 = 0x008D + http2cipher_TLS_DHE_PSK_WITH_RC4_128_SHA uint16 = 0x008E + http2cipher_TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x008F + http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0090 + http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0091 + http2cipher_TLS_RSA_PSK_WITH_RC4_128_SHA uint16 = 0x0092 + http2cipher_TLS_RSA_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0x0093 + http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA uint16 = 0x0094 + http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA uint16 = 0x0095 + http2cipher_TLS_RSA_WITH_SEED_CBC_SHA uint16 = 0x0096 + http2cipher_TLS_DH_DSS_WITH_SEED_CBC_SHA uint16 = 0x0097 + http2cipher_TLS_DH_RSA_WITH_SEED_CBC_SHA uint16 = 0x0098 + http2cipher_TLS_DHE_DSS_WITH_SEED_CBC_SHA uint16 = 0x0099 + http2cipher_TLS_DHE_RSA_WITH_SEED_CBC_SHA uint16 = 0x009A + http2cipher_TLS_DH_anon_WITH_SEED_CBC_SHA uint16 = 0x009B + http2cipher_TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009C + http2cipher_TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009D + http2cipher_TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009E + http2cipher_TLS_DHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009F + http2cipher_TLS_DH_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x00A0 + http2cipher_TLS_DH_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x00A1 + http2cipher_TLS_DHE_DSS_WITH_AES_128_GCM_SHA256 uint16 = 0x00A2 + http2cipher_TLS_DHE_DSS_WITH_AES_256_GCM_SHA384 uint16 = 0x00A3 + http2cipher_TLS_DH_DSS_WITH_AES_128_GCM_SHA256 uint16 = 0x00A4 + http2cipher_TLS_DH_DSS_WITH_AES_256_GCM_SHA384 uint16 = 0x00A5 + http2cipher_TLS_DH_anon_WITH_AES_128_GCM_SHA256 uint16 = 0x00A6 + http2cipher_TLS_DH_anon_WITH_AES_256_GCM_SHA384 uint16 = 0x00A7 + http2cipher_TLS_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00A8 + http2cipher_TLS_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00A9 + http2cipher_TLS_DHE_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00AA + http2cipher_TLS_DHE_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00AB + http2cipher_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0x00AC + http2cipher_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384 uint16 = 0x00AD + http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00AE + http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00AF + http2cipher_TLS_PSK_WITH_NULL_SHA256 uint16 = 0x00B0 + http2cipher_TLS_PSK_WITH_NULL_SHA384 uint16 = 0x00B1 + http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B2 + http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B3 + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA256 uint16 = 0x00B4 + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA384 uint16 = 0x00B5 + http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0x00B6 + http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0x00B7 + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA256 uint16 = 0x00B8 + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA384 uint16 = 0x00B9 + http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BA + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BB + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BC + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BD + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BE + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0x00BF + http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C0 + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C1 + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C2 + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C3 + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C4 + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA256 uint16 = 0x00C5 + // Unassigned uint16 = 0x00C6-FE + http2cipher_TLS_EMPTY_RENEGOTIATION_INFO_SCSV uint16 = 0x00FF + // Unassigned uint16 = 0x01-55,* + http2cipher_TLS_FALLBACK_SCSV uint16 = 0x5600 + // Unassigned uint16 = 0x5601 - 0xC000 + http2cipher_TLS_ECDH_ECDSA_WITH_NULL_SHA uint16 = 0xC001 + http2cipher_TLS_ECDH_ECDSA_WITH_RC4_128_SHA uint16 = 0xC002 + http2cipher_TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC003 + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xC004 + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xC005 + http2cipher_TLS_ECDHE_ECDSA_WITH_NULL_SHA uint16 = 0xC006 + http2cipher_TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xC007 + http2cipher_TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC008 + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xC009 + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xC00A + http2cipher_TLS_ECDH_RSA_WITH_NULL_SHA uint16 = 0xC00B + http2cipher_TLS_ECDH_RSA_WITH_RC4_128_SHA uint16 = 0xC00C + http2cipher_TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC00D + http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC00E + http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC00F + http2cipher_TLS_ECDHE_RSA_WITH_NULL_SHA uint16 = 0xC010 + http2cipher_TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xC011 + http2cipher_TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC012 + http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC013 + http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC014 + http2cipher_TLS_ECDH_anon_WITH_NULL_SHA uint16 = 0xC015 + http2cipher_TLS_ECDH_anon_WITH_RC4_128_SHA uint16 = 0xC016 + http2cipher_TLS_ECDH_anon_WITH_3DES_EDE_CBC_SHA uint16 = 0xC017 + http2cipher_TLS_ECDH_anon_WITH_AES_128_CBC_SHA uint16 = 0xC018 + http2cipher_TLS_ECDH_anon_WITH_AES_256_CBC_SHA uint16 = 0xC019 + http2cipher_TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01A + http2cipher_TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01B + http2cipher_TLS_SRP_SHA_DSS_WITH_3DES_EDE_CBC_SHA uint16 = 0xC01C + http2cipher_TLS_SRP_SHA_WITH_AES_128_CBC_SHA uint16 = 0xC01D + http2cipher_TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA uint16 = 0xC01E + http2cipher_TLS_SRP_SHA_DSS_WITH_AES_128_CBC_SHA uint16 = 0xC01F + http2cipher_TLS_SRP_SHA_WITH_AES_256_CBC_SHA uint16 = 0xC020 + http2cipher_TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA uint16 = 0xC021 + http2cipher_TLS_SRP_SHA_DSS_WITH_AES_256_CBC_SHA uint16 = 0xC022 + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC023 + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC024 + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC025 + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC026 + http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC027 + http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC028 + http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xC029 + http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA384 uint16 = 0xC02A + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02B + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC02C + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02D + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC02E + http2cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC02F + http2cipher_TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC030 + http2cipher_TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xC031 + http2cipher_TLS_ECDH_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xC032 + http2cipher_TLS_ECDHE_PSK_WITH_RC4_128_SHA uint16 = 0xC033 + http2cipher_TLS_ECDHE_PSK_WITH_3DES_EDE_CBC_SHA uint16 = 0xC034 + http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA uint16 = 0xC035 + http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA uint16 = 0xC036 + http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 uint16 = 0xC037 + http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384 uint16 = 0xC038 + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA uint16 = 0xC039 + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA256 uint16 = 0xC03A + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA384 uint16 = 0xC03B + http2cipher_TLS_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC03C + http2cipher_TLS_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC03D + http2cipher_TLS_DH_DSS_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC03E + http2cipher_TLS_DH_DSS_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC03F + http2cipher_TLS_DH_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC040 + http2cipher_TLS_DH_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC041 + http2cipher_TLS_DHE_DSS_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC042 + http2cipher_TLS_DHE_DSS_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC043 + http2cipher_TLS_DHE_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC044 + http2cipher_TLS_DHE_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC045 + http2cipher_TLS_DH_anon_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC046 + http2cipher_TLS_DH_anon_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC047 + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC048 + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC049 + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04A + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04B + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04C + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04D + http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC04E + http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC04F + http2cipher_TLS_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC050 + http2cipher_TLS_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC051 + http2cipher_TLS_DHE_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC052 + http2cipher_TLS_DHE_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC053 + http2cipher_TLS_DH_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC054 + http2cipher_TLS_DH_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC055 + http2cipher_TLS_DHE_DSS_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC056 + http2cipher_TLS_DHE_DSS_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC057 + http2cipher_TLS_DH_DSS_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC058 + http2cipher_TLS_DH_DSS_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC059 + http2cipher_TLS_DH_anon_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05A + http2cipher_TLS_DH_anon_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05B + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05C + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05D + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC05E + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC05F + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC060 + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC061 + http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC062 + http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC063 + http2cipher_TLS_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC064 + http2cipher_TLS_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC065 + http2cipher_TLS_DHE_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC066 + http2cipher_TLS_DHE_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC067 + http2cipher_TLS_RSA_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC068 + http2cipher_TLS_RSA_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC069 + http2cipher_TLS_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06A + http2cipher_TLS_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06B + http2cipher_TLS_DHE_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06C + http2cipher_TLS_DHE_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06D + http2cipher_TLS_RSA_PSK_WITH_ARIA_128_GCM_SHA256 uint16 = 0xC06E + http2cipher_TLS_RSA_PSK_WITH_ARIA_256_GCM_SHA384 uint16 = 0xC06F + http2cipher_TLS_ECDHE_PSK_WITH_ARIA_128_CBC_SHA256 uint16 = 0xC070 + http2cipher_TLS_ECDHE_PSK_WITH_ARIA_256_CBC_SHA384 uint16 = 0xC071 + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC072 + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC073 + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC074 + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC075 + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC076 + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC077 + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC078 + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC079 + http2cipher_TLS_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07A + http2cipher_TLS_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07B + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07C + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07D + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC07E + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC07F + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC080 + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC081 + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC082 + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC083 + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC084 + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC085 + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC086 + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC087 + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC088 + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC089 + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08A + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08B + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08C + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08D + http2cipher_TLS_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC08E + http2cipher_TLS_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC08F + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC090 + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC091 + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_GCM_SHA256 uint16 = 0xC092 + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_GCM_SHA384 uint16 = 0xC093 + http2cipher_TLS_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC094 + http2cipher_TLS_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC095 + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC096 + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC097 + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC098 + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC099 + http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_128_CBC_SHA256 uint16 = 0xC09A + http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_256_CBC_SHA384 uint16 = 0xC09B + http2cipher_TLS_RSA_WITH_AES_128_CCM uint16 = 0xC09C + http2cipher_TLS_RSA_WITH_AES_256_CCM uint16 = 0xC09D + http2cipher_TLS_DHE_RSA_WITH_AES_128_CCM uint16 = 0xC09E + http2cipher_TLS_DHE_RSA_WITH_AES_256_CCM uint16 = 0xC09F + http2cipher_TLS_RSA_WITH_AES_128_CCM_8 uint16 = 0xC0A0 + http2cipher_TLS_RSA_WITH_AES_256_CCM_8 uint16 = 0xC0A1 + http2cipher_TLS_DHE_RSA_WITH_AES_128_CCM_8 uint16 = 0xC0A2 + http2cipher_TLS_DHE_RSA_WITH_AES_256_CCM_8 uint16 = 0xC0A3 + http2cipher_TLS_PSK_WITH_AES_128_CCM uint16 = 0xC0A4 + http2cipher_TLS_PSK_WITH_AES_256_CCM uint16 = 0xC0A5 + http2cipher_TLS_DHE_PSK_WITH_AES_128_CCM uint16 = 0xC0A6 + http2cipher_TLS_DHE_PSK_WITH_AES_256_CCM uint16 = 0xC0A7 + http2cipher_TLS_PSK_WITH_AES_128_CCM_8 uint16 = 0xC0A8 + http2cipher_TLS_PSK_WITH_AES_256_CCM_8 uint16 = 0xC0A9 + http2cipher_TLS_PSK_DHE_WITH_AES_128_CCM_8 uint16 = 0xC0AA + http2cipher_TLS_PSK_DHE_WITH_AES_256_CCM_8 uint16 = 0xC0AB + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CCM uint16 = 0xC0AC + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CCM uint16 = 0xC0AD + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 uint16 = 0xC0AE + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CCM_8 uint16 = 0xC0AF + // Unassigned uint16 = 0xC0B0-FF + // Unassigned uint16 = 0xC1-CB,* + // Unassigned uint16 = 0xCC00-A7 + http2cipher_TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCA8 + http2cipher_TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCA9 + http2cipher_TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAA + http2cipher_TLS_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAB + http2cipher_TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAC + http2cipher_TLS_DHE_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAD + http2cipher_TLS_RSA_PSK_WITH_CHACHA20_POLY1305_SHA256 uint16 = 0xCCAE +) + +// isBadCipher reports whether the cipher is blacklisted by the HTTP/2 spec. +// References: +// https://tools.ietf.org/html/rfc7540#appendix-A +// Reject cipher suites from Appendix A. +// "This list includes those cipher suites that do not +// offer an ephemeral key exchange and those that are +// based on the TLS null, stream or block cipher type" +func http2isBadCipher(cipher uint16) bool { + switch cipher { + case http2cipher_TLS_NULL_WITH_NULL_NULL, + http2cipher_TLS_RSA_WITH_NULL_MD5, + http2cipher_TLS_RSA_WITH_NULL_SHA, + http2cipher_TLS_RSA_EXPORT_WITH_RC4_40_MD5, + http2cipher_TLS_RSA_WITH_RC4_128_MD5, + http2cipher_TLS_RSA_WITH_RC4_128_SHA, + http2cipher_TLS_RSA_EXPORT_WITH_RC2_CBC_40_MD5, + http2cipher_TLS_RSA_WITH_IDEA_CBC_SHA, + http2cipher_TLS_RSA_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_RSA_WITH_DES_CBC_SHA, + http2cipher_TLS_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DH_DSS_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_DES_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DH_RSA_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_DES_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_DES_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_DES_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DH_anon_EXPORT_WITH_RC4_40_MD5, + http2cipher_TLS_DH_anon_WITH_RC4_128_MD5, + http2cipher_TLS_DH_anon_EXPORT_WITH_DES40_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_DES_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_KRB5_WITH_DES_CBC_SHA, + http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_KRB5_WITH_RC4_128_SHA, + http2cipher_TLS_KRB5_WITH_IDEA_CBC_SHA, + http2cipher_TLS_KRB5_WITH_DES_CBC_MD5, + http2cipher_TLS_KRB5_WITH_3DES_EDE_CBC_MD5, + http2cipher_TLS_KRB5_WITH_RC4_128_MD5, + http2cipher_TLS_KRB5_WITH_IDEA_CBC_MD5, + http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA, + http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_SHA, + http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_SHA, + http2cipher_TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5, + http2cipher_TLS_KRB5_EXPORT_WITH_RC2_CBC_40_MD5, + http2cipher_TLS_KRB5_EXPORT_WITH_RC4_40_MD5, + http2cipher_TLS_PSK_WITH_NULL_SHA, + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA, + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA, + http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA, + http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA, + http2cipher_TLS_RSA_WITH_NULL_SHA256, + http2cipher_TLS_RSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_RSA_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_DH_DSS_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_DH_RSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_DHE_DSS_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_DH_DSS_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_DH_RSA_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_DHE_DSS_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_DH_anon_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_DH_anon_WITH_AES_256_CBC_SHA256, + http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA, + http2cipher_TLS_PSK_WITH_RC4_128_SHA, + http2cipher_TLS_PSK_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA, + http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA, + http2cipher_TLS_DHE_PSK_WITH_RC4_128_SHA, + http2cipher_TLS_DHE_PSK_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA, + http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA, + http2cipher_TLS_RSA_PSK_WITH_RC4_128_SHA, + http2cipher_TLS_RSA_PSK_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA, + http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA, + http2cipher_TLS_RSA_WITH_SEED_CBC_SHA, + http2cipher_TLS_DH_DSS_WITH_SEED_CBC_SHA, + http2cipher_TLS_DH_RSA_WITH_SEED_CBC_SHA, + http2cipher_TLS_DHE_DSS_WITH_SEED_CBC_SHA, + http2cipher_TLS_DHE_RSA_WITH_SEED_CBC_SHA, + http2cipher_TLS_DH_anon_WITH_SEED_CBC_SHA, + http2cipher_TLS_RSA_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_RSA_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_DH_RSA_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_DH_RSA_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_DH_DSS_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_DH_DSS_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_DH_anon_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_DH_anon_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_PSK_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_PSK_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_PSK_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_PSK_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_PSK_WITH_NULL_SHA256, + http2cipher_TLS_PSK_WITH_NULL_SHA384, + http2cipher_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA256, + http2cipher_TLS_DHE_PSK_WITH_NULL_SHA384, + http2cipher_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA256, + http2cipher_TLS_RSA_PSK_WITH_NULL_SHA384, + http2cipher_TLS_RSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_RSA_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_DHE_DSS_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_DHE_RSA_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_CBC_SHA256, + http2cipher_TLS_EMPTY_RENEGOTIATION_INFO_SCSV, + http2cipher_TLS_ECDH_ECDSA_WITH_NULL_SHA, + http2cipher_TLS_ECDH_ECDSA_WITH_RC4_128_SHA, + http2cipher_TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_NULL_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDH_RSA_WITH_NULL_SHA, + http2cipher_TLS_ECDH_RSA_WITH_RC4_128_SHA, + http2cipher_TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDHE_RSA_WITH_NULL_SHA, + http2cipher_TLS_ECDHE_RSA_WITH_RC4_128_SHA, + http2cipher_TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDH_anon_WITH_NULL_SHA, + http2cipher_TLS_ECDH_anon_WITH_RC4_128_SHA, + http2cipher_TLS_ECDH_anon_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDH_anon_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDH_anon_WITH_AES_256_CBC_SHA, + http2cipher_TLS_SRP_SHA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_SRP_SHA_RSA_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_SRP_SHA_DSS_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_SRP_SHA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_SRP_SHA_RSA_WITH_AES_128_CBC_SHA, + http2cipher_TLS_SRP_SHA_DSS_WITH_AES_128_CBC_SHA, + http2cipher_TLS_SRP_SHA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_SRP_SHA_RSA_WITH_AES_256_CBC_SHA, + http2cipher_TLS_SRP_SHA_DSS_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_AES_256_GCM_SHA384, + http2cipher_TLS_ECDHE_PSK_WITH_RC4_128_SHA, + http2cipher_TLS_ECDHE_PSK_WITH_3DES_EDE_CBC_SHA, + http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA, + http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA, + http2cipher_TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256, + http2cipher_TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384, + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA, + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA256, + http2cipher_TLS_ECDHE_PSK_WITH_NULL_SHA384, + http2cipher_TLS_RSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_RSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DH_DSS_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DH_DSS_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DH_RSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DH_RSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DHE_DSS_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DHE_DSS_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DHE_RSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DHE_RSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DH_anon_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DH_anon_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_ECDSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_RSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_RSA_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_RSA_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_DH_RSA_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_DH_RSA_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_DH_DSS_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_DH_DSS_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_DH_anon_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_DH_anon_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_PSK_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_PSK_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_DHE_PSK_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_DHE_PSK_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_RSA_PSK_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_RSA_PSK_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_PSK_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_PSK_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_RSA_PSK_WITH_ARIA_128_GCM_SHA256, + http2cipher_TLS_RSA_PSK_WITH_ARIA_256_GCM_SHA384, + http2cipher_TLS_ECDHE_PSK_WITH_ARIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_PSK_WITH_ARIA_256_CBC_SHA384, + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_ECDSA_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_RSA_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_RSA_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_RSA_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_DH_RSA_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_DH_DSS_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_DH_anon_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_ECDH_ECDSA_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_ECDH_RSA_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_PSK_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_PSK_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_GCM_SHA256, + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_GCM_SHA384, + http2cipher_TLS_PSK_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_PSK_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_DHE_PSK_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_RSA_PSK_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_128_CBC_SHA256, + http2cipher_TLS_ECDHE_PSK_WITH_CAMELLIA_256_CBC_SHA384, + http2cipher_TLS_RSA_WITH_AES_128_CCM, + http2cipher_TLS_RSA_WITH_AES_256_CCM, + http2cipher_TLS_RSA_WITH_AES_128_CCM_8, + http2cipher_TLS_RSA_WITH_AES_256_CCM_8, + http2cipher_TLS_PSK_WITH_AES_128_CCM, + http2cipher_TLS_PSK_WITH_AES_256_CCM, + http2cipher_TLS_PSK_WITH_AES_128_CCM_8, + http2cipher_TLS_PSK_WITH_AES_256_CCM_8: + return true + default: + return false + } +} + // ClientConnPool manages a pool of HTTP/2 client connections. type http2ClientConnPool interface { GetClientConn(req *Request, addr string) (*http2ClientConn, error) @@ -126,7 +762,7 @@ type http2dialCall struct { // requires p.mu is held. func (p *http2clientConnPool) getStartDialLocked(addr string) *http2dialCall { if call, ok := p.dialing[addr]; ok { - + // A dial is already in-flight. Don't start another. return call } call := &http2dialCall{p: p, done: make(chan struct{})} @@ -254,7 +890,12 @@ func (p *http2clientConnPool) MarkDead(cc *http2ClientConn) { func (p *http2clientConnPool) closeIdleConnections() { p.mu.Lock() defer p.mu.Unlock() - + // TODO: don't close a cc if it was just added to the pool + // milliseconds ago and has never been used. There's currently + // a small race window with the HTTP/1 Transport's integration + // where it can add an idle conn just before using it, and + // somebody else can concurrently call CloseIdleConns and + // break some caller's RoundTrip. for _, vv := range p.conns { for _, cc := range vv { cc.closeIfIdle() @@ -269,7 +910,8 @@ func http2filterOutClientConn(in []*http2ClientConn, exclude *http2ClientConn) [ out = append(out, v) } } - + // If we filtered it out, zero out the last item to prevent + // the GC from seeing it. if len(in) != len(out) { in[len(in)-1] = nil } @@ -277,7 +919,7 @@ func http2filterOutClientConn(in []*http2ClientConn, exclude *http2ClientConn) [ } // noDialClientConnPool is an implementation of http2.ClientConnPool -// which never dials. We let the HTTP/1.1 client dial and use its TLS +// which never dials. We let the HTTP/1.1 client dial and use its TLS // connection instead. type http2noDialClientConnPool struct{ *http2clientConnPool } @@ -310,7 +952,10 @@ func http2configureTransport(t1 *Transport) (*http2Transport, error) { go c.Close() return http2erringRoundTripper{err} } else if !used { - + // Turns out we don't need this c. + // For example, two goroutines made requests to the same host + // at the same time, both kicking off TCP dials. (since protocol + // was unknown) go c.Close() } return t2 @@ -326,7 +971,7 @@ func http2configureTransport(t1 *Transport) (*http2Transport, error) { } // registerHTTPSProtocol calls Transport.RegisterProtocol but -// convering panics into errors. +// converting panics into errors. func http2registerHTTPSProtocol(t *Transport, rt RoundTripper) (err error) { defer func() { if e := recover(); e != nil { @@ -349,6 +994,141 @@ func (rt http2noDialH2RoundTripper) RoundTrip(req *Request) (*Response, error) { return res, err } +// Buffer chunks are allocated from a pool to reduce pressure on GC. +// The maximum wasted space per dataBuffer is 2x the largest size class, +// which happens when the dataBuffer has multiple chunks and there is +// one unread byte in both the first and last chunks. We use a few size +// classes to minimize overheads for servers that typically receive very +// small request bodies. +// +// TODO: Benchmark to determine if the pools are necessary. The GC may have +// improved enough that we can instead allocate chunks like this: +// make([]byte, max(16<<10, expectedBytesRemaining)) +var ( + http2dataChunkSizeClasses = []int{ + 1 << 10, + 2 << 10, + 4 << 10, + 8 << 10, + 16 << 10, + } + http2dataChunkPools = [...]sync.Pool{ + {New: func() interface{} { return make([]byte, 1<<10) }}, + {New: func() interface{} { return make([]byte, 2<<10) }}, + {New: func() interface{} { return make([]byte, 4<<10) }}, + {New: func() interface{} { return make([]byte, 8<<10) }}, + {New: func() interface{} { return make([]byte, 16<<10) }}, + } +) + +func http2getDataBufferChunk(size int64) []byte { + i := 0 + for ; i < len(http2dataChunkSizeClasses)-1; i++ { + if size <= int64(http2dataChunkSizeClasses[i]) { + break + } + } + return http2dataChunkPools[i].Get().([]byte) +} + +func http2putDataBufferChunk(p []byte) { + for i, n := range http2dataChunkSizeClasses { + if len(p) == n { + http2dataChunkPools[i].Put(p) + return + } + } + panic(fmt.Sprintf("unexpected buffer len=%v", len(p))) +} + +// dataBuffer is an io.ReadWriter backed by a list of data chunks. +// Each dataBuffer is used to read DATA frames on a single stream. +// The buffer is divided into chunks so the server can limit the +// total memory used by a single connection without limiting the +// request body size on any single stream. +type http2dataBuffer struct { + chunks [][]byte + r int // next byte to read is chunks[0][r] + w int // next byte to write is chunks[len(chunks)-1][w] + size int // total buffered bytes + expected int64 // we expect at least this many bytes in future Write calls (ignored if <= 0) +} + +var http2errReadEmpty = errors.New("read from empty dataBuffer") + +// Read copies bytes from the buffer into p. +// It is an error to read when no data is available. +func (b *http2dataBuffer) Read(p []byte) (int, error) { + if b.size == 0 { + return 0, http2errReadEmpty + } + var ntotal int + for len(p) > 0 && b.size > 0 { + readFrom := b.bytesFromFirstChunk() + n := copy(p, readFrom) + p = p[n:] + ntotal += n + b.r += n + b.size -= n + // If the first chunk has been consumed, advance to the next chunk. + if b.r == len(b.chunks[0]) { + http2putDataBufferChunk(b.chunks[0]) + end := len(b.chunks) - 1 + copy(b.chunks[:end], b.chunks[1:]) + b.chunks[end] = nil + b.chunks = b.chunks[:end] + b.r = 0 + } + } + return ntotal, nil +} + +func (b *http2dataBuffer) bytesFromFirstChunk() []byte { + if len(b.chunks) == 1 { + return b.chunks[0][b.r:b.w] + } + return b.chunks[0][b.r:] +} + +// Len returns the number of bytes of the unread portion of the buffer. +func (b *http2dataBuffer) Len() int { + return b.size +} + +// Write appends p to the buffer. +func (b *http2dataBuffer) Write(p []byte) (int, error) { + ntotal := len(p) + for len(p) > 0 { + // If the last chunk is empty, allocate a new chunk. Try to allocate + // enough to fully copy p plus any additional bytes we expect to + // receive. However, this may allocate less than len(p). + want := int64(len(p)) + if b.expected > want { + want = b.expected + } + chunk := b.lastChunkOrAlloc(want) + n := copy(chunk[b.w:], p) + p = p[n:] + b.w += n + b.size += n + b.expected -= int64(n) + } + return ntotal, nil +} + +func (b *http2dataBuffer) lastChunkOrAlloc(want int64) []byte { + if len(b.chunks) != 0 { + last := b.chunks[len(b.chunks)-1] + if b.w < len(last) { + return last + } + } + chunk := http2getDataBufferChunk(want) + b.chunks = append(b.chunks, chunk) + b.w = 0 + return chunk +} + // An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec. type http2ErrCode uint32 @@ -429,11 +1209,16 @@ type http2goAwayFlowError struct{} func (http2goAwayFlowError) Error() string { return "connection exceeded flow control window size" } +// connError represents an HTTP/2 ConnectionError error code, along +// with a string (for debugging) explaining why. +// // Errors of this type are only returned by the frame parser functions -// and converted into ConnectionError(ErrCodeProtocol). +// and converted into ConnectionError(Code), after stashing away +// the Reason into the Framer's errDetail field, accessible via +// the (*Framer).ErrorDetail method. type http2connError struct { - Code http2ErrCode - Reason string + Code http2ErrCode // the ConnectionError error code + Reason string // additional reason } func (e http2connError) Error() string { @@ -469,56 +1254,6 @@ var ( http2errPseudoAfterRegular = errors.New("pseudo header field after regular") ) -// fixedBuffer is an io.ReadWriter backed by a fixed size buffer. -// It never allocates, but moves old data as new data is written. -type http2fixedBuffer struct { - buf []byte - r, w int -} - -var ( - http2errReadEmpty = errors.New("read from empty fixedBuffer") - http2errWriteFull = errors.New("write on full fixedBuffer") -) - -// Read copies bytes from the buffer into p. -// It is an error to read when no data is available. -func (b *http2fixedBuffer) Read(p []byte) (n int, err error) { - if b.r == b.w { - return 0, http2errReadEmpty - } - n = copy(p, b.buf[b.r:b.w]) - b.r += n - if b.r == b.w { - b.r = 0 - b.w = 0 - } - return n, nil -} - -// Len returns the number of bytes of the unread portion of the buffer. -func (b *http2fixedBuffer) Len() int { - return b.w - b.r -} - -// Write copies bytes from p into the buffer. -// It is an error to write more data than the buffer can hold. -func (b *http2fixedBuffer) Write(p []byte) (n int, err error) { - - if b.r > 0 && len(p) > len(b.buf)-b.w { - copy(b.buf, b.buf[b.r:b.w]) - b.w -= b.r - b.r = 0 - } - - n = copy(b.buf[b.w:], p) - b.w += n - if n < len(p) { - err = http2errWriteFull - } - return n, err -} - // flow is the flow control window's size. type http2flow struct { // n is the number of DATA bytes we're allowed to send. @@ -666,7 +1401,7 @@ var http2flagName = map[http2FrameType]map[http2Flags]string{ // a frameParser parses a frame given its FrameHeader and payload // bytes. The length of payload will always equal fh.Length (which // might be 0). -type http2frameParser func(fh http2FrameHeader, payload []byte) (http2Frame, error) +type http2frameParser func(fc *http2frameCache, fh http2FrameHeader, payload []byte) (http2Frame, error) var http2frameParsers = map[http2FrameType]http2frameParser{ http2FrameData: http2parseDataFrame, @@ -855,25 +1590,33 @@ type http2Framer struct { // If the limit is hit, MetaHeadersFrame.Truncated is set true. MaxHeaderListSize uint32 + // TODO: track which type of frame & with which flags was sent + // last. Then return an error (unless AllowIllegalWrites) if + // we're in the middle of a header block and a + // non-Continuation or Continuation on a different stream is + // attempted to be written. + logReads, logWrites bool debugFramer *http2Framer // only use for logging written writes debugFramerBuf *bytes.Buffer debugReadLoggerf func(string, ...interface{}) debugWriteLoggerf func(string, ...interface{}) + + frameCache *http2frameCache // nil if frames aren't reused (default) } func (fr *http2Framer) maxHeaderListSize() uint32 { if fr.MaxHeaderListSize == 0 { - return 16 << 20 + return 16 << 20 // sane default, per docs } return fr.MaxHeaderListSize } func (f *http2Framer) startWrite(ftype http2FrameType, flags http2Flags, streamID uint32) { - + // Write the FrameHeader. f.wbuf = append(f.wbuf[:0], - 0, + 0, // 3 bytes of length, filled in in endWrite 0, 0, byte(ftype), @@ -885,7 +1628,8 @@ func (f *http2Framer) startWrite(ftype http2FrameType, flags http2Flags, streamI } func (f *http2Framer) endWrite() error { - + // Now that we know the final size, fill in the FrameHeader in + // the space previously reserved for it. Abuse append. length := len(f.wbuf) - http2frameHeaderLen if length >= (1 << 24) { return http2ErrFrameTooLarge @@ -909,8 +1653,9 @@ func (f *http2Framer) logWrite() { if f.debugFramer == nil { f.debugFramerBuf = new(bytes.Buffer) f.debugFramer = http2NewFramer(nil, f.debugFramerBuf) - f.debugFramer.logReads = false - + f.debugFramer.logReads = false // we log it ourselves, saying "wrote" below + // Let us read anything, even if we accidentally wrote it + // in the wrong order: f.debugFramer.AllowIllegalReads = true } f.debugFramerBuf.Write(f.wbuf) @@ -937,6 +1682,27 @@ const ( http2maxFrameSize = 1<<24 - 1 ) +// SetReuseFrames allows the Framer to reuse Frames. +// If called on a Framer, Frames returned by calls to ReadFrame are only +// valid until the next call to ReadFrame. +func (fr *http2Framer) SetReuseFrames() { + if fr.frameCache != nil { + return + } + fr.frameCache = &http2frameCache{} +} + +type http2frameCache struct { + dataFrame http2DataFrame +} + +func (fc *http2frameCache) getDataFrame() *http2DataFrame { + if fc == nil { + return &http2DataFrame{} + } + return &fc.dataFrame +} + // NewFramer returns a Framer that writes frames to w and reads them from r. func http2NewFramer(w io.Writer, r io.Reader) *http2Framer { fr := &http2Framer{ @@ -1016,7 +1782,7 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) { if _, err := io.ReadFull(fr.r, payload); err != nil { return nil, err } - f, err := http2typeFrameParser(fh.Type)(fh, payload) + f, err := http2typeFrameParser(fh.Type)(fr.frameCache, fh, payload) if err != nil { if ce, ok := err.(http2connError); ok { return nil, fr.connError(ce.Code, ce.Reason) @@ -1104,14 +1870,18 @@ func (f *http2DataFrame) Data() []byte { return f.data } -func http2parseDataFrame(fh http2FrameHeader, payload []byte) (http2Frame, error) { +func http2parseDataFrame(fc *http2frameCache, fh http2FrameHeader, payload []byte) (http2Frame, error) { if fh.StreamID == 0 { - + // DATA frames MUST be associated with a stream. If a + // DATA frame is received whose stream identifier + // field is 0x0, the recipient MUST respond with a + // connection error (Section 5.4.1) of type + // PROTOCOL_ERROR. return nil, http2connError{http2ErrCodeProtocol, "DATA frame with stream ID 0"} } - f := &http2DataFrame{ - http2FrameHeader: fh, - } + f := fc.getDataFrame() + f.http2FrameHeader = fh + var padSize byte if fh.Flags.Has(http2FlagDataPadded) { var err error @@ -1121,7 +1891,10 @@ func http2parseDataFrame(fh http2FrameHeader, payload []byte) (http2Frame, error } } if int(padSize) > len(payload) { - + // If the length of the padding is greater than the + // length of the frame payload, the recipient MUST + // treat this as a connection error. + // Filed: https://github.com/http2/http2-spec/issues/610 return nil, http2connError{http2ErrCodeProtocol, "pad size larger than data payload"} } f.data = payload[:len(payload)-int(padSize)] @@ -1132,6 +1905,7 @@ var ( http2errStreamID = errors.New("invalid stream ID") http2errDepStreamID = errors.New("invalid dependent stream ID") http2errPadLength = errors.New("pad length too large") + http2errPadBytes = errors.New("padding bytes must all be zeros unless AllowIllegalWrites is enabled") ) func http2validStreamIDOrZero(streamID uint32) bool { @@ -1155,6 +1929,7 @@ func (f *http2Framer) WriteData(streamID uint32, endStream bool, data []byte) er // // If pad is nil, the padding bit is not sent. // The length of pad must not exceed 255 bytes. +// The bytes of pad must all be zero, unless f.AllowIllegalWrites is set. // // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility not to violate the maximum frame size @@ -1163,8 +1938,18 @@ func (f *http2Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad if !http2validStreamID(streamID) && !f.AllowIllegalWrites { return http2errStreamID } - if len(pad) > 255 { - return http2errPadLength + if len(pad) > 0 { + if len(pad) > 255 { + return http2errPadLength + } + if !f.AllowIllegalWrites { + for _, b := range pad { + if b != 0 { + // "Padding octets MUST be set to zero when sending." + return http2errPadBytes + } + } + } } var flags http2Flags if endStream { @@ -1192,22 +1977,35 @@ type http2SettingsFrame struct { p []byte } -func http2parseSettingsFrame(fh http2FrameHeader, p []byte) (http2Frame, error) { +func http2parseSettingsFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { if fh.Flags.Has(http2FlagSettingsAck) && fh.Length > 0 { - + // When this (ACK 0x1) bit is set, the payload of the + // SETTINGS frame MUST be empty. Receipt of a + // SETTINGS frame with the ACK flag set and a length + // field value other than 0 MUST be treated as a + // connection error (Section 5.4.1) of type + // FRAME_SIZE_ERROR. return nil, http2ConnectionError(http2ErrCodeFrameSize) } if fh.StreamID != 0 { - + // SETTINGS frames always apply to a connection, + // never a single stream. The stream identifier for a + // SETTINGS frame MUST be zero (0x0). If an endpoint + // receives a SETTINGS frame whose stream identifier + // field is anything other than 0x0, the endpoint MUST + // respond with a connection error (Section 5.4.1) of + // type PROTOCOL_ERROR. return nil, http2ConnectionError(http2ErrCodeProtocol) } if len(p)%6 != 0 { - + // Expecting even number of 6 byte settings. return nil, http2ConnectionError(http2ErrCodeFrameSize) } f := &http2SettingsFrame{http2FrameHeader: fh, p: p} if v, ok := f.Value(http2SettingInitialWindowSize); ok && v > (1<<31)-1 { - + // Values above the maximum flow control window size of 2^31 - 1 MUST + // be treated as a connection error (Section 5.4.1) of type + // FLOW_CONTROL_ERROR. return nil, http2ConnectionError(http2ErrCodeFlowControl) } return f, nil @@ -1281,7 +2079,7 @@ type http2PingFrame struct { func (f *http2PingFrame) IsAck() bool { return f.Flags.Has(http2FlagPingAck) } -func http2parsePingFrame(fh http2FrameHeader, payload []byte) (http2Frame, error) { +func http2parsePingFrame(_ *http2frameCache, fh http2FrameHeader, payload []byte) (http2Frame, error) { if len(payload) != 8 { return nil, http2ConnectionError(http2ErrCodeFrameSize) } @@ -1321,7 +2119,7 @@ func (f *http2GoAwayFrame) DebugData() []byte { return f.debugData } -func http2parseGoAwayFrame(fh http2FrameHeader, p []byte) (http2Frame, error) { +func http2parseGoAwayFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { if fh.StreamID != 0 { return nil, http2ConnectionError(http2ErrCodeProtocol) } @@ -1361,7 +2159,7 @@ func (f *http2UnknownFrame) Payload() []byte { return f.p } -func http2parseUnknownFrame(fh http2FrameHeader, p []byte) (http2Frame, error) { +func http2parseUnknownFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { return &http2UnknownFrame{fh, p}, nil } @@ -1372,13 +2170,18 @@ type http2WindowUpdateFrame struct { Increment uint32 // never read with high bit set } -func http2parseWindowUpdateFrame(fh http2FrameHeader, p []byte) (http2Frame, error) { +func http2parseWindowUpdateFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { if len(p) != 4 { return nil, http2ConnectionError(http2ErrCodeFrameSize) } - inc := binary.BigEndian.Uint32(p[:4]) & 0x7fffffff + inc := binary.BigEndian.Uint32(p[:4]) & 0x7fffffff // mask off high reserved bit if inc == 0 { - + // A receiver MUST treat the receipt of a + // WINDOW_UPDATE frame with an flow control window + // increment of 0 as a stream error (Section 5.4.2) of + // type PROTOCOL_ERROR; errors on the connection flow + // control window MUST be treated as a connection + // error (Section 5.4.1). if fh.StreamID == 0 { return nil, http2ConnectionError(http2ErrCodeProtocol) } @@ -1395,7 +2198,7 @@ func http2parseWindowUpdateFrame(fh http2FrameHeader, p []byte) (http2Frame, err // If the Stream ID is zero, the window update applies to the // connection as a whole. func (f *http2Framer) WriteWindowUpdate(streamID, incr uint32) error { - + // "The legal range for the increment to the flow control window is 1 to 2^31-1 (2,147,483,647) octets." if (incr < 1 || incr > 2147483647) && !f.AllowIllegalWrites { return errors.New("illegal window increment value") } @@ -1432,12 +2235,15 @@ func (f *http2HeadersFrame) HasPriority() bool { return f.http2FrameHeader.Flags.Has(http2FlagHeadersPriority) } -func http2parseHeadersFrame(fh http2FrameHeader, p []byte) (_ http2Frame, err error) { +func http2parseHeadersFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (_ http2Frame, err error) { hf := &http2HeadersFrame{ http2FrameHeader: fh, } if fh.StreamID == 0 { - + // HEADERS frames MUST be associated with a stream. If a HEADERS frame + // is received whose stream identifier field is 0x0, the recipient MUST + // respond with a connection error (Section 5.4.1) of type + // PROTOCOL_ERROR. return nil, http2connError{http2ErrCodeProtocol, "HEADERS frame with stream ID 0"} } var padLength uint8 @@ -1453,7 +2259,7 @@ func http2parseHeadersFrame(fh http2FrameHeader, p []byte) (_ http2Frame, err er return nil, err } hf.Priority.StreamDep = v & 0x7fffffff - hf.Priority.Exclusive = (v != hf.Priority.StreamDep) + hf.Priority.Exclusive = (v != hf.Priority.StreamDep) // high bit was set p, hf.Priority.Weight, err = http2readByte(p) if err != nil { return nil, err @@ -1556,7 +2362,7 @@ type http2PriorityParam struct { Exclusive bool // Weight is the stream's zero-indexed weight. It should be - // set together with StreamDep, or neither should be set. Per + // set together with StreamDep, or neither should be set. Per // the spec, "Add one to the value to obtain a weight between // 1 and 256." Weight uint8 @@ -1566,7 +2372,7 @@ func (p http2PriorityParam) IsZero() bool { return p == http2PriorityParam{} } -func http2parsePriorityFrame(fh http2FrameHeader, payload []byte) (http2Frame, error) { +func http2parsePriorityFrame(_ *http2frameCache, fh http2FrameHeader, payload []byte) (http2Frame, error) { if fh.StreamID == 0 { return nil, http2connError{http2ErrCodeProtocol, "PRIORITY frame with stream ID 0"} } @@ -1574,13 +2380,13 @@ func http2parsePriorityFrame(fh http2FrameHeader, payload []byte) (http2Frame, e return nil, http2connError{http2ErrCodeFrameSize, fmt.Sprintf("PRIORITY frame payload size was %d; want 5", len(payload))} } v := binary.BigEndian.Uint32(payload[:4]) - streamID := v & 0x7fffffff + streamID := v & 0x7fffffff // mask off high bit return &http2PriorityFrame{ http2FrameHeader: fh, http2PriorityParam: http2PriorityParam{ Weight: payload[4], StreamDep: streamID, - Exclusive: streamID != v, + Exclusive: streamID != v, // was high bit set? }, }, nil } @@ -1613,7 +2419,7 @@ type http2RSTStreamFrame struct { ErrCode http2ErrCode } -func http2parseRSTStreamFrame(fh http2FrameHeader, p []byte) (http2Frame, error) { +func http2parseRSTStreamFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { if len(p) != 4 { return nil, http2ConnectionError(http2ErrCodeFrameSize) } @@ -1643,7 +2449,7 @@ type http2ContinuationFrame struct { headerFragBuf []byte } -func http2parseContinuationFrame(fh http2FrameHeader, p []byte) (http2Frame, error) { +func http2parseContinuationFrame(_ *http2frameCache, fh http2FrameHeader, p []byte) (http2Frame, error) { if fh.StreamID == 0 { return nil, http2connError{http2ErrCodeProtocol, "CONTINUATION frame with stream ID 0"} } @@ -1693,12 +2499,17 @@ func (f *http2PushPromiseFrame) HeadersEnded() bool { return f.http2FrameHeader.Flags.Has(http2FlagPushPromiseEndHeaders) } -func http2parsePushPromise(fh http2FrameHeader, p []byte) (_ http2Frame, err error) { +func http2parsePushPromise(_ *http2frameCache, fh http2FrameHeader, p []byte) (_ http2Frame, err error) { pp := &http2PushPromiseFrame{ http2FrameHeader: fh, } if pp.StreamID == 0 { - + // PUSH_PROMISE frames MUST be associated with an existing, + // peer-initiated stream. The stream identifier of a + // PUSH_PROMISE frame indicates the stream it is associated + // with. If the stream identifier field specifies the value + // 0x0, a recipient MUST respond with a connection error + // (Section 5.4.1) of type PROTOCOL_ERROR. return nil, http2ConnectionError(http2ErrCodeProtocol) } // The PUSH_PROMISE frame includes optional padding. @@ -1717,7 +2528,7 @@ func http2parsePushPromise(fh http2FrameHeader, p []byte) (_ http2Frame, err err pp.PromiseID = pp.PromiseID & (1<<31 - 1) if int(padLength) > len(p) { - + // like the DATA frame, error out if padding is longer than the body. return nil, http2ConnectionError(http2ErrCodeProtocol) } pp.headerFragBuf = p[:len(p)-int(padLength)] @@ -1887,7 +2698,9 @@ func (mh *http2MetaHeadersFrame) checkPseudos() error { default: return http2pseudoHeaderError(hf.Name) } - + // Check for duplicates. + // This would be a bad algorithm, but N is 4. + // And this doesn't allocate. for _, hf2 := range pf[:i] { if hf.Name == hf2.Name { return http2duplicatePseudoHeaderError(hf.Name) @@ -1905,7 +2718,8 @@ func (fr *http2Framer) maxHeaderStringLen() int { if uint32(int(v)) == v { return int(v) } - + // They had a crazy big number for MaxHeaderBytes anyway, + // so give them unlimited header lengths: return 0 } @@ -1960,7 +2774,7 @@ func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFr mh.Fields = append(mh.Fields, hf) }) - + // Lose reference to MetaHeadersFrame: defer hdec.SetEmitFunc(func(hf hpack.HeaderField) {}) var hc http2headersOrContinuation = hf @@ -1976,7 +2790,7 @@ func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFr if f, err := fr.ReadFrame(); err != nil { return nil, err } else { - hc = f.(*http2ContinuationFrame) + hc = f.(*http2ContinuationFrame) // guaranteed by checkFrameOrder } } @@ -2018,7 +2832,7 @@ func http2summarizeFrame(f http2Frame) string { return nil }) if n > 0 { - buf.Truncate(buf.Len() - 1) + buf.Truncate(buf.Len() - 1) // remove trailing comma } case *http2DataFrame: data := f.Data() @@ -2050,29 +2864,6 @@ func http2transportExpectContinueTimeout(t1 *Transport) time.Duration { return t1.ExpectContinueTimeout } -// isBadCipher reports whether the cipher is blacklisted by the HTTP/2 spec. -func http2isBadCipher(cipher uint16) bool { - switch cipher { - case tls.TLS_RSA_WITH_RC4_128_SHA, - tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, - tls.TLS_RSA_WITH_AES_128_CBC_SHA, - tls.TLS_RSA_WITH_AES_256_CBC_SHA, - tls.TLS_RSA_WITH_AES_128_GCM_SHA256, - tls.TLS_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA, - tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, - tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: - - return true - default: - return false - } -} - type http2contextContext interface { context.Context } @@ -2164,7 +2955,11 @@ func (cc *http2ClientConn) Ping(ctx context.Context) error { return cc.ping(ctx) } -func http2cloneTLSConfig(c *tls.Config) *tls.Config { return c.Clone() } +func http2cloneTLSConfig(c *tls.Config) *tls.Config { + c2 := c.Clone() + c2.GetClientCertificate = c.GetClientCertificate // golang.org/issue/19264 + return c2 +} var _ Pusher = (*http2responseWriter)(nil) @@ -2201,6 +2996,13 @@ func http2reqBodyIsNoBody(body io.ReadCloser) bool { return body == NoBody } +func http2go18httpNoBody() io.ReadCloser { return NoBody } // for tests only + +func http2configureServer19(s *Server, conf *http2Server) error { + s.RegisterOnShutdown(conf.state.startGracefulShutdown) + return nil +} + var http2DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1" type http2goroutineLock uint64 @@ -2237,7 +3039,7 @@ func http2curGoroutineID() uint64 { defer http2littleBuf.Put(bp) b := *bp b = b[:runtime.Stack(b, false)] - + // Parse the 4707 out of "goroutine 4707 [" b = bytes.TrimPrefix(b, http2goroutineSpace) i := bytes.IndexByte(b, ' ') if i < 0 { @@ -2273,9 +3075,10 @@ func http2parseUintBytes(s []byte, base int, bitSize int) (n uint64, err error) goto Error case 2 <= base && base <= 36: + // valid base; nothing to do case base == 0: - + // Look for octal, hex prefix. switch { case s[0] == '0' && len(s) > 1 && (s[1] == 'x' || s[1] == 'X'): base = 16 @@ -2321,7 +3124,7 @@ func http2parseUintBytes(s []byte, base int, bitSize int) (n uint64, err error) } if n >= cutoff { - + // n*base overflows n = 1<<64 - 1 err = strconv.ErrRange goto Error @@ -2330,7 +3133,7 @@ func http2parseUintBytes(s []byte, base int, bitSize int) (n uint64, err error) n1 := n + uint64(v) if n1 < n || n1 > maxVal { - + // n+v overflows n = 1<<64 - 1 err = strconv.ErrRange goto Error @@ -2514,7 +3317,7 @@ func (s http2Setting) String() string { // Valid reports whether the setting is valid. func (s http2Setting) Valid() error { - + // Limits and error codes from 6.5.2 Defined SETTINGS Parameters switch s.ID { case http2SettingEnablePush: if s.Val != 1 && s.Val != 0 { @@ -2758,7 +3561,8 @@ func (s *http2sorter) Keys(h Header) []string { } func (s *http2sorter) SortStrings(ss []string) { - + // Our sorter works on s.v, which sorter owns, so + // stash it away while we sort the user's buffer. save := s.v s.v = ss sort.Sort(s) @@ -2768,27 +3572,31 @@ func (s *http2sorter) SortStrings(ss []string) { // validPseudoPath reports whether v is a valid :path pseudo-header // value. It must be either: // -// *) a non-empty string starting with '/', but not with with "//", +// *) a non-empty string starting with '/' // *) the string '*', for OPTIONS requests. // // For now this is only used a quick check for deciding when to clean // up Opaque URLs before sending requests from the Transport. // See golang.org/issue/16847 +// +// We used to enforce that the path also didn't start with "//", but +// Google's GFE accepts such paths and Chrome sends them, so ignore +// that part of the spec. See golang.org/issue/19103. func http2validPseudoPath(v string) bool { - return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*" + return (len(v) > 0 && v[0] == '/') || v == "*" } -// pipe is a goroutine-safe io.Reader/io.Writer pair. It's like +// pipe is a goroutine-safe io.Reader/io.Writer pair. It's like // io.Pipe except there are no PipeReader/PipeWriter halves, and the // underlying buffer is an interface. (io.Pipe is always unbuffered) type http2pipe struct { mu sync.Mutex - c sync.Cond // c.L lazily initialized to &p.mu - b http2pipeBuffer - err error // read error once empty. non-nil means closed. - breakErr error // immediate read error (caller doesn't see rest of b) - donec chan struct{} // closed on error - readFn func() // optional code to run in Read before error + c sync.Cond // c.L lazily initialized to &p.mu + b http2pipeBuffer // nil when done reading + err error // read error once empty. non-nil means closed. + breakErr error // immediate read error (caller doesn't see rest of b) + donec chan struct{} // closed on error + readFn func() // optional code to run in Read before error } type http2pipeBuffer interface { @@ -2800,6 +3608,9 @@ type http2pipeBuffer interface { func (p *http2pipe) Len() int { p.mu.Lock() defer p.mu.Unlock() + if p.b == nil { + return 0 + } return p.b.Len() } @@ -2815,14 +3626,15 @@ func (p *http2pipe) Read(d []byte) (n int, err error) { if p.breakErr != nil { return 0, p.breakErr } - if p.b.Len() > 0 { + if p.b != nil && p.b.Len() > 0 { return p.b.Read(d) } if p.err != nil { if p.readFn != nil { - p.readFn() - p.readFn = nil + p.readFn() // e.g. copy trailers + p.readFn = nil // not sticky like p.err } + p.b = nil return 0, p.err } p.c.Wait() @@ -2843,6 +3655,9 @@ func (p *http2pipe) Write(d []byte) (n int, err error) { if p.err != nil { return 0, http2errClosedPipeWrite } + if p.breakErr != nil { + return len(d), nil // discard when there is no reader + } return p.b.Write(d) } @@ -2873,10 +3688,13 @@ func (p *http2pipe) closeWithError(dst *error, err error, fn func()) { } defer p.c.Signal() if *dst != nil { - + // Already been done. return } p.readFn = fn + if dst == &p.breakErr { + p.b = nil + } *dst = err p.closeDoneLocked() } @@ -2886,7 +3704,8 @@ func (p *http2pipe) closeDoneLocked() { if p.donec == nil { return } - + // Close if unclosed. This isn't racy since we always + // hold p.mu while closing. select { case <-p.donec: default: @@ -2912,7 +3731,7 @@ func (p *http2pipe) Done() <-chan struct{} { if p.donec == nil { p.donec = make(chan struct{}) if p.err != nil || p.breakErr != nil { - + // Already hit an error. p.closeDoneLocked() } } @@ -2980,9 +3799,41 @@ type http2Server struct { // activity for the purposes of IdleTimeout. IdleTimeout time.Duration + // MaxUploadBufferPerConnection is the size of the initial flow + // control window for each connections. The HTTP/2 spec does not + // allow this to be smaller than 65535 or larger than 2^32-1. + // If the value is outside this range, a default value will be + // used instead. + MaxUploadBufferPerConnection int32 + + // MaxUploadBufferPerStream is the size of the initial flow control + // window for each stream. The HTTP/2 spec does not allow this to + // be larger than 2^32-1. If the value is zero or larger than the + // maximum, a default value will be used instead. + MaxUploadBufferPerStream int32 + // NewWriteScheduler constructs a write scheduler for a connection. // If nil, a default scheduler is chosen. NewWriteScheduler func() http2WriteScheduler + + // Internal state. This is a pointer (rather than embedded directly) + // so that we don't embed a Mutex in this struct, which will make the + // struct non-copyable, which might break some callers. + state *http2serverInternalState +} + +func (s *http2Server) initialConnRecvWindowSize() int32 { + if s.MaxUploadBufferPerConnection > http2initialWindowSize { + return s.MaxUploadBufferPerConnection + } + return 1 << 20 +} + +func (s *http2Server) initialStreamRecvWindowSize() int32 { + if s.MaxUploadBufferPerStream > 0 { + return s.MaxUploadBufferPerStream + } + return 1 << 20 } func (s *http2Server) maxReadFrameSize() uint32 { @@ -2999,6 +3850,40 @@ func (s *http2Server) maxConcurrentStreams() uint32 { return http2defaultMaxStreams } +type http2serverInternalState struct { + mu sync.Mutex + activeConns map[*http2serverConn]struct{} +} + +func (s *http2serverInternalState) registerConn(sc *http2serverConn) { + if s == nil { + return // if the Server was used without calling ConfigureServer + } + s.mu.Lock() + s.activeConns[sc] = struct{}{} + s.mu.Unlock() +} + +func (s *http2serverInternalState) unregisterConn(sc *http2serverConn) { + if s == nil { + return // if the Server was used without calling ConfigureServer + } + s.mu.Lock() + delete(s.activeConns, sc) + s.mu.Unlock() +} + +func (s *http2serverInternalState) startGracefulShutdown() { + if s == nil { + return // if the Server was used without calling ConfigureServer + } + s.mu.Lock() + for sc := range s.activeConns { + sc.startGracefulShutdown() + } + s.mu.Unlock() +} + // ConfigureServer adds HTTP/2 support to a net/http Server. // // The configuration conf may be nil. @@ -3011,9 +3896,13 @@ func http2ConfigureServer(s *Server, conf *http2Server) error { if conf == nil { conf = new(http2Server) } + conf.state = &http2serverInternalState{activeConns: make(map[*http2serverConn]struct{})} if err := http2configureServer18(s, conf); err != nil { return err } + if err := http2configureServer19(s, conf); err != nil { + return err + } if s.TLSConfig == nil { s.TLSConfig = new(tls.Config) @@ -3039,6 +3928,13 @@ func http2ConfigureServer(s *Server, conf *http2Server) error { } } + // Note: not setting MinVersion to tls.VersionTLS12, + // as we don't want to interfere with HTTP/1.1 traffic + // on the user's server. We enforce TLS 1.2 later once + // we accept a connection. Ideally this should be done + // during next-proto selection, but using TLS <1.2 with + // HTTP/2 is still the client's bug. + s.TLSConfig.PreferServerCipherSuites = true haveNPN := false @@ -3118,29 +4014,37 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { defer cancel() sc := &http2serverConn{ - srv: s, - hs: opts.baseConfig(), - conn: c, - baseCtx: baseCtx, - remoteAddrStr: c.RemoteAddr().String(), - bw: http2newBufferedWriter(c), - handler: opts.handler(), - streams: make(map[uint32]*http2stream), - readFrameCh: make(chan http2readFrameResult), - wantWriteFrameCh: make(chan http2FrameWriteRequest, 8), - wantStartPushCh: make(chan http2startPushRequest, 8), - wroteFrameCh: make(chan http2frameWriteResult, 1), - bodyReadCh: make(chan http2bodyReadMsg), - doneServing: make(chan struct{}), - clientMaxStreams: math.MaxUint32, - advMaxStreams: s.maxConcurrentStreams(), - initialWindowSize: http2initialWindowSize, - maxFrameSize: http2initialMaxFrameSize, - headerTableSize: http2initialHeaderTableSize, - serveG: http2newGoroutineLock(), - pushEnabled: true, - } - + srv: s, + hs: opts.baseConfig(), + conn: c, + baseCtx: baseCtx, + remoteAddrStr: c.RemoteAddr().String(), + bw: http2newBufferedWriter(c), + handler: opts.handler(), + streams: make(map[uint32]*http2stream), + readFrameCh: make(chan http2readFrameResult), + wantWriteFrameCh: make(chan http2FrameWriteRequest, 8), + serveMsgCh: make(chan interface{}, 8), + wroteFrameCh: make(chan http2frameWriteResult, 1), // buffered; one send in writeFrameAsync + bodyReadCh: make(chan http2bodyReadMsg), // buffering doesn't matter either way + doneServing: make(chan struct{}), + clientMaxStreams: math.MaxUint32, // Section 6.5.2: "Initially, there is no limit to this value" + advMaxStreams: s.maxConcurrentStreams(), + initialStreamSendWindowSize: http2initialWindowSize, + maxFrameSize: http2initialMaxFrameSize, + headerTableSize: http2initialHeaderTableSize, + serveG: http2newGoroutineLock(), + pushEnabled: true, + } + + s.state.registerConn(sc) + defer s.state.unregisterConn(sc) + + // The net/http package sets the write deadline from the + // http.Server.WriteTimeout during the TLS handshake, but then + // passes the connection off to us with the deadline already set. + // Write deadlines are set per stream in serverConn.newStream. + // Disarm the net.Conn write deadline here. if sc.hs.WriteTimeout != 0 { sc.conn.SetWriteDeadline(time.Time{}) } @@ -3151,6 +4055,9 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { sc.writeSched = http2NewRandomWriteScheduler() } + // These start at the RFC-specified defaults. If there is a higher + // configured value for inflow, that will be updated when we send a + // WINDOW_UPDATE shortly after sending SETTINGS. sc.flow.add(http2initialWindowSize) sc.inflow.add(http2initialWindowSize) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) @@ -3164,18 +4071,44 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { if tc, ok := c.(http2connectionStater); ok { sc.tlsState = new(tls.ConnectionState) *sc.tlsState = tc.ConnectionState() - + // 9.2 Use of TLS Features + // An implementation of HTTP/2 over TLS MUST use TLS + // 1.2 or higher with the restrictions on feature set + // and cipher suite described in this section. Due to + // implementation limitations, it might not be + // possible to fail TLS negotiation. An endpoint MUST + // immediately terminate an HTTP/2 connection that + // does not meet the TLS requirements described in + // this section with a connection error (Section + // 5.4.1) of type INADEQUATE_SECURITY. if sc.tlsState.Version < tls.VersionTLS12 { sc.rejectConn(http2ErrCodeInadequateSecurity, "TLS version too low") return } if sc.tlsState.ServerName == "" { - + // Client must use SNI, but we don't enforce that anymore, + // since it was causing problems when connecting to bare IP + // addresses during development. + // + // TODO: optionally enforce? Or enforce at the time we receive + // a new request, and verify the the ServerName matches the :authority? + // But that precludes proxy situations, perhaps. + // + // So for now, do nothing here again. } if !s.PermitProhibitedCipherSuites && http2isBadCipher(sc.tlsState.CipherSuite) { - + // "Endpoints MAY choose to generate a connection error + // (Section 5.4.1) of type INADEQUATE_SECURITY if one of + // the prohibited cipher suites are negotiated." + // + // We choose that. In my opinion, the spec is weak + // here. It also says both parties must support at least + // TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 so there's no + // excuses here. If we really must, we could allow an + // "AllowInsecureWeakCiphers" option on the server later. + // Let's see how it plays out first. sc.rejectConn(http2ErrCodeInadequateSecurity, fmt.Sprintf("Prohibited TLS 1.2 Cipher Suite: %x", sc.tlsState.CipherSuite)) return } @@ -3189,7 +4122,7 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { func (sc *http2serverConn) rejectConn(err http2ErrCode, debug string) { sc.vlogf("http2: server rejecting conn: %v, %s", err, debug) - + // ignoring errors. hanging up anyway. sc.framer.WriteGoAway(0, err, []byte(debug)) sc.bw.Flush() sc.conn.Close() @@ -3207,10 +4140,9 @@ type http2serverConn struct { doneServing chan struct{} // closed when serverConn.serve ends readFrameCh chan http2readFrameResult // written by serverConn.readFrames wantWriteFrameCh chan http2FrameWriteRequest // from handlers -> serve - wantStartPushCh chan http2startPushRequest // from handlers -> serve wroteFrameCh chan http2frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes bodyReadCh chan http2bodyReadMsg // from handlers -> serve - testHookCh chan func(int) // code to run on the serve loop + serveMsgCh chan interface{} // misc messages & code to send to / run on the serve loop flow http2flow // conn-wide (not stream-specific) outbound flow control inflow http2flow // conn-wide inbound flow control tlsState *tls.ConnectionState // shared by all handlers, like net/http @@ -3218,38 +4150,39 @@ type http2serverConn struct { writeSched http2WriteScheduler // Everything following is owned by the serve loop; use serveG.check(): - serveG http2goroutineLock // used to verify funcs are on serve() - pushEnabled bool - sawFirstSettings bool // got the initial SETTINGS frame after the preface - needToSendSettingsAck bool - unackedSettings int // how many SETTINGS have we sent without ACKs? - clientMaxStreams uint32 // SETTINGS_MAX_CONCURRENT_STREAMS from client (our PUSH_PROMISE limit) - advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client - curClientStreams uint32 // number of open streams initiated by the client - curPushedStreams uint32 // number of open streams initiated by server push - maxClientStreamID uint32 // max ever seen from client (odd), or 0 if there have been no client requests - maxPushPromiseID uint32 // ID of the last push promise (even), or 0 if there have been no pushes - streams map[uint32]*http2stream - initialWindowSize int32 - maxFrameSize int32 - headerTableSize uint32 - peerMaxHeaderListSize uint32 // zero means unknown (default) - canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case - writingFrame bool // started writing a frame (on serve goroutine or separate) - writingFrameAsync bool // started a frame on its own goroutine but haven't heard back on wroteFrameCh - needsFrameFlush bool // last frame write wasn't a flush - inGoAway bool // we've started to or sent GOAWAY - inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop - needToSendGoAway bool // we need to schedule a GOAWAY frame write - goAwayCode http2ErrCode - shutdownTimerCh <-chan time.Time // nil until used - shutdownTimer *time.Timer // nil until used - idleTimer *time.Timer // nil if unused - idleTimerCh <-chan time.Time // nil if unused + serveG http2goroutineLock // used to verify funcs are on serve() + pushEnabled bool + sawFirstSettings bool // got the initial SETTINGS frame after the preface + needToSendSettingsAck bool + unackedSettings int // how many SETTINGS have we sent without ACKs? + clientMaxStreams uint32 // SETTINGS_MAX_CONCURRENT_STREAMS from client (our PUSH_PROMISE limit) + advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client + curClientStreams uint32 // number of open streams initiated by the client + curPushedStreams uint32 // number of open streams initiated by server push + maxClientStreamID uint32 // max ever seen from client (odd), or 0 if there have been no client requests + maxPushPromiseID uint32 // ID of the last push promise (even), or 0 if there have been no pushes + streams map[uint32]*http2stream + initialStreamSendWindowSize int32 + maxFrameSize int32 + headerTableSize uint32 + peerMaxHeaderListSize uint32 // zero means unknown (default) + canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case + writingFrame bool // started writing a frame (on serve goroutine or separate) + writingFrameAsync bool // started a frame on its own goroutine but haven't heard back on wroteFrameCh + needsFrameFlush bool // last frame write wasn't a flush + inGoAway bool // we've started to or sent GOAWAY + inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop + needToSendGoAway bool // we need to schedule a GOAWAY frame write + goAwayCode http2ErrCode + shutdownTimer *time.Timer // nil until used + idleTimer *time.Timer // nil if unused // Owned by the writeFrameAsync goroutine: headerWriteBuf bytes.Buffer hpackEncoder *hpack.Encoder + + // Used by startGracefulShutdown. + shutdownOnce sync.Once } func (sc *http2serverConn) maxHeaderListSize() uint32 { @@ -3294,10 +4227,10 @@ type http2stream struct { numTrailerValues int64 weight uint8 state http2streamState - resetQueued bool // RST_STREAM queued for write; set by sc.resetStream - gotTrailerHeader bool // HEADER frame for trailers was seen - wroteHeaders bool // whether we wrote headers (not status 100) - reqBuf []byte // if non-nil, body pipe buffer to return later at EOF + resetQueued bool // RST_STREAM queued for write; set by sc.resetStream + gotTrailerHeader bool // HEADER frame for trailers was seen + wroteHeaders bool // whether we wrote headers (not status 100) + writeDeadline *time.Timer // nil if unused trailer Header // accumulated trailers reqTrailer Header // handler's Request.Trailer @@ -3315,11 +4248,16 @@ func (sc *http2serverConn) HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) { func (sc *http2serverConn) state(streamID uint32) (http2streamState, *http2stream) { sc.serveG.check() - + // http://tools.ietf.org/html/rfc7540#section-5.1 if st, ok := sc.streams[streamID]; ok { return st.state, st } - + // "The first use of a new stream identifier implicitly closes all + // streams in the "idle" state that might have been initiated by + // that peer with a lower-valued stream identifier. For example, if + // a client sends a HEADERS frame on stream 7 without ever sending a + // frame on stream 5, then stream 5 transitions to the "closed" + // state when the first frame for stream 7 is sent or received." if streamID%2 == 1 { if streamID <= sc.maxClientStreamID { return http2stateClosed, nil @@ -3373,11 +4311,18 @@ func http2isClosedConnError(err error) bool { return false } + // TODO: remove this string search and be more like the Windows + // case below. That might involve modifying the standard library + // to return better error types. str := err.Error() if strings.Contains(str, "use of closed network connection") { return true } + // TODO(bradfitz): x/tools/cmd/bundle doesn't really support + // build tags, so I can't make an http2_windows.go file with + // Windows-specific stuff. Fix that and move this, once we + // have a way to bundle this into std's net/http somehow. if runtime.GOOS == "windows" { if oe, ok := err.(*net.OpError); ok && oe.Op == "read" { if se, ok := oe.Err.(*os.SyscallError); ok && se.Syscall == "wsarecv" { @@ -3397,7 +4342,7 @@ func (sc *http2serverConn) condlogf(err error, format string, args ...interface{ return } if err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err) { - + // Boring, expected errors. sc.vlogf(format, args...) } else { sc.logf(format, args...) @@ -3487,7 +4432,7 @@ func (sc *http2serverConn) stopShutdownTimer() { } func (sc *http2serverConn) notePanic() { - + // Note: this is for serverConn.serve panicking, not http.Handler code. if http2testHookOnPanicMu != nil { http2testHookOnPanicMu.Lock() defer http2testHookOnPanicMu.Unlock() @@ -3507,7 +4452,7 @@ func (sc *http2serverConn) serve() { defer sc.conn.Close() defer sc.closeAllStreamsOnConnClose() defer sc.stopShutdownTimer() - defer close(sc.doneServing) + defer close(sc.doneServing) // unblocks handlers trying to send if http2VerboseLogs { sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs) @@ -3518,44 +4463,48 @@ func (sc *http2serverConn) serve() { {http2SettingMaxFrameSize, sc.srv.maxReadFrameSize()}, {http2SettingMaxConcurrentStreams, sc.advMaxStreams}, {http2SettingMaxHeaderListSize, sc.maxHeaderListSize()}, + {http2SettingInitialWindowSize, uint32(sc.srv.initialStreamRecvWindowSize())}, }, }) sc.unackedSettings++ + // Each connection starts with intialWindowSize inflow tokens. + // If a higher value is configured, we add more tokens. + if diff := sc.srv.initialConnRecvWindowSize() - http2initialWindowSize; diff > 0 { + sc.sendWindowUpdate(nil, int(diff)) + } + if err := sc.readPreface(); err != nil { sc.condlogf(err, "http2: server: error reading preface from client %v: %v", sc.conn.RemoteAddr(), err) return } - + // Now that we've got the preface, get us out of the + // "StateNew" state. We can't go directly to idle, though. + // Active means we read some data and anticipate a request. We'll + // do another Active when we get a HEADERS frame. sc.setConnState(StateActive) sc.setConnState(StateIdle) if sc.srv.IdleTimeout != 0 { - sc.idleTimer = time.NewTimer(sc.srv.IdleTimeout) + sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer) defer sc.idleTimer.Stop() - sc.idleTimerCh = sc.idleTimer.C } - var gracefulShutdownCh chan struct{} - if sc.hs != nil { - ch := http2h1ServerShutdownChan(sc.hs) - if ch != nil { - gracefulShutdownCh = make(chan struct{}) - go sc.awaitGracefulShutdown(ch, gracefulShutdownCh) - } - } + go sc.readFrames() // closed by defer sc.conn.Close above - go sc.readFrames() + settingsTimer := time.AfterFunc(http2firstSettingsTimeout, sc.onSettingsTimer) + defer settingsTimer.Stop() - settingsTimer := time.NewTimer(http2firstSettingsTimeout) loopNum := 0 for { loopNum++ select { case wr := <-sc.wantWriteFrameCh: + if se, ok := wr.write.(http2StreamError); ok { + sc.resetStream(se) + break + } sc.writeFrame(wr) - case spr := <-sc.wantStartPushCh: - sc.startPush(spr) case res := <-sc.wroteFrameCh: sc.wroteFrame(res) case res := <-sc.readFrameCh: @@ -3563,26 +4512,37 @@ func (sc *http2serverConn) serve() { return } res.readMore() - if settingsTimer.C != nil { + if settingsTimer != nil { settingsTimer.Stop() - settingsTimer.C = nil + settingsTimer = nil } case m := <-sc.bodyReadCh: sc.noteBodyRead(m.st, m.n) - case <-settingsTimer.C: - sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr()) - return - case <-gracefulShutdownCh: - gracefulShutdownCh = nil - sc.startGracefulShutdown() - case <-sc.shutdownTimerCh: - sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr()) - return - case <-sc.idleTimerCh: - sc.vlogf("connection is idle") - sc.goAway(http2ErrCodeNo) - case fn := <-sc.testHookCh: - fn(loopNum) + case msg := <-sc.serveMsgCh: + switch v := msg.(type) { + case func(int): + v(loopNum) // for testing + case *http2serverMessage: + switch v { + case http2settingsTimerMsg: + sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr()) + return + case http2idleTimerMsg: + sc.vlogf("connection is idle") + sc.goAway(http2ErrCodeNo) + case http2shutdownTimerMsg: + sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr()) + return + case http2gracefulShutdownMsg: + sc.startGracefulShutdownInternal() + default: + panic("unknown timer") + } + case *http2startPushRequest: + sc.startPush(v) + default: + panic(fmt.Sprintf("unexpected type %T", v)) + } } if sc.inGoAway && sc.curOpenStreams() == 0 && !sc.needToSendGoAway && !sc.writingFrame { @@ -3599,12 +4559,36 @@ func (sc *http2serverConn) awaitGracefulShutdown(sharedCh <-chan struct{}, priva } } +type http2serverMessage int + +// Message values sent to serveMsgCh. +var ( + http2settingsTimerMsg = new(http2serverMessage) + http2idleTimerMsg = new(http2serverMessage) + http2shutdownTimerMsg = new(http2serverMessage) + http2gracefulShutdownMsg = new(http2serverMessage) +) + +func (sc *http2serverConn) onSettingsTimer() { sc.sendServeMsg(http2settingsTimerMsg) } + +func (sc *http2serverConn) onIdleTimer() { sc.sendServeMsg(http2idleTimerMsg) } + +func (sc *http2serverConn) onShutdownTimer() { sc.sendServeMsg(http2shutdownTimerMsg) } + +func (sc *http2serverConn) sendServeMsg(msg interface{}) { + sc.serveG.checkNotOn() // NOT + select { + case sc.serveMsgCh <- msg: + case <-sc.doneServing: + } +} + // readPreface reads the ClientPreface greeting from the peer // or returns an error on timeout or an invalid greeting. func (sc *http2serverConn) readPreface() error { errc := make(chan error, 1) go func() { - + // Read the client preface buf := make([]byte, len(http2ClientPreface)) if _, err := io.ReadFull(sc.conn, buf); err != nil { errc <- err @@ -3614,7 +4598,7 @@ func (sc *http2serverConn) readPreface() error { errc <- nil } }() - timer := time.NewTimer(http2prefaceTimeout) + timer := time.NewTimer(http2prefaceTimeout) // TODO: configurable on *Server? defer timer.Stop() select { case <-timer.C: @@ -3658,7 +4642,13 @@ func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte case <-sc.doneServing: return http2errClientDisconnected case <-stream.cw: - + // If both ch and stream.cw were ready (as might + // happen on the final Write after an http.Handler + // ends), prefer the write result. Otherwise this + // might just be us successfully closing the stream. + // The writeFrameAsync and serve goroutines guarantee + // that the ch send will happen before the stream.cw + // close. select { case err = <-ch: frameWriteDone = true @@ -3681,12 +4671,13 @@ func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte // buffered and is read by serve itself). If you're on the serve // goroutine, call writeFrame instead. func (sc *http2serverConn) writeFrameFromHandler(wr http2FrameWriteRequest) error { - sc.serveG.checkNotOn() + sc.serveG.checkNotOn() // NOT select { case sc.wantWriteFrameCh <- wr: return nil case <-sc.doneServing: - + // Serve loop is gone. + // Client has closed their connection to the server. return http2errClientDisconnected } } @@ -3705,6 +4696,24 @@ func (sc *http2serverConn) writeFrame(wr http2FrameWriteRequest) { // If true, wr will not be written and wr.done will not be signaled. var ignoreWrite bool + // We are not allowed to write frames on closed streams. RFC 7540 Section + // 5.1.1 says: "An endpoint MUST NOT send frames other than PRIORITY on + // a closed stream." Our server never sends PRIORITY, so that exception + // does not apply. + // + // The serverConn might close an open stream while the stream's handler + // is still running. For example, the server might close a stream when it + // receives bad data from the client. If this happens, the handler might + // attempt to write a frame after the stream has been closed (since the + // handler hasn't yet been notified of the close). In this case, we simply + // ignore the frame. The handler will notice that the stream is closed when + // it waits for the frame to be written. + // + // As an exception to this rule, we allow sending RST_STREAM after close. + // This allows us to immediately reject new streams without tracking any + // state for those streams (except for the queued RST_STREAM frame). This + // may result in duplicate RST_STREAMs in some cases, but the client should + // ignore those. if wr.StreamID() != 0 { _, isReset := wr.write.(http2StreamError) if state, _ := sc.state(wr.StreamID()); state == http2stateClosed && !isReset { @@ -3712,12 +4721,15 @@ func (sc *http2serverConn) writeFrame(wr http2FrameWriteRequest) { } } + // Don't send a 100-continue response if we've already sent headers. + // See golang.org/issue/14030. switch wr.write.(type) { case *http2writeResHeaders: wr.stream.wroteHeaders = true case http2write100ContinueHeadersFrame: if wr.stream.wroteHeaders { - + // We do not need to notify wr.done because this frame is + // never written with wr.done != nil. if wr.done != nil { panic("wr.done != nil for write100ContinueHeadersFrame") } @@ -3746,7 +4758,8 @@ func (sc *http2serverConn) startFrameWrite(wr http2FrameWriteRequest) { case http2stateHalfClosedLocal: switch wr.write.(type) { case http2StreamError, http2handlerPanicRST, http2writeWindowUpdate: - + // RFC 7540 Section 5.1 allows sending RST_STREAM, PRIORITY, and WINDOW_UPDATE + // in this state. (We never send PRIORITY from the server, so that is not checked.) default: panic(fmt.Sprintf("internal error: attempt to send frame on a half-closed-local stream: %v", wr)) } @@ -3800,16 +4813,29 @@ func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { } switch st.state { case http2stateOpen: - + // Here we would go to stateHalfClosedLocal in + // theory, but since our handler is done and + // the net/http package provides no mechanism + // for closing a ResponseWriter while still + // reading data (see possible TODO at top of + // this file), we go into closed state here + // anyway, after telling the peer we're + // hanging up on them. We'll transition to + // stateClosed after the RST_STREAM frame is + // written. st.state = http2stateHalfClosedLocal - sc.resetStream(http2streamError(st.id, http2ErrCodeCancel)) + // Section 8.1: a server MAY request that the client abort + // transmission of a request without error by sending a + // RST_STREAM with an error code of NO_ERROR after sending + // a complete response. + sc.resetStream(http2streamError(st.id, http2ErrCodeNo)) case http2stateHalfClosedRemote: sc.closeStream(st, http2errHandlerComplete) } } else { switch v := wr.write.(type) { case http2StreamError: - + // st may be unknown if the RST_STREAM was generated to reject bad input. if st, ok := sc.streams[v.StreamID]; ok { sc.closeStream(st, v) } @@ -3818,6 +4844,7 @@ func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { } } + // Reply (if requested) to unblock the ServeHTTP goroutine. wr.replyToWriter(res.err) sc.scheduleFrameWrite() @@ -3865,7 +4892,7 @@ func (sc *http2serverConn) scheduleFrameWrite() { } if sc.needsFrameFlush { sc.startFrameWrite(http2FrameWriteRequest{write: http2flushFrameWriter{}}) - sc.needsFrameFlush = false + sc.needsFrameFlush = false // after startFrameWrite, since it sets this true continue } break @@ -3873,10 +4900,19 @@ func (sc *http2serverConn) scheduleFrameWrite() { sc.inFrameScheduleLoop = false } -// startGracefulShutdown sends a GOAWAY with ErrCodeNo to tell the -// client we're gracefully shutting down. The connection isn't closed -// until all current streams are done. +// startGracefulShutdown gracefully shuts down a connection. This +// sends GOAWAY with ErrCodeNo to tell the client we're gracefully +// shutting down. The connection isn't closed until all current +// streams are done. +// +// startGracefulShutdown returns immediately; it does not wait until +// the connection has shut down. func (sc *http2serverConn) startGracefulShutdown() { + sc.serveG.checkNotOn() // NOT + sc.shutdownOnce.Do(func() { sc.sendServeMsg(http2gracefulShutdownMsg) }) +} + +func (sc *http2serverConn) startGracefulShutdownInternal() { sc.goAwayIn(http2ErrCodeNo, 0) } @@ -3886,7 +4922,7 @@ func (sc *http2serverConn) goAway(code http2ErrCode) { if code != http2ErrCodeNo { forceCloseIn = 250 * time.Millisecond } else { - + // TODO: configurable forceCloseIn = 1 * time.Second } sc.goAwayIn(code, forceCloseIn) @@ -3908,8 +4944,7 @@ func (sc *http2serverConn) goAwayIn(code http2ErrCode, forceCloseIn time.Duratio func (sc *http2serverConn) shutDownIn(d time.Duration) { sc.serveG.check() - sc.shutdownTimer = time.NewTimer(d) - sc.shutdownTimerCh = sc.shutdownTimer.C + sc.shutdownTimer = time.AfterFunc(d, sc.onShutdownTimer) } func (sc *http2serverConn) resetStream(se http2StreamError) { @@ -3929,11 +4964,18 @@ func (sc *http2serverConn) processFrameFromReader(res http2readFrameResult) bool if err != nil { if err == http2ErrFrameTooLarge { sc.goAway(http2ErrCodeFrameSize) - return true + return true // goAway will close the loop } clientGone := err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err) if clientGone { - + // TODO: could we also get into this state if + // the peer does a half close + // (e.g. CloseWrite) because they're done + // sending frames but they're still wanting + // our open replies? Investigate. + // TODO: add CloseWrite to crypto/tls.Conn first + // so we have a way to test this? I suppose + // just for testing we could have a non-TLS mode. return false } } else { @@ -3957,7 +4999,7 @@ func (sc *http2serverConn) processFrameFromReader(res http2readFrameResult) bool case http2ConnectionError: sc.logf("http2: server connection error from %v: %v", sc.conn.RemoteAddr(), ev) sc.goAway(http2ErrCode(ev)) - return true + return true // goAway will handle shutdown default: if res.err != nil { sc.vlogf("http2: server closing client connection; error reading frame from client %s: %v", sc.conn.RemoteAddr(), err) @@ -3971,6 +5013,7 @@ func (sc *http2serverConn) processFrameFromReader(res http2readFrameResult) bool func (sc *http2serverConn) processFrame(f http2Frame) error { sc.serveG.check() + // First frame received must be SETTINGS. if !sc.sawFirstSettings { if _, ok := f.(*http2SettingsFrame); !ok { return http2ConnectionError(http2ErrCodeProtocol) @@ -3996,7 +5039,8 @@ func (sc *http2serverConn) processFrame(f http2Frame) error { case *http2GoAwayFrame: return sc.processGoAway(f) case *http2PushPromiseFrame: - + // A client cannot push. Thus, servers MUST treat the receipt of a PUSH_PROMISE + // frame as a connection error (Section 5.4.1) of type PROTOCOL_ERROR. return http2ConnectionError(http2ErrCodeProtocol) default: sc.vlogf("http2: server ignoring frame: %v", f.Header()) @@ -4007,11 +5051,16 @@ func (sc *http2serverConn) processFrame(f http2Frame) error { func (sc *http2serverConn) processPing(f *http2PingFrame) error { sc.serveG.check() if f.IsAck() { - + // 6.7 PING: " An endpoint MUST NOT respond to PING frames + // containing this flag." return nil } if f.StreamID != 0 { - + // "PING frames are not associated with any individual + // stream. If a PING frame is received with a stream + // identifier field value other than 0x0, the recipient MUST + // respond with a connection error (Section 5.4.1) of type + // PROTOCOL_ERROR." return http2ConnectionError(http2ErrCodeProtocol) } if sc.inGoAway && sc.goAwayCode != http2ErrCodeNo { @@ -4024,20 +5073,27 @@ func (sc *http2serverConn) processPing(f *http2PingFrame) error { func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error { sc.serveG.check() switch { - case f.StreamID != 0: + case f.StreamID != 0: // stream-level flow control state, st := sc.state(f.StreamID) if state == http2stateIdle { - + // Section 5.1: "Receiving any frame other than HEADERS + // or PRIORITY on a stream in this state MUST be + // treated as a connection error (Section 5.4.1) of + // type PROTOCOL_ERROR." return http2ConnectionError(http2ErrCodeProtocol) } if st == nil { - + // "WINDOW_UPDATE can be sent by a peer that has sent a + // frame bearing the END_STREAM flag. This means that a + // receiver could receive a WINDOW_UPDATE frame on a "half + // closed (remote)" or "closed" stream. A receiver MUST + // NOT treat this as an error, see Section 5.1." return nil } if !st.flow.add(int32(f.Increment)) { return http2streamError(f.StreamID, http2ErrCodeFlowControl) } - default: + default: // connection-level flow control if !sc.flow.add(int32(f.Increment)) { return http2goAwayFlowError{} } @@ -4051,7 +5107,11 @@ func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error { state, st := sc.state(f.StreamID) if state == http2stateIdle { - + // 6.4 "RST_STREAM frames MUST NOT be sent for a + // stream in the "idle" state. If a RST_STREAM frame + // identifying an idle stream is received, the + // recipient MUST treat this as a connection error + // (Section 5.4.1) of type PROTOCOL_ERROR. return http2ConnectionError(http2ErrCodeProtocol) } if st != nil { @@ -4067,6 +5127,9 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) { panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state)) } st.state = http2stateClosed + if st.writeDeadline != nil { + st.writeDeadline.Stop() + } if st.isPushed() { sc.curPushedStreams-- } else { @@ -4079,16 +5142,17 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) { sc.idleTimer.Reset(sc.srv.IdleTimeout) } if http2h1ServerKeepAlivesDisabled(sc.hs) { - sc.startGracefulShutdown() + sc.startGracefulShutdownInternal() } } if p := st.body; p != nil { - + // Return any buffered unread bytes worth of conn-level flow control. + // See golang.org/issue/16481 sc.sendWindowUpdate(nil, p.Len()) p.CloseWithError(err) } - st.cw.Close() + st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc sc.writeSched.CloseStream(st.id) } @@ -4097,7 +5161,9 @@ func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error { if f.IsAck() { sc.unackedSettings-- if sc.unackedSettings < 0 { - + // Why is the peer ACKing settings we never sent? + // The spec doesn't mention this case, but + // hang up on them anyway. return http2ConnectionError(http2ErrCodeProtocol) } return nil @@ -4129,11 +5195,13 @@ func (sc *http2serverConn) processSetting(s http2Setting) error { case http2SettingInitialWindowSize: return sc.processSettingInitialWindowSize(s.Val) case http2SettingMaxFrameSize: - sc.maxFrameSize = int32(s.Val) + sc.maxFrameSize = int32(s.Val) // the maximum valid s.Val is < 2^31 case http2SettingMaxHeaderListSize: sc.peerMaxHeaderListSize = s.Val default: - + // Unknown setting: "An endpoint that receives a SETTINGS + // frame with any unknown or unsupported identifier MUST + // ignore that setting." if http2VerboseLogs { sc.vlogf("http2: server ignoring unknown setting %v", s) } @@ -4143,13 +5211,26 @@ func (sc *http2serverConn) processSetting(s http2Setting) error { func (sc *http2serverConn) processSettingInitialWindowSize(val uint32) error { sc.serveG.check() - - old := sc.initialWindowSize - sc.initialWindowSize = int32(val) - growth := sc.initialWindowSize - old + // Note: val already validated to be within range by + // processSetting's Valid call. + + // "A SETTINGS frame can alter the initial flow control window + // size for all current streams. When the value of + // SETTINGS_INITIAL_WINDOW_SIZE changes, a receiver MUST + // adjust the size of all stream flow control windows that it + // maintains by the difference between the new value and the + // old value." + old := sc.initialStreamSendWindowSize + sc.initialStreamSendWindowSize = int32(val) + growth := int32(val) - old // may be negative for _, st := range sc.streams { if !st.flow.add(growth) { - + // 6.9.2 Initial Flow Control Window Size + // "An endpoint MUST treat a change to + // SETTINGS_INITIAL_WINDOW_SIZE that causes any flow + // control window to exceed the maximum size as a + // connection error (Section 5.4.1) of type + // FLOW_CONTROL_ERROR." return http2ConnectionError(http2ErrCodeFlowControl) } } @@ -4163,23 +5244,40 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { } data := f.Data() + // "If a DATA frame is received whose stream is not in "open" + // or "half closed (local)" state, the recipient MUST respond + // with a stream error (Section 5.4.2) of type STREAM_CLOSED." id := f.Header().StreamID state, st := sc.state(id) if id == 0 || state == http2stateIdle { - + // Section 5.1: "Receiving any frame other than HEADERS + // or PRIORITY on a stream in this state MUST be + // treated as a connection error (Section 5.4.1) of + // type PROTOCOL_ERROR." return http2ConnectionError(http2ErrCodeProtocol) } if st == nil || state != http2stateOpen || st.gotTrailerHeader || st.resetQueued { - + // This includes sending a RST_STREAM if the stream is + // in stateHalfClosedLocal (which currently means that + // the http.Handler returned, so it's done reading & + // done writing). Try to stop the client from sending + // more DATA. + + // But still enforce their connection-level flow control, + // and return any flow control bytes since we're not going + // to consume them. if sc.inflow.available() < int32(f.Length) { return http2streamError(id, http2ErrCodeFlowControl) } - + // Deduct the flow control from inflow, since we're + // going to immediately add it back in + // sendWindowUpdate, which also schedules sending the + // frames. sc.inflow.take(int32(f.Length)) - sc.sendWindowUpdate(nil, int(f.Length)) + sc.sendWindowUpdate(nil, int(f.Length)) // conn-level if st != nil && st.resetQueued { - + // Already have a stream error in flight. Don't send another. return nil } return http2streamError(id, http2ErrCodeStreamClosed) @@ -4188,12 +5286,13 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { panic("internal error: should have a body in this state") } + // Sender sending more than they'd declared? if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes { st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes)) return http2streamError(id, http2ErrCodeStreamClosed) } if f.Length > 0 { - + // Check whether the client has flow control quota. if st.inflow.available() < int32(f.Length) { return http2streamError(id, http2ErrCodeFlowControl) } @@ -4210,6 +5309,8 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { st.bodyBytes += int64(len(data)) } + // Return any padded flow control now, since we won't + // refund it later on body reads. if pad := int32(f.Length) - int32(len(data)); pad > 0 { sc.sendWindowUpdate32(nil, pad) sc.sendWindowUpdate32(st, pad) @@ -4228,8 +5329,9 @@ func (sc *http2serverConn) processGoAway(f *http2GoAwayFrame) error { } else { sc.vlogf("http2: received GOAWAY %+v, starting graceful shutdown", f) } - sc.startGracefulShutdown() - + sc.startGracefulShutdownInternal() + // http://tools.ietf.org/html/rfc7540#section-6.8 + // We should not create any new streams, which means we should disable push. sc.pushEnabled = false return nil } @@ -4260,32 +5362,51 @@ func (st *http2stream) endStream() { func (st *http2stream) copyTrailersToHandlerRequest() { for k, vv := range st.trailer { if _, ok := st.reqTrailer[k]; ok { - + // Only copy it over it was pre-declared. st.reqTrailer[k] = vv } } } +// onWriteTimeout is run on its own goroutine (from time.AfterFunc) +// when the stream's WriteTimeout has fired. +func (st *http2stream) onWriteTimeout() { + st.sc.writeFrameFromHandler(http2FrameWriteRequest{write: http2streamError(st.id, http2ErrCodeInternal)}) +} + func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { sc.serveG.check() id := f.StreamID if sc.inGoAway { - + // Ignore. return nil } - + // http://tools.ietf.org/html/rfc7540#section-5.1.1 + // Streams initiated by a client MUST use odd-numbered stream + // identifiers. [...] An endpoint that receives an unexpected + // stream identifier MUST respond with a connection error + // (Section 5.4.1) of type PROTOCOL_ERROR. if id%2 != 1 { return http2ConnectionError(http2ErrCodeProtocol) } - + // A HEADERS frame can be used to create a new stream or + // send a trailer for an open one. If we already have a stream + // open, let it process its own HEADERS frame (trailers at this + // point, if it's valid). if st := sc.streams[f.StreamID]; st != nil { if st.resetQueued { - + // We're sending RST_STREAM to close the stream, so don't bother + // processing this frame. return nil } return st.processTrailerHeaders(f) } + // [...] The identifier of a newly established stream MUST be + // numerically greater than all streams that the initiating + // endpoint has opened or reserved. [...] An endpoint that + // receives an unexpected stream identifier MUST respond with + // a connection error (Section 5.4.1) of type PROTOCOL_ERROR. if id <= sc.maxClientStreamID { return http2ConnectionError(http2ErrCodeProtocol) } @@ -4295,12 +5416,22 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { sc.idleTimer.Stop() } + // http://tools.ietf.org/html/rfc7540#section-5.1.2 + // [...] Endpoints MUST NOT exceed the limit set by their peer. An + // endpoint that receives a HEADERS frame that causes their + // advertised concurrent stream limit to be exceeded MUST treat + // this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR + // or REFUSED_STREAM. if sc.curClientStreams+1 > sc.advMaxStreams { if sc.unackedSettings == 0 { - + // They should know better. return http2streamError(id, http2ErrCodeProtocol) } - + // Assume it's a network race, where they just haven't + // received our last SETTINGS update. But actually + // this can't happen yet, because we don't yet provide + // a way for users to adjust server parameters at + // runtime. return http2streamError(id, http2ErrCodeRefusedStream) } @@ -4325,17 +5456,24 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { if st.reqTrailer != nil { st.trailer = make(Header) } - st.body = req.Body.(*http2requestBody).pipe + st.body = req.Body.(*http2requestBody).pipe // may be nil st.declBodyBytes = req.ContentLength handler := sc.handler.ServeHTTP if f.Truncated { - + // Their header list was too long. Send a 431 error. handler = http2handleHeaderListTooLong } else if err := http2checkValidHTTP2RequestHeaders(req.Header); err != nil { handler = http2new400Handler(err) } + // The net/http package sets the read deadline from the + // http.Server.ReadTimeout during the TLS handshake, but then + // passes the connection off to us with the deadline already + // set. Disarm it here after the request headers are read, + // similar to how the http1 server works. Here it's + // technically more like the http1 Server's ReadHeaderTimeout + // (in Go 1.8), though. That's a more sane option anyway. if sc.hs.ReadTimeout != 0 { sc.conn.SetReadDeadline(time.Time{}) } @@ -4362,7 +5500,9 @@ func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error { for _, hf := range f.RegularFields() { key := sc.canonicalHeader(hf.Name) if !http2ValidTrailerHeader(key) { - + // TODO: send more details to the peer somehow. But http2 has + // no way to send debug data at a stream level. Discuss with + // HTTP folk. return http2streamError(st.id, http2ErrCodeProtocol) } st.trailer[key] = append(st.trailer[key], hf.Value) @@ -4374,7 +5514,10 @@ func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error { func http2checkPriority(streamID uint32, p http2PriorityParam) error { if streamID == p.StreamDep { - + // Section 5.3.1: "A stream cannot depend on itself. An endpoint MUST treat + // this as a stream error (Section 5.4.2) of type PROTOCOL_ERROR." + // Section 5.3.3 says that a stream can depend on one of its dependencies, + // so it's only self-dependencies that are forbidden. return http2streamError(streamID, http2ErrCodeProtocol) } return nil @@ -4406,10 +5549,13 @@ func (sc *http2serverConn) newStream(id, pusherID uint32, state http2streamState cancelCtx: cancelCtx, } st.cw.Init() - st.flow.conn = &sc.flow - st.flow.add(sc.initialWindowSize) - st.inflow.conn = &sc.inflow - st.inflow.add(http2initialWindowSize) + st.flow.conn = &sc.flow // link to conn-level counter + st.flow.add(sc.initialStreamSendWindowSize) + st.inflow.conn = &sc.inflow // link to conn-level counter + st.inflow.add(sc.srv.initialStreamRecvWindowSize()) + if sc.hs.WriteTimeout != 0 { + st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) + } sc.streams[id] = st sc.writeSched.OpenStream(st.id, http2OpenStreamOptions{PusherID: pusherID}) @@ -4441,13 +5587,22 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) } } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") { - + // See 8.1.2.6 Malformed Requests and Responses: + // + // Malformed requests or responses that are detected + // MUST be treated as a stream error (Section 5.4.2) + // of type PROTOCOL_ERROR." + // + // 8.1.2.3 Request Pseudo-Header Fields + // "All HTTP/2 requests MUST include exactly one valid + // value for the :method, :scheme, and :path + // pseudo-header fields" return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) } bodyOpen := !f.StreamEnded() if rp.method == "HEAD" && bodyOpen { - + // HEAD requests can't have bodies return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) } @@ -4464,16 +5619,14 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead return nil, nil, err } if bodyOpen { - st.reqBuf = http2getRequestBodyBuf() - req.Body.(*http2requestBody).pipe = &http2pipe{ - b: &http2fixedBuffer{buf: st.reqBuf}, - } - if vv, ok := rp.header["Content-Length"]; ok { req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64) } else { req.ContentLength = -1 } + req.Body.(*http2requestBody).pipe = &http2pipe{ + b: &http2dataBuffer{expected: req.ContentLength}, + } } return rw, req, nil } @@ -4496,7 +5649,7 @@ func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2re if needsContinue { rp.header.Del("Expect") } - + // Merge Cookie headers into one "; "-delimited value. if cookies := rp.header["Cookie"]; len(cookies) > 1 { rp.header.Set("Cookie", strings.Join(cookies, "; ")) } @@ -4508,7 +5661,8 @@ func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2re key = CanonicalHeaderKey(strings.TrimSpace(key)) switch key { case "Transfer-Encoding", "Trailer", "Content-Length": - + // Bogus. (copy of http1 rules) + // Ignore. default: if trailer == nil { trailer = make(Header) @@ -4523,7 +5677,7 @@ func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2re var requestURI string if rp.method == "CONNECT" { url_ = &url.URL{Host: rp.authority} - requestURI = rp.authority + requestURI = rp.authority // mimic HTTP/1 server behavior } else { var err error url_, err = url.ParseRequestURI(rp.path) @@ -4556,7 +5710,7 @@ func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2re rws := http2responseWriterStatePool.Get().(*http2responseWriterState) bwSave := rws.bw - *rws = http2responseWriterState{} + *rws = http2responseWriterState{} // zero all the fields rws.conn = sc rws.bw = bwSave rws.bw.Reset(http2chunkWriter{rws}) @@ -4568,24 +5722,6 @@ func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2re return rw, req, nil } -var http2reqBodyCache = make(chan []byte, 8) - -func http2getRequestBodyBuf() []byte { - select { - case b := <-http2reqBodyCache: - return b - default: - return make([]byte, http2initialWindowSize) - } -} - -func http2putRequestBodyBuf(b []byte) { - select { - case http2reqBodyCache <- b: - default: - } -} - // Run on its own goroutine. func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *Request, handler func(ResponseWriter, *Request)) { didPanic := true @@ -4597,7 +5733,7 @@ func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *Request, han write: http2handlerPanicRST{rw.rws.stream.id}, stream: rw.rws.stream, }) - + // Same as net/http: if http2shouldLogPanic(e) { const size = 64 << 10 buf := make([]byte, size) @@ -4625,10 +5761,13 @@ func http2handleHeaderListTooLong(w ResponseWriter, r *Request) { // called from handler goroutines. // h may be nil. func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeResHeaders) error { - sc.serveG.checkNotOn() + sc.serveG.checkNotOn() // NOT on var errc chan error if headerData.h != nil { - + // If there's a header map (which we don't own), so we have to block on + // waiting for this frame to be written, so an http.Flush mid-handler + // writes out the correct value of keys, before a handler later potentially + // mutates it. errc = http2errChanPool.Get().(chan error) } if err := sc.writeFrameFromHandler(http2FrameWriteRequest{ @@ -4671,26 +5810,21 @@ type http2bodyReadMsg struct { // Notes that the handler for the given stream ID read n bytes of its body // and schedules flow control tokens to be sent. func (sc *http2serverConn) noteBodyReadFromHandler(st *http2stream, n int, err error) { - sc.serveG.checkNotOn() + sc.serveG.checkNotOn() // NOT on if n > 0 { select { case sc.bodyReadCh <- http2bodyReadMsg{st, n}: case <-sc.doneServing: } } - if err == io.EOF { - if buf := st.reqBuf; buf != nil { - st.reqBuf = nil - http2putRequestBodyBuf(buf) - } - } } func (sc *http2serverConn) noteBodyRead(st *http2stream, n int) { sc.serveG.check() - sc.sendWindowUpdate(nil, n) + sc.sendWindowUpdate(nil, n) // conn-level if st.state != http2stateHalfClosedRemote && st.state != http2stateClosed { - + // Don't send this WINDOW_UPDATE if the stream is closed + // remotely. sc.sendWindowUpdate(st, n) } } @@ -4777,8 +5911,8 @@ func (b *http2requestBody) Read(p []byte) (n int, err error) { return } -// responseWriter is the http.ResponseWriter implementation. It's -// intentionally small (1 pointer wide) to minimize garbage. The +// responseWriter is the http.ResponseWriter implementation. It's +// intentionally small (1 pointer wide) to minimize garbage. The // responseWriterState pointer inside is zeroed at the end of a // request (in handlerDone) and calls on the responseWriter thereafter // simply crash (caller's mistake), but the much larger responseWriterState @@ -4812,6 +5946,7 @@ type http2responseWriterState struct { wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet. sentHeader bool // have we sent the header frame? handlerDone bool // handler has finished + dirty bool // a Write failed; don't reuse this responseWriterState sentContentLen int64 // non-zero if handler set a Content-Length header wroteBytes int64 @@ -4832,7 +5967,7 @@ func (rws *http2responseWriterState) hasTrailers() bool { return len(rws.trailer func (rws *http2responseWriterState) declareTrailer(k string) { k = CanonicalHeaderKey(k) if !http2ValidTrailerHeader(k) { - + // Forbidden by RFC 2616 14.40. rws.conn.logf("ignoring invalid trailer %q", k) return } @@ -4874,7 +6009,7 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { } var date string if _, ok := rws.snapHeader["Date"]; !ok { - + // TODO(bradfitz): be faster here, like net/http? measure. date = time.Now().UTC().Format(TimeFormat) } @@ -4893,6 +6028,7 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { date: date, }) if err != nil { + rws.dirty = true return 0, err } if endStream { @@ -4912,8 +6048,9 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { endStream := rws.handlerDone && !rws.hasTrailers() if len(p) > 0 || endStream { - + // only send a 0 byte DATA frame if we're ending the stream. if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil { + rws.dirty = true return 0, err } } @@ -4925,6 +6062,9 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) { trailers: rws.trailers, endStream: true, }) + if err != nil { + rws.dirty = true + } return len(p), err } return len(p), nil @@ -4952,7 +6092,7 @@ const http2TrailerPrefix = "Trailer:" // says you SHOULD (but not must) predeclare any trailers in the // header, the official ResponseWriter rules said trailers in Go must // be predeclared, and then we reuse the same ResponseWriter.Header() -// map to mean both Headers and Trailers. When it's time to write the +// map to mean both Headers and Trailers. When it's time to write the // Trailers, we pick out the fields of Headers that were declared as // trailers. That worked for a while, until we found the first major // user of Trailers in the wild: gRPC (using them only over http2), @@ -4989,11 +6129,14 @@ func (w *http2responseWriter) Flush() { } if rws.bw.Buffered() > 0 { if err := rws.bw.Flush(); err != nil { - + // Ignore the error. The frame writer already knows. return } } else { - + // The bufio.Writer won't call chunkWriter.Write + // (writeChunk with zero bytes, so we have to do it + // ourselves to force the HTTP response header and/or + // final DATA frame (with END_STREAM) to be sent. rws.writeChunk(nil) } } @@ -5010,7 +6153,7 @@ func (w *http2responseWriter) CloseNotify() <-chan bool { rws.closeNotifierCh = ch cw := rws.stream.cw go func() { - cw.Wait() + cw.Wait() // wait for close ch <- true }() } @@ -5061,7 +6204,7 @@ func http2cloneHeader(h Header) Header { // // * Handler calls w.Write or w.WriteString -> // * -> rws.bw (*bufio.Writer) -> -// * (Handler migth call Flush) +// * (Handler might call Flush) // * -> chunkWriter{rws} // * -> responseWriterState.writeChunk(p []byte) // * -> responseWriterState.writeChunk (most of the magic; see comment there) @@ -5085,9 +6228,9 @@ func (w *http2responseWriter) write(lenData int, dataB []byte, dataS string) (n if !http2bodyAllowedForStatus(rws.status) { return 0, ErrBodyNotAllowed } - rws.wroteBytes += int64(len(dataB)) + int64(len(dataS)) + rws.wroteBytes += int64(len(dataB)) + int64(len(dataS)) // only one can be set if rws.sentContentLen != 0 && rws.wroteBytes > rws.sentContentLen { - + // TODO: send a RST_STREAM return 0, errors.New("http2: handler wrote more than declared Content-Length") } @@ -5100,10 +6243,19 @@ func (w *http2responseWriter) write(lenData int, dataB []byte, dataS string) (n func (w *http2responseWriter) handlerDone() { rws := w.rws + dirty := rws.dirty rws.handlerDone = true w.Flush() w.rws = nil - http2responseWriterStatePool.Put(rws) + if !dirty { + // Only recycle the pool if all prior Write calls to + // the serverConn goroutine completed successfully. If + // they returned earlier due to resets from the peer + // there might still be write goroutines outstanding + // from the serverConn referencing the rws memory. See + // issue 20704. + http2responseWriterStatePool.Put(rws) + } } // Push errors. @@ -5124,10 +6276,13 @@ func (w *http2responseWriter) push(target string, opts http2pushOptions) error { sc := st.sc sc.serveG.checkNotOn() + // No recursive pushes: "PUSH_PROMISE frames MUST only be sent on a peer-initiated stream." + // http://tools.ietf.org/html/rfc7540#section-6.6 if st.isPushed() { return http2ErrRecursivePush } + // Default options. if opts.Method == "" { opts.Method = "GET" } @@ -5139,6 +6294,7 @@ func (w *http2responseWriter) push(target string, opts http2pushOptions) error { wantScheme = "https" } + // Validate the request. u, err := url.Parse(target) if err != nil { return err @@ -5161,7 +6317,10 @@ func (w *http2responseWriter) push(target string, opts http2pushOptions) error { if strings.HasPrefix(k, ":") { return fmt.Errorf("promised request headers cannot include pseudo header %q", k) } - + // These headers are meaningful only if the request has a body, + // but PUSH_PROMISE requests cannot have a body. + // http://tools.ietf.org/html/rfc7540#section-8.2 + // Also disallow Host, since the promised URL must be absolute. switch strings.ToLower(k) { case "content-length", "content-encoding", "trailer", "te", "expect", "host": return fmt.Errorf("promised request headers cannot include %q", k) @@ -5171,11 +6330,14 @@ func (w *http2responseWriter) push(target string, opts http2pushOptions) error { return err } + // The RFC effectively limits promised requests to GET and HEAD: + // "Promised requests MUST be cacheable [GET, HEAD, or POST], and MUST be safe [GET or HEAD]" + // http://tools.ietf.org/html/rfc7540#section-8.2 if opts.Method != "GET" && opts.Method != "HEAD" { return fmt.Errorf("method %q must be GET or HEAD", opts.Method) } - msg := http2startPushRequest{ + msg := &http2startPushRequest{ parent: st, method: opts.Method, url: u, @@ -5188,7 +6350,7 @@ func (w *http2responseWriter) push(target string, opts http2pushOptions) error { return http2errClientDisconnected case <-st.cw: return http2errStreamClosed - case sc.wantStartPushCh <- msg: + case sc.serveMsgCh <- msg: } select { @@ -5210,48 +6372,66 @@ type http2startPushRequest struct { done chan error } -func (sc *http2serverConn) startPush(msg http2startPushRequest) { +func (sc *http2serverConn) startPush(msg *http2startPushRequest) { sc.serveG.check() + // http://tools.ietf.org/html/rfc7540#section-6.6. + // PUSH_PROMISE frames MUST only be sent on a peer-initiated stream that + // is in either the "open" or "half-closed (remote)" state. if msg.parent.state != http2stateOpen && msg.parent.state != http2stateHalfClosedRemote { - + // responseWriter.Push checks that the stream is peer-initiaed. msg.done <- http2errStreamClosed return } + // http://tools.ietf.org/html/rfc7540#section-6.6. if !sc.pushEnabled { msg.done <- ErrNotSupported return } + // PUSH_PROMISE frames must be sent in increasing order by stream ID, so + // we allocate an ID for the promised stream lazily, when the PUSH_PROMISE + // is written. Once the ID is allocated, we start the request handler. allocatePromisedID := func() (uint32, error) { sc.serveG.check() + // Check this again, just in case. Technically, we might have received + // an updated SETTINGS by the time we got around to writing this frame. if !sc.pushEnabled { return 0, ErrNotSupported } - + // http://tools.ietf.org/html/rfc7540#section-6.5.2. if sc.curPushedStreams+1 > sc.clientMaxStreams { return 0, http2ErrPushLimitReached } + // http://tools.ietf.org/html/rfc7540#section-5.1.1. + // Streams initiated by the server MUST use even-numbered identifiers. + // A server that is unable to establish a new stream identifier can send a GOAWAY + // frame so that the client is forced to open a new connection for new streams. if sc.maxPushPromiseID+2 >= 1<<31 { - sc.startGracefulShutdown() + sc.startGracefulShutdownInternal() return 0, http2ErrPushLimitReached } sc.maxPushPromiseID += 2 promisedID := sc.maxPushPromiseID + // http://tools.ietf.org/html/rfc7540#section-8.2. + // Strictly speaking, the new stream should start in "reserved (local)", then + // transition to "half closed (remote)" after sending the initial HEADERS, but + // we start in "half closed (remote)" for simplicity. + // See further comments at the definition of stateHalfClosedRemote. promised := sc.newStream(promisedID, msg.parent.id, http2stateHalfClosedRemote) rw, req, err := sc.newWriterAndRequestNoBody(promised, http2requestParam{ method: msg.method, scheme: msg.url.Scheme, authority: msg.url.Host, path: msg.url.RequestURI(), - header: http2cloneHeader(msg.header), + header: http2cloneHeader(msg.header), // clone since handler runs concurrently with writing the PUSH_PROMISE }) if err != nil { - + // Should not happen, since we've already validated msg.url. panic(fmt.Sprintf("newWriterAndRequestNoBody(%+v): %v", msg.url, err)) } @@ -5356,31 +6536,6 @@ var http2badTrailer = map[string]bool{ "Www-Authenticate": true, } -// h1ServerShutdownChan returns a channel that will be closed when the -// provided *http.Server wants to shut down. -// -// This is a somewhat hacky way to get at http1 innards. It works -// when the http2 code is bundled into the net/http package in the -// standard library. The alternatives ended up making the cmd/go tool -// depend on http Servers. This is the lightest option for now. -// This is tested via the TestServeShutdown* tests in net/http. -func http2h1ServerShutdownChan(hs *Server) <-chan struct{} { - if fn := http2testh1ServerShutdownChan; fn != nil { - return fn(hs) - } - var x interface{} = hs - type I interface { - getDoneChan() <-chan struct{} - } - if hs, ok := x.(I); ok { - return hs.getDoneChan() - } - return nil -} - -// optional test hook for h1ServerShutdownChan. -var http2testh1ServerShutdownChan func(hs *Server) <-chan struct{} - // h1ServerKeepAlivesDisabled reports whether hs has its keep-alives // disabled. See comments on h1ServerShutdownChan above for why // the code is written this way. @@ -5486,7 +6641,7 @@ var http2errTransportVersion = errors.New("http2: ConfigureTransport is only sup // It requires Go 1.6 or later and returns an error if the net/http package is too old // or if t1 has already been HTTP/2-enabled. func http2ConfigureTransport(t1 *Transport) error { - _, err := http2configureTransport(t1) + _, err := http2configureTransport(t1) // in configure_transport.go (go1.6) or not_go16.go return err } @@ -5669,7 +6824,7 @@ func (t *http2Transport) RoundTrip(req *Request) (*Response, error) { // and returns a host:port. The port 443 is added if needed. func http2authorityAddr(scheme string, authority string) (addr string) { host, port, err := net.SplitHostPort(authority) - if err != nil { + if err != nil { // authority didn't have a port port = "443" if scheme == "http" { port = "80" @@ -5679,7 +6834,7 @@ func http2authorityAddr(scheme string, authority string) (addr string) { if a, err := idna.ToASCII(host); err == nil { host = a } - + // IPv6 address literal, without a port: if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { return host + ":" + port } @@ -5742,12 +6897,14 @@ func http2shouldRetryRequest(req *Request, err error) (*Request, error) { case http2errClientConnUnusable, http2errClientConnGotGoAway: return req, nil case http2errClientConnGotGoAwayAfterSomeReqBody: - + // If the Body is nil (or http.NoBody), it's safe to reuse + // this request and its Body. if req.Body == nil || http2reqBodyIsNoBody(req.Body) { return req, nil } - - getBody := http2reqGetBody(req) + // Otherwise we depend on the Request having its GetBody + // func defined. + getBody := http2reqGetBody(req) // Go 1.8: getBody = req.GetBody if getBody == nil { return nil, errors.New("http2: Transport: peer server initiated graceful shutdown after some of Request.Body was written; define Request.GetBody to avoid this error") } @@ -5840,9 +6997,9 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client tconn: c, readerDone: make(chan struct{}), nextStreamID: 1, - maxFrameSize: 16 << 10, - initialWindowSize: 65535, - maxConcurrentStreams: 1000, + maxFrameSize: 16 << 10, // spec default + initialWindowSize: 65535, // spec default + maxConcurrentStreams: 1000, // "infinite", per spec. 1000 seems good enough. streams: make(map[uint32]*http2clientStream), singleUse: singleUse, wantSettingsAck: true, @@ -5859,12 +7016,16 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client cc.cond = sync.NewCond(&cc.mu) cc.flow.add(int32(http2initialWindowSize)) + // TODO: adjust this writer size to account for frame size + + // MTU + crypto/tls record padding. cc.bw = bufio.NewWriter(http2stickyErrWriter{c, &cc.werr}) cc.br = bufio.NewReader(c) cc.fr = http2NewFramer(cc.bw, cc.br) cc.fr.ReadMetaHeaders = hpack.NewDecoder(http2initialHeaderTableSize, nil) cc.fr.MaxHeaderListSize = t.maxHeaderListSize() + // TODO: SetMaxDynamicTableSize, SetMaxDynamicTableSizeLimit on + // henc in response to SETTINGS frames? cc.henc = hpack.NewEncoder(&cc.hbuf) if cs, ok := c.(http2connectionStater); ok { @@ -5900,6 +7061,7 @@ func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) { old := cc.goAway cc.goAway = f + // Merge the previous and current GoAway error frames. if cc.goAwayDebug == "" { cc.goAwayDebug = string(f.DebugData()) } @@ -5932,7 +7094,7 @@ func (cc *http2ClientConn) canTakeNewRequestLocked() bool { cc.nextStreamID < math.MaxInt32 } -// onIdleTimeout is called from a time.AfterFunc goroutine. It will +// onIdleTimeout is called from a time.AfterFunc goroutine. It will // only be called when we're idle, but because we're coming from a new // goroutine, there could be a new request coming in at the same time, // so this simply calls the synchronized closeIfIdle to shut down this @@ -5950,7 +7112,7 @@ func (cc *http2ClientConn) closeIfIdle() { } cc.closed = true nextID := cc.nextStreamID - + // TODO: do clients send GOAWAY too? maybe? Just Close: cc.mu.Unlock() if http2VerboseLogs { @@ -5996,7 +7158,7 @@ func (cc *http2ClientConn) putFrameScratchBuffer(buf []byte) { return } } - + // forget about it. } // errRequestCanceled is a copy of net/http's errRequestCanceled because it's not @@ -6024,7 +7186,10 @@ func (cc *http2ClientConn) responseHeaderTimeout() time.Duration { if cc.t.t1 != nil { return cc.t.t1.ResponseHeaderTimeout } - + // No way to do this (yet?) with just an http2.Transport. Probably + // no need. Request.Cancel this is the new way. We only need to support + // this for compatibility with the old http.Transport fields when + // we're doing transparent http2. return 0 } @@ -6048,7 +7213,7 @@ func http2checkConnHeaders(req *Request) error { // req.ContentLength, where 0 actually means zero (not unknown) and -1 // means unknown. func http2actualContentLength(req *Request) int64 { - if req.Body == nil { + if req.Body == nil || http2reqBodyIsNoBody(req.Body) { return 0 } if req.ContentLength != 0 { @@ -6079,8 +7244,8 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { } body := req.Body - hasBody := body != nil contentLen := http2actualContentLength(req) + hasBody := contentLen != 0 // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? var requestedGzip bool @@ -6088,10 +7253,24 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" { - + // Request gzip only, not deflate. Deflate is ambiguous and + // not as universally supported anyway. + // See: http://www.gzip.org/zlib/zlib_faq.html#faq38 + // + // Note that we don't request this for HEAD requests, + // due to a bug in nginx: + // http://trac.nginx.org/nginx/ticket/358 + // https://golang.org/issue/5522 + // + // We don't request gzip if the request is for a range, since + // auto-decoding a portion of a gzipped document will just fail + // anyway. See https://golang.org/issue/8923 requestedGzip = true } + // we send: HEADERS{1}, CONTINUATION{0,} + DATA{0,} (DATA is + // sent by writeRequestBody below, along with any Trailers, + // again in form HEADERS{1}, CONTINUATION{0,}) hdrs, err := cc.encodeHeaders(req, requestedGzip, trailers, contentLen) if err != nil { cc.mu.Unlock() @@ -6114,11 +7293,12 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { if werr != nil { if hasBody { - req.Body.Close() + req.Body.Close() // per RoundTripper contract bodyWriter.cancel() } cc.forgetStreamID(cs.ID) - + // Don't bother sending a RST_STREAM (our write already failed; + // no need to keep writing) http2traceWroteRequest(cs.trace, werr) return nil, werr } @@ -6142,7 +7322,15 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { handleReadLoopResponse := func(re http2resAndError) (*Response, error) { res := re.res if re.err != nil || res.StatusCode > 299 { - + // On error or status code 3xx, 4xx, 5xx, etc abort any + // ongoing write, assuming that the server doesn't care + // about our request body. If the server replied with 1xx or + // 2xx, however, then assume the server DOES potentially + // want our body (e.g. full-duplex streaming: + // golang.org/issue/13444). If it turns out the server + // doesn't, they'll RST_STREAM us soon enough. This is a + // heuristic to avoid adding knobs to Transport. Hopefully + // we can keep it. bodyWriter.cancel() cs.abortRequestBodyWrite(http2errStopReqBodyWrite) } @@ -6209,9 +7397,12 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { return handleReadLoopResponse(re) default: } + // processResetStream already removed the + // stream from the streams map; no need for + // forgetStreamID. return nil, cs.resetErr case err := <-bodyWriter.resc: - + // Prefer the read loop's response, if available. Issue 16102. select { case re := <-readLoopResCh: return handleReadLoopResponse(re) @@ -6232,7 +7423,7 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { // requires cc.wmu be held func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, hdrs []byte) error { - first := true + first := true // first frame written (HEADERS is first, then CONTINUATION) frameSize := int(cc.maxFrameSize) for len(hdrs) > 0 && cc.werr == nil { chunk := hdrs @@ -6253,7 +7444,10 @@ func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, hdrs [] cc.fr.WriteContinuation(streamID, endHeaders, chunk) } } - + // TODO(bradfitz): this Flush could potentially block (as + // could the WriteHeaders call(s) above), which means they + // wouldn't respond to Request.Cancel being readable. That's + // rare, but this should probably be in a goroutine. cc.bw.Flush() return cc.werr } @@ -6269,13 +7463,16 @@ var ( func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (err error) { cc := cs.cc - sentEnd := false + sentEnd := false // whether we sent the final DATA frame w/ END_STREAM buf := cc.frameScratchBuffer() defer cc.putFrameScratchBuffer(buf) defer func() { http2traceWroteRequest(cs.trace, err) - + // TODO: write h12Compare test showing whether + // Request.Body is closed by the Transport, + // and in multiple cases: server replies <=299 and >299 + // while still writing request body cerr := bodyCloser.Close() if err == nil { err = cerr @@ -6314,7 +7511,12 @@ func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Clos sentEnd = sawEOF && len(remain) == 0 && !hasTrailers err = cc.fr.WriteData(cs.ID, sentEnd, data) if err == nil { - + // TODO(bradfitz): this flush is for latency, not bandwidth. + // Most requests won't need this. Make this opt-in or + // opt-out? Use some heuristic on the body type? Nagel-like + // timers? Based on 'n'? Only last chunk of this for loop, + // unless flow control tokens are low? For now, always. + // If we change this, see comment below. err = cc.bw.Flush() } cc.wmu.Unlock() @@ -6325,7 +7527,9 @@ func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Clos } if sentEnd { - + // Already sent END_STREAM (which implies we have no + // trailers) and flushed, because currently all + // WriteData frames above get a flush. So we're done. return nil } @@ -6339,6 +7543,8 @@ func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Clos cc.wmu.Lock() defer cc.wmu.Unlock() + // Two ways to send END_STREAM: either with trailers, or + // with an empty DATA frame. if len(trls) > 0 { err = cc.writeHeaders(cs.ID, true, trls) } else { @@ -6372,7 +7578,7 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err er take := a if int(take) > maxBytes { - take = int32(maxBytes) + take = int32(maxBytes) // can't truncate int; take is int32 } if take > int32(cc.maxFrameSize) { take = int32(cc.maxFrameSize) @@ -6420,6 +7626,9 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail } } + // Check for any invalid headers and return an error before we + // potentially pollute our hpack state. (We want to be able to + // continue to reuse the hpack encoder for future requests) for k, vv := range req.Header { if !httplex.ValidHeaderFieldName(k) { return nil, fmt.Errorf("invalid HTTP header name %q", k) @@ -6431,6 +7640,11 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail } } + // 8.1.2.3 Request Pseudo-Header Fields + // The :path pseudo-header field includes the path and query parts of the + // target URI (the path-absolute production and optionally a '?' character + // followed by the query production (see Sections 3.3 and 3.4 of + // [RFC3986]). cc.writeHeader(":authority", host) cc.writeHeader(":method", req.Method) if req.Method != "CONNECT" { @@ -6446,13 +7660,20 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail lowKey := strings.ToLower(k) switch lowKey { case "host", "content-length": - + // Host is :authority, already sent. + // Content-Length is automatic, set below. continue case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive": - + // Per 8.1.2.2 Connection-Specific Header + // Fields, don't send connection-specific + // fields. We have already checked if any + // are error-worthy so just ignore the rest. continue case "user-agent": - + // Match Go's http1 behavior: at most one + // User-Agent. If set to nil or empty string, + // then omit it. Otherwise if not mentioned, + // include the default (below). didUA = true if len(vv) < 1 { continue @@ -6490,7 +7711,8 @@ func http2shouldSendReqContentLength(method string, contentLength int64) bool { if contentLength < 0 { return false } - + // For zero bodies, whether we send a content-length depends on the method. + // It also kinda doesn't matter for http2 either way, with END_STREAM. switch method { case "POST", "PUT", "PATCH": return true @@ -6503,7 +7725,8 @@ func http2shouldSendReqContentLength(method string, contentLength int64) bool { func (cc *http2ClientConn) encodeTrailers(req *Request) []byte { cc.hbuf.Reset() for k, vv := range req.Trailer { - + // Transfer-Encoding, etc.. have already been filter at the + // start of RoundTrip lowKey := strings.ToLower(k) for _, v := range vv { cc.writeHeader(lowKey, v) @@ -6557,7 +7780,7 @@ func (cc *http2ClientConn) streamByID(id uint32, andRemove bool) *http2clientStr cc.idleTimer.Reset(cc.idleTimeout) } close(cs.done) - cc.cond.Broadcast() + cc.cond.Broadcast() // wake up checkResetOrDone via clientStream.awaitFlowControl } return cs } @@ -6616,6 +7839,9 @@ func (rl *http2clientConnReadLoop) cleanup() { cc.idleTimer.Stop() } + // Close any response bodies if the server closes prematurely. + // TODO: also do this if we've written the headers but not + // gotten a response yet. err := cc.readerErr cc.mu.Lock() if cc.goAway != nil && http2isEOFOrNetReadError(err) { @@ -6645,7 +7871,7 @@ func (rl *http2clientConnReadLoop) cleanup() { func (rl *http2clientConnReadLoop) run() error { cc := rl.cc rl.closeWhenIdle = cc.t.disableKeepAlives() || cc.singleUse - gotReply := false + gotReply := false // ever saw a HEADERS reply gotSettings := false for { f, err := cc.fr.ReadFrame() @@ -6653,7 +7879,7 @@ func (rl *http2clientConnReadLoop) run() error { cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err) } if se, ok := err.(http2StreamError); ok { - if cs := cc.streamByID(se.StreamID, true); cs != nil { + if cs := cc.streamByID(se.StreamID, true /*ended; remove it*/); cs != nil { cs.cc.writeStreamReset(cs.ID, se.Code, err) if se.Cause == nil { se.Cause = cc.fr.errDetail @@ -6674,7 +7900,7 @@ func (rl *http2clientConnReadLoop) run() error { } gotSettings = true } - maybeIdle := false + maybeIdle := false // whether frame might transition us to idle switch f := f.(type) { case *http2MetaHeadersFrame: @@ -6717,12 +7943,17 @@ func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) erro cc := rl.cc cs := cc.streamByID(f.StreamID, f.StreamEnded()) if cs == nil { - + // We'd get here if we canceled a request while the + // server had its response still in flight. So if this + // was just something we canceled, ignore it. return nil } if !cs.firstByte { if cs.trace != nil { - + // TODO(bradfitz): move first response byte earlier, + // when we first read the 9 byte header, not waiting + // until all the HEADERS+CONTINUATION frames have been + // merged. This works for now. http2traceFirstResponseByte(cs.trace) } cs.firstByte = true @@ -6738,13 +7969,13 @@ func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) erro if _, ok := err.(http2ConnectionError); ok { return err } - + // Any other error type is a stream error. cs.cc.writeStreamReset(f.StreamID, http2ErrCodeProtocol, err) cs.resc <- http2resAndError{err: err} - return nil + return nil // return nil from process* funcs to keep conn alive } if res == nil { - + // (nil, nil) special case. See handleResponse docs. return nil } if res.Body != http2noBody { @@ -6779,9 +8010,9 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http if statusCode == 100 { http2traceGot100Continue(cs.trace) if cs.on100 != nil { - cs.on100() + cs.on100() // forces any write delay timer to fire } - cs.pastHeaders = false + cs.pastHeaders = false // do it all again return nil, nil } @@ -6817,10 +8048,12 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil { res.ContentLength = clen64 } else { - + // TODO: care? unlike http/1, it won't mess up our framing, so it's + // more safe smuggling-wise to ignore. } } else if len(clens) > 1 { - + // TODO: care? unlike http/1, it won't mess up our framing, so it's + // more safe smuggling-wise to ignore. } } @@ -6829,8 +8062,7 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http return res, nil } - buf := new(bytes.Buffer) - cs.bufPipe = http2pipe{b: buf} + cs.bufPipe = http2pipe{b: &http2dataBuffer{expected: res.ContentLength}} cs.bytesRemain = res.ContentLength res.Body = http2transportResponseBody{cs} go cs.awaitRequestCancel(cs.req) @@ -6847,16 +8079,18 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http func (rl *http2clientConnReadLoop) processTrailers(cs *http2clientStream, f *http2MetaHeadersFrame) error { if cs.pastTrailers { - + // Too many HEADERS frames for this stream. return http2ConnectionError(http2ErrCodeProtocol) } cs.pastTrailers = true if !f.StreamEnded() { - + // We expect that any headers for trailers also + // has END_STREAM. return http2ConnectionError(http2ErrCodeProtocol) } if len(f.PseudoFields()) > 0 { - + // No pseudo header fields are defined for trailers. + // TODO: ConnectionError might be overly harsh? Check. return http2ConnectionError(http2ErrCodeProtocol) } @@ -6904,7 +8138,7 @@ func (b http2transportResponseBody) Read(p []byte) (n int, err error) { } } if n == 0 { - + // No flow control tokens to send back. return } @@ -6912,13 +8146,15 @@ func (b http2transportResponseBody) Read(p []byte) (n int, err error) { defer cc.mu.Unlock() var connAdd, streamAdd int32 - + // Check the conn-level first, before the stream-level. if v := cc.inflow.available(); v < http2transportDefaultConnFlow/2 { connAdd = http2transportDefaultConnFlow - v cc.inflow.add(connAdd) } - if err == nil { - + if err == nil { // No need to refresh if the stream is over or failed. + // Consider any buffered body data (read from the conn but not + // consumed by the client) when computing flow control for this + // stream. v := int(cs.inflow.available()) + cs.bufPipe.Len() if v < http2transportDefaultStreamFlow-http2transportDefaultStreamMinRefresh { streamAdd = int32(http2transportDefaultStreamFlow - v) @@ -6953,8 +8189,9 @@ func (b http2transportResponseBody) Close() error { cc.wmu.Lock() if !serverSentStreamEnd { cc.fr.WriteRSTStream(cs.ID, http2ErrCodeCancel) + cs.didReset = true } - + // Return connection-level flow control. if unread > 0 { cc.inflow.add(int32(unread)) cc.fr.WriteWindowUpdate(0, uint32(unread)) @@ -6977,11 +8214,16 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { neverSent := cc.nextStreamID cc.mu.Unlock() if f.StreamID >= neverSent { - + // We never asked for this. cc.logf("http2: Transport received unsolicited DATA frame; closing connection") return http2ConnectionError(http2ErrCodeProtocol) } + // We probably did ask for this, but canceled. Just ignore it. + // TODO: be stricter here? only silently ignore things which + // we canceled, but not things which were closed normally + // by the peer? Tough without accumulating too much state. + // But at least return their flow control: if f.Length > 0 { cc.mu.Lock() cc.inflow.add(int32(f.Length)) @@ -6995,12 +8237,7 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { return nil } if f.Length > 0 { - if len(data) > 0 && cs.bufPipe.b == nil { - - cc.logf("http2: Transport received DATA frame for closed stream; closing connection") - return http2ConnectionError(http2ErrCodeProtocol) - } - + // Check connection-level flow control. cc.mu.Lock() if cs.inflow.available() >= int32(f.Length) { cs.inflow.take(int32(f.Length)) @@ -7008,17 +8245,29 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { cc.mu.Unlock() return http2ConnectionError(http2ErrCodeFlowControl) } - - if pad := int32(f.Length) - int32(len(data)); pad > 0 { - cs.inflow.add(pad) - cc.inflow.add(pad) + // Return any padded flow control now, since we won't + // refund it later on body reads. + var refund int + if pad := int(f.Length) - len(data); pad > 0 { + refund += pad + } + // Return len(data) now if the stream is already closed, + // since data will never be read. + didReset := cs.didReset + if didReset { + refund += len(data) + } + if refund > 0 { + cc.inflow.add(int32(refund)) cc.wmu.Lock() - cc.fr.WriteWindowUpdate(0, uint32(pad)) - cc.fr.WriteWindowUpdate(cs.ID, uint32(pad)) + cc.fr.WriteWindowUpdate(0, uint32(refund)) + if !didReset { + cs.inflow.add(int32(refund)) + cc.fr.WriteWindowUpdate(cs.ID, uint32(refund)) + } cc.bw.Flush() cc.wmu.Unlock() } - didReset := cs.didReset cc.mu.Unlock() if len(data) > 0 && !didReset { @@ -7038,7 +8287,8 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { var http2errInvalidTrailers = errors.New("http2: invalid trailers") func (rl *http2clientConnReadLoop) endStream(cs *http2clientStream) { - + // TODO: check that any declared content-length matches, like + // server.go's (*stream).endStream method. rl.endStreamError(cs, nil) } @@ -7074,7 +8324,7 @@ func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error { cc := rl.cc cc.t.connPool().MarkDead(cc) if f.ErrCode != 0 { - + // TODO: deal with GOAWAY more. particularly the error code cc.vlogf("transport got GOAWAY with error code = %v", f.ErrCode) } cc.setGoAway(f) @@ -7101,11 +8351,17 @@ func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error case http2SettingMaxConcurrentStreams: cc.maxConcurrentStreams = s.Val case http2SettingInitialWindowSize: - + // Values above the maximum flow-control + // window size of 2^31-1 MUST be treated as a + // connection error (Section 5.4.1) of type + // FLOW_CONTROL_ERROR. if s.Val > math.MaxInt32 { return http2ConnectionError(http2ErrCodeFlowControl) } + // Adjust flow control of currently-open + // frames by the difference of the old initial + // window size and this one. delta := int32(s.Val) - int32(cc.initialWindowSize) for _, cs := range cc.streams { cs.flow.add(delta) @@ -7114,7 +8370,7 @@ func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error cc.initialWindowSize = s.Val default: - + // TODO(bradfitz): handle more settings? SETTINGS_HEADER_TABLE_SIZE probably. cc.vlogf("Unhandled Setting: %v", s) } return nil @@ -7155,18 +8411,21 @@ func (rl *http2clientConnReadLoop) processWindowUpdate(f *http2WindowUpdateFrame func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) error { cs := rl.cc.streamByID(f.StreamID, true) if cs == nil { - + // TODO: return error if server tries to RST_STEAM an idle stream return nil } select { case <-cs.peerReset: - + // Already reset. + // This is the only goroutine + // which closes this, so there + // isn't a race. default: err := http2streamError(cs.ID, f.ErrCode) cs.resetErr = err close(cs.peerReset) cs.bufPipe.CloseWithError(err) - cs.cc.cond.Broadcast() + cs.cc.cond.Broadcast() // wake up checkResetOrDone via clientStream.awaitFlowControl } delete(rl.activeRes, cs.ID) return nil @@ -7183,7 +8442,7 @@ func (cc *http2ClientConn) ping(ctx http2contextContext) error { return err } cc.mu.Lock() - + // check for dup before insert if _, found := cc.pings[p]; !found { cc.pings[p] = c cc.mu.Unlock() @@ -7207,7 +8466,7 @@ func (cc *http2ClientConn) ping(ctx http2contextContext) error { case <-ctx.Done(): return ctx.Err() case <-cc.readerDone: - + // connection closed return cc.readerErr } } @@ -7217,7 +8476,7 @@ func (rl *http2clientConnReadLoop) processPing(f *http2PingFrame) error { cc := rl.cc cc.mu.Lock() defer cc.mu.Unlock() - + // If ack, notify listener if any if c, ok := cc.pings[f.Data]; ok { close(c) delete(cc.pings, f.Data) @@ -7234,12 +8493,21 @@ func (rl *http2clientConnReadLoop) processPing(f *http2PingFrame) error { } func (rl *http2clientConnReadLoop) processPushPromise(f *http2PushPromiseFrame) error { - + // We told the peer we don't want them. + // Spec says: + // "PUSH_PROMISE MUST NOT be sent if the SETTINGS_ENABLE_PUSH + // setting of the peer endpoint is set to 0. An endpoint that + // has set this setting and has received acknowledgement MUST + // treat the receipt of a PUSH_PROMISE frame as a connection + // error (Section 5.4.1) of type PROTOCOL_ERROR." return http2ConnectionError(http2ErrCodeProtocol) } func (cc *http2ClientConn) writeStreamReset(streamID uint32, code http2ErrCode, err error) { - + // TODO: map err to more interesting error codes, once the + // HTTP community comes up with some. But currently for + // RST_STREAM there's no equivalent to GOAWAY frame's debug + // data, and the error codes are all pretty vague ("cancel"). cc.wmu.Lock() cc.fr.WriteRSTStream(streamID, code) cc.bw.Flush() @@ -7368,7 +8636,8 @@ func (s http2bodyWriterState) cancel() { func (s http2bodyWriterState) on100() { if s.timer == nil { - + // If we didn't do a delayed write, ignore the server's + // bogus 100 continue response. return } s.timer.Stop() @@ -7380,7 +8649,9 @@ func (s http2bodyWriterState) on100() { // called until after the headers have been written. func (s http2bodyWriterState) scheduleBodyWrite() { if s.timer == nil { - + // We're not doing a delayed write (see + // getBodyWriterState), so just start the writing + // goroutine immediately. go s.fn() return } @@ -7435,7 +8706,9 @@ func http2writeEndsStream(w http2writeFramer) bool { case *http2writeResHeaders: return v.endStream case nil: - + // This can only happen if the caller reuses w after it's + // been intentionally nil'ed out to prevent use. Keep this + // here to catch future refactoring breaking it. panic("writeEndsStream called on nil writeFramer") } return false @@ -7469,14 +8742,14 @@ type http2writeGoAway struct { func (p *http2writeGoAway) writeFrame(ctx http2writeContext) error { err := ctx.Framer().WriteGoAway(p.maxStreamID, p.code, nil) if p.code != 0 { - ctx.Flush() + ctx.Flush() // ignore error: we're hanging up on them anyway time.Sleep(50 * time.Millisecond) ctx.CloseConn() } return err } -func (*http2writeGoAway) staysWithinBuffer(max int) bool { return false } +func (*http2writeGoAway) staysWithinBuffer(max int) bool { return false } // flushes type http2writeData struct { streamID uint32 @@ -7581,7 +8854,13 @@ func http2encKV(enc *hpack.Encoder, k, v string) { } func (w *http2writeResHeaders) staysWithinBuffer(max int) bool { - + // TODO: this is a common one. It'd be nice to return true + // here and get into the fast path if we could be clever and + // calculate the size fast enough, or at least a conservative + // uppper bound that usually fires. (Maybe if w.h and + // w.trailers are nil, so we don't need to enumerate it.) + // Otherwise I'm afraid that just calculating the length to + // answer this question would be slower than the ~2µs benefit. return false } @@ -7640,7 +8919,7 @@ type http2writePushPromise struct { } func (w *http2writePushPromise) staysWithinBuffer(max int) bool { - + // TODO: see writeResHeaders.staysWithinBuffer return false } @@ -7692,7 +8971,7 @@ func (w http2write100ContinueHeadersFrame) writeFrame(ctx http2writeContext) err } func (w http2write100ContinueHeadersFrame) staysWithinBuffer(max int) bool { - + // Sloppy but conservative: return 9+2*(len(":status")+len("100")) <= max } @@ -7712,7 +8991,9 @@ func (wu http2writeWindowUpdate) writeFrame(ctx http2writeContext) error { func http2encodeHeaders(enc *hpack.Encoder, h Header, keys []string) { if keys == nil { sorter := http2sorterPool.Get().(*http2sorter) - + // Using defer here, since the returned keys from the + // sorter.Keys method is only valid until the sorter + // is returned: defer http2sorterPool.Put(sorter) keys = sorter.Keys(h) } @@ -7720,16 +9001,19 @@ func http2encodeHeaders(enc *hpack.Encoder, h Header, keys []string) { vv := h[k] k = http2lowerHeader(k) if !http2validWireHeaderFieldName(k) { - + // Skip it as backup paranoia. Per + // golang.org/issue/14048, these should + // already be rejected at a higher level. continue } isTE := k == "transfer-encoding" for _, v := range vv { if !httplex.ValidHeaderFieldValue(v) { - + // TODO: return an error? golang.org/issue/14048 + // For now just omit it. continue } - + // TODO: more of "8.1.2.2 Connection-Specific Header Fields" if isTE && v != "trailers" { continue } @@ -7797,7 +9081,10 @@ type http2FrameWriteRequest struct { func (wr http2FrameWriteRequest) StreamID() uint32 { if wr.stream == nil { if se, ok := wr.write.(http2StreamError); ok { - + // (*serverConn).resetStream doesn't set + // stream because it doesn't necessarily have + // one. So special case this type of write + // message. return se.StreamID } return 0 @@ -7827,11 +9114,13 @@ func (wr http2FrameWriteRequest) DataSize() int { func (wr http2FrameWriteRequest) Consume(n int32) (http2FrameWriteRequest, http2FrameWriteRequest, int) { var empty http2FrameWriteRequest + // Non-DATA frames are always consumed whole. wd, ok := wr.write.(*http2writeData) if !ok || len(wd.p) == 0 { return wr, empty, 1 } + // Might need to split after applying limits. allowed := wr.stream.flow.available() if n < allowed { allowed = n @@ -7849,10 +9138,13 @@ func (wr http2FrameWriteRequest) Consume(n int32) (http2FrameWriteRequest, http2 write: &http2writeData{ streamID: wd.streamID, p: wd.p[:allowed], - + // Even if the original had endStream set, there + // are bytes remaining because len(wd.p) > allowed, + // so we know endStream is false. endStream: false, }, - + // Our caller is blocking on the final DATA frame, not + // this intermediate frame, so no need to wait. done: nil, } rest := http2FrameWriteRequest{ @@ -7867,6 +9159,8 @@ func (wr http2FrameWriteRequest) Consume(n int32) (http2FrameWriteRequest, http2 return consumed, rest, 2 } + // The frame is consumed whole. + // NB: This cast cannot overflow because allowed is <= math.MaxInt32. wr.stream.flow.take(int32(len(wd.p))) return wr, empty, 1 } @@ -7893,7 +9187,7 @@ func (wr *http2FrameWriteRequest) replyToWriter(err error) { default: panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wr.write)) } - wr.write = nil + wr.write = nil // prevent use (assume it's tainted after wr.done send) } // writeQueue is used by implementations of WriteScheduler. @@ -7912,7 +9206,7 @@ func (q *http2writeQueue) shift() http2FrameWriteRequest { panic("invalid use of queue") } wr := q.s[0] - + // TODO: less copy-happy queue. copy(q.s, q.s[1:]) q.s[len(q.s)-1] = http2FrameWriteRequest{} q.s = q.s[:len(q.s)-1] @@ -7942,6 +9236,8 @@ func (q *http2writeQueue) consume(n int32) (http2FrameWriteRequest, bool) { type http2writeQueuePool []*http2writeQueue // put inserts an unused writeQueue into the pool. + +// put inserts an unused writeQueue into the pool. func (p *http2writeQueuePool) put(q *http2writeQueue) { for i := range q.s { q.s[i] = http2FrameWriteRequest{} @@ -8006,11 +9302,12 @@ type http2PriorityWriteSchedulerConfig struct { } // NewPriorityWriteScheduler constructs a WriteScheduler that schedules -// frames by following HTTP/2 priorities as described in RFC 7340 Section 5.3. +// frames by following HTTP/2 priorities as described in RFC 7540 Section 5.3. // If cfg is nil, default options are used. func http2NewPriorityWriteScheduler(cfg *http2PriorityWriteSchedulerConfig) http2WriteScheduler { if cfg == nil { - + // For justification of these defaults, see: + // https://docs.google.com/document/d/1oLhNg1skaWD4_DtaoCxdSRN5erEXrH-KnLrMwEpOtFY cfg = &http2PriorityWriteSchedulerConfig{ MaxClosedNodesInTree: 10, MaxIdleNodesInTree: 10, @@ -8065,7 +9362,7 @@ func (n *http2priorityNode) setParent(parent *http2priorityNode) { if n.parent == parent { return } - + // Unlink from current parent. if parent := n.parent; parent != nil { if n.prev == nil { parent.kids = n.next @@ -8076,7 +9373,9 @@ func (n *http2priorityNode) setParent(parent *http2priorityNode) { n.next.prev = n.prev } } - + // Link to new parent. + // If parent=nil, remove n from the tree. + // Always insert at the head of parent.kids (this is assumed by walkReadyInOrder). n.parent = parent if parent == nil { n.next = nil @@ -8112,10 +9411,15 @@ func (n *http2priorityNode) walkReadyInOrder(openParent bool, tmp *[]*http2prior return false } + // Don't consider the root "open" when updating openParent since + // we can't send data frames on the root stream (only control frames). if n.id != 0 { openParent = openParent || (n.state == http2priorityNodeOpen) } + // Common case: only one kid or all kids have the same weight. + // Some clients don't use weights; other clients (like web browsers) + // use mostly-linear priority trees. w := n.kids.weight needSort := false for k := n.kids.next; k != nil; k = k.next { @@ -8133,6 +9437,8 @@ func (n *http2priorityNode) walkReadyInOrder(openParent bool, tmp *[]*http2prior return false } + // Uncommon case: sort the child nodes. We remove the kids from the parent, + // then re-insert after sorting so we can reuse tmp for future sort calls. *tmp = (*tmp)[:0] for n.kids != nil { *tmp = append(*tmp, n.kids) @@ -8140,7 +9446,7 @@ func (n *http2priorityNode) walkReadyInOrder(openParent bool, tmp *[]*http2prior } sort.Sort(http2sortPriorityNodeSiblings(*tmp)) for i := len(*tmp) - 1; i >= 0; i-- { - (*tmp)[i].setParent(n) + (*tmp)[i].setParent(n) // setParent inserts at the head of n.kids } for k := n.kids; k != nil; k = k.next { if k.walkReadyInOrder(openParent, tmp, f) { @@ -8157,7 +9463,8 @@ func (z http2sortPriorityNodeSiblings) Len() int { return len(z) } func (z http2sortPriorityNodeSiblings) Swap(i, k int) { z[i], z[k] = z[k], z[i] } func (z http2sortPriorityNodeSiblings) Less(i, k int) bool { - + // Prefer the subtree that has sent fewer bytes relative to its weight. + // See sections 5.3.2 and 5.3.4. wi, bi := float64(z[i].weight+1), float64(z[i].subtreeBytes) wk, bk := float64(z[k].weight+1), float64(z[k].subtreeBytes) if bi == 0 && bk == 0 { @@ -8199,7 +9506,7 @@ type http2priorityWriteScheduler struct { } func (ws *http2priorityWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) { - + // The stream may be currently idle but cannot be opened or closed. if curr := ws.nodes[streamID]; curr != nil { if curr.state != http2priorityNodeIdle { panic(fmt.Sprintf("stream %d already opened", streamID)) @@ -8208,6 +9515,10 @@ func (ws *http2priorityWriteScheduler) OpenStream(streamID uint32, options http2 return } + // RFC 7540, Section 5.3.5: + // "All streams are initially assigned a non-exclusive dependency on stream 0x0. + // Pushed streams initially depend on their associated stream. In both cases, + // streams are assigned a default weight of 16." parent := ws.nodes[options.PusherID] if parent == nil { parent = &ws.root @@ -8255,6 +9566,9 @@ func (ws *http2priorityWriteScheduler) AdjustStream(streamID uint32, priority ht panic("adjustPriority on root") } + // If streamID does not exist, there are two cases: + // - A closed stream that has been removed (this will have ID <= maxID) + // - An idle stream that is being used for "grouping" (this will have ID > maxID) n := ws.nodes[streamID] if n == nil { if streamID <= ws.maxID || ws.maxIdleNodesInTree == 0 { @@ -8272,6 +9586,8 @@ func (ws *http2priorityWriteScheduler) AdjustStream(streamID uint32, priority ht ws.addClosedOrIdleNode(&ws.idleNodes, ws.maxIdleNodesInTree, n) } + // Section 5.3.1: A dependency on a stream that is not currently in the tree + // results in that stream being given a default priority (Section 5.3.5). parent := ws.nodes[priority.StreamDep] if parent == nil { n.setParent(&ws.root) @@ -8279,10 +9595,18 @@ func (ws *http2priorityWriteScheduler) AdjustStream(streamID uint32, priority ht return } + // Ignore if the client tries to make a node its own parent. if n == parent { return } + // Section 5.3.3: + // "If a stream is made dependent on one of its own dependencies, the + // formerly dependent stream is first moved to be dependent on the + // reprioritized stream's previous parent. The moved dependency retains + // its weight." + // + // That is: if parent depends on n, move parent to depend on n.parent. for x := parent.parent; x != nil; x = x.parent { if x == n { parent.setParent(n.parent) @@ -8290,6 +9614,9 @@ func (ws *http2priorityWriteScheduler) AdjustStream(streamID uint32, priority ht } } + // Section 5.3.3: The exclusive flag causes the stream to become the sole + // dependency of its parent stream, causing other dependencies to become + // dependent on the exclusive stream. if priority.Exclusive { k := parent.kids for k != nil { @@ -8312,7 +9639,11 @@ func (ws *http2priorityWriteScheduler) Push(wr http2FrameWriteRequest) { } else { n = ws.nodes[id] if n == nil { - + // id is an idle or closed stream. wr should not be a HEADERS or + // DATA frame. However, wr can be a RST_STREAM. In this case, we + // push wr onto the root, rather than creating a new priorityNode, + // since RST_STREAM is tiny and the stream's priority is unknown + // anyway. See issue #17919. if wr.DataSize() > 0 { panic("add DATA on non-open stream") } @@ -8333,7 +9664,9 @@ func (ws *http2priorityWriteScheduler) Pop() (wr http2FrameWriteRequest, ok bool return false } n.addBytes(int64(wr.DataSize())) - + // If B depends on A and B continuously has data available but A + // does not, gradually increase the throttling limit to allow B to + // steal more and more bandwidth from A. if openParent { ws.writeThrottleLimit += 1024 if ws.writeThrottleLimit < 0 { @@ -8352,7 +9685,7 @@ func (ws *http2priorityWriteScheduler) addClosedOrIdleNode(list *[]*http2priorit return } if len(*list) == maxSize { - + // Remove the oldest node, then shift left. ws.removeNode((*list)[0]) x := (*list)[1:] copy(*list, x) @@ -8390,7 +9723,7 @@ type http2randomWriteScheduler struct { } func (ws *http2randomWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) { - + // no-op: idle streams are not tracked } func (ws *http2randomWriteScheduler) CloseStream(streamID uint32) { @@ -8403,7 +9736,7 @@ func (ws *http2randomWriteScheduler) CloseStream(streamID uint32) { } func (ws *http2randomWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) { - + // no-op: priorities are ignored } func (ws *http2randomWriteScheduler) Push(wr http2FrameWriteRequest) { @@ -8421,11 +9754,11 @@ func (ws *http2randomWriteScheduler) Push(wr http2FrameWriteRequest) { } func (ws *http2randomWriteScheduler) Pop() (http2FrameWriteRequest, bool) { - + // Control frames first. if !ws.zero.empty() { return ws.zero.shift(), true } - + // Iterate over all non-idle streams until finding one that can be consumed. for _, q := range ws.sq { if wr, ok := q.consume(math.MaxInt32); ok { return wr, true diff --git a/libgo/go/net/http/httptest/recorder.go b/libgo/go/net/http/httptest/recorder.go index 5f1aa6a..741f076 100644 --- a/libgo/go/net/http/httptest/recorder.go +++ b/libgo/go/net/http/httptest/recorder.go @@ -6,6 +6,7 @@ package httptest import ( "bytes" + "fmt" "io/ioutil" "net/http" "strconv" @@ -176,7 +177,7 @@ func (rw *ResponseRecorder) Result() *http.Response { if res.StatusCode == 0 { res.StatusCode = 200 } - res.Status = http.StatusText(res.StatusCode) + res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode)) if rw.Body != nil { res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes())) } diff --git a/libgo/go/net/http/httptest/recorder_test.go b/libgo/go/net/http/httptest/recorder_test.go index 9afba4e..a6259eb 100644 --- a/libgo/go/net/http/httptest/recorder_test.go +++ b/libgo/go/net/http/httptest/recorder_test.go @@ -23,7 +23,15 @@ func TestRecorder(t *testing.T) { return nil } } - hasResultStatus := func(wantCode int) checkFunc { + hasResultStatus := func(want string) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Result().Status != want { + return fmt.Errorf("Result().Status = %q; want %q", rec.Result().Status, want) + } + return nil + } + } + hasResultStatusCode := func(wantCode int) checkFunc { return func(rec *ResponseRecorder) error { if rec.Result().StatusCode != wantCode { return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode) @@ -235,7 +243,8 @@ func TestRecorder(t *testing.T) { hasOldHeader("X-Foo", "1"), hasStatus(0), hasHeader("X-Foo", "1"), - hasResultStatus(200), + hasResultStatus("200 OK"), + hasResultStatusCode(200), ), }, { diff --git a/libgo/go/net/http/httptest/server.go b/libgo/go/net/http/httptest/server.go index 5e9ace5..e543672 100644 --- a/libgo/go/net/http/httptest/server.go +++ b/libgo/go/net/http/httptest/server.go @@ -9,6 +9,7 @@ package httptest import ( "bytes" "crypto/tls" + "crypto/x509" "flag" "fmt" "log" @@ -35,6 +36,9 @@ type Server struct { // before Start or StartTLS. Config *http.Server + // certificate is a parsed version of the TLS config certificate, if present. + certificate *x509.Certificate + // wg counts the number of outstanding HTTP requests on this server. // Close blocks until all requests are finished. wg sync.WaitGroup @@ -42,6 +46,10 @@ type Server struct { mu sync.Mutex // guards closed and conns closed bool conns map[net.Conn]http.ConnState // except terminal states + + // client is configured for use with the server. + // Its transport is automatically closed when Close is called. + client *http.Client } func newLocalListener() net.Listener { @@ -93,6 +101,9 @@ func (s *Server) Start() { if s.URL != "" { panic("Server already started") } + if s.client == nil { + s.client = &http.Client{Transport: &http.Transport{}} + } s.URL = "http://" + s.Listener.Addr().String() s.wrap() s.goServe() @@ -107,6 +118,9 @@ func (s *Server) StartTLS() { if s.URL != "" { panic("Server already started") } + if s.client == nil { + s.client = &http.Client{Transport: &http.Transport{}} + } cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey) if err != nil { panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) @@ -124,6 +138,17 @@ func (s *Server) StartTLS() { if len(s.TLS.Certificates) == 0 { s.TLS.Certificates = []tls.Certificate{cert} } + s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0]) + if err != nil { + panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) + } + certpool := x509.NewCertPool() + certpool.AddCert(s.certificate) + s.client.Transport = &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: certpool, + }, + } s.Listener = tls.NewListener(s.Listener, s.TLS) s.URL = "https://" + s.Listener.Addr().String() s.wrap() @@ -186,6 +211,13 @@ func (s *Server) Close() { t.CloseIdleConnections() } + // Also close the client idle connections. + if s.client != nil { + if t, ok := s.client.Transport.(closeIdleTransport); ok { + t.CloseIdleConnections() + } + } + s.wg.Wait() } @@ -206,7 +238,7 @@ func (s *Server) CloseClientConnections() { nconn := len(s.conns) ch := make(chan struct{}, nconn) for c := range s.conns { - s.closeConnChan(c, ch) + go s.closeConnChan(c, ch) } s.mu.Unlock() @@ -228,6 +260,19 @@ func (s *Server) CloseClientConnections() { } } +// Certificate returns the certificate used by the server, or nil if +// the server doesn't use TLS. +func (s *Server) Certificate() *x509.Certificate { + return s.certificate +} + +// Client returns an HTTP client configured for making requests to the server. +// It is configured to trust the server's TLS test certificate and will +// close its idle connections on Server.Close. +func (s *Server) Client() *http.Client { + return s.client +} + func (s *Server) goServe() { s.wg.Add(1) go func() { diff --git a/libgo/go/net/http/httptest/server_test.go b/libgo/go/net/http/httptest/server_test.go index d032c59..8ab50cd 100644 --- a/libgo/go/net/http/httptest/server_test.go +++ b/libgo/go/net/http/httptest/server_test.go @@ -12,8 +12,48 @@ import ( "testing" ) +type newServerFunc func(http.Handler) *Server + +var newServers = map[string]newServerFunc{ + "NewServer": NewServer, + "NewTLSServer": NewTLSServer, + + // The manual variants of newServer create a Server manually by only filling + // in the exported fields of Server. + "NewServerManual": func(h http.Handler) *Server { + ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}} + ts.Start() + return ts + }, + "NewTLSServerManual": func(h http.Handler) *Server { + ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}} + ts.StartTLS() + return ts + }, +} + func TestServer(t *testing.T) { - ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for _, name := range []string{"NewServer", "NewServerManual"} { + t.Run(name, func(t *testing.T) { + newServer := newServers[name] + t.Run("Server", func(t *testing.T) { testServer(t, newServer) }) + t.Run("GetAfterClose", func(t *testing.T) { testGetAfterClose(t, newServer) }) + t.Run("ServerCloseBlocking", func(t *testing.T) { testServerCloseBlocking(t, newServer) }) + t.Run("ServerCloseClientConnections", func(t *testing.T) { testServerCloseClientConnections(t, newServer) }) + t.Run("ServerClientTransportType", func(t *testing.T) { testServerClientTransportType(t, newServer) }) + }) + } + for _, name := range []string{"NewTLSServer", "NewTLSServerManual"} { + t.Run(name, func(t *testing.T) { + newServer := newServers[name] + t.Run("ServerClient", func(t *testing.T) { testServerClient(t, newServer) }) + t.Run("TLSServerClientTransportType", func(t *testing.T) { testTLSServerClientTransportType(t, newServer) }) + }) + } +} + +func testServer(t *testing.T, newServer newServerFunc) { + ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("hello")) })) defer ts.Close() @@ -22,6 +62,7 @@ func TestServer(t *testing.T) { t.Fatal(err) } got, err := ioutil.ReadAll(res.Body) + res.Body.Close() if err != nil { t.Fatal(err) } @@ -31,8 +72,8 @@ func TestServer(t *testing.T) { } // Issue 12781 -func TestGetAfterClose(t *testing.T) { - ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +func testGetAfterClose(t *testing.T, newServer newServerFunc) { + ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("hello")) })) @@ -57,8 +98,8 @@ func TestGetAfterClose(t *testing.T) { } } -func TestServerCloseBlocking(t *testing.T) { - ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +func testServerCloseBlocking(t *testing.T, newServer newServerFunc) { + ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("hello")) })) dial := func() net.Conn { @@ -86,9 +127,9 @@ func TestServerCloseBlocking(t *testing.T) { } // Issue 14290 -func TestServerCloseClientConnections(t *testing.T) { +func testServerCloseClientConnections(t *testing.T, newServer newServerFunc) { var s *Server - s = NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s = newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s.CloseClientConnections() })) defer s.Close() @@ -98,3 +139,66 @@ func TestServerCloseClientConnections(t *testing.T) { t.Fatalf("Unexpected response: %#v", res) } } + +// Tests that the Server.Client method works and returns an http.Client that can hit +// NewTLSServer without cert warnings. +func testServerClient(t *testing.T, newTLSServer newServerFunc) { + ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hello")) + })) + defer ts.Close() + client := ts.Client() + res, err := client.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + got, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatal(err) + } + if string(got) != "hello" { + t.Errorf("got %q, want hello", string(got)) + } +} + +// Tests that the Server.Client.Transport interface is implemented +// by a *http.Transport. +func testServerClientTransportType(t *testing.T, newServer newServerFunc) { + ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + })) + defer ts.Close() + client := ts.Client() + if _, ok := client.Transport.(*http.Transport); !ok { + t.Errorf("got %T, want *http.Transport", client.Transport) + } +} + +// Tests that the TLS Server.Client.Transport interface is implemented +// by a *http.Transport. +func testTLSServerClientTransportType(t *testing.T, newTLSServer newServerFunc) { + ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + })) + defer ts.Close() + client := ts.Client() + if _, ok := client.Transport.(*http.Transport); !ok { + t.Errorf("got %T, want *http.Transport", client.Transport) + } +} + +type onlyCloseListener struct { + net.Listener +} + +func (onlyCloseListener) Close() error { return nil } + +// Issue 19729: panic in Server.Close for values created directly +// without a constructor (so the unexported client field is nil). +func TestServerZeroValueClose(t *testing.T) { + ts := &Server{ + Listener: onlyCloseListener{}, + Config: &http.Server{}, + } + + ts.Close() // tests that it doesn't panic +} diff --git a/libgo/go/net/http/httputil/reverseproxy.go b/libgo/go/net/http/httputil/reverseproxy.go index 79c8fe2..0d514f5 100644 --- a/libgo/go/net/http/httputil/reverseproxy.go +++ b/libgo/go/net/http/httputil/reverseproxy.go @@ -114,6 +114,16 @@ func copyHeader(dst, src http.Header) { } } +func cloneHeader(h http.Header) http.Header { + h2 := make(http.Header, len(h)) + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 + } + return h2 +} + // Hop-by-hop headers. These are removed when sent to the backend. // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html var hopHeaders = []string{ @@ -149,30 +159,21 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { }() } - outreq := new(http.Request) - *outreq = *req // includes shallow copies of maps, but okay + outreq := req.WithContext(ctx) // includes shallow copies of maps, but okay if req.ContentLength == 0 { outreq.Body = nil // Issue 16036: nil Body for http.Transport retries } - outreq = outreq.WithContext(ctx) + + outreq.Header = cloneHeader(req.Header) p.Director(outreq) outreq.Close = false - // We are modifying the same underlying map from req (shallow - // copied above) so we only copy it if necessary. - copiedHeaders := false - // Remove hop-by-hop headers listed in the "Connection" header. // See RFC 2616, section 14.10. if c := outreq.Header.Get("Connection"); c != "" { for _, f := range strings.Split(c, ",") { if f = strings.TrimSpace(f); f != "" { - if !copiedHeaders { - outreq.Header = make(http.Header) - copyHeader(outreq.Header, req.Header) - copiedHeaders = true - } outreq.Header.Del(f) } } @@ -183,11 +184,6 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // connection, regardless of what the client sent to us. for _, h := range hopHeaders { if outreq.Header.Get(h) != "" { - if !copiedHeaders { - outreq.Header = make(http.Header) - copyHeader(outreq.Header, req.Header) - copiedHeaders = true - } outreq.Header.Del(h) } } @@ -235,7 +231,8 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // The "Trailer" header isn't included in the Transport's response, // at least for *http.Transport. Build it up from Trailer. - if len(res.Trailer) > 0 { + announcedTrailers := len(res.Trailer) + if announcedTrailers > 0 { trailerKeys := make([]string, 0, len(res.Trailer)) for k := range res.Trailer { trailerKeys = append(trailerKeys, k) @@ -254,7 +251,18 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } p.copyResponse(rw, res.Body) res.Body.Close() // close now, instead of defer, to populate res.Trailer - copyHeader(rw.Header(), res.Trailer) + + if len(res.Trailer) == announcedTrailers { + copyHeader(rw.Header(), res.Trailer) + return + } + + for k, vv := range res.Trailer { + k = http.TrailerPrefix + k + for _, v := range vv { + rw.Header().Add(k, v) + } + } } func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { @@ -288,7 +296,7 @@ func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int var written int64 for { nr, rerr := src.Read(buf) - if rerr != nil && rerr != io.EOF { + if rerr != nil && rerr != io.EOF && rerr != context.Canceled { p.logf("httputil: ReverseProxy read error during body copy: %v", rerr) } if nr > 0 { diff --git a/libgo/go/net/http/httputil/reverseproxy_test.go b/libgo/go/net/http/httputil/reverseproxy_test.go index 20c4e16..37a9992 100644 --- a/libgo/go/net/http/httputil/reverseproxy_test.go +++ b/libgo/go/net/http/httputil/reverseproxy_test.go @@ -69,6 +69,7 @@ func TestReverseProxy(t *testing.T) { w.WriteHeader(backendStatus) w.Write([]byte(backendResponse)) w.Header().Set("X-Trailer", "trailer_value") + w.Header().Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") })) defer backend.Close() backendURL, err := url.Parse(backend.URL) @@ -79,6 +80,7 @@ func TestReverseProxy(t *testing.T) { proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests frontend := httptest.NewServer(proxyHandler) defer frontend.Close() + frontendClient := frontend.Client() getReq, _ := http.NewRequest("GET", frontend.URL, nil) getReq.Host = "some-name" @@ -86,7 +88,7 @@ func TestReverseProxy(t *testing.T) { getReq.Header.Set("Proxy-Connection", "should be deleted") getReq.Header.Set("Upgrade", "foo") getReq.Close = true - res, err := http.DefaultClient.Do(getReq) + res, err := frontendClient.Do(getReq) if err != nil { t.Fatalf("Get: %v", err) } @@ -121,12 +123,15 @@ func TestReverseProxy(t *testing.T) { if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e { t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e) } + if g, e := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != e { + t.Errorf("Trailer(X-Unannounced-Trailer) = %q ; want %q", g, e) + } // Test that a backend failing to be reached or one which doesn't return // a response results in a StatusBadGateway. getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil) getReq.Close = true - res, err = http.DefaultClient.Do(getReq) + res, err = frontendClient.Do(getReq) if err != nil { t.Fatal(err) } @@ -172,7 +177,7 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { getReq.Header.Set("Connection", "Upgrade, "+fakeConnectionToken) getReq.Header.Set("Upgrade", "original value") getReq.Header.Set(fakeConnectionToken, "should be deleted") - res, err := http.DefaultClient.Do(getReq) + res, err := frontend.Client().Do(getReq) if err != nil { t.Fatalf("Get: %v", err) } @@ -220,7 +225,7 @@ func TestXForwardedFor(t *testing.T) { getReq.Header.Set("Connection", "close") getReq.Header.Set("X-Forwarded-For", prevForwardedFor) getReq.Close = true - res, err := http.DefaultClient.Do(getReq) + res, err := frontend.Client().Do(getReq) if err != nil { t.Fatalf("Get: %v", err) } @@ -259,7 +264,7 @@ func TestReverseProxyQuery(t *testing.T) { frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL)) req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil) req.Close = true - res, err := http.DefaultClient.Do(req) + res, err := frontend.Client().Do(req) if err != nil { t.Fatalf("%d. Get: %v", i, err) } @@ -295,7 +300,7 @@ func TestReverseProxyFlushInterval(t *testing.T) { req, _ := http.NewRequest("GET", frontend.URL, nil) req.Close = true - res, err := http.DefaultClient.Do(req) + res, err := frontend.Client().Do(req) if err != nil { t.Fatalf("Get: %v", err) } @@ -349,13 +354,14 @@ func TestReverseProxyCancelation(t *testing.T) { frontend := httptest.NewServer(proxyHandler) defer frontend.Close() + frontendClient := frontend.Client() getReq, _ := http.NewRequest("GET", frontend.URL, nil) go func() { <-reqInFlight - http.DefaultTransport.(*http.Transport).CancelRequest(getReq) + frontendClient.Transport.(*http.Transport).CancelRequest(getReq) }() - res, err := http.DefaultClient.Do(getReq) + res, err := frontendClient.Do(getReq) if res != nil { t.Errorf("got response %v; want nil", res.Status) } @@ -363,7 +369,7 @@ func TestReverseProxyCancelation(t *testing.T) { // This should be an error like: // Get http://127.0.0.1:58079: read tcp 127.0.0.1:58079: // use of closed network connection - t.Error("DefaultClient.Do() returned nil error; want non-nil error") + t.Error("Server.Client().Do() returned nil error; want non-nil error") } } @@ -428,11 +434,12 @@ func TestUserAgentHeader(t *testing.T) { proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests frontend := httptest.NewServer(proxyHandler) defer frontend.Close() + frontendClient := frontend.Client() getReq, _ := http.NewRequest("GET", frontend.URL, nil) getReq.Header.Set("User-Agent", explicitUA) getReq.Close = true - res, err := http.DefaultClient.Do(getReq) + res, err := frontendClient.Do(getReq) if err != nil { t.Fatalf("Get: %v", err) } @@ -441,7 +448,7 @@ func TestUserAgentHeader(t *testing.T) { getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil) getReq.Header.Set("User-Agent", "") getReq.Close = true - res, err = http.DefaultClient.Do(getReq) + res, err = frontendClient.Do(getReq) if err != nil { t.Fatalf("Get: %v", err) } @@ -493,7 +500,7 @@ func TestReverseProxyGetPutBuffer(t *testing.T) { req, _ := http.NewRequest("GET", frontend.URL, nil) req.Close = true - res, err := http.DefaultClient.Do(req) + res, err := frontend.Client().Do(req) if err != nil { t.Fatalf("Get: %v", err) } @@ -540,7 +547,7 @@ func TestReverseProxy_Post(t *testing.T) { defer frontend.Close() postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody)) - res, err := http.DefaultClient.Do(postReq) + res, err := frontend.Client().Do(postReq) if err != nil { t.Fatalf("Do: %v", err) } @@ -573,7 +580,7 @@ func TestReverseProxy_NilBody(t *testing.T) { frontend := httptest.NewServer(proxyHandler) defer frontend.Close() - res, err := http.DefaultClient.Get(frontend.URL) + res, err := frontend.Client().Get(frontend.URL) if err != nil { t.Fatal(err) } @@ -664,3 +671,101 @@ func TestReverseProxy_CopyBuffer(t *testing.T) { } } } + +type staticTransport struct { + res *http.Response +} + +func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) { + return t.res, nil +} + +func BenchmarkServeHTTP(b *testing.B) { + res := &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(strings.NewReader("")), + } + proxy := &ReverseProxy{ + Director: func(*http.Request) {}, + Transport: &staticTransport{res}, + } + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + proxy.ServeHTTP(w, r) + } +} + +func TestServeHTTPDeepCopy(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("Hello Gopher!")) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + + type result struct { + before, after string + } + + resultChan := make(chan result, 1) + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + before := r.URL.String() + proxyHandler.ServeHTTP(w, r) + after := r.URL.String() + resultChan <- result{before: before, after: after} + })) + defer frontend.Close() + + want := result{before: "/", after: "/"} + + res, err := frontend.Client().Get(frontend.URL) + if err != nil { + t.Fatalf("Do: %v", err) + } + res.Body.Close() + + got := <-resultChan + if got != want { + t.Errorf("got = %+v; want = %+v", got, want) + } +} + +// Issue 18327: verify we always do a deep copy of the Request.Header map +// before any mutations. +func TestClonesRequestHeaders(t *testing.T) { + req, _ := http.NewRequest("GET", "http://foo.tld/", nil) + req.RemoteAddr = "1.2.3.4:56789" + rp := &ReverseProxy{ + Director: func(req *http.Request) { + req.Header.Set("From-Director", "1") + }, + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + if v := req.Header.Get("From-Director"); v != "1" { + t.Errorf("From-Directory value = %q; want 1", v) + } + return nil, io.EOF + }), + } + rp.ServeHTTP(httptest.NewRecorder(), req) + + if req.Header.Get("From-Director") == "1" { + t.Error("Director header mutation modified caller's request") + } + if req.Header.Get("X-Forwarded-For") != "" { + t.Error("X-Forward-For header mutation modified caller's request") + } + +} + +type roundTripperFunc func(req *http.Request) (*http.Response, error) + +func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} diff --git a/libgo/go/net/http/main_test.go b/libgo/go/net/http/main_test.go index 438bd2e..21c8505 100644 --- a/libgo/go/net/http/main_test.go +++ b/libgo/go/net/http/main_test.go @@ -37,6 +37,8 @@ func interestingGoroutines() (gs []string) { } stack := strings.TrimSpace(sl[1]) if stack == "" || + strings.Contains(stack, "testing.(*M).before.func1") || + strings.Contains(stack, "os/signal.signal_recv") || strings.Contains(stack, "created by net.startServer") || strings.Contains(stack, "created by testing.RunTests") || strings.Contains(stack, "closeWriteAndWait") || @@ -56,8 +58,9 @@ func interestingGoroutines() (gs []string) { // Verify the other tests didn't leave any goroutines running. func goroutineLeaked() bool { - if testing.Short() { - // not counting goroutines for leakage in -short mode + if testing.Short() || runningBenchmarks() { + // Don't worry about goroutine leaks in -short mode or in + // benchmark mode. Too distracting when there are false positives. return false } @@ -92,6 +95,18 @@ func setParallel(t *testing.T) { } } +func runningBenchmarks() bool { + for i, arg := range os.Args { + if strings.HasPrefix(arg, "-test.bench=") && !strings.HasSuffix(arg, "=") { + return true + } + if arg == "-test.bench" && i < len(os.Args)-1 && os.Args[i+1] != "" { + return true + } + } + return false +} + func afterTest(t testing.TB) { http.DefaultTransport.(*http.Transport).CloseIdleConnections() if testing.Short() { @@ -151,7 +166,3 @@ func waitErrCondition(waitFor, checkEvery time.Duration, fn func() error) error } return err } - -func closeClient(c *http.Client) { - c.Transport.(*http.Transport).CloseIdleConnections() -} diff --git a/libgo/go/net/http/npn_test.go b/libgo/go/net/http/npn_test.go index 4c1f6b5..618bdbe 100644 --- a/libgo/go/net/http/npn_test.go +++ b/libgo/go/net/http/npn_test.go @@ -8,6 +8,7 @@ import ( "bufio" "bytes" "crypto/tls" + "crypto/x509" "fmt" "io" "io/ioutil" @@ -43,10 +44,7 @@ func TestNextProtoUpgrade(t *testing.T) { // Normal request, without NPN. { - tr := newTLSTransport(t, ts) - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} - + c := ts.Client() res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) @@ -63,11 +61,18 @@ func TestNextProtoUpgrade(t *testing.T) { // Request to an advertised but unhandled NPN protocol. // Server will hang up. { - tr := newTLSTransport(t, ts) - tr.TLSClientConfig.NextProtos = []string{"unhandled-proto"} + certPool := x509.NewCertPool() + certPool.AddCert(ts.Certificate()) + tr := &Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: certPool, + NextProtos: []string{"unhandled-proto"}, + }, + } defer tr.CloseIdleConnections() - c := &Client{Transport: tr} - + c := &Client{ + Transport: tr, + } res, err := c.Get(ts.URL) if err == nil { defer res.Body.Close() @@ -80,7 +85,8 @@ func TestNextProtoUpgrade(t *testing.T) { // Request using the "tls-0.9" protocol, which we register here. // It is HTTP/0.9 over TLS. { - tlsConfig := newTLSTransport(t, ts).TLSClientConfig + c := ts.Client() + tlsConfig := c.Transport.(*Transport).TLSClientConfig tlsConfig.NextProtos = []string{"tls-0.9"} conn, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig) if err != nil { diff --git a/libgo/go/net/http/pprof/pprof.go b/libgo/go/net/http/pprof/pprof.go index 05d0890..12c7599 100644 --- a/libgo/go/net/http/pprof/pprof.go +++ b/libgo/go/net/http/pprof/pprof.go @@ -37,6 +37,11 @@ // // wget http://localhost:6060/debug/pprof/trace?seconds=5 // +// Or to look at the holders of contended mutexes, after calling +// runtime.SetMutexProfileFraction in your program: +// +// go tool pprof http://localhost:6060/debug/pprof/mutex +// // To view all available profiles, open http://localhost:6060/debug/pprof/ // in your browser. // @@ -57,6 +62,7 @@ import ( "os" "runtime" "runtime/pprof" + "runtime/trace" "strconv" "strings" "time" @@ -89,6 +95,11 @@ func sleep(w http.ResponseWriter, d time.Duration) { } } +func durationExceedsWriteTimeout(r *http.Request, seconds float64) bool { + srv, ok := r.Context().Value(http.ServerContextKey).(*http.Server) + return ok && srv.WriteTimeout != 0 && seconds >= srv.WriteTimeout.Seconds() +} + // Profile responds with the pprof-formatted cpu profile. // The package initialization registers it as /debug/pprof/profile. func Profile(w http.ResponseWriter, r *http.Request) { @@ -97,6 +108,14 @@ func Profile(w http.ResponseWriter, r *http.Request) { sec = 30 } + if durationExceedsWriteTimeout(r, float64(sec)) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Go-Pprof", "1") + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintln(w, "profile duration exceeds server's WriteTimeout") + return + } + // Set Content Type assuming StartCPUProfile will work, // because if it does it starts writing. w.Header().Set("Content-Type", "application/octet-stream") @@ -105,6 +124,7 @@ func Profile(w http.ResponseWriter, r *http.Request) { // Can change header back to text content // and send error code. w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Go-Pprof", "1") w.WriteHeader(http.StatusInternalServerError) fmt.Fprintf(w, "Could not enable CPU profiling: %s\n", err) return @@ -122,20 +142,28 @@ func Trace(w http.ResponseWriter, r *http.Request) { sec = 1 } + if durationExceedsWriteTimeout(r, sec) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Go-Pprof", "1") + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintln(w, "profile duration exceeds server's WriteTimeout") + return + } + // Set Content Type assuming trace.Start will work, // because if it does it starts writing. w.Header().Set("Content-Type", "application/octet-stream") - w.Write([]byte("tracing not yet supported with gccgo")) - // if err := trace.Start(w); err != nil { - // // trace.Start failed, so no writes yet. - // // Can change header back to text content and send error code. - // w.Header().Set("Content-Type", "text/plain; charset=utf-8") - // w.WriteHeader(http.StatusInternalServerError) - // fmt.Fprintf(w, "Could not enable tracing: %s\n", err) - // return - // } - // sleep(w, time.Duration(sec*float64(time.Second))) - // trace.Stop() + if err := trace.Start(w); err != nil { + // trace.Start failed, so no writes yet. + // Can change header back to text content and send error code. + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Go-Pprof", "1") + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(w, "Could not enable tracing: %s\n", err) + return + } + sleep(w, time.Duration(sec*float64(time.Second))) + trace.Stop() } // Symbol looks up the program counters listed in the request, @@ -207,7 +235,6 @@ func (name handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { runtime.GC() } p.WriteTo(w, debug) - return } // Index responds with the pprof-formatted profile named by the request. diff --git a/libgo/go/net/http/proxy_test.go b/libgo/go/net/http/proxy_test.go index 823d144..f59a551 100644 --- a/libgo/go/net/http/proxy_test.go +++ b/libgo/go/net/http/proxy_test.go @@ -75,7 +75,13 @@ func TestCacheKeys(t *testing.T) { func ResetProxyEnv() { for _, v := range []string{"HTTP_PROXY", "http_proxy", "NO_PROXY", "no_proxy"} { - os.Setenv(v, "") + os.Unsetenv(v) } ResetCachedEnvironment() } + +func TestInvalidNoProxy(t *testing.T) { + ResetProxyEnv() + os.Setenv("NO_PROXY", ":1") + useProxy("example.com:80") // should not panic +} diff --git a/libgo/go/net/http/request.go b/libgo/go/net/http/request.go index fb6bb0a..13f367c 100644 --- a/libgo/go/net/http/request.go +++ b/libgo/go/net/http/request.go @@ -27,8 +27,6 @@ import ( "sync" "golang_org/x/net/idna" - "golang_org/x/text/unicode/norm" - "golang_org/x/text/width" ) const ( @@ -331,6 +329,16 @@ func (r *Request) WithContext(ctx context.Context) *Request { r2 := new(Request) *r2 = *r r2.ctx = ctx + + // Deep copy the URL because it isn't + // a map and the URL is mutable by users + // of WithContext. + if r.URL != nil { + r2URL := new(url.URL) + *r2URL = *r.URL + r2.URL = r2URL + } + return r2 } @@ -341,18 +349,6 @@ func (r *Request) ProtoAtLeast(major, minor int) bool { r.ProtoMajor == major && r.ProtoMinor >= minor } -// protoAtLeastOutgoing is like ProtoAtLeast, but is for outgoing -// requests (see issue 18407) where these fields aren't supposed to -// matter. As a minor fix for Go 1.8, at least treat (0, 0) as -// matching HTTP/1.1 or HTTP/1.0. Only HTTP/1.1 is used. -// TODO(bradfitz): ideally remove this whole method. It shouldn't be used. -func (r *Request) protoAtLeastOutgoing(major, minor int) bool { - if r.ProtoMajor == 0 && r.ProtoMinor == 0 && major == 1 && minor <= 1 { - return true - } - return r.ProtoAtLeast(major, minor) -} - // UserAgent returns the client's User-Agent, if sent in the request. func (r *Request) UserAgent() string { return r.Header.Get("User-Agent") @@ -621,6 +617,9 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai // Write body and trailer err = tw.WriteBody(w) if err != nil { + if tw.bodyReadError == err { + err = requestBodyReadError{err} + } return err } @@ -630,17 +629,25 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai return nil } +// requestBodyReadError wraps an error from (*Request).write to indicate +// that the error came from a Read call on the Request.Body. +// This error type should not escape the net/http package to users. +type requestBodyReadError struct{ error } + func idnaASCII(v string) (string, error) { + // TODO: Consider removing this check after verifying performance is okay. + // Right now punycode verification, length checks, context checks, and the + // permissible character tests are all omitted. It also prevents the ToASCII + // call from salvaging an invalid IDN, when possible. As a result it may be + // possible to have two IDNs that appear identical to the user where the + // ASCII-only version causes an error downstream whereas the non-ASCII + // version does not. + // Note that for correct ASCII IDNs ToASCII will only do considerably more + // work, but it will not cause an allocation. if isASCII(v) { return v, nil } - // The idna package doesn't do everything from - // https://tools.ietf.org/html/rfc5895 so we do it here. - // TODO(bradfitz): should the idna package do this instead? - v = strings.ToLower(v) - v = width.Fold.String(v) - v = norm.NFC.String(v) - return idna.ToASCII(v) + return idna.Lookup.ToASCII(v) } // cleanHost cleans up the host sent in request's Host header. @@ -755,7 +762,7 @@ func validMethod(method string) bool { // exact value (instead of -1), GetBody is populated (so 307 and 308 // redirects can replay the body), and Body is set to NoBody if the // ContentLength is 0. -func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { +func NewRequest(method, url string, body io.Reader) (*Request, error) { if method == "" { // We document that "" means "GET" for Request.Method, and people have // relied on that from NewRequest, so keep that working. @@ -765,7 +772,7 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { if !validMethod(method) { return nil, fmt.Errorf("net/http: invalid method %q", method) } - u, err := url.Parse(urlStr) + u, err := parseURL(url) // Just url.Parse (url is shadowed for godoc). if err != nil { return nil, err } @@ -930,6 +937,9 @@ func readRequest(b *bufio.Reader, deleteHostHeader bool) (req *Request, err erro if !ok { return nil, &badStringError{"malformed HTTP request", s} } + if !validMethod(req.Method) { + return nil, &badStringError{"invalid method", req.Method} + } rawurl := req.RequestURI if req.ProtoMajor, req.ProtoMinor, ok = ParseHTTPVersion(req.Proto); !ok { return nil, &badStringError{"malformed HTTP version", req.Proto} @@ -1021,11 +1031,6 @@ type maxBytesReader struct { err error // sticky error } -func (l *maxBytesReader) tooLarge() (n int, err error) { - l.err = errors.New("http: request body too large") - return 0, l.err -} - func (l *maxBytesReader) Read(p []byte) (n int, err error) { if l.err != nil { return 0, l.err @@ -1297,7 +1302,7 @@ func (r *Request) closeBody() { } func (r *Request) isReplayable() bool { - if r.Body == nil { + if r.Body == nil || r.Body == NoBody || r.GetBody != nil { switch valueOrDefault(r.Method, "GET") { case "GET", "HEAD", "OPTIONS", "TRACE": return true diff --git a/libgo/go/net/http/request_test.go b/libgo/go/net/http/request_test.go index e674837..967156b 100644 --- a/libgo/go/net/http/request_test.go +++ b/libgo/go/net/http/request_test.go @@ -7,6 +7,7 @@ package http_test import ( "bufio" "bytes" + "context" "encoding/base64" "fmt" "io" @@ -785,6 +786,28 @@ func TestMaxBytesReaderStickyError(t *testing.T) { } } +func TestWithContextDeepCopiesURL(t *testing.T) { + req, err := NewRequest("POST", "https://golang.org/", nil) + if err != nil { + t.Fatal(err) + } + + reqCopy := req.WithContext(context.Background()) + reqCopy.URL.Scheme = "http" + + firstURL, secondURL := req.URL.String(), reqCopy.URL.String() + if firstURL == secondURL { + t.Errorf("unexpected change to original request's URL") + } + + // And also check we don't crash on nil (Issue 20601) + req.URL = nil + reqCopy = req.WithContext(context.Background()) + if reqCopy.URL != nil { + t.Error("expected nil URL in cloned request") + } +} + // verify that NewRequest sets Request.GetBody and that it works func TestNewRequestGetBody(t *testing.T) { tests := []struct { diff --git a/libgo/go/net/http/response.go b/libgo/go/net/http/response.go index ae118fb..0357b60 100644 --- a/libgo/go/net/http/response.go +++ b/libgo/go/net/http/response.go @@ -37,9 +37,10 @@ type Response struct { // Header maps header keys to values. If the response had multiple // headers with the same key, they may be concatenated, with comma // delimiters. (Section 4.2 of RFC 2616 requires that multiple headers - // be semantically equivalent to a comma-delimited sequence.) Values - // duplicated by other fields in this struct (e.g., ContentLength) are - // omitted from Header. + // be semantically equivalent to a comma-delimited sequence.) When + // Header values are duplicated by other fields in this struct (e.g., + // ContentLength, TransferEncoding, Trailer), the field values are + // authoritative. // // Keys in the map are canonicalized (see CanonicalHeaderKey). Header Header @@ -152,23 +153,23 @@ func ReadResponse(r *bufio.Reader, req *Request) (*Response, error) { } return nil, err } - f := strings.SplitN(line, " ", 3) - if len(f) < 2 { + if i := strings.IndexByte(line, ' '); i == -1 { return nil, &badStringError{"malformed HTTP response", line} + } else { + resp.Proto = line[:i] + resp.Status = strings.TrimLeft(line[i+1:], " ") } - reasonPhrase := "" - if len(f) > 2 { - reasonPhrase = f[2] + statusCode := resp.Status + if i := strings.IndexByte(resp.Status, ' '); i != -1 { + statusCode = resp.Status[:i] } - if len(f[1]) != 3 { - return nil, &badStringError{"malformed HTTP status code", f[1]} + if len(statusCode) != 3 { + return nil, &badStringError{"malformed HTTP status code", statusCode} } - resp.StatusCode, err = strconv.Atoi(f[1]) + resp.StatusCode, err = strconv.Atoi(statusCode) if err != nil || resp.StatusCode < 0 { - return nil, &badStringError{"malformed HTTP status code", f[1]} + return nil, &badStringError{"malformed HTTP status code", statusCode} } - resp.Status = f[1] + " " + reasonPhrase - resp.Proto = f[0] var ok bool if resp.ProtoMajor, resp.ProtoMinor, ok = ParseHTTPVersion(resp.Proto); !ok { return nil, &badStringError{"malformed HTTP version", resp.Proto} @@ -320,3 +321,9 @@ func (r *Response) Write(w io.Writer) error { // Success return nil } + +func (r *Response) closeBody() { + if r.Body != nil { + r.Body.Close() + } +} diff --git a/libgo/go/net/http/response_test.go b/libgo/go/net/http/response_test.go index 660d517..f1a50bd 100644 --- a/libgo/go/net/http/response_test.go +++ b/libgo/go/net/http/response_test.go @@ -318,7 +318,7 @@ var respTests = []respTest{ { "HTTP/1.0 303\r\n\r\n", Response{ - Status: "303 ", + Status: "303", StatusCode: 303, Proto: "HTTP/1.0", ProtoMajor: 1, @@ -532,6 +532,29 @@ some body`, }, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00", }, + + // Issue 19989: two spaces between HTTP version and status. + { + "HTTP/1.0 401 Unauthorized\r\n" + + "Content-type: text/html\r\n" + + "WWW-Authenticate: Basic realm=\"\"\r\n\r\n" + + "Your Authentication failed.\r\n", + Response{ + Status: "401 Unauthorized", + StatusCode: 401, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{ + "Content-Type": {"text/html"}, + "Www-Authenticate": {`Basic realm=""`}, + }, + Close: true, + ContentLength: -1, + }, + "Your Authentication failed.\r\n", + }, } // tests successful calls to ReadResponse, and inspects the returned Response. @@ -926,3 +949,29 @@ func TestNeedsSniff(t *testing.T) { t.Errorf("needsSniff empty Content-Type = %t; want %t", got, want) } } + +// A response should only write out single Connection: close header. Tests #19499. +func TestResponseWritesOnlySingleConnectionClose(t *testing.T) { + const connectionCloseHeader = "Connection: close" + + res, err := ReadResponse(bufio.NewReader(strings.NewReader("HTTP/1.0 200 OK\r\n\r\nAAAA")), nil) + if err != nil { + t.Fatalf("ReadResponse failed %v", err) + } + + var buf1 bytes.Buffer + if err = res.Write(&buf1); err != nil { + t.Fatalf("Write failed %v", err) + } + if res, err = ReadResponse(bufio.NewReader(&buf1), nil); err != nil { + t.Fatalf("ReadResponse failed %v", err) + } + + var buf2 bytes.Buffer + if err = res.Write(&buf2); err != nil { + t.Fatalf("Write failed %v", err) + } + if count := strings.Count(buf2.String(), connectionCloseHeader); count != 1 { + t.Errorf("Found %d %q header", count, connectionCloseHeader) + } +} diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go index 73dd56e..7137599 100644 --- a/libgo/go/net/http/serve_test.go +++ b/libgo/go/net/http/serve_test.go @@ -337,6 +337,7 @@ var serveMuxTests = []struct { {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"}, {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"}, {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"}, + {"GET", "codesearch.google.com:443", "/", 203, "codesearch.google.com/"}, {"GET", "images.google.com", "/search", 201, "/search"}, {"GET", "images.google.com", "/search/", 404, ""}, {"GET", "images.google.com", "/search/foo", 404, ""}, @@ -460,31 +461,86 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { } } +func BenchmarkServeMux(b *testing.B) { + + type test struct { + path string + code int + req *Request + } + + // Build example handlers and requests + var tests []test + endpoints := []string{"search", "dir", "file", "change", "count", "s"} + for _, e := range endpoints { + for i := 200; i < 230; i++ { + p := fmt.Sprintf("/%s/%d/", e, i) + tests = append(tests, test{ + path: p, + code: i, + req: &Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: p}}, + }) + } + } + mux := NewServeMux() + for _, tt := range tests { + mux.Handle(tt.path, serve(tt.code)) + } + + rw := httptest.NewRecorder() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, tt := range tests { + *rw = httptest.ResponseRecorder{} + h, pattern := mux.Handler(tt.req) + h.ServeHTTP(rw, tt.req) + if pattern != tt.path || rw.Code != tt.code { + b.Fatalf("got %d, %q, want %d, %q", rw.Code, pattern, tt.code, tt.path) + } + } + } +} + func TestServerTimeouts(t *testing.T) { setParallel(t) defer afterTest(t) + // Try three times, with increasing timeouts. + tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second} + for i, timeout := range tries { + err := testServerTimeouts(timeout) + if err == nil { + return + } + t.Logf("failed at %v: %v", timeout, err) + if i != len(tries)-1 { + t.Logf("retrying at %v ...", tries[i+1]) + } + } + t.Fatal("all attempts failed") +} + +func testServerTimeouts(timeout time.Duration) error { reqNum := 0 ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { reqNum++ fmt.Fprintf(res, "req=%d", reqNum) })) - ts.Config.ReadTimeout = 250 * time.Millisecond - ts.Config.WriteTimeout = 250 * time.Millisecond + ts.Config.ReadTimeout = timeout + ts.Config.WriteTimeout = timeout ts.Start() defer ts.Close() // Hit the HTTP server successfully. - tr := &Transport{DisableKeepAlives: true} // they interfere with this test - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() r, err := c.Get(ts.URL) if err != nil { - t.Fatalf("http Get #1: %v", err) + return fmt.Errorf("http Get #1: %v", err) } got, err := ioutil.ReadAll(r.Body) expected := "req=1" if string(got) != expected || err != nil { - t.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil", + return fmt.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil", string(got), err, expected) } @@ -492,17 +548,18 @@ func TestServerTimeouts(t *testing.T) { t1 := time.Now() conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { - t.Fatalf("Dial: %v", err) + return fmt.Errorf("Dial: %v", err) } buf := make([]byte, 1) n, err := conn.Read(buf) conn.Close() latency := time.Since(t1) if n != 0 || err != io.EOF { - t.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF) + return fmt.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF) } - if latency < 200*time.Millisecond /* fudge from 250 ms above */ { - t.Errorf("got EOF after %s, want >= %s", latency, 200*time.Millisecond) + minLatency := timeout / 5 * 4 + if latency < minLatency { + return fmt.Errorf("got EOF after %s, want >= %s", latency, minLatency) } // Hit the HTTP server successfully again, verifying that the @@ -510,29 +567,31 @@ func TestServerTimeouts(t *testing.T) { // get "req=2", not "req=3") r, err = c.Get(ts.URL) if err != nil { - t.Fatalf("http Get #2: %v", err) + return fmt.Errorf("http Get #2: %v", err) } got, err = ioutil.ReadAll(r.Body) + r.Body.Close() expected = "req=2" if string(got) != expected || err != nil { - t.Errorf("Get #2 got %q, %v, want %q, nil", string(got), err, expected) + return fmt.Errorf("Get #2 got %q, %v, want %q, nil", string(got), err, expected) } if !testing.Short() { conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { - t.Fatalf("Dial: %v", err) + return fmt.Errorf("long Dial: %v", err) } defer conn.Close() go io.Copy(ioutil.Discard, conn) for i := 0; i < 5; i++ { _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n")) if err != nil { - t.Fatalf("on write %d: %v", i, err) + return fmt.Errorf("on write %d: %v", i, err) } - time.Sleep(ts.Config.ReadTimeout / 2) + time.Sleep(timeout / 2) } } + return nil } // Test that the HTTP/2 server handles Server.WriteTimeout (Issue 18437) @@ -548,12 +607,10 @@ func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) { ts.StartTLS() defer ts.Close() - tr := newTLSTransport(t, ts) - defer tr.CloseIdleConnections() - if err := ExportHttp2ConfigureTransport(tr); err != nil { + c := ts.Client() + if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil { t.Fatal(err) } - c := &Client{Transport: tr} for i := 1; i <= 3; i++ { req, err := NewRequest("GET", ts.URL, nil) @@ -585,13 +642,139 @@ func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) { } } +// tryTimeouts runs testFunc with increasing timeouts. Test passes on first success, +// and fails if all timeouts fail. +func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) { + tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second} + for i, timeout := range tries { + err := testFunc(timeout) + if err == nil { + return + } + t.Logf("failed at %v: %v", timeout, err) + if i != len(tries)-1 { + t.Logf("retrying at %v ...", tries[i+1]) + } + } + t.Fatal("all attempts failed") +} + +// Test that the HTTP/2 server RSTs stream on slow write. +func TestHTTP2WriteDeadlineEnforcedPerStream(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + setParallel(t) + defer afterTest(t) + tryTimeouts(t, testHTTP2WriteDeadlineEnforcedPerStream) +} + +func testHTTP2WriteDeadlineEnforcedPerStream(timeout time.Duration) error { + reqNum := 0 + ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { + reqNum++ + if reqNum == 1 { + return // first request succeeds + } + time.Sleep(timeout) // second request times out + })) + ts.Config.WriteTimeout = timeout / 2 + ts.TLS = &tls.Config{NextProtos: []string{"h2"}} + ts.StartTLS() + defer ts.Close() + + c := ts.Client() + if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil { + return fmt.Errorf("ExportHttp2ConfigureTransport: %v", err) + } + + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + return fmt.Errorf("NewRequest: %v", err) + } + r, err := c.Do(req) + if err != nil { + return fmt.Errorf("http2 Get #1: %v", err) + } + r.Body.Close() + if r.ProtoMajor != 2 { + return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto) + } + + req, err = NewRequest("GET", ts.URL, nil) + if err != nil { + return fmt.Errorf("NewRequest: %v", err) + } + r, err = c.Do(req) + if err == nil { + r.Body.Close() + if r.ProtoMajor != 2 { + return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto) + } + return fmt.Errorf("http2 Get #2 expected error, got nil") + } + expected := "stream ID 3; INTERNAL_ERROR" // client IDs are odd, second stream should be 3 + if !strings.Contains(err.Error(), expected) { + return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err) + } + return nil +} + +// Test that the HTTP/2 server does not send RST when WriteDeadline not set. +func TestHTTP2NoWriteDeadline(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + setParallel(t) + defer afterTest(t) + tryTimeouts(t, testHTTP2NoWriteDeadline) +} + +func testHTTP2NoWriteDeadline(timeout time.Duration) error { + reqNum := 0 + ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { + reqNum++ + if reqNum == 1 { + return // first request succeeds + } + time.Sleep(timeout) // second request timesout + })) + ts.TLS = &tls.Config{NextProtos: []string{"h2"}} + ts.StartTLS() + defer ts.Close() + + c := ts.Client() + if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil { + return fmt.Errorf("ExportHttp2ConfigureTransport: %v", err) + } + + for i := 0; i < 2; i++ { + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + return fmt.Errorf("NewRequest: %v", err) + } + r, err := c.Do(req) + if err != nil { + return fmt.Errorf("http2 Get #%d: %v", i, err) + } + r.Body.Close() + if r.ProtoMajor != 2 { + return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto) + } + } + return nil +} + // golang.org/issue/4741 -- setting only a write timeout that triggers // shouldn't cause a handler to block forever on reads (next HTTP // request) that will never happen. func TestOnlyWriteTimeout(t *testing.T) { setParallel(t) defer afterTest(t) - var conn net.Conn + var ( + mu sync.RWMutex + conn net.Conn + ) var afterTimeoutErrc = make(chan error, 1) ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, req *Request) { buf := make([]byte, 512<<10) @@ -600,17 +783,21 @@ func TestOnlyWriteTimeout(t *testing.T) { t.Errorf("handler Write error: %v", err) return } + mu.RLock() + defer mu.RUnlock() + if conn == nil { + t.Error("no established connection found") + return + } conn.SetWriteDeadline(time.Now().Add(-30 * time.Second)) _, err = w.Write(buf) afterTimeoutErrc <- err })) - ts.Listener = trackLastConnListener{ts.Listener, &conn} + ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn} ts.Start() defer ts.Close() - tr := &Transport{DisableKeepAlives: false} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() errc := make(chan error) go func() { @@ -620,6 +807,7 @@ func TestOnlyWriteTimeout(t *testing.T) { return } _, err = io.Copy(ioutil.Discard, res.Body) + res.Body.Close() errc <- err }() select { @@ -638,12 +826,18 @@ func TestOnlyWriteTimeout(t *testing.T) { // trackLastConnListener tracks the last net.Conn that was accepted. type trackLastConnListener struct { net.Listener + + mu *sync.RWMutex last *net.Conn // destination } func (l trackLastConnListener) Accept() (c net.Conn, err error) { c, err = l.Listener.Accept() - *l.last = c + if err == nil { + l.mu.Lock() + *l.last = c + l.mu.Unlock() + } return } @@ -671,8 +865,7 @@ func TestIdentityResponse(t *testing.T) { ts := httptest.NewServer(handler) defer ts.Close() - c := &Client{Transport: new(Transport)} - defer closeClient(c) + c := ts.Client() // Note: this relies on the assumption (which is true) that // Get sends HTTP/1.1 or greater requests. Otherwise the @@ -936,7 +1129,6 @@ func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr { // Issue 12943 func TestServerAllowsBlockingRemoteAddr(t *testing.T) { - setParallel(t) defer afterTest(t) ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "RA:%s", r.RemoteAddr) @@ -949,21 +1141,22 @@ func TestServerAllowsBlockingRemoteAddr(t *testing.T) { ts.Start() defer ts.Close() - tr := &Transport{DisableKeepAlives: true} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr, Timeout: time.Second} + c := ts.Client() + c.Timeout = time.Second + // Force separate connection for each: + c.Transport.(*Transport).DisableKeepAlives = true - fetch := func(response chan string) { + fetch := func(num int, response chan<- string) { resp, err := c.Get(ts.URL) if err != nil { - t.Error(err) + t.Errorf("Request %d: %v", num, err) response <- "" return } defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) if err != nil { - t.Error(err) + t.Errorf("Request %d: %v", num, err) response <- "" return } @@ -972,14 +1165,14 @@ func TestServerAllowsBlockingRemoteAddr(t *testing.T) { // Start a request. The server will block on getting conn.RemoteAddr. response1c := make(chan string, 1) - go fetch(response1c) + go fetch(1, response1c) // Wait for the server to accept it; grab the connection. conn1 := <-conns // Start another request and grab its connection response2c := make(chan string, 1) - go fetch(response2c) + go fetch(2, response2c) var conn2 net.Conn select { @@ -1022,9 +1215,7 @@ func TestIdentityResponseHeaders(t *testing.T) { })) defer ts.Close() - c := &Client{Transport: new(Transport)} - defer closeClient(c) - + c := ts.Client() res, err := c.Get(ts.URL) if err != nil { t.Fatalf("Get error: %v", err) @@ -1145,12 +1336,7 @@ func TestTLSServer(t *testing.T) { t.Errorf("expected test TLS server to start with https://, got %q", ts.URL) return } - noVerifyTransport := &Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - } - client := &Client{Transport: noVerifyTransport} + client := ts.Client() res, err := client.Get(ts.URL) if err != nil { t.Error(err) @@ -1171,6 +1357,59 @@ func TestTLSServer(t *testing.T) { }) } +func TestServeTLS(t *testing.T) { + // Not parallel: uses global test hooks. + defer afterTest(t) + defer SetTestHookServerServe(nil) + + cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey) + if err != nil { + t.Fatal(err) + } + tlsConf := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + ln := newLocalListener(t) + defer ln.Close() + addr := ln.Addr().String() + + serving := make(chan bool, 1) + SetTestHookServerServe(func(s *Server, ln net.Listener) { + serving <- true + }) + handler := HandlerFunc(func(w ResponseWriter, r *Request) {}) + s := &Server{ + Addr: addr, + TLSConfig: tlsConf, + Handler: handler, + } + errc := make(chan error, 1) + go func() { errc <- s.ServeTLS(ln, "", "") }() + select { + case err := <-errc: + t.Fatalf("ServeTLS: %v", err) + case <-serving: + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } + + c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"h2", "http/1.1"}, + }) + if err != nil { + t.Fatal(err) + } + defer c.Close() + if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want { + t.Errorf("NegotiatedProtocol = %q; want %q", got, want) + } + if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want { + t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want) + } +} + // Issue 15908 func TestAutomaticHTTP2_Serve_NoTLSConfig(t *testing.T) { testAutomaticHTTP2_Serve(t, nil, true) @@ -1967,8 +2206,7 @@ func TestTimeoutHandlerRace(t *testing.T) { ts := httptest.NewServer(TimeoutHandler(delayHi, 20*time.Millisecond, "")) defer ts.Close() - c := &Client{Transport: new(Transport)} - defer closeClient(c) + c := ts.Client() var wg sync.WaitGroup gate := make(chan bool, 10) @@ -2011,8 +2249,8 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) { if testing.Short() { n = 10 } - c := &Client{Transport: new(Transport)} - defer closeClient(c) + + c := ts.Client() for i := 0; i < n; i++ { gate <- true wg.Add(1) @@ -2099,8 +2337,7 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) { ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) defer ts.Close() - c := &Client{Transport: new(Transport)} - defer closeClient(c) + c := ts.Client() // Issue was caused by the timeout handler starting the timer when // was created, not when the request. So wait for more than the timeout @@ -2127,8 +2364,7 @@ func TestTimeoutHandlerEmptyResponse(t *testing.T) { ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) defer ts.Close() - c := &Client{Transport: new(Transport)} - defer closeClient(c) + c := ts.Client() res, err := c.Get(ts.URL) if err != nil { @@ -2364,9 +2600,7 @@ func TestServerWriteHijackZeroBytes(t *testing.T) { ts.Start() defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) @@ -2411,8 +2645,7 @@ func TestStripPrefix(t *testing.T) { ts := httptest.NewServer(StripPrefix("/foo", h)) defer ts.Close() - c := &Client{Transport: new(Transport)} - defer closeClient(c) + c := ts.Client() res, err := c.Get(ts.URL + "/foo/bar") if err != nil { @@ -2433,6 +2666,16 @@ func TestStripPrefix(t *testing.T) { res.Body.Close() } +// https://golang.org/issue/18952. +func TestStripPrefix_notModifyRequest(t *testing.T) { + h := StripPrefix("/foo", NotFoundHandler()) + req := httptest.NewRequest("GET", "/foo/bar", nil) + h.ServeHTTP(httptest.NewRecorder(), req) + if req.URL.Path != "/foo/bar" { + t.Errorf("StripPrefix should not modify the provided Request, but it did") + } +} + func TestRequestLimit_h1(t *testing.T) { testRequestLimit(t, h1Mode) } func TestRequestLimit_h2(t *testing.T) { testRequestLimit(t, h2Mode) } func testRequestLimit(t *testing.T, h2 bool) { @@ -3512,8 +3755,8 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { // Test that a hanging Request.Body.Read from another goroutine can't // cause the Handler goroutine's Request.Body.Close to block. +// See issue 7121. func TestRequestBodyCloseDoesntBlock(t *testing.T) { - t.Skipf("Skipping known issue; see golang.org/issue/7121") if testing.Short() { t.Skip("skipping in -short mode") } @@ -3644,9 +3887,7 @@ func TestServerConnState(t *testing.T) { } ts.Start() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() mustGet := func(url string, headers ...string) { req, err := NewRequest("GET", url, nil) @@ -4170,6 +4411,9 @@ func TestServerValidatesHostHeader(t *testing.T) { // Make an exception for HTTP upgrade requests: {"PRI * HTTP/2.0", "", 200}, + // Also an exception for CONNECT requests: (Issue 18215) + {"CONNECT golang.org:443 HTTP/1.1", "", 200}, + // But not other HTTP/2 stuff: {"PRI / HTTP/2.0", "", 400}, {"GET / HTTP/2.0", "", 400}, @@ -4373,13 +4617,6 @@ func testServerContext_ServerContextKey(t *testing.T, h2 bool) { if _, ok := got.(*Server); !ok { t.Errorf("context value = %T; want *http.Server", got) } - - got = ctx.Value(LocalAddrContextKey) - if addr, ok := got.(net.Addr); !ok { - t.Errorf("local addr value = %T; want net.Addr", got) - } else if fmt.Sprint(addr) != r.Host { - t.Errorf("local addr = %v; want %v", addr, r.Host) - } })) defer cst.close() res, err := cst.c.Get(cst.ts.URL) @@ -4389,6 +4626,37 @@ func testServerContext_ServerContextKey(t *testing.T, h2 bool) { res.Body.Close() } +func TestServerContext_LocalAddrContextKey_h1(t *testing.T) { + testServerContext_LocalAddrContextKey(t, h1Mode) +} +func TestServerContext_LocalAddrContextKey_h2(t *testing.T) { + testServerContext_LocalAddrContextKey(t, h2Mode) +} +func testServerContext_LocalAddrContextKey(t *testing.T, h2 bool) { + setParallel(t) + defer afterTest(t) + ch := make(chan interface{}, 1) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + ch <- r.Context().Value(LocalAddrContextKey) + })) + defer cst.close() + if _, err := cst.c.Head(cst.ts.URL); err != nil { + t.Fatal(err) + } + + host := cst.ts.Listener.Addr().String() + select { + case got := <-ch: + if addr, ok := got.(net.Addr); !ok { + t.Errorf("local addr value = %T; want net.Addr", got) + } else if fmt.Sprint(addr) != host { + t.Errorf("local addr = %v; want %v", addr, host) + } + case <-time.After(5 * time.Second): + t.Error("timed out") + } +} + // https://golang.org/issue/15960 func TestHandlerSetTransferEncodingChunked(t *testing.T) { setParallel(t) @@ -4481,15 +4749,9 @@ func benchmarkClientServerParallel(b *testing.B, parallelism int, useTLS bool) { b.ResetTimer() b.SetParallelism(parallelism) b.RunParallel(func(pb *testing.PB) { - noVerifyTransport := &Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - } - defer noVerifyTransport.CloseIdleConnections() - client := &Client{Transport: noVerifyTransport} + c := ts.Client() for pb.Next() { - res, err := client.Get(ts.URL) + res, err := c.Get(ts.URL) if err != nil { b.Logf("Get: %v", err) continue @@ -4924,10 +5186,7 @@ func TestServerIdleTimeout(t *testing.T) { ts.Config.IdleTimeout = 2 * time.Second ts.Start() defer ts.Close() - - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() get := func() string { res, err := c.Get(ts.URL) @@ -4988,9 +5247,8 @@ func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) { })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() + tr := c.Transport.(*Transport) get := func() string { return get(t, c, ts.URL) } @@ -5030,7 +5288,8 @@ func testServerShutdown(t *testing.T, h2 bool) { defer afterTest(t) var doShutdown func() // set later var shutdownRes = make(chan error, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + var gotOnShutdown = make(chan struct{}, 1) + handler := HandlerFunc(func(w ResponseWriter, r *Request) { go doShutdown() // Shutdown is graceful, so it should not interrupt // this in-flight response. Add a tiny sleep here to @@ -5038,7 +5297,10 @@ func testServerShutdown(t *testing.T, h2 bool) { // bugs. time.Sleep(20 * time.Millisecond) io.WriteString(w, r.RemoteAddr) - })) + }) + cst := newClientServerTest(t, h2, handler, func(srv *httptest.Server) { + srv.Config.RegisterOnShutdown(func() { gotOnShutdown <- struct{}{} }) + }) defer cst.close() doShutdown = func() { @@ -5049,6 +5311,11 @@ func testServerShutdown(t *testing.T, h2 bool) { if err := <-shutdownRes; err != nil { t.Fatalf("Shutdown: %v", err) } + select { + case <-gotOnShutdown: + case <-time.After(5 * time.Second): + t.Errorf("onShutdown callback not called, RegisterOnShutdown broken?") + } res, err := cst.c.Get(cst.ts.URL) if err == nil { @@ -5109,9 +5376,7 @@ func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { ts.Start() defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() res, err := c.Get(ts.URL) if err != nil { @@ -5312,3 +5577,41 @@ func TestServerHijackGetsBackgroundByte_big(t *testing.T) { t.Error("timeout") } } + +// Issue 18319: test that the Server validates the request method. +func TestServerValidatesMethod(t *testing.T) { + tests := []struct { + method string + want int + }{ + {"GET", 200}, + {"GE(T", 400}, + } + for _, tt := range tests { + conn := &testConn{closec: make(chan bool, 1)} + io.WriteString(&conn.readBuf, tt.method+" / HTTP/1.1\r\nHost: foo.example\r\n\r\n") + + ln := &oneConnListener{conn} + go Serve(ln, serve(200)) + <-conn.closec + res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil) + if err != nil { + t.Errorf("For %s, ReadResponse: %v", tt.method, res) + continue + } + if res.StatusCode != tt.want { + t.Errorf("For %s, Status = %d; want %d", tt.method, res.StatusCode, tt.want) + } + } +} + +func BenchmarkResponseStatusLine(b *testing.B) { + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + bw := bufio.NewWriter(ioutil.Discard) + var buf3 [3]byte + for pb.Next() { + Export_writeStatusLine(bw, true, 200, buf3[:]) + } + }) +} diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go index df70a15..2fa8ab2 100644 --- a/libgo/go/net/http/server.go +++ b/libgo/go/net/http/server.go @@ -75,9 +75,10 @@ var ( // If ServeHTTP panics, the server (the caller of ServeHTTP) assumes // that the effect of the panic was isolated to the active request. // It recovers the panic, logs a stack trace to the server error log, -// and hangs up the connection. To abort a handler so the client sees -// an interrupted response but the server doesn't log an error, panic -// with the value ErrAbortHandler. +// and either closes the network connection or sends an HTTP/2 +// RST_STREAM, depending on the HTTP protocol. To abort a handler so +// the client sees an interrupted response but the server doesn't log +// an error, panic with the value ErrAbortHandler. type Handler interface { ServeHTTP(ResponseWriter, *Request) } @@ -177,6 +178,9 @@ type Hijacker interface { // // The returned bufio.Reader may contain unprocessed buffered // data from the client. + // + // After a call to Hijack, the original Request.Body should + // not be used. Hijack() (net.Conn, *bufio.ReadWriter, error) } @@ -439,9 +443,10 @@ type response struct { handlerDone atomicBool // set true when the handler exits - // Buffers for Date and Content-Length - dateBuf [len(TimeFormat)]byte - clenBuf [10]byte + // Buffers for Date, Content-Length, and status code + dateBuf [len(TimeFormat)]byte + clenBuf [10]byte + statusBuf [3]byte // closeNotifyCh is the channel returned by CloseNotify. // TODO(bradfitz): this is currently (for Go 1.8) always @@ -622,7 +627,6 @@ type connReader struct { mu sync.Mutex // guards following hasByte bool byteBuf [1]byte - bgErr error // non-nil means error happened on background read cond *sync.Cond inRead bool aborted bool // set true before conn.rwc deadline is set to past @@ -731,11 +735,6 @@ func (cr *connReader) Read(p []byte) (n int, err error) { cr.unlock() return 0, io.EOF } - if cr.bgErr != nil { - err = cr.bgErr - cr.unlock() - return 0, err - } if len(p) == 0 { cr.unlock() return 0, nil @@ -839,7 +838,7 @@ func (srv *Server) initialReadLimitSize() int64 { return int64(srv.maxHeaderBytes()) + 4096 // bufio slop } -// wrapper around io.ReaderCloser which on first read, sends an +// wrapper around io.ReadCloser which on first read, sends an // HTTP/1.1 100 Continue header type expectContinueReader struct { resp *response @@ -948,7 +947,7 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) { hosts, haveHost := req.Header["Host"] isH2Upgrade := req.isH2Upgrade() - if req.ProtoAtLeast(1, 1) && (!haveHost || len(hosts) == 0) && !isH2Upgrade { + if req.ProtoAtLeast(1, 1) && (!haveHost || len(hosts) == 0) && !isH2Upgrade && req.Method != "CONNECT" { return nil, badRequestError("missing required Host header") } if len(hosts) > 1 { @@ -1379,7 +1378,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { } } - w.conn.bufw.WriteString(statusLine(w.req, code)) + writeStatusLine(w.conn.bufw, w.req.ProtoAtLeast(1, 1), code, w.statusBuf[:]) cw.header.WriteSubset(w.conn.bufw, excludeHeader) setHeader.Write(w.conn.bufw) w.conn.bufw.Write(crlf) @@ -1403,49 +1402,25 @@ func foreachHeaderElement(v string, fn func(string)) { } } -// statusLines is a cache of Status-Line strings, keyed by code (for -// HTTP/1.1) or negative code (for HTTP/1.0). This is faster than a -// map keyed by struct of two fields. This map's max size is bounded -// by 2*len(statusText), two protocol types for each known official -// status code in the statusText map. -var ( - statusMu sync.RWMutex - statusLines = make(map[int]string) -) - -// statusLine returns a response Status-Line (RFC 2616 Section 6.1) -// for the given request and response status code. -func statusLine(req *Request, code int) string { - // Fast path: - key := code - proto11 := req.ProtoAtLeast(1, 1) - if !proto11 { - key = -key - } - statusMu.RLock() - line, ok := statusLines[key] - statusMu.RUnlock() - if ok { - return line - } - - // Slow path: - proto := "HTTP/1.0" - if proto11 { - proto = "HTTP/1.1" - } - codestring := fmt.Sprintf("%03d", code) - text, ok := statusText[code] - if !ok { - text = "status code " + codestring +// writeStatusLine writes an HTTP/1.x Status-Line (RFC 2616 Section 6.1) +// to bw. is11 is whether the HTTP request is HTTP/1.1. false means HTTP/1.0. +// code is the response status code. +// scratch is an optional scratch buffer. If it has at least capacity 3, it's used. +func writeStatusLine(bw *bufio.Writer, is11 bool, code int, scratch []byte) { + if is11 { + bw.WriteString("HTTP/1.1 ") + } else { + bw.WriteString("HTTP/1.0 ") } - line = proto + " " + codestring + " " + text + "\r\n" - if ok { - statusMu.Lock() - defer statusMu.Unlock() - statusLines[key] = line + if text, ok := statusText[code]; ok { + bw.Write(strconv.AppendInt(scratch[:0], int64(code), 10)) + bw.WriteByte(' ') + bw.WriteString(text) + bw.WriteString("\r\n") + } else { + // don't worry about performance + fmt.Fprintf(bw, "%03d status code %d\r\n", code, code) } - return line } // bodyAllowed reports whether a Write is allowed for this response type. @@ -1714,6 +1689,7 @@ func isCommonNetReadError(err error) bool { // Serve a new connection. func (c *conn) serve(ctx context.Context) { c.remoteAddr = c.rwc.RemoteAddr().String() + ctx = context.WithValue(ctx, LocalAddrContextKey, c.rwc.LocalAddr()) defer func() { if err := recover(); err != nil && err != ErrAbortHandler { const size = 64 << 10 @@ -1973,8 +1949,12 @@ func StripPrefix(prefix string, h Handler) Handler { } return HandlerFunc(func(w ResponseWriter, r *Request) { if p := strings.TrimPrefix(r.URL.Path, prefix); len(p) < len(r.URL.Path) { - r.URL.Path = p - h.ServeHTTP(w, r) + r2 := new(Request) + *r2 = *r + r2.URL = new(url.URL) + *r2.URL = *r.URL + r2.URL.Path = p + h.ServeHTTP(w, r2) } else { NotFound(w, r) } @@ -1986,8 +1966,9 @@ func StripPrefix(prefix string, h Handler) Handler { // // The provided code should be in the 3xx range and is usually // StatusMovedPermanently, StatusFound or StatusSeeOther. -func Redirect(w ResponseWriter, r *Request, urlStr string, code int) { - if u, err := url.Parse(urlStr); err == nil { +func Redirect(w ResponseWriter, r *Request, url string, code int) { + // parseURL is just url.Parse (url is shadowed for godoc). + if u, err := parseURL(url); err == nil { // If url was relative, make absolute by // combining with request path. // The browser would probably do this for us, @@ -2011,39 +1992,43 @@ func Redirect(w ResponseWriter, r *Request, urlStr string, code int) { } // no leading http://server - if urlStr == "" || urlStr[0] != '/' { + if url == "" || url[0] != '/' { // make relative path absolute olddir, _ := path.Split(oldpath) - urlStr = olddir + urlStr + url = olddir + url } var query string - if i := strings.Index(urlStr, "?"); i != -1 { - urlStr, query = urlStr[:i], urlStr[i:] + if i := strings.Index(url, "?"); i != -1 { + url, query = url[:i], url[i:] } // clean up but preserve trailing slash - trailing := strings.HasSuffix(urlStr, "/") - urlStr = path.Clean(urlStr) - if trailing && !strings.HasSuffix(urlStr, "/") { - urlStr += "/" + trailing := strings.HasSuffix(url, "/") + url = path.Clean(url) + if trailing && !strings.HasSuffix(url, "/") { + url += "/" } - urlStr += query + url += query } } - w.Header().Set("Location", hexEscapeNonASCII(urlStr)) + w.Header().Set("Location", hexEscapeNonASCII(url)) w.WriteHeader(code) // RFC 2616 recommends that a short note "SHOULD" be included in the // response because older user agents may not understand 301/307. // Shouldn't send the response for POST or HEAD; that leaves GET. if r.Method == "GET" { - note := "<a href=\"" + htmlEscape(urlStr) + "\">" + statusText[code] + "</a>.\n" + note := "<a href=\"" + htmlEscape(url) + "\">" + statusText[code] + "</a>.\n" fmt.Fprintln(w, note) } } +// parseURL is just url.Parse. It exists only so that url.Parse can be called +// in places where url is shadowed for godoc. See https://golang.org/cl/49930. +var parseURL = url.Parse + var htmlReplacer = strings.NewReplacer( "&", "&", "<", "<", @@ -2163,9 +2148,29 @@ func cleanPath(p string) string { return np } -// Find a handler on a handler map given a path string -// Most-specific (longest) pattern wins +// stripHostPort returns h without any trailing ":<port>". +func stripHostPort(h string) string { + // If no port on host, return unchanged + if strings.IndexByte(h, ':') == -1 { + return h + } + host, _, err := net.SplitHostPort(h) + if err != nil { + return h // on error, return unchanged + } + return host +} + +// Find a handler on a handler map given a path string. +// Most-specific (longest) pattern wins. func (mux *ServeMux) match(path string) (h Handler, pattern string) { + // Check for exact match first. + v, ok := mux.m[path] + if ok { + return v.h, v.pattern + } + + // Check for longest valid match. var n = 0 for k, v := range mux.m { if !pathMatch(k, path) { @@ -2184,7 +2189,10 @@ func (mux *ServeMux) match(path string) (h Handler, pattern string) { // consulting r.Method, r.Host, and r.URL.Path. It always returns // a non-nil handler. If the path is not in its canonical form, the // handler will be an internally-generated handler that redirects -// to the canonical path. +// to the canonical path. If the host contains a port, it is ignored +// when matching handlers. +// +// The path and host are used unchanged for CONNECT requests. // // Handler also returns the registered pattern that matches the // request or, in the case of internally-generated redirects, @@ -2193,16 +2201,24 @@ func (mux *ServeMux) match(path string) (h Handler, pattern string) { // If there is no registered handler that applies to the request, // Handler returns a ``page not found'' handler and an empty pattern. func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) { - if r.Method != "CONNECT" { - if p := cleanPath(r.URL.Path); p != r.URL.Path { - _, pattern = mux.handler(r.Host, p) - url := *r.URL - url.Path = p - return RedirectHandler(url.String(), StatusMovedPermanently), pattern - } + + // CONNECT requests are not canonicalized. + if r.Method == "CONNECT" { + return mux.handler(r.Host, r.URL.Path) } - return mux.handler(r.Host, r.URL.Path) + // All other requests have any port stripped and path cleaned + // before passing to mux.handler. + host := stripHostPort(r.Host) + path := cleanPath(r.URL.Path) + if path != r.URL.Path { + _, pattern = mux.handler(host, path) + url := *r.URL + url.Path = path + return RedirectHandler(url.String(), StatusMovedPermanently), pattern + } + + return mux.handler(host, r.URL.Path) } // handler is the main implementation of Handler. @@ -2307,12 +2323,27 @@ func Serve(l net.Listener, handler Handler) error { return srv.Serve(l) } +// Serve accepts incoming HTTPS connections on the listener l, +// creating a new service goroutine for each. The service goroutines +// read requests and then call handler to reply to them. +// +// Handler is typically nil, in which case the DefaultServeMux is used. +// +// Additionally, files containing a certificate and matching private key +// for the server must be provided. If the certificate is signed by a +// certificate authority, the certFile should be the concatenation +// of the server's certificate, any intermediates, and the CA's certificate. +func ServeTLS(l net.Listener, handler Handler, certFile, keyFile string) error { + srv := &Server{Handler: handler} + return srv.ServeTLS(l, certFile, keyFile) +} + // A Server defines parameters for running an HTTP server. // The zero value for Server is a valid configuration. type Server struct { Addr string // TCP address to listen on, ":http" if empty Handler Handler // handler to invoke, http.DefaultServeMux if nil - TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS + TLSConfig *tls.Config // optional TLS config, used by ServeTLS and ListenAndServeTLS // ReadTimeout is the maximum duration for reading the entire // request, including the body. @@ -2338,7 +2369,7 @@ type Server struct { // IdleTimeout is the maximum amount of time to wait for the // next request when keep-alives are enabled. If IdleTimeout // is zero, the value of ReadTimeout is used. If both are - // zero, there is no timeout. + // zero, ReadHeaderTimeout is used. IdleTimeout time.Duration // MaxHeaderBytes controls the maximum number of bytes the @@ -2379,6 +2410,7 @@ type Server struct { listeners map[net.Listener]struct{} activeConn map[*conn]struct{} doneChan chan struct{} + onShutdown []func() } func (s *Server) getDoneChan() <-chan struct{} { @@ -2441,7 +2473,12 @@ var shutdownPollInterval = 500 * time.Millisecond // listeners, then closing all idle connections, and then waiting // indefinitely for connections to return to idle and then shut down. // If the provided context expires before the shutdown is complete, -// then the context's error is returned. +// Shutdown returns the context's error, otherwise it returns any +// error returned from closing the Server's underlying Listener(s). +// +// When Shutdown is called, Serve, ListenAndServe, and +// ListenAndServeTLS immediately return ErrServerClosed. Make sure the +// program doesn't exit and waits instead for Shutdown to return. // // Shutdown does not attempt to close nor wait for hijacked // connections such as WebSockets. The caller of Shutdown should @@ -2454,6 +2491,9 @@ func (srv *Server) Shutdown(ctx context.Context) error { srv.mu.Lock() lnerr := srv.closeListenersLocked() srv.closeDoneChanLocked() + for _, f := range srv.onShutdown { + go f() + } srv.mu.Unlock() ticker := time.NewTicker(shutdownPollInterval) @@ -2470,6 +2510,17 @@ func (srv *Server) Shutdown(ctx context.Context) error { } } +// RegisterOnShutdown registers a function to call on Shutdown. +// This can be used to gracefully shutdown connections that have +// undergone NPN/ALPN protocol upgrade or that have been hijacked. +// This function should start protocol-specific graceful shutdown, +// but should not wait for shutdown to complete. +func (srv *Server) RegisterOnShutdown(f func()) { + srv.mu.Lock() + srv.onShutdown = append(srv.onShutdown, f) + srv.mu.Unlock() +} + // closeIdleConns closes all idle connections and reports whether the // server is quiescent. func (s *Server) closeIdleConns() bool { @@ -2609,6 +2660,8 @@ func (srv *Server) shouldConfigureHTTP2ForServe() bool { return strSliceContains(srv.TLSConfig.NextProtos, http2NextProtoTLS) } +// ErrServerClosed is returned by the Server's Serve, ServeTLS, ListenAndServe, +// and ListenAndServeTLS methods after a call to Shutdown or Close. var ErrServerClosed = errors.New("http: Server closed") // Serve accepts incoming connections on the Listener l, creating a @@ -2638,7 +2691,6 @@ func (srv *Server) Serve(l net.Listener) error { baseCtx := context.Background() // base is always background, per Issue 16220 ctx := context.WithValue(baseCtx, ServerContextKey, srv) - ctx = context.WithValue(ctx, LocalAddrContextKey, l.Addr()) for { rw, e := l.Accept() if e != nil { @@ -2669,6 +2721,49 @@ func (srv *Server) Serve(l net.Listener) error { } } +// ServeTLS accepts incoming connections on the Listener l, creating a +// new service goroutine for each. The service goroutines read requests and +// then call srv.Handler to reply to them. +// +// Additionally, files containing a certificate and matching private key for +// the server must be provided if neither the Server's TLSConfig.Certificates +// nor TLSConfig.GetCertificate are populated.. If the certificate is signed by +// a certificate authority, the certFile should be the concatenation of the +// server's certificate, any intermediates, and the CA's certificate. +// +// For HTTP/2 support, srv.TLSConfig should be initialized to the +// provided listener's TLS Config before calling Serve. If +// srv.TLSConfig is non-nil and doesn't include the string "h2" in +// Config.NextProtos, HTTP/2 support is not enabled. +// +// ServeTLS always returns a non-nil error. After Shutdown or Close, the +// returned error is ErrServerClosed. +func (srv *Server) ServeTLS(l net.Listener, certFile, keyFile string) error { + // Setup HTTP/2 before srv.Serve, to initialize srv.TLSConfig + // before we clone it and create the TLS Listener. + if err := srv.setupHTTP2_ServeTLS(); err != nil { + return err + } + + config := cloneTLSConfig(srv.TLSConfig) + if !strSliceContains(config.NextProtos, "http/1.1") { + config.NextProtos = append(config.NextProtos, "http/1.1") + } + + configHasCert := len(config.Certificates) > 0 || config.GetCertificate != nil + if !configHasCert || certFile != "" || keyFile != "" { + var err error + config.Certificates = make([]tls.Certificate, 1) + config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return err + } + } + + tlsListener := tls.NewListener(l, config) + return srv.Serve(tlsListener) +} + func (s *Server) trackListener(ln net.Listener, add bool) { s.mu.Lock() defer s.mu.Unlock() @@ -2840,47 +2935,25 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { addr = ":https" } - // Setup HTTP/2 before srv.Serve, to initialize srv.TLSConfig - // before we clone it and create the TLS Listener. - if err := srv.setupHTTP2_ListenAndServeTLS(); err != nil { - return err - } - - config := cloneTLSConfig(srv.TLSConfig) - if !strSliceContains(config.NextProtos, "http/1.1") { - config.NextProtos = append(config.NextProtos, "http/1.1") - } - - configHasCert := len(config.Certificates) > 0 || config.GetCertificate != nil - if !configHasCert || certFile != "" || keyFile != "" { - var err error - config.Certificates = make([]tls.Certificate, 1) - config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return err - } - } - ln, err := net.Listen("tcp", addr) if err != nil { return err } - tlsListener := tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config) - return srv.Serve(tlsListener) + return srv.ServeTLS(tcpKeepAliveListener{ln.(*net.TCPListener)}, certFile, keyFile) } -// setupHTTP2_ListenAndServeTLS conditionally configures HTTP/2 on +// setupHTTP2_ServeTLS conditionally configures HTTP/2 on // srv and returns whether there was an error setting it up. If it is // not configured for policy reasons, nil is returned. -func (srv *Server) setupHTTP2_ListenAndServeTLS() error { +func (srv *Server) setupHTTP2_ServeTLS() error { srv.nextProtoOnce.Do(srv.onceSetNextProtoDefaults) return srv.nextProtoErr } // setupHTTP2_Serve is called from (*Server).Serve and conditionally // configures HTTP/2 on srv using a more conservative policy than -// setupHTTP2_ListenAndServeTLS because Serve may be called +// setupHTTP2_ServeTLS because Serve may be called // concurrently. // // The tests named TestTransportAutomaticHTTP2* and @@ -2907,7 +2980,10 @@ func (srv *Server) onceSetNextProtoDefaults() { // Enable HTTP/2 by default if the user hasn't otherwise // configured their TLSNextProto map. if srv.TLSNextProto == nil { - srv.nextProtoErr = http2ConfigureServer(srv, nil) + conf := &http2Server{ + NewWriteScheduler: func() http2WriteScheduler { return http2NewPriorityWriteScheduler(nil) }, + } + srv.nextProtoErr = http2ConfigureServer(srv, conf) } } diff --git a/libgo/go/net/http/sniff.go b/libgo/go/net/http/sniff.go index 0d21b44..ecc65e4 100644 --- a/libgo/go/net/http/sniff.go +++ b/libgo/go/net/http/sniff.go @@ -107,8 +107,8 @@ var sniffSignatures = []sniffSig{ ct: "audio/basic", }, &maskedSig{ - mask: []byte("OggS\x00"), - pat: []byte("\x4F\x67\x67\x53\x00"), + mask: []byte("\xFF\xFF\xFF\xFF\xFF"), + pat: []byte("OggS\x00"), ct: "application/ogg", }, &maskedSig{ diff --git a/libgo/go/net/http/sniff_test.go b/libgo/go/net/http/sniff_test.go index 38f3f81..24f1298 100644 --- a/libgo/go/net/http/sniff_test.go +++ b/libgo/go/net/http/sniff_test.go @@ -45,7 +45,11 @@ var sniffTests = []struct { {"WAV audio #1", []byte("RIFFb\xb8\x00\x00WAVEfmt \x12\x00\x00\x00\x06"), "audio/wave"}, {"WAV audio #2", []byte("RIFF,\x00\x00\x00WAVEfmt \x12\x00\x00\x00\x06"), "audio/wave"}, {"AIFF audio #1", []byte("FORM\x00\x00\x00\x00AIFFCOMM\x00\x00\x00\x12\x00\x01\x00\x00\x57\x55\x00\x10\x40\x0d\xf3\x34"), "audio/aiff"}, + {"OGG audio", []byte("OggS\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x7e\x46\x00\x00\x00\x00\x00\x00\x1f\xf6\xb4\xfc\x01\x1e\x01\x76\x6f\x72"), "application/ogg"}, + {"Must not match OGG", []byte("owow\x00"), "application/octet-stream"}, + {"Must not match OGG", []byte("oooS\x00"), "application/octet-stream"}, + {"Must not match OGG", []byte("oggS\x00"), "application/octet-stream"}, // Video types. {"MP4 video", []byte("\x00\x00\x00\x18ftypmp42\x00\x00\x00\x00mp42isom<\x06t\xbfmdat"), "video/mp4"}, diff --git a/libgo/go/net/http/transfer.go b/libgo/go/net/http/transfer.go index 4f47637..8faff2d 100644 --- a/libgo/go/net/http/transfer.go +++ b/libgo/go/net/http/transfer.go @@ -51,6 +51,19 @@ func (br *byteReader) Read(p []byte) (n int, err error) { return 1, io.EOF } +// transferBodyReader is an io.Reader that reads from tw.Body +// and records any non-EOF error in tw.bodyReadError. +// It is exactly 1 pointer wide to avoid allocations into interfaces. +type transferBodyReader struct{ tw *transferWriter } + +func (br transferBodyReader) Read(p []byte) (n int, err error) { + n, err = br.tw.Body.Read(p) + if err != nil && err != io.EOF { + br.tw.bodyReadError = err + } + return +} + // transferWriter inspects the fields of a user-supplied Request or Response, // sanitizes them without changing the user object and provides methods for // writing the respective header, body and trailer in wire format. @@ -62,8 +75,10 @@ type transferWriter struct { ContentLength int64 // -1 means unknown, 0 means exactly none Close bool TransferEncoding []string + Header Header Trailer Header IsResponse bool + bodyReadError error // any non-EOF error from reading Body FlushHeaders bool // flush headers to network before body ByteReadCh chan readResult // non-nil if probeRequestBody called @@ -82,14 +97,15 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) { t.Method = valueOrDefault(rr.Method, "GET") t.Close = rr.Close t.TransferEncoding = rr.TransferEncoding + t.Header = rr.Header t.Trailer = rr.Trailer - atLeastHTTP11 = rr.protoAtLeastOutgoing(1, 1) t.Body = rr.Body t.BodyCloser = rr.Body t.ContentLength = rr.outgoingLength() - if t.ContentLength < 0 && len(t.TransferEncoding) == 0 && atLeastHTTP11 && t.shouldSendChunkedRequestBody() { + if t.ContentLength < 0 && len(t.TransferEncoding) == 0 && t.shouldSendChunkedRequestBody() { t.TransferEncoding = []string{"chunked"} } + atLeastHTTP11 = true // Transport requests are always 1.1 or 2.0 case *Response: t.IsResponse = true if rr.Request != nil { @@ -100,6 +116,7 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) { t.ContentLength = rr.ContentLength t.Close = rr.Close t.TransferEncoding = rr.TransferEncoding + t.Header = rr.Header t.Trailer = rr.Trailer atLeastHTTP11 = rr.ProtoAtLeast(1, 1) t.ResponseToHEAD = noResponseBodyExpected(t.Method) @@ -252,7 +269,7 @@ func (t *transferWriter) shouldSendContentLength() bool { } func (t *transferWriter) WriteHeader(w io.Writer) error { - if t.Close { + if t.Close && !hasToken(t.Header.get("Connection"), "close") { if _, err := io.WriteString(w, "Connection: close\r\n"); err != nil { return err } @@ -304,24 +321,25 @@ func (t *transferWriter) WriteBody(w io.Writer) error { // Write body if t.Body != nil { + var body = transferBodyReader{t} if chunked(t.TransferEncoding) { if bw, ok := w.(*bufio.Writer); ok && !t.IsResponse { w = &internal.FlushAfterChunkWriter{Writer: bw} } cw := internal.NewChunkedWriter(w) - _, err = io.Copy(cw, t.Body) + _, err = io.Copy(cw, body) if err == nil { err = cw.Close() } } else if t.ContentLength == -1 { - ncopy, err = io.Copy(w, t.Body) + ncopy, err = io.Copy(w, body) } else { - ncopy, err = io.Copy(w, io.LimitReader(t.Body, t.ContentLength)) + ncopy, err = io.Copy(w, io.LimitReader(body, t.ContentLength)) if err != nil { return err } var nextra int64 - nextra, err = io.Copy(ioutil.Discard, t.Body) + nextra, err = io.Copy(ioutil.Discard, body) ncopy += nextra } if err != nil { diff --git a/libgo/go/net/http/transport.go b/libgo/go/net/http/transport.go index 571943d6..6a89392 100644 --- a/libgo/go/net/http/transport.go +++ b/libgo/go/net/http/transport.go @@ -29,6 +29,7 @@ import ( "time" "golang_org/x/net/lex/httplex" + "golang_org/x/net/proxy" ) // DefaultTransport is the default implementation of Transport and is @@ -88,6 +89,11 @@ type Transport struct { // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the // request is aborted with the provided error. + // + // The proxy type is determined by the URL scheme. "http" + // and "socks5" are supported. If the scheme is empty, + // "http" is assumed. + // // If Proxy is nil or returns a nil *URL, no proxy is used. Proxy func(*Request) (*url.URL, error) @@ -275,13 +281,17 @@ func ProxyFromEnvironment(req *Request) (*url.URL, error) { return nil, nil } proxyURL, err := url.Parse(proxy) - if err != nil || !strings.HasPrefix(proxyURL.Scheme, "http") { + if err != nil || + (proxyURL.Scheme != "http" && + proxyURL.Scheme != "https" && + proxyURL.Scheme != "socks5") { // proxy was bogus. Try prepending "http://" to it and // see if that parses correctly. If not, we fall // through and complain about the original one. if proxyURL, err := url.Parse("http://" + proxy); err == nil { return proxyURL, nil } + } if err != nil { return nil, fmt.Errorf("invalid proxy address %q: %v", proxy, err) @@ -298,11 +308,15 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) { } // transportRequest is a wrapper around a *Request that adds -// optional extra headers to write. +// optional extra headers to write and stores any error to return +// from roundTrip. type transportRequest struct { *Request // original request, not to be mutated extra Header // extra headers to write, or nil trace *httptrace.ClientTrace // optional + + mu sync.Mutex // guards err + err error // first setError value for mapRoundTripError to consider } func (tr *transportRequest) extraHeaders() Header { @@ -312,6 +326,14 @@ func (tr *transportRequest) extraHeaders() Header { return tr.extra } +func (tr *transportRequest) setError(err error) { + tr.mu.Lock() + if tr.err == nil { + tr.err = err + } + tr.mu.Unlock() +} + // RoundTrip implements the RoundTripper interface. // // For higher-level HTTP client support (such as handling of cookies @@ -402,6 +424,18 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { return nil, err } testHookRoundTripRetried() + + // Rewind the body if we're able to. (HTTP/2 does this itself so we only + // need to do it for HTTP/1.1 connections.) + if req.GetBody != nil && pconn.alt == nil { + newReq := *req + var err error + newReq.Body, err = req.GetBody() + if err != nil { + return nil, err + } + req = &newReq + } } } @@ -433,8 +467,9 @@ func (pc *persistConn) shouldRetryRequest(req *Request, err error) bool { return false } if _, ok := err.(nothingWrittenError); ok { - // We never wrote anything, so it's safe to retry. - return true + // We never wrote anything, so it's safe to retry, if there's no body or we + // can "rewind" the body with GetBody. + return req.outgoingLength() == 0 || req.GetBody != nil } if !req.isReplayable() { // Don't retry non-idempotent requests. @@ -788,7 +823,7 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) { } t.idleLRU.remove(pconn) key := pconn.cacheKey - pconns, _ := t.idleConn[key] + pconns := t.idleConn[key] switch len(pconns) { case 0: // Nothing @@ -964,6 +999,23 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC } } +type oneConnDialer <-chan net.Conn + +func newOneConnDialer(c net.Conn) proxy.Dialer { + ch := make(chan net.Conn, 1) + ch <- c + return oneConnDialer(ch) +} + +func (d oneConnDialer) Dial(network, addr string) (net.Conn, error) { + select { + case c := <-d: + return c, nil + default: + return nil, io.EOF + } +} + func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistConn, error) { pconn := &persistConn{ t: t, @@ -1020,6 +1072,23 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon switch { case cm.proxyURL == nil: // Do nothing. Not using a proxy. + case cm.proxyURL.Scheme == "socks5": + conn := pconn.conn + var auth *proxy.Auth + if u := cm.proxyURL.User; u != nil { + auth = &proxy.Auth{} + auth.User = u.Username() + auth.Password, _ = u.Password() + } + p, err := proxy.SOCKS5("", cm.addr(), auth, newOneConnDialer(conn)) + if err != nil { + conn.Close() + return nil, err + } + if _, err := p.Dial("tcp", cm.targetAddr); err != nil { + conn.Close() + return nil, err + } case cm.targetScheme == "http": pconn.isProxy = true if pa := cm.proxyAuth(); pa != "" { @@ -1176,6 +1245,10 @@ func useProxy(addr string) bool { if addr == p { return false } + if len(p) == 0 { + // There is no host part, likely the entry is malformed; ignore. + continue + } if p[0] == '.' && (strings.HasSuffix(addr, p) || addr == p[1:]) { // no_proxy ".foo.com" matches "bar.foo.com" or "foo.com" return false @@ -1193,19 +1266,21 @@ func useProxy(addr string) bool { // // A connect method may be of the following types: // -// Cache key form Description -// ----------------- ------------------------- -// |http|foo.com http directly to server, no proxy -// |https|foo.com https directly to server, no proxy -// http://proxy.com|https|foo.com http to proxy, then CONNECT to foo.com -// http://proxy.com|http http to proxy, http to anywhere after that +// Cache key form Description +// ----------------- ------------------------- +// |http|foo.com http directly to server, no proxy +// |https|foo.com https directly to server, no proxy +// http://proxy.com|https|foo.com http to proxy, then CONNECT to foo.com +// http://proxy.com|http http to proxy, http to anywhere after that +// socks5://proxy.com|http|foo.com socks5 to proxy, then http to foo.com +// socks5://proxy.com|https|foo.com socks5 to proxy, then https to foo.com // // Note: no support to https to the proxy yet. // type connectMethod struct { proxyURL *url.URL // nil for no proxy, else full proxy URL targetScheme string // "http" or "https" - targetAddr string // Not used if proxy + http targetScheme (4th example in table) + targetAddr string // Not used if http proxy + http targetScheme (4th example in table) } func (cm *connectMethod) key() connectMethodKey { @@ -1213,7 +1288,7 @@ func (cm *connectMethod) key() connectMethodKey { targetAddr := cm.targetAddr if cm.proxyURL != nil { proxyStr = cm.proxyURL.String() - if cm.targetScheme == "http" { + if strings.HasPrefix(cm.proxyURL.Scheme, "http") && cm.targetScheme == "http" { targetAddr = "" } } @@ -1379,63 +1454,53 @@ func (pc *persistConn) closeConnIfStillIdle() { pc.close(errIdleConnTimeout) } -// mapRoundTripErrorFromReadLoop maps the provided readLoop error into -// the error value that should be returned from persistConn.roundTrip. +// mapRoundTripError returns the appropriate error value for +// persistConn.roundTrip. +// +// The provided err is the first error that (*persistConn).roundTrip +// happened to receive from its select statement. // // The startBytesWritten value should be the value of pc.nwrite before the roundTrip // started writing the request. -func (pc *persistConn) mapRoundTripErrorFromReadLoop(req *Request, startBytesWritten int64, err error) (out error) { +func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritten int64, err error) error { if err == nil { return nil } - if err := pc.canceled(); err != nil { - return err + + // If the request was canceled, that's better than network + // failures that were likely the result of tearing down the + // connection. + if cerr := pc.canceled(); cerr != nil { + return cerr + } + + // See if an error was set explicitly. + req.mu.Lock() + reqErr := req.err + req.mu.Unlock() + if reqErr != nil { + return reqErr } + if err == errServerClosedIdle { + // Don't decorate return err } + if _, ok := err.(transportReadFromServerError); ok { + // Don't decorate return err } if pc.isBroken() { <-pc.writeLoopDone - if pc.nwrite == startBytesWritten && req.outgoingLength() == 0 { + if pc.nwrite == startBytesWritten { return nothingWrittenError{err} } + return fmt.Errorf("net/http: HTTP/1.x transport connection broken: %v", err) } return err } -// mapRoundTripErrorAfterClosed returns the error value to be propagated -// up to Transport.RoundTrip method when persistConn.roundTrip sees -// its pc.closech channel close, indicating the persistConn is dead. -// (after closech is closed, pc.closed is valid). -func (pc *persistConn) mapRoundTripErrorAfterClosed(req *Request, startBytesWritten int64) error { - if err := pc.canceled(); err != nil { - return err - } - err := pc.closed - if err == errServerClosedIdle { - // Don't decorate - return err - } - if _, ok := err.(transportReadFromServerError); ok { - // Don't decorate - return err - } - - // Wait for the writeLoop goroutine to terminated, and then - // see if we actually managed to write anything. If not, we - // can retry the request. - <-pc.writeLoopDone - if pc.nwrite == startBytesWritten && req.outgoingLength() == 0 { - return nothingWrittenError{err} - } - - return fmt.Errorf("net/http: HTTP/1.x transport connection broken: %v", err) - -} - func (pc *persistConn) readLoop() { closeErr := errReadLoopExiting // default value, if not changed below defer func() { @@ -1497,16 +1562,6 @@ func (pc *persistConn) readLoop() { err = fmt.Errorf("net/http: server response headers exceeded %d bytes; aborted", pc.maxHeaderResponseSize()) } - // If we won't be able to retry this request later (from the - // roundTrip goroutine), mark it as done now. - // BEFORE the send on rc.ch, as the client might re-use the - // same *Request pointer, and we don't want to set call - // t.setReqCanceler from this persistConn while the Transport - // potentially spins up a different persistConn for the - // caller's subsequent request. - if !pc.shouldRetryRequest(rc.req, err) { - pc.t.setReqCanceler(rc.req, nil) - } select { case rc.ch <- responseAndError{err: err}: case <-rc.callerGone: @@ -1579,7 +1634,7 @@ func (pc *persistConn) readLoop() { } resp.Body = body - if rc.addedGzip && resp.Header.Get("Content-Encoding") == "gzip" { + if rc.addedGzip && strings.EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { resp.Body = &gzipReader{body: body} resp.Header.Del("Content-Encoding") resp.Header.Del("Content-Length") @@ -1705,12 +1760,23 @@ func (pc *persistConn) writeLoop() { case wr := <-pc.writech: startBytesWritten := pc.nwrite err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra, pc.waitForContinue(wr.continueCh)) + if bre, ok := err.(requestBodyReadError); ok { + err = bre.error + // Errors reading from the user's + // Request.Body are high priority. + // Set it here before sending on the + // channels below or calling + // pc.close() which tears town + // connections and causes other + // errors. + wr.req.setError(err) + } if err == nil { err = pc.bw.Flush() } if err != nil { wr.req.Request.closeBody() - if pc.nwrite == startBytesWritten && wr.req.outgoingLength() == 0 { + if pc.nwrite == startBytesWritten { err = nothingWrittenError{err} } } @@ -1872,6 +1938,14 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err gone := make(chan struct{}) defer close(gone) + defer func() { + if err != nil { + pc.t.setReqCanceler(req.Request, nil) + } + }() + + const debugRoundTrip = false + // Write the request concurrently with waiting for a response, // in case the server decides to reply before reading our full // request body. @@ -1888,38 +1962,50 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err callerGone: gone, } - var re responseAndError var respHeaderTimer <-chan time.Time cancelChan := req.Request.Cancel ctxDoneChan := req.Context().Done() -WaitResponse: for { testHookWaitResLoop() select { case err := <-writeErrCh: + if debugRoundTrip { + req.logf("writeErrCh resv: %T/%#v", err, err) + } if err != nil { - if cerr := pc.canceled(); cerr != nil { - err = cerr - } - re = responseAndError{err: err} pc.close(fmt.Errorf("write error: %v", err)) - break WaitResponse + return nil, pc.mapRoundTripError(req, startBytesWritten, err) } if d := pc.t.ResponseHeaderTimeout; d > 0 { + if debugRoundTrip { + req.logf("starting timer for %v", d) + } timer := time.NewTimer(d) defer timer.Stop() // prevent leaks respHeaderTimer = timer.C } case <-pc.closech: - re = responseAndError{err: pc.mapRoundTripErrorAfterClosed(req.Request, startBytesWritten)} - break WaitResponse + if debugRoundTrip { + req.logf("closech recv: %T %#v", pc.closed, pc.closed) + } + return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed) case <-respHeaderTimer: + if debugRoundTrip { + req.logf("timeout waiting for response headers.") + } pc.close(errTimeout) - re = responseAndError{err: errTimeout} - break WaitResponse - case re = <-resc: - re.err = pc.mapRoundTripErrorFromReadLoop(req.Request, startBytesWritten, re.err) - break WaitResponse + return nil, errTimeout + case re := <-resc: + if (re.res == nil) == (re.err == nil) { + panic(fmt.Sprintf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil)) + } + if debugRoundTrip { + req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err) + } + if re.err != nil { + return nil, pc.mapRoundTripError(req, startBytesWritten, re.err) + } + return re.res, nil case <-cancelChan: pc.t.CancelRequest(req.Request) cancelChan = nil @@ -1929,14 +2015,16 @@ WaitResponse: ctxDoneChan = nil } } +} - if re.err != nil { - pc.t.setReqCanceler(req.Request, nil) - } - if (re.res == nil) == (re.err == nil) { - panic("internal error: exactly one of res or err should be set") +// tLogKey is a context WithValue key for test debugging contexts containing +// a t.Logf func. See export_test.go's Request.WithT method. +type tLogKey struct{} + +func (r *transportRequest) logf(format string, args ...interface{}) { + if logf, ok := r.Request.Context().Value(tLogKey{}).(func(string, ...interface{})); ok { + logf(time.Now().Format(time.RFC3339Nano)+": "+format, args...) } - return re.res, re.err } // markReused marks this connection as having been successfully used for a @@ -1982,8 +2070,9 @@ func (pc *persistConn) closeLocked(err error) { } var portMap = map[string]string{ - "http": "80", - "https": "443", + "http": "80", + "https": "443", + "socks5": "1080", } // canonicalAddr returns url.Host but always with a ":port" suffix diff --git a/libgo/go/net/http/transport_internal_test.go b/libgo/go/net/http/transport_internal_test.go index 3d24fc1..594bf6e 100644 --- a/libgo/go/net/http/transport_internal_test.go +++ b/libgo/go/net/http/transport_internal_test.go @@ -9,6 +9,7 @@ package http import ( "errors" "net" + "strings" "testing" ) @@ -30,6 +31,7 @@ func TestTransportPersistConnReadLoopEOF(t *testing.T) { tr := new(Transport) req, _ := NewRequest("GET", "http://"+ln.Addr().String(), nil) + req = req.WithT(t) treq := &transportRequest{Request: req} cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()} pc, err := tr.getConn(treq, cm) @@ -47,13 +49,13 @@ func TestTransportPersistConnReadLoopEOF(t *testing.T) { _, err = pc.roundTrip(treq) if !isTransportReadFromServerError(err) && err != errServerClosedIdle { - t.Fatalf("roundTrip = %#v, %v; want errServerClosedConn or errServerClosedIdle", err, err) + t.Errorf("roundTrip = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err) } <-pc.closech err = pc.closed if !isTransportReadFromServerError(err) && err != errServerClosedIdle { - t.Fatalf("pc.closed = %#v, %v; want errServerClosedConn or errServerClosedIdle", err, err) + t.Errorf("pc.closed = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err) } } @@ -80,6 +82,19 @@ func dummyRequest(method string) *Request { } return req } +func dummyRequestWithBody(method string) *Request { + req, err := NewRequest(method, "http://fake.tld/", strings.NewReader("foo")) + if err != nil { + panic(err) + } + return req +} + +func dummyRequestWithBodyNoGetBody(method string) *Request { + req := dummyRequestWithBody(method) + req.GetBody = nil + return req +} func TestTransportShouldRetryRequest(t *testing.T) { tests := []struct { @@ -131,6 +146,18 @@ func TestTransportShouldRetryRequest(t *testing.T) { err: errServerClosedIdle, want: true, }, + 7: { + pc: &persistConn{reused: true}, + req: dummyRequestWithBody("POST"), + err: nothingWrittenError{}, + want: true, + }, + 8: { + pc: &persistConn{reused: true}, + req: dummyRequestWithBodyNoGetBody("POST"), + err: nothingWrittenError{}, + want: false, + }, } for i, tt := range tests { got := tt.pc.shouldRetryRequest(tt.req, tt.err) diff --git a/libgo/go/net/http/transport_test.go b/libgo/go/net/http/transport_test.go index a58b183..27b55dc 100644 --- a/libgo/go/net/http/transport_test.go +++ b/libgo/go/net/http/transport_test.go @@ -16,6 +16,7 @@ import ( "context" "crypto/rand" "crypto/tls" + "encoding/binary" "errors" "fmt" "internal/nettrace" @@ -130,11 +131,9 @@ func TestTransportKeepAlives(t *testing.T) { ts := httptest.NewServer(hostPortHandler) defer ts.Close() + c := ts.Client() for _, disableKeepAlive := range []bool{false, true} { - tr := &Transport{DisableKeepAlives: disableKeepAlive} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} - + c.Transport.(*Transport).DisableKeepAlives = disableKeepAlive fetch := func(n int) string { res, err := c.Get(ts.URL) if err != nil { @@ -165,12 +164,11 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) { connSet, testDial := makeTestDial(t) - for _, connectionClose := range []bool{false, true} { - tr := &Transport{ - Dial: testDial, - } - c := &Client{Transport: tr} + c := ts.Client() + tr := c.Transport.(*Transport) + tr.Dial = testDial + for _, connectionClose := range []bool{false, true} { fetch := func(n int) string { req := new(Request) var err error @@ -216,12 +214,10 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) { connSet, testDial := makeTestDial(t) + c := ts.Client() + tr := c.Transport.(*Transport) + tr.Dial = testDial for _, connectionClose := range []bool{false, true} { - tr := &Transport{ - Dial: testDial, - } - c := &Client{Transport: tr} - fetch := func(n int) string { req := new(Request) var err error @@ -272,10 +268,9 @@ func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) { ts := httptest.NewServer(hostPortHandler) defer ts.Close() - tr := &Transport{ - DisableKeepAlives: true, - } - c := &Client{Transport: tr} + c := ts.Client() + c.Transport.(*Transport).DisableKeepAlives = true + res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) @@ -290,9 +285,8 @@ func TestTransportIdleCacheKeys(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() - - tr := &Transport{DisableKeepAlives: false} - c := &Client{Transport: tr} + c := ts.Client() + tr := c.Transport.(*Transport) if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g) @@ -384,9 +378,11 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { } })) defer ts.Close() + + c := ts.Client() + tr := c.Transport.(*Transport) maxIdleConnsPerHost := 2 - tr := &Transport{DisableKeepAlives: false, MaxIdleConnsPerHost: maxIdleConnsPerHost} - c := &Client{Transport: tr} + tr.MaxIdleConnsPerHost = maxIdleConnsPerHost // Start 3 outstanding requests and wait for the server to get them. // Their responses will hang until we write to resch, though. @@ -449,9 +445,8 @@ func TestTransportRemovesDeadIdleConnections(t *testing.T) { })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() + tr := c.Transport.(*Transport) doReq := func(name string) string { // Do a POST instead of a GET to prevent the Transport's @@ -495,9 +490,7 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(hostPortHandler) defer ts.Close() - - tr := &Transport{} - c := &Client{Transport: tr} + c := ts.Client() fetch := func(n, retries int) string { condFatalf := func(format string, arg ...interface{}) { @@ -563,10 +556,7 @@ func TestStressSurpriseServerCloses(t *testing.T) { conn.Close() })) defer ts.Close() - - tr := &Transport{DisableKeepAlives: false} - c := &Client{Transport: tr} - defer tr.CloseIdleConnections() + c := ts.Client() // Do a bunch of traffic from different goroutines. Send to activityc // after each request completes, regardless of whether it failed. @@ -619,9 +609,8 @@ func TestTransportHeadResponses(t *testing.T) { w.WriteHeader(200) })) defer ts.Close() + c := ts.Client() - tr := &Transport{DisableKeepAlives: false} - c := &Client{Transport: tr} for i := 0; i < 2; i++ { res, err := c.Head(ts.URL) if err != nil { @@ -655,10 +644,7 @@ func TestTransportHeadChunkedResponse(t *testing.T) { w.WriteHeader(200) })) defer ts.Close() - - tr := &Transport{DisableKeepAlives: false} - c := &Client{Transport: tr} - defer tr.CloseIdleConnections() + c := ts.Client() // Ensure that we wait for the readLoop to complete before // calling Head again @@ -719,6 +705,7 @@ func TestRoundTripGzip(t *testing.T) { } })) defer ts.Close() + tr := ts.Client().Transport.(*Transport) for i, test := range roundTripTests { // Test basic request (no accept-encoding) @@ -726,7 +713,7 @@ func TestRoundTripGzip(t *testing.T) { if test.accept != "" { req.Header.Set("Accept-Encoding", test.accept) } - res, err := DefaultTransport.RoundTrip(req) + res, err := tr.RoundTrip(req) var body []byte if test.compressed { var r *gzip.Reader @@ -791,10 +778,9 @@ func TestTransportGzip(t *testing.T) { gz.Close() })) defer ts.Close() + c := ts.Client() for _, chunked := range []string{"1", "0"} { - c := &Client{Transport: &Transport{}} - // First fetch something large, but only read some of it. res, err := c.Get(ts.URL + "/?body=large&chunked=" + chunked) if err != nil { @@ -844,7 +830,6 @@ func TestTransportGzip(t *testing.T) { } // And a HEAD request too, because they're always weird. - c := &Client{Transport: &Transport{}} res, err := c.Head(ts.URL) if err != nil { t.Fatalf("Head: %v", err) @@ -914,11 +899,13 @@ func TestTransportExpect100Continue(t *testing.T) { {path: "/timeout", body: []byte("hello"), sent: 5, status: 200}, // Timeout exceeded and entire body is sent. } + c := ts.Client() for i, v := range tests { - tr := &Transport{ExpectContinueTimeout: 2 * time.Second} + tr := &Transport{ + ExpectContinueTimeout: 2 * time.Second, + } defer tr.CloseIdleConnections() - c := &Client{Transport: tr} - + c.Transport = tr body := bytes.NewReader(v.body) req, err := NewRequest("PUT", ts.URL+v.path, body) if err != nil { @@ -943,6 +930,99 @@ func TestTransportExpect100Continue(t *testing.T) { } } +func TestSocks5Proxy(t *testing.T) { + defer afterTest(t) + ch := make(chan string, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ch <- "real server" + })) + defer ts.Close() + l := newLocalListener(t) + defer l.Close() + go func() { + defer close(ch) + s, err := l.Accept() + if err != nil { + t.Errorf("socks5 proxy Accept(): %v", err) + return + } + defer s.Close() + var buf [22]byte + if _, err := io.ReadFull(s, buf[:3]); err != nil { + t.Errorf("socks5 proxy initial read: %v", err) + return + } + if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) { + t.Errorf("socks5 proxy initial read: got %v, want %v", buf[:3], want) + return + } + if _, err := s.Write([]byte{5, 0}); err != nil { + t.Errorf("socks5 proxy initial write: %v", err) + return + } + if _, err := io.ReadFull(s, buf[:4]); err != nil { + t.Errorf("socks5 proxy second read: %v", err) + return + } + if want := []byte{5, 1, 0}; !bytes.Equal(buf[:3], want) { + t.Errorf("socks5 proxy second read: got %v, want %v", buf[:3], want) + return + } + var ipLen int + switch buf[3] { + case 1: + ipLen = 4 + case 4: + ipLen = 16 + default: + t.Fatalf("socks5 proxy second read: unexpected address type %v", buf[4]) + } + if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil { + t.Errorf("socks5 proxy address read: %v", err) + return + } + ip := net.IP(buf[4 : ipLen+4]) + port := binary.BigEndian.Uint16(buf[ipLen+4 : ipLen+6]) + copy(buf[:3], []byte{5, 0, 0}) + if _, err := s.Write(buf[:ipLen+6]); err != nil { + t.Errorf("socks5 proxy connect write: %v", err) + return + } + done := make(chan struct{}) + srv := &Server{Handler: HandlerFunc(func(w ResponseWriter, r *Request) { + done <- struct{}{} + })} + srv.Serve(&oneConnListener{conn: s}) + <-done + srv.Shutdown(context.Background()) + ch <- fmt.Sprintf("proxy for %s:%d", ip, port) + }() + + pu, err := url.Parse("socks5://" + l.Addr().String()) + if err != nil { + t.Fatal(err) + } + c := ts.Client() + c.Transport.(*Transport).Proxy = ProxyURL(pu) + if _, err := c.Head(ts.URL); err != nil { + t.Error(err) + } + var got string + select { + case got = <-ch: + case <-time.After(5 * time.Second): + t.Fatal("timeout connecting to socks5 proxy") + } + tsu, err := url.Parse(ts.URL) + if err != nil { + t.Fatal(err) + } + want := "proxy for " + tsu.Host + if got != want { + t.Errorf("got %q, want %q", got, want) + } +} + func TestTransportProxy(t *testing.T) { defer afterTest(t) ch := make(chan string, 1) @@ -959,12 +1039,20 @@ func TestTransportProxy(t *testing.T) { if err != nil { t.Fatal(err) } - c := &Client{Transport: &Transport{Proxy: ProxyURL(pu)}} - c.Head(ts.URL) - got := <-ch + c := ts.Client() + c.Transport.(*Transport).Proxy = ProxyURL(pu) + if _, err := c.Head(ts.URL); err != nil { + t.Error(err) + } + var got string + select { + case got = <-ch: + case <-time.After(5 * time.Second): + t.Fatal("timeout connecting to http proxy") + } want := "proxy for " + ts.URL + "/" if got != want { - t.Errorf("want %q, got %q", want, got) + t.Errorf("got %q, want %q", got, want) } } @@ -1022,9 +1110,7 @@ func TestTransportGzipRecursive(t *testing.T) { })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) @@ -1052,9 +1138,7 @@ func TestTransportGzipShort(t *testing.T) { })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) @@ -1095,9 +1179,8 @@ func TestTransportPersistConnLeak(t *testing.T) { w.WriteHeader(204) })) defer ts.Close() - - tr := &Transport{} - c := &Client{Transport: tr} + c := ts.Client() + tr := c.Transport.(*Transport) n0 := runtime.NumGoroutine() @@ -1160,9 +1243,8 @@ func TestTransportPersistConnLeakShortBody(t *testing.T) { ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { })) defer ts.Close() - - tr := &Transport{} - c := &Client{Transport: tr} + c := ts.Client() + tr := c.Transport.(*Transport) n0 := runtime.NumGoroutine() body := []byte("Hello") @@ -1194,8 +1276,7 @@ func TestTransportPersistConnLeakShortBody(t *testing.T) { // This used to crash; https://golang.org/issue/3266 func TestTransportIdleConnCrash(t *testing.T) { defer afterTest(t) - tr := &Transport{} - c := &Client{Transport: tr} + var tr *Transport unblockCh := make(chan bool, 1) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -1203,6 +1284,8 @@ func TestTransportIdleConnCrash(t *testing.T) { tr.CloseIdleConnections() })) defer ts.Close() + c := ts.Client() + tr = c.Transport.(*Transport) didreq := make(chan bool) go func() { @@ -1232,8 +1315,7 @@ func TestIssue3644(t *testing.T) { } })) defer ts.Close() - tr := &Transport{} - c := &Client{Transport: tr} + c := ts.Client() res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) @@ -1258,8 +1340,7 @@ func TestIssue3595(t *testing.T) { Error(w, deniedMsg, StatusUnauthorized) })) defer ts.Close() - tr := &Transport{} - c := &Client{Transport: tr} + c := ts.Client() res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a')) if err != nil { t.Errorf("Post: %v", err) @@ -1283,8 +1364,8 @@ func TestChunkedNoContent(t *testing.T) { })) defer ts.Close() + c := ts.Client() for _, closeBody := range []bool{true, false} { - c := &Client{Transport: &Transport{}} const n = 4 for i := 1; i <= n; i++ { res, err := c.Get(ts.URL) @@ -1324,10 +1405,7 @@ func TestTransportConcurrency(t *testing.T) { SetPendingDialHooks(func() { wg.Add(1) }, wg.Done) defer SetPendingDialHooks(nil, nil) - tr := &Transport{} - defer tr.CloseIdleConnections() - - c := &Client{Transport: tr} + c := ts.Client() reqs := make(chan string) defer close(reqs) @@ -1369,23 +1447,20 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { io.Copy(w, neverEnding('a')) }) ts := httptest.NewServer(mux) + defer ts.Close() timeout := 100 * time.Millisecond - client := &Client{ - Transport: &Transport{ - Dial: func(n, addr string) (net.Conn, error) { - conn, err := net.Dial(n, addr) - if err != nil { - return nil, err - } - conn.SetDeadline(time.Now().Add(timeout)) - if debug { - conn = NewLoggingConn("client", conn) - } - return conn, nil - }, - DisableKeepAlives: true, - }, + c := ts.Client() + c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) { + conn, err := net.Dial(n, addr) + if err != nil { + return nil, err + } + conn.SetDeadline(time.Now().Add(timeout)) + if debug { + conn = NewLoggingConn("client", conn) + } + return conn, nil } getFailed := false @@ -1397,7 +1472,7 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { if debug { println("run", i+1, "of", nRuns) } - sres, err := client.Get(ts.URL + "/get") + sres, err := c.Get(ts.URL + "/get") if err != nil { if !getFailed { // Make the timeout longer, once. @@ -1419,7 +1494,6 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { if debug { println("tests complete; waiting for handlers to finish") } - ts.Close() } func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { @@ -1437,21 +1511,17 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { ts := httptest.NewServer(mux) timeout := 100 * time.Millisecond - client := &Client{ - Transport: &Transport{ - Dial: func(n, addr string) (net.Conn, error) { - conn, err := net.Dial(n, addr) - if err != nil { - return nil, err - } - conn.SetDeadline(time.Now().Add(timeout)) - if debug { - conn = NewLoggingConn("client", conn) - } - return conn, nil - }, - DisableKeepAlives: true, - }, + c := ts.Client() + c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) { + conn, err := net.Dial(n, addr) + if err != nil { + return nil, err + } + conn.SetDeadline(time.Now().Add(timeout)) + if debug { + conn = NewLoggingConn("client", conn) + } + return conn, nil } getFailed := false @@ -1463,7 +1533,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { if debug { println("run", i+1, "of", nRuns) } - sres, err := client.Get(ts.URL + "/get") + sres, err := c.Get(ts.URL + "/get") if err != nil { if !getFailed { // Make the timeout longer, once. @@ -1477,7 +1547,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { break } req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body) - _, err = client.Do(req) + _, err = c.Do(req) if err == nil { sres.Body.Close() t.Errorf("Unexpected successful PUT") @@ -1509,11 +1579,8 @@ func TestTransportResponseHeaderTimeout(t *testing.T) { ts := httptest.NewServer(mux) defer ts.Close() - tr := &Transport{ - ResponseHeaderTimeout: 500 * time.Millisecond, - } - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() + c.Transport.(*Transport).ResponseHeaderTimeout = 500 * time.Millisecond tests := []struct { path string @@ -1525,7 +1592,9 @@ func TestTransportResponseHeaderTimeout(t *testing.T) { {path: "/fast", want: 200}, } for i, tt := range tests { - res, err := c.Get(ts.URL + tt.path) + req, _ := NewRequest("GET", ts.URL+tt.path, nil) + req = req.WithT(t) + res, err := c.Do(req) select { case <-inHandler: case <-time.After(5 * time.Second): @@ -1578,9 +1647,8 @@ func TestTransportCancelRequest(t *testing.T) { defer ts.Close() defer close(unblockc) - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() + tr := c.Transport.(*Transport) req, _ := NewRequest("GET", ts.URL, nil) res, err := c.Do(req) @@ -1688,9 +1756,8 @@ func TestCancelRequestWithChannel(t *testing.T) { defer ts.Close() defer close(unblockc) - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() + tr := c.Transport.(*Transport) req, _ := NewRequest("GET", ts.URL, nil) ch := make(chan struct{}) @@ -1747,9 +1814,7 @@ func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) { defer ts.Close() defer close(unblockc) - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() req, _ := NewRequest("GET", ts.URL, nil) if withCtx { @@ -1837,9 +1902,8 @@ func TestTransportCloseResponseBody(t *testing.T) { })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c := ts.Client() + tr := c.Transport.(*Transport) req, _ := NewRequest("GET", ts.URL, nil) defer tr.CancelRequest(req) @@ -1959,18 +2023,12 @@ func TestTransportSocketLateBinding(t *testing.T) { defer ts.Close() dialGate := make(chan bool, 1) - tr := &Transport{ - Dial: func(n, addr string) (net.Conn, error) { - if <-dialGate { - return net.Dial(n, addr) - } - return nil, errors.New("manually closed") - }, - DisableKeepAlives: false, - } - defer tr.CloseIdleConnections() - c := &Client{ - Transport: tr, + c := ts.Client() + c.Transport.(*Transport).Dial = func(n, addr string) (net.Conn, error) { + if <-dialGate { + return net.Dial(n, addr) + } + return nil, errors.New("manually closed") } dialGate <- true // only allow one dial @@ -2160,6 +2218,7 @@ var proxyFromEnvTests = []proxyFromEnvTest{ {env: "https://cache.corp.example.com", want: "https://cache.corp.example.com"}, {env: "http://127.0.0.1:8080", want: "http://127.0.0.1:8080"}, {env: "https://127.0.0.1:8080", want: "https://127.0.0.1:8080"}, + {env: "socks5://127.0.0.1", want: "socks5://127.0.0.1"}, // Don't use secure for http {req: "http://insecure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://http.proxy.tld"}, @@ -2184,6 +2243,7 @@ var proxyFromEnvTests = []proxyFromEnvTest{ func TestProxyFromEnvironment(t *testing.T) { ResetProxyEnv() + defer ResetProxyEnv() for _, tt := range proxyFromEnvTests { os.Setenv("HTTP_PROXY", tt.env) os.Setenv("HTTPS_PROXY", tt.httpsenv) @@ -2223,14 +2283,11 @@ func TestIdleConnChannelLeak(t *testing.T) { SetReadLoopBeforeNextReadHook(func() { didRead <- true }) defer SetReadLoopBeforeNextReadHook(nil) - tr := &Transport{ - Dial: func(netw, addr string) (net.Conn, error) { - return net.Dial(netw, ts.Listener.Addr().String()) - }, + c := ts.Client() + tr := c.Transport.(*Transport) + tr.Dial = func(netw, addr string) (net.Conn, error) { + return net.Dial(netw, ts.Listener.Addr().String()) } - defer tr.CloseIdleConnections() - - c := &Client{Transport: tr} // First, without keep-alives. for _, disableKeep := range []bool{true, false} { @@ -2273,13 +2330,11 @@ func TestTransportClosesRequestBody(t *testing.T) { })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - cl := &Client{Transport: tr} + c := ts.Client() closes := 0 - res, err := cl.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) + res, err := c.Post(ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) if err != nil { t.Fatal(err) } @@ -2365,20 +2420,16 @@ func TestTLSServerClosesConnection(t *testing.T) { fmt.Fprintf(w, "hello") })) defer ts.Close() - tr := &Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - } - defer tr.CloseIdleConnections() - client := &Client{Transport: tr} + + c := ts.Client() + tr := c.Transport.(*Transport) var nSuccess = 0 var errs []error const trials = 20 for i := 0; i < trials; i++ { tr.CloseIdleConnections() - res, err := client.Get(ts.URL + "/keep-alive-then-die") + res, err := c.Get(ts.URL + "/keep-alive-then-die") if err != nil { t.Fatal(err) } @@ -2393,7 +2444,7 @@ func TestTLSServerClosesConnection(t *testing.T) { // Now try again and see if we successfully // pick a new connection. - res, err = client.Get(ts.URL + "/") + res, err = c.Get(ts.URL + "/") if err != nil { errs = append(errs, err) continue @@ -2472,22 +2523,20 @@ func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { go io.Copy(ioutil.Discard, conn) })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - client := &Client{Transport: tr} + c := ts.Client() const bodySize = 256 << 10 finalBit := make(byteFromChanReader, 1) req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit)) req.ContentLength = bodySize - res, err := client.Do(req) + res, err := c.Do(req) if err := wantBody(res, err, "foo"); err != nil { t.Errorf("POST response: %v", err) } donec := make(chan bool) go func() { defer close(donec) - res, err = client.Get(ts.URL) + res, err = c.Get(ts.URL) if err := wantBody(res, err, "bar"); err != nil { t.Errorf("GET response: %v", err) return @@ -2519,10 +2568,9 @@ func TestTransportIssue10457(t *testing.T) { conn.Close() })) defer ts.Close() - tr := &Transport{} - defer tr.CloseIdleConnections() - cl := &Client{Transport: tr} - res, err := cl.Get(ts.URL) + c := ts.Client() + + res, err := c.Get(ts.URL) if err != nil { t.Fatalf("Get: %v", err) } @@ -2553,89 +2601,160 @@ type writerFuncConn struct { func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) } -// Issue 4677. If we try to reuse a connection that the server is in the -// process of closing, we may end up successfully writing out our request (or a -// portion of our request) only to find a connection error when we try to read -// from (or finish writing to) the socket. +// Issues 4677, 18241, and 17844. If we try to reuse a connection that the +// server is in the process of closing, we may end up successfully writing out +// our request (or a portion of our request) only to find a connection error +// when we try to read from (or finish writing to) the socket. // -// NOTE: we resend a request only if the request is idempotent, we reused a -// keep-alive connection, and we haven't yet received any header data. This -// automatically prevents an infinite resend loop because we'll run out of the -// cached keep-alive connections eventually. -func TestRetryIdempotentRequestsOnError(t *testing.T) { - defer afterTest(t) +// NOTE: we resend a request only if: +// - we reused a keep-alive connection +// - we haven't yet received any header data +// - either we wrote no bytes to the server, or the request is idempotent +// This automatically prevents an infinite resend loop because we'll run out of +// the cached keep-alive connections eventually. +func TestRetryRequestsOnError(t *testing.T) { + newRequest := func(method, urlStr string, body io.Reader) *Request { + req, err := NewRequest(method, urlStr, body) + if err != nil { + t.Fatal(err) + } + return req + } - var ( - mu sync.Mutex - logbuf bytes.Buffer - ) - logf := func(format string, args ...interface{}) { - mu.Lock() - defer mu.Unlock() - fmt.Fprintf(&logbuf, format, args...) - logbuf.WriteByte('\n') + testCases := []struct { + name string + failureN int + failureErr error + // Note that we can't just re-use the Request object across calls to c.Do + // because we need to rewind Body between calls. (GetBody is only used to + // rewind Body on failure and redirects, not just because it's done.) + req func() *Request + reqString string + }{ + { + name: "IdempotentNoBodySomeWritten", + // Believe that we've written some bytes to the server, so we know we're + // not just in the "retry when no bytes sent" case". + failureN: 1, + // Use the specific error that shouldRetryRequest looks for with idempotent requests. + failureErr: ExportErrServerClosedIdle, + req: func() *Request { + return newRequest("GET", "http://fake.golang", nil) + }, + reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`, + }, + { + name: "IdempotentGetBodySomeWritten", + // Believe that we've written some bytes to the server, so we know we're + // not just in the "retry when no bytes sent" case". + failureN: 1, + // Use the specific error that shouldRetryRequest looks for with idempotent requests. + failureErr: ExportErrServerClosedIdle, + req: func() *Request { + return newRequest("GET", "http://fake.golang", strings.NewReader("foo\n")) + }, + reqString: `GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`, + }, + { + name: "NothingWrittenNoBody", + // It's key that we return 0 here -- that's what enables Transport to know + // that nothing was written, even though this is a non-idempotent request. + failureN: 0, + failureErr: errors.New("second write fails"), + req: func() *Request { + return newRequest("DELETE", "http://fake.golang", nil) + }, + reqString: `DELETE / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n`, + }, + { + name: "NothingWrittenGetBody", + // It's key that we return 0 here -- that's what enables Transport to know + // that nothing was written, even though this is a non-idempotent request. + failureN: 0, + failureErr: errors.New("second write fails"), + // Note that NewRequest will set up GetBody for strings.Reader, which is + // required for the retry to occur + req: func() *Request { + return newRequest("POST", "http://fake.golang", strings.NewReader("foo\n")) + }, + reqString: `POST / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 4\r\nAccept-Encoding: gzip\r\n\r\nfoo\n`, + }, } - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - logf("Handler") - w.Header().Set("X-Status", "ok") - })) - defer ts.Close() + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + defer afterTest(t) - var writeNumAtomic int32 - tr := &Transport{ - Dial: func(network, addr string) (net.Conn, error) { - logf("Dial") - c, err := net.Dial(network, ts.Listener.Addr().String()) - if err != nil { - logf("Dial error: %v", err) - return nil, err + var ( + mu sync.Mutex + logbuf bytes.Buffer + ) + logf := func(format string, args ...interface{}) { + mu.Lock() + defer mu.Unlock() + fmt.Fprintf(&logbuf, format, args...) + logbuf.WriteByte('\n') } - return &writerFuncConn{ - Conn: c, - write: func(p []byte) (n int, err error) { - if atomic.AddInt32(&writeNumAtomic, 1) == 2 { - logf("intentional write failure") - return 0, errors.New("second write fails") - } - logf("Write(%q)", p) - return c.Write(p) - }, - }, nil - }, - } - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} - SetRoundTripRetried(func() { - logf("Retried.") - }) - defer SetRoundTripRetried(nil) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + logf("Handler") + w.Header().Set("X-Status", "ok") + })) + defer ts.Close() + + var writeNumAtomic int32 + c := ts.Client() + c.Transport.(*Transport).Dial = func(network, addr string) (net.Conn, error) { + logf("Dial") + c, err := net.Dial(network, ts.Listener.Addr().String()) + if err != nil { + logf("Dial error: %v", err) + return nil, err + } + return &writerFuncConn{ + Conn: c, + write: func(p []byte) (n int, err error) { + if atomic.AddInt32(&writeNumAtomic, 1) == 2 { + logf("intentional write failure") + return tc.failureN, tc.failureErr + } + logf("Write(%q)", p) + return c.Write(p) + }, + }, nil + } - for i := 0; i < 3; i++ { - res, err := c.Get("http://fake.golang/") - if err != nil { - t.Fatalf("i=%d: Get = %v", i, err) - } - res.Body.Close() - } + SetRoundTripRetried(func() { + logf("Retried.") + }) + defer SetRoundTripRetried(nil) - mu.Lock() - got := logbuf.String() - mu.Unlock() - const want = `Dial -Write("GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n") + for i := 0; i < 3; i++ { + res, err := c.Do(tc.req()) + if err != nil { + t.Fatalf("i=%d: Do = %v", i, err) + } + res.Body.Close() + } + + mu.Lock() + got := logbuf.String() + mu.Unlock() + want := fmt.Sprintf(`Dial +Write("%s") Handler intentional write failure Retried. Dial -Write("GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n") +Write("%s") Handler -Write("GET / HTTP/1.1\r\nHost: fake.golang\r\nUser-Agent: Go-http-client/1.1\r\nAccept-Encoding: gzip\r\n\r\n") +Write("%s") Handler -` - if got != want { - t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want) +`, tc.reqString, tc.reqString, tc.reqString) + if got != want { + t.Errorf("Log of events differs. Got:\n%s\nWant:\n%s", got, want) + } + }) } } @@ -2649,6 +2768,7 @@ func TestTransportClosesBodyOnError(t *testing.T) { readBody <- err })) defer ts.Close() + c := ts.Client() fakeErr := errors.New("fake error") didClose := make(chan bool, 1) req, _ := NewRequest("POST", ts.URL, struct { @@ -2664,7 +2784,7 @@ func TestTransportClosesBodyOnError(t *testing.T) { return nil }), }) - res, err := DefaultClient.Do(req) + res, err := c.Do(req) if res != nil { defer res.Body.Close() } @@ -2698,23 +2818,19 @@ func TestTransportDialTLS(t *testing.T) { mu.Unlock() })) defer ts.Close() - tr := &Transport{ - DialTLS: func(netw, addr string) (net.Conn, error) { - mu.Lock() - didDial = true - mu.Unlock() - c, err := tls.Dial(netw, addr, &tls.Config{ - InsecureSkipVerify: true, - }) - if err != nil { - return nil, err - } - return c, c.Handshake() - }, + c := ts.Client() + c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) { + mu.Lock() + didDial = true + mu.Unlock() + c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig) + if err != nil { + return nil, err + } + return c, c.Handshake() } - defer tr.CloseIdleConnections() - client := &Client{Transport: tr} - res, err := client.Get(ts.URL) + + res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) } @@ -2796,10 +2912,11 @@ func TestTransportRangeAndGzip(t *testing.T) { reqc <- r })) defer ts.Close() + c := ts.Client() req, _ := NewRequest("GET", ts.URL, nil) req.Header.Set("Range", "bytes=7-11") - res, err := DefaultClient.Do(req) + res, err := c.Do(req) if err != nil { t.Fatal(err) } @@ -2828,9 +2945,7 @@ func TestTransportResponseCancelRace(t *testing.T) { w.Write(b[:]) })) defer ts.Close() - - tr := &Transport{} - defer tr.CloseIdleConnections() + tr := ts.Client().Transport.(*Transport) req, err := NewRequest("GET", ts.URL, nil) if err != nil { @@ -2859,14 +2974,46 @@ func TestTransportResponseCancelRace(t *testing.T) { res.Body.Close() } +// Test for issue 19248: Content-Encoding's value is case insensitive. +func TestTransportContentEncodingCaseInsensitive(t *testing.T) { + setParallel(t) + defer afterTest(t) + for _, ce := range []string{"gzip", "GZIP"} { + ce := ce + t.Run(ce, func(t *testing.T) { + const encodedString = "Hello Gopher" + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Encoding", ce) + gz := gzip.NewWriter(w) + gz.Write([]byte(encodedString)) + gz.Close() + })) + defer ts.Close() + + res, err := ts.Client().Get(ts.URL) + if err != nil { + t.Fatal(err) + } + + body, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatal(err) + } + + if string(body) != encodedString { + t.Fatalf("Expected body %q, got: %q\n", encodedString, string(body)) + } + }) + } +} + func TestTransportDialCancelRace(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) defer ts.Close() - - tr := &Transport{} - defer tr.CloseIdleConnections() + tr := ts.Client().Transport.(*Transport) req, err := NewRequest("GET", ts.URL, nil) if err != nil { @@ -2993,6 +3140,7 @@ func TestTransportPrefersResponseOverWriteError(t *testing.T) { w.WriteHeader(StatusOK) })) defer ts.Close() + c := ts.Client() fail := 0 count := 100 @@ -3002,10 +3150,7 @@ func TestTransportPrefersResponseOverWriteError(t *testing.T) { if err != nil { t.Fatal(err) } - tr := new(Transport) - defer tr.CloseIdleConnections() - client := &Client{Transport: tr} - resp, err := client.Do(req) + resp, err := c.Do(req) if err != nil { fail++ t.Logf("%d = %#v", i, err) @@ -3218,10 +3363,8 @@ func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) { w.Write(rgz) // arbitrary gzip response })) defer ts.Close() + c := ts.Client() - tr := &Transport{} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} for i := 0; i < 2; i++ { res, err := c.Get(ts.URL) if err != nil { @@ -3250,12 +3393,9 @@ func TestTransportResponseHeaderLength(t *testing.T) { } })) defer ts.Close() + c := ts.Client() + c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10 - tr := &Transport{ - MaxResponseHeaderBytes: 512 << 10, - } - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} if res, err := c.Get(ts.URL); err != nil { t.Fatal(err) } else { @@ -3426,16 +3566,26 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { } } -func TestTransportEventTraceRealDNS(t *testing.T) { - if testing.Short() && testenv.Builder() == "" { - // Skip this test in short mode (the default for - // all.bash), in case the user is using a shady/ISP - // DNS server hijacking queries. - // See issues 16732, 16716. - // Our builders use 8.8.8.8, though, which correctly - // returns NXDOMAIN, so still run this test there. - t.Skip("skipping in short mode") +var ( + isDNSHijackedOnce sync.Once + isDNSHijacked bool +) + +func skipIfDNSHijacked(t *testing.T) { + // Skip this test if the user is using a shady/ISP + // DNS server hijacking queries. + // See issues 16732, 16716. + isDNSHijackedOnce.Do(func() { + addrs, _ := net.LookupHost("dns-should-not-resolve.golang") + isDNSHijacked = len(addrs) != 0 + }) + if isDNSHijacked { + t.Skip("skipping; test requires non-hijacking DNS server") } +} + +func TestTransportEventTraceRealDNS(t *testing.T) { + skipIfDNSHijacked(t) defer afterTest(t) tr := &Transport{} defer tr.CloseIdleConnections() @@ -3506,8 +3656,8 @@ func TestTransportRejectsAlphaPort(t *testing.T) { // connections. The http2 test is done in TestTransportEventTrace_h2 func TestTLSHandshakeTrace(t *testing.T) { defer afterTest(t) - s := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) - defer s.Close() + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + defer ts.Close() var mu sync.Mutex var start, done bool @@ -3527,10 +3677,8 @@ func TestTLSHandshakeTrace(t *testing.T) { }, } - tr := &Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} - req, err := NewRequest("GET", s.URL, nil) + c := ts.Client() + req, err := NewRequest("GET", ts.URL, nil) if err != nil { t.Fatal("Unable to construct test request:", err) } @@ -3557,16 +3705,14 @@ func TestTransportMaxIdleConns(t *testing.T) { // No body for convenience. })) defer ts.Close() - tr := &Transport{ - MaxIdleConns: 4, - } - defer tr.CloseIdleConnections() + c := ts.Client() + tr := c.Transport.(*Transport) + tr.MaxIdleConns = 4 ip, port, err := net.SplitHostPort(ts.Listener.Addr().String()) if err != nil { t.Fatal(err) } - c := &Client{Transport: tr} ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, host string) ([]net.IPAddr, error) { return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil }) @@ -3862,17 +4008,16 @@ func TestTransportProxyConnectHeader(t *testing.T) { c.Close() })) defer ts.Close() - tr := &Transport{ - ProxyConnectHeader: Header{ - "User-Agent": {"foo"}, - "Other": {"bar"}, - }, - Proxy: func(r *Request) (*url.URL, error) { - return url.Parse(ts.URL) - }, + + c := ts.Client() + c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) { + return url.Parse(ts.URL) } - defer tr.CloseIdleConnections() - c := &Client{Transport: tr} + c.Transport.(*Transport).ProxyConnectHeader = Header{ + "User-Agent": {"foo"}, + "Other": {"bar"}, + } + res, err := c.Get("https://dummy.tld/") // https to force a CONNECT if err == nil { res.Body.Close() diff --git a/libgo/go/net/interface.go b/libgo/go/net/interface.go index b3297f2..4036a7f 100644 --- a/libgo/go/net/interface.go +++ b/libgo/go/net/interface.go @@ -211,30 +211,30 @@ func (zc *ipv6ZoneCache) update(ift []Interface) { } } -func zoneToString(zone int) string { - if zone == 0 { +func (zc *ipv6ZoneCache) name(index int) string { + if index == 0 { return "" } zoneCache.update(nil) zoneCache.RLock() defer zoneCache.RUnlock() - name, ok := zoneCache.toName[zone] + name, ok := zoneCache.toName[index] if !ok { - name = uitoa(uint(zone)) + name = uitoa(uint(index)) } return name } -func zoneToInt(zone string) int { - if zone == "" { +func (zc *ipv6ZoneCache) index(name string) int { + if name == "" { return 0 } zoneCache.update(nil) zoneCache.RLock() defer zoneCache.RUnlock() - index, ok := zoneCache.toIndex[zone] + index, ok := zoneCache.toIndex[name] if !ok { - index, _, _ = dtoi(zone) + index, _, _ = dtoi(name) } return index } diff --git a/libgo/go/net/interface_linux.go b/libgo/go/net/interface_linux.go index 5e391b2..441ab2f 100644 --- a/libgo/go/net/interface_linux.go +++ b/libgo/go/net/interface_linux.go @@ -162,7 +162,7 @@ loop: if err != nil { return nil, os.NewSyscallError("parsenetlinkrouteattr", err) } - ifa := newAddr(ifi, ifam, attrs) + ifa := newAddr(ifam, attrs) if ifa != nil { ifat = append(ifat, ifa) } @@ -172,7 +172,7 @@ loop: return ifat, nil } -func newAddr(ifi *Interface, ifam *syscall.IfAddrmsg, attrs []syscall.NetlinkRouteAttr) Addr { +func newAddr(ifam *syscall.IfAddrmsg, attrs []syscall.NetlinkRouteAttr) Addr { var ipPointToPoint bool // Seems like we need to make sure whether the IP interface // stack consists of IP point-to-point numbered or unnumbered diff --git a/libgo/go/net/interface_test.go b/libgo/go/net/interface_test.go index 38a2ca4..534137a 100644 --- a/libgo/go/net/interface_test.go +++ b/libgo/go/net/interface_test.go @@ -262,13 +262,13 @@ func validateInterfaceMulticastAddrs(ifat []Addr) (*routeStats, error) { func checkUnicastStats(ifStats *ifStats, uniStats *routeStats) error { // Test the existence of connected unicast routes for IPv4. - if supportsIPv4 && ifStats.loop+ifStats.other > 0 && uniStats.ipv4 == 0 { + if supportsIPv4() && ifStats.loop+ifStats.other > 0 && uniStats.ipv4 == 0 { return fmt.Errorf("num IPv4 unicast routes = 0; want >0; summary: %+v, %+v", ifStats, uniStats) } // Test the existence of connected unicast routes for IPv6. // We can assume the existence of ::1/128 when at least one // loopback interface is installed. - if supportsIPv6 && ifStats.loop > 0 && uniStats.ipv6 == 0 { + if supportsIPv6() && ifStats.loop > 0 && uniStats.ipv6 == 0 { return fmt.Errorf("num IPv6 unicast routes = 0; want >0; summary: %+v, %+v", ifStats, uniStats) } return nil @@ -290,7 +290,7 @@ func checkMulticastStats(ifStats *ifStats, uniStats, multiStats *routeStats) err // We can assume the existence of connected multicast // route clones when at least two connected unicast // routes, ::1/128 and other, are installed. - if supportsIPv6 && ifStats.loop > 0 && uniStats.ipv6 > 1 && multiStats.ipv6 == 0 { + if supportsIPv6() && ifStats.loop > 0 && uniStats.ipv6 > 1 && multiStats.ipv6 == 0 { return fmt.Errorf("num IPv6 multicast route clones = 0; want >0; summary: %+v, %+v, %+v", ifStats, uniStats, multiStats) } } diff --git a/libgo/go/net/interface_windows.go b/libgo/go/net/interface_windows.go index 8b976e5..b08d158 100644 --- a/libgo/go/net/interface_windows.go +++ b/libgo/go/net/interface_windows.go @@ -24,10 +24,7 @@ func probeWindowsIPStack() (supportsVistaIP bool) { if err != nil { return true // Windows 10 and above will deprecate this API } - if byte(v) < 6 { // major version of Windows Vista is 6 - return false - } - return true + return byte(v) >= 6 // major version of Windows Vista is 6 } // adapterAddresses returns a list of IP adapter and address diff --git a/libgo/go/net/internal/socktest/sys_cloexec.go b/libgo/go/net/internal/socktest/sys_cloexec.go index 340ff07..d1b8f4f 100644 --- a/libgo/go/net/internal/socktest/sys_cloexec.go +++ b/libgo/go/net/internal/socktest/sys_cloexec.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build freebsd linux +// +build dragonfly freebsd linux package socktest @@ -15,7 +15,7 @@ func (sw *Switch) Accept4(s, flags int) (ns int, sa syscall.Sockaddr, err error) return syscall.Accept4(s, flags) } sw.fmu.RLock() - f, _ := sw.fltab[FilterAccept] + f := sw.fltab[FilterAccept] sw.fmu.RUnlock() af, err := f.apply(so) diff --git a/libgo/go/net/internal/socktest/sys_unix.go b/libgo/go/net/internal/socktest/sys_unix.go index a3d1282..397c524 100644 --- a/libgo/go/net/internal/socktest/sys_unix.go +++ b/libgo/go/net/internal/socktest/sys_unix.go @@ -14,7 +14,7 @@ func (sw *Switch) Socket(family, sotype, proto int) (s int, err error) { so := &Status{Cookie: cookie(family, sotype, proto)} sw.fmu.RLock() - f, _ := sw.fltab[FilterSocket] + f := sw.fltab[FilterSocket] sw.fmu.RUnlock() af, err := f.apply(so) @@ -47,7 +47,7 @@ func (sw *Switch) Close(s int) (err error) { return syscall.Close(s) } sw.fmu.RLock() - f, _ := sw.fltab[FilterClose] + f := sw.fltab[FilterClose] sw.fmu.RUnlock() af, err := f.apply(so) @@ -77,7 +77,7 @@ func (sw *Switch) Connect(s int, sa syscall.Sockaddr) (err error) { return syscall.Connect(s, sa) } sw.fmu.RLock() - f, _ := sw.fltab[FilterConnect] + f := sw.fltab[FilterConnect] sw.fmu.RUnlock() af, err := f.apply(so) @@ -106,7 +106,7 @@ func (sw *Switch) Listen(s, backlog int) (err error) { return syscall.Listen(s, backlog) } sw.fmu.RLock() - f, _ := sw.fltab[FilterListen] + f := sw.fltab[FilterListen] sw.fmu.RUnlock() af, err := f.apply(so) @@ -135,7 +135,7 @@ func (sw *Switch) Accept(s int) (ns int, sa syscall.Sockaddr, err error) { return syscall.Accept(s) } sw.fmu.RLock() - f, _ := sw.fltab[FilterAccept] + f := sw.fltab[FilterAccept] sw.fmu.RUnlock() af, err := f.apply(so) @@ -168,7 +168,7 @@ func (sw *Switch) GetsockoptInt(s, level, opt int) (soerr int, err error) { return syscall.GetsockoptInt(s, level, opt) } sw.fmu.RLock() - f, _ := sw.fltab[FilterGetsockoptInt] + f := sw.fltab[FilterGetsockoptInt] sw.fmu.RUnlock() af, err := f.apply(so) diff --git a/libgo/go/net/ip.go b/libgo/go/net/ip.go index db3364c..6b7ba4c 100644 --- a/libgo/go/net/ip.go +++ b/libgo/go/net/ip.go @@ -12,6 +12,8 @@ package net +import _ "unsafe" // for go:linkname + // IP address lengths (bytes). const ( IPv4len = 4 @@ -106,7 +108,8 @@ var ( IPv6linklocalallrouters = IP{0xff, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x02} ) -// IsUnspecified reports whether ip is an unspecified address. +// IsUnspecified reports whether ip is an unspecified address, either +// the IPv4 address "0.0.0.0" or the IPv6 address "::". func (ip IP) IsUnspecified() bool { return ip.Equal(IPv4zero) || ip.Equal(IPv6unspecified) } @@ -338,7 +341,8 @@ func ipEmptyString(ip IP) string { } // MarshalText implements the encoding.TextMarshaler interface. -// The encoding is the same as returned by String. +// The encoding is the same as returned by String, with one exception: +// When len(ip) is zero, it returns an empty slice. func (ip IP) MarshalText() ([]byte, error) { if len(ip) == 0 { return []byte(""), nil @@ -381,17 +385,9 @@ func (ip IP) Equal(x IP) bool { return false } -func bytesEqual(x, y []byte) bool { - if len(x) != len(y) { - return false - } - for i, b := range x { - if y[i] != b { - return false - } - } - return true -} +// bytes.Equal is implemented in runtime/asm_$goarch.s +//go:linkname bytesEqual bytes.Equal +func bytesEqual(x, y []byte) bool func (ip IP) matchAddrFamily(x IP) bool { return ip.To4() != nil && x.To4() != nil || ip.To16() != nil && ip.To4() == nil && x.To16() != nil && x.To4() == nil @@ -667,7 +663,7 @@ func ParseIP(s string) IP { // It returns the IP address and the network implied by the IP and // prefix length. // For example, ParseCIDR("192.0.2.1/24") returns the IP address -// 198.0.2.1 and the network 198.0.2.0/24. +// 192.0.2.1 and the network 192.0.2.0/24. func ParseCIDR(s string) (IP, *IPNet, error) { i := byteIndex(s, '/') if i < 0 { diff --git a/libgo/go/net/ip_test.go b/libgo/go/net/ip_test.go index 4655163..ad13388 100644 --- a/libgo/go/net/ip_test.go +++ b/libgo/go/net/ip_test.go @@ -6,6 +6,7 @@ package net import ( "bytes" + "math/rand" "reflect" "runtime" "testing" @@ -468,61 +469,77 @@ func TestNetworkNumberAndMask(t *testing.T) { } } -var splitJoinTests = []struct { - host string - port string - join string -}{ - {"www.google.com", "80", "www.google.com:80"}, - {"127.0.0.1", "1234", "127.0.0.1:1234"}, - {"::1", "80", "[::1]:80"}, - {"fe80::1%lo0", "80", "[fe80::1%lo0]:80"}, - {"localhost%lo0", "80", "[localhost%lo0]:80"}, - {"", "0", ":0"}, - - {"google.com", "https%foo", "google.com:https%foo"}, // Go 1.0 behavior - {"127.0.0.1", "", "127.0.0.1:"}, // Go 1.0 behavior - {"www.google.com", "", "www.google.com:"}, // Go 1.0 behavior -} - -var splitFailureTests = []struct { - hostPort string - err string -}{ - {"www.google.com", "missing port in address"}, - {"127.0.0.1", "missing port in address"}, - {"[::1]", "missing port in address"}, - {"[fe80::1%lo0]", "missing port in address"}, - {"[localhost%lo0]", "missing port in address"}, - {"localhost%lo0", "missing port in address"}, - - {"::1", "too many colons in address"}, - {"fe80::1%lo0", "too many colons in address"}, - {"fe80::1%lo0:80", "too many colons in address"}, +func TestSplitHostPort(t *testing.T) { + for _, tt := range []struct { + hostPort string + host string + port string + }{ + // Host name + {"localhost:http", "localhost", "http"}, + {"localhost:80", "localhost", "80"}, + + // Go-specific host name with zone identifier + {"localhost%lo0:http", "localhost%lo0", "http"}, + {"localhost%lo0:80", "localhost%lo0", "80"}, + {"[localhost%lo0]:http", "localhost%lo0", "http"}, // Go 1 behavior + {"[localhost%lo0]:80", "localhost%lo0", "80"}, // Go 1 behavior + + // IP literal + {"127.0.0.1:http", "127.0.0.1", "http"}, + {"127.0.0.1:80", "127.0.0.1", "80"}, + {"[::1]:http", "::1", "http"}, + {"[::1]:80", "::1", "80"}, + + // IP literal with zone identifier + {"[::1%lo0]:http", "::1%lo0", "http"}, + {"[::1%lo0]:80", "::1%lo0", "80"}, + + // Go-specific wildcard for host name + {":http", "", "http"}, // Go 1 behavior + {":80", "", "80"}, // Go 1 behavior + + // Go-specific wildcard for service name or transport port number + {"golang.org:", "golang.org", ""}, // Go 1 behavior + {"127.0.0.1:", "127.0.0.1", ""}, // Go 1 behavior + {"[::1]:", "::1", ""}, // Go 1 behavior + + // Opaque service name + {"golang.org:https%foo", "golang.org", "https%foo"}, // Go 1 behavior + } { + if host, port, err := SplitHostPort(tt.hostPort); host != tt.host || port != tt.port || err != nil { + t.Errorf("SplitHostPort(%q) = %q, %q, %v; want %q, %q, nil", tt.hostPort, host, port, err, tt.host, tt.port) + } + } - {"localhost%lo0:80", "missing brackets in address"}, + for _, tt := range []struct { + hostPort string + err string + }{ + {"golang.org", "missing port in address"}, + {"127.0.0.1", "missing port in address"}, + {"[::1]", "missing port in address"}, + {"[fe80::1%lo0]", "missing port in address"}, + {"[localhost%lo0]", "missing port in address"}, + {"localhost%lo0", "missing port in address"}, - // Test cases that didn't fail in Go 1.0 + {"::1", "too many colons in address"}, + {"fe80::1%lo0", "too many colons in address"}, + {"fe80::1%lo0:80", "too many colons in address"}, - {"[foo:bar]", "missing port in address"}, - {"[foo:bar]baz", "missing port in address"}, - {"[foo]bar:baz", "missing port in address"}, + // Test cases that didn't fail in Go 1 - {"[foo]:[bar]:baz", "too many colons in address"}, + {"[foo:bar]", "missing port in address"}, + {"[foo:bar]baz", "missing port in address"}, + {"[foo]bar:baz", "missing port in address"}, - {"[foo]:[bar]baz", "unexpected '[' in address"}, - {"foo[bar]:baz", "unexpected '[' in address"}, + {"[foo]:[bar]:baz", "too many colons in address"}, - {"foo]bar:baz", "unexpected ']' in address"}, -} + {"[foo]:[bar]baz", "unexpected '[' in address"}, + {"foo[bar]:baz", "unexpected '[' in address"}, -func TestSplitHostPort(t *testing.T) { - for _, tt := range splitJoinTests { - if host, port, err := SplitHostPort(tt.join); host != tt.host || port != tt.port || err != nil { - t.Errorf("SplitHostPort(%q) = %q, %q, %v; want %q, %q, nil", tt.join, host, port, err, tt.host, tt.port) - } - } - for _, tt := range splitFailureTests { + {"foo]bar:baz", "unexpected ']' in address"}, + } { if host, port, err := SplitHostPort(tt.hostPort); err == nil { t.Errorf("SplitHostPort(%q) should have failed", tt.hostPort) } else { @@ -538,9 +555,43 @@ func TestSplitHostPort(t *testing.T) { } func TestJoinHostPort(t *testing.T) { - for _, tt := range splitJoinTests { - if join := JoinHostPort(tt.host, tt.port); join != tt.join { - t.Errorf("JoinHostPort(%q, %q) = %q; want %q", tt.host, tt.port, join, tt.join) + for _, tt := range []struct { + host string + port string + hostPort string + }{ + // Host name + {"localhost", "http", "localhost:http"}, + {"localhost", "80", "localhost:80"}, + + // Go-specific host name with zone identifier + {"localhost%lo0", "http", "localhost%lo0:http"}, + {"localhost%lo0", "80", "localhost%lo0:80"}, + + // IP literal + {"127.0.0.1", "http", "127.0.0.1:http"}, + {"127.0.0.1", "80", "127.0.0.1:80"}, + {"::1", "http", "[::1]:http"}, + {"::1", "80", "[::1]:80"}, + + // IP literal with zone identifier + {"::1%lo0", "http", "[::1%lo0]:http"}, + {"::1%lo0", "80", "[::1%lo0]:80"}, + + // Go-specific wildcard for host name + {"", "http", ":http"}, // Go 1 behavior + {"", "80", ":80"}, // Go 1 behavior + + // Go-specific wildcard for service name or transport port number + {"golang.org", "", "golang.org:"}, // Go 1 behavior + {"127.0.0.1", "", "127.0.0.1:"}, // Go 1 behavior + {"::1", "", "[::1]:"}, // Go 1 behavior + + // Opaque service name + {"golang.org", "https%foo", "golang.org:https%foo"}, // Go 1 behavior + } { + if hostPort := JoinHostPort(tt.host, tt.port); hostPort != tt.hostPort { + t.Errorf("JoinHostPort(%q, %q) = %q; want %q", tt.host, tt.port, hostPort, tt.hostPort) } } } @@ -645,3 +696,32 @@ func TestIPAddrScope(t *testing.T) { } } } + +func BenchmarkIPEqual(b *testing.B) { + b.Run("IPv4", func(b *testing.B) { + benchmarkIPEqual(b, IPv4len) + }) + b.Run("IPv6", func(b *testing.B) { + benchmarkIPEqual(b, IPv6len) + }) +} + +func benchmarkIPEqual(b *testing.B, size int) { + ips := make([]IP, 1000) + for i := range ips { + ips[i] = make(IP, size) + rand.Read(ips[i]) + } + // Half of the N are equal. + for i := 0; i < b.N/2; i++ { + x := ips[i%len(ips)] + y := ips[i%len(ips)] + x.Equal(y) + } + // The other half are not equal. + for i := 0; i < b.N/2; i++ { + x := ips[i%len(ips)] + y := ips[(i+1)%len(ips)] + x.Equal(y) + } +} diff --git a/libgo/go/net/iprawsock.go b/libgo/go/net/iprawsock.go index d994fc6..c4b54f0 100644 --- a/libgo/go/net/iprawsock.go +++ b/libgo/go/net/iprawsock.go @@ -61,30 +61,37 @@ func (a *IPAddr) opAddr() Addr { return a } -// ResolveIPAddr parses addr as an IP address of the form "host" or -// "ipv6-host%zone" and resolves the domain name on the network net, -// which must be "ip", "ip4" or "ip6". +// ResolveIPAddr returns an address of IP end point. // -// Resolving a hostname is not recommended because this returns at most -// one of its IP addresses. -func ResolveIPAddr(net, addr string) (*IPAddr, error) { - if net == "" { // a hint wildcard for Go 1.0 undocumented behavior - net = "ip" +// The network must be an IP network name. +// +// If the host in the address parameter is not a literal IP address, +// ResolveIPAddr resolves the address to an address of IP end point. +// Otherwise, it parses the address as a literal IP address. +// The address parameter can use a host name, but this is not +// recommended, because it will return at most one of the host name's +// IP addresses. +// +// See func Dial for a description of the network and address +// parameters. +func ResolveIPAddr(network, address string) (*IPAddr, error) { + if network == "" { // a hint wildcard for Go 1.0 undocumented behavior + network = "ip" } - afnet, _, err := parseNetwork(context.Background(), net) + afnet, _, err := parseNetwork(context.Background(), network, false) if err != nil { return nil, err } switch afnet { case "ip", "ip4", "ip6": default: - return nil, UnknownNetworkError(net) + return nil, UnknownNetworkError(network) } - addrs, err := DefaultResolver.internetAddrList(context.Background(), afnet, addr) + addrs, err := DefaultResolver.internetAddrList(context.Background(), afnet, address) if err != nil { return nil, err } - return addrs.first(isIPv4).(*IPAddr), nil + return addrs.forResolve(network, address).(*IPAddr), nil } // IPConn is the implementation of the Conn and PacketConn interfaces @@ -93,13 +100,16 @@ type IPConn struct { conn } -// ReadFromIP reads an IP packet from c, copying the payload into b. -// It returns the number of bytes copied into b and the return address -// that was on the packet. -// -// ReadFromIP can be made to time out and return an error with -// Timeout() == true after a fixed time limit; see SetDeadline and -// SetReadDeadline. +// SyscallConn returns a raw network connection. +// This implements the syscall.Conn interface. +func (c *IPConn) SyscallConn() (syscall.RawConn, error) { + if !c.ok() { + return nil, syscall.EINVAL + } + return newRawConn(c.fd) +} + +// ReadFromIP acts like ReadFrom but returns an IPAddr. func (c *IPConn) ReadFromIP(b []byte) (int, *IPAddr, error) { if !c.ok() { return 0, nil, syscall.EINVAL @@ -126,10 +136,13 @@ func (c *IPConn) ReadFrom(b []byte) (int, Addr, error) { return n, addr, err } -// ReadMsgIP reads a packet from c, copying the payload into b and the -// associated out-of-band data into oob. It returns the number of +// ReadMsgIP reads a message from c, copying the payload into b and +// the associated out-of-band data into oob. It returns the number of // bytes copied into b, the number of bytes copied into oob, the flags -// that were set on the packet and the source address of the packet. +// that were set on the message and the source address of the message. +// +// The packages golang.org/x/net/ipv4 and golang.org/x/net/ipv6 can be +// used to manipulate IP-level socket options in oob. func (c *IPConn) ReadMsgIP(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err error) { if !c.ok() { return 0, 0, 0, nil, syscall.EINVAL @@ -141,13 +154,7 @@ func (c *IPConn) ReadMsgIP(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err return } -// WriteToIP writes an IP packet to addr via c, copying the payload -// from b. -// -// WriteToIP can be made to time out and return an error with -// Timeout() == true after a fixed time limit; see SetDeadline and -// SetWriteDeadline. On packet-oriented connections, write timeouts -// are rare. +// WriteToIP acts like WriteTo but takes an IPAddr. func (c *IPConn) WriteToIP(b []byte, addr *IPAddr) (int, error) { if !c.ok() { return 0, syscall.EINVAL @@ -175,9 +182,12 @@ func (c *IPConn) WriteTo(b []byte, addr Addr) (int, error) { return n, err } -// WriteMsgIP writes a packet to addr via c, copying the payload from +// WriteMsgIP writes a message to addr via c, copying the payload from // b and the associated out-of-band data from oob. It returns the // number of payload and out-of-band bytes written. +// +// The packages golang.org/x/net/ipv4 and golang.org/x/net/ipv6 can be +// used to manipulate IP-level socket options in oob. func (c *IPConn) WriteMsgIP(b, oob []byte, addr *IPAddr) (n, oobn int, err error) { if !c.ok() { return 0, 0, syscall.EINVAL @@ -191,25 +201,32 @@ func (c *IPConn) WriteMsgIP(b, oob []byte, addr *IPAddr) (n, oobn int, err error func newIPConn(fd *netFD) *IPConn { return &IPConn{conn{fd}} } -// DialIP connects to the remote address raddr on the network protocol -// netProto, which must be "ip", "ip4", or "ip6" followed by a colon -// and a protocol number or name. -func DialIP(netProto string, laddr, raddr *IPAddr) (*IPConn, error) { - c, err := dialIP(context.Background(), netProto, laddr, raddr) +// DialIP acts like Dial for IP networks. +// +// The network must be an IP network name; see func Dial for details. +// +// If laddr is nil, a local address is automatically chosen. +// If the IP field of raddr is nil or an unspecified IP address, the +// local system is assumed. +func DialIP(network string, laddr, raddr *IPAddr) (*IPConn, error) { + c, err := dialIP(context.Background(), network, laddr, raddr) if err != nil { - return nil, &OpError{Op: "dial", Net: netProto, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} + return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} } return c, nil } -// ListenIP listens for incoming IP packets addressed to the local -// address laddr. The returned connection's ReadFrom and WriteTo -// methods can be used to receive and send IP packets with per-packet -// addressing. -func ListenIP(netProto string, laddr *IPAddr) (*IPConn, error) { - c, err := listenIP(context.Background(), netProto, laddr) +// ListenIP acts like ListenPacket for IP networks. +// +// The network must be an IP network name; see func Dial for details. +// +// If the IP field of laddr is nil or an unspecified IP address, +// ListenIP listens on all available IP addresses of the local system +// except multicast IP addresses. +func ListenIP(network string, laddr *IPAddr) (*IPConn, error) { + c, err := listenIP(context.Background(), network, laddr) if err != nil { - return nil, &OpError{Op: "listen", Net: netProto, Source: nil, Addr: laddr.opAddr(), Err: err} + return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: err} } return c, nil } diff --git a/libgo/go/net/iprawsock_posix.go b/libgo/go/net/iprawsock_posix.go index 16e65dc..d613e6f 100644 --- a/libgo/go/net/iprawsock_posix.go +++ b/libgo/go/net/iprawsock_posix.go @@ -16,7 +16,7 @@ func sockaddrToIP(sa syscall.Sockaddr) Addr { case *syscall.SockaddrInet4: return &IPAddr{IP: sa.Addr[0:]} case *syscall.SockaddrInet6: - return &IPAddr{IP: sa.Addr[0:], Zone: zoneToString(int(sa.ZoneId))} + return &IPAddr{IP: sa.Addr[0:], Zone: zoneCache.name(int(sa.ZoneId))} } return nil } @@ -52,7 +52,7 @@ func (c *IPConn) readFrom(b []byte) (int, *IPAddr, error) { addr = &IPAddr{IP: sa.Addr[0:]} n = stripIPv4Header(n, b) case *syscall.SockaddrInet6: - addr = &IPAddr{IP: sa.Addr[0:], Zone: zoneToString(int(sa.ZoneId))} + addr = &IPAddr{IP: sa.Addr[0:], Zone: zoneCache.name(int(sa.ZoneId))} } return n, addr, err } @@ -79,7 +79,7 @@ func (c *IPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *IPAddr, err e case *syscall.SockaddrInet4: addr = &IPAddr{IP: sa.Addr[0:]} case *syscall.SockaddrInet6: - addr = &IPAddr{IP: sa.Addr[0:], Zone: zoneToString(int(sa.ZoneId))} + addr = &IPAddr{IP: sa.Addr[0:], Zone: zoneCache.name(int(sa.ZoneId))} } return } @@ -113,7 +113,7 @@ func (c *IPConn) writeMsg(b, oob []byte, addr *IPAddr) (n, oobn int, err error) } func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn, error) { - network, proto, err := parseNetwork(ctx, netProto) + network, proto, err := parseNetwork(ctx, netProto, true) if err != nil { return nil, err } @@ -133,7 +133,7 @@ func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn } func listenIP(ctx context.Context, netProto string, laddr *IPAddr) (*IPConn, error) { - network, proto, err := parseNetwork(ctx, netProto) + network, proto, err := parseNetwork(ctx, netProto, true) if err != nil { return nil, err } diff --git a/libgo/go/net/iprawsock_test.go b/libgo/go/net/iprawsock_test.go index 5d33b26..8972051 100644 --- a/libgo/go/net/iprawsock_test.go +++ b/libgo/go/net/iprawsock_test.go @@ -117,3 +117,75 @@ func TestIPConnRemoteName(t *testing.T) { t.Fatalf("got %#v; want %#v", c.RemoteAddr(), raddr) } } + +func TestDialListenIPArgs(t *testing.T) { + type test struct { + argLists [][2]string + shouldFail bool + } + tests := []test{ + { + argLists: [][2]string{ + {"ip", "127.0.0.1"}, + {"ip:", "127.0.0.1"}, + {"ip::", "127.0.0.1"}, + {"ip", "::1"}, + {"ip:", "::1"}, + {"ip::", "::1"}, + {"ip4", "127.0.0.1"}, + {"ip4:", "127.0.0.1"}, + {"ip4::", "127.0.0.1"}, + {"ip6", "::1"}, + {"ip6:", "::1"}, + {"ip6::", "::1"}, + }, + shouldFail: true, + }, + } + if testableNetwork("ip") { + priv := test{shouldFail: false} + for _, tt := range []struct { + network, address string + args [2]string + }{ + {"ip4:47", "127.0.0.1", [2]string{"ip4:47", "127.0.0.1"}}, + {"ip6:47", "::1", [2]string{"ip6:47", "::1"}}, + } { + c, err := ListenPacket(tt.network, tt.address) + if err != nil { + continue + } + c.Close() + priv.argLists = append(priv.argLists, tt.args) + } + if len(priv.argLists) > 0 { + tests = append(tests, priv) + } + } + + for _, tt := range tests { + for _, args := range tt.argLists { + _, err := Dial(args[0], args[1]) + if tt.shouldFail != (err != nil) { + t.Errorf("Dial(%q, %q) = %v; want (err != nil) is %t", args[0], args[1], err, tt.shouldFail) + } + _, err = ListenPacket(args[0], args[1]) + if tt.shouldFail != (err != nil) { + t.Errorf("ListenPacket(%q, %q) = %v; want (err != nil) is %t", args[0], args[1], err, tt.shouldFail) + } + a, err := ResolveIPAddr("ip", args[1]) + if err != nil { + t.Errorf("ResolveIPAddr(\"ip\", %q) = %v", args[1], err) + continue + } + _, err = DialIP(args[0], nil, a) + if tt.shouldFail != (err != nil) { + t.Errorf("DialIP(%q, %v) = %v; want (err != nil) is %t", args[0], a, err, tt.shouldFail) + } + _, err = ListenIP(args[0], a) + if tt.shouldFail != (err != nil) { + t.Errorf("ListenIP(%q, %v) = %v; want (err != nil) is %t", args[0], a, err, tt.shouldFail) + } + } + } +} diff --git a/libgo/go/net/ipsock.go b/libgo/go/net/ipsock.go index f1394a7..947bdf3 100644 --- a/libgo/go/net/ipsock.go +++ b/libgo/go/net/ipsock.go @@ -2,12 +2,11 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Internet protocol family sockets - package net import ( "context" + "sync" ) // BUG(rsc,mikio): On DragonFly BSD and OpenBSD, listening on the @@ -17,25 +16,41 @@ import ( // both address families are to be supported. // See inet6(4) for details. -var ( - // supportsIPv4 reports whether the platform supports IPv4 - // networking functionality. - supportsIPv4 bool +type ipStackCapabilities struct { + sync.Once // guards following + ipv4Enabled bool + ipv6Enabled bool + ipv4MappedIPv6Enabled bool +} - // supportsIPv6 reports whether the platform supports IPv6 - // networking functionality. - supportsIPv6 bool +var ipStackCaps ipStackCapabilities - // supportsIPv4map reports whether the platform supports - // mapping an IPv4 address inside an IPv6 address at transport - // layer protocols. See RFC 4291, RFC 4038 and RFC 3493. - supportsIPv4map bool -) +// supportsIPv4 reports whether the platform supports IPv4 networking +// functionality. +func supportsIPv4() bool { + ipStackCaps.Once.Do(ipStackCaps.probe) + return ipStackCaps.ipv4Enabled +} + +// supportsIPv6 reports whether the platform supports IPv6 networking +// functionality. +func supportsIPv6() bool { + ipStackCaps.Once.Do(ipStackCaps.probe) + return ipStackCaps.ipv6Enabled +} + +// supportsIPv4map reports whether the platform supports mapping an +// IPv4 address inside an IPv6 address at transport layer +// protocols. See RFC 4291, RFC 4038 and RFC 3493. +func supportsIPv4map() bool { + ipStackCaps.Once.Do(ipStackCaps.probe) + return ipStackCaps.ipv4MappedIPv6Enabled +} // An addrList represents a list of network endpoint addresses. type addrList []Addr -// isIPv4 returns true if the Addr contains an IPv4 address. +// isIPv4 reports whether addr contains an IPv4 address. func isIPv4(addr Addr) bool { switch addr := addr.(type) { case *TCPAddr: @@ -48,6 +63,28 @@ func isIPv4(addr Addr) bool { return false } +// isNotIPv4 reports whether addr does not contain an IPv4 address. +func isNotIPv4(addr Addr) bool { return !isIPv4(addr) } + +// forResolve returns the most appropriate address in address for +// a call to ResolveTCPAddr, ResolveUDPAddr, or ResolveIPAddr. +// IPv4 is preferred, unless addr contains an IPv6 literal. +func (addrs addrList) forResolve(network, addr string) Addr { + var want6 bool + switch network { + case "ip": + // IPv6 literal (addr does NOT contain a port) + want6 = count(addr, ':') > 0 + case "tcp", "udp": + // IPv6 literal. (addr contains a port, so look for '[') + want6 = count(addr, '[') > 0 + } + if want6 { + return addrs.first(isNotIPv4) + } + return addrs.first(isIPv4) +} + // first returns the first address which satisfies strategy, or if // none do, then the first address of any kind. func (addrs addrList) first(strategy func(Addr) bool) Addr { @@ -107,10 +144,14 @@ func ipv6only(addr IPAddr) bool { } // SplitHostPort splits a network address of the form "host:port", -// "[host]:port" or "[ipv6-host%zone]:port" into host or -// ipv6-host%zone and port. A literal address or host name for IPv6 -// must be enclosed in square brackets, as in "[::1]:80", -// "[ipv6-host]:http" or "[ipv6-host%zone]:80". +// "host%zone:port", "[host]:port" or "[host%zone]:port" into host or +// host%zone and port. +// +// A literal IPv6 address in hostport must be enclosed in square +// brackets, as in "[::1]:80", "[::1%lo0]:80". +// +// See func Dial for a description of the hostport parameter, and host +// and port results. func SplitHostPort(hostport string) (host, port string, err error) { const ( missingPort = "missing port in address" @@ -154,9 +195,6 @@ func SplitHostPort(hostport string) (host, port string, err error) { if byteIndex(host, ':') >= 0 { return addrErr(hostport, tooManyColons) } - if byteIndex(host, '%') >= 0 { - return addrErr(hostport, "missing brackets in address") - } } if byteIndex(hostport[j:], '[') >= 0 { return addrErr(hostport, "unexpected '[' in address") @@ -181,11 +219,14 @@ func splitHostZone(s string) (host, zone string) { } // JoinHostPort combines host and port into a network address of the -// form "host:port" or, if host contains a colon or a percent sign, -// "[host]:port". +// form "host:port". If host contains a colon, as found in literal +// IPv6 addresses, then JoinHostPort returns "[host]:port". +// +// See func Dial for a description of the host and port parameters. func JoinHostPort(host, port string) string { - // If host has colons or a percent sign, have to bracket it. - if byteIndex(host, ':') >= 0 || byteIndex(host, '%') >= 0 { + // We assume that host is a literal IPv6 address if host has + // colons. + if byteIndex(host, ':') >= 0 { return "[" + host + "]:" + port } return host + ":" + port @@ -240,6 +281,13 @@ func (r *Resolver) internetAddrList(ctx context.Context, net, addr string) (addr ips = []IPAddr{{IP: ip}} } else if ip, zone := parseIPv6(host, true); ip != nil { ips = []IPAddr{{IP: ip, Zone: zone}} + // Issue 18806: if the machine has halfway configured + // IPv6 such that it can bind on "::" (IPv6unspecified) + // but not connect back to that same address, fall + // back to dialing 0.0.0.0. + if ip.Equal(IPv6unspecified) { + ips = append(ips, IPAddr{IP: IPv4zero}) + } } else { // Try as a DNS name. ips, err = r.LookupIPAddr(ctx, host) diff --git a/libgo/go/net/ipsock_plan9.go b/libgo/go/net/ipsock_plan9.go index b7fd344..312e4ad 100644 --- a/libgo/go/net/ipsock_plan9.go +++ b/libgo/go/net/ipsock_plan9.go @@ -2,8 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Internet protocol family sockets for Plan 9 - package net import ( @@ -12,12 +10,25 @@ import ( "syscall" ) +// Probe probes IPv4, IPv6 and IPv4-mapped IPv6 communication +// capabilities. +// +// Plan 9 uses IPv6 natively, see ip(3). +func (p *ipStackCapabilities) probe() { + p.ipv4Enabled = probe(netdir+"/iproute", "4i") + p.ipv6Enabled = probe(netdir+"/iproute", "6i") + if p.ipv4Enabled && p.ipv6Enabled { + p.ipv4MappedIPv6Enabled = true + } +} + func probe(filename, query string) bool { var file *file var err error if file, err = open(filename); err != nil { return false } + defer file.close() r := false for line, ok := file.readLine(); ok && !r; line, ok = file.readLine() { @@ -32,27 +43,9 @@ func probe(filename, query string) bool { } } } - file.close() return r } -func probeIPv4Stack() bool { - return probe(netdir+"/iproute", "4i") -} - -// probeIPv6Stack returns two boolean values. If the first boolean -// value is true, kernel supports basic IPv6 functionality. If the -// second boolean value is true, kernel supports IPv6 IPv4-mapping. -func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) { - // Plan 9 uses IPv6 natively, see ip(3). - r := probe(netdir+"/iproute", "6i") - v := false - if r { - v = probe(netdir+"/iproute", "4i") - } - return r, v -} - // parsePlan9Addr parses address of the form [ip!]port (e.g. 127.0.0.1!80). func parsePlan9Addr(s string) (ip IP, iport int, err error) { addr := IPv4zero // address contains port only @@ -249,10 +242,10 @@ func (fd *netFD) netFD() (*netFD, error) { func (fd *netFD) acceptPlan9() (nfd *netFD, err error) { defer func() { fixErr(err) }() - if err := fd.readLock(); err != nil { + if err := fd.pfd.ReadLock(); err != nil { return nil, err } - defer fd.readUnlock() + defer fd.pfd.ReadUnlock() listen, err := os.Open(fd.dir + "/listen") if err != nil { return nil, err diff --git a/libgo/go/net/ipsock_posix.go b/libgo/go/net/ipsock_posix.go index 05bf939..4b4363a 100644 --- a/libgo/go/net/ipsock_posix.go +++ b/libgo/go/net/ipsock_posix.go @@ -8,35 +8,29 @@ package net import ( "context" + "internal/poll" "runtime" "syscall" ) -func probeIPv4Stack() bool { +// Probe probes IPv4, IPv6 and IPv4-mapped IPv6 communication +// capabilities which are controlled by the IPV6_V6ONLY socket option +// and kernel configuration. +// +// Should we try to use the IPv4 socket interface if we're only +// dealing with IPv4 sockets? As long as the host system understands +// IPv4-mapped IPv6, it's okay to pass IPv4-mapeed IPv6 addresses to +// the IPv6 interface. That simplifies our code and is most +// general. Unfortunately, we need to run on kernels built without +// IPv6 support too. So probe the kernel to figure it out. +func (p *ipStackCapabilities) probe() { s, err := socketFunc(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP) switch err { case syscall.EAFNOSUPPORT, syscall.EPROTONOSUPPORT: - return false case nil: - closeFunc(s) + poll.CloseFunc(s) + p.ipv4Enabled = true } - return true -} - -// Should we try to use the IPv4 socket interface if we're -// only dealing with IPv4 sockets? As long as the host system -// understands IPv6, it's okay to pass IPv4 addresses to the IPv6 -// interface. That simplifies our code and is most general. -// Unfortunately, we need to run on kernels built without IPv6 -// support too. So probe the kernel to figure it out. -// -// probeIPv6Stack probes both basic IPv6 capability and IPv6 IPv4- -// mapping capability which is controlled by IPV6_V6ONLY socket -// option and/or kernel state "net.inet6.ip6.v6only". -// It returns two boolean values. If the first boolean value is -// true, kernel supports basic IPv6 functionality. If the second -// boolean value is true, kernel supports IPv6 IPv4-mapping. -func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) { var probes = []struct { laddr TCPAddr value int @@ -46,29 +40,19 @@ func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) { // IPv4-mapped IPv6 address communication capability {laddr: TCPAddr{IP: IPv4(127, 0, 0, 1)}, value: 0}, } - var supps [2]bool switch runtime.GOOS { case "dragonfly", "openbsd": - // Some released versions of DragonFly BSD pretend to - // accept IPV6_V6ONLY=0 successfully, but the state - // still stays IPV6_V6ONLY=1. Eventually DragonFly BSD - // stops pretending, but the transition period would - // cause unpredictable behavior and we need to avoid - // it. - // - // OpenBSD also doesn't support IPV6_V6ONLY=0 but it - // never pretends to accept IPV6_V6OLY=0. It always - // returns an error and we don't need to probe the - // capability. + // The latest DragonFly BSD and OpenBSD kernels don't + // support IPV6_V6ONLY=0. They always return an error + // and we don't need to probe the capability. probes = probes[:1] } - for i := range probes { s, err := socketFunc(syscall.AF_INET6, syscall.SOCK_STREAM, syscall.IPPROTO_TCP) if err != nil { continue } - defer closeFunc(s) + defer poll.CloseFunc(s) syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_V6ONLY, probes[i].value) sa, err := probes[i].laddr.sockaddr(syscall.AF_INET6) if err != nil { @@ -77,51 +61,55 @@ func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) { if err := syscall.Bind(s, sa); err != nil { continue } - supps[i] = true + if i == 0 { + p.ipv6Enabled = true + } else { + p.ipv4MappedIPv6Enabled = true + } } - - return supps[0], supps[1] } -// favoriteAddrFamily returns the appropriate address family to -// the given net, laddr, raddr and mode. At first it figures -// address family out from the net. If mode indicates "listen" -// and laddr is a wildcard, it assumes that the user wants to -// make a passive connection with a wildcard address family, both -// AF_INET and AF_INET6, and a wildcard address like following: +// favoriteAddrFamily returns the appropriate address family for the +// given network, laddr, raddr and mode. +// +// If mode indicates "listen" and laddr is a wildcard, we assume that +// the user wants to make a passive-open connection with a wildcard +// address family, both AF_INET and AF_INET6, and a wildcard address +// like the following: // -// 1. A wild-wild listen, "tcp" + "" -// If the platform supports both IPv6 and IPv6 IPv4-mapping -// capabilities, or does not support IPv4, we assume that -// the user wants to listen on both IPv4 and IPv6 wildcard -// addresses over an AF_INET6 socket with IPV6_V6ONLY=0. -// Otherwise we prefer an IPv4 wildcard address listen over -// an AF_INET socket. +// - A listen for a wildcard communication domain, "tcp" or +// "udp", with a wildcard address: If the platform supports +// both IPv6 and IPv4-mapped IPv6 communication capabilities, +// or does not support IPv4, we use a dual stack, AF_INET6 and +// IPV6_V6ONLY=0, wildcard address listen. The dual stack +// wildcard address listen may fall back to an IPv6-only, +// AF_INET6 and IPV6_V6ONLY=1, wildcard address listen. +// Otherwise we prefer an IPv4-only, AF_INET, wildcard address +// listen. // -// 2. A wild-ipv4wild listen, "tcp" + "0.0.0.0" -// Same as 1. +// - A listen for a wildcard communication domain, "tcp" or +// "udp", with an IPv4 wildcard address: same as above. // -// 3. A wild-ipv6wild listen, "tcp" + "[::]" -// Almost same as 1 but we prefer an IPv6 wildcard address -// listen over an AF_INET6 socket with IPV6_V6ONLY=0 when -// the platform supports IPv6 capability but not IPv6 IPv4- -// mapping capability. +// - A listen for a wildcard communication domain, "tcp" or +// "udp", with an IPv6 wildcard address: same as above. // -// 4. A ipv4-ipv4wild listen, "tcp4" + "" or "0.0.0.0" -// We use an IPv4 (AF_INET) wildcard address listen. +// - A listen for an IPv4 communication domain, "tcp4" or "udp4", +// with an IPv4 wildcard address: We use an IPv4-only, AF_INET, +// wildcard address listen. // -// 5. A ipv6-ipv6wild listen, "tcp6" + "" or "[::]" -// We use an IPv6 (AF_INET6, IPV6_V6ONLY=1) wildcard address -// listen. +// - A listen for an IPv6 communication domain, "tcp6" or "udp6", +// with an IPv6 wildcard address: We use an IPv6-only, AF_INET6 +// and IPV6_V6ONLY=1, wildcard address listen. // -// Otherwise guess: if the addresses are IPv4 then returns AF_INET, -// or else returns AF_INET6. It also returns a boolean value what +// Otherwise guess: If the addresses are IPv4 then returns AF_INET, +// or else returns AF_INET6. It also returns a boolean value what // designates IPV6_V6ONLY option. // -// Note that OpenBSD allows neither "net.inet6.ip6.v6only=1" change -// nor IPPROTO_IPV6 level IPV6_V6ONLY socket option setting. -func favoriteAddrFamily(net string, laddr, raddr sockaddr, mode string) (family int, ipv6only bool) { - switch net[len(net)-1] { +// Note that the latest DragonFly BSD and OpenBSD kernels allow +// neither "net.inet6.ip6.v6only=1" change nor IPPROTO_IPV6 level +// IPV6_V6ONLY socket option setting. +func favoriteAddrFamily(network string, laddr, raddr sockaddr, mode string) (family int, ipv6only bool) { + switch network[len(network)-1] { case '4': return syscall.AF_INET, false case '6': @@ -129,7 +117,7 @@ func favoriteAddrFamily(net string, laddr, raddr sockaddr, mode string) (family } if mode == "listen" && (laddr == nil || laddr.isWildcard()) { - if supportsIPv4map || !supportsIPv4 { + if supportsIPv4map() || !supportsIPv4() { return syscall.AF_INET6, false } if laddr == nil { @@ -145,7 +133,6 @@ func favoriteAddrFamily(net string, laddr, raddr sockaddr, mode string) (family return syscall.AF_INET6, false } -// Internet sockets (TCP, UDP, IP) func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string) (fd *netFD, err error) { if (runtime.GOOS == "windows" || runtime.GOOS == "openbsd" || runtime.GOOS == "nacl") && mode == "dial" && raddr.isWildcard() { raddr = raddr.toLocal(net) @@ -187,7 +174,7 @@ func ipToSockaddr(family int, ip IP, port int, zone string) (syscall.Sockaddr, e if ip6 == nil { return nil, &AddrError{Err: "non-IPv6 address", Addr: ip.String()} } - sa := &syscall.SockaddrInet6{Port: port, ZoneId: uint32(zoneToInt(zone))} + sa := &syscall.SockaddrInet6{Port: port, ZoneId: uint32(zoneCache.index(zone))} copy(sa.Addr[:], ip6) return sa, nil } diff --git a/libgo/go/net/ipsock_test.go b/libgo/go/net/ipsock_test.go index 1d0f00f..aede354 100644 --- a/libgo/go/net/ipsock_test.go +++ b/libgo/go/net/ipsock_test.go @@ -215,7 +215,7 @@ var addrListTests = []struct { } func TestAddrList(t *testing.T) { - if !supportsIPv4 || !supportsIPv6 { + if !supportsIPv4() || !supportsIPv6() { t.Skip("both IPv4 and IPv6 are required") } diff --git a/libgo/go/net/listen_test.go b/libgo/go/net/listen_test.go index 6037f36..21ad446 100644 --- a/libgo/go/net/listen_test.go +++ b/libgo/go/net/listen_test.go @@ -225,7 +225,7 @@ func TestDualStackTCPListener(t *testing.T) { case "nacl", "plan9": t.Skipf("not supported on %s", runtime.GOOS) } - if !supportsIPv4 || !supportsIPv6 { + if !supportsIPv4() || !supportsIPv6() { t.Skip("both IPv4 and IPv6 are required") } @@ -235,7 +235,7 @@ func TestDualStackTCPListener(t *testing.T) { continue } - if !supportsIPv4map && differentWildcardAddr(tt.address1, tt.address2) { + if !supportsIPv4map() && differentWildcardAddr(tt.address1, tt.address2) { tt.xerr = nil } var firstErr, secondErr error @@ -315,7 +315,7 @@ func TestDualStackUDPListener(t *testing.T) { case "nacl", "plan9": t.Skipf("not supported on %s", runtime.GOOS) } - if !supportsIPv4 || !supportsIPv6 { + if !supportsIPv4() || !supportsIPv6() { t.Skip("both IPv4 and IPv6 are required") } @@ -325,7 +325,7 @@ func TestDualStackUDPListener(t *testing.T) { continue } - if !supportsIPv4map && differentWildcardAddr(tt.address1, tt.address2) { + if !supportsIPv4map() && differentWildcardAddr(tt.address1, tt.address2) { tt.xerr = nil } var firstErr, secondErr error @@ -454,7 +454,7 @@ func checkDualStackAddrFamily(fd *netFD) error { // and IPv6 IPv4-mapping capability, we can assume // that the node listens on a wildcard address with an // AF_INET6 socket. - if supportsIPv4map && fd.laddr.(*TCPAddr).isWildcard() { + if supportsIPv4map() && fd.laddr.(*TCPAddr).isWildcard() { if fd.family != syscall.AF_INET6 { return fmt.Errorf("Listen(%s, %v) returns %v; want %v", fd.net, fd.laddr, fd.family, syscall.AF_INET6) } @@ -468,7 +468,7 @@ func checkDualStackAddrFamily(fd *netFD) error { // and IPv6 IPv4-mapping capability, we can assume // that the node listens on a wildcard address with an // AF_INET6 socket. - if supportsIPv4map && fd.laddr.(*UDPAddr).isWildcard() { + if supportsIPv4map() && fd.laddr.(*UDPAddr).isWildcard() { if fd.family != syscall.AF_INET6 { return fmt.Errorf("ListenPacket(%s, %v) returns %v; want %v", fd.net, fd.laddr, fd.family, syscall.AF_INET6) } @@ -535,7 +535,7 @@ func TestIPv4MulticastListener(t *testing.T) { case "solaris": t.Skipf("not supported on solaris, see golang.org/issue/7399") } - if !supportsIPv4 { + if !supportsIPv4() { t.Skip("IPv4 is not supported") } @@ -610,7 +610,7 @@ func TestIPv6MulticastListener(t *testing.T) { case "solaris": t.Skipf("not supported on solaris, see issue 7399") } - if !supportsIPv6 { + if !supportsIPv6() { t.Skip("IPv6 is not supported") } if os.Getuid() != 0 { diff --git a/libgo/go/net/lookup.go b/libgo/go/net/lookup.go index cc2013e..c9f3270 100644 --- a/libgo/go/net/lookup.go +++ b/libgo/go/net/lookup.go @@ -28,6 +28,9 @@ var protocols = map[string]int{ // services contains minimal mappings between services names and port // numbers for platforms that don't have a complete list of port numbers // (some Solaris distros, nacl, etc). +// +// See https://www.iana.org/assignments/service-names-port-numbers +// // On Unix, this map is augmented by readServices via goLookupPort. var services = map[string]map[string]int{ "udp": { @@ -63,7 +66,12 @@ func lookupProtocolMap(name string) (int, error) { return proto, nil } -const maxServiceLength = len("mobility-header") + 10 // with room to grow +// maxPortBufSize is the longest reasonable name of a service +// (non-numeric port). +// Currently the longest known IANA-unregistered name is +// "mobility-header", so we use that length, plus some slop in case +// something longer is added in the future. +const maxPortBufSize = len("mobility-header") + 10 func lookupPortMap(network, service string) (port int, error error) { switch network { @@ -74,7 +82,7 @@ func lookupPortMap(network, service string) (port int, error error) { } if m, ok := services[network]; ok { - var lowerService [maxServiceLength]byte + var lowerService [maxPortBufSize]byte n := copy(lowerService[:], service) lowerASCIIBytes(lowerService[:n]) if port, ok := m[string(lowerService[:n])]; ok && n == len(service) { @@ -97,6 +105,29 @@ type Resolver struct { // GODEBUG=netdns=go, but scoped to just this resolver. PreferGo bool + // StrictErrors controls the behavior of temporary errors + // (including timeout, socket errors, and SERVFAIL) when using + // Go's built-in resolver. For a query composed of multiple + // sub-queries (such as an A+AAAA address lookup, or walking the + // DNS search list), this option causes such errors to abort the + // whole query instead of returning a partial result. This is + // not enabled by default because it may affect compatibility + // with resolvers that process AAAA queries incorrectly. + StrictErrors bool + + // Dial optionally specifies an alternate dialer for use by + // Go's built-in DNS resolver to make TCP and UDP connections + // to DNS services. The host in the address parameter will + // always be a literal IP address and not a host name, and the + // port in the address parameter will be a literal port number + // and not a service name. + // If the Conn returned is also a PacketConn, sent and received DNS + // messages must adhere to RFC 1035 section 4.2.1, "UDP usage". + // Otherwise, DNS messages transmitted over Conn must adhere + // to RFC 7766 section 5, "Transport Protocol Selection". + // If nil, the default dialer is used. + Dial func(ctx context.Context, network, address string) (Conn, error) + // TODO(bradfitz): optional interface impl override hook // TODO(bradfitz): Timeout time.Duration? } @@ -164,12 +195,15 @@ func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, err select { case <-ctx.Done(): - // The DNS lookup timed out for some reason. Force + // If the DNS lookup timed out for some reason, force // future requests to start the DNS lookup again // rather than waiting for the current lookup to // complete. See issue 8602. - err := mapErr(ctx.Err()) - lookupGroup.Forget(host) + ctxErr := ctx.Err() + if ctxErr == context.DeadlineExceeded { + lookupGroup.Forget(host) + } + err := mapErr(ctxErr) if trace != nil && trace.DNSDone != nil { trace.DNSDone(nil, false, err) } diff --git a/libgo/go/net/lookup_test.go b/libgo/go/net/lookup_test.go index 36db56a..68a7abe 100644 --- a/libgo/go/net/lookup_test.go +++ b/libgo/go/net/lookup_test.go @@ -63,7 +63,7 @@ func TestLookupGoogleSRV(t *testing.T) { testenv.MustHaveExternalNetwork(t) } - if !supportsIPv4 || !*testIPv4 { + if !supportsIPv4() || !*testIPv4 { t.Skip("IPv4 is required") } @@ -99,7 +99,7 @@ func TestLookupGmailMX(t *testing.T) { testenv.MustHaveExternalNetwork(t) } - if !supportsIPv4 || !*testIPv4 { + if !supportsIPv4() || !*testIPv4 { t.Skip("IPv4 is required") } @@ -131,7 +131,7 @@ func TestLookupGmailNS(t *testing.T) { testenv.MustHaveExternalNetwork(t) } - if !supportsIPv4 || !*testIPv4 { + if !supportsIPv4() || !*testIPv4 { t.Skip("IPv4 is required") } @@ -164,7 +164,7 @@ func TestLookupGmailTXT(t *testing.T) { testenv.MustHaveExternalNetwork(t) } - if !supportsIPv4 || !*testIPv4 { + if !supportsIPv4() || !*testIPv4 { t.Skip("IPv4 is required") } @@ -199,7 +199,7 @@ func TestLookupGooglePublicDNSAddr(t *testing.T) { testenv.MustHaveExternalNetwork(t) } - if !supportsIPv4 || !supportsIPv6 || !*testIPv4 || !*testIPv6 { + if !supportsIPv4() || !supportsIPv6() || !*testIPv4 || !*testIPv6 { t.Skip("both IPv4 and IPv6 are required") } @@ -220,7 +220,7 @@ func TestLookupGooglePublicDNSAddr(t *testing.T) { } func TestLookupIPv6LinkLocalAddr(t *testing.T) { - if !supportsIPv6 || !*testIPv6 { + if !supportsIPv6() || !*testIPv6 { t.Skip("IPv6 is required") } @@ -256,7 +256,7 @@ func TestLookupCNAME(t *testing.T) { testenv.MustHaveExternalNetwork(t) } - if !supportsIPv4 || !*testIPv4 { + if !supportsIPv4() || !*testIPv4 { t.Skip("IPv4 is required") } @@ -283,7 +283,7 @@ func TestLookupGoogleHost(t *testing.T) { testenv.MustHaveExternalNetwork(t) } - if !supportsIPv4 || !*testIPv4 { + if !supportsIPv4() || !*testIPv4 { t.Skip("IPv4 is required") } @@ -315,7 +315,7 @@ func TestLookupGoogleIP(t *testing.T) { testenv.MustHaveExternalNetwork(t) } - if !supportsIPv4 || !*testIPv4 { + if !supportsIPv4() || !*testIPv4 { t.Skip("IPv4 is required") } @@ -450,7 +450,7 @@ func TestDNSFlood(t *testing.T) { } func TestLookupDotsWithLocalSource(t *testing.T) { - if !supportsIPv4 || !*testIPv4 { + if !supportsIPv4() || !*testIPv4 { t.Skip("IPv4 is required") } @@ -499,7 +499,7 @@ func TestLookupDotsWithRemoteSource(t *testing.T) { testenv.MustHaveExternalNetwork(t) } - if !supportsIPv4 || !*testIPv4 { + if !supportsIPv4() || !*testIPv4 { t.Skip("IPv4 is required") } diff --git a/libgo/go/net/lookup_unix.go b/libgo/go/net/lookup_unix.go index f96c8be..2813f14 100644 --- a/libgo/go/net/lookup_unix.go +++ b/libgo/go/net/lookup_unix.go @@ -16,28 +16,31 @@ var onceReadProtocols sync.Once // readProtocols loads contents of /etc/protocols into protocols map // for quick access. func readProtocols() { - if file, err := open("/etc/protocols"); err == nil { - for line, ok := file.readLine(); ok; line, ok = file.readLine() { - // tcp 6 TCP # transmission control protocol - if i := byteIndex(line, '#'); i >= 0 { - line = line[0:i] - } - f := getFields(line) - if len(f) < 2 { - continue + file, err := open("/etc/protocols") + if err != nil { + return + } + defer file.close() + + for line, ok := file.readLine(); ok; line, ok = file.readLine() { + // tcp 6 TCP # transmission control protocol + if i := byteIndex(line, '#'); i >= 0 { + line = line[0:i] + } + f := getFields(line) + if len(f) < 2 { + continue + } + if proto, _, ok := dtoi(f[1]); ok { + if _, ok := protocols[f[0]]; !ok { + protocols[f[0]] = proto } - if proto, _, ok := dtoi(f[1]); ok { - if _, ok := protocols[f[0]]; !ok { - protocols[f[0]] = proto - } - for _, alias := range f[2:] { - if _, ok := protocols[alias]; !ok { - protocols[alias] = proto - } + for _, alias := range f[2:] { + if _, ok := protocols[alias]; !ok { + protocols[alias] = proto } } } - file.close() } } @@ -48,6 +51,29 @@ func lookupProtocol(_ context.Context, name string) (int, error) { return lookupProtocolMap(name) } +func (r *Resolver) dial(ctx context.Context, network, server string) (dnsConn, error) { + // Calling Dial here is scary -- we have to be sure not to + // dial a name that will require a DNS lookup, or Dial will + // call back here to translate it. The DNS config parser has + // already checked that all the cfg.servers are IP + // addresses, which Dial will use without a DNS lookup. + var c Conn + var err error + if r.Dial != nil { + c, err = r.Dial(ctx, network, server) + } else { + var d Dialer + c, err = d.DialContext(ctx, network, server) + } + if err != nil { + return nil, mapErr(err) + } + if _, ok := c.(PacketConn); ok { + return &dnsPacketConn{c}, nil + } + return &dnsStreamConn{c}, nil +} + func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) { order := systemConf().hostLookupOrder(host) if !r.PreferGo && order == hostLookupCgo { @@ -57,12 +83,12 @@ func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, // cgo not available (or netgo); fall back to Go's DNS resolver order = hostLookupFilesDNS } - return goLookupHostOrder(ctx, host, order) + return r.goLookupHostOrder(ctx, host, order) } func (r *Resolver) lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { if r.PreferGo { - return goLookupIP(ctx, host) + return r.goLookupIP(ctx, host) } order := systemConf().hostLookupOrder(host) if order == hostLookupCgo { @@ -72,7 +98,7 @@ func (r *Resolver) lookupIP(ctx context.Context, host string) (addrs []IPAddr, e // cgo not available (or netgo); fall back to Go's DNS resolver order = hostLookupFilesDNS } - addrs, _, err = goLookupIPCNAMEOrder(ctx, host, order) + addrs, _, err = r.goLookupIPCNAMEOrder(ctx, host, order) return } @@ -98,17 +124,17 @@ func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error) return cname, err } } - return goLookupCNAME(ctx, name) + return r.goLookupCNAME(ctx, name) } -func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) { +func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) { var target string if service == "" && proto == "" { target = name } else { target = "_" + service + "._" + proto + "." + name } - cname, rrs, err := lookup(ctx, target, dnsTypeSRV) + cname, rrs, err := r.lookup(ctx, target, dnsTypeSRV) if err != nil { return "", nil, err } @@ -121,8 +147,8 @@ func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (st return cname, srvs, nil } -func (*Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) { - _, rrs, err := lookup(ctx, name, dnsTypeMX) +func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) { + _, rrs, err := r.lookup(ctx, name, dnsTypeMX) if err != nil { return nil, err } @@ -135,8 +161,8 @@ func (*Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) { return mxs, nil } -func (*Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) { - _, rrs, err := lookup(ctx, name, dnsTypeNS) +func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) { + _, rrs, err := r.lookup(ctx, name, dnsTypeNS) if err != nil { return nil, err } @@ -148,7 +174,7 @@ func (*Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) { } func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) { - _, rrs, err := lookup(ctx, name, dnsTypeTXT) + _, rrs, err := r.lookup(ctx, name, dnsTypeTXT) if err != nil { return nil, err } @@ -165,5 +191,5 @@ func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error return ptrs, err } } - return goLookupPTR(ctx, addr) + return r.goLookupPTR(ctx, addr) } diff --git a/libgo/go/net/lookup_windows.go b/libgo/go/net/lookup_windows.go index 5808293..0036d89 100644 --- a/libgo/go/net/lookup_windows.go +++ b/libgo/go/net/lookup_windows.go @@ -107,7 +107,7 @@ func (r *Resolver) lookupIP(ctx context.Context, name string) ([]IPAddr, error) addrs = append(addrs, IPAddr{IP: IPv4(a[0], a[1], a[2], a[3])}) case syscall.AF_INET6: a := (*syscall.RawSockaddrInet6)(addr).Addr - zone := zoneToString(int((*syscall.RawSockaddrInet6)(addr).Scope_id)) + zone := zoneCache.name(int((*syscall.RawSockaddrInet6)(addr).Scope_id)) addrs = append(addrs, IPAddr{IP: IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]}, Zone: zone}) default: ch <- ret{err: &DNSError{Err: syscall.EWINDOWS.Error(), Name: name}} diff --git a/libgo/go/net/mail/message.go b/libgo/go/net/mail/message.go index 702b765..45a995e 100644 --- a/libgo/go/net/mail/message.go +++ b/libgo/go/net/mail/message.go @@ -49,7 +49,7 @@ type Message struct { // ReadMessage reads a message from r. // The headers are parsed, and the body of the message will be available -// for reading from r. +// for reading from msg.Body. func ReadMessage(r io.Reader) (msg *Message, err error) { tp := textproto.NewReader(bufio.NewReader(r)) @@ -387,13 +387,15 @@ func (p *addrParser) consumePhrase() (phrase string, err error) { debug.Printf("consumePhrase: [%s]", p.s) // phrase = 1*word var words []string + var isPrevEncoded bool for { // word = atom / quoted-string var word string p.skipSpace() if p.empty() { - return "", errors.New("mail: missing phrase") + break } + isEncoded := false if p.peek() == '"' { // quoted-string word, err = p.consumeQuotedString() @@ -403,7 +405,7 @@ func (p *addrParser) consumePhrase() (phrase string, err error) { // than what RFC 5322 specifies. word, err = p.consumeAtom(true, true) if err == nil { - word, err = p.decodeRFC2047Word(word) + word, isEncoded, err = p.decodeRFC2047Word(word) } } @@ -411,7 +413,12 @@ func (p *addrParser) consumePhrase() (phrase string, err error) { break } debug.Printf("consumePhrase: consumed %q", word) - words = append(words, word) + if isPrevEncoded && isEncoded { + words[len(words)-1] += word + } else { + words = append(words, word) + } + isPrevEncoded = isEncoded } // Ignore any error if we got at least one word. if err != nil && len(words) == 0 { @@ -540,22 +547,23 @@ func (p *addrParser) len() int { return len(p.s) } -func (p *addrParser) decodeRFC2047Word(s string) (string, error) { +func (p *addrParser) decodeRFC2047Word(s string) (word string, isEncoded bool, err error) { if p.dec != nil { - return p.dec.DecodeHeader(s) + word, err = p.dec.Decode(s) + } else { + word, err = rfc2047Decoder.Decode(s) } - dec, err := rfc2047Decoder.Decode(s) if err == nil { - return dec, nil + return word, true, nil } if _, ok := err.(charsetError); ok { - return s, err + return s, true, err } // Ignore invalid RFC 2047 encoded-word errors. - return s, nil + return s, false, nil } var rfc2047Decoder = mime.WordDecoder{ diff --git a/libgo/go/net/mail/message_test.go b/libgo/go/net/mail/message_test.go index f0761ab..2106a0b 100644 --- a/libgo/go/net/mail/message_test.go +++ b/libgo/go/net/mail/message_test.go @@ -136,6 +136,7 @@ func TestAddressParsingError(t *testing.T) { 4: {"\"\\" + string([]byte{0x80}) + "\" <escaped-invalid-unicode@example.net>", "invalid utf-8 in quoted-string"}, 5: {"\"\x00\" <null@example.net>", "bad character in quoted-string"}, 6: {"\"\\\x00\" <escaped-null@example.net>", "bad character in quoted-string"}, + 7: {"John Doe", "no angle-addr"}, } for i, tc := range mustErrTestCases { @@ -235,6 +236,16 @@ func TestAddressParsing(t *testing.T) { }, }, }, + // RFC 2047 "Q"-encoded UTF-8 address with multiple encoded-words. + { + `=?utf-8?q?J=C3=B6rg?= =?utf-8?q?Doe?= <joerg@example.com>`, + []*Address{ + { + Name: `JörgDoe`, + Address: "joerg@example.com", + }, + }, + }, // RFC 2047, Section 8. { `=?ISO-8859-1?Q?Andr=E9?= Pirard <PIRARD@vm1.ulg.ac.be>`, diff --git a/libgo/go/net/main_cloexec_test.go b/libgo/go/net/main_cloexec_test.go index 7903819..fa1ed02 100644 --- a/libgo/go/net/main_cloexec_test.go +++ b/libgo/go/net/main_cloexec_test.go @@ -2,10 +2,12 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build freebsd linux +// +build dragonfly freebsd linux package net +import "internal/poll" + func init() { extraTestHookInstallers = append(extraTestHookInstallers, installAccept4TestHook) extraTestHookUninstallers = append(extraTestHookUninstallers, uninstallAccept4TestHook) @@ -13,13 +15,13 @@ func init() { var ( // Placeholders for saving original socket system calls. - origAccept4 = accept4Func + origAccept4 = poll.Accept4Func ) func installAccept4TestHook() { - accept4Func = sw.Accept4 + poll.Accept4Func = sw.Accept4 } func uninstallAccept4TestHook() { - accept4Func = origAccept4 + poll.Accept4Func = origAccept4 } diff --git a/libgo/go/net/main_test.go b/libgo/go/net/main_test.go index 28a8ff6..3e7a85a 100644 --- a/libgo/go/net/main_test.go +++ b/libgo/go/net/main_test.go @@ -70,7 +70,7 @@ var ( ) func setupTestData() { - if supportsIPv4 { + if supportsIPv4() { resolveTCPAddrTests = append(resolveTCPAddrTests, []resolveTCPAddrTest{ {"tcp", "localhost:1", &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 1}, nil}, {"tcp4", "localhost:2", &TCPAddr{IP: IPv4(127, 0, 0, 1), Port: 2}, nil}, @@ -85,25 +85,31 @@ func setupTestData() { }...) } - if supportsIPv6 { + if supportsIPv6() { resolveTCPAddrTests = append(resolveTCPAddrTests, resolveTCPAddrTest{"tcp6", "localhost:3", &TCPAddr{IP: IPv6loopback, Port: 3}, nil}) resolveUDPAddrTests = append(resolveUDPAddrTests, resolveUDPAddrTest{"udp6", "localhost:3", &UDPAddr{IP: IPv6loopback, Port: 3}, nil}) resolveIPAddrTests = append(resolveIPAddrTests, resolveIPAddrTest{"ip6", "localhost", &IPAddr{IP: IPv6loopback}, nil}) + + // Issue 20911: don't return IPv4 addresses for + // Resolve*Addr calls of the IPv6 unspecified address. + resolveTCPAddrTests = append(resolveTCPAddrTests, resolveTCPAddrTest{"tcp", "[::]:4", &TCPAddr{IP: IPv6unspecified, Port: 4}, nil}) + resolveUDPAddrTests = append(resolveUDPAddrTests, resolveUDPAddrTest{"udp", "[::]:4", &UDPAddr{IP: IPv6unspecified, Port: 4}, nil}) + resolveIPAddrTests = append(resolveIPAddrTests, resolveIPAddrTest{"ip", "::", &IPAddr{IP: IPv6unspecified}, nil}) } ifi := loopbackInterface() if ifi != nil { index := fmt.Sprintf("%v", ifi.Index) resolveTCPAddrTests = append(resolveTCPAddrTests, []resolveTCPAddrTest{ - {"tcp6", "[fe80::1%" + ifi.Name + "]:1", &TCPAddr{IP: ParseIP("fe80::1"), Port: 1, Zone: zoneToString(ifi.Index)}, nil}, + {"tcp6", "[fe80::1%" + ifi.Name + "]:1", &TCPAddr{IP: ParseIP("fe80::1"), Port: 1, Zone: zoneCache.name(ifi.Index)}, nil}, {"tcp6", "[fe80::1%" + index + "]:2", &TCPAddr{IP: ParseIP("fe80::1"), Port: 2, Zone: index}, nil}, }...) resolveUDPAddrTests = append(resolveUDPAddrTests, []resolveUDPAddrTest{ - {"udp6", "[fe80::1%" + ifi.Name + "]:1", &UDPAddr{IP: ParseIP("fe80::1"), Port: 1, Zone: zoneToString(ifi.Index)}, nil}, + {"udp6", "[fe80::1%" + ifi.Name + "]:1", &UDPAddr{IP: ParseIP("fe80::1"), Port: 1, Zone: zoneCache.name(ifi.Index)}, nil}, {"udp6", "[fe80::1%" + index + "]:2", &UDPAddr{IP: ParseIP("fe80::1"), Port: 2, Zone: index}, nil}, }...) resolveIPAddrTests = append(resolveIPAddrTests, []resolveIPAddrTest{ - {"ip6", "fe80::1%" + ifi.Name, &IPAddr{IP: ParseIP("fe80::1"), Zone: zoneToString(ifi.Index)}, nil}, + {"ip6", "fe80::1%" + ifi.Name, &IPAddr{IP: ParseIP("fe80::1"), Zone: zoneCache.name(ifi.Index)}, nil}, {"ip6", "fe80::1%" + index, &IPAddr{IP: ParseIP("fe80::1"), Zone: index}, nil}, }...) } diff --git a/libgo/go/net/main_unix_test.go b/libgo/go/net/main_unix_test.go index 8c8f944..34a8a10 100644 --- a/libgo/go/net/main_unix_test.go +++ b/libgo/go/net/main_unix_test.go @@ -6,13 +6,15 @@ package net +import "internal/poll" + var ( // Placeholders for saving original socket system calls. origSocket = socketFunc - origClose = closeFunc + origClose = poll.CloseFunc origConnect = connectFunc origListen = listenFunc - origAccept = acceptFunc + origAccept = poll.AcceptFunc origGetsockoptInt = getsockoptIntFunc extraTestHookInstallers []func() @@ -21,10 +23,10 @@ var ( func installTestHooks() { socketFunc = sw.Socket - closeFunc = sw.Close + poll.CloseFunc = sw.Close connectFunc = sw.Connect listenFunc = sw.Listen - acceptFunc = sw.Accept + poll.AcceptFunc = sw.Accept getsockoptIntFunc = sw.GetsockoptInt for _, fn := range extraTestHookInstallers { @@ -34,10 +36,10 @@ func installTestHooks() { func uninstallTestHooks() { socketFunc = origSocket - closeFunc = origClose + poll.CloseFunc = origClose connectFunc = origConnect listenFunc = origListen - acceptFunc = origAccept + poll.AcceptFunc = origAccept getsockoptIntFunc = origGetsockoptInt for _, fn := range extraTestHookUninstallers { @@ -48,6 +50,6 @@ func uninstallTestHooks() { // forceCloseSockets must be called only from TestMain. func forceCloseSockets() { for s := range sw.Sockets() { - closeFunc(s) + poll.CloseFunc(s) } } diff --git a/libgo/go/net/main_windows_test.go b/libgo/go/net/main_windows_test.go index 6ea318c..f38a3a0 100644 --- a/libgo/go/net/main_windows_test.go +++ b/libgo/go/net/main_windows_test.go @@ -4,37 +4,39 @@ package net +import "internal/poll" + var ( // Placeholders for saving original socket system calls. origSocket = socketFunc - origClosesocket = closeFunc + origClosesocket = poll.CloseFunc origConnect = connectFunc - origConnectEx = connectExFunc + origConnectEx = poll.ConnectExFunc origListen = listenFunc - origAccept = acceptFunc + origAccept = poll.AcceptFunc ) func installTestHooks() { socketFunc = sw.Socket - closeFunc = sw.Closesocket + poll.CloseFunc = sw.Closesocket connectFunc = sw.Connect - connectExFunc = sw.ConnectEx + poll.ConnectExFunc = sw.ConnectEx listenFunc = sw.Listen - acceptFunc = sw.AcceptEx + poll.AcceptFunc = sw.AcceptEx } func uninstallTestHooks() { socketFunc = origSocket - closeFunc = origClosesocket + poll.CloseFunc = origClosesocket connectFunc = origConnect - connectExFunc = origConnectEx + poll.ConnectExFunc = origConnectEx listenFunc = origListen - acceptFunc = origAccept + poll.AcceptFunc = origAccept } // forceCloseSockets must be called only from TestMain. func forceCloseSockets() { for s := range sw.Sockets() { - closeFunc(s) + poll.CloseFunc(s) } } diff --git a/libgo/go/net/mockserver_test.go b/libgo/go/net/mockserver_test.go index 766de6a..44581d9 100644 --- a/libgo/go/net/mockserver_test.go +++ b/libgo/go/net/mockserver_test.go @@ -31,20 +31,20 @@ func testUnixAddr() string { func newLocalListener(network string) (Listener, error) { switch network { case "tcp": - if supportsIPv4 { + if supportsIPv4() { if ln, err := Listen("tcp4", "127.0.0.1:0"); err == nil { return ln, nil } } - if supportsIPv6 { + if supportsIPv6() { return Listen("tcp6", "[::1]:0") } case "tcp4": - if supportsIPv4 { + if supportsIPv4() { return Listen("tcp4", "127.0.0.1:0") } case "tcp6": - if supportsIPv6 { + if supportsIPv6() { return Listen("tcp6", "[::1]:0") } case "unix", "unixpacket": @@ -333,18 +333,18 @@ func timeoutTransmitter(c Conn, d, min, max time.Duration, ch chan<- error) { func newLocalPacketListener(network string) (PacketConn, error) { switch network { case "udp": - if supportsIPv4 { + if supportsIPv4() { return ListenPacket("udp4", "127.0.0.1:0") } - if supportsIPv6 { + if supportsIPv6() { return ListenPacket("udp6", "[::1]:0") } case "udp4": - if supportsIPv4 { + if supportsIPv4() { return ListenPacket("udp4", "127.0.0.1:0") } case "udp6": - if supportsIPv6 { + if supportsIPv6() { return ListenPacket("udp6", "[::1]:0") } case "unixgram": diff --git a/libgo/go/net/net.go b/libgo/go/net/net.go index a8b5736..91ec048 100644 --- a/libgo/go/net/net.go +++ b/libgo/go/net/net.go @@ -81,6 +81,7 @@ package net import ( "context" "errors" + "internal/poll" "io" "os" "syscall" @@ -95,12 +96,6 @@ var ( netCgo bool // set true in conf_netcgo.go for build tag "netcgo" ) -func init() { - sysInit() - supportsIPv4 = probeIPv4Stack() - supportsIPv6, supportsIPv4map = probeIPv6Stack() -} - // Addr represents a network end point address. // // The two methods Network and String conventionally return strings @@ -234,7 +229,7 @@ func (c *conn) SetDeadline(t time.Time) error { if !c.ok() { return syscall.EINVAL } - if err := c.fd.setDeadline(t); err != nil { + if err := c.fd.pfd.SetDeadline(t); err != nil { return &OpError{Op: "set", Net: c.fd.net, Source: nil, Addr: c.fd.laddr, Err: err} } return nil @@ -245,7 +240,7 @@ func (c *conn) SetReadDeadline(t time.Time) error { if !c.ok() { return syscall.EINVAL } - if err := c.fd.setReadDeadline(t); err != nil { + if err := c.fd.pfd.SetReadDeadline(t); err != nil { return &OpError{Op: "set", Net: c.fd.net, Source: nil, Addr: c.fd.laddr, Err: err} } return nil @@ -256,7 +251,7 @@ func (c *conn) SetWriteDeadline(t time.Time) error { if !c.ok() { return syscall.EINVAL } - if err := c.fd.setWriteDeadline(t); err != nil { + if err := c.fd.pfd.SetWriteDeadline(t); err != nil { return &OpError{Op: "set", Net: c.fd.net, Source: nil, Addr: c.fd.laddr, Err: err} } return nil @@ -391,10 +386,8 @@ var ( errMissingAddress = errors.New("missing address") // For both read and write operations. - errTimeout error = &timeoutError{} - errCanceled = errors.New("operation was canceled") - errClosing = errors.New("use of closed network connection") - ErrWriteToConnected = errors.New("use of WriteTo with pre-connected connection") + errCanceled = errors.New("operation was canceled") + ErrWriteToConnected = errors.New("use of WriteTo with pre-connected connection") ) // mapErr maps from the context errors to the historical internal net @@ -407,7 +400,7 @@ func mapErr(err error) error { case context.Canceled: return errCanceled case context.DeadlineExceeded: - return errTimeout + return poll.ErrTimeout default: return err } @@ -502,12 +495,6 @@ func (e *OpError) Temporary() bool { return ok && t.Temporary() } -type timeoutError struct{} - -func (e *timeoutError) Error() string { return "i/o timeout" } -func (e *timeoutError) Timeout() bool { return true } -func (e *timeoutError) Temporary() bool { return true } - // A ParseError is the error type of literal network address parsers. type ParseError struct { // Type is the type of string that was expected, such as @@ -632,8 +619,6 @@ type buffersWriter interface { writeBuffers(*Buffers) (int64, error) } -var testHookDidWritev = func(wrote int) {} - // Buffers contains zero or more runs of bytes to write. // // On certain machines, for certain types of connections, this is diff --git a/libgo/go/net/net_test.go b/libgo/go/net/net_test.go index 9a9a7e5..024505e 100644 --- a/libgo/go/net/net_test.go +++ b/libgo/go/net/net_test.go @@ -54,7 +54,7 @@ func TestCloseRead(t *testing.T) { err = c.CloseRead() } if err != nil { - if perr := parseCloseError(err); perr != nil { + if perr := parseCloseError(err, true); perr != nil { t.Error(perr) } t.Fatal(err) @@ -94,7 +94,7 @@ func TestCloseWrite(t *testing.T) { err = c.CloseWrite() } if err != nil { - if perr := parseCloseError(err); perr != nil { + if perr := parseCloseError(err, true); perr != nil { t.Error(perr) } t.Error(err) @@ -139,7 +139,7 @@ func TestCloseWrite(t *testing.T) { err = c.CloseWrite() } if err != nil { - if perr := parseCloseError(err); perr != nil { + if perr := parseCloseError(err, true); perr != nil { t.Error(perr) } t.Fatal(err) @@ -184,7 +184,7 @@ func TestConnClose(t *testing.T) { defer c.Close() if err := c.Close(); err != nil { - if perr := parseCloseError(err); perr != nil { + if perr := parseCloseError(err, false); perr != nil { t.Error(perr) } t.Fatal(err) @@ -215,7 +215,7 @@ func TestListenerClose(t *testing.T) { dst := ln.Addr().String() if err := ln.Close(); err != nil { - if perr := parseCloseError(err); perr != nil { + if perr := parseCloseError(err, false); perr != nil { t.Error(perr) } t.Fatal(err) @@ -269,7 +269,7 @@ func TestPacketConnClose(t *testing.T) { defer c.Close() if err := c.Close(); err != nil { - if perr := parseCloseError(err); perr != nil { + if perr := parseCloseError(err, false); perr != nil { t.Error(perr) } t.Fatal(err) @@ -292,7 +292,7 @@ func TestListenCloseListen(t *testing.T) { } addr := ln.Addr().String() if err := ln.Close(); err != nil { - if perr := parseCloseError(err); perr != nil { + if perr := parseCloseError(err, false); perr != nil { t.Error(perr) } t.Fatal(err) diff --git a/libgo/go/net/platform_test.go b/libgo/go/net/platform_test.go index 2a14095..5841ca3 100644 --- a/libgo/go/net/platform_test.go +++ b/libgo/go/net/platform_test.go @@ -50,11 +50,11 @@ func testableNetwork(network string) bool { } switch ss[0] { case "tcp4", "udp4", "ip4": - if !supportsIPv4 { + if !supportsIPv4() { return false } case "tcp6", "udp6", "ip6": - if !supportsIPv6 { + if !supportsIPv6() { return false } } @@ -117,25 +117,25 @@ func testableListenArgs(network, address, client string) bool { // Test functionality of IPv4 communication using AF_INET and // IPv6 communication using AF_INET6 sockets. - if !supportsIPv4 && ip.To4() != nil { + if !supportsIPv4() && ip.To4() != nil { return false } - if !supportsIPv6 && ip.To16() != nil && ip.To4() == nil { + if !supportsIPv6() && ip.To16() != nil && ip.To4() == nil { return false } cip := ParseIP(client) if cip != nil { - if !supportsIPv4 && cip.To4() != nil { + if !supportsIPv4() && cip.To4() != nil { return false } - if !supportsIPv6 && cip.To16() != nil && cip.To4() == nil { + if !supportsIPv6() && cip.To16() != nil && cip.To4() == nil { return false } } // Test functionality of IPv4 communication using AF_INET6 // sockets. - if !supportsIPv4map && supportsIPv4 && (network == "tcp" || network == "udp" || network == "ip") && wildcard { + if !supportsIPv4map() && supportsIPv4() && (network == "tcp" || network == "udp" || network == "ip") && wildcard { // At this point, we prefer IPv4 when ip is nil. // See favoriteAddrFamily for further information. if ip.To16() != nil && ip.To4() == nil && cip.To4() != nil { // a pair of IPv6 server and IPv4 client diff --git a/libgo/go/net/port_unix.go b/libgo/go/net/port_unix.go index 3120ba1..8dd1c32 100644 --- a/libgo/go/net/port_unix.go +++ b/libgo/go/net/port_unix.go @@ -17,6 +17,8 @@ func readServices() { if err != nil { return } + defer file.close() + for line, ok := file.readLine(); ok; line, ok = file.readLine() { // "http 80/tcp www www-http # World Wide Web HTTP" if i := byteIndex(line, '#'); i >= 0 { @@ -43,7 +45,6 @@ func readServices() { } } } - file.close() } // goLookupPort is the native Go implementation of LookupPort. diff --git a/libgo/go/net/rawconn.go b/libgo/go/net/rawconn.go new file mode 100644 index 0000000..d67be64 --- /dev/null +++ b/libgo/go/net/rawconn.go @@ -0,0 +1,62 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "runtime" + "syscall" +) + +// BUG(mikio): On Windows, the Read and Write methods of +// syscall.RawConn are not implemented. + +// BUG(mikio): On NaCl and Plan 9, the Control, Read and Write methods +// of syscall.RawConn are not implemented. + +type rawConn struct { + fd *netFD +} + +func (c *rawConn) ok() bool { return c != nil && c.fd != nil } + +func (c *rawConn) Control(f func(uintptr)) error { + if !c.ok() { + return syscall.EINVAL + } + err := c.fd.pfd.RawControl(f) + runtime.KeepAlive(c.fd) + if err != nil { + err = &OpError{Op: "raw-control", Net: c.fd.net, Source: nil, Addr: c.fd.laddr, Err: err} + } + return err +} + +func (c *rawConn) Read(f func(uintptr) bool) error { + if !c.ok() { + return syscall.EINVAL + } + err := c.fd.pfd.RawRead(f) + runtime.KeepAlive(c.fd) + if err != nil { + err = &OpError{Op: "raw-read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err} + } + return err +} + +func (c *rawConn) Write(f func(uintptr) bool) error { + if !c.ok() { + return syscall.EINVAL + } + err := c.fd.pfd.RawWrite(f) + runtime.KeepAlive(c.fd) + if err != nil { + err = &OpError{Op: "raw-write", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err} + } + return err +} + +func newRawConn(fd *netFD) (*rawConn, error) { + return &rawConn{fd: fd}, nil +} diff --git a/libgo/go/net/rawconn_unix_test.go b/libgo/go/net/rawconn_unix_test.go new file mode 100644 index 0000000..294249b --- /dev/null +++ b/libgo/go/net/rawconn_unix_test.go @@ -0,0 +1,94 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd solaris + +package net + +import ( + "bytes" + "syscall" + "testing" +) + +func TestRawConn(t *testing.T) { + handler := func(ls *localServer, ln Listener) { + c, err := ln.Accept() + if err != nil { + t.Error(err) + return + } + defer c.Close() + var b [32]byte + n, err := c.Read(b[:]) + if err != nil { + t.Error(err) + return + } + if _, err := c.Write(b[:n]); err != nil { + t.Error(err) + return + } + } + ls, err := newLocalServer("tcp") + if err != nil { + t.Fatal(err) + } + defer ls.teardown() + if err := ls.buildup(handler); err != nil { + t.Fatal(err) + } + + c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer c.Close() + cc, err := c.(*TCPConn).SyscallConn() + if err != nil { + t.Fatal(err) + } + + var operr error + data := []byte("HELLO-R-U-THERE") + err = cc.Write(func(s uintptr) bool { + _, operr = syscall.Write(int(s), data) + if operr == syscall.EAGAIN { + return false + } + return true + }) + if err != nil || operr != nil { + t.Fatal(err, operr) + } + + var nr int + var b [32]byte + err = cc.Read(func(s uintptr) bool { + nr, operr = syscall.Read(int(s), b[:]) + if operr == syscall.EAGAIN { + return false + } + return true + }) + if err != nil || operr != nil { + t.Fatal(err, operr) + } + if bytes.Compare(b[:nr], data) != 0 { + t.Fatalf("got %#v; want %#v", b[:nr], data) + } + + fn := func(s uintptr) { + operr = syscall.SetsockoptInt(int(s), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) + } + err = cc.Control(fn) + if err != nil || operr != nil { + t.Fatal(err, operr) + } + c.Close() + err = cc.Control(fn) + if err == nil { + t.Fatal("should fail") + } +} diff --git a/libgo/go/net/rawconn_windows_test.go b/libgo/go/net/rawconn_windows_test.go new file mode 100644 index 0000000..5fb6de7 --- /dev/null +++ b/libgo/go/net/rawconn_windows_test.go @@ -0,0 +1,36 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "syscall" + "testing" +) + +func TestRawConn(t *testing.T) { + c, err := newLocalPacketListener("udp") + if err != nil { + t.Fatal(err) + } + defer c.Close() + cc, err := c.(*UDPConn).SyscallConn() + if err != nil { + t.Fatal(err) + } + + var operr error + fn := func(s uintptr) { + operr = syscall.SetsockoptInt(syscall.Handle(s), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) + } + err = cc.Control(fn) + if err != nil || operr != nil { + t.Fatal(err, operr) + } + c.Close() + err = cc.Control(fn) + if err == nil { + t.Fatal("should fail") + } +} diff --git a/libgo/go/net/rpc/debug.go b/libgo/go/net/rpc/debug.go index 98b2c1c..a1d799f 100644 --- a/libgo/go/net/rpc/debug.go +++ b/libgo/go/net/rpc/debug.go @@ -71,20 +71,17 @@ type debugHTTP struct { // Runs at /debug/rpc func (server debugHTTP) ServeHTTP(w http.ResponseWriter, req *http.Request) { // Build a sorted version of the data. - var services = make(serviceArray, len(server.serviceMap)) - i := 0 - server.mu.Lock() - for sname, service := range server.serviceMap { - services[i] = debugService{service, sname, make(methodArray, len(service.method))} - j := 0 - for mname, method := range service.method { - services[i].Method[j] = debugMethod{method, mname} - j++ + var services serviceArray + server.serviceMap.Range(func(snamei, svci interface{}) bool { + svc := svci.(*service) + ds := debugService{svc, snamei.(string), make(methodArray, 0, len(svc.method))} + for mname, method := range svc.method { + ds.Method = append(ds.Method, debugMethod{method, mname}) } - sort.Sort(services[i].Method) - i++ - } - server.mu.Unlock() + sort.Sort(ds.Method) + services = append(services, ds) + return true + }) sort.Sort(services) err := debug.Execute(w, services) if err != nil { diff --git a/libgo/go/net/rpc/jsonrpc/all_test.go b/libgo/go/net/rpc/jsonrpc/all_test.go index b811d3c..bbb8eb0 100644 --- a/libgo/go/net/rpc/jsonrpc/all_test.go +++ b/libgo/go/net/rpc/jsonrpc/all_test.go @@ -13,6 +13,7 @@ import ( "io/ioutil" "net" "net/rpc" + "reflect" "strings" "testing" ) @@ -55,8 +56,26 @@ func (t *Arith) Error(args *Args, reply *Reply) error { panic("ERROR") } +type BuiltinTypes struct{} + +func (BuiltinTypes) Map(i int, reply *map[int]int) error { + (*reply)[i] = i + return nil +} + +func (BuiltinTypes) Slice(i int, reply *[]int) error { + *reply = append(*reply, i) + return nil +} + +func (BuiltinTypes) Array(i int, reply *[1]int) error { + (*reply)[0] = i + return nil +} + func init() { rpc.Register(new(Arith)) + rpc.Register(BuiltinTypes{}) } func TestServerNoParams(t *testing.T) { @@ -182,6 +201,45 @@ func TestClient(t *testing.T) { } } +func TestBuiltinTypes(t *testing.T) { + cli, srv := net.Pipe() + go ServeConn(srv) + + client := NewClient(cli) + defer client.Close() + + // Map + arg := 7 + replyMap := map[int]int{} + err := client.Call("BuiltinTypes.Map", arg, &replyMap) + if err != nil { + t.Errorf("Map: expected no error but got string %q", err.Error()) + } + if replyMap[arg] != arg { + t.Errorf("Map: expected %d got %d", arg, replyMap[arg]) + } + + // Slice + replySlice := []int{} + err = client.Call("BuiltinTypes.Slice", arg, &replySlice) + if err != nil { + t.Errorf("Slice: expected no error but got string %q", err.Error()) + } + if e := []int{arg}; !reflect.DeepEqual(replySlice, e) { + t.Errorf("Slice: expected %v got %v", e, replySlice) + } + + // Array + replyArray := [1]int{} + err = client.Call("BuiltinTypes.Array", arg, &replyArray) + if err != nil { + t.Errorf("Array: expected no error but got string %q", err.Error()) + } + if e := [1]int{arg}; !reflect.DeepEqual(replyArray, e) { + t.Errorf("Array: expected %v got %v", e, replyArray) + } +} + func TestMalformedInput(t *testing.T) { cli, srv := net.Pipe() go cli.Write([]byte(`{id:1}`)) // invalid json diff --git a/libgo/go/net/rpc/jsonrpc/client.go b/libgo/go/net/rpc/jsonrpc/client.go index da1b816..e6359be 100644 --- a/libgo/go/net/rpc/jsonrpc/client.go +++ b/libgo/go/net/rpc/jsonrpc/client.go @@ -2,8 +2,9 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package jsonrpc implements a JSON-RPC ClientCodec and ServerCodec +// Package jsonrpc implements a JSON-RPC 1.0 ClientCodec and ServerCodec // for the rpc package. +// For JSON-RPC 2.0 support, see https://godoc.org/?q=json-rpc+2.0 package jsonrpc import ( diff --git a/libgo/go/net/rpc/server.go b/libgo/go/net/rpc/server.go index 18ea629..29aae7e 100644 --- a/libgo/go/net/rpc/server.go +++ b/libgo/go/net/rpc/server.go @@ -187,8 +187,7 @@ type Response struct { // Server represents an RPC Server. type Server struct { - mu sync.RWMutex // protects the serviceMap - serviceMap map[string]*service + serviceMap sync.Map // map[string]*service reqLock sync.Mutex // protects freeReq freeReq *Request respLock sync.Mutex // protects freeResp @@ -197,7 +196,7 @@ type Server struct { // NewServer returns a new Server. func NewServer() *Server { - return &Server{serviceMap: make(map[string]*service)} + return &Server{} } // DefaultServer is the default instance of *Server. @@ -240,11 +239,6 @@ func (server *Server) RegisterName(name string, rcvr interface{}) error { } func (server *Server) register(rcvr interface{}, name string, useName bool) error { - server.mu.Lock() - defer server.mu.Unlock() - if server.serviceMap == nil { - server.serviceMap = make(map[string]*service) - } s := new(service) s.typ = reflect.TypeOf(rcvr) s.rcvr = reflect.ValueOf(rcvr) @@ -262,9 +256,6 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) erro log.Print(s) return errors.New(s) } - if _, present := server.serviceMap[sname]; present { - return errors.New("rpc: service already defined: " + sname) - } s.name = sname // Install the methods @@ -283,7 +274,10 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) erro log.Print(str) return errors.New(str) } - server.serviceMap[s.name] = s + + if _, dup := server.serviceMap.LoadOrStore(sname, s); dup { + return errors.New("rpc: service already defined: " + sname) + } return nil } @@ -571,10 +565,17 @@ func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *m } replyv = reflect.New(mtype.ReplyType.Elem()) + + switch mtype.ReplyType.Elem().Kind() { + case reflect.Map: + replyv.Elem().Set(reflect.MakeMap(mtype.ReplyType.Elem())) + case reflect.Slice: + replyv.Elem().Set(reflect.MakeSlice(mtype.ReplyType.Elem(), 0, 0)) + } return } -func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mtype *methodType, req *Request, keepReading bool, err error) { +func (server *Server) readRequestHeader(codec ServerCodec) (svc *service, mtype *methodType, req *Request, keepReading bool, err error) { // Grab the request header. req = server.getRequest() err = codec.ReadRequestHeader(req) @@ -600,14 +601,13 @@ func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mt methodName := req.ServiceMethod[dot+1:] // Look up the request. - server.mu.RLock() - service = server.serviceMap[serviceName] - server.mu.RUnlock() - if service == nil { + svci, ok := server.serviceMap.Load(serviceName) + if !ok { err = errors.New("rpc: can't find service " + req.ServiceMethod) return } - mtype = service.method[methodName] + svc = svci.(*service) + mtype = svc.method[methodName] if mtype == nil { err = errors.New("rpc: can't find method " + req.ServiceMethod) } diff --git a/libgo/go/net/rpc/server_test.go b/libgo/go/net/rpc/server_test.go index 8369c9d..fb97f82 100644 --- a/libgo/go/net/rpc/server_test.go +++ b/libgo/go/net/rpc/server_test.go @@ -11,6 +11,7 @@ import ( "log" "net" "net/http/httptest" + "reflect" "runtime" "strings" "sync" @@ -85,6 +86,24 @@ type Embed struct { hidden } +type BuiltinTypes struct{} + +func (BuiltinTypes) Map(args *Args, reply *map[int]int) error { + (*reply)[args.A] = args.B + return nil +} + +func (BuiltinTypes) Slice(args *Args, reply *[]int) error { + *reply = append(*reply, args.A, args.B) + return nil +} + +func (BuiltinTypes) Array(args *Args, reply *[2]int) error { + (*reply)[0] = args.A + (*reply)[1] = args.B + return nil +} + func listenTCP() (net.Listener, string) { l, e := net.Listen("tcp", "127.0.0.1:0") // any available address if e != nil { @@ -97,6 +116,7 @@ func startServer() { Register(new(Arith)) Register(new(Embed)) RegisterName("net.rpc.Arith", new(Arith)) + Register(BuiltinTypes{}) var l net.Listener l, serverAddr = listenTCP() @@ -326,6 +346,49 @@ func testHTTPRPC(t *testing.T, path string) { } } +func TestBuiltinTypes(t *testing.T) { + once.Do(startServer) + + client, err := DialHTTP("tcp", httpServerAddr) + if err != nil { + t.Fatal("dialing", err) + } + defer client.Close() + + // Map + args := &Args{7, 8} + replyMap := map[int]int{} + err = client.Call("BuiltinTypes.Map", args, &replyMap) + if err != nil { + t.Errorf("Map: expected no error but got string %q", err.Error()) + } + if replyMap[args.A] != args.B { + t.Errorf("Map: expected %d got %d", args.B, replyMap[args.A]) + } + + // Slice + args = &Args{7, 8} + replySlice := []int{} + err = client.Call("BuiltinTypes.Slice", args, &replySlice) + if err != nil { + t.Errorf("Slice: expected no error but got string %q", err.Error()) + } + if e := []int{args.A, args.B}; !reflect.DeepEqual(replySlice, e) { + t.Errorf("Slice: expected %v got %v", e, replySlice) + } + + // Array + args = &Args{7, 8} + replyArray := [2]int{} + err = client.Call("BuiltinTypes.Array", args, &replyArray) + if err != nil { + t.Errorf("Array: expected no error but got string %q", err.Error()) + } + if e := [2]int{args.A, args.B}; !reflect.DeepEqual(replyArray, e) { + t.Errorf("Array: expected %v got %v", e, replyArray) + } +} + // CodecEmulator provides a client-like api and a ServerCodec interface. // Can be used to test ServeRequest. type CodecEmulator struct { @@ -619,13 +682,13 @@ func TestErrorAfterClientClose(t *testing.T) { // Tests the fix to issue 11221. Without the fix, this loops forever or crashes. func TestAcceptExitAfterListenerClose(t *testing.T) { - newServer = NewServer() + newServer := NewServer() newServer.Register(new(Arith)) newServer.RegisterName("net.rpc.Arith", new(Arith)) newServer.RegisterName("newServer.Arith", new(Arith)) var l net.Listener - l, newServerAddr = listenTCP() + l, _ = listenTCP() l.Close() newServer.Accept(l) } diff --git a/libgo/go/net/sendfile_bsd.go b/libgo/go/net/sendfile_bsd.go new file mode 100644 index 0000000..7a2b48c --- /dev/null +++ b/libgo/go/net/sendfile_bsd.go @@ -0,0 +1,67 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build dragonfly freebsd + +package net + +import ( + "internal/poll" + "io" + "os" +) + +// sendFile copies the contents of r to c using the sendfile +// system call to minimize copies. +// +// if handled == true, sendFile returns the number of bytes copied and any +// non-EOF error. +// +// if handled == false, sendFile performed no work. +func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { + // FreeBSD and DragonFly use 0 as the "until EOF" value. + // If you pass in more bytes than the file contains, it will + // loop back to the beginning ad nauseam until it's sent + // exactly the number of bytes told to. As such, we need to + // know exactly how many bytes to send. + var remain int64 = 0 + + lr, ok := r.(*io.LimitedReader) + if ok { + remain, r = lr.N, lr.R + if remain <= 0 { + return 0, nil, true + } + } + f, ok := r.(*os.File) + if !ok { + return 0, nil, false + } + + if remain == 0 { + fi, err := f.Stat() + if err != nil { + return 0, err, false + } + + remain = fi.Size() + } + + // The other quirk with FreeBSD/DragonFly's sendfile + // implementation is that it doesn't use the current position + // of the file -- if you pass it offset 0, it starts from + // offset 0. There's no way to tell it "start from current + // position", so we have to manage that explicitly. + pos, err := f.Seek(0, io.SeekCurrent) + if err != nil { + return 0, err, false + } + + written, err = poll.SendFile(&c.pfd, int(f.Fd()), pos, remain) + + if lr != nil { + lr.N = remain - written + } + return written, wrapSyscallError("sendfile", err), written > 0 +} diff --git a/libgo/go/net/sendfile_dragonfly.go b/libgo/go/net/sendfile_dragonfly.go deleted file mode 100644 index d4b825c..0000000 --- a/libgo/go/net/sendfile_dragonfly.go +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package net - -import ( - "io" - "os" - "syscall" -) - -// maxSendfileSize is the largest chunk size we ask the kernel to copy -// at a time. -const maxSendfileSize int = 4 << 20 - -// sendFile copies the contents of r to c using the sendfile -// system call to minimize copies. -// -// if handled == true, sendFile returns the number of bytes copied and any -// non-EOF error. -// -// if handled == false, sendFile performed no work. -func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { - // DragonFly uses 0 as the "until EOF" value. If you pass in more bytes than the - // file contains, it will loop back to the beginning ad nauseam until it's sent - // exactly the number of bytes told to. As such, we need to know exactly how many - // bytes to send. - var remain int64 = 0 - - lr, ok := r.(*io.LimitedReader) - if ok { - remain, r = lr.N, lr.R - if remain <= 0 { - return 0, nil, true - } - } - f, ok := r.(*os.File) - if !ok { - return 0, nil, false - } - - if remain == 0 { - fi, err := f.Stat() - if err != nil { - return 0, err, false - } - - remain = fi.Size() - } - - // The other quirk with DragonFly's sendfile implementation is that it doesn't - // use the current position of the file -- if you pass it offset 0, it starts - // from offset 0. There's no way to tell it "start from current position", so - // we have to manage that explicitly. - pos, err := f.Seek(0, io.SeekCurrent) - if err != nil { - return 0, err, false - } - - if err := c.writeLock(); err != nil { - return 0, err, true - } - defer c.writeUnlock() - - dst := c.sysfd - src := int(f.Fd()) - for remain > 0 { - n := maxSendfileSize - if int64(n) > remain { - n = int(remain) - } - pos1 := pos - n, err1 := syscall.Sendfile(dst, src, &pos1, n) - if n > 0 { - pos += int64(n) - written += int64(n) - remain -= int64(n) - } - if n == 0 && err1 == nil { - break - } - if err1 == syscall.EAGAIN { - if err1 = c.pd.waitWrite(); err1 == nil { - continue - } - } - if err1 == syscall.EINTR { - continue - } - if err1 != nil { - // This includes syscall.ENOSYS (no kernel - // support) and syscall.EINVAL (fd types which - // don't implement sendfile) - err = err1 - break - } - } - if lr != nil { - lr.N = remain - } - if err != nil { - err = os.NewSyscallError("sendfile", err) - } - return written, err, written > 0 -} diff --git a/libgo/go/net/sendfile_freebsd.go b/libgo/go/net/sendfile_freebsd.go deleted file mode 100644 index 18cbb27..0000000 --- a/libgo/go/net/sendfile_freebsd.go +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package net - -import ( - "io" - "os" - "syscall" -) - -// maxSendfileSize is the largest chunk size we ask the kernel to copy -// at a time. -const maxSendfileSize int = 4 << 20 - -// sendFile copies the contents of r to c using the sendfile -// system call to minimize copies. -// -// if handled == true, sendFile returns the number of bytes copied and any -// non-EOF error. -// -// if handled == false, sendFile performed no work. -func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { - // FreeBSD uses 0 as the "until EOF" value. If you pass in more bytes than the - // file contains, it will loop back to the beginning ad nauseam until it's sent - // exactly the number of bytes told to. As such, we need to know exactly how many - // bytes to send. - var remain int64 = 0 - - lr, ok := r.(*io.LimitedReader) - if ok { - remain, r = lr.N, lr.R - if remain <= 0 { - return 0, nil, true - } - } - f, ok := r.(*os.File) - if !ok { - return 0, nil, false - } - - if remain == 0 { - fi, err := f.Stat() - if err != nil { - return 0, err, false - } - - remain = fi.Size() - } - - // The other quirk with FreeBSD's sendfile implementation is that it doesn't - // use the current position of the file -- if you pass it offset 0, it starts - // from offset 0. There's no way to tell it "start from current position", so - // we have to manage that explicitly. - pos, err := f.Seek(0, io.SeekCurrent) - if err != nil { - return 0, err, false - } - - if err := c.writeLock(); err != nil { - return 0, err, true - } - defer c.writeUnlock() - - dst := c.sysfd - src := int(f.Fd()) - for remain > 0 { - n := maxSendfileSize - if int64(n) > remain { - n = int(remain) - } - pos1 := pos - n, err1 := syscall.Sendfile(dst, src, &pos1, n) - if n > 0 { - pos += int64(n) - written += int64(n) - remain -= int64(n) - } - if n == 0 && err1 == nil { - break - } - if err1 == syscall.EAGAIN { - if err1 = c.pd.waitWrite(); err1 == nil { - continue - } - } - if err1 == syscall.EINTR { - continue - } - if err1 != nil { - // This includes syscall.ENOSYS (no kernel - // support) and syscall.EINVAL (fd types which - // don't implement sendfile) - err = err1 - break - } - } - if lr != nil { - lr.N = remain - } - if err != nil { - err = os.NewSyscallError("sendfile", err) - } - return written, err, written > 0 -} diff --git a/libgo/go/net/sendfile_linux.go b/libgo/go/net/sendfile_linux.go index 7e741f9..c537ea6 100644 --- a/libgo/go/net/sendfile_linux.go +++ b/libgo/go/net/sendfile_linux.go @@ -5,15 +5,11 @@ package net import ( + "internal/poll" "io" "os" - "syscall" ) -// maxSendfileSize is the largest chunk size we ask the kernel to copy -// at a time. -const maxSendfileSize int = 4 << 20 - // sendFile copies the contents of r to c using the sendfile // system call to minimize copies. // @@ -36,44 +32,10 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { return 0, nil, false } - if err := c.writeLock(); err != nil { - return 0, err, true - } - defer c.writeUnlock() + written, err = poll.SendFile(&c.pfd, int(f.Fd()), remain) - dst := c.sysfd - src := int(f.Fd()) - for remain > 0 { - n := maxSendfileSize - if int64(n) > remain { - n = int(remain) - } - n, err1 := syscall.Sendfile(dst, src, nil, n) - if n > 0 { - written += int64(n) - remain -= int64(n) - } - if n == 0 && err1 == nil { - break - } - if err1 == syscall.EAGAIN { - if err1 = c.pd.waitWrite(); err1 == nil { - continue - } - } - if err1 != nil { - // This includes syscall.ENOSYS (no kernel - // support) and syscall.EINVAL (fd types which - // don't implement sendfile) - err = err1 - break - } - } if lr != nil { - lr.N = remain - } - if err != nil { - err = os.NewSyscallError("sendfile", err) + lr.N = remain - written } - return written, err, written > 0 + return written, wrapSyscallError("sendfile", err), written > 0 } diff --git a/libgo/go/net/sendfile_solaris.go b/libgo/go/net/sendfile_solaris.go index add70c3..63ca9d4 100644 --- a/libgo/go/net/sendfile_solaris.go +++ b/libgo/go/net/sendfile_solaris.go @@ -5,19 +5,11 @@ package net import ( + "internal/poll" "io" "os" - "syscall" ) -// Not strictly needed, but very helpful for debugging, see issue #10221. -//go:cgo_import_dynamic _ _ "libsendfile.so" -//go:cgo_import_dynamic _ _ "libsocket.so" - -// maxSendfileSize is the largest chunk size we ask the kernel to copy -// at a time. -const maxSendfileSize int = 4 << 20 - // sendFile copies the contents of r to c using the sendfile // system call to minimize copies. // @@ -62,56 +54,10 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { return 0, err, false } - if err := c.writeLock(); err != nil { - return 0, err, true - } - defer c.writeUnlock() + written, err = poll.SendFile(&c.pfd, int(f.Fd()), pos, remain) - dst := c.sysfd - src := int(f.Fd()) - for remain > 0 { - n := maxSendfileSize - if int64(n) > remain { - n = int(remain) - } - pos1 := pos - n, err1 := syscall.Sendfile(dst, src, &pos1, n) - if err1 == syscall.EAGAIN || err1 == syscall.EINTR { - // partial write may have occurred - if n = int(pos1 - pos); n == 0 { - // nothing more to write - err1 = nil - } - } - if n > 0 { - pos += int64(n) - written += int64(n) - remain -= int64(n) - } - if n == 0 && err1 == nil { - break - } - if err1 == syscall.EAGAIN { - if err1 = c.pd.waitWrite(); err1 == nil { - continue - } - } - if err1 == syscall.EINTR { - continue - } - if err1 != nil { - // This includes syscall.ENOSYS (no kernel - // support) and syscall.EINVAL (fd types which - // don't implement sendfile) - err = err1 - break - } - } if lr != nil { - lr.N = remain - } - if err != nil { - err = os.NewSyscallError("sendfile", err) + lr.N = remain - written } - return written, err, written > 0 + return written, wrapSyscallError("sendfile", err), written > 0 } diff --git a/libgo/go/net/sendfile_windows.go b/libgo/go/net/sendfile_windows.go index bc0b7fb..bccd8b1 100644 --- a/libgo/go/net/sendfile_windows.go +++ b/libgo/go/net/sendfile_windows.go @@ -5,6 +5,7 @@ package net import ( + "internal/poll" "io" "os" "syscall" @@ -34,19 +35,10 @@ func sendFile(fd *netFD, r io.Reader) (written int64, err error, handled bool) { return 0, nil, false } - if err := fd.writeLock(); err != nil { - return 0, err, true - } - defer fd.writeUnlock() + done, err := poll.SendFile(&fd.pfd, syscall.Handle(f.Fd()), n) - o := &fd.wop - o.qty = uint32(n) - o.handle = syscall.Handle(f.Fd()) - done, err := wsrv.ExecIO(o, "TransmitFile", func(o *operation) error { - return syscall.TransmitFile(o.fd.sysfd, o.handle, o.qty, 0, &o.o, nil, syscall.TF_WRITE_BEHIND) - }) if err != nil { - return 0, os.NewSyscallError("transmitfile", err), false + return 0, wrapSyscallError("transmitfile", err), false } if lr != nil { lr.N -= int64(done) diff --git a/libgo/go/net/smtp/smtp.go b/libgo/go/net/smtp/smtp.go index a408fa5..28472e4 100644 --- a/libgo/go/net/smtp/smtp.go +++ b/libgo/go/net/smtp/smtp.go @@ -298,7 +298,7 @@ var testHookStartTLS func(*tls.Config) // nil, except for tests // messages is accomplished by including an email address in the to // parameter but not including it in the msg headers. // -// The SendMail function and the the net/smtp package are low-level +// The SendMail function and the net/smtp package are low-level // mechanisms and provide no support for DKIM signing, MIME // attachments (see the mime/multipart package), or other mail // functionality. Higher-level packages exist outside of the standard diff --git a/libgo/go/net/smtp/smtp_test.go b/libgo/go/net/smtp/smtp_test.go index c48fae6d..9dbe3eb 100644 --- a/libgo/go/net/smtp/smtp_test.go +++ b/libgo/go/net/smtp/smtp_test.go @@ -9,9 +9,11 @@ import ( "bytes" "crypto/tls" "crypto/x509" + "internal/testenv" "io" "net" "net/textproto" + "runtime" "strings" "testing" "time" @@ -592,6 +594,9 @@ QUIT ` func TestTLSClient(t *testing.T) { + if runtime.GOOS == "freebsd" && runtime.GOARCH == "amd64" { + testenv.SkipFlaky(t, 19229) + } ln := newLocalListener(t) defer ln.Close() errc := make(chan error) diff --git a/libgo/go/net/sock_cloexec.go b/libgo/go/net/sock_cloexec.go index 616a101e..06ff10d 100644 --- a/libgo/go/net/sock_cloexec.go +++ b/libgo/go/net/sock_cloexec.go @@ -5,11 +5,12 @@ // This file implements sysSocket and accept for platforms that // provide a fast path for setting SetNonblock and CloseOnExec. -// +build freebsd linux +// +build dragonfly freebsd linux package net import ( + "internal/poll" "os" "syscall" ) @@ -42,46 +43,8 @@ func sysSocket(family, sotype, proto int) (int, error) { return -1, os.NewSyscallError("socket", err) } if err = syscall.SetNonblock(s, true); err != nil { - closeFunc(s) + poll.CloseFunc(s) return -1, os.NewSyscallError("setnonblock", err) } return s, nil } - -// Wrapper around the accept system call that marks the returned file -// descriptor as nonblocking and close-on-exec. -func accept(s int) (int, syscall.Sockaddr, error) { - ns, sa, err := accept4Func(s, syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC) - // On Linux the accept4 system call was introduced in 2.6.28 - // kernel and on FreeBSD it was introduced in 10 kernel. If we - // get an ENOSYS error on both Linux and FreeBSD, or EINVAL - // error on Linux, fall back to using accept. - switch err { - case nil: - return ns, sa, nil - default: // errors other than the ones listed - return -1, sa, os.NewSyscallError("accept4", err) - case syscall.ENOSYS: // syscall missing - case syscall.EINVAL: // some Linux use this instead of ENOSYS - case syscall.EACCES: // some Linux use this instead of ENOSYS - case syscall.EFAULT: // some Linux use this instead of ENOSYS - } - - // See ../syscall/exec_unix.go for description of ForkLock. - // It is probably okay to hold the lock across syscall.Accept - // because we have put fd.sysfd into non-blocking mode. - // However, a call to the File method will put it back into - // blocking mode. We can't take that risk, so no use of ForkLock here. - ns, sa, err = acceptFunc(s) - if err == nil { - syscall.CloseOnExec(ns) - } - if err != nil { - return -1, nil, os.NewSyscallError("accept", err) - } - if err = syscall.SetNonblock(ns, true); err != nil { - closeFunc(ns) - return -1, nil, os.NewSyscallError("setnonblock", err) - } - return ns, sa, nil -} diff --git a/libgo/go/net/sock_posix.go b/libgo/go/net/sock_posix.go index 6bbfd12..a30efe2 100644 --- a/libgo/go/net/sock_posix.go +++ b/libgo/go/net/sock_posix.go @@ -8,6 +8,7 @@ package net import ( "context" + "internal/poll" "os" "syscall" ) @@ -43,11 +44,11 @@ func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only return nil, err } if err = setDefaultSockopts(s, family, sotype, ipv6only); err != nil { - closeFunc(s) + poll.CloseFunc(s) return nil, err } if fd, err = newFD(s, family, sotype, net); err != nil { - closeFunc(s) + poll.CloseFunc(s) return nil, err } @@ -127,17 +128,18 @@ func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr) error { if lsa, err = laddr.sockaddr(fd.family); err != nil { return err } else if lsa != nil { - if err := syscall.Bind(fd.sysfd, lsa); err != nil { + if err := syscall.Bind(fd.pfd.Sysfd, lsa); err != nil { return os.NewSyscallError("bind", err) } } } - var rsa syscall.Sockaddr + var rsa syscall.Sockaddr // remote address from the user + var crsa syscall.Sockaddr // remote address we actually connected to if raddr != nil { if rsa, err = raddr.sockaddr(fd.family); err != nil { return err } - if err := fd.connect(ctx, lsa, rsa); err != nil { + if crsa, err = fd.connect(ctx, lsa, rsa); err != nil { return err } fd.isConnected = true @@ -146,8 +148,16 @@ func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr) error { return err } } - lsa, _ = syscall.Getsockname(fd.sysfd) - if rsa, _ = syscall.Getpeername(fd.sysfd); rsa != nil { + // Record the local and remote addresses from the actual socket. + // Get the local address by calling Getsockname. + // For the remote address, use + // 1) the one returned by the connect method, if any; or + // 2) the one from Getpeername, if it succeeds; or + // 3) the one passed to us as the raddr parameter. + lsa, _ = syscall.Getsockname(fd.pfd.Sysfd) + if crsa != nil { + fd.setAddr(fd.addrFunc()(lsa), fd.addrFunc()(crsa)) + } else if rsa, _ = syscall.Getpeername(fd.pfd.Sysfd); rsa != nil { fd.setAddr(fd.addrFunc()(lsa), fd.addrFunc()(rsa)) } else { fd.setAddr(fd.addrFunc()(lsa), raddr) @@ -156,23 +166,23 @@ func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr) error { } func (fd *netFD) listenStream(laddr sockaddr, backlog int) error { - if err := setDefaultListenerSockopts(fd.sysfd); err != nil { + if err := setDefaultListenerSockopts(fd.pfd.Sysfd); err != nil { return err } if lsa, err := laddr.sockaddr(fd.family); err != nil { return err } else if lsa != nil { - if err := syscall.Bind(fd.sysfd, lsa); err != nil { + if err := syscall.Bind(fd.pfd.Sysfd, lsa); err != nil { return os.NewSyscallError("bind", err) } } - if err := listenFunc(fd.sysfd, backlog); err != nil { + if err := listenFunc(fd.pfd.Sysfd, backlog); err != nil { return os.NewSyscallError("listen", err) } if err := fd.init(); err != nil { return err } - lsa, _ := syscall.Getsockname(fd.sysfd) + lsa, _ := syscall.Getsockname(fd.pfd.Sysfd) fd.setAddr(fd.addrFunc()(lsa), nil) return nil } @@ -188,7 +198,7 @@ func (fd *netFD) listenDatagram(laddr sockaddr) error { // multiple UDP listeners that listen on the same UDP // port to join the same group address. if addr.IP != nil && addr.IP.IsMulticast() { - if err := setDefaultMulticastSockopts(fd.sysfd); err != nil { + if err := setDefaultMulticastSockopts(fd.pfd.Sysfd); err != nil { return err } addr := *addr @@ -204,14 +214,14 @@ func (fd *netFD) listenDatagram(laddr sockaddr) error { if lsa, err := laddr.sockaddr(fd.family); err != nil { return err } else if lsa != nil { - if err := syscall.Bind(fd.sysfd, lsa); err != nil { + if err := syscall.Bind(fd.pfd.Sysfd, lsa); err != nil { return os.NewSyscallError("bind", err) } } if err := fd.init(); err != nil { return err } - lsa, _ := syscall.Getsockname(fd.sysfd) + lsa, _ := syscall.Getsockname(fd.pfd.Sysfd) fd.setAddr(fd.addrFunc()(lsa), nil) return nil } diff --git a/libgo/go/net/sockopt_bsd.go b/libgo/go/net/sockopt_bsd.go index 734a109..1aae88a 100644 --- a/libgo/go/net/sockopt_bsd.go +++ b/libgo/go/net/sockopt_bsd.go @@ -25,7 +25,7 @@ func setDefaultSockopts(s, family, sotype int, ipv6only bool) error { syscall.SetsockoptInt(s, syscall.IPPROTO_IPV6, syscall.IPV6_PORTRANGE, syscall.IPV6_PORTRANGE_HIGH) } } - if supportsIPv4map && family == syscall.AF_INET6 && sotype != syscall.SOCK_RAW { + if supportsIPv4map() && family == syscall.AF_INET6 && sotype != syscall.SOCK_RAW { // Allow both IP versions even if the OS default // is otherwise. Note that some operating systems // never admit this option. diff --git a/libgo/go/net/sockopt_posix.go b/libgo/go/net/sockopt_posix.go index cacd048..29edddb 100644 --- a/libgo/go/net/sockopt_posix.go +++ b/libgo/go/net/sockopt_posix.go @@ -7,7 +7,7 @@ package net import ( - "os" + "runtime" "syscall" ) @@ -101,27 +101,21 @@ done: } func setReadBuffer(fd *netFD, bytes int) error { - if err := fd.incref(); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bytes)) + err := fd.pfd.SetsockoptInt(syscall.SOL_SOCKET, syscall.SO_RCVBUF, bytes) + runtime.KeepAlive(fd) + return wrapSyscallError("setsockopt", err) } func setWriteBuffer(fd *netFD, bytes int) error { - if err := fd.incref(); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bytes)) + err := fd.pfd.SetsockoptInt(syscall.SOL_SOCKET, syscall.SO_SNDBUF, bytes) + runtime.KeepAlive(fd) + return wrapSyscallError("setsockopt", err) } func setKeepAlive(fd *netFD, keepalive bool) error { - if err := fd.incref(); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, boolint(keepalive))) + err := fd.pfd.SetsockoptInt(syscall.SOL_SOCKET, syscall.SO_KEEPALIVE, boolint(keepalive)) + runtime.KeepAlive(fd) + return wrapSyscallError("setsockopt", err) } func setLinger(fd *netFD, sec int) error { @@ -133,9 +127,7 @@ func setLinger(fd *netFD, sec int) error { l.Onoff = 0 l.Linger = 0 } - if err := fd.incref(); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.SetsockoptLinger(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_LINGER, &l)) + err := fd.pfd.SetsockoptLinger(syscall.SOL_SOCKET, syscall.SO_LINGER, &l) + runtime.KeepAlive(fd) + return wrapSyscallError("setsockopt", err) } diff --git a/libgo/go/net/sockoptip_bsd.go b/libgo/go/net/sockoptip_bsd.go index b15c639..b11f3a4 100644 --- a/libgo/go/net/sockoptip_bsd.go +++ b/libgo/go/net/sockoptip_bsd.go @@ -7,28 +7,24 @@ package net import ( - "os" + "runtime" "syscall" ) func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error { ip, err := interfaceToIPv4Addr(ifi) if err != nil { - return os.NewSyscallError("setsockopt", err) + return wrapSyscallError("setsockopt", err) } var a [4]byte copy(a[:], ip.To4()) - if err := fd.incref(); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.SetsockoptInet4Addr(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, a)) + err = fd.pfd.SetsockoptInet4Addr(syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, a) + runtime.KeepAlive(fd) + return wrapSyscallError("setsockopt", err) } func setIPv4MulticastLoopback(fd *netFD, v bool) error { - if err := fd.incref(); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.SetsockoptByte(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, byte(boolint(v)))) + err := fd.pfd.SetsockoptByte(syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, byte(boolint(v))) + runtime.KeepAlive(fd) + return wrapSyscallError("setsockopt", err) } diff --git a/libgo/go/net/sockoptip_linux.go b/libgo/go/net/sockoptip_linux.go index c1dcc91..bd7d834 100644 --- a/libgo/go/net/sockoptip_linux.go +++ b/libgo/go/net/sockoptip_linux.go @@ -5,7 +5,7 @@ package net import ( - "os" + "runtime" "syscall" ) @@ -15,17 +15,13 @@ func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error { v = int32(ifi.Index) } mreq := &syscall.IPMreqn{Ifindex: v} - if err := fd.incref(); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.SetsockoptIPMreqn(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, mreq)) + err := fd.pfd.SetsockoptIPMreqn(syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, mreq) + runtime.KeepAlive(fd) + return wrapSyscallError("setsockopt", err) } func setIPv4MulticastLoopback(fd *netFD, v bool) error { - if err := fd.incref(); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v))) + err := fd.pfd.SetsockoptInt(syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v)) + runtime.KeepAlive(fd) + return wrapSyscallError("setsockopt", err) } diff --git a/libgo/go/net/sockoptip_posix.go b/libgo/go/net/sockoptip_posix.go index 4afd4c8..92af764 100644 --- a/libgo/go/net/sockoptip_posix.go +++ b/libgo/go/net/sockoptip_posix.go @@ -7,7 +7,7 @@ package net import ( - "os" + "runtime" "syscall" ) @@ -16,11 +16,9 @@ func joinIPv4Group(fd *netFD, ifi *Interface, ip IP) error { if err := setIPv4MreqToInterface(mreq, ifi); err != nil { return err } - if err := fd.incref(); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.SetsockoptIPMreq(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_ADD_MEMBERSHIP, mreq)) + err := fd.pfd.SetsockoptIPMreq(syscall.IPPROTO_IP, syscall.IP_ADD_MEMBERSHIP, mreq) + runtime.KeepAlive(fd) + return wrapSyscallError("setsockopt", err) } func setIPv6MulticastInterface(fd *netFD, ifi *Interface) error { @@ -28,19 +26,15 @@ func setIPv6MulticastInterface(fd *netFD, ifi *Interface) error { if ifi != nil { v = ifi.Index } - if err := fd.incref(); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_IF, v)) + err := fd.pfd.SetsockoptInt(syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_IF, v) + runtime.KeepAlive(fd) + return wrapSyscallError("setsockopt", err) } func setIPv6MulticastLoopback(fd *netFD, v bool) error { - if err := fd.incref(); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_LOOP, boolint(v))) + err := fd.pfd.SetsockoptInt(syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_LOOP, boolint(v)) + runtime.KeepAlive(fd) + return wrapSyscallError("setsockopt", err) } func joinIPv6Group(fd *netFD, ifi *Interface, ip IP) error { @@ -49,9 +43,7 @@ func joinIPv6Group(fd *netFD, ifi *Interface, ip IP) error { if ifi != nil { mreq.Interface = uint32(ifi.Index) } - if err := fd.incref(); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.SetsockoptIPv6Mreq(fd.sysfd, syscall.IPPROTO_IPV6, syscall.IPV6_JOIN_GROUP, mreq)) + err := fd.pfd.SetsockoptIPv6Mreq(syscall.IPPROTO_IPV6, syscall.IPV6_JOIN_GROUP, mreq) + runtime.KeepAlive(fd) + return wrapSyscallError("setsockopt", err) } diff --git a/libgo/go/net/sockoptip_windows.go b/libgo/go/net/sockoptip_windows.go index 916debe..6267603 100644 --- a/libgo/go/net/sockoptip_windows.go +++ b/libgo/go/net/sockoptip_windows.go @@ -6,6 +6,7 @@ package net import ( "os" + "runtime" "syscall" "unsafe" ) @@ -17,17 +18,13 @@ func setIPv4MulticastInterface(fd *netFD, ifi *Interface) error { } var a [4]byte copy(a[:], ip.To4()) - if err := fd.incref(); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.Setsockopt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, (*byte)(unsafe.Pointer(&a[0])), 4)) + err = fd.pfd.Setsockopt(syscall.IPPROTO_IP, syscall.IP_MULTICAST_IF, (*byte)(unsafe.Pointer(&a[0])), 4) + runtime.KeepAlive(fd) + return wrapSyscallError("setsockopt", err) } func setIPv4MulticastLoopback(fd *netFD, v bool) error { - if err := fd.incref(); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v))) + err := fd.pfd.SetsockoptInt(syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, boolint(v)) + runtime.KeepAlive(fd) + return wrapSyscallError("setsockopt", err) } diff --git a/libgo/go/net/sys_cloexec.go b/libgo/go/net/sys_cloexec.go index f2ea842..def05cb 100644 --- a/libgo/go/net/sys_cloexec.go +++ b/libgo/go/net/sys_cloexec.go @@ -5,11 +5,12 @@ // This file implements sysSocket and accept for platforms that do not // provide a fast path for setting SetNonblock and CloseOnExec. -// +build aix darwin dragonfly nacl netbsd openbsd solaris +// +build aix darwin nacl netbsd openbsd solaris package net import ( + "internal/poll" "os" "syscall" ) @@ -28,30 +29,8 @@ func sysSocket(family, sotype, proto int) (int, error) { return -1, os.NewSyscallError("socket", err) } if err = syscall.SetNonblock(s, true); err != nil { - closeFunc(s) + poll.CloseFunc(s) return -1, os.NewSyscallError("setnonblock", err) } return s, nil } - -// Wrapper around the accept system call that marks the returned file -// descriptor as nonblocking and close-on-exec. -func accept(s int) (int, syscall.Sockaddr, error) { - // See ../syscall/exec_unix.go for description of ForkLock. - // It is probably okay to hold the lock across syscall.Accept - // because we have put fd.sysfd into non-blocking mode. - // However, a call to the File method will put it back into - // blocking mode. We can't take that risk, so no use of ForkLock here. - ns, sa, err := acceptFunc(s) - if err == nil { - syscall.CloseOnExec(ns) - } - if err != nil { - return -1, nil, os.NewSyscallError("accept", err) - } - if err = syscall.SetNonblock(ns, true); err != nil { - closeFunc(ns) - return -1, nil, os.NewSyscallError("setnonblock", err) - } - return ns, sa, nil -} diff --git a/libgo/go/net/tcpsock.go b/libgo/go/net/tcpsock.go index 69731eb..e957aa3 100644 --- a/libgo/go/net/tcpsock.go +++ b/libgo/go/net/tcpsock.go @@ -50,28 +50,34 @@ func (a *TCPAddr) opAddr() Addr { return a } -// ResolveTCPAddr parses addr as a TCP address of the form "host:port" -// or "[ipv6-host%zone]:port" and resolves a pair of domain name and -// port name on the network net, which must be "tcp", "tcp4" or -// "tcp6". A literal address or host name for IPv6 must be enclosed -// in square brackets, as in "[::1]:80", "[ipv6-host]:http" or -// "[ipv6-host%zone]:80". +// ResolveTCPAddr returns an address of TCP end point. // -// Resolving a hostname is not recommended because this returns at most -// one of its IP addresses. -func ResolveTCPAddr(net, addr string) (*TCPAddr, error) { - switch net { +// The network must be a TCP network name. +// +// If the host in the address parameter is not a literal IP address or +// the port is not a literal port number, ResolveTCPAddr resolves the +// address to an address of TCP end point. +// Otherwise, it parses the address as a pair of literal IP address +// and port number. +// The address parameter can use a host name, but this is not +// recommended, because it will return at most one of the host name's +// IP addresses. +// +// See func Dial for a description of the network and address +// parameters. +func ResolveTCPAddr(network, address string) (*TCPAddr, error) { + switch network { case "tcp", "tcp4", "tcp6": case "": // a hint wildcard for Go 1.0 undocumented behavior - net = "tcp" + network = "tcp" default: - return nil, UnknownNetworkError(net) + return nil, UnknownNetworkError(network) } - addrs, err := DefaultResolver.internetAddrList(context.Background(), net, addr) + addrs, err := DefaultResolver.internetAddrList(context.Background(), network, address) if err != nil { return nil, err } - return addrs.first(isIPv4).(*TCPAddr), nil + return addrs.forResolve(network, address).(*TCPAddr), nil } // TCPConn is an implementation of the Conn interface for TCP network @@ -80,6 +86,15 @@ type TCPConn struct { conn } +// SyscallConn returns a raw network connection. +// This implements the syscall.Conn interface. +func (c *TCPConn) SyscallConn() (syscall.RawConn, error) { + if !c.ok() { + return nil, syscall.EINVAL + } + return newRawConn(c.fd) +} + // ReadFrom implements the io.ReaderFrom ReadFrom method. func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) { if !c.ok() { @@ -181,21 +196,25 @@ func newTCPConn(fd *netFD) *TCPConn { return c } -// DialTCP connects to the remote address raddr on the network net, -// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is -// used as the local address for the connection. -func DialTCP(net string, laddr, raddr *TCPAddr) (*TCPConn, error) { - switch net { +// DialTCP acts like Dial for TCP networks. +// +// The network must be a TCP network name; see func Dial for details. +// +// If laddr is nil, a local address is automatically chosen. +// If the IP field of raddr is nil or an unspecified IP address, the +// local system is assumed. +func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) { + switch network { case "tcp", "tcp4", "tcp6": default: - return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(net)} + return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(network)} } if raddr == nil { - return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress} + return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress} } - c, err := dialTCP(context.Background(), net, laddr, raddr) + c, err := dialTCP(context.Background(), network, laddr, raddr) if err != nil { - return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} + return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} } return c, nil } @@ -255,7 +274,7 @@ func (l *TCPListener) SetDeadline(t time.Time) error { if !l.ok() { return syscall.EINVAL } - if err := l.fd.setDeadline(t); err != nil { + if err := l.fd.pfd.SetDeadline(t); err != nil { return &OpError{Op: "set", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err} } return nil @@ -279,22 +298,27 @@ func (l *TCPListener) File() (f *os.File, err error) { return } -// ListenTCP announces on the TCP address laddr and returns a TCP -// listener. Net must be "tcp", "tcp4", or "tcp6". If laddr has a -// port of 0, ListenTCP will choose an available port. The caller can -// use the Addr method of TCPListener to retrieve the chosen address. -func ListenTCP(net string, laddr *TCPAddr) (*TCPListener, error) { - switch net { +// ListenTCP acts like Listen for TCP networks. +// +// The network must be a TCP network name; see func Dial for details. +// +// If the IP field of laddr is nil or an unspecified IP address, +// ListenTCP listens on all available unicast and anycast IP addresses +// of the local system. +// If the Port field of laddr is 0, a port number is automatically +// chosen. +func ListenTCP(network string, laddr *TCPAddr) (*TCPListener, error) { + switch network { case "tcp", "tcp4", "tcp6": default: - return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: UnknownNetworkError(net)} + return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: UnknownNetworkError(network)} } if laddr == nil { laddr = &TCPAddr{} } - ln, err := listenTCP(context.Background(), net, laddr) + ln, err := listenTCP(context.Background(), network, laddr) if err != nil { - return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err} + return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: err} } return ln, nil } diff --git a/libgo/go/net/tcpsock_plan9.go b/libgo/go/net/tcpsock_plan9.go index d286060..e37f065 100644 --- a/libgo/go/net/tcpsock_plan9.go +++ b/libgo/go/net/tcpsock_plan9.go @@ -48,6 +48,9 @@ func (ln *TCPListener) accept() (*TCPConn, error) { } func (ln *TCPListener) close() error { + if err := ln.fd.pfd.Close(); err != nil { + return err + } if _, err := ln.fd.ctl.WriteString("hangup"); err != nil { ln.fd.ctl.Close() return err diff --git a/libgo/go/net/tcpsock_posix.go b/libgo/go/net/tcpsock_posix.go index 7533c24..9ba199d 100644 --- a/libgo/go/net/tcpsock_posix.go +++ b/libgo/go/net/tcpsock_posix.go @@ -18,7 +18,7 @@ func sockaddrToTCP(sa syscall.Sockaddr) Addr { case *syscall.SockaddrInet4: return &TCPAddr{IP: sa.Addr[0:], Port: sa.Port} case *syscall.SockaddrInet6: - return &TCPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneToString(int(sa.ZoneId))} + return &TCPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneCache.name(int(sa.ZoneId))} } return nil } diff --git a/libgo/go/net/tcpsock_test.go b/libgo/go/net/tcpsock_test.go index 5115422..660f424 100644 --- a/libgo/go/net/tcpsock_test.go +++ b/libgo/go/net/tcpsock_test.go @@ -32,28 +32,28 @@ func BenchmarkTCP4PersistentTimeout(b *testing.B) { } func BenchmarkTCP6OneShot(b *testing.B) { - if !supportsIPv6 { + if !supportsIPv6() { b.Skip("ipv6 is not supported") } benchmarkTCP(b, false, false, "[::1]:0") } func BenchmarkTCP6OneShotTimeout(b *testing.B) { - if !supportsIPv6 { + if !supportsIPv6() { b.Skip("ipv6 is not supported") } benchmarkTCP(b, false, true, "[::1]:0") } func BenchmarkTCP6Persistent(b *testing.B) { - if !supportsIPv6 { + if !supportsIPv6() { b.Skip("ipv6 is not supported") } benchmarkTCP(b, true, false, "[::1]:0") } func BenchmarkTCP6PersistentTimeout(b *testing.B) { - if !supportsIPv6 { + if !supportsIPv6() { b.Skip("ipv6 is not supported") } benchmarkTCP(b, true, true, "[::1]:0") @@ -163,7 +163,7 @@ func BenchmarkTCP4ConcurrentReadWrite(b *testing.B) { } func BenchmarkTCP6ConcurrentReadWrite(b *testing.B) { - if !supportsIPv6 { + if !supportsIPv6() { b.Skip("ipv6 is not supported") } benchmarkTCPConcurrentReadWrite(b, "[::1]:0") @@ -372,7 +372,7 @@ func TestTCPListenerName(t *testing.T) { func TestIPv6LinkLocalUnicastTCP(t *testing.T) { testenv.MustHaveExternalNetwork(t) - if !supportsIPv6 { + if !supportsIPv6() { t.Skip("IPv6 is not supported") } diff --git a/libgo/go/net/tcpsock_unix_test.go b/libgo/go/net/tcpsock_unix_test.go index 2375fe2..3af1834 100644 --- a/libgo/go/net/tcpsock_unix_test.go +++ b/libgo/go/net/tcpsock_unix_test.go @@ -2,11 +2,14 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin +// +build !plan9,!windows package net import ( + "context" + "internal/testenv" + "math/rand" "runtime" "sync" "syscall" @@ -77,3 +80,37 @@ func TestTCPSpuriousConnSetupCompletion(t *testing.T) { ln.Close() wg.Wait() } + +// Issue 19289. +// Test that a canceled Dial does not cause a subsequent Dial to succeed. +func TestTCPSpuriousConnSetupCompletionWithCancel(t *testing.T) { + if testenv.Builder() == "" { + testenv.MustHaveExternalNetwork(t) + } + t.Parallel() + const tries = 10000 + var wg sync.WaitGroup + wg.Add(tries * 2) + sem := make(chan bool, 5) + for i := 0; i < tries; i++ { + sem <- true + ctx, cancel := context.WithCancel(context.Background()) + go func() { + defer wg.Done() + time.Sleep(time.Duration(rand.Int63n(int64(5 * time.Millisecond)))) + cancel() + }() + go func(i int) { + defer wg.Done() + var dialer Dialer + // Try to connect to a real host on a port + // that it is not listening on. + _, err := dialer.DialContext(ctx, "tcp", "golang.org:3") + if err == nil { + t.Errorf("Dial to unbound port succeeded on attempt %d", i) + } + <-sem + }(i) + } + wg.Wait() +} diff --git a/libgo/go/net/tcpsockopt_darwin.go b/libgo/go/net/tcpsockopt_darwin.go index 0d1310e..7415c76 100644 --- a/libgo/go/net/tcpsockopt_darwin.go +++ b/libgo/go/net/tcpsockopt_darwin.go @@ -5,7 +5,7 @@ package net import ( - "os" + "runtime" "syscall" "time" ) @@ -13,17 +13,15 @@ import ( const sysTCP_KEEPINTVL = 0x101 func setKeepAlivePeriod(fd *netFD, d time.Duration) error { - if err := fd.incref(); err != nil { - return err - } - defer fd.decref() // The kernel expects seconds so round to next highest second. d += (time.Second - time.Nanosecond) secs := int(d.Seconds()) - switch err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, sysTCP_KEEPINTVL, secs); err { + switch err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, sysTCP_KEEPINTVL, secs); err { case nil, syscall.ENOPROTOOPT: // OS X 10.7 and earlier don't support this option default: - return os.NewSyscallError("setsockopt", err) + return wrapSyscallError("setsockopt", err) } - return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.TCP_KEEPALIVE, secs)) + err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPALIVE, secs) + runtime.KeepAlive(fd) + return wrapSyscallError("setsockopt", err) } diff --git a/libgo/go/net/tcpsockopt_dragonfly.go b/libgo/go/net/tcpsockopt_dragonfly.go index 7cc716b..2b018f2 100644 --- a/libgo/go/net/tcpsockopt_dragonfly.go +++ b/libgo/go/net/tcpsockopt_dragonfly.go @@ -5,22 +5,20 @@ package net import ( - "os" + "runtime" "syscall" "time" ) func setKeepAlivePeriod(fd *netFD, d time.Duration) error { - if err := fd.incref(); err != nil { - return err - } - defer fd.decref() // The kernel expects milliseconds so round to next highest // millisecond. d += (time.Millisecond - time.Nanosecond) msecs := int(d / time.Millisecond) - if err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.TCP_KEEPINTVL, msecs); err != nil { - return os.NewSyscallError("setsockopt", err) + if err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPINTVL, msecs); err != nil { + return wrapSyscallError("setsockopt", err) } - return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.TCP_KEEPIDLE, msecs)) + err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPIDLE, msecs) + runtime.KeepAlive(fd) + return wrapSyscallError("setsockopt", err) } diff --git a/libgo/go/net/tcpsockopt_posix.go b/libgo/go/net/tcpsockopt_posix.go index 36866ac..5e00ba1 100644 --- a/libgo/go/net/tcpsockopt_posix.go +++ b/libgo/go/net/tcpsockopt_posix.go @@ -7,14 +7,12 @@ package net import ( - "os" + "runtime" "syscall" ) func setNoDelay(fd *netFD, noDelay bool) error { - if err := fd.incref(); err != nil { - return err - } - defer fd.decref() - return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.TCP_NODELAY, boolint(noDelay))) + err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_NODELAY, boolint(noDelay)) + runtime.KeepAlive(fd) + return wrapSyscallError("setsockopt", err) } diff --git a/libgo/go/net/tcpsockopt_solaris.go b/libgo/go/net/tcpsockopt_solaris.go index 347c17d..aa86a29 100644 --- a/libgo/go/net/tcpsockopt_solaris.go +++ b/libgo/go/net/tcpsockopt_solaris.go @@ -2,26 +2,20 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// TCP socket options for solaris - package net import ( - "os" + "runtime" "syscall" "time" ) -// Set keep alive period. func setKeepAlivePeriod(fd *netFD, d time.Duration) error { - if err := fd.incref(); err != nil { - return err - } - defer fd.decref() - // The kernel expects seconds so round to next highest second. d += (time.Second - time.Nanosecond) secs := int(d.Seconds()) - return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.SO_KEEPALIVE, secs)) + err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.SO_KEEPALIVE, secs) + runtime.KeepAlive(fd) + return wrapSyscallError("setsockopt", err) } diff --git a/libgo/go/net/tcpsockopt_unix.go b/libgo/go/net/tcpsockopt_unix.go index 46e5e6d..d589258 100644 --- a/libgo/go/net/tcpsockopt_unix.go +++ b/libgo/go/net/tcpsockopt_unix.go @@ -7,21 +7,19 @@ package net import ( - "os" + "runtime" "syscall" "time" ) func setKeepAlivePeriod(fd *netFD, d time.Duration) error { - if err := fd.incref(); err != nil { - return err - } - defer fd.decref() // The kernel expects seconds so round to next highest second. d += (time.Second - time.Nanosecond) secs := int(d.Seconds()) - if err := syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.TCP_KEEPINTVL, secs); err != nil { - return os.NewSyscallError("setsockopt", err) + if err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPINTVL, secs); err != nil { + return wrapSyscallError("setsockopt", err) } - return os.NewSyscallError("setsockopt", syscall.SetsockoptInt(fd.sysfd, syscall.IPPROTO_TCP, syscall.TCP_KEEPIDLE, secs)) + err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPIDLE, secs) + runtime.KeepAlive(fd) + return wrapSyscallError("setsockopt", err) } diff --git a/libgo/go/net/tcpsockopt_windows.go b/libgo/go/net/tcpsockopt_windows.go index 45a4dca..73dead1 100644 --- a/libgo/go/net/tcpsockopt_windows.go +++ b/libgo/go/net/tcpsockopt_windows.go @@ -6,16 +6,13 @@ package net import ( "os" + "runtime" "syscall" "time" "unsafe" ) func setKeepAlivePeriod(fd *netFD, d time.Duration) error { - if err := fd.incref(); err != nil { - return err - } - defer fd.decref() // The kernel expects milliseconds so round to next highest // millisecond. d += (time.Millisecond - time.Nanosecond) @@ -27,6 +24,7 @@ func setKeepAlivePeriod(fd *netFD, d time.Duration) error { } ret := uint32(0) size := uint32(unsafe.Sizeof(ka)) - err := syscall.WSAIoctl(fd.sysfd, syscall.SIO_KEEPALIVE_VALS, (*byte)(unsafe.Pointer(&ka)), size, nil, 0, &ret, nil, 0) + err := fd.pfd.WSAIoctl(syscall.SIO_KEEPALIVE_VALS, (*byte)(unsafe.Pointer(&ka)), size, nil, 0, &ret, nil, 0) + runtime.KeepAlive(fd) return os.NewSyscallError("wsaioctl", err) } diff --git a/libgo/go/net/timeout_test.go b/libgo/go/net/timeout_test.go index 55bbf44..9de7801 100644 --- a/libgo/go/net/timeout_test.go +++ b/libgo/go/net/timeout_test.go @@ -6,6 +6,7 @@ package net import ( "fmt" + "internal/poll" "internal/testenv" "io" "io/ioutil" @@ -145,9 +146,9 @@ var acceptTimeoutTests = []struct { }{ // Tests that accept deadlines in the past work, even if // there's incoming connections available. - {-5 * time.Second, [2]error{errTimeout, errTimeout}}, + {-5 * time.Second, [2]error{poll.ErrTimeout, poll.ErrTimeout}}, - {50 * time.Millisecond, [2]error{nil, errTimeout}}, + {50 * time.Millisecond, [2]error{nil, poll.ErrTimeout}}, } func TestAcceptTimeout(t *testing.T) { @@ -299,9 +300,9 @@ var readTimeoutTests = []struct { }{ // Tests that read deadlines work, even if there's data ready // to be read. - {-5 * time.Second, [2]error{errTimeout, errTimeout}}, + {-5 * time.Second, [2]error{poll.ErrTimeout, poll.ErrTimeout}}, - {50 * time.Millisecond, [2]error{nil, errTimeout}}, + {50 * time.Millisecond, [2]error{nil, poll.ErrTimeout}}, } func TestReadTimeout(t *testing.T) { @@ -423,9 +424,9 @@ var readFromTimeoutTests = []struct { }{ // Tests that read deadlines work, even if there's data ready // to be read. - {-5 * time.Second, [2]error{errTimeout, errTimeout}}, + {-5 * time.Second, [2]error{poll.ErrTimeout, poll.ErrTimeout}}, - {50 * time.Millisecond, [2]error{nil, errTimeout}}, + {50 * time.Millisecond, [2]error{nil, poll.ErrTimeout}}, } func TestReadFromTimeout(t *testing.T) { @@ -496,9 +497,9 @@ var writeTimeoutTests = []struct { }{ // Tests that write deadlines work, even if there's buffer // space available to write. - {-5 * time.Second, [2]error{errTimeout, errTimeout}}, + {-5 * time.Second, [2]error{poll.ErrTimeout, poll.ErrTimeout}}, - {10 * time.Millisecond, [2]error{nil, errTimeout}}, + {10 * time.Millisecond, [2]error{nil, poll.ErrTimeout}}, } func TestWriteTimeout(t *testing.T) { @@ -610,9 +611,9 @@ var writeToTimeoutTests = []struct { }{ // Tests that write deadlines work, even if there's buffer // space available to write. - {-5 * time.Second, [2]error{errTimeout, errTimeout}}, + {-5 * time.Second, [2]error{poll.ErrTimeout, poll.ErrTimeout}}, - {10 * time.Millisecond, [2]error{nil, errTimeout}}, + {10 * time.Millisecond, [2]error{nil, poll.ErrTimeout}}, } func TestWriteToTimeout(t *testing.T) { diff --git a/libgo/go/net/udpsock.go b/libgo/go/net/udpsock.go index 841ef53..2c0f74f 100644 --- a/libgo/go/net/udpsock.go +++ b/libgo/go/net/udpsock.go @@ -53,28 +53,34 @@ func (a *UDPAddr) opAddr() Addr { return a } -// ResolveUDPAddr parses addr as a UDP address of the form "host:port" -// or "[ipv6-host%zone]:port" and resolves a pair of domain name and -// port name on the network net, which must be "udp", "udp4" or -// "udp6". A literal address or host name for IPv6 must be enclosed -// in square brackets, as in "[::1]:80", "[ipv6-host]:http" or -// "[ipv6-host%zone]:80". +// ResolveUDPAddr returns an address of UDP end point. // -// Resolving a hostname is not recommended because this returns at most -// one of its IP addresses. -func ResolveUDPAddr(net, addr string) (*UDPAddr, error) { - switch net { +// The network must be a UDP network name. +// +// If the host in the address parameter is not a literal IP address or +// the port is not a literal port number, ResolveUDPAddr resolves the +// address to an address of UDP end point. +// Otherwise, it parses the address as a pair of literal IP address +// and port number. +// The address parameter can use a host name, but this is not +// recommended, because it will return at most one of the host name's +// IP addresses. +// +// See func Dial for a description of the network and address +// parameters. +func ResolveUDPAddr(network, address string) (*UDPAddr, error) { + switch network { case "udp", "udp4", "udp6": case "": // a hint wildcard for Go 1.0 undocumented behavior - net = "udp" + network = "udp" default: - return nil, UnknownNetworkError(net) + return nil, UnknownNetworkError(network) } - addrs, err := DefaultResolver.internetAddrList(context.Background(), net, addr) + addrs, err := DefaultResolver.internetAddrList(context.Background(), network, address) if err != nil { return nil, err } - return addrs.first(isIPv4).(*UDPAddr), nil + return addrs.forResolve(network, address).(*UDPAddr), nil } // UDPConn is the implementation of the Conn and PacketConn interfaces @@ -83,13 +89,16 @@ type UDPConn struct { conn } -// ReadFromUDP reads a UDP packet from c, copying the payload into b. -// It returns the number of bytes copied into b and the return address -// that was on the packet. -// -// ReadFromUDP can be made to time out and return an error with -// Timeout() == true after a fixed time limit; see SetDeadline and -// SetReadDeadline. +// SyscallConn returns a raw network connection. +// This implements the syscall.Conn interface. +func (c *UDPConn) SyscallConn() (syscall.RawConn, error) { + if !c.ok() { + return nil, syscall.EINVAL + } + return newRawConn(c.fd) +} + +// ReadFromUDP acts like ReadFrom but returns a UDPAddr. func (c *UDPConn) ReadFromUDP(b []byte) (int, *UDPAddr, error) { if !c.ok() { return 0, nil, syscall.EINVAL @@ -116,11 +125,13 @@ func (c *UDPConn) ReadFrom(b []byte) (int, Addr, error) { return n, addr, err } -// ReadMsgUDP reads a packet from c, copying the payload into b and -// the associated out-of-band data into oob. It returns the number -// of bytes copied into b, the number of bytes copied into oob, the -// flags that were set on the packet and the source address of the -// packet. +// ReadMsgUDP reads a message from c, copying the payload into b and +// the associated out-of-band data into oob. It returns the number of +// bytes copied into b, the number of bytes copied into oob, the flags +// that were set on the message and the source address of the message. +// +// The packages golang.org/x/net/ipv4 and golang.org/x/net/ipv6 can be +// used to manipulate IP-level socket options in oob. func (c *UDPConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, err error) { if !c.ok() { return 0, 0, 0, nil, syscall.EINVAL @@ -132,13 +143,7 @@ func (c *UDPConn) ReadMsgUDP(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, return } -// WriteToUDP writes a UDP packet to addr via c, copying the payload -// from b. -// -// WriteToUDP can be made to time out and return an error with -// Timeout() == true after a fixed time limit; see SetDeadline and -// SetWriteDeadline. On packet-oriented connections, write timeouts -// are rare. +// WriteToUDP acts like WriteTo but takes a UDPAddr. func (c *UDPConn) WriteToUDP(b []byte, addr *UDPAddr) (int, error) { if !c.ok() { return 0, syscall.EINVAL @@ -166,11 +171,14 @@ func (c *UDPConn) WriteTo(b []byte, addr Addr) (int, error) { return n, err } -// WriteMsgUDP writes a packet to addr via c if c isn't connected, or -// to c's remote destination address if c is connected (in which case -// addr must be nil). The payload is copied from b and the associated -// out-of-band data is copied from oob. It returns the number of -// payload and out-of-band bytes written. +// WriteMsgUDP writes a message to addr via c if c isn't connected, or +// to c's remote address if c is connected (in which case addr must be +// nil). The payload is copied from b and the associated out-of-band +// data is copied from oob. It returns the number of payload and +// out-of-band bytes written. +// +// The packages golang.org/x/net/ipv4 and golang.org/x/net/ipv6 can be +// used to manipulate IP-level socket options in oob. func (c *UDPConn) WriteMsgUDP(b, oob []byte, addr *UDPAddr) (n, oobn int, err error) { if !c.ok() { return 0, 0, syscall.EINVAL @@ -184,55 +192,67 @@ func (c *UDPConn) WriteMsgUDP(b, oob []byte, addr *UDPAddr) (n, oobn int, err er func newUDPConn(fd *netFD) *UDPConn { return &UDPConn{conn{fd}} } -// DialUDP connects to the remote address raddr on the network net, -// which must be "udp", "udp4", or "udp6". If laddr is not nil, it is -// used as the local address for the connection. -func DialUDP(net string, laddr, raddr *UDPAddr) (*UDPConn, error) { - switch net { +// DialUDP acts like Dial for UDP networks. +// +// The network must be a UDP network name; see func Dial for details. +// +// If laddr is nil, a local address is automatically chosen. +// If the IP field of raddr is nil or an unspecified IP address, the +// local system is assumed. +func DialUDP(network string, laddr, raddr *UDPAddr) (*UDPConn, error) { + switch network { case "udp", "udp4", "udp6": default: - return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(net)} + return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(network)} } if raddr == nil { - return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress} + return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress} } - c, err := dialUDP(context.Background(), net, laddr, raddr) + c, err := dialUDP(context.Background(), network, laddr, raddr) if err != nil { - return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} + return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} } return c, nil } -// ListenUDP listens for incoming UDP packets addressed to the local -// address laddr. Net must be "udp", "udp4", or "udp6". If laddr has -// a port of 0, ListenUDP will choose an available port. -// The LocalAddr method of the returned UDPConn can be used to -// discover the port. The returned connection's ReadFrom and WriteTo -// methods can be used to receive and send UDP packets with per-packet -// addressing. -func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) { - switch net { +// ListenUDP acts like ListenPacket for UDP networks. +// +// The network must be a UDP network name; see func Dial for details. +// +// If the IP field of laddr is nil or an unspecified IP address, +// ListenUDP listens on all available IP addresses of the local system +// except multicast IP addresses. +// If the Port field of laddr is 0, a port number is automatically +// chosen. +func ListenUDP(network string, laddr *UDPAddr) (*UDPConn, error) { + switch network { case "udp", "udp4", "udp6": default: - return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: UnknownNetworkError(net)} + return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: UnknownNetworkError(network)} } if laddr == nil { laddr = &UDPAddr{} } - c, err := listenUDP(context.Background(), net, laddr) + c, err := listenUDP(context.Background(), network, laddr) if err != nil { - return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err} + return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: err} } return c, nil } -// ListenMulticastUDP listens for incoming multicast UDP packets -// addressed to the group address gaddr on the interface ifi. -// Network must be "udp", "udp4" or "udp6". -// ListenMulticastUDP uses the system-assigned multicast interface -// when ifi is nil, although this is not recommended because the +// ListenMulticastUDP acts like ListenPacket for UDP networks but +// takes a group address on a specific network interface. +// +// The network must be a UDP network name; see func Dial for details. +// +// ListenMulticastUDP listens on all available IP addresses of the +// local system including the group, multicast IP address. +// If ifi is nil, ListenMulticastUDP uses the system-assigned +// multicast interface, although this is not recommended because the // assignment depends on platforms and sometimes it might require // routing configuration. +// If the Port field of gaddr is 0, a port number is automatically +// chosen. // // ListenMulticastUDP is just for convenience of simple, small // applications. There are golang.org/x/net/ipv4 and diff --git a/libgo/go/net/udpsock_posix.go b/libgo/go/net/udpsock_posix.go index 0c905af..fe552ba 100644 --- a/libgo/go/net/udpsock_posix.go +++ b/libgo/go/net/udpsock_posix.go @@ -16,7 +16,7 @@ func sockaddrToUDP(sa syscall.Sockaddr) Addr { case *syscall.SockaddrInet4: return &UDPAddr{IP: sa.Addr[0:], Port: sa.Port} case *syscall.SockaddrInet6: - return &UDPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneToString(int(sa.ZoneId))} + return &UDPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneCache.name(int(sa.ZoneId))} } return nil } @@ -49,7 +49,7 @@ func (c *UDPConn) readFrom(b []byte) (int, *UDPAddr, error) { case *syscall.SockaddrInet4: addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port} case *syscall.SockaddrInet6: - addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneToString(int(sa.ZoneId))} + addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneCache.name(int(sa.ZoneId))} } return n, addr, err } @@ -61,7 +61,7 @@ func (c *UDPConn) readMsg(b, oob []byte) (n, oobn, flags int, addr *UDPAddr, err case *syscall.SockaddrInet4: addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port} case *syscall.SockaddrInet6: - addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneToString(int(sa.ZoneId))} + addr = &UDPAddr{IP: sa.Addr[0:], Port: sa.Port, Zone: zoneCache.name(int(sa.ZoneId))} } return } diff --git a/libgo/go/net/udpsock_test.go b/libgo/go/net/udpsock_test.go index 708cc10..6d4974e 100644 --- a/libgo/go/net/udpsock_test.go +++ b/libgo/go/net/udpsock_test.go @@ -15,7 +15,7 @@ import ( func BenchmarkUDP6LinkLocalUnicast(b *testing.B) { testHookUninstaller.Do(uninstallTestHooks) - if !supportsIPv6 { + if !supportsIPv6() { b.Skip("IPv6 is not supported") } ifi := loopbackInterface() @@ -279,7 +279,7 @@ func TestUDPConnLocalAndRemoteNames(t *testing.T) { func TestIPv6LinkLocalUnicastUDP(t *testing.T) { testenv.MustHaveExternalNetwork(t) - if !supportsIPv6 { + if !supportsIPv6() { t.Skip("IPv6 is not supported") } diff --git a/libgo/go/net/unixsock.go b/libgo/go/net/unixsock.go index b25d492..057940a 100644 --- a/libgo/go/net/unixsock.go +++ b/libgo/go/net/unixsock.go @@ -42,15 +42,18 @@ func (a *UnixAddr) opAddr() Addr { return a } -// ResolveUnixAddr parses addr as a Unix domain socket address. -// The string net gives the network name, "unix", "unixgram" or -// "unixpacket". -func ResolveUnixAddr(net, addr string) (*UnixAddr, error) { - switch net { +// ResolveUnixAddr returns an address of Unix domain socket end point. +// +// The network must be a Unix network name. +// +// See func Dial for a description of the network and address +// parameters. +func ResolveUnixAddr(network, address string) (*UnixAddr, error) { + switch network { case "unix", "unixgram", "unixpacket": - return &UnixAddr{Name: addr, Net: net}, nil + return &UnixAddr{Name: address, Net: network}, nil default: - return nil, UnknownNetworkError(net) + return nil, UnknownNetworkError(network) } } @@ -60,6 +63,15 @@ type UnixConn struct { conn } +// SyscallConn returns a raw network connection. +// This implements the syscall.Conn interface. +func (c *UnixConn) SyscallConn() (syscall.RawConn, error) { + if !c.ok() { + return nil, syscall.EINVAL + } + return newRawConn(c.fd) +} + // CloseRead shuts down the reading side of the Unix domain connection. // Most callers should just use Close. func (c *UnixConn) CloseRead() error { @@ -84,13 +96,7 @@ func (c *UnixConn) CloseWrite() error { return nil } -// ReadFromUnix reads a packet from c, copying the payload into b. It -// returns the number of bytes copied into b and the source address of -// the packet. -// -// ReadFromUnix can be made to time out and return an error with -// Timeout() == true after a fixed time limit; see SetDeadline and -// SetReadDeadline. +// ReadFromUnix acts like ReadFrom but returns a UnixAddr. func (c *UnixConn) ReadFromUnix(b []byte) (int, *UnixAddr, error) { if !c.ok() { return 0, nil, syscall.EINVAL @@ -117,10 +123,10 @@ func (c *UnixConn) ReadFrom(b []byte) (int, Addr, error) { return n, addr, err } -// ReadMsgUnix reads a packet from c, copying the payload into b and +// ReadMsgUnix reads a message from c, copying the payload into b and // the associated out-of-band data into oob. It returns the number of // bytes copied into b, the number of bytes copied into oob, the flags -// that were set on the packet, and the source address of the packet. +// that were set on the message and the source address of the message. // // Note that if len(b) == 0 and len(oob) > 0, this function will still // read (and discard) 1 byte from the connection. @@ -135,12 +141,7 @@ func (c *UnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *UnixAdd return } -// WriteToUnix writes a packet to addr via c, copying the payload from b. -// -// WriteToUnix can be made to time out and return an error with -// Timeout() == true after a fixed time limit; see SetDeadline and -// SetWriteDeadline. On packet-oriented connections, write timeouts -// are rare. +// WriteToUnix acts like WriteTo but takes a UnixAddr. func (c *UnixConn) WriteToUnix(b []byte, addr *UnixAddr) (int, error) { if !c.ok() { return 0, syscall.EINVAL @@ -168,9 +169,9 @@ func (c *UnixConn) WriteTo(b []byte, addr Addr) (int, error) { return n, err } -// WriteMsgUnix writes a packet to addr via c, copying the payload -// from b and the associated out-of-band data from oob. It returns -// the number of payload and out-of-band bytes written. +// WriteMsgUnix writes a message to addr via c, copying the payload +// from b and the associated out-of-band data from oob. It returns the +// number of payload and out-of-band bytes written. // // Note that if len(b) == 0 and len(oob) > 0, this function will still // write 1 byte to the connection. @@ -187,18 +188,21 @@ func (c *UnixConn) WriteMsgUnix(b, oob []byte, addr *UnixAddr) (n, oobn int, err func newUnixConn(fd *netFD) *UnixConn { return &UnixConn{conn{fd}} } -// DialUnix connects to the remote address raddr on the network net, -// which must be "unix", "unixgram" or "unixpacket". If laddr is not -// nil, it is used as the local address for the connection. -func DialUnix(net string, laddr, raddr *UnixAddr) (*UnixConn, error) { - switch net { +// DialUnix acts like Dial for Unix networks. +// +// The network must be a Unix network name; see func Dial for details. +// +// If laddr is non-nil, it is used as the local address for the +// connection. +func DialUnix(network string, laddr, raddr *UnixAddr) (*UnixConn, error) { + switch network { case "unix", "unixgram", "unixpacket": default: - return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(net)} + return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(network)} } - c, err := dialUnix(context.Background(), net, laddr, raddr) + c, err := dialUnix(context.Background(), network, laddr, raddr) if err != nil { - return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} + return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} } return c, nil } @@ -264,7 +268,7 @@ func (l *UnixListener) SetDeadline(t time.Time) error { if !l.ok() { return syscall.EINVAL } - if err := l.fd.setDeadline(t); err != nil { + if err := l.fd.pfd.SetDeadline(t); err != nil { return &OpError{Op: "set", Net: l.fd.net, Source: nil, Addr: l.fd.laddr, Err: err} } return nil @@ -288,40 +292,40 @@ func (l *UnixListener) File() (f *os.File, err error) { return } -// ListenUnix announces on the Unix domain socket laddr and returns a -// Unix listener. The network net must be "unix" or "unixpacket". -func ListenUnix(net string, laddr *UnixAddr) (*UnixListener, error) { - switch net { +// ListenUnix acts like Listen for Unix networks. +// +// The network must be "unix" or "unixpacket". +func ListenUnix(network string, laddr *UnixAddr) (*UnixListener, error) { + switch network { case "unix", "unixpacket": default: - return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: UnknownNetworkError(net)} + return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: UnknownNetworkError(network)} } if laddr == nil { - return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: errMissingAddress} + return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: errMissingAddress} } - ln, err := listenUnix(context.Background(), net, laddr) + ln, err := listenUnix(context.Background(), network, laddr) if err != nil { - return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err} + return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: err} } return ln, nil } -// ListenUnixgram listens for incoming Unix datagram packets addressed -// to the local address laddr. The network net must be "unixgram". -// The returned connection's ReadFrom and WriteTo methods can be used -// to receive and send packets with per-packet addressing. -func ListenUnixgram(net string, laddr *UnixAddr) (*UnixConn, error) { - switch net { +// ListenUnixgram acts like ListenPacket for Unix networks. +// +// The network must be "unixgram". +func ListenUnixgram(network string, laddr *UnixAddr) (*UnixConn, error) { + switch network { case "unixgram": default: - return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: UnknownNetworkError(net)} + return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: UnknownNetworkError(network)} } if laddr == nil { - return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: errMissingAddress} + return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: errMissingAddress} } - c, err := listenUnixgram(context.Background(), net, laddr) + c, err := listenUnixgram(context.Background(), network, laddr) if err != nil { - return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err} + return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: err} } return c, nil } diff --git a/libgo/go/net/url/url.go b/libgo/go/net/url/url.go index 42a514b..2ac2472 100644 --- a/libgo/go/net/url/url.go +++ b/libgo/go/net/url/url.go @@ -309,9 +309,10 @@ func escape(s string, mode encoding) string { } // A URL represents a parsed URL (technically, a URI reference). +// // The general form represented is: // -// scheme://[userinfo@]host/path[?query][#fragment] +// [scheme:][//[userinfo@]host][/]path[?query][#fragment] // // URLs that do not start with a slash after the scheme are interpreted as: // @@ -321,26 +322,19 @@ func escape(s string, mode encoding) string { // A consequence is that it is impossible to tell which slashes in the Path were // slashes in the raw URL and which were %2f. This distinction is rarely important, // but when it is, code must not use Path directly. -// -// Go 1.5 introduced the RawPath field to hold the encoded form of Path. // The Parse function sets both Path and RawPath in the URL it returns, // and URL's String method uses RawPath if it is a valid encoding of Path, // by calling the EscapedPath method. -// -// In earlier versions of Go, the more indirect workarounds were that an -// HTTP server could consult req.RequestURI and an HTTP client could -// construct a URL struct directly and set the Opaque field instead of Path. -// These still work as well. type URL struct { Scheme string Opaque string // encoded opaque data User *Userinfo // username and password information Host string // host or host:port - Path string - RawPath string // encoded path hint (Go 1.5 and later only; see EscapedPath method) - ForceQuery bool // append a query ('?') even if RawQuery is empty - RawQuery string // encoded query values, without '?' - Fragment string // fragment for references, without '#' + Path string // path (relative paths may omit leading slash) + RawPath string // encoded path hint (see EscapedPath method) + ForceQuery bool // append a query ('?') even if RawQuery is empty + RawQuery string // encoded query values, without '?' + Fragment string // fragment for references, without '#' } // User returns a Userinfo containing the provided username @@ -351,6 +345,7 @@ func User(username string) *Userinfo { // UserPassword returns a Userinfo containing the provided username // and password. +// // This functionality should only be used with legacy web sites. // RFC 2396 warns that interpreting Userinfo this way // ``is NOT RECOMMENDED, because the passing of authentication @@ -974,6 +969,8 @@ func (u *URL) ResolveReference(ref *URL) *URL { } // Query parses RawQuery and returns the corresponding values. +// It silently discards malformed value pairs. +// To check errors use ParseQuery. func (u *URL) Query() Values { v, _ := ParseQuery(u.RawQuery) return v diff --git a/libgo/go/net/writev_test.go b/libgo/go/net/writev_test.go index 7160d28..4c05be4 100644 --- a/libgo/go/net/writev_test.go +++ b/libgo/go/net/writev_test.go @@ -7,6 +7,7 @@ package net import ( "bytes" "fmt" + "internal/poll" "io" "io/ioutil" "reflect" @@ -99,13 +100,13 @@ func TestBuffers_WriteTo(t *testing.T) { } func testBuffer_writeTo(t *testing.T, chunks int, useCopy bool) { - oldHook := testHookDidWritev - defer func() { testHookDidWritev = oldHook }() + oldHook := poll.TestHookDidWritev + defer func() { poll.TestHookDidWritev = oldHook }() var writeLog struct { sync.Mutex log []int } - testHookDidWritev = func(size int) { + poll.TestHookDidWritev = func(size int) { writeLog.Lock() writeLog.log = append(writeLog.log, size) writeLog.Unlock() diff --git a/libgo/go/net/writev_unix.go b/libgo/go/net/writev_unix.go index 174e6bc..bf0fbf8 100644 --- a/libgo/go/net/writev_unix.go +++ b/libgo/go/net/writev_unix.go @@ -7,10 +7,8 @@ package net import ( - "io" - "os" + "runtime" "syscall" - "unsafe" ) func (c *conn) writeBuffers(v *Buffers) (int64, error) { @@ -25,71 +23,7 @@ func (c *conn) writeBuffers(v *Buffers) (int64, error) { } func (fd *netFD) writeBuffers(v *Buffers) (n int64, err error) { - if err := fd.writeLock(); err != nil { - return 0, err - } - defer fd.writeUnlock() - if err := fd.pd.prepareWrite(); err != nil { - return 0, err - } - - var iovecs []syscall.Iovec - if fd.iovecs != nil { - iovecs = *fd.iovecs - } - // TODO: read from sysconf(_SC_IOV_MAX)? The Linux default is - // 1024 and this seems conservative enough for now. Darwin's - // UIO_MAXIOV also seems to be 1024. - maxVec := 1024 - - for len(*v) > 0 { - iovecs = iovecs[:0] - for _, chunk := range *v { - if len(chunk) == 0 { - continue - } - iovecs = append(iovecs, syscall.Iovec{Base: &chunk[0]}) - if fd.isStream && len(chunk) > 1<<30 { - iovecs[len(iovecs)-1].SetLen(1 << 30) - break // continue chunk on next writev - } - iovecs[len(iovecs)-1].SetLen(len(chunk)) - if len(iovecs) == maxVec { - break - } - } - if len(iovecs) == 0 { - break - } - fd.iovecs = &iovecs // cache - - wrote, _, e0 := syscall.Syscall(syscall.SYS_WRITEV, - uintptr(fd.sysfd), - uintptr(unsafe.Pointer(&iovecs[0])), - uintptr(len(iovecs))) - if wrote == ^uintptr(0) { - wrote = 0 - } - testHookDidWritev(int(wrote)) - n += int64(wrote) - v.consume(int64(wrote)) - if e0 == syscall.EAGAIN { - if err = fd.pd.waitWrite(); err == nil { - continue - } - } else if e0 != 0 { - err = syscall.Errno(e0) - } - if err != nil { - break - } - if n == 0 { - err = io.ErrUnexpectedEOF - break - } - } - if _, ok := err.(syscall.Errno); ok { - err = os.NewSyscallError("writev", err) - } - return n, err + n, err = fd.pfd.Writev((*[][]byte)(v)) + runtime.KeepAlive(fd) + return n, wrapSyscallError("writev", err) } |