aboutsummaryrefslogtreecommitdiff
path: root/libgo/go/net
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/net')
-rw-r--r--libgo/go/net/cgo_unix.go2
-rw-r--r--libgo/go/net/conf.go11
-rw-r--r--libgo/go/net/conf_test.go20
-rw-r--r--libgo/go/net/conn_test.go2
-rw-r--r--libgo/go/net/dial.go176
-rw-r--r--libgo/go/net/dial_test.go88
-rw-r--r--libgo/go/net/dial_unix_test.go5
-rw-r--r--libgo/go/net/dnsclient.go70
-rw-r--r--libgo/go/net/dnsclient_test.go48
-rw-r--r--libgo/go/net/dnsclient_unix.go426
-rw-r--r--libgo/go/net/dnsclient_unix_test.go680
-rw-r--r--libgo/go/net/dnsconfig_unix.go4
-rw-r--r--libgo/go/net/dnsmsg.go884
-rw-r--r--libgo/go/net/dnsmsg_test.go481
-rw-r--r--libgo/go/net/dnsname_test.go2
-rw-r--r--libgo/go/net/error_nacl.go9
-rw-r--r--libgo/go/net/error_plan9.go9
-rw-r--r--libgo/go/net/error_posix.go2
-rw-r--r--libgo/go/net/error_test.go2
-rw-r--r--libgo/go/net/error_unix.go16
-rw-r--r--libgo/go/net/error_windows.go14
-rw-r--r--libgo/go/net/external_test.go2
-rw-r--r--libgo/go/net/fd_plan9.go13
-rw-r--r--libgo/go/net/fd_unix.go84
-rw-r--r--libgo/go/net/fd_windows.go15
-rw-r--r--libgo/go/net/file.go2
-rw-r--r--libgo/go/net/file_stub.go2
-rw-r--r--libgo/go/net/file_test.go56
-rw-r--r--libgo/go/net/file_unix.go5
-rw-r--r--libgo/go/net/hook_unix.go2
-rw-r--r--libgo/go/net/hosts.go2
-rw-r--r--libgo/go/net/http/cgi/child.go2
-rw-r--r--libgo/go/net/http/cgi/host_test.go3
-rw-r--r--libgo/go/net/http/client.go73
-rw-r--r--libgo/go/net/http/client_test.go34
-rw-r--r--libgo/go/net/http/clientserver_test.go1
-rw-r--r--libgo/go/net/http/cookie.go50
-rw-r--r--libgo/go/net/http/cookie_test.go39
-rw-r--r--libgo/go/net/http/cookiejar/jar.go10
-rw-r--r--libgo/go/net/http/example_test.go22
-rw-r--r--libgo/go/net/http/export_test.go43
-rw-r--r--libgo/go/net/http/fcgi/fcgi.go5
-rw-r--r--libgo/go/net/http/fs.go40
-rw-r--r--libgo/go/net/http/fs_test.go2
-rw-r--r--libgo/go/net/http/h2_bundle.go476
-rw-r--r--libgo/go/net/http/header.go19
-rw-r--r--libgo/go/net/http/http.go4
-rw-r--r--libgo/go/net/http/httptest/httptest_test.go75
-rw-r--r--libgo/go/net/http/httptest/recorder.go20
-rw-r--r--libgo/go/net/http/httptest/recorder_test.go46
-rw-r--r--libgo/go/net/http/httptest/server.go4
-rw-r--r--libgo/go/net/http/httptrace/trace.go14
-rw-r--r--libgo/go/net/http/httputil/httputil.go4
-rw-r--r--libgo/go/net/http/httputil/reverseproxy.go110
-rw-r--r--libgo/go/net/http/httputil/reverseproxy_test.go149
-rw-r--r--libgo/go/net/http/internal/chunked.go4
-rw-r--r--libgo/go/net/http/main_test.go10
-rw-r--r--libgo/go/net/http/pprof/pprof.go75
-rw-r--r--libgo/go/net/http/pprof/pprof_test.go12
-rw-r--r--libgo/go/net/http/proxy_test.go39
-rw-r--r--libgo/go/net/http/readrequest_test.go2
-rw-r--r--libgo/go/net/http/request.go79
-rw-r--r--libgo/go/net/http/request_test.go19
-rw-r--r--libgo/go/net/http/response.go8
-rw-r--r--libgo/go/net/http/response_test.go4
-rw-r--r--libgo/go/net/http/roundtrip.go18
-rw-r--r--libgo/go/net/http/roundtrip_js.go293
-rw-r--r--libgo/go/net/http/serve_test.go328
-rw-r--r--libgo/go/net/http/server.go401
-rw-r--r--libgo/go/net/http/sniff.go13
-rw-r--r--libgo/go/net/http/sniff_test.go12
-rw-r--r--libgo/go/net/http/socks_bundle.go472
-rw-r--r--libgo/go/net/http/status.go4
-rw-r--r--libgo/go/net/http/transfer.go61
-rw-r--r--libgo/go/net/http/transfer_test.go28
-rw-r--r--libgo/go/net/http/transport.go448
-rw-r--r--libgo/go/net/http/transport_test.go563
-rw-r--r--libgo/go/net/http/triv.go4
-rw-r--r--libgo/go/net/interface.go2
-rw-r--r--libgo/go/net/interface_stub.go2
-rw-r--r--libgo/go/net/interface_test.go4
-rw-r--r--libgo/go/net/interface_windows.go96
-rw-r--r--libgo/go/net/interface_windows_test.go132
-rw-r--r--libgo/go/net/internal/socktest/main_test.go2
-rw-r--r--libgo/go/net/internal/socktest/main_unix_test.go2
-rw-r--r--libgo/go/net/internal/socktest/switch_unix.go2
-rw-r--r--libgo/go/net/internal/socktest/sys_cloexec.go2
-rw-r--r--libgo/go/net/internal/socktest/sys_unix.go2
-rw-r--r--libgo/go/net/ip.go100
-rw-r--r--libgo/go/net/ip_test.go18
-rw-r--r--libgo/go/net/iprawsock.go16
-rw-r--r--libgo/go/net/iprawsock_plan9.go4
-rw-r--r--libgo/go/net/iprawsock_posix.go21
-rw-r--r--libgo/go/net/iprawsock_test.go2
-rw-r--r--libgo/go/net/ipsock.go28
-rw-r--r--libgo/go/net/ipsock_posix.go6
-rw-r--r--libgo/go/net/listen_test.go56
-rw-r--r--libgo/go/net/lookup.go37
-rw-r--r--libgo/go/net/lookup_fake.go (renamed from libgo/go/net/lookup_nacl.go)8
-rw-r--r--libgo/go/net/lookup_plan9.go76
-rw-r--r--libgo/go/net/lookup_test.go288
-rw-r--r--libgo/go/net/lookup_unix.go215
-rw-r--r--libgo/go/net/lookup_windows.go33
-rw-r--r--libgo/go/net/mail/message.go3
-rw-r--r--libgo/go/net/main_cloexec_test.go2
-rw-r--r--libgo/go/net/main_conf_test.go2
-rw-r--r--libgo/go/net/main_noconf_test.go2
-rw-r--r--libgo/go/net/main_posix_test.go2
-rw-r--r--libgo/go/net/main_test.go2
-rw-r--r--libgo/go/net/mockserver_test.go2
-rw-r--r--libgo/go/net/net.go37
-rw-r--r--libgo/go/net/net_fake.go284
-rw-r--r--libgo/go/net/net_test.go33
-rw-r--r--libgo/go/net/packetconn_test.go2
-rw-r--r--libgo/go/net/port_unix.go2
-rw-r--r--libgo/go/net/protoconn_test.go17
-rw-r--r--libgo/go/net/rawconn.go11
-rw-r--r--libgo/go/net/rawconn_stub_test.go28
-rw-r--r--libgo/go/net/rawconn_test.go220
-rw-r--r--libgo/go/net/rawconn_unix_test.go193
-rw-r--r--libgo/go/net/rawconn_windows_test.go159
-rw-r--r--libgo/go/net/rpc/client.go9
-rw-r--r--libgo/go/net/rpc/server.go5
-rw-r--r--libgo/go/net/sendfile_solaris.go63
-rw-r--r--libgo/go/net/sendfile_stub.go2
-rw-r--r--libgo/go/net/sendfile_test.go122
-rw-r--r--libgo/go/net/sendfile_unix_alt.go (renamed from libgo/go/net/sendfile_bsd.go)12
-rw-r--r--libgo/go/net/server_test.go2
-rw-r--r--libgo/go/net/smtp/smtp.go9
-rw-r--r--libgo/go/net/smtp/smtp_test.go54
-rw-r--r--libgo/go/net/sock_cloexec.go2
-rw-r--r--libgo/go/net/sock_posix.go117
-rw-r--r--libgo/go/net/sock_stub.go2
-rw-r--r--libgo/go/net/sockaddr_posix.go34
-rw-r--r--libgo/go/net/sockopt_stub.go2
-rw-r--r--libgo/go/net/sockoptip_stub.go2
-rw-r--r--libgo/go/net/splice_linux.go35
-rw-r--r--libgo/go/net/splice_stub.go13
-rw-r--r--libgo/go/net/splice_test.go489
-rw-r--r--libgo/go/net/sys_cloexec.go2
-rw-r--r--libgo/go/net/tcpsock.go14
-rw-r--r--libgo/go/net/tcpsock_plan9.go18
-rw-r--r--libgo/go/net/tcpsock_posix.go23
-rw-r--r--libgo/go/net/tcpsock_test.go2
-rw-r--r--libgo/go/net/tcpsock_unix_test.go8
-rw-r--r--libgo/go/net/tcpsockopt_darwin.go4
-rw-r--r--libgo/go/net/tcpsockopt_stub.go2
-rw-r--r--libgo/go/net/textproto/reader.go2
-rw-r--r--libgo/go/net/textproto/reader_test.go2
-rw-r--r--libgo/go/net/timeout_test.go4
-rw-r--r--libgo/go/net/udpsock.go12
-rw-r--r--libgo/go/net/udpsock_plan9.go12
-rw-r--r--libgo/go/net/udpsock_posix.go14
-rw-r--r--libgo/go/net/udpsock_test.go61
-rw-r--r--libgo/go/net/unixsock.go16
-rw-r--r--libgo/go/net/unixsock_plan9.go6
-rw-r--r--libgo/go/net/unixsock_posix.go18
-rw-r--r--libgo/go/net/unixsock_test.go2
-rw-r--r--libgo/go/net/url/url.go32
-rw-r--r--libgo/go/net/url/url_test.go5
-rw-r--r--libgo/go/net/writev_test.go2
161 files changed, 6938 insertions, 3728 deletions
diff --git a/libgo/go/net/cgo_unix.go b/libgo/go/net/cgo_unix.go
index 5ea13bc..a4be3ba 100644
--- a/libgo/go/net/cgo_unix.go
+++ b/libgo/go/net/cgo_unix.go
@@ -278,7 +278,7 @@ func cgoLookupPTR(ctx context.Context, addr string) (names []string, err error,
var zone string
ip := parseIPv4(addr)
if ip == nil {
- ip, zone = parseIPv6(addr, true)
+ ip, zone = parseIPv6Zone(addr)
}
if ip == nil {
return nil, &DNSError{Err: "invalid address", Name: addr}, true
diff --git a/libgo/go/net/conf.go b/libgo/go/net/conf.go
index a798699..6cc4a99 100644
--- a/libgo/go/net/conf.go
+++ b/libgo/go/net/conf.go
@@ -114,18 +114,19 @@ func initConfVal() {
// canUseCgo reports whether calling cgo functions is allowed
// for non-hostname lookups.
func (c *conf) canUseCgo() bool {
- return c.hostLookupOrder("") == hostLookupCgo
+ return c.hostLookupOrder(nil, "") == hostLookupCgo
}
// hostLookupOrder determines which strategy to use to resolve hostname.
-func (c *conf) hostLookupOrder(hostname string) (ret hostLookupOrder) {
+// The provided Resolver is optional. nil means to not consider its options.
+func (c *conf) hostLookupOrder(r *Resolver, hostname string) (ret hostLookupOrder) {
if c.dnsDebugLevel > 1 {
defer func() {
print("go package net: hostLookupOrder(", hostname, ") = ", ret.String(), "\n")
}()
}
fallbackOrder := hostLookupCgo
- if c.netGo {
+ if c.netGo || r.preferGo() {
fallbackOrder = hostLookupFilesDNS
}
if c.forceCgoLookupHost || c.resolv.unknownOpt || c.goos == "android" {
@@ -148,7 +149,7 @@ func (c *conf) hostLookupOrder(hostname string) (ret hostLookupOrder) {
}
lookup := c.resolv.lookup
if len(lookup) == 0 {
- // http://www.openbsd.org/cgi-bin/man.cgi/OpenBSD-current/man5/resolv.conf.5
+ // https://www.openbsd.org/cgi-bin/man.cgi/OpenBSD-current/man5/resolv.conf.5
// "If the lookup keyword is not used in the
// system's resolv.conf file then the assumed
// order is 'bind file'"
@@ -202,7 +203,7 @@ func (c *conf) hostLookupOrder(hostname string) (ret hostLookupOrder) {
}
if c.goos == "linux" {
// glibc says the default is "dns [!UNAVAIL=return] files"
- // http://www.gnu.org/software/libc/manual/html_node/Notes-on-NSS-Configuration-File.html.
+ // https://www.gnu.org/software/libc/manual/html_node/Notes-on-NSS-Configuration-File.html.
return hostLookupDNSFiles
}
return hostLookupFilesDNS
diff --git a/libgo/go/net/conf_test.go b/libgo/go/net/conf_test.go
index 17d03f4..3c7403e 100644
--- a/libgo/go/net/conf_test.go
+++ b/libgo/go/net/conf_test.go
@@ -33,6 +33,7 @@ func TestConfHostLookupOrder(t *testing.T) {
tests := []struct {
name string
c *conf
+ resolver *Resolver
hostTests []nssHostTest
}{
{
@@ -170,7 +171,7 @@ func TestConfHostLookupOrder(t *testing.T) {
hostTests: []nssHostTest{{"google.com", "myhostname", hostLookupDNSFiles}},
},
// glibc lacking an nsswitch.conf, per
- // http://www.gnu.org/software/libc/manual/html_node/Notes-on-NSS-Configuration-File.html
+ // https://www.gnu.org/software/libc/manual/html_node/Notes-on-NSS-Configuration-File.html
{
name: "linux_no_nsswitch.conf",
c: &conf{
@@ -322,6 +323,21 @@ func TestConfHostLookupOrder(t *testing.T) {
{"x.com", "myhostname", hostLookupCgo},
},
},
+ // Issue 24393: make sure "Resolver.PreferGo = true" acts like netgo.
+ {
+ name: "resolver-prefergo",
+ resolver: &Resolver{PreferGo: true},
+ c: &conf{
+ goos: "darwin",
+ forceCgoLookupHost: true, // always true for darwin
+ resolv: defaultResolvConf,
+ nss: nssStr(""),
+ netCgo: true,
+ },
+ hostTests: []nssHostTest{
+ {"localhost", "myhostname", hostLookupFilesDNS},
+ },
+ },
}
origGetHostname := getHostname
@@ -331,7 +347,7 @@ func TestConfHostLookupOrder(t *testing.T) {
for _, ht := range tt.hostTests {
getHostname = func() (string, error) { return ht.localhost, nil }
- gotOrder := tt.c.hostLookupOrder(ht.host)
+ gotOrder := tt.c.hostLookupOrder(tt.resolver, ht.host)
if gotOrder != ht.want {
t.Errorf("%s: hostLookupOrder(%q) = %v; want %v", tt.name, ht.host, gotOrder, ht.want)
}
diff --git a/libgo/go/net/conn_test.go b/libgo/go/net/conn_test.go
index 16cf69e..6854898 100644
--- a/libgo/go/net/conn_test.go
+++ b/libgo/go/net/conn_test.go
@@ -5,6 +5,8 @@
// This file implements API tests across platforms and will never have a build
// tag.
+// +build !js
+
package net
import (
diff --git a/libgo/go/net/dial.go b/libgo/go/net/dial.go
index f8b4aa2..b1a5ca7 100644
--- a/libgo/go/net/dial.go
+++ b/libgo/go/net/dial.go
@@ -8,6 +8,7 @@ import (
"context"
"internal/nettrace"
"internal/poll"
+ "syscall"
"time"
)
@@ -70,6 +71,14 @@ type Dialer struct {
//
// Deprecated: Use DialContext instead.
Cancel <-chan struct{}
+
+ // If Control is not nil, it is called after creating the network
+ // connection but before actually dialing.
+ //
+ // Network and address parameters passed to Control method are not
+ // necessarily the ones passed to Dial. For example, passing "tcp" to Dial
+ // will cause the Control function to be called with "tcp4" or "tcp6".
+ Control func(network, address string, c syscall.RawConn) error
}
func minNonzeroTime(a, b time.Time) time.Time {
@@ -306,8 +315,8 @@ func DialTimeout(network, address string, timeout time.Duration) (Conn, error) {
return d.Dial(network, address)
}
-// dialParam contains a Dial's parameters and configuration.
-type dialParam struct {
+// sysDialer contains a Dial's parameters and configuration.
+type sysDialer struct {
Dialer
network, address string
}
@@ -377,7 +386,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn
return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
}
- dp := &dialParam{
+ sd := &sysDialer{
Dialer: *d,
network: network,
address: address,
@@ -392,9 +401,9 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn
var c Conn
if len(fallbacks) > 0 {
- c, err = dialParallel(ctx, dp, primaries, fallbacks)
+ c, err = sd.dialParallel(ctx, primaries, fallbacks)
} else {
- c, err = dialSerial(ctx, dp, primaries)
+ c, err = sd.dialSerial(ctx, primaries)
}
if err != nil {
return nil, err
@@ -412,9 +421,9 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn
// head start. It returns the first established connection and
// closes the others. Otherwise it returns an error from the first
// primary address.
-func dialParallel(ctx context.Context, dp *dialParam, primaries, fallbacks addrList) (Conn, error) {
+func (sd *sysDialer) dialParallel(ctx context.Context, primaries, fallbacks addrList) (Conn, error) {
if len(fallbacks) == 0 {
- return dialSerial(ctx, dp, primaries)
+ return sd.dialSerial(ctx, primaries)
}
returned := make(chan struct{})
@@ -433,7 +442,7 @@ func dialParallel(ctx context.Context, dp *dialParam, primaries, fallbacks addrL
if !primary {
ras = fallbacks
}
- c, err := dialSerial(ctx, dp, ras)
+ c, err := sd.dialSerial(ctx, ras)
select {
case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
case <-returned:
@@ -451,7 +460,7 @@ func dialParallel(ctx context.Context, dp *dialParam, primaries, fallbacks addrL
go startRacer(primaryCtx, true)
// Start the timer for the fallback racer.
- fallbackTimer := time.NewTimer(dp.fallbackDelay())
+ fallbackTimer := time.NewTimer(sd.fallbackDelay())
defer fallbackTimer.Stop()
for {
@@ -486,13 +495,13 @@ func dialParallel(ctx context.Context, dp *dialParam, primaries, fallbacks addrL
// dialSerial connects to a list of addresses in sequence, returning
// either the first successful connection, or the first error.
-func dialSerial(ctx context.Context, dp *dialParam, ras addrList) (Conn, error) {
+func (sd *sysDialer) dialSerial(ctx context.Context, ras addrList) (Conn, error) {
var firstErr error // The error from the first address is most relevant.
for i, ra := range ras {
select {
case <-ctx.Done():
- return nil, &OpError{Op: "dial", Net: dp.network, Source: dp.LocalAddr, Addr: ra, Err: mapErr(ctx.Err())}
+ return nil, &OpError{Op: "dial", Net: sd.network, Source: sd.LocalAddr, Addr: ra, Err: mapErr(ctx.Err())}
default:
}
@@ -501,7 +510,7 @@ func dialSerial(ctx context.Context, dp *dialParam, ras addrList) (Conn, error)
if err != nil {
// Ran out of time.
if firstErr == nil {
- firstErr = &OpError{Op: "dial", Net: dp.network, Source: dp.LocalAddr, Addr: ra, Err: err}
+ firstErr = &OpError{Op: "dial", Net: sd.network, Source: sd.LocalAddr, Addr: ra, Err: err}
}
break
}
@@ -512,7 +521,7 @@ func dialSerial(ctx context.Context, dp *dialParam, ras addrList) (Conn, error)
defer cancel()
}
- c, err := dialSingle(dialCtx, dp, ra)
+ c, err := sd.dialSingle(dialCtx, ra)
if err == nil {
return c, nil
}
@@ -522,47 +531,126 @@ func dialSerial(ctx context.Context, dp *dialParam, ras addrList) (Conn, error)
}
if firstErr == nil {
- firstErr = &OpError{Op: "dial", Net: dp.network, Source: nil, Addr: nil, Err: errMissingAddress}
+ firstErr = &OpError{Op: "dial", Net: sd.network, Source: nil, Addr: nil, Err: errMissingAddress}
}
return nil, firstErr
}
// dialSingle attempts to establish and returns a single connection to
// the destination address.
-func dialSingle(ctx context.Context, dp *dialParam, ra Addr) (c Conn, err error) {
+func (sd *sysDialer) dialSingle(ctx context.Context, ra Addr) (c Conn, err error) {
trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace)
if trace != nil {
raStr := ra.String()
if trace.ConnectStart != nil {
- trace.ConnectStart(dp.network, raStr)
+ trace.ConnectStart(sd.network, raStr)
}
if trace.ConnectDone != nil {
- defer func() { trace.ConnectDone(dp.network, raStr, err) }()
+ defer func() { trace.ConnectDone(sd.network, raStr, err) }()
}
}
- la := dp.LocalAddr
+ la := sd.LocalAddr
switch ra := ra.(type) {
case *TCPAddr:
la, _ := la.(*TCPAddr)
- c, err = dialTCP(ctx, dp.network, la, ra)
+ c, err = sd.dialTCP(ctx, la, ra)
case *UDPAddr:
la, _ := la.(*UDPAddr)
- c, err = dialUDP(ctx, dp.network, la, ra)
+ c, err = sd.dialUDP(ctx, la, ra)
case *IPAddr:
la, _ := la.(*IPAddr)
- c, err = dialIP(ctx, dp.network, la, ra)
+ c, err = sd.dialIP(ctx, la, ra)
case *UnixAddr:
la, _ := la.(*UnixAddr)
- c, err = dialUnix(ctx, dp.network, la, ra)
+ c, err = sd.dialUnix(ctx, la, ra)
+ default:
+ return nil, &OpError{Op: "dial", Net: sd.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: sd.address}}
+ }
+ if err != nil {
+ return nil, &OpError{Op: "dial", Net: sd.network, Source: la, Addr: ra, Err: err} // c is non-nil interface containing nil pointer
+ }
+ return c, nil
+}
+
+// ListenConfig contains options for listening to an address.
+type ListenConfig struct {
+ // If Control is not nil, it is called after creating the network
+ // connection but before binding it to the operating system.
+ //
+ // Network and address parameters passed to Control method are not
+ // necessarily the ones passed to Listen. For example, passing "tcp" to
+ // Listen will cause the Control function to be called with "tcp4" or "tcp6".
+ Control func(network, address string, c syscall.RawConn) error
+}
+
+// Listen announces on the local network address.
+//
+// See func Listen for a description of the network and address
+// parameters.
+func (lc *ListenConfig) Listen(ctx context.Context, network, address string) (Listener, error) {
+ addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
+ if err != nil {
+ return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
+ }
+ sl := &sysListener{
+ ListenConfig: *lc,
+ network: network,
+ address: address,
+ }
+ var l Listener
+ la := addrs.first(isIPv4)
+ switch la := la.(type) {
+ case *TCPAddr:
+ l, err = sl.listenTCP(ctx, la)
+ case *UnixAddr:
+ l, err = sl.listenUnix(ctx, la)
+ default:
+ return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
+ }
+ if err != nil {
+ return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err} // l is non-nil interface containing nil pointer
+ }
+ return l, nil
+}
+
+// ListenPacket announces on the local network address.
+//
+// See func ListenPacket for a description of the network and address
+// parameters.
+func (lc *ListenConfig) ListenPacket(ctx context.Context, network, address string) (PacketConn, error) {
+ addrs, err := DefaultResolver.resolveAddrList(ctx, "listen", network, address, nil)
+ if err != nil {
+ return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
+ }
+ sl := &sysListener{
+ ListenConfig: *lc,
+ network: network,
+ address: address,
+ }
+ var c PacketConn
+ la := addrs.first(isIPv4)
+ switch la := la.(type) {
+ case *UDPAddr:
+ c, err = sl.listenUDP(ctx, la)
+ case *IPAddr:
+ c, err = sl.listenIP(ctx, la)
+ case *UnixAddr:
+ c, err = sl.listenUnixgram(ctx, la)
default:
- return nil, &OpError{Op: "dial", Net: dp.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: dp.address}}
+ return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
}
if err != nil {
- return nil, &OpError{Op: "dial", Net: dp.network, Source: la, Addr: ra, Err: err} // c is non-nil interface containing nil pointer
+ return nil, &OpError{Op: "listen", Net: sl.network, Source: nil, Addr: la, Err: err} // c is non-nil interface containing nil pointer
}
return c, nil
}
+// sysListener contains a Listen's parameters and configuration.
+type sysListener struct {
+ ListenConfig
+ network, address string
+}
+
// Listen announces on the local network address.
//
// The network must be "tcp", "tcp4", "tcp6", "unix" or "unixpacket".
@@ -582,23 +670,8 @@ func dialSingle(ctx context.Context, dp *dialParam, ra Addr) (c Conn, err error)
// See func Dial for a description of the network and address
// parameters.
func Listen(network, address string) (Listener, error) {
- addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", network, address, nil)
- if err != nil {
- return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
- }
- var l Listener
- switch la := addrs.first(isIPv4).(type) {
- case *TCPAddr:
- l, err = ListenTCP(network, la)
- case *UnixAddr:
- l, err = ListenUnix(network, la)
- default:
- return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
- }
- if err != nil {
- return nil, err // l is non-nil interface containing nil pointer
- }
- return l, nil
+ var lc ListenConfig
+ return lc.Listen(context.Background(), network, address)
}
// ListenPacket announces on the local network address.
@@ -624,23 +697,6 @@ func Listen(network, address string) (Listener, error) {
// See func Dial for a description of the network and address
// parameters.
func ListenPacket(network, address string) (PacketConn, error) {
- addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", network, address, nil)
- if err != nil {
- return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: err}
- }
- var l PacketConn
- switch la := addrs.first(isIPv4).(type) {
- case *UDPAddr:
- l, err = ListenUDP(network, la)
- case *IPAddr:
- l, err = ListenIP(network, la)
- case *UnixAddr:
- l, err = ListenUnixgram(network, la)
- default:
- return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: address}}
- }
- if err != nil {
- return nil, err // l is non-nil interface containing nil pointer
- }
- return l, nil
+ var lc ListenConfig
+ return lc.ListenPacket(context.Background(), network, address)
}
diff --git a/libgo/go/net/dial_test.go b/libgo/go/net/dial_test.go
index b5f1dc9..00a84d1 100644
--- a/libgo/go/net/dial_test.go
+++ b/libgo/go/net/dial_test.go
@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// +build !js
+
package net
import (
@@ -142,8 +144,9 @@ const (
// In some environments, the slow IPs may be explicitly unreachable, and fail
// more quickly than expected. This test hook prevents dialTCP from returning
// before the deadline.
-func slowDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
- c, err := doDialTCP(ctx, net, laddr, raddr)
+func slowDialTCP(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
+ sd := &sysDialer{network: network, address: raddr.String()}
+ c, err := sd.doDialTCP(ctx, laddr, raddr)
if ParseIP(slowDst4).Equal(raddr.IP) || ParseIP(slowDst6).Equal(raddr.IP) {
// Wait for the deadline, or indefinitely if none exists.
<-ctx.Done()
@@ -295,12 +298,12 @@ func TestDialParallel(t *testing.T) {
FallbackDelay: fallbackDelay,
}
startTime := time.Now()
- dp := &dialParam{
+ sd := &sysDialer{
Dialer: d,
network: "tcp",
address: "?",
}
- c, err := dialParallel(context.Background(), dp, primaries, fallbacks)
+ c, err := sd.dialParallel(context.Background(), primaries, fallbacks)
elapsed := time.Since(startTime)
if c != nil {
@@ -331,7 +334,7 @@ func TestDialParallel(t *testing.T) {
wg.Done()
}()
startTime = time.Now()
- c, err = dialParallel(ctx, dp, primaries, fallbacks)
+ c, err = sd.dialParallel(ctx, primaries, fallbacks)
if c != nil {
c.Close()
}
@@ -467,13 +470,14 @@ func TestDialParallelSpuriousConnection(t *testing.T) {
// Now ignore the provided context (which will be canceled) and use a
// different one to make sure this completes with a valid connection,
// which we hope to be closed below:
- return doDialTCP(context.Background(), net, laddr, raddr)
+ sd := &sysDialer{network: net, address: raddr.String()}
+ return sd.doDialTCP(context.Background(), laddr, raddr)
}
d := Dialer{
FallbackDelay: fallbackDelay,
}
- dp := &dialParam{
+ sd := &sysDialer{
Dialer: d,
network: "tcp",
address: "?",
@@ -488,7 +492,7 @@ func TestDialParallelSpuriousConnection(t *testing.T) {
}
// dialParallel returns one connection (and closes the other.)
- c, err := dialParallel(context.Background(), dp, makeAddr("127.0.0.1"), makeAddr("::1"))
+ c, err := sd.dialParallel(context.Background(), makeAddr("127.0.0.1"), makeAddr("::1"))
if err != nil {
t.Fatal(err)
}
@@ -749,9 +753,8 @@ func TestDialCancel(t *testing.T) {
switch testenv.Builder() {
case "linux-arm64-buildlet":
t.Skip("skipping on linux-arm64-buildlet; incompatible network config? issue 15191")
- case "":
- testenv.MustHaveExternalNetwork(t)
}
+ mustHaveExternalNetwork(t)
if runtime.GOOS == "nacl" {
// nacl doesn't have external network access.
@@ -897,9 +900,7 @@ func TestCancelAfterDial(t *testing.T) {
// if the machine has halfway configured IPv6 such that it can bind on
// "::" not connect back to that same address.
func TestDialListenerAddr(t *testing.T) {
- if testenv.Builder() == "" {
- testenv.MustHaveExternalNetwork(t)
- }
+ mustHaveExternalNetwork(t)
ln, err := Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
@@ -912,3 +913,64 @@ func TestDialListenerAddr(t *testing.T) {
}
c.Close()
}
+
+func TestDialerControl(t *testing.T) {
+ switch runtime.GOOS {
+ case "nacl", "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ t.Run("StreamDial", func(t *testing.T) {
+ for _, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} {
+ if !testableNetwork(network) {
+ continue
+ }
+ ln, err := newLocalListener(network)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ defer ln.Close()
+ d := Dialer{Control: controlOnConnSetup}
+ c, err := d.Dial(network, ln.Addr().String())
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ c.Close()
+ }
+ })
+ t.Run("PacketDial", func(t *testing.T) {
+ for _, network := range []string{"udp", "udp4", "udp6", "unixgram"} {
+ if !testableNetwork(network) {
+ continue
+ }
+ c1, err := newLocalPacketListener(network)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ if network == "unixgram" {
+ defer os.Remove(c1.LocalAddr().String())
+ }
+ defer c1.Close()
+ d := Dialer{Control: controlOnConnSetup}
+ c2, err := d.Dial(network, c1.LocalAddr().String())
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ c2.Close()
+ }
+ })
+}
+
+// mustHaveExternalNetwork is like testenv.MustHaveExternalNetwork
+// except that it won't skip testing on non-iOS builders.
+func mustHaveExternalNetwork(t *testing.T) {
+ t.Helper()
+ ios := runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64")
+ if testenv.Builder() == "" || ios {
+ testenv.MustHaveExternalNetwork(t)
+ }
+}
diff --git a/libgo/go/net/dial_unix_test.go b/libgo/go/net/dial_unix_test.go
index d5c6dde2..3cfc9d8 100644
--- a/libgo/go/net/dial_unix_test.go
+++ b/libgo/go/net/dial_unix_test.go
@@ -102,7 +102,8 @@ func TestDialContextCancelRace(t *testing.T) {
if !ok || oe.Op != "dial" {
t.Fatalf("Dial error = %#v; want dial *OpError", err)
}
- if oe.Err != ctx.Err() {
- t.Errorf("DialContext = (%v, %v); want OpError with error %v", c, err, ctx.Err())
+
+ if oe.Err != errCanceled {
+ t.Errorf("DialContext = (%v, %v); want OpError with error %v", c, err, errCanceled)
}
}
diff --git a/libgo/go/net/dnsclient.go b/libgo/go/net/dnsclient.go
index 2ab5639..2e4bffa 100644
--- a/libgo/go/net/dnsclient.go
+++ b/libgo/go/net/dnsclient.go
@@ -7,6 +7,8 @@ package net
import (
"math/rand"
"sort"
+
+ "golang_org/x/net/dns/dnsmessage"
)
// reverseaddr returns the in-addr.arpa. or ip6.arpa. hostname of the IP
@@ -35,71 +37,13 @@ func reverseaddr(addr string) (arpa string, err error) {
return string(buf), nil
}
-// Find answer for name in dns message.
-// On return, if err == nil, addrs != nil.
-func answer(name, server string, dns *dnsMsg, qtype uint16) (cname string, addrs []dnsRR, err error) {
- addrs = make([]dnsRR, 0, len(dns.answer))
-
- if dns.rcode == dnsRcodeNameError {
- return "", nil, &DNSError{Err: errNoSuchHost.Error(), Name: name, Server: server}
- }
- if dns.rcode != dnsRcodeSuccess {
- // None of the error codes make sense
- // for the query we sent. If we didn't get
- // a name error and we didn't get success,
- // the server is behaving incorrectly or
- // having temporary trouble.
- err := &DNSError{Err: "server misbehaving", Name: name, Server: server}
- if dns.rcode == dnsRcodeServerFailure {
- err.IsTemporary = true
- }
- return "", nil, err
- }
-
- // Look for the name.
- // Presotto says it's okay to assume that servers listed in
- // /etc/resolv.conf are recursive resolvers.
- // We asked for recursion, so it should have included
- // all the answers we need in this one packet.
-Cname:
- for cnameloop := 0; cnameloop < 10; cnameloop++ {
- addrs = addrs[0:0]
- for _, rr := range dns.answer {
- if _, justHeader := rr.(*dnsRR_Header); justHeader {
- // Corrupt record: we only have a
- // header. That header might say it's
- // of type qtype, but we don't
- // actually have it. Skip.
- continue
- }
- h := rr.Header()
- if h.Class == dnsClassINET && equalASCIILabel(h.Name, name) {
- switch h.Rrtype {
- case qtype:
- addrs = append(addrs, rr)
- case dnsTypeCNAME:
- // redirect to cname
- name = rr.(*dnsRR_CNAME).Cname
- continue Cname
- }
- }
- }
- if len(addrs) == 0 {
- return "", nil, &DNSError{Err: errNoSuchHost.Error(), Name: name, Server: server}
- }
- return name, addrs, nil
- }
-
- return "", nil, &DNSError{Err: "too many redirects", Name: name, Server: server}
-}
-
-func equalASCIILabel(x, y string) bool {
- if len(x) != len(y) {
+func equalASCIIName(x, y dnsmessage.Name) bool {
+ if x.Length != y.Length {
return false
}
- for i := 0; i < len(x); i++ {
- a := x[i]
- b := y[i]
+ for i := 0; i < int(x.Length); i++ {
+ a := x.Data[i]
+ b := y.Data[i]
if 'A' <= a && a <= 'Z' {
a += 0x20
}
diff --git a/libgo/go/net/dnsclient_test.go b/libgo/go/net/dnsclient_test.go
index 7308fb0..3ab2b83 100644
--- a/libgo/go/net/dnsclient_test.go
+++ b/libgo/go/net/dnsclient_test.go
@@ -67,51 +67,3 @@ func testWeighting(t *testing.T, margin float64) {
func TestWeighting(t *testing.T) {
testWeighting(t, 0.05)
}
-
-// Issue 8434: verify that Temporary returns true on an error when rcode
-// is SERVFAIL
-func TestIssue8434(t *testing.T) {
- msg := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- rcode: dnsRcodeServerFailure,
- },
- }
-
- _, _, err := answer("golang.org", "foo:53", msg, uint16(dnsTypeSRV))
- if err == nil {
- t.Fatal("expected an error")
- }
- if ne, ok := err.(Error); !ok {
- t.Fatalf("err = %#v; wanted something supporting net.Error", err)
- } else if !ne.Temporary() {
- t.Fatalf("Temporary = false for err = %#v; want Temporary == true", err)
- }
- if de, ok := err.(*DNSError); !ok {
- t.Fatalf("err = %#v; wanted a *net.DNSError", err)
- } else if !de.IsTemporary {
- t.Fatalf("IsTemporary = false for err = %#v; want IsTemporary == true", err)
- }
-}
-
-// Issue 12778: verify that NXDOMAIN without RA bit errors as
-// "no such host" and not "server misbehaving"
-func TestIssue12778(t *testing.T) {
- msg := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- rcode: dnsRcodeNameError,
- recursion_available: false,
- },
- }
-
- _, _, err := answer("golang.org", "foo:53", msg, uint16(dnsTypeSRV))
- if err == nil {
- t.Fatal("expected an error")
- }
- de, ok := err.(*DNSError)
- if !ok {
- t.Fatalf("err = %#v; wanted a *net.DNSError", err)
- }
- if de.Err != errNoSuchHost.Error() {
- t.Fatalf("Err = %#v; wanted %q", de.Err, errNoSuchHost.Error())
- }
-}
diff --git a/libgo/go/net/dnsclient_unix.go b/libgo/go/net/dnsclient_unix.go
index 73a507e..6ec2f44 100644
--- a/libgo/go/net/dnsclient_unix.go
+++ b/libgo/go/net/dnsclient_unix.go
@@ -23,142 +23,225 @@ import (
"os"
"sync"
"time"
-)
-
-// A dnsConn represents a DNS transport endpoint.
-type dnsConn interface {
- io.Closer
- SetDeadline(time.Time) error
+ "golang_org/x/net/dns/dnsmessage"
+)
- // dnsRoundTrip executes a single DNS transaction, returning a
- // DNS response message for the provided DNS query message.
- dnsRoundTrip(query *dnsMsg) (*dnsMsg, error)
+func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) {
+ id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
+ b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true})
+ b.EnableCompression()
+ if err := b.StartQuestions(); err != nil {
+ return 0, nil, nil, err
+ }
+ if err := b.Question(q); err != nil {
+ return 0, nil, nil, err
+ }
+ tcpReq, err = b.Finish()
+ udpReq = tcpReq[2:]
+ l := len(tcpReq) - 2
+ tcpReq[0] = byte(l >> 8)
+ tcpReq[1] = byte(l)
+ return id, udpReq, tcpReq, err
}
-// dnsPacketConn implements the dnsConn interface for RFC 1035's
-// "UDP usage" transport mechanism. Conn is a packet-oriented connection,
-// such as a *UDPConn.
-type dnsPacketConn struct {
- Conn
+func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool {
+ if !respHdr.Response {
+ return false
+ }
+ if reqID != respHdr.ID {
+ return false
+ }
+ if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) {
+ return false
+ }
+ return true
}
-func (c *dnsPacketConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
- b, ok := query.Pack()
- if !ok {
- return nil, errors.New("cannot marshal DNS message")
- }
+func dnsPacketRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
if _, err := c.Write(b); err != nil {
- return nil, err
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
b = make([]byte, 512) // see RFC 1035
for {
n, err := c.Read(b)
if err != nil {
- return nil, err
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
- resp := &dnsMsg{}
- if !resp.Unpack(b[:n]) || !resp.IsResponseTo(query) {
- // Ignore invalid responses as they may be malicious
- // forgery attempts. Instead continue waiting until
- // timeout. See golang.org/issue/13281.
+ var p dnsmessage.Parser
+ // Ignore invalid responses as they may be malicious
+ // forgery attempts. Instead continue waiting until
+ // timeout. See golang.org/issue/13281.
+ h, err := p.Start(b[:n])
+ if err != nil {
+ continue
+ }
+ q, err := p.Question()
+ if err != nil || !checkResponse(id, query, h, q) {
continue
}
- return resp, nil
+ return p, h, nil
}
}
-// dnsStreamConn implements the dnsConn interface for RFC 1035's
-// "TCP usage" transport mechanism. Conn is a stream-oriented connection,
-// such as a *TCPConn.
-type dnsStreamConn struct {
- Conn
-}
-
-func (c *dnsStreamConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
- b, ok := query.Pack()
- if !ok {
- return nil, errors.New("cannot marshal DNS message")
- }
- l := len(b)
- b = append([]byte{byte(l >> 8), byte(l)}, b...)
+func dnsStreamRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
if _, err := c.Write(b); err != nil {
- return nil, err
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
b = make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
if _, err := io.ReadFull(c, b[:2]); err != nil {
- return nil, err
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
- l = int(b[0])<<8 | int(b[1])
+ l := int(b[0])<<8 | int(b[1])
if l > len(b) {
b = make([]byte, l)
}
n, err := io.ReadFull(c, b[:l])
if err != nil {
- return nil, err
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
- resp := &dnsMsg{}
- if !resp.Unpack(b[:n]) {
- return nil, errors.New("cannot unmarshal DNS message")
+ var p dnsmessage.Parser
+ h, err := p.Start(b[:n])
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("cannot unmarshal DNS message")
}
- if !resp.IsResponseTo(query) {
- return nil, errors.New("invalid DNS response")
+ q, err := p.Question()
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("cannot unmarshal DNS message")
+ }
+ if !checkResponse(id, query, h, q) {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("invalid DNS response")
}
- return resp, nil
+ return p, h, nil
}
// exchange sends a query on the connection and hopes for a response.
-func (r *Resolver) exchange(ctx context.Context, server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) {
- out := dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- recursion_desired: true,
- },
- question: []dnsQuestion{
- {name, qtype, dnsClassINET},
- },
+func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
+ q.Class = dnsmessage.ClassINET
+ id, udpReq, tcpReq, err := newRequest(q)
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("cannot marshal DNS message")
}
for _, network := range []string{"udp", "tcp"} {
- // TODO(mdempsky): Refactor so defers from UDP-based
- // exchanges happen before TCP-based exchange.
-
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
defer cancel()
c, err := r.dial(ctx, network, server)
if err != nil {
- return nil, err
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
}
- defer c.Close()
if d, ok := ctx.Deadline(); ok && !d.IsZero() {
c.SetDeadline(d)
}
- out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
- in, err := c.dnsRoundTrip(&out)
+ var p dnsmessage.Parser
+ var h dnsmessage.Header
+ if _, ok := c.(PacketConn); ok {
+ p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
+ } else {
+ p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
+ }
+ c.Close()
if err != nil {
- return nil, mapErr(err)
+ return dnsmessage.Parser{}, dnsmessage.Header{}, mapErr(err)
}
- if in.truncated { // see RFC 5966
+ if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("invalid DNS response")
+ }
+ if h.Truncated { // see RFC 5966
continue
}
- return in, nil
+ return p, h, nil
+ }
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("no answer from DNS server")
+}
+
+// checkHeader performs basic sanity checks on the header.
+func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header, name, server string) error {
+ _, err := p.AnswerHeader()
+ if err != nil && err != dnsmessage.ErrSectionDone {
+ return &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+
+ // libresolv continues to the next server when it receives
+ // an invalid referral response. See golang.org/issue/15434.
+ if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
+ return &DNSError{Err: "lame referral", Name: name, Server: server}
+ }
+
+ if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError {
+ // None of the error codes make sense
+ // for the query we sent. If we didn't get
+ // a name error and we didn't get success,
+ // the server is behaving incorrectly or
+ // having temporary trouble.
+ err := &DNSError{Err: "server misbehaving", Name: name, Server: server}
+ if h.RCode == dnsmessage.RCodeServerFailure {
+ err.IsTemporary = true
+ }
+ return err
+ }
+
+ return nil
+}
+
+func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type, name, server string) error {
+ for {
+ h, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ return &DNSError{
+ Err: errNoSuchHost.Error(),
+ Name: name,
+ Server: server,
+ }
+ }
+ if err != nil {
+ return &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ if h.Type == qtype {
+ return nil
+ }
+ if err := p.SkipAnswer(); err != nil {
+ return &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
}
- return nil, errors.New("no answer from DNS server")
}
// Do a lookup for a single name, which must be rooted
// (otherwise answer will not find the answers).
-func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) {
+func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
var lastErr error
serverOffset := cfg.serverOffset()
sLen := uint32(len(cfg.servers))
+ n, err := dnsmessage.NewName(name)
+ if err != nil {
+ return dnsmessage.Parser{}, "", errors.New("cannot marshal DNS message")
+ }
+ q := dnsmessage.Question{
+ Name: n,
+ Type: qtype,
+ Class: dnsmessage.ClassINET,
+ }
+
for i := 0; i < cfg.attempts; i++ {
for j := uint32(0); j < sLen; j++ {
server := cfg.servers[(serverOffset+j)%sLen]
- msg, err := r.exchange(ctx, server, name, qtype, cfg.timeout)
+ p, h, err := r.exchange(ctx, server, q, cfg.timeout)
if err != nil {
lastErr = &DNSError{
Err: err.Error(),
@@ -175,41 +258,26 @@ func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string,
}
continue
}
- // libresolv continues to the next server when it receives
- // an invalid referral response. See golang.org/issue/15434.
- if msg.rcode == dnsRcodeSuccess && !msg.authoritative && !msg.recursion_available && len(msg.answer) == 0 && len(msg.extra) == 0 {
- lastErr = &DNSError{Err: "lame referral", Name: name, Server: server}
+
+ // The name does not exist, so trying another server won't help.
+ //
+ // TODO: indicate this in a more obvious way, such as a field on DNSError?
+ if h.RCode == dnsmessage.RCodeNameError {
+ return dnsmessage.Parser{}, "", &DNSError{Err: errNoSuchHost.Error(), Name: name, Server: server}
+ }
+
+ lastErr = checkHeader(&p, h, name, server)
+ if lastErr != nil {
continue
}
- cname, rrs, err := answer(name, server, msg, qtype)
- // If answer errored for rcodes dnsRcodeSuccess or dnsRcodeNameError,
- // it means the response in msg was not useful and trying another
- // server probably won't help. Return now in those cases.
- // TODO: indicate this in a more obvious way, such as a field on DNSError?
- if err == nil || msg.rcode == dnsRcodeSuccess || msg.rcode == dnsRcodeNameError {
- return cname, rrs, err
+
+ lastErr = skipToAnswer(&p, qtype, name, server)
+ if lastErr == nil {
+ return p, server, nil
}
- lastErr = err
}
}
- return "", nil, lastErr
-}
-
-// addrRecordList converts and returns a list of IP addresses from DNS
-// address records (both A and AAAA). Other record types are ignored.
-func addrRecordList(rrs []dnsRR) []IPAddr {
- addrs := make([]IPAddr, 0, 4)
- for _, rr := range rrs {
- switch rr := rr.(type) {
- case *dnsRR_A:
- addrs = append(addrs, IPAddr{IP: IPv4(byte(rr.A>>24), byte(rr.A>>16), byte(rr.A>>8), byte(rr.A))})
- case *dnsRR_AAAA:
- ip := make(IP, IPv6len)
- copy(ip, rr.AAAA[:])
- addrs = append(addrs, IPAddr{IP: ip})
- }
- }
- return addrs
+ return dnsmessage.Parser{}, "", lastErr
}
// A resolverConfig represents a DNS stub resolver configuration.
@@ -287,37 +355,45 @@ func (conf *resolverConfig) releaseSema() {
<-conf.ch
}
-func (r *Resolver) lookup(ctx context.Context, name string, qtype uint16) (cname string, rrs []dnsRR, err error) {
+func (r *Resolver) lookup(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
if !isDomainName(name) {
// We used to use "invalid domain name" as the error,
// but that is a detail of the specific lookup mechanism.
// Other lookups might allow broader name syntax
// (for example Multicast DNS allows UTF-8; see RFC 6762).
// For consistency with libc resolvers, report no such host.
- return "", nil, &DNSError{Err: errNoSuchHost.Error(), Name: name}
+ return dnsmessage.Parser{}, "", &DNSError{Err: errNoSuchHost.Error(), Name: name}
}
resolvConf.tryUpdate("/etc/resolv.conf")
resolvConf.mu.RLock()
conf := resolvConf.dnsConfig
resolvConf.mu.RUnlock()
+ var (
+ p dnsmessage.Parser
+ server string
+ err error
+ )
for _, fqdn := range conf.nameList(name) {
- cname, rrs, err = r.tryOneName(ctx, conf, fqdn, qtype)
+ p, server, err = r.tryOneName(ctx, conf, fqdn, qtype)
if err == nil {
break
}
- if nerr, ok := err.(Error); ok && nerr.Temporary() && r.StrictErrors {
+ if nerr, ok := err.(Error); ok && nerr.Temporary() && r.strictErrors() {
// If we hit a temporary error with StrictErrors enabled,
// stop immediately instead of trying more names.
break
}
}
+ if err == nil {
+ return p, server, nil
+ }
if err, ok := err.(*DNSError); ok {
// Show original name passed to lookup, not suffixed one.
// In general we might have tried many suffixes; showing
// just one is misleading. See also golang.org/issue/6324.
err.Name = name
}
- return
+ return dnsmessage.Parser{}, "", err
}
// avoidDNS reports whether this is a hostname for which we should not
@@ -449,48 +525,48 @@ func goLookupIPFiles(name string) (addrs []IPAddr) {
// goLookupIP is the native Go implementation of LookupIP.
// The libc versions are in cgo_*.go.
func (r *Resolver) goLookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
- order := systemConf().hostLookupOrder(host)
+ order := systemConf().hostLookupOrder(r, host)
addrs, _, err = r.goLookupIPCNAMEOrder(ctx, host, order)
return
}
-func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, cname string, err error) {
+func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, cname dnsmessage.Name, err error) {
if order == hostLookupFilesDNS || order == hostLookupFiles {
addrs = goLookupIPFiles(name)
if len(addrs) > 0 || order == hostLookupFiles {
- return addrs, name, nil
+ return addrs, dnsmessage.Name{}, nil
}
}
if !isDomainName(name) {
// See comment in func lookup above about use of errNoSuchHost.
- return nil, "", &DNSError{Err: errNoSuchHost.Error(), Name: name}
+ return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name}
}
resolvConf.tryUpdate("/etc/resolv.conf")
resolvConf.mu.RLock()
conf := resolvConf.dnsConfig
resolvConf.mu.RUnlock()
type racer struct {
- cname string
- rrs []dnsRR
+ p dnsmessage.Parser
+ server string
error
}
lane := make(chan racer, 1)
- qtypes := [...]uint16{dnsTypeA, dnsTypeAAAA}
+ qtypes := [...]dnsmessage.Type{dnsmessage.TypeA, dnsmessage.TypeAAAA}
var lastErr error
for _, fqdn := range conf.nameList(name) {
for _, qtype := range qtypes {
dnsWaitGroup.Add(1)
- go func(qtype uint16) {
- defer dnsWaitGroup.Done()
- cname, rrs, err := r.tryOneName(ctx, conf, fqdn, qtype)
- lane <- racer{cname, rrs, err}
+ go func(qtype dnsmessage.Type) {
+ p, server, err := r.tryOneName(ctx, conf, fqdn, qtype)
+ lane <- racer{p, server, err}
+ dnsWaitGroup.Done()
}(qtype)
}
hitStrictError := false
for range qtypes {
racer := <-lane
if racer.error != nil {
- if nerr, ok := racer.error.(Error); ok && nerr.Temporary() && r.StrictErrors {
+ if nerr, ok := racer.error.(Error); ok && nerr.Temporary() && r.strictErrors() {
// This error will abort the nameList loop.
hitStrictError = true
lastErr = racer.error
@@ -500,9 +576,74 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order
}
continue
}
- addrs = append(addrs, addrRecordList(racer.rrs)...)
- if cname == "" {
- cname = racer.cname
+
+ // Presotto says it's okay to assume that servers listed in
+ // /etc/resolv.conf are recursive resolvers.
+ //
+ // We asked for recursion, so it should have included all the
+ // answers we need in this one packet.
+ //
+ // Further, RFC 1035 section 4.3.1 says that "the recursive
+ // response to a query will be... The answer to the query,
+ // possibly preface by one or more CNAME RRs that specify
+ // aliases encountered on the way to an answer."
+ //
+ // Therefore, we should be able to assume that we can ignore
+ // CNAMEs and that the A and AAAA records we requested are
+ // for the canonical name.
+
+ loop:
+ for {
+ h, err := racer.p.AnswerHeader()
+ if err != nil && err != dnsmessage.ErrSectionDone {
+ lastErr = &DNSError{
+ Err: "cannot marshal DNS message",
+ Name: name,
+ Server: racer.server,
+ }
+ }
+ if err != nil {
+ break
+ }
+ switch h.Type {
+ case dnsmessage.TypeA:
+ a, err := racer.p.AResource()
+ if err != nil {
+ lastErr = &DNSError{
+ Err: "cannot marshal DNS message",
+ Name: name,
+ Server: racer.server,
+ }
+ break loop
+ }
+ addrs = append(addrs, IPAddr{IP: IP(a.A[:])})
+
+ case dnsmessage.TypeAAAA:
+ aaaa, err := racer.p.AAAAResource()
+ if err != nil {
+ lastErr = &DNSError{
+ Err: "cannot marshal DNS message",
+ Name: name,
+ Server: racer.server,
+ }
+ break loop
+ }
+ addrs = append(addrs, IPAddr{IP: IP(aaaa.AAAA[:])})
+
+ default:
+ if err := racer.p.SkipAnswer(); err != nil {
+ lastErr = &DNSError{
+ Err: "cannot marshal DNS message",
+ Name: name,
+ Server: racer.server,
+ }
+ break loop
+ }
+ continue
+ }
+ if cname.Length == 0 && h.Name.Length != 0 {
+ cname = h.Name
+ }
}
}
if hitStrictError {
@@ -528,17 +669,17 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order
addrs = goLookupIPFiles(name)
}
if len(addrs) == 0 && lastErr != nil {
- return nil, "", lastErr
+ return nil, dnsmessage.Name{}, lastErr
}
}
return addrs, cname, nil
}
// goLookupCNAME is the native Go (non-cgo) implementation of LookupCNAME.
-func (r *Resolver) goLookupCNAME(ctx context.Context, host string) (cname string, err error) {
- order := systemConf().hostLookupOrder(host)
- _, cname, err = r.goLookupIPCNAMEOrder(ctx, host, order)
- return
+func (r *Resolver) goLookupCNAME(ctx context.Context, host string) (string, error) {
+ order := systemConf().hostLookupOrder(r, host)
+ _, cname, err := r.goLookupIPCNAMEOrder(ctx, host, order)
+ return cname.String(), err
}
// goLookupPTR is the native Go implementation of LookupAddr.
@@ -555,13 +696,36 @@ func (r *Resolver) goLookupPTR(ctx context.Context, addr string) ([]string, erro
if err != nil {
return nil, err
}
- _, rrs, err := r.lookup(ctx, arpa, dnsTypePTR)
+ p, server, err := r.lookup(ctx, arpa, dnsmessage.TypePTR)
if err != nil {
return nil, err
}
- ptrs := make([]string, len(rrs))
- for i, rr := range rrs {
- ptrs[i] = rr.(*dnsRR_PTR).Ptr
+ var ptrs []string
+ for {
+ h, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ break
+ }
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot marshal DNS message",
+ Name: addr,
+ Server: server,
+ }
+ }
+ if h.Type != dnsmessage.TypePTR {
+ continue
+ }
+ ptr, err := p.PTRResource()
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot marshal DNS message",
+ Name: addr,
+ Server: server,
+ }
+ }
+ ptrs = append(ptrs, ptr.PTR.String())
+
}
return ptrs, nil
}
diff --git a/libgo/go/net/dnsclient_unix_test.go b/libgo/go/net/dnsclient_unix_test.go
index 9e4015f..f1bb09d 100644
--- a/libgo/go/net/dnsclient_unix_test.go
+++ b/libgo/go/net/dnsclient_unix_test.go
@@ -19,42 +19,59 @@ import (
"sync"
"testing"
"time"
+
+ "golang_org/x/net/dns/dnsmessage"
)
var goResolver = Resolver{PreferGo: true}
// Test address from 192.0.2.0/24 block, reserved by RFC 5737 for documentation.
-const TestAddr uint32 = 0xc0000201
+var TestAddr = [4]byte{0xc0, 0x00, 0x02, 0x01}
// Test address from 2001:db8::/32 block, reserved by RFC 3849 for documentation.
var VarTestAddr6 = [16]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
+func mustNewName(name string) dnsmessage.Name {
+ nn, err := dnsmessage.NewName(name)
+ if err != nil {
+ panic(fmt.Sprint("creating name: ", err))
+ }
+ return nn
+}
+
+func mustQuestion(name string, qtype dnsmessage.Type, class dnsmessage.Class) dnsmessage.Question {
+ return dnsmessage.Question{
+ Name: mustNewName(name),
+ Type: qtype,
+ Class: class,
+ }
+}
+
var dnsTransportFallbackTests = []struct {
- server string
- name string
- qtype uint16
- timeout int
- rcode int
+ server string
+ question dnsmessage.Question
+ timeout int
+ rcode dnsmessage.RCode
}{
// Querying "com." with qtype=255 usually makes an answer
// which requires more than 512 bytes.
- {"8.8.8.8:53", "com.", dnsTypeALL, 2, dnsRcodeSuccess},
- {"8.8.4.4:53", "com.", dnsTypeALL, 4, dnsRcodeSuccess},
+ {"8.8.8.8:53", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), 2, dnsmessage.RCodeSuccess},
+ {"8.8.4.4:53", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), 4, dnsmessage.RCodeSuccess},
}
func TestDNSTransportFallback(t *testing.T) {
fake := fakeDNSServer{
- rh: func(n, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
- r := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: q.id,
- response: true,
- rcode: dnsRcodeSuccess,
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
},
- question: q.question,
+ Questions: q.Questions,
}
if n == "udp" {
- r.truncated = true
+ r.Header.Truncated = true
}
return r, nil
},
@@ -63,15 +80,13 @@ func TestDNSTransportFallback(t *testing.T) {
for _, tt := range dnsTransportFallbackTests {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- msg, err := r.exchange(ctx, tt.server, tt.name, tt.qtype, time.Second)
+ _, h, err := r.exchange(ctx, tt.server, tt.question, time.Second)
if err != nil {
t.Error(err)
continue
}
- switch msg.rcode {
- case tt.rcode:
- default:
- t.Errorf("got %v from %v; want %v", msg.rcode, tt.server, tt.rcode)
+ if h.RCode != tt.rcode {
+ t.Errorf("got %v from %v; want %v", h.RCode, tt.server, tt.rcode)
continue
}
}
@@ -80,39 +95,38 @@ func TestDNSTransportFallback(t *testing.T) {
// See RFC 6761 for further information about the reserved, pseudo
// domain names.
var specialDomainNameTests = []struct {
- name string
- qtype uint16
- rcode int
+ question dnsmessage.Question
+ rcode dnsmessage.RCode
}{
// Name resolution APIs and libraries should not recognize the
// followings as special.
- {"1.0.168.192.in-addr.arpa.", dnsTypePTR, dnsRcodeNameError},
- {"test.", dnsTypeALL, dnsRcodeNameError},
- {"example.com.", dnsTypeALL, dnsRcodeSuccess},
+ {mustQuestion("1.0.168.192.in-addr.arpa.", dnsmessage.TypePTR, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
+ {mustQuestion("test.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
+ {mustQuestion("example.com.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeSuccess},
// Name resolution APIs and libraries should recognize the
// followings as special and should not send any queries.
// Though, we test those names here for verifying negative
// answers at DNS query-response interaction level.
- {"localhost.", dnsTypeALL, dnsRcodeNameError},
- {"invalid.", dnsTypeALL, dnsRcodeNameError},
+ {mustQuestion("localhost.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
+ {mustQuestion("invalid.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError},
}
func TestSpecialDomainName(t *testing.T) {
- fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
- r := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: q.id,
- response: true,
+ fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
},
- question: q.question,
+ Questions: q.Questions,
}
- switch q.question[0].Name {
+ switch q.Questions[0].Name.String() {
case "example.com.":
- r.rcode = dnsRcodeSuccess
+ r.Header.RCode = dnsmessage.RCodeSuccess
default:
- r.rcode = dnsRcodeNameError
+ r.Header.RCode = dnsmessage.RCodeNameError
}
return r, nil
@@ -122,15 +136,13 @@ func TestSpecialDomainName(t *testing.T) {
for _, tt := range specialDomainNameTests {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- msg, err := r.exchange(ctx, server, tt.name, tt.qtype, 3*time.Second)
+ _, h, err := r.exchange(ctx, server, tt.question, 3*time.Second)
if err != nil {
t.Error(err)
continue
}
- switch msg.rcode {
- case tt.rcode, dnsRcodeServerFailure:
- default:
- t.Errorf("got %v from %v; want %v", msg.rcode, server, tt.rcode)
+ if h.RCode != tt.rcode {
+ t.Errorf("got %v from %v; want %v", h.RCode, server, tt.rcode)
continue
}
}
@@ -177,24 +189,26 @@ func TestAvoidDNSName(t *testing.T) {
}
}
-var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
- r := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: q.id,
- response: true,
+var fakeDNSServerSuccessful = fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
},
- question: q.question,
- }
- if len(q.question) == 1 && q.question[0].Qtype == dnsTypeA {
- r.answer = []dnsRR{
- &dnsRR_A{
- Hdr: dnsRR_Header{
- Name: q.question[0].Name,
- Rrtype: dnsTypeA,
- Class: dnsClassINET,
- Rdlength: 4,
+ Questions: q.Questions,
+ }
+ if len(q.Questions) == 1 && q.Questions[0].Type == dnsmessage.TypeA {
+ r.Answers = []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
},
- A: TestAddr,
},
}
}
@@ -459,54 +473,57 @@ var goLookupIPWithResolverConfigTests = []struct {
func TestGoLookupIPWithResolverConfig(t *testing.T) {
defer dnsWaitGroup.Wait()
-
- fake := fakeDNSServer{func(n, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
+ fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
switch s {
case "[2001:4860:4860::8888]:53", "8.8.8.8:53":
break
default:
time.Sleep(10 * time.Millisecond)
- return nil, poll.ErrTimeout
+ return dnsmessage.Message{}, poll.ErrTimeout
}
- r := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: q.id,
- response: true,
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
},
- question: q.question,
+ Questions: q.Questions,
}
- for _, question := range q.question {
- switch question.Qtype {
- case dnsTypeA:
- switch question.Name {
+ for _, question := range q.Questions {
+ switch question.Type {
+ case dnsmessage.TypeA:
+ switch question.Name.String() {
case "hostname.as112.net.":
break
case "ipv4.google.com.":
- r.answer = append(r.answer, &dnsRR_A{
- Hdr: dnsRR_Header{
- Name: q.question[0].Name,
- Rrtype: dnsTypeA,
- Class: dnsClassINET,
- Rdlength: 4,
+ r.Answers = append(r.Answers, dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
},
- A: TestAddr,
})
default:
}
- case dnsTypeAAAA:
- switch question.Name {
+ case dnsmessage.TypeAAAA:
+ switch question.Name.String() {
case "hostname.as112.net.":
break
case "ipv6.google.com.":
- r.answer = append(r.answer, &dnsRR_AAAA{
- Hdr: dnsRR_Header{
- Name: q.question[0].Name,
- Rrtype: dnsTypeAAAA,
- Class: dnsClassINET,
- Rdlength: 16,
+ r.Answers = append(r.Answers, dnsmessage.Resource{
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeAAAA,
+ Class: dnsmessage.ClassINET,
+ Length: 16,
+ },
+ Body: &dnsmessage.AAAAResource{
+ AAAA: VarTestAddr6,
},
- AAAA: VarTestAddr6,
})
}
}
@@ -554,13 +571,13 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) {
func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
defer dnsWaitGroup.Wait()
- fake := fakeDNSServer{func(n, s string, q *dnsMsg, tm time.Time) (*dnsMsg, error) {
- r := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: q.id,
- response: true,
+ fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
},
- question: q.question,
+ Questions: q.Questions,
}
return r, nil
}}
@@ -624,20 +641,20 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
t.Fatal(err)
}
- fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
- r := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: q.id,
- response: true,
+ fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
},
- question: q.question,
+ Questions: q.Questions,
}
- switch q.question[0].Name {
+ switch q.Questions[0].Name.String() {
case fqdn + ".servfail.":
- r.rcode = dnsRcodeServerFailure
+ r.Header.RCode = dnsmessage.RCodeServerFailure
default:
- r.rcode = dnsRcodeNameError
+ r.Header.RCode = dnsmessage.RCodeNameError
}
return r, nil
@@ -679,28 +696,30 @@ func TestIgnoreLameReferrals(t *testing.T) {
t.Fatal(err)
}
- fake := fakeDNSServer{func(_, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) {
+ fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
t.Log(s, q)
- r := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: q.id,
- response: true,
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
},
- question: q.question,
+ Questions: q.Questions,
}
if s == "192.0.2.2:53" {
- r.recursion_available = true
- if q.question[0].Qtype == dnsTypeA {
- r.answer = []dnsRR{
- &dnsRR_A{
- Hdr: dnsRR_Header{
- Name: q.question[0].Name,
- Rrtype: dnsTypeA,
- Class: dnsClassINET,
- Rdlength: 4,
+ r.Header.RecursionAvailable = true
+ if q.Questions[0].Type == dnsmessage.TypeA {
+ r.Answers = []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
},
- A: TestAddr,
},
}
}
@@ -727,6 +746,7 @@ func TestIgnoreLameReferrals(t *testing.T) {
func BenchmarkGoLookupIP(b *testing.B) {
testHookUninstaller.Do(uninstallTestHooks)
ctx := context.Background()
+ b.ReportAllocs()
for i := 0; i < b.N; i++ {
goResolver.LookupIPAddr(ctx, "www.example.com")
@@ -736,6 +756,7 @@ func BenchmarkGoLookupIP(b *testing.B) {
func BenchmarkGoLookupIPNoSuchHost(b *testing.B) {
testHookUninstaller.Do(uninstallTestHooks)
ctx := context.Background()
+ b.ReportAllocs()
for i := 0; i < b.N; i++ {
goResolver.LookupIPAddr(ctx, "some.nonexistent")
@@ -759,6 +780,7 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
b.Fatal(err)
}
ctx := context.Background()
+ b.ReportAllocs()
for i := 0; i < b.N; i++ {
goResolver.LookupIPAddr(ctx, "www.example.com")
@@ -766,20 +788,26 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
}
type fakeDNSServer struct {
- rh func(n, s string, q *dnsMsg, t time.Time) (*dnsMsg, error)
+ rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error)
+ alwaysTCP bool
}
func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
- return &fakeDNSConn{nil, server, n, s, nil, time.Time{}}, nil
+ if server.alwaysTCP || n == "tcp" || n == "tcp4" || n == "tcp6" {
+ return &fakeDNSConn{tcp: true, server: server, n: n, s: s}, nil
+ }
+ return &fakeDNSPacketConn{fakeDNSConn: fakeDNSConn{tcp: false, server: server, n: n, s: s}}, nil
}
type fakeDNSConn struct {
Conn
+ tcp bool
server *fakeDNSServer
n string
s string
- q *dnsMsg
+ q dnsmessage.Message
t time.Time
+ buf []byte
}
func (f *fakeDNSConn) Close() error {
@@ -787,15 +815,32 @@ func (f *fakeDNSConn) Close() error {
}
func (f *fakeDNSConn) Read(b []byte) (int, error) {
+ if len(f.buf) > 0 {
+ n := copy(b, f.buf)
+ f.buf = f.buf[n:]
+ return n, nil
+ }
+
resp, err := f.server.rh(f.n, f.s, f.q, f.t)
if err != nil {
return 0, err
}
- bb, ok := resp.Pack()
- if !ok {
- return 0, errors.New("cannot marshal DNS message")
+ bb := make([]byte, 2, 514)
+ bb, err = resp.AppendPack(bb)
+ if err != nil {
+ return 0, fmt.Errorf("cannot marshal DNS message: %v", err)
+ }
+
+ if f.tcp {
+ l := len(bb) - 2
+ bb[0] = byte(l >> 8)
+ bb[1] = byte(l)
+ f.buf = bb
+ return f.Read(b)
}
+
+ bb = bb[2:]
if len(b) < len(bb) {
return 0, errors.New("read would fragment DNS message")
}
@@ -804,27 +849,34 @@ func (f *fakeDNSConn) Read(b []byte) (int, error) {
return len(bb), nil
}
-func (f *fakeDNSConn) ReadFrom(b []byte) (int, Addr, error) {
- return 0, nil, nil
-}
-
func (f *fakeDNSConn) Write(b []byte) (int, error) {
- f.q = new(dnsMsg)
- if !f.q.Unpack(b) {
- return 0, errors.New("cannot unmarshal DNS message")
+ if f.tcp && len(b) >= 2 {
+ b = b[2:]
+ }
+ if f.q.Unpack(b) != nil {
+ return 0, fmt.Errorf("cannot unmarshal DNS message fake %s (%d)", f.n, len(b))
}
return len(b), nil
}
-func (f *fakeDNSConn) WriteTo(b []byte, addr Addr) (int, error) {
- return 0, nil
-}
-
func (f *fakeDNSConn) SetDeadline(t time.Time) error {
f.t = t
return nil
}
+type fakeDNSPacketConn struct {
+ PacketConn
+ fakeDNSConn
+}
+
+func (f *fakeDNSPacketConn) SetDeadline(t time.Time) error {
+ return f.fakeDNSConn.SetDeadline(t)
+}
+
+func (f *fakeDNSPacketConn) Close() error {
+ return f.fakeDNSConn.Close()
+}
+
// UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
func TestIgnoreDNSForgeries(t *testing.T) {
c, s := Pipe()
@@ -836,64 +888,75 @@ func TestIgnoreDNSForgeries(t *testing.T) {
return
}
- msg := &dnsMsg{}
- if !msg.Unpack(b[:n]) {
- t.Error("invalid DNS query")
+ var msg dnsmessage.Message
+ if msg.Unpack(b[:n]) != nil {
+ t.Error("invalid DNS query:", err)
return
}
s.Write([]byte("garbage DNS response packet"))
- msg.response = true
- msg.id++ // make invalid ID
- b, ok := msg.Pack()
- if !ok {
- t.Error("failed to pack DNS response")
+ msg.Header.Response = true
+ msg.Header.ID++ // make invalid ID
+
+ if b, err = msg.Pack(); err != nil {
+ t.Error("failed to pack DNS response:", err)
return
}
s.Write(b)
- msg.id-- // restore original ID
- msg.answer = []dnsRR{
- &dnsRR_A{
- Hdr: dnsRR_Header{
- Name: "www.example.com.",
- Rrtype: dnsTypeA,
- Class: dnsClassINET,
- Rdlength: 4,
+ msg.Header.ID-- // restore original ID
+ msg.Answers = []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: mustNewName("www.example.com."),
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
},
- A: TestAddr,
},
}
- b, ok = msg.Pack()
- if !ok {
- t.Error("failed to pack DNS response")
+ b, err = msg.Pack()
+ if err != nil {
+ t.Error("failed to pack DNS response:", err)
return
}
s.Write(b)
}()
- msg := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: 42,
+ msg := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: 42,
},
- question: []dnsQuestion{
+ Questions: []dnsmessage.Question{
{
- Name: "www.example.com.",
- Qtype: dnsTypeA,
- Qclass: dnsClassINET,
+ Name: mustNewName("www.example.com."),
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
},
},
}
- dc := &dnsPacketConn{c}
- resp, err := dc.dnsRoundTrip(msg)
+ b, err := msg.Pack()
+ if err != nil {
+ t.Fatal("Pack failed:", err)
+ }
+
+ p, _, err := dnsPacketRoundTrip(c, 42, msg.Questions[0], b)
if err != nil {
- t.Fatalf("dnsRoundTripUDP failed: %v", err)
+ t.Fatalf("dnsPacketRoundTrip failed: %v", err)
}
- if got := resp.answer[0].(*dnsRR_A).A; got != TestAddr {
+ p.SkipAllQuestions()
+ as, err := p.AllAnswers()
+ if err != nil {
+ t.Fatal("AllAnswers failed:", err)
+ }
+ if got := as[0].Body.(*dnsmessage.AResource).A; got != TestAddr {
t.Errorf("got address %v, want %v", got, TestAddr)
}
}
@@ -918,7 +981,7 @@ func TestRetryTimeout(t *testing.T) {
var deadline0 time.Time
- fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
+ fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
t.Log(s, q, deadline)
if deadline.IsZero() {
@@ -928,7 +991,7 @@ func TestRetryTimeout(t *testing.T) {
if s == "192.0.2.1:53" {
deadline0 = deadline
time.Sleep(10 * time.Millisecond)
- return nil, poll.ErrTimeout
+ return dnsmessage.Message{}, poll.ErrTimeout
}
if deadline.Equal(deadline0) {
@@ -979,7 +1042,7 @@ func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
}
var usedServers []string
- fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
+ fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
usedServers = append(usedServers, s)
return mockTXTResponse(q), nil
}}
@@ -997,22 +1060,24 @@ func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
}
}
-func mockTXTResponse(q *dnsMsg) *dnsMsg {
- r := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: q.id,
- response: true,
- recursion_available: true,
+func mockTXTResponse(q dnsmessage.Message) dnsmessage.Message {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ RecursionAvailable: true,
},
- question: q.question,
- answer: []dnsRR{
- &dnsRR_TXT{
- Hdr: dnsRR_Header{
- Name: q.question[0].Name,
- Rrtype: dnsTypeTXT,
- Class: dnsClassINET,
+ Questions: q.Questions,
+ Answers: []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeTXT,
+ Class: dnsmessage.ClassINET,
+ },
+ Body: &dnsmessage.TXTResource{
+ TXT: []string{"ok"},
},
- Txt: "ok",
},
},
}
@@ -1080,22 +1145,22 @@ func TestStrictErrorsLookupIP(t *testing.T) {
cases := []struct {
desc string
- resolveWhich func(quest *dnsQuestion) resolveWhichEnum
+ resolveWhich func(quest dnsmessage.Question) resolveWhichEnum
wantStrictErr error
wantLaxErr error
wantIPs []string
}{
{
desc: "No errors",
- resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
return resolveOK
},
wantIPs: []string{ip4, ip6},
},
{
desc: "searchX error fails in strict mode",
- resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
- if quest.Name == searchX {
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchX {
return resolveTimeout
}
return resolveOK
@@ -1105,8 +1170,8 @@ func TestStrictErrorsLookupIP(t *testing.T) {
},
{
desc: "searchX IPv4-only timeout fails in strict mode",
- resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
- if quest.Name == searchX && quest.Qtype == dnsTypeA {
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchX && quest.Type == dnsmessage.TypeA {
return resolveTimeout
}
return resolveOK
@@ -1116,8 +1181,8 @@ func TestStrictErrorsLookupIP(t *testing.T) {
},
{
desc: "searchX IPv6-only servfail fails in strict mode",
- resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
- if quest.Name == searchX && quest.Qtype == dnsTypeAAAA {
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchX && quest.Type == dnsmessage.TypeAAAA {
return resolveServfail
}
return resolveOK
@@ -1127,8 +1192,8 @@ func TestStrictErrorsLookupIP(t *testing.T) {
},
{
desc: "searchY error always fails",
- resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
- if quest.Name == searchY {
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchY {
return resolveTimeout
}
return resolveOK
@@ -1138,8 +1203,8 @@ func TestStrictErrorsLookupIP(t *testing.T) {
},
{
desc: "searchY IPv4-only socket error fails in strict mode",
- resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
- if quest.Name == searchY && quest.Qtype == dnsTypeA {
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchY && quest.Type == dnsmessage.TypeA {
return resolveOpError
}
return resolveOK
@@ -1149,8 +1214,8 @@ func TestStrictErrorsLookupIP(t *testing.T) {
},
{
desc: "searchY IPv6-only timeout fails in strict mode",
- resolveWhich: func(quest *dnsQuestion) resolveWhichEnum {
- if quest.Name == searchY && quest.Qtype == dnsTypeAAAA {
+ resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum {
+ if quest.Name.String() == searchY && quest.Type == dnsmessage.TypeAAAA {
return resolveTimeout
}
return resolveOK
@@ -1161,80 +1226,84 @@ func TestStrictErrorsLookupIP(t *testing.T) {
}
for i, tt := range cases {
- fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
+ fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
t.Log(s, q)
- switch tt.resolveWhich(&q.question[0]) {
+ switch tt.resolveWhich(q.Questions[0]) {
case resolveOK:
// Handle below.
case resolveOpError:
- return nil, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")}
+ return dnsmessage.Message{}, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")}
case resolveServfail:
- return &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: q.id,
- response: true,
- rcode: dnsRcodeServerFailure,
+ return dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeServerFailure,
},
- question: q.question,
+ Questions: q.Questions,
}, nil
case resolveTimeout:
- return nil, poll.ErrTimeout
+ return dnsmessage.Message{}, poll.ErrTimeout
default:
t.Fatal("Impossible resolveWhich")
}
- switch q.question[0].Name {
+ switch q.Questions[0].Name.String() {
case searchX, name + ".":
// Return NXDOMAIN to utilize the search list.
- return &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: q.id,
- response: true,
- rcode: dnsRcodeNameError,
+ return dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeNameError,
},
- question: q.question,
+ Questions: q.Questions,
}, nil
case searchY:
// Return records below.
default:
- return nil, fmt.Errorf("Unexpected Name: %v", q.question[0].Name)
+ return dnsmessage.Message{}, fmt.Errorf("Unexpected Name: %v", q.Questions[0].Name)
}
- r := &dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: q.id,
- response: true,
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
},
- question: q.question,
+ Questions: q.Questions,
}
- switch q.question[0].Qtype {
- case dnsTypeA:
- r.answer = []dnsRR{
- &dnsRR_A{
- Hdr: dnsRR_Header{
- Name: q.question[0].Name,
- Rrtype: dnsTypeA,
- Class: dnsClassINET,
- Rdlength: 4,
+ switch q.Questions[0].Type {
+ case dnsmessage.TypeA:
+ r.Answers = []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeA,
+ Class: dnsmessage.ClassINET,
+ Length: 4,
+ },
+ Body: &dnsmessage.AResource{
+ A: TestAddr,
},
- A: TestAddr,
},
}
- case dnsTypeAAAA:
- r.answer = []dnsRR{
- &dnsRR_AAAA{
- Hdr: dnsRR_Header{
- Name: q.question[0].Name,
- Rrtype: dnsTypeAAAA,
- Class: dnsClassINET,
- Rdlength: 16,
+ case dnsmessage.TypeAAAA:
+ r.Answers = []dnsmessage.Resource{
+ {
+ Header: dnsmessage.ResourceHeader{
+ Name: q.Questions[0].Name,
+ Type: dnsmessage.TypeAAAA,
+ Class: dnsmessage.ClassINET,
+ Length: 16,
+ },
+ Body: &dnsmessage.AAAAResource{
+ AAAA: VarTestAddr6,
},
- AAAA: VarTestAddr6,
},
}
default:
- return nil, fmt.Errorf("Unexpected Qtype: %v", q.question[0].Qtype)
+ return dnsmessage.Message{}, fmt.Errorf("Unexpected Type: %v", q.Questions[0].Type)
}
return r, nil
}}
@@ -1295,22 +1364,22 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
const searchY = "test.y.golang.org."
const txt = "Hello World"
- fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) {
+ fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
t.Log(s, q)
- switch q.question[0].Name {
+ switch q.Questions[0].Name.String() {
case searchX:
- return nil, poll.ErrTimeout
+ return dnsmessage.Message{}, poll.ErrTimeout
case searchY:
return mockTXTResponse(q), nil
default:
- return nil, fmt.Errorf("Unexpected Name: %v", q.question[0].Name)
+ return dnsmessage.Message{}, fmt.Errorf("Unexpected Name: %v", q.Questions[0].Name)
}
}}
for _, strict := range []bool{true, false} {
r := Resolver{StrictErrors: strict, Dial: fake.DialContext}
- _, rrs, err := r.lookup(context.Background(), name, dnsTypeTXT)
+ p, _, err := r.lookup(context.Background(), name, dnsmessage.TypeTXT)
var wantErr error
var wantRRs int
if strict {
@@ -1326,8 +1395,12 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
if !reflect.DeepEqual(err, wantErr) {
t.Errorf("strict=%v: got err %#v; want %#v", strict, err, wantErr)
}
- if len(rrs) != wantRRs {
- t.Errorf("strict=%v: got %v; want %v", strict, len(rrs), wantRRs)
+ a, err := p.AllAnswers()
+ if err != nil {
+ a = nil
+ }
+ if len(a) != wantRRs {
+ t.Errorf("strict=%v: got %v; want %v", strict, len(a), wantRRs)
}
}
}
@@ -1337,9 +1410,9 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
func TestDNSGoroutineRace(t *testing.T) {
defer dnsWaitGroup.Wait()
- fake := fakeDNSServer{func(n, s string, q *dnsMsg, t time.Time) (*dnsMsg, error) {
+ fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) {
time.Sleep(10 * time.Microsecond)
- return nil, poll.ErrTimeout
+ return dnsmessage.Message{}, poll.ErrTimeout
}}
r := Resolver{PreferGo: true, Dial: fake.DialContext}
@@ -1353,3 +1426,112 @@ func TestDNSGoroutineRace(t *testing.T) {
t.Fatal("fake DNS lookup unexpectedly succeeded")
}
}
+
+// Issue 8434: verify that Temporary returns true on an error when rcode
+// is SERVFAIL
+func TestIssue8434(t *testing.T) {
+ msg := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ RCode: dnsmessage.RCodeServerFailure,
+ },
+ }
+ b, err := msg.Pack()
+ if err != nil {
+ t.Fatal("Pack failed:", err)
+ }
+ var p dnsmessage.Parser
+ h, err := p.Start(b)
+ if err != nil {
+ t.Fatal("Start failed:", err)
+ }
+ if err := p.SkipAllQuestions(); err != nil {
+ t.Fatal("SkipAllQuestions failed:", err)
+ }
+
+ err = checkHeader(&p, h, "golang.org", "foo:53")
+ if err == nil {
+ t.Fatal("expected an error")
+ }
+ if ne, ok := err.(Error); !ok {
+ t.Fatalf("err = %#v; wanted something supporting net.Error", err)
+ } else if !ne.Temporary() {
+ t.Fatalf("Temporary = false for err = %#v; want Temporary == true", err)
+ }
+ if de, ok := err.(*DNSError); !ok {
+ t.Fatalf("err = %#v; wanted a *net.DNSError", err)
+ } else if !de.IsTemporary {
+ t.Fatalf("IsTemporary = false for err = %#v; want IsTemporary == true", err)
+ }
+}
+
+// Issue 12778: verify that NXDOMAIN without RA bit errors as
+// "no such host" and not "server misbehaving"
+//
+// Issue 25336: verify that NXDOMAIN errors fail fast.
+func TestIssue12778(t *testing.T) {
+ lookups := 0
+ fake := fakeDNSServer{
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ lookups++
+ return dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeNameError,
+ RecursionAvailable: false,
+ },
+ Questions: q.Questions,
+ }, nil
+ },
+ }
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+
+ resolvConf.mu.RLock()
+ conf := resolvConf.dnsConfig
+ resolvConf.mu.RUnlock()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ _, _, err := r.tryOneName(ctx, conf, ".", dnsmessage.TypeALL)
+
+ if lookups != 1 {
+ t.Errorf("got %d lookups, wanted 1", lookups)
+ }
+
+ if err == nil {
+ t.Fatal("expected an error")
+ }
+ de, ok := err.(*DNSError)
+ if !ok {
+ t.Fatalf("err = %#v; wanted a *net.DNSError", err)
+ }
+ if de.Err != errNoSuchHost.Error() {
+ t.Fatalf("Err = %#v; wanted %q", de.Err, errNoSuchHost.Error())
+ }
+}
+
+// Issue 26573: verify that Conns that don't implement PacketConn are treated
+// as streams even when udp was requested.
+func TestDNSDialTCP(t *testing.T) {
+ fake := fakeDNSServer{
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ },
+ Questions: q.Questions,
+ }
+ return r, nil
+ },
+ alwaysTCP: true,
+ }
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+ ctx := context.Background()
+ _, _, err := r.exchange(ctx, "0.0.0.0", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), time.Second)
+ if err != nil {
+ t.Fatal("exhange failed:", err)
+ }
+}
diff --git a/libgo/go/net/dnsconfig_unix.go b/libgo/go/net/dnsconfig_unix.go
index 24487af..8ae5de6 100644
--- a/libgo/go/net/dnsconfig_unix.go
+++ b/libgo/go/net/dnsconfig_unix.go
@@ -73,7 +73,7 @@ func dnsReadConfig(filename string) *dnsConfig {
// to look it up.
if parseIPv4(f[1]) != nil {
conf.servers = append(conf.servers, JoinHostPort(f[1], "53"))
- } else if ip, _ := parseIPv6(f[1], true); ip != nil {
+ } else if ip, _ := parseIPv6Zone(f[1]); ip != nil {
conf.servers = append(conf.servers, JoinHostPort(f[1], "53"))
}
}
@@ -121,7 +121,7 @@ func dnsReadConfig(filename string) *dnsConfig {
case "lookup":
// OpenBSD option:
- // http://www.openbsd.org/cgi-bin/man.cgi/OpenBSD-current/man5/resolv.conf.5
+ // https://www.openbsd.org/cgi-bin/man.cgi/OpenBSD-current/man5/resolv.conf.5
// "the legal space-separated values are: bind, file, yp"
conf.lookup = f[1:]
diff --git a/libgo/go/net/dnsmsg.go b/libgo/go/net/dnsmsg.go
deleted file mode 100644
index 8f6c7b6..0000000
--- a/libgo/go/net/dnsmsg.go
+++ /dev/null
@@ -1,884 +0,0 @@
-// Copyright 2009 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// DNS packet assembly. See RFC 1035.
-//
-// This is intended to support name resolution during Dial.
-// It doesn't have to be blazing fast.
-//
-// Each message structure has a Walk method that is used by
-// a generic pack/unpack routine. Thus, if in the future we need
-// to define new message structs, no new pack/unpack/printing code
-// needs to be written.
-//
-// The first half of this file defines the DNS message formats.
-// The second half implements the conversion to and from wire format.
-// A few of the structure elements have string tags to aid the
-// generic pack/unpack routines.
-//
-// TODO(rsc): There are enough names defined in this file that they're all
-// prefixed with dns. Perhaps put this in its own package later.
-
-package net
-
-// Packet formats
-
-// Wire constants.
-const (
- // valid dnsRR_Header.Rrtype and dnsQuestion.qtype
- dnsTypeA = 1
- dnsTypeNS = 2
- dnsTypeMD = 3
- dnsTypeMF = 4
- dnsTypeCNAME = 5
- dnsTypeSOA = 6
- dnsTypeMB = 7
- dnsTypeMG = 8
- dnsTypeMR = 9
- dnsTypeNULL = 10
- dnsTypeWKS = 11
- dnsTypePTR = 12
- dnsTypeHINFO = 13
- dnsTypeMINFO = 14
- dnsTypeMX = 15
- dnsTypeTXT = 16
- dnsTypeAAAA = 28
- dnsTypeSRV = 33
-
- // valid dnsQuestion.qtype only
- dnsTypeAXFR = 252
- dnsTypeMAILB = 253
- dnsTypeMAILA = 254
- dnsTypeALL = 255
-
- // valid dnsQuestion.qclass
- dnsClassINET = 1
- dnsClassCSNET = 2
- dnsClassCHAOS = 3
- dnsClassHESIOD = 4
- dnsClassANY = 255
-
- // dnsMsg.rcode
- dnsRcodeSuccess = 0
- dnsRcodeFormatError = 1
- dnsRcodeServerFailure = 2
- dnsRcodeNameError = 3
- dnsRcodeNotImplemented = 4
- dnsRcodeRefused = 5
-)
-
-// A dnsStruct describes how to iterate over its fields to emulate
-// reflective marshaling.
-type dnsStruct interface {
- // Walk iterates over fields of a structure and calls f
- // with a reference to that field, the name of the field
- // and a tag ("", "domain", "ipv4", "ipv6") specifying
- // particular encodings. Possible concrete types
- // for v are *uint16, *uint32, *string, or []byte, and
- // *int, *bool in the case of dnsMsgHdr.
- // Whenever f returns false, Walk must stop and return
- // false, and otherwise return true.
- Walk(f func(v interface{}, name, tag string) (ok bool)) (ok bool)
-}
-
-// The wire format for the DNS packet header.
-type dnsHeader struct {
- Id uint16
- Bits uint16
- Qdcount, Ancount, Nscount, Arcount uint16
-}
-
-func (h *dnsHeader) Walk(f func(v interface{}, name, tag string) bool) bool {
- return f(&h.Id, "Id", "") &&
- f(&h.Bits, "Bits", "") &&
- f(&h.Qdcount, "Qdcount", "") &&
- f(&h.Ancount, "Ancount", "") &&
- f(&h.Nscount, "Nscount", "") &&
- f(&h.Arcount, "Arcount", "")
-}
-
-const (
- // dnsHeader.Bits
- _QR = 1 << 15 // query/response (response=1)
- _AA = 1 << 10 // authoritative
- _TC = 1 << 9 // truncated
- _RD = 1 << 8 // recursion desired
- _RA = 1 << 7 // recursion available
-)
-
-// DNS queries.
-type dnsQuestion struct {
- Name string
- Qtype uint16
- Qclass uint16
-}
-
-func (q *dnsQuestion) Walk(f func(v interface{}, name, tag string) bool) bool {
- return f(&q.Name, "Name", "domain") &&
- f(&q.Qtype, "Qtype", "") &&
- f(&q.Qclass, "Qclass", "")
-}
-
-// DNS responses (resource records).
-// There are many types of messages,
-// but they all share the same header.
-type dnsRR_Header struct {
- Name string
- Rrtype uint16
- Class uint16
- Ttl uint32
- Rdlength uint16 // length of data after header
-}
-
-func (h *dnsRR_Header) Header() *dnsRR_Header {
- return h
-}
-
-func (h *dnsRR_Header) Walk(f func(v interface{}, name, tag string) bool) bool {
- return f(&h.Name, "Name", "domain") &&
- f(&h.Rrtype, "Rrtype", "") &&
- f(&h.Class, "Class", "") &&
- f(&h.Ttl, "Ttl", "") &&
- f(&h.Rdlength, "Rdlength", "")
-}
-
-type dnsRR interface {
- dnsStruct
- Header() *dnsRR_Header
-}
-
-// Specific DNS RR formats for each query type.
-
-type dnsRR_CNAME struct {
- Hdr dnsRR_Header
- Cname string
-}
-
-func (rr *dnsRR_CNAME) Header() *dnsRR_Header {
- return &rr.Hdr
-}
-
-func (rr *dnsRR_CNAME) Walk(f func(v interface{}, name, tag string) bool) bool {
- return rr.Hdr.Walk(f) && f(&rr.Cname, "Cname", "domain")
-}
-
-type dnsRR_MX struct {
- Hdr dnsRR_Header
- Pref uint16
- Mx string
-}
-
-func (rr *dnsRR_MX) Header() *dnsRR_Header {
- return &rr.Hdr
-}
-
-func (rr *dnsRR_MX) Walk(f func(v interface{}, name, tag string) bool) bool {
- return rr.Hdr.Walk(f) && f(&rr.Pref, "Pref", "") && f(&rr.Mx, "Mx", "domain")
-}
-
-type dnsRR_NS struct {
- Hdr dnsRR_Header
- Ns string
-}
-
-func (rr *dnsRR_NS) Header() *dnsRR_Header {
- return &rr.Hdr
-}
-
-func (rr *dnsRR_NS) Walk(f func(v interface{}, name, tag string) bool) bool {
- return rr.Hdr.Walk(f) && f(&rr.Ns, "Ns", "domain")
-}
-
-type dnsRR_PTR struct {
- Hdr dnsRR_Header
- Ptr string
-}
-
-func (rr *dnsRR_PTR) Header() *dnsRR_Header {
- return &rr.Hdr
-}
-
-func (rr *dnsRR_PTR) Walk(f func(v interface{}, name, tag string) bool) bool {
- return rr.Hdr.Walk(f) && f(&rr.Ptr, "Ptr", "domain")
-}
-
-type dnsRR_SOA struct {
- Hdr dnsRR_Header
- Ns string
- Mbox string
- Serial uint32
- Refresh uint32
- Retry uint32
- Expire uint32
- Minttl uint32
-}
-
-func (rr *dnsRR_SOA) Header() *dnsRR_Header {
- return &rr.Hdr
-}
-
-func (rr *dnsRR_SOA) Walk(f func(v interface{}, name, tag string) bool) bool {
- return rr.Hdr.Walk(f) &&
- f(&rr.Ns, "Ns", "domain") &&
- f(&rr.Mbox, "Mbox", "domain") &&
- f(&rr.Serial, "Serial", "") &&
- f(&rr.Refresh, "Refresh", "") &&
- f(&rr.Retry, "Retry", "") &&
- f(&rr.Expire, "Expire", "") &&
- f(&rr.Minttl, "Minttl", "")
-}
-
-type dnsRR_TXT struct {
- Hdr dnsRR_Header
- Txt string // not domain name
-}
-
-func (rr *dnsRR_TXT) Header() *dnsRR_Header {
- return &rr.Hdr
-}
-
-func (rr *dnsRR_TXT) Walk(f func(v interface{}, name, tag string) bool) bool {
- if !rr.Hdr.Walk(f) {
- return false
- }
- var n uint16 = 0
- for n < rr.Hdr.Rdlength {
- var txt string
- if !f(&txt, "Txt", "") {
- return false
- }
- // more bytes than rr.Hdr.Rdlength said there would be
- if rr.Hdr.Rdlength-n < uint16(len(txt))+1 {
- return false
- }
- n += uint16(len(txt)) + 1
- rr.Txt += txt
- }
- return true
-}
-
-type dnsRR_SRV struct {
- Hdr dnsRR_Header
- Priority uint16
- Weight uint16
- Port uint16
- Target string
-}
-
-func (rr *dnsRR_SRV) Header() *dnsRR_Header {
- return &rr.Hdr
-}
-
-func (rr *dnsRR_SRV) Walk(f func(v interface{}, name, tag string) bool) bool {
- return rr.Hdr.Walk(f) &&
- f(&rr.Priority, "Priority", "") &&
- f(&rr.Weight, "Weight", "") &&
- f(&rr.Port, "Port", "") &&
- f(&rr.Target, "Target", "domain")
-}
-
-type dnsRR_A struct {
- Hdr dnsRR_Header
- A uint32
-}
-
-func (rr *dnsRR_A) Header() *dnsRR_Header {
- return &rr.Hdr
-}
-
-func (rr *dnsRR_A) Walk(f func(v interface{}, name, tag string) bool) bool {
- return rr.Hdr.Walk(f) && f(&rr.A, "A", "ipv4")
-}
-
-type dnsRR_AAAA struct {
- Hdr dnsRR_Header
- AAAA [16]byte
-}
-
-func (rr *dnsRR_AAAA) Header() *dnsRR_Header {
- return &rr.Hdr
-}
-
-func (rr *dnsRR_AAAA) Walk(f func(v interface{}, name, tag string) bool) bool {
- return rr.Hdr.Walk(f) && f(rr.AAAA[:], "AAAA", "ipv6")
-}
-
-// Packing and unpacking.
-//
-// All the packers and unpackers take a (msg []byte, off int)
-// and return (off1 int, ok bool). If they return ok==false, they
-// also return off1==len(msg), so that the next unpacker will
-// also fail. This lets us avoid checks of ok until the end of a
-// packing sequence.
-
-// Map of constructors for each RR wire type.
-var rr_mk = map[int]func() dnsRR{
- dnsTypeCNAME: func() dnsRR { return new(dnsRR_CNAME) },
- dnsTypeMX: func() dnsRR { return new(dnsRR_MX) },
- dnsTypeNS: func() dnsRR { return new(dnsRR_NS) },
- dnsTypePTR: func() dnsRR { return new(dnsRR_PTR) },
- dnsTypeSOA: func() dnsRR { return new(dnsRR_SOA) },
- dnsTypeTXT: func() dnsRR { return new(dnsRR_TXT) },
- dnsTypeSRV: func() dnsRR { return new(dnsRR_SRV) },
- dnsTypeA: func() dnsRR { return new(dnsRR_A) },
- dnsTypeAAAA: func() dnsRR { return new(dnsRR_AAAA) },
-}
-
-// Pack a domain name s into msg[off:].
-// Domain names are a sequence of counted strings
-// split at the dots. They end with a zero-length string.
-func packDomainName(s string, msg []byte, off int) (off1 int, ok bool) {
- // Add trailing dot to canonicalize name.
- if n := len(s); n == 0 || s[n-1] != '.' {
- s += "."
- }
-
- // Allow root domain.
- if s == "." {
- msg[off] = 0
- off++
- return off, true
- }
-
- // Each dot ends a segment of the name.
- // We trade each dot byte for a length byte.
- // There is also a trailing zero.
- // Check that we have all the space we need.
- tot := len(s) + 1
- if off+tot > len(msg) {
- return len(msg), false
- }
-
- // Emit sequence of counted strings, chopping at dots.
- begin := 0
- for i := 0; i < len(s); i++ {
- if s[i] == '.' {
- if i-begin >= 1<<6 { // top two bits of length must be clear
- return len(msg), false
- }
- if i-begin == 0 {
- return len(msg), false
- }
-
- msg[off] = byte(i - begin)
- off++
-
- for j := begin; j < i; j++ {
- msg[off] = s[j]
- off++
- }
- begin = i + 1
- }
- }
- msg[off] = 0
- off++
- return off, true
-}
-
-// Unpack a domain name.
-// In addition to the simple sequences of counted strings above,
-// domain names are allowed to refer to strings elsewhere in the
-// packet, to avoid repeating common suffixes when returning
-// many entries in a single domain. The pointers are marked
-// by a length byte with the top two bits set. Ignoring those
-// two bits, that byte and the next give a 14 bit offset from msg[0]
-// where we should pick up the trail.
-// Note that if we jump elsewhere in the packet,
-// we return off1 == the offset after the first pointer we found,
-// which is where the next record will start.
-// In theory, the pointers are only allowed to jump backward.
-// We let them jump anywhere and stop jumping after a while.
-func unpackDomainName(msg []byte, off int) (s string, off1 int, ok bool) {
- s = ""
- ptr := 0 // number of pointers followed
-Loop:
- for {
- if off >= len(msg) {
- return "", len(msg), false
- }
- c := int(msg[off])
- off++
- switch c & 0xC0 {
- case 0x00:
- if c == 0x00 {
- // end of name
- break Loop
- }
- // literal string
- if off+c > len(msg) {
- return "", len(msg), false
- }
- s += string(msg[off:off+c]) + "."
- off += c
- case 0xC0:
- // pointer to somewhere else in msg.
- // remember location after first ptr,
- // since that's how many bytes we consumed.
- // also, don't follow too many pointers --
- // maybe there's a loop.
- if off >= len(msg) {
- return "", len(msg), false
- }
- c1 := msg[off]
- off++
- if ptr == 0 {
- off1 = off
- }
- if ptr++; ptr > 10 {
- return "", len(msg), false
- }
- off = (c^0xC0)<<8 | int(c1)
- default:
- // 0x80 and 0x40 are reserved
- return "", len(msg), false
- }
- }
- if len(s) == 0 {
- s = "."
- }
- if ptr == 0 {
- off1 = off
- }
- return s, off1, true
-}
-
-// packStruct packs a structure into msg at specified offset off, and
-// returns off1 such that msg[off:off1] is the encoded data.
-func packStruct(any dnsStruct, msg []byte, off int) (off1 int, ok bool) {
- ok = any.Walk(func(field interface{}, name, tag string) bool {
- switch fv := field.(type) {
- default:
- println("net: dns: unknown packing type")
- return false
- case *uint16:
- i := *fv
- if off+2 > len(msg) {
- return false
- }
- msg[off] = byte(i >> 8)
- msg[off+1] = byte(i)
- off += 2
- case *uint32:
- i := *fv
- msg[off] = byte(i >> 24)
- msg[off+1] = byte(i >> 16)
- msg[off+2] = byte(i >> 8)
- msg[off+3] = byte(i)
- off += 4
- case []byte:
- n := len(fv)
- if off+n > len(msg) {
- return false
- }
- copy(msg[off:off+n], fv)
- off += n
- case *string:
- s := *fv
- switch tag {
- default:
- println("net: dns: unknown string tag", tag)
- return false
- case "domain":
- off, ok = packDomainName(s, msg, off)
- if !ok {
- return false
- }
- case "":
- // Counted string: 1 byte length.
- if len(s) > 255 || off+1+len(s) > len(msg) {
- return false
- }
- msg[off] = byte(len(s))
- off++
- off += copy(msg[off:], s)
- }
- }
- return true
- })
- if !ok {
- return len(msg), false
- }
- return off, true
-}
-
-// unpackStruct decodes msg[off:] into the given structure, and
-// returns off1 such that msg[off:off1] is the encoded data.
-func unpackStruct(any dnsStruct, msg []byte, off int) (off1 int, ok bool) {
- ok = any.Walk(func(field interface{}, name, tag string) bool {
- switch fv := field.(type) {
- default:
- println("net: dns: unknown packing type")
- return false
- case *uint16:
- if off+2 > len(msg) {
- return false
- }
- *fv = uint16(msg[off])<<8 | uint16(msg[off+1])
- off += 2
- case *uint32:
- if off+4 > len(msg) {
- return false
- }
- *fv = uint32(msg[off])<<24 | uint32(msg[off+1])<<16 |
- uint32(msg[off+2])<<8 | uint32(msg[off+3])
- off += 4
- case []byte:
- n := len(fv)
- if off+n > len(msg) {
- return false
- }
- copy(fv, msg[off:off+n])
- off += n
- case *string:
- var s string
- switch tag {
- default:
- println("net: dns: unknown string tag", tag)
- return false
- case "domain":
- s, off, ok = unpackDomainName(msg, off)
- if !ok {
- return false
- }
- case "":
- if off >= len(msg) || off+1+int(msg[off]) > len(msg) {
- return false
- }
- n := int(msg[off])
- off++
- b := make([]byte, n)
- for i := 0; i < n; i++ {
- b[i] = msg[off+i]
- }
- off += n
- s = string(b)
- }
- *fv = s
- }
- return true
- })
- if !ok {
- return len(msg), false
- }
- return off, true
-}
-
-// Generic struct printer. Prints fields with tag "ipv4" or "ipv6"
-// as IP addresses.
-func printStruct(any dnsStruct) string {
- s := "{"
- i := 0
- any.Walk(func(val interface{}, name, tag string) bool {
- i++
- if i > 1 {
- s += ", "
- }
- s += name + "="
- switch tag {
- case "ipv4":
- i := *val.(*uint32)
- s += IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).String()
- case "ipv6":
- i := val.([]byte)
- s += IP(i).String()
- default:
- var i int64
- switch v := val.(type) {
- default:
- // can't really happen.
- s += "<unknown type>"
- return true
- case *string:
- s += *v
- return true
- case []byte:
- s += string(v)
- return true
- case *bool:
- if *v {
- s += "true"
- } else {
- s += "false"
- }
- return true
- case *int:
- i = int64(*v)
- case *uint:
- i = int64(*v)
- case *uint8:
- i = int64(*v)
- case *uint16:
- i = int64(*v)
- case *uint32:
- i = int64(*v)
- case *uint64:
- i = int64(*v)
- case *uintptr:
- i = int64(*v)
- }
- s += itoa(int(i))
- }
- return true
- })
- s += "}"
- return s
-}
-
-// Resource record packer.
-func packRR(rr dnsRR, msg []byte, off int) (off2 int, ok bool) {
- var off1 int
- // pack twice, once to find end of header
- // and again to find end of packet.
- // a bit inefficient but this doesn't need to be fast.
- // off1 is end of header
- // off2 is end of rr
- off1, ok = packStruct(rr.Header(), msg, off)
- if !ok {
- return len(msg), false
- }
- off2, ok = packStruct(rr, msg, off)
- if !ok {
- return len(msg), false
- }
- // pack a third time; redo header with correct data length
- rr.Header().Rdlength = uint16(off2 - off1)
- packStruct(rr.Header(), msg, off)
- return off2, true
-}
-
-// Resource record unpacker.
-func unpackRR(msg []byte, off int) (rr dnsRR, off1 int, ok bool) {
- // unpack just the header, to find the rr type and length
- var h dnsRR_Header
- off0 := off
- if off, ok = unpackStruct(&h, msg, off); !ok {
- return nil, len(msg), false
- }
- end := off + int(h.Rdlength)
-
- // make an rr of that type and re-unpack.
- // again inefficient but doesn't need to be fast.
- mk, known := rr_mk[int(h.Rrtype)]
- if !known {
- return &h, end, true
- }
- rr = mk()
- off, ok = unpackStruct(rr, msg, off0)
- if off != end {
- return &h, end, true
- }
- return rr, off, ok
-}
-
-// Usable representation of a DNS packet.
-
-// A manually-unpacked version of (id, bits).
-// This is in its own struct for easy printing.
-type dnsMsgHdr struct {
- id uint16
- response bool
- opcode int
- authoritative bool
- truncated bool
- recursion_desired bool
- recursion_available bool
- rcode int
-}
-
-func (h *dnsMsgHdr) Walk(f func(v interface{}, name, tag string) bool) bool {
- return f(&h.id, "id", "") &&
- f(&h.response, "response", "") &&
- f(&h.opcode, "opcode", "") &&
- f(&h.authoritative, "authoritative", "") &&
- f(&h.truncated, "truncated", "") &&
- f(&h.recursion_desired, "recursion_desired", "") &&
- f(&h.recursion_available, "recursion_available", "") &&
- f(&h.rcode, "rcode", "")
-}
-
-type dnsMsg struct {
- dnsMsgHdr
- question []dnsQuestion
- answer []dnsRR
- ns []dnsRR
- extra []dnsRR
-}
-
-func (dns *dnsMsg) Pack() (msg []byte, ok bool) {
- var dh dnsHeader
-
- // Convert convenient dnsMsg into wire-like dnsHeader.
- dh.Id = dns.id
- dh.Bits = uint16(dns.opcode)<<11 | uint16(dns.rcode)
- if dns.recursion_available {
- dh.Bits |= _RA
- }
- if dns.recursion_desired {
- dh.Bits |= _RD
- }
- if dns.truncated {
- dh.Bits |= _TC
- }
- if dns.authoritative {
- dh.Bits |= _AA
- }
- if dns.response {
- dh.Bits |= _QR
- }
-
- // Prepare variable sized arrays.
- question := dns.question
- answer := dns.answer
- ns := dns.ns
- extra := dns.extra
-
- dh.Qdcount = uint16(len(question))
- dh.Ancount = uint16(len(answer))
- dh.Nscount = uint16(len(ns))
- dh.Arcount = uint16(len(extra))
-
- // Could work harder to calculate message size,
- // but this is far more than we need and not
- // big enough to hurt the allocator.
- msg = make([]byte, 2000)
-
- // Pack it in: header and then the pieces.
- off := 0
- off, ok = packStruct(&dh, msg, off)
- if !ok {
- return nil, false
- }
- for i := 0; i < len(question); i++ {
- off, ok = packStruct(&question[i], msg, off)
- if !ok {
- return nil, false
- }
- }
- for i := 0; i < len(answer); i++ {
- off, ok = packRR(answer[i], msg, off)
- if !ok {
- return nil, false
- }
- }
- for i := 0; i < len(ns); i++ {
- off, ok = packRR(ns[i], msg, off)
- if !ok {
- return nil, false
- }
- }
- for i := 0; i < len(extra); i++ {
- off, ok = packRR(extra[i], msg, off)
- if !ok {
- return nil, false
- }
- }
- return msg[0:off], true
-}
-
-func (dns *dnsMsg) Unpack(msg []byte) bool {
- // Header.
- var dh dnsHeader
- off := 0
- var ok bool
- if off, ok = unpackStruct(&dh, msg, off); !ok {
- return false
- }
- dns.id = dh.Id
- dns.response = (dh.Bits & _QR) != 0
- dns.opcode = int(dh.Bits>>11) & 0xF
- dns.authoritative = (dh.Bits & _AA) != 0
- dns.truncated = (dh.Bits & _TC) != 0
- dns.recursion_desired = (dh.Bits & _RD) != 0
- dns.recursion_available = (dh.Bits & _RA) != 0
- dns.rcode = int(dh.Bits & 0xF)
-
- // Arrays.
- dns.question = make([]dnsQuestion, dh.Qdcount)
- dns.answer = make([]dnsRR, 0, dh.Ancount)
- dns.ns = make([]dnsRR, 0, dh.Nscount)
- dns.extra = make([]dnsRR, 0, dh.Arcount)
-
- var rec dnsRR
-
- for i := 0; i < len(dns.question); i++ {
- off, ok = unpackStruct(&dns.question[i], msg, off)
- if !ok {
- return false
- }
- }
- for i := 0; i < int(dh.Ancount); i++ {
- rec, off, ok = unpackRR(msg, off)
- if !ok {
- return false
- }
- dns.answer = append(dns.answer, rec)
- }
- for i := 0; i < int(dh.Nscount); i++ {
- rec, off, ok = unpackRR(msg, off)
- if !ok {
- return false
- }
- dns.ns = append(dns.ns, rec)
- }
- for i := 0; i < int(dh.Arcount); i++ {
- rec, off, ok = unpackRR(msg, off)
- if !ok {
- return false
- }
- dns.extra = append(dns.extra, rec)
- }
- // if off != len(msg) {
- // println("extra bytes in dns packet", off, "<", len(msg));
- // }
- return true
-}
-
-func (dns *dnsMsg) String() string {
- s := "DNS: " + printStruct(&dns.dnsMsgHdr) + "\n"
- if len(dns.question) > 0 {
- s += "-- Questions\n"
- for i := 0; i < len(dns.question); i++ {
- s += printStruct(&dns.question[i]) + "\n"
- }
- }
- if len(dns.answer) > 0 {
- s += "-- Answers\n"
- for i := 0; i < len(dns.answer); i++ {
- s += printStruct(dns.answer[i]) + "\n"
- }
- }
- if len(dns.ns) > 0 {
- s += "-- Name servers\n"
- for i := 0; i < len(dns.ns); i++ {
- s += printStruct(dns.ns[i]) + "\n"
- }
- }
- if len(dns.extra) > 0 {
- s += "-- Extra\n"
- for i := 0; i < len(dns.extra); i++ {
- s += printStruct(dns.extra[i]) + "\n"
- }
- }
- return s
-}
-
-// IsResponseTo reports whether m is an acceptable response to query.
-func (m *dnsMsg) IsResponseTo(query *dnsMsg) bool {
- if !m.response {
- return false
- }
- if m.id != query.id {
- return false
- }
- if len(m.question) != len(query.question) {
- return false
- }
- for i, q := range m.question {
- q2 := query.question[i]
- if !equalASCIILabel(q.Name, q2.Name) || q.Qtype != q2.Qtype || q.Qclass != q2.Qclass {
- return false
- }
- }
- return true
-}
diff --git a/libgo/go/net/dnsmsg_test.go b/libgo/go/net/dnsmsg_test.go
deleted file mode 100644
index 2a25a21..0000000
--- a/libgo/go/net/dnsmsg_test.go
+++ /dev/null
@@ -1,481 +0,0 @@
-// Copyright 2011 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package net
-
-import (
- "encoding/hex"
- "reflect"
- "testing"
-)
-
-func TestStructPackUnpack(t *testing.T) {
- want := dnsQuestion{
- Name: ".",
- Qtype: dnsTypeA,
- Qclass: dnsClassINET,
- }
- buf := make([]byte, 50)
- n, ok := packStruct(&want, buf, 0)
- if !ok {
- t.Fatal("packing failed")
- }
- buf = buf[:n]
- got := dnsQuestion{}
- n, ok = unpackStruct(&got, buf, 0)
- if !ok {
- t.Fatal("unpacking failed")
- }
- if n != len(buf) {
- t.Errorf("unpacked different amount than packed: got n = %d, want = %d", n, len(buf))
- }
- if !reflect.DeepEqual(got, want) {
- t.Errorf("got = %+v, want = %+v", got, want)
- }
-}
-
-func TestDomainNamePackUnpack(t *testing.T) {
- tests := []struct {
- in string
- want string
- ok bool
- }{
- {"", ".", true},
- {".", ".", true},
- {"google..com", "", false},
- {"google.com", "google.com.", true},
- {"google..com.", "", false},
- {"google.com.", "google.com.", true},
- {".google.com.", "", false},
- {"www..google.com.", "", false},
- {"www.google.com.", "www.google.com.", true},
- }
-
- for _, test := range tests {
- buf := make([]byte, 30)
- n, ok := packDomainName(test.in, buf, 0)
- if ok != test.ok {
- t.Errorf("packing of %s: got ok = %t, want = %t", test.in, ok, test.ok)
- continue
- }
- if !test.ok {
- continue
- }
- buf = buf[:n]
- got, n, ok := unpackDomainName(buf, 0)
- if !ok {
- t.Errorf("unpacking for %s failed", test.in)
- continue
- }
- if n != len(buf) {
- t.Errorf(
- "unpacked different amount than packed for %s: got n = %d, want = %d",
- test.in,
- n,
- len(buf),
- )
- }
- if got != test.want {
- t.Errorf("unpacking packing of %s: got = %s, want = %s", test.in, got, test.want)
- }
- }
-}
-
-func TestDNSPackUnpack(t *testing.T) {
- want := dnsMsg{
- question: []dnsQuestion{{
- Name: ".",
- Qtype: dnsTypeAAAA,
- Qclass: dnsClassINET,
- }},
- answer: []dnsRR{},
- ns: []dnsRR{},
- extra: []dnsRR{},
- }
- b, ok := want.Pack()
- if !ok {
- t.Fatal("packing failed")
- }
- var got dnsMsg
- ok = got.Unpack(b)
- if !ok {
- t.Fatal("unpacking failed")
- }
- if !reflect.DeepEqual(got, want) {
- t.Errorf("got = %+v, want = %+v", got, want)
- }
-}
-
-func TestDNSParseSRVReply(t *testing.T) {
- data, err := hex.DecodeString(dnsSRVReply)
- if err != nil {
- t.Fatal(err)
- }
- msg := new(dnsMsg)
- ok := msg.Unpack(data)
- if !ok {
- t.Fatal("unpacking packet failed")
- }
- _ = msg.String() // exercise this code path
- if g, e := len(msg.answer), 5; g != e {
- t.Errorf("len(msg.answer) = %d; want %d", g, e)
- }
- for idx, rr := range msg.answer {
- if g, e := rr.Header().Rrtype, uint16(dnsTypeSRV); g != e {
- t.Errorf("rr[%d].Header().Rrtype = %d; want %d", idx, g, e)
- }
- if _, ok := rr.(*dnsRR_SRV); !ok {
- t.Errorf("answer[%d] = %T; want *dnsRR_SRV", idx, rr)
- }
- }
- for _, name := range [...]string{
- "_xmpp-server._tcp.google.com.",
- "_XMPP-Server._TCP.Google.COM.",
- "_XMPP-SERVER._TCP.GOOGLE.COM.",
- } {
- _, addrs, err := answer(name, "foo:53", msg, uint16(dnsTypeSRV))
- if err != nil {
- t.Error(err)
- }
- if g, e := len(addrs), 5; g != e {
- t.Errorf("len(addrs) = %d; want %d", g, e)
- t.Logf("addrs = %#v", addrs)
- }
- }
- // repack and unpack.
- data2, ok := msg.Pack()
- msg2 := new(dnsMsg)
- msg2.Unpack(data2)
- switch {
- case !ok:
- t.Error("failed to repack message")
- case !reflect.DeepEqual(msg, msg2):
- t.Error("repacked message differs from original")
- }
-}
-
-func TestDNSParseCorruptSRVReply(t *testing.T) {
- data, err := hex.DecodeString(dnsSRVCorruptReply)
- if err != nil {
- t.Fatal(err)
- }
- msg := new(dnsMsg)
- ok := msg.Unpack(data)
- if !ok {
- t.Fatal("unpacking packet failed")
- }
- _ = msg.String() // exercise this code path
- if g, e := len(msg.answer), 5; g != e {
- t.Errorf("len(msg.answer) = %d; want %d", g, e)
- }
- for idx, rr := range msg.answer {
- if g, e := rr.Header().Rrtype, uint16(dnsTypeSRV); g != e {
- t.Errorf("rr[%d].Header().Rrtype = %d; want %d", idx, g, e)
- }
- if idx == 4 {
- if _, ok := rr.(*dnsRR_Header); !ok {
- t.Errorf("answer[%d] = %T; want *dnsRR_Header", idx, rr)
- }
- } else {
- if _, ok := rr.(*dnsRR_SRV); !ok {
- t.Errorf("answer[%d] = %T; want *dnsRR_SRV", idx, rr)
- }
- }
- }
- _, addrs, err := answer("_xmpp-server._tcp.google.com.", "foo:53", msg, uint16(dnsTypeSRV))
- if err != nil {
- t.Fatalf("answer: %v", err)
- }
- if g, e := len(addrs), 4; g != e {
- t.Errorf("len(addrs) = %d; want %d", g, e)
- t.Logf("addrs = %#v", addrs)
- }
-}
-
-func TestDNSParseTXTReply(t *testing.T) {
- expectedTxt1 := "v=spf1 redirect=_spf.google.com"
- expectedTxt2 := "v=spf1 ip4:69.63.179.25 ip4:69.63.178.128/25 ip4:69.63.184.0/25 " +
- "ip4:66.220.144.128/25 ip4:66.220.155.0/24 " +
- "ip4:69.171.232.0/25 ip4:66.220.157.0/25 " +
- "ip4:69.171.244.0/24 mx -all"
-
- replies := []string{dnsTXTReply1, dnsTXTReply2}
- expectedTxts := []string{expectedTxt1, expectedTxt2}
-
- for i := range replies {
- data, err := hex.DecodeString(replies[i])
- if err != nil {
- t.Fatal(err)
- }
-
- msg := new(dnsMsg)
- ok := msg.Unpack(data)
- if !ok {
- t.Errorf("test %d: unpacking packet failed", i)
- continue
- }
-
- if len(msg.answer) != 1 {
- t.Errorf("test %d: len(rr.answer) = %d; want 1", i, len(msg.answer))
- continue
- }
-
- rr := msg.answer[0]
- rrTXT, ok := rr.(*dnsRR_TXT)
- if !ok {
- t.Errorf("test %d: answer[0] = %T; want *dnsRR_TXT", i, rr)
- continue
- }
-
- if rrTXT.Txt != expectedTxts[i] {
- t.Errorf("test %d: Txt = %s; want %s", i, rrTXT.Txt, expectedTxts[i])
- }
- }
-}
-
-func TestDNSParseTXTCorruptDataLengthReply(t *testing.T) {
- replies := []string{dnsTXTCorruptDataLengthReply1, dnsTXTCorruptDataLengthReply2}
-
- for i := range replies {
- data, err := hex.DecodeString(replies[i])
- if err != nil {
- t.Fatal(err)
- }
-
- msg := new(dnsMsg)
- ok := msg.Unpack(data)
- if ok {
- t.Errorf("test %d: expected to fail on unpacking corrupt packet", i)
- }
- }
-}
-
-func TestDNSParseTXTCorruptTXTLengthReply(t *testing.T) {
- replies := []string{dnsTXTCorruptTXTLengthReply1, dnsTXTCorruptTXTLengthReply2}
-
- for i := range replies {
- data, err := hex.DecodeString(replies[i])
- if err != nil {
- t.Fatal(err)
- }
-
- msg := new(dnsMsg)
- ok := msg.Unpack(data)
- // Unpacking should succeed, but we should just get the header.
- if !ok {
- t.Errorf("test %d: unpacking packet failed", i)
- continue
- }
-
- if len(msg.answer) != 1 {
- t.Errorf("test %d: len(rr.answer) = %d; want 1", i, len(msg.answer))
- continue
- }
-
- rr := msg.answer[0]
- if _, justHeader := rr.(*dnsRR_Header); !justHeader {
- t.Errorf("test %d: rr = %T; expected *dnsRR_Header", i, rr)
- }
- }
-}
-
-func TestIsResponseTo(t *testing.T) {
- // Sample DNS query.
- query := dnsMsg{
- dnsMsgHdr: dnsMsgHdr{
- id: 42,
- },
- question: []dnsQuestion{
- {
- Name: "www.example.com.",
- Qtype: dnsTypeA,
- Qclass: dnsClassINET,
- },
- },
- }
-
- resp := query
- resp.response = true
- if !resp.IsResponseTo(&query) {
- t.Error("got false, want true")
- }
-
- badResponses := []dnsMsg{
- // Different ID.
- {
- dnsMsgHdr: dnsMsgHdr{
- id: 43,
- response: true,
- },
- question: []dnsQuestion{
- {
- Name: "www.example.com.",
- Qtype: dnsTypeA,
- Qclass: dnsClassINET,
- },
- },
- },
-
- // Different query name.
- {
- dnsMsgHdr: dnsMsgHdr{
- id: 42,
- response: true,
- },
- question: []dnsQuestion{
- {
- Name: "www.google.com.",
- Qtype: dnsTypeA,
- Qclass: dnsClassINET,
- },
- },
- },
-
- // Different query type.
- {
- dnsMsgHdr: dnsMsgHdr{
- id: 42,
- response: true,
- },
- question: []dnsQuestion{
- {
- Name: "www.example.com.",
- Qtype: dnsTypeAAAA,
- Qclass: dnsClassINET,
- },
- },
- },
-
- // Different query class.
- {
- dnsMsgHdr: dnsMsgHdr{
- id: 42,
- response: true,
- },
- question: []dnsQuestion{
- {
- Name: "www.example.com.",
- Qtype: dnsTypeA,
- Qclass: dnsClassCSNET,
- },
- },
- },
-
- // No questions.
- {
- dnsMsgHdr: dnsMsgHdr{
- id: 42,
- response: true,
- },
- },
-
- // Extra questions.
- {
- dnsMsgHdr: dnsMsgHdr{
- id: 42,
- response: true,
- },
- question: []dnsQuestion{
- {
- Name: "www.example.com.",
- Qtype: dnsTypeA,
- Qclass: dnsClassINET,
- },
- {
- Name: "www.golang.org.",
- Qtype: dnsTypeAAAA,
- Qclass: dnsClassINET,
- },
- },
- },
- }
-
- for i := range badResponses {
- if badResponses[i].IsResponseTo(&query) {
- t.Errorf("%v: got true, want false", i)
- }
- }
-}
-
-// Valid DNS SRV reply
-const dnsSRVReply = "0901818000010005000000000c5f786d70702d736572766572045f74637006676f6f67" +
- "6c6503636f6d0000210001c00c002100010000012c00210014000014950c786d70702d" +
- "73657276657234016c06676f6f676c6503636f6d00c00c002100010000012c00210014" +
- "000014950c786d70702d73657276657232016c06676f6f676c6503636f6d00c00c0021" +
- "00010000012c00210014000014950c786d70702d73657276657233016c06676f6f676c" +
- "6503636f6d00c00c002100010000012c00200005000014950b786d70702d7365727665" +
- "72016c06676f6f676c6503636f6d00c00c002100010000012c00210014000014950c78" +
- "6d70702d73657276657231016c06676f6f676c6503636f6d00"
-
-// Corrupt DNS SRV reply, with its final RR having a bogus length
-// (perhaps it was truncated, or it's malicious) The mutation is the
-// capital "FF" below, instead of the proper "21".
-const dnsSRVCorruptReply = "0901818000010005000000000c5f786d70702d736572766572045f74637006676f6f67" +
- "6c6503636f6d0000210001c00c002100010000012c00210014000014950c786d70702d" +
- "73657276657234016c06676f6f676c6503636f6d00c00c002100010000012c00210014" +
- "000014950c786d70702d73657276657232016c06676f6f676c6503636f6d00c00c0021" +
- "00010000012c00210014000014950c786d70702d73657276657233016c06676f6f676c" +
- "6503636f6d00c00c002100010000012c00200005000014950b786d70702d7365727665" +
- "72016c06676f6f676c6503636f6d00c00c002100010000012c00FF0014000014950c78" +
- "6d70702d73657276657231016c06676f6f676c6503636f6d00"
-
-// TXT reply with one <character-string>
-const dnsTXTReply1 = "b3458180000100010004000505676d61696c03636f6d0000100001c00c001000010000012c00" +
- "201f763d737066312072656469726563743d5f7370662e676f6f676c652e636f6dc00" +
- "c0002000100025d4c000d036e733406676f6f676c65c012c00c0002000100025d4c00" +
- "06036e7331c057c00c0002000100025d4c0006036e7333c057c00c0002000100025d4" +
- "c0006036e7332c057c06c00010001000248b50004d8ef200ac09000010001000248b5" +
- "0004d8ef220ac07e00010001000248b50004d8ef240ac05300010001000248b50004d" +
- "8ef260a0000291000000000000000"
-
-// TXT reply with more than one <character-string>.
-// See https://tools.ietf.org/html/rfc1035#section-3.3.14
-const dnsTXTReply2 = "a0a381800001000100020002045f7370660866616365626f6f6b03636f6d0000100001c00c0010000" +
- "100000e1000af7f763d73706631206970343a36392e36332e3137392e3235206970343a36392e" +
- "36332e3137382e3132382f3235206970343a36392e36332e3138342e302f3235206970343a363" +
- "62e3232302e3134342e3132382f3235206970343a36362e3232302e3135352e302f3234206970" +
- "343a36392e3137312e3233322e302f323520692e70343a36362e3232302e3135372e302f32352" +
- "06970343a36392e3137312e3234342e302f3234206d78202d616c6cc0110002000100025d1500" +
- "070161026e73c011c0110002000100025d1500040162c0ecc0ea0001000100025d15000445abe" +
- "f0cc0fd0001000100025d15000445abff0c"
-
-// DataLength field should be sum of all TXT fields. In this case it's less.
-const dnsTXTCorruptDataLengthReply1 = "a0a381800001000100020002045f7370660866616365626f6f6b03636f6d0000100001c00c0010000" +
- "100000e1000967f763d73706631206970343a36392e36332e3137392e3235206970343a36392e" +
- "36332e3137382e3132382f3235206970343a36392e36332e3138342e302f3235206970343a363" +
- "62e3232302e3134342e3132382f3235206970343a36362e3232302e3135352e302f3234206970" +
- "343a36392e3137312e3233322e302f323520692e70343a36362e3232302e3135372e302f32352" +
- "06970343a36392e3137312e3234342e302f3234206d78202d616c6cc0110002000100025d1500" +
- "070161026e73c011c0110002000100025d1500040162c0ecc0ea0001000100025d15000445abe" +
- "f0cc0fd0001000100025d15000445abff0c"
-
-// Same as above but DataLength is more than sum of TXT fields.
-const dnsTXTCorruptDataLengthReply2 = "a0a381800001000100020002045f7370660866616365626f6f6b03636f6d0000100001c00c0010000" +
- "100000e1001227f763d73706631206970343a36392e36332e3137392e3235206970343a36392e" +
- "36332e3137382e3132382f3235206970343a36392e36332e3138342e302f3235206970343a363" +
- "62e3232302e3134342e3132382f3235206970343a36362e3232302e3135352e302f3234206970" +
- "343a36392e3137312e3233322e302f323520692e70343a36362e3232302e3135372e302f32352" +
- "06970343a36392e3137312e3234342e302f3234206d78202d616c6cc0110002000100025d1500" +
- "070161026e73c011c0110002000100025d1500040162c0ecc0ea0001000100025d15000445abe" +
- "f0cc0fd0001000100025d15000445abff0c"
-
-// TXT Length field is less than actual length.
-const dnsTXTCorruptTXTLengthReply1 = "a0a381800001000100020002045f7370660866616365626f6f6b03636f6d0000100001c00c0010000" +
- "100000e1000af7f763d73706631206970343a36392e36332e3137392e3235206970343a36392e" +
- "36332e3137382e3132382f3235206970343a36392e36332e3138342e302f3235206970343a363" +
- "62e3232302e3134342e3132382f3235206970343a36362e3232302e3135352e302f3234206970" +
- "343a36392e3137312e3233322e302f323520691470343a36362e3232302e3135372e302f32352" +
- "06970343a36392e3137312e3234342e302f3234206d78202d616c6cc0110002000100025d1500" +
- "070161026e73c011c0110002000100025d1500040162c0ecc0ea0001000100025d15000445abe" +
- "f0cc0fd0001000100025d15000445abff0c"
-
-// TXT Length field is more than actual length.
-const dnsTXTCorruptTXTLengthReply2 = "a0a381800001000100020002045f7370660866616365626f6f6b03636f6d0000100001c00c0010000" +
- "100000e1000af7f763d73706631206970343a36392e36332e3137392e3235206970343a36392e" +
- "36332e3137382e3132382f3235206970343a36392e36332e3138342e302f3235206970343a363" +
- "62e3232302e3134342e3132382f3235206970343a36362e3232302e3135352e302f3234206970" +
- "343a36392e3137312e3233322e302f323520693370343a36362e3232302e3135372e302f32352" +
- "06970343a36392e3137312e3234342e302f3234206d78202d616c6cc0110002000100025d1500" +
- "070161026e73c011c0110002000100025d1500040162c0ecc0ea0001000100025d15000445abe" +
- "f0cc0fd0001000100025d15000445abff0c"
diff --git a/libgo/go/net/dnsname_test.go b/libgo/go/net/dnsname_test.go
index e0f786d..806d875 100644
--- a/libgo/go/net/dnsname_test.go
+++ b/libgo/go/net/dnsname_test.go
@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// +build !js
+
package net
import (
diff --git a/libgo/go/net/error_nacl.go b/libgo/go/net/error_nacl.go
new file mode 100644
index 0000000..caad133
--- /dev/null
+++ b/libgo/go/net/error_nacl.go
@@ -0,0 +1,9 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+func isConnError(err error) bool {
+ return false
+}
diff --git a/libgo/go/net/error_plan9.go b/libgo/go/net/error_plan9.go
new file mode 100644
index 0000000..caad133
--- /dev/null
+++ b/libgo/go/net/error_plan9.go
@@ -0,0 +1,9 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+func isConnError(err error) bool {
+ return false
+}
diff --git a/libgo/go/net/error_posix.go b/libgo/go/net/error_posix.go
index d0ffaae..70efa4c 100644
--- a/libgo/go/net/error_posix.go
+++ b/libgo/go/net/error_posix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build aix darwin dragonfly freebsd linux nacl netbsd openbsd solaris windows
+// +build aix darwin dragonfly freebsd js,wasm linux nacl netbsd openbsd solaris windows
package net
diff --git a/libgo/go/net/error_test.go b/libgo/go/net/error_test.go
index 9791e6f..e09670e 100644
--- a/libgo/go/net/error_test.go
+++ b/libgo/go/net/error_test.go
@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// +build !js
+
package net
import (
diff --git a/libgo/go/net/error_unix.go b/libgo/go/net/error_unix.go
new file mode 100644
index 0000000..b5a5829
--- /dev/null
+++ b/libgo/go/net/error_unix.go
@@ -0,0 +1,16 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin dragonfly freebsd js linux netbsd openbsd solaris
+
+package net
+
+import "syscall"
+
+func isConnError(err error) bool {
+ if se, ok := err.(syscall.Errno); ok {
+ return se == syscall.ECONNRESET || se == syscall.ECONNABORTED
+ }
+ return false
+}
diff --git a/libgo/go/net/error_windows.go b/libgo/go/net/error_windows.go
new file mode 100644
index 0000000..570b97b
--- /dev/null
+++ b/libgo/go/net/error_windows.go
@@ -0,0 +1,14 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import "syscall"
+
+func isConnError(err error) bool {
+ if se, ok := err.(syscall.Errno); ok {
+ return se == syscall.WSAECONNRESET || se == syscall.WSAECONNABORTED
+ }
+ return false
+}
diff --git a/libgo/go/net/external_test.go b/libgo/go/net/external_test.go
index 38788ef..f3c69c4 100644
--- a/libgo/go/net/external_test.go
+++ b/libgo/go/net/external_test.go
@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// +build !js
+
package net
import (
diff --git a/libgo/go/net/fd_plan9.go b/libgo/go/net/fd_plan9.go
index 46ee5d9..da41bc0 100644
--- a/libgo/go/net/fd_plan9.go
+++ b/libgo/go/net/fd_plan9.go
@@ -9,6 +9,7 @@ import (
"io"
"os"
"syscall"
+ "time"
)
// Network file descriptor.
@@ -172,3 +173,15 @@ func setReadBuffer(fd *netFD, bytes int) error {
func setWriteBuffer(fd *netFD, bytes int) error {
return syscall.EPLAN9
}
+
+func (fd *netFD) SetDeadline(t time.Time) error {
+ return fd.pfd.SetDeadline(t)
+}
+
+func (fd *netFD) SetReadDeadline(t time.Time) error {
+ return fd.pfd.SetReadDeadline(t)
+}
+
+func (fd *netFD) SetWriteDeadline(t time.Time) error {
+ return fd.pfd.SetWriteDeadline(t)
+}
diff --git a/libgo/go/net/fd_unix.go b/libgo/go/net/fd_unix.go
index 95d5e4f..e7ab9a4 100644
--- a/libgo/go/net/fd_unix.go
+++ b/libgo/go/net/fd_unix.go
@@ -11,8 +11,8 @@ import (
"internal/poll"
"os"
"runtime"
- "sync/atomic"
"syscall"
+ "time"
)
// Network file descriptor.
@@ -22,7 +22,7 @@ type netFD struct {
// immutable until Close
family int
sotype int
- isConnected bool
+ isConnected bool // handshake completed or use of association with peer
net string
laddr Addr
raddr Addr
@@ -121,7 +121,7 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (rsa sysc
// == nil). Because we've now poisoned the connection
// by making it unwritable, don't return a successful
// dial. This was issue 16523.
- ret = ctxErr
+ ret = mapErr(ctxErr)
fd.Close() // prevent a leak
}
}()
@@ -256,74 +256,26 @@ func (fd *netFD) accept() (netfd *netFD, err error) {
return netfd, nil
}
-// Use a helper function to call fcntl. This is defined in C in
-// libgo/runtime.
-//extern __go_fcntl_uintptr
-func fcntl(uintptr, uintptr, uintptr) (uintptr, uintptr)
-
-// tryDupCloexec indicates whether F_DUPFD_CLOEXEC should be used.
-// If the kernel doesn't support it, this is set to 0.
-var tryDupCloexec = int32(1)
-
-func dupCloseOnExec(fd int) (newfd int, err error) {
- if atomic.LoadInt32(&tryDupCloexec) == 1 && syscall.F_DUPFD_CLOEXEC != 0 {
- syscall.Entersyscall()
- r0, errno := fcntl(uintptr(fd), syscall.F_DUPFD_CLOEXEC, 0)
- syscall.Exitsyscall()
- e1 := syscall.Errno(errno)
- if runtime.GOOS == "darwin" && e1 == syscall.EBADF {
- // On OS X 10.6 and below (but we only support
- // >= 10.6), F_DUPFD_CLOEXEC is unsupported
- // and fcntl there falls back (undocumented)
- // to doing an ioctl instead, returning EBADF
- // in this case because fd is not of the
- // expected device fd type. Treat it as
- // EINVAL instead, so we fall back to the
- // normal dup path.
- // TODO: only do this on 10.6 if we can detect 10.6
- // cheaply.
- e1 = syscall.EINVAL
- }
- switch e1 {
- case 0:
- return int(r0), nil
- case syscall.EINVAL:
- // Old kernel. Fall back to the portable way
- // from now on.
- atomic.StoreInt32(&tryDupCloexec, 0)
- default:
- return -1, os.NewSyscallError("fcntl", e1)
+func (fd *netFD) dup() (f *os.File, err error) {
+ ns, call, err := fd.pfd.Dup()
+ if err != nil {
+ if call != "" {
+ err = os.NewSyscallError(call, err)
}
+ return nil, err
}
- return dupCloseOnExecOld(fd)
-}
-// dupCloseOnExecUnixOld is the traditional way to dup an fd and
-// set its O_CLOEXEC bit, using two system calls.
-func dupCloseOnExecOld(fd int) (newfd int, err error) {
- syscall.ForkLock.RLock()
- defer syscall.ForkLock.RUnlock()
- newfd, err = syscall.Dup(fd)
- if err != nil {
- return -1, os.NewSyscallError("dup", err)
- }
- syscall.CloseOnExec(newfd)
- return
+ return os.NewFile(uintptr(ns), fd.name()), nil
}
-func (fd *netFD) dup() (f *os.File, err error) {
- ns, err := dupCloseOnExec(fd.pfd.Sysfd)
- if err != nil {
- return nil, err
- }
+func (fd *netFD) SetDeadline(t time.Time) error {
+ return fd.pfd.SetDeadline(t)
+}
- // We want blocking mode for the new fd, hence the double negative.
- // This also puts the old fd into blocking mode, meaning that
- // I/O will block the thread instead of letting us use the epoll server.
- // Everything will still work, just with more threads.
- if err = fd.pfd.SetBlocking(); err != nil {
- return nil, os.NewSyscallError("setnonblock", err)
- }
+func (fd *netFD) SetReadDeadline(t time.Time) error {
+ return fd.pfd.SetReadDeadline(t)
+}
- return os.NewFile(uintptr(ns), fd.name()), nil
+func (fd *netFD) SetWriteDeadline(t time.Time) error {
+ return fd.pfd.SetWriteDeadline(t)
}
diff --git a/libgo/go/net/fd_windows.go b/libgo/go/net/fd_windows.go
index e5f8da1..3cc4c7a6 100644
--- a/libgo/go/net/fd_windows.go
+++ b/libgo/go/net/fd_windows.go
@@ -10,6 +10,7 @@ import (
"os"
"runtime"
"syscall"
+ "time"
"unsafe"
)
@@ -31,7 +32,7 @@ type netFD struct {
// immutable until Close
family int
sotype int
- isConnected bool
+ isConnected bool // handshake completed or use of association with peer
net string
laddr Addr
raddr Addr
@@ -241,3 +242,15 @@ func (fd *netFD) dup() (*os.File, error) {
// TODO: Implement this
return nil, syscall.EWINDOWS
}
+
+func (fd *netFD) SetDeadline(t time.Time) error {
+ return fd.pfd.SetDeadline(t)
+}
+
+func (fd *netFD) SetReadDeadline(t time.Time) error {
+ return fd.pfd.SetReadDeadline(t)
+}
+
+func (fd *netFD) SetWriteDeadline(t time.Time) error {
+ return fd.pfd.SetWriteDeadline(t)
+}
diff --git a/libgo/go/net/file.go b/libgo/go/net/file.go
index 07099851..81a44e1 100644
--- a/libgo/go/net/file.go
+++ b/libgo/go/net/file.go
@@ -6,7 +6,7 @@ package net
import "os"
-// BUG(mikio): On NaCl and Windows, the FileConn, FileListener and
+// BUG(mikio): On JS, NaCl and Windows, the FileConn, FileListener and
// FilePacketConn functions are not implemented.
type fileAddr string
diff --git a/libgo/go/net/file_stub.go b/libgo/go/net/file_stub.go
index 0f7460c..2256608 100644
--- a/libgo/go/net/file_stub.go
+++ b/libgo/go/net/file_stub.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build nacl
+// +build nacl js,wasm
package net
diff --git a/libgo/go/net/file_test.go b/libgo/go/net/file_test.go
index abf8b3a..cd71774 100644
--- a/libgo/go/net/file_test.go
+++ b/libgo/go/net/file_test.go
@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// +build !js
+
package net
import (
@@ -291,3 +293,57 @@ func TestFilePacketConn(t *testing.T) {
}
}
}
+
+// Issue 24483.
+func TestFileCloseRace(t *testing.T) {
+ switch runtime.GOOS {
+ case "nacl", "plan9", "windows":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+ if !testableNetwork("tcp") {
+ t.Skip("tcp not supported")
+ }
+
+ handler := func(ls *localServer, ln Listener) {
+ c, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ defer c.Close()
+ var b [1]byte
+ c.Read(b[:])
+ }
+
+ ls, err := newLocalServer("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ls.teardown()
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+
+ const tries = 100
+ for i := 0; i < tries; i++ {
+ c1, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ tc := c1.(*TCPConn)
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+ f, err := tc.File()
+ if err == nil {
+ f.Close()
+ }
+ }()
+ go func() {
+ defer wg.Done()
+ c1.Close()
+ }()
+ wg.Wait()
+ }
+}
diff --git a/libgo/go/net/file_unix.go b/libgo/go/net/file_unix.go
index 3655a89..452a079 100644
--- a/libgo/go/net/file_unix.go
+++ b/libgo/go/net/file_unix.go
@@ -13,8 +13,11 @@ import (
)
func dupSocket(f *os.File) (int, error) {
- s, err := dupCloseOnExec(int(f.Fd()))
+ s, call, err := poll.DupCloseOnExec(int(f.Fd()))
if err != nil {
+ if call != "" {
+ err = os.NewSyscallError(call, err)
+ }
return -1, err
}
if err := syscall.SetNonblock(s, true); err != nil {
diff --git a/libgo/go/net/hook_unix.go b/libgo/go/net/hook_unix.go
index 7d58d0f..a156831 100644
--- a/libgo/go/net/hook_unix.go
+++ b/libgo/go/net/hook_unix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build aix darwin dragonfly freebsd linux nacl netbsd openbsd solaris
+// +build aix darwin dragonfly freebsd js,wasm linux nacl netbsd openbsd solaris
package net
diff --git a/libgo/go/net/hosts.go b/libgo/go/net/hosts.go
index 9c101c6..ebc0353 100644
--- a/libgo/go/net/hosts.go
+++ b/libgo/go/net/hosts.go
@@ -16,7 +16,7 @@ func parseLiteralIP(addr string) string {
var zone string
ip = parseIPv4(addr)
if ip == nil {
- ip, zone = parseIPv6(addr, true)
+ ip, zone = parseIPv6Zone(addr)
}
if ip == nil {
return ""
diff --git a/libgo/go/net/http/cgi/child.go b/libgo/go/net/http/cgi/child.go
index ec10108..da12ac3 100644
--- a/libgo/go/net/http/cgi/child.go
+++ b/libgo/go/net/http/cgi/child.go
@@ -102,7 +102,7 @@ func RequestFromMap(params map[string]string) (*http.Request, error) {
}
// There's apparently a de-facto standard for this.
- // http://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636
+ // https://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636
if s := params["HTTPS"]; s == "on" || s == "ON" || s == "1" {
r.TLS = &tls.ConnectionState{HandshakeComplete: true}
}
diff --git a/libgo/go/net/http/cgi/host_test.go b/libgo/go/net/http/cgi/host_test.go
index 1336300..25882de 100644
--- a/libgo/go/net/http/cgi/host_test.go
+++ b/libgo/go/net/http/cgi/host_test.go
@@ -503,6 +503,7 @@ func TestDirWindows(t *testing.T) {
}
func TestEnvOverride(t *testing.T) {
+ check(t)
cgifile, _ := filepath.Abs("testdata/test.cgi")
var perl string
@@ -525,7 +526,7 @@ func TestEnvOverride(t *testing.T) {
"PATH=/wibble"},
}
expectedMap := map[string]string{
- "cwd": cwd,
+ "cwd": cwd,
"env-SCRIPT_FILENAME": cgifile,
"env-REQUEST_URI": "/foo/bar",
"env-PATH": "/wibble",
diff --git a/libgo/go/net/http/client.go b/libgo/go/net/http/client.go
index 6f6024e..8f69a29 100644
--- a/libgo/go/net/http/client.go
+++ b/libgo/go/net/http/client.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// HTTP client. See RFC 2616.
+// HTTP client. See RFC 7230 through 7235.
//
// This is the high-level Client interface.
// The low-level implementation is in transport.go.
@@ -95,14 +95,12 @@ type Client struct {
// A Timeout of zero means no timeout.
//
// The Client cancels requests to the underlying Transport
- // using the Request.Cancel mechanism. Requests passed
- // to Client.Do may still set Request.Cancel; both will
- // cancel the request.
+ // as if the Request's Context ended.
//
// For compatibility, the Client will also use the deprecated
// CancelRequest method on Transport if found. New
- // RoundTripper implementations should use Request.Cancel
- // instead of implementing CancelRequest.
+ // RoundTripper implementations should use the Request's Context
+ // for cancelation instead of implementing CancelRequest.
Timeout time.Duration
}
@@ -129,8 +127,8 @@ type RoundTripper interface {
// RoundTrip should not modify the request, except for
// consuming and closing the Request's Body. RoundTrip may
// read fields of the request in a separate goroutine. Callers
- // should not mutate the request until the Response's Body has
- // been closed.
+ // should not mutate or reuse the request until the Response's
+ // Body has been closed.
//
// RoundTrip must always close the body, including on errors,
// but depending on the implementation may do so in a separate
@@ -335,7 +333,7 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi
return stopTimer, timedOut.isSet
}
-// See 2 (end of page 4) http://www.ietf.org/rfc/rfc2617.txt
+// See 2 (end of page 4) https://www.ietf.org/rfc/rfc2617.txt
// "To receive authorization, the client sends the userid and password,
// separated by a single colon (":") character, within a base64
// encoded string in the credentials."
@@ -357,7 +355,9 @@ func basicAuth(username, password string) string {
//
// An error is returned if there were too many redirects or if there
// was an HTTP protocol error. A non-2xx response doesn't cause an
-// error.
+// error. Any returned error will be of type *url.Error. The url.Error
+// value's Timeout method will report true if request timed out or was
+// canceled.
//
// When err is nil, resp always contains a non-nil resp.Body.
// Caller should close resp.Body when done reading from it.
@@ -382,7 +382,9 @@ func Get(url string) (resp *Response, err error) {
//
// An error is returned if the Client's CheckRedirect function fails
// or if there was an HTTP protocol error. A non-2xx response doesn't
-// cause an error.
+// cause an error. Any returned error will be of type *url.Error. The
+// url.Error value's Timeout method will report true if request timed
+// out or was canceled.
//
// When err is nil, resp always contains a non-nil resp.Body.
// Caller should close resp.Body when done reading from it.
@@ -457,6 +459,15 @@ func redirectBehavior(reqMethod string, resp *Response, ireq *Request) (redirect
return redirectMethod, shouldRedirect, includeBody
}
+// urlErrorOp returns the (*url.Error).Op value to use for the
+// provided (*Request).Method value.
+func urlErrorOp(method string) string {
+ if method == "" {
+ return "Get"
+ }
+ return method[:1] + strings.ToLower(method[1:])
+}
+
// Do sends an HTTP request and returns an HTTP response, following
// policy (such as redirects, cookies, auth) as configured on the
// client.
@@ -490,10 +501,26 @@ func redirectBehavior(reqMethod string, resp *Response, ireq *Request) (redirect
// provided that the Request.GetBody function is defined.
// The NewRequest function automatically sets GetBody for common
// standard library body types.
+//
+// Any returned error will be of type *url.Error. The url.Error
+// value's Timeout method will report true if request timed out or was
+// canceled.
func (c *Client) Do(req *Request) (*Response, error) {
+ return c.do(req)
+}
+
+var testHookClientDoResult func(retres *Response, reterr error)
+
+func (c *Client) do(req *Request) (retres *Response, reterr error) {
+ if testHookClientDoResult != nil {
+ defer func() { testHookClientDoResult(retres, reterr) }()
+ }
if req.URL == nil {
req.closeBody()
- return nil, errors.New("http: nil Request.URL")
+ return nil, &url.Error{
+ Op: urlErrorOp(req.Method),
+ Err: errors.New("http: nil Request.URL"),
+ }
}
var (
@@ -512,15 +539,14 @@ func (c *Client) Do(req *Request) (*Response, error) {
if !reqBodyClosed {
req.closeBody()
}
- method := valueOrDefault(reqs[0].Method, "GET")
var urlStr string
if resp != nil && resp.Request != nil {
- urlStr = resp.Request.URL.String()
+ urlStr = stripPassword(resp.Request.URL)
} else {
- urlStr = req.URL.String()
+ urlStr = stripPassword(req.URL)
}
return &url.Error{
- Op: method[:1] + strings.ToLower(method[1:]),
+ Op: urlErrorOp(reqs[0].Method),
URL: urlStr,
Err: err,
}
@@ -617,6 +643,7 @@ func (c *Client) Do(req *Request) (*Response, error) {
reqBodyClosed = true
if !deadline.IsZero() && didTimeout() {
err = &httpError{
+ // TODO: early in cycle: s/Client.Timeout exceeded/timeout or context cancelation/
err: err.Error() + " (Client.Timeout exceeded while awaiting headers)",
timeout: true,
}
@@ -718,7 +745,7 @@ func defaultCheckRedirect(req *Request, via []*Request) error {
//
// See the Client.Do method documentation for details on how redirects
// are handled.
-func Post(url string, contentType string, body io.Reader) (resp *Response, err error) {
+func Post(url, contentType string, body io.Reader) (resp *Response, err error) {
return DefaultClient.Post(url, contentType, body)
}
@@ -733,7 +760,7 @@ func Post(url string, contentType string, body io.Reader) (resp *Response, err e
//
// See the Client.Do method documentation for details on how redirects
// are handled.
-func (c *Client) Post(url string, contentType string, body io.Reader) (resp *Response, err error) {
+func (c *Client) Post(url, contentType string, body io.Reader) (resp *Response, err error) {
req, err := NewRequest("POST", url, body)
if err != nil {
return nil, err
@@ -827,6 +854,7 @@ func (b *cancelTimerBody) Read(p []byte) (n int, err error) {
}
if b.reqDidTimeout() {
err = &httpError{
+ // TODO: early in cycle: s/Client.Timeout exceeded/timeout or context cancelation/
err: err.Error() + " (Client.Timeout exceeded while reading body)",
timeout: true,
}
@@ -880,3 +908,12 @@ func isDomainOrSubdomain(sub, parent string) bool {
}
return sub[len(sub)-len(parent)-1] == '.'
}
+
+func stripPassword(u *url.URL) string {
+ pass, passSet := u.User.Password()
+ if passSet {
+ return strings.Replace(u.String(), pass+"@", "***@", 1)
+ }
+
+ return u.String()
+}
diff --git a/libgo/go/net/http/client_test.go b/libgo/go/net/http/client_test.go
index eea3b16..bfc793e 100644
--- a/libgo/go/net/http/client_test.go
+++ b/libgo/go/net/http/client_test.go
@@ -1162,6 +1162,40 @@ func TestBasicAuthHeadersPreserved(t *testing.T) {
}
+func TestStripPasswordFromError(t *testing.T) {
+ client := &Client{Transport: &recordingTransport{}}
+ testCases := []struct {
+ desc string
+ in string
+ out string
+ }{
+ {
+ desc: "Strip password from error message",
+ in: "http://user:password@dummy.faketld/",
+ out: "Get http://user:***@dummy.faketld/: dummy impl",
+ },
+ {
+ desc: "Don't Strip password from domain name",
+ in: "http://user:password@password.faketld/",
+ out: "Get http://user:***@password.faketld/: dummy impl",
+ },
+ {
+ desc: "Don't Strip password from path",
+ in: "http://user:password@dummy.faketld/password",
+ out: "Get http://user:***@dummy.faketld/password: dummy impl",
+ },
+ }
+ for _, tC := range testCases {
+ t.Run(tC.desc, func(t *testing.T) {
+ _, err := client.Get(tC.in)
+ if err.Error() != tC.out {
+ t.Errorf("Unexpected output for %q: expected %q, actual %q",
+ tC.in, tC.out, err.Error())
+ }
+ })
+ }
+}
+
func TestClientTimeout_h1(t *testing.T) { testClientTimeout(t, h1Mode) }
func TestClientTimeout_h2(t *testing.T) { testClientTimeout(t, h2Mode) }
diff --git a/libgo/go/net/http/clientserver_test.go b/libgo/go/net/http/clientserver_test.go
index 8f2e574..6513b2d 100644
--- a/libgo/go/net/http/clientserver_test.go
+++ b/libgo/go/net/http/clientserver_test.go
@@ -1236,7 +1236,6 @@ func TestH12_AutoGzipWithDumpResponse(t *testing.T) {
h := w.Header()
h.Set("Content-Encoding", "gzip")
h.Set("Content-Length", "23")
- h.Set("Connection", "keep-alive")
io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00")
},
EarlyCheckResponse: func(proto string, res *Response) {
diff --git a/libgo/go/net/http/cookie.go b/libgo/go/net/http/cookie.go
index 38b1b36..b1a6cef 100644
--- a/libgo/go/net/http/cookie.go
+++ b/libgo/go/net/http/cookie.go
@@ -5,7 +5,6 @@
package http
import (
- "bytes"
"log"
"net"
"strconv"
@@ -16,7 +15,7 @@ import (
// A Cookie represents an HTTP cookie as sent in the Set-Cookie header of an
// HTTP response or the Cookie header of an HTTP request.
//
-// See http://tools.ietf.org/html/rfc6265 for details.
+// See https://tools.ietf.org/html/rfc6265 for details.
type Cookie struct {
Name string
Value string
@@ -32,10 +31,25 @@ type Cookie struct {
MaxAge int
Secure bool
HttpOnly bool
+ SameSite SameSite
Raw string
Unparsed []string // Raw text of unparsed attribute-value pairs
}
+// SameSite allows a server define a cookie attribute making it impossible to
+// the browser send this cookie along with cross-site requests. The main goal
+// is mitigate the risk of cross-origin information leakage, and provides some
+// protection against cross-site request forgery attacks.
+//
+// See https://tools.ietf.org/html/draft-ietf-httpbis-cookie-same-site-00 for details.
+type SameSite int
+
+const (
+ SameSiteDefaultMode SameSite = iota + 1
+ SameSiteLaxMode
+ SameSiteStrictMode
+)
+
// readSetCookies parses all "Set-Cookie" values from
// the header h and returns the successfully parsed Cookies.
func readSetCookies(h Header) []*Cookie {
@@ -84,6 +98,17 @@ func readSetCookies(h Header) []*Cookie {
continue
}
switch lowerAttr {
+ case "samesite":
+ lowerVal := strings.ToLower(val)
+ switch lowerVal {
+ case "lax":
+ c.SameSite = SameSiteLaxMode
+ case "strict":
+ c.SameSite = SameSiteStrictMode
+ default:
+ c.SameSite = SameSiteDefaultMode
+ }
+ continue
case "secure":
c.Secure = true
continue
@@ -143,7 +168,7 @@ func (c *Cookie) String() string {
if c == nil || !isCookieNameValid(c.Name) {
return ""
}
- var b bytes.Buffer
+ var b strings.Builder
b.WriteString(sanitizeCookieName(c.Name))
b.WriteRune('=')
b.WriteString(sanitizeCookieValue(c.Value))
@@ -168,17 +193,14 @@ func (c *Cookie) String() string {
log.Printf("net/http: invalid Cookie.Domain %q; dropping domain attribute", c.Domain)
}
}
+ var buf [len(TimeFormat)]byte
if validCookieExpires(c.Expires) {
b.WriteString("; Expires=")
- b2 := b.Bytes()
- b.Reset()
- b.Write(c.Expires.UTC().AppendFormat(b2, TimeFormat))
+ b.Write(c.Expires.UTC().AppendFormat(buf[:0], TimeFormat))
}
if c.MaxAge > 0 {
b.WriteString("; Max-Age=")
- b2 := b.Bytes()
- b.Reset()
- b.Write(strconv.AppendInt(b2, int64(c.MaxAge), 10))
+ b.Write(strconv.AppendInt(buf[:0], int64(c.MaxAge), 10))
} else if c.MaxAge < 0 {
b.WriteString("; Max-Age=0")
}
@@ -188,6 +210,14 @@ func (c *Cookie) String() string {
if c.Secure {
b.WriteString("; Secure")
}
+ switch c.SameSite {
+ case SameSiteDefaultMode:
+ b.WriteString("; SameSite")
+ case SameSiteLaxMode:
+ b.WriteString("; SameSite=Lax")
+ case SameSiteStrictMode:
+ b.WriteString("; SameSite=Strict")
+ }
return b.String()
}
@@ -311,7 +341,7 @@ func sanitizeCookieName(n string) string {
return cookieNameSanitizer.Replace(n)
}
-// http://tools.ietf.org/html/rfc6265#section-4.1.1
+// https://tools.ietf.org/html/rfc6265#section-4.1.1
// cookie-value = *cookie-octet / ( DQUOTE *cookie-octet DQUOTE )
// cookie-octet = %x21 / %x23-2B / %x2D-3A / %x3C-5B / %x5D-7E
// ; US-ASCII characters excluding CTLs,
diff --git a/libgo/go/net/http/cookie_test.go b/libgo/go/net/http/cookie_test.go
index 9d199a3..022adaa 100644
--- a/libgo/go/net/http/cookie_test.go
+++ b/libgo/go/net/http/cookie_test.go
@@ -65,6 +65,18 @@ var writeSetCookiesTests = []struct {
&Cookie{Name: "cookie-11", Value: "invalid-expiry", Expires: time.Date(1600, 1, 1, 1, 1, 1, 1, time.UTC)},
"cookie-11=invalid-expiry",
},
+ {
+ &Cookie{Name: "cookie-12", Value: "samesite-default", SameSite: SameSiteDefaultMode},
+ "cookie-12=samesite-default; SameSite",
+ },
+ {
+ &Cookie{Name: "cookie-13", Value: "samesite-lax", SameSite: SameSiteLaxMode},
+ "cookie-13=samesite-lax; SameSite=Lax",
+ },
+ {
+ &Cookie{Name: "cookie-14", Value: "samesite-strict", SameSite: SameSiteStrictMode},
+ "cookie-14=samesite-strict; SameSite=Strict",
+ },
// The "special" cookies have values containing commas or spaces which
// are disallowed by RFC 6265 but are common in the wild.
{
@@ -241,6 +253,33 @@ var readSetCookiesTests = []struct {
Raw: "ASP.NET_SessionId=foo; path=/; HttpOnly",
}},
},
+ {
+ Header{"Set-Cookie": {"samesitedefault=foo; SameSite"}},
+ []*Cookie{{
+ Name: "samesitedefault",
+ Value: "foo",
+ SameSite: SameSiteDefaultMode,
+ Raw: "samesitedefault=foo; SameSite",
+ }},
+ },
+ {
+ Header{"Set-Cookie": {"samesitelax=foo; SameSite=Lax"}},
+ []*Cookie{{
+ Name: "samesitelax",
+ Value: "foo",
+ SameSite: SameSiteLaxMode,
+ Raw: "samesitelax=foo; SameSite=Lax",
+ }},
+ },
+ {
+ Header{"Set-Cookie": {"samesitestrict=foo; SameSite=Strict"}},
+ []*Cookie{{
+ Name: "samesitestrict",
+ Value: "foo",
+ SameSite: SameSiteStrictMode,
+ Raw: "samesitestrict=foo; SameSite=Strict",
+ }},
+ },
// Make sure we can properly read back the Set-Cookie headers we create
// for values containing spaces or commas:
{
diff --git a/libgo/go/net/http/cookiejar/jar.go b/libgo/go/net/http/cookiejar/jar.go
index ef8c35b..9f19917 100644
--- a/libgo/go/net/http/cookiejar/jar.go
+++ b/libgo/go/net/http/cookiejar/jar.go
@@ -93,6 +93,7 @@ type entry struct {
Value string
Domain string
Path string
+ SameSite string
Secure bool
HttpOnly bool
Persistent bool
@@ -418,6 +419,15 @@ func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e e
e.Secure = c.Secure
e.HttpOnly = c.HttpOnly
+ switch c.SameSite {
+ case http.SameSiteDefaultMode:
+ e.SameSite = "SameSite"
+ case http.SameSiteStrictMode:
+ e.SameSite = "SameSite=Strict"
+ case http.SameSiteLaxMode:
+ e.SameSite = "SameSite=Lax"
+ }
+
return e, false, nil
}
diff --git a/libgo/go/net/http/example_test.go b/libgo/go/net/http/example_test.go
index 9de0893..53fb0bb 100644
--- a/libgo/go/net/http/example_test.go
+++ b/libgo/go/net/http/example_test.go
@@ -137,3 +137,25 @@ func ExampleServer_Shutdown() {
<-idleConnsClosed
}
+
+func ExampleListenAndServeTLS() {
+ http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
+ io.WriteString(w, "Hello, TLS!\n")
+ })
+
+ // One can use generate_cert.go in crypto/tls to generate cert.pem and key.pem.
+ log.Printf("About to listen on 8443. Go to https://127.0.0.1:8443/")
+ err := http.ListenAndServeTLS(":8443", "cert.pem", "key.pem", nil)
+ log.Fatal(err)
+}
+
+func ExampleListenAndServe() {
+ // Hello world, the web server
+
+ helloHandler := func(w http.ResponseWriter, req *http.Request) {
+ io.WriteString(w, "Hello, world!\n")
+ }
+
+ http.HandleFunc("/hello", helloHandler)
+ log.Fatal(http.ListenAndServe(":8080", nil))
+}
diff --git a/libgo/go/net/http/export_test.go b/libgo/go/net/http/export_test.go
index 1825acd..bc0db53 100644
--- a/libgo/go/net/http/export_test.go
+++ b/libgo/go/net/http/export_test.go
@@ -9,7 +9,9 @@ package http
import (
"context"
+ "fmt"
"net"
+ "net/url"
"sort"
"sync"
"testing"
@@ -33,11 +35,28 @@ var (
Export_writeStatusLine = writeStatusLine
)
+const MaxWriteWaitBeforeConnReuse = maxWriteWaitBeforeConnReuse
+
func init() {
// We only want to pay for this cost during testing.
// When not under test, these values are always nil
// and never assigned to.
testHookMu = new(sync.Mutex)
+
+ testHookClientDoResult = func(res *Response, err error) {
+ if err != nil {
+ if _, ok := err.(*url.Error); !ok {
+ panic(fmt.Sprintf("unexpected Client.Do error of type %T; want *url.Error", err))
+ }
+ } else {
+ if res == nil {
+ panic("Client.Do returned nil, nil")
+ }
+ if res.Body == nil {
+ panic("Client.Do returned nil res.Body and no error")
+ }
+ }
+ }
}
var (
@@ -76,9 +95,7 @@ func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler {
}
func ResetCachedEnvironment() {
- httpProxyEnv.reset()
- httpsProxyEnv.reset()
- noProxyEnv.reset()
+ resetProxyConfig()
}
func (t *Transport) NumPendingRequestsForTesting() int {
@@ -119,7 +136,7 @@ func (t *Transport) IdleConnStrsForTesting() []string {
func (t *Transport) IdleConnStrsForTesting_h2() []string {
var ret []string
- noDialPool := t.h2transport.ConnPool.(http2noDialClientConnPool)
+ noDialPool := t.h2transport.(*http2Transport).ConnPool.(http2noDialClientConnPool)
pool := noDialPool.http2clientConnPool
pool.mu.Lock()
@@ -135,9 +152,11 @@ func (t *Transport) IdleConnStrsForTesting_h2() []string {
return ret
}
-func (t *Transport) IdleConnCountForTesting(cacheKey string) int {
+func (t *Transport) IdleConnCountForTesting(scheme, addr string) int {
t.idleMu.Lock()
defer t.idleMu.Unlock()
+ key := connectMethodKey{"", scheme, addr}
+ cacheKey := key.String()
for k, conns := range t.idleConn {
if k.String() == cacheKey {
return len(conns)
@@ -162,13 +181,19 @@ func (t *Transport) RequestIdleConnChForTesting() {
t.getIdleConnCh(connectMethod{nil, "http", "example.com"})
}
-func (t *Transport) PutIdleTestConn() bool {
+func (t *Transport) PutIdleTestConn(scheme, addr string) bool {
c, _ := net.Pipe()
+ key := connectMethodKey{"", scheme, addr}
+ select {
+ case <-t.incHostConnCount(key):
+ default:
+ return false
+ }
return t.tryPutIdleConn(&persistConn{
t: t,
conn: c, // dummy
closech: make(chan struct{}), // so it can be closed
- cacheKey: connectMethodKey{"", "http", "example.com"},
+ cacheKey: key,
}) == nil
}
@@ -200,8 +225,8 @@ func (s *Server) ExportAllConnsIdle() bool {
s.mu.Lock()
defer s.mu.Unlock()
for c := range s.activeConn {
- st, ok := c.curState.Load().(ConnState)
- if !ok || st != StateIdle {
+ st, unixSec := c.getState()
+ if unixSec == 0 || st != StateIdle {
return false
}
}
diff --git a/libgo/go/net/http/fcgi/fcgi.go b/libgo/go/net/http/fcgi/fcgi.go
index 8f3449a..fb822f8 100644
--- a/libgo/go/net/http/fcgi/fcgi.go
+++ b/libgo/go/net/http/fcgi/fcgi.go
@@ -4,9 +4,8 @@
// Package fcgi implements the FastCGI protocol.
//
-// The protocol is not an official standard and the original
-// documentation is no longer online. See the Internet Archive's
-// mirror at: https://web.archive.org/web/20150420080736/http://www.fastcgi.com/drupal/node/6?q=node/22
+// See https://fast-cgi.github.io/ for an unofficial mirror of the
+// original documentation.
//
// Currently only the responder role is supported.
package fcgi
diff --git a/libgo/go/net/http/fs.go b/libgo/go/net/http/fs.go
index ecad14a..db44d6b 100644
--- a/libgo/go/net/http/fs.go
+++ b/libgo/go/net/http/fs.go
@@ -235,17 +235,17 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time,
}
switch {
case len(ranges) == 1:
- // RFC 2616, Section 14.16:
- // "When an HTTP message includes the content of a single
- // range (for example, a response to a request for a
- // single range, or to a request for a set of ranges
- // that overlap without any holes), this content is
- // transmitted with a Content-Range header, and a
- // Content-Length header showing the number of bytes
- // actually transferred.
+ // RFC 7233, Section 4.1:
+ // "If a single part is being transferred, the server
+ // generating the 206 response MUST generate a
+ // Content-Range header field, describing what range
+ // of the selected representation is enclosed, and a
+ // payload consisting of the range.
// ...
- // A response to a request for a single range MUST NOT
- // be sent using the multipart/byteranges media type."
+ // A server MUST NOT generate a multipart response to
+ // a request for a single range, since a client that
+ // does not request multiple parts might not support
+ // multipart responses."
ra := ranges[0]
if _, err := content.Seek(ra.start, io.SeekStart); err != nil {
Error(w, err.Error(), StatusRequestedRangeNotSatisfiable)
@@ -650,15 +650,23 @@ func localRedirect(w ResponseWriter, r *Request, newPath string) {
// file or directory.
//
// If the provided file or directory name is a relative path, it is
-// interpreted relative to the current directory and may ascend to parent
-// directories. If the provided name is constructed from user input, it
-// should be sanitized before calling ServeFile. As a precaution, ServeFile
-// will reject requests where r.URL.Path contains a ".." path element.
+// interpreted relative to the current directory and may ascend to
+// parent directories. If the provided name is constructed from user
+// input, it should be sanitized before calling ServeFile.
//
-// As a special case, ServeFile redirects any request where r.URL.Path
+// As a precaution, ServeFile will reject requests where r.URL.Path
+// contains a ".." path element; this protects against callers who
+// might unsafely use filepath.Join on r.URL.Path without sanitizing
+// it and then use that filepath.Join result as the name argument.
+//
+// As another special case, ServeFile redirects any request where r.URL.Path
// ends in "/index.html" to the same path, without the final
// "index.html". To avoid such redirects either modify the path or
// use ServeContent.
+//
+// Outside of those two special cases, ServeFile does not use
+// r.URL.Path for selecting the file or directory to serve; only the
+// file or directory provided in the name argument is used.
func ServeFile(w ResponseWriter, r *Request, name string) {
if containsDotDot(r.URL.Path) {
// Too many programs use r.URL.Path to construct the argument to
@@ -731,7 +739,7 @@ func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHead
}
}
-// parseRange parses a Range header string as per RFC 2616.
+// parseRange parses a Range header string as per RFC 7233.
// errNoOverlap is returned if none of the ranges overlap.
func parseRange(s string, size int64) ([]httpRange, error) {
if s == "" {
diff --git a/libgo/go/net/http/fs_test.go b/libgo/go/net/http/fs_test.go
index de772f9..1d6380d 100644
--- a/libgo/go/net/http/fs_test.go
+++ b/libgo/go/net/http/fs_test.go
@@ -993,7 +993,7 @@ func TestServeContent(t *testing.T) {
for _, method := range []string{"GET", "HEAD"} {
//restore content in case it is consumed by previous method
if content, ok := content.(*strings.Reader); ok {
- content.Seek(io.SeekStart, 0)
+ content.Seek(0, io.SeekStart)
}
servec <- serveParam{
diff --git a/libgo/go/net/http/h2_bundle.go b/libgo/go/net/http/h2_bundle.go
index 3671875..4268f2f 100644
--- a/libgo/go/net/http/h2_bundle.go
+++ b/libgo/go/net/http/h2_bundle.go
@@ -44,13 +44,13 @@ import (
"sync"
"time"
+ "golang_org/x/net/http/httpguts"
"golang_org/x/net/http2/hpack"
"golang_org/x/net/idna"
- "golang_org/x/net/lex/httplex"
)
// A list of the possible cipher suite ids. Taken from
-// http://www.iana.org/assignments/tls-parameters/tls-parameters.txt
+// https://www.iana.org/assignments/tls-parameters/tls-parameters.txt
const (
http2cipher_TLS_NULL_WITH_NULL_NULL uint16 = 0x0000
@@ -725,9 +725,31 @@ const (
http2noDialOnMiss = false
)
+// shouldTraceGetConn reports whether getClientConn should call any
+// ClientTrace.GetConn hook associated with the http.Request.
+//
+// This complexity is needed to avoid double calls of the GetConn hook
+// during the back-and-forth between net/http and x/net/http2 (when the
+// net/http.Transport is upgraded to also speak http2), as well as support
+// the case where x/net/http2 is being used directly.
+func (p *http2clientConnPool) shouldTraceGetConn(st http2clientConnIdleState) bool {
+ // If our Transport wasn't made via ConfigureTransport, always
+ // trace the GetConn hook if provided, because that means the
+ // http2 package is being used directly and it's the one
+ // dialing, as opposed to net/http.
+ if _, ok := p.t.ConnPool.(http2noDialClientConnPool); !ok {
+ return true
+ }
+ // Otherwise, only use the GetConn hook if this connection has
+ // been used previously for other requests. For fresh
+ // connections, the net/http package does the dialing.
+ return !st.freshConn
+}
+
func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) {
if http2isConnectionCloseRequest(req) && dialOnMiss {
// It gets its own connection.
+ http2traceGetConn(req, addr)
const singleUse = true
cc, err := p.t.dialClientConn(addr, singleUse)
if err != nil {
@@ -737,7 +759,10 @@ func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMis
}
p.mu.Lock()
for _, cc := range p.conns[addr] {
- if cc.CanTakeNewRequest() {
+ if st := cc.idleState(); st.canTakeNewRequest {
+ if p.shouldTraceGetConn(st) {
+ http2traceGetConn(req, addr)
+ }
p.mu.Unlock()
return cc, nil
}
@@ -746,6 +771,7 @@ func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMis
p.mu.Unlock()
return nil, http2ErrNoCachedConn
}
+ http2traceGetConn(req, addr)
call := p.getStartDialLocked(addr)
p.mu.Unlock()
<-call.done
@@ -973,7 +999,7 @@ func http2configureTransport(t1 *Transport) (*http2Transport, error) {
// registerHTTPSProtocol calls Transport.RegisterProtocol but
// converting panics into errors.
-func http2registerHTTPSProtocol(t *Transport, rt RoundTripper) (err error) {
+func http2registerHTTPSProtocol(t *Transport, rt http2noDialH2RoundTripper) (err error) {
defer func() {
if e := recover(); e != nil {
err = fmt.Errorf("%v", e)
@@ -985,10 +1011,12 @@ func http2registerHTTPSProtocol(t *Transport, rt RoundTripper) (err error) {
// noDialH2RoundTripper is a RoundTripper which only tries to complete the request
// if there's already has a cached connection to the host.
-type http2noDialH2RoundTripper struct{ t *http2Transport }
+// (The field is exported so it can be accessed via reflect from net/http; tested
+// by TestNoDialH2RoundTripperType)
+type http2noDialH2RoundTripper struct{ *http2Transport }
func (rt http2noDialH2RoundTripper) RoundTrip(req *Request) (*Response, error) {
- res, err := rt.t.RoundTrip(req)
+ res, err := rt.http2Transport.RoundTrip(req)
if http2isNoCachedConnError(err) {
return nil, ErrSkipAltProtocol
}
@@ -1290,12 +1318,12 @@ func (f *http2flow) take(n int32) {
// add adds n bytes (positive or negative) to the flow control window.
// It returns false if the sum would exceed 2^31-1.
func (f *http2flow) add(n int32) bool {
- remain := (1<<31 - 1) - f.n
- if n > remain {
- return false
+ sum := f.n + n
+ if (sum > n) == (f.n > 0) {
+ f.n = sum
+ return true
}
- f.n += n
- return true
+ return false
}
const http2frameHeaderLen = 9
@@ -2016,32 +2044,67 @@ func (f *http2SettingsFrame) IsAck() bool {
return f.http2FrameHeader.Flags.Has(http2FlagSettingsAck)
}
-func (f *http2SettingsFrame) Value(s http2SettingID) (v uint32, ok bool) {
+func (f *http2SettingsFrame) Value(id http2SettingID) (v uint32, ok bool) {
f.checkValid()
- buf := f.p
- for len(buf) > 0 {
- settingID := http2SettingID(binary.BigEndian.Uint16(buf[:2]))
- if settingID == s {
- return binary.BigEndian.Uint32(buf[2:6]), true
+ for i := 0; i < f.NumSettings(); i++ {
+ if s := f.Setting(i); s.ID == id {
+ return s.Val, true
}
- buf = buf[6:]
}
return 0, false
}
+// Setting returns the setting from the frame at the given 0-based index.
+// The index must be >= 0 and less than f.NumSettings().
+func (f *http2SettingsFrame) Setting(i int) http2Setting {
+ buf := f.p
+ return http2Setting{
+ ID: http2SettingID(binary.BigEndian.Uint16(buf[i*6 : i*6+2])),
+ Val: binary.BigEndian.Uint32(buf[i*6+2 : i*6+6]),
+ }
+}
+
+func (f *http2SettingsFrame) NumSettings() int { return len(f.p) / 6 }
+
+// HasDuplicates reports whether f contains any duplicate setting IDs.
+func (f *http2SettingsFrame) HasDuplicates() bool {
+ num := f.NumSettings()
+ if num == 0 {
+ return false
+ }
+ // If it's small enough (the common case), just do the n^2
+ // thing and avoid a map allocation.
+ if num < 10 {
+ for i := 0; i < num; i++ {
+ idi := f.Setting(i).ID
+ for j := i + 1; j < num; j++ {
+ idj := f.Setting(j).ID
+ if idi == idj {
+ return true
+ }
+ }
+ }
+ return false
+ }
+ seen := map[http2SettingID]bool{}
+ for i := 0; i < num; i++ {
+ id := f.Setting(i).ID
+ if seen[id] {
+ return true
+ }
+ seen[id] = true
+ }
+ return false
+}
+
// ForeachSetting runs fn for each setting.
// It stops and returns the first error.
func (f *http2SettingsFrame) ForeachSetting(fn func(http2Setting) error) error {
f.checkValid()
- buf := f.p
- for len(buf) > 0 {
- if err := fn(http2Setting{
- http2SettingID(binary.BigEndian.Uint16(buf[:2])),
- binary.BigEndian.Uint32(buf[2:6]),
- }); err != nil {
+ for i := 0; i < f.NumSettings(); i++ {
+ if err := fn(f.Setting(i)); err != nil {
return err
}
- buf = buf[6:]
}
return nil
}
@@ -2745,7 +2808,7 @@ func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFr
if http2VerboseLogs && fr.logReads {
fr.debugReadLoggerf("http2: decoded hpack field %+v", hf)
}
- if !httplex.ValidHeaderFieldValue(hf.Value) {
+ if !httpguts.ValidHeaderFieldValue(hf.Value) {
invalid = http2headerFieldValueError(hf.Value)
}
isPseudo := strings.HasPrefix(hf.Name, ":")
@@ -2861,6 +2924,23 @@ func http2summarizeFrame(f http2Frame) string {
return buf.String()
}
+func http2traceHasWroteHeaderField(trace *http2clientTrace) bool {
+ return trace != nil && trace.WroteHeaderField != nil
+}
+
+func http2traceWroteHeaderField(trace *http2clientTrace, k, v string) {
+ if trace != nil && trace.WroteHeaderField != nil {
+ trace.WroteHeaderField(k, []string{v})
+ }
+}
+
+func http2traceGot1xxResponseFunc(trace *http2clientTrace) func(int, textproto.MIMEHeader) error {
+ if trace != nil {
+ return trace.Got1xxResponse
+ }
+ return nil
+}
+
func http2transportExpectContinueTimeout(t1 *Transport) time.Duration {
return t1.ExpectContinueTimeout
}
@@ -2869,6 +2949,8 @@ type http2contextContext interface {
context.Context
}
+var http2errCanceled = context.Canceled
+
func http2serverConnBaseContext(c net.Conn, opts *http2ServeConnOpts) (ctx http2contextContext, cancel func()) {
ctx, cancel = context.WithCancel(context.Background())
ctx = context.WithValue(ctx, LocalAddrContextKey, c.LocalAddr())
@@ -2899,6 +2981,14 @@ func (t *http2Transport) idleConnTimeout() time.Duration {
func http2setResponseUncompressed(res *Response) { res.Uncompressed = true }
+func http2traceGetConn(req *Request, hostPort string) {
+ trace := httptrace.ContextClientTrace(req.Context())
+ if trace == nil || trace.GetConn == nil {
+ return
+ }
+ trace.GetConn(hostPort)
+}
+
func http2traceGotConn(req *Request, cc *http2ClientConn) {
trace := httptrace.ContextClientTrace(req.Context())
if trace == nil || trace.GotConn == nil {
@@ -2956,6 +3046,11 @@ func (cc *http2ClientConn) Ping(ctx context.Context) error {
return cc.ping(ctx)
}
+// Shutdown gracefully closes the client connection, waiting for running streams to complete.
+func (cc *http2ClientConn) Shutdown(ctx context.Context) error {
+ return cc.shutdown(ctx)
+}
+
func http2cloneTLSConfig(c *tls.Config) *tls.Config {
c2 := c.Clone()
c2.GetClientCertificate = c.GetClientCertificate // golang.org/issue/19264
@@ -3371,7 +3466,7 @@ var (
)
// validWireHeaderFieldName reports whether v is a valid header field
-// name (key). See httplex.ValidHeaderName for the base rules.
+// name (key). See httpguts.ValidHeaderName for the base rules.
//
// Further, http2 says:
// "Just as in HTTP/1.x, header field names are strings of ASCII
@@ -3383,7 +3478,7 @@ func http2validWireHeaderFieldName(v string) bool {
return false
}
for _, r := range v {
- if !httplex.IsTokenRune(r) {
+ if !httpguts.IsTokenRune(r) {
return false
}
if 'A' <= r && r <= 'Z' {
@@ -3505,7 +3600,7 @@ func http2mustUint31(v int32) uint32 {
}
// bodyAllowedForStatus reports whether a given response status code
-// permits a body. See RFC 2616, section 4.4.
+// permits a body. See RFC 7230, section 3.3.
func http2bodyAllowedForStatus(status int) bool {
switch {
case status >= 100 && status <= 199:
@@ -4096,7 +4191,7 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) {
// addresses during development.
//
// TODO: optionally enforce? Or enforce at the time we receive
- // a new request, and verify the the ServerName matches the :authority?
+ // a new request, and verify the ServerName matches the :authority?
// But that precludes proxy situations, perhaps.
//
// So for now, do nothing here again.
@@ -5181,6 +5276,12 @@ func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error {
}
return nil
}
+ if f.NumSettings() > 100 || f.HasDuplicates() {
+ // This isn't actually in the spec, but hang up on
+ // suspiciously large settings frames or those with
+ // duplicate entries.
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
if err := f.ForeachSetting(sc.processSetting); err != nil {
return err
}
@@ -5269,6 +5370,12 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error {
// type PROTOCOL_ERROR."
return http2ConnectionError(http2ErrCodeProtocol)
}
+ // RFC 7540, sec 6.1: If a DATA frame is received whose stream is not in
+ // "open" or "half-closed (local)" state, the recipient MUST respond with a
+ // stream error (Section 5.4.2) of type STREAM_CLOSED.
+ if state == http2stateClosed {
+ return http2streamError(id, http2ErrCodeStreamClosed)
+ }
if st == nil || state != http2stateOpen || st.gotTrailerHeader || st.resetQueued {
// This includes sending a RST_STREAM if the stream is
// in stateHalfClosedLocal (which currently means that
@@ -5302,7 +5409,10 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error {
// Sender sending more than they'd declared?
if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes {
st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes))
- return http2streamError(id, http2ErrCodeStreamClosed)
+ // RFC 7540, sec 8.1.2.6: A request or response is also malformed if the
+ // value of a content-length header field does not equal the sum of the
+ // DATA frame payload lengths that form the body.
+ return http2streamError(id, http2ErrCodeProtocol)
}
if f.Length > 0 {
// Check whether the client has flow control quota.
@@ -5412,6 +5522,13 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error {
// processing this frame.
return nil
}
+ // RFC 7540, sec 5.1: If an endpoint receives additional frames, other than
+ // WINDOW_UPDATE, PRIORITY, or RST_STREAM, for a stream that is in
+ // this state, it MUST respond with a stream error (Section 5.4.2) of
+ // type STREAM_CLOSED.
+ if st.state == http2stateHalfClosedRemote {
+ return http2streamError(id, http2ErrCodeStreamClosed)
+ }
return st.processTrailerHeaders(f)
}
@@ -5512,7 +5629,7 @@ func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error {
if st.trailer != nil {
for _, hf := range f.RegularFields() {
key := sc.canonicalHeader(hf.Name)
- if !http2ValidTrailerHeader(key) {
+ if !httpguts.ValidTrailerHeader(key) {
// TODO: send more details to the peer somehow. But http2 has
// no way to send debug data at a stream level. Discuss with
// HTTP folk.
@@ -5979,8 +6096,8 @@ func (rws *http2responseWriterState) hasTrailers() bool { return len(rws.trailer
// written in the trailers at the end of the response.
func (rws *http2responseWriterState) declareTrailer(k string) {
k = CanonicalHeaderKey(k)
- if !http2ValidTrailerHeader(k) {
- // Forbidden by RFC 2616 14.40.
+ if !httpguts.ValidTrailerHeader(k) {
+ // Forbidden by RFC 7230, section 4.1.2.
rws.conn.logf("ignoring invalid trailer %q", k)
return
}
@@ -6030,6 +6147,19 @@ func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) {
http2foreachHeaderElement(v, rws.declareTrailer)
}
+ // "Connection" headers aren't allowed in HTTP/2 (RFC 7540, 8.1.2.2),
+ // but respect "Connection" == "close" to mean sending a GOAWAY and tearing
+ // down the TCP connection when idle, like we do for HTTP/1.
+ // TODO: remove more Connection-specific header fields here, in addition
+ // to "Connection".
+ if _, ok := rws.snapHeader["Connection"]; ok {
+ v := rws.snapHeader.Get("Connection")
+ delete(rws.snapHeader, "Connection")
+ if v == "close" {
+ rws.conn.startGracefulShutdown()
+ }
+ }
+
endStream := (rws.handlerDone && !rws.hasTrailers() && len(p) == 0) || isHeadResp
err = rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{
streamID: rws.stream.id,
@@ -6101,7 +6231,7 @@ const http2TrailerPrefix = "Trailer:"
// after the header has already been flushed. Because the Go
// ResponseWriter interface has no way to set Trailers (only the
// Header), and because we didn't want to expand the ResponseWriter
-// interface, and because nobody used trailers, and because RFC 2616
+// interface, and because nobody used trailers, and because RFC 7230
// says you SHOULD (but not must) predeclare any trailers in the
// header, the official ResponseWriter rules said trailers in Go must
// be predeclared, and then we reuse the same ResponseWriter.Header()
@@ -6485,7 +6615,7 @@ func (sc *http2serverConn) startPush(msg *http2startPushRequest) {
}
// foreachHeaderElement splits v according to the "#rule" construction
-// in RFC 2616 section 2.1 and calls fn for each non-empty element.
+// in RFC 7230 section 7 and calls fn for each non-empty element.
func http2foreachHeaderElement(v string, fn func(string)) {
v = textproto.TrimString(v)
if v == "" {
@@ -6533,41 +6663,6 @@ func http2new400Handler(err error) HandlerFunc {
}
}
-// ValidTrailerHeader reports whether name is a valid header field name to appear
-// in trailers.
-// See: http://tools.ietf.org/html/rfc7230#section-4.1.2
-func http2ValidTrailerHeader(name string) bool {
- name = CanonicalHeaderKey(name)
- if strings.HasPrefix(name, "If-") || http2badTrailer[name] {
- return false
- }
- return true
-}
-
-var http2badTrailer = map[string]bool{
- "Authorization": true,
- "Cache-Control": true,
- "Connection": true,
- "Content-Encoding": true,
- "Content-Length": true,
- "Content-Range": true,
- "Content-Type": true,
- "Expect": true,
- "Host": true,
- "Keep-Alive": true,
- "Max-Forwards": true,
- "Pragma": true,
- "Proxy-Authenticate": true,
- "Proxy-Authorization": true,
- "Proxy-Connection": true,
- "Range": true,
- "Realm": true,
- "Te": true,
- "Trailer": true,
- "Transfer-Encoding": true,
- "Www-Authenticate": true,
-}
-
// h1ServerKeepAlivesDisabled reports whether hs has its keep-alives
// disabled. See comments on h1ServerShutdownChan above for why
// the code is written this way.
@@ -6709,6 +6804,7 @@ type http2ClientConn struct {
cond *sync.Cond // hold mu; broadcast on flow/closed changes
flow http2flow // our conn-level flow control quota (cs.flow is per stream)
inflow http2flow // peer's conn-level flow control
+ closing bool
closed bool
wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back
goAway *http2GoAwayFrame // if non-nil, the GoAwayFrame we received
@@ -6761,9 +6857,10 @@ type http2clientStream struct {
done chan struct{} // closed when stream remove from cc.streams map; close calls guarded by cc.mu
// owned by clientConnReadLoop:
- firstByte bool // got the first response byte
- pastHeaders bool // got first MetaHeadersFrame (actual headers)
- pastTrailers bool // got optional second MetaHeadersFrame (trailers)
+ firstByte bool // got the first response byte
+ pastHeaders bool // got first MetaHeadersFrame (actual headers)
+ pastTrailers bool // got optional second MetaHeadersFrame (trailers)
+ num1xx uint8 // number of 1xx responses seen
trailer Header // accumulated trailers
resTrailer *Header // client's Response.Trailer
@@ -6787,6 +6884,17 @@ func http2awaitRequestCancel(req *Request, done <-chan struct{}) error {
}
}
+var http2got1xxFuncForTests func(int, textproto.MIMEHeader) error
+
+// get1xxTraceFunc returns the value of request's httptrace.ClientTrace.Got1xxResponse func,
+// if any. It returns nil if not set or if the Go version is too old.
+func (cs *http2clientStream) get1xxTraceFunc() func(int, textproto.MIMEHeader) error {
+ if fn := http2got1xxFuncForTests; fn != nil {
+ return fn
+ }
+ return http2traceGot1xxResponseFunc(cs.trace)
+}
+
// awaitRequestCancel waits for the user to cancel a request, its context to
// expire, or for the request to be done (any way it might be removed from the
// cc.streams map: peer reset, successful completion, TCP connection breakage,
@@ -6856,10 +6964,12 @@ func (sew http2stickyErrWriter) Write(p []byte) (n int, err error) {
return
}
-// noCachedConnError is the concrete type of ErrNoCachedConn, needs to be detected
-// by net/http regardless of whether it's its bundled version (in h2_bundle.go with a rewritten type name)
-// or from a user's x/net/http2. As such, as it has a unique method name (IsHTTP2NoCachedConnError) that
-// net/http sniffs for via func isNoCachedConnError.
+// noCachedConnError is the concrete type of ErrNoCachedConn, which
+// needs to be detected by net/http regardless of whether it's its
+// bundled version (in h2_bundle.go with a rewritten type name) or
+// from a user's x/net/http2. As such, as it has a unique method name
+// (IsHTTP2NoCachedConnError) that net/http sniffs for via func
+// isNoCachedConnError.
type http2noCachedConnError struct{}
func (http2noCachedConnError) IsHTTP2NoCachedConnError() {}
@@ -6870,9 +6980,7 @@ func (http2noCachedConnError) Error() string { return "http2: no cached connecti
// or its equivalent renamed type in net/http2's h2_bundle.go. Both types
// may coexist in the same running program.
func http2isNoCachedConnError(err error) bool {
- _, ok := err.(interface {
- IsHTTP2NoCachedConnError()
- })
+ _, ok := err.(interface{ IsHTTP2NoCachedConnError() })
return ok
}
@@ -6974,27 +7082,36 @@ func http2shouldRetryRequest(req *Request, err error, afterBodyWrite bool) (*Req
if !http2canRetryError(err) {
return nil, err
}
- if !afterBodyWrite {
- return req, nil
- }
// If the Body is nil (or http.NoBody), it's safe to reuse
// this request and its Body.
if req.Body == nil || http2reqBodyIsNoBody(req.Body) {
return req, nil
}
- // Otherwise we depend on the Request having its GetBody
- // func defined.
+
+ // If the request body can be reset back to its original
+ // state via the optional req.GetBody, do that.
getBody := http2reqGetBody(req) // Go 1.8: getBody = req.GetBody
- if getBody == nil {
- return nil, fmt.Errorf("http2: Transport: cannot retry err [%v] after Request.Body was written; define Request.GetBody to avoid this error", err)
+ if getBody != nil {
+ // TODO: consider a req.Body.Close here? or audit that all caller paths do?
+ body, err := getBody()
+ if err != nil {
+ return nil, err
+ }
+ newReq := *req
+ newReq.Body = body
+ return &newReq, nil
}
- body, err := getBody()
- if err != nil {
- return nil, err
+
+ // The Request.Body can't reset back to the beginning, but we
+ // don't seem to have started to read from it yet, so reuse
+ // the request directly. The "afterBodyWrite" means the
+ // bodyWrite process has started, which becomes true before
+ // the first Read.
+ if !afterBodyWrite {
+ return req, nil
}
- newReq := *req
- newReq.Body = body
- return &newReq, nil
+
+ return nil, fmt.Errorf("http2: Transport: cannot retry err [%v] after Request.Body was written; define Request.GetBody to avoid this error", err)
}
func http2canRetryError(err error) bool {
@@ -7118,6 +7235,10 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client
// henc in response to SETTINGS frames?
cc.henc = hpack.NewEncoder(&cc.hbuf)
+ if t.AllowHTTP {
+ cc.nextStreamID = 3
+ }
+
if cs, ok := c.(http2connectionStater); ok {
state := cs.ConnectionState()
cc.tlsState = &state
@@ -7177,12 +7298,32 @@ func (cc *http2ClientConn) CanTakeNewRequest() bool {
return cc.canTakeNewRequestLocked()
}
-func (cc *http2ClientConn) canTakeNewRequestLocked() bool {
+// clientConnIdleState describes the suitability of a client
+// connection to initiate a new RoundTrip request.
+type http2clientConnIdleState struct {
+ canTakeNewRequest bool
+ freshConn bool // whether it's unused by any previous request
+}
+
+func (cc *http2ClientConn) idleState() http2clientConnIdleState {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ return cc.idleStateLocked()
+}
+
+func (cc *http2ClientConn) idleStateLocked() (st http2clientConnIdleState) {
if cc.singleUse && cc.nextStreamID > 1 {
- return false
+ return
}
- return cc.goAway == nil && !cc.closed &&
+ st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing &&
int64(cc.nextStreamID)+int64(cc.pendingRequests) < math.MaxInt32
+ st.freshConn = cc.nextStreamID == 1 && st.canTakeNewRequest
+ return
+}
+
+func (cc *http2ClientConn) canTakeNewRequestLocked() bool {
+ st := cc.idleStateLocked()
+ return st.canTakeNewRequest
}
// onIdleTimeout is called from a time.AfterFunc goroutine. It will
@@ -7212,6 +7353,88 @@ func (cc *http2ClientConn) closeIfIdle() {
cc.tconn.Close()
}
+var http2shutdownEnterWaitStateHook = func() {}
+
+// Shutdown gracefully close the client connection, waiting for running streams to complete.
+// Public implementation is in go17.go and not_go17.go
+func (cc *http2ClientConn) shutdown(ctx http2contextContext) error {
+ if err := cc.sendGoAway(); err != nil {
+ return err
+ }
+ // Wait for all in-flight streams to complete or connection to close
+ done := make(chan error, 1)
+ cancelled := false // guarded by cc.mu
+ go func() {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ for {
+ if len(cc.streams) == 0 || cc.closed {
+ cc.closed = true
+ done <- cc.tconn.Close()
+ break
+ }
+ if cancelled {
+ break
+ }
+ cc.cond.Wait()
+ }
+ }()
+ http2shutdownEnterWaitStateHook()
+ select {
+ case err := <-done:
+ return err
+ case <-ctx.Done():
+ cc.mu.Lock()
+ // Free the goroutine above
+ cancelled = true
+ cc.cond.Broadcast()
+ cc.mu.Unlock()
+ return ctx.Err()
+ }
+}
+
+func (cc *http2ClientConn) sendGoAway() error {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ cc.wmu.Lock()
+ defer cc.wmu.Unlock()
+ if cc.closing {
+ // GOAWAY sent already
+ return nil
+ }
+ // Send a graceful shutdown frame to server
+ maxStreamID := cc.nextStreamID
+ if err := cc.fr.WriteGoAway(maxStreamID, http2ErrCodeNo, nil); err != nil {
+ return err
+ }
+ if err := cc.bw.Flush(); err != nil {
+ return err
+ }
+ // Prevent new requests
+ cc.closing = true
+ return nil
+}
+
+// Close closes the client connection immediately.
+//
+// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead.
+func (cc *http2ClientConn) Close() error {
+ cc.mu.Lock()
+ defer cc.cond.Broadcast()
+ defer cc.mu.Unlock()
+ err := errors.New("http2: client connection force closed via ClientConn.Close")
+ for id, cs := range cc.streams {
+ select {
+ case cs.resc <- http2resAndError{err: err}:
+ default:
+ }
+ cs.bufPipe.CloseWithError(err)
+ delete(cc.streams, id)
+ }
+ cc.closed = true
+ return cc.tconn.Close()
+}
+
const http2maxAllocFrameSize = 512 << 10
// frameBuffer returns a scratch buffer suitable for writing DATA frames.
@@ -7294,7 +7517,7 @@ func http2checkConnHeaders(req *Request) error {
if vv := req.Header["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") {
return fmt.Errorf("http2: invalid Transfer-Encoding request header: %q", vv)
}
- if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "close" && vv[0] != "keep-alive") {
+ if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !strings.EqualFold(vv[0], "close") && !strings.EqualFold(vv[0], "keep-alive")) {
return fmt.Errorf("http2: invalid Connection request header: %q", vv)
}
return nil
@@ -7517,6 +7740,9 @@ func (cc *http2ClientConn) awaitOpenSlotForRequest(req *Request) error {
for {
cc.lastActive = time.Now()
if cc.closed || !cc.canTakeNewRequestLocked() {
+ if waitingForConn != nil {
+ close(waitingForConn)
+ }
return http2errClientConnUnusable
}
if int64(len(cc.streams))+1 <= int64(cc.maxConcurrentStreams) {
@@ -7740,7 +7966,7 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail
if host == "" {
host = req.URL.Host
}
- host, err := httplex.PunycodeHostPort(host)
+ host, err := httpguts.PunycodeHostPort(host)
if err != nil {
return nil, err
}
@@ -7765,11 +7991,11 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail
// potentially pollute our hpack state. (We want to be able to
// continue to reuse the hpack encoder for future requests)
for k, vv := range req.Header {
- if !httplex.ValidHeaderFieldName(k) {
+ if !httpguts.ValidHeaderFieldName(k) {
return nil, fmt.Errorf("invalid HTTP header name %q", k)
}
for _, v := range vv {
- if !httplex.ValidHeaderFieldValue(v) {
+ if !httpguts.ValidHeaderFieldValue(v) {
return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
}
}
@@ -7850,9 +8076,16 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail
return nil, http2errRequestHeaderListSize
}
+ trace := http2requestTrace(req)
+ traceHeaders := http2traceHasWroteHeaderField(trace)
+
// Header list size is ok. Write the headers.
enumerateHeaders(func(name, value string) {
- cc.writeHeader(strings.ToLower(name), value)
+ name = strings.ToLower(name)
+ cc.writeHeader(name, value)
+ if traceHeaders {
+ http2traceWroteHeaderField(trace, name, value)
+ }
})
return cc.hbuf.Bytes(), nil
@@ -8174,8 +8407,7 @@ func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) erro
// is the detail.
//
// As a special case, handleResponse may return (nil, nil) to skip the
-// frame (currently only used for 100 expect continue). This special
-// case is going away after Issue 13851 is fixed.
+// frame (currently only used for 1xx responses).
func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http2MetaHeadersFrame) (*Response, error) {
if f.Truncated {
return nil, http2errResponseHeaderListSize
@@ -8190,15 +8422,6 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http
return nil, errors.New("malformed response from server: malformed non-numeric status pseudo header")
}
- if statusCode == 100 {
- http2traceGot100Continue(cs.trace)
- if cs.on100 != nil {
- cs.on100() // forces any write delay timer to fire
- }
- cs.pastHeaders = false // do it all again
- return nil, nil
- }
-
header := make(Header)
res := &Response{
Proto: "HTTP/2.0",
@@ -8223,6 +8446,27 @@ func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http
}
}
+ if statusCode >= 100 && statusCode <= 199 {
+ cs.num1xx++
+ const max1xxResponses = 5 // arbitrary bound on number of informational responses, same as net/http
+ if cs.num1xx > max1xxResponses {
+ return nil, errors.New("http2: too many 1xx informational responses")
+ }
+ if fn := cs.get1xxTraceFunc(); fn != nil {
+ if err := fn(statusCode, textproto.MIMEHeader(header)); err != nil {
+ return nil, err
+ }
+ }
+ if statusCode == 100 {
+ http2traceGot100Continue(cs.trace)
+ if cs.on100 != nil {
+ cs.on100() // forces any write delay timer to fire
+ }
+ }
+ cs.pastHeaders = false // do it all again
+ return nil, nil
+ }
+
streamEnded := f.StreamEnded()
isHead := cs.req.Method == "HEAD"
if !streamEnded || isHead {
@@ -8810,7 +9054,7 @@ func (t *http2Transport) getBodyWriterState(cs *http2clientStream, body io.Reade
}
s.delay = t.expectContinueTimeout()
if s.delay == 0 ||
- !httplex.HeaderValuesContainsToken(
+ !httpguts.HeaderValuesContainsToken(
cs.req.Header["Expect"],
"100-continue") {
return
@@ -8865,7 +9109,7 @@ func (s http2bodyWriterState) scheduleBodyWrite() {
// isConnectionCloseRequest reports whether req should use its own
// connection for a single request and then close the connection.
func http2isConnectionCloseRequest(req *Request) bool {
- return req.Close || httplex.HeaderValuesContainsToken(req.Header["Connection"], "close")
+ return req.Close || httpguts.HeaderValuesContainsToken(req.Header["Connection"], "close")
}
// writeFramer is implemented by any type that is used to write frames.
@@ -9205,7 +9449,7 @@ func http2encodeHeaders(enc *hpack.Encoder, h Header, keys []string) {
}
isTE := k == "transfer-encoding"
for _, v := range vv {
- if !httplex.ValidHeaderFieldValue(v) {
+ if !httpguts.ValidHeaderFieldValue(v) {
// TODO: return an error? golang.org/issue/14048
// For now just omit it.
continue
diff --git a/libgo/go/net/http/header.go b/libgo/go/net/http/header.go
index 622ad28..461ae93 100644
--- a/libgo/go/net/http/header.go
+++ b/libgo/go/net/http/header.go
@@ -6,6 +6,7 @@ package http
import (
"io"
+ "net/http/httptrace"
"net/textproto"
"sort"
"strings"
@@ -56,7 +57,11 @@ func (h Header) Del(key string) {
// Write writes a header in wire format.
func (h Header) Write(w io.Writer) error {
- return h.WriteSubset(w, nil)
+ return h.write(w, nil)
+}
+
+func (h Header) write(w io.Writer, trace *httptrace.ClientTrace) error {
+ return h.writeSubset(w, nil, trace)
}
func (h Header) clone() Header {
@@ -145,11 +150,16 @@ func (h Header) sortedKeyValues(exclude map[string]bool) (kvs []keyValues, hs *h
// WriteSubset writes a header in wire format.
// If exclude is not nil, keys where exclude[key] == true are not written.
func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error {
+ return h.writeSubset(w, exclude, nil)
+}
+
+func (h Header) writeSubset(w io.Writer, exclude map[string]bool, trace *httptrace.ClientTrace) error {
ws, ok := w.(writeStringer)
if !ok {
ws = stringWriter{w}
}
kvs, sorter := h.sortedKeyValues(exclude)
+ var formattedVals []string
for _, kv := range kvs {
for _, v := range kv.values {
v = headerNewlineToSpace.Replace(v)
@@ -160,6 +170,13 @@ func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error {
return err
}
}
+ if trace != nil && trace.WroteHeaderField != nil {
+ formattedVals = append(formattedVals, v)
+ }
+ }
+ if trace != nil && trace.WroteHeaderField != nil {
+ trace.WroteHeaderField(kv.key, formattedVals)
+ formattedVals = nil
}
}
headerSorterPool.Put(sorter)
diff --git a/libgo/go/net/http/http.go b/libgo/go/net/http/http.go
index b95ca89..ce0eceb 100644
--- a/libgo/go/net/http/http.go
+++ b/libgo/go/net/http/http.go
@@ -11,7 +11,7 @@ import (
"time"
"unicode/utf8"
- "golang_org/x/net/lex/httplex"
+ "golang_org/x/net/http/httpguts"
)
// maxInt64 is the effective "infinite" value for the Server and
@@ -47,7 +47,7 @@ func removeEmptyPort(host string) string {
}
func isNotToken(r rune) bool {
- return !httplex.IsTokenRune(r)
+ return !httpguts.IsTokenRune(r)
}
func isASCII(s string) bool {
diff --git a/libgo/go/net/http/httptest/httptest_test.go b/libgo/go/net/http/httptest/httptest_test.go
index 4f9ecbd..ef7d943 100644
--- a/libgo/go/net/http/httptest/httptest_test.go
+++ b/libgo/go/net/http/httptest/httptest_test.go
@@ -16,15 +16,17 @@ import (
)
func TestNewRequest(t *testing.T) {
- tests := [...]struct {
+ for _, tt := range [...]struct {
+ name string
+
method, uri string
body io.Reader
want *http.Request
wantBody string
}{
- // Empty method means GET:
- 0: {
+ {
+ name: "Empty method means GET",
method: "",
uri: "/",
body: nil,
@@ -42,8 +44,8 @@ func TestNewRequest(t *testing.T) {
wantBody: "",
},
- // GET with full URL:
- 1: {
+ {
+ name: "GET with full URL",
method: "GET",
uri: "http://foo.com/path/%2f/bar/",
body: nil,
@@ -66,8 +68,8 @@ func TestNewRequest(t *testing.T) {
wantBody: "",
},
- // GET with full https URL:
- 2: {
+ {
+ name: "GET with full https URL",
method: "GET",
uri: "https://foo.com/path/",
body: nil,
@@ -94,8 +96,8 @@ func TestNewRequest(t *testing.T) {
wantBody: "",
},
- // Post with known length
- 3: {
+ {
+ name: "Post with known length",
method: "POST",
uri: "/",
body: strings.NewReader("foo"),
@@ -114,8 +116,8 @@ func TestNewRequest(t *testing.T) {
wantBody: "foo",
},
- // Post with unknown length
- 4: {
+ {
+ name: "Post with unknown length",
method: "POST",
uri: "/",
body: struct{ io.Reader }{strings.NewReader("foo")},
@@ -134,8 +136,8 @@ func TestNewRequest(t *testing.T) {
wantBody: "foo",
},
- // OPTIONS *
- 5: {
+ {
+ name: "OPTIONS *",
method: "OPTIONS",
uri: "*",
want: &http.Request{
@@ -150,28 +152,29 @@ func TestNewRequest(t *testing.T) {
RequestURI: "*",
},
},
- }
- for i, tt := range tests {
- got := NewRequest(tt.method, tt.uri, tt.body)
- slurp, err := ioutil.ReadAll(got.Body)
- if err != nil {
- t.Errorf("%d. ReadAll: %v", i, err)
- }
- if string(slurp) != tt.wantBody {
- t.Errorf("%d. Body = %q; want %q", i, slurp, tt.wantBody)
- }
- got.Body = nil // before DeepEqual
- if !reflect.DeepEqual(got.URL, tt.want.URL) {
- t.Errorf("%d. Request.URL mismatch:\n got: %#v\nwant: %#v", i, got.URL, tt.want.URL)
- }
- if !reflect.DeepEqual(got.Header, tt.want.Header) {
- t.Errorf("%d. Request.Header mismatch:\n got: %#v\nwant: %#v", i, got.Header, tt.want.Header)
- }
- if !reflect.DeepEqual(got.TLS, tt.want.TLS) {
- t.Errorf("%d. Request.TLS mismatch:\n got: %#v\nwant: %#v", i, got.TLS, tt.want.TLS)
- }
- if !reflect.DeepEqual(got, tt.want) {
- t.Errorf("%d. Request mismatch:\n got: %#v\nwant: %#v", i, got, tt.want)
- }
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ got := NewRequest(tt.method, tt.uri, tt.body)
+ slurp, err := ioutil.ReadAll(got.Body)
+ if err != nil {
+ t.Errorf("ReadAll: %v", err)
+ }
+ if string(slurp) != tt.wantBody {
+ t.Errorf("Body = %q; want %q", slurp, tt.wantBody)
+ }
+ got.Body = nil // before DeepEqual
+ if !reflect.DeepEqual(got.URL, tt.want.URL) {
+ t.Errorf("Request.URL mismatch:\n got: %#v\nwant: %#v", got.URL, tt.want.URL)
+ }
+ if !reflect.DeepEqual(got.Header, tt.want.Header) {
+ t.Errorf("Request.Header mismatch:\n got: %#v\nwant: %#v", got.Header, tt.want.Header)
+ }
+ if !reflect.DeepEqual(got.TLS, tt.want.TLS) {
+ t.Errorf("Request.TLS mismatch:\n got: %#v\nwant: %#v", got.TLS, tt.want.TLS)
+ }
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("Request mismatch:\n got: %#v\nwant: %#v", got, tt.want)
+ }
+ })
}
}
diff --git a/libgo/go/net/http/httptest/recorder.go b/libgo/go/net/http/httptest/recorder.go
index 741f076..67f90b8 100644
--- a/libgo/go/net/http/httptest/recorder.go
+++ b/libgo/go/net/http/httptest/recorder.go
@@ -11,6 +11,8 @@ import (
"net/http"
"strconv"
"strings"
+
+ "golang_org/x/net/http/httpguts"
)
// ResponseRecorder is an implementation of http.ResponseWriter that
@@ -25,9 +27,11 @@ type ResponseRecorder struct {
Code int
// HeaderMap contains the headers explicitly set by the Handler.
+ // It is an internal detail.
//
- // To get the implicit headers set by the server (such as
- // automatic Content-Type), use the Result method.
+ // Deprecated: HeaderMap exists for historical compatibility
+ // and should not be used. To access the headers returned by a handler,
+ // use the Response.Header map as returned by the Result method.
HeaderMap http.Header
// Body is the buffer to which the Handler's Write calls are sent.
@@ -180,21 +184,19 @@ func (rw *ResponseRecorder) Result() *http.Response {
res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode))
if rw.Body != nil {
res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
+ } else {
+ res.Body = http.NoBody
}
res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
if trailers, ok := rw.snapHeader["Trailer"]; ok {
res.Trailer = make(http.Header, len(trailers))
for _, k := range trailers {
- // TODO: use http2.ValidTrailerHeader, but we can't
- // get at it easily because it's bundled into net/http
- // unexported. This is good enough for now:
- switch k {
- case "Transfer-Encoding", "Content-Length", "Trailer":
- // Ignore since forbidden by RFC 2616 14.40.
+ k = http.CanonicalHeaderKey(k)
+ if !httpguts.ValidTrailerHeader(k) {
+ // Ignore since forbidden by RFC 7230, section 4.1.2.
continue
}
- k = http.CanonicalHeaderKey(k)
vv, ok := rw.HeaderMap[k]
if !ok {
continue
diff --git a/libgo/go/net/http/httptest/recorder_test.go b/libgo/go/net/http/httptest/recorder_test.go
index a6259eb..0986554 100644
--- a/libgo/go/net/http/httptest/recorder_test.go
+++ b/libgo/go/net/http/httptest/recorder_test.go
@@ -7,6 +7,7 @@ package httptest
import (
"fmt"
"io"
+ "io/ioutil"
"net/http"
"testing"
)
@@ -39,6 +40,19 @@ func TestRecorder(t *testing.T) {
return nil
}
}
+ hasResultContents := func(want string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ contentBytes, err := ioutil.ReadAll(rec.Result().Body)
+ if err != nil {
+ return err
+ }
+ contents := string(contentBytes)
+ if contents != want {
+ return fmt.Errorf("Result().Body = %s; want %s", contents, want)
+ }
+ return nil
+ }
+ }
hasContents := func(want string) checkFunc {
return func(rec *ResponseRecorder) error {
if rec.Body.String() != want {
@@ -111,7 +125,7 @@ func TestRecorder(t *testing.T) {
}
}
- tests := []struct {
+ for _, tt := range [...]struct {
name string
h func(w http.ResponseWriter, r *http.Request)
checks []checkFunc
@@ -273,16 +287,26 @@ func TestRecorder(t *testing.T) {
},
check(hasStatus(200), hasContents("Some body"), hasContentLength(9)),
},
- }
- r, _ := http.NewRequest("GET", "http://foo.com/", nil)
- for _, tt := range tests {
- h := http.HandlerFunc(tt.h)
- rec := NewRecorder()
- h.ServeHTTP(rec, r)
- for _, check := range tt.checks {
- if err := check(rec); err != nil {
- t.Errorf("%s: %v", tt.name, err)
+ {
+ "nil ResponseRecorder.Body", // Issue 26642
+ func(w http.ResponseWriter, r *http.Request) {
+ w.(*ResponseRecorder).Body = nil
+ io.WriteString(w, "hi")
+ },
+ check(hasResultContents("")), // check we don't crash reading the body
+
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ r, _ := http.NewRequest("GET", "http://foo.com/", nil)
+ h := http.HandlerFunc(tt.h)
+ rec := NewRecorder()
+ h.ServeHTTP(rec, r)
+ for _, check := range tt.checks {
+ if err := check(rec); err != nil {
+ t.Error(err)
+ }
}
- }
+ })
}
}
diff --git a/libgo/go/net/http/httptest/server.go b/libgo/go/net/http/httptest/server.go
index e543672..f6bcf3a 100644
--- a/libgo/go/net/http/httptest/server.go
+++ b/libgo/go/net/http/httptest/server.go
@@ -7,7 +7,6 @@
package httptest
import (
- "bytes"
"crypto/tls"
"crypto/x509"
"flag"
@@ -17,6 +16,7 @@ import (
"net/http"
"net/http/internal"
"os"
+ "strings"
"sync"
"time"
)
@@ -224,7 +224,7 @@ func (s *Server) Close() {
func (s *Server) logCloseHangDebugInfo() {
s.mu.Lock()
defer s.mu.Unlock()
- var buf bytes.Buffer
+ var buf strings.Builder
buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n")
for c, st := range s.conns {
fmt.Fprintf(&buf, " %T %p %v in state %v\n", c, c, c.RemoteAddr(), st)
diff --git a/libgo/go/net/http/httptrace/trace.go b/libgo/go/net/http/httptrace/trace.go
index ea7b38c..3a62741 100644
--- a/libgo/go/net/http/httptrace/trace.go
+++ b/libgo/go/net/http/httptrace/trace.go
@@ -11,6 +11,7 @@ import (
"crypto/tls"
"internal/nettrace"
"net"
+ "net/textproto"
"reflect"
"time"
)
@@ -107,6 +108,12 @@ type ClientTrace struct {
// Continue" response.
Got100Continue func()
+ // Got1xxResponse is called for each 1xx informational response header
+ // returned before the final non-1xx response. Got1xxResponse is called
+ // for "100 Continue" responses, even if Got100Continue is also defined.
+ // If it returns an error, the client request is aborted with that error value.
+ Got1xxResponse func(code int, header textproto.MIMEHeader) error
+
// DNSStart is called when a DNS lookup begins.
DNSStart func(DNSStartInfo)
@@ -135,8 +142,13 @@ type ClientTrace struct {
// failure.
TLSHandshakeDone func(tls.ConnectionState, error)
+ // WroteHeaderField is called after the Transport has written
+ // each request header. At the time of this call the values
+ // might be buffered and not yet written to the network.
+ WroteHeaderField func(key string, value []string)
+
// WroteHeaders is called after the Transport has written
- // the request headers.
+ // all request headers.
WroteHeaders func()
// Wait100Continue is called if the Request specified
diff --git a/libgo/go/net/http/httputil/httputil.go b/libgo/go/net/http/httputil/httputil.go
index 2e523e9..09ea74d 100644
--- a/libgo/go/net/http/httputil/httputil.go
+++ b/libgo/go/net/http/httputil/httputil.go
@@ -23,7 +23,9 @@ func NewChunkedReader(r io.Reader) io.Reader {
// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP
// "chunked" format before writing them to w. Closing the returned chunkedWriter
-// sends the final 0-length chunk that marks the end of the stream.
+// sends the final 0-length chunk that marks the end of the stream but does
+// not send the final CRLF that appears after trailers; trailers and the last
+// CRLF must be written separately.
//
// NewChunkedWriter is not needed by normal applications. The http
// package adds chunking automatically if handlers don't set a
diff --git a/libgo/go/net/http/httputil/reverseproxy.go b/libgo/go/net/http/httputil/reverseproxy.go
index b96bb21..1dddaa9 100644
--- a/libgo/go/net/http/httputil/reverseproxy.go
+++ b/libgo/go/net/http/httputil/reverseproxy.go
@@ -55,10 +55,23 @@ type ReverseProxy struct {
// copying HTTP response bodies.
BufferPool BufferPool
- // ModifyResponse is an optional function that
- // modifies the Response from the backend.
- // If it returns an error, the proxy returns a StatusBadGateway error.
+ // ModifyResponse is an optional function that modifies the
+ // Response from the backend. It is called if the backend
+ // returns a response at all, with any HTTP status code.
+ // If the backend is unreachable, the optional ErrorHandler is
+ // called without any call to ModifyResponse.
+ //
+ // If ModifyResponse returns an error, ErrorHandler is called
+ // with its error value. If ErrorHandler is nil, its default
+ // implementation is used.
ModifyResponse func(*http.Response) error
+
+ // ErrorHandler is an optional function that handles errors
+ // reaching the backend or errors from ModifyResponse.
+ //
+ // If nil, the default is to log the provided error and return
+ // a 502 Status Bad Gateway response.
+ ErrorHandler func(http.ResponseWriter, *http.Request, error)
}
// A BufferPool is an interface for getting and returning temporary
@@ -125,7 +138,10 @@ func cloneHeader(h http.Header) http.Header {
}
// Hop-by-hop headers. These are removed when sent to the backend.
-// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
+// As of RFC 7230, hop-by-hop headers are required to appear in the
+// Connection header field. These are the headers defined by the
+// obsoleted RFC 2616 (section 13.5.1) and are used for backward
+// compatibility.
var hopHeaders = []string{
"Connection",
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
@@ -133,11 +149,23 @@ var hopHeaders = []string{
"Proxy-Authenticate",
"Proxy-Authorization",
"Te", // canonicalized version of "TE"
- "Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522
+ "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
"Transfer-Encoding",
"Upgrade",
}
+func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
+ p.logf("http: proxy error: %v", err)
+ rw.WriteHeader(http.StatusBadGateway)
+}
+
+func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
+ if p.ErrorHandler != nil {
+ return p.ErrorHandler
+ }
+ return p.defaultErrorHandler
+}
+
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
transport := p.Transport
if transport == nil {
@@ -175,9 +203,20 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// important is "Connection" because we want a persistent
// connection, regardless of what the client sent to us.
for _, h := range hopHeaders {
- if outreq.Header.Get(h) != "" {
- outreq.Header.Del(h)
+ hv := outreq.Header.Get(h)
+ if hv == "" {
+ continue
+ }
+ if h == "Te" && hv == "trailers" {
+ // Issue 21096: tell backend applications that
+ // care about trailer support that we support
+ // trailers. (We do, but we don't go out of
+ // our way to advertise that unless the
+ // incoming client request thought it was
+ // worth mentioning)
+ continue
}
+ outreq.Header.Del(h)
}
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
@@ -192,8 +231,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
res, err := transport.RoundTrip(outreq)
if err != nil {
- p.logf("http: proxy error: %v", err)
- rw.WriteHeader(http.StatusBadGateway)
+ p.getErrorHandler()(rw, outreq, err)
return
}
@@ -205,9 +243,8 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if p.ModifyResponse != nil {
if err := p.ModifyResponse(res); err != nil {
- p.logf("http: proxy error: %v", err)
- rw.WriteHeader(http.StatusBadGateway)
res.Body.Close()
+ p.getErrorHandler()(rw, outreq, err)
return
}
}
@@ -234,7 +271,18 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
fl.Flush()
}
}
- p.copyResponse(rw, res.Body)
+ err = p.copyResponse(rw, res.Body)
+ if err != nil {
+ defer res.Body.Close()
+ // Since we're streaming the response, if we run into an error all we can do
+ // is abort the request. Issue 23643: ReverseProxy should use ErrAbortHandler
+ // on read error while copying body.
+ if !shouldPanicOnCopyError(req) {
+ p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
+ return
+ }
+ panic(http.ErrAbortHandler)
+ }
res.Body.Close() // close now, instead of defer, to populate res.Trailer
if len(res.Trailer) == announcedTrailers {
@@ -250,8 +298,30 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
}
+var inOurTests bool // whether we're in our own tests
+
+// shouldPanicOnCopyError reports whether the reverse proxy should
+// panic with http.ErrAbortHandler. This is the right thing to do by
+// default, but Go 1.10 and earlier did not, so existing unit tests
+// weren't expecting panics. Only panic in our own tests, or when
+// running under the HTTP server.
+func shouldPanicOnCopyError(req *http.Request) bool {
+ if inOurTests {
+ // Our tests know to handle this panic.
+ return true
+ }
+ if req.Context().Value(http.ServerContextKey) != nil {
+ // We seem to be running under an HTTP server, so
+ // it'll recover the panic.
+ return true
+ }
+ // Otherwise act like Go 1.10 and earlier to not break
+ // existing tests.
+ return false
+}
+
// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h.
-// See RFC 2616, section 14.10.
+// See RFC 7230, section 6.1
func removeConnectionHeaders(h http.Header) {
if c := h.Get("Connection"); c != "" {
for _, f := range strings.Split(c, ",") {
@@ -262,7 +332,7 @@ func removeConnectionHeaders(h http.Header) {
}
}
-func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
+func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) error {
if p.FlushInterval != 0 {
if wf, ok := dst.(writeFlusher); ok {
mlw := &maxLatencyWriter{
@@ -279,13 +349,14 @@ func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
var buf []byte
if p.BufferPool != nil {
buf = p.BufferPool.Get()
+ defer p.BufferPool.Put(buf)
}
- p.copyBuffer(dst, src, buf)
- if p.BufferPool != nil {
- p.BufferPool.Put(buf)
- }
+ _, err := p.copyBuffer(dst, src, buf)
+ return err
}
+// copyBuffer returns any write errors or non-EOF read errors, and the amount
+// of bytes written.
func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
if len(buf) == 0 {
buf = make([]byte, 32*1024)
@@ -309,6 +380,9 @@ func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int
}
}
if rerr != nil {
+ if rerr == io.EOF {
+ rerr = nil
+ }
return written, rerr
}
}
diff --git a/libgo/go/net/http/httputil/reverseproxy_test.go b/libgo/go/net/http/httputil/reverseproxy_test.go
index 2232042..2f75b4e 100644
--- a/libgo/go/net/http/httputil/reverseproxy_test.go
+++ b/libgo/go/net/http/httputil/reverseproxy_test.go
@@ -17,6 +17,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
+ "os"
"reflect"
"strconv"
"strings"
@@ -28,6 +29,7 @@ import (
const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
func init() {
+ inOurTests = true
hopHeaders = append(hopHeaders, fakeHopHeader)
}
@@ -49,6 +51,9 @@ func TestReverseProxy(t *testing.T) {
if c := r.Header.Get("Connection"); c != "" {
t.Errorf("handler got Connection header value %q", c)
}
+ if c := r.Header.Get("Te"); c != "trailers" {
+ t.Errorf("handler got Te header value %q; want 'trailers'", c)
+ }
if c := r.Header.Get("Upgrade"); c != "" {
t.Errorf("handler got Upgrade header value %q", c)
}
@@ -85,6 +90,7 @@ func TestReverseProxy(t *testing.T) {
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
getReq.Host = "some-name"
getReq.Header.Set("Connection", "close")
+ getReq.Header.Set("Te", "trailers")
getReq.Header.Set("Proxy-Connection", "should be deleted")
getReq.Header.Set("Upgrade", "foo")
getReq.Close = true
@@ -631,6 +637,93 @@ func TestReverseProxyModifyResponse(t *testing.T) {
}
}
+type failingRoundTripper struct{}
+
+func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
+ return nil, errors.New("some error")
+}
+
+type staticResponseRoundTripper struct{ res *http.Response }
+
+func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
+ return rt.res, nil
+}
+
+func TestReverseProxyErrorHandler(t *testing.T) {
+ tests := []struct {
+ name string
+ wantCode int
+ errorHandler func(http.ResponseWriter, *http.Request, error)
+ transport http.RoundTripper // defaults to failingRoundTripper
+ modifyResponse func(*http.Response) error
+ }{
+ {
+ name: "default",
+ wantCode: http.StatusBadGateway,
+ },
+ {
+ name: "errorhandler",
+ wantCode: http.StatusTeapot,
+ errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
+ },
+ {
+ name: "modifyresponse_noerr",
+ transport: staticResponseRoundTripper{
+ &http.Response{StatusCode: 345, Body: http.NoBody},
+ },
+ modifyResponse: func(res *http.Response) error {
+ res.StatusCode++
+ return nil
+ },
+ errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
+ wantCode: 346,
+ },
+ {
+ name: "modifyresponse_err",
+ transport: staticResponseRoundTripper{
+ &http.Response{StatusCode: 345, Body: http.NoBody},
+ },
+ modifyResponse: func(res *http.Response) error {
+ res.StatusCode++
+ return errors.New("some error to trigger errorHandler")
+ },
+ errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
+ wantCode: http.StatusTeapot,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ target := &url.URL{
+ Scheme: "http",
+ Host: "dummy.tld",
+ Path: "/",
+ }
+ rproxy := NewSingleHostReverseProxy(target)
+ rproxy.Transport = tt.transport
+ rproxy.ModifyResponse = tt.modifyResponse
+ if rproxy.Transport == nil {
+ rproxy.Transport = failingRoundTripper{}
+ }
+ rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
+ if tt.errorHandler != nil {
+ rproxy.ErrorHandler = tt.errorHandler
+ }
+ frontendProxy := httptest.NewServer(rproxy)
+ defer frontendProxy.Close()
+
+ resp, err := http.Get(frontendProxy.URL + "/test")
+ if err != nil {
+ t.Fatalf("failed to reach proxy: %v", err)
+ }
+ if g, e := resp.StatusCode, tt.wantCode; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ resp.Body.Close()
+ })
+ }
+}
+
// Issue 16659: log errors from short read
func TestReverseProxy_CopyBuffer(t *testing.T) {
backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -649,18 +742,22 @@ func TestReverseProxy_CopyBuffer(t *testing.T) {
var proxyLog bytes.Buffer
rproxy := NewSingleHostReverseProxy(rpURL)
rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile)
- frontendProxy := httptest.NewServer(rproxy)
+ donec := make(chan bool, 1)
+ frontendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ defer func() { donec <- true }()
+ rproxy.ServeHTTP(w, r)
+ }))
defer frontendProxy.Close()
- resp, err := http.Get(frontendProxy.URL)
- if err != nil {
- t.Fatalf("failed to reach proxy: %v", err)
- }
- defer resp.Body.Close()
-
- if _, err := ioutil.ReadAll(resp.Body); err == nil {
+ if _, err = frontendProxy.Client().Get(frontendProxy.URL); err == nil {
t.Fatalf("want non-nil error")
}
+ // The race detector complains about the proxyLog usage in logf in copyBuffer
+ // and our usage below with proxyLog.Bytes() so we're explicitly using a
+ // channel to ensure that the ReverseProxy's ServeHTTP is done before we
+ // continue after Get.
+ <-donec
+
expected := []string{
"EOF",
"read",
@@ -740,6 +837,8 @@ func TestServeHTTPDeepCopy(t *testing.T) {
// Issue 18327: verify we always do a deep copy of the Request.Header map
// before any mutations.
func TestClonesRequestHeaders(t *testing.T) {
+ log.SetOutput(ioutil.Discard)
+ defer log.SetOutput(os.Stderr)
req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
req.RemoteAddr = "1.2.3.4:56789"
rp := &ReverseProxy{
@@ -813,3 +912,37 @@ func (cc *checkCloser) Close() error {
func (cc *checkCloser) Read(b []byte) (int, error) {
return len(b), nil
}
+
+// Issue 23643: panic on body copy error
+func TestReverseProxy_PanicBodyError(t *testing.T) {
+ log.SetOutput(ioutil.Discard)
+ defer log.SetOutput(os.Stderr)
+ backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ out := "this call was relayed by the reverse proxy"
+ // Coerce a wrong content length to induce io.ErrUnexpectedEOF
+ w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2))
+ fmt.Fprintln(w, out)
+ }))
+ defer backendServer.Close()
+
+ rpURL, err := url.Parse(backendServer.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ rproxy := NewSingleHostReverseProxy(rpURL)
+
+ // Ensure that the handler panics when the body read encounters an
+ // io.ErrUnexpectedEOF
+ defer func() {
+ err := recover()
+ if err == nil {
+ t.Fatal("handler should have panicked")
+ }
+ if err != http.ErrAbortHandler {
+ t.Fatal("expected ErrAbortHandler, got", err)
+ }
+ }()
+ req, _ := http.NewRequest("GET", "http://foo.tld/", nil)
+ rproxy.ServeHTTP(httptest.NewRecorder(), req)
+}
diff --git a/libgo/go/net/http/internal/chunked.go b/libgo/go/net/http/internal/chunked.go
index 63f321d..f06e572 100644
--- a/libgo/go/net/http/internal/chunked.go
+++ b/libgo/go/net/http/internal/chunked.go
@@ -171,7 +171,9 @@ func removeChunkExtension(p []byte) ([]byte, error) {
// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP
// "chunked" format before writing them to w. Closing the returned chunkedWriter
-// sends the final 0-length chunk that marks the end of the stream.
+// sends the final 0-length chunk that marks the end of the stream but does
+// not send the final CRLF that appears after trailers; trailers and the last
+// CRLF must be written separately.
//
// NewChunkedWriter is not needed by normal applications. The http
// package adds chunking automatically if handlers don't set a
diff --git a/libgo/go/net/http/main_test.go b/libgo/go/net/http/main_test.go
index 21c8505..7936fb3 100644
--- a/libgo/go/net/http/main_test.go
+++ b/libgo/go/net/http/main_test.go
@@ -114,12 +114,12 @@ func afterTest(t testing.TB) {
}
var bad string
badSubstring := map[string]string{
- ").readLoop(": "a Transport",
- ").writeLoop(": "a Transport",
+ ").readLoop(": "a Transport",
+ ").writeLoop(": "a Transport",
"created by net/http/httptest.(*Server).Start": "an httptest.Server",
- "timeoutHandler": "a TimeoutHandler",
- "net.(*netFD).connect(": "a timing out dial",
- ").noteClientGone(": "a closenotifier sender",
+ "timeoutHandler": "a TimeoutHandler",
+ "net.(*netFD).connect(": "a timing out dial",
+ ").noteClientGone(": "a closenotifier sender",
}
var stacks string
for i := 0; i < 4; i++ {
diff --git a/libgo/go/net/http/pprof/pprof.go b/libgo/go/net/http/pprof/pprof.go
index 77e0bcd..35b3285 100644
--- a/libgo/go/net/http/pprof/pprof.go
+++ b/libgo/go/net/http/pprof/pprof.go
@@ -26,7 +26,7 @@
//
// Or to look at a 30-second CPU profile:
//
-// go tool pprof http://localhost:6060/debug/pprof/profile
+// go tool pprof http://localhost:6060/debug/pprof/profile?seconds=30
//
// Or to look at the goroutine blocking profile, after calling
// runtime.SetBlockProfileRate in your program:
@@ -63,6 +63,7 @@ import (
"runtime"
"runtime/pprof"
"runtime/trace"
+ "sort"
"strconv"
"strings"
"time"
@@ -110,11 +111,12 @@ func serveError(w http.ResponseWriter, status int, txt string) {
}
// Profile responds with the pprof-formatted cpu profile.
+// Profiling lasts for duration specified in seconds GET parameter, or for 30 seconds if not specified.
// The package initialization registers it as /debug/pprof/profile.
func Profile(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Content-Type-Options", "nosniff")
- sec, _ := strconv.ParseInt(r.FormValue("seconds"), 10, 64)
- if sec == 0 {
+ sec, err := strconv.ParseInt(r.FormValue("seconds"), 10, 64)
+ if sec <= 0 || err != nil {
sec = 30
}
@@ -243,6 +245,18 @@ func (name handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
p.WriteTo(w, debug)
}
+var profileDescriptions = map[string]string{
+ "allocs": "A sampling of all past memory allocations",
+ "block": "Stack traces that led to blocking on synchronization primitives",
+ "cmdline": "The command line invocation of the current program",
+ "goroutine": "Stack traces of all current goroutines",
+ "heap": "A sampling of memory allocations of live objects. You can specify the gc GET parameter to run GC before taking the heap sample.",
+ "mutex": "Stack traces of holders of contended mutexes",
+ "profile": "CPU profile. You can specify the duration in the seconds GET parameter. After you get the profile file, use the go tool pprof command to investigate the profile.",
+ "threadcreate": "Stack traces that led to the creation of new OS threads",
+ "trace": "A trace of execution of the current program. You can specify the duration in the seconds GET parameter. After you get the trace file, use the go tool trace command to investigate the trace.",
+}
+
// Index responds with the pprof-formatted profile named by the request.
// For example, "/debug/pprof/heap" serves the "heap" profile.
// Index responds to a request for "/debug/pprof/" with an HTML page
@@ -256,7 +270,35 @@ func Index(w http.ResponseWriter, r *http.Request) {
}
}
- profiles := pprof.Profiles()
+ type profile struct {
+ Name string
+ Href string
+ Desc string
+ Count int
+ }
+ var profiles []profile
+ for _, p := range pprof.Profiles() {
+ profiles = append(profiles, profile{
+ Name: p.Name(),
+ Href: p.Name() + "?debug=1",
+ Desc: profileDescriptions[p.Name()],
+ Count: p.Count(),
+ })
+ }
+
+ // Adding other profiles exposed from within this package
+ for _, p := range []string{"cmdline", "profile", "trace"} {
+ profiles = append(profiles, profile{
+ Name: p,
+ Href: p,
+ Desc: profileDescriptions[p],
+ })
+ }
+
+ sort.Slice(profiles, func(i, j int) bool {
+ return profiles[i].Name < profiles[j].Name
+ })
+
if err := indexTmpl.Execute(w, profiles); err != nil {
log.Print(err)
}
@@ -265,18 +307,35 @@ func Index(w http.ResponseWriter, r *http.Request) {
var indexTmpl = template.Must(template.New("index").Parse(`<html>
<head>
<title>/debug/pprof/</title>
+<style>
+.profile-name{
+ display:inline-block;
+ width:6rem;
+}
+</style>
</head>
<body>
/debug/pprof/<br>
<br>
-profiles:<br>
+Types of profiles available:
<table>
+<thead><td>Count</td><td>Profile</td></thead>
{{range .}}
-<tr><td align=right>{{.Count}}<td><a href="{{.Name}}?debug=1">{{.Name}}</a>
+ <tr>
+ <td>{{.Count}}</td><td><a href={{.Href}}>{{.Name}}</a></td>
+ </tr>
{{end}}
</table>
-<br>
-<a href="goroutine?debug=2">full goroutine stack dump</a><br>
+<a href="goroutine?debug=2">full goroutine stack dump</a>
+<br/>
+<p>
+Profile Descriptions:
+<ul>
+{{range .}}
+<li><div class=profile-name>{{.Name}}:</div> {{.Desc}}</li>
+{{end}}
+</ul>
+</p>
</body>
</html>
`))
diff --git a/libgo/go/net/http/pprof/pprof_test.go b/libgo/go/net/http/pprof/pprof_test.go
index 47dd35b..dbb6fef 100644
--- a/libgo/go/net/http/pprof/pprof_test.go
+++ b/libgo/go/net/http/pprof/pprof_test.go
@@ -9,9 +9,21 @@ import (
"io/ioutil"
"net/http"
"net/http/httptest"
+ "runtime/pprof"
"testing"
)
+// TestDescriptions checks that the profile names under runtime/pprof package
+// have a key in the description map.
+func TestDescriptions(t *testing.T) {
+ for _, p := range pprof.Profiles() {
+ _, ok := profileDescriptions[p.Name()]
+ if ok != true {
+ t.Errorf("%s does not exist in profileDescriptions map\n", p.Name())
+ }
+ }
+}
+
func TestHandlers(t *testing.T) {
testCases := []struct {
path string
diff --git a/libgo/go/net/http/proxy_test.go b/libgo/go/net/http/proxy_test.go
index f59a551..eef0ca8 100644
--- a/libgo/go/net/http/proxy_test.go
+++ b/libgo/go/net/http/proxy_test.go
@@ -13,37 +13,6 @@ import (
// TODO(mattn):
// test ProxyAuth
-var UseProxyTests = []struct {
- host string
- match bool
-}{
- // Never proxy localhost:
- {"localhost", false},
- {"127.0.0.1", false},
- {"127.0.0.2", false},
- {"[::1]", false},
- {"[::2]", true}, // not a loopback address
-
- {"barbaz.net", false}, // match as .barbaz.net
- {"foobar.com", false}, // have a port but match
- {"foofoobar.com", true}, // not match as a part of foobar.com
- {"baz.com", true}, // not match as a part of barbaz.com
- {"localhost.net", true}, // not match as suffix of address
- {"local.localhost", true}, // not match as prefix as address
- {"barbarbaz.net", true}, // not match because NO_PROXY have a '.'
- {"www.foobar.com", false}, // match because NO_PROXY includes "foobar.com"
-}
-
-func TestUseProxy(t *testing.T) {
- ResetProxyEnv()
- os.Setenv("NO_PROXY", "foobar.com, .barbaz.net")
- for _, test := range UseProxyTests {
- if useProxy(test.host+":80") != test.match {
- t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match)
- }
- }
-}
-
var cacheKeysTests = []struct {
proxy string
scheme string
@@ -74,14 +43,8 @@ func TestCacheKeys(t *testing.T) {
}
func ResetProxyEnv() {
- for _, v := range []string{"HTTP_PROXY", "http_proxy", "NO_PROXY", "no_proxy"} {
+ for _, v := range []string{"HTTP_PROXY", "http_proxy", "NO_PROXY", "no_proxy", "REQUEST_METHOD"} {
os.Unsetenv(v)
}
ResetCachedEnvironment()
}
-
-func TestInvalidNoProxy(t *testing.T) {
- ResetProxyEnv()
- os.Setenv("NO_PROXY", ":1")
- useProxy("example.com:80") // should not panic
-}
diff --git a/libgo/go/net/http/readrequest_test.go b/libgo/go/net/http/readrequest_test.go
index 22a9c2e..18eed34 100644
--- a/libgo/go/net/http/readrequest_test.go
+++ b/libgo/go/net/http/readrequest_test.go
@@ -126,7 +126,7 @@ var reqTests = []reqTest{
noError,
},
- // Tests a bogus abs_path on the Request-Line (RFC 2616 section 5.1.2)
+ // Tests a bogus absolute-path on the Request-Line (RFC 7230 section 5.3.1)
{
"GET ../../../../etc/passwd HTTP/1.1\r\n" +
"Host: test\r\n\r\n",
diff --git a/libgo/go/net/http/request.go b/libgo/go/net/http/request.go
index c9642e5..a40b0a3 100644
--- a/libgo/go/net/http/request.go
+++ b/libgo/go/net/http/request.go
@@ -65,11 +65,19 @@ var (
// request's Content-Type is not multipart/form-data.
ErrNotMultipart = &ProtocolError{"request Content-Type isn't multipart/form-data"}
- // Deprecated: ErrHeaderTooLong is not used.
+ // Deprecated: ErrHeaderTooLong is no longer returned by
+ // anything in the net/http package. Callers should not
+ // compare errors against this variable.
ErrHeaderTooLong = &ProtocolError{"header too long"}
- // Deprecated: ErrShortBody is not used.
+
+ // Deprecated: ErrShortBody is no longer returned by
+ // anything in the net/http package. Callers should not
+ // compare errors against this variable.
ErrShortBody = &ProtocolError{"entity body too short"}
- // Deprecated: ErrMissingContentLength is not used.
+
+ // Deprecated: ErrMissingContentLength is no longer returned by
+ // anything in the net/http package. Callers should not
+ // compare errors against this variable.
ErrMissingContentLength = &ProtocolError{"missing ContentLength in HEAD response"}
)
@@ -110,7 +118,7 @@ type Request struct {
// For server requests the URL is parsed from the URI
// supplied on the Request-Line as stored in RequestURI. For
// most requests, fields other than Path and RawQuery will be
- // empty. (See RFC 2616, Section 5.1.2)
+ // empty. (See RFC 7230, Section 5.3)
//
// For client requests, the URL's Host specifies the server to
// connect to, while the Request's Host field optionally
@@ -207,13 +215,18 @@ type Request struct {
// Transport.DisableKeepAlives were set.
Close bool
- // For server requests Host specifies the host on which the
- // URL is sought. Per RFC 2616, this is either the value of
- // the "Host" header or the host name given in the URL itself.
+ // For server requests Host specifies the host on which the URL
+ // is sought. Per RFC 7230, section 5.4, this is either the value
+ // of the "Host" header or the host name given in the URL itself.
// It may be of the form "host:port". For international domain
// names, Host may be in Punycode or Unicode form. Use
// golang.org/x/net/idna to convert it to either format if
// needed.
+ // To prevent DNS rebinding attacks, server Handlers should
+ // validate that the Host header has a value for which the
+ // Handler considers itself authoritative. The included
+ // ServeMux supports patterns registered to particular host
+ // names and thus protects its registered Handlers.
//
// For client requests Host optionally overrides the Host
// header to send. If empty, the Request.Write method uses
@@ -268,8 +281,8 @@ type Request struct {
// This field is ignored by the HTTP client.
RemoteAddr string
- // RequestURI is the unmodified Request-URI of the
- // Request-Line (RFC 2616, Section 5.1) as sent by the client
+ // RequestURI is the unmodified request-target of the
+ // Request-Line (RFC 7230, Section 3.1.1) as sent by the client
// to a server. Usually the URL field should be used instead.
// It is an error to set this field in an HTTP client request.
RequestURI string
@@ -326,6 +339,10 @@ func (r *Request) Context() context.Context {
// WithContext returns a shallow copy of r with its context changed
// to ctx. The provided ctx must be non-nil.
+//
+// For outgoing client request, the context controls the entire
+// lifetime of a request and its response: obtaining a connection,
+// sending the request, and reading the response headers and body.
func (r *Request) WithContext(ctx context.Context) *Request {
if ctx == nil {
panic("nil context")
@@ -411,7 +428,7 @@ var multipartByReader = &multipart.Form{
}
// MultipartReader returns a MIME multipart reader if this is a
-// multipart/form-data POST request, else returns nil and an error.
+// multipart/form-data or a multipart/mixed POST request, else returns nil and an error.
// Use this function instead of ParseMultipartForm to
// process the request body as a stream.
func (r *Request) MultipartReader() (*multipart.Reader, error) {
@@ -422,16 +439,16 @@ func (r *Request) MultipartReader() (*multipart.Reader, error) {
return nil, errors.New("http: multipart handled by ParseMultipartForm")
}
r.MultipartForm = multipartByReader
- return r.multipartReader()
+ return r.multipartReader(true)
}
-func (r *Request) multipartReader() (*multipart.Reader, error) {
+func (r *Request) multipartReader(allowMixed bool) (*multipart.Reader, error) {
v := r.Header.Get("Content-Type")
if v == "" {
return nil, ErrNotMultipart
}
d, params, err := mime.ParseMediaType(v)
- if err != nil || d != "multipart/form-data" {
+ if err != nil || !(d == "multipart/form-data" || allowMixed && d == "multipart/mixed") {
return nil, ErrNotMultipart
}
boundary, ok := params["boundary"]
@@ -481,7 +498,7 @@ func (r *Request) Write(w io.Writer) error {
// WriteProxy is like Write but writes the request in the form
// expected by an HTTP proxy. In particular, WriteProxy writes the
// initial Request-URI line of the request with an absolute URI, per
-// section 5.1.2 of RFC 2616, including the scheme and host.
+// section 5.3 of RFC 7230, including the scheme and host.
// In either case, WriteProxy also writes a Host header, using
// either r.Host or r.URL.Host.
func (r *Request) WriteProxy(w io.Writer) error {
@@ -550,6 +567,9 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF
if err != nil {
return err
}
+ if trace != nil && trace.WroteHeaderField != nil {
+ trace.WroteHeaderField("Host", []string{host})
+ }
// Use the defaultUserAgent unless the Header contains one, which
// may be blank to not send the header.
@@ -562,6 +582,9 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF
if err != nil {
return err
}
+ if trace != nil && trace.WroteHeaderField != nil {
+ trace.WroteHeaderField("User-Agent", []string{userAgent})
+ }
}
// Process Body,ContentLength,Close,Trailer
@@ -569,18 +592,18 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF
if err != nil {
return err
}
- err = tw.WriteHeader(w)
+ err = tw.writeHeader(w, trace)
if err != nil {
return err
}
- err = r.Header.WriteSubset(w, reqWriteExcludeHeader)
+ err = r.Header.writeSubset(w, reqWriteExcludeHeader, trace)
if err != nil {
return err
}
if extraHeaders != nil {
- err = extraHeaders.Write(w)
+ err = extraHeaders.write(w, trace)
if err != nil {
return err
}
@@ -619,7 +642,7 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF
}
// Write body and trailer
- err = tw.WriteBody(w)
+ err = tw.writeBody(w)
if err != nil {
if tw.bodyReadError == err {
err = requestBodyReadError{err}
@@ -858,7 +881,8 @@ func (r *Request) BasicAuth() (username, password string, ok bool) {
// "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ("Aladdin", "open sesame", true).
func parseBasicAuth(auth string) (username, password string, ok bool) {
const prefix = "Basic "
- if !strings.HasPrefix(auth, prefix) {
+ // Case insensitive prefix match. See Issue 22736.
+ if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
return
}
c, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
@@ -910,6 +934,11 @@ func putTextprotoReader(r *textproto.Reader) {
}
// ReadRequest reads and parses an incoming request from b.
+//
+// ReadRequest is a low-level function and should only be used for
+// specialized applications; most code should use the Server to read
+// requests and handle them via the Handler interface. ReadRequest
+// only supports HTTP/1.x requests. For HTTP/2, use golang.org/x/net/http2.
func ReadRequest(b *bufio.Reader) (*Request, error) {
return readRequest(b, deleteHostHeader)
}
@@ -979,7 +1008,7 @@ func readRequest(b *bufio.Reader, deleteHostHeader bool) (req *Request, err erro
}
req.Header = Header(mimeHeader)
- // RFC 2616: Must treat
+ // RFC 7230, section 5.3: Must treat
// GET /index.html HTTP/1.1
// Host: www.google.com
// and
@@ -1094,8 +1123,8 @@ func parsePostForm(r *Request) (vs url.Values, err error) {
return
}
ct := r.Header.Get("Content-Type")
- // RFC 2616, section 7.2.1 - empty type
- // SHOULD be treated as application/octet-stream
+ // RFC 7231, section 3.1.1.5 - empty type
+ // MAY be treated as application/octet-stream
if ct == "" {
ct = "application/octet-stream"
}
@@ -1207,7 +1236,7 @@ func (r *Request) ParseMultipartForm(maxMemory int64) error {
return nil
}
- mr, err := r.multipartReader()
+ mr, err := r.multipartReader(false)
if err != nil {
return err
}
@@ -1248,8 +1277,8 @@ func (r *Request) FormValue(key string) string {
return ""
}
-// PostFormValue returns the first value for the named component of the POST
-// or PUT request body. URL query parameters are ignored.
+// PostFormValue returns the first value for the named component of the POST,
+// PATCH, or PUT request body. URL query parameters are ignored.
// PostFormValue calls ParseMultipartForm and ParseForm if necessary and ignores
// any errors returned by these functions.
// If key is not present, PostFormValue returns the empty string.
diff --git a/libgo/go/net/http/request_test.go b/libgo/go/net/http/request_test.go
index 967156b..7a83ae5 100644
--- a/libgo/go/net/http/request_test.go
+++ b/libgo/go/net/http/request_test.go
@@ -91,8 +91,8 @@ type parseContentTypeTest struct {
var parseContentTypeTests = []parseContentTypeTest{
{false, stringMap{"Content-Type": {"text/plain"}}},
- // Empty content type is legal - should be treated as
- // application/octet-stream (RFC 2616, section 7.2.1)
+ // Empty content type is legal - may be treated as
+ // application/octet-stream (RFC 7231, section 3.1.1.5)
{false, stringMap{}},
{true, stringMap{"Content-Type": {"text/plain; boundary="}}},
{false, stringMap{"Content-Type": {"application/unknown"}}},
@@ -143,6 +143,16 @@ func TestMultipartReader(t *testing.T) {
t.Errorf("expected multipart; error: %v", err)
}
+ req = &Request{
+ Method: "POST",
+ Header: Header{"Content-Type": {`multipart/mixed; boundary="foo123"`}},
+ Body: ioutil.NopCloser(new(bytes.Buffer)),
+ }
+ multipart, err = req.MultipartReader()
+ if multipart == nil {
+ t.Errorf("expected multipart; error: %v", err)
+ }
+
req.Header = Header{"Content-Type": {"text/plain"}}
multipart, err = req.MultipartReader()
if multipart != nil {
@@ -597,6 +607,11 @@ var parseBasicAuthTests = []struct {
ok bool
}{
{"Basic " + base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "Aladdin", "open sesame", true},
+
+ // Case doesn't matter:
+ {"BASIC " + base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "Aladdin", "open sesame", true},
+ {"basic " + base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "Aladdin", "open sesame", true},
+
{"Basic " + base64.StdEncoding.EncodeToString([]byte("Aladdin:open:sesame")), "Aladdin", "open:sesame", true},
{"Basic " + base64.StdEncoding.EncodeToString([]byte(":")), "", "", true},
{"Basic" + base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "", "", false},
diff --git a/libgo/go/net/http/response.go b/libgo/go/net/http/response.go
index a91efcf..bf1e13c 100644
--- a/libgo/go/net/http/response.go
+++ b/libgo/go/net/http/response.go
@@ -39,7 +39,7 @@ type Response struct {
// Header maps header keys to values. If the response had multiple
// headers with the same key, they may be concatenated, with comma
- // delimiters. (Section 4.2 of RFC 2616 requires that multiple headers
+ // delimiters. (RFC 7230, section 3.2.2 requires that multiple headers
// be semantically equivalent to a comma-delimited sequence.) When
// Header values are duplicated by other fields in this struct (e.g.,
// ContentLength, TransferEncoding, Trailer), the field values are
@@ -201,7 +201,7 @@ func ReadResponse(r *bufio.Reader, req *Request) (*Response, error) {
return resp, nil
}
-// RFC 2616: Should treat
+// RFC 7234, section 5.4: Should treat
// Pragma: no-cache
// like
// Cache-Control: no-cache
@@ -293,7 +293,7 @@ func (r *Response) Write(w io.Writer) error {
if err != nil {
return err
}
- err = tw.WriteHeader(w)
+ err = tw.writeHeader(w, nil)
if err != nil {
return err
}
@@ -319,7 +319,7 @@ func (r *Response) Write(w io.Writer) error {
}
// Write body and trailer
- err = tw.WriteBody(w)
+ err = tw.writeBody(w)
if err != nil {
return err
}
diff --git a/libgo/go/net/http/response_test.go b/libgo/go/net/http/response_test.go
index 1ea1961..c28b0cb 100644
--- a/libgo/go/net/http/response_test.go
+++ b/libgo/go/net/http/response_test.go
@@ -295,7 +295,7 @@ var respTests = []respTest{
},
// Status line without a Reason-Phrase, but trailing space.
- // (permitted by RFC 2616)
+ // (permitted by RFC 7230, section 3.1.2)
{
"HTTP/1.0 303 \r\n\r\n",
Response{
@@ -314,7 +314,7 @@ var respTests = []respTest{
},
// Status line without a Reason-Phrase, and no trailing space.
- // (not permitted by RFC 2616, but we'll accept it anyway)
+ // (not permitted by RFC 7230, but we'll accept it anyway)
{
"HTTP/1.0 303\r\n\r\n",
Response{
diff --git a/libgo/go/net/http/roundtrip.go b/libgo/go/net/http/roundtrip.go
new file mode 100644
index 0000000..2ec736b
--- /dev/null
+++ b/libgo/go/net/http/roundtrip.go
@@ -0,0 +1,18 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !js !wasm
+
+package http
+
+// RoundTrip implements the RoundTripper interface.
+//
+// For higher-level HTTP client support (such as handling of cookies
+// and redirects), see Get, Post, and the Client type.
+//
+// Like the RoundTripper interface, the error types returned
+// by RoundTrip are unspecified.
+func (t *Transport) RoundTrip(req *Request) (*Response, error) {
+ return t.roundTrip(req)
+}
diff --git a/libgo/go/net/http/roundtrip_js.go b/libgo/go/net/http/roundtrip_js.go
new file mode 100644
index 0000000..16b7b89
--- /dev/null
+++ b/libgo/go/net/http/roundtrip_js.go
@@ -0,0 +1,293 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build js,wasm
+
+package http
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "strconv"
+ "strings"
+ "syscall/js"
+)
+
+// jsFetchMode is a Request.Header map key that, if present,
+// signals that the map entry is actually an option to the Fetch API mode setting.
+// Valid values are: "cors", "no-cors", "same-origin", "navigate"
+// The default is "same-origin".
+//
+// Reference: https://developer.mozilla.org/en-US/docs/Web/API/WindowOrWorkerGlobalScope/fetch#Parameters
+const jsFetchMode = "js.fetch:mode"
+
+// jsFetchCreds is a Request.Header map key that, if present,
+// signals that the map entry is actually an option to the Fetch API credentials setting.
+// Valid values are: "omit", "same-origin", "include"
+// The default is "same-origin".
+//
+// Reference: https://developer.mozilla.org/en-US/docs/Web/API/WindowOrWorkerGlobalScope/fetch#Parameters
+const jsFetchCreds = "js.fetch:credentials"
+
+// RoundTrip implements the RoundTripper interface using the WHATWG Fetch API.
+func (t *Transport) RoundTrip(req *Request) (*Response, error) {
+ if useFakeNetwork() {
+ return t.roundTrip(req)
+ }
+
+ ac := js.Global().Get("AbortController")
+ if ac != js.Undefined() {
+ // Some browsers that support WASM don't necessarily support
+ // the AbortController. See
+ // https://developer.mozilla.org/en-US/docs/Web/API/AbortController#Browser_compatibility.
+ ac = ac.New()
+ }
+
+ opt := js.Global().Get("Object").New()
+ // See https://developer.mozilla.org/en-US/docs/Web/API/WindowOrWorkerGlobalScope/fetch
+ // for options available.
+ opt.Set("method", req.Method)
+ opt.Set("credentials", "same-origin")
+ if h := req.Header.Get(jsFetchCreds); h != "" {
+ opt.Set("credentials", h)
+ req.Header.Del(jsFetchCreds)
+ }
+ if h := req.Header.Get(jsFetchMode); h != "" {
+ opt.Set("mode", h)
+ req.Header.Del(jsFetchMode)
+ }
+ if ac != js.Undefined() {
+ opt.Set("signal", ac.Get("signal"))
+ }
+ headers := js.Global().Get("Headers").New()
+ for key, values := range req.Header {
+ for _, value := range values {
+ headers.Call("append", key, value)
+ }
+ }
+ opt.Set("headers", headers)
+
+ if req.Body != nil {
+ // TODO(johanbrandhorst): Stream request body when possible.
+ // See https://bugs.chromium.org/p/chromium/issues/detail?id=688906 for Blink issue.
+ // See https://bugzilla.mozilla.org/show_bug.cgi?id=1387483 for Firefox issue.
+ // See https://github.com/web-platform-tests/wpt/issues/7693 for WHATWG tests issue.
+ // See https://developer.mozilla.org/en-US/docs/Web/API/Streams_API for more details on the Streams API
+ // and browser support.
+ body, err := ioutil.ReadAll(req.Body)
+ if err != nil {
+ req.Body.Close() // RoundTrip must always close the body, including on errors.
+ return nil, err
+ }
+ req.Body.Close()
+ a := js.TypedArrayOf(body)
+ defer a.Release()
+ opt.Set("body", a)
+ }
+ respPromise := js.Global().Call("fetch", req.URL.String(), opt)
+ var (
+ respCh = make(chan *Response, 1)
+ errCh = make(chan error, 1)
+ )
+ success := js.NewCallback(func(args []js.Value) {
+ result := args[0]
+ header := Header{}
+ // https://developer.mozilla.org/en-US/docs/Web/API/Headers/entries
+ headersIt := result.Get("headers").Call("entries")
+ for {
+ n := headersIt.Call("next")
+ if n.Get("done").Bool() {
+ break
+ }
+ pair := n.Get("value")
+ key, value := pair.Index(0).String(), pair.Index(1).String()
+ ck := CanonicalHeaderKey(key)
+ header[ck] = append(header[ck], value)
+ }
+
+ contentLength := int64(0)
+ if cl, err := strconv.ParseInt(header.Get("Content-Length"), 10, 64); err == nil {
+ contentLength = cl
+ }
+
+ b := result.Get("body")
+ var body io.ReadCloser
+ if b != js.Undefined() {
+ body = &streamReader{stream: b.Call("getReader")}
+ } else {
+ // Fall back to using ArrayBuffer
+ // https://developer.mozilla.org/en-US/docs/Web/API/Body/arrayBuffer
+ body = &arrayReader{arrayPromise: result.Call("arrayBuffer")}
+ }
+
+ select {
+ case respCh <- &Response{
+ Status: result.Get("status").String() + " " + StatusText(result.Get("status").Int()),
+ StatusCode: result.Get("status").Int(),
+ Header: header,
+ ContentLength: contentLength,
+ Body: body,
+ Request: req,
+ }:
+ case <-req.Context().Done():
+ }
+ })
+ defer success.Release()
+ failure := js.NewCallback(func(args []js.Value) {
+ err := fmt.Errorf("net/http: fetch() failed: %s", args[0].String())
+ select {
+ case errCh <- err:
+ case <-req.Context().Done():
+ }
+ })
+ defer failure.Release()
+ respPromise.Call("then", success, failure)
+ select {
+ case <-req.Context().Done():
+ if ac != js.Undefined() {
+ // Abort the Fetch request
+ ac.Call("abort")
+ }
+ return nil, req.Context().Err()
+ case resp := <-respCh:
+ return resp, nil
+ case err := <-errCh:
+ return nil, err
+ }
+}
+
+var errClosed = errors.New("net/http: reader is closed")
+
+// useFakeNetwork is used to determine whether the request is made
+// by a test and should be made to use the fake in-memory network.
+func useFakeNetwork() bool {
+ return len(os.Args) > 0 && strings.HasSuffix(os.Args[0], ".test")
+}
+
+// streamReader implements an io.ReadCloser wrapper for ReadableStream.
+// See https://fetch.spec.whatwg.org/#readablestream for more information.
+type streamReader struct {
+ pending []byte
+ stream js.Value
+ err error // sticky read error
+}
+
+func (r *streamReader) Read(p []byte) (n int, err error) {
+ if r.err != nil {
+ return 0, r.err
+ }
+ if len(r.pending) == 0 {
+ var (
+ bCh = make(chan []byte, 1)
+ errCh = make(chan error, 1)
+ )
+ success := js.NewCallback(func(args []js.Value) {
+ result := args[0]
+ if result.Get("done").Bool() {
+ errCh <- io.EOF
+ return
+ }
+ value := make([]byte, result.Get("value").Get("byteLength").Int())
+ a := js.TypedArrayOf(value)
+ a.Call("set", result.Get("value"))
+ a.Release()
+ bCh <- value
+ })
+ defer success.Release()
+ failure := js.NewCallback(func(args []js.Value) {
+ // Assumes it's a TypeError. See
+ // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/TypeError
+ // for more information on this type. See
+ // https://streams.spec.whatwg.org/#byob-reader-read for the spec on
+ // the read method.
+ errCh <- errors.New(args[0].Get("message").String())
+ })
+ defer failure.Release()
+ r.stream.Call("read").Call("then", success, failure)
+ select {
+ case b := <-bCh:
+ r.pending = b
+ case err := <-errCh:
+ r.err = err
+ return 0, err
+ }
+ }
+ n = copy(p, r.pending)
+ r.pending = r.pending[n:]
+ return n, nil
+}
+
+func (r *streamReader) Close() error {
+ // This ignores any error returned from cancel method. So far, I did not encounter any concrete
+ // situation where reporting the error is meaningful. Most users ignore error from resp.Body.Close().
+ // If there's a need to report error here, it can be implemented and tested when that need comes up.
+ r.stream.Call("cancel")
+ if r.err == nil {
+ r.err = errClosed
+ }
+ return nil
+}
+
+// arrayReader implements an io.ReadCloser wrapper for ArrayBuffer.
+// https://developer.mozilla.org/en-US/docs/Web/API/Body/arrayBuffer.
+type arrayReader struct {
+ arrayPromise js.Value
+ pending []byte
+ read bool
+ err error // sticky read error
+}
+
+func (r *arrayReader) Read(p []byte) (n int, err error) {
+ if r.err != nil {
+ return 0, r.err
+ }
+ if !r.read {
+ r.read = true
+ var (
+ bCh = make(chan []byte, 1)
+ errCh = make(chan error, 1)
+ )
+ success := js.NewCallback(func(args []js.Value) {
+ // Wrap the input ArrayBuffer with a Uint8Array
+ uint8arrayWrapper := js.Global().Get("Uint8Array").New(args[0])
+ value := make([]byte, uint8arrayWrapper.Get("byteLength").Int())
+ a := js.TypedArrayOf(value)
+ a.Call("set", uint8arrayWrapper)
+ a.Release()
+ bCh <- value
+ })
+ defer success.Release()
+ failure := js.NewCallback(func(args []js.Value) {
+ // Assumes it's a TypeError. See
+ // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/TypeError
+ // for more information on this type.
+ // See https://fetch.spec.whatwg.org/#concept-body-consume-body for reasons this might error.
+ errCh <- errors.New(args[0].Get("message").String())
+ })
+ defer failure.Release()
+ r.arrayPromise.Call("then", success, failure)
+ select {
+ case b := <-bCh:
+ r.pending = b
+ case err := <-errCh:
+ return 0, err
+ }
+ }
+ if len(r.pending) == 0 {
+ return 0, io.EOF
+ }
+ n = copy(p, r.pending)
+ r.pending = r.pending[n:]
+ return n, nil
+}
+
+func (r *arrayReader) Close() error {
+ if r.err == nil {
+ r.err = errClosed
+ }
+ return nil
+}
diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go
index 9cbfe87..a438541 100644
--- a/libgo/go/net/http/serve_test.go
+++ b/libgo/go/net/http/serve_test.go
@@ -134,14 +134,15 @@ func reqBytes(req string) []byte {
}
type handlerTest struct {
+ logbuf bytes.Buffer
handler Handler
}
func newHandlerTest(h Handler) handlerTest {
- return handlerTest{h}
+ return handlerTest{handler: h}
}
-func (ht handlerTest) rawResponse(req string) string {
+func (ht *handlerTest) rawResponse(req string) string {
reqb := reqBytes(req)
var output bytes.Buffer
conn := &rwTestConn{
@@ -150,7 +151,11 @@ func (ht handlerTest) rawResponse(req string) string {
closec: make(chan bool, 1),
}
ln := &oneConnListener{conn: conn}
- go Serve(ln, ht.handler)
+ srv := &Server{
+ ErrorLog: log.New(&ht.logbuf, "", 0),
+ Handler: ht.handler,
+ }
+ go srv.Serve(ln)
<-conn.closec
return output.String()
}
@@ -379,6 +384,18 @@ func TestServeMuxHandler(t *testing.T) {
}
}
+// Issue 24297
+func TestServeMuxHandleFuncWithNilHandler(t *testing.T) {
+ setParallel(t)
+ defer func() {
+ if err := recover(); err == nil {
+ t.Error("expected call to mux.HandleFunc to panic")
+ }
+ }()
+ mux := NewServeMux()
+ mux.HandleFunc("/", nil)
+}
+
var serveMuxTests2 = []struct {
method string
host string
@@ -581,8 +598,19 @@ func TestServeWithSlashRedirectForHostPatterns(t *testing.T) {
}
}
-func BenchmarkServeMux(b *testing.B) {
+func TestShouldRedirectConcurrency(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+
+ mux := NewServeMux()
+ ts := httptest.NewServer(mux)
+ defer ts.Close()
+ mux.HandleFunc("/", func(w ResponseWriter, r *Request) {})
+}
+func BenchmarkServeMux(b *testing.B) { benchmarkServeMux(b, true) }
+func BenchmarkServeMux_SkipServe(b *testing.B) { benchmarkServeMux(b, false) }
+func benchmarkServeMux(b *testing.B, runHandler bool) {
type test struct {
path string
code int
@@ -614,9 +642,11 @@ func BenchmarkServeMux(b *testing.B) {
for _, tt := range tests {
*rw = httptest.ResponseRecorder{}
h, pattern := mux.Handler(tt.req)
- h.ServeHTTP(rw, tt.req)
- if pattern != tt.path || rw.Code != tt.code {
- b.Fatalf("got %d, %q, want %d, %q", rw.Code, pattern, tt.code, tt.path)
+ if runHandler {
+ h.ServeHTTP(rw, tt.req)
+ if pattern != tt.path || rw.Code != tt.code {
+ b.Fatalf("got %d, %q, want %d, %q", rw.Code, pattern, tt.code, tt.path)
+ }
}
}
}
@@ -931,7 +961,7 @@ func TestOnlyWriteTimeout(t *testing.T) {
if err == nil {
t.Errorf("expected an error from Get request")
}
- case <-time.After(5 * time.Second):
+ case <-time.After(10 * time.Second):
t.Fatal("timeout waiting for Get error")
}
if err := <-afterTimeoutErrc; err == nil {
@@ -2294,6 +2324,9 @@ func testTimeoutHandler(t *testing.T, h2 bool) {
if !strings.Contains(string(body), "<title>Timeout</title>") {
t.Errorf("expected timeout body; got %q", string(body))
}
+ if g, w := res.Header.Get("Content-Type"), "text/html; charset=utf-8"; g != w {
+ t.Errorf("response content-type = %q; want %q", g, w)
+ }
// Now make the previously-timed out handler speak again,
// which verifies the panic is handled:
@@ -2554,31 +2587,49 @@ func TestRedirect(t *testing.T) {
for _, tt := range tests {
rec := httptest.NewRecorder()
Redirect(rec, req, tt.in, 302)
+ if got, want := rec.Code, 302; got != want {
+ t.Errorf("Redirect(%q) generated status code %v; want %v", tt.in, got, want)
+ }
if got := rec.Header().Get("Location"); got != tt.want {
t.Errorf("Redirect(%q) generated Location header %q; want %q", tt.in, got, tt.want)
}
}
}
-// Test that Content-Type header is set for GET and HEAD requests.
-func TestRedirectContentTypeAndBody(t *testing.T) {
+// Test that Redirect sets Content-Type header for GET and HEAD requests
+// and writes a short HTML body, unless the request already has a Content-Type header.
+func TestRedirect_contentTypeAndBody(t *testing.T) {
+ type ctHeader struct {
+ Values []string
+ }
+
var tests = []struct {
method string
+ ct *ctHeader // Optional Content-Type header to set.
wantCT string
wantBody string
}{
- {MethodGet, "text/html; charset=utf-8", "<a href=\"/foo\">Found</a>.\n\n"},
- {MethodHead, "text/html; charset=utf-8", ""},
- {MethodPost, "", ""},
- {MethodDelete, "", ""},
- {"foo", "", ""},
+ {MethodGet, nil, "text/html; charset=utf-8", "<a href=\"/foo\">Found</a>.\n\n"},
+ {MethodHead, nil, "text/html; charset=utf-8", ""},
+ {MethodPost, nil, "", ""},
+ {MethodDelete, nil, "", ""},
+ {"foo", nil, "", ""},
+ {MethodGet, &ctHeader{[]string{"application/test"}}, "application/test", ""},
+ {MethodGet, &ctHeader{[]string{}}, "", ""},
+ {MethodGet, &ctHeader{nil}, "", ""},
}
for _, tt := range tests {
req := httptest.NewRequest(tt.method, "http://example.com/qux/", nil)
rec := httptest.NewRecorder()
+ if tt.ct != nil {
+ rec.Header()["Content-Type"] = tt.ct.Values
+ }
Redirect(rec, req, "/foo", 302)
+ if got, want := rec.Code, 302; got != want {
+ t.Errorf("Redirect(%q, %#v) generated status code %v; want %v", tt.method, tt.ct, got, want)
+ }
if got, want := rec.Header().Get("Content-Type"), tt.wantCT; got != want {
- t.Errorf("Redirect(%q) generated Content-Type header %q; want %q", tt.method, got, want)
+ t.Errorf("Redirect(%q, %#v) generated Content-Type header %q; want %q", tt.method, tt.ct, got, want)
}
resp := rec.Result()
body, err := ioutil.ReadAll(resp.Body)
@@ -2586,7 +2637,7 @@ func TestRedirectContentTypeAndBody(t *testing.T) {
t.Fatal(err)
}
if got, want := string(body), tt.wantBody; got != want {
- t.Errorf("Redirect(%q) generated Body %q; want %q", tt.method, got, want)
+ t.Errorf("Redirect(%q, %#v) generated Body %q; want %q", tt.method, tt.ct, got, want)
}
}
}
@@ -3127,25 +3178,32 @@ For:
ts.Close()
}
-// Tests that a pipelined request causes the first request's Handler's CloseNotify
-// channel to fire. Previously it deadlocked.
+// Tests that a pipelined request does not cause the first request's
+// Handler's CloseNotify channel to fire.
//
-// Issue 13165
+// Issue 13165 (where it used to deadlock), but behavior changed in Issue 23921.
func TestCloseNotifierPipelined(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
gotReq := make(chan bool, 2)
sawClose := make(chan bool, 2)
ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
gotReq <- true
cc := rw.(CloseNotifier).CloseNotify()
- <-cc
+ select {
+ case <-cc:
+ t.Error("unexpected CloseNotify")
+ case <-time.After(100 * time.Millisecond):
+ }
sawClose <- true
}))
+ defer ts.Close()
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatalf("error dialing: %v", err)
}
diec := make(chan bool, 1)
+ defer close(diec)
go func() {
const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n"
_, err = io.WriteString(conn, req+req) // two requests
@@ -3158,27 +3216,23 @@ func TestCloseNotifierPipelined(t *testing.T) {
}()
reqs := 0
closes := 0
-For:
for {
select {
case <-gotReq:
reqs++
if reqs > 2 {
t.Fatal("too many requests")
- } else if reqs > 1 {
- diec <- true
}
case <-sawClose:
closes++
if closes > 1 {
- break For
+ return
}
case <-time.After(5 * time.Second):
ts.CloseClientConnections()
t.Fatal("timeout")
}
}
- ts.Close()
}
func TestCloseNotifierChanLeak(t *testing.T) {
@@ -3377,14 +3431,14 @@ func TestHeaderToWire(t *testing.T) {
tests := []struct {
name string
handler func(ResponseWriter, *Request)
- check func(output string) error
+ check func(got, logs string) error
}{
{
name: "write without Header",
handler: func(rw ResponseWriter, r *Request) {
rw.Write([]byte("hello world"))
},
- check: func(got string) error {
+ check: func(got, logs string) error {
if !strings.Contains(got, "Content-Length:") {
return errors.New("no content-length")
}
@@ -3402,7 +3456,7 @@ func TestHeaderToWire(t *testing.T) {
rw.Write([]byte("hello world"))
h.Set("Too-Late", "bogus")
},
- check: func(got string) error {
+ check: func(got, logs string) error {
if !strings.Contains(got, "Content-Length:") {
return errors.New("no content-length")
}
@@ -3421,7 +3475,7 @@ func TestHeaderToWire(t *testing.T) {
rw.Write([]byte("hello world"))
rw.Header().Set("Too-Late", "Write already wrote headers")
},
- check: func(got string) error {
+ check: func(got, logs string) error {
if strings.Contains(got, "Too-Late") {
return errors.New("header appeared from after WriteHeader")
}
@@ -3435,7 +3489,7 @@ func TestHeaderToWire(t *testing.T) {
rw.Write([]byte("post-flush"))
rw.Header().Set("Too-Late", "Write already wrote headers")
},
- check: func(got string) error {
+ check: func(got, logs string) error {
if !strings.Contains(got, "Transfer-Encoding: chunked") {
return errors.New("not chunked")
}
@@ -3453,7 +3507,7 @@ func TestHeaderToWire(t *testing.T) {
rw.Write([]byte("post-flush"))
rw.Header().Set("Too-Late", "Write already wrote headers")
},
- check: func(got string) error {
+ check: func(got, logs string) error {
if !strings.Contains(got, "Transfer-Encoding: chunked") {
return errors.New("not chunked")
}
@@ -3472,7 +3526,7 @@ func TestHeaderToWire(t *testing.T) {
rw.Write([]byte("<html><head></head><body>some html</body></html>"))
rw.Header().Set("Content-Type", "x/wrong")
},
- check: func(got string) error {
+ check: func(got, logs string) error {
if !strings.Contains(got, "Content-Type: text/html") {
return errors.New("wrong content-type; want html")
}
@@ -3485,7 +3539,7 @@ func TestHeaderToWire(t *testing.T) {
rw.Header().Set("Content-Type", "some/type")
rw.Write([]byte("<html><head></head><body>some html</body></html>"))
},
- check: func(got string) error {
+ check: func(got, logs string) error {
if !strings.Contains(got, "Content-Type: some/type") {
return errors.New("wrong content-type; want html")
}
@@ -3496,7 +3550,7 @@ func TestHeaderToWire(t *testing.T) {
name: "empty handler",
handler: func(rw ResponseWriter, r *Request) {
},
- check: func(got string) error {
+ check: func(got, logs string) error {
if !strings.Contains(got, "Content-Length: 0") {
return errors.New("want 0 content-length")
}
@@ -3508,7 +3562,7 @@ func TestHeaderToWire(t *testing.T) {
handler: func(rw ResponseWriter, r *Request) {
rw.Header().Set("Some-Header", "some-value")
},
- check: func(got string) error {
+ check: func(got, logs string) error {
if !strings.Contains(got, "Some-Header") {
return errors.New("didn't get header")
}
@@ -3521,7 +3575,7 @@ func TestHeaderToWire(t *testing.T) {
rw.WriteHeader(404)
rw.Header().Set("Too-Late", "some-value")
},
- check: func(got string) error {
+ check: func(got, logs string) error {
if !strings.Contains(got, "404") {
return errors.New("wrong status")
}
@@ -3535,8 +3589,9 @@ func TestHeaderToWire(t *testing.T) {
for _, tc := range tests {
ht := newHandlerTest(HandlerFunc(tc.handler))
got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
- if err := tc.check(got); err != nil {
- t.Errorf("%s: %v\nGot response:\n%s", tc.name, err, got)
+ logs := ht.logbuf.String()
+ if err := tc.check(got, logs); err != nil {
+ t.Errorf("%s: %v\nGot response:\n%s\n\n%s", tc.name, err, got, logs)
}
}
}
@@ -5493,6 +5548,76 @@ func testServerShutdown(t *testing.T, h2 bool) {
}
}
+func TestServerShutdownStateNew(t *testing.T) {
+ if testing.Short() {
+ t.Skip("test takes 5-6 seconds; skipping in short mode")
+ }
+ setParallel(t)
+ defer afterTest(t)
+
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ // nothing.
+ }))
+ var connAccepted sync.WaitGroup
+ ts.Config.ConnState = func(conn net.Conn, state ConnState) {
+ if state == StateNew {
+ connAccepted.Done()
+ }
+ }
+ ts.Start()
+ defer ts.Close()
+
+ // Start a connection but never write to it.
+ connAccepted.Add(1)
+ c, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ // Wait for the connection to be accepted by the server. Otherwise, if
+ // Shutdown happens to run first, the server will be closed when
+ // encountering the connection, in which case it will be rejected
+ // immediately.
+ connAccepted.Wait()
+
+ shutdownRes := make(chan error, 1)
+ go func() {
+ shutdownRes <- ts.Config.Shutdown(context.Background())
+ }()
+ readRes := make(chan error, 1)
+ go func() {
+ _, err := c.Read([]byte{0})
+ readRes <- err
+ }()
+
+ const expectTimeout = 5 * time.Second
+ t0 := time.Now()
+ select {
+ case got := <-shutdownRes:
+ d := time.Since(t0)
+ if got != nil {
+ t.Fatalf("shutdown error after %v: %v", d, err)
+ }
+ if d < expectTimeout/2 {
+ t.Errorf("shutdown too soon after %v", d)
+ }
+ case <-time.After(expectTimeout * 3 / 2):
+ t.Fatalf("timeout waiting for shutdown")
+ }
+
+ // Wait for c.Read to unblock; should be already done at this point,
+ // or within a few milliseconds.
+ select {
+ case err := <-readRes:
+ if err == nil {
+ t.Error("expected error from Read")
+ }
+ case <-time.After(2 * time.Second):
+ t.Errorf("timeout waiting for Read to unblock")
+ }
+}
+
// Issue 17878: tests that we can call Close twice.
func TestServerCloseDeadlock(t *testing.T) {
var s Server
@@ -5590,6 +5715,10 @@ func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t *
// Issue 18535: test that the Server doesn't try to do a background
// read if it's already done one.
func TestServerDuplicateBackgroundRead(t *testing.T) {
+ if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm" {
+ testenv.SkipFlaky(t, 24826)
+ }
+
setParallel(t)
defer afterTest(t)
@@ -5653,31 +5782,23 @@ func TestServerHijackGetsBackgroundByte(t *testing.T) {
// Tell the client to send more data after the GET request.
inHandler <- true
- // Wait until the HTTP server sees the extra data
- // after the GET request. The HTTP server fires the
- // close notifier here, assuming it's a pipelined
- // request, as documented.
- select {
- case <-w.(CloseNotifier).CloseNotify():
- case <-time.After(5 * time.Second):
- t.Error("timeout")
- return
- }
-
conn, buf, err := w.(Hijacker).Hijack()
if err != nil {
t.Error(err)
return
}
defer conn.Close()
- n := buf.Reader.Buffered()
- if n != 1 {
- t.Errorf("buffered data = %d; want 1", n)
- }
+
peek, err := buf.Reader.Peek(3)
if string(peek) != "foo" || err != nil {
t.Errorf("Peek = %q, %v; want foo, nil", peek, err)
}
+
+ select {
+ case <-r.Context().Done():
+ t.Error("context unexpectedly canceled")
+ default:
+ }
}))
defer ts.Close()
@@ -5718,17 +5839,6 @@ func TestServerHijackGetsBackgroundByte_big(t *testing.T) {
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
defer close(done)
- // Wait until the HTTP server sees the extra data
- // after the GET request. The HTTP server fires the
- // close notifier here, assuming it's a pipelined
- // request, as documented.
- select {
- case <-w.(CloseNotifier).CloseNotify():
- case <-time.After(5 * time.Second):
- t.Error("timeout")
- return
- }
-
conn, buf, err := w.(Hijacker).Hijack()
if err != nil {
t.Error(err)
@@ -5800,6 +5910,94 @@ func TestServerValidatesMethod(t *testing.T) {
}
}
+// Listener for TestServerListenNotComparableListener.
+type eofListenerNotComparable []int
+
+func (eofListenerNotComparable) Accept() (net.Conn, error) { return nil, io.EOF }
+func (eofListenerNotComparable) Addr() net.Addr { return nil }
+func (eofListenerNotComparable) Close() error { return nil }
+
+// Issue 24812: don't crash on non-comparable Listener
+func TestServerListenNotComparableListener(t *testing.T) {
+ var s Server
+ s.Serve(make(eofListenerNotComparable, 1)) // used to panic
+}
+
+// countCloseListener is a Listener wrapper that counts the number of Close calls.
+type countCloseListener struct {
+ net.Listener
+ closes int32 // atomic
+}
+
+func (p *countCloseListener) Close() error {
+ atomic.AddInt32(&p.closes, 1)
+ return nil
+}
+
+// Issue 24803: don't call Listener.Close on Server.Shutdown.
+func TestServerCloseListenerOnce(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ cl := &countCloseListener{Listener: ln}
+ server := &Server{}
+ sdone := make(chan bool, 1)
+
+ go func() {
+ server.Serve(cl)
+ sdone <- true
+ }()
+ time.Sleep(10 * time.Millisecond)
+ server.Shutdown(context.Background())
+ ln.Close()
+ <-sdone
+
+ nclose := atomic.LoadInt32(&cl.closes)
+ if nclose != 1 {
+ t.Errorf("Close calls = %v; want 1", nclose)
+ }
+}
+
+// Issue 20239: don't block in Serve if Shutdown is called first.
+func TestServerShutdownThenServe(t *testing.T) {
+ var srv Server
+ cl := &countCloseListener{Listener: nil}
+ srv.Shutdown(context.Background())
+ got := srv.Serve(cl)
+ if got != ErrServerClosed {
+ t.Errorf("Serve err = %v; want ErrServerClosed", got)
+ }
+ nclose := atomic.LoadInt32(&cl.closes)
+ if nclose != 1 {
+ t.Errorf("Close calls = %v; want 1", nclose)
+ }
+}
+
+// Issue 23351: document and test behavior of ServeMux with ports
+func TestStripPortFromHost(t *testing.T) {
+ mux := NewServeMux()
+
+ mux.HandleFunc("example.com/", func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "OK")
+ })
+ mux.HandleFunc("example.com:9000/", func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "uh-oh!")
+ })
+
+ req := httptest.NewRequest("GET", "http://example.com:9000/", nil)
+ rw := httptest.NewRecorder()
+
+ mux.ServeHTTP(rw, req)
+
+ response := rw.Body.String()
+ if response != "OK" {
+ t.Errorf("Response gotten was %q", response)
+ }
+}
+
func BenchmarkResponseStatusLine(b *testing.B) {
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go
index 57e1b5d..c24ad75 100644
--- a/libgo/go/net/http/server.go
+++ b/libgo/go/net/http/server.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// HTTP server. See RFC 2616.
+// HTTP server. See RFC 7230 through 7235.
package http
@@ -28,7 +28,7 @@ import (
"sync/atomic"
"time"
- "golang_org/x/net/lex/httplex"
+ "golang_org/x/net/http/httpguts"
)
// Errors used by the HTTP server.
@@ -51,7 +51,9 @@ var (
// declared.
ErrContentLength = errors.New("http: wrote more than the declared Content-Length")
- // Deprecated: ErrWriteAfterFlush is no longer used.
+ // Deprecated: ErrWriteAfterFlush is no longer returned by
+ // anything in the net/http package. Callers should not
+ // compare errors against this variable.
ErrWriteAfterFlush = errors.New("unused")
)
@@ -107,7 +109,7 @@ type ResponseWriter interface {
// is to prefix the Header map keys with the TrailerPrefix
// constant value. See TrailerPrefix.
//
- // To suppress implicit response headers (such as "Date"), set
+ // To suppress automatic response headers (such as "Date"), set
// their value to nil.
Header() Header
@@ -117,7 +119,9 @@ type ResponseWriter interface {
// WriteHeader(http.StatusOK) before writing the data. If the Header
// does not contain a Content-Type line, Write adds a Content-Type set
// to the result of passing the initial 512 bytes of written data to
- // DetectContentType.
+ // DetectContentType. Additionally, if the total size of all written
+ // data is under a few KB and there are no Flush calls, the
+ // Content-Length header is added automatically.
//
// Depending on the HTTP protocol version and the client, calling
// Write or WriteHeader may prevent future reads on the
@@ -187,8 +191,10 @@ type Hijacker interface {
// The returned bufio.Reader may contain unprocessed buffered
// data from the client.
//
- // After a call to Hijack, the original Request.Body must
- // not be used.
+ // After a call to Hijack, the original Request.Body must not
+ // be used. The original Request's Context remains valid and
+ // is not canceled until the Request's ServeHTTP method
+ // returns.
Hijack() (net.Conn, *bufio.ReadWriter, error)
}
@@ -197,6 +203,9 @@ type Hijacker interface {
//
// This mechanism can be used to cancel long operations on the server
// if the client has disconnected before the response is ready.
+//
+// Deprecated: the CloseNotifier interface predates Go's context package.
+// New code should use Request.Context instead.
type CloseNotifier interface {
// CloseNotify returns a channel that receives at most a
// single value (true) when the client connection has gone
@@ -227,8 +236,8 @@ var (
ServerContextKey = &contextKey{"http-server"}
// LocalAddrContextKey is a context key. It can be used in
- // HTTP handlers with context.WithValue to access the address
- // the local address the connection arrived on.
+ // HTTP handlers with context.WithValue to access the local
+ // address the connection arrived on.
// The associated value will be of type net.Addr.
LocalAddrContextKey = &contextKey{"local-addr"}
)
@@ -279,7 +288,7 @@ type conn struct {
curReq atomic.Value // of *response (which has a Request in it)
- curState atomic.Value // of ConnState
+ curState struct{ atomic uint64 } // packed (unixtime<<8|uint8(ConnState))
// mu guards hijackedv
mu sync.Mutex
@@ -334,7 +343,7 @@ type chunkWriter struct {
res *response
// header is either nil or a deep clone of res.handlerHeader
- // at the time of res.WriteHeader, if res.WriteHeader is
+ // at the time of res.writeHeader, if res.writeHeader is
// called and extra buffering is being done to calculate
// Content-Type and/or Content-Length.
header Header
@@ -510,9 +519,8 @@ func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }
// written in the trailers at the end of the response.
func (w *response) declareTrailer(k string) {
k = CanonicalHeaderKey(k)
- switch k {
- case "Transfer-Encoding", "Content-Length", "Trailer":
- // Forbidden by RFC 2616 14.40.
+ if !httpguts.ValidTrailerHeader(k) {
+ // Forbidden by RFC 7230, section 4.1.2
return
}
w.trailers = append(w.trailers, k)
@@ -669,10 +677,28 @@ func (cr *connReader) backgroundRead() {
cr.lock()
if n == 1 {
cr.hasByte = true
- // We were at EOF already (since we wouldn't be in a
- // background read otherwise), so this is a pipelined
- // HTTP request.
- cr.closeNotifyFromPipelinedRequest()
+ // We were past the end of the previous request's body already
+ // (since we wouldn't be in a background read otherwise), so
+ // this is a pipelined HTTP request. Prior to Go 1.11 we used to
+ // send on the CloseNotify channel and cancel the context here,
+ // but the behavior was documented as only "may", and we only
+ // did that because that's how CloseNotify accidentally behaved
+ // in very early Go releases prior to context support. Once we
+ // added context support, people used a Handler's
+ // Request.Context() and passed it along. Having that context
+ // cancel on pipelined HTTP requests caused problems.
+ // Fortunately, almost nothing uses HTTP/1.x pipelining.
+ // Unfortunately, apt-get does, or sometimes does.
+ // New Go 1.11 behavior: don't fire CloseNotify or cancel
+ // contexts on pipelined requests. Shouldn't affect people, but
+ // fixes cases like Issue 23921. This does mean that a client
+ // closing their TCP connection after sending a pipelined
+ // request won't cancel the context, but we'll catch that on any
+ // write failure (in checkConnErrorWriter.Write).
+ // If the server never writes, yes, there are still contrived
+ // server & client behaviors where this fails to ever cancel the
+ // context, but that's kinda why HTTP/1.x pipelining died
+ // anyway.
}
if ne, ok := err.(net.Error); ok && cr.aborted && ne.Timeout() {
// Ignore this error. It's the expected error from
@@ -704,22 +730,18 @@ func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain }
func (cr *connReader) setInfiniteReadLimit() { cr.remain = maxInt64 }
func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 }
-// may be called from multiple goroutines.
-func (cr *connReader) handleReadError(err error) {
- cr.conn.cancelCtx()
- cr.closeNotify()
-}
-
-// closeNotifyFromPipelinedRequest simply calls closeNotify.
+// handleReadError is called whenever a Read from the client returns a
+// non-nil error.
//
-// This method wrapper is here for documentation. The callers are the
-// cases where we send on the closenotify channel because of a
-// pipelined HTTP request, per the previous Go behavior and
-// documentation (that this "MAY" happen).
+// The provided non-nil err is almost always io.EOF or a "use of
+// closed network connection". In any case, the error is not
+// particularly interesting, except perhaps for debugging during
+// development. Any error means the connection is dead and we should
+// down its context.
//
-// TODO: consider changing this behavior and making context
-// cancelation and closenotify work the same.
-func (cr *connReader) closeNotifyFromPipelinedRequest() {
+// It may be called from multiple goroutines.
+func (cr *connReader) handleReadError(_ error) {
+ cr.conn.cancelCtx()
cr.closeNotify()
}
@@ -937,7 +959,7 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) {
c.r.setReadLimit(c.server.initialReadLimitSize())
if c.lastMethod == "POST" {
- // RFC 2616 section 4.1 tolerance for old buggy clients.
+ // RFC 7230 section 3 tolerance for old buggy clients.
peek, _ := c.bufr.Peek(4) // ReadRequest will get err below
c.bufr.Discard(numLeadingCRorLF(peek))
}
@@ -964,15 +986,15 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) {
if len(hosts) > 1 {
return nil, badRequestError("too many Host headers")
}
- if len(hosts) == 1 && !httplex.ValidHostHeader(hosts[0]) {
+ if len(hosts) == 1 && !httpguts.ValidHostHeader(hosts[0]) {
return nil, badRequestError("malformed Host header")
}
for k, vv := range req.Header {
- if !httplex.ValidHeaderFieldName(k) {
+ if !httpguts.ValidHeaderFieldName(k) {
return nil, badRequestError("invalid header name")
}
for _, v := range vv {
- if !httplex.ValidHeaderFieldValue(v) {
+ if !httpguts.ValidHeaderFieldValue(v) {
return nil, badRequestError("invalid header value")
}
}
@@ -1058,7 +1080,7 @@ func checkWriteHeaderCode(code int) {
// Issue 22880: require valid WriteHeader status codes.
// For now we only enforce that it's three digits.
// In the future we might block things over 599 (600 and above aren't defined
- // at http://httpwg.org/specs/rfc7231.html#status.codes)
+ // at https://httpwg.org/specs/rfc7231.html#status.codes)
// and we might block under 200 (once we have more mature 1xx support).
// But for now any three digits.
//
@@ -1414,7 +1436,7 @@ func (cw *chunkWriter) writeHeader(p []byte) {
}
// foreachHeaderElement splits v according to the "#rule" construction
-// in RFC 2616 section 2.1 and calls fn for each non-empty element.
+// in RFC 7230 section 7 and calls fn for each non-empty element.
func foreachHeaderElement(v string, fn func(string)) {
v = textproto.TrimString(v)
if v == "" {
@@ -1431,7 +1453,7 @@ func foreachHeaderElement(v string, fn func(string)) {
}
}
-// writeStatusLine writes an HTTP/1.x Status-Line (RFC 2616 Section 6.1)
+// writeStatusLine writes an HTTP/1.x Status-Line (RFC 7230 Section 3.1.2)
// to bw. is11 is whether the HTTP request is HTTP/1.1. false means HTTP/1.0.
// code is the response status code.
// scratch is an optional scratch buffer. If it has at least capacity 3, it's used.
@@ -1668,21 +1690,19 @@ func (c *conn) setState(nc net.Conn, state ConnState) {
case StateHijacked, StateClosed:
srv.trackConn(c, false)
}
- c.curState.Store(connStateInterface[state])
+ if state > 0xff || state < 0 {
+ panic("internal error")
+ }
+ packedState := uint64(time.Now().Unix()<<8) | uint64(state)
+ atomic.StoreUint64(&c.curState.atomic, packedState)
if hook := srv.ConnState; hook != nil {
hook(nc, state)
}
}
-// connStateInterface is an array of the interface{} versions of
-// ConnState values, so we can use them in atomic.Values later without
-// paying the cost of shoving their integers in an interface{}.
-var connStateInterface = [...]interface{}{
- StateNew: StateNew,
- StateActive: StateActive,
- StateIdle: StateIdle,
- StateHijacked: StateHijacked,
- StateClosed: StateClosed,
+func (c *conn) getState() (state ConnState, unixSec int64) {
+ packedState := atomic.LoadUint64(&c.curState.atomic)
+ return ConnState(packedState & 0xff), int64(packedState >> 8)
}
// badRequestError is a literal string (used by in the server in HTML,
@@ -1814,9 +1834,6 @@ func (c *conn) serve(ctx context.Context) {
if requestBodyRemains(req.Body) {
registerOnHitEOF(req.Body, w.conn.r.startBackgroundRead)
} else {
- if w.conn.bufr.Buffered() > 0 {
- w.conn.r.closeNotifyFromPipelinedRequest()
- }
w.conn.r.startBackgroundRead()
}
@@ -1868,11 +1885,11 @@ func (w *response) sendExpectationFailed() {
// make the ResponseWriter an optional
// "ExpectReplier" interface or something.
//
- // For now we'll just obey RFC 2616 14.20 which says
- // "If a server receives a request containing an
- // Expect field that includes an expectation-
- // extension that it does not support, it MUST
- // respond with a 417 (Expectation Failed) status."
+ // For now we'll just obey RFC 7231 5.1.1 which says
+ // "A server that receives an Expect field-value other
+ // than 100-continue MAY respond with a 417 (Expectation
+ // Failed) status code to indicate that the unexpected
+ // expectation cannot be met."
w.Header().Set("Connection", "close")
w.WriteHeader(StatusExpectationFailed)
w.finishRequest()
@@ -1995,25 +2012,19 @@ func StripPrefix(prefix string, h Handler) Handler {
//
// The provided code should be in the 3xx range and is usually
// StatusMovedPermanently, StatusFound or StatusSeeOther.
+//
+// If the Content-Type header has not been set, Redirect sets it
+// to "text/html; charset=utf-8" and writes a small HTML body.
+// Setting the Content-Type header to any value, including nil,
+// disables that behavior.
func Redirect(w ResponseWriter, r *Request, url string, code int) {
// parseURL is just url.Parse (url is shadowed for godoc).
if u, err := parseURL(url); err == nil {
- // If url was relative, make absolute by
+ // If url was relative, make its path absolute by
// combining with request path.
- // The browser would probably do this for us,
+ // The client would probably do this for us,
// but doing it ourselves is more reliable.
-
- // NOTE(rsc): RFC 2616 says that the Location
- // line must be an absolute URI, like
- // "http://www.google.com/redirect/",
- // not a path like "/redirect/".
- // Unfortunately, we don't know what to
- // put in the host name section to get the
- // client to connect to us again, so we can't
- // know the right absolute URI to send back.
- // Because of this problem, no one pays attention
- // to the RFC; they all send back just a new path.
- // So do we.
+ // See RFC 7231, section 7.1.2
if u.Scheme == "" && u.Host == "" {
oldpath := r.URL.Path
if oldpath == "" { // should not happen, but avoid a crash if it does
@@ -2042,18 +2053,23 @@ func Redirect(w ResponseWriter, r *Request, url string, code int) {
}
}
- w.Header().Set("Location", hexEscapeNonASCII(url))
- if r.Method == "GET" || r.Method == "HEAD" {
- w.Header().Set("Content-Type", "text/html; charset=utf-8")
+ h := w.Header()
+
+ // RFC 7231 notes that a short HTML body is usually included in
+ // the response because older user agents may not understand 301/307.
+ // Do it only if the request didn't already have a Content-Type header.
+ _, hadCT := h["Content-Type"]
+
+ h.Set("Location", hexEscapeNonASCII(url))
+ if !hadCT && (r.Method == "GET" || r.Method == "HEAD") {
+ h.Set("Content-Type", "text/html; charset=utf-8")
}
w.WriteHeader(code)
- // RFC 2616 recommends that a short note "SHOULD" be included in the
- // response because older user agents may not understand 301/307.
- // Shouldn't send the response for POST or HEAD; that leaves GET.
- if r.Method == "GET" {
- note := "<a href=\"" + htmlEscape(url) + "\">" + statusText[code] + "</a>.\n"
- fmt.Fprintln(w, note)
+ // Shouldn't send the body for POST or HEAD; that leaves GET.
+ if !hadCT && r.Method == "GET" {
+ body := "<a href=\"" + htmlEscape(url) + "\">" + statusText[code] + "</a>.\n"
+ fmt.Fprintln(w, body)
}
}
@@ -2127,9 +2143,9 @@ func RedirectHandler(url string, code int) Handler {
// "/codesearch" and "codesearch.google.com/" without also taking over
// requests for "http://www.google.com/".
//
-// ServeMux also takes care of sanitizing the URL request path,
-// redirecting any request containing . or .. elements or repeated slashes
-// to an equivalent, cleaner URL.
+// ServeMux also takes care of sanitizing the URL request path and the Host
+// header, stripping the port number and redirecting any request containing . or
+// .. elements or repeated slashes to an equivalent, cleaner URL.
type ServeMux struct {
mu sync.RWMutex
m map[string]muxEntry
@@ -2162,7 +2178,7 @@ func pathMatch(pattern, path string) bool {
return len(path) >= n && path[0:n] == pattern
}
-// Return the canonical path for p, eliminating . and .. elements.
+// cleanPath returns the canonical path for p, eliminating . and .. elements.
func cleanPath(p string) string {
if p == "" {
return "/"
@@ -2174,7 +2190,12 @@ func cleanPath(p string) string {
// path.Clean removes trailing slash except for root;
// put the trailing slash back if necessary.
if p[len(p)-1] == '/' && np != "/" {
- np += "/"
+ // Fast path for common case of p being the string we want:
+ if len(p) == len(np)+1 && strings.HasPrefix(p, np) {
+ np = p
+ } else {
+ np += "/"
+ }
}
return np
}
@@ -2221,7 +2242,10 @@ func (mux *ServeMux) match(path string) (h Handler, pattern string) {
// not for path itself. If the path needs appending to, it creates a new
// URL, setting the path to u.Path + "/" and returning true to indicate so.
func (mux *ServeMux) redirectToPathSlash(host, path string, u *url.URL) (*url.URL, bool) {
- if !mux.shouldRedirect(host, path) {
+ mux.mu.RLock()
+ shouldRedirect := mux.shouldRedirectRLocked(host, path)
+ mux.mu.RUnlock()
+ if !shouldRedirect {
return u, false
}
path = path + "/"
@@ -2229,10 +2253,10 @@ func (mux *ServeMux) redirectToPathSlash(host, path string, u *url.URL) (*url.UR
return u, true
}
-// shouldRedirect reports whether the given path and host should be redirected to
+// shouldRedirectRLocked reports whether the given path and host should be redirected to
// path+"/". This should happen if a handler is registered for path+"/" but
// not path -- see comments at ServeMux.
-func (mux *ServeMux) shouldRedirect(host, path string) bool {
+func (mux *ServeMux) shouldRedirectRLocked(host, path string) bool {
p := []string{path, host + path}
for _, c := range p {
@@ -2365,6 +2389,9 @@ func (mux *ServeMux) Handle(pattern string, handler Handler) {
// HandleFunc registers the handler function for the given pattern.
func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Request)) {
+ if handler == nil {
+ panic("http: nil handler")
+ }
mux.Handle(pattern, HandlerFunc(handler))
}
@@ -2383,7 +2410,14 @@ func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) {
// Serve accepts incoming HTTP connections on the listener l,
// creating a new service goroutine for each. The service goroutines
// read requests and then call handler to reply to them.
-// Handler is typically nil, in which case the DefaultServeMux is used.
+//
+// The handler is typically nil, in which case the DefaultServeMux is used.
+//
+// HTTP/2 support is only enabled if the Listener returns *tls.Conn
+// connections and they were configured with "h2" in the TLS
+// Config.NextProtos.
+//
+// Serve always returns a non-nil error.
func Serve(l net.Listener, handler Handler) error {
srv := &Server{Handler: handler}
return srv.Serve(l)
@@ -2393,12 +2427,14 @@ func Serve(l net.Listener, handler Handler) error {
// creating a new service goroutine for each. The service goroutines
// read requests and then call handler to reply to them.
//
-// Handler is typically nil, in which case the DefaultServeMux is used.
+// The handler is typically nil, in which case the DefaultServeMux is used.
//
// Additionally, files containing a certificate and matching private key
// for the server must be provided. If the certificate is signed by a
// certificate authority, the certFile should be the concatenation
// of the server's certificate, any intermediates, and the CA's certificate.
+//
+// ServeTLS always returns a non-nil error.
func ServeTLS(l net.Listener, handler Handler, certFile, keyFile string) error {
srv := &Server{Handler: handler}
return srv.ServeTLS(l, certFile, keyFile)
@@ -2481,7 +2517,7 @@ type Server struct {
nextProtoErr error // result of http2.ConfigureServer if used
mu sync.Mutex
- listeners map[net.Listener]struct{}
+ listeners map[*net.Listener]struct{}
activeConn map[*conn]struct{}
doneChan chan struct{}
onShutdown []func()
@@ -2522,6 +2558,7 @@ func (s *Server) closeDoneChanLocked() {
// Close returns any error returned from closing the Server's
// underlying Listener(s).
func (srv *Server) Close() error {
+ atomic.StoreInt32(&srv.inShutdown, 1)
srv.mu.Lock()
defer srv.mu.Unlock()
srv.closeDoneChanLocked()
@@ -2559,9 +2596,11 @@ var shutdownPollInterval = 500 * time.Millisecond
// separately notify such long-lived connections of shutdown and wait
// for them to close, if desired. See RegisterOnShutdown for a way to
// register shutdown notification functions.
+//
+// Once Shutdown has been called on a server, it may not be reused;
+// future calls to methods such as Serve will return ErrServerClosed.
func (srv *Server) Shutdown(ctx context.Context) error {
- atomic.AddInt32(&srv.inShutdown, 1)
- defer atomic.AddInt32(&srv.inShutdown, -1)
+ atomic.StoreInt32(&srv.inShutdown, 1)
srv.mu.Lock()
lnerr := srv.closeListenersLocked()
@@ -2603,8 +2642,16 @@ func (s *Server) closeIdleConns() bool {
defer s.mu.Unlock()
quiescent := true
for c := range s.activeConn {
- st, ok := c.curState.Load().(ConnState)
- if !ok || st != StateIdle {
+ st, unixSec := c.getState()
+ // Issue 22682: treat StateNew connections as if
+ // they're idle if we haven't read the first request's
+ // header in over 5 seconds.
+ if st == StateNew && unixSec < time.Now().Unix()-5 {
+ st = StateIdle
+ }
+ if st != StateIdle || unixSec == 0 {
+ // Assume unixSec == 0 means it's a very new
+ // connection, without state set yet.
quiescent = false
continue
}
@@ -2617,7 +2664,7 @@ func (s *Server) closeIdleConns() bool {
func (s *Server) closeListenersLocked() error {
var err error
for ln := range s.listeners {
- if cerr := ln.Close(); cerr != nil && err == nil {
+ if cerr := (*ln).Close(); cerr != nil && err == nil {
err = cerr
}
delete(s.listeners, ln)
@@ -2697,9 +2744,15 @@ func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) {
// ListenAndServe listens on the TCP network address srv.Addr and then
// calls Serve to handle requests on incoming connections.
// Accepted connections are configured to enable TCP keep-alives.
+//
// If srv.Addr is blank, ":http" is used.
-// ListenAndServe always returns a non-nil error.
+//
+// ListenAndServe always returns a non-nil error. After Shutdown or Close,
+// the returned error is ErrServerClosed.
func (srv *Server) ListenAndServe() error {
+ if srv.shuttingDown() {
+ return ErrServerClosed
+ }
addr := srv.Addr
if addr == "" {
addr = ":http"
@@ -2743,27 +2796,30 @@ var ErrServerClosed = errors.New("http: Server closed")
// new service goroutine for each. The service goroutines read requests and
// then call srv.Handler to reply to them.
//
-// For HTTP/2 support, srv.TLSConfig should be initialized to the
-// provided listener's TLS Config before calling Serve. If
-// srv.TLSConfig is non-nil and doesn't include the string "h2" in
-// Config.NextProtos, HTTP/2 support is not enabled.
+// HTTP/2 support is only enabled if the Listener returns *tls.Conn
+// connections and they were configured with "h2" in the TLS
+// Config.NextProtos.
//
-// Serve always returns a non-nil error. After Shutdown or Close, the
-// returned error is ErrServerClosed.
+// Serve always returns a non-nil error and closes l.
+// After Shutdown or Close, the returned error is ErrServerClosed.
func (srv *Server) Serve(l net.Listener) error {
- defer l.Close()
if fn := testHookServerServe; fn != nil {
- fn(srv, l)
+ fn(srv, l) // call hook with unwrapped listener
}
- var tempDelay time.Duration // how long to sleep on accept failure
+
+ l = &onceCloseListener{Listener: l}
+ defer l.Close()
if err := srv.setupHTTP2_Serve(); err != nil {
return err
}
- srv.trackListener(l, true)
- defer srv.trackListener(l, false)
+ if !srv.trackListener(&l, true) {
+ return ErrServerClosed
+ }
+ defer srv.trackListener(&l, false)
+ var tempDelay time.Duration // how long to sleep on accept failure
baseCtx := context.Background() // base is always background, per Issue 16220
ctx := context.WithValue(baseCtx, ServerContextKey, srv)
for {
@@ -2797,19 +2853,15 @@ func (srv *Server) Serve(l net.Listener) error {
}
// ServeTLS accepts incoming connections on the Listener l, creating a
-// new service goroutine for each. The service goroutines read requests and
-// then call srv.Handler to reply to them.
+// new service goroutine for each. The service goroutines perform TLS
+// setup and then read requests, calling srv.Handler to reply to them.
//
-// Additionally, files containing a certificate and matching private key for
-// the server must be provided if neither the Server's TLSConfig.Certificates
-// nor TLSConfig.GetCertificate are populated.. If the certificate is signed by
-// a certificate authority, the certFile should be the concatenation of the
-// server's certificate, any intermediates, and the CA's certificate.
-//
-// For HTTP/2 support, srv.TLSConfig should be initialized to the
-// provided listener's TLS Config before calling ServeTLS. If
-// srv.TLSConfig is non-nil and doesn't include the string "h2" in
-// Config.NextProtos, HTTP/2 support is not enabled.
+// Files containing a certificate and matching private key for the
+// server must be provided if neither the Server's
+// TLSConfig.Certificates nor TLSConfig.GetCertificate are populated.
+// If the certificate is signed by a certificate authority, the
+// certFile should be the concatenation of the server's certificate,
+// any intermediates, and the CA's certificate.
//
// ServeTLS always returns a non-nil error. After Shutdown or Close, the
// returned error is ErrServerClosed.
@@ -2839,22 +2891,31 @@ func (srv *Server) ServeTLS(l net.Listener, certFile, keyFile string) error {
return srv.Serve(tlsListener)
}
-func (s *Server) trackListener(ln net.Listener, add bool) {
+// trackListener adds or removes a net.Listener to the set of tracked
+// listeners.
+//
+// We store a pointer to interface in the map set, in case the
+// net.Listener is not comparable. This is safe because we only call
+// trackListener via Serve and can track+defer untrack the same
+// pointer to local variable there. We never need to compare a
+// Listener from another caller.
+//
+// It reports whether the server is still up (not Shutdown or Closed).
+func (s *Server) trackListener(ln *net.Listener, add bool) bool {
s.mu.Lock()
defer s.mu.Unlock()
if s.listeners == nil {
- s.listeners = make(map[net.Listener]struct{})
+ s.listeners = make(map[*net.Listener]struct{})
}
if add {
- // If the *Server is being reused after a previous
- // Close or Shutdown, reset its doneChan:
- if len(s.listeners) == 0 && len(s.activeConn) == 0 {
- s.doneChan = nil
+ if s.shuttingDown() {
+ return false
}
s.listeners[ln] = struct{}{}
} else {
delete(s.listeners, ln)
}
+ return true
}
func (s *Server) trackConn(c *conn, add bool) {
@@ -2889,6 +2950,8 @@ func (s *Server) doKeepAlives() bool {
}
func (s *Server) shuttingDown() bool {
+ // TODO: replace inShutdown with the existing atomicBool type;
+ // see https://github.com/golang/go/issues/20239#issuecomment-381434582
return atomic.LoadInt32(&s.inShutdown) != 0
}
@@ -2906,14 +2969,7 @@ func (srv *Server) SetKeepAlivesEnabled(v bool) {
// Close idle HTTP/1 conns:
srv.closeIdleConns()
- // Close HTTP/2 conns, as soon as they become idle, but reset
- // the chan so future conns (if the listener is still active)
- // still work and don't get a GOAWAY immediately, before their
- // first request:
- srv.mu.Lock()
- defer srv.mu.Unlock()
- srv.closeDoneChanLocked() // closes http2 conns
- srv.doneChan = nil
+ // TODO: Issue 26303: close HTTP/2 conns as soon as they become idle.
}
func (s *Server) logf(format string, args ...interface{}) {
@@ -2936,32 +2992,11 @@ func logf(r *Request, format string, args ...interface{}) {
}
}
-// ListenAndServe listens on the TCP network address addr
-// and then calls Serve with handler to handle requests
-// on incoming connections.
+// ListenAndServe listens on the TCP network address addr and then calls
+// Serve with handler to handle requests on incoming connections.
// Accepted connections are configured to enable TCP keep-alives.
-// Handler is typically nil, in which case the DefaultServeMux is
-// used.
-//
-// A trivial example server is:
-//
-// package main
//
-// import (
-// "io"
-// "net/http"
-// "log"
-// )
-//
-// // hello world, the web server
-// func HelloServer(w http.ResponseWriter, req *http.Request) {
-// io.WriteString(w, "hello, world!\n")
-// }
-//
-// func main() {
-// http.HandleFunc("/hello", HelloServer)
-// log.Fatal(http.ListenAndServe(":12345", nil))
-// }
+// The handler is typically nil, in which case the DefaultServeMux is used.
//
// ListenAndServe always returns a non-nil error.
func ListenAndServe(addr string, handler Handler) error {
@@ -2974,36 +3009,13 @@ func ListenAndServe(addr string, handler Handler) error {
// matching private key for the server must be provided. If the certificate
// is signed by a certificate authority, the certFile should be the concatenation
// of the server's certificate, any intermediates, and the CA's certificate.
-//
-// A trivial example server is:
-//
-// import (
-// "log"
-// "net/http"
-// )
-//
-// func handler(w http.ResponseWriter, req *http.Request) {
-// w.Header().Set("Content-Type", "text/plain")
-// w.Write([]byte("This is an example server.\n"))
-// }
-//
-// func main() {
-// http.HandleFunc("/", handler)
-// log.Printf("About to listen on 10443. Go to https://127.0.0.1:10443/")
-// err := http.ListenAndServeTLS(":10443", "cert.pem", "key.pem", nil)
-// log.Fatal(err)
-// }
-//
-// One can use generate_cert.go in crypto/tls to generate cert.pem and key.pem.
-//
-// ListenAndServeTLS always returns a non-nil error.
func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error {
server := &Server{Addr: addr, Handler: handler}
return server.ListenAndServeTLS(certFile, keyFile)
}
// ListenAndServeTLS listens on the TCP network address srv.Addr and
-// then calls Serve to handle requests on incoming TLS connections.
+// then calls ServeTLS to handle requests on incoming TLS connections.
// Accepted connections are configured to enable TCP keep-alives.
//
// Filenames containing a certificate and matching private key for the
@@ -3015,8 +3027,12 @@ func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error {
//
// If srv.Addr is blank, ":https" is used.
//
-// ListenAndServeTLS always returns a non-nil error.
+// ListenAndServeTLS always returns a non-nil error. After Shutdown or
+// Close, the returned error is ErrServerClosed.
func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error {
+ if srv.shuttingDown() {
+ return ErrServerClosed
+ }
addr := srv.Addr
if addr == "" {
addr = ":https"
@@ -3042,8 +3058,8 @@ func (srv *Server) setupHTTP2_ServeTLS() error {
// setupHTTP2_Serve is called from (*Server).Serve and conditionally
// configures HTTP/2 on srv using a more conservative policy than
-// setupHTTP2_ServeTLS because Serve may be called
-// concurrently.
+// setupHTTP2_ServeTLS because Serve is called after tls.Listen,
+// and may be called concurrently. See shouldConfigureHTTP2ForServe.
//
// The tests named TestTransportAutomaticHTTP2* and
// TestConcurrentServerServe in server_test.go demonstrate some
@@ -3222,6 +3238,21 @@ func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
return tc, nil
}
+// onceCloseListener wraps a net.Listener, protecting it from
+// multiple Close calls.
+type onceCloseListener struct {
+ net.Listener
+ once sync.Once
+ closeErr error
+}
+
+func (oc *onceCloseListener) Close() error {
+ oc.once.Do(oc.close)
+ return oc.closeErr
+}
+
+func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() }
+
// globalOptionsHandler responds to "OPTIONS *" requests.
type globalOptionsHandler struct{}
diff --git a/libgo/go/net/http/sniff.go b/libgo/go/net/http/sniff.go
index 365a36c..c1494ab 100644
--- a/libgo/go/net/http/sniff.go
+++ b/libgo/go/net/http/sniff.go
@@ -13,7 +13,7 @@ import (
const sniffLen = 512
// DetectContentType implements the algorithm described
-// at http://mimesniff.spec.whatwg.org/ to determine the
+// at https://mimesniff.spec.whatwg.org/ to determine the
// Content-Type of the given data. It considers at most the
// first 512 bytes of data. DetectContentType always returns
// a valid MIME type: if it cannot determine a more specific one, it
@@ -136,16 +136,19 @@ var sniffSignatures = []sniffSig{
mask: []byte("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xFF\xFF"),
ct: "application/vnd.ms-fontobject",
},
- &exactSig{[]byte("\x00\x01\x00\x00"), "application/font-ttf"},
- &exactSig{[]byte("OTTO"), "application/font-off"},
- &exactSig{[]byte("ttcf"), "application/font-cff"},
- &exactSig{[]byte("wOFF"), "application/font-woff"},
+ &exactSig{[]byte("\x00\x01\x00\x00"), "font/ttf"},
+ &exactSig{[]byte("OTTO"), "font/otf"},
+ &exactSig{[]byte("ttcf"), "font/collection"},
+ &exactSig{[]byte("wOFF"), "font/woff"},
+ &exactSig{[]byte("wOF2"), "font/woff2"},
&exactSig{[]byte("\x1A\x45\xDF\xA3"), "video/webm"},
&exactSig{[]byte("\x52\x61\x72\x20\x1A\x07\x00"), "application/x-rar-compressed"},
&exactSig{[]byte("\x50\x4B\x03\x04"), "application/zip"},
&exactSig{[]byte("\x1F\x8B\x08"), "application/x-gzip"},
+ &exactSig{[]byte("\x00\x61\x73\x6D"), "application/wasm"},
+
mp4Sig{},
textSig{}, // should be last
diff --git a/libgo/go/net/http/sniff_test.go b/libgo/go/net/http/sniff_test.go
index bf1f6be..b4d3c9f 100644
--- a/libgo/go/net/http/sniff_test.go
+++ b/libgo/go/net/http/sniff_test.go
@@ -58,14 +58,14 @@ var sniffTests = []struct {
// Font types.
// {"MS.FontObject", []byte("\x00\x00")},
- {"TTF sample I", []byte("\x00\x01\x00\x00\x00\x17\x01\x00\x00\x04\x01\x60\x4f"), "application/font-ttf"},
- {"TTF sample II", []byte("\x00\x01\x00\x00\x00\x0e\x00\x80\x00\x03\x00\x60\x46"), "application/font-ttf"},
+ {"TTF sample I", []byte("\x00\x01\x00\x00\x00\x17\x01\x00\x00\x04\x01\x60\x4f"), "font/ttf"},
+ {"TTF sample II", []byte("\x00\x01\x00\x00\x00\x0e\x00\x80\x00\x03\x00\x60\x46"), "font/ttf"},
- {"OTTO sample I", []byte("\x4f\x54\x54\x4f\x00\x0e\x00\x80\x00\x03\x00\x60\x42\x41\x53\x45"), "application/font-off"},
+ {"OTTO sample I", []byte("\x4f\x54\x54\x4f\x00\x0e\x00\x80\x00\x03\x00\x60\x42\x41\x53\x45"), "font/otf"},
- {"woff sample I", []byte("\x77\x4f\x46\x46\x00\x01\x00\x00\x00\x00\x30\x54\x00\x0d\x00\x00"), "application/font-woff"},
- // Woff2 is not yet recognized, change this test once mime-sniff working group adds woff2
- {"woff2 not recognized", []byte("\x77\x4f\x46\x32\x00\x01\x00\x00\x00"), "application/octet-stream"},
+ {"woff sample I", []byte("\x77\x4f\x46\x46\x00\x01\x00\x00\x00\x00\x30\x54\x00\x0d\x00\x00"), "font/woff"},
+ {"woff2 sample", []byte("\x77\x4f\x46\x32\x00\x01\x00\x00\x00"), "font/woff2"},
+ {"wasm sample", []byte("\x00\x61\x73\x6d\x01\x00"), "application/wasm"},
}
func TestDetectContentType(t *testing.T) {
diff --git a/libgo/go/net/http/socks_bundle.go b/libgo/go/net/http/socks_bundle.go
new file mode 100644
index 0000000..e4314b4
--- /dev/null
+++ b/libgo/go/net/http/socks_bundle.go
@@ -0,0 +1,472 @@
+// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT.
+//go:generate bundle -o socks_bundle.go -dst net/http -prefix socks -underscore golang.org/x/net/internal/socks
+
+// Package socks provides a SOCKS version 5 client implementation.
+//
+// SOCKS protocol version 5 is defined in RFC 1928.
+// Username/Password authentication for SOCKS version 5 is defined in
+// RFC 1929.
+//
+
+package http
+
+import (
+ "context"
+ "errors"
+ "io"
+ "net"
+ "strconv"
+ "time"
+)
+
+var (
+ socksnoDeadline = time.Time{}
+ socksaLongTimeAgo = time.Unix(1, 0)
+)
+
+func (d *socksDialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) {
+ host, port, err := sockssplitHostPort(address)
+ if err != nil {
+ return nil, err
+ }
+ if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
+ c.SetDeadline(deadline)
+ defer c.SetDeadline(socksnoDeadline)
+ }
+ if ctx != context.Background() {
+ errCh := make(chan error, 1)
+ done := make(chan struct{})
+ defer func() {
+ close(done)
+ if ctxErr == nil {
+ ctxErr = <-errCh
+ }
+ }()
+ go func() {
+ select {
+ case <-ctx.Done():
+ c.SetDeadline(socksaLongTimeAgo)
+ errCh <- ctx.Err()
+ case <-done:
+ errCh <- nil
+ }
+ }()
+ }
+
+ b := make([]byte, 0, 6+len(host)) // the size here is just an estimate
+ b = append(b, socksVersion5)
+ if len(d.AuthMethods) == 0 || d.Authenticate == nil {
+ b = append(b, 1, byte(socksAuthMethodNotRequired))
+ } else {
+ ams := d.AuthMethods
+ if len(ams) > 255 {
+ return nil, errors.New("too many authentication methods")
+ }
+ b = append(b, byte(len(ams)))
+ for _, am := range ams {
+ b = append(b, byte(am))
+ }
+ }
+ if _, ctxErr = c.Write(b); ctxErr != nil {
+ return
+ }
+
+ if _, ctxErr = io.ReadFull(c, b[:2]); ctxErr != nil {
+ return
+ }
+ if b[0] != socksVersion5 {
+ return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
+ }
+ am := socksAuthMethod(b[1])
+ if am == socksAuthMethodNoAcceptableMethods {
+ return nil, errors.New("no acceptable authentication methods")
+ }
+ if d.Authenticate != nil {
+ if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil {
+ return
+ }
+ }
+
+ b = b[:0]
+ b = append(b, socksVersion5, byte(d.cmd), 0)
+ if ip := net.ParseIP(host); ip != nil {
+ if ip4 := ip.To4(); ip4 != nil {
+ b = append(b, socksAddrTypeIPv4)
+ b = append(b, ip4...)
+ } else if ip6 := ip.To16(); ip6 != nil {
+ b = append(b, socksAddrTypeIPv6)
+ b = append(b, ip6...)
+ } else {
+ return nil, errors.New("unknown address type")
+ }
+ } else {
+ if len(host) > 255 {
+ return nil, errors.New("FQDN too long")
+ }
+ b = append(b, socksAddrTypeFQDN)
+ b = append(b, byte(len(host)))
+ b = append(b, host...)
+ }
+ b = append(b, byte(port>>8), byte(port))
+ if _, ctxErr = c.Write(b); ctxErr != nil {
+ return
+ }
+
+ if _, ctxErr = io.ReadFull(c, b[:4]); ctxErr != nil {
+ return
+ }
+ if b[0] != socksVersion5 {
+ return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
+ }
+ if cmdErr := socksReply(b[1]); cmdErr != socksStatusSucceeded {
+ return nil, errors.New("unknown error " + cmdErr.String())
+ }
+ if b[2] != 0 {
+ return nil, errors.New("non-zero reserved field")
+ }
+ l := 2
+ var a socksAddr
+ switch b[3] {
+ case socksAddrTypeIPv4:
+ l += net.IPv4len
+ a.IP = make(net.IP, net.IPv4len)
+ case socksAddrTypeIPv6:
+ l += net.IPv6len
+ a.IP = make(net.IP, net.IPv6len)
+ case socksAddrTypeFQDN:
+ if _, err := io.ReadFull(c, b[:1]); err != nil {
+ return nil, err
+ }
+ l += int(b[0])
+ default:
+ return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
+ }
+ if cap(b) < l {
+ b = make([]byte, l)
+ } else {
+ b = b[:l]
+ }
+ if _, ctxErr = io.ReadFull(c, b); ctxErr != nil {
+ return
+ }
+ if a.IP != nil {
+ copy(a.IP, b)
+ } else {
+ a.Name = string(b[:len(b)-2])
+ }
+ a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1])
+ return &a, nil
+}
+
+func sockssplitHostPort(address string) (string, int, error) {
+ host, port, err := net.SplitHostPort(address)
+ if err != nil {
+ return "", 0, err
+ }
+ portnum, err := strconv.Atoi(port)
+ if err != nil {
+ return "", 0, err
+ }
+ if 1 > portnum || portnum > 0xffff {
+ return "", 0, errors.New("port number out of range " + port)
+ }
+ return host, portnum, nil
+}
+
+// A Command represents a SOCKS command.
+type socksCommand int
+
+func (cmd socksCommand) String() string {
+ switch cmd {
+ case socksCmdConnect:
+ return "socks connect"
+ case sockscmdBind:
+ return "socks bind"
+ default:
+ return "socks " + strconv.Itoa(int(cmd))
+ }
+}
+
+// An AuthMethod represents a SOCKS authentication method.
+type socksAuthMethod int
+
+// A Reply represents a SOCKS command reply code.
+type socksReply int
+
+func (code socksReply) String() string {
+ switch code {
+ case socksStatusSucceeded:
+ return "succeeded"
+ case 0x01:
+ return "general SOCKS server failure"
+ case 0x02:
+ return "connection not allowed by ruleset"
+ case 0x03:
+ return "network unreachable"
+ case 0x04:
+ return "host unreachable"
+ case 0x05:
+ return "connection refused"
+ case 0x06:
+ return "TTL expired"
+ case 0x07:
+ return "command not supported"
+ case 0x08:
+ return "address type not supported"
+ default:
+ return "unknown code: " + strconv.Itoa(int(code))
+ }
+}
+
+// Wire protocol constants.
+const (
+ socksVersion5 = 0x05
+
+ socksAddrTypeIPv4 = 0x01
+ socksAddrTypeFQDN = 0x03
+ socksAddrTypeIPv6 = 0x04
+
+ socksCmdConnect socksCommand = 0x01 // establishes an active-open forward proxy connection
+ sockscmdBind socksCommand = 0x02 // establishes a passive-open forward proxy connection
+
+ socksAuthMethodNotRequired socksAuthMethod = 0x00 // no authentication required
+ socksAuthMethodUsernamePassword socksAuthMethod = 0x02 // use username/password
+ socksAuthMethodNoAcceptableMethods socksAuthMethod = 0xff // no acceptable authentication methods
+
+ socksStatusSucceeded socksReply = 0x00
+)
+
+// An Addr represents a SOCKS-specific address.
+// Either Name or IP is used exclusively.
+type socksAddr struct {
+ Name string // fully-qualified domain name
+ IP net.IP
+ Port int
+}
+
+func (a *socksAddr) Network() string { return "socks" }
+
+func (a *socksAddr) String() string {
+ if a == nil {
+ return "<nil>"
+ }
+ port := strconv.Itoa(a.Port)
+ if a.IP == nil {
+ return net.JoinHostPort(a.Name, port)
+ }
+ return net.JoinHostPort(a.IP.String(), port)
+}
+
+// A Conn represents a forward proxy connection.
+type socksConn struct {
+ net.Conn
+
+ boundAddr net.Addr
+}
+
+// BoundAddr returns the address assigned by the proxy server for
+// connecting to the command target address from the proxy server.
+func (c *socksConn) BoundAddr() net.Addr {
+ if c == nil {
+ return nil
+ }
+ return c.boundAddr
+}
+
+// A Dialer holds SOCKS-specific options.
+type socksDialer struct {
+ cmd socksCommand // either CmdConnect or cmdBind
+ proxyNetwork string // network between a proxy server and a client
+ proxyAddress string // proxy server address
+
+ // ProxyDial specifies the optional dial function for
+ // establishing the transport connection.
+ ProxyDial func(context.Context, string, string) (net.Conn, error)
+
+ // AuthMethods specifies the list of request authention
+ // methods.
+ // If empty, SOCKS client requests only AuthMethodNotRequired.
+ AuthMethods []socksAuthMethod
+
+ // Authenticate specifies the optional authentication
+ // function. It must be non-nil when AuthMethods is not empty.
+ // It must return an error when the authentication is failed.
+ Authenticate func(context.Context, io.ReadWriter, socksAuthMethod) error
+}
+
+// DialContext connects to the provided address on the provided
+// network.
+//
+// The returned error value may be a net.OpError. When the Op field of
+// net.OpError contains "socks", the Source field contains a proxy
+// server address and the Addr field contains a command target
+// address.
+//
+// See func Dial of the net package of standard library for a
+// description of the network and address parameters.
+func (d *socksDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
+ if err := d.validateTarget(network, address); err != nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ if ctx == nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
+ }
+ var err error
+ var c net.Conn
+ if d.ProxyDial != nil {
+ c, err = d.ProxyDial(ctx, d.proxyNetwork, d.proxyAddress)
+ } else {
+ var dd net.Dialer
+ c, err = dd.DialContext(ctx, d.proxyNetwork, d.proxyAddress)
+ }
+ if err != nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ a, err := d.connect(ctx, c, address)
+ if err != nil {
+ c.Close()
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ return &socksConn{Conn: c, boundAddr: a}, nil
+}
+
+// DialWithConn initiates a connection from SOCKS server to the target
+// network and address using the connection c that is already
+// connected to the SOCKS server.
+//
+// It returns the connection's local address assigned by the SOCKS
+// server.
+func (d *socksDialer) DialWithConn(ctx context.Context, c net.Conn, network, address string) (net.Addr, error) {
+ if err := d.validateTarget(network, address); err != nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ if ctx == nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")}
+ }
+ a, err := d.connect(ctx, c, address)
+ if err != nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ return a, nil
+}
+
+// Dial connects to the provided address on the provided network.
+//
+// Unlike DialContext, it returns a raw transport connection instead
+// of a forward proxy connection.
+//
+// Deprecated: Use DialContext or DialWithConn instead.
+func (d *socksDialer) Dial(network, address string) (net.Conn, error) {
+ if err := d.validateTarget(network, address); err != nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ var err error
+ var c net.Conn
+ if d.ProxyDial != nil {
+ c, err = d.ProxyDial(context.Background(), d.proxyNetwork, d.proxyAddress)
+ } else {
+ c, err = net.Dial(d.proxyNetwork, d.proxyAddress)
+ }
+ if err != nil {
+ proxy, dst, _ := d.pathAddrs(address)
+ return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err}
+ }
+ if _, err := d.DialWithConn(context.Background(), c, network, address); err != nil {
+ return nil, err
+ }
+ return c, nil
+}
+
+func (d *socksDialer) validateTarget(network, address string) error {
+ switch network {
+ case "tcp", "tcp6", "tcp4":
+ default:
+ return errors.New("network not implemented")
+ }
+ switch d.cmd {
+ case socksCmdConnect, sockscmdBind:
+ default:
+ return errors.New("command not implemented")
+ }
+ return nil
+}
+
+func (d *socksDialer) pathAddrs(address string) (proxy, dst net.Addr, err error) {
+ for i, s := range []string{d.proxyAddress, address} {
+ host, port, err := sockssplitHostPort(s)
+ if err != nil {
+ return nil, nil, err
+ }
+ a := &socksAddr{Port: port}
+ a.IP = net.ParseIP(host)
+ if a.IP == nil {
+ a.Name = host
+ }
+ if i == 0 {
+ proxy = a
+ } else {
+ dst = a
+ }
+ }
+ return
+}
+
+// NewDialer returns a new Dialer that dials through the provided
+// proxy server's network and address.
+func socksNewDialer(network, address string) *socksDialer {
+ return &socksDialer{proxyNetwork: network, proxyAddress: address, cmd: socksCmdConnect}
+}
+
+const (
+ socksauthUsernamePasswordVersion = 0x01
+ socksauthStatusSucceeded = 0x00
+)
+
+// UsernamePassword are the credentials for the username/password
+// authentication method.
+type socksUsernamePassword struct {
+ Username string
+ Password string
+}
+
+// Authenticate authenticates a pair of username and password with the
+// proxy server.
+func (up *socksUsernamePassword) Authenticate(ctx context.Context, rw io.ReadWriter, auth socksAuthMethod) error {
+ switch auth {
+ case socksAuthMethodNotRequired:
+ return nil
+ case socksAuthMethodUsernamePassword:
+ if len(up.Username) == 0 || len(up.Username) > 255 || len(up.Password) == 0 || len(up.Password) > 255 {
+ return errors.New("invalid username/password")
+ }
+ b := []byte{socksauthUsernamePasswordVersion}
+ b = append(b, byte(len(up.Username)))
+ b = append(b, up.Username...)
+ b = append(b, byte(len(up.Password)))
+ b = append(b, up.Password...)
+ // TODO(mikio): handle IO deadlines and cancelation if
+ // necessary
+ if _, err := rw.Write(b); err != nil {
+ return err
+ }
+ if _, err := io.ReadFull(rw, b[:2]); err != nil {
+ return err
+ }
+ if b[0] != socksauthUsernamePasswordVersion {
+ return errors.New("invalid username/password version")
+ }
+ if b[1] != socksauthStatusSucceeded {
+ return errors.New("username/password authentication failed")
+ }
+ return nil
+ }
+ return errors.New("unsupported authentication method " + strconv.Itoa(int(auth)))
+}
diff --git a/libgo/go/net/http/status.go b/libgo/go/net/http/status.go
index 98645b7..dd72d67 100644
--- a/libgo/go/net/http/status.go
+++ b/libgo/go/net/http/status.go
@@ -5,7 +5,7 @@
package http
// HTTP status codes as registered with IANA.
-// See: http://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml
+// See: https://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml
const (
StatusContinue = 100 // RFC 7231, 6.2.1
StatusSwitchingProtocols = 101 // RFC 7231, 6.2.2
@@ -51,6 +51,7 @@ const (
StatusRequestedRangeNotSatisfiable = 416 // RFC 7233, 4.4
StatusExpectationFailed = 417 // RFC 7231, 6.5.14
StatusTeapot = 418 // RFC 7168, 2.3.3
+ StatusMisdirectedRequest = 421 // RFC 7540, 9.1.2
StatusUnprocessableEntity = 422 // RFC 4918, 11.2
StatusLocked = 423 // RFC 4918, 11.3
StatusFailedDependency = 424 // RFC 4918, 11.4
@@ -117,6 +118,7 @@ var statusText = map[int]string{
StatusRequestedRangeNotSatisfiable: "Requested Range Not Satisfiable",
StatusExpectationFailed: "Expectation Failed",
StatusTeapot: "I'm a teapot",
+ StatusMisdirectedRequest: "Misdirected Request",
StatusUnprocessableEntity: "Unprocessable Entity",
StatusLocked: "Locked",
StatusFailedDependency: "Failed Dependency",
diff --git a/libgo/go/net/http/transfer.go b/libgo/go/net/http/transfer.go
index a400a6a..2c6ba32 100644
--- a/libgo/go/net/http/transfer.go
+++ b/libgo/go/net/http/transfer.go
@@ -11,15 +11,17 @@ import (
"fmt"
"io"
"io/ioutil"
+ "net/http/httptrace"
"net/http/internal"
"net/textproto"
+ "reflect"
"sort"
"strconv"
"strings"
"sync"
"time"
- "golang_org/x/net/lex/httplex"
+ "golang_org/x/net/http/httpguts"
)
// ErrLineTooLong is returned when reading request or response bodies
@@ -105,6 +107,17 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) {
if t.ContentLength < 0 && len(t.TransferEncoding) == 0 && t.shouldSendChunkedRequestBody() {
t.TransferEncoding = []string{"chunked"}
}
+ // If there's a body, conservatively flush the headers
+ // to any bufio.Writer we're writing to, just in case
+ // the server needs the headers early, before we copy
+ // the body and possibly block. We make an exception
+ // for the common standard library in-memory types,
+ // though, to avoid unnecessary TCP packets on the
+ // wire. (Issue 22088.)
+ if t.ContentLength != 0 && !isKnownInMemoryReader(t.Body) {
+ t.FlushHeaders = true
+ }
+
atLeastHTTP11 = true // Transport requests are always 1.1 or 2.0
case *Response:
t.IsResponse = true
@@ -268,11 +281,14 @@ func (t *transferWriter) shouldSendContentLength() bool {
return false
}
-func (t *transferWriter) WriteHeader(w io.Writer) error {
+func (t *transferWriter) writeHeader(w io.Writer, trace *httptrace.ClientTrace) error {
if t.Close && !hasToken(t.Header.get("Connection"), "close") {
if _, err := io.WriteString(w, "Connection: close\r\n"); err != nil {
return err
}
+ if trace != nil && trace.WroteHeaderField != nil {
+ trace.WroteHeaderField("Connection", []string{"close"})
+ }
}
// Write Content-Length and/or Transfer-Encoding whose values are a
@@ -285,10 +301,16 @@ func (t *transferWriter) WriteHeader(w io.Writer) error {
if _, err := io.WriteString(w, strconv.FormatInt(t.ContentLength, 10)+"\r\n"); err != nil {
return err
}
+ if trace != nil && trace.WroteHeaderField != nil {
+ trace.WroteHeaderField("Content-Length", []string{strconv.FormatInt(t.ContentLength, 10)})
+ }
} else if chunked(t.TransferEncoding) {
if _, err := io.WriteString(w, "Transfer-Encoding: chunked\r\n"); err != nil {
return err
}
+ if trace != nil && trace.WroteHeaderField != nil {
+ trace.WroteHeaderField("Transfer-Encoding", []string{"chunked"})
+ }
}
// Write Trailer header
@@ -309,13 +331,16 @@ func (t *transferWriter) WriteHeader(w io.Writer) error {
if _, err := io.WriteString(w, "Trailer: "+strings.Join(keys, ",")+"\r\n"); err != nil {
return err
}
+ if trace != nil && trace.WroteHeaderField != nil {
+ trace.WroteHeaderField("Trailer", keys)
+ }
}
}
return nil
}
-func (t *transferWriter) WriteBody(w io.Writer) error {
+func (t *transferWriter) writeBody(w io.Writer) error {
var err error
var ncopy int64
@@ -390,7 +415,7 @@ func (t *transferReader) protoAtLeast(m, n int) bool {
}
// bodyAllowedForStatus reports whether a given response status code
-// permits a body. See RFC 2616, section 4.4.
+// permits a body. See RFC 7230, section 3.3.
func bodyAllowedForStatus(status int) bool {
switch {
case status >= 100 && status <= 199:
@@ -411,7 +436,7 @@ var (
func suppressedHeaders(status int) []string {
switch {
case status == 304:
- // RFC 2616 section 10.3.5: "the response MUST NOT include other entity-headers"
+ // RFC 7232 section 4.1
return suppressedHeaders304
case !bodyAllowedForStatus(status):
return suppressedHeadersNoBody
@@ -482,7 +507,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
// If there is no Content-Length or chunked Transfer-Encoding on a *Response
// and the status is not 1xx, 204 or 304, then the body is unbounded.
- // See RFC 2616, section 4.4.
+ // See RFC 7230, section 3.3.
switch msg.(type) {
case *Response:
if realLength == -1 &&
@@ -601,7 +626,7 @@ func (t *transferReader) fixTransferEncoding() error {
return nil
}
-// Determine the expected body length, using RFC 2616 Section 4.4. This
+// Determine the expected body length, using RFC 7230 Section 3.3. This
// function is not a method, because ultimately it should be shared by
// ReadResponse and ReadRequest.
func fixLength(isResponse bool, status int, requestMethod string, header Header, te []string) (int64, error) {
@@ -667,7 +692,7 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header,
header.Del("Content-Length")
if isRequest {
- // RFC 2616 neither explicitly permits nor forbids an
+ // RFC 7230 neither explicitly permits nor forbids an
// entity-body on a GET request so we permit one if
// declared, but we default to 0 here (not -1 below)
// if there's no mention of a body.
@@ -690,9 +715,9 @@ func shouldClose(major, minor int, header Header, removeCloseHeader bool) bool {
}
conv := header["Connection"]
- hasClose := httplex.HeaderValuesContainsToken(conv, "close")
+ hasClose := httpguts.HeaderValuesContainsToken(conv, "close")
if major == 1 && minor == 0 {
- return hasClose || !httplex.HeaderValuesContainsToken(conv, "keep-alive")
+ return hasClose || !httpguts.HeaderValuesContainsToken(conv, "keep-alive")
}
if hasClose && removeCloseHeader {
@@ -1009,3 +1034,19 @@ func (fr finishAsyncByteRead) Read(p []byte) (n int, err error) {
}
return
}
+
+var nopCloserType = reflect.TypeOf(ioutil.NopCloser(nil))
+
+// isKnownInMemoryReader reports whether r is a type known to not
+// block on Read. Its caller uses this as an optional optimization to
+// send fewer TCP packets.
+func isKnownInMemoryReader(r io.Reader) bool {
+ switch r.(type) {
+ case *bytes.Reader, *bytes.Buffer, *strings.Reader:
+ return true
+ }
+ if reflect.TypeOf(r) == nopCloserType {
+ return isKnownInMemoryReader(reflect.ValueOf(r).Field(0).Interface().(io.Reader))
+ }
+ return false
+}
diff --git a/libgo/go/net/http/transfer_test.go b/libgo/go/net/http/transfer_test.go
index 48cd540..993ea4e 100644
--- a/libgo/go/net/http/transfer_test.go
+++ b/libgo/go/net/http/transfer_test.go
@@ -6,7 +6,9 @@ package http
import (
"bufio"
+ "bytes"
"io"
+ "io/ioutil"
"strings"
"testing"
)
@@ -62,3 +64,29 @@ func TestFinalChunkedBodyReadEOF(t *testing.T) {
t.Errorf("buf = %q; want %q", buf, want)
}
}
+
+func TestDetectInMemoryReaders(t *testing.T) {
+ pr, _ := io.Pipe()
+ tests := []struct {
+ r io.Reader
+ want bool
+ }{
+ {pr, false},
+
+ {bytes.NewReader(nil), true},
+ {bytes.NewBuffer(nil), true},
+ {strings.NewReader(""), true},
+
+ {ioutil.NopCloser(pr), false},
+
+ {ioutil.NopCloser(bytes.NewReader(nil)), true},
+ {ioutil.NopCloser(bytes.NewBuffer(nil)), true},
+ {ioutil.NopCloser(strings.NewReader("")), true},
+ }
+ for i, tt := range tests {
+ got := isKnownInMemoryReader(tt.r)
+ if got != tt.want {
+ t.Errorf("%d: got = %v; want %v", i, got, tt.want)
+ }
+ }
+}
diff --git a/libgo/go/net/http/transport.go b/libgo/go/net/http/transport.go
index 7ef8f01..40947ba 100644
--- a/libgo/go/net/http/transport.go
+++ b/libgo/go/net/http/transport.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// HTTP client implementation. See RFC 2616.
+// HTTP client implementation. See RFC 7230 through 7235.
//
// This is the low-level Transport implementation of RoundTripper.
// The high-level interface is in client.go.
@@ -21,15 +21,17 @@ import (
"log"
"net"
"net/http/httptrace"
+ "net/textproto"
"net/url"
"os"
+ "reflect"
"strings"
"sync"
"sync/atomic"
"time"
- "golang_org/x/net/lex/httplex"
- "golang_org/x/net/proxy"
+ "golang_org/x/net/http/httpguts"
+ "golang_org/x/net/http/httpproxy"
)
// DefaultTransport is the default implementation of Transport and is
@@ -54,6 +56,15 @@ var DefaultTransport RoundTripper = &Transport{
// MaxIdleConnsPerHost.
const DefaultMaxIdleConnsPerHost = 2
+// connsPerHostClosedCh is a closed channel used by MaxConnsPerHost
+// for the property that receives from a closed channel return the
+// zero value.
+var connsPerHostClosedCh = make(chan struct{})
+
+func init() {
+ close(connsPerHostClosedCh)
+}
+
// Transport is an implementation of RoundTripper that supports HTTP,
// HTTPS, and HTTP proxies (for either HTTP or HTTPS with CONNECT).
//
@@ -82,6 +93,13 @@ const DefaultMaxIdleConnsPerHost = 2
// being written while the response body is streamed. Go's HTTP/2
// implementation does support full duplex, but many CONNECT proxies speak
// HTTP/1.x.
+//
+// Responses with status codes in the 1xx range are either handled
+// automatically (100 expect-continue) or ignored. The one
+// exception is HTTP status code 101 (Switching Protocols), which is
+// considered a terminal status and returned by RoundTrip. To see the
+// ignored 1xx responses, use the httptrace trace package's
+// ClientTrace.Got1xxResponse.
type Transport struct {
idleMu sync.Mutex
wantIdle bool // user has requested to close all idle conns
@@ -95,12 +113,16 @@ type Transport struct {
altMu sync.Mutex // guards changing altProto only
altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme
+ connCountMu sync.Mutex
+ connPerHostCount map[connectMethodKey]int
+ connPerHostAvailable map[connectMethodKey]chan struct{}
+
// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
// request is aborted with the provided error.
//
- // The proxy type is determined by the URL scheme. "http"
- // and "socks5" are supported. If the scheme is empty,
+ // The proxy type is determined by the URL scheme. "http",
+ // "https", and "socks5" are supported. If the scheme is empty,
// "http" is assumed.
//
// If Proxy is nil or returns a nil *URL, no proxy is used.
@@ -109,10 +131,20 @@ type Transport struct {
// DialContext specifies the dial function for creating unencrypted TCP connections.
// If DialContext is nil (and the deprecated Dial below is also nil),
// then the transport dials using package net.
+ //
+ // DialContext runs concurrently with calls to RoundTrip.
+ // A RoundTrip call that initiates a dial may end up using
+ // an connection dialed previously when the earlier connection
+ // becomes idle before the later DialContext completes.
DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
// Dial specifies the dial function for creating unencrypted TCP connections.
//
+ // Dial runs concurrently with calls to RoundTrip.
+ // A RoundTrip call that initiates a dial may end up using
+ // an connection dialed previously when the earlier connection
+ // becomes idle before the later Dial completes.
+ //
// Deprecated: Use DialContext instead, which allows the transport
// to cancel dials as soon as they are no longer needed.
// If both are set, DialContext takes priority.
@@ -139,8 +171,11 @@ type Transport struct {
// wait for a TLS handshake. Zero means no timeout.
TLSHandshakeTimeout time.Duration
- // DisableKeepAlives, if true, prevents re-use of TCP connections
- // between different HTTP requests.
+ // DisableKeepAlives, if true, disables HTTP keep-alives and
+ // will only use the connection to the server for a single
+ // HTTP request.
+ //
+ // This is unrelated to the similarly named TCP keep-alives.
DisableKeepAlives bool
// DisableCompression, if true, prevents the Transport from
@@ -162,6 +197,18 @@ type Transport struct {
// DefaultMaxIdleConnsPerHost is used.
MaxIdleConnsPerHost int
+ // MaxConnsPerHost optionally limits the total number of
+ // connections per host, including connections in the dialing,
+ // active, and idle states. On limit violation, dials will block.
+ //
+ // Zero means no limit.
+ //
+ // For HTTP/2, this currently only controls the number of new
+ // connections being created at a time, instead of the total
+ // number. In practice, hosts using HTTP/2 only have about one
+ // idle connection, though.
+ MaxConnsPerHost int
+
// IdleConnTimeout is the maximum amount of time an idle
// (keep-alive) connection will remain idle before closing
// itself.
@@ -209,9 +256,17 @@ type Transport struct {
// nextProtoOnce guards initialization of TLSNextProto and
// h2transport (via onceSetNextProtoDefaults)
nextProtoOnce sync.Once
- h2transport *http2Transport // non-nil if http2 wired up
+ h2transport h2Transport // non-nil if http2 wired up
+}
- // TODO: tunable on max per-host TCP dials in flight (Issue 13957)
+// h2Transport is the interface we expect to be able to call from
+// net/http against an *http2.Transport that's either bundled into
+// h2_bundle.go or supplied by the user via x/net/http2.
+//
+// We name it with the "h2" prefix to stay out of the "http2" prefix
+// namespace used by x/tools/cmd/bundle for h2_bundle.go.
+type h2Transport interface {
+ CloseIdleConnections()
}
// onceSetNextProtoDefaults initializes TLSNextProto.
@@ -220,6 +275,21 @@ func (t *Transport) onceSetNextProtoDefaults() {
if strings.Contains(os.Getenv("GODEBUG"), "http2client=0") {
return
}
+
+ // If they've already configured http2 with
+ // golang.org/x/net/http2 instead of the bundled copy, try to
+ // get at its http2.Transport value (via the the "https"
+ // altproto map) so we can call CloseIdleConnections on it if
+ // requested. (Issue 22891)
+ altProto, _ := t.altProto.Load().(map[string]RoundTripper)
+ if rv := reflect.ValueOf(altProto["https"]); rv.IsValid() && rv.Type().Kind() == reflect.Struct && rv.Type().NumField() == 1 {
+ if v := rv.Field(0); v.CanInterface() {
+ if h2i, ok := v.Interface().(h2Transport); ok {
+ t.h2transport = h2i
+ }
+ }
+ }
+
if t.TLSNextProto != nil {
// This is the documented way to disable http2 on a
// Transport.
@@ -273,39 +343,7 @@ func (t *Transport) onceSetNextProtoDefaults() {
// As a special case, if req.URL.Host is "localhost" (with or without
// a port number), then a nil URL and nil error will be returned.
func ProxyFromEnvironment(req *Request) (*url.URL, error) {
- var proxy string
- if req.URL.Scheme == "https" {
- proxy = httpsProxyEnv.Get()
- }
- if proxy == "" {
- proxy = httpProxyEnv.Get()
- if proxy != "" && os.Getenv("REQUEST_METHOD") != "" {
- return nil, errors.New("net/http: refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")
- }
- }
- if proxy == "" {
- return nil, nil
- }
- if !useProxy(canonicalAddr(req.URL)) {
- return nil, nil
- }
- proxyURL, err := url.Parse(proxy)
- if err != nil ||
- (proxyURL.Scheme != "http" &&
- proxyURL.Scheme != "https" &&
- proxyURL.Scheme != "socks5") {
- // proxy was bogus. Try prepending "http://" to it and
- // see if that parses correctly. If not, we fall
- // through and complain about the original one.
- if proxyURL, err := url.Parse("http://" + proxy); err == nil {
- return proxyURL, nil
- }
-
- }
- if err != nil {
- return nil, fmt.Errorf("invalid proxy address %q: %v", proxy, err)
- }
- return proxyURL, nil
+ return envProxyFunc()(req.URL)
}
// ProxyURL returns a proxy function (for use in a Transport)
@@ -343,11 +381,8 @@ func (tr *transportRequest) setError(err error) {
tr.mu.Unlock()
}
-// RoundTrip implements the RoundTripper interface.
-//
-// For higher-level HTTP client support (such as handling of cookies
-// and redirects), see Get, Post, and the Client type.
-func (t *Transport) RoundTrip(req *Request) (*Response, error) {
+// roundTrip implements a RoundTripper over HTTP.
+func (t *Transport) roundTrip(req *Request) (*Response, error) {
t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
ctx := req.Context()
trace := httptrace.ContextClientTrace(ctx)
@@ -364,11 +399,11 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) {
isHTTP := scheme == "http" || scheme == "https"
if isHTTP {
for k, vv := range req.Header {
- if !httplex.ValidHeaderFieldName(k) {
+ if !httpguts.ValidHeaderFieldName(k) {
return nil, fmt.Errorf("net/http: invalid header field name %q", k)
}
for _, v := range vv {
- if !httplex.ValidHeaderFieldValue(v) {
+ if !httpguts.ValidHeaderFieldValue(v) {
return nil, fmt.Errorf("net/http: invalid header field value %q for key %v", v, k)
}
}
@@ -394,6 +429,13 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) {
}
for {
+ select {
+ case <-ctx.Done():
+ req.closeBody()
+ return nil, ctx.Err()
+ default:
+ }
+
// treq gets modified by roundTrip, so we need to recreate for each retry.
treq := &transportRequest{Request: req, trace: trace}
cm, err := t.connectMethodForRequest(treq)
@@ -416,7 +458,8 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) {
var resp *Response
if pconn.alt != nil {
// HTTP/2 path.
- t.setReqCanceler(req, nil) // not cancelable with CancelRequest
+ t.decHostConnCount(cm.key()) // don't count cached http2 conns toward conns per host
+ t.setReqCanceler(req, nil) // not cancelable with CancelRequest
resp, err = pconn.alt.RoundTrip(req)
} else {
resp, err = pconn.roundTrip(treq)
@@ -575,44 +618,25 @@ func (t *Transport) cancelRequest(req *Request, err error) {
//
var (
- httpProxyEnv = &envOnce{
- names: []string{"HTTP_PROXY", "http_proxy"},
- }
- httpsProxyEnv = &envOnce{
- names: []string{"HTTPS_PROXY", "https_proxy"},
- }
- noProxyEnv = &envOnce{
- names: []string{"NO_PROXY", "no_proxy"},
- }
+ // proxyConfigOnce guards proxyConfig
+ envProxyOnce sync.Once
+ envProxyFuncValue func(*url.URL) (*url.URL, error)
)
-// envOnce looks up an environment variable (optionally by multiple
-// names) once. It mitigates expensive lookups on some platforms
-// (e.g. Windows).
-type envOnce struct {
- names []string
- once sync.Once
- val string
+// defaultProxyConfig returns a ProxyConfig value looked up
+// from the environment. This mitigates expensive lookups
+// on some platforms (e.g. Windows).
+func envProxyFunc() func(*url.URL) (*url.URL, error) {
+ envProxyOnce.Do(func() {
+ envProxyFuncValue = httpproxy.FromEnvironment().ProxyFunc()
+ })
+ return envProxyFuncValue
}
-func (e *envOnce) Get() string {
- e.once.Do(e.init)
- return e.val
-}
-
-func (e *envOnce) init() {
- for _, n := range e.names {
- e.val = os.Getenv(n)
- if e.val != "" {
- return
- }
- }
-}
-
-// reset is used by tests
-func (e *envOnce) reset() {
- e.once = sync.Once{}
- e.val = ""
+// resetProxyConfig is used by tests.
+func resetProxyConfig() {
+ envProxyOnce = sync.Once{}
+ envProxyFuncValue = nil
}
func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) {
@@ -813,12 +837,6 @@ func (t *Transport) getIdleConn(cm connectMethod) (pconn *persistConn, idleSince
// carry on.
continue
}
- if pconn.idleTimer != nil && !pconn.idleTimer.Stop() {
- // We picked this conn at the ~same time it
- // was expiring and it's trying to close
- // itself in another goroutine. Don't use it.
- continue
- }
return pconn, pconn.idleAt
}
}
@@ -934,6 +952,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
err error
}
dialc := make(chan dialRes)
+ cmKey := cm.key()
// Copy these hooks so we don't race on the postPendingDial in
// the goroutine we launch. Issue 11136.
@@ -945,6 +964,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
go func() {
if v := <-dialc; v.err == nil {
t.putOrCloseIdleConn(v.pc)
+ } else {
+ t.decHostConnCount(cmKey)
}
testHookPostPendingDial()
}()
@@ -953,6 +974,27 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
cancelc := make(chan error, 1)
t.setReqCanceler(req, func(err error) { cancelc <- err })
+ if t.MaxConnsPerHost > 0 {
+ select {
+ case <-t.incHostConnCount(cmKey):
+ // count below conn per host limit; proceed
+ case pc := <-t.getIdleConnCh(cm):
+ if trace != nil && trace.GotConn != nil {
+ trace.GotConn(httptrace.GotConnInfo{Conn: pc.conn, Reused: pc.isReused()})
+ }
+ return pc, nil
+ case <-req.Cancel:
+ return nil, errRequestCanceledConn
+ case <-req.Context().Done():
+ return nil, req.Context().Err()
+ case err := <-cancelc:
+ if err == errRequestCanceled {
+ err = errRequestCanceledConn
+ }
+ return nil, err
+ }
+ }
+
go func() {
pc, err := t.dialConn(ctx, cm)
dialc <- dialRes{pc, err}
@@ -970,6 +1012,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
}
// Our dial failed. See why to return a nicer error
// value.
+ t.decHostConnCount(cmKey)
select {
case <-req.Cancel:
// It was an error due to cancelation, so prioritize that
@@ -1013,21 +1056,81 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
}
}
-type oneConnDialer <-chan net.Conn
-
-func newOneConnDialer(c net.Conn) proxy.Dialer {
- ch := make(chan net.Conn, 1)
- ch <- c
- return oneConnDialer(ch)
+// incHostConnCount increments the count of connections for a
+// given host. It returns an already-closed channel if the count
+// is not at its limit; otherwise it returns a channel which is
+// notified when the count is below the limit.
+func (t *Transport) incHostConnCount(cmKey connectMethodKey) <-chan struct{} {
+ if t.MaxConnsPerHost <= 0 {
+ return connsPerHostClosedCh
+ }
+ t.connCountMu.Lock()
+ defer t.connCountMu.Unlock()
+ if t.connPerHostCount[cmKey] == t.MaxConnsPerHost {
+ if t.connPerHostAvailable == nil {
+ t.connPerHostAvailable = make(map[connectMethodKey]chan struct{})
+ }
+ ch, ok := t.connPerHostAvailable[cmKey]
+ if !ok {
+ ch = make(chan struct{})
+ t.connPerHostAvailable[cmKey] = ch
+ }
+ return ch
+ }
+ if t.connPerHostCount == nil {
+ t.connPerHostCount = make(map[connectMethodKey]int)
+ }
+ t.connPerHostCount[cmKey]++
+ // return a closed channel to avoid race: if decHostConnCount is called
+ // after incHostConnCount and during the nil check, decHostConnCount
+ // will delete the channel since it's not being listened on yet.
+ return connsPerHostClosedCh
}
-func (d oneConnDialer) Dial(network, addr string) (net.Conn, error) {
+// decHostConnCount decrements the count of connections
+// for a given host.
+// See Transport.MaxConnsPerHost.
+func (t *Transport) decHostConnCount(cmKey connectMethodKey) {
+ if t.MaxConnsPerHost <= 0 {
+ return
+ }
+ t.connCountMu.Lock()
+ defer t.connCountMu.Unlock()
+ t.connPerHostCount[cmKey]--
select {
- case c := <-d:
- return c, nil
+ case t.connPerHostAvailable[cmKey] <- struct{}{}:
default:
- return nil, io.EOF
+ // close channel before deleting avoids getConn waiting forever in
+ // case getConn has reference to channel but hasn't started waiting.
+ // This could lead to more than MaxConnsPerHost in the unlikely case
+ // that > 1 go routine has fetched the channel but none started waiting.
+ if t.connPerHostAvailable[cmKey] != nil {
+ close(t.connPerHostAvailable[cmKey])
+ }
+ delete(t.connPerHostAvailable, cmKey)
+ }
+ if t.connPerHostCount[cmKey] == 0 {
+ delete(t.connPerHostCount, cmKey)
+ }
+}
+
+// connCloseListener wraps a connection, the transport that dialed it
+// and the connected-to host key so the host connection count can be
+// transparently decremented by whatever closes the embedded connection.
+type connCloseListener struct {
+ net.Conn
+ t *Transport
+ cmKey connectMethodKey
+ didClose int32
+}
+
+func (c *connCloseListener) Close() error {
+ if atomic.AddInt32(&c.didClose, 1) != 1 {
+ return nil
}
+ err := c.Conn.Close()
+ c.t.decHostConnCount(c.cmKey)
+ return err
}
// The connect method and the transport can both specify a TLS
@@ -1078,12 +1181,6 @@ func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) erro
}
return err
}
- if !cfg.InsecureSkipVerify {
- if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
- plainConn.Close()
- return err
- }
- }
cs := tlsConn.ConnectionState()
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(cs, nil)
@@ -1162,18 +1259,19 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
// Do nothing. Not using a proxy.
case cm.proxyURL.Scheme == "socks5":
conn := pconn.conn
- var auth *proxy.Auth
+ d := socksNewDialer("tcp", conn.RemoteAddr().String())
if u := cm.proxyURL.User; u != nil {
- auth = &proxy.Auth{}
- auth.User = u.Username()
+ auth := &socksUsernamePassword{
+ Username: u.Username(),
+ }
auth.Password, _ = u.Password()
+ d.AuthMethods = []socksAuthMethod{
+ socksAuthMethodNotRequired,
+ socksAuthMethodUsernamePassword,
+ }
+ d.Authenticate = auth.Authenticate
}
- p, err := proxy.SOCKS5("", cm.addr(), auth, newOneConnDialer(conn))
- if err != nil {
- conn.Close()
- return nil, err
- }
- if _, err := p.Dial("tcp", cm.targetAddr); err != nil {
+ if _, err := d.DialWithConn(ctx, conn, "tcp", cm.targetAddr); err != nil {
conn.Close()
return nil, err
}
@@ -1232,6 +1330,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
}
}
+ if t.MaxConnsPerHost > 0 {
+ pconn.conn = &connCloseListener{Conn: pconn.conn, t: t, cmKey: pconn.cacheKey}
+ }
pconn.br = bufio.NewReader(pconn)
pconn.bw = bufio.NewWriter(persistConnWriter{pconn})
go pconn.readLoop()
@@ -1255,63 +1356,6 @@ func (w persistConnWriter) Write(p []byte) (n int, err error) {
return
}
-// useProxy reports whether requests to addr should use a proxy,
-// according to the NO_PROXY or no_proxy environment variable.
-// addr is always a canonicalAddr with a host and port.
-func useProxy(addr string) bool {
- if len(addr) == 0 {
- return true
- }
- host, _, err := net.SplitHostPort(addr)
- if err != nil {
- return false
- }
- if host == "localhost" {
- return false
- }
- if ip := net.ParseIP(host); ip != nil {
- if ip.IsLoopback() {
- return false
- }
- }
-
- noProxy := noProxyEnv.Get()
- if noProxy == "*" {
- return false
- }
-
- addr = strings.ToLower(strings.TrimSpace(addr))
- if hasPort(addr) {
- addr = addr[:strings.LastIndex(addr, ":")]
- }
-
- for _, p := range strings.Split(noProxy, ",") {
- p = strings.ToLower(strings.TrimSpace(p))
- if len(p) == 0 {
- continue
- }
- if hasPort(p) {
- p = p[:strings.LastIndex(p, ":")]
- }
- if addr == p {
- return false
- }
- if len(p) == 0 {
- // There is no host part, likely the entry is malformed; ignore.
- continue
- }
- if p[0] == '.' && (strings.HasSuffix(addr, p) || addr == p[1:]) {
- // no_proxy ".foo.com" matches "bar.foo.com" or "foo.com"
- return false
- }
- if p[0] != '.' && strings.HasSuffix(addr, p) && addr[len(addr)-len(p)-1] == '.' {
- // no_proxy "foo.com" matches "bar.foo.com"
- return false
- }
- }
- return true
-}
-
// connectMethod is the map key (in its String form) for keeping persistent
// TCP connections alive for subsequent HTTP requests.
//
@@ -1764,26 +1808,45 @@ func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTr
trace.GotFirstResponseByte()
}
}
- resp, err = ReadResponse(pc.br, rc.req)
- if err != nil {
- return
- }
- if rc.continueCh != nil {
- if resp.StatusCode == 100 {
- if trace != nil && trace.Got100Continue != nil {
- trace.Got100Continue()
- }
- rc.continueCh <- struct{}{}
- } else {
- close(rc.continueCh)
- }
- }
- if resp.StatusCode == 100 {
- pc.readLimit = pc.maxHeaderResponseSize() // reset the limit
+ num1xx := 0 // number of informational 1xx headers received
+ const max1xxResponses = 5 // arbitrary bound on number of informational responses
+
+ continueCh := rc.continueCh
+ for {
resp, err = ReadResponse(pc.br, rc.req)
if err != nil {
return
}
+ resCode := resp.StatusCode
+ if continueCh != nil {
+ if resCode == 100 {
+ if trace != nil && trace.Got100Continue != nil {
+ trace.Got100Continue()
+ }
+ continueCh <- struct{}{}
+ continueCh = nil
+ } else if resCode >= 200 {
+ close(continueCh)
+ continueCh = nil
+ }
+ }
+ is1xx := 100 <= resCode && resCode <= 199
+ // treat 101 as a terminal status, see issue 26161
+ is1xxNonTerminal := is1xx && resCode != StatusSwitchingProtocols
+ if is1xxNonTerminal {
+ num1xx++
+ if num1xx > max1xxResponses {
+ return nil, errors.New("net/http: too many 1xx informational responses")
+ }
+ pc.readLimit = pc.maxHeaderResponseSize() // reset the limit
+ if trace != nil && trace.Got1xxResponse != nil {
+ if err := trace.Got1xxResponse(resCode, textproto.MIMEHeader(resp.Header)); err != nil {
+ return nil, err
+ }
+ }
+ continue
+ }
+ break
}
resp.TLS = pc.tlsState
return
@@ -1855,6 +1918,11 @@ func (pc *persistConn) writeLoop() {
}
}
+// maxWriteWaitBeforeConnReuse is how long the a Transport RoundTrip
+// will wait to see the Request's Body.Write result after getting a
+// response from the server. See comments in (*persistConn).wroteRequest.
+const maxWriteWaitBeforeConnReuse = 50 * time.Millisecond
+
// wroteRequest is a check before recycling a connection that the previous write
// (from writeLoop above) happened and was successful.
func (pc *persistConn) wroteRequest() bool {
@@ -1877,7 +1945,7 @@ func (pc *persistConn) wroteRequest() bool {
select {
case err := <-pc.writeErrCh:
return err == nil
- case <-time.After(50 * time.Millisecond):
+ case <-time.After(maxWriteWaitBeforeConnReuse):
return false
}
}
@@ -1979,7 +2047,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
//
// Note that we don't request this for HEAD requests,
// due to a bug in nginx:
- // http://trac.nginx.org/nginx/ticket/358
+ // https://trac.nginx.org/nginx/ticket/358
// https://golang.org/issue/5522
//
// We don't request gzip if the request is for a range, since
diff --git a/libgo/go/net/http/transport_test.go b/libgo/go/net/http/transport_test.go
index 5588077..aa8beb9 100644
--- a/libgo/go/net/http/transport_test.go
+++ b/libgo/go/net/http/transport_test.go
@@ -16,6 +16,7 @@ import (
"context"
"crypto/rand"
"crypto/tls"
+ "crypto/x509"
"encoding/binary"
"errors"
"fmt"
@@ -30,6 +31,7 @@ import (
"net/http/httptrace"
"net/http/httputil"
"net/http/internal"
+ "net/textproto"
"net/url"
"os"
"reflect"
@@ -444,27 +446,95 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) {
if e, g := 1, len(keys); e != g {
t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g)
}
- cacheKey := "|http|" + ts.Listener.Addr().String()
+ addr := ts.Listener.Addr().String()
+ cacheKey := "|http|" + addr
if keys[0] != cacheKey {
t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0])
}
- if e, g := 1, tr.IdleConnCountForTesting(cacheKey); e != g {
+ if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g {
t.Errorf("after first response, expected %d idle conns; got %d", e, g)
}
resch <- "res2"
<-donech
- if g, w := tr.IdleConnCountForTesting(cacheKey), 2; g != w {
+ if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w {
t.Errorf("after second response, idle conns = %d; want %d", g, w)
}
resch <- "res3"
<-donech
- if g, w := tr.IdleConnCountForTesting(cacheKey), maxIdleConnsPerHost; g != w {
+ if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w {
t.Errorf("after third response, idle conns = %d; want %d", g, w)
}
}
+func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ _, err := w.Write([]byte("foo"))
+ if err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+ }))
+ defer ts.Close()
+ c := ts.Client()
+ tr := c.Transport.(*Transport)
+ dialStarted := make(chan struct{})
+ stallDial := make(chan struct{})
+ tr.Dial = func(network, addr string) (net.Conn, error) {
+ dialStarted <- struct{}{}
+ <-stallDial
+ return net.Dial(network, addr)
+ }
+
+ tr.DisableKeepAlives = true
+ tr.MaxConnsPerHost = 1
+
+ preDial := make(chan struct{})
+ reqComplete := make(chan struct{})
+ doReq := func(reqId string) {
+ req, _ := NewRequest("GET", ts.URL, nil)
+ trace := &httptrace.ClientTrace{
+ GetConn: func(hostPort string) {
+ preDial <- struct{}{}
+ },
+ }
+ req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
+ resp, err := tr.RoundTrip(req)
+ if err != nil {
+ t.Errorf("unexpected error for request %s: %v", reqId, err)
+ }
+ _, err = ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Errorf("unexpected error for request %s: %v", reqId, err)
+ }
+ reqComplete <- struct{}{}
+ }
+ // get req1 to dial-in-progress
+ go doReq("req1")
+ <-preDial
+ <-dialStarted
+
+ // get req2 to waiting on conns per host to go down below max
+ go doReq("req2")
+ <-preDial
+ select {
+ case <-dialStarted:
+ t.Error("req2 dial started while req1 dial in progress")
+ return
+ default:
+ }
+
+ // let req1 complete
+ stallDial <- struct{}{}
+ <-reqComplete
+
+ // let req2 complete
+ <-dialStarted
+ stallDial <- struct{}{}
+ <-reqComplete
+}
+
func TestTransportRemovesDeadIdleConnections(t *testing.T) {
setParallel(t)
defer afterTest(t)
@@ -958,7 +1028,7 @@ func TestTransportExpect100Continue(t *testing.T) {
}
}
-func TestSocks5Proxy(t *testing.T) {
+func TestSOCKS5Proxy(t *testing.T) {
defer afterTest(t)
ch := make(chan string, 1)
l := newLocalListener(t)
@@ -995,9 +1065,9 @@ func TestSocks5Proxy(t *testing.T) {
var ipLen int
switch buf[3] {
case 1:
- ipLen = 4
+ ipLen = net.IPv4len
case 4:
- ipLen = 16
+ ipLen = net.IPv6len
default:
t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
return
@@ -2286,6 +2356,7 @@ Content-Length: %d
c := &Client{Transport: tr}
testResponse := func(req *Request, name string, wantCode int) {
+ t.Helper()
res, err := c.Do(req)
if err != nil {
t.Fatalf("%s: Do: %v", name, err)
@@ -2308,13 +2379,90 @@ Content-Length: %d
req.Header.Set("Request-Id", reqID(i))
testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
}
+}
- // And some other informational 1xx but non-100 responses, to test
- // we return them but don't re-use the connection.
- for i := 1; i <= numReqs; i++ {
- req, _ := NewRequest("POST", "http://other.tld/", strings.NewReader(reqBody(i)))
- req.Header.Set("X-Want-Response-Code", "123 Sesame Street")
- testResponse(req, fmt.Sprintf("123, %d/%d", i, numReqs), 123)
+// Issue 17739: the HTTP client must ignore any unknown 1xx
+// informational responses before the actual response.
+func TestTransportIgnore1xxResponses(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, buf, _ := w.(Hijacker).Hijack()
+ buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello"))
+ buf.Flush()
+ conn.Close()
+ }))
+ defer cst.close()
+ cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
+
+ var got bytes.Buffer
+
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
+ Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
+ fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header)
+ return nil
+ },
+ }))
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+
+ res.Write(&got)
+ want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello"
+ if got.String() != want {
+ t.Errorf(" got: %q\nwant: %q\n", got.Bytes(), want)
+ }
+}
+
+func TestTransportLimits1xxResponses(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, buf, _ := w.(Hijacker).Hijack()
+ for i := 0; i < 10; i++ {
+ buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n"))
+ }
+ buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
+ buf.Flush()
+ conn.Close()
+ }))
+ defer cst.close()
+ cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if res != nil {
+ defer res.Body.Close()
+ }
+ got := fmt.Sprint(err)
+ wantSub := "too many 1xx informational responses"
+ if !strings.Contains(got, wantSub) {
+ t.Errorf("Get error = %v; want substring %q", err, wantSub)
+ }
+}
+
+// Issue 26161: the HTTP client must treat 101 responses
+// as the final response.
+func TestTransportTreat101Terminal(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, buf, _ := w.(Hijacker).Hijack()
+ buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n"))
+ buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
+ buf.Flush()
+ conn.Close()
+ }))
+ defer cst.close()
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ if res.StatusCode != StatusSwitchingProtocols {
+ t.Errorf("StatusCode = %v; want 101 Switching Protocols", res.StatusCode)
}
}
@@ -2380,39 +2528,61 @@ var proxyFromEnvTests = []proxyFromEnvTest{
// where HTTP_PROXY can be attacker-controlled.
{env: "http://10.1.2.3:8080", reqmeth: "POST",
want: "<nil>",
- wanterr: errors.New("net/http: refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")},
+ wanterr: errors.New("refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")},
{want: "<nil>"},
{noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
- {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "<nil>"},
+ {noenv: ".example.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
{noenv: "ample.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
{noenv: "example.com", req: "http://foo.example.com/", env: "proxy", want: "<nil>"},
{noenv: ".foo.com", req: "http://example.com/", env: "proxy", want: "http://proxy"},
}
+func testProxyForRequest(t *testing.T, tt proxyFromEnvTest, proxyForRequest func(req *Request) (*url.URL, error)) {
+ t.Helper()
+ reqURL := tt.req
+ if reqURL == "" {
+ reqURL = "http://example.com"
+ }
+ req, _ := NewRequest("GET", reqURL, nil)
+ url, err := proxyForRequest(req)
+ if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
+ t.Errorf("%v: got error = %q, want %q", tt, g, e)
+ return
+ }
+ if got := fmt.Sprintf("%s", url); got != tt.want {
+ t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
+ }
+}
+
func TestProxyFromEnvironment(t *testing.T) {
ResetProxyEnv()
defer ResetProxyEnv()
for _, tt := range proxyFromEnvTests {
- os.Setenv("HTTP_PROXY", tt.env)
- os.Setenv("HTTPS_PROXY", tt.httpsenv)
- os.Setenv("NO_PROXY", tt.noenv)
- os.Setenv("REQUEST_METHOD", tt.reqmeth)
- ResetCachedEnvironment()
- reqURL := tt.req
- if reqURL == "" {
- reqURL = "http://example.com"
- }
- req, _ := NewRequest("GET", reqURL, nil)
- url, err := ProxyFromEnvironment(req)
- if g, e := fmt.Sprintf("%v", err), fmt.Sprintf("%v", tt.wanterr); g != e {
- t.Errorf("%v: got error = %q, want %q", tt, g, e)
- continue
- }
- if got := fmt.Sprintf("%s", url); got != tt.want {
- t.Errorf("%v: got URL = %q, want %q", tt, url, tt.want)
- }
+ testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
+ os.Setenv("HTTP_PROXY", tt.env)
+ os.Setenv("HTTPS_PROXY", tt.httpsenv)
+ os.Setenv("NO_PROXY", tt.noenv)
+ os.Setenv("REQUEST_METHOD", tt.reqmeth)
+ ResetCachedEnvironment()
+ return ProxyFromEnvironment(req)
+ })
+ }
+}
+
+func TestProxyFromEnvironmentLowerCase(t *testing.T) {
+ ResetProxyEnv()
+ defer ResetProxyEnv()
+ for _, tt := range proxyFromEnvTests {
+ testProxyForRequest(t, tt, func(req *Request) (*url.URL, error) {
+ os.Setenv("http_proxy", tt.env)
+ os.Setenv("https_proxy", tt.httpsenv)
+ os.Setenv("no_proxy", tt.noenv)
+ os.Setenv("REQUEST_METHOD", tt.reqmeth)
+ ResetCachedEnvironment()
+ return ProxyFromEnvironment(req)
+ })
}
}
@@ -2880,9 +3050,16 @@ func TestRetryRequestsOnError(t *testing.T) {
defer SetRoundTripRetried(nil)
for i := 0; i < 3; i++ {
+ t0 := time.Now()
res, err := c.Do(tc.req())
if err != nil {
- t.Fatalf("i=%d: Do = %v", i, err)
+ if time.Since(t0) < MaxWriteWaitBeforeConnReuse/2 {
+ mu.Lock()
+ got := logbuf.String()
+ mu.Unlock()
+ t.Fatalf("i=%d: Do = %v; log:\n%s", i, err, got)
+ }
+ t.Skipf("connection likely wasn't recycled within %d, interfering with actual test; skipping", MaxWriteWaitBeforeConnReuse)
}
res.Body.Close()
}
@@ -3016,7 +3193,7 @@ func TestRoundTripReturnsProxyError(t *testing.T) {
func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
tr := &Transport{}
wantIdle := func(when string, n int) bool {
- got := tr.IdleConnCountForTesting("|http|example.com") // key used by PutIdleTestConn
+ got := tr.IdleConnCountForTesting("http", "example.com") // key used by PutIdleTestConn
if got == n {
return true
}
@@ -3024,10 +3201,10 @@ func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
return false
}
wantIdle("start", 0)
- if !tr.PutIdleTestConn() {
+ if !tr.PutIdleTestConn("http", "example.com") {
t.Fatal("put failed")
}
- if !tr.PutIdleTestConn() {
+ if !tr.PutIdleTestConn("http", "example.com") {
t.Fatal("second put failed")
}
wantIdle("after put", 2)
@@ -3036,7 +3213,7 @@ func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
t.Error("should be idle after CloseIdleConnections")
}
wantIdle("after close idle", 0)
- if tr.PutIdleTestConn() {
+ if tr.PutIdleTestConn("http", "example.com") {
t.Fatal("put didn't fail")
}
wantIdle("after second put", 0)
@@ -3045,7 +3222,7 @@ func TestTransportCloseIdleConnsThenReturn(t *testing.T) {
if tr.IsIdleForTesting() {
t.Error("shouldn't be idle after RequestIdleConnChForTesting")
}
- if !tr.PutIdleTestConn() {
+ if !tr.PutIdleTestConn("http", "example.com") {
t.Fatal("after re-activation")
}
wantIdle("after final put", 1)
@@ -3263,8 +3440,8 @@ func TestTransportFlushesBodyChunks(t *testing.T) {
defer res.Body.Close()
want := []string{
- "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n" +
- "5\r\nnum0\n\r\n",
+ "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n",
+ "5\r\nnum0\n\r\n",
"5\r\nnum1\n\r\n",
"5\r\nnum2\n\r\n",
"0\r\n\r\n",
@@ -3274,6 +3451,40 @@ func TestTransportFlushesBodyChunks(t *testing.T) {
}
}
+// Issue 22088: flush Transport request headers if we're not sure the body won't block on read.
+func TestTransportFlushesRequestHeader(t *testing.T) {
+ defer afterTest(t)
+ gotReq := make(chan struct{})
+ cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ close(gotReq)
+ }))
+ defer cst.close()
+
+ pr, pw := io.Pipe()
+ req, err := NewRequest("POST", cst.ts.URL, pr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ gotRes := make(chan struct{})
+ go func() {
+ defer close(gotRes)
+ res, err := cst.tr.RoundTrip(req)
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ res.Body.Close()
+ }()
+
+ select {
+ case <-gotReq:
+ pw.Close()
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout waiting for handler to get request")
+ }
+ <-gotRes
+}
+
// Issue 11745.
func TestTransportPrefersResponseOverWriteError(t *testing.T) {
if testing.Short() {
@@ -3578,8 +3789,12 @@ func TestTransportEventTrace_NoHooks_h2(t *testing.T) { testTransportEventTrace(
func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
defer afterTest(t)
const resBody = "some body"
- gotWroteReqEvent := make(chan struct{})
+ gotWroteReqEvent := make(chan struct{}, 500)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method == "GET" {
+ // Do nothing for the second request.
+ return
+ }
if _, err := ioutil.ReadAll(r.Body); err != nil {
t.Error(err)
}
@@ -3620,7 +3835,9 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil
})
- req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader("some body"))
+ body := "some body"
+ req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader(body))
+ req.Header["X-Foo-Multiple-Vals"] = []string{"bar", "baz"}
trace := &httptrace.ClientTrace{
GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
@@ -3635,11 +3852,17 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
}
logf("ConnectDone: connected to %s %s = %v", network, addr, err)
},
+ WroteHeaderField: func(key string, value []string) {
+ logf("WroteHeaderField: %s: %v", key, value)
+ },
+ WroteHeaders: func() {
+ logf("WroteHeaders")
+ },
Wait100Continue: func() { logf("Wait100Continue") },
Got100Continue: func() { logf("Got100Continue") },
WroteRequest: func(e httptrace.WroteRequestInfo) {
logf("WroteRequest: %+v", e)
- close(gotWroteReqEvent)
+ gotWroteReqEvent <- struct{}{}
},
}
if h2 {
@@ -3704,7 +3927,15 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
wantOnce("tls handshake done")
} else {
wantOnce("PutIdleConn = <nil>")
- }
+ wantOnce("WroteHeaderField: User-Agent: [Go-http-client/1.1]")
+ // TODO(meirf): issue 19761. Make these agnostic to h1/h2. (These are not h1 specific, but the
+ // WroteHeaderField hook is not yet implemented in h2.)
+ wantOnce(fmt.Sprintf("WroteHeaderField: Host: [dns-is-faked.golang:%s]", port))
+ wantOnce(fmt.Sprintf("WroteHeaderField: Content-Length: [%d]", len(body)))
+ wantOnce("WroteHeaderField: X-Foo-Multiple-Vals: [bar baz]")
+ wantOnce("WroteHeaderField: Accept-Encoding: [gzip]")
+ }
+ wantOnce("WroteHeaders")
wantOnce("Wait100Continue")
wantOnce("Got100Continue")
wantOnce("WroteRequest: {Err:<nil>}")
@@ -3714,6 +3945,90 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
if t.Failed() {
t.Errorf("Output:\n%s", got)
}
+
+ // And do a second request:
+ req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
+ req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
+ res, err = cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 200 {
+ t.Fatal(res.Status)
+ }
+ res.Body.Close()
+
+ mu.Lock()
+ got = buf.String()
+ mu.Unlock()
+
+ sub := "Getting conn for dns-is-faked.golang:"
+ if gotn, want := strings.Count(got, sub), 2; gotn != want {
+ t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
+ }
+
+}
+
+func TestTransportEventTraceTLSVerify(t *testing.T) {
+ var mu sync.Mutex
+ var buf bytes.Buffer
+ logf := func(format string, args ...interface{}) {
+ mu.Lock()
+ defer mu.Unlock()
+ fmt.Fprintf(&buf, format, args...)
+ buf.WriteByte('\n')
+ }
+
+ ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ t.Error("Unexpected request")
+ }))
+ defer ts.Close()
+ ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) {
+ logf("%s", p)
+ return len(p), nil
+ }), "", 0)
+
+ certpool := x509.NewCertPool()
+ certpool.AddCert(ts.Certificate())
+
+ c := &Client{Transport: &Transport{
+ TLSClientConfig: &tls.Config{
+ ServerName: "dns-is-faked.golang",
+ RootCAs: certpool,
+ },
+ }}
+
+ trace := &httptrace.ClientTrace{
+ TLSHandshakeStart: func() { logf("TLSHandshakeStart") },
+ TLSHandshakeDone: func(s tls.ConnectionState, err error) {
+ logf("TLSHandshakeDone: ConnectionState = %v \n err = %v", s, err)
+ },
+ }
+
+ req, _ := NewRequest("GET", ts.URL, nil)
+ req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
+ _, err := c.Do(req)
+ if err == nil {
+ t.Error("Expected request to fail TLS verification")
+ }
+
+ mu.Lock()
+ got := buf.String()
+ mu.Unlock()
+
+ wantOnce := func(sub string) {
+ if strings.Count(got, sub) != 1 {
+ t.Errorf("expected substring %q exactly once in output.", sub)
+ }
+ }
+
+ wantOnce("TLSHandshakeStart")
+ wantOnce("TLSHandshakeDone")
+ wantOnce("err = x509: certificate is valid for example.com")
+
+ if t.Failed() {
+ t.Errorf("Output:\n%s", got)
+ }
}
var (
@@ -4365,3 +4680,161 @@ func TestNoBodyOnChunked304Response(t *testing.T) {
t.Errorf("Unexpected body on 304 response")
}
}
+
+type funcWriter func([]byte) (int, error)
+
+func (f funcWriter) Write(p []byte) (int, error) { return f(p) }
+
+type doneContext struct {
+ context.Context
+ err error
+}
+
+func (doneContext) Done() <-chan struct{} {
+ c := make(chan struct{})
+ close(c)
+ return c
+}
+
+func (d doneContext) Err() error { return d.err }
+
+// Issue 25852: Transport should check whether Context is done early.
+func TestTransportCheckContextDoneEarly(t *testing.T) {
+ tr := &Transport{}
+ req, _ := NewRequest("GET", "http://fake.example/", nil)
+ wantErr := errors.New("some error")
+ req = req.WithContext(doneContext{context.Background(), wantErr})
+ _, err := tr.RoundTrip(req)
+ if err != wantErr {
+ t.Errorf("error = %v; want %v", err, wantErr)
+ }
+}
+
+// Issue 23399: verify that if a client request times out, the Transport's
+// conn is closed so that it's not reused.
+//
+// This is the test variant that times out before the server replies with
+// any response headers.
+func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ inHandler := make(chan net.Conn, 1)
+ handlerReadReturned := make(chan bool, 1)
+ cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ inHandler <- conn
+ n, err := conn.Read([]byte{0})
+ if n != 0 || err != io.EOF {
+ t.Errorf("unexpected Read result: %v, %v", n, err)
+ }
+ handlerReadReturned <- true
+ }))
+ defer cst.close()
+
+ const timeout = 50 * time.Millisecond
+ cst.c.Timeout = timeout
+
+ _, err := cst.c.Get(cst.ts.URL)
+ if err == nil {
+ t.Fatal("unexpected Get succeess")
+ }
+
+ select {
+ case c := <-inHandler:
+ select {
+ case <-handlerReadReturned:
+ // Success.
+ return
+ case <-time.After(5 * time.Second):
+ t.Error("Handler's conn.Read seems to be stuck in Read")
+ c.Close() // close it to unblock Handler
+ }
+ case <-time.After(timeout * 10):
+ // If we didn't get into the Handler in 50ms, that probably means
+ // the builder was just slow and the the Get failed in that time
+ // but never made it to the server. That's fine. We'll usually
+ // test the part above on faster machines.
+ t.Skip("skipping test on slow builder")
+ }
+}
+
+// Issue 23399: verify that if a client request times out, the Transport's
+// conn is closed so that it's not reused.
+//
+// This is the test variant that has the server send response headers
+// first, and time out during the the write of the response body.
+func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ inHandler := make(chan net.Conn, 1)
+ handlerResult := make(chan error, 1)
+ cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "100")
+ w.(Flusher).Flush()
+ conn, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ conn.Write([]byte("foo"))
+ inHandler <- conn
+ n, err := conn.Read([]byte{0})
+ // The error should be io.EOF or "read tcp
+ // 127.0.0.1:35827->127.0.0.1:40290: read: connection
+ // reset by peer" depending on timing. Really we just
+ // care that it returns at all. But if it returns with
+ // data, that's weird.
+ if n != 0 || err == nil {
+ handlerResult <- fmt.Errorf("unexpected Read result: %v, %v", n, err)
+ return
+ }
+ handlerResult <- nil
+ }))
+ defer cst.close()
+
+ // Set Timeout to something very long but non-zero to exercise
+ // the codepaths that check for it. But rather than wait for it to fire
+ // (which would make the test slow), we send on the req.Cancel channel instead,
+ // which happens to exercise the same code paths.
+ cst.c.Timeout = time.Minute // just to be non-zero, not to hit it.
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ cancel := make(chan struct{})
+ req.Cancel = cancel
+
+ res, err := cst.c.Do(req)
+ if err != nil {
+ select {
+ case <-inHandler:
+ t.Fatalf("Get error: %v", err)
+ default:
+ // Failed before entering handler. Ignore result.
+ t.Skip("skipping test on slow builder")
+ }
+ }
+
+ close(cancel)
+ got, err := ioutil.ReadAll(res.Body)
+ if err == nil {
+ t.Fatalf("unexpected success; read %q, nil", got)
+ }
+
+ select {
+ case c := <-inHandler:
+ select {
+ case err := <-handlerResult:
+ if err != nil {
+ t.Errorf("handler: %v", err)
+ }
+ return
+ case <-time.After(5 * time.Second):
+ t.Error("Handler's conn.Read seems to be stuck in Read")
+ c.Close() // close it to unblock Handler
+ }
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout")
+ }
+}
diff --git a/libgo/go/net/http/triv.go b/libgo/go/net/http/triv.go
index cfbc577..23e65d5 100644
--- a/libgo/go/net/http/triv.go
+++ b/libgo/go/net/http/triv.go
@@ -107,7 +107,7 @@ func DateServer(rw http.ResponseWriter, req *http.Request) {
date, err := exec.Command("/bin/date").Output()
if err != nil {
- http.Error(rw, err.Error(), 500)
+ http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
rw.Write(date)
@@ -115,7 +115,7 @@ func DateServer(rw http.ResponseWriter, req *http.Request) {
func Logger(w http.ResponseWriter, req *http.Request) {
log.Print(req.URL)
- http.Error(w, "oops", 404)
+ http.Error(w, "oops", http.StatusNotFound)
}
var webroot = flag.String("root", os.Getenv("HOME"), "web root directory")
diff --git a/libgo/go/net/interface.go b/libgo/go/net/interface.go
index 4036a7f..375a456 100644
--- a/libgo/go/net/interface.go
+++ b/libgo/go/net/interface.go
@@ -10,7 +10,7 @@ import (
"time"
)
-// BUG(mikio): On NaCl, methods and functions related to
+// BUG(mikio): On JS and NaCl, methods and functions related to
// Interface are not implemented.
// BUG(mikio): On DragonFly BSD, NetBSD, OpenBSD, Plan 9 and Solaris,
diff --git a/libgo/go/net/interface_stub.go b/libgo/go/net/interface_stub.go
index 6d7147e..3b3b3c0 100644
--- a/libgo/go/net/interface_stub.go
+++ b/libgo/go/net/interface_stub.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build aix nacl
+// +build aix nacl js,wasm
package net
diff --git a/libgo/go/net/interface_test.go b/libgo/go/net/interface_test.go
index 534137a..5d183c5 100644
--- a/libgo/go/net/interface_test.go
+++ b/libgo/go/net/interface_test.go
@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// +build !js
+
package net
import (
@@ -202,7 +204,7 @@ func validateInterfaceUnicastAddrs(ifat []Addr) (*routeStats, error) {
if 0 >= prefixLen || prefixLen > 8*IPv4len || maxPrefixLen != 8*IPv4len {
return nil, fmt.Errorf("unexpected prefix length: %d/%d for %#v", prefixLen, maxPrefixLen, ifa)
}
- if ifa.IP.IsLoopback() && (prefixLen != 8 && prefixLen != 8*IPv4len) { // see RFC 1122
+ if ifa.IP.IsLoopback() && prefixLen < 8 { // see RFC 1122
return nil, fmt.Errorf("unexpected prefix length: %d/%d for %#v", prefixLen, maxPrefixLen, ifa)
}
stats.ipv4++
diff --git a/libgo/go/net/interface_windows.go b/libgo/go/net/interface_windows.go
index b08d158..28b0a65 100644
--- a/libgo/go/net/interface_windows.go
+++ b/libgo/go/net/interface_windows.go
@@ -11,22 +11,6 @@ import (
"unsafe"
)
-// supportsVistaIP reports whether the platform implements new IP
-// stack and ABIs supported on Windows Vista and above.
-var supportsVistaIP bool
-
-func init() {
- supportsVistaIP = probeWindowsIPStack()
-}
-
-func probeWindowsIPStack() (supportsVistaIP bool) {
- v, err := syscall.GetVersion()
- if err != nil {
- return true // Windows 10 and above will deprecate this API
- }
- return byte(v) >= 6 // major version of Windows Vista is 6
-}
-
// adapterAddresses returns a list of IP adapter and address
// structures. The structure contains an IP adapter and flattened
// multiple IP addresses including unicast, anycast and multicast
@@ -81,9 +65,8 @@ func interfaceTable(ifindex int) ([]Interface, error) {
}
// For now we need to infer link-layer service
// capabilities from media types.
- // We will be able to use
- // MIB_IF_ROW2.AccessType once we drop support
- // for Windows XP.
+ // TODO: use MIB_IF_ROW2.AccessType now that we no longer support
+ // Windows XP.
switch aa.IfType {
case windows.IF_TYPE_ETHERNET_CSMACD, windows.IF_TYPE_ISO88025_TOKENRING, windows.IF_TYPE_IEEE80211, windows.IF_TYPE_IEEE1394:
ifi.Flags |= FlagBroadcast | FlagMulticast
@@ -126,35 +109,17 @@ func interfaceAddrTable(ifi *Interface) ([]Addr, error) {
if index == 0 { // ipv6IfIndex is a substitute for ifIndex
index = aa.Ipv6IfIndex
}
- var pfx4, pfx6 []IPNet
- if !supportsVistaIP {
- pfx4, pfx6, err = addrPrefixTable(aa)
- if err != nil {
- return nil, err
- }
- }
if ifi == nil || ifi.Index == int(index) {
for puni := aa.FirstUnicastAddress; puni != nil; puni = puni.Next {
sa, err := puni.Address.Sockaddr.Sockaddr()
if err != nil {
return nil, os.NewSyscallError("sockaddr", err)
}
- var l int
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
- if supportsVistaIP {
- l = int(puni.OnLinkPrefixLength)
- } else {
- l = addrPrefixLen(pfx4, IP(sa.Addr[:]))
- }
- ifat = append(ifat, &IPNet{IP: IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]), Mask: CIDRMask(l, 8*IPv4len)})
+ ifat = append(ifat, &IPNet{IP: IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]), Mask: CIDRMask(int(puni.OnLinkPrefixLength), 8*IPv4len)})
case *syscall.SockaddrInet6:
- if supportsVistaIP {
- l = int(puni.OnLinkPrefixLength)
- } else {
- l = addrPrefixLen(pfx6, IP(sa.Addr[:]))
- }
- ifa := &IPNet{IP: make(IP, IPv6len), Mask: CIDRMask(l, 8*IPv6len)}
+ ifa := &IPNet{IP: make(IP, IPv6len), Mask: CIDRMask(int(puni.OnLinkPrefixLength), 8*IPv6len)}
copy(ifa.IP, sa.Addr[:])
ifat = append(ifat, ifa)
}
@@ -178,59 +143,6 @@ func interfaceAddrTable(ifi *Interface) ([]Addr, error) {
return ifat, nil
}
-func addrPrefixTable(aa *windows.IpAdapterAddresses) (pfx4, pfx6 []IPNet, err error) {
- for p := aa.FirstPrefix; p != nil; p = p.Next {
- sa, err := p.Address.Sockaddr.Sockaddr()
- if err != nil {
- return nil, nil, os.NewSyscallError("sockaddr", err)
- }
- switch sa := sa.(type) {
- case *syscall.SockaddrInet4:
- pfx := IPNet{IP: IP(sa.Addr[:]), Mask: CIDRMask(int(p.PrefixLength), 8*IPv4len)}
- pfx4 = append(pfx4, pfx)
- case *syscall.SockaddrInet6:
- pfx := IPNet{IP: IP(sa.Addr[:]), Mask: CIDRMask(int(p.PrefixLength), 8*IPv6len)}
- pfx6 = append(pfx6, pfx)
- }
- }
- return
-}
-
-// addrPrefixLen returns an appropriate prefix length in bits for ip
-// from pfxs. It returns 32 or 128 when no appropriate on-link address
-// prefix found.
-//
-// NOTE: This is pretty naive implementation that contains many
-// allocations and non-effective linear search, and should not be used
-// freely.
-func addrPrefixLen(pfxs []IPNet, ip IP) int {
- var l int
- var cand *IPNet
- for i := range pfxs {
- if !pfxs[i].Contains(ip) {
- continue
- }
- if cand == nil {
- l, _ = pfxs[i].Mask.Size()
- cand = &pfxs[i]
- continue
- }
- m, _ := pfxs[i].Mask.Size()
- if m > l {
- l = m
- cand = &pfxs[i]
- continue
- }
- }
- if l > 0 {
- return l
- }
- if ip.To4() != nil {
- return 8 * IPv4len
- }
- return 8 * IPv6len
-}
-
// interfaceMulticastAddrTable returns addresses for a specific
// interface.
func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) {
diff --git a/libgo/go/net/interface_windows_test.go b/libgo/go/net/interface_windows_test.go
deleted file mode 100644
index 03f9168..0000000
--- a/libgo/go/net/interface_windows_test.go
+++ /dev/null
@@ -1,132 +0,0 @@
-// Copyright 2015 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package net
-
-import (
- "bytes"
- "internal/syscall/windows"
- "sort"
- "testing"
-)
-
-func TestWindowsInterfaces(t *testing.T) {
- aas, err := adapterAddresses()
- if err != nil {
- t.Fatal(err)
- }
- ift, err := Interfaces()
- if err != nil {
- t.Fatal(err)
- }
- for i, ifi := range ift {
- aa := aas[i]
- if len(ifi.HardwareAddr) != int(aa.PhysicalAddressLength) {
- t.Errorf("got %d; want %d", len(ifi.HardwareAddr), aa.PhysicalAddressLength)
- }
- if ifi.MTU > 0x7fffffff {
- t.Errorf("%s: got %d; want less than or equal to 1<<31 - 1", ifi.Name, ifi.MTU)
- }
- if ifi.Flags&FlagUp != 0 && aa.OperStatus != windows.IfOperStatusUp {
- t.Errorf("%s: got %v; should not include FlagUp", ifi.Name, ifi.Flags)
- }
- if ifi.Flags&FlagLoopback != 0 && aa.IfType != windows.IF_TYPE_SOFTWARE_LOOPBACK {
- t.Errorf("%s: got %v; should not include FlagLoopback", ifi.Name, ifi.Flags)
- }
- if _, _, err := addrPrefixTable(aa); err != nil {
- t.Errorf("%s: %v", ifi.Name, err)
- }
- }
-}
-
-type byAddrLen []IPNet
-
-func (ps byAddrLen) Len() int { return len(ps) }
-
-func (ps byAddrLen) Less(i, j int) bool {
- if n := bytes.Compare(ps[i].IP, ps[j].IP); n != 0 {
- return n < 0
- }
- if n := bytes.Compare(ps[i].Mask, ps[j].Mask); n != 0 {
- return n < 0
- }
- return false
-}
-
-func (ps byAddrLen) Swap(i, j int) { ps[i], ps[j] = ps[j], ps[i] }
-
-var windowsAddrPrefixLenTests = []struct {
- pfxs []IPNet
- ip IP
- out int
-}{
- {
- []IPNet{
- {IP: IPv4(172, 16, 0, 0), Mask: IPv4Mask(255, 255, 0, 0)},
- {IP: IPv4(192, 168, 0, 0), Mask: IPv4Mask(255, 255, 255, 0)},
- {IP: IPv4(192, 168, 0, 0), Mask: IPv4Mask(255, 255, 255, 128)},
- {IP: IPv4(192, 168, 0, 0), Mask: IPv4Mask(255, 255, 255, 192)},
- },
- IPv4(192, 168, 0, 1),
- 26,
- },
- {
- []IPNet{
- {IP: ParseIP("2001:db8::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fff0"))},
- {IP: ParseIP("2001:db8::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fff8"))},
- {IP: ParseIP("2001:db8::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffc"))},
- },
- ParseIP("2001:db8::1"),
- 126,
- },
-
- // Fallback cases. It may happen on Windows XP or 2003 server.
- {
- []IPNet{
- {IP: IPv4(127, 0, 0, 0).To4(), Mask: IPv4Mask(255, 0, 0, 0)},
- {IP: IPv4(10, 0, 0, 0).To4(), Mask: IPv4Mask(255, 0, 0, 0)},
- {IP: IPv4(172, 16, 0, 0).To4(), Mask: IPv4Mask(255, 255, 0, 0)},
- {IP: IPv4(192, 168, 255, 0), Mask: IPv4Mask(255, 255, 255, 0)},
- {IP: IPv4zero, Mask: IPv4Mask(0, 0, 0, 0)},
- },
- IPv4(192, 168, 0, 1),
- 8 * IPv4len,
- },
- {
- nil,
- IPv4(192, 168, 0, 1),
- 8 * IPv4len,
- },
- {
- []IPNet{
- {IP: IPv6loopback, Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"))},
- {IP: ParseIP("2001:db8:1::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fff0"))},
- {IP: ParseIP("2001:db8:2::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fff8"))},
- {IP: ParseIP("2001:db8:3::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffc"))},
- {IP: IPv6unspecified, Mask: IPMask(ParseIP("::"))},
- },
- ParseIP("2001:db8::1"),
- 8 * IPv6len,
- },
- {
- nil,
- ParseIP("2001:db8::1"),
- 8 * IPv6len,
- },
-}
-
-func TestWindowsAddrPrefixLen(t *testing.T) {
- for i, tt := range windowsAddrPrefixLenTests {
- sort.Sort(byAddrLen(tt.pfxs))
- l := addrPrefixLen(tt.pfxs, tt.ip)
- if l != tt.out {
- t.Errorf("#%d: got %d; want %d", i, l, tt.out)
- }
- sort.Sort(sort.Reverse(byAddrLen(tt.pfxs)))
- l = addrPrefixLen(tt.pfxs, tt.ip)
- if l != tt.out {
- t.Errorf("#%d: got %d; want %d", i, l, tt.out)
- }
- }
-}
diff --git a/libgo/go/net/internal/socktest/main_test.go b/libgo/go/net/internal/socktest/main_test.go
index 60e581f..3b0a48a 100644
--- a/libgo/go/net/internal/socktest/main_test.go
+++ b/libgo/go/net/internal/socktest/main_test.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build !plan9
+// +build !js,!plan9
package socktest_test
diff --git a/libgo/go/net/internal/socktest/main_unix_test.go b/libgo/go/net/internal/socktest/main_unix_test.go
index b8eebc2..4d9d414 100644
--- a/libgo/go/net/internal/socktest/main_unix_test.go
+++ b/libgo/go/net/internal/socktest/main_unix_test.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build !plan9,!windows
+// +build !js,!plan9,!windows
package socktest_test
diff --git a/libgo/go/net/internal/socktest/switch_unix.go b/libgo/go/net/internal/socktest/switch_unix.go
index 8fb15f3..0626aa0 100644
--- a/libgo/go/net/internal/socktest/switch_unix.go
+++ b/libgo/go/net/internal/socktest/switch_unix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build aix darwin dragonfly freebsd linux nacl netbsd openbsd solaris
+// +build aix darwin dragonfly freebsd js,wasm linux nacl netbsd openbsd solaris
package socktest
diff --git a/libgo/go/net/internal/socktest/sys_cloexec.go b/libgo/go/net/internal/socktest/sys_cloexec.go
index d1b8f4f..986d894 100644
--- a/libgo/go/net/internal/socktest/sys_cloexec.go
+++ b/libgo/go/net/internal/socktest/sys_cloexec.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build dragonfly freebsd linux
+// +build dragonfly freebsd linux netbsd openbsd
package socktest
diff --git a/libgo/go/net/internal/socktest/sys_unix.go b/libgo/go/net/internal/socktest/sys_unix.go
index 397c524..b96075b 100644
--- a/libgo/go/net/internal/socktest/sys_unix.go
+++ b/libgo/go/net/internal/socktest/sys_unix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build aix darwin dragonfly freebsd linux nacl netbsd openbsd solaris
+// +build aix darwin dragonfly freebsd js,wasm linux nacl netbsd openbsd solaris
package socktest
diff --git a/libgo/go/net/ip.go b/libgo/go/net/ip.go
index 6b7ba4c..da8dca5 100644
--- a/libgo/go/net/ip.go
+++ b/libgo/go/net/ip.go
@@ -260,6 +260,25 @@ func (ip IP) Mask(mask IPMask) IP {
return out
}
+// ubtoa encodes the string form of the integer v to dst[start:] and
+// returns the number of bytes written to dst. The caller must ensure
+// that dst has sufficient length.
+func ubtoa(dst []byte, start int, v byte) int {
+ if v < 10 {
+ dst[start] = v + '0'
+ return 1
+ } else if v < 100 {
+ dst[start+1] = v%10 + '0'
+ dst[start] = v/10 + '0'
+ return 2
+ }
+
+ dst[start+2] = v%10 + '0'
+ dst[start+1] = (v/10)%10 + '0'
+ dst[start] = v/100 + '0'
+ return 3
+}
+
// String returns the string form of the IP address ip.
// It returns one of 4 forms:
// - "<nil>", if ip has length 0
@@ -275,10 +294,23 @@ func (ip IP) String() string {
// If IPv4, use dotted notation.
if p4 := p.To4(); len(p4) == IPv4len {
- return uitoa(uint(p4[0])) + "." +
- uitoa(uint(p4[1])) + "." +
- uitoa(uint(p4[2])) + "." +
- uitoa(uint(p4[3]))
+ const maxIPv4StringLen = len("255.255.255.255")
+ b := make([]byte, maxIPv4StringLen)
+
+ n := ubtoa(b, 0, p4[0])
+ b[n] = '.'
+ n++
+
+ n += ubtoa(b, n, p4[1])
+ b[n] = '.'
+ n++
+
+ n += ubtoa(b, n, p4[2])
+ b[n] = '.'
+ n++
+
+ n += ubtoa(b, n, p4[3])
+ return string(b[:n])
}
if len(p) != IPv6len {
return "?" + hexString(ip)
@@ -530,25 +562,26 @@ func parseIPv4(s string) IP {
return IPv4(p[0], p[1], p[2], p[3])
}
-// parseIPv6 parses s as a literal IPv6 address described in RFC 4291
-// and RFC 5952. It can also parse a literal scoped IPv6 address with
-// zone identifier which is described in RFC 4007 when zoneAllowed is
-// true.
-func parseIPv6(s string, zoneAllowed bool) (ip IP, zone string) {
+// parseIPv6Zone parses s as a literal IPv6 address and its associated zone
+// identifier which is described in RFC 4007.
+func parseIPv6Zone(s string) (IP, string) {
+ s, zone := splitHostZone(s)
+ return parseIPv6(s), zone
+}
+
+// parseIPv6Zone parses s as a literal IPv6 address described in RFC 4291
+// and RFC 5952.
+func parseIPv6(s string) (ip IP) {
ip = make(IP, IPv6len)
ellipsis := -1 // position of ellipsis in ip
- if zoneAllowed {
- s, zone = splitHostZone(s)
- }
-
// Might have leading ellipsis
if len(s) >= 2 && s[0] == ':' && s[1] == ':' {
ellipsis = 0
s = s[2:]
// Might be only ellipsis
if len(s) == 0 {
- return ip, zone
+ return ip
}
}
@@ -558,22 +591,22 @@ func parseIPv6(s string, zoneAllowed bool) (ip IP, zone string) {
// Hex number.
n, c, ok := xtoi(s)
if !ok || n > 0xFFFF {
- return nil, zone
+ return nil
}
// If followed by dot, might be in trailing IPv4.
if c < len(s) && s[c] == '.' {
if ellipsis < 0 && i != IPv6len-IPv4len {
// Not the right place.
- return nil, zone
+ return nil
}
if i+IPv4len > IPv6len {
// Not enough room.
- return nil, zone
+ return nil
}
ip4 := parseIPv4(s)
if ip4 == nil {
- return nil, zone
+ return nil
}
ip[i] = ip4[12]
ip[i+1] = ip4[13]
@@ -597,14 +630,14 @@ func parseIPv6(s string, zoneAllowed bool) (ip IP, zone string) {
// Otherwise must be followed by colon and more.
if s[0] != ':' || len(s) == 1 {
- return nil, zone
+ return nil
}
s = s[1:]
// Look for ellipsis.
if s[0] == ':' {
if ellipsis >= 0 { // already have one
- return nil, zone
+ return nil
}
ellipsis = i
s = s[1:]
@@ -616,13 +649,13 @@ func parseIPv6(s string, zoneAllowed bool) (ip IP, zone string) {
// Must have used entire string.
if len(s) != 0 {
- return nil, zone
+ return nil
}
// If didn't parse enough, expand ellipsis.
if i < IPv6len {
if ellipsis < 0 {
- return nil, zone
+ return nil
}
n := IPv6len - i
for j := i - 1; j >= ellipsis; j-- {
@@ -633,9 +666,9 @@ func parseIPv6(s string, zoneAllowed bool) (ip IP, zone string) {
}
} else if ellipsis >= 0 {
// Ellipsis must represent at least one 0 group.
- return nil, zone
+ return nil
}
- return ip, zone
+ return ip
}
// ParseIP parses s as an IP address, returning the result.
@@ -649,13 +682,26 @@ func ParseIP(s string) IP {
case '.':
return parseIPv4(s)
case ':':
- ip, _ := parseIPv6(s, false)
- return ip
+ return parseIPv6(s)
}
}
return nil
}
+// parseIPZone parses s as an IP address, return it and its associated zone
+// identifier (IPv6 only).
+func parseIPZone(s string) (IP, string) {
+ for i := 0; i < len(s); i++ {
+ switch s[i] {
+ case '.':
+ return parseIPv4(s), ""
+ case ':':
+ return parseIPv6Zone(s)
+ }
+ }
+ return nil, ""
+}
+
// ParseCIDR parses s as a CIDR notation IP address and prefix length,
// like "192.0.2.0/24" or "2001:db8::/32", as defined in
// RFC 4632 and RFC 4291.
@@ -674,7 +720,7 @@ func ParseCIDR(s string) (IP, *IPNet, error) {
ip := parseIPv4(addr)
if ip == nil {
iplen = IPv6len
- ip, _ = parseIPv6(addr, false)
+ ip = parseIPv6(addr)
}
n, i, ok := dtoi(mask)
if ip == nil || !ok || i != len(mask) || n < 0 || n > 8*iplen {
diff --git a/libgo/go/net/ip_test.go b/libgo/go/net/ip_test.go
index ad13388..a5fc5e6 100644
--- a/libgo/go/net/ip_test.go
+++ b/libgo/go/net/ip_test.go
@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// +build !js
+
package net
import (
@@ -129,7 +131,7 @@ func TestMarshalEmptyIP(t *testing.T) {
}
}
-var ipStringTests = []struct {
+var ipStringTests = []*struct {
in IP // see RFC 791 and RFC 4291
str string // see RFC 791, RFC 4291 and RFC 5952
byt []byte
@@ -252,9 +254,21 @@ var sink string
func BenchmarkIPString(b *testing.B) {
testHookUninstaller.Do(uninstallTestHooks)
+ b.Run("IPv4", func(b *testing.B) {
+ benchmarkIPString(b, IPv4len)
+ })
+
+ b.Run("IPv6", func(b *testing.B) {
+ benchmarkIPString(b, IPv6len)
+ })
+}
+
+func benchmarkIPString(b *testing.B, size int) {
+ b.ReportAllocs()
+ b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, tt := range ipStringTests {
- if tt.in != nil {
+ if tt.in != nil && len(tt.in) == size {
sink = tt.in.String()
}
}
diff --git a/libgo/go/net/iprawsock.go b/libgo/go/net/iprawsock.go
index 72cbc39..8a9c265 100644
--- a/libgo/go/net/iprawsock.go
+++ b/libgo/go/net/iprawsock.go
@@ -21,8 +21,8 @@ import (
// change the behavior of these methods; use Read or ReadMsgIP
// instead.
-// BUG(mikio): On NaCl and Plan 9, the ReadMsgIP and
-// WriteMsgIP methods of IPConn are not implemented.
+// BUG(mikio): On JS, NaCl and Plan 9, methods and functions related
+// to IPConn are not implemented.
// BUG(mikio): On Windows, the File method of IPConn is not
// implemented.
@@ -209,7 +209,11 @@ func newIPConn(fd *netFD) *IPConn { return &IPConn{conn{fd}} }
// If the IP field of raddr is nil or an unspecified IP address, the
// local system is assumed.
func DialIP(network string, laddr, raddr *IPAddr) (*IPConn, error) {
- c, err := dialIP(context.Background(), network, laddr, raddr)
+ if raddr == nil {
+ return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
+ }
+ sd := &sysDialer{network: network, address: raddr.String()}
+ c, err := sd.dialIP(context.Background(), laddr, raddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}
@@ -224,7 +228,11 @@ func DialIP(network string, laddr, raddr *IPAddr) (*IPConn, error) {
// ListenIP listens on all available IP addresses of the local system
// except multicast IP addresses.
func ListenIP(network string, laddr *IPAddr) (*IPConn, error) {
- c, err := listenIP(context.Background(), network, laddr)
+ if laddr == nil {
+ laddr = &IPAddr{}
+ }
+ sl := &sysListener{network: network, address: laddr.String()}
+ c, err := sl.listenIP(context.Background(), laddr)
if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: err}
}
diff --git a/libgo/go/net/iprawsock_plan9.go b/libgo/go/net/iprawsock_plan9.go
index 6aebea1..ebe5808 100644
--- a/libgo/go/net/iprawsock_plan9.go
+++ b/libgo/go/net/iprawsock_plan9.go
@@ -25,10 +25,10 @@ func (c *IPConn) writeMsg(b, oob []byte, addr *IPAddr) (n, oobn int, err error)
return 0, 0, syscall.EPLAN9
}
-func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn, error) {
+func (sd *sysDialer) dialIP(ctx context.Context, laddr, raddr *IPAddr) (*IPConn, error) {
return nil, syscall.EPLAN9
}
-func listenIP(ctx context.Context, netProto string, laddr *IPAddr) (*IPConn, error) {
+func (sl *sysListener) listenIP(ctx context.Context, laddr *IPAddr) (*IPConn, error) {
return nil, syscall.EPLAN9
}
diff --git a/libgo/go/net/iprawsock_posix.go b/libgo/go/net/iprawsock_posix.go
index d613e6f..2a5d49f 100644
--- a/libgo/go/net/iprawsock_posix.go
+++ b/libgo/go/net/iprawsock_posix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build aix darwin dragonfly freebsd linux nacl netbsd openbsd solaris windows
+// +build aix darwin dragonfly freebsd js,wasm linux nacl netbsd openbsd solaris windows
package net
@@ -112,37 +112,34 @@ func (c *IPConn) writeMsg(b, oob []byte, addr *IPAddr) (n, oobn int, err error)
return c.fd.writeMsg(b, oob, sa)
}
-func dialIP(ctx context.Context, netProto string, laddr, raddr *IPAddr) (*IPConn, error) {
- network, proto, err := parseNetwork(ctx, netProto, true)
+func (sd *sysDialer) dialIP(ctx context.Context, laddr, raddr *IPAddr) (*IPConn, error) {
+ network, proto, err := parseNetwork(ctx, sd.network, true)
if err != nil {
return nil, err
}
switch network {
case "ip", "ip4", "ip6":
default:
- return nil, UnknownNetworkError(netProto)
+ return nil, UnknownNetworkError(sd.network)
}
- if raddr == nil {
- return nil, errMissingAddress
- }
- fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial")
+ fd, err := internetSocket(ctx, network, laddr, raddr, syscall.SOCK_RAW, proto, "dial", sd.Dialer.Control)
if err != nil {
return nil, err
}
return newIPConn(fd), nil
}
-func listenIP(ctx context.Context, netProto string, laddr *IPAddr) (*IPConn, error) {
- network, proto, err := parseNetwork(ctx, netProto, true)
+func (sl *sysListener) listenIP(ctx context.Context, laddr *IPAddr) (*IPConn, error) {
+ network, proto, err := parseNetwork(ctx, sl.network, true)
if err != nil {
return nil, err
}
switch network {
case "ip", "ip4", "ip6":
default:
- return nil, UnknownNetworkError(netProto)
+ return nil, UnknownNetworkError(sl.network)
}
- fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen")
+ fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_RAW, proto, "listen", sl.ListenConfig.Control)
if err != nil {
return nil, err
}
diff --git a/libgo/go/net/iprawsock_test.go b/libgo/go/net/iprawsock_test.go
index 8972051..8e3543d 100644
--- a/libgo/go/net/iprawsock_test.go
+++ b/libgo/go/net/iprawsock_test.go
@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// +build !js
+
package net
import (
diff --git a/libgo/go/net/ipsock.go b/libgo/go/net/ipsock.go
index 947bdf3..f4ff82b 100644
--- a/libgo/go/net/ipsock.go
+++ b/libgo/go/net/ipsock.go
@@ -276,24 +276,16 @@ func (r *Resolver) internetAddrList(ctx context.Context, net, addr string) (addr
}
// Try as a literal IP address, then as a DNS name.
- var ips []IPAddr
- if ip := parseIPv4(host); ip != nil {
- ips = []IPAddr{{IP: ip}}
- } else if ip, zone := parseIPv6(host, true); ip != nil {
- ips = []IPAddr{{IP: ip, Zone: zone}}
- // Issue 18806: if the machine has halfway configured
- // IPv6 such that it can bind on "::" (IPv6unspecified)
- // but not connect back to that same address, fall
- // back to dialing 0.0.0.0.
- if ip.Equal(IPv6unspecified) {
- ips = append(ips, IPAddr{IP: IPv4zero})
- }
- } else {
- // Try as a DNS name.
- ips, err = r.LookupIPAddr(ctx, host)
- if err != nil {
- return nil, err
- }
+ ips, err := r.LookupIPAddr(ctx, host)
+ if err != nil {
+ return nil, err
+ }
+ // Issue 18806: if the machine has halfway configured
+ // IPv6 such that it can bind on "::" (IPv6unspecified)
+ // but not connect back to that same address, fall
+ // back to dialing 0.0.0.0.
+ if len(ips) == 1 && ips[0].IP.Equal(IPv6unspecified) {
+ ips = append(ips, IPAddr{IP: IPv4zero})
}
var filter func(IPAddr) bool
diff --git a/libgo/go/net/ipsock_posix.go b/libgo/go/net/ipsock_posix.go
index 9cff960..08804ca 100644
--- a/libgo/go/net/ipsock_posix.go
+++ b/libgo/go/net/ipsock_posix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build aix darwin dragonfly freebsd linux nacl netbsd openbsd solaris windows
+// +build aix darwin dragonfly freebsd js,wasm linux nacl netbsd openbsd solaris windows
package net
@@ -133,12 +133,12 @@ func favoriteAddrFamily(network string, laddr, raddr sockaddr, mode string) (fam
return syscall.AF_INET6, false
}
-func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string) (fd *netFD, err error) {
+func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, sotype, proto int, mode string, ctrlFn func(string, string, syscall.RawConn) error) (fd *netFD, err error) {
if (runtime.GOOS == "windows" || runtime.GOOS == "openbsd" || runtime.GOOS == "nacl") && mode == "dial" && raddr.isWildcard() {
raddr = raddr.toLocal(net)
}
family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode)
- return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr)
+ return socket(ctx, net, family, sotype, proto, ipv6only, laddr, raddr, ctrlFn)
}
func ipToSockaddr(family int, ip IP, port int, zone string) (syscall.Sockaddr, error) {
diff --git a/libgo/go/net/listen_test.go b/libgo/go/net/listen_test.go
index 96624f9..ffce8e2 100644
--- a/libgo/go/net/listen_test.go
+++ b/libgo/go/net/listen_test.go
@@ -2,11 +2,12 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build !plan9
+// +build !js,!plan9
package net
import (
+ "context"
"fmt"
"internal/testenv"
"os"
@@ -729,3 +730,56 @@ func TestClosingListener(t *testing.T) {
}
ln2.Close()
}
+
+func TestListenConfigControl(t *testing.T) {
+ switch runtime.GOOS {
+ case "nacl", "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ t.Run("StreamListen", func(t *testing.T) {
+ for _, network := range []string{"tcp", "tcp4", "tcp6", "unix", "unixpacket"} {
+ if !testableNetwork(network) {
+ continue
+ }
+ ln, err := newLocalListener(network)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ address := ln.Addr().String()
+ ln.Close()
+ lc := ListenConfig{Control: controlOnConnSetup}
+ ln, err = lc.Listen(context.Background(), network, address)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ ln.Close()
+ }
+ })
+ t.Run("PacketListen", func(t *testing.T) {
+ for _, network := range []string{"udp", "udp4", "udp6", "unixgram"} {
+ if !testableNetwork(network) {
+ continue
+ }
+ c, err := newLocalPacketListener(network)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ address := c.LocalAddr().String()
+ c.Close()
+ if network == "unixgram" {
+ os.Remove(address)
+ }
+ lc := ListenConfig{Control: controlOnConnSetup}
+ c, err = lc.ListenPacket(context.Background(), network, address)
+ if err != nil {
+ t.Error(err)
+ continue
+ }
+ c.Close()
+ }
+ })
+}
diff --git a/libgo/go/net/lookup.go b/libgo/go/net/lookup.go
index a65b735..e0f21fa 100644
--- a/libgo/go/net/lookup.go
+++ b/libgo/go/net/lookup.go
@@ -15,7 +15,7 @@ import (
// names and numbers for platforms that don't have a complete list of
// protocol numbers.
//
-// See http://www.iana.org/assignments/protocol-numbers
+// See https://www.iana.org/assignments/protocol-numbers
//
// On Unix, this map is augmented by readProtocols via lookupProtocol.
var protocols = map[string]int{
@@ -133,10 +133,25 @@ type Resolver struct {
// If nil, the default dialer is used.
Dial func(ctx context.Context, network, address string) (Conn, error)
+ // lookupGroup merges LookupIPAddr calls together for lookups for the same
+ // host. The lookupGroup key is the LookupIPAddr.host argument.
+ // The return values are ([]IPAddr, error).
+ lookupGroup singleflight.Group
+
// TODO(bradfitz): optional interface impl override hook
// TODO(bradfitz): Timeout time.Duration?
}
+func (r *Resolver) preferGo() bool { return r != nil && r.PreferGo }
+func (r *Resolver) strictErrors() bool { return r != nil && r.StrictErrors }
+
+func (r *Resolver) getLookupGroup() *singleflight.Group {
+ if r == nil {
+ return &DefaultResolver.lookupGroup
+ }
+ return &r.lookupGroup
+}
+
// LookupHost looks up the given host using the local resolver.
// It returns a slice of that host's addresses.
func LookupHost(host string) (addrs []string, err error) {
@@ -147,11 +162,11 @@ func LookupHost(host string) (addrs []string, err error) {
// It returns a slice of that host's addresses.
func (r *Resolver) LookupHost(ctx context.Context, host string) (addrs []string, err error) {
// Make sure that no matter what we do later, host=="" is rejected.
- // ParseIP, for example, does accept empty strings.
+ // parseIP, for example, does accept empty strings.
if host == "" {
return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host}
}
- if ip := ParseIP(host); ip != nil {
+ if ip, _ := parseIPZone(host); ip != nil {
return []string{host}, nil
}
return r.lookupHost(ctx, host)
@@ -175,12 +190,12 @@ func LookupIP(host string) ([]IP, error) {
// It returns a slice of that host's IPv4 and IPv6 addresses.
func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, error) {
// Make sure that no matter what we do later, host=="" is rejected.
- // ParseIP, for example, does accept empty strings.
+ // parseIP, for example, does accept empty strings.
if host == "" {
return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host}
}
- if ip := ParseIP(host); ip != nil {
- return []IPAddr{{IP: ip}}, nil
+ if ip, zone := parseIPZone(host); ip != nil {
+ return []IPAddr{{IP: ip, Zone: zone}}, nil
}
trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace)
if trace != nil && trace.DNSStart != nil {
@@ -201,7 +216,7 @@ func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, err
lookupGroupCtx, lookupGroupCancel := context.WithCancel(context.Background())
dnsWaitGroup.Add(1)
- ch, called := lookupGroup.DoChan(host, func() (interface{}, error) {
+ ch, called := r.getLookupGroup().DoChan(host, func() (interface{}, error) {
defer dnsWaitGroup.Done()
return testHookLookupIP(lookupGroupCtx, resolverFunc, host)
})
@@ -218,7 +233,7 @@ func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, err
// let the lookup continue uncanceled, and let later
// lookups with the same key share the result.
// See issues 8602, 20703, 22724.
- if lookupGroup.ForgetUnshared(host) {
+ if r.getLookupGroup().ForgetUnshared(host) {
lookupGroupCancel()
} else {
go func() {
@@ -241,12 +256,6 @@ func (r *Resolver) LookupIPAddr(ctx context.Context, host string) ([]IPAddr, err
}
}
-// lookupGroup merges LookupIPAddr calls together for lookups
-// for the same host. The lookupGroup key is is the LookupIPAddr.host
-// argument.
-// The return values are ([]IPAddr, error).
-var lookupGroup singleflight.Group
-
// lookupIPReturn turns the return values from singleflight.Do into
// the return values from LookupIP.
func lookupIPReturn(addrsi interface{}, err error, shared bool) ([]IPAddr, error) {
diff --git a/libgo/go/net/lookup_nacl.go b/libgo/go/net/lookup_fake.go
index 43cebad..d3d1dbc 100644
--- a/libgo/go/net/lookup_nacl.go
+++ b/libgo/go/net/lookup_fake.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build nacl
+// +build nacl js,wasm
package net
@@ -50,3 +50,9 @@ func (*Resolver) lookupTXT(ctx context.Context, name string) (txts []string, err
func (*Resolver) lookupAddr(ctx context.Context, addr string) (ptrs []string, err error) {
return nil, syscall.ENOPROTOOPT
}
+
+// concurrentThreadsLimit returns the number of threads we permit to
+// run concurrently doing DNS lookups.
+func concurrentThreadsLimit() int {
+ return 500
+}
diff --git a/libgo/go/net/lookup_plan9.go b/libgo/go/net/lookup_plan9.go
index 1037b81..5547f0b 100644
--- a/libgo/go/net/lookup_plan9.go
+++ b/libgo/go/net/lookup_plan9.go
@@ -11,34 +11,58 @@ import (
"os"
)
-func query(ctx context.Context, filename, query string, bufSize int) (res []string, err error) {
- file, err := os.OpenFile(filename, os.O_RDWR, 0)
- if err != nil {
- return
- }
- defer file.Close()
+func query(ctx context.Context, filename, query string, bufSize int) (addrs []string, err error) {
+ queryAddrs := func() (addrs []string, err error) {
+ file, err := os.OpenFile(filename, os.O_RDWR, 0)
+ if err != nil {
+ return nil, err
+ }
+ defer file.Close()
- _, err = file.Seek(0, io.SeekStart)
- if err != nil {
- return
- }
- _, err = file.WriteString(query)
- if err != nil {
- return
+ _, err = file.Seek(0, io.SeekStart)
+ if err != nil {
+ return nil, err
+ }
+ _, err = file.WriteString(query)
+ if err != nil {
+ return nil, err
+ }
+ _, err = file.Seek(0, io.SeekStart)
+ if err != nil {
+ return nil, err
+ }
+ buf := make([]byte, bufSize)
+ for {
+ n, _ := file.Read(buf)
+ if n <= 0 {
+ break
+ }
+ addrs = append(addrs, string(buf[:n]))
+ }
+ return addrs, nil
}
- _, err = file.Seek(0, io.SeekStart)
- if err != nil {
- return
+
+ type ret struct {
+ addrs []string
+ err error
}
- buf := make([]byte, bufSize)
- for {
- n, _ := file.Read(buf)
- if n <= 0 {
- break
+
+ ch := make(chan ret, 1)
+ go func() {
+ addrs, err := queryAddrs()
+ ch <- ret{addrs: addrs, err: err}
+ }()
+
+ select {
+ case r := <-ch:
+ return r.addrs, r.err
+ case <-ctx.Done():
+ return nil, &DNSError{
+ Name: query,
+ Err: ctx.Err().Error(),
+ IsTimeout: ctx.Err() == context.DeadlineExceeded,
}
- res = append(res, string(buf[:n]))
}
- return
}
func queryCS(ctx context.Context, net, host, service string) (res []string, err error) {
@@ -305,3 +329,9 @@ func (*Resolver) lookupAddr(ctx context.Context, addr string) (name []string, er
}
return
}
+
+// concurrentThreadsLimit returns the number of threads we permit to
+// run concurrently doing DNS lookups.
+func concurrentThreadsLimit() int {
+ return 500
+}
diff --git a/libgo/go/net/lookup_test.go b/libgo/go/net/lookup_test.go
index 24787cc..5c66dfa 100644
--- a/libgo/go/net/lookup_test.go
+++ b/libgo/go/net/lookup_test.go
@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// +build !js
+
package net
import (
@@ -13,6 +15,7 @@ import (
"runtime"
"sort"
"strings"
+ "sync"
"testing"
"time"
)
@@ -60,19 +63,34 @@ var lookupGoogleSRVTests = []struct {
},
}
+var backoffDuration = [...]time.Duration{time.Second, 5 * time.Second, 30 * time.Second}
+
func TestLookupGoogleSRV(t *testing.T) {
- if testenv.Builder() == "" {
- testenv.MustHaveExternalNetwork(t)
+ t.Parallel()
+ mustHaveExternalNetwork(t)
+
+ if runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") {
+ t.Skip("no resolv.conf on iOS")
}
if !supportsIPv4() || !*testIPv4 {
t.Skip("IPv4 is required")
}
- for _, tt := range lookupGoogleSRVTests {
+ attempts := 0
+ for i := 0; i < len(lookupGoogleSRVTests); i++ {
+ tt := lookupGoogleSRVTests[i]
cname, srvs, err := LookupSRV(tt.service, tt.proto, tt.name)
if err != nil {
testenv.SkipFlakyNet(t)
+ if attempts < len(backoffDuration) {
+ dur := backoffDuration[attempts]
+ t.Logf("backoff %v after failure %v\n", dur, err)
+ time.Sleep(dur)
+ attempts++
+ i--
+ continue
+ }
t.Fatal(err)
}
if len(srvs) == 0 {
@@ -97,19 +115,31 @@ var lookupGmailMXTests = []struct {
}
func TestLookupGmailMX(t *testing.T) {
- if testenv.Builder() == "" {
- testenv.MustHaveExternalNetwork(t)
+ t.Parallel()
+ mustHaveExternalNetwork(t)
+
+ if runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") {
+ t.Skip("no resolv.conf on iOS")
}
if !supportsIPv4() || !*testIPv4 {
t.Skip("IPv4 is required")
}
- defer dnsWaitGroup.Wait()
-
- for _, tt := range lookupGmailMXTests {
+ attempts := 0
+ for i := 0; i < len(lookupGmailMXTests); i++ {
+ tt := lookupGmailMXTests[i]
mxs, err := LookupMX(tt.name)
if err != nil {
+ testenv.SkipFlakyNet(t)
+ if attempts < len(backoffDuration) {
+ dur := backoffDuration[attempts]
+ t.Logf("backoff %v after failure %v\n", dur, err)
+ time.Sleep(dur)
+ attempts++
+ i--
+ continue
+ }
t.Fatal(err)
}
if len(mxs) == 0 {
@@ -131,20 +161,31 @@ var lookupGmailNSTests = []struct {
}
func TestLookupGmailNS(t *testing.T) {
- if testenv.Builder() == "" {
- testenv.MustHaveExternalNetwork(t)
+ t.Parallel()
+ mustHaveExternalNetwork(t)
+
+ if runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") {
+ t.Skip("no resolv.conf on iOS")
}
if !supportsIPv4() || !*testIPv4 {
t.Skip("IPv4 is required")
}
- defer dnsWaitGroup.Wait()
-
- for _, tt := range lookupGmailNSTests {
+ attempts := 0
+ for i := 0; i < len(lookupGmailNSTests); i++ {
+ tt := lookupGmailNSTests[i]
nss, err := LookupNS(tt.name)
if err != nil {
testenv.SkipFlakyNet(t)
+ if attempts < len(backoffDuration) {
+ dur := backoffDuration[attempts]
+ t.Logf("backoff %v after failure %v\n", dur, err)
+ time.Sleep(dur)
+ attempts++
+ i--
+ continue
+ }
t.Fatal(err)
}
if len(nss) == 0 {
@@ -166,19 +207,31 @@ var lookupGmailTXTTests = []struct {
}
func TestLookupGmailTXT(t *testing.T) {
- if testenv.Builder() == "" {
- testenv.MustHaveExternalNetwork(t)
+ t.Parallel()
+ mustHaveExternalNetwork(t)
+
+ if runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") {
+ t.Skip("no resolv.conf on iOS")
}
if !supportsIPv4() || !*testIPv4 {
t.Skip("IPv4 is required")
}
- defer dnsWaitGroup.Wait()
-
- for _, tt := range lookupGmailTXTTests {
+ attempts := 0
+ for i := 0; i < len(lookupGmailTXTTests); i++ {
+ tt := lookupGmailTXTTests[i]
txts, err := LookupTXT(tt.name)
if err != nil {
+ testenv.SkipFlakyNet(t)
+ if attempts < len(backoffDuration) {
+ dur := backoffDuration[attempts]
+ t.Logf("backoff %v after failure %v\n", dur, err)
+ time.Sleep(dur)
+ attempts++
+ i--
+ continue
+ }
t.Fatal(err)
}
if len(txts) == 0 {
@@ -203,9 +256,7 @@ var lookupGooglePublicDNSAddrTests = []struct {
}
func TestLookupGooglePublicDNSAddr(t *testing.T) {
- if testenv.Builder() == "" {
- testenv.MustHaveExternalNetwork(t)
- }
+ mustHaveExternalNetwork(t)
if !supportsIPv4() || !supportsIPv6() || !*testIPv4 || !*testIPv6 {
t.Skip("both IPv4 and IPv6 are required")
@@ -255,6 +306,32 @@ func TestLookupIPv6LinkLocalAddr(t *testing.T) {
}
}
+func TestLookupIPv6LinkLocalAddrWithZone(t *testing.T) {
+ if !supportsIPv6() || !*testIPv6 {
+ t.Skip("IPv6 is required")
+ }
+
+ ipaddrs, err := DefaultResolver.LookupIPAddr(context.Background(), "fe80::1%lo0")
+ if err != nil {
+ t.Error(err)
+ }
+ for _, addr := range ipaddrs {
+ if e, a := "lo0", addr.Zone; e != a {
+ t.Errorf("wrong zone: want %q, got %q", e, a)
+ }
+ }
+
+ addrs, err := DefaultResolver.LookupHost(context.Background(), "fe80::1%lo0")
+ if err != nil {
+ t.Error(err)
+ }
+ for _, addr := range addrs {
+ if e, a := "fe80::1%lo0", addr; e != a {
+ t.Errorf("wrong host: want %q got %q", e, a)
+ }
+ }
+}
+
var lookupCNAMETests = []struct {
name, cname string
}{
@@ -264,9 +341,7 @@ var lookupCNAMETests = []struct {
}
func TestLookupCNAME(t *testing.T) {
- if testenv.Builder() == "" {
- testenv.MustHaveExternalNetwork(t)
- }
+ mustHaveExternalNetwork(t)
if !supportsIPv4() || !*testIPv4 {
t.Skip("IPv4 is required")
@@ -274,9 +349,20 @@ func TestLookupCNAME(t *testing.T) {
defer dnsWaitGroup.Wait()
- for _, tt := range lookupCNAMETests {
+ attempts := 0
+ for i := 0; i < len(lookupCNAMETests); i++ {
+ tt := lookupCNAMETests[i]
cname, err := LookupCNAME(tt.name)
if err != nil {
+ testenv.SkipFlakyNet(t)
+ if attempts < len(backoffDuration) {
+ dur := backoffDuration[attempts]
+ t.Logf("backoff %v after failure %v\n", dur, err)
+ time.Sleep(dur)
+ attempts++
+ i--
+ continue
+ }
t.Fatal(err)
}
if !strings.HasSuffix(cname, tt.cname) {
@@ -293,9 +379,7 @@ var lookupGoogleHostTests = []struct {
}
func TestLookupGoogleHost(t *testing.T) {
- if testenv.Builder() == "" {
- testenv.MustHaveExternalNetwork(t)
- }
+ mustHaveExternalNetwork(t)
if !supportsIPv4() || !*testIPv4 {
t.Skip("IPv4 is required")
@@ -320,12 +404,8 @@ func TestLookupGoogleHost(t *testing.T) {
}
func TestLookupLongTXT(t *testing.T) {
- if runtime.GOOS == "plan9" {
- t.Skip("skipping on plan9; see https://golang.org/issue/22857")
- }
- if testenv.Builder() == "" {
- testenv.MustHaveExternalNetwork(t)
- }
+ testenv.SkipFlaky(t, 22857)
+ mustHaveExternalNetwork(t)
defer dnsWaitGroup.Wait()
@@ -351,9 +431,7 @@ var lookupGoogleIPTests = []struct {
}
func TestLookupGoogleIP(t *testing.T) {
- if testenv.Builder() == "" {
- testenv.MustHaveExternalNetwork(t)
- }
+ mustHaveExternalNetwork(t)
if !supportsIPv4() || !*testIPv4 {
t.Skip("IPv4 is required")
@@ -499,9 +577,7 @@ func TestLookupDotsWithLocalSource(t *testing.T) {
t.Skip("IPv4 is required")
}
- if testenv.Builder() == "" {
- testenv.MustHaveExternalNetwork(t)
- }
+ mustHaveExternalNetwork(t)
defer dnsWaitGroup.Wait()
@@ -542,14 +618,16 @@ func TestLookupDotsWithLocalSource(t *testing.T) {
}
func TestLookupDotsWithRemoteSource(t *testing.T) {
- if testenv.Builder() == "" {
- testenv.MustHaveExternalNetwork(t)
- }
+ mustHaveExternalNetwork(t)
if !supportsIPv4() || !*testIPv4 {
t.Skip("IPv4 is required")
}
+ if runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") {
+ t.Skip("no resolv.conf on iOS")
+ }
+
defer dnsWaitGroup.Wait()
if fixup := forceGoDNS(); fixup != nil {
@@ -664,10 +742,10 @@ func srvString(srvs []*SRV) string {
}
func TestLookupPort(t *testing.T) {
- // See http://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.xhtml
+ // See https://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.xhtml
//
// Please be careful about adding new test cases.
- // There are platforms having incomplete mappings for
+ // There are platforms which have incomplete mappings for
// restricted resource access and security reasons.
type test struct {
network string
@@ -793,9 +871,7 @@ func TestLookupNonLDH(t *testing.T) {
}
func TestLookupContextCancel(t *testing.T) {
- if testenv.Builder() == "" {
- testenv.MustHaveExternalNetwork(t)
- }
+ mustHaveExternalNetwork(t)
if runtime.GOOS == "nacl" {
t.Skip("skip on nacl")
}
@@ -816,3 +892,119 @@ func TestLookupContextCancel(t *testing.T) {
t.Fatal(err)
}
}
+
+// Issue 24330: treat the nil *Resolver like a zero value. Verify nothing
+// crashes if nil is used.
+func TestNilResolverLookup(t *testing.T) {
+ mustHaveExternalNetwork(t)
+ if runtime.GOOS == "nacl" {
+ t.Skip("skip on nacl")
+ }
+ var r *Resolver = nil
+ ctx := context.Background()
+
+ // Don't care about the results, just that nothing panics:
+ r.LookupAddr(ctx, "8.8.8.8")
+ r.LookupCNAME(ctx, "google.com")
+ r.LookupHost(ctx, "google.com")
+ r.LookupIPAddr(ctx, "google.com")
+ r.LookupMX(ctx, "gmail.com")
+ r.LookupNS(ctx, "google.com")
+ r.LookupPort(ctx, "tcp", "smtp")
+ r.LookupSRV(ctx, "service", "proto", "name")
+ r.LookupTXT(ctx, "gmail.com")
+}
+
+// TestLookupHostCancel verifies that lookup works even after many
+// canceled lookups (see golang.org/issue/24178 for details).
+func TestLookupHostCancel(t *testing.T) {
+ mustHaveExternalNetwork(t)
+ if runtime.GOOS == "nacl" {
+ t.Skip("skip on nacl")
+ }
+
+ const (
+ google = "www.google.com"
+ invalidDomain = "nonexistentdomain.golang.org"
+ n = 600 // this needs to be larger than threadLimit size
+ )
+
+ _, err := LookupHost(google)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ for i := 0; i < n; i++ {
+ addr, err := DefaultResolver.LookupHost(ctx, invalidDomain)
+ if err == nil {
+ t.Fatalf("LookupHost(%q): returns %v, but should fail", invalidDomain, addr)
+ }
+ if !strings.Contains(err.Error(), "canceled") {
+ t.Fatalf("LookupHost(%q): failed with unexpected error: %v", invalidDomain, err)
+ }
+ time.Sleep(time.Millisecond * 1)
+ }
+
+ _, err = LookupHost(google)
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+type lookupCustomResolver struct {
+ *Resolver
+ mu sync.RWMutex
+ dialed bool
+}
+
+func (lcr *lookupCustomResolver) dial() func(ctx context.Context, network, address string) (Conn, error) {
+ return func(ctx context.Context, network, address string) (Conn, error) {
+ lcr.mu.Lock()
+ lcr.dialed = true
+ lcr.mu.Unlock()
+ return Dial(network, address)
+ }
+}
+
+// TestConcurrentPreferGoResolversDial tests that multiple resolvers with the
+// PreferGo option used concurrently are all dialed properly.
+func TestConcurrentPreferGoResolversDial(t *testing.T) {
+ // The windows implementation of the resolver does not use the Dial
+ // function.
+ if runtime.GOOS == "windows" {
+ t.Skip("skip on windows")
+ }
+
+ testenv.MustHaveExternalNetwork(t)
+ testenv.SkipFlakyNet(t)
+
+ defer dnsWaitGroup.Wait()
+
+ resolvers := make([]*lookupCustomResolver, 2)
+ for i := range resolvers {
+ cs := lookupCustomResolver{Resolver: &Resolver{PreferGo: true}}
+ cs.Dial = cs.dial()
+ resolvers[i] = &cs
+ }
+
+ var wg sync.WaitGroup
+ wg.Add(len(resolvers))
+ for i, resolver := range resolvers {
+ go func(r *Resolver, index int) {
+ defer wg.Done()
+ _, err := r.LookupIPAddr(context.Background(), "google.com")
+ if err != nil {
+ t.Fatalf("lookup failed for resolver %d: %q", index, err)
+ }
+ }(resolver.Resolver, i)
+ }
+ wg.Wait()
+
+ for i, resolver := range resolvers {
+ if !resolver.dialed {
+ t.Errorf("custom resolver %d not dialed during lookup", i)
+ }
+ }
+}
diff --git a/libgo/go/net/lookup_unix.go b/libgo/go/net/lookup_unix.go
index 2813f14..76d6ae3 100644
--- a/libgo/go/net/lookup_unix.go
+++ b/libgo/go/net/lookup_unix.go
@@ -9,6 +9,9 @@ package net
import (
"context"
"sync"
+ "syscall"
+
+ "golang_org/x/net/dns/dnsmessage"
)
var onceReadProtocols sync.Once
@@ -51,7 +54,7 @@ func lookupProtocol(_ context.Context, name string) (int, error) {
return lookupProtocolMap(name)
}
-func (r *Resolver) dial(ctx context.Context, network, server string) (dnsConn, error) {
+func (r *Resolver) dial(ctx context.Context, network, server string) (Conn, error) {
// Calling Dial here is scary -- we have to be sure not to
// dial a name that will require a DNS lookup, or Dial will
// call back here to translate it. The DNS config parser has
@@ -59,7 +62,7 @@ func (r *Resolver) dial(ctx context.Context, network, server string) (dnsConn, e
// addresses, which Dial will use without a DNS lookup.
var c Conn
var err error
- if r.Dial != nil {
+ if r != nil && r.Dial != nil {
c, err = r.Dial(ctx, network, server)
} else {
var d Dialer
@@ -68,15 +71,12 @@ func (r *Resolver) dial(ctx context.Context, network, server string) (dnsConn, e
if err != nil {
return nil, mapErr(err)
}
- if _, ok := c.(PacketConn); ok {
- return &dnsPacketConn{c}, nil
- }
- return &dnsStreamConn{c}, nil
+ return c, nil
}
func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string, err error) {
- order := systemConf().hostLookupOrder(host)
- if !r.PreferGo && order == hostLookupCgo {
+ order := systemConf().hostLookupOrder(r, host)
+ if !r.preferGo() && order == hostLookupCgo {
if addrs, err, ok := cgoLookupHost(ctx, host); ok {
return addrs, err
}
@@ -87,10 +87,10 @@ func (r *Resolver) lookupHost(ctx context.Context, host string) (addrs []string,
}
func (r *Resolver) lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
- if r.PreferGo {
+ if r.preferGo() {
return r.goLookupIP(ctx, host)
}
- order := systemConf().hostLookupOrder(host)
+ order := systemConf().hostLookupOrder(r, host)
if order == hostLookupCgo {
if addrs, err, ok := cgoLookupIP(ctx, host); ok {
return addrs, err
@@ -98,12 +98,12 @@ func (r *Resolver) lookupIP(ctx context.Context, host string) (addrs []IPAddr, e
// cgo not available (or netgo); fall back to Go's DNS resolver
order = hostLookupFilesDNS
}
- addrs, _, err = r.goLookupIPCNAMEOrder(ctx, host, order)
- return
+ ips, _, err := r.goLookupIPCNAMEOrder(ctx, host, order)
+ return ips, err
}
func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int, error) {
- if !r.PreferGo && systemConf().canUseCgo() {
+ if !r.preferGo() && systemConf().canUseCgo() {
if port, err, ok := cgoLookupPort(ctx, network, service); ok {
if err != nil {
// Issue 18213: if cgo fails, first check to see whether we
@@ -119,7 +119,7 @@ func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int
}
func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error) {
- if !r.PreferGo && systemConf().canUseCgo() {
+ if !r.preferGo() && systemConf().canUseCgo() {
if cname, err, ok := cgoLookupCNAME(ctx, name); ok {
return cname, err
}
@@ -134,62 +134,209 @@ func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (
} else {
target = "_" + service + "._" + proto + "." + name
}
- cname, rrs, err := r.lookup(ctx, target, dnsTypeSRV)
+ p, server, err := r.lookup(ctx, target, dnsmessage.TypeSRV)
if err != nil {
return "", nil, err
}
- srvs := make([]*SRV, len(rrs))
- for i, rr := range rrs {
- rr := rr.(*dnsRR_SRV)
- srvs[i] = &SRV{Target: rr.Target, Port: rr.Port, Priority: rr.Priority, Weight: rr.Weight}
+ var srvs []*SRV
+ var cname dnsmessage.Name
+ for {
+ h, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ break
+ }
+ if err != nil {
+ return "", nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ if h.Type != dnsmessage.TypeSRV {
+ if err := p.SkipAnswer(); err != nil {
+ return "", nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ continue
+ }
+ if cname.Length == 0 && h.Name.Length != 0 {
+ cname = h.Name
+ }
+ srv, err := p.SRVResource()
+ if err != nil {
+ return "", nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ srvs = append(srvs, &SRV{Target: srv.Target.String(), Port: srv.Port, Priority: srv.Priority, Weight: srv.Weight})
}
byPriorityWeight(srvs).sort()
- return cname, srvs, nil
+ return cname.String(), srvs, nil
}
func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
- _, rrs, err := r.lookup(ctx, name, dnsTypeMX)
+ p, server, err := r.lookup(ctx, name, dnsmessage.TypeMX)
if err != nil {
return nil, err
}
- mxs := make([]*MX, len(rrs))
- for i, rr := range rrs {
- rr := rr.(*dnsRR_MX)
- mxs[i] = &MX{Host: rr.Mx, Pref: rr.Pref}
+ var mxs []*MX
+ for {
+ h, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ break
+ }
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ if h.Type != dnsmessage.TypeMX {
+ if err := p.SkipAnswer(); err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ continue
+ }
+ mx, err := p.MXResource()
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ mxs = append(mxs, &MX{Host: mx.MX.String(), Pref: mx.Pref})
+
}
byPref(mxs).sort()
return mxs, nil
}
func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
- _, rrs, err := r.lookup(ctx, name, dnsTypeNS)
+ p, server, err := r.lookup(ctx, name, dnsmessage.TypeNS)
if err != nil {
return nil, err
}
- nss := make([]*NS, len(rrs))
- for i, rr := range rrs {
- nss[i] = &NS{Host: rr.(*dnsRR_NS).Ns}
+ var nss []*NS
+ for {
+ h, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ break
+ }
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ if h.Type != dnsmessage.TypeNS {
+ if err := p.SkipAnswer(); err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ continue
+ }
+ ns, err := p.NSResource()
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ nss = append(nss, &NS{Host: ns.NS.String()})
}
return nss, nil
}
func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) {
- _, rrs, err := r.lookup(ctx, name, dnsTypeTXT)
+ p, server, err := r.lookup(ctx, name, dnsmessage.TypeTXT)
if err != nil {
return nil, err
}
- txts := make([]string, len(rrs))
- for i, rr := range rrs {
- txts[i] = rr.(*dnsRR_TXT).Txt
+ var txts []string
+ for {
+ h, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ break
+ }
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ if h.Type != dnsmessage.TypeTXT {
+ if err := p.SkipAnswer(); err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ continue
+ }
+ txt, err := p.TXTResource()
+ if err != nil {
+ return nil, &DNSError{
+ Err: "cannot unmarshal DNS message",
+ Name: name,
+ Server: server,
+ }
+ }
+ if len(txts) == 0 {
+ txts = txt.TXT
+ } else {
+ txts = append(txts, txt.TXT...)
+ }
}
return txts, nil
}
func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) {
- if !r.PreferGo && systemConf().canUseCgo() {
+ if !r.preferGo() && systemConf().canUseCgo() {
if ptrs, err, ok := cgoLookupPTR(ctx, addr); ok {
return ptrs, err
}
}
return r.goLookupPTR(ctx, addr)
}
+
+// concurrentThreadsLimit returns the number of threads we permit to
+// run concurrently doing DNS lookups via cgo. A DNS lookup may use a
+// file descriptor so we limit this to less than the number of
+// permitted open files. On some systems, notably Darwin, if
+// getaddrinfo is unable to open a file descriptor it simply returns
+// EAI_NONAME rather than a useful error. Limiting the number of
+// concurrent getaddrinfo calls to less than the permitted number of
+// file descriptors makes that error less likely. We don't bother to
+// apply the same limit to DNS lookups run directly from Go, because
+// there we will return a meaningful "too many open files" error.
+func concurrentThreadsLimit() int {
+ var rlim syscall.Rlimit
+ if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rlim); err != nil {
+ return 500
+ }
+ r := int(rlim.Cur)
+ if r > 500 {
+ r = 500
+ } else if r > 30 {
+ r -= 30
+ }
+ return r
+}
diff --git a/libgo/go/net/lookup_windows.go b/libgo/go/net/lookup_windows.go
index ac1f9b4..f76e0af 100644
--- a/libgo/go/net/lookup_windows.go
+++ b/libgo/go/net/lookup_windows.go
@@ -79,12 +79,7 @@ func (r *Resolver) lookupHost(ctx context.Context, name string) ([]string, error
func (r *Resolver) lookupIP(ctx context.Context, name string) ([]IPAddr, error) {
// TODO(bradfitz,brainman): use ctx more. See TODO below.
- type ret struct {
- addrs []IPAddr
- err error
- }
- ch := make(chan ret, 1)
- go func() {
+ getaddr := func() ([]IPAddr, error) {
acquireThread()
defer releaseThread()
hints := syscall.AddrinfoW{
@@ -95,7 +90,7 @@ func (r *Resolver) lookupIP(ctx context.Context, name string) ([]IPAddr, error)
var result *syscall.AddrinfoW
e := syscall.GetAddrInfoW(syscall.StringToUTF16Ptr(name), nil, &hints, &result)
if e != nil {
- ch <- ret{err: &DNSError{Err: winError("getaddrinfow", e).Error(), Name: name}}
+ return nil, &DNSError{Err: winError("getaddrinfow", e).Error(), Name: name}
}
defer syscall.FreeAddrInfoW(result)
addrs := make([]IPAddr, 0, 5)
@@ -110,11 +105,23 @@ func (r *Resolver) lookupIP(ctx context.Context, name string) ([]IPAddr, error)
zone := zoneCache.name(int((*syscall.RawSockaddrInet6)(addr).Scope_id))
addrs = append(addrs, IPAddr{IP: IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]}, Zone: zone})
default:
- ch <- ret{err: &DNSError{Err: syscall.EWINDOWS.Error(), Name: name}}
+ return nil, &DNSError{Err: syscall.EWINDOWS.Error(), Name: name}
}
}
- ch <- ret{addrs: addrs}
+ return addrs, nil
+ }
+
+ type ret struct {
+ addrs []IPAddr
+ err error
+ }
+
+ ch := make(chan ret, 1)
+ go func() {
+ addr, err := getaddr()
+ ch <- ret{addrs: addr, err: err}
}()
+
select {
case r := <-ch:
return r.addrs, r.err
@@ -136,7 +143,7 @@ func (r *Resolver) lookupIP(ctx context.Context, name string) ([]IPAddr, error)
}
func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int, error) {
- if r.PreferGo {
+ if r.preferGo() {
return lookupPortMap(network, service)
}
@@ -357,3 +364,9 @@ Cname:
}
return name
}
+
+// concurrentThreadsLimit returns the number of threads we permit to
+// run concurrently doing DNS lookups.
+func concurrentThreadsLimit() int {
+ return 500
+}
diff --git a/libgo/go/net/mail/message.go b/libgo/go/net/mail/message.go
index 4f3184f..5912b90 100644
--- a/libgo/go/net/mail/message.go
+++ b/libgo/go/net/mail/message.go
@@ -19,7 +19,6 @@ package mail
import (
"bufio"
- "bytes"
"errors"
"fmt"
"io"
@@ -735,7 +734,7 @@ func isQtext(r rune) bool {
// quoteString renders a string as an RFC 5322 quoted-string.
func quoteString(s string) string {
- var buf bytes.Buffer
+ var buf strings.Builder
buf.WriteByte('"')
for _, r := range s {
if isQtext(r) || isWSP(r) {
diff --git a/libgo/go/net/main_cloexec_test.go b/libgo/go/net/main_cloexec_test.go
index fa1ed02..5398f9e 100644
--- a/libgo/go/net/main_cloexec_test.go
+++ b/libgo/go/net/main_cloexec_test.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build dragonfly freebsd linux
+// +build dragonfly freebsd linux netbsd openbsd
package net
diff --git a/libgo/go/net/main_conf_test.go b/libgo/go/net/main_conf_test.go
index 9875cea..b535046 100644
--- a/libgo/go/net/main_conf_test.go
+++ b/libgo/go/net/main_conf_test.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build !nacl,!plan9,!windows
+// +build !js,!nacl,!plan9,!windows
package net
diff --git a/libgo/go/net/main_noconf_test.go b/libgo/go/net/main_noconf_test.go
index 489477b..55e3770 100644
--- a/libgo/go/net/main_noconf_test.go
+++ b/libgo/go/net/main_noconf_test.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build nacl plan9 windows
+// +build js,wasm nacl plan9 windows
package net
diff --git a/libgo/go/net/main_posix_test.go b/libgo/go/net/main_posix_test.go
index ead311c..f2484f3 100644
--- a/libgo/go/net/main_posix_test.go
+++ b/libgo/go/net/main_posix_test.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build !plan9
+// +build !js,!plan9
package net
diff --git a/libgo/go/net/main_test.go b/libgo/go/net/main_test.go
index 3e7a85a..85a269d 100644
--- a/libgo/go/net/main_test.go
+++ b/libgo/go/net/main_test.go
@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// +build !js
+
package net
import (
diff --git a/libgo/go/net/mockserver_test.go b/libgo/go/net/mockserver_test.go
index 44581d9..5302935 100644
--- a/libgo/go/net/mockserver_test.go
+++ b/libgo/go/net/mockserver_test.go
@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// +build !js
+
package net
import (
diff --git a/libgo/go/net/net.go b/libgo/go/net/net.go
index 3ad9103..c909986 100644
--- a/libgo/go/net/net.go
+++ b/libgo/go/net/net.go
@@ -84,6 +84,7 @@ import (
"internal/poll"
"io"
"os"
+ "sync"
"syscall"
"time"
)
@@ -229,7 +230,7 @@ func (c *conn) SetDeadline(t time.Time) error {
if !c.ok() {
return syscall.EINVAL
}
- if err := c.fd.pfd.SetDeadline(t); err != nil {
+ if err := c.fd.SetDeadline(t); err != nil {
return &OpError{Op: "set", Net: c.fd.net, Source: nil, Addr: c.fd.laddr, Err: err}
}
return nil
@@ -240,7 +241,7 @@ func (c *conn) SetReadDeadline(t time.Time) error {
if !c.ok() {
return syscall.EINVAL
}
- if err := c.fd.pfd.SetReadDeadline(t); err != nil {
+ if err := c.fd.SetReadDeadline(t); err != nil {
return &OpError{Op: "set", Net: c.fd.net, Source: nil, Addr: c.fd.laddr, Err: err}
}
return nil
@@ -251,7 +252,7 @@ func (c *conn) SetWriteDeadline(t time.Time) error {
if !c.ok() {
return syscall.EINVAL
}
- if err := c.fd.pfd.SetWriteDeadline(t); err != nil {
+ if err := c.fd.SetWriteDeadline(t); err != nil {
return &OpError{Op: "set", Net: c.fd.net, Source: nil, Addr: c.fd.laddr, Err: err}
}
return nil
@@ -281,15 +282,13 @@ func (c *conn) SetWriteBuffer(bytes int) error {
return nil
}
-// File sets the underlying os.File to blocking mode and returns a copy.
+// File returns a copy of the underlying os.File
// It is the caller's responsibility to close f when finished.
// Closing c does not affect f, and closing f does not affect c.
//
// The returned os.File's file descriptor is different from the connection's.
// Attempting to change properties of the original using this duplicate
// may or may not have the desired effect.
-//
-// On Unix systems this will cause the SetDeadline methods to stop working.
func (c *conn) File() (f *os.File, err error) {
f, err = c.fd.dup()
if err != nil {
@@ -303,20 +302,23 @@ func (c *conn) File() (f *os.File, err error) {
// Multiple goroutines may invoke methods on a PacketConn simultaneously.
type PacketConn interface {
// ReadFrom reads a packet from the connection,
- // copying the payload into b. It returns the number of
- // bytes copied into b and the return address that
+ // copying the payload into p. It returns the number of
+ // bytes copied into p and the return address that
// was on the packet.
+ // It returns the number of bytes read (0 <= n <= len(p))
+ // and any error encountered. Callers should always process
+ // the n > 0 bytes returned before considering the error err.
// ReadFrom can be made to time out and return
// an Error with Timeout() == true after a fixed time limit;
// see SetDeadline and SetReadDeadline.
- ReadFrom(b []byte) (n int, addr Addr, err error)
+ ReadFrom(p []byte) (n int, addr Addr, err error)
- // WriteTo writes a packet with payload b to addr.
+ // WriteTo writes a packet with payload p to addr.
// WriteTo can be made to time out and return
// an Error with Timeout() == true after a fixed time limit;
// see SetDeadline and SetWriteDeadline.
// On packet-oriented connections, write timeouts are rare.
- WriteTo(b []byte, addr Addr) (n int, err error)
+ WriteTo(p []byte, addr Addr) (n int, err error)
// Close closes the connection.
// Any blocked ReadFrom or WriteTo operations will be unblocked and return errors.
@@ -489,6 +491,12 @@ type temporary interface {
}
func (e *OpError) Temporary() bool {
+ // Treat ECONNRESET and ECONNABORTED as temporary errors when
+ // they come from calling accept. See issue 6163.
+ if e.Op == "accept" && isConnError(e.Err) {
+ return true
+ }
+
if ne, ok := e.Err.(*os.SyscallError); ok {
t, ok := ne.Err.(temporary)
return ok && t.Temporary()
@@ -603,9 +611,14 @@ func genericReadFrom(w io.Writer, r io.Reader) (n int64, err error) {
// server is not responding. Then the many lookups each use a different
// thread, and the system or the program runs out of threads.
-var threadLimit = make(chan struct{}, 500)
+var threadLimit chan struct{}
+
+var threadOnce sync.Once
func acquireThread() {
+ threadOnce.Do(func() {
+ threadLimit = make(chan struct{}, concurrentThreadsLimit())
+ })
threadLimit <- struct{}{}
}
diff --git a/libgo/go/net/net_fake.go b/libgo/go/net/net_fake.go
new file mode 100644
index 0000000..0c48dd5
--- /dev/null
+++ b/libgo/go/net/net_fake.go
@@ -0,0 +1,284 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Fake networking for js/wasm. It is intended to allow tests of other package to pass.
+
+// +build js,wasm
+
+package net
+
+import (
+ "context"
+ "internal/poll"
+ "io"
+ "os"
+ "sync"
+ "syscall"
+ "time"
+)
+
+var listenersMu sync.Mutex
+var listeners = make(map[string]*netFD)
+
+var portCounterMu sync.Mutex
+var portCounter = 0
+
+func nextPort() int {
+ portCounterMu.Lock()
+ defer portCounterMu.Unlock()
+ portCounter++
+ return portCounter
+}
+
+// Network file descriptor.
+type netFD struct {
+ r *bufferedPipe
+ w *bufferedPipe
+ incoming chan *netFD
+
+ closedMu sync.Mutex
+ closed bool
+
+ // immutable until Close
+ listener bool
+ family int
+ sotype int
+ net string
+ laddr Addr
+ raddr Addr
+
+ // unused
+ pfd poll.FD
+ isConnected bool // handshake completed or use of association with peer
+}
+
+// socket returns a network file descriptor that is ready for
+// asynchronous I/O using the network poller.
+func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) (*netFD, error) {
+ fd := &netFD{family: family, sotype: sotype, net: net}
+
+ if laddr != nil && raddr == nil { // listener
+ l := laddr.(*TCPAddr)
+ fd.laddr = &TCPAddr{
+ IP: l.IP,
+ Port: nextPort(),
+ Zone: l.Zone,
+ }
+ fd.listener = true
+ fd.incoming = make(chan *netFD, 1024)
+ listenersMu.Lock()
+ listeners[fd.laddr.(*TCPAddr).String()] = fd
+ listenersMu.Unlock()
+ return fd, nil
+ }
+
+ fd.laddr = &TCPAddr{
+ IP: IPv4(127, 0, 0, 1),
+ Port: nextPort(),
+ }
+ fd.raddr = raddr
+ fd.r = newBufferedPipe(65536)
+ fd.w = newBufferedPipe(65536)
+
+ fd2 := &netFD{family: fd.family, sotype: sotype, net: net}
+ fd2.laddr = fd.raddr
+ fd2.raddr = fd.laddr
+ fd2.r = fd.w
+ fd2.w = fd.r
+ listenersMu.Lock()
+ l, ok := listeners[fd.raddr.(*TCPAddr).String()]
+ if !ok {
+ listenersMu.Unlock()
+ return nil, syscall.ECONNREFUSED
+ }
+ l.incoming <- fd2
+ listenersMu.Unlock()
+
+ return fd, nil
+}
+
+func (fd *netFD) Read(p []byte) (n int, err error) {
+ return fd.r.Read(p)
+}
+
+func (fd *netFD) Write(p []byte) (nn int, err error) {
+ return fd.w.Write(p)
+}
+
+func (fd *netFD) Close() error {
+ fd.closedMu.Lock()
+ if fd.closed {
+ fd.closedMu.Unlock()
+ return nil
+ }
+ fd.closed = true
+ fd.closedMu.Unlock()
+
+ if fd.listener {
+ listenersMu.Lock()
+ delete(listeners, fd.laddr.String())
+ close(fd.incoming)
+ fd.listener = false
+ listenersMu.Unlock()
+ return nil
+ }
+
+ fd.r.Close()
+ fd.w.Close()
+ return nil
+}
+
+func (fd *netFD) closeRead() error {
+ fd.r.Close()
+ return nil
+}
+
+func (fd *netFD) closeWrite() error {
+ fd.w.Close()
+ return nil
+}
+
+func (fd *netFD) accept() (*netFD, error) {
+ c, ok := <-fd.incoming
+ if !ok {
+ return nil, syscall.EINVAL
+ }
+ return c, nil
+}
+
+func (fd *netFD) SetDeadline(t time.Time) error {
+ fd.r.SetReadDeadline(t)
+ fd.w.SetWriteDeadline(t)
+ return nil
+}
+
+func (fd *netFD) SetReadDeadline(t time.Time) error {
+ fd.r.SetReadDeadline(t)
+ return nil
+}
+
+func (fd *netFD) SetWriteDeadline(t time.Time) error {
+ fd.w.SetWriteDeadline(t)
+ return nil
+}
+
+func newBufferedPipe(softLimit int) *bufferedPipe {
+ p := &bufferedPipe{softLimit: softLimit}
+ p.rCond.L = &p.mu
+ p.wCond.L = &p.mu
+ return p
+}
+
+type bufferedPipe struct {
+ softLimit int
+ mu sync.Mutex
+ buf []byte
+ closed bool
+ rCond sync.Cond
+ wCond sync.Cond
+ rDeadline time.Time
+ wDeadline time.Time
+}
+
+func (p *bufferedPipe) Read(b []byte) (int, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ for {
+ if p.closed && len(p.buf) == 0 {
+ return 0, io.EOF
+ }
+ if !p.rDeadline.IsZero() {
+ d := time.Until(p.rDeadline)
+ if d <= 0 {
+ return 0, syscall.EAGAIN
+ }
+ time.AfterFunc(d, p.rCond.Broadcast)
+ }
+ if len(p.buf) > 0 {
+ break
+ }
+ p.rCond.Wait()
+ }
+
+ n := copy(b, p.buf)
+ p.buf = p.buf[n:]
+ p.wCond.Broadcast()
+ return n, nil
+}
+
+func (p *bufferedPipe) Write(b []byte) (int, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ for {
+ if p.closed {
+ return 0, syscall.ENOTCONN
+ }
+ if !p.wDeadline.IsZero() {
+ d := time.Until(p.wDeadline)
+ if d <= 0 {
+ return 0, syscall.EAGAIN
+ }
+ time.AfterFunc(d, p.wCond.Broadcast)
+ }
+ if len(p.buf) <= p.softLimit {
+ break
+ }
+ p.wCond.Wait()
+ }
+
+ p.buf = append(p.buf, b...)
+ p.rCond.Broadcast()
+ return len(b), nil
+}
+
+func (p *bufferedPipe) Close() {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ p.closed = true
+ p.rCond.Broadcast()
+ p.wCond.Broadcast()
+}
+
+func (p *bufferedPipe) SetReadDeadline(t time.Time) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ p.rDeadline = t
+ p.rCond.Broadcast()
+}
+
+func (p *bufferedPipe) SetWriteDeadline(t time.Time) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ p.wDeadline = t
+ p.wCond.Broadcast()
+}
+
+func sysSocket(family, sotype, proto int) (int, error) {
+ return 0, syscall.ENOSYS
+}
+
+func (fd *netFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
+ return 0, nil, syscall.ENOSYS
+}
+
+func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) {
+ return 0, 0, 0, nil, syscall.ENOSYS
+}
+
+func (fd *netFD) writeTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
+ return 0, syscall.ENOSYS
+}
+
+func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
+ return 0, 0, syscall.ENOSYS
+}
+
+func (fd *netFD) dup() (f *os.File, err error) {
+ return nil, syscall.ENOSYS
+}
diff --git a/libgo/go/net/net_test.go b/libgo/go/net/net_test.go
index 024505e..692f269 100644
--- a/libgo/go/net/net_test.go
+++ b/libgo/go/net/net_test.go
@@ -2,11 +2,14 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// +build !js
+
package net
import (
"errors"
"fmt"
+ "internal/testenv"
"io"
"net/internal/socktest"
"os"
@@ -516,3 +519,33 @@ func TestCloseUnblocksRead(t *testing.T) {
}
withTCPConnPair(t, client, server)
}
+
+// Issue 24808: verify that ECONNRESET is not temporary for read.
+func TestNotTemporaryRead(t *testing.T) {
+ if runtime.GOOS == "freebsd" {
+ testenv.SkipFlaky(t, 25289)
+ }
+ t.Parallel()
+ server := func(cs *TCPConn) error {
+ cs.SetLinger(0)
+ // Give the client time to get stuck in a Read.
+ time.Sleep(20 * time.Millisecond)
+ cs.Close()
+ return nil
+ }
+ client := func(ss *TCPConn) error {
+ _, err := ss.Read([]byte{0})
+ if err == nil {
+ return errors.New("Read succeeded unexpectedly")
+ } else if err == io.EOF {
+ // This happens on NaCl and Plan 9.
+ return nil
+ } else if ne, ok := err.(Error); !ok {
+ return fmt.Errorf("unexpected error %v", err)
+ } else if ne.Temporary() {
+ return fmt.Errorf("unexpected temporary error %v", err)
+ }
+ return nil
+ }
+ withTCPConnPair(t, client, server)
+}
diff --git a/libgo/go/net/packetconn_test.go b/libgo/go/net/packetconn_test.go
index 7d50489..a377d33 100644
--- a/libgo/go/net/packetconn_test.go
+++ b/libgo/go/net/packetconn_test.go
@@ -5,6 +5,8 @@
// This file implements API tests across platforms and will never have a build
// tag.
+// +build !js
+
package net
import (
diff --git a/libgo/go/net/port_unix.go b/libgo/go/net/port_unix.go
index 8dd1c32..ea3bb02 100644
--- a/libgo/go/net/port_unix.go
+++ b/libgo/go/net/port_unix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris nacl
+// +build aix darwin dragonfly freebsd js,wasm linux netbsd openbsd solaris nacl
// Read system port mappings from /etc/services
diff --git a/libgo/go/net/protoconn_test.go b/libgo/go/net/protoconn_test.go
index def8d65..9f6772c 100644
--- a/libgo/go/net/protoconn_test.go
+++ b/libgo/go/net/protoconn_test.go
@@ -5,10 +5,11 @@
// This file implements API tests across platforms and will never have a build
// tag.
+// +build !js
+
package net
import (
- "internal/testenv"
"os"
"runtime"
"testing"
@@ -139,15 +140,11 @@ func TestUDPConnSpecificMethods(t *testing.T) {
if _, _, err := c.ReadFromUDP(rb); err != nil {
t.Fatal(err)
}
- if testenv.IsWindowsXP() {
- t.Log("skipping broken test on Windows XP (see golang.org/issue/23072)")
- } else {
- if _, _, err := c.WriteMsgUDP(wb, nil, c.LocalAddr().(*UDPAddr)); err != nil {
- condFatalf(t, c.LocalAddr().Network(), "%v", err)
- }
- if _, _, _, _, err := c.ReadMsgUDP(rb, nil); err != nil {
- condFatalf(t, c.LocalAddr().Network(), "%v", err)
- }
+ if _, _, err := c.WriteMsgUDP(wb, nil, c.LocalAddr().(*UDPAddr)); err != nil {
+ condFatalf(t, c.LocalAddr().Network(), "%v", err)
+ }
+ if _, _, _, _, err := c.ReadMsgUDP(rb, nil); err != nil {
+ condFatalf(t, c.LocalAddr().Network(), "%v", err)
}
if f, err := c.File(); err != nil {
diff --git a/libgo/go/net/rawconn.go b/libgo/go/net/rawconn.go
index 2399c9f..c40ea4a 100644
--- a/libgo/go/net/rawconn.go
+++ b/libgo/go/net/rawconn.go
@@ -9,11 +9,14 @@ import (
"syscall"
)
-// BUG(mikio): On Windows, the Read and Write methods of
-// syscall.RawConn are not implemented.
+// BUG(tmm1): On Windows, the Write method of syscall.RawConn
+// does not integrate with the runtime's network poller. It cannot
+// wait for the connection to become writeable, and does not respect
+// deadlines. If the user-provided callback returns false, the Write
+// method will fail immediately.
-// BUG(mikio): On NaCl and Plan 9, the Control, Read and Write methods
-// of syscall.RawConn are not implemented.
+// BUG(mikio): On JS, NaCl and Plan 9, the Control, Read and Write
+// methods of syscall.RawConn are not implemented.
type rawConn struct {
fd *netFD
diff --git a/libgo/go/net/rawconn_stub_test.go b/libgo/go/net/rawconn_stub_test.go
new file mode 100644
index 0000000..0a033c1
--- /dev/null
+++ b/libgo/go/net/rawconn_stub_test.go
@@ -0,0 +1,28 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build js,wasm nacl plan9
+
+package net
+
+import (
+ "errors"
+ "syscall"
+)
+
+func readRawConn(c syscall.RawConn, b []byte) (int, error) {
+ return 0, errors.New("not supported")
+}
+
+func writeRawConn(c syscall.RawConn, b []byte) error {
+ return errors.New("not supported")
+}
+
+func controlRawConn(c syscall.RawConn, addr Addr) error {
+ return errors.New("not supported")
+}
+
+func controlOnConnSetup(network string, address string, c syscall.RawConn) error {
+ return nil
+}
diff --git a/libgo/go/net/rawconn_test.go b/libgo/go/net/rawconn_test.go
new file mode 100644
index 0000000..11900df
--- /dev/null
+++ b/libgo/go/net/rawconn_test.go
@@ -0,0 +1,220 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !js
+
+package net
+
+import (
+ "bytes"
+ "runtime"
+ "testing"
+ "time"
+)
+
+func TestRawConnReadWrite(t *testing.T) {
+ switch runtime.GOOS {
+ case "nacl", "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ t.Run("TCP", func(t *testing.T) {
+ handler := func(ls *localServer, ln Listener) {
+ c, err := ln.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ defer c.Close()
+
+ cc, err := ln.(*TCPListener).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+ called := false
+ op := func(uintptr) bool {
+ called = true
+ return true
+ }
+ err = cc.Write(op)
+ if err == nil {
+ t.Error("Write should return an error")
+ }
+ if called {
+ t.Error("Write shouldn't call op")
+ }
+ called = false
+ err = cc.Read(op)
+ if err == nil {
+ t.Error("Read should return an error")
+ }
+ if called {
+ t.Error("Read shouldn't call op")
+ }
+
+ var b [32]byte
+ n, err := c.Read(b[:])
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ if _, err := c.Write(b[:n]); err != nil {
+ t.Error(err)
+ return
+ }
+ }
+ ls, err := newLocalServer("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ls.teardown()
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+
+ c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ cc, err := c.(*TCPConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+ data := []byte("HELLO-R-U-THERE")
+ if err := writeRawConn(cc, data); err != nil {
+ t.Fatal(err)
+ }
+ var b [32]byte
+ n, err := readRawConn(cc, b[:])
+ if err != nil {
+ t.Fatal(err)
+ }
+ if bytes.Compare(b[:n], data) != 0 {
+ t.Fatalf("got %q; want %q", b[:n], data)
+ }
+ })
+ t.Run("Deadline", func(t *testing.T) {
+ switch runtime.GOOS {
+ case "windows":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ ln, err := newLocalListener("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+
+ c, err := Dial(ln.Addr().Network(), ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ cc, err := c.(*TCPConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+ var b [1]byte
+
+ c.SetDeadline(noDeadline)
+ if err := c.SetDeadline(time.Now().Add(-1)); err != nil {
+ t.Fatal(err)
+ }
+ if err = writeRawConn(cc, b[:]); err == nil {
+ t.Fatal("Write should fail")
+ }
+ if perr := parseWriteError(err); perr != nil {
+ t.Error(perr)
+ }
+ if nerr, ok := err.(Error); !ok || !nerr.Timeout() {
+ t.Errorf("got %v; want timeout", err)
+ }
+ if _, err = readRawConn(cc, b[:]); err == nil {
+ t.Fatal("Read should fail")
+ }
+ if perr := parseReadError(err); perr != nil {
+ t.Error(perr)
+ }
+ if nerr, ok := err.(Error); !ok || !nerr.Timeout() {
+ t.Errorf("got %v; want timeout", err)
+ }
+
+ c.SetReadDeadline(noDeadline)
+ if err := c.SetReadDeadline(time.Now().Add(-1)); err != nil {
+ t.Fatal(err)
+ }
+ if _, err = readRawConn(cc, b[:]); err == nil {
+ t.Fatal("Read should fail")
+ }
+ if perr := parseReadError(err); perr != nil {
+ t.Error(perr)
+ }
+ if nerr, ok := err.(Error); !ok || !nerr.Timeout() {
+ t.Errorf("got %v; want timeout", err)
+ }
+
+ c.SetWriteDeadline(noDeadline)
+ if err := c.SetWriteDeadline(time.Now().Add(-1)); err != nil {
+ t.Fatal(err)
+ }
+ if err = writeRawConn(cc, b[:]); err == nil {
+ t.Fatal("Write should fail")
+ }
+ if perr := parseWriteError(err); perr != nil {
+ t.Error(perr)
+ }
+ if nerr, ok := err.(Error); !ok || !nerr.Timeout() {
+ t.Errorf("got %v; want timeout", err)
+ }
+ })
+}
+
+func TestRawConnControl(t *testing.T) {
+ switch runtime.GOOS {
+ case "nacl", "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ t.Run("TCP", func(t *testing.T) {
+ ln, err := newLocalListener("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+
+ cc1, err := ln.(*TCPListener).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := controlRawConn(cc1, ln.Addr()); err != nil {
+ t.Fatal(err)
+ }
+
+ c, err := Dial(ln.Addr().Network(), ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ cc2, err := c.(*TCPConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := controlRawConn(cc2, c.LocalAddr()); err != nil {
+ t.Fatal(err)
+ }
+
+ ln.Close()
+ if err := controlRawConn(cc1, ln.Addr()); err == nil {
+ t.Fatal("Control after Close should fail")
+ }
+ c.Close()
+ if err := controlRawConn(cc2, c.LocalAddr()); err == nil {
+ t.Fatal("Control after Close should fail")
+ }
+ })
+}
diff --git a/libgo/go/net/rawconn_unix_test.go b/libgo/go/net/rawconn_unix_test.go
index 913ad86..a720a8a 100644
--- a/libgo/go/net/rawconn_unix_test.go
+++ b/libgo/go/net/rawconn_unix_test.go
@@ -7,138 +7,121 @@
package net
import (
- "bytes"
+ "errors"
"syscall"
- "testing"
)
-func TestRawConn(t *testing.T) {
- handler := func(ls *localServer, ln Listener) {
- c, err := ln.Accept()
- if err != nil {
- t.Error(err)
- return
- }
- defer c.Close()
- var b [32]byte
- n, err := c.Read(b[:])
- if err != nil {
- t.Error(err)
- return
- }
- if _, err := c.Write(b[:n]); err != nil {
- t.Error(err)
- return
- }
- }
- ls, err := newLocalServer("tcp")
- if err != nil {
- t.Fatal(err)
- }
- defer ls.teardown()
- if err := ls.buildup(handler); err != nil {
- t.Fatal(err)
- }
-
- c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
- if err != nil {
- t.Fatal(err)
- }
- defer c.Close()
- cc, err := c.(*TCPConn).SyscallConn()
- if err != nil {
- t.Fatal(err)
- }
-
+func readRawConn(c syscall.RawConn, b []byte) (int, error) {
var operr error
- data := []byte("HELLO-R-U-THERE")
- err = cc.Write(func(s uintptr) bool {
- _, operr = syscall.Write(int(s), data)
+ var n int
+ err := c.Read(func(s uintptr) bool {
+ n, operr = syscall.Read(int(s), b)
if operr == syscall.EAGAIN {
return false
}
return true
})
- if err != nil || operr != nil {
- t.Fatal(err, operr)
+ if err != nil {
+ return n, err
+ }
+ if operr != nil {
+ return n, operr
}
+ return n, nil
+}
- var nr int
- var b [32]byte
- err = cc.Read(func(s uintptr) bool {
- nr, operr = syscall.Read(int(s), b[:])
+func writeRawConn(c syscall.RawConn, b []byte) error {
+ var operr error
+ err := c.Write(func(s uintptr) bool {
+ _, operr = syscall.Write(int(s), b)
if operr == syscall.EAGAIN {
return false
}
return true
})
- if err != nil || operr != nil {
- t.Fatal(err, operr)
+ if err != nil {
+ return err
}
- if bytes.Compare(b[:nr], data) != 0 {
- t.Fatalf("got %#v; want %#v", b[:nr], data)
+ if operr != nil {
+ return operr
}
+ return nil
+}
+func controlRawConn(c syscall.RawConn, addr Addr) error {
+ var operr error
fn := func(s uintptr) {
- operr = syscall.SetsockoptInt(int(s), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
+ _, operr = syscall.GetsockoptInt(int(s), syscall.SOL_SOCKET, syscall.SO_REUSEADDR)
+ if operr != nil {
+ return
+ }
+ switch addr := addr.(type) {
+ case *TCPAddr:
+ // There's no guarantee that IP-level socket
+ // options work well with dual stack sockets.
+ // A simple solution would be to take a look
+ // at the bound address to the raw connection
+ // and to classify the address family of the
+ // underlying socket by the bound address:
+ //
+ // - When IP.To16() != nil and IP.To4() == nil,
+ // we can assume that the raw connection
+ // consists of an IPv6 socket using only
+ // IPv6 addresses.
+ //
+ // - When IP.To16() == nil and IP.To4() != nil,
+ // the raw connection consists of an IPv4
+ // socket using only IPv4 addresses.
+ //
+ // - Otherwise, the raw connection is a dual
+ // stack socket, an IPv6 socket using IPv6
+ // addresses including IPv4-mapped or
+ // IPv4-embedded IPv6 addresses.
+ if addr.IP.To16() != nil && addr.IP.To4() == nil {
+ operr = syscall.SetsockoptInt(int(s), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, 1)
+ } else if addr.IP.To16() == nil && addr.IP.To4() != nil {
+ operr = syscall.SetsockoptInt(int(s), syscall.IPPROTO_IP, syscall.IP_TTL, 1)
+ }
+ }
}
- err = cc.Control(fn)
- if err != nil || operr != nil {
- t.Fatal(err, operr)
+ if err := c.Control(fn); err != nil {
+ return err
}
- c.Close()
- err = cc.Control(fn)
- if err == nil {
- t.Fatal("should fail")
+ if operr != nil {
+ return operr
}
+ return nil
}
-func TestRawConnListener(t *testing.T) {
- ln, err := newLocalListener("tcp")
- if err != nil {
- t.Fatal(err)
- }
- defer ln.Close()
-
- cc, err := ln.(*TCPListener).SyscallConn()
- if err != nil {
- t.Fatal(err)
- }
-
- called := false
- op := func(uintptr) bool {
- called = true
- return true
- }
-
- err = cc.Write(op)
- if err == nil {
- t.Error("Write should return an error")
- }
- if called {
- t.Error("Write shouldn't call op")
- }
-
- called = false
- err = cc.Read(op)
- if err == nil {
- t.Error("Read should return an error")
- }
- if called {
- t.Error("Read shouldn't call op")
- }
-
+func controlOnConnSetup(network string, address string, c syscall.RawConn) error {
var operr error
- fn := func(s uintptr) {
- _, operr = syscall.GetsockoptInt(int(s), syscall.SOL_SOCKET, syscall.SO_REUSEADDR)
+ var fn func(uintptr)
+ switch network {
+ case "tcp", "udp", "ip":
+ return errors.New("ambiguous network: " + network)
+ case "unix", "unixpacket", "unixgram":
+ fn = func(s uintptr) {
+ _, operr = syscall.GetsockoptInt(int(s), syscall.SOL_SOCKET, syscall.SO_ERROR)
+ }
+ default:
+ switch network[len(network)-1] {
+ case '4':
+ fn = func(s uintptr) {
+ operr = syscall.SetsockoptInt(int(s), syscall.IPPROTO_IP, syscall.IP_TTL, 1)
+ }
+ case '6':
+ fn = func(s uintptr) {
+ operr = syscall.SetsockoptInt(int(s), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, 1)
+ }
+ default:
+ return errors.New("unknown network: " + network)
+ }
}
- err = cc.Control(fn)
- if err != nil || operr != nil {
- t.Fatal(err, operr)
+ if err := c.Control(fn); err != nil {
+ return err
}
- ln.Close()
- err = cc.Control(fn)
- if err == nil {
- t.Fatal("Control after Close should fail")
+ if operr != nil {
+ return operr
}
+ return nil
}
diff --git a/libgo/go/net/rawconn_windows_test.go b/libgo/go/net/rawconn_windows_test.go
index 2ee12c3..2774c97 100644
--- a/libgo/go/net/rawconn_windows_test.go
+++ b/libgo/go/net/rawconn_windows_test.go
@@ -5,85 +5,124 @@
package net
import (
+ "errors"
"syscall"
- "testing"
"unsafe"
)
-func TestRawConn(t *testing.T) {
- c, err := newLocalPacketListener("udp")
- if err != nil {
- t.Fatal(err)
- }
- defer c.Close()
- cc, err := c.(*UDPConn).SyscallConn()
- if err != nil {
- t.Fatal(err)
- }
-
+func readRawConn(c syscall.RawConn, b []byte) (int, error) {
var operr error
- fn := func(s uintptr) {
- operr = syscall.SetsockoptInt(syscall.Handle(s), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
- }
- err = cc.Control(fn)
- if err != nil || operr != nil {
- t.Fatal(err, operr)
- }
- c.Close()
- err = cc.Control(fn)
- if err == nil {
- t.Fatal("should fail")
- }
-}
-
-func TestRawConnListener(t *testing.T) {
- ln, err := newLocalListener("tcp")
+ var n int
+ err := c.Read(func(s uintptr) bool {
+ var read uint32
+ var flags uint32
+ var buf syscall.WSABuf
+ buf.Buf = &b[0]
+ buf.Len = uint32(len(b))
+ operr = syscall.WSARecv(syscall.Handle(s), &buf, 1, &read, &flags, nil, nil)
+ n = int(read)
+ return true
+ })
if err != nil {
- t.Fatal(err)
+ return n, err
}
- defer ln.Close()
-
- cc, err := ln.(*TCPListener).SyscallConn()
- if err != nil {
- t.Fatal(err)
+ if operr != nil {
+ return n, operr
}
+ return n, nil
+}
- called := false
- op := func(uintptr) bool {
- called = true
+func writeRawConn(c syscall.RawConn, b []byte) error {
+ var operr error
+ err := c.Write(func(s uintptr) bool {
+ var written uint32
+ var buf syscall.WSABuf
+ buf.Buf = &b[0]
+ buf.Len = uint32(len(b))
+ operr = syscall.WSASend(syscall.Handle(s), &buf, 1, &written, 0, nil, nil)
return true
+ })
+ if err != nil {
+ return err
}
-
- err = cc.Write(op)
- if err == nil {
- t.Error("Write should return an error")
- }
- if called {
- t.Error("Write shouldn't call op")
- }
-
- called = false
- err = cc.Read(op)
- if err == nil {
- t.Error("Read should return an error")
- }
- if called {
- t.Error("Read shouldn't call op")
+ if operr != nil {
+ return operr
}
+ return nil
+}
+func controlRawConn(c syscall.RawConn, addr Addr) error {
var operr error
fn := func(s uintptr) {
var v, l int32
l = int32(unsafe.Sizeof(v))
operr = syscall.Getsockopt(syscall.Handle(s), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, (*byte)(unsafe.Pointer(&v)), &l)
+ if operr != nil {
+ return
+ }
+ switch addr := addr.(type) {
+ case *TCPAddr:
+ // There's no guarantee that IP-level socket
+ // options work well with dual stack sockets.
+ // A simple solution would be to take a look
+ // at the bound address to the raw connection
+ // and to classify the address family of the
+ // underlying socket by the bound address:
+ //
+ // - When IP.To16() != nil and IP.To4() == nil,
+ // we can assume that the raw connection
+ // consists of an IPv6 socket using only
+ // IPv6 addresses.
+ //
+ // - When IP.To16() == nil and IP.To4() != nil,
+ // the raw connection consists of an IPv4
+ // socket using only IPv4 addresses.
+ //
+ // - Otherwise, the raw connection is a dual
+ // stack socket, an IPv6 socket using IPv6
+ // addresses including IPv4-mapped or
+ // IPv4-embedded IPv6 addresses.
+ if addr.IP.To16() != nil && addr.IP.To4() == nil {
+ operr = syscall.SetsockoptInt(syscall.Handle(s), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, 1)
+ } else if addr.IP.To16() == nil && addr.IP.To4() != nil {
+ operr = syscall.SetsockoptInt(syscall.Handle(s), syscall.IPPROTO_IP, syscall.IP_TTL, 1)
+ }
+ }
+ }
+ if err := c.Control(fn); err != nil {
+ return err
+ }
+ if operr != nil {
+ return operr
+ }
+ return nil
+}
+
+func controlOnConnSetup(network string, address string, c syscall.RawConn) error {
+ var operr error
+ var fn func(uintptr)
+ switch network {
+ case "tcp", "udp", "ip":
+ return errors.New("ambiguous network: " + network)
+ default:
+ switch network[len(network)-1] {
+ case '4':
+ fn = func(s uintptr) {
+ operr = syscall.SetsockoptInt(syscall.Handle(s), syscall.IPPROTO_IP, syscall.IP_TTL, 1)
+ }
+ case '6':
+ fn = func(s uintptr) {
+ operr = syscall.SetsockoptInt(syscall.Handle(s), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, 1)
+ }
+ default:
+ return errors.New("unknown network: " + network)
+ }
}
- err = cc.Control(fn)
- if err != nil || operr != nil {
- t.Fatal(err, operr)
+ if err := c.Control(fn); err != nil {
+ return err
}
- ln.Close()
- err = cc.Control(fn)
- if err == nil {
- t.Fatal("Control after Close should fail")
+ if operr != nil {
+ return operr
}
+ return nil
}
diff --git a/libgo/go/net/rpc/client.go b/libgo/go/net/rpc/client.go
index fce6a48..cad2d45 100644
--- a/libgo/go/net/rpc/client.go
+++ b/libgo/go/net/rpc/client.go
@@ -59,8 +59,8 @@ type Client struct {
// connection. ReadResponseBody may be called with a nil
// argument to force the body of the response to be read and then
// discarded.
+// See NewClient's comment for information about concurrent access.
type ClientCodec interface {
- // WriteRequest must be safe for concurrent use by multiple goroutines.
WriteRequest(*Request, interface{}) error
ReadResponseHeader(*Response) error
ReadResponseBody(interface{}) error
@@ -75,8 +75,8 @@ func (client *Client) send(call *Call) {
// Register this call.
client.mutex.Lock()
if client.shutdown || client.closing {
- call.Error = ErrShutdown
client.mutex.Unlock()
+ call.Error = ErrShutdown
call.done()
return
}
@@ -185,6 +185,11 @@ func (call *Call) done() {
// set of services at the other end of the connection.
// It adds a buffer to the write side of the connection so
// the header and payload are sent as a unit.
+//
+// The read and write halves of the connection are serialized independently,
+// so no interlocking is required. However each half may be accessed
+// concurrently so the implementation of conn should protect against
+// concurrent reads or concurrent writes.
func NewClient(conn io.ReadWriteCloser) *Client {
encBuf := bufio.NewWriter(conn)
client := &gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(encBuf), encBuf}
diff --git a/libgo/go/net/rpc/server.go b/libgo/go/net/rpc/server.go
index 96e6973..7bb6476 100644
--- a/libgo/go/net/rpc/server.go
+++ b/libgo/go/net/rpc/server.go
@@ -444,6 +444,7 @@ func (c *gobServerCodec) Close() error {
// The caller typically invokes ServeConn in a go statement.
// ServeConn uses the gob wire format (see package gob) on the
// connection. To use an alternate codec, use ServeCodec.
+// See NewClient's comment for information about concurrent access.
func (server *Server) ServeConn(conn io.ReadWriteCloser) {
buf := bufio.NewWriter(conn)
srv := &gobServerCodec{
@@ -653,12 +654,13 @@ func RegisterName(name string, rcvr interface{}) error {
// write a response back. The server calls Close when finished with the
// connection. ReadRequestBody may be called with a nil
// argument to force the body of the request to be read and discarded.
+// See NewClient's comment for information about concurrent access.
type ServerCodec interface {
ReadRequestHeader(*Request) error
ReadRequestBody(interface{}) error
- // WriteResponse must be safe for concurrent use by multiple goroutines.
WriteResponse(*Response, interface{}) error
+ // Close can be called multiple times and must be idempotent.
Close() error
}
@@ -667,6 +669,7 @@ type ServerCodec interface {
// The caller typically invokes ServeConn in a go statement.
// ServeConn uses the gob wire format (see package gob) on the
// connection. To use an alternate codec, use ServeCodec.
+// See NewClient's comment for information about concurrent access.
func ServeConn(conn io.ReadWriteCloser) {
DefaultServer.ServeConn(conn)
}
diff --git a/libgo/go/net/sendfile_solaris.go b/libgo/go/net/sendfile_solaris.go
deleted file mode 100644
index 63ca9d4..0000000
--- a/libgo/go/net/sendfile_solaris.go
+++ /dev/null
@@ -1,63 +0,0 @@
-// Copyright 2015 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package net
-
-import (
- "internal/poll"
- "io"
- "os"
-)
-
-// sendFile copies the contents of r to c using the sendfile
-// system call to minimize copies.
-//
-// if handled == true, sendFile returns the number of bytes copied and any
-// non-EOF error.
-//
-// if handled == false, sendFile performed no work.
-func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) {
- // Solaris uses 0 as the "until EOF" value. If you pass in more bytes than the
- // file contains, it will loop back to the beginning ad nauseam until it's sent
- // exactly the number of bytes told to. As such, we need to know exactly how many
- // bytes to send.
- var remain int64 = 0
-
- lr, ok := r.(*io.LimitedReader)
- if ok {
- remain, r = lr.N, lr.R
- if remain <= 0 {
- return 0, nil, true
- }
- }
- f, ok := r.(*os.File)
- if !ok {
- return 0, nil, false
- }
-
- if remain == 0 {
- fi, err := f.Stat()
- if err != nil {
- return 0, err, false
- }
-
- remain = fi.Size()
- }
-
- // The other quirk with Solaris's sendfile implementation is that it doesn't
- // use the current position of the file -- if you pass it offset 0, it starts
- // from offset 0. There's no way to tell it "start from current position", so
- // we have to manage that explicitly.
- pos, err := f.Seek(0, io.SeekCurrent)
- if err != nil {
- return 0, err, false
- }
-
- written, err = poll.SendFile(&c.pfd, int(f.Fd()), pos, remain)
-
- if lr != nil {
- lr.N = remain - written
- }
- return written, wrapSyscallError("sendfile", err), written > 0
-}
diff --git a/libgo/go/net/sendfile_stub.go b/libgo/go/net/sendfile_stub.go
index f043062..6d338da 100644
--- a/libgo/go/net/sendfile_stub.go
+++ b/libgo/go/net/sendfile_stub.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build aix darwin nacl netbsd openbsd
+// +build aix darwin js,wasm nacl netbsd openbsd
package net
diff --git a/libgo/go/net/sendfile_test.go b/libgo/go/net/sendfile_test.go
index 2255e7c..3b98277 100644
--- a/libgo/go/net/sendfile_test.go
+++ b/libgo/go/net/sendfile_test.go
@@ -2,9 +2,12 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// +build !js
+
package net
import (
+ "bytes"
"crypto/sha256"
"encoding/hex"
"fmt"
@@ -88,3 +91,122 @@ func TestSendfile(t *testing.T) {
t.Error(err)
}
}
+
+func TestSendfileParts(t *testing.T) {
+ ln, err := newLocalListener("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+
+ errc := make(chan error, 1)
+ go func(ln Listener) {
+ // Wait for a connection.
+ conn, err := ln.Accept()
+ if err != nil {
+ errc <- err
+ close(errc)
+ return
+ }
+
+ go func() {
+ defer close(errc)
+ defer conn.Close()
+
+ f, err := os.Open(twain)
+ if err != nil {
+ errc <- err
+ return
+ }
+ defer f.Close()
+
+ for i := 0; i < 3; i++ {
+ // Return file data using io.CopyN, which should use
+ // sendFile if available.
+ _, err = io.CopyN(conn, f, 3)
+ if err != nil {
+ errc <- err
+ return
+ }
+ }
+ }()
+ }(ln)
+
+ c, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ buf := new(bytes.Buffer)
+ buf.ReadFrom(c)
+
+ if want, have := "Produced ", buf.String(); have != want {
+ t.Errorf("unexpected server reply %q, want %q", have, want)
+ }
+
+ for err := range errc {
+ t.Error(err)
+ }
+}
+
+func TestSendfileSeeked(t *testing.T) {
+ ln, err := newLocalListener("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+
+ const seekTo = 65 << 10
+ const sendSize = 10 << 10
+
+ errc := make(chan error, 1)
+ go func(ln Listener) {
+ // Wait for a connection.
+ conn, err := ln.Accept()
+ if err != nil {
+ errc <- err
+ close(errc)
+ return
+ }
+
+ go func() {
+ defer close(errc)
+ defer conn.Close()
+
+ f, err := os.Open(twain)
+ if err != nil {
+ errc <- err
+ return
+ }
+ defer f.Close()
+ if _, err := f.Seek(seekTo, os.SEEK_SET); err != nil {
+ errc <- err
+ return
+ }
+
+ _, err = io.CopyN(conn, f, sendSize)
+ if err != nil {
+ errc <- err
+ return
+ }
+ }()
+ }(ln)
+
+ c, err := Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+
+ buf := new(bytes.Buffer)
+ buf.ReadFrom(c)
+
+ if buf.Len() != sendSize {
+ t.Errorf("Got %d bytes; want %d", buf.Len(), sendSize)
+ }
+
+ for err := range errc {
+ t.Error(err)
+ }
+}
diff --git a/libgo/go/net/sendfile_bsd.go b/libgo/go/net/sendfile_unix_alt.go
index 7a2b48c..9b3ba4e 100644
--- a/libgo/go/net/sendfile_bsd.go
+++ b/libgo/go/net/sendfile_unix_alt.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build dragonfly freebsd
+// +build dragonfly freebsd solaris
package net
@@ -20,7 +20,7 @@ import (
//
// if handled == false, sendFile performed no work.
func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) {
- // FreeBSD and DragonFly use 0 as the "until EOF" value.
+ // FreeBSD, DragonFly and Solaris use 0 as the "until EOF" value.
// If you pass in more bytes than the file contains, it will
// loop back to the beginning ad nauseam until it's sent
// exactly the number of bytes told to. As such, we need to
@@ -48,7 +48,7 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) {
remain = fi.Size()
}
- // The other quirk with FreeBSD/DragonFly's sendfile
+ // The other quirk with FreeBSD/DragonFly/Solaris's sendfile
// implementation is that it doesn't use the current position
// of the file -- if you pass it offset 0, it starts from
// offset 0. There's no way to tell it "start from current
@@ -63,5 +63,11 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) {
if lr != nil {
lr.N = remain - written
}
+
+ _, err1 := f.Seek(written, io.SeekCurrent)
+ if err1 != nil && err == nil {
+ return written, err1, written > 0
+ }
+
return written, wrapSyscallError("sendfile", err), written > 0
}
diff --git a/libgo/go/net/server_test.go b/libgo/go/net/server_test.go
index 2e998e2..1608beb 100644
--- a/libgo/go/net/server_test.go
+++ b/libgo/go/net/server_test.go
@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// +build !js
+
package net
import (
diff --git a/libgo/go/net/smtp/smtp.go b/libgo/go/net/smtp/smtp.go
index cf699e6..e4e12ae 100644
--- a/libgo/go/net/smtp/smtp.go
+++ b/libgo/go/net/smtp/smtp.go
@@ -343,10 +343,11 @@ func SendMail(addr string, a Auth, from string, to []string, msg []byte) error {
}
}
if a != nil && c.ext != nil {
- if _, ok := c.ext["AUTH"]; ok {
- if err = c.Auth(a); err != nil {
- return err
- }
+ if _, ok := c.ext["AUTH"]; !ok {
+ return errors.New("smtp: server doesn't support AUTH")
+ }
+ if err = c.Auth(a); err != nil {
+ return err
}
}
if err = c.Mail(from); err != nil {
diff --git a/libgo/go/net/smtp/smtp_test.go b/libgo/go/net/smtp/smtp_test.go
index d489922..000cac4 100644
--- a/libgo/go/net/smtp/smtp_test.go
+++ b/libgo/go/net/smtp/smtp_test.go
@@ -15,6 +15,7 @@ import (
"net/textproto"
"runtime"
"strings"
+ "sync"
"testing"
"time"
)
@@ -635,6 +636,50 @@ SendMail is working for me.
QUIT
`
+func TestSendMailWithAuth(t *testing.T) {
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("Unable to to create listener: %v", err)
+ }
+ defer l.Close()
+ wg := sync.WaitGroup{}
+ var done = make(chan struct{})
+ go func() {
+ defer wg.Done()
+ conn, err := l.Accept()
+ if err != nil {
+ t.Errorf("Accept error: %v", err)
+ return
+ }
+ defer conn.Close()
+
+ tc := textproto.NewConn(conn)
+ tc.PrintfLine("220 hello world")
+ msg, err := tc.ReadLine()
+ if msg == "EHLO localhost" {
+ tc.PrintfLine("250 mx.google.com at your service")
+ }
+ // for this test case, there should have no more traffic
+ <-done
+ }()
+ wg.Add(1)
+
+ err = SendMail(l.Addr().String(), PlainAuth("", "user", "pass", "smtp.google.com"), "test@example.com", []string{"other@example.com"}, []byte(strings.Replace(`From: test@example.com
+To: other@example.com
+Subject: SendMail test
+
+SendMail is working for me.
+`, "\n", "\r\n", -1)))
+ if err == nil {
+ t.Error("SendMail: Server doesn't support AUTH, expected to get an error, but got none ")
+ }
+ if err.Error() != "smtp: server doesn't support AUTH" {
+ t.Errorf("Expected: smtp: server doesn't support AUTH, got: %s", err)
+ }
+ close(done)
+ wg.Wait()
+}
+
func TestAuthFailed(t *testing.T) {
server := strings.Join(strings.Split(authFailedServer, "\n"), "\r\n")
client := strings.Join(strings.Split(authFailedClient, "\n"), "\r\n")
@@ -680,7 +725,7 @@ QUIT
`
func TestTLSClient(t *testing.T) {
- if runtime.GOOS == "freebsd" && runtime.GOARCH == "amd64" {
+ if (runtime.GOOS == "freebsd" && runtime.GOARCH == "amd64") || runtime.GOOS == "js" {
testenv.SkipFlaky(t, 19229)
}
ln := newLocalListener(t)
@@ -830,14 +875,9 @@ func init() {
}
func sendMail(hostPort string) error {
- host, _, err := net.SplitHostPort(hostPort)
- if err != nil {
- return err
- }
- auth := PlainAuth("", "", "", host)
from := "joe1@example.com"
to := []string{"joe2@example.com"}
- return SendMail(hostPort, auth, from, to, []byte("Subject: test\n\nhowdy!"))
+ return SendMail(hostPort, nil, from, to, []byte("Subject: test\n\nhowdy!"))
}
// (copied from net/http/httptest)
diff --git a/libgo/go/net/sock_cloexec.go b/libgo/go/net/sock_cloexec.go
index 06ff10d..0c883dc 100644
--- a/libgo/go/net/sock_cloexec.go
+++ b/libgo/go/net/sock_cloexec.go
@@ -5,7 +5,7 @@
// This file implements sysSocket and accept for platforms that
// provide a fast path for setting SetNonblock and CloseOnExec.
-// +build dragonfly freebsd linux
+// +build dragonfly freebsd linux netbsd openbsd
package net
diff --git a/libgo/go/net/sock_posix.go b/libgo/go/net/sock_posix.go
index 4733c42..fac3ac1 100644
--- a/libgo/go/net/sock_posix.go
+++ b/libgo/go/net/sock_posix.go
@@ -13,32 +13,9 @@ import (
"syscall"
)
-// A sockaddr represents a TCP, UDP, IP or Unix network endpoint
-// address that can be converted into a syscall.Sockaddr.
-type sockaddr interface {
- Addr
-
- // family returns the platform-dependent address family
- // identifier.
- family() int
-
- // isWildcard reports whether the address is a wildcard
- // address.
- isWildcard() bool
-
- // sockaddr returns the address converted into a syscall
- // sockaddr type that implements syscall.Sockaddr
- // interface. It returns a nil interface when the address is
- // nil.
- sockaddr(family int) (syscall.Sockaddr, error)
-
- // toLocal maps the zero address to a local system address (127.0.0.1 or ::1)
- toLocal(net string) sockaddr
-}
-
// socket returns a network file descriptor that is ready for
// asynchronous I/O using the network poller.
-func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr) (fd *netFD, err error) {
+func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) (fd *netFD, err error) {
s, err := sysSocket(family, sotype, proto)
if err != nil {
return nil, err
@@ -77,26 +54,41 @@ func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only
if laddr != nil && raddr == nil {
switch sotype {
case syscall.SOCK_STREAM, syscall.SOCK_SEQPACKET:
- if err := fd.listenStream(laddr, listenerBacklog); err != nil {
+ if err := fd.listenStream(laddr, listenerBacklog, ctrlFn); err != nil {
fd.Close()
return nil, err
}
return fd, nil
case syscall.SOCK_DGRAM:
- if err := fd.listenDatagram(laddr); err != nil {
+ if err := fd.listenDatagram(laddr, ctrlFn); err != nil {
fd.Close()
return nil, err
}
return fd, nil
}
}
- if err := fd.dial(ctx, laddr, raddr); err != nil {
+ if err := fd.dial(ctx, laddr, raddr, ctrlFn); err != nil {
fd.Close()
return nil, err
}
return fd, nil
}
+func (fd *netFD) ctrlNetwork() string {
+ switch fd.net {
+ case "unix", "unixgram", "unixpacket":
+ return fd.net
+ }
+ switch fd.net[len(fd.net)-1] {
+ case '4', '6':
+ return fd.net
+ }
+ if fd.family == syscall.AF_INET {
+ return fd.net + "4"
+ }
+ return fd.net + "6"
+}
+
func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr {
switch fd.family {
case syscall.AF_INET, syscall.AF_INET6:
@@ -121,14 +113,29 @@ func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr {
return func(syscall.Sockaddr) Addr { return nil }
}
-func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr) error {
+func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) error {
+ if ctrlFn != nil {
+ c, err := newRawConn(fd)
+ if err != nil {
+ return err
+ }
+ var ctrlAddr string
+ if raddr != nil {
+ ctrlAddr = raddr.String()
+ } else if laddr != nil {
+ ctrlAddr = laddr.String()
+ }
+ if err := ctrlFn(fd.ctrlNetwork(), ctrlAddr, c); err != nil {
+ return err
+ }
+ }
var err error
var lsa syscall.Sockaddr
if laddr != nil {
if lsa, err = laddr.sockaddr(fd.family); err != nil {
return err
} else if lsa != nil {
- if err := syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
+ if err = syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
return os.NewSyscallError("bind", err)
}
}
@@ -165,24 +172,34 @@ func (fd *netFD) dial(ctx context.Context, laddr, raddr sockaddr) error {
return nil
}
-func (fd *netFD) listenStream(laddr sockaddr, backlog int) error {
- if err := setDefaultListenerSockopts(fd.pfd.Sysfd); err != nil {
+func (fd *netFD) listenStream(laddr sockaddr, backlog int, ctrlFn func(string, string, syscall.RawConn) error) error {
+ var err error
+ if err = setDefaultListenerSockopts(fd.pfd.Sysfd); err != nil {
return err
}
- if lsa, err := laddr.sockaddr(fd.family); err != nil {
+ var lsa syscall.Sockaddr
+ if lsa, err = laddr.sockaddr(fd.family); err != nil {
return err
- } else if lsa != nil {
- if err := syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
- return os.NewSyscallError("bind", err)
+ }
+ if ctrlFn != nil {
+ c, err := newRawConn(fd)
+ if err != nil {
+ return err
+ }
+ if err := ctrlFn(fd.ctrlNetwork(), laddr.String(), c); err != nil {
+ return err
}
}
- if err := listenFunc(fd.pfd.Sysfd, backlog); err != nil {
+ if err = syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
+ return os.NewSyscallError("bind", err)
+ }
+ if err = listenFunc(fd.pfd.Sysfd, backlog); err != nil {
return os.NewSyscallError("listen", err)
}
- if err := fd.init(); err != nil {
+ if err = fd.init(); err != nil {
return err
}
- lsa, err := syscall.Getsockname(fd.pfd.Sysfd)
+ lsa, err = syscall.Getsockname(fd.pfd.Sysfd)
if err != nil {
return os.NewSyscallError("getsockname", err)
}
@@ -190,7 +207,7 @@ func (fd *netFD) listenStream(laddr sockaddr, backlog int) error {
return nil
}
-func (fd *netFD) listenDatagram(laddr sockaddr) error {
+func (fd *netFD) listenDatagram(laddr sockaddr, ctrlFn func(string, string, syscall.RawConn) error) error {
switch addr := laddr.(type) {
case *UDPAddr:
// We provide a socket that listens to a wildcard
@@ -214,17 +231,27 @@ func (fd *netFD) listenDatagram(laddr sockaddr) error {
laddr = &addr
}
}
- if lsa, err := laddr.sockaddr(fd.family); err != nil {
+ var err error
+ var lsa syscall.Sockaddr
+ if lsa, err = laddr.sockaddr(fd.family); err != nil {
return err
- } else if lsa != nil {
- if err := syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
- return os.NewSyscallError("bind", err)
+ }
+ if ctrlFn != nil {
+ c, err := newRawConn(fd)
+ if err != nil {
+ return err
}
+ if err := ctrlFn(fd.ctrlNetwork(), laddr.String(), c); err != nil {
+ return err
+ }
+ }
+ if err = syscall.Bind(fd.pfd.Sysfd, lsa); err != nil {
+ return os.NewSyscallError("bind", err)
}
- if err := fd.init(); err != nil {
+ if err = fd.init(); err != nil {
return err
}
- lsa, err := syscall.Getsockname(fd.pfd.Sysfd)
+ lsa, err = syscall.Getsockname(fd.pfd.Sysfd)
if err != nil {
return os.NewSyscallError("getsockname", err)
}
diff --git a/libgo/go/net/sock_stub.go b/libgo/go/net/sock_stub.go
index d1ec029..bbce61b 100644
--- a/libgo/go/net/sock_stub.go
+++ b/libgo/go/net/sock_stub.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build aix nacl solaris
+// +build aix nacl js,wasm solaris
package net
diff --git a/libgo/go/net/sockaddr_posix.go b/libgo/go/net/sockaddr_posix.go
new file mode 100644
index 0000000..4b8699d
--- /dev/null
+++ b/libgo/go/net/sockaddr_posix.go
@@ -0,0 +1,34 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin dragonfly freebsd js,wasm linux nacl netbsd openbsd solaris windows
+
+package net
+
+import (
+ "syscall"
+)
+
+// A sockaddr represents a TCP, UDP, IP or Unix network endpoint
+// address that can be converted into a syscall.Sockaddr.
+type sockaddr interface {
+ Addr
+
+ // family returns the platform-dependent address family
+ // identifier.
+ family() int
+
+ // isWildcard reports whether the address is a wildcard
+ // address.
+ isWildcard() bool
+
+ // sockaddr returns the address converted into a syscall
+ // sockaddr type that implements syscall.Sockaddr
+ // interface. It returns a nil interface when the address is
+ // nil.
+ sockaddr(family int) (syscall.Sockaddr, error)
+
+ // toLocal maps the zero address to a local system address (127.0.0.1 or ::1)
+ toLocal(net string) sockaddr
+}
diff --git a/libgo/go/net/sockopt_stub.go b/libgo/go/net/sockopt_stub.go
index 7e9e560..bc06675 100644
--- a/libgo/go/net/sockopt_stub.go
+++ b/libgo/go/net/sockopt_stub.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build nacl
+// +build nacl js,wasm
package net
diff --git a/libgo/go/net/sockoptip_stub.go b/libgo/go/net/sockoptip_stub.go
index fc20a9f..3297969 100644
--- a/libgo/go/net/sockoptip_stub.go
+++ b/libgo/go/net/sockoptip_stub.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build nacl
+// +build nacl js,wasm
package net
diff --git a/libgo/go/net/splice_linux.go b/libgo/go/net/splice_linux.go
new file mode 100644
index 0000000..b055f93
--- /dev/null
+++ b/libgo/go/net/splice_linux.go
@@ -0,0 +1,35 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "internal/poll"
+ "io"
+)
+
+// splice transfers data from r to c using the splice system call to minimize
+// copies from and to userspace. c must be a TCP connection. Currently, splice
+// is only enabled if r is also a TCP connection.
+//
+// If splice returns handled == false, it has performed no work.
+func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) {
+ var remain int64 = 1 << 62 // by default, copy until EOF
+ lr, ok := r.(*io.LimitedReader)
+ if ok {
+ remain, r = lr.N, lr.R
+ if remain <= 0 {
+ return 0, nil, true
+ }
+ }
+ s, ok := r.(*TCPConn)
+ if !ok {
+ return 0, nil, false
+ }
+ written, handled, sc, err := poll.Splice(&c.pfd, &s.fd.pfd, remain)
+ if lr != nil {
+ lr.N -= written
+ }
+ return written, wrapSyscallError(sc, err), handled
+}
diff --git a/libgo/go/net/splice_stub.go b/libgo/go/net/splice_stub.go
new file mode 100644
index 0000000..9106cb2
--- /dev/null
+++ b/libgo/go/net/splice_stub.go
@@ -0,0 +1,13 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build !linux
+
+package net
+
+import "io"
+
+func splice(c *netFD, r io.Reader) (int64, error, bool) {
+ return 0, nil, false
+}
diff --git a/libgo/go/net/splice_test.go b/libgo/go/net/splice_test.go
new file mode 100644
index 0000000..44a5c00
--- /dev/null
+++ b/libgo/go/net/splice_test.go
@@ -0,0 +1,489 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build linux
+
+package net
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "sync"
+ "testing"
+)
+
+func TestSplice(t *testing.T) {
+ t.Run("simple", testSpliceSimple)
+ t.Run("multipleWrite", testSpliceMultipleWrite)
+ t.Run("big", testSpliceBig)
+ t.Run("honorsLimitedReader", testSpliceHonorsLimitedReader)
+ t.Run("readerAtEOF", testSpliceReaderAtEOF)
+ t.Run("issue25985", testSpliceIssue25985)
+}
+
+func testSpliceSimple(t *testing.T) {
+ srv, err := newSpliceTestServer()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer srv.Close()
+ copyDone := srv.Copy()
+ msg := []byte("splice test")
+ if _, err := srv.Write(msg); err != nil {
+ t.Fatal(err)
+ }
+ got := make([]byte, len(msg))
+ if _, err := io.ReadFull(srv, got); err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(got, msg) {
+ t.Errorf("got %q, wrote %q", got, msg)
+ }
+ srv.CloseWrite()
+ srv.CloseRead()
+ if err := <-copyDone; err != nil {
+ t.Errorf("splice: %v", err)
+ }
+}
+
+func testSpliceMultipleWrite(t *testing.T) {
+ srv, err := newSpliceTestServer()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer srv.Close()
+ copyDone := srv.Copy()
+ msg1 := []byte("splice test part 1 ")
+ msg2 := []byte(" splice test part 2")
+ if _, err := srv.Write(msg1); err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+ if _, err := srv.Write(msg2); err != nil {
+ t.Fatal(err)
+ }
+ got := make([]byte, len(msg1)+len(msg2))
+ if _, err := io.ReadFull(srv, got); err != nil {
+ t.Fatal(err)
+ }
+ want := append(msg1, msg2...)
+ if !bytes.Equal(got, want) {
+ t.Errorf("got %q, wrote %q", got, want)
+ }
+ srv.CloseWrite()
+ srv.CloseRead()
+ if err := <-copyDone; err != nil {
+ t.Errorf("splice: %v", err)
+ }
+}
+
+func testSpliceBig(t *testing.T) {
+ // The maximum amount of data that internal/poll.Splice will use in a
+ // splice(2) call is 4 << 20. Use a bigger size here so that we test an
+ // amount that doesn't fit in a single call.
+ size := 5 << 20
+ srv, err := newSpliceTestServer()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer srv.Close()
+ big := make([]byte, size)
+ copyDone := srv.Copy()
+ type readResult struct {
+ b []byte
+ err error
+ }
+ readDone := make(chan readResult)
+ go func() {
+ got := make([]byte, len(big))
+ _, err := io.ReadFull(srv, got)
+ readDone <- readResult{got, err}
+ }()
+ if _, err := srv.Write(big); err != nil {
+ t.Fatal(err)
+ }
+ res := <-readDone
+ if res.err != nil {
+ t.Fatal(res.err)
+ }
+ got := res.b
+ if !bytes.Equal(got, big) {
+ t.Errorf("input and output differ")
+ }
+ srv.CloseWrite()
+ srv.CloseRead()
+ if err := <-copyDone; err != nil {
+ t.Errorf("splice: %v", err)
+ }
+}
+
+func testSpliceHonorsLimitedReader(t *testing.T) {
+ t.Run("stopsAfterN", testSpliceStopsAfterN)
+ t.Run("updatesN", testSpliceUpdatesN)
+}
+
+func testSpliceStopsAfterN(t *testing.T) {
+ clientUp, serverUp, err := spliceTestSocketPair("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer clientUp.Close()
+ defer serverUp.Close()
+ clientDown, serverDown, err := spliceTestSocketPair("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer clientDown.Close()
+ defer serverDown.Close()
+ count := 128
+ copyDone := make(chan error)
+ lr := &io.LimitedReader{
+ N: int64(count),
+ R: serverUp,
+ }
+ go func() {
+ _, err := io.Copy(serverDown, lr)
+ serverDown.Close()
+ copyDone <- err
+ }()
+ msg := make([]byte, 2*count)
+ if _, err := clientUp.Write(msg); err != nil {
+ t.Fatal(err)
+ }
+ clientUp.Close()
+ var buf bytes.Buffer
+ if _, err := io.Copy(&buf, clientDown); err != nil {
+ t.Fatal(err)
+ }
+ if buf.Len() != count {
+ t.Errorf("splice transferred %d bytes, want to stop after %d", buf.Len(), count)
+ }
+ clientDown.Close()
+ if err := <-copyDone; err != nil {
+ t.Errorf("splice: %v", err)
+ }
+}
+
+func testSpliceUpdatesN(t *testing.T) {
+ clientUp, serverUp, err := spliceTestSocketPair("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer clientUp.Close()
+ defer serverUp.Close()
+ clientDown, serverDown, err := spliceTestSocketPair("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer clientDown.Close()
+ defer serverDown.Close()
+ count := 128
+ copyDone := make(chan error)
+ lr := &io.LimitedReader{
+ N: int64(100 + count),
+ R: serverUp,
+ }
+ go func() {
+ _, err := io.Copy(serverDown, lr)
+ copyDone <- err
+ }()
+ msg := make([]byte, count)
+ if _, err := clientUp.Write(msg); err != nil {
+ t.Fatal(err)
+ }
+ clientUp.Close()
+ got := make([]byte, count)
+ if _, err := io.ReadFull(clientDown, got); err != nil {
+ t.Fatal(err)
+ }
+ clientDown.Close()
+ if err := <-copyDone; err != nil {
+ t.Errorf("splice: %v", err)
+ }
+ wantN := int64(100)
+ if lr.N != wantN {
+ t.Errorf("lr.N = %d, want %d", lr.N, wantN)
+ }
+}
+
+func testSpliceReaderAtEOF(t *testing.T) {
+ clientUp, serverUp, err := spliceTestSocketPair("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer clientUp.Close()
+ defer serverUp.Close()
+ clientDown, serverDown, err := spliceTestSocketPair("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer clientDown.Close()
+ defer serverDown.Close()
+
+ serverUp.Close()
+ _, err, handled := splice(serverDown.(*TCPConn).fd, serverUp)
+ if !handled {
+ t.Errorf("closed connection: got err = %v, handled = %t, want handled = true", err, handled)
+ }
+ lr := &io.LimitedReader{
+ N: 0,
+ R: serverUp,
+ }
+ _, err, handled = splice(serverDown.(*TCPConn).fd, lr)
+ if !handled {
+ t.Errorf("exhausted LimitedReader: got err = %v, handled = %t, want handled = true", err, handled)
+ }
+}
+
+func testSpliceIssue25985(t *testing.T) {
+ front, err := newLocalListener("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer front.Close()
+ back, err := newLocalListener("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer back.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ proxy := func() {
+ src, err := front.Accept()
+ if err != nil {
+ return
+ }
+ dst, err := Dial("tcp", back.Addr().String())
+ if err != nil {
+ return
+ }
+ defer dst.Close()
+ defer src.Close()
+ go func() {
+ io.Copy(src, dst)
+ wg.Done()
+ }()
+ go func() {
+ io.Copy(dst, src)
+ wg.Done()
+ }()
+ }
+
+ go proxy()
+
+ toFront, err := Dial("tcp", front.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ io.WriteString(toFront, "foo")
+ toFront.Close()
+
+ fromProxy, err := back.Accept()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer fromProxy.Close()
+
+ _, err = ioutil.ReadAll(fromProxy)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ wg.Wait()
+}
+
+func BenchmarkTCPReadFrom(b *testing.B) {
+ testHookUninstaller.Do(uninstallTestHooks)
+
+ var chunkSizes []int
+ for i := uint(10); i <= 20; i++ {
+ chunkSizes = append(chunkSizes, 1<<i)
+ }
+ // To benchmark the genericReadFrom code path, set this to false.
+ useSplice := true
+ for _, chunkSize := range chunkSizes {
+ b.Run(fmt.Sprint(chunkSize), func(b *testing.B) {
+ benchmarkSplice(b, chunkSize, useSplice)
+ })
+ }
+}
+
+func benchmarkSplice(b *testing.B, chunkSize int, useSplice bool) {
+ srv, err := newSpliceTestServer()
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer srv.Close()
+ var copyDone <-chan error
+ if useSplice {
+ copyDone = srv.Copy()
+ } else {
+ copyDone = srv.CopyNoSplice()
+ }
+ chunk := make([]byte, chunkSize)
+ discardDone := make(chan struct{})
+ go func() {
+ for {
+ buf := make([]byte, chunkSize)
+ _, err := srv.Read(buf)
+ if err != nil {
+ break
+ }
+ }
+ discardDone <- struct{}{}
+ }()
+ b.SetBytes(int64(chunkSize))
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ srv.Write(chunk)
+ }
+ srv.CloseWrite()
+ <-copyDone
+ srv.CloseRead()
+ <-discardDone
+}
+
+type spliceTestServer struct {
+ clientUp io.WriteCloser
+ clientDown io.ReadCloser
+ serverUp io.ReadCloser
+ serverDown io.WriteCloser
+}
+
+func newSpliceTestServer() (*spliceTestServer, error) {
+ // For now, both networks are hard-coded to TCP.
+ // If splice is enabled for non-tcp upstream connections,
+ // newSpliceTestServer will need to take a network parameter.
+ clientUp, serverUp, err := spliceTestSocketPair("tcp")
+ if err != nil {
+ return nil, err
+ }
+ clientDown, serverDown, err := spliceTestSocketPair("tcp")
+ if err != nil {
+ clientUp.Close()
+ serverUp.Close()
+ return nil, err
+ }
+ return &spliceTestServer{clientUp, clientDown, serverUp, serverDown}, nil
+}
+
+// Read reads from the downstream connection.
+func (srv *spliceTestServer) Read(b []byte) (int, error) {
+ return srv.clientDown.Read(b)
+}
+
+// Write writes to the upstream connection.
+func (srv *spliceTestServer) Write(b []byte) (int, error) {
+ return srv.clientUp.Write(b)
+}
+
+// Close closes the server.
+func (srv *spliceTestServer) Close() error {
+ err := srv.closeUp()
+ err1 := srv.closeDown()
+ if err == nil {
+ return err1
+ }
+ return err
+}
+
+// CloseWrite closes the client side of the upstream connection.
+func (srv *spliceTestServer) CloseWrite() error {
+ return srv.clientUp.Close()
+}
+
+// CloseRead closes the client side of the downstream connection.
+func (srv *spliceTestServer) CloseRead() error {
+ return srv.clientDown.Close()
+}
+
+// Copy copies from the server side of the upstream connection
+// to the server side of the downstream connection, in a separate
+// goroutine. Copy is done when the first send on the returned
+// channel succeeds.
+func (srv *spliceTestServer) Copy() <-chan error {
+ ch := make(chan error)
+ go func() {
+ _, err := io.Copy(srv.serverDown, srv.serverUp)
+ ch <- err
+ close(ch)
+ }()
+ return ch
+}
+
+// CopyNoSplice is like Copy, but ensures that the splice code path
+// is not reached.
+func (srv *spliceTestServer) CopyNoSplice() <-chan error {
+ type onlyReader struct {
+ io.Reader
+ }
+ ch := make(chan error)
+ go func() {
+ _, err := io.Copy(srv.serverDown, onlyReader{srv.serverUp})
+ ch <- err
+ close(ch)
+ }()
+ return ch
+}
+
+func (srv *spliceTestServer) closeUp() error {
+ var err, err1 error
+ if srv.serverUp != nil {
+ err = srv.serverUp.Close()
+ }
+ if srv.clientUp != nil {
+ err1 = srv.clientUp.Close()
+ }
+ if err == nil {
+ return err1
+ }
+ return err
+}
+
+func (srv *spliceTestServer) closeDown() error {
+ var err, err1 error
+ if srv.serverDown != nil {
+ err = srv.serverDown.Close()
+ }
+ if srv.clientDown != nil {
+ err1 = srv.clientDown.Close()
+ }
+ if err == nil {
+ return err1
+ }
+ return err
+}
+
+func spliceTestSocketPair(net string) (client, server Conn, err error) {
+ ln, err := newLocalListener(net)
+ if err != nil {
+ return nil, nil, err
+ }
+ defer ln.Close()
+ var cerr, serr error
+ acceptDone := make(chan struct{})
+ go func() {
+ server, serr = ln.Accept()
+ acceptDone <- struct{}{}
+ }()
+ client, cerr = Dial(ln.Addr().Network(), ln.Addr().String())
+ <-acceptDone
+ if cerr != nil {
+ if server != nil {
+ server.Close()
+ }
+ return nil, nil, cerr
+ }
+ if serr != nil {
+ if client != nil {
+ client.Close()
+ }
+ return nil, nil, serr
+ }
+ return client, server, nil
+}
diff --git a/libgo/go/net/sys_cloexec.go b/libgo/go/net/sys_cloexec.go
index def05cb..e97fb21 100644
--- a/libgo/go/net/sys_cloexec.go
+++ b/libgo/go/net/sys_cloexec.go
@@ -5,7 +5,7 @@
// This file implements sysSocket and accept for platforms that do not
// provide a fast path for setting SetNonblock and CloseOnExec.
-// +build aix darwin nacl netbsd openbsd solaris
+// +build aix darwin nacl solaris
package net
diff --git a/libgo/go/net/tcpsock.go b/libgo/go/net/tcpsock.go
index 9528140..db5d1f8 100644
--- a/libgo/go/net/tcpsock.go
+++ b/libgo/go/net/tcpsock.go
@@ -12,8 +12,8 @@ import (
"time"
)
-// BUG(mikio): On Windows, the File method of TCPListener is not
-// implemented.
+// BUG(mikio): On JS, NaCl and Windows, the File method of TCPConn and
+// TCPListener is not implemented.
// TCPAddr represents the address of a TCP end point.
type TCPAddr struct {
@@ -212,7 +212,8 @@ func DialTCP(network string, laddr, raddr *TCPAddr) (*TCPConn, error) {
if raddr == nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
}
- c, err := dialTCP(context.Background(), network, laddr, raddr)
+ sd := &sysDialer{network: network, address: raddr.String()}
+ c, err := sd.dialTCP(context.Background(), laddr, raddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}
@@ -292,8 +293,8 @@ func (l *TCPListener) SetDeadline(t time.Time) error {
return nil
}
-// File returns a copy of the underlying os.File, set to blocking
-// mode. It is the caller's responsibility to close f when finished.
+// File returns a copy of the underlying os.File.
+// It is the caller's responsibility to close f when finished.
// Closing l does not affect f, and closing f does not affect l.
//
// The returned os.File's file descriptor is different from the
@@ -328,7 +329,8 @@ func ListenTCP(network string, laddr *TCPAddr) (*TCPListener, error) {
if laddr == nil {
laddr = &TCPAddr{}
}
- ln, err := listenTCP(context.Background(), network, laddr)
+ sl := &sysListener{network: network, address: laddr.String()}
+ ln, err := sl.listenTCP(context.Background(), laddr)
if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: err}
}
diff --git a/libgo/go/net/tcpsock_plan9.go b/libgo/go/net/tcpsock_plan9.go
index e37f065..f70ef6f 100644
--- a/libgo/go/net/tcpsock_plan9.go
+++ b/libgo/go/net/tcpsock_plan9.go
@@ -14,23 +14,23 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
return genericReadFrom(c, r)
}
-func dialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
+func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
if testHookDialTCP != nil {
- return testHookDialTCP(ctx, net, laddr, raddr)
+ return testHookDialTCP(ctx, sd.network, laddr, raddr)
}
- return doDialTCP(ctx, net, laddr, raddr)
+ return sd.doDialTCP(ctx, laddr, raddr)
}
-func doDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
- switch net {
+func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
+ switch sd.network {
case "tcp", "tcp4", "tcp6":
default:
- return nil, UnknownNetworkError(net)
+ return nil, UnknownNetworkError(sd.network)
}
if raddr == nil {
return nil, errMissingAddress
}
- fd, err := dialPlan9(ctx, net, laddr, raddr)
+ fd, err := dialPlan9(ctx, sd.network, laddr, raddr)
if err != nil {
return nil, err
}
@@ -69,8 +69,8 @@ func (ln *TCPListener) file() (*os.File, error) {
return f, nil
}
-func listenTCP(ctx context.Context, network string, laddr *TCPAddr) (*TCPListener, error) {
- fd, err := listenPlan9(ctx, network, laddr)
+func (sl *sysListener) listenTCP(ctx context.Context, laddr *TCPAddr) (*TCPListener, error) {
+ fd, err := listenPlan9(ctx, sl.network, laddr)
if err != nil {
return nil, err
}
diff --git a/libgo/go/net/tcpsock_posix.go b/libgo/go/net/tcpsock_posix.go
index 9ba199d..64e71bf 100644
--- a/libgo/go/net/tcpsock_posix.go
+++ b/libgo/go/net/tcpsock_posix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build aix darwin dragonfly freebsd linux nacl netbsd openbsd solaris windows
+// +build aix darwin dragonfly freebsd js,wasm linux nacl netbsd openbsd solaris windows
package net
@@ -45,21 +45,24 @@ func (a *TCPAddr) toLocal(net string) sockaddr {
}
func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
+ if n, err, handled := splice(c.fd, r); handled {
+ return n, err
+ }
if n, err, handled := sendFile(c.fd, r); handled {
return n, err
}
return genericReadFrom(c, r)
}
-func dialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
+func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
if testHookDialTCP != nil {
- return testHookDialTCP(ctx, net, laddr, raddr)
+ return testHookDialTCP(ctx, sd.network, laddr, raddr)
}
- return doDialTCP(ctx, net, laddr, raddr)
+ return sd.doDialTCP(ctx, laddr, raddr)
}
-func doDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
- fd, err := internetSocket(ctx, net, laddr, raddr, syscall.SOCK_STREAM, 0, "dial")
+func (sd *sysDialer) doDialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
+ fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", sd.Dialer.Control)
// TCP has a rarely used mechanism called a 'simultaneous connection' in
// which Dial("tcp", addr1, addr2) run on the machine at addr1 can
@@ -77,7 +80,7 @@ func doDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn
// close the fd and try again. If it happens twice more, we relent and
// use the result. See also:
// https://golang.org/issue/2690
- // http://stackoverflow.com/questions/4949858/
+ // https://stackoverflow.com/questions/4949858/
//
// The opposite can also happen: if we ask the kernel to pick an appropriate
// originating local address, sometimes it picks one that is already in use.
@@ -89,7 +92,7 @@ func doDialTCP(ctx context.Context, net string, laddr, raddr *TCPAddr) (*TCPConn
if err == nil {
fd.Close()
}
- fd, err = internetSocket(ctx, net, laddr, raddr, syscall.SOCK_STREAM, 0, "dial")
+ fd, err = internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial", sd.Dialer.Control)
}
if err != nil {
@@ -152,8 +155,8 @@ func (ln *TCPListener) file() (*os.File, error) {
return f, nil
}
-func listenTCP(ctx context.Context, network string, laddr *TCPAddr) (*TCPListener, error) {
- fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_STREAM, 0, "listen")
+func (sl *sysListener) listenTCP(ctx context.Context, laddr *TCPAddr) (*TCPListener, error) {
+ fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_STREAM, 0, "listen", sl.ListenConfig.Control)
if err != nil {
return nil, err
}
diff --git a/libgo/go/net/tcpsock_test.go b/libgo/go/net/tcpsock_test.go
index 04b38b6..f8a775f 100644
--- a/libgo/go/net/tcpsock_test.go
+++ b/libgo/go/net/tcpsock_test.go
@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// +build !js
+
package net
import (
diff --git a/libgo/go/net/tcpsock_unix_test.go b/libgo/go/net/tcpsock_unix_test.go
index 95c02d2..2bd591b 100644
--- a/libgo/go/net/tcpsock_unix_test.go
+++ b/libgo/go/net/tcpsock_unix_test.go
@@ -2,13 +2,12 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build !plan9,!windows
+// +build !js,!plan9,!windows
package net
import (
"context"
- "internal/testenv"
"math/rand"
"runtime"
"sync"
@@ -84,9 +83,8 @@ func TestTCPSpuriousConnSetupCompletion(t *testing.T) {
// Issue 19289.
// Test that a canceled Dial does not cause a subsequent Dial to succeed.
func TestTCPSpuriousConnSetupCompletionWithCancel(t *testing.T) {
- if testenv.Builder() == "" {
- testenv.MustHaveExternalNetwork(t)
- }
+ mustHaveExternalNetwork(t)
+
defer dnsWaitGroup.Wait()
t.Parallel()
const tries = 10000
diff --git a/libgo/go/net/tcpsockopt_darwin.go b/libgo/go/net/tcpsockopt_darwin.go
index 7415c76..5b738d2 100644
--- a/libgo/go/net/tcpsockopt_darwin.go
+++ b/libgo/go/net/tcpsockopt_darwin.go
@@ -16,9 +16,7 @@ func setKeepAlivePeriod(fd *netFD, d time.Duration) error {
// The kernel expects seconds so round to next highest second.
d += (time.Second - time.Nanosecond)
secs := int(d.Seconds())
- switch err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, sysTCP_KEEPINTVL, secs); err {
- case nil, syscall.ENOPROTOOPT: // OS X 10.7 and earlier don't support this option
- default:
+ if err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, sysTCP_KEEPINTVL, secs); err != nil {
return wrapSyscallError("setsockopt", err)
}
err := fd.pfd.SetsockoptInt(syscall.IPPROTO_TCP, syscall.TCP_KEEPALIVE, secs)
diff --git a/libgo/go/net/tcpsockopt_stub.go b/libgo/go/net/tcpsockopt_stub.go
index 19c83e6..fd7f579 100644
--- a/libgo/go/net/tcpsockopt_stub.go
+++ b/libgo/go/net/tcpsockopt_stub.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build nacl
+// +build nacl js,wasm
package net
diff --git a/libgo/go/net/textproto/reader.go b/libgo/go/net/textproto/reader.go
index 8c3a052..feb464b 100644
--- a/libgo/go/net/textproto/reader.go
+++ b/libgo/go/net/textproto/reader.go
@@ -236,7 +236,7 @@ func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err err
// with the same code followed by a space. Each line in message is
// separated by a newline (\n).
//
-// See page 36 of RFC 959 (http://www.ietf.org/rfc/rfc959.txt) for
+// See page 36 of RFC 959 (https://www.ietf.org/rfc/rfc959.txt) for
// details of another form of response accepted:
//
// code-message line 1
diff --git a/libgo/go/net/textproto/reader_test.go b/libgo/go/net/textproto/reader_test.go
index c6a6ced..4f37903 100644
--- a/libgo/go/net/textproto/reader_test.go
+++ b/libgo/go/net/textproto/reader_test.go
@@ -290,7 +290,7 @@ var readResponseTests = []readResponseTest{
},
}
-// See http://www.ietf.org/rfc/rfc959.txt page 36.
+// See https://www.ietf.org/rfc/rfc959.txt page 36.
func TestRFC959Lines(t *testing.T) {
for i, tt := range readResponseTests {
r := reader(tt.in + "\nFOLLOWING DATA")
diff --git a/libgo/go/net/timeout_test.go b/libgo/go/net/timeout_test.go
index 9de7801..7c7d0c8 100644
--- a/libgo/go/net/timeout_test.go
+++ b/libgo/go/net/timeout_test.go
@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// +build !js
+
package net
import (
@@ -482,7 +484,7 @@ func TestReadFromTimeout(t *testing.T) {
time.Sleep(tt.timeout / 3)
continue
}
- if n != 0 {
+ if nerr, ok := err.(Error); ok && nerr.Timeout() && n != 0 {
t.Fatalf("#%d/%d: read %d; want 0", i, j, n)
}
break
diff --git a/libgo/go/net/udpsock.go b/libgo/go/net/udpsock.go
index 158265f..b234ed8 100644
--- a/libgo/go/net/udpsock.go
+++ b/libgo/go/net/udpsock.go
@@ -18,6 +18,9 @@ import (
// BUG(mikio): On NaCl, the ListenMulticastUDP function is not
// implemented.
+// BUG(mikio): On JS, methods and functions related to UDPConn are not
+// implemented.
+
// UDPAddr represents the address of a UDP end point.
type UDPAddr struct {
IP IP
@@ -208,7 +211,8 @@ func DialUDP(network string, laddr, raddr *UDPAddr) (*UDPConn, error) {
if raddr == nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
}
- c, err := dialUDP(context.Background(), network, laddr, raddr)
+ sd := &sysDialer{network: network, address: raddr.String()}
+ c, err := sd.dialUDP(context.Background(), laddr, raddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}
@@ -233,7 +237,8 @@ func ListenUDP(network string, laddr *UDPAddr) (*UDPConn, error) {
if laddr == nil {
laddr = &UDPAddr{}
}
- c, err := listenUDP(context.Background(), network, laddr)
+ sl := &sysListener{network: network, address: laddr.String()}
+ c, err := sl.listenUDP(context.Background(), laddr)
if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: err}
}
@@ -266,7 +271,8 @@ func ListenMulticastUDP(network string, ifi *Interface, gaddr *UDPAddr) (*UDPCon
if gaddr == nil || gaddr.IP == nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr.opAddr(), Err: errMissingAddress}
}
- c, err := listenMulticastUDP(context.Background(), network, ifi, gaddr)
+ sl := &sysListener{network: network, address: gaddr.String()}
+ c, err := sl.listenMulticastUDP(context.Background(), ifi, gaddr)
if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr.opAddr(), Err: err}
}
diff --git a/libgo/go/net/udpsock_plan9.go b/libgo/go/net/udpsock_plan9.go
index 1ce7f88..563d943 100644
--- a/libgo/go/net/udpsock_plan9.go
+++ b/libgo/go/net/udpsock_plan9.go
@@ -55,8 +55,8 @@ func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error
return 0, 0, syscall.EPLAN9
}
-func dialUDP(ctx context.Context, net string, laddr, raddr *UDPAddr) (*UDPConn, error) {
- fd, err := dialPlan9(ctx, net, laddr, raddr)
+func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPConn, error) {
+ fd, err := dialPlan9(ctx, sd.network, laddr, raddr)
if err != nil {
return nil, err
}
@@ -91,8 +91,8 @@ func unmarshalUDPHeader(b []byte) (*udpHeader, []byte) {
return h, b
}
-func listenUDP(ctx context.Context, network string, laddr *UDPAddr) (*UDPConn, error) {
- l, err := listenPlan9(ctx, network, laddr)
+func (sl *sysListener) listenUDP(ctx context.Context, laddr *UDPAddr) (*UDPConn, error) {
+ l, err := listenPlan9(ctx, sl.network, laddr)
if err != nil {
return nil, err
}
@@ -108,8 +108,8 @@ func listenUDP(ctx context.Context, network string, laddr *UDPAddr) (*UDPConn, e
return newUDPConn(fd), err
}
-func listenMulticastUDP(ctx context.Context, network string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
- l, err := listenPlan9(ctx, network, gaddr)
+func (sl *sysListener) listenMulticastUDP(ctx context.Context, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
+ l, err := listenPlan9(ctx, sl.network, gaddr)
if err != nil {
return nil, err
}
diff --git a/libgo/go/net/udpsock_posix.go b/libgo/go/net/udpsock_posix.go
index fe552ba..611fe51 100644
--- a/libgo/go/net/udpsock_posix.go
+++ b/libgo/go/net/udpsock_posix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build aix darwin dragonfly freebsd linux nacl netbsd openbsd solaris windows
+// +build aix darwin dragonfly freebsd js,wasm linux nacl netbsd openbsd solaris windows
package net
@@ -94,24 +94,24 @@ func (c *UDPConn) writeMsg(b, oob []byte, addr *UDPAddr) (n, oobn int, err error
return c.fd.writeMsg(b, oob, sa)
}
-func dialUDP(ctx context.Context, net string, laddr, raddr *UDPAddr) (*UDPConn, error) {
- fd, err := internetSocket(ctx, net, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial")
+func (sd *sysDialer) dialUDP(ctx context.Context, laddr, raddr *UDPAddr) (*UDPConn, error) {
+ fd, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_DGRAM, 0, "dial", sd.Dialer.Control)
if err != nil {
return nil, err
}
return newUDPConn(fd), nil
}
-func listenUDP(ctx context.Context, network string, laddr *UDPAddr) (*UDPConn, error) {
- fd, err := internetSocket(ctx, network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen")
+func (sl *sysListener) listenUDP(ctx context.Context, laddr *UDPAddr) (*UDPConn, error) {
+ fd, err := internetSocket(ctx, sl.network, laddr, nil, syscall.SOCK_DGRAM, 0, "listen", sl.ListenConfig.Control)
if err != nil {
return nil, err
}
return newUDPConn(fd), nil
}
-func listenMulticastUDP(ctx context.Context, network string, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
- fd, err := internetSocket(ctx, network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen")
+func (sl *sysListener) listenMulticastUDP(ctx context.Context, ifi *Interface, gaddr *UDPAddr) (*UDPConn, error) {
+ fd, err := internetSocket(ctx, sl.network, gaddr, nil, syscall.SOCK_DGRAM, 0, "listen", sl.ListenConfig.Control)
if err != nil {
return nil, err
}
diff --git a/libgo/go/net/udpsock_test.go b/libgo/go/net/udpsock_test.go
index 769576c..4940644 100644
--- a/libgo/go/net/udpsock_test.go
+++ b/libgo/go/net/udpsock_test.go
@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// +build !js
+
package net
import (
@@ -163,11 +165,6 @@ func testWriteToConn(t *testing.T, raddr string) {
switch runtime.GOOS {
case "nacl": // see golang.org/issue/9252
t.Skipf("not implemented yet on %s", runtime.GOOS)
- case "windows":
- if testenv.IsWindowsXP() {
- t.Log("skipping broken test on Windows XP (see golang.org/issue/23072)")
- return
- }
default:
if err != nil {
t.Fatal(err)
@@ -211,11 +208,6 @@ func testWriteToPacketConn(t *testing.T, raddr string) {
switch runtime.GOOS {
case "nacl": // see golang.org/issue/9252
t.Skipf("not implemented yet on %s", runtime.GOOS)
- case "windows":
- if testenv.IsWindowsXP() {
- t.Log("skipping broken test on Windows XP (see golang.org/issue/23072)")
- return
- }
default:
if err != nil {
t.Fatal(err)
@@ -408,9 +400,56 @@ func TestUDPZeroByteBuffer(t *testing.T) {
switch err {
case nil: // ReadFrom succeeds
default: // Read may timeout, it depends on the platform
- if nerr, ok := err.(Error); (!ok || !nerr.Timeout()) && runtime.GOOS != "windows" { // Windows returns WSAEMSGSIZ
+ if nerr, ok := err.(Error); (!ok || !nerr.Timeout()) && runtime.GOOS != "windows" { // Windows returns WSAEMSGSIZE
+ t.Fatal(err)
+ }
+ }
+ }
+}
+
+func TestUDPReadSizeError(t *testing.T) {
+ switch runtime.GOOS {
+ case "nacl", "plan9":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ c1, err := newLocalPacketListener("udp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c1.Close()
+
+ c2, err := Dial("udp", c1.LocalAddr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c2.Close()
+
+ b1 := []byte("READ SIZE ERROR TEST")
+ for _, genericRead := range []bool{false, true} {
+ n, err := c2.Write(b1)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != len(b1) {
+ t.Errorf("got %d; want %d", n, len(b1))
+ }
+ c1.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
+ b2 := make([]byte, len(b1)-1)
+ if genericRead {
+ n, err = c1.(Conn).Read(b2)
+ } else {
+ n, _, err = c1.ReadFrom(b2)
+ }
+ switch err {
+ case nil: // ReadFrom succeeds
+ default: // Read may timeout, it depends on the platform
+ if nerr, ok := err.(Error); (!ok || !nerr.Timeout()) && runtime.GOOS != "windows" { // Windows returns WSAEMSGSIZE
t.Fatal(err)
}
}
+ if n != len(b1)-1 {
+ t.Fatalf("got %d; want %d", n, len(b1)-1)
+ }
}
}
diff --git a/libgo/go/net/unixsock.go b/libgo/go/net/unixsock.go
index 20326da..3ae62f6 100644
--- a/libgo/go/net/unixsock.go
+++ b/libgo/go/net/unixsock.go
@@ -12,6 +12,9 @@ import (
"time"
)
+// BUG(mikio): On JS, NaCl, Plan 9 and Windows, methods and functions
+// related to UnixConn and UnixListener are not implemented.
+
// UnixAddr represents the address of a Unix domain socket end point.
type UnixAddr struct {
Name string
@@ -200,7 +203,8 @@ func DialUnix(network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
default:
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: UnknownNetworkError(network)}
}
- c, err := dialUnix(context.Background(), network, laddr, raddr)
+ sd := &sysDialer{network: network, address: raddr.String()}
+ c, err := sd.dialUnix(context.Background(), laddr, raddr)
if err != nil {
return nil, &OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}
@@ -286,8 +290,8 @@ func (l *UnixListener) SetDeadline(t time.Time) error {
return nil
}
-// File returns a copy of the underlying os.File, set to blocking
-// mode. It is the caller's responsibility to close f when finished.
+// File returns a copy of the underlying os.File.
+// It is the caller's responsibility to close f when finished.
// Closing l does not affect f, and closing f does not affect l.
//
// The returned os.File's file descriptor is different from the
@@ -316,7 +320,8 @@ func ListenUnix(network string, laddr *UnixAddr) (*UnixListener, error) {
if laddr == nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: errMissingAddress}
}
- ln, err := listenUnix(context.Background(), network, laddr)
+ sl := &sysListener{network: network, address: laddr.String()}
+ ln, err := sl.listenUnix(context.Background(), laddr)
if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: err}
}
@@ -335,7 +340,8 @@ func ListenUnixgram(network string, laddr *UnixAddr) (*UnixConn, error) {
if laddr == nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: nil, Err: errMissingAddress}
}
- c, err := listenUnixgram(context.Background(), network, laddr)
+ sl := &sysListener{network: network, address: laddr.String()}
+ c, err := sl.listenUnixgram(context.Background(), laddr)
if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: laddr.opAddr(), Err: err}
}
diff --git a/libgo/go/net/unixsock_plan9.go b/libgo/go/net/unixsock_plan9.go
index e70eb21..6ebd4d7 100644
--- a/libgo/go/net/unixsock_plan9.go
+++ b/libgo/go/net/unixsock_plan9.go
@@ -26,7 +26,7 @@ func (c *UnixConn) writeMsg(b, oob []byte, addr *UnixAddr) (n, oobn int, err err
return 0, 0, syscall.EPLAN9
}
-func dialUnix(ctx context.Context, network string, laddr, raddr *UnixAddr) (*UnixConn, error) {
+func (sd *sysDialer) dialUnix(ctx context.Context, laddr, raddr *UnixAddr) (*UnixConn, error) {
return nil, syscall.EPLAN9
}
@@ -42,10 +42,10 @@ func (ln *UnixListener) file() (*os.File, error) {
return nil, syscall.EPLAN9
}
-func listenUnix(ctx context.Context, network string, laddr *UnixAddr) (*UnixListener, error) {
+func (sl *sysListener) listenUnix(ctx context.Context, laddr *UnixAddr) (*UnixListener, error) {
return nil, syscall.EPLAN9
}
-func listenUnixgram(ctx context.Context, network string, laddr *UnixAddr) (*UnixConn, error) {
+func (sl *sysListener) listenUnixgram(ctx context.Context, laddr *UnixAddr) (*UnixConn, error) {
return nil, syscall.EPLAN9
}
diff --git a/libgo/go/net/unixsock_posix.go b/libgo/go/net/unixsock_posix.go
index 945aa03..74f5cc2 100644
--- a/libgo/go/net/unixsock_posix.go
+++ b/libgo/go/net/unixsock_posix.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build aix darwin dragonfly freebsd linux nacl netbsd openbsd solaris windows
+// +build aix darwin dragonfly freebsd js,wasm linux nacl netbsd openbsd solaris windows
package net
@@ -13,7 +13,7 @@ import (
"syscall"
)
-func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string) (*netFD, error) {
+func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode string, ctrlFn func(string, string, syscall.RawConn) error) (*netFD, error) {
var sotype int
switch net {
case "unix":
@@ -42,7 +42,7 @@ func unixSocket(ctx context.Context, net string, laddr, raddr sockaddr, mode str
return nil, errors.New("unknown mode: " + mode)
}
- fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr)
+ fd, err := socket(ctx, net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, ctrlFn)
if err != nil {
return nil, err
}
@@ -150,8 +150,8 @@ func (c *UnixConn) writeMsg(b, oob []byte, addr *UnixAddr) (n, oobn int, err err
return c.fd.writeMsg(b, oob, sa)
}
-func dialUnix(ctx context.Context, net string, laddr, raddr *UnixAddr) (*UnixConn, error) {
- fd, err := unixSocket(ctx, net, laddr, raddr, "dial")
+func (sd *sysDialer) dialUnix(ctx context.Context, laddr, raddr *UnixAddr) (*UnixConn, error) {
+ fd, err := unixSocket(ctx, sd.network, laddr, raddr, "dial", sd.Dialer.Control)
if err != nil {
return nil, err
}
@@ -206,16 +206,16 @@ func (l *UnixListener) SetUnlinkOnClose(unlink bool) {
l.unlink = unlink
}
-func listenUnix(ctx context.Context, network string, laddr *UnixAddr) (*UnixListener, error) {
- fd, err := unixSocket(ctx, network, laddr, nil, "listen")
+func (sl *sysListener) listenUnix(ctx context.Context, laddr *UnixAddr) (*UnixListener, error) {
+ fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", sl.ListenConfig.Control)
if err != nil {
return nil, err
}
return &UnixListener{fd: fd, path: fd.laddr.String(), unlink: true}, nil
}
-func listenUnixgram(ctx context.Context, network string, laddr *UnixAddr) (*UnixConn, error) {
- fd, err := unixSocket(ctx, network, laddr, nil, "listen")
+func (sl *sysListener) listenUnixgram(ctx context.Context, laddr *UnixAddr) (*UnixConn, error) {
+ fd, err := unixSocket(ctx, sl.network, laddr, nil, "listen", sl.ListenConfig.Control)
if err != nil {
return nil, err
}
diff --git a/libgo/go/net/unixsock_test.go b/libgo/go/net/unixsock_test.go
index 3e5c8bc..4828990 100644
--- a/libgo/go/net/unixsock_test.go
+++ b/libgo/go/net/unixsock_test.go
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// +build !nacl,!plan9,!windows
+// +build !js,!nacl,!plan9,!windows
package net
diff --git a/libgo/go/net/url/url.go b/libgo/go/net/url/url.go
index 3e12179..80eb7a8 100644
--- a/libgo/go/net/url/url.go
+++ b/libgo/go/net/url/url.go
@@ -11,7 +11,6 @@ package url
// contain references to issue numbers with details.
import (
- "bytes"
"errors"
"fmt"
"sort"
@@ -159,13 +158,26 @@ func shouldEscape(c byte, mode encoding) bool {
}
}
+ if mode == encodeFragment {
+ // RFC 3986 §2.2 allows not escaping sub-delims. A subset of sub-delims are
+ // included in reserved from RFC 2396 §2.2. The remaining sub-delims do not
+ // need to be escaped. To minimize potential breakage, we apply two restrictions:
+ // (1) we always escape sub-delims outside of the fragment, and (2) we always
+ // escape single quote to avoid breaking callers that had previously assumed that
+ // single quotes would be escaped. See issue #19917.
+ switch c {
+ case '!', '(', ')', '*':
+ return false
+ }
+ }
+
// Everything else must be escaped.
return true
}
// QueryUnescape does the inverse transformation of QueryEscape,
// converting each 3-byte encoded substring of the form "%AB" into the
-// hex-decoded byte 0xAB. It also converts '+' into ' ' (space).
+// hex-decoded byte 0xAB.
// It returns an error if any % is not followed by two hexadecimal
// digits.
func QueryUnescape(s string) (string, error) {
@@ -174,9 +186,8 @@ func QueryUnescape(s string) (string, error) {
// PathUnescape does the inverse transformation of PathEscape,
// converting each 3-byte encoded substring of the form "%AB" into the
-// hex-decoded byte 0xAB. It also converts '+' into ' ' (space).
-// It returns an error if any % is not followed by two hexadecimal
-// digits.
+// hex-decoded byte 0xAB. It returns an error if any % is not followed
+// by two hexadecimal digits.
//
// PathUnescape is identical to QueryUnescape except that it does not
// unescape '+' to ' ' (space).
@@ -738,7 +749,7 @@ func validOptionalPort(port string) bool {
// - if u.RawQuery is empty, ?query is omitted.
// - if u.Fragment is empty, #fragment is omitted.
func (u *URL) String() string {
- var buf bytes.Buffer
+ var buf strings.Builder
if u.Scheme != "" {
buf.WriteString(u.Scheme)
buf.WriteByte(':')
@@ -879,7 +890,7 @@ func (v Values) Encode() string {
if v == nil {
return ""
}
- var buf bytes.Buffer
+ var buf strings.Builder
keys := make([]string, 0, len(v))
for k := range v {
keys = append(keys, k)
@@ -887,12 +898,13 @@ func (v Values) Encode() string {
sort.Strings(keys)
for _, k := range keys {
vs := v[k]
- prefix := QueryEscape(k) + "="
+ keyEscaped := QueryEscape(k)
for _, v := range vs {
if buf.Len() > 0 {
buf.WriteByte('&')
}
- buf.WriteString(prefix)
+ buf.WriteString(keyEscaped)
+ buf.WriteByte('=')
buf.WriteString(QueryEscape(v))
}
}
@@ -953,7 +965,7 @@ func (u *URL) Parse(ref string) (*URL, error) {
}
// ResolveReference resolves a URI reference to an absolute URI from
-// an absolute base URI, per RFC 3986 Section 5.2. The URI reference
+// an absolute base URI u, per RFC 3986 Section 5.2. The URI reference
// may be relative or absolute. ResolveReference always returns a new
// URL instance, even if the returned URL is identical to either the
// base or reference. If ref is an absolute URL, then ResolveReference
diff --git a/libgo/go/net/url/url_test.go b/libgo/go/net/url/url_test.go
index f2d311a..9043a84 100644
--- a/libgo/go/net/url/url_test.go
+++ b/libgo/go/net/url/url_test.go
@@ -1075,6 +1075,7 @@ var resolveReferenceTests = []struct {
// Fragment
{"http://foo.com/bar", ".#frag", "http://foo.com/#frag"},
+ {"http://example.org/", "#!$&%27()*+,;=", "http://example.org/#!$&%27()*+,;="},
// Paths with escaping (issue 16947).
{"http://foo.com/foo%2fbar/", "../baz", "http://foo.com/baz"},
@@ -1433,8 +1434,8 @@ func TestParseErrors(t *testing.T) {
{"mysql://x@y(1.2.3.4:123)/foo", false},
{"http://[]%20%48%54%54%50%2f%31%2e%31%0a%4d%79%48%65%61%64%65%72%3a%20%31%32%33%0a%0a/", true}, // golang.org/issue/11208
- {"http://a b.com/", true}, // no space in host name please
- {"cache_object://foo", true}, // scheme cannot have _, relative path cannot have : in first segment
+ {"http://a b.com/", true}, // no space in host name please
+ {"cache_object://foo", true}, // scheme cannot have _, relative path cannot have : in first segment
{"cache_object:foo", true},
{"cache_object:foo/bar", true},
{"cache_object/:foo/bar", false},
diff --git a/libgo/go/net/writev_test.go b/libgo/go/net/writev_test.go
index 4c05be4..c43be84 100644
--- a/libgo/go/net/writev_test.go
+++ b/libgo/go/net/writev_test.go
@@ -2,6 +2,8 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
+// +build !js
+
package net
import (