aboutsummaryrefslogtreecommitdiff
path: root/libgo/go/net
diff options
context:
space:
mode:
authorIan Lance Taylor <iant@golang.org>2017-01-14 00:05:42 +0000
committerIan Lance Taylor <ian@gcc.gnu.org>2017-01-14 00:05:42 +0000
commitc2047754c300b68c05d65faa8dc2925fe67b71b4 (patch)
treee183ae81a1f48a02945cb6de463a70c5be1b06f6 /libgo/go/net
parent829afb8f05602bb31c9c597b24df7377fed4f059 (diff)
downloadgcc-c2047754c300b68c05d65faa8dc2925fe67b71b4.zip
gcc-c2047754c300b68c05d65faa8dc2925fe67b71b4.tar.gz
gcc-c2047754c300b68c05d65faa8dc2925fe67b71b4.tar.bz2
libgo: update to Go 1.8 release candidate 1
Compiler changes: * Change map assignment to use mapassign and assign value directly. * Change string iteration to use decoderune, faster for ASCII strings. * Change makeslice to take int, and use makeslice64 for larger values. * Add new noverflow field to hmap struct used for maps. Unresolved problems, to be fixed later: * Commented out test in go/types/sizes_test.go that doesn't compile. * Commented out reflect.TestStructOf test for padding after zero-sized field. Reviewed-on: https://go-review.googlesource.com/35231 gotools/: Updates for Go 1.8rc1. * Makefile.am (go_cmd_go_files): Add bug.go. (s-zdefaultcc): Write defaultPkgConfig. * Makefile.in: Rebuild. From-SVN: r244456
Diffstat (limited to 'libgo/go/net')
-rw-r--r--libgo/go/net/addrselect.go51
-rw-r--r--libgo/go/net/addrselect_test.go85
-rw-r--r--libgo/go/net/cgo_unix.go10
-rw-r--r--libgo/go/net/conf.go26
-rw-r--r--libgo/go/net/conf_test.go114
-rw-r--r--libgo/go/net/dial.go44
-rw-r--r--libgo/go/net/dial_test.go19
-rw-r--r--libgo/go/net/dnsclient.go16
-rw-r--r--libgo/go/net/dnsclient_unix.go83
-rw-r--r--libgo/go/net/dnsclient_unix_test.go121
-rw-r--r--libgo/go/net/dnsconfig_unix.go25
-rw-r--r--libgo/go/net/dnsconfig_unix_test.go83
-rw-r--r--libgo/go/net/dnsmsg.go2
-rw-r--r--libgo/go/net/dnsmsg_test.go6
-rw-r--r--libgo/go/net/dnsname_test.go27
-rw-r--r--libgo/go/net/error_test.go21
-rw-r--r--libgo/go/net/fd_io_plan9.go93
-rw-r--r--libgo/go/net/fd_plan9.go149
-rw-r--r--libgo/go/net/fd_poll_nacl.go2
-rw-r--r--libgo/go/net/fd_poll_runtime.go4
-rw-r--r--libgo/go/net/fd_unix.go15
-rw-r--r--libgo/go/net/fd_windows.go66
-rw-r--r--libgo/go/net/file.go3
-rw-r--r--libgo/go/net/file_plan9.go2
-rw-r--r--libgo/go/net/http/client.go411
-rw-r--r--libgo/go/net/http/client_test.go697
-rw-r--r--libgo/go/net/http/clientserver_test.go171
-rw-r--r--libgo/go/net/http/cookie.go66
-rw-r--r--libgo/go/net/http/cookie_test.go98
-rw-r--r--libgo/go/net/http/cookiejar/dummy_publicsuffix_test.go23
-rw-r--r--libgo/go/net/http/cookiejar/example_test.go67
-rw-r--r--libgo/go/net/http/cookiejar/jar.go33
-rw-r--r--libgo/go/net/http/doc.go30
-rw-r--r--libgo/go/net/http/export_test.go39
-rw-r--r--libgo/go/net/http/fcgi/fcgi.go6
-rw-r--r--libgo/go/net/http/fs.go339
-rw-r--r--libgo/go/net/http/fs_test.go160
-rw-r--r--libgo/go/net/http/h2_bundle.go2269
-rw-r--r--libgo/go/net/http/header.go6
-rw-r--r--libgo/go/net/http/http.go98
-rw-r--r--libgo/go/net/http/http_test.go20
-rw-r--r--libgo/go/net/http/httptest/httptest.go3
-rw-r--r--libgo/go/net/http/httptest/recorder.go58
-rw-r--r--libgo/go/net/http/httptest/recorder_test.go22
-rw-r--r--libgo/go/net/http/httptest/server.go15
-rw-r--r--libgo/go/net/http/httptrace/example_test.go31
-rw-r--r--libgo/go/net/http/httptrace/trace.go31
-rw-r--r--libgo/go/net/http/httptrace/trace_test.go29
-rw-r--r--libgo/go/net/http/httputil/dump.go13
-rw-r--r--libgo/go/net/http/httputil/dump_test.go12
-rw-r--r--libgo/go/net/http/httputil/persist.go9
-rw-r--r--libgo/go/net/http/httputil/reverseproxy.go151
-rw-r--r--libgo/go/net/http/httputil/reverseproxy_test.go179
-rw-r--r--libgo/go/net/http/internal/chunked.go30
-rw-r--r--libgo/go/net/http/internal/chunked_test.go27
-rw-r--r--libgo/go/net/http/main_test.go21
-rw-r--r--libgo/go/net/http/npn_test.go1
-rw-r--r--libgo/go/net/http/range_test.go2
-rw-r--r--libgo/go/net/http/readrequest_test.go26
-rw-r--r--libgo/go/net/http/request.go212
-rw-r--r--libgo/go/net/http/request_test.go195
-rw-r--r--libgo/go/net/http/requestwrite_test.go296
-rw-r--r--libgo/go/net/http/response.go4
-rw-r--r--libgo/go/net/http/response_test.go35
-rw-r--r--libgo/go/net/http/responsewrite_test.go21
-rw-r--r--libgo/go/net/http/serve_test.go672
-rw-r--r--libgo/go/net/http/server.go719
-rw-r--r--libgo/go/net/http/sniff_test.go2
-rw-r--r--libgo/go/net/http/transfer.go216
-rw-r--r--libgo/go/net/http/transport.go286
-rw-r--r--libgo/go/net/http/transport_internal_test.go67
-rw-r--r--libgo/go/net/http/transport_test.go400
-rw-r--r--libgo/go/net/interface.go31
-rw-r--r--libgo/go/net/interface_plan9.go198
-rw-r--r--libgo/go/net/interface_solaris.go107
-rw-r--r--libgo/go/net/interface_stub.go2
-rw-r--r--libgo/go/net/interface_test.go11
-rw-r--r--libgo/go/net/ip.go109
-rw-r--r--libgo/go/net/ip_test.go16
-rw-r--r--libgo/go/net/iprawsock.go23
-rw-r--r--libgo/go/net/iprawsock_posix.go16
-rw-r--r--libgo/go/net/iprawsock_test.go27
-rw-r--r--libgo/go/net/ipsock.go61
-rw-r--r--libgo/go/net/ipsock_plan9.go60
-rw-r--r--libgo/go/net/ipsock_posix.go10
-rw-r--r--libgo/go/net/ipsock_test.go14
-rw-r--r--libgo/go/net/lookup.go291
-rw-r--r--libgo/go/net/lookup_nacl.go52
-rw-r--r--libgo/go/net/lookup_plan9.go41
-rw-r--r--libgo/go/net/lookup_stub.go52
-rw-r--r--libgo/go/net/lookup_test.go85
-rw-r--r--libgo/go/net/lookup_unix.go52
-rw-r--r--libgo/go/net/lookup_windows.go59
-rw-r--r--libgo/go/net/mail/message.go15
-rw-r--r--libgo/go/net/mail/message_test.go23
-rw-r--r--libgo/go/net/main_test.go2
-rw-r--r--libgo/go/net/net.go122
-rw-r--r--libgo/go/net/net_test.go104
-rw-r--r--libgo/go/net/parse.go58
-rw-r--r--libgo/go/net/parse_test.go7
-rw-r--r--libgo/go/net/port_unix.go32
-rw-r--r--libgo/go/net/rpc/client.go2
-rw-r--r--libgo/go/net/rpc/client_test.go4
-rw-r--r--libgo/go/net/rpc/server.go6
-rw-r--r--libgo/go/net/rpc/server_test.go3
-rw-r--r--libgo/go/net/smtp/smtp.go5
-rw-r--r--libgo/go/net/smtp/smtp_test.go71
-rw-r--r--libgo/go/net/sock_linux.go2
-rw-r--r--libgo/go/net/sock_posix.go3
-rw-r--r--libgo/go/net/tcpsock.go10
-rw-r--r--libgo/go/net/tcpsock_posix.go4
-rw-r--r--libgo/go/net/tcpsock_test.go125
-rw-r--r--libgo/go/net/tcpsock_unix_test.go4
-rw-r--r--libgo/go/net/testdata/invalid-ndots-resolv.conf1
-rw-r--r--libgo/go/net/testdata/large-ndots-resolv.conf1
-rw-r--r--libgo/go/net/testdata/negative-ndots-resolv.conf1
-rw-r--r--libgo/go/net/textproto/header.go4
-rw-r--r--libgo/go/net/timeout_test.go47
-rw-r--r--libgo/go/net/udpsock.go14
-rw-r--r--libgo/go/net/udpsock_plan9.go38
-rw-r--r--libgo/go/net/udpsock_plan9_test.go69
-rw-r--r--libgo/go/net/udpsock_posix.go4
-rw-r--r--libgo/go/net/udpsock_test.go31
-rw-r--r--libgo/go/net/unixsock.go14
-rw-r--r--libgo/go/net/unixsock_posix.go25
-rw-r--r--libgo/go/net/unixsock_test.go124
-rw-r--r--libgo/go/net/url/url.go185
-rw-r--r--libgo/go/net/url/url_test.go327
-rw-r--r--libgo/go/net/writev_test.go225
-rw-r--r--libgo/go/net/writev_unix.go95
130 files changed, 9938 insertions, 2571 deletions
diff --git a/libgo/go/net/addrselect.go b/libgo/go/net/addrselect.go
index 0b9d160..1ab9fc53 100644
--- a/libgo/go/net/addrselect.go
+++ b/libgo/go/net/addrselect.go
@@ -188,33 +188,17 @@ func (s *byRFC6724) Less(i, j int) bool {
// Rule 9: Use longest matching prefix.
// When DA and DB belong to the same address family (both are IPv6 or
- // both are IPv4): If CommonPrefixLen(Source(DA), DA) >
+ // both are IPv4 [but see below]): If CommonPrefixLen(Source(DA), DA) >
// CommonPrefixLen(Source(DB), DB), then prefer DA. Similarly, if
// CommonPrefixLen(Source(DA), DA) < CommonPrefixLen(Source(DB), DB),
// then prefer DB.
- da4 := DA.To4() != nil
- db4 := DB.To4() != nil
- if da4 == db4 {
+ //
+ // However, applying this rule to IPv4 addresses causes
+ // problems (see issues 13283 and 18518), so limit to IPv6.
+ if DA.To4() == nil && DB.To4() == nil {
commonA := commonPrefixLen(SourceDA, DA)
commonB := commonPrefixLen(SourceDB, DB)
- // CommonPrefixLen doesn't really make sense for IPv4, and even
- // causes problems for common load balancing practices
- // (e.g., https://golang.org/issue/13283). Glibc instead only
- // uses CommonPrefixLen for IPv4 when the source and destination
- // addresses are on the same subnet, but that requires extra
- // work to find the netmask for our source addresses. As a
- // simpler heuristic, we limit its use to when the source and
- // destination belong to the same special purpose block.
- if da4 {
- if !sameIPv4SpecialPurposeBlock(SourceDA, DA) {
- commonA = 0
- }
- if !sameIPv4SpecialPurposeBlock(SourceDB, DB) {
- commonB = 0
- }
- }
-
if commonA > commonB {
return preferDA
}
@@ -404,28 +388,3 @@ func commonPrefixLen(a, b IP) (cpl int) {
}
return
}
-
-// sameIPv4SpecialPurposeBlock reports whether a and b belong to the same
-// address block reserved by the IANA IPv4 Special-Purpose Address Registry:
-// http://www.iana.org/assignments/iana-ipv4-special-registry/iana-ipv4-special-registry.xhtml
-func sameIPv4SpecialPurposeBlock(a, b IP) bool {
- a, b = a.To4(), b.To4()
- if a == nil || b == nil || a[0] != b[0] {
- return false
- }
- // IANA defines more special-purpose blocks, but these are the only
- // ones likely to be relevant to typical Go systems.
- switch a[0] {
- case 10: // 10.0.0.0/8: Private-Use
- return true
- case 127: // 127.0.0.0/8: Loopback
- return true
- case 169: // 169.254.0.0/16: Link Local
- return a[1] == 254 && b[1] == 254
- case 172: // 172.16.0.0/12: Private-Use
- return a[1]&0xf0 == 16 && b[1]&0xf0 == 16
- case 192: // 192.168.0.0/16: Private-Use
- return a[1] == 168 && b[1] == 168
- }
- return false
-}
diff --git a/libgo/go/net/addrselect_test.go b/libgo/go/net/addrselect_test.go
index 80aa4eb..d6e0e63 100644
--- a/libgo/go/net/addrselect_test.go
+++ b/libgo/go/net/addrselect_test.go
@@ -117,27 +117,6 @@ func TestSortByRFC6724(t *testing.T) {
},
reverse: false,
},
-
- // Prefer longer common prefixes, but only for IPv4 address
- // pairs in the same special-purpose block.
- {
- in: []IPAddr{
- {IP: ParseIP("1.2.3.4")},
- {IP: ParseIP("10.55.0.1")},
- {IP: ParseIP("10.66.0.1")},
- },
- srcs: []IP{
- ParseIP("1.2.3.5"),
- ParseIP("10.66.1.2"),
- ParseIP("10.66.1.2"),
- },
- want: []IPAddr{
- {IP: ParseIP("10.66.0.1")},
- {IP: ParseIP("10.55.0.1")},
- {IP: ParseIP("1.2.3.4")},
- },
- reverse: true,
- },
}
for i, tt := range tests {
inCopy := make([]IPAddr, len(tt.in))
@@ -268,67 +247,3 @@ func TestRFC6724CommonPrefixLength(t *testing.T) {
}
}
-
-func mustParseCIDRs(t *testing.T, blocks ...string) []*IPNet {
- res := make([]*IPNet, len(blocks))
- for i, block := range blocks {
- var err error
- _, res[i], err = ParseCIDR(block)
- if err != nil {
- t.Fatalf("ParseCIDR(%s) failed: %v", block, err)
- }
- }
- return res
-}
-
-func TestSameIPv4SpecialPurposeBlock(t *testing.T) {
- blocks := mustParseCIDRs(t,
- "10.0.0.0/8",
- "127.0.0.0/8",
- "169.254.0.0/16",
- "172.16.0.0/12",
- "192.168.0.0/16",
- )
-
- addrs := []struct {
- ip IP
- block int // index or -1
- }{
- {IP{1, 2, 3, 4}, -1},
- {IP{2, 3, 4, 5}, -1},
- {IP{10, 2, 3, 4}, 0},
- {IP{10, 6, 7, 8}, 0},
- {IP{127, 0, 0, 1}, 1},
- {IP{127, 255, 255, 255}, 1},
- {IP{169, 254, 77, 99}, 2},
- {IP{169, 254, 44, 22}, 2},
- {IP{169, 255, 0, 1}, -1},
- {IP{172, 15, 5, 6}, -1},
- {IP{172, 16, 32, 41}, 3},
- {IP{172, 31, 128, 9}, 3},
- {IP{172, 32, 88, 100}, -1},
- {IP{192, 168, 1, 1}, 4},
- {IP{192, 168, 128, 42}, 4},
- {IP{192, 169, 1, 1}, -1},
- }
-
- for i, addr := range addrs {
- for j, block := range blocks {
- got := block.Contains(addr.ip)
- want := addr.block == j
- if got != want {
- t.Errorf("%d/%d. %s.Contains(%s): got %v, want %v", i, j, block, addr.ip, got, want)
- }
- }
- }
-
- for i, addr1 := range addrs {
- for j, addr2 := range addrs {
- got := sameIPv4SpecialPurposeBlock(addr1.ip, addr2.ip)
- want := addr1.block >= 0 && addr1.block == addr2.block
- if got != want {
- t.Errorf("%d/%d. sameIPv4SpecialPurposeBlock(%s, %s): got %v, want %v", i, j, addr1.ip, addr2.ip, got, want)
- }
- }
- }
-}
diff --git a/libgo/go/net/cgo_unix.go b/libgo/go/net/cgo_unix.go
index 525c63c..a90aaa9 100644
--- a/libgo/go/net/cgo_unix.go
+++ b/libgo/go/net/cgo_unix.go
@@ -114,7 +114,15 @@ func cgoLookupPort(ctx context.Context, network, service string) (port int, err
}
func cgoLookupServicePort(hints *syscall.Addrinfo, network, service string) (port int, err error) {
- s := syscall.StringBytePtr(service)
+ s, err := syscall.BytePtrFromString(service)
+ if err != nil {
+ return 0, err
+ }
+ // Lowercase the service name in the memory passed to C.
+ for i := 0; i < len(service); i++ {
+ bp := (*byte)(unsafe.Pointer(uintptr(unsafe.Pointer(s)) + uintptr(i)))
+ *bp = lowerASCII(*bp)
+ }
var res *syscall.Addrinfo
syscall.Entersyscall()
gerrno := libc_getaddrinfo(nil, s, hints, &res)
diff --git a/libgo/go/net/conf.go b/libgo/go/net/conf.go
index eb72916..c10aafe 100644
--- a/libgo/go/net/conf.go
+++ b/libgo/go/net/conf.go
@@ -179,8 +179,6 @@ func (c *conf) hostLookupOrder(hostname string) (ret hostLookupOrder) {
}
}
- hasDot := byteIndex(hostname, '.') != -1
-
// Canonicalize the hostname by removing any trailing dot.
if stringsHasSuffix(hostname, ".") {
hostname = hostname[:len(hostname)-1]
@@ -220,10 +218,14 @@ func (c *conf) hostLookupOrder(hostname string) (ret hostLookupOrder) {
var first string
for _, src := range srcs {
if src.source == "myhostname" {
- if hostname == "" || hasDot {
- continue
+ if isLocalhost(hostname) || isGateway(hostname) {
+ return fallbackOrder
}
- return fallbackOrder
+ hn, err := getHostname()
+ if err != nil || stringsEqualFold(hostname, hn) {
+ return fallbackOrder
+ }
+ continue
}
if src.source == "files" || src.source == "dns" {
if !src.standardCriteria() {
@@ -293,7 +295,7 @@ func goDebugNetDNS() (dnsMode string, debugLevel int) {
return
}
if '0' <= s[0] && s[0] <= '9' {
- debugLevel, _, _ = dtoi(s, 0)
+ debugLevel, _, _ = dtoi(s)
} else {
dnsMode = s
}
@@ -306,3 +308,15 @@ func goDebugNetDNS() (dnsMode string, debugLevel int) {
parsePart(goDebug)
return
}
+
+// isLocalhost reports whether h should be considered a "localhost"
+// name for the myhostname NSS module.
+func isLocalhost(h string) bool {
+ return stringsEqualFold(h, "localhost") || stringsEqualFold(h, "localhost.localdomain") || stringsHasSuffixFold(h, ".localhost") || stringsHasSuffixFold(h, ".localhost.localdomain")
+}
+
+// isGateway reports whether h should be considered a "gateway"
+// name for the myhostname NSS module.
+func isGateway(h string) bool {
+ return stringsEqualFold(h, "gateway")
+}
diff --git a/libgo/go/net/conf_test.go b/libgo/go/net/conf_test.go
index ec8814b..17d03f4 100644
--- a/libgo/go/net/conf_test.go
+++ b/libgo/go/net/conf_test.go
@@ -13,8 +13,9 @@ import (
)
type nssHostTest struct {
- host string
- want hostLookupOrder
+ host string
+ localhost string
+ want hostLookupOrder
}
func nssStr(s string) *nssConf { return parseNSSConf(strings.NewReader(s)) }
@@ -42,8 +43,8 @@ func TestConfHostLookupOrder(t *testing.T) {
resolv: defaultResolvConf,
},
hostTests: []nssHostTest{
- {"foo.local", hostLookupCgo},
- {"google.com", hostLookupCgo},
+ {"foo.local", "myhostname", hostLookupCgo},
+ {"google.com", "myhostname", hostLookupCgo},
},
},
{
@@ -54,7 +55,7 @@ func TestConfHostLookupOrder(t *testing.T) {
resolv: defaultResolvConf,
},
hostTests: []nssHostTest{
- {"x.com", hostLookupDNSFiles},
+ {"x.com", "myhostname", hostLookupDNSFiles},
},
},
{
@@ -65,7 +66,7 @@ func TestConfHostLookupOrder(t *testing.T) {
resolv: defaultResolvConf,
},
hostTests: []nssHostTest{
- {"x.com", hostLookupFilesDNS},
+ {"x.com", "myhostname", hostLookupFilesDNS},
},
},
{
@@ -75,11 +76,11 @@ func TestConfHostLookupOrder(t *testing.T) {
resolv: defaultResolvConf,
},
hostTests: []nssHostTest{
- {"foo.local", hostLookupCgo},
- {"foo.local.", hostLookupCgo},
- {"foo.LOCAL", hostLookupCgo},
- {"foo.LOCAL.", hostLookupCgo},
- {"google.com", hostLookupFilesDNS},
+ {"foo.local", "myhostname", hostLookupCgo},
+ {"foo.local.", "myhostname", hostLookupCgo},
+ {"foo.LOCAL", "myhostname", hostLookupCgo},
+ {"foo.LOCAL.", "myhostname", hostLookupCgo},
+ {"google.com", "myhostname", hostLookupFilesDNS},
},
},
{
@@ -89,7 +90,7 @@ func TestConfHostLookupOrder(t *testing.T) {
nss: nssStr("foo: bar"),
resolv: defaultResolvConf,
},
- hostTests: []nssHostTest{{"google.com", hostLookupFilesDNS}},
+ hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupFilesDNS}},
},
// On OpenBSD, no resolv.conf means no DNS.
{
@@ -98,7 +99,7 @@ func TestConfHostLookupOrder(t *testing.T) {
goos: "openbsd",
resolv: defaultResolvConf,
},
- hostTests: []nssHostTest{{"google.com", hostLookupFiles}},
+ hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupFiles}},
},
{
name: "solaris_no_nsswitch",
@@ -107,7 +108,7 @@ func TestConfHostLookupOrder(t *testing.T) {
nss: &nssConf{err: os.ErrNotExist},
resolv: defaultResolvConf,
},
- hostTests: []nssHostTest{{"google.com", hostLookupCgo}},
+ hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupCgo}},
},
{
name: "openbsd_lookup_bind_file",
@@ -116,8 +117,8 @@ func TestConfHostLookupOrder(t *testing.T) {
resolv: &dnsConfig{lookup: []string{"bind", "file"}},
},
hostTests: []nssHostTest{
- {"google.com", hostLookupDNSFiles},
- {"foo.local", hostLookupDNSFiles},
+ {"google.com", "myhostname", hostLookupDNSFiles},
+ {"foo.local", "myhostname", hostLookupDNSFiles},
},
},
{
@@ -126,7 +127,7 @@ func TestConfHostLookupOrder(t *testing.T) {
goos: "openbsd",
resolv: &dnsConfig{lookup: []string{"file", "bind"}},
},
- hostTests: []nssHostTest{{"google.com", hostLookupFilesDNS}},
+ hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupFilesDNS}},
},
{
name: "openbsd_lookup_bind",
@@ -134,7 +135,7 @@ func TestConfHostLookupOrder(t *testing.T) {
goos: "openbsd",
resolv: &dnsConfig{lookup: []string{"bind"}},
},
- hostTests: []nssHostTest{{"google.com", hostLookupDNS}},
+ hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupDNS}},
},
{
name: "openbsd_lookup_file",
@@ -142,7 +143,7 @@ func TestConfHostLookupOrder(t *testing.T) {
goos: "openbsd",
resolv: &dnsConfig{lookup: []string{"file"}},
},
- hostTests: []nssHostTest{{"google.com", hostLookupFiles}},
+ hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupFiles}},
},
{
name: "openbsd_lookup_yp",
@@ -150,7 +151,7 @@ func TestConfHostLookupOrder(t *testing.T) {
goos: "openbsd",
resolv: &dnsConfig{lookup: []string{"file", "bind", "yp"}},
},
- hostTests: []nssHostTest{{"google.com", hostLookupCgo}},
+ hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupCgo}},
},
{
name: "openbsd_lookup_two",
@@ -158,7 +159,7 @@ func TestConfHostLookupOrder(t *testing.T) {
goos: "openbsd",
resolv: &dnsConfig{lookup: []string{"file", "foo"}},
},
- hostTests: []nssHostTest{{"google.com", hostLookupCgo}},
+ hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupCgo}},
},
{
name: "openbsd_lookup_empty",
@@ -166,7 +167,7 @@ func TestConfHostLookupOrder(t *testing.T) {
goos: "openbsd",
resolv: &dnsConfig{lookup: nil},
},
- hostTests: []nssHostTest{{"google.com", hostLookupDNSFiles}},
+ hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupDNSFiles}},
},
// glibc lacking an nsswitch.conf, per
// http://www.gnu.org/software/libc/manual/html_node/Notes-on-NSS-Configuration-File.html
@@ -177,7 +178,7 @@ func TestConfHostLookupOrder(t *testing.T) {
nss: &nssConf{err: os.ErrNotExist},
resolv: defaultResolvConf,
},
- hostTests: []nssHostTest{{"google.com", hostLookupDNSFiles}},
+ hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupDNSFiles}},
},
{
name: "files_mdns_dns",
@@ -186,8 +187,8 @@ func TestConfHostLookupOrder(t *testing.T) {
resolv: defaultResolvConf,
},
hostTests: []nssHostTest{
- {"x.com", hostLookupFilesDNS},
- {"x.local", hostLookupCgo},
+ {"x.com", "myhostname", hostLookupFilesDNS},
+ {"x.local", "myhostname", hostLookupCgo},
},
},
{
@@ -197,9 +198,9 @@ func TestConfHostLookupOrder(t *testing.T) {
resolv: defaultResolvConf,
},
hostTests: []nssHostTest{
- {"x.com", hostLookupDNS},
- {"x\\.com", hostLookupCgo}, // punt on weird glibc escape
- {"foo.com%en0", hostLookupCgo}, // and IPv6 zones
+ {"x.com", "myhostname", hostLookupDNS},
+ {"x\\.com", "myhostname", hostLookupCgo}, // punt on weird glibc escape
+ {"foo.com%en0", "myhostname", hostLookupCgo}, // and IPv6 zones
},
},
{
@@ -210,8 +211,8 @@ func TestConfHostLookupOrder(t *testing.T) {
hasMDNSAllow: true,
},
hostTests: []nssHostTest{
- {"x.com", hostLookupCgo},
- {"x.local", hostLookupCgo},
+ {"x.com", "myhostname", hostLookupCgo},
+ {"x.local", "myhostname", hostLookupCgo},
},
},
{
@@ -221,9 +222,9 @@ func TestConfHostLookupOrder(t *testing.T) {
resolv: defaultResolvConf,
},
hostTests: []nssHostTest{
- {"x.com", hostLookupFilesDNS},
- {"x", hostLookupFilesDNS},
- {"x.local", hostLookupCgo},
+ {"x.com", "myhostname", hostLookupFilesDNS},
+ {"x", "myhostname", hostLookupFilesDNS},
+ {"x.local", "myhostname", hostLookupCgo},
},
},
{
@@ -233,9 +234,9 @@ func TestConfHostLookupOrder(t *testing.T) {
resolv: defaultResolvConf,
},
hostTests: []nssHostTest{
- {"x.com", hostLookupDNSFiles},
- {"x", hostLookupDNSFiles},
- {"x.local", hostLookupCgo},
+ {"x.com", "myhostname", hostLookupDNSFiles},
+ {"x", "myhostname", hostLookupDNSFiles},
+ {"x.local", "myhostname", hostLookupCgo},
},
},
{
@@ -245,7 +246,7 @@ func TestConfHostLookupOrder(t *testing.T) {
resolv: defaultResolvConf,
},
hostTests: []nssHostTest{
- {"x.com", hostLookupCgo},
+ {"x.com", "myhostname", hostLookupCgo},
},
},
{
@@ -255,9 +256,23 @@ func TestConfHostLookupOrder(t *testing.T) {
resolv: defaultResolvConf,
},
hostTests: []nssHostTest{
- {"x.com", hostLookupFilesDNS},
- {"somehostname", hostLookupCgo},
- {"", hostLookupFilesDNS}, // Issue 13623
+ {"x.com", "myhostname", hostLookupFilesDNS},
+ {"myhostname", "myhostname", hostLookupCgo},
+ {"myHostname", "myhostname", hostLookupCgo},
+ {"myhostname.dot", "myhostname.dot", hostLookupCgo},
+ {"myHostname.dot", "myhostname.dot", hostLookupCgo},
+ {"gateway", "myhostname", hostLookupCgo},
+ {"Gateway", "myhostname", hostLookupCgo},
+ {"localhost", "myhostname", hostLookupCgo},
+ {"Localhost", "myhostname", hostLookupCgo},
+ {"anything.localhost", "myhostname", hostLookupCgo},
+ {"Anything.localhost", "myhostname", hostLookupCgo},
+ {"localhost.localdomain", "myhostname", hostLookupCgo},
+ {"Localhost.Localdomain", "myhostname", hostLookupCgo},
+ {"anything.localhost.localdomain", "myhostname", hostLookupCgo},
+ {"Anything.Localhost.Localdomain", "myhostname", hostLookupCgo},
+ {"somehostname", "myhostname", hostLookupFilesDNS},
+ {"", "myhostname", hostLookupFilesDNS}, // Issue 13623
},
},
{
@@ -267,8 +282,9 @@ func TestConfHostLookupOrder(t *testing.T) {
resolv: defaultResolvConf,
},
hostTests: []nssHostTest{
- {"x.com", hostLookupFilesDNS},
- {"somehostname", hostLookupCgo},
+ {"x.com", "myhostname", hostLookupFilesDNS},
+ {"somehostname", "myhostname", hostLookupFilesDNS},
+ {"myhostname", "myhostname", hostLookupCgo},
},
},
// Debian Squeeze is just "dns,files", but lists all
@@ -282,8 +298,8 @@ func TestConfHostLookupOrder(t *testing.T) {
resolv: defaultResolvConf,
},
hostTests: []nssHostTest{
- {"x.com", hostLookupDNSFiles},
- {"somehostname", hostLookupDNSFiles},
+ {"x.com", "myhostname", hostLookupDNSFiles},
+ {"somehostname", "myhostname", hostLookupDNSFiles},
},
},
{
@@ -292,7 +308,7 @@ func TestConfHostLookupOrder(t *testing.T) {
nss: nssStr("foo: bar"),
resolv: &dnsConfig{servers: defaultNS, ndots: 1, timeout: 5, attempts: 2, unknownOpt: true},
},
- hostTests: []nssHostTest{{"google.com", hostLookupCgo}},
+ hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupCgo}},
},
// Android should always use cgo.
{
@@ -303,12 +319,18 @@ func TestConfHostLookupOrder(t *testing.T) {
resolv: defaultResolvConf,
},
hostTests: []nssHostTest{
- {"x.com", hostLookupCgo},
+ {"x.com", "myhostname", hostLookupCgo},
},
},
}
+
+ origGetHostname := getHostname
+ defer func() { getHostname = origGetHostname }()
+
for _, tt := range tests {
for _, ht := range tt.hostTests {
+ getHostname = func() (string, error) { return ht.localhost, nil }
+
gotOrder := tt.c.hostLookupOrder(ht.host)
if gotOrder != ht.want {
t.Errorf("%s: hostLookupOrder(%q) = %v; want %v", tt.name, ht.host, gotOrder, ht.want)
diff --git a/libgo/go/net/dial.go b/libgo/go/net/dial.go
index 55edb43..50bba5a 100644
--- a/libgo/go/net/dial.go
+++ b/libgo/go/net/dial.go
@@ -59,6 +59,9 @@ type Dialer struct {
// that do not support keep-alives ignore this field.
KeepAlive time.Duration
+ // Resolver optionally specifies an alternate resolver to use.
+ Resolver *Resolver
+
// Cancel is an optional channel whose closure indicates that
// the dial should be canceled. Not all types of dials support
// cancelation.
@@ -92,6 +95,13 @@ func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Tim
return minNonzeroTime(earliest, d.Deadline)
}
+func (d *Dialer) resolver() *Resolver {
+ if d.Resolver != nil {
+ return d.Resolver
+ }
+ return DefaultResolver
+}
+
// partialDeadline returns the deadline to use for a single address,
// when multiple addresses are pending.
func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
@@ -141,7 +151,7 @@ func parseNetwork(ctx context.Context, net string) (afnet string, proto int, err
switch afnet {
case "ip", "ip4", "ip6":
protostr := net[i+1:]
- proto, i, ok := dtoi(protostr, 0)
+ proto, i, ok := dtoi(protostr)
if !ok || i != len(protostr) {
proto, err = lookupProtocol(ctx, protostr)
if err != nil {
@@ -153,10 +163,10 @@ func parseNetwork(ctx context.Context, net string) (afnet string, proto int, err
return "", 0, UnknownNetworkError(net)
}
-// resolverAddrList resolves addr using hint and returns a list of
+// resolveAddrList resolves addr using hint and returns a list of
// addresses. The result contains at least one address when error is
// nil.
-func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) {
+func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) {
afnet, _, err := parseNetwork(ctx, network)
if err != nil {
return nil, err
@@ -166,7 +176,6 @@ func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (
}
switch afnet {
case "unix", "unixgram", "unixpacket":
- // TODO(bradfitz): push down context
addr, err := ResolveUnixAddr(afnet, addr)
if err != nil {
return nil, err
@@ -176,7 +185,7 @@ func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (
}
return addrList{addr}, nil
}
- addrs, err := internetAddrList(ctx, afnet, addr)
+ addrs, err := r.internetAddrList(ctx, afnet, addr)
if err != nil || op != "dial" || hint == nil {
return addrs, err
}
@@ -221,7 +230,7 @@ func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (
}
}
if len(naddrs) == 0 {
- return nil, errNoSuitableAddress
+ return nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: hint.String()}
}
return naddrs, nil
}
@@ -256,6 +265,9 @@ func resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (
// Dial("ip6:ipv6-icmp", "2001:db8::1")
//
// For Unix networks, the address must be a file system path.
+//
+// If the host is resolved to multiple addresses,
+// Dial will try each address in order until one succeeds.
func Dial(network, address string) (Conn, error) {
var d Dialer
return d.Dial(network, address)
@@ -290,6 +302,14 @@ func (d *Dialer) Dial(network, address string) (Conn, error) {
// connected, any expiration of the context will not affect the
// connection.
//
+// When using TCP, and the host in the address parameter resolves to multiple
+// network addresses, any dial timeout (from d.Timeout or ctx) is spread
+// over each consecutive dial, such that each is given an appropriate
+// fraction of the time to connect.
+// For example, if a host has 4 IP addresses and the timeout is 1 minute,
+// the connect to each single address will be given 15 seconds to complete
+// before trying the next one.
+//
// See func Dial for a description of the network and address
// parameters.
func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
@@ -326,7 +346,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn
resolveCtx = context.WithValue(resolveCtx, nettrace.TraceKey{}, &shadow)
}
- addrs, err := resolveAddrList(resolveCtx, "dial", network, address, d.LocalAddr)
+ addrs, err := d.resolver().resolveAddrList(resolveCtx, "dial", network, address, d.LocalAddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
}
@@ -524,8 +544,11 @@ func dialSingle(ctx context.Context, dp *dialParam, ra Addr) (c Conn, err error)
// If host is omitted, as in ":8080", Listen listens on all available interfaces
// instead of just the interface with the given host address.
// See Dial for more details about address syntax.
+//
+// Listening on a hostname is not recommended because this creates a socket
+// for at most one of its IP addresses.
func Listen(net, laddr string) (Listener, error) {
- addrs, err := resolveAddrList(context.Background(), "listen", net, laddr, nil)
+ addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", net, laddr, nil)
if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err}
}
@@ -551,8 +574,11 @@ func Listen(net, laddr string) (Listener, error) {
// If host is omitted, as in ":8080", ListenPacket listens on all available interfaces
// instead of just the interface with the given host address.
// See Dial for the syntax of laddr.
+//
+// Listening on a hostname is not recommended because this creates a socket
+// for at most one of its IP addresses.
func ListenPacket(net, laddr string) (PacketConn, error) {
- addrs, err := resolveAddrList(context.Background(), "listen", net, laddr, nil)
+ addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", net, laddr, nil)
if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err}
}
diff --git a/libgo/go/net/dial_test.go b/libgo/go/net/dial_test.go
index 8b21e6b..9919d72 100644
--- a/libgo/go/net/dial_test.go
+++ b/libgo/go/net/dial_test.go
@@ -55,6 +55,23 @@ func TestProhibitionaryDialArg(t *testing.T) {
}
}
+func TestDialLocal(t *testing.T) {
+ ln, err := newLocalListener("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+ _, port, err := SplitHostPort(ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ c, err := Dial("tcp", JoinHostPort("", port))
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.Close()
+}
+
func TestDialTimeoutFDLeak(t *testing.T) {
switch runtime.GOOS {
case "plan9":
@@ -125,6 +142,8 @@ func TestDialerDualStackFDLeak(t *testing.T) {
t.Skipf("%s does not have full support of socktest", runtime.GOOS)
case "windows":
t.Skipf("not implemented a way to cancel dial racers in TCP SYN-SENT state on %s", runtime.GOOS)
+ case "openbsd":
+ testenv.SkipFlaky(t, 15157)
}
if !supportsIPv4 || !supportsIPv6 {
t.Skip("both IPv4 and IPv6 are required")
diff --git a/libgo/go/net/dnsclient.go b/libgo/go/net/dnsclient.go
index f1835b8..2ab5639 100644
--- a/libgo/go/net/dnsclient.go
+++ b/libgo/go/net/dnsclient.go
@@ -113,12 +113,20 @@ func equalASCIILabel(x, y string) bool {
return true
}
+// isDomainName checks if a string is a presentation-format domain name
+// (currently restricted to hostname-compatible "preferred name" LDH labels and
+// SRV-like "underscore labels"; see golang.org/issue/12421).
func isDomainName(s string) bool {
// See RFC 1035, RFC 3696.
- if len(s) == 0 {
- return false
- }
- if len(s) > 255 {
+ // Presentation format has dots before every label except the first, and the
+ // terminal empty label is optional here because we assume fully-qualified
+ // (absolute) input. We must therefore reserve space for the first and last
+ // labels' length octets in wire format, where they are necessary and the
+ // maximum total length is 255.
+ // So our _effective_ maximum is 253, but 254 is not rejected if the last
+ // character is a dot.
+ l := len(s)
+ if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
return false
}
diff --git a/libgo/go/net/dnsclient_unix.go b/libgo/go/net/dnsclient_unix.go
index b5b6ffb..4dd4e16 100644
--- a/libgo/go/net/dnsclient_unix.go
+++ b/libgo/go/net/dnsclient_unix.go
@@ -125,7 +125,7 @@ func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn,
// Calling Dial here is scary -- we have to be sure not to
// dial a name that will require a DNS lookup, or Dial will
// call back here to translate it. The DNS config parser has
- // already checked that all the cfg.servers[i] are IP
+ // already checked that all the cfg.servers are IP
// addresses, which Dial will use without a DNS lookup.
c, err := d.DialContext(ctx, network, server)
if err != nil {
@@ -182,13 +182,14 @@ func exchange(ctx context.Context, server, name string, qtype uint16, timeout ti
// Do a lookup for a single name, which must be rooted
// (otherwise answer will not find the answers).
func tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) {
- if len(cfg.servers) == 0 {
- return "", nil, &DNSError{Err: "no DNS servers", Name: name}
- }
-
var lastErr error
+ serverOffset := cfg.serverOffset()
+ sLen := uint32(len(cfg.servers))
+
for i := 0; i < cfg.attempts; i++ {
- for _, server := range cfg.servers {
+ for j := uint32(0); j < sLen; j++ {
+ server := cfg.servers[(serverOffset+j)%sLen]
+
msg, err := exchange(ctx, server, name, qtype, cfg.timeout)
if err != nil {
lastErr = &DNSError{
@@ -315,7 +316,12 @@ func (conf *resolverConfig) releaseSema() {
func lookup(ctx context.Context, name string, qtype uint16) (cname string, rrs []dnsRR, err error) {
if !isDomainName(name) {
- return "", nil, &DNSError{Err: "invalid domain name", Name: name}
+ // We used to use "invalid domain name" as the error,
+ // but that is a detail of the specific lookup mechanism.
+ // Other lookups might allow broader name syntax
+ // (for example Multicast DNS allows UTF-8; see RFC 6762).
+ // For consistency with libc resolvers, report no such host.
+ return "", nil, &DNSError{Err: errNoSuchHost.Error(), Name: name}
}
resolvConf.tryUpdate("/etc/resolv.conf")
resolvConf.mu.RLock()
@@ -356,14 +362,21 @@ func (conf *dnsConfig) nameList(name string) []string {
return nil
}
+ // Check name length (see isDomainName).
+ l := len(name)
+ rooted := l > 0 && name[l-1] == '.'
+ if l > 254 || l == 254 && rooted {
+ return nil
+ }
+
// If name is rooted (trailing dot), try only that name.
- rooted := len(name) > 0 && name[len(name)-1] == '.'
if rooted {
return []string{name}
}
hasNdots := count(name, '.') >= conf.ndots
name += "."
+ l++
// Build list of search choices.
names := make([]string, 0, 1+len(conf.search))
@@ -371,9 +384,11 @@ func (conf *dnsConfig) nameList(name string) []string {
if hasNdots {
names = append(names, name)
}
- // Try suffixes.
+ // Try suffixes that are not too long (see isDomainName).
for _, suffix := range conf.search {
- names = append(names, name+suffix)
+ if l+len(suffix) <= 254 {
+ names = append(names, name+suffix)
+ }
}
// Try unsuffixed, if not tried first above.
if !hasNdots {
@@ -429,7 +444,7 @@ func goLookupHostOrder(ctx context.Context, name string, order hostLookupOrder)
return
}
}
- ips, err := goLookupIPOrder(ctx, name, order)
+ ips, _, err := goLookupIPCNAMEOrder(ctx, name, order)
if err != nil {
return
}
@@ -455,27 +470,30 @@ func goLookupIPFiles(name string) (addrs []IPAddr) {
// goLookupIP is the native Go implementation of LookupIP.
// The libc versions are in cgo_*.go.
-func goLookupIP(ctx context.Context, name string) (addrs []IPAddr, err error) {
- return goLookupIPOrder(ctx, name, hostLookupFilesDNS)
+func goLookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
+ order := systemConf().hostLookupOrder(host)
+ addrs, _, err = goLookupIPCNAMEOrder(ctx, host, order)
+ return
}
-func goLookupIPOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, err error) {
+func goLookupIPCNAMEOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, cname string, err error) {
if order == hostLookupFilesDNS || order == hostLookupFiles {
addrs = goLookupIPFiles(name)
if len(addrs) > 0 || order == hostLookupFiles {
- return addrs, nil
+ return addrs, name, nil
}
}
if !isDomainName(name) {
- return nil, &DNSError{Err: "invalid domain name", Name: name}
+ // See comment in func lookup above about use of errNoSuchHost.
+ return nil, "", &DNSError{Err: errNoSuchHost.Error(), Name: name}
}
resolvConf.tryUpdate("/etc/resolv.conf")
resolvConf.mu.RLock()
conf := resolvConf.dnsConfig
resolvConf.mu.RUnlock()
type racer struct {
- fqdn string
- rrs []dnsRR
+ cname string
+ rrs []dnsRR
error
}
lane := make(chan racer, 1)
@@ -484,20 +502,23 @@ func goLookupIPOrder(ctx context.Context, name string, order hostLookupOrder) (a
for _, fqdn := range conf.nameList(name) {
for _, qtype := range qtypes {
go func(qtype uint16) {
- _, rrs, err := tryOneName(ctx, conf, fqdn, qtype)
- lane <- racer{fqdn, rrs, err}
+ cname, rrs, err := tryOneName(ctx, conf, fqdn, qtype)
+ lane <- racer{cname, rrs, err}
}(qtype)
}
for range qtypes {
racer := <-lane
if racer.error != nil {
// Prefer error for original name.
- if lastErr == nil || racer.fqdn == name+"." {
+ if lastErr == nil || fqdn == name+"." {
lastErr = racer.error
}
continue
}
addrs = append(addrs, addrRecordList(racer.rrs)...)
+ if cname == "" {
+ cname = racer.cname
+ }
}
if len(addrs) > 0 {
break
@@ -515,24 +536,16 @@ func goLookupIPOrder(ctx context.Context, name string, order hostLookupOrder) (a
addrs = goLookupIPFiles(name)
}
if len(addrs) == 0 && lastErr != nil {
- return nil, lastErr
+ return nil, "", lastErr
}
}
- return addrs, nil
+ return addrs, cname, nil
}
-// goLookupCNAME is the native Go implementation of LookupCNAME.
-// Used only if cgoLookupCNAME refuses to handle the request
-// (that is, only if cgoLookupCNAME is the stub in cgo_stub.go).
-// Normally we let cgo use the C library resolver instead of
-// depending on our lookup code, so that Go and C get the same
-// answers.
-func goLookupCNAME(ctx context.Context, name string) (cname string, err error) {
- _, rrs, err := lookup(ctx, name, dnsTypeCNAME)
- if err != nil {
- return
- }
- cname = rrs[0].(*dnsRR_CNAME).Cname
+// goLookupCNAME is the native Go (non-cgo) implementation of LookupCNAME.
+func goLookupCNAME(ctx context.Context, host string) (cname string, err error) {
+ order := systemConf().hostLookupOrder(host)
+ _, cname, err = goLookupIPCNAMEOrder(ctx, host, order)
return
}
diff --git a/libgo/go/net/dnsclient_unix_test.go b/libgo/go/net/dnsclient_unix_test.go
index 6ebeeae..85267bb 100644
--- a/libgo/go/net/dnsclient_unix_test.go
+++ b/libgo/go/net/dnsclient_unix_test.go
@@ -411,7 +411,7 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) {
// We need to take care with errors on both
// DNS message exchange layer and DNS
// transport layer because goLookupIP may fail
- // when the IP connectivty on node under test
+ // when the IP connectivity on node under test
// gets lost during its run.
if err, ok := err.(*DNSError); !ok || tt.error != nil && (err.Name != tt.error.(*DNSError).Name || err.Server != tt.error.(*DNSError).Server || err.IsTimeout != tt.error.(*DNSError).IsTimeout) {
t.Errorf("got %v; want %v", err, tt.error)
@@ -455,14 +455,14 @@ func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
name := fmt.Sprintf("order %v", order)
// First ensure that we get an error when contacting a non-existent host.
- _, err := goLookupIPOrder(context.Background(), "notarealhost", order)
+ _, _, err := goLookupIPCNAMEOrder(context.Background(), "notarealhost", order)
if err == nil {
t.Errorf("%s: expected error while looking up name not in hosts file", name)
continue
}
// Now check that we get an address when the name appears in the hosts file.
- addrs, err := goLookupIPOrder(context.Background(), "thor", order) // entry is in "testdata/hosts"
+ addrs, _, err := goLookupIPCNAMEOrder(context.Background(), "thor", order) // entry is in "testdata/hosts"
if err != nil {
t.Errorf("%s: expected to successfully lookup host entry", name)
continue
@@ -668,12 +668,14 @@ func TestIgnoreDNSForgeries(t *testing.T) {
b := make([]byte, 512)
n, err := s.Read(b)
if err != nil {
- t.Fatal(err)
+ t.Error(err)
+ return
}
msg := &dnsMsg{}
if !msg.Unpack(b[:n]) {
- t.Fatal("invalid DNS query")
+ t.Error("invalid DNS query")
+ return
}
s.Write([]byte("garbage DNS response packet"))
@@ -682,7 +684,8 @@ func TestIgnoreDNSForgeries(t *testing.T) {
msg.id++ // make invalid ID
b, ok := msg.Pack()
if !ok {
- t.Fatal("failed to pack DNS response")
+ t.Error("failed to pack DNS response")
+ return
}
s.Write(b)
@@ -701,7 +704,8 @@ func TestIgnoreDNSForgeries(t *testing.T) {
b, ok = msg.Pack()
if !ok {
- t.Fatal("failed to pack DNS response")
+ t.Error("failed to pack DNS response")
+ return
}
s.Write(b)
}()
@@ -740,8 +744,11 @@ func TestRetryTimeout(t *testing.T) {
}
defer conf.teardown()
- if err := conf.writeAndUpdate([]string{"nameserver 192.0.2.1", // the one that will timeout
- "nameserver 192.0.2.2"}); err != nil {
+ testConf := []string{
+ "nameserver 192.0.2.1", // the one that will timeout
+ "nameserver 192.0.2.2",
+ }
+ if err := conf.writeAndUpdate(testConf); err != nil {
t.Fatal(err)
}
@@ -767,28 +774,10 @@ func TestRetryTimeout(t *testing.T) {
t.Error("deadline didn't change")
}
- r := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: q.id,
- response: true,
- recursion_available: true,
- },
- question: q.question,
- answer: []dnsRR{
- &dnsRR_CNAME{
- Hdr: dnsRR_Header{
- Name: q.question[0].Name,
- Rrtype: dnsTypeCNAME,
- Class: dnsClassINET,
- },
- Cname: "golang.org",
- },
- },
- }
- return r, nil
+ return mockTXTResponse(q), nil
}
- _, err = goLookupCNAME(context.Background(), "www.golang.org")
+ _, err = LookupTXT("www.golang.org")
if err != nil {
t.Fatal(err)
}
@@ -797,3 +786,77 @@ func TestRetryTimeout(t *testing.T) {
t.Error("deadline0 still zero", deadline0)
}
}
+
+func TestRotate(t *testing.T) {
+ // without rotation, always uses the first server
+ testRotate(t, false, []string{"192.0.2.1", "192.0.2.2"}, []string{"192.0.2.1:53", "192.0.2.1:53", "192.0.2.1:53"})
+
+ // with rotation, rotates through back to first
+ testRotate(t, true, []string{"192.0.2.1", "192.0.2.2"}, []string{"192.0.2.1:53", "192.0.2.2:53", "192.0.2.1:53"})
+}
+
+func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
+ origTestHookDNSDialer := testHookDNSDialer
+ defer func() { testHookDNSDialer = origTestHookDNSDialer }()
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ var confLines []string
+ for _, ns := range nameservers {
+ confLines = append(confLines, "nameserver "+ns)
+ }
+ if rotate {
+ confLines = append(confLines, "options rotate")
+ }
+
+ if err := conf.writeAndUpdate(confLines); err != nil {
+ t.Fatal(err)
+ }
+
+ d := &fakeDNSDialer{}
+ testHookDNSDialer = func() dnsDialer { return d }
+
+ var usedServers []string
+ d.rh = func(s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
+ usedServers = append(usedServers, s)
+ return mockTXTResponse(q), nil
+ }
+
+ // len(nameservers) + 1 to allow rotation to get back to start
+ for i := 0; i < len(nameservers)+1; i++ {
+ if _, err := LookupTXT("www.golang.org"); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ if !reflect.DeepEqual(usedServers, wantServers) {
+ t.Errorf("rotate=%t got used servers:\n%v\nwant:\n%v", rotate, usedServers, wantServers)
+ }
+}
+
+func mockTXTResponse(q *dnsMsg) *dnsMsg {
+ r := &dnsMsg{
+ dnsMsgHdr: dnsMsgHdr{
+ id: q.id,
+ response: true,
+ recursion_available: true,
+ },
+ question: q.question,
+ answer: []dnsRR{
+ &dnsRR_TXT{
+ Hdr: dnsRR_Header{
+ Name: q.question[0].Name,
+ Rrtype: dnsTypeTXT,
+ Class: dnsClassINET,
+ },
+ Txt: "ok",
+ },
+ },
+ }
+
+ return r
+}
diff --git a/libgo/go/net/dnsconfig_unix.go b/libgo/go/net/dnsconfig_unix.go
index aec575e..9c8108d 100644
--- a/libgo/go/net/dnsconfig_unix.go
+++ b/libgo/go/net/dnsconfig_unix.go
@@ -10,6 +10,7 @@ package net
import (
"os"
+ "sync/atomic"
"time"
)
@@ -29,6 +30,7 @@ type dnsConfig struct {
lookup []string // OpenBSD top-level database "lookup" order
err error // any error that occurs during open of resolv.conf
mtime time.Time // time of resolv.conf modification
+ soffset uint32 // used by serverOffset
}
// See resolv.conf(5) on a Linux machine.
@@ -91,19 +93,21 @@ func dnsReadConfig(filename string) *dnsConfig {
for _, s := range f[1:] {
switch {
case hasPrefix(s, "ndots:"):
- n, _, _ := dtoi(s, 6)
- if n < 1 {
- n = 1
+ n, _, _ := dtoi(s[6:])
+ if n < 0 {
+ n = 0
+ } else if n > 15 {
+ n = 15
}
conf.ndots = n
case hasPrefix(s, "timeout:"):
- n, _, _ := dtoi(s, 8)
+ n, _, _ := dtoi(s[8:])
if n < 1 {
n = 1
}
conf.timeout = time.Duration(n) * time.Second
case hasPrefix(s, "attempts:"):
- n, _, _ := dtoi(s, 9)
+ n, _, _ := dtoi(s[9:])
if n < 1 {
n = 1
}
@@ -134,6 +138,17 @@ func dnsReadConfig(filename string) *dnsConfig {
return conf
}
+// serverOffset returns an offset that can be used to determine
+// indices of servers in c.servers when making queries.
+// When the rotate option is enabled, this offset increases.
+// Otherwise it is always 0.
+func (c *dnsConfig) serverOffset() uint32 {
+ if c.rotate {
+ return atomic.AddUint32(&c.soffset, 1) - 1 // return 0 to start
+ }
+ return 0
+}
+
func dnsDefaultSearch() []string {
hn, err := getHostname()
if err != nil {
diff --git a/libgo/go/net/dnsconfig_unix_test.go b/libgo/go/net/dnsconfig_unix_test.go
index 9fd6dbf..37bdeb0 100644
--- a/libgo/go/net/dnsconfig_unix_test.go
+++ b/libgo/go/net/dnsconfig_unix_test.go
@@ -10,6 +10,7 @@ import (
"errors"
"os"
"reflect"
+ "strings"
"testing"
"time"
)
@@ -61,6 +62,36 @@ var dnsReadConfigTests = []struct {
},
},
{
+ name: "testdata/invalid-ndots-resolv.conf",
+ want: &dnsConfig{
+ servers: defaultNS,
+ ndots: 0,
+ timeout: 5 * time.Second,
+ attempts: 2,
+ search: []string{"domain.local."},
+ },
+ },
+ {
+ name: "testdata/large-ndots-resolv.conf",
+ want: &dnsConfig{
+ servers: defaultNS,
+ ndots: 15,
+ timeout: 5 * time.Second,
+ attempts: 2,
+ search: []string{"domain.local."},
+ },
+ },
+ {
+ name: "testdata/negative-ndots-resolv.conf",
+ want: &dnsConfig{
+ servers: defaultNS,
+ ndots: 0,
+ timeout: 5 * time.Second,
+ attempts: 2,
+ search: []string{"domain.local."},
+ },
+ },
+ {
name: "testdata/openbsd-resolv.conf",
want: &dnsConfig{
ndots: 1,
@@ -154,3 +185,55 @@ func TestDNSDefaultSearch(t *testing.T) {
}
}
}
+
+func TestDNSNameLength(t *testing.T) {
+ origGetHostname := getHostname
+ defer func() { getHostname = origGetHostname }()
+ getHostname = func() (string, error) { return "host.domain.local", nil }
+
+ var char63 = ""
+ for i := 0; i < 63; i++ {
+ char63 += "a"
+ }
+ longDomain := strings.Repeat(char63+".", 5) + "example"
+
+ for _, tt := range dnsReadConfigTests {
+ conf := dnsReadConfig(tt.name)
+ if conf.err != nil {
+ t.Fatal(conf.err)
+ }
+
+ var shortestSuffix int
+ for _, suffix := range tt.want.search {
+ if shortestSuffix == 0 || len(suffix) < shortestSuffix {
+ shortestSuffix = len(suffix)
+ }
+ }
+
+ // Test a name that will be maximally long when prefixing the shortest
+ // suffix (accounting for the intervening dot).
+ longName := longDomain[len(longDomain)-254+1+shortestSuffix:]
+ if longName[0] == '.' || longName[1] == '.' {
+ longName = "aa." + longName[3:]
+ }
+ for _, fqdn := range conf.nameList(longName) {
+ if len(fqdn) > 254 {
+ t.Errorf("got %d; want less than or equal to 254", len(fqdn))
+ }
+ }
+
+ // Now test a name that's too long for suffixing.
+ unsuffixable := "a." + longName[1:]
+ unsuffixableResults := conf.nameList(unsuffixable)
+ if len(unsuffixableResults) != 1 {
+ t.Errorf("suffixed names %v; want []", unsuffixableResults[1:])
+ }
+
+ // Now test a name that's too long for DNS.
+ tooLong := "a." + longDomain
+ tooLongResults := conf.nameList(tooLong)
+ if tooLongResults != nil {
+ t.Errorf("suffixed names %v; want nil", tooLongResults)
+ }
+ }
+}
diff --git a/libgo/go/net/dnsmsg.go b/libgo/go/net/dnsmsg.go
index afdb44c..8f6c7b6 100644
--- a/libgo/go/net/dnsmsg.go
+++ b/libgo/go/net/dnsmsg.go
@@ -69,7 +69,7 @@ const (
)
// A dnsStruct describes how to iterate over its fields to emulate
-// reflective marshalling.
+// reflective marshaling.
type dnsStruct interface {
// Walk iterates over fields of a structure and calls f
// with a reference to that field, the name of the field
diff --git a/libgo/go/net/dnsmsg_test.go b/libgo/go/net/dnsmsg_test.go
index 25bd98c..2a25a21 100644
--- a/libgo/go/net/dnsmsg_test.go
+++ b/libgo/go/net/dnsmsg_test.go
@@ -117,7 +117,7 @@ func TestDNSParseSRVReply(t *testing.T) {
if !ok {
t.Fatal("unpacking packet failed")
}
- msg.String() // exercise this code path
+ _ = msg.String() // exercise this code path
if g, e := len(msg.answer), 5; g != e {
t.Errorf("len(msg.answer) = %d; want %d", g, e)
}
@@ -165,7 +165,7 @@ func TestDNSParseCorruptSRVReply(t *testing.T) {
if !ok {
t.Fatal("unpacking packet failed")
}
- msg.String() // exercise this code path
+ _ = msg.String() // exercise this code path
if g, e := len(msg.answer), 5; g != e {
t.Errorf("len(msg.answer) = %d; want %d", g, e)
}
@@ -393,7 +393,7 @@ func TestIsResponseTo(t *testing.T) {
for i := range badResponses {
if badResponses[i].IsResponseTo(&query) {
- t.Error("%v: got true, want false", i)
+ t.Errorf("%v: got true, want false", i)
}
}
}
diff --git a/libgo/go/net/dnsname_test.go b/libgo/go/net/dnsname_test.go
index bc777b8..e0f786d 100644
--- a/libgo/go/net/dnsname_test.go
+++ b/libgo/go/net/dnsname_test.go
@@ -32,14 +32,12 @@ var dnsNameTests = []dnsNameTest{
func emitDNSNameTest(ch chan<- dnsNameTest) {
defer close(ch)
- var char59 = ""
var char63 = ""
- var char64 = ""
- for i := 0; i < 59; i++ {
- char59 += "a"
+ for i := 0; i < 63; i++ {
+ char63 += "a"
}
- char63 = char59 + "aaaa"
- char64 = char63 + "a"
+ char64 := char63 + "a"
+ longDomain := strings.Repeat(char63+".", 5) + "example"
for _, tc := range dnsNameTests {
ch <- tc
@@ -47,14 +45,15 @@ func emitDNSNameTest(ch chan<- dnsNameTest) {
ch <- dnsNameTest{char63 + ".com", true}
ch <- dnsNameTest{char64 + ".com", false}
- // 255 char name is fine:
- ch <- dnsNameTest{char59 + "." + char63 + "." + char63 + "." +
- char63 + ".com",
- true}
- // 256 char name is bad:
- ch <- dnsNameTest{char59 + "a." + char63 + "." + char63 + "." +
- char63 + ".com",
- false}
+
+ // Remember: wire format is two octets longer than presentation
+ // (length octets for the first and [root] last labels).
+ // 253 is fine:
+ ch <- dnsNameTest{longDomain[len(longDomain)-253:], true}
+ // A terminal dot doesn't contribute to length:
+ ch <- dnsNameTest{longDomain[len(longDomain)-253:] + ".", true}
+ // 254 is bad:
+ ch <- dnsNameTest{longDomain[len(longDomain)-254:], false}
}
func TestDNSName(t *testing.T) {
diff --git a/libgo/go/net/error_test.go b/libgo/go/net/error_test.go
index d6de5a3..c23da49 100644
--- a/libgo/go/net/error_test.go
+++ b/libgo/go/net/error_test.go
@@ -97,7 +97,8 @@ second:
goto third
}
switch nestedErr {
- case errCanceled, errClosing, errMissingAddress, errNoSuitableAddress:
+ case errCanceled, errClosing, errMissingAddress, errNoSuitableAddress,
+ context.DeadlineExceeded, context.Canceled:
return nil
}
return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr)
@@ -230,23 +231,27 @@ func TestDialAddrError(t *testing.T) {
} {
var err error
var c Conn
+ var op string
if tt.lit != "" {
c, err = Dial(tt.network, JoinHostPort(tt.lit, "0"))
+ op = fmt.Sprintf("Dial(%q, %q)", tt.network, JoinHostPort(tt.lit, "0"))
} else {
c, err = DialTCP(tt.network, nil, tt.addr)
+ op = fmt.Sprintf("DialTCP(%q, %q)", tt.network, tt.addr)
}
if err == nil {
c.Close()
- t.Errorf("%s %q/%v: should fail", tt.network, tt.lit, tt.addr)
+ t.Errorf("%s succeeded, want error", op)
continue
}
if perr := parseDialError(err); perr != nil {
- t.Error(perr)
+ t.Errorf("%s: %v", op, perr)
continue
}
- aerr, ok := err.(*OpError).Err.(*AddrError)
+ operr := err.(*OpError).Err
+ aerr, ok := operr.(*AddrError)
if !ok {
- t.Errorf("%s %q/%v: should be AddrError: %v", tt.network, tt.lit, tt.addr, err)
+ t.Errorf("%s: %v is %T, want *AddrError", op, err, operr)
continue
}
want := tt.lit
@@ -254,7 +259,7 @@ func TestDialAddrError(t *testing.T) {
want = tt.addr.IP.String()
}
if aerr.Addr != want {
- t.Fatalf("%s: got %q; want %q", tt.network, aerr.Addr, want)
+ t.Errorf("%s: %v, error Addr=%q, want %q", op, err, aerr.Addr, want)
}
}
}
@@ -521,6 +526,10 @@ third:
if isPlatformError(nestedErr) {
return nil
}
+ switch nestedErr {
+ case os.ErrClosed: // for Plan 9
+ return nil
+ }
return fmt.Errorf("unexpected type on 3rd nested level: %T", nestedErr)
}
diff --git a/libgo/go/net/fd_io_plan9.go b/libgo/go/net/fd_io_plan9.go
new file mode 100644
index 0000000..76da0c5
--- /dev/null
+++ b/libgo/go/net/fd_io_plan9.go
@@ -0,0 +1,93 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "os"
+ "runtime"
+ "sync"
+ "syscall"
+)
+
+// asyncIO implements asynchronous cancelable I/O.
+// An asyncIO represents a single asynchronous Read or Write
+// operation. The result is returned on the result channel.
+// The undergoing I/O system call can either complete or be
+// interrupted by a note.
+type asyncIO struct {
+ res chan result
+
+ // mu guards the pid field.
+ mu sync.Mutex
+
+ // pid holds the process id of
+ // the process running the IO operation.
+ pid int
+}
+
+// result is the return value of a Read or Write operation.
+type result struct {
+ n int
+ err error
+}
+
+// newAsyncIO returns a new asyncIO that performs an I/O
+// operation by calling fn, which must do one and only one
+// interruptible system call.
+func newAsyncIO(fn func([]byte) (int, error), b []byte) *asyncIO {
+ aio := &asyncIO{
+ res: make(chan result, 0),
+ }
+ aio.mu.Lock()
+ go func() {
+ // Lock the current goroutine to its process
+ // and store the pid in io so that Cancel can
+ // interrupt it. We ignore the "hangup" signal,
+ // so the signal does not take down the entire
+ // Go runtime.
+ runtime.LockOSThread()
+ runtime_ignoreHangup()
+ aio.pid = os.Getpid()
+ aio.mu.Unlock()
+
+ n, err := fn(b)
+
+ aio.mu.Lock()
+ aio.pid = -1
+ runtime_unignoreHangup()
+ aio.mu.Unlock()
+
+ aio.res <- result{n, err}
+ }()
+ return aio
+}
+
+var hangupNote os.Signal = syscall.Note("hangup")
+
+// Cancel interrupts the I/O operation, causing
+// the Wait function to return.
+func (aio *asyncIO) Cancel() {
+ aio.mu.Lock()
+ defer aio.mu.Unlock()
+ if aio.pid == -1 {
+ return
+ }
+ proc, err := os.FindProcess(aio.pid)
+ if err != nil {
+ return
+ }
+ proc.Signal(hangupNote)
+}
+
+// Wait for the I/O operation to complete.
+func (aio *asyncIO) Wait() (int, error) {
+ res := <-aio.res
+ return res.n, res.err
+}
+
+// The following functions, provided by the runtime, are used to
+// ignore and unignore the "hangup" signal received by the process.
+func runtime_ignoreHangup()
+func runtime_unignoreHangup()
diff --git a/libgo/go/net/fd_plan9.go b/libgo/go/net/fd_plan9.go
index 7533232..300d8c4 100644
--- a/libgo/go/net/fd_plan9.go
+++ b/libgo/go/net/fd_plan9.go
@@ -7,21 +7,37 @@ package net
import (
"io"
"os"
+ "sync/atomic"
"syscall"
"time"
)
+type atomicBool int32
+
+func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
+func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) }
+func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }
+
// Network file descriptor.
type netFD struct {
// locking/lifetime of sysfd + serialize access to Read and Write methods
fdmu fdMutex
// immutable until Close
- net string
- n string
- dir string
- ctl, data *os.File
- laddr, raddr Addr
+ net string
+ n string
+ dir string
+ listen, ctl, data *os.File
+ laddr, raddr Addr
+ isStream bool
+
+ // deadlines
+ raio *asyncIO
+ waio *asyncIO
+ rtimer *time.Timer
+ wtimer *time.Timer
+ rtimedout atomicBool // set true when read deadline has been reached
+ wtimedout atomicBool // set true when write deadline has been reached
}
var (
@@ -32,8 +48,16 @@ func sysInit() {
netdir = "/net"
}
-func newFD(net, name string, ctl, data *os.File, laddr, raddr Addr) (*netFD, error) {
- return &netFD{net: net, n: name, dir: netdir + "/" + net + "/" + name, ctl: ctl, data: data, laddr: laddr, raddr: raddr}, nil
+func newFD(net, name string, listen, ctl, data *os.File, laddr, raddr Addr) (*netFD, error) {
+ return &netFD{
+ net: net,
+ n: name,
+ dir: netdir + "/" + net + "/" + name,
+ listen: listen,
+ ctl: ctl, data: data,
+ laddr: laddr,
+ raddr: raddr,
+ }, nil
}
func (fd *netFD) init() error {
@@ -64,11 +88,20 @@ func (fd *netFD) destroy() {
err = err1
}
}
+ if fd.listen != nil {
+ if err1 := fd.listen.Close(); err1 != nil && err == nil {
+ err = err1
+ }
+ }
fd.ctl = nil
fd.data = nil
+ fd.listen = nil
}
func (fd *netFD) Read(b []byte) (n int, err error) {
+ if fd.rtimedout.isSet() {
+ return 0, errTimeout
+ }
if !fd.ok() || fd.data == nil {
return 0, syscall.EINVAL
}
@@ -79,10 +112,15 @@ func (fd *netFD) Read(b []byte) (n int, err error) {
if len(b) == 0 {
return 0, nil
}
- n, err = fd.data.Read(b)
+ fd.raio = newAsyncIO(fd.data.Read, b)
+ n, err = fd.raio.Wait()
+ fd.raio = nil
if isHangup(err) {
err = io.EOF
}
+ if isInterrupted(err) {
+ err = errTimeout
+ }
if fd.net == "udp" && err == io.EOF {
n = 0
err = nil
@@ -91,6 +129,9 @@ func (fd *netFD) Read(b []byte) (n int, err error) {
}
func (fd *netFD) Write(b []byte) (n int, err error) {
+ if fd.wtimedout.isSet() {
+ return 0, errTimeout
+ }
if !fd.ok() || fd.data == nil {
return 0, syscall.EINVAL
}
@@ -98,7 +139,13 @@ func (fd *netFD) Write(b []byte) (n int, err error) {
return 0, err
}
defer fd.writeUnlock()
- return fd.data.Write(b)
+ fd.waio = newAsyncIO(fd.data.Write, b)
+ n, err = fd.waio.Wait()
+ fd.waio = nil
+ if isInterrupted(err) {
+ err = errTimeout
+ }
+ return
}
func (fd *netFD) closeRead() error {
@@ -124,11 +171,10 @@ func (fd *netFD) Close() error {
}
if fd.net == "tcp" {
// The following line is required to unblock Reads.
- // For some reason, WriteString returns an error:
- // "write /net/tcp/39/listen: inappropriate use of fd"
- // But without it, Reads on dead conns hang forever.
- // See Issue 9554.
- fd.ctl.WriteString("hangup")
+ _, err := fd.ctl.WriteString("close")
+ if err != nil {
+ return err
+ }
}
err := fd.ctl.Close()
if fd.data != nil {
@@ -136,8 +182,14 @@ func (fd *netFD) Close() error {
err = err1
}
}
+ if fd.listen != nil {
+ if err1 := fd.listen.Close(); err1 != nil && err == nil {
+ err = err1
+ }
+ }
fd.ctl = nil
fd.data = nil
+ fd.listen = nil
return err
}
@@ -165,15 +217,74 @@ func (fd *netFD) file(f *os.File, s string) (*os.File, error) {
}
func (fd *netFD) setDeadline(t time.Time) error {
- return syscall.EPLAN9
+ return setDeadlineImpl(fd, t, 'r'+'w')
}
func (fd *netFD) setReadDeadline(t time.Time) error {
- return syscall.EPLAN9
+ return setDeadlineImpl(fd, t, 'r')
}
func (fd *netFD) setWriteDeadline(t time.Time) error {
- return syscall.EPLAN9
+ return setDeadlineImpl(fd, t, 'w')
+}
+
+func setDeadlineImpl(fd *netFD, t time.Time, mode int) error {
+ d := t.Sub(time.Now())
+ if mode == 'r' || mode == 'r'+'w' {
+ fd.rtimedout.setFalse()
+ }
+ if mode == 'w' || mode == 'r'+'w' {
+ fd.wtimedout.setFalse()
+ }
+ if t.IsZero() || d < 0 {
+ // Stop timer
+ if mode == 'r' || mode == 'r'+'w' {
+ if fd.rtimer != nil {
+ fd.rtimer.Stop()
+ }
+ fd.rtimer = nil
+ }
+ if mode == 'w' || mode == 'r'+'w' {
+ if fd.wtimer != nil {
+ fd.wtimer.Stop()
+ }
+ fd.wtimer = nil
+ }
+ } else {
+ // Interrupt I/O operation once timer has expired
+ if mode == 'r' || mode == 'r'+'w' {
+ fd.rtimer = time.AfterFunc(d, func() {
+ fd.rtimedout.setTrue()
+ if fd.raio != nil {
+ fd.raio.Cancel()
+ }
+ })
+ }
+ if mode == 'w' || mode == 'r'+'w' {
+ fd.wtimer = time.AfterFunc(d, func() {
+ fd.wtimedout.setTrue()
+ if fd.waio != nil {
+ fd.waio.Cancel()
+ }
+ })
+ }
+ }
+ if !t.IsZero() && d < 0 {
+ // Interrupt current I/O operation
+ if mode == 'r' || mode == 'r'+'w' {
+ fd.rtimedout.setTrue()
+ if fd.raio != nil {
+ fd.raio.Cancel()
+ }
+ }
+ if mode == 'w' || mode == 'r'+'w' {
+ fd.wtimedout.setTrue()
+ if fd.waio != nil {
+ fd.waio.Cancel()
+ }
+ }
+ }
+ return nil
}
func setReadBuffer(fd *netFD, bytes int) error {
@@ -187,3 +298,7 @@ func setWriteBuffer(fd *netFD, bytes int) error {
func isHangup(err error) bool {
return err != nil && stringsHasSuffix(err.Error(), "Hangup")
}
+
+func isInterrupted(err error) bool {
+ return err != nil && stringsHasSuffix(err.Error(), "interrupted")
+}
diff --git a/libgo/go/net/fd_poll_nacl.go b/libgo/go/net/fd_poll_nacl.go
index cda8b82..8398760 100644
--- a/libgo/go/net/fd_poll_nacl.go
+++ b/libgo/go/net/fd_poll_nacl.go
@@ -5,6 +5,7 @@
package net
import (
+ "runtime"
"syscall"
"time"
)
@@ -22,6 +23,7 @@ func (pd *pollDesc) evict() {
pd.closing = true
if pd.fd != nil {
syscall.StopIO(pd.fd.sysfd)
+ runtime.KeepAlive(pd.fd)
}
}
diff --git a/libgo/go/net/fd_poll_runtime.go b/libgo/go/net/fd_poll_runtime.go
index 6c1d095..62b69fc 100644
--- a/libgo/go/net/fd_poll_runtime.go
+++ b/libgo/go/net/fd_poll_runtime.go
@@ -7,6 +7,7 @@
package net
import (
+ "runtime"
"sync"
"syscall"
"time"
@@ -33,6 +34,7 @@ var serverInit sync.Once
func (pd *pollDesc) init(fd *netFD) error {
serverInit.Do(runtime_pollServerInit)
ctx, errno := runtime_pollOpen(uintptr(fd.sysfd))
+ runtime.KeepAlive(fd)
if errno != 0 {
return syscall.Errno(errno)
}
@@ -120,7 +122,7 @@ func (fd *netFD) setWriteDeadline(t time.Time) error {
}
func setDeadlineImpl(fd *netFD, t time.Time, mode int) error {
- diff := int64(t.Sub(time.Now()))
+ diff := int64(time.Until(t))
d := runtimeNano() + diff
if d <= 0 && diff > 0 {
// If the user has a deadline in the future, but the delay calculation
diff --git a/libgo/go/net/fd_unix.go b/libgo/go/net/fd_unix.go
index 0309db0..9bc5ebc 100644
--- a/libgo/go/net/fd_unix.go
+++ b/libgo/go/net/fd_unix.go
@@ -24,11 +24,15 @@ type netFD struct {
sysfd int
family int
sotype int
+ isStream bool
isConnected bool
net string
laddr Addr
raddr Addr
+ // writev cache.
+ iovecs *[]syscall.Iovec
+
// wait server
pd pollDesc
}
@@ -37,7 +41,7 @@ func sysInit() {
}
func newFD(sysfd, family, sotype int, net string) (*netFD, error) {
- return &netFD{sysfd: sysfd, family: family, sotype: sotype, net: net}, nil
+ return &netFD{sysfd: sysfd, family: family, sotype: sotype, net: net, isStream: sotype == syscall.SOCK_STREAM}, nil
}
func (fd *netFD) init() error {
@@ -235,6 +239,9 @@ func (fd *netFD) Read(p []byte) (n int, err error) {
if err := fd.pd.prepareRead(); err != nil {
return 0, err
}
+ if fd.isStream && len(p) > 1<<30 {
+ p = p[:1<<30]
+ }
for {
n, err = syscall.Read(fd.sysfd, p)
if err != nil {
@@ -318,7 +325,11 @@ func (fd *netFD) Write(p []byte) (nn int, err error) {
}
for {
var n int
- n, err = syscall.Write(fd.sysfd, p[nn:])
+ max := len(p)
+ if fd.isStream && max-nn > 1<<30 {
+ max = nn + 1<<30
+ }
+ n, err = syscall.Write(fd.sysfd, p[nn:max])
if n > 0 {
nn += n
}
diff --git a/libgo/go/net/fd_windows.go b/libgo/go/net/fd_windows.go
index b0b6769..a976f2a 100644
--- a/libgo/go/net/fd_windows.go
+++ b/libgo/go/net/fd_windows.go
@@ -96,6 +96,7 @@ type operation struct {
rsan int32
handle syscall.Handle
flags uint32
+ bufs []syscall.WSABuf
}
func (o *operation) InitBuf(buf []byte) {
@@ -106,6 +107,30 @@ func (o *operation) InitBuf(buf []byte) {
}
}
+func (o *operation) InitBufs(buf *Buffers) {
+ if o.bufs == nil {
+ o.bufs = make([]syscall.WSABuf, 0, len(*buf))
+ } else {
+ o.bufs = o.bufs[:0]
+ }
+ for _, b := range *buf {
+ var p *byte
+ if len(b) > 0 {
+ p = &b[0]
+ }
+ o.bufs = append(o.bufs, syscall.WSABuf{Len: uint32(len(b)), Buf: p})
+ }
+}
+
+// ClearBufs clears all pointers to Buffers parameter captured
+// by InitBufs, so it can be released by garbage collector.
+func (o *operation) ClearBufs() {
+ for i := range o.bufs {
+ o.bufs[i].Buf = nil
+ }
+ o.bufs = o.bufs[:0]
+}
+
// ioSrv executes net IO requests.
type ioSrv struct {
req chan ioSrvReq
@@ -239,6 +264,7 @@ type netFD struct {
sysfd syscall.Handle
family int
sotype int
+ isStream bool
isConnected bool
skipSyncNotif bool
net string
@@ -257,7 +283,7 @@ func newFD(sysfd syscall.Handle, family, sotype int, net string) (*netFD, error)
return nil, initErr
}
onceStartServer.Do(startServer)
- return &netFD{sysfd: sysfd, family: family, sotype: sotype, net: net}, nil
+ return &netFD{sysfd: sysfd, family: family, sotype: sotype, net: net, isStream: sotype == syscall.SOCK_STREAM}, nil
}
func (fd *netFD) init() error {
@@ -483,6 +509,42 @@ func (fd *netFD) Write(buf []byte) (int, error) {
return n, err
}
+func (c *conn) writeBuffers(v *Buffers) (int64, error) {
+ if !c.ok() {
+ return 0, syscall.EINVAL
+ }
+ n, err := c.fd.writeBuffers(v)
+ if err != nil {
+ return n, &OpError{Op: "WSASend", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return n, nil
+}
+
+func (fd *netFD) writeBuffers(buf *Buffers) (int64, error) {
+ if len(*buf) == 0 {
+ return 0, nil
+ }
+ if err := fd.writeLock(); err != nil {
+ return 0, err
+ }
+ defer fd.writeUnlock()
+ if race.Enabled {
+ race.ReleaseMerge(unsafe.Pointer(&ioSync))
+ }
+ o := &fd.wop
+ o.InitBufs(buf)
+ n, err := wsrv.ExecIO(o, "WSASend", func(o *operation) error {
+ return syscall.WSASend(o.fd.sysfd, &o.bufs[0], uint32(len(*buf)), &o.qty, 0, &o.o, nil)
+ })
+ o.ClearBufs()
+ if _, ok := err.(syscall.Errno); ok {
+ err = os.NewSyscallError("wsasend", err)
+ }
+ testHookDidWritev(n)
+ buf.consume(int64(n))
+ return int64(n), err
+}
+
func (fd *netFD) writeTo(buf []byte, sa syscall.Sockaddr) (int, error) {
if len(buf) == 0 {
return 0, nil
@@ -541,7 +603,7 @@ func (fd *netFD) acceptOne(rawsa []syscall.RawSockaddrAny, o *operation) (*netFD
netfd.Close()
return nil, os.NewSyscallError("setsockopt", err)
}
-
+ runtime.KeepAlive(fd)
return netfd, nil
}
diff --git a/libgo/go/net/file.go b/libgo/go/net/file.go
index 1aad477..07099851 100644
--- a/libgo/go/net/file.go
+++ b/libgo/go/net/file.go
@@ -6,6 +6,9 @@ package net
import "os"
+// BUG(mikio): On NaCl and Windows, the FileConn, FileListener and
+// FilePacketConn functions are not implemented.
+
type fileAddr string
func (fileAddr) Network() string { return "file+net" }
diff --git a/libgo/go/net/file_plan9.go b/libgo/go/net/file_plan9.go
index 2939c09..d16e5a1 100644
--- a/libgo/go/net/file_plan9.go
+++ b/libgo/go/net/file_plan9.go
@@ -81,7 +81,7 @@ func newFileFD(f *os.File) (net *netFD, err error) {
if err != nil {
return nil, err
}
- return newFD(comp[1], name, ctl, nil, laddr, nil)
+ return newFD(comp[1], name, nil, ctl, nil, laddr, nil)
}
func fileConn(f *os.File) (Conn, error) {
diff --git a/libgo/go/net/http/client.go b/libgo/go/net/http/client.go
index 993c247..d368bae 100644
--- a/libgo/go/net/http/client.go
+++ b/libgo/go/net/http/client.go
@@ -18,6 +18,7 @@ import (
"io/ioutil"
"log"
"net/url"
+ "sort"
"strings"
"sync"
"time"
@@ -33,6 +34,25 @@ import (
// A Client is higher-level than a RoundTripper (such as Transport)
// and additionally handles HTTP details such as cookies and
// redirects.
+//
+// When following redirects, the Client will forward all headers set on the
+// initial Request except:
+//
+// * when forwarding sensitive headers like "Authorization",
+// "WWW-Authenticate", and "Cookie" to untrusted targets.
+// These headers will be ignored when following a redirect to a domain
+// that is not a subdomain match or exact match of the initial domain.
+// For example, a redirect from "foo.com" to either "foo.com" or "sub.foo.com"
+// will forward the sensitive headers, but a redirect to "bar.com" will not.
+//
+// * when forwarding the "Cookie" header with a non-nil cookie Jar.
+// Since each redirect may mutate the state of the cookie jar,
+// a redirect may possibly alter a cookie set in the initial request.
+// When forwarding the "Cookie" header, any mutated cookies will be omitted,
+// with the expectation that the Jar will insert those mutated cookies
+// with the updated values (assuming the origin matches).
+// If Jar is nil, the initial cookies are forwarded without change.
+//
type Client struct {
// Transport specifies the mechanism by which individual
// HTTP requests are made.
@@ -56,8 +76,14 @@ type Client struct {
CheckRedirect func(req *Request, via []*Request) error
// Jar specifies the cookie jar.
- // If Jar is nil, cookies are not sent in requests and ignored
- // in responses.
+ //
+ // The Jar is used to insert relevant cookies into every
+ // outbound Request and is updated with the cookie values
+ // of every inbound Response. The Jar is consulted for every
+ // redirect that the Client follows.
+ //
+ // If Jar is nil, cookies are only sent if they are explicitly
+ // set on the Request.
Jar CookieJar
// Timeout specifies a time limit for requests made by this
@@ -137,56 +163,23 @@ func refererForURL(lastReq, newReq *url.URL) string {
return referer
}
-func (c *Client) send(req *Request, deadline time.Time) (*Response, error) {
+// didTimeout is non-nil only if err != nil.
+func (c *Client) send(req *Request, deadline time.Time) (resp *Response, didTimeout func() bool, err error) {
if c.Jar != nil {
for _, cookie := range c.Jar.Cookies(req.URL) {
req.AddCookie(cookie)
}
}
- resp, err := send(req, c.transport(), deadline)
+ resp, didTimeout, err = send(req, c.transport(), deadline)
if err != nil {
- return nil, err
+ return nil, didTimeout, err
}
if c.Jar != nil {
if rc := resp.Cookies(); len(rc) > 0 {
c.Jar.SetCookies(req.URL, rc)
}
}
- return resp, nil
-}
-
-// Do sends an HTTP request and returns an HTTP response, following
-// policy (such as redirects, cookies, auth) as configured on the
-// client.
-//
-// An error is returned if caused by client policy (such as
-// CheckRedirect), or failure to speak HTTP (such as a network
-// connectivity problem). A non-2xx status code doesn't cause an
-// error.
-//
-// If the returned error is nil, the Response will contain a non-nil
-// Body which the user is expected to close. If the Body is not
-// closed, the Client's underlying RoundTripper (typically Transport)
-// may not be able to re-use a persistent TCP connection to the server
-// for a subsequent "keep-alive" request.
-//
-// The request Body, if non-nil, will be closed by the underlying
-// Transport, even on errors.
-//
-// On error, any Response can be ignored. A non-nil Response with a
-// non-nil error only occurs when CheckRedirect fails, and even then
-// the returned Response.Body is already closed.
-//
-// Generally Get, Post, or PostForm will be used instead of Do.
-func (c *Client) Do(req *Request) (*Response, error) {
- method := valueOrDefault(req.Method, "GET")
- if method == "GET" || method == "HEAD" {
- return c.doFollowingRedirects(req, shouldRedirectGet)
- }
- if method == "POST" || method == "PUT" {
- return c.doFollowingRedirects(req, shouldRedirectPost)
- }
- return c.send(req, c.deadline())
+ return resp, nil, nil
}
func (c *Client) deadline() time.Time {
@@ -205,22 +198,22 @@ func (c *Client) transport() RoundTripper {
// send issues an HTTP request.
// Caller should close resp.Body when done reading from it.
-func send(ireq *Request, rt RoundTripper, deadline time.Time) (*Response, error) {
+func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, didTimeout func() bool, err error) {
req := ireq // req is either the original request, or a modified fork
if rt == nil {
req.closeBody()
- return nil, errors.New("http: no Client.Transport or DefaultTransport")
+ return nil, alwaysFalse, errors.New("http: no Client.Transport or DefaultTransport")
}
if req.URL == nil {
req.closeBody()
- return nil, errors.New("http: nil Request.URL")
+ return nil, alwaysFalse, errors.New("http: nil Request.URL")
}
if req.RequestURI != "" {
req.closeBody()
- return nil, errors.New("http: Request.RequestURI can't be set in client requests.")
+ return nil, alwaysFalse, errors.New("http: Request.RequestURI can't be set in client requests.")
}
// forkReq forks req into a shallow clone of ireq the first
@@ -251,9 +244,9 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (*Response, error)
if !deadline.IsZero() {
forkReq()
}
- stopTimer, wasCanceled := setRequestCancel(req, rt, deadline)
+ stopTimer, didTimeout := setRequestCancel(req, rt, deadline)
- resp, err := rt.RoundTrip(req)
+ resp, err = rt.RoundTrip(req)
if err != nil {
stopTimer()
if resp != nil {
@@ -267,22 +260,27 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (*Response, error)
err = errors.New("http: server gave HTTP response to HTTPS client")
}
}
- return nil, err
+ return nil, didTimeout, err
}
if !deadline.IsZero() {
resp.Body = &cancelTimerBody{
- stop: stopTimer,
- rc: resp.Body,
- reqWasCanceled: wasCanceled,
+ stop: stopTimer,
+ rc: resp.Body,
+ reqDidTimeout: didTimeout,
}
}
- return resp, nil
+ return resp, nil, nil
}
// setRequestCancel sets the Cancel field of req, if deadline is
// non-zero. The RoundTripper's type is used to determine whether the legacy
// CancelRequest behavior should be used.
-func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTimer func(), wasCanceled func() bool) {
+//
+// As background, there are three ways to cancel a request:
+// First was Transport.CancelRequest. (deprecated)
+// Second was Request.Cancel (this mechanism).
+// Third was Request.Context.
+func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTimer func(), didTimeout func() bool) {
if deadline.IsZero() {
return nop, alwaysFalse
}
@@ -292,17 +290,8 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi
cancel := make(chan struct{})
req.Cancel = cancel
- wasCanceled = func() bool {
- select {
- case <-cancel:
- return true
- default:
- return false
- }
- }
-
doCancel := func() {
- // The new way:
+ // The newer way (the second way in the func comment):
close(cancel)
// The legacy compatibility way, used only
@@ -324,19 +313,23 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi
var once sync.Once
stopTimer = func() { once.Do(func() { close(stopTimerCh) }) }
- timer := time.NewTimer(deadline.Sub(time.Now()))
+ timer := time.NewTimer(time.Until(deadline))
+ var timedOut atomicBool
+
go func() {
select {
case <-initialReqCancel:
doCancel()
+ timer.Stop()
case <-timer.C:
+ timedOut.setTrue()
doCancel()
case <-stopTimerCh:
timer.Stop()
}
}()
- return stopTimer, wasCanceled
+ return stopTimer, timedOut.isSet
}
// See 2 (end of page 4) http://www.ietf.org/rfc/rfc2617.txt
@@ -349,26 +342,6 @@ func basicAuth(username, password string) string {
return base64.StdEncoding.EncodeToString([]byte(auth))
}
-// True if the specified HTTP status code is one for which the Get utility should
-// automatically redirect.
-func shouldRedirectGet(statusCode int) bool {
- switch statusCode {
- case StatusMovedPermanently, StatusFound, StatusSeeOther, StatusTemporaryRedirect:
- return true
- }
- return false
-}
-
-// True if the specified HTTP status code is one for which the Post utility should
-// automatically redirect.
-func shouldRedirectPost(statusCode int) bool {
- switch statusCode {
- case StatusFound, StatusSeeOther:
- return true
- }
- return false
-}
-
// Get issues a GET to the specified URL. If the response is one of
// the following redirect codes, Get follows the redirect, up to a
// maximum of 10 redirects:
@@ -377,6 +350,7 @@ func shouldRedirectPost(statusCode int) bool {
// 302 (Found)
// 303 (See Other)
// 307 (Temporary Redirect)
+// 308 (Permanent Redirect)
//
// An error is returned if there were too many redirects or if there
// was an HTTP protocol error. A non-2xx response doesn't cause an
@@ -401,6 +375,7 @@ func Get(url string) (resp *Response, err error) {
// 302 (Found)
// 303 (See Other)
// 307 (Temporary Redirect)
+// 308 (Permanent Redirect)
//
// An error is returned if the Client's CheckRedirect function fails
// or if there was an HTTP protocol error. A non-2xx response doesn't
@@ -415,7 +390,7 @@ func (c *Client) Get(url string) (resp *Response, err error) {
if err != nil {
return nil, err
}
- return c.doFollowingRedirects(req, shouldRedirectGet)
+ return c.Do(req)
}
func alwaysFalse() bool { return false }
@@ -436,16 +411,92 @@ func (c *Client) checkRedirect(req *Request, via []*Request) error {
return fn(req, via)
}
-func (c *Client) doFollowingRedirects(req *Request, shouldRedirect func(int) bool) (*Response, error) {
+// redirectBehavior describes what should happen when the
+// client encounters a 3xx status code from the server
+func redirectBehavior(reqMethod string, resp *Response, ireq *Request) (redirectMethod string, shouldRedirect bool) {
+ switch resp.StatusCode {
+ case 301, 302, 303:
+ redirectMethod = reqMethod
+ shouldRedirect = true
+
+ // RFC 2616 allowed automatic redirection only with GET and
+ // HEAD requests. RFC 7231 lifts this restriction, but we still
+ // restrict other methods to GET to maintain compatibility.
+ // See Issue 18570.
+ if reqMethod != "GET" && reqMethod != "HEAD" {
+ redirectMethod = "GET"
+ }
+ case 307, 308:
+ redirectMethod = reqMethod
+ shouldRedirect = true
+
+ // Treat 307 and 308 specially, since they're new in
+ // Go 1.8, and they also require re-sending the request body.
+ if resp.Header.Get("Location") == "" {
+ // 308s have been observed in the wild being served
+ // without Location headers. Since Go 1.7 and earlier
+ // didn't follow these codes, just stop here instead
+ // of returning an error.
+ // See Issue 17773.
+ shouldRedirect = false
+ break
+ }
+ if ireq.GetBody == nil && ireq.outgoingLength() != 0 {
+ // We had a request body, and 307/308 require
+ // re-sending it, but GetBody is not defined. So just
+ // return this response to the user instead of an
+ // error, like we did in Go 1.7 and earlier.
+ shouldRedirect = false
+ }
+ }
+ return redirectMethod, shouldRedirect
+}
+
+// Do sends an HTTP request and returns an HTTP response, following
+// policy (such as redirects, cookies, auth) as configured on the
+// client.
+//
+// An error is returned if caused by client policy (such as
+// CheckRedirect), or failure to speak HTTP (such as a network
+// connectivity problem). A non-2xx status code doesn't cause an
+// error.
+//
+// If the returned error is nil, the Response will contain a non-nil
+// Body which the user is expected to close. If the Body is not
+// closed, the Client's underlying RoundTripper (typically Transport)
+// may not be able to re-use a persistent TCP connection to the server
+// for a subsequent "keep-alive" request.
+//
+// The request Body, if non-nil, will be closed by the underlying
+// Transport, even on errors.
+//
+// On error, any Response can be ignored. A non-nil Response with a
+// non-nil error only occurs when CheckRedirect fails, and even then
+// the returned Response.Body is already closed.
+//
+// Generally Get, Post, or PostForm will be used instead of Do.
+//
+// If the server replies with a redirect, the Client first uses the
+// CheckRedirect function to determine whether the redirect should be
+// followed. If permitted, a 301, 302, or 303 redirect causes
+// subsequent requests to use HTTP method GET
+// (or HEAD if the original request was HEAD), with no body.
+// A 307 or 308 redirect preserves the original HTTP method and body,
+// provided that the Request.GetBody function is defined.
+// The NewRequest function automatically sets GetBody for common
+// standard library body types.
+func (c *Client) Do(req *Request) (*Response, error) {
if req.URL == nil {
req.closeBody()
return nil, errors.New("http: nil Request.URL")
}
var (
- deadline = c.deadline()
- reqs []*Request
- resp *Response
+ deadline = c.deadline()
+ reqs []*Request
+ resp *Response
+ copyHeaders = c.makeHeadersCopier(req)
+ redirectMethod string
)
uerr := func(err error) error {
req.closeBody()
@@ -476,16 +527,27 @@ func (c *Client) doFollowingRedirects(req *Request, shouldRedirect func(int) boo
}
ireq := reqs[0]
req = &Request{
- Method: ireq.Method,
+ Method: redirectMethod,
Response: resp,
URL: u,
Header: make(Header),
Cancel: ireq.Cancel,
ctx: ireq.ctx,
}
- if ireq.Method == "POST" || ireq.Method == "PUT" {
- req.Method = "GET"
+ if ireq.GetBody != nil {
+ req.Body, err = ireq.GetBody()
+ if err != nil {
+ return nil, uerr(err)
+ }
+ req.ContentLength = ireq.ContentLength
}
+
+ // Copy original headers before setting the Referer,
+ // in case the user set Referer on their first request.
+ // If they really want to override, they can do it in
+ // their CheckRedirect func.
+ copyHeaders(req)
+
// Add the Referer header from the most recent
// request URL to the new one, if it's not https->http:
if ref := refererForURL(reqs[len(reqs)-1].URL, req.URL); ref != "" {
@@ -523,10 +585,10 @@ func (c *Client) doFollowingRedirects(req *Request, shouldRedirect func(int) boo
}
reqs = append(reqs, req)
-
var err error
- if resp, err = c.send(req, deadline); err != nil {
- if !deadline.IsZero() && !time.Now().Before(deadline) {
+ var didTimeout func() bool
+ if resp, didTimeout, err = c.send(req, deadline); err != nil {
+ if !deadline.IsZero() && didTimeout() {
err = &httpError{
err: err.Error() + " (Client.Timeout exceeded while awaiting headers)",
timeout: true,
@@ -535,9 +597,77 @@ func (c *Client) doFollowingRedirects(req *Request, shouldRedirect func(int) boo
return nil, uerr(err)
}
- if !shouldRedirect(resp.StatusCode) {
+ var shouldRedirect bool
+ redirectMethod, shouldRedirect = redirectBehavior(req.Method, resp, reqs[0])
+ if !shouldRedirect {
return resp, nil
}
+
+ req.closeBody()
+ }
+}
+
+// makeHeadersCopier makes a function that copies headers from the
+// initial Request, ireq. For every redirect, this function must be called
+// so that it can copy headers into the upcoming Request.
+func (c *Client) makeHeadersCopier(ireq *Request) func(*Request) {
+ // The headers to copy are from the very initial request.
+ // We use a closured callback to keep a reference to these original headers.
+ var (
+ ireqhdr = ireq.Header.clone()
+ icookies map[string][]*Cookie
+ )
+ if c.Jar != nil && ireq.Header.Get("Cookie") != "" {
+ icookies = make(map[string][]*Cookie)
+ for _, c := range ireq.Cookies() {
+ icookies[c.Name] = append(icookies[c.Name], c)
+ }
+ }
+
+ preq := ireq // The previous request
+ return func(req *Request) {
+ // If Jar is present and there was some initial cookies provided
+ // via the request header, then we may need to alter the initial
+ // cookies as we follow redirects since each redirect may end up
+ // modifying a pre-existing cookie.
+ //
+ // Since cookies already set in the request header do not contain
+ // information about the original domain and path, the logic below
+ // assumes any new set cookies override the original cookie
+ // regardless of domain or path.
+ //
+ // See https://golang.org/issue/17494
+ if c.Jar != nil && icookies != nil {
+ var changed bool
+ resp := req.Response // The response that caused the upcoming redirect
+ for _, c := range resp.Cookies() {
+ if _, ok := icookies[c.Name]; ok {
+ delete(icookies, c.Name)
+ changed = true
+ }
+ }
+ if changed {
+ ireqhdr.Del("Cookie")
+ var ss []string
+ for _, cs := range icookies {
+ for _, c := range cs {
+ ss = append(ss, c.Name+"="+c.Value)
+ }
+ }
+ sort.Strings(ss) // Ensure deterministic headers
+ ireqhdr.Set("Cookie", strings.Join(ss, "; "))
+ }
+ }
+
+ // Copy the initial request's Header values
+ // (at least the safe ones).
+ for k, vv := range ireqhdr {
+ if shouldCopyHeaderOnRedirect(k, preq.URL, req.URL) {
+ req.Header[k] = vv
+ }
+ }
+
+ preq = req // Update previous Request with the current request
}
}
@@ -558,8 +688,11 @@ func defaultCheckRedirect(req *Request, via []*Request) error {
// Post is a wrapper around DefaultClient.Post.
//
// To set custom headers, use NewRequest and DefaultClient.Do.
-func Post(url string, bodyType string, body io.Reader) (resp *Response, err error) {
- return DefaultClient.Post(url, bodyType, body)
+//
+// See the Client.Do method documentation for details on how redirects
+// are handled.
+func Post(url string, contentType string, body io.Reader) (resp *Response, err error) {
+ return DefaultClient.Post(url, contentType, body)
}
// Post issues a POST to the specified URL.
@@ -570,13 +703,16 @@ func Post(url string, bodyType string, body io.Reader) (resp *Response, err erro
// request.
//
// To set custom headers, use NewRequest and Client.Do.
-func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Response, err error) {
+//
+// See the Client.Do method documentation for details on how redirects
+// are handled.
+func (c *Client) Post(url string, contentType string, body io.Reader) (resp *Response, err error) {
req, err := NewRequest("POST", url, body)
if err != nil {
return nil, err
}
- req.Header.Set("Content-Type", bodyType)
- return c.doFollowingRedirects(req, shouldRedirectPost)
+ req.Header.Set("Content-Type", contentType)
+ return c.Do(req)
}
// PostForm issues a POST to the specified URL, with data's keys and
@@ -589,6 +725,9 @@ func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Respon
// Caller should close resp.Body when done reading from it.
//
// PostForm is a wrapper around DefaultClient.PostForm.
+//
+// See the Client.Do method documentation for details on how redirects
+// are handled.
func PostForm(url string, data url.Values) (resp *Response, err error) {
return DefaultClient.PostForm(url, data)
}
@@ -601,11 +740,14 @@ func PostForm(url string, data url.Values) (resp *Response, err error) {
//
// When err is nil, resp always contains a non-nil resp.Body.
// Caller should close resp.Body when done reading from it.
+//
+// See the Client.Do method documentation for details on how redirects
+// are handled.
func (c *Client) PostForm(url string, data url.Values) (resp *Response, err error) {
return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode()))
}
-// Head issues a HEAD to the specified URL. If the response is one of
+// Head issues a HEAD to the specified URL. If the response is one of
// the following redirect codes, Head follows the redirect, up to a
// maximum of 10 redirects:
//
@@ -613,13 +755,14 @@ func (c *Client) PostForm(url string, data url.Values) (resp *Response, err erro
// 302 (Found)
// 303 (See Other)
// 307 (Temporary Redirect)
+// 308 (Permanent Redirect)
//
// Head is a wrapper around DefaultClient.Head
func Head(url string) (resp *Response, err error) {
return DefaultClient.Head(url)
}
-// Head issues a HEAD to the specified URL. If the response is one of the
+// Head issues a HEAD to the specified URL. If the response is one of the
// following redirect codes, Head follows the redirect after calling the
// Client's CheckRedirect function:
//
@@ -627,22 +770,23 @@ func Head(url string) (resp *Response, err error) {
// 302 (Found)
// 303 (See Other)
// 307 (Temporary Redirect)
+// 308 (Permanent Redirect)
func (c *Client) Head(url string) (resp *Response, err error) {
req, err := NewRequest("HEAD", url, nil)
if err != nil {
return nil, err
}
- return c.doFollowingRedirects(req, shouldRedirectGet)
+ return c.Do(req)
}
// cancelTimerBody is an io.ReadCloser that wraps rc with two features:
// 1) on Read error or close, the stop func is called.
-// 2) On Read failure, if reqWasCanceled is true, the error is wrapped and
+// 2) On Read failure, if reqDidTimeout is true, the error is wrapped and
// marked as net.Error that hit its timeout.
type cancelTimerBody struct {
- stop func() // stops the time.Timer waiting to cancel the request
- rc io.ReadCloser
- reqWasCanceled func() bool
+ stop func() // stops the time.Timer waiting to cancel the request
+ rc io.ReadCloser
+ reqDidTimeout func() bool
}
func (b *cancelTimerBody) Read(p []byte) (n int, err error) {
@@ -654,7 +798,7 @@ func (b *cancelTimerBody) Read(p []byte) (n int, err error) {
if err == io.EOF {
return n, err
}
- if b.reqWasCanceled() {
+ if b.reqDidTimeout() {
err = &httpError{
err: err.Error() + " (Client.Timeout exceeded while reading body)",
timeout: true,
@@ -668,3 +812,52 @@ func (b *cancelTimerBody) Close() error {
b.stop()
return err
}
+
+func shouldCopyHeaderOnRedirect(headerKey string, initial, dest *url.URL) bool {
+ switch CanonicalHeaderKey(headerKey) {
+ case "Authorization", "Www-Authenticate", "Cookie", "Cookie2":
+ // Permit sending auth/cookie headers from "foo.com"
+ // to "sub.foo.com".
+
+ // Note that we don't send all cookies to subdomains
+ // automatically. This function is only used for
+ // Cookies set explicitly on the initial outgoing
+ // client request. Cookies automatically added via the
+ // CookieJar mechanism continue to follow each
+ // cookie's scope as set by Set-Cookie. But for
+ // outgoing requests with the Cookie header set
+ // directly, we don't know their scope, so we assume
+ // it's for *.domain.com.
+
+ // TODO(bradfitz): once issue 16142 is fixed, make
+ // this code use those URL accessors, and consider
+ // "http://foo.com" and "http://foo.com:80" as
+ // equivalent?
+
+ // TODO(bradfitz): better hostname canonicalization,
+ // at least once we figure out IDNA/Punycode (issue
+ // 13835).
+ ihost := strings.ToLower(initial.Host)
+ dhost := strings.ToLower(dest.Host)
+ return isDomainOrSubdomain(dhost, ihost)
+ }
+ // All other headers are copied:
+ return true
+}
+
+// isDomainOrSubdomain reports whether sub is a subdomain (or exact
+// match) of the parent domain.
+//
+// Both domains must already be in canonical form.
+func isDomainOrSubdomain(sub, parent string) bool {
+ if sub == parent {
+ return true
+ }
+ // If sub is "foo.example.com" and parent is "example.com",
+ // that means sub must end in "."+parent.
+ // Do it without allocating.
+ if !strings.HasSuffix(sub, parent) {
+ return false
+ }
+ return sub[len(sub)-len(parent)-1] == '.'
+}
diff --git a/libgo/go/net/http/client_test.go b/libgo/go/net/http/client_test.go
index a9b1948..eaf2cdc 100644
--- a/libgo/go/net/http/client_test.go
+++ b/libgo/go/net/http/client_test.go
@@ -19,11 +19,14 @@ import (
"log"
"net"
. "net/http"
+ "net/http/cookiejar"
"net/http/httptest"
"net/url"
+ "reflect"
"strconv"
"strings"
"sync"
+ "sync/atomic"
"testing"
"time"
)
@@ -65,11 +68,13 @@ func (w chanWriter) Write(p []byte) (n int, err error) {
}
func TestClient(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
ts := httptest.NewServer(robotsTxtHandler)
defer ts.Close()
- r, err := Get(ts.URL)
+ c := &Client{Transport: &Transport{DisableKeepAlives: true}}
+ r, err := c.Get(ts.URL)
var b []byte
if err == nil {
b, err = pedanticReadAll(r.Body)
@@ -109,6 +114,7 @@ func (t *recordingTransport) RoundTrip(req *Request) (resp *Response, err error)
}
func TestGetRequestFormat(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
tr := &recordingTransport{}
client := &Client{Transport: tr}
@@ -195,6 +201,7 @@ func TestPostFormRequestFormat(t *testing.T) {
}
func TestClientRedirects(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
var ts *httptest.Server
ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -206,14 +213,17 @@ func TestClientRedirects(t *testing.T) {
}
}
if n < 15 {
- Redirect(w, r, fmt.Sprintf("/?n=%d", n+1), StatusFound)
+ Redirect(w, r, fmt.Sprintf("/?n=%d", n+1), StatusTemporaryRedirect)
return
}
fmt.Fprintf(w, "n=%d", n)
}))
defer ts.Close()
- c := &Client{}
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+
+ c := &Client{Transport: tr}
_, err := c.Get(ts.URL)
if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
t.Errorf("with default client Get, expected error %q, got %q", e, g)
@@ -242,11 +252,14 @@ func TestClientRedirects(t *testing.T) {
var checkErr error
var lastVia []*Request
var lastReq *Request
- c = &Client{CheckRedirect: func(req *Request, via []*Request) error {
- lastReq = req
- lastVia = via
- return checkErr
- }}
+ c = &Client{
+ Transport: tr,
+ CheckRedirect: func(req *Request, via []*Request) error {
+ lastReq = req
+ lastVia = via
+ return checkErr
+ },
+ }
res, err := c.Get(ts.URL)
if err != nil {
t.Fatalf("Get error: %v", err)
@@ -292,20 +305,27 @@ func TestClientRedirects(t *testing.T) {
}
func TestClientRedirectContext(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
- Redirect(w, r, "/", StatusFound)
+ Redirect(w, r, "/", StatusTemporaryRedirect)
}))
defer ts.Close()
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+
ctx, cancel := context.WithCancel(context.Background())
- c := &Client{CheckRedirect: func(req *Request, via []*Request) error {
- cancel()
- if len(via) > 2 {
- return errors.New("too many redirects")
- }
- return nil
- }}
+ c := &Client{
+ Transport: tr,
+ CheckRedirect: func(req *Request, via []*Request) error {
+ cancel()
+ if len(via) > 2 {
+ return errors.New("too many redirects")
+ }
+ return nil
+ },
+ }
req, _ := NewRequest("GET", ts.URL, nil)
req = req.WithContext(ctx)
_, err := c.Do(req)
@@ -313,12 +333,96 @@ func TestClientRedirectContext(t *testing.T) {
if !ok {
t.Fatalf("got error %T; want *url.Error", err)
}
- if ue.Err != ExportErrRequestCanceled && ue.Err != ExportErrRequestCanceledConn {
- t.Errorf("url.Error.Err = %v; want errRequestCanceled or errRequestCanceledConn", ue.Err)
+ if ue.Err != context.Canceled {
+ t.Errorf("url.Error.Err = %v; want %v", ue.Err, context.Canceled)
}
}
+type redirectTest struct {
+ suffix string
+ want int // response code
+ redirectBody string
+}
+
func TestPostRedirects(t *testing.T) {
+ postRedirectTests := []redirectTest{
+ {"/", 200, "first"},
+ {"/?code=301&next=302", 200, "c301"},
+ {"/?code=302&next=302", 200, "c302"},
+ {"/?code=303&next=301", 200, "c303wc301"}, // Issue 9348
+ {"/?code=304", 304, "c304"},
+ {"/?code=305", 305, "c305"},
+ {"/?code=307&next=303,308,302", 200, "c307"},
+ {"/?code=308&next=302,301", 200, "c308"},
+ {"/?code=404", 404, "c404"},
+ }
+
+ wantSegments := []string{
+ `POST / "first"`,
+ `POST /?code=301&next=302 "c301"`,
+ `GET /?code=302 "c301"`,
+ `GET / "c301"`,
+ `POST /?code=302&next=302 "c302"`,
+ `GET /?code=302 "c302"`,
+ `GET / "c302"`,
+ `POST /?code=303&next=301 "c303wc301"`,
+ `GET /?code=301 "c303wc301"`,
+ `GET / "c303wc301"`,
+ `POST /?code=304 "c304"`,
+ `POST /?code=305 "c305"`,
+ `POST /?code=307&next=303,308,302 "c307"`,
+ `POST /?code=303&next=308,302 "c307"`,
+ `GET /?code=308&next=302 "c307"`,
+ `GET /?code=302 "c307"`,
+ `GET / "c307"`,
+ `POST /?code=308&next=302,301 "c308"`,
+ `POST /?code=302&next=301 "c308"`,
+ `GET /?code=301 "c308"`,
+ `GET / "c308"`,
+ `POST /?code=404 "c404"`,
+ }
+ want := strings.Join(wantSegments, "\n")
+ testRedirectsByMethod(t, "POST", postRedirectTests, want)
+}
+
+func TestDeleteRedirects(t *testing.T) {
+ deleteRedirectTests := []redirectTest{
+ {"/", 200, "first"},
+ {"/?code=301&next=302,308", 200, "c301"},
+ {"/?code=302&next=302", 200, "c302"},
+ {"/?code=303", 200, "c303"},
+ {"/?code=307&next=301,308,303,302,304", 304, "c307"},
+ {"/?code=308&next=307", 200, "c308"},
+ {"/?code=404", 404, "c404"},
+ }
+
+ wantSegments := []string{
+ `DELETE / "first"`,
+ `DELETE /?code=301&next=302,308 "c301"`,
+ `GET /?code=302&next=308 "c301"`,
+ `GET /?code=308 "c301"`,
+ `GET / "c301"`,
+ `DELETE /?code=302&next=302 "c302"`,
+ `GET /?code=302 "c302"`,
+ `GET / "c302"`,
+ `DELETE /?code=303 "c303"`,
+ `GET / "c303"`,
+ `DELETE /?code=307&next=301,308,303,302,304 "c307"`,
+ `DELETE /?code=301&next=308,303,302,304 "c307"`,
+ `GET /?code=308&next=303,302,304 "c307"`,
+ `GET /?code=303&next=302,304 "c307"`,
+ `GET /?code=302&next=304 "c307"`,
+ `GET /?code=304 "c307"`,
+ `DELETE /?code=308&next=307 "c308"`,
+ `DELETE /?code=307 "c308"`,
+ `DELETE / "c308"`,
+ `DELETE /?code=404 "c404"`,
+ }
+ want := strings.Join(wantSegments, "\n")
+ testRedirectsByMethod(t, "DELETE", deleteRedirectTests, want)
+}
+
+func testRedirectsByMethod(t *testing.T, method string, table []redirectTest, want string) {
defer afterTest(t)
var log struct {
sync.Mutex
@@ -327,29 +431,35 @@ func TestPostRedirects(t *testing.T) {
var ts *httptest.Server
ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
log.Lock()
- fmt.Fprintf(&log.Buffer, "%s %s ", r.Method, r.RequestURI)
+ slurp, _ := ioutil.ReadAll(r.Body)
+ fmt.Fprintf(&log.Buffer, "%s %s %q\n", r.Method, r.RequestURI, slurp)
log.Unlock()
- if v := r.URL.Query().Get("code"); v != "" {
+ urlQuery := r.URL.Query()
+ if v := urlQuery.Get("code"); v != "" {
+ location := ts.URL
+ if final := urlQuery.Get("next"); final != "" {
+ splits := strings.Split(final, ",")
+ first, rest := splits[0], splits[1:]
+ location = fmt.Sprintf("%s?code=%s", location, first)
+ if len(rest) > 0 {
+ location = fmt.Sprintf("%s&next=%s", location, strings.Join(rest, ","))
+ }
+ }
code, _ := strconv.Atoi(v)
if code/100 == 3 {
- w.Header().Set("Location", ts.URL)
+ w.Header().Set("Location", location)
}
w.WriteHeader(code)
}
}))
defer ts.Close()
- tests := []struct {
- suffix string
- want int // response code
- }{
- {"/", 200},
- {"/?code=301", 301},
- {"/?code=302", 200},
- {"/?code=303", 200},
- {"/?code=404", 404},
- }
- for _, tt := range tests {
- res, err := Post(ts.URL+tt.suffix, "text/plain", strings.NewReader("Some content"))
+
+ for _, tt := range table {
+ content := tt.redirectBody
+ req, _ := NewRequest(method, ts.URL+tt.suffix, strings.NewReader(content))
+ req.GetBody = func() (io.ReadCloser, error) { return ioutil.NopCloser(strings.NewReader(content)), nil }
+ res, err := DefaultClient.Do(req)
+
if err != nil {
t.Fatal(err)
}
@@ -360,13 +470,17 @@ func TestPostRedirects(t *testing.T) {
log.Lock()
got := log.String()
log.Unlock()
- want := "POST / POST /?code=301 POST /?code=302 GET / POST /?code=303 GET / POST /?code=404 "
+
+ got = strings.TrimSpace(got)
+ want = strings.TrimSpace(want)
+
if got != want {
- t.Errorf("Log differs.\n Got: %q\nWant: %q", got, want)
+ t.Errorf("Log differs.\n Got:\n%s\nWant:\n%s\n", got, want)
}
}
func TestClientRedirectUseResponse(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
const body = "Hello, world."
var ts *httptest.Server
@@ -381,12 +495,18 @@ func TestClientRedirectUseResponse(t *testing.T) {
}))
defer ts.Close()
- c := &Client{CheckRedirect: func(req *Request, via []*Request) error {
- if req.Response == nil {
- t.Error("expected non-nil Request.Response")
- }
- return ErrUseLastResponse
- }}
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+
+ c := &Client{
+ Transport: tr,
+ CheckRedirect: func(req *Request, via []*Request) error {
+ if req.Response == nil {
+ t.Error("expected non-nil Request.Response")
+ }
+ return ErrUseLastResponse
+ },
+ }
res, err := c.Get(ts.URL)
if err != nil {
t.Fatal(err)
@@ -404,6 +524,57 @@ func TestClientRedirectUseResponse(t *testing.T) {
}
}
+// Issue 17773: don't follow a 308 (or 307) if the response doesn't
+// have a Location header.
+func TestClientRedirect308NoLocation(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Foo", "Bar")
+ w.WriteHeader(308)
+ }))
+ defer ts.Close()
+ res, err := Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if res.StatusCode != 308 {
+ t.Errorf("status = %d; want %d", res.StatusCode, 308)
+ }
+ if got := res.Header.Get("Foo"); got != "Bar" {
+ t.Errorf("Foo header = %q; want Bar", got)
+ }
+}
+
+// Don't follow a 307/308 if we can't resent the request body.
+func TestClientRedirect308NoGetBody(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ const fakeURL = "https://localhost:1234/" // won't be hit
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Location", fakeURL)
+ w.WriteHeader(308)
+ }))
+ defer ts.Close()
+ req, err := NewRequest("POST", ts.URL, strings.NewReader("some body"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.GetBody = nil // so it can't rewind.
+ res, err := DefaultClient.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if res.StatusCode != 308 {
+ t.Errorf("status = %d; want %d", res.StatusCode, 308)
+ }
+ if got := res.Header.Get("Location"); got != fakeURL {
+ t.Errorf("Location header = %q; want %q", got, fakeURL)
+ }
+}
+
var expectedCookies = []*Cookie{
{Name: "ChocolateChip", Value: "tasty"},
{Name: "First", Value: "Hit"},
@@ -476,12 +647,16 @@ func (j *TestJar) Cookies(u *url.URL) []*Cookie {
}
func TestRedirectCookiesJar(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
var ts *httptest.Server
ts = httptest.NewServer(echoCookiesRedirectHandler)
defer ts.Close()
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
c := &Client{
- Jar: new(TestJar),
+ Transport: tr,
+ Jar: new(TestJar),
}
u, _ := url.Parse(ts.URL)
c.Jar.SetCookies(u, []*Cookie{expectedCookies[0]})
@@ -665,6 +840,7 @@ func TestClientWrites(t *testing.T) {
}
func TestClientInsecureTransport(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Write([]byte("Hello"))
@@ -842,6 +1018,7 @@ func TestResponseSetsTLSConnectionState(t *testing.T) {
func TestHTTPSClientDetectsHTTPServer(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+ ts.Config.ErrorLog = quietLog
defer ts.Close()
_, err := Get(strings.Replace(ts.URL, "http", "https", 1))
@@ -895,6 +1072,7 @@ func testClientHeadContentLength(t *testing.T, h2 bool) {
}
func TestEmptyPasswordAuth(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
gopher := "gopher"
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -915,7 +1093,9 @@ func TestEmptyPasswordAuth(t *testing.T) {
}
}))
defer ts.Close()
- c := &Client{}
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
req, err := NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
@@ -1007,10 +1187,10 @@ func TestClientTimeout_h1(t *testing.T) { testClientTimeout(t, h1Mode) }
func TestClientTimeout_h2(t *testing.T) { testClientTimeout(t, h2Mode) }
func testClientTimeout(t *testing.T, h2 bool) {
- if testing.Short() {
- t.Skip("skipping in short mode")
- }
+ setParallel(t)
defer afterTest(t)
+ testDone := make(chan struct{}) // closed in defer below
+
sawRoot := make(chan bool, 1)
sawSlow := make(chan bool, 1)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -1020,19 +1200,26 @@ func testClientTimeout(t *testing.T, h2 bool) {
return
}
if r.URL.Path == "/slow" {
+ sawSlow <- true
w.Write([]byte("Hello"))
w.(Flusher).Flush()
- sawSlow <- true
- time.Sleep(2 * time.Second)
+ <-testDone
return
}
}))
defer cst.close()
- const timeout = 500 * time.Millisecond
+ defer close(testDone) // before cst.close, to unblock /slow handler
+
+ // 200ms should be long enough to get a normal request (the /
+ // handler), but not so long that it makes the test slow.
+ const timeout = 200 * time.Millisecond
cst.c.Timeout = timeout
res, err := cst.c.Get(cst.ts.URL)
if err != nil {
+ if strings.Contains(err.Error(), "Client.Timeout") {
+ t.Skipf("host too slow to get fast resource in %v", timeout)
+ }
t.Fatal(err)
}
@@ -1057,7 +1244,7 @@ func testClientTimeout(t *testing.T, h2 bool) {
res.Body.Close()
}()
- const failTime = timeout * 2
+ const failTime = 5 * time.Second
select {
case err := <-errc:
if err == nil {
@@ -1082,11 +1269,9 @@ func TestClientTimeout_Headers_h2(t *testing.T) { testClientTimeout_Headers(t, h
// Client.Timeout firing before getting to the body
func testClientTimeout_Headers(t *testing.T, h2 bool) {
- if testing.Short() {
- t.Skip("skipping in short mode")
- }
+ setParallel(t)
defer afterTest(t)
- donec := make(chan bool)
+ donec := make(chan bool, 1)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
<-donec
}))
@@ -1100,9 +1285,10 @@ func testClientTimeout_Headers(t *testing.T, h2 bool) {
// doesn't know this, so synchronize explicitly.
defer func() { donec <- true }()
- cst.c.Timeout = 500 * time.Millisecond
- _, err := cst.c.Get(cst.ts.URL)
+ cst.c.Timeout = 5 * time.Millisecond
+ res, err := cst.c.Get(cst.ts.URL)
if err == nil {
+ res.Body.Close()
t.Fatal("got response from Get; expected error")
}
if _, ok := err.(*url.Error); !ok {
@@ -1120,9 +1306,40 @@ func testClientTimeout_Headers(t *testing.T, h2 bool) {
}
}
+// Issue 16094: if Client.Timeout is set but not hit, a Timeout error shouldn't be
+// returned.
+func TestClientTimeoutCancel(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+
+ testDone := make(chan struct{})
+ ctx, cancel := context.WithCancel(context.Background())
+
+ cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.(Flusher).Flush()
+ <-testDone
+ }))
+ defer cst.close()
+ defer close(testDone)
+
+ cst.c.Timeout = 1 * time.Hour
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ req.Cancel = ctx.Done()
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ cancel()
+ _, err = io.Copy(ioutil.Discard, res.Body)
+ if err != ExportErrRequestCanceled {
+ t.Fatalf("error = %v; want errRequestCanceled", err)
+ }
+}
+
func TestClientRedirectEatsBody_h1(t *testing.T) { testClientRedirectEatsBody(t, h1Mode) }
func TestClientRedirectEatsBody_h2(t *testing.T) { testClientRedirectEatsBody(t, h2Mode) }
func testClientRedirectEatsBody(t *testing.T, h2 bool) {
+ setParallel(t)
defer afterTest(t)
saw := make(chan string, 2)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -1138,10 +1355,10 @@ func testClientRedirectEatsBody(t *testing.T, h2 bool) {
t.Fatal(err)
}
_, err = ioutil.ReadAll(res.Body)
+ res.Body.Close()
if err != nil {
t.Fatal(err)
}
- res.Body.Close()
var first string
select {
@@ -1229,3 +1446,369 @@ func TestClientRedirectResponseWithoutRequest(t *testing.T) {
// Check that this doesn't crash:
c.Get("http://dummy.tld")
}
+
+// Issue 4800: copy (some) headers when Client follows a redirect
+func TestClientCopyHeadersOnRedirect(t *testing.T) {
+ const (
+ ua = "some-agent/1.2"
+ xfoo = "foo-val"
+ )
+ var ts2URL string
+ ts1 := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ want := Header{
+ "User-Agent": []string{ua},
+ "X-Foo": []string{xfoo},
+ "Referer": []string{ts2URL},
+ "Accept-Encoding": []string{"gzip"},
+ }
+ if !reflect.DeepEqual(r.Header, want) {
+ t.Errorf("Request.Header = %#v; want %#v", r.Header, want)
+ }
+ if t.Failed() {
+ w.Header().Set("Result", "got errors")
+ } else {
+ w.Header().Set("Result", "ok")
+ }
+ }))
+ defer ts1.Close()
+ ts2 := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ Redirect(w, r, ts1.URL, StatusFound)
+ }))
+ defer ts2.Close()
+ ts2URL = ts2.URL
+
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+ c := &Client{
+ Transport: tr,
+ CheckRedirect: func(r *Request, via []*Request) error {
+ want := Header{
+ "User-Agent": []string{ua},
+ "X-Foo": []string{xfoo},
+ "Referer": []string{ts2URL},
+ }
+ if !reflect.DeepEqual(r.Header, want) {
+ t.Errorf("CheckRedirect Request.Header = %#v; want %#v", r.Header, want)
+ }
+ return nil
+ },
+ }
+
+ req, _ := NewRequest("GET", ts2.URL, nil)
+ req.Header.Add("User-Agent", ua)
+ req.Header.Add("X-Foo", xfoo)
+ req.Header.Add("Cookie", "foo=bar")
+ req.Header.Add("Authorization", "secretpassword")
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != 200 {
+ t.Fatal(res.Status)
+ }
+ if got := res.Header.Get("Result"); got != "ok" {
+ t.Errorf("result = %q; want ok", got)
+ }
+}
+
+// Issue 17494: cookies should be altered when Client follows redirects.
+func TestClientAltersCookiesOnRedirect(t *testing.T) {
+ cookieMap := func(cs []*Cookie) map[string][]string {
+ m := make(map[string][]string)
+ for _, c := range cs {
+ m[c.Name] = append(m[c.Name], c.Value)
+ }
+ return m
+ }
+
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ var want map[string][]string
+ got := cookieMap(r.Cookies())
+
+ c, _ := r.Cookie("Cycle")
+ switch c.Value {
+ case "0":
+ want = map[string][]string{
+ "Cookie1": {"OldValue1a", "OldValue1b"},
+ "Cookie2": {"OldValue2"},
+ "Cookie3": {"OldValue3a", "OldValue3b"},
+ "Cookie4": {"OldValue4"},
+ "Cycle": {"0"},
+ }
+ SetCookie(w, &Cookie{Name: "Cycle", Value: "1", Path: "/"})
+ SetCookie(w, &Cookie{Name: "Cookie2", Path: "/", MaxAge: -1}) // Delete cookie from Header
+ Redirect(w, r, "/", StatusFound)
+ case "1":
+ want = map[string][]string{
+ "Cookie1": {"OldValue1a", "OldValue1b"},
+ "Cookie3": {"OldValue3a", "OldValue3b"},
+ "Cookie4": {"OldValue4"},
+ "Cycle": {"1"},
+ }
+ SetCookie(w, &Cookie{Name: "Cycle", Value: "2", Path: "/"})
+ SetCookie(w, &Cookie{Name: "Cookie3", Value: "NewValue3", Path: "/"}) // Modify cookie in Header
+ SetCookie(w, &Cookie{Name: "Cookie4", Value: "NewValue4", Path: "/"}) // Modify cookie in Jar
+ Redirect(w, r, "/", StatusFound)
+ case "2":
+ want = map[string][]string{
+ "Cookie1": {"OldValue1a", "OldValue1b"},
+ "Cookie3": {"NewValue3"},
+ "Cookie4": {"NewValue4"},
+ "Cycle": {"2"},
+ }
+ SetCookie(w, &Cookie{Name: "Cycle", Value: "3", Path: "/"})
+ SetCookie(w, &Cookie{Name: "Cookie5", Value: "NewValue5", Path: "/"}) // Insert cookie into Jar
+ Redirect(w, r, "/", StatusFound)
+ case "3":
+ want = map[string][]string{
+ "Cookie1": {"OldValue1a", "OldValue1b"},
+ "Cookie3": {"NewValue3"},
+ "Cookie4": {"NewValue4"},
+ "Cookie5": {"NewValue5"},
+ "Cycle": {"3"},
+ }
+ // Don't redirect to ensure the loop ends.
+ default:
+ t.Errorf("unexpected redirect cycle")
+ return
+ }
+
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("redirect %s, Cookie = %v, want %v", c.Value, got, want)
+ }
+ }))
+ defer ts.Close()
+
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+ jar, _ := cookiejar.New(nil)
+ c := &Client{
+ Transport: tr,
+ Jar: jar,
+ }
+
+ u, _ := url.Parse(ts.URL)
+ req, _ := NewRequest("GET", ts.URL, nil)
+ req.AddCookie(&Cookie{Name: "Cookie1", Value: "OldValue1a"})
+ req.AddCookie(&Cookie{Name: "Cookie1", Value: "OldValue1b"})
+ req.AddCookie(&Cookie{Name: "Cookie2", Value: "OldValue2"})
+ req.AddCookie(&Cookie{Name: "Cookie3", Value: "OldValue3a"})
+ req.AddCookie(&Cookie{Name: "Cookie3", Value: "OldValue3b"})
+ jar.SetCookies(u, []*Cookie{{Name: "Cookie4", Value: "OldValue4", Path: "/"}})
+ jar.SetCookies(u, []*Cookie{{Name: "Cycle", Value: "0", Path: "/"}})
+ res, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != 200 {
+ t.Fatal(res.Status)
+ }
+}
+
+// Part of Issue 4800
+func TestShouldCopyHeaderOnRedirect(t *testing.T) {
+ tests := []struct {
+ header string
+ initialURL string
+ destURL string
+ want bool
+ }{
+ {"User-Agent", "http://foo.com/", "http://bar.com/", true},
+ {"X-Foo", "http://foo.com/", "http://bar.com/", true},
+
+ // Sensitive headers:
+ {"cookie", "http://foo.com/", "http://bar.com/", false},
+ {"cookie2", "http://foo.com/", "http://bar.com/", false},
+ {"authorization", "http://foo.com/", "http://bar.com/", false},
+ {"www-authenticate", "http://foo.com/", "http://bar.com/", false},
+
+ // But subdomains should work:
+ {"www-authenticate", "http://foo.com/", "http://foo.com/", true},
+ {"www-authenticate", "http://foo.com/", "http://sub.foo.com/", true},
+ {"www-authenticate", "http://foo.com/", "http://notfoo.com/", false},
+ // TODO(bradfitz): make this test work, once issue 16142 is fixed:
+ // {"www-authenticate", "http://foo.com:80/", "http://foo.com/", true},
+ }
+ for i, tt := range tests {
+ u0, err := url.Parse(tt.initialURL)
+ if err != nil {
+ t.Errorf("%d. initial URL %q parse error: %v", i, tt.initialURL, err)
+ continue
+ }
+ u1, err := url.Parse(tt.destURL)
+ if err != nil {
+ t.Errorf("%d. dest URL %q parse error: %v", i, tt.destURL, err)
+ continue
+ }
+ got := Export_shouldCopyHeaderOnRedirect(tt.header, u0, u1)
+ if got != tt.want {
+ t.Errorf("%d. shouldCopyHeaderOnRedirect(%q, %q => %q) = %v; want %v",
+ i, tt.header, tt.initialURL, tt.destURL, got, tt.want)
+ }
+ }
+}
+
+func TestClientRedirectTypes(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+
+ tests := [...]struct {
+ method string
+ serverStatus int
+ wantMethod string // desired subsequent client method
+ }{
+ 0: {method: "POST", serverStatus: 301, wantMethod: "GET"},
+ 1: {method: "POST", serverStatus: 302, wantMethod: "GET"},
+ 2: {method: "POST", serverStatus: 303, wantMethod: "GET"},
+ 3: {method: "POST", serverStatus: 307, wantMethod: "POST"},
+ 4: {method: "POST", serverStatus: 308, wantMethod: "POST"},
+
+ 5: {method: "HEAD", serverStatus: 301, wantMethod: "HEAD"},
+ 6: {method: "HEAD", serverStatus: 302, wantMethod: "HEAD"},
+ 7: {method: "HEAD", serverStatus: 303, wantMethod: "HEAD"},
+ 8: {method: "HEAD", serverStatus: 307, wantMethod: "HEAD"},
+ 9: {method: "HEAD", serverStatus: 308, wantMethod: "HEAD"},
+
+ 10: {method: "GET", serverStatus: 301, wantMethod: "GET"},
+ 11: {method: "GET", serverStatus: 302, wantMethod: "GET"},
+ 12: {method: "GET", serverStatus: 303, wantMethod: "GET"},
+ 13: {method: "GET", serverStatus: 307, wantMethod: "GET"},
+ 14: {method: "GET", serverStatus: 308, wantMethod: "GET"},
+
+ 15: {method: "DELETE", serverStatus: 301, wantMethod: "GET"},
+ 16: {method: "DELETE", serverStatus: 302, wantMethod: "GET"},
+ 17: {method: "DELETE", serverStatus: 303, wantMethod: "GET"},
+ 18: {method: "DELETE", serverStatus: 307, wantMethod: "DELETE"},
+ 19: {method: "DELETE", serverStatus: 308, wantMethod: "DELETE"},
+
+ 20: {method: "PUT", serverStatus: 301, wantMethod: "GET"},
+ 21: {method: "PUT", serverStatus: 302, wantMethod: "GET"},
+ 22: {method: "PUT", serverStatus: 303, wantMethod: "GET"},
+ 23: {method: "PUT", serverStatus: 307, wantMethod: "PUT"},
+ 24: {method: "PUT", serverStatus: 308, wantMethod: "PUT"},
+
+ 25: {method: "MADEUPMETHOD", serverStatus: 301, wantMethod: "GET"},
+ 26: {method: "MADEUPMETHOD", serverStatus: 302, wantMethod: "GET"},
+ 27: {method: "MADEUPMETHOD", serverStatus: 303, wantMethod: "GET"},
+ 28: {method: "MADEUPMETHOD", serverStatus: 307, wantMethod: "MADEUPMETHOD"},
+ 29: {method: "MADEUPMETHOD", serverStatus: 308, wantMethod: "MADEUPMETHOD"},
+ }
+
+ handlerc := make(chan HandlerFunc, 1)
+
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ h := <-handlerc
+ h(rw, req)
+ }))
+ defer ts.Close()
+
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+
+ for i, tt := range tests {
+ handlerc <- func(w ResponseWriter, r *Request) {
+ w.Header().Set("Location", ts.URL)
+ w.WriteHeader(tt.serverStatus)
+ }
+
+ req, err := NewRequest(tt.method, ts.URL, nil)
+ if err != nil {
+ t.Errorf("#%d: NewRequest: %v", i, err)
+ continue
+ }
+
+ c := &Client{Transport: tr}
+ c.CheckRedirect = func(req *Request, via []*Request) error {
+ if got, want := req.Method, tt.wantMethod; got != want {
+ return fmt.Errorf("#%d: got next method %q; want %q", i, got, want)
+ }
+ handlerc <- func(rw ResponseWriter, req *Request) {
+ // TODO: Check that the body is valid when we do 307 and 308 support
+ }
+ return nil
+ }
+
+ res, err := c.Do(req)
+ if err != nil {
+ t.Errorf("#%d: Response: %v", i, err)
+ continue
+ }
+
+ res.Body.Close()
+ }
+}
+
+// issue18239Body is an io.ReadCloser for TestTransportBodyReadError.
+// Its Read returns readErr and increments *readCalls atomically.
+// Its Close returns nil and increments *closeCalls atomically.
+type issue18239Body struct {
+ readCalls *int32
+ closeCalls *int32
+ readErr error
+}
+
+func (b issue18239Body) Read([]byte) (int, error) {
+ atomic.AddInt32(b.readCalls, 1)
+ return 0, b.readErr
+}
+
+func (b issue18239Body) Close() error {
+ atomic.AddInt32(b.closeCalls, 1)
+ return nil
+}
+
+// Issue 18239: make sure the Transport doesn't retry requests with bodies.
+// (Especially if Request.GetBody is not defined.)
+func TestTransportBodyReadError(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.URL.Path == "/ping" {
+ return
+ }
+ buf := make([]byte, 1)
+ n, err := r.Body.Read(buf)
+ w.Header().Set("X-Body-Read", fmt.Sprintf("%v, %v", n, err))
+ }))
+ defer ts.Close()
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ // Do one initial successful request to create an idle TCP connection
+ // for the subsequent request to reuse. (The Transport only retries
+ // requests on reused connections.)
+ res, err := c.Get(ts.URL + "/ping")
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+
+ var readCallsAtomic int32
+ var closeCallsAtomic int32 // atomic
+ someErr := errors.New("some body read error")
+ body := issue18239Body{&readCallsAtomic, &closeCallsAtomic, someErr}
+
+ req, err := NewRequest("POST", ts.URL, body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = tr.RoundTrip(req)
+ if err != someErr {
+ t.Errorf("Got error: %v; want Request.Body read error: %v", err, someErr)
+ }
+
+ // And verify that our Body wasn't used multiple times, which
+ // would indicate retries. (as it buggily was during part of
+ // Go 1.8's dev cycle)
+ readCalls := atomic.LoadInt32(&readCallsAtomic)
+ closeCalls := atomic.LoadInt32(&closeCallsAtomic)
+ if readCalls != 1 {
+ t.Errorf("read calls = %d; want 1", readCalls)
+ }
+ if closeCalls != 1 {
+ t.Errorf("close calls = %d; want 1", closeCalls)
+ }
+}
diff --git a/libgo/go/net/http/clientserver_test.go b/libgo/go/net/http/clientserver_test.go
index 3d1f09c..580115c 100644
--- a/libgo/go/net/http/clientserver_test.go
+++ b/libgo/go/net/http/clientserver_test.go
@@ -44,6 +44,19 @@ func (t *clientServerTest) close() {
t.ts.Close()
}
+func (t *clientServerTest) getURL(u string) string {
+ res, err := t.c.Get(u)
+ if err != nil {
+ t.t.Fatal(err)
+ }
+ defer res.Body.Close()
+ slurp, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.t.Fatal(err)
+ }
+ return string(slurp)
+}
+
func (t *clientServerTest) scheme() string {
if t.h2 {
return "https"
@@ -56,6 +69,10 @@ const (
h2Mode = true
)
+var optQuietLog = func(ts *httptest.Server) {
+ ts.Config.ErrorLog = quietLog
+}
+
func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...interface{}) *clientServerTest {
cst := &clientServerTest{
t: t,
@@ -64,21 +81,23 @@ func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...interface{})
tr: &Transport{},
}
cst.c = &Client{Transport: cst.tr}
+ cst.ts = httptest.NewUnstartedServer(h)
for _, opt := range opts {
switch opt := opt.(type) {
case func(*Transport):
opt(cst.tr)
+ case func(*httptest.Server):
+ opt(cst.ts)
default:
t.Fatalf("unhandled option type %T", opt)
}
}
if !h2 {
- cst.ts = httptest.NewServer(h)
+ cst.ts.Start()
return cst
}
- cst.ts = httptest.NewUnstartedServer(h)
ExportHttp2ConfigureServer(cst.ts.Config, nil)
cst.ts.TLS = cst.ts.Config.TLSConfig
cst.ts.StartTLS()
@@ -170,6 +189,7 @@ func (tt h12Compare) reqFunc() reqFunc {
}
func (tt h12Compare) run(t *testing.T) {
+ setParallel(t)
cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler), tt.Opts...)
defer cst1.close()
cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler), tt.Opts...)
@@ -468,7 +488,7 @@ func TestH12_RequestContentLength_Known_NonZero(t *testing.T) {
}
func TestH12_RequestContentLength_Known_Zero(t *testing.T) {
- h12requestContentLength(t, func() io.Reader { return strings.NewReader("") }, 0)
+ h12requestContentLength(t, func() io.Reader { return nil }, 0)
}
func TestH12_RequestContentLength_Unknown(t *testing.T) {
@@ -938,6 +958,7 @@ func testStarRequest(t *testing.T, method string, h2 bool) {
// Issue 13957
func TestTransportDiscardsUnneededConns(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "Hello, %v", r.RemoteAddr)
@@ -1026,6 +1047,7 @@ func testTransportGCRequest(t *testing.T, h2, body bool) {
t.Skip("skipping on gccgo because conservative GC means that finalizer may never run")
}
+ setParallel(t)
defer afterTest(t)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
ioutil.ReadAll(r.Body)
@@ -1072,10 +1094,11 @@ func TestTransportRejectsInvalidHeaders_h2(t *testing.T) {
testTransportRejectsInvalidHeaders(t, h2Mode)
}
func testTransportRejectsInvalidHeaders(t *testing.T, h2 bool) {
+ setParallel(t)
defer afterTest(t)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "Handler saw headers: %q", r.Header)
- }))
+ }), optQuietLog)
defer cst.close()
cst.tr.DisableKeepAlives = true
@@ -1143,24 +1166,44 @@ func testBogusStatusWorks(t *testing.T, h2 bool) {
}
}
-func TestInterruptWithPanic_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode) }
-func TestInterruptWithPanic_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode) }
-func testInterruptWithPanic(t *testing.T, h2 bool) {
- log.SetOutput(ioutil.Discard) // is noisy otherwise
- defer log.SetOutput(os.Stderr)
-
+func TestInterruptWithPanic_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, "boom") }
+func TestInterruptWithPanic_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, "boom") }
+func TestInterruptWithPanic_nil_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, nil) }
+func TestInterruptWithPanic_nil_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, nil) }
+func TestInterruptWithPanic_ErrAbortHandler_h1(t *testing.T) {
+ testInterruptWithPanic(t, h1Mode, ErrAbortHandler)
+}
+func TestInterruptWithPanic_ErrAbortHandler_h2(t *testing.T) {
+ testInterruptWithPanic(t, h2Mode, ErrAbortHandler)
+}
+func testInterruptWithPanic(t *testing.T, h2 bool, panicValue interface{}) {
+ setParallel(t)
const msg = "hello"
defer afterTest(t)
+
+ testDone := make(chan struct{})
+ defer close(testDone)
+
+ var errorLog lockedBytesBuffer
+ gotHeaders := make(chan bool, 1)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
io.WriteString(w, msg)
w.(Flusher).Flush()
- panic("no more")
- }))
+
+ select {
+ case <-gotHeaders:
+ case <-testDone:
+ }
+ panic(panicValue)
+ }), func(ts *httptest.Server) {
+ ts.Config.ErrorLog = log.New(&errorLog, "", 0)
+ })
defer cst.close()
res, err := cst.c.Get(cst.ts.URL)
if err != nil {
t.Fatal(err)
}
+ gotHeaders <- true
defer res.Body.Close()
slurp, err := ioutil.ReadAll(res.Body)
if string(slurp) != msg {
@@ -1169,6 +1212,42 @@ func testInterruptWithPanic(t *testing.T, h2 bool) {
if err == nil {
t.Errorf("client read all successfully; want some error")
}
+ logOutput := func() string {
+ errorLog.Lock()
+ defer errorLog.Unlock()
+ return errorLog.String()
+ }
+ wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler
+
+ if err := waitErrCondition(5*time.Second, 10*time.Millisecond, func() error {
+ gotLog := logOutput()
+ if !wantStackLogged {
+ if gotLog == "" {
+ return nil
+ }
+ return fmt.Errorf("want no log output; got: %s", gotLog)
+ }
+ if gotLog == "" {
+ return fmt.Errorf("wanted a stack trace logged; got nothing")
+ }
+ if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 {
+ return fmt.Errorf("output doesn't look like a panic stack trace. Got: %s", gotLog)
+ }
+ return nil
+ }); err != nil {
+ t.Fatal(err)
+ }
+}
+
+type lockedBytesBuffer struct {
+ sync.Mutex
+ bytes.Buffer
+}
+
+func (b *lockedBytesBuffer) Write(p []byte) (int, error) {
+ b.Lock()
+ defer b.Unlock()
+ return b.Buffer.Write(p)
}
// Issue 15366
@@ -1204,6 +1283,7 @@ func TestH12_AutoGzipWithDumpResponse(t *testing.T) {
func TestCloseIdleConnections_h1(t *testing.T) { testCloseIdleConnections(t, h1Mode) }
func TestCloseIdleConnections_h2(t *testing.T) { testCloseIdleConnections(t, h2Mode) }
func testCloseIdleConnections(t *testing.T, h2 bool) {
+ setParallel(t)
defer afterTest(t)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("X-Addr", r.RemoteAddr)
@@ -1238,3 +1318,70 @@ func (x noteCloseConn) Close() error {
x.closeFunc()
return x.Conn.Close()
}
+
+type testErrorReader struct{ t *testing.T }
+
+func (r testErrorReader) Read(p []byte) (n int, err error) {
+ r.t.Error("unexpected Read call")
+ return 0, io.EOF
+}
+
+func TestNoSniffExpectRequestBody_h1(t *testing.T) { testNoSniffExpectRequestBody(t, h1Mode) }
+func TestNoSniffExpectRequestBody_h2(t *testing.T) { testNoSniffExpectRequestBody(t, h2Mode) }
+
+func testNoSniffExpectRequestBody(t *testing.T, h2 bool) {
+ defer afterTest(t)
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.WriteHeader(StatusUnauthorized)
+ }))
+ defer cst.close()
+
+ // Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it.
+ cst.tr.ExpectContinueTimeout = 10 * time.Second
+
+ req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t})
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.ContentLength = 0 // so transport is tempted to sniff it
+ req.Header.Set("Expect", "100-continue")
+ res, err := cst.tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != StatusUnauthorized {
+ t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized)
+ }
+}
+
+func TestServerUndeclaredTrailers_h1(t *testing.T) { testServerUndeclaredTrailers(t, h1Mode) }
+func TestServerUndeclaredTrailers_h2(t *testing.T) { testServerUndeclaredTrailers(t, h2Mode) }
+func testServerUndeclaredTrailers(t *testing.T, h2 bool) {
+ defer afterTest(t)
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Foo", "Bar")
+ w.Header().Set("Trailer:Foo", "Baz")
+ w.(Flusher).Flush()
+ w.Header().Add("Trailer:Foo", "Baz2")
+ w.Header().Set("Trailer:Bar", "Quux")
+ }))
+ defer cst.close()
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := io.Copy(ioutil.Discard, res.Body); err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ delete(res.Header, "Date")
+ delete(res.Header, "Content-Type")
+
+ if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) {
+ t.Errorf("Header = %#v; want %#v", res.Header, want)
+ }
+ if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) {
+ t.Errorf("Trailer = %#v; want %#v", res.Trailer, want)
+ }
+}
diff --git a/libgo/go/net/http/cookie.go b/libgo/go/net/http/cookie.go
index 1ea0e93..5a67476 100644
--- a/libgo/go/net/http/cookie.go
+++ b/libgo/go/net/http/cookie.go
@@ -6,7 +6,6 @@ package http
import (
"bytes"
- "fmt"
"log"
"net"
"strconv"
@@ -40,7 +39,11 @@ type Cookie struct {
// readSetCookies parses all "Set-Cookie" values from
// the header h and returns the successfully parsed Cookies.
func readSetCookies(h Header) []*Cookie {
- cookies := []*Cookie{}
+ cookieCount := len(h["Set-Cookie"])
+ if cookieCount == 0 {
+ return []*Cookie{}
+ }
+ cookies := make([]*Cookie, 0, cookieCount)
for _, line := range h["Set-Cookie"] {
parts := strings.Split(strings.TrimSpace(line), ";")
if len(parts) == 1 && parts[0] == "" {
@@ -55,8 +58,8 @@ func readSetCookies(h Header) []*Cookie {
if !isCookieNameValid(name) {
continue
}
- value, success := parseCookieValue(value, true)
- if !success {
+ value, ok := parseCookieValue(value, true)
+ if !ok {
continue
}
c := &Cookie{
@@ -75,8 +78,8 @@ func readSetCookies(h Header) []*Cookie {
attr, val = attr[:j], attr[j+1:]
}
lowerAttr := strings.ToLower(attr)
- val, success = parseCookieValue(val, false)
- if !success {
+ val, ok = parseCookieValue(val, false)
+ if !ok {
c.Unparsed = append(c.Unparsed, parts[i])
continue
}
@@ -96,10 +99,9 @@ func readSetCookies(h Header) []*Cookie {
break
}
if secs <= 0 {
- c.MaxAge = -1
- } else {
- c.MaxAge = secs
+ secs = -1
}
+ c.MaxAge = secs
continue
case "expires":
c.RawExpires = val
@@ -142,9 +144,13 @@ func (c *Cookie) String() string {
return ""
}
var b bytes.Buffer
- fmt.Fprintf(&b, "%s=%s", sanitizeCookieName(c.Name), sanitizeCookieValue(c.Value))
+ b.WriteString(sanitizeCookieName(c.Name))
+ b.WriteRune('=')
+ b.WriteString(sanitizeCookieValue(c.Value))
+
if len(c.Path) > 0 {
- fmt.Fprintf(&b, "; Path=%s", sanitizeCookiePath(c.Path))
+ b.WriteString("; Path=")
+ b.WriteString(sanitizeCookiePath(c.Path))
}
if len(c.Domain) > 0 {
if validCookieDomain(c.Domain) {
@@ -156,25 +162,31 @@ func (c *Cookie) String() string {
if d[0] == '.' {
d = d[1:]
}
- fmt.Fprintf(&b, "; Domain=%s", d)
+ b.WriteString("; Domain=")
+ b.WriteString(d)
} else {
- log.Printf("net/http: invalid Cookie.Domain %q; dropping domain attribute",
- c.Domain)
+ log.Printf("net/http: invalid Cookie.Domain %q; dropping domain attribute", c.Domain)
}
}
- if c.Expires.Unix() > 0 {
- fmt.Fprintf(&b, "; Expires=%s", c.Expires.UTC().Format(TimeFormat))
+ if validCookieExpires(c.Expires) {
+ b.WriteString("; Expires=")
+ b2 := b.Bytes()
+ b.Reset()
+ b.Write(c.Expires.UTC().AppendFormat(b2, TimeFormat))
}
if c.MaxAge > 0 {
- fmt.Fprintf(&b, "; Max-Age=%d", c.MaxAge)
+ b.WriteString("; Max-Age=")
+ b2 := b.Bytes()
+ b.Reset()
+ b.Write(strconv.AppendInt(b2, int64(c.MaxAge), 10))
} else if c.MaxAge < 0 {
- fmt.Fprintf(&b, "; Max-Age=0")
+ b.WriteString("; Max-Age=0")
}
if c.HttpOnly {
- fmt.Fprintf(&b, "; HttpOnly")
+ b.WriteString("; HttpOnly")
}
if c.Secure {
- fmt.Fprintf(&b, "; Secure")
+ b.WriteString("; Secure")
}
return b.String()
}
@@ -184,12 +196,12 @@ func (c *Cookie) String() string {
//
// if filter isn't empty, only cookies of that name are returned
func readCookies(h Header, filter string) []*Cookie {
- cookies := []*Cookie{}
lines, ok := h["Cookie"]
if !ok {
- return cookies
+ return []*Cookie{}
}
+ cookies := []*Cookie{}
for _, line := range lines {
parts := strings.Split(strings.TrimSpace(line), ";")
if len(parts) == 1 && parts[0] == "" {
@@ -212,8 +224,8 @@ func readCookies(h Header, filter string) []*Cookie {
if filter != "" && filter != name {
continue
}
- val, success := parseCookieValue(val, true)
- if !success {
+ val, ok := parseCookieValue(val, true)
+ if !ok {
continue
}
cookies = append(cookies, &Cookie{Name: name, Value: val})
@@ -234,6 +246,12 @@ func validCookieDomain(v string) bool {
return false
}
+// validCookieExpires returns whether v is a valid cookie expires-value.
+func validCookieExpires(t time.Time) bool {
+ // IETF RFC 6265 Section 5.1.1.5, the year must not be less than 1601
+ return t.Year() >= 1601
+}
+
// isCookieDomainName returns whether s is a valid domain name or a valid
// domain name with a leading dot '.'. It is almost a direct copy of
// package net's isDomainName.
diff --git a/libgo/go/net/http/cookie_test.go b/libgo/go/net/http/cookie_test.go
index 95e6147..b3e54f8 100644
--- a/libgo/go/net/http/cookie_test.go
+++ b/libgo/go/net/http/cookie_test.go
@@ -56,6 +56,15 @@ var writeSetCookiesTests = []struct {
&Cookie{Name: "cookie-9", Value: "expiring", Expires: time.Unix(1257894000, 0)},
"cookie-9=expiring; Expires=Tue, 10 Nov 2009 23:00:00 GMT",
},
+ // According to IETF 6265 Section 5.1.1.5, the year cannot be less than 1601
+ {
+ &Cookie{Name: "cookie-10", Value: "expiring-1601", Expires: time.Date(1601, 1, 1, 1, 1, 1, 1, time.UTC)},
+ "cookie-10=expiring-1601; Expires=Mon, 01 Jan 1601 01:01:01 GMT",
+ },
+ {
+ &Cookie{Name: "cookie-11", Value: "invalid-expiry", Expires: time.Date(1600, 1, 1, 1, 1, 1, 1, time.UTC)},
+ "cookie-11=invalid-expiry",
+ },
// The "special" cookies have values containing commas or spaces which
// are disallowed by RFC 6265 but are common in the wild.
{
@@ -426,3 +435,92 @@ func TestCookieSanitizePath(t *testing.T) {
t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got)
}
}
+
+func BenchmarkCookieString(b *testing.B) {
+ const wantCookieString = `cookie-9=i3e01nf61b6t23bvfmplnanol3; Path=/restricted/; Domain=example.com; Expires=Tue, 10 Nov 2009 23:00:00 GMT; Max-Age=3600`
+ c := &Cookie{
+ Name: "cookie-9",
+ Value: "i3e01nf61b6t23bvfmplnanol3",
+ Expires: time.Unix(1257894000, 0),
+ Path: "/restricted/",
+ Domain: ".example.com",
+ MaxAge: 3600,
+ }
+ var benchmarkCookieString string
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ benchmarkCookieString = c.String()
+ }
+ if have, want := benchmarkCookieString, wantCookieString; have != want {
+ b.Fatalf("Have: %v Want: %v", have, want)
+ }
+}
+
+func BenchmarkReadSetCookies(b *testing.B) {
+ header := Header{
+ "Set-Cookie": {
+ "NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly",
+ ".ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly",
+ },
+ }
+ wantCookies := []*Cookie{
+ {
+ Name: "NID",
+ Value: "99=YsDT5i3E-CXax-",
+ Path: "/",
+ Domain: ".google.ch",
+ HttpOnly: true,
+ Expires: time.Date(2011, 11, 23, 1, 5, 3, 0, time.UTC),
+ RawExpires: "Wed, 23-Nov-2011 01:05:03 GMT",
+ Raw: "NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly",
+ },
+ {
+ Name: ".ASPXAUTH",
+ Value: "7E3AA",
+ Path: "/",
+ Expires: time.Date(2012, 3, 7, 14, 25, 6, 0, time.UTC),
+ RawExpires: "Wed, 07-Mar-2012 14:25:06 GMT",
+ HttpOnly: true,
+ Raw: ".ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly",
+ },
+ }
+ var c []*Cookie
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ c = readSetCookies(header)
+ }
+ if !reflect.DeepEqual(c, wantCookies) {
+ b.Fatalf("readSetCookies:\nhave: %s\nwant: %s\n", toJSON(c), toJSON(wantCookies))
+ }
+}
+
+func BenchmarkReadCookies(b *testing.B) {
+ header := Header{
+ "Cookie": {
+ `de=; client_region=0; rpld1=0:hispeed.ch|20:che|21:zh|22:zurich|23:47.36|24:8.53|; rpld0=1:08|; backplane-channel=newspaper.com:1471; devicetype=0; osfam=0; rplmct=2; s_pers=%20s_vmonthnum%3D1472680800496%2526vn%253D1%7C1472680800496%3B%20s_nr%3D1471686767664-New%7C1474278767664%3B%20s_lv%3D1471686767669%7C1566294767669%3B%20s_lv_s%3DFirst%2520Visit%7C1471688567669%3B%20s_monthinvisit%3Dtrue%7C1471688567677%3B%20gvp_p5%3Dsports%253Ablog%253Aearly-lead%2520-%2520184693%2520-%252020160820%2520-%2520u-s%7C1471688567681%3B%20gvp_p51%3Dwp%2520-%2520sports%7C1471688567684%3B; s_sess=%20s_wp_ep%3Dhomepage%3B%20s._ref%3Dhttps%253A%252F%252Fwww.google.ch%252F%3B%20s_cc%3Dtrue%3B%20s_ppvl%3Dsports%25253Ablog%25253Aearly-lead%252520-%252520184693%252520-%25252020160820%252520-%252520u-lawyer%252C12%252C12%252C502%252C1231%252C502%252C1680%252C1050%252C2%252CP%3B%20s_ppv%3Dsports%25253Ablog%25253Aearly-lead%252520-%252520184693%252520-%25252020160820%252520-%252520u-s-lawyer%252C12%252C12%252C502%252C1231%252C502%252C1680%252C1050%252C2%252CP%3B%20s_dslv%3DFirst%2520Visit%3B%20s_sq%3Dwpninewspapercom%253D%252526pid%25253Dsports%2525253Ablog%2525253Aearly-lead%25252520-%25252520184693%25252520-%2525252020160820%25252520-%25252520u-s%252526pidt%25253D1%252526oid%25253Dhttps%2525253A%2525252F%2525252Fwww.newspaper.com%2525252F%2525253Fnid%2525253Dmenu_nav_homepage%252526ot%25253DA%3B`,
+ },
+ }
+ wantCookies := []*Cookie{
+ {Name: "de", Value: ""},
+ {Name: "client_region", Value: "0"},
+ {Name: "rpld1", Value: "0:hispeed.ch|20:che|21:zh|22:zurich|23:47.36|24:8.53|"},
+ {Name: "rpld0", Value: "1:08|"},
+ {Name: "backplane-channel", Value: "newspaper.com:1471"},
+ {Name: "devicetype", Value: "0"},
+ {Name: "osfam", Value: "0"},
+ {Name: "rplmct", Value: "2"},
+ {Name: "s_pers", Value: "%20s_vmonthnum%3D1472680800496%2526vn%253D1%7C1472680800496%3B%20s_nr%3D1471686767664-New%7C1474278767664%3B%20s_lv%3D1471686767669%7C1566294767669%3B%20s_lv_s%3DFirst%2520Visit%7C1471688567669%3B%20s_monthinvisit%3Dtrue%7C1471688567677%3B%20gvp_p5%3Dsports%253Ablog%253Aearly-lead%2520-%2520184693%2520-%252020160820%2520-%2520u-s%7C1471688567681%3B%20gvp_p51%3Dwp%2520-%2520sports%7C1471688567684%3B"},
+ {Name: "s_sess", Value: "%20s_wp_ep%3Dhomepage%3B%20s._ref%3Dhttps%253A%252F%252Fwww.google.ch%252F%3B%20s_cc%3Dtrue%3B%20s_ppvl%3Dsports%25253Ablog%25253Aearly-lead%252520-%252520184693%252520-%25252020160820%252520-%252520u-lawyer%252C12%252C12%252C502%252C1231%252C502%252C1680%252C1050%252C2%252CP%3B%20s_ppv%3Dsports%25253Ablog%25253Aearly-lead%252520-%252520184693%252520-%25252020160820%252520-%252520u-s-lawyer%252C12%252C12%252C502%252C1231%252C502%252C1680%252C1050%252C2%252CP%3B%20s_dslv%3DFirst%2520Visit%3B%20s_sq%3Dwpninewspapercom%253D%252526pid%25253Dsports%2525253Ablog%2525253Aearly-lead%25252520-%25252520184693%25252520-%2525252020160820%25252520-%25252520u-s%252526pidt%25253D1%252526oid%25253Dhttps%2525253A%2525252F%2525252Fwww.newspaper.com%2525252F%2525253Fnid%2525253Dmenu_nav_homepage%252526ot%25253DA%3B"},
+ }
+ var c []*Cookie
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ c = readCookies(header, "")
+ }
+ if !reflect.DeepEqual(c, wantCookies) {
+ b.Fatalf("readCookies:\nhave: %s\nwant: %s\n", toJSON(c), toJSON(wantCookies))
+ }
+}
diff --git a/libgo/go/net/http/cookiejar/dummy_publicsuffix_test.go b/libgo/go/net/http/cookiejar/dummy_publicsuffix_test.go
new file mode 100644
index 0000000..748ec5c
--- /dev/null
+++ b/libgo/go/net/http/cookiejar/dummy_publicsuffix_test.go
@@ -0,0 +1,23 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build ignore
+
+package cookiejar_test
+
+import "net/http/cookiejar"
+
+type dummypsl struct {
+ List cookiejar.PublicSuffixList
+}
+
+func (dummypsl) PublicSuffix(domain string) string {
+ return domain
+}
+
+func (dummypsl) String() string {
+ return "dummy"
+}
+
+var publicsuffix = dummypsl{}
diff --git a/libgo/go/net/http/cookiejar/example_test.go b/libgo/go/net/http/cookiejar/example_test.go
new file mode 100644
index 0000000..19a5746
--- /dev/null
+++ b/libgo/go/net/http/cookiejar/example_test.go
@@ -0,0 +1,67 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build ignore
+
+package cookiejar_test
+
+import (
+ "fmt"
+ "log"
+ "net/http"
+ "net/http/cookiejar"
+ "net/http/httptest"
+ "net/url"
+)
+
+func ExampleNew() {
+ // Start a server to give us cookies.
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if cookie, err := r.Cookie("Flavor"); err != nil {
+ http.SetCookie(w, &http.Cookie{Name: "Flavor", Value: "Chocolate Chip"})
+ } else {
+ cookie.Value = "Oatmeal Raisin"
+ http.SetCookie(w, cookie)
+ }
+ }))
+ defer ts.Close()
+
+ u, err := url.Parse(ts.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // All users of cookiejar should import "golang.org/x/net/publicsuffix"
+ jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ client := &http.Client{
+ Jar: jar,
+ }
+
+ if _, err = client.Get(u.String()); err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Println("After 1st request:")
+ for _, cookie := range jar.Cookies(u) {
+ fmt.Printf(" %s: %s\n", cookie.Name, cookie.Value)
+ }
+
+ if _, err = client.Get(u.String()); err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Println("After 2nd request:")
+ for _, cookie := range jar.Cookies(u) {
+ fmt.Printf(" %s: %s\n", cookie.Name, cookie.Value)
+ }
+ // Output:
+ // After 1st request:
+ // Flavor: Chocolate Chip
+ // After 2nd request:
+ // Flavor: Oatmeal Raisin
+}
diff --git a/libgo/go/net/http/cookiejar/jar.go b/libgo/go/net/http/cookiejar/jar.go
index 0e0fac9..f89abbc 100644
--- a/libgo/go/net/http/cookiejar/jar.go
+++ b/libgo/go/net/http/cookiejar/jar.go
@@ -107,7 +107,7 @@ type entry struct {
seqNum uint64
}
-// Id returns the domain;path;name triple of e as an id.
+// id returns the domain;path;name triple of e as an id.
func (e *entry) id() string {
return fmt.Sprintf("%s;%s;%s", e.Domain, e.Path, e.Name)
}
@@ -147,24 +147,6 @@ func hasDotSuffix(s, suffix string) bool {
return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix
}
-// byPathLength is a []entry sort.Interface that sorts according to RFC 6265
-// section 5.4 point 2: by longest path and then by earliest creation time.
-type byPathLength []entry
-
-func (s byPathLength) Len() int { return len(s) }
-
-func (s byPathLength) Less(i, j int) bool {
- if len(s[i].Path) != len(s[j].Path) {
- return len(s[i].Path) > len(s[j].Path)
- }
- if !s[i].Creation.Equal(s[j].Creation) {
- return s[i].Creation.Before(s[j].Creation)
- }
- return s[i].seqNum < s[j].seqNum
-}
-
-func (s byPathLength) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
-
// Cookies implements the Cookies method of the http.CookieJar interface.
//
// It returns an empty slice if the URL's scheme is not HTTP or HTTPS.
@@ -221,7 +203,18 @@ func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) {
}
}
- sort.Sort(byPathLength(selected))
+ // sort according to RFC 6265 section 5.4 point 2: by longest
+ // path and then by earliest creation time.
+ sort.Slice(selected, func(i, j int) bool {
+ s := selected
+ if len(s[i].Path) != len(s[j].Path) {
+ return len(s[i].Path) > len(s[j].Path)
+ }
+ if !s[i].Creation.Equal(s[j].Creation) {
+ return s[i].Creation.Before(s[j].Creation)
+ }
+ return s[i].seqNum < s[j].seqNum
+ })
for _, e := range selected {
cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value})
}
diff --git a/libgo/go/net/http/doc.go b/libgo/go/net/http/doc.go
index 4ec8272..7855fea 100644
--- a/libgo/go/net/http/doc.go
+++ b/libgo/go/net/http/doc.go
@@ -44,7 +44,8 @@ For control over proxies, TLS configuration, keep-alives,
compression, and other settings, create a Transport:
tr := &http.Transport{
- TLSClientConfig: &tls.Config{RootCAs: pool},
+ MaxIdleConns: 10,
+ IdleConnTimeout: 30 * time.Second,
DisableCompression: true,
}
client := &http.Client{Transport: tr}
@@ -77,19 +78,30 @@ custom Server:
}
log.Fatal(s.ListenAndServe())
-The http package has transparent support for the HTTP/2 protocol when
-using HTTPS. Programs that must disable HTTP/2 can do so by setting
-Transport.TLSNextProto (for clients) or Server.TLSNextProto (for
-servers) to a non-nil, empty map. Alternatively, the following GODEBUG
-environment variables are currently supported:
+Starting with Go 1.6, the http package has transparent support for the
+HTTP/2 protocol when using HTTPS. Programs that must disable HTTP/2
+can do so by setting Transport.TLSNextProto (for clients) or
+Server.TLSNextProto (for servers) to a non-nil, empty
+map. Alternatively, the following GODEBUG environment variables are
+currently supported:
GODEBUG=http2client=0 # disable HTTP/2 client support
GODEBUG=http2server=0 # disable HTTP/2 server support
GODEBUG=http2debug=1 # enable verbose HTTP/2 debug logs
GODEBUG=http2debug=2 # ... even more verbose, with frame dumps
-The GODEBUG variables are not covered by Go's API compatibility promise.
-HTTP/2 support was added in Go 1.6. Please report any issues instead of
-disabling HTTP/2 support: https://golang.org/s/http2bug
+The GODEBUG variables are not covered by Go's API compatibility
+promise. Please report any issues before disabling HTTP/2
+support: https://golang.org/s/http2bug
+
+The http package's Transport and Server both automatically enable
+HTTP/2 support for simple configurations. To enable HTTP/2 for more
+complex configurations, to use lower-level HTTP/2 features, or to use
+a newer version of Go's http2 package, import "golang.org/x/net/http2"
+directly and use its ConfigureTransport and/or ConfigureServer
+functions. Manually configuring HTTP/2 via the golang.org/x/net/http2
+package takes precedence over the net/http package's built-in HTTP/2
+support.
+
*/
package http
diff --git a/libgo/go/net/http/export_test.go b/libgo/go/net/http/export_test.go
index 9c5ba08..b61f58b 100644
--- a/libgo/go/net/http/export_test.go
+++ b/libgo/go/net/http/export_test.go
@@ -24,6 +24,7 @@ var (
ExportErrRequestCanceled = errRequestCanceled
ExportErrRequestCanceledConn = errRequestCanceledConn
ExportServeFile = serveFile
+ ExportScanETag = scanETag
ExportHttp2ConfigureServer = http2ConfigureServer
)
@@ -87,6 +88,12 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) {
return
}
+func (t *Transport) IdleConnKeyCountForTesting() int {
+ t.idleMu.Lock()
+ defer t.idleMu.Unlock()
+ return len(t.idleConn)
+}
+
func (t *Transport) IdleConnStrsForTesting() []string {
var ret []string
t.idleMu.Lock()
@@ -100,6 +107,24 @@ func (t *Transport) IdleConnStrsForTesting() []string {
return ret
}
+func (t *Transport) IdleConnStrsForTesting_h2() []string {
+ var ret []string
+ noDialPool := t.h2transport.ConnPool.(http2noDialClientConnPool)
+ pool := noDialPool.http2clientConnPool
+
+ pool.mu.Lock()
+ defer pool.mu.Unlock()
+
+ for k, cc := range pool.conns {
+ for range cc {
+ ret = append(ret, k)
+ }
+ }
+
+ sort.Strings(ret)
+ return ret
+}
+
func (t *Transport) IdleConnCountForTesting(cacheKey string) int {
t.idleMu.Lock()
defer t.idleMu.Unlock()
@@ -160,3 +185,17 @@ func ExportHttp2ConfigureTransport(t *Transport) error {
t.h2transport = t2
return nil
}
+
+var Export_shouldCopyHeaderOnRedirect = shouldCopyHeaderOnRedirect
+
+func (s *Server) ExportAllConnsIdle() bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ for c := range s.activeConn {
+ st, ok := c.curState.Load().(ConnState)
+ if !ok || st != StateIdle {
+ return false
+ }
+ }
+ return true
+}
diff --git a/libgo/go/net/http/fcgi/fcgi.go b/libgo/go/net/http/fcgi/fcgi.go
index 3374841..5057d70 100644
--- a/libgo/go/net/http/fcgi/fcgi.go
+++ b/libgo/go/net/http/fcgi/fcgi.go
@@ -3,8 +3,12 @@
// license that can be found in the LICENSE file.
// Package fcgi implements the FastCGI protocol.
+//
+// The protocol is not an official standard and the original
+// documentation is no longer online. See the Internet Archive's
+// mirror at: https://web.archive.org/web/20150420080736/http://www.fastcgi.com/drupal/node/6?q=node/22
+//
// Currently only the responder role is supported.
-// The protocol is defined at http://www.fastcgi.com/drupal/node/6?q=node/22
package fcgi
// This file defines the raw protocol and some utilities used by the child and
diff --git a/libgo/go/net/http/fs.go b/libgo/go/net/http/fs.go
index c7a58a6..bf63bb5 100644
--- a/libgo/go/net/http/fs.go
+++ b/libgo/go/net/http/fs.go
@@ -77,7 +77,7 @@ func dirList(w ResponseWriter, f File) {
Error(w, "Error reading directory", StatusInternalServerError)
return
}
- sort.Sort(byName(dirs))
+ sort.Slice(dirs, func(i, j int) bool { return dirs[i].Name() < dirs[j].Name() })
w.Header().Set("Content-Type", "text/html; charset=utf-8")
fmt.Fprintf(w, "<pre>\n")
@@ -98,7 +98,8 @@ func dirList(w ResponseWriter, f File) {
// ServeContent replies to the request using the content in the
// provided ReadSeeker. The main benefit of ServeContent over io.Copy
// is that it handles Range requests properly, sets the MIME type, and
-// handles If-Modified-Since requests.
+// handles If-Match, If-Unmodified-Since, If-None-Match, If-Modified-Since,
+// and If-Range requests.
//
// If the response's Content-Type header is not set, ServeContent
// first tries to deduce the type from name's file extension and,
@@ -115,8 +116,8 @@ func dirList(w ResponseWriter, f File) {
// The content's Seek method must work: ServeContent uses
// a seek to the end of the content to determine its size.
//
-// If the caller has set w's ETag header, ServeContent uses it to
-// handle requests using If-Range and If-None-Match.
+// If the caller has set w's ETag header formatted per RFC 7232, section 2.3,
+// ServeContent uses it to handle requests using If-Match, If-None-Match, or If-Range.
//
// Note that *os.File implements the io.ReadSeeker interface.
func ServeContent(w ResponseWriter, req *Request, name string, modtime time.Time, content io.ReadSeeker) {
@@ -140,15 +141,17 @@ func ServeContent(w ResponseWriter, req *Request, name string, modtime time.Time
// users.
var errSeeker = errors.New("seeker can't seek")
+// errNoOverlap is returned by serveContent's parseRange if first-byte-pos of
+// all of the byte-range-spec values is greater than the content size.
+var errNoOverlap = errors.New("invalid range: failed to overlap")
+
// if name is empty, filename is unknown. (used for mime type, before sniffing)
// if modtime.IsZero(), modtime is unknown.
// content must be seeked to the beginning of the file.
// The sizeFunc is called at most once. Its error, if any, is sent in the HTTP response.
func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, sizeFunc func() (int64, error), content io.ReadSeeker) {
- if checkLastModified(w, r, modtime) {
- return
- }
- rangeReq, done := checkETag(w, r, modtime)
+ setLastModified(w, modtime)
+ done, rangeReq := checkPreconditions(w, r, modtime)
if done {
return
}
@@ -189,6 +192,9 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time,
if size >= 0 {
ranges, err := parseRange(rangeReq, size)
if err != nil {
+ if err == errNoOverlap {
+ w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", size))
+ }
Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
return
}
@@ -263,90 +269,245 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time,
}
}
-var unixEpochTime = time.Unix(0, 0)
-
-// modtime is the modification time of the resource to be served, or IsZero().
-// return value is whether this request is now complete.
-func checkLastModified(w ResponseWriter, r *Request, modtime time.Time) bool {
- if modtime.IsZero() || modtime.Equal(unixEpochTime) {
- // If the file doesn't have a modtime (IsZero), or the modtime
- // is obviously garbage (Unix time == 0), then ignore modtimes
- // and don't process the If-Modified-Since header.
- return false
+// scanETag determines if a syntactically valid ETag is present at s. If so,
+// the ETag and remaining text after consuming ETag is returned. Otherwise,
+// it returns "", "".
+func scanETag(s string) (etag string, remain string) {
+ s = textproto.TrimString(s)
+ start := 0
+ if strings.HasPrefix(s, "W/") {
+ start = 2
+ }
+ if len(s[start:]) < 2 || s[start] != '"' {
+ return "", ""
+ }
+ // ETag is either W/"text" or "text".
+ // See RFC 7232 2.3.
+ for i := start + 1; i < len(s); i++ {
+ c := s[i]
+ switch {
+ // Character values allowed in ETags.
+ case c == 0x21 || c >= 0x23 && c <= 0x7E || c >= 0x80:
+ case c == '"':
+ return string(s[:i+1]), s[i+1:]
+ default:
+ break
+ }
}
+ return "", ""
+}
- // The Date-Modified header truncates sub-second precision, so
- // use mtime < t+1s instead of mtime <= t to check for unmodified.
- if t, err := time.Parse(TimeFormat, r.Header.Get("If-Modified-Since")); err == nil && modtime.Before(t.Add(1*time.Second)) {
- h := w.Header()
- delete(h, "Content-Type")
- delete(h, "Content-Length")
- w.WriteHeader(StatusNotModified)
- return true
- }
- w.Header().Set("Last-Modified", modtime.UTC().Format(TimeFormat))
- return false
+// etagStrongMatch reports whether a and b match using strong ETag comparison.
+// Assumes a and b are valid ETags.
+func etagStrongMatch(a, b string) bool {
+ return a == b && a != "" && a[0] == '"'
}
-// checkETag implements If-None-Match and If-Range checks.
-//
-// The ETag or modtime must have been previously set in the
-// ResponseWriter's headers. The modtime is only compared at second
-// granularity and may be the zero value to mean unknown.
-//
-// The return value is the effective request "Range" header to use and
-// whether this request is now considered done.
-func checkETag(w ResponseWriter, r *Request, modtime time.Time) (rangeReq string, done bool) {
- etag := w.Header().get("Etag")
- rangeReq = r.Header.get("Range")
-
- // Invalidate the range request if the entity doesn't match the one
- // the client was expecting.
- // "If-Range: version" means "ignore the Range: header unless version matches the
- // current file."
- // We only support ETag versions.
- // The caller must have set the ETag on the response already.
- if ir := r.Header.get("If-Range"); ir != "" && ir != etag {
- // The If-Range value is typically the ETag value, but it may also be
- // the modtime date. See golang.org/issue/8367.
- timeMatches := false
- if !modtime.IsZero() {
- if t, err := ParseTime(ir); err == nil && t.Unix() == modtime.Unix() {
- timeMatches = true
- }
+// etagWeakMatch reports whether a and b match using weak ETag comparison.
+// Assumes a and b are valid ETags.
+func etagWeakMatch(a, b string) bool {
+ return strings.TrimPrefix(a, "W/") == strings.TrimPrefix(b, "W/")
+}
+
+// condResult is the result of an HTTP request precondition check.
+// See https://tools.ietf.org/html/rfc7232 section 3.
+type condResult int
+
+const (
+ condNone condResult = iota
+ condTrue
+ condFalse
+)
+
+func checkIfMatch(w ResponseWriter, r *Request) condResult {
+ im := r.Header.Get("If-Match")
+ if im == "" {
+ return condNone
+ }
+ for {
+ im = textproto.TrimString(im)
+ if len(im) == 0 {
+ break
+ }
+ if im[0] == ',' {
+ im = im[1:]
+ continue
+ }
+ if im[0] == '*' {
+ return condTrue
}
- if !timeMatches {
- rangeReq = ""
+ etag, remain := scanETag(im)
+ if etag == "" {
+ break
+ }
+ if etagStrongMatch(etag, w.Header().get("Etag")) {
+ return condTrue
}
+ im = remain
}
- if inm := r.Header.get("If-None-Match"); inm != "" {
- // Must know ETag.
+ return condFalse
+}
+
+func checkIfUnmodifiedSince(w ResponseWriter, r *Request, modtime time.Time) condResult {
+ ius := r.Header.Get("If-Unmodified-Since")
+ if ius == "" || isZeroTime(modtime) {
+ return condNone
+ }
+ if t, err := ParseTime(ius); err == nil {
+ // The Date-Modified header truncates sub-second precision, so
+ // use mtime < t+1s instead of mtime <= t to check for unmodified.
+ if modtime.Before(t.Add(1 * time.Second)) {
+ return condTrue
+ }
+ return condFalse
+ }
+ return condNone
+}
+
+func checkIfNoneMatch(w ResponseWriter, r *Request) condResult {
+ inm := r.Header.get("If-None-Match")
+ if inm == "" {
+ return condNone
+ }
+ buf := inm
+ for {
+ buf = textproto.TrimString(buf)
+ if len(buf) == 0 {
+ break
+ }
+ if buf[0] == ',' {
+ buf = buf[1:]
+ }
+ if buf[0] == '*' {
+ return condFalse
+ }
+ etag, remain := scanETag(buf)
if etag == "" {
- return rangeReq, false
+ break
+ }
+ if etagWeakMatch(etag, w.Header().get("Etag")) {
+ return condFalse
}
+ buf = remain
+ }
+ return condTrue
+}
+
+func checkIfModifiedSince(w ResponseWriter, r *Request, modtime time.Time) condResult {
+ if r.Method != "GET" && r.Method != "HEAD" {
+ return condNone
+ }
+ ims := r.Header.Get("If-Modified-Since")
+ if ims == "" || isZeroTime(modtime) {
+ return condNone
+ }
+ t, err := ParseTime(ims)
+ if err != nil {
+ return condNone
+ }
+ // The Date-Modified header truncates sub-second precision, so
+ // use mtime < t+1s instead of mtime <= t to check for unmodified.
+ if modtime.Before(t.Add(1 * time.Second)) {
+ return condFalse
+ }
+ return condTrue
+}
+
+func checkIfRange(w ResponseWriter, r *Request, modtime time.Time) condResult {
+ if r.Method != "GET" {
+ return condNone
+ }
+ ir := r.Header.get("If-Range")
+ if ir == "" {
+ return condNone
+ }
+ etag, _ := scanETag(ir)
+ if etag != "" {
+ if etagStrongMatch(etag, w.Header().Get("Etag")) {
+ return condTrue
+ } else {
+ return condFalse
+ }
+ }
+ // The If-Range value is typically the ETag value, but it may also be
+ // the modtime date. See golang.org/issue/8367.
+ if modtime.IsZero() {
+ return condFalse
+ }
+ t, err := ParseTime(ir)
+ if err != nil {
+ return condFalse
+ }
+ if t.Unix() == modtime.Unix() {
+ return condTrue
+ }
+ return condFalse
+}
+
+var unixEpochTime = time.Unix(0, 0)
+
+// isZeroTime reports whether t is obviously unspecified (either zero or Unix()=0).
+func isZeroTime(t time.Time) bool {
+ return t.IsZero() || t.Equal(unixEpochTime)
+}
+
+func setLastModified(w ResponseWriter, modtime time.Time) {
+ if !isZeroTime(modtime) {
+ w.Header().Set("Last-Modified", modtime.UTC().Format(TimeFormat))
+ }
+}
- // TODO(bradfitz): non-GET/HEAD requests require more work:
- // sending a different status code on matches, and
- // also can't use weak cache validators (those with a "W/
- // prefix). But most users of ServeContent will be using
- // it on GET or HEAD, so only support those for now.
- if r.Method != "GET" && r.Method != "HEAD" {
- return rangeReq, false
+func writeNotModified(w ResponseWriter) {
+ // RFC 7232 section 4.1:
+ // a sender SHOULD NOT generate representation metadata other than the
+ // above listed fields unless said metadata exists for the purpose of
+ // guiding cache updates (e.g., Last-Modified might be useful if the
+ // response does not have an ETag field).
+ h := w.Header()
+ delete(h, "Content-Type")
+ delete(h, "Content-Length")
+ if h.Get("Etag") != "" {
+ delete(h, "Last-Modified")
+ }
+ w.WriteHeader(StatusNotModified)
+}
+
+// checkPreconditions evaluates request preconditions and reports whether a precondition
+// resulted in sending StatusNotModified or StatusPreconditionFailed.
+func checkPreconditions(w ResponseWriter, r *Request, modtime time.Time) (done bool, rangeHeader string) {
+ // This function carefully follows RFC 7232 section 6.
+ ch := checkIfMatch(w, r)
+ if ch == condNone {
+ ch = checkIfUnmodifiedSince(w, r, modtime)
+ }
+ if ch == condFalse {
+ w.WriteHeader(StatusPreconditionFailed)
+ return true, ""
+ }
+ switch checkIfNoneMatch(w, r) {
+ case condFalse:
+ if r.Method == "GET" || r.Method == "HEAD" {
+ writeNotModified(w)
+ return true, ""
+ } else {
+ w.WriteHeader(StatusPreconditionFailed)
+ return true, ""
}
+ case condNone:
+ if checkIfModifiedSince(w, r, modtime) == condFalse {
+ writeNotModified(w)
+ return true, ""
+ }
+ }
- // TODO(bradfitz): deal with comma-separated or multiple-valued
- // list of If-None-match values. For now just handle the common
- // case of a single item.
- if inm == etag || inm == "*" {
- h := w.Header()
- delete(h, "Content-Type")
- delete(h, "Content-Length")
- w.WriteHeader(StatusNotModified)
- return "", true
+ rangeHeader = r.Header.get("Range")
+ if rangeHeader != "" {
+ if checkIfRange(w, r, modtime) == condFalse {
+ rangeHeader = ""
}
}
- return rangeReq, false
+ return false, rangeHeader
}
// name is '/'-separated, not filepath.Separator.
@@ -419,9 +580,11 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec
// Still a directory? (we didn't find an index.html file)
if d.IsDir() {
- if checkLastModified(w, r, d.ModTime()) {
+ if checkIfModifiedSince(w, r, d.ModTime()) == condFalse {
+ writeNotModified(w)
return
}
+ w.Header().Set("Last-Modified", d.ModTime().UTC().Format(TimeFormat))
dirList(w, f)
return
}
@@ -543,6 +706,7 @@ func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHead
}
// parseRange parses a Range header string as per RFC 2616.
+// errNoOverlap is returned if none of the ranges overlap.
func parseRange(s string, size int64) ([]httpRange, error) {
if s == "" {
return nil, nil // header not present
@@ -552,6 +716,7 @@ func parseRange(s string, size int64) ([]httpRange, error) {
return nil, errors.New("invalid range")
}
var ranges []httpRange
+ noOverlap := false
for _, ra := range strings.Split(s[len(b):], ",") {
ra = strings.TrimSpace(ra)
if ra == "" {
@@ -577,9 +742,15 @@ func parseRange(s string, size int64) ([]httpRange, error) {
r.length = size - r.start
} else {
i, err := strconv.ParseInt(start, 10, 64)
- if err != nil || i >= size || i < 0 {
+ if err != nil || i < 0 {
return nil, errors.New("invalid range")
}
+ if i >= size {
+ // If the range begins after the size of the content,
+ // then it does not overlap.
+ noOverlap = true
+ continue
+ }
r.start = i
if end == "" {
// If no end is specified, range extends to end of the file.
@@ -597,6 +768,10 @@ func parseRange(s string, size int64) ([]httpRange, error) {
}
ranges = append(ranges, r)
}
+ if noOverlap && len(ranges) == 0 {
+ // The specified ranges did not overlap with the content.
+ return nil, errNoOverlap
+ }
return ranges, nil
}
@@ -628,9 +803,3 @@ func sumRangesSize(ranges []httpRange) (size int64) {
}
return
}
-
-type byName []os.FileInfo
-
-func (s byName) Len() int { return len(s) }
-func (s byName) Less(i, j int) bool { return s[i].Name() < s[j].Name() }
-func (s byName) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
diff --git a/libgo/go/net/http/fs_test.go b/libgo/go/net/http/fs_test.go
index 22be389..bba5682 100644
--- a/libgo/go/net/http/fs_test.go
+++ b/libgo/go/net/http/fs_test.go
@@ -68,6 +68,7 @@ var ServeFileRangeTests = []struct {
}
func TestServeFile(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
ServeFile(w, r, "testdata/file")
@@ -274,6 +275,7 @@ func TestFileServerEscapesNames(t *testing.T) {
{`"'<>&`, `<a href="%22%27%3C%3E&">&#34;&#39;&lt;&gt;&amp;</a>`},
{`?foo=bar#baz`, `<a href="%3Ffoo=bar%23baz">?foo=bar#baz</a>`},
{`<combo>?foo`, `<a href="%3Ccombo%3E%3Ffoo">&lt;combo&gt;?foo</a>`},
+ {`foo:bar`, `<a href="./foo:bar">foo:bar</a>`},
}
// We put each test file in its own directory in the fakeFS so we can look at it in isolation.
@@ -765,6 +767,7 @@ func TestServeContent(t *testing.T) {
reqHeader map[string]string
wantLastMod string
wantContentType string
+ wantContentRange string
wantStatus int
}
htmlModTime := mustStat(t, "testdata/index.html").ModTime()
@@ -782,8 +785,9 @@ func TestServeContent(t *testing.T) {
wantStatus: 200,
},
"not_modified_modtime": {
- file: "testdata/style.css",
- modtime: htmlModTime,
+ file: "testdata/style.css",
+ serveETag: `"foo"`, // Last-Modified sent only when no ETag
+ modtime: htmlModTime,
reqHeader: map[string]string{
"If-Modified-Since": htmlModTime.UTC().Format(TimeFormat),
},
@@ -792,6 +796,7 @@ func TestServeContent(t *testing.T) {
"not_modified_modtime_with_contenttype": {
file: "testdata/style.css",
serveContentType: "text/css", // explicit content type
+ serveETag: `"foo"`, // Last-Modified sent only when no ETag
modtime: htmlModTime,
reqHeader: map[string]string{
"If-Modified-Since": htmlModTime.UTC().Format(TimeFormat),
@@ -808,21 +813,62 @@ func TestServeContent(t *testing.T) {
},
"not_modified_etag_no_seek": {
content: panicOnSeek{nil}, // should never be called
- serveETag: `"foo"`,
+ serveETag: `W/"foo"`, // If-None-Match uses weak ETag comparison
reqHeader: map[string]string{
- "If-None-Match": `"foo"`,
+ "If-None-Match": `"baz", W/"foo"`,
},
wantStatus: 304,
},
+ "if_none_match_mismatch": {
+ file: "testdata/style.css",
+ serveETag: `"foo"`,
+ reqHeader: map[string]string{
+ "If-None-Match": `"Foo"`,
+ },
+ wantStatus: 200,
+ wantContentType: "text/css; charset=utf-8",
+ },
"range_good": {
file: "testdata/style.css",
serveETag: `"A"`,
reqHeader: map[string]string{
"Range": "bytes=0-4",
},
- wantStatus: StatusPartialContent,
+ wantStatus: StatusPartialContent,
+ wantContentType: "text/css; charset=utf-8",
+ wantContentRange: "bytes 0-4/8",
+ },
+ "range_match": {
+ file: "testdata/style.css",
+ serveETag: `"A"`,
+ reqHeader: map[string]string{
+ "Range": "bytes=0-4",
+ "If-Range": `"A"`,
+ },
+ wantStatus: StatusPartialContent,
+ wantContentType: "text/css; charset=utf-8",
+ wantContentRange: "bytes 0-4/8",
+ },
+ "range_match_weak_etag": {
+ file: "testdata/style.css",
+ serveETag: `W/"A"`,
+ reqHeader: map[string]string{
+ "Range": "bytes=0-4",
+ "If-Range": `W/"A"`,
+ },
+ wantStatus: 200,
wantContentType: "text/css; charset=utf-8",
},
+ "range_no_overlap": {
+ file: "testdata/style.css",
+ serveETag: `"A"`,
+ reqHeader: map[string]string{
+ "Range": "bytes=10-20",
+ },
+ wantStatus: StatusRequestedRangeNotSatisfiable,
+ wantContentType: "text/plain; charset=utf-8",
+ wantContentRange: "bytes */8",
+ },
// An If-Range resource for entity "A", but entity "B" is now current.
// The Range request should be ignored.
"range_no_match": {
@@ -842,9 +888,10 @@ func TestServeContent(t *testing.T) {
"Range": "bytes=0-4",
"If-Range": "Wed, 25 Jun 2014 17:12:18 GMT",
},
- wantStatus: StatusPartialContent,
- wantContentType: "text/css; charset=utf-8",
- wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT",
+ wantStatus: StatusPartialContent,
+ wantContentType: "text/css; charset=utf-8",
+ wantContentRange: "bytes 0-4/8",
+ wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT",
},
"range_with_modtime_nanos": {
file: "testdata/style.css",
@@ -853,9 +900,10 @@ func TestServeContent(t *testing.T) {
"Range": "bytes=0-4",
"If-Range": "Wed, 25 Jun 2014 17:12:18 GMT",
},
- wantStatus: StatusPartialContent,
- wantContentType: "text/css; charset=utf-8",
- wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT",
+ wantStatus: StatusPartialContent,
+ wantContentType: "text/css; charset=utf-8",
+ wantContentRange: "bytes 0-4/8",
+ wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT",
},
"unix_zero_modtime": {
content: strings.NewReader("<html>foo"),
@@ -863,6 +911,62 @@ func TestServeContent(t *testing.T) {
wantStatus: StatusOK,
wantContentType: "text/html; charset=utf-8",
},
+ "ifmatch_matches": {
+ file: "testdata/style.css",
+ serveETag: `"A"`,
+ reqHeader: map[string]string{
+ "If-Match": `"Z", "A"`,
+ },
+ wantStatus: 200,
+ wantContentType: "text/css; charset=utf-8",
+ },
+ "ifmatch_star": {
+ file: "testdata/style.css",
+ serveETag: `"A"`,
+ reqHeader: map[string]string{
+ "If-Match": `*`,
+ },
+ wantStatus: 200,
+ wantContentType: "text/css; charset=utf-8",
+ },
+ "ifmatch_failed": {
+ file: "testdata/style.css",
+ serveETag: `"A"`,
+ reqHeader: map[string]string{
+ "If-Match": `"B"`,
+ },
+ wantStatus: 412,
+ wantContentType: "text/plain; charset=utf-8",
+ },
+ "ifmatch_fails_on_weak_etag": {
+ file: "testdata/style.css",
+ serveETag: `W/"A"`,
+ reqHeader: map[string]string{
+ "If-Match": `W/"A"`,
+ },
+ wantStatus: 412,
+ wantContentType: "text/plain; charset=utf-8",
+ },
+ "if_unmodified_since_true": {
+ file: "testdata/style.css",
+ modtime: htmlModTime,
+ reqHeader: map[string]string{
+ "If-Unmodified-Since": htmlModTime.UTC().Format(TimeFormat),
+ },
+ wantStatus: 200,
+ wantContentType: "text/css; charset=utf-8",
+ wantLastMod: htmlModTime.UTC().Format(TimeFormat),
+ },
+ "if_unmodified_since_false": {
+ file: "testdata/style.css",
+ modtime: htmlModTime,
+ reqHeader: map[string]string{
+ "If-Unmodified-Since": htmlModTime.Add(-2 * time.Second).UTC().Format(TimeFormat),
+ },
+ wantStatus: 412,
+ wantContentType: "text/plain; charset=utf-8",
+ wantLastMod: htmlModTime.UTC().Format(TimeFormat),
+ },
}
for testName, tt := range tests {
var content io.ReadSeeker
@@ -903,6 +1007,9 @@ func TestServeContent(t *testing.T) {
if g, e := res.Header.Get("Content-Type"), tt.wantContentType; g != e {
t.Errorf("test %q: content-type = %q, want %q", testName, g, e)
}
+ if g, e := res.Header.Get("Content-Range"), tt.wantContentRange; g != e {
+ t.Errorf("test %q: content-range = %q, want %q", testName, g, e)
+ }
if g, e := res.Header.Get("Last-Modified"), tt.wantLastMod; g != e {
t.Errorf("test %q: last-modified = %q, want %q", testName, g, e)
}
@@ -958,6 +1065,7 @@ func TestServeContentErrorMessages(t *testing.T) {
// verifies that sendfile is being used on Linux
func TestLinuxSendfile(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
if runtime.GOOS != "linux" {
t.Skip("skipping; linux-only test")
@@ -982,6 +1090,8 @@ func TestLinuxSendfile(t *testing.T) {
// strace on the above platforms doesn't support sendfile64
// and will error out if we specify that with `-e trace='.
syscalls = "sendfile"
+ case "mips64":
+ t.Skip("TODO: update this test to be robust against various versions of strace on mips64. See golang.org/issue/33430")
}
var buf bytes.Buffer
@@ -1008,10 +1118,9 @@ func TestLinuxSendfile(t *testing.T) {
Post(fmt.Sprintf("http://%s/quit", ln.Addr()), "", nil)
child.Wait()
- rx := regexp.MustCompile(`sendfile(64)?\(\d+,\s*\d+,\s*NULL,\s*\d+\)\s*=\s*\d+\s*\n`)
- rxResume := regexp.MustCompile(`<\.\.\. sendfile(64)? resumed> \)\s*=\s*\d+\s*\n`)
+ rx := regexp.MustCompile(`sendfile(64)?\(\d+,\s*\d+,\s*NULL,\s*\d+`)
out := buf.String()
- if !rx.MatchString(out) && !rxResume.MatchString(out) {
+ if !rx.MatchString(out) {
t.Errorf("no sendfile system call found in:\n%s", out)
}
}
@@ -1090,3 +1199,26 @@ func (d fileServerCleanPathDir) Open(path string) (File, error) {
}
type panicOnSeek struct{ io.ReadSeeker }
+
+func Test_scanETag(t *testing.T) {
+ tests := []struct {
+ in string
+ wantETag string
+ wantRemain string
+ }{
+ {`W/"etag-1"`, `W/"etag-1"`, ""},
+ {`"etag-2"`, `"etag-2"`, ""},
+ {`"etag-1", "etag-2"`, `"etag-1"`, `, "etag-2"`},
+ {"", "", ""},
+ {"", "", ""},
+ {"W/", "", ""},
+ {`W/"truc`, "", ""},
+ {`w/"case-sensitive"`, "", ""},
+ }
+ for _, test := range tests {
+ etag, remain := ExportScanETag(test.in)
+ if etag != test.wantETag || remain != test.wantRemain {
+ t.Errorf("scanETag(%q)=%q %q, want %q %q", test.in, etag, remain, test.wantETag, test.wantRemain)
+ }
+ }
+}
diff --git a/libgo/go/net/http/h2_bundle.go b/libgo/go/net/http/h2_bundle.go
index 5826bb7..25fdf09 100644
--- a/libgo/go/net/http/h2_bundle.go
+++ b/libgo/go/net/http/h2_bundle.go
@@ -1,5 +1,5 @@
// Code generated by golang.org/x/tools/cmd/bundle.
-//go:generate bundle -o h2_bundle.go -prefix http2 golang.org/x/net/http2
+//go:generate bundle -o h2_bundle.go -prefix http2 -underscore golang.org/x/net/http2
// Package http2 implements the HTTP/2 protocol.
//
@@ -21,6 +21,7 @@ import (
"bytes"
"compress/gzip"
"context"
+ "crypto/rand"
"crypto/tls"
"encoding/binary"
"errors"
@@ -43,6 +44,7 @@ import (
"time"
"golang_org/x/net/http2/hpack"
+ "golang_org/x/net/idna"
"golang_org/x/net/lex/httplex"
)
@@ -853,10 +855,12 @@ type http2Framer struct {
// If the limit is hit, MetaHeadersFrame.Truncated is set true.
MaxHeaderListSize uint32
- logReads bool
+ logReads, logWrites bool
- debugFramer *http2Framer // only use for logging written writes
- debugFramerBuf *bytes.Buffer
+ debugFramer *http2Framer // only use for logging written writes
+ debugFramerBuf *bytes.Buffer
+ debugReadLoggerf func(string, ...interface{})
+ debugWriteLoggerf func(string, ...interface{})
}
func (fr *http2Framer) maxHeaderListSize() uint32 {
@@ -890,7 +894,7 @@ func (f *http2Framer) endWrite() error {
byte(length>>16),
byte(length>>8),
byte(length))
- if http2logFrameWrites {
+ if f.logWrites {
f.logWrite()
}
@@ -912,10 +916,10 @@ func (f *http2Framer) logWrite() {
f.debugFramerBuf.Write(f.wbuf)
fr, err := f.debugFramer.ReadFrame()
if err != nil {
- log.Printf("http2: Framer %p: failed to decode just-written frame", f)
+ f.debugWriteLoggerf("http2: Framer %p: failed to decode just-written frame", f)
return
}
- log.Printf("http2: Framer %p: wrote %v", f, http2summarizeFrame(fr))
+ f.debugWriteLoggerf("http2: Framer %p: wrote %v", f, http2summarizeFrame(fr))
}
func (f *http2Framer) writeByte(v byte) { f.wbuf = append(f.wbuf, v) }
@@ -936,9 +940,12 @@ const (
// NewFramer returns a Framer that writes frames to w and reads them from r.
func http2NewFramer(w io.Writer, r io.Reader) *http2Framer {
fr := &http2Framer{
- w: w,
- r: r,
- logReads: http2logFrameReads,
+ w: w,
+ r: r,
+ logReads: http2logFrameReads,
+ logWrites: http2logFrameWrites,
+ debugReadLoggerf: log.Printf,
+ debugWriteLoggerf: log.Printf,
}
fr.getReadBuf = func(size uint32) []byte {
if cap(fr.readBuf) >= int(size) {
@@ -1020,7 +1027,7 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) {
return nil, err
}
if fr.logReads {
- log.Printf("http2: Framer %p: read %v", fr, http2summarizeFrame(f))
+ fr.debugReadLoggerf("http2: Framer %p: read %v", fr, http2summarizeFrame(f))
}
if fh.Type == http2FrameHeaders && fr.ReadMetaHeaders != nil {
return fr.readMetaFrame(f.(*http2HeadersFrame))
@@ -1254,7 +1261,7 @@ func (f *http2Framer) WriteSettings(settings ...http2Setting) error {
return f.endWrite()
}
-// WriteSettings writes an empty SETTINGS frame with the ACK bit set.
+// WriteSettingsAck writes an empty SETTINGS frame with the ACK bit set.
//
// It will perform exactly one Write to the underlying Writer.
// It is the caller's responsibility to not call other Write methods concurrently.
@@ -1920,8 +1927,8 @@ func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFr
hdec.SetEmitEnabled(true)
hdec.SetMaxStringLength(fr.maxHeaderStringLen())
hdec.SetEmitFunc(func(hf hpack.HeaderField) {
- if http2VerboseLogs && http2logFrameReads {
- log.Printf("http2: decoded hpack field %+v", hf)
+ if http2VerboseLogs && fr.logReads {
+ fr.debugReadLoggerf("http2: decoded hpack field %+v", hf)
}
if !httplex.ValidHeaderFieldValue(hf.Value) {
invalid = http2headerFieldValueError(hf.Value)
@@ -2091,6 +2098,13 @@ type http2clientTrace httptrace.ClientTrace
func http2reqContext(r *Request) context.Context { return r.Context() }
+func (t *http2Transport) idleConnTimeout() time.Duration {
+ if t.t1 != nil {
+ return t.t1.IdleConnTimeout
+ }
+ return 0
+}
+
func http2setResponseUncompressed(res *Response) { res.Uncompressed = true }
func http2traceGotConn(req *Request, cc *http2ClientConn) {
@@ -2145,6 +2159,48 @@ func http2requestTrace(req *Request) *http2clientTrace {
return (*http2clientTrace)(trace)
}
+// Ping sends a PING frame to the server and waits for the ack.
+func (cc *http2ClientConn) Ping(ctx context.Context) error {
+ return cc.ping(ctx)
+}
+
+func http2cloneTLSConfig(c *tls.Config) *tls.Config { return c.Clone() }
+
+var _ Pusher = (*http2responseWriter)(nil)
+
+// Push implements http.Pusher.
+func (w *http2responseWriter) Push(target string, opts *PushOptions) error {
+ internalOpts := http2pushOptions{}
+ if opts != nil {
+ internalOpts.Method = opts.Method
+ internalOpts.Header = opts.Header
+ }
+ return w.push(target, internalOpts)
+}
+
+func http2configureServer18(h1 *Server, h2 *http2Server) error {
+ if h2.IdleTimeout == 0 {
+ if h1.IdleTimeout != 0 {
+ h2.IdleTimeout = h1.IdleTimeout
+ } else {
+ h2.IdleTimeout = h1.ReadTimeout
+ }
+ }
+ return nil
+}
+
+func http2shouldLogPanic(panicValue interface{}) bool {
+ return panicValue != nil && panicValue != ErrAbortHandler
+}
+
+func http2reqGetBody(req *Request) func() (io.ReadCloser, error) {
+ return req.GetBody
+}
+
+func http2reqBodyIsNoBody(body io.ReadCloser) bool {
+ return body == NoBody
+}
+
var http2DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1"
type http2goroutineLock uint64
@@ -2368,6 +2424,7 @@ var (
http2VerboseLogs bool
http2logFrameWrites bool
http2logFrameReads bool
+ http2inTests bool
)
func init() {
@@ -2409,13 +2466,23 @@ var (
type http2streamState int
+// HTTP/2 stream states.
+//
+// See http://tools.ietf.org/html/rfc7540#section-5.1.
+//
+// For simplicity, the server code merges "reserved (local)" into
+// "half-closed (remote)". This is one less state transition to track.
+// The only downside is that we send PUSH_PROMISEs slightly less
+// liberally than allowable. More discussion here:
+// https://lists.w3.org/Archives/Public/ietf-http-wg/2016JulSep/0599.html
+//
+// "reserved (remote)" is omitted since the client code does not
+// support server push.
const (
http2stateIdle http2streamState = iota
http2stateOpen
http2stateHalfClosedLocal
http2stateHalfClosedRemote
- http2stateResvLocal
- http2stateResvRemote
http2stateClosed
)
@@ -2424,8 +2491,6 @@ var http2stateName = [...]string{
http2stateOpen: "Open",
http2stateHalfClosedLocal: "HalfClosedLocal",
http2stateHalfClosedRemote: "HalfClosedRemote",
- http2stateResvLocal: "ResvLocal",
- http2stateResvRemote: "ResvRemote",
http2stateClosed: "Closed",
}
@@ -2586,13 +2651,27 @@ func http2newBufferedWriter(w io.Writer) *http2bufferedWriter {
return &http2bufferedWriter{w: w}
}
+// bufWriterPoolBufferSize is the size of bufio.Writer's
+// buffers created using bufWriterPool.
+//
+// TODO: pick a less arbitrary value? this is a bit under
+// (3 x typical 1500 byte MTU) at least. Other than that,
+// not much thought went into it.
+const http2bufWriterPoolBufferSize = 4 << 10
+
var http2bufWriterPool = sync.Pool{
New: func() interface{} {
-
- return bufio.NewWriterSize(nil, 4<<10)
+ return bufio.NewWriterSize(nil, http2bufWriterPoolBufferSize)
},
}
+func (w *http2bufferedWriter) Available() int {
+ if w.bw == nil {
+ return http2bufWriterPoolBufferSize
+ }
+ return w.bw.Available()
+}
+
func (w *http2bufferedWriter) Write(p []byte) (n int, err error) {
if w.bw == nil {
bw := http2bufWriterPool.Get().(*bufio.Writer)
@@ -2686,6 +2765,19 @@ func (s *http2sorter) SortStrings(ss []string) {
s.v = save
}
+// validPseudoPath reports whether v is a valid :path pseudo-header
+// value. It must be either:
+//
+// *) a non-empty string starting with '/', but not with with "//",
+// *) the string '*', for OPTIONS requests.
+//
+// For now this is only used a quick check for deciding when to clean
+// up Opaque URLs before sending requests from the Transport.
+// See golang.org/issue/16847
+func http2validPseudoPath(v string) bool {
+ return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*"
+}
+
// pipe is a goroutine-safe io.Reader/io.Writer pair. It's like
// io.Pipe except there are no PipeReader/PipeWriter halves, and the
// underlying buffer is an interface. (io.Pipe is always unbuffered)
@@ -2882,6 +2974,15 @@ type http2Server struct {
// PermitProhibitedCipherSuites, if true, permits the use of
// cipher suites prohibited by the HTTP/2 spec.
PermitProhibitedCipherSuites bool
+
+ // IdleTimeout specifies how long until idle clients should be
+ // closed with a GOAWAY frame. PING frames are not considered
+ // activity for the purposes of IdleTimeout.
+ IdleTimeout time.Duration
+
+ // NewWriteScheduler constructs a write scheduler for a connection.
+ // If nil, a default scheduler is chosen.
+ NewWriteScheduler func() http2WriteScheduler
}
func (s *http2Server) maxReadFrameSize() uint32 {
@@ -2904,9 +3005,15 @@ func (s *http2Server) maxConcurrentStreams() uint32 {
//
// ConfigureServer must be called before s begins serving.
func http2ConfigureServer(s *Server, conf *http2Server) error {
+ if s == nil {
+ panic("nil *http.Server")
+ }
if conf == nil {
conf = new(http2Server)
}
+ if err := http2configureServer18(s, conf); err != nil {
+ return err
+ }
if s.TLSConfig == nil {
s.TLSConfig = new(tls.Config)
@@ -2945,8 +3052,6 @@ func http2ConfigureServer(s *Server, conf *http2Server) error {
s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, http2NextProtoTLS)
}
- s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "h2-14")
-
if s.TLSNextProto == nil {
s.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
}
@@ -2960,7 +3065,6 @@ func http2ConfigureServer(s *Server, conf *http2Server) error {
})
}
s.TLSNextProto[http2NextProtoTLS] = protoHandler
- s.TLSNextProto["h2-14"] = protoHandler
return nil
}
@@ -3014,29 +3118,39 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) {
defer cancel()
sc := &http2serverConn{
- srv: s,
- hs: opts.baseConfig(),
- conn: c,
- baseCtx: baseCtx,
- remoteAddrStr: c.RemoteAddr().String(),
- bw: http2newBufferedWriter(c),
- handler: opts.handler(),
- streams: make(map[uint32]*http2stream),
- readFrameCh: make(chan http2readFrameResult),
- wantWriteFrameCh: make(chan http2frameWriteMsg, 8),
- wroteFrameCh: make(chan http2frameWriteResult, 1),
- bodyReadCh: make(chan http2bodyReadMsg),
- doneServing: make(chan struct{}),
- advMaxStreams: s.maxConcurrentStreams(),
- writeSched: http2writeScheduler{
- maxFrameSize: http2initialMaxFrameSize,
- },
+ srv: s,
+ hs: opts.baseConfig(),
+ conn: c,
+ baseCtx: baseCtx,
+ remoteAddrStr: c.RemoteAddr().String(),
+ bw: http2newBufferedWriter(c),
+ handler: opts.handler(),
+ streams: make(map[uint32]*http2stream),
+ readFrameCh: make(chan http2readFrameResult),
+ wantWriteFrameCh: make(chan http2FrameWriteRequest, 8),
+ wantStartPushCh: make(chan http2startPushRequest, 8),
+ wroteFrameCh: make(chan http2frameWriteResult, 1),
+ bodyReadCh: make(chan http2bodyReadMsg),
+ doneServing: make(chan struct{}),
+ clientMaxStreams: math.MaxUint32,
+ advMaxStreams: s.maxConcurrentStreams(),
initialWindowSize: http2initialWindowSize,
+ maxFrameSize: http2initialMaxFrameSize,
headerTableSize: http2initialHeaderTableSize,
serveG: http2newGoroutineLock(),
pushEnabled: true,
}
+ if sc.hs.WriteTimeout != 0 {
+ sc.conn.SetWriteDeadline(time.Time{})
+ }
+
+ if s.NewWriteScheduler != nil {
+ sc.writeSched = s.NewWriteScheduler()
+ } else {
+ sc.writeSched = http2NewRandomWriteScheduler()
+ }
+
sc.flow.add(http2initialWindowSize)
sc.inflow.add(http2initialWindowSize)
sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
@@ -3090,16 +3204,18 @@ type http2serverConn struct {
handler Handler
baseCtx http2contextContext
framer *http2Framer
- doneServing chan struct{} // closed when serverConn.serve ends
- readFrameCh chan http2readFrameResult // written by serverConn.readFrames
- wantWriteFrameCh chan http2frameWriteMsg // from handlers -> serve
- wroteFrameCh chan http2frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes
- bodyReadCh chan http2bodyReadMsg // from handlers -> serve
- testHookCh chan func(int) // code to run on the serve loop
- flow http2flow // conn-wide (not stream-specific) outbound flow control
- inflow http2flow // conn-wide inbound flow control
- tlsState *tls.ConnectionState // shared by all handlers, like net/http
+ doneServing chan struct{} // closed when serverConn.serve ends
+ readFrameCh chan http2readFrameResult // written by serverConn.readFrames
+ wantWriteFrameCh chan http2FrameWriteRequest // from handlers -> serve
+ wantStartPushCh chan http2startPushRequest // from handlers -> serve
+ wroteFrameCh chan http2frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes
+ bodyReadCh chan http2bodyReadMsg // from handlers -> serve
+ testHookCh chan func(int) // code to run on the serve loop
+ flow http2flow // conn-wide (not stream-specific) outbound flow control
+ inflow http2flow // conn-wide inbound flow control
+ tlsState *tls.ConnectionState // shared by all handlers, like net/http
remoteAddrStr string
+ writeSched http2WriteScheduler
// Everything following is owned by the serve loop; use serveG.check():
serveG http2goroutineLock // used to verify funcs are on serve()
@@ -3109,22 +3225,27 @@ type http2serverConn struct {
unackedSettings int // how many SETTINGS have we sent without ACKs?
clientMaxStreams uint32 // SETTINGS_MAX_CONCURRENT_STREAMS from client (our PUSH_PROMISE limit)
advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client
- curOpenStreams uint32 // client's number of open streams
- maxStreamID uint32 // max ever seen
+ curClientStreams uint32 // number of open streams initiated by the client
+ curPushedStreams uint32 // number of open streams initiated by server push
+ maxClientStreamID uint32 // max ever seen from client (odd), or 0 if there have been no client requests
+ maxPushPromiseID uint32 // ID of the last push promise (even), or 0 if there have been no pushes
streams map[uint32]*http2stream
initialWindowSize int32
+ maxFrameSize int32
headerTableSize uint32
peerMaxHeaderListSize uint32 // zero means unknown (default)
canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case
- writingFrame bool // started write goroutine but haven't heard back on wroteFrameCh
+ writingFrame bool // started writing a frame (on serve goroutine or separate)
+ writingFrameAsync bool // started a frame on its own goroutine but haven't heard back on wroteFrameCh
needsFrameFlush bool // last frame write wasn't a flush
- writeSched http2writeScheduler
- inGoAway bool // we've started to or sent GOAWAY
- needToSendGoAway bool // we need to schedule a GOAWAY frame write
+ inGoAway bool // we've started to or sent GOAWAY
+ inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop
+ needToSendGoAway bool // we need to schedule a GOAWAY frame write
goAwayCode http2ErrCode
shutdownTimerCh <-chan time.Time // nil until used
shutdownTimer *time.Timer // nil until used
- freeRequestBodyBuf []byte // if non-nil, a free initialWindowSize buffer for getRequestBodyBuf
+ idleTimer *time.Timer // nil if unused
+ idleTimerCh <-chan time.Time // nil if unused
// Owned by the writeFrameAsync goroutine:
headerWriteBuf bytes.Buffer
@@ -3143,6 +3264,11 @@ func (sc *http2serverConn) maxHeaderListSize() uint32 {
return uint32(n + typicalHeaders*perFieldOverhead)
}
+func (sc *http2serverConn) curOpenStreams() uint32 {
+ sc.serveG.check()
+ return sc.curClientStreams + sc.curPushedStreams
+}
+
// stream represents a stream. This is the minimal metadata needed by
// the serve goroutine. Most of the actual stream state is owned by
// the http.Handler's goroutine in the responseWriter. Because the
@@ -3168,11 +3294,10 @@ type http2stream struct {
numTrailerValues int64
weight uint8
state http2streamState
- sentReset bool // only true once detached from streams map
- gotReset bool // only true once detacted from streams map
- gotTrailerHeader bool // HEADER frame for trailers was seen
- wroteHeaders bool // whether we wrote headers (not status 100)
- reqBuf []byte
+ resetQueued bool // RST_STREAM queued for write; set by sc.resetStream
+ gotTrailerHeader bool // HEADER frame for trailers was seen
+ wroteHeaders bool // whether we wrote headers (not status 100)
+ reqBuf []byte // if non-nil, body pipe buffer to return later at EOF
trailer Header // accumulated trailers
reqTrailer Header // handler's Request.Trailer
@@ -3195,8 +3320,14 @@ func (sc *http2serverConn) state(streamID uint32) (http2streamState, *http2strea
return st.state, st
}
- if streamID <= sc.maxStreamID {
- return http2stateClosed, nil
+ if streamID%2 == 1 {
+ if streamID <= sc.maxClientStreamID {
+ return http2stateClosed, nil
+ }
+ } else {
+ if streamID <= sc.maxPushPromiseID {
+ return http2stateClosed, nil
+ }
}
return http2stateIdle, nil
}
@@ -3328,17 +3459,17 @@ func (sc *http2serverConn) readFrames() {
// frameWriteResult is the message passed from writeFrameAsync to the serve goroutine.
type http2frameWriteResult struct {
- wm http2frameWriteMsg // what was written (or attempted)
- err error // result of the writeFrame call
+ wr http2FrameWriteRequest // what was written (or attempted)
+ err error // result of the writeFrame call
}
// writeFrameAsync runs in its own goroutine and writes a single frame
// and then reports when it's done.
// At most one goroutine can be running writeFrameAsync at a time per
// serverConn.
-func (sc *http2serverConn) writeFrameAsync(wm http2frameWriteMsg) {
- err := wm.write.writeFrame(sc)
- sc.wroteFrameCh <- http2frameWriteResult{wm, err}
+func (sc *http2serverConn) writeFrameAsync(wr http2FrameWriteRequest) {
+ err := wr.write.writeFrame(sc)
+ sc.wroteFrameCh <- http2frameWriteResult{wr, err}
}
func (sc *http2serverConn) closeAllStreamsOnConnClose() {
@@ -3382,7 +3513,7 @@ func (sc *http2serverConn) serve() {
sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
}
- sc.writeFrame(http2frameWriteMsg{
+ sc.writeFrame(http2FrameWriteRequest{
write: http2writeSettings{
{http2SettingMaxFrameSize, sc.srv.maxReadFrameSize()},
{http2SettingMaxConcurrentStreams, sc.advMaxStreams},
@@ -3399,6 +3530,17 @@ func (sc *http2serverConn) serve() {
sc.setConnState(StateActive)
sc.setConnState(StateIdle)
+ if sc.srv.IdleTimeout != 0 {
+ sc.idleTimer = time.NewTimer(sc.srv.IdleTimeout)
+ defer sc.idleTimer.Stop()
+ sc.idleTimerCh = sc.idleTimer.C
+ }
+
+ var gracefulShutdownCh <-chan struct{}
+ if sc.hs != nil {
+ gracefulShutdownCh = http2h1ServerShutdownChan(sc.hs)
+ }
+
go sc.readFrames()
settingsTimer := time.NewTimer(http2firstSettingsTimeout)
@@ -3406,8 +3548,10 @@ func (sc *http2serverConn) serve() {
for {
loopNum++
select {
- case wm := <-sc.wantWriteFrameCh:
- sc.writeFrame(wm)
+ case wr := <-sc.wantWriteFrameCh:
+ sc.writeFrame(wr)
+ case spr := <-sc.wantStartPushCh:
+ sc.startPush(spr)
case res := <-sc.wroteFrameCh:
sc.wroteFrame(res)
case res := <-sc.readFrameCh:
@@ -3424,12 +3568,22 @@ func (sc *http2serverConn) serve() {
case <-settingsTimer.C:
sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr())
return
+ case <-gracefulShutdownCh:
+ gracefulShutdownCh = nil
+ sc.startGracefulShutdown()
case <-sc.shutdownTimerCh:
sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
return
+ case <-sc.idleTimerCh:
+ sc.vlogf("connection is idle")
+ sc.goAway(http2ErrCodeNo)
case fn := <-sc.testHookCh:
fn(loopNum)
}
+
+ if sc.inGoAway && sc.curOpenStreams() == 0 && !sc.needToSendGoAway && !sc.writingFrame {
+ return
+ }
}
}
@@ -3477,7 +3631,7 @@ func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte
ch := http2errChanPool.Get().(chan error)
writeArg := http2writeDataPool.Get().(*http2writeData)
*writeArg = http2writeData{stream.id, data, endStream}
- err := sc.writeFrameFromHandler(http2frameWriteMsg{
+ err := sc.writeFrameFromHandler(http2FrameWriteRequest{
write: writeArg,
stream: stream,
done: ch,
@@ -3507,17 +3661,17 @@ func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte
return err
}
-// writeFrameFromHandler sends wm to sc.wantWriteFrameCh, but aborts
+// writeFrameFromHandler sends wr to sc.wantWriteFrameCh, but aborts
// if the connection has gone away.
//
// This must not be run from the serve goroutine itself, else it might
// deadlock writing to sc.wantWriteFrameCh (which is only mildly
// buffered and is read by serve itself). If you're on the serve
// goroutine, call writeFrame instead.
-func (sc *http2serverConn) writeFrameFromHandler(wm http2frameWriteMsg) error {
+func (sc *http2serverConn) writeFrameFromHandler(wr http2FrameWriteRequest) error {
sc.serveG.checkNotOn()
select {
- case sc.wantWriteFrameCh <- wm:
+ case sc.wantWriteFrameCh <- wr:
return nil
case <-sc.doneServing:
@@ -3533,53 +3687,81 @@ func (sc *http2serverConn) writeFrameFromHandler(wm http2frameWriteMsg) error {
// make it onto the wire
//
// If you're not on the serve goroutine, use writeFrameFromHandler instead.
-func (sc *http2serverConn) writeFrame(wm http2frameWriteMsg) {
+func (sc *http2serverConn) writeFrame(wr http2FrameWriteRequest) {
sc.serveG.check()
+ // If true, wr will not be written and wr.done will not be signaled.
var ignoreWrite bool
- switch wm.write.(type) {
+ if wr.StreamID() != 0 {
+ _, isReset := wr.write.(http2StreamError)
+ if state, _ := sc.state(wr.StreamID()); state == http2stateClosed && !isReset {
+ ignoreWrite = true
+ }
+ }
+
+ switch wr.write.(type) {
case *http2writeResHeaders:
- wm.stream.wroteHeaders = true
+ wr.stream.wroteHeaders = true
case http2write100ContinueHeadersFrame:
- if wm.stream.wroteHeaders {
+ if wr.stream.wroteHeaders {
+
+ if wr.done != nil {
+ panic("wr.done != nil for write100ContinueHeadersFrame")
+ }
ignoreWrite = true
}
}
if !ignoreWrite {
- sc.writeSched.add(wm)
+ sc.writeSched.Push(wr)
}
sc.scheduleFrameWrite()
}
-// startFrameWrite starts a goroutine to write wm (in a separate
+// startFrameWrite starts a goroutine to write wr (in a separate
// goroutine since that might block on the network), and updates the
-// serve goroutine's state about the world, updated from info in wm.
-func (sc *http2serverConn) startFrameWrite(wm http2frameWriteMsg) {
+// serve goroutine's state about the world, updated from info in wr.
+func (sc *http2serverConn) startFrameWrite(wr http2FrameWriteRequest) {
sc.serveG.check()
if sc.writingFrame {
panic("internal error: can only be writing one frame at a time")
}
- st := wm.stream
+ st := wr.stream
if st != nil {
switch st.state {
case http2stateHalfClosedLocal:
- panic("internal error: attempt to send frame on half-closed-local stream")
- case http2stateClosed:
- if st.sentReset || st.gotReset {
+ switch wr.write.(type) {
+ case http2StreamError, http2handlerPanicRST, http2writeWindowUpdate:
- sc.scheduleFrameWrite()
- return
+ default:
+ panic(fmt.Sprintf("internal error: attempt to send frame on a half-closed-local stream: %v", wr))
}
- panic(fmt.Sprintf("internal error: attempt to send a write %v on a closed stream", wm))
+ case http2stateClosed:
+ panic(fmt.Sprintf("internal error: attempt to send frame on a closed stream: %v", wr))
+ }
+ }
+ if wpp, ok := wr.write.(*http2writePushPromise); ok {
+ var err error
+ wpp.promisedID, err = wpp.allocatePromisedID()
+ if err != nil {
+ sc.writingFrameAsync = false
+ wr.replyToWriter(err)
+ return
}
}
sc.writingFrame = true
sc.needsFrameFlush = true
- go sc.writeFrameAsync(wm)
+ if wr.write.staysWithinBuffer(sc.bw.Available()) {
+ sc.writingFrameAsync = false
+ err := wr.write.writeFrame(sc)
+ sc.wroteFrame(http2frameWriteResult{wr, err})
+ } else {
+ sc.writingFrameAsync = true
+ go sc.writeFrameAsync(wr)
+ }
}
// errHandlerPanicked is the error given to any callers blocked in a read from
@@ -3595,26 +3777,12 @@ func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) {
panic("internal error: expected to be already writing a frame")
}
sc.writingFrame = false
+ sc.writingFrameAsync = false
- wm := res.wm
- st := wm.stream
-
- closeStream := http2endsStream(wm.write)
-
- if _, ok := wm.write.(http2handlerPanicRST); ok {
- sc.closeStream(st, http2errHandlerPanicked)
- }
+ wr := res.wr
- if ch := wm.done; ch != nil {
- select {
- case ch <- res.err:
- default:
- panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wm.write))
- }
- }
- wm.write = nil
-
- if closeStream {
+ if http2writeEndsStream(wr.write) {
+ st := wr.stream
if st == nil {
panic("internal error: expecting non-nil stream")
}
@@ -3622,13 +3790,24 @@ func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) {
case http2stateOpen:
st.state = http2stateHalfClosedLocal
- errCancel := http2streamError(st.id, http2ErrCodeCancel)
- sc.resetStream(errCancel)
+ sc.resetStream(http2streamError(st.id, http2ErrCodeCancel))
case http2stateHalfClosedRemote:
sc.closeStream(st, http2errHandlerComplete)
}
+ } else {
+ switch v := wr.write.(type) {
+ case http2StreamError:
+
+ if st, ok := sc.streams[v.StreamID]; ok {
+ sc.closeStream(st, v)
+ }
+ case http2handlerPanicRST:
+ sc.closeStream(wr.stream, http2errHandlerPanicked)
+ }
}
+ wr.replyToWriter(res.err)
+
sc.scheduleFrameWrite()
}
@@ -3646,47 +3825,68 @@ func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) {
// flush the write buffer.
func (sc *http2serverConn) scheduleFrameWrite() {
sc.serveG.check()
- if sc.writingFrame {
- return
- }
- if sc.needToSendGoAway {
- sc.needToSendGoAway = false
- sc.startFrameWrite(http2frameWriteMsg{
- write: &http2writeGoAway{
- maxStreamID: sc.maxStreamID,
- code: sc.goAwayCode,
- },
- })
- return
- }
- if sc.needToSendSettingsAck {
- sc.needToSendSettingsAck = false
- sc.startFrameWrite(http2frameWriteMsg{write: http2writeSettingsAck{}})
+ if sc.writingFrame || sc.inFrameScheduleLoop {
return
}
- if !sc.inGoAway {
- if wm, ok := sc.writeSched.take(); ok {
- sc.startFrameWrite(wm)
- return
+ sc.inFrameScheduleLoop = true
+ for !sc.writingFrameAsync {
+ if sc.needToSendGoAway {
+ sc.needToSendGoAway = false
+ sc.startFrameWrite(http2FrameWriteRequest{
+ write: &http2writeGoAway{
+ maxStreamID: sc.maxClientStreamID,
+ code: sc.goAwayCode,
+ },
+ })
+ continue
}
+ if sc.needToSendSettingsAck {
+ sc.needToSendSettingsAck = false
+ sc.startFrameWrite(http2FrameWriteRequest{write: http2writeSettingsAck{}})
+ continue
+ }
+ if !sc.inGoAway || sc.goAwayCode == http2ErrCodeNo {
+ if wr, ok := sc.writeSched.Pop(); ok {
+ sc.startFrameWrite(wr)
+ continue
+ }
+ }
+ if sc.needsFrameFlush {
+ sc.startFrameWrite(http2FrameWriteRequest{write: http2flushFrameWriter{}})
+ sc.needsFrameFlush = false
+ continue
+ }
+ break
}
- if sc.needsFrameFlush {
- sc.startFrameWrite(http2frameWriteMsg{write: http2flushFrameWriter{}})
- sc.needsFrameFlush = false
- return
- }
+ sc.inFrameScheduleLoop = false
+}
+
+// startGracefulShutdown sends a GOAWAY with ErrCodeNo to tell the
+// client we're gracefully shutting down. The connection isn't closed
+// until all current streams are done.
+func (sc *http2serverConn) startGracefulShutdown() {
+ sc.goAwayIn(http2ErrCodeNo, 0)
}
func (sc *http2serverConn) goAway(code http2ErrCode) {
sc.serveG.check()
- if sc.inGoAway {
- return
- }
+ var forceCloseIn time.Duration
if code != http2ErrCodeNo {
- sc.shutDownIn(250 * time.Millisecond)
+ forceCloseIn = 250 * time.Millisecond
} else {
- sc.shutDownIn(1 * time.Second)
+ forceCloseIn = 1 * time.Second
+ }
+ sc.goAwayIn(code, forceCloseIn)
+}
+
+func (sc *http2serverConn) goAwayIn(code http2ErrCode, forceCloseIn time.Duration) {
+ sc.serveG.check()
+ if sc.inGoAway {
+ return
+ }
+ if forceCloseIn != 0 {
+ sc.shutDownIn(forceCloseIn)
}
sc.inGoAway = true
sc.needToSendGoAway = true
@@ -3702,10 +3902,9 @@ func (sc *http2serverConn) shutDownIn(d time.Duration) {
func (sc *http2serverConn) resetStream(se http2StreamError) {
sc.serveG.check()
- sc.writeFrame(http2frameWriteMsg{write: se})
+ sc.writeFrame(http2FrameWriteRequest{write: se})
if st, ok := sc.streams[se.StreamID]; ok {
- st.sentReset = true
- sc.closeStream(st, se)
+ st.resetQueued = true
}
}
@@ -3782,6 +3981,8 @@ func (sc *http2serverConn) processFrame(f http2Frame) error {
return sc.processResetStream(f)
case *http2PriorityFrame:
return sc.processPriority(f)
+ case *http2GoAwayFrame:
+ return sc.processGoAway(f)
case *http2PushPromiseFrame:
return http2ConnectionError(http2ErrCodeProtocol)
@@ -3801,7 +4002,10 @@ func (sc *http2serverConn) processPing(f *http2PingFrame) error {
return http2ConnectionError(http2ErrCodeProtocol)
}
- sc.writeFrame(http2frameWriteMsg{write: http2writePingAck{f}})
+ if sc.inGoAway && sc.goAwayCode != http2ErrCodeNo {
+ return nil
+ }
+ sc.writeFrame(http2FrameWriteRequest{write: http2writePingAck{f}})
return nil
}
@@ -3809,7 +4013,11 @@ func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error
sc.serveG.check()
switch {
case f.StreamID != 0:
- st := sc.streams[f.StreamID]
+ state, st := sc.state(f.StreamID)
+ if state == http2stateIdle {
+
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
if st == nil {
return nil
@@ -3835,7 +4043,6 @@ func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error {
return http2ConnectionError(http2ErrCodeProtocol)
}
if st != nil {
- st.gotReset = true
st.cancelCtx()
sc.closeStream(st, http2streamError(f.StreamID, f.ErrCode))
}
@@ -3848,11 +4055,21 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) {
panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state))
}
st.state = http2stateClosed
- sc.curOpenStreams--
- if sc.curOpenStreams == 0 {
- sc.setConnState(StateIdle)
+ if st.isPushed() {
+ sc.curPushedStreams--
+ } else {
+ sc.curClientStreams--
}
delete(sc.streams, st.id)
+ if len(sc.streams) == 0 {
+ sc.setConnState(StateIdle)
+ if sc.srv.IdleTimeout != 0 {
+ sc.idleTimer.Reset(sc.srv.IdleTimeout)
+ }
+ if http2h1ServerKeepAlivesDisabled(sc.hs) {
+ sc.startGracefulShutdown()
+ }
+ }
if p := st.body; p != nil {
sc.sendWindowUpdate(nil, p.Len())
@@ -3860,11 +4077,7 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) {
p.CloseWithError(err)
}
st.cw.Close()
- sc.writeSched.forgetStream(st.id)
- if st.reqBuf != nil {
-
- sc.freeRequestBodyBuf = st.reqBuf
- }
+ sc.writeSched.CloseStream(st.id)
}
func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error {
@@ -3904,7 +4117,7 @@ func (sc *http2serverConn) processSetting(s http2Setting) error {
case http2SettingInitialWindowSize:
return sc.processSettingInitialWindowSize(s.Val)
case http2SettingMaxFrameSize:
- sc.writeSched.maxFrameSize = s.Val
+ sc.maxFrameSize = int32(s.Val)
case http2SettingMaxHeaderListSize:
sc.peerMaxHeaderListSize = s.Val
default:
@@ -3933,11 +4146,18 @@ func (sc *http2serverConn) processSettingInitialWindowSize(val uint32) error {
func (sc *http2serverConn) processData(f *http2DataFrame) error {
sc.serveG.check()
+ if sc.inGoAway && sc.goAwayCode != http2ErrCodeNo {
+ return nil
+ }
data := f.Data()
id := f.Header().StreamID
- st, ok := sc.streams[id]
- if !ok || st.state != http2stateOpen || st.gotTrailerHeader {
+ state, st := sc.state(id)
+ if id == 0 || state == http2stateIdle {
+
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
+ if st == nil || state != http2stateOpen || st.gotTrailerHeader || st.resetQueued {
if sc.inflow.available() < int32(f.Length) {
return http2streamError(id, http2ErrCodeFlowControl)
@@ -3946,6 +4166,10 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error {
sc.inflow.take(int32(f.Length))
sc.sendWindowUpdate(nil, int(f.Length))
+ if st != nil && st.resetQueued {
+
+ return nil
+ }
return http2streamError(id, http2ErrCodeStreamClosed)
}
if st.body == nil {
@@ -3985,6 +4209,24 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error {
return nil
}
+func (sc *http2serverConn) processGoAway(f *http2GoAwayFrame) error {
+ sc.serveG.check()
+ if f.ErrCode != http2ErrCodeNo {
+ sc.logf("http2: received GOAWAY %+v, starting graceful shutdown", f)
+ } else {
+ sc.vlogf("http2: received GOAWAY %+v, starting graceful shutdown", f)
+ }
+ sc.startGracefulShutdown()
+
+ sc.pushEnabled = false
+ return nil
+}
+
+// isPushed reports whether the stream is server-initiated.
+func (st *http2stream) isPushed() bool {
+ return st.id%2 == 0
+}
+
// endStream closes a Request.Body's pipe. It is called when a DATA
// frame says a request body is over (or after trailers).
func (st *http2stream) endStream() {
@@ -4014,7 +4256,7 @@ func (st *http2stream) copyTrailersToHandlerRequest() {
func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error {
sc.serveG.check()
- id := f.Header().StreamID
+ id := f.StreamID
if sc.inGoAway {
return nil
@@ -4024,50 +4266,43 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error {
return http2ConnectionError(http2ErrCodeProtocol)
}
- st := sc.streams[f.Header().StreamID]
- if st != nil {
+ if st := sc.streams[f.StreamID]; st != nil {
+ if st.resetQueued {
+
+ return nil
+ }
return st.processTrailerHeaders(f)
}
- if id <= sc.maxStreamID {
+ if id <= sc.maxClientStreamID {
return http2ConnectionError(http2ErrCodeProtocol)
}
- sc.maxStreamID = id
+ sc.maxClientStreamID = id
- ctx, cancelCtx := http2contextWithCancel(sc.baseCtx)
- st = &http2stream{
- sc: sc,
- id: id,
- state: http2stateOpen,
- ctx: ctx,
- cancelCtx: cancelCtx,
- }
- if f.StreamEnded() {
- st.state = http2stateHalfClosedRemote
+ if sc.idleTimer != nil {
+ sc.idleTimer.Stop()
}
- st.cw.Init()
- st.flow.conn = &sc.flow
- st.flow.add(sc.initialWindowSize)
- st.inflow.conn = &sc.inflow
- st.inflow.add(http2initialWindowSize)
+ if sc.curClientStreams+1 > sc.advMaxStreams {
+ if sc.unackedSettings == 0 {
- sc.streams[id] = st
- if f.HasPriority() {
- http2adjustStreamPriority(sc.streams, st.id, f.Priority)
- }
- sc.curOpenStreams++
- if sc.curOpenStreams == 1 {
- sc.setConnState(StateActive)
+ return http2streamError(id, http2ErrCodeProtocol)
+ }
+
+ return http2streamError(id, http2ErrCodeRefusedStream)
}
- if sc.curOpenStreams > sc.advMaxStreams {
- if sc.unackedSettings == 0 {
+ initialState := http2stateOpen
+ if f.StreamEnded() {
+ initialState = http2stateHalfClosedRemote
+ }
+ st := sc.newStream(id, 0, initialState)
- return http2streamError(st.id, http2ErrCodeProtocol)
+ if f.HasPriority() {
+ if err := http2checkPriority(f.StreamID, f.Priority); err != nil {
+ return err
}
-
- return http2streamError(st.id, http2ErrCodeRefusedStream)
+ sc.writeSched.AdjustStream(st.id, f.Priority)
}
rw, req, err := sc.newWriterAndRequest(st, f)
@@ -4085,10 +4320,14 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error {
if f.Truncated {
handler = http2handleHeaderListTooLong
- } else if err := http2checkValidHTTP2Request(req); err != nil {
+ } else if err := http2checkValidHTTP2RequestHeaders(req.Header); err != nil {
handler = http2new400Handler(err)
}
+ if sc.hs.ReadTimeout != 0 {
+ sc.conn.SetReadDeadline(time.Time{})
+ }
+
go sc.runHandler(rw, req, handler)
return nil
}
@@ -4121,90 +4360,138 @@ func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error {
return nil
}
-func (sc *http2serverConn) processPriority(f *http2PriorityFrame) error {
- http2adjustStreamPriority(sc.streams, f.StreamID, f.http2PriorityParam)
+func http2checkPriority(streamID uint32, p http2PriorityParam) error {
+ if streamID == p.StreamDep {
+
+ return http2streamError(streamID, http2ErrCodeProtocol)
+ }
return nil
}
-func http2adjustStreamPriority(streams map[uint32]*http2stream, streamID uint32, priority http2PriorityParam) {
- st, ok := streams[streamID]
- if !ok {
+func (sc *http2serverConn) processPriority(f *http2PriorityFrame) error {
+ if sc.inGoAway {
+ return nil
+ }
+ if err := http2checkPriority(f.StreamID, f.http2PriorityParam); err != nil {
+ return err
+ }
+ sc.writeSched.AdjustStream(f.StreamID, f.http2PriorityParam)
+ return nil
+}
- return
+func (sc *http2serverConn) newStream(id, pusherID uint32, state http2streamState) *http2stream {
+ sc.serveG.check()
+ if id == 0 {
+ panic("internal error: cannot create stream with id 0")
}
- st.weight = priority.Weight
- parent := streams[priority.StreamDep]
- if parent == st {
- return
+ ctx, cancelCtx := http2contextWithCancel(sc.baseCtx)
+ st := &http2stream{
+ sc: sc,
+ id: id,
+ state: state,
+ ctx: ctx,
+ cancelCtx: cancelCtx,
}
+ st.cw.Init()
+ st.flow.conn = &sc.flow
+ st.flow.add(sc.initialWindowSize)
+ st.inflow.conn = &sc.inflow
+ st.inflow.add(http2initialWindowSize)
- for piter := parent; piter != nil; piter = piter.parent {
- if piter == st {
- parent.parent = st.parent
- break
- }
+ sc.streams[id] = st
+ sc.writeSched.OpenStream(st.id, http2OpenStreamOptions{PusherID: pusherID})
+ if st.isPushed() {
+ sc.curPushedStreams++
+ } else {
+ sc.curClientStreams++
}
- st.parent = parent
- if priority.Exclusive && (st.parent != nil || priority.StreamDep == 0) {
- for _, openStream := range streams {
- if openStream != st && openStream.parent == st.parent {
- openStream.parent = st
- }
- }
+ if sc.curOpenStreams() == 1 {
+ sc.setConnState(StateActive)
}
+
+ return st
}
func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHeadersFrame) (*http2responseWriter, *Request, error) {
sc.serveG.check()
- method := f.PseudoValue("method")
- path := f.PseudoValue("path")
- scheme := f.PseudoValue("scheme")
- authority := f.PseudoValue("authority")
+ rp := http2requestParam{
+ method: f.PseudoValue("method"),
+ scheme: f.PseudoValue("scheme"),
+ authority: f.PseudoValue("authority"),
+ path: f.PseudoValue("path"),
+ }
- isConnect := method == "CONNECT"
+ isConnect := rp.method == "CONNECT"
if isConnect {
- if path != "" || scheme != "" || authority == "" {
+ if rp.path != "" || rp.scheme != "" || rp.authority == "" {
return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol)
}
- } else if method == "" || path == "" ||
- (scheme != "https" && scheme != "http") {
+ } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") {
return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol)
}
bodyOpen := !f.StreamEnded()
- if method == "HEAD" && bodyOpen {
+ if rp.method == "HEAD" && bodyOpen {
return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol)
}
- var tlsState *tls.ConnectionState // nil if not scheme https
- if scheme == "https" {
- tlsState = sc.tlsState
+ rp.header = make(Header)
+ for _, hf := range f.RegularFields() {
+ rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value)
+ }
+ if rp.authority == "" {
+ rp.authority = rp.header.Get("Host")
}
- header := make(Header)
- for _, hf := range f.RegularFields() {
- header.Add(sc.canonicalHeader(hf.Name), hf.Value)
+ rw, req, err := sc.newWriterAndRequestNoBody(st, rp)
+ if err != nil {
+ return nil, nil, err
}
+ if bodyOpen {
+ st.reqBuf = http2getRequestBodyBuf()
+ req.Body.(*http2requestBody).pipe = &http2pipe{
+ b: &http2fixedBuffer{buf: st.reqBuf},
+ }
- if authority == "" {
- authority = header.Get("Host")
+ if vv, ok := rp.header["Content-Length"]; ok {
+ req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64)
+ } else {
+ req.ContentLength = -1
+ }
}
- needsContinue := header.Get("Expect") == "100-continue"
+ return rw, req, nil
+}
+
+type http2requestParam struct {
+ method string
+ scheme, authority, path string
+ header Header
+}
+
+func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2requestParam) (*http2responseWriter, *Request, error) {
+ sc.serveG.check()
+
+ var tlsState *tls.ConnectionState // nil if not scheme https
+ if rp.scheme == "https" {
+ tlsState = sc.tlsState
+ }
+
+ needsContinue := rp.header.Get("Expect") == "100-continue"
if needsContinue {
- header.Del("Expect")
+ rp.header.Del("Expect")
}
- if cookies := header["Cookie"]; len(cookies) > 1 {
- header.Set("Cookie", strings.Join(cookies, "; "))
+ if cookies := rp.header["Cookie"]; len(cookies) > 1 {
+ rp.header.Set("Cookie", strings.Join(cookies, "; "))
}
// Setup Trailers
var trailer Header
- for _, v := range header["Trailer"] {
+ for _, v := range rp.header["Trailer"] {
for _, key := range strings.Split(v, ",") {
key = CanonicalHeaderKey(strings.TrimSpace(key))
switch key {
@@ -4218,55 +4505,42 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead
}
}
}
- delete(header, "Trailer")
+ delete(rp.header, "Trailer")
- body := &http2requestBody{
- conn: sc,
- stream: st,
- needsContinue: needsContinue,
- }
var url_ *url.URL
var requestURI string
- if isConnect {
- url_ = &url.URL{Host: authority}
- requestURI = authority
+ if rp.method == "CONNECT" {
+ url_ = &url.URL{Host: rp.authority}
+ requestURI = rp.authority
} else {
var err error
- url_, err = url.ParseRequestURI(path)
+ url_, err = url.ParseRequestURI(rp.path)
if err != nil {
- return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol)
+ return nil, nil, http2streamError(st.id, http2ErrCodeProtocol)
}
- requestURI = path
+ requestURI = rp.path
+ }
+
+ body := &http2requestBody{
+ conn: sc,
+ stream: st,
+ needsContinue: needsContinue,
}
req := &Request{
- Method: method,
+ Method: rp.method,
URL: url_,
RemoteAddr: sc.remoteAddrStr,
- Header: header,
+ Header: rp.header,
RequestURI: requestURI,
Proto: "HTTP/2.0",
ProtoMajor: 2,
ProtoMinor: 0,
TLS: tlsState,
- Host: authority,
+ Host: rp.authority,
Body: body,
Trailer: trailer,
}
req = http2requestWithContext(req, st.ctx)
- if bodyOpen {
-
- buf := make([]byte, http2initialWindowSize)
-
- body.pipe = &http2pipe{
- b: &http2fixedBuffer{buf: buf},
- }
-
- if vv, ok := header["Content-Length"]; ok {
- req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64)
- } else {
- req.ContentLength = -1
- }
- }
rws := http2responseWriterStatePool.Get().(*http2responseWriterState)
bwSave := rws.bw
@@ -4282,13 +4556,22 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead
return rw, req, nil
}
-func (sc *http2serverConn) getRequestBodyBuf() []byte {
- sc.serveG.check()
- if buf := sc.freeRequestBodyBuf; buf != nil {
- sc.freeRequestBodyBuf = nil
- return buf
+var http2reqBodyCache = make(chan []byte, 8)
+
+func http2getRequestBodyBuf() []byte {
+ select {
+ case b := <-http2reqBodyCache:
+ return b
+ default:
+ return make([]byte, http2initialWindowSize)
+ }
+}
+
+func http2putRequestBodyBuf(b []byte) {
+ select {
+ case http2reqBodyCache <- b:
+ default:
}
- return make([]byte, http2initialWindowSize)
}
// Run on its own goroutine.
@@ -4298,15 +4581,17 @@ func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *Request, han
rw.rws.stream.cancelCtx()
if didPanic {
e := recover()
- // Same as net/http:
- const size = 64 << 10
- buf := make([]byte, size)
- buf = buf[:runtime.Stack(buf, false)]
- sc.writeFrameFromHandler(http2frameWriteMsg{
+ sc.writeFrameFromHandler(http2FrameWriteRequest{
write: http2handlerPanicRST{rw.rws.stream.id},
stream: rw.rws.stream,
})
- sc.logf("http2: panic serving %v: %v\n%s", sc.conn.RemoteAddr(), e, buf)
+
+ if http2shouldLogPanic(e) {
+ const size = 64 << 10
+ buf := make([]byte, size)
+ buf = buf[:runtime.Stack(buf, false)]
+ sc.logf("http2: panic serving %v: %v\n%s", sc.conn.RemoteAddr(), e, buf)
+ }
return
}
rw.handlerDone()
@@ -4334,7 +4619,7 @@ func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeR
errc = http2errChanPool.Get().(chan error)
}
- if err := sc.writeFrameFromHandler(http2frameWriteMsg{
+ if err := sc.writeFrameFromHandler(http2FrameWriteRequest{
write: headerData,
stream: st,
done: errc,
@@ -4357,7 +4642,7 @@ func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeR
// called from handler goroutines.
func (sc *http2serverConn) write100ContinueHeaders(st *http2stream) {
- sc.writeFrameFromHandler(http2frameWriteMsg{
+ sc.writeFrameFromHandler(http2FrameWriteRequest{
write: http2write100ContinueHeadersFrame{st.id},
stream: st,
})
@@ -4373,11 +4658,19 @@ type http2bodyReadMsg struct {
// called from handler goroutines.
// Notes that the handler for the given stream ID read n bytes of its body
// and schedules flow control tokens to be sent.
-func (sc *http2serverConn) noteBodyReadFromHandler(st *http2stream, n int) {
+func (sc *http2serverConn) noteBodyReadFromHandler(st *http2stream, n int, err error) {
sc.serveG.checkNotOn()
- select {
- case sc.bodyReadCh <- http2bodyReadMsg{st, n}:
- case <-sc.doneServing:
+ if n > 0 {
+ select {
+ case sc.bodyReadCh <- http2bodyReadMsg{st, n}:
+ case <-sc.doneServing:
+ }
+ }
+ if err == io.EOF {
+ if buf := st.reqBuf; buf != nil {
+ st.reqBuf = nil
+ http2putRequestBodyBuf(buf)
+ }
}
}
@@ -4419,7 +4712,7 @@ func (sc *http2serverConn) sendWindowUpdate32(st *http2stream, n int32) {
if st != nil {
streamID = st.id
}
- sc.writeFrame(http2frameWriteMsg{
+ sc.writeFrame(http2FrameWriteRequest{
write: http2writeWindowUpdate{streamID: streamID, n: uint32(n)},
stream: st,
})
@@ -4434,16 +4727,19 @@ func (sc *http2serverConn) sendWindowUpdate32(st *http2stream, n int32) {
}
}
+// requestBody is the Handler's Request.Body type.
+// Read and Close may be called concurrently.
type http2requestBody struct {
stream *http2stream
conn *http2serverConn
- closed bool
+ closed bool // for use by Close only
+ sawEOF bool // for use by Read only
pipe *http2pipe // non-nil if we have a HTTP entity message body
needsContinue bool // need to send a 100-continue
}
func (b *http2requestBody) Close() error {
- if b.pipe != nil {
+ if b.pipe != nil && !b.closed {
b.pipe.BreakWithError(http2errClosedBody)
}
b.closed = true
@@ -4455,13 +4751,17 @@ func (b *http2requestBody) Read(p []byte) (n int, err error) {
b.needsContinue = false
b.conn.write100ContinueHeaders(b.stream)
}
- if b.pipe == nil {
+ if b.pipe == nil || b.sawEOF {
return 0, io.EOF
}
n, err = b.pipe.Read(p)
- if n > 0 {
- b.conn.noteBodyReadFromHandler(b.stream, n)
+ if err == io.EOF {
+ b.sawEOF = true
}
+ if b.conn == nil && http2inTests {
+ return
+ }
+ b.conn.noteBodyReadFromHandler(b.stream, n, err)
return
}
@@ -4696,8 +4996,9 @@ func (w *http2responseWriter) CloseNotify() <-chan bool {
if ch == nil {
ch = make(chan bool, 1)
rws.closeNotifierCh = ch
+ cw := rws.stream.cw
go func() {
- rws.stream.cw.Wait()
+ cw.Wait()
ch <- true
}()
}
@@ -4793,6 +5094,172 @@ func (w *http2responseWriter) handlerDone() {
http2responseWriterStatePool.Put(rws)
}
+// Push errors.
+var (
+ http2ErrRecursivePush = errors.New("http2: recursive push not allowed")
+ http2ErrPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS")
+)
+
+// pushOptions is the internal version of http.PushOptions, which we
+// cannot include here because it's only defined in Go 1.8 and later.
+type http2pushOptions struct {
+ Method string
+ Header Header
+}
+
+func (w *http2responseWriter) push(target string, opts http2pushOptions) error {
+ st := w.rws.stream
+ sc := st.sc
+ sc.serveG.checkNotOn()
+
+ if st.isPushed() {
+ return http2ErrRecursivePush
+ }
+
+ if opts.Method == "" {
+ opts.Method = "GET"
+ }
+ if opts.Header == nil {
+ opts.Header = Header{}
+ }
+ wantScheme := "http"
+ if w.rws.req.TLS != nil {
+ wantScheme = "https"
+ }
+
+ u, err := url.Parse(target)
+ if err != nil {
+ return err
+ }
+ if u.Scheme == "" {
+ if !strings.HasPrefix(target, "/") {
+ return fmt.Errorf("target must be an absolute URL or an absolute path: %q", target)
+ }
+ u.Scheme = wantScheme
+ u.Host = w.rws.req.Host
+ } else {
+ if u.Scheme != wantScheme {
+ return fmt.Errorf("cannot push URL with scheme %q from request with scheme %q", u.Scheme, wantScheme)
+ }
+ if u.Host == "" {
+ return errors.New("URL must have a host")
+ }
+ }
+ for k := range opts.Header {
+ if strings.HasPrefix(k, ":") {
+ return fmt.Errorf("promised request headers cannot include pseudo header %q", k)
+ }
+
+ switch strings.ToLower(k) {
+ case "content-length", "content-encoding", "trailer", "te", "expect", "host":
+ return fmt.Errorf("promised request headers cannot include %q", k)
+ }
+ }
+ if err := http2checkValidHTTP2RequestHeaders(opts.Header); err != nil {
+ return err
+ }
+
+ if opts.Method != "GET" && opts.Method != "HEAD" {
+ return fmt.Errorf("method %q must be GET or HEAD", opts.Method)
+ }
+
+ msg := http2startPushRequest{
+ parent: st,
+ method: opts.Method,
+ url: u,
+ header: http2cloneHeader(opts.Header),
+ done: http2errChanPool.Get().(chan error),
+ }
+
+ select {
+ case <-sc.doneServing:
+ return http2errClientDisconnected
+ case <-st.cw:
+ return http2errStreamClosed
+ case sc.wantStartPushCh <- msg:
+ }
+
+ select {
+ case <-sc.doneServing:
+ return http2errClientDisconnected
+ case <-st.cw:
+ return http2errStreamClosed
+ case err := <-msg.done:
+ http2errChanPool.Put(msg.done)
+ return err
+ }
+}
+
+type http2startPushRequest struct {
+ parent *http2stream
+ method string
+ url *url.URL
+ header Header
+ done chan error
+}
+
+func (sc *http2serverConn) startPush(msg http2startPushRequest) {
+ sc.serveG.check()
+
+ if msg.parent.state != http2stateOpen && msg.parent.state != http2stateHalfClosedRemote {
+
+ msg.done <- http2errStreamClosed
+ return
+ }
+
+ if !sc.pushEnabled {
+ msg.done <- ErrNotSupported
+ return
+ }
+
+ allocatePromisedID := func() (uint32, error) {
+ sc.serveG.check()
+
+ if !sc.pushEnabled {
+ return 0, ErrNotSupported
+ }
+
+ if sc.curPushedStreams+1 > sc.clientMaxStreams {
+ return 0, http2ErrPushLimitReached
+ }
+
+ if sc.maxPushPromiseID+2 >= 1<<31 {
+ sc.startGracefulShutdown()
+ return 0, http2ErrPushLimitReached
+ }
+ sc.maxPushPromiseID += 2
+ promisedID := sc.maxPushPromiseID
+
+ promised := sc.newStream(promisedID, msg.parent.id, http2stateHalfClosedRemote)
+ rw, req, err := sc.newWriterAndRequestNoBody(promised, http2requestParam{
+ method: msg.method,
+ scheme: msg.url.Scheme,
+ authority: msg.url.Host,
+ path: msg.url.RequestURI(),
+ header: http2cloneHeader(msg.header),
+ })
+ if err != nil {
+
+ panic(fmt.Sprintf("newWriterAndRequestNoBody(%+v): %v", msg.url, err))
+ }
+
+ go sc.runHandler(rw, req, sc.handler.ServeHTTP)
+ return promisedID, nil
+ }
+
+ sc.writeFrame(http2FrameWriteRequest{
+ write: &http2writePushPromise{
+ streamID: msg.parent.id,
+ method: msg.method,
+ url: msg.url,
+ h: msg.header,
+ allocatePromisedID: allocatePromisedID,
+ },
+ stream: msg.parent,
+ done: msg.done,
+ })
+}
+
// foreachHeaderElement splits v according to the "#rule" construction
// in RFC 2616 section 2.1 and calls fn for each non-empty element.
func http2foreachHeaderElement(v string, fn func(string)) {
@@ -4820,16 +5287,16 @@ var http2connHeaders = []string{
"Upgrade",
}
-// checkValidHTTP2Request checks whether req is a valid HTTP/2 request,
+// checkValidHTTP2RequestHeaders checks whether h is a valid HTTP/2 request,
// per RFC 7540 Section 8.1.2.2.
// The returned error is reported to users.
-func http2checkValidHTTP2Request(req *Request) error {
- for _, h := range http2connHeaders {
- if _, ok := req.Header[h]; ok {
- return fmt.Errorf("request header %q is not valid in HTTP/2", h)
+func http2checkValidHTTP2RequestHeaders(h Header) error {
+ for _, k := range http2connHeaders {
+ if _, ok := h[k]; ok {
+ return fmt.Errorf("request header %q is not valid in HTTP/2", k)
}
}
- te := req.Header["Te"]
+ te := h["Te"]
if len(te) > 0 && (len(te) > 1 || (te[0] != "trailers" && te[0] != "")) {
return errors.New(`request header "TE" may only be "trailers" in HTTP/2`)
}
@@ -4877,6 +5344,45 @@ var http2badTrailer = map[string]bool{
"Www-Authenticate": true,
}
+// h1ServerShutdownChan returns a channel that will be closed when the
+// provided *http.Server wants to shut down.
+//
+// This is a somewhat hacky way to get at http1 innards. It works
+// when the http2 code is bundled into the net/http package in the
+// standard library. The alternatives ended up making the cmd/go tool
+// depend on http Servers. This is the lightest option for now.
+// This is tested via the TestServeShutdown* tests in net/http.
+func http2h1ServerShutdownChan(hs *Server) <-chan struct{} {
+ if fn := http2testh1ServerShutdownChan; fn != nil {
+ return fn(hs)
+ }
+ var x interface{} = hs
+ type I interface {
+ getDoneChan() <-chan struct{}
+ }
+ if hs, ok := x.(I); ok {
+ return hs.getDoneChan()
+ }
+ return nil
+}
+
+// optional test hook for h1ServerShutdownChan.
+var http2testh1ServerShutdownChan func(hs *Server) <-chan struct{}
+
+// h1ServerKeepAlivesDisabled reports whether hs has its keep-alives
+// disabled. See comments on h1ServerShutdownChan above for why
+// the code is written this way.
+func http2h1ServerKeepAlivesDisabled(hs *Server) bool {
+ var x interface{} = hs
+ type I interface {
+ doKeepAlives() bool
+ }
+ if hs, ok := x.(I); ok {
+ return !hs.doKeepAlives()
+ }
+ return false
+}
+
const (
// transportDefaultConnFlow is how many connection-level flow control
// tokens we give the server at start-up, past the default 64k.
@@ -4997,6 +5503,9 @@ type http2ClientConn struct {
readerDone chan struct{} // closed on error
readerErr error // set before readerDone is closed
+ idleTimeout time.Duration // or 0 for never
+ idleTimer *time.Timer
+
mu sync.Mutex // guards following
cond *sync.Cond // hold mu; broadcast on flow/closed changes
flow http2flow // our conn-level flow control quota (cs.flow is per stream)
@@ -5007,6 +5516,7 @@ type http2ClientConn struct {
goAwayDebug string // goAway frame's debug data, retained as a string
streams map[uint32]*http2clientStream // client-initiated
nextStreamID uint32
+ pings map[[8]byte]chan struct{} // in flight ping data to notification channel
bw *bufio.Writer
br *bufio.Reader
fr *http2Framer
@@ -5033,6 +5543,7 @@ type http2clientStream struct {
ID uint32
resc chan http2resAndError
bufPipe http2pipe // buffered pipe with the flow-controlled response payload
+ startedWrite bool // started request body write; guarded by cc.mu
requestedGzip bool
on100 func() // optional code to run if get a 100 continue response
@@ -5041,6 +5552,7 @@ type http2clientStream struct {
bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read
readErr error // sticky read error; owned by transportResponseBody.Read
stopReqBody error // if non-nil, stop writing req body; guarded by cc.mu
+ didReset bool // whether we sent a RST_STREAM to the server; guarded by cc.mu
peerReset chan struct{} // closed on peer reset
resetErr error // populated before peerReset is closed
@@ -5068,15 +5580,26 @@ func (cs *http2clientStream) awaitRequestCancel(req *Request) {
}
select {
case <-req.Cancel:
+ cs.cancelStream()
cs.bufPipe.CloseWithError(http2errRequestCanceled)
- cs.cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
case <-ctx.Done():
+ cs.cancelStream()
cs.bufPipe.CloseWithError(ctx.Err())
- cs.cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
case <-cs.done:
}
}
+func (cs *http2clientStream) cancelStream() {
+ cs.cc.mu.Lock()
+ didReset := cs.didReset
+ cs.didReset = true
+ cs.cc.mu.Unlock()
+
+ if !didReset {
+ cs.cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
+ }
+}
+
// checkResetOrDone reports any error sent in a RST_STREAM frame by the
// server, or errStreamClosed if the stream is complete.
func (cs *http2clientStream) checkResetOrDone() error {
@@ -5133,14 +5656,22 @@ func (t *http2Transport) RoundTrip(req *Request) (*Response, error) {
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
// and returns a host:port. The port 443 is added if needed.
func http2authorityAddr(scheme string, authority string) (addr string) {
- if _, _, err := net.SplitHostPort(authority); err == nil {
- return authority
+ host, port, err := net.SplitHostPort(authority)
+ if err != nil {
+ port = "443"
+ if scheme == "http" {
+ port = "80"
+ }
+ host = authority
+ }
+ if a, err := idna.ToASCII(host); err == nil {
+ host = a
}
- port := "443"
- if scheme == "http" {
- port = "80"
+
+ if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
+ return host + ":" + port
}
- return net.JoinHostPort(authority, port)
+ return net.JoinHostPort(host, port)
}
// RoundTripOpt is like RoundTrip, but takes options.
@@ -5158,8 +5689,10 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res
}
http2traceGotConn(req, cc)
res, err := cc.RoundTrip(req)
- if http2shouldRetryRequest(req, err) {
- continue
+ if err != nil {
+ if req, err = http2shouldRetryRequest(req, err); err == nil {
+ continue
+ }
}
if err != nil {
t.vlogf("RoundTrip failure: %v", err)
@@ -5181,11 +5714,39 @@ func (t *http2Transport) CloseIdleConnections() {
var (
http2errClientConnClosed = errors.New("http2: client conn is closed")
http2errClientConnUnusable = errors.New("http2: client conn not usable")
+
+ http2errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY")
+ http2errClientConnGotGoAwayAfterSomeReqBody = errors.New("http2: Transport received Server's graceful shutdown GOAWAY; some request body already written")
)
-func http2shouldRetryRequest(req *Request, err error) bool {
+// shouldRetryRequest is called by RoundTrip when a request fails to get
+// response headers. It is always called with a non-nil error.
+// It returns either a request to retry (either the same request, or a
+// modified clone), or an error if the request can't be replayed.
+func http2shouldRetryRequest(req *Request, err error) (*Request, error) {
+ switch err {
+ default:
+ return nil, err
+ case http2errClientConnUnusable, http2errClientConnGotGoAway:
+ return req, nil
+ case http2errClientConnGotGoAwayAfterSomeReqBody:
+
+ if req.Body == nil || http2reqBodyIsNoBody(req.Body) {
+ return req, nil
+ }
- return err == http2errClientConnUnusable
+ getBody := http2reqGetBody(req)
+ if getBody == nil {
+ return nil, errors.New("http2: Transport: peer server initiated graceful shutdown after some of Request.Body was written; define Request.GetBody to avoid this error")
+ }
+ body, err := getBody()
+ if err != nil {
+ return nil, err
+ }
+ newReq := *req
+ newReq.Body = body
+ return &newReq, nil
+ }
}
func (t *http2Transport) dialClientConn(addr string, singleUse bool) (*http2ClientConn, error) {
@@ -5203,7 +5764,7 @@ func (t *http2Transport) dialClientConn(addr string, singleUse bool) (*http2Clie
func (t *http2Transport) newTLSConfig(host string) *tls.Config {
cfg := new(tls.Config)
if t.TLSClientConfig != nil {
- *cfg = *t.TLSClientConfig
+ *cfg = *http2cloneTLSConfig(t.TLSClientConfig)
}
if !http2strSliceContains(cfg.NextProtos, http2NextProtoTLS) {
cfg.NextProtos = append([]string{http2NextProtoTLS}, cfg.NextProtos...)
@@ -5273,6 +5834,11 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client
streams: make(map[uint32]*http2clientStream),
singleUse: singleUse,
wantSettingsAck: true,
+ pings: make(map[[8]byte]chan struct{}),
+ }
+ if d := t.idleConnTimeout(); d != 0 {
+ cc.idleTimeout = d
+ cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout)
}
if http2VerboseLogs {
t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr())
@@ -5328,6 +5894,15 @@ func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) {
if old != nil && old.ErrCode != http2ErrCodeNo {
cc.goAway.ErrCode = old.ErrCode
}
+ last := f.LastStreamID
+ for streamID, cs := range cc.streams {
+ if streamID > last {
+ select {
+ case cs.resc <- http2resAndError{err: http2errClientConnGotGoAway}:
+ default:
+ }
+ }
+ }
}
func (cc *http2ClientConn) CanTakeNewRequest() bool {
@@ -5345,6 +5920,16 @@ func (cc *http2ClientConn) canTakeNewRequestLocked() bool {
cc.nextStreamID < math.MaxInt32
}
+// onIdleTimeout is called from a time.AfterFunc goroutine. It will
+// only be called when we're idle, but because we're coming from a new
+// goroutine, there could be a new request coming in at the same time,
+// so this simply calls the synchronized closeIfIdle to shut down this
+// connection. The timer could just call closeIfIdle, but this is more
+// clear.
+func (cc *http2ClientConn) onIdleTimeout() {
+ cc.closeIfIdle()
+}
+
func (cc *http2ClientConn) closeIfIdle() {
cc.mu.Lock()
if len(cc.streams) > 0 {
@@ -5437,48 +6022,37 @@ func (cc *http2ClientConn) responseHeaderTimeout() time.Duration {
// Certain headers are special-cased as okay but not transmitted later.
func http2checkConnHeaders(req *Request) error {
if v := req.Header.Get("Upgrade"); v != "" {
- return errors.New("http2: invalid Upgrade request header")
+ return fmt.Errorf("http2: invalid Upgrade request header: %q", req.Header["Upgrade"])
}
- if v := req.Header.Get("Transfer-Encoding"); (v != "" && v != "chunked") || len(req.Header["Transfer-Encoding"]) > 1 {
- return errors.New("http2: invalid Transfer-Encoding request header")
+ if vv := req.Header["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") {
+ return fmt.Errorf("http2: invalid Transfer-Encoding request header: %q", vv)
}
- if v := req.Header.Get("Connection"); (v != "" && v != "close" && v != "keep-alive") || len(req.Header["Connection"]) > 1 {
- return errors.New("http2: invalid Connection request header")
+ if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "close" && vv[0] != "keep-alive") {
+ return fmt.Errorf("http2: invalid Connection request header: %q", vv)
}
return nil
}
-func http2bodyAndLength(req *Request) (body io.Reader, contentLen int64) {
- body = req.Body
- if body == nil {
- return nil, 0
+// actualContentLength returns a sanitized version of
+// req.ContentLength, where 0 actually means zero (not unknown) and -1
+// means unknown.
+func http2actualContentLength(req *Request) int64 {
+ if req.Body == nil {
+ return 0
}
if req.ContentLength != 0 {
- return req.Body, req.ContentLength
- }
-
- // We have a body but a zero content length. Test to see if
- // it's actually zero or just unset.
- var buf [1]byte
- n, rerr := body.Read(buf[:])
- if rerr != nil && rerr != io.EOF {
- return http2errorReader{rerr}, -1
- }
- if n == 1 {
-
- if rerr == io.EOF {
- return bytes.NewReader(buf[:]), 1
- }
- return io.MultiReader(bytes.NewReader(buf[:]), body), -1
+ return req.ContentLength
}
-
- return nil, 0
+ return -1
}
func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
if err := http2checkConnHeaders(req); err != nil {
return nil, err
}
+ if cc.idleTimer != nil {
+ cc.idleTimer.Stop()
+ }
trailers, err := http2commaSeparatedTrailers(req)
if err != nil {
@@ -5486,9 +6060,6 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
}
hasTrailers := trailers != ""
- body, contentLen := http2bodyAndLength(req)
- hasBody := body != nil
-
cc.mu.Lock()
cc.lastActive = time.Now()
if cc.closed || !cc.canTakeNewRequestLocked() {
@@ -5496,6 +6067,10 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
return nil, http2errClientConnUnusable
}
+ body := req.Body
+ hasBody := body != nil
+ contentLen := http2actualContentLength(req)
+
// TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
var requestedGzip bool
if !cc.t.disableCompression() &&
@@ -5561,6 +6136,13 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
cs.abortRequestBodyWrite(http2errStopReqBodyWrite)
}
if re.err != nil {
+ if re.err == http2errClientConnGotGoAway {
+ cc.mu.Lock()
+ if cs.startedWrite {
+ re.err = http2errClientConnGotGoAwayAfterSomeReqBody
+ }
+ cc.mu.Unlock()
+ }
cc.forgetStreamID(cs.ID)
return nil, re.err
}
@@ -5806,6 +6388,26 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail
if host == "" {
host = req.URL.Host
}
+ host, err := httplex.PunycodeHostPort(host)
+ if err != nil {
+ return nil, err
+ }
+
+ var path string
+ if req.Method != "CONNECT" {
+ path = req.URL.RequestURI()
+ if !http2validPseudoPath(path) {
+ orig := path
+ path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
+ if !http2validPseudoPath(path) {
+ if req.URL.Opaque != "" {
+ return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
+ } else {
+ return nil, fmt.Errorf("invalid request :path %q", orig)
+ }
+ }
+ }
+ }
for k, vv := range req.Header {
if !httplex.ValidHeaderFieldName(k) {
@@ -5821,8 +6423,8 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail
cc.writeHeader(":authority", host)
cc.writeHeader(":method", req.Method)
if req.Method != "CONNECT" {
- cc.writeHeader(":path", req.URL.RequestURI())
- cc.writeHeader(":scheme", "https")
+ cc.writeHeader(":path", path)
+ cc.writeHeader(":scheme", req.URL.Scheme)
}
if trailers != "" {
cc.writeHeader("trailer", trailers)
@@ -5940,6 +6542,9 @@ func (cc *http2ClientConn) streamByID(id uint32, andRemove bool) *http2clientStr
if andRemove && cs != nil && !cc.closed {
cc.lastActive = time.Now()
delete(cc.streams, id)
+ if len(cc.streams) == 0 && cc.idleTimer != nil {
+ cc.idleTimer.Reset(cc.idleTimeout)
+ }
close(cs.done)
cc.cond.Broadcast()
}
@@ -5996,6 +6601,10 @@ func (rl *http2clientConnReadLoop) cleanup() {
defer cc.t.connPool().MarkDead(cc)
defer close(cc.readerDone)
+ if cc.idleTimer != nil {
+ cc.idleTimer.Stop()
+ }
+
err := cc.readerErr
cc.mu.Lock()
if cc.goAway != nil && http2isEOFOrNetReadError(err) {
@@ -6398,9 +7007,10 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error {
cc.bw.Flush()
cc.wmu.Unlock()
}
+ didReset := cs.didReset
cc.mu.Unlock()
- if len(data) > 0 {
+ if len(data) > 0 && !didReset {
if _, err := cs.bufPipe.Write(data); err != nil {
rl.endStreamError(cs, err)
return err
@@ -6551,9 +7161,56 @@ func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) er
return nil
}
+// Ping sends a PING frame to the server and waits for the ack.
+// Public implementation is in go17.go and not_go17.go
+func (cc *http2ClientConn) ping(ctx http2contextContext) error {
+ c := make(chan struct{})
+ // Generate a random payload
+ var p [8]byte
+ for {
+ if _, err := rand.Read(p[:]); err != nil {
+ return err
+ }
+ cc.mu.Lock()
+
+ if _, found := cc.pings[p]; !found {
+ cc.pings[p] = c
+ cc.mu.Unlock()
+ break
+ }
+ cc.mu.Unlock()
+ }
+ cc.wmu.Lock()
+ if err := cc.fr.WritePing(false, p); err != nil {
+ cc.wmu.Unlock()
+ return err
+ }
+ if err := cc.bw.Flush(); err != nil {
+ cc.wmu.Unlock()
+ return err
+ }
+ cc.wmu.Unlock()
+ select {
+ case <-c:
+ return nil
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-cc.readerDone:
+
+ return cc.readerErr
+ }
+}
+
func (rl *http2clientConnReadLoop) processPing(f *http2PingFrame) error {
if f.IsAck() {
+ cc := rl.cc
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ if c, ok := cc.pings[f.Data]; ok {
+ close(c)
+ delete(cc.pings, f.Data)
+ }
return nil
}
cc := rl.cc
@@ -6666,6 +7323,9 @@ func (t *http2Transport) getBodyWriterState(cs *http2clientStream, body io.Reade
resc := make(chan error, 1)
s.resc = resc
s.fn = func() {
+ cs.cc.mu.Lock()
+ cs.startedWrite = true
+ cs.cc.mu.Unlock()
resc <- cs.writeRequestBody(body, cs.req.Body)
}
s.delay = t.expectContinueTimeout()
@@ -6728,6 +7388,11 @@ func http2isConnectionCloseRequest(req *Request) bool {
// writeFramer is implemented by any type that is used to write frames.
type http2writeFramer interface {
writeFrame(http2writeContext) error
+
+ // staysWithinBuffer reports whether this writer promises that
+ // it will only write less than or equal to size bytes, and it
+ // won't Flush the write context.
+ staysWithinBuffer(size int) bool
}
// writeContext is the interface needed by the various frame writer
@@ -6749,9 +7414,10 @@ type http2writeContext interface {
HeaderEncoder() (*hpack.Encoder, *bytes.Buffer)
}
-// endsStream reports whether the given frame writer w will locally
-// close the stream.
-func http2endsStream(w http2writeFramer) bool {
+// writeEndsStream reports whether w writes a frame that will transition
+// the stream to a half-closed local state. This returns false for RST_STREAM,
+// which closes the entire stream (not just the local half).
+func http2writeEndsStream(w http2writeFramer) bool {
switch v := w.(type) {
case *http2writeData:
return v.endStream
@@ -6759,7 +7425,7 @@ func http2endsStream(w http2writeFramer) bool {
return v.endStream
case nil:
- panic("endsStream called on nil writeFramer")
+ panic("writeEndsStream called on nil writeFramer")
}
return false
}
@@ -6770,8 +7436,16 @@ func (http2flushFrameWriter) writeFrame(ctx http2writeContext) error {
return ctx.Flush()
}
+func (http2flushFrameWriter) staysWithinBuffer(max int) bool { return false }
+
type http2writeSettings []http2Setting
+func (s http2writeSettings) staysWithinBuffer(max int) bool {
+ const settingSize = 6 // uint16 + uint32
+ return http2frameHeaderLen+settingSize*len(s) <= max
+
+}
+
func (s http2writeSettings) writeFrame(ctx http2writeContext) error {
return ctx.Framer().WriteSettings([]http2Setting(s)...)
}
@@ -6791,6 +7465,8 @@ func (p *http2writeGoAway) writeFrame(ctx http2writeContext) error {
return err
}
+func (*http2writeGoAway) staysWithinBuffer(max int) bool { return false }
+
type http2writeData struct {
streamID uint32
p []byte
@@ -6805,6 +7481,10 @@ func (w *http2writeData) writeFrame(ctx http2writeContext) error {
return ctx.Framer().WriteData(w.streamID, w.endStream, w.p)
}
+func (w *http2writeData) staysWithinBuffer(max int) bool {
+ return http2frameHeaderLen+len(w.p) <= max
+}
+
// handlerPanicRST is the message sent from handler goroutines when
// the handler panics.
type http2handlerPanicRST struct {
@@ -6815,22 +7495,59 @@ func (hp http2handlerPanicRST) writeFrame(ctx http2writeContext) error {
return ctx.Framer().WriteRSTStream(hp.StreamID, http2ErrCodeInternal)
}
+func (hp http2handlerPanicRST) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max }
+
func (se http2StreamError) writeFrame(ctx http2writeContext) error {
return ctx.Framer().WriteRSTStream(se.StreamID, se.Code)
}
+func (se http2StreamError) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max }
+
type http2writePingAck struct{ pf *http2PingFrame }
func (w http2writePingAck) writeFrame(ctx http2writeContext) error {
return ctx.Framer().WritePing(true, w.pf.Data)
}
+func (w http2writePingAck) staysWithinBuffer(max int) bool {
+ return http2frameHeaderLen+len(w.pf.Data) <= max
+}
+
type http2writeSettingsAck struct{}
func (http2writeSettingsAck) writeFrame(ctx http2writeContext) error {
return ctx.Framer().WriteSettingsAck()
}
+func (http2writeSettingsAck) staysWithinBuffer(max int) bool { return http2frameHeaderLen <= max }
+
+// splitHeaderBlock splits headerBlock into fragments so that each fragment fits
+// in a single frame, then calls fn for each fragment. firstFrag/lastFrag are true
+// for the first/last fragment, respectively.
+func http2splitHeaderBlock(ctx http2writeContext, headerBlock []byte, fn func(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error) error {
+ // For now we're lazy and just pick the minimum MAX_FRAME_SIZE
+ // that all peers must support (16KB). Later we could care
+ // more and send larger frames if the peer advertised it, but
+ // there's little point. Most headers are small anyway (so we
+ // generally won't have CONTINUATION frames), and extra frames
+ // only waste 9 bytes anyway.
+ const maxFrameSize = 16384
+
+ first := true
+ for len(headerBlock) > 0 {
+ frag := headerBlock
+ if len(frag) > maxFrameSize {
+ frag = frag[:maxFrameSize]
+ }
+ headerBlock = headerBlock[len(frag):]
+ if err := fn(ctx, frag, first, len(headerBlock) == 0); err != nil {
+ return err
+ }
+ first = false
+ }
+ return nil
+}
+
// writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames
// for HTTP response headers or trailers from a server handler.
type http2writeResHeaders struct {
@@ -6852,6 +7569,11 @@ func http2encKV(enc *hpack.Encoder, k, v string) {
enc.WriteField(hpack.HeaderField{Name: k, Value: v})
}
+func (w *http2writeResHeaders) staysWithinBuffer(max int) bool {
+
+ return false
+}
+
func (w *http2writeResHeaders) writeFrame(ctx http2writeContext) error {
enc, buf := ctx.HeaderEncoder()
buf.Reset()
@@ -6877,39 +7599,69 @@ func (w *http2writeResHeaders) writeFrame(ctx http2writeContext) error {
panic("unexpected empty hpack")
}
- // For now we're lazy and just pick the minimum MAX_FRAME_SIZE
- // that all peers must support (16KB). Later we could care
- // more and send larger frames if the peer advertised it, but
- // there's little point. Most headers are small anyway (so we
- // generally won't have CONTINUATION frames), and extra frames
- // only waste 9 bytes anyway.
- const maxFrameSize = 16384
+ return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock)
+}
- first := true
- for len(headerBlock) > 0 {
- frag := headerBlock
- if len(frag) > maxFrameSize {
- frag = frag[:maxFrameSize]
- }
- headerBlock = headerBlock[len(frag):]
- endHeaders := len(headerBlock) == 0
- var err error
- if first {
- first = false
- err = ctx.Framer().WriteHeaders(http2HeadersFrameParam{
- StreamID: w.streamID,
- BlockFragment: frag,
- EndStream: w.endStream,
- EndHeaders: endHeaders,
- })
- } else {
- err = ctx.Framer().WriteContinuation(w.streamID, endHeaders, frag)
- }
- if err != nil {
- return err
- }
+func (w *http2writeResHeaders) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error {
+ if firstFrag {
+ return ctx.Framer().WriteHeaders(http2HeadersFrameParam{
+ StreamID: w.streamID,
+ BlockFragment: frag,
+ EndStream: w.endStream,
+ EndHeaders: lastFrag,
+ })
+ } else {
+ return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag)
+ }
+}
+
+// writePushPromise is a request to write a PUSH_PROMISE and 0+ CONTINUATION frames.
+type http2writePushPromise struct {
+ streamID uint32 // pusher stream
+ method string // for :method
+ url *url.URL // for :scheme, :authority, :path
+ h Header
+
+ // Creates an ID for a pushed stream. This runs on serveG just before
+ // the frame is written. The returned ID is copied to promisedID.
+ allocatePromisedID func() (uint32, error)
+ promisedID uint32
+}
+
+func (w *http2writePushPromise) staysWithinBuffer(max int) bool {
+
+ return false
+}
+
+func (w *http2writePushPromise) writeFrame(ctx http2writeContext) error {
+ enc, buf := ctx.HeaderEncoder()
+ buf.Reset()
+
+ http2encKV(enc, ":method", w.method)
+ http2encKV(enc, ":scheme", w.url.Scheme)
+ http2encKV(enc, ":authority", w.url.Host)
+ http2encKV(enc, ":path", w.url.RequestURI())
+ http2encodeHeaders(enc, w.h, nil)
+
+ headerBlock := buf.Bytes()
+ if len(headerBlock) == 0 {
+ panic("unexpected empty hpack")
+ }
+
+ return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock)
+}
+
+func (w *http2writePushPromise) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error {
+ if firstFrag {
+ return ctx.Framer().WritePushPromise(http2PushPromiseParam{
+ StreamID: w.streamID,
+ PromiseID: w.promisedID,
+ BlockFragment: frag,
+ EndHeaders: lastFrag,
+ })
+ } else {
+ return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag)
}
- return nil
}
type http2write100ContinueHeadersFrame struct {
@@ -6928,15 +7680,24 @@ func (w http2write100ContinueHeadersFrame) writeFrame(ctx http2writeContext) err
})
}
+func (w http2write100ContinueHeadersFrame) staysWithinBuffer(max int) bool {
+
+ return 9+2*(len(":status")+len("100")) <= max
+}
+
type http2writeWindowUpdate struct {
streamID uint32 // or 0 for conn-level
n uint32
}
+func (wu http2writeWindowUpdate) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max }
+
func (wu http2writeWindowUpdate) writeFrame(ctx http2writeContext) error {
return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n)
}
+// encodeHeaders encodes an http.Header. If keys is not nil, then (k, h[k])
+// is encoded only only if k is in keys.
func http2encodeHeaders(enc *hpack.Encoder, h Header, keys []string) {
if keys == nil {
sorter := http2sorterPool.Get().(*http2sorter)
@@ -6966,14 +7727,53 @@ func http2encodeHeaders(enc *hpack.Encoder, h Header, keys []string) {
}
}
-// frameWriteMsg is a request to write a frame.
-type http2frameWriteMsg struct {
+// WriteScheduler is the interface implemented by HTTP/2 write schedulers.
+// Methods are never called concurrently.
+type http2WriteScheduler interface {
+ // OpenStream opens a new stream in the write scheduler.
+ // It is illegal to call this with streamID=0 or with a streamID that is
+ // already open -- the call may panic.
+ OpenStream(streamID uint32, options http2OpenStreamOptions)
+
+ // CloseStream closes a stream in the write scheduler. Any frames queued on
+ // this stream should be discarded. It is illegal to call this on a stream
+ // that is not open -- the call may panic.
+ CloseStream(streamID uint32)
+
+ // AdjustStream adjusts the priority of the given stream. This may be called
+ // on a stream that has not yet been opened or has been closed. Note that
+ // RFC 7540 allows PRIORITY frames to be sent on streams in any state. See:
+ // https://tools.ietf.org/html/rfc7540#section-5.1
+ AdjustStream(streamID uint32, priority http2PriorityParam)
+
+ // Push queues a frame in the scheduler. In most cases, this will not be
+ // called with wr.StreamID()!=0 unless that stream is currently open. The one
+ // exception is RST_STREAM frames, which may be sent on idle or closed streams.
+ Push(wr http2FrameWriteRequest)
+
+ // Pop dequeues the next frame to write. Returns false if no frames can
+ // be written. Frames with a given wr.StreamID() are Pop'd in the same
+ // order they are Push'd.
+ Pop() (wr http2FrameWriteRequest, ok bool)
+}
+
+// OpenStreamOptions specifies extra options for WriteScheduler.OpenStream.
+type http2OpenStreamOptions struct {
+ // PusherID is zero if the stream was initiated by the client. Otherwise,
+ // PusherID names the stream that pushed the newly opened stream.
+ PusherID uint32
+}
+
+// FrameWriteRequest is a request to write a frame.
+type http2FrameWriteRequest struct {
// write is the interface value that does the writing, once the
- // writeScheduler (below) has decided to select this frame
- // to write. The write functions are all defined in write.go.
+ // WriteScheduler has selected this frame to write. The write
+ // functions are all defined in write.go.
write http2writeFramer
- stream *http2stream // used for prioritization. nil for non-stream frames.
+ // stream is the stream on which this frame will be written.
+ // nil for non-stream frames like PING and SETTINGS.
+ stream *http2stream
// done, if non-nil, must be a buffered channel with space for
// 1 message and is sent the return value from write (or an
@@ -6981,247 +7781,644 @@ type http2frameWriteMsg struct {
done chan error
}
-// for debugging only:
-func (wm http2frameWriteMsg) String() string {
- var streamID uint32
- if wm.stream != nil {
- streamID = wm.stream.id
+// StreamID returns the id of the stream this frame will be written to.
+// 0 is used for non-stream frames such as PING and SETTINGS.
+func (wr http2FrameWriteRequest) StreamID() uint32 {
+ if wr.stream == nil {
+ if se, ok := wr.write.(http2StreamError); ok {
+
+ return se.StreamID
+ }
+ return 0
+ }
+ return wr.stream.id
+}
+
+// DataSize returns the number of flow control bytes that must be consumed
+// to write this entire frame. This is 0 for non-DATA frames.
+func (wr http2FrameWriteRequest) DataSize() int {
+ if wd, ok := wr.write.(*http2writeData); ok {
+ return len(wd.p)
+ }
+ return 0
+}
+
+// Consume consumes min(n, available) bytes from this frame, where available
+// is the number of flow control bytes available on the stream. Consume returns
+// 0, 1, or 2 frames, where the integer return value gives the number of frames
+// returned.
+//
+// If flow control prevents consuming any bytes, this returns (_, _, 0). If
+// the entire frame was consumed, this returns (wr, _, 1). Otherwise, this
+// returns (consumed, rest, 2), where 'consumed' contains the consumed bytes and
+// 'rest' contains the remaining bytes. The consumed bytes are deducted from the
+// underlying stream's flow control budget.
+func (wr http2FrameWriteRequest) Consume(n int32) (http2FrameWriteRequest, http2FrameWriteRequest, int) {
+ var empty http2FrameWriteRequest
+
+ wd, ok := wr.write.(*http2writeData)
+ if !ok || len(wd.p) == 0 {
+ return wr, empty, 1
+ }
+
+ allowed := wr.stream.flow.available()
+ if n < allowed {
+ allowed = n
+ }
+ if wr.stream.sc.maxFrameSize < allowed {
+ allowed = wr.stream.sc.maxFrameSize
+ }
+ if allowed <= 0 {
+ return empty, empty, 0
+ }
+ if len(wd.p) > int(allowed) {
+ wr.stream.flow.take(allowed)
+ consumed := http2FrameWriteRequest{
+ stream: wr.stream,
+ write: &http2writeData{
+ streamID: wd.streamID,
+ p: wd.p[:allowed],
+
+ endStream: false,
+ },
+
+ done: nil,
+ }
+ rest := http2FrameWriteRequest{
+ stream: wr.stream,
+ write: &http2writeData{
+ streamID: wd.streamID,
+ p: wd.p[allowed:],
+ endStream: wd.endStream,
+ },
+ done: wr.done,
+ }
+ return consumed, rest, 2
}
+
+ wr.stream.flow.take(int32(len(wd.p)))
+ return wr, empty, 1
+}
+
+// String is for debugging only.
+func (wr http2FrameWriteRequest) String() string {
var des string
- if s, ok := wm.write.(fmt.Stringer); ok {
+ if s, ok := wr.write.(fmt.Stringer); ok {
des = s.String()
} else {
- des = fmt.Sprintf("%T", wm.write)
+ des = fmt.Sprintf("%T", wr.write)
}
- return fmt.Sprintf("[frameWriteMsg stream=%d, ch=%v, type: %v]", streamID, wm.done != nil, des)
+ return fmt.Sprintf("[FrameWriteRequest stream=%d, ch=%v, writer=%v]", wr.StreamID(), wr.done != nil, des)
}
-// writeScheduler tracks pending frames to write, priorities, and decides
-// the next one to use. It is not thread-safe.
-type http2writeScheduler struct {
- // zero are frames not associated with a specific stream.
- // They're sent before any stream-specific freams.
- zero http2writeQueue
+// replyToWriter sends err to wr.done and panics if the send must block
+// This does nothing if wr.done is nil.
+func (wr *http2FrameWriteRequest) replyToWriter(err error) {
+ if wr.done == nil {
+ return
+ }
+ select {
+ case wr.done <- err:
+ default:
+ panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wr.write))
+ }
+ wr.write = nil
+}
- // maxFrameSize is the maximum size of a DATA frame
- // we'll write. Must be non-zero and between 16K-16M.
- maxFrameSize uint32
+// writeQueue is used by implementations of WriteScheduler.
+type http2writeQueue struct {
+ s []http2FrameWriteRequest
+}
- // sq contains the stream-specific queues, keyed by stream ID.
- // when a stream is idle, it's deleted from the map.
- sq map[uint32]*http2writeQueue
+func (q *http2writeQueue) empty() bool { return len(q.s) == 0 }
- // canSend is a slice of memory that's reused between frame
- // scheduling decisions to hold the list of writeQueues (from sq)
- // which have enough flow control data to send. After canSend is
- // built, the best is selected.
- canSend []*http2writeQueue
+func (q *http2writeQueue) push(wr http2FrameWriteRequest) {
+ q.s = append(q.s, wr)
+}
- // pool of empty queues for reuse.
- queuePool []*http2writeQueue
+func (q *http2writeQueue) shift() http2FrameWriteRequest {
+ if len(q.s) == 0 {
+ panic("invalid use of queue")
+ }
+ wr := q.s[0]
+
+ copy(q.s, q.s[1:])
+ q.s[len(q.s)-1] = http2FrameWriteRequest{}
+ q.s = q.s[:len(q.s)-1]
+ return wr
}
-func (ws *http2writeScheduler) putEmptyQueue(q *http2writeQueue) {
- if len(q.s) != 0 {
- panic("queue must be empty")
+// consume consumes up to n bytes from q.s[0]. If the frame is
+// entirely consumed, it is removed from the queue. If the frame
+// is partially consumed, the frame is kept with the consumed
+// bytes removed. Returns true iff any bytes were consumed.
+func (q *http2writeQueue) consume(n int32) (http2FrameWriteRequest, bool) {
+ if len(q.s) == 0 {
+ return http2FrameWriteRequest{}, false
}
- ws.queuePool = append(ws.queuePool, q)
+ consumed, rest, numresult := q.s[0].Consume(n)
+ switch numresult {
+ case 0:
+ return http2FrameWriteRequest{}, false
+ case 1:
+ q.shift()
+ case 2:
+ q.s[0] = rest
+ }
+ return consumed, true
}
-func (ws *http2writeScheduler) getEmptyQueue() *http2writeQueue {
- ln := len(ws.queuePool)
+type http2writeQueuePool []*http2writeQueue
+
+// put inserts an unused writeQueue into the pool.
+func (p *http2writeQueuePool) put(q *http2writeQueue) {
+ for i := range q.s {
+ q.s[i] = http2FrameWriteRequest{}
+ }
+ q.s = q.s[:0]
+ *p = append(*p, q)
+}
+
+// get returns an empty writeQueue.
+func (p *http2writeQueuePool) get() *http2writeQueue {
+ ln := len(*p)
if ln == 0 {
return new(http2writeQueue)
}
- q := ws.queuePool[ln-1]
- ws.queuePool = ws.queuePool[:ln-1]
+ x := ln - 1
+ q := (*p)[x]
+ (*p)[x] = nil
+ *p = (*p)[:x]
return q
}
-func (ws *http2writeScheduler) empty() bool { return ws.zero.empty() && len(ws.sq) == 0 }
+// RFC 7540, Section 5.3.5: the default weight is 16.
+const http2priorityDefaultWeight = 15 // 16 = 15 + 1
-func (ws *http2writeScheduler) add(wm http2frameWriteMsg) {
- st := wm.stream
- if st == nil {
- ws.zero.push(wm)
+// PriorityWriteSchedulerConfig configures a priorityWriteScheduler.
+type http2PriorityWriteSchedulerConfig struct {
+ // MaxClosedNodesInTree controls the maximum number of closed streams to
+ // retain in the priority tree. Setting this to zero saves a small amount
+ // of memory at the cost of performance.
+ //
+ // See RFC 7540, Section 5.3.4:
+ // "It is possible for a stream to become closed while prioritization
+ // information ... is in transit. ... This potentially creates suboptimal
+ // prioritization, since the stream could be given a priority that is
+ // different from what is intended. To avoid these problems, an endpoint
+ // SHOULD retain stream prioritization state for a period after streams
+ // become closed. The longer state is retained, the lower the chance that
+ // streams are assigned incorrect or default priority values."
+ MaxClosedNodesInTree int
+
+ // MaxIdleNodesInTree controls the maximum number of idle streams to
+ // retain in the priority tree. Setting this to zero saves a small amount
+ // of memory at the cost of performance.
+ //
+ // See RFC 7540, Section 5.3.4:
+ // Similarly, streams that are in the "idle" state can be assigned
+ // priority or become a parent of other streams. This allows for the
+ // creation of a grouping node in the dependency tree, which enables
+ // more flexible expressions of priority. Idle streams begin with a
+ // default priority (Section 5.3.5).
+ MaxIdleNodesInTree int
+
+ // ThrottleOutOfOrderWrites enables write throttling to help ensure that
+ // data is delivered in priority order. This works around a race where
+ // stream B depends on stream A and both streams are about to call Write
+ // to queue DATA frames. If B wins the race, a naive scheduler would eagerly
+ // write as much data from B as possible, but this is suboptimal because A
+ // is a higher-priority stream. With throttling enabled, we write a small
+ // amount of data from B to minimize the amount of bandwidth that B can
+ // steal from A.
+ ThrottleOutOfOrderWrites bool
+}
+
+// NewPriorityWriteScheduler constructs a WriteScheduler that schedules
+// frames by following HTTP/2 priorities as described in RFC 7340 Section 5.3.
+// If cfg is nil, default options are used.
+func http2NewPriorityWriteScheduler(cfg *http2PriorityWriteSchedulerConfig) http2WriteScheduler {
+ if cfg == nil {
+
+ cfg = &http2PriorityWriteSchedulerConfig{
+ MaxClosedNodesInTree: 10,
+ MaxIdleNodesInTree: 10,
+ ThrottleOutOfOrderWrites: false,
+ }
+ }
+
+ ws := &http2priorityWriteScheduler{
+ nodes: make(map[uint32]*http2priorityNode),
+ maxClosedNodesInTree: cfg.MaxClosedNodesInTree,
+ maxIdleNodesInTree: cfg.MaxIdleNodesInTree,
+ enableWriteThrottle: cfg.ThrottleOutOfOrderWrites,
+ }
+ ws.nodes[0] = &ws.root
+ if cfg.ThrottleOutOfOrderWrites {
+ ws.writeThrottleLimit = 1024
} else {
- ws.streamQueue(st.id).push(wm)
+ ws.writeThrottleLimit = math.MaxInt32
}
+ return ws
+}
+
+type http2priorityNodeState int
+
+const (
+ http2priorityNodeOpen http2priorityNodeState = iota
+ http2priorityNodeClosed
+ http2priorityNodeIdle
+)
+
+// priorityNode is a node in an HTTP/2 priority tree.
+// Each node is associated with a single stream ID.
+// See RFC 7540, Section 5.3.
+type http2priorityNode struct {
+ q http2writeQueue // queue of pending frames to write
+ id uint32 // id of the stream, or 0 for the root of the tree
+ weight uint8 // the actual weight is weight+1, so the value is in [1,256]
+ state http2priorityNodeState // open | closed | idle
+ bytes int64 // number of bytes written by this node, or 0 if closed
+ subtreeBytes int64 // sum(node.bytes) of all nodes in this subtree
+
+ // These links form the priority tree.
+ parent *http2priorityNode
+ kids *http2priorityNode // start of the kids list
+ prev, next *http2priorityNode // doubly-linked list of siblings
}
-func (ws *http2writeScheduler) streamQueue(streamID uint32) *http2writeQueue {
- if q, ok := ws.sq[streamID]; ok {
- return q
+func (n *http2priorityNode) setParent(parent *http2priorityNode) {
+ if n == parent {
+ panic("setParent to self")
}
- if ws.sq == nil {
- ws.sq = make(map[uint32]*http2writeQueue)
+ if n.parent == parent {
+ return
+ }
+
+ if parent := n.parent; parent != nil {
+ if n.prev == nil {
+ parent.kids = n.next
+ } else {
+ n.prev.next = n.next
+ }
+ if n.next != nil {
+ n.next.prev = n.prev
+ }
+ }
+
+ n.parent = parent
+ if parent == nil {
+ n.next = nil
+ n.prev = nil
+ } else {
+ n.next = parent.kids
+ n.prev = nil
+ if n.next != nil {
+ n.next.prev = n
+ }
+ parent.kids = n
}
- q := ws.getEmptyQueue()
- ws.sq[streamID] = q
- return q
}
-// take returns the most important frame to write and removes it from the scheduler.
-// It is illegal to call this if the scheduler is empty or if there are no connection-level
-// flow control bytes available.
-func (ws *http2writeScheduler) take() (wm http2frameWriteMsg, ok bool) {
- if ws.maxFrameSize == 0 {
- panic("internal error: ws.maxFrameSize not initialized or invalid")
+func (n *http2priorityNode) addBytes(b int64) {
+ n.bytes += b
+ for ; n != nil; n = n.parent {
+ n.subtreeBytes += b
}
+}
- if !ws.zero.empty() {
- return ws.zero.shift(), true
+// walkReadyInOrder iterates over the tree in priority order, calling f for each node
+// with a non-empty write queue. When f returns true, this funcion returns true and the
+// walk halts. tmp is used as scratch space for sorting.
+//
+// f(n, openParent) takes two arguments: the node to visit, n, and a bool that is true
+// if any ancestor p of n is still open (ignoring the root node).
+func (n *http2priorityNode) walkReadyInOrder(openParent bool, tmp *[]*http2priorityNode, f func(*http2priorityNode, bool) bool) bool {
+ if !n.q.empty() && f(n, openParent) {
+ return true
}
- if len(ws.sq) == 0 {
- return
+ if n.kids == nil {
+ return false
+ }
+
+ if n.id != 0 {
+ openParent = openParent || (n.state == http2priorityNodeOpen)
}
- for id, q := range ws.sq {
- if q.firstIsNoCost() {
- return ws.takeFrom(id, q)
+ w := n.kids.weight
+ needSort := false
+ for k := n.kids.next; k != nil; k = k.next {
+ if k.weight != w {
+ needSort = true
+ break
}
}
+ if !needSort {
+ for k := n.kids; k != nil; k = k.next {
+ if k.walkReadyInOrder(openParent, tmp, f) {
+ return true
+ }
+ }
+ return false
+ }
- if len(ws.canSend) != 0 {
- panic("should be empty")
+ *tmp = (*tmp)[:0]
+ for n.kids != nil {
+ *tmp = append(*tmp, n.kids)
+ n.kids.setParent(nil)
}
- for _, q := range ws.sq {
- if n := ws.streamWritableBytes(q); n > 0 {
- ws.canSend = append(ws.canSend, q)
+ sort.Sort(http2sortPriorityNodeSiblings(*tmp))
+ for i := len(*tmp) - 1; i >= 0; i-- {
+ (*tmp)[i].setParent(n)
+ }
+ for k := n.kids; k != nil; k = k.next {
+ if k.walkReadyInOrder(openParent, tmp, f) {
+ return true
}
}
- if len(ws.canSend) == 0 {
- return
+ return false
+}
+
+type http2sortPriorityNodeSiblings []*http2priorityNode
+
+func (z http2sortPriorityNodeSiblings) Len() int { return len(z) }
+
+func (z http2sortPriorityNodeSiblings) Swap(i, k int) { z[i], z[k] = z[k], z[i] }
+
+func (z http2sortPriorityNodeSiblings) Less(i, k int) bool {
+
+ wi, bi := float64(z[i].weight+1), float64(z[i].subtreeBytes)
+ wk, bk := float64(z[k].weight+1), float64(z[k].subtreeBytes)
+ if bi == 0 && bk == 0 {
+ return wi >= wk
}
- defer ws.zeroCanSend()
+ if bk == 0 {
+ return false
+ }
+ return bi/bk <= wi/wk
+}
- q := ws.canSend[0]
+type http2priorityWriteScheduler struct {
+ // root is the root of the priority tree, where root.id = 0.
+ // The root queues control frames that are not associated with any stream.
+ root http2priorityNode
- return ws.takeFrom(q.streamID(), q)
+ // nodes maps stream ids to priority tree nodes.
+ nodes map[uint32]*http2priorityNode
+
+ // maxID is the maximum stream id in nodes.
+ maxID uint32
+
+ // lists of nodes that have been closed or are idle, but are kept in
+ // the tree for improved prioritization. When the lengths exceed either
+ // maxClosedNodesInTree or maxIdleNodesInTree, old nodes are discarded.
+ closedNodes, idleNodes []*http2priorityNode
+
+ // From the config.
+ maxClosedNodesInTree int
+ maxIdleNodesInTree int
+ writeThrottleLimit int32
+ enableWriteThrottle bool
+
+ // tmp is scratch space for priorityNode.walkReadyInOrder to reduce allocations.
+ tmp []*http2priorityNode
+
+ // pool of empty queues for reuse.
+ queuePool http2writeQueuePool
}
-// zeroCanSend is defered from take.
-func (ws *http2writeScheduler) zeroCanSend() {
- for i := range ws.canSend {
- ws.canSend[i] = nil
+func (ws *http2priorityWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) {
+
+ if curr := ws.nodes[streamID]; curr != nil {
+ if curr.state != http2priorityNodeIdle {
+ panic(fmt.Sprintf("stream %d already opened", streamID))
+ }
+ curr.state = http2priorityNodeOpen
+ return
+ }
+
+ parent := ws.nodes[options.PusherID]
+ if parent == nil {
+ parent = &ws.root
+ }
+ n := &http2priorityNode{
+ q: *ws.queuePool.get(),
+ id: streamID,
+ weight: http2priorityDefaultWeight,
+ state: http2priorityNodeOpen,
+ }
+ n.setParent(parent)
+ ws.nodes[streamID] = n
+ if streamID > ws.maxID {
+ ws.maxID = streamID
}
- ws.canSend = ws.canSend[:0]
}
-// streamWritableBytes returns the number of DATA bytes we could write
-// from the given queue's stream, if this stream/queue were
-// selected. It is an error to call this if q's head isn't a
-// *writeData.
-func (ws *http2writeScheduler) streamWritableBytes(q *http2writeQueue) int32 {
- wm := q.head()
- ret := wm.stream.flow.available()
- if ret == 0 {
- return 0
+func (ws *http2priorityWriteScheduler) CloseStream(streamID uint32) {
+ if streamID == 0 {
+ panic("violation of WriteScheduler interface: cannot close stream 0")
}
- if int32(ws.maxFrameSize) < ret {
- ret = int32(ws.maxFrameSize)
+ if ws.nodes[streamID] == nil {
+ panic(fmt.Sprintf("violation of WriteScheduler interface: unknown stream %d", streamID))
}
- if ret == 0 {
- panic("internal error: ws.maxFrameSize not initialized or invalid")
+ if ws.nodes[streamID].state != http2priorityNodeOpen {
+ panic(fmt.Sprintf("violation of WriteScheduler interface: stream %d already closed", streamID))
}
- wd := wm.write.(*http2writeData)
- if len(wd.p) < int(ret) {
- ret = int32(len(wd.p))
+
+ n := ws.nodes[streamID]
+ n.state = http2priorityNodeClosed
+ n.addBytes(-n.bytes)
+
+ q := n.q
+ ws.queuePool.put(&q)
+ n.q.s = nil
+ if ws.maxClosedNodesInTree > 0 {
+ ws.addClosedOrIdleNode(&ws.closedNodes, ws.maxClosedNodesInTree, n)
+ } else {
+ ws.removeNode(n)
}
- return ret
}
-func (ws *http2writeScheduler) takeFrom(id uint32, q *http2writeQueue) (wm http2frameWriteMsg, ok bool) {
- wm = q.head()
-
- if wd, ok := wm.write.(*http2writeData); ok && len(wd.p) > 0 {
- allowed := wm.stream.flow.available()
- if allowed == 0 {
+func (ws *http2priorityWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) {
+ if streamID == 0 {
+ panic("adjustPriority on root")
+ }
- return http2frameWriteMsg{}, false
+ n := ws.nodes[streamID]
+ if n == nil {
+ if streamID <= ws.maxID || ws.maxIdleNodesInTree == 0 {
+ return
}
- if int32(ws.maxFrameSize) < allowed {
- allowed = int32(ws.maxFrameSize)
+ ws.maxID = streamID
+ n = &http2priorityNode{
+ q: *ws.queuePool.get(),
+ id: streamID,
+ weight: http2priorityDefaultWeight,
+ state: http2priorityNodeIdle,
}
+ n.setParent(&ws.root)
+ ws.nodes[streamID] = n
+ ws.addClosedOrIdleNode(&ws.idleNodes, ws.maxIdleNodesInTree, n)
+ }
- if len(wd.p) > int(allowed) {
- wm.stream.flow.take(allowed)
- chunk := wd.p[:allowed]
- wd.p = wd.p[allowed:]
+ parent := ws.nodes[priority.StreamDep]
+ if parent == nil {
+ n.setParent(&ws.root)
+ n.weight = http2priorityDefaultWeight
+ return
+ }
- return http2frameWriteMsg{
- stream: wm.stream,
- write: &http2writeData{
- streamID: wd.streamID,
- p: chunk,
+ if n == parent {
+ return
+ }
- endStream: false,
- },
+ for x := parent.parent; x != nil; x = x.parent {
+ if x == n {
+ parent.setParent(n.parent)
+ break
+ }
+ }
- done: nil,
- }, true
+ if priority.Exclusive {
+ k := parent.kids
+ for k != nil {
+ next := k.next
+ if k != n {
+ k.setParent(n)
+ }
+ k = next
}
- wm.stream.flow.take(int32(len(wd.p)))
}
- q.shift()
- if q.empty() {
- ws.putEmptyQueue(q)
- delete(ws.sq, id)
+ n.setParent(parent)
+ n.weight = priority.Weight
+}
+
+func (ws *http2priorityWriteScheduler) Push(wr http2FrameWriteRequest) {
+ var n *http2priorityNode
+ if id := wr.StreamID(); id == 0 {
+ n = &ws.root
+ } else {
+ n = ws.nodes[id]
+ if n == nil {
+
+ if wr.DataSize() > 0 {
+ panic("add DATA on non-open stream")
+ }
+ n = &ws.root
+ }
}
- return wm, true
+ n.q.push(wr)
}
-func (ws *http2writeScheduler) forgetStream(id uint32) {
- q, ok := ws.sq[id]
- if !ok {
+func (ws *http2priorityWriteScheduler) Pop() (wr http2FrameWriteRequest, ok bool) {
+ ws.root.walkReadyInOrder(false, &ws.tmp, func(n *http2priorityNode, openParent bool) bool {
+ limit := int32(math.MaxInt32)
+ if openParent {
+ limit = ws.writeThrottleLimit
+ }
+ wr, ok = n.q.consume(limit)
+ if !ok {
+ return false
+ }
+ n.addBytes(int64(wr.DataSize()))
+
+ if openParent {
+ ws.writeThrottleLimit += 1024
+ if ws.writeThrottleLimit < 0 {
+ ws.writeThrottleLimit = math.MaxInt32
+ }
+ } else if ws.enableWriteThrottle {
+ ws.writeThrottleLimit = 1024
+ }
+ return true
+ })
+ return wr, ok
+}
+
+func (ws *http2priorityWriteScheduler) addClosedOrIdleNode(list *[]*http2priorityNode, maxSize int, n *http2priorityNode) {
+ if maxSize == 0 {
return
}
- delete(ws.sq, id)
+ if len(*list) == maxSize {
- for i := range q.s {
- q.s[i] = http2frameWriteMsg{}
+ ws.removeNode((*list)[0])
+ x := (*list)[1:]
+ copy(*list, x)
+ *list = (*list)[:len(x)]
}
- q.s = q.s[:0]
- ws.putEmptyQueue(q)
+ *list = append(*list, n)
}
-type http2writeQueue struct {
- s []http2frameWriteMsg
+func (ws *http2priorityWriteScheduler) removeNode(n *http2priorityNode) {
+ for k := n.kids; k != nil; k = k.next {
+ k.setParent(n.parent)
+ }
+ n.setParent(nil)
+ delete(ws.nodes, n.id)
}
-// streamID returns the stream ID for a non-empty stream-specific queue.
-func (q *http2writeQueue) streamID() uint32 { return q.s[0].stream.id }
+// NewRandomWriteScheduler constructs a WriteScheduler that ignores HTTP/2
+// priorities. Control frames like SETTINGS and PING are written before DATA
+// frames, but if no control frames are queued and multiple streams have queued
+// HEADERS or DATA frames, Pop selects a ready stream arbitrarily.
+func http2NewRandomWriteScheduler() http2WriteScheduler {
+ return &http2randomWriteScheduler{sq: make(map[uint32]*http2writeQueue)}
+}
-func (q *http2writeQueue) empty() bool { return len(q.s) == 0 }
+type http2randomWriteScheduler struct {
+ // zero are frames not associated with a specific stream.
+ zero http2writeQueue
+
+ // sq contains the stream-specific queues, keyed by stream ID.
+ // When a stream is idle or closed, it's deleted from the map.
+ sq map[uint32]*http2writeQueue
-func (q *http2writeQueue) push(wm http2frameWriteMsg) {
- q.s = append(q.s, wm)
+ // pool of empty queues for reuse.
+ queuePool http2writeQueuePool
}
-// head returns the next item that would be removed by shift.
-func (q *http2writeQueue) head() http2frameWriteMsg {
- if len(q.s) == 0 {
- panic("invalid use of queue")
- }
- return q.s[0]
+func (ws *http2randomWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) {
+
}
-func (q *http2writeQueue) shift() http2frameWriteMsg {
- if len(q.s) == 0 {
- panic("invalid use of queue")
+func (ws *http2randomWriteScheduler) CloseStream(streamID uint32) {
+ q, ok := ws.sq[streamID]
+ if !ok {
+ return
}
- wm := q.s[0]
+ delete(ws.sq, streamID)
+ ws.queuePool.put(q)
+}
+
+func (ws *http2randomWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) {
- copy(q.s, q.s[1:])
- q.s[len(q.s)-1] = http2frameWriteMsg{}
- q.s = q.s[:len(q.s)-1]
- return wm
}
-func (q *http2writeQueue) firstIsNoCost() bool {
- if df, ok := q.s[0].write.(*http2writeData); ok {
- return len(df.p) == 0
+func (ws *http2randomWriteScheduler) Push(wr http2FrameWriteRequest) {
+ id := wr.StreamID()
+ if id == 0 {
+ ws.zero.push(wr)
+ return
}
- return true
+ q, ok := ws.sq[id]
+ if !ok {
+ q = ws.queuePool.get()
+ ws.sq[id] = q
+ }
+ q.push(wr)
+}
+
+func (ws *http2randomWriteScheduler) Pop() (http2FrameWriteRequest, bool) {
+
+ if !ws.zero.empty() {
+ return ws.zero.shift(), true
+ }
+
+ for _, q := range ws.sq {
+ if wr, ok := q.consume(math.MaxInt32); ok {
+ return wr, true
+ }
+ }
+ return http2FrameWriteRequest{}, false
}
diff --git a/libgo/go/net/http/header.go b/libgo/go/net/http/header.go
index 6343165..8321692 100644
--- a/libgo/go/net/http/header.go
+++ b/libgo/go/net/http/header.go
@@ -32,9 +32,11 @@ func (h Header) Set(key, value string) {
}
// Get gets the first value associated with the given key.
+// It is case insensitive; textproto.CanonicalMIMEHeaderKey is used
+// to canonicalize the provided key.
// If there are no values associated with the key, Get returns "".
-// To access multiple values of a key, access the map directly
-// with CanonicalHeaderKey.
+// To access multiple values of a key, or to use non-canonical keys,
+// access the map directly.
func (h Header) Get(key string) string {
return textproto.MIMEHeader(h).Get(key)
}
diff --git a/libgo/go/net/http/http.go b/libgo/go/net/http/http.go
index b34ae41..826f7ff 100644
--- a/libgo/go/net/http/http.go
+++ b/libgo/go/net/http/http.go
@@ -5,7 +5,11 @@
package http
import (
+ "io"
+ "strconv"
"strings"
+ "time"
+ "unicode/utf8"
"golang_org/x/net/lex/httplex"
)
@@ -14,6 +18,10 @@ import (
// Transport's byte-limiting readers.
const maxInt64 = 1<<63 - 1
+// aLongTimeAgo is a non-zero time, far in the past, used for
+// immediate cancelation of network operations.
+var aLongTimeAgo = time.Unix(233431200, 0)
+
// TODO(bradfitz): move common stuff here. The other files have accumulated
// generic http stuff in random places.
@@ -41,3 +49,93 @@ func removeEmptyPort(host string) string {
func isNotToken(r rune) bool {
return !httplex.IsTokenRune(r)
}
+
+func isASCII(s string) bool {
+ for i := 0; i < len(s); i++ {
+ if s[i] >= utf8.RuneSelf {
+ return false
+ }
+ }
+ return true
+}
+
+func hexEscapeNonASCII(s string) string {
+ newLen := 0
+ for i := 0; i < len(s); i++ {
+ if s[i] >= utf8.RuneSelf {
+ newLen += 3
+ } else {
+ newLen++
+ }
+ }
+ if newLen == len(s) {
+ return s
+ }
+ b := make([]byte, 0, newLen)
+ for i := 0; i < len(s); i++ {
+ if s[i] >= utf8.RuneSelf {
+ b = append(b, '%')
+ b = strconv.AppendInt(b, int64(s[i]), 16)
+ } else {
+ b = append(b, s[i])
+ }
+ }
+ return string(b)
+}
+
+// NoBody is an io.ReadCloser with no bytes. Read always returns EOF
+// and Close always returns nil. It can be used in an outgoing client
+// request to explicitly signal that a request has zero bytes.
+// An alternative, however, is to simply set Request.Body to nil.
+var NoBody = noBody{}
+
+type noBody struct{}
+
+func (noBody) Read([]byte) (int, error) { return 0, io.EOF }
+func (noBody) Close() error { return nil }
+func (noBody) WriteTo(io.Writer) (int64, error) { return 0, nil }
+
+var (
+ // verify that an io.Copy from NoBody won't require a buffer:
+ _ io.WriterTo = NoBody
+ _ io.ReadCloser = NoBody
+)
+
+// PushOptions describes options for Pusher.Push.
+type PushOptions struct {
+ // Method specifies the HTTP method for the promised request.
+ // If set, it must be "GET" or "HEAD". Empty means "GET".
+ Method string
+
+ // Header specifies additional promised request headers. This cannot
+ // include HTTP/2 pseudo header fields like ":path" and ":scheme",
+ // which will be added automatically.
+ Header Header
+}
+
+// Pusher is the interface implemented by ResponseWriters that support
+// HTTP/2 server push. For more background, see
+// https://tools.ietf.org/html/rfc7540#section-8.2.
+type Pusher interface {
+ // Push initiates an HTTP/2 server push. This constructs a synthetic
+ // request using the given target and options, serializes that request
+ // into a PUSH_PROMISE frame, then dispatches that request using the
+ // server's request handler. If opts is nil, default options are used.
+ //
+ // The target must either be an absolute path (like "/path") or an absolute
+ // URL that contains a valid host and the same scheme as the parent request.
+ // If the target is a path, it will inherit the scheme and host of the
+ // parent request.
+ //
+ // The HTTP/2 spec disallows recursive pushes and cross-authority pushes.
+ // Push may or may not detect these invalid pushes; however, invalid
+ // pushes will be detected and canceled by conforming clients.
+ //
+ // Handlers that wish to push URL X should call Push before sending any
+ // data that may trigger a request for URL X. This avoids a race where the
+ // client issues requests for X before receiving the PUSH_PROMISE for X.
+ //
+ // Push returns ErrNotSupported if the client has disabled push or if push
+ // is not supported on the underlying connection.
+ Push(target string, opts *PushOptions) error
+}
diff --git a/libgo/go/net/http/http_test.go b/libgo/go/net/http/http_test.go
index 34da4bb..8f466bb 100644
--- a/libgo/go/net/http/http_test.go
+++ b/libgo/go/net/http/http_test.go
@@ -12,8 +12,13 @@ import (
"os/exec"
"reflect"
"testing"
+ "time"
)
+func init() {
+ shutdownPollInterval = 5 * time.Millisecond
+}
+
func TestForeachHeaderElement(t *testing.T) {
tests := []struct {
in string
@@ -51,6 +56,18 @@ func TestCleanHost(t *testing.T) {
{"www.google.com foo", "www.google.com"},
{"www.google.com/foo", "www.google.com"},
{" first character is a space", ""},
+ {"[1::6]:8080", "[1::6]:8080"},
+
+ // Punycode:
+ {"гофер.рф/foo", "xn--c1ae0ajs.xn--p1ai"},
+ {"bücher.de", "xn--bcher-kva.de"},
+ {"bücher.de:8080", "xn--bcher-kva.de:8080"},
+ // Verify we convert to lowercase before punycode:
+ {"BÜCHER.de", "xn--bcher-kva.de"},
+ {"BÜCHER.de:8080", "xn--bcher-kva.de:8080"},
+ // Verify we normalize to NFC before punycode:
+ {"gophér.nfc", "xn--gophr-esa.nfc"}, // NFC input; no work needed
+ {"goph\u0065\u0301r.nfd", "xn--gophr-esa.nfd"}, // NFD input
}
for _, tt := range tests {
got := cleanHost(tt.in)
@@ -65,8 +82,9 @@ func TestCleanHost(t *testing.T) {
// This catches accidental dependencies between the HTTP transport and
// server code.
func TestCmdGoNoHTTPServer(t *testing.T) {
+ t.Parallel()
goBin := testenv.GoToolPath(t)
- out, err := exec.Command("go", "tool", "nm", goBin).CombinedOutput()
+ out, err := exec.Command(goBin, "tool", "nm", goBin).CombinedOutput()
if err != nil {
t.Fatalf("go tool nm: %v: %s", err, out)
}
diff --git a/libgo/go/net/http/httptest/httptest.go b/libgo/go/net/http/httptest/httptest.go
index e2148a6..f7202da 100644
--- a/libgo/go/net/http/httptest/httptest.go
+++ b/libgo/go/net/http/httptest/httptest.go
@@ -35,6 +35,9 @@ import (
//
// NewRequest panics on error for ease of use in testing, where a
// panic is acceptable.
+//
+// To generate a client HTTP request instead of a server request, see
+// the NewRequest function in the net/http package.
func NewRequest(method, target string, body io.Reader) *http.Request {
if method == "" {
method = "GET"
diff --git a/libgo/go/net/http/httptest/recorder.go b/libgo/go/net/http/httptest/recorder.go
index 0ad26a3..5f1aa6a 100644
--- a/libgo/go/net/http/httptest/recorder.go
+++ b/libgo/go/net/http/httptest/recorder.go
@@ -8,15 +8,33 @@ import (
"bytes"
"io/ioutil"
"net/http"
+ "strconv"
+ "strings"
)
// ResponseRecorder is an implementation of http.ResponseWriter that
// records its mutations for later inspection in tests.
type ResponseRecorder struct {
- Code int // the HTTP response code from WriteHeader
- HeaderMap http.Header // the HTTP response headers
- Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
- Flushed bool
+ // Code is the HTTP response code set by WriteHeader.
+ //
+ // Note that if a Handler never calls WriteHeader or Write,
+ // this might end up being 0, rather than the implicit
+ // http.StatusOK. To get the implicit value, use the Result
+ // method.
+ Code int
+
+ // HeaderMap contains the headers explicitly set by the Handler.
+ //
+ // To get the implicit headers set by the server (such as
+ // automatic Content-Type), use the Result method.
+ HeaderMap http.Header
+
+ // Body is the buffer to which the Handler's Write calls are sent.
+ // If nil, the Writes are silently discarded.
+ Body *bytes.Buffer
+
+ // Flushed is whether the Handler called Flush.
+ Flushed bool
result *http.Response // cache of Result's return value
snapHeader http.Header // snapshot of HeaderMap at first Write
@@ -136,6 +154,9 @@ func (rw *ResponseRecorder) Flush() {
// first write call, or at the time of this call, if the handler never
// did a write.
//
+// The Response.Body is guaranteed to be non-nil and Body.Read call is
+// guaranteed to not return any error other than io.EOF.
+//
// Result must only be called after the handler has finished running.
func (rw *ResponseRecorder) Result() *http.Response {
if rw.result != nil {
@@ -159,6 +180,7 @@ func (rw *ResponseRecorder) Result() *http.Response {
if rw.Body != nil {
res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
}
+ res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
if trailers, ok := rw.snapHeader["Trailer"]; ok {
res.Trailer = make(http.Header, len(trailers))
@@ -181,5 +203,33 @@ func (rw *ResponseRecorder) Result() *http.Response {
res.Trailer[k] = vv2
}
}
+ for k, vv := range rw.HeaderMap {
+ if !strings.HasPrefix(k, http.TrailerPrefix) {
+ continue
+ }
+ if res.Trailer == nil {
+ res.Trailer = make(http.Header)
+ }
+ for _, v := range vv {
+ res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
+ }
+ }
return res
}
+
+// parseContentLength trims whitespace from s and returns -1 if no value
+// is set, or the value if it's >= 0.
+//
+// This a modified version of same function found in net/http/transfer.go. This
+// one just ignores an invalid header.
+func parseContentLength(cl string) int64 {
+ cl = strings.TrimSpace(cl)
+ if cl == "" {
+ return -1
+ }
+ n, err := strconv.ParseInt(cl, 10, 64)
+ if err != nil {
+ return -1
+ }
+ return n
+}
diff --git a/libgo/go/net/http/httptest/recorder_test.go b/libgo/go/net/http/httptest/recorder_test.go
index d4e7137..9afba4e 100644
--- a/libgo/go/net/http/httptest/recorder_test.go
+++ b/libgo/go/net/http/httptest/recorder_test.go
@@ -94,6 +94,14 @@ func TestRecorder(t *testing.T) {
return nil
}
}
+ hasContentLength := func(length int64) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if got := rec.Result().ContentLength; got != length {
+ return fmt.Errorf("ContentLength = %d; want %d", got, length)
+ }
+ return nil
+ }
+ }
tests := []struct {
name string
@@ -141,7 +149,7 @@ func TestRecorder(t *testing.T) {
w.(http.Flusher).Flush() // also sends a 200
w.WriteHeader(201)
},
- check(hasStatus(200), hasFlush(true)),
+ check(hasStatus(200), hasFlush(true), hasContentLength(-1)),
},
{
"Content-Type detection",
@@ -199,6 +207,7 @@ func TestRecorder(t *testing.T) {
w.Header().Set("Trailer-A", "valuea")
w.Header().Set("Trailer-C", "valuec")
w.Header().Set("Trailer-NotDeclared", "should be omitted")
+ w.Header().Set("Trailer:Trailer-D", "with prefix")
},
check(
hasStatus(200),
@@ -208,6 +217,7 @@ func TestRecorder(t *testing.T) {
hasTrailer("Trailer-A", "valuea"),
hasTrailer("Trailer-C", "valuec"),
hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"),
+ hasTrailer("Trailer-D", "with prefix"),
),
},
{
@@ -244,6 +254,16 @@ func TestRecorder(t *testing.T) {
hasNotHeaders("X-Bar"),
),
},
+ {
+ "setting Content-Length header",
+ func(w http.ResponseWriter, r *http.Request) {
+ body := "Some body"
+ contentLength := fmt.Sprintf("%d", len(body))
+ w.Header().Set("Content-Length", contentLength)
+ io.WriteString(w, body)
+ },
+ check(hasStatus(200), hasContents("Some body"), hasContentLength(9)),
+ },
}
r, _ := http.NewRequest("GET", "http://foo.com/", nil)
for _, tt := range tests {
diff --git a/libgo/go/net/http/httptest/server.go b/libgo/go/net/http/httptest/server.go
index e27526a..5e9ace5 100644
--- a/libgo/go/net/http/httptest/server.go
+++ b/libgo/go/net/http/httptest/server.go
@@ -16,7 +16,6 @@ import (
"net/http"
"net/http/internal"
"os"
- "runtime"
"sync"
"time"
)
@@ -114,9 +113,10 @@ func (s *Server) StartTLS() {
}
existingConfig := s.TLS
- s.TLS = new(tls.Config)
if existingConfig != nil {
- *s.TLS = *existingConfig
+ s.TLS = existingConfig.Clone()
+ } else {
+ s.TLS = new(tls.Config)
}
if s.TLS.NextProtos == nil {
s.TLS.NextProtos = []string{"http/1.1"}
@@ -293,15 +293,6 @@ func (s *Server) closeConn(c net.Conn) { s.closeConnChan(c, nil) }
// closeConnChan is like closeConn, but takes an optional channel to receive a value
// when the goroutine closing c is done.
func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) {
- if runtime.GOOS == "plan9" {
- // Go's Plan 9 net package isn't great at unblocking reads when
- // their underlying TCP connections are closed. Don't trust
- // that that the ConnState state machine will get to
- // StateClosed. Instead, just go there directly. Plan 9 may leak
- // resources if the syscall doesn't end up returning. Oh well.
- s.forgetConn(c)
- }
-
c.Close()
if done != nil {
done <- struct{}{}
diff --git a/libgo/go/net/http/httptrace/example_test.go b/libgo/go/net/http/httptrace/example_test.go
new file mode 100644
index 0000000..27cdcde
--- /dev/null
+++ b/libgo/go/net/http/httptrace/example_test.go
@@ -0,0 +1,31 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build ignore
+
+package httptrace_test
+
+import (
+ "fmt"
+ "log"
+ "net/http"
+ "net/http/httptrace"
+)
+
+func Example() {
+ req, _ := http.NewRequest("GET", "http://example.com", nil)
+ trace := &httptrace.ClientTrace{
+ GotConn: func(connInfo httptrace.GotConnInfo) {
+ fmt.Printf("Got Conn: %+v\n", connInfo)
+ },
+ DNSDone: func(dnsInfo httptrace.DNSDoneInfo) {
+ fmt.Printf("DNS Info: %+v\n", dnsInfo)
+ },
+ }
+ req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
+ _, err := http.DefaultTransport.RoundTrip(req)
+ if err != nil {
+ log.Fatal(err)
+ }
+}
diff --git a/libgo/go/net/http/httptrace/trace.go b/libgo/go/net/http/httptrace/trace.go
index 6f187a7..ea7b38c 100644
--- a/libgo/go/net/http/httptrace/trace.go
+++ b/libgo/go/net/http/httptrace/trace.go
@@ -1,6 +1,6 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.h
+// license that can be found in the LICENSE file.
// Package httptrace provides mechanisms to trace the events within
// HTTP client requests.
@@ -8,6 +8,7 @@ package httptrace
import (
"context"
+ "crypto/tls"
"internal/nettrace"
"net"
"reflect"
@@ -65,11 +66,16 @@ func WithClientTrace(ctx context.Context, trace *ClientTrace) context.Context {
return ctx
}
-// ClientTrace is a set of hooks to run at various stages of an HTTP
-// client request. Any particular hook may be nil. Functions may be
-// called concurrently from different goroutines, starting after the
-// call to Transport.RoundTrip and ending either when RoundTrip
-// returns an error, or when the Response.Body is closed.
+// ClientTrace is a set of hooks to run at various stages of an outgoing
+// HTTP request. Any particular hook may be nil. Functions may be
+// called concurrently from different goroutines and some may be called
+// after the request has completed or failed.
+//
+// ClientTrace currently traces a single HTTP request & response
+// during a single round trip and has no hooks that span a series
+// of redirected requests.
+//
+// See https://blog.golang.org/http-tracing for more.
type ClientTrace struct {
// GetConn is called before a connection is created or
// retrieved from an idle pool. The hostPort is the
@@ -119,6 +125,16 @@ type ClientTrace struct {
// enabled, this may be called multiple times.
ConnectDone func(network, addr string, err error)
+ // TLSHandshakeStart is called when the TLS handshake is started. When
+ // connecting to a HTTPS site via a HTTP proxy, the handshake happens after
+ // the CONNECT request is processed by the proxy.
+ TLSHandshakeStart func()
+
+ // TLSHandshakeDone is called after the TLS handshake with either the
+ // successful handshake's connection state, or a non-nil error on handshake
+ // failure.
+ TLSHandshakeDone func(tls.ConnectionState, error)
+
// WroteHeaders is called after the Transport has written
// the request headers.
WroteHeaders func()
@@ -130,7 +146,8 @@ type ClientTrace struct {
Wait100Continue func()
// WroteRequest is called with the result of writing the
- // request and any body.
+ // request and any body. It may be called multiple times
+ // in the case of retried requests.
WroteRequest func(WroteRequestInfo)
}
diff --git a/libgo/go/net/http/httptrace/trace_test.go b/libgo/go/net/http/httptrace/trace_test.go
index c7eaed8..bb57ada 100644
--- a/libgo/go/net/http/httptrace/trace_test.go
+++ b/libgo/go/net/http/httptrace/trace_test.go
@@ -1,14 +1,41 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.h
+// license that can be found in the LICENSE file.
package httptrace
import (
"bytes"
+ "context"
"testing"
)
+func TestWithClientTrace(t *testing.T) {
+ var buf bytes.Buffer
+ connectStart := func(b byte) func(network, addr string) {
+ return func(network, addr string) {
+ buf.WriteByte(b)
+ }
+ }
+
+ ctx := context.Background()
+ oldtrace := &ClientTrace{
+ ConnectStart: connectStart('O'),
+ }
+ ctx = WithClientTrace(ctx, oldtrace)
+ newtrace := &ClientTrace{
+ ConnectStart: connectStart('N'),
+ }
+ ctx = WithClientTrace(ctx, newtrace)
+ trace := ContextClientTrace(ctx)
+
+ buf.Reset()
+ trace.ConnectStart("net", "addr")
+ if got, want := buf.String(), "NO"; got != want {
+ t.Errorf("got %q; want %q", got, want)
+ }
+}
+
func TestCompose(t *testing.T) {
var buf bytes.Buffer
var testNum int
diff --git a/libgo/go/net/http/httputil/dump.go b/libgo/go/net/http/httputil/dump.go
index 1511681..7104c37 100644
--- a/libgo/go/net/http/httputil/dump.go
+++ b/libgo/go/net/http/httputil/dump.go
@@ -18,11 +18,16 @@ import (
"time"
)
-// One of the copies, say from b to r2, could be avoided by using a more
-// elaborate trick where the other copy is made during Request/Response.Write.
-// This would complicate things too much, given that these functions are for
-// debugging only.
+// drainBody reads all of b to memory and then returns two equivalent
+// ReadClosers yielding the same bytes.
+//
+// It returns an error if the initial slurp of all bytes fails. It does not attempt
+// to make the returned ReadClosers have identical error-matching behavior.
func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) {
+ if b == http.NoBody {
+ // No copying needed. Preserve the magic sentinel meaning of NoBody.
+ return http.NoBody, http.NoBody, nil
+ }
var buf bytes.Buffer
if _, err = buf.ReadFrom(b); err != nil {
return nil, b, err
diff --git a/libgo/go/net/http/httputil/dump_test.go b/libgo/go/net/http/httputil/dump_test.go
index 2e980d3..f881020 100644
--- a/libgo/go/net/http/httputil/dump_test.go
+++ b/libgo/go/net/http/httputil/dump_test.go
@@ -184,6 +184,18 @@ var dumpTests = []dumpTest{
WantDump: "POST /v2/api/?login HTTP/1.1\r\n" +
"Host: passport.myhost.com\r\n\r\n",
},
+
+ // Issue 18506: make drainBody recognize NoBody. Otherwise
+ // this was turning into a chunked request.
+ {
+ Req: *mustNewRequest("POST", "http://example.com/foo", http.NoBody),
+
+ WantDumpOut: "POST /foo HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Content-Length: 0\r\n" +
+ "Accept-Encoding: gzip\r\n\r\n",
+ },
}
func TestDumpRequest(t *testing.T) {
diff --git a/libgo/go/net/http/httputil/persist.go b/libgo/go/net/http/httputil/persist.go
index 87ddd52..cbedf25 100644
--- a/libgo/go/net/http/httputil/persist.go
+++ b/libgo/go/net/http/httputil/persist.go
@@ -15,9 +15,14 @@ import (
)
var (
+ // Deprecated: No longer used.
ErrPersistEOF = &http.ProtocolError{ErrorString: "persistent connection closed"}
- ErrClosed = &http.ProtocolError{ErrorString: "connection closed by user"}
- ErrPipeline = &http.ProtocolError{ErrorString: "pipeline error"}
+
+ // Deprecated: No longer used.
+ ErrClosed = &http.ProtocolError{ErrorString: "connection closed by user"}
+
+ // Deprecated: No longer used.
+ ErrPipeline = &http.ProtocolError{ErrorString: "pipeline error"}
)
// This is an API usage error - the local side is closed.
diff --git a/libgo/go/net/http/httputil/reverseproxy.go b/libgo/go/net/http/httputil/reverseproxy.go
index 49c120a..79c8fe2 100644
--- a/libgo/go/net/http/httputil/reverseproxy.go
+++ b/libgo/go/net/http/httputil/reverseproxy.go
@@ -7,6 +7,7 @@
package httputil
import (
+ "context"
"io"
"log"
"net"
@@ -29,6 +30,8 @@ type ReverseProxy struct {
// the request into a new request to be sent
// using Transport. Its response is then copied
// back to the original client unmodified.
+ // Director must not access the provided Request
+ // after returning.
Director func(*http.Request)
// The transport used to perform proxy requests.
@@ -51,6 +54,11 @@ type ReverseProxy struct {
// get byte slices for use by io.CopyBuffer when
// copying HTTP response bodies.
BufferPool BufferPool
+
+ // ModifyResponse is an optional function that
+ // modifies the Response from the backend.
+ // If it returns an error, the proxy returns a StatusBadGateway error.
+ ModifyResponse func(*http.Response) error
}
// A BufferPool is an interface for getting and returning temporary
@@ -120,76 +128,59 @@ var hopHeaders = []string{
"Upgrade",
}
-type requestCanceler interface {
- CancelRequest(*http.Request)
-}
-
-type runOnFirstRead struct {
- io.Reader // optional; nil means empty body
-
- fn func() // Run before first Read, then set to nil
-}
-
-func (c *runOnFirstRead) Read(bs []byte) (int, error) {
- if c.fn != nil {
- c.fn()
- c.fn = nil
- }
- if c.Reader == nil {
- return 0, io.EOF
- }
- return c.Reader.Read(bs)
-}
-
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
transport := p.Transport
if transport == nil {
transport = http.DefaultTransport
}
+ ctx := req.Context()
+ if cn, ok := rw.(http.CloseNotifier); ok {
+ var cancel context.CancelFunc
+ ctx, cancel = context.WithCancel(ctx)
+ defer cancel()
+ notifyChan := cn.CloseNotify()
+ go func() {
+ select {
+ case <-notifyChan:
+ cancel()
+ case <-ctx.Done():
+ }
+ }()
+ }
+
outreq := new(http.Request)
*outreq = *req // includes shallow copies of maps, but okay
-
- if closeNotifier, ok := rw.(http.CloseNotifier); ok {
- if requestCanceler, ok := transport.(requestCanceler); ok {
- reqDone := make(chan struct{})
- defer close(reqDone)
-
- clientGone := closeNotifier.CloseNotify()
-
- outreq.Body = struct {
- io.Reader
- io.Closer
- }{
- Reader: &runOnFirstRead{
- Reader: outreq.Body,
- fn: func() {
- go func() {
- select {
- case <-clientGone:
- requestCanceler.CancelRequest(outreq)
- case <-reqDone:
- }
- }()
- },
- },
- Closer: outreq.Body,
- }
- }
+ if req.ContentLength == 0 {
+ outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
}
+ outreq = outreq.WithContext(ctx)
p.Director(outreq)
- outreq.Proto = "HTTP/1.1"
- outreq.ProtoMajor = 1
- outreq.ProtoMinor = 1
outreq.Close = false
- // Remove hop-by-hop headers to the backend. Especially
- // important is "Connection" because we want a persistent
- // connection, regardless of what the client sent to us. This
- // is modifying the same underlying map from req (shallow
+ // We are modifying the same underlying map from req (shallow
// copied above) so we only copy it if necessary.
copiedHeaders := false
+
+ // Remove hop-by-hop headers listed in the "Connection" header.
+ // See RFC 2616, section 14.10.
+ if c := outreq.Header.Get("Connection"); c != "" {
+ for _, f := range strings.Split(c, ",") {
+ if f = strings.TrimSpace(f); f != "" {
+ if !copiedHeaders {
+ outreq.Header = make(http.Header)
+ copyHeader(outreq.Header, req.Header)
+ copiedHeaders = true
+ }
+ outreq.Header.Del(f)
+ }
+ }
+ }
+
+ // Remove hop-by-hop headers to the backend. Especially
+ // important is "Connection" because we want a persistent
+ // connection, regardless of what the client sent to us.
for _, h := range hopHeaders {
if outreq.Header.Get(h) != "" {
if !copiedHeaders {
@@ -218,16 +209,34 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
+ // Remove hop-by-hop headers listed in the
+ // "Connection" header of the response.
+ if c := res.Header.Get("Connection"); c != "" {
+ for _, f := range strings.Split(c, ",") {
+ if f = strings.TrimSpace(f); f != "" {
+ res.Header.Del(f)
+ }
+ }
+ }
+
for _, h := range hopHeaders {
res.Header.Del(h)
}
+ if p.ModifyResponse != nil {
+ if err := p.ModifyResponse(res); err != nil {
+ p.logf("http: proxy error: %v", err)
+ rw.WriteHeader(http.StatusBadGateway)
+ return
+ }
+ }
+
copyHeader(rw.Header(), res.Header)
// The "Trailer" header isn't included in the Transport's response,
// at least for *http.Transport. Build it up from Trailer.
if len(res.Trailer) > 0 {
- var trailerKeys []string
+ trailerKeys := make([]string, 0, len(res.Trailer))
for k := range res.Trailer {
trailerKeys = append(trailerKeys, k)
}
@@ -266,12 +275,40 @@ func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
if p.BufferPool != nil {
buf = p.BufferPool.Get()
}
- io.CopyBuffer(dst, src, buf)
+ p.copyBuffer(dst, src, buf)
if p.BufferPool != nil {
p.BufferPool.Put(buf)
}
}
+func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
+ if len(buf) == 0 {
+ buf = make([]byte, 32*1024)
+ }
+ var written int64
+ for {
+ nr, rerr := src.Read(buf)
+ if rerr != nil && rerr != io.EOF {
+ p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
+ }
+ if nr > 0 {
+ nw, werr := dst.Write(buf[:nr])
+ if nw > 0 {
+ written += int64(nw)
+ }
+ if werr != nil {
+ return written, werr
+ }
+ if nr != nw {
+ return written, io.ErrShortWrite
+ }
+ }
+ if rerr != nil {
+ return written, rerr
+ }
+ }
+}
+
func (p *ReverseProxy) logf(format string, args ...interface{}) {
if p.ErrorLog != nil {
p.ErrorLog.Printf(format, args...)
diff --git a/libgo/go/net/http/httputil/reverseproxy_test.go b/libgo/go/net/http/httputil/reverseproxy_test.go
index fe7cdb8..20c4e16 100644
--- a/libgo/go/net/http/httputil/reverseproxy_test.go
+++ b/libgo/go/net/http/httputil/reverseproxy_test.go
@@ -9,6 +9,8 @@ package httputil
import (
"bufio"
"bytes"
+ "errors"
+ "fmt"
"io"
"io/ioutil"
"log"
@@ -135,6 +137,61 @@ func TestReverseProxy(t *testing.T) {
}
+// Issue 16875: remove any proxied headers mentioned in the "Connection"
+// header value.
+func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) {
+ const fakeConnectionToken = "X-Fake-Connection-Token"
+ const backendResponse = "I am the backend"
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if c := r.Header.Get(fakeConnectionToken); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
+ }
+ if c := r.Header.Get("Upgrade"); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", "Upgrade", c)
+ }
+ w.Header().Set("Connection", "Upgrade, "+fakeConnectionToken)
+ w.Header().Set("Upgrade", "should be deleted")
+ w.Header().Set(fakeConnectionToken, "should be deleted")
+ io.WriteString(w, backendResponse)
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ proxyHandler.ServeHTTP(w, r)
+ if c := r.Header.Get("Upgrade"); c != "original value" {
+ t.Errorf("handler modified header %q = %q; want %q", "Upgrade", c, "original value")
+ }
+ }))
+ defer frontend.Close()
+
+ getReq, _ := http.NewRequest("GET", frontend.URL, nil)
+ getReq.Header.Set("Connection", "Upgrade, "+fakeConnectionToken)
+ getReq.Header.Set("Upgrade", "original value")
+ getReq.Header.Set(fakeConnectionToken, "should be deleted")
+ res, err := http.DefaultClient.Do(getReq)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+ bodyBytes, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("reading body: %v", err)
+ }
+ if got, want := string(bodyBytes), backendResponse; got != want {
+ t.Errorf("got body %q; want %q", got, want)
+ }
+ if c := res.Header.Get("Upgrade"); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", "Upgrade", c)
+ }
+ if c := res.Header.Get(fakeConnectionToken); c != "" {
+ t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c)
+ }
+}
+
func TestXForwardedFor(t *testing.T) {
const prevForwardedFor = "client ip"
const backendResponse = "I am the backend"
@@ -260,14 +317,14 @@ func TestReverseProxyCancelation(t *testing.T) {
reqInFlight := make(chan struct{})
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- close(reqInFlight)
+ close(reqInFlight) // cause the client to cancel its request
select {
case <-time.After(10 * time.Second):
// Note: this should only happen in broken implementations, and the
// closenotify case should be instantaneous.
- t.Log("Failed to close backend connection")
- t.Fail()
+ t.Error("Handler never saw CloseNotify")
+ return
case <-w.(http.CloseNotifier).CloseNotify():
}
@@ -300,13 +357,13 @@ func TestReverseProxyCancelation(t *testing.T) {
}()
res, err := http.DefaultClient.Do(getReq)
if res != nil {
- t.Fatal("Non-nil response")
+ t.Errorf("got response %v; want nil", res.Status)
}
if err == nil {
// This should be an error like:
// Get http://127.0.0.1:58079: read tcp 127.0.0.1:58079:
// use of closed network connection
- t.Fatal("DefaultClient.Do() returned nil error")
+ t.Error("DefaultClient.Do() returned nil error; want non-nil error")
}
}
@@ -495,3 +552,115 @@ func TestReverseProxy_Post(t *testing.T) {
t.Errorf("got body %q; expected %q", g, e)
}
}
+
+type RoundTripperFunc func(*http.Request) (*http.Response, error)
+
+func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
+ return fn(req)
+}
+
+// Issue 16036: send a Request with a nil Body when possible
+func TestReverseProxy_NilBody(t *testing.T) {
+ backendURL, _ := url.Parse("http://fake.tld/")
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
+ proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
+ if req.Body != nil {
+ t.Error("Body != nil; want a nil Body")
+ }
+ return nil, errors.New("done testing the interesting part; so force a 502 Gateway error")
+ })
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ res, err := http.DefaultClient.Get(frontend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != 502 {
+ t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status)
+ }
+}
+
+// Issue 14237. Test ModifyResponse and that an error from it
+// causes the proxy to return StatusBadGateway, or StatusOK otherwise.
+func TestReverseProxyModifyResponse(t *testing.T) {
+ backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod"))
+ }))
+ defer backendServer.Close()
+
+ rpURL, _ := url.Parse(backendServer.URL)
+ rproxy := NewSingleHostReverseProxy(rpURL)
+ rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
+ rproxy.ModifyResponse = func(resp *http.Response) error {
+ if resp.Header.Get("X-Hit-Mod") != "true" {
+ return fmt.Errorf("tried to by-pass proxy")
+ }
+ return nil
+ }
+
+ frontendProxy := httptest.NewServer(rproxy)
+ defer frontendProxy.Close()
+
+ tests := []struct {
+ url string
+ wantCode int
+ }{
+ {frontendProxy.URL + "/mod", http.StatusOK},
+ {frontendProxy.URL + "/schedule", http.StatusBadGateway},
+ }
+
+ for i, tt := range tests {
+ resp, err := http.Get(tt.url)
+ if err != nil {
+ t.Fatalf("failed to reach proxy: %v", err)
+ }
+ if g, e := resp.StatusCode, tt.wantCode; g != e {
+ t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e)
+ }
+ resp.Body.Close()
+ }
+}
+
+// Issue 16659: log errors from short read
+func TestReverseProxy_CopyBuffer(t *testing.T) {
+ backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ out := "this call was relayed by the reverse proxy"
+ // Coerce a wrong content length to induce io.UnexpectedEOF
+ w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
+ fmt.Fprintln(w, out)
+ }))
+ defer backendServer.Close()
+
+ rpURL, err := url.Parse(backendServer.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var proxyLog bytes.Buffer
+ rproxy := NewSingleHostReverseProxy(rpURL)
+ rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile)
+ frontendProxy := httptest.NewServer(rproxy)
+ defer frontendProxy.Close()
+
+ resp, err := http.Get(frontendProxy.URL)
+ if err != nil {
+ t.Fatalf("failed to reach proxy: %v", err)
+ }
+ defer resp.Body.Close()
+
+ if _, err := ioutil.ReadAll(resp.Body); err == nil {
+ t.Fatalf("want non-nil error")
+ }
+ expected := []string{
+ "EOF",
+ "read",
+ }
+ for _, phrase := range expected {
+ if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) {
+ t.Errorf("expected log to contain phrase %q", phrase)
+ }
+ }
+}
diff --git a/libgo/go/net/http/internal/chunked.go b/libgo/go/net/http/internal/chunked.go
index 2e62c00..63f321d 100644
--- a/libgo/go/net/http/internal/chunked.go
+++ b/libgo/go/net/http/internal/chunked.go
@@ -35,10 +35,11 @@ func NewChunkedReader(r io.Reader) io.Reader {
}
type chunkedReader struct {
- r *bufio.Reader
- n uint64 // unread bytes in chunk
- err error
- buf [2]byte
+ r *bufio.Reader
+ n uint64 // unread bytes in chunk
+ err error
+ buf [2]byte
+ checkEnd bool // whether need to check for \r\n chunk footer
}
func (cr *chunkedReader) beginChunk() {
@@ -68,6 +69,21 @@ func (cr *chunkedReader) chunkHeaderAvailable() bool {
func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
for cr.err == nil {
+ if cr.checkEnd {
+ if n > 0 && cr.r.Buffered() < 2 {
+ // We have some data. Return early (per the io.Reader
+ // contract) instead of potentially blocking while
+ // reading more.
+ break
+ }
+ if _, cr.err = io.ReadFull(cr.r, cr.buf[:2]); cr.err == nil {
+ if string(cr.buf[:]) != "\r\n" {
+ cr.err = errors.New("malformed chunked encoding")
+ break
+ }
+ }
+ cr.checkEnd = false
+ }
if cr.n == 0 {
if n > 0 && !cr.chunkHeaderAvailable() {
// We've read enough. Don't potentially block
@@ -92,11 +108,7 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
// If we're at the end of a chunk, read the next two
// bytes to verify they are "\r\n".
if cr.n == 0 && cr.err == nil {
- if _, cr.err = io.ReadFull(cr.r, cr.buf[:2]); cr.err == nil {
- if cr.buf[0] != '\r' || cr.buf[1] != '\n' {
- cr.err = errors.New("malformed chunked encoding")
- }
- }
+ cr.checkEnd = true
}
}
return n, cr.err
diff --git a/libgo/go/net/http/internal/chunked_test.go b/libgo/go/net/http/internal/chunked_test.go
index 9abe1ab..d067165 100644
--- a/libgo/go/net/http/internal/chunked_test.go
+++ b/libgo/go/net/http/internal/chunked_test.go
@@ -185,3 +185,30 @@ func TestChunkReadingIgnoresExtensions(t *testing.T) {
t.Errorf("read %q; want %q", g, e)
}
}
+
+// Issue 17355: ChunkedReader shouldn't block waiting for more data
+// if it can return something.
+func TestChunkReadPartial(t *testing.T) {
+ pr, pw := io.Pipe()
+ go func() {
+ pw.Write([]byte("7\r\n1234567"))
+ }()
+ cr := NewChunkedReader(pr)
+ readBuf := make([]byte, 7)
+ n, err := cr.Read(readBuf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ want := "1234567"
+ if n != 7 || string(readBuf) != want {
+ t.Fatalf("Read: %v %q; want %d, %q", n, readBuf[:n], len(want), want)
+ }
+ go func() {
+ pw.Write([]byte("xx"))
+ }()
+ _, err = cr.Read(readBuf)
+ if got := fmt.Sprint(err); !strings.Contains(got, "malformed") {
+ t.Fatalf("second read = %v; want malformed error", err)
+ }
+
+}
diff --git a/libgo/go/net/http/main_test.go b/libgo/go/net/http/main_test.go
index aea6e12..438bd2e 100644
--- a/libgo/go/net/http/main_test.go
+++ b/libgo/go/net/http/main_test.go
@@ -6,6 +6,8 @@ package http_test
import (
"fmt"
+ "io/ioutil"
+ "log"
"net/http"
"os"
"runtime"
@@ -15,6 +17,8 @@ import (
"time"
)
+var quietLog = log.New(ioutil.Discard, "", 0)
+
func TestMain(m *testing.M) {
v := m.Run()
if v == 0 && goroutineLeaked() {
@@ -134,3 +138,20 @@ func waitCondition(waitFor, checkEvery time.Duration, fn func() bool) bool {
}
return false
}
+
+// waitErrCondition is like waitCondition but with errors instead of bools.
+func waitErrCondition(waitFor, checkEvery time.Duration, fn func() error) error {
+ deadline := time.Now().Add(waitFor)
+ var err error
+ for time.Now().Before(deadline) {
+ if err = fn(); err == nil {
+ return nil
+ }
+ time.Sleep(checkEvery)
+ }
+ return err
+}
+
+func closeClient(c *http.Client) {
+ c.Transport.(*http.Transport).CloseIdleConnections()
+}
diff --git a/libgo/go/net/http/npn_test.go b/libgo/go/net/http/npn_test.go
index e2e911d..4c1f6b5 100644
--- a/libgo/go/net/http/npn_test.go
+++ b/libgo/go/net/http/npn_test.go
@@ -18,6 +18,7 @@ import (
)
func TestNextProtoUpgrade(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "path=%s,proto=", r.URL.Path)
diff --git a/libgo/go/net/http/range_test.go b/libgo/go/net/http/range_test.go
index ef911af..114987e 100644
--- a/libgo/go/net/http/range_test.go
+++ b/libgo/go/net/http/range_test.go
@@ -38,7 +38,7 @@ var ParseRangeTests = []struct {
{"bytes=0-", 10, []httpRange{{0, 10}}},
{"bytes=5-", 10, []httpRange{{5, 5}}},
{"bytes=0-20", 10, []httpRange{{0, 10}}},
- {"bytes=15-,0-5", 10, nil},
+ {"bytes=15-,0-5", 10, []httpRange{{0, 6}}},
{"bytes=1-2,5-", 10, []httpRange{{1, 2}, {5, 5}}},
{"bytes=-2 , 7-", 11, []httpRange{{9, 2}, {7, 4}}},
{"bytes=0-0 ,2-2, 7-", 11, []httpRange{{0, 1}, {2, 1}, {7, 4}}},
diff --git a/libgo/go/net/http/readrequest_test.go b/libgo/go/net/http/readrequest_test.go
index 4bf646b..28a148b 100644
--- a/libgo/go/net/http/readrequest_test.go
+++ b/libgo/go/net/http/readrequest_test.go
@@ -25,7 +25,7 @@ type reqTest struct {
}
var noError = ""
-var noBody = ""
+var noBodyStr = ""
var noTrailer Header = nil
var reqTests = []reqTest{
@@ -95,7 +95,7 @@ var reqTests = []reqTest{
RequestURI: "/",
},
- noBody,
+ noBodyStr,
noTrailer,
noError,
},
@@ -121,7 +121,7 @@ var reqTests = []reqTest{
RequestURI: "//user@host/is/actually/a/path/",
},
- noBody,
+ noBodyStr,
noTrailer,
noError,
},
@@ -131,7 +131,7 @@ var reqTests = []reqTest{
"GET ../../../../etc/passwd HTTP/1.1\r\n" +
"Host: test\r\n\r\n",
nil,
- noBody,
+ noBodyStr,
noTrailer,
"parse ../../../../etc/passwd: invalid URI for request",
},
@@ -141,7 +141,7 @@ var reqTests = []reqTest{
"GET HTTP/1.1\r\n" +
"Host: test\r\n\r\n",
nil,
- noBody,
+ noBodyStr,
noTrailer,
"parse : empty url",
},
@@ -227,7 +227,7 @@ var reqTests = []reqTest{
RequestURI: "www.google.com:443",
},
- noBody,
+ noBodyStr,
noTrailer,
noError,
},
@@ -251,7 +251,7 @@ var reqTests = []reqTest{
RequestURI: "127.0.0.1:6060",
},
- noBody,
+ noBodyStr,
noTrailer,
noError,
},
@@ -275,7 +275,7 @@ var reqTests = []reqTest{
RequestURI: "/_goRPC_",
},
- noBody,
+ noBodyStr,
noTrailer,
noError,
},
@@ -299,7 +299,7 @@ var reqTests = []reqTest{
RequestURI: "*",
},
- noBody,
+ noBodyStr,
noTrailer,
noError,
},
@@ -323,7 +323,7 @@ var reqTests = []reqTest{
RequestURI: "*",
},
- noBody,
+ noBodyStr,
noTrailer,
noError,
},
@@ -350,7 +350,7 @@ var reqTests = []reqTest{
RequestURI: "/",
},
- noBody,
+ noBodyStr,
noTrailer,
noError,
},
@@ -376,7 +376,7 @@ var reqTests = []reqTest{
RequestURI: "/",
},
- noBody,
+ noBodyStr,
noTrailer,
noError,
},
@@ -397,7 +397,7 @@ var reqTests = []reqTest{
ContentLength: -1,
Close: true,
},
- noBody,
+ noBodyStr,
noTrailer,
noError,
},
diff --git a/libgo/go/net/http/request.go b/libgo/go/net/http/request.go
index dc55592..fb6bb0a 100644
--- a/libgo/go/net/http/request.go
+++ b/libgo/go/net/http/request.go
@@ -18,12 +18,17 @@ import (
"io/ioutil"
"mime"
"mime/multipart"
+ "net"
"net/http/httptrace"
"net/textproto"
"net/url"
"strconv"
"strings"
"sync"
+
+ "golang_org/x/net/idna"
+ "golang_org/x/text/unicode/norm"
+ "golang_org/x/text/width"
)
const (
@@ -34,21 +39,40 @@ const (
// is either not present in the request or not a file field.
var ErrMissingFile = errors.New("http: no such file")
-// HTTP request parsing errors.
+// ProtocolError represents an HTTP protocol error.
+//
+// Deprecated: Not all errors in the http package related to protocol errors
+// are of type ProtocolError.
type ProtocolError struct {
ErrorString string
}
-func (err *ProtocolError) Error() string { return err.ErrorString }
+func (pe *ProtocolError) Error() string { return pe.ErrorString }
var (
- ErrHeaderTooLong = &ProtocolError{"header too long"}
- ErrShortBody = &ProtocolError{"entity body too short"}
- ErrNotSupported = &ProtocolError{"feature not supported"}
- ErrUnexpectedTrailer = &ProtocolError{"trailer header without chunked transfer encoding"}
+ // ErrNotSupported is returned by the Push method of Pusher
+ // implementations to indicate that HTTP/2 Push support is not
+ // available.
+ ErrNotSupported = &ProtocolError{"feature not supported"}
+
+ // ErrUnexpectedTrailer is returned by the Transport when a server
+ // replies with a Trailer header, but without a chunked reply.
+ ErrUnexpectedTrailer = &ProtocolError{"trailer header without chunked transfer encoding"}
+
+ // ErrMissingBoundary is returned by Request.MultipartReader when the
+ // request's Content-Type does not include a "boundary" parameter.
+ ErrMissingBoundary = &ProtocolError{"no multipart boundary param in Content-Type"}
+
+ // ErrNotMultipart is returned by Request.MultipartReader when the
+ // request's Content-Type is not multipart/form-data.
+ ErrNotMultipart = &ProtocolError{"request Content-Type isn't multipart/form-data"}
+
+ // Deprecated: ErrHeaderTooLong is not used.
+ ErrHeaderTooLong = &ProtocolError{"header too long"}
+ // Deprecated: ErrShortBody is not used.
+ ErrShortBody = &ProtocolError{"entity body too short"}
+ // Deprecated: ErrMissingContentLength is not used.
ErrMissingContentLength = &ProtocolError{"missing ContentLength in HEAD response"}
- ErrNotMultipart = &ProtocolError{"request Content-Type isn't multipart/form-data"}
- ErrMissingBoundary = &ProtocolError{"no multipart boundary param in Content-Type"}
)
type badStringError struct {
@@ -146,11 +170,20 @@ type Request struct {
// Handler does not need to.
Body io.ReadCloser
+ // GetBody defines an optional func to return a new copy of
+ // Body. It is used for client requests when a redirect requires
+ // reading the body more than once. Use of GetBody still
+ // requires setting Body.
+ //
+ // For server requests it is unused.
+ GetBody func() (io.ReadCloser, error)
+
// ContentLength records the length of the associated content.
// The value -1 indicates that the length is unknown.
// Values >= 0 indicate that the given number of bytes may
// be read from Body.
- // For client requests, a value of 0 means unknown if Body is not nil.
+ // For client requests, a value of 0 with a non-nil Body is
+ // also treated as unknown.
ContentLength int64
// TransferEncoding lists the transfer encodings from outermost to
@@ -175,11 +208,15 @@ type Request struct {
// For server requests Host specifies the host on which the
// URL is sought. Per RFC 2616, this is either the value of
// the "Host" header or the host name given in the URL itself.
- // It may be of the form "host:port".
+ // It may be of the form "host:port". For international domain
+ // names, Host may be in Punycode or Unicode form. Use
+ // golang.org/x/net/idna to convert it to either format if
+ // needed.
//
// For client requests Host optionally overrides the Host
// header to send. If empty, the Request.Write method uses
- // the value of URL.Host.
+ // the value of URL.Host. Host may contain an international
+ // domain name.
Host string
// Form contains the parsed form data, including both the URL
@@ -276,8 +313,8 @@ type Request struct {
// For outgoing client requests, the context controls cancelation.
//
// For incoming server requests, the context is canceled when the
-// ServeHTTP method returns. For its associated values, see
-// ServerContextKey and LocalAddrContextKey.
+// client's connection closes, the request is canceled (with HTTP/2),
+// or when the ServeHTTP method returns.
func (r *Request) Context() context.Context {
if r.ctx != nil {
return r.ctx
@@ -304,6 +341,18 @@ func (r *Request) ProtoAtLeast(major, minor int) bool {
r.ProtoMajor == major && r.ProtoMinor >= minor
}
+// protoAtLeastOutgoing is like ProtoAtLeast, but is for outgoing
+// requests (see issue 18407) where these fields aren't supposed to
+// matter. As a minor fix for Go 1.8, at least treat (0, 0) as
+// matching HTTP/1.1 or HTTP/1.0. Only HTTP/1.1 is used.
+// TODO(bradfitz): ideally remove this whole method. It shouldn't be used.
+func (r *Request) protoAtLeastOutgoing(major, minor int) bool {
+ if r.ProtoMajor == 0 && r.ProtoMinor == 0 && major == 1 && minor <= 1 {
+ return true
+ }
+ return r.ProtoAtLeast(major, minor)
+}
+
// UserAgent returns the client's User-Agent, if sent in the request.
func (r *Request) UserAgent() string {
return r.Header.Get("User-Agent")
@@ -319,6 +368,8 @@ var ErrNoCookie = errors.New("http: named cookie not present")
// Cookie returns the named cookie provided in the request or
// ErrNoCookie if not found.
+// If multiple cookies match the given name, only one cookie will
+// be returned.
func (r *Request) Cookie(name string) (*Cookie, error) {
for _, c := range readCookies(r.Header, name) {
return c, nil
@@ -561,6 +612,12 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai
}
}
+ if bw, ok := w.(*bufio.Writer); ok && tw.FlushHeaders {
+ if err := bw.Flush(); err != nil {
+ return err
+ }
+ }
+
// Write body and trailer
err = tw.WriteBody(w)
if err != nil {
@@ -573,7 +630,24 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai
return nil
}
-// cleanHost strips anything after '/' or ' '.
+func idnaASCII(v string) (string, error) {
+ if isASCII(v) {
+ return v, nil
+ }
+ // The idna package doesn't do everything from
+ // https://tools.ietf.org/html/rfc5895 so we do it here.
+ // TODO(bradfitz): should the idna package do this instead?
+ v = strings.ToLower(v)
+ v = width.Fold.String(v)
+ v = norm.NFC.String(v)
+ return idna.ToASCII(v)
+}
+
+// cleanHost cleans up the host sent in request's Host header.
+//
+// It both strips anything after '/' or ' ', and puts the value
+// into Punycode form, if necessary.
+//
// Ideally we'd clean the Host header according to the spec:
// https://tools.ietf.org/html/rfc7230#section-5.4 (Host = uri-host [ ":" port ]")
// https://tools.ietf.org/html/rfc7230#section-2.7 (uri-host -> rfc3986's host)
@@ -584,9 +658,21 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai
// first offending character.
func cleanHost(in string) string {
if i := strings.IndexAny(in, " /"); i != -1 {
- return in[:i]
+ in = in[:i]
+ }
+ host, port, err := net.SplitHostPort(in)
+ if err != nil { // input was just a host
+ a, err := idnaASCII(in)
+ if err != nil {
+ return in // garbage in, garbage out
+ }
+ return a
}
- return in
+ a, err := idnaASCII(host)
+ if err != nil {
+ return in // garbage in, garbage out
+ }
+ return net.JoinHostPort(a, port)
}
// removeZone removes IPv6 zone identifier from host.
@@ -658,11 +744,17 @@ func validMethod(method string) bool {
// methods Do, Post, and PostForm, and Transport.RoundTrip.
//
// NewRequest returns a Request suitable for use with Client.Do or
-// Transport.RoundTrip.
-// To create a request for use with testing a Server Handler use either
-// ReadRequest or manually update the Request fields. See the Request
-// type's documentation for the difference between inbound and outbound
-// request fields.
+// Transport.RoundTrip. To create a request for use with testing a
+// Server Handler, either use the NewRequest function in the
+// net/http/httptest package, use ReadRequest, or manually update the
+// Request fields. See the Request type's documentation for the
+// difference between inbound and outbound request fields.
+//
+// If body is of type *bytes.Buffer, *bytes.Reader, or
+// *strings.Reader, the returned request's ContentLength is set to its
+// exact value (instead of -1), GetBody is populated (so 307 and 308
+// redirects can replay the body), and Body is set to NoBody if the
+// ContentLength is 0.
func NewRequest(method, urlStr string, body io.Reader) (*Request, error) {
if method == "" {
// We document that "" means "GET" for Request.Method, and people have
@@ -697,10 +789,43 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) {
switch v := body.(type) {
case *bytes.Buffer:
req.ContentLength = int64(v.Len())
+ buf := v.Bytes()
+ req.GetBody = func() (io.ReadCloser, error) {
+ r := bytes.NewReader(buf)
+ return ioutil.NopCloser(r), nil
+ }
case *bytes.Reader:
req.ContentLength = int64(v.Len())
+ snapshot := *v
+ req.GetBody = func() (io.ReadCloser, error) {
+ r := snapshot
+ return ioutil.NopCloser(&r), nil
+ }
case *strings.Reader:
req.ContentLength = int64(v.Len())
+ snapshot := *v
+ req.GetBody = func() (io.ReadCloser, error) {
+ r := snapshot
+ return ioutil.NopCloser(&r), nil
+ }
+ default:
+ // This is where we'd set it to -1 (at least
+ // if body != NoBody) to mean unknown, but
+ // that broke people during the Go 1.8 testing
+ // period. People depend on it being 0 I
+ // guess. Maybe retry later. See Issue 18117.
+ }
+ // For client requests, Request.ContentLength of 0
+ // means either actually 0, or unknown. The only way
+ // to explicitly say that the ContentLength is zero is
+ // to set the Body to nil. But turns out too much code
+ // depends on NewRequest returning a non-nil Body,
+ // so we use a well-known ReadCloser variable instead
+ // and have the http package also treat that sentinel
+ // variable to mean explicitly zero.
+ if req.GetBody != nil && req.ContentLength == 0 {
+ req.Body = NoBody
+ req.GetBody = func() (io.ReadCloser, error) { return NoBody, nil }
}
}
@@ -1000,18 +1125,24 @@ func parsePostForm(r *Request) (vs url.Values, err error) {
return
}
-// ParseForm parses the raw query from the URL and updates r.Form.
+// ParseForm populates r.Form and r.PostForm.
+//
+// For all requests, ParseForm parses the raw query from the URL and updates
+// r.Form.
+//
+// For POST, PUT, and PATCH requests, it also parses the request body as a form
+// and puts the results into both r.PostForm and r.Form. Request body parameters
+// take precedence over URL query string values in r.Form.
//
-// For POST or PUT requests, it also parses the request body as a form and
-// put the results into both r.PostForm and r.Form.
-// POST and PUT body parameters take precedence over URL query string values
-// in r.Form.
+// For other HTTP methods, or when the Content-Type is not
+// application/x-www-form-urlencoded, the request Body is not read, and
+// r.PostForm is initialized to a non-nil, empty value.
//
// If the request Body's size has not already been limited by MaxBytesReader,
// the size is capped at 10MB.
//
// ParseMultipartForm calls ParseForm automatically.
-// It is idempotent.
+// ParseForm is idempotent.
func (r *Request) ParseForm() error {
var err error
if r.PostForm == nil {
@@ -1174,3 +1305,30 @@ func (r *Request) isReplayable() bool {
}
return false
}
+
+// outgoingLength reports the Content-Length of this outgoing (Client) request.
+// It maps 0 into -1 (unknown) when the Body is non-nil.
+func (r *Request) outgoingLength() int64 {
+ if r.Body == nil || r.Body == NoBody {
+ return 0
+ }
+ if r.ContentLength != 0 {
+ return r.ContentLength
+ }
+ return -1
+}
+
+// requestMethodUsuallyLacksBody reports whether the given request
+// method is one that typically does not involve a request body.
+// This is used by the Transport (via
+// transferWriter.shouldSendChunkedRequestBody) to determine whether
+// we try to test-read a byte from a non-nil Request.Body when
+// Request.outgoingLength() returns -1. See the comments in
+// shouldSendChunkedRequestBody.
+func requestMethodUsuallyLacksBody(method string) bool {
+ switch method {
+ case "GET", "HEAD", "DELETE", "OPTIONS", "PROPFIND", "SEARCH":
+ return true
+ }
+ return false
+}
diff --git a/libgo/go/net/http/request_test.go b/libgo/go/net/http/request_test.go
index a4c88c0..e674837 100644
--- a/libgo/go/net/http/request_test.go
+++ b/libgo/go/net/http/request_test.go
@@ -29,9 +29,9 @@ func TestQuery(t *testing.T) {
}
}
-func TestPostQuery(t *testing.T) {
- req, _ := NewRequest("POST", "http://www.google.com/search?q=foo&q=bar&both=x&prio=1&empty=not",
- strings.NewReader("z=post&both=y&prio=2&empty="))
+func TestParseFormQuery(t *testing.T) {
+ req, _ := NewRequest("POST", "http://www.google.com/search?q=foo&q=bar&both=x&prio=1&orphan=nope&empty=not",
+ strings.NewReader("z=post&both=y&prio=2&=nokey&orphan;empty=&"))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
if q := req.FormValue("q"); q != "foo" {
@@ -55,39 +55,30 @@ func TestPostQuery(t *testing.T) {
if prio := req.FormValue("prio"); prio != "2" {
t.Errorf(`req.FormValue("prio") = %q, want "2" (from body)`, prio)
}
- if empty := req.FormValue("empty"); empty != "" {
+ if orphan := req.Form["orphan"]; !reflect.DeepEqual(orphan, []string{"", "nope"}) {
+ t.Errorf(`req.FormValue("orphan") = %q, want "" (from body)`, orphan)
+ }
+ if empty := req.Form["empty"]; !reflect.DeepEqual(empty, []string{"", "not"}) {
t.Errorf(`req.FormValue("empty") = %q, want "" (from body)`, empty)
}
+ if nokey := req.Form[""]; !reflect.DeepEqual(nokey, []string{"nokey"}) {
+ t.Errorf(`req.FormValue("nokey") = %q, want "nokey" (from body)`, nokey)
+ }
}
-func TestPatchQuery(t *testing.T) {
- req, _ := NewRequest("PATCH", "http://www.google.com/search?q=foo&q=bar&both=x&prio=1&empty=not",
- strings.NewReader("z=post&both=y&prio=2&empty="))
- req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
-
- if q := req.FormValue("q"); q != "foo" {
- t.Errorf(`req.FormValue("q") = %q, want "foo"`, q)
- }
- if z := req.FormValue("z"); z != "post" {
- t.Errorf(`req.FormValue("z") = %q, want "post"`, z)
- }
- if bq, found := req.PostForm["q"]; found {
- t.Errorf(`req.PostForm["q"] = %q, want no entry in map`, bq)
- }
- if bz := req.PostFormValue("z"); bz != "post" {
- t.Errorf(`req.PostFormValue("z") = %q, want "post"`, bz)
- }
- if qs := req.Form["q"]; !reflect.DeepEqual(qs, []string{"foo", "bar"}) {
- t.Errorf(`req.Form["q"] = %q, want ["foo", "bar"]`, qs)
- }
- if both := req.Form["both"]; !reflect.DeepEqual(both, []string{"y", "x"}) {
- t.Errorf(`req.Form["both"] = %q, want ["y", "x"]`, both)
- }
- if prio := req.FormValue("prio"); prio != "2" {
- t.Errorf(`req.FormValue("prio") = %q, want "2" (from body)`, prio)
- }
- if empty := req.FormValue("empty"); empty != "" {
- t.Errorf(`req.FormValue("empty") = %q, want "" (from body)`, empty)
+// Tests that we only parse the form automatically for certain methods.
+func TestParseFormQueryMethods(t *testing.T) {
+ for _, method := range []string{"POST", "PATCH", "PUT", "FOO"} {
+ req, _ := NewRequest(method, "http://www.google.com/search",
+ strings.NewReader("foo=bar"))
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
+ want := "bar"
+ if method == "FOO" {
+ want = ""
+ }
+ if got := req.FormValue("foo"); got != want {
+ t.Errorf(`for method %s, FormValue("foo") = %q; want %q`, method, got, want)
+ }
}
}
@@ -374,18 +365,68 @@ func TestFormFileOrder(t *testing.T) {
var readRequestErrorTests = []struct {
in string
- err error
+ err string
+
+ header Header
}{
- {"GET / HTTP/1.1\r\nheader:foo\r\n\r\n", nil},
- {"GET / HTTP/1.1\r\nheader:foo\r\n", io.ErrUnexpectedEOF},
- {"", io.EOF},
+ 0: {"GET / HTTP/1.1\r\nheader:foo\r\n\r\n", "", Header{"Header": {"foo"}}},
+ 1: {"GET / HTTP/1.1\r\nheader:foo\r\n", io.ErrUnexpectedEOF.Error(), nil},
+ 2: {"", io.EOF.Error(), nil},
+ 3: {
+ in: "HEAD / HTTP/1.1\r\nContent-Length:4\r\n\r\n",
+ err: "http: method cannot contain a Content-Length",
+ },
+ 4: {
+ in: "HEAD / HTTP/1.1\r\n\r\n",
+ header: Header{},
+ },
+
+ // Multiple Content-Length values should either be
+ // deduplicated if same or reject otherwise
+ // See Issue 16490.
+ 5: {
+ in: "POST / HTTP/1.1\r\nContent-Length: 10\r\nContent-Length: 0\r\n\r\nGopher hey\r\n",
+ err: "cannot contain multiple Content-Length headers",
+ },
+ 6: {
+ in: "POST / HTTP/1.1\r\nContent-Length: 10\r\nContent-Length: 6\r\n\r\nGopher\r\n",
+ err: "cannot contain multiple Content-Length headers",
+ },
+ 7: {
+ in: "PUT / HTTP/1.1\r\nContent-Length: 6 \r\nContent-Length: 6\r\nContent-Length:6\r\n\r\nGopher\r\n",
+ err: "",
+ header: Header{"Content-Length": {"6"}},
+ },
+ 8: {
+ in: "PUT / HTTP/1.1\r\nContent-Length: 1\r\nContent-Length: 6 \r\n\r\n",
+ err: "cannot contain multiple Content-Length headers",
+ },
+ 9: {
+ in: "POST / HTTP/1.1\r\nContent-Length:\r\nContent-Length: 3\r\n\r\n",
+ err: "cannot contain multiple Content-Length headers",
+ },
+ 10: {
+ in: "HEAD / HTTP/1.1\r\nContent-Length:0\r\nContent-Length: 0\r\n\r\n",
+ header: Header{"Content-Length": {"0"}},
+ },
}
func TestReadRequestErrors(t *testing.T) {
for i, tt := range readRequestErrorTests {
- _, err := ReadRequest(bufio.NewReader(strings.NewReader(tt.in)))
- if err != tt.err {
- t.Errorf("%d. got error = %v; want %v", i, err, tt.err)
+ req, err := ReadRequest(bufio.NewReader(strings.NewReader(tt.in)))
+ if err == nil {
+ if tt.err != "" {
+ t.Errorf("#%d: got nil err; want %q", i, tt.err)
+ }
+
+ if !reflect.DeepEqual(tt.header, req.Header) {
+ t.Errorf("#%d: gotHeader: %q wantHeader: %q", i, req.Header, tt.header)
+ }
+ continue
+ }
+
+ if tt.err == "" || !strings.Contains(err.Error(), tt.err) {
+ t.Errorf("%d: got error = %v; want %v", i, err, tt.err)
}
}
}
@@ -456,18 +497,23 @@ func TestNewRequestContentLength(t *testing.T) {
{bytes.NewReader([]byte("123")), 3},
{bytes.NewBuffer([]byte("1234")), 4},
{strings.NewReader("12345"), 5},
- // Not detected:
+ {strings.NewReader(""), 0},
+ {NoBody, 0},
+
+ // Not detected. During Go 1.8 we tried to make these set to -1, but
+ // due to Issue 18117, we keep these returning 0, even though they're
+ // unknown.
{struct{ io.Reader }{strings.NewReader("xyz")}, 0},
{io.NewSectionReader(strings.NewReader("x"), 0, 6), 0},
{readByte(io.NewSectionReader(strings.NewReader("xy"), 0, 6)), 0},
}
- for _, tt := range tests {
+ for i, tt := range tests {
req, err := NewRequest("POST", "http://localhost/", tt.r)
if err != nil {
t.Fatal(err)
}
if req.ContentLength != tt.want {
- t.Errorf("ContentLength(%T) = %d; want %d", tt.r, req.ContentLength, tt.want)
+ t.Errorf("test[%d]: ContentLength(%T) = %d; want %d", i, tt.r, req.ContentLength, tt.want)
}
}
}
@@ -626,11 +672,31 @@ func TestStarRequest(t *testing.T) {
if err != nil {
return
}
+ if req.ContentLength != 0 {
+ t.Errorf("ContentLength = %d; want 0", req.ContentLength)
+ }
+ if req.Body == nil {
+ t.Errorf("Body = nil; want non-nil")
+ }
+
+ // Request.Write has Client semantics for Body/ContentLength,
+ // where ContentLength 0 means unknown if Body is non-nil, and
+ // thus chunking will happen unless we change semantics and
+ // signal that we want to serialize it as exactly zero. The
+ // only way to do that for outbound requests is with a nil
+ // Body:
+ clientReq := *req
+ clientReq.Body = nil
+
var out bytes.Buffer
- if err := req.Write(&out); err != nil {
+ if err := clientReq.Write(&out); err != nil {
t.Fatal(err)
}
- back, err := ReadRequest(bufio.NewReader(&out))
+
+ if strings.Contains(out.String(), "chunked") {
+ t.Error("wrote chunked request; want no body")
+ }
+ back, err := ReadRequest(bufio.NewReader(bytes.NewReader(out.Bytes())))
if err != nil {
t.Fatal(err)
}
@@ -719,6 +785,47 @@ func TestMaxBytesReaderStickyError(t *testing.T) {
}
}
+// verify that NewRequest sets Request.GetBody and that it works
+func TestNewRequestGetBody(t *testing.T) {
+ tests := []struct {
+ r io.Reader
+ }{
+ {r: strings.NewReader("hello")},
+ {r: bytes.NewReader([]byte("hello"))},
+ {r: bytes.NewBuffer([]byte("hello"))},
+ }
+ for i, tt := range tests {
+ req, err := NewRequest("POST", "http://foo.tld/", tt.r)
+ if err != nil {
+ t.Errorf("test[%d]: %v", i, err)
+ continue
+ }
+ if req.Body == nil {
+ t.Errorf("test[%d]: Body = nil", i)
+ continue
+ }
+ if req.GetBody == nil {
+ t.Errorf("test[%d]: GetBody = nil", i)
+ continue
+ }
+ slurp1, err := ioutil.ReadAll(req.Body)
+ if err != nil {
+ t.Errorf("test[%d]: ReadAll(Body) = %v", i, err)
+ }
+ newBody, err := req.GetBody()
+ if err != nil {
+ t.Errorf("test[%d]: GetBody = %v", i, err)
+ }
+ slurp2, err := ioutil.ReadAll(newBody)
+ if err != nil {
+ t.Errorf("test[%d]: ReadAll(GetBody()) = %v", i, err)
+ }
+ if string(slurp1) != string(slurp2) {
+ t.Errorf("test[%d]: Body %q != GetBody %q", i, slurp1, slurp2)
+ }
+ }
+}
+
func testMissingFile(t *testing.T, req *Request) {
f, fh, err := req.FormFile("missing")
if f != nil {
diff --git a/libgo/go/net/http/requestwrite_test.go b/libgo/go/net/http/requestwrite_test.go
index 2545f6f..eb65b9f 100644
--- a/libgo/go/net/http/requestwrite_test.go
+++ b/libgo/go/net/http/requestwrite_test.go
@@ -5,14 +5,17 @@
package http
import (
+ "bufio"
"bytes"
"errors"
"fmt"
"io"
"io/ioutil"
+ "net"
"net/url"
"strings"
"testing"
+ "time"
)
type reqWriteTest struct {
@@ -28,7 +31,7 @@ type reqWriteTest struct {
var reqWriteTests = []reqWriteTest{
// HTTP/1.1 => chunked coding; no body; no trailer
- {
+ 0: {
Req: Request{
Method: "GET",
URL: &url.URL{
@@ -75,7 +78,7 @@ var reqWriteTests = []reqWriteTest{
"Proxy-Connection: keep-alive\r\n\r\n",
},
// HTTP/1.1 => chunked coding; body; empty trailer
- {
+ 1: {
Req: Request{
Method: "GET",
URL: &url.URL{
@@ -104,7 +107,7 @@ var reqWriteTests = []reqWriteTest{
chunk("abcdef") + chunk(""),
},
// HTTP/1.1 POST => chunked coding; body; empty trailer
- {
+ 2: {
Req: Request{
Method: "POST",
URL: &url.URL{
@@ -137,7 +140,7 @@ var reqWriteTests = []reqWriteTest{
},
// HTTP/1.1 POST with Content-Length, no chunking
- {
+ 3: {
Req: Request{
Method: "POST",
URL: &url.URL{
@@ -172,7 +175,7 @@ var reqWriteTests = []reqWriteTest{
},
// HTTP/1.1 POST with Content-Length in headers
- {
+ 4: {
Req: Request{
Method: "POST",
URL: mustParseURL("http://example.com/"),
@@ -201,7 +204,7 @@ var reqWriteTests = []reqWriteTest{
},
// default to HTTP/1.1
- {
+ 5: {
Req: Request{
Method: "GET",
URL: mustParseURL("/search"),
@@ -215,7 +218,7 @@ var reqWriteTests = []reqWriteTest{
},
// Request with a 0 ContentLength and a 0 byte body.
- {
+ 6: {
Req: Request{
Method: "POST",
URL: mustParseURL("/"),
@@ -227,9 +230,32 @@ var reqWriteTests = []reqWriteTest{
Body: func() io.ReadCloser { return ioutil.NopCloser(io.LimitReader(strings.NewReader("xx"), 0)) },
- // RFC 2616 Section 14.13 says Content-Length should be specified
- // unless body is prohibited by the request method.
- // Also, nginx expects it for POST and PUT.
+ WantWrite: "POST / HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Transfer-Encoding: chunked\r\n" +
+ "\r\n0\r\n\r\n",
+
+ WantProxy: "POST / HTTP/1.1\r\n" +
+ "Host: example.com\r\n" +
+ "User-Agent: Go-http-client/1.1\r\n" +
+ "Transfer-Encoding: chunked\r\n" +
+ "\r\n0\r\n\r\n",
+ },
+
+ // Request with a 0 ContentLength and a nil body.
+ 7: {
+ Req: Request{
+ Method: "POST",
+ URL: mustParseURL("/"),
+ Host: "example.com",
+ ProtoMajor: 1,
+ ProtoMinor: 1,
+ ContentLength: 0, // as if unset by user
+ },
+
+ Body: func() io.ReadCloser { return nil },
+
WantWrite: "POST / HTTP/1.1\r\n" +
"Host: example.com\r\n" +
"User-Agent: Go-http-client/1.1\r\n" +
@@ -244,7 +270,7 @@ var reqWriteTests = []reqWriteTest{
},
// Request with a 0 ContentLength and a 1 byte body.
- {
+ 8: {
Req: Request{
Method: "POST",
URL: mustParseURL("/"),
@@ -270,7 +296,7 @@ var reqWriteTests = []reqWriteTest{
},
// Request with a ContentLength of 10 but a 5 byte body.
- {
+ 9: {
Req: Request{
Method: "POST",
URL: mustParseURL("/"),
@@ -284,7 +310,7 @@ var reqWriteTests = []reqWriteTest{
},
// Request with a ContentLength of 4 but an 8 byte body.
- {
+ 10: {
Req: Request{
Method: "POST",
URL: mustParseURL("/"),
@@ -298,7 +324,7 @@ var reqWriteTests = []reqWriteTest{
},
// Request with a 5 ContentLength and nil body.
- {
+ 11: {
Req: Request{
Method: "POST",
URL: mustParseURL("/"),
@@ -311,7 +337,7 @@ var reqWriteTests = []reqWriteTest{
},
// Request with a 0 ContentLength and a body with 1 byte content and an error.
- {
+ 12: {
Req: Request{
Method: "POST",
URL: mustParseURL("/"),
@@ -331,7 +357,7 @@ var reqWriteTests = []reqWriteTest{
},
// Request with a 0 ContentLength and a body without content and an error.
- {
+ 13: {
Req: Request{
Method: "POST",
URL: mustParseURL("/"),
@@ -352,7 +378,7 @@ var reqWriteTests = []reqWriteTest{
// Verify that DumpRequest preserves the HTTP version number, doesn't add a Host,
// and doesn't add a User-Agent.
- {
+ 14: {
Req: Request{
Method: "GET",
URL: mustParseURL("/foo"),
@@ -373,7 +399,7 @@ var reqWriteTests = []reqWriteTest{
// an empty Host header, and don't use
// Request.Header["Host"]. This is just testing that
// we don't change Go 1.0 behavior.
- {
+ 15: {
Req: Request{
Method: "GET",
Host: "",
@@ -395,7 +421,7 @@ var reqWriteTests = []reqWriteTest{
},
// Opaque test #1 from golang.org/issue/4860
- {
+ 16: {
Req: Request{
Method: "GET",
URL: &url.URL{
@@ -414,7 +440,7 @@ var reqWriteTests = []reqWriteTest{
},
// Opaque test #2 from golang.org/issue/4860
- {
+ 17: {
Req: Request{
Method: "GET",
URL: &url.URL{
@@ -433,7 +459,7 @@ var reqWriteTests = []reqWriteTest{
},
// Testing custom case in header keys. Issue 5022.
- {
+ 18: {
Req: Request{
Method: "GET",
URL: &url.URL{
@@ -457,7 +483,7 @@ var reqWriteTests = []reqWriteTest{
},
// Request with host header field; IPv6 address with zone identifier
- {
+ 19: {
Req: Request{
Method: "GET",
URL: &url.URL{
@@ -472,7 +498,7 @@ var reqWriteTests = []reqWriteTest{
},
// Request with optional host header field; IPv6 address with zone identifier
- {
+ 20: {
Req: Request{
Method: "GET",
URL: &url.URL{
@@ -543,6 +569,138 @@ func TestRequestWrite(t *testing.T) {
}
}
+func TestRequestWriteTransport(t *testing.T) {
+ t.Parallel()
+
+ matchSubstr := func(substr string) func(string) error {
+ return func(written string) error {
+ if !strings.Contains(written, substr) {
+ return fmt.Errorf("expected substring %q in request: %s", substr, written)
+ }
+ return nil
+ }
+ }
+
+ noContentLengthOrTransferEncoding := func(req string) error {
+ if strings.Contains(req, "Content-Length: ") {
+ return fmt.Errorf("unexpected Content-Length in request: %s", req)
+ }
+ if strings.Contains(req, "Transfer-Encoding: ") {
+ return fmt.Errorf("unexpected Transfer-Encoding in request: %s", req)
+ }
+ return nil
+ }
+
+ all := func(checks ...func(string) error) func(string) error {
+ return func(req string) error {
+ for _, c := range checks {
+ if err := c(req); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }
+
+ type testCase struct {
+ method string
+ clen int64 // ContentLength
+ body io.ReadCloser
+ want func(string) error
+
+ // optional:
+ init func(*testCase)
+ afterReqRead func()
+ }
+
+ tests := []testCase{
+ {
+ method: "GET",
+ want: noContentLengthOrTransferEncoding,
+ },
+ {
+ method: "GET",
+ body: ioutil.NopCloser(strings.NewReader("")),
+ want: noContentLengthOrTransferEncoding,
+ },
+ {
+ method: "GET",
+ clen: -1,
+ body: ioutil.NopCloser(strings.NewReader("")),
+ want: noContentLengthOrTransferEncoding,
+ },
+ // A GET with a body, with explicit content length:
+ {
+ method: "GET",
+ clen: 7,
+ body: ioutil.NopCloser(strings.NewReader("foobody")),
+ want: all(matchSubstr("Content-Length: 7"),
+ matchSubstr("foobody")),
+ },
+ // A GET with a body, sniffing the leading "f" from "foobody".
+ {
+ method: "GET",
+ clen: -1,
+ body: ioutil.NopCloser(strings.NewReader("foobody")),
+ want: all(matchSubstr("Transfer-Encoding: chunked"),
+ matchSubstr("\r\n1\r\nf\r\n"),
+ matchSubstr("oobody")),
+ },
+ // But a POST request is expected to have a body, so
+ // no sniffing happens:
+ {
+ method: "POST",
+ clen: -1,
+ body: ioutil.NopCloser(strings.NewReader("foobody")),
+ want: all(matchSubstr("Transfer-Encoding: chunked"),
+ matchSubstr("foobody")),
+ },
+ {
+ method: "POST",
+ clen: -1,
+ body: ioutil.NopCloser(strings.NewReader("")),
+ want: all(matchSubstr("Transfer-Encoding: chunked")),
+ },
+ // Verify that a blocking Request.Body doesn't block forever.
+ {
+ method: "GET",
+ clen: -1,
+ init: func(tt *testCase) {
+ pr, pw := io.Pipe()
+ tt.afterReqRead = func() {
+ pw.Close()
+ }
+ tt.body = ioutil.NopCloser(pr)
+ },
+ want: matchSubstr("Transfer-Encoding: chunked"),
+ },
+ }
+
+ for i, tt := range tests {
+ if tt.init != nil {
+ tt.init(&tt)
+ }
+ req := &Request{
+ Method: tt.method,
+ URL: &url.URL{
+ Scheme: "http",
+ Host: "example.com",
+ },
+ Header: make(Header),
+ ContentLength: tt.clen,
+ Body: tt.body,
+ }
+ got, err := dumpRequestOut(req, tt.afterReqRead)
+ if err != nil {
+ t.Errorf("test[%d]: %v", i, err)
+ continue
+ }
+ if err := tt.want(string(got)); err != nil {
+ t.Errorf("test[%d]: %v", i, err)
+ }
+ }
+}
+
type closeChecker struct {
io.Reader
closed bool
@@ -553,17 +711,19 @@ func (rc *closeChecker) Close() error {
return nil
}
-// TestRequestWriteClosesBody tests that Request.Write does close its request.Body.
+// TestRequestWriteClosesBody tests that Request.Write closes its request.Body.
// It also indirectly tests NewRequest and that it doesn't wrap an existing Closer
// inside a NopCloser, and that it serializes it correctly.
func TestRequestWriteClosesBody(t *testing.T) {
rc := &closeChecker{Reader: strings.NewReader("my body")}
- req, _ := NewRequest("POST", "http://foo.com/", rc)
- if req.ContentLength != 0 {
- t.Errorf("got req.ContentLength %d, want 0", req.ContentLength)
+ req, err := NewRequest("POST", "http://foo.com/", rc)
+ if err != nil {
+ t.Fatal(err)
}
buf := new(bytes.Buffer)
- req.Write(buf)
+ if err := req.Write(buf); err != nil {
+ t.Error(err)
+ }
if !rc.closed {
t.Error("body not closed after write")
}
@@ -571,12 +731,7 @@ func TestRequestWriteClosesBody(t *testing.T) {
"Host: foo.com\r\n" +
"User-Agent: Go-http-client/1.1\r\n" +
"Transfer-Encoding: chunked\r\n\r\n" +
- // TODO: currently we don't buffer before chunking, so we get a
- // single "m" chunk before the other chunks, as this was the 1-byte
- // read from our MultiReader where we stitched the Body back together
- // after sniffing whether the Body was 0 bytes or not.
- chunk("m") +
- chunk("y body") +
+ chunk("my body") +
chunk("")
if buf.String() != expected {
t.Errorf("write:\n got: %s\nwant: %s", buf.String(), expected)
@@ -652,3 +807,76 @@ func TestRequestWriteError(t *testing.T) {
t.Fatalf("writeCalls constant is outdated in test")
}
}
+
+// dumpRequestOut is a modified copy of net/http/httputil.DumpRequestOut.
+// Unlike the original, this version doesn't mutate the req.Body and
+// try to restore it. It always dumps the whole body.
+// And it doesn't support https.
+func dumpRequestOut(req *Request, onReadHeaders func()) ([]byte, error) {
+
+ // Use the actual Transport code to record what we would send
+ // on the wire, but not using TCP. Use a Transport with a
+ // custom dialer that returns a fake net.Conn that waits
+ // for the full input (and recording it), and then responds
+ // with a dummy response.
+ var buf bytes.Buffer // records the output
+ pr, pw := io.Pipe()
+ defer pr.Close()
+ defer pw.Close()
+ dr := &delegateReader{c: make(chan io.Reader)}
+
+ t := &Transport{
+ Dial: func(net, addr string) (net.Conn, error) {
+ return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil
+ },
+ }
+ defer t.CloseIdleConnections()
+
+ // Wait for the request before replying with a dummy response:
+ go func() {
+ req, err := ReadRequest(bufio.NewReader(pr))
+ if err == nil {
+ if onReadHeaders != nil {
+ onReadHeaders()
+ }
+ // Ensure all the body is read; otherwise
+ // we'll get a partial dump.
+ io.Copy(ioutil.Discard, req.Body)
+ req.Body.Close()
+ }
+ dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n")
+ }()
+
+ _, err := t.RoundTrip(req)
+ if err != nil {
+ return nil, err
+ }
+ return buf.Bytes(), nil
+}
+
+// delegateReader is a reader that delegates to another reader,
+// once it arrives on a channel.
+type delegateReader struct {
+ c chan io.Reader
+ r io.Reader // nil until received from c
+}
+
+func (r *delegateReader) Read(p []byte) (int, error) {
+ if r.r == nil {
+ r.r = <-r.c
+ }
+ return r.r.Read(p)
+}
+
+// dumpConn is a net.Conn that writes to Writer and reads from Reader.
+type dumpConn struct {
+ io.Writer
+ io.Reader
+}
+
+func (c *dumpConn) Close() error { return nil }
+func (c *dumpConn) LocalAddr() net.Addr { return nil }
+func (c *dumpConn) RemoteAddr() net.Addr { return nil }
+func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
+func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
+func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
diff --git a/libgo/go/net/http/response.go b/libgo/go/net/http/response.go
index 5450d50..ae118fb 100644
--- a/libgo/go/net/http/response.go
+++ b/libgo/go/net/http/response.go
@@ -261,7 +261,7 @@ func (r *Response) Write(w io.Writer) error {
if n == 0 {
// Reset it to a known zero reader, in case underlying one
// is unhappy being read repeatedly.
- r1.Body = eofReader
+ r1.Body = NoBody
} else {
r1.ContentLength = -1
r1.Body = struct {
@@ -300,7 +300,7 @@ func (r *Response) Write(w io.Writer) error {
// contentLengthAlreadySent may have been already sent for
// POST/PUT requests, even if zero length. See Issue 8180.
contentLengthAlreadySent := tw.shouldSendContentLength()
- if r1.ContentLength == 0 && !chunked(r1.TransferEncoding) && !contentLengthAlreadySent {
+ if r1.ContentLength == 0 && !chunked(r1.TransferEncoding) && !contentLengthAlreadySent && bodyAllowedForStatus(r.StatusCode) {
if _, err := io.WriteString(w, "Content-Length: 0\r\n"); err != nil {
return err
}
diff --git a/libgo/go/net/http/response_test.go b/libgo/go/net/http/response_test.go
index 126da92..660d517 100644
--- a/libgo/go/net/http/response_test.go
+++ b/libgo/go/net/http/response_test.go
@@ -589,6 +589,7 @@ var readResponseCloseInMiddleTests = []struct {
// reading only part of its contents advances the read to the end of
// the request, right up until the next request.
func TestReadResponseCloseInMiddle(t *testing.T) {
+ t.Parallel()
for _, test := range readResponseCloseInMiddleTests {
fatalf := func(format string, args ...interface{}) {
args = append([]interface{}{test.chunked, test.compressed}, args...)
@@ -792,6 +793,7 @@ func TestReadResponseErrors(t *testing.T) {
type testCase struct {
name string // optional, defaults to in
in string
+ header Header
wantErr interface{} // nil, err value, or string substring
}
@@ -817,11 +819,22 @@ func TestReadResponseErrors(t *testing.T) {
}
}
+ contentLength := func(status, body string, wantErr interface{}, header Header) testCase {
+ return testCase{
+ name: fmt.Sprintf("status %q %q", status, body),
+ in: fmt.Sprintf("HTTP/1.1 %s\r\n%s", status, body),
+ wantErr: wantErr,
+ header: header,
+ }
+ }
+
+ errMultiCL := "message cannot contain multiple Content-Length headers"
+
tests := []testCase{
- {"", "", io.ErrUnexpectedEOF},
- {"", "HTTP/1.1 301 Moved Permanently\r\nFoo: bar", io.ErrUnexpectedEOF},
- {"", "HTTP/1.1", "malformed HTTP response"},
- {"", "HTTP/2.0", "malformed HTTP response"},
+ {"", "", nil, io.ErrUnexpectedEOF},
+ {"", "HTTP/1.1 301 Moved Permanently\r\nFoo: bar", nil, io.ErrUnexpectedEOF},
+ {"", "HTTP/1.1", nil, "malformed HTTP response"},
+ {"", "HTTP/2.0", nil, "malformed HTTP response"},
status("20X Unknown", true),
status("abcd Unknown", true),
status("二百/两百 OK", true),
@@ -846,7 +859,21 @@ func TestReadResponseErrors(t *testing.T) {
version("HTTP/A.B", true),
version("HTTP/1", true),
version("http/1.1", true),
+
+ contentLength("200 OK", "Content-Length: 10\r\nContent-Length: 7\r\n\r\nGopher hey\r\n", errMultiCL, nil),
+ contentLength("200 OK", "Content-Length: 7\r\nContent-Length: 7\r\n\r\nGophers\r\n", nil, Header{"Content-Length": {"7"}}),
+ contentLength("201 OK", "Content-Length: 0\r\nContent-Length: 7\r\n\r\nGophers\r\n", errMultiCL, nil),
+ contentLength("300 OK", "Content-Length: 0\r\nContent-Length: 0 \r\n\r\nGophers\r\n", nil, Header{"Content-Length": {"0"}}),
+ contentLength("200 OK", "Content-Length:\r\nContent-Length:\r\n\r\nGophers\r\n", nil, nil),
+ contentLength("206 OK", "Content-Length:\r\nContent-Length: 0 \r\nConnection: close\r\n\r\nGophers\r\n", errMultiCL, nil),
+
+ // multiple content-length headers for 204 and 304 should still be checked
+ contentLength("204 OK", "Content-Length: 7\r\nContent-Length: 8\r\n\r\n", errMultiCL, nil),
+ contentLength("204 OK", "Content-Length: 3\r\nContent-Length: 3\r\n\r\n", nil, nil),
+ contentLength("304 OK", "Content-Length: 880\r\nContent-Length: 1\r\n\r\n", errMultiCL, nil),
+ contentLength("304 OK", "Content-Length: 961\r\nContent-Length: 961\r\n\r\n", nil, nil),
}
+
for i, tt := range tests {
br := bufio.NewReader(strings.NewReader(tt.in))
_, rerr := ReadResponse(br, nil)
diff --git a/libgo/go/net/http/responsewrite_test.go b/libgo/go/net/http/responsewrite_test.go
index 90f6767..d41d898 100644
--- a/libgo/go/net/http/responsewrite_test.go
+++ b/libgo/go/net/http/responsewrite_test.go
@@ -241,7 +241,8 @@ func TestResponseWrite(t *testing.T) {
"HTTP/1.0 007 license to violate specs\r\nContent-Length: 0\r\n\r\n",
},
- // No stutter.
+ // No stutter. Status code in 1xx range response should
+ // not include a Content-Length header. See issue #16942.
{
Response{
StatusCode: 123,
@@ -253,7 +254,23 @@ func TestResponseWrite(t *testing.T) {
Body: nil,
},
- "HTTP/1.0 123 Sesame Street\r\nContent-Length: 0\r\n\r\n",
+ "HTTP/1.0 123 Sesame Street\r\n\r\n",
+ },
+
+ // Status code 204 (No content) response should not include a
+ // Content-Length header. See issue #16942.
+ {
+ Response{
+ StatusCode: 204,
+ Status: "No Content",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Body: nil,
+ },
+
+ "HTTP/1.0 204 No Content\r\n\r\n",
},
}
diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go
index 13e5f28..072da25 100644
--- a/libgo/go/net/http/serve_test.go
+++ b/libgo/go/net/http/serve_test.go
@@ -156,6 +156,7 @@ func (ht handlerTest) rawResponse(req string) string {
}
func TestConsumingBodyOnNextConn(t *testing.T) {
+ t.Parallel()
defer afterTest(t)
conn := new(testConn)
for i := 0; i < 2; i++ {
@@ -237,6 +238,7 @@ var vtests = []struct {
}
func TestHostHandlers(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
mux := NewServeMux()
for _, h := range handlers {
@@ -353,6 +355,7 @@ var serveMuxTests = []struct {
}
func TestServeMuxHandler(t *testing.T) {
+ setParallel(t)
mux := NewServeMux()
for _, e := range serveMuxRegister {
mux.Handle(e.pattern, e.h)
@@ -390,15 +393,16 @@ var serveMuxTests2 = []struct {
// TestServeMuxHandlerRedirects tests that automatic redirects generated by
// mux.Handler() shouldn't clear the request's query string.
func TestServeMuxHandlerRedirects(t *testing.T) {
+ setParallel(t)
mux := NewServeMux()
for _, e := range serveMuxRegister {
mux.Handle(e.pattern, e.h)
}
for _, tt := range serveMuxTests2 {
- tries := 1
+ tries := 1 // expect at most 1 redirection if redirOk is true.
turl := tt.url
- for tries > 0 {
+ for {
u, e := url.Parse(turl)
if e != nil {
t.Fatal(e)
@@ -432,6 +436,7 @@ func TestServeMuxHandlerRedirects(t *testing.T) {
// Tests for https://golang.org/issue/900
func TestMuxRedirectLeadingSlashes(t *testing.T) {
+ setParallel(t)
paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"}
for _, path := range paths {
req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n")))
@@ -456,9 +461,6 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) {
}
func TestServerTimeouts(t *testing.T) {
- if runtime.GOOS == "plan9" {
- t.Skip("skipping test; see https://golang.org/issue/7237")
- }
setParallel(t)
defer afterTest(t)
reqNum := 0
@@ -479,11 +481,11 @@ func TestServerTimeouts(t *testing.T) {
if err != nil {
t.Fatalf("http Get #1: %v", err)
}
- got, _ := ioutil.ReadAll(r.Body)
+ got, err := ioutil.ReadAll(r.Body)
expected := "req=1"
- if string(got) != expected {
- t.Errorf("Unexpected response for request #1; got %q; expected %q",
- string(got), expected)
+ if string(got) != expected || err != nil {
+ t.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil",
+ string(got), err, expected)
}
// Slow client that should timeout.
@@ -494,6 +496,7 @@ func TestServerTimeouts(t *testing.T) {
}
buf := make([]byte, 1)
n, err := conn.Read(buf)
+ conn.Close()
latency := time.Since(t1)
if n != 0 || err != io.EOF {
t.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
@@ -505,14 +508,14 @@ func TestServerTimeouts(t *testing.T) {
// Hit the HTTP server successfully again, verifying that the
// previous slow connection didn't run our handler. (that we
// get "req=2", not "req=3")
- r, err = Get(ts.URL)
+ r, err = c.Get(ts.URL)
if err != nil {
t.Fatalf("http Get #2: %v", err)
}
- got, _ = ioutil.ReadAll(r.Body)
+ got, err = ioutil.ReadAll(r.Body)
expected = "req=2"
- if string(got) != expected {
- t.Errorf("Get #2 got %q, want %q", string(got), expected)
+ if string(got) != expected || err != nil {
+ t.Errorf("Get #2 got %q, %v, want %q, nil", string(got), err, expected)
}
if !testing.Short() {
@@ -532,13 +535,61 @@ func TestServerTimeouts(t *testing.T) {
}
}
+// Test that the HTTP/2 server handles Server.WriteTimeout (Issue 18437)
+func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ setParallel(t)
+ defer afterTest(t)
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {}))
+ ts.Config.WriteTimeout = 250 * time.Millisecond
+ ts.TLS = &tls.Config{NextProtos: []string{"h2"}}
+ ts.StartTLS()
+ defer ts.Close()
+
+ tr := newTLSTransport(t, ts)
+ defer tr.CloseIdleConnections()
+ if err := ExportHttp2ConfigureTransport(tr); err != nil {
+ t.Fatal(err)
+ }
+ c := &Client{Transport: tr}
+
+ for i := 1; i <= 3; i++ {
+ req, err := NewRequest("GET", ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // fail test if no response after 1 second
+ ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
+ defer cancel()
+ req = req.WithContext(ctx)
+
+ r, err := c.Do(req)
+ select {
+ case <-ctx.Done():
+ if ctx.Err() == context.DeadlineExceeded {
+ t.Fatalf("http2 Get #%d response timed out", i)
+ }
+ default:
+ }
+ if err != nil {
+ t.Fatalf("http2 Get #%d: %v", i, err)
+ }
+ r.Body.Close()
+ if r.ProtoMajor != 2 {
+ t.Fatalf("http2 Get expected HTTP/2.0, got %q", r.Proto)
+ }
+ time.Sleep(ts.Config.WriteTimeout / 2)
+ }
+}
+
// golang.org/issue/4741 -- setting only a write timeout that triggers
// shouldn't cause a handler to block forever on reads (next HTTP
// request) that will never happen.
func TestOnlyWriteTimeout(t *testing.T) {
- if runtime.GOOS == "plan9" {
- t.Skip("skipping test; see https://golang.org/issue/7237")
- }
+ setParallel(t)
defer afterTest(t)
var conn net.Conn
var afterTimeoutErrc = make(chan error, 1)
@@ -598,6 +649,7 @@ func (l trackLastConnListener) Accept() (c net.Conn, err error) {
// TestIdentityResponse verifies that a handler can unset
func TestIdentityResponse(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
rw.Header().Set("Content-Length", "3")
@@ -619,13 +671,16 @@ func TestIdentityResponse(t *testing.T) {
ts := httptest.NewServer(handler)
defer ts.Close()
+ c := &Client{Transport: new(Transport)}
+ defer closeClient(c)
+
// Note: this relies on the assumption (which is true) that
// Get sends HTTP/1.1 or greater requests. Otherwise the
// server wouldn't have the choice to send back chunked
// responses.
for _, te := range []string{"", "identity"} {
url := ts.URL + "/?te=" + te
- res, err := Get(url)
+ res, err := c.Get(url)
if err != nil {
t.Fatalf("error with Get of %s: %v", url, err)
}
@@ -644,7 +699,7 @@ func TestIdentityResponse(t *testing.T) {
// Verify that ErrContentLength is returned
url := ts.URL + "/?overwrite=1"
- res, err := Get(url)
+ res, err := c.Get(url)
if err != nil {
t.Fatalf("error with Get of %s: %v", url, err)
}
@@ -674,6 +729,7 @@ func TestIdentityResponse(t *testing.T) {
}
func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
+ setParallel(t)
defer afterTest(t)
s := httptest.NewServer(h)
defer s.Close()
@@ -717,6 +773,7 @@ func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
}
func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) {
+ setParallel(t)
defer afterTest(t)
ts := httptest.NewServer(handler)
defer ts.Close()
@@ -750,7 +807,7 @@ func TestServeHTTP10Close(t *testing.T) {
// TestClientCanClose verifies that clients can also force a connection to close.
func TestClientCanClose(t *testing.T) {
- testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
+ testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
// Nothing.
}))
}
@@ -758,7 +815,7 @@ func TestClientCanClose(t *testing.T) {
// TestHandlersCanSetConnectionClose verifies that handlers can force a connection to close,
// even for HTTP/1.1 requests.
func TestHandlersCanSetConnectionClose11(t *testing.T) {
- testTCPConnectionCloses(t, "GET / HTTP/1.1\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
+ testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Connection", "close")
}))
}
@@ -796,6 +853,7 @@ func TestHTTP10KeepAlive304Response(t *testing.T) {
// Issue 15703
func TestKeepAliveFinalChunkWithEOF(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
cst := newClientServerTest(t, false /* h1 */, HandlerFunc(func(w ResponseWriter, r *Request) {
w.(Flusher).Flush() // force chunked encoding
@@ -828,6 +886,7 @@ func TestSetsRemoteAddr_h1(t *testing.T) { testSetsRemoteAddr(t, h1Mode) }
func TestSetsRemoteAddr_h2(t *testing.T) { testSetsRemoteAddr(t, h2Mode) }
func testSetsRemoteAddr(t *testing.T, h2 bool) {
+ setParallel(t)
defer afterTest(t)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "%s", r.RemoteAddr)
@@ -877,6 +936,7 @@ func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr {
// Issue 12943
func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "RA:%s", r.RemoteAddr)
@@ -948,7 +1008,9 @@ func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
t.Fatalf("response 1 addr = %q; want %q", g, e)
}
}
+
func TestIdentityResponseHeaders(t *testing.T) {
+ // Not parallel; changes log output.
defer afterTest(t)
log.SetOutput(ioutil.Discard) // is noisy otherwise
defer log.SetOutput(os.Stderr)
@@ -960,7 +1022,10 @@ func TestIdentityResponseHeaders(t *testing.T) {
}))
defer ts.Close()
- res, err := Get(ts.URL)
+ c := &Client{Transport: new(Transport)}
+ defer closeClient(c)
+
+ res, err := c.Get(ts.URL)
if err != nil {
t.Fatalf("Get error: %v", err)
}
@@ -983,6 +1048,7 @@ func TestHeadResponses_h1(t *testing.T) { testHeadResponses(t, h1Mode) }
func TestHeadResponses_h2(t *testing.T) { testHeadResponses(t, h2Mode) }
func testHeadResponses(t *testing.T, h2 bool) {
+ setParallel(t)
defer afterTest(t)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
_, err := w.Write([]byte("<html>"))
@@ -1020,9 +1086,6 @@ func testHeadResponses(t *testing.T, h2 bool) {
}
func TestTLSHandshakeTimeout(t *testing.T) {
- if runtime.GOOS == "plan9" {
- t.Skip("skipping test; see https://golang.org/issue/7237")
- }
setParallel(t)
defer afterTest(t)
ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
@@ -1054,6 +1117,7 @@ func TestTLSHandshakeTimeout(t *testing.T) {
}
func TestTLSServer(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.TLS != nil {
@@ -1121,6 +1185,7 @@ func TestAutomaticHTTP2_Serve_H2TLSConfig(t *testing.T) {
}
func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) {
+ setParallel(t)
defer afterTest(t)
ln := newLocalListener(t)
ln.Close() // immediately (not a defer!)
@@ -1136,6 +1201,7 @@ func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) {
}
func TestAutomaticHTTP2_Serve_WithTLSConfig(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
ln := newLocalListener(t)
ln.Close() // immediately (not a defer!)
@@ -1177,6 +1243,7 @@ func TestAutomaticHTTP2_ListenAndServe_GetCertificate(t *testing.T) {
}
func testAutomaticHTTP2_ListenAndServe(t *testing.T, tlsConf *tls.Config) {
+ // Not parallel: uses global test hooks.
defer afterTest(t)
defer SetTestHookServerServe(nil)
var ok bool
@@ -1280,6 +1347,7 @@ var serverExpectTests = []serverExpectTest{
// correctly.
// http2 test: TestServer_Response_Automatic100Continue
func TestServerExpect(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
// Note using r.FormValue("readbody") because for POST
@@ -1373,6 +1441,7 @@ func TestServerExpect(t *testing.T) {
// Under a ~256KB (maxPostHandlerReadBytes) threshold, the server
// should consume client request bodies that a handler didn't read.
func TestServerUnreadRequestBodyLittle(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
conn := new(testConn)
body := strings.Repeat("x", 100<<10)
@@ -1413,6 +1482,7 @@ func TestServerUnreadRequestBodyLittle(t *testing.T) {
// should ignore client request bodies that a handler didn't read
// and close the connection.
func TestServerUnreadRequestBodyLarge(t *testing.T) {
+ setParallel(t)
if testing.Short() && testenv.Builder() == "" {
t.Log("skipping in short mode")
}
@@ -1546,6 +1616,7 @@ var handlerBodyCloseTests = [...]handlerBodyCloseTest{
}
func TestHandlerBodyClose(t *testing.T) {
+ setParallel(t)
if testing.Short() && testenv.Builder() == "" {
t.Skip("skipping in -short mode")
}
@@ -1625,6 +1696,7 @@ var testHandlerBodyConsumers = []testHandlerBodyConsumer{
}
func TestRequestBodyReadErrorClosesConnection(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
for _, handler := range testHandlerBodyConsumers {
conn := new(testConn)
@@ -1655,6 +1727,7 @@ func TestRequestBodyReadErrorClosesConnection(t *testing.T) {
}
func TestInvalidTrailerClosesConnection(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
for _, handler := range testHandlerBodyConsumers {
conn := new(testConn)
@@ -1737,7 +1810,7 @@ restart:
if !c.rd.IsZero() {
// If the deadline falls in the middle of our sleep window, deduct
// part of the sleep, then return a timeout.
- if remaining := c.rd.Sub(time.Now()); remaining < cue {
+ if remaining := time.Until(c.rd); remaining < cue {
c.script[0] = cue - remaining
time.Sleep(remaining)
return 0, syscall.ETIMEDOUT
@@ -1823,6 +1896,7 @@ func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) }
func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) }
func testTimeoutHandler(t *testing.T, h2 bool) {
+ setParallel(t)
defer afterTest(t)
sendHi := make(chan bool, 1)
writeErrors := make(chan error, 1)
@@ -1876,6 +1950,7 @@ func testTimeoutHandler(t *testing.T, h2 bool) {
// See issues 8209 and 8414.
func TestTimeoutHandlerRace(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
delayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -1892,6 +1967,9 @@ func TestTimeoutHandlerRace(t *testing.T) {
ts := httptest.NewServer(TimeoutHandler(delayHi, 20*time.Millisecond, ""))
defer ts.Close()
+ c := &Client{Transport: new(Transport)}
+ defer closeClient(c)
+
var wg sync.WaitGroup
gate := make(chan bool, 10)
n := 50
@@ -1905,7 +1983,7 @@ func TestTimeoutHandlerRace(t *testing.T) {
go func() {
defer wg.Done()
defer func() { <-gate }()
- res, err := Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50)))
+ res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50)))
if err == nil {
io.Copy(ioutil.Discard, res.Body)
res.Body.Close()
@@ -1917,6 +1995,7 @@ func TestTimeoutHandlerRace(t *testing.T) {
// See issues 8209 and 8414.
func TestTimeoutHandlerRaceHeader(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
delay204 := HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -1932,13 +2011,15 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) {
if testing.Short() {
n = 10
}
+ c := &Client{Transport: new(Transport)}
+ defer closeClient(c)
for i := 0; i < n; i++ {
gate <- true
wg.Add(1)
go func() {
defer wg.Done()
defer func() { <-gate }()
- res, err := Get(ts.URL)
+ res, err := c.Get(ts.URL)
if err != nil {
t.Error(err)
return
@@ -1952,6 +2033,7 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) {
// Issue 9162
func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
sendHi := make(chan bool, 1)
writeErrors := make(chan error, 1)
@@ -2016,11 +2098,15 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
timeout := 300 * time.Millisecond
ts := httptest.NewServer(TimeoutHandler(handler, timeout, ""))
defer ts.Close()
+
+ c := &Client{Transport: new(Transport)}
+ defer closeClient(c)
+
// Issue was caused by the timeout handler starting the timer when
// was created, not when the request. So wait for more than the timeout
// to ensure that's not the case.
time.Sleep(2 * timeout)
- res, err := Get(ts.URL)
+ res, err := c.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
@@ -2032,6 +2118,7 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
// https://golang.org/issue/15948
func TestTimeoutHandlerEmptyResponse(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
// No response.
@@ -2040,7 +2127,10 @@ func TestTimeoutHandlerEmptyResponse(t *testing.T) {
ts := httptest.NewServer(TimeoutHandler(handler, timeout, ""))
defer ts.Close()
- res, err := Get(ts.URL)
+ c := &Client{Transport: new(Transport)}
+ defer closeClient(c)
+
+ res, err := c.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
@@ -2050,23 +2140,6 @@ func TestTimeoutHandlerEmptyResponse(t *testing.T) {
}
}
-// Verifies we don't path.Clean() on the wrong parts in redirects.
-func TestRedirectMunging(t *testing.T) {
- req, _ := NewRequest("GET", "http://example.com/", nil)
-
- resp := httptest.NewRecorder()
- Redirect(resp, req, "/foo?next=http://bar.com/", 302)
- if g, e := resp.Header().Get("Location"), "/foo?next=http://bar.com/"; g != e {
- t.Errorf("Location header was %q; want %q", g, e)
- }
-
- resp = httptest.NewRecorder()
- Redirect(resp, req, "http://localhost:8080/_ah/login?continue=http://localhost:8080/", 302)
- if g, e := resp.Header().Get("Location"), "http://localhost:8080/_ah/login?continue=http://localhost:8080/"; g != e {
- t.Errorf("Location header was %q; want %q", g, e)
- }
-}
-
func TestRedirectBadPath(t *testing.T) {
// This used to crash. It's not valid input (bad path), but it
// shouldn't crash.
@@ -2085,7 +2158,7 @@ func TestRedirectBadPath(t *testing.T) {
}
// Test different URL formats and schemes
-func TestRedirectURLFormat(t *testing.T) {
+func TestRedirect(t *testing.T) {
req, _ := NewRequest("GET", "http://example.com/qux/", nil)
var tests = []struct {
@@ -2108,6 +2181,14 @@ func TestRedirectURLFormat(t *testing.T) {
{"../quux/foobar.com/baz", "/quux/foobar.com/baz"},
// incorrect number of slashes
{"///foobar.com/baz", "/foobar.com/baz"},
+
+ // Verifies we don't path.Clean() on the wrong parts in redirects:
+ {"/foo?next=http://bar.com/", "/foo?next=http://bar.com/"},
+ {"http://localhost:8080/_ah/login?continue=http://localhost:8080/",
+ "http://localhost:8080/_ah/login?continue=http://localhost:8080/"},
+
+ {"/фубар", "/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
+ {"http://foo.com/фубар", "http://foo.com/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
}
for _, tt := range tests {
@@ -2133,6 +2214,7 @@ func TestZeroLengthPostAndResponse_h2(t *testing.T) {
}
func testZeroLengthPostAndResponse(t *testing.T, h2 bool) {
+ setParallel(t)
defer afterTest(t)
cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, r *Request) {
all, err := ioutil.ReadAll(r.Body)
@@ -2252,12 +2334,58 @@ func testHandlerPanic(t *testing.T, withHijack, h2 bool, panicValue interface{})
}
}
+type terrorWriter struct{ t *testing.T }
+
+func (w terrorWriter) Write(p []byte) (int, error) {
+ w.t.Errorf("%s", p)
+ return len(p), nil
+}
+
+// Issue 16456: allow writing 0 bytes on hijacked conn to test hijack
+// without any log spam.
+func TestServerWriteHijackZeroBytes(t *testing.T) {
+ defer afterTest(t)
+ done := make(chan struct{})
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ defer close(done)
+ w.(Flusher).Flush()
+ conn, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Errorf("Hijack: %v", err)
+ return
+ }
+ defer conn.Close()
+ _, err = w.Write(nil)
+ if err != ErrHijacked {
+ t.Errorf("Write error = %v; want ErrHijacked", err)
+ }
+ }))
+ ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0)
+ ts.Start()
+ defer ts.Close()
+
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ select {
+ case <-done:
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout")
+ }
+}
+
func TestServerNoDate_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Date") }
func TestServerNoDate_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Date") }
func TestServerNoContentType_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Content-Type") }
func TestServerNoContentType_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Content-Type") }
func testServerNoHeader(t *testing.T, h2 bool, header string) {
+ setParallel(t)
defer afterTest(t)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header()[header] = nil
@@ -2275,6 +2403,7 @@ func testServerNoHeader(t *testing.T, h2 bool, header string) {
}
func TestStripPrefix(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
h := HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("X-Path", r.URL.Path)
@@ -2282,7 +2411,10 @@ func TestStripPrefix(t *testing.T) {
ts := httptest.NewServer(StripPrefix("/foo", h))
defer ts.Close()
- res, err := Get(ts.URL + "/foo/bar")
+ c := &Client{Transport: new(Transport)}
+ defer closeClient(c)
+
+ res, err := c.Get(ts.URL + "/foo/bar")
if err != nil {
t.Fatal(err)
}
@@ -2304,10 +2436,11 @@ func TestStripPrefix(t *testing.T) {
func TestRequestLimit_h1(t *testing.T) { testRequestLimit(t, h1Mode) }
func TestRequestLimit_h2(t *testing.T) { testRequestLimit(t, h2Mode) }
func testRequestLimit(t *testing.T, h2 bool) {
+ setParallel(t)
defer afterTest(t)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
t.Fatalf("didn't expect to get request in Handler")
- }))
+ }), optQuietLog)
defer cst.close()
req, _ := NewRequest("GET", cst.ts.URL, nil)
var bytesPerHeader = len("header12345: val12345\r\n")
@@ -2350,6 +2483,7 @@ func (cr countReader) Read(p []byte) (n int, err error) {
func TestRequestBodyLimit_h1(t *testing.T) { testRequestBodyLimit(t, h1Mode) }
func TestRequestBodyLimit_h2(t *testing.T) { testRequestBodyLimit(t, h2Mode) }
func testRequestBodyLimit(t *testing.T, h2 bool) {
+ setParallel(t)
defer afterTest(t)
const limit = 1 << 20
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -2399,14 +2533,14 @@ func TestClientWriteShutdown(t *testing.T) {
}
err = conn.(*net.TCPConn).CloseWrite()
if err != nil {
- t.Fatalf("Dial: %v", err)
+ t.Fatalf("CloseWrite: %v", err)
}
donec := make(chan bool)
go func() {
defer close(donec)
bs, err := ioutil.ReadAll(conn)
if err != nil {
- t.Fatalf("ReadAll: %v", err)
+ t.Errorf("ReadAll: %v", err)
}
got := string(bs)
if got != "" {
@@ -2445,6 +2579,7 @@ func TestServerBufferedChunking(t *testing.T) {
// closing the TCP connection, causing the client to get a RST.
// See https://golang.org/issue/3595
func TestServerGracefulClose(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
Error(w, "bye", StatusUnauthorized)
@@ -2557,7 +2692,8 @@ func TestCloseNotifier(t *testing.T) {
go func() {
_, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
if err != nil {
- t.Fatal(err)
+ t.Error(err)
+ return
}
<-diec
conn.Close()
@@ -2599,7 +2735,8 @@ func TestCloseNotifierPipelined(t *testing.T) {
const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n"
_, err = io.WriteString(conn, req+req) // two requests
if err != nil {
- t.Fatal(err)
+ t.Error(err)
+ return
}
<-diec
conn.Close()
@@ -2707,6 +2844,7 @@ func TestHijackAfterCloseNotifier(t *testing.T) {
}
func TestHijackBeforeRequestBodyRead(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
var requestBody = bytes.Repeat([]byte("a"), 1<<20)
bodyOkay := make(chan bool, 1)
@@ -3028,15 +3166,18 @@ func (l *errorListener) Addr() net.Addr {
}
func TestAcceptMaxFds(t *testing.T) {
- log.SetOutput(ioutil.Discard) // is noisy otherwise
- defer log.SetOutput(os.Stderr)
+ setParallel(t)
ln := &errorListener{[]error{
&net.OpError{
Op: "accept",
Err: syscall.EMFILE,
}}}
- err := Serve(ln, HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})))
+ server := &Server{
+ Handler: HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})),
+ ErrorLog: log.New(ioutil.Discard, "", 0), // noisy otherwise
+ }
+ err := server.Serve(ln)
if err != io.EOF {
t.Errorf("got error %v, want EOF", err)
}
@@ -3161,6 +3302,7 @@ func TestHTTP10ConnectionHeader(t *testing.T) {
func TestServerReaderFromOrder_h1(t *testing.T) { testServerReaderFromOrder(t, h1Mode) }
func TestServerReaderFromOrder_h2(t *testing.T) { testServerReaderFromOrder(t, h2Mode) }
func testServerReaderFromOrder(t *testing.T, h2 bool) {
+ setParallel(t)
defer afterTest(t)
pr, pw := io.Pipe()
const size = 3 << 20
@@ -3265,6 +3407,7 @@ func TestTransportAndServerSharedBodyRace_h2(t *testing.T) {
testTransportAndServerSharedBodyRace(t, h2Mode)
}
func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) {
+ setParallel(t)
defer afterTest(t)
const bodySize = 1 << 20
@@ -3453,6 +3596,7 @@ func TestAppendTime(t *testing.T) {
}
func TestServerConnState(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
handler := map[string]func(w ResponseWriter, r *Request){
"/": func(w ResponseWriter, r *Request) {
@@ -3500,14 +3644,39 @@ func TestServerConnState(t *testing.T) {
}
ts.Start()
- mustGet(t, ts.URL+"/")
- mustGet(t, ts.URL+"/close")
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ mustGet := func(url string, headers ...string) {
+ req, err := NewRequest("GET", url, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for len(headers) > 0 {
+ req.Header.Add(headers[0], headers[1])
+ headers = headers[2:]
+ }
+ res, err := c.Do(req)
+ if err != nil {
+ t.Errorf("Error fetching %s: %v", url, err)
+ return
+ }
+ _, err = ioutil.ReadAll(res.Body)
+ defer res.Body.Close()
+ if err != nil {
+ t.Errorf("Error reading %s: %v", url, err)
+ }
+ }
+
+ mustGet(ts.URL + "/")
+ mustGet(ts.URL + "/close")
- mustGet(t, ts.URL+"/")
- mustGet(t, ts.URL+"/", "Connection", "close")
+ mustGet(ts.URL + "/")
+ mustGet(ts.URL+"/", "Connection", "close")
- mustGet(t, ts.URL+"/hijack")
- mustGet(t, ts.URL+"/hijack-panic")
+ mustGet(ts.URL + "/hijack")
+ mustGet(ts.URL + "/hijack-panic")
// New->Closed
{
@@ -3587,31 +3756,10 @@ func TestServerConnState(t *testing.T) {
}
mu.Lock()
- t.Errorf("Unexpected events.\nGot log: %s\n Want: %s\n", logString(stateLog), logString(want))
+ t.Errorf("Unexpected events.\nGot log:\n%s\n Want:\n%s\n", logString(stateLog), logString(want))
mu.Unlock()
}
-func mustGet(t *testing.T, url string, headers ...string) {
- req, err := NewRequest("GET", url, nil)
- if err != nil {
- t.Fatal(err)
- }
- for len(headers) > 0 {
- req.Header.Add(headers[0], headers[1])
- headers = headers[2:]
- }
- res, err := DefaultClient.Do(req)
- if err != nil {
- t.Errorf("Error fetching %s: %v", url, err)
- return
- }
- _, err = ioutil.ReadAll(res.Body)
- defer res.Body.Close()
- if err != nil {
- t.Errorf("Error reading %s: %v", url, err)
- }
-}
-
func TestServerKeepAlivesEnabled(t *testing.T) {
defer afterTest(t)
ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
@@ -3632,6 +3780,7 @@ func TestServerKeepAlivesEnabled(t *testing.T) {
func TestServerEmptyBodyRace_h1(t *testing.T) { testServerEmptyBodyRace(t, h1Mode) }
func TestServerEmptyBodyRace_h2(t *testing.T) { testServerEmptyBodyRace(t, h2Mode) }
func testServerEmptyBodyRace(t *testing.T, h2 bool) {
+ setParallel(t)
defer afterTest(t)
var n int32
cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) {
@@ -3695,6 +3844,7 @@ func (c *closeWriteTestConn) CloseWrite() error {
}
func TestCloseWrite(t *testing.T) {
+ setParallel(t)
var srv Server
var testConn closeWriteTestConn
c := ExportServerNewConn(&srv, &testConn)
@@ -3935,6 +4085,7 @@ Host: foo
// If a Handler finishes and there's an unread request body,
// verify the server try to do implicit read on it before replying.
func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) {
+ setParallel(t)
conn := &testConn{closec: make(chan bool)}
conn.readBuf.Write([]byte(fmt.Sprintf(
"POST / HTTP/1.1\r\n" +
@@ -4033,7 +4184,11 @@ func TestServerValidatesHostHeader(t *testing.T) {
io.WriteString(&conn.readBuf, methodTarget+tt.proto+"\r\n"+tt.host+"\r\n")
ln := &oneConnListener{conn}
- go Serve(ln, HandlerFunc(func(ResponseWriter, *Request) {}))
+ srv := Server{
+ ErrorLog: quietLog,
+ Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
+ }
+ go srv.Serve(ln)
<-conn.closec
res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
if err != nil {
@@ -4088,6 +4243,7 @@ func TestServerHandlersCanHandleH2PRI(t *testing.T) {
// Test that we validate the valid bytes in HTTP/1 headers.
// Issue 11207.
func TestServerValidatesHeaders(t *testing.T) {
+ setParallel(t)
tests := []struct {
header string
want int
@@ -4097,9 +4253,10 @@ func TestServerValidatesHeaders(t *testing.T) {
{"X-Foo: bar\r\n", 200},
{"Foo: a space\r\n", 200},
- {"A space: foo\r\n", 400}, // space in header
- {"foo\xffbar: foo\r\n", 400}, // binary in header
- {"foo\x00bar: foo\r\n", 400}, // binary in header
+ {"A space: foo\r\n", 400}, // space in header
+ {"foo\xffbar: foo\r\n", 400}, // binary in header
+ {"foo\x00bar: foo\r\n", 400}, // binary in header
+ {"Foo: " + strings.Repeat("x", 1<<21) + "\r\n", 431}, // header too large
{"foo: foo foo\r\n", 200}, // LWS space is okay
{"foo: foo\tfoo\r\n", 200}, // LWS tab is okay
@@ -4112,7 +4269,11 @@ func TestServerValidatesHeaders(t *testing.T) {
io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n")
ln := &oneConnListener{conn}
- go Serve(ln, HandlerFunc(func(ResponseWriter, *Request) {}))
+ srv := Server{
+ ErrorLog: quietLog,
+ Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
+ }
+ go srv.Serve(ln)
<-conn.closec
res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
if err != nil {
@@ -4132,6 +4293,7 @@ func TestServerRequestContextCancel_ServeHTTPDone_h2(t *testing.T) {
testServerRequestContextCancel_ServeHTTPDone(t, h2Mode)
}
func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) {
+ setParallel(t)
defer afterTest(t)
ctxc := make(chan context.Context, 1)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -4157,13 +4319,12 @@ func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) {
}
}
+// Tests that the Request.Context available to the Handler is canceled
+// if the peer closes their TCP connection. This requires that the server
+// is always blocked in a Read call so it notices the EOF from the client.
+// See issues 15927 and 15224.
func TestServerRequestContextCancel_ConnClose(t *testing.T) {
- // Currently the context is not canceled when the connection
- // is closed because we're not reading from the connection
- // until after ServeHTTP for the previous handler is done.
- // Until the server code is modified to always be in a read
- // (Issue 15224), this test doesn't work yet.
- t.Skip("TODO(bradfitz): this test doesn't yet work; golang.org/issue/15224")
+ setParallel(t)
defer afterTest(t)
inHandler := make(chan struct{})
handlerDone := make(chan struct{})
@@ -4192,7 +4353,7 @@ func TestServerRequestContextCancel_ConnClose(t *testing.T) {
select {
case <-handlerDone:
- case <-time.After(3 * time.Second):
+ case <-time.After(4 * time.Second):
t.Fatalf("timeout waiting to see ServeHTTP exit")
}
}
@@ -4204,6 +4365,7 @@ func TestServerContext_ServerContextKey_h2(t *testing.T) {
testServerContext_ServerContextKey(t, h2Mode)
}
func testServerContext_ServerContextKey(t *testing.T, h2 bool) {
+ setParallel(t)
defer afterTest(t)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
ctx := r.Context()
@@ -4229,6 +4391,7 @@ func testServerContext_ServerContextKey(t *testing.T, h2 bool) {
// https://golang.org/issue/15960
func TestHandlerSetTransferEncodingChunked(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Transfer-Encoding", "chunked")
@@ -4243,6 +4406,7 @@ func TestHandlerSetTransferEncodingChunked(t *testing.T) {
// https://golang.org/issue/16063
func TestHandlerSetTransferEncodingGzip(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Transfer-Encoding", "gzip")
@@ -4416,13 +4580,19 @@ func BenchmarkClient(b *testing.B) {
b.StopTimer()
defer afterTest(b)
- port := os.Getenv("TEST_BENCH_SERVER_PORT") // can be set by user
- if port == "" {
- port = "39207"
- }
var data = []byte("Hello world.\n")
if server := os.Getenv("TEST_BENCH_SERVER"); server != "" {
// Server process mode.
+ port := os.Getenv("TEST_BENCH_SERVER_PORT") // can be set by user
+ if port == "" {
+ port = "0"
+ }
+ ln, err := net.Listen("tcp", "localhost:"+port)
+ if err != nil {
+ fmt.Fprintln(os.Stderr, err.Error())
+ os.Exit(1)
+ }
+ fmt.Println(ln.Addr().String())
HandleFunc("/", func(w ResponseWriter, r *Request) {
r.ParseForm()
if r.Form.Get("stop") != "" {
@@ -4431,33 +4601,44 @@ func BenchmarkClient(b *testing.B) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.Write(data)
})
- log.Fatal(ListenAndServe("localhost:"+port, nil))
+ var srv Server
+ log.Fatal(srv.Serve(ln))
}
// Start server process.
cmd := exec.Command(os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkClient$")
cmd.Env = append(os.Environ(), "TEST_BENCH_SERVER=yes")
+ cmd.Stderr = os.Stderr
+ stdout, err := cmd.StdoutPipe()
+ if err != nil {
+ b.Fatal(err)
+ }
if err := cmd.Start(); err != nil {
b.Fatalf("subprocess failed to start: %v", err)
}
defer cmd.Process.Kill()
+
+ // Wait for the server in the child process to respond and tell us
+ // its listening address, once it's started listening:
+ timer := time.AfterFunc(10*time.Second, func() {
+ cmd.Process.Kill()
+ })
+ defer timer.Stop()
+ bs := bufio.NewScanner(stdout)
+ if !bs.Scan() {
+ b.Fatalf("failed to read listening URL from child: %v", bs.Err())
+ }
+ url := "http://" + strings.TrimSpace(bs.Text()) + "/"
+ timer.Stop()
+ if _, err := getNoBody(url); err != nil {
+ b.Fatalf("initial probe of child process failed: %v", err)
+ }
+
done := make(chan error)
go func() {
done <- cmd.Wait()
}()
- // Wait for the server process to respond.
- url := "http://localhost:" + port + "/"
- for i := 0; i < 100; i++ {
- time.Sleep(100 * time.Millisecond)
- if _, err := getNoBody(url); err == nil {
- break
- }
- if i == 99 {
- b.Fatalf("subprocess does not respond")
- }
- }
-
// Do b.N requests to the server.
b.StartTimer()
for i := 0; i < b.N; i++ {
@@ -4719,6 +4900,7 @@ func BenchmarkCloseNotifier(b *testing.B) {
// Verify this doesn't race (Issue 16505)
func TestConcurrentServerServe(t *testing.T) {
+ setParallel(t)
for i := 0; i < 100; i++ {
ln1 := &oneConnListener{conn: nil}
ln2 := &oneConnListener{conn: nil}
@@ -4727,3 +4909,267 @@ func TestConcurrentServerServe(t *testing.T) {
go func() { srv.Serve(ln2) }()
}
}
+
+func TestServerIdleTimeout(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ setParallel(t)
+ defer afterTest(t)
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.Copy(ioutil.Discard, r.Body)
+ io.WriteString(w, r.RemoteAddr)
+ }))
+ ts.Config.ReadHeaderTimeout = 1 * time.Second
+ ts.Config.IdleTimeout = 2 * time.Second
+ ts.Start()
+ defer ts.Close()
+
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ get := func() string {
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ slurp, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return string(slurp)
+ }
+
+ a1, a2 := get(), get()
+ if a1 != a2 {
+ t.Fatalf("did requests on different connections")
+ }
+ time.Sleep(3 * time.Second)
+ a3 := get()
+ if a2 == a3 {
+ t.Fatal("request three unexpectedly on same connection")
+ }
+
+ // And test that ReadHeaderTimeout still works:
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n"))
+ time.Sleep(2 * time.Second)
+ if _, err := io.CopyN(ioutil.Discard, conn, 1); err == nil {
+ t.Fatal("copy byte succeeded; want err")
+ }
+}
+
+func get(t *testing.T, c *Client, url string) string {
+ res, err := c.Get(url)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ slurp, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return string(slurp)
+}
+
+// Tests that calls to Server.SetKeepAlivesEnabled(false) closes any
+// currently-open connections.
+func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.WriteString(w, r.RemoteAddr)
+ }))
+ defer ts.Close()
+
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ get := func() string { return get(t, c, ts.URL) }
+
+ a1, a2 := get(), get()
+ if a1 != a2 {
+ t.Fatal("expected first two requests on same connection")
+ }
+ var idle0 int
+ if !waitCondition(2*time.Second, 10*time.Millisecond, func() bool {
+ idle0 = tr.IdleConnKeyCountForTesting()
+ return idle0 == 1
+ }) {
+ t.Fatalf("idle count before SetKeepAlivesEnabled called = %v; want 1", idle0)
+ }
+
+ ts.Config.SetKeepAlivesEnabled(false)
+
+ var idle1 int
+ if !waitCondition(2*time.Second, 10*time.Millisecond, func() bool {
+ idle1 = tr.IdleConnKeyCountForTesting()
+ return idle1 == 0
+ }) {
+ t.Fatalf("idle count after SetKeepAlivesEnabled called = %v; want 0", idle1)
+ }
+
+ a3 := get()
+ if a3 == a2 {
+ t.Fatal("expected third request on new connection")
+ }
+}
+
+func TestServerShutdown_h1(t *testing.T) { testServerShutdown(t, h1Mode) }
+func TestServerShutdown_h2(t *testing.T) { testServerShutdown(t, h2Mode) }
+
+func testServerShutdown(t *testing.T, h2 bool) {
+ setParallel(t)
+ defer afterTest(t)
+ var doShutdown func() // set later
+ var shutdownRes = make(chan error, 1)
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ go doShutdown()
+ // Shutdown is graceful, so it should not interrupt
+ // this in-flight response. Add a tiny sleep here to
+ // increase the odds of a failure if shutdown has
+ // bugs.
+ time.Sleep(20 * time.Millisecond)
+ io.WriteString(w, r.RemoteAddr)
+ }))
+ defer cst.close()
+
+ doShutdown = func() {
+ shutdownRes <- cst.ts.Config.Shutdown(context.Background())
+ }
+ get(t, cst.c, cst.ts.URL) // calls t.Fail on failure
+
+ if err := <-shutdownRes; err != nil {
+ t.Fatalf("Shutdown: %v", err)
+ }
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if err == nil {
+ res.Body.Close()
+ t.Fatal("second request should fail. server should be shut down")
+ }
+}
+
+// Issue 17878: tests that we can call Close twice.
+func TestServerCloseDeadlock(t *testing.T) {
+ var s Server
+ s.Close()
+ s.Close()
+}
+
+// Issue 17717: tests that Server.SetKeepAlivesEnabled is respected by
+// both HTTP/1 and HTTP/2.
+func TestServerKeepAlivesEnabled_h1(t *testing.T) { testServerKeepAlivesEnabled(t, h1Mode) }
+func TestServerKeepAlivesEnabled_h2(t *testing.T) { testServerKeepAlivesEnabled(t, h2Mode) }
+func testServerKeepAlivesEnabled(t *testing.T, h2 bool) {
+ setParallel(t)
+ defer afterTest(t)
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "%v", r.RemoteAddr)
+ }))
+ defer cst.close()
+ srv := cst.ts.Config
+ srv.SetKeepAlivesEnabled(false)
+ a := cst.getURL(cst.ts.URL)
+ if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) {
+ t.Fatalf("test server has active conns")
+ }
+ b := cst.getURL(cst.ts.URL)
+ if a == b {
+ t.Errorf("got same connection between first and second requests")
+ }
+ if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) {
+ t.Fatalf("test server has active conns")
+ }
+}
+
+// Issue 18447: test that the Server's ReadTimeout is stopped while
+// the server's doing its 1-byte background read between requests,
+// waiting for the connection to maybe close.
+func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ const timeout = 250 * time.Millisecond
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ select {
+ case <-time.After(2 * timeout):
+ fmt.Fprint(w, "ok")
+ case <-r.Context().Done():
+ fmt.Fprint(w, r.Context().Err())
+ }
+ }))
+ ts.Config.ReadTimeout = timeout
+ ts.Start()
+ defer ts.Close()
+
+ tr := &Transport{}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ res, err := c.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ slurp, err := ioutil.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(slurp) != "ok" {
+ t.Fatalf("Got: %q, want ok", slurp)
+ }
+}
+
+// Issue 18535: test that the Server doesn't try to do a background
+// read if it's already done one.
+func TestServerDuplicateBackgroundRead(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+
+ const goroutines = 5
+ const requests = 2000
+
+ hts := httptest.NewServer(HandlerFunc(NotFound))
+ defer hts.Close()
+
+ reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
+
+ var wg sync.WaitGroup
+ for i := 0; i < goroutines; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ cn, err := net.Dial("tcp", hts.Listener.Addr().String())
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer cn.Close()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ io.Copy(ioutil.Discard, cn)
+ }()
+
+ for j := 0; j < requests; j++ {
+ if t.Failed() {
+ return
+ }
+ _, err := cn.Write(reqBytes)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ }
+ }()
+ }
+ wg.Wait()
+}
diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go
index 89574a8b..9623648 100644
--- a/libgo/go/net/http/server.go
+++ b/libgo/go/net/http/server.go
@@ -40,7 +40,9 @@ var (
// ErrHijacked is returned by ResponseWriter.Write calls when
// the underlying connection has been hijacked using the
- // Hijacker interfaced.
+ // Hijacker interface. A zero-byte write on a hijacked
+ // connection will return ErrHijacked without any other side
+ // effects.
ErrHijacked = errors.New("http: connection has been hijacked")
// ErrContentLength is returned by ResponseWriter.Write calls
@@ -73,7 +75,9 @@ var (
// If ServeHTTP panics, the server (the caller of ServeHTTP) assumes
// that the effect of the panic was isolated to the active request.
// It recovers the panic, logs a stack trace to the server error log,
-// and hangs up the connection.
+// and hangs up the connection. To abort a handler so the client sees
+// an interrupted response but the server doesn't log an error, panic
+// with the value ErrAbortHandler.
type Handler interface {
ServeHTTP(ResponseWriter, *Request)
}
@@ -85,11 +89,25 @@ type Handler interface {
// has returned.
type ResponseWriter interface {
// Header returns the header map that will be sent by
- // WriteHeader. Changing the header after a call to
- // WriteHeader (or Write) has no effect unless the modified
- // headers were declared as trailers by setting the
- // "Trailer" header before the call to WriteHeader (see example).
- // To suppress implicit response headers, set their value to nil.
+ // WriteHeader. The Header map also is the mechanism with which
+ // Handlers can set HTTP trailers.
+ //
+ // Changing the header map after a call to WriteHeader (or
+ // Write) has no effect unless the modified headers are
+ // trailers.
+ //
+ // There are two ways to set Trailers. The preferred way is to
+ // predeclare in the headers which trailers you will later
+ // send by setting the "Trailer" header to the names of the
+ // trailer keys which will come later. In this case, those
+ // keys of the Header map are treated as if they were
+ // trailers. See the example. The second way, for trailer
+ // keys not known to the Handler until after the first Write,
+ // is to prefix the Header map keys with the TrailerPrefix
+ // constant value. See TrailerPrefix.
+ //
+ // To suppress implicit response headers (such as "Date"), set
+ // their value to nil.
Header() Header
// Write writes the data to the connection as part of an HTTP reply.
@@ -206,6 +224,9 @@ type conn struct {
// Immutable; never nil.
server *Server
+ // cancelCtx cancels the connection-level context.
+ cancelCtx context.CancelFunc
+
// rwc is the underlying network connection.
// This is never wrapped by other types and is the value given out
// to CloseNotifier callers. It is usually of type *net.TCPConn or
@@ -232,7 +253,6 @@ type conn struct {
r *connReader
// bufr reads from r.
- // Users of bufr must hold mu.
bufr *bufio.Reader
// bufw writes to checkConnErrorWriter{c}, which populates werr on error.
@@ -242,7 +262,11 @@ type conn struct {
// on this connection, if any.
lastMethod string
- // mu guards hijackedv, use of bufr, (*response).closeNotifyCh.
+ curReq atomic.Value // of *response (which has a Request in it)
+
+ curState atomic.Value // of ConnState
+
+ // mu guards hijackedv
mu sync.Mutex
// hijackedv is whether this connection has been hijacked
@@ -262,8 +286,12 @@ func (c *conn) hijackLocked() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
if c.hijackedv {
return nil, nil, ErrHijacked
}
+ c.r.abortPendingRead()
+
c.hijackedv = true
rwc = c.rwc
+ rwc.SetDeadline(time.Time{})
+
buf = bufio.NewReadWriter(c.bufr, bufio.NewWriter(rwc))
c.setState(rwc, StateHijacked)
return
@@ -346,13 +374,7 @@ func (cw *chunkWriter) close() {
bw := cw.res.conn.bufw // conn's bufio writer
// zero chunk to mark EOF
bw.WriteString("0\r\n")
- if len(cw.res.trailers) > 0 {
- trailers := make(Header)
- for _, h := range cw.res.trailers {
- if vv := cw.res.handlerHeader[h]; len(vv) > 0 {
- trailers[h] = vv
- }
- }
+ if trailers := cw.res.finalTrailers(); trailers != nil {
trailers.Write(bw) // the writer handles noting errors
}
// final blank line after the trailers (whether
@@ -413,9 +435,48 @@ type response struct {
dateBuf [len(TimeFormat)]byte
clenBuf [10]byte
- // closeNotifyCh is non-nil once CloseNotify is called.
- // Guarded by conn.mu
- closeNotifyCh <-chan bool
+ // closeNotifyCh is the channel returned by CloseNotify.
+ // TODO(bradfitz): this is currently (for Go 1.8) always
+ // non-nil. Make this lazily-created again as it used to be?
+ closeNotifyCh chan bool
+ didCloseNotify int32 // atomic (only 0->1 winner should send)
+}
+
+// TrailerPrefix is a magic prefix for ResponseWriter.Header map keys
+// that, if present, signals that the map entry is actually for
+// the response trailers, and not the response headers. The prefix
+// is stripped after the ServeHTTP call finishes and the values are
+// sent in the trailers.
+//
+// This mechanism is intended only for trailers that are not known
+// prior to the headers being written. If the set of trailers is fixed
+// or known before the header is written, the normal Go trailers mechanism
+// is preferred:
+// https://golang.org/pkg/net/http/#ResponseWriter
+// https://golang.org/pkg/net/http/#example_ResponseWriter_trailers
+const TrailerPrefix = "Trailer:"
+
+// finalTrailers is called after the Handler exits and returns a non-nil
+// value if the Handler set any trailers.
+func (w *response) finalTrailers() Header {
+ var t Header
+ for k, vv := range w.handlerHeader {
+ if strings.HasPrefix(k, TrailerPrefix) {
+ if t == nil {
+ t = make(Header)
+ }
+ t[strings.TrimPrefix(k, TrailerPrefix)] = vv
+ }
+ }
+ for _, k := range w.trailers {
+ if t == nil {
+ t = make(Header)
+ }
+ for _, v := range w.handlerHeader[k] {
+ t.Add(k, v)
+ }
+ }
+ return t
}
type atomicBool int32
@@ -548,60 +609,152 @@ type readResult struct {
// call blocked in a background goroutine to wait for activity and
// trigger a CloseNotifier channel.
type connReader struct {
- r io.Reader
- remain int64 // bytes remaining
+ conn *conn
- // ch is non-nil if a background read is in progress.
- // It is guarded by conn.mu.
- ch chan readResult
+ mu sync.Mutex // guards following
+ hasByte bool
+ byteBuf [1]byte
+ bgErr error // non-nil means error happened on background read
+ cond *sync.Cond
+ inRead bool
+ aborted bool // set true before conn.rwc deadline is set to past
+ remain int64 // bytes remaining
+}
+
+func (cr *connReader) lock() {
+ cr.mu.Lock()
+ if cr.cond == nil {
+ cr.cond = sync.NewCond(&cr.mu)
+ }
+}
+
+func (cr *connReader) unlock() { cr.mu.Unlock() }
+
+func (cr *connReader) startBackgroundRead() {
+ cr.lock()
+ defer cr.unlock()
+ if cr.inRead {
+ panic("invalid concurrent Body.Read call")
+ }
+ if cr.hasByte {
+ return
+ }
+ cr.inRead = true
+ cr.conn.rwc.SetReadDeadline(time.Time{})
+ go cr.backgroundRead()
+}
+
+func (cr *connReader) backgroundRead() {
+ n, err := cr.conn.rwc.Read(cr.byteBuf[:])
+ cr.lock()
+ if n == 1 {
+ cr.hasByte = true
+ // We were at EOF already (since we wouldn't be in a
+ // background read otherwise), so this is a pipelined
+ // HTTP request.
+ cr.closeNotifyFromPipelinedRequest()
+ }
+ if ne, ok := err.(net.Error); ok && cr.aborted && ne.Timeout() {
+ // Ignore this error. It's the expected error from
+ // another goroutine calling abortPendingRead.
+ } else if err != nil {
+ cr.handleReadError(err)
+ }
+ cr.aborted = false
+ cr.inRead = false
+ cr.unlock()
+ cr.cond.Broadcast()
+}
+
+func (cr *connReader) abortPendingRead() {
+ cr.lock()
+ defer cr.unlock()
+ if !cr.inRead {
+ return
+ }
+ cr.aborted = true
+ cr.conn.rwc.SetReadDeadline(aLongTimeAgo)
+ for cr.inRead {
+ cr.cond.Wait()
+ }
+ cr.conn.rwc.SetReadDeadline(time.Time{})
}
func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain }
func (cr *connReader) setInfiniteReadLimit() { cr.remain = maxInt64 }
func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 }
+// may be called from multiple goroutines.
+func (cr *connReader) handleReadError(err error) {
+ cr.conn.cancelCtx()
+ cr.closeNotify()
+}
+
+// closeNotifyFromPipelinedRequest simply calls closeNotify.
+//
+// This method wrapper is here for documentation. The callers are the
+// cases where we send on the closenotify channel because of a
+// pipelined HTTP request, per the previous Go behavior and
+// documentation (that this "MAY" happen).
+//
+// TODO: consider changing this behavior and making context
+// cancelation and closenotify work the same.
+func (cr *connReader) closeNotifyFromPipelinedRequest() {
+ cr.closeNotify()
+}
+
+// may be called from multiple goroutines.
+func (cr *connReader) closeNotify() {
+ res, _ := cr.conn.curReq.Load().(*response)
+ if res != nil {
+ if atomic.CompareAndSwapInt32(&res.didCloseNotify, 0, 1) {
+ res.closeNotifyCh <- true
+ }
+ }
+}
+
func (cr *connReader) Read(p []byte) (n int, err error) {
+ cr.lock()
+ if cr.inRead {
+ cr.unlock()
+ panic("invalid concurrent Body.Read call")
+ }
if cr.hitReadLimit() {
+ cr.unlock()
return 0, io.EOF
}
+ if cr.bgErr != nil {
+ err = cr.bgErr
+ cr.unlock()
+ return 0, err
+ }
if len(p) == 0 {
- return
+ cr.unlock()
+ return 0, nil
}
if int64(len(p)) > cr.remain {
p = p[:cr.remain]
}
-
- // Is a background read (started by CloseNotifier) already in
- // flight? If so, wait for it and use its result.
- ch := cr.ch
- if ch != nil {
- cr.ch = nil
- res := <-ch
- if res.n == 1 {
- p[0] = res.b
- cr.remain -= 1
- }
- return res.n, res.err
+ if cr.hasByte {
+ p[0] = cr.byteBuf[0]
+ cr.hasByte = false
+ cr.unlock()
+ return 1, nil
}
- n, err = cr.r.Read(p)
- cr.remain -= int64(n)
- return
-}
+ cr.inRead = true
+ cr.unlock()
+ n, err = cr.conn.rwc.Read(p)
-func (cr *connReader) startBackgroundRead(onReadComplete func()) {
- if cr.ch != nil {
- // Background read already started.
- return
+ cr.lock()
+ cr.inRead = false
+ if err != nil {
+ cr.handleReadError(err)
}
- cr.ch = make(chan readResult, 1)
- go cr.closeNotifyAwaitActivityRead(cr.ch, onReadComplete)
-}
+ cr.remain -= int64(n)
+ cr.unlock()
-func (cr *connReader) closeNotifyAwaitActivityRead(ch chan<- readResult, onReadComplete func()) {
- var buf [1]byte
- n, err := cr.r.Read(buf[:1])
- onReadComplete()
- ch <- readResult{n, err, buf[0]}
+ cr.cond.Broadcast()
+ return n, err
}
var (
@@ -633,7 +786,7 @@ func newBufioReader(r io.Reader) *bufio.Reader {
br.Reset(r)
return br
}
- // Note: if this reader size is every changed, update
+ // Note: if this reader size is ever changed, update
// TestHandlerBodyClose's assumptions.
return bufio.NewReader(r)
}
@@ -746,9 +899,18 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) {
return nil, ErrHijacked
}
+ var (
+ wholeReqDeadline time.Time // or zero if none
+ hdrDeadline time.Time // or zero if none
+ )
+ t0 := time.Now()
+ if d := c.server.readHeaderTimeout(); d != 0 {
+ hdrDeadline = t0.Add(d)
+ }
if d := c.server.ReadTimeout; d != 0 {
- c.rwc.SetReadDeadline(time.Now().Add(d))
+ wholeReqDeadline = t0.Add(d)
}
+ c.rwc.SetReadDeadline(hdrDeadline)
if d := c.server.WriteTimeout; d != 0 {
defer func() {
c.rwc.SetWriteDeadline(time.Now().Add(d))
@@ -756,14 +918,12 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) {
}
c.r.setReadLimit(c.server.initialReadLimitSize())
- c.mu.Lock() // while using bufr
if c.lastMethod == "POST" {
// RFC 2616 section 4.1 tolerance for old buggy clients.
peek, _ := c.bufr.Peek(4) // ReadRequest will get err below
c.bufr.Discard(numLeadingCRorLF(peek))
}
req, err := readRequest(c.bufr, keepHostHeader)
- c.mu.Unlock()
if err != nil {
if c.r.hitReadLimit() {
return nil, errTooLarge
@@ -809,6 +969,11 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) {
body.doEarlyClose = true
}
+ // Adjust the read deadline if necessary.
+ if !hdrDeadline.Equal(wholeReqDeadline) {
+ c.rwc.SetReadDeadline(wholeReqDeadline)
+ }
+
w = &response{
conn: c,
cancelCtx: cancelCtx,
@@ -816,6 +981,7 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) {
reqBody: req.Body,
handlerHeader: make(Header),
contentLength: -1,
+ closeNotifyCh: make(chan bool, 1),
// We populate these ahead of time so we're not
// reading from req.Header after their Handler starts
@@ -990,7 +1156,17 @@ func (cw *chunkWriter) writeHeader(p []byte) {
}
var setHeader extraHeader
+ // Don't write out the fake "Trailer:foo" keys. See TrailerPrefix.
trailers := false
+ for k := range cw.header {
+ if strings.HasPrefix(k, TrailerPrefix) {
+ if excludeHeader == nil {
+ excludeHeader = make(map[string]bool)
+ }
+ excludeHeader[k] = true
+ trailers = true
+ }
+ }
for _, v := range cw.header["Trailer"] {
trailers = true
foreachHeaderElement(v, cw.res.declareTrailer)
@@ -1318,7 +1494,9 @@ func (w *response) WriteString(data string) (n int, err error) {
// either dataB or dataS is non-zero.
func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err error) {
if w.conn.hijacked() {
- w.conn.server.logf("http: response.Write on hijacked connection")
+ if lenData > 0 {
+ w.conn.server.logf("http: response.Write on hijacked connection")
+ }
return 0, ErrHijacked
}
if !w.wroteHeader {
@@ -1354,6 +1532,8 @@ func (w *response) finishRequest() {
w.cw.close()
w.conn.bufw.Flush()
+ w.conn.r.abortPendingRead()
+
// Close the body (regardless of w.closeAfterReply) so we can
// re-use its bufio.Reader later safely.
w.reqBody.Close()
@@ -1469,11 +1649,30 @@ func validNPN(proto string) bool {
}
func (c *conn) setState(nc net.Conn, state ConnState) {
- if hook := c.server.ConnState; hook != nil {
+ srv := c.server
+ switch state {
+ case StateNew:
+ srv.trackConn(c, true)
+ case StateHijacked, StateClosed:
+ srv.trackConn(c, false)
+ }
+ c.curState.Store(connStateInterface[state])
+ if hook := srv.ConnState; hook != nil {
hook(nc, state)
}
}
+// connStateInterface is an array of the interface{} versions of
+// ConnState values, so we can use them in atomic.Values later without
+// paying the cost of shoving their integers in an interface{}.
+var connStateInterface = [...]interface{}{
+ StateNew: StateNew,
+ StateActive: StateActive,
+ StateIdle: StateIdle,
+ StateHijacked: StateHijacked,
+ StateClosed: StateClosed,
+}
+
// badRequestError is a literal string (used by in the server in HTML,
// unescaped) to tell the user why their request was bad. It should
// be plain text without user info or other embedded errors.
@@ -1481,11 +1680,34 @@ type badRequestError string
func (e badRequestError) Error() string { return "Bad Request: " + string(e) }
+// ErrAbortHandler is a sentinel panic value to abort a handler.
+// While any panic from ServeHTTP aborts the response to the client,
+// panicking with ErrAbortHandler also suppresses logging of a stack
+// trace to the server's error log.
+var ErrAbortHandler = errors.New("net/http: abort Handler")
+
+// isCommonNetReadError reports whether err is a common error
+// encountered during reading a request off the network when the
+// client has gone away or had its read fail somehow. This is used to
+// determine which logs are interesting enough to log about.
+func isCommonNetReadError(err error) bool {
+ if err == io.EOF {
+ return true
+ }
+ if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
+ return true
+ }
+ if oe, ok := err.(*net.OpError); ok && oe.Op == "read" {
+ return true
+ }
+ return false
+}
+
// Serve a new connection.
func (c *conn) serve(ctx context.Context) {
c.remoteAddr = c.rwc.RemoteAddr().String()
defer func() {
- if err := recover(); err != nil {
+ if err := recover(); err != nil && err != ErrAbortHandler {
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
@@ -1521,13 +1743,14 @@ func (c *conn) serve(ctx context.Context) {
// HTTP/1.x from here on.
- c.r = &connReader{r: c.rwc}
- c.bufr = newBufioReader(c.r)
- c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10)
-
ctx, cancelCtx := context.WithCancel(ctx)
+ c.cancelCtx = cancelCtx
defer cancelCtx()
+ c.r = &connReader{conn: c}
+ c.bufr = newBufioReader(c.r)
+ c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10)
+
for {
w, err := c.readRequest(ctx)
if c.r.remain != c.server.initialReadLimitSize() {
@@ -1535,27 +1758,29 @@ func (c *conn) serve(ctx context.Context) {
c.setState(c.rwc, StateActive)
}
if err != nil {
+ const errorHeaders = "\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\n"
+
if err == errTooLarge {
// Their HTTP client may or may not be
// able to read this if we're
// responding to them and hanging up
// while they're still writing their
// request. Undefined behavior.
- io.WriteString(c.rwc, "HTTP/1.1 431 Request Header Fields Too Large\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n431 Request Header Fields Too Large")
+ const publicErr = "431 Request Header Fields Too Large"
+ fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr)
c.closeWriteAndWait()
return
}
- if err == io.EOF {
- return // don't reply
- }
- if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
+ if isCommonNetReadError(err) {
return // don't reply
}
- var publicErr string
+
+ publicErr := "400 Bad Request"
if v, ok := err.(badRequestError); ok {
- publicErr = ": " + string(v)
+ publicErr = publicErr + ": " + string(v)
}
- io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n400 Bad Request"+publicErr)
+
+ fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr)
return
}
@@ -1571,11 +1796,24 @@ func (c *conn) serve(ctx context.Context) {
return
}
+ c.curReq.Store(w)
+
+ if requestBodyRemains(req.Body) {
+ registerOnHitEOF(req.Body, w.conn.r.startBackgroundRead)
+ } else {
+ if w.conn.bufr.Buffered() > 0 {
+ w.conn.r.closeNotifyFromPipelinedRequest()
+ }
+ w.conn.r.startBackgroundRead()
+ }
+
// HTTP cannot have multiple simultaneous active requests.[*]
// Until the server replies to this request, it can't read another,
// so we might as well run the handler in this goroutine.
// [*] Not strictly true: HTTP pipelining. We could let them all process
// in parallel even if their responses need to be serialized.
+ // But we're not going to implement HTTP pipelining because it
+ // was never deployed in the wild and the answer is HTTP/2.
serverHandler{c.server}.ServeHTTP(w, w.req)
w.cancelCtx()
if c.hijacked() {
@@ -1589,6 +1827,23 @@ func (c *conn) serve(ctx context.Context) {
return
}
c.setState(c.rwc, StateIdle)
+ c.curReq.Store((*response)(nil))
+
+ if !w.conn.server.doKeepAlives() {
+ // We're in shutdown mode. We might've replied
+ // to the user without "Connection: close" and
+ // they might think they can send another
+ // request, but such is life with HTTP/1.1.
+ return
+ }
+
+ if d := c.server.idleTimeout(); d != 0 {
+ c.rwc.SetReadDeadline(time.Now().Add(d))
+ if _, err := c.bufr.Peek(4); err != nil {
+ return
+ }
+ }
+ c.rwc.SetReadDeadline(time.Time{})
}
}
@@ -1624,10 +1879,6 @@ func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
c.mu.Lock()
defer c.mu.Unlock()
- if w.closeNotifyCh != nil {
- return nil, nil, errors.New("http: Hijack is incompatible with use of CloseNotifier in same ServeHTTP call")
- }
-
// Release the bufioWriter that writes to the chunk writer, it is not
// used after a connection has been hijacked.
rwc, buf, err = c.hijackLocked()
@@ -1642,50 +1893,7 @@ func (w *response) CloseNotify() <-chan bool {
if w.handlerDone.isSet() {
panic("net/http: CloseNotify called after ServeHTTP finished")
}
- c := w.conn
- c.mu.Lock()
- defer c.mu.Unlock()
-
- if w.closeNotifyCh != nil {
- return w.closeNotifyCh
- }
- ch := make(chan bool, 1)
- w.closeNotifyCh = ch
-
- if w.conn.hijackedv {
- // CloseNotify is undefined after a hijack, but we have
- // no place to return an error, so just return a channel,
- // even though it'll never receive a value.
- return ch
- }
-
- var once sync.Once
- notify := func() { once.Do(func() { ch <- true }) }
-
- if requestBodyRemains(w.reqBody) {
- // They're still consuming the request body, so we
- // shouldn't notify yet.
- registerOnHitEOF(w.reqBody, func() {
- c.mu.Lock()
- defer c.mu.Unlock()
- startCloseNotifyBackgroundRead(c, notify)
- })
- } else {
- startCloseNotifyBackgroundRead(c, notify)
- }
- return ch
-}
-
-// c.mu must be held.
-func startCloseNotifyBackgroundRead(c *conn, notify func()) {
- if c.bufr.Buffered() > 0 {
- // They've consumed the request body, so anything
- // remaining is a pipelined request, which we
- // document as firing on.
- notify()
- } else {
- c.r.startBackgroundRead(notify)
- }
+ return w.closeNotifyCh
}
func registerOnHitEOF(rc io.ReadCloser, fn func()) {
@@ -1702,7 +1910,7 @@ func registerOnHitEOF(rc io.ReadCloser, fn func()) {
// requestBodyRemains reports whether future calls to Read
// on rc might yield more data.
func requestBodyRemains(rc io.ReadCloser) bool {
- if rc == eofReader {
+ if rc == NoBody {
return false
}
switch v := rc.(type) {
@@ -1816,7 +2024,7 @@ func Redirect(w ResponseWriter, r *Request, urlStr string, code int) {
}
}
- w.Header().Set("Location", urlStr)
+ w.Header().Set("Location", hexEscapeNonASCII(urlStr))
w.WriteHeader(code)
// RFC 2616 recommends that a short note "SHOULD" be included in the
@@ -2094,11 +2302,36 @@ func Serve(l net.Listener, handler Handler) error {
// A Server defines parameters for running an HTTP server.
// The zero value for Server is a valid configuration.
type Server struct {
- Addr string // TCP address to listen on, ":http" if empty
- Handler Handler // handler to invoke, http.DefaultServeMux if nil
- ReadTimeout time.Duration // maximum duration before timing out read of the request
- WriteTimeout time.Duration // maximum duration before timing out write of the response
- TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS
+ Addr string // TCP address to listen on, ":http" if empty
+ Handler Handler // handler to invoke, http.DefaultServeMux if nil
+ TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS
+
+ // ReadTimeout is the maximum duration for reading the entire
+ // request, including the body.
+ //
+ // Because ReadTimeout does not let Handlers make per-request
+ // decisions on each request body's acceptable deadline or
+ // upload rate, most users will prefer to use
+ // ReadHeaderTimeout. It is valid to use them both.
+ ReadTimeout time.Duration
+
+ // ReadHeaderTimeout is the amount of time allowed to read
+ // request headers. The connection's read deadline is reset
+ // after reading the headers and the Handler can decide what
+ // is considered too slow for the body.
+ ReadHeaderTimeout time.Duration
+
+ // WriteTimeout is the maximum duration before timing out
+ // writes of the response. It is reset whenever a new
+ // request's header is read. Like ReadTimeout, it does not
+ // let Handlers make decisions on a per-request basis.
+ WriteTimeout time.Duration
+
+ // IdleTimeout is the maximum amount of time to wait for the
+ // next request when keep-alives are enabled. If IdleTimeout
+ // is zero, the value of ReadTimeout is used. If both are
+ // zero, there is no timeout.
+ IdleTimeout time.Duration
// MaxHeaderBytes controls the maximum number of bytes the
// server will read parsing the request header's keys and
@@ -2114,7 +2347,8 @@ type Server struct {
// handle HTTP requests and will initialize the Request's TLS
// and RemoteAddr if not already set. The connection is
// automatically closed when the function returns.
- // If TLSNextProto is nil, HTTP/2 support is enabled automatically.
+ // If TLSNextProto is not nil, HTTP/2 support is not enabled
+ // automatically.
TLSNextProto map[string]func(*Server, *tls.Conn, Handler)
// ConnState specifies an optional callback function that is
@@ -2129,8 +2363,132 @@ type Server struct {
ErrorLog *log.Logger
disableKeepAlives int32 // accessed atomically.
+ inShutdown int32 // accessed atomically (non-zero means we're in Shutdown)
nextProtoOnce sync.Once // guards setupHTTP2_* init
nextProtoErr error // result of http2.ConfigureServer if used
+
+ mu sync.Mutex
+ listeners map[net.Listener]struct{}
+ activeConn map[*conn]struct{}
+ doneChan chan struct{}
+}
+
+func (s *Server) getDoneChan() <-chan struct{} {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return s.getDoneChanLocked()
+}
+
+func (s *Server) getDoneChanLocked() chan struct{} {
+ if s.doneChan == nil {
+ s.doneChan = make(chan struct{})
+ }
+ return s.doneChan
+}
+
+func (s *Server) closeDoneChanLocked() {
+ ch := s.getDoneChanLocked()
+ select {
+ case <-ch:
+ // Already closed. Don't close again.
+ default:
+ // Safe to close here. We're the only closer, guarded
+ // by s.mu.
+ close(ch)
+ }
+}
+
+// Close immediately closes all active net.Listeners and any
+// connections in state StateNew, StateActive, or StateIdle. For a
+// graceful shutdown, use Shutdown.
+//
+// Close does not attempt to close (and does not even know about)
+// any hijacked connections, such as WebSockets.
+//
+// Close returns any error returned from closing the Server's
+// underlying Listener(s).
+func (srv *Server) Close() error {
+ srv.mu.Lock()
+ defer srv.mu.Unlock()
+ srv.closeDoneChanLocked()
+ err := srv.closeListenersLocked()
+ for c := range srv.activeConn {
+ c.rwc.Close()
+ delete(srv.activeConn, c)
+ }
+ return err
+}
+
+// shutdownPollInterval is how often we poll for quiescence
+// during Server.Shutdown. This is lower during tests, to
+// speed up tests.
+// Ideally we could find a solution that doesn't involve polling,
+// but which also doesn't have a high runtime cost (and doesn't
+// involve any contentious mutexes), but that is left as an
+// exercise for the reader.
+var shutdownPollInterval = 500 * time.Millisecond
+
+// Shutdown gracefully shuts down the server without interrupting any
+// active connections. Shutdown works by first closing all open
+// listeners, then closing all idle connections, and then waiting
+// indefinitely for connections to return to idle and then shut down.
+// If the provided context expires before the shutdown is complete,
+// then the context's error is returned.
+//
+// Shutdown does not attempt to close nor wait for hijacked
+// connections such as WebSockets. The caller of Shutdown should
+// separately notify such long-lived connections of shutdown and wait
+// for them to close, if desired.
+func (srv *Server) Shutdown(ctx context.Context) error {
+ atomic.AddInt32(&srv.inShutdown, 1)
+ defer atomic.AddInt32(&srv.inShutdown, -1)
+
+ srv.mu.Lock()
+ lnerr := srv.closeListenersLocked()
+ srv.closeDoneChanLocked()
+ srv.mu.Unlock()
+
+ ticker := time.NewTicker(shutdownPollInterval)
+ defer ticker.Stop()
+ for {
+ if srv.closeIdleConns() {
+ return lnerr
+ }
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-ticker.C:
+ }
+ }
+}
+
+// closeIdleConns closes all idle connections and reports whether the
+// server is quiescent.
+func (s *Server) closeIdleConns() bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ quiescent := true
+ for c := range s.activeConn {
+ st, ok := c.curState.Load().(ConnState)
+ if !ok || st != StateIdle {
+ quiescent = false
+ continue
+ }
+ c.rwc.Close()
+ delete(s.activeConn, c)
+ }
+ return quiescent
+}
+
+func (s *Server) closeListenersLocked() error {
+ var err error
+ for ln := range s.listeners {
+ if cerr := ln.Close(); cerr != nil && err == nil {
+ err = cerr
+ }
+ delete(s.listeners, ln)
+ }
+ return err
}
// A ConnState represents the state of a client connection to a server.
@@ -2243,6 +2601,8 @@ func (srv *Server) shouldConfigureHTTP2ForServe() bool {
return strSliceContains(srv.TLSConfig.NextProtos, http2NextProtoTLS)
}
+var ErrServerClosed = errors.New("http: Server closed")
+
// Serve accepts incoming connections on the Listener l, creating a
// new service goroutine for each. The service goroutines read requests and
// then call srv.Handler to reply to them.
@@ -2252,7 +2612,8 @@ func (srv *Server) shouldConfigureHTTP2ForServe() bool {
// srv.TLSConfig is non-nil and doesn't include the string "h2" in
// Config.NextProtos, HTTP/2 support is not enabled.
//
-// Serve always returns a non-nil error.
+// Serve always returns a non-nil error. After Shutdown or Close, the
+// returned error is ErrServerClosed.
func (srv *Server) Serve(l net.Listener) error {
defer l.Close()
if fn := testHookServerServe; fn != nil {
@@ -2264,14 +2625,20 @@ func (srv *Server) Serve(l net.Listener) error {
return err
}
- // TODO: allow changing base context? can't imagine concrete
- // use cases yet.
- baseCtx := context.Background()
+ srv.trackListener(l, true)
+ defer srv.trackListener(l, false)
+
+ baseCtx := context.Background() // base is always background, per Issue 16220
ctx := context.WithValue(baseCtx, ServerContextKey, srv)
ctx = context.WithValue(ctx, LocalAddrContextKey, l.Addr())
for {
rw, e := l.Accept()
if e != nil {
+ select {
+ case <-srv.getDoneChan():
+ return ErrServerClosed
+ default:
+ }
if ne, ok := e.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
@@ -2294,8 +2661,57 @@ func (srv *Server) Serve(l net.Listener) error {
}
}
+func (s *Server) trackListener(ln net.Listener, add bool) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.listeners == nil {
+ s.listeners = make(map[net.Listener]struct{})
+ }
+ if add {
+ // If the *Server is being reused after a previous
+ // Close or Shutdown, reset its doneChan:
+ if len(s.listeners) == 0 && len(s.activeConn) == 0 {
+ s.doneChan = nil
+ }
+ s.listeners[ln] = struct{}{}
+ } else {
+ delete(s.listeners, ln)
+ }
+}
+
+func (s *Server) trackConn(c *conn, add bool) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.activeConn == nil {
+ s.activeConn = make(map[*conn]struct{})
+ }
+ if add {
+ s.activeConn[c] = struct{}{}
+ } else {
+ delete(s.activeConn, c)
+ }
+}
+
+func (s *Server) idleTimeout() time.Duration {
+ if s.IdleTimeout != 0 {
+ return s.IdleTimeout
+ }
+ return s.ReadTimeout
+}
+
+func (s *Server) readHeaderTimeout() time.Duration {
+ if s.ReadHeaderTimeout != 0 {
+ return s.ReadHeaderTimeout
+ }
+ return s.ReadTimeout
+}
+
func (s *Server) doKeepAlives() bool {
- return atomic.LoadInt32(&s.disableKeepAlives) == 0
+ return atomic.LoadInt32(&s.disableKeepAlives) == 0 && !s.shuttingDown()
+}
+
+func (s *Server) shuttingDown() bool {
+ return atomic.LoadInt32(&s.inShutdown) != 0
}
// SetKeepAlivesEnabled controls whether HTTP keep-alives are enabled.
@@ -2305,9 +2721,21 @@ func (s *Server) doKeepAlives() bool {
func (srv *Server) SetKeepAlivesEnabled(v bool) {
if v {
atomic.StoreInt32(&srv.disableKeepAlives, 0)
- } else {
- atomic.StoreInt32(&srv.disableKeepAlives, 1)
+ return
}
+ atomic.StoreInt32(&srv.disableKeepAlives, 1)
+
+ // Close idle HTTP/1 conns:
+ srv.closeIdleConns()
+
+ // Close HTTP/2 conns, as soon as they become idle, but reset
+ // the chan so future conns (if the listener is still active)
+ // still work and don't get a GOAWAY immediately, before their
+ // first request:
+ srv.mu.Lock()
+ defer srv.mu.Unlock()
+ srv.closeDoneChanLocked() // closes http2 conns
+ srv.doneChan = nil
}
func (s *Server) logf(format string, args ...interface{}) {
@@ -2630,24 +3058,6 @@ func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) {
}
}
-type eofReaderWithWriteTo struct{}
-
-func (eofReaderWithWriteTo) WriteTo(io.Writer) (int64, error) { return 0, nil }
-func (eofReaderWithWriteTo) Read([]byte) (int, error) { return 0, io.EOF }
-
-// eofReader is a non-nil io.ReadCloser that always returns EOF.
-// It has a WriteTo method so io.Copy won't need a buffer.
-var eofReader = &struct {
- eofReaderWithWriteTo
- io.Closer
-}{
- eofReaderWithWriteTo{},
- ioutil.NopCloser(nil),
-}
-
-// Verify that an io.Copy from an eofReader won't require a buffer.
-var _ io.WriterTo = eofReader
-
// initNPNRequest is an HTTP handler that initializes certain
// uninitialized fields in its *Request. Such partially-initialized
// Requests come from NPN protocol handlers.
@@ -2662,7 +3072,7 @@ func (h initNPNRequest) ServeHTTP(rw ResponseWriter, req *Request) {
*req.TLS = h.c.ConnectionState()
}
if req.Body == nil {
- req.Body = eofReader
+ req.Body = NoBody
}
if req.RemoteAddr == "" {
req.RemoteAddr = h.c.RemoteAddr().String()
@@ -2723,6 +3133,7 @@ func (w checkConnErrorWriter) Write(p []byte) (n int, err error) {
n, err = w.c.rwc.Write(p)
if err != nil && w.c.werr == nil {
w.c.werr = err
+ w.c.cancelCtx()
}
return
}
diff --git a/libgo/go/net/http/sniff_test.go b/libgo/go/net/http/sniff_test.go
index ac404bf..38f3f81 100644
--- a/libgo/go/net/http/sniff_test.go
+++ b/libgo/go/net/http/sniff_test.go
@@ -66,6 +66,7 @@ func TestServerContentType_h1(t *testing.T) { testServerContentType(t, h1Mode) }
func TestServerContentType_h2(t *testing.T) { testServerContentType(t, h2Mode) }
func testServerContentType(t *testing.T, h2 bool) {
+ setParallel(t)
defer afterTest(t)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
i, _ := strconv.Atoi(r.FormValue("i"))
@@ -160,6 +161,7 @@ func testContentTypeWithCopy(t *testing.T, h2 bool) {
func TestSniffWriteSize_h1(t *testing.T) { testSniffWriteSize(t, h1Mode) }
func TestSniffWriteSize_h2(t *testing.T) { testSniffWriteSize(t, h2Mode) }
func testSniffWriteSize(t *testing.T, h2 bool) {
+ setParallel(t)
defer afterTest(t)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
size, _ := strconv.Atoi(r.FormValue("size"))
diff --git a/libgo/go/net/http/transfer.go b/libgo/go/net/http/transfer.go
index c653467..4f47637 100644
--- a/libgo/go/net/http/transfer.go
+++ b/libgo/go/net/http/transfer.go
@@ -17,6 +17,7 @@ import (
"strconv"
"strings"
"sync"
+ "time"
"golang_org/x/net/lex/httplex"
)
@@ -33,6 +34,23 @@ func (r errorReader) Read(p []byte) (n int, err error) {
return 0, r.err
}
+type byteReader struct {
+ b byte
+ done bool
+}
+
+func (br *byteReader) Read(p []byte) (n int, err error) {
+ if br.done {
+ return 0, io.EOF
+ }
+ if len(p) == 0 {
+ return 0, nil
+ }
+ br.done = true
+ p[0] = br.b
+ return 1, io.EOF
+}
+
// transferWriter inspects the fields of a user-supplied Request or Response,
// sanitizes them without changing the user object and provides methods for
// writing the respective header, body and trailer in wire format.
@@ -46,6 +64,9 @@ type transferWriter struct {
TransferEncoding []string
Trailer Header
IsResponse bool
+
+ FlushHeaders bool // flush headers to network before body
+ ByteReadCh chan readResult // non-nil if probeRequestBody called
}
func newTransferWriter(r interface{}) (t *transferWriter, err error) {
@@ -59,37 +80,15 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) {
return nil, fmt.Errorf("http: Request.ContentLength=%d with nil Body", rr.ContentLength)
}
t.Method = valueOrDefault(rr.Method, "GET")
- t.Body = rr.Body
- t.BodyCloser = rr.Body
- t.ContentLength = rr.ContentLength
t.Close = rr.Close
t.TransferEncoding = rr.TransferEncoding
t.Trailer = rr.Trailer
- atLeastHTTP11 = rr.ProtoAtLeast(1, 1)
- if t.Body != nil && len(t.TransferEncoding) == 0 && atLeastHTTP11 {
- if t.ContentLength == 0 {
- // Test to see if it's actually zero or just unset.
- var buf [1]byte
- n, rerr := io.ReadFull(t.Body, buf[:])
- if rerr != nil && rerr != io.EOF {
- t.ContentLength = -1
- t.Body = errorReader{rerr}
- } else if n == 1 {
- // Oh, guess there is data in this Body Reader after all.
- // The ContentLength field just wasn't set.
- // Stich the Body back together again, re-attaching our
- // consumed byte.
- t.ContentLength = -1
- t.Body = io.MultiReader(bytes.NewReader(buf[:]), t.Body)
- } else {
- // Body is actually empty.
- t.Body = nil
- t.BodyCloser = nil
- }
- }
- if t.ContentLength < 0 {
- t.TransferEncoding = []string{"chunked"}
- }
+ atLeastHTTP11 = rr.protoAtLeastOutgoing(1, 1)
+ t.Body = rr.Body
+ t.BodyCloser = rr.Body
+ t.ContentLength = rr.outgoingLength()
+ if t.ContentLength < 0 && len(t.TransferEncoding) == 0 && atLeastHTTP11 && t.shouldSendChunkedRequestBody() {
+ t.TransferEncoding = []string{"chunked"}
}
case *Response:
t.IsResponse = true
@@ -103,7 +102,7 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) {
t.TransferEncoding = rr.TransferEncoding
t.Trailer = rr.Trailer
atLeastHTTP11 = rr.ProtoAtLeast(1, 1)
- t.ResponseToHEAD = noBodyExpected(t.Method)
+ t.ResponseToHEAD = noResponseBodyExpected(t.Method)
}
// Sanitize Body,ContentLength,TransferEncoding
@@ -131,7 +130,100 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) {
return t, nil
}
-func noBodyExpected(requestMethod string) bool {
+// shouldSendChunkedRequestBody reports whether we should try to send a
+// chunked request body to the server. In particular, the case we really
+// want to prevent is sending a GET or other typically-bodyless request to a
+// server with a chunked body when the body has zero bytes, since GETs with
+// bodies (while acceptable according to specs), even zero-byte chunked
+// bodies, are approximately never seen in the wild and confuse most
+// servers. See Issue 18257, as one example.
+//
+// The only reason we'd send such a request is if the user set the Body to a
+// non-nil value (say, ioutil.NopCloser(bytes.NewReader(nil))) and didn't
+// set ContentLength, or NewRequest set it to -1 (unknown), so then we assume
+// there's bytes to send.
+//
+// This code tries to read a byte from the Request.Body in such cases to see
+// whether the body actually has content (super rare) or is actually just
+// a non-nil content-less ReadCloser (the more common case). In that more
+// common case, we act as if their Body were nil instead, and don't send
+// a body.
+func (t *transferWriter) shouldSendChunkedRequestBody() bool {
+ // Note that t.ContentLength is the corrected content length
+ // from rr.outgoingLength, so 0 actually means zero, not unknown.
+ if t.ContentLength >= 0 || t.Body == nil { // redundant checks; caller did them
+ return false
+ }
+ if requestMethodUsuallyLacksBody(t.Method) {
+ // Only probe the Request.Body for GET/HEAD/DELETE/etc
+ // requests, because it's only those types of requests
+ // that confuse servers.
+ t.probeRequestBody() // adjusts t.Body, t.ContentLength
+ return t.Body != nil
+ }
+ // For all other request types (PUT, POST, PATCH, or anything
+ // made-up we've never heard of), assume it's normal and the server
+ // can deal with a chunked request body. Maybe we'll adjust this
+ // later.
+ return true
+}
+
+// probeRequestBody reads a byte from t.Body to see whether it's empty
+// (returns io.EOF right away).
+//
+// But because we've had problems with this blocking users in the past
+// (issue 17480) when the body is a pipe (perhaps waiting on the response
+// headers before the pipe is fed data), we need to be careful and bound how
+// long we wait for it. This delay will only affect users if all the following
+// are true:
+// * the request body blocks
+// * the content length is not set (or set to -1)
+// * the method doesn't usually have a body (GET, HEAD, DELETE, ...)
+// * there is no transfer-encoding=chunked already set.
+// In other words, this delay will not normally affect anybody, and there
+// are workarounds if it does.
+func (t *transferWriter) probeRequestBody() {
+ t.ByteReadCh = make(chan readResult, 1)
+ go func(body io.Reader) {
+ var buf [1]byte
+ var rres readResult
+ rres.n, rres.err = body.Read(buf[:])
+ if rres.n == 1 {
+ rres.b = buf[0]
+ }
+ t.ByteReadCh <- rres
+ }(t.Body)
+ timer := time.NewTimer(200 * time.Millisecond)
+ select {
+ case rres := <-t.ByteReadCh:
+ timer.Stop()
+ if rres.n == 0 && rres.err == io.EOF {
+ // It was empty.
+ t.Body = nil
+ t.ContentLength = 0
+ } else if rres.n == 1 {
+ if rres.err != nil {
+ t.Body = io.MultiReader(&byteReader{b: rres.b}, errorReader{rres.err})
+ } else {
+ t.Body = io.MultiReader(&byteReader{b: rres.b}, t.Body)
+ }
+ } else if rres.err != nil {
+ t.Body = errorReader{rres.err}
+ }
+ case <-timer.C:
+ // Too slow. Don't wait. Read it later, and keep
+ // assuming that this is ContentLength == -1
+ // (unknown), which means we'll send a
+ // "Transfer-Encoding: chunked" header.
+ t.Body = io.MultiReader(finishAsyncByteRead{t}, t.Body)
+ // Request that Request.Write flush the headers to the
+ // network before writing the body, since our body may not
+ // become readable until it's seen the response headers.
+ t.FlushHeaders = true
+ }
+}
+
+func noResponseBodyExpected(requestMethod string) bool {
return requestMethod == "HEAD"
}
@@ -214,7 +306,7 @@ func (t *transferWriter) WriteBody(w io.Writer) error {
if t.Body != nil {
if chunked(t.TransferEncoding) {
if bw, ok := w.(*bufio.Writer); ok && !t.IsResponse {
- w = &internal.FlushAfterChunkWriter{bw}
+ w = &internal.FlushAfterChunkWriter{Writer: bw}
}
cw := internal.NewChunkedWriter(w)
_, err = io.Copy(cw, t.Body)
@@ -235,7 +327,9 @@ func (t *transferWriter) WriteBody(w io.Writer) error {
if err != nil {
return err
}
- if err = t.BodyCloser.Close(); err != nil {
+ }
+ if t.BodyCloser != nil {
+ if err := t.BodyCloser.Close(); err != nil {
return err
}
}
@@ -385,13 +479,13 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
// or close connection when finished, since multipart is not supported yet
switch {
case chunked(t.TransferEncoding):
- if noBodyExpected(t.RequestMethod) {
- t.Body = eofReader
+ if noResponseBodyExpected(t.RequestMethod) {
+ t.Body = NoBody
} else {
t.Body = &body{src: internal.NewChunkedReader(r), hdr: msg, r: r, closing: t.Close}
}
case realLength == 0:
- t.Body = eofReader
+ t.Body = NoBody
case realLength > 0:
t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close}
default:
@@ -401,7 +495,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
t.Body = &body{src: r, closing: t.Close}
} else {
// Persistent connection (i.e. HTTP/1.1)
- t.Body = eofReader
+ t.Body = NoBody
}
}
@@ -493,10 +587,31 @@ func (t *transferReader) fixTransferEncoding() error {
// function is not a method, because ultimately it should be shared by
// ReadResponse and ReadRequest.
func fixLength(isResponse bool, status int, requestMethod string, header Header, te []string) (int64, error) {
- contentLens := header["Content-Length"]
isRequest := !isResponse
+ contentLens := header["Content-Length"]
+
+ // Hardening against HTTP request smuggling
+ if len(contentLens) > 1 {
+ // Per RFC 7230 Section 3.3.2, prevent multiple
+ // Content-Length headers if they differ in value.
+ // If there are dups of the value, remove the dups.
+ // See Issue 16490.
+ first := strings.TrimSpace(contentLens[0])
+ for _, ct := range contentLens[1:] {
+ if first != strings.TrimSpace(ct) {
+ return 0, fmt.Errorf("http: message cannot contain multiple Content-Length headers; got %q", contentLens)
+ }
+ }
+
+ // deduplicate Content-Length
+ header.Del("Content-Length")
+ header.Add("Content-Length", first)
+
+ contentLens = header["Content-Length"]
+ }
+
// Logic based on response type or status
- if noBodyExpected(requestMethod) {
+ if noResponseBodyExpected(requestMethod) {
// For HTTP requests, as part of hardening against request
// smuggling (RFC 7230), don't allow a Content-Length header for
// methods which don't permit bodies. As an exception, allow
@@ -514,11 +629,6 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header,
return 0, nil
}
- if len(contentLens) > 1 {
- // harden against HTTP request smuggling. See RFC 7230.
- return 0, errors.New("http: message cannot contain multiple Content-Length headers")
- }
-
// Logic based on Transfer-Encoding
if chunked(te) {
return -1, nil
@@ -539,7 +649,7 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header,
header.Del("Content-Length")
}
- if !isResponse {
+ if isRequest {
// RFC 2616 neither explicitly permits nor forbids an
// entity-body on a GET request so we permit one if
// declared, but we default to 0 here (not -1 below)
@@ -864,3 +974,21 @@ func parseContentLength(cl string) (int64, error) {
return n, nil
}
+
+// finishAsyncByteRead finishes reading the 1-byte sniff
+// from the ContentLength==0, Body!=nil case.
+type finishAsyncByteRead struct {
+ tw *transferWriter
+}
+
+func (fr finishAsyncByteRead) Read(p []byte) (n int, err error) {
+ if len(p) == 0 {
+ return
+ }
+ rres := <-fr.tw.ByteReadCh
+ n, err = rres.n, rres.err
+ if n == 1 {
+ p[0] = rres.b
+ }
+ return
+}
diff --git a/libgo/go/net/http/transport.go b/libgo/go/net/http/transport.go
index 1f07634..571943d6 100644
--- a/libgo/go/net/http/transport.go
+++ b/libgo/go/net/http/transport.go
@@ -25,6 +25,7 @@ import (
"os"
"strings"
"sync"
+ "sync/atomic"
"time"
"golang_org/x/net/lex/httplex"
@@ -40,6 +41,7 @@ var DefaultTransport RoundTripper = &Transport{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
+ DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
@@ -66,8 +68,10 @@ const DefaultMaxIdleConnsPerHost = 2
// For high-level functionality, such as cookies and redirects, see Client.
//
// Transport uses HTTP/1.1 for HTTP URLs and either HTTP/1.1 or HTTP/2
-// for HTTPS URLs, depending on whether the server supports HTTP/2.
-// See the package docs for more about HTTP/2.
+// for HTTPS URLs, depending on whether the server supports HTTP/2,
+// and how the Transport is configured. The DefaultTransport supports HTTP/2.
+// To explicitly enable HTTP/2 on a transport, use golang.org/x/net/http2
+// and call ConfigureTransport. See the package docs for more about HTTP/2.
type Transport struct {
idleMu sync.Mutex
wantIdle bool // user has requested to close all idle conns
@@ -76,10 +80,10 @@ type Transport struct {
idleLRU connLRU
reqMu sync.Mutex
- reqCanceler map[*Request]func()
+ reqCanceler map[*Request]func(error)
- altMu sync.RWMutex
- altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper
+ altMu sync.Mutex // guards changing altProto only
+ altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme
// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
@@ -111,7 +115,9 @@ type Transport struct {
DialTLS func(network, addr string) (net.Conn, error)
// TLSClientConfig specifies the TLS configuration to use with
- // tls.Client. If nil, the default configuration is used.
+ // tls.Client.
+ // If nil, the default configuration is used.
+ // If non-nil, HTTP/2 support may not be enabled by default.
TLSClientConfig *tls.Config
// TLSHandshakeTimeout specifies the maximum amount of time waiting to
@@ -156,7 +162,9 @@ type Transport struct {
// ExpectContinueTimeout, if non-zero, specifies the amount of
// time to wait for a server's first response headers after fully
// writing the request headers if the request has an
- // "Expect: 100-continue" header. Zero means no timeout.
+ // "Expect: 100-continue" header. Zero means no timeout and
+ // causes the body to be sent immediately, without
+ // waiting for the server to approve.
// This time does not include the time to send the request header.
ExpectContinueTimeout time.Duration
@@ -168,9 +176,14 @@ type Transport struct {
// called with the request's authority (such as "example.com"
// or "example.com:1234") and the TLS connection. The function
// must return a RoundTripper that then handles the request.
- // If TLSNextProto is nil, HTTP/2 support is enabled automatically.
+ // If TLSNextProto is not nil, HTTP/2 support is not enabled
+ // automatically.
TLSNextProto map[string]func(authority string, c *tls.Conn) RoundTripper
+ // ProxyConnectHeader optionally specifies headers to send to
+ // proxies during CONNECT requests.
+ ProxyConnectHeader Header
+
// MaxResponseHeaderBytes specifies a limit on how many
// response bytes are allowed in the server's response
// header.
@@ -330,11 +343,9 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) {
}
}
}
- // TODO(bradfitz): switch to atomic.Value for this map instead of RWMutex
- t.altMu.RLock()
- altRT := t.altProto[scheme]
- t.altMu.RUnlock()
- if altRT != nil {
+
+ altProto, _ := t.altProto.Load().(map[string]RoundTripper)
+ if altRT := altProto[scheme]; altRT != nil {
if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol {
return resp, err
}
@@ -421,19 +432,15 @@ func (pc *persistConn) shouldRetryRequest(req *Request, err error) bool {
// our request (as opposed to sending an error).
return false
}
+ if _, ok := err.(nothingWrittenError); ok {
+ // We never wrote anything, so it's safe to retry.
+ return true
+ }
if !req.isReplayable() {
// Don't retry non-idempotent requests.
-
- // TODO: swap the nothingWrittenError and isReplayable checks,
- // putting the "if nothingWrittenError => return true" case
- // first, per golang.org/issue/15723
return false
}
- switch err.(type) {
- case nothingWrittenError:
- // We never wrote anything, so it's safe to retry.
- return true
- case transportReadFromServerError:
+ if _, ok := err.(transportReadFromServerError); ok {
// We got some non-EOF net.Conn.Read failure reading
// the 1st response byte from the server.
return true
@@ -463,13 +470,16 @@ var ErrSkipAltProtocol = errors.New("net/http: skip alternate protocol")
func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) {
t.altMu.Lock()
defer t.altMu.Unlock()
- if t.altProto == nil {
- t.altProto = make(map[string]RoundTripper)
- }
- if _, exists := t.altProto[scheme]; exists {
+ oldMap, _ := t.altProto.Load().(map[string]RoundTripper)
+ if _, exists := oldMap[scheme]; exists {
panic("protocol " + scheme + " already registered")
}
- t.altProto[scheme] = rt
+ newMap := make(map[string]RoundTripper)
+ for k, v := range oldMap {
+ newMap[k] = v
+ }
+ newMap[scheme] = rt
+ t.altProto.Store(newMap)
}
// CloseIdleConnections closes any connections which were previously
@@ -502,12 +512,17 @@ func (t *Transport) CloseIdleConnections() {
// cancelable context instead. CancelRequest cannot cancel HTTP/2
// requests.
func (t *Transport) CancelRequest(req *Request) {
+ t.cancelRequest(req, errRequestCanceled)
+}
+
+// Cancel an in-flight request, recording the error value.
+func (t *Transport) cancelRequest(req *Request, err error) {
t.reqMu.Lock()
cancel := t.reqCanceler[req]
delete(t.reqCanceler, req)
t.reqMu.Unlock()
if cancel != nil {
- cancel()
+ cancel(err)
}
}
@@ -557,10 +572,18 @@ func (e *envOnce) reset() {
}
func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) {
+ if port := treq.URL.Port(); !validPort(port) {
+ return cm, fmt.Errorf("invalid URL port %q", port)
+ }
cm.targetScheme = treq.URL.Scheme
cm.targetAddr = canonicalAddr(treq.URL)
if t.Proxy != nil {
cm.proxyURL, err = t.Proxy(treq.Request)
+ if err == nil && cm.proxyURL != nil {
+ if port := cm.proxyURL.Port(); !validPort(port) {
+ return cm, fmt.Errorf("invalid proxy URL port %q", port)
+ }
+ }
}
return cm, err
}
@@ -787,11 +810,11 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) {
}
}
-func (t *Transport) setReqCanceler(r *Request, fn func()) {
+func (t *Transport) setReqCanceler(r *Request, fn func(error)) {
t.reqMu.Lock()
defer t.reqMu.Unlock()
if t.reqCanceler == nil {
- t.reqCanceler = make(map[*Request]func())
+ t.reqCanceler = make(map[*Request]func(error))
}
if fn != nil {
t.reqCanceler[r] = fn
@@ -804,7 +827,7 @@ func (t *Transport) setReqCanceler(r *Request, fn func()) {
// for the request, we don't set the function and return false.
// Since CancelRequest will clear the canceler, we can use the return value to detect if
// the request was canceled since the last setReqCancel call.
-func (t *Transport) replaceReqCanceler(r *Request, fn func()) bool {
+func (t *Transport) replaceReqCanceler(r *Request, fn func(error)) bool {
t.reqMu.Lock()
defer t.reqMu.Unlock()
_, ok := t.reqCanceler[r]
@@ -853,7 +876,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
// set request canceler to some non-nil function so we
// can detect whether it was cleared between now and when
// we enter roundTrip
- t.setReqCanceler(req, func() {})
+ t.setReqCanceler(req, func(error) {})
return pc, nil
}
@@ -878,8 +901,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
}()
}
- cancelc := make(chan struct{})
- t.setReqCanceler(req, func() { close(cancelc) })
+ cancelc := make(chan error, 1)
+ t.setReqCanceler(req, func(err error) { cancelc <- err })
go func() {
pc, err := t.dialConn(ctx, cm)
@@ -900,16 +923,21 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
// value.
select {
case <-req.Cancel:
+ // It was an error due to cancelation, so prioritize that
+ // error value. (Issue 16049)
+ return nil, errRequestCanceledConn
case <-req.Context().Done():
- case <-cancelc:
+ return nil, req.Context().Err()
+ case err := <-cancelc:
+ if err == errRequestCanceled {
+ err = errRequestCanceledConn
+ }
+ return nil, err
default:
// It wasn't an error due to cancelation, so
// return the original error message:
return nil, v.err
}
- // It was an error due to cancelation, so prioritize that
- // error value. (Issue 16049)
- return nil, errRequestCanceledConn
case pc := <-idleConnCh:
// Another request finished first and its net.Conn
// became available before our dial. Or somebody
@@ -926,10 +954,13 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
return nil, errRequestCanceledConn
case <-req.Context().Done():
handlePendingDial()
- return nil, errRequestCanceledConn
- case <-cancelc:
+ return nil, req.Context().Err()
+ case err := <-cancelc:
handlePendingDial()
- return nil, errRequestCanceledConn
+ if err == errRequestCanceled {
+ err = errRequestCanceledConn
+ }
+ return nil, err
}
}
@@ -943,6 +974,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
writeErrCh: make(chan error, 1),
writeLoopDone: make(chan struct{}),
}
+ trace := httptrace.ContextClientTrace(ctx)
tlsDial := t.DialTLS != nil && cm.targetScheme == "https" && cm.proxyURL == nil
if tlsDial {
var err error
@@ -956,18 +988,28 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
if tc, ok := pconn.conn.(*tls.Conn); ok {
// Handshake here, in case DialTLS didn't. TLSNextProto below
// depends on it for knowing the connection state.
+ if trace != nil && trace.TLSHandshakeStart != nil {
+ trace.TLSHandshakeStart()
+ }
if err := tc.Handshake(); err != nil {
go pconn.conn.Close()
+ if trace != nil && trace.TLSHandshakeDone != nil {
+ trace.TLSHandshakeDone(tls.ConnectionState{}, err)
+ }
return nil, err
}
cs := tc.ConnectionState()
+ if trace != nil && trace.TLSHandshakeDone != nil {
+ trace.TLSHandshakeDone(cs, nil)
+ }
pconn.tlsState = &cs
}
} else {
conn, err := t.dial(ctx, "tcp", cm.addr())
if err != nil {
if cm.proxyURL != nil {
- err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err)
+ // Return a typed error, per Issue 16997:
+ err = &net.OpError{Op: "proxyconnect", Net: "tcp", Err: err}
}
return nil, err
}
@@ -987,11 +1029,15 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
}
case cm.targetScheme == "https":
conn := pconn.conn
+ hdr := t.ProxyConnectHeader
+ if hdr == nil {
+ hdr = make(Header)
+ }
connectReq := &Request{
Method: "CONNECT",
URL: &url.URL{Opaque: cm.targetAddr},
Host: cm.targetAddr,
- Header: make(Header),
+ Header: hdr,
}
if pa := cm.proxyAuth(); pa != "" {
connectReq.Header.Set("Proxy-Authorization", pa)
@@ -1016,7 +1062,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
if cm.targetScheme == "https" && !tlsDial {
// Initiate TLS and check remote host name against certificate.
- cfg := cloneTLSClientConfig(t.TLSClientConfig)
+ cfg := cloneTLSConfig(t.TLSClientConfig)
if cfg.ServerName == "" {
cfg.ServerName = cm.tlsHost()
}
@@ -1030,6 +1076,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
})
}
go func() {
+ if trace != nil && trace.TLSHandshakeStart != nil {
+ trace.TLSHandshakeStart()
+ }
err := tlsConn.Handshake()
if timer != nil {
timer.Stop()
@@ -1038,6 +1087,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
}()
if err := <-errc; err != nil {
plainConn.Close()
+ if trace != nil && trace.TLSHandshakeDone != nil {
+ trace.TLSHandshakeDone(tls.ConnectionState{}, err)
+ }
return nil, err
}
if !cfg.InsecureSkipVerify {
@@ -1047,6 +1099,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
}
}
cs := tlsConn.ConnectionState()
+ if trace != nil && trace.TLSHandshakeDone != nil {
+ trace.TLSHandshakeDone(cs, nil)
+ }
pconn.tlsState = &cs
pconn.conn = tlsConn
}
@@ -1235,8 +1290,8 @@ type persistConn struct {
mu sync.Mutex // guards following fields
numExpectedResponses int
closed error // set non-nil when conn is closed, before closech is closed
+ canceledErr error // set non-nil if conn is canceled
broken bool // an error has happened on this connection; marked broken so it's not reused.
- canceled bool // whether this conn was broken due a CancelRequest
reused bool // whether conn has had successful request/response and is being reused.
// mutateHeaderFunc is an optional func to modify extra
// headers on each outbound request before it's written. (the
@@ -1274,11 +1329,12 @@ func (pc *persistConn) isBroken() bool {
return b
}
-// isCanceled reports whether this connection was closed due to CancelRequest.
-func (pc *persistConn) isCanceled() bool {
+// canceled returns non-nil if the connection was closed due to
+// CancelRequest or due to context cancelation.
+func (pc *persistConn) canceled() error {
pc.mu.Lock()
defer pc.mu.Unlock()
- return pc.canceled
+ return pc.canceledErr
}
// isReused reports whether this connection is in a known broken state.
@@ -1301,10 +1357,10 @@ func (pc *persistConn) gotIdleConnTrace(idleAt time.Time) (t httptrace.GotConnIn
return
}
-func (pc *persistConn) cancelRequest() {
+func (pc *persistConn) cancelRequest(err error) {
pc.mu.Lock()
defer pc.mu.Unlock()
- pc.canceled = true
+ pc.canceledErr = err
pc.closeLocked(errRequestCanceled)
}
@@ -1328,12 +1384,12 @@ func (pc *persistConn) closeConnIfStillIdle() {
//
// The startBytesWritten value should be the value of pc.nwrite before the roundTrip
// started writing the request.
-func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, err error) (out error) {
+func (pc *persistConn) mapRoundTripErrorFromReadLoop(req *Request, startBytesWritten int64, err error) (out error) {
if err == nil {
return nil
}
- if pc.isCanceled() {
- return errRequestCanceled
+ if err := pc.canceled(); err != nil {
+ return err
}
if err == errServerClosedIdle {
return err
@@ -1343,7 +1399,7 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, er
}
if pc.isBroken() {
<-pc.writeLoopDone
- if pc.nwrite == startBytesWritten {
+ if pc.nwrite == startBytesWritten && req.outgoingLength() == 0 {
return nothingWrittenError{err}
}
}
@@ -1354,9 +1410,9 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, er
// up to Transport.RoundTrip method when persistConn.roundTrip sees
// its pc.closech channel close, indicating the persistConn is dead.
// (after closech is closed, pc.closed is valid).
-func (pc *persistConn) mapRoundTripErrorAfterClosed(startBytesWritten int64) error {
- if pc.isCanceled() {
- return errRequestCanceled
+func (pc *persistConn) mapRoundTripErrorAfterClosed(req *Request, startBytesWritten int64) error {
+ if err := pc.canceled(); err != nil {
+ return err
}
err := pc.closed
if err == errServerClosedIdle {
@@ -1372,7 +1428,7 @@ func (pc *persistConn) mapRoundTripErrorAfterClosed(startBytesWritten int64) err
// see if we actually managed to write anything. If not, we
// can retry the request.
<-pc.writeLoopDone
- if pc.nwrite == startBytesWritten {
+ if pc.nwrite == startBytesWritten && req.outgoingLength() == 0 {
return nothingWrittenError{err}
}
@@ -1513,8 +1569,10 @@ func (pc *persistConn) readLoop() {
waitForBodyRead <- isEOF
if isEOF {
<-eofc // see comment above eofc declaration
- } else if err != nil && pc.isCanceled() {
- return errRequestCanceled
+ } else if err != nil {
+ if cerr := pc.canceled(); cerr != nil {
+ return cerr
+ }
}
return err
},
@@ -1554,7 +1612,7 @@ func (pc *persistConn) readLoop() {
pc.t.CancelRequest(rc.req)
case <-rc.req.Context().Done():
alive = false
- pc.t.CancelRequest(rc.req)
+ pc.t.cancelRequest(rc.req, rc.req.Context().Err())
case <-pc.closech:
alive = false
}
@@ -1652,7 +1710,7 @@ func (pc *persistConn) writeLoop() {
}
if err != nil {
wr.req.Request.closeBody()
- if pc.nwrite == startBytesWritten {
+ if pc.nwrite == startBytesWritten && wr.req.outgoingLength() == 0 {
err = nothingWrittenError{err}
}
}
@@ -1840,8 +1898,8 @@ WaitResponse:
select {
case err := <-writeErrCh:
if err != nil {
- if pc.isCanceled() {
- err = errRequestCanceled
+ if cerr := pc.canceled(); cerr != nil {
+ err = cerr
}
re = responseAndError{err: err}
pc.close(fmt.Errorf("write error: %v", err))
@@ -1853,21 +1911,20 @@ WaitResponse:
respHeaderTimer = timer.C
}
case <-pc.closech:
- re = responseAndError{err: pc.mapRoundTripErrorAfterClosed(startBytesWritten)}
+ re = responseAndError{err: pc.mapRoundTripErrorAfterClosed(req.Request, startBytesWritten)}
break WaitResponse
case <-respHeaderTimer:
pc.close(errTimeout)
re = responseAndError{err: errTimeout}
break WaitResponse
case re = <-resc:
- re.err = pc.mapRoundTripErrorFromReadLoop(startBytesWritten, re.err)
+ re.err = pc.mapRoundTripErrorFromReadLoop(req.Request, startBytesWritten, re.err)
break WaitResponse
case <-cancelChan:
pc.t.CancelRequest(req.Request)
cancelChan = nil
- ctxDoneChan = nil
case <-ctxDoneChan:
- pc.t.CancelRequest(req.Request)
+ pc.t.cancelRequest(req.Request, req.Context().Err())
cancelChan = nil
ctxDoneChan = nil
}
@@ -1931,11 +1988,15 @@ var portMap = map[string]string{
// canonicalAddr returns url.Host but always with a ":port" suffix
func canonicalAddr(url *url.URL) string {
- addr := url.Host
- if !hasPort(addr) {
- return addr + ":" + portMap[url.Scheme]
+ addr := url.Hostname()
+ if v, err := idnaASCII(addr); err == nil {
+ addr = v
+ }
+ port := url.Port()
+ if port == "" {
+ port = portMap[url.Scheme]
}
- return addr
+ return net.JoinHostPort(addr, port)
}
// bodyEOFSignal is used by the HTTP/1 transport when reading response
@@ -2060,75 +2121,14 @@ type fakeLocker struct{}
func (fakeLocker) Lock() {}
func (fakeLocker) Unlock() {}
-// cloneTLSConfig returns a shallow clone of the exported
-// fields of cfg, ignoring the unexported sync.Once, which
-// contains a mutex and must not be copied.
-//
-// The cfg must not be in active use by tls.Server, or else
-// there can still be a race with tls.Server updating SessionTicketKey
-// and our copying it, and also a race with the server setting
-// SessionTicketsDisabled=false on failure to set the random
-// ticket key.
-//
-// If cfg is nil, a new zero tls.Config is returned.
+// clneTLSConfig returns a shallow clone of cfg, or a new zero tls.Config if
+// cfg is nil. This is safe to call even if cfg is in active use by a TLS
+// client or server.
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
- return &tls.Config{
- Rand: cfg.Rand,
- Time: cfg.Time,
- Certificates: cfg.Certificates,
- NameToCertificate: cfg.NameToCertificate,
- GetCertificate: cfg.GetCertificate,
- RootCAs: cfg.RootCAs,
- NextProtos: cfg.NextProtos,
- ServerName: cfg.ServerName,
- ClientAuth: cfg.ClientAuth,
- ClientCAs: cfg.ClientCAs,
- InsecureSkipVerify: cfg.InsecureSkipVerify,
- CipherSuites: cfg.CipherSuites,
- PreferServerCipherSuites: cfg.PreferServerCipherSuites,
- SessionTicketsDisabled: cfg.SessionTicketsDisabled,
- SessionTicketKey: cfg.SessionTicketKey,
- ClientSessionCache: cfg.ClientSessionCache,
- MinVersion: cfg.MinVersion,
- MaxVersion: cfg.MaxVersion,
- CurvePreferences: cfg.CurvePreferences,
- DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
- Renegotiation: cfg.Renegotiation,
- }
-}
-
-// cloneTLSClientConfig is like cloneTLSConfig but omits
-// the fields SessionTicketsDisabled and SessionTicketKey.
-// This makes it safe to call cloneTLSClientConfig on a config
-// in active use by a server.
-func cloneTLSClientConfig(cfg *tls.Config) *tls.Config {
- if cfg == nil {
- return &tls.Config{}
- }
- return &tls.Config{
- Rand: cfg.Rand,
- Time: cfg.Time,
- Certificates: cfg.Certificates,
- NameToCertificate: cfg.NameToCertificate,
- GetCertificate: cfg.GetCertificate,
- RootCAs: cfg.RootCAs,
- NextProtos: cfg.NextProtos,
- ServerName: cfg.ServerName,
- ClientAuth: cfg.ClientAuth,
- ClientCAs: cfg.ClientCAs,
- InsecureSkipVerify: cfg.InsecureSkipVerify,
- CipherSuites: cfg.CipherSuites,
- PreferServerCipherSuites: cfg.PreferServerCipherSuites,
- ClientSessionCache: cfg.ClientSessionCache,
- MinVersion: cfg.MinVersion,
- MaxVersion: cfg.MaxVersion,
- CurvePreferences: cfg.CurvePreferences,
- DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
- Renegotiation: cfg.Renegotiation,
- }
+ return cfg.Clone()
}
type connLRU struct {
@@ -2169,3 +2169,15 @@ func (cl *connLRU) remove(pc *persistConn) {
func (cl *connLRU) len() int {
return len(cl.m)
}
+
+// validPort reports whether p (without the colon) is a valid port in
+// a URL, per RFC 3986 Section 3.2.3, which says the port may be
+// empty, or only contain digits.
+func validPort(p string) bool {
+ for _, r := range []byte(p) {
+ if r < '0' || r > '9' {
+ return false
+ }
+ }
+ return true
+}
diff --git a/libgo/go/net/http/transport_internal_test.go b/libgo/go/net/http/transport_internal_test.go
index a05ca6e..3d24fc1 100644
--- a/libgo/go/net/http/transport_internal_test.go
+++ b/libgo/go/net/http/transport_internal_test.go
@@ -72,3 +72,70 @@ func newLocalListener(t *testing.T) net.Listener {
}
return ln
}
+
+func dummyRequest(method string) *Request {
+ req, err := NewRequest(method, "http://fake.tld/", nil)
+ if err != nil {
+ panic(err)
+ }
+ return req
+}
+
+func TestTransportShouldRetryRequest(t *testing.T) {
+ tests := []struct {
+ pc *persistConn
+ req *Request
+
+ err error
+ want bool
+ }{
+ 0: {
+ pc: &persistConn{reused: false},
+ req: dummyRequest("POST"),
+ err: nothingWrittenError{},
+ want: false,
+ },
+ 1: {
+ pc: &persistConn{reused: true},
+ req: dummyRequest("POST"),
+ err: nothingWrittenError{},
+ want: true,
+ },
+ 2: {
+ pc: &persistConn{reused: true},
+ req: dummyRequest("POST"),
+ err: http2ErrNoCachedConn,
+ want: true,
+ },
+ 3: {
+ pc: &persistConn{reused: true},
+ req: dummyRequest("POST"),
+ err: errMissingHost,
+ want: false,
+ },
+ 4: {
+ pc: &persistConn{reused: true},
+ req: dummyRequest("POST"),
+ err: transportReadFromServerError{},
+ want: false,
+ },
+ 5: {
+ pc: &persistConn{reused: true},
+ req: dummyRequest("GET"),
+ err: transportReadFromServerError{},
+ want: true,
+ },
+ 6: {
+ pc: &persistConn{reused: true},
+ req: dummyRequest("GET"),
+ err: errServerClosedIdle,
+ want: true,
+ },
+ }
+ for i, tt := range tests {
+ got := tt.pc.shouldRetryRequest(tt.req, tt.err)
+ if got != tt.want {
+ t.Errorf("%d. shouldRetryRequest = %v; want %v", i, got, tt.want)
+ }
+ }
+}
diff --git a/libgo/go/net/http/transport_test.go b/libgo/go/net/http/transport_test.go
index 298682d..d5ddf6a 100644
--- a/libgo/go/net/http/transport_test.go
+++ b/libgo/go/net/http/transport_test.go
@@ -441,9 +441,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) {
}
func TestTransportRemovesDeadIdleConnections(t *testing.T) {
- if runtime.GOOS == "plan9" {
- t.Skip("skipping test; see https://golang.org/issue/15464")
- }
+ setParallel(t)
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
io.WriteString(w, r.RemoteAddr)
@@ -700,6 +698,7 @@ var roundTripTests = []struct {
// Test that the modification made to the Request by the RoundTripper is cleaned up
func TestRoundTripGzip(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
const responseBody = "test response body"
ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
@@ -758,6 +757,7 @@ func TestRoundTripGzip(t *testing.T) {
}
func TestTransportGzip(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
const nRandBytes = 1024 * 1024
@@ -856,6 +856,7 @@ func TestTransportGzip(t *testing.T) {
// If a request has Expect:100-continue header, the request blocks sending body until the first response.
// Premature consumption of the request body should not be occurred.
func TestTransportExpect100Continue(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
@@ -966,6 +967,48 @@ func TestTransportProxy(t *testing.T) {
}
}
+// Issue 16997: test transport dial preserves typed errors
+func TestTransportDialPreservesNetOpProxyError(t *testing.T) {
+ defer afterTest(t)
+
+ var errDial = errors.New("some dial error")
+
+ tr := &Transport{
+ Proxy: func(*Request) (*url.URL, error) {
+ return url.Parse("http://proxy.fake.tld/")
+ },
+ Dial: func(string, string) (net.Conn, error) {
+ return nil, errDial
+ },
+ }
+ defer tr.CloseIdleConnections()
+
+ c := &Client{Transport: tr}
+ req, _ := NewRequest("GET", "http://fake.tld", nil)
+ res, err := c.Do(req)
+ if err == nil {
+ res.Body.Close()
+ t.Fatal("wanted a non-nil error")
+ }
+
+ uerr, ok := err.(*url.Error)
+ if !ok {
+ t.Fatalf("got %T, want *url.Error", err)
+ }
+ oe, ok := uerr.Err.(*net.OpError)
+ if !ok {
+ t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err)
+ }
+ want := &net.OpError{
+ Op: "proxyconnect",
+ Net: "tcp",
+ Err: errDial, // original error, unwrapped.
+ }
+ if !reflect.DeepEqual(oe, want) {
+ t.Errorf("Got error %#v; want %#v", oe, want)
+ }
+}
+
// TestTransportGzipRecursive sends a gzip quine and checks that the
// client gets the same value back. This is more cute than anything,
// but checks that we don't recurse forever, and checks that
@@ -1038,10 +1081,12 @@ func waitNumGoroutine(nmax int) int {
// tests that persistent goroutine connections shut down when no longer desired.
func TestTransportPersistConnLeak(t *testing.T) {
- setParallel(t)
+ // Not parallel: counts goroutines
defer afterTest(t)
- gotReqCh := make(chan bool)
- unblockCh := make(chan bool)
+
+ const numReq = 25
+ gotReqCh := make(chan bool, numReq)
+ unblockCh := make(chan bool, numReq)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
gotReqCh <- true
<-unblockCh
@@ -1055,14 +1100,15 @@ func TestTransportPersistConnLeak(t *testing.T) {
n0 := runtime.NumGoroutine()
- const numReq = 25
- didReqCh := make(chan bool)
+ didReqCh := make(chan bool, numReq)
+ failed := make(chan bool, numReq)
for i := 0; i < numReq; i++ {
go func() {
res, err := c.Get(ts.URL)
didReqCh <- true
if err != nil {
t.Errorf("client fetch error: %v", err)
+ failed <- true
return
}
res.Body.Close()
@@ -1071,7 +1117,13 @@ func TestTransportPersistConnLeak(t *testing.T) {
// Wait for all goroutines to be stuck in the Handler.
for i := 0; i < numReq; i++ {
- <-gotReqCh
+ select {
+ case <-gotReqCh:
+ // ok
+ case <-failed:
+ close(unblockCh)
+ return
+ }
}
nhigh := runtime.NumGoroutine()
@@ -1102,7 +1154,7 @@ func TestTransportPersistConnLeak(t *testing.T) {
// golang.org/issue/4531: Transport leaks goroutines when
// request.ContentLength is explicitly short
func TestTransportPersistConnLeakShortBody(t *testing.T) {
- setParallel(t)
+ // Not parallel: measures goroutines.
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
}))
@@ -1198,6 +1250,7 @@ func TestIssue3644(t *testing.T) {
// Test that a client receives a server's reply, even if the server doesn't read
// the entire request body.
func TestIssue3595(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
const deniedMsg = "sorry, denied."
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -1246,6 +1299,7 @@ func TestChunkedNoContent(t *testing.T) {
}
func TestTransportConcurrency(t *testing.T) {
+ // Not parallel: uses global test hooks.
defer afterTest(t)
maxProcs, numReqs := 16, 500
if testing.Short() {
@@ -1306,9 +1360,7 @@ func TestTransportConcurrency(t *testing.T) {
}
func TestIssue4191_InfiniteGetTimeout(t *testing.T) {
- if runtime.GOOS == "plan9" {
- t.Skip("skipping test; see https://golang.org/issue/7237")
- }
+ setParallel(t)
defer afterTest(t)
const debug = false
mux := NewServeMux()
@@ -1370,9 +1422,7 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) {
}
func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
- if runtime.GOOS == "plan9" {
- t.Skip("skipping test; see https://golang.org/issue/7237")
- }
+ setParallel(t)
defer afterTest(t)
const debug = false
mux := NewServeMux()
@@ -1696,12 +1746,6 @@ func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) {
defer ts.Close()
defer close(unblockc)
- // Don't interfere with the next test on plan9.
- // Cf. https://golang.org/issues/11476
- if runtime.GOOS == "plan9" {
- defer time.Sleep(500 * time.Millisecond)
- }
-
tr := &Transport{}
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
@@ -1718,8 +1762,17 @@ func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) {
}
_, err := c.Do(req)
- if err == nil || !strings.Contains(err.Error(), "canceled") {
- t.Errorf("Do error = %v; want cancelation", err)
+ if ue, ok := err.(*url.Error); ok {
+ err = ue.Err
+ }
+ if withCtx {
+ if err != context.Canceled {
+ t.Errorf("Do error = %v; want %v", err, context.Canceled)
+ }
+ } else {
+ if err == nil || !strings.Contains(err.Error(), "canceled") {
+ t.Errorf("Do error = %v; want cancelation", err)
+ }
}
}
@@ -1888,6 +1941,7 @@ func TestTransportEmptyMethod(t *testing.T) {
}
func TestTransportSocketLateBinding(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
mux := NewServeMux()
@@ -2152,6 +2206,7 @@ func TestProxyFromEnvironment(t *testing.T) {
}
func TestIdleConnChannelLeak(t *testing.T) {
+ // Not parallel: uses global test hooks.
var mu sync.Mutex
var n int
@@ -2383,6 +2438,7 @@ func (c byteFromChanReader) Read(p []byte) (n int, err error) {
// questionable state.
// golang.org/issue/7569
func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
var sconn struct {
sync.Mutex
@@ -2485,22 +2541,6 @@ type errorReader struct {
func (e errorReader) Read(p []byte) (int, error) { return 0, e.err }
-type plan9SleepReader struct{}
-
-func (plan9SleepReader) Read(p []byte) (int, error) {
- if runtime.GOOS == "plan9" {
- // After the fix to unblock TCP Reads in
- // https://golang.org/cl/15941, this sleep is required
- // on plan9 to make sure TCP Writes before an
- // immediate TCP close go out on the wire. On Plan 9,
- // it seems that a hangup of a TCP connection with
- // queued data doesn't send the queued data first.
- // https://golang.org/issue/9554
- time.Sleep(50 * time.Millisecond)
- }
- return 0, io.EOF
-}
-
type closerFunc func() error
func (f closerFunc) Close() error { return f() }
@@ -2595,7 +2635,7 @@ func TestTransportClosesBodyOnError(t *testing.T) {
io.Reader
io.Closer
}{
- io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), plan9SleepReader{}, errorReader{fakeErr}),
+ io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), errorReader{fakeErr}),
closerFunc(func() error {
select {
case didClose <- true:
@@ -2627,6 +2667,8 @@ func TestTransportClosesBodyOnError(t *testing.T) {
}
func TestTransportDialTLS(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
var mu sync.Mutex // guards following
var gotReq, didDial bool
@@ -2904,14 +2946,8 @@ func TestTransportFlushesBodyChunks(t *testing.T) {
defer res.Body.Close()
want := []string{
- // Because Request.ContentLength = 0, the body is sniffed for 1 byte to determine whether there's content.
- // That explains the initial "num0" being split into "n" and "um0".
- // The first byte is included with the request headers Write. Perhaps in the future
- // we will want to flush the headers out early if the first byte of the request body is
- // taking a long time to arrive. But not yet.
"POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n" +
- "1\r\nn\r\n",
- "4\r\num0\n\r\n",
+ "5\r\nnum0\n\r\n",
"5\r\nnum1\n\r\n",
"5\r\nnum2\n\r\n",
"0\r\n\r\n",
@@ -3150,6 +3186,7 @@ func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) {
// Make sure we re-use underlying TCP connection for gzipped responses too.
func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) {
+ setParallel(t)
defer afterTest(t)
addr := make(chan string, 2)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -3185,6 +3222,7 @@ func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) {
}
func TestTransportResponseHeaderLength(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.URL.Path == "/long" {
@@ -3248,7 +3286,7 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
cst.tr.ExpectContinueTimeout = 1 * time.Second
- var mu sync.Mutex
+ var mu sync.Mutex // guards buf
var buf bytes.Buffer
logf := func(format string, args ...interface{}) {
mu.Lock()
@@ -3290,10 +3328,16 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
Wait100Continue: func() { logf("Wait100Continue") },
Got100Continue: func() { logf("Got100Continue") },
WroteRequest: func(e httptrace.WroteRequestInfo) {
- close(gotWroteReqEvent)
logf("WroteRequest: %+v", e)
+ close(gotWroteReqEvent)
},
}
+ if h2 {
+ trace.TLSHandshakeStart = func() { logf("tls handshake start") }
+ trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
+ logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
+ }
+ }
if noHooks {
// zero out all func pointers, trying to get some path to crash
*trace = httptrace.ClientTrace{}
@@ -3323,7 +3367,10 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
return
}
+ mu.Lock()
got := buf.String()
+ mu.Unlock()
+
wantOnce := func(sub string) {
if strings.Count(got, sub) != 1 {
t.Errorf("expected substring %q exactly once in output.", sub)
@@ -3342,7 +3389,10 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
wantOnce("Reused:false WasIdle:false IdleTime:0s")
wantOnce("first response byte")
- if !h2 {
+ if h2 {
+ wantOnce("tls handshake start")
+ wantOnce("tls handshake done")
+ } else {
wantOnce("PutIdleConn = <nil>")
}
wantOnce("Wait100Continue")
@@ -3357,12 +3407,21 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
}
func TestTransportEventTraceRealDNS(t *testing.T) {
+ if testing.Short() && testenv.Builder() == "" {
+ // Skip this test in short mode (the default for
+ // all.bash), in case the user is using a shady/ISP
+ // DNS server hijacking queries.
+ // See issues 16732, 16716.
+ // Our builders use 8.8.8.8, though, which correctly
+ // returns NXDOMAIN, so still run this test there.
+ t.Skip("skipping in short mode")
+ }
defer afterTest(t)
tr := &Transport{}
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
- var mu sync.Mutex
+ var mu sync.Mutex // guards buf
var buf bytes.Buffer
logf := func(format string, args ...interface{}) {
mu.Lock()
@@ -3386,7 +3445,10 @@ func TestTransportEventTraceRealDNS(t *testing.T) {
t.Fatal("expected error during DNS lookup")
}
+ mu.Lock()
got := buf.String()
+ mu.Unlock()
+
wantSub := func(sub string) {
if !strings.Contains(got, sub) {
t.Errorf("expected substring %q in output.", sub)
@@ -3402,6 +3464,73 @@ func TestTransportEventTraceRealDNS(t *testing.T) {
}
}
+// Issue 14353: port can only contain digits.
+func TestTransportRejectsAlphaPort(t *testing.T) {
+ res, err := Get("http://dummy.tld:123foo/bar")
+ if err == nil {
+ res.Body.Close()
+ t.Fatal("unexpected success")
+ }
+ ue, ok := err.(*url.Error)
+ if !ok {
+ t.Fatalf("got %#v; want *url.Error", err)
+ }
+ got := ue.Err.Error()
+ want := `invalid URL port "123foo"`
+ if got != want {
+ t.Errorf("got error %q; want %q", got, want)
+ }
+}
+
+// Test the httptrace.TLSHandshake{Start,Done} hooks with a https http1
+// connections. The http2 test is done in TestTransportEventTrace_h2
+func TestTLSHandshakeTrace(t *testing.T) {
+ defer afterTest(t)
+ s := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+ defer s.Close()
+
+ var mu sync.Mutex
+ var start, done bool
+ trace := &httptrace.ClientTrace{
+ TLSHandshakeStart: func() {
+ mu.Lock()
+ defer mu.Unlock()
+ start = true
+ },
+ TLSHandshakeDone: func(s tls.ConnectionState, err error) {
+ mu.Lock()
+ defer mu.Unlock()
+ done = true
+ if err != nil {
+ t.Fatal("Expected error to be nil but was:", err)
+ }
+ },
+ }
+
+ tr := &Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+ req, err := NewRequest("GET", s.URL, nil)
+ if err != nil {
+ t.Fatal("Unable to construct test request:", err)
+ }
+ req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
+
+ r, err := c.Do(req)
+ if err != nil {
+ t.Fatal("Unexpected error making request:", err)
+ }
+ r.Body.Close()
+ mu.Lock()
+ defer mu.Unlock()
+ if !start {
+ t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
+ }
+ if !done {
+ t.Fatal("Expected TLSHandshakeDone to be called, but wasnt't")
+ }
+}
+
func TestTransportMaxIdleConns(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -3457,27 +3586,36 @@ func TestTransportMaxIdleConns(t *testing.T) {
}
}
-func TestTransportIdleConnTimeout(t *testing.T) {
+func TestTransportIdleConnTimeout_h1(t *testing.T) { testTransportIdleConnTimeout(t, h1Mode) }
+func TestTransportIdleConnTimeout_h2(t *testing.T) { testTransportIdleConnTimeout(t, h2Mode) }
+func testTransportIdleConnTimeout(t *testing.T, h2 bool) {
if testing.Short() {
t.Skip("skipping in short mode")
}
defer afterTest(t)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ const timeout = 1 * time.Second
+
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
// No body for convenience.
}))
- defer ts.Close()
-
- const timeout = 1 * time.Second
- tr := &Transport{
- IdleConnTimeout: timeout,
- }
+ defer cst.close()
+ tr := cst.tr
+ tr.IdleConnTimeout = timeout
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
+ idleConns := func() []string {
+ if h2 {
+ return tr.IdleConnStrsForTesting_h2()
+ } else {
+ return tr.IdleConnStrsForTesting()
+ }
+ }
+
var conn string
doReq := func(n int) {
- req, _ := NewRequest("GET", ts.URL, nil)
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
PutIdleConn: func(err error) {
if err != nil {
@@ -3490,7 +3628,7 @@ func TestTransportIdleConnTimeout(t *testing.T) {
t.Fatal(err)
}
res.Body.Close()
- conns := tr.IdleConnStrsForTesting()
+ conns := idleConns()
if len(conns) != 1 {
t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns)
}
@@ -3506,7 +3644,7 @@ func TestTransportIdleConnTimeout(t *testing.T) {
time.Sleep(timeout / 2)
}
time.Sleep(timeout * 3 / 2)
- if got := tr.IdleConnStrsForTesting(); len(got) != 0 {
+ if got := idleConns(); len(got) != 0 {
t.Errorf("idle conns = %q; want none", got)
}
}
@@ -3523,6 +3661,7 @@ func TestTransportIdleConnTimeout(t *testing.T) {
// know the successful tls.Dial from DialTLS will need to go into the
// idle pool. Then we give it a of time to explode.
func TestIdleConnH2Crash(t *testing.T) {
+ setParallel(t)
cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
// nothing
}))
@@ -3531,12 +3670,12 @@ func TestIdleConnH2Crash(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- gotErr := make(chan bool, 1)
+ sawDoErr := make(chan bool, 1)
+ testDone := make(chan struct{})
+ defer close(testDone)
cst.tr.IdleConnTimeout = 5 * time.Millisecond
cst.tr.DialTLS = func(network, addr string) (net.Conn, error) {
- cancel()
- <-gotErr
c, err := tls.Dial(network, addr, &tls.Config{
InsecureSkipVerify: true,
NextProtos: []string{"h2"},
@@ -3550,6 +3689,17 @@ func TestIdleConnH2Crash(t *testing.T) {
c.Close()
return nil, errors.New("bogus")
}
+
+ cancel()
+
+ failTimer := time.NewTimer(5 * time.Second)
+ defer failTimer.Stop()
+ select {
+ case <-sawDoErr:
+ case <-testDone:
+ case <-failTimer.C:
+ t.Error("timeout in DialTLS, waiting too long for cst.c.Do to fail")
+ }
return c, nil
}
@@ -3560,7 +3710,7 @@ func TestIdleConnH2Crash(t *testing.T) {
res.Body.Close()
t.Fatal("unexpected success")
}
- gotErr <- true
+ sawDoErr <- true
// Wait for the explosion.
time.Sleep(cst.tr.IdleConnTimeout * 10)
@@ -3605,6 +3755,122 @@ func TestTransportReturnsPeekError(t *testing.T) {
}
}
+// Issue 13835: international domain names should work
+func TestTransportIDNA_h1(t *testing.T) { testTransportIDNA(t, h1Mode) }
+func TestTransportIDNA_h2(t *testing.T) { testTransportIDNA(t, h2Mode) }
+func testTransportIDNA(t *testing.T, h2 bool) {
+ defer afterTest(t)
+
+ const uniDomain = "гофер.го"
+ const punyDomain = "xn--c1ae0ajs.xn--c1aw"
+
+ var port string
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ want := punyDomain + ":" + port
+ if r.Host != want {
+ t.Errorf("Host header = %q; want %q", r.Host, want)
+ }
+ if h2 {
+ if r.TLS == nil {
+ t.Errorf("r.TLS == nil")
+ } else if r.TLS.ServerName != punyDomain {
+ t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain)
+ }
+ }
+ w.Header().Set("Hit-Handler", "1")
+ }))
+ defer cst.close()
+
+ ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Install a fake DNS server.
+ ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, host string) ([]net.IPAddr, error) {
+ if host != punyDomain {
+ t.Errorf("got DNS host lookup for %q; want %q", host, punyDomain)
+ return nil, nil
+ }
+ return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
+ })
+
+ req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil)
+ trace := &httptrace.ClientTrace{
+ GetConn: func(hostPort string) {
+ want := net.JoinHostPort(punyDomain, port)
+ if hostPort != want {
+ t.Errorf("getting conn for %q; want %q", hostPort, want)
+ }
+ },
+ DNSStart: func(e httptrace.DNSStartInfo) {
+ if e.Host != punyDomain {
+ t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain)
+ }
+ },
+ }
+ req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
+
+ res, err := cst.tr.RoundTrip(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.Header.Get("Hit-Handler") != "1" {
+ out, err := httputil.DumpResponse(res, true)
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out)
+ }
+}
+
+// Issue 13290: send User-Agent in proxy CONNECT
+func TestTransportProxyConnectHeader(t *testing.T) {
+ defer afterTest(t)
+ reqc := make(chan *Request, 1)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method != "CONNECT" {
+ t.Errorf("method = %q; want CONNECT", r.Method)
+ }
+ reqc <- r
+ c, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Errorf("Hijack: %v", err)
+ return
+ }
+ c.Close()
+ }))
+ defer ts.Close()
+ tr := &Transport{
+ ProxyConnectHeader: Header{
+ "User-Agent": {"foo"},
+ "Other": {"bar"},
+ },
+ Proxy: func(r *Request) (*url.URL, error) {
+ return url.Parse(ts.URL)
+ },
+ }
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+ res, err := c.Get("https://dummy.tld/") // https to force a CONNECT
+ if err == nil {
+ res.Body.Close()
+ t.Errorf("unexpected success")
+ }
+ select {
+ case <-time.After(3 * time.Second):
+ t.Fatal("timeout")
+ case r := <-reqc:
+ if got, want := r.Header.Get("User-Agent"), "foo"; got != want {
+ t.Errorf("CONNECT request User-Agent = %q; want %q", got, want)
+ }
+ if got, want := r.Header.Get("Other"), "bar"; got != want {
+ t.Errorf("CONNECT request Other = %q; want %q", got, want)
+ }
+ }
+}
+
var errFakeRoundTrip = errors.New("fake roundtrip")
type funcRoundTripper func()
diff --git a/libgo/go/net/interface.go b/libgo/go/net/interface.go
index 52b857c..b3297f2 100644
--- a/libgo/go/net/interface.go
+++ b/libgo/go/net/interface.go
@@ -10,6 +10,12 @@ import (
"time"
)
+// BUG(mikio): On NaCl, methods and functions related to
+// Interface are not implemented.
+
+// BUG(mikio): On DragonFly BSD, NetBSD, OpenBSD, Plan 9 and Solaris,
+// the MulticastAddrs method of Interface is not implemented.
+
var (
errInvalidInterface = errors.New("invalid network interface")
errInvalidInterfaceIndex = errors.New("invalid network interface index")
@@ -63,7 +69,8 @@ func (f Flags) String() string {
return s
}
-// Addrs returns interface addresses for a specific interface.
+// Addrs returns a list of unicast interface addresses for a specific
+// interface.
func (ifi *Interface) Addrs() ([]Addr, error) {
if ifi == nil {
return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterface}
@@ -75,8 +82,8 @@ func (ifi *Interface) Addrs() ([]Addr, error) {
return ifat, err
}
-// MulticastAddrs returns multicast, joined group addresses for
-// a specific interface.
+// MulticastAddrs returns a list of multicast, joined group addresses
+// for a specific interface.
func (ifi *Interface) MulticastAddrs() ([]Addr, error) {
if ifi == nil {
return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterface}
@@ -100,8 +107,11 @@ func Interfaces() ([]Interface, error) {
return ift, nil
}
-// InterfaceAddrs returns a list of the system's network interface
+// InterfaceAddrs returns a list of the system's unicast interface
// addresses.
+//
+// The returned list does not identify the associated interface; use
+// Interfaces and Interface.Addrs for more detail.
func InterfaceAddrs() ([]Addr, error) {
ifat, err := interfaceAddrTable(nil)
if err != nil {
@@ -111,6 +121,10 @@ func InterfaceAddrs() ([]Addr, error) {
}
// InterfaceByIndex returns the interface specified by index.
+//
+// On Solaris, it returns one of the logical network interfaces
+// sharing the logical data link; for more precision use
+// InterfaceByName.
func InterfaceByIndex(index int) (*Interface, error) {
if index <= 0 {
return nil, &OpError{Op: "route", Net: "ip+net", Source: nil, Addr: nil, Err: errInvalidInterfaceIndex}
@@ -158,6 +172,9 @@ func InterfaceByName(name string) (*Interface, error) {
// An ipv6ZoneCache represents a cache holding partial network
// interface information. It is used for reducing the cost of IPv6
// addressing scope zone resolution.
+//
+// Multiple names sharing the index are managed by first-come
+// first-served basis for consistency.
type ipv6ZoneCache struct {
sync.RWMutex // guard the following
lastFetched time.Time // last time routing information was fetched
@@ -188,7 +205,9 @@ func (zc *ipv6ZoneCache) update(ift []Interface) {
zc.toName = make(map[int]string, len(ift))
for _, ifi := range ift {
zc.toIndex[ifi.Name] = ifi.Index
- zc.toName[ifi.Index] = ifi.Name
+ if _, ok := zc.toName[ifi.Index]; !ok {
+ zc.toName[ifi.Index] = ifi.Name
+ }
}
}
@@ -215,7 +234,7 @@ func zoneToInt(zone string) int {
defer zoneCache.RUnlock()
index, ok := zoneCache.toIndex[zone]
if !ok {
- index, _, _ = dtoi(zone, 0)
+ index, _, _ = dtoi(zone)
}
return index
}
diff --git a/libgo/go/net/interface_plan9.go b/libgo/go/net/interface_plan9.go
new file mode 100644
index 0000000..e5d7739
--- /dev/null
+++ b/libgo/go/net/interface_plan9.go
@@ -0,0 +1,198 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "errors"
+ "os"
+)
+
+// If the ifindex is zero, interfaceTable returns mappings of all
+// network interfaces. Otherwise it returns a mapping of a specific
+// interface.
+func interfaceTable(ifindex int) ([]Interface, error) {
+ if ifindex == 0 {
+ n, err := interfaceCount()
+ if err != nil {
+ return nil, err
+ }
+ ifcs := make([]Interface, n)
+ for i := range ifcs {
+ ifc, err := readInterface(i)
+ if err != nil {
+ return nil, err
+ }
+ ifcs[i] = *ifc
+ }
+ return ifcs, nil
+ }
+
+ ifc, err := readInterface(ifindex - 1)
+ if err != nil {
+ return nil, err
+ }
+ return []Interface{*ifc}, nil
+}
+
+func readInterface(i int) (*Interface, error) {
+ ifc := &Interface{
+ Index: i + 1, // Offset the index by one to suit the contract
+ Name: netdir + "/ipifc/" + itoa(i), // Name is the full path to the interface path in plan9
+ }
+
+ ifcstat := ifc.Name + "/status"
+ ifcstatf, err := open(ifcstat)
+ if err != nil {
+ return nil, err
+ }
+ defer ifcstatf.close()
+
+ line, ok := ifcstatf.readLine()
+ if !ok {
+ return nil, errors.New("invalid interface status file: " + ifcstat)
+ }
+
+ fields := getFields(line)
+ if len(fields) < 4 {
+ return nil, errors.New("invalid interface status file: " + ifcstat)
+ }
+
+ device := fields[1]
+ mtustr := fields[3]
+
+ mtu, _, ok := dtoi(mtustr)
+ if !ok {
+ return nil, errors.New("invalid status file of interface: " + ifcstat)
+ }
+ ifc.MTU = mtu
+
+ // Not a loopback device
+ if device != "/dev/null" {
+ deviceaddrf, err := open(device + "/addr")
+ if err != nil {
+ return nil, err
+ }
+ defer deviceaddrf.close()
+
+ line, ok = deviceaddrf.readLine()
+ if !ok {
+ return nil, errors.New("invalid address file for interface: " + device + "/addr")
+ }
+
+ if len(line) > 0 && len(line)%2 == 0 {
+ ifc.HardwareAddr = make([]byte, len(line)/2)
+ var ok bool
+ for i := range ifc.HardwareAddr {
+ j := (i + 1) * 2
+ ifc.HardwareAddr[i], ok = xtoi2(line[i*2:j], 0)
+ if !ok {
+ ifc.HardwareAddr = ifc.HardwareAddr[:i]
+ break
+ }
+ }
+ }
+
+ ifc.Flags = FlagUp | FlagBroadcast | FlagMulticast
+ } else {
+ ifc.Flags = FlagUp | FlagMulticast | FlagLoopback
+ }
+
+ return ifc, nil
+}
+
+func interfaceCount() (int, error) {
+ d, err := os.Open(netdir + "/ipifc")
+ if err != nil {
+ return -1, err
+ }
+ defer d.Close()
+
+ names, err := d.Readdirnames(0)
+ if err != nil {
+ return -1, err
+ }
+
+ // Assumes that numbered files in ipifc are strictly
+ // the incrementing numbered directories for the
+ // interfaces
+ c := 0
+ for _, name := range names {
+ if _, _, ok := dtoi(name); !ok {
+ continue
+ }
+ c++
+ }
+
+ return c, nil
+}
+
+// If the ifi is nil, interfaceAddrTable returns addresses for all
+// network interfaces. Otherwise it returns addresses for a specific
+// interface.
+func interfaceAddrTable(ifi *Interface) ([]Addr, error) {
+ var ifcs []Interface
+ if ifi == nil {
+ var err error
+ ifcs, err = interfaceTable(0)
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ ifcs = []Interface{*ifi}
+ }
+
+ addrs := make([]Addr, len(ifcs))
+ for i, ifc := range ifcs {
+ status := ifc.Name + "/status"
+ statusf, err := open(status)
+ if err != nil {
+ return nil, err
+ }
+ defer statusf.close()
+
+ line, ok := statusf.readLine()
+ line, ok = statusf.readLine()
+ if !ok {
+ return nil, errors.New("cannot parse IP address for interface: " + status)
+ }
+
+ // This assumes only a single address for the interface.
+ fields := getFields(line)
+ if len(fields) < 1 {
+ return nil, errors.New("cannot parse IP address for interface: " + status)
+ }
+ addr := fields[0]
+ ip := ParseIP(addr)
+ if ip == nil {
+ return nil, errors.New("cannot parse IP address for interface: " + status)
+ }
+
+ // The mask is represented as CIDR relative to the IPv6 address.
+ // Plan 9 internal representation is always IPv6.
+ maskfld := fields[1]
+ maskfld = maskfld[1:]
+ pfxlen, _, ok := dtoi(maskfld)
+ if !ok {
+ return nil, errors.New("cannot parse network mask for interface: " + status)
+ }
+ var mask IPMask
+ if ip.To4() != nil { // IPv4 or IPv6 IPv4-mapped address
+ mask = CIDRMask(pfxlen-8*len(v4InV6Prefix), 8*IPv4len)
+ }
+ if ip.To16() != nil && ip.To4() == nil { // IPv6 address
+ mask = CIDRMask(pfxlen, 8*IPv6len)
+ }
+
+ addrs[i] = &IPNet{IP: ip, Mask: mask}
+ }
+
+ return addrs, nil
+}
+
+// interfaceMulticastAddrTable returns addresses for a specific
+// interface.
+func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) {
+ return nil, nil
+}
diff --git a/libgo/go/net/interface_solaris.go b/libgo/go/net/interface_solaris.go
new file mode 100644
index 0000000..dc8ffbf
--- /dev/null
+++ b/libgo/go/net/interface_solaris.go
@@ -0,0 +1,107 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "syscall"
+
+ "golang_org/x/net/lif"
+)
+
+// If the ifindex is zero, interfaceTable returns mappings of all
+// network interfaces. Otherwise it returns a mapping of a specific
+// interface.
+func interfaceTable(ifindex int) ([]Interface, error) {
+ lls, err := lif.Links(syscall.AF_UNSPEC, "")
+ if err != nil {
+ return nil, err
+ }
+ var ift []Interface
+ for _, ll := range lls {
+ if ifindex != 0 && ifindex != ll.Index {
+ continue
+ }
+ ifi := Interface{Index: ll.Index, MTU: ll.MTU, Name: ll.Name, Flags: linkFlags(ll.Flags)}
+ if len(ll.Addr) > 0 {
+ ifi.HardwareAddr = HardwareAddr(ll.Addr)
+ }
+ ift = append(ift, ifi)
+ }
+ return ift, nil
+}
+
+const (
+ sysIFF_UP = 0x1
+ sysIFF_BROADCAST = 0x2
+ sysIFF_DEBUG = 0x4
+ sysIFF_LOOPBACK = 0x8
+ sysIFF_POINTOPOINT = 0x10
+ sysIFF_NOTRAILERS = 0x20
+ sysIFF_RUNNING = 0x40
+ sysIFF_NOARP = 0x80
+ sysIFF_PROMISC = 0x100
+ sysIFF_ALLMULTI = 0x200
+ sysIFF_INTELLIGENT = 0x400
+ sysIFF_MULTICAST = 0x800
+ sysIFF_MULTI_BCAST = 0x1000
+ sysIFF_UNNUMBERED = 0x2000
+ sysIFF_PRIVATE = 0x8000
+)
+
+func linkFlags(rawFlags int) Flags {
+ var f Flags
+ if rawFlags&sysIFF_UP != 0 {
+ f |= FlagUp
+ }
+ if rawFlags&sysIFF_BROADCAST != 0 {
+ f |= FlagBroadcast
+ }
+ if rawFlags&sysIFF_LOOPBACK != 0 {
+ f |= FlagLoopback
+ }
+ if rawFlags&sysIFF_POINTOPOINT != 0 {
+ f |= FlagPointToPoint
+ }
+ if rawFlags&sysIFF_MULTICAST != 0 {
+ f |= FlagMulticast
+ }
+ return f
+}
+
+// If the ifi is nil, interfaceAddrTable returns addresses for all
+// network interfaces. Otherwise it returns addresses for a specific
+// interface.
+func interfaceAddrTable(ifi *Interface) ([]Addr, error) {
+ var name string
+ if ifi != nil {
+ name = ifi.Name
+ }
+ as, err := lif.Addrs(syscall.AF_UNSPEC, name)
+ if err != nil {
+ return nil, err
+ }
+ var ifat []Addr
+ for _, a := range as {
+ var ip IP
+ var mask IPMask
+ switch a := a.(type) {
+ case *lif.Inet4Addr:
+ ip = IPv4(a.IP[0], a.IP[1], a.IP[2], a.IP[3])
+ mask = CIDRMask(a.PrefixLen, 8*IPv4len)
+ case *lif.Inet6Addr:
+ ip = make(IP, IPv6len)
+ copy(ip, a.IP[:])
+ mask = CIDRMask(a.PrefixLen, 8*IPv6len)
+ }
+ ifat = append(ifat, &IPNet{IP: ip, Mask: mask})
+ }
+ return ifat, nil
+}
+
+// interfaceMulticastAddrTable returns addresses for a specific
+// interface.
+func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) {
+ return nil, nil
+}
diff --git a/libgo/go/net/interface_stub.go b/libgo/go/net/interface_stub.go
index f64174c..3b0a1ae 100644
--- a/libgo/go/net/interface_stub.go
+++ b/libgo/go/net/interface_stub.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build nacl plan9 solaris
+// +build nacl
package net
diff --git a/libgo/go/net/interface_test.go b/libgo/go/net/interface_test.go
index 4c695b9..38a2ca4 100644
--- a/libgo/go/net/interface_test.go
+++ b/libgo/go/net/interface_test.go
@@ -58,8 +58,15 @@ func TestInterfaces(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- if !reflect.DeepEqual(ifxi, &ifi) {
- t.Errorf("got %v; want %v", ifxi, ifi)
+ switch runtime.GOOS {
+ case "solaris":
+ if ifxi.Index != ifi.Index {
+ t.Errorf("got %v; want %v", ifxi, ifi)
+ }
+ default:
+ if !reflect.DeepEqual(ifxi, &ifi) {
+ t.Errorf("got %v; want %v", ifxi, ifi)
+ }
}
ifxn, err := InterfaceByName(ifi.Name)
if err != nil {
diff --git a/libgo/go/net/ip.go b/libgo/go/net/ip.go
index d0c8263..db3364c 100644
--- a/libgo/go/net/ip.go
+++ b/libgo/go/net/ip.go
@@ -90,7 +90,7 @@ func CIDRMask(ones, bits int) IPMask {
// Well-known IPv4 addresses
var (
- IPv4bcast = IPv4(255, 255, 255, 255) // broadcast
+ IPv4bcast = IPv4(255, 255, 255, 255) // limited broadcast
IPv4allsys = IPv4(224, 0, 0, 1) // all systems
IPv4allrouter = IPv4(224, 0, 0, 2) // all routers
IPv4zero = IPv4(0, 0, 0, 0) // all zeros
@@ -153,6 +153,12 @@ func (ip IP) IsLinkLocalUnicast() bool {
// IsGlobalUnicast reports whether ip is a global unicast
// address.
+//
+// The identification of global unicast addresses uses address type
+// identification as defined in RFC 1122, RFC 4632 and RFC 4291 with
+// the exception of IPv4 directed broadcast addresses.
+// It returns true even if ip is in IPv4 private address space or
+// local IPv6 unicast address space.
func (ip IP) IsGlobalUnicast() bool {
return (len(ip) == IPv4len || len(ip) == IPv6len) &&
!ip.Equal(IPv4bcast) &&
@@ -504,29 +510,25 @@ func (n *IPNet) String() string {
// Parse IPv4 address (d.d.d.d).
func parseIPv4(s string) IP {
var p [IPv4len]byte
- i := 0
- for j := 0; j < IPv4len; j++ {
- if i >= len(s) {
+ for i := 0; i < IPv4len; i++ {
+ if len(s) == 0 {
// Missing octets.
return nil
}
- if j > 0 {
- if s[i] != '.' {
+ if i > 0 {
+ if s[0] != '.' {
return nil
}
- i++
+ s = s[1:]
}
- var (
- n int
- ok bool
- )
- n, i, ok = dtoi(s, i)
+ n, c, ok := dtoi(s)
if !ok || n > 0xFF {
return nil
}
- p[j] = byte(n)
+ s = s[c:]
+ p[i] = byte(n)
}
- if i != len(s) {
+ if len(s) != 0 {
return nil
}
return IPv4(p[0], p[1], p[2], p[3])
@@ -538,8 +540,7 @@ func parseIPv4(s string) IP {
// true.
func parseIPv6(s string, zoneAllowed bool) (ip IP, zone string) {
ip = make(IP, IPv6len)
- ellipsis := -1 // position of ellipsis in p
- i := 0 // index in string s
+ ellipsis := -1 // position of ellipsis in ip
if zoneAllowed {
s, zone = splitHostZone(s)
@@ -548,90 +549,91 @@ func parseIPv6(s string, zoneAllowed bool) (ip IP, zone string) {
// Might have leading ellipsis
if len(s) >= 2 && s[0] == ':' && s[1] == ':' {
ellipsis = 0
- i = 2
+ s = s[2:]
// Might be only ellipsis
- if i == len(s) {
+ if len(s) == 0 {
return ip, zone
}
}
// Loop, parsing hex numbers followed by colon.
- j := 0
- for j < IPv6len {
+ i := 0
+ for i < IPv6len {
// Hex number.
- n, i1, ok := xtoi(s, i)
+ n, c, ok := xtoi(s)
if !ok || n > 0xFFFF {
return nil, zone
}
// If followed by dot, might be in trailing IPv4.
- if i1 < len(s) && s[i1] == '.' {
- if ellipsis < 0 && j != IPv6len-IPv4len {
+ if c < len(s) && s[c] == '.' {
+ if ellipsis < 0 && i != IPv6len-IPv4len {
// Not the right place.
return nil, zone
}
- if j+IPv4len > IPv6len {
+ if i+IPv4len > IPv6len {
// Not enough room.
return nil, zone
}
- ip4 := parseIPv4(s[i:])
+ ip4 := parseIPv4(s)
if ip4 == nil {
return nil, zone
}
- ip[j] = ip4[12]
- ip[j+1] = ip4[13]
- ip[j+2] = ip4[14]
- ip[j+3] = ip4[15]
- i = len(s)
- j += IPv4len
+ ip[i] = ip4[12]
+ ip[i+1] = ip4[13]
+ ip[i+2] = ip4[14]
+ ip[i+3] = ip4[15]
+ s = ""
+ i += IPv4len
break
}
// Save this 16-bit chunk.
- ip[j] = byte(n >> 8)
- ip[j+1] = byte(n)
- j += 2
+ ip[i] = byte(n >> 8)
+ ip[i+1] = byte(n)
+ i += 2
// Stop at end of string.
- i = i1
- if i == len(s) {
+ s = s[c:]
+ if len(s) == 0 {
break
}
// Otherwise must be followed by colon and more.
- if s[i] != ':' || i+1 == len(s) {
+ if s[0] != ':' || len(s) == 1 {
return nil, zone
}
- i++
+ s = s[1:]
// Look for ellipsis.
- if s[i] == ':' {
+ if s[0] == ':' {
if ellipsis >= 0 { // already have one
return nil, zone
}
- ellipsis = j
- if i++; i == len(s) { // can be at end
+ ellipsis = i
+ s = s[1:]
+ if len(s) == 0 { // can be at end
break
}
}
}
// Must have used entire string.
- if i != len(s) {
+ if len(s) != 0 {
return nil, zone
}
// If didn't parse enough, expand ellipsis.
- if j < IPv6len {
+ if i < IPv6len {
if ellipsis < 0 {
return nil, zone
}
- n := IPv6len - j
- for k := j - 1; k >= ellipsis; k-- {
- ip[k+n] = ip[k]
+ n := IPv6len - i
+ for j := i - 1; j >= ellipsis; j-- {
+ ip[j+n] = ip[j]
}
- for k := ellipsis + n - 1; k >= ellipsis; k-- {
- ip[k] = 0
+ for j := ellipsis + n - 1; j >= ellipsis; j-- {
+ ip[j] = 0
}
} else if ellipsis >= 0 {
// Ellipsis must represent at least one 0 group.
@@ -658,13 +660,14 @@ func ParseIP(s string) IP {
return nil
}
-// ParseCIDR parses s as a CIDR notation IP address and mask,
+// ParseCIDR parses s as a CIDR notation IP address and prefix length,
// like "192.0.2.0/24" or "2001:db8::/32", as defined in
// RFC 4632 and RFC 4291.
//
-// It returns the IP address and the network implied by the IP
-// and mask. For example, ParseCIDR("198.51.100.1/24") returns
-// the IP address 198.51.100.1 and the network 198.51.100.0/24.
+// It returns the IP address and the network implied by the IP and
+// prefix length.
+// For example, ParseCIDR("192.0.2.1/24") returns the IP address
+// 198.0.2.1 and the network 198.0.2.0/24.
func ParseCIDR(s string) (IP, *IPNet, error) {
i := byteIndex(s, '/')
if i < 0 {
@@ -677,7 +680,7 @@ func ParseCIDR(s string) (IP, *IPNet, error) {
iplen = IPv6len
ip, _ = parseIPv6(addr, false)
}
- n, i, ok := dtoi(mask, 0)
+ n, i, ok := dtoi(mask)
if ip == nil || !ok || i != len(mask) || n < 0 || n > 8*iplen {
return nil, nil, &ParseError{Type: "CIDR address", Text: s}
}
diff --git a/libgo/go/net/ip_test.go b/libgo/go/net/ip_test.go
index b6ac26d..4655163 100644
--- a/libgo/go/net/ip_test.go
+++ b/libgo/go/net/ip_test.go
@@ -28,6 +28,10 @@ var parseIPTests = []struct {
{"2001:4860:0:2001::68", IP{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68}},
{"2001:4860:0000:2001:0000:0000:0000:0068", IP{0x20, 0x01, 0x48, 0x60, 0, 0, 0x20, 0x01, 0, 0, 0, 0, 0, 0, 0x00, 0x68}},
+ {"-0.0.0.0", nil},
+ {"0.-1.0.0", nil},
+ {"0.0.-2.0", nil},
+ {"0.0.0.-3", nil},
{"127.0.0.256", nil},
{"abc", nil},
{"123:", nil},
@@ -242,13 +246,15 @@ func TestIPString(t *testing.T) {
}
}
+var sink string
+
func BenchmarkIPString(b *testing.B) {
testHookUninstaller.Do(uninstallTestHooks)
for i := 0; i < b.N; i++ {
for _, tt := range ipStringTests {
if tt.in != nil {
- tt.in.String()
+ sink = tt.in.String()
}
}
}
@@ -299,7 +305,7 @@ func BenchmarkIPMaskString(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, tt := range ipMaskStringTests {
- tt.in.String()
+ sink = tt.in.String()
}
}
}
@@ -330,6 +336,12 @@ var parseCIDRTests = []struct {
{"192.168.1.1/255.255.255.0", nil, nil, &ParseError{Type: "CIDR address", Text: "192.168.1.1/255.255.255.0"}},
{"192.168.1.1/35", nil, nil, &ParseError{Type: "CIDR address", Text: "192.168.1.1/35"}},
{"2001:db8::1/-1", nil, nil, &ParseError{Type: "CIDR address", Text: "2001:db8::1/-1"}},
+ {"2001:db8::1/-0", nil, nil, &ParseError{Type: "CIDR address", Text: "2001:db8::1/-0"}},
+ {"-0.0.0.0/32", nil, nil, &ParseError{Type: "CIDR address", Text: "-0.0.0.0/32"}},
+ {"0.-1.0.0/32", nil, nil, &ParseError{Type: "CIDR address", Text: "0.-1.0.0/32"}},
+ {"0.0.-2.0/32", nil, nil, &ParseError{Type: "CIDR address", Text: "0.0.-2.0/32"}},
+ {"0.0.0.-3/32", nil, nil, &ParseError{Type: "CIDR address", Text: "0.0.0.-3/32"}},
+ {"0.0.0.0/-0", nil, nil, &ParseError{Type: "CIDR address", Text: "0.0.0.0/-0"}},
{"", nil, nil, &ParseError{Type: "CIDR address", Text: ""}},
}
diff --git a/libgo/go/net/iprawsock.go b/libgo/go/net/iprawsock.go
index 173b3cb..d994fc6 100644
--- a/libgo/go/net/iprawsock.go
+++ b/libgo/go/net/iprawsock.go
@@ -9,6 +9,24 @@ import (
"syscall"
)
+// BUG(mikio): On every POSIX platform, reads from the "ip4" network
+// using the ReadFrom or ReadFromIP method might not return a complete
+// IPv4 packet, including its header, even if there is space
+// available. This can occur even in cases where Read or ReadMsgIP
+// could return a complete packet. For this reason, it is recommended
+// that you do not use these methods if it is important to receive a
+// full packet.
+//
+// The Go 1 compatibility guidelines make it impossible for us to
+// change the behavior of these methods; use Read or ReadMsgIP
+// instead.
+
+// BUG(mikio): On NaCl, Plan 9 and Windows, the ReadMsgIP and
+// WriteMsgIP methods of IPConn are not implemented.
+
+// BUG(mikio): On Windows, the File method of IPConn is not
+// implemented.
+
// IPAddr represents the address of an IP end point.
type IPAddr struct {
IP IP
@@ -46,6 +64,9 @@ func (a *IPAddr) opAddr() Addr {
// ResolveIPAddr parses addr as an IP address of the form "host" or
// "ipv6-host%zone" and resolves the domain name on the network net,
// which must be "ip", "ip4" or "ip6".
+//
+// Resolving a hostname is not recommended because this returns at most
+// one of its IP addresses.
func ResolveIPAddr(net, addr string) (*IPAddr, error) {
if net == "" { // a hint wildcard for Go 1.0 undocumented behavior
net = "ip"
@@ -59,7 +80,7 @@ func ResolveIPAddr(net, addr string) (*IPAddr, error) {
default:
return nil, UnknownNetworkError(net)
}
- addrs, err := internetAddrList(context.Background(), afnet, addr)
+ addrs, err := DefaultResolver.internetAddrList(context.Background(), afnet, addr)
if err != nil {
return nil, err
}
diff --git a/libgo/go/net/iprawsock_posix.go b/libgo/go/net/iprawsock_posix.go
index 3e0b060..8f4b702 100644
--- a/libgo/go/net/iprawsock_posix.go
+++ b/libgo/go/net/iprawsock_posix.go
@@ -11,18 +11,6 @@ import (
"syscall"
)
-// BUG(mikio): On every POSIX platform, reads from the "ip4" network
-// using the ReadFrom or ReadFromIP method might not return a complete
-// IPv4 packet, including its header, even if there is space
-// available. This can occur even in cases where Read or ReadMsgIP
-// could return a complete packet. For this reason, it is recommended
-// that you do not uses these methods if it is important to receive a
-// full packet.
-//
-// The Go 1 compatibility guidelines make it impossible for us to
-// change the behavior of these methods; use Read or ReadMsgIP
-// instead.
-
func sockaddrToIP(sa syscall.Sockaddr) Addr {
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
@@ -50,6 +38,10 @@ func (a *IPAddr) sockaddr(family int) (syscall.Sockaddr, error) {
return ipToSockaddr(family, a.IP, 0, a.Zone)
}
+func (a *IPAddr) toLocal(net string) sockaddr {
+ return &IPAddr{loopbackIP(net), a.Zone}
+}
+
func (c *IPConn) readFrom(b []byte) (int, *IPAddr, error) {
// TODO(cw,rsc): consider using readv if we know the family
// type to avoid the header trim/copy
diff --git a/libgo/go/net/iprawsock_test.go b/libgo/go/net/iprawsock_test.go
index 29cd4b6..5d33b26 100644
--- a/libgo/go/net/iprawsock_test.go
+++ b/libgo/go/net/iprawsock_test.go
@@ -43,6 +43,13 @@ var resolveIPAddrTests = []resolveIPAddrTest{
{"l2tp", "127.0.0.1", nil, UnknownNetworkError("l2tp")},
{"l2tp:gre", "127.0.0.1", nil, UnknownNetworkError("l2tp:gre")},
{"tcp", "1.2.3.4:123", nil, UnknownNetworkError("tcp")},
+
+ {"ip4", "2001:db8::1", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "2001:db8::1"}},
+ {"ip4:icmp", "2001:db8::1", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "2001:db8::1"}},
+ {"ip6", "127.0.0.1", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "127.0.0.1"}},
+ {"ip6", "::ffff:127.0.0.1", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "::ffff:127.0.0.1"}},
+ {"ip6:ipv6-icmp", "127.0.0.1", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "127.0.0.1"}},
+ {"ip6:ipv6-icmp", "::ffff:127.0.0.1", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "::ffff:127.0.0.1"}},
}
func TestResolveIPAddr(t *testing.T) {
@@ -54,21 +61,17 @@ func TestResolveIPAddr(t *testing.T) {
defer func() { testHookLookupIP = origTestHookLookupIP }()
testHookLookupIP = lookupLocalhost
- for i, tt := range resolveIPAddrTests {
+ for _, tt := range resolveIPAddrTests {
addr, err := ResolveIPAddr(tt.network, tt.litAddrOrName)
- if err != tt.err {
- t.Errorf("#%d: %v", i, err)
- } else if !reflect.DeepEqual(addr, tt.addr) {
- t.Errorf("#%d: got %#v; want %#v", i, addr, tt.addr)
- }
- if err != nil {
+ if !reflect.DeepEqual(addr, tt.addr) || !reflect.DeepEqual(err, tt.err) {
+ t.Errorf("ResolveIPAddr(%q, %q) = %#v, %v, want %#v, %v", tt.network, tt.litAddrOrName, addr, err, tt.addr, tt.err)
continue
}
- rtaddr, err := ResolveIPAddr(addr.Network(), addr.String())
- if err != nil {
- t.Errorf("#%d: %v", i, err)
- } else if !reflect.DeepEqual(rtaddr, addr) {
- t.Errorf("#%d: got %#v; want %#v", i, rtaddr, addr)
+ if err == nil {
+ addr2, err := ResolveIPAddr(addr.Network(), addr.String())
+ if !reflect.DeepEqual(addr2, tt.addr) || err != tt.err {
+ t.Errorf("(%q, %q): ResolveIPAddr(%q, %q) = %#v, %v, want %#v, %v", tt.network, tt.litAddrOrName, addr.Network(), addr.String(), addr2, err, tt.addr, tt.err)
+ }
}
}
}
diff --git a/libgo/go/net/ipsock.go b/libgo/go/net/ipsock.go
index 24daf17..f1394a7 100644
--- a/libgo/go/net/ipsock.go
+++ b/libgo/go/net/ipsock.go
@@ -10,6 +10,13 @@ import (
"context"
)
+// BUG(rsc,mikio): On DragonFly BSD and OpenBSD, listening on the
+// "tcp" and "udp" networks does not listen for both IPv4 and IPv6
+// connections. This is due to the fact that IPv4 traffic will not be
+// routed to an IPv6 socket - two separate sockets are required if
+// both address families are to be supported.
+// See inet6(4) for details.
+
var (
// supportsIPv4 reports whether the platform supports IPv4
// networking functionality.
@@ -76,7 +83,7 @@ func (addrs addrList) partition(strategy func(Addr) bool) (primaries, fallbacks
// yielding a list of Addr objects. Known filters are nil, ipv4only,
// and ipv6only. It returns every address when the filter is nil.
// The result contains at least one address when error is nil.
-func filterAddrList(filter func(IPAddr) bool, ips []IPAddr, inetaddr func(IPAddr) Addr) (addrList, error) {
+func filterAddrList(filter func(IPAddr) bool, ips []IPAddr, inetaddr func(IPAddr) Addr, originalAddr string) (addrList, error) {
var addrs addrList
for _, ip := range ips {
if filter == nil || filter(ip) {
@@ -84,21 +91,19 @@ func filterAddrList(filter func(IPAddr) bool, ips []IPAddr, inetaddr func(IPAddr
}
}
if len(addrs) == 0 {
- return nil, errNoSuitableAddress
+ return nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: originalAddr}
}
return addrs, nil
}
-// ipv4only reports whether the kernel supports IPv4 addressing mode
-// and addr is an IPv4 address.
+// ipv4only reports whether addr is an IPv4 address.
func ipv4only(addr IPAddr) bool {
- return supportsIPv4 && addr.IP.To4() != nil
+ return addr.IP.To4() != nil
}
-// ipv6only reports whether the kernel supports IPv6 addressing mode
-// and addr is an IPv6 address except IPv4-mapped IPv6 address.
+// ipv6only reports whether addr is an IPv6 address except IPv4-mapped IPv6 address.
func ipv6only(addr IPAddr) bool {
- return supportsIPv6 && len(addr.IP) == IPv6len && addr.IP.To4() == nil
+ return len(addr.IP) == IPv6len && addr.IP.To4() == nil
}
// SplitHostPort splits a network address of the form "host:port",
@@ -190,7 +195,7 @@ func JoinHostPort(host, port string) string {
// address or a DNS name, and returns a list of internet protocol
// family addresses. The result contains at least one address when
// error is nil.
-func internetAddrList(ctx context.Context, net, addr string) (addrList, error) {
+func (r *Resolver) internetAddrList(ctx context.Context, net, addr string) (addrList, error) {
var (
err error
host, port string
@@ -202,7 +207,7 @@ func internetAddrList(ctx context.Context, net, addr string) (addrList, error) {
if host, port, err = SplitHostPort(addr); err != nil {
return nil, err
}
- if portnum, err = LookupPort(net, port); err != nil {
+ if portnum, err = r.LookupPort(ctx, net, port); err != nil {
return nil, err
}
}
@@ -228,20 +233,21 @@ func internetAddrList(ctx context.Context, net, addr string) (addrList, error) {
if host == "" {
return addrList{inetaddr(IPAddr{})}, nil
}
- // Try as a literal IP address.
- var ip IP
- if ip = parseIPv4(host); ip != nil {
- return addrList{inetaddr(IPAddr{IP: ip})}, nil
- }
- var zone string
- if ip, zone = parseIPv6(host, true); ip != nil {
- return addrList{inetaddr(IPAddr{IP: ip, Zone: zone})}, nil
- }
- // Try as a DNS name.
- ips, err := lookupIPContext(ctx, host)
- if err != nil {
- return nil, err
+
+ // Try as a literal IP address, then as a DNS name.
+ var ips []IPAddr
+ if ip := parseIPv4(host); ip != nil {
+ ips = []IPAddr{{IP: ip}}
+ } else if ip, zone := parseIPv6(host, true); ip != nil {
+ ips = []IPAddr{{IP: ip, Zone: zone}}
+ } else {
+ // Try as a DNS name.
+ ips, err = r.LookupIPAddr(ctx, host)
+ if err != nil {
+ return nil, err
+ }
}
+
var filter func(IPAddr) bool
if net != "" && net[len(net)-1] == '4' {
filter = ipv4only
@@ -249,5 +255,12 @@ func internetAddrList(ctx context.Context, net, addr string) (addrList, error) {
if net != "" && net[len(net)-1] == '6' {
filter = ipv6only
}
- return filterAddrList(filter, ips, inetaddr)
+ return filterAddrList(filter, ips, inetaddr, host)
+}
+
+func loopbackIP(net string) IP {
+ if net != "" && net[len(net)-1] == '6' {
+ return IPv6loopback
+ }
+ return IP{127, 0, 0, 1}
}
diff --git a/libgo/go/net/ipsock_plan9.go b/libgo/go/net/ipsock_plan9.go
index 2b84683..b7fd344 100644
--- a/libgo/go/net/ipsock_plan9.go
+++ b/libgo/go/net/ipsock_plan9.go
@@ -63,7 +63,7 @@ func parsePlan9Addr(s string) (ip IP, iport int, err error) {
return nil, 0, &ParseError{Type: "IP address", Text: s}
}
}
- p, _, ok := dtoi(s[i+1:], 0)
+ p, _, ok := dtoi(s[i+1:])
if !ok {
return nil, 0, &ParseError{Type: "port", Text: s}
}
@@ -119,6 +119,11 @@ func startPlan9(ctx context.Context, net string, addr Addr) (ctl *os.File, dest,
return
}
+ if port > 65535 {
+ err = InvalidAddrError("port should be < 65536")
+ return
+ }
+
clone, dest, err := queryCS1(ctx, proto, ip, port)
if err != nil {
return
@@ -193,6 +198,9 @@ func dialPlan9(ctx context.Context, net string, laddr, raddr Addr) (fd *netFD, e
}
func dialPlan9Blocking(ctx context.Context, net string, laddr, raddr Addr) (fd *netFD, err error) {
+ if isWildcard(raddr) {
+ raddr = toLocal(raddr, net)
+ }
f, dest, proto, name, err := startPlan9(ctx, net, raddr)
if err != nil {
return nil, err
@@ -213,7 +221,7 @@ func dialPlan9Blocking(ctx context.Context, net string, laddr, raddr Addr) (fd *
f.Close()
return nil, err
}
- return newFD(proto, name, f, data, laddr, raddr)
+ return newFD(proto, name, nil, f, data, laddr, raddr)
}
func listenPlan9(ctx context.Context, net string, laddr Addr) (fd *netFD, err error) {
@@ -232,11 +240,11 @@ func listenPlan9(ctx context.Context, net string, laddr Addr) (fd *netFD, err er
f.Close()
return nil, err
}
- return newFD(proto, name, f, nil, laddr, nil)
+ return newFD(proto, name, nil, f, nil, laddr, nil)
}
func (fd *netFD) netFD() (*netFD, error) {
- return newFD(fd.net, fd.n, fd.ctl, fd.data, fd.laddr, fd.raddr)
+ return newFD(fd.net, fd.n, fd.listen, fd.ctl, fd.data, fd.laddr, fd.raddr)
}
func (fd *netFD) acceptPlan9() (nfd *netFD, err error) {
@@ -245,27 +253,59 @@ func (fd *netFD) acceptPlan9() (nfd *netFD, err error) {
return nil, err
}
defer fd.readUnlock()
- f, err := os.Open(fd.dir + "/listen")
+ listen, err := os.Open(fd.dir + "/listen")
if err != nil {
return nil, err
}
var buf [16]byte
- n, err := f.Read(buf[:])
+ n, err := listen.Read(buf[:])
if err != nil {
- f.Close()
+ listen.Close()
return nil, err
}
name := string(buf[:n])
+ ctl, err := os.OpenFile(netdir+"/"+fd.net+"/"+name+"/ctl", os.O_RDWR, 0)
+ if err != nil {
+ listen.Close()
+ return nil, err
+ }
data, err := os.OpenFile(netdir+"/"+fd.net+"/"+name+"/data", os.O_RDWR, 0)
if err != nil {
- f.Close()
+ listen.Close()
+ ctl.Close()
return nil, err
}
raddr, err := readPlan9Addr(fd.net, netdir+"/"+fd.net+"/"+name+"/remote")
if err != nil {
+ listen.Close()
+ ctl.Close()
data.Close()
- f.Close()
return nil, err
}
- return newFD(fd.net, name, f, data, fd.laddr, raddr)
+ return newFD(fd.net, name, listen, ctl, data, fd.laddr, raddr)
+}
+
+func isWildcard(a Addr) bool {
+ var wildcard bool
+ switch a := a.(type) {
+ case *TCPAddr:
+ wildcard = a.isWildcard()
+ case *UDPAddr:
+ wildcard = a.isWildcard()
+ case *IPAddr:
+ wildcard = a.isWildcard()
+ }
+ return wildcard
+}
+
+func toLocal(a Addr, net string) Addr {
+ switch a := a.(type) {
+ case *TCPAddr:
+ a.IP = loopbackIP(net)
+ case *UDPAddr:
+ a.IP = loopbackIP(net)
+ case *IPAddr:
+ a.IP = loopbackIP(net)
+ }
+ return a
}
diff --git a/libgo/go/net/ipsock_posix.go b/libgo/go/net/ipsock_posix.go
index abe90ac..ff280c3 100644
--- a/libgo/go/net/ipsock_posix.go
+++ b/libgo/go/net/ipsock_posix.go
@@ -12,13 +12,6 @@ import (
"syscall"
)
-// BUG(rsc,mikio): On DragonFly BSD and OpenBSD, listening on the
-// "tcp" and "udp" networks does not listen for both IPv4 and IPv6
-// connections. This is due to the fact that IPv4 traffic will not be
-// routed to an IPv6 socket - two separate sockets are required if
-// both address families are to be supported.
-// See inet6(4) for details.
-
func probeIPv4Stack() bool {
s, err := socketFunc(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
switch err {
@@ -154,6 +147,9 @@ func favoriteAddrFamily(net string, laddr, raddr sockaddr, mode string) (family
// Internet sockets (TCP, UDP, IP)
func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string) (fd *netFD, err error) {
+ if (runtime.GOOS == "windows" || runtime.GOOS == "openbsd" || runtime.GOOS == "nacl") && mode == "dial" && raddr.isWildcard() {
+ raddr = raddr.toLocal(net)
+ }
family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode)
return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr)
}
diff --git a/libgo/go/net/ipsock_test.go b/libgo/go/net/ipsock_test.go
index b36557a..1d0f00f 100644
--- a/libgo/go/net/ipsock_test.go
+++ b/libgo/go/net/ipsock_test.go
@@ -205,13 +205,13 @@ var addrListTests = []struct {
nil,
},
- {nil, nil, testInetaddr, nil, nil, nil, errNoSuitableAddress},
+ {nil, nil, testInetaddr, nil, nil, nil, &AddrError{errNoSuitableAddress.Error(), "ADDR"}},
- {ipv4only, nil, testInetaddr, nil, nil, nil, errNoSuitableAddress},
- {ipv4only, []IPAddr{{IP: IPv6loopback}}, testInetaddr, nil, nil, nil, errNoSuitableAddress},
+ {ipv4only, nil, testInetaddr, nil, nil, nil, &AddrError{errNoSuitableAddress.Error(), "ADDR"}},
+ {ipv4only, []IPAddr{{IP: IPv6loopback}}, testInetaddr, nil, nil, nil, &AddrError{errNoSuitableAddress.Error(), "ADDR"}},
- {ipv6only, nil, testInetaddr, nil, nil, nil, errNoSuitableAddress},
- {ipv6only, []IPAddr{{IP: IPv4(127, 0, 0, 1)}}, testInetaddr, nil, nil, nil, errNoSuitableAddress},
+ {ipv6only, nil, testInetaddr, nil, nil, nil, &AddrError{errNoSuitableAddress.Error(), "ADDR"}},
+ {ipv6only, []IPAddr{{IP: IPv4(127, 0, 0, 1)}}, testInetaddr, nil, nil, nil, &AddrError{errNoSuitableAddress.Error(), "ADDR"}},
}
func TestAddrList(t *testing.T) {
@@ -220,8 +220,8 @@ func TestAddrList(t *testing.T) {
}
for i, tt := range addrListTests {
- addrs, err := filterAddrList(tt.filter, tt.ips, tt.inetaddr)
- if err != tt.err {
+ addrs, err := filterAddrList(tt.filter, tt.ips, tt.inetaddr, "ADDR")
+ if !reflect.DeepEqual(err, tt.err) {
t.Errorf("#%v: got %v; want %v", i, err, tt.err)
}
if tt.err != nil {
diff --git a/libgo/go/net/lookup.go b/libgo/go/net/lookup.go
index c169e9e..cc2013e 100644
--- a/libgo/go/net/lookup.go
+++ b/libgo/go/net/lookup.go
@@ -15,94 +15,137 @@ import (
// protocol numbers.
//
// See http://www.iana.org/assignments/protocol-numbers
+//
+// On Unix, this map is augmented by readProtocols via lookupProtocol.
var protocols = map[string]int{
- "icmp": 1, "ICMP": 1,
- "igmp": 2, "IGMP": 2,
- "tcp": 6, "TCP": 6,
- "udp": 17, "UDP": 17,
- "ipv6-icmp": 58, "IPV6-ICMP": 58, "IPv6-ICMP": 58,
+ "icmp": 1,
+ "igmp": 2,
+ "tcp": 6,
+ "udp": 17,
+ "ipv6-icmp": 58,
}
-// LookupHost looks up the given host using the local resolver.
-// It returns an array of that host's addresses.
-func LookupHost(host string) (addrs []string, err error) {
- // Make sure that no matter what we do later, host=="" is rejected.
- // ParseIP, for example, does accept empty strings.
- if host == "" {
- return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host}
+// services contains minimal mappings between services names and port
+// numbers for platforms that don't have a complete list of port numbers
+// (some Solaris distros, nacl, etc).
+// On Unix, this map is augmented by readServices via goLookupPort.
+var services = map[string]map[string]int{
+ "udp": {
+ "domain": 53,
+ },
+ "tcp": {
+ "ftp": 21,
+ "ftps": 990,
+ "gopher": 70, // ʕ◔ϖ◔ʔ
+ "http": 80,
+ "https": 443,
+ "imap2": 143,
+ "imap3": 220,
+ "imaps": 993,
+ "pop3": 110,
+ "pop3s": 995,
+ "smtp": 25,
+ "ssh": 22,
+ "telnet": 23,
+ },
+}
+
+const maxProtoLength = len("RSVP-E2E-IGNORE") + 10 // with room to grow
+
+func lookupProtocolMap(name string) (int, error) {
+ var lowerProtocol [maxProtoLength]byte
+ n := copy(lowerProtocol[:], name)
+ lowerASCIIBytes(lowerProtocol[:n])
+ proto, found := protocols[string(lowerProtocol[:n])]
+ if !found || n != len(name) {
+ return 0, &AddrError{Err: "unknown IP protocol specified", Addr: name}
}
- if ip := ParseIP(host); ip != nil {
- return []string{host}, nil
+ return proto, nil
+}
+
+const maxServiceLength = len("mobility-header") + 10 // with room to grow
+
+func lookupPortMap(network, service string) (port int, error error) {
+ switch network {
+ case "tcp4", "tcp6":
+ network = "tcp"
+ case "udp4", "udp6":
+ network = "udp"
}
- return lookupHost(context.Background(), host)
+
+ if m, ok := services[network]; ok {
+ var lowerService [maxServiceLength]byte
+ n := copy(lowerService[:], service)
+ lowerASCIIBytes(lowerService[:n])
+ if port, ok := m[string(lowerService[:n])]; ok && n == len(service) {
+ return port, nil
+ }
+ }
+ return 0, &AddrError{Err: "unknown port", Addr: network + "/" + service}
}
-// LookupIP looks up host using the local resolver.
-// It returns an array of that host's IPv4 and IPv6 addresses.
-func LookupIP(host string) (ips []IP, err error) {
+// DefaultResolver is the resolver used by the package-level Lookup
+// functions and by Dialers without a specified Resolver.
+var DefaultResolver = &Resolver{}
+
+// A Resolver looks up names and numbers.
+//
+// A nil *Resolver is equivalent to a zero Resolver.
+type Resolver struct {
+ // PreferGo controls whether Go's built-in DNS resolver is preferred
+ // on platforms where it's available. It is equivalent to setting
+ // GODEBUG=netdns=go, but scoped to just this resolver.
+ PreferGo bool
+
+ // TODO(bradfitz): optional interface impl override hook
+ // TODO(bradfitz): Timeout time.Duration?
+}
+
+// LookupHost looks up the given host using the local resolver.
+// It returns a slice of that host's addresses.
+func LookupHost(host string) (addrs []string, err error) {
+ return DefaultResolver.LookupHost(context.Background(), host)
+}
+
+// LookupHost looks up the given host using the local resolver.
+// It returns a slice of that host's addresses.
+func (r *Resolver) LookupHost(ctx context.Context, host string) (addrs []string, err error) {
// Make sure that no matter what we do later, host=="" is rejected.
// ParseIP, for example, does accept empty strings.
if host == "" {
return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host}
}
if ip := ParseIP(host); ip != nil {
- return []IP{ip}, nil
- }
- addrs, err := lookupIPMerge(context.Background(), host)
- if err != nil {
- return
- }
- ips = make([]IP, len(addrs))
- for i, addr := range addrs {
- ips[i] = addr.IP
+ return []string{host}, nil
}
- return
-}
-
-var lookupGroup singleflight.Group
-
-// lookupIPMerge wraps lookupIP, but makes sure that for any given
-// host, only one lookup is in-flight at a time. The returned memory
-// is always owned by the caller.
-func lookupIPMerge(ctx context.Context, host string) (addrs []IPAddr, err error) {
- addrsi, err, shared := lookupGroup.Do(host, func() (interface{}, error) {
- return testHookLookupIP(ctx, lookupIP, host)
- })
- return lookupIPReturn(addrsi, err, shared)
+ return r.lookupHost(ctx, host)
}
-// lookupIPReturn turns the return values from singleflight.Do into
-// the return values from LookupIP.
-func lookupIPReturn(addrsi interface{}, err error, shared bool) ([]IPAddr, error) {
+// LookupIP looks up host using the local resolver.
+// It returns a slice of that host's IPv4 and IPv6 addresses.
+func LookupIP(host string) ([]IP, error) {
+ addrs, err := DefaultResolver.LookupIPAddr(context.Background(), host)
if err != nil {
return nil, err
}
- addrs := addrsi.([]IPAddr)
- if shared {
- clone := make([]IPAddr, len(addrs))
- copy(clone, addrs)
- addrs = clone
+ ips := make([]IP, len(addrs))
+ for i, ia := range addrs {
+ ips[i] = ia.IP
}
- return addrs, nil
+ return ips, nil
}
-// ipAddrsEface returns an empty interface slice of addrs.
-func ipAddrsEface(addrs []IPAddr) []interface{} {
- s := make([]interface{}, len(addrs))
- for i, v := range addrs {
- s[i] = v
+// LookupIPAddr looks up host using the local resolver.
+// It returns a slice of that host's IPv4 and IPv6 addresses.
+func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, error) {
+ // Make sure that no matter what we do later, host=="" is rejected.
+ // ParseIP, for example, does accept empty strings.
+ if host == "" {
+ return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host}
+ }
+ if ip := ParseIP(host); ip != nil {
+ return []IPAddr{{IP: ip}}, nil
}
- return s
-}
-
-// lookupIPContext looks up a hostname with a context.
-//
-// TODO(bradfitz): rename this function. All the other
-// build-tag-specific lookupIP funcs also take a context now, so this
-// name is no longer great. Maybe make this lookupIPMerge and ditch
-// the other one, making its callers call this instead with a
-// context.Background().
-func lookupIPContext(ctx context.Context, host string) (addrs []IPAddr, err error) {
trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace)
if trace != nil && trace.DNSStart != nil {
trace.DNSStart(host)
@@ -110,7 +153,7 @@ func lookupIPContext(ctx context.Context, host string) (addrs []IPAddr, err erro
// The underlying resolver func is lookupIP by default but it
// can be overridden by tests. This is needed by net/http, so it
// uses a context key instead of unexported variables.
- resolverFunc := lookupIP
+ resolverFunc := r.lookupIP
if alt, _ := ctx.Value(nettrace.LookupIPAltResolverKey{}).(func(context.Context, string) ([]IPAddr, error)); alt != nil {
resolverFunc = alt
}
@@ -140,11 +183,46 @@ func lookupIPContext(ctx context.Context, host string) (addrs []IPAddr, err erro
}
}
+// lookupGroup merges LookupIPAddr calls together for lookups
+// for the same host. The lookupGroup key is is the LookupIPAddr.host
+// argument.
+// The return values are ([]IPAddr, error).
+var lookupGroup singleflight.Group
+
+// lookupIPReturn turns the return values from singleflight.Do into
+// the return values from LookupIP.
+func lookupIPReturn(addrsi interface{}, err error, shared bool) ([]IPAddr, error) {
+ if err != nil {
+ return nil, err
+ }
+ addrs := addrsi.([]IPAddr)
+ if shared {
+ clone := make([]IPAddr, len(addrs))
+ copy(clone, addrs)
+ addrs = clone
+ }
+ return addrs, nil
+}
+
+// ipAddrsEface returns an empty interface slice of addrs.
+func ipAddrsEface(addrs []IPAddr) []interface{} {
+ s := make([]interface{}, len(addrs))
+ for i, v := range addrs {
+ s[i] = v
+ }
+ return s
+}
+
// LookupPort looks up the port for the given network and service.
func LookupPort(network, service string) (port int, err error) {
+ return DefaultResolver.LookupPort(context.Background(), network, service)
+}
+
+// LookupPort looks up the port for the given network and service.
+func (r *Resolver) LookupPort(ctx context.Context, network, service string) (port int, err error) {
port, needsLookup := parsePort(service)
if needsLookup {
- port, err = lookupPort(context.Background(), network, service)
+ port, err = r.lookupPort(ctx, network, service)
if err != nil {
return 0, err
}
@@ -155,12 +233,32 @@ func LookupPort(network, service string) (port int, err error) {
return port, nil
}
-// LookupCNAME returns the canonical DNS host for the given name.
+// LookupCNAME returns the canonical name for the given host.
+// Callers that do not care about the canonical name can call
+// LookupHost or LookupIP directly; both take care of resolving
+// the canonical name as part of the lookup.
+//
+// A canonical name is the final name after following zero
+// or more CNAME records.
+// LookupCNAME does not return an error if host does not
+// contain DNS "CNAME" records, as long as host resolves to
+// address records.
+func LookupCNAME(host string) (cname string, err error) {
+ return DefaultResolver.lookupCNAME(context.Background(), host)
+}
+
+// LookupCNAME returns the canonical name for the given host.
// Callers that do not care about the canonical name can call
// LookupHost or LookupIP directly; both take care of resolving
// the canonical name as part of the lookup.
-func LookupCNAME(name string) (cname string, err error) {
- return lookupCNAME(context.Background(), name)
+//
+// A canonical name is the final name after following zero
+// or more CNAME records.
+// LookupCNAME does not return an error if host does not
+// contain DNS "CNAME" records, as long as host resolves to
+// address records.
+func (r *Resolver) LookupCNAME(ctx context.Context, host string) (cname string, err error) {
+ return r.lookupCNAME(ctx, host)
}
// LookupSRV tries to resolve an SRV query of the given service,
@@ -173,26 +271,63 @@ func LookupCNAME(name string) (cname string, err error) {
// publishing SRV records under non-standard names, if both service
// and proto are empty strings, LookupSRV looks up name directly.
func LookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) {
- return lookupSRV(context.Background(), service, proto, name)
+ return DefaultResolver.lookupSRV(context.Background(), service, proto, name)
+}
+
+// LookupSRV tries to resolve an SRV query of the given service,
+// protocol, and domain name. The proto is "tcp" or "udp".
+// The returned records are sorted by priority and randomized
+// by weight within a priority.
+//
+// LookupSRV constructs the DNS name to look up following RFC 2782.
+// That is, it looks up _service._proto.name. To accommodate services
+// publishing SRV records under non-standard names, if both service
+// and proto are empty strings, LookupSRV looks up name directly.
+func (r *Resolver) LookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*SRV, err error) {
+ return r.lookupSRV(ctx, service, proto, name)
+}
+
+// LookupMX returns the DNS MX records for the given domain name sorted by preference.
+func LookupMX(name string) ([]*MX, error) {
+ return DefaultResolver.lookupMX(context.Background(), name)
}
// LookupMX returns the DNS MX records for the given domain name sorted by preference.
-func LookupMX(name string) (mxs []*MX, err error) {
- return lookupMX(context.Background(), name)
+func (r *Resolver) LookupMX(ctx context.Context, name string) ([]*MX, error) {
+ return r.lookupMX(ctx, name)
}
// LookupNS returns the DNS NS records for the given domain name.
-func LookupNS(name string) (nss []*NS, err error) {
- return lookupNS(context.Background(), name)
+func LookupNS(name string) ([]*NS, error) {
+ return DefaultResolver.lookupNS(context.Background(), name)
+}
+
+// LookupNS returns the DNS NS records for the given domain name.
+func (r *Resolver) LookupNS(ctx context.Context, name string) ([]*NS, error) {
+ return r.lookupNS(ctx, name)
}
// LookupTXT returns the DNS TXT records for the given domain name.
-func LookupTXT(name string) (txts []string, err error) {
- return lookupTXT(context.Background(), name)
+func LookupTXT(name string) ([]string, error) {
+ return DefaultResolver.lookupTXT(context.Background(), name)
+}
+
+// LookupTXT returns the DNS TXT records for the given domain name.
+func (r *Resolver) LookupTXT(ctx context.Context, name string) ([]string, error) {
+ return r.lookupTXT(ctx, name)
}
// LookupAddr performs a reverse lookup for the given address, returning a list
// of names mapping to that address.
+//
+// When using the host C library resolver, at most one result will be
+// returned. To bypass the host resolver, use a custom Resolver.
func LookupAddr(addr string) (names []string, err error) {
- return lookupAddr(context.Background(), addr)
+ return DefaultResolver.lookupAddr(context.Background(), addr)
+}
+
+// LookupAddr performs a reverse lookup for the given address, returning a list
+// of names mapping to that address.
+func (r *Resolver) LookupAddr(ctx context.Context, addr string) (names []string, err error) {
+ return r.lookupAddr(ctx, addr)
}
diff --git a/libgo/go/net/lookup_nacl.go b/libgo/go/net/lookup_nacl.go
new file mode 100644
index 0000000..43cebad
--- /dev/null
+++ b/libgo/go/net/lookup_nacl.go
@@ -0,0 +1,52 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build nacl
+
+package net
+
+import (
+ "context"
+ "syscall"
+)
+
+func lookupProtocol(ctx context.Context, name string) (proto int, err error) {
+ return lookupProtocolMap(name)
+}
+
+func (*Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) {
+ return nil, syscall.ENOPROTOOPT
+}
+
+func (*Resolver) lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
+ return nil, syscall.ENOPROTOOPT
+}
+
+func (*Resolver) lookupPort(ctx context.Context, network, service string) (port int, err error) {
+ return goLookupPort(network, service)
+}
+
+func (*Resolver) lookupCNAME(ctx context.Context, name string) (cname string, err error) {
+ return "", syscall.ENOPROTOOPT
+}
+
+func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (cname string, srvs []*SRV, err error) {
+ return "", nil, syscall.ENOPROTOOPT
+}
+
+func (*Resolver) lookupMX(ctx context.Context, name string) (mxs []*MX, err error) {
+ return nil, syscall.ENOPROTOOPT
+}
+
+func (*Resolver) lookupNS(ctx context.Context, name string) (nss []*NS, err error) {
+ return nil, syscall.ENOPROTOOPT
+}
+
+func (*Resolver) lookupTXT(ctx context.Context, name string) (txts []string, err error) {
+ return nil, syscall.ENOPROTOOPT
+}
+
+func (*Resolver) lookupAddr(ctx context.Context, addr string) (ptrs []string, err error) {
+ return nil, syscall.ENOPROTOOPT
+}
diff --git a/libgo/go/net/lookup_plan9.go b/libgo/go/net/lookup_plan9.go
index 3f7af2a..f81e220 100644
--- a/libgo/go/net/lookup_plan9.go
+++ b/libgo/go/net/lookup_plan9.go
@@ -111,17 +111,20 @@ func lookupProtocol(ctx context.Context, name string) (proto int, err error) {
return 0, UnknownNetworkError(name)
}
s := f[1]
- if n, _, ok := dtoi(s, byteIndex(s, '=')+1); ok {
+ if n, _, ok := dtoi(s[byteIndex(s, '=')+1:]); ok {
return n, nil
}
return 0, UnknownNetworkError(name)
}
-func lookupHost(ctx context.Context, host string) (addrs []string, err error) {
+func (*Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) {
// Use netdir/cs instead of netdir/dns because cs knows about
// host names in local network (e.g. from /lib/ndb/local)
lines, err := queryCS(ctx, "net", host, "1")
if err != nil {
+ if stringsHasSuffix(err.Error(), "dns failure") {
+ err = errNoSuchHost
+ }
return
}
loop:
@@ -148,8 +151,8 @@ loop:
return
}
-func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
- lits, err := lookupHost(ctx, host)
+func (r *Resolver) lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
+ lits, err := r.lookupHost(ctx, host)
if err != nil {
return
}
@@ -163,14 +166,14 @@ func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
return
}
-func lookupPort(ctx context.Context, network, service string) (port int, err error) {
+func (*Resolver) lookupPort(ctx context.Context, network, service string) (port int, err error) {
switch network {
case "tcp4", "tcp6":
network = "tcp"
case "udp4", "udp6":
network = "udp"
}
- lines, err := queryCS(ctx, network, "127.0.0.1", service)
+ lines, err := queryCS(ctx, network, "127.0.0.1", toLower(service))
if err != nil {
return
}
@@ -186,15 +189,19 @@ func lookupPort(ctx context.Context, network, service string) (port int, err err
if i := byteIndex(s, '!'); i >= 0 {
s = s[i+1:] // remove address
}
- if n, _, ok := dtoi(s, 0); ok {
+ if n, _, ok := dtoi(s); ok {
return n, nil
}
return 0, unknownPortError
}
-func lookupCNAME(ctx context.Context, name string) (cname string, err error) {
+func (*Resolver) lookupCNAME(ctx context.Context, name string) (cname string, err error) {
lines, err := queryDNS(ctx, name, "cname")
if err != nil {
+ if stringsHasSuffix(err.Error(), "dns failure") {
+ cname = name + "."
+ err = nil
+ }
return
}
if len(lines) > 0 {
@@ -205,7 +212,7 @@ func lookupCNAME(ctx context.Context, name string) (cname string, err error) {
return "", errors.New("bad response from ndb/dns")
}
-func lookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*SRV, err error) {
+func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (cname string, addrs []*SRV, err error) {
var target string
if service == "" && proto == "" {
target = name
@@ -221,9 +228,9 @@ func lookupSRV(ctx context.Context, service, proto, name string) (cname string,
if len(f) < 6 {
continue
}
- port, _, portOk := dtoi(f[4], 0)
- priority, _, priorityOk := dtoi(f[3], 0)
- weight, _, weightOk := dtoi(f[2], 0)
+ port, _, portOk := dtoi(f[4])
+ priority, _, priorityOk := dtoi(f[3])
+ weight, _, weightOk := dtoi(f[2])
if !(portOk && priorityOk && weightOk) {
continue
}
@@ -234,7 +241,7 @@ func lookupSRV(ctx context.Context, service, proto, name string) (cname string,
return
}
-func lookupMX(ctx context.Context, name string) (mx []*MX, err error) {
+func (*Resolver) lookupMX(ctx context.Context, name string) (mx []*MX, err error) {
lines, err := queryDNS(ctx, name, "mx")
if err != nil {
return
@@ -244,7 +251,7 @@ func lookupMX(ctx context.Context, name string) (mx []*MX, err error) {
if len(f) < 4 {
continue
}
- if pref, _, ok := dtoi(f[2], 0); ok {
+ if pref, _, ok := dtoi(f[2]); ok {
mx = append(mx, &MX{absDomainName([]byte(f[3])), uint16(pref)})
}
}
@@ -252,7 +259,7 @@ func lookupMX(ctx context.Context, name string) (mx []*MX, err error) {
return
}
-func lookupNS(ctx context.Context, name string) (ns []*NS, err error) {
+func (*Resolver) lookupNS(ctx context.Context, name string) (ns []*NS, err error) {
lines, err := queryDNS(ctx, name, "ns")
if err != nil {
return
@@ -267,7 +274,7 @@ func lookupNS(ctx context.Context, name string) (ns []*NS, err error) {
return
}
-func lookupTXT(ctx context.Context, name string) (txt []string, err error) {
+func (*Resolver) lookupTXT(ctx context.Context, name string) (txt []string, err error) {
lines, err := queryDNS(ctx, name, "txt")
if err != nil {
return
@@ -280,7 +287,7 @@ func lookupTXT(ctx context.Context, name string) (txt []string, err error) {
return
}
-func lookupAddr(ctx context.Context, addr string) (name []string, err error) {
+func (*Resolver) lookupAddr(ctx context.Context, addr string) (name []string, err error) {
arpa, err := reverseaddr(addr)
if err != nil {
return
diff --git a/libgo/go/net/lookup_stub.go b/libgo/go/net/lookup_stub.go
deleted file mode 100644
index bd096b3..0000000
--- a/libgo/go/net/lookup_stub.go
+++ /dev/null
@@ -1,52 +0,0 @@
-// Copyright 2011 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build nacl
-
-package net
-
-import (
- "context"
- "syscall"
-)
-
-func lookupProtocol(ctx context.Context, name string) (proto int, err error) {
- return 0, syscall.ENOPROTOOPT
-}
-
-func lookupHost(ctx context.Context, host string) (addrs []string, err error) {
- return nil, syscall.ENOPROTOOPT
-}
-
-func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
- return nil, syscall.ENOPROTOOPT
-}
-
-func lookupPort(ctx context.Context, network, service string) (port int, err error) {
- return 0, syscall.ENOPROTOOPT
-}
-
-func lookupCNAME(ctx context.Context, name string) (cname string, err error) {
- return "", syscall.ENOPROTOOPT
-}
-
-func lookupSRV(ctx context.Context, service, proto, name string) (cname string, srvs []*SRV, err error) {
- return "", nil, syscall.ENOPROTOOPT
-}
-
-func lookupMX(ctx context.Context, name string) (mxs []*MX, err error) {
- return nil, syscall.ENOPROTOOPT
-}
-
-func lookupNS(ctx context.Context, name string) (nss []*NS, err error) {
- return nil, syscall.ENOPROTOOPT
-}
-
-func lookupTXT(ctx context.Context, name string) (txts []string, err error) {
- return nil, syscall.ENOPROTOOPT
-}
-
-func lookupAddr(ctx context.Context, addr string) (ptrs []string, err error) {
- return nil, syscall.ENOPROTOOPT
-}
diff --git a/libgo/go/net/lookup_test.go b/libgo/go/net/lookup_test.go
index b3aeb85..36db56a 100644
--- a/libgo/go/net/lookup_test.go
+++ b/libgo/go/net/lookup_test.go
@@ -243,14 +243,15 @@ func TestLookupIPv6LinkLocalAddr(t *testing.T) {
}
}
-var lookupIANACNAMETests = []struct {
+var lookupCNAMETests = []struct {
name, cname string
}{
{"www.iana.org", "icann.org."},
{"www.iana.org.", "icann.org."},
+ {"www.google.com", "google.com."},
}
-func TestLookupIANACNAME(t *testing.T) {
+func TestLookupCNAME(t *testing.T) {
if testenv.Builder() == "" {
testenv.MustHaveExternalNetwork(t)
}
@@ -259,7 +260,7 @@ func TestLookupIANACNAME(t *testing.T) {
t.Skip("IPv4 is required")
}
- for _, tt := range lookupIANACNAMETests {
+ for _, tt := range lookupCNAMETests {
cname, err := LookupCNAME(tt.name)
if err != nil {
t.Fatal(err)
@@ -398,11 +399,11 @@ func TestDNSFlood(t *testing.T) {
for i := 0; i < N; i++ {
name := fmt.Sprintf("%d.net-test.golang.org", i)
go func() {
- _, err := lookupIPContext(ctxHalfTimeout, name)
+ _, err := DefaultResolver.LookupIPAddr(ctxHalfTimeout, name)
c <- err
}()
go func() {
- _, err := lookupIPContext(ctxTimeout, name)
+ _, err := DefaultResolver.LookupIPAddr(ctxTimeout, name)
c <- err
}()
}
@@ -616,7 +617,7 @@ func srvString(srvs []*SRV) string {
func TestLookupPort(t *testing.T) {
// See http://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.xhtml
//
- // Please be careful about adding new mappings for testings.
+ // Please be careful about adding new test cases.
// There are platforms having incomplete mappings for
// restricted resource access and security reasons.
type test struct {
@@ -648,8 +649,6 @@ func TestLookupPort(t *testing.T) {
}
switch runtime.GOOS {
- case "nacl":
- t.Skipf("not supported on %s", runtime.GOOS)
case "android":
if netGo {
t.Skipf("not supported on %s without cgo; see golang.org/issues/14576", runtime.GOOS)
@@ -670,3 +669,73 @@ func TestLookupPort(t *testing.T) {
}
}
}
+
+// Like TestLookupPort but with minimal tests that should always pass
+// because the answers are baked-in to the net package.
+func TestLookupPort_Minimal(t *testing.T) {
+ type test struct {
+ network string
+ name string
+ port int
+ }
+ var tests = []test{
+ {"tcp", "http", 80},
+ {"tcp", "HTTP", 80}, // case shouldn't matter
+ {"tcp", "https", 443},
+ {"tcp", "ssh", 22},
+ {"tcp", "gopher", 70},
+ {"tcp4", "http", 80},
+ {"tcp6", "http", 80},
+ }
+
+ for _, tt := range tests {
+ port, err := LookupPort(tt.network, tt.name)
+ if port != tt.port || err != nil {
+ t.Errorf("LookupPort(%q, %q) = %d, %v; want %d, error=nil", tt.network, tt.name, port, err, tt.port)
+ }
+ }
+}
+
+func TestLookupProtocol_Minimal(t *testing.T) {
+ type test struct {
+ name string
+ want int
+ }
+ var tests = []test{
+ {"tcp", 6},
+ {"TcP", 6}, // case shouldn't matter
+ {"icmp", 1},
+ {"igmp", 2},
+ {"udp", 17},
+ {"ipv6-icmp", 58},
+ }
+
+ for _, tt := range tests {
+ got, err := lookupProtocol(context.Background(), tt.name)
+ if got != tt.want || err != nil {
+ t.Errorf("LookupProtocol(%q) = %d, %v; want %d, error=nil", tt.name, got, err, tt.want)
+ }
+ }
+
+}
+
+func TestLookupNonLDH(t *testing.T) {
+ if runtime.GOOS == "nacl" {
+ t.Skip("skip on nacl")
+ }
+ if fixup := forceGoDNS(); fixup != nil {
+ defer fixup()
+ }
+
+ // "LDH" stands for letters, digits, and hyphens and is the usual
+ // description of standard DNS names.
+ // This test is checking that other kinds of names are reported
+ // as not found, not reported as invalid names.
+ addrs, err := LookupHost("!!!.###.bogus..domain.")
+ if err == nil {
+ t.Fatalf("lookup succeeded: %v", addrs)
+ }
+ if !strings.HasSuffix(err.Error(), errNoSuchHost.Error()) {
+ t.Fatalf("lookup error = %v, want %v", err, errNoSuchHost)
+ }
+}
diff --git a/libgo/go/net/lookup_unix.go b/libgo/go/net/lookup_unix.go
index 15397e8..be2ced9 100644
--- a/libgo/go/net/lookup_unix.go
+++ b/libgo/go/net/lookup_unix.go
@@ -26,7 +26,7 @@ func readProtocols() {
if len(f) < 2 {
continue
}
- if proto, _, ok := dtoi(f[1], 0); ok {
+ if proto, _, ok := dtoi(f[1]); ok {
if _, ok := protocols[f[0]]; !ok {
protocols[f[0]] = proto
}
@@ -45,16 +45,12 @@ func readProtocols() {
// returns correspondent protocol number.
func lookupProtocol(_ context.Context, name string) (int, error) {
onceReadProtocols.Do(readProtocols)
- proto, found := protocols[name]
- if !found {
- return 0, &AddrError{Err: "unknown IP protocol specified", Addr: name}
- }
- return proto, nil
+ return lookupProtocolMap(name)
}
-func lookupHost(ctx context.Context, host string) (addrs []string, err error) {
+func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) {
order := systemConf().hostLookupOrder(host)
- if order == hostLookupCgo {
+ if !r.PreferGo && order == hostLookupCgo {
if addrs, err, ok := cgoLookupHost(ctx, host); ok {
return addrs, err
}
@@ -64,7 +60,10 @@ func lookupHost(ctx context.Context, host string) (addrs []string, err error) {
return goLookupHostOrder(ctx, host, order)
}
-func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
+func (r *Resolver) lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
+ if r.PreferGo {
+ return goLookupIP(ctx, host)
+ }
order := systemConf().hostLookupOrder(host)
if order == hostLookupCgo {
if addrs, err, ok := cgoLookupIP(ctx, host); ok {
@@ -73,25 +72,28 @@ func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
// cgo not available (or netgo); fall back to Go's DNS resolver
order = hostLookupFilesDNS
}
- return goLookupIPOrder(ctx, host, order)
+ addrs, _, err = goLookupIPCNAMEOrder(ctx, host, order)
+ return
}
-func lookupPort(ctx context.Context, network, service string) (int, error) {
- // TODO: use the context if there ever becomes a need. Related
- // is issue 15321. But port lookup generally just involves
- // local files, and the os package has no context support. The
- // files might be on a remote filesystem, though. This should
- // probably race goroutines if ctx != context.Background().
- if systemConf().canUseCgo() {
+func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int, error) {
+ if !r.PreferGo && systemConf().canUseCgo() {
if port, err, ok := cgoLookupPort(ctx, network, service); ok {
+ if err != nil {
+ // Issue 18213: if cgo fails, first check to see whether we
+ // have the answer baked-in to the net package.
+ if port, err := goLookupPort(network, service); err == nil {
+ return port, nil
+ }
+ }
return port, err
}
}
return goLookupPort(network, service)
}
-func lookupCNAME(ctx context.Context, name string) (string, error) {
- if systemConf().canUseCgo() {
+func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error) {
+ if !r.PreferGo && systemConf().canUseCgo() {
if cname, err, ok := cgoLookupCNAME(ctx, name); ok {
return cname, err
}
@@ -99,7 +101,7 @@ func lookupCNAME(ctx context.Context, name string) (string, error) {
return goLookupCNAME(ctx, name)
}
-func lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
+func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
var target string
if service == "" && proto == "" {
target = name
@@ -119,7 +121,7 @@ func lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV
return cname, srvs, nil
}
-func lookupMX(ctx context.Context, name string) ([]*MX, error) {
+func (*Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
_, rrs, err := lookup(ctx, name, dnsTypeMX)
if err != nil {
return nil, err
@@ -133,7 +135,7 @@ func lookupMX(ctx context.Context, name string) ([]*MX, error) {
return mxs, nil
}
-func lookupNS(ctx context.Context, name string) ([]*NS, error) {
+func (*Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
_, rrs, err := lookup(ctx, name, dnsTypeNS)
if err != nil {
return nil, err
@@ -145,7 +147,7 @@ func lookupNS(ctx context.Context, name string) ([]*NS, error) {
return nss, nil
}
-func lookupTXT(ctx context.Context, name string) ([]string, error) {
+func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) {
_, rrs, err := lookup(ctx, name, dnsTypeTXT)
if err != nil {
return nil, err
@@ -157,8 +159,8 @@ func lookupTXT(ctx context.Context, name string) ([]string, error) {
return txts, nil
}
-func lookupAddr(ctx context.Context, addr string) ([]string, error) {
- if systemConf().canUseCgo() {
+func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) {
+ if !r.PreferGo && systemConf().canUseCgo() {
if ptrs, err, ok := cgoLookupPTR(ctx, addr); ok {
return ptrs, err
}
diff --git a/libgo/go/net/lookup_windows.go b/libgo/go/net/lookup_windows.go
index 5f65c2d..5808293 100644
--- a/libgo/go/net/lookup_windows.go
+++ b/libgo/go/net/lookup_windows.go
@@ -12,10 +12,20 @@ import (
"unsafe"
)
+const _WSAHOST_NOT_FOUND = syscall.Errno(11001)
+
+func winError(call string, err error) error {
+ switch err {
+ case _WSAHOST_NOT_FOUND:
+ return errNoSuchHost
+ }
+ return os.NewSyscallError(call, err)
+}
+
func getprotobyname(name string) (proto int, err error) {
p, err := syscall.GetProtoByName(name)
if err != nil {
- return 0, os.NewSyscallError("getprotobyname", err)
+ return 0, winError("getprotobyname", err)
}
return int(p.Proto), nil
}
@@ -43,7 +53,7 @@ func lookupProtocol(ctx context.Context, name string) (int, error) {
select {
case r := <-ch:
if r.err != nil {
- if proto, ok := protocols[name]; ok {
+ if proto, err := lookupProtocolMap(name); err == nil {
return proto, nil
}
r.err = &DNSError{Err: r.err.Error(), Name: name}
@@ -54,8 +64,8 @@ func lookupProtocol(ctx context.Context, name string) (int, error) {
}
}
-func lookupHost(ctx context.Context, name string) ([]string, error) {
- ips, err := lookupIP(ctx, name)
+func (r *Resolver) lookupHost(ctx context.Context, name string) ([]string, error) {
+ ips, err := r.lookupIP(ctx, name)
if err != nil {
return nil, err
}
@@ -66,8 +76,8 @@ func lookupHost(ctx context.Context, name string) ([]string, error) {
return addrs, nil
}
-func lookupIP(ctx context.Context, name string) ([]IPAddr, error) {
- // TODO(bradfitz,brainman): use ctx?
+func (r *Resolver) lookupIP(ctx context.Context, name string) ([]IPAddr, error) {
+ // TODO(bradfitz,brainman): use ctx more. See TODO below.
type ret struct {
addrs []IPAddr
@@ -85,7 +95,7 @@ func lookupIP(ctx context.Context, name string) ([]IPAddr, error) {
var result *syscall.AddrinfoW
e := syscall.GetAddrInfoW(syscall.StringToUTF16Ptr(name), nil, &hints, &result)
if e != nil {
- ch <- ret{err: &DNSError{Err: os.NewSyscallError("getaddrinfow", e).Error(), Name: name}}
+ ch <- ret{err: &DNSError{Err: winError("getaddrinfow", e).Error(), Name: name}}
}
defer syscall.FreeAddrInfoW(result)
addrs := make([]IPAddr, 0, 5)
@@ -125,7 +135,11 @@ func lookupIP(ctx context.Context, name string) ([]IPAddr, error) {
}
}
-func lookupPort(ctx context.Context, network, service string) (int, error) {
+func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int, error) {
+ if r.PreferGo {
+ return lookupPortMap(network, service)
+ }
+
// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
acquireThread()
defer releaseThread()
@@ -144,7 +158,10 @@ func lookupPort(ctx context.Context, network, service string) (int, error) {
var result *syscall.AddrinfoW
e := syscall.GetAddrInfoW(nil, syscall.StringToUTF16Ptr(service), &hints, &result)
if e != nil {
- return 0, &DNSError{Err: os.NewSyscallError("getaddrinfow", e).Error(), Name: network + "/" + service}
+ if port, err := lookupPortMap(network, service); err == nil {
+ return port, nil
+ }
+ return 0, &DNSError{Err: winError("getaddrinfow", e).Error(), Name: network + "/" + service}
}
defer syscall.FreeAddrInfoW(result)
if result == nil {
@@ -162,7 +179,7 @@ func lookupPort(ctx context.Context, network, service string) (int, error) {
return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service}
}
-func lookupCNAME(ctx context.Context, name string) (string, error) {
+func (*Resolver) lookupCNAME(ctx context.Context, name string) (string, error) {
// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
acquireThread()
defer releaseThread()
@@ -174,7 +191,7 @@ func lookupCNAME(ctx context.Context, name string) (string, error) {
return absDomainName([]byte(name)), nil
}
if e != nil {
- return "", &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: name}
+ return "", &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
}
defer syscall.DnsRecordListFree(r, 1)
@@ -183,7 +200,7 @@ func lookupCNAME(ctx context.Context, name string) (string, error) {
return absDomainName([]byte(cname)), nil
}
-func lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
+func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
acquireThread()
defer releaseThread()
@@ -196,7 +213,7 @@ func lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV
var r *syscall.DNSRecord
e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &r, nil)
if e != nil {
- return "", nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: target}
+ return "", nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: target}
}
defer syscall.DnsRecordListFree(r, 1)
@@ -209,14 +226,14 @@ func lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV
return absDomainName([]byte(target)), srvs, nil
}
-func lookupMX(ctx context.Context, name string) ([]*MX, error) {
+func (*Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
acquireThread()
defer releaseThread()
var r *syscall.DNSRecord
e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &r, nil)
if e != nil {
- return nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: name}
+ return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
}
defer syscall.DnsRecordListFree(r, 1)
@@ -229,14 +246,14 @@ func lookupMX(ctx context.Context, name string) ([]*MX, error) {
return mxs, nil
}
-func lookupNS(ctx context.Context, name string) ([]*NS, error) {
+func (*Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
acquireThread()
defer releaseThread()
var r *syscall.DNSRecord
e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &r, nil)
if e != nil {
- return nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: name}
+ return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
}
defer syscall.DnsRecordListFree(r, 1)
@@ -248,14 +265,14 @@ func lookupNS(ctx context.Context, name string) ([]*NS, error) {
return nss, nil
}
-func lookupTXT(ctx context.Context, name string) ([]string, error) {
+func (*Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) {
// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
acquireThread()
defer releaseThread()
var r *syscall.DNSRecord
e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &r, nil)
if e != nil {
- return nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: name}
+ return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
}
defer syscall.DnsRecordListFree(r, 1)
@@ -270,7 +287,7 @@ func lookupTXT(ctx context.Context, name string) ([]string, error) {
return txts, nil
}
-func lookupAddr(ctx context.Context, addr string) ([]string, error) {
+func (*Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) {
// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
acquireThread()
defer releaseThread()
@@ -281,7 +298,7 @@ func lookupAddr(ctx context.Context, addr string) ([]string, error) {
var r *syscall.DNSRecord
e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &r, nil)
if e != nil {
- return nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: addr}
+ return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: addr}
}
defer syscall.DnsRecordListFree(r, 1)
diff --git a/libgo/go/net/mail/message.go b/libgo/go/net/mail/message.go
index 0c00069..702b765 100644
--- a/libgo/go/net/mail/message.go
+++ b/libgo/go/net/mail/message.go
@@ -92,7 +92,8 @@ func init() {
}
}
-func parseDate(date string) (time.Time, error) {
+// ParseDate parses an RFC 5322 date string.
+func ParseDate(date string) (time.Time, error) {
for _, layout := range dateLayouts {
t, err := time.Parse(layout, date)
if err == nil {
@@ -106,7 +107,11 @@ func parseDate(date string) (time.Time, error) {
type Header map[string][]string
// Get gets the first value associated with the given key.
+// It is case insensitive; CanonicalMIMEHeaderKey is used
+// to canonicalize the provided key.
// If there are no values associated with the key, Get returns "".
+// To access multiple values of a key, or to use non-canonical keys,
+// access the map directly.
func (h Header) Get(key string) string {
return textproto.MIMEHeader(h).Get(key)
}
@@ -119,7 +124,7 @@ func (h Header) Date() (time.Time, error) {
if hdr == "" {
return time.Time{}, ErrHeaderNotPresent
}
- return parseDate(hdr)
+ return ParseDate(hdr)
}
// AddressList parses the named header field as a list of addresses.
@@ -345,6 +350,9 @@ func (p *addrParser) consumeAddrSpec() (spec string, err error) {
// quoted-string
debug.Printf("consumeAddrSpec: parsing quoted-string")
localPart, err = p.consumeQuotedString()
+ if localPart == "" {
+ err = errors.New("mail: empty quoted string in addr-spec")
+ }
} else {
// dot-atom
debug.Printf("consumeAddrSpec: parsing dot-atom")
@@ -462,9 +470,6 @@ Loop:
i += size
}
p.s = p.s[i+1:]
- if len(qsb) == 0 {
- return "", errors.New("mail: empty quoted-string")
- }
return string(qsb), nil
}
diff --git a/libgo/go/net/mail/message_test.go b/libgo/go/net/mail/message_test.go
index bbbba6b..f0761ab 100644
--- a/libgo/go/net/mail/message_test.go
+++ b/libgo/go/net/mail/message_test.go
@@ -110,11 +110,16 @@ func TestDateParsing(t *testing.T) {
}
date, err := hdr.Date()
if err != nil {
- t.Errorf("Failed parsing %q: %v", test.dateStr, err)
- continue
+ t.Errorf("Header(Date: %s).Date(): %v", test.dateStr, err)
+ } else if !date.Equal(test.exp) {
+ t.Errorf("Header(Date: %s).Date() = %+v, want %+v", test.dateStr, date, test.exp)
}
- if !date.Equal(test.exp) {
- t.Errorf("Parse of %q: got %+v, want %+v", test.dateStr, date, test.exp)
+
+ date, err = ParseDate(test.dateStr)
+ if err != nil {
+ t.Errorf("ParseDate(%s): %v", test.dateStr, err)
+ } else if !date.Equal(test.exp) {
+ t.Errorf("ParseDate(%s) = %+v, want %+v", test.dateStr, date, test.exp)
}
}
}
@@ -310,6 +315,16 @@ func TestAddressParsing(t *testing.T) {
},
},
},
+ // Issue 14866
+ {
+ `"" <emptystring@example.com>`,
+ []*Address{
+ {
+ Name: "",
+ Address: "emptystring@example.com",
+ },
+ },
+ },
}
for _, test := range tests {
if len(test.exp) == 1 {
diff --git a/libgo/go/net/main_test.go b/libgo/go/net/main_test.go
index 7573ded..28a8ff6 100644
--- a/libgo/go/net/main_test.go
+++ b/libgo/go/net/main_test.go
@@ -24,6 +24,8 @@ var (
)
var (
+ testTCPBig = flag.Bool("tcpbig", false, "whether to test massive size of data per read or write call on TCP connection")
+
testDNSFlood = flag.Bool("dnsflood", false, "whether to test DNS query flooding")
// If external IPv4 connectivity exists, we can try dialing
diff --git a/libgo/go/net/net.go b/libgo/go/net/net.go
index d6812d1..81206ea 100644
--- a/libgo/go/net/net.go
+++ b/libgo/go/net/net.go
@@ -102,9 +102,13 @@ func init() {
}
// Addr represents a network end point address.
+//
+// The two methods Network and String conventionally return strings
+// that can be passed as the arguments to Dial, but the exact form
+// and meaning of the strings is up to the implementation.
type Addr interface {
- Network() string // name of the network
- String() string // string form of address
+ Network() string // name of the network (for example, "tcp", "udp")
+ String() string // string form of address (for example, "192.0.2.1:25", "[2001:db8::1]:80")
}
// Conn is a generic stream-oriented network connection.
@@ -112,12 +116,12 @@ type Addr interface {
// Multiple goroutines may invoke methods on a Conn simultaneously.
type Conn interface {
// Read reads data from the connection.
- // Read can be made to time out and return a Error with Timeout() == true
+ // Read can be made to time out and return an Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetReadDeadline.
Read(b []byte) (n int, err error)
// Write writes data to the connection.
- // Write can be made to time out and return a Error with Timeout() == true
+ // Write can be made to time out and return an Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetWriteDeadline.
Write(b []byte) (n int, err error)
@@ -137,8 +141,10 @@ type Conn interface {
//
// A deadline is an absolute time after which I/O operations
// fail with a timeout (see type Error) instead of
- // blocking. The deadline applies to all future I/O, not just
- // the immediately following call to Read or Write.
+ // blocking. The deadline applies to all future and pending
+ // I/O, not just the immediately following call to Read or
+ // Write. After a deadline has been exceeded, the connection
+ // can be refreshed by setting a deadline in the future.
//
// An idle timeout can be implemented by repeatedly extending
// the deadline after successful Read or Write calls.
@@ -146,11 +152,13 @@ type Conn interface {
// A zero value for t means I/O operations will not time out.
SetDeadline(t time.Time) error
- // SetReadDeadline sets the deadline for future Read calls.
+ // SetReadDeadline sets the deadline for future Read calls
+ // and any currently-blocked Read call.
// A zero value for t means Read will not time out.
SetReadDeadline(t time.Time) error
- // SetWriteDeadline sets the deadline for future Write calls.
+ // SetWriteDeadline sets the deadline for future Write calls
+ // and any currently-blocked Write call.
// Even if write times out, it may return n > 0, indicating that
// some of the data was successfully written.
// A zero value for t means Write will not time out.
@@ -302,13 +310,13 @@ type PacketConn interface {
// bytes copied into b and the return address that
// was on the packet.
// ReadFrom can be made to time out and return
- // an error with Timeout() == true after a fixed time limit;
+ // an Error with Timeout() == true after a fixed time limit;
// see SetDeadline and SetReadDeadline.
ReadFrom(b []byte) (n int, addr Addr, err error)
// WriteTo writes a packet with payload b to addr.
// WriteTo can be made to time out and return
- // an error with Timeout() == true after a fixed time limit;
+ // an Error with Timeout() == true after a fixed time limit;
// see SetDeadline and SetWriteDeadline.
// On packet-oriented connections, write timeouts are rare.
WriteTo(b []byte, addr Addr) (n int, err error)
@@ -321,21 +329,32 @@ type PacketConn interface {
LocalAddr() Addr
// SetDeadline sets the read and write deadlines associated
- // with the connection.
+ // with the connection. It is equivalent to calling both
+ // SetReadDeadline and SetWriteDeadline.
+ //
+ // A deadline is an absolute time after which I/O operations
+ // fail with a timeout (see type Error) instead of
+ // blocking. The deadline applies to all future and pending
+ // I/O, not just the immediately following call to ReadFrom or
+ // WriteTo. After a deadline has been exceeded, the connection
+ // can be refreshed by setting a deadline in the future.
+ //
+ // An idle timeout can be implemented by repeatedly extending
+ // the deadline after successful ReadFrom or WriteTo calls.
+ //
+ // A zero value for t means I/O operations will not time out.
SetDeadline(t time.Time) error
- // SetReadDeadline sets the deadline for future Read calls.
- // If the deadline is reached, Read will fail with a timeout
- // (see type Error) instead of blocking.
- // A zero value for t means Read will not time out.
+ // SetReadDeadline sets the deadline for future ReadFrom calls
+ // and any currently-blocked ReadFrom call.
+ // A zero value for t means ReadFrom will not time out.
SetReadDeadline(t time.Time) error
- // SetWriteDeadline sets the deadline for future Write calls.
- // If the deadline is reached, Write will fail with a timeout
- // (see type Error) instead of blocking.
- // A zero value for t means Write will not time out.
+ // SetWriteDeadline sets the deadline for future WriteTo calls
+ // and any currently-blocked WriteTo call.
// Even if write times out, it may return n > 0, indicating that
// some of the data was successfully written.
+ // A zero value for t means WriteTo will not time out.
SetWriteDeadline(t time.Time) error
}
@@ -512,7 +531,7 @@ func (e *AddrError) Error() string {
}
s := e.Err
if e.Addr != "" {
- s += " " + e.Addr
+ s = "address " + e.Addr + ": " + s
}
return s
}
@@ -604,3 +623,66 @@ func acquireThread() {
func releaseThread() {
<-threadLimit
}
+
+// buffersWriter is the interface implemented by Conns that support a
+// "writev"-like batch write optimization.
+// writeBuffers should fully consume and write all chunks from the
+// provided Buffers, else it should report a non-nil error.
+type buffersWriter interface {
+ writeBuffers(*Buffers) (int64, error)
+}
+
+var testHookDidWritev = func(wrote int) {}
+
+// Buffers contains zero or more runs of bytes to write.
+//
+// On certain machines, for certain types of connections, this is
+// optimized into an OS-specific batch write operation (such as
+// "writev").
+type Buffers [][]byte
+
+var (
+ _ io.WriterTo = (*Buffers)(nil)
+ _ io.Reader = (*Buffers)(nil)
+)
+
+func (v *Buffers) WriteTo(w io.Writer) (n int64, err error) {
+ if wv, ok := w.(buffersWriter); ok {
+ return wv.writeBuffers(v)
+ }
+ for _, b := range *v {
+ nb, err := w.Write(b)
+ n += int64(nb)
+ if err != nil {
+ v.consume(n)
+ return n, err
+ }
+ }
+ v.consume(n)
+ return n, nil
+}
+
+func (v *Buffers) Read(p []byte) (n int, err error) {
+ for len(p) > 0 && len(*v) > 0 {
+ n0 := copy(p, (*v)[0])
+ v.consume(int64(n0))
+ p = p[n0:]
+ n += n0
+ }
+ if len(*v) == 0 {
+ err = io.EOF
+ }
+ return
+}
+
+func (v *Buffers) consume(n int64) {
+ for len(*v) > 0 {
+ ln0 := int64(len((*v)[0]))
+ if ln0 > n {
+ (*v)[0] = (*v)[0][n:]
+ return
+ }
+ n -= ln0
+ *v = (*v)[1:]
+ }
+}
diff --git a/libgo/go/net/net_test.go b/libgo/go/net/net_test.go
index b2f825d..9a9a7e5 100644
--- a/libgo/go/net/net_test.go
+++ b/libgo/go/net/net_test.go
@@ -5,6 +5,8 @@
package net
import (
+ "errors"
+ "fmt"
"io"
"net/internal/socktest"
"os"
@@ -15,7 +17,7 @@ import (
func TestCloseRead(t *testing.T) {
switch runtime.GOOS {
- case "nacl", "plan9":
+ case "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
}
@@ -414,3 +416,103 @@ func TestZeroByteRead(t *testing.T) {
}
}
}
+
+// withTCPConnPair sets up a TCP connection between two peers, then
+// runs peer1 and peer2 concurrently. withTCPConnPair returns when
+// both have completed.
+func withTCPConnPair(t *testing.T, peer1, peer2 func(c *TCPConn) error) {
+ ln, err := newLocalListener("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+ errc := make(chan error, 2)
+ go func() {
+ c1, err := ln.Accept()
+ if err != nil {
+ errc <- err
+ return
+ }
+ defer c1.Close()
+ errc <- peer1(c1.(*TCPConn))
+ }()
+ go func() {
+ c2, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ errc <- err
+ return
+ }
+ defer c2.Close()
+ errc <- peer2(c2.(*TCPConn))
+ }()
+ for i := 0; i < 2; i++ {
+ if err := <-errc; err != nil {
+ t.Fatal(err)
+ }
+ }
+}
+
+// Tests that a blocked Read is interrupted by a concurrent SetReadDeadline
+// modifying that Conn's read deadline to the past.
+// See golang.org/cl/30164 which documented this. The net/http package
+// depends on this.
+func TestReadTimeoutUnblocksRead(t *testing.T) {
+ serverDone := make(chan struct{})
+ server := func(cs *TCPConn) error {
+ defer close(serverDone)
+ errc := make(chan error, 1)
+ go func() {
+ defer close(errc)
+ go func() {
+ // TODO: find a better way to wait
+ // until we're blocked in the cs.Read
+ // call below. Sleep is lame.
+ time.Sleep(100 * time.Millisecond)
+
+ // Interrupt the upcoming Read, unblocking it:
+ cs.SetReadDeadline(time.Unix(123, 0)) // time in the past
+ }()
+ var buf [1]byte
+ n, err := cs.Read(buf[:1])
+ if n != 0 || err == nil {
+ errc <- fmt.Errorf("Read = %v, %v; want 0, non-nil", n, err)
+ }
+ }()
+ select {
+ case err := <-errc:
+ return err
+ case <-time.After(5 * time.Second):
+ buf := make([]byte, 2<<20)
+ buf = buf[:runtime.Stack(buf, true)]
+ println("Stacks at timeout:\n", string(buf))
+ return errors.New("timeout waiting for Read to finish")
+ }
+
+ }
+ // Do nothing in the client. Never write. Just wait for the
+ // server's half to be done.
+ client := func(*TCPConn) error {
+ <-serverDone
+ return nil
+ }
+ withTCPConnPair(t, client, server)
+}
+
+// Issue 17695: verify that a blocked Read is woken up by a Close.
+func TestCloseUnblocksRead(t *testing.T) {
+ t.Parallel()
+ server := func(cs *TCPConn) error {
+ // Give the client time to get stuck in a Read:
+ time.Sleep(20 * time.Millisecond)
+ cs.Close()
+ return nil
+ }
+ client := func(ss *TCPConn) error {
+ n, err := ss.Read([]byte{0})
+ if n != 0 || err != io.EOF {
+ return fmt.Errorf("Read = %v, %v; want 0, EOF", n, err)
+ }
+ return nil
+ }
+ withTCPConnPair(t, client, server)
+}
diff --git a/libgo/go/net/parse.go b/libgo/go/net/parse.go
index 2c6b98a..b270159 100644
--- a/libgo/go/net/parse.go
+++ b/libgo/go/net/parse.go
@@ -124,39 +124,27 @@ func getFields(s string) []string { return splitAtBytes(s, " \r\t\n") }
// Bigger than we need, not too big to worry about overflow
const big = 0xFFFFFF
-// Decimal to integer starting at &s[i0].
-// Returns number, new offset, success.
-func dtoi(s string, i0 int) (n int, i int, ok bool) {
+// Decimal to integer.
+// Returns number, characters consumed, success.
+func dtoi(s string) (n int, i int, ok bool) {
n = 0
- neg := false
- if len(s) > 0 && s[0] == '-' {
- neg = true
- s = s[1:]
- }
- for i = i0; i < len(s) && '0' <= s[i] && s[i] <= '9'; i++ {
+ for i = 0; i < len(s) && '0' <= s[i] && s[i] <= '9'; i++ {
n = n*10 + int(s[i]-'0')
if n >= big {
- if neg {
- return -big, i + 1, false
- }
return big, i, false
}
}
- if i == i0 {
- return 0, i, false
- }
- if neg {
- n = -n
- i++
+ if i == 0 {
+ return 0, 0, false
}
return n, i, true
}
-// Hexadecimal to integer starting at &s[i0].
-// Returns number, new offset, success.
-func xtoi(s string, i0 int) (n int, i int, ok bool) {
+// Hexadecimal to integer.
+// Returns number, characters consumed, success.
+func xtoi(s string) (n int, i int, ok bool) {
n = 0
- for i = i0; i < len(s); i++ {
+ for i = 0; i < len(s); i++ {
if '0' <= s[i] && s[i] <= '9' {
n *= 16
n += int(s[i] - '0')
@@ -173,7 +161,7 @@ func xtoi(s string, i0 int) (n int, i int, ok bool) {
return 0, i, false
}
}
- if i == i0 {
+ if i == 0 {
return 0, i, false
}
return n, i, true
@@ -187,7 +175,7 @@ func xtoi2(s string, e byte) (byte, bool) {
if len(s) > 2 && s[2] != e {
return 0, false
}
- n, ei, ok := xtoi(s[:2], 0)
+ n, ei, ok := xtoi(s[:2])
return byte(n), ok && ei == 2
}
@@ -348,22 +336,28 @@ func stringsHasSuffix(s, suffix string) bool {
// stringsHasSuffixFold reports whether s ends in suffix,
// ASCII-case-insensitively.
func stringsHasSuffixFold(s, suffix string) bool {
- if len(suffix) > len(s) {
+ return len(s) >= len(suffix) && stringsEqualFold(s[len(s)-len(suffix):], suffix)
+}
+
+// stringsHasPrefix is strings.HasPrefix. It reports whether s begins with prefix.
+func stringsHasPrefix(s, prefix string) bool {
+ return len(s) >= len(prefix) && s[:len(prefix)] == prefix
+}
+
+// stringsEqualFold is strings.EqualFold, ASCII only. It reports whether s and t
+// are equal, ASCII-case-insensitively.
+func stringsEqualFold(s, t string) bool {
+ if len(s) != len(t) {
return false
}
- for i := 0; i < len(suffix); i++ {
- if lowerASCII(suffix[i]) != lowerASCII(s[len(s)-len(suffix)+i]) {
+ for i := 0; i < len(s); i++ {
+ if lowerASCII(s[i]) != lowerASCII(t[i]) {
return false
}
}
return true
}
-// stringsHasPrefix is strings.HasPrefix. It reports whether s begins with prefix.
-func stringsHasPrefix(s, prefix string) bool {
- return len(s) >= len(prefix) && s[:len(prefix)] == prefix
-}
-
func readFull(r io.Reader) (all []byte, err error) {
buf := make([]byte, 1024)
for {
diff --git a/libgo/go/net/parse_test.go b/libgo/go/net/parse_test.go
index fec9200..c5f8bfd 100644
--- a/libgo/go/net/parse_test.go
+++ b/libgo/go/net/parse_test.go
@@ -86,14 +86,13 @@ func TestDtoi(t *testing.T) {
ok bool
}{
{"", 0, 0, false},
-
- {"-123456789", -big, 9, false},
- {"-1", -1, 2, true},
{"0", 0, 1, true},
{"65536", 65536, 5, true},
{"123456789", big, 8, false},
+ {"-0", 0, 0, false},
+ {"-1234", 0, 0, false},
} {
- n, i, ok := dtoi(tt.in, 0)
+ n, i, ok := dtoi(tt.in)
if n != tt.out || i != tt.off || ok != tt.ok {
t.Errorf("got %d, %d, %v; want %d, %d, %v", n, i, ok, tt.out, tt.off, tt.ok)
}
diff --git a/libgo/go/net/port_unix.go b/libgo/go/net/port_unix.go
index badf8ab..868d1e4 100644
--- a/libgo/go/net/port_unix.go
+++ b/libgo/go/net/port_unix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build darwin dragonfly freebsd linux netbsd openbsd solaris
+// +build darwin dragonfly freebsd linux netbsd openbsd solaris nacl
// Read system port mappings from /etc/services
@@ -10,31 +10,24 @@ package net
import "sync"
-// services contains minimal mappings between services names and port
-// numbers for platforms that don't have a complete list of port numbers
-// (some Solaris distros).
-var services = map[string]map[string]int{
- "tcp": {"http": 80},
-}
-var servicesError error
var onceReadServices sync.Once
func readServices() {
- var file *file
- if file, servicesError = open("/etc/services"); servicesError != nil {
+ file, err := open("/etc/services")
+ if err != nil {
return
}
for line, ok := file.readLine(); ok; line, ok = file.readLine() {
// "http 80/tcp www www-http # World Wide Web HTTP"
if i := byteIndex(line, '#'); i >= 0 {
- line = line[0:i]
+ line = line[:i]
}
f := getFields(line)
if len(f) < 2 {
continue
}
portnet := f[1] // "80/tcp"
- port, j, ok := dtoi(portnet, 0)
+ port, j, ok := dtoi(portnet)
if !ok || port <= 0 || j >= len(portnet) || portnet[j] != '/' {
continue
}
@@ -56,18 +49,5 @@ func readServices() {
// goLookupPort is the native Go implementation of LookupPort.
func goLookupPort(network, service string) (port int, err error) {
onceReadServices.Do(readServices)
-
- switch network {
- case "tcp4", "tcp6":
- network = "tcp"
- case "udp4", "udp6":
- network = "udp"
- }
-
- if m, ok := services[network]; ok {
- if port, ok = m[service]; ok {
- return
- }
- }
- return 0, &AddrError{Err: "unknown port", Addr: network + "/" + service}
+ return lookupPortMap(network, service)
}
diff --git a/libgo/go/net/rpc/client.go b/libgo/go/net/rpc/client.go
index 862fb1a..fce6a48 100644
--- a/libgo/go/net/rpc/client.go
+++ b/libgo/go/net/rpc/client.go
@@ -274,6 +274,8 @@ func Dial(network, address string) (*Client, error) {
return NewClient(conn), nil
}
+// Close calls the underlying codec's Close method. If the connection is already
+// shutting down, ErrShutdown is returned.
func (client *Client) Close() error {
client.mutex.Lock()
if client.closing {
diff --git a/libgo/go/net/rpc/client_test.go b/libgo/go/net/rpc/client_test.go
index ba11ff8..d116d2a 100644
--- a/libgo/go/net/rpc/client_test.go
+++ b/libgo/go/net/rpc/client_test.go
@@ -8,7 +8,6 @@ import (
"errors"
"fmt"
"net"
- "runtime"
"strings"
"testing"
)
@@ -53,9 +52,6 @@ func (s *S) Recv(nul *struct{}, reply *R) error {
}
func TestGobError(t *testing.T) {
- if runtime.GOOS == "plan9" {
- t.Skip("skipping test; see https://golang.org/issue/8908")
- }
defer func() {
err := recover()
if err == nil {
diff --git a/libgo/go/net/rpc/server.go b/libgo/go/net/rpc/server.go
index cff3241..18ea629 100644
--- a/libgo/go/net/rpc/server.go
+++ b/libgo/go/net/rpc/server.go
@@ -23,7 +23,7 @@
func (t *T) MethodName(argType T1, replyType *T2) error
- where T, T1 and T2 can be marshaled by encoding/gob.
+ where T1 and T2 can be marshaled by encoding/gob.
These requirements apply even if a different codec is used.
(In the future, these requirements may soften for custom codecs.)
@@ -55,6 +55,8 @@
package server
+ import "errors"
+
type Args struct {
A, B int
}
@@ -119,6 +121,8 @@
A server implementation will often provide a simple, type-safe wrapper for the
client.
+
+ The net/rpc package is frozen and is not accepting new features.
*/
package rpc
diff --git a/libgo/go/net/rpc/server_test.go b/libgo/go/net/rpc/server_test.go
index d04271d..8369c9d 100644
--- a/libgo/go/net/rpc/server_test.go
+++ b/libgo/go/net/rpc/server_test.go
@@ -693,7 +693,8 @@ func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) {
B := call.Args.(*Args).B
C := call.Reply.(*Reply).C
if A+B != C {
- b.Fatalf("incorrect reply: Add: expected %d got %d", A+B, C)
+ b.Errorf("incorrect reply: Add: expected %d got %d", A+B, C)
+ return
}
<-gate
if atomic.AddInt32(&recv, -1) == 0 {
diff --git a/libgo/go/net/smtp/smtp.go b/libgo/go/net/smtp/smtp.go
index 9e04dd7..a408fa5 100644
--- a/libgo/go/net/smtp/smtp.go
+++ b/libgo/go/net/smtp/smtp.go
@@ -9,7 +9,7 @@
// STARTTLS RFC 3207
// Additional extensions may be handled by clients.
//
-// The smtp package is frozen and not accepting new features.
+// The smtp package is frozen and is not accepting new features.
// Some external packages provide more functionality. See:
//
// https://godoc.org/?q=smtp
@@ -19,6 +19,7 @@ import (
"crypto/tls"
"encoding/base64"
"errors"
+ "fmt"
"io"
"net"
"net/textproto"
@@ -200,7 +201,7 @@ func (c *Client) Auth(a Auth) error {
}
resp64 := make([]byte, encoding.EncodedLen(len(resp)))
encoding.Encode(resp64, resp)
- code, msg64, err := c.cmd(0, "AUTH %s %s", mech, resp64)
+ code, msg64, err := c.cmd(0, strings.TrimSpace(fmt.Sprintf("AUTH %s %s", mech, resp64)))
for err == nil {
var msg []byte
switch code {
diff --git a/libgo/go/net/smtp/smtp_test.go b/libgo/go/net/smtp/smtp_test.go
index 3ae0d5b..c48fae6d 100644
--- a/libgo/go/net/smtp/smtp_test.go
+++ b/libgo/go/net/smtp/smtp_test.go
@@ -94,6 +94,46 @@ func TestAuthPlain(t *testing.T) {
}
}
+// Issue 17794: don't send a trailing space on AUTH command when there's no password.
+func TestClientAuthTrimSpace(t *testing.T) {
+ server := "220 hello world\r\n" +
+ "200 some more"
+ var wrote bytes.Buffer
+ var fake faker
+ fake.ReadWriter = struct {
+ io.Reader
+ io.Writer
+ }{
+ strings.NewReader(server),
+ &wrote,
+ }
+ c, err := NewClient(fake, "fake.host")
+ if err != nil {
+ t.Fatalf("NewClient: %v", err)
+ }
+ c.tls = true
+ c.didHello = true
+ c.Auth(toServerEmptyAuth{})
+ c.Close()
+ if got, want := wrote.String(), "AUTH FOOAUTH\r\n*\r\nQUIT\r\n"; got != want {
+ t.Errorf("wrote %q; want %q", got, want)
+ }
+}
+
+// toServerEmptyAuth is an implementation of Auth that only implements
+// the Start method, and returns "FOOAUTH", nil, nil. Notably, it returns
+// zero bytes for "toServer" so we can test that we don't send spaces at
+// the end of the line. See TestClientAuthTrimSpace.
+type toServerEmptyAuth struct{}
+
+func (toServerEmptyAuth) Start(server *ServerInfo) (proto string, toServer []byte, err error) {
+ return "FOOAUTH", nil, nil
+}
+
+func (toServerEmptyAuth) Next(fromServer []byte, more bool) (toServer []byte, err error) {
+ panic("unexpected call")
+}
+
type faker struct {
io.ReadWriter
}
@@ -716,23 +756,24 @@ func sendMail(hostPort string) error {
// generated from src/crypto/tls:
// go run generate_cert.go --rsa-bits 512 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
-MIIBdzCCASOgAwIBAgIBADALBgkqhkiG9w0BAQUwEjEQMA4GA1UEChMHQWNtZSBD
-bzAeFw03MDAxMDEwMDAwMDBaFw00OTEyMzEyMzU5NTlaMBIxEDAOBgNVBAoTB0Fj
-bWUgQ28wWjALBgkqhkiG9w0BAQEDSwAwSAJBAN55NcYKZeInyTuhcCwFMhDHCmwa
-IUSdtXdcbItRB/yfXGBhiex00IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEA
-AaNoMGYwDgYDVR0PAQH/BAQDAgCkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1Ud
-EwEB/wQFMAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAA
-AAAAAAAAAAAAAAEwCwYJKoZIhvcNAQEFA0EAAoQn/ytgqpiLcZu9XKbCJsJcvkgk
-Se6AbGXgSlq+ZCEVo0qIwSgeBqmsJxUu7NCSOwVJLYNEBO2DtIxoYVk+MA==
+MIIBjjCCATigAwIBAgIQMon9v0s3pDFXvAMnPgelpzANBgkqhkiG9w0BAQsFADAS
+MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw
+MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJB
+AM0u/mNXKkhAzNsFkwKZPSpC4lZZaePQ55IyaJv3ovMM2smvthnlqaUfVKVmz7FF
+wLP9csX6vGtvkZg1uWAtvfkCAwEAAaNoMGYwDgYDVR0PAQH/BAQDAgKkMBMGA1Ud
+JQQMMAoGCCsGAQUFBwMBMA8GA1UdEwEB/wQFMAMBAf8wLgYDVR0RBCcwJYILZXhh
+bXBsZS5jb22HBH8AAAGHEAAAAAAAAAAAAAAAAAAAAAEwDQYJKoZIhvcNAQELBQAD
+QQBOZsFVC7IwX+qibmSbt2IPHkUgXhfbq0a9MYhD6tHcj4gbDcTXh4kZCbgHCz22
+gfSj2/G2wxzopoISVDucuncj
-----END CERTIFICATE-----`)
// localhostKey is the private key for localhostCert.
var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
-MIIBPAIBAAJBAN55NcYKZeInyTuhcCwFMhDHCmwaIUSdtXdcbItRB/yfXGBhiex0
-0IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEAAQJBAQdUx66rfh8sYsgfdcvV
-NoafYpnEcB5s4m/vSVe6SU7dCK6eYec9f9wpT353ljhDUHq3EbmE4foNzJngh35d
-AekCIQDhRQG5Li0Wj8TM4obOnnXUXf1jRv0UkzE9AHWLG5q3AwIhAPzSjpYUDjVW
-MCUXgckTpKCuGwbJk7424Nb8bLzf3kllAiA5mUBgjfr/WtFSJdWcPQ4Zt9KTMNKD
-EUO0ukpTwEIl6wIhAMbGqZK3zAAFdq8DD2jPx+UJXnh0rnOkZBzDtJ6/iN69AiEA
-1Aq8MJgTaYsDQWyU/hDq5YkDJc9e9DSCvUIzqxQWMQE=
+MIIBOwIBAAJBAM0u/mNXKkhAzNsFkwKZPSpC4lZZaePQ55IyaJv3ovMM2smvthnl
+qaUfVKVmz7FFwLP9csX6vGtvkZg1uWAtvfkCAwEAAQJART2qkxODLUbQ2siSx7m2
+rmBLyR/7X+nLe8aPDrMOxj3heDNl4YlaAYLexbcY8d7VDfCRBKYoAOP0UCP1Vhuf
+UQIhAO6PEI55K3SpNIdc2k5f0xz+9rodJCYzu51EwWX7r8ufAiEA3C9EkLiU2NuK
+3L3DHCN5IlUSN1Nr/lw8NIt50Yorj2cCIQCDw1VbvCV6bDLtSSXzAA51B4ZzScE7
+sHtB5EYF9Dwm9QIhAJuCquuH4mDzVjUntXjXOQPdj7sRqVGCNWdrJwOukat7AiAy
+LXLEwb77DIPoI5ZuaXQC+MnyyJj1ExC9RFcGz+bexA==
-----END RSA PRIVATE KEY-----`)
diff --git a/libgo/go/net/sock_linux.go b/libgo/go/net/sock_linux.go
index e2732c5..7bca376 100644
--- a/libgo/go/net/sock_linux.go
+++ b/libgo/go/net/sock_linux.go
@@ -17,7 +17,7 @@ func maxListenerBacklog() int {
return syscall.SOMAXCONN
}
f := getFields(l)
- n, _, ok := dtoi(f[0], 0)
+ n, _, ok := dtoi(f[0])
if n == 0 || !ok {
return syscall.SOMAXCONN
}
diff --git a/libgo/go/net/sock_posix.go b/libgo/go/net/sock_posix.go
index c3af27b..16351e1 100644
--- a/libgo/go/net/sock_posix.go
+++ b/libgo/go/net/sock_posix.go
@@ -30,6 +30,9 @@ type sockaddr interface {
// interface. It returns a nil interface when the address is
// nil.
sockaddr(family int) (syscall.Sockaddr, error)
+
+ // toLocal maps the zero address to a local system address (127.0.0.1 or ::1)
+ toLocal(net string) sockaddr
}
// socket returns a network file descriptor that is ready for
diff --git a/libgo/go/net/tcpsock.go b/libgo/go/net/tcpsock.go
index 7cffcc5..69731eb 100644
--- a/libgo/go/net/tcpsock.go
+++ b/libgo/go/net/tcpsock.go
@@ -12,6 +12,9 @@ import (
"time"
)
+// BUG(mikio): On Windows, the File method of TCPListener is not
+// implemented.
+
// TCPAddr represents the address of a TCP end point.
type TCPAddr struct {
IP IP
@@ -53,6 +56,9 @@ func (a *TCPAddr) opAddr() Addr {
// "tcp6". A literal address or host name for IPv6 must be enclosed
// in square brackets, as in "[::1]:80", "[ipv6-host]:http" or
// "[ipv6-host%zone]:80".
+//
+// Resolving a hostname is not recommended because this returns at most
+// one of its IP addresses.
func ResolveTCPAddr(net, addr string) (*TCPAddr, error) {
switch net {
case "tcp", "tcp4", "tcp6":
@@ -61,7 +67,7 @@ func ResolveTCPAddr(net, addr string) (*TCPAddr, error) {
default:
return nil, UnknownNetworkError(net)
}
- addrs, err := internetAddrList(context.Background(), net, addr)
+ addrs, err := DefaultResolver.internetAddrList(context.Background(), net, addr)
if err != nil {
return nil, err
}
@@ -81,7 +87,7 @@ func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) {
}
n, err := c.readFrom(r)
if err != nil && err != io.EOF {
- err = &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ err = &OpError{Op: "readfrom", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
}
return n, err
}
diff --git a/libgo/go/net/tcpsock_posix.go b/libgo/go/net/tcpsock_posix.go
index c9a8b68..9641e5c 100644
--- a/libgo/go/net/tcpsock_posix.go
+++ b/libgo/go/net/tcpsock_posix.go
@@ -40,6 +40,10 @@ func (a *TCPAddr) sockaddr(family int) (syscall.Sockaddr, error) {
return ipToSockaddr(family, a.IP, a.Port, a.Zone)
}
+func (a *TCPAddr) toLocal(net string) sockaddr {
+ return &TCPAddr{loopbackIP(net), a.Port, a.Zone}
+}
+
func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
if n, err, handled := sendFile(c.fd, r); handled {
return n, err
diff --git a/libgo/go/net/tcpsock_test.go b/libgo/go/net/tcpsock_test.go
index a8d93b0..5115422 100644
--- a/libgo/go/net/tcpsock_test.go
+++ b/libgo/go/net/tcpsock_test.go
@@ -5,6 +5,7 @@
package net
import (
+ "fmt"
"internal/testenv"
"io"
"reflect"
@@ -310,6 +311,17 @@ var resolveTCPAddrTests = []resolveTCPAddrTest{
{"tcp", ":12345", &TCPAddr{Port: 12345}, nil},
{"http", "127.0.0.1:0", nil, UnknownNetworkError("http")},
+
+ {"tcp", "127.0.0.1:http", &TCPAddr{IP: ParseIP("127.0.0.1"), Port: 80}, nil},
+ {"tcp", "[::ffff:127.0.0.1]:http", &TCPAddr{IP: ParseIP("::ffff:127.0.0.1"), Port: 80}, nil},
+ {"tcp", "[2001:db8::1]:http", &TCPAddr{IP: ParseIP("2001:db8::1"), Port: 80}, nil},
+ {"tcp4", "127.0.0.1:http", &TCPAddr{IP: ParseIP("127.0.0.1"), Port: 80}, nil},
+ {"tcp4", "[::ffff:127.0.0.1]:http", &TCPAddr{IP: ParseIP("127.0.0.1"), Port: 80}, nil},
+ {"tcp6", "[2001:db8::1]:http", &TCPAddr{IP: ParseIP("2001:db8::1"), Port: 80}, nil},
+
+ {"tcp4", "[2001:db8::1]:http", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "2001:db8::1"}},
+ {"tcp6", "127.0.0.1:http", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "127.0.0.1"}},
+ {"tcp6", "[::ffff:127.0.0.1]:http", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "::ffff:127.0.0.1"}},
}
func TestResolveTCPAddr(t *testing.T) {
@@ -317,21 +329,17 @@ func TestResolveTCPAddr(t *testing.T) {
defer func() { testHookLookupIP = origTestHookLookupIP }()
testHookLookupIP = lookupLocalhost
- for i, tt := range resolveTCPAddrTests {
+ for _, tt := range resolveTCPAddrTests {
addr, err := ResolveTCPAddr(tt.network, tt.litAddrOrName)
- if err != tt.err {
- t.Errorf("#%d: %v", i, err)
- } else if !reflect.DeepEqual(addr, tt.addr) {
- t.Errorf("#%d: got %#v; want %#v", i, addr, tt.addr)
- }
- if err != nil {
+ if !reflect.DeepEqual(addr, tt.addr) || !reflect.DeepEqual(err, tt.err) {
+ t.Errorf("ResolveTCPAddr(%q, %q) = %#v, %v, want %#v, %v", tt.network, tt.litAddrOrName, addr, err, tt.addr, tt.err)
continue
}
- rtaddr, err := ResolveTCPAddr(addr.Network(), addr.String())
- if err != nil {
- t.Errorf("#%d: %v", i, err)
- } else if !reflect.DeepEqual(rtaddr, addr) {
- t.Errorf("#%d: got %#v; want %#v", i, rtaddr, addr)
+ if err == nil {
+ addr2, err := ResolveTCPAddr(addr.Network(), addr.String())
+ if !reflect.DeepEqual(addr2, tt.addr) || err != tt.err {
+ t.Errorf("(%q, %q): ResolveTCPAddr(%q, %q) = %#v, %v, want %#v, %v", tt.network, tt.litAddrOrName, addr.Network(), addr.String(), addr2, err, tt.addr, tt.err)
+ }
}
}
}
@@ -459,12 +467,19 @@ func TestTCPConcurrentAccept(t *testing.T) {
}
func TestTCPReadWriteAllocs(t *testing.T) {
+ if runtime.Compiler == "gccgo" {
+ t.Skip("skipping for gccgo until escape analysis is enabled")
+ }
+
switch runtime.GOOS {
- case "nacl", "windows":
+ case "plan9":
+ // The implementation of asynchronous cancelable
+ // I/O on Plan 9 allocates memory.
+ // See net/fd_io_plan9.go.
+ t.Skipf("not supported on %s", runtime.GOOS)
+ case "nacl":
// NaCl needs to allocate pseudo file descriptor
// stuff. See syscall/fd_nacl.go.
- // Windows uses closures and channels for IO
- // completion port-based netpoll. See fd_windows.go.
t.Skipf("not supported on %s", runtime.GOOS)
}
@@ -474,7 +489,7 @@ func TestTCPReadWriteAllocs(t *testing.T) {
}
defer ln.Close()
var server Conn
- errc := make(chan error)
+ errc := make(chan error, 1)
go func() {
var err error
server, err = ln.Accept()
@@ -489,6 +504,7 @@ func TestTCPReadWriteAllocs(t *testing.T) {
t.Fatal(err)
}
defer server.Close()
+
var buf [128]byte
allocs := testing.AllocsPerRun(1000, func() {
_, err := server.Write(buf[:])
@@ -504,6 +520,28 @@ func TestTCPReadWriteAllocs(t *testing.T) {
if allocs > 7 {
t.Fatalf("got %v; want 0", allocs)
}
+
+ var bufwrt [128]byte
+ ch := make(chan bool)
+ defer close(ch)
+ go func() {
+ for <-ch {
+ _, err := server.Write(bufwrt[:])
+ errc <- err
+ }
+ }()
+ allocs = testing.AllocsPerRun(1000, func() {
+ ch <- true
+ if _, err = io.ReadFull(client, buf[:]); err != nil {
+ t.Fatal(err)
+ }
+ if err := <-errc; err != nil {
+ t.Fatal(err)
+ }
+ })
+ if allocs > 0 {
+ t.Fatalf("got %v; want 0", allocs)
+ }
}
func TestTCPStress(t *testing.T) {
@@ -634,3 +672,58 @@ func TestTCPSelfConnect(t *testing.T) {
}
}
}
+
+// Test that >32-bit reads work on 64-bit systems.
+// On 32-bit systems this tests that maxint reads work.
+func TestTCPBig(t *testing.T) {
+ if !*testTCPBig {
+ t.Skip("test disabled; use -tcpbig to enable")
+ }
+
+ for _, writev := range []bool{false, true} {
+ t.Run(fmt.Sprintf("writev=%v", writev), func(t *testing.T) {
+ ln, err := newLocalListener("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+
+ x := int(1 << 30)
+ x = x*5 + 1<<20 // just over 5 GB on 64-bit, just over 1GB on 32-bit
+ done := make(chan int)
+ go func() {
+ defer close(done)
+ c, err := ln.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ buf := make([]byte, x)
+ var n int
+ if writev {
+ var n64 int64
+ n64, err = (&Buffers{buf}).WriteTo(c)
+ n = int(n64)
+ } else {
+ n, err = c.Write(buf)
+ }
+ if n != len(buf) || err != nil {
+ t.Errorf("Write(buf) = %d, %v, want %d, nil", n, err, x)
+ }
+ c.Close()
+ }()
+
+ c, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ buf := make([]byte, x)
+ n, err := io.ReadFull(c, buf)
+ if n != len(buf) || err != nil {
+ t.Errorf("Read(buf) = %d, %v, want %d, nil", n, err, x)
+ }
+ c.Close()
+ <-done
+ })
+ }
+}
diff --git a/libgo/go/net/tcpsock_unix_test.go b/libgo/go/net/tcpsock_unix_test.go
index c07f7d7..2375fe2 100644
--- a/libgo/go/net/tcpsock_unix_test.go
+++ b/libgo/go/net/tcpsock_unix_test.go
@@ -15,7 +15,7 @@ import (
)
// See golang.org/issue/14548.
-func TestTCPSupriousConnSetupCompletion(t *testing.T) {
+func TestTCPSpuriousConnSetupCompletion(t *testing.T) {
if testing.Short() {
t.Skip("skipping in short mode")
}
@@ -57,7 +57,7 @@ func TestTCPSupriousConnSetupCompletion(t *testing.T) {
c, err := d.Dial(ln.Addr().Network(), ln.Addr().String())
if err != nil {
if perr := parseDialError(err); perr != nil {
- t.Errorf("#%d: %v", i, err)
+ t.Errorf("#%d: %v (original error: %v)", i, perr, err)
}
return
}
diff --git a/libgo/go/net/testdata/invalid-ndots-resolv.conf b/libgo/go/net/testdata/invalid-ndots-resolv.conf
new file mode 100644
index 0000000..084c164
--- /dev/null
+++ b/libgo/go/net/testdata/invalid-ndots-resolv.conf
@@ -0,0 +1 @@
+options ndots:invalid \ No newline at end of file
diff --git a/libgo/go/net/testdata/large-ndots-resolv.conf b/libgo/go/net/testdata/large-ndots-resolv.conf
new file mode 100644
index 0000000..72968ee
--- /dev/null
+++ b/libgo/go/net/testdata/large-ndots-resolv.conf
@@ -0,0 +1 @@
+options ndots:16 \ No newline at end of file
diff --git a/libgo/go/net/testdata/negative-ndots-resolv.conf b/libgo/go/net/testdata/negative-ndots-resolv.conf
new file mode 100644
index 0000000..c11e0cc
--- /dev/null
+++ b/libgo/go/net/testdata/negative-ndots-resolv.conf
@@ -0,0 +1 @@
+options ndots:-1 \ No newline at end of file
diff --git a/libgo/go/net/textproto/header.go b/libgo/go/net/textproto/header.go
index 2e2752a..ed096d9 100644
--- a/libgo/go/net/textproto/header.go
+++ b/libgo/go/net/textproto/header.go
@@ -23,8 +23,10 @@ func (h MIMEHeader) Set(key, value string) {
}
// Get gets the first value associated with the given key.
+// It is case insensitive; CanonicalMIMEHeaderKey is used
+// to canonicalize the provided key.
// If there are no values associated with the key, Get returns "".
-// Get is a convenience method. For more complex queries,
+// To access multiple values of a key, or to use non-canonical keys,
// access the map directly.
func (h MIMEHeader) Get(key string) string {
if h == nil {
diff --git a/libgo/go/net/timeout_test.go b/libgo/go/net/timeout_test.go
index ed26f2a..55bbf44 100644
--- a/libgo/go/net/timeout_test.go
+++ b/libgo/go/net/timeout_test.go
@@ -5,7 +5,6 @@
package net
import (
- "context"
"fmt"
"internal/testenv"
"io"
@@ -152,6 +151,7 @@ var acceptTimeoutTests = []struct {
}
func TestAcceptTimeout(t *testing.T) {
+ testenv.SkipFlaky(t, 17948)
t.Parallel()
switch runtime.GOOS {
@@ -165,19 +165,18 @@ func TestAcceptTimeout(t *testing.T) {
}
defer ln.Close()
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
+ var wg sync.WaitGroup
for i, tt := range acceptTimeoutTests {
if tt.timeout < 0 {
+ wg.Add(1)
go func() {
- var d Dialer
- c, err := d.DialContext(ctx, ln.Addr().Network(), ln.Addr().String())
+ defer wg.Done()
+ d := Dialer{Timeout: 100 * time.Millisecond}
+ c, err := d.Dial(ln.Addr().Network(), ln.Addr().String())
if err != nil {
t.Error(err)
return
}
- var b [1]byte
- c.Read(b[:])
c.Close()
}()
}
@@ -198,13 +197,14 @@ func TestAcceptTimeout(t *testing.T) {
}
if err == nil {
c.Close()
- time.Sleep(tt.timeout / 3)
+ time.Sleep(10 * time.Millisecond)
continue
}
break
}
}
}
+ wg.Wait()
}
func TestAcceptTimeoutMustReturn(t *testing.T) {
@@ -305,11 +305,6 @@ var readTimeoutTests = []struct {
}
func TestReadTimeout(t *testing.T) {
- switch runtime.GOOS {
- case "plan9":
- t.Skipf("not supported on %s", runtime.GOOS)
- }
-
handler := func(ls *localServer, ln Listener) {
c, err := ln.Accept()
if err != nil {
@@ -435,7 +430,7 @@ var readFromTimeoutTests = []struct {
func TestReadFromTimeout(t *testing.T) {
switch runtime.GOOS {
- case "nacl", "plan9":
+ case "nacl":
t.Skipf("not supported on %s", runtime.GOOS) // see golang.org/issue/8916
}
@@ -509,11 +504,6 @@ var writeTimeoutTests = []struct {
func TestWriteTimeout(t *testing.T) {
t.Parallel()
- switch runtime.GOOS {
- case "plan9":
- t.Skipf("not supported on %s", runtime.GOOS)
- }
-
ln, err := newLocalListener("tcp")
if err != nil {
t.Fatal(err)
@@ -629,7 +619,7 @@ func TestWriteToTimeout(t *testing.T) {
t.Parallel()
switch runtime.GOOS {
- case "nacl", "plan9":
+ case "nacl":
t.Skipf("not supported on %s", runtime.GOOS)
}
@@ -681,11 +671,6 @@ func TestWriteToTimeout(t *testing.T) {
func TestReadTimeoutFluctuation(t *testing.T) {
t.Parallel()
- switch runtime.GOOS {
- case "plan9":
- t.Skipf("not supported on %s", runtime.GOOS)
- }
-
ln, err := newLocalListener("tcp")
if err != nil {
t.Fatal(err)
@@ -719,11 +704,6 @@ func TestReadTimeoutFluctuation(t *testing.T) {
func TestReadFromTimeoutFluctuation(t *testing.T) {
t.Parallel()
- switch runtime.GOOS {
- case "plan9":
- t.Skipf("not supported on %s", runtime.GOOS)
- }
-
c1, err := newLocalPacketListener("udp")
if err != nil {
t.Fatal(err)
@@ -829,11 +809,6 @@ func (b neverEnding) Read(p []byte) (int, error) {
}
func testVariousDeadlines(t *testing.T) {
- switch runtime.GOOS {
- case "plan9":
- t.Skipf("not supported on %s", runtime.GOOS)
- }
-
type result struct {
n int64
err error
@@ -1030,7 +1005,7 @@ func TestReadWriteDeadlineRace(t *testing.T) {
t.Parallel()
switch runtime.GOOS {
- case "nacl", "plan9":
+ case "nacl":
t.Skipf("not supported on %s", runtime.GOOS)
}
diff --git a/libgo/go/net/udpsock.go b/libgo/go/net/udpsock.go
index 980f67c..841ef53 100644
--- a/libgo/go/net/udpsock.go
+++ b/libgo/go/net/udpsock.go
@@ -9,6 +9,15 @@ import (
"syscall"
)
+// BUG(mikio): On NaCl, Plan 9 and Windows, the ReadMsgUDP and
+// WriteMsgUDP methods of UDPConn are not implemented.
+
+// BUG(mikio): On Windows, the File method of UDPConn is not
+// implemented.
+
+// BUG(mikio): On NaCl, the ListenMulticastUDP function is not
+// implemented.
+
// UDPAddr represents the address of a UDP end point.
type UDPAddr struct {
IP IP
@@ -50,6 +59,9 @@ func (a *UDPAddr) opAddr() Addr {
// "udp6". A literal address or host name for IPv6 must be enclosed
// in square brackets, as in "[::1]:80", "[ipv6-host]:http" or
// "[ipv6-host%zone]:80".
+//
+// Resolving a hostname is not recommended because this returns at most
+// one of its IP addresses.
func ResolveUDPAddr(net, addr string) (*UDPAddr, error) {
switch net {
case "udp", "udp4", "udp6":
@@ -58,7 +70,7 @@ func ResolveUDPAddr(net, addr string) (*UDPAddr, error) {
default:
return nil, UnknownNetworkError(net)
}
- addrs, err := internetAddrList(context.Background(), net, addr)
+ addrs, err := DefaultResolver.internetAddrList(context.Background(), net, addr)
if err != nil {
return nil, err
}
diff --git a/libgo/go/net/udpsock_plan9.go b/libgo/go/net/udpsock_plan9.go
index 666f206..1ce7f88 100644
--- a/libgo/go/net/udpsock_plan9.go
+++ b/libgo/go/net/udpsock_plan9.go
@@ -109,5 +109,41 @@ func listenUDP(ctx context.Context, network string, laddr *UDPAddr) (*UDPConn, e
}
func listenMulticastUDP(ctx context.Context, network string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
- return nil, syscall.EPLAN9
+ l, err := listenPlan9(ctx, network, gaddr)
+ if err != nil {
+ return nil, err
+ }
+ _, err = l.ctl.WriteString("headers")
+ if err != nil {
+ return nil, err
+ }
+ var addrs []Addr
+ if ifi != nil {
+ addrs, err = ifi.Addrs()
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ addrs, err = InterfaceAddrs()
+ if err != nil {
+ return nil, err
+ }
+ }
+ for _, addr := range addrs {
+ if ipnet, ok := addr.(*IPNet); ok {
+ _, err = l.ctl.WriteString("addmulti " + ipnet.IP.String() + " " + gaddr.IP.String())
+ if err != nil {
+ return nil, err
+ }
+ }
+ }
+ l.data, err = os.OpenFile(l.dir+"/data", os.O_RDWR, 0)
+ if err != nil {
+ return nil, err
+ }
+ fd, err := l.netFD()
+ if err != nil {
+ return nil, err
+ }
+ return newUDPConn(fd), nil
}
diff --git a/libgo/go/net/udpsock_plan9_test.go b/libgo/go/net/udpsock_plan9_test.go
new file mode 100644
index 0000000..09f5a5d
--- /dev/null
+++ b/libgo/go/net/udpsock_plan9_test.go
@@ -0,0 +1,69 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "internal/testenv"
+ "runtime"
+ "testing"
+)
+
+func TestListenMulticastUDP(t *testing.T) {
+ testenv.MustHaveExternalNetwork(t)
+
+ ifcs, err := Interfaces()
+ if err != nil {
+ t.Skip(err.Error())
+ }
+ if len(ifcs) == 0 {
+ t.Skip("no network interfaces found")
+ }
+
+ var mifc *Interface
+ for _, ifc := range ifcs {
+ if ifc.Flags&FlagUp|FlagMulticast != FlagUp|FlagMulticast {
+ continue
+ }
+ mifc = &ifc
+ break
+ }
+
+ if mifc == nil {
+ t.Skipf("no multicast interfaces found")
+ }
+
+ c1, err := ListenMulticastUDP("udp4", mifc, &UDPAddr{IP: ParseIP("224.0.0.254")})
+ if err != nil {
+ t.Fatalf("multicast not working on %s", runtime.GOOS)
+ }
+ c1addr := c1.LocalAddr().(*UDPAddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c1.Close()
+
+ c2, err := ListenUDP("udp4", &UDPAddr{IP: IPv4zero, Port: 0})
+ c2addr := c2.LocalAddr().(*UDPAddr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c2.Close()
+
+ n, err := c2.WriteToUDP([]byte("data"), c1addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != 4 {
+ t.Fatalf("got %d; want 4", n)
+ }
+
+ n, err = c1.WriteToUDP([]byte("data"), c2addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != 4 {
+ t.Fatalf("got %d; want 4", n)
+ }
+}
diff --git a/libgo/go/net/udpsock_posix.go b/libgo/go/net/udpsock_posix.go
index 4924801..72aadca 100644
--- a/libgo/go/net/udpsock_posix.go
+++ b/libgo/go/net/udpsock_posix.go
@@ -38,6 +38,10 @@ func (a *UDPAddr) sockaddr(family int) (syscall.Sockaddr, error) {
return ipToSockaddr(family, a.IP, a.Port, a.Zone)
}
+func (a *UDPAddr) toLocal(net string) sockaddr {
+ return &UDPAddr{loopbackIP(net), a.Port, a.Zone}
+}
+
func (c *UDPConn) readFrom(b []byte) (int, *UDPAddr, error) {
var addr *UDPAddr
n, sa, err := c.fd.readFrom(b)
diff --git a/libgo/go/net/udpsock_test.go b/libgo/go/net/udpsock_test.go
index 29d769c..708cc10 100644
--- a/libgo/go/net/udpsock_test.go
+++ b/libgo/go/net/udpsock_test.go
@@ -72,6 +72,17 @@ var resolveUDPAddrTests = []resolveUDPAddrTest{
{"udp", ":12345", &UDPAddr{Port: 12345}, nil},
{"http", "127.0.0.1:0", nil, UnknownNetworkError("http")},
+
+ {"udp", "127.0.0.1:domain", &UDPAddr{IP: ParseIP("127.0.0.1"), Port: 53}, nil},
+ {"udp", "[::ffff:127.0.0.1]:domain", &UDPAddr{IP: ParseIP("::ffff:127.0.0.1"), Port: 53}, nil},
+ {"udp", "[2001:db8::1]:domain", &UDPAddr{IP: ParseIP("2001:db8::1"), Port: 53}, nil},
+ {"udp4", "127.0.0.1:domain", &UDPAddr{IP: ParseIP("127.0.0.1"), Port: 53}, nil},
+ {"udp4", "[::ffff:127.0.0.1]:domain", &UDPAddr{IP: ParseIP("127.0.0.1"), Port: 53}, nil},
+ {"udp6", "[2001:db8::1]:domain", &UDPAddr{IP: ParseIP("2001:db8::1"), Port: 53}, nil},
+
+ {"udp4", "[2001:db8::1]:domain", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "2001:db8::1"}},
+ {"udp6", "127.0.0.1:domain", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "127.0.0.1"}},
+ {"udp6", "[::ffff:127.0.0.1]:domain", nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: "::ffff:127.0.0.1"}},
}
func TestResolveUDPAddr(t *testing.T) {
@@ -79,21 +90,17 @@ func TestResolveUDPAddr(t *testing.T) {
defer func() { testHookLookupIP = origTestHookLookupIP }()
testHookLookupIP = lookupLocalhost
- for i, tt := range resolveUDPAddrTests {
+ for _, tt := range resolveUDPAddrTests {
addr, err := ResolveUDPAddr(tt.network, tt.litAddrOrName)
- if err != tt.err {
- t.Errorf("#%d: %v", i, err)
- } else if !reflect.DeepEqual(addr, tt.addr) {
- t.Errorf("#%d: got %#v; want %#v", i, addr, tt.addr)
- }
- if err != nil {
+ if !reflect.DeepEqual(addr, tt.addr) || !reflect.DeepEqual(err, tt.err) {
+ t.Errorf("ResolveUDPAddr(%q, %q) = %#v, %v, want %#v, %v", tt.network, tt.litAddrOrName, addr, err, tt.addr, tt.err)
continue
}
- rtaddr, err := ResolveUDPAddr(addr.Network(), addr.String())
- if err != nil {
- t.Errorf("#%d: %v", i, err)
- } else if !reflect.DeepEqual(rtaddr, addr) {
- t.Errorf("#%d: got %#v; want %#v", i, rtaddr, addr)
+ if err == nil {
+ addr2, err := ResolveUDPAddr(addr.Network(), addr.String())
+ if !reflect.DeepEqual(addr2, tt.addr) || err != tt.err {
+ t.Errorf("(%q, %q): ResolveUDPAddr(%q, %q) = %#v, %v, want %#v, %v", tt.network, tt.litAddrOrName, addr.Network(), addr.String(), addr2, err, tt.addr, tt.err)
+ }
}
}
}
diff --git a/libgo/go/net/unixsock.go b/libgo/go/net/unixsock.go
index bacdaa4..b25d492 100644
--- a/libgo/go/net/unixsock.go
+++ b/libgo/go/net/unixsock.go
@@ -7,6 +7,7 @@ package net
import (
"context"
"os"
+ "sync"
"syscall"
"time"
)
@@ -120,6 +121,9 @@ func (c *UnixConn) ReadFrom(b []byte) (int, Addr, error) {
// the associated out-of-band data into oob. It returns the number of
// bytes copied into b, the number of bytes copied into oob, the flags
// that were set on the packet, and the source address of the packet.
+//
+// Note that if len(b) == 0 and len(oob) > 0, this function will still
+// read (and discard) 1 byte from the connection.
func (c *UnixConn) ReadMsgUnix(b, oob []byte) (n, oobn, flags int, addr *UnixAddr, err error) {
if !c.ok() {
return 0, 0, 0, nil, syscall.EINVAL
@@ -167,6 +171,9 @@ func (c *UnixConn) WriteTo(b []byte, addr Addr) (int, error) {
// WriteMsgUnix writes a packet to addr via c, copying the payload
// from b and the associated out-of-band data from oob. It returns
// the number of payload and out-of-band bytes written.
+//
+// Note that if len(b) == 0 and len(oob) > 0, this function will still
+// write 1 byte to the connection.
func (c *UnixConn) WriteMsgUnix(b, oob []byte, addr *UnixAddr) (n, oobn int, err error) {
if !c.ok() {
return 0, 0, syscall.EINVAL
@@ -200,9 +207,10 @@ func DialUnix(net string, laddr, raddr *UnixAddr) (*UnixConn, error) {
// typically use variables of type Listener instead of assuming Unix
// domain sockets.
type UnixListener struct {
- fd *netFD
- path string
- unlink bool
+ fd *netFD
+ path string
+ unlink bool
+ unlinkOnce sync.Once
}
func (ln *UnixListener) ok() bool { return ln != nil && ln.fd != nil }
diff --git a/libgo/go/net/unixsock_posix.go b/libgo/go/net/unixsock_posix.go
index 5f0999c..a8f892e 100644
--- a/libgo/go/net/unixsock_posix.go
+++ b/libgo/go/net/unixsock_posix.go
@@ -94,6 +94,10 @@ func (a *UnixAddr) sockaddr(family int) (syscall.Sockaddr, error) {
return &syscall.SockaddrUnix{Name: a.Name}, nil
}
+func (a *UnixAddr) toLocal(net string) sockaddr {
+ return a
+}
+
func (c *UnixConn) readFrom(b []byte) (int, *UnixAddr, error) {
var addr *UnixAddr
n, sa, err := c.fd.readFrom(b)
@@ -173,9 +177,12 @@ func (ln *UnixListener) close() error {
// is at least compatible with the auto-remove
// sequence in ListenUnix. It's only non-Go
// programs that can mess us up.
- if ln.path[0] != '@' && ln.unlink {
- syscall.Unlink(ln.path)
- }
+ // Even if there are racy calls to Close, we want to unlink only for the first one.
+ ln.unlinkOnce.Do(func() {
+ if ln.path[0] != '@' && ln.unlink {
+ syscall.Unlink(ln.path)
+ }
+ })
return ln.fd.Close()
}
@@ -187,6 +194,18 @@ func (ln *UnixListener) file() (*os.File, error) {
return f, nil
}
+// SetUnlinkOnClose sets whether the underlying socket file should be removed
+// from the file system when the listener is closed.
+//
+// The default behavior is to unlink the socket file only when package net created it.
+// That is, when the listener and the underlying socket file were created by a call to
+// Listen or ListenUnix, then by default closing the listener will remove the socket file.
+// but if the listener was created by a call to FileListener to use an already existing
+// socket file, then by default closing the listener will not remove the socket file.
+func (l *UnixListener) SetUnlinkOnClose(unlink bool) {
+ l.unlink = unlink
+}
+
func listenUnix(ctx context.Context, network string, laddr *UnixAddr) (*UnixListener, error) {
fd, err := unixSocket(ctx, network, laddr, nil, "listen")
if err != nil {
diff --git a/libgo/go/net/unixsock_test.go b/libgo/go/net/unixsock_test.go
index f0f88ed..489a29b 100644
--- a/libgo/go/net/unixsock_test.go
+++ b/libgo/go/net/unixsock_test.go
@@ -9,6 +9,7 @@ package net
import (
"bytes"
"internal/testenv"
+ "io/ioutil"
"os"
"reflect"
"runtime"
@@ -414,33 +415,104 @@ func TestUnixUnlink(t *testing.T) {
t.Skip("unix test")
}
name := testUnixAddr()
- l, err := Listen("unix", name)
- if err != nil {
- t.Fatal(err)
- }
- if _, err := os.Stat(name); err != nil {
- t.Fatalf("cannot stat unix socket after ListenUnix: %v", err)
- }
- f, _ := l.(*UnixListener).File()
- l1, err := FileListener(f)
- if err != nil {
- t.Fatal(err)
- }
- if _, err := os.Stat(name); err != nil {
- t.Fatalf("cannot stat unix socket after FileListener: %v", err)
- }
- if err := l1.Close(); err != nil {
- t.Fatalf("closing file listener: %v", err)
- }
- if _, err := os.Stat(name); err != nil {
- t.Fatalf("cannot stat unix socket after closing FileListener: %v", err)
+
+ listen := func(t *testing.T) *UnixListener {
+ l, err := Listen("unix", name)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return l.(*UnixListener)
}
- f.Close()
- if _, err := os.Stat(name); err != nil {
- t.Fatalf("cannot stat unix socket after closing FileListener and fd: %v", err)
+ checkExists := func(t *testing.T, desc string) {
+ if _, err := os.Stat(name); err != nil {
+ t.Fatalf("unix socket does not exist %s: %v", desc, err)
+ }
}
- l.Close()
- if _, err := os.Stat(name); err == nil {
- t.Fatal("closing unix listener did not remove unix socket")
+ checkNotExists := func(t *testing.T, desc string) {
+ if _, err := os.Stat(name); err == nil {
+ t.Fatalf("unix socket does exist %s: %v", desc, err)
+ }
}
+
+ // Listener should remove on close.
+ t.Run("Listen", func(t *testing.T) {
+ l := listen(t)
+ checkExists(t, "after Listen")
+ l.Close()
+ checkNotExists(t, "after Listener close")
+ })
+
+ // FileListener should not.
+ t.Run("FileListener", func(t *testing.T) {
+ l := listen(t)
+ f, _ := l.File()
+ l1, _ := FileListener(f)
+ checkExists(t, "after FileListener")
+ f.Close()
+ checkExists(t, "after File close")
+ l1.Close()
+ checkExists(t, "after FileListener close")
+ l.Close()
+ checkNotExists(t, "after Listener close")
+ })
+
+ // Only first call to l.Close should remove.
+ t.Run("SecondClose", func(t *testing.T) {
+ l := listen(t)
+ checkExists(t, "after Listen")
+ l.Close()
+ checkNotExists(t, "after Listener close")
+ if err := ioutil.WriteFile(name, []byte("hello world"), 0666); err != nil {
+ t.Fatalf("cannot recreate socket file: %v", err)
+ }
+ checkExists(t, "after writing temp file")
+ l.Close()
+ checkExists(t, "after second Listener close")
+ os.Remove(name)
+ })
+
+ // SetUnlinkOnClose should do what it says.
+
+ t.Run("Listen/SetUnlinkOnClose(true)", func(t *testing.T) {
+ l := listen(t)
+ checkExists(t, "after Listen")
+ l.SetUnlinkOnClose(true)
+ l.Close()
+ checkNotExists(t, "after Listener close")
+ })
+
+ t.Run("Listen/SetUnlinkOnClose(false)", func(t *testing.T) {
+ l := listen(t)
+ checkExists(t, "after Listen")
+ l.SetUnlinkOnClose(false)
+ l.Close()
+ checkExists(t, "after Listener close")
+ os.Remove(name)
+ })
+
+ t.Run("FileListener/SetUnlinkOnClose(true)", func(t *testing.T) {
+ l := listen(t)
+ f, _ := l.File()
+ l1, _ := FileListener(f)
+ checkExists(t, "after FileListener")
+ l1.(*UnixListener).SetUnlinkOnClose(true)
+ f.Close()
+ checkExists(t, "after File close")
+ l1.Close()
+ checkNotExists(t, "after FileListener close")
+ l.Close()
+ })
+
+ t.Run("FileListener/SetUnlinkOnClose(false)", func(t *testing.T) {
+ l := listen(t)
+ f, _ := l.File()
+ l1, _ := FileListener(f)
+ checkExists(t, "after FileListener")
+ l1.(*UnixListener).SetUnlinkOnClose(false)
+ f.Close()
+ checkExists(t, "after File close")
+ l1.Close()
+ checkExists(t, "after FileListener close")
+ l.Close()
+ })
}
diff --git a/libgo/go/net/url/url.go b/libgo/go/net/url/url.go
index 30e9277..42a514b 100644
--- a/libgo/go/net/url/url.go
+++ b/libgo/go/net/url/url.go
@@ -74,6 +74,7 @@ type encoding int
const (
encodePath encoding = 1 + iota
+ encodePathSegment
encodeHost
encodeZone
encodeUserPassword
@@ -132,9 +133,14 @@ func shouldEscape(c byte, mode encoding) bool {
// The RFC allows : @ & = + $ but saves / ; , for assigning
// meaning to individual path segments. This package
// only manipulates the path as a whole, so we allow those
- // last two as well. That leaves only ? to escape.
+ // last three as well. That leaves only ? to escape.
return c == '?'
+ case encodePathSegment: // §3.3
+ // The RFC allows : @ & = + $ but saves / ; , for assigning
+ // meaning to individual path segments.
+ return c == '/' || c == ';' || c == ',' || c == '?'
+
case encodeUserPassword: // §3.2.1
// The RFC allows ';', ':', '&', '=', '+', '$', and ',' in
// userinfo, so we must escape only '@', '/', and '?'.
@@ -164,6 +170,15 @@ func QueryUnescape(s string) (string, error) {
return unescape(s, encodeQueryComponent)
}
+// PathUnescape does the inverse transformation of PathEscape, converting
+// %AB into the byte 0xAB. It returns an error if any % is not followed by
+// two hexadecimal digits.
+//
+// PathUnescape is identical to QueryUnescape except that it does not unescape '+' to ' ' (space).
+func PathUnescape(s string) (string, error) {
+ return unescape(s, encodePathSegment)
+}
+
// unescape unescapes a string; the mode specifies
// which section of the URL string is being unescaped.
func unescape(s string, mode encoding) (string, error) {
@@ -250,6 +265,12 @@ func QueryEscape(s string) string {
return escape(s, encodeQueryComponent)
}
+// PathEscape escapes the string so it can be safely placed
+// inside a URL path segment.
+func PathEscape(s string) string {
+ return escape(s, encodePathSegment)
+}
+
func escape(s string, mode encoding) string {
spaceCount, hexCount := 0, 0
for i := 0; i < len(s); i++ {
@@ -356,10 +377,7 @@ func (u *Userinfo) Username() string {
// Password returns the password in case it is set, and whether it is set.
func (u *Userinfo) Password() (string, bool) {
- if u.passwordSet {
- return u.password, true
- }
- return "", false
+ return u.password, u.passwordSet
}
// String returns the encoded userinfo information in the standard form
@@ -420,7 +438,7 @@ func Parse(rawurl string) (*URL, error) {
u, frag := split(rawurl, "#", true)
url, err := parse(u, false)
if err != nil {
- return nil, err
+ return nil, &Error{"parse", u, err}
}
if frag == "" {
return url, nil
@@ -437,31 +455,35 @@ func Parse(rawurl string) (*URL, error) {
// The string rawurl is assumed not to have a #fragment suffix.
// (Web browsers strip #fragment before sending the URL to a web server.)
func ParseRequestURI(rawurl string) (*URL, error) {
- return parse(rawurl, true)
+ url, err := parse(rawurl, true)
+ if err != nil {
+ return nil, &Error{"parse", rawurl, err}
+ }
+ return url, nil
}
// parse parses a URL from a string in one of two contexts. If
// viaRequest is true, the URL is assumed to have arrived via an HTTP request,
// in which case only absolute URLs or path-absolute relative URLs are allowed.
// If viaRequest is false, all forms of relative URLs are allowed.
-func parse(rawurl string, viaRequest bool) (url *URL, err error) {
+func parse(rawurl string, viaRequest bool) (*URL, error) {
var rest string
+ var err error
if rawurl == "" && viaRequest {
- err = errors.New("empty url")
- goto Error
+ return nil, errors.New("empty url")
}
- url = new(URL)
+ url := new(URL)
if rawurl == "*" {
url.Path = "*"
- return
+ return url, nil
}
// Split off possible leading "http:", "mailto:", etc.
// Cannot contain escaped characters.
if url.Scheme, rest, err = getscheme(rawurl); err != nil {
- goto Error
+ return nil, err
}
url.Scheme = strings.ToLower(url.Scheme)
@@ -479,8 +501,20 @@ func parse(rawurl string, viaRequest bool) (url *URL, err error) {
return url, nil
}
if viaRequest {
- err = errors.New("invalid URI for request")
- goto Error
+ return nil, errors.New("invalid URI for request")
+ }
+
+ // Avoid confusion with malformed schemes, like cache_object:foo/bar.
+ // See golang.org/issue/16822.
+ //
+ // RFC 3986, §3.3:
+ // In addition, a URI reference (Section 4.1) may be a relative-path reference,
+ // in which case the first path segment cannot contain a colon (":") character.
+ colon := strings.Index(rest, ":")
+ slash := strings.Index(rest, "/")
+ if colon >= 0 && (slash < 0 || colon < slash) {
+ // First path segment has colon. Not allowed in relative URL.
+ return nil, errors.New("first path segment in URL cannot contain colon")
}
}
@@ -489,23 +523,17 @@ func parse(rawurl string, viaRequest bool) (url *URL, err error) {
authority, rest = split(rest[2:], "/", false)
url.User, url.Host, err = parseAuthority(authority)
if err != nil {
- goto Error
+ return nil, err
}
}
- if url.Path, err = unescape(rest, encodePath); err != nil {
- goto Error
- }
- // RawPath is a hint as to the encoding of Path to use
- // in url.EscapedPath. If that method already gets the
- // right answer without RawPath, leave it empty.
- // This will help make sure that people don't rely on it in general.
- if url.EscapedPath() != rest && validEncodedPath(rest) {
- url.RawPath = rest
+ // Set Path and, optionally, RawPath.
+ // RawPath is a hint of the encoding of Path. We don't want to set it if
+ // the default escaping of Path is equivalent, to help make sure that people
+ // don't rely on it in general.
+ if err := url.setPath(rest); err != nil {
+ return nil, err
}
return url, nil
-
-Error:
- return nil, &Error{"parse", rawurl, err}
}
func parseAuthority(authority string) (user *Userinfo, host string, err error) {
@@ -586,6 +614,29 @@ func parseHost(host string) (string, error) {
return host, nil
}
+// setPath sets the Path and RawPath fields of the URL based on the provided
+// escaped path p. It maintains the invariant that RawPath is only specified
+// when it differs from the default encoding of the path.
+// For example:
+// - setPath("/foo/bar") will set Path="/foo/bar" and RawPath=""
+// - setPath("/foo%2fbar") will set Path="/foo/bar" and RawPath="/foo%2fbar"
+// setPath will return an error only if the provided path contains an invalid
+// escaping.
+func (u *URL) setPath(p string) error {
+ path, err := unescape(p, encodePath)
+ if err != nil {
+ return err
+ }
+ u.Path = path
+ if escp := escape(path, encodePath); p == escp {
+ // Default encoding is fine.
+ u.RawPath = ""
+ } else {
+ u.RawPath = p
+ }
+ return nil
+}
+
// EscapedPath returns the escaped form of u.Path.
// In general there are multiple possible escaped forms of any path.
// EscapedPath returns u.RawPath when it is a valid escaping of u.Path.
@@ -693,6 +744,17 @@ func (u *URL) String() string {
if path != "" && path[0] != '/' && u.Host != "" {
buf.WriteByte('/')
}
+ if buf.Len() == 0 {
+ // RFC 3986 §4.2
+ // A path segment that contains a colon character (e.g., "this:that")
+ // cannot be used as the first segment of a relative-path reference, as
+ // it would be mistaken for a scheme name. Such a segment must be
+ // preceded by a dot-segment (e.g., "./this:that") to make a relative-
+ // path reference.
+ if i := strings.IndexByte(path, ':'); i > -1 && strings.IndexByte(path[:i], '/') == -1 {
+ buf.WriteString("./")
+ }
+ }
buf.WriteString(path)
}
if u.ForceQuery || u.RawQuery != "" {
@@ -749,6 +811,10 @@ func (v Values) Del(key string) {
// ParseQuery always returns a non-nil map containing all the
// valid query parameters found; err describes the first decoding error
// encountered, if any.
+//
+// Query is expected to be a list of key=value settings separated by
+// ampersands or semicolons. A setting without an equals sign is
+// interpreted as a key set to an empty value.
func ParseQuery(query string) (Values, error) {
m := make(Values)
err := parseQuery(m, query)
@@ -852,6 +918,7 @@ func resolvePath(base, ref string) string {
}
// IsAbs reports whether the URL is absolute.
+// Absolute means that it has a non-empty scheme.
func (u *URL) IsAbs() bool {
return u.Scheme != ""
}
@@ -880,7 +947,9 @@ func (u *URL) ResolveReference(ref *URL) *URL {
}
if ref.Scheme != "" || ref.Host != "" || ref.User != nil {
// The "absoluteURI" or "net_path" cases.
- url.Path = resolvePath(ref.Path, "")
+ // We can ignore the error from setPath since we know we provided a
+ // validly-escaped path.
+ url.setPath(resolvePath(ref.EscapedPath(), ""))
return &url
}
if ref.Opaque != "" {
@@ -900,7 +969,7 @@ func (u *URL) ResolveReference(ref *URL) *URL {
// The "abs_path" or "rel_path" cases.
url.Host = u.Host
url.User = u.User
- url.Path = resolvePath(u.Path, ref.Path)
+ url.setPath(resolvePath(u.EscapedPath(), ref.EscapedPath()))
return &url
}
@@ -929,3 +998,59 @@ func (u *URL) RequestURI() string {
}
return result
}
+
+// Hostname returns u.Host, without any port number.
+//
+// If Host is an IPv6 literal with a port number, Hostname returns the
+// IPv6 literal without the square brackets. IPv6 literals may include
+// a zone identifier.
+func (u *URL) Hostname() string {
+ return stripPort(u.Host)
+}
+
+// Port returns the port part of u.Host, without the leading colon.
+// If u.Host doesn't contain a port, Port returns an empty string.
+func (u *URL) Port() string {
+ return portOnly(u.Host)
+}
+
+func stripPort(hostport string) string {
+ colon := strings.IndexByte(hostport, ':')
+ if colon == -1 {
+ return hostport
+ }
+ if i := strings.IndexByte(hostport, ']'); i != -1 {
+ return strings.TrimPrefix(hostport[:i], "[")
+ }
+ return hostport[:colon]
+}
+
+func portOnly(hostport string) string {
+ colon := strings.IndexByte(hostport, ':')
+ if colon == -1 {
+ return ""
+ }
+ if i := strings.Index(hostport, "]:"); i != -1 {
+ return hostport[i+len("]:"):]
+ }
+ if strings.Contains(hostport, "]") {
+ return ""
+ }
+ return hostport[colon+len(":"):]
+}
+
+// Marshaling interface implementations.
+// Would like to implement MarshalText/UnmarshalText but that will change the JSON representation of URLs.
+
+func (u *URL) MarshalBinary() (text []byte, err error) {
+ return []byte(u.String()), nil
+}
+
+func (u *URL) UnmarshalBinary(text []byte) error {
+ u1, err := Parse(string(text))
+ if err != nil {
+ return err
+ }
+ *u = *u1
+ return nil
+}
diff --git a/libgo/go/net/url/url_test.go b/libgo/go/net/url/url_test.go
index 7560f22..6c3bb21 100644
--- a/libgo/go/net/url/url_test.go
+++ b/libgo/go/net/url/url_test.go
@@ -5,6 +5,10 @@
package url
import (
+ "bytes"
+ encodingPkg "encoding"
+ "encoding/gob"
+ "encoding/json"
"fmt"
"io"
"net"
@@ -579,20 +583,6 @@ func ufmt(u *URL) string {
u.Opaque, u.Scheme, user, pass, u.Host, u.Path, u.RawPath, u.RawQuery, u.Fragment, u.ForceQuery)
}
-func DoTest(t *testing.T, parse func(string) (*URL, error), name string, tests []URLTest) {
- for _, tt := range tests {
- u, err := parse(tt.in)
- if err != nil {
- t.Errorf("%s(%q) returned error %s", name, tt.in, err)
- continue
- }
- if !reflect.DeepEqual(u, tt.out) {
- t.Errorf("%s(%q):\n\thave %v\n\twant %v\n",
- name, tt.in, ufmt(u), ufmt(tt.out))
- }
- }
-}
-
func BenchmarkString(b *testing.B) {
b.StopTimer()
b.ReportAllocs()
@@ -618,7 +608,16 @@ func BenchmarkString(b *testing.B) {
}
func TestParse(t *testing.T) {
- DoTest(t, Parse, "Parse", urltests)
+ for _, tt := range urltests {
+ u, err := Parse(tt.in)
+ if err != nil {
+ t.Errorf("Parse(%q) returned error %v", tt.in, err)
+ continue
+ }
+ if !reflect.DeepEqual(u, tt.out) {
+ t.Errorf("Parse(%q):\n\tgot %v\n\twant %v\n", tt.in, ufmt(u), ufmt(tt.out))
+ }
+ }
}
const pathThatLooksSchemeRelative = "//not.a.user@not.a.host/just/a/path"
@@ -665,9 +664,10 @@ var parseRequestURLTests = []struct {
func TestParseRequestURI(t *testing.T) {
for _, test := range parseRequestURLTests {
_, err := ParseRequestURI(test.url)
- valid := err == nil
- if valid != test.expectedValid {
- t.Errorf("Expected valid=%v for %q; got %v", test.expectedValid, test.url, valid)
+ if test.expectedValid && err != nil {
+ t.Errorf("ParseRequestURI(%q) gave err %v; want no error", test.url, err)
+ } else if !test.expectedValid && err == nil {
+ t.Errorf("ParseRequestURI(%q) gave nil error; want some error", test.url)
}
}
@@ -676,45 +676,69 @@ func TestParseRequestURI(t *testing.T) {
t.Fatalf("Unexpected error %v", err)
}
if url.Path != pathThatLooksSchemeRelative {
- t.Errorf("Expected path %q; got %q", pathThatLooksSchemeRelative, url.Path)
+ t.Errorf("ParseRequestURI path:\ngot %q\nwant %q", url.Path, pathThatLooksSchemeRelative)
}
}
-func DoTestString(t *testing.T, parse func(string) (*URL, error), name string, tests []URLTest) {
- for _, tt := range tests {
- u, err := parse(tt.in)
+var stringURLTests = []struct {
+ url URL
+ want string
+}{
+ // No leading slash on path should prepend slash on String() call
+ {
+ url: URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "search",
+ },
+ want: "http://www.google.com/search",
+ },
+ // Relative path with first element containing ":" should be prepended with "./", golang.org/issue/17184
+ {
+ url: URL{
+ Path: "this:that",
+ },
+ want: "./this:that",
+ },
+ // Relative path with second element containing ":" should not be prepended with "./"
+ {
+ url: URL{
+ Path: "here/this:that",
+ },
+ want: "here/this:that",
+ },
+ // Non-relative path with first element containing ":" should not be prepended with "./"
+ {
+ url: URL{
+ Scheme: "http",
+ Host: "www.google.com",
+ Path: "this:that",
+ },
+ want: "http://www.google.com/this:that",
+ },
+}
+
+func TestURLString(t *testing.T) {
+ for _, tt := range urltests {
+ u, err := Parse(tt.in)
if err != nil {
- t.Errorf("%s(%q) returned error %s", name, tt.in, err)
+ t.Errorf("Parse(%q) returned error %s", tt.in, err)
continue
}
expected := tt.in
- if len(tt.roundtrip) > 0 {
+ if tt.roundtrip != "" {
expected = tt.roundtrip
}
s := u.String()
if s != expected {
- t.Errorf("%s(%q).String() == %q (expected %q)", name, tt.in, s, expected)
+ t.Errorf("Parse(%q).String() == %q (expected %q)", tt.in, s, expected)
}
}
-}
-func TestURLString(t *testing.T) {
- DoTestString(t, Parse, "Parse", urltests)
-
- // no leading slash on path should prepend
- // slash on String() call
- noslash := URLTest{
- "http://www.google.com/search",
- &URL{
- Scheme: "http",
- Host: "www.google.com",
- Path: "search",
- },
- "",
- }
- s := noslash.out.String()
- if s != noslash.in {
- t.Errorf("Expected %s; go %s", noslash.in, s)
+ for _, tt := range stringURLTests {
+ if got := tt.url.String(); got != tt.want {
+ t.Errorf("%+v.String() = %q; want %q", tt.url, got, tt.want)
+ }
}
}
@@ -780,6 +804,16 @@ var unescapeTests = []EscapeTest{
"",
EscapeError("%zz"),
},
+ {
+ "a+b",
+ "a b",
+ nil,
+ },
+ {
+ "a%20b",
+ "a b",
+ nil,
+ },
}
func TestUnescape(t *testing.T) {
@@ -788,10 +822,33 @@ func TestUnescape(t *testing.T) {
if actual != tt.out || (err != nil) != (tt.err != nil) {
t.Errorf("QueryUnescape(%q) = %q, %s; want %q, %s", tt.in, actual, err, tt.out, tt.err)
}
+
+ in := tt.in
+ out := tt.out
+ if strings.Contains(tt.in, "+") {
+ in = strings.Replace(tt.in, "+", "%20", -1)
+ actual, err := PathUnescape(in)
+ if actual != tt.out || (err != nil) != (tt.err != nil) {
+ t.Errorf("PathUnescape(%q) = %q, %s; want %q, %s", in, actual, err, tt.out, tt.err)
+ }
+ if tt.err == nil {
+ s, err := QueryUnescape(strings.Replace(tt.in, "+", "XXX", -1))
+ if err != nil {
+ continue
+ }
+ in = tt.in
+ out = strings.Replace(s, "XXX", "+", -1)
+ }
+ }
+
+ actual, err = PathUnescape(in)
+ if actual != out || (err != nil) != (tt.err != nil) {
+ t.Errorf("PathUnescape(%q) = %q, %s; want %q, %s", in, actual, err, out, tt.err)
+ }
}
}
-var escapeTests = []EscapeTest{
+var queryEscapeTests = []EscapeTest{
{
"",
"",
@@ -819,8 +876,8 @@ var escapeTests = []EscapeTest{
},
}
-func TestEscape(t *testing.T) {
- for _, tt := range escapeTests {
+func TestQueryEscape(t *testing.T) {
+ for _, tt := range queryEscapeTests {
actual := QueryEscape(tt.in)
if tt.out != actual {
t.Errorf("QueryEscape(%q) = %q, want %q", tt.in, actual, tt.out)
@@ -834,6 +891,54 @@ func TestEscape(t *testing.T) {
}
}
+var pathEscapeTests = []EscapeTest{
+ {
+ "",
+ "",
+ nil,
+ },
+ {
+ "abc",
+ "abc",
+ nil,
+ },
+ {
+ "abc+def",
+ "abc+def",
+ nil,
+ },
+ {
+ "one two",
+ "one%20two",
+ nil,
+ },
+ {
+ "10%",
+ "10%25",
+ nil,
+ },
+ {
+ " ?&=#+%!<>#\"{}|\\^[]`☺\t:/@$'()*,;",
+ "%20%3F&=%23+%25%21%3C%3E%23%22%7B%7D%7C%5C%5E%5B%5D%60%E2%98%BA%09:%2F@$%27%28%29%2A%2C%3B",
+ nil,
+ },
+}
+
+func TestPathEscape(t *testing.T) {
+ for _, tt := range pathEscapeTests {
+ actual := PathEscape(tt.in)
+ if tt.out != actual {
+ t.Errorf("PathEscape(%q) = %q, want %q", tt.in, actual, tt.out)
+ }
+
+ // for bonus points, verify that escape:unescape is an identity.
+ roundtrip, err := PathUnescape(actual)
+ if roundtrip != tt.in || err != nil {
+ t.Errorf("PathUnescape(%q) = %q, %s; want %q, %s", actual, roundtrip, err, tt.in, "[no error]")
+ }
+ }
+}
+
//var userinfoTests = []UserinfoTest{
// {"user", "password", "user:password"},
// {"foo:bar", "~!@#$%^&*()_+{}|[]\\-=`:;'\"<>?,./",
@@ -945,6 +1050,15 @@ var resolveReferenceTests = []struct {
// Fragment
{"http://foo.com/bar", ".#frag", "http://foo.com/#frag"},
+ // Paths with escaping (issue 16947).
+ {"http://foo.com/foo%2fbar/", "../baz", "http://foo.com/baz"},
+ {"http://foo.com/1/2%2f/3%2f4/5", "../../a/b/c", "http://foo.com/1/a/b/c"},
+ {"http://foo.com/1/2/3", "./a%2f../../b/..%2fc", "http://foo.com/1/2/b/..%2fc"},
+ {"http://foo.com/1/2%2f/3%2f4/5", "./a%2f../b/../c", "http://foo.com/1/2%2f/3%2f4/a%2f../c"},
+ {"http://foo.com/foo%20bar/", "../baz", "http://foo.com/baz"},
+ {"http://foo.com/foo", "../bar%2fbaz", "http://foo.com/bar%2fbaz"},
+ {"http://foo.com/foo%2dbar/", "./baz-quux", "http://foo.com/foo%2dbar/baz-quux"},
+
// RFC 3986: Normal Examples
// http://tools.ietf.org/html/rfc3986#section-5.4.1
{"http://a/b/c/d;p?q", "g:h", "g:h"},
@@ -1004,7 +1118,7 @@ func TestResolveReference(t *testing.T) {
mustParse := func(url string) *URL {
u, err := Parse(url)
if err != nil {
- t.Fatalf("Expected URL to parse: %q, got error: %v", url, err)
+ t.Fatalf("Parse(%q) got err %v", url, err)
}
return u
}
@@ -1013,8 +1127,8 @@ func TestResolveReference(t *testing.T) {
base := mustParse(test.base)
rel := mustParse(test.rel)
url := base.ResolveReference(rel)
- if url.String() != test.expected {
- t.Errorf("URL(%q).ResolveReference(%q) == %q, got %q", test.base, test.rel, test.expected, url.String())
+ if got := url.String(); got != test.expected {
+ t.Errorf("URL(%q).ResolveReference(%q)\ngot %q\nwant %q", test.base, test.rel, got, test.expected)
}
// Ensure that new instances are returned.
if base == url {
@@ -1024,8 +1138,8 @@ func TestResolveReference(t *testing.T) {
url, err := base.Parse(test.rel)
if err != nil {
t.Errorf("URL(%q).Parse(%q) failed: %v", test.base, test.rel, err)
- } else if url.String() != test.expected {
- t.Errorf("URL(%q).Parse(%q) == %q, got %q", test.base, test.rel, test.expected, url.String())
+ } else if got := url.String(); got != test.expected {
+ t.Errorf("URL(%q).Parse(%q)\ngot %q\nwant %q", test.base, test.rel, got, test.expected)
} else if base == url {
// Ensure that new instances are returned for the wrapper too.
t.Errorf("Expected URL.Parse to return new URL instance.")
@@ -1033,14 +1147,14 @@ func TestResolveReference(t *testing.T) {
// Ensure Opaque resets the URL.
url = base.ResolveReference(opaque)
if *url != *opaque {
- t.Errorf("ResolveReference failed to resolve opaque URL: want %#v, got %#v", url, opaque)
+ t.Errorf("ResolveReference failed to resolve opaque URL:\ngot %#v\nwant %#v", url, opaque)
}
// Test the convenience wrapper with an opaque URL too.
url, err = base.Parse("scheme:opaque")
if err != nil {
t.Errorf(`URL(%q).Parse("scheme:opaque") failed: %v`, test.base, err)
} else if *url != *opaque {
- t.Errorf("Parse failed to resolve opaque URL: want %#v, got %#v", url, opaque)
+ t.Errorf("Parse failed to resolve opaque URL:\ngot %#v\nwant %#v", opaque, url)
} else if base == url {
// Ensure that new instances are returned, again.
t.Errorf("Expected URL.Parse to return new URL instance.")
@@ -1271,7 +1385,7 @@ func TestParseFailure(t *testing.T) {
}
}
-func TestParseAuthority(t *testing.T) {
+func TestParseErrors(t *testing.T) {
tests := []struct {
in string
wantErr bool
@@ -1291,9 +1405,13 @@ func TestParseAuthority(t *testing.T) {
{"http://%41:8080/", true}, // not allowed: % encoding only for non-ASCII
{"mysql://x@y(z:123)/foo", false}, // golang.org/issue/12023
{"mysql://x@y(1.2.3.4:123)/foo", false},
- {"mysql://x@y([2001:db8::1]:123)/foo", false},
+
{"http://[]%20%48%54%54%50%2f%31%2e%31%0a%4d%79%48%65%61%64%65%72%3a%20%31%32%33%0a%0a/", true}, // golang.org/issue/11208
{"http://a b.com/", true}, // no space in host name please
+ {"cache_object://foo", true}, // scheme cannot have _, relative path cannot have : in first segment
+ {"cache_object:foo", true},
+ {"cache_object:foo/bar", true},
+ {"cache_object/:foo/bar", false},
}
for _, tt := range tests {
u, err := Parse(tt.in)
@@ -1462,11 +1580,106 @@ func TestURLErrorImplementsNetError(t *testing.T) {
continue
}
if err.Timeout() != tt.timeout {
- t.Errorf("%d: err.Timeout(): want %v, have %v", i+1, tt.timeout, err.Timeout())
+ t.Errorf("%d: err.Timeout(): got %v, want %v", i+1, err.Timeout(), tt.timeout)
continue
}
if err.Temporary() != tt.temporary {
- t.Errorf("%d: err.Temporary(): want %v, have %v", i+1, tt.temporary, err.Temporary())
+ t.Errorf("%d: err.Temporary(): got %v, want %v", i+1, err.Temporary(), tt.temporary)
+ }
+ }
+}
+
+func TestURLHostname(t *testing.T) {
+ tests := []struct {
+ host string // URL.Host field
+ want string
+ }{
+ {"foo.com:80", "foo.com"},
+ {"foo.com", "foo.com"},
+ {"FOO.COM", "FOO.COM"}, // no canonicalization (yet?)
+ {"1.2.3.4", "1.2.3.4"},
+ {"1.2.3.4:80", "1.2.3.4"},
+ {"[1:2:3:4]", "1:2:3:4"},
+ {"[1:2:3:4]:80", "1:2:3:4"},
+ {"[::1]:80", "::1"},
+ }
+ for _, tt := range tests {
+ u := &URL{Host: tt.host}
+ got := u.Hostname()
+ if got != tt.want {
+ t.Errorf("Hostname for Host %q = %q; want %q", tt.host, got, tt.want)
+ }
+ }
+}
+
+func TestURLPort(t *testing.T) {
+ tests := []struct {
+ host string // URL.Host field
+ want string
+ }{
+ {"foo.com", ""},
+ {"foo.com:80", "80"},
+ {"1.2.3.4", ""},
+ {"1.2.3.4:80", "80"},
+ {"[1:2:3:4]", ""},
+ {"[1:2:3:4]:80", "80"},
+ }
+ for _, tt := range tests {
+ u := &URL{Host: tt.host}
+ got := u.Port()
+ if got != tt.want {
+ t.Errorf("Port for Host %q = %q; want %q", tt.host, got, tt.want)
}
}
}
+
+var _ encodingPkg.BinaryMarshaler = (*URL)(nil)
+var _ encodingPkg.BinaryUnmarshaler = (*URL)(nil)
+
+func TestJSON(t *testing.T) {
+ u, err := Parse("https://www.google.com/x?y=z")
+ if err != nil {
+ t.Fatal(err)
+ }
+ js, err := json.Marshal(u)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // If only we could implement TextMarshaler/TextUnmarshaler,
+ // this would work:
+ //
+ // if string(js) != strconv.Quote(u.String()) {
+ // t.Errorf("json encoding: %s\nwant: %s\n", js, strconv.Quote(u.String()))
+ // }
+
+ u1 := new(URL)
+ err = json.Unmarshal(js, u1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if u1.String() != u.String() {
+ t.Errorf("json decoded to: %s\nwant: %s\n", u1, u)
+ }
+}
+
+func TestGob(t *testing.T) {
+ u, err := Parse("https://www.google.com/x?y=z")
+ if err != nil {
+ t.Fatal(err)
+ }
+ var w bytes.Buffer
+ err = gob.NewEncoder(&w).Encode(u)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ u1 := new(URL)
+ err = gob.NewDecoder(&w).Decode(u1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if u1.String() != u.String() {
+ t.Errorf("json decoded to: %s\nwant: %s\n", u1, u)
+ }
+}
diff --git a/libgo/go/net/writev_test.go b/libgo/go/net/writev_test.go
new file mode 100644
index 0000000..7160d28
--- /dev/null
+++ b/libgo/go/net/writev_test.go
@@ -0,0 +1,225 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "reflect"
+ "runtime"
+ "sync"
+ "testing"
+)
+
+func TestBuffers_read(t *testing.T) {
+ const story = "once upon a time in Gopherland ... "
+ buffers := Buffers{
+ []byte("once "),
+ []byte("upon "),
+ []byte("a "),
+ []byte("time "),
+ []byte("in "),
+ []byte("Gopherland ... "),
+ }
+ got, err := ioutil.ReadAll(&buffers)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(got) != story {
+ t.Errorf("read %q; want %q", got, story)
+ }
+ if len(buffers) != 0 {
+ t.Errorf("len(buffers) = %d; want 0", len(buffers))
+ }
+}
+
+func TestBuffers_consume(t *testing.T) {
+ tests := []struct {
+ in Buffers
+ consume int64
+ want Buffers
+ }{
+ {
+ in: Buffers{[]byte("foo"), []byte("bar")},
+ consume: 0,
+ want: Buffers{[]byte("foo"), []byte("bar")},
+ },
+ {
+ in: Buffers{[]byte("foo"), []byte("bar")},
+ consume: 2,
+ want: Buffers{[]byte("o"), []byte("bar")},
+ },
+ {
+ in: Buffers{[]byte("foo"), []byte("bar")},
+ consume: 3,
+ want: Buffers{[]byte("bar")},
+ },
+ {
+ in: Buffers{[]byte("foo"), []byte("bar")},
+ consume: 4,
+ want: Buffers{[]byte("ar")},
+ },
+ {
+ in: Buffers{nil, nil, nil, []byte("bar")},
+ consume: 1,
+ want: Buffers{[]byte("ar")},
+ },
+ {
+ in: Buffers{nil, nil, nil, []byte("foo")},
+ consume: 0,
+ want: Buffers{[]byte("foo")},
+ },
+ {
+ in: Buffers{nil, nil, nil},
+ consume: 0,
+ want: Buffers{},
+ },
+ }
+ for i, tt := range tests {
+ in := tt.in
+ in.consume(tt.consume)
+ if !reflect.DeepEqual(in, tt.want) {
+ t.Errorf("%d. after consume(%d) = %+v, want %+v", i, tt.consume, in, tt.want)
+ }
+ }
+}
+
+func TestBuffers_WriteTo(t *testing.T) {
+ for _, name := range []string{"WriteTo", "Copy"} {
+ for _, size := range []int{0, 10, 1023, 1024, 1025} {
+ t.Run(fmt.Sprintf("%s/%d", name, size), func(t *testing.T) {
+ testBuffer_writeTo(t, size, name == "Copy")
+ })
+ }
+ }
+}
+
+func testBuffer_writeTo(t *testing.T, chunks int, useCopy bool) {
+ oldHook := testHookDidWritev
+ defer func() { testHookDidWritev = oldHook }()
+ var writeLog struct {
+ sync.Mutex
+ log []int
+ }
+ testHookDidWritev = func(size int) {
+ writeLog.Lock()
+ writeLog.log = append(writeLog.log, size)
+ writeLog.Unlock()
+ }
+ var want bytes.Buffer
+ for i := 0; i < chunks; i++ {
+ want.WriteByte(byte(i))
+ }
+
+ withTCPConnPair(t, func(c *TCPConn) error {
+ buffers := make(Buffers, chunks)
+ for i := range buffers {
+ buffers[i] = want.Bytes()[i : i+1]
+ }
+ var n int64
+ var err error
+ if useCopy {
+ n, err = io.Copy(c, &buffers)
+ } else {
+ n, err = buffers.WriteTo(c)
+ }
+ if err != nil {
+ return err
+ }
+ if len(buffers) != 0 {
+ return fmt.Errorf("len(buffers) = %d; want 0", len(buffers))
+ }
+ if n != int64(want.Len()) {
+ return fmt.Errorf("Buffers.WriteTo returned %d; want %d", n, want.Len())
+ }
+ return nil
+ }, func(c *TCPConn) error {
+ all, err := ioutil.ReadAll(c)
+ if !bytes.Equal(all, want.Bytes()) || err != nil {
+ return fmt.Errorf("client read %q, %v; want %q, nil", all, err, want.Bytes())
+ }
+
+ writeLog.Lock() // no need to unlock
+ var gotSum int
+ for _, v := range writeLog.log {
+ gotSum += v
+ }
+
+ var wantSum int
+ switch runtime.GOOS {
+ case "android", "darwin", "dragonfly", "freebsd", "linux", "netbsd", "openbsd":
+ var wantMinCalls int
+ wantSum = want.Len()
+ v := chunks
+ for v > 0 {
+ wantMinCalls++
+ v -= 1024
+ }
+ if len(writeLog.log) < wantMinCalls {
+ t.Errorf("write calls = %v < wanted min %v", len(writeLog.log), wantMinCalls)
+ }
+ case "windows":
+ var wantCalls int
+ wantSum = want.Len()
+ if wantSum > 0 {
+ wantCalls = 1 // windows will always do 1 syscall, unless sending empty buffer
+ }
+ if len(writeLog.log) != wantCalls {
+ t.Errorf("write calls = %v; want %v", len(writeLog.log), wantCalls)
+ }
+ }
+ if gotSum != wantSum {
+ t.Errorf("writev call sum = %v; want %v", gotSum, wantSum)
+ }
+ return nil
+ })
+}
+
+func TestWritevError(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skipf("skipping the test: windows does not have problem sending large chunks of data")
+ }
+
+ ln, err := newLocalListener("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+
+ ch := make(chan Conn, 1)
+ go func() {
+ defer close(ch)
+ c, err := ln.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ ch <- c
+ }()
+ c1, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c1.Close()
+ c2 := <-ch
+ if c2 == nil {
+ t.Fatal("no server side connection")
+ }
+ c2.Close()
+
+ // 1 GB of data should be enough to notice the connection is gone.
+ // Just a few bytes is not enough.
+ // Arrange to reuse the same 1 MB buffer so that we don't allocate much.
+ buf := make([]byte, 1<<20)
+ buffers := make(Buffers, 1<<10)
+ for i := range buffers {
+ buffers[i] = buf
+ }
+ if _, err := buffers.WriteTo(c1); err == nil {
+ t.Fatal("Buffers.WriteTo(closed conn) succeeded, want error")
+ }
+}
diff --git a/libgo/go/net/writev_unix.go b/libgo/go/net/writev_unix.go
new file mode 100644
index 0000000..174e6bc
--- /dev/null
+++ b/libgo/go/net/writev_unix.go
@@ -0,0 +1,95 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin dragonfly freebsd linux netbsd openbsd
+
+package net
+
+import (
+ "io"
+ "os"
+ "syscall"
+ "unsafe"
+)
+
+func (c *conn) writeBuffers(v *Buffers) (int64, error) {
+ if !c.ok() {
+ return 0, syscall.EINVAL
+ }
+ n, err := c.fd.writeBuffers(v)
+ if err != nil {
+ return n, &OpError{Op: "writev", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+ }
+ return n, nil
+}
+
+func (fd *netFD) writeBuffers(v *Buffers) (n int64, err error) {
+ if err := fd.writeLock(); err != nil {
+ return 0, err
+ }
+ defer fd.writeUnlock()
+ if err := fd.pd.prepareWrite(); err != nil {
+ return 0, err
+ }
+
+ var iovecs []syscall.Iovec
+ if fd.iovecs != nil {
+ iovecs = *fd.iovecs
+ }
+ // TODO: read from sysconf(_SC_IOV_MAX)? The Linux default is
+ // 1024 and this seems conservative enough for now. Darwin's
+ // UIO_MAXIOV also seems to be 1024.
+ maxVec := 1024
+
+ for len(*v) > 0 {
+ iovecs = iovecs[:0]
+ for _, chunk := range *v {
+ if len(chunk) == 0 {
+ continue
+ }
+ iovecs = append(iovecs, syscall.Iovec{Base: &chunk[0]})
+ if fd.isStream && len(chunk) > 1<<30 {
+ iovecs[len(iovecs)-1].SetLen(1 << 30)
+ break // continue chunk on next writev
+ }
+ iovecs[len(iovecs)-1].SetLen(len(chunk))
+ if len(iovecs) == maxVec {
+ break
+ }
+ }
+ if len(iovecs) == 0 {
+ break
+ }
+ fd.iovecs = &iovecs // cache
+
+ wrote, _, e0 := syscall.Syscall(syscall.SYS_WRITEV,
+ uintptr(fd.sysfd),
+ uintptr(unsafe.Pointer(&iovecs[0])),
+ uintptr(len(iovecs)))
+ if wrote == ^uintptr(0) {
+ wrote = 0
+ }
+ testHookDidWritev(int(wrote))
+ n += int64(wrote)
+ v.consume(int64(wrote))
+ if e0 == syscall.EAGAIN {
+ if err = fd.pd.waitWrite(); err == nil {
+ continue
+ }
+ } else if e0 != 0 {
+ err = syscall.Errno(e0)
+ }
+ if err != nil {
+ break
+ }
+ if n == 0 {
+ err = io.ErrUnexpectedEOF
+ break
+ }
+ }
+ if _, ok := err.(syscall.Errno); ok {
+ err = os.NewSyscallError("writev", err)
+ }
+ return n, err
+}