diff options
author | Ian Lance Taylor <iant@golang.org> | 2017-01-14 00:05:42 +0000 |
---|---|---|
committer | Ian Lance Taylor <ian@gcc.gnu.org> | 2017-01-14 00:05:42 +0000 |
commit | c2047754c300b68c05d65faa8dc2925fe67b71b4 (patch) | |
tree | e183ae81a1f48a02945cb6de463a70c5be1b06f6 /libgo/go/net | |
parent | 829afb8f05602bb31c9c597b24df7377fed4f059 (diff) | |
download | gcc-c2047754c300b68c05d65faa8dc2925fe67b71b4.zip gcc-c2047754c300b68c05d65faa8dc2925fe67b71b4.tar.gz gcc-c2047754c300b68c05d65faa8dc2925fe67b71b4.tar.bz2 |
libgo: update to Go 1.8 release candidate 1
Compiler changes:
* Change map assignment to use mapassign and assign value directly.
* Change string iteration to use decoderune, faster for ASCII strings.
* Change makeslice to take int, and use makeslice64 for larger values.
* Add new noverflow field to hmap struct used for maps.
Unresolved problems, to be fixed later:
* Commented out test in go/types/sizes_test.go that doesn't compile.
* Commented out reflect.TestStructOf test for padding after zero-sized field.
Reviewed-on: https://go-review.googlesource.com/35231
gotools/:
Updates for Go 1.8rc1.
* Makefile.am (go_cmd_go_files): Add bug.go.
(s-zdefaultcc): Write defaultPkgConfig.
* Makefile.in: Rebuild.
From-SVN: r244456
Diffstat (limited to 'libgo/go/net')
130 files changed, 9938 insertions, 2571 deletions
diff --git a/libgo/go/net/addrselect.go b/libgo/go/net/addrselect.go index 0b9d160..1ab9fc53 100644 --- a/libgo/go/net/addrselect.go +++ b/libgo/go/net/addrselect.go @@ -188,33 +188,17 @@ func (s *byRFC6724) Less(i, j int) bool { // Rule 9: Use longest matching prefix. // When DA and DB belong to the same address family (both are IPv6 or - // both are IPv4): If CommonPrefixLen(Source(DA), DA) > + // both are IPv4 [but see below]): If CommonPrefixLen(Source(DA), DA) > // CommonPrefixLen(Source(DB), DB), then prefer DA. Similarly, if // CommonPrefixLen(Source(DA), DA) < CommonPrefixLen(Source(DB), DB), // then prefer DB. - da4 := DA.To4() != nil - db4 := DB.To4() != nil - if da4 == db4 { + // + // However, applying this rule to IPv4 addresses causes + // problems (see issues 13283 and 18518), so limit to IPv6. + if DA.To4() == nil && DB.To4() == nil { commonA := commonPrefixLen(SourceDA, DA) commonB := commonPrefixLen(SourceDB, DB) - // CommonPrefixLen doesn't really make sense for IPv4, and even - // causes problems for common load balancing practices - // (e.g., https://golang.org/issue/13283). Glibc instead only - // uses CommonPrefixLen for IPv4 when the source and destination - // addresses are on the same subnet, but that requires extra - // work to find the netmask for our source addresses. As a - // simpler heuristic, we limit its use to when the source and - // destination belong to the same special purpose block. - if da4 { - if !sameIPv4SpecialPurposeBlock(SourceDA, DA) { - commonA = 0 - } - if !sameIPv4SpecialPurposeBlock(SourceDB, DB) { - commonB = 0 - } - } - if commonA > commonB { return preferDA } @@ -404,28 +388,3 @@ func commonPrefixLen(a, b IP) (cpl int) { } return } - -// sameIPv4SpecialPurposeBlock reports whether a and b belong to the same -// address block reserved by the IANA IPv4 Special-Purpose Address Registry: -// http://www.iana.org/assignments/iana-ipv4-special-registry/iana-ipv4-special-registry.xhtml -func sameIPv4SpecialPurposeBlock(a, b IP) bool { - a, b = a.To4(), b.To4() - if a == nil || b == nil || a[0] != b[0] { - return false - } - // IANA defines more special-purpose blocks, but these are the only - // ones likely to be relevant to typical Go systems. - switch a[0] { - case 10: // 10.0.0.0/8: Private-Use - return true - case 127: // 127.0.0.0/8: Loopback - return true - case 169: // 169.254.0.0/16: Link Local - return a[1] == 254 && b[1] == 254 - case 172: // 172.16.0.0/12: Private-Use - return a[1]&0xf0 == 16 && b[1]&0xf0 == 16 - case 192: // 192.168.0.0/16: Private-Use - return a[1] == 168 && b[1] == 168 - } - return false -} diff --git a/libgo/go/net/addrselect_test.go b/libgo/go/net/addrselect_test.go index 80aa4eb..d6e0e63 100644 --- a/libgo/go/net/addrselect_test.go +++ b/libgo/go/net/addrselect_test.go @@ -117,27 +117,6 @@ func TestSortByRFC6724(t *testing.T) { }, reverse: false, }, - - // Prefer longer common prefixes, but only for IPv4 address - // pairs in the same special-purpose block. - { - in: []IPAddr{ - {IP: ParseIP("1.2.3.4")}, - {IP: ParseIP("10.55.0.1")}, - {IP: ParseIP("10.66.0.1")}, - }, - srcs: []IP{ - ParseIP("1.2.3.5"), - ParseIP("10.66.1.2"), - ParseIP("10.66.1.2"), - }, - want: []IPAddr{ - {IP: ParseIP("10.66.0.1")}, - {IP: ParseIP("10.55.0.1")}, - {IP: ParseIP("1.2.3.4")}, - }, - reverse: true, - }, } for i, tt := range tests { inCopy := make([]IPAddr, len(tt.in)) @@ -268,67 +247,3 @@ func TestRFC6724CommonPrefixLength(t *testing.T) { } } - -func mustParseCIDRs(t *testing.T, blocks ...string) []*IPNet { - res := make([]*IPNet, len(blocks)) - for i, block := range blocks { - var err error - _, res[i], err = ParseCIDR(block) - if err != nil { - t.Fatalf("ParseCIDR(%s) failed: %v", block, err) - } - } - return res -} - -func TestSameIPv4SpecialPurposeBlock(t *testing.T) { - blocks := mustParseCIDRs(t, - "10.0.0.0/8", - "127.0.0.0/8", - "169.254.0.0/16", - "172.16.0.0/12", - "192.168.0.0/16", - ) - - addrs := []struct { - ip IP - block int // index or -1 - }{ - {IP{1, 2, 3, 4}, -1}, - {IP{2, 3, 4, 5}, -1}, - {IP{10, 2, 3, 4}, 0}, - {IP{10, 6, 7, 8}, 0}, - {IP{127, 0, 0, 1}, 1}, - {IP{127, 255, 255, 255}, 1}, - {IP{169, 254, 77, 99}, 2}, - {IP{169, 254, 44, 22}, 2}, - {IP{169, 255, 0, 1}, -1}, - {IP{172, 15, 5, 6}, -1}, - {IP{172, 16, 32, 41}, 3}, - {IP{172, 31, 128, 9}, 3}, - {IP{172, 32, 88, 100}, -1}, - {IP{192, 168, 1, 1}, 4}, - {IP{192, 168, 128, 42}, 4}, - {IP{192, 169, 1, 1}, -1}, - } - - for i, addr := range addrs { - for j, block := range blocks { - got := block.Contains(addr.ip) - want := addr.block == j - if got != want { - t.Errorf("%d/%d. %s.Contains(%s): got %v, want %v", i, j, block, addr.ip, got, want) - } - } - } - - for i, addr1 := range addrs { - for j, addr2 := range addrs { - got := sameIPv4SpecialPurposeBlock(addr1.ip, addr2.ip) - want := addr1.block >= 0 && addr1.block == addr2.block - if got != want { - t.Errorf("%d/%d. sameIPv4SpecialPurposeBlock(%s, %s): got %v, want %v", i, j, addr1.ip, addr2.ip, got, want) - } - } - } -} diff --git a/libgo/go/net/cgo_unix.go b/libgo/go/net/cgo_unix.go index 525c63c..a90aaa9 100644 --- a/libgo/go/net/cgo_unix.go +++ b/libgo/go/net/cgo_unix.go @@ -114,7 +114,15 @@ func cgoLookupPort(ctx context.Context, network, service string) (port int, err } func cgoLookupServicePort(hints *syscall.Addrinfo, network, service string) (port int, err error) { - s := syscall.StringBytePtr(service) + s, err := syscall.BytePtrFromString(service) + if err != nil { + return 0, err + } + // Lowercase the service name in the memory passed to C. + for i := 0; i < len(service); i++ { + bp := (*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(s)) + uintptr(i))) + *bp = lowerASCII(*bp) + } var res *syscall.Addrinfo syscall.Entersyscall() gerrno := libc_getaddrinfo(nil, s, hints, &res) diff --git a/libgo/go/net/conf.go b/libgo/go/net/conf.go index eb72916..c10aafe 100644 --- a/libgo/go/net/conf.go +++ b/libgo/go/net/conf.go @@ -179,8 +179,6 @@ func (c *conf) hostLookupOrder(hostname string) (ret hostLookupOrder) { } } - hasDot := byteIndex(hostname, '.') != -1 - // Canonicalize the hostname by removing any trailing dot. if stringsHasSuffix(hostname, ".") { hostname = hostname[:len(hostname)-1] @@ -220,10 +218,14 @@ func (c *conf) hostLookupOrder(hostname string) (ret hostLookupOrder) { var first string for _, src := range srcs { if src.source == "myhostname" { - if hostname == "" || hasDot { - continue + if isLocalhost(hostname) || isGateway(hostname) { + return fallbackOrder } - return fallbackOrder + hn, err := getHostname() + if err != nil || stringsEqualFold(hostname, hn) { + return fallbackOrder + } + continue } if src.source == "files" || src.source == "dns" { if !src.standardCriteria() { @@ -293,7 +295,7 @@ func goDebugNetDNS() (dnsMode string, debugLevel int) { return } if '0' <= s[0] && s[0] <= '9' { - debugLevel, _, _ = dtoi(s, 0) + debugLevel, _, _ = dtoi(s) } else { dnsMode = s } @@ -306,3 +308,15 @@ func goDebugNetDNS() (dnsMode string, debugLevel int) { parsePart(goDebug) return } + +// isLocalhost reports whether h should be considered a "localhost" +// name for the myhostname NSS module. +func isLocalhost(h string) bool { + return stringsEqualFold(h, "localhost") || stringsEqualFold(h, "localhost.localdomain") || stringsHasSuffixFold(h, ".localhost") || stringsHasSuffixFold(h, ".localhost.localdomain") +} + +// isGateway reports whether h should be considered a "gateway" +// name for the myhostname NSS module. +func isGateway(h string) bool { + return stringsEqualFold(h, "gateway") +} diff --git a/libgo/go/net/conf_test.go b/libgo/go/net/conf_test.go index ec8814b..17d03f4 100644 --- a/libgo/go/net/conf_test.go +++ b/libgo/go/net/conf_test.go @@ -13,8 +13,9 @@ import ( ) type nssHostTest struct { - host string - want hostLookupOrder + host string + localhost string + want hostLookupOrder } func nssStr(s string) *nssConf { return parseNSSConf(strings.NewReader(s)) } @@ -42,8 +43,8 @@ func TestConfHostLookupOrder(t *testing.T) { resolv: defaultResolvConf, }, hostTests: []nssHostTest{ - {"foo.local", hostLookupCgo}, - {"google.com", hostLookupCgo}, + {"foo.local", "myhostname", hostLookupCgo}, + {"google.com", "myhostname", hostLookupCgo}, }, }, { @@ -54,7 +55,7 @@ func TestConfHostLookupOrder(t *testing.T) { resolv: defaultResolvConf, }, hostTests: []nssHostTest{ - {"x.com", hostLookupDNSFiles}, + {"x.com", "myhostname", hostLookupDNSFiles}, }, }, { @@ -65,7 +66,7 @@ func TestConfHostLookupOrder(t *testing.T) { resolv: defaultResolvConf, }, hostTests: []nssHostTest{ - {"x.com", hostLookupFilesDNS}, + {"x.com", "myhostname", hostLookupFilesDNS}, }, }, { @@ -75,11 +76,11 @@ func TestConfHostLookupOrder(t *testing.T) { resolv: defaultResolvConf, }, hostTests: []nssHostTest{ - {"foo.local", hostLookupCgo}, - {"foo.local.", hostLookupCgo}, - {"foo.LOCAL", hostLookupCgo}, - {"foo.LOCAL.", hostLookupCgo}, - {"google.com", hostLookupFilesDNS}, + {"foo.local", "myhostname", hostLookupCgo}, + {"foo.local.", "myhostname", hostLookupCgo}, + {"foo.LOCAL", "myhostname", hostLookupCgo}, + {"foo.LOCAL.", "myhostname", hostLookupCgo}, + {"google.com", "myhostname", hostLookupFilesDNS}, }, }, { @@ -89,7 +90,7 @@ func TestConfHostLookupOrder(t *testing.T) { nss: nssStr("foo: bar"), resolv: defaultResolvConf, }, - hostTests: []nssHostTest{{"google.com", hostLookupFilesDNS}}, + hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupFilesDNS}}, }, // On OpenBSD, no resolv.conf means no DNS. { @@ -98,7 +99,7 @@ func TestConfHostLookupOrder(t *testing.T) { goos: "openbsd", resolv: defaultResolvConf, }, - hostTests: []nssHostTest{{"google.com", hostLookupFiles}}, + hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupFiles}}, }, { name: "solaris_no_nsswitch", @@ -107,7 +108,7 @@ func TestConfHostLookupOrder(t *testing.T) { nss: &nssConf{err: os.ErrNotExist}, resolv: defaultResolvConf, }, - hostTests: []nssHostTest{{"google.com", hostLookupCgo}}, + hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupCgo}}, }, { name: "openbsd_lookup_bind_file", @@ -116,8 +117,8 @@ func TestConfHostLookupOrder(t *testing.T) { resolv: &dnsConfig{lookup: []string{"bind", "file"}}, }, hostTests: []nssHostTest{ - {"google.com", hostLookupDNSFiles}, - {"foo.local", hostLookupDNSFiles}, + {"google.com", "myhostname", hostLookupDNSFiles}, + {"foo.local", "myhostname", hostLookupDNSFiles}, }, }, { @@ -126,7 +127,7 @@ func TestConfHostLookupOrder(t *testing.T) { goos: "openbsd", resolv: &dnsConfig{lookup: []string{"file", "bind"}}, }, - hostTests: []nssHostTest{{"google.com", hostLookupFilesDNS}}, + hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupFilesDNS}}, }, { name: "openbsd_lookup_bind", @@ -134,7 +135,7 @@ func TestConfHostLookupOrder(t *testing.T) { goos: "openbsd", resolv: &dnsConfig{lookup: []string{"bind"}}, }, - hostTests: []nssHostTest{{"google.com", hostLookupDNS}}, + hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupDNS}}, }, { name: "openbsd_lookup_file", @@ -142,7 +143,7 @@ func TestConfHostLookupOrder(t *testing.T) { goos: "openbsd", resolv: &dnsConfig{lookup: []string{"file"}}, }, - hostTests: []nssHostTest{{"google.com", hostLookupFiles}}, + hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupFiles}}, }, { name: "openbsd_lookup_yp", @@ -150,7 +151,7 @@ func TestConfHostLookupOrder(t *testing.T) { goos: "openbsd", resolv: &dnsConfig{lookup: []string{"file", "bind", "yp"}}, }, - hostTests: []nssHostTest{{"google.com", hostLookupCgo}}, + hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupCgo}}, }, { name: "openbsd_lookup_two", @@ -158,7 +159,7 @@ func TestConfHostLookupOrder(t *testing.T) { goos: "openbsd", resolv: &dnsConfig{lookup: []string{"file", "foo"}}, }, - hostTests: []nssHostTest{{"google.com", hostLookupCgo}}, + hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupCgo}}, }, { name: "openbsd_lookup_empty", @@ -166,7 +167,7 @@ func TestConfHostLookupOrder(t *testing.T) { goos: "openbsd", resolv: &dnsConfig{lookup: nil}, }, - hostTests: []nssHostTest{{"google.com", hostLookupDNSFiles}}, + hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupDNSFiles}}, }, // glibc lacking an nsswitch.conf, per // http://www.gnu.org/software/libc/manual/html_node/Notes-on-NSS-Configuration-File.html @@ -177,7 +178,7 @@ func TestConfHostLookupOrder(t *testing.T) { nss: &nssConf{err: os.ErrNotExist}, resolv: defaultResolvConf, }, - hostTests: []nssHostTest{{"google.com", hostLookupDNSFiles}}, + hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupDNSFiles}}, }, { name: "files_mdns_dns", @@ -186,8 +187,8 @@ func TestConfHostLookupOrder(t *testing.T) { resolv: defaultResolvConf, }, hostTests: []nssHostTest{ - {"x.com", hostLookupFilesDNS}, - {"x.local", hostLookupCgo}, + {"x.com", "myhostname", hostLookupFilesDNS}, + {"x.local", "myhostname", hostLookupCgo}, }, }, { @@ -197,9 +198,9 @@ func TestConfHostLookupOrder(t *testing.T) { resolv: defaultResolvConf, }, hostTests: []nssHostTest{ - {"x.com", hostLookupDNS}, - {"x\\.com", hostLookupCgo}, // punt on weird glibc escape - {"foo.com%en0", hostLookupCgo}, // and IPv6 zones + {"x.com", "myhostname", hostLookupDNS}, + {"x\\.com", "myhostname", hostLookupCgo}, // punt on weird glibc escape + {"foo.com%en0", "myhostname", hostLookupCgo}, // and IPv6 zones }, }, { @@ -210,8 +211,8 @@ func TestConfHostLookupOrder(t *testing.T) { hasMDNSAllow: true, }, hostTests: []nssHostTest{ - {"x.com", hostLookupCgo}, - {"x.local", hostLookupCgo}, + {"x.com", "myhostname", hostLookupCgo}, + {"x.local", "myhostname", hostLookupCgo}, }, }, { @@ -221,9 +222,9 @@ func TestConfHostLookupOrder(t *testing.T) { resolv: defaultResolvConf, }, hostTests: []nssHostTest{ - {"x.com", hostLookupFilesDNS}, - {"x", hostLookupFilesDNS}, - {"x.local", hostLookupCgo}, + {"x.com", "myhostname", hostLookupFilesDNS}, + {"x", "myhostname", hostLookupFilesDNS}, + {"x.local", "myhostname", hostLookupCgo}, }, }, { @@ -233,9 +234,9 @@ func TestConfHostLookupOrder(t *testing.T) { resolv: defaultResolvConf, }, hostTests: []nssHostTest{ - {"x.com", hostLookupDNSFiles}, - {"x", hostLookupDNSFiles}, - {"x.local", hostLookupCgo}, + {"x.com", "myhostname", hostLookupDNSFiles}, + {"x", "myhostname", hostLookupDNSFiles}, + {"x.local", "myhostname", hostLookupCgo}, }, }, { @@ -245,7 +246,7 @@ func TestConfHostLookupOrder(t *testing.T) { resolv: defaultResolvConf, }, hostTests: []nssHostTest{ - {"x.com", hostLookupCgo}, + {"x.com", "myhostname", hostLookupCgo}, }, }, { @@ -255,9 +256,23 @@ func TestConfHostLookupOrder(t *testing.T) { resolv: defaultResolvConf, }, hostTests: []nssHostTest{ - {"x.com", hostLookupFilesDNS}, - {"somehostname", hostLookupCgo}, - {"", hostLookupFilesDNS}, // Issue 13623 + {"x.com", "myhostname", hostLookupFilesDNS}, + {"myhostname", "myhostname", hostLookupCgo}, + {"myHostname", "myhostname", hostLookupCgo}, + {"myhostname.dot", "myhostname.dot", hostLookupCgo}, + {"myHostname.dot", "myhostname.dot", hostLookupCgo}, + {"gateway", "myhostname", hostLookupCgo}, + {"Gateway", "myhostname", hostLookupCgo}, + {"localhost", "myhostname", hostLookupCgo}, + {"Localhost", "myhostname", hostLookupCgo}, + {"anything.localhost", "myhostname", hostLookupCgo}, + {"Anything.localhost", "myhostname", hostLookupCgo}, + {"localhost.localdomain", "myhostname", hostLookupCgo}, + {"Localhost.Localdomain", "myhostname", hostLookupCgo}, + {"anything.localhost.localdomain", "myhostname", hostLookupCgo}, + {"Anything.Localhost.Localdomain", "myhostname", hostLookupCgo}, + {"somehostname", "myhostname", hostLookupFilesDNS}, + {"", "myhostname", hostLookupFilesDNS}, // Issue 13623 }, }, { @@ -267,8 +282,9 @@ func TestConfHostLookupOrder(t *testing.T) { resolv: defaultResolvConf, }, hostTests: []nssHostTest{ - {"x.com", hostLookupFilesDNS}, - {"somehostname", hostLookupCgo}, + {"x.com", "myhostname", hostLookupFilesDNS}, + {"somehostname", "myhostname", hostLookupFilesDNS}, + {"myhostname", "myhostname", hostLookupCgo}, }, }, // Debian Squeeze is just "dns,files", but lists all @@ -282,8 +298,8 @@ func TestConfHostLookupOrder(t *testing.T) { resolv: defaultResolvConf, }, hostTests: []nssHostTest{ - {"x.com", hostLookupDNSFiles}, - {"somehostname", hostLookupDNSFiles}, + {"x.com", "myhostname", hostLookupDNSFiles}, + {"somehostname", "myhostname", hostLookupDNSFiles}, }, }, { @@ -292,7 +308,7 @@ func TestConfHostLookupOrder(t *testing.T) { nss: nssStr("foo: bar"), resolv: &dnsConfig{servers: defaultNS, ndots: 1, timeout: 5, attempts: 2, unknownOpt: true}, }, - hostTests: []nssHostTest{{"google.com", hostLookupCgo}}, + hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupCgo}}, }, // Android should always use cgo. { @@ -303,12 +319,18 @@ func TestConfHostLookupOrder(t *testing.T) { resolv: defaultResolvConf, }, hostTests: []nssHostTest{ - {"x.com", hostLookupCgo}, + {"x.com", "myhostname", hostLookupCgo}, }, }, } + + origGetHostname := getHostname + defer func() { getHostname = origGetHostname }() + for _, tt := range tests { for _, ht := range tt.hostTests { + getHostname = func() (string, error) { return ht.localhost, nil } + gotOrder := tt.c.hostLookupOrder(ht.host) if gotOrder != ht.want { t.Errorf("%s: hostLookupOrder(%q) = %v; want %v", tt.name, ht.host, gotOrder, ht.want) diff --git a/libgo/go/net/dial.go b/libgo/go/net/dial.go index 55edb43..50bba5a 100644 --- a/libgo/go/net/dial.go +++ b/libgo/go/net/dial.go @@ -59,6 +59,9 @@ type Dialer struct { // that do not support keep-alives ignore this field. KeepAlive time.Duration + // Resolver optionally specifies an alternate resolver to use. + Resolver *Resolver + // Cancel is an optional channel whose closure indicates that // the dial should be canceled. Not all types of dials support // cancelation. @@ -92,6 +95,13 @@ func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Tim return minNonzeroTime(earliest, d.Deadline) } +func (d *Dialer) resolver() *Resolver { + if d.Resolver != nil { + return d.Resolver + } + return DefaultResolver +} + // partialDeadline returns the deadline to use for a single address, // when multiple addresses are pending. func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) { @@ -141,7 +151,7 @@ func parseNetwork(ctx context.Context, net string) (afnet string, proto int, err switch afnet { case "ip", "ip4", "ip6": protostr := net[i+1:] - proto, i, ok := dtoi(protostr, 0) + proto, i, ok := dtoi(protostr) if !ok || i != len(protostr) { proto, err = lookupProtocol(ctx, protostr) if err != nil { @@ -153,10 +163,10 @@ func parseNetwork(ctx context.Context, net string) (afnet string, proto int, err return "", 0, UnknownNetworkError(net) } -// resolverAddrList resolves addr using hint and returns a list of +// resolveAddrList resolves addr using hint and returns a list of // addresses. The result contains at least one address when error is // nil. -func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) { +func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) { afnet, _, err := parseNetwork(ctx, network) if err != nil { return nil, err @@ -166,7 +176,6 @@ func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) ( } switch afnet { case "unix", "unixgram", "unixpacket": - // TODO(bradfitz): push down context addr, err := ResolveUnixAddr(afnet, addr) if err != nil { return nil, err @@ -176,7 +185,7 @@ func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) ( } return addrList{addr}, nil } - addrs, err := internetAddrList(ctx, afnet, addr) + addrs, err := r.internetAddrList(ctx, afnet, addr) if err != nil || op != "dial" || hint == nil { return addrs, err } @@ -221,7 +230,7 @@ func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) ( } } if len(naddrs) == 0 { - return nil, errNoSuitableAddress + return nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: hint.String()} } return naddrs, nil } @@ -256,6 +265,9 @@ func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) ( // Dial("ip6:ipv6-icmp", "2001:db8::1") // // For Unix networks, the address must be a file system path. +// +// If the host is resolved to multiple addresses, +// Dial will try each address in order until one succeeds. func Dial(network, address string) (Conn, error) { var d Dialer return d.Dial(network, address) @@ -290,6 +302,14 @@ func (d *Dialer) Dial(network, address string) (Conn, error) { // connected, any expiration of the context will not affect the // connection. // +// When using TCP, and the host in the address parameter resolves to multiple +// network addresses, any dial timeout (from d.Timeout or ctx) is spread +// over each consecutive dial, such that each is given an appropriate +// fraction of the time to connect. +// For example, if a host has 4 IP addresses and the timeout is 1 minute, +// the connect to each single address will be given 15 seconds to complete +// before trying the next one. +// // See func Dial for a description of the network and address // parameters. func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) { @@ -326,7 +346,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn resolveCtx = context.WithValue(resolveCtx, nettrace.TraceKey{}, &shadow) } - addrs, err := resolveAddrList(resolveCtx, "dial", network, address, d.LocalAddr) + addrs, err := d.resolver().resolveAddrList(resolveCtx, "dial", network, address, d.LocalAddr) if err != nil { return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err} } @@ -524,8 +544,11 @@ func dialSingle(ctx context.Context, dp *dialParam, ra Addr) (c Conn, err error) // If host is omitted, as in ":8080", Listen listens on all available interfaces // instead of just the interface with the given host address. // See Dial for more details about address syntax. +// +// Listening on a hostname is not recommended because this creates a socket +// for at most one of its IP addresses. func Listen(net, laddr string) (Listener, error) { - addrs, err := resolveAddrList(context.Background(), "listen", net, laddr, nil) + addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", net, laddr, nil) if err != nil { return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err} } @@ -551,8 +574,11 @@ func Listen(net, laddr string) (Listener, error) { // If host is omitted, as in ":8080", ListenPacket listens on all available interfaces // instead of just the interface with the given host address. // See Dial for the syntax of laddr. +// +// Listening on a hostname is not recommended because this creates a socket +// for at most one of its IP addresses. func ListenPacket(net, laddr string) (PacketConn, error) { - addrs, err := resolveAddrList(context.Background(), "listen", net, laddr, nil) + addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", net, laddr, nil) if err != nil { return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err} } diff --git a/libgo/go/net/dial_test.go b/libgo/go/net/dial_test.go index 8b21e6b..9919d72 100644 --- a/libgo/go/net/dial_test.go +++ b/libgo/go/net/dial_test.go @@ -55,6 +55,23 @@ func TestProhibitionaryDialArg(t *testing.T) { } } +func TestDialLocal(t *testing.T) { + ln, err := newLocalListener("tcp") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + _, port, err := SplitHostPort(ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + c, err := Dial("tcp", JoinHostPort("", port)) + if err != nil { + t.Fatal(err) + } + c.Close() +} + func TestDialTimeoutFDLeak(t *testing.T) { switch runtime.GOOS { case "plan9": @@ -125,6 +142,8 @@ func TestDialerDualStackFDLeak(t *testing.T) { t.Skipf("%s does not have full support of socktest", runtime.GOOS) case "windows": t.Skipf("not implemented a way to cancel dial racers in TCP SYN-SENT state on %s", runtime.GOOS) + case "openbsd": + testenv.SkipFlaky(t, 15157) } if !supportsIPv4 || !supportsIPv6 { t.Skip("both IPv4 and IPv6 are required") diff --git a/libgo/go/net/dnsclient.go b/libgo/go/net/dnsclient.go index f1835b8..2ab5639 100644 --- a/libgo/go/net/dnsclient.go +++ b/libgo/go/net/dnsclient.go @@ -113,12 +113,20 @@ func equalASCIILabel(x, y string) bool { return true } +// isDomainName checks if a string is a presentation-format domain name +// (currently restricted to hostname-compatible "preferred name" LDH labels and +// SRV-like "underscore labels"; see golang.org/issue/12421). func isDomainName(s string) bool { // See RFC 1035, RFC 3696. - if len(s) == 0 { - return false - } - if len(s) > 255 { + // Presentation format has dots before every label except the first, and the + // terminal empty label is optional here because we assume fully-qualified + // (absolute) input. We must therefore reserve space for the first and last + // labels' length octets in wire format, where they are necessary and the + // maximum total length is 255. + // So our _effective_ maximum is 253, but 254 is not rejected if the last + // character is a dot. + l := len(s) + if l == 0 || l > 254 || l == 254 && s[l-1] != '.' { return false } diff --git a/libgo/go/net/dnsclient_unix.go b/libgo/go/net/dnsclient_unix.go index b5b6ffb..4dd4e16 100644 --- a/libgo/go/net/dnsclient_unix.go +++ b/libgo/go/net/dnsclient_unix.go @@ -125,7 +125,7 @@ func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn, // Calling Dial here is scary -- we have to be sure not to // dial a name that will require a DNS lookup, or Dial will // call back here to translate it. The DNS config parser has - // already checked that all the cfg.servers[i] are IP + // already checked that all the cfg.servers are IP // addresses, which Dial will use without a DNS lookup. c, err := d.DialContext(ctx, network, server) if err != nil { @@ -182,13 +182,14 @@ func exchange(ctx context.Context, server, name string, qtype uint16, timeout ti // Do a lookup for a single name, which must be rooted // (otherwise answer will not find the answers). func tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) { - if len(cfg.servers) == 0 { - return "", nil, &DNSError{Err: "no DNS servers", Name: name} - } - var lastErr error + serverOffset := cfg.serverOffset() + sLen := uint32(len(cfg.servers)) + for i := 0; i < cfg.attempts; i++ { - for _, server := range cfg.servers { + for j := uint32(0); j < sLen; j++ { + server := cfg.servers[(serverOffset+j)%sLen] + msg, err := exchange(ctx, server, name, qtype, cfg.timeout) if err != nil { lastErr = &DNSError{ @@ -315,7 +316,12 @@ func (conf *resolverConfig) releaseSema() { func lookup(ctx context.Context, name string, qtype uint16) (cname string, rrs []dnsRR, err error) { if !isDomainName(name) { - return "", nil, &DNSError{Err: "invalid domain name", Name: name} + // We used to use "invalid domain name" as the error, + // but that is a detail of the specific lookup mechanism. + // Other lookups might allow broader name syntax + // (for example Multicast DNS allows UTF-8; see RFC 6762). + // For consistency with libc resolvers, report no such host. + return "", nil, &DNSError{Err: errNoSuchHost.Error(), Name: name} } resolvConf.tryUpdate("/etc/resolv.conf") resolvConf.mu.RLock() @@ -356,14 +362,21 @@ func (conf *dnsConfig) nameList(name string) []string { return nil } + // Check name length (see isDomainName). + l := len(name) + rooted := l > 0 && name[l-1] == '.' + if l > 254 || l == 254 && rooted { + return nil + } + // If name is rooted (trailing dot), try only that name. - rooted := len(name) > 0 && name[len(name)-1] == '.' if rooted { return []string{name} } hasNdots := count(name, '.') >= conf.ndots name += "." + l++ // Build list of search choices. names := make([]string, 0, 1+len(conf.search)) @@ -371,9 +384,11 @@ func (conf *dnsConfig) nameList(name string) []string { if hasNdots { names = append(names, name) } - // Try suffixes. + // Try suffixes that are not too long (see isDomainName). for _, suffix := range conf.search { - names = append(names, name+suffix) + if l+len(suffix) <= 254 { + names = append(names, name+suffix) + } } // Try unsuffixed, if not tried first above. if !hasNdots { @@ -429,7 +444,7 @@ func goLookupHostOrder(ctx context.Context, name string, order hostLookupOrder) return } } - ips, err := goLookupIPOrder(ctx, name, order) + ips, _, err := goLookupIPCNAMEOrder(ctx, name, order) if err != nil { return } @@ -455,27 +470,30 @@ func goLookupIPFiles(name string) (addrs []IPAddr) { // goLookupIP is the native Go implementation of LookupIP. // The libc versions are in cgo_*.go. -func goLookupIP(ctx context.Context, name string) (addrs []IPAddr, err error) { - return goLookupIPOrder(ctx, name, hostLookupFilesDNS) +func goLookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { + order := systemConf().hostLookupOrder(host) + addrs, _, err = goLookupIPCNAMEOrder(ctx, host, order) + return } -func goLookupIPOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, err error) { +func goLookupIPCNAMEOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, cname string, err error) { if order == hostLookupFilesDNS || order == hostLookupFiles { addrs = goLookupIPFiles(name) if len(addrs) > 0 || order == hostLookupFiles { - return addrs, nil + return addrs, name, nil } } if !isDomainName(name) { - return nil, &DNSError{Err: "invalid domain name", Name: name} + // See comment in func lookup above about use of errNoSuchHost. + return nil, "", &DNSError{Err: errNoSuchHost.Error(), Name: name} } resolvConf.tryUpdate("/etc/resolv.conf") resolvConf.mu.RLock() conf := resolvConf.dnsConfig resolvConf.mu.RUnlock() type racer struct { - fqdn string - rrs []dnsRR + cname string + rrs []dnsRR error } lane := make(chan racer, 1) @@ -484,20 +502,23 @@ func goLookupIPOrder(ctx context.Context, name string, order hostLookupOrder) (a for _, fqdn := range conf.nameList(name) { for _, qtype := range qtypes { go func(qtype uint16) { - _, rrs, err := tryOneName(ctx, conf, fqdn, qtype) - lane <- racer{fqdn, rrs, err} + cname, rrs, err := tryOneName(ctx, conf, fqdn, qtype) + lane <- racer{cname, rrs, err} }(qtype) } for range qtypes { racer := <-lane if racer.error != nil { // Prefer error for original name. - if lastErr == nil || racer.fqdn == name+"." { + if lastErr == nil || fqdn == name+"." { lastErr = racer.error } continue } addrs = append(addrs, addrRecordList(racer.rrs)...) + if cname == "" { + cname = racer.cname + } } if len(addrs) > 0 { break @@ -515,24 +536,16 @@ func goLookupIPOrder(ctx context.Context, name string, order hostLookupOrder) (a addrs = goLookupIPFiles(name) } if len(addrs) == 0 && lastErr != nil { - return nil, lastErr + return nil, "", lastErr } } - return addrs, nil + return addrs, cname, nil } -// goLookupCNAME is the native Go implementation of LookupCNAME. -// Used only if cgoLookupCNAME refuses to handle the request -// (that is, only if cgoLookupCNAME is the stub in cgo_stub.go). -// Normally we let cgo use the C library resolver instead of -// depending on our lookup code, so that Go and C get the same -// answers. -func goLookupCNAME(ctx context.Context, name string) (cname string, err error) { - _, rrs, err := lookup(ctx, name, dnsTypeCNAME) - if err != nil { - return - } - cname = rrs[0].(*dnsRR_CNAME).Cname +// goLookupCNAME is the native Go (non-cgo) implementation of LookupCNAME. +func goLookupCNAME(ctx context.Context, host string) (cname string, err error) { + order := systemConf().hostLookupOrder(host) + _, cname, err = goLookupIPCNAMEOrder(ctx, host, order) return } diff --git a/libgo/go/net/dnsclient_unix_test.go b/libgo/go/net/dnsclient_unix_test.go index 6ebeeae..85267bb 100644 --- a/libgo/go/net/dnsclient_unix_test.go +++ b/libgo/go/net/dnsclient_unix_test.go @@ -411,7 +411,7 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) { // We need to take care with errors on both // DNS message exchange layer and DNS // transport layer because goLookupIP may fail - // when the IP connectivty on node under test + // when the IP connectivity on node under test // gets lost during its run. if err, ok := err.(*DNSError); !ok || tt.error != nil && (err.Name != tt.error.(*DNSError).Name || err.Server != tt.error.(*DNSError).Server || err.IsTimeout != tt.error.(*DNSError).IsTimeout) { t.Errorf("got %v; want %v", err, tt.error) @@ -455,14 +455,14 @@ func TestGoLookupIPOrderFallbackToFile(t *testing.T) { name := fmt.Sprintf("order %v", order) // First ensure that we get an error when contacting a non-existent host. - _, err := goLookupIPOrder(context.Background(), "notarealhost", order) + _, _, err := goLookupIPCNAMEOrder(context.Background(), "notarealhost", order) if err == nil { t.Errorf("%s: expected error while looking up name not in hosts file", name) continue } // Now check that we get an address when the name appears in the hosts file. - addrs, err := goLookupIPOrder(context.Background(), "thor", order) // entry is in "testdata/hosts" + addrs, _, err := goLookupIPCNAMEOrder(context.Background(), "thor", order) // entry is in "testdata/hosts" if err != nil { t.Errorf("%s: expected to successfully lookup host entry", name) continue @@ -668,12 +668,14 @@ func TestIgnoreDNSForgeries(t *testing.T) { b := make([]byte, 512) n, err := s.Read(b) if err != nil { - t.Fatal(err) + t.Error(err) + return } msg := &dnsMsg{} if !msg.Unpack(b[:n]) { - t.Fatal("invalid DNS query") + t.Error("invalid DNS query") + return } s.Write([]byte("garbage DNS response packet")) @@ -682,7 +684,8 @@ func TestIgnoreDNSForgeries(t *testing.T) { msg.id++ // make invalid ID b, ok := msg.Pack() if !ok { - t.Fatal("failed to pack DNS response") + t.Error("failed to pack DNS response") + return } s.Write(b) @@ -701,7 +704,8 @@ func TestIgnoreDNSForgeries(t *testing.T) { b, ok = msg.Pack() if !ok { - t.Fatal("failed to pack DNS response") + t.Error("failed to pack DNS response") + return } s.Write(b) }() @@ -740,8 +744,11 @@ func TestRetryTimeout(t *testing.T) { } defer conf.teardown() - if err := conf.writeAndUpdate([]string{"nameserver 192.0.2.1", // the one that will timeout - "nameserver 192.0.2.2"}); err != nil { + testConf := []string{ + "nameserver 192.0.2.1", // the one that will timeout + "nameserver 192.0.2.2", + } + if err := conf.writeAndUpdate(testConf); err != nil { t.Fatal(err) } @@ -767,28 +774,10 @@ func TestRetryTimeout(t *testing.T) { t.Error("deadline didn't change") } - r := &dnsMsg{ - dnsMsgHdr: dnsMsgHdr{ - id: q.id, - response: true, - recursion_available: true, - }, - question: q.question, - answer: []dnsRR{ - &dnsRR_CNAME{ - Hdr: dnsRR_Header{ - Name: q.question[0].Name, - Rrtype: dnsTypeCNAME, - Class: dnsClassINET, - }, - Cname: "golang.org", - }, - }, - } - return r, nil + return mockTXTResponse(q), nil } - _, err = goLookupCNAME(context.Background(), "www.golang.org") + _, err = LookupTXT("www.golang.org") if err != nil { t.Fatal(err) } @@ -797,3 +786,77 @@ func TestRetryTimeout(t *testing.T) { t.Error("deadline0 still zero", deadline0) } } + +func TestRotate(t *testing.T) { + // without rotation, always uses the first server + testRotate(t, false, []string{"192.0.2.1", "192.0.2.2"}, []string{"192.0.2.1:53", "192.0.2.1:53", "192.0.2.1:53"}) + + // with rotation, rotates through back to first + testRotate(t, true, []string{"192.0.2.1", "192.0.2.2"}, []string{"192.0.2.1:53", "192.0.2.2:53", "192.0.2.1:53"}) +} + +func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) { + origTestHookDNSDialer := testHookDNSDialer + defer func() { testHookDNSDialer = origTestHookDNSDialer }() + + conf, err := newResolvConfTest() + if err != nil { + t.Fatal(err) + } + defer conf.teardown() + + var confLines []string + for _, ns := range nameservers { + confLines = append(confLines, "nameserver "+ns) + } + if rotate { + confLines = append(confLines, "options rotate") + } + + if err := conf.writeAndUpdate(confLines); err != nil { + t.Fatal(err) + } + + d := &fakeDNSDialer{} + testHookDNSDialer = func() dnsDialer { return d } + + var usedServers []string + d.rh = func(s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { + usedServers = append(usedServers, s) + return mockTXTResponse(q), nil + } + + // len(nameservers) + 1 to allow rotation to get back to start + for i := 0; i < len(nameservers)+1; i++ { + if _, err := LookupTXT("www.golang.org"); err != nil { + t.Fatal(err) + } + } + + if !reflect.DeepEqual(usedServers, wantServers) { + t.Errorf("rotate=%t got used servers:\n%v\nwant:\n%v", rotate, usedServers, wantServers) + } +} + +func mockTXTResponse(q *dnsMsg) *dnsMsg { + r := &dnsMsg{ + dnsMsgHdr: dnsMsgHdr{ + id: q.id, + response: true, + recursion_available: true, + }, + question: q.question, + answer: []dnsRR{ + &dnsRR_TXT{ + Hdr: dnsRR_Header{ + Name: q.question[0].Name, + Rrtype: dnsTypeTXT, + Class: dnsClassINET, + }, + Txt: "ok", + }, + }, + } + + return r +} diff --git a/libgo/go/net/dnsconfig_unix.go b/libgo/go/net/dnsconfig_unix.go index aec575e..9c8108d 100644 --- a/libgo/go/net/dnsconfig_unix.go +++ b/libgo/go/net/dnsconfig_unix.go @@ -10,6 +10,7 @@ package net import ( "os" + "sync/atomic" "time" ) @@ -29,6 +30,7 @@ type dnsConfig struct { lookup []string // OpenBSD top-level database "lookup" order err error // any error that occurs during open of resolv.conf mtime time.Time // time of resolv.conf modification + soffset uint32 // used by serverOffset } // See resolv.conf(5) on a Linux machine. @@ -91,19 +93,21 @@ func dnsReadConfig(filename string) *dnsConfig { for _, s := range f[1:] { switch { case hasPrefix(s, "ndots:"): - n, _, _ := dtoi(s, 6) - if n < 1 { - n = 1 + n, _, _ := dtoi(s[6:]) + if n < 0 { + n = 0 + } else if n > 15 { + n = 15 } conf.ndots = n case hasPrefix(s, "timeout:"): - n, _, _ := dtoi(s, 8) + n, _, _ := dtoi(s[8:]) if n < 1 { n = 1 } conf.timeout = time.Duration(n) * time.Second case hasPrefix(s, "attempts:"): - n, _, _ := dtoi(s, 9) + n, _, _ := dtoi(s[9:]) if n < 1 { n = 1 } @@ -134,6 +138,17 @@ func dnsReadConfig(filename string) *dnsConfig { return conf } +// serverOffset returns an offset that can be used to determine +// indices of servers in c.servers when making queries. +// When the rotate option is enabled, this offset increases. +// Otherwise it is always 0. +func (c *dnsConfig) serverOffset() uint32 { + if c.rotate { + return atomic.AddUint32(&c.soffset, 1) - 1 // return 0 to start + } + return 0 +} + func dnsDefaultSearch() []string { hn, err := getHostname() if err != nil { diff --git a/libgo/go/net/dnsconfig_unix_test.go b/libgo/go/net/dnsconfig_unix_test.go index 9fd6dbf..37bdeb0 100644 --- a/libgo/go/net/dnsconfig_unix_test.go +++ b/libgo/go/net/dnsconfig_unix_test.go @@ -10,6 +10,7 @@ import ( "errors" "os" "reflect" + "strings" "testing" "time" ) @@ -61,6 +62,36 @@ var dnsReadConfigTests = []struct { }, }, { + name: "testdata/invalid-ndots-resolv.conf", + want: &dnsConfig{ + servers: defaultNS, + ndots: 0, + timeout: 5 * time.Second, + attempts: 2, + search: []string{"domain.local."}, + }, + }, + { + name: "testdata/large-ndots-resolv.conf", + want: &dnsConfig{ + servers: defaultNS, + ndots: 15, + timeout: 5 * time.Second, + attempts: 2, + search: []string{"domain.local."}, + }, + }, + { + name: "testdata/negative-ndots-resolv.conf", + want: &dnsConfig{ + servers: defaultNS, + ndots: 0, + timeout: 5 * time.Second, + attempts: 2, + search: []string{"domain.local."}, + }, + }, + { name: "testdata/openbsd-resolv.conf", want: &dnsConfig{ ndots: 1, @@ -154,3 +185,55 @@ func TestDNSDefaultSearch(t *testing.T) { } } } + +func TestDNSNameLength(t *testing.T) { + origGetHostname := getHostname + defer func() { getHostname = origGetHostname }() + getHostname = func() (string, error) { return "host.domain.local", nil } + + var char63 = "" + for i := 0; i < 63; i++ { + char63 += "a" + } + longDomain := strings.Repeat(char63+".", 5) + "example" + + for _, tt := range dnsReadConfigTests { + conf := dnsReadConfig(tt.name) + if conf.err != nil { + t.Fatal(conf.err) + } + + var shortestSuffix int + for _, suffix := range tt.want.search { + if shortestSuffix == 0 || len(suffix) < shortestSuffix { + shortestSuffix = len(suffix) + } + } + + // Test a name that will be maximally long when prefixing the shortest + // suffix (accounting for the intervening dot). + longName := longDomain[len(longDomain)-254+1+shortestSuffix:] + if longName[0] == '.' || longName[1] == '.' { + longName = "aa." + longName[3:] + } + for _, fqdn := range conf.nameList(longName) { + if len(fqdn) > 254 { + t.Errorf("got %d; want less than or equal to 254", len(fqdn)) + } + } + + // Now test a name that's too long for suffixing. + unsuffixable := "a." + longName[1:] + unsuffixableResults := conf.nameList(unsuffixable) + if len(unsuffixableResults) != 1 { + t.Errorf("suffixed names %v; want []", unsuffixableResults[1:]) + } + + // Now test a name that's too long for DNS. + tooLong := "a." + longDomain + tooLongResults := conf.nameList(tooLong) + if tooLongResults != nil { + t.Errorf("suffixed names %v; want nil", tooLongResults) + } + } +} diff --git a/libgo/go/net/dnsmsg.go b/libgo/go/net/dnsmsg.go index afdb44c..8f6c7b6 100644 --- a/libgo/go/net/dnsmsg.go +++ b/libgo/go/net/dnsmsg.go @@ -69,7 +69,7 @@ const ( ) // A dnsStruct describes how to iterate over its fields to emulate -// reflective marshalling. +// reflective marshaling. type dnsStruct interface { // Walk iterates over fields of a structure and calls f // with a reference to that field, the name of the field diff --git a/libgo/go/net/dnsmsg_test.go b/libgo/go/net/dnsmsg_test.go index 25bd98c..2a25a21 100644 --- a/libgo/go/net/dnsmsg_test.go +++ b/libgo/go/net/dnsmsg_test.go @@ -117,7 +117,7 @@ func TestDNSParseSRVReply(t *testing.T) { if !ok { t.Fatal("unpacking packet failed") } - msg.String() // exercise this code path + _ = msg.String() // exercise this code path if g, e := len(msg.answer), 5; g != e { t.Errorf("len(msg.answer) = %d; want %d", g, e) } @@ -165,7 +165,7 @@ func TestDNSParseCorruptSRVReply(t *testing.T) { if !ok { t.Fatal("unpacking packet failed") } - msg.String() // exercise this code path + _ = msg.String() // exercise this code path if g, e := len(msg.answer), 5; g != e { t.Errorf("len(msg.answer) = %d; want %d", g, e) } @@ -393,7 +393,7 @@ func TestIsResponseTo(t *testing.T) { for i := range badResponses { if badResponses[i].IsResponseTo(&query) { - t.Error("%v: got true, want false", i) + t.Errorf("%v: got true, want false", i) } } } diff --git a/libgo/go/net/dnsname_test.go b/libgo/go/net/dnsname_test.go index bc777b8..e0f786d 100644 --- a/libgo/go/net/dnsname_test.go +++ b/libgo/go/net/dnsname_test.go @@ -32,14 +32,12 @@ var dnsNameTests = []dnsNameTest{ func emitDNSNameTest(ch chan<- dnsNameTest) { defer close(ch) - var char59 = "" var char63 = "" - var char64 = "" - for i := 0; i < 59; i++ { - char59 += "a" + for i := 0; i < 63; i++ { + char63 += "a" } - char63 = char59 + "aaaa" - char64 = char63 + "a" + char64 := char63 + "a" + longDomain := strings.Repeat(char63+".", 5) + "example" for _, tc := range dnsNameTests { ch <- tc @@ -47,14 +45,15 @@ func emitDNSNameTest(ch chan<- dnsNameTest) { ch <- dnsNameTest{char63 + ".com", true} ch <- dnsNameTest{char64 + ".com", false} - // 255 char name is fine: - ch <- dnsNameTest{char59 + "." + char63 + "." + char63 + "." + - char63 + ".com", - true} - // 256 char name is bad: - ch <- dnsNameTest{char59 + "a." + char63 + "." + char63 + "." + - char63 + ".com", - false} + + // Remember: wire format is two octets longer than presentation + // (length octets for the first and [root] last labels). + // 253 is fine: + ch <- dnsNameTest{longDomain[len(longDomain)-253:], true} + // A terminal dot doesn't contribute to length: + ch <- dnsNameTest{longDomain[len(longDomain)-253:] + ".", true} + // 254 is bad: + ch <- dnsNameTest{longDomain[len(longDomain)-254:], false} } func TestDNSName(t *testing.T) { diff --git a/libgo/go/net/error_test.go b/libgo/go/net/error_test.go index d6de5a3..c23da49 100644 --- a/libgo/go/net/error_test.go +++ b/libgo/go/net/error_test.go @@ -97,7 +97,8 @@ second: goto third } switch nestedErr { - case errCanceled, errClosing, errMissingAddress, errNoSuitableAddress: + case errCanceled, errClosing, errMissingAddress, errNoSuitableAddress, + context.DeadlineExceeded, context.Canceled: return nil } return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr) @@ -230,23 +231,27 @@ func TestDialAddrError(t *testing.T) { } { var err error var c Conn + var op string if tt.lit != "" { c, err = Dial(tt.network, JoinHostPort(tt.lit, "0")) + op = fmt.Sprintf("Dial(%q, %q)", tt.network, JoinHostPort(tt.lit, "0")) } else { c, err = DialTCP(tt.network, nil, tt.addr) + op = fmt.Sprintf("DialTCP(%q, %q)", tt.network, tt.addr) } if err == nil { c.Close() - t.Errorf("%s %q/%v: should fail", tt.network, tt.lit, tt.addr) + t.Errorf("%s succeeded, want error", op) continue } if perr := parseDialError(err); perr != nil { - t.Error(perr) + t.Errorf("%s: %v", op, perr) continue } - aerr, ok := err.(*OpError).Err.(*AddrError) + operr := err.(*OpError).Err + aerr, ok := operr.(*AddrError) if !ok { - t.Errorf("%s %q/%v: should be AddrError: %v", tt.network, tt.lit, tt.addr, err) + t.Errorf("%s: %v is %T, want *AddrError", op, err, operr) continue } want := tt.lit @@ -254,7 +259,7 @@ func TestDialAddrError(t *testing.T) { want = tt.addr.IP.String() } if aerr.Addr != want { - t.Fatalf("%s: got %q; want %q", tt.network, aerr.Addr, want) + t.Errorf("%s: %v, error Addr=%q, want %q", op, err, aerr.Addr, want) } } } @@ -521,6 +526,10 @@ third: if isPlatformError(nestedErr) { return nil } + switch nestedErr { + case os.ErrClosed: // for Plan 9 + return nil + } return fmt.Errorf("unexpected type on 3rd nested level: %T", nestedErr) } diff --git a/libgo/go/net/fd_io_plan9.go b/libgo/go/net/fd_io_plan9.go new file mode 100644 index 0000000..76da0c5 --- /dev/null +++ b/libgo/go/net/fd_io_plan9.go @@ -0,0 +1,93 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "os" + "runtime" + "sync" + "syscall" +) + +// asyncIO implements asynchronous cancelable I/O. +// An asyncIO represents a single asynchronous Read or Write +// operation. The result is returned on the result channel. +// The undergoing I/O system call can either complete or be +// interrupted by a note. +type asyncIO struct { + res chan result + + // mu guards the pid field. + mu sync.Mutex + + // pid holds the process id of + // the process running the IO operation. + pid int +} + +// result is the return value of a Read or Write operation. +type result struct { + n int + err error +} + +// newAsyncIO returns a new asyncIO that performs an I/O +// operation by calling fn, which must do one and only one +// interruptible system call. +func newAsyncIO(fn func([]byte) (int, error), b []byte) *asyncIO { + aio := &asyncIO{ + res: make(chan result, 0), + } + aio.mu.Lock() + go func() { + // Lock the current goroutine to its process + // and store the pid in io so that Cancel can + // interrupt it. We ignore the "hangup" signal, + // so the signal does not take down the entire + // Go runtime. + runtime.LockOSThread() + runtime_ignoreHangup() + aio.pid = os.Getpid() + aio.mu.Unlock() + + n, err := fn(b) + + aio.mu.Lock() + aio.pid = -1 + runtime_unignoreHangup() + aio.mu.Unlock() + + aio.res <- result{n, err} + }() + return aio +} + +var hangupNote os.Signal = syscall.Note("hangup") + +// Cancel interrupts the I/O operation, causing +// the Wait function to return. +func (aio *asyncIO) Cancel() { + aio.mu.Lock() + defer aio.mu.Unlock() + if aio.pid == -1 { + return + } + proc, err := os.FindProcess(aio.pid) + if err != nil { + return + } + proc.Signal(hangupNote) +} + +// Wait for the I/O operation to complete. +func (aio *asyncIO) Wait() (int, error) { + res := <-aio.res + return res.n, res.err +} + +// The following functions, provided by the runtime, are used to +// ignore and unignore the "hangup" signal received by the process. +func runtime_ignoreHangup() +func runtime_unignoreHangup() diff --git a/libgo/go/net/fd_plan9.go b/libgo/go/net/fd_plan9.go index 7533232..300d8c4 100644 --- a/libgo/go/net/fd_plan9.go +++ b/libgo/go/net/fd_plan9.go @@ -7,21 +7,37 @@ package net import ( "io" "os" + "sync/atomic" "syscall" "time" ) +type atomicBool int32 + +func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 } +func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) } +func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) } + // Network file descriptor. type netFD struct { // locking/lifetime of sysfd + serialize access to Read and Write methods fdmu fdMutex // immutable until Close - net string - n string - dir string - ctl, data *os.File - laddr, raddr Addr + net string + n string + dir string + listen, ctl, data *os.File + laddr, raddr Addr + isStream bool + + // deadlines + raio *asyncIO + waio *asyncIO + rtimer *time.Timer + wtimer *time.Timer + rtimedout atomicBool // set true when read deadline has been reached + wtimedout atomicBool // set true when write deadline has been reached } var ( @@ -32,8 +48,16 @@ func sysInit() { netdir = "/net" } -func newFD(net, name string, ctl, data *os.File, laddr, raddr Addr) (*netFD, error) { - return &netFD{net: net, n: name, dir: netdir + "/" + net + "/" + name, ctl: ctl, data: data, laddr: laddr, raddr: raddr}, nil +func newFD(net, name string, listen, ctl, data *os.File, laddr, raddr Addr) (*netFD, error) { + return &netFD{ + net: net, + n: name, + dir: netdir + "/" + net + "/" + name, + listen: listen, + ctl: ctl, data: data, + laddr: laddr, + raddr: raddr, + }, nil } func (fd *netFD) init() error { @@ -64,11 +88,20 @@ func (fd *netFD) destroy() { err = err1 } } + if fd.listen != nil { + if err1 := fd.listen.Close(); err1 != nil && err == nil { + err = err1 + } + } fd.ctl = nil fd.data = nil + fd.listen = nil } func (fd *netFD) Read(b []byte) (n int, err error) { + if fd.rtimedout.isSet() { + return 0, errTimeout + } if !fd.ok() || fd.data == nil { return 0, syscall.EINVAL } @@ -79,10 +112,15 @@ func (fd *netFD) Read(b []byte) (n int, err error) { if len(b) == 0 { return 0, nil } - n, err = fd.data.Read(b) + fd.raio = newAsyncIO(fd.data.Read, b) + n, err = fd.raio.Wait() + fd.raio = nil if isHangup(err) { err = io.EOF } + if isInterrupted(err) { + err = errTimeout + } if fd.net == "udp" && err == io.EOF { n = 0 err = nil @@ -91,6 +129,9 @@ func (fd *netFD) Read(b []byte) (n int, err error) { } func (fd *netFD) Write(b []byte) (n int, err error) { + if fd.wtimedout.isSet() { + return 0, errTimeout + } if !fd.ok() || fd.data == nil { return 0, syscall.EINVAL } @@ -98,7 +139,13 @@ func (fd *netFD) Write(b []byte) (n int, err error) { return 0, err } defer fd.writeUnlock() - return fd.data.Write(b) + fd.waio = newAsyncIO(fd.data.Write, b) + n, err = fd.waio.Wait() + fd.waio = nil + if isInterrupted(err) { + err = errTimeout + } + return } func (fd *netFD) closeRead() error { @@ -124,11 +171,10 @@ func (fd *netFD) Close() error { } if fd.net == "tcp" { // The following line is required to unblock Reads. - // For some reason, WriteString returns an error: - // "write /net/tcp/39/listen: inappropriate use of fd" - // But without it, Reads on dead conns hang forever. - // See Issue 9554. - fd.ctl.WriteString("hangup") + _, err := fd.ctl.WriteString("close") + if err != nil { + return err + } } err := fd.ctl.Close() if fd.data != nil { @@ -136,8 +182,14 @@ func (fd *netFD) Close() error { err = err1 } } + if fd.listen != nil { + if err1 := fd.listen.Close(); err1 != nil && err == nil { + err = err1 + } + } fd.ctl = nil fd.data = nil + fd.listen = nil return err } @@ -165,15 +217,74 @@ func (fd *netFD) file(f *os.File, s string) (*os.File, error) { } func (fd *netFD) setDeadline(t time.Time) error { - return syscall.EPLAN9 + return setDeadlineImpl(fd, t, 'r'+'w') } func (fd *netFD) setReadDeadline(t time.Time) error { - return syscall.EPLAN9 + return setDeadlineImpl(fd, t, 'r') } func (fd *netFD) setWriteDeadline(t time.Time) error { - return syscall.EPLAN9 + return setDeadlineImpl(fd, t, 'w') +} + +func setDeadlineImpl(fd *netFD, t time.Time, mode int) error { + d := t.Sub(time.Now()) + if mode == 'r' || mode == 'r'+'w' { + fd.rtimedout.setFalse() + } + if mode == 'w' || mode == 'r'+'w' { + fd.wtimedout.setFalse() + } + if t.IsZero() || d < 0 { + // Stop timer + if mode == 'r' || mode == 'r'+'w' { + if fd.rtimer != nil { + fd.rtimer.Stop() + } + fd.rtimer = nil + } + if mode == 'w' || mode == 'r'+'w' { + if fd.wtimer != nil { + fd.wtimer.Stop() + } + fd.wtimer = nil + } + } else { + // Interrupt I/O operation once timer has expired + if mode == 'r' || mode == 'r'+'w' { + fd.rtimer = time.AfterFunc(d, func() { + fd.rtimedout.setTrue() + if fd.raio != nil { + fd.raio.Cancel() + } + }) + } + if mode == 'w' || mode == 'r'+'w' { + fd.wtimer = time.AfterFunc(d, func() { + fd.wtimedout.setTrue() + if fd.waio != nil { + fd.waio.Cancel() + } + }) + } + } + if !t.IsZero() && d < 0 { + // Interrupt current I/O operation + if mode == 'r' || mode == 'r'+'w' { + fd.rtimedout.setTrue() + if fd.raio != nil { + fd.raio.Cancel() + } + } + if mode == 'w' || mode == 'r'+'w' { + fd.wtimedout.setTrue() + if fd.waio != nil { + fd.waio.Cancel() + } + } + } + return nil } func setReadBuffer(fd *netFD, bytes int) error { @@ -187,3 +298,7 @@ func setWriteBuffer(fd *netFD, bytes int) error { func isHangup(err error) bool { return err != nil && stringsHasSuffix(err.Error(), "Hangup") } + +func isInterrupted(err error) bool { + return err != nil && stringsHasSuffix(err.Error(), "interrupted") +} diff --git a/libgo/go/net/fd_poll_nacl.go b/libgo/go/net/fd_poll_nacl.go index cda8b82..8398760 100644 --- a/libgo/go/net/fd_poll_nacl.go +++ b/libgo/go/net/fd_poll_nacl.go @@ -5,6 +5,7 @@ package net import ( + "runtime" "syscall" "time" ) @@ -22,6 +23,7 @@ func (pd *pollDesc) evict() { pd.closing = true if pd.fd != nil { syscall.StopIO(pd.fd.sysfd) + runtime.KeepAlive(pd.fd) } } diff --git a/libgo/go/net/fd_poll_runtime.go b/libgo/go/net/fd_poll_runtime.go index 6c1d095..62b69fc 100644 --- a/libgo/go/net/fd_poll_runtime.go +++ b/libgo/go/net/fd_poll_runtime.go @@ -7,6 +7,7 @@ package net import ( + "runtime" "sync" "syscall" "time" @@ -33,6 +34,7 @@ var serverInit sync.Once func (pd *pollDesc) init(fd *netFD) error { serverInit.Do(runtime_pollServerInit) ctx, errno := runtime_pollOpen(uintptr(fd.sysfd)) + runtime.KeepAlive(fd) if errno != 0 { return syscall.Errno(errno) } @@ -120,7 +122,7 @@ func (fd *netFD) setWriteDeadline(t time.Time) error { } func setDeadlineImpl(fd *netFD, t time.Time, mode int) error { - diff := int64(t.Sub(time.Now())) + diff := int64(time.Until(t)) d := runtimeNano() + diff if d <= 0 && diff > 0 { // If the user has a deadline in the future, but the delay calculation diff --git a/libgo/go/net/fd_unix.go b/libgo/go/net/fd_unix.go index 0309db0..9bc5ebc 100644 --- a/libgo/go/net/fd_unix.go +++ b/libgo/go/net/fd_unix.go @@ -24,11 +24,15 @@ type netFD struct { sysfd int family int sotype int + isStream bool isConnected bool net string laddr Addr raddr Addr + // writev cache. + iovecs *[]syscall.Iovec + // wait server pd pollDesc } @@ -37,7 +41,7 @@ func sysInit() { } func newFD(sysfd, family, sotype int, net string) (*netFD, error) { - return &netFD{sysfd: sysfd, family: family, sotype: sotype, net: net}, nil + return &netFD{sysfd: sysfd, family: family, sotype: sotype, net: net, isStream: sotype == syscall.SOCK_STREAM}, nil } func (fd *netFD) init() error { @@ -235,6 +239,9 @@ func (fd *netFD) Read(p []byte) (n int, err error) { if err := fd.pd.prepareRead(); err != nil { return 0, err } + if fd.isStream && len(p) > 1<<30 { + p = p[:1<<30] + } for { n, err = syscall.Read(fd.sysfd, p) if err != nil { @@ -318,7 +325,11 @@ func (fd *netFD) Write(p []byte) (nn int, err error) { } for { var n int - n, err = syscall.Write(fd.sysfd, p[nn:]) + max := len(p) + if fd.isStream && max-nn > 1<<30 { + max = nn + 1<<30 + } + n, err = syscall.Write(fd.sysfd, p[nn:max]) if n > 0 { nn += n } diff --git a/libgo/go/net/fd_windows.go b/libgo/go/net/fd_windows.go index b0b6769..a976f2a 100644 --- a/libgo/go/net/fd_windows.go +++ b/libgo/go/net/fd_windows.go @@ -96,6 +96,7 @@ type operation struct { rsan int32 handle syscall.Handle flags uint32 + bufs []syscall.WSABuf } func (o *operation) InitBuf(buf []byte) { @@ -106,6 +107,30 @@ func (o *operation) InitBuf(buf []byte) { } } +func (o *operation) InitBufs(buf *Buffers) { + if o.bufs == nil { + o.bufs = make([]syscall.WSABuf, 0, len(*buf)) + } else { + o.bufs = o.bufs[:0] + } + for _, b := range *buf { + var p *byte + if len(b) > 0 { + p = &b[0] + } + o.bufs = append(o.bufs, syscall.WSABuf{Len: uint32(len(b)), Buf: p}) + } +} + +// ClearBufs clears all pointers to Buffers parameter captured +// by InitBufs, so it can be released by garbage collector. +func (o *operation) ClearBufs() { + for i := range o.bufs { + o.bufs[i].Buf = nil + } + o.bufs = o.bufs[:0] +} + // ioSrv executes net IO requests. type ioSrv struct { req chan ioSrvReq @@ -239,6 +264,7 @@ type netFD struct { sysfd syscall.Handle family int sotype int + isStream bool isConnected bool skipSyncNotif bool net string @@ -257,7 +283,7 @@ func newFD(sysfd syscall.Handle, family, sotype int, net string) (*netFD, error) return nil, initErr } onceStartServer.Do(startServer) - return &netFD{sysfd: sysfd, family: family, sotype: sotype, net: net}, nil + return &netFD{sysfd: sysfd, family: family, sotype: sotype, net: net, isStream: sotype == syscall.SOCK_STREAM}, nil } func (fd *netFD) init() error { @@ -483,6 +509,42 @@ func (fd *netFD) Write(buf []byte) (int, error) { return n, err } +func (c *conn) writeBuffers(v *Buffers) (int64, error) { + if !c.ok() { + return 0, syscall.EINVAL + } + n, err := c.fd.writeBuffers(v) + if err != nil { + return n, &OpError{Op: "WSASend", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err} + } + return n, nil +} + +func (fd *netFD) writeBuffers(buf *Buffers) (int64, error) { + if len(*buf) == 0 { + return 0, nil + } + if err := fd.writeLock(); err != nil { + return 0, err + } + defer fd.writeUnlock() + if race.Enabled { + race.ReleaseMerge(unsafe.Pointer(&ioSync)) + } + o := &fd.wop + o.InitBufs(buf) + n, err := wsrv.ExecIO(o, "WSASend", func(o *operation) error { + return syscall.WSASend(o.fd.sysfd, &o.bufs[0], uint32(len(*buf)), &o.qty, 0, &o.o, nil) + }) + o.ClearBufs() + if _, ok := err.(syscall.Errno); ok { + err = os.NewSyscallError("wsasend", err) + } + testHookDidWritev(n) + buf.consume(int64(n)) + return int64(n), err +} + func (fd *netFD) writeTo(buf []byte, sa syscall.Sockaddr) (int, error) { if len(buf) == 0 { return 0, nil @@ -541,7 +603,7 @@ func (fd *netFD) acceptOne(rawsa []syscall.RawSockaddrAny, o *operation) (*netFD netfd.Close() return nil, os.NewSyscallError("setsockopt", err) } - + runtime.KeepAlive(fd) return netfd, nil } diff --git a/libgo/go/net/file.go b/libgo/go/net/file.go index 1aad477..07099851 100644 --- a/libgo/go/net/file.go +++ b/libgo/go/net/file.go @@ -6,6 +6,9 @@ package net import "os" +// BUG(mikio): On NaCl and Windows, the FileConn, FileListener and +// FilePacketConn functions are not implemented. + type fileAddr string func (fileAddr) Network() string { return "file+net" } diff --git a/libgo/go/net/file_plan9.go b/libgo/go/net/file_plan9.go index 2939c09..d16e5a1 100644 --- a/libgo/go/net/file_plan9.go +++ b/libgo/go/net/file_plan9.go @@ -81,7 +81,7 @@ func newFileFD(f *os.File) (net *netFD, err error) { if err != nil { return nil, err } - return newFD(comp[1], name, ctl, nil, laddr, nil) + return newFD(comp[1], name, nil, ctl, nil, laddr, nil) } func fileConn(f *os.File) (Conn, error) { diff --git a/libgo/go/net/http/client.go b/libgo/go/net/http/client.go index 993c247..d368bae 100644 --- a/libgo/go/net/http/client.go +++ b/libgo/go/net/http/client.go @@ -18,6 +18,7 @@ import ( "io/ioutil" "log" "net/url" + "sort" "strings" "sync" "time" @@ -33,6 +34,25 @@ import ( // A Client is higher-level than a RoundTripper (such as Transport) // and additionally handles HTTP details such as cookies and // redirects. +// +// When following redirects, the Client will forward all headers set on the +// initial Request except: +// +// * when forwarding sensitive headers like "Authorization", +// "WWW-Authenticate", and "Cookie" to untrusted targets. +// These headers will be ignored when following a redirect to a domain +// that is not a subdomain match or exact match of the initial domain. +// For example, a redirect from "foo.com" to either "foo.com" or "sub.foo.com" +// will forward the sensitive headers, but a redirect to "bar.com" will not. +// +// * when forwarding the "Cookie" header with a non-nil cookie Jar. +// Since each redirect may mutate the state of the cookie jar, +// a redirect may possibly alter a cookie set in the initial request. +// When forwarding the "Cookie" header, any mutated cookies will be omitted, +// with the expectation that the Jar will insert those mutated cookies +// with the updated values (assuming the origin matches). +// If Jar is nil, the initial cookies are forwarded without change. +// type Client struct { // Transport specifies the mechanism by which individual // HTTP requests are made. @@ -56,8 +76,14 @@ type Client struct { CheckRedirect func(req *Request, via []*Request) error // Jar specifies the cookie jar. - // If Jar is nil, cookies are not sent in requests and ignored - // in responses. + // + // The Jar is used to insert relevant cookies into every + // outbound Request and is updated with the cookie values + // of every inbound Response. The Jar is consulted for every + // redirect that the Client follows. + // + // If Jar is nil, cookies are only sent if they are explicitly + // set on the Request. Jar CookieJar // Timeout specifies a time limit for requests made by this @@ -137,56 +163,23 @@ func refererForURL(lastReq, newReq *url.URL) string { return referer } -func (c *Client) send(req *Request, deadline time.Time) (*Response, error) { +// didTimeout is non-nil only if err != nil. +func (c *Client) send(req *Request, deadline time.Time) (resp *Response, didTimeout func() bool, err error) { if c.Jar != nil { for _, cookie := range c.Jar.Cookies(req.URL) { req.AddCookie(cookie) } } - resp, err := send(req, c.transport(), deadline) + resp, didTimeout, err = send(req, c.transport(), deadline) if err != nil { - return nil, err + return nil, didTimeout, err } if c.Jar != nil { if rc := resp.Cookies(); len(rc) > 0 { c.Jar.SetCookies(req.URL, rc) } } - return resp, nil -} - -// Do sends an HTTP request and returns an HTTP response, following -// policy (such as redirects, cookies, auth) as configured on the -// client. -// -// An error is returned if caused by client policy (such as -// CheckRedirect), or failure to speak HTTP (such as a network -// connectivity problem). A non-2xx status code doesn't cause an -// error. -// -// If the returned error is nil, the Response will contain a non-nil -// Body which the user is expected to close. If the Body is not -// closed, the Client's underlying RoundTripper (typically Transport) -// may not be able to re-use a persistent TCP connection to the server -// for a subsequent "keep-alive" request. -// -// The request Body, if non-nil, will be closed by the underlying -// Transport, even on errors. -// -// On error, any Response can be ignored. A non-nil Response with a -// non-nil error only occurs when CheckRedirect fails, and even then -// the returned Response.Body is already closed. -// -// Generally Get, Post, or PostForm will be used instead of Do. -func (c *Client) Do(req *Request) (*Response, error) { - method := valueOrDefault(req.Method, "GET") - if method == "GET" || method == "HEAD" { - return c.doFollowingRedirects(req, shouldRedirectGet) - } - if method == "POST" || method == "PUT" { - return c.doFollowingRedirects(req, shouldRedirectPost) - } - return c.send(req, c.deadline()) + return resp, nil, nil } func (c *Client) deadline() time.Time { @@ -205,22 +198,22 @@ func (c *Client) transport() RoundTripper { // send issues an HTTP request. // Caller should close resp.Body when done reading from it. -func send(ireq *Request, rt RoundTripper, deadline time.Time) (*Response, error) { +func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, didTimeout func() bool, err error) { req := ireq // req is either the original request, or a modified fork if rt == nil { req.closeBody() - return nil, errors.New("http: no Client.Transport or DefaultTransport") + return nil, alwaysFalse, errors.New("http: no Client.Transport or DefaultTransport") } if req.URL == nil { req.closeBody() - return nil, errors.New("http: nil Request.URL") + return nil, alwaysFalse, errors.New("http: nil Request.URL") } if req.RequestURI != "" { req.closeBody() - return nil, errors.New("http: Request.RequestURI can't be set in client requests.") + return nil, alwaysFalse, errors.New("http: Request.RequestURI can't be set in client requests.") } // forkReq forks req into a shallow clone of ireq the first @@ -251,9 +244,9 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (*Response, error) if !deadline.IsZero() { forkReq() } - stopTimer, wasCanceled := setRequestCancel(req, rt, deadline) + stopTimer, didTimeout := setRequestCancel(req, rt, deadline) - resp, err := rt.RoundTrip(req) + resp, err = rt.RoundTrip(req) if err != nil { stopTimer() if resp != nil { @@ -267,22 +260,27 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (*Response, error) err = errors.New("http: server gave HTTP response to HTTPS client") } } - return nil, err + return nil, didTimeout, err } if !deadline.IsZero() { resp.Body = &cancelTimerBody{ - stop: stopTimer, - rc: resp.Body, - reqWasCanceled: wasCanceled, + stop: stopTimer, + rc: resp.Body, + reqDidTimeout: didTimeout, } } - return resp, nil + return resp, nil, nil } // setRequestCancel sets the Cancel field of req, if deadline is // non-zero. The RoundTripper's type is used to determine whether the legacy // CancelRequest behavior should be used. -func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTimer func(), wasCanceled func() bool) { +// +// As background, there are three ways to cancel a request: +// First was Transport.CancelRequest. (deprecated) +// Second was Request.Cancel (this mechanism). +// Third was Request.Context. +func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTimer func(), didTimeout func() bool) { if deadline.IsZero() { return nop, alwaysFalse } @@ -292,17 +290,8 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi cancel := make(chan struct{}) req.Cancel = cancel - wasCanceled = func() bool { - select { - case <-cancel: - return true - default: - return false - } - } - doCancel := func() { - // The new way: + // The newer way (the second way in the func comment): close(cancel) // The legacy compatibility way, used only @@ -324,19 +313,23 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi var once sync.Once stopTimer = func() { once.Do(func() { close(stopTimerCh) }) } - timer := time.NewTimer(deadline.Sub(time.Now())) + timer := time.NewTimer(time.Until(deadline)) + var timedOut atomicBool + go func() { select { case <-initialReqCancel: doCancel() + timer.Stop() case <-timer.C: + timedOut.setTrue() doCancel() case <-stopTimerCh: timer.Stop() } }() - return stopTimer, wasCanceled + return stopTimer, timedOut.isSet } // See 2 (end of page 4) http://www.ietf.org/rfc/rfc2617.txt @@ -349,26 +342,6 @@ func basicAuth(username, password string) string { return base64.StdEncoding.EncodeToString([]byte(auth)) } -// True if the specified HTTP status code is one for which the Get utility should -// automatically redirect. -func shouldRedirectGet(statusCode int) bool { - switch statusCode { - case StatusMovedPermanently, StatusFound, StatusSeeOther, StatusTemporaryRedirect: - return true - } - return false -} - -// True if the specified HTTP status code is one for which the Post utility should -// automatically redirect. -func shouldRedirectPost(statusCode int) bool { - switch statusCode { - case StatusFound, StatusSeeOther: - return true - } - return false -} - // Get issues a GET to the specified URL. If the response is one of // the following redirect codes, Get follows the redirect, up to a // maximum of 10 redirects: @@ -377,6 +350,7 @@ func shouldRedirectPost(statusCode int) bool { // 302 (Found) // 303 (See Other) // 307 (Temporary Redirect) +// 308 (Permanent Redirect) // // An error is returned if there were too many redirects or if there // was an HTTP protocol error. A non-2xx response doesn't cause an @@ -401,6 +375,7 @@ func Get(url string) (resp *Response, err error) { // 302 (Found) // 303 (See Other) // 307 (Temporary Redirect) +// 308 (Permanent Redirect) // // An error is returned if the Client's CheckRedirect function fails // or if there was an HTTP protocol error. A non-2xx response doesn't @@ -415,7 +390,7 @@ func (c *Client) Get(url string) (resp *Response, err error) { if err != nil { return nil, err } - return c.doFollowingRedirects(req, shouldRedirectGet) + return c.Do(req) } func alwaysFalse() bool { return false } @@ -436,16 +411,92 @@ func (c *Client) checkRedirect(req *Request, via []*Request) error { return fn(req, via) } -func (c *Client) doFollowingRedirects(req *Request, shouldRedirect func(int) bool) (*Response, error) { +// redirectBehavior describes what should happen when the +// client encounters a 3xx status code from the server +func redirectBehavior(reqMethod string, resp *Response, ireq *Request) (redirectMethod string, shouldRedirect bool) { + switch resp.StatusCode { + case 301, 302, 303: + redirectMethod = reqMethod + shouldRedirect = true + + // RFC 2616 allowed automatic redirection only with GET and + // HEAD requests. RFC 7231 lifts this restriction, but we still + // restrict other methods to GET to maintain compatibility. + // See Issue 18570. + if reqMethod != "GET" && reqMethod != "HEAD" { + redirectMethod = "GET" + } + case 307, 308: + redirectMethod = reqMethod + shouldRedirect = true + + // Treat 307 and 308 specially, since they're new in + // Go 1.8, and they also require re-sending the request body. + if resp.Header.Get("Location") == "" { + // 308s have been observed in the wild being served + // without Location headers. Since Go 1.7 and earlier + // didn't follow these codes, just stop here instead + // of returning an error. + // See Issue 17773. + shouldRedirect = false + break + } + if ireq.GetBody == nil && ireq.outgoingLength() != 0 { + // We had a request body, and 307/308 require + // re-sending it, but GetBody is not defined. So just + // return this response to the user instead of an + // error, like we did in Go 1.7 and earlier. + shouldRedirect = false + } + } + return redirectMethod, shouldRedirect +} + +// Do sends an HTTP request and returns an HTTP response, following +// policy (such as redirects, cookies, auth) as configured on the +// client. +// +// An error is returned if caused by client policy (such as +// CheckRedirect), or failure to speak HTTP (such as a network +// connectivity problem). A non-2xx status code doesn't cause an +// error. +// +// If the returned error is nil, the Response will contain a non-nil +// Body which the user is expected to close. If the Body is not +// closed, the Client's underlying RoundTripper (typically Transport) +// may not be able to re-use a persistent TCP connection to the server +// for a subsequent "keep-alive" request. +// +// The request Body, if non-nil, will be closed by the underlying +// Transport, even on errors. +// +// On error, any Response can be ignored. A non-nil Response with a +// non-nil error only occurs when CheckRedirect fails, and even then +// the returned Response.Body is already closed. +// +// Generally Get, Post, or PostForm will be used instead of Do. +// +// If the server replies with a redirect, the Client first uses the +// CheckRedirect function to determine whether the redirect should be +// followed. If permitted, a 301, 302, or 303 redirect causes +// subsequent requests to use HTTP method GET +// (or HEAD if the original request was HEAD), with no body. +// A 307 or 308 redirect preserves the original HTTP method and body, +// provided that the Request.GetBody function is defined. +// The NewRequest function automatically sets GetBody for common +// standard library body types. +func (c *Client) Do(req *Request) (*Response, error) { if req.URL == nil { req.closeBody() return nil, errors.New("http: nil Request.URL") } var ( - deadline = c.deadline() - reqs []*Request - resp *Response + deadline = c.deadline() + reqs []*Request + resp *Response + copyHeaders = c.makeHeadersCopier(req) + redirectMethod string ) uerr := func(err error) error { req.closeBody() @@ -476,16 +527,27 @@ func (c *Client) doFollowingRedirects(req *Request, shouldRedirect func(int) boo } ireq := reqs[0] req = &Request{ - Method: ireq.Method, + Method: redirectMethod, Response: resp, URL: u, Header: make(Header), Cancel: ireq.Cancel, ctx: ireq.ctx, } - if ireq.Method == "POST" || ireq.Method == "PUT" { - req.Method = "GET" + if ireq.GetBody != nil { + req.Body, err = ireq.GetBody() + if err != nil { + return nil, uerr(err) + } + req.ContentLength = ireq.ContentLength } + + // Copy original headers before setting the Referer, + // in case the user set Referer on their first request. + // If they really want to override, they can do it in + // their CheckRedirect func. + copyHeaders(req) + // Add the Referer header from the most recent // request URL to the new one, if it's not https->http: if ref := refererForURL(reqs[len(reqs)-1].URL, req.URL); ref != "" { @@ -523,10 +585,10 @@ func (c *Client) doFollowingRedirects(req *Request, shouldRedirect func(int) boo } reqs = append(reqs, req) - var err error - if resp, err = c.send(req, deadline); err != nil { - if !deadline.IsZero() && !time.Now().Before(deadline) { + var didTimeout func() bool + if resp, didTimeout, err = c.send(req, deadline); err != nil { + if !deadline.IsZero() && didTimeout() { err = &httpError{ err: err.Error() + " (Client.Timeout exceeded while awaiting headers)", timeout: true, @@ -535,9 +597,77 @@ func (c *Client) doFollowingRedirects(req *Request, shouldRedirect func(int) boo return nil, uerr(err) } - if !shouldRedirect(resp.StatusCode) { + var shouldRedirect bool + redirectMethod, shouldRedirect = redirectBehavior(req.Method, resp, reqs[0]) + if !shouldRedirect { return resp, nil } + + req.closeBody() + } +} + +// makeHeadersCopier makes a function that copies headers from the +// initial Request, ireq. For every redirect, this function must be called +// so that it can copy headers into the upcoming Request. +func (c *Client) makeHeadersCopier(ireq *Request) func(*Request) { + // The headers to copy are from the very initial request. + // We use a closured callback to keep a reference to these original headers. + var ( + ireqhdr = ireq.Header.clone() + icookies map[string][]*Cookie + ) + if c.Jar != nil && ireq.Header.Get("Cookie") != "" { + icookies = make(map[string][]*Cookie) + for _, c := range ireq.Cookies() { + icookies[c.Name] = append(icookies[c.Name], c) + } + } + + preq := ireq // The previous request + return func(req *Request) { + // If Jar is present and there was some initial cookies provided + // via the request header, then we may need to alter the initial + // cookies as we follow redirects since each redirect may end up + // modifying a pre-existing cookie. + // + // Since cookies already set in the request header do not contain + // information about the original domain and path, the logic below + // assumes any new set cookies override the original cookie + // regardless of domain or path. + // + // See https://golang.org/issue/17494 + if c.Jar != nil && icookies != nil { + var changed bool + resp := req.Response // The response that caused the upcoming redirect + for _, c := range resp.Cookies() { + if _, ok := icookies[c.Name]; ok { + delete(icookies, c.Name) + changed = true + } + } + if changed { + ireqhdr.Del("Cookie") + var ss []string + for _, cs := range icookies { + for _, c := range cs { + ss = append(ss, c.Name+"="+c.Value) + } + } + sort.Strings(ss) // Ensure deterministic headers + ireqhdr.Set("Cookie", strings.Join(ss, "; ")) + } + } + + // Copy the initial request's Header values + // (at least the safe ones). + for k, vv := range ireqhdr { + if shouldCopyHeaderOnRedirect(k, preq.URL, req.URL) { + req.Header[k] = vv + } + } + + preq = req // Update previous Request with the current request } } @@ -558,8 +688,11 @@ func defaultCheckRedirect(req *Request, via []*Request) error { // Post is a wrapper around DefaultClient.Post. // // To set custom headers, use NewRequest and DefaultClient.Do. -func Post(url string, bodyType string, body io.Reader) (resp *Response, err error) { - return DefaultClient.Post(url, bodyType, body) +// +// See the Client.Do method documentation for details on how redirects +// are handled. +func Post(url string, contentType string, body io.Reader) (resp *Response, err error) { + return DefaultClient.Post(url, contentType, body) } // Post issues a POST to the specified URL. @@ -570,13 +703,16 @@ func Post(url string, bodyType string, body io.Reader) (resp *Response, err erro // request. // // To set custom headers, use NewRequest and Client.Do. -func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Response, err error) { +// +// See the Client.Do method documentation for details on how redirects +// are handled. +func (c *Client) Post(url string, contentType string, body io.Reader) (resp *Response, err error) { req, err := NewRequest("POST", url, body) if err != nil { return nil, err } - req.Header.Set("Content-Type", bodyType) - return c.doFollowingRedirects(req, shouldRedirectPost) + req.Header.Set("Content-Type", contentType) + return c.Do(req) } // PostForm issues a POST to the specified URL, with data's keys and @@ -589,6 +725,9 @@ func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Respon // Caller should close resp.Body when done reading from it. // // PostForm is a wrapper around DefaultClient.PostForm. +// +// See the Client.Do method documentation for details on how redirects +// are handled. func PostForm(url string, data url.Values) (resp *Response, err error) { return DefaultClient.PostForm(url, data) } @@ -601,11 +740,14 @@ func PostForm(url string, data url.Values) (resp *Response, err error) { // // When err is nil, resp always contains a non-nil resp.Body. // Caller should close resp.Body when done reading from it. +// +// See the Client.Do method documentation for details on how redirects +// are handled. func (c *Client) PostForm(url string, data url.Values) (resp *Response, err error) { return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) } -// Head issues a HEAD to the specified URL. If the response is one of +// Head issues a HEAD to the specified URL. If the response is one of // the following redirect codes, Head follows the redirect, up to a // maximum of 10 redirects: // @@ -613,13 +755,14 @@ func (c *Client) PostForm(url string, data url.Values) (resp *Response, err erro // 302 (Found) // 303 (See Other) // 307 (Temporary Redirect) +// 308 (Permanent Redirect) // // Head is a wrapper around DefaultClient.Head func Head(url string) (resp *Response, err error) { return DefaultClient.Head(url) } -// Head issues a HEAD to the specified URL. If the response is one of the +// Head issues a HEAD to the specified URL. If the response is one of the // following redirect codes, Head follows the redirect after calling the // Client's CheckRedirect function: // @@ -627,22 +770,23 @@ func Head(url string) (resp *Response, err error) { // 302 (Found) // 303 (See Other) // 307 (Temporary Redirect) +// 308 (Permanent Redirect) func (c *Client) Head(url string) (resp *Response, err error) { req, err := NewRequest("HEAD", url, nil) if err != nil { return nil, err } - return c.doFollowingRedirects(req, shouldRedirectGet) + return c.Do(req) } // cancelTimerBody is an io.ReadCloser that wraps rc with two features: // 1) on Read error or close, the stop func is called. -// 2) On Read failure, if reqWasCanceled is true, the error is wrapped and +// 2) On Read failure, if reqDidTimeout is true, the error is wrapped and // marked as net.Error that hit its timeout. type cancelTimerBody struct { - stop func() // stops the time.Timer waiting to cancel the request - rc io.ReadCloser - reqWasCanceled func() bool + stop func() // stops the time.Timer waiting to cancel the request + rc io.ReadCloser + reqDidTimeout func() bool } func (b *cancelTimerBody) Read(p []byte) (n int, err error) { @@ -654,7 +798,7 @@ func (b *cancelTimerBody) Read(p []byte) (n int, err error) { if err == io.EOF { return n, err } - if b.reqWasCanceled() { + if b.reqDidTimeout() { err = &httpError{ err: err.Error() + " (Client.Timeout exceeded while reading body)", timeout: true, @@ -668,3 +812,52 @@ func (b *cancelTimerBody) Close() error { b.stop() return err } + +func shouldCopyHeaderOnRedirect(headerKey string, initial, dest *url.URL) bool { + switch CanonicalHeaderKey(headerKey) { + case "Authorization", "Www-Authenticate", "Cookie", "Cookie2": + // Permit sending auth/cookie headers from "foo.com" + // to "sub.foo.com". + + // Note that we don't send all cookies to subdomains + // automatically. This function is only used for + // Cookies set explicitly on the initial outgoing + // client request. Cookies automatically added via the + // CookieJar mechanism continue to follow each + // cookie's scope as set by Set-Cookie. But for + // outgoing requests with the Cookie header set + // directly, we don't know their scope, so we assume + // it's for *.domain.com. + + // TODO(bradfitz): once issue 16142 is fixed, make + // this code use those URL accessors, and consider + // "http://foo.com" and "http://foo.com:80" as + // equivalent? + + // TODO(bradfitz): better hostname canonicalization, + // at least once we figure out IDNA/Punycode (issue + // 13835). + ihost := strings.ToLower(initial.Host) + dhost := strings.ToLower(dest.Host) + return isDomainOrSubdomain(dhost, ihost) + } + // All other headers are copied: + return true +} + +// isDomainOrSubdomain reports whether sub is a subdomain (or exact +// match) of the parent domain. +// +// Both domains must already be in canonical form. +func isDomainOrSubdomain(sub, parent string) bool { + if sub == parent { + return true + } + // If sub is "foo.example.com" and parent is "example.com", + // that means sub must end in "."+parent. + // Do it without allocating. + if !strings.HasSuffix(sub, parent) { + return false + } + return sub[len(sub)-len(parent)-1] == '.' +} diff --git a/libgo/go/net/http/client_test.go b/libgo/go/net/http/client_test.go index a9b1948..eaf2cdc 100644 --- a/libgo/go/net/http/client_test.go +++ b/libgo/go/net/http/client_test.go @@ -19,11 +19,14 @@ import ( "log" "net" . "net/http" + "net/http/cookiejar" "net/http/httptest" "net/url" + "reflect" "strconv" "strings" "sync" + "sync/atomic" "testing" "time" ) @@ -65,11 +68,13 @@ func (w chanWriter) Write(p []byte) (n int, err error) { } func TestClient(t *testing.T) { + setParallel(t) defer afterTest(t) ts := httptest.NewServer(robotsTxtHandler) defer ts.Close() - r, err := Get(ts.URL) + c := &Client{Transport: &Transport{DisableKeepAlives: true}} + r, err := c.Get(ts.URL) var b []byte if err == nil { b, err = pedanticReadAll(r.Body) @@ -109,6 +114,7 @@ func (t *recordingTransport) RoundTrip(req *Request) (resp *Response, err error) } func TestGetRequestFormat(t *testing.T) { + setParallel(t) defer afterTest(t) tr := &recordingTransport{} client := &Client{Transport: tr} @@ -195,6 +201,7 @@ func TestPostFormRequestFormat(t *testing.T) { } func TestClientRedirects(t *testing.T) { + setParallel(t) defer afterTest(t) var ts *httptest.Server ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -206,14 +213,17 @@ func TestClientRedirects(t *testing.T) { } } if n < 15 { - Redirect(w, r, fmt.Sprintf("/?n=%d", n+1), StatusFound) + Redirect(w, r, fmt.Sprintf("/?n=%d", n+1), StatusTemporaryRedirect) return } fmt.Fprintf(w, "n=%d", n) })) defer ts.Close() - c := &Client{} + tr := &Transport{} + defer tr.CloseIdleConnections() + + c := &Client{Transport: tr} _, err := c.Get(ts.URL) if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g { t.Errorf("with default client Get, expected error %q, got %q", e, g) @@ -242,11 +252,14 @@ func TestClientRedirects(t *testing.T) { var checkErr error var lastVia []*Request var lastReq *Request - c = &Client{CheckRedirect: func(req *Request, via []*Request) error { - lastReq = req - lastVia = via - return checkErr - }} + c = &Client{ + Transport: tr, + CheckRedirect: func(req *Request, via []*Request) error { + lastReq = req + lastVia = via + return checkErr + }, + } res, err := c.Get(ts.URL) if err != nil { t.Fatalf("Get error: %v", err) @@ -292,20 +305,27 @@ func TestClientRedirects(t *testing.T) { } func TestClientRedirectContext(t *testing.T) { + setParallel(t) defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - Redirect(w, r, "/", StatusFound) + Redirect(w, r, "/", StatusTemporaryRedirect) })) defer ts.Close() + tr := &Transport{} + defer tr.CloseIdleConnections() + ctx, cancel := context.WithCancel(context.Background()) - c := &Client{CheckRedirect: func(req *Request, via []*Request) error { - cancel() - if len(via) > 2 { - return errors.New("too many redirects") - } - return nil - }} + c := &Client{ + Transport: tr, + CheckRedirect: func(req *Request, via []*Request) error { + cancel() + if len(via) > 2 { + return errors.New("too many redirects") + } + return nil + }, + } req, _ := NewRequest("GET", ts.URL, nil) req = req.WithContext(ctx) _, err := c.Do(req) @@ -313,12 +333,96 @@ func TestClientRedirectContext(t *testing.T) { if !ok { t.Fatalf("got error %T; want *url.Error", err) } - if ue.Err != ExportErrRequestCanceled && ue.Err != ExportErrRequestCanceledConn { - t.Errorf("url.Error.Err = %v; want errRequestCanceled or errRequestCanceledConn", ue.Err) + if ue.Err != context.Canceled { + t.Errorf("url.Error.Err = %v; want %v", ue.Err, context.Canceled) } } +type redirectTest struct { + suffix string + want int // response code + redirectBody string +} + func TestPostRedirects(t *testing.T) { + postRedirectTests := []redirectTest{ + {"/", 200, "first"}, + {"/?code=301&next=302", 200, "c301"}, + {"/?code=302&next=302", 200, "c302"}, + {"/?code=303&next=301", 200, "c303wc301"}, // Issue 9348 + {"/?code=304", 304, "c304"}, + {"/?code=305", 305, "c305"}, + {"/?code=307&next=303,308,302", 200, "c307"}, + {"/?code=308&next=302,301", 200, "c308"}, + {"/?code=404", 404, "c404"}, + } + + wantSegments := []string{ + `POST / "first"`, + `POST /?code=301&next=302 "c301"`, + `GET /?code=302 "c301"`, + `GET / "c301"`, + `POST /?code=302&next=302 "c302"`, + `GET /?code=302 "c302"`, + `GET / "c302"`, + `POST /?code=303&next=301 "c303wc301"`, + `GET /?code=301 "c303wc301"`, + `GET / "c303wc301"`, + `POST /?code=304 "c304"`, + `POST /?code=305 "c305"`, + `POST /?code=307&next=303,308,302 "c307"`, + `POST /?code=303&next=308,302 "c307"`, + `GET /?code=308&next=302 "c307"`, + `GET /?code=302 "c307"`, + `GET / "c307"`, + `POST /?code=308&next=302,301 "c308"`, + `POST /?code=302&next=301 "c308"`, + `GET /?code=301 "c308"`, + `GET / "c308"`, + `POST /?code=404 "c404"`, + } + want := strings.Join(wantSegments, "\n") + testRedirectsByMethod(t, "POST", postRedirectTests, want) +} + +func TestDeleteRedirects(t *testing.T) { + deleteRedirectTests := []redirectTest{ + {"/", 200, "first"}, + {"/?code=301&next=302,308", 200, "c301"}, + {"/?code=302&next=302", 200, "c302"}, + {"/?code=303", 200, "c303"}, + {"/?code=307&next=301,308,303,302,304", 304, "c307"}, + {"/?code=308&next=307", 200, "c308"}, + {"/?code=404", 404, "c404"}, + } + + wantSegments := []string{ + `DELETE / "first"`, + `DELETE /?code=301&next=302,308 "c301"`, + `GET /?code=302&next=308 "c301"`, + `GET /?code=308 "c301"`, + `GET / "c301"`, + `DELETE /?code=302&next=302 "c302"`, + `GET /?code=302 "c302"`, + `GET / "c302"`, + `DELETE /?code=303 "c303"`, + `GET / "c303"`, + `DELETE /?code=307&next=301,308,303,302,304 "c307"`, + `DELETE /?code=301&next=308,303,302,304 "c307"`, + `GET /?code=308&next=303,302,304 "c307"`, + `GET /?code=303&next=302,304 "c307"`, + `GET /?code=302&next=304 "c307"`, + `GET /?code=304 "c307"`, + `DELETE /?code=308&next=307 "c308"`, + `DELETE /?code=307 "c308"`, + `DELETE / "c308"`, + `DELETE /?code=404 "c404"`, + } + want := strings.Join(wantSegments, "\n") + testRedirectsByMethod(t, "DELETE", deleteRedirectTests, want) +} + +func testRedirectsByMethod(t *testing.T, method string, table []redirectTest, want string) { defer afterTest(t) var log struct { sync.Mutex @@ -327,29 +431,35 @@ func TestPostRedirects(t *testing.T) { var ts *httptest.Server ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { log.Lock() - fmt.Fprintf(&log.Buffer, "%s %s ", r.Method, r.RequestURI) + slurp, _ := ioutil.ReadAll(r.Body) + fmt.Fprintf(&log.Buffer, "%s %s %q\n", r.Method, r.RequestURI, slurp) log.Unlock() - if v := r.URL.Query().Get("code"); v != "" { + urlQuery := r.URL.Query() + if v := urlQuery.Get("code"); v != "" { + location := ts.URL + if final := urlQuery.Get("next"); final != "" { + splits := strings.Split(final, ",") + first, rest := splits[0], splits[1:] + location = fmt.Sprintf("%s?code=%s", location, first) + if len(rest) > 0 { + location = fmt.Sprintf("%s&next=%s", location, strings.Join(rest, ",")) + } + } code, _ := strconv.Atoi(v) if code/100 == 3 { - w.Header().Set("Location", ts.URL) + w.Header().Set("Location", location) } w.WriteHeader(code) } })) defer ts.Close() - tests := []struct { - suffix string - want int // response code - }{ - {"/", 200}, - {"/?code=301", 301}, - {"/?code=302", 200}, - {"/?code=303", 200}, - {"/?code=404", 404}, - } - for _, tt := range tests { - res, err := Post(ts.URL+tt.suffix, "text/plain", strings.NewReader("Some content")) + + for _, tt := range table { + content := tt.redirectBody + req, _ := NewRequest(method, ts.URL+tt.suffix, strings.NewReader(content)) + req.GetBody = func() (io.ReadCloser, error) { return ioutil.NopCloser(strings.NewReader(content)), nil } + res, err := DefaultClient.Do(req) + if err != nil { t.Fatal(err) } @@ -360,13 +470,17 @@ func TestPostRedirects(t *testing.T) { log.Lock() got := log.String() log.Unlock() - want := "POST / POST /?code=301 POST /?code=302 GET / POST /?code=303 GET / POST /?code=404 " + + got = strings.TrimSpace(got) + want = strings.TrimSpace(want) + if got != want { - t.Errorf("Log differs.\n Got: %q\nWant: %q", got, want) + t.Errorf("Log differs.\n Got:\n%s\nWant:\n%s\n", got, want) } } func TestClientRedirectUseResponse(t *testing.T) { + setParallel(t) defer afterTest(t) const body = "Hello, world." var ts *httptest.Server @@ -381,12 +495,18 @@ func TestClientRedirectUseResponse(t *testing.T) { })) defer ts.Close() - c := &Client{CheckRedirect: func(req *Request, via []*Request) error { - if req.Response == nil { - t.Error("expected non-nil Request.Response") - } - return ErrUseLastResponse - }} + tr := &Transport{} + defer tr.CloseIdleConnections() + + c := &Client{ + Transport: tr, + CheckRedirect: func(req *Request, via []*Request) error { + if req.Response == nil { + t.Error("expected non-nil Request.Response") + } + return ErrUseLastResponse + }, + } res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) @@ -404,6 +524,57 @@ func TestClientRedirectUseResponse(t *testing.T) { } } +// Issue 17773: don't follow a 308 (or 307) if the response doesn't +// have a Location header. +func TestClientRedirect308NoLocation(t *testing.T) { + setParallel(t) + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Foo", "Bar") + w.WriteHeader(308) + })) + defer ts.Close() + res, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != 308 { + t.Errorf("status = %d; want %d", res.StatusCode, 308) + } + if got := res.Header.Get("Foo"); got != "Bar" { + t.Errorf("Foo header = %q; want Bar", got) + } +} + +// Don't follow a 307/308 if we can't resent the request body. +func TestClientRedirect308NoGetBody(t *testing.T) { + setParallel(t) + defer afterTest(t) + const fakeURL = "https://localhost:1234/" // won't be hit + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Location", fakeURL) + w.WriteHeader(308) + })) + defer ts.Close() + req, err := NewRequest("POST", ts.URL, strings.NewReader("some body")) + if err != nil { + t.Fatal(err) + } + req.GetBody = nil // so it can't rewind. + res, err := DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != 308 { + t.Errorf("status = %d; want %d", res.StatusCode, 308) + } + if got := res.Header.Get("Location"); got != fakeURL { + t.Errorf("Location header = %q; want %q", got, fakeURL) + } +} + var expectedCookies = []*Cookie{ {Name: "ChocolateChip", Value: "tasty"}, {Name: "First", Value: "Hit"}, @@ -476,12 +647,16 @@ func (j *TestJar) Cookies(u *url.URL) []*Cookie { } func TestRedirectCookiesJar(t *testing.T) { + setParallel(t) defer afterTest(t) var ts *httptest.Server ts = httptest.NewServer(echoCookiesRedirectHandler) defer ts.Close() + tr := &Transport{} + defer tr.CloseIdleConnections() c := &Client{ - Jar: new(TestJar), + Transport: tr, + Jar: new(TestJar), } u, _ := url.Parse(ts.URL) c.Jar.SetCookies(u, []*Cookie{expectedCookies[0]}) @@ -665,6 +840,7 @@ func TestClientWrites(t *testing.T) { } func TestClientInsecureTransport(t *testing.T) { + setParallel(t) defer afterTest(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello")) @@ -842,6 +1018,7 @@ func TestResponseSetsTLSConnectionState(t *testing.T) { func TestHTTPSClientDetectsHTTPServer(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + ts.Config.ErrorLog = quietLog defer ts.Close() _, err := Get(strings.Replace(ts.URL, "http", "https", 1)) @@ -895,6 +1072,7 @@ func testClientHeadContentLength(t *testing.T, h2 bool) { } func TestEmptyPasswordAuth(t *testing.T) { + setParallel(t) defer afterTest(t) gopher := "gopher" ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -915,7 +1093,9 @@ func TestEmptyPasswordAuth(t *testing.T) { } })) defer ts.Close() - c := &Client{} + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} req, err := NewRequest("GET", ts.URL, nil) if err != nil { t.Fatal(err) @@ -1007,10 +1187,10 @@ func TestClientTimeout_h1(t *testing.T) { testClientTimeout(t, h1Mode) } func TestClientTimeout_h2(t *testing.T) { testClientTimeout(t, h2Mode) } func testClientTimeout(t *testing.T, h2 bool) { - if testing.Short() { - t.Skip("skipping in short mode") - } + setParallel(t) defer afterTest(t) + testDone := make(chan struct{}) // closed in defer below + sawRoot := make(chan bool, 1) sawSlow := make(chan bool, 1) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { @@ -1020,19 +1200,26 @@ func testClientTimeout(t *testing.T, h2 bool) { return } if r.URL.Path == "/slow" { + sawSlow <- true w.Write([]byte("Hello")) w.(Flusher).Flush() - sawSlow <- true - time.Sleep(2 * time.Second) + <-testDone return } })) defer cst.close() - const timeout = 500 * time.Millisecond + defer close(testDone) // before cst.close, to unblock /slow handler + + // 200ms should be long enough to get a normal request (the / + // handler), but not so long that it makes the test slow. + const timeout = 200 * time.Millisecond cst.c.Timeout = timeout res, err := cst.c.Get(cst.ts.URL) if err != nil { + if strings.Contains(err.Error(), "Client.Timeout") { + t.Skipf("host too slow to get fast resource in %v", timeout) + } t.Fatal(err) } @@ -1057,7 +1244,7 @@ func testClientTimeout(t *testing.T, h2 bool) { res.Body.Close() }() - const failTime = timeout * 2 + const failTime = 5 * time.Second select { case err := <-errc: if err == nil { @@ -1082,11 +1269,9 @@ func TestClientTimeout_Headers_h2(t *testing.T) { testClientTimeout_Headers(t, h // Client.Timeout firing before getting to the body func testClientTimeout_Headers(t *testing.T, h2 bool) { - if testing.Short() { - t.Skip("skipping in short mode") - } + setParallel(t) defer afterTest(t) - donec := make(chan bool) + donec := make(chan bool, 1) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { <-donec })) @@ -1100,9 +1285,10 @@ func testClientTimeout_Headers(t *testing.T, h2 bool) { // doesn't know this, so synchronize explicitly. defer func() { donec <- true }() - cst.c.Timeout = 500 * time.Millisecond - _, err := cst.c.Get(cst.ts.URL) + cst.c.Timeout = 5 * time.Millisecond + res, err := cst.c.Get(cst.ts.URL) if err == nil { + res.Body.Close() t.Fatal("got response from Get; expected error") } if _, ok := err.(*url.Error); !ok { @@ -1120,9 +1306,40 @@ func testClientTimeout_Headers(t *testing.T, h2 bool) { } } +// Issue 16094: if Client.Timeout is set but not hit, a Timeout error shouldn't be +// returned. +func TestClientTimeoutCancel(t *testing.T) { + setParallel(t) + defer afterTest(t) + + testDone := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + + cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.(Flusher).Flush() + <-testDone + })) + defer cst.close() + defer close(testDone) + + cst.c.Timeout = 1 * time.Hour + req, _ := NewRequest("GET", cst.ts.URL, nil) + req.Cancel = ctx.Done() + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + cancel() + _, err = io.Copy(ioutil.Discard, res.Body) + if err != ExportErrRequestCanceled { + t.Fatalf("error = %v; want errRequestCanceled", err) + } +} + func TestClientRedirectEatsBody_h1(t *testing.T) { testClientRedirectEatsBody(t, h1Mode) } func TestClientRedirectEatsBody_h2(t *testing.T) { testClientRedirectEatsBody(t, h2Mode) } func testClientRedirectEatsBody(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) saw := make(chan string, 2) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { @@ -1138,10 +1355,10 @@ func testClientRedirectEatsBody(t *testing.T, h2 bool) { t.Fatal(err) } _, err = ioutil.ReadAll(res.Body) + res.Body.Close() if err != nil { t.Fatal(err) } - res.Body.Close() var first string select { @@ -1229,3 +1446,369 @@ func TestClientRedirectResponseWithoutRequest(t *testing.T) { // Check that this doesn't crash: c.Get("http://dummy.tld") } + +// Issue 4800: copy (some) headers when Client follows a redirect +func TestClientCopyHeadersOnRedirect(t *testing.T) { + const ( + ua = "some-agent/1.2" + xfoo = "foo-val" + ) + var ts2URL string + ts1 := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + want := Header{ + "User-Agent": []string{ua}, + "X-Foo": []string{xfoo}, + "Referer": []string{ts2URL}, + "Accept-Encoding": []string{"gzip"}, + } + if !reflect.DeepEqual(r.Header, want) { + t.Errorf("Request.Header = %#v; want %#v", r.Header, want) + } + if t.Failed() { + w.Header().Set("Result", "got errors") + } else { + w.Header().Set("Result", "ok") + } + })) + defer ts1.Close() + ts2 := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + Redirect(w, r, ts1.URL, StatusFound) + })) + defer ts2.Close() + ts2URL = ts2.URL + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{ + Transport: tr, + CheckRedirect: func(r *Request, via []*Request) error { + want := Header{ + "User-Agent": []string{ua}, + "X-Foo": []string{xfoo}, + "Referer": []string{ts2URL}, + } + if !reflect.DeepEqual(r.Header, want) { + t.Errorf("CheckRedirect Request.Header = %#v; want %#v", r.Header, want) + } + return nil + }, + } + + req, _ := NewRequest("GET", ts2.URL, nil) + req.Header.Add("User-Agent", ua) + req.Header.Add("X-Foo", xfoo) + req.Header.Add("Cookie", "foo=bar") + req.Header.Add("Authorization", "secretpassword") + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + t.Fatal(res.Status) + } + if got := res.Header.Get("Result"); got != "ok" { + t.Errorf("result = %q; want ok", got) + } +} + +// Issue 17494: cookies should be altered when Client follows redirects. +func TestClientAltersCookiesOnRedirect(t *testing.T) { + cookieMap := func(cs []*Cookie) map[string][]string { + m := make(map[string][]string) + for _, c := range cs { + m[c.Name] = append(m[c.Name], c.Value) + } + return m + } + + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + var want map[string][]string + got := cookieMap(r.Cookies()) + + c, _ := r.Cookie("Cycle") + switch c.Value { + case "0": + want = map[string][]string{ + "Cookie1": {"OldValue1a", "OldValue1b"}, + "Cookie2": {"OldValue2"}, + "Cookie3": {"OldValue3a", "OldValue3b"}, + "Cookie4": {"OldValue4"}, + "Cycle": {"0"}, + } + SetCookie(w, &Cookie{Name: "Cycle", Value: "1", Path: "/"}) + SetCookie(w, &Cookie{Name: "Cookie2", Path: "/", MaxAge: -1}) // Delete cookie from Header + Redirect(w, r, "/", StatusFound) + case "1": + want = map[string][]string{ + "Cookie1": {"OldValue1a", "OldValue1b"}, + "Cookie3": {"OldValue3a", "OldValue3b"}, + "Cookie4": {"OldValue4"}, + "Cycle": {"1"}, + } + SetCookie(w, &Cookie{Name: "Cycle", Value: "2", Path: "/"}) + SetCookie(w, &Cookie{Name: "Cookie3", Value: "NewValue3", Path: "/"}) // Modify cookie in Header + SetCookie(w, &Cookie{Name: "Cookie4", Value: "NewValue4", Path: "/"}) // Modify cookie in Jar + Redirect(w, r, "/", StatusFound) + case "2": + want = map[string][]string{ + "Cookie1": {"OldValue1a", "OldValue1b"}, + "Cookie3": {"NewValue3"}, + "Cookie4": {"NewValue4"}, + "Cycle": {"2"}, + } + SetCookie(w, &Cookie{Name: "Cycle", Value: "3", Path: "/"}) + SetCookie(w, &Cookie{Name: "Cookie5", Value: "NewValue5", Path: "/"}) // Insert cookie into Jar + Redirect(w, r, "/", StatusFound) + case "3": + want = map[string][]string{ + "Cookie1": {"OldValue1a", "OldValue1b"}, + "Cookie3": {"NewValue3"}, + "Cookie4": {"NewValue4"}, + "Cookie5": {"NewValue5"}, + "Cycle": {"3"}, + } + // Don't redirect to ensure the loop ends. + default: + t.Errorf("unexpected redirect cycle") + return + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("redirect %s, Cookie = %v, want %v", c.Value, got, want) + } + })) + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + jar, _ := cookiejar.New(nil) + c := &Client{ + Transport: tr, + Jar: jar, + } + + u, _ := url.Parse(ts.URL) + req, _ := NewRequest("GET", ts.URL, nil) + req.AddCookie(&Cookie{Name: "Cookie1", Value: "OldValue1a"}) + req.AddCookie(&Cookie{Name: "Cookie1", Value: "OldValue1b"}) + req.AddCookie(&Cookie{Name: "Cookie2", Value: "OldValue2"}) + req.AddCookie(&Cookie{Name: "Cookie3", Value: "OldValue3a"}) + req.AddCookie(&Cookie{Name: "Cookie3", Value: "OldValue3b"}) + jar.SetCookies(u, []*Cookie{{Name: "Cookie4", Value: "OldValue4", Path: "/"}}) + jar.SetCookies(u, []*Cookie{{Name: "Cycle", Value: "0", Path: "/"}}) + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + t.Fatal(res.Status) + } +} + +// Part of Issue 4800 +func TestShouldCopyHeaderOnRedirect(t *testing.T) { + tests := []struct { + header string + initialURL string + destURL string + want bool + }{ + {"User-Agent", "http://foo.com/", "http://bar.com/", true}, + {"X-Foo", "http://foo.com/", "http://bar.com/", true}, + + // Sensitive headers: + {"cookie", "http://foo.com/", "http://bar.com/", false}, + {"cookie2", "http://foo.com/", "http://bar.com/", false}, + {"authorization", "http://foo.com/", "http://bar.com/", false}, + {"www-authenticate", "http://foo.com/", "http://bar.com/", false}, + + // But subdomains should work: + {"www-authenticate", "http://foo.com/", "http://foo.com/", true}, + {"www-authenticate", "http://foo.com/", "http://sub.foo.com/", true}, + {"www-authenticate", "http://foo.com/", "http://notfoo.com/", false}, + // TODO(bradfitz): make this test work, once issue 16142 is fixed: + // {"www-authenticate", "http://foo.com:80/", "http://foo.com/", true}, + } + for i, tt := range tests { + u0, err := url.Parse(tt.initialURL) + if err != nil { + t.Errorf("%d. initial URL %q parse error: %v", i, tt.initialURL, err) + continue + } + u1, err := url.Parse(tt.destURL) + if err != nil { + t.Errorf("%d. dest URL %q parse error: %v", i, tt.destURL, err) + continue + } + got := Export_shouldCopyHeaderOnRedirect(tt.header, u0, u1) + if got != tt.want { + t.Errorf("%d. shouldCopyHeaderOnRedirect(%q, %q => %q) = %v; want %v", + i, tt.header, tt.initialURL, tt.destURL, got, tt.want) + } + } +} + +func TestClientRedirectTypes(t *testing.T) { + setParallel(t) + defer afterTest(t) + + tests := [...]struct { + method string + serverStatus int + wantMethod string // desired subsequent client method + }{ + 0: {method: "POST", serverStatus: 301, wantMethod: "GET"}, + 1: {method: "POST", serverStatus: 302, wantMethod: "GET"}, + 2: {method: "POST", serverStatus: 303, wantMethod: "GET"}, + 3: {method: "POST", serverStatus: 307, wantMethod: "POST"}, + 4: {method: "POST", serverStatus: 308, wantMethod: "POST"}, + + 5: {method: "HEAD", serverStatus: 301, wantMethod: "HEAD"}, + 6: {method: "HEAD", serverStatus: 302, wantMethod: "HEAD"}, + 7: {method: "HEAD", serverStatus: 303, wantMethod: "HEAD"}, + 8: {method: "HEAD", serverStatus: 307, wantMethod: "HEAD"}, + 9: {method: "HEAD", serverStatus: 308, wantMethod: "HEAD"}, + + 10: {method: "GET", serverStatus: 301, wantMethod: "GET"}, + 11: {method: "GET", serverStatus: 302, wantMethod: "GET"}, + 12: {method: "GET", serverStatus: 303, wantMethod: "GET"}, + 13: {method: "GET", serverStatus: 307, wantMethod: "GET"}, + 14: {method: "GET", serverStatus: 308, wantMethod: "GET"}, + + 15: {method: "DELETE", serverStatus: 301, wantMethod: "GET"}, + 16: {method: "DELETE", serverStatus: 302, wantMethod: "GET"}, + 17: {method: "DELETE", serverStatus: 303, wantMethod: "GET"}, + 18: {method: "DELETE", serverStatus: 307, wantMethod: "DELETE"}, + 19: {method: "DELETE", serverStatus: 308, wantMethod: "DELETE"}, + + 20: {method: "PUT", serverStatus: 301, wantMethod: "GET"}, + 21: {method: "PUT", serverStatus: 302, wantMethod: "GET"}, + 22: {method: "PUT", serverStatus: 303, wantMethod: "GET"}, + 23: {method: "PUT", serverStatus: 307, wantMethod: "PUT"}, + 24: {method: "PUT", serverStatus: 308, wantMethod: "PUT"}, + + 25: {method: "MADEUPMETHOD", serverStatus: 301, wantMethod: "GET"}, + 26: {method: "MADEUPMETHOD", serverStatus: 302, wantMethod: "GET"}, + 27: {method: "MADEUPMETHOD", serverStatus: 303, wantMethod: "GET"}, + 28: {method: "MADEUPMETHOD", serverStatus: 307, wantMethod: "MADEUPMETHOD"}, + 29: {method: "MADEUPMETHOD", serverStatus: 308, wantMethod: "MADEUPMETHOD"}, + } + + handlerc := make(chan HandlerFunc, 1) + + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + h := <-handlerc + h(rw, req) + })) + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + + for i, tt := range tests { + handlerc <- func(w ResponseWriter, r *Request) { + w.Header().Set("Location", ts.URL) + w.WriteHeader(tt.serverStatus) + } + + req, err := NewRequest(tt.method, ts.URL, nil) + if err != nil { + t.Errorf("#%d: NewRequest: %v", i, err) + continue + } + + c := &Client{Transport: tr} + c.CheckRedirect = func(req *Request, via []*Request) error { + if got, want := req.Method, tt.wantMethod; got != want { + return fmt.Errorf("#%d: got next method %q; want %q", i, got, want) + } + handlerc <- func(rw ResponseWriter, req *Request) { + // TODO: Check that the body is valid when we do 307 and 308 support + } + return nil + } + + res, err := c.Do(req) + if err != nil { + t.Errorf("#%d: Response: %v", i, err) + continue + } + + res.Body.Close() + } +} + +// issue18239Body is an io.ReadCloser for TestTransportBodyReadError. +// Its Read returns readErr and increments *readCalls atomically. +// Its Close returns nil and increments *closeCalls atomically. +type issue18239Body struct { + readCalls *int32 + closeCalls *int32 + readErr error +} + +func (b issue18239Body) Read([]byte) (int, error) { + atomic.AddInt32(b.readCalls, 1) + return 0, b.readErr +} + +func (b issue18239Body) Close() error { + atomic.AddInt32(b.closeCalls, 1) + return nil +} + +// Issue 18239: make sure the Transport doesn't retry requests with bodies. +// (Especially if Request.GetBody is not defined.) +func TestTransportBodyReadError(t *testing.T) { + setParallel(t) + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.URL.Path == "/ping" { + return + } + buf := make([]byte, 1) + n, err := r.Body.Read(buf) + w.Header().Set("X-Body-Read", fmt.Sprintf("%v, %v", n, err)) + })) + defer ts.Close() + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + // Do one initial successful request to create an idle TCP connection + // for the subsequent request to reuse. (The Transport only retries + // requests on reused connections.) + res, err := c.Get(ts.URL + "/ping") + if err != nil { + t.Fatal(err) + } + res.Body.Close() + + var readCallsAtomic int32 + var closeCallsAtomic int32 // atomic + someErr := errors.New("some body read error") + body := issue18239Body{&readCallsAtomic, &closeCallsAtomic, someErr} + + req, err := NewRequest("POST", ts.URL, body) + if err != nil { + t.Fatal(err) + } + _, err = tr.RoundTrip(req) + if err != someErr { + t.Errorf("Got error: %v; want Request.Body read error: %v", err, someErr) + } + + // And verify that our Body wasn't used multiple times, which + // would indicate retries. (as it buggily was during part of + // Go 1.8's dev cycle) + readCalls := atomic.LoadInt32(&readCallsAtomic) + closeCalls := atomic.LoadInt32(&closeCallsAtomic) + if readCalls != 1 { + t.Errorf("read calls = %d; want 1", readCalls) + } + if closeCalls != 1 { + t.Errorf("close calls = %d; want 1", closeCalls) + } +} diff --git a/libgo/go/net/http/clientserver_test.go b/libgo/go/net/http/clientserver_test.go index 3d1f09c..580115c 100644 --- a/libgo/go/net/http/clientserver_test.go +++ b/libgo/go/net/http/clientserver_test.go @@ -44,6 +44,19 @@ func (t *clientServerTest) close() { t.ts.Close() } +func (t *clientServerTest) getURL(u string) string { + res, err := t.c.Get(u) + if err != nil { + t.t.Fatal(err) + } + defer res.Body.Close() + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.t.Fatal(err) + } + return string(slurp) +} + func (t *clientServerTest) scheme() string { if t.h2 { return "https" @@ -56,6 +69,10 @@ const ( h2Mode = true ) +var optQuietLog = func(ts *httptest.Server) { + ts.Config.ErrorLog = quietLog +} + func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...interface{}) *clientServerTest { cst := &clientServerTest{ t: t, @@ -64,21 +81,23 @@ func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...interface{}) tr: &Transport{}, } cst.c = &Client{Transport: cst.tr} + cst.ts = httptest.NewUnstartedServer(h) for _, opt := range opts { switch opt := opt.(type) { case func(*Transport): opt(cst.tr) + case func(*httptest.Server): + opt(cst.ts) default: t.Fatalf("unhandled option type %T", opt) } } if !h2 { - cst.ts = httptest.NewServer(h) + cst.ts.Start() return cst } - cst.ts = httptest.NewUnstartedServer(h) ExportHttp2ConfigureServer(cst.ts.Config, nil) cst.ts.TLS = cst.ts.Config.TLSConfig cst.ts.StartTLS() @@ -170,6 +189,7 @@ func (tt h12Compare) reqFunc() reqFunc { } func (tt h12Compare) run(t *testing.T) { + setParallel(t) cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler), tt.Opts...) defer cst1.close() cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler), tt.Opts...) @@ -468,7 +488,7 @@ func TestH12_RequestContentLength_Known_NonZero(t *testing.T) { } func TestH12_RequestContentLength_Known_Zero(t *testing.T) { - h12requestContentLength(t, func() io.Reader { return strings.NewReader("") }, 0) + h12requestContentLength(t, func() io.Reader { return nil }, 0) } func TestH12_RequestContentLength_Unknown(t *testing.T) { @@ -938,6 +958,7 @@ func testStarRequest(t *testing.T, method string, h2 bool) { // Issue 13957 func TestTransportDiscardsUnneededConns(t *testing.T) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Hello, %v", r.RemoteAddr) @@ -1026,6 +1047,7 @@ func testTransportGCRequest(t *testing.T, h2, body bool) { t.Skip("skipping on gccgo because conservative GC means that finalizer may never run") } + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { ioutil.ReadAll(r.Body) @@ -1072,10 +1094,11 @@ func TestTransportRejectsInvalidHeaders_h2(t *testing.T) { testTransportRejectsInvalidHeaders(t, h2Mode) } func testTransportRejectsInvalidHeaders(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Handler saw headers: %q", r.Header) - })) + }), optQuietLog) defer cst.close() cst.tr.DisableKeepAlives = true @@ -1143,24 +1166,44 @@ func testBogusStatusWorks(t *testing.T, h2 bool) { } } -func TestInterruptWithPanic_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode) } -func TestInterruptWithPanic_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode) } -func testInterruptWithPanic(t *testing.T, h2 bool) { - log.SetOutput(ioutil.Discard) // is noisy otherwise - defer log.SetOutput(os.Stderr) - +func TestInterruptWithPanic_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, "boom") } +func TestInterruptWithPanic_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, "boom") } +func TestInterruptWithPanic_nil_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, nil) } +func TestInterruptWithPanic_nil_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, nil) } +func TestInterruptWithPanic_ErrAbortHandler_h1(t *testing.T) { + testInterruptWithPanic(t, h1Mode, ErrAbortHandler) +} +func TestInterruptWithPanic_ErrAbortHandler_h2(t *testing.T) { + testInterruptWithPanic(t, h2Mode, ErrAbortHandler) +} +func testInterruptWithPanic(t *testing.T, h2 bool, panicValue interface{}) { + setParallel(t) const msg = "hello" defer afterTest(t) + + testDone := make(chan struct{}) + defer close(testDone) + + var errorLog lockedBytesBuffer + gotHeaders := make(chan bool, 1) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, msg) w.(Flusher).Flush() - panic("no more") - })) + + select { + case <-gotHeaders: + case <-testDone: + } + panic(panicValue) + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(&errorLog, "", 0) + }) defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) } + gotHeaders <- true defer res.Body.Close() slurp, err := ioutil.ReadAll(res.Body) if string(slurp) != msg { @@ -1169,6 +1212,42 @@ func testInterruptWithPanic(t *testing.T, h2 bool) { if err == nil { t.Errorf("client read all successfully; want some error") } + logOutput := func() string { + errorLog.Lock() + defer errorLog.Unlock() + return errorLog.String() + } + wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler + + if err := waitErrCondition(5*time.Second, 10*time.Millisecond, func() error { + gotLog := logOutput() + if !wantStackLogged { + if gotLog == "" { + return nil + } + return fmt.Errorf("want no log output; got: %s", gotLog) + } + if gotLog == "" { + return fmt.Errorf("wanted a stack trace logged; got nothing") + } + if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 { + return fmt.Errorf("output doesn't look like a panic stack trace. Got: %s", gotLog) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +type lockedBytesBuffer struct { + sync.Mutex + bytes.Buffer +} + +func (b *lockedBytesBuffer) Write(p []byte) (int, error) { + b.Lock() + defer b.Unlock() + return b.Buffer.Write(p) } // Issue 15366 @@ -1204,6 +1283,7 @@ func TestH12_AutoGzipWithDumpResponse(t *testing.T) { func TestCloseIdleConnections_h1(t *testing.T) { testCloseIdleConnections(t, h1Mode) } func TestCloseIdleConnections_h2(t *testing.T) { testCloseIdleConnections(t, h2Mode) } func testCloseIdleConnections(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Addr", r.RemoteAddr) @@ -1238,3 +1318,70 @@ func (x noteCloseConn) Close() error { x.closeFunc() return x.Conn.Close() } + +type testErrorReader struct{ t *testing.T } + +func (r testErrorReader) Read(p []byte) (n int, err error) { + r.t.Error("unexpected Read call") + return 0, io.EOF +} + +func TestNoSniffExpectRequestBody_h1(t *testing.T) { testNoSniffExpectRequestBody(t, h1Mode) } +func TestNoSniffExpectRequestBody_h2(t *testing.T) { testNoSniffExpectRequestBody(t, h2Mode) } + +func testNoSniffExpectRequestBody(t *testing.T, h2 bool) { + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + w.WriteHeader(StatusUnauthorized) + })) + defer cst.close() + + // Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it. + cst.tr.ExpectContinueTimeout = 10 * time.Second + + req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t}) + if err != nil { + t.Fatal(err) + } + req.ContentLength = 0 // so transport is tempted to sniff it + req.Header.Set("Expect", "100-continue") + res, err := cst.tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != StatusUnauthorized { + t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized) + } +} + +func TestServerUndeclaredTrailers_h1(t *testing.T) { testServerUndeclaredTrailers(t, h1Mode) } +func TestServerUndeclaredTrailers_h2(t *testing.T) { testServerUndeclaredTrailers(t, h2Mode) } +func testServerUndeclaredTrailers(t *testing.T, h2 bool) { + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Foo", "Bar") + w.Header().Set("Trailer:Foo", "Baz") + w.(Flusher).Flush() + w.Header().Add("Trailer:Foo", "Baz2") + w.Header().Set("Trailer:Bar", "Quux") + })) + defer cst.close() + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + if _, err := io.Copy(ioutil.Discard, res.Body); err != nil { + t.Fatal(err) + } + res.Body.Close() + delete(res.Header, "Date") + delete(res.Header, "Content-Type") + + if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) { + t.Errorf("Header = %#v; want %#v", res.Header, want) + } + if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) { + t.Errorf("Trailer = %#v; want %#v", res.Trailer, want) + } +} diff --git a/libgo/go/net/http/cookie.go b/libgo/go/net/http/cookie.go index 1ea0e93..5a67476 100644 --- a/libgo/go/net/http/cookie.go +++ b/libgo/go/net/http/cookie.go @@ -6,7 +6,6 @@ package http import ( "bytes" - "fmt" "log" "net" "strconv" @@ -40,7 +39,11 @@ type Cookie struct { // readSetCookies parses all "Set-Cookie" values from // the header h and returns the successfully parsed Cookies. func readSetCookies(h Header) []*Cookie { - cookies := []*Cookie{} + cookieCount := len(h["Set-Cookie"]) + if cookieCount == 0 { + return []*Cookie{} + } + cookies := make([]*Cookie, 0, cookieCount) for _, line := range h["Set-Cookie"] { parts := strings.Split(strings.TrimSpace(line), ";") if len(parts) == 1 && parts[0] == "" { @@ -55,8 +58,8 @@ func readSetCookies(h Header) []*Cookie { if !isCookieNameValid(name) { continue } - value, success := parseCookieValue(value, true) - if !success { + value, ok := parseCookieValue(value, true) + if !ok { continue } c := &Cookie{ @@ -75,8 +78,8 @@ func readSetCookies(h Header) []*Cookie { attr, val = attr[:j], attr[j+1:] } lowerAttr := strings.ToLower(attr) - val, success = parseCookieValue(val, false) - if !success { + val, ok = parseCookieValue(val, false) + if !ok { c.Unparsed = append(c.Unparsed, parts[i]) continue } @@ -96,10 +99,9 @@ func readSetCookies(h Header) []*Cookie { break } if secs <= 0 { - c.MaxAge = -1 - } else { - c.MaxAge = secs + secs = -1 } + c.MaxAge = secs continue case "expires": c.RawExpires = val @@ -142,9 +144,13 @@ func (c *Cookie) String() string { return "" } var b bytes.Buffer - fmt.Fprintf(&b, "%s=%s", sanitizeCookieName(c.Name), sanitizeCookieValue(c.Value)) + b.WriteString(sanitizeCookieName(c.Name)) + b.WriteRune('=') + b.WriteString(sanitizeCookieValue(c.Value)) + if len(c.Path) > 0 { - fmt.Fprintf(&b, "; Path=%s", sanitizeCookiePath(c.Path)) + b.WriteString("; Path=") + b.WriteString(sanitizeCookiePath(c.Path)) } if len(c.Domain) > 0 { if validCookieDomain(c.Domain) { @@ -156,25 +162,31 @@ func (c *Cookie) String() string { if d[0] == '.' { d = d[1:] } - fmt.Fprintf(&b, "; Domain=%s", d) + b.WriteString("; Domain=") + b.WriteString(d) } else { - log.Printf("net/http: invalid Cookie.Domain %q; dropping domain attribute", - c.Domain) + log.Printf("net/http: invalid Cookie.Domain %q; dropping domain attribute", c.Domain) } } - if c.Expires.Unix() > 0 { - fmt.Fprintf(&b, "; Expires=%s", c.Expires.UTC().Format(TimeFormat)) + if validCookieExpires(c.Expires) { + b.WriteString("; Expires=") + b2 := b.Bytes() + b.Reset() + b.Write(c.Expires.UTC().AppendFormat(b2, TimeFormat)) } if c.MaxAge > 0 { - fmt.Fprintf(&b, "; Max-Age=%d", c.MaxAge) + b.WriteString("; Max-Age=") + b2 := b.Bytes() + b.Reset() + b.Write(strconv.AppendInt(b2, int64(c.MaxAge), 10)) } else if c.MaxAge < 0 { - fmt.Fprintf(&b, "; Max-Age=0") + b.WriteString("; Max-Age=0") } if c.HttpOnly { - fmt.Fprintf(&b, "; HttpOnly") + b.WriteString("; HttpOnly") } if c.Secure { - fmt.Fprintf(&b, "; Secure") + b.WriteString("; Secure") } return b.String() } @@ -184,12 +196,12 @@ func (c *Cookie) String() string { // // if filter isn't empty, only cookies of that name are returned func readCookies(h Header, filter string) []*Cookie { - cookies := []*Cookie{} lines, ok := h["Cookie"] if !ok { - return cookies + return []*Cookie{} } + cookies := []*Cookie{} for _, line := range lines { parts := strings.Split(strings.TrimSpace(line), ";") if len(parts) == 1 && parts[0] == "" { @@ -212,8 +224,8 @@ func readCookies(h Header, filter string) []*Cookie { if filter != "" && filter != name { continue } - val, success := parseCookieValue(val, true) - if !success { + val, ok := parseCookieValue(val, true) + if !ok { continue } cookies = append(cookies, &Cookie{Name: name, Value: val}) @@ -234,6 +246,12 @@ func validCookieDomain(v string) bool { return false } +// validCookieExpires returns whether v is a valid cookie expires-value. +func validCookieExpires(t time.Time) bool { + // IETF RFC 6265 Section 5.1.1.5, the year must not be less than 1601 + return t.Year() >= 1601 +} + // isCookieDomainName returns whether s is a valid domain name or a valid // domain name with a leading dot '.'. It is almost a direct copy of // package net's isDomainName. diff --git a/libgo/go/net/http/cookie_test.go b/libgo/go/net/http/cookie_test.go index 95e6147..b3e54f8 100644 --- a/libgo/go/net/http/cookie_test.go +++ b/libgo/go/net/http/cookie_test.go @@ -56,6 +56,15 @@ var writeSetCookiesTests = []struct { &Cookie{Name: "cookie-9", Value: "expiring", Expires: time.Unix(1257894000, 0)}, "cookie-9=expiring; Expires=Tue, 10 Nov 2009 23:00:00 GMT", }, + // According to IETF 6265 Section 5.1.1.5, the year cannot be less than 1601 + { + &Cookie{Name: "cookie-10", Value: "expiring-1601", Expires: time.Date(1601, 1, 1, 1, 1, 1, 1, time.UTC)}, + "cookie-10=expiring-1601; Expires=Mon, 01 Jan 1601 01:01:01 GMT", + }, + { + &Cookie{Name: "cookie-11", Value: "invalid-expiry", Expires: time.Date(1600, 1, 1, 1, 1, 1, 1, time.UTC)}, + "cookie-11=invalid-expiry", + }, // The "special" cookies have values containing commas or spaces which // are disallowed by RFC 6265 but are common in the wild. { @@ -426,3 +435,92 @@ func TestCookieSanitizePath(t *testing.T) { t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got) } } + +func BenchmarkCookieString(b *testing.B) { + const wantCookieString = `cookie-9=i3e01nf61b6t23bvfmplnanol3; Path=/restricted/; Domain=example.com; Expires=Tue, 10 Nov 2009 23:00:00 GMT; Max-Age=3600` + c := &Cookie{ + Name: "cookie-9", + Value: "i3e01nf61b6t23bvfmplnanol3", + Expires: time.Unix(1257894000, 0), + Path: "/restricted/", + Domain: ".example.com", + MaxAge: 3600, + } + var benchmarkCookieString string + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkCookieString = c.String() + } + if have, want := benchmarkCookieString, wantCookieString; have != want { + b.Fatalf("Have: %v Want: %v", have, want) + } +} + +func BenchmarkReadSetCookies(b *testing.B) { + header := Header{ + "Set-Cookie": { + "NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly", + ".ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly", + }, + } + wantCookies := []*Cookie{ + { + Name: "NID", + Value: "99=YsDT5i3E-CXax-", + Path: "/", + Domain: ".google.ch", + HttpOnly: true, + Expires: time.Date(2011, 11, 23, 1, 5, 3, 0, time.UTC), + RawExpires: "Wed, 23-Nov-2011 01:05:03 GMT", + Raw: "NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly", + }, + { + Name: ".ASPXAUTH", + Value: "7E3AA", + Path: "/", + Expires: time.Date(2012, 3, 7, 14, 25, 6, 0, time.UTC), + RawExpires: "Wed, 07-Mar-2012 14:25:06 GMT", + HttpOnly: true, + Raw: ".ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly", + }, + } + var c []*Cookie + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + c = readSetCookies(header) + } + if !reflect.DeepEqual(c, wantCookies) { + b.Fatalf("readSetCookies:\nhave: %s\nwant: %s\n", toJSON(c), toJSON(wantCookies)) + } +} + +func BenchmarkReadCookies(b *testing.B) { + header := Header{ + "Cookie": { + `de=; client_region=0; rpld1=0:hispeed.ch|20:che|21:zh|22:zurich|23:47.36|24:8.53|; rpld0=1:08|; backplane-channel=newspaper.com:1471; devicetype=0; osfam=0; rplmct=2; s_pers=%20s_vmonthnum%3D1472680800496%2526vn%253D1%7C1472680800496%3B%20s_nr%3D1471686767664-New%7C1474278767664%3B%20s_lv%3D1471686767669%7C1566294767669%3B%20s_lv_s%3DFirst%2520Visit%7C1471688567669%3B%20s_monthinvisit%3Dtrue%7C1471688567677%3B%20gvp_p5%3Dsports%253Ablog%253Aearly-lead%2520-%2520184693%2520-%252020160820%2520-%2520u-s%7C1471688567681%3B%20gvp_p51%3Dwp%2520-%2520sports%7C1471688567684%3B; s_sess=%20s_wp_ep%3Dhomepage%3B%20s._ref%3Dhttps%253A%252F%252Fwww.google.ch%252F%3B%20s_cc%3Dtrue%3B%20s_ppvl%3Dsports%25253Ablog%25253Aearly-lead%252520-%252520184693%252520-%25252020160820%252520-%252520u-lawyer%252C12%252C12%252C502%252C1231%252C502%252C1680%252C1050%252C2%252CP%3B%20s_ppv%3Dsports%25253Ablog%25253Aearly-lead%252520-%252520184693%252520-%25252020160820%252520-%252520u-s-lawyer%252C12%252C12%252C502%252C1231%252C502%252C1680%252C1050%252C2%252CP%3B%20s_dslv%3DFirst%2520Visit%3B%20s_sq%3Dwpninewspapercom%253D%252526pid%25253Dsports%2525253Ablog%2525253Aearly-lead%25252520-%25252520184693%25252520-%2525252020160820%25252520-%25252520u-s%252526pidt%25253D1%252526oid%25253Dhttps%2525253A%2525252F%2525252Fwww.newspaper.com%2525252F%2525253Fnid%2525253Dmenu_nav_homepage%252526ot%25253DA%3B`, + }, + } + wantCookies := []*Cookie{ + {Name: "de", Value: ""}, + {Name: "client_region", Value: "0"}, + {Name: "rpld1", Value: "0:hispeed.ch|20:che|21:zh|22:zurich|23:47.36|24:8.53|"}, + {Name: "rpld0", Value: "1:08|"}, + {Name: "backplane-channel", Value: "newspaper.com:1471"}, + {Name: "devicetype", Value: "0"}, + {Name: "osfam", Value: "0"}, + {Name: "rplmct", Value: "2"}, + {Name: "s_pers", Value: "%20s_vmonthnum%3D1472680800496%2526vn%253D1%7C1472680800496%3B%20s_nr%3D1471686767664-New%7C1474278767664%3B%20s_lv%3D1471686767669%7C1566294767669%3B%20s_lv_s%3DFirst%2520Visit%7C1471688567669%3B%20s_monthinvisit%3Dtrue%7C1471688567677%3B%20gvp_p5%3Dsports%253Ablog%253Aearly-lead%2520-%2520184693%2520-%252020160820%2520-%2520u-s%7C1471688567681%3B%20gvp_p51%3Dwp%2520-%2520sports%7C1471688567684%3B"}, + {Name: "s_sess", Value: "%20s_wp_ep%3Dhomepage%3B%20s._ref%3Dhttps%253A%252F%252Fwww.google.ch%252F%3B%20s_cc%3Dtrue%3B%20s_ppvl%3Dsports%25253Ablog%25253Aearly-lead%252520-%252520184693%252520-%25252020160820%252520-%252520u-lawyer%252C12%252C12%252C502%252C1231%252C502%252C1680%252C1050%252C2%252CP%3B%20s_ppv%3Dsports%25253Ablog%25253Aearly-lead%252520-%252520184693%252520-%25252020160820%252520-%252520u-s-lawyer%252C12%252C12%252C502%252C1231%252C502%252C1680%252C1050%252C2%252CP%3B%20s_dslv%3DFirst%2520Visit%3B%20s_sq%3Dwpninewspapercom%253D%252526pid%25253Dsports%2525253Ablog%2525253Aearly-lead%25252520-%25252520184693%25252520-%2525252020160820%25252520-%25252520u-s%252526pidt%25253D1%252526oid%25253Dhttps%2525253A%2525252F%2525252Fwww.newspaper.com%2525252F%2525253Fnid%2525253Dmenu_nav_homepage%252526ot%25253DA%3B"}, + } + var c []*Cookie + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + c = readCookies(header, "") + } + if !reflect.DeepEqual(c, wantCookies) { + b.Fatalf("readCookies:\nhave: %s\nwant: %s\n", toJSON(c), toJSON(wantCookies)) + } +} diff --git a/libgo/go/net/http/cookiejar/dummy_publicsuffix_test.go b/libgo/go/net/http/cookiejar/dummy_publicsuffix_test.go new file mode 100644 index 0000000..748ec5c --- /dev/null +++ b/libgo/go/net/http/cookiejar/dummy_publicsuffix_test.go @@ -0,0 +1,23 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build ignore + +package cookiejar_test + +import "net/http/cookiejar" + +type dummypsl struct { + List cookiejar.PublicSuffixList +} + +func (dummypsl) PublicSuffix(domain string) string { + return domain +} + +func (dummypsl) String() string { + return "dummy" +} + +var publicsuffix = dummypsl{} diff --git a/libgo/go/net/http/cookiejar/example_test.go b/libgo/go/net/http/cookiejar/example_test.go new file mode 100644 index 0000000..19a5746 --- /dev/null +++ b/libgo/go/net/http/cookiejar/example_test.go @@ -0,0 +1,67 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build ignore + +package cookiejar_test + +import ( + "fmt" + "log" + "net/http" + "net/http/cookiejar" + "net/http/httptest" + "net/url" +) + +func ExampleNew() { + // Start a server to give us cookies. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if cookie, err := r.Cookie("Flavor"); err != nil { + http.SetCookie(w, &http.Cookie{Name: "Flavor", Value: "Chocolate Chip"}) + } else { + cookie.Value = "Oatmeal Raisin" + http.SetCookie(w, cookie) + } + })) + defer ts.Close() + + u, err := url.Parse(ts.URL) + if err != nil { + log.Fatal(err) + } + + // All users of cookiejar should import "golang.org/x/net/publicsuffix" + jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) + if err != nil { + log.Fatal(err) + } + + client := &http.Client{ + Jar: jar, + } + + if _, err = client.Get(u.String()); err != nil { + log.Fatal(err) + } + + fmt.Println("After 1st request:") + for _, cookie := range jar.Cookies(u) { + fmt.Printf(" %s: %s\n", cookie.Name, cookie.Value) + } + + if _, err = client.Get(u.String()); err != nil { + log.Fatal(err) + } + + fmt.Println("After 2nd request:") + for _, cookie := range jar.Cookies(u) { + fmt.Printf(" %s: %s\n", cookie.Name, cookie.Value) + } + // Output: + // After 1st request: + // Flavor: Chocolate Chip + // After 2nd request: + // Flavor: Oatmeal Raisin +} diff --git a/libgo/go/net/http/cookiejar/jar.go b/libgo/go/net/http/cookiejar/jar.go index 0e0fac9..f89abbc 100644 --- a/libgo/go/net/http/cookiejar/jar.go +++ b/libgo/go/net/http/cookiejar/jar.go @@ -107,7 +107,7 @@ type entry struct { seqNum uint64 } -// Id returns the domain;path;name triple of e as an id. +// id returns the domain;path;name triple of e as an id. func (e *entry) id() string { return fmt.Sprintf("%s;%s;%s", e.Domain, e.Path, e.Name) } @@ -147,24 +147,6 @@ func hasDotSuffix(s, suffix string) bool { return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix } -// byPathLength is a []entry sort.Interface that sorts according to RFC 6265 -// section 5.4 point 2: by longest path and then by earliest creation time. -type byPathLength []entry - -func (s byPathLength) Len() int { return len(s) } - -func (s byPathLength) Less(i, j int) bool { - if len(s[i].Path) != len(s[j].Path) { - return len(s[i].Path) > len(s[j].Path) - } - if !s[i].Creation.Equal(s[j].Creation) { - return s[i].Creation.Before(s[j].Creation) - } - return s[i].seqNum < s[j].seqNum -} - -func (s byPathLength) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - // Cookies implements the Cookies method of the http.CookieJar interface. // // It returns an empty slice if the URL's scheme is not HTTP or HTTPS. @@ -221,7 +203,18 @@ func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) { } } - sort.Sort(byPathLength(selected)) + // sort according to RFC 6265 section 5.4 point 2: by longest + // path and then by earliest creation time. + sort.Slice(selected, func(i, j int) bool { + s := selected + if len(s[i].Path) != len(s[j].Path) { + return len(s[i].Path) > len(s[j].Path) + } + if !s[i].Creation.Equal(s[j].Creation) { + return s[i].Creation.Before(s[j].Creation) + } + return s[i].seqNum < s[j].seqNum + }) for _, e := range selected { cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value}) } diff --git a/libgo/go/net/http/doc.go b/libgo/go/net/http/doc.go index 4ec8272..7855fea 100644 --- a/libgo/go/net/http/doc.go +++ b/libgo/go/net/http/doc.go @@ -44,7 +44,8 @@ For control over proxies, TLS configuration, keep-alives, compression, and other settings, create a Transport: tr := &http.Transport{ - TLSClientConfig: &tls.Config{RootCAs: pool}, + MaxIdleConns: 10, + IdleConnTimeout: 30 * time.Second, DisableCompression: true, } client := &http.Client{Transport: tr} @@ -77,19 +78,30 @@ custom Server: } log.Fatal(s.ListenAndServe()) -The http package has transparent support for the HTTP/2 protocol when -using HTTPS. Programs that must disable HTTP/2 can do so by setting -Transport.TLSNextProto (for clients) or Server.TLSNextProto (for -servers) to a non-nil, empty map. Alternatively, the following GODEBUG -environment variables are currently supported: +Starting with Go 1.6, the http package has transparent support for the +HTTP/2 protocol when using HTTPS. Programs that must disable HTTP/2 +can do so by setting Transport.TLSNextProto (for clients) or +Server.TLSNextProto (for servers) to a non-nil, empty +map. Alternatively, the following GODEBUG environment variables are +currently supported: GODEBUG=http2client=0 # disable HTTP/2 client support GODEBUG=http2server=0 # disable HTTP/2 server support GODEBUG=http2debug=1 # enable verbose HTTP/2 debug logs GODEBUG=http2debug=2 # ... even more verbose, with frame dumps -The GODEBUG variables are not covered by Go's API compatibility promise. -HTTP/2 support was added in Go 1.6. Please report any issues instead of -disabling HTTP/2 support: https://golang.org/s/http2bug +The GODEBUG variables are not covered by Go's API compatibility +promise. Please report any issues before disabling HTTP/2 +support: https://golang.org/s/http2bug + +The http package's Transport and Server both automatically enable +HTTP/2 support for simple configurations. To enable HTTP/2 for more +complex configurations, to use lower-level HTTP/2 features, or to use +a newer version of Go's http2 package, import "golang.org/x/net/http2" +directly and use its ConfigureTransport and/or ConfigureServer +functions. Manually configuring HTTP/2 via the golang.org/x/net/http2 +package takes precedence over the net/http package's built-in HTTP/2 +support. + */ package http diff --git a/libgo/go/net/http/export_test.go b/libgo/go/net/http/export_test.go index 9c5ba08..b61f58b 100644 --- a/libgo/go/net/http/export_test.go +++ b/libgo/go/net/http/export_test.go @@ -24,6 +24,7 @@ var ( ExportErrRequestCanceled = errRequestCanceled ExportErrRequestCanceledConn = errRequestCanceledConn ExportServeFile = serveFile + ExportScanETag = scanETag ExportHttp2ConfigureServer = http2ConfigureServer ) @@ -87,6 +88,12 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) { return } +func (t *Transport) IdleConnKeyCountForTesting() int { + t.idleMu.Lock() + defer t.idleMu.Unlock() + return len(t.idleConn) +} + func (t *Transport) IdleConnStrsForTesting() []string { var ret []string t.idleMu.Lock() @@ -100,6 +107,24 @@ func (t *Transport) IdleConnStrsForTesting() []string { return ret } +func (t *Transport) IdleConnStrsForTesting_h2() []string { + var ret []string + noDialPool := t.h2transport.ConnPool.(http2noDialClientConnPool) + pool := noDialPool.http2clientConnPool + + pool.mu.Lock() + defer pool.mu.Unlock() + + for k, cc := range pool.conns { + for range cc { + ret = append(ret, k) + } + } + + sort.Strings(ret) + return ret +} + func (t *Transport) IdleConnCountForTesting(cacheKey string) int { t.idleMu.Lock() defer t.idleMu.Unlock() @@ -160,3 +185,17 @@ func ExportHttp2ConfigureTransport(t *Transport) error { t.h2transport = t2 return nil } + +var Export_shouldCopyHeaderOnRedirect = shouldCopyHeaderOnRedirect + +func (s *Server) ExportAllConnsIdle() bool { + s.mu.Lock() + defer s.mu.Unlock() + for c := range s.activeConn { + st, ok := c.curState.Load().(ConnState) + if !ok || st != StateIdle { + return false + } + } + return true +} diff --git a/libgo/go/net/http/fcgi/fcgi.go b/libgo/go/net/http/fcgi/fcgi.go index 3374841..5057d70 100644 --- a/libgo/go/net/http/fcgi/fcgi.go +++ b/libgo/go/net/http/fcgi/fcgi.go @@ -3,8 +3,12 @@ // license that can be found in the LICENSE file. // Package fcgi implements the FastCGI protocol. +// +// The protocol is not an official standard and the original +// documentation is no longer online. See the Internet Archive's +// mirror at: https://web.archive.org/web/20150420080736/http://www.fastcgi.com/drupal/node/6?q=node/22 +// // Currently only the responder role is supported. -// The protocol is defined at http://www.fastcgi.com/drupal/node/6?q=node/22 package fcgi // This file defines the raw protocol and some utilities used by the child and diff --git a/libgo/go/net/http/fs.go b/libgo/go/net/http/fs.go index c7a58a6..bf63bb5 100644 --- a/libgo/go/net/http/fs.go +++ b/libgo/go/net/http/fs.go @@ -77,7 +77,7 @@ func dirList(w ResponseWriter, f File) { Error(w, "Error reading directory", StatusInternalServerError) return } - sort.Sort(byName(dirs)) + sort.Slice(dirs, func(i, j int) bool { return dirs[i].Name() < dirs[j].Name() }) w.Header().Set("Content-Type", "text/html; charset=utf-8") fmt.Fprintf(w, "<pre>\n") @@ -98,7 +98,8 @@ func dirList(w ResponseWriter, f File) { // ServeContent replies to the request using the content in the // provided ReadSeeker. The main benefit of ServeContent over io.Copy // is that it handles Range requests properly, sets the MIME type, and -// handles If-Modified-Since requests. +// handles If-Match, If-Unmodified-Since, If-None-Match, If-Modified-Since, +// and If-Range requests. // // If the response's Content-Type header is not set, ServeContent // first tries to deduce the type from name's file extension and, @@ -115,8 +116,8 @@ func dirList(w ResponseWriter, f File) { // The content's Seek method must work: ServeContent uses // a seek to the end of the content to determine its size. // -// If the caller has set w's ETag header, ServeContent uses it to -// handle requests using If-Range and If-None-Match. +// If the caller has set w's ETag header formatted per RFC 7232, section 2.3, +// ServeContent uses it to handle requests using If-Match, If-None-Match, or If-Range. // // Note that *os.File implements the io.ReadSeeker interface. func ServeContent(w ResponseWriter, req *Request, name string, modtime time.Time, content io.ReadSeeker) { @@ -140,15 +141,17 @@ func ServeContent(w ResponseWriter, req *Request, name string, modtime time.Time // users. var errSeeker = errors.New("seeker can't seek") +// errNoOverlap is returned by serveContent's parseRange if first-byte-pos of +// all of the byte-range-spec values is greater than the content size. +var errNoOverlap = errors.New("invalid range: failed to overlap") + // if name is empty, filename is unknown. (used for mime type, before sniffing) // if modtime.IsZero(), modtime is unknown. // content must be seeked to the beginning of the file. // The sizeFunc is called at most once. Its error, if any, is sent in the HTTP response. func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, sizeFunc func() (int64, error), content io.ReadSeeker) { - if checkLastModified(w, r, modtime) { - return - } - rangeReq, done := checkETag(w, r, modtime) + setLastModified(w, modtime) + done, rangeReq := checkPreconditions(w, r, modtime) if done { return } @@ -189,6 +192,9 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, if size >= 0 { ranges, err := parseRange(rangeReq, size) if err != nil { + if err == errNoOverlap { + w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", size)) + } Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) return } @@ -263,90 +269,245 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, } } -var unixEpochTime = time.Unix(0, 0) - -// modtime is the modification time of the resource to be served, or IsZero(). -// return value is whether this request is now complete. -func checkLastModified(w ResponseWriter, r *Request, modtime time.Time) bool { - if modtime.IsZero() || modtime.Equal(unixEpochTime) { - // If the file doesn't have a modtime (IsZero), or the modtime - // is obviously garbage (Unix time == 0), then ignore modtimes - // and don't process the If-Modified-Since header. - return false +// scanETag determines if a syntactically valid ETag is present at s. If so, +// the ETag and remaining text after consuming ETag is returned. Otherwise, +// it returns "", "". +func scanETag(s string) (etag string, remain string) { + s = textproto.TrimString(s) + start := 0 + if strings.HasPrefix(s, "W/") { + start = 2 + } + if len(s[start:]) < 2 || s[start] != '"' { + return "", "" + } + // ETag is either W/"text" or "text". + // See RFC 7232 2.3. + for i := start + 1; i < len(s); i++ { + c := s[i] + switch { + // Character values allowed in ETags. + case c == 0x21 || c >= 0x23 && c <= 0x7E || c >= 0x80: + case c == '"': + return string(s[:i+1]), s[i+1:] + default: + break + } } + return "", "" +} - // The Date-Modified header truncates sub-second precision, so - // use mtime < t+1s instead of mtime <= t to check for unmodified. - if t, err := time.Parse(TimeFormat, r.Header.Get("If-Modified-Since")); err == nil && modtime.Before(t.Add(1*time.Second)) { - h := w.Header() - delete(h, "Content-Type") - delete(h, "Content-Length") - w.WriteHeader(StatusNotModified) - return true - } - w.Header().Set("Last-Modified", modtime.UTC().Format(TimeFormat)) - return false +// etagStrongMatch reports whether a and b match using strong ETag comparison. +// Assumes a and b are valid ETags. +func etagStrongMatch(a, b string) bool { + return a == b && a != "" && a[0] == '"' } -// checkETag implements If-None-Match and If-Range checks. -// -// The ETag or modtime must have been previously set in the -// ResponseWriter's headers. The modtime is only compared at second -// granularity and may be the zero value to mean unknown. -// -// The return value is the effective request "Range" header to use and -// whether this request is now considered done. -func checkETag(w ResponseWriter, r *Request, modtime time.Time) (rangeReq string, done bool) { - etag := w.Header().get("Etag") - rangeReq = r.Header.get("Range") - - // Invalidate the range request if the entity doesn't match the one - // the client was expecting. - // "If-Range: version" means "ignore the Range: header unless version matches the - // current file." - // We only support ETag versions. - // The caller must have set the ETag on the response already. - if ir := r.Header.get("If-Range"); ir != "" && ir != etag { - // The If-Range value is typically the ETag value, but it may also be - // the modtime date. See golang.org/issue/8367. - timeMatches := false - if !modtime.IsZero() { - if t, err := ParseTime(ir); err == nil && t.Unix() == modtime.Unix() { - timeMatches = true - } +// etagWeakMatch reports whether a and b match using weak ETag comparison. +// Assumes a and b are valid ETags. +func etagWeakMatch(a, b string) bool { + return strings.TrimPrefix(a, "W/") == strings.TrimPrefix(b, "W/") +} + +// condResult is the result of an HTTP request precondition check. +// See https://tools.ietf.org/html/rfc7232 section 3. +type condResult int + +const ( + condNone condResult = iota + condTrue + condFalse +) + +func checkIfMatch(w ResponseWriter, r *Request) condResult { + im := r.Header.Get("If-Match") + if im == "" { + return condNone + } + for { + im = textproto.TrimString(im) + if len(im) == 0 { + break + } + if im[0] == ',' { + im = im[1:] + continue + } + if im[0] == '*' { + return condTrue } - if !timeMatches { - rangeReq = "" + etag, remain := scanETag(im) + if etag == "" { + break + } + if etagStrongMatch(etag, w.Header().get("Etag")) { + return condTrue } + im = remain } - if inm := r.Header.get("If-None-Match"); inm != "" { - // Must know ETag. + return condFalse +} + +func checkIfUnmodifiedSince(w ResponseWriter, r *Request, modtime time.Time) condResult { + ius := r.Header.Get("If-Unmodified-Since") + if ius == "" || isZeroTime(modtime) { + return condNone + } + if t, err := ParseTime(ius); err == nil { + // The Date-Modified header truncates sub-second precision, so + // use mtime < t+1s instead of mtime <= t to check for unmodified. + if modtime.Before(t.Add(1 * time.Second)) { + return condTrue + } + return condFalse + } + return condNone +} + +func checkIfNoneMatch(w ResponseWriter, r *Request) condResult { + inm := r.Header.get("If-None-Match") + if inm == "" { + return condNone + } + buf := inm + for { + buf = textproto.TrimString(buf) + if len(buf) == 0 { + break + } + if buf[0] == ',' { + buf = buf[1:] + } + if buf[0] == '*' { + return condFalse + } + etag, remain := scanETag(buf) if etag == "" { - return rangeReq, false + break + } + if etagWeakMatch(etag, w.Header().get("Etag")) { + return condFalse } + buf = remain + } + return condTrue +} + +func checkIfModifiedSince(w ResponseWriter, r *Request, modtime time.Time) condResult { + if r.Method != "GET" && r.Method != "HEAD" { + return condNone + } + ims := r.Header.Get("If-Modified-Since") + if ims == "" || isZeroTime(modtime) { + return condNone + } + t, err := ParseTime(ims) + if err != nil { + return condNone + } + // The Date-Modified header truncates sub-second precision, so + // use mtime < t+1s instead of mtime <= t to check for unmodified. + if modtime.Before(t.Add(1 * time.Second)) { + return condFalse + } + return condTrue +} + +func checkIfRange(w ResponseWriter, r *Request, modtime time.Time) condResult { + if r.Method != "GET" { + return condNone + } + ir := r.Header.get("If-Range") + if ir == "" { + return condNone + } + etag, _ := scanETag(ir) + if etag != "" { + if etagStrongMatch(etag, w.Header().Get("Etag")) { + return condTrue + } else { + return condFalse + } + } + // The If-Range value is typically the ETag value, but it may also be + // the modtime date. See golang.org/issue/8367. + if modtime.IsZero() { + return condFalse + } + t, err := ParseTime(ir) + if err != nil { + return condFalse + } + if t.Unix() == modtime.Unix() { + return condTrue + } + return condFalse +} + +var unixEpochTime = time.Unix(0, 0) + +// isZeroTime reports whether t is obviously unspecified (either zero or Unix()=0). +func isZeroTime(t time.Time) bool { + return t.IsZero() || t.Equal(unixEpochTime) +} + +func setLastModified(w ResponseWriter, modtime time.Time) { + if !isZeroTime(modtime) { + w.Header().Set("Last-Modified", modtime.UTC().Format(TimeFormat)) + } +} - // TODO(bradfitz): non-GET/HEAD requests require more work: - // sending a different status code on matches, and - // also can't use weak cache validators (those with a "W/ - // prefix). But most users of ServeContent will be using - // it on GET or HEAD, so only support those for now. - if r.Method != "GET" && r.Method != "HEAD" { - return rangeReq, false +func writeNotModified(w ResponseWriter) { + // RFC 7232 section 4.1: + // a sender SHOULD NOT generate representation metadata other than the + // above listed fields unless said metadata exists for the purpose of + // guiding cache updates (e.g., Last-Modified might be useful if the + // response does not have an ETag field). + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") + if h.Get("Etag") != "" { + delete(h, "Last-Modified") + } + w.WriteHeader(StatusNotModified) +} + +// checkPreconditions evaluates request preconditions and reports whether a precondition +// resulted in sending StatusNotModified or StatusPreconditionFailed. +func checkPreconditions(w ResponseWriter, r *Request, modtime time.Time) (done bool, rangeHeader string) { + // This function carefully follows RFC 7232 section 6. + ch := checkIfMatch(w, r) + if ch == condNone { + ch = checkIfUnmodifiedSince(w, r, modtime) + } + if ch == condFalse { + w.WriteHeader(StatusPreconditionFailed) + return true, "" + } + switch checkIfNoneMatch(w, r) { + case condFalse: + if r.Method == "GET" || r.Method == "HEAD" { + writeNotModified(w) + return true, "" + } else { + w.WriteHeader(StatusPreconditionFailed) + return true, "" } + case condNone: + if checkIfModifiedSince(w, r, modtime) == condFalse { + writeNotModified(w) + return true, "" + } + } - // TODO(bradfitz): deal with comma-separated or multiple-valued - // list of If-None-match values. For now just handle the common - // case of a single item. - if inm == etag || inm == "*" { - h := w.Header() - delete(h, "Content-Type") - delete(h, "Content-Length") - w.WriteHeader(StatusNotModified) - return "", true + rangeHeader = r.Header.get("Range") + if rangeHeader != "" { + if checkIfRange(w, r, modtime) == condFalse { + rangeHeader = "" } } - return rangeReq, false + return false, rangeHeader } // name is '/'-separated, not filepath.Separator. @@ -419,9 +580,11 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec // Still a directory? (we didn't find an index.html file) if d.IsDir() { - if checkLastModified(w, r, d.ModTime()) { + if checkIfModifiedSince(w, r, d.ModTime()) == condFalse { + writeNotModified(w) return } + w.Header().Set("Last-Modified", d.ModTime().UTC().Format(TimeFormat)) dirList(w, f) return } @@ -543,6 +706,7 @@ func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHead } // parseRange parses a Range header string as per RFC 2616. +// errNoOverlap is returned if none of the ranges overlap. func parseRange(s string, size int64) ([]httpRange, error) { if s == "" { return nil, nil // header not present @@ -552,6 +716,7 @@ func parseRange(s string, size int64) ([]httpRange, error) { return nil, errors.New("invalid range") } var ranges []httpRange + noOverlap := false for _, ra := range strings.Split(s[len(b):], ",") { ra = strings.TrimSpace(ra) if ra == "" { @@ -577,9 +742,15 @@ func parseRange(s string, size int64) ([]httpRange, error) { r.length = size - r.start } else { i, err := strconv.ParseInt(start, 10, 64) - if err != nil || i >= size || i < 0 { + if err != nil || i < 0 { return nil, errors.New("invalid range") } + if i >= size { + // If the range begins after the size of the content, + // then it does not overlap. + noOverlap = true + continue + } r.start = i if end == "" { // If no end is specified, range extends to end of the file. @@ -597,6 +768,10 @@ func parseRange(s string, size int64) ([]httpRange, error) { } ranges = append(ranges, r) } + if noOverlap && len(ranges) == 0 { + // The specified ranges did not overlap with the content. + return nil, errNoOverlap + } return ranges, nil } @@ -628,9 +803,3 @@ func sumRangesSize(ranges []httpRange) (size int64) { } return } - -type byName []os.FileInfo - -func (s byName) Len() int { return len(s) } -func (s byName) Less(i, j int) bool { return s[i].Name() < s[j].Name() } -func (s byName) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/libgo/go/net/http/fs_test.go b/libgo/go/net/http/fs_test.go index 22be389..bba5682 100644 --- a/libgo/go/net/http/fs_test.go +++ b/libgo/go/net/http/fs_test.go @@ -68,6 +68,7 @@ var ServeFileRangeTests = []struct { } func TestServeFile(t *testing.T) { + setParallel(t) defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "testdata/file") @@ -274,6 +275,7 @@ func TestFileServerEscapesNames(t *testing.T) { {`"'<>&`, `<a href="%22%27%3C%3E&">"'<>&</a>`}, {`?foo=bar#baz`, `<a href="%3Ffoo=bar%23baz">?foo=bar#baz</a>`}, {`<combo>?foo`, `<a href="%3Ccombo%3E%3Ffoo"><combo>?foo</a>`}, + {`foo:bar`, `<a href="./foo:bar">foo:bar</a>`}, } // We put each test file in its own directory in the fakeFS so we can look at it in isolation. @@ -765,6 +767,7 @@ func TestServeContent(t *testing.T) { reqHeader map[string]string wantLastMod string wantContentType string + wantContentRange string wantStatus int } htmlModTime := mustStat(t, "testdata/index.html").ModTime() @@ -782,8 +785,9 @@ func TestServeContent(t *testing.T) { wantStatus: 200, }, "not_modified_modtime": { - file: "testdata/style.css", - modtime: htmlModTime, + file: "testdata/style.css", + serveETag: `"foo"`, // Last-Modified sent only when no ETag + modtime: htmlModTime, reqHeader: map[string]string{ "If-Modified-Since": htmlModTime.UTC().Format(TimeFormat), }, @@ -792,6 +796,7 @@ func TestServeContent(t *testing.T) { "not_modified_modtime_with_contenttype": { file: "testdata/style.css", serveContentType: "text/css", // explicit content type + serveETag: `"foo"`, // Last-Modified sent only when no ETag modtime: htmlModTime, reqHeader: map[string]string{ "If-Modified-Since": htmlModTime.UTC().Format(TimeFormat), @@ -808,21 +813,62 @@ func TestServeContent(t *testing.T) { }, "not_modified_etag_no_seek": { content: panicOnSeek{nil}, // should never be called - serveETag: `"foo"`, + serveETag: `W/"foo"`, // If-None-Match uses weak ETag comparison reqHeader: map[string]string{ - "If-None-Match": `"foo"`, + "If-None-Match": `"baz", W/"foo"`, }, wantStatus: 304, }, + "if_none_match_mismatch": { + file: "testdata/style.css", + serveETag: `"foo"`, + reqHeader: map[string]string{ + "If-None-Match": `"Foo"`, + }, + wantStatus: 200, + wantContentType: "text/css; charset=utf-8", + }, "range_good": { file: "testdata/style.css", serveETag: `"A"`, reqHeader: map[string]string{ "Range": "bytes=0-4", }, - wantStatus: StatusPartialContent, + wantStatus: StatusPartialContent, + wantContentType: "text/css; charset=utf-8", + wantContentRange: "bytes 0-4/8", + }, + "range_match": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "Range": "bytes=0-4", + "If-Range": `"A"`, + }, + wantStatus: StatusPartialContent, + wantContentType: "text/css; charset=utf-8", + wantContentRange: "bytes 0-4/8", + }, + "range_match_weak_etag": { + file: "testdata/style.css", + serveETag: `W/"A"`, + reqHeader: map[string]string{ + "Range": "bytes=0-4", + "If-Range": `W/"A"`, + }, + wantStatus: 200, wantContentType: "text/css; charset=utf-8", }, + "range_no_overlap": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "Range": "bytes=10-20", + }, + wantStatus: StatusRequestedRangeNotSatisfiable, + wantContentType: "text/plain; charset=utf-8", + wantContentRange: "bytes */8", + }, // An If-Range resource for entity "A", but entity "B" is now current. // The Range request should be ignored. "range_no_match": { @@ -842,9 +888,10 @@ func TestServeContent(t *testing.T) { "Range": "bytes=0-4", "If-Range": "Wed, 25 Jun 2014 17:12:18 GMT", }, - wantStatus: StatusPartialContent, - wantContentType: "text/css; charset=utf-8", - wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT", + wantStatus: StatusPartialContent, + wantContentType: "text/css; charset=utf-8", + wantContentRange: "bytes 0-4/8", + wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT", }, "range_with_modtime_nanos": { file: "testdata/style.css", @@ -853,9 +900,10 @@ func TestServeContent(t *testing.T) { "Range": "bytes=0-4", "If-Range": "Wed, 25 Jun 2014 17:12:18 GMT", }, - wantStatus: StatusPartialContent, - wantContentType: "text/css; charset=utf-8", - wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT", + wantStatus: StatusPartialContent, + wantContentType: "text/css; charset=utf-8", + wantContentRange: "bytes 0-4/8", + wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT", }, "unix_zero_modtime": { content: strings.NewReader("<html>foo"), @@ -863,6 +911,62 @@ func TestServeContent(t *testing.T) { wantStatus: StatusOK, wantContentType: "text/html; charset=utf-8", }, + "ifmatch_matches": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "If-Match": `"Z", "A"`, + }, + wantStatus: 200, + wantContentType: "text/css; charset=utf-8", + }, + "ifmatch_star": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "If-Match": `*`, + }, + wantStatus: 200, + wantContentType: "text/css; charset=utf-8", + }, + "ifmatch_failed": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "If-Match": `"B"`, + }, + wantStatus: 412, + wantContentType: "text/plain; charset=utf-8", + }, + "ifmatch_fails_on_weak_etag": { + file: "testdata/style.css", + serveETag: `W/"A"`, + reqHeader: map[string]string{ + "If-Match": `W/"A"`, + }, + wantStatus: 412, + wantContentType: "text/plain; charset=utf-8", + }, + "if_unmodified_since_true": { + file: "testdata/style.css", + modtime: htmlModTime, + reqHeader: map[string]string{ + "If-Unmodified-Since": htmlModTime.UTC().Format(TimeFormat), + }, + wantStatus: 200, + wantContentType: "text/css; charset=utf-8", + wantLastMod: htmlModTime.UTC().Format(TimeFormat), + }, + "if_unmodified_since_false": { + file: "testdata/style.css", + modtime: htmlModTime, + reqHeader: map[string]string{ + "If-Unmodified-Since": htmlModTime.Add(-2 * time.Second).UTC().Format(TimeFormat), + }, + wantStatus: 412, + wantContentType: "text/plain; charset=utf-8", + wantLastMod: htmlModTime.UTC().Format(TimeFormat), + }, } for testName, tt := range tests { var content io.ReadSeeker @@ -903,6 +1007,9 @@ func TestServeContent(t *testing.T) { if g, e := res.Header.Get("Content-Type"), tt.wantContentType; g != e { t.Errorf("test %q: content-type = %q, want %q", testName, g, e) } + if g, e := res.Header.Get("Content-Range"), tt.wantContentRange; g != e { + t.Errorf("test %q: content-range = %q, want %q", testName, g, e) + } if g, e := res.Header.Get("Last-Modified"), tt.wantLastMod; g != e { t.Errorf("test %q: last-modified = %q, want %q", testName, g, e) } @@ -958,6 +1065,7 @@ func TestServeContentErrorMessages(t *testing.T) { // verifies that sendfile is being used on Linux func TestLinuxSendfile(t *testing.T) { + setParallel(t) defer afterTest(t) if runtime.GOOS != "linux" { t.Skip("skipping; linux-only test") @@ -982,6 +1090,8 @@ func TestLinuxSendfile(t *testing.T) { // strace on the above platforms doesn't support sendfile64 // and will error out if we specify that with `-e trace='. syscalls = "sendfile" + case "mips64": + t.Skip("TODO: update this test to be robust against various versions of strace on mips64. See golang.org/issue/33430") } var buf bytes.Buffer @@ -1008,10 +1118,9 @@ func TestLinuxSendfile(t *testing.T) { Post(fmt.Sprintf("http://%s/quit", ln.Addr()), "", nil) child.Wait() - rx := regexp.MustCompile(`sendfile(64)?\(\d+,\s*\d+,\s*NULL,\s*\d+\)\s*=\s*\d+\s*\n`) - rxResume := regexp.MustCompile(`<\.\.\. sendfile(64)? resumed> \)\s*=\s*\d+\s*\n`) + rx := regexp.MustCompile(`sendfile(64)?\(\d+,\s*\d+,\s*NULL,\s*\d+`) out := buf.String() - if !rx.MatchString(out) && !rxResume.MatchString(out) { + if !rx.MatchString(out) { t.Errorf("no sendfile system call found in:\n%s", out) } } @@ -1090,3 +1199,26 @@ func (d fileServerCleanPathDir) Open(path string) (File, error) { } type panicOnSeek struct{ io.ReadSeeker } + +func Test_scanETag(t *testing.T) { + tests := []struct { + in string + wantETag string + wantRemain string + }{ + {`W/"etag-1"`, `W/"etag-1"`, ""}, + {`"etag-2"`, `"etag-2"`, ""}, + {`"etag-1", "etag-2"`, `"etag-1"`, `, "etag-2"`}, + {"", "", ""}, + {"", "", ""}, + {"W/", "", ""}, + {`W/"truc`, "", ""}, + {`w/"case-sensitive"`, "", ""}, + } + for _, test := range tests { + etag, remain := ExportScanETag(test.in) + if etag != test.wantETag || remain != test.wantRemain { + t.Errorf("scanETag(%q)=%q %q, want %q %q", test.in, etag, remain, test.wantETag, test.wantRemain) + } + } +} diff --git a/libgo/go/net/http/h2_bundle.go b/libgo/go/net/http/h2_bundle.go index 5826bb7..25fdf09 100644 --- a/libgo/go/net/http/h2_bundle.go +++ b/libgo/go/net/http/h2_bundle.go @@ -1,5 +1,5 @@ // Code generated by golang.org/x/tools/cmd/bundle. -//go:generate bundle -o h2_bundle.go -prefix http2 golang.org/x/net/http2 +//go:generate bundle -o h2_bundle.go -prefix http2 -underscore golang.org/x/net/http2 // Package http2 implements the HTTP/2 protocol. // @@ -21,6 +21,7 @@ import ( "bytes" "compress/gzip" "context" + "crypto/rand" "crypto/tls" "encoding/binary" "errors" @@ -43,6 +44,7 @@ import ( "time" "golang_org/x/net/http2/hpack" + "golang_org/x/net/idna" "golang_org/x/net/lex/httplex" ) @@ -853,10 +855,12 @@ type http2Framer struct { // If the limit is hit, MetaHeadersFrame.Truncated is set true. MaxHeaderListSize uint32 - logReads bool + logReads, logWrites bool - debugFramer *http2Framer // only use for logging written writes - debugFramerBuf *bytes.Buffer + debugFramer *http2Framer // only use for logging written writes + debugFramerBuf *bytes.Buffer + debugReadLoggerf func(string, ...interface{}) + debugWriteLoggerf func(string, ...interface{}) } func (fr *http2Framer) maxHeaderListSize() uint32 { @@ -890,7 +894,7 @@ func (f *http2Framer) endWrite() error { byte(length>>16), byte(length>>8), byte(length)) - if http2logFrameWrites { + if f.logWrites { f.logWrite() } @@ -912,10 +916,10 @@ func (f *http2Framer) logWrite() { f.debugFramerBuf.Write(f.wbuf) fr, err := f.debugFramer.ReadFrame() if err != nil { - log.Printf("http2: Framer %p: failed to decode just-written frame", f) + f.debugWriteLoggerf("http2: Framer %p: failed to decode just-written frame", f) return } - log.Printf("http2: Framer %p: wrote %v", f, http2summarizeFrame(fr)) + f.debugWriteLoggerf("http2: Framer %p: wrote %v", f, http2summarizeFrame(fr)) } func (f *http2Framer) writeByte(v byte) { f.wbuf = append(f.wbuf, v) } @@ -936,9 +940,12 @@ const ( // NewFramer returns a Framer that writes frames to w and reads them from r. func http2NewFramer(w io.Writer, r io.Reader) *http2Framer { fr := &http2Framer{ - w: w, - r: r, - logReads: http2logFrameReads, + w: w, + r: r, + logReads: http2logFrameReads, + logWrites: http2logFrameWrites, + debugReadLoggerf: log.Printf, + debugWriteLoggerf: log.Printf, } fr.getReadBuf = func(size uint32) []byte { if cap(fr.readBuf) >= int(size) { @@ -1020,7 +1027,7 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) { return nil, err } if fr.logReads { - log.Printf("http2: Framer %p: read %v", fr, http2summarizeFrame(f)) + fr.debugReadLoggerf("http2: Framer %p: read %v", fr, http2summarizeFrame(f)) } if fh.Type == http2FrameHeaders && fr.ReadMetaHeaders != nil { return fr.readMetaFrame(f.(*http2HeadersFrame)) @@ -1254,7 +1261,7 @@ func (f *http2Framer) WriteSettings(settings ...http2Setting) error { return f.endWrite() } -// WriteSettings writes an empty SETTINGS frame with the ACK bit set. +// WriteSettingsAck writes an empty SETTINGS frame with the ACK bit set. // // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility to not call other Write methods concurrently. @@ -1920,8 +1927,8 @@ func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFr hdec.SetEmitEnabled(true) hdec.SetMaxStringLength(fr.maxHeaderStringLen()) hdec.SetEmitFunc(func(hf hpack.HeaderField) { - if http2VerboseLogs && http2logFrameReads { - log.Printf("http2: decoded hpack field %+v", hf) + if http2VerboseLogs && fr.logReads { + fr.debugReadLoggerf("http2: decoded hpack field %+v", hf) } if !httplex.ValidHeaderFieldValue(hf.Value) { invalid = http2headerFieldValueError(hf.Value) @@ -2091,6 +2098,13 @@ type http2clientTrace httptrace.ClientTrace func http2reqContext(r *Request) context.Context { return r.Context() } +func (t *http2Transport) idleConnTimeout() time.Duration { + if t.t1 != nil { + return t.t1.IdleConnTimeout + } + return 0 +} + func http2setResponseUncompressed(res *Response) { res.Uncompressed = true } func http2traceGotConn(req *Request, cc *http2ClientConn) { @@ -2145,6 +2159,48 @@ func http2requestTrace(req *Request) *http2clientTrace { return (*http2clientTrace)(trace) } +// Ping sends a PING frame to the server and waits for the ack. +func (cc *http2ClientConn) Ping(ctx context.Context) error { + return cc.ping(ctx) +} + +func http2cloneTLSConfig(c *tls.Config) *tls.Config { return c.Clone() } + +var _ Pusher = (*http2responseWriter)(nil) + +// Push implements http.Pusher. +func (w *http2responseWriter) Push(target string, opts *PushOptions) error { + internalOpts := http2pushOptions{} + if opts != nil { + internalOpts.Method = opts.Method + internalOpts.Header = opts.Header + } + return w.push(target, internalOpts) +} + +func http2configureServer18(h1 *Server, h2 *http2Server) error { + if h2.IdleTimeout == 0 { + if h1.IdleTimeout != 0 { + h2.IdleTimeout = h1.IdleTimeout + } else { + h2.IdleTimeout = h1.ReadTimeout + } + } + return nil +} + +func http2shouldLogPanic(panicValue interface{}) bool { + return panicValue != nil && panicValue != ErrAbortHandler +} + +func http2reqGetBody(req *Request) func() (io.ReadCloser, error) { + return req.GetBody +} + +func http2reqBodyIsNoBody(body io.ReadCloser) bool { + return body == NoBody +} + var http2DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1" type http2goroutineLock uint64 @@ -2368,6 +2424,7 @@ var ( http2VerboseLogs bool http2logFrameWrites bool http2logFrameReads bool + http2inTests bool ) func init() { @@ -2409,13 +2466,23 @@ var ( type http2streamState int +// HTTP/2 stream states. +// +// See http://tools.ietf.org/html/rfc7540#section-5.1. +// +// For simplicity, the server code merges "reserved (local)" into +// "half-closed (remote)". This is one less state transition to track. +// The only downside is that we send PUSH_PROMISEs slightly less +// liberally than allowable. More discussion here: +// https://lists.w3.org/Archives/Public/ietf-http-wg/2016JulSep/0599.html +// +// "reserved (remote)" is omitted since the client code does not +// support server push. const ( http2stateIdle http2streamState = iota http2stateOpen http2stateHalfClosedLocal http2stateHalfClosedRemote - http2stateResvLocal - http2stateResvRemote http2stateClosed ) @@ -2424,8 +2491,6 @@ var http2stateName = [...]string{ http2stateOpen: "Open", http2stateHalfClosedLocal: "HalfClosedLocal", http2stateHalfClosedRemote: "HalfClosedRemote", - http2stateResvLocal: "ResvLocal", - http2stateResvRemote: "ResvRemote", http2stateClosed: "Closed", } @@ -2586,13 +2651,27 @@ func http2newBufferedWriter(w io.Writer) *http2bufferedWriter { return &http2bufferedWriter{w: w} } +// bufWriterPoolBufferSize is the size of bufio.Writer's +// buffers created using bufWriterPool. +// +// TODO: pick a less arbitrary value? this is a bit under +// (3 x typical 1500 byte MTU) at least. Other than that, +// not much thought went into it. +const http2bufWriterPoolBufferSize = 4 << 10 + var http2bufWriterPool = sync.Pool{ New: func() interface{} { - - return bufio.NewWriterSize(nil, 4<<10) + return bufio.NewWriterSize(nil, http2bufWriterPoolBufferSize) }, } +func (w *http2bufferedWriter) Available() int { + if w.bw == nil { + return http2bufWriterPoolBufferSize + } + return w.bw.Available() +} + func (w *http2bufferedWriter) Write(p []byte) (n int, err error) { if w.bw == nil { bw := http2bufWriterPool.Get().(*bufio.Writer) @@ -2686,6 +2765,19 @@ func (s *http2sorter) SortStrings(ss []string) { s.v = save } +// validPseudoPath reports whether v is a valid :path pseudo-header +// value. It must be either: +// +// *) a non-empty string starting with '/', but not with with "//", +// *) the string '*', for OPTIONS requests. +// +// For now this is only used a quick check for deciding when to clean +// up Opaque URLs before sending requests from the Transport. +// See golang.org/issue/16847 +func http2validPseudoPath(v string) bool { + return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*" +} + // pipe is a goroutine-safe io.Reader/io.Writer pair. It's like // io.Pipe except there are no PipeReader/PipeWriter halves, and the // underlying buffer is an interface. (io.Pipe is always unbuffered) @@ -2882,6 +2974,15 @@ type http2Server struct { // PermitProhibitedCipherSuites, if true, permits the use of // cipher suites prohibited by the HTTP/2 spec. PermitProhibitedCipherSuites bool + + // IdleTimeout specifies how long until idle clients should be + // closed with a GOAWAY frame. PING frames are not considered + // activity for the purposes of IdleTimeout. + IdleTimeout time.Duration + + // NewWriteScheduler constructs a write scheduler for a connection. + // If nil, a default scheduler is chosen. + NewWriteScheduler func() http2WriteScheduler } func (s *http2Server) maxReadFrameSize() uint32 { @@ -2904,9 +3005,15 @@ func (s *http2Server) maxConcurrentStreams() uint32 { // // ConfigureServer must be called before s begins serving. func http2ConfigureServer(s *Server, conf *http2Server) error { + if s == nil { + panic("nil *http.Server") + } if conf == nil { conf = new(http2Server) } + if err := http2configureServer18(s, conf); err != nil { + return err + } if s.TLSConfig == nil { s.TLSConfig = new(tls.Config) @@ -2945,8 +3052,6 @@ func http2ConfigureServer(s *Server, conf *http2Server) error { s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, http2NextProtoTLS) } - s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "h2-14") - if s.TLSNextProto == nil { s.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){} } @@ -2960,7 +3065,6 @@ func http2ConfigureServer(s *Server, conf *http2Server) error { }) } s.TLSNextProto[http2NextProtoTLS] = protoHandler - s.TLSNextProto["h2-14"] = protoHandler return nil } @@ -3014,29 +3118,39 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { defer cancel() sc := &http2serverConn{ - srv: s, - hs: opts.baseConfig(), - conn: c, - baseCtx: baseCtx, - remoteAddrStr: c.RemoteAddr().String(), - bw: http2newBufferedWriter(c), - handler: opts.handler(), - streams: make(map[uint32]*http2stream), - readFrameCh: make(chan http2readFrameResult), - wantWriteFrameCh: make(chan http2frameWriteMsg, 8), - wroteFrameCh: make(chan http2frameWriteResult, 1), - bodyReadCh: make(chan http2bodyReadMsg), - doneServing: make(chan struct{}), - advMaxStreams: s.maxConcurrentStreams(), - writeSched: http2writeScheduler{ - maxFrameSize: http2initialMaxFrameSize, - }, + srv: s, + hs: opts.baseConfig(), + conn: c, + baseCtx: baseCtx, + remoteAddrStr: c.RemoteAddr().String(), + bw: http2newBufferedWriter(c), + handler: opts.handler(), + streams: make(map[uint32]*http2stream), + readFrameCh: make(chan http2readFrameResult), + wantWriteFrameCh: make(chan http2FrameWriteRequest, 8), + wantStartPushCh: make(chan http2startPushRequest, 8), + wroteFrameCh: make(chan http2frameWriteResult, 1), + bodyReadCh: make(chan http2bodyReadMsg), + doneServing: make(chan struct{}), + clientMaxStreams: math.MaxUint32, + advMaxStreams: s.maxConcurrentStreams(), initialWindowSize: http2initialWindowSize, + maxFrameSize: http2initialMaxFrameSize, headerTableSize: http2initialHeaderTableSize, serveG: http2newGoroutineLock(), pushEnabled: true, } + if sc.hs.WriteTimeout != 0 { + sc.conn.SetWriteDeadline(time.Time{}) + } + + if s.NewWriteScheduler != nil { + sc.writeSched = s.NewWriteScheduler() + } else { + sc.writeSched = http2NewRandomWriteScheduler() + } + sc.flow.add(http2initialWindowSize) sc.inflow.add(http2initialWindowSize) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) @@ -3090,16 +3204,18 @@ type http2serverConn struct { handler Handler baseCtx http2contextContext framer *http2Framer - doneServing chan struct{} // closed when serverConn.serve ends - readFrameCh chan http2readFrameResult // written by serverConn.readFrames - wantWriteFrameCh chan http2frameWriteMsg // from handlers -> serve - wroteFrameCh chan http2frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes - bodyReadCh chan http2bodyReadMsg // from handlers -> serve - testHookCh chan func(int) // code to run on the serve loop - flow http2flow // conn-wide (not stream-specific) outbound flow control - inflow http2flow // conn-wide inbound flow control - tlsState *tls.ConnectionState // shared by all handlers, like net/http + doneServing chan struct{} // closed when serverConn.serve ends + readFrameCh chan http2readFrameResult // written by serverConn.readFrames + wantWriteFrameCh chan http2FrameWriteRequest // from handlers -> serve + wantStartPushCh chan http2startPushRequest // from handlers -> serve + wroteFrameCh chan http2frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes + bodyReadCh chan http2bodyReadMsg // from handlers -> serve + testHookCh chan func(int) // code to run on the serve loop + flow http2flow // conn-wide (not stream-specific) outbound flow control + inflow http2flow // conn-wide inbound flow control + tlsState *tls.ConnectionState // shared by all handlers, like net/http remoteAddrStr string + writeSched http2WriteScheduler // Everything following is owned by the serve loop; use serveG.check(): serveG http2goroutineLock // used to verify funcs are on serve() @@ -3109,22 +3225,27 @@ type http2serverConn struct { unackedSettings int // how many SETTINGS have we sent without ACKs? clientMaxStreams uint32 // SETTINGS_MAX_CONCURRENT_STREAMS from client (our PUSH_PROMISE limit) advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client - curOpenStreams uint32 // client's number of open streams - maxStreamID uint32 // max ever seen + curClientStreams uint32 // number of open streams initiated by the client + curPushedStreams uint32 // number of open streams initiated by server push + maxClientStreamID uint32 // max ever seen from client (odd), or 0 if there have been no client requests + maxPushPromiseID uint32 // ID of the last push promise (even), or 0 if there have been no pushes streams map[uint32]*http2stream initialWindowSize int32 + maxFrameSize int32 headerTableSize uint32 peerMaxHeaderListSize uint32 // zero means unknown (default) canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case - writingFrame bool // started write goroutine but haven't heard back on wroteFrameCh + writingFrame bool // started writing a frame (on serve goroutine or separate) + writingFrameAsync bool // started a frame on its own goroutine but haven't heard back on wroteFrameCh needsFrameFlush bool // last frame write wasn't a flush - writeSched http2writeScheduler - inGoAway bool // we've started to or sent GOAWAY - needToSendGoAway bool // we need to schedule a GOAWAY frame write + inGoAway bool // we've started to or sent GOAWAY + inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop + needToSendGoAway bool // we need to schedule a GOAWAY frame write goAwayCode http2ErrCode shutdownTimerCh <-chan time.Time // nil until used shutdownTimer *time.Timer // nil until used - freeRequestBodyBuf []byte // if non-nil, a free initialWindowSize buffer for getRequestBodyBuf + idleTimer *time.Timer // nil if unused + idleTimerCh <-chan time.Time // nil if unused // Owned by the writeFrameAsync goroutine: headerWriteBuf bytes.Buffer @@ -3143,6 +3264,11 @@ func (sc *http2serverConn) maxHeaderListSize() uint32 { return uint32(n + typicalHeaders*perFieldOverhead) } +func (sc *http2serverConn) curOpenStreams() uint32 { + sc.serveG.check() + return sc.curClientStreams + sc.curPushedStreams +} + // stream represents a stream. This is the minimal metadata needed by // the serve goroutine. Most of the actual stream state is owned by // the http.Handler's goroutine in the responseWriter. Because the @@ -3168,11 +3294,10 @@ type http2stream struct { numTrailerValues int64 weight uint8 state http2streamState - sentReset bool // only true once detached from streams map - gotReset bool // only true once detacted from streams map - gotTrailerHeader bool // HEADER frame for trailers was seen - wroteHeaders bool // whether we wrote headers (not status 100) - reqBuf []byte + resetQueued bool // RST_STREAM queued for write; set by sc.resetStream + gotTrailerHeader bool // HEADER frame for trailers was seen + wroteHeaders bool // whether we wrote headers (not status 100) + reqBuf []byte // if non-nil, body pipe buffer to return later at EOF trailer Header // accumulated trailers reqTrailer Header // handler's Request.Trailer @@ -3195,8 +3320,14 @@ func (sc *http2serverConn) state(streamID uint32) (http2streamState, *http2strea return st.state, st } - if streamID <= sc.maxStreamID { - return http2stateClosed, nil + if streamID%2 == 1 { + if streamID <= sc.maxClientStreamID { + return http2stateClosed, nil + } + } else { + if streamID <= sc.maxPushPromiseID { + return http2stateClosed, nil + } } return http2stateIdle, nil } @@ -3328,17 +3459,17 @@ func (sc *http2serverConn) readFrames() { // frameWriteResult is the message passed from writeFrameAsync to the serve goroutine. type http2frameWriteResult struct { - wm http2frameWriteMsg // what was written (or attempted) - err error // result of the writeFrame call + wr http2FrameWriteRequest // what was written (or attempted) + err error // result of the writeFrame call } // writeFrameAsync runs in its own goroutine and writes a single frame // and then reports when it's done. // At most one goroutine can be running writeFrameAsync at a time per // serverConn. -func (sc *http2serverConn) writeFrameAsync(wm http2frameWriteMsg) { - err := wm.write.writeFrame(sc) - sc.wroteFrameCh <- http2frameWriteResult{wm, err} +func (sc *http2serverConn) writeFrameAsync(wr http2FrameWriteRequest) { + err := wr.write.writeFrame(sc) + sc.wroteFrameCh <- http2frameWriteResult{wr, err} } func (sc *http2serverConn) closeAllStreamsOnConnClose() { @@ -3382,7 +3513,7 @@ func (sc *http2serverConn) serve() { sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs) } - sc.writeFrame(http2frameWriteMsg{ + sc.writeFrame(http2FrameWriteRequest{ write: http2writeSettings{ {http2SettingMaxFrameSize, sc.srv.maxReadFrameSize()}, {http2SettingMaxConcurrentStreams, sc.advMaxStreams}, @@ -3399,6 +3530,17 @@ func (sc *http2serverConn) serve() { sc.setConnState(StateActive) sc.setConnState(StateIdle) + if sc.srv.IdleTimeout != 0 { + sc.idleTimer = time.NewTimer(sc.srv.IdleTimeout) + defer sc.idleTimer.Stop() + sc.idleTimerCh = sc.idleTimer.C + } + + var gracefulShutdownCh <-chan struct{} + if sc.hs != nil { + gracefulShutdownCh = http2h1ServerShutdownChan(sc.hs) + } + go sc.readFrames() settingsTimer := time.NewTimer(http2firstSettingsTimeout) @@ -3406,8 +3548,10 @@ func (sc *http2serverConn) serve() { for { loopNum++ select { - case wm := <-sc.wantWriteFrameCh: - sc.writeFrame(wm) + case wr := <-sc.wantWriteFrameCh: + sc.writeFrame(wr) + case spr := <-sc.wantStartPushCh: + sc.startPush(spr) case res := <-sc.wroteFrameCh: sc.wroteFrame(res) case res := <-sc.readFrameCh: @@ -3424,12 +3568,22 @@ func (sc *http2serverConn) serve() { case <-settingsTimer.C: sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr()) return + case <-gracefulShutdownCh: + gracefulShutdownCh = nil + sc.startGracefulShutdown() case <-sc.shutdownTimerCh: sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr()) return + case <-sc.idleTimerCh: + sc.vlogf("connection is idle") + sc.goAway(http2ErrCodeNo) case fn := <-sc.testHookCh: fn(loopNum) } + + if sc.inGoAway && sc.curOpenStreams() == 0 && !sc.needToSendGoAway && !sc.writingFrame { + return + } } } @@ -3477,7 +3631,7 @@ func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte ch := http2errChanPool.Get().(chan error) writeArg := http2writeDataPool.Get().(*http2writeData) *writeArg = http2writeData{stream.id, data, endStream} - err := sc.writeFrameFromHandler(http2frameWriteMsg{ + err := sc.writeFrameFromHandler(http2FrameWriteRequest{ write: writeArg, stream: stream, done: ch, @@ -3507,17 +3661,17 @@ func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte return err } -// writeFrameFromHandler sends wm to sc.wantWriteFrameCh, but aborts +// writeFrameFromHandler sends wr to sc.wantWriteFrameCh, but aborts // if the connection has gone away. // // This must not be run from the serve goroutine itself, else it might // deadlock writing to sc.wantWriteFrameCh (which is only mildly // buffered and is read by serve itself). If you're on the serve // goroutine, call writeFrame instead. -func (sc *http2serverConn) writeFrameFromHandler(wm http2frameWriteMsg) error { +func (sc *http2serverConn) writeFrameFromHandler(wr http2FrameWriteRequest) error { sc.serveG.checkNotOn() select { - case sc.wantWriteFrameCh <- wm: + case sc.wantWriteFrameCh <- wr: return nil case <-sc.doneServing: @@ -3533,53 +3687,81 @@ func (sc *http2serverConn) writeFrameFromHandler(wm http2frameWriteMsg) error { // make it onto the wire // // If you're not on the serve goroutine, use writeFrameFromHandler instead. -func (sc *http2serverConn) writeFrame(wm http2frameWriteMsg) { +func (sc *http2serverConn) writeFrame(wr http2FrameWriteRequest) { sc.serveG.check() + // If true, wr will not be written and wr.done will not be signaled. var ignoreWrite bool - switch wm.write.(type) { + if wr.StreamID() != 0 { + _, isReset := wr.write.(http2StreamError) + if state, _ := sc.state(wr.StreamID()); state == http2stateClosed && !isReset { + ignoreWrite = true + } + } + + switch wr.write.(type) { case *http2writeResHeaders: - wm.stream.wroteHeaders = true + wr.stream.wroteHeaders = true case http2write100ContinueHeadersFrame: - if wm.stream.wroteHeaders { + if wr.stream.wroteHeaders { + + if wr.done != nil { + panic("wr.done != nil for write100ContinueHeadersFrame") + } ignoreWrite = true } } if !ignoreWrite { - sc.writeSched.add(wm) + sc.writeSched.Push(wr) } sc.scheduleFrameWrite() } -// startFrameWrite starts a goroutine to write wm (in a separate +// startFrameWrite starts a goroutine to write wr (in a separate // goroutine since that might block on the network), and updates the -// serve goroutine's state about the world, updated from info in wm. -func (sc *http2serverConn) startFrameWrite(wm http2frameWriteMsg) { +// serve goroutine's state about the world, updated from info in wr. +func (sc *http2serverConn) startFrameWrite(wr http2FrameWriteRequest) { sc.serveG.check() if sc.writingFrame { panic("internal error: can only be writing one frame at a time") } - st := wm.stream + st := wr.stream if st != nil { switch st.state { case http2stateHalfClosedLocal: - panic("internal error: attempt to send frame on half-closed-local stream") - case http2stateClosed: - if st.sentReset || st.gotReset { + switch wr.write.(type) { + case http2StreamError, http2handlerPanicRST, http2writeWindowUpdate: - sc.scheduleFrameWrite() - return + default: + panic(fmt.Sprintf("internal error: attempt to send frame on a half-closed-local stream: %v", wr)) } - panic(fmt.Sprintf("internal error: attempt to send a write %v on a closed stream", wm)) + case http2stateClosed: + panic(fmt.Sprintf("internal error: attempt to send frame on a closed stream: %v", wr)) + } + } + if wpp, ok := wr.write.(*http2writePushPromise); ok { + var err error + wpp.promisedID, err = wpp.allocatePromisedID() + if err != nil { + sc.writingFrameAsync = false + wr.replyToWriter(err) + return } } sc.writingFrame = true sc.needsFrameFlush = true - go sc.writeFrameAsync(wm) + if wr.write.staysWithinBuffer(sc.bw.Available()) { + sc.writingFrameAsync = false + err := wr.write.writeFrame(sc) + sc.wroteFrame(http2frameWriteResult{wr, err}) + } else { + sc.writingFrameAsync = true + go sc.writeFrameAsync(wr) + } } // errHandlerPanicked is the error given to any callers blocked in a read from @@ -3595,26 +3777,12 @@ func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { panic("internal error: expected to be already writing a frame") } sc.writingFrame = false + sc.writingFrameAsync = false - wm := res.wm - st := wm.stream - - closeStream := http2endsStream(wm.write) - - if _, ok := wm.write.(http2handlerPanicRST); ok { - sc.closeStream(st, http2errHandlerPanicked) - } + wr := res.wr - if ch := wm.done; ch != nil { - select { - case ch <- res.err: - default: - panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wm.write)) - } - } - wm.write = nil - - if closeStream { + if http2writeEndsStream(wr.write) { + st := wr.stream if st == nil { panic("internal error: expecting non-nil stream") } @@ -3622,13 +3790,24 @@ func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { case http2stateOpen: st.state = http2stateHalfClosedLocal - errCancel := http2streamError(st.id, http2ErrCodeCancel) - sc.resetStream(errCancel) + sc.resetStream(http2streamError(st.id, http2ErrCodeCancel)) case http2stateHalfClosedRemote: sc.closeStream(st, http2errHandlerComplete) } + } else { + switch v := wr.write.(type) { + case http2StreamError: + + if st, ok := sc.streams[v.StreamID]; ok { + sc.closeStream(st, v) + } + case http2handlerPanicRST: + sc.closeStream(wr.stream, http2errHandlerPanicked) + } } + wr.replyToWriter(res.err) + sc.scheduleFrameWrite() } @@ -3646,47 +3825,68 @@ func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { // flush the write buffer. func (sc *http2serverConn) scheduleFrameWrite() { sc.serveG.check() - if sc.writingFrame { - return - } - if sc.needToSendGoAway { - sc.needToSendGoAway = false - sc.startFrameWrite(http2frameWriteMsg{ - write: &http2writeGoAway{ - maxStreamID: sc.maxStreamID, - code: sc.goAwayCode, - }, - }) - return - } - if sc.needToSendSettingsAck { - sc.needToSendSettingsAck = false - sc.startFrameWrite(http2frameWriteMsg{write: http2writeSettingsAck{}}) + if sc.writingFrame || sc.inFrameScheduleLoop { return } - if !sc.inGoAway { - if wm, ok := sc.writeSched.take(); ok { - sc.startFrameWrite(wm) - return + sc.inFrameScheduleLoop = true + for !sc.writingFrameAsync { + if sc.needToSendGoAway { + sc.needToSendGoAway = false + sc.startFrameWrite(http2FrameWriteRequest{ + write: &http2writeGoAway{ + maxStreamID: sc.maxClientStreamID, + code: sc.goAwayCode, + }, + }) + continue } + if sc.needToSendSettingsAck { + sc.needToSendSettingsAck = false + sc.startFrameWrite(http2FrameWriteRequest{write: http2writeSettingsAck{}}) + continue + } + if !sc.inGoAway || sc.goAwayCode == http2ErrCodeNo { + if wr, ok := sc.writeSched.Pop(); ok { + sc.startFrameWrite(wr) + continue + } + } + if sc.needsFrameFlush { + sc.startFrameWrite(http2FrameWriteRequest{write: http2flushFrameWriter{}}) + sc.needsFrameFlush = false + continue + } + break } - if sc.needsFrameFlush { - sc.startFrameWrite(http2frameWriteMsg{write: http2flushFrameWriter{}}) - sc.needsFrameFlush = false - return - } + sc.inFrameScheduleLoop = false +} + +// startGracefulShutdown sends a GOAWAY with ErrCodeNo to tell the +// client we're gracefully shutting down. The connection isn't closed +// until all current streams are done. +func (sc *http2serverConn) startGracefulShutdown() { + sc.goAwayIn(http2ErrCodeNo, 0) } func (sc *http2serverConn) goAway(code http2ErrCode) { sc.serveG.check() - if sc.inGoAway { - return - } + var forceCloseIn time.Duration if code != http2ErrCodeNo { - sc.shutDownIn(250 * time.Millisecond) + forceCloseIn = 250 * time.Millisecond } else { - sc.shutDownIn(1 * time.Second) + forceCloseIn = 1 * time.Second + } + sc.goAwayIn(code, forceCloseIn) +} + +func (sc *http2serverConn) goAwayIn(code http2ErrCode, forceCloseIn time.Duration) { + sc.serveG.check() + if sc.inGoAway { + return + } + if forceCloseIn != 0 { + sc.shutDownIn(forceCloseIn) } sc.inGoAway = true sc.needToSendGoAway = true @@ -3702,10 +3902,9 @@ func (sc *http2serverConn) shutDownIn(d time.Duration) { func (sc *http2serverConn) resetStream(se http2StreamError) { sc.serveG.check() - sc.writeFrame(http2frameWriteMsg{write: se}) + sc.writeFrame(http2FrameWriteRequest{write: se}) if st, ok := sc.streams[se.StreamID]; ok { - st.sentReset = true - sc.closeStream(st, se) + st.resetQueued = true } } @@ -3782,6 +3981,8 @@ func (sc *http2serverConn) processFrame(f http2Frame) error { return sc.processResetStream(f) case *http2PriorityFrame: return sc.processPriority(f) + case *http2GoAwayFrame: + return sc.processGoAway(f) case *http2PushPromiseFrame: return http2ConnectionError(http2ErrCodeProtocol) @@ -3801,7 +4002,10 @@ func (sc *http2serverConn) processPing(f *http2PingFrame) error { return http2ConnectionError(http2ErrCodeProtocol) } - sc.writeFrame(http2frameWriteMsg{write: http2writePingAck{f}}) + if sc.inGoAway && sc.goAwayCode != http2ErrCodeNo { + return nil + } + sc.writeFrame(http2FrameWriteRequest{write: http2writePingAck{f}}) return nil } @@ -3809,7 +4013,11 @@ func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error sc.serveG.check() switch { case f.StreamID != 0: - st := sc.streams[f.StreamID] + state, st := sc.state(f.StreamID) + if state == http2stateIdle { + + return http2ConnectionError(http2ErrCodeProtocol) + } if st == nil { return nil @@ -3835,7 +4043,6 @@ func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error { return http2ConnectionError(http2ErrCodeProtocol) } if st != nil { - st.gotReset = true st.cancelCtx() sc.closeStream(st, http2streamError(f.StreamID, f.ErrCode)) } @@ -3848,11 +4055,21 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) { panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state)) } st.state = http2stateClosed - sc.curOpenStreams-- - if sc.curOpenStreams == 0 { - sc.setConnState(StateIdle) + if st.isPushed() { + sc.curPushedStreams-- + } else { + sc.curClientStreams-- } delete(sc.streams, st.id) + if len(sc.streams) == 0 { + sc.setConnState(StateIdle) + if sc.srv.IdleTimeout != 0 { + sc.idleTimer.Reset(sc.srv.IdleTimeout) + } + if http2h1ServerKeepAlivesDisabled(sc.hs) { + sc.startGracefulShutdown() + } + } if p := st.body; p != nil { sc.sendWindowUpdate(nil, p.Len()) @@ -3860,11 +4077,7 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) { p.CloseWithError(err) } st.cw.Close() - sc.writeSched.forgetStream(st.id) - if st.reqBuf != nil { - - sc.freeRequestBodyBuf = st.reqBuf - } + sc.writeSched.CloseStream(st.id) } func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error { @@ -3904,7 +4117,7 @@ func (sc *http2serverConn) processSetting(s http2Setting) error { case http2SettingInitialWindowSize: return sc.processSettingInitialWindowSize(s.Val) case http2SettingMaxFrameSize: - sc.writeSched.maxFrameSize = s.Val + sc.maxFrameSize = int32(s.Val) case http2SettingMaxHeaderListSize: sc.peerMaxHeaderListSize = s.Val default: @@ -3933,11 +4146,18 @@ func (sc *http2serverConn) processSettingInitialWindowSize(val uint32) error { func (sc *http2serverConn) processData(f *http2DataFrame) error { sc.serveG.check() + if sc.inGoAway && sc.goAwayCode != http2ErrCodeNo { + return nil + } data := f.Data() id := f.Header().StreamID - st, ok := sc.streams[id] - if !ok || st.state != http2stateOpen || st.gotTrailerHeader { + state, st := sc.state(id) + if id == 0 || state == http2stateIdle { + + return http2ConnectionError(http2ErrCodeProtocol) + } + if st == nil || state != http2stateOpen || st.gotTrailerHeader || st.resetQueued { if sc.inflow.available() < int32(f.Length) { return http2streamError(id, http2ErrCodeFlowControl) @@ -3946,6 +4166,10 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { sc.inflow.take(int32(f.Length)) sc.sendWindowUpdate(nil, int(f.Length)) + if st != nil && st.resetQueued { + + return nil + } return http2streamError(id, http2ErrCodeStreamClosed) } if st.body == nil { @@ -3985,6 +4209,24 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { return nil } +func (sc *http2serverConn) processGoAway(f *http2GoAwayFrame) error { + sc.serveG.check() + if f.ErrCode != http2ErrCodeNo { + sc.logf("http2: received GOAWAY %+v, starting graceful shutdown", f) + } else { + sc.vlogf("http2: received GOAWAY %+v, starting graceful shutdown", f) + } + sc.startGracefulShutdown() + + sc.pushEnabled = false + return nil +} + +// isPushed reports whether the stream is server-initiated. +func (st *http2stream) isPushed() bool { + return st.id%2 == 0 +} + // endStream closes a Request.Body's pipe. It is called when a DATA // frame says a request body is over (or after trailers). func (st *http2stream) endStream() { @@ -4014,7 +4256,7 @@ func (st *http2stream) copyTrailersToHandlerRequest() { func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { sc.serveG.check() - id := f.Header().StreamID + id := f.StreamID if sc.inGoAway { return nil @@ -4024,50 +4266,43 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { return http2ConnectionError(http2ErrCodeProtocol) } - st := sc.streams[f.Header().StreamID] - if st != nil { + if st := sc.streams[f.StreamID]; st != nil { + if st.resetQueued { + + return nil + } return st.processTrailerHeaders(f) } - if id <= sc.maxStreamID { + if id <= sc.maxClientStreamID { return http2ConnectionError(http2ErrCodeProtocol) } - sc.maxStreamID = id + sc.maxClientStreamID = id - ctx, cancelCtx := http2contextWithCancel(sc.baseCtx) - st = &http2stream{ - sc: sc, - id: id, - state: http2stateOpen, - ctx: ctx, - cancelCtx: cancelCtx, - } - if f.StreamEnded() { - st.state = http2stateHalfClosedRemote + if sc.idleTimer != nil { + sc.idleTimer.Stop() } - st.cw.Init() - st.flow.conn = &sc.flow - st.flow.add(sc.initialWindowSize) - st.inflow.conn = &sc.inflow - st.inflow.add(http2initialWindowSize) + if sc.curClientStreams+1 > sc.advMaxStreams { + if sc.unackedSettings == 0 { - sc.streams[id] = st - if f.HasPriority() { - http2adjustStreamPriority(sc.streams, st.id, f.Priority) - } - sc.curOpenStreams++ - if sc.curOpenStreams == 1 { - sc.setConnState(StateActive) + return http2streamError(id, http2ErrCodeProtocol) + } + + return http2streamError(id, http2ErrCodeRefusedStream) } - if sc.curOpenStreams > sc.advMaxStreams { - if sc.unackedSettings == 0 { + initialState := http2stateOpen + if f.StreamEnded() { + initialState = http2stateHalfClosedRemote + } + st := sc.newStream(id, 0, initialState) - return http2streamError(st.id, http2ErrCodeProtocol) + if f.HasPriority() { + if err := http2checkPriority(f.StreamID, f.Priority); err != nil { + return err } - - return http2streamError(st.id, http2ErrCodeRefusedStream) + sc.writeSched.AdjustStream(st.id, f.Priority) } rw, req, err := sc.newWriterAndRequest(st, f) @@ -4085,10 +4320,14 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { if f.Truncated { handler = http2handleHeaderListTooLong - } else if err := http2checkValidHTTP2Request(req); err != nil { + } else if err := http2checkValidHTTP2RequestHeaders(req.Header); err != nil { handler = http2new400Handler(err) } + if sc.hs.ReadTimeout != 0 { + sc.conn.SetReadDeadline(time.Time{}) + } + go sc.runHandler(rw, req, handler) return nil } @@ -4121,90 +4360,138 @@ func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error { return nil } -func (sc *http2serverConn) processPriority(f *http2PriorityFrame) error { - http2adjustStreamPriority(sc.streams, f.StreamID, f.http2PriorityParam) +func http2checkPriority(streamID uint32, p http2PriorityParam) error { + if streamID == p.StreamDep { + + return http2streamError(streamID, http2ErrCodeProtocol) + } return nil } -func http2adjustStreamPriority(streams map[uint32]*http2stream, streamID uint32, priority http2PriorityParam) { - st, ok := streams[streamID] - if !ok { +func (sc *http2serverConn) processPriority(f *http2PriorityFrame) error { + if sc.inGoAway { + return nil + } + if err := http2checkPriority(f.StreamID, f.http2PriorityParam); err != nil { + return err + } + sc.writeSched.AdjustStream(f.StreamID, f.http2PriorityParam) + return nil +} - return +func (sc *http2serverConn) newStream(id, pusherID uint32, state http2streamState) *http2stream { + sc.serveG.check() + if id == 0 { + panic("internal error: cannot create stream with id 0") } - st.weight = priority.Weight - parent := streams[priority.StreamDep] - if parent == st { - return + ctx, cancelCtx := http2contextWithCancel(sc.baseCtx) + st := &http2stream{ + sc: sc, + id: id, + state: state, + ctx: ctx, + cancelCtx: cancelCtx, } + st.cw.Init() + st.flow.conn = &sc.flow + st.flow.add(sc.initialWindowSize) + st.inflow.conn = &sc.inflow + st.inflow.add(http2initialWindowSize) - for piter := parent; piter != nil; piter = piter.parent { - if piter == st { - parent.parent = st.parent - break - } + sc.streams[id] = st + sc.writeSched.OpenStream(st.id, http2OpenStreamOptions{PusherID: pusherID}) + if st.isPushed() { + sc.curPushedStreams++ + } else { + sc.curClientStreams++ } - st.parent = parent - if priority.Exclusive && (st.parent != nil || priority.StreamDep == 0) { - for _, openStream := range streams { - if openStream != st && openStream.parent == st.parent { - openStream.parent = st - } - } + if sc.curOpenStreams() == 1 { + sc.setConnState(StateActive) } + + return st } func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHeadersFrame) (*http2responseWriter, *Request, error) { sc.serveG.check() - method := f.PseudoValue("method") - path := f.PseudoValue("path") - scheme := f.PseudoValue("scheme") - authority := f.PseudoValue("authority") + rp := http2requestParam{ + method: f.PseudoValue("method"), + scheme: f.PseudoValue("scheme"), + authority: f.PseudoValue("authority"), + path: f.PseudoValue("path"), + } - isConnect := method == "CONNECT" + isConnect := rp.method == "CONNECT" if isConnect { - if path != "" || scheme != "" || authority == "" { + if rp.path != "" || rp.scheme != "" || rp.authority == "" { return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) } - } else if method == "" || path == "" || - (scheme != "https" && scheme != "http") { + } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") { return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) } bodyOpen := !f.StreamEnded() - if method == "HEAD" && bodyOpen { + if rp.method == "HEAD" && bodyOpen { return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) } - var tlsState *tls.ConnectionState // nil if not scheme https - if scheme == "https" { - tlsState = sc.tlsState + rp.header = make(Header) + for _, hf := range f.RegularFields() { + rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value) + } + if rp.authority == "" { + rp.authority = rp.header.Get("Host") } - header := make(Header) - for _, hf := range f.RegularFields() { - header.Add(sc.canonicalHeader(hf.Name), hf.Value) + rw, req, err := sc.newWriterAndRequestNoBody(st, rp) + if err != nil { + return nil, nil, err } + if bodyOpen { + st.reqBuf = http2getRequestBodyBuf() + req.Body.(*http2requestBody).pipe = &http2pipe{ + b: &http2fixedBuffer{buf: st.reqBuf}, + } - if authority == "" { - authority = header.Get("Host") + if vv, ok := rp.header["Content-Length"]; ok { + req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64) + } else { + req.ContentLength = -1 + } } - needsContinue := header.Get("Expect") == "100-continue" + return rw, req, nil +} + +type http2requestParam struct { + method string + scheme, authority, path string + header Header +} + +func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2requestParam) (*http2responseWriter, *Request, error) { + sc.serveG.check() + + var tlsState *tls.ConnectionState // nil if not scheme https + if rp.scheme == "https" { + tlsState = sc.tlsState + } + + needsContinue := rp.header.Get("Expect") == "100-continue" if needsContinue { - header.Del("Expect") + rp.header.Del("Expect") } - if cookies := header["Cookie"]; len(cookies) > 1 { - header.Set("Cookie", strings.Join(cookies, "; ")) + if cookies := rp.header["Cookie"]; len(cookies) > 1 { + rp.header.Set("Cookie", strings.Join(cookies, "; ")) } // Setup Trailers var trailer Header - for _, v := range header["Trailer"] { + for _, v := range rp.header["Trailer"] { for _, key := range strings.Split(v, ",") { key = CanonicalHeaderKey(strings.TrimSpace(key)) switch key { @@ -4218,55 +4505,42 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead } } } - delete(header, "Trailer") + delete(rp.header, "Trailer") - body := &http2requestBody{ - conn: sc, - stream: st, - needsContinue: needsContinue, - } var url_ *url.URL var requestURI string - if isConnect { - url_ = &url.URL{Host: authority} - requestURI = authority + if rp.method == "CONNECT" { + url_ = &url.URL{Host: rp.authority} + requestURI = rp.authority } else { var err error - url_, err = url.ParseRequestURI(path) + url_, err = url.ParseRequestURI(rp.path) if err != nil { - return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) + return nil, nil, http2streamError(st.id, http2ErrCodeProtocol) } - requestURI = path + requestURI = rp.path + } + + body := &http2requestBody{ + conn: sc, + stream: st, + needsContinue: needsContinue, } req := &Request{ - Method: method, + Method: rp.method, URL: url_, RemoteAddr: sc.remoteAddrStr, - Header: header, + Header: rp.header, RequestURI: requestURI, Proto: "HTTP/2.0", ProtoMajor: 2, ProtoMinor: 0, TLS: tlsState, - Host: authority, + Host: rp.authority, Body: body, Trailer: trailer, } req = http2requestWithContext(req, st.ctx) - if bodyOpen { - - buf := make([]byte, http2initialWindowSize) - - body.pipe = &http2pipe{ - b: &http2fixedBuffer{buf: buf}, - } - - if vv, ok := header["Content-Length"]; ok { - req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64) - } else { - req.ContentLength = -1 - } - } rws := http2responseWriterStatePool.Get().(*http2responseWriterState) bwSave := rws.bw @@ -4282,13 +4556,22 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead return rw, req, nil } -func (sc *http2serverConn) getRequestBodyBuf() []byte { - sc.serveG.check() - if buf := sc.freeRequestBodyBuf; buf != nil { - sc.freeRequestBodyBuf = nil - return buf +var http2reqBodyCache = make(chan []byte, 8) + +func http2getRequestBodyBuf() []byte { + select { + case b := <-http2reqBodyCache: + return b + default: + return make([]byte, http2initialWindowSize) + } +} + +func http2putRequestBodyBuf(b []byte) { + select { + case http2reqBodyCache <- b: + default: } - return make([]byte, http2initialWindowSize) } // Run on its own goroutine. @@ -4298,15 +4581,17 @@ func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *Request, han rw.rws.stream.cancelCtx() if didPanic { e := recover() - // Same as net/http: - const size = 64 << 10 - buf := make([]byte, size) - buf = buf[:runtime.Stack(buf, false)] - sc.writeFrameFromHandler(http2frameWriteMsg{ + sc.writeFrameFromHandler(http2FrameWriteRequest{ write: http2handlerPanicRST{rw.rws.stream.id}, stream: rw.rws.stream, }) - sc.logf("http2: panic serving %v: %v\n%s", sc.conn.RemoteAddr(), e, buf) + + if http2shouldLogPanic(e) { + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + sc.logf("http2: panic serving %v: %v\n%s", sc.conn.RemoteAddr(), e, buf) + } return } rw.handlerDone() @@ -4334,7 +4619,7 @@ func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeR errc = http2errChanPool.Get().(chan error) } - if err := sc.writeFrameFromHandler(http2frameWriteMsg{ + if err := sc.writeFrameFromHandler(http2FrameWriteRequest{ write: headerData, stream: st, done: errc, @@ -4357,7 +4642,7 @@ func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeR // called from handler goroutines. func (sc *http2serverConn) write100ContinueHeaders(st *http2stream) { - sc.writeFrameFromHandler(http2frameWriteMsg{ + sc.writeFrameFromHandler(http2FrameWriteRequest{ write: http2write100ContinueHeadersFrame{st.id}, stream: st, }) @@ -4373,11 +4658,19 @@ type http2bodyReadMsg struct { // called from handler goroutines. // Notes that the handler for the given stream ID read n bytes of its body // and schedules flow control tokens to be sent. -func (sc *http2serverConn) noteBodyReadFromHandler(st *http2stream, n int) { +func (sc *http2serverConn) noteBodyReadFromHandler(st *http2stream, n int, err error) { sc.serveG.checkNotOn() - select { - case sc.bodyReadCh <- http2bodyReadMsg{st, n}: - case <-sc.doneServing: + if n > 0 { + select { + case sc.bodyReadCh <- http2bodyReadMsg{st, n}: + case <-sc.doneServing: + } + } + if err == io.EOF { + if buf := st.reqBuf; buf != nil { + st.reqBuf = nil + http2putRequestBodyBuf(buf) + } } } @@ -4419,7 +4712,7 @@ func (sc *http2serverConn) sendWindowUpdate32(st *http2stream, n int32) { if st != nil { streamID = st.id } - sc.writeFrame(http2frameWriteMsg{ + sc.writeFrame(http2FrameWriteRequest{ write: http2writeWindowUpdate{streamID: streamID, n: uint32(n)}, stream: st, }) @@ -4434,16 +4727,19 @@ func (sc *http2serverConn) sendWindowUpdate32(st *http2stream, n int32) { } } +// requestBody is the Handler's Request.Body type. +// Read and Close may be called concurrently. type http2requestBody struct { stream *http2stream conn *http2serverConn - closed bool + closed bool // for use by Close only + sawEOF bool // for use by Read only pipe *http2pipe // non-nil if we have a HTTP entity message body needsContinue bool // need to send a 100-continue } func (b *http2requestBody) Close() error { - if b.pipe != nil { + if b.pipe != nil && !b.closed { b.pipe.BreakWithError(http2errClosedBody) } b.closed = true @@ -4455,13 +4751,17 @@ func (b *http2requestBody) Read(p []byte) (n int, err error) { b.needsContinue = false b.conn.write100ContinueHeaders(b.stream) } - if b.pipe == nil { + if b.pipe == nil || b.sawEOF { return 0, io.EOF } n, err = b.pipe.Read(p) - if n > 0 { - b.conn.noteBodyReadFromHandler(b.stream, n) + if err == io.EOF { + b.sawEOF = true } + if b.conn == nil && http2inTests { + return + } + b.conn.noteBodyReadFromHandler(b.stream, n, err) return } @@ -4696,8 +4996,9 @@ func (w *http2responseWriter) CloseNotify() <-chan bool { if ch == nil { ch = make(chan bool, 1) rws.closeNotifierCh = ch + cw := rws.stream.cw go func() { - rws.stream.cw.Wait() + cw.Wait() ch <- true }() } @@ -4793,6 +5094,172 @@ func (w *http2responseWriter) handlerDone() { http2responseWriterStatePool.Put(rws) } +// Push errors. +var ( + http2ErrRecursivePush = errors.New("http2: recursive push not allowed") + http2ErrPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS") +) + +// pushOptions is the internal version of http.PushOptions, which we +// cannot include here because it's only defined in Go 1.8 and later. +type http2pushOptions struct { + Method string + Header Header +} + +func (w *http2responseWriter) push(target string, opts http2pushOptions) error { + st := w.rws.stream + sc := st.sc + sc.serveG.checkNotOn() + + if st.isPushed() { + return http2ErrRecursivePush + } + + if opts.Method == "" { + opts.Method = "GET" + } + if opts.Header == nil { + opts.Header = Header{} + } + wantScheme := "http" + if w.rws.req.TLS != nil { + wantScheme = "https" + } + + u, err := url.Parse(target) + if err != nil { + return err + } + if u.Scheme == "" { + if !strings.HasPrefix(target, "/") { + return fmt.Errorf("target must be an absolute URL or an absolute path: %q", target) + } + u.Scheme = wantScheme + u.Host = w.rws.req.Host + } else { + if u.Scheme != wantScheme { + return fmt.Errorf("cannot push URL with scheme %q from request with scheme %q", u.Scheme, wantScheme) + } + if u.Host == "" { + return errors.New("URL must have a host") + } + } + for k := range opts.Header { + if strings.HasPrefix(k, ":") { + return fmt.Errorf("promised request headers cannot include pseudo header %q", k) + } + + switch strings.ToLower(k) { + case "content-length", "content-encoding", "trailer", "te", "expect", "host": + return fmt.Errorf("promised request headers cannot include %q", k) + } + } + if err := http2checkValidHTTP2RequestHeaders(opts.Header); err != nil { + return err + } + + if opts.Method != "GET" && opts.Method != "HEAD" { + return fmt.Errorf("method %q must be GET or HEAD", opts.Method) + } + + msg := http2startPushRequest{ + parent: st, + method: opts.Method, + url: u, + header: http2cloneHeader(opts.Header), + done: http2errChanPool.Get().(chan error), + } + + select { + case <-sc.doneServing: + return http2errClientDisconnected + case <-st.cw: + return http2errStreamClosed + case sc.wantStartPushCh <- msg: + } + + select { + case <-sc.doneServing: + return http2errClientDisconnected + case <-st.cw: + return http2errStreamClosed + case err := <-msg.done: + http2errChanPool.Put(msg.done) + return err + } +} + +type http2startPushRequest struct { + parent *http2stream + method string + url *url.URL + header Header + done chan error +} + +func (sc *http2serverConn) startPush(msg http2startPushRequest) { + sc.serveG.check() + + if msg.parent.state != http2stateOpen && msg.parent.state != http2stateHalfClosedRemote { + + msg.done <- http2errStreamClosed + return + } + + if !sc.pushEnabled { + msg.done <- ErrNotSupported + return + } + + allocatePromisedID := func() (uint32, error) { + sc.serveG.check() + + if !sc.pushEnabled { + return 0, ErrNotSupported + } + + if sc.curPushedStreams+1 > sc.clientMaxStreams { + return 0, http2ErrPushLimitReached + } + + if sc.maxPushPromiseID+2 >= 1<<31 { + sc.startGracefulShutdown() + return 0, http2ErrPushLimitReached + } + sc.maxPushPromiseID += 2 + promisedID := sc.maxPushPromiseID + + promised := sc.newStream(promisedID, msg.parent.id, http2stateHalfClosedRemote) + rw, req, err := sc.newWriterAndRequestNoBody(promised, http2requestParam{ + method: msg.method, + scheme: msg.url.Scheme, + authority: msg.url.Host, + path: msg.url.RequestURI(), + header: http2cloneHeader(msg.header), + }) + if err != nil { + + panic(fmt.Sprintf("newWriterAndRequestNoBody(%+v): %v", msg.url, err)) + } + + go sc.runHandler(rw, req, sc.handler.ServeHTTP) + return promisedID, nil + } + + sc.writeFrame(http2FrameWriteRequest{ + write: &http2writePushPromise{ + streamID: msg.parent.id, + method: msg.method, + url: msg.url, + h: msg.header, + allocatePromisedID: allocatePromisedID, + }, + stream: msg.parent, + done: msg.done, + }) +} + // foreachHeaderElement splits v according to the "#rule" construction // in RFC 2616 section 2.1 and calls fn for each non-empty element. func http2foreachHeaderElement(v string, fn func(string)) { @@ -4820,16 +5287,16 @@ var http2connHeaders = []string{ "Upgrade", } -// checkValidHTTP2Request checks whether req is a valid HTTP/2 request, +// checkValidHTTP2RequestHeaders checks whether h is a valid HTTP/2 request, // per RFC 7540 Section 8.1.2.2. // The returned error is reported to users. -func http2checkValidHTTP2Request(req *Request) error { - for _, h := range http2connHeaders { - if _, ok := req.Header[h]; ok { - return fmt.Errorf("request header %q is not valid in HTTP/2", h) +func http2checkValidHTTP2RequestHeaders(h Header) error { + for _, k := range http2connHeaders { + if _, ok := h[k]; ok { + return fmt.Errorf("request header %q is not valid in HTTP/2", k) } } - te := req.Header["Te"] + te := h["Te"] if len(te) > 0 && (len(te) > 1 || (te[0] != "trailers" && te[0] != "")) { return errors.New(`request header "TE" may only be "trailers" in HTTP/2`) } @@ -4877,6 +5344,45 @@ var http2badTrailer = map[string]bool{ "Www-Authenticate": true, } +// h1ServerShutdownChan returns a channel that will be closed when the +// provided *http.Server wants to shut down. +// +// This is a somewhat hacky way to get at http1 innards. It works +// when the http2 code is bundled into the net/http package in the +// standard library. The alternatives ended up making the cmd/go tool +// depend on http Servers. This is the lightest option for now. +// This is tested via the TestServeShutdown* tests in net/http. +func http2h1ServerShutdownChan(hs *Server) <-chan struct{} { + if fn := http2testh1ServerShutdownChan; fn != nil { + return fn(hs) + } + var x interface{} = hs + type I interface { + getDoneChan() <-chan struct{} + } + if hs, ok := x.(I); ok { + return hs.getDoneChan() + } + return nil +} + +// optional test hook for h1ServerShutdownChan. +var http2testh1ServerShutdownChan func(hs *Server) <-chan struct{} + +// h1ServerKeepAlivesDisabled reports whether hs has its keep-alives +// disabled. See comments on h1ServerShutdownChan above for why +// the code is written this way. +func http2h1ServerKeepAlivesDisabled(hs *Server) bool { + var x interface{} = hs + type I interface { + doKeepAlives() bool + } + if hs, ok := x.(I); ok { + return !hs.doKeepAlives() + } + return false +} + const ( // transportDefaultConnFlow is how many connection-level flow control // tokens we give the server at start-up, past the default 64k. @@ -4997,6 +5503,9 @@ type http2ClientConn struct { readerDone chan struct{} // closed on error readerErr error // set before readerDone is closed + idleTimeout time.Duration // or 0 for never + idleTimer *time.Timer + mu sync.Mutex // guards following cond *sync.Cond // hold mu; broadcast on flow/closed changes flow http2flow // our conn-level flow control quota (cs.flow is per stream) @@ -5007,6 +5516,7 @@ type http2ClientConn struct { goAwayDebug string // goAway frame's debug data, retained as a string streams map[uint32]*http2clientStream // client-initiated nextStreamID uint32 + pings map[[8]byte]chan struct{} // in flight ping data to notification channel bw *bufio.Writer br *bufio.Reader fr *http2Framer @@ -5033,6 +5543,7 @@ type http2clientStream struct { ID uint32 resc chan http2resAndError bufPipe http2pipe // buffered pipe with the flow-controlled response payload + startedWrite bool // started request body write; guarded by cc.mu requestedGzip bool on100 func() // optional code to run if get a 100 continue response @@ -5041,6 +5552,7 @@ type http2clientStream struct { bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read readErr error // sticky read error; owned by transportResponseBody.Read stopReqBody error // if non-nil, stop writing req body; guarded by cc.mu + didReset bool // whether we sent a RST_STREAM to the server; guarded by cc.mu peerReset chan struct{} // closed on peer reset resetErr error // populated before peerReset is closed @@ -5068,15 +5580,26 @@ func (cs *http2clientStream) awaitRequestCancel(req *Request) { } select { case <-req.Cancel: + cs.cancelStream() cs.bufPipe.CloseWithError(http2errRequestCanceled) - cs.cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) case <-ctx.Done(): + cs.cancelStream() cs.bufPipe.CloseWithError(ctx.Err()) - cs.cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) case <-cs.done: } } +func (cs *http2clientStream) cancelStream() { + cs.cc.mu.Lock() + didReset := cs.didReset + cs.didReset = true + cs.cc.mu.Unlock() + + if !didReset { + cs.cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + } +} + // checkResetOrDone reports any error sent in a RST_STREAM frame by the // server, or errStreamClosed if the stream is complete. func (cs *http2clientStream) checkResetOrDone() error { @@ -5133,14 +5656,22 @@ func (t *http2Transport) RoundTrip(req *Request) (*Response, error) { // authorityAddr returns a given authority (a host/IP, or host:port / ip:port) // and returns a host:port. The port 443 is added if needed. func http2authorityAddr(scheme string, authority string) (addr string) { - if _, _, err := net.SplitHostPort(authority); err == nil { - return authority + host, port, err := net.SplitHostPort(authority) + if err != nil { + port = "443" + if scheme == "http" { + port = "80" + } + host = authority + } + if a, err := idna.ToASCII(host); err == nil { + host = a } - port := "443" - if scheme == "http" { - port = "80" + + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + return host + ":" + port } - return net.JoinHostPort(authority, port) + return net.JoinHostPort(host, port) } // RoundTripOpt is like RoundTrip, but takes options. @@ -5158,8 +5689,10 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res } http2traceGotConn(req, cc) res, err := cc.RoundTrip(req) - if http2shouldRetryRequest(req, err) { - continue + if err != nil { + if req, err = http2shouldRetryRequest(req, err); err == nil { + continue + } } if err != nil { t.vlogf("RoundTrip failure: %v", err) @@ -5181,11 +5714,39 @@ func (t *http2Transport) CloseIdleConnections() { var ( http2errClientConnClosed = errors.New("http2: client conn is closed") http2errClientConnUnusable = errors.New("http2: client conn not usable") + + http2errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY") + http2errClientConnGotGoAwayAfterSomeReqBody = errors.New("http2: Transport received Server's graceful shutdown GOAWAY; some request body already written") ) -func http2shouldRetryRequest(req *Request, err error) bool { +// shouldRetryRequest is called by RoundTrip when a request fails to get +// response headers. It is always called with a non-nil error. +// It returns either a request to retry (either the same request, or a +// modified clone), or an error if the request can't be replayed. +func http2shouldRetryRequest(req *Request, err error) (*Request, error) { + switch err { + default: + return nil, err + case http2errClientConnUnusable, http2errClientConnGotGoAway: + return req, nil + case http2errClientConnGotGoAwayAfterSomeReqBody: + + if req.Body == nil || http2reqBodyIsNoBody(req.Body) { + return req, nil + } - return err == http2errClientConnUnusable + getBody := http2reqGetBody(req) + if getBody == nil { + return nil, errors.New("http2: Transport: peer server initiated graceful shutdown after some of Request.Body was written; define Request.GetBody to avoid this error") + } + body, err := getBody() + if err != nil { + return nil, err + } + newReq := *req + newReq.Body = body + return &newReq, nil + } } func (t *http2Transport) dialClientConn(addr string, singleUse bool) (*http2ClientConn, error) { @@ -5203,7 +5764,7 @@ func (t *http2Transport) dialClientConn(addr string, singleUse bool) (*http2Clie func (t *http2Transport) newTLSConfig(host string) *tls.Config { cfg := new(tls.Config) if t.TLSClientConfig != nil { - *cfg = *t.TLSClientConfig + *cfg = *http2cloneTLSConfig(t.TLSClientConfig) } if !http2strSliceContains(cfg.NextProtos, http2NextProtoTLS) { cfg.NextProtos = append([]string{http2NextProtoTLS}, cfg.NextProtos...) @@ -5273,6 +5834,11 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client streams: make(map[uint32]*http2clientStream), singleUse: singleUse, wantSettingsAck: true, + pings: make(map[[8]byte]chan struct{}), + } + if d := t.idleConnTimeout(); d != 0 { + cc.idleTimeout = d + cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout) } if http2VerboseLogs { t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr()) @@ -5328,6 +5894,15 @@ func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) { if old != nil && old.ErrCode != http2ErrCodeNo { cc.goAway.ErrCode = old.ErrCode } + last := f.LastStreamID + for streamID, cs := range cc.streams { + if streamID > last { + select { + case cs.resc <- http2resAndError{err: http2errClientConnGotGoAway}: + default: + } + } + } } func (cc *http2ClientConn) CanTakeNewRequest() bool { @@ -5345,6 +5920,16 @@ func (cc *http2ClientConn) canTakeNewRequestLocked() bool { cc.nextStreamID < math.MaxInt32 } +// onIdleTimeout is called from a time.AfterFunc goroutine. It will +// only be called when we're idle, but because we're coming from a new +// goroutine, there could be a new request coming in at the same time, +// so this simply calls the synchronized closeIfIdle to shut down this +// connection. The timer could just call closeIfIdle, but this is more +// clear. +func (cc *http2ClientConn) onIdleTimeout() { + cc.closeIfIdle() +} + func (cc *http2ClientConn) closeIfIdle() { cc.mu.Lock() if len(cc.streams) > 0 { @@ -5437,48 +6022,37 @@ func (cc *http2ClientConn) responseHeaderTimeout() time.Duration { // Certain headers are special-cased as okay but not transmitted later. func http2checkConnHeaders(req *Request) error { if v := req.Header.Get("Upgrade"); v != "" { - return errors.New("http2: invalid Upgrade request header") + return fmt.Errorf("http2: invalid Upgrade request header: %q", req.Header["Upgrade"]) } - if v := req.Header.Get("Transfer-Encoding"); (v != "" && v != "chunked") || len(req.Header["Transfer-Encoding"]) > 1 { - return errors.New("http2: invalid Transfer-Encoding request header") + if vv := req.Header["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") { + return fmt.Errorf("http2: invalid Transfer-Encoding request header: %q", vv) } - if v := req.Header.Get("Connection"); (v != "" && v != "close" && v != "keep-alive") || len(req.Header["Connection"]) > 1 { - return errors.New("http2: invalid Connection request header") + if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "close" && vv[0] != "keep-alive") { + return fmt.Errorf("http2: invalid Connection request header: %q", vv) } return nil } -func http2bodyAndLength(req *Request) (body io.Reader, contentLen int64) { - body = req.Body - if body == nil { - return nil, 0 +// actualContentLength returns a sanitized version of +// req.ContentLength, where 0 actually means zero (not unknown) and -1 +// means unknown. +func http2actualContentLength(req *Request) int64 { + if req.Body == nil { + return 0 } if req.ContentLength != 0 { - return req.Body, req.ContentLength - } - - // We have a body but a zero content length. Test to see if - // it's actually zero or just unset. - var buf [1]byte - n, rerr := body.Read(buf[:]) - if rerr != nil && rerr != io.EOF { - return http2errorReader{rerr}, -1 - } - if n == 1 { - - if rerr == io.EOF { - return bytes.NewReader(buf[:]), 1 - } - return io.MultiReader(bytes.NewReader(buf[:]), body), -1 + return req.ContentLength } - - return nil, 0 + return -1 } func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { if err := http2checkConnHeaders(req); err != nil { return nil, err } + if cc.idleTimer != nil { + cc.idleTimer.Stop() + } trailers, err := http2commaSeparatedTrailers(req) if err != nil { @@ -5486,9 +6060,6 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { } hasTrailers := trailers != "" - body, contentLen := http2bodyAndLength(req) - hasBody := body != nil - cc.mu.Lock() cc.lastActive = time.Now() if cc.closed || !cc.canTakeNewRequestLocked() { @@ -5496,6 +6067,10 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { return nil, http2errClientConnUnusable } + body := req.Body + hasBody := body != nil + contentLen := http2actualContentLength(req) + // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? var requestedGzip bool if !cc.t.disableCompression() && @@ -5561,6 +6136,13 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { cs.abortRequestBodyWrite(http2errStopReqBodyWrite) } if re.err != nil { + if re.err == http2errClientConnGotGoAway { + cc.mu.Lock() + if cs.startedWrite { + re.err = http2errClientConnGotGoAwayAfterSomeReqBody + } + cc.mu.Unlock() + } cc.forgetStreamID(cs.ID) return nil, re.err } @@ -5806,6 +6388,26 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail if host == "" { host = req.URL.Host } + host, err := httplex.PunycodeHostPort(host) + if err != nil { + return nil, err + } + + var path string + if req.Method != "CONNECT" { + path = req.URL.RequestURI() + if !http2validPseudoPath(path) { + orig := path + path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host) + if !http2validPseudoPath(path) { + if req.URL.Opaque != "" { + return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque) + } else { + return nil, fmt.Errorf("invalid request :path %q", orig) + } + } + } + } for k, vv := range req.Header { if !httplex.ValidHeaderFieldName(k) { @@ -5821,8 +6423,8 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail cc.writeHeader(":authority", host) cc.writeHeader(":method", req.Method) if req.Method != "CONNECT" { - cc.writeHeader(":path", req.URL.RequestURI()) - cc.writeHeader(":scheme", "https") + cc.writeHeader(":path", path) + cc.writeHeader(":scheme", req.URL.Scheme) } if trailers != "" { cc.writeHeader("trailer", trailers) @@ -5940,6 +6542,9 @@ func (cc *http2ClientConn) streamByID(id uint32, andRemove bool) *http2clientStr if andRemove && cs != nil && !cc.closed { cc.lastActive = time.Now() delete(cc.streams, id) + if len(cc.streams) == 0 && cc.idleTimer != nil { + cc.idleTimer.Reset(cc.idleTimeout) + } close(cs.done) cc.cond.Broadcast() } @@ -5996,6 +6601,10 @@ func (rl *http2clientConnReadLoop) cleanup() { defer cc.t.connPool().MarkDead(cc) defer close(cc.readerDone) + if cc.idleTimer != nil { + cc.idleTimer.Stop() + } + err := cc.readerErr cc.mu.Lock() if cc.goAway != nil && http2isEOFOrNetReadError(err) { @@ -6398,9 +7007,10 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { cc.bw.Flush() cc.wmu.Unlock() } + didReset := cs.didReset cc.mu.Unlock() - if len(data) > 0 { + if len(data) > 0 && !didReset { if _, err := cs.bufPipe.Write(data); err != nil { rl.endStreamError(cs, err) return err @@ -6551,9 +7161,56 @@ func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) er return nil } +// Ping sends a PING frame to the server and waits for the ack. +// Public implementation is in go17.go and not_go17.go +func (cc *http2ClientConn) ping(ctx http2contextContext) error { + c := make(chan struct{}) + // Generate a random payload + var p [8]byte + for { + if _, err := rand.Read(p[:]); err != nil { + return err + } + cc.mu.Lock() + + if _, found := cc.pings[p]; !found { + cc.pings[p] = c + cc.mu.Unlock() + break + } + cc.mu.Unlock() + } + cc.wmu.Lock() + if err := cc.fr.WritePing(false, p); err != nil { + cc.wmu.Unlock() + return err + } + if err := cc.bw.Flush(); err != nil { + cc.wmu.Unlock() + return err + } + cc.wmu.Unlock() + select { + case <-c: + return nil + case <-ctx.Done(): + return ctx.Err() + case <-cc.readerDone: + + return cc.readerErr + } +} + func (rl *http2clientConnReadLoop) processPing(f *http2PingFrame) error { if f.IsAck() { + cc := rl.cc + cc.mu.Lock() + defer cc.mu.Unlock() + if c, ok := cc.pings[f.Data]; ok { + close(c) + delete(cc.pings, f.Data) + } return nil } cc := rl.cc @@ -6666,6 +7323,9 @@ func (t *http2Transport) getBodyWriterState(cs *http2clientStream, body io.Reade resc := make(chan error, 1) s.resc = resc s.fn = func() { + cs.cc.mu.Lock() + cs.startedWrite = true + cs.cc.mu.Unlock() resc <- cs.writeRequestBody(body, cs.req.Body) } s.delay = t.expectContinueTimeout() @@ -6728,6 +7388,11 @@ func http2isConnectionCloseRequest(req *Request) bool { // writeFramer is implemented by any type that is used to write frames. type http2writeFramer interface { writeFrame(http2writeContext) error + + // staysWithinBuffer reports whether this writer promises that + // it will only write less than or equal to size bytes, and it + // won't Flush the write context. + staysWithinBuffer(size int) bool } // writeContext is the interface needed by the various frame writer @@ -6749,9 +7414,10 @@ type http2writeContext interface { HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) } -// endsStream reports whether the given frame writer w will locally -// close the stream. -func http2endsStream(w http2writeFramer) bool { +// writeEndsStream reports whether w writes a frame that will transition +// the stream to a half-closed local state. This returns false for RST_STREAM, +// which closes the entire stream (not just the local half). +func http2writeEndsStream(w http2writeFramer) bool { switch v := w.(type) { case *http2writeData: return v.endStream @@ -6759,7 +7425,7 @@ func http2endsStream(w http2writeFramer) bool { return v.endStream case nil: - panic("endsStream called on nil writeFramer") + panic("writeEndsStream called on nil writeFramer") } return false } @@ -6770,8 +7436,16 @@ func (http2flushFrameWriter) writeFrame(ctx http2writeContext) error { return ctx.Flush() } +func (http2flushFrameWriter) staysWithinBuffer(max int) bool { return false } + type http2writeSettings []http2Setting +func (s http2writeSettings) staysWithinBuffer(max int) bool { + const settingSize = 6 // uint16 + uint32 + return http2frameHeaderLen+settingSize*len(s) <= max + +} + func (s http2writeSettings) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteSettings([]http2Setting(s)...) } @@ -6791,6 +7465,8 @@ func (p *http2writeGoAway) writeFrame(ctx http2writeContext) error { return err } +func (*http2writeGoAway) staysWithinBuffer(max int) bool { return false } + type http2writeData struct { streamID uint32 p []byte @@ -6805,6 +7481,10 @@ func (w *http2writeData) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteData(w.streamID, w.endStream, w.p) } +func (w *http2writeData) staysWithinBuffer(max int) bool { + return http2frameHeaderLen+len(w.p) <= max +} + // handlerPanicRST is the message sent from handler goroutines when // the handler panics. type http2handlerPanicRST struct { @@ -6815,22 +7495,59 @@ func (hp http2handlerPanicRST) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteRSTStream(hp.StreamID, http2ErrCodeInternal) } +func (hp http2handlerPanicRST) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } + func (se http2StreamError) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteRSTStream(se.StreamID, se.Code) } +func (se http2StreamError) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } + type http2writePingAck struct{ pf *http2PingFrame } func (w http2writePingAck) writeFrame(ctx http2writeContext) error { return ctx.Framer().WritePing(true, w.pf.Data) } +func (w http2writePingAck) staysWithinBuffer(max int) bool { + return http2frameHeaderLen+len(w.pf.Data) <= max +} + type http2writeSettingsAck struct{} func (http2writeSettingsAck) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteSettingsAck() } +func (http2writeSettingsAck) staysWithinBuffer(max int) bool { return http2frameHeaderLen <= max } + +// splitHeaderBlock splits headerBlock into fragments so that each fragment fits +// in a single frame, then calls fn for each fragment. firstFrag/lastFrag are true +// for the first/last fragment, respectively. +func http2splitHeaderBlock(ctx http2writeContext, headerBlock []byte, fn func(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error) error { + // For now we're lazy and just pick the minimum MAX_FRAME_SIZE + // that all peers must support (16KB). Later we could care + // more and send larger frames if the peer advertised it, but + // there's little point. Most headers are small anyway (so we + // generally won't have CONTINUATION frames), and extra frames + // only waste 9 bytes anyway. + const maxFrameSize = 16384 + + first := true + for len(headerBlock) > 0 { + frag := headerBlock + if len(frag) > maxFrameSize { + frag = frag[:maxFrameSize] + } + headerBlock = headerBlock[len(frag):] + if err := fn(ctx, frag, first, len(headerBlock) == 0); err != nil { + return err + } + first = false + } + return nil +} + // writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames // for HTTP response headers or trailers from a server handler. type http2writeResHeaders struct { @@ -6852,6 +7569,11 @@ func http2encKV(enc *hpack.Encoder, k, v string) { enc.WriteField(hpack.HeaderField{Name: k, Value: v}) } +func (w *http2writeResHeaders) staysWithinBuffer(max int) bool { + + return false +} + func (w *http2writeResHeaders) writeFrame(ctx http2writeContext) error { enc, buf := ctx.HeaderEncoder() buf.Reset() @@ -6877,39 +7599,69 @@ func (w *http2writeResHeaders) writeFrame(ctx http2writeContext) error { panic("unexpected empty hpack") } - // For now we're lazy and just pick the minimum MAX_FRAME_SIZE - // that all peers must support (16KB). Later we could care - // more and send larger frames if the peer advertised it, but - // there's little point. Most headers are small anyway (so we - // generally won't have CONTINUATION frames), and extra frames - // only waste 9 bytes anyway. - const maxFrameSize = 16384 + return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) +} - first := true - for len(headerBlock) > 0 { - frag := headerBlock - if len(frag) > maxFrameSize { - frag = frag[:maxFrameSize] - } - headerBlock = headerBlock[len(frag):] - endHeaders := len(headerBlock) == 0 - var err error - if first { - first = false - err = ctx.Framer().WriteHeaders(http2HeadersFrameParam{ - StreamID: w.streamID, - BlockFragment: frag, - EndStream: w.endStream, - EndHeaders: endHeaders, - }) - } else { - err = ctx.Framer().WriteContinuation(w.streamID, endHeaders, frag) - } - if err != nil { - return err - } +func (w *http2writeResHeaders) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error { + if firstFrag { + return ctx.Framer().WriteHeaders(http2HeadersFrameParam{ + StreamID: w.streamID, + BlockFragment: frag, + EndStream: w.endStream, + EndHeaders: lastFrag, + }) + } else { + return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) + } +} + +// writePushPromise is a request to write a PUSH_PROMISE and 0+ CONTINUATION frames. +type http2writePushPromise struct { + streamID uint32 // pusher stream + method string // for :method + url *url.URL // for :scheme, :authority, :path + h Header + + // Creates an ID for a pushed stream. This runs on serveG just before + // the frame is written. The returned ID is copied to promisedID. + allocatePromisedID func() (uint32, error) + promisedID uint32 +} + +func (w *http2writePushPromise) staysWithinBuffer(max int) bool { + + return false +} + +func (w *http2writePushPromise) writeFrame(ctx http2writeContext) error { + enc, buf := ctx.HeaderEncoder() + buf.Reset() + + http2encKV(enc, ":method", w.method) + http2encKV(enc, ":scheme", w.url.Scheme) + http2encKV(enc, ":authority", w.url.Host) + http2encKV(enc, ":path", w.url.RequestURI()) + http2encodeHeaders(enc, w.h, nil) + + headerBlock := buf.Bytes() + if len(headerBlock) == 0 { + panic("unexpected empty hpack") + } + + return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) +} + +func (w *http2writePushPromise) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error { + if firstFrag { + return ctx.Framer().WritePushPromise(http2PushPromiseParam{ + StreamID: w.streamID, + PromiseID: w.promisedID, + BlockFragment: frag, + EndHeaders: lastFrag, + }) + } else { + return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) } - return nil } type http2write100ContinueHeadersFrame struct { @@ -6928,15 +7680,24 @@ func (w http2write100ContinueHeadersFrame) writeFrame(ctx http2writeContext) err }) } +func (w http2write100ContinueHeadersFrame) staysWithinBuffer(max int) bool { + + return 9+2*(len(":status")+len("100")) <= max +} + type http2writeWindowUpdate struct { streamID uint32 // or 0 for conn-level n uint32 } +func (wu http2writeWindowUpdate) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } + func (wu http2writeWindowUpdate) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n) } +// encodeHeaders encodes an http.Header. If keys is not nil, then (k, h[k]) +// is encoded only only if k is in keys. func http2encodeHeaders(enc *hpack.Encoder, h Header, keys []string) { if keys == nil { sorter := http2sorterPool.Get().(*http2sorter) @@ -6966,14 +7727,53 @@ func http2encodeHeaders(enc *hpack.Encoder, h Header, keys []string) { } } -// frameWriteMsg is a request to write a frame. -type http2frameWriteMsg struct { +// WriteScheduler is the interface implemented by HTTP/2 write schedulers. +// Methods are never called concurrently. +type http2WriteScheduler interface { + // OpenStream opens a new stream in the write scheduler. + // It is illegal to call this with streamID=0 or with a streamID that is + // already open -- the call may panic. + OpenStream(streamID uint32, options http2OpenStreamOptions) + + // CloseStream closes a stream in the write scheduler. Any frames queued on + // this stream should be discarded. It is illegal to call this on a stream + // that is not open -- the call may panic. + CloseStream(streamID uint32) + + // AdjustStream adjusts the priority of the given stream. This may be called + // on a stream that has not yet been opened or has been closed. Note that + // RFC 7540 allows PRIORITY frames to be sent on streams in any state. See: + // https://tools.ietf.org/html/rfc7540#section-5.1 + AdjustStream(streamID uint32, priority http2PriorityParam) + + // Push queues a frame in the scheduler. In most cases, this will not be + // called with wr.StreamID()!=0 unless that stream is currently open. The one + // exception is RST_STREAM frames, which may be sent on idle or closed streams. + Push(wr http2FrameWriteRequest) + + // Pop dequeues the next frame to write. Returns false if no frames can + // be written. Frames with a given wr.StreamID() are Pop'd in the same + // order they are Push'd. + Pop() (wr http2FrameWriteRequest, ok bool) +} + +// OpenStreamOptions specifies extra options for WriteScheduler.OpenStream. +type http2OpenStreamOptions struct { + // PusherID is zero if the stream was initiated by the client. Otherwise, + // PusherID names the stream that pushed the newly opened stream. + PusherID uint32 +} + +// FrameWriteRequest is a request to write a frame. +type http2FrameWriteRequest struct { // write is the interface value that does the writing, once the - // writeScheduler (below) has decided to select this frame - // to write. The write functions are all defined in write.go. + // WriteScheduler has selected this frame to write. The write + // functions are all defined in write.go. write http2writeFramer - stream *http2stream // used for prioritization. nil for non-stream frames. + // stream is the stream on which this frame will be written. + // nil for non-stream frames like PING and SETTINGS. + stream *http2stream // done, if non-nil, must be a buffered channel with space for // 1 message and is sent the return value from write (or an @@ -6981,247 +7781,644 @@ type http2frameWriteMsg struct { done chan error } -// for debugging only: -func (wm http2frameWriteMsg) String() string { - var streamID uint32 - if wm.stream != nil { - streamID = wm.stream.id +// StreamID returns the id of the stream this frame will be written to. +// 0 is used for non-stream frames such as PING and SETTINGS. +func (wr http2FrameWriteRequest) StreamID() uint32 { + if wr.stream == nil { + if se, ok := wr.write.(http2StreamError); ok { + + return se.StreamID + } + return 0 + } + return wr.stream.id +} + +// DataSize returns the number of flow control bytes that must be consumed +// to write this entire frame. This is 0 for non-DATA frames. +func (wr http2FrameWriteRequest) DataSize() int { + if wd, ok := wr.write.(*http2writeData); ok { + return len(wd.p) + } + return 0 +} + +// Consume consumes min(n, available) bytes from this frame, where available +// is the number of flow control bytes available on the stream. Consume returns +// 0, 1, or 2 frames, where the integer return value gives the number of frames +// returned. +// +// If flow control prevents consuming any bytes, this returns (_, _, 0). If +// the entire frame was consumed, this returns (wr, _, 1). Otherwise, this +// returns (consumed, rest, 2), where 'consumed' contains the consumed bytes and +// 'rest' contains the remaining bytes. The consumed bytes are deducted from the +// underlying stream's flow control budget. +func (wr http2FrameWriteRequest) Consume(n int32) (http2FrameWriteRequest, http2FrameWriteRequest, int) { + var empty http2FrameWriteRequest + + wd, ok := wr.write.(*http2writeData) + if !ok || len(wd.p) == 0 { + return wr, empty, 1 + } + + allowed := wr.stream.flow.available() + if n < allowed { + allowed = n + } + if wr.stream.sc.maxFrameSize < allowed { + allowed = wr.stream.sc.maxFrameSize + } + if allowed <= 0 { + return empty, empty, 0 + } + if len(wd.p) > int(allowed) { + wr.stream.flow.take(allowed) + consumed := http2FrameWriteRequest{ + stream: wr.stream, + write: &http2writeData{ + streamID: wd.streamID, + p: wd.p[:allowed], + + endStream: false, + }, + + done: nil, + } + rest := http2FrameWriteRequest{ + stream: wr.stream, + write: &http2writeData{ + streamID: wd.streamID, + p: wd.p[allowed:], + endStream: wd.endStream, + }, + done: wr.done, + } + return consumed, rest, 2 } + + wr.stream.flow.take(int32(len(wd.p))) + return wr, empty, 1 +} + +// String is for debugging only. +func (wr http2FrameWriteRequest) String() string { var des string - if s, ok := wm.write.(fmt.Stringer); ok { + if s, ok := wr.write.(fmt.Stringer); ok { des = s.String() } else { - des = fmt.Sprintf("%T", wm.write) + des = fmt.Sprintf("%T", wr.write) } - return fmt.Sprintf("[frameWriteMsg stream=%d, ch=%v, type: %v]", streamID, wm.done != nil, des) + return fmt.Sprintf("[FrameWriteRequest stream=%d, ch=%v, writer=%v]", wr.StreamID(), wr.done != nil, des) } -// writeScheduler tracks pending frames to write, priorities, and decides -// the next one to use. It is not thread-safe. -type http2writeScheduler struct { - // zero are frames not associated with a specific stream. - // They're sent before any stream-specific freams. - zero http2writeQueue +// replyToWriter sends err to wr.done and panics if the send must block +// This does nothing if wr.done is nil. +func (wr *http2FrameWriteRequest) replyToWriter(err error) { + if wr.done == nil { + return + } + select { + case wr.done <- err: + default: + panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wr.write)) + } + wr.write = nil +} - // maxFrameSize is the maximum size of a DATA frame - // we'll write. Must be non-zero and between 16K-16M. - maxFrameSize uint32 +// writeQueue is used by implementations of WriteScheduler. +type http2writeQueue struct { + s []http2FrameWriteRequest +} - // sq contains the stream-specific queues, keyed by stream ID. - // when a stream is idle, it's deleted from the map. - sq map[uint32]*http2writeQueue +func (q *http2writeQueue) empty() bool { return len(q.s) == 0 } - // canSend is a slice of memory that's reused between frame - // scheduling decisions to hold the list of writeQueues (from sq) - // which have enough flow control data to send. After canSend is - // built, the best is selected. - canSend []*http2writeQueue +func (q *http2writeQueue) push(wr http2FrameWriteRequest) { + q.s = append(q.s, wr) +} - // pool of empty queues for reuse. - queuePool []*http2writeQueue +func (q *http2writeQueue) shift() http2FrameWriteRequest { + if len(q.s) == 0 { + panic("invalid use of queue") + } + wr := q.s[0] + + copy(q.s, q.s[1:]) + q.s[len(q.s)-1] = http2FrameWriteRequest{} + q.s = q.s[:len(q.s)-1] + return wr } -func (ws *http2writeScheduler) putEmptyQueue(q *http2writeQueue) { - if len(q.s) != 0 { - panic("queue must be empty") +// consume consumes up to n bytes from q.s[0]. If the frame is +// entirely consumed, it is removed from the queue. If the frame +// is partially consumed, the frame is kept with the consumed +// bytes removed. Returns true iff any bytes were consumed. +func (q *http2writeQueue) consume(n int32) (http2FrameWriteRequest, bool) { + if len(q.s) == 0 { + return http2FrameWriteRequest{}, false } - ws.queuePool = append(ws.queuePool, q) + consumed, rest, numresult := q.s[0].Consume(n) + switch numresult { + case 0: + return http2FrameWriteRequest{}, false + case 1: + q.shift() + case 2: + q.s[0] = rest + } + return consumed, true } -func (ws *http2writeScheduler) getEmptyQueue() *http2writeQueue { - ln := len(ws.queuePool) +type http2writeQueuePool []*http2writeQueue + +// put inserts an unused writeQueue into the pool. +func (p *http2writeQueuePool) put(q *http2writeQueue) { + for i := range q.s { + q.s[i] = http2FrameWriteRequest{} + } + q.s = q.s[:0] + *p = append(*p, q) +} + +// get returns an empty writeQueue. +func (p *http2writeQueuePool) get() *http2writeQueue { + ln := len(*p) if ln == 0 { return new(http2writeQueue) } - q := ws.queuePool[ln-1] - ws.queuePool = ws.queuePool[:ln-1] + x := ln - 1 + q := (*p)[x] + (*p)[x] = nil + *p = (*p)[:x] return q } -func (ws *http2writeScheduler) empty() bool { return ws.zero.empty() && len(ws.sq) == 0 } +// RFC 7540, Section 5.3.5: the default weight is 16. +const http2priorityDefaultWeight = 15 // 16 = 15 + 1 -func (ws *http2writeScheduler) add(wm http2frameWriteMsg) { - st := wm.stream - if st == nil { - ws.zero.push(wm) +// PriorityWriteSchedulerConfig configures a priorityWriteScheduler. +type http2PriorityWriteSchedulerConfig struct { + // MaxClosedNodesInTree controls the maximum number of closed streams to + // retain in the priority tree. Setting this to zero saves a small amount + // of memory at the cost of performance. + // + // See RFC 7540, Section 5.3.4: + // "It is possible for a stream to become closed while prioritization + // information ... is in transit. ... This potentially creates suboptimal + // prioritization, since the stream could be given a priority that is + // different from what is intended. To avoid these problems, an endpoint + // SHOULD retain stream prioritization state for a period after streams + // become closed. The longer state is retained, the lower the chance that + // streams are assigned incorrect or default priority values." + MaxClosedNodesInTree int + + // MaxIdleNodesInTree controls the maximum number of idle streams to + // retain in the priority tree. Setting this to zero saves a small amount + // of memory at the cost of performance. + // + // See RFC 7540, Section 5.3.4: + // Similarly, streams that are in the "idle" state can be assigned + // priority or become a parent of other streams. This allows for the + // creation of a grouping node in the dependency tree, which enables + // more flexible expressions of priority. Idle streams begin with a + // default priority (Section 5.3.5). + MaxIdleNodesInTree int + + // ThrottleOutOfOrderWrites enables write throttling to help ensure that + // data is delivered in priority order. This works around a race where + // stream B depends on stream A and both streams are about to call Write + // to queue DATA frames. If B wins the race, a naive scheduler would eagerly + // write as much data from B as possible, but this is suboptimal because A + // is a higher-priority stream. With throttling enabled, we write a small + // amount of data from B to minimize the amount of bandwidth that B can + // steal from A. + ThrottleOutOfOrderWrites bool +} + +// NewPriorityWriteScheduler constructs a WriteScheduler that schedules +// frames by following HTTP/2 priorities as described in RFC 7340 Section 5.3. +// If cfg is nil, default options are used. +func http2NewPriorityWriteScheduler(cfg *http2PriorityWriteSchedulerConfig) http2WriteScheduler { + if cfg == nil { + + cfg = &http2PriorityWriteSchedulerConfig{ + MaxClosedNodesInTree: 10, + MaxIdleNodesInTree: 10, + ThrottleOutOfOrderWrites: false, + } + } + + ws := &http2priorityWriteScheduler{ + nodes: make(map[uint32]*http2priorityNode), + maxClosedNodesInTree: cfg.MaxClosedNodesInTree, + maxIdleNodesInTree: cfg.MaxIdleNodesInTree, + enableWriteThrottle: cfg.ThrottleOutOfOrderWrites, + } + ws.nodes[0] = &ws.root + if cfg.ThrottleOutOfOrderWrites { + ws.writeThrottleLimit = 1024 } else { - ws.streamQueue(st.id).push(wm) + ws.writeThrottleLimit = math.MaxInt32 } + return ws +} + +type http2priorityNodeState int + +const ( + http2priorityNodeOpen http2priorityNodeState = iota + http2priorityNodeClosed + http2priorityNodeIdle +) + +// priorityNode is a node in an HTTP/2 priority tree. +// Each node is associated with a single stream ID. +// See RFC 7540, Section 5.3. +type http2priorityNode struct { + q http2writeQueue // queue of pending frames to write + id uint32 // id of the stream, or 0 for the root of the tree + weight uint8 // the actual weight is weight+1, so the value is in [1,256] + state http2priorityNodeState // open | closed | idle + bytes int64 // number of bytes written by this node, or 0 if closed + subtreeBytes int64 // sum(node.bytes) of all nodes in this subtree + + // These links form the priority tree. + parent *http2priorityNode + kids *http2priorityNode // start of the kids list + prev, next *http2priorityNode // doubly-linked list of siblings } -func (ws *http2writeScheduler) streamQueue(streamID uint32) *http2writeQueue { - if q, ok := ws.sq[streamID]; ok { - return q +func (n *http2priorityNode) setParent(parent *http2priorityNode) { + if n == parent { + panic("setParent to self") } - if ws.sq == nil { - ws.sq = make(map[uint32]*http2writeQueue) + if n.parent == parent { + return + } + + if parent := n.parent; parent != nil { + if n.prev == nil { + parent.kids = n.next + } else { + n.prev.next = n.next + } + if n.next != nil { + n.next.prev = n.prev + } + } + + n.parent = parent + if parent == nil { + n.next = nil + n.prev = nil + } else { + n.next = parent.kids + n.prev = nil + if n.next != nil { + n.next.prev = n + } + parent.kids = n } - q := ws.getEmptyQueue() - ws.sq[streamID] = q - return q } -// take returns the most important frame to write and removes it from the scheduler. -// It is illegal to call this if the scheduler is empty or if there are no connection-level -// flow control bytes available. -func (ws *http2writeScheduler) take() (wm http2frameWriteMsg, ok bool) { - if ws.maxFrameSize == 0 { - panic("internal error: ws.maxFrameSize not initialized or invalid") +func (n *http2priorityNode) addBytes(b int64) { + n.bytes += b + for ; n != nil; n = n.parent { + n.subtreeBytes += b } +} - if !ws.zero.empty() { - return ws.zero.shift(), true +// walkReadyInOrder iterates over the tree in priority order, calling f for each node +// with a non-empty write queue. When f returns true, this funcion returns true and the +// walk halts. tmp is used as scratch space for sorting. +// +// f(n, openParent) takes two arguments: the node to visit, n, and a bool that is true +// if any ancestor p of n is still open (ignoring the root node). +func (n *http2priorityNode) walkReadyInOrder(openParent bool, tmp *[]*http2priorityNode, f func(*http2priorityNode, bool) bool) bool { + if !n.q.empty() && f(n, openParent) { + return true } - if len(ws.sq) == 0 { - return + if n.kids == nil { + return false + } + + if n.id != 0 { + openParent = openParent || (n.state == http2priorityNodeOpen) } - for id, q := range ws.sq { - if q.firstIsNoCost() { - return ws.takeFrom(id, q) + w := n.kids.weight + needSort := false + for k := n.kids.next; k != nil; k = k.next { + if k.weight != w { + needSort = true + break } } + if !needSort { + for k := n.kids; k != nil; k = k.next { + if k.walkReadyInOrder(openParent, tmp, f) { + return true + } + } + return false + } - if len(ws.canSend) != 0 { - panic("should be empty") + *tmp = (*tmp)[:0] + for n.kids != nil { + *tmp = append(*tmp, n.kids) + n.kids.setParent(nil) } - for _, q := range ws.sq { - if n := ws.streamWritableBytes(q); n > 0 { - ws.canSend = append(ws.canSend, q) + sort.Sort(http2sortPriorityNodeSiblings(*tmp)) + for i := len(*tmp) - 1; i >= 0; i-- { + (*tmp)[i].setParent(n) + } + for k := n.kids; k != nil; k = k.next { + if k.walkReadyInOrder(openParent, tmp, f) { + return true } } - if len(ws.canSend) == 0 { - return + return false +} + +type http2sortPriorityNodeSiblings []*http2priorityNode + +func (z http2sortPriorityNodeSiblings) Len() int { return len(z) } + +func (z http2sortPriorityNodeSiblings) Swap(i, k int) { z[i], z[k] = z[k], z[i] } + +func (z http2sortPriorityNodeSiblings) Less(i, k int) bool { + + wi, bi := float64(z[i].weight+1), float64(z[i].subtreeBytes) + wk, bk := float64(z[k].weight+1), float64(z[k].subtreeBytes) + if bi == 0 && bk == 0 { + return wi >= wk } - defer ws.zeroCanSend() + if bk == 0 { + return false + } + return bi/bk <= wi/wk +} - q := ws.canSend[0] +type http2priorityWriteScheduler struct { + // root is the root of the priority tree, where root.id = 0. + // The root queues control frames that are not associated with any stream. + root http2priorityNode - return ws.takeFrom(q.streamID(), q) + // nodes maps stream ids to priority tree nodes. + nodes map[uint32]*http2priorityNode + + // maxID is the maximum stream id in nodes. + maxID uint32 + + // lists of nodes that have been closed or are idle, but are kept in + // the tree for improved prioritization. When the lengths exceed either + // maxClosedNodesInTree or maxIdleNodesInTree, old nodes are discarded. + closedNodes, idleNodes []*http2priorityNode + + // From the config. + maxClosedNodesInTree int + maxIdleNodesInTree int + writeThrottleLimit int32 + enableWriteThrottle bool + + // tmp is scratch space for priorityNode.walkReadyInOrder to reduce allocations. + tmp []*http2priorityNode + + // pool of empty queues for reuse. + queuePool http2writeQueuePool } -// zeroCanSend is defered from take. -func (ws *http2writeScheduler) zeroCanSend() { - for i := range ws.canSend { - ws.canSend[i] = nil +func (ws *http2priorityWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) { + + if curr := ws.nodes[streamID]; curr != nil { + if curr.state != http2priorityNodeIdle { + panic(fmt.Sprintf("stream %d already opened", streamID)) + } + curr.state = http2priorityNodeOpen + return + } + + parent := ws.nodes[options.PusherID] + if parent == nil { + parent = &ws.root + } + n := &http2priorityNode{ + q: *ws.queuePool.get(), + id: streamID, + weight: http2priorityDefaultWeight, + state: http2priorityNodeOpen, + } + n.setParent(parent) + ws.nodes[streamID] = n + if streamID > ws.maxID { + ws.maxID = streamID } - ws.canSend = ws.canSend[:0] } -// streamWritableBytes returns the number of DATA bytes we could write -// from the given queue's stream, if this stream/queue were -// selected. It is an error to call this if q's head isn't a -// *writeData. -func (ws *http2writeScheduler) streamWritableBytes(q *http2writeQueue) int32 { - wm := q.head() - ret := wm.stream.flow.available() - if ret == 0 { - return 0 +func (ws *http2priorityWriteScheduler) CloseStream(streamID uint32) { + if streamID == 0 { + panic("violation of WriteScheduler interface: cannot close stream 0") } - if int32(ws.maxFrameSize) < ret { - ret = int32(ws.maxFrameSize) + if ws.nodes[streamID] == nil { + panic(fmt.Sprintf("violation of WriteScheduler interface: unknown stream %d", streamID)) } - if ret == 0 { - panic("internal error: ws.maxFrameSize not initialized or invalid") + if ws.nodes[streamID].state != http2priorityNodeOpen { + panic(fmt.Sprintf("violation of WriteScheduler interface: stream %d already closed", streamID)) } - wd := wm.write.(*http2writeData) - if len(wd.p) < int(ret) { - ret = int32(len(wd.p)) + + n := ws.nodes[streamID] + n.state = http2priorityNodeClosed + n.addBytes(-n.bytes) + + q := n.q + ws.queuePool.put(&q) + n.q.s = nil + if ws.maxClosedNodesInTree > 0 { + ws.addClosedOrIdleNode(&ws.closedNodes, ws.maxClosedNodesInTree, n) + } else { + ws.removeNode(n) } - return ret } -func (ws *http2writeScheduler) takeFrom(id uint32, q *http2writeQueue) (wm http2frameWriteMsg, ok bool) { - wm = q.head() - - if wd, ok := wm.write.(*http2writeData); ok && len(wd.p) > 0 { - allowed := wm.stream.flow.available() - if allowed == 0 { +func (ws *http2priorityWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) { + if streamID == 0 { + panic("adjustPriority on root") + } - return http2frameWriteMsg{}, false + n := ws.nodes[streamID] + if n == nil { + if streamID <= ws.maxID || ws.maxIdleNodesInTree == 0 { + return } - if int32(ws.maxFrameSize) < allowed { - allowed = int32(ws.maxFrameSize) + ws.maxID = streamID + n = &http2priorityNode{ + q: *ws.queuePool.get(), + id: streamID, + weight: http2priorityDefaultWeight, + state: http2priorityNodeIdle, } + n.setParent(&ws.root) + ws.nodes[streamID] = n + ws.addClosedOrIdleNode(&ws.idleNodes, ws.maxIdleNodesInTree, n) + } - if len(wd.p) > int(allowed) { - wm.stream.flow.take(allowed) - chunk := wd.p[:allowed] - wd.p = wd.p[allowed:] + parent := ws.nodes[priority.StreamDep] + if parent == nil { + n.setParent(&ws.root) + n.weight = http2priorityDefaultWeight + return + } - return http2frameWriteMsg{ - stream: wm.stream, - write: &http2writeData{ - streamID: wd.streamID, - p: chunk, + if n == parent { + return + } - endStream: false, - }, + for x := parent.parent; x != nil; x = x.parent { + if x == n { + parent.setParent(n.parent) + break + } + } - done: nil, - }, true + if priority.Exclusive { + k := parent.kids + for k != nil { + next := k.next + if k != n { + k.setParent(n) + } + k = next } - wm.stream.flow.take(int32(len(wd.p))) } - q.shift() - if q.empty() { - ws.putEmptyQueue(q) - delete(ws.sq, id) + n.setParent(parent) + n.weight = priority.Weight +} + +func (ws *http2priorityWriteScheduler) Push(wr http2FrameWriteRequest) { + var n *http2priorityNode + if id := wr.StreamID(); id == 0 { + n = &ws.root + } else { + n = ws.nodes[id] + if n == nil { + + if wr.DataSize() > 0 { + panic("add DATA on non-open stream") + } + n = &ws.root + } } - return wm, true + n.q.push(wr) } -func (ws *http2writeScheduler) forgetStream(id uint32) { - q, ok := ws.sq[id] - if !ok { +func (ws *http2priorityWriteScheduler) Pop() (wr http2FrameWriteRequest, ok bool) { + ws.root.walkReadyInOrder(false, &ws.tmp, func(n *http2priorityNode, openParent bool) bool { + limit := int32(math.MaxInt32) + if openParent { + limit = ws.writeThrottleLimit + } + wr, ok = n.q.consume(limit) + if !ok { + return false + } + n.addBytes(int64(wr.DataSize())) + + if openParent { + ws.writeThrottleLimit += 1024 + if ws.writeThrottleLimit < 0 { + ws.writeThrottleLimit = math.MaxInt32 + } + } else if ws.enableWriteThrottle { + ws.writeThrottleLimit = 1024 + } + return true + }) + return wr, ok +} + +func (ws *http2priorityWriteScheduler) addClosedOrIdleNode(list *[]*http2priorityNode, maxSize int, n *http2priorityNode) { + if maxSize == 0 { return } - delete(ws.sq, id) + if len(*list) == maxSize { - for i := range q.s { - q.s[i] = http2frameWriteMsg{} + ws.removeNode((*list)[0]) + x := (*list)[1:] + copy(*list, x) + *list = (*list)[:len(x)] } - q.s = q.s[:0] - ws.putEmptyQueue(q) + *list = append(*list, n) } -type http2writeQueue struct { - s []http2frameWriteMsg +func (ws *http2priorityWriteScheduler) removeNode(n *http2priorityNode) { + for k := n.kids; k != nil; k = k.next { + k.setParent(n.parent) + } + n.setParent(nil) + delete(ws.nodes, n.id) } -// streamID returns the stream ID for a non-empty stream-specific queue. -func (q *http2writeQueue) streamID() uint32 { return q.s[0].stream.id } +// NewRandomWriteScheduler constructs a WriteScheduler that ignores HTTP/2 +// priorities. Control frames like SETTINGS and PING are written before DATA +// frames, but if no control frames are queued and multiple streams have queued +// HEADERS or DATA frames, Pop selects a ready stream arbitrarily. +func http2NewRandomWriteScheduler() http2WriteScheduler { + return &http2randomWriteScheduler{sq: make(map[uint32]*http2writeQueue)} +} -func (q *http2writeQueue) empty() bool { return len(q.s) == 0 } +type http2randomWriteScheduler struct { + // zero are frames not associated with a specific stream. + zero http2writeQueue + + // sq contains the stream-specific queues, keyed by stream ID. + // When a stream is idle or closed, it's deleted from the map. + sq map[uint32]*http2writeQueue -func (q *http2writeQueue) push(wm http2frameWriteMsg) { - q.s = append(q.s, wm) + // pool of empty queues for reuse. + queuePool http2writeQueuePool } -// head returns the next item that would be removed by shift. -func (q *http2writeQueue) head() http2frameWriteMsg { - if len(q.s) == 0 { - panic("invalid use of queue") - } - return q.s[0] +func (ws *http2randomWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) { + } -func (q *http2writeQueue) shift() http2frameWriteMsg { - if len(q.s) == 0 { - panic("invalid use of queue") +func (ws *http2randomWriteScheduler) CloseStream(streamID uint32) { + q, ok := ws.sq[streamID] + if !ok { + return } - wm := q.s[0] + delete(ws.sq, streamID) + ws.queuePool.put(q) +} + +func (ws *http2randomWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) { - copy(q.s, q.s[1:]) - q.s[len(q.s)-1] = http2frameWriteMsg{} - q.s = q.s[:len(q.s)-1] - return wm } -func (q *http2writeQueue) firstIsNoCost() bool { - if df, ok := q.s[0].write.(*http2writeData); ok { - return len(df.p) == 0 +func (ws *http2randomWriteScheduler) Push(wr http2FrameWriteRequest) { + id := wr.StreamID() + if id == 0 { + ws.zero.push(wr) + return } - return true + q, ok := ws.sq[id] + if !ok { + q = ws.queuePool.get() + ws.sq[id] = q + } + q.push(wr) +} + +func (ws *http2randomWriteScheduler) Pop() (http2FrameWriteRequest, bool) { + + if !ws.zero.empty() { + return ws.zero.shift(), true + } + + for _, q := range ws.sq { + if wr, ok := q.consume(math.MaxInt32); ok { + return wr, true + } + } + return http2FrameWriteRequest{}, false } diff --git a/libgo/go/net/http/header.go b/libgo/go/net/http/header.go index 6343165..8321692 100644 --- a/libgo/go/net/http/header.go +++ b/libgo/go/net/http/header.go @@ -32,9 +32,11 @@ func (h Header) Set(key, value string) { } // Get gets the first value associated with the given key. +// It is case insensitive; textproto.CanonicalMIMEHeaderKey is used +// to canonicalize the provided key. // If there are no values associated with the key, Get returns "". -// To access multiple values of a key, access the map directly -// with CanonicalHeaderKey. +// To access multiple values of a key, or to use non-canonical keys, +// access the map directly. func (h Header) Get(key string) string { return textproto.MIMEHeader(h).Get(key) } diff --git a/libgo/go/net/http/http.go b/libgo/go/net/http/http.go index b34ae41..826f7ff 100644 --- a/libgo/go/net/http/http.go +++ b/libgo/go/net/http/http.go @@ -5,7 +5,11 @@ package http import ( + "io" + "strconv" "strings" + "time" + "unicode/utf8" "golang_org/x/net/lex/httplex" ) @@ -14,6 +18,10 @@ import ( // Transport's byte-limiting readers. const maxInt64 = 1<<63 - 1 +// aLongTimeAgo is a non-zero time, far in the past, used for +// immediate cancelation of network operations. +var aLongTimeAgo = time.Unix(233431200, 0) + // TODO(bradfitz): move common stuff here. The other files have accumulated // generic http stuff in random places. @@ -41,3 +49,93 @@ func removeEmptyPort(host string) string { func isNotToken(r rune) bool { return !httplex.IsTokenRune(r) } + +func isASCII(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] >= utf8.RuneSelf { + return false + } + } + return true +} + +func hexEscapeNonASCII(s string) string { + newLen := 0 + for i := 0; i < len(s); i++ { + if s[i] >= utf8.RuneSelf { + newLen += 3 + } else { + newLen++ + } + } + if newLen == len(s) { + return s + } + b := make([]byte, 0, newLen) + for i := 0; i < len(s); i++ { + if s[i] >= utf8.RuneSelf { + b = append(b, '%') + b = strconv.AppendInt(b, int64(s[i]), 16) + } else { + b = append(b, s[i]) + } + } + return string(b) +} + +// NoBody is an io.ReadCloser with no bytes. Read always returns EOF +// and Close always returns nil. It can be used in an outgoing client +// request to explicitly signal that a request has zero bytes. +// An alternative, however, is to simply set Request.Body to nil. +var NoBody = noBody{} + +type noBody struct{} + +func (noBody) Read([]byte) (int, error) { return 0, io.EOF } +func (noBody) Close() error { return nil } +func (noBody) WriteTo(io.Writer) (int64, error) { return 0, nil } + +var ( + // verify that an io.Copy from NoBody won't require a buffer: + _ io.WriterTo = NoBody + _ io.ReadCloser = NoBody +) + +// PushOptions describes options for Pusher.Push. +type PushOptions struct { + // Method specifies the HTTP method for the promised request. + // If set, it must be "GET" or "HEAD". Empty means "GET". + Method string + + // Header specifies additional promised request headers. This cannot + // include HTTP/2 pseudo header fields like ":path" and ":scheme", + // which will be added automatically. + Header Header +} + +// Pusher is the interface implemented by ResponseWriters that support +// HTTP/2 server push. For more background, see +// https://tools.ietf.org/html/rfc7540#section-8.2. +type Pusher interface { + // Push initiates an HTTP/2 server push. This constructs a synthetic + // request using the given target and options, serializes that request + // into a PUSH_PROMISE frame, then dispatches that request using the + // server's request handler. If opts is nil, default options are used. + // + // The target must either be an absolute path (like "/path") or an absolute + // URL that contains a valid host and the same scheme as the parent request. + // If the target is a path, it will inherit the scheme and host of the + // parent request. + // + // The HTTP/2 spec disallows recursive pushes and cross-authority pushes. + // Push may or may not detect these invalid pushes; however, invalid + // pushes will be detected and canceled by conforming clients. + // + // Handlers that wish to push URL X should call Push before sending any + // data that may trigger a request for URL X. This avoids a race where the + // client issues requests for X before receiving the PUSH_PROMISE for X. + // + // Push returns ErrNotSupported if the client has disabled push or if push + // is not supported on the underlying connection. + Push(target string, opts *PushOptions) error +} diff --git a/libgo/go/net/http/http_test.go b/libgo/go/net/http/http_test.go index 34da4bb..8f466bb 100644 --- a/libgo/go/net/http/http_test.go +++ b/libgo/go/net/http/http_test.go @@ -12,8 +12,13 @@ import ( "os/exec" "reflect" "testing" + "time" ) +func init() { + shutdownPollInterval = 5 * time.Millisecond +} + func TestForeachHeaderElement(t *testing.T) { tests := []struct { in string @@ -51,6 +56,18 @@ func TestCleanHost(t *testing.T) { {"www.google.com foo", "www.google.com"}, {"www.google.com/foo", "www.google.com"}, {" first character is a space", ""}, + {"[1::6]:8080", "[1::6]:8080"}, + + // Punycode: + {"гофер.рф/foo", "xn--c1ae0ajs.xn--p1ai"}, + {"bücher.de", "xn--bcher-kva.de"}, + {"bücher.de:8080", "xn--bcher-kva.de:8080"}, + // Verify we convert to lowercase before punycode: + {"BÜCHER.de", "xn--bcher-kva.de"}, + {"BÜCHER.de:8080", "xn--bcher-kva.de:8080"}, + // Verify we normalize to NFC before punycode: + {"gophér.nfc", "xn--gophr-esa.nfc"}, // NFC input; no work needed + {"goph\u0065\u0301r.nfd", "xn--gophr-esa.nfd"}, // NFD input } for _, tt := range tests { got := cleanHost(tt.in) @@ -65,8 +82,9 @@ func TestCleanHost(t *testing.T) { // This catches accidental dependencies between the HTTP transport and // server code. func TestCmdGoNoHTTPServer(t *testing.T) { + t.Parallel() goBin := testenv.GoToolPath(t) - out, err := exec.Command("go", "tool", "nm", goBin).CombinedOutput() + out, err := exec.Command(goBin, "tool", "nm", goBin).CombinedOutput() if err != nil { t.Fatalf("go tool nm: %v: %s", err, out) } diff --git a/libgo/go/net/http/httptest/httptest.go b/libgo/go/net/http/httptest/httptest.go index e2148a6..f7202da 100644 --- a/libgo/go/net/http/httptest/httptest.go +++ b/libgo/go/net/http/httptest/httptest.go @@ -35,6 +35,9 @@ import ( // // NewRequest panics on error for ease of use in testing, where a // panic is acceptable. +// +// To generate a client HTTP request instead of a server request, see +// the NewRequest function in the net/http package. func NewRequest(method, target string, body io.Reader) *http.Request { if method == "" { method = "GET" diff --git a/libgo/go/net/http/httptest/recorder.go b/libgo/go/net/http/httptest/recorder.go index 0ad26a3..5f1aa6a 100644 --- a/libgo/go/net/http/httptest/recorder.go +++ b/libgo/go/net/http/httptest/recorder.go @@ -8,15 +8,33 @@ import ( "bytes" "io/ioutil" "net/http" + "strconv" + "strings" ) // ResponseRecorder is an implementation of http.ResponseWriter that // records its mutations for later inspection in tests. type ResponseRecorder struct { - Code int // the HTTP response code from WriteHeader - HeaderMap http.Header // the HTTP response headers - Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to - Flushed bool + // Code is the HTTP response code set by WriteHeader. + // + // Note that if a Handler never calls WriteHeader or Write, + // this might end up being 0, rather than the implicit + // http.StatusOK. To get the implicit value, use the Result + // method. + Code int + + // HeaderMap contains the headers explicitly set by the Handler. + // + // To get the implicit headers set by the server (such as + // automatic Content-Type), use the Result method. + HeaderMap http.Header + + // Body is the buffer to which the Handler's Write calls are sent. + // If nil, the Writes are silently discarded. + Body *bytes.Buffer + + // Flushed is whether the Handler called Flush. + Flushed bool result *http.Response // cache of Result's return value snapHeader http.Header // snapshot of HeaderMap at first Write @@ -136,6 +154,9 @@ func (rw *ResponseRecorder) Flush() { // first write call, or at the time of this call, if the handler never // did a write. // +// The Response.Body is guaranteed to be non-nil and Body.Read call is +// guaranteed to not return any error other than io.EOF. +// // Result must only be called after the handler has finished running. func (rw *ResponseRecorder) Result() *http.Response { if rw.result != nil { @@ -159,6 +180,7 @@ func (rw *ResponseRecorder) Result() *http.Response { if rw.Body != nil { res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes())) } + res.ContentLength = parseContentLength(res.Header.Get("Content-Length")) if trailers, ok := rw.snapHeader["Trailer"]; ok { res.Trailer = make(http.Header, len(trailers)) @@ -181,5 +203,33 @@ func (rw *ResponseRecorder) Result() *http.Response { res.Trailer[k] = vv2 } } + for k, vv := range rw.HeaderMap { + if !strings.HasPrefix(k, http.TrailerPrefix) { + continue + } + if res.Trailer == nil { + res.Trailer = make(http.Header) + } + for _, v := range vv { + res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v) + } + } return res } + +// parseContentLength trims whitespace from s and returns -1 if no value +// is set, or the value if it's >= 0. +// +// This a modified version of same function found in net/http/transfer.go. This +// one just ignores an invalid header. +func parseContentLength(cl string) int64 { + cl = strings.TrimSpace(cl) + if cl == "" { + return -1 + } + n, err := strconv.ParseInt(cl, 10, 64) + if err != nil { + return -1 + } + return n +} diff --git a/libgo/go/net/http/httptest/recorder_test.go b/libgo/go/net/http/httptest/recorder_test.go index d4e7137..9afba4e 100644 --- a/libgo/go/net/http/httptest/recorder_test.go +++ b/libgo/go/net/http/httptest/recorder_test.go @@ -94,6 +94,14 @@ func TestRecorder(t *testing.T) { return nil } } + hasContentLength := func(length int64) checkFunc { + return func(rec *ResponseRecorder) error { + if got := rec.Result().ContentLength; got != length { + return fmt.Errorf("ContentLength = %d; want %d", got, length) + } + return nil + } + } tests := []struct { name string @@ -141,7 +149,7 @@ func TestRecorder(t *testing.T) { w.(http.Flusher).Flush() // also sends a 200 w.WriteHeader(201) }, - check(hasStatus(200), hasFlush(true)), + check(hasStatus(200), hasFlush(true), hasContentLength(-1)), }, { "Content-Type detection", @@ -199,6 +207,7 @@ func TestRecorder(t *testing.T) { w.Header().Set("Trailer-A", "valuea") w.Header().Set("Trailer-C", "valuec") w.Header().Set("Trailer-NotDeclared", "should be omitted") + w.Header().Set("Trailer:Trailer-D", "with prefix") }, check( hasStatus(200), @@ -208,6 +217,7 @@ func TestRecorder(t *testing.T) { hasTrailer("Trailer-A", "valuea"), hasTrailer("Trailer-C", "valuec"), hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"), + hasTrailer("Trailer-D", "with prefix"), ), }, { @@ -244,6 +254,16 @@ func TestRecorder(t *testing.T) { hasNotHeaders("X-Bar"), ), }, + { + "setting Content-Length header", + func(w http.ResponseWriter, r *http.Request) { + body := "Some body" + contentLength := fmt.Sprintf("%d", len(body)) + w.Header().Set("Content-Length", contentLength) + io.WriteString(w, body) + }, + check(hasStatus(200), hasContents("Some body"), hasContentLength(9)), + }, } r, _ := http.NewRequest("GET", "http://foo.com/", nil) for _, tt := range tests { diff --git a/libgo/go/net/http/httptest/server.go b/libgo/go/net/http/httptest/server.go index e27526a..5e9ace5 100644 --- a/libgo/go/net/http/httptest/server.go +++ b/libgo/go/net/http/httptest/server.go @@ -16,7 +16,6 @@ import ( "net/http" "net/http/internal" "os" - "runtime" "sync" "time" ) @@ -114,9 +113,10 @@ func (s *Server) StartTLS() { } existingConfig := s.TLS - s.TLS = new(tls.Config) if existingConfig != nil { - *s.TLS = *existingConfig + s.TLS = existingConfig.Clone() + } else { + s.TLS = new(tls.Config) } if s.TLS.NextProtos == nil { s.TLS.NextProtos = []string{"http/1.1"} @@ -293,15 +293,6 @@ func (s *Server) closeConn(c net.Conn) { s.closeConnChan(c, nil) } // closeConnChan is like closeConn, but takes an optional channel to receive a value // when the goroutine closing c is done. func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) { - if runtime.GOOS == "plan9" { - // Go's Plan 9 net package isn't great at unblocking reads when - // their underlying TCP connections are closed. Don't trust - // that that the ConnState state machine will get to - // StateClosed. Instead, just go there directly. Plan 9 may leak - // resources if the syscall doesn't end up returning. Oh well. - s.forgetConn(c) - } - c.Close() if done != nil { done <- struct{}{} diff --git a/libgo/go/net/http/httptrace/example_test.go b/libgo/go/net/http/httptrace/example_test.go new file mode 100644 index 0000000..27cdcde --- /dev/null +++ b/libgo/go/net/http/httptrace/example_test.go @@ -0,0 +1,31 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build ignore + +package httptrace_test + +import ( + "fmt" + "log" + "net/http" + "net/http/httptrace" +) + +func Example() { + req, _ := http.NewRequest("GET", "http://example.com", nil) + trace := &httptrace.ClientTrace{ + GotConn: func(connInfo httptrace.GotConnInfo) { + fmt.Printf("Got Conn: %+v\n", connInfo) + }, + DNSDone: func(dnsInfo httptrace.DNSDoneInfo) { + fmt.Printf("DNS Info: %+v\n", dnsInfo) + }, + } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + _, err := http.DefaultTransport.RoundTrip(req) + if err != nil { + log.Fatal(err) + } +} diff --git a/libgo/go/net/http/httptrace/trace.go b/libgo/go/net/http/httptrace/trace.go index 6f187a7..ea7b38c 100644 --- a/libgo/go/net/http/httptrace/trace.go +++ b/libgo/go/net/http/httptrace/trace.go @@ -1,6 +1,6 @@ // Copyright 2016 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file.h +// license that can be found in the LICENSE file. // Package httptrace provides mechanisms to trace the events within // HTTP client requests. @@ -8,6 +8,7 @@ package httptrace import ( "context" + "crypto/tls" "internal/nettrace" "net" "reflect" @@ -65,11 +66,16 @@ func WithClientTrace(ctx context.Context, trace *ClientTrace) context.Context { return ctx } -// ClientTrace is a set of hooks to run at various stages of an HTTP -// client request. Any particular hook may be nil. Functions may be -// called concurrently from different goroutines, starting after the -// call to Transport.RoundTrip and ending either when RoundTrip -// returns an error, or when the Response.Body is closed. +// ClientTrace is a set of hooks to run at various stages of an outgoing +// HTTP request. Any particular hook may be nil. Functions may be +// called concurrently from different goroutines and some may be called +// after the request has completed or failed. +// +// ClientTrace currently traces a single HTTP request & response +// during a single round trip and has no hooks that span a series +// of redirected requests. +// +// See https://blog.golang.org/http-tracing for more. type ClientTrace struct { // GetConn is called before a connection is created or // retrieved from an idle pool. The hostPort is the @@ -119,6 +125,16 @@ type ClientTrace struct { // enabled, this may be called multiple times. ConnectDone func(network, addr string, err error) + // TLSHandshakeStart is called when the TLS handshake is started. When + // connecting to a HTTPS site via a HTTP proxy, the handshake happens after + // the CONNECT request is processed by the proxy. + TLSHandshakeStart func() + + // TLSHandshakeDone is called after the TLS handshake with either the + // successful handshake's connection state, or a non-nil error on handshake + // failure. + TLSHandshakeDone func(tls.ConnectionState, error) + // WroteHeaders is called after the Transport has written // the request headers. WroteHeaders func() @@ -130,7 +146,8 @@ type ClientTrace struct { Wait100Continue func() // WroteRequest is called with the result of writing the - // request and any body. + // request and any body. It may be called multiple times + // in the case of retried requests. WroteRequest func(WroteRequestInfo) } diff --git a/libgo/go/net/http/httptrace/trace_test.go b/libgo/go/net/http/httptrace/trace_test.go index c7eaed8..bb57ada 100644 --- a/libgo/go/net/http/httptrace/trace_test.go +++ b/libgo/go/net/http/httptrace/trace_test.go @@ -1,14 +1,41 @@ // Copyright 2016 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file.h +// license that can be found in the LICENSE file. package httptrace import ( "bytes" + "context" "testing" ) +func TestWithClientTrace(t *testing.T) { + var buf bytes.Buffer + connectStart := func(b byte) func(network, addr string) { + return func(network, addr string) { + buf.WriteByte(b) + } + } + + ctx := context.Background() + oldtrace := &ClientTrace{ + ConnectStart: connectStart('O'), + } + ctx = WithClientTrace(ctx, oldtrace) + newtrace := &ClientTrace{ + ConnectStart: connectStart('N'), + } + ctx = WithClientTrace(ctx, newtrace) + trace := ContextClientTrace(ctx) + + buf.Reset() + trace.ConnectStart("net", "addr") + if got, want := buf.String(), "NO"; got != want { + t.Errorf("got %q; want %q", got, want) + } +} + func TestCompose(t *testing.T) { var buf bytes.Buffer var testNum int diff --git a/libgo/go/net/http/httputil/dump.go b/libgo/go/net/http/httputil/dump.go index 1511681..7104c37 100644 --- a/libgo/go/net/http/httputil/dump.go +++ b/libgo/go/net/http/httputil/dump.go @@ -18,11 +18,16 @@ import ( "time" ) -// One of the copies, say from b to r2, could be avoided by using a more -// elaborate trick where the other copy is made during Request/Response.Write. -// This would complicate things too much, given that these functions are for -// debugging only. +// drainBody reads all of b to memory and then returns two equivalent +// ReadClosers yielding the same bytes. +// +// It returns an error if the initial slurp of all bytes fails. It does not attempt +// to make the returned ReadClosers have identical error-matching behavior. func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) { + if b == http.NoBody { + // No copying needed. Preserve the magic sentinel meaning of NoBody. + return http.NoBody, http.NoBody, nil + } var buf bytes.Buffer if _, err = buf.ReadFrom(b); err != nil { return nil, b, err diff --git a/libgo/go/net/http/httputil/dump_test.go b/libgo/go/net/http/httputil/dump_test.go index 2e980d3..f881020 100644 --- a/libgo/go/net/http/httputil/dump_test.go +++ b/libgo/go/net/http/httputil/dump_test.go @@ -184,6 +184,18 @@ var dumpTests = []dumpTest{ WantDump: "POST /v2/api/?login HTTP/1.1\r\n" + "Host: passport.myhost.com\r\n\r\n", }, + + // Issue 18506: make drainBody recognize NoBody. Otherwise + // this was turning into a chunked request. + { + Req: *mustNewRequest("POST", "http://example.com/foo", http.NoBody), + + WantDumpOut: "POST /foo HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go-http-client/1.1\r\n" + + "Content-Length: 0\r\n" + + "Accept-Encoding: gzip\r\n\r\n", + }, } func TestDumpRequest(t *testing.T) { diff --git a/libgo/go/net/http/httputil/persist.go b/libgo/go/net/http/httputil/persist.go index 87ddd52..cbedf25 100644 --- a/libgo/go/net/http/httputil/persist.go +++ b/libgo/go/net/http/httputil/persist.go @@ -15,9 +15,14 @@ import ( ) var ( + // Deprecated: No longer used. ErrPersistEOF = &http.ProtocolError{ErrorString: "persistent connection closed"} - ErrClosed = &http.ProtocolError{ErrorString: "connection closed by user"} - ErrPipeline = &http.ProtocolError{ErrorString: "pipeline error"} + + // Deprecated: No longer used. + ErrClosed = &http.ProtocolError{ErrorString: "connection closed by user"} + + // Deprecated: No longer used. + ErrPipeline = &http.ProtocolError{ErrorString: "pipeline error"} ) // This is an API usage error - the local side is closed. diff --git a/libgo/go/net/http/httputil/reverseproxy.go b/libgo/go/net/http/httputil/reverseproxy.go index 49c120a..79c8fe2 100644 --- a/libgo/go/net/http/httputil/reverseproxy.go +++ b/libgo/go/net/http/httputil/reverseproxy.go @@ -7,6 +7,7 @@ package httputil import ( + "context" "io" "log" "net" @@ -29,6 +30,8 @@ type ReverseProxy struct { // the request into a new request to be sent // using Transport. Its response is then copied // back to the original client unmodified. + // Director must not access the provided Request + // after returning. Director func(*http.Request) // The transport used to perform proxy requests. @@ -51,6 +54,11 @@ type ReverseProxy struct { // get byte slices for use by io.CopyBuffer when // copying HTTP response bodies. BufferPool BufferPool + + // ModifyResponse is an optional function that + // modifies the Response from the backend. + // If it returns an error, the proxy returns a StatusBadGateway error. + ModifyResponse func(*http.Response) error } // A BufferPool is an interface for getting and returning temporary @@ -120,76 +128,59 @@ var hopHeaders = []string{ "Upgrade", } -type requestCanceler interface { - CancelRequest(*http.Request) -} - -type runOnFirstRead struct { - io.Reader // optional; nil means empty body - - fn func() // Run before first Read, then set to nil -} - -func (c *runOnFirstRead) Read(bs []byte) (int, error) { - if c.fn != nil { - c.fn() - c.fn = nil - } - if c.Reader == nil { - return 0, io.EOF - } - return c.Reader.Read(bs) -} - func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { transport := p.Transport if transport == nil { transport = http.DefaultTransport } + ctx := req.Context() + if cn, ok := rw.(http.CloseNotifier); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + defer cancel() + notifyChan := cn.CloseNotify() + go func() { + select { + case <-notifyChan: + cancel() + case <-ctx.Done(): + } + }() + } + outreq := new(http.Request) *outreq = *req // includes shallow copies of maps, but okay - - if closeNotifier, ok := rw.(http.CloseNotifier); ok { - if requestCanceler, ok := transport.(requestCanceler); ok { - reqDone := make(chan struct{}) - defer close(reqDone) - - clientGone := closeNotifier.CloseNotify() - - outreq.Body = struct { - io.Reader - io.Closer - }{ - Reader: &runOnFirstRead{ - Reader: outreq.Body, - fn: func() { - go func() { - select { - case <-clientGone: - requestCanceler.CancelRequest(outreq) - case <-reqDone: - } - }() - }, - }, - Closer: outreq.Body, - } - } + if req.ContentLength == 0 { + outreq.Body = nil // Issue 16036: nil Body for http.Transport retries } + outreq = outreq.WithContext(ctx) p.Director(outreq) - outreq.Proto = "HTTP/1.1" - outreq.ProtoMajor = 1 - outreq.ProtoMinor = 1 outreq.Close = false - // Remove hop-by-hop headers to the backend. Especially - // important is "Connection" because we want a persistent - // connection, regardless of what the client sent to us. This - // is modifying the same underlying map from req (shallow + // We are modifying the same underlying map from req (shallow // copied above) so we only copy it if necessary. copiedHeaders := false + + // Remove hop-by-hop headers listed in the "Connection" header. + // See RFC 2616, section 14.10. + if c := outreq.Header.Get("Connection"); c != "" { + for _, f := range strings.Split(c, ",") { + if f = strings.TrimSpace(f); f != "" { + if !copiedHeaders { + outreq.Header = make(http.Header) + copyHeader(outreq.Header, req.Header) + copiedHeaders = true + } + outreq.Header.Del(f) + } + } + } + + // Remove hop-by-hop headers to the backend. Especially + // important is "Connection" because we want a persistent + // connection, regardless of what the client sent to us. for _, h := range hopHeaders { if outreq.Header.Get(h) != "" { if !copiedHeaders { @@ -218,16 +209,34 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } + // Remove hop-by-hop headers listed in the + // "Connection" header of the response. + if c := res.Header.Get("Connection"); c != "" { + for _, f := range strings.Split(c, ",") { + if f = strings.TrimSpace(f); f != "" { + res.Header.Del(f) + } + } + } + for _, h := range hopHeaders { res.Header.Del(h) } + if p.ModifyResponse != nil { + if err := p.ModifyResponse(res); err != nil { + p.logf("http: proxy error: %v", err) + rw.WriteHeader(http.StatusBadGateway) + return + } + } + copyHeader(rw.Header(), res.Header) // The "Trailer" header isn't included in the Transport's response, // at least for *http.Transport. Build it up from Trailer. if len(res.Trailer) > 0 { - var trailerKeys []string + trailerKeys := make([]string, 0, len(res.Trailer)) for k := range res.Trailer { trailerKeys = append(trailerKeys, k) } @@ -266,12 +275,40 @@ func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { if p.BufferPool != nil { buf = p.BufferPool.Get() } - io.CopyBuffer(dst, src, buf) + p.copyBuffer(dst, src, buf) if p.BufferPool != nil { p.BufferPool.Put(buf) } } +func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { + if len(buf) == 0 { + buf = make([]byte, 32*1024) + } + var written int64 + for { + nr, rerr := src.Read(buf) + if rerr != nil && rerr != io.EOF { + p.logf("httputil: ReverseProxy read error during body copy: %v", rerr) + } + if nr > 0 { + nw, werr := dst.Write(buf[:nr]) + if nw > 0 { + written += int64(nw) + } + if werr != nil { + return written, werr + } + if nr != nw { + return written, io.ErrShortWrite + } + } + if rerr != nil { + return written, rerr + } + } +} + func (p *ReverseProxy) logf(format string, args ...interface{}) { if p.ErrorLog != nil { p.ErrorLog.Printf(format, args...) diff --git a/libgo/go/net/http/httputil/reverseproxy_test.go b/libgo/go/net/http/httputil/reverseproxy_test.go index fe7cdb8..20c4e16 100644 --- a/libgo/go/net/http/httputil/reverseproxy_test.go +++ b/libgo/go/net/http/httputil/reverseproxy_test.go @@ -9,6 +9,8 @@ package httputil import ( "bufio" "bytes" + "errors" + "fmt" "io" "io/ioutil" "log" @@ -135,6 +137,61 @@ func TestReverseProxy(t *testing.T) { } +// Issue 16875: remove any proxied headers mentioned in the "Connection" +// header value. +func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { + const fakeConnectionToken = "X-Fake-Connection-Token" + const backendResponse = "I am the backend" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if c := r.Header.Get(fakeConnectionToken); c != "" { + t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) + } + if c := r.Header.Get("Upgrade"); c != "" { + t.Errorf("handler got header %q = %q; want empty", "Upgrade", c) + } + w.Header().Set("Connection", "Upgrade, "+fakeConnectionToken) + w.Header().Set("Upgrade", "should be deleted") + w.Header().Set(fakeConnectionToken, "should be deleted") + io.WriteString(w, backendResponse) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + proxyHandler.ServeHTTP(w, r) + if c := r.Header.Get("Upgrade"); c != "original value" { + t.Errorf("handler modified header %q = %q; want %q", "Upgrade", c, "original value") + } + })) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Header.Set("Connection", "Upgrade, "+fakeConnectionToken) + getReq.Header.Set("Upgrade", "original value") + getReq.Header.Set(fakeConnectionToken, "should be deleted") + res, err := http.DefaultClient.Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + bodyBytes, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("reading body: %v", err) + } + if got, want := string(bodyBytes), backendResponse; got != want { + t.Errorf("got body %q; want %q", got, want) + } + if c := res.Header.Get("Upgrade"); c != "" { + t.Errorf("handler got header %q = %q; want empty", "Upgrade", c) + } + if c := res.Header.Get(fakeConnectionToken); c != "" { + t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) + } +} + func TestXForwardedFor(t *testing.T) { const prevForwardedFor = "client ip" const backendResponse = "I am the backend" @@ -260,14 +317,14 @@ func TestReverseProxyCancelation(t *testing.T) { reqInFlight := make(chan struct{}) backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - close(reqInFlight) + close(reqInFlight) // cause the client to cancel its request select { case <-time.After(10 * time.Second): // Note: this should only happen in broken implementations, and the // closenotify case should be instantaneous. - t.Log("Failed to close backend connection") - t.Fail() + t.Error("Handler never saw CloseNotify") + return case <-w.(http.CloseNotifier).CloseNotify(): } @@ -300,13 +357,13 @@ func TestReverseProxyCancelation(t *testing.T) { }() res, err := http.DefaultClient.Do(getReq) if res != nil { - t.Fatal("Non-nil response") + t.Errorf("got response %v; want nil", res.Status) } if err == nil { // This should be an error like: // Get http://127.0.0.1:58079: read tcp 127.0.0.1:58079: // use of closed network connection - t.Fatal("DefaultClient.Do() returned nil error") + t.Error("DefaultClient.Do() returned nil error; want non-nil error") } } @@ -495,3 +552,115 @@ func TestReverseProxy_Post(t *testing.T) { t.Errorf("got body %q; expected %q", g, e) } } + +type RoundTripperFunc func(*http.Request) (*http.Response, error) + +func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} + +// Issue 16036: send a Request with a nil Body when possible +func TestReverseProxy_NilBody(t *testing.T) { + backendURL, _ := url.Parse("http://fake.tld/") + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.Body != nil { + t.Error("Body != nil; want a nil Body") + } + return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") + }) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + res, err := http.DefaultClient.Get(frontend.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != 502 { + t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status) + } +} + +// Issue 14237. Test ModifyResponse and that an error from it +// causes the proxy to return StatusBadGateway, or StatusOK otherwise. +func TestReverseProxyModifyResponse(t *testing.T) { + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod")) + })) + defer backendServer.Close() + + rpURL, _ := url.Parse(backendServer.URL) + rproxy := NewSingleHostReverseProxy(rpURL) + rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + rproxy.ModifyResponse = func(resp *http.Response) error { + if resp.Header.Get("X-Hit-Mod") != "true" { + return fmt.Errorf("tried to by-pass proxy") + } + return nil + } + + frontendProxy := httptest.NewServer(rproxy) + defer frontendProxy.Close() + + tests := []struct { + url string + wantCode int + }{ + {frontendProxy.URL + "/mod", http.StatusOK}, + {frontendProxy.URL + "/schedule", http.StatusBadGateway}, + } + + for i, tt := range tests { + resp, err := http.Get(tt.url) + if err != nil { + t.Fatalf("failed to reach proxy: %v", err) + } + if g, e := resp.StatusCode, tt.wantCode; g != e { + t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e) + } + resp.Body.Close() + } +} + +// Issue 16659: log errors from short read +func TestReverseProxy_CopyBuffer(t *testing.T) { + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + out := "this call was relayed by the reverse proxy" + // Coerce a wrong content length to induce io.UnexpectedEOF + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) + fmt.Fprintln(w, out) + })) + defer backendServer.Close() + + rpURL, err := url.Parse(backendServer.URL) + if err != nil { + t.Fatal(err) + } + + var proxyLog bytes.Buffer + rproxy := NewSingleHostReverseProxy(rpURL) + rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile) + frontendProxy := httptest.NewServer(rproxy) + defer frontendProxy.Close() + + resp, err := http.Get(frontendProxy.URL) + if err != nil { + t.Fatalf("failed to reach proxy: %v", err) + } + defer resp.Body.Close() + + if _, err := ioutil.ReadAll(resp.Body); err == nil { + t.Fatalf("want non-nil error") + } + expected := []string{ + "EOF", + "read", + } + for _, phrase := range expected { + if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) { + t.Errorf("expected log to contain phrase %q", phrase) + } + } +} diff --git a/libgo/go/net/http/internal/chunked.go b/libgo/go/net/http/internal/chunked.go index 2e62c00..63f321d 100644 --- a/libgo/go/net/http/internal/chunked.go +++ b/libgo/go/net/http/internal/chunked.go @@ -35,10 +35,11 @@ func NewChunkedReader(r io.Reader) io.Reader { } type chunkedReader struct { - r *bufio.Reader - n uint64 // unread bytes in chunk - err error - buf [2]byte + r *bufio.Reader + n uint64 // unread bytes in chunk + err error + buf [2]byte + checkEnd bool // whether need to check for \r\n chunk footer } func (cr *chunkedReader) beginChunk() { @@ -68,6 +69,21 @@ func (cr *chunkedReader) chunkHeaderAvailable() bool { func (cr *chunkedReader) Read(b []uint8) (n int, err error) { for cr.err == nil { + if cr.checkEnd { + if n > 0 && cr.r.Buffered() < 2 { + // We have some data. Return early (per the io.Reader + // contract) instead of potentially blocking while + // reading more. + break + } + if _, cr.err = io.ReadFull(cr.r, cr.buf[:2]); cr.err == nil { + if string(cr.buf[:]) != "\r\n" { + cr.err = errors.New("malformed chunked encoding") + break + } + } + cr.checkEnd = false + } if cr.n == 0 { if n > 0 && !cr.chunkHeaderAvailable() { // We've read enough. Don't potentially block @@ -92,11 +108,7 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) { // If we're at the end of a chunk, read the next two // bytes to verify they are "\r\n". if cr.n == 0 && cr.err == nil { - if _, cr.err = io.ReadFull(cr.r, cr.buf[:2]); cr.err == nil { - if cr.buf[0] != '\r' || cr.buf[1] != '\n' { - cr.err = errors.New("malformed chunked encoding") - } - } + cr.checkEnd = true } } return n, cr.err diff --git a/libgo/go/net/http/internal/chunked_test.go b/libgo/go/net/http/internal/chunked_test.go index 9abe1ab..d067165 100644 --- a/libgo/go/net/http/internal/chunked_test.go +++ b/libgo/go/net/http/internal/chunked_test.go @@ -185,3 +185,30 @@ func TestChunkReadingIgnoresExtensions(t *testing.T) { t.Errorf("read %q; want %q", g, e) } } + +// Issue 17355: ChunkedReader shouldn't block waiting for more data +// if it can return something. +func TestChunkReadPartial(t *testing.T) { + pr, pw := io.Pipe() + go func() { + pw.Write([]byte("7\r\n1234567")) + }() + cr := NewChunkedReader(pr) + readBuf := make([]byte, 7) + n, err := cr.Read(readBuf) + if err != nil { + t.Fatal(err) + } + want := "1234567" + if n != 7 || string(readBuf) != want { + t.Fatalf("Read: %v %q; want %d, %q", n, readBuf[:n], len(want), want) + } + go func() { + pw.Write([]byte("xx")) + }() + _, err = cr.Read(readBuf) + if got := fmt.Sprint(err); !strings.Contains(got, "malformed") { + t.Fatalf("second read = %v; want malformed error", err) + } + +} diff --git a/libgo/go/net/http/main_test.go b/libgo/go/net/http/main_test.go index aea6e12..438bd2e 100644 --- a/libgo/go/net/http/main_test.go +++ b/libgo/go/net/http/main_test.go @@ -6,6 +6,8 @@ package http_test import ( "fmt" + "io/ioutil" + "log" "net/http" "os" "runtime" @@ -15,6 +17,8 @@ import ( "time" ) +var quietLog = log.New(ioutil.Discard, "", 0) + func TestMain(m *testing.M) { v := m.Run() if v == 0 && goroutineLeaked() { @@ -134,3 +138,20 @@ func waitCondition(waitFor, checkEvery time.Duration, fn func() bool) bool { } return false } + +// waitErrCondition is like waitCondition but with errors instead of bools. +func waitErrCondition(waitFor, checkEvery time.Duration, fn func() error) error { + deadline := time.Now().Add(waitFor) + var err error + for time.Now().Before(deadline) { + if err = fn(); err == nil { + return nil + } + time.Sleep(checkEvery) + } + return err +} + +func closeClient(c *http.Client) { + c.Transport.(*http.Transport).CloseIdleConnections() +} diff --git a/libgo/go/net/http/npn_test.go b/libgo/go/net/http/npn_test.go index e2e911d..4c1f6b5 100644 --- a/libgo/go/net/http/npn_test.go +++ b/libgo/go/net/http/npn_test.go @@ -18,6 +18,7 @@ import ( ) func TestNextProtoUpgrade(t *testing.T) { + setParallel(t) defer afterTest(t) ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "path=%s,proto=", r.URL.Path) diff --git a/libgo/go/net/http/range_test.go b/libgo/go/net/http/range_test.go index ef911af..114987e 100644 --- a/libgo/go/net/http/range_test.go +++ b/libgo/go/net/http/range_test.go @@ -38,7 +38,7 @@ var ParseRangeTests = []struct { {"bytes=0-", 10, []httpRange{{0, 10}}}, {"bytes=5-", 10, []httpRange{{5, 5}}}, {"bytes=0-20", 10, []httpRange{{0, 10}}}, - {"bytes=15-,0-5", 10, nil}, + {"bytes=15-,0-5", 10, []httpRange{{0, 6}}}, {"bytes=1-2,5-", 10, []httpRange{{1, 2}, {5, 5}}}, {"bytes=-2 , 7-", 11, []httpRange{{9, 2}, {7, 4}}}, {"bytes=0-0 ,2-2, 7-", 11, []httpRange{{0, 1}, {2, 1}, {7, 4}}}, diff --git a/libgo/go/net/http/readrequest_test.go b/libgo/go/net/http/readrequest_test.go index 4bf646b..28a148b 100644 --- a/libgo/go/net/http/readrequest_test.go +++ b/libgo/go/net/http/readrequest_test.go @@ -25,7 +25,7 @@ type reqTest struct { } var noError = "" -var noBody = "" +var noBodyStr = "" var noTrailer Header = nil var reqTests = []reqTest{ @@ -95,7 +95,7 @@ var reqTests = []reqTest{ RequestURI: "/", }, - noBody, + noBodyStr, noTrailer, noError, }, @@ -121,7 +121,7 @@ var reqTests = []reqTest{ RequestURI: "//user@host/is/actually/a/path/", }, - noBody, + noBodyStr, noTrailer, noError, }, @@ -131,7 +131,7 @@ var reqTests = []reqTest{ "GET ../../../../etc/passwd HTTP/1.1\r\n" + "Host: test\r\n\r\n", nil, - noBody, + noBodyStr, noTrailer, "parse ../../../../etc/passwd: invalid URI for request", }, @@ -141,7 +141,7 @@ var reqTests = []reqTest{ "GET HTTP/1.1\r\n" + "Host: test\r\n\r\n", nil, - noBody, + noBodyStr, noTrailer, "parse : empty url", }, @@ -227,7 +227,7 @@ var reqTests = []reqTest{ RequestURI: "www.google.com:443", }, - noBody, + noBodyStr, noTrailer, noError, }, @@ -251,7 +251,7 @@ var reqTests = []reqTest{ RequestURI: "127.0.0.1:6060", }, - noBody, + noBodyStr, noTrailer, noError, }, @@ -275,7 +275,7 @@ var reqTests = []reqTest{ RequestURI: "/_goRPC_", }, - noBody, + noBodyStr, noTrailer, noError, }, @@ -299,7 +299,7 @@ var reqTests = []reqTest{ RequestURI: "*", }, - noBody, + noBodyStr, noTrailer, noError, }, @@ -323,7 +323,7 @@ var reqTests = []reqTest{ RequestURI: "*", }, - noBody, + noBodyStr, noTrailer, noError, }, @@ -350,7 +350,7 @@ var reqTests = []reqTest{ RequestURI: "/", }, - noBody, + noBodyStr, noTrailer, noError, }, @@ -376,7 +376,7 @@ var reqTests = []reqTest{ RequestURI: "/", }, - noBody, + noBodyStr, noTrailer, noError, }, @@ -397,7 +397,7 @@ var reqTests = []reqTest{ ContentLength: -1, Close: true, }, - noBody, + noBodyStr, noTrailer, noError, }, diff --git a/libgo/go/net/http/request.go b/libgo/go/net/http/request.go index dc55592..fb6bb0a 100644 --- a/libgo/go/net/http/request.go +++ b/libgo/go/net/http/request.go @@ -18,12 +18,17 @@ import ( "io/ioutil" "mime" "mime/multipart" + "net" "net/http/httptrace" "net/textproto" "net/url" "strconv" "strings" "sync" + + "golang_org/x/net/idna" + "golang_org/x/text/unicode/norm" + "golang_org/x/text/width" ) const ( @@ -34,21 +39,40 @@ const ( // is either not present in the request or not a file field. var ErrMissingFile = errors.New("http: no such file") -// HTTP request parsing errors. +// ProtocolError represents an HTTP protocol error. +// +// Deprecated: Not all errors in the http package related to protocol errors +// are of type ProtocolError. type ProtocolError struct { ErrorString string } -func (err *ProtocolError) Error() string { return err.ErrorString } +func (pe *ProtocolError) Error() string { return pe.ErrorString } var ( - ErrHeaderTooLong = &ProtocolError{"header too long"} - ErrShortBody = &ProtocolError{"entity body too short"} - ErrNotSupported = &ProtocolError{"feature not supported"} - ErrUnexpectedTrailer = &ProtocolError{"trailer header without chunked transfer encoding"} + // ErrNotSupported is returned by the Push method of Pusher + // implementations to indicate that HTTP/2 Push support is not + // available. + ErrNotSupported = &ProtocolError{"feature not supported"} + + // ErrUnexpectedTrailer is returned by the Transport when a server + // replies with a Trailer header, but without a chunked reply. + ErrUnexpectedTrailer = &ProtocolError{"trailer header without chunked transfer encoding"} + + // ErrMissingBoundary is returned by Request.MultipartReader when the + // request's Content-Type does not include a "boundary" parameter. + ErrMissingBoundary = &ProtocolError{"no multipart boundary param in Content-Type"} + + // ErrNotMultipart is returned by Request.MultipartReader when the + // request's Content-Type is not multipart/form-data. + ErrNotMultipart = &ProtocolError{"request Content-Type isn't multipart/form-data"} + + // Deprecated: ErrHeaderTooLong is not used. + ErrHeaderTooLong = &ProtocolError{"header too long"} + // Deprecated: ErrShortBody is not used. + ErrShortBody = &ProtocolError{"entity body too short"} + // Deprecated: ErrMissingContentLength is not used. ErrMissingContentLength = &ProtocolError{"missing ContentLength in HEAD response"} - ErrNotMultipart = &ProtocolError{"request Content-Type isn't multipart/form-data"} - ErrMissingBoundary = &ProtocolError{"no multipart boundary param in Content-Type"} ) type badStringError struct { @@ -146,11 +170,20 @@ type Request struct { // Handler does not need to. Body io.ReadCloser + // GetBody defines an optional func to return a new copy of + // Body. It is used for client requests when a redirect requires + // reading the body more than once. Use of GetBody still + // requires setting Body. + // + // For server requests it is unused. + GetBody func() (io.ReadCloser, error) + // ContentLength records the length of the associated content. // The value -1 indicates that the length is unknown. // Values >= 0 indicate that the given number of bytes may // be read from Body. - // For client requests, a value of 0 means unknown if Body is not nil. + // For client requests, a value of 0 with a non-nil Body is + // also treated as unknown. ContentLength int64 // TransferEncoding lists the transfer encodings from outermost to @@ -175,11 +208,15 @@ type Request struct { // For server requests Host specifies the host on which the // URL is sought. Per RFC 2616, this is either the value of // the "Host" header or the host name given in the URL itself. - // It may be of the form "host:port". + // It may be of the form "host:port". For international domain + // names, Host may be in Punycode or Unicode form. Use + // golang.org/x/net/idna to convert it to either format if + // needed. // // For client requests Host optionally overrides the Host // header to send. If empty, the Request.Write method uses - // the value of URL.Host. + // the value of URL.Host. Host may contain an international + // domain name. Host string // Form contains the parsed form data, including both the URL @@ -276,8 +313,8 @@ type Request struct { // For outgoing client requests, the context controls cancelation. // // For incoming server requests, the context is canceled when the -// ServeHTTP method returns. For its associated values, see -// ServerContextKey and LocalAddrContextKey. +// client's connection closes, the request is canceled (with HTTP/2), +// or when the ServeHTTP method returns. func (r *Request) Context() context.Context { if r.ctx != nil { return r.ctx @@ -304,6 +341,18 @@ func (r *Request) ProtoAtLeast(major, minor int) bool { r.ProtoMajor == major && r.ProtoMinor >= minor } +// protoAtLeastOutgoing is like ProtoAtLeast, but is for outgoing +// requests (see issue 18407) where these fields aren't supposed to +// matter. As a minor fix for Go 1.8, at least treat (0, 0) as +// matching HTTP/1.1 or HTTP/1.0. Only HTTP/1.1 is used. +// TODO(bradfitz): ideally remove this whole method. It shouldn't be used. +func (r *Request) protoAtLeastOutgoing(major, minor int) bool { + if r.ProtoMajor == 0 && r.ProtoMinor == 0 && major == 1 && minor <= 1 { + return true + } + return r.ProtoAtLeast(major, minor) +} + // UserAgent returns the client's User-Agent, if sent in the request. func (r *Request) UserAgent() string { return r.Header.Get("User-Agent") @@ -319,6 +368,8 @@ var ErrNoCookie = errors.New("http: named cookie not present") // Cookie returns the named cookie provided in the request or // ErrNoCookie if not found. +// If multiple cookies match the given name, only one cookie will +// be returned. func (r *Request) Cookie(name string) (*Cookie, error) { for _, c := range readCookies(r.Header, name) { return c, nil @@ -561,6 +612,12 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai } } + if bw, ok := w.(*bufio.Writer); ok && tw.FlushHeaders { + if err := bw.Flush(); err != nil { + return err + } + } + // Write body and trailer err = tw.WriteBody(w) if err != nil { @@ -573,7 +630,24 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai return nil } -// cleanHost strips anything after '/' or ' '. +func idnaASCII(v string) (string, error) { + if isASCII(v) { + return v, nil + } + // The idna package doesn't do everything from + // https://tools.ietf.org/html/rfc5895 so we do it here. + // TODO(bradfitz): should the idna package do this instead? + v = strings.ToLower(v) + v = width.Fold.String(v) + v = norm.NFC.String(v) + return idna.ToASCII(v) +} + +// cleanHost cleans up the host sent in request's Host header. +// +// It both strips anything after '/' or ' ', and puts the value +// into Punycode form, if necessary. +// // Ideally we'd clean the Host header according to the spec: // https://tools.ietf.org/html/rfc7230#section-5.4 (Host = uri-host [ ":" port ]") // https://tools.ietf.org/html/rfc7230#section-2.7 (uri-host -> rfc3986's host) @@ -584,9 +658,21 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai // first offending character. func cleanHost(in string) string { if i := strings.IndexAny(in, " /"); i != -1 { - return in[:i] + in = in[:i] + } + host, port, err := net.SplitHostPort(in) + if err != nil { // input was just a host + a, err := idnaASCII(in) + if err != nil { + return in // garbage in, garbage out + } + return a } - return in + a, err := idnaASCII(host) + if err != nil { + return in // garbage in, garbage out + } + return net.JoinHostPort(a, port) } // removeZone removes IPv6 zone identifier from host. @@ -658,11 +744,17 @@ func validMethod(method string) bool { // methods Do, Post, and PostForm, and Transport.RoundTrip. // // NewRequest returns a Request suitable for use with Client.Do or -// Transport.RoundTrip. -// To create a request for use with testing a Server Handler use either -// ReadRequest or manually update the Request fields. See the Request -// type's documentation for the difference between inbound and outbound -// request fields. +// Transport.RoundTrip. To create a request for use with testing a +// Server Handler, either use the NewRequest function in the +// net/http/httptest package, use ReadRequest, or manually update the +// Request fields. See the Request type's documentation for the +// difference between inbound and outbound request fields. +// +// If body is of type *bytes.Buffer, *bytes.Reader, or +// *strings.Reader, the returned request's ContentLength is set to its +// exact value (instead of -1), GetBody is populated (so 307 and 308 +// redirects can replay the body), and Body is set to NoBody if the +// ContentLength is 0. func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { if method == "" { // We document that "" means "GET" for Request.Method, and people have @@ -697,10 +789,43 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { switch v := body.(type) { case *bytes.Buffer: req.ContentLength = int64(v.Len()) + buf := v.Bytes() + req.GetBody = func() (io.ReadCloser, error) { + r := bytes.NewReader(buf) + return ioutil.NopCloser(r), nil + } case *bytes.Reader: req.ContentLength = int64(v.Len()) + snapshot := *v + req.GetBody = func() (io.ReadCloser, error) { + r := snapshot + return ioutil.NopCloser(&r), nil + } case *strings.Reader: req.ContentLength = int64(v.Len()) + snapshot := *v + req.GetBody = func() (io.ReadCloser, error) { + r := snapshot + return ioutil.NopCloser(&r), nil + } + default: + // This is where we'd set it to -1 (at least + // if body != NoBody) to mean unknown, but + // that broke people during the Go 1.8 testing + // period. People depend on it being 0 I + // guess. Maybe retry later. See Issue 18117. + } + // For client requests, Request.ContentLength of 0 + // means either actually 0, or unknown. The only way + // to explicitly say that the ContentLength is zero is + // to set the Body to nil. But turns out too much code + // depends on NewRequest returning a non-nil Body, + // so we use a well-known ReadCloser variable instead + // and have the http package also treat that sentinel + // variable to mean explicitly zero. + if req.GetBody != nil && req.ContentLength == 0 { + req.Body = NoBody + req.GetBody = func() (io.ReadCloser, error) { return NoBody, nil } } } @@ -1000,18 +1125,24 @@ func parsePostForm(r *Request) (vs url.Values, err error) { return } -// ParseForm parses the raw query from the URL and updates r.Form. +// ParseForm populates r.Form and r.PostForm. +// +// For all requests, ParseForm parses the raw query from the URL and updates +// r.Form. +// +// For POST, PUT, and PATCH requests, it also parses the request body as a form +// and puts the results into both r.PostForm and r.Form. Request body parameters +// take precedence over URL query string values in r.Form. // -// For POST or PUT requests, it also parses the request body as a form and -// put the results into both r.PostForm and r.Form. -// POST and PUT body parameters take precedence over URL query string values -// in r.Form. +// For other HTTP methods, or when the Content-Type is not +// application/x-www-form-urlencoded, the request Body is not read, and +// r.PostForm is initialized to a non-nil, empty value. // // If the request Body's size has not already been limited by MaxBytesReader, // the size is capped at 10MB. // // ParseMultipartForm calls ParseForm automatically. -// It is idempotent. +// ParseForm is idempotent. func (r *Request) ParseForm() error { var err error if r.PostForm == nil { @@ -1174,3 +1305,30 @@ func (r *Request) isReplayable() bool { } return false } + +// outgoingLength reports the Content-Length of this outgoing (Client) request. +// It maps 0 into -1 (unknown) when the Body is non-nil. +func (r *Request) outgoingLength() int64 { + if r.Body == nil || r.Body == NoBody { + return 0 + } + if r.ContentLength != 0 { + return r.ContentLength + } + return -1 +} + +// requestMethodUsuallyLacksBody reports whether the given request +// method is one that typically does not involve a request body. +// This is used by the Transport (via +// transferWriter.shouldSendChunkedRequestBody) to determine whether +// we try to test-read a byte from a non-nil Request.Body when +// Request.outgoingLength() returns -1. See the comments in +// shouldSendChunkedRequestBody. +func requestMethodUsuallyLacksBody(method string) bool { + switch method { + case "GET", "HEAD", "DELETE", "OPTIONS", "PROPFIND", "SEARCH": + return true + } + return false +} diff --git a/libgo/go/net/http/request_test.go b/libgo/go/net/http/request_test.go index a4c88c0..e674837 100644 --- a/libgo/go/net/http/request_test.go +++ b/libgo/go/net/http/request_test.go @@ -29,9 +29,9 @@ func TestQuery(t *testing.T) { } } -func TestPostQuery(t *testing.T) { - req, _ := NewRequest("POST", "http://www.google.com/search?q=foo&q=bar&both=x&prio=1&empty=not", - strings.NewReader("z=post&both=y&prio=2&empty=")) +func TestParseFormQuery(t *testing.T) { + req, _ := NewRequest("POST", "http://www.google.com/search?q=foo&q=bar&both=x&prio=1&orphan=nope&empty=not", + strings.NewReader("z=post&both=y&prio=2&=nokey&orphan;empty=&")) req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") if q := req.FormValue("q"); q != "foo" { @@ -55,39 +55,30 @@ func TestPostQuery(t *testing.T) { if prio := req.FormValue("prio"); prio != "2" { t.Errorf(`req.FormValue("prio") = %q, want "2" (from body)`, prio) } - if empty := req.FormValue("empty"); empty != "" { + if orphan := req.Form["orphan"]; !reflect.DeepEqual(orphan, []string{"", "nope"}) { + t.Errorf(`req.FormValue("orphan") = %q, want "" (from body)`, orphan) + } + if empty := req.Form["empty"]; !reflect.DeepEqual(empty, []string{"", "not"}) { t.Errorf(`req.FormValue("empty") = %q, want "" (from body)`, empty) } + if nokey := req.Form[""]; !reflect.DeepEqual(nokey, []string{"nokey"}) { + t.Errorf(`req.FormValue("nokey") = %q, want "nokey" (from body)`, nokey) + } } -func TestPatchQuery(t *testing.T) { - req, _ := NewRequest("PATCH", "http://www.google.com/search?q=foo&q=bar&both=x&prio=1&empty=not", - strings.NewReader("z=post&both=y&prio=2&empty=")) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") - - if q := req.FormValue("q"); q != "foo" { - t.Errorf(`req.FormValue("q") = %q, want "foo"`, q) - } - if z := req.FormValue("z"); z != "post" { - t.Errorf(`req.FormValue("z") = %q, want "post"`, z) - } - if bq, found := req.PostForm["q"]; found { - t.Errorf(`req.PostForm["q"] = %q, want no entry in map`, bq) - } - if bz := req.PostFormValue("z"); bz != "post" { - t.Errorf(`req.PostFormValue("z") = %q, want "post"`, bz) - } - if qs := req.Form["q"]; !reflect.DeepEqual(qs, []string{"foo", "bar"}) { - t.Errorf(`req.Form["q"] = %q, want ["foo", "bar"]`, qs) - } - if both := req.Form["both"]; !reflect.DeepEqual(both, []string{"y", "x"}) { - t.Errorf(`req.Form["both"] = %q, want ["y", "x"]`, both) - } - if prio := req.FormValue("prio"); prio != "2" { - t.Errorf(`req.FormValue("prio") = %q, want "2" (from body)`, prio) - } - if empty := req.FormValue("empty"); empty != "" { - t.Errorf(`req.FormValue("empty") = %q, want "" (from body)`, empty) +// Tests that we only parse the form automatically for certain methods. +func TestParseFormQueryMethods(t *testing.T) { + for _, method := range []string{"POST", "PATCH", "PUT", "FOO"} { + req, _ := NewRequest(method, "http://www.google.com/search", + strings.NewReader("foo=bar")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + want := "bar" + if method == "FOO" { + want = "" + } + if got := req.FormValue("foo"); got != want { + t.Errorf(`for method %s, FormValue("foo") = %q; want %q`, method, got, want) + } } } @@ -374,18 +365,68 @@ func TestFormFileOrder(t *testing.T) { var readRequestErrorTests = []struct { in string - err error + err string + + header Header }{ - {"GET / HTTP/1.1\r\nheader:foo\r\n\r\n", nil}, - {"GET / HTTP/1.1\r\nheader:foo\r\n", io.ErrUnexpectedEOF}, - {"", io.EOF}, + 0: {"GET / HTTP/1.1\r\nheader:foo\r\n\r\n", "", Header{"Header": {"foo"}}}, + 1: {"GET / HTTP/1.1\r\nheader:foo\r\n", io.ErrUnexpectedEOF.Error(), nil}, + 2: {"", io.EOF.Error(), nil}, + 3: { + in: "HEAD / HTTP/1.1\r\nContent-Length:4\r\n\r\n", + err: "http: method cannot contain a Content-Length", + }, + 4: { + in: "HEAD / HTTP/1.1\r\n\r\n", + header: Header{}, + }, + + // Multiple Content-Length values should either be + // deduplicated if same or reject otherwise + // See Issue 16490. + 5: { + in: "POST / HTTP/1.1\r\nContent-Length: 10\r\nContent-Length: 0\r\n\r\nGopher hey\r\n", + err: "cannot contain multiple Content-Length headers", + }, + 6: { + in: "POST / HTTP/1.1\r\nContent-Length: 10\r\nContent-Length: 6\r\n\r\nGopher\r\n", + err: "cannot contain multiple Content-Length headers", + }, + 7: { + in: "PUT / HTTP/1.1\r\nContent-Length: 6 \r\nContent-Length: 6\r\nContent-Length:6\r\n\r\nGopher\r\n", + err: "", + header: Header{"Content-Length": {"6"}}, + }, + 8: { + in: "PUT / HTTP/1.1\r\nContent-Length: 1\r\nContent-Length: 6 \r\n\r\n", + err: "cannot contain multiple Content-Length headers", + }, + 9: { + in: "POST / HTTP/1.1\r\nContent-Length:\r\nContent-Length: 3\r\n\r\n", + err: "cannot contain multiple Content-Length headers", + }, + 10: { + in: "HEAD / HTTP/1.1\r\nContent-Length:0\r\nContent-Length: 0\r\n\r\n", + header: Header{"Content-Length": {"0"}}, + }, } func TestReadRequestErrors(t *testing.T) { for i, tt := range readRequestErrorTests { - _, err := ReadRequest(bufio.NewReader(strings.NewReader(tt.in))) - if err != tt.err { - t.Errorf("%d. got error = %v; want %v", i, err, tt.err) + req, err := ReadRequest(bufio.NewReader(strings.NewReader(tt.in))) + if err == nil { + if tt.err != "" { + t.Errorf("#%d: got nil err; want %q", i, tt.err) + } + + if !reflect.DeepEqual(tt.header, req.Header) { + t.Errorf("#%d: gotHeader: %q wantHeader: %q", i, req.Header, tt.header) + } + continue + } + + if tt.err == "" || !strings.Contains(err.Error(), tt.err) { + t.Errorf("%d: got error = %v; want %v", i, err, tt.err) } } } @@ -456,18 +497,23 @@ func TestNewRequestContentLength(t *testing.T) { {bytes.NewReader([]byte("123")), 3}, {bytes.NewBuffer([]byte("1234")), 4}, {strings.NewReader("12345"), 5}, - // Not detected: + {strings.NewReader(""), 0}, + {NoBody, 0}, + + // Not detected. During Go 1.8 we tried to make these set to -1, but + // due to Issue 18117, we keep these returning 0, even though they're + // unknown. {struct{ io.Reader }{strings.NewReader("xyz")}, 0}, {io.NewSectionReader(strings.NewReader("x"), 0, 6), 0}, {readByte(io.NewSectionReader(strings.NewReader("xy"), 0, 6)), 0}, } - for _, tt := range tests { + for i, tt := range tests { req, err := NewRequest("POST", "http://localhost/", tt.r) if err != nil { t.Fatal(err) } if req.ContentLength != tt.want { - t.Errorf("ContentLength(%T) = %d; want %d", tt.r, req.ContentLength, tt.want) + t.Errorf("test[%d]: ContentLength(%T) = %d; want %d", i, tt.r, req.ContentLength, tt.want) } } } @@ -626,11 +672,31 @@ func TestStarRequest(t *testing.T) { if err != nil { return } + if req.ContentLength != 0 { + t.Errorf("ContentLength = %d; want 0", req.ContentLength) + } + if req.Body == nil { + t.Errorf("Body = nil; want non-nil") + } + + // Request.Write has Client semantics for Body/ContentLength, + // where ContentLength 0 means unknown if Body is non-nil, and + // thus chunking will happen unless we change semantics and + // signal that we want to serialize it as exactly zero. The + // only way to do that for outbound requests is with a nil + // Body: + clientReq := *req + clientReq.Body = nil + var out bytes.Buffer - if err := req.Write(&out); err != nil { + if err := clientReq.Write(&out); err != nil { t.Fatal(err) } - back, err := ReadRequest(bufio.NewReader(&out)) + + if strings.Contains(out.String(), "chunked") { + t.Error("wrote chunked request; want no body") + } + back, err := ReadRequest(bufio.NewReader(bytes.NewReader(out.Bytes()))) if err != nil { t.Fatal(err) } @@ -719,6 +785,47 @@ func TestMaxBytesReaderStickyError(t *testing.T) { } } +// verify that NewRequest sets Request.GetBody and that it works +func TestNewRequestGetBody(t *testing.T) { + tests := []struct { + r io.Reader + }{ + {r: strings.NewReader("hello")}, + {r: bytes.NewReader([]byte("hello"))}, + {r: bytes.NewBuffer([]byte("hello"))}, + } + for i, tt := range tests { + req, err := NewRequest("POST", "http://foo.tld/", tt.r) + if err != nil { + t.Errorf("test[%d]: %v", i, err) + continue + } + if req.Body == nil { + t.Errorf("test[%d]: Body = nil", i) + continue + } + if req.GetBody == nil { + t.Errorf("test[%d]: GetBody = nil", i) + continue + } + slurp1, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Errorf("test[%d]: ReadAll(Body) = %v", i, err) + } + newBody, err := req.GetBody() + if err != nil { + t.Errorf("test[%d]: GetBody = %v", i, err) + } + slurp2, err := ioutil.ReadAll(newBody) + if err != nil { + t.Errorf("test[%d]: ReadAll(GetBody()) = %v", i, err) + } + if string(slurp1) != string(slurp2) { + t.Errorf("test[%d]: Body %q != GetBody %q", i, slurp1, slurp2) + } + } +} + func testMissingFile(t *testing.T, req *Request) { f, fh, err := req.FormFile("missing") if f != nil { diff --git a/libgo/go/net/http/requestwrite_test.go b/libgo/go/net/http/requestwrite_test.go index 2545f6f..eb65b9f 100644 --- a/libgo/go/net/http/requestwrite_test.go +++ b/libgo/go/net/http/requestwrite_test.go @@ -5,14 +5,17 @@ package http import ( + "bufio" "bytes" "errors" "fmt" "io" "io/ioutil" + "net" "net/url" "strings" "testing" + "time" ) type reqWriteTest struct { @@ -28,7 +31,7 @@ type reqWriteTest struct { var reqWriteTests = []reqWriteTest{ // HTTP/1.1 => chunked coding; no body; no trailer - { + 0: { Req: Request{ Method: "GET", URL: &url.URL{ @@ -75,7 +78,7 @@ var reqWriteTests = []reqWriteTest{ "Proxy-Connection: keep-alive\r\n\r\n", }, // HTTP/1.1 => chunked coding; body; empty trailer - { + 1: { Req: Request{ Method: "GET", URL: &url.URL{ @@ -104,7 +107,7 @@ var reqWriteTests = []reqWriteTest{ chunk("abcdef") + chunk(""), }, // HTTP/1.1 POST => chunked coding; body; empty trailer - { + 2: { Req: Request{ Method: "POST", URL: &url.URL{ @@ -137,7 +140,7 @@ var reqWriteTests = []reqWriteTest{ }, // HTTP/1.1 POST with Content-Length, no chunking - { + 3: { Req: Request{ Method: "POST", URL: &url.URL{ @@ -172,7 +175,7 @@ var reqWriteTests = []reqWriteTest{ }, // HTTP/1.1 POST with Content-Length in headers - { + 4: { Req: Request{ Method: "POST", URL: mustParseURL("http://example.com/"), @@ -201,7 +204,7 @@ var reqWriteTests = []reqWriteTest{ }, // default to HTTP/1.1 - { + 5: { Req: Request{ Method: "GET", URL: mustParseURL("/search"), @@ -215,7 +218,7 @@ var reqWriteTests = []reqWriteTest{ }, // Request with a 0 ContentLength and a 0 byte body. - { + 6: { Req: Request{ Method: "POST", URL: mustParseURL("/"), @@ -227,9 +230,32 @@ var reqWriteTests = []reqWriteTest{ Body: func() io.ReadCloser { return ioutil.NopCloser(io.LimitReader(strings.NewReader("xx"), 0)) }, - // RFC 2616 Section 14.13 says Content-Length should be specified - // unless body is prohibited by the request method. - // Also, nginx expects it for POST and PUT. + WantWrite: "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go-http-client/1.1\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n0\r\n\r\n", + + WantProxy: "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go-http-client/1.1\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n0\r\n\r\n", + }, + + // Request with a 0 ContentLength and a nil body. + 7: { + Req: Request{ + Method: "POST", + URL: mustParseURL("/"), + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 0, // as if unset by user + }, + + Body: func() io.ReadCloser { return nil }, + WantWrite: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + "User-Agent: Go-http-client/1.1\r\n" + @@ -244,7 +270,7 @@ var reqWriteTests = []reqWriteTest{ }, // Request with a 0 ContentLength and a 1 byte body. - { + 8: { Req: Request{ Method: "POST", URL: mustParseURL("/"), @@ -270,7 +296,7 @@ var reqWriteTests = []reqWriteTest{ }, // Request with a ContentLength of 10 but a 5 byte body. - { + 9: { Req: Request{ Method: "POST", URL: mustParseURL("/"), @@ -284,7 +310,7 @@ var reqWriteTests = []reqWriteTest{ }, // Request with a ContentLength of 4 but an 8 byte body. - { + 10: { Req: Request{ Method: "POST", URL: mustParseURL("/"), @@ -298,7 +324,7 @@ var reqWriteTests = []reqWriteTest{ }, // Request with a 5 ContentLength and nil body. - { + 11: { Req: Request{ Method: "POST", URL: mustParseURL("/"), @@ -311,7 +337,7 @@ var reqWriteTests = []reqWriteTest{ }, // Request with a 0 ContentLength and a body with 1 byte content and an error. - { + 12: { Req: Request{ Method: "POST", URL: mustParseURL("/"), @@ -331,7 +357,7 @@ var reqWriteTests = []reqWriteTest{ }, // Request with a 0 ContentLength and a body without content and an error. - { + 13: { Req: Request{ Method: "POST", URL: mustParseURL("/"), @@ -352,7 +378,7 @@ var reqWriteTests = []reqWriteTest{ // Verify that DumpRequest preserves the HTTP version number, doesn't add a Host, // and doesn't add a User-Agent. - { + 14: { Req: Request{ Method: "GET", URL: mustParseURL("/foo"), @@ -373,7 +399,7 @@ var reqWriteTests = []reqWriteTest{ // an empty Host header, and don't use // Request.Header["Host"]. This is just testing that // we don't change Go 1.0 behavior. - { + 15: { Req: Request{ Method: "GET", Host: "", @@ -395,7 +421,7 @@ var reqWriteTests = []reqWriteTest{ }, // Opaque test #1 from golang.org/issue/4860 - { + 16: { Req: Request{ Method: "GET", URL: &url.URL{ @@ -414,7 +440,7 @@ var reqWriteTests = []reqWriteTest{ }, // Opaque test #2 from golang.org/issue/4860 - { + 17: { Req: Request{ Method: "GET", URL: &url.URL{ @@ -433,7 +459,7 @@ var reqWriteTests = []reqWriteTest{ }, // Testing custom case in header keys. Issue 5022. - { + 18: { Req: Request{ Method: "GET", URL: &url.URL{ @@ -457,7 +483,7 @@ var reqWriteTests = []reqWriteTest{ }, // Request with host header field; IPv6 address with zone identifier - { + 19: { Req: Request{ Method: "GET", URL: &url.URL{ @@ -472,7 +498,7 @@ var reqWriteTests = []reqWriteTest{ }, // Request with optional host header field; IPv6 address with zone identifier - { + 20: { Req: Request{ Method: "GET", URL: &url.URL{ @@ -543,6 +569,138 @@ func TestRequestWrite(t *testing.T) { } } +func TestRequestWriteTransport(t *testing.T) { + t.Parallel() + + matchSubstr := func(substr string) func(string) error { + return func(written string) error { + if !strings.Contains(written, substr) { + return fmt.Errorf("expected substring %q in request: %s", substr, written) + } + return nil + } + } + + noContentLengthOrTransferEncoding := func(req string) error { + if strings.Contains(req, "Content-Length: ") { + return fmt.Errorf("unexpected Content-Length in request: %s", req) + } + if strings.Contains(req, "Transfer-Encoding: ") { + return fmt.Errorf("unexpected Transfer-Encoding in request: %s", req) + } + return nil + } + + all := func(checks ...func(string) error) func(string) error { + return func(req string) error { + for _, c := range checks { + if err := c(req); err != nil { + return err + } + } + return nil + } + } + + type testCase struct { + method string + clen int64 // ContentLength + body io.ReadCloser + want func(string) error + + // optional: + init func(*testCase) + afterReqRead func() + } + + tests := []testCase{ + { + method: "GET", + want: noContentLengthOrTransferEncoding, + }, + { + method: "GET", + body: ioutil.NopCloser(strings.NewReader("")), + want: noContentLengthOrTransferEncoding, + }, + { + method: "GET", + clen: -1, + body: ioutil.NopCloser(strings.NewReader("")), + want: noContentLengthOrTransferEncoding, + }, + // A GET with a body, with explicit content length: + { + method: "GET", + clen: 7, + body: ioutil.NopCloser(strings.NewReader("foobody")), + want: all(matchSubstr("Content-Length: 7"), + matchSubstr("foobody")), + }, + // A GET with a body, sniffing the leading "f" from "foobody". + { + method: "GET", + clen: -1, + body: ioutil.NopCloser(strings.NewReader("foobody")), + want: all(matchSubstr("Transfer-Encoding: chunked"), + matchSubstr("\r\n1\r\nf\r\n"), + matchSubstr("oobody")), + }, + // But a POST request is expected to have a body, so + // no sniffing happens: + { + method: "POST", + clen: -1, + body: ioutil.NopCloser(strings.NewReader("foobody")), + want: all(matchSubstr("Transfer-Encoding: chunked"), + matchSubstr("foobody")), + }, + { + method: "POST", + clen: -1, + body: ioutil.NopCloser(strings.NewReader("")), + want: all(matchSubstr("Transfer-Encoding: chunked")), + }, + // Verify that a blocking Request.Body doesn't block forever. + { + method: "GET", + clen: -1, + init: func(tt *testCase) { + pr, pw := io.Pipe() + tt.afterReqRead = func() { + pw.Close() + } + tt.body = ioutil.NopCloser(pr) + }, + want: matchSubstr("Transfer-Encoding: chunked"), + }, + } + + for i, tt := range tests { + if tt.init != nil { + tt.init(&tt) + } + req := &Request{ + Method: tt.method, + URL: &url.URL{ + Scheme: "http", + Host: "example.com", + }, + Header: make(Header), + ContentLength: tt.clen, + Body: tt.body, + } + got, err := dumpRequestOut(req, tt.afterReqRead) + if err != nil { + t.Errorf("test[%d]: %v", i, err) + continue + } + if err := tt.want(string(got)); err != nil { + t.Errorf("test[%d]: %v", i, err) + } + } +} + type closeChecker struct { io.Reader closed bool @@ -553,17 +711,19 @@ func (rc *closeChecker) Close() error { return nil } -// TestRequestWriteClosesBody tests that Request.Write does close its request.Body. +// TestRequestWriteClosesBody tests that Request.Write closes its request.Body. // It also indirectly tests NewRequest and that it doesn't wrap an existing Closer // inside a NopCloser, and that it serializes it correctly. func TestRequestWriteClosesBody(t *testing.T) { rc := &closeChecker{Reader: strings.NewReader("my body")} - req, _ := NewRequest("POST", "http://foo.com/", rc) - if req.ContentLength != 0 { - t.Errorf("got req.ContentLength %d, want 0", req.ContentLength) + req, err := NewRequest("POST", "http://foo.com/", rc) + if err != nil { + t.Fatal(err) } buf := new(bytes.Buffer) - req.Write(buf) + if err := req.Write(buf); err != nil { + t.Error(err) + } if !rc.closed { t.Error("body not closed after write") } @@ -571,12 +731,7 @@ func TestRequestWriteClosesBody(t *testing.T) { "Host: foo.com\r\n" + "User-Agent: Go-http-client/1.1\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + - // TODO: currently we don't buffer before chunking, so we get a - // single "m" chunk before the other chunks, as this was the 1-byte - // read from our MultiReader where we stitched the Body back together - // after sniffing whether the Body was 0 bytes or not. - chunk("m") + - chunk("y body") + + chunk("my body") + chunk("") if buf.String() != expected { t.Errorf("write:\n got: %s\nwant: %s", buf.String(), expected) @@ -652,3 +807,76 @@ func TestRequestWriteError(t *testing.T) { t.Fatalf("writeCalls constant is outdated in test") } } + +// dumpRequestOut is a modified copy of net/http/httputil.DumpRequestOut. +// Unlike the original, this version doesn't mutate the req.Body and +// try to restore it. It always dumps the whole body. +// And it doesn't support https. +func dumpRequestOut(req *Request, onReadHeaders func()) ([]byte, error) { + + // Use the actual Transport code to record what we would send + // on the wire, but not using TCP. Use a Transport with a + // custom dialer that returns a fake net.Conn that waits + // for the full input (and recording it), and then responds + // with a dummy response. + var buf bytes.Buffer // records the output + pr, pw := io.Pipe() + defer pr.Close() + defer pw.Close() + dr := &delegateReader{c: make(chan io.Reader)} + + t := &Transport{ + Dial: func(net, addr string) (net.Conn, error) { + return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil + }, + } + defer t.CloseIdleConnections() + + // Wait for the request before replying with a dummy response: + go func() { + req, err := ReadRequest(bufio.NewReader(pr)) + if err == nil { + if onReadHeaders != nil { + onReadHeaders() + } + // Ensure all the body is read; otherwise + // we'll get a partial dump. + io.Copy(ioutil.Discard, req.Body) + req.Body.Close() + } + dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n") + }() + + _, err := t.RoundTrip(req) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// delegateReader is a reader that delegates to another reader, +// once it arrives on a channel. +type delegateReader struct { + c chan io.Reader + r io.Reader // nil until received from c +} + +func (r *delegateReader) Read(p []byte) (int, error) { + if r.r == nil { + r.r = <-r.c + } + return r.r.Read(p) +} + +// dumpConn is a net.Conn that writes to Writer and reads from Reader. +type dumpConn struct { + io.Writer + io.Reader +} + +func (c *dumpConn) Close() error { return nil } +func (c *dumpConn) LocalAddr() net.Addr { return nil } +func (c *dumpConn) RemoteAddr() net.Addr { return nil } +func (c *dumpConn) SetDeadline(t time.Time) error { return nil } +func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil } +func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/libgo/go/net/http/response.go b/libgo/go/net/http/response.go index 5450d50..ae118fb 100644 --- a/libgo/go/net/http/response.go +++ b/libgo/go/net/http/response.go @@ -261,7 +261,7 @@ func (r *Response) Write(w io.Writer) error { if n == 0 { // Reset it to a known zero reader, in case underlying one // is unhappy being read repeatedly. - r1.Body = eofReader + r1.Body = NoBody } else { r1.ContentLength = -1 r1.Body = struct { @@ -300,7 +300,7 @@ func (r *Response) Write(w io.Writer) error { // contentLengthAlreadySent may have been already sent for // POST/PUT requests, even if zero length. See Issue 8180. contentLengthAlreadySent := tw.shouldSendContentLength() - if r1.ContentLength == 0 && !chunked(r1.TransferEncoding) && !contentLengthAlreadySent { + if r1.ContentLength == 0 && !chunked(r1.TransferEncoding) && !contentLengthAlreadySent && bodyAllowedForStatus(r.StatusCode) { if _, err := io.WriteString(w, "Content-Length: 0\r\n"); err != nil { return err } diff --git a/libgo/go/net/http/response_test.go b/libgo/go/net/http/response_test.go index 126da92..660d517 100644 --- a/libgo/go/net/http/response_test.go +++ b/libgo/go/net/http/response_test.go @@ -589,6 +589,7 @@ var readResponseCloseInMiddleTests = []struct { // reading only part of its contents advances the read to the end of // the request, right up until the next request. func TestReadResponseCloseInMiddle(t *testing.T) { + t.Parallel() for _, test := range readResponseCloseInMiddleTests { fatalf := func(format string, args ...interface{}) { args = append([]interface{}{test.chunked, test.compressed}, args...) @@ -792,6 +793,7 @@ func TestReadResponseErrors(t *testing.T) { type testCase struct { name string // optional, defaults to in in string + header Header wantErr interface{} // nil, err value, or string substring } @@ -817,11 +819,22 @@ func TestReadResponseErrors(t *testing.T) { } } + contentLength := func(status, body string, wantErr interface{}, header Header) testCase { + return testCase{ + name: fmt.Sprintf("status %q %q", status, body), + in: fmt.Sprintf("HTTP/1.1 %s\r\n%s", status, body), + wantErr: wantErr, + header: header, + } + } + + errMultiCL := "message cannot contain multiple Content-Length headers" + tests := []testCase{ - {"", "", io.ErrUnexpectedEOF}, - {"", "HTTP/1.1 301 Moved Permanently\r\nFoo: bar", io.ErrUnexpectedEOF}, - {"", "HTTP/1.1", "malformed HTTP response"}, - {"", "HTTP/2.0", "malformed HTTP response"}, + {"", "", nil, io.ErrUnexpectedEOF}, + {"", "HTTP/1.1 301 Moved Permanently\r\nFoo: bar", nil, io.ErrUnexpectedEOF}, + {"", "HTTP/1.1", nil, "malformed HTTP response"}, + {"", "HTTP/2.0", nil, "malformed HTTP response"}, status("20X Unknown", true), status("abcd Unknown", true), status("二百/两百 OK", true), @@ -846,7 +859,21 @@ func TestReadResponseErrors(t *testing.T) { version("HTTP/A.B", true), version("HTTP/1", true), version("http/1.1", true), + + contentLength("200 OK", "Content-Length: 10\r\nContent-Length: 7\r\n\r\nGopher hey\r\n", errMultiCL, nil), + contentLength("200 OK", "Content-Length: 7\r\nContent-Length: 7\r\n\r\nGophers\r\n", nil, Header{"Content-Length": {"7"}}), + contentLength("201 OK", "Content-Length: 0\r\nContent-Length: 7\r\n\r\nGophers\r\n", errMultiCL, nil), + contentLength("300 OK", "Content-Length: 0\r\nContent-Length: 0 \r\n\r\nGophers\r\n", nil, Header{"Content-Length": {"0"}}), + contentLength("200 OK", "Content-Length:\r\nContent-Length:\r\n\r\nGophers\r\n", nil, nil), + contentLength("206 OK", "Content-Length:\r\nContent-Length: 0 \r\nConnection: close\r\n\r\nGophers\r\n", errMultiCL, nil), + + // multiple content-length headers for 204 and 304 should still be checked + contentLength("204 OK", "Content-Length: 7\r\nContent-Length: 8\r\n\r\n", errMultiCL, nil), + contentLength("204 OK", "Content-Length: 3\r\nContent-Length: 3\r\n\r\n", nil, nil), + contentLength("304 OK", "Content-Length: 880\r\nContent-Length: 1\r\n\r\n", errMultiCL, nil), + contentLength("304 OK", "Content-Length: 961\r\nContent-Length: 961\r\n\r\n", nil, nil), } + for i, tt := range tests { br := bufio.NewReader(strings.NewReader(tt.in)) _, rerr := ReadResponse(br, nil) diff --git a/libgo/go/net/http/responsewrite_test.go b/libgo/go/net/http/responsewrite_test.go index 90f6767..d41d898 100644 --- a/libgo/go/net/http/responsewrite_test.go +++ b/libgo/go/net/http/responsewrite_test.go @@ -241,7 +241,8 @@ func TestResponseWrite(t *testing.T) { "HTTP/1.0 007 license to violate specs\r\nContent-Length: 0\r\n\r\n", }, - // No stutter. + // No stutter. Status code in 1xx range response should + // not include a Content-Length header. See issue #16942. { Response{ StatusCode: 123, @@ -253,7 +254,23 @@ func TestResponseWrite(t *testing.T) { Body: nil, }, - "HTTP/1.0 123 Sesame Street\r\nContent-Length: 0\r\n\r\n", + "HTTP/1.0 123 Sesame Street\r\n\r\n", + }, + + // Status code 204 (No content) response should not include a + // Content-Length header. See issue #16942. + { + Response{ + StatusCode: 204, + Status: "No Content", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{}, + Body: nil, + }, + + "HTTP/1.0 204 No Content\r\n\r\n", }, } diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go index 13e5f28..072da25 100644 --- a/libgo/go/net/http/serve_test.go +++ b/libgo/go/net/http/serve_test.go @@ -156,6 +156,7 @@ func (ht handlerTest) rawResponse(req string) string { } func TestConsumingBodyOnNextConn(t *testing.T) { + t.Parallel() defer afterTest(t) conn := new(testConn) for i := 0; i < 2; i++ { @@ -237,6 +238,7 @@ var vtests = []struct { } func TestHostHandlers(t *testing.T) { + setParallel(t) defer afterTest(t) mux := NewServeMux() for _, h := range handlers { @@ -353,6 +355,7 @@ var serveMuxTests = []struct { } func TestServeMuxHandler(t *testing.T) { + setParallel(t) mux := NewServeMux() for _, e := range serveMuxRegister { mux.Handle(e.pattern, e.h) @@ -390,15 +393,16 @@ var serveMuxTests2 = []struct { // TestServeMuxHandlerRedirects tests that automatic redirects generated by // mux.Handler() shouldn't clear the request's query string. func TestServeMuxHandlerRedirects(t *testing.T) { + setParallel(t) mux := NewServeMux() for _, e := range serveMuxRegister { mux.Handle(e.pattern, e.h) } for _, tt := range serveMuxTests2 { - tries := 1 + tries := 1 // expect at most 1 redirection if redirOk is true. turl := tt.url - for tries > 0 { + for { u, e := url.Parse(turl) if e != nil { t.Fatal(e) @@ -432,6 +436,7 @@ func TestServeMuxHandlerRedirects(t *testing.T) { // Tests for https://golang.org/issue/900 func TestMuxRedirectLeadingSlashes(t *testing.T) { + setParallel(t) paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"} for _, path := range paths { req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n"))) @@ -456,9 +461,6 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { } func TestServerTimeouts(t *testing.T) { - if runtime.GOOS == "plan9" { - t.Skip("skipping test; see https://golang.org/issue/7237") - } setParallel(t) defer afterTest(t) reqNum := 0 @@ -479,11 +481,11 @@ func TestServerTimeouts(t *testing.T) { if err != nil { t.Fatalf("http Get #1: %v", err) } - got, _ := ioutil.ReadAll(r.Body) + got, err := ioutil.ReadAll(r.Body) expected := "req=1" - if string(got) != expected { - t.Errorf("Unexpected response for request #1; got %q; expected %q", - string(got), expected) + if string(got) != expected || err != nil { + t.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil", + string(got), err, expected) } // Slow client that should timeout. @@ -494,6 +496,7 @@ func TestServerTimeouts(t *testing.T) { } buf := make([]byte, 1) n, err := conn.Read(buf) + conn.Close() latency := time.Since(t1) if n != 0 || err != io.EOF { t.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF) @@ -505,14 +508,14 @@ func TestServerTimeouts(t *testing.T) { // Hit the HTTP server successfully again, verifying that the // previous slow connection didn't run our handler. (that we // get "req=2", not "req=3") - r, err = Get(ts.URL) + r, err = c.Get(ts.URL) if err != nil { t.Fatalf("http Get #2: %v", err) } - got, _ = ioutil.ReadAll(r.Body) + got, err = ioutil.ReadAll(r.Body) expected = "req=2" - if string(got) != expected { - t.Errorf("Get #2 got %q, want %q", string(got), expected) + if string(got) != expected || err != nil { + t.Errorf("Get #2 got %q, %v, want %q, nil", string(got), err, expected) } if !testing.Short() { @@ -532,13 +535,61 @@ func TestServerTimeouts(t *testing.T) { } } +// Test that the HTTP/2 server handles Server.WriteTimeout (Issue 18437) +func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + setParallel(t) + defer afterTest(t) + ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {})) + ts.Config.WriteTimeout = 250 * time.Millisecond + ts.TLS = &tls.Config{NextProtos: []string{"h2"}} + ts.StartTLS() + defer ts.Close() + + tr := newTLSTransport(t, ts) + defer tr.CloseIdleConnections() + if err := ExportHttp2ConfigureTransport(tr); err != nil { + t.Fatal(err) + } + c := &Client{Transport: tr} + + for i := 1; i <= 3; i++ { + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + + // fail test if no response after 1 second + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + req = req.WithContext(ctx) + + r, err := c.Do(req) + select { + case <-ctx.Done(): + if ctx.Err() == context.DeadlineExceeded { + t.Fatalf("http2 Get #%d response timed out", i) + } + default: + } + if err != nil { + t.Fatalf("http2 Get #%d: %v", i, err) + } + r.Body.Close() + if r.ProtoMajor != 2 { + t.Fatalf("http2 Get expected HTTP/2.0, got %q", r.Proto) + } + time.Sleep(ts.Config.WriteTimeout / 2) + } +} + // golang.org/issue/4741 -- setting only a write timeout that triggers // shouldn't cause a handler to block forever on reads (next HTTP // request) that will never happen. func TestOnlyWriteTimeout(t *testing.T) { - if runtime.GOOS == "plan9" { - t.Skip("skipping test; see https://golang.org/issue/7237") - } + setParallel(t) defer afterTest(t) var conn net.Conn var afterTimeoutErrc = make(chan error, 1) @@ -598,6 +649,7 @@ func (l trackLastConnListener) Accept() (c net.Conn, err error) { // TestIdentityResponse verifies that a handler can unset func TestIdentityResponse(t *testing.T) { + setParallel(t) defer afterTest(t) handler := HandlerFunc(func(rw ResponseWriter, req *Request) { rw.Header().Set("Content-Length", "3") @@ -619,13 +671,16 @@ func TestIdentityResponse(t *testing.T) { ts := httptest.NewServer(handler) defer ts.Close() + c := &Client{Transport: new(Transport)} + defer closeClient(c) + // Note: this relies on the assumption (which is true) that // Get sends HTTP/1.1 or greater requests. Otherwise the // server wouldn't have the choice to send back chunked // responses. for _, te := range []string{"", "identity"} { url := ts.URL + "/?te=" + te - res, err := Get(url) + res, err := c.Get(url) if err != nil { t.Fatalf("error with Get of %s: %v", url, err) } @@ -644,7 +699,7 @@ func TestIdentityResponse(t *testing.T) { // Verify that ErrContentLength is returned url := ts.URL + "/?overwrite=1" - res, err := Get(url) + res, err := c.Get(url) if err != nil { t.Fatalf("error with Get of %s: %v", url, err) } @@ -674,6 +729,7 @@ func TestIdentityResponse(t *testing.T) { } func testTCPConnectionCloses(t *testing.T, req string, h Handler) { + setParallel(t) defer afterTest(t) s := httptest.NewServer(h) defer s.Close() @@ -717,6 +773,7 @@ func testTCPConnectionCloses(t *testing.T, req string, h Handler) { } func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) { + setParallel(t) defer afterTest(t) ts := httptest.NewServer(handler) defer ts.Close() @@ -750,7 +807,7 @@ func TestServeHTTP10Close(t *testing.T) { // TestClientCanClose verifies that clients can also force a connection to close. func TestClientCanClose(t *testing.T) { - testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { // Nothing. })) } @@ -758,7 +815,7 @@ func TestClientCanClose(t *testing.T) { // TestHandlersCanSetConnectionClose verifies that handlers can force a connection to close, // even for HTTP/1.1 requests. func TestHandlersCanSetConnectionClose11(t *testing.T) { - testTCPConnectionCloses(t, "GET / HTTP/1.1\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Connection", "close") })) } @@ -796,6 +853,7 @@ func TestHTTP10KeepAlive304Response(t *testing.T) { // Issue 15703 func TestKeepAliveFinalChunkWithEOF(t *testing.T) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, false /* h1 */, HandlerFunc(func(w ResponseWriter, r *Request) { w.(Flusher).Flush() // force chunked encoding @@ -828,6 +886,7 @@ func TestSetsRemoteAddr_h1(t *testing.T) { testSetsRemoteAddr(t, h1Mode) } func TestSetsRemoteAddr_h2(t *testing.T) { testSetsRemoteAddr(t, h2Mode) } func testSetsRemoteAddr(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%s", r.RemoteAddr) @@ -877,6 +936,7 @@ func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr { // Issue 12943 func TestServerAllowsBlockingRemoteAddr(t *testing.T) { + setParallel(t) defer afterTest(t) ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "RA:%s", r.RemoteAddr) @@ -948,7 +1008,9 @@ func TestServerAllowsBlockingRemoteAddr(t *testing.T) { t.Fatalf("response 1 addr = %q; want %q", g, e) } } + func TestIdentityResponseHeaders(t *testing.T) { + // Not parallel; changes log output. defer afterTest(t) log.SetOutput(ioutil.Discard) // is noisy otherwise defer log.SetOutput(os.Stderr) @@ -960,7 +1022,10 @@ func TestIdentityResponseHeaders(t *testing.T) { })) defer ts.Close() - res, err := Get(ts.URL) + c := &Client{Transport: new(Transport)} + defer closeClient(c) + + res, err := c.Get(ts.URL) if err != nil { t.Fatalf("Get error: %v", err) } @@ -983,6 +1048,7 @@ func TestHeadResponses_h1(t *testing.T) { testHeadResponses(t, h1Mode) } func TestHeadResponses_h2(t *testing.T) { testHeadResponses(t, h2Mode) } func testHeadResponses(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { _, err := w.Write([]byte("<html>")) @@ -1020,9 +1086,6 @@ func testHeadResponses(t *testing.T, h2 bool) { } func TestTLSHandshakeTimeout(t *testing.T) { - if runtime.GOOS == "plan9" { - t.Skip("skipping test; see https://golang.org/issue/7237") - } setParallel(t) defer afterTest(t) ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) @@ -1054,6 +1117,7 @@ func TestTLSHandshakeTimeout(t *testing.T) { } func TestTLSServer(t *testing.T) { + setParallel(t) defer afterTest(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.TLS != nil { @@ -1121,6 +1185,7 @@ func TestAutomaticHTTP2_Serve_H2TLSConfig(t *testing.T) { } func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) { + setParallel(t) defer afterTest(t) ln := newLocalListener(t) ln.Close() // immediately (not a defer!) @@ -1136,6 +1201,7 @@ func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) { } func TestAutomaticHTTP2_Serve_WithTLSConfig(t *testing.T) { + setParallel(t) defer afterTest(t) ln := newLocalListener(t) ln.Close() // immediately (not a defer!) @@ -1177,6 +1243,7 @@ func TestAutomaticHTTP2_ListenAndServe_GetCertificate(t *testing.T) { } func testAutomaticHTTP2_ListenAndServe(t *testing.T, tlsConf *tls.Config) { + // Not parallel: uses global test hooks. defer afterTest(t) defer SetTestHookServerServe(nil) var ok bool @@ -1280,6 +1347,7 @@ var serverExpectTests = []serverExpectTest{ // correctly. // http2 test: TestServer_Response_Automatic100Continue func TestServerExpect(t *testing.T) { + setParallel(t) defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { // Note using r.FormValue("readbody") because for POST @@ -1373,6 +1441,7 @@ func TestServerExpect(t *testing.T) { // Under a ~256KB (maxPostHandlerReadBytes) threshold, the server // should consume client request bodies that a handler didn't read. func TestServerUnreadRequestBodyLittle(t *testing.T) { + setParallel(t) defer afterTest(t) conn := new(testConn) body := strings.Repeat("x", 100<<10) @@ -1413,6 +1482,7 @@ func TestServerUnreadRequestBodyLittle(t *testing.T) { // should ignore client request bodies that a handler didn't read // and close the connection. func TestServerUnreadRequestBodyLarge(t *testing.T) { + setParallel(t) if testing.Short() && testenv.Builder() == "" { t.Log("skipping in short mode") } @@ -1546,6 +1616,7 @@ var handlerBodyCloseTests = [...]handlerBodyCloseTest{ } func TestHandlerBodyClose(t *testing.T) { + setParallel(t) if testing.Short() && testenv.Builder() == "" { t.Skip("skipping in -short mode") } @@ -1625,6 +1696,7 @@ var testHandlerBodyConsumers = []testHandlerBodyConsumer{ } func TestRequestBodyReadErrorClosesConnection(t *testing.T) { + setParallel(t) defer afterTest(t) for _, handler := range testHandlerBodyConsumers { conn := new(testConn) @@ -1655,6 +1727,7 @@ func TestRequestBodyReadErrorClosesConnection(t *testing.T) { } func TestInvalidTrailerClosesConnection(t *testing.T) { + setParallel(t) defer afterTest(t) for _, handler := range testHandlerBodyConsumers { conn := new(testConn) @@ -1737,7 +1810,7 @@ restart: if !c.rd.IsZero() { // If the deadline falls in the middle of our sleep window, deduct // part of the sleep, then return a timeout. - if remaining := c.rd.Sub(time.Now()); remaining < cue { + if remaining := time.Until(c.rd); remaining < cue { c.script[0] = cue - remaining time.Sleep(remaining) return 0, syscall.ETIMEDOUT @@ -1823,6 +1896,7 @@ func TestRequestBodyTimeoutClosesConnection(t *testing.T) { func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) } func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) } func testTimeoutHandler(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) sendHi := make(chan bool, 1) writeErrors := make(chan error, 1) @@ -1876,6 +1950,7 @@ func testTimeoutHandler(t *testing.T, h2 bool) { // See issues 8209 and 8414. func TestTimeoutHandlerRace(t *testing.T) { + setParallel(t) defer afterTest(t) delayHi := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -1892,6 +1967,9 @@ func TestTimeoutHandlerRace(t *testing.T) { ts := httptest.NewServer(TimeoutHandler(delayHi, 20*time.Millisecond, "")) defer ts.Close() + c := &Client{Transport: new(Transport)} + defer closeClient(c) + var wg sync.WaitGroup gate := make(chan bool, 10) n := 50 @@ -1905,7 +1983,7 @@ func TestTimeoutHandlerRace(t *testing.T) { go func() { defer wg.Done() defer func() { <-gate }() - res, err := Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50))) + res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50))) if err == nil { io.Copy(ioutil.Discard, res.Body) res.Body.Close() @@ -1917,6 +1995,7 @@ func TestTimeoutHandlerRace(t *testing.T) { // See issues 8209 and 8414. func TestTimeoutHandlerRaceHeader(t *testing.T) { + setParallel(t) defer afterTest(t) delay204 := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -1932,13 +2011,15 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) { if testing.Short() { n = 10 } + c := &Client{Transport: new(Transport)} + defer closeClient(c) for i := 0; i < n; i++ { gate <- true wg.Add(1) go func() { defer wg.Done() defer func() { <-gate }() - res, err := Get(ts.URL) + res, err := c.Get(ts.URL) if err != nil { t.Error(err) return @@ -1952,6 +2033,7 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) { // Issue 9162 func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { + setParallel(t) defer afterTest(t) sendHi := make(chan bool, 1) writeErrors := make(chan error, 1) @@ -2016,11 +2098,15 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) { timeout := 300 * time.Millisecond ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) defer ts.Close() + + c := &Client{Transport: new(Transport)} + defer closeClient(c) + // Issue was caused by the timeout handler starting the timer when // was created, not when the request. So wait for more than the timeout // to ensure that's not the case. time.Sleep(2 * timeout) - res, err := Get(ts.URL) + res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) } @@ -2032,6 +2118,7 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) { // https://golang.org/issue/15948 func TestTimeoutHandlerEmptyResponse(t *testing.T) { + setParallel(t) defer afterTest(t) var handler HandlerFunc = func(w ResponseWriter, _ *Request) { // No response. @@ -2040,7 +2127,10 @@ func TestTimeoutHandlerEmptyResponse(t *testing.T) { ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) defer ts.Close() - res, err := Get(ts.URL) + c := &Client{Transport: new(Transport)} + defer closeClient(c) + + res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) } @@ -2050,23 +2140,6 @@ func TestTimeoutHandlerEmptyResponse(t *testing.T) { } } -// Verifies we don't path.Clean() on the wrong parts in redirects. -func TestRedirectMunging(t *testing.T) { - req, _ := NewRequest("GET", "http://example.com/", nil) - - resp := httptest.NewRecorder() - Redirect(resp, req, "/foo?next=http://bar.com/", 302) - if g, e := resp.Header().Get("Location"), "/foo?next=http://bar.com/"; g != e { - t.Errorf("Location header was %q; want %q", g, e) - } - - resp = httptest.NewRecorder() - Redirect(resp, req, "http://localhost:8080/_ah/login?continue=http://localhost:8080/", 302) - if g, e := resp.Header().Get("Location"), "http://localhost:8080/_ah/login?continue=http://localhost:8080/"; g != e { - t.Errorf("Location header was %q; want %q", g, e) - } -} - func TestRedirectBadPath(t *testing.T) { // This used to crash. It's not valid input (bad path), but it // shouldn't crash. @@ -2085,7 +2158,7 @@ func TestRedirectBadPath(t *testing.T) { } // Test different URL formats and schemes -func TestRedirectURLFormat(t *testing.T) { +func TestRedirect(t *testing.T) { req, _ := NewRequest("GET", "http://example.com/qux/", nil) var tests = []struct { @@ -2108,6 +2181,14 @@ func TestRedirectURLFormat(t *testing.T) { {"../quux/foobar.com/baz", "/quux/foobar.com/baz"}, // incorrect number of slashes {"///foobar.com/baz", "/foobar.com/baz"}, + + // Verifies we don't path.Clean() on the wrong parts in redirects: + {"/foo?next=http://bar.com/", "/foo?next=http://bar.com/"}, + {"http://localhost:8080/_ah/login?continue=http://localhost:8080/", + "http://localhost:8080/_ah/login?continue=http://localhost:8080/"}, + + {"/фубар", "/%d1%84%d1%83%d0%b1%d0%b0%d1%80"}, + {"http://foo.com/фубар", "http://foo.com/%d1%84%d1%83%d0%b1%d0%b0%d1%80"}, } for _, tt := range tests { @@ -2133,6 +2214,7 @@ func TestZeroLengthPostAndResponse_h2(t *testing.T) { } func testZeroLengthPostAndResponse(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, r *Request) { all, err := ioutil.ReadAll(r.Body) @@ -2252,12 +2334,58 @@ func testHandlerPanic(t *testing.T, withHijack, h2 bool, panicValue interface{}) } } +type terrorWriter struct{ t *testing.T } + +func (w terrorWriter) Write(p []byte) (int, error) { + w.t.Errorf("%s", p) + return len(p), nil +} + +// Issue 16456: allow writing 0 bytes on hijacked conn to test hijack +// without any log spam. +func TestServerWriteHijackZeroBytes(t *testing.T) { + defer afterTest(t) + done := make(chan struct{}) + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + defer close(done) + w.(Flusher).Flush() + conn, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Errorf("Hijack: %v", err) + return + } + defer conn.Close() + _, err = w.Write(nil) + if err != ErrHijacked { + t.Errorf("Write error = %v; want ErrHijacked", err) + } + })) + ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0) + ts.Start() + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } +} + func TestServerNoDate_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Date") } func TestServerNoDate_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Date") } func TestServerNoContentType_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Content-Type") } func TestServerNoContentType_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Content-Type") } func testServerNoHeader(t *testing.T, h2 bool, header string) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header()[header] = nil @@ -2275,6 +2403,7 @@ func testServerNoHeader(t *testing.T, h2 bool, header string) { } func TestStripPrefix(t *testing.T) { + setParallel(t) defer afterTest(t) h := HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Path", r.URL.Path) @@ -2282,7 +2411,10 @@ func TestStripPrefix(t *testing.T) { ts := httptest.NewServer(StripPrefix("/foo", h)) defer ts.Close() - res, err := Get(ts.URL + "/foo/bar") + c := &Client{Transport: new(Transport)} + defer closeClient(c) + + res, err := c.Get(ts.URL + "/foo/bar") if err != nil { t.Fatal(err) } @@ -2304,10 +2436,11 @@ func TestStripPrefix(t *testing.T) { func TestRequestLimit_h1(t *testing.T) { testRequestLimit(t, h1Mode) } func TestRequestLimit_h2(t *testing.T) { testRequestLimit(t, h2Mode) } func testRequestLimit(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { t.Fatalf("didn't expect to get request in Handler") - })) + }), optQuietLog) defer cst.close() req, _ := NewRequest("GET", cst.ts.URL, nil) var bytesPerHeader = len("header12345: val12345\r\n") @@ -2350,6 +2483,7 @@ func (cr countReader) Read(p []byte) (n int, err error) { func TestRequestBodyLimit_h1(t *testing.T) { testRequestBodyLimit(t, h1Mode) } func TestRequestBodyLimit_h2(t *testing.T) { testRequestBodyLimit(t, h2Mode) } func testRequestBodyLimit(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) const limit = 1 << 20 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { @@ -2399,14 +2533,14 @@ func TestClientWriteShutdown(t *testing.T) { } err = conn.(*net.TCPConn).CloseWrite() if err != nil { - t.Fatalf("Dial: %v", err) + t.Fatalf("CloseWrite: %v", err) } donec := make(chan bool) go func() { defer close(donec) bs, err := ioutil.ReadAll(conn) if err != nil { - t.Fatalf("ReadAll: %v", err) + t.Errorf("ReadAll: %v", err) } got := string(bs) if got != "" { @@ -2445,6 +2579,7 @@ func TestServerBufferedChunking(t *testing.T) { // closing the TCP connection, causing the client to get a RST. // See https://golang.org/issue/3595 func TestServerGracefulClose(t *testing.T) { + setParallel(t) defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { Error(w, "bye", StatusUnauthorized) @@ -2557,7 +2692,8 @@ func TestCloseNotifier(t *testing.T) { go func() { _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n") if err != nil { - t.Fatal(err) + t.Error(err) + return } <-diec conn.Close() @@ -2599,7 +2735,8 @@ func TestCloseNotifierPipelined(t *testing.T) { const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n" _, err = io.WriteString(conn, req+req) // two requests if err != nil { - t.Fatal(err) + t.Error(err) + return } <-diec conn.Close() @@ -2707,6 +2844,7 @@ func TestHijackAfterCloseNotifier(t *testing.T) { } func TestHijackBeforeRequestBodyRead(t *testing.T) { + setParallel(t) defer afterTest(t) var requestBody = bytes.Repeat([]byte("a"), 1<<20) bodyOkay := make(chan bool, 1) @@ -3028,15 +3166,18 @@ func (l *errorListener) Addr() net.Addr { } func TestAcceptMaxFds(t *testing.T) { - log.SetOutput(ioutil.Discard) // is noisy otherwise - defer log.SetOutput(os.Stderr) + setParallel(t) ln := &errorListener{[]error{ &net.OpError{ Op: "accept", Err: syscall.EMFILE, }}} - err := Serve(ln, HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {}))) + server := &Server{ + Handler: HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})), + ErrorLog: log.New(ioutil.Discard, "", 0), // noisy otherwise + } + err := server.Serve(ln) if err != io.EOF { t.Errorf("got error %v, want EOF", err) } @@ -3161,6 +3302,7 @@ func TestHTTP10ConnectionHeader(t *testing.T) { func TestServerReaderFromOrder_h1(t *testing.T) { testServerReaderFromOrder(t, h1Mode) } func TestServerReaderFromOrder_h2(t *testing.T) { testServerReaderFromOrder(t, h2Mode) } func testServerReaderFromOrder(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) pr, pw := io.Pipe() const size = 3 << 20 @@ -3265,6 +3407,7 @@ func TestTransportAndServerSharedBodyRace_h2(t *testing.T) { testTransportAndServerSharedBodyRace(t, h2Mode) } func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) const bodySize = 1 << 20 @@ -3453,6 +3596,7 @@ func TestAppendTime(t *testing.T) { } func TestServerConnState(t *testing.T) { + setParallel(t) defer afterTest(t) handler := map[string]func(w ResponseWriter, r *Request){ "/": func(w ResponseWriter, r *Request) { @@ -3500,14 +3644,39 @@ func TestServerConnState(t *testing.T) { } ts.Start() - mustGet(t, ts.URL+"/") - mustGet(t, ts.URL+"/close") + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + mustGet := func(url string, headers ...string) { + req, err := NewRequest("GET", url, nil) + if err != nil { + t.Fatal(err) + } + for len(headers) > 0 { + req.Header.Add(headers[0], headers[1]) + headers = headers[2:] + } + res, err := c.Do(req) + if err != nil { + t.Errorf("Error fetching %s: %v", url, err) + return + } + _, err = ioutil.ReadAll(res.Body) + defer res.Body.Close() + if err != nil { + t.Errorf("Error reading %s: %v", url, err) + } + } + + mustGet(ts.URL + "/") + mustGet(ts.URL + "/close") - mustGet(t, ts.URL+"/") - mustGet(t, ts.URL+"/", "Connection", "close") + mustGet(ts.URL + "/") + mustGet(ts.URL+"/", "Connection", "close") - mustGet(t, ts.URL+"/hijack") - mustGet(t, ts.URL+"/hijack-panic") + mustGet(ts.URL + "/hijack") + mustGet(ts.URL + "/hijack-panic") // New->Closed { @@ -3587,31 +3756,10 @@ func TestServerConnState(t *testing.T) { } mu.Lock() - t.Errorf("Unexpected events.\nGot log: %s\n Want: %s\n", logString(stateLog), logString(want)) + t.Errorf("Unexpected events.\nGot log:\n%s\n Want:\n%s\n", logString(stateLog), logString(want)) mu.Unlock() } -func mustGet(t *testing.T, url string, headers ...string) { - req, err := NewRequest("GET", url, nil) - if err != nil { - t.Fatal(err) - } - for len(headers) > 0 { - req.Header.Add(headers[0], headers[1]) - headers = headers[2:] - } - res, err := DefaultClient.Do(req) - if err != nil { - t.Errorf("Error fetching %s: %v", url, err) - return - } - _, err = ioutil.ReadAll(res.Body) - defer res.Body.Close() - if err != nil { - t.Errorf("Error reading %s: %v", url, err) - } -} - func TestServerKeepAlivesEnabled(t *testing.T) { defer afterTest(t) ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) @@ -3632,6 +3780,7 @@ func TestServerKeepAlivesEnabled(t *testing.T) { func TestServerEmptyBodyRace_h1(t *testing.T) { testServerEmptyBodyRace(t, h1Mode) } func TestServerEmptyBodyRace_h2(t *testing.T) { testServerEmptyBodyRace(t, h2Mode) } func testServerEmptyBodyRace(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) var n int32 cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { @@ -3695,6 +3844,7 @@ func (c *closeWriteTestConn) CloseWrite() error { } func TestCloseWrite(t *testing.T) { + setParallel(t) var srv Server var testConn closeWriteTestConn c := ExportServerNewConn(&srv, &testConn) @@ -3935,6 +4085,7 @@ Host: foo // If a Handler finishes and there's an unread request body, // verify the server try to do implicit read on it before replying. func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) { + setParallel(t) conn := &testConn{closec: make(chan bool)} conn.readBuf.Write([]byte(fmt.Sprintf( "POST / HTTP/1.1\r\n" + @@ -4033,7 +4184,11 @@ func TestServerValidatesHostHeader(t *testing.T) { io.WriteString(&conn.readBuf, methodTarget+tt.proto+"\r\n"+tt.host+"\r\n") ln := &oneConnListener{conn} - go Serve(ln, HandlerFunc(func(ResponseWriter, *Request) {})) + srv := Server{ + ErrorLog: quietLog, + Handler: HandlerFunc(func(ResponseWriter, *Request) {}), + } + go srv.Serve(ln) <-conn.closec res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil) if err != nil { @@ -4088,6 +4243,7 @@ func TestServerHandlersCanHandleH2PRI(t *testing.T) { // Test that we validate the valid bytes in HTTP/1 headers. // Issue 11207. func TestServerValidatesHeaders(t *testing.T) { + setParallel(t) tests := []struct { header string want int @@ -4097,9 +4253,10 @@ func TestServerValidatesHeaders(t *testing.T) { {"X-Foo: bar\r\n", 200}, {"Foo: a space\r\n", 200}, - {"A space: foo\r\n", 400}, // space in header - {"foo\xffbar: foo\r\n", 400}, // binary in header - {"foo\x00bar: foo\r\n", 400}, // binary in header + {"A space: foo\r\n", 400}, // space in header + {"foo\xffbar: foo\r\n", 400}, // binary in header + {"foo\x00bar: foo\r\n", 400}, // binary in header + {"Foo: " + strings.Repeat("x", 1<<21) + "\r\n", 431}, // header too large {"foo: foo foo\r\n", 200}, // LWS space is okay {"foo: foo\tfoo\r\n", 200}, // LWS tab is okay @@ -4112,7 +4269,11 @@ func TestServerValidatesHeaders(t *testing.T) { io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n") ln := &oneConnListener{conn} - go Serve(ln, HandlerFunc(func(ResponseWriter, *Request) {})) + srv := Server{ + ErrorLog: quietLog, + Handler: HandlerFunc(func(ResponseWriter, *Request) {}), + } + go srv.Serve(ln) <-conn.closec res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil) if err != nil { @@ -4132,6 +4293,7 @@ func TestServerRequestContextCancel_ServeHTTPDone_h2(t *testing.T) { testServerRequestContextCancel_ServeHTTPDone(t, h2Mode) } func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) ctxc := make(chan context.Context, 1) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { @@ -4157,13 +4319,12 @@ func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) { } } +// Tests that the Request.Context available to the Handler is canceled +// if the peer closes their TCP connection. This requires that the server +// is always blocked in a Read call so it notices the EOF from the client. +// See issues 15927 and 15224. func TestServerRequestContextCancel_ConnClose(t *testing.T) { - // Currently the context is not canceled when the connection - // is closed because we're not reading from the connection - // until after ServeHTTP for the previous handler is done. - // Until the server code is modified to always be in a read - // (Issue 15224), this test doesn't work yet. - t.Skip("TODO(bradfitz): this test doesn't yet work; golang.org/issue/15224") + setParallel(t) defer afterTest(t) inHandler := make(chan struct{}) handlerDone := make(chan struct{}) @@ -4192,7 +4353,7 @@ func TestServerRequestContextCancel_ConnClose(t *testing.T) { select { case <-handlerDone: - case <-time.After(3 * time.Second): + case <-time.After(4 * time.Second): t.Fatalf("timeout waiting to see ServeHTTP exit") } } @@ -4204,6 +4365,7 @@ func TestServerContext_ServerContextKey_h2(t *testing.T) { testServerContext_ServerContextKey(t, h2Mode) } func testServerContext_ServerContextKey(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { ctx := r.Context() @@ -4229,6 +4391,7 @@ func testServerContext_ServerContextKey(t *testing.T, h2 bool) { // https://golang.org/issue/15960 func TestHandlerSetTransferEncodingChunked(t *testing.T) { + setParallel(t) defer afterTest(t) ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Transfer-Encoding", "chunked") @@ -4243,6 +4406,7 @@ func TestHandlerSetTransferEncodingChunked(t *testing.T) { // https://golang.org/issue/16063 func TestHandlerSetTransferEncodingGzip(t *testing.T) { + setParallel(t) defer afterTest(t) ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Transfer-Encoding", "gzip") @@ -4416,13 +4580,19 @@ func BenchmarkClient(b *testing.B) { b.StopTimer() defer afterTest(b) - port := os.Getenv("TEST_BENCH_SERVER_PORT") // can be set by user - if port == "" { - port = "39207" - } var data = []byte("Hello world.\n") if server := os.Getenv("TEST_BENCH_SERVER"); server != "" { // Server process mode. + port := os.Getenv("TEST_BENCH_SERVER_PORT") // can be set by user + if port == "" { + port = "0" + } + ln, err := net.Listen("tcp", "localhost:"+port) + if err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } + fmt.Println(ln.Addr().String()) HandleFunc("/", func(w ResponseWriter, r *Request) { r.ParseForm() if r.Form.Get("stop") != "" { @@ -4431,33 +4601,44 @@ func BenchmarkClient(b *testing.B) { w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Write(data) }) - log.Fatal(ListenAndServe("localhost:"+port, nil)) + var srv Server + log.Fatal(srv.Serve(ln)) } // Start server process. cmd := exec.Command(os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkClient$") cmd.Env = append(os.Environ(), "TEST_BENCH_SERVER=yes") + cmd.Stderr = os.Stderr + stdout, err := cmd.StdoutPipe() + if err != nil { + b.Fatal(err) + } if err := cmd.Start(); err != nil { b.Fatalf("subprocess failed to start: %v", err) } defer cmd.Process.Kill() + + // Wait for the server in the child process to respond and tell us + // its listening address, once it's started listening: + timer := time.AfterFunc(10*time.Second, func() { + cmd.Process.Kill() + }) + defer timer.Stop() + bs := bufio.NewScanner(stdout) + if !bs.Scan() { + b.Fatalf("failed to read listening URL from child: %v", bs.Err()) + } + url := "http://" + strings.TrimSpace(bs.Text()) + "/" + timer.Stop() + if _, err := getNoBody(url); err != nil { + b.Fatalf("initial probe of child process failed: %v", err) + } + done := make(chan error) go func() { done <- cmd.Wait() }() - // Wait for the server process to respond. - url := "http://localhost:" + port + "/" - for i := 0; i < 100; i++ { - time.Sleep(100 * time.Millisecond) - if _, err := getNoBody(url); err == nil { - break - } - if i == 99 { - b.Fatalf("subprocess does not respond") - } - } - // Do b.N requests to the server. b.StartTimer() for i := 0; i < b.N; i++ { @@ -4719,6 +4900,7 @@ func BenchmarkCloseNotifier(b *testing.B) { // Verify this doesn't race (Issue 16505) func TestConcurrentServerServe(t *testing.T) { + setParallel(t) for i := 0; i < 100; i++ { ln1 := &oneConnListener{conn: nil} ln2 := &oneConnListener{conn: nil} @@ -4727,3 +4909,267 @@ func TestConcurrentServerServe(t *testing.T) { go func() { srv.Serve(ln2) }() } } + +func TestServerIdleTimeout(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + setParallel(t) + defer afterTest(t) + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + io.Copy(ioutil.Discard, r.Body) + io.WriteString(w, r.RemoteAddr) + })) + ts.Config.ReadHeaderTimeout = 1 * time.Second + ts.Config.IdleTimeout = 2 * time.Second + ts.Start() + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + get := func() string { + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + return string(slurp) + } + + a1, a2 := get(), get() + if a1 != a2 { + t.Fatalf("did requests on different connections") + } + time.Sleep(3 * time.Second) + a3 := get() + if a2 == a3 { + t.Fatal("request three unexpectedly on same connection") + } + + // And test that ReadHeaderTimeout still works: + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n")) + time.Sleep(2 * time.Second) + if _, err := io.CopyN(ioutil.Discard, conn, 1); err == nil { + t.Fatal("copy byte succeeded; want err") + } +} + +func get(t *testing.T, c *Client, url string) string { + res, err := c.Get(url) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + return string(slurp) +} + +// Tests that calls to Server.SetKeepAlivesEnabled(false) closes any +// currently-open connections. +func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) { + setParallel(t) + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + io.WriteString(w, r.RemoteAddr) + })) + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + get := func() string { return get(t, c, ts.URL) } + + a1, a2 := get(), get() + if a1 != a2 { + t.Fatal("expected first two requests on same connection") + } + var idle0 int + if !waitCondition(2*time.Second, 10*time.Millisecond, func() bool { + idle0 = tr.IdleConnKeyCountForTesting() + return idle0 == 1 + }) { + t.Fatalf("idle count before SetKeepAlivesEnabled called = %v; want 1", idle0) + } + + ts.Config.SetKeepAlivesEnabled(false) + + var idle1 int + if !waitCondition(2*time.Second, 10*time.Millisecond, func() bool { + idle1 = tr.IdleConnKeyCountForTesting() + return idle1 == 0 + }) { + t.Fatalf("idle count after SetKeepAlivesEnabled called = %v; want 0", idle1) + } + + a3 := get() + if a3 == a2 { + t.Fatal("expected third request on new connection") + } +} + +func TestServerShutdown_h1(t *testing.T) { testServerShutdown(t, h1Mode) } +func TestServerShutdown_h2(t *testing.T) { testServerShutdown(t, h2Mode) } + +func testServerShutdown(t *testing.T, h2 bool) { + setParallel(t) + defer afterTest(t) + var doShutdown func() // set later + var shutdownRes = make(chan error, 1) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + go doShutdown() + // Shutdown is graceful, so it should not interrupt + // this in-flight response. Add a tiny sleep here to + // increase the odds of a failure if shutdown has + // bugs. + time.Sleep(20 * time.Millisecond) + io.WriteString(w, r.RemoteAddr) + })) + defer cst.close() + + doShutdown = func() { + shutdownRes <- cst.ts.Config.Shutdown(context.Background()) + } + get(t, cst.c, cst.ts.URL) // calls t.Fail on failure + + if err := <-shutdownRes; err != nil { + t.Fatalf("Shutdown: %v", err) + } + + res, err := cst.c.Get(cst.ts.URL) + if err == nil { + res.Body.Close() + t.Fatal("second request should fail. server should be shut down") + } +} + +// Issue 17878: tests that we can call Close twice. +func TestServerCloseDeadlock(t *testing.T) { + var s Server + s.Close() + s.Close() +} + +// Issue 17717: tests that Server.SetKeepAlivesEnabled is respected by +// both HTTP/1 and HTTP/2. +func TestServerKeepAlivesEnabled_h1(t *testing.T) { testServerKeepAlivesEnabled(t, h1Mode) } +func TestServerKeepAlivesEnabled_h2(t *testing.T) { testServerKeepAlivesEnabled(t, h2Mode) } +func testServerKeepAlivesEnabled(t *testing.T, h2 bool) { + setParallel(t) + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "%v", r.RemoteAddr) + })) + defer cst.close() + srv := cst.ts.Config + srv.SetKeepAlivesEnabled(false) + a := cst.getURL(cst.ts.URL) + if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) { + t.Fatalf("test server has active conns") + } + b := cst.getURL(cst.ts.URL) + if a == b { + t.Errorf("got same connection between first and second requests") + } + if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) { + t.Fatalf("test server has active conns") + } +} + +// Issue 18447: test that the Server's ReadTimeout is stopped while +// the server's doing its 1-byte background read between requests, +// waiting for the connection to maybe close. +func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { + setParallel(t) + defer afterTest(t) + const timeout = 250 * time.Millisecond + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + select { + case <-time.After(2 * timeout): + fmt.Fprint(w, "ok") + case <-r.Context().Done(): + fmt.Fprint(w, r.Context().Err()) + } + })) + ts.Config.ReadTimeout = timeout + ts.Start() + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + slurp, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatal(err) + } + if string(slurp) != "ok" { + t.Fatalf("Got: %q, want ok", slurp) + } +} + +// Issue 18535: test that the Server doesn't try to do a background +// read if it's already done one. +func TestServerDuplicateBackgroundRead(t *testing.T) { + setParallel(t) + defer afterTest(t) + + const goroutines = 5 + const requests = 2000 + + hts := httptest.NewServer(HandlerFunc(NotFound)) + defer hts.Close() + + reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n") + + var wg sync.WaitGroup + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + cn, err := net.Dial("tcp", hts.Listener.Addr().String()) + if err != nil { + t.Error(err) + return + } + defer cn.Close() + + wg.Add(1) + go func() { + defer wg.Done() + io.Copy(ioutil.Discard, cn) + }() + + for j := 0; j < requests; j++ { + if t.Failed() { + return + } + _, err := cn.Write(reqBytes) + if err != nil { + t.Error(err) + return + } + } + }() + } + wg.Wait() +} diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go index 89574a8b..9623648 100644 --- a/libgo/go/net/http/server.go +++ b/libgo/go/net/http/server.go @@ -40,7 +40,9 @@ var ( // ErrHijacked is returned by ResponseWriter.Write calls when // the underlying connection has been hijacked using the - // Hijacker interfaced. + // Hijacker interface. A zero-byte write on a hijacked + // connection will return ErrHijacked without any other side + // effects. ErrHijacked = errors.New("http: connection has been hijacked") // ErrContentLength is returned by ResponseWriter.Write calls @@ -73,7 +75,9 @@ var ( // If ServeHTTP panics, the server (the caller of ServeHTTP) assumes // that the effect of the panic was isolated to the active request. // It recovers the panic, logs a stack trace to the server error log, -// and hangs up the connection. +// and hangs up the connection. To abort a handler so the client sees +// an interrupted response but the server doesn't log an error, panic +// with the value ErrAbortHandler. type Handler interface { ServeHTTP(ResponseWriter, *Request) } @@ -85,11 +89,25 @@ type Handler interface { // has returned. type ResponseWriter interface { // Header returns the header map that will be sent by - // WriteHeader. Changing the header after a call to - // WriteHeader (or Write) has no effect unless the modified - // headers were declared as trailers by setting the - // "Trailer" header before the call to WriteHeader (see example). - // To suppress implicit response headers, set their value to nil. + // WriteHeader. The Header map also is the mechanism with which + // Handlers can set HTTP trailers. + // + // Changing the header map after a call to WriteHeader (or + // Write) has no effect unless the modified headers are + // trailers. + // + // There are two ways to set Trailers. The preferred way is to + // predeclare in the headers which trailers you will later + // send by setting the "Trailer" header to the names of the + // trailer keys which will come later. In this case, those + // keys of the Header map are treated as if they were + // trailers. See the example. The second way, for trailer + // keys not known to the Handler until after the first Write, + // is to prefix the Header map keys with the TrailerPrefix + // constant value. See TrailerPrefix. + // + // To suppress implicit response headers (such as "Date"), set + // their value to nil. Header() Header // Write writes the data to the connection as part of an HTTP reply. @@ -206,6 +224,9 @@ type conn struct { // Immutable; never nil. server *Server + // cancelCtx cancels the connection-level context. + cancelCtx context.CancelFunc + // rwc is the underlying network connection. // This is never wrapped by other types and is the value given out // to CloseNotifier callers. It is usually of type *net.TCPConn or @@ -232,7 +253,6 @@ type conn struct { r *connReader // bufr reads from r. - // Users of bufr must hold mu. bufr *bufio.Reader // bufw writes to checkConnErrorWriter{c}, which populates werr on error. @@ -242,7 +262,11 @@ type conn struct { // on this connection, if any. lastMethod string - // mu guards hijackedv, use of bufr, (*response).closeNotifyCh. + curReq atomic.Value // of *response (which has a Request in it) + + curState atomic.Value // of ConnState + + // mu guards hijackedv mu sync.Mutex // hijackedv is whether this connection has been hijacked @@ -262,8 +286,12 @@ func (c *conn) hijackLocked() (rwc net.Conn, buf *bufio.ReadWriter, err error) { if c.hijackedv { return nil, nil, ErrHijacked } + c.r.abortPendingRead() + c.hijackedv = true rwc = c.rwc + rwc.SetDeadline(time.Time{}) + buf = bufio.NewReadWriter(c.bufr, bufio.NewWriter(rwc)) c.setState(rwc, StateHijacked) return @@ -346,13 +374,7 @@ func (cw *chunkWriter) close() { bw := cw.res.conn.bufw // conn's bufio writer // zero chunk to mark EOF bw.WriteString("0\r\n") - if len(cw.res.trailers) > 0 { - trailers := make(Header) - for _, h := range cw.res.trailers { - if vv := cw.res.handlerHeader[h]; len(vv) > 0 { - trailers[h] = vv - } - } + if trailers := cw.res.finalTrailers(); trailers != nil { trailers.Write(bw) // the writer handles noting errors } // final blank line after the trailers (whether @@ -413,9 +435,48 @@ type response struct { dateBuf [len(TimeFormat)]byte clenBuf [10]byte - // closeNotifyCh is non-nil once CloseNotify is called. - // Guarded by conn.mu - closeNotifyCh <-chan bool + // closeNotifyCh is the channel returned by CloseNotify. + // TODO(bradfitz): this is currently (for Go 1.8) always + // non-nil. Make this lazily-created again as it used to be? + closeNotifyCh chan bool + didCloseNotify int32 // atomic (only 0->1 winner should send) +} + +// TrailerPrefix is a magic prefix for ResponseWriter.Header map keys +// that, if present, signals that the map entry is actually for +// the response trailers, and not the response headers. The prefix +// is stripped after the ServeHTTP call finishes and the values are +// sent in the trailers. +// +// This mechanism is intended only for trailers that are not known +// prior to the headers being written. If the set of trailers is fixed +// or known before the header is written, the normal Go trailers mechanism +// is preferred: +// https://golang.org/pkg/net/http/#ResponseWriter +// https://golang.org/pkg/net/http/#example_ResponseWriter_trailers +const TrailerPrefix = "Trailer:" + +// finalTrailers is called after the Handler exits and returns a non-nil +// value if the Handler set any trailers. +func (w *response) finalTrailers() Header { + var t Header + for k, vv := range w.handlerHeader { + if strings.HasPrefix(k, TrailerPrefix) { + if t == nil { + t = make(Header) + } + t[strings.TrimPrefix(k, TrailerPrefix)] = vv + } + } + for _, k := range w.trailers { + if t == nil { + t = make(Header) + } + for _, v := range w.handlerHeader[k] { + t.Add(k, v) + } + } + return t } type atomicBool int32 @@ -548,60 +609,152 @@ type readResult struct { // call blocked in a background goroutine to wait for activity and // trigger a CloseNotifier channel. type connReader struct { - r io.Reader - remain int64 // bytes remaining + conn *conn - // ch is non-nil if a background read is in progress. - // It is guarded by conn.mu. - ch chan readResult + mu sync.Mutex // guards following + hasByte bool + byteBuf [1]byte + bgErr error // non-nil means error happened on background read + cond *sync.Cond + inRead bool + aborted bool // set true before conn.rwc deadline is set to past + remain int64 // bytes remaining +} + +func (cr *connReader) lock() { + cr.mu.Lock() + if cr.cond == nil { + cr.cond = sync.NewCond(&cr.mu) + } +} + +func (cr *connReader) unlock() { cr.mu.Unlock() } + +func (cr *connReader) startBackgroundRead() { + cr.lock() + defer cr.unlock() + if cr.inRead { + panic("invalid concurrent Body.Read call") + } + if cr.hasByte { + return + } + cr.inRead = true + cr.conn.rwc.SetReadDeadline(time.Time{}) + go cr.backgroundRead() +} + +func (cr *connReader) backgroundRead() { + n, err := cr.conn.rwc.Read(cr.byteBuf[:]) + cr.lock() + if n == 1 { + cr.hasByte = true + // We were at EOF already (since we wouldn't be in a + // background read otherwise), so this is a pipelined + // HTTP request. + cr.closeNotifyFromPipelinedRequest() + } + if ne, ok := err.(net.Error); ok && cr.aborted && ne.Timeout() { + // Ignore this error. It's the expected error from + // another goroutine calling abortPendingRead. + } else if err != nil { + cr.handleReadError(err) + } + cr.aborted = false + cr.inRead = false + cr.unlock() + cr.cond.Broadcast() +} + +func (cr *connReader) abortPendingRead() { + cr.lock() + defer cr.unlock() + if !cr.inRead { + return + } + cr.aborted = true + cr.conn.rwc.SetReadDeadline(aLongTimeAgo) + for cr.inRead { + cr.cond.Wait() + } + cr.conn.rwc.SetReadDeadline(time.Time{}) } func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain } func (cr *connReader) setInfiniteReadLimit() { cr.remain = maxInt64 } func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 } +// may be called from multiple goroutines. +func (cr *connReader) handleReadError(err error) { + cr.conn.cancelCtx() + cr.closeNotify() +} + +// closeNotifyFromPipelinedRequest simply calls closeNotify. +// +// This method wrapper is here for documentation. The callers are the +// cases where we send on the closenotify channel because of a +// pipelined HTTP request, per the previous Go behavior and +// documentation (that this "MAY" happen). +// +// TODO: consider changing this behavior and making context +// cancelation and closenotify work the same. +func (cr *connReader) closeNotifyFromPipelinedRequest() { + cr.closeNotify() +} + +// may be called from multiple goroutines. +func (cr *connReader) closeNotify() { + res, _ := cr.conn.curReq.Load().(*response) + if res != nil { + if atomic.CompareAndSwapInt32(&res.didCloseNotify, 0, 1) { + res.closeNotifyCh <- true + } + } +} + func (cr *connReader) Read(p []byte) (n int, err error) { + cr.lock() + if cr.inRead { + cr.unlock() + panic("invalid concurrent Body.Read call") + } if cr.hitReadLimit() { + cr.unlock() return 0, io.EOF } + if cr.bgErr != nil { + err = cr.bgErr + cr.unlock() + return 0, err + } if len(p) == 0 { - return + cr.unlock() + return 0, nil } if int64(len(p)) > cr.remain { p = p[:cr.remain] } - - // Is a background read (started by CloseNotifier) already in - // flight? If so, wait for it and use its result. - ch := cr.ch - if ch != nil { - cr.ch = nil - res := <-ch - if res.n == 1 { - p[0] = res.b - cr.remain -= 1 - } - return res.n, res.err + if cr.hasByte { + p[0] = cr.byteBuf[0] + cr.hasByte = false + cr.unlock() + return 1, nil } - n, err = cr.r.Read(p) - cr.remain -= int64(n) - return -} + cr.inRead = true + cr.unlock() + n, err = cr.conn.rwc.Read(p) -func (cr *connReader) startBackgroundRead(onReadComplete func()) { - if cr.ch != nil { - // Background read already started. - return + cr.lock() + cr.inRead = false + if err != nil { + cr.handleReadError(err) } - cr.ch = make(chan readResult, 1) - go cr.closeNotifyAwaitActivityRead(cr.ch, onReadComplete) -} + cr.remain -= int64(n) + cr.unlock() -func (cr *connReader) closeNotifyAwaitActivityRead(ch chan<- readResult, onReadComplete func()) { - var buf [1]byte - n, err := cr.r.Read(buf[:1]) - onReadComplete() - ch <- readResult{n, err, buf[0]} + cr.cond.Broadcast() + return n, err } var ( @@ -633,7 +786,7 @@ func newBufioReader(r io.Reader) *bufio.Reader { br.Reset(r) return br } - // Note: if this reader size is every changed, update + // Note: if this reader size is ever changed, update // TestHandlerBodyClose's assumptions. return bufio.NewReader(r) } @@ -746,9 +899,18 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) { return nil, ErrHijacked } + var ( + wholeReqDeadline time.Time // or zero if none + hdrDeadline time.Time // or zero if none + ) + t0 := time.Now() + if d := c.server.readHeaderTimeout(); d != 0 { + hdrDeadline = t0.Add(d) + } if d := c.server.ReadTimeout; d != 0 { - c.rwc.SetReadDeadline(time.Now().Add(d)) + wholeReqDeadline = t0.Add(d) } + c.rwc.SetReadDeadline(hdrDeadline) if d := c.server.WriteTimeout; d != 0 { defer func() { c.rwc.SetWriteDeadline(time.Now().Add(d)) @@ -756,14 +918,12 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) { } c.r.setReadLimit(c.server.initialReadLimitSize()) - c.mu.Lock() // while using bufr if c.lastMethod == "POST" { // RFC 2616 section 4.1 tolerance for old buggy clients. peek, _ := c.bufr.Peek(4) // ReadRequest will get err below c.bufr.Discard(numLeadingCRorLF(peek)) } req, err := readRequest(c.bufr, keepHostHeader) - c.mu.Unlock() if err != nil { if c.r.hitReadLimit() { return nil, errTooLarge @@ -809,6 +969,11 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) { body.doEarlyClose = true } + // Adjust the read deadline if necessary. + if !hdrDeadline.Equal(wholeReqDeadline) { + c.rwc.SetReadDeadline(wholeReqDeadline) + } + w = &response{ conn: c, cancelCtx: cancelCtx, @@ -816,6 +981,7 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) { reqBody: req.Body, handlerHeader: make(Header), contentLength: -1, + closeNotifyCh: make(chan bool, 1), // We populate these ahead of time so we're not // reading from req.Header after their Handler starts @@ -990,7 +1156,17 @@ func (cw *chunkWriter) writeHeader(p []byte) { } var setHeader extraHeader + // Don't write out the fake "Trailer:foo" keys. See TrailerPrefix. trailers := false + for k := range cw.header { + if strings.HasPrefix(k, TrailerPrefix) { + if excludeHeader == nil { + excludeHeader = make(map[string]bool) + } + excludeHeader[k] = true + trailers = true + } + } for _, v := range cw.header["Trailer"] { trailers = true foreachHeaderElement(v, cw.res.declareTrailer) @@ -1318,7 +1494,9 @@ func (w *response) WriteString(data string) (n int, err error) { // either dataB or dataS is non-zero. func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err error) { if w.conn.hijacked() { - w.conn.server.logf("http: response.Write on hijacked connection") + if lenData > 0 { + w.conn.server.logf("http: response.Write on hijacked connection") + } return 0, ErrHijacked } if !w.wroteHeader { @@ -1354,6 +1532,8 @@ func (w *response) finishRequest() { w.cw.close() w.conn.bufw.Flush() + w.conn.r.abortPendingRead() + // Close the body (regardless of w.closeAfterReply) so we can // re-use its bufio.Reader later safely. w.reqBody.Close() @@ -1469,11 +1649,30 @@ func validNPN(proto string) bool { } func (c *conn) setState(nc net.Conn, state ConnState) { - if hook := c.server.ConnState; hook != nil { + srv := c.server + switch state { + case StateNew: + srv.trackConn(c, true) + case StateHijacked, StateClosed: + srv.trackConn(c, false) + } + c.curState.Store(connStateInterface[state]) + if hook := srv.ConnState; hook != nil { hook(nc, state) } } +// connStateInterface is an array of the interface{} versions of +// ConnState values, so we can use them in atomic.Values later without +// paying the cost of shoving their integers in an interface{}. +var connStateInterface = [...]interface{}{ + StateNew: StateNew, + StateActive: StateActive, + StateIdle: StateIdle, + StateHijacked: StateHijacked, + StateClosed: StateClosed, +} + // badRequestError is a literal string (used by in the server in HTML, // unescaped) to tell the user why their request was bad. It should // be plain text without user info or other embedded errors. @@ -1481,11 +1680,34 @@ type badRequestError string func (e badRequestError) Error() string { return "Bad Request: " + string(e) } +// ErrAbortHandler is a sentinel panic value to abort a handler. +// While any panic from ServeHTTP aborts the response to the client, +// panicking with ErrAbortHandler also suppresses logging of a stack +// trace to the server's error log. +var ErrAbortHandler = errors.New("net/http: abort Handler") + +// isCommonNetReadError reports whether err is a common error +// encountered during reading a request off the network when the +// client has gone away or had its read fail somehow. This is used to +// determine which logs are interesting enough to log about. +func isCommonNetReadError(err error) bool { + if err == io.EOF { + return true + } + if neterr, ok := err.(net.Error); ok && neterr.Timeout() { + return true + } + if oe, ok := err.(*net.OpError); ok && oe.Op == "read" { + return true + } + return false +} + // Serve a new connection. func (c *conn) serve(ctx context.Context) { c.remoteAddr = c.rwc.RemoteAddr().String() defer func() { - if err := recover(); err != nil { + if err := recover(); err != nil && err != ErrAbortHandler { const size = 64 << 10 buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] @@ -1521,13 +1743,14 @@ func (c *conn) serve(ctx context.Context) { // HTTP/1.x from here on. - c.r = &connReader{r: c.rwc} - c.bufr = newBufioReader(c.r) - c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10) - ctx, cancelCtx := context.WithCancel(ctx) + c.cancelCtx = cancelCtx defer cancelCtx() + c.r = &connReader{conn: c} + c.bufr = newBufioReader(c.r) + c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10) + for { w, err := c.readRequest(ctx) if c.r.remain != c.server.initialReadLimitSize() { @@ -1535,27 +1758,29 @@ func (c *conn) serve(ctx context.Context) { c.setState(c.rwc, StateActive) } if err != nil { + const errorHeaders = "\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\n" + if err == errTooLarge { // Their HTTP client may or may not be // able to read this if we're // responding to them and hanging up // while they're still writing their // request. Undefined behavior. - io.WriteString(c.rwc, "HTTP/1.1 431 Request Header Fields Too Large\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n431 Request Header Fields Too Large") + const publicErr = "431 Request Header Fields Too Large" + fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr) c.closeWriteAndWait() return } - if err == io.EOF { - return // don't reply - } - if neterr, ok := err.(net.Error); ok && neterr.Timeout() { + if isCommonNetReadError(err) { return // don't reply } - var publicErr string + + publicErr := "400 Bad Request" if v, ok := err.(badRequestError); ok { - publicErr = ": " + string(v) + publicErr = publicErr + ": " + string(v) } - io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n400 Bad Request"+publicErr) + + fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr) return } @@ -1571,11 +1796,24 @@ func (c *conn) serve(ctx context.Context) { return } + c.curReq.Store(w) + + if requestBodyRemains(req.Body) { + registerOnHitEOF(req.Body, w.conn.r.startBackgroundRead) + } else { + if w.conn.bufr.Buffered() > 0 { + w.conn.r.closeNotifyFromPipelinedRequest() + } + w.conn.r.startBackgroundRead() + } + // HTTP cannot have multiple simultaneous active requests.[*] // Until the server replies to this request, it can't read another, // so we might as well run the handler in this goroutine. // [*] Not strictly true: HTTP pipelining. We could let them all process // in parallel even if their responses need to be serialized. + // But we're not going to implement HTTP pipelining because it + // was never deployed in the wild and the answer is HTTP/2. serverHandler{c.server}.ServeHTTP(w, w.req) w.cancelCtx() if c.hijacked() { @@ -1589,6 +1827,23 @@ func (c *conn) serve(ctx context.Context) { return } c.setState(c.rwc, StateIdle) + c.curReq.Store((*response)(nil)) + + if !w.conn.server.doKeepAlives() { + // We're in shutdown mode. We might've replied + // to the user without "Connection: close" and + // they might think they can send another + // request, but such is life with HTTP/1.1. + return + } + + if d := c.server.idleTimeout(); d != 0 { + c.rwc.SetReadDeadline(time.Now().Add(d)) + if _, err := c.bufr.Peek(4); err != nil { + return + } + } + c.rwc.SetReadDeadline(time.Time{}) } } @@ -1624,10 +1879,6 @@ func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { c.mu.Lock() defer c.mu.Unlock() - if w.closeNotifyCh != nil { - return nil, nil, errors.New("http: Hijack is incompatible with use of CloseNotifier in same ServeHTTP call") - } - // Release the bufioWriter that writes to the chunk writer, it is not // used after a connection has been hijacked. rwc, buf, err = c.hijackLocked() @@ -1642,50 +1893,7 @@ func (w *response) CloseNotify() <-chan bool { if w.handlerDone.isSet() { panic("net/http: CloseNotify called after ServeHTTP finished") } - c := w.conn - c.mu.Lock() - defer c.mu.Unlock() - - if w.closeNotifyCh != nil { - return w.closeNotifyCh - } - ch := make(chan bool, 1) - w.closeNotifyCh = ch - - if w.conn.hijackedv { - // CloseNotify is undefined after a hijack, but we have - // no place to return an error, so just return a channel, - // even though it'll never receive a value. - return ch - } - - var once sync.Once - notify := func() { once.Do(func() { ch <- true }) } - - if requestBodyRemains(w.reqBody) { - // They're still consuming the request body, so we - // shouldn't notify yet. - registerOnHitEOF(w.reqBody, func() { - c.mu.Lock() - defer c.mu.Unlock() - startCloseNotifyBackgroundRead(c, notify) - }) - } else { - startCloseNotifyBackgroundRead(c, notify) - } - return ch -} - -// c.mu must be held. -func startCloseNotifyBackgroundRead(c *conn, notify func()) { - if c.bufr.Buffered() > 0 { - // They've consumed the request body, so anything - // remaining is a pipelined request, which we - // document as firing on. - notify() - } else { - c.r.startBackgroundRead(notify) - } + return w.closeNotifyCh } func registerOnHitEOF(rc io.ReadCloser, fn func()) { @@ -1702,7 +1910,7 @@ func registerOnHitEOF(rc io.ReadCloser, fn func()) { // requestBodyRemains reports whether future calls to Read // on rc might yield more data. func requestBodyRemains(rc io.ReadCloser) bool { - if rc == eofReader { + if rc == NoBody { return false } switch v := rc.(type) { @@ -1816,7 +2024,7 @@ func Redirect(w ResponseWriter, r *Request, urlStr string, code int) { } } - w.Header().Set("Location", urlStr) + w.Header().Set("Location", hexEscapeNonASCII(urlStr)) w.WriteHeader(code) // RFC 2616 recommends that a short note "SHOULD" be included in the @@ -2094,11 +2302,36 @@ func Serve(l net.Listener, handler Handler) error { // A Server defines parameters for running an HTTP server. // The zero value for Server is a valid configuration. type Server struct { - Addr string // TCP address to listen on, ":http" if empty - Handler Handler // handler to invoke, http.DefaultServeMux if nil - ReadTimeout time.Duration // maximum duration before timing out read of the request - WriteTimeout time.Duration // maximum duration before timing out write of the response - TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS + Addr string // TCP address to listen on, ":http" if empty + Handler Handler // handler to invoke, http.DefaultServeMux if nil + TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS + + // ReadTimeout is the maximum duration for reading the entire + // request, including the body. + // + // Because ReadTimeout does not let Handlers make per-request + // decisions on each request body's acceptable deadline or + // upload rate, most users will prefer to use + // ReadHeaderTimeout. It is valid to use them both. + ReadTimeout time.Duration + + // ReadHeaderTimeout is the amount of time allowed to read + // request headers. The connection's read deadline is reset + // after reading the headers and the Handler can decide what + // is considered too slow for the body. + ReadHeaderTimeout time.Duration + + // WriteTimeout is the maximum duration before timing out + // writes of the response. It is reset whenever a new + // request's header is read. Like ReadTimeout, it does not + // let Handlers make decisions on a per-request basis. + WriteTimeout time.Duration + + // IdleTimeout is the maximum amount of time to wait for the + // next request when keep-alives are enabled. If IdleTimeout + // is zero, the value of ReadTimeout is used. If both are + // zero, there is no timeout. + IdleTimeout time.Duration // MaxHeaderBytes controls the maximum number of bytes the // server will read parsing the request header's keys and @@ -2114,7 +2347,8 @@ type Server struct { // handle HTTP requests and will initialize the Request's TLS // and RemoteAddr if not already set. The connection is // automatically closed when the function returns. - // If TLSNextProto is nil, HTTP/2 support is enabled automatically. + // If TLSNextProto is not nil, HTTP/2 support is not enabled + // automatically. TLSNextProto map[string]func(*Server, *tls.Conn, Handler) // ConnState specifies an optional callback function that is @@ -2129,8 +2363,132 @@ type Server struct { ErrorLog *log.Logger disableKeepAlives int32 // accessed atomically. + inShutdown int32 // accessed atomically (non-zero means we're in Shutdown) nextProtoOnce sync.Once // guards setupHTTP2_* init nextProtoErr error // result of http2.ConfigureServer if used + + mu sync.Mutex + listeners map[net.Listener]struct{} + activeConn map[*conn]struct{} + doneChan chan struct{} +} + +func (s *Server) getDoneChan() <-chan struct{} { + s.mu.Lock() + defer s.mu.Unlock() + return s.getDoneChanLocked() +} + +func (s *Server) getDoneChanLocked() chan struct{} { + if s.doneChan == nil { + s.doneChan = make(chan struct{}) + } + return s.doneChan +} + +func (s *Server) closeDoneChanLocked() { + ch := s.getDoneChanLocked() + select { + case <-ch: + // Already closed. Don't close again. + default: + // Safe to close here. We're the only closer, guarded + // by s.mu. + close(ch) + } +} + +// Close immediately closes all active net.Listeners and any +// connections in state StateNew, StateActive, or StateIdle. For a +// graceful shutdown, use Shutdown. +// +// Close does not attempt to close (and does not even know about) +// any hijacked connections, such as WebSockets. +// +// Close returns any error returned from closing the Server's +// underlying Listener(s). +func (srv *Server) Close() error { + srv.mu.Lock() + defer srv.mu.Unlock() + srv.closeDoneChanLocked() + err := srv.closeListenersLocked() + for c := range srv.activeConn { + c.rwc.Close() + delete(srv.activeConn, c) + } + return err +} + +// shutdownPollInterval is how often we poll for quiescence +// during Server.Shutdown. This is lower during tests, to +// speed up tests. +// Ideally we could find a solution that doesn't involve polling, +// but which also doesn't have a high runtime cost (and doesn't +// involve any contentious mutexes), but that is left as an +// exercise for the reader. +var shutdownPollInterval = 500 * time.Millisecond + +// Shutdown gracefully shuts down the server without interrupting any +// active connections. Shutdown works by first closing all open +// listeners, then closing all idle connections, and then waiting +// indefinitely for connections to return to idle and then shut down. +// If the provided context expires before the shutdown is complete, +// then the context's error is returned. +// +// Shutdown does not attempt to close nor wait for hijacked +// connections such as WebSockets. The caller of Shutdown should +// separately notify such long-lived connections of shutdown and wait +// for them to close, if desired. +func (srv *Server) Shutdown(ctx context.Context) error { + atomic.AddInt32(&srv.inShutdown, 1) + defer atomic.AddInt32(&srv.inShutdown, -1) + + srv.mu.Lock() + lnerr := srv.closeListenersLocked() + srv.closeDoneChanLocked() + srv.mu.Unlock() + + ticker := time.NewTicker(shutdownPollInterval) + defer ticker.Stop() + for { + if srv.closeIdleConns() { + return lnerr + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } + } +} + +// closeIdleConns closes all idle connections and reports whether the +// server is quiescent. +func (s *Server) closeIdleConns() bool { + s.mu.Lock() + defer s.mu.Unlock() + quiescent := true + for c := range s.activeConn { + st, ok := c.curState.Load().(ConnState) + if !ok || st != StateIdle { + quiescent = false + continue + } + c.rwc.Close() + delete(s.activeConn, c) + } + return quiescent +} + +func (s *Server) closeListenersLocked() error { + var err error + for ln := range s.listeners { + if cerr := ln.Close(); cerr != nil && err == nil { + err = cerr + } + delete(s.listeners, ln) + } + return err } // A ConnState represents the state of a client connection to a server. @@ -2243,6 +2601,8 @@ func (srv *Server) shouldConfigureHTTP2ForServe() bool { return strSliceContains(srv.TLSConfig.NextProtos, http2NextProtoTLS) } +var ErrServerClosed = errors.New("http: Server closed") + // Serve accepts incoming connections on the Listener l, creating a // new service goroutine for each. The service goroutines read requests and // then call srv.Handler to reply to them. @@ -2252,7 +2612,8 @@ func (srv *Server) shouldConfigureHTTP2ForServe() bool { // srv.TLSConfig is non-nil and doesn't include the string "h2" in // Config.NextProtos, HTTP/2 support is not enabled. // -// Serve always returns a non-nil error. +// Serve always returns a non-nil error. After Shutdown or Close, the +// returned error is ErrServerClosed. func (srv *Server) Serve(l net.Listener) error { defer l.Close() if fn := testHookServerServe; fn != nil { @@ -2264,14 +2625,20 @@ func (srv *Server) Serve(l net.Listener) error { return err } - // TODO: allow changing base context? can't imagine concrete - // use cases yet. - baseCtx := context.Background() + srv.trackListener(l, true) + defer srv.trackListener(l, false) + + baseCtx := context.Background() // base is always background, per Issue 16220 ctx := context.WithValue(baseCtx, ServerContextKey, srv) ctx = context.WithValue(ctx, LocalAddrContextKey, l.Addr()) for { rw, e := l.Accept() if e != nil { + select { + case <-srv.getDoneChan(): + return ErrServerClosed + default: + } if ne, ok := e.(net.Error); ok && ne.Temporary() { if tempDelay == 0 { tempDelay = 5 * time.Millisecond @@ -2294,8 +2661,57 @@ func (srv *Server) Serve(l net.Listener) error { } } +func (s *Server) trackListener(ln net.Listener, add bool) { + s.mu.Lock() + defer s.mu.Unlock() + if s.listeners == nil { + s.listeners = make(map[net.Listener]struct{}) + } + if add { + // If the *Server is being reused after a previous + // Close or Shutdown, reset its doneChan: + if len(s.listeners) == 0 && len(s.activeConn) == 0 { + s.doneChan = nil + } + s.listeners[ln] = struct{}{} + } else { + delete(s.listeners, ln) + } +} + +func (s *Server) trackConn(c *conn, add bool) { + s.mu.Lock() + defer s.mu.Unlock() + if s.activeConn == nil { + s.activeConn = make(map[*conn]struct{}) + } + if add { + s.activeConn[c] = struct{}{} + } else { + delete(s.activeConn, c) + } +} + +func (s *Server) idleTimeout() time.Duration { + if s.IdleTimeout != 0 { + return s.IdleTimeout + } + return s.ReadTimeout +} + +func (s *Server) readHeaderTimeout() time.Duration { + if s.ReadHeaderTimeout != 0 { + return s.ReadHeaderTimeout + } + return s.ReadTimeout +} + func (s *Server) doKeepAlives() bool { - return atomic.LoadInt32(&s.disableKeepAlives) == 0 + return atomic.LoadInt32(&s.disableKeepAlives) == 0 && !s.shuttingDown() +} + +func (s *Server) shuttingDown() bool { + return atomic.LoadInt32(&s.inShutdown) != 0 } // SetKeepAlivesEnabled controls whether HTTP keep-alives are enabled. @@ -2305,9 +2721,21 @@ func (s *Server) doKeepAlives() bool { func (srv *Server) SetKeepAlivesEnabled(v bool) { if v { atomic.StoreInt32(&srv.disableKeepAlives, 0) - } else { - atomic.StoreInt32(&srv.disableKeepAlives, 1) + return } + atomic.StoreInt32(&srv.disableKeepAlives, 1) + + // Close idle HTTP/1 conns: + srv.closeIdleConns() + + // Close HTTP/2 conns, as soon as they become idle, but reset + // the chan so future conns (if the listener is still active) + // still work and don't get a GOAWAY immediately, before their + // first request: + srv.mu.Lock() + defer srv.mu.Unlock() + srv.closeDoneChanLocked() // closes http2 conns + srv.doneChan = nil } func (s *Server) logf(format string, args ...interface{}) { @@ -2630,24 +3058,6 @@ func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) { } } -type eofReaderWithWriteTo struct{} - -func (eofReaderWithWriteTo) WriteTo(io.Writer) (int64, error) { return 0, nil } -func (eofReaderWithWriteTo) Read([]byte) (int, error) { return 0, io.EOF } - -// eofReader is a non-nil io.ReadCloser that always returns EOF. -// It has a WriteTo method so io.Copy won't need a buffer. -var eofReader = &struct { - eofReaderWithWriteTo - io.Closer -}{ - eofReaderWithWriteTo{}, - ioutil.NopCloser(nil), -} - -// Verify that an io.Copy from an eofReader won't require a buffer. -var _ io.WriterTo = eofReader - // initNPNRequest is an HTTP handler that initializes certain // uninitialized fields in its *Request. Such partially-initialized // Requests come from NPN protocol handlers. @@ -2662,7 +3072,7 @@ func (h initNPNRequest) ServeHTTP(rw ResponseWriter, req *Request) { *req.TLS = h.c.ConnectionState() } if req.Body == nil { - req.Body = eofReader + req.Body = NoBody } if req.RemoteAddr == "" { req.RemoteAddr = h.c.RemoteAddr().String() @@ -2723,6 +3133,7 @@ func (w checkConnErrorWriter) Write(p []byte) (n int, err error) { n, err = w.c.rwc.Write(p) if err != nil && w.c.werr == nil { w.c.werr = err + w.c.cancelCtx() } return } diff --git a/libgo/go/net/http/sniff_test.go b/libgo/go/net/http/sniff_test.go index ac404bf..38f3f81 100644 --- a/libgo/go/net/http/sniff_test.go +++ b/libgo/go/net/http/sniff_test.go @@ -66,6 +66,7 @@ func TestServerContentType_h1(t *testing.T) { testServerContentType(t, h1Mode) } func TestServerContentType_h2(t *testing.T) { testServerContentType(t, h2Mode) } func testServerContentType(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { i, _ := strconv.Atoi(r.FormValue("i")) @@ -160,6 +161,7 @@ func testContentTypeWithCopy(t *testing.T, h2 bool) { func TestSniffWriteSize_h1(t *testing.T) { testSniffWriteSize(t, h1Mode) } func TestSniffWriteSize_h2(t *testing.T) { testSniffWriteSize(t, h2Mode) } func testSniffWriteSize(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { size, _ := strconv.Atoi(r.FormValue("size")) diff --git a/libgo/go/net/http/transfer.go b/libgo/go/net/http/transfer.go index c653467..4f47637 100644 --- a/libgo/go/net/http/transfer.go +++ b/libgo/go/net/http/transfer.go @@ -17,6 +17,7 @@ import ( "strconv" "strings" "sync" + "time" "golang_org/x/net/lex/httplex" ) @@ -33,6 +34,23 @@ func (r errorReader) Read(p []byte) (n int, err error) { return 0, r.err } +type byteReader struct { + b byte + done bool +} + +func (br *byteReader) Read(p []byte) (n int, err error) { + if br.done { + return 0, io.EOF + } + if len(p) == 0 { + return 0, nil + } + br.done = true + p[0] = br.b + return 1, io.EOF +} + // transferWriter inspects the fields of a user-supplied Request or Response, // sanitizes them without changing the user object and provides methods for // writing the respective header, body and trailer in wire format. @@ -46,6 +64,9 @@ type transferWriter struct { TransferEncoding []string Trailer Header IsResponse bool + + FlushHeaders bool // flush headers to network before body + ByteReadCh chan readResult // non-nil if probeRequestBody called } func newTransferWriter(r interface{}) (t *transferWriter, err error) { @@ -59,37 +80,15 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) { return nil, fmt.Errorf("http: Request.ContentLength=%d with nil Body", rr.ContentLength) } t.Method = valueOrDefault(rr.Method, "GET") - t.Body = rr.Body - t.BodyCloser = rr.Body - t.ContentLength = rr.ContentLength t.Close = rr.Close t.TransferEncoding = rr.TransferEncoding t.Trailer = rr.Trailer - atLeastHTTP11 = rr.ProtoAtLeast(1, 1) - if t.Body != nil && len(t.TransferEncoding) == 0 && atLeastHTTP11 { - if t.ContentLength == 0 { - // Test to see if it's actually zero or just unset. - var buf [1]byte - n, rerr := io.ReadFull(t.Body, buf[:]) - if rerr != nil && rerr != io.EOF { - t.ContentLength = -1 - t.Body = errorReader{rerr} - } else if n == 1 { - // Oh, guess there is data in this Body Reader after all. - // The ContentLength field just wasn't set. - // Stich the Body back together again, re-attaching our - // consumed byte. - t.ContentLength = -1 - t.Body = io.MultiReader(bytes.NewReader(buf[:]), t.Body) - } else { - // Body is actually empty. - t.Body = nil - t.BodyCloser = nil - } - } - if t.ContentLength < 0 { - t.TransferEncoding = []string{"chunked"} - } + atLeastHTTP11 = rr.protoAtLeastOutgoing(1, 1) + t.Body = rr.Body + t.BodyCloser = rr.Body + t.ContentLength = rr.outgoingLength() + if t.ContentLength < 0 && len(t.TransferEncoding) == 0 && atLeastHTTP11 && t.shouldSendChunkedRequestBody() { + t.TransferEncoding = []string{"chunked"} } case *Response: t.IsResponse = true @@ -103,7 +102,7 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) { t.TransferEncoding = rr.TransferEncoding t.Trailer = rr.Trailer atLeastHTTP11 = rr.ProtoAtLeast(1, 1) - t.ResponseToHEAD = noBodyExpected(t.Method) + t.ResponseToHEAD = noResponseBodyExpected(t.Method) } // Sanitize Body,ContentLength,TransferEncoding @@ -131,7 +130,100 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) { return t, nil } -func noBodyExpected(requestMethod string) bool { +// shouldSendChunkedRequestBody reports whether we should try to send a +// chunked request body to the server. In particular, the case we really +// want to prevent is sending a GET or other typically-bodyless request to a +// server with a chunked body when the body has zero bytes, since GETs with +// bodies (while acceptable according to specs), even zero-byte chunked +// bodies, are approximately never seen in the wild and confuse most +// servers. See Issue 18257, as one example. +// +// The only reason we'd send such a request is if the user set the Body to a +// non-nil value (say, ioutil.NopCloser(bytes.NewReader(nil))) and didn't +// set ContentLength, or NewRequest set it to -1 (unknown), so then we assume +// there's bytes to send. +// +// This code tries to read a byte from the Request.Body in such cases to see +// whether the body actually has content (super rare) or is actually just +// a non-nil content-less ReadCloser (the more common case). In that more +// common case, we act as if their Body were nil instead, and don't send +// a body. +func (t *transferWriter) shouldSendChunkedRequestBody() bool { + // Note that t.ContentLength is the corrected content length + // from rr.outgoingLength, so 0 actually means zero, not unknown. + if t.ContentLength >= 0 || t.Body == nil { // redundant checks; caller did them + return false + } + if requestMethodUsuallyLacksBody(t.Method) { + // Only probe the Request.Body for GET/HEAD/DELETE/etc + // requests, because it's only those types of requests + // that confuse servers. + t.probeRequestBody() // adjusts t.Body, t.ContentLength + return t.Body != nil + } + // For all other request types (PUT, POST, PATCH, or anything + // made-up we've never heard of), assume it's normal and the server + // can deal with a chunked request body. Maybe we'll adjust this + // later. + return true +} + +// probeRequestBody reads a byte from t.Body to see whether it's empty +// (returns io.EOF right away). +// +// But because we've had problems with this blocking users in the past +// (issue 17480) when the body is a pipe (perhaps waiting on the response +// headers before the pipe is fed data), we need to be careful and bound how +// long we wait for it. This delay will only affect users if all the following +// are true: +// * the request body blocks +// * the content length is not set (or set to -1) +// * the method doesn't usually have a body (GET, HEAD, DELETE, ...) +// * there is no transfer-encoding=chunked already set. +// In other words, this delay will not normally affect anybody, and there +// are workarounds if it does. +func (t *transferWriter) probeRequestBody() { + t.ByteReadCh = make(chan readResult, 1) + go func(body io.Reader) { + var buf [1]byte + var rres readResult + rres.n, rres.err = body.Read(buf[:]) + if rres.n == 1 { + rres.b = buf[0] + } + t.ByteReadCh <- rres + }(t.Body) + timer := time.NewTimer(200 * time.Millisecond) + select { + case rres := <-t.ByteReadCh: + timer.Stop() + if rres.n == 0 && rres.err == io.EOF { + // It was empty. + t.Body = nil + t.ContentLength = 0 + } else if rres.n == 1 { + if rres.err != nil { + t.Body = io.MultiReader(&byteReader{b: rres.b}, errorReader{rres.err}) + } else { + t.Body = io.MultiReader(&byteReader{b: rres.b}, t.Body) + } + } else if rres.err != nil { + t.Body = errorReader{rres.err} + } + case <-timer.C: + // Too slow. Don't wait. Read it later, and keep + // assuming that this is ContentLength == -1 + // (unknown), which means we'll send a + // "Transfer-Encoding: chunked" header. + t.Body = io.MultiReader(finishAsyncByteRead{t}, t.Body) + // Request that Request.Write flush the headers to the + // network before writing the body, since our body may not + // become readable until it's seen the response headers. + t.FlushHeaders = true + } +} + +func noResponseBodyExpected(requestMethod string) bool { return requestMethod == "HEAD" } @@ -214,7 +306,7 @@ func (t *transferWriter) WriteBody(w io.Writer) error { if t.Body != nil { if chunked(t.TransferEncoding) { if bw, ok := w.(*bufio.Writer); ok && !t.IsResponse { - w = &internal.FlushAfterChunkWriter{bw} + w = &internal.FlushAfterChunkWriter{Writer: bw} } cw := internal.NewChunkedWriter(w) _, err = io.Copy(cw, t.Body) @@ -235,7 +327,9 @@ func (t *transferWriter) WriteBody(w io.Writer) error { if err != nil { return err } - if err = t.BodyCloser.Close(); err != nil { + } + if t.BodyCloser != nil { + if err := t.BodyCloser.Close(); err != nil { return err } } @@ -385,13 +479,13 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { // or close connection when finished, since multipart is not supported yet switch { case chunked(t.TransferEncoding): - if noBodyExpected(t.RequestMethod) { - t.Body = eofReader + if noResponseBodyExpected(t.RequestMethod) { + t.Body = NoBody } else { t.Body = &body{src: internal.NewChunkedReader(r), hdr: msg, r: r, closing: t.Close} } case realLength == 0: - t.Body = eofReader + t.Body = NoBody case realLength > 0: t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close} default: @@ -401,7 +495,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { t.Body = &body{src: r, closing: t.Close} } else { // Persistent connection (i.e. HTTP/1.1) - t.Body = eofReader + t.Body = NoBody } } @@ -493,10 +587,31 @@ func (t *transferReader) fixTransferEncoding() error { // function is not a method, because ultimately it should be shared by // ReadResponse and ReadRequest. func fixLength(isResponse bool, status int, requestMethod string, header Header, te []string) (int64, error) { - contentLens := header["Content-Length"] isRequest := !isResponse + contentLens := header["Content-Length"] + + // Hardening against HTTP request smuggling + if len(contentLens) > 1 { + // Per RFC 7230 Section 3.3.2, prevent multiple + // Content-Length headers if they differ in value. + // If there are dups of the value, remove the dups. + // See Issue 16490. + first := strings.TrimSpace(contentLens[0]) + for _, ct := range contentLens[1:] { + if first != strings.TrimSpace(ct) { + return 0, fmt.Errorf("http: message cannot contain multiple Content-Length headers; got %q", contentLens) + } + } + + // deduplicate Content-Length + header.Del("Content-Length") + header.Add("Content-Length", first) + + contentLens = header["Content-Length"] + } + // Logic based on response type or status - if noBodyExpected(requestMethod) { + if noResponseBodyExpected(requestMethod) { // For HTTP requests, as part of hardening against request // smuggling (RFC 7230), don't allow a Content-Length header for // methods which don't permit bodies. As an exception, allow @@ -514,11 +629,6 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header, return 0, nil } - if len(contentLens) > 1 { - // harden against HTTP request smuggling. See RFC 7230. - return 0, errors.New("http: message cannot contain multiple Content-Length headers") - } - // Logic based on Transfer-Encoding if chunked(te) { return -1, nil @@ -539,7 +649,7 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header, header.Del("Content-Length") } - if !isResponse { + if isRequest { // RFC 2616 neither explicitly permits nor forbids an // entity-body on a GET request so we permit one if // declared, but we default to 0 here (not -1 below) @@ -864,3 +974,21 @@ func parseContentLength(cl string) (int64, error) { return n, nil } + +// finishAsyncByteRead finishes reading the 1-byte sniff +// from the ContentLength==0, Body!=nil case. +type finishAsyncByteRead struct { + tw *transferWriter +} + +func (fr finishAsyncByteRead) Read(p []byte) (n int, err error) { + if len(p) == 0 { + return + } + rres := <-fr.tw.ByteReadCh + n, err = rres.n, rres.err + if n == 1 { + p[0] = rres.b + } + return +} diff --git a/libgo/go/net/http/transport.go b/libgo/go/net/http/transport.go index 1f07634..571943d6 100644 --- a/libgo/go/net/http/transport.go +++ b/libgo/go/net/http/transport.go @@ -25,6 +25,7 @@ import ( "os" "strings" "sync" + "sync/atomic" "time" "golang_org/x/net/lex/httplex" @@ -40,6 +41,7 @@ var DefaultTransport RoundTripper = &Transport{ DialContext: (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, + DualStack: true, }).DialContext, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, @@ -66,8 +68,10 @@ const DefaultMaxIdleConnsPerHost = 2 // For high-level functionality, such as cookies and redirects, see Client. // // Transport uses HTTP/1.1 for HTTP URLs and either HTTP/1.1 or HTTP/2 -// for HTTPS URLs, depending on whether the server supports HTTP/2. -// See the package docs for more about HTTP/2. +// for HTTPS URLs, depending on whether the server supports HTTP/2, +// and how the Transport is configured. The DefaultTransport supports HTTP/2. +// To explicitly enable HTTP/2 on a transport, use golang.org/x/net/http2 +// and call ConfigureTransport. See the package docs for more about HTTP/2. type Transport struct { idleMu sync.Mutex wantIdle bool // user has requested to close all idle conns @@ -76,10 +80,10 @@ type Transport struct { idleLRU connLRU reqMu sync.Mutex - reqCanceler map[*Request]func() + reqCanceler map[*Request]func(error) - altMu sync.RWMutex - altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper + altMu sync.Mutex // guards changing altProto only + altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the @@ -111,7 +115,9 @@ type Transport struct { DialTLS func(network, addr string) (net.Conn, error) // TLSClientConfig specifies the TLS configuration to use with - // tls.Client. If nil, the default configuration is used. + // tls.Client. + // If nil, the default configuration is used. + // If non-nil, HTTP/2 support may not be enabled by default. TLSClientConfig *tls.Config // TLSHandshakeTimeout specifies the maximum amount of time waiting to @@ -156,7 +162,9 @@ type Transport struct { // ExpectContinueTimeout, if non-zero, specifies the amount of // time to wait for a server's first response headers after fully // writing the request headers if the request has an - // "Expect: 100-continue" header. Zero means no timeout. + // "Expect: 100-continue" header. Zero means no timeout and + // causes the body to be sent immediately, without + // waiting for the server to approve. // This time does not include the time to send the request header. ExpectContinueTimeout time.Duration @@ -168,9 +176,14 @@ type Transport struct { // called with the request's authority (such as "example.com" // or "example.com:1234") and the TLS connection. The function // must return a RoundTripper that then handles the request. - // If TLSNextProto is nil, HTTP/2 support is enabled automatically. + // If TLSNextProto is not nil, HTTP/2 support is not enabled + // automatically. TLSNextProto map[string]func(authority string, c *tls.Conn) RoundTripper + // ProxyConnectHeader optionally specifies headers to send to + // proxies during CONNECT requests. + ProxyConnectHeader Header + // MaxResponseHeaderBytes specifies a limit on how many // response bytes are allowed in the server's response // header. @@ -330,11 +343,9 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { } } } - // TODO(bradfitz): switch to atomic.Value for this map instead of RWMutex - t.altMu.RLock() - altRT := t.altProto[scheme] - t.altMu.RUnlock() - if altRT != nil { + + altProto, _ := t.altProto.Load().(map[string]RoundTripper) + if altRT := altProto[scheme]; altRT != nil { if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol { return resp, err } @@ -421,19 +432,15 @@ func (pc *persistConn) shouldRetryRequest(req *Request, err error) bool { // our request (as opposed to sending an error). return false } + if _, ok := err.(nothingWrittenError); ok { + // We never wrote anything, so it's safe to retry. + return true + } if !req.isReplayable() { // Don't retry non-idempotent requests. - - // TODO: swap the nothingWrittenError and isReplayable checks, - // putting the "if nothingWrittenError => return true" case - // first, per golang.org/issue/15723 return false } - switch err.(type) { - case nothingWrittenError: - // We never wrote anything, so it's safe to retry. - return true - case transportReadFromServerError: + if _, ok := err.(transportReadFromServerError); ok { // We got some non-EOF net.Conn.Read failure reading // the 1st response byte from the server. return true @@ -463,13 +470,16 @@ var ErrSkipAltProtocol = errors.New("net/http: skip alternate protocol") func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) { t.altMu.Lock() defer t.altMu.Unlock() - if t.altProto == nil { - t.altProto = make(map[string]RoundTripper) - } - if _, exists := t.altProto[scheme]; exists { + oldMap, _ := t.altProto.Load().(map[string]RoundTripper) + if _, exists := oldMap[scheme]; exists { panic("protocol " + scheme + " already registered") } - t.altProto[scheme] = rt + newMap := make(map[string]RoundTripper) + for k, v := range oldMap { + newMap[k] = v + } + newMap[scheme] = rt + t.altProto.Store(newMap) } // CloseIdleConnections closes any connections which were previously @@ -502,12 +512,17 @@ func (t *Transport) CloseIdleConnections() { // cancelable context instead. CancelRequest cannot cancel HTTP/2 // requests. func (t *Transport) CancelRequest(req *Request) { + t.cancelRequest(req, errRequestCanceled) +} + +// Cancel an in-flight request, recording the error value. +func (t *Transport) cancelRequest(req *Request, err error) { t.reqMu.Lock() cancel := t.reqCanceler[req] delete(t.reqCanceler, req) t.reqMu.Unlock() if cancel != nil { - cancel() + cancel(err) } } @@ -557,10 +572,18 @@ func (e *envOnce) reset() { } func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { + if port := treq.URL.Port(); !validPort(port) { + return cm, fmt.Errorf("invalid URL port %q", port) + } cm.targetScheme = treq.URL.Scheme cm.targetAddr = canonicalAddr(treq.URL) if t.Proxy != nil { cm.proxyURL, err = t.Proxy(treq.Request) + if err == nil && cm.proxyURL != nil { + if port := cm.proxyURL.Port(); !validPort(port) { + return cm, fmt.Errorf("invalid proxy URL port %q", port) + } + } } return cm, err } @@ -787,11 +810,11 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) { } } -func (t *Transport) setReqCanceler(r *Request, fn func()) { +func (t *Transport) setReqCanceler(r *Request, fn func(error)) { t.reqMu.Lock() defer t.reqMu.Unlock() if t.reqCanceler == nil { - t.reqCanceler = make(map[*Request]func()) + t.reqCanceler = make(map[*Request]func(error)) } if fn != nil { t.reqCanceler[r] = fn @@ -804,7 +827,7 @@ func (t *Transport) setReqCanceler(r *Request, fn func()) { // for the request, we don't set the function and return false. // Since CancelRequest will clear the canceler, we can use the return value to detect if // the request was canceled since the last setReqCancel call. -func (t *Transport) replaceReqCanceler(r *Request, fn func()) bool { +func (t *Transport) replaceReqCanceler(r *Request, fn func(error)) bool { t.reqMu.Lock() defer t.reqMu.Unlock() _, ok := t.reqCanceler[r] @@ -853,7 +876,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC // set request canceler to some non-nil function so we // can detect whether it was cleared between now and when // we enter roundTrip - t.setReqCanceler(req, func() {}) + t.setReqCanceler(req, func(error) {}) return pc, nil } @@ -878,8 +901,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC }() } - cancelc := make(chan struct{}) - t.setReqCanceler(req, func() { close(cancelc) }) + cancelc := make(chan error, 1) + t.setReqCanceler(req, func(err error) { cancelc <- err }) go func() { pc, err := t.dialConn(ctx, cm) @@ -900,16 +923,21 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC // value. select { case <-req.Cancel: + // It was an error due to cancelation, so prioritize that + // error value. (Issue 16049) + return nil, errRequestCanceledConn case <-req.Context().Done(): - case <-cancelc: + return nil, req.Context().Err() + case err := <-cancelc: + if err == errRequestCanceled { + err = errRequestCanceledConn + } + return nil, err default: // It wasn't an error due to cancelation, so // return the original error message: return nil, v.err } - // It was an error due to cancelation, so prioritize that - // error value. (Issue 16049) - return nil, errRequestCanceledConn case pc := <-idleConnCh: // Another request finished first and its net.Conn // became available before our dial. Or somebody @@ -926,10 +954,13 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC return nil, errRequestCanceledConn case <-req.Context().Done(): handlePendingDial() - return nil, errRequestCanceledConn - case <-cancelc: + return nil, req.Context().Err() + case err := <-cancelc: handlePendingDial() - return nil, errRequestCanceledConn + if err == errRequestCanceled { + err = errRequestCanceledConn + } + return nil, err } } @@ -943,6 +974,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon writeErrCh: make(chan error, 1), writeLoopDone: make(chan struct{}), } + trace := httptrace.ContextClientTrace(ctx) tlsDial := t.DialTLS != nil && cm.targetScheme == "https" && cm.proxyURL == nil if tlsDial { var err error @@ -956,18 +988,28 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon if tc, ok := pconn.conn.(*tls.Conn); ok { // Handshake here, in case DialTLS didn't. TLSNextProto below // depends on it for knowing the connection state. + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } if err := tc.Handshake(); err != nil { go pconn.conn.Close() + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tls.ConnectionState{}, err) + } return nil, err } cs := tc.ConnectionState() + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(cs, nil) + } pconn.tlsState = &cs } } else { conn, err := t.dial(ctx, "tcp", cm.addr()) if err != nil { if cm.proxyURL != nil { - err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err) + // Return a typed error, per Issue 16997: + err = &net.OpError{Op: "proxyconnect", Net: "tcp", Err: err} } return nil, err } @@ -987,11 +1029,15 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon } case cm.targetScheme == "https": conn := pconn.conn + hdr := t.ProxyConnectHeader + if hdr == nil { + hdr = make(Header) + } connectReq := &Request{ Method: "CONNECT", URL: &url.URL{Opaque: cm.targetAddr}, Host: cm.targetAddr, - Header: make(Header), + Header: hdr, } if pa := cm.proxyAuth(); pa != "" { connectReq.Header.Set("Proxy-Authorization", pa) @@ -1016,7 +1062,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon if cm.targetScheme == "https" && !tlsDial { // Initiate TLS and check remote host name against certificate. - cfg := cloneTLSClientConfig(t.TLSClientConfig) + cfg := cloneTLSConfig(t.TLSClientConfig) if cfg.ServerName == "" { cfg.ServerName = cm.tlsHost() } @@ -1030,6 +1076,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon }) } go func() { + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } err := tlsConn.Handshake() if timer != nil { timer.Stop() @@ -1038,6 +1087,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon }() if err := <-errc; err != nil { plainConn.Close() + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tls.ConnectionState{}, err) + } return nil, err } if !cfg.InsecureSkipVerify { @@ -1047,6 +1099,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon } } cs := tlsConn.ConnectionState() + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(cs, nil) + } pconn.tlsState = &cs pconn.conn = tlsConn } @@ -1235,8 +1290,8 @@ type persistConn struct { mu sync.Mutex // guards following fields numExpectedResponses int closed error // set non-nil when conn is closed, before closech is closed + canceledErr error // set non-nil if conn is canceled broken bool // an error has happened on this connection; marked broken so it's not reused. - canceled bool // whether this conn was broken due a CancelRequest reused bool // whether conn has had successful request/response and is being reused. // mutateHeaderFunc is an optional func to modify extra // headers on each outbound request before it's written. (the @@ -1274,11 +1329,12 @@ func (pc *persistConn) isBroken() bool { return b } -// isCanceled reports whether this connection was closed due to CancelRequest. -func (pc *persistConn) isCanceled() bool { +// canceled returns non-nil if the connection was closed due to +// CancelRequest or due to context cancelation. +func (pc *persistConn) canceled() error { pc.mu.Lock() defer pc.mu.Unlock() - return pc.canceled + return pc.canceledErr } // isReused reports whether this connection is in a known broken state. @@ -1301,10 +1357,10 @@ func (pc *persistConn) gotIdleConnTrace(idleAt time.Time) (t httptrace.GotConnIn return } -func (pc *persistConn) cancelRequest() { +func (pc *persistConn) cancelRequest(err error) { pc.mu.Lock() defer pc.mu.Unlock() - pc.canceled = true + pc.canceledErr = err pc.closeLocked(errRequestCanceled) } @@ -1328,12 +1384,12 @@ func (pc *persistConn) closeConnIfStillIdle() { // // The startBytesWritten value should be the value of pc.nwrite before the roundTrip // started writing the request. -func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, err error) (out error) { +func (pc *persistConn) mapRoundTripErrorFromReadLoop(req *Request, startBytesWritten int64, err error) (out error) { if err == nil { return nil } - if pc.isCanceled() { - return errRequestCanceled + if err := pc.canceled(); err != nil { + return err } if err == errServerClosedIdle { return err @@ -1343,7 +1399,7 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, er } if pc.isBroken() { <-pc.writeLoopDone - if pc.nwrite == startBytesWritten { + if pc.nwrite == startBytesWritten && req.outgoingLength() == 0 { return nothingWrittenError{err} } } @@ -1354,9 +1410,9 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, er // up to Transport.RoundTrip method when persistConn.roundTrip sees // its pc.closech channel close, indicating the persistConn is dead. // (after closech is closed, pc.closed is valid). -func (pc *persistConn) mapRoundTripErrorAfterClosed(startBytesWritten int64) error { - if pc.isCanceled() { - return errRequestCanceled +func (pc *persistConn) mapRoundTripErrorAfterClosed(req *Request, startBytesWritten int64) error { + if err := pc.canceled(); err != nil { + return err } err := pc.closed if err == errServerClosedIdle { @@ -1372,7 +1428,7 @@ func (pc *persistConn) mapRoundTripErrorAfterClosed(startBytesWritten int64) err // see if we actually managed to write anything. If not, we // can retry the request. <-pc.writeLoopDone - if pc.nwrite == startBytesWritten { + if pc.nwrite == startBytesWritten && req.outgoingLength() == 0 { return nothingWrittenError{err} } @@ -1513,8 +1569,10 @@ func (pc *persistConn) readLoop() { waitForBodyRead <- isEOF if isEOF { <-eofc // see comment above eofc declaration - } else if err != nil && pc.isCanceled() { - return errRequestCanceled + } else if err != nil { + if cerr := pc.canceled(); cerr != nil { + return cerr + } } return err }, @@ -1554,7 +1612,7 @@ func (pc *persistConn) readLoop() { pc.t.CancelRequest(rc.req) case <-rc.req.Context().Done(): alive = false - pc.t.CancelRequest(rc.req) + pc.t.cancelRequest(rc.req, rc.req.Context().Err()) case <-pc.closech: alive = false } @@ -1652,7 +1710,7 @@ func (pc *persistConn) writeLoop() { } if err != nil { wr.req.Request.closeBody() - if pc.nwrite == startBytesWritten { + if pc.nwrite == startBytesWritten && wr.req.outgoingLength() == 0 { err = nothingWrittenError{err} } } @@ -1840,8 +1898,8 @@ WaitResponse: select { case err := <-writeErrCh: if err != nil { - if pc.isCanceled() { - err = errRequestCanceled + if cerr := pc.canceled(); cerr != nil { + err = cerr } re = responseAndError{err: err} pc.close(fmt.Errorf("write error: %v", err)) @@ -1853,21 +1911,20 @@ WaitResponse: respHeaderTimer = timer.C } case <-pc.closech: - re = responseAndError{err: pc.mapRoundTripErrorAfterClosed(startBytesWritten)} + re = responseAndError{err: pc.mapRoundTripErrorAfterClosed(req.Request, startBytesWritten)} break WaitResponse case <-respHeaderTimer: pc.close(errTimeout) re = responseAndError{err: errTimeout} break WaitResponse case re = <-resc: - re.err = pc.mapRoundTripErrorFromReadLoop(startBytesWritten, re.err) + re.err = pc.mapRoundTripErrorFromReadLoop(req.Request, startBytesWritten, re.err) break WaitResponse case <-cancelChan: pc.t.CancelRequest(req.Request) cancelChan = nil - ctxDoneChan = nil case <-ctxDoneChan: - pc.t.CancelRequest(req.Request) + pc.t.cancelRequest(req.Request, req.Context().Err()) cancelChan = nil ctxDoneChan = nil } @@ -1931,11 +1988,15 @@ var portMap = map[string]string{ // canonicalAddr returns url.Host but always with a ":port" suffix func canonicalAddr(url *url.URL) string { - addr := url.Host - if !hasPort(addr) { - return addr + ":" + portMap[url.Scheme] + addr := url.Hostname() + if v, err := idnaASCII(addr); err == nil { + addr = v + } + port := url.Port() + if port == "" { + port = portMap[url.Scheme] } - return addr + return net.JoinHostPort(addr, port) } // bodyEOFSignal is used by the HTTP/1 transport when reading response @@ -2060,75 +2121,14 @@ type fakeLocker struct{} func (fakeLocker) Lock() {} func (fakeLocker) Unlock() {} -// cloneTLSConfig returns a shallow clone of the exported -// fields of cfg, ignoring the unexported sync.Once, which -// contains a mutex and must not be copied. -// -// The cfg must not be in active use by tls.Server, or else -// there can still be a race with tls.Server updating SessionTicketKey -// and our copying it, and also a race with the server setting -// SessionTicketsDisabled=false on failure to set the random -// ticket key. -// -// If cfg is nil, a new zero tls.Config is returned. +// clneTLSConfig returns a shallow clone of cfg, or a new zero tls.Config if +// cfg is nil. This is safe to call even if cfg is in active use by a TLS +// client or server. func cloneTLSConfig(cfg *tls.Config) *tls.Config { if cfg == nil { return &tls.Config{} } - return &tls.Config{ - Rand: cfg.Rand, - Time: cfg.Time, - Certificates: cfg.Certificates, - NameToCertificate: cfg.NameToCertificate, - GetCertificate: cfg.GetCertificate, - RootCAs: cfg.RootCAs, - NextProtos: cfg.NextProtos, - ServerName: cfg.ServerName, - ClientAuth: cfg.ClientAuth, - ClientCAs: cfg.ClientCAs, - InsecureSkipVerify: cfg.InsecureSkipVerify, - CipherSuites: cfg.CipherSuites, - PreferServerCipherSuites: cfg.PreferServerCipherSuites, - SessionTicketsDisabled: cfg.SessionTicketsDisabled, - SessionTicketKey: cfg.SessionTicketKey, - ClientSessionCache: cfg.ClientSessionCache, - MinVersion: cfg.MinVersion, - MaxVersion: cfg.MaxVersion, - CurvePreferences: cfg.CurvePreferences, - DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled, - Renegotiation: cfg.Renegotiation, - } -} - -// cloneTLSClientConfig is like cloneTLSConfig but omits -// the fields SessionTicketsDisabled and SessionTicketKey. -// This makes it safe to call cloneTLSClientConfig on a config -// in active use by a server. -func cloneTLSClientConfig(cfg *tls.Config) *tls.Config { - if cfg == nil { - return &tls.Config{} - } - return &tls.Config{ - Rand: cfg.Rand, - Time: cfg.Time, - Certificates: cfg.Certificates, - NameToCertificate: cfg.NameToCertificate, - GetCertificate: cfg.GetCertificate, - RootCAs: cfg.RootCAs, - NextProtos: cfg.NextProtos, - ServerName: cfg.ServerName, - ClientAuth: cfg.ClientAuth, - ClientCAs: cfg.ClientCAs, - InsecureSkipVerify: cfg.InsecureSkipVerify, - CipherSuites: cfg.CipherSuites, - PreferServerCipherSuites: cfg.PreferServerCipherSuites, - ClientSessionCache: cfg.ClientSessionCache, - MinVersion: cfg.MinVersion, - MaxVersion: cfg.MaxVersion, - CurvePreferences: cfg.CurvePreferences, - DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled, - Renegotiation: cfg.Renegotiation, - } + return cfg.Clone() } type connLRU struct { @@ -2169,3 +2169,15 @@ func (cl *connLRU) remove(pc *persistConn) { func (cl *connLRU) len() int { return len(cl.m) } + +// validPort reports whether p (without the colon) is a valid port in +// a URL, per RFC 3986 Section 3.2.3, which says the port may be +// empty, or only contain digits. +func validPort(p string) bool { + for _, r := range []byte(p) { + if r < '0' || r > '9' { + return false + } + } + return true +} diff --git a/libgo/go/net/http/transport_internal_test.go b/libgo/go/net/http/transport_internal_test.go index a05ca6e..3d24fc1 100644 --- a/libgo/go/net/http/transport_internal_test.go +++ b/libgo/go/net/http/transport_internal_test.go @@ -72,3 +72,70 @@ func newLocalListener(t *testing.T) net.Listener { } return ln } + +func dummyRequest(method string) *Request { + req, err := NewRequest(method, "http://fake.tld/", nil) + if err != nil { + panic(err) + } + return req +} + +func TestTransportShouldRetryRequest(t *testing.T) { + tests := []struct { + pc *persistConn + req *Request + + err error + want bool + }{ + 0: { + pc: &persistConn{reused: false}, + req: dummyRequest("POST"), + err: nothingWrittenError{}, + want: false, + }, + 1: { + pc: &persistConn{reused: true}, + req: dummyRequest("POST"), + err: nothingWrittenError{}, + want: true, + }, + 2: { + pc: &persistConn{reused: true}, + req: dummyRequest("POST"), + err: http2ErrNoCachedConn, + want: true, + }, + 3: { + pc: &persistConn{reused: true}, + req: dummyRequest("POST"), + err: errMissingHost, + want: false, + }, + 4: { + pc: &persistConn{reused: true}, + req: dummyRequest("POST"), + err: transportReadFromServerError{}, + want: false, + }, + 5: { + pc: &persistConn{reused: true}, + req: dummyRequest("GET"), + err: transportReadFromServerError{}, + want: true, + }, + 6: { + pc: &persistConn{reused: true}, + req: dummyRequest("GET"), + err: errServerClosedIdle, + want: true, + }, + } + for i, tt := range tests { + got := tt.pc.shouldRetryRequest(tt.req, tt.err) + if got != tt.want { + t.Errorf("%d. shouldRetryRequest = %v; want %v", i, got, tt.want) + } + } +} diff --git a/libgo/go/net/http/transport_test.go b/libgo/go/net/http/transport_test.go index 298682d..d5ddf6a 100644 --- a/libgo/go/net/http/transport_test.go +++ b/libgo/go/net/http/transport_test.go @@ -441,9 +441,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { } func TestTransportRemovesDeadIdleConnections(t *testing.T) { - if runtime.GOOS == "plan9" { - t.Skip("skipping test; see https://golang.org/issue/15464") - } + setParallel(t) defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, r.RemoteAddr) @@ -700,6 +698,7 @@ var roundTripTests = []struct { // Test that the modification made to the Request by the RoundTripper is cleaned up func TestRoundTripGzip(t *testing.T) { + setParallel(t) defer afterTest(t) const responseBody = "test response body" ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { @@ -758,6 +757,7 @@ func TestRoundTripGzip(t *testing.T) { } func TestTransportGzip(t *testing.T) { + setParallel(t) defer afterTest(t) const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" const nRandBytes = 1024 * 1024 @@ -856,6 +856,7 @@ func TestTransportGzip(t *testing.T) { // If a request has Expect:100-continue header, the request blocks sending body until the first response. // Premature consumption of the request body should not be occurred. func TestTransportExpect100Continue(t *testing.T) { + setParallel(t) defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { @@ -966,6 +967,48 @@ func TestTransportProxy(t *testing.T) { } } +// Issue 16997: test transport dial preserves typed errors +func TestTransportDialPreservesNetOpProxyError(t *testing.T) { + defer afterTest(t) + + var errDial = errors.New("some dial error") + + tr := &Transport{ + Proxy: func(*Request) (*url.URL, error) { + return url.Parse("http://proxy.fake.tld/") + }, + Dial: func(string, string) (net.Conn, error) { + return nil, errDial + }, + } + defer tr.CloseIdleConnections() + + c := &Client{Transport: tr} + req, _ := NewRequest("GET", "http://fake.tld", nil) + res, err := c.Do(req) + if err == nil { + res.Body.Close() + t.Fatal("wanted a non-nil error") + } + + uerr, ok := err.(*url.Error) + if !ok { + t.Fatalf("got %T, want *url.Error", err) + } + oe, ok := uerr.Err.(*net.OpError) + if !ok { + t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err) + } + want := &net.OpError{ + Op: "proxyconnect", + Net: "tcp", + Err: errDial, // original error, unwrapped. + } + if !reflect.DeepEqual(oe, want) { + t.Errorf("Got error %#v; want %#v", oe, want) + } +} + // TestTransportGzipRecursive sends a gzip quine and checks that the // client gets the same value back. This is more cute than anything, // but checks that we don't recurse forever, and checks that @@ -1038,10 +1081,12 @@ func waitNumGoroutine(nmax int) int { // tests that persistent goroutine connections shut down when no longer desired. func TestTransportPersistConnLeak(t *testing.T) { - setParallel(t) + // Not parallel: counts goroutines defer afterTest(t) - gotReqCh := make(chan bool) - unblockCh := make(chan bool) + + const numReq = 25 + gotReqCh := make(chan bool, numReq) + unblockCh := make(chan bool, numReq) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { gotReqCh <- true <-unblockCh @@ -1055,14 +1100,15 @@ func TestTransportPersistConnLeak(t *testing.T) { n0 := runtime.NumGoroutine() - const numReq = 25 - didReqCh := make(chan bool) + didReqCh := make(chan bool, numReq) + failed := make(chan bool, numReq) for i := 0; i < numReq; i++ { go func() { res, err := c.Get(ts.URL) didReqCh <- true if err != nil { t.Errorf("client fetch error: %v", err) + failed <- true return } res.Body.Close() @@ -1071,7 +1117,13 @@ func TestTransportPersistConnLeak(t *testing.T) { // Wait for all goroutines to be stuck in the Handler. for i := 0; i < numReq; i++ { - <-gotReqCh + select { + case <-gotReqCh: + // ok + case <-failed: + close(unblockCh) + return + } } nhigh := runtime.NumGoroutine() @@ -1102,7 +1154,7 @@ func TestTransportPersistConnLeak(t *testing.T) { // golang.org/issue/4531: Transport leaks goroutines when // request.ContentLength is explicitly short func TestTransportPersistConnLeakShortBody(t *testing.T) { - setParallel(t) + // Not parallel: measures goroutines. defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { })) @@ -1198,6 +1250,7 @@ func TestIssue3644(t *testing.T) { // Test that a client receives a server's reply, even if the server doesn't read // the entire request body. func TestIssue3595(t *testing.T) { + setParallel(t) defer afterTest(t) const deniedMsg = "sorry, denied." ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -1246,6 +1299,7 @@ func TestChunkedNoContent(t *testing.T) { } func TestTransportConcurrency(t *testing.T) { + // Not parallel: uses global test hooks. defer afterTest(t) maxProcs, numReqs := 16, 500 if testing.Short() { @@ -1306,9 +1360,7 @@ func TestTransportConcurrency(t *testing.T) { } func TestIssue4191_InfiniteGetTimeout(t *testing.T) { - if runtime.GOOS == "plan9" { - t.Skip("skipping test; see https://golang.org/issue/7237") - } + setParallel(t) defer afterTest(t) const debug = false mux := NewServeMux() @@ -1370,9 +1422,7 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { } func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { - if runtime.GOOS == "plan9" { - t.Skip("skipping test; see https://golang.org/issue/7237") - } + setParallel(t) defer afterTest(t) const debug = false mux := NewServeMux() @@ -1696,12 +1746,6 @@ func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) { defer ts.Close() defer close(unblockc) - // Don't interfere with the next test on plan9. - // Cf. https://golang.org/issues/11476 - if runtime.GOOS == "plan9" { - defer time.Sleep(500 * time.Millisecond) - } - tr := &Transport{} defer tr.CloseIdleConnections() c := &Client{Transport: tr} @@ -1718,8 +1762,17 @@ func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) { } _, err := c.Do(req) - if err == nil || !strings.Contains(err.Error(), "canceled") { - t.Errorf("Do error = %v; want cancelation", err) + if ue, ok := err.(*url.Error); ok { + err = ue.Err + } + if withCtx { + if err != context.Canceled { + t.Errorf("Do error = %v; want %v", err, context.Canceled) + } + } else { + if err == nil || !strings.Contains(err.Error(), "canceled") { + t.Errorf("Do error = %v; want cancelation", err) + } } } @@ -1888,6 +1941,7 @@ func TestTransportEmptyMethod(t *testing.T) { } func TestTransportSocketLateBinding(t *testing.T) { + setParallel(t) defer afterTest(t) mux := NewServeMux() @@ -2152,6 +2206,7 @@ func TestProxyFromEnvironment(t *testing.T) { } func TestIdleConnChannelLeak(t *testing.T) { + // Not parallel: uses global test hooks. var mu sync.Mutex var n int @@ -2383,6 +2438,7 @@ func (c byteFromChanReader) Read(p []byte) (n int, err error) { // questionable state. // golang.org/issue/7569 func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { + setParallel(t) defer afterTest(t) var sconn struct { sync.Mutex @@ -2485,22 +2541,6 @@ type errorReader struct { func (e errorReader) Read(p []byte) (int, error) { return 0, e.err } -type plan9SleepReader struct{} - -func (plan9SleepReader) Read(p []byte) (int, error) { - if runtime.GOOS == "plan9" { - // After the fix to unblock TCP Reads in - // https://golang.org/cl/15941, this sleep is required - // on plan9 to make sure TCP Writes before an - // immediate TCP close go out on the wire. On Plan 9, - // it seems that a hangup of a TCP connection with - // queued data doesn't send the queued data first. - // https://golang.org/issue/9554 - time.Sleep(50 * time.Millisecond) - } - return 0, io.EOF -} - type closerFunc func() error func (f closerFunc) Close() error { return f() } @@ -2595,7 +2635,7 @@ func TestTransportClosesBodyOnError(t *testing.T) { io.Reader io.Closer }{ - io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), plan9SleepReader{}, errorReader{fakeErr}), + io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), errorReader{fakeErr}), closerFunc(func() error { select { case didClose <- true: @@ -2627,6 +2667,8 @@ func TestTransportClosesBodyOnError(t *testing.T) { } func TestTransportDialTLS(t *testing.T) { + setParallel(t) + defer afterTest(t) var mu sync.Mutex // guards following var gotReq, didDial bool @@ -2904,14 +2946,8 @@ func TestTransportFlushesBodyChunks(t *testing.T) { defer res.Body.Close() want := []string{ - // Because Request.ContentLength = 0, the body is sniffed for 1 byte to determine whether there's content. - // That explains the initial "num0" being split into "n" and "um0". - // The first byte is included with the request headers Write. Perhaps in the future - // we will want to flush the headers out early if the first byte of the request body is - // taking a long time to arrive. But not yet. "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n" + - "1\r\nn\r\n", - "4\r\num0\n\r\n", + "5\r\nnum0\n\r\n", "5\r\nnum1\n\r\n", "5\r\nnum2\n\r\n", "0\r\n\r\n", @@ -3150,6 +3186,7 @@ func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) { // Make sure we re-use underlying TCP connection for gzipped responses too. func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) { + setParallel(t) defer afterTest(t) addr := make(chan string, 2) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -3185,6 +3222,7 @@ func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) { } func TestTransportResponseHeaderLength(t *testing.T) { + setParallel(t) defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.URL.Path == "/long" { @@ -3248,7 +3286,7 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { cst.tr.ExpectContinueTimeout = 1 * time.Second - var mu sync.Mutex + var mu sync.Mutex // guards buf var buf bytes.Buffer logf := func(format string, args ...interface{}) { mu.Lock() @@ -3290,10 +3328,16 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { Wait100Continue: func() { logf("Wait100Continue") }, Got100Continue: func() { logf("Got100Continue") }, WroteRequest: func(e httptrace.WroteRequestInfo) { - close(gotWroteReqEvent) logf("WroteRequest: %+v", e) + close(gotWroteReqEvent) }, } + if h2 { + trace.TLSHandshakeStart = func() { logf("tls handshake start") } + trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) { + logf("tls handshake done. ConnectionState = %v \n err = %v", s, err) + } + } if noHooks { // zero out all func pointers, trying to get some path to crash *trace = httptrace.ClientTrace{} @@ -3323,7 +3367,10 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { return } + mu.Lock() got := buf.String() + mu.Unlock() + wantOnce := func(sub string) { if strings.Count(got, sub) != 1 { t.Errorf("expected substring %q exactly once in output.", sub) @@ -3342,7 +3389,10 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { wantOnceOrMore("connected to tcp " + addrStr + " = <nil>") wantOnce("Reused:false WasIdle:false IdleTime:0s") wantOnce("first response byte") - if !h2 { + if h2 { + wantOnce("tls handshake start") + wantOnce("tls handshake done") + } else { wantOnce("PutIdleConn = <nil>") } wantOnce("Wait100Continue") @@ -3357,12 +3407,21 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { } func TestTransportEventTraceRealDNS(t *testing.T) { + if testing.Short() && testenv.Builder() == "" { + // Skip this test in short mode (the default for + // all.bash), in case the user is using a shady/ISP + // DNS server hijacking queries. + // See issues 16732, 16716. + // Our builders use 8.8.8.8, though, which correctly + // returns NXDOMAIN, so still run this test there. + t.Skip("skipping in short mode") + } defer afterTest(t) tr := &Transport{} defer tr.CloseIdleConnections() c := &Client{Transport: tr} - var mu sync.Mutex + var mu sync.Mutex // guards buf var buf bytes.Buffer logf := func(format string, args ...interface{}) { mu.Lock() @@ -3386,7 +3445,10 @@ func TestTransportEventTraceRealDNS(t *testing.T) { t.Fatal("expected error during DNS lookup") } + mu.Lock() got := buf.String() + mu.Unlock() + wantSub := func(sub string) { if !strings.Contains(got, sub) { t.Errorf("expected substring %q in output.", sub) @@ -3402,6 +3464,73 @@ func TestTransportEventTraceRealDNS(t *testing.T) { } } +// Issue 14353: port can only contain digits. +func TestTransportRejectsAlphaPort(t *testing.T) { + res, err := Get("http://dummy.tld:123foo/bar") + if err == nil { + res.Body.Close() + t.Fatal("unexpected success") + } + ue, ok := err.(*url.Error) + if !ok { + t.Fatalf("got %#v; want *url.Error", err) + } + got := ue.Err.Error() + want := `invalid URL port "123foo"` + if got != want { + t.Errorf("got error %q; want %q", got, want) + } +} + +// Test the httptrace.TLSHandshake{Start,Done} hooks with a https http1 +// connections. The http2 test is done in TestTransportEventTrace_h2 +func TestTLSHandshakeTrace(t *testing.T) { + defer afterTest(t) + s := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + defer s.Close() + + var mu sync.Mutex + var start, done bool + trace := &httptrace.ClientTrace{ + TLSHandshakeStart: func() { + mu.Lock() + defer mu.Unlock() + start = true + }, + TLSHandshakeDone: func(s tls.ConnectionState, err error) { + mu.Lock() + defer mu.Unlock() + done = true + if err != nil { + t.Fatal("Expected error to be nil but was:", err) + } + }, + } + + tr := &Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + req, err := NewRequest("GET", s.URL, nil) + if err != nil { + t.Fatal("Unable to construct test request:", err) + } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + + r, err := c.Do(req) + if err != nil { + t.Fatal("Unexpected error making request:", err) + } + r.Body.Close() + mu.Lock() + defer mu.Unlock() + if !start { + t.Fatal("Expected TLSHandshakeStart to be called, but wasn't") + } + if !done { + t.Fatal("Expected TLSHandshakeDone to be called, but wasnt't") + } +} + func TestTransportMaxIdleConns(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -3457,27 +3586,36 @@ func TestTransportMaxIdleConns(t *testing.T) { } } -func TestTransportIdleConnTimeout(t *testing.T) { +func TestTransportIdleConnTimeout_h1(t *testing.T) { testTransportIdleConnTimeout(t, h1Mode) } +func TestTransportIdleConnTimeout_h2(t *testing.T) { testTransportIdleConnTimeout(t, h2Mode) } +func testTransportIdleConnTimeout(t *testing.T, h2 bool) { if testing.Short() { t.Skip("skipping in short mode") } defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + const timeout = 1 * time.Second + + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { // No body for convenience. })) - defer ts.Close() - - const timeout = 1 * time.Second - tr := &Transport{ - IdleConnTimeout: timeout, - } + defer cst.close() + tr := cst.tr + tr.IdleConnTimeout = timeout defer tr.CloseIdleConnections() c := &Client{Transport: tr} + idleConns := func() []string { + if h2 { + return tr.IdleConnStrsForTesting_h2() + } else { + return tr.IdleConnStrsForTesting() + } + } + var conn string doReq := func(n int) { - req, _ := NewRequest("GET", ts.URL, nil) + req, _ := NewRequest("GET", cst.ts.URL, nil) req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ PutIdleConn: func(err error) { if err != nil { @@ -3490,7 +3628,7 @@ func TestTransportIdleConnTimeout(t *testing.T) { t.Fatal(err) } res.Body.Close() - conns := tr.IdleConnStrsForTesting() + conns := idleConns() if len(conns) != 1 { t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns) } @@ -3506,7 +3644,7 @@ func TestTransportIdleConnTimeout(t *testing.T) { time.Sleep(timeout / 2) } time.Sleep(timeout * 3 / 2) - if got := tr.IdleConnStrsForTesting(); len(got) != 0 { + if got := idleConns(); len(got) != 0 { t.Errorf("idle conns = %q; want none", got) } } @@ -3523,6 +3661,7 @@ func TestTransportIdleConnTimeout(t *testing.T) { // know the successful tls.Dial from DialTLS will need to go into the // idle pool. Then we give it a of time to explode. func TestIdleConnH2Crash(t *testing.T) { + setParallel(t) cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { // nothing })) @@ -3531,12 +3670,12 @@ func TestIdleConnH2Crash(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - gotErr := make(chan bool, 1) + sawDoErr := make(chan bool, 1) + testDone := make(chan struct{}) + defer close(testDone) cst.tr.IdleConnTimeout = 5 * time.Millisecond cst.tr.DialTLS = func(network, addr string) (net.Conn, error) { - cancel() - <-gotErr c, err := tls.Dial(network, addr, &tls.Config{ InsecureSkipVerify: true, NextProtos: []string{"h2"}, @@ -3550,6 +3689,17 @@ func TestIdleConnH2Crash(t *testing.T) { c.Close() return nil, errors.New("bogus") } + + cancel() + + failTimer := time.NewTimer(5 * time.Second) + defer failTimer.Stop() + select { + case <-sawDoErr: + case <-testDone: + case <-failTimer.C: + t.Error("timeout in DialTLS, waiting too long for cst.c.Do to fail") + } return c, nil } @@ -3560,7 +3710,7 @@ func TestIdleConnH2Crash(t *testing.T) { res.Body.Close() t.Fatal("unexpected success") } - gotErr <- true + sawDoErr <- true // Wait for the explosion. time.Sleep(cst.tr.IdleConnTimeout * 10) @@ -3605,6 +3755,122 @@ func TestTransportReturnsPeekError(t *testing.T) { } } +// Issue 13835: international domain names should work +func TestTransportIDNA_h1(t *testing.T) { testTransportIDNA(t, h1Mode) } +func TestTransportIDNA_h2(t *testing.T) { testTransportIDNA(t, h2Mode) } +func testTransportIDNA(t *testing.T, h2 bool) { + defer afterTest(t) + + const uniDomain = "гофер.го" + const punyDomain = "xn--c1ae0ajs.xn--c1aw" + + var port string + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + want := punyDomain + ":" + port + if r.Host != want { + t.Errorf("Host header = %q; want %q", r.Host, want) + } + if h2 { + if r.TLS == nil { + t.Errorf("r.TLS == nil") + } else if r.TLS.ServerName != punyDomain { + t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain) + } + } + w.Header().Set("Hit-Handler", "1") + })) + defer cst.close() + + ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + + // Install a fake DNS server. + ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, host string) ([]net.IPAddr, error) { + if host != punyDomain { + t.Errorf("got DNS host lookup for %q; want %q", host, punyDomain) + return nil, nil + } + return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil + }) + + req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil) + trace := &httptrace.ClientTrace{ + GetConn: func(hostPort string) { + want := net.JoinHostPort(punyDomain, port) + if hostPort != want { + t.Errorf("getting conn for %q; want %q", hostPort, want) + } + }, + DNSStart: func(e httptrace.DNSStartInfo) { + if e.Host != punyDomain { + t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain) + } + }, + } + req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) + + res, err := cst.tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.Header.Get("Hit-Handler") != "1" { + out, err := httputil.DumpResponse(res, true) + if err != nil { + t.Fatal(err) + } + t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out) + } +} + +// Issue 13290: send User-Agent in proxy CONNECT +func TestTransportProxyConnectHeader(t *testing.T) { + defer afterTest(t) + reqc := make(chan *Request, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method != "CONNECT" { + t.Errorf("method = %q; want CONNECT", r.Method) + } + reqc <- r + c, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Errorf("Hijack: %v", err) + return + } + c.Close() + })) + defer ts.Close() + tr := &Transport{ + ProxyConnectHeader: Header{ + "User-Agent": {"foo"}, + "Other": {"bar"}, + }, + Proxy: func(r *Request) (*url.URL, error) { + return url.Parse(ts.URL) + }, + } + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + res, err := c.Get("https://dummy.tld/") // https to force a CONNECT + if err == nil { + res.Body.Close() + t.Errorf("unexpected success") + } + select { + case <-time.After(3 * time.Second): + t.Fatal("timeout") + case r := <-reqc: + if got, want := r.Header.Get("User-Agent"), "foo"; got != want { + t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) + } + if got, want := r.Header.Get("Other"), "bar"; got != want { + t.Errorf("CONNECT request Other = %q; want %q", got, want) + } + } +} + var errFakeRoundTrip = errors.New("fake roundtrip") type funcRoundTripper func() diff --git a/libgo/go/net/interface.go b/libgo/go/net/interface.go index 52b857c..b3297f2 100644 --- a/libgo/go/net/interface.go +++ b/libgo/go/net/interface.go @@ -10,6 +10,12 @@ import ( "time" ) +// BUG(mikio): On NaCl, methods and functions related to +// Interface are not implemented. + +// BUG(mikio): On DragonFly BSD, NetBSD, OpenBSD, Plan 9 and Solaris, +// the MulticastAddrs method of Interface is not implemented. + var ( errInvalidInterface = errors.New("invalid network interface") errInvalidInterfaceIndex = errors.New("invalid network interface index") @@ -63,7 +69,8 @@ func (f Flags) String() string { return s } -// Addrs returns interface addresses for a specific interface. +// Addrs returns a list of unicast interface addresses for a specific +// interface. func (ifi *Interface) Addrs() ([]Addr, error) { if ifi == nil { return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterface} @@ -75,8 +82,8 @@ func (ifi *Interface) Addrs() ([]Addr, error) { return ifat, err } -// MulticastAddrs returns multicast, joined group addresses for -// a specific interface. +// MulticastAddrs returns a list of multicast, joined group addresses +// for a specific interface. func (ifi *Interface) MulticastAddrs() ([]Addr, error) { if ifi == nil { return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterface} @@ -100,8 +107,11 @@ func Interfaces() ([]Interface, error) { return ift, nil } -// InterfaceAddrs returns a list of the system's network interface +// InterfaceAddrs returns a list of the system's unicast interface // addresses. +// +// The returned list does not identify the associated interface; use +// Interfaces and Interface.Addrs for more detail. func InterfaceAddrs() ([]Addr, error) { ifat, err := interfaceAddrTable(nil) if err != nil { @@ -111,6 +121,10 @@ func InterfaceAddrs() ([]Addr, error) { } // InterfaceByIndex returns the interface specified by index. +// +// On Solaris, it returns one of the logical network interfaces +// sharing the logical data link; for more precision use +// InterfaceByName. func InterfaceByIndex(index int) (*Interface, error) { if index <= 0 { return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterfaceIndex} @@ -158,6 +172,9 @@ func InterfaceByName(name string) (*Interface, error) { // An ipv6ZoneCache represents a cache holding partial network // interface information. It is used for reducing the cost of IPv6 // addressing scope zone resolution. +// +// Multiple names sharing the index are managed by first-come +// first-served basis for consistency. type ipv6ZoneCache struct { sync.RWMutex // guard the following lastFetched time.Time // last time routing information was fetched @@ -188,7 +205,9 @@ func (zc *ipv6ZoneCache) update(ift []Interface) { zc.toName = make(map[int]string, len(ift)) for _, ifi := range ift { zc.toIndex[ifi.Name] = ifi.Index - zc.toName[ifi.Index] = ifi.Name + if _, ok := zc.toName[ifi.Index]; !ok { + zc.toName[ifi.Index] = ifi.Name + } } } @@ -215,7 +234,7 @@ func zoneToInt(zone string) int { defer zoneCache.RUnlock() index, ok := zoneCache.toIndex[zone] if !ok { - index, _, _ = dtoi(zone, 0) + index, _, _ = dtoi(zone) } return index } diff --git a/libgo/go/net/interface_plan9.go b/libgo/go/net/interface_plan9.go new file mode 100644 index 0000000..e5d7739 --- /dev/null +++ b/libgo/go/net/interface_plan9.go @@ -0,0 +1,198 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "errors" + "os" +) + +// If the ifindex is zero, interfaceTable returns mappings of all +// network interfaces. Otherwise it returns a mapping of a specific +// interface. +func interfaceTable(ifindex int) ([]Interface, error) { + if ifindex == 0 { + n, err := interfaceCount() + if err != nil { + return nil, err + } + ifcs := make([]Interface, n) + for i := range ifcs { + ifc, err := readInterface(i) + if err != nil { + return nil, err + } + ifcs[i] = *ifc + } + return ifcs, nil + } + + ifc, err := readInterface(ifindex - 1) + if err != nil { + return nil, err + } + return []Interface{*ifc}, nil +} + +func readInterface(i int) (*Interface, error) { + ifc := &Interface{ + Index: i + 1, // Offset the index by one to suit the contract + Name: netdir + "/ipifc/" + itoa(i), // Name is the full path to the interface path in plan9 + } + + ifcstat := ifc.Name + "/status" + ifcstatf, err := open(ifcstat) + if err != nil { + return nil, err + } + defer ifcstatf.close() + + line, ok := ifcstatf.readLine() + if !ok { + return nil, errors.New("invalid interface status file: " + ifcstat) + } + + fields := getFields(line) + if len(fields) < 4 { + return nil, errors.New("invalid interface status file: " + ifcstat) + } + + device := fields[1] + mtustr := fields[3] + + mtu, _, ok := dtoi(mtustr) + if !ok { + return nil, errors.New("invalid status file of interface: " + ifcstat) + } + ifc.MTU = mtu + + // Not a loopback device + if device != "/dev/null" { + deviceaddrf, err := open(device + "/addr") + if err != nil { + return nil, err + } + defer deviceaddrf.close() + + line, ok = deviceaddrf.readLine() + if !ok { + return nil, errors.New("invalid address file for interface: " + device + "/addr") + } + + if len(line) > 0 && len(line)%2 == 0 { + ifc.HardwareAddr = make([]byte, len(line)/2) + var ok bool + for i := range ifc.HardwareAddr { + j := (i + 1) * 2 + ifc.HardwareAddr[i], ok = xtoi2(line[i*2:j], 0) + if !ok { + ifc.HardwareAddr = ifc.HardwareAddr[:i] + break + } + } + } + + ifc.Flags = FlagUp | FlagBroadcast | FlagMulticast + } else { + ifc.Flags = FlagUp | FlagMulticast | FlagLoopback + } + + return ifc, nil +} + +func interfaceCount() (int, error) { + d, err := os.Open(netdir + "/ipifc") + if err != nil { + return -1, err + } + defer d.Close() + + names, err := d.Readdirnames(0) + if err != nil { + return -1, err + } + + // Assumes that numbered files in ipifc are strictly + // the incrementing numbered directories for the + // interfaces + c := 0 + for _, name := range names { + if _, _, ok := dtoi(name); !ok { + continue + } + c++ + } + + return c, nil +} + +// If the ifi is nil, interfaceAddrTable returns addresses for all +// network interfaces. Otherwise it returns addresses for a specific +// interface. +func interfaceAddrTable(ifi *Interface) ([]Addr, error) { + var ifcs []Interface + if ifi == nil { + var err error + ifcs, err = interfaceTable(0) + if err != nil { + return nil, err + } + } else { + ifcs = []Interface{*ifi} + } + + addrs := make([]Addr, len(ifcs)) + for i, ifc := range ifcs { + status := ifc.Name + "/status" + statusf, err := open(status) + if err != nil { + return nil, err + } + defer statusf.close() + + line, ok := statusf.readLine() + line, ok = statusf.readLine() + if !ok { + return nil, errors.New("cannot parse IP address for interface: " + status) + } + + // This assumes only a single address for the interface. + fields := getFields(line) + if len(fields) < 1 { + return nil, errors.New("cannot parse IP address for interface: " + status) + } + addr := fields[0] + ip := ParseIP(addr) + if ip == nil { + return nil, errors.New("cannot parse IP address for interface: " + status) + } + + // The mask is represented as CIDR relative to the IPv6 address. + // Plan 9 internal representation is always IPv6. + maskfld := fields[1] + maskfld = maskfld[1:] + pfxlen, _, ok := dtoi(maskfld) + if !ok { + return nil, errors.New("cannot parse network mask for interface: " + status) + } + var mask IPMask + if ip.To4() != nil { // IPv4 or IPv6 IPv4-mapped address + mask = CIDRMask(pfxlen-8*len(v4InV6Prefix), 8*IPv4len) + } + if ip.To16() != nil && ip.To4() == nil { // IPv6 address + mask = CIDRMask(pfxlen, 8*IPv6len) + } + + addrs[i] = &IPNet{IP: ip, Mask: mask} + } + + return addrs, nil +} + +// interfaceMulticastAddrTable returns addresses for a specific +// interface. +func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) { + return nil, nil +} diff --git a/libgo/go/net/interface_solaris.go b/libgo/go/net/interface_solaris.go new file mode 100644 index 0000000..dc8ffbf --- /dev/null +++ b/libgo/go/net/interface_solaris.go @@ -0,0 +1,107 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "syscall" + + "golang_org/x/net/lif" +) + +// If the ifindex is zero, interfaceTable returns mappings of all +// network interfaces. Otherwise it returns a mapping of a specific +// interface. +func interfaceTable(ifindex int) ([]Interface, error) { + lls, err := lif.Links(syscall.AF_UNSPEC, "") + if err != nil { + return nil, err + } + var ift []Interface + for _, ll := range lls { + if ifindex != 0 && ifindex != ll.Index { + continue + } + ifi := Interface{Index: ll.Index, MTU: ll.MTU, Name: ll.Name, Flags: linkFlags(ll.Flags)} + if len(ll.Addr) > 0 { + ifi.HardwareAddr = HardwareAddr(ll.Addr) + } + ift = append(ift, ifi) + } + return ift, nil +} + +const ( + sysIFF_UP = 0x1 + sysIFF_BROADCAST = 0x2 + sysIFF_DEBUG = 0x4 + sysIFF_LOOPBACK = 0x8 + sysIFF_POINTOPOINT = 0x10 + sysIFF_NOTRAILERS = 0x20 + sysIFF_RUNNING = 0x40 + sysIFF_NOARP = 0x80 + sysIFF_PROMISC = 0x100 + sysIFF_ALLMULTI = 0x200 + sysIFF_INTELLIGENT = 0x400 + sysIFF_MULTICAST = 0x800 + sysIFF_MULTI_BCAST = 0x1000 + sysIFF_UNNUMBERED = 0x2000 + sysIFF_PRIVATE = 0x8000 +) + +func linkFlags(rawFlags int) Flags { + var f Flags + if rawFlags&sysIFF_UP != 0 { + f |= FlagUp + } + if rawFlags&sysIFF_BROADCAST != 0 { + f |= FlagBroadcast + } + if rawFlags&sysIFF_LOOPBACK != 0 { + f |= FlagLoopback + } + if rawFlags&sysIFF_POINTOPOINT != 0 { + f |= FlagPointToPoint + } + if rawFlags&sysIFF_MULTICAST != 0 { + f |= FlagMulticast + } + return f +} + +// If the ifi is nil, interfaceAddrTable returns addresses for all +// network interfaces. Otherwise it returns addresses for a specific +// interface. +func interfaceAddrTable(ifi *Interface) ([]Addr, error) { + var name string + if ifi != nil { + name = ifi.Name + } + as, err := lif.Addrs(syscall.AF_UNSPEC, name) + if err != nil { + return nil, err + } + var ifat []Addr + for _, a := range as { + var ip IP + var mask IPMask + switch a := a.(type) { + case *lif.Inet4Addr: + ip = IPv4(a.IP[0], a.IP[1], a.IP[2], a.IP[3]) + mask = CIDRMask(a.PrefixLen, 8*IPv4len) + case *lif.Inet6Addr: + ip = make(IP, IPv6len) + copy(ip, a.IP[:]) + mask = CIDRMask(a.PrefixLen, 8*IPv6len) + } + ifat = append(ifat, &IPNet{IP: ip, Mask: mask}) + } + return ifat, nil +} + +// interfaceMulticastAddrTable returns addresses for a specific +// interface. +func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) { + return nil, nil +} diff --git a/libgo/go/net/interface_stub.go b/libgo/go/net/interface_stub.go index f64174c..3b0a1ae 100644 --- a/libgo/go/net/interface_stub.go +++ b/libgo/go/net/interface_stub.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build nacl plan9 solaris +// +build nacl package net diff --git a/libgo/go/net/interface_test.go b/libgo/go/net/interface_test.go index 4c695b9..38a2ca4 100644 --- a/libgo/go/net/interface_test.go +++ b/libgo/go/net/interface_test.go @@ -58,8 +58,15 @@ func TestInterfaces(t *testing.T) { if err != nil { t.Fatal(err) } - if !reflect.DeepEqual(ifxi, &ifi) { - t.Errorf("got %v; want %v", ifxi, ifi) + switch runtime.GOOS { + case "solaris": + if ifxi.Index != ifi.Index { + t.Errorf("got %v; want %v", ifxi, ifi) + } + default: + if !reflect.DeepEqual(ifxi, &ifi) { + t.Errorf("got %v; want %v", ifxi, ifi) + } } ifxn, err := InterfaceByName(ifi.Name) if err != nil { diff --git a/libgo/go/net/ip.go b/libgo/go/net/ip.go index d0c8263..db3364c 100644 --- a/libgo/go/net/ip.go +++ b/libgo/go/net/ip.go @@ -90,7 +90,7 @@ func CIDRMask(ones, bits int) IPMask { // Well-known IPv4 addresses var ( - IPv4bcast = IPv4(255, 255, 255, 255) // broadcast + IPv4bcast = IPv4(255, 255, 255, 255) // limited broadcast IPv4allsys = IPv4(224, 0, 0, 1) // all systems IPv4allrouter = IPv4(224, 0, 0, 2) // all routers IPv4zero = IPv4(0, 0, 0, 0) // all zeros @@ -153,6 +153,12 @@ func (ip IP) IsLinkLocalUnicast() bool { // IsGlobalUnicast reports whether ip is a global unicast // address. +// +// The identification of global unicast addresses uses address type +// identification as defined in RFC 1122, RFC 4632 and RFC 4291 with +// the exception of IPv4 directed broadcast addresses. +// It returns true even if ip is in IPv4 private address space or +// local IPv6 unicast address space. func (ip IP) IsGlobalUnicast() bool { return (len(ip) == IPv4len || len(ip) == IPv6len) && !ip.Equal(IPv4bcast) && @@ -504,29 +510,25 @@ func (n *IPNet) String() string { // Parse IPv4 address (d.d.d.d). func parseIPv4(s string) IP { var p [IPv4len]byte - i := 0 - for j := 0; j < IPv4len; j++ { - if i >= len(s) { + for i := 0; i < IPv4len; i++ { + if len(s) == 0 { // Missing octets. return nil } - if j > 0 { - if s[i] != '.' { + if i > 0 { + if s[0] != '.' { return nil } - i++ + s = s[1:] } - var ( - n int - ok bool - ) - n, i, ok = dtoi(s, i) + n, c, ok := dtoi(s) if !ok || n > 0xFF { return nil } - p[j] = byte(n) + s = s[c:] + p[i] = byte(n) } - if i != len(s) { + if len(s) != 0 { return nil } return IPv4(p[0], p[1], p[2], p[3]) @@ -538,8 +540,7 @@ func parseIPv4(s string) IP { // true. func parseIPv6(s string, zoneAllowed bool) (ip IP, zone string) { ip = make(IP, IPv6len) - ellipsis := -1 // position of ellipsis in p - i := 0 // index in string s + ellipsis := -1 // position of ellipsis in ip if zoneAllowed { s, zone = splitHostZone(s) @@ -548,90 +549,91 @@ func parseIPv6(s string, zoneAllowed bool) (ip IP, zone string) { // Might have leading ellipsis if len(s) >= 2 && s[0] == ':' && s[1] == ':' { ellipsis = 0 - i = 2 + s = s[2:] // Might be only ellipsis - if i == len(s) { + if len(s) == 0 { return ip, zone } } // Loop, parsing hex numbers followed by colon. - j := 0 - for j < IPv6len { + i := 0 + for i < IPv6len { // Hex number. - n, i1, ok := xtoi(s, i) + n, c, ok := xtoi(s) if !ok || n > 0xFFFF { return nil, zone } // If followed by dot, might be in trailing IPv4. - if i1 < len(s) && s[i1] == '.' { - if ellipsis < 0 && j != IPv6len-IPv4len { + if c < len(s) && s[c] == '.' { + if ellipsis < 0 && i != IPv6len-IPv4len { // Not the right place. return nil, zone } - if j+IPv4len > IPv6len { + if i+IPv4len > IPv6len { // Not enough room. return nil, zone } - ip4 := parseIPv4(s[i:]) + ip4 := parseIPv4(s) if ip4 == nil { return nil, zone } - ip[j] = ip4[12] - ip[j+1] = ip4[13] - ip[j+2] = ip4[14] - ip[j+3] = ip4[15] - i = len(s) - j += IPv4len + ip[i] = ip4[12] + ip[i+1] = ip4[13] + ip[i+2] = ip4[14] + ip[i+3] = ip4[15] + s = "" + i += IPv4len break } // Save this 16-bit chunk. - ip[j] = byte(n >> 8) - ip[j+1] = byte(n) - j += 2 + ip[i] = byte(n >> 8) + ip[i+1] = byte(n) + i += 2 // Stop at end of string. - i = i1 - if i == len(s) { + s = s[c:] + if len(s) == 0 { break } // Otherwise must be followed by colon and more. - if s[i] != ':' || i+1 == len(s) { + if s[0] != ':' || len(s) == 1 { return nil, zone } - i++ + s = s[1:] // Look for ellipsis. - if s[i] == ':' { + if s[0] == ':' { if ellipsis >= 0 { // already have one return nil, zone } - ellipsis = j - if i++; i == len(s) { // can be at end + ellipsis = i + s = s[1:] + if len(s) == 0 { // can be at end break } } } // Must have used entire string. - if i != len(s) { + if len(s) != 0 { return nil, zone } // If didn't parse enough, expand ellipsis. - if j < IPv6len { + if i < IPv6len { if ellipsis < 0 { return nil, zone } - n := IPv6len - j - for k := j - 1; k >= ellipsis; k-- { - ip[k+n] = ip[k] + n := IPv6len - i + for j := i - 1; j >= ellipsis; j-- { + ip[j+n] = ip[j] } - for k := ellipsis + n - 1; k >= ellipsis; k-- { - ip[k] = 0 + for j := ellipsis + n - 1; j >= ellipsis; j-- { + ip[j] = 0 } } else if ellipsis >= 0 { // Ellipsis must represent at least one 0 group. @@ -658,13 +660,14 @@ func ParseIP(s string) IP { return nil } -// ParseCIDR parses s as a CIDR notation IP address and mask, +// ParseCIDR parses s as a CIDR notation IP address and prefix length, // like "192.0.2.0/24" or "2001:db8::/32", as defined in // RFC 4632 and RFC 4291. // -// It returns the IP address and the network implied by the IP -// and mask. For example, ParseCIDR("198.51.100.1/24") returns -// the IP address 198.51.100.1 and the network 198.51.100.0/24. +// It returns the IP address and the network implied by the IP and +// prefix length. +// For example, ParseCIDR("192.0.2.1/24") returns the IP address +// 198.0.2.1 and the network 198.0.2.0/24. func ParseCIDR(s string) (IP, *IPNet, error) { i := byteIndex(s, '/') if i < 0 { @@ -677,7 +680,7 @@ func ParseCIDR(s string) (IP, *IPNet, error) { iplen = IPv6len ip, _ = parseIPv6(addr, false) } - n, i, ok := dtoi(mask, 0) + n, i, ok := dtoi(mask) if ip == nil || !ok || i != len(mask) || n < 0 || n > 8*iplen { return nil, nil, &ParseError{Type: "CIDR address", Text: s} } diff --git a/libgo/go/net/ip_test.go b/libgo/go/net/ip_test.go index b6ac26d..4655163 100644 --- a/libgo/go/net/ip_test.go +++ b/libgo/go/net/ip_test.go @@ -28,6 +28,10 @@ var parseIPTests = []struct { {"2001:4860:0:2001::68", IP{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68}}, {"2001:4860:0000:2001:0000:0000:0000:0068", IP{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68}}, + {"-0.0.0.0", nil}, + {"0.-1.0.0", nil}, + {"0.0.-2.0", nil}, + {"0.0.0.-3", nil}, {"127.0.0.256", nil}, {"abc", nil}, {"123:", nil}, @@ -242,13 +246,15 @@ func TestIPString(t *testing.T) { } } +var sink string + func BenchmarkIPString(b *testing.B) { testHookUninstaller.Do(uninstallTestHooks) for i := 0; i < b.N; i++ { for _, tt := range ipStringTests { if tt.in != nil { - tt.in.String() + sink = tt.in.String() } } } @@ -299,7 +305,7 @@ func BenchmarkIPMaskString(b *testing.B) { for i := 0; i < b.N; i++ { for _, tt := range ipMaskStringTests { - tt.in.String() + sink = tt.in.String() } } } @@ -330,6 +336,12 @@ var parseCIDRTests = []struct { {"192.168.1.1/255.255.255.0", nil, nil, &ParseError{Type: "CIDR address", Text: "192.168.1.1/255.255.255.0"}}, {"192.168.1.1/35", nil, nil, &ParseError{Type: "CIDR address", Text: "192.168.1.1/35"}}, {"2001:db8::1/-1", nil, nil, &ParseError{Type: "CIDR address", Text: "2001:db8::1/-1"}}, + {"2001:db8::1/-0", nil, nil, &ParseError{Type: "CIDR address", Text: "2001:db8::1/-0"}}, + {"-0.0.0.0/32", nil, nil, &ParseError{Type: "CIDR address", Text: "-0.0.0.0/32"}}, + {"0.-1.0.0/32", nil, nil, &ParseError{Type: "CIDR address", Text: "0.-1.0.0/32"}}, + {"0.0.-2.0/32", nil, nil, &ParseError{Type: "CIDR address", Text: "0.0.-2.0/32"}}, + {"0.0.0.-3/32", nil, nil, &ParseError{Type: "CIDR address", Text: "0.0.0.-3/32"}}, + {"0.0.0.0/-0", nil, nil, &ParseError{Type: "CIDR address", Text: "0.0.0.0/-0"}}, {"", nil, nil, &ParseError{Type: "CIDR address", Text: ""}}, } diff --git a/libgo/go/net/iprawsock.go b/libgo/go/net/iprawsock.go index 173b3cb..d994fc6 100644 --- a/libgo/go/net/iprawsock.go +++ b/libgo/go/net/iprawsock.go @@ -9,6 +9,24 @@ import ( "syscall" ) +// BUG(mikio): On every POSIX platform, reads from the "ip4" network +// using the ReadFrom or ReadFromIP method might not return a complete +// IPv4 packet, including its header, even if there is space +// available. This can occur even in cases where Read or ReadMsgIP +// could return a complete packet. For this reason, it is recommended +// that you do not use these methods if it is important to receive a +// full packet. +// +// The Go 1 compatibility guidelines make it impossible for us to +// change the behavior of these methods; use Read or ReadMsgIP +// instead. + +// BUG(mikio): On NaCl, Plan 9 and Windows, the ReadMsgIP and +// WriteMsgIP methods of IPConn are not implemented. + +// BUG(mikio): On Windows, the File method of IPConn is not +// implemented. + // IPAddr represents the address of an IP end point. type IPAddr struct { IP IP @@ -46,6 +64,9 @@ func (a *IPAddr) opAddr() Addr { // ResolveIPAddr parses addr as an IP address of the form "host" or // "ipv6-host%zone" and resolves the domain name on the network net, // which must be "ip", "ip4" or "ip6". +// +// Resolving a hostname is not recommended because this returns at most +// one of its IP addresses. func ResolveIPAddr(net, addr string) (*IPAddr, error) { if net == "" { // a hint wildcard for Go 1.0 undocumented behavior net = "ip" @@ -59,7 +80,7 @@ func ResolveIPAddr(net, addr string) (*IPAddr, error) { default: return nil, UnknownNetworkError(net) } - addrs, err := internetAddrList(context.Background(), afnet, addr) + addrs, err := DefaultResolver.internetAddrList(context.Background(), afnet, addr) if err != nil { return nil, err } diff --git a/libgo/go/net/iprawsock_posix.go b/libgo/go/net/iprawsock_posix.go index 3e0b060..8f4b702 100644 --- a/libgo/go/net/iprawsock_posix.go +++ b/libgo/go/net/iprawsock_posix.go @@ -11,18 +11,6 @@ import ( "syscall" ) -// BUG(mikio): On every POSIX platform, reads from the "ip4" network -// using the ReadFrom or ReadFromIP method might not return a complete -// IPv4 packet, including its header, even if there is space -// available. This can occur even in cases where Read or ReadMsgIP -// could return a complete packet. For this reason, it is recommended -// that you do not uses these methods if it is important to receive a -// full packet. -// -// The Go 1 compatibility guidelines make it impossible for us to -// change the behavior of these methods; use Read or ReadMsgIP -// instead. - func sockaddrToIP(sa syscall.Sockaddr) Addr { switch sa := sa.(type) { case *syscall.SockaddrInet4: @@ -50,6 +38,10 @@ func (a *IPAddr) sockaddr(family int) (syscall.Sockaddr, error) { return ipToSockaddr(family, a.IP, 0, a.Zone) } +func (a *IPAddr) toLocal(net string) sockaddr { + return &IPAddr{loopbackIP(net), a.Zone} +} + func (c *IPConn) readFrom(b []byte) (int, *IPAddr, error) { // TODO(cw,rsc): consider using readv if we know the family // type to avoid the header trim/copy diff --git a/libgo/go/net/iprawsock_test.go b/libgo/go/net/iprawsock_test.go index 29cd4b6..5d33b26 100644 --- a/libgo/go/net/iprawsock_test.go +++ b/libgo/go/net/iprawsock_test.go @@ -43,6 +43,13 @@ var resolveIPAddrTests = []resolveIPAddrTest{ {"l2tp", "127.0.0.1", nil, UnknownNetworkError("l2tp")}, {"l2tp:gre", "127.0.0.1", nil, UnknownNetworkError("l2tp:gre")}, {"tcp", "1.2.3.4:123", nil, UnknownNetworkError("tcp")}, + + {"ip4", "2001:db8::1", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "2001:db8::1"}}, + {"ip4:icmp", "2001:db8::1", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "2001:db8::1"}}, + {"ip6", "127.0.0.1", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "127.0.0.1"}}, + {"ip6", "::ffff:127.0.0.1", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "::ffff:127.0.0.1"}}, + {"ip6:ipv6-icmp", "127.0.0.1", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "127.0.0.1"}}, + {"ip6:ipv6-icmp", "::ffff:127.0.0.1", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "::ffff:127.0.0.1"}}, } func TestResolveIPAddr(t *testing.T) { @@ -54,21 +61,17 @@ func TestResolveIPAddr(t *testing.T) { defer func() { testHookLookupIP = origTestHookLookupIP }() testHookLookupIP = lookupLocalhost - for i, tt := range resolveIPAddrTests { + for _, tt := range resolveIPAddrTests { addr, err := ResolveIPAddr(tt.network, tt.litAddrOrName) - if err != tt.err { - t.Errorf("#%d: %v", i, err) - } else if !reflect.DeepEqual(addr, tt.addr) { - t.Errorf("#%d: got %#v; want %#v", i, addr, tt.addr) - } - if err != nil { + if !reflect.DeepEqual(addr, tt.addr) || !reflect.DeepEqual(err, tt.err) { + t.Errorf("ResolveIPAddr(%q, %q) = %#v, %v, want %#v, %v", tt.network, tt.litAddrOrName, addr, err, tt.addr, tt.err) continue } - rtaddr, err := ResolveIPAddr(addr.Network(), addr.String()) - if err != nil { - t.Errorf("#%d: %v", i, err) - } else if !reflect.DeepEqual(rtaddr, addr) { - t.Errorf("#%d: got %#v; want %#v", i, rtaddr, addr) + if err == nil { + addr2, err := ResolveIPAddr(addr.Network(), addr.String()) + if !reflect.DeepEqual(addr2, tt.addr) || err != tt.err { + t.Errorf("(%q, %q): ResolveIPAddr(%q, %q) = %#v, %v, want %#v, %v", tt.network, tt.litAddrOrName, addr.Network(), addr.String(), addr2, err, tt.addr, tt.err) + } } } } diff --git a/libgo/go/net/ipsock.go b/libgo/go/net/ipsock.go index 24daf17..f1394a7 100644 --- a/libgo/go/net/ipsock.go +++ b/libgo/go/net/ipsock.go @@ -10,6 +10,13 @@ import ( "context" ) +// BUG(rsc,mikio): On DragonFly BSD and OpenBSD, listening on the +// "tcp" and "udp" networks does not listen for both IPv4 and IPv6 +// connections. This is due to the fact that IPv4 traffic will not be +// routed to an IPv6 socket - two separate sockets are required if +// both address families are to be supported. +// See inet6(4) for details. + var ( // supportsIPv4 reports whether the platform supports IPv4 // networking functionality. @@ -76,7 +83,7 @@ func (addrs addrList) partition(strategy func(Addr) bool) (primaries, fallbacks // yielding a list of Addr objects. Known filters are nil, ipv4only, // and ipv6only. It returns every address when the filter is nil. // The result contains at least one address when error is nil. -func filterAddrList(filter func(IPAddr) bool, ips []IPAddr, inetaddr func(IPAddr) Addr) (addrList, error) { +func filterAddrList(filter func(IPAddr) bool, ips []IPAddr, inetaddr func(IPAddr) Addr, originalAddr string) (addrList, error) { var addrs addrList for _, ip := range ips { if filter == nil || filter(ip) { @@ -84,21 +91,19 @@ func filterAddrList(filter func(IPAddr) bool, ips []IPAddr, inetaddr func(IPAddr } } if len(addrs) == 0 { - return nil, errNoSuitableAddress + return nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: originalAddr} } return addrs, nil } -// ipv4only reports whether the kernel supports IPv4 addressing mode -// and addr is an IPv4 address. +// ipv4only reports whether addr is an IPv4 address. func ipv4only(addr IPAddr) bool { - return supportsIPv4 && addr.IP.To4() != nil + return addr.IP.To4() != nil } -// ipv6only reports whether the kernel supports IPv6 addressing mode -// and addr is an IPv6 address except IPv4-mapped IPv6 address. +// ipv6only reports whether addr is an IPv6 address except IPv4-mapped IPv6 address. func ipv6only(addr IPAddr) bool { - return supportsIPv6 && len(addr.IP) == IPv6len && addr.IP.To4() == nil + return len(addr.IP) == IPv6len && addr.IP.To4() == nil } // SplitHostPort splits a network address of the form "host:port", @@ -190,7 +195,7 @@ func JoinHostPort(host, port string) string { // address or a DNS name, and returns a list of internet protocol // family addresses. The result contains at least one address when // error is nil. -func internetAddrList(ctx context.Context, net, addr string) (addrList, error) { +func (r *Resolver) internetAddrList(ctx context.Context, net, addr string) (addrList, error) { var ( err error host, port string @@ -202,7 +207,7 @@ func internetAddrList(ctx context.Context, net, addr string) (addrList, error) { if host, port, err = SplitHostPort(addr); err != nil { return nil, err } - if portnum, err = LookupPort(net, port); err != nil { + if portnum, err = r.LookupPort(ctx, net, port); err != nil { return nil, err } } @@ -228,20 +233,21 @@ func internetAddrList(ctx context.Context, net, addr string) (addrList, error) { if host == "" { return addrList{inetaddr(IPAddr{})}, nil } - // Try as a literal IP address. - var ip IP - if ip = parseIPv4(host); ip != nil { - return addrList{inetaddr(IPAddr{IP: ip})}, nil - } - var zone string - if ip, zone = parseIPv6(host, true); ip != nil { - return addrList{inetaddr(IPAddr{IP: ip, Zone: zone})}, nil - } - // Try as a DNS name. - ips, err := lookupIPContext(ctx, host) - if err != nil { - return nil, err + + // Try as a literal IP address, then as a DNS name. + var ips []IPAddr + if ip := parseIPv4(host); ip != nil { + ips = []IPAddr{{IP: ip}} + } else if ip, zone := parseIPv6(host, true); ip != nil { + ips = []IPAddr{{IP: ip, Zone: zone}} + } else { + // Try as a DNS name. + ips, err = r.LookupIPAddr(ctx, host) + if err != nil { + return nil, err + } } + var filter func(IPAddr) bool if net != "" && net[len(net)-1] == '4' { filter = ipv4only @@ -249,5 +255,12 @@ func internetAddrList(ctx context.Context, net, addr string) (addrList, error) { if net != "" && net[len(net)-1] == '6' { filter = ipv6only } - return filterAddrList(filter, ips, inetaddr) + return filterAddrList(filter, ips, inetaddr, host) +} + +func loopbackIP(net string) IP { + if net != "" && net[len(net)-1] == '6' { + return IPv6loopback + } + return IP{127, 0, 0, 1} } diff --git a/libgo/go/net/ipsock_plan9.go b/libgo/go/net/ipsock_plan9.go index 2b84683..b7fd344 100644 --- a/libgo/go/net/ipsock_plan9.go +++ b/libgo/go/net/ipsock_plan9.go @@ -63,7 +63,7 @@ func parsePlan9Addr(s string) (ip IP, iport int, err error) { return nil, 0, &ParseError{Type: "IP address", Text: s} } } - p, _, ok := dtoi(s[i+1:], 0) + p, _, ok := dtoi(s[i+1:]) if !ok { return nil, 0, &ParseError{Type: "port", Text: s} } @@ -119,6 +119,11 @@ func startPlan9(ctx context.Context, net string, addr Addr) (ctl *os.File, dest, return } + if port > 65535 { + err = InvalidAddrError("port should be < 65536") + return + } + clone, dest, err := queryCS1(ctx, proto, ip, port) if err != nil { return @@ -193,6 +198,9 @@ func dialPlan9(ctx context.Context, net string, laddr, raddr Addr) (fd *netFD, e } func dialPlan9Blocking(ctx context.Context, net string, laddr, raddr Addr) (fd *netFD, err error) { + if isWildcard(raddr) { + raddr = toLocal(raddr, net) + } f, dest, proto, name, err := startPlan9(ctx, net, raddr) if err != nil { return nil, err @@ -213,7 +221,7 @@ func dialPlan9Blocking(ctx context.Context, net string, laddr, raddr Addr) (fd * f.Close() return nil, err } - return newFD(proto, name, f, data, laddr, raddr) + return newFD(proto, name, nil, f, data, laddr, raddr) } func listenPlan9(ctx context.Context, net string, laddr Addr) (fd *netFD, err error) { @@ -232,11 +240,11 @@ func listenPlan9(ctx context.Context, net string, laddr Addr) (fd *netFD, err er f.Close() return nil, err } - return newFD(proto, name, f, nil, laddr, nil) + return newFD(proto, name, nil, f, nil, laddr, nil) } func (fd *netFD) netFD() (*netFD, error) { - return newFD(fd.net, fd.n, fd.ctl, fd.data, fd.laddr, fd.raddr) + return newFD(fd.net, fd.n, fd.listen, fd.ctl, fd.data, fd.laddr, fd.raddr) } func (fd *netFD) acceptPlan9() (nfd *netFD, err error) { @@ -245,27 +253,59 @@ func (fd *netFD) acceptPlan9() (nfd *netFD, err error) { return nil, err } defer fd.readUnlock() - f, err := os.Open(fd.dir + "/listen") + listen, err := os.Open(fd.dir + "/listen") if err != nil { return nil, err } var buf [16]byte - n, err := f.Read(buf[:]) + n, err := listen.Read(buf[:]) if err != nil { - f.Close() + listen.Close() return nil, err } name := string(buf[:n]) + ctl, err := os.OpenFile(netdir+"/"+fd.net+"/"+name+"/ctl", os.O_RDWR, 0) + if err != nil { + listen.Close() + return nil, err + } data, err := os.OpenFile(netdir+"/"+fd.net+"/"+name+"/data", os.O_RDWR, 0) if err != nil { - f.Close() + listen.Close() + ctl.Close() return nil, err } raddr, err := readPlan9Addr(fd.net, netdir+"/"+fd.net+"/"+name+"/remote") if err != nil { + listen.Close() + ctl.Close() data.Close() - f.Close() return nil, err } - return newFD(fd.net, name, f, data, fd.laddr, raddr) + return newFD(fd.net, name, listen, ctl, data, fd.laddr, raddr) +} + +func isWildcard(a Addr) bool { + var wildcard bool + switch a := a.(type) { + case *TCPAddr: + wildcard = a.isWildcard() + case *UDPAddr: + wildcard = a.isWildcard() + case *IPAddr: + wildcard = a.isWildcard() + } + return wildcard +} + +func toLocal(a Addr, net string) Addr { + switch a := a.(type) { + case *TCPAddr: + a.IP = loopbackIP(net) + case *UDPAddr: + a.IP = loopbackIP(net) + case *IPAddr: + a.IP = loopbackIP(net) + } + return a } diff --git a/libgo/go/net/ipsock_posix.go b/libgo/go/net/ipsock_posix.go index abe90ac..ff280c3 100644 --- a/libgo/go/net/ipsock_posix.go +++ b/libgo/go/net/ipsock_posix.go @@ -12,13 +12,6 @@ import ( "syscall" ) -// BUG(rsc,mikio): On DragonFly BSD and OpenBSD, listening on the -// "tcp" and "udp" networks does not listen for both IPv4 and IPv6 -// connections. This is due to the fact that IPv4 traffic will not be -// routed to an IPv6 socket - two separate sockets are required if -// both address families are to be supported. -// See inet6(4) for details. - func probeIPv4Stack() bool { s, err := socketFunc(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP) switch err { @@ -154,6 +147,9 @@ func favoriteAddrFamily(net string, laddr, raddr sockaddr, mode string) (family // Internet sockets (TCP, UDP, IP) func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string) (fd *netFD, err error) { + if (runtime.GOOS == "windows" || runtime.GOOS == "openbsd" || runtime.GOOS == "nacl") && mode == "dial" && raddr.isWildcard() { + raddr = raddr.toLocal(net) + } family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode) return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr) } diff --git a/libgo/go/net/ipsock_test.go b/libgo/go/net/ipsock_test.go index b36557a..1d0f00f 100644 --- a/libgo/go/net/ipsock_test.go +++ b/libgo/go/net/ipsock_test.go @@ -205,13 +205,13 @@ var addrListTests = []struct { nil, }, - {nil, nil, testInetaddr, nil, nil, nil, errNoSuitableAddress}, + {nil, nil, testInetaddr, nil, nil, nil, &AddrError{errNoSuitableAddress.Error(), "ADDR"}}, - {ipv4only, nil, testInetaddr, nil, nil, nil, errNoSuitableAddress}, - {ipv4only, []IPAddr{{IP: IPv6loopback}}, testInetaddr, nil, nil, nil, errNoSuitableAddress}, + {ipv4only, nil, testInetaddr, nil, nil, nil, &AddrError{errNoSuitableAddress.Error(), "ADDR"}}, + {ipv4only, []IPAddr{{IP: IPv6loopback}}, testInetaddr, nil, nil, nil, &AddrError{errNoSuitableAddress.Error(), "ADDR"}}, - {ipv6only, nil, testInetaddr, nil, nil, nil, errNoSuitableAddress}, - {ipv6only, []IPAddr{{IP: IPv4(127, 0, 0, 1)}}, testInetaddr, nil, nil, nil, errNoSuitableAddress}, + {ipv6only, nil, testInetaddr, nil, nil, nil, &AddrError{errNoSuitableAddress.Error(), "ADDR"}}, + {ipv6only, []IPAddr{{IP: IPv4(127, 0, 0, 1)}}, testInetaddr, nil, nil, nil, &AddrError{errNoSuitableAddress.Error(), "ADDR"}}, } func TestAddrList(t *testing.T) { @@ -220,8 +220,8 @@ func TestAddrList(t *testing.T) { } for i, tt := range addrListTests { - addrs, err := filterAddrList(tt.filter, tt.ips, tt.inetaddr) - if err != tt.err { + addrs, err := filterAddrList(tt.filter, tt.ips, tt.inetaddr, "ADDR") + if !reflect.DeepEqual(err, tt.err) { t.Errorf("#%v: got %v; want %v", i, err, tt.err) } if tt.err != nil { diff --git a/libgo/go/net/lookup.go b/libgo/go/net/lookup.go index c169e9e..cc2013e 100644 --- a/libgo/go/net/lookup.go +++ b/libgo/go/net/lookup.go @@ -15,94 +15,137 @@ import ( // protocol numbers. // // See http://www.iana.org/assignments/protocol-numbers +// +// On Unix, this map is augmented by readProtocols via lookupProtocol. var protocols = map[string]int{ - "icmp": 1, "ICMP": 1, - "igmp": 2, "IGMP": 2, - "tcp": 6, "TCP": 6, - "udp": 17, "UDP": 17, - "ipv6-icmp": 58, "IPV6-ICMP": 58, "IPv6-ICMP": 58, + "icmp": 1, + "igmp": 2, + "tcp": 6, + "udp": 17, + "ipv6-icmp": 58, } -// LookupHost looks up the given host using the local resolver. -// It returns an array of that host's addresses. -func LookupHost(host string) (addrs []string, err error) { - // Make sure that no matter what we do later, host=="" is rejected. - // ParseIP, for example, does accept empty strings. - if host == "" { - return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host} +// services contains minimal mappings between services names and port +// numbers for platforms that don't have a complete list of port numbers +// (some Solaris distros, nacl, etc). +// On Unix, this map is augmented by readServices via goLookupPort. +var services = map[string]map[string]int{ + "udp": { + "domain": 53, + }, + "tcp": { + "ftp": 21, + "ftps": 990, + "gopher": 70, // ʕ◔ϖ◔ʔ + "http": 80, + "https": 443, + "imap2": 143, + "imap3": 220, + "imaps": 993, + "pop3": 110, + "pop3s": 995, + "smtp": 25, + "ssh": 22, + "telnet": 23, + }, +} + +const maxProtoLength = len("RSVP-E2E-IGNORE") + 10 // with room to grow + +func lookupProtocolMap(name string) (int, error) { + var lowerProtocol [maxProtoLength]byte + n := copy(lowerProtocol[:], name) + lowerASCIIBytes(lowerProtocol[:n]) + proto, found := protocols[string(lowerProtocol[:n])] + if !found || n != len(name) { + return 0, &AddrError{Err: "unknown IP protocol specified", Addr: name} } - if ip := ParseIP(host); ip != nil { - return []string{host}, nil + return proto, nil +} + +const maxServiceLength = len("mobility-header") + 10 // with room to grow + +func lookupPortMap(network, service string) (port int, error error) { + switch network { + case "tcp4", "tcp6": + network = "tcp" + case "udp4", "udp6": + network = "udp" } - return lookupHost(context.Background(), host) + + if m, ok := services[network]; ok { + var lowerService [maxServiceLength]byte + n := copy(lowerService[:], service) + lowerASCIIBytes(lowerService[:n]) + if port, ok := m[string(lowerService[:n])]; ok && n == len(service) { + return port, nil + } + } + return 0, &AddrError{Err: "unknown port", Addr: network + "/" + service} } -// LookupIP looks up host using the local resolver. -// It returns an array of that host's IPv4 and IPv6 addresses. -func LookupIP(host string) (ips []IP, err error) { +// DefaultResolver is the resolver used by the package-level Lookup +// functions and by Dialers without a specified Resolver. +var DefaultResolver = &Resolver{} + +// A Resolver looks up names and numbers. +// +// A nil *Resolver is equivalent to a zero Resolver. +type Resolver struct { + // PreferGo controls whether Go's built-in DNS resolver is preferred + // on platforms where it's available. It is equivalent to setting + // GODEBUG=netdns=go, but scoped to just this resolver. + PreferGo bool + + // TODO(bradfitz): optional interface impl override hook + // TODO(bradfitz): Timeout time.Duration? +} + +// LookupHost looks up the given host using the local resolver. +// It returns a slice of that host's addresses. +func LookupHost(host string) (addrs []string, err error) { + return DefaultResolver.LookupHost(context.Background(), host) +} + +// LookupHost looks up the given host using the local resolver. +// It returns a slice of that host's addresses. +func (r *Resolver) LookupHost(ctx context.Context, host string) (addrs []string, err error) { // Make sure that no matter what we do later, host=="" is rejected. // ParseIP, for example, does accept empty strings. if host == "" { return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host} } if ip := ParseIP(host); ip != nil { - return []IP{ip}, nil - } - addrs, err := lookupIPMerge(context.Background(), host) - if err != nil { - return - } - ips = make([]IP, len(addrs)) - for i, addr := range addrs { - ips[i] = addr.IP + return []string{host}, nil } - return -} - -var lookupGroup singleflight.Group - -// lookupIPMerge wraps lookupIP, but makes sure that for any given -// host, only one lookup is in-flight at a time. The returned memory -// is always owned by the caller. -func lookupIPMerge(ctx context.Context, host string) (addrs []IPAddr, err error) { - addrsi, err, shared := lookupGroup.Do(host, func() (interface{}, error) { - return testHookLookupIP(ctx, lookupIP, host) - }) - return lookupIPReturn(addrsi, err, shared) + return r.lookupHost(ctx, host) } -// lookupIPReturn turns the return values from singleflight.Do into -// the return values from LookupIP. -func lookupIPReturn(addrsi interface{}, err error, shared bool) ([]IPAddr, error) { +// LookupIP looks up host using the local resolver. +// It returns a slice of that host's IPv4 and IPv6 addresses. +func LookupIP(host string) ([]IP, error) { + addrs, err := DefaultResolver.LookupIPAddr(context.Background(), host) if err != nil { return nil, err } - addrs := addrsi.([]IPAddr) - if shared { - clone := make([]IPAddr, len(addrs)) - copy(clone, addrs) - addrs = clone + ips := make([]IP, len(addrs)) + for i, ia := range addrs { + ips[i] = ia.IP } - return addrs, nil + return ips, nil } -// ipAddrsEface returns an empty interface slice of addrs. -func ipAddrsEface(addrs []IPAddr) []interface{} { - s := make([]interface{}, len(addrs)) - for i, v := range addrs { - s[i] = v +// LookupIPAddr looks up host using the local resolver. +// It returns a slice of that host's IPv4 and IPv6 addresses. +func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, error) { + // Make sure that no matter what we do later, host=="" is rejected. + // ParseIP, for example, does accept empty strings. + if host == "" { + return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host} + } + if ip := ParseIP(host); ip != nil { + return []IPAddr{{IP: ip}}, nil } - return s -} - -// lookupIPContext looks up a hostname with a context. -// -// TODO(bradfitz): rename this function. All the other -// build-tag-specific lookupIP funcs also take a context now, so this -// name is no longer great. Maybe make this lookupIPMerge and ditch -// the other one, making its callers call this instead with a -// context.Background(). -func lookupIPContext(ctx context.Context, host string) (addrs []IPAddr, err error) { trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace) if trace != nil && trace.DNSStart != nil { trace.DNSStart(host) @@ -110,7 +153,7 @@ func lookupIPContext(ctx context.Context, host string) (addrs []IPAddr, err erro // The underlying resolver func is lookupIP by default but it // can be overridden by tests. This is needed by net/http, so it // uses a context key instead of unexported variables. - resolverFunc := lookupIP + resolverFunc := r.lookupIP if alt, _ := ctx.Value(nettrace.LookupIPAltResolverKey{}).(func(context.Context, string) ([]IPAddr, error)); alt != nil { resolverFunc = alt } @@ -140,11 +183,46 @@ func lookupIPContext(ctx context.Context, host string) (addrs []IPAddr, err erro } } +// lookupGroup merges LookupIPAddr calls together for lookups +// for the same host. The lookupGroup key is is the LookupIPAddr.host +// argument. +// The return values are ([]IPAddr, error). +var lookupGroup singleflight.Group + +// lookupIPReturn turns the return values from singleflight.Do into +// the return values from LookupIP. +func lookupIPReturn(addrsi interface{}, err error, shared bool) ([]IPAddr, error) { + if err != nil { + return nil, err + } + addrs := addrsi.([]IPAddr) + if shared { + clone := make([]IPAddr, len(addrs)) + copy(clone, addrs) + addrs = clone + } + return addrs, nil +} + +// ipAddrsEface returns an empty interface slice of addrs. +func ipAddrsEface(addrs []IPAddr) []interface{} { + s := make([]interface{}, len(addrs)) + for i, v := range addrs { + s[i] = v + } + return s +} + // LookupPort looks up the port for the given network and service. func LookupPort(network, service string) (port int, err error) { + return DefaultResolver.LookupPort(context.Background(), network, service) +} + +// LookupPort looks up the port for the given network and service. +func (r *Resolver) LookupPort(ctx context.Context, network, service string) (port int, err error) { port, needsLookup := parsePort(service) if needsLookup { - port, err = lookupPort(context.Background(), network, service) + port, err = r.lookupPort(ctx, network, service) if err != nil { return 0, err } @@ -155,12 +233,32 @@ func LookupPort(network, service string) (port int, err error) { return port, nil } -// LookupCNAME returns the canonical DNS host for the given name. +// LookupCNAME returns the canonical name for the given host. +// Callers that do not care about the canonical name can call +// LookupHost or LookupIP directly; both take care of resolving +// the canonical name as part of the lookup. +// +// A canonical name is the final name after following zero +// or more CNAME records. +// LookupCNAME does not return an error if host does not +// contain DNS "CNAME" records, as long as host resolves to +// address records. +func LookupCNAME(host string) (cname string, err error) { + return DefaultResolver.lookupCNAME(context.Background(), host) +} + +// LookupCNAME returns the canonical name for the given host. // Callers that do not care about the canonical name can call // LookupHost or LookupIP directly; both take care of resolving // the canonical name as part of the lookup. -func LookupCNAME(name string) (cname string, err error) { - return lookupCNAME(context.Background(), name) +// +// A canonical name is the final name after following zero +// or more CNAME records. +// LookupCNAME does not return an error if host does not +// contain DNS "CNAME" records, as long as host resolves to +// address records. +func (r *Resolver) LookupCNAME(ctx context.Context, host string) (cname string, err error) { + return r.lookupCNAME(ctx, host) } // LookupSRV tries to resolve an SRV query of the given service, @@ -173,26 +271,63 @@ func LookupCNAME(name string) (cname string, err error) { // publishing SRV records under non-standard names, if both service // and proto are empty strings, LookupSRV looks up name directly. func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) { - return lookupSRV(context.Background(), service, proto, name) + return DefaultResolver.lookupSRV(context.Background(), service, proto, name) +} + +// LookupSRV tries to resolve an SRV query of the given service, +// protocol, and domain name. The proto is "tcp" or "udp". +// The returned records are sorted by priority and randomized +// by weight within a priority. +// +// LookupSRV constructs the DNS name to look up following RFC 2782. +// That is, it looks up _service._proto.name. To accommodate services +// publishing SRV records under non-standard names, if both service +// and proto are empty strings, LookupSRV looks up name directly. +func (r *Resolver) LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*SRV, err error) { + return r.lookupSRV(ctx, service, proto, name) +} + +// LookupMX returns the DNS MX records for the given domain name sorted by preference. +func LookupMX(name string) ([]*MX, error) { + return DefaultResolver.lookupMX(context.Background(), name) } // LookupMX returns the DNS MX records for the given domain name sorted by preference. -func LookupMX(name string) (mxs []*MX, err error) { - return lookupMX(context.Background(), name) +func (r *Resolver) LookupMX(ctx context.Context, name string) ([]*MX, error) { + return r.lookupMX(ctx, name) } // LookupNS returns the DNS NS records for the given domain name. -func LookupNS(name string) (nss []*NS, err error) { - return lookupNS(context.Background(), name) +func LookupNS(name string) ([]*NS, error) { + return DefaultResolver.lookupNS(context.Background(), name) +} + +// LookupNS returns the DNS NS records for the given domain name. +func (r *Resolver) LookupNS(ctx context.Context, name string) ([]*NS, error) { + return r.lookupNS(ctx, name) } // LookupTXT returns the DNS TXT records for the given domain name. -func LookupTXT(name string) (txts []string, err error) { - return lookupTXT(context.Background(), name) +func LookupTXT(name string) ([]string, error) { + return DefaultResolver.lookupTXT(context.Background(), name) +} + +// LookupTXT returns the DNS TXT records for the given domain name. +func (r *Resolver) LookupTXT(ctx context.Context, name string) ([]string, error) { + return r.lookupTXT(ctx, name) } // LookupAddr performs a reverse lookup for the given address, returning a list // of names mapping to that address. +// +// When using the host C library resolver, at most one result will be +// returned. To bypass the host resolver, use a custom Resolver. func LookupAddr(addr string) (names []string, err error) { - return lookupAddr(context.Background(), addr) + return DefaultResolver.lookupAddr(context.Background(), addr) +} + +// LookupAddr performs a reverse lookup for the given address, returning a list +// of names mapping to that address. +func (r *Resolver) LookupAddr(ctx context.Context, addr string) (names []string, err error) { + return r.lookupAddr(ctx, addr) } diff --git a/libgo/go/net/lookup_nacl.go b/libgo/go/net/lookup_nacl.go new file mode 100644 index 0000000..43cebad --- /dev/null +++ b/libgo/go/net/lookup_nacl.go @@ -0,0 +1,52 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build nacl + +package net + +import ( + "context" + "syscall" +) + +func lookupProtocol(ctx context.Context, name string) (proto int, err error) { + return lookupProtocolMap(name) +} + +func (*Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) { + return nil, syscall.ENOPROTOOPT +} + +func (*Resolver) lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { + return nil, syscall.ENOPROTOOPT +} + +func (*Resolver) lookupPort(ctx context.Context, network, service string) (port int, err error) { + return goLookupPort(network, service) +} + +func (*Resolver) lookupCNAME(ctx context.Context, name string) (cname string, err error) { + return "", syscall.ENOPROTOOPT +} + +func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (cname string, srvs []*SRV, err error) { + return "", nil, syscall.ENOPROTOOPT +} + +func (*Resolver) lookupMX(ctx context.Context, name string) (mxs []*MX, err error) { + return nil, syscall.ENOPROTOOPT +} + +func (*Resolver) lookupNS(ctx context.Context, name string) (nss []*NS, err error) { + return nil, syscall.ENOPROTOOPT +} + +func (*Resolver) lookupTXT(ctx context.Context, name string) (txts []string, err error) { + return nil, syscall.ENOPROTOOPT +} + +func (*Resolver) lookupAddr(ctx context.Context, addr string) (ptrs []string, err error) { + return nil, syscall.ENOPROTOOPT +} diff --git a/libgo/go/net/lookup_plan9.go b/libgo/go/net/lookup_plan9.go index 3f7af2a..f81e220 100644 --- a/libgo/go/net/lookup_plan9.go +++ b/libgo/go/net/lookup_plan9.go @@ -111,17 +111,20 @@ func lookupProtocol(ctx context.Context, name string) (proto int, err error) { return 0, UnknownNetworkError(name) } s := f[1] - if n, _, ok := dtoi(s, byteIndex(s, '=')+1); ok { + if n, _, ok := dtoi(s[byteIndex(s, '=')+1:]); ok { return n, nil } return 0, UnknownNetworkError(name) } -func lookupHost(ctx context.Context, host string) (addrs []string, err error) { +func (*Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) { // Use netdir/cs instead of netdir/dns because cs knows about // host names in local network (e.g. from /lib/ndb/local) lines, err := queryCS(ctx, "net", host, "1") if err != nil { + if stringsHasSuffix(err.Error(), "dns failure") { + err = errNoSuchHost + } return } loop: @@ -148,8 +151,8 @@ loop: return } -func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { - lits, err := lookupHost(ctx, host) +func (r *Resolver) lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { + lits, err := r.lookupHost(ctx, host) if err != nil { return } @@ -163,14 +166,14 @@ func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { return } -func lookupPort(ctx context.Context, network, service string) (port int, err error) { +func (*Resolver) lookupPort(ctx context.Context, network, service string) (port int, err error) { switch network { case "tcp4", "tcp6": network = "tcp" case "udp4", "udp6": network = "udp" } - lines, err := queryCS(ctx, network, "127.0.0.1", service) + lines, err := queryCS(ctx, network, "127.0.0.1", toLower(service)) if err != nil { return } @@ -186,15 +189,19 @@ func lookupPort(ctx context.Context, network, service string) (port int, err err if i := byteIndex(s, '!'); i >= 0 { s = s[i+1:] // remove address } - if n, _, ok := dtoi(s, 0); ok { + if n, _, ok := dtoi(s); ok { return n, nil } return 0, unknownPortError } -func lookupCNAME(ctx context.Context, name string) (cname string, err error) { +func (*Resolver) lookupCNAME(ctx context.Context, name string) (cname string, err error) { lines, err := queryDNS(ctx, name, "cname") if err != nil { + if stringsHasSuffix(err.Error(), "dns failure") { + cname = name + "." + err = nil + } return } if len(lines) > 0 { @@ -205,7 +212,7 @@ func lookupCNAME(ctx context.Context, name string) (cname string, err error) { return "", errors.New("bad response from ndb/dns") } -func lookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*SRV, err error) { +func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*SRV, err error) { var target string if service == "" && proto == "" { target = name @@ -221,9 +228,9 @@ func lookupSRV(ctx context.Context, service, proto, name string) (cname string, if len(f) < 6 { continue } - port, _, portOk := dtoi(f[4], 0) - priority, _, priorityOk := dtoi(f[3], 0) - weight, _, weightOk := dtoi(f[2], 0) + port, _, portOk := dtoi(f[4]) + priority, _, priorityOk := dtoi(f[3]) + weight, _, weightOk := dtoi(f[2]) if !(portOk && priorityOk && weightOk) { continue } @@ -234,7 +241,7 @@ func lookupSRV(ctx context.Context, service, proto, name string) (cname string, return } -func lookupMX(ctx context.Context, name string) (mx []*MX, err error) { +func (*Resolver) lookupMX(ctx context.Context, name string) (mx []*MX, err error) { lines, err := queryDNS(ctx, name, "mx") if err != nil { return @@ -244,7 +251,7 @@ func lookupMX(ctx context.Context, name string) (mx []*MX, err error) { if len(f) < 4 { continue } - if pref, _, ok := dtoi(f[2], 0); ok { + if pref, _, ok := dtoi(f[2]); ok { mx = append(mx, &MX{absDomainName([]byte(f[3])), uint16(pref)}) } } @@ -252,7 +259,7 @@ func lookupMX(ctx context.Context, name string) (mx []*MX, err error) { return } -func lookupNS(ctx context.Context, name string) (ns []*NS, err error) { +func (*Resolver) lookupNS(ctx context.Context, name string) (ns []*NS, err error) { lines, err := queryDNS(ctx, name, "ns") if err != nil { return @@ -267,7 +274,7 @@ func lookupNS(ctx context.Context, name string) (ns []*NS, err error) { return } -func lookupTXT(ctx context.Context, name string) (txt []string, err error) { +func (*Resolver) lookupTXT(ctx context.Context, name string) (txt []string, err error) { lines, err := queryDNS(ctx, name, "txt") if err != nil { return @@ -280,7 +287,7 @@ func lookupTXT(ctx context.Context, name string) (txt []string, err error) { return } -func lookupAddr(ctx context.Context, addr string) (name []string, err error) { +func (*Resolver) lookupAddr(ctx context.Context, addr string) (name []string, err error) { arpa, err := reverseaddr(addr) if err != nil { return diff --git a/libgo/go/net/lookup_stub.go b/libgo/go/net/lookup_stub.go deleted file mode 100644 index bd096b3..0000000 --- a/libgo/go/net/lookup_stub.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build nacl - -package net - -import ( - "context" - "syscall" -) - -func lookupProtocol(ctx context.Context, name string) (proto int, err error) { - return 0, syscall.ENOPROTOOPT -} - -func lookupHost(ctx context.Context, host string) (addrs []string, err error) { - return nil, syscall.ENOPROTOOPT -} - -func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { - return nil, syscall.ENOPROTOOPT -} - -func lookupPort(ctx context.Context, network, service string) (port int, err error) { - return 0, syscall.ENOPROTOOPT -} - -func lookupCNAME(ctx context.Context, name string) (cname string, err error) { - return "", syscall.ENOPROTOOPT -} - -func lookupSRV(ctx context.Context, service, proto, name string) (cname string, srvs []*SRV, err error) { - return "", nil, syscall.ENOPROTOOPT -} - -func lookupMX(ctx context.Context, name string) (mxs []*MX, err error) { - return nil, syscall.ENOPROTOOPT -} - -func lookupNS(ctx context.Context, name string) (nss []*NS, err error) { - return nil, syscall.ENOPROTOOPT -} - -func lookupTXT(ctx context.Context, name string) (txts []string, err error) { - return nil, syscall.ENOPROTOOPT -} - -func lookupAddr(ctx context.Context, addr string) (ptrs []string, err error) { - return nil, syscall.ENOPROTOOPT -} diff --git a/libgo/go/net/lookup_test.go b/libgo/go/net/lookup_test.go index b3aeb85..36db56a 100644 --- a/libgo/go/net/lookup_test.go +++ b/libgo/go/net/lookup_test.go @@ -243,14 +243,15 @@ func TestLookupIPv6LinkLocalAddr(t *testing.T) { } } -var lookupIANACNAMETests = []struct { +var lookupCNAMETests = []struct { name, cname string }{ {"www.iana.org", "icann.org."}, {"www.iana.org.", "icann.org."}, + {"www.google.com", "google.com."}, } -func TestLookupIANACNAME(t *testing.T) { +func TestLookupCNAME(t *testing.T) { if testenv.Builder() == "" { testenv.MustHaveExternalNetwork(t) } @@ -259,7 +260,7 @@ func TestLookupIANACNAME(t *testing.T) { t.Skip("IPv4 is required") } - for _, tt := range lookupIANACNAMETests { + for _, tt := range lookupCNAMETests { cname, err := LookupCNAME(tt.name) if err != nil { t.Fatal(err) @@ -398,11 +399,11 @@ func TestDNSFlood(t *testing.T) { for i := 0; i < N; i++ { name := fmt.Sprintf("%d.net-test.golang.org", i) go func() { - _, err := lookupIPContext(ctxHalfTimeout, name) + _, err := DefaultResolver.LookupIPAddr(ctxHalfTimeout, name) c <- err }() go func() { - _, err := lookupIPContext(ctxTimeout, name) + _, err := DefaultResolver.LookupIPAddr(ctxTimeout, name) c <- err }() } @@ -616,7 +617,7 @@ func srvString(srvs []*SRV) string { func TestLookupPort(t *testing.T) { // See http://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.xhtml // - // Please be careful about adding new mappings for testings. + // Please be careful about adding new test cases. // There are platforms having incomplete mappings for // restricted resource access and security reasons. type test struct { @@ -648,8 +649,6 @@ func TestLookupPort(t *testing.T) { } switch runtime.GOOS { - case "nacl": - t.Skipf("not supported on %s", runtime.GOOS) case "android": if netGo { t.Skipf("not supported on %s without cgo; see golang.org/issues/14576", runtime.GOOS) @@ -670,3 +669,73 @@ func TestLookupPort(t *testing.T) { } } } + +// Like TestLookupPort but with minimal tests that should always pass +// because the answers are baked-in to the net package. +func TestLookupPort_Minimal(t *testing.T) { + type test struct { + network string + name string + port int + } + var tests = []test{ + {"tcp", "http", 80}, + {"tcp", "HTTP", 80}, // case shouldn't matter + {"tcp", "https", 443}, + {"tcp", "ssh", 22}, + {"tcp", "gopher", 70}, + {"tcp4", "http", 80}, + {"tcp6", "http", 80}, + } + + for _, tt := range tests { + port, err := LookupPort(tt.network, tt.name) + if port != tt.port || err != nil { + t.Errorf("LookupPort(%q, %q) = %d, %v; want %d, error=nil", tt.network, tt.name, port, err, tt.port) + } + } +} + +func TestLookupProtocol_Minimal(t *testing.T) { + type test struct { + name string + want int + } + var tests = []test{ + {"tcp", 6}, + {"TcP", 6}, // case shouldn't matter + {"icmp", 1}, + {"igmp", 2}, + {"udp", 17}, + {"ipv6-icmp", 58}, + } + + for _, tt := range tests { + got, err := lookupProtocol(context.Background(), tt.name) + if got != tt.want || err != nil { + t.Errorf("LookupProtocol(%q) = %d, %v; want %d, error=nil", tt.name, got, err, tt.want) + } + } + +} + +func TestLookupNonLDH(t *testing.T) { + if runtime.GOOS == "nacl" { + t.Skip("skip on nacl") + } + if fixup := forceGoDNS(); fixup != nil { + defer fixup() + } + + // "LDH" stands for letters, digits, and hyphens and is the usual + // description of standard DNS names. + // This test is checking that other kinds of names are reported + // as not found, not reported as invalid names. + addrs, err := LookupHost("!!!.###.bogus..domain.") + if err == nil { + t.Fatalf("lookup succeeded: %v", addrs) + } + if !strings.HasSuffix(err.Error(), errNoSuchHost.Error()) { + t.Fatalf("lookup error = %v, want %v", err, errNoSuchHost) + } +} diff --git a/libgo/go/net/lookup_unix.go b/libgo/go/net/lookup_unix.go index 15397e8..be2ced9 100644 --- a/libgo/go/net/lookup_unix.go +++ b/libgo/go/net/lookup_unix.go @@ -26,7 +26,7 @@ func readProtocols() { if len(f) < 2 { continue } - if proto, _, ok := dtoi(f[1], 0); ok { + if proto, _, ok := dtoi(f[1]); ok { if _, ok := protocols[f[0]]; !ok { protocols[f[0]] = proto } @@ -45,16 +45,12 @@ func readProtocols() { // returns correspondent protocol number. func lookupProtocol(_ context.Context, name string) (int, error) { onceReadProtocols.Do(readProtocols) - proto, found := protocols[name] - if !found { - return 0, &AddrError{Err: "unknown IP protocol specified", Addr: name} - } - return proto, nil + return lookupProtocolMap(name) } -func lookupHost(ctx context.Context, host string) (addrs []string, err error) { +func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) { order := systemConf().hostLookupOrder(host) - if order == hostLookupCgo { + if !r.PreferGo && order == hostLookupCgo { if addrs, err, ok := cgoLookupHost(ctx, host); ok { return addrs, err } @@ -64,7 +60,10 @@ func lookupHost(ctx context.Context, host string) (addrs []string, err error) { return goLookupHostOrder(ctx, host, order) } -func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { +func (r *Resolver) lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { + if r.PreferGo { + return goLookupIP(ctx, host) + } order := systemConf().hostLookupOrder(host) if order == hostLookupCgo { if addrs, err, ok := cgoLookupIP(ctx, host); ok { @@ -73,25 +72,28 @@ func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) { // cgo not available (or netgo); fall back to Go's DNS resolver order = hostLookupFilesDNS } - return goLookupIPOrder(ctx, host, order) + addrs, _, err = goLookupIPCNAMEOrder(ctx, host, order) + return } -func lookupPort(ctx context.Context, network, service string) (int, error) { - // TODO: use the context if there ever becomes a need. Related - // is issue 15321. But port lookup generally just involves - // local files, and the os package has no context support. The - // files might be on a remote filesystem, though. This should - // probably race goroutines if ctx != context.Background(). - if systemConf().canUseCgo() { +func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int, error) { + if !r.PreferGo && systemConf().canUseCgo() { if port, err, ok := cgoLookupPort(ctx, network, service); ok { + if err != nil { + // Issue 18213: if cgo fails, first check to see whether we + // have the answer baked-in to the net package. + if port, err := goLookupPort(network, service); err == nil { + return port, nil + } + } return port, err } } return goLookupPort(network, service) } -func lookupCNAME(ctx context.Context, name string) (string, error) { - if systemConf().canUseCgo() { +func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error) { + if !r.PreferGo && systemConf().canUseCgo() { if cname, err, ok := cgoLookupCNAME(ctx, name); ok { return cname, err } @@ -99,7 +101,7 @@ func lookupCNAME(ctx context.Context, name string) (string, error) { return goLookupCNAME(ctx, name) } -func lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) { +func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) { var target string if service == "" && proto == "" { target = name @@ -119,7 +121,7 @@ func lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV return cname, srvs, nil } -func lookupMX(ctx context.Context, name string) ([]*MX, error) { +func (*Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) { _, rrs, err := lookup(ctx, name, dnsTypeMX) if err != nil { return nil, err @@ -133,7 +135,7 @@ func lookupMX(ctx context.Context, name string) ([]*MX, error) { return mxs, nil } -func lookupNS(ctx context.Context, name string) ([]*NS, error) { +func (*Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) { _, rrs, err := lookup(ctx, name, dnsTypeNS) if err != nil { return nil, err @@ -145,7 +147,7 @@ func lookupNS(ctx context.Context, name string) ([]*NS, error) { return nss, nil } -func lookupTXT(ctx context.Context, name string) ([]string, error) { +func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) { _, rrs, err := lookup(ctx, name, dnsTypeTXT) if err != nil { return nil, err @@ -157,8 +159,8 @@ func lookupTXT(ctx context.Context, name string) ([]string, error) { return txts, nil } -func lookupAddr(ctx context.Context, addr string) ([]string, error) { - if systemConf().canUseCgo() { +func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) { + if !r.PreferGo && systemConf().canUseCgo() { if ptrs, err, ok := cgoLookupPTR(ctx, addr); ok { return ptrs, err } diff --git a/libgo/go/net/lookup_windows.go b/libgo/go/net/lookup_windows.go index 5f65c2d..5808293 100644 --- a/libgo/go/net/lookup_windows.go +++ b/libgo/go/net/lookup_windows.go @@ -12,10 +12,20 @@ import ( "unsafe" ) +const _WSAHOST_NOT_FOUND = syscall.Errno(11001) + +func winError(call string, err error) error { + switch err { + case _WSAHOST_NOT_FOUND: + return errNoSuchHost + } + return os.NewSyscallError(call, err) +} + func getprotobyname(name string) (proto int, err error) { p, err := syscall.GetProtoByName(name) if err != nil { - return 0, os.NewSyscallError("getprotobyname", err) + return 0, winError("getprotobyname", err) } return int(p.Proto), nil } @@ -43,7 +53,7 @@ func lookupProtocol(ctx context.Context, name string) (int, error) { select { case r := <-ch: if r.err != nil { - if proto, ok := protocols[name]; ok { + if proto, err := lookupProtocolMap(name); err == nil { return proto, nil } r.err = &DNSError{Err: r.err.Error(), Name: name} @@ -54,8 +64,8 @@ func lookupProtocol(ctx context.Context, name string) (int, error) { } } -func lookupHost(ctx context.Context, name string) ([]string, error) { - ips, err := lookupIP(ctx, name) +func (r *Resolver) lookupHost(ctx context.Context, name string) ([]string, error) { + ips, err := r.lookupIP(ctx, name) if err != nil { return nil, err } @@ -66,8 +76,8 @@ func lookupHost(ctx context.Context, name string) ([]string, error) { return addrs, nil } -func lookupIP(ctx context.Context, name string) ([]IPAddr, error) { - // TODO(bradfitz,brainman): use ctx? +func (r *Resolver) lookupIP(ctx context.Context, name string) ([]IPAddr, error) { + // TODO(bradfitz,brainman): use ctx more. See TODO below. type ret struct { addrs []IPAddr @@ -85,7 +95,7 @@ func lookupIP(ctx context.Context, name string) ([]IPAddr, error) { var result *syscall.AddrinfoW e := syscall.GetAddrInfoW(syscall.StringToUTF16Ptr(name), nil, &hints, &result) if e != nil { - ch <- ret{err: &DNSError{Err: os.NewSyscallError("getaddrinfow", e).Error(), Name: name}} + ch <- ret{err: &DNSError{Err: winError("getaddrinfow", e).Error(), Name: name}} } defer syscall.FreeAddrInfoW(result) addrs := make([]IPAddr, 0, 5) @@ -125,7 +135,11 @@ func lookupIP(ctx context.Context, name string) ([]IPAddr, error) { } } -func lookupPort(ctx context.Context, network, service string) (int, error) { +func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int, error) { + if r.PreferGo { + return lookupPortMap(network, service) + } + // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() @@ -144,7 +158,10 @@ func lookupPort(ctx context.Context, network, service string) (int, error) { var result *syscall.AddrinfoW e := syscall.GetAddrInfoW(nil, syscall.StringToUTF16Ptr(service), &hints, &result) if e != nil { - return 0, &DNSError{Err: os.NewSyscallError("getaddrinfow", e).Error(), Name: network + "/" + service} + if port, err := lookupPortMap(network, service); err == nil { + return port, nil + } + return 0, &DNSError{Err: winError("getaddrinfow", e).Error(), Name: network + "/" + service} } defer syscall.FreeAddrInfoW(result) if result == nil { @@ -162,7 +179,7 @@ func lookupPort(ctx context.Context, network, service string) (int, error) { return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service} } -func lookupCNAME(ctx context.Context, name string) (string, error) { +func (*Resolver) lookupCNAME(ctx context.Context, name string) (string, error) { // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() @@ -174,7 +191,7 @@ func lookupCNAME(ctx context.Context, name string) (string, error) { return absDomainName([]byte(name)), nil } if e != nil { - return "", &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: name} + return "", &DNSError{Err: winError("dnsquery", e).Error(), Name: name} } defer syscall.DnsRecordListFree(r, 1) @@ -183,7 +200,7 @@ func lookupCNAME(ctx context.Context, name string) (string, error) { return absDomainName([]byte(cname)), nil } -func lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) { +func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) { // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() @@ -196,7 +213,7 @@ func lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV var r *syscall.DNSRecord e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &r, nil) if e != nil { - return "", nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: target} + return "", nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: target} } defer syscall.DnsRecordListFree(r, 1) @@ -209,14 +226,14 @@ func lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV return absDomainName([]byte(target)), srvs, nil } -func lookupMX(ctx context.Context, name string) ([]*MX, error) { +func (*Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) { // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() var r *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &r, nil) if e != nil { - return nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: name} + return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name} } defer syscall.DnsRecordListFree(r, 1) @@ -229,14 +246,14 @@ func lookupMX(ctx context.Context, name string) ([]*MX, error) { return mxs, nil } -func lookupNS(ctx context.Context, name string) ([]*NS, error) { +func (*Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) { // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() var r *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &r, nil) if e != nil { - return nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: name} + return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name} } defer syscall.DnsRecordListFree(r, 1) @@ -248,14 +265,14 @@ func lookupNS(ctx context.Context, name string) ([]*NS, error) { return nss, nil } -func lookupTXT(ctx context.Context, name string) ([]string, error) { +func (*Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) { // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() var r *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &r, nil) if e != nil { - return nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: name} + return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name} } defer syscall.DnsRecordListFree(r, 1) @@ -270,7 +287,7 @@ func lookupTXT(ctx context.Context, name string) ([]string, error) { return txts, nil } -func lookupAddr(ctx context.Context, addr string) ([]string, error) { +func (*Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) { // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. acquireThread() defer releaseThread() @@ -281,7 +298,7 @@ func lookupAddr(ctx context.Context, addr string) ([]string, error) { var r *syscall.DNSRecord e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &r, nil) if e != nil { - return nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: addr} + return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: addr} } defer syscall.DnsRecordListFree(r, 1) diff --git a/libgo/go/net/mail/message.go b/libgo/go/net/mail/message.go index 0c00069..702b765 100644 --- a/libgo/go/net/mail/message.go +++ b/libgo/go/net/mail/message.go @@ -92,7 +92,8 @@ func init() { } } -func parseDate(date string) (time.Time, error) { +// ParseDate parses an RFC 5322 date string. +func ParseDate(date string) (time.Time, error) { for _, layout := range dateLayouts { t, err := time.Parse(layout, date) if err == nil { @@ -106,7 +107,11 @@ func parseDate(date string) (time.Time, error) { type Header map[string][]string // Get gets the first value associated with the given key. +// It is case insensitive; CanonicalMIMEHeaderKey is used +// to canonicalize the provided key. // If there are no values associated with the key, Get returns "". +// To access multiple values of a key, or to use non-canonical keys, +// access the map directly. func (h Header) Get(key string) string { return textproto.MIMEHeader(h).Get(key) } @@ -119,7 +124,7 @@ func (h Header) Date() (time.Time, error) { if hdr == "" { return time.Time{}, ErrHeaderNotPresent } - return parseDate(hdr) + return ParseDate(hdr) } // AddressList parses the named header field as a list of addresses. @@ -345,6 +350,9 @@ func (p *addrParser) consumeAddrSpec() (spec string, err error) { // quoted-string debug.Printf("consumeAddrSpec: parsing quoted-string") localPart, err = p.consumeQuotedString() + if localPart == "" { + err = errors.New("mail: empty quoted string in addr-spec") + } } else { // dot-atom debug.Printf("consumeAddrSpec: parsing dot-atom") @@ -462,9 +470,6 @@ Loop: i += size } p.s = p.s[i+1:] - if len(qsb) == 0 { - return "", errors.New("mail: empty quoted-string") - } return string(qsb), nil } diff --git a/libgo/go/net/mail/message_test.go b/libgo/go/net/mail/message_test.go index bbbba6b..f0761ab 100644 --- a/libgo/go/net/mail/message_test.go +++ b/libgo/go/net/mail/message_test.go @@ -110,11 +110,16 @@ func TestDateParsing(t *testing.T) { } date, err := hdr.Date() if err != nil { - t.Errorf("Failed parsing %q: %v", test.dateStr, err) - continue + t.Errorf("Header(Date: %s).Date(): %v", test.dateStr, err) + } else if !date.Equal(test.exp) { + t.Errorf("Header(Date: %s).Date() = %+v, want %+v", test.dateStr, date, test.exp) } - if !date.Equal(test.exp) { - t.Errorf("Parse of %q: got %+v, want %+v", test.dateStr, date, test.exp) + + date, err = ParseDate(test.dateStr) + if err != nil { + t.Errorf("ParseDate(%s): %v", test.dateStr, err) + } else if !date.Equal(test.exp) { + t.Errorf("ParseDate(%s) = %+v, want %+v", test.dateStr, date, test.exp) } } } @@ -310,6 +315,16 @@ func TestAddressParsing(t *testing.T) { }, }, }, + // Issue 14866 + { + `"" <emptystring@example.com>`, + []*Address{ + { + Name: "", + Address: "emptystring@example.com", + }, + }, + }, } for _, test := range tests { if len(test.exp) == 1 { diff --git a/libgo/go/net/main_test.go b/libgo/go/net/main_test.go index 7573ded..28a8ff6 100644 --- a/libgo/go/net/main_test.go +++ b/libgo/go/net/main_test.go @@ -24,6 +24,8 @@ var ( ) var ( + testTCPBig = flag.Bool("tcpbig", false, "whether to test massive size of data per read or write call on TCP connection") + testDNSFlood = flag.Bool("dnsflood", false, "whether to test DNS query flooding") // If external IPv4 connectivity exists, we can try dialing diff --git a/libgo/go/net/net.go b/libgo/go/net/net.go index d6812d1..81206ea 100644 --- a/libgo/go/net/net.go +++ b/libgo/go/net/net.go @@ -102,9 +102,13 @@ func init() { } // Addr represents a network end point address. +// +// The two methods Network and String conventionally return strings +// that can be passed as the arguments to Dial, but the exact form +// and meaning of the strings is up to the implementation. type Addr interface { - Network() string // name of the network - String() string // string form of address + Network() string // name of the network (for example, "tcp", "udp") + String() string // string form of address (for example, "192.0.2.1:25", "[2001:db8::1]:80") } // Conn is a generic stream-oriented network connection. @@ -112,12 +116,12 @@ type Addr interface { // Multiple goroutines may invoke methods on a Conn simultaneously. type Conn interface { // Read reads data from the connection. - // Read can be made to time out and return a Error with Timeout() == true + // Read can be made to time out and return an Error with Timeout() == true // after a fixed time limit; see SetDeadline and SetReadDeadline. Read(b []byte) (n int, err error) // Write writes data to the connection. - // Write can be made to time out and return a Error with Timeout() == true + // Write can be made to time out and return an Error with Timeout() == true // after a fixed time limit; see SetDeadline and SetWriteDeadline. Write(b []byte) (n int, err error) @@ -137,8 +141,10 @@ type Conn interface { // // A deadline is an absolute time after which I/O operations // fail with a timeout (see type Error) instead of - // blocking. The deadline applies to all future I/O, not just - // the immediately following call to Read or Write. + // blocking. The deadline applies to all future and pending + // I/O, not just the immediately following call to Read or + // Write. After a deadline has been exceeded, the connection + // can be refreshed by setting a deadline in the future. // // An idle timeout can be implemented by repeatedly extending // the deadline after successful Read or Write calls. @@ -146,11 +152,13 @@ type Conn interface { // A zero value for t means I/O operations will not time out. SetDeadline(t time.Time) error - // SetReadDeadline sets the deadline for future Read calls. + // SetReadDeadline sets the deadline for future Read calls + // and any currently-blocked Read call. // A zero value for t means Read will not time out. SetReadDeadline(t time.Time) error - // SetWriteDeadline sets the deadline for future Write calls. + // SetWriteDeadline sets the deadline for future Write calls + // and any currently-blocked Write call. // Even if write times out, it may return n > 0, indicating that // some of the data was successfully written. // A zero value for t means Write will not time out. @@ -302,13 +310,13 @@ type PacketConn interface { // bytes copied into b and the return address that // was on the packet. // ReadFrom can be made to time out and return - // an error with Timeout() == true after a fixed time limit; + // an Error with Timeout() == true after a fixed time limit; // see SetDeadline and SetReadDeadline. ReadFrom(b []byte) (n int, addr Addr, err error) // WriteTo writes a packet with payload b to addr. // WriteTo can be made to time out and return - // an error with Timeout() == true after a fixed time limit; + // an Error with Timeout() == true after a fixed time limit; // see SetDeadline and SetWriteDeadline. // On packet-oriented connections, write timeouts are rare. WriteTo(b []byte, addr Addr) (n int, err error) @@ -321,21 +329,32 @@ type PacketConn interface { LocalAddr() Addr // SetDeadline sets the read and write deadlines associated - // with the connection. + // with the connection. It is equivalent to calling both + // SetReadDeadline and SetWriteDeadline. + // + // A deadline is an absolute time after which I/O operations + // fail with a timeout (see type Error) instead of + // blocking. The deadline applies to all future and pending + // I/O, not just the immediately following call to ReadFrom or + // WriteTo. After a deadline has been exceeded, the connection + // can be refreshed by setting a deadline in the future. + // + // An idle timeout can be implemented by repeatedly extending + // the deadline after successful ReadFrom or WriteTo calls. + // + // A zero value for t means I/O operations will not time out. SetDeadline(t time.Time) error - // SetReadDeadline sets the deadline for future Read calls. - // If the deadline is reached, Read will fail with a timeout - // (see type Error) instead of blocking. - // A zero value for t means Read will not time out. + // SetReadDeadline sets the deadline for future ReadFrom calls + // and any currently-blocked ReadFrom call. + // A zero value for t means ReadFrom will not time out. SetReadDeadline(t time.Time) error - // SetWriteDeadline sets the deadline for future Write calls. - // If the deadline is reached, Write will fail with a timeout - // (see type Error) instead of blocking. - // A zero value for t means Write will not time out. + // SetWriteDeadline sets the deadline for future WriteTo calls + // and any currently-blocked WriteTo call. // Even if write times out, it may return n > 0, indicating that // some of the data was successfully written. + // A zero value for t means WriteTo will not time out. SetWriteDeadline(t time.Time) error } @@ -512,7 +531,7 @@ func (e *AddrError) Error() string { } s := e.Err if e.Addr != "" { - s += " " + e.Addr + s = "address " + e.Addr + ": " + s } return s } @@ -604,3 +623,66 @@ func acquireThread() { func releaseThread() { <-threadLimit } + +// buffersWriter is the interface implemented by Conns that support a +// "writev"-like batch write optimization. +// writeBuffers should fully consume and write all chunks from the +// provided Buffers, else it should report a non-nil error. +type buffersWriter interface { + writeBuffers(*Buffers) (int64, error) +} + +var testHookDidWritev = func(wrote int) {} + +// Buffers contains zero or more runs of bytes to write. +// +// On certain machines, for certain types of connections, this is +// optimized into an OS-specific batch write operation (such as +// "writev"). +type Buffers [][]byte + +var ( + _ io.WriterTo = (*Buffers)(nil) + _ io.Reader = (*Buffers)(nil) +) + +func (v *Buffers) WriteTo(w io.Writer) (n int64, err error) { + if wv, ok := w.(buffersWriter); ok { + return wv.writeBuffers(v) + } + for _, b := range *v { + nb, err := w.Write(b) + n += int64(nb) + if err != nil { + v.consume(n) + return n, err + } + } + v.consume(n) + return n, nil +} + +func (v *Buffers) Read(p []byte) (n int, err error) { + for len(p) > 0 && len(*v) > 0 { + n0 := copy(p, (*v)[0]) + v.consume(int64(n0)) + p = p[n0:] + n += n0 + } + if len(*v) == 0 { + err = io.EOF + } + return +} + +func (v *Buffers) consume(n int64) { + for len(*v) > 0 { + ln0 := int64(len((*v)[0])) + if ln0 > n { + (*v)[0] = (*v)[0][n:] + return + } + n -= ln0 + *v = (*v)[1:] + } +} diff --git a/libgo/go/net/net_test.go b/libgo/go/net/net_test.go index b2f825d..9a9a7e5 100644 --- a/libgo/go/net/net_test.go +++ b/libgo/go/net/net_test.go @@ -5,6 +5,8 @@ package net import ( + "errors" + "fmt" "io" "net/internal/socktest" "os" @@ -15,7 +17,7 @@ import ( func TestCloseRead(t *testing.T) { switch runtime.GOOS { - case "nacl", "plan9": + case "plan9": t.Skipf("not supported on %s", runtime.GOOS) } @@ -414,3 +416,103 @@ func TestZeroByteRead(t *testing.T) { } } } + +// withTCPConnPair sets up a TCP connection between two peers, then +// runs peer1 and peer2 concurrently. withTCPConnPair returns when +// both have completed. +func withTCPConnPair(t *testing.T, peer1, peer2 func(c *TCPConn) error) { + ln, err := newLocalListener("tcp") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + errc := make(chan error, 2) + go func() { + c1, err := ln.Accept() + if err != nil { + errc <- err + return + } + defer c1.Close() + errc <- peer1(c1.(*TCPConn)) + }() + go func() { + c2, err := Dial("tcp", ln.Addr().String()) + if err != nil { + errc <- err + return + } + defer c2.Close() + errc <- peer2(c2.(*TCPConn)) + }() + for i := 0; i < 2; i++ { + if err := <-errc; err != nil { + t.Fatal(err) + } + } +} + +// Tests that a blocked Read is interrupted by a concurrent SetReadDeadline +// modifying that Conn's read deadline to the past. +// See golang.org/cl/30164 which documented this. The net/http package +// depends on this. +func TestReadTimeoutUnblocksRead(t *testing.T) { + serverDone := make(chan struct{}) + server := func(cs *TCPConn) error { + defer close(serverDone) + errc := make(chan error, 1) + go func() { + defer close(errc) + go func() { + // TODO: find a better way to wait + // until we're blocked in the cs.Read + // call below. Sleep is lame. + time.Sleep(100 * time.Millisecond) + + // Interrupt the upcoming Read, unblocking it: + cs.SetReadDeadline(time.Unix(123, 0)) // time in the past + }() + var buf [1]byte + n, err := cs.Read(buf[:1]) + if n != 0 || err == nil { + errc <- fmt.Errorf("Read = %v, %v; want 0, non-nil", n, err) + } + }() + select { + case err := <-errc: + return err + case <-time.After(5 * time.Second): + buf := make([]byte, 2<<20) + buf = buf[:runtime.Stack(buf, true)] + println("Stacks at timeout:\n", string(buf)) + return errors.New("timeout waiting for Read to finish") + } + + } + // Do nothing in the client. Never write. Just wait for the + // server's half to be done. + client := func(*TCPConn) error { + <-serverDone + return nil + } + withTCPConnPair(t, client, server) +} + +// Issue 17695: verify that a blocked Read is woken up by a Close. +func TestCloseUnblocksRead(t *testing.T) { + t.Parallel() + server := func(cs *TCPConn) error { + // Give the client time to get stuck in a Read: + time.Sleep(20 * time.Millisecond) + cs.Close() + return nil + } + client := func(ss *TCPConn) error { + n, err := ss.Read([]byte{0}) + if n != 0 || err != io.EOF { + return fmt.Errorf("Read = %v, %v; want 0, EOF", n, err) + } + return nil + } + withTCPConnPair(t, client, server) +} diff --git a/libgo/go/net/parse.go b/libgo/go/net/parse.go index 2c6b98a..b270159 100644 --- a/libgo/go/net/parse.go +++ b/libgo/go/net/parse.go @@ -124,39 +124,27 @@ func getFields(s string) []string { return splitAtBytes(s, " \r\t\n") } // Bigger than we need, not too big to worry about overflow const big = 0xFFFFFF -// Decimal to integer starting at &s[i0]. -// Returns number, new offset, success. -func dtoi(s string, i0 int) (n int, i int, ok bool) { +// Decimal to integer. +// Returns number, characters consumed, success. +func dtoi(s string) (n int, i int, ok bool) { n = 0 - neg := false - if len(s) > 0 && s[0] == '-' { - neg = true - s = s[1:] - } - for i = i0; i < len(s) && '0' <= s[i] && s[i] <= '9'; i++ { + for i = 0; i < len(s) && '0' <= s[i] && s[i] <= '9'; i++ { n = n*10 + int(s[i]-'0') if n >= big { - if neg { - return -big, i + 1, false - } return big, i, false } } - if i == i0 { - return 0, i, false - } - if neg { - n = -n - i++ + if i == 0 { + return 0, 0, false } return n, i, true } -// Hexadecimal to integer starting at &s[i0]. -// Returns number, new offset, success. -func xtoi(s string, i0 int) (n int, i int, ok bool) { +// Hexadecimal to integer. +// Returns number, characters consumed, success. +func xtoi(s string) (n int, i int, ok bool) { n = 0 - for i = i0; i < len(s); i++ { + for i = 0; i < len(s); i++ { if '0' <= s[i] && s[i] <= '9' { n *= 16 n += int(s[i] - '0') @@ -173,7 +161,7 @@ func xtoi(s string, i0 int) (n int, i int, ok bool) { return 0, i, false } } - if i == i0 { + if i == 0 { return 0, i, false } return n, i, true @@ -187,7 +175,7 @@ func xtoi2(s string, e byte) (byte, bool) { if len(s) > 2 && s[2] != e { return 0, false } - n, ei, ok := xtoi(s[:2], 0) + n, ei, ok := xtoi(s[:2]) return byte(n), ok && ei == 2 } @@ -348,22 +336,28 @@ func stringsHasSuffix(s, suffix string) bool { // stringsHasSuffixFold reports whether s ends in suffix, // ASCII-case-insensitively. func stringsHasSuffixFold(s, suffix string) bool { - if len(suffix) > len(s) { + return len(s) >= len(suffix) && stringsEqualFold(s[len(s)-len(suffix):], suffix) +} + +// stringsHasPrefix is strings.HasPrefix. It reports whether s begins with prefix. +func stringsHasPrefix(s, prefix string) bool { + return len(s) >= len(prefix) && s[:len(prefix)] == prefix +} + +// stringsEqualFold is strings.EqualFold, ASCII only. It reports whether s and t +// are equal, ASCII-case-insensitively. +func stringsEqualFold(s, t string) bool { + if len(s) != len(t) { return false } - for i := 0; i < len(suffix); i++ { - if lowerASCII(suffix[i]) != lowerASCII(s[len(s)-len(suffix)+i]) { + for i := 0; i < len(s); i++ { + if lowerASCII(s[i]) != lowerASCII(t[i]) { return false } } return true } -// stringsHasPrefix is strings.HasPrefix. It reports whether s begins with prefix. -func stringsHasPrefix(s, prefix string) bool { - return len(s) >= len(prefix) && s[:len(prefix)] == prefix -} - func readFull(r io.Reader) (all []byte, err error) { buf := make([]byte, 1024) for { diff --git a/libgo/go/net/parse_test.go b/libgo/go/net/parse_test.go index fec9200..c5f8bfd 100644 --- a/libgo/go/net/parse_test.go +++ b/libgo/go/net/parse_test.go @@ -86,14 +86,13 @@ func TestDtoi(t *testing.T) { ok bool }{ {"", 0, 0, false}, - - {"-123456789", -big, 9, false}, - {"-1", -1, 2, true}, {"0", 0, 1, true}, {"65536", 65536, 5, true}, {"123456789", big, 8, false}, + {"-0", 0, 0, false}, + {"-1234", 0, 0, false}, } { - n, i, ok := dtoi(tt.in, 0) + n, i, ok := dtoi(tt.in) if n != tt.out || i != tt.off || ok != tt.ok { t.Errorf("got %d, %d, %v; want %d, %d, %v", n, i, ok, tt.out, tt.off, tt.ok) } diff --git a/libgo/go/net/port_unix.go b/libgo/go/net/port_unix.go index badf8ab..868d1e4 100644 --- a/libgo/go/net/port_unix.go +++ b/libgo/go/net/port_unix.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build darwin dragonfly freebsd linux netbsd openbsd solaris +// +build darwin dragonfly freebsd linux netbsd openbsd solaris nacl // Read system port mappings from /etc/services @@ -10,31 +10,24 @@ package net import "sync" -// services contains minimal mappings between services names and port -// numbers for platforms that don't have a complete list of port numbers -// (some Solaris distros). -var services = map[string]map[string]int{ - "tcp": {"http": 80}, -} -var servicesError error var onceReadServices sync.Once func readServices() { - var file *file - if file, servicesError = open("/etc/services"); servicesError != nil { + file, err := open("/etc/services") + if err != nil { return } for line, ok := file.readLine(); ok; line, ok = file.readLine() { // "http 80/tcp www www-http # World Wide Web HTTP" if i := byteIndex(line, '#'); i >= 0 { - line = line[0:i] + line = line[:i] } f := getFields(line) if len(f) < 2 { continue } portnet := f[1] // "80/tcp" - port, j, ok := dtoi(portnet, 0) + port, j, ok := dtoi(portnet) if !ok || port <= 0 || j >= len(portnet) || portnet[j] != '/' { continue } @@ -56,18 +49,5 @@ func readServices() { // goLookupPort is the native Go implementation of LookupPort. func goLookupPort(network, service string) (port int, err error) { onceReadServices.Do(readServices) - - switch network { - case "tcp4", "tcp6": - network = "tcp" - case "udp4", "udp6": - network = "udp" - } - - if m, ok := services[network]; ok { - if port, ok = m[service]; ok { - return - } - } - return 0, &AddrError{Err: "unknown port", Addr: network + "/" + service} + return lookupPortMap(network, service) } diff --git a/libgo/go/net/rpc/client.go b/libgo/go/net/rpc/client.go index 862fb1a..fce6a48 100644 --- a/libgo/go/net/rpc/client.go +++ b/libgo/go/net/rpc/client.go @@ -274,6 +274,8 @@ func Dial(network, address string) (*Client, error) { return NewClient(conn), nil } +// Close calls the underlying codec's Close method. If the connection is already +// shutting down, ErrShutdown is returned. func (client *Client) Close() error { client.mutex.Lock() if client.closing { diff --git a/libgo/go/net/rpc/client_test.go b/libgo/go/net/rpc/client_test.go index ba11ff8..d116d2a 100644 --- a/libgo/go/net/rpc/client_test.go +++ b/libgo/go/net/rpc/client_test.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "net" - "runtime" "strings" "testing" ) @@ -53,9 +52,6 @@ func (s *S) Recv(nul *struct{}, reply *R) error { } func TestGobError(t *testing.T) { - if runtime.GOOS == "plan9" { - t.Skip("skipping test; see https://golang.org/issue/8908") - } defer func() { err := recover() if err == nil { diff --git a/libgo/go/net/rpc/server.go b/libgo/go/net/rpc/server.go index cff3241..18ea629 100644 --- a/libgo/go/net/rpc/server.go +++ b/libgo/go/net/rpc/server.go @@ -23,7 +23,7 @@ func (t *T) MethodName(argType T1, replyType *T2) error - where T, T1 and T2 can be marshaled by encoding/gob. + where T1 and T2 can be marshaled by encoding/gob. These requirements apply even if a different codec is used. (In the future, these requirements may soften for custom codecs.) @@ -55,6 +55,8 @@ package server + import "errors" + type Args struct { A, B int } @@ -119,6 +121,8 @@ A server implementation will often provide a simple, type-safe wrapper for the client. + + The net/rpc package is frozen and is not accepting new features. */ package rpc diff --git a/libgo/go/net/rpc/server_test.go b/libgo/go/net/rpc/server_test.go index d04271d..8369c9d 100644 --- a/libgo/go/net/rpc/server_test.go +++ b/libgo/go/net/rpc/server_test.go @@ -693,7 +693,8 @@ func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) { B := call.Args.(*Args).B C := call.Reply.(*Reply).C if A+B != C { - b.Fatalf("incorrect reply: Add: expected %d got %d", A+B, C) + b.Errorf("incorrect reply: Add: expected %d got %d", A+B, C) + return } <-gate if atomic.AddInt32(&recv, -1) == 0 { diff --git a/libgo/go/net/smtp/smtp.go b/libgo/go/net/smtp/smtp.go index 9e04dd7..a408fa5 100644 --- a/libgo/go/net/smtp/smtp.go +++ b/libgo/go/net/smtp/smtp.go @@ -9,7 +9,7 @@ // STARTTLS RFC 3207 // Additional extensions may be handled by clients. // -// The smtp package is frozen and not accepting new features. +// The smtp package is frozen and is not accepting new features. // Some external packages provide more functionality. See: // // https://godoc.org/?q=smtp @@ -19,6 +19,7 @@ import ( "crypto/tls" "encoding/base64" "errors" + "fmt" "io" "net" "net/textproto" @@ -200,7 +201,7 @@ func (c *Client) Auth(a Auth) error { } resp64 := make([]byte, encoding.EncodedLen(len(resp))) encoding.Encode(resp64, resp) - code, msg64, err := c.cmd(0, "AUTH %s %s", mech, resp64) + code, msg64, err := c.cmd(0, strings.TrimSpace(fmt.Sprintf("AUTH %s %s", mech, resp64))) for err == nil { var msg []byte switch code { diff --git a/libgo/go/net/smtp/smtp_test.go b/libgo/go/net/smtp/smtp_test.go index 3ae0d5b..c48fae6d 100644 --- a/libgo/go/net/smtp/smtp_test.go +++ b/libgo/go/net/smtp/smtp_test.go @@ -94,6 +94,46 @@ func TestAuthPlain(t *testing.T) { } } +// Issue 17794: don't send a trailing space on AUTH command when there's no password. +func TestClientAuthTrimSpace(t *testing.T) { + server := "220 hello world\r\n" + + "200 some more" + var wrote bytes.Buffer + var fake faker + fake.ReadWriter = struct { + io.Reader + io.Writer + }{ + strings.NewReader(server), + &wrote, + } + c, err := NewClient(fake, "fake.host") + if err != nil { + t.Fatalf("NewClient: %v", err) + } + c.tls = true + c.didHello = true + c.Auth(toServerEmptyAuth{}) + c.Close() + if got, want := wrote.String(), "AUTH FOOAUTH\r\n*\r\nQUIT\r\n"; got != want { + t.Errorf("wrote %q; want %q", got, want) + } +} + +// toServerEmptyAuth is an implementation of Auth that only implements +// the Start method, and returns "FOOAUTH", nil, nil. Notably, it returns +// zero bytes for "toServer" so we can test that we don't send spaces at +// the end of the line. See TestClientAuthTrimSpace. +type toServerEmptyAuth struct{} + +func (toServerEmptyAuth) Start(server *ServerInfo) (proto string, toServer []byte, err error) { + return "FOOAUTH", nil, nil +} + +func (toServerEmptyAuth) Next(fromServer []byte, more bool) (toServer []byte, err error) { + panic("unexpected call") +} + type faker struct { io.ReadWriter } @@ -716,23 +756,24 @@ func sendMail(hostPort string) error { // generated from src/crypto/tls: // go run generate_cert.go --rsa-bits 512 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h var localhostCert = []byte(`-----BEGIN CERTIFICATE----- -MIIBdzCCASOgAwIBAgIBADALBgkqhkiG9w0BAQUwEjEQMA4GA1UEChMHQWNtZSBD -bzAeFw03MDAxMDEwMDAwMDBaFw00OTEyMzEyMzU5NTlaMBIxEDAOBgNVBAoTB0Fj -bWUgQ28wWjALBgkqhkiG9w0BAQEDSwAwSAJBAN55NcYKZeInyTuhcCwFMhDHCmwa -IUSdtXdcbItRB/yfXGBhiex00IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEA -AaNoMGYwDgYDVR0PAQH/BAQDAgCkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1Ud -EwEB/wQFMAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAA -AAAAAAAAAAAAAAEwCwYJKoZIhvcNAQEFA0EAAoQn/ytgqpiLcZu9XKbCJsJcvkgk -Se6AbGXgSlq+ZCEVo0qIwSgeBqmsJxUu7NCSOwVJLYNEBO2DtIxoYVk+MA== +MIIBjjCCATigAwIBAgIQMon9v0s3pDFXvAMnPgelpzANBgkqhkiG9w0BAQsFADAS +MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw +MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJB +AM0u/mNXKkhAzNsFkwKZPSpC4lZZaePQ55IyaJv3ovMM2smvthnlqaUfVKVmz7FF +wLP9csX6vGtvkZg1uWAtvfkCAwEAAaNoMGYwDgYDVR0PAQH/BAQDAgKkMBMGA1Ud +JQQMMAoGCCsGAQUFBwMBMA8GA1UdEwEB/wQFMAMBAf8wLgYDVR0RBCcwJYILZXhh +bXBsZS5jb22HBH8AAAGHEAAAAAAAAAAAAAAAAAAAAAEwDQYJKoZIhvcNAQELBQAD +QQBOZsFVC7IwX+qibmSbt2IPHkUgXhfbq0a9MYhD6tHcj4gbDcTXh4kZCbgHCz22 +gfSj2/G2wxzopoISVDucuncj -----END CERTIFICATE-----`) // localhostKey is the private key for localhostCert. var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- -MIIBPAIBAAJBAN55NcYKZeInyTuhcCwFMhDHCmwaIUSdtXdcbItRB/yfXGBhiex0 -0IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEAAQJBAQdUx66rfh8sYsgfdcvV -NoafYpnEcB5s4m/vSVe6SU7dCK6eYec9f9wpT353ljhDUHq3EbmE4foNzJngh35d -AekCIQDhRQG5Li0Wj8TM4obOnnXUXf1jRv0UkzE9AHWLG5q3AwIhAPzSjpYUDjVW -MCUXgckTpKCuGwbJk7424Nb8bLzf3kllAiA5mUBgjfr/WtFSJdWcPQ4Zt9KTMNKD -EUO0ukpTwEIl6wIhAMbGqZK3zAAFdq8DD2jPx+UJXnh0rnOkZBzDtJ6/iN69AiEA -1Aq8MJgTaYsDQWyU/hDq5YkDJc9e9DSCvUIzqxQWMQE= +MIIBOwIBAAJBAM0u/mNXKkhAzNsFkwKZPSpC4lZZaePQ55IyaJv3ovMM2smvthnl +qaUfVKVmz7FFwLP9csX6vGtvkZg1uWAtvfkCAwEAAQJART2qkxODLUbQ2siSx7m2 +rmBLyR/7X+nLe8aPDrMOxj3heDNl4YlaAYLexbcY8d7VDfCRBKYoAOP0UCP1Vhuf +UQIhAO6PEI55K3SpNIdc2k5f0xz+9rodJCYzu51EwWX7r8ufAiEA3C9EkLiU2NuK +3L3DHCN5IlUSN1Nr/lw8NIt50Yorj2cCIQCDw1VbvCV6bDLtSSXzAA51B4ZzScE7 +sHtB5EYF9Dwm9QIhAJuCquuH4mDzVjUntXjXOQPdj7sRqVGCNWdrJwOukat7AiAy +LXLEwb77DIPoI5ZuaXQC+MnyyJj1ExC9RFcGz+bexA== -----END RSA PRIVATE KEY-----`) diff --git a/libgo/go/net/sock_linux.go b/libgo/go/net/sock_linux.go index e2732c5..7bca376 100644 --- a/libgo/go/net/sock_linux.go +++ b/libgo/go/net/sock_linux.go @@ -17,7 +17,7 @@ func maxListenerBacklog() int { return syscall.SOMAXCONN } f := getFields(l) - n, _, ok := dtoi(f[0], 0) + n, _, ok := dtoi(f[0]) if n == 0 || !ok { return syscall.SOMAXCONN } diff --git a/libgo/go/net/sock_posix.go b/libgo/go/net/sock_posix.go index c3af27b..16351e1 100644 --- a/libgo/go/net/sock_posix.go +++ b/libgo/go/net/sock_posix.go @@ -30,6 +30,9 @@ type sockaddr interface { // interface. It returns a nil interface when the address is // nil. sockaddr(family int) (syscall.Sockaddr, error) + + // toLocal maps the zero address to a local system address (127.0.0.1 or ::1) + toLocal(net string) sockaddr } // socket returns a network file descriptor that is ready for diff --git a/libgo/go/net/tcpsock.go b/libgo/go/net/tcpsock.go index 7cffcc5..69731eb 100644 --- a/libgo/go/net/tcpsock.go +++ b/libgo/go/net/tcpsock.go @@ -12,6 +12,9 @@ import ( "time" ) +// BUG(mikio): On Windows, the File method of TCPListener is not +// implemented. + // TCPAddr represents the address of a TCP end point. type TCPAddr struct { IP IP @@ -53,6 +56,9 @@ func (a *TCPAddr) opAddr() Addr { // "tcp6". A literal address or host name for IPv6 must be enclosed // in square brackets, as in "[::1]:80", "[ipv6-host]:http" or // "[ipv6-host%zone]:80". +// +// Resolving a hostname is not recommended because this returns at most +// one of its IP addresses. func ResolveTCPAddr(net, addr string) (*TCPAddr, error) { switch net { case "tcp", "tcp4", "tcp6": @@ -61,7 +67,7 @@ func ResolveTCPAddr(net, addr string) (*TCPAddr, error) { default: return nil, UnknownNetworkError(net) } - addrs, err := internetAddrList(context.Background(), net, addr) + addrs, err := DefaultResolver.internetAddrList(context.Background(), net, addr) if err != nil { return nil, err } @@ -81,7 +87,7 @@ func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) { } n, err := c.readFrom(r) if err != nil && err != io.EOF { - err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err} + err = &OpError{Op: "readfrom", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err} } return n, err } diff --git a/libgo/go/net/tcpsock_posix.go b/libgo/go/net/tcpsock_posix.go index c9a8b68..9641e5c 100644 --- a/libgo/go/net/tcpsock_posix.go +++ b/libgo/go/net/tcpsock_posix.go @@ -40,6 +40,10 @@ func (a *TCPAddr) sockaddr(family int) (syscall.Sockaddr, error) { return ipToSockaddr(family, a.IP, a.Port, a.Zone) } +func (a *TCPAddr) toLocal(net string) sockaddr { + return &TCPAddr{loopbackIP(net), a.Port, a.Zone} +} + func (c *TCPConn) readFrom(r io.Reader) (int64, error) { if n, err, handled := sendFile(c.fd, r); handled { return n, err diff --git a/libgo/go/net/tcpsock_test.go b/libgo/go/net/tcpsock_test.go index a8d93b0..5115422 100644 --- a/libgo/go/net/tcpsock_test.go +++ b/libgo/go/net/tcpsock_test.go @@ -5,6 +5,7 @@ package net import ( + "fmt" "internal/testenv" "io" "reflect" @@ -310,6 +311,17 @@ var resolveTCPAddrTests = []resolveTCPAddrTest{ {"tcp", ":12345", &TCPAddr{Port: 12345}, nil}, {"http", "127.0.0.1:0", nil, UnknownNetworkError("http")}, + + {"tcp", "127.0.0.1:http", &TCPAddr{IP: ParseIP("127.0.0.1"), Port: 80}, nil}, + {"tcp", "[::ffff:127.0.0.1]:http", &TCPAddr{IP: ParseIP("::ffff:127.0.0.1"), Port: 80}, nil}, + {"tcp", "[2001:db8::1]:http", &TCPAddr{IP: ParseIP("2001:db8::1"), Port: 80}, nil}, + {"tcp4", "127.0.0.1:http", &TCPAddr{IP: ParseIP("127.0.0.1"), Port: 80}, nil}, + {"tcp4", "[::ffff:127.0.0.1]:http", &TCPAddr{IP: ParseIP("127.0.0.1"), Port: 80}, nil}, + {"tcp6", "[2001:db8::1]:http", &TCPAddr{IP: ParseIP("2001:db8::1"), Port: 80}, nil}, + + {"tcp4", "[2001:db8::1]:http", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "2001:db8::1"}}, + {"tcp6", "127.0.0.1:http", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "127.0.0.1"}}, + {"tcp6", "[::ffff:127.0.0.1]:http", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "::ffff:127.0.0.1"}}, } func TestResolveTCPAddr(t *testing.T) { @@ -317,21 +329,17 @@ func TestResolveTCPAddr(t *testing.T) { defer func() { testHookLookupIP = origTestHookLookupIP }() testHookLookupIP = lookupLocalhost - for i, tt := range resolveTCPAddrTests { + for _, tt := range resolveTCPAddrTests { addr, err := ResolveTCPAddr(tt.network, tt.litAddrOrName) - if err != tt.err { - t.Errorf("#%d: %v", i, err) - } else if !reflect.DeepEqual(addr, tt.addr) { - t.Errorf("#%d: got %#v; want %#v", i, addr, tt.addr) - } - if err != nil { + if !reflect.DeepEqual(addr, tt.addr) || !reflect.DeepEqual(err, tt.err) { + t.Errorf("ResolveTCPAddr(%q, %q) = %#v, %v, want %#v, %v", tt.network, tt.litAddrOrName, addr, err, tt.addr, tt.err) continue } - rtaddr, err := ResolveTCPAddr(addr.Network(), addr.String()) - if err != nil { - t.Errorf("#%d: %v", i, err) - } else if !reflect.DeepEqual(rtaddr, addr) { - t.Errorf("#%d: got %#v; want %#v", i, rtaddr, addr) + if err == nil { + addr2, err := ResolveTCPAddr(addr.Network(), addr.String()) + if !reflect.DeepEqual(addr2, tt.addr) || err != tt.err { + t.Errorf("(%q, %q): ResolveTCPAddr(%q, %q) = %#v, %v, want %#v, %v", tt.network, tt.litAddrOrName, addr.Network(), addr.String(), addr2, err, tt.addr, tt.err) + } } } } @@ -459,12 +467,19 @@ func TestTCPConcurrentAccept(t *testing.T) { } func TestTCPReadWriteAllocs(t *testing.T) { + if runtime.Compiler == "gccgo" { + t.Skip("skipping for gccgo until escape analysis is enabled") + } + switch runtime.GOOS { - case "nacl", "windows": + case "plan9": + // The implementation of asynchronous cancelable + // I/O on Plan 9 allocates memory. + // See net/fd_io_plan9.go. + t.Skipf("not supported on %s", runtime.GOOS) + case "nacl": // NaCl needs to allocate pseudo file descriptor // stuff. See syscall/fd_nacl.go. - // Windows uses closures and channels for IO - // completion port-based netpoll. See fd_windows.go. t.Skipf("not supported on %s", runtime.GOOS) } @@ -474,7 +489,7 @@ func TestTCPReadWriteAllocs(t *testing.T) { } defer ln.Close() var server Conn - errc := make(chan error) + errc := make(chan error, 1) go func() { var err error server, err = ln.Accept() @@ -489,6 +504,7 @@ func TestTCPReadWriteAllocs(t *testing.T) { t.Fatal(err) } defer server.Close() + var buf [128]byte allocs := testing.AllocsPerRun(1000, func() { _, err := server.Write(buf[:]) @@ -504,6 +520,28 @@ func TestTCPReadWriteAllocs(t *testing.T) { if allocs > 7 { t.Fatalf("got %v; want 0", allocs) } + + var bufwrt [128]byte + ch := make(chan bool) + defer close(ch) + go func() { + for <-ch { + _, err := server.Write(bufwrt[:]) + errc <- err + } + }() + allocs = testing.AllocsPerRun(1000, func() { + ch <- true + if _, err = io.ReadFull(client, buf[:]); err != nil { + t.Fatal(err) + } + if err := <-errc; err != nil { + t.Fatal(err) + } + }) + if allocs > 0 { + t.Fatalf("got %v; want 0", allocs) + } } func TestTCPStress(t *testing.T) { @@ -634,3 +672,58 @@ func TestTCPSelfConnect(t *testing.T) { } } } + +// Test that >32-bit reads work on 64-bit systems. +// On 32-bit systems this tests that maxint reads work. +func TestTCPBig(t *testing.T) { + if !*testTCPBig { + t.Skip("test disabled; use -tcpbig to enable") + } + + for _, writev := range []bool{false, true} { + t.Run(fmt.Sprintf("writev=%v", writev), func(t *testing.T) { + ln, err := newLocalListener("tcp") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + x := int(1 << 30) + x = x*5 + 1<<20 // just over 5 GB on 64-bit, just over 1GB on 32-bit + done := make(chan int) + go func() { + defer close(done) + c, err := ln.Accept() + if err != nil { + t.Error(err) + return + } + buf := make([]byte, x) + var n int + if writev { + var n64 int64 + n64, err = (&Buffers{buf}).WriteTo(c) + n = int(n64) + } else { + n, err = c.Write(buf) + } + if n != len(buf) || err != nil { + t.Errorf("Write(buf) = %d, %v, want %d, nil", n, err, x) + } + c.Close() + }() + + c, err := Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + buf := make([]byte, x) + n, err := io.ReadFull(c, buf) + if n != len(buf) || err != nil { + t.Errorf("Read(buf) = %d, %v, want %d, nil", n, err, x) + } + c.Close() + <-done + }) + } +} diff --git a/libgo/go/net/tcpsock_unix_test.go b/libgo/go/net/tcpsock_unix_test.go index c07f7d7..2375fe2 100644 --- a/libgo/go/net/tcpsock_unix_test.go +++ b/libgo/go/net/tcpsock_unix_test.go @@ -15,7 +15,7 @@ import ( ) // See golang.org/issue/14548. -func TestTCPSupriousConnSetupCompletion(t *testing.T) { +func TestTCPSpuriousConnSetupCompletion(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } @@ -57,7 +57,7 @@ func TestTCPSupriousConnSetupCompletion(t *testing.T) { c, err := d.Dial(ln.Addr().Network(), ln.Addr().String()) if err != nil { if perr := parseDialError(err); perr != nil { - t.Errorf("#%d: %v", i, err) + t.Errorf("#%d: %v (original error: %v)", i, perr, err) } return } diff --git a/libgo/go/net/testdata/invalid-ndots-resolv.conf b/libgo/go/net/testdata/invalid-ndots-resolv.conf new file mode 100644 index 0000000..084c164 --- /dev/null +++ b/libgo/go/net/testdata/invalid-ndots-resolv.conf @@ -0,0 +1 @@ +options ndots:invalid
\ No newline at end of file diff --git a/libgo/go/net/testdata/large-ndots-resolv.conf b/libgo/go/net/testdata/large-ndots-resolv.conf new file mode 100644 index 0000000..72968ee --- /dev/null +++ b/libgo/go/net/testdata/large-ndots-resolv.conf @@ -0,0 +1 @@ +options ndots:16
\ No newline at end of file diff --git a/libgo/go/net/testdata/negative-ndots-resolv.conf b/libgo/go/net/testdata/negative-ndots-resolv.conf new file mode 100644 index 0000000..c11e0cc --- /dev/null +++ b/libgo/go/net/testdata/negative-ndots-resolv.conf @@ -0,0 +1 @@ +options ndots:-1
\ No newline at end of file diff --git a/libgo/go/net/textproto/header.go b/libgo/go/net/textproto/header.go index 2e2752a..ed096d9 100644 --- a/libgo/go/net/textproto/header.go +++ b/libgo/go/net/textproto/header.go @@ -23,8 +23,10 @@ func (h MIMEHeader) Set(key, value string) { } // Get gets the first value associated with the given key. +// It is case insensitive; CanonicalMIMEHeaderKey is used +// to canonicalize the provided key. // If there are no values associated with the key, Get returns "". -// Get is a convenience method. For more complex queries, +// To access multiple values of a key, or to use non-canonical keys, // access the map directly. func (h MIMEHeader) Get(key string) string { if h == nil { diff --git a/libgo/go/net/timeout_test.go b/libgo/go/net/timeout_test.go index ed26f2a..55bbf44 100644 --- a/libgo/go/net/timeout_test.go +++ b/libgo/go/net/timeout_test.go @@ -5,7 +5,6 @@ package net import ( - "context" "fmt" "internal/testenv" "io" @@ -152,6 +151,7 @@ var acceptTimeoutTests = []struct { } func TestAcceptTimeout(t *testing.T) { + testenv.SkipFlaky(t, 17948) t.Parallel() switch runtime.GOOS { @@ -165,19 +165,18 @@ func TestAcceptTimeout(t *testing.T) { } defer ln.Close() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + var wg sync.WaitGroup for i, tt := range acceptTimeoutTests { if tt.timeout < 0 { + wg.Add(1) go func() { - var d Dialer - c, err := d.DialContext(ctx, ln.Addr().Network(), ln.Addr().String()) + defer wg.Done() + d := Dialer{Timeout: 100 * time.Millisecond} + c, err := d.Dial(ln.Addr().Network(), ln.Addr().String()) if err != nil { t.Error(err) return } - var b [1]byte - c.Read(b[:]) c.Close() }() } @@ -198,13 +197,14 @@ func TestAcceptTimeout(t *testing.T) { } if err == nil { c.Close() - time.Sleep(tt.timeout / 3) + time.Sleep(10 * time.Millisecond) continue } break } } } + wg.Wait() } func TestAcceptTimeoutMustReturn(t *testing.T) { @@ -305,11 +305,6 @@ var readTimeoutTests = []struct { } func TestReadTimeout(t *testing.T) { - switch runtime.GOOS { - case "plan9": - t.Skipf("not supported on %s", runtime.GOOS) - } - handler := func(ls *localServer, ln Listener) { c, err := ln.Accept() if err != nil { @@ -435,7 +430,7 @@ var readFromTimeoutTests = []struct { func TestReadFromTimeout(t *testing.T) { switch runtime.GOOS { - case "nacl", "plan9": + case "nacl": t.Skipf("not supported on %s", runtime.GOOS) // see golang.org/issue/8916 } @@ -509,11 +504,6 @@ var writeTimeoutTests = []struct { func TestWriteTimeout(t *testing.T) { t.Parallel() - switch runtime.GOOS { - case "plan9": - t.Skipf("not supported on %s", runtime.GOOS) - } - ln, err := newLocalListener("tcp") if err != nil { t.Fatal(err) @@ -629,7 +619,7 @@ func TestWriteToTimeout(t *testing.T) { t.Parallel() switch runtime.GOOS { - case "nacl", "plan9": + case "nacl": t.Skipf("not supported on %s", runtime.GOOS) } @@ -681,11 +671,6 @@ func TestWriteToTimeout(t *testing.T) { func TestReadTimeoutFluctuation(t *testing.T) { t.Parallel() - switch runtime.GOOS { - case "plan9": - t.Skipf("not supported on %s", runtime.GOOS) - } - ln, err := newLocalListener("tcp") if err != nil { t.Fatal(err) @@ -719,11 +704,6 @@ func TestReadTimeoutFluctuation(t *testing.T) { func TestReadFromTimeoutFluctuation(t *testing.T) { t.Parallel() - switch runtime.GOOS { - case "plan9": - t.Skipf("not supported on %s", runtime.GOOS) - } - c1, err := newLocalPacketListener("udp") if err != nil { t.Fatal(err) @@ -829,11 +809,6 @@ func (b neverEnding) Read(p []byte) (int, error) { } func testVariousDeadlines(t *testing.T) { - switch runtime.GOOS { - case "plan9": - t.Skipf("not supported on %s", runtime.GOOS) - } - type result struct { n int64 err error @@ -1030,7 +1005,7 @@ func TestReadWriteDeadlineRace(t *testing.T) { t.Parallel() switch runtime.GOOS { - case "nacl", "plan9": + case "nacl": t.Skipf("not supported on %s", runtime.GOOS) } diff --git a/libgo/go/net/udpsock.go b/libgo/go/net/udpsock.go index 980f67c..841ef53 100644 --- a/libgo/go/net/udpsock.go +++ b/libgo/go/net/udpsock.go @@ -9,6 +9,15 @@ import ( "syscall" ) +// BUG(mikio): On NaCl, Plan 9 and Windows, the ReadMsgUDP and +// WriteMsgUDP methods of UDPConn are not implemented. + +// BUG(mikio): On Windows, the File method of UDPConn is not +// implemented. + +// BUG(mikio): On NaCl, the ListenMulticastUDP function is not +// implemented. + // UDPAddr represents the address of a UDP end point. type UDPAddr struct { IP IP @@ -50,6 +59,9 @@ func (a *UDPAddr) opAddr() Addr { // "udp6". A literal address or host name for IPv6 must be enclosed // in square brackets, as in "[::1]:80", "[ipv6-host]:http" or // "[ipv6-host%zone]:80". +// +// Resolving a hostname is not recommended because this returns at most +// one of its IP addresses. func ResolveUDPAddr(net, addr string) (*UDPAddr, error) { switch net { case "udp", "udp4", "udp6": @@ -58,7 +70,7 @@ func ResolveUDPAddr(net, addr string) (*UDPAddr, error) { default: return nil, UnknownNetworkError(net) } - addrs, err := internetAddrList(context.Background(), net, addr) + addrs, err := DefaultResolver.internetAddrList(context.Background(), net, addr) if err != nil { return nil, err } diff --git a/libgo/go/net/udpsock_plan9.go b/libgo/go/net/udpsock_plan9.go index 666f206..1ce7f88 100644 --- a/libgo/go/net/udpsock_plan9.go +++ b/libgo/go/net/udpsock_plan9.go @@ -109,5 +109,41 @@ func listenUDP(ctx context.Context, network string, laddr *UDPAddr) (*UDPConn, e } func listenMulticastUDP(ctx context.Context, network string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) { - return nil, syscall.EPLAN9 + l, err := listenPlan9(ctx, network, gaddr) + if err != nil { + return nil, err + } + _, err = l.ctl.WriteString("headers") + if err != nil { + return nil, err + } + var addrs []Addr + if ifi != nil { + addrs, err = ifi.Addrs() + if err != nil { + return nil, err + } + } else { + addrs, err = InterfaceAddrs() + if err != nil { + return nil, err + } + } + for _, addr := range addrs { + if ipnet, ok := addr.(*IPNet); ok { + _, err = l.ctl.WriteString("addmulti " + ipnet.IP.String() + " " + gaddr.IP.String()) + if err != nil { + return nil, err + } + } + } + l.data, err = os.OpenFile(l.dir+"/data", os.O_RDWR, 0) + if err != nil { + return nil, err + } + fd, err := l.netFD() + if err != nil { + return nil, err + } + return newUDPConn(fd), nil } diff --git a/libgo/go/net/udpsock_plan9_test.go b/libgo/go/net/udpsock_plan9_test.go new file mode 100644 index 0000000..09f5a5d --- /dev/null +++ b/libgo/go/net/udpsock_plan9_test.go @@ -0,0 +1,69 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "internal/testenv" + "runtime" + "testing" +) + +func TestListenMulticastUDP(t *testing.T) { + testenv.MustHaveExternalNetwork(t) + + ifcs, err := Interfaces() + if err != nil { + t.Skip(err.Error()) + } + if len(ifcs) == 0 { + t.Skip("no network interfaces found") + } + + var mifc *Interface + for _, ifc := range ifcs { + if ifc.Flags&FlagUp|FlagMulticast != FlagUp|FlagMulticast { + continue + } + mifc = &ifc + break + } + + if mifc == nil { + t.Skipf("no multicast interfaces found") + } + + c1, err := ListenMulticastUDP("udp4", mifc, &UDPAddr{IP: ParseIP("224.0.0.254")}) + if err != nil { + t.Fatalf("multicast not working on %s", runtime.GOOS) + } + c1addr := c1.LocalAddr().(*UDPAddr) + if err != nil { + t.Fatal(err) + } + defer c1.Close() + + c2, err := ListenUDP("udp4", &UDPAddr{IP: IPv4zero, Port: 0}) + c2addr := c2.LocalAddr().(*UDPAddr) + if err != nil { + t.Fatal(err) + } + defer c2.Close() + + n, err := c2.WriteToUDP([]byte("data"), c1addr) + if err != nil { + t.Fatal(err) + } + if n != 4 { + t.Fatalf("got %d; want 4", n) + } + + n, err = c1.WriteToUDP([]byte("data"), c2addr) + if err != nil { + t.Fatal(err) + } + if n != 4 { + t.Fatalf("got %d; want 4", n) + } +} diff --git a/libgo/go/net/udpsock_posix.go b/libgo/go/net/udpsock_posix.go index 4924801..72aadca 100644 --- a/libgo/go/net/udpsock_posix.go +++ b/libgo/go/net/udpsock_posix.go @@ -38,6 +38,10 @@ func (a *UDPAddr) sockaddr(family int) (syscall.Sockaddr, error) { return ipToSockaddr(family, a.IP, a.Port, a.Zone) } +func (a *UDPAddr) toLocal(net string) sockaddr { + return &UDPAddr{loopbackIP(net), a.Port, a.Zone} +} + func (c *UDPConn) readFrom(b []byte) (int, *UDPAddr, error) { var addr *UDPAddr n, sa, err := c.fd.readFrom(b) diff --git a/libgo/go/net/udpsock_test.go b/libgo/go/net/udpsock_test.go index 29d769c..708cc10 100644 --- a/libgo/go/net/udpsock_test.go +++ b/libgo/go/net/udpsock_test.go @@ -72,6 +72,17 @@ var resolveUDPAddrTests = []resolveUDPAddrTest{ {"udp", ":12345", &UDPAddr{Port: 12345}, nil}, {"http", "127.0.0.1:0", nil, UnknownNetworkError("http")}, + + {"udp", "127.0.0.1:domain", &UDPAddr{IP: ParseIP("127.0.0.1"), Port: 53}, nil}, + {"udp", "[::ffff:127.0.0.1]:domain", &UDPAddr{IP: ParseIP("::ffff:127.0.0.1"), Port: 53}, nil}, + {"udp", "[2001:db8::1]:domain", &UDPAddr{IP: ParseIP("2001:db8::1"), Port: 53}, nil}, + {"udp4", "127.0.0.1:domain", &UDPAddr{IP: ParseIP("127.0.0.1"), Port: 53}, nil}, + {"udp4", "[::ffff:127.0.0.1]:domain", &UDPAddr{IP: ParseIP("127.0.0.1"), Port: 53}, nil}, + {"udp6", "[2001:db8::1]:domain", &UDPAddr{IP: ParseIP("2001:db8::1"), Port: 53}, nil}, + + {"udp4", "[2001:db8::1]:domain", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "2001:db8::1"}}, + {"udp6", "127.0.0.1:domain", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "127.0.0.1"}}, + {"udp6", "[::ffff:127.0.0.1]:domain", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "::ffff:127.0.0.1"}}, } func TestResolveUDPAddr(t *testing.T) { @@ -79,21 +90,17 @@ func TestResolveUDPAddr(t *testing.T) { defer func() { testHookLookupIP = origTestHookLookupIP }() testHookLookupIP = lookupLocalhost - for i, tt := range resolveUDPAddrTests { + for _, tt := range resolveUDPAddrTests { addr, err := ResolveUDPAddr(tt.network, tt.litAddrOrName) - if err != tt.err { - t.Errorf("#%d: %v", i, err) - } else if !reflect.DeepEqual(addr, tt.addr) { - t.Errorf("#%d: got %#v; want %#v", i, addr, tt.addr) - } - if err != nil { + if !reflect.DeepEqual(addr, tt.addr) || !reflect.DeepEqual(err, tt.err) { + t.Errorf("ResolveUDPAddr(%q, %q) = %#v, %v, want %#v, %v", tt.network, tt.litAddrOrName, addr, err, tt.addr, tt.err) continue } - rtaddr, err := ResolveUDPAddr(addr.Network(), addr.String()) - if err != nil { - t.Errorf("#%d: %v", i, err) - } else if !reflect.DeepEqual(rtaddr, addr) { - t.Errorf("#%d: got %#v; want %#v", i, rtaddr, addr) + if err == nil { + addr2, err := ResolveUDPAddr(addr.Network(), addr.String()) + if !reflect.DeepEqual(addr2, tt.addr) || err != tt.err { + t.Errorf("(%q, %q): ResolveUDPAddr(%q, %q) = %#v, %v, want %#v, %v", tt.network, tt.litAddrOrName, addr.Network(), addr.String(), addr2, err, tt.addr, tt.err) + } } } } diff --git a/libgo/go/net/unixsock.go b/libgo/go/net/unixsock.go index bacdaa4..b25d492 100644 --- a/libgo/go/net/unixsock.go +++ b/libgo/go/net/unixsock.go @@ -7,6 +7,7 @@ package net import ( "context" "os" + "sync" "syscall" "time" ) @@ -120,6 +121,9 @@ func (c *UnixConn) ReadFrom(b []byte) (int, Addr, error) { // the associated out-of-band data into oob. It returns the number of // bytes copied into b, the number of bytes copied into oob, the flags // that were set on the packet, and the source address of the packet. +// +// Note that if len(b) == 0 and len(oob) > 0, this function will still +// read (and discard) 1 byte from the connection. func (c *UnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err error) { if !c.ok() { return 0, 0, 0, nil, syscall.EINVAL @@ -167,6 +171,9 @@ func (c *UnixConn) WriteTo(b []byte, addr Addr) (int, error) { // WriteMsgUnix writes a packet to addr via c, copying the payload // from b and the associated out-of-band data from oob. It returns // the number of payload and out-of-band bytes written. +// +// Note that if len(b) == 0 and len(oob) > 0, this function will still +// write 1 byte to the connection. func (c *UnixConn) WriteMsgUnix(b, oob []byte, addr *UnixAddr) (n, oobn int, err error) { if !c.ok() { return 0, 0, syscall.EINVAL @@ -200,9 +207,10 @@ func DialUnix(net string, laddr, raddr *UnixAddr) (*UnixConn, error) { // typically use variables of type Listener instead of assuming Unix // domain sockets. type UnixListener struct { - fd *netFD - path string - unlink bool + fd *netFD + path string + unlink bool + unlinkOnce sync.Once } func (ln *UnixListener) ok() bool { return ln != nil && ln.fd != nil } diff --git a/libgo/go/net/unixsock_posix.go b/libgo/go/net/unixsock_posix.go index 5f0999c..a8f892e 100644 --- a/libgo/go/net/unixsock_posix.go +++ b/libgo/go/net/unixsock_posix.go @@ -94,6 +94,10 @@ func (a *UnixAddr) sockaddr(family int) (syscall.Sockaddr, error) { return &syscall.SockaddrUnix{Name: a.Name}, nil } +func (a *UnixAddr) toLocal(net string) sockaddr { + return a +} + func (c *UnixConn) readFrom(b []byte) (int, *UnixAddr, error) { var addr *UnixAddr n, sa, err := c.fd.readFrom(b) @@ -173,9 +177,12 @@ func (ln *UnixListener) close() error { // is at least compatible with the auto-remove // sequence in ListenUnix. It's only non-Go // programs that can mess us up. - if ln.path[0] != '@' && ln.unlink { - syscall.Unlink(ln.path) - } + // Even if there are racy calls to Close, we want to unlink only for the first one. + ln.unlinkOnce.Do(func() { + if ln.path[0] != '@' && ln.unlink { + syscall.Unlink(ln.path) + } + }) return ln.fd.Close() } @@ -187,6 +194,18 @@ func (ln *UnixListener) file() (*os.File, error) { return f, nil } +// SetUnlinkOnClose sets whether the underlying socket file should be removed +// from the file system when the listener is closed. +// +// The default behavior is to unlink the socket file only when package net created it. +// That is, when the listener and the underlying socket file were created by a call to +// Listen or ListenUnix, then by default closing the listener will remove the socket file. +// but if the listener was created by a call to FileListener to use an already existing +// socket file, then by default closing the listener will not remove the socket file. +func (l *UnixListener) SetUnlinkOnClose(unlink bool) { + l.unlink = unlink +} + func listenUnix(ctx context.Context, network string, laddr *UnixAddr) (*UnixListener, error) { fd, err := unixSocket(ctx, network, laddr, nil, "listen") if err != nil { diff --git a/libgo/go/net/unixsock_test.go b/libgo/go/net/unixsock_test.go index f0f88ed..489a29b 100644 --- a/libgo/go/net/unixsock_test.go +++ b/libgo/go/net/unixsock_test.go @@ -9,6 +9,7 @@ package net import ( "bytes" "internal/testenv" + "io/ioutil" "os" "reflect" "runtime" @@ -414,33 +415,104 @@ func TestUnixUnlink(t *testing.T) { t.Skip("unix test") } name := testUnixAddr() - l, err := Listen("unix", name) - if err != nil { - t.Fatal(err) - } - if _, err := os.Stat(name); err != nil { - t.Fatalf("cannot stat unix socket after ListenUnix: %v", err) - } - f, _ := l.(*UnixListener).File() - l1, err := FileListener(f) - if err != nil { - t.Fatal(err) - } - if _, err := os.Stat(name); err != nil { - t.Fatalf("cannot stat unix socket after FileListener: %v", err) - } - if err := l1.Close(); err != nil { - t.Fatalf("closing file listener: %v", err) - } - if _, err := os.Stat(name); err != nil { - t.Fatalf("cannot stat unix socket after closing FileListener: %v", err) + + listen := func(t *testing.T) *UnixListener { + l, err := Listen("unix", name) + if err != nil { + t.Fatal(err) + } + return l.(*UnixListener) } - f.Close() - if _, err := os.Stat(name); err != nil { - t.Fatalf("cannot stat unix socket after closing FileListener and fd: %v", err) + checkExists := func(t *testing.T, desc string) { + if _, err := os.Stat(name); err != nil { + t.Fatalf("unix socket does not exist %s: %v", desc, err) + } } - l.Close() - if _, err := os.Stat(name); err == nil { - t.Fatal("closing unix listener did not remove unix socket") + checkNotExists := func(t *testing.T, desc string) { + if _, err := os.Stat(name); err == nil { + t.Fatalf("unix socket does exist %s: %v", desc, err) + } } + + // Listener should remove on close. + t.Run("Listen", func(t *testing.T) { + l := listen(t) + checkExists(t, "after Listen") + l.Close() + checkNotExists(t, "after Listener close") + }) + + // FileListener should not. + t.Run("FileListener", func(t *testing.T) { + l := listen(t) + f, _ := l.File() + l1, _ := FileListener(f) + checkExists(t, "after FileListener") + f.Close() + checkExists(t, "after File close") + l1.Close() + checkExists(t, "after FileListener close") + l.Close() + checkNotExists(t, "after Listener close") + }) + + // Only first call to l.Close should remove. + t.Run("SecondClose", func(t *testing.T) { + l := listen(t) + checkExists(t, "after Listen") + l.Close() + checkNotExists(t, "after Listener close") + if err := ioutil.WriteFile(name, []byte("hello world"), 0666); err != nil { + t.Fatalf("cannot recreate socket file: %v", err) + } + checkExists(t, "after writing temp file") + l.Close() + checkExists(t, "after second Listener close") + os.Remove(name) + }) + + // SetUnlinkOnClose should do what it says. + + t.Run("Listen/SetUnlinkOnClose(true)", func(t *testing.T) { + l := listen(t) + checkExists(t, "after Listen") + l.SetUnlinkOnClose(true) + l.Close() + checkNotExists(t, "after Listener close") + }) + + t.Run("Listen/SetUnlinkOnClose(false)", func(t *testing.T) { + l := listen(t) + checkExists(t, "after Listen") + l.SetUnlinkOnClose(false) + l.Close() + checkExists(t, "after Listener close") + os.Remove(name) + }) + + t.Run("FileListener/SetUnlinkOnClose(true)", func(t *testing.T) { + l := listen(t) + f, _ := l.File() + l1, _ := FileListener(f) + checkExists(t, "after FileListener") + l1.(*UnixListener).SetUnlinkOnClose(true) + f.Close() + checkExists(t, "after File close") + l1.Close() + checkNotExists(t, "after FileListener close") + l.Close() + }) + + t.Run("FileListener/SetUnlinkOnClose(false)", func(t *testing.T) { + l := listen(t) + f, _ := l.File() + l1, _ := FileListener(f) + checkExists(t, "after FileListener") + l1.(*UnixListener).SetUnlinkOnClose(false) + f.Close() + checkExists(t, "after File close") + l1.Close() + checkExists(t, "after FileListener close") + l.Close() + }) } diff --git a/libgo/go/net/url/url.go b/libgo/go/net/url/url.go index 30e9277..42a514b 100644 --- a/libgo/go/net/url/url.go +++ b/libgo/go/net/url/url.go @@ -74,6 +74,7 @@ type encoding int const ( encodePath encoding = 1 + iota + encodePathSegment encodeHost encodeZone encodeUserPassword @@ -132,9 +133,14 @@ func shouldEscape(c byte, mode encoding) bool { // The RFC allows : @ & = + $ but saves / ; , for assigning // meaning to individual path segments. This package // only manipulates the path as a whole, so we allow those - // last two as well. That leaves only ? to escape. + // last three as well. That leaves only ? to escape. return c == '?' + case encodePathSegment: // §3.3 + // The RFC allows : @ & = + $ but saves / ; , for assigning + // meaning to individual path segments. + return c == '/' || c == ';' || c == ',' || c == '?' + case encodeUserPassword: // §3.2.1 // The RFC allows ';', ':', '&', '=', '+', '$', and ',' in // userinfo, so we must escape only '@', '/', and '?'. @@ -164,6 +170,15 @@ func QueryUnescape(s string) (string, error) { return unescape(s, encodeQueryComponent) } +// PathUnescape does the inverse transformation of PathEscape, converting +// %AB into the byte 0xAB. It returns an error if any % is not followed by +// two hexadecimal digits. +// +// PathUnescape is identical to QueryUnescape except that it does not unescape '+' to ' ' (space). +func PathUnescape(s string) (string, error) { + return unescape(s, encodePathSegment) +} + // unescape unescapes a string; the mode specifies // which section of the URL string is being unescaped. func unescape(s string, mode encoding) (string, error) { @@ -250,6 +265,12 @@ func QueryEscape(s string) string { return escape(s, encodeQueryComponent) } +// PathEscape escapes the string so it can be safely placed +// inside a URL path segment. +func PathEscape(s string) string { + return escape(s, encodePathSegment) +} + func escape(s string, mode encoding) string { spaceCount, hexCount := 0, 0 for i := 0; i < len(s); i++ { @@ -356,10 +377,7 @@ func (u *Userinfo) Username() string { // Password returns the password in case it is set, and whether it is set. func (u *Userinfo) Password() (string, bool) { - if u.passwordSet { - return u.password, true - } - return "", false + return u.password, u.passwordSet } // String returns the encoded userinfo information in the standard form @@ -420,7 +438,7 @@ func Parse(rawurl string) (*URL, error) { u, frag := split(rawurl, "#", true) url, err := parse(u, false) if err != nil { - return nil, err + return nil, &Error{"parse", u, err} } if frag == "" { return url, nil @@ -437,31 +455,35 @@ func Parse(rawurl string) (*URL, error) { // The string rawurl is assumed not to have a #fragment suffix. // (Web browsers strip #fragment before sending the URL to a web server.) func ParseRequestURI(rawurl string) (*URL, error) { - return parse(rawurl, true) + url, err := parse(rawurl, true) + if err != nil { + return nil, &Error{"parse", rawurl, err} + } + return url, nil } // parse parses a URL from a string in one of two contexts. If // viaRequest is true, the URL is assumed to have arrived via an HTTP request, // in which case only absolute URLs or path-absolute relative URLs are allowed. // If viaRequest is false, all forms of relative URLs are allowed. -func parse(rawurl string, viaRequest bool) (url *URL, err error) { +func parse(rawurl string, viaRequest bool) (*URL, error) { var rest string + var err error if rawurl == "" && viaRequest { - err = errors.New("empty url") - goto Error + return nil, errors.New("empty url") } - url = new(URL) + url := new(URL) if rawurl == "*" { url.Path = "*" - return + return url, nil } // Split off possible leading "http:", "mailto:", etc. // Cannot contain escaped characters. if url.Scheme, rest, err = getscheme(rawurl); err != nil { - goto Error + return nil, err } url.Scheme = strings.ToLower(url.Scheme) @@ -479,8 +501,20 @@ func parse(rawurl string, viaRequest bool) (url *URL, err error) { return url, nil } if viaRequest { - err = errors.New("invalid URI for request") - goto Error + return nil, errors.New("invalid URI for request") + } + + // Avoid confusion with malformed schemes, like cache_object:foo/bar. + // See golang.org/issue/16822. + // + // RFC 3986, §3.3: + // In addition, a URI reference (Section 4.1) may be a relative-path reference, + // in which case the first path segment cannot contain a colon (":") character. + colon := strings.Index(rest, ":") + slash := strings.Index(rest, "/") + if colon >= 0 && (slash < 0 || colon < slash) { + // First path segment has colon. Not allowed in relative URL. + return nil, errors.New("first path segment in URL cannot contain colon") } } @@ -489,23 +523,17 @@ func parse(rawurl string, viaRequest bool) (url *URL, err error) { authority, rest = split(rest[2:], "/", false) url.User, url.Host, err = parseAuthority(authority) if err != nil { - goto Error + return nil, err } } - if url.Path, err = unescape(rest, encodePath); err != nil { - goto Error - } - // RawPath is a hint as to the encoding of Path to use - // in url.EscapedPath. If that method already gets the - // right answer without RawPath, leave it empty. - // This will help make sure that people don't rely on it in general. - if url.EscapedPath() != rest && validEncodedPath(rest) { - url.RawPath = rest + // Set Path and, optionally, RawPath. + // RawPath is a hint of the encoding of Path. We don't want to set it if + // the default escaping of Path is equivalent, to help make sure that people + // don't rely on it in general. + if err := url.setPath(rest); err != nil { + return nil, err } return url, nil - -Error: - return nil, &Error{"parse", rawurl, err} } func parseAuthority(authority string) (user *Userinfo, host string, err error) { @@ -586,6 +614,29 @@ func parseHost(host string) (string, error) { return host, nil } +// setPath sets the Path and RawPath fields of the URL based on the provided +// escaped path p. It maintains the invariant that RawPath is only specified +// when it differs from the default encoding of the path. +// For example: +// - setPath("/foo/bar") will set Path="/foo/bar" and RawPath="" +// - setPath("/foo%2fbar") will set Path="/foo/bar" and RawPath="/foo%2fbar" +// setPath will return an error only if the provided path contains an invalid +// escaping. +func (u *URL) setPath(p string) error { + path, err := unescape(p, encodePath) + if err != nil { + return err + } + u.Path = path + if escp := escape(path, encodePath); p == escp { + // Default encoding is fine. + u.RawPath = "" + } else { + u.RawPath = p + } + return nil +} + // EscapedPath returns the escaped form of u.Path. // In general there are multiple possible escaped forms of any path. // EscapedPath returns u.RawPath when it is a valid escaping of u.Path. @@ -693,6 +744,17 @@ func (u *URL) String() string { if path != "" && path[0] != '/' && u.Host != "" { buf.WriteByte('/') } + if buf.Len() == 0 { + // RFC 3986 §4.2 + // A path segment that contains a colon character (e.g., "this:that") + // cannot be used as the first segment of a relative-path reference, as + // it would be mistaken for a scheme name. Such a segment must be + // preceded by a dot-segment (e.g., "./this:that") to make a relative- + // path reference. + if i := strings.IndexByte(path, ':'); i > -1 && strings.IndexByte(path[:i], '/') == -1 { + buf.WriteString("./") + } + } buf.WriteString(path) } if u.ForceQuery || u.RawQuery != "" { @@ -749,6 +811,10 @@ func (v Values) Del(key string) { // ParseQuery always returns a non-nil map containing all the // valid query parameters found; err describes the first decoding error // encountered, if any. +// +// Query is expected to be a list of key=value settings separated by +// ampersands or semicolons. A setting without an equals sign is +// interpreted as a key set to an empty value. func ParseQuery(query string) (Values, error) { m := make(Values) err := parseQuery(m, query) @@ -852,6 +918,7 @@ func resolvePath(base, ref string) string { } // IsAbs reports whether the URL is absolute. +// Absolute means that it has a non-empty scheme. func (u *URL) IsAbs() bool { return u.Scheme != "" } @@ -880,7 +947,9 @@ func (u *URL) ResolveReference(ref *URL) *URL { } if ref.Scheme != "" || ref.Host != "" || ref.User != nil { // The "absoluteURI" or "net_path" cases. - url.Path = resolvePath(ref.Path, "") + // We can ignore the error from setPath since we know we provided a + // validly-escaped path. + url.setPath(resolvePath(ref.EscapedPath(), "")) return &url } if ref.Opaque != "" { @@ -900,7 +969,7 @@ func (u *URL) ResolveReference(ref *URL) *URL { // The "abs_path" or "rel_path" cases. url.Host = u.Host url.User = u.User - url.Path = resolvePath(u.Path, ref.Path) + url.setPath(resolvePath(u.EscapedPath(), ref.EscapedPath())) return &url } @@ -929,3 +998,59 @@ func (u *URL) RequestURI() string { } return result } + +// Hostname returns u.Host, without any port number. +// +// If Host is an IPv6 literal with a port number, Hostname returns the +// IPv6 literal without the square brackets. IPv6 literals may include +// a zone identifier. +func (u *URL) Hostname() string { + return stripPort(u.Host) +} + +// Port returns the port part of u.Host, without the leading colon. +// If u.Host doesn't contain a port, Port returns an empty string. +func (u *URL) Port() string { + return portOnly(u.Host) +} + +func stripPort(hostport string) string { + colon := strings.IndexByte(hostport, ':') + if colon == -1 { + return hostport + } + if i := strings.IndexByte(hostport, ']'); i != -1 { + return strings.TrimPrefix(hostport[:i], "[") + } + return hostport[:colon] +} + +func portOnly(hostport string) string { + colon := strings.IndexByte(hostport, ':') + if colon == -1 { + return "" + } + if i := strings.Index(hostport, "]:"); i != -1 { + return hostport[i+len("]:"):] + } + if strings.Contains(hostport, "]") { + return "" + } + return hostport[colon+len(":"):] +} + +// Marshaling interface implementations. +// Would like to implement MarshalText/UnmarshalText but that will change the JSON representation of URLs. + +func (u *URL) MarshalBinary() (text []byte, err error) { + return []byte(u.String()), nil +} + +func (u *URL) UnmarshalBinary(text []byte) error { + u1, err := Parse(string(text)) + if err != nil { + return err + } + *u = *u1 + return nil +} diff --git a/libgo/go/net/url/url_test.go b/libgo/go/net/url/url_test.go index 7560f22..6c3bb21 100644 --- a/libgo/go/net/url/url_test.go +++ b/libgo/go/net/url/url_test.go @@ -5,6 +5,10 @@ package url import ( + "bytes" + encodingPkg "encoding" + "encoding/gob" + "encoding/json" "fmt" "io" "net" @@ -579,20 +583,6 @@ func ufmt(u *URL) string { u.Opaque, u.Scheme, user, pass, u.Host, u.Path, u.RawPath, u.RawQuery, u.Fragment, u.ForceQuery) } -func DoTest(t *testing.T, parse func(string) (*URL, error), name string, tests []URLTest) { - for _, tt := range tests { - u, err := parse(tt.in) - if err != nil { - t.Errorf("%s(%q) returned error %s", name, tt.in, err) - continue - } - if !reflect.DeepEqual(u, tt.out) { - t.Errorf("%s(%q):\n\thave %v\n\twant %v\n", - name, tt.in, ufmt(u), ufmt(tt.out)) - } - } -} - func BenchmarkString(b *testing.B) { b.StopTimer() b.ReportAllocs() @@ -618,7 +608,16 @@ func BenchmarkString(b *testing.B) { } func TestParse(t *testing.T) { - DoTest(t, Parse, "Parse", urltests) + for _, tt := range urltests { + u, err := Parse(tt.in) + if err != nil { + t.Errorf("Parse(%q) returned error %v", tt.in, err) + continue + } + if !reflect.DeepEqual(u, tt.out) { + t.Errorf("Parse(%q):\n\tgot %v\n\twant %v\n", tt.in, ufmt(u), ufmt(tt.out)) + } + } } const pathThatLooksSchemeRelative = "//not.a.user@not.a.host/just/a/path" @@ -665,9 +664,10 @@ var parseRequestURLTests = []struct { func TestParseRequestURI(t *testing.T) { for _, test := range parseRequestURLTests { _, err := ParseRequestURI(test.url) - valid := err == nil - if valid != test.expectedValid { - t.Errorf("Expected valid=%v for %q; got %v", test.expectedValid, test.url, valid) + if test.expectedValid && err != nil { + t.Errorf("ParseRequestURI(%q) gave err %v; want no error", test.url, err) + } else if !test.expectedValid && err == nil { + t.Errorf("ParseRequestURI(%q) gave nil error; want some error", test.url) } } @@ -676,45 +676,69 @@ func TestParseRequestURI(t *testing.T) { t.Fatalf("Unexpected error %v", err) } if url.Path != pathThatLooksSchemeRelative { - t.Errorf("Expected path %q; got %q", pathThatLooksSchemeRelative, url.Path) + t.Errorf("ParseRequestURI path:\ngot %q\nwant %q", url.Path, pathThatLooksSchemeRelative) } } -func DoTestString(t *testing.T, parse func(string) (*URL, error), name string, tests []URLTest) { - for _, tt := range tests { - u, err := parse(tt.in) +var stringURLTests = []struct { + url URL + want string +}{ + // No leading slash on path should prepend slash on String() call + { + url: URL{ + Scheme: "http", + Host: "www.google.com", + Path: "search", + }, + want: "http://www.google.com/search", + }, + // Relative path with first element containing ":" should be prepended with "./", golang.org/issue/17184 + { + url: URL{ + Path: "this:that", + }, + want: "./this:that", + }, + // Relative path with second element containing ":" should not be prepended with "./" + { + url: URL{ + Path: "here/this:that", + }, + want: "here/this:that", + }, + // Non-relative path with first element containing ":" should not be prepended with "./" + { + url: URL{ + Scheme: "http", + Host: "www.google.com", + Path: "this:that", + }, + want: "http://www.google.com/this:that", + }, +} + +func TestURLString(t *testing.T) { + for _, tt := range urltests { + u, err := Parse(tt.in) if err != nil { - t.Errorf("%s(%q) returned error %s", name, tt.in, err) + t.Errorf("Parse(%q) returned error %s", tt.in, err) continue } expected := tt.in - if len(tt.roundtrip) > 0 { + if tt.roundtrip != "" { expected = tt.roundtrip } s := u.String() if s != expected { - t.Errorf("%s(%q).String() == %q (expected %q)", name, tt.in, s, expected) + t.Errorf("Parse(%q).String() == %q (expected %q)", tt.in, s, expected) } } -} -func TestURLString(t *testing.T) { - DoTestString(t, Parse, "Parse", urltests) - - // no leading slash on path should prepend - // slash on String() call - noslash := URLTest{ - "http://www.google.com/search", - &URL{ - Scheme: "http", - Host: "www.google.com", - Path: "search", - }, - "", - } - s := noslash.out.String() - if s != noslash.in { - t.Errorf("Expected %s; go %s", noslash.in, s) + for _, tt := range stringURLTests { + if got := tt.url.String(); got != tt.want { + t.Errorf("%+v.String() = %q; want %q", tt.url, got, tt.want) + } } } @@ -780,6 +804,16 @@ var unescapeTests = []EscapeTest{ "", EscapeError("%zz"), }, + { + "a+b", + "a b", + nil, + }, + { + "a%20b", + "a b", + nil, + }, } func TestUnescape(t *testing.T) { @@ -788,10 +822,33 @@ func TestUnescape(t *testing.T) { if actual != tt.out || (err != nil) != (tt.err != nil) { t.Errorf("QueryUnescape(%q) = %q, %s; want %q, %s", tt.in, actual, err, tt.out, tt.err) } + + in := tt.in + out := tt.out + if strings.Contains(tt.in, "+") { + in = strings.Replace(tt.in, "+", "%20", -1) + actual, err := PathUnescape(in) + if actual != tt.out || (err != nil) != (tt.err != nil) { + t.Errorf("PathUnescape(%q) = %q, %s; want %q, %s", in, actual, err, tt.out, tt.err) + } + if tt.err == nil { + s, err := QueryUnescape(strings.Replace(tt.in, "+", "XXX", -1)) + if err != nil { + continue + } + in = tt.in + out = strings.Replace(s, "XXX", "+", -1) + } + } + + actual, err = PathUnescape(in) + if actual != out || (err != nil) != (tt.err != nil) { + t.Errorf("PathUnescape(%q) = %q, %s; want %q, %s", in, actual, err, out, tt.err) + } } } -var escapeTests = []EscapeTest{ +var queryEscapeTests = []EscapeTest{ { "", "", @@ -819,8 +876,8 @@ var escapeTests = []EscapeTest{ }, } -func TestEscape(t *testing.T) { - for _, tt := range escapeTests { +func TestQueryEscape(t *testing.T) { + for _, tt := range queryEscapeTests { actual := QueryEscape(tt.in) if tt.out != actual { t.Errorf("QueryEscape(%q) = %q, want %q", tt.in, actual, tt.out) @@ -834,6 +891,54 @@ func TestEscape(t *testing.T) { } } +var pathEscapeTests = []EscapeTest{ + { + "", + "", + nil, + }, + { + "abc", + "abc", + nil, + }, + { + "abc+def", + "abc+def", + nil, + }, + { + "one two", + "one%20two", + nil, + }, + { + "10%", + "10%25", + nil, + }, + { + " ?&=#+%!<>#\"{}|\\^[]`☺\t:/@$'()*,;", + "%20%3F&=%23+%25%21%3C%3E%23%22%7B%7D%7C%5C%5E%5B%5D%60%E2%98%BA%09:%2F@$%27%28%29%2A%2C%3B", + nil, + }, +} + +func TestPathEscape(t *testing.T) { + for _, tt := range pathEscapeTests { + actual := PathEscape(tt.in) + if tt.out != actual { + t.Errorf("PathEscape(%q) = %q, want %q", tt.in, actual, tt.out) + } + + // for bonus points, verify that escape:unescape is an identity. + roundtrip, err := PathUnescape(actual) + if roundtrip != tt.in || err != nil { + t.Errorf("PathUnescape(%q) = %q, %s; want %q, %s", actual, roundtrip, err, tt.in, "[no error]") + } + } +} + //var userinfoTests = []UserinfoTest{ // {"user", "password", "user:password"}, // {"foo:bar", "~!@#$%^&*()_+{}|[]\\-=`:;'\"<>?,./", @@ -945,6 +1050,15 @@ var resolveReferenceTests = []struct { // Fragment {"http://foo.com/bar", ".#frag", "http://foo.com/#frag"}, + // Paths with escaping (issue 16947). + {"http://foo.com/foo%2fbar/", "../baz", "http://foo.com/baz"}, + {"http://foo.com/1/2%2f/3%2f4/5", "../../a/b/c", "http://foo.com/1/a/b/c"}, + {"http://foo.com/1/2/3", "./a%2f../../b/..%2fc", "http://foo.com/1/2/b/..%2fc"}, + {"http://foo.com/1/2%2f/3%2f4/5", "./a%2f../b/../c", "http://foo.com/1/2%2f/3%2f4/a%2f../c"}, + {"http://foo.com/foo%20bar/", "../baz", "http://foo.com/baz"}, + {"http://foo.com/foo", "../bar%2fbaz", "http://foo.com/bar%2fbaz"}, + {"http://foo.com/foo%2dbar/", "./baz-quux", "http://foo.com/foo%2dbar/baz-quux"}, + // RFC 3986: Normal Examples // http://tools.ietf.org/html/rfc3986#section-5.4.1 {"http://a/b/c/d;p?q", "g:h", "g:h"}, @@ -1004,7 +1118,7 @@ func TestResolveReference(t *testing.T) { mustParse := func(url string) *URL { u, err := Parse(url) if err != nil { - t.Fatalf("Expected URL to parse: %q, got error: %v", url, err) + t.Fatalf("Parse(%q) got err %v", url, err) } return u } @@ -1013,8 +1127,8 @@ func TestResolveReference(t *testing.T) { base := mustParse(test.base) rel := mustParse(test.rel) url := base.ResolveReference(rel) - if url.String() != test.expected { - t.Errorf("URL(%q).ResolveReference(%q) == %q, got %q", test.base, test.rel, test.expected, url.String()) + if got := url.String(); got != test.expected { + t.Errorf("URL(%q).ResolveReference(%q)\ngot %q\nwant %q", test.base, test.rel, got, test.expected) } // Ensure that new instances are returned. if base == url { @@ -1024,8 +1138,8 @@ func TestResolveReference(t *testing.T) { url, err := base.Parse(test.rel) if err != nil { t.Errorf("URL(%q).Parse(%q) failed: %v", test.base, test.rel, err) - } else if url.String() != test.expected { - t.Errorf("URL(%q).Parse(%q) == %q, got %q", test.base, test.rel, test.expected, url.String()) + } else if got := url.String(); got != test.expected { + t.Errorf("URL(%q).Parse(%q)\ngot %q\nwant %q", test.base, test.rel, got, test.expected) } else if base == url { // Ensure that new instances are returned for the wrapper too. t.Errorf("Expected URL.Parse to return new URL instance.") @@ -1033,14 +1147,14 @@ func TestResolveReference(t *testing.T) { // Ensure Opaque resets the URL. url = base.ResolveReference(opaque) if *url != *opaque { - t.Errorf("ResolveReference failed to resolve opaque URL: want %#v, got %#v", url, opaque) + t.Errorf("ResolveReference failed to resolve opaque URL:\ngot %#v\nwant %#v", url, opaque) } // Test the convenience wrapper with an opaque URL too. url, err = base.Parse("scheme:opaque") if err != nil { t.Errorf(`URL(%q).Parse("scheme:opaque") failed: %v`, test.base, err) } else if *url != *opaque { - t.Errorf("Parse failed to resolve opaque URL: want %#v, got %#v", url, opaque) + t.Errorf("Parse failed to resolve opaque URL:\ngot %#v\nwant %#v", opaque, url) } else if base == url { // Ensure that new instances are returned, again. t.Errorf("Expected URL.Parse to return new URL instance.") @@ -1271,7 +1385,7 @@ func TestParseFailure(t *testing.T) { } } -func TestParseAuthority(t *testing.T) { +func TestParseErrors(t *testing.T) { tests := []struct { in string wantErr bool @@ -1291,9 +1405,13 @@ func TestParseAuthority(t *testing.T) { {"http://%41:8080/", true}, // not allowed: % encoding only for non-ASCII {"mysql://x@y(z:123)/foo", false}, // golang.org/issue/12023 {"mysql://x@y(1.2.3.4:123)/foo", false}, - {"mysql://x@y([2001:db8::1]:123)/foo", false}, + {"http://[]%20%48%54%54%50%2f%31%2e%31%0a%4d%79%48%65%61%64%65%72%3a%20%31%32%33%0a%0a/", true}, // golang.org/issue/11208 {"http://a b.com/", true}, // no space in host name please + {"cache_object://foo", true}, // scheme cannot have _, relative path cannot have : in first segment + {"cache_object:foo", true}, + {"cache_object:foo/bar", true}, + {"cache_object/:foo/bar", false}, } for _, tt := range tests { u, err := Parse(tt.in) @@ -1462,11 +1580,106 @@ func TestURLErrorImplementsNetError(t *testing.T) { continue } if err.Timeout() != tt.timeout { - t.Errorf("%d: err.Timeout(): want %v, have %v", i+1, tt.timeout, err.Timeout()) + t.Errorf("%d: err.Timeout(): got %v, want %v", i+1, err.Timeout(), tt.timeout) continue } if err.Temporary() != tt.temporary { - t.Errorf("%d: err.Temporary(): want %v, have %v", i+1, tt.temporary, err.Temporary()) + t.Errorf("%d: err.Temporary(): got %v, want %v", i+1, err.Temporary(), tt.temporary) + } + } +} + +func TestURLHostname(t *testing.T) { + tests := []struct { + host string // URL.Host field + want string + }{ + {"foo.com:80", "foo.com"}, + {"foo.com", "foo.com"}, + {"FOO.COM", "FOO.COM"}, // no canonicalization (yet?) + {"1.2.3.4", "1.2.3.4"}, + {"1.2.3.4:80", "1.2.3.4"}, + {"[1:2:3:4]", "1:2:3:4"}, + {"[1:2:3:4]:80", "1:2:3:4"}, + {"[::1]:80", "::1"}, + } + for _, tt := range tests { + u := &URL{Host: tt.host} + got := u.Hostname() + if got != tt.want { + t.Errorf("Hostname for Host %q = %q; want %q", tt.host, got, tt.want) + } + } +} + +func TestURLPort(t *testing.T) { + tests := []struct { + host string // URL.Host field + want string + }{ + {"foo.com", ""}, + {"foo.com:80", "80"}, + {"1.2.3.4", ""}, + {"1.2.3.4:80", "80"}, + {"[1:2:3:4]", ""}, + {"[1:2:3:4]:80", "80"}, + } + for _, tt := range tests { + u := &URL{Host: tt.host} + got := u.Port() + if got != tt.want { + t.Errorf("Port for Host %q = %q; want %q", tt.host, got, tt.want) } } } + +var _ encodingPkg.BinaryMarshaler = (*URL)(nil) +var _ encodingPkg.BinaryUnmarshaler = (*URL)(nil) + +func TestJSON(t *testing.T) { + u, err := Parse("https://www.google.com/x?y=z") + if err != nil { + t.Fatal(err) + } + js, err := json.Marshal(u) + if err != nil { + t.Fatal(err) + } + + // If only we could implement TextMarshaler/TextUnmarshaler, + // this would work: + // + // if string(js) != strconv.Quote(u.String()) { + // t.Errorf("json encoding: %s\nwant: %s\n", js, strconv.Quote(u.String())) + // } + + u1 := new(URL) + err = json.Unmarshal(js, u1) + if err != nil { + t.Fatal(err) + } + if u1.String() != u.String() { + t.Errorf("json decoded to: %s\nwant: %s\n", u1, u) + } +} + +func TestGob(t *testing.T) { + u, err := Parse("https://www.google.com/x?y=z") + if err != nil { + t.Fatal(err) + } + var w bytes.Buffer + err = gob.NewEncoder(&w).Encode(u) + if err != nil { + t.Fatal(err) + } + + u1 := new(URL) + err = gob.NewDecoder(&w).Decode(u1) + if err != nil { + t.Fatal(err) + } + if u1.String() != u.String() { + t.Errorf("json decoded to: %s\nwant: %s\n", u1, u) + } +} diff --git a/libgo/go/net/writev_test.go b/libgo/go/net/writev_test.go new file mode 100644 index 0000000..7160d28 --- /dev/null +++ b/libgo/go/net/writev_test.go @@ -0,0 +1,225 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "reflect" + "runtime" + "sync" + "testing" +) + +func TestBuffers_read(t *testing.T) { + const story = "once upon a time in Gopherland ... " + buffers := Buffers{ + []byte("once "), + []byte("upon "), + []byte("a "), + []byte("time "), + []byte("in "), + []byte("Gopherland ... "), + } + got, err := ioutil.ReadAll(&buffers) + if err != nil { + t.Fatal(err) + } + if string(got) != story { + t.Errorf("read %q; want %q", got, story) + } + if len(buffers) != 0 { + t.Errorf("len(buffers) = %d; want 0", len(buffers)) + } +} + +func TestBuffers_consume(t *testing.T) { + tests := []struct { + in Buffers + consume int64 + want Buffers + }{ + { + in: Buffers{[]byte("foo"), []byte("bar")}, + consume: 0, + want: Buffers{[]byte("foo"), []byte("bar")}, + }, + { + in: Buffers{[]byte("foo"), []byte("bar")}, + consume: 2, + want: Buffers{[]byte("o"), []byte("bar")}, + }, + { + in: Buffers{[]byte("foo"), []byte("bar")}, + consume: 3, + want: Buffers{[]byte("bar")}, + }, + { + in: Buffers{[]byte("foo"), []byte("bar")}, + consume: 4, + want: Buffers{[]byte("ar")}, + }, + { + in: Buffers{nil, nil, nil, []byte("bar")}, + consume: 1, + want: Buffers{[]byte("ar")}, + }, + { + in: Buffers{nil, nil, nil, []byte("foo")}, + consume: 0, + want: Buffers{[]byte("foo")}, + }, + { + in: Buffers{nil, nil, nil}, + consume: 0, + want: Buffers{}, + }, + } + for i, tt := range tests { + in := tt.in + in.consume(tt.consume) + if !reflect.DeepEqual(in, tt.want) { + t.Errorf("%d. after consume(%d) = %+v, want %+v", i, tt.consume, in, tt.want) + } + } +} + +func TestBuffers_WriteTo(t *testing.T) { + for _, name := range []string{"WriteTo", "Copy"} { + for _, size := range []int{0, 10, 1023, 1024, 1025} { + t.Run(fmt.Sprintf("%s/%d", name, size), func(t *testing.T) { + testBuffer_writeTo(t, size, name == "Copy") + }) + } + } +} + +func testBuffer_writeTo(t *testing.T, chunks int, useCopy bool) { + oldHook := testHookDidWritev + defer func() { testHookDidWritev = oldHook }() + var writeLog struct { + sync.Mutex + log []int + } + testHookDidWritev = func(size int) { + writeLog.Lock() + writeLog.log = append(writeLog.log, size) + writeLog.Unlock() + } + var want bytes.Buffer + for i := 0; i < chunks; i++ { + want.WriteByte(byte(i)) + } + + withTCPConnPair(t, func(c *TCPConn) error { + buffers := make(Buffers, chunks) + for i := range buffers { + buffers[i] = want.Bytes()[i : i+1] + } + var n int64 + var err error + if useCopy { + n, err = io.Copy(c, &buffers) + } else { + n, err = buffers.WriteTo(c) + } + if err != nil { + return err + } + if len(buffers) != 0 { + return fmt.Errorf("len(buffers) = %d; want 0", len(buffers)) + } + if n != int64(want.Len()) { + return fmt.Errorf("Buffers.WriteTo returned %d; want %d", n, want.Len()) + } + return nil + }, func(c *TCPConn) error { + all, err := ioutil.ReadAll(c) + if !bytes.Equal(all, want.Bytes()) || err != nil { + return fmt.Errorf("client read %q, %v; want %q, nil", all, err, want.Bytes()) + } + + writeLog.Lock() // no need to unlock + var gotSum int + for _, v := range writeLog.log { + gotSum += v + } + + var wantSum int + switch runtime.GOOS { + case "android", "darwin", "dragonfly", "freebsd", "linux", "netbsd", "openbsd": + var wantMinCalls int + wantSum = want.Len() + v := chunks + for v > 0 { + wantMinCalls++ + v -= 1024 + } + if len(writeLog.log) < wantMinCalls { + t.Errorf("write calls = %v < wanted min %v", len(writeLog.log), wantMinCalls) + } + case "windows": + var wantCalls int + wantSum = want.Len() + if wantSum > 0 { + wantCalls = 1 // windows will always do 1 syscall, unless sending empty buffer + } + if len(writeLog.log) != wantCalls { + t.Errorf("write calls = %v; want %v", len(writeLog.log), wantCalls) + } + } + if gotSum != wantSum { + t.Errorf("writev call sum = %v; want %v", gotSum, wantSum) + } + return nil + }) +} + +func TestWritevError(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skipf("skipping the test: windows does not have problem sending large chunks of data") + } + + ln, err := newLocalListener("tcp") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + ch := make(chan Conn, 1) + go func() { + defer close(ch) + c, err := ln.Accept() + if err != nil { + t.Error(err) + return + } + ch <- c + }() + c1, err := Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer c1.Close() + c2 := <-ch + if c2 == nil { + t.Fatal("no server side connection") + } + c2.Close() + + // 1 GB of data should be enough to notice the connection is gone. + // Just a few bytes is not enough. + // Arrange to reuse the same 1 MB buffer so that we don't allocate much. + buf := make([]byte, 1<<20) + buffers := make(Buffers, 1<<10) + for i := range buffers { + buffers[i] = buf + } + if _, err := buffers.WriteTo(c1); err == nil { + t.Fatal("Buffers.WriteTo(closed conn) succeeded, want error") + } +} diff --git a/libgo/go/net/writev_unix.go b/libgo/go/net/writev_unix.go new file mode 100644 index 0000000..174e6bc --- /dev/null +++ b/libgo/go/net/writev_unix.go @@ -0,0 +1,95 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd + +package net + +import ( + "io" + "os" + "syscall" + "unsafe" +) + +func (c *conn) writeBuffers(v *Buffers) (int64, error) { + if !c.ok() { + return 0, syscall.EINVAL + } + n, err := c.fd.writeBuffers(v) + if err != nil { + return n, &OpError{Op: "writev", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err} + } + return n, nil +} + +func (fd *netFD) writeBuffers(v *Buffers) (n int64, err error) { + if err := fd.writeLock(); err != nil { + return 0, err + } + defer fd.writeUnlock() + if err := fd.pd.prepareWrite(); err != nil { + return 0, err + } + + var iovecs []syscall.Iovec + if fd.iovecs != nil { + iovecs = *fd.iovecs + } + // TODO: read from sysconf(_SC_IOV_MAX)? The Linux default is + // 1024 and this seems conservative enough for now. Darwin's + // UIO_MAXIOV also seems to be 1024. + maxVec := 1024 + + for len(*v) > 0 { + iovecs = iovecs[:0] + for _, chunk := range *v { + if len(chunk) == 0 { + continue + } + iovecs = append(iovecs, syscall.Iovec{Base: &chunk[0]}) + if fd.isStream && len(chunk) > 1<<30 { + iovecs[len(iovecs)-1].SetLen(1 << 30) + break // continue chunk on next writev + } + iovecs[len(iovecs)-1].SetLen(len(chunk)) + if len(iovecs) == maxVec { + break + } + } + if len(iovecs) == 0 { + break + } + fd.iovecs = &iovecs // cache + + wrote, _, e0 := syscall.Syscall(syscall.SYS_WRITEV, + uintptr(fd.sysfd), + uintptr(unsafe.Pointer(&iovecs[0])), + uintptr(len(iovecs))) + if wrote == ^uintptr(0) { + wrote = 0 + } + testHookDidWritev(int(wrote)) + n += int64(wrote) + v.consume(int64(wrote)) + if e0 == syscall.EAGAIN { + if err = fd.pd.waitWrite(); err == nil { + continue + } + } else if e0 != 0 { + err = syscall.Errno(e0) + } + if err != nil { + break + } + if n == 0 { + err = io.ErrUnexpectedEOF + break + } + } + if _, ok := err.(syscall.Errno); ok { + err = os.NewSyscallError("writev", err) + } + return n, err +} |